agent_client_protocol/
rpc.rs

1use std::{
2    any::Any,
3    collections::HashMap,
4    rc::Rc,
5    sync::{
6        Arc,
7        atomic::{AtomicI32, Ordering},
8    },
9};
10
11use anyhow::Result;
12use futures::{
13    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
14    StreamExt as _,
15    channel::{
16        mpsc::{self, UnboundedReceiver, UnboundedSender},
17        oneshot,
18    },
19    future::LocalBoxFuture,
20    io::BufReader,
21    select_biased,
22};
23use parking_lot::Mutex;
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use serde_json::value::RawValue;
26
27use crate::Error;
28
29pub struct RpcConnection<Local: Side, Remote: Side> {
30    outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
31    pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
32    next_id: AtomicI32,
33}
34
35struct PendingResponse {
36    deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>, Error>,
37    respond: oneshot::Sender<Result<Box<dyn Any + Send>, Error>>,
38}
39
40impl<Local, Remote> RpcConnection<Local, Remote>
41where
42    Local: Side + 'static,
43    Remote: Side + 'static,
44{
45    pub fn new<Handler>(
46        handler: Handler,
47        outgoing_bytes: impl Unpin + AsyncWrite,
48        incoming_bytes: impl Unpin + AsyncRead,
49        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
50    ) -> (Self, impl futures::Future<Output = Result<()>>)
51    where
52        Handler: MessageHandler<Local> + 'static,
53    {
54        let (incoming_tx, incoming_rx) = mpsc::unbounded();
55        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
56
57        let pending_responses = Arc::new(Mutex::new(HashMap::default()));
58
59        let io_task = Self::handle_io(
60            incoming_tx,
61            outgoing_rx,
62            outgoing_bytes,
63            incoming_bytes,
64            pending_responses.clone(),
65        );
66
67        Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
68
69        let this = Self {
70            outgoing_tx,
71            pending_responses,
72            next_id: AtomicI32::new(0),
73        };
74
75        (this, io_task)
76    }
77
78    pub fn notify(
79        &self,
80        method: &'static str,
81        params: Option<Remote::InNotification>,
82    ) -> Result<(), Error> {
83        self.outgoing_tx
84            .unbounded_send(OutgoingMessage::Notification { method, params })
85            .map_err(|_| Error::internal_error().with_data("failed to send notification"))
86    }
87
88    pub fn request<Out: DeserializeOwned + Send + 'static>(
89        &self,
90        method: &'static str,
91        params: Option<Remote::InRequest>,
92    ) -> impl Future<Output = Result<Out, Error>> {
93        let (tx, rx) = oneshot::channel();
94        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
95        self.pending_responses.lock().insert(
96            id,
97            PendingResponse {
98                deserialize: |value| {
99                    serde_json::from_str::<Out>(value.get())
100                        .map(|out| Box::new(out) as _)
101                        .map_err(|_| {
102                            Error::internal_error().with_data("failed to deserialize response")
103                        })
104                },
105                respond: tx,
106            },
107        );
108
109        if self
110            .outgoing_tx
111            .unbounded_send(OutgoingMessage::Request { id, method, params })
112            .is_err()
113        {
114            self.pending_responses.lock().remove(&id);
115        }
116        async move {
117            let result = rx
118                .await
119                .map_err(|e| Error::internal_error().with_data(e.to_string()))??
120                .downcast::<Out>()
121                .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
122
123            Ok(*result)
124        }
125    }
126
127    async fn handle_io(
128        incoming_tx: UnboundedSender<IncomingMessage<Local>>,
129        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
130        mut outgoing_bytes: impl Unpin + AsyncWrite,
131        incoming_bytes: impl Unpin + AsyncRead,
132        pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
133    ) -> Result<()> {
134        let mut input_reader = BufReader::new(incoming_bytes);
135        let mut outgoing_line = Vec::new();
136        let mut incoming_line = String::new();
137        loop {
138            select_biased! {
139                message = outgoing_rx.next() => {
140                    if let Some(message) = message {
141                        outgoing_line.clear();
142                        serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
143                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
144                        outgoing_line.push(b'\n');
145                        outgoing_bytes.write_all(&outgoing_line).await.ok();
146                    } else {
147                        break;
148                    }
149                }
150                bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
151                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
152                        break
153                    }
154                    log::trace!("recv: {}", &incoming_line);
155
156                    match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
157                        Ok(message) => {
158                            if let Some(id) = message.id {
159                                if let Some(method) = message.method {
160                                    // Request
161                                    match Local::decode_request(method, message.params) {
162                                        Ok(request) => {
163                                            incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
164                                        }
165                                        Err(err) => {
166                                            outgoing_line.clear();
167                                            let error_response = OutgoingMessage::<Local, Remote>::Response {
168                                                id,
169                                                result: ResponseResult::Error(err),
170                                            };
171
172                                            serde_json::to_writer(&mut outgoing_line, &error_response)?;
173                                            log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
174                                            outgoing_line.push(b'\n');
175                                            outgoing_bytes.write_all(&outgoing_line).await.ok();
176                                        }
177                                    }
178                                } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
179                                    // Response
180                                    if let Some(result) = message.result {
181                                        let result = (pending_response.deserialize)(result);
182                                        pending_response.respond.send(result).ok();
183                                    } else if let Some(error) = message.error {
184                                        pending_response.respond.send(Err(error)).ok();
185                                    } else {
186                                        let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
187                                        pending_response.respond.send(result).ok();
188                                    }
189                                } else {
190                                    log::error!("received response for unknown request id: {id}");
191                                }
192                            } else if let Some(method) = message.method {
193                                // Notification
194                                match Local::decode_notification(method, message.params) {
195                                    Ok(notification) => {
196                                        incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
197                                    }
198                                    Err(err) => {
199                                        log::error!("failed to decode notification: {err}");
200                                    }
201                                }
202                            } else {
203                                log::error!("received message with neither id nor method");
204                            }
205                        }
206                        Err(error) => {
207                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
208                        }
209                    }
210                    incoming_line.clear();
211                }
212            }
213        }
214        Ok(())
215    }
216
217    fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
218        outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
219        mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
220        handler: Handler,
221        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
222    ) {
223        let spawn = Rc::new(spawn);
224        let handler = Rc::new(handler);
225        spawn({
226            let spawn = spawn.clone();
227            async move {
228                while let Some(message) = incoming_rx.next().await {
229                    match message {
230                        IncomingMessage::Request { id, request } => {
231                            let outgoing_tx = outgoing_tx.clone();
232                            let handler = handler.clone();
233                            spawn(
234                                async move {
235                                    let result = handler.handle_request(request).await.into();
236                                    outgoing_tx
237                                        .unbounded_send(OutgoingMessage::Response { id, result })
238                                        .ok();
239                                }
240                                .boxed_local(),
241                            )
242                        }
243                        IncomingMessage::Notification { notification } => {
244                            let handler = handler.clone();
245                            spawn(
246                                async move {
247                                    if let Err(err) =
248                                        handler.handle_notification(notification).await
249                                    {
250                                        log::error!("failed to handle notification: {err:?}");
251                                    }
252                                }
253                                .boxed_local(),
254                            )
255                        }
256                    }
257                }
258            }
259            .boxed_local()
260        })
261    }
262}
263
264#[derive(Deserialize)]
265struct RawIncomingMessage<'a> {
266    id: Option<i32>,
267    method: Option<&'a str>,
268    params: Option<&'a RawValue>,
269    result: Option<&'a RawValue>,
270    error: Option<Error>,
271}
272
273enum IncomingMessage<Local: Side> {
274    Request { id: i32, request: Local::InRequest },
275    Notification { notification: Local::InNotification },
276}
277
278#[derive(Serialize, Deserialize)]
279#[serde(untagged)]
280pub enum OutgoingMessage<Local: Side, Remote: Side> {
281    Request {
282        id: i32,
283        method: &'static str,
284        #[serde(skip_serializing_if = "Option::is_none")]
285        params: Option<Remote::InRequest>,
286    },
287    Response {
288        id: i32,
289        #[serde(flatten)]
290        result: ResponseResult<Local::OutResponse>,
291    },
292    Notification {
293        method: &'static str,
294        #[serde(skip_serializing_if = "Option::is_none")]
295        params: Option<Remote::InNotification>,
296    },
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300#[serde(rename_all = "lowercase")]
301pub enum ResponseResult<Res> {
302    Result(Res),
303    Error(Error),
304}
305
306impl<T> From<Result<T, Error>> for ResponseResult<T> {
307    fn from(result: Result<T, Error>) -> Self {
308        match result {
309            Ok(value) => ResponseResult::Result(value),
310            Err(error) => ResponseResult::Error(error),
311        }
312    }
313}
314
315pub trait Side {
316    type InRequest: Serialize + DeserializeOwned + 'static;
317    type OutResponse: Serialize + DeserializeOwned + 'static;
318    type InNotification: Serialize + DeserializeOwned + 'static;
319
320    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest, Error>;
321
322    fn decode_notification(
323        method: &str,
324        params: Option<&RawValue>,
325    ) -> Result<Self::InNotification, Error>;
326}
327
328pub trait MessageHandler<Local: Side> {
329    fn handle_request(
330        &self,
331        request: Local::InRequest,
332    ) -> impl Future<Output = Result<Local::OutResponse, Error>>;
333
334    fn handle_notification(
335        &self,
336        notification: Local::InNotification,
337    ) -> impl Future<Output = Result<(), Error>>;
338}
339
340// pub trait Dispatcher {
341//     type Notification: DeserializeOwned;
342
343//     fn request(&self, id: i32, method: &str, params: Option<&RawValue>) -> Result<(), Error>;
344//     fn notification(&self, method: &str, params: Option<&RawValue>) -> Result<(), Error>;
345// }
346
347#[macro_export]
348macro_rules! dispatch_request {
349    ($base:expr, $id:expr, $params:expr, $request_type:ty, $method:expr, $response_wrapper:expr) => {{
350        let Some(params) = $params else {
351            return Err($crate::Error::invalid_params());
352        };
353
354        match serde_json::from_str::<$request_type>(params.get()) {
355            Ok(arguments) => {
356                let fut = $method(&$base.delegate, arguments);
357                let outgoing_tx = $base.outgoing_tx.clone();
358                ($base.spawn)(::futures::FutureExt::boxed_local(async move {
359                    outgoing_tx
360                        .unbounded_send($crate::rpc::OutgoingMessage::Response {
361                            id: $id,
362                            result: fut.await.map($response_wrapper).into(),
363                        })
364                        .ok();
365                }));
366
367                Ok(())
368            }
369            Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
370        }
371    }};
372}
373
374#[macro_export]
375macro_rules! dispatch_notification {
376    ($method:expr, $params:expr, $params_type:ty, $handler:expr) => {{
377        let Some(params) = $params else {
378            return Err($crate::Error::invalid_params());
379        };
380
381        match serde_json::from_str::<$params_type>(params.get()) {
382            Ok(arguments) => $handler(arguments),
383            Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
384        }
385    }};
386}