agent_client_protocol/role/
acp.rs1use std::{fmt::Debug, hash::Hash};
2
3use agent_client_protocol_schema::{NewSessionRequest, NewSessionResponse, SessionId};
4
5use crate::jsonrpc::{Builder, handlers::NullHandler, run::NullRun};
6use crate::role::{HasPeer, RemoteStyle};
7use crate::schema::{InitializeProxyRequest, InitializeRequest, METHOD_INITIALIZE_PROXY};
8use crate::util::MatchDispatchFrom;
9use crate::{ConnectTo, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, Role, RoleId};
10
11#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct Client;
16
17impl Role for Client {
18 type Counterpart = Agent;
19
20 fn builder(self) -> Builder<Self> {
21 Builder::new(self).v1_client()
22 }
23
24 async fn default_handle_dispatch_from(
25 &self,
26 message: Dispatch,
27 _connection: ConnectionTo<Client>,
28 ) -> Result<Handled<Dispatch>, crate::Error> {
29 Ok(Handled::No {
30 message,
31 retry: false,
32 })
33 }
34
35 fn role_id(&self) -> RoleId {
36 RoleId::from_singleton(self)
37 }
38
39 fn counterpart(&self) -> Self::Counterpart {
40 Agent
41 }
42}
43
44impl Client {
45 pub fn builder(self) -> Builder<Client, NullHandler, NullRun> {
47 <Self as Role>::builder(self)
48 }
49
50 #[cfg(feature = "unstable_protocol_v2")]
58 pub fn v2(self) -> Builder<Client, NullHandler, NullRun> {
59 self.builder().v2_client()
60 }
61
62 pub async fn connect_with<R>(
67 self,
68 agent: impl ConnectTo<Client>,
69 main_fn: impl AsyncFnOnce(ConnectionTo<Agent>) -> Result<R, crate::Error>,
70 ) -> Result<R, crate::Error> {
71 self.builder().connect_with(agent, main_fn).await
72 }
73}
74
75impl HasPeer<Client> for Client {
76 fn remote_style(&self, _peer: Client) -> RemoteStyle {
77 RemoteStyle::Counterpart
78 }
79}
80
81#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
86pub struct Agent;
87
88impl Role for Agent {
89 type Counterpart = Client;
90
91 fn builder(self) -> Builder<Self> {
92 Builder::new(self).v1_agent()
93 }
94
95 fn role_id(&self) -> RoleId {
96 RoleId::from_singleton(self)
97 }
98
99 fn counterpart(&self) -> Self::Counterpart {
100 Client
101 }
102
103 async fn default_handle_dispatch_from(
104 &self,
105 message: Dispatch,
106 connection: ConnectionTo<Agent>,
107 ) -> Result<Handled<Dispatch>, crate::Error> {
108 MatchDispatchFrom::new(message, &connection)
109 .if_message_from(Agent, async |message: Dispatch| {
110 let retry = message.has_session_id();
118 Ok(Handled::No { message, retry })
119 })
120 .await
121 .done()
122 }
123}
124
125impl Agent {
126 pub fn builder(self) -> Builder<Agent, NullHandler, NullRun> {
128 <Self as Role>::builder(self)
129 }
130
131 #[cfg(feature = "unstable_protocol_v2")]
139 pub fn v2(self) -> Builder<Agent, NullHandler, NullRun> {
140 self.builder().v2_agent()
141 }
142}
143
144impl HasPeer<Agent> for Agent {
145 fn remote_style(&self, _peer: Agent) -> RemoteStyle {
146 RemoteStyle::Counterpart
147 }
148}
149
150#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
159pub struct Proxy;
160
161impl Role for Proxy {
162 type Counterpart = Conductor;
163
164 async fn default_handle_dispatch_from(
165 &self,
166 message: crate::Dispatch,
167 _connection: crate::ConnectionTo<Self>,
168 ) -> Result<crate::Handled<crate::Dispatch>, crate::Error> {
169 Ok(Handled::No {
170 message,
171 retry: false,
172 })
173 }
174
175 fn role_id(&self) -> RoleId {
176 RoleId::from_singleton(self)
177 }
178
179 fn counterpart(&self) -> Self::Counterpart {
180 Conductor
181 }
182}
183
184impl Proxy {
185 pub fn builder(self) -> Builder<Proxy, NullHandler, NullRun> {
187 Builder::new(self)
188 }
189}
190
191impl HasPeer<Proxy> for Proxy {
192 fn remote_style(&self, _peer: Proxy) -> RemoteStyle {
193 RemoteStyle::Counterpart
194 }
195}
196
197#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
202pub struct Conductor;
203
204impl Role for Conductor {
205 type Counterpart = Proxy;
206
207 fn role_id(&self) -> RoleId {
208 RoleId::from_singleton(self)
209 }
210
211 fn counterpart(&self) -> Self::Counterpart {
212 Proxy
213 }
214
215 async fn default_handle_dispatch_from(
216 &self,
217 message: Dispatch,
218 cx: ConnectionTo<Conductor>,
219 ) -> Result<Handled<Dispatch>, crate::Error> {
220 MatchDispatchFrom::new(message, &cx)
222 .if_request_from(Client, async |_req: InitializeRequest, responder| {
223 responder.respond_with_error(crate::Error::invalid_request().data(format!(
224 "proxies must be initialized with `{METHOD_INITIALIZE_PROXY}`"
225 )))
226 })
227 .await
228 .if_request_from(
231 Client,
232 async |request: InitializeProxyRequest, responder| {
233 let InitializeProxyRequest { initialize } = request;
234 cx.send_request_to(Agent, initialize)
235 .forward_response_to(responder)
236 },
237 )
238 .await
239 .if_request_from(Client, async |request: NewSessionRequest, responder| {
242 cx.send_request_to(Agent, request).on_receiving_result({
243 let cx = cx.clone();
244 async move |result| {
245 if let Ok(NewSessionResponse { session_id, .. }) = &result {
246 cx.add_dynamic_handler(ProxySessionMessages::new(session_id.clone()))?
247 .run_indefinitely();
248 }
249 responder.respond_with_result(result)
250 }
251 })
252 })
253 .await
254 .if_message_from(Client, async |message: Dispatch| {
256 cx.send_proxied_message_to(Agent, message)
257 })
258 .await
259 .if_message_from(Agent, async |message: Dispatch| {
261 cx.send_proxied_message_to(Client, message)
262 })
263 .await
264 .done()
265 }
266}
267
268impl Conductor {
269 pub fn builder(self) -> Builder<Conductor, NullHandler, NullRun> {
271 Builder::new(self)
272 }
273}
274
275impl HasPeer<Client> for Conductor {
276 fn remote_style(&self, _peer: Client) -> RemoteStyle {
277 RemoteStyle::Predecessor
278 }
279}
280
281impl HasPeer<Agent> for Conductor {
282 fn remote_style(&self, _peer: Agent) -> RemoteStyle {
283 RemoteStyle::Successor
284 }
285}
286
287pub(crate) struct ProxySessionMessages {
292 session_id: SessionId,
293}
294
295impl ProxySessionMessages {
296 pub fn new(session_id: SessionId) -> Self {
298 Self { session_id }
299 }
300}
301
302impl<Counterpart: Role> HandleDispatchFrom<Counterpart> for ProxySessionMessages
303where
304 Counterpart: HasPeer<Agent> + HasPeer<Client>,
305{
306 async fn handle_dispatch_from(
307 &mut self,
308 message: Dispatch,
309 connection: ConnectionTo<Counterpart>,
310 ) -> Result<Handled<Dispatch>, crate::Error> {
311 MatchDispatchFrom::new(message, &connection)
312 .if_message_from(Agent, async |message| {
313 if let Some(session_id) = message.get_session_id()?
315 && session_id == self.session_id
316 {
317 connection.send_proxied_message_to(Client, message)?;
318 return Ok(Handled::Yes);
319 }
320
321 Ok(Handled::No {
323 message,
324 retry: false,
325 })
326 })
327 .await
328 .done()
329 }
330
331 fn describe_chain(&self) -> impl std::fmt::Debug {
332 format!("ProxySessionMessages({})", self.session_id)
333 }
334}