Skip to main content

agent_client_protocol/
rpc.rs

1use std::{
2    any::Any,
3    borrow::Cow,
4    collections::HashMap,
5    rc::Rc,
6    sync::{
7        Arc, Mutex,
8        atomic::{AtomicI64, Ordering},
9    },
10};
11
12use agent_client_protocol_schema::{
13    Error, JsonRpcMessage, Notification, OutgoingMessage, Request, RequestId, Response, Result,
14    Side,
15};
16use futures::{
17    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
18    StreamExt as _,
19    channel::{
20        mpsc::{self, UnboundedReceiver, UnboundedSender},
21        oneshot,
22    },
23    future::LocalBoxFuture,
24    io::BufReader,
25    select_biased,
26};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::value::RawValue;
29
30use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
31
32#[derive(Debug)]
33pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
34    outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
35    pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
36    next_id: AtomicI64,
37    broadcast: StreamBroadcast,
38}
39
40#[derive(Debug)]
41struct PendingResponse {
42    deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
43    respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
44}
45
46impl<Local, Remote> RpcConnection<Local, Remote>
47where
48    Local: Side + 'static,
49    Remote: Side + 'static,
50{
51    pub(crate) fn new<Handler>(
52        handler: Handler,
53        outgoing_bytes: impl Unpin + AsyncWrite,
54        incoming_bytes: impl Unpin + AsyncRead,
55        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
56    ) -> (Self, impl futures::Future<Output = Result<()>>)
57    where
58        Handler: MessageHandler<Local> + 'static,
59    {
60        let (incoming_tx, incoming_rx) = mpsc::unbounded();
61        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
62
63        let pending_responses = Arc::new(Mutex::new(HashMap::default()));
64        let (broadcast_tx, broadcast) = StreamBroadcast::new();
65
66        let io_task = {
67            let pending_responses = pending_responses.clone();
68            async move {
69                let result = Self::handle_io(
70                    incoming_tx,
71                    outgoing_rx,
72                    outgoing_bytes,
73                    incoming_bytes,
74                    pending_responses.clone(),
75                    broadcast_tx,
76                )
77                .await;
78                pending_responses.lock().unwrap().clear();
79                result
80            }
81        };
82
83        Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
84
85        let this = Self {
86            outgoing_tx,
87            pending_responses,
88            next_id: AtomicI64::new(0),
89            broadcast,
90        };
91
92        (this, io_task)
93    }
94
95    pub(crate) fn subscribe(&self) -> StreamReceiver {
96        self.broadcast.receiver()
97    }
98
99    pub(crate) fn notify(
100        &self,
101        method: impl Into<Arc<str>>,
102        params: Option<Remote::InNotification>,
103    ) -> Result<()> {
104        self.outgoing_tx
105            .unbounded_send(OutgoingMessage::Notification(Notification {
106                method: method.into(),
107                params,
108            }))
109            .map_err(|_| Error::internal_error().data("failed to send notification"))
110    }
111
112    pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
113        &self,
114        method: impl Into<Arc<str>>,
115        params: Option<Remote::InRequest>,
116    ) -> impl Future<Output = Result<Out>> {
117        let (tx, rx) = oneshot::channel();
118        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
119        let id = RequestId::Number(id);
120        self.pending_responses.lock().unwrap().insert(
121            id.clone(),
122            PendingResponse {
123                deserialize: |value| {
124                    serde_json::from_str::<Out>(value.get())
125                        .map(|out| Box::new(out) as _)
126                        .map_err(|_| Error::internal_error().data("failed to deserialize response"))
127                },
128                respond: tx,
129            },
130        );
131
132        if self
133            .outgoing_tx
134            .unbounded_send(OutgoingMessage::Request(Request {
135                id: id.clone(),
136                method: method.into(),
137                params,
138            }))
139            .is_err()
140        {
141            self.pending_responses.lock().unwrap().remove(&id);
142        }
143        async move {
144            let result = rx
145                .await
146                .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??
147                .downcast::<Out>()
148                .map_err(|_| Error::internal_error().data("failed to deserialize response"))?;
149
150            Ok(*result)
151        }
152    }
153
154    async fn handle_io(
155        incoming_tx: UnboundedSender<IncomingMessage<Local>>,
156        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
157        mut outgoing_bytes: impl Unpin + AsyncWrite,
158        incoming_bytes: impl Unpin + AsyncRead,
159        pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
160        broadcast: StreamSender,
161    ) -> Result<()> {
162        // TODO: Create nicer abstraction for broadcast
163        let mut input_reader = BufReader::new(incoming_bytes);
164        let mut outgoing_line = Vec::new();
165        let mut incoming_line = String::new();
166        loop {
167            select_biased! {
168                message = outgoing_rx.next() => {
169                    if let Some(message) = message {
170                        outgoing_line.clear();
171                        serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
172                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
173                        outgoing_line.push(b'\n');
174                        outgoing_bytes.write_all(&outgoing_line).await.ok();
175                        broadcast.outgoing(&message);
176                    } else {
177                        break;
178                    }
179                }
180                bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
181                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
182                        break
183                    }
184                    log::trace!("recv: {}", &incoming_line);
185
186                    match serde_json::from_str::<RawIncomingMessage<'_>>(&incoming_line) {
187                        Ok(message) => {
188                            if let Some(id) = message.id {
189                                if let Some(method) = message.method {
190                                    // Request
191                                    match Local::decode_request(&method, message.params) {
192                                        Ok(request) => {
193                                            broadcast.incoming_request(id.clone(), &*method, &request);
194                                            incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
195                                        }
196                                        Err(error) => {
197                                            outgoing_line.clear();
198                                            let error_response = OutgoingMessage::<Local, Remote>::Response(Response::Error {
199                                                id,
200                                                error,
201                                            });
202
203                                            serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
204                                            log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
205                                            outgoing_line.push(b'\n');
206                                            outgoing_bytes.write_all(&outgoing_line).await.ok();
207                                            broadcast.outgoing(&error_response);
208                                        }
209                                    }
210                                } else if let Some(pending_response) = pending_responses.lock().unwrap().remove(&id) {
211                                    // Response
212                                    if let Some(result_value) = message.result {
213                                        broadcast.incoming_response(id, Ok(Some(result_value)));
214
215                                        let result = (pending_response.deserialize)(result_value);
216                                        pending_response.respond.send(result).ok();
217                                    } else if let Some(error) = message.error {
218                                        broadcast.incoming_response(id, Err(&error));
219
220                                        pending_response.respond.send(Err(error)).ok();
221                                    } else {
222                                        broadcast.incoming_response(id, Ok(None));
223
224                                        let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
225                                        pending_response.respond.send(result).ok();
226                                    }
227                                } else {
228                                    log::error!("received response for unknown request id: {id:?}");
229                                }
230                            } else if let Some(method) = message.method {
231                                // Notification
232                                match Local::decode_notification(&method, message.params) {
233                                    Ok(notification) => {
234                                        broadcast.incoming_notification(&*method, &notification);
235                                        incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
236                                    }
237                                    Err(err) => {
238                                        log::error!("failed to decode {:?}: {err}", message.params);
239                                    }
240                                }
241                            } else {
242                                log::error!("received message with neither id nor method");
243                            }
244                        }
245                        Err(error) => {
246                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
247                        }
248                    }
249                    incoming_line.clear();
250                }
251            }
252        }
253        Ok(())
254    }
255
256    fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
257        outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
258        mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
259        handler: Handler,
260        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
261    ) {
262        let spawn = Rc::new(spawn);
263        let handler = Rc::new(handler);
264        spawn({
265            let spawn = spawn.clone();
266            async move {
267                while let Some(message) = incoming_rx.next().await {
268                    match message {
269                        IncomingMessage::Request { id, request } => {
270                            let outgoing_tx = outgoing_tx.clone();
271                            let handler = handler.clone();
272                            spawn(
273                                async move {
274                                    let result = handler.handle_request(request).await;
275                                    outgoing_tx
276                                        .unbounded_send(OutgoingMessage::Response(Response::new(
277                                            id, result,
278                                        )))
279                                        .ok();
280                                }
281                                .boxed_local(),
282                            );
283                        }
284                        IncomingMessage::Notification { notification } => {
285                            let handler = handler.clone();
286                            spawn(
287                                async move {
288                                    if let Err(err) =
289                                        handler.handle_notification(notification).await
290                                    {
291                                        log::error!("failed to handle notification: {err:?}");
292                                    }
293                                }
294                                .boxed_local(),
295                            );
296                        }
297                    }
298                }
299            }
300            .boxed_local()
301        });
302    }
303}
304
305#[derive(Debug, Deserialize)]
306pub struct RawIncomingMessage<'a> {
307    id: Option<RequestId>,
308    #[serde(borrow)]
309    method: Option<Cow<'a, str>>,
310    #[serde(borrow)]
311    params: Option<&'a RawValue>,
312    #[serde(borrow)]
313    result: Option<&'a RawValue>,
314    error: Option<Error>,
315}
316
317#[derive(Debug)]
318pub enum IncomingMessage<Local: Side> {
319    Request {
320        id: RequestId,
321        request: Local::InRequest,
322    },
323    Notification {
324        notification: Local::InNotification,
325    },
326}
327
328pub trait MessageHandler<Local: Side> {
329    fn handle_request(
330        &self,
331        request: Local::InRequest,
332    ) -> impl Future<Output = Result<Local::OutResponse>>;
333
334    fn handle_notification(
335        &self,
336        notification: Local::InNotification,
337    ) -> impl Future<Output = Result<()>>;
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_raw_incoming_message_with_escaped_slash() {
346        // JSON with escaped forward slash in method name (valid per RFC 8259).
347        // Some JSON encoders (especially behind WebSocket proxies) produce
348        // `\/` instead of `/`.  The Cow<str> field in RawIncomingMessage ensures
349        // serde can allocate a new String when unescaping is required.
350        //
351        // Before the fix, this would fail because `&'a str` cannot hold an
352        // unescaped value that differs from the source bytes.
353        let json_str = r#"{"jsonrpc":"2.0","id":1,"method":"session\/update","params":{}}"#;
354        let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
355        assert_eq!(parsed.method.unwrap(), "session/update");
356        assert_eq!(parsed.params.unwrap().to_string(), "{}");
357    }
358
359    #[test]
360    fn test_raw_incoming_message_without_escape() {
361        // Normal method name without escapes should still work (zero-copy borrow via Cow::Borrowed).
362        let json_str = r#"{"jsonrpc":"2.0","id":2,"method":"session/update","params":{}}"#;
363        let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
364        assert_eq!(parsed.method.unwrap(), "session/update");
365        assert_eq!(parsed.params.unwrap().to_string(), "{}");
366    }
367}