agent_client_protocol/role/
acp.rs1use std::{fmt::Debug, hash::Hash};
2
3use agent_client_protocol_schema::{NewSessionRequest, NewSessionResponse, SessionId};
4
5use crate::jsonrpc::{Builder, handlers::NullHandler, run::NullRun};
6use crate::role::{HasPeer, RemoteStyle};
7use crate::schema::{InitializeProxyRequest, InitializeRequest, METHOD_INITIALIZE_PROXY};
8use crate::util::MatchDispatchFrom;
9use crate::{ConnectTo, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, Role, RoleId};
10
11#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct Client;
16
17impl Role for Client {
18 type Counterpart = Agent;
19
20 async fn default_handle_dispatch_from(
21 &self,
22 message: Dispatch,
23 _connection: ConnectionTo<Client>,
24 ) -> Result<Handled<Dispatch>, crate::Error> {
25 Ok(Handled::No {
26 message,
27 retry: false,
28 })
29 }
30
31 fn role_id(&self) -> RoleId {
32 RoleId::from_singleton(self)
33 }
34
35 fn counterpart(&self) -> Self::Counterpart {
36 Agent
37 }
38}
39
40impl Client {
41 pub fn builder(self) -> Builder<Client, NullHandler, NullRun> {
43 Builder::new(self)
44 }
45
46 pub async fn connect_with<R>(
51 self,
52 agent: impl ConnectTo<Client>,
53 main_fn: impl AsyncFnOnce(ConnectionTo<Agent>) -> Result<R, crate::Error>,
54 ) -> Result<R, crate::Error> {
55 self.builder().connect_with(agent, main_fn).await
56 }
57}
58
59impl HasPeer<Client> for Client {
60 fn remote_style(&self, _peer: Client) -> RemoteStyle {
61 RemoteStyle::Counterpart
62 }
63}
64
65#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
70pub struct Agent;
71
72impl Role for Agent {
73 type Counterpart = Client;
74
75 fn role_id(&self) -> RoleId {
76 RoleId::from_singleton(self)
77 }
78
79 fn counterpart(&self) -> Self::Counterpart {
80 Client
81 }
82
83 async fn default_handle_dispatch_from(
84 &self,
85 message: Dispatch,
86 connection: ConnectionTo<Agent>,
87 ) -> Result<Handled<Dispatch>, crate::Error> {
88 MatchDispatchFrom::new(message, &connection)
89 .if_message_from(Agent, async |message: Dispatch| {
90 let retry = message.has_session_id();
98 Ok(Handled::No { message, retry })
99 })
100 .await
101 .done()
102 }
103}
104
105impl Agent {
106 pub fn builder(self) -> Builder<Agent, NullHandler, NullRun> {
108 Builder::new(self)
109 }
110}
111
112impl HasPeer<Agent> for Agent {
113 fn remote_style(&self, _peer: Agent) -> RemoteStyle {
114 RemoteStyle::Counterpart
115 }
116}
117
118#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
127pub struct Proxy;
128
129impl Role for Proxy {
130 type Counterpart = Conductor;
131
132 async fn default_handle_dispatch_from(
133 &self,
134 message: crate::Dispatch,
135 _connection: crate::ConnectionTo<Self>,
136 ) -> Result<crate::Handled<crate::Dispatch>, crate::Error> {
137 Ok(Handled::No {
138 message,
139 retry: false,
140 })
141 }
142
143 fn role_id(&self) -> RoleId {
144 RoleId::from_singleton(self)
145 }
146
147 fn counterpart(&self) -> Self::Counterpart {
148 Conductor
149 }
150}
151
152impl Proxy {
153 pub fn builder(self) -> Builder<Proxy, NullHandler, NullRun> {
155 Builder::new(self)
156 }
157}
158
159impl HasPeer<Proxy> for Proxy {
160 fn remote_style(&self, _peer: Proxy) -> RemoteStyle {
161 RemoteStyle::Counterpart
162 }
163}
164
165#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
170pub struct Conductor;
171
172impl Role for Conductor {
173 type Counterpart = Proxy;
174
175 fn role_id(&self) -> RoleId {
176 RoleId::from_singleton(self)
177 }
178
179 fn counterpart(&self) -> Self::Counterpart {
180 Proxy
181 }
182
183 async fn default_handle_dispatch_from(
184 &self,
185 message: Dispatch,
186 cx: ConnectionTo<Conductor>,
187 ) -> Result<Handled<Dispatch>, crate::Error> {
188 MatchDispatchFrom::new(message, &cx)
190 .if_request_from(Client, async |_req: InitializeRequest, responder| {
191 responder.respond_with_error(crate::Error::invalid_request().data(format!(
192 "proxies must be initialized with `{METHOD_INITIALIZE_PROXY}`"
193 )))
194 })
195 .await
196 .if_request_from(
199 Client,
200 async |request: InitializeProxyRequest, responder| {
201 let InitializeProxyRequest { initialize } = request;
202 cx.send_request_to(Agent, initialize)
203 .forward_response_to(responder)
204 },
205 )
206 .await
207 .if_request_from(Client, async |request: NewSessionRequest, responder| {
210 cx.send_request_to(Agent, request).on_receiving_result({
211 let cx = cx.clone();
212 async move |result| {
213 if let Ok(NewSessionResponse { session_id, .. }) = &result {
214 cx.add_dynamic_handler(ProxySessionMessages::new(session_id.clone()))?
215 .run_indefinitely();
216 }
217 responder.respond_with_result(result)
218 }
219 })
220 })
221 .await
222 .if_message_from(Client, async |message: Dispatch| {
224 cx.send_proxied_message_to(Agent, message)
225 })
226 .await
227 .if_message_from(Agent, async |message: Dispatch| {
229 cx.send_proxied_message_to(Client, message)
230 })
231 .await
232 .done()
233 }
234}
235
236impl Conductor {
237 pub fn builder(self) -> Builder<Conductor, NullHandler, NullRun> {
239 Builder::new(self)
240 }
241}
242
243impl HasPeer<Client> for Conductor {
244 fn remote_style(&self, _peer: Client) -> RemoteStyle {
245 RemoteStyle::Predecessor
246 }
247}
248
249impl HasPeer<Agent> for Conductor {
250 fn remote_style(&self, _peer: Agent) -> RemoteStyle {
251 RemoteStyle::Successor
252 }
253}
254
255pub(crate) struct ProxySessionMessages {
260 session_id: SessionId,
261}
262
263impl ProxySessionMessages {
264 pub fn new(session_id: SessionId) -> Self {
266 Self { session_id }
267 }
268}
269
270impl<Counterpart: Role> HandleDispatchFrom<Counterpart> for ProxySessionMessages
271where
272 Counterpart: HasPeer<Agent> + HasPeer<Client>,
273{
274 async fn handle_dispatch_from(
275 &mut self,
276 message: Dispatch,
277 connection: ConnectionTo<Counterpart>,
278 ) -> Result<Handled<Dispatch>, crate::Error> {
279 MatchDispatchFrom::new(message, &connection)
280 .if_message_from(Agent, async |message| {
281 if let Some(session_id) = message.get_session_id()?
283 && session_id == self.session_id
284 {
285 connection.send_proxied_message_to(Client, message)?;
286 return Ok(Handled::Yes);
287 }
288
289 Ok(Handled::No {
291 message,
292 retry: false,
293 })
294 })
295 .await
296 .done()
297 }
298
299 fn describe_chain(&self) -> impl std::fmt::Debug {
300 format!("ProxySessionMessages({})", self.session_id)
301 }
302}