agent_client_protocol/
rpc.rs

1use std::{
2    any::Any,
3    collections::HashMap,
4    rc::Rc,
5    sync::{
6        Arc, Mutex,
7        atomic::{AtomicI64, Ordering},
8    },
9};
10
11use agent_client_protocol_schema::{
12    Error, JsonRpcMessage, Notification, OutgoingMessage, Request, RequestId, Response, Result,
13    Side,
14};
15use futures::{
16    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
17    StreamExt as _,
18    channel::{
19        mpsc::{self, UnboundedReceiver, UnboundedSender},
20        oneshot,
21    },
22    future::LocalBoxFuture,
23    io::BufReader,
24    select_biased,
25};
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::value::RawValue;
28
29use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
30
31#[derive(Debug)]
32pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
33    outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
34    pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
35    next_id: AtomicI64,
36    broadcast: StreamBroadcast,
37}
38
39#[derive(Debug)]
40struct PendingResponse {
41    deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
42    respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
43}
44
45impl<Local, Remote> RpcConnection<Local, Remote>
46where
47    Local: Side + 'static,
48    Remote: Side + 'static,
49{
50    pub(crate) fn new<Handler>(
51        handler: Handler,
52        outgoing_bytes: impl Unpin + AsyncWrite,
53        incoming_bytes: impl Unpin + AsyncRead,
54        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
55    ) -> (Self, impl futures::Future<Output = Result<()>>)
56    where
57        Handler: MessageHandler<Local> + 'static,
58    {
59        let (incoming_tx, incoming_rx) = mpsc::unbounded();
60        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
61
62        let pending_responses = Arc::new(Mutex::new(HashMap::default()));
63        let (broadcast_tx, broadcast) = StreamBroadcast::new();
64
65        let io_task = {
66            let pending_responses = pending_responses.clone();
67            async move {
68                let result = Self::handle_io(
69                    incoming_tx,
70                    outgoing_rx,
71                    outgoing_bytes,
72                    incoming_bytes,
73                    pending_responses.clone(),
74                    broadcast_tx,
75                )
76                .await;
77                pending_responses.lock().unwrap().clear();
78                result
79            }
80        };
81
82        Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
83
84        let this = Self {
85            outgoing_tx,
86            pending_responses,
87            next_id: AtomicI64::new(0),
88            broadcast,
89        };
90
91        (this, io_task)
92    }
93
94    pub(crate) fn subscribe(&self) -> StreamReceiver {
95        self.broadcast.receiver()
96    }
97
98    pub(crate) fn notify(
99        &self,
100        method: impl Into<Arc<str>>,
101        params: Option<Remote::InNotification>,
102    ) -> Result<()> {
103        self.outgoing_tx
104            .unbounded_send(OutgoingMessage::Notification(Notification {
105                method: method.into(),
106                params,
107            }))
108            .map_err(|_| Error::internal_error().data("failed to send notification"))
109    }
110
111    pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
112        &self,
113        method: impl Into<Arc<str>>,
114        params: Option<Remote::InRequest>,
115    ) -> impl Future<Output = Result<Out>> {
116        let (tx, rx) = oneshot::channel();
117        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
118        let id = RequestId::Number(id);
119        self.pending_responses.lock().unwrap().insert(
120            id.clone(),
121            PendingResponse {
122                deserialize: |value| {
123                    serde_json::from_str::<Out>(value.get())
124                        .map(|out| Box::new(out) as _)
125                        .map_err(|_| Error::internal_error().data("failed to deserialize response"))
126                },
127                respond: tx,
128            },
129        );
130
131        if self
132            .outgoing_tx
133            .unbounded_send(OutgoingMessage::Request(Request {
134                id: id.clone(),
135                method: method.into(),
136                params,
137            }))
138            .is_err()
139        {
140            self.pending_responses.lock().unwrap().remove(&id);
141        }
142        async move {
143            let result = rx
144                .await
145                .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??
146                .downcast::<Out>()
147                .map_err(|_| Error::internal_error().data("failed to deserialize response"))?;
148
149            Ok(*result)
150        }
151    }
152
153    async fn handle_io(
154        incoming_tx: UnboundedSender<IncomingMessage<Local>>,
155        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
156        mut outgoing_bytes: impl Unpin + AsyncWrite,
157        incoming_bytes: impl Unpin + AsyncRead,
158        pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
159        broadcast: StreamSender,
160    ) -> Result<()> {
161        // TODO: Create nicer abstraction for broadcast
162        let mut input_reader = BufReader::new(incoming_bytes);
163        let mut outgoing_line = Vec::new();
164        let mut incoming_line = String::new();
165        loop {
166            select_biased! {
167                message = outgoing_rx.next() => {
168                    if let Some(message) = message {
169                        outgoing_line.clear();
170                        serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
171                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
172                        outgoing_line.push(b'\n');
173                        outgoing_bytes.write_all(&outgoing_line).await.ok();
174                        broadcast.outgoing(&message);
175                    } else {
176                        break;
177                    }
178                }
179                bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
180                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
181                        break
182                    }
183                    log::trace!("recv: {}", &incoming_line);
184
185                    match serde_json::from_str::<RawIncomingMessage<'_>>(&incoming_line) {
186                        Ok(message) => {
187                            if let Some(id) = message.id {
188                                if let Some(method) = message.method {
189                                    // Request
190                                    match Local::decode_request(method, message.params) {
191                                        Ok(request) => {
192                                            broadcast.incoming_request(id.clone(), method, &request);
193                                            incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
194                                        }
195                                        Err(error) => {
196                                            outgoing_line.clear();
197                                            let error_response = OutgoingMessage::<Local, Remote>::Response(Response::Error {
198                                                id,
199                                                error,
200                                            });
201
202                                            serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
203                                            log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
204                                            outgoing_line.push(b'\n');
205                                            outgoing_bytes.write_all(&outgoing_line).await.ok();
206                                            broadcast.outgoing(&error_response);
207                                        }
208                                    }
209                                } else if let Some(pending_response) = pending_responses.lock().unwrap().remove(&id) {
210                                    // Response
211                                    if let Some(result_value) = message.result {
212                                        broadcast.incoming_response(id, Ok(Some(result_value)));
213
214                                        let result = (pending_response.deserialize)(result_value);
215                                        pending_response.respond.send(result).ok();
216                                    } else if let Some(error) = message.error {
217                                        broadcast.incoming_response(id, Err(&error));
218
219                                        pending_response.respond.send(Err(error)).ok();
220                                    } else {
221                                        broadcast.incoming_response(id, Ok(None));
222
223                                        let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
224                                        pending_response.respond.send(result).ok();
225                                    }
226                                } else {
227                                    log::error!("received response for unknown request id: {id:?}");
228                                }
229                            } else if let Some(method) = message.method {
230                                // Notification
231                                match Local::decode_notification(method, message.params) {
232                                    Ok(notification) => {
233                                        broadcast.incoming_notification(method, &notification);
234                                        incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
235                                    }
236                                    Err(err) => {
237                                        log::error!("failed to decode {:?}: {err}", message.params);
238                                    }
239                                }
240                            } else {
241                                log::error!("received message with neither id nor method");
242                            }
243                        }
244                        Err(error) => {
245                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
246                        }
247                    }
248                    incoming_line.clear();
249                }
250            }
251        }
252        Ok(())
253    }
254
255    fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
256        outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
257        mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
258        handler: Handler,
259        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
260    ) {
261        let spawn = Rc::new(spawn);
262        let handler = Rc::new(handler);
263        spawn({
264            let spawn = spawn.clone();
265            async move {
266                while let Some(message) = incoming_rx.next().await {
267                    match message {
268                        IncomingMessage::Request { id, request } => {
269                            let outgoing_tx = outgoing_tx.clone();
270                            let handler = handler.clone();
271                            spawn(
272                                async move {
273                                    let result = handler.handle_request(request).await;
274                                    outgoing_tx
275                                        .unbounded_send(OutgoingMessage::Response(Response::new(
276                                            id, result,
277                                        )))
278                                        .ok();
279                                }
280                                .boxed_local(),
281                            );
282                        }
283                        IncomingMessage::Notification { notification } => {
284                            let handler = handler.clone();
285                            spawn(
286                                async move {
287                                    if let Err(err) =
288                                        handler.handle_notification(notification).await
289                                    {
290                                        log::error!("failed to handle notification: {err:?}");
291                                    }
292                                }
293                                .boxed_local(),
294                            );
295                        }
296                    }
297                }
298            }
299            .boxed_local()
300        });
301    }
302}
303
304#[derive(Debug, Deserialize)]
305pub struct RawIncomingMessage<'a> {
306    id: Option<RequestId>,
307    method: Option<&'a str>,
308    params: Option<&'a RawValue>,
309    result: Option<&'a RawValue>,
310    error: Option<Error>,
311}
312
313#[derive(Debug)]
314pub enum IncomingMessage<Local: Side> {
315    Request {
316        id: RequestId,
317        request: Local::InRequest,
318    },
319    Notification {
320        notification: Local::InNotification,
321    },
322}
323
324pub trait MessageHandler<Local: Side> {
325    fn handle_request(
326        &self,
327        request: Local::InRequest,
328    ) -> impl Future<Output = Result<Local::OutResponse>>;
329
330    fn handle_notification(
331        &self,
332        notification: Local::InNotification,
333    ) -> impl Future<Output = Result<()>>;
334}