agent_client_protocol/role/
acp.rs1use std::{fmt::Debug, hash::Hash};
2
3use crate::jsonrpc::{Builder, handlers::NullHandler, run::NullRun};
4use crate::role::{HasPeer, RemoteStyle};
5use crate::schema::v1::{InitializeRequest, NewSessionRequest, NewSessionResponse, SessionId};
6use crate::schema::{InitializeProxyRequest, METHOD_INITIALIZE_PROXY};
7use crate::util::MatchDispatchFrom;
8use crate::{ConnectTo, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, Role, RoleId};
9
10#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
14pub struct Client;
15
16impl Role for Client {
17 type Counterpart = Agent;
18
19 fn builder(self) -> Builder<Self> {
20 Builder::new(self).v1_client()
21 }
22
23 async fn default_handle_dispatch_from(
24 &self,
25 message: Dispatch,
26 _connection: ConnectionTo<Client>,
27 ) -> Result<Handled<Dispatch>, crate::Error> {
28 Ok(Handled::No {
29 message,
30 retry: false,
31 })
32 }
33
34 fn role_id(&self) -> RoleId {
35 RoleId::from_singleton(self)
36 }
37
38 fn counterpart(&self) -> Self::Counterpart {
39 Agent
40 }
41}
42
43impl Client {
44 pub fn builder(self) -> Builder<Client, NullHandler, NullRun> {
46 <Self as Role>::builder(self)
47 }
48
49 #[cfg(feature = "unstable_protocol_v2")]
57 pub fn v2(self) -> Builder<Client, NullHandler, NullRun> {
58 self.builder().v2_client()
59 }
60
61 pub async fn connect_with<R>(
66 self,
67 agent: impl ConnectTo<Client>,
68 main_fn: impl AsyncFnOnce(ConnectionTo<Agent>) -> Result<R, crate::Error>,
69 ) -> Result<R, crate::Error> {
70 self.builder().connect_with(agent, main_fn).await
71 }
72}
73
74impl HasPeer<Client> for Client {
75 fn remote_style(&self, _peer: Client) -> RemoteStyle {
76 RemoteStyle::Counterpart
77 }
78}
79
80#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
85pub struct Agent;
86
87impl Role for Agent {
88 type Counterpart = Client;
89
90 fn builder(self) -> Builder<Self> {
91 Builder::new(self).v1_agent()
92 }
93
94 fn role_id(&self) -> RoleId {
95 RoleId::from_singleton(self)
96 }
97
98 fn counterpart(&self) -> Self::Counterpart {
99 Client
100 }
101
102 async fn default_handle_dispatch_from(
103 &self,
104 message: Dispatch,
105 connection: ConnectionTo<Agent>,
106 ) -> Result<Handled<Dispatch>, crate::Error> {
107 MatchDispatchFrom::new(message, &connection)
108 .if_message_from(Agent, async |message: Dispatch| {
109 let retry = message.has_session_id();
117 Ok(Handled::No { message, retry })
118 })
119 .await
120 .done()
121 }
122}
123
124impl Agent {
125 pub fn builder(self) -> Builder<Agent, NullHandler, NullRun> {
127 <Self as Role>::builder(self)
128 }
129
130 #[cfg(feature = "unstable_protocol_v2")]
138 pub fn v2(self) -> Builder<Agent, NullHandler, NullRun> {
139 self.builder().v2_agent()
140 }
141}
142
143impl HasPeer<Agent> for Agent {
144 fn remote_style(&self, _peer: Agent) -> RemoteStyle {
145 RemoteStyle::Counterpart
146 }
147}
148
149#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
158pub struct Proxy;
159
160impl Role for Proxy {
161 type Counterpart = Conductor;
162
163 async fn default_handle_dispatch_from(
164 &self,
165 message: crate::Dispatch,
166 _connection: crate::ConnectionTo<Self>,
167 ) -> Result<crate::Handled<crate::Dispatch>, crate::Error> {
168 Ok(Handled::No {
169 message,
170 retry: false,
171 })
172 }
173
174 fn role_id(&self) -> RoleId {
175 RoleId::from_singleton(self)
176 }
177
178 fn counterpart(&self) -> Self::Counterpart {
179 Conductor
180 }
181}
182
183impl Proxy {
184 pub fn builder(self) -> Builder<Proxy, NullHandler, NullRun> {
186 Builder::new(self)
187 }
188}
189
190impl HasPeer<Proxy> for Proxy {
191 fn remote_style(&self, _peer: Proxy) -> RemoteStyle {
192 RemoteStyle::Counterpart
193 }
194}
195
196#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
201pub struct Conductor;
202
203impl Role for Conductor {
204 type Counterpart = Proxy;
205
206 fn role_id(&self) -> RoleId {
207 RoleId::from_singleton(self)
208 }
209
210 fn counterpart(&self) -> Self::Counterpart {
211 Proxy
212 }
213
214 async fn default_handle_dispatch_from(
215 &self,
216 message: Dispatch,
217 cx: ConnectionTo<Conductor>,
218 ) -> Result<Handled<Dispatch>, crate::Error> {
219 MatchDispatchFrom::new(message, &cx)
221 .if_request_from(Client, async |_req: InitializeRequest, responder| {
222 responder.respond_with_error(crate::Error::invalid_request().data(format!(
223 "proxies must be initialized with `{METHOD_INITIALIZE_PROXY}`"
224 )))
225 })
226 .await
227 .if_request_from(
230 Client,
231 async |request: InitializeProxyRequest, responder| {
232 let InitializeProxyRequest { initialize } = request;
233 cx.send_request_to(Agent, initialize)
234 .forward_response_to(responder)
235 },
236 )
237 .await
238 .if_request_from(Client, async |request: NewSessionRequest, responder| {
241 let sent = cx.send_request_to(Agent, request);
242 #[cfg(feature = "unstable_cancel_request")]
247 let sent = sent.forward_cancellation_from(responder.cancellation());
248 sent.on_receiving_result({
249 let cx = cx.clone();
250 async move |result| {
251 if let Ok(NewSessionResponse { session_id, .. }) = &result {
252 cx.add_dynamic_handler(ProxySessionMessages::new(session_id.clone()))?
253 .run_indefinitely();
254 }
255 responder.respond_with_result(result)
256 }
257 })
258 })
259 .await
260 .if_message_from(Client, async |message: Dispatch| {
262 cx.send_proxied_message_to(Agent, message)
263 })
264 .await
265 .if_message_from(Agent, async |message: Dispatch| {
267 cx.send_proxied_message_to(Client, message)
268 })
269 .await
270 .done()
271 }
272}
273
274impl Conductor {
275 pub fn builder(self) -> Builder<Conductor, NullHandler, NullRun> {
277 Builder::new(self)
278 }
279}
280
281impl HasPeer<Client> for Conductor {
282 fn remote_style(&self, _peer: Client) -> RemoteStyle {
283 RemoteStyle::Predecessor
284 }
285}
286
287impl HasPeer<Agent> for Conductor {
288 fn remote_style(&self, _peer: Agent) -> RemoteStyle {
289 RemoteStyle::Successor
290 }
291}
292
293pub(crate) struct ProxySessionMessages {
298 session_id: SessionId,
299}
300
301impl ProxySessionMessages {
302 pub fn new(session_id: SessionId) -> Self {
304 Self { session_id }
305 }
306}
307
308impl<Counterpart: Role> HandleDispatchFrom<Counterpart> for ProxySessionMessages
309where
310 Counterpart: HasPeer<Agent> + HasPeer<Client>,
311{
312 async fn handle_dispatch_from(
313 &mut self,
314 message: Dispatch,
315 connection: ConnectionTo<Counterpart>,
316 ) -> Result<Handled<Dispatch>, crate::Error> {
317 MatchDispatchFrom::new(message, &connection)
318 .if_message_from(Agent, async |message| {
319 if let Some(session_id) = message.get_session_id()?
321 && session_id == self.session_id
322 {
323 connection.send_proxied_message_to(Client, message)?;
324 return Ok(Handled::Yes);
325 }
326
327 Ok(Handled::No {
329 message,
330 retry: false,
331 })
332 })
333 .await
334 .done()
335 }
336
337 fn describe_chain(&self) -> impl std::fmt::Debug {
338 format!("ProxySessionMessages({})", self.session_id)
339 }
340}