agent_client_protocol/
rpc.rs

1use std::{
2    any::Any,
3    collections::HashMap,
4    rc::Rc,
5    sync::{
6        Arc,
7        atomic::{AtomicI64, Ordering},
8    },
9};
10
11use agent_client_protocol_schema::{
12    Error, JsonRpcMessage, OutgoingMessage, RequestId, ResponseResult, Result, Side,
13};
14use futures::{
15    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
16    StreamExt as _,
17    channel::{
18        mpsc::{self, UnboundedReceiver, UnboundedSender},
19        oneshot,
20    },
21    future::LocalBoxFuture,
22    io::BufReader,
23    select_biased,
24};
25use parking_lot::Mutex;
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::value::RawValue;
28
29use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
30
31pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
32    outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
33    pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
34    next_id: AtomicI64,
35    broadcast: StreamBroadcast,
36}
37
38struct PendingResponse {
39    deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
40    respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
41}
42
43impl<Local, Remote> RpcConnection<Local, Remote>
44where
45    Local: Side + 'static,
46    Remote: Side + 'static,
47{
48    pub(crate) fn new<Handler>(
49        handler: Handler,
50        outgoing_bytes: impl Unpin + AsyncWrite,
51        incoming_bytes: impl Unpin + AsyncRead,
52        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
53    ) -> (Self, impl futures::Future<Output = Result<()>>)
54    where
55        Handler: MessageHandler<Local> + 'static,
56    {
57        let (incoming_tx, incoming_rx) = mpsc::unbounded();
58        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
59
60        let pending_responses = Arc::new(Mutex::new(HashMap::default()));
61        let (broadcast_tx, broadcast) = StreamBroadcast::new();
62
63        let io_task = {
64            let pending_responses = pending_responses.clone();
65            async move {
66                let result = Self::handle_io(
67                    incoming_tx,
68                    outgoing_rx,
69                    outgoing_bytes,
70                    incoming_bytes,
71                    pending_responses.clone(),
72                    broadcast_tx,
73                )
74                .await;
75                pending_responses.lock().clear();
76                result
77            }
78        };
79
80        Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
81
82        let this = Self {
83            outgoing_tx,
84            pending_responses,
85            next_id: AtomicI64::new(0),
86            broadcast,
87        };
88
89        (this, io_task)
90    }
91
92    pub(crate) fn subscribe(&self) -> StreamReceiver {
93        self.broadcast.receiver()
94    }
95
96    pub(crate) fn notify(
97        &self,
98        method: impl Into<Arc<str>>,
99        params: Option<Remote::InNotification>,
100    ) -> Result<()> {
101        self.outgoing_tx
102            .unbounded_send(OutgoingMessage::Notification {
103                method: method.into(),
104                params,
105            })
106            .map_err(|_| Error::internal_error().with_data("failed to send notification"))
107    }
108
109    pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
110        &self,
111        method: impl Into<Arc<str>>,
112        params: Option<Remote::InRequest>,
113    ) -> impl Future<Output = Result<Out>> {
114        let (tx, rx) = oneshot::channel();
115        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
116        let id = RequestId::Number(id);
117        self.pending_responses.lock().insert(
118            id.clone(),
119            PendingResponse {
120                deserialize: |value| {
121                    serde_json::from_str::<Out>(value.get())
122                        .map(|out| Box::new(out) as _)
123                        .map_err(|_| {
124                            Error::internal_error().with_data("failed to deserialize response")
125                        })
126                },
127                respond: tx,
128            },
129        );
130
131        if self
132            .outgoing_tx
133            .unbounded_send(OutgoingMessage::Request {
134                id: id.clone(),
135                method: method.into(),
136                params,
137            })
138            .is_err()
139        {
140            self.pending_responses.lock().remove(&id);
141        }
142        async move {
143            let result = rx
144                .await
145                .map_err(|_| Error::internal_error().with_data("server shut down unexpectedly"))??
146                .downcast::<Out>()
147                .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
148
149            Ok(*result)
150        }
151    }
152
153    async fn handle_io(
154        incoming_tx: UnboundedSender<IncomingMessage<Local>>,
155        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
156        mut outgoing_bytes: impl Unpin + AsyncWrite,
157        incoming_bytes: impl Unpin + AsyncRead,
158        pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
159        broadcast: StreamSender,
160    ) -> Result<()> {
161        // TODO: Create nicer abstraction for broadcast
162        let mut input_reader = BufReader::new(incoming_bytes);
163        let mut outgoing_line = Vec::new();
164        let mut incoming_line = String::new();
165        loop {
166            select_biased! {
167                message = outgoing_rx.next() => {
168                    if let Some(message) = message {
169                        outgoing_line.clear();
170                        serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
171                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
172                        outgoing_line.push(b'\n');
173                        outgoing_bytes.write_all(&outgoing_line).await.ok();
174                        broadcast.outgoing(&message);
175                    } else {
176                        break;
177                    }
178                }
179                bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
180                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
181                        break
182                    }
183                    log::trace!("recv: {}", &incoming_line);
184
185                    match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
186                        Ok(message) => {
187                            if let Some(id) = message.id {
188                                if let Some(method) = message.method {
189                                    // Request
190                                    match Local::decode_request(method, message.params) {
191                                        Ok(request) => {
192                                            broadcast.incoming_request(id.clone(), method, &request);
193                                            incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
194                                        }
195                                        Err(err) => {
196                                            outgoing_line.clear();
197                                            let error_response = OutgoingMessage::<Local, Remote>::Response {
198                                                id,
199                                                result: ResponseResult::Error(err),
200                                            };
201
202                                            serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
203                                            log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
204                                            outgoing_line.push(b'\n');
205                                            outgoing_bytes.write_all(&outgoing_line).await.ok();
206                                            broadcast.outgoing(&error_response);
207                                        }
208                                    }
209                                } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
210                                    // Response
211                                    if let Some(result_value) = message.result {
212                                        broadcast.incoming_response(id, Ok(Some(result_value)));
213
214                                        let result = (pending_response.deserialize)(result_value);
215                                        pending_response.respond.send(result).ok();
216                                    } else if let Some(error) = message.error {
217                                        broadcast.incoming_response(id, Err(&error));
218
219                                        pending_response.respond.send(Err(error)).ok();
220                                    } else {
221                                        broadcast.incoming_response(id, Ok(None));
222
223                                        let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
224                                        pending_response.respond.send(result).ok();
225                                    }
226                                } else {
227                                    log::error!("received response for unknown request id: {id:?}");
228                                }
229                            } else if let Some(method) = message.method {
230                                // Notification
231                                match Local::decode_notification(method, message.params) {
232                                    Ok(notification) => {
233                                        broadcast.incoming_notification(method, &notification);
234                                        incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
235                                    }
236                                    Err(err) => {
237                                        log::error!("failed to decode {:?}: {err}", message.params);
238                                    }
239                                }
240                            } else {
241                                log::error!("received message with neither id nor method");
242                            }
243                        }
244                        Err(error) => {
245                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
246                        }
247                    }
248                    incoming_line.clear();
249                }
250            }
251        }
252        Ok(())
253    }
254
255    fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
256        outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
257        mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
258        handler: Handler,
259        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
260    ) {
261        let spawn = Rc::new(spawn);
262        let handler = Rc::new(handler);
263        spawn({
264            let spawn = spawn.clone();
265            async move {
266                while let Some(message) = incoming_rx.next().await {
267                    match message {
268                        IncomingMessage::Request { id, request } => {
269                            let outgoing_tx = outgoing_tx.clone();
270                            let handler = handler.clone();
271                            spawn(
272                                async move {
273                                    let result = handler.handle_request(request).await.into();
274                                    outgoing_tx
275                                        .unbounded_send(OutgoingMessage::Response { id, result })
276                                        .ok();
277                                }
278                                .boxed_local(),
279                            );
280                        }
281                        IncomingMessage::Notification { notification } => {
282                            let handler = handler.clone();
283                            spawn(
284                                async move {
285                                    if let Err(err) =
286                                        handler.handle_notification(notification).await
287                                    {
288                                        log::error!("failed to handle notification: {err:?}");
289                                    }
290                                }
291                                .boxed_local(),
292                            );
293                        }
294                    }
295                }
296            }
297            .boxed_local()
298        });
299    }
300}
301
302#[derive(Deserialize)]
303pub struct RawIncomingMessage<'a> {
304    id: Option<RequestId>,
305    method: Option<&'a str>,
306    params: Option<&'a RawValue>,
307    result: Option<&'a RawValue>,
308    error: Option<Error>,
309}
310
311pub enum IncomingMessage<Local: Side> {
312    Request {
313        id: RequestId,
314        request: Local::InRequest,
315    },
316    Notification {
317        notification: Local::InNotification,
318    },
319}
320
321pub trait MessageHandler<Local: Side> {
322    fn handle_request(
323        &self,
324        request: Local::InRequest,
325    ) -> impl Future<Output = Result<Local::OutResponse>>;
326
327    fn handle_notification(
328        &self,
329        notification: Local::InNotification,
330    ) -> impl Future<Output = Result<()>>;
331}