Skip to main content

agent_client_protocol/jsonrpc/
handlers.rs

1use crate::jsonrpc::{HandleDispatchFrom, Handled, IntoHandled, JsonRpcResponse};
2
3use crate::role::{HasPeer, Role, handle_incoming_dispatch};
4use crate::{ConnectionTo, Dispatch, JsonRpcNotification, JsonRpcRequest, UntypedMessage};
5// Types re-exported from crate root
6use super::Responder;
7use std::marker::PhantomData;
8use std::ops::AsyncFnMut;
9
10/// Null handler that accepts no messages.
11#[derive(Debug)]
12pub struct NullHandler;
13
14impl NullHandler {
15    /// Creates a new null handler.
16    #[must_use]
17    pub fn new() -> Self {
18        Self
19    }
20}
21
22impl Default for NullHandler {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<Counterpart: Role> HandleDispatchFrom<Counterpart> for NullHandler {
29    fn describe_chain(&self) -> impl std::fmt::Debug {
30        "(null)"
31    }
32
33    async fn handle_dispatch_from(
34        &mut self,
35        message: Dispatch,
36        _cx: ConnectionTo<Counterpart>,
37    ) -> Result<Handled<Dispatch>, crate::Error> {
38        Ok(Handled::No {
39            message,
40            retry: false,
41        })
42    }
43}
44
45/// Handler for typed request messages
46pub struct RequestHandler<
47    Counterpart: Role,
48    Peer: Role,
49    Req: JsonRpcRequest = UntypedMessage,
50    F = (),
51    ToFut = (),
52> {
53    counterpart: Counterpart,
54    peer: Peer,
55    handler: F,
56    to_future_hack: ToFut,
57    phantom: PhantomData<fn(Req)>,
58}
59
60impl<Counterpart: Role, Peer: Role, Req: JsonRpcRequest, F, ToFut>
61    RequestHandler<Counterpart, Peer, Req, F, ToFut>
62{
63    /// Creates a new request handler
64    pub fn new(counterpart: Counterpart, peer: Peer, handler: F, to_future_hack: ToFut) -> Self {
65        Self {
66            counterpart,
67            peer,
68            handler,
69            to_future_hack,
70            phantom: PhantomData,
71        }
72    }
73}
74
75impl<Counterpart: Role, Peer: Role, Req, F, T, ToFut> HandleDispatchFrom<Counterpart>
76    for RequestHandler<Counterpart, Peer, Req, F, ToFut>
77where
78    Counterpart: HasPeer<Peer>,
79    Req: JsonRpcRequest,
80    F: AsyncFnMut(
81            Req,
82            Responder<Req::Response>,
83            ConnectionTo<Counterpart>,
84        ) -> Result<T, crate::Error>
85        + Send,
86    T: crate::IntoHandled<(Req, Responder<Req::Response>)>,
87    ToFut: Fn(
88            &mut F,
89            Req,
90            Responder<Req::Response>,
91            ConnectionTo<Counterpart>,
92        ) -> crate::BoxFuture<'_, Result<T, crate::Error>>
93        + Send
94        + Sync,
95{
96    fn describe_chain(&self) -> impl std::fmt::Debug {
97        std::any::type_name::<Req>()
98    }
99
100    async fn handle_dispatch_from(
101        &mut self,
102        dispatch: Dispatch,
103        connection: ConnectionTo<Counterpart>,
104    ) -> Result<Handled<Dispatch>, crate::Error> {
105        handle_incoming_dispatch(
106            self.counterpart.clone(),
107            self.peer.clone(),
108            dispatch,
109            connection,
110            async |dispatch, connection| {
111                match dispatch {
112                    Dispatch::Request(message, responder) => {
113                        tracing::debug!(
114                            request_type = std::any::type_name::<Req>(),
115                            message = ?message,
116                            "RequestHandler::handle_request"
117                        );
118                        if Req::matches_method(&message.method) {
119                            match Req::parse_message(&message.method, &message.params) {
120                                Ok(req) => {
121                                    tracing::trace!(
122                                        ?req,
123                                        "RequestHandler::handle_request: parse completed"
124                                    );
125                                    let typed_responder = responder.cast();
126                                    let result = (self.to_future_hack)(
127                                        &mut self.handler,
128                                        req,
129                                        typed_responder,
130                                        connection,
131                                    )
132                                    .await?;
133                                    match result.into_handled() {
134                                        Handled::Yes => Ok(Handled::Yes),
135                                        Handled::No {
136                                            message: (request, responder),
137                                            retry,
138                                        } => {
139                                            // Handler returned the request back, convert to untyped
140                                            let untyped = request.to_untyped_message()?;
141                                            Ok(Handled::No {
142                                                message: Dispatch::Request(
143                                                    untyped,
144                                                    responder.erase_to_json(),
145                                                ),
146                                                retry,
147                                            })
148                                        }
149                                    }
150                                }
151                                Err(err) => {
152                                    tracing::trace!(
153                                        ?err,
154                                        "RequestHandler::handle_request: parse errored"
155                                    );
156                                    Err(err)
157                                }
158                            }
159                        } else {
160                            tracing::trace!("RequestHandler::handle_request: method doesn't match");
161                            Ok(Handled::No {
162                                message: Dispatch::Request(message, responder),
163                                retry: false,
164                            })
165                        }
166                    }
167
168                    Dispatch::Notification(..) | Dispatch::Response(..) => Ok(Handled::No {
169                        message: dispatch,
170                        retry: false,
171                    }),
172                }
173            },
174        )
175        .await
176    }
177}
178
179/// Handler for typed notification messages
180pub struct NotificationHandler<
181    Counterpart: Role,
182    Peer: Role,
183    Notif: JsonRpcNotification = UntypedMessage,
184    F = (),
185    ToFut = (),
186> {
187    counterpart: Counterpart,
188    peer: Peer,
189    handler: F,
190    to_future_hack: ToFut,
191    phantom: PhantomData<fn(Notif)>,
192}
193
194impl<Counterpart: Role, Peer: Role, Notif: JsonRpcNotification, F, ToFut>
195    NotificationHandler<Counterpart, Peer, Notif, F, ToFut>
196{
197    /// Creates a new notification handler
198    pub fn new(counterpart: Counterpart, peer: Peer, handler: F, to_future_hack: ToFut) -> Self {
199        Self {
200            counterpart,
201            peer,
202            handler,
203            to_future_hack,
204            phantom: PhantomData,
205        }
206    }
207}
208
209impl<Counterpart: Role, Peer: Role, Notif, F, T, ToFut> HandleDispatchFrom<Counterpart>
210    for NotificationHandler<Counterpart, Peer, Notif, F, ToFut>
211where
212    Counterpart: HasPeer<Peer>,
213    Notif: JsonRpcNotification,
214    F: AsyncFnMut(Notif, ConnectionTo<Counterpart>) -> Result<T, crate::Error> + Send,
215    T: crate::IntoHandled<(Notif, ConnectionTo<Counterpart>)>,
216    ToFut: Fn(
217            &mut F,
218            Notif,
219            ConnectionTo<Counterpart>,
220        ) -> crate::BoxFuture<'_, Result<T, crate::Error>>
221        + Send
222        + Sync,
223{
224    fn describe_chain(&self) -> impl std::fmt::Debug {
225        std::any::type_name::<Notif>()
226    }
227
228    async fn handle_dispatch_from(
229        &mut self,
230        dispatch: Dispatch,
231        connection: ConnectionTo<Counterpart>,
232    ) -> Result<Handled<Dispatch>, crate::Error> {
233        handle_incoming_dispatch(
234            self.counterpart.clone(),
235            self.peer.clone(),
236            dispatch,
237            connection,
238            async |dispatch, connection| {
239                match dispatch {
240                    Dispatch::Notification(message) => {
241                        tracing::debug!(
242                            request_type = std::any::type_name::<Notif>(),
243                            message = ?message,
244                            "NotificationHandler::handle_dispatch"
245                        );
246                        if Notif::matches_method(&message.method) {
247                            match Notif::parse_message(&message.method, &message.params) {
248                                Ok(notif) => {
249                                    tracing::trace!(
250                                        ?notif,
251                                        "NotificationHandler::handle_notification: parse completed"
252                                    );
253                                    let result =
254                                        (self.to_future_hack)(&mut self.handler, notif, connection)
255                                            .await?;
256                                    match result.into_handled() {
257                                        Handled::Yes => Ok(Handled::Yes),
258                                        Handled::No {
259                                            message: (notification, _cx),
260                                            retry,
261                                        } => {
262                                            // Handler returned the notification back, convert to untyped
263                                            let untyped = notification.to_untyped_message()?;
264                                            Ok(Handled::No {
265                                                message: Dispatch::Notification(untyped),
266                                                retry,
267                                            })
268                                        }
269                                    }
270                                }
271                                Err(err) => {
272                                    tracing::trace!(
273                                        ?err,
274                                        "NotificationHandler::handle_notification: parse errored"
275                                    );
276                                    Err(err)
277                                }
278                            }
279                        } else {
280                            tracing::trace!(
281                                "NotificationHandler::handle_notification: method doesn't match"
282                            );
283                            Ok(Handled::No {
284                                message: Dispatch::Notification(message),
285                                retry: false,
286                            })
287                        }
288                    }
289
290                    Dispatch::Request(..) | Dispatch::Response(..) => Ok(Handled::No {
291                        message: dispatch,
292                        retry: false,
293                    }),
294                }
295            },
296        )
297        .await
298    }
299}
300
301/// Handler that handles both requests and notifications of specific types.
302pub struct MessageHandler<
303    Counterpart: Role,
304    Peer: Role,
305    Req: JsonRpcRequest = UntypedMessage,
306    Notif: JsonRpcNotification = UntypedMessage,
307    F = (),
308    ToFut = (),
309> {
310    counterpart: Counterpart,
311    peer: Peer,
312    handler: F,
313    to_future_hack: ToFut,
314    phantom: PhantomData<fn(Dispatch<Req, Notif>)>,
315}
316
317impl<Counterpart: Role, Peer: Role, Req: JsonRpcRequest, Notif: JsonRpcNotification, F, ToFut>
318    MessageHandler<Counterpart, Peer, Req, Notif, F, ToFut>
319{
320    /// Creates a new message handler
321    pub fn new(counterpart: Counterpart, peer: Peer, handler: F, to_future_hack: ToFut) -> Self {
322        Self {
323            counterpart,
324            peer,
325            handler,
326            to_future_hack,
327            phantom: PhantomData,
328        }
329    }
330}
331
332impl<Counterpart: Role, Peer: Role, Req: JsonRpcRequest, Notif: JsonRpcNotification, F, T, ToFut>
333    HandleDispatchFrom<Counterpart> for MessageHandler<Counterpart, Peer, Req, Notif, F, ToFut>
334where
335    Counterpart: HasPeer<Peer>,
336    F: AsyncFnMut(Dispatch<Req, Notif>, ConnectionTo<Counterpart>) -> Result<T, crate::Error>
337        + Send,
338    T: IntoHandled<Dispatch<Req, Notif>>,
339    ToFut: Fn(
340            &mut F,
341            Dispatch<Req, Notif>,
342            ConnectionTo<Counterpart>,
343        ) -> crate::BoxFuture<'_, Result<T, crate::Error>>
344        + Send
345        + Sync,
346{
347    fn describe_chain(&self) -> impl std::fmt::Debug {
348        format!(
349            "({}, {})",
350            std::any::type_name::<Req>(),
351            std::any::type_name::<Notif>()
352        )
353    }
354
355    async fn handle_dispatch_from(
356        &mut self,
357        dispatch: Dispatch,
358        connection: ConnectionTo<Counterpart>,
359    ) -> Result<Handled<Dispatch>, crate::Error> {
360        handle_incoming_dispatch(
361            self.counterpart.clone(),
362            self.peer.clone(),
363            dispatch,
364            connection,
365            async |dispatch, connection| match dispatch.into_typed_dispatch::<Req, Notif>()? {
366                Ok(typed_dispatch) => {
367                    let result =
368                        (self.to_future_hack)(&mut self.handler, typed_dispatch, connection)
369                            .await?;
370                    match result.into_handled() {
371                        Handled::Yes => Ok(Handled::Yes),
372                        Handled::No {
373                            message: Dispatch::Request(request, responder),
374                            retry,
375                        } => {
376                            let untyped = request.to_untyped_message()?;
377                            Ok(Handled::No {
378                                message: Dispatch::Request(untyped, responder.erase_to_json()),
379                                retry,
380                            })
381                        }
382                        Handled::No {
383                            message: Dispatch::Notification(notification),
384                            retry,
385                        } => {
386                            let untyped = notification.to_untyped_message()?;
387                            Ok(Handled::No {
388                                message: Dispatch::Notification(untyped),
389                                retry,
390                            })
391                        }
392                        Handled::No {
393                            message: Dispatch::Response(result, responder),
394                            retry,
395                        } => {
396                            let method = responder.method();
397                            let untyped_result = match result {
398                                Ok(response) => response.into_json(method).map(Ok),
399                                Err(err) => Ok(Err(err)),
400                            }?;
401                            Ok(Handled::No {
402                                message: Dispatch::Response(
403                                    untyped_result,
404                                    responder.erase_to_json(),
405                                ),
406                                retry,
407                            })
408                        }
409                    }
410                }
411
412                Err(dispatch) => Ok(Handled::No {
413                    message: dispatch,
414                    retry: false,
415                }),
416            },
417        )
418        .await
419    }
420}
421
422/// Wraps a handler with an optional name for tracing/debugging.
423pub struct NamedHandler<H> {
424    name: Option<String>,
425    handler: H,
426}
427
428impl<H> NamedHandler<H> {
429    /// Creates a new named handler
430    pub fn new(name: Option<String>, handler: H) -> Self {
431        Self { name, handler }
432    }
433}
434
435impl<Counterpart: Role, H: HandleDispatchFrom<Counterpart>> HandleDispatchFrom<Counterpart>
436    for NamedHandler<H>
437{
438    fn describe_chain(&self) -> impl std::fmt::Debug {
439        format!(
440            "NamedHandler({:?}, {:?})",
441            self.name,
442            self.handler.describe_chain()
443        )
444    }
445
446    async fn handle_dispatch_from(
447        &mut self,
448        message: Dispatch,
449        connection: ConnectionTo<Counterpart>,
450    ) -> Result<Handled<Dispatch>, crate::Error> {
451        if let Some(name) = &self.name {
452            crate::util::instrumented_with_connection_name(
453                name.clone(),
454                self.handler.handle_dispatch_from(message, connection),
455            )
456            .await
457        } else {
458            self.handler.handle_dispatch_from(message, connection).await
459        }
460    }
461}
462
463/// Chains two handlers together, trying the first handler and falling back to the second
464pub struct ChainedHandler<H1, H2> {
465    handler1: H1,
466    handler2: H2,
467}
468
469impl<H1, H2> ChainedHandler<H1, H2> {
470    /// Creates a new chain handler
471    pub fn new(handler1: H1, handler2: H2) -> Self {
472        Self { handler1, handler2 }
473    }
474}
475
476impl<Counterpart: Role, H1, H2> HandleDispatchFrom<Counterpart> for ChainedHandler<H1, H2>
477where
478    H1: HandleDispatchFrom<Counterpart>,
479    H2: HandleDispatchFrom<Counterpart>,
480{
481    fn describe_chain(&self) -> impl std::fmt::Debug {
482        format!(
483            "{:?}, {:?}",
484            self.handler1.describe_chain(),
485            self.handler2.describe_chain()
486        )
487    }
488
489    async fn handle_dispatch_from(
490        &mut self,
491        message: Dispatch,
492        connection: ConnectionTo<Counterpart>,
493    ) -> Result<Handled<Dispatch>, crate::Error> {
494        match self
495            .handler1
496            .handle_dispatch_from(message, connection.clone())
497            .await?
498        {
499            Handled::Yes => Ok(Handled::Yes),
500            Handled::No {
501                message,
502                retry: retry1,
503            } => match self
504                .handler2
505                .handle_dispatch_from(message, connection)
506                .await?
507            {
508                Handled::Yes => Ok(Handled::Yes),
509                Handled::No {
510                    message,
511                    retry: retry2,
512                } => Ok(Handled::No {
513                    message,
514                    retry: retry1 | retry2,
515                }),
516            },
517        }
518    }
519}