1use std::{
2 any::Any,
3 borrow::Cow,
4 collections::HashMap,
5 rc::Rc,
6 sync::{
7 Arc, Mutex,
8 atomic::{AtomicI64, Ordering},
9 },
10};
11
12use agent_client_protocol_schema::{
13 Error, JsonRpcMessage, Notification, OutgoingMessage, Request, RequestId, Response, Result,
14 Side,
15};
16use futures::{
17 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
18 StreamExt as _,
19 channel::{
20 mpsc::{self, UnboundedReceiver, UnboundedSender},
21 oneshot,
22 },
23 future::LocalBoxFuture,
24 io::BufReader,
25 select_biased,
26};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::value::RawValue;
29
30use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
31
32#[derive(Debug)]
33pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
34 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
35 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
36 next_id: AtomicI64,
37 broadcast: StreamBroadcast,
38}
39
40#[derive(Debug)]
41struct PendingResponse {
42 deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
43 respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
44}
45
46impl<Local, Remote> RpcConnection<Local, Remote>
47where
48 Local: Side + 'static,
49 Remote: Side + 'static,
50{
51 pub(crate) fn new<Handler>(
52 handler: Handler,
53 outgoing_bytes: impl Unpin + AsyncWrite,
54 incoming_bytes: impl Unpin + AsyncRead,
55 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
56 ) -> (Self, impl futures::Future<Output = Result<()>>)
57 where
58 Handler: MessageHandler<Local> + 'static,
59 {
60 let (incoming_tx, incoming_rx) = mpsc::unbounded();
61 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
62
63 let pending_responses = Arc::new(Mutex::new(HashMap::default()));
64 let (broadcast_tx, broadcast) = StreamBroadcast::new();
65
66 let io_task = {
67 let pending_responses = pending_responses.clone();
68 async move {
69 let result = Self::handle_io(
70 incoming_tx,
71 outgoing_rx,
72 outgoing_bytes,
73 incoming_bytes,
74 pending_responses.clone(),
75 broadcast_tx,
76 )
77 .await;
78 pending_responses.lock().unwrap().clear();
79 result
80 }
81 };
82
83 Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
84
85 let this = Self {
86 outgoing_tx,
87 pending_responses,
88 next_id: AtomicI64::new(0),
89 broadcast,
90 };
91
92 (this, io_task)
93 }
94
95 pub(crate) fn subscribe(&self) -> StreamReceiver {
96 self.broadcast.receiver()
97 }
98
99 pub(crate) fn notify(
100 &self,
101 method: impl Into<Arc<str>>,
102 params: Option<Remote::InNotification>,
103 ) -> Result<()> {
104 self.outgoing_tx
105 .unbounded_send(OutgoingMessage::Notification(Notification {
106 method: method.into(),
107 params,
108 }))
109 .map_err(|_| Error::internal_error().data("failed to send notification"))
110 }
111
112 pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
113 &self,
114 method: impl Into<Arc<str>>,
115 params: Option<Remote::InRequest>,
116 ) -> Result<impl Future<Output = Result<Out>>> {
117 let (tx, rx) = oneshot::channel();
118 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
119 let id = RequestId::Number(id);
120 self.pending_responses.lock().unwrap().insert(
121 id.clone(),
122 PendingResponse {
123 deserialize: |value| {
124 serde_json::from_str::<Out>(value.get())
125 .map(|out| Box::new(out) as _)
126 .map_err(|_| Error::internal_error().data("failed to deserialize response"))
127 },
128 respond: tx,
129 },
130 );
131
132 if self
133 .outgoing_tx
134 .unbounded_send(OutgoingMessage::Request(Request {
135 id: id.clone(),
136 method: method.into(),
137 params,
138 }))
139 .is_err()
140 {
141 self.pending_responses.lock().unwrap().remove(&id);
142 return Err(
143 Error::internal_error().data("connection closed before request could be sent")
144 );
145 }
146 Ok(async move {
147 let result = rx
148 .await
149 .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??
150 .downcast::<Out>()
151 .map_err(|_| Error::internal_error().data("failed to deserialize response"))?;
152
153 Ok(*result)
154 })
155 }
156
157 async fn handle_io(
158 incoming_tx: UnboundedSender<IncomingMessage<Local>>,
159 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
160 mut outgoing_bytes: impl Unpin + AsyncWrite,
161 incoming_bytes: impl Unpin + AsyncRead,
162 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
163 broadcast: StreamSender,
164 ) -> Result<()> {
165 let mut input_reader = BufReader::new(incoming_bytes);
167 let mut outgoing_line = Vec::new();
168 let mut incoming_line = String::new();
169 loop {
170 select_biased! {
171 message = outgoing_rx.next() => {
172 if let Some(message) = message {
173 outgoing_line.clear();
174 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
175 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
176 outgoing_line.push(b'\n');
177 if let Err(e) = outgoing_bytes.write_all(&outgoing_line).await {
178 log::warn!("failed to send message to peer: {e}");
179 }
180 broadcast.outgoing(&message);
181 } else {
182 break;
183 }
184 }
185 bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
186 if bytes_read.map_err(Error::into_internal_error)? == 0 {
187 break
188 }
189 log::trace!("recv: {}", &incoming_line);
190
191 match serde_json::from_str::<RawIncomingMessage<'_>>(&incoming_line) {
192 Ok(message) => {
193 if let Some(id) = message.id {
194 if let Some(method) = message.method {
195 match Local::decode_request(&method, message.params) {
197 Ok(request) => {
198 broadcast.incoming_request(id.clone(), &*method, &request);
199 if let Err(e) = incoming_tx.unbounded_send(IncomingMessage::Request { id, request }) {
200 log::warn!("failed to send request to handler, channel full: {e:?}");
201 }
202 }
203 Err(error) => {
204 outgoing_line.clear();
205 let error_response = OutgoingMessage::<Local, Remote>::Response(Response::Error {
206 id,
207 error,
208 });
209
210 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
211 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
212 outgoing_line.push(b'\n');
213 if let Err(e) = outgoing_bytes.write_all(&outgoing_line).await {
214 log::warn!("failed to send error response to peer: {e}");
215 }
216 broadcast.outgoing(&error_response);
217 }
218 }
219 } else if let Some(pending_response) = pending_responses.lock().unwrap().remove(&id) {
220 if let Some(result_value) = message.result {
222 broadcast.incoming_response(id, Ok(Some(result_value)));
223
224 let result = (pending_response.deserialize)(result_value);
225 pending_response.respond.send(result).ok();
226 } else if let Some(error) = message.error {
227 broadcast.incoming_response(id, Err(&error));
228
229 pending_response.respond.send(Err(error)).ok();
230 } else {
231 broadcast.incoming_response(id, Ok(None));
232
233 let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
234 pending_response.respond.send(result).ok();
235 }
236 } else {
237 log::error!("received response for unknown request id: {id:?}");
238 }
239 } else if let Some(method) = message.method {
240 match Local::decode_notification(&method, message.params) {
242 Ok(notification) => {
243 broadcast.incoming_notification(&*method, ¬ification);
244 if let Err(e) = incoming_tx.unbounded_send(IncomingMessage::Notification { notification }) {
245 log::warn!("failed to send notification to handler, channel full: {e:?}");
246 }
247 }
248 Err(err) => {
249 log::error!("failed to decode {:?}: {err}", message.params);
250 }
251 }
252 } else {
253 log::error!("received message with neither id nor method");
254 }
255 }
256 Err(error) => {
257 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
258 }
259 }
260 incoming_line.clear();
261 }
262 }
263 }
264 Ok(())
265 }
266
267 fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
268 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
269 mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
270 handler: Handler,
271 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
272 ) {
273 let spawn = Rc::new(spawn);
274 let handler = Rc::new(handler);
275 spawn({
276 let spawn = spawn.clone();
277 async move {
278 while let Some(message) = incoming_rx.next().await {
279 match message {
280 IncomingMessage::Request { id, request } => {
281 let outgoing_tx = outgoing_tx.clone();
282 let handler = handler.clone();
283 spawn(
284 async move {
285 let result = handler.handle_request(request).await;
286 outgoing_tx
287 .unbounded_send(OutgoingMessage::Response(Response::new(
288 id, result,
289 )))
290 .ok();
291 }
292 .boxed_local(),
293 );
294 }
295 IncomingMessage::Notification { notification } => {
296 let handler = handler.clone();
297 spawn(
298 async move {
299 if let Err(err) =
300 handler.handle_notification(notification).await
301 {
302 log::error!("failed to handle notification: {err:?}");
303 }
304 }
305 .boxed_local(),
306 );
307 }
308 }
309 }
310 }
311 .boxed_local()
312 });
313 }
314}
315
316#[derive(Debug, Deserialize)]
317pub struct RawIncomingMessage<'a> {
318 id: Option<RequestId>,
319 #[serde(borrow)]
320 method: Option<Cow<'a, str>>,
321 #[serde(borrow)]
322 params: Option<&'a RawValue>,
323 #[serde(borrow)]
324 result: Option<&'a RawValue>,
325 error: Option<Error>,
326}
327
328#[derive(Debug)]
329pub enum IncomingMessage<Local: Side> {
330 Request {
331 id: RequestId,
332 request: Local::InRequest,
333 },
334 Notification {
335 notification: Local::InNotification,
336 },
337}
338
339pub trait MessageHandler<Local: Side> {
340 fn handle_request(
341 &self,
342 request: Local::InRequest,
343 ) -> impl Future<Output = Result<Local::OutResponse>>;
344
345 fn handle_notification(
346 &self,
347 notification: Local::InNotification,
348 ) -> impl Future<Output = Result<()>>;
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_raw_incoming_message_with_escaped_slash() {
357 let json_str = r#"{"jsonrpc":"2.0","id":1,"method":"session\/update","params":{}}"#;
365 let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
366 assert_eq!(parsed.method.unwrap(), "session/update");
367 assert_eq!(parsed.params.unwrap().to_string(), "{}");
368 }
369
370 #[test]
371 fn test_raw_incoming_message_without_escape() {
372 let json_str = r#"{"jsonrpc":"2.0","id":2,"method":"session/update","params":{}}"#;
374 let parsed: RawIncomingMessage<'_> = serde_json::from_str(json_str).unwrap();
375 assert_eq!(parsed.method.unwrap(), "session/update");
376 assert_eq!(parsed.params.unwrap().to_string(), "{}");
377 }
378}