1use std::{
2 any::Any,
3 collections::HashMap,
4 rc::Rc,
5 sync::{
6 Arc,
7 atomic::{AtomicI64, Ordering},
8 },
9};
10
11use agent_client_protocol_schema::{Error, Result};
12use derive_more::Display;
13use futures::{
14 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
15 StreamExt as _,
16 channel::{
17 mpsc::{self, UnboundedReceiver, UnboundedSender},
18 oneshot,
19 },
20 future::LocalBoxFuture,
21 io::BufReader,
22 select_biased,
23};
24use parking_lot::Mutex;
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use serde_json::value::RawValue;
27
28use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
29
30pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
31 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
32 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
33 next_id: AtomicI64,
34 broadcast: StreamBroadcast,
35}
36
37struct PendingResponse {
38 deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
39 respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
40}
41
42impl<Local, Remote> RpcConnection<Local, Remote>
43where
44 Local: Side + 'static,
45 Remote: Side + 'static,
46{
47 pub(crate) fn new<Handler>(
48 handler: Handler,
49 outgoing_bytes: impl Unpin + AsyncWrite,
50 incoming_bytes: impl Unpin + AsyncRead,
51 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
52 ) -> (Self, impl futures::Future<Output = Result<()>>)
53 where
54 Handler: MessageHandler<Local> + 'static,
55 {
56 let (incoming_tx, incoming_rx) = mpsc::unbounded();
57 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
58
59 let pending_responses = Arc::new(Mutex::new(HashMap::default()));
60 let (broadcast_tx, broadcast) = StreamBroadcast::new();
61
62 let io_task = {
63 let pending_responses = pending_responses.clone();
64 async move {
65 let result = Self::handle_io(
66 incoming_tx,
67 outgoing_rx,
68 outgoing_bytes,
69 incoming_bytes,
70 pending_responses.clone(),
71 broadcast_tx,
72 )
73 .await;
74 pending_responses.lock().clear();
75 result
76 }
77 };
78
79 Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
80
81 let this = Self {
82 outgoing_tx,
83 pending_responses,
84 next_id: AtomicI64::new(0),
85 broadcast,
86 };
87
88 (this, io_task)
89 }
90
91 pub(crate) fn subscribe(&self) -> StreamReceiver {
92 self.broadcast.receiver()
93 }
94
95 pub(crate) fn notify(
96 &self,
97 method: impl Into<Arc<str>>,
98 params: Option<Remote::InNotification>,
99 ) -> Result<()> {
100 self.outgoing_tx
101 .unbounded_send(OutgoingMessage::Notification {
102 method: method.into(),
103 params,
104 })
105 .map_err(|_| Error::internal_error().with_data("failed to send notification"))
106 }
107
108 pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
109 &self,
110 method: impl Into<Arc<str>>,
111 params: Option<Remote::InRequest>,
112 ) -> impl Future<Output = Result<Out>> {
113 let (tx, rx) = oneshot::channel();
114 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
115 let id = RequestId::Number(id);
116 self.pending_responses.lock().insert(
117 id.clone(),
118 PendingResponse {
119 deserialize: |value| {
120 serde_json::from_str::<Out>(value.get())
121 .map(|out| Box::new(out) as _)
122 .map_err(|_| {
123 Error::internal_error().with_data("failed to deserialize response")
124 })
125 },
126 respond: tx,
127 },
128 );
129
130 if self
131 .outgoing_tx
132 .unbounded_send(OutgoingMessage::Request {
133 id: id.clone(),
134 method: method.into(),
135 params,
136 })
137 .is_err()
138 {
139 self.pending_responses.lock().remove(&id);
140 }
141 async move {
142 let result = rx
143 .await
144 .map_err(|_| Error::internal_error().with_data("server shut down unexpectedly"))??
145 .downcast::<Out>()
146 .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
147
148 Ok(*result)
149 }
150 }
151
152 async fn handle_io(
153 incoming_tx: UnboundedSender<IncomingMessage<Local>>,
154 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
155 mut outgoing_bytes: impl Unpin + AsyncWrite,
156 incoming_bytes: impl Unpin + AsyncRead,
157 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
158 broadcast: StreamSender,
159 ) -> Result<()> {
160 let mut input_reader = BufReader::new(incoming_bytes);
162 let mut outgoing_line = Vec::new();
163 let mut incoming_line = String::new();
164 loop {
165 select_biased! {
166 message = outgoing_rx.next() => {
167 if let Some(message) = message {
168 outgoing_line.clear();
169 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
170 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
171 outgoing_line.push(b'\n');
172 outgoing_bytes.write_all(&outgoing_line).await.ok();
173 broadcast.outgoing(&message);
174 } else {
175 break;
176 }
177 }
178 bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
179 if bytes_read.map_err(Error::into_internal_error)? == 0 {
180 break
181 }
182 log::trace!("recv: {}", &incoming_line);
183
184 match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
185 Ok(message) => {
186 if let Some(id) = message.id {
187 if let Some(method) = message.method {
188 match Local::decode_request(method, message.params) {
190 Ok(request) => {
191 broadcast.incoming_request(id.clone(), method, &request);
192 incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
193 }
194 Err(err) => {
195 outgoing_line.clear();
196 let error_response = OutgoingMessage::<Local, Remote>::Response {
197 id,
198 result: ResponseResult::Error(err),
199 };
200
201 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
202 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
203 outgoing_line.push(b'\n');
204 outgoing_bytes.write_all(&outgoing_line).await.ok();
205 broadcast.outgoing(&error_response);
206 }
207 }
208 } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
209 if let Some(result_value) = message.result {
211 broadcast.incoming_response(id, Ok(Some(result_value)));
212
213 let result = (pending_response.deserialize)(result_value);
214 pending_response.respond.send(result).ok();
215 } else if let Some(error) = message.error {
216 broadcast.incoming_response(id, Err(&error));
217
218 pending_response.respond.send(Err(error)).ok();
219 } else {
220 broadcast.incoming_response(id, Ok(None));
221
222 let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
223 pending_response.respond.send(result).ok();
224 }
225 } else {
226 log::error!("received response for unknown request id: {id:?}");
227 }
228 } else if let Some(method) = message.method {
229 match Local::decode_notification(method, message.params) {
231 Ok(notification) => {
232 broadcast.incoming_notification(method, ¬ification);
233 incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
234 }
235 Err(err) => {
236 log::error!("failed to decode {:?}: {err}", message.params);
237 }
238 }
239 } else {
240 log::error!("received message with neither id nor method");
241 }
242 }
243 Err(error) => {
244 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
245 }
246 }
247 incoming_line.clear();
248 }
249 }
250 }
251 Ok(())
252 }
253
254 fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
255 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
256 mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
257 handler: Handler,
258 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
259 ) {
260 let spawn = Rc::new(spawn);
261 let handler = Rc::new(handler);
262 spawn({
263 let spawn = spawn.clone();
264 async move {
265 while let Some(message) = incoming_rx.next().await {
266 match message {
267 IncomingMessage::Request { id, request } => {
268 let outgoing_tx = outgoing_tx.clone();
269 let handler = handler.clone();
270 spawn(
271 async move {
272 let result = handler.handle_request(request).await.into();
273 outgoing_tx
274 .unbounded_send(OutgoingMessage::Response { id, result })
275 .ok();
276 }
277 .boxed_local(),
278 );
279 }
280 IncomingMessage::Notification { notification } => {
281 let handler = handler.clone();
282 spawn(
283 async move {
284 if let Err(err) =
285 handler.handle_notification(notification).await
286 {
287 log::error!("failed to handle notification: {err:?}");
288 }
289 }
290 .boxed_local(),
291 );
292 }
293 }
294 }
295 }
296 .boxed_local()
297 });
298 }
299}
300
301#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord, Display)]
303#[serde(deny_unknown_fields)]
304#[serde(untagged)]
305pub enum RequestId {
306 #[display("null")]
307 Null,
308 Number(i64),
309 Str(String),
310}
311
312#[derive(Deserialize)]
313pub struct RawIncomingMessage<'a> {
314 id: Option<RequestId>,
315 method: Option<&'a str>,
316 params: Option<&'a RawValue>,
317 result: Option<&'a RawValue>,
318 error: Option<Error>,
319}
320
321pub enum IncomingMessage<Local: Side> {
322 Request {
323 id: RequestId,
324 request: Local::InRequest,
325 },
326 Notification {
327 notification: Local::InNotification,
328 },
329}
330
331#[derive(Serialize, Deserialize, Clone)]
332#[serde(untagged)]
333pub enum OutgoingMessage<Local: Side, Remote: Side> {
334 Request {
335 id: RequestId,
336 method: Arc<str>,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 params: Option<Remote::InRequest>,
339 },
340 Response {
341 id: RequestId,
342 #[serde(flatten)]
343 result: ResponseResult<Local::OutResponse>,
344 },
345 Notification {
346 method: Arc<str>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 params: Option<Remote::InNotification>,
349 },
350}
351
352#[derive(Debug, Serialize, Deserialize)]
357pub struct JsonRpcMessage<M> {
358 jsonrpc: &'static str,
359 #[serde(flatten)]
360 message: M,
361}
362
363impl<M> JsonRpcMessage<M> {
364 pub const VERSION: &'static str = "2.0";
368
369 #[must_use]
372 pub fn wrap(message: M) -> Self {
373 Self {
374 jsonrpc: Self::VERSION,
375 message,
376 }
377 }
378}
379
380#[derive(Debug, Serialize, Deserialize, Clone)]
381#[serde(rename_all = "snake_case")]
382pub enum ResponseResult<Res> {
383 Result(Res),
384 Error(Error),
385}
386
387impl<T> From<Result<T>> for ResponseResult<T> {
388 fn from(result: Result<T>) -> Self {
389 match result {
390 Ok(value) => ResponseResult::Result(value),
391 Err(error) => ResponseResult::Error(error),
392 }
393 }
394}
395
396pub trait Side: Clone {
397 type InRequest: Clone + Serialize + DeserializeOwned + 'static;
398 type OutResponse: Clone + Serialize + DeserializeOwned + 'static;
399 type InNotification: Clone + Serialize + DeserializeOwned + 'static;
400
401 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
402
403 fn decode_notification(method: &str, params: Option<&RawValue>)
404 -> Result<Self::InNotification>;
405}
406
407pub trait MessageHandler<Local: Side> {
408 fn handle_request(
409 &self,
410 request: Local::InRequest,
411 ) -> impl Future<Output = Result<Local::OutResponse>>;
412
413 fn handle_notification(
414 &self,
415 notification: Local::InNotification,
416 ) -> impl Future<Output = Result<()>>;
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 use serde_json::{Number, Value};
424
425 #[test]
426 fn id_deserialization() {
427 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
428 assert_eq!(id, RequestId::Null);
429
430 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
431 .unwrap();
432 assert_eq!(id, RequestId::Number(1));
433
434 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
435 .unwrap();
436 assert_eq!(id, RequestId::Number(-1));
437
438 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
439 assert_eq!(id, RequestId::Str("id".to_owned()));
440 }
441
442 #[test]
443 fn id_serialization() {
444 let id = serde_json::to_value(RequestId::Null).unwrap();
445 assert_eq!(id, Value::Null);
446
447 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
448 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
449
450 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
451 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
452
453 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
454 assert_eq!(id, Value::String("id".to_owned()));
455 }
456
457 #[test]
458 fn id_display() {
459 let id = RequestId::Null;
460 assert_eq!(id.to_string(), "null");
461
462 let id = RequestId::Number(1);
463 assert_eq!(id.to_string(), "1");
464
465 let id = RequestId::Number(-1);
466 assert_eq!(id.to_string(), "-1");
467
468 let id = RequestId::Str("id".to_owned());
469 assert_eq!(id.to_string(), "id");
470 }
471}