agent_client_protocol/
rpc.rs

1use std::{
2    any::Any,
3    collections::HashMap,
4    rc::Rc,
5    sync::{
6        Arc,
7        atomic::{AtomicI32, Ordering},
8    },
9};
10
11use anyhow::Result;
12use futures::{
13    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
14    StreamExt as _,
15    channel::{
16        mpsc::{self, UnboundedReceiver, UnboundedSender},
17        oneshot,
18    },
19    future::LocalBoxFuture,
20    io::BufReader,
21    select_biased,
22};
23use parking_lot::Mutex;
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use serde_json::value::RawValue;
26
27use crate::Error;
28
29pub struct RpcConnection<Local: Side, Remote: Side> {
30    outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
31    pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
32    next_id: AtomicI32,
33}
34
35struct PendingResponse {
36    deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>, Error>,
37    respond: oneshot::Sender<Result<Box<dyn Any + Send>, Error>>,
38}
39
40impl<Local, Remote> RpcConnection<Local, Remote>
41where
42    Local: Side + 'static,
43    Remote: Side + 'static,
44{
45    pub fn new<Handler>(
46        handler: Handler,
47        outgoing_bytes: impl Unpin + AsyncWrite,
48        incoming_bytes: impl Unpin + AsyncRead,
49        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
50    ) -> (Self, impl futures::Future<Output = Result<()>>)
51    where
52        Handler: MessageHandler<Local> + 'static,
53    {
54        let (incoming_tx, incoming_rx) = mpsc::unbounded();
55        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
56
57        let pending_responses = Arc::new(Mutex::new(HashMap::default()));
58
59        let io_task = {
60            let pending_responses = pending_responses.clone();
61            async move {
62                let result = Self::handle_io(
63                    incoming_tx,
64                    outgoing_rx,
65                    outgoing_bytes,
66                    incoming_bytes,
67                    pending_responses.clone(),
68                )
69                .await;
70                pending_responses.lock().clear();
71                result
72            }
73        };
74
75        Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
76
77        let this = Self {
78            outgoing_tx,
79            pending_responses,
80            next_id: AtomicI32::new(0),
81        };
82
83        (this, io_task)
84    }
85
86    pub fn notify(
87        &self,
88        method: &'static str,
89        params: Option<Remote::InNotification>,
90    ) -> Result<(), Error> {
91        self.outgoing_tx
92            .unbounded_send(OutgoingMessage::Notification { method, params })
93            .map_err(|_| Error::internal_error().with_data("failed to send notification"))
94    }
95
96    pub fn request<Out: DeserializeOwned + Send + 'static>(
97        &self,
98        method: &'static str,
99        params: Option<Remote::InRequest>,
100    ) -> impl Future<Output = Result<Out, Error>> {
101        let (tx, rx) = oneshot::channel();
102        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
103        self.pending_responses.lock().insert(
104            id,
105            PendingResponse {
106                deserialize: |value| {
107                    serde_json::from_str::<Out>(value.get())
108                        .map(|out| Box::new(out) as _)
109                        .map_err(|_| {
110                            Error::internal_error().with_data("failed to deserialize response")
111                        })
112                },
113                respond: tx,
114            },
115        );
116
117        if self
118            .outgoing_tx
119            .unbounded_send(OutgoingMessage::Request { id, method, params })
120            .is_err()
121        {
122            self.pending_responses.lock().remove(&id);
123        }
124        async move {
125            let result = rx
126                .await
127                .map_err(|_| Error::internal_error().with_data("server shut down unexpectedly"))??
128                .downcast::<Out>()
129                .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
130
131            Ok(*result)
132        }
133    }
134
135    async fn handle_io(
136        incoming_tx: UnboundedSender<IncomingMessage<Local>>,
137        mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
138        mut outgoing_bytes: impl Unpin + AsyncWrite,
139        incoming_bytes: impl Unpin + AsyncRead,
140        pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
141    ) -> Result<()> {
142        let mut input_reader = BufReader::new(incoming_bytes);
143        let mut outgoing_line = Vec::new();
144        let mut incoming_line = String::new();
145        loop {
146            select_biased! {
147                message = outgoing_rx.next() => {
148                    if let Some(message) = message {
149                        outgoing_line.clear();
150                        serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
151                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
152                        outgoing_line.push(b'\n');
153                        outgoing_bytes.write_all(&outgoing_line).await.ok();
154                    } else {
155                        break;
156                    }
157                }
158                bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
159                    if bytes_read.map_err(Error::into_internal_error)? == 0 {
160                        break
161                    }
162                    log::trace!("recv: {}", &incoming_line);
163
164                    match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
165                        Ok(message) => {
166                            if let Some(id) = message.id {
167                                if let Some(method) = message.method {
168                                    // Request
169                                    match Local::decode_request(method, message.params) {
170                                        Ok(request) => {
171                                            incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
172                                        }
173                                        Err(err) => {
174                                            outgoing_line.clear();
175                                            let error_response = OutgoingMessage::<Local, Remote>::Response {
176                                                id,
177                                                result: ResponseResult::Error(err),
178                                            };
179
180                                            serde_json::to_writer(&mut outgoing_line, &error_response)?;
181                                            log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
182                                            outgoing_line.push(b'\n');
183                                            outgoing_bytes.write_all(&outgoing_line).await.ok();
184                                        }
185                                    }
186                                } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
187                                    // Response
188                                    if let Some(result) = message.result {
189                                        let result = (pending_response.deserialize)(result);
190                                        pending_response.respond.send(result).ok();
191                                    } else if let Some(error) = message.error {
192                                        pending_response.respond.send(Err(error)).ok();
193                                    } else {
194                                        let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
195                                        pending_response.respond.send(result).ok();
196                                    }
197                                } else {
198                                    log::error!("received response for unknown request id: {id}");
199                                }
200                            } else if let Some(method) = message.method {
201                                // Notification
202                                match Local::decode_notification(method, message.params) {
203                                    Ok(notification) => {
204                                        incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
205                                    }
206                                    Err(err) => {
207                                        log::error!("failed to decode notification: {err}");
208                                    }
209                                }
210                            } else {
211                                log::error!("received message with neither id nor method");
212                            }
213                        }
214                        Err(error) => {
215                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
216                        }
217                    }
218                    incoming_line.clear();
219                }
220            }
221        }
222        Ok(())
223    }
224
225    fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
226        outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
227        mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
228        handler: Handler,
229        spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
230    ) {
231        let spawn = Rc::new(spawn);
232        let handler = Rc::new(handler);
233        spawn({
234            let spawn = spawn.clone();
235            async move {
236                while let Some(message) = incoming_rx.next().await {
237                    match message {
238                        IncomingMessage::Request { id, request } => {
239                            let outgoing_tx = outgoing_tx.clone();
240                            let handler = handler.clone();
241                            spawn(
242                                async move {
243                                    let result = handler.handle_request(request).await.into();
244                                    outgoing_tx
245                                        .unbounded_send(OutgoingMessage::Response { id, result })
246                                        .ok();
247                                }
248                                .boxed_local(),
249                            )
250                        }
251                        IncomingMessage::Notification { notification } => {
252                            let handler = handler.clone();
253                            spawn(
254                                async move {
255                                    if let Err(err) =
256                                        handler.handle_notification(notification).await
257                                    {
258                                        log::error!("failed to handle notification: {err:?}");
259                                    }
260                                }
261                                .boxed_local(),
262                            )
263                        }
264                    }
265                }
266            }
267            .boxed_local()
268        })
269    }
270}
271
272#[derive(Deserialize)]
273struct RawIncomingMessage<'a> {
274    id: Option<i32>,
275    method: Option<&'a str>,
276    params: Option<&'a RawValue>,
277    result: Option<&'a RawValue>,
278    error: Option<Error>,
279}
280
281enum IncomingMessage<Local: Side> {
282    Request { id: i32, request: Local::InRequest },
283    Notification { notification: Local::InNotification },
284}
285
286#[derive(Serialize, Deserialize)]
287#[serde(untagged)]
288pub enum OutgoingMessage<Local: Side, Remote: Side> {
289    Request {
290        id: i32,
291        method: &'static str,
292        #[serde(skip_serializing_if = "Option::is_none")]
293        params: Option<Remote::InRequest>,
294    },
295    Response {
296        id: i32,
297        #[serde(flatten)]
298        result: ResponseResult<Local::OutResponse>,
299    },
300    Notification {
301        method: &'static str,
302        #[serde(skip_serializing_if = "Option::is_none")]
303        params: Option<Remote::InNotification>,
304    },
305}
306
307#[derive(Debug, Serialize, Deserialize)]
308#[serde(rename_all = "snake_case")]
309pub enum ResponseResult<Res> {
310    Result(Res),
311    Error(Error),
312}
313
314impl<T> From<Result<T, Error>> for ResponseResult<T> {
315    fn from(result: Result<T, Error>) -> Self {
316        match result {
317            Ok(value) => ResponseResult::Result(value),
318            Err(error) => ResponseResult::Error(error),
319        }
320    }
321}
322
323pub trait Side {
324    type InRequest: Serialize + DeserializeOwned + 'static;
325    type OutResponse: Serialize + DeserializeOwned + 'static;
326    type InNotification: Serialize + DeserializeOwned + 'static;
327
328    fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest, Error>;
329
330    fn decode_notification(
331        method: &str,
332        params: Option<&RawValue>,
333    ) -> Result<Self::InNotification, Error>;
334}
335
336pub trait MessageHandler<Local: Side> {
337    fn handle_request(
338        &self,
339        request: Local::InRequest,
340    ) -> impl Future<Output = Result<Local::OutResponse, Error>>;
341
342    fn handle_notification(
343        &self,
344        notification: Local::InNotification,
345    ) -> impl Future<Output = Result<(), Error>>;
346}
347
348// pub trait Dispatcher {
349//     type Notification: DeserializeOwned;
350
351//     fn request(&self, id: i32, method: &str, params: Option<&RawValue>) -> Result<(), Error>;
352//     fn notification(&self, method: &str, params: Option<&RawValue>) -> Result<(), Error>;
353// }
354
355#[macro_export]
356macro_rules! dispatch_request {
357    ($base:expr, $id:expr, $params:expr, $request_type:ty, $method:expr, $response_wrapper:expr) => {{
358        let Some(params) = $params else {
359            return Err($crate::Error::invalid_params());
360        };
361
362        match serde_json::from_str::<$request_type>(params.get()) {
363            Ok(arguments) => {
364                let fut = $method(&$base.delegate, arguments);
365                let outgoing_tx = $base.outgoing_tx.clone();
366                ($base.spawn)(::futures::FutureExt::boxed_local(async move {
367                    outgoing_tx
368                        .unbounded_send($crate::rpc::OutgoingMessage::Response {
369                            id: $id,
370                            result: fut.await.map($response_wrapper).into(),
371                        })
372                        .ok();
373                }));
374
375                Ok(())
376            }
377            Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
378        }
379    }};
380}
381
382#[macro_export]
383macro_rules! dispatch_notification {
384    ($method:expr, $params:expr, $params_type:ty, $handler:expr) => {{
385        let Some(params) = $params else {
386            return Err($crate::Error::invalid_params());
387        };
388
389        match serde_json::from_str::<$params_type>(params.get()) {
390            Ok(arguments) => $handler(arguments),
391            Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
392        }
393    }};
394}