agent_client_protocol/
rpc.rs

1use std::{
2    any::Any,
3    collections::HashMap,
4    rc::Rc,
5    sync::{
6        Arc,
7        atomic::{AtomicI64, Ordering},
8    },
9};
10
11use agent_client_protocol_schema::{Error, Result};
12use derive_more::Display;
13use futures::{
14    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
15    StreamExt as _,
16    channel::{
17        mpsc::{self, UnboundedReceiver, UnboundedSender},
18        oneshot,
19    },
20    future::LocalBoxFuture,
21    io::BufReader,
22    select_biased,
23};
24use parking_lot::Mutex;
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use serde_json::value::RawValue;
27
28use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
29
30pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
31    outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
32    pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
33    next_id: AtomicI64,
34    broadcast: StreamBroadcast,
35}
36
37struct PendingResponse {
38    deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
39    respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
40}
41
42impl<Local, Remote> RpcConnection<Local, Remote>
43where
44    Local: Side + 'static,
45    Remote: Side + 'static,
46{
47    pub(crate) fn new<Handler>(
48        handler: Handler,
49        outgoing_bytes: impl Unpin + AsyncWrite,
50        incoming_bytes: impl Unpin + AsyncRead,
51        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
52    ) -> (Self, impl futures::Future<Output = Result<()>>)
53    where
54        Handler: MessageHandler<Local> + 'static,
55    {
56        let (incoming_tx, incoming_rx) = mpsc::unbounded();
57        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
58
59        let pending_responses = Arc::new(Mutex::new(HashMap::default()));
60        let (broadcast_tx, broadcast) = StreamBroadcast::new();
61
62        let io_task = {
63            let pending_responses = pending_responses.clone();
64            async move {
65                let result = Self::handle_io(
66                    incoming_tx,
67                    outgoing_rx,
68                    outgoing_bytes,
69                    incoming_bytes,
70                    pending_responses.clone(),
71                    broadcast_tx,
72                )
73                .await;
74                pending_responses.lock().clear();
75                result
76            }
77        };
78
79        Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
80
81        let this = Self {
82            outgoing_tx,
83            pending_responses,
84            next_id: AtomicI64::new(0),
85            broadcast,
86        };
87
88        (this, io_task)
89    }
90
91    pub(crate) fn subscribe(&self) -> StreamReceiver {
92        self.broadcast.receiver()
93    }
94
95    pub(crate) fn notify(
96        &self,
97        method: impl Into<Arc<str>>,
98        params: Option<Remote::InNotification>,
99    ) -> Result<()> {
100        self.outgoing_tx
101            .unbounded_send(OutgoingMessage::Notification {
102                method: method.into(),
103                params,
104            })
105            .map_err(|_| Error::internal_error().with_data("failed to send notification"))
106    }
107
108    pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
109        &self,
110        method: impl Into<Arc<str>>,
111        params: Option<Remote::InRequest>,
112    ) -> impl Future<Output = Result<Out>> {
113        let (tx, rx) = oneshot::channel();
114        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
115        let id = RequestId::Number(id);
116        self.pending_responses.lock().insert(
117            id.clone(),
118            PendingResponse {
119                deserialize: |value| {
120                    serde_json::from_str::<Out>(value.get())
121                        .map(|out| Box::new(out) as _)
122                        .map_err(|_| {
123                            Error::internal_error().with_data("failed to deserialize response")
124                        })
125                },
126                respond: tx,
127            },
128        );
129
130        if self
131            .outgoing_tx
132            .unbounded_send(OutgoingMessage::Request {
133                id: id.clone(),
134                method: method.into(),
135                params,
136            })
137            .is_err()
138        {
139            self.pending_responses.lock().remove(&id);
140        }
141        async move {
142            let result = rx
143                .await
144                .map_err(|_| Error::internal_error().with_data("server shut down unexpectedly"))??
145                .downcast::<Out>()
146                .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
147
148            Ok(*result)
149        }
150    }
151
152    async fn handle_io(
153        incoming_tx: UnboundedSender<IncomingMessage<Local>>,
154        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
155        mut outgoing_bytes: impl Unpin + AsyncWrite,
156        incoming_bytes: impl Unpin + AsyncRead,
157        pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
158        broadcast: StreamSender,
159    ) -> Result<()> {
160        // TODO: Create nicer abstraction for broadcast
161        let mut input_reader = BufReader::new(incoming_bytes);
162        let mut outgoing_line = Vec::new();
163        let mut incoming_line = String::new();
164        loop {
165            select_biased! {
166                message = outgoing_rx.next() => {
167                    if let Some(message) = message {
168                        outgoing_line.clear();
169                        serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
170                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
171                        outgoing_line.push(b'\n');
172                        outgoing_bytes.write_all(&outgoing_line).await.ok();
173                        broadcast.outgoing(&message);
174                    } else {
175                        break;
176                    }
177                }
178                bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
179                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
180                        break
181                    }
182                    log::trace!("recv: {}", &incoming_line);
183
184                    match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
185                        Ok(message) => {
186                            if let Some(id) = message.id {
187                                if let Some(method) = message.method {
188                                    // Request
189                                    match Local::decode_request(method, message.params) {
190                                        Ok(request) => {
191                                            broadcast.incoming_request(id.clone(), method, &request);
192                                            incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
193                                        }
194                                        Err(err) => {
195                                            outgoing_line.clear();
196                                            let error_response = OutgoingMessage::<Local, Remote>::Response {
197                                                id,
198                                                result: ResponseResult::Error(err),
199                                            };
200
201                                            serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
202                                            log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
203                                            outgoing_line.push(b'\n');
204                                            outgoing_bytes.write_all(&outgoing_line).await.ok();
205                                            broadcast.outgoing(&error_response);
206                                        }
207                                    }
208                                } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
209                                    // Response
210                                    if let Some(result_value) = message.result {
211                                        broadcast.incoming_response(id, Ok(Some(result_value)));
212
213                                        let result = (pending_response.deserialize)(result_value);
214                                        pending_response.respond.send(result).ok();
215                                    } else if let Some(error) = message.error {
216                                        broadcast.incoming_response(id, Err(&error));
217
218                                        pending_response.respond.send(Err(error)).ok();
219                                    } else {
220                                        broadcast.incoming_response(id, Ok(None));
221
222                                        let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
223                                        pending_response.respond.send(result).ok();
224                                    }
225                                } else {
226                                    log::error!("received response for unknown request id: {id:?}");
227                                }
228                            } else if let Some(method) = message.method {
229                                // Notification
230                                match Local::decode_notification(method, message.params) {
231                                    Ok(notification) => {
232                                        broadcast.incoming_notification(method, &notification);
233                                        incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
234                                    }
235                                    Err(err) => {
236                                        log::error!("failed to decode {:?}: {err}", message.params);
237                                    }
238                                }
239                            } else {
240                                log::error!("received message with neither id nor method");
241                            }
242                        }
243                        Err(error) => {
244                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
245                        }
246                    }
247                    incoming_line.clear();
248                }
249            }
250        }
251        Ok(())
252    }
253
254    fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
255        outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
256        mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
257        handler: Handler,
258        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
259    ) {
260        let spawn = Rc::new(spawn);
261        let handler = Rc::new(handler);
262        spawn({
263            let spawn = spawn.clone();
264            async move {
265                while let Some(message) = incoming_rx.next().await {
266                    match message {
267                        IncomingMessage::Request { id, request } => {
268                            let outgoing_tx = outgoing_tx.clone();
269                            let handler = handler.clone();
270                            spawn(
271                                async move {
272                                    let result = handler.handle_request(request).await.into();
273                                    outgoing_tx
274                                        .unbounded_send(OutgoingMessage::Response { id, result })
275                                        .ok();
276                                }
277                                .boxed_local(),
278                            );
279                        }
280                        IncomingMessage::Notification { notification } => {
281                            let handler = handler.clone();
282                            spawn(
283                                async move {
284                                    if let Err(err) =
285                                        handler.handle_notification(notification).await
286                                    {
287                                        log::error!("failed to handle notification: {err:?}");
288                                    }
289                                }
290                                .boxed_local(),
291                            );
292                        }
293                    }
294                }
295            }
296            .boxed_local()
297        });
298    }
299}
300
301/// JSON RPC Request Id
302#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display)]
303#[serde(deny_unknown_fields)]
304#[serde(untagged)]
305pub enum RequestId {
306    #[display("null")]
307    Null,
308    Number(i64),
309    Str(String),
310}
311
312#[derive(Deserialize)]
313pub struct RawIncomingMessage<'a> {
314    id: Option<RequestId>,
315    method: Option<&'a str>,
316    params: Option<&'a RawValue>,
317    result: Option<&'a RawValue>,
318    error: Option<Error>,
319}
320
321pub enum IncomingMessage<Local: Side> {
322    Request {
323        id: RequestId,
324        request: Local::InRequest,
325    },
326    Notification {
327        notification: Local::InNotification,
328    },
329}
330
331#[derive(Serialize, Deserialize, Clone)]
332#[serde(untagged)]
333pub enum OutgoingMessage<Local: Side, Remote: Side> {
334    Request {
335        id: RequestId,
336        method: Arc<str>,
337        #[serde(skip_serializing_if = "Option::is_none")]
338        params: Option<Remote::InRequest>,
339    },
340    Response {
341        id: RequestId,
342        #[serde(flatten)]
343        result: ResponseResult<Local::OutResponse>,
344    },
345    Notification {
346        method: Arc<str>,
347        #[serde(skip_serializing_if = "Option::is_none")]
348        params: Option<Remote::InNotification>,
349    },
350}
351
352/// Either [`OutgoingMessage`] or [`IncomingMessage`] with `"jsonrpc": "2.0"` specified as
353/// [required by JSON-RPC 2.0 Specification][1].
354///
355/// [1]: https://www.jsonrpc.org/specification#compatibility
356#[derive(Debug, Serialize, Deserialize)]
357pub struct JsonRpcMessage<M> {
358    jsonrpc: &'static str,
359    #[serde(flatten)]
360    message: M,
361}
362
363impl<M> JsonRpcMessage<M> {
364    /// Used version of [JSON-RPC protocol].
365    ///
366    /// [JSON-RPC]: https://www.jsonrpc.org
367    pub const VERSION: &'static str = "2.0";
368
369    /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
370    /// [`JsonRpcMessage`].
371    #[must_use]
372    pub fn wrap(message: M) -> Self {
373        Self {
374            jsonrpc: Self::VERSION,
375            message,
376        }
377    }
378}
379
380#[derive(Debug, Serialize, Deserialize, Clone)]
381#[serde(rename_all = "snake_case")]
382pub enum ResponseResult<Res> {
383    Result(Res),
384    Error(Error),
385}
386
387impl<T> From<Result<T>> for ResponseResult<T> {
388    fn from(result: Result<T>) -> Self {
389        match result {
390            Ok(value) => ResponseResult::Result(value),
391            Err(error) => ResponseResult::Error(error),
392        }
393    }
394}
395
396pub trait Side: Clone {
397    type InRequest: Clone + Serialize + DeserializeOwned + 'static;
398    type OutResponse: Clone + Serialize + DeserializeOwned + 'static;
399    type InNotification: Clone + Serialize + DeserializeOwned + 'static;
400
401    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
402
403    fn decode_notification(method: &str, params: Option<&RawValue>)
404    -> Result<Self::InNotification>;
405}
406
407pub trait MessageHandler<Local: Side> {
408    fn handle_request(
409        &self,
410        request: Local::InRequest,
411    ) -> impl Future<Output = Result<Local::OutResponse>>;
412
413    fn handle_notification(
414        &self,
415        notification: Local::InNotification,
416    ) -> impl Future<Output = Result<()>>;
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    use serde_json::{Number, Value};
424
425    #[test]
426    fn id_deserialization() {
427        let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
428        assert_eq!(id, RequestId::Null);
429
430        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
431            .unwrap();
432        assert_eq!(id, RequestId::Number(1));
433
434        let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
435            .unwrap();
436        assert_eq!(id, RequestId::Number(-1));
437
438        let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
439        assert_eq!(id, RequestId::Str("id".to_owned()));
440    }
441
442    #[test]
443    fn id_serialization() {
444        let id = serde_json::to_value(RequestId::Null).unwrap();
445        assert_eq!(id, Value::Null);
446
447        let id = serde_json::to_value(RequestId::Number(1)).unwrap();
448        assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
449
450        let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
451        assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
452
453        let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
454        assert_eq!(id, Value::String("id".to_owned()));
455    }
456
457    #[test]
458    fn id_display() {
459        let id = RequestId::Null;
460        assert_eq!(id.to_string(), "null");
461
462        let id = RequestId::Number(1);
463        assert_eq!(id.to_string(), "1");
464
465        let id = RequestId::Number(-1);
466        assert_eq!(id.to_string(), "-1");
467
468        let id = RequestId::Str("id".to_owned());
469        assert_eq!(id.to_string(), "id");
470    }
471}