Skip to main content

agent_client_protocol/
rpc.rs

1use std::{
2    any::Any,
3    borrow::Cow,
4    collections::HashMap,
5    rc::Rc,
6    sync::{
7        Arc, Mutex,
8        atomic::{AtomicI64, Ordering},
9    },
10};
11
12use agent_client_protocol_schema::{
13    Error, JsonRpcMessage, Notification, OutgoingMessage, Request, RequestId, Response, Result,
14    Side,
15};
16use futures::{
17    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
18    StreamExt as _,
19    channel::{
20        mpsc::{self, UnboundedReceiver, UnboundedSender},
21        oneshot,
22    },
23    future::LocalBoxFuture,
24    io::BufReader,
25    select_biased,
26};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::value::RawValue;
29
30use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
31
32#[derive(Debug)]
33pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
34    outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
35    pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
36    next_id: AtomicI64,
37    broadcast: StreamBroadcast,
38}
39
40#[derive(Debug)]
41struct PendingResponse {
42    deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
43    respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
44}
45
46impl<Local, Remote> RpcConnection<Local, Remote>
47where
48    Local: Side + 'static,
49    Remote: Side + 'static,
50{
51    pub(crate) fn new<Handler>(
52        handler: Handler,
53        outgoing_bytes: impl Unpin + AsyncWrite,
54        incoming_bytes: impl Unpin + AsyncRead,
55        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
56    ) -> (Self, impl futures::Future<Output = Result<()>>)
57    where
58        Handler: MessageHandler<Local> + 'static,
59    {
60        let (incoming_tx, incoming_rx) = mpsc::unbounded();
61        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
62
63        let pending_responses = Arc::new(Mutex::new(HashMap::default()));
64        let (broadcast_tx, broadcast) = StreamBroadcast::new();
65
66        let io_task = {
67            let pending_responses = pending_responses.clone();
68            async move {
69                let result = Self::handle_io(
70                    incoming_tx,
71                    outgoing_rx,
72                    outgoing_bytes,
73                    incoming_bytes,
74                    pending_responses.clone(),
75                    broadcast_tx,
76                )
77                .await;
78                pending_responses.lock().unwrap().clear();
79                result
80            }
81        };
82
83        Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
84
85        let this = Self {
86            outgoing_tx,
87            pending_responses,
88            next_id: AtomicI64::new(0),
89            broadcast,
90        };
91
92        (this, io_task)
93    }
94
95    pub(crate) fn subscribe(&self) -> StreamReceiver {
96        self.broadcast.receiver()
97    }
98
99    pub(crate) fn notify(
100        &self,
101        method: impl Into<Arc<str>>,
102        params: Option<Remote::InNotification>,
103    ) -> Result<()> {
104        self.outgoing_tx
105            .unbounded_send(OutgoingMessage::Notification(Notification {
106                method: method.into(),
107                params,
108            }))
109            .map_err(|_| Error::internal_error().data("failed to send notification"))
110    }
111
112    pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
113        &self,
114        method: impl Into<Arc<str>>,
115        params: Option<Remote::InRequest>,
116    ) -> Result<impl Future<Output = Result<Out>>> {
117        let (tx, rx) = oneshot::channel();
118        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
119        let id = RequestId::Number(id);
120        self.pending_responses.lock().unwrap().insert(
121            id.clone(),
122            PendingResponse {
123                deserialize: |value| {
124                    serde_json::from_str::<Out>(value.get())
125                        .map(|out| Box::new(out) as _)
126                        .map_err(|_| Error::internal_error().data("failed to deserialize response"))
127                },
128                respond: tx,
129            },
130        );
131
132        if self
133            .outgoing_tx
134            .unbounded_send(OutgoingMessage::Request(Request {
135                id: id.clone(),
136                method: method.into(),
137                params,
138            }))
139            .is_err()
140        {
141            self.pending_responses.lock().unwrap().remove(&id);
142            return Err(
143                Error::internal_error().data("connection closed before request could be sent")
144            );
145        }
146        Ok(async move {
147            let result = rx
148                .await
149                .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??
150                .downcast::<Out>()
151                .map_err(|_| Error::internal_error().data("failed to deserialize response"))?;
152
153            Ok(*result)
154        })
155    }
156
157    async fn handle_io(
158        incoming_tx: UnboundedSender<IncomingMessage<Local>>,
159        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
160        mut outgoing_bytes: impl Unpin + AsyncWrite,
161        incoming_bytes: impl Unpin + AsyncRead,
162        pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
163        broadcast: StreamSender,
164    ) -> Result<()> {
165        // TODO: Create nicer abstraction for broadcast
166        let mut input_reader = BufReader::new(incoming_bytes);
167        let mut outgoing_line = Vec::new();
168        let mut incoming_line = String::new();
169        loop {
170            select_biased! {
171                message = outgoing_rx.next() => {
172                    if let Some(message) = message {
173                        outgoing_line.clear();
174                        serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
175                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
176                        outgoing_line.push(b'\n');
177                        if let Err(e) = outgoing_bytes.write_all(&outgoing_line).await {
178                            log::warn!("failed to send message to peer: {e}");
179                        }
180                        broadcast.outgoing(&message);
181                    } else {
182                        break;
183                    }
184                }
185                bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
186                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
187                        break
188                    }
189                    log::trace!("recv: {}", &incoming_line);
190
191                    match serde_json::from_str::<RawIncomingMessage<'_>>(&incoming_line) {
192                        Ok(message) => {
193                            if let Some(id) = message.id {
194                                if let Some(method) = message.method {
195                                    // Request
196                                    match Local::decode_request(&method, message.params) {
197                                        Ok(request) => {
198                                            broadcast.incoming_request(id.clone(), &*method, &request);
199                                            if let Err(e) = incoming_tx.unbounded_send(IncomingMessage::Request { id, request }) {
200                                                log::warn!("failed to send request to handler, channel full: {e:?}");
201                                            }
202                                        }
203                                        Err(error) => {
204                                            outgoing_line.clear();
205                                            let error_response = OutgoingMessage::<Local, Remote>::Response(Response::Error {
206                                                id,
207                                                error,
208                                            });
209
210                                            serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
211                                            log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
212                                            outgoing_line.push(b'\n');
213                                            if let Err(e) = outgoing_bytes.write_all(&outgoing_line).await {
214                                                log::warn!("failed to send error response to peer: {e}");
215                                            }
216                                            broadcast.outgoing(&error_response);
217                                        }
218                                    }
219                                } else if let Some(pending_response) = pending_responses.lock().unwrap().remove(&id) {
220                                    // Response
221                                    if let Some(result_value) = message.result {
222                                        broadcast.incoming_response(id, Ok(Some(result_value)));
223
224                                        let result = (pending_response.deserialize)(result_value);
225                                        pending_response.respond.send(result).ok();
226                                    } else if let Some(error) = message.error {
227                                        broadcast.incoming_response(id, Err(&error));
228
229                                        pending_response.respond.send(Err(error)).ok();
230                                    } else {
231                                        broadcast.incoming_response(id, Ok(None));
232
233                                        let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
234                                        pending_response.respond.send(result).ok();
235                                    }
236                                } else {
237                                    log::error!("received response for unknown request id: {id:?}");
238                                }
239                            } else if let Some(method) = message.method {
240                                // Notification
241                                match Local::decode_notification(&method, message.params) {
242                                    Ok(notification) => {
243                                        broadcast.incoming_notification(&*method, &notification);
244                                        if let Err(e) = incoming_tx.unbounded_send(IncomingMessage::Notification { notification }) {
245                                            log::warn!("failed to send notification to handler, channel full: {e:?}");
246                                        }
247                                    }
248                                    Err(err) => {
249                                        log::error!("failed to decode {:?}: {err}", message.params);
250                                    }
251                                }
252                            } else {
253                                log::error!("received message with neither id nor method");
254                            }
255                        }
256                        Err(error) => {
257                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
258                        }
259                    }
260                    incoming_line.clear();
261                }
262            }
263        }
264        Ok(())
265    }
266
267    fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
268        outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
269        mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
270        handler: Handler,
271        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
272    ) {
273        let spawn = Rc::new(spawn);
274        let handler = Rc::new(handler);
275        spawn({
276            let spawn = spawn.clone();
277            async move {
278                while let Some(message) = incoming_rx.next().await {
279                    match message {
280                        IncomingMessage::Request { id, request } => {
281                            let outgoing_tx = outgoing_tx.clone();
282                            let handler = handler.clone();
283                            spawn(
284                                async move {
285                                    let result = handler.handle_request(request).await;
286                                    outgoing_tx
287                                        .unbounded_send(OutgoingMessage::Response(Response::new(
288                                            id, result,
289                                        )))
290                                        .ok();
291                                }
292                                .boxed_local(),
293                            );
294                        }
295                        IncomingMessage::Notification { notification } => {
296                            let handler = handler.clone();
297                            spawn(
298                                async move {
299                                    if let Err(err) =
300                                        handler.handle_notification(notification).await
301                                    {
302                                        log::error!("failed to handle notification: {err:?}");
303                                    }
304                                }
305                                .boxed_local(),
306                            );
307                        }
308                    }
309                }
310            }
311            .boxed_local()
312        });
313    }
314}
315
316#[derive(Debug, Deserialize)]
317pub struct RawIncomingMessage<'a> {
318    id: Option<RequestId>,
319    #[serde(borrow)]
320    method: Option<Cow<'a, str>>,
321    #[serde(borrow)]
322    params: Option<&'a RawValue>,
323    #[serde(borrow)]
324    result: Option<&'a RawValue>,
325    error: Option<Error>,
326}
327
328#[derive(Debug)]
329pub enum IncomingMessage<Local: Side> {
330    Request {
331        id: RequestId,
332        request: Local::InRequest,
333    },
334    Notification {
335        notification: Local::InNotification,
336    },
337}
338
339pub trait MessageHandler<Local: Side> {
340    fn handle_request(
341        &self,
342        request: Local::InRequest,
343    ) -> impl Future<Output = Result<Local::OutResponse>>;
344
345    fn handle_notification(
346        &self,
347        notification: Local::InNotification,
348    ) -> impl Future<Output = Result<()>>;
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_raw_incoming_message_with_escaped_slash() {
357        // JSON with escaped forward slash in method name (valid per RFC 8259).
358        // Some JSON encoders (especially behind WebSocket proxies) produce
359        // `\/` instead of `/`.  The Cow<str> field in RawIncomingMessage ensures
360        // serde can allocate a new String when unescaping is required.
361        //
362        // Before the fix, this would fail because `&'a str` cannot hold an
363        // unescaped value that differs from the source bytes.
364        let json_str = r#"{"jsonrpc":"2.0","id":1,"method":"session\/update","params":{}}"#;
365        let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
366        assert_eq!(parsed.method.unwrap(), "session/update");
367        assert_eq!(parsed.params.unwrap().to_string(), "{}");
368    }
369
370    #[test]
371    fn test_raw_incoming_message_without_escape() {
372        // Normal method name without escapes should still work (zero-copy borrow via Cow::Borrowed).
373        let json_str = r#"{"jsonrpc":"2.0","id":2,"method":"session/update","params":{}}"#;
374        let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
375        assert_eq!(parsed.method.unwrap(), "session/update");
376        assert_eq!(parsed.params.unwrap().to_string(), "{}");
377    }
378}