1use std::{
2 any::Any,
3 borrow::Cow,
4 collections::HashMap,
5 rc::Rc,
6 sync::{
7 Arc, Mutex,
8 atomic::{AtomicI64, Ordering},
9 },
10};
11
12use agent_client_protocol_schema::{
13 Error, JsonRpcMessage, Notification, OutgoingMessage, Request, RequestId, Response, Result,
14 Side,
15};
16use futures::{
17 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
18 StreamExt as _,
19 channel::{
20 mpsc::{self, UnboundedReceiver, UnboundedSender},
21 oneshot,
22 },
23 future::LocalBoxFuture,
24 io::BufReader,
25 select_biased,
26};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::value::RawValue;
29
30use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
31
32#[derive(Debug)]
33pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
34 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
35 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
36 next_id: AtomicI64,
37 broadcast: StreamBroadcast,
38}
39
40#[derive(Debug)]
41struct PendingResponse {
42 deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
43 respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
44}
45
46impl<Local, Remote> RpcConnection<Local, Remote>
47where
48 Local: Side + 'static,
49 Remote: Side + 'static,
50{
51 pub(crate) fn new<Handler>(
52 handler: Handler,
53 outgoing_bytes: impl Unpin + AsyncWrite,
54 incoming_bytes: impl Unpin + AsyncRead,
55 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
56 ) -> (Self, impl futures::Future<Output = Result<()>>)
57 where
58 Handler: MessageHandler<Local> + 'static,
59 {
60 let (incoming_tx, incoming_rx) = mpsc::unbounded();
61 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
62
63 let pending_responses = Arc::new(Mutex::new(HashMap::default()));
64 let (broadcast_tx, broadcast) = StreamBroadcast::new();
65
66 let io_task = {
67 let pending_responses = pending_responses.clone();
68 async move {
69 let result = Self::handle_io(
70 incoming_tx,
71 outgoing_rx,
72 outgoing_bytes,
73 incoming_bytes,
74 pending_responses.clone(),
75 broadcast_tx,
76 )
77 .await;
78 pending_responses.lock().unwrap().clear();
79 result
80 }
81 };
82
83 Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
84
85 let this = Self {
86 outgoing_tx,
87 pending_responses,
88 next_id: AtomicI64::new(0),
89 broadcast,
90 };
91
92 (this, io_task)
93 }
94
95 pub(crate) fn subscribe(&self) -> StreamReceiver {
96 self.broadcast.receiver()
97 }
98
99 pub(crate) fn notify(
100 &self,
101 method: impl Into<Arc<str>>,
102 params: Option<Remote::InNotification>,
103 ) -> Result<()> {
104 self.outgoing_tx
105 .unbounded_send(OutgoingMessage::Notification(Notification {
106 method: method.into(),
107 params,
108 }))
109 .map_err(|_| Error::internal_error().data("failed to send notification"))
110 }
111
112 pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
113 &self,
114 method: impl Into<Arc<str>>,
115 params: Option<Remote::InRequest>,
116 ) -> impl Future<Output = Result<Out>> {
117 let (tx, rx) = oneshot::channel();
118 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
119 let id = RequestId::Number(id);
120 self.pending_responses.lock().unwrap().insert(
121 id.clone(),
122 PendingResponse {
123 deserialize: |value| {
124 serde_json::from_str::<Out>(value.get())
125 .map(|out| Box::new(out) as _)
126 .map_err(|_| Error::internal_error().data("failed to deserialize response"))
127 },
128 respond: tx,
129 },
130 );
131
132 if self
133 .outgoing_tx
134 .unbounded_send(OutgoingMessage::Request(Request {
135 id: id.clone(),
136 method: method.into(),
137 params,
138 }))
139 .is_err()
140 {
141 self.pending_responses.lock().unwrap().remove(&id);
142 }
143 async move {
144 let result = rx
145 .await
146 .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??
147 .downcast::<Out>()
148 .map_err(|_| Error::internal_error().data("failed to deserialize response"))?;
149
150 Ok(*result)
151 }
152 }
153
154 async fn handle_io(
155 incoming_tx: UnboundedSender<IncomingMessage<Local>>,
156 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
157 mut outgoing_bytes: impl Unpin + AsyncWrite,
158 incoming_bytes: impl Unpin + AsyncRead,
159 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
160 broadcast: StreamSender,
161 ) -> Result<()> {
162 let mut input_reader = BufReader::new(incoming_bytes);
164 let mut outgoing_line = Vec::new();
165 let mut incoming_line = String::new();
166 loop {
167 select_biased! {
168 message = outgoing_rx.next() => {
169 if let Some(message) = message {
170 outgoing_line.clear();
171 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
172 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
173 outgoing_line.push(b'\n');
174 outgoing_bytes.write_all(&outgoing_line).await.ok();
175 broadcast.outgoing(&message);
176 } else {
177 break;
178 }
179 }
180 bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
181 if bytes_read.map_err(Error::into_internal_error)? == 0 {
182 break
183 }
184 log::trace!("recv: {}", &incoming_line);
185
186 match serde_json::from_str::<RawIncomingMessage<'_>>(&incoming_line) {
187 Ok(message) => {
188 if let Some(id) = message.id {
189 if let Some(method) = message.method {
190 match Local::decode_request(&method, message.params) {
192 Ok(request) => {
193 broadcast.incoming_request(id.clone(), &*method, &request);
194 incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
195 }
196 Err(error) => {
197 outgoing_line.clear();
198 let error_response = OutgoingMessage::<Local, Remote>::Response(Response::Error {
199 id,
200 error,
201 });
202
203 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
204 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
205 outgoing_line.push(b'\n');
206 outgoing_bytes.write_all(&outgoing_line).await.ok();
207 broadcast.outgoing(&error_response);
208 }
209 }
210 } else if let Some(pending_response) = pending_responses.lock().unwrap().remove(&id) {
211 if let Some(result_value) = message.result {
213 broadcast.incoming_response(id, Ok(Some(result_value)));
214
215 let result = (pending_response.deserialize)(result_value);
216 pending_response.respond.send(result).ok();
217 } else if let Some(error) = message.error {
218 broadcast.incoming_response(id, Err(&error));
219
220 pending_response.respond.send(Err(error)).ok();
221 } else {
222 broadcast.incoming_response(id, Ok(None));
223
224 let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
225 pending_response.respond.send(result).ok();
226 }
227 } else {
228 log::error!("received response for unknown request id: {id:?}");
229 }
230 } else if let Some(method) = message.method {
231 match Local::decode_notification(&method, message.params) {
233 Ok(notification) => {
234 broadcast.incoming_notification(&*method, ¬ification);
235 incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
236 }
237 Err(err) => {
238 log::error!("failed to decode {:?}: {err}", message.params);
239 }
240 }
241 } else {
242 log::error!("received message with neither id nor method");
243 }
244 }
245 Err(error) => {
246 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
247 }
248 }
249 incoming_line.clear();
250 }
251 }
252 }
253 Ok(())
254 }
255
256 fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
257 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
258 mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
259 handler: Handler,
260 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
261 ) {
262 let spawn = Rc::new(spawn);
263 let handler = Rc::new(handler);
264 spawn({
265 let spawn = spawn.clone();
266 async move {
267 while let Some(message) = incoming_rx.next().await {
268 match message {
269 IncomingMessage::Request { id, request } => {
270 let outgoing_tx = outgoing_tx.clone();
271 let handler = handler.clone();
272 spawn(
273 async move {
274 let result = handler.handle_request(request).await;
275 outgoing_tx
276 .unbounded_send(OutgoingMessage::Response(Response::new(
277 id, result,
278 )))
279 .ok();
280 }
281 .boxed_local(),
282 );
283 }
284 IncomingMessage::Notification { notification } => {
285 let handler = handler.clone();
286 spawn(
287 async move {
288 if let Err(err) =
289 handler.handle_notification(notification).await
290 {
291 log::error!("failed to handle notification: {err:?}");
292 }
293 }
294 .boxed_local(),
295 );
296 }
297 }
298 }
299 }
300 .boxed_local()
301 });
302 }
303}
304
305#[derive(Debug, Deserialize)]
306pub struct RawIncomingMessage<'a> {
307 id: Option<RequestId>,
308 #[serde(borrow)]
309 method: Option<Cow<'a, str>>,
310 #[serde(borrow)]
311 params: Option<&'a RawValue>,
312 #[serde(borrow)]
313 result: Option<&'a RawValue>,
314 error: Option<Error>,
315}
316
317#[derive(Debug)]
318pub enum IncomingMessage<Local: Side> {
319 Request {
320 id: RequestId,
321 request: Local::InRequest,
322 },
323 Notification {
324 notification: Local::InNotification,
325 },
326}
327
328pub trait MessageHandler<Local: Side> {
329 fn handle_request(
330 &self,
331 request: Local::InRequest,
332 ) -> impl Future<Output = Result<Local::OutResponse>>;
333
334 fn handle_notification(
335 &self,
336 notification: Local::InNotification,
337 ) -> impl Future<Output = Result<()>>;
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_raw_incoming_message_with_escaped_slash() {
346 let json_str = r#"{"jsonrpc":"2.0","id":1,"method":"session\/update","params":{}}"#;
354 let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
355 assert_eq!(parsed.method.unwrap(), "session/update");
356 assert_eq!(parsed.params.unwrap().to_string(), "{}");
357 }
358
359 #[test]
360 fn test_raw_incoming_message_without_escape() {
361 let json_str = r#"{"jsonrpc":"2.0","id":2,"method":"session/update","params":{}}"#;
363 let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
364 assert_eq!(parsed.method.unwrap(), "session/update");
365 assert_eq!(parsed.params.unwrap().to_string(), "{}");
366 }
367}