1use std::{
2 any::Any,
3 collections::HashMap,
4 rc::Rc,
5 sync::{
6 Arc, Mutex,
7 atomic::{AtomicI64, Ordering},
8 },
9};
10
11use agent_client_protocol_schema::{
12 Error, JsonRpcMessage, Notification, OutgoingMessage, Request, RequestId, Response, Result,
13 Side,
14};
15use futures::{
16 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
17 StreamExt as _,
18 channel::{
19 mpsc::{self, UnboundedReceiver, UnboundedSender},
20 oneshot,
21 },
22 future::LocalBoxFuture,
23 io::BufReader,
24 select_biased,
25};
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::value::RawValue;
28
29use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
30
31#[derive(Debug)]
32pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
33 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
34 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
35 next_id: AtomicI64,
36 broadcast: StreamBroadcast,
37}
38
39#[derive(Debug)]
40struct PendingResponse {
41 deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
42 respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
43}
44
45impl<Local, Remote> RpcConnection<Local, Remote>
46where
47 Local: Side + 'static,
48 Remote: Side + 'static,
49{
50 pub(crate) fn new<Handler>(
51 handler: Handler,
52 outgoing_bytes: impl Unpin + AsyncWrite,
53 incoming_bytes: impl Unpin + AsyncRead,
54 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
55 ) -> (Self, impl futures::Future<Output = Result<()>>)
56 where
57 Handler: MessageHandler<Local> + 'static,
58 {
59 let (incoming_tx, incoming_rx) = mpsc::unbounded();
60 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
61
62 let pending_responses = Arc::new(Mutex::new(HashMap::default()));
63 let (broadcast_tx, broadcast) = StreamBroadcast::new();
64
65 let io_task = {
66 let pending_responses = pending_responses.clone();
67 async move {
68 let result = Self::handle_io(
69 incoming_tx,
70 outgoing_rx,
71 outgoing_bytes,
72 incoming_bytes,
73 pending_responses.clone(),
74 broadcast_tx,
75 )
76 .await;
77 pending_responses.lock().unwrap().clear();
78 result
79 }
80 };
81
82 Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
83
84 let this = Self {
85 outgoing_tx,
86 pending_responses,
87 next_id: AtomicI64::new(0),
88 broadcast,
89 };
90
91 (this, io_task)
92 }
93
94 pub(crate) fn subscribe(&self) -> StreamReceiver {
95 self.broadcast.receiver()
96 }
97
98 pub(crate) fn notify(
99 &self,
100 method: impl Into<Arc<str>>,
101 params: Option<Remote::InNotification>,
102 ) -> Result<()> {
103 self.outgoing_tx
104 .unbounded_send(OutgoingMessage::Notification(Notification {
105 method: method.into(),
106 params,
107 }))
108 .map_err(|_| Error::internal_error().data("failed to send notification"))
109 }
110
111 pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
112 &self,
113 method: impl Into<Arc<str>>,
114 params: Option<Remote::InRequest>,
115 ) -> impl Future<Output = Result<Out>> {
116 let (tx, rx) = oneshot::channel();
117 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
118 let id = RequestId::Number(id);
119 self.pending_responses.lock().unwrap().insert(
120 id.clone(),
121 PendingResponse {
122 deserialize: |value| {
123 serde_json::from_str::<Out>(value.get())
124 .map(|out| Box::new(out) as _)
125 .map_err(|_| Error::internal_error().data("failed to deserialize response"))
126 },
127 respond: tx,
128 },
129 );
130
131 if self
132 .outgoing_tx
133 .unbounded_send(OutgoingMessage::Request(Request {
134 id: id.clone(),
135 method: method.into(),
136 params,
137 }))
138 .is_err()
139 {
140 self.pending_responses.lock().unwrap().remove(&id);
141 }
142 async move {
143 let result = rx
144 .await
145 .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??
146 .downcast::<Out>()
147 .map_err(|_| Error::internal_error().data("failed to deserialize response"))?;
148
149 Ok(*result)
150 }
151 }
152
153 async fn handle_io(
154 incoming_tx: UnboundedSender<IncomingMessage<Local>>,
155 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
156 mut outgoing_bytes: impl Unpin + AsyncWrite,
157 incoming_bytes: impl Unpin + AsyncRead,
158 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
159 broadcast: StreamSender,
160 ) -> Result<()> {
161 let mut input_reader = BufReader::new(incoming_bytes);
163 let mut outgoing_line = Vec::new();
164 let mut incoming_line = String::new();
165 loop {
166 select_biased! {
167 message = outgoing_rx.next() => {
168 if let Some(message) = message {
169 outgoing_line.clear();
170 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
171 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
172 outgoing_line.push(b'\n');
173 outgoing_bytes.write_all(&outgoing_line).await.ok();
174 broadcast.outgoing(&message);
175 } else {
176 break;
177 }
178 }
179 bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
180 if bytes_read.map_err(Error::into_internal_error)? == 0 {
181 break
182 }
183 log::trace!("recv: {}", &incoming_line);
184
185 match serde_json::from_str::<RawIncomingMessage<'_>>(&incoming_line) {
186 Ok(message) => {
187 if let Some(id) = message.id {
188 if let Some(method) = message.method {
189 match Local::decode_request(method, message.params) {
191 Ok(request) => {
192 broadcast.incoming_request(id.clone(), method, &request);
193 incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
194 }
195 Err(error) => {
196 outgoing_line.clear();
197 let error_response = OutgoingMessage::<Local, Remote>::Response(Response::Error {
198 id,
199 error,
200 });
201
202 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
203 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
204 outgoing_line.push(b'\n');
205 outgoing_bytes.write_all(&outgoing_line).await.ok();
206 broadcast.outgoing(&error_response);
207 }
208 }
209 } else if let Some(pending_response) = pending_responses.lock().unwrap().remove(&id) {
210 if let Some(result_value) = message.result {
212 broadcast.incoming_response(id, Ok(Some(result_value)));
213
214 let result = (pending_response.deserialize)(result_value);
215 pending_response.respond.send(result).ok();
216 } else if let Some(error) = message.error {
217 broadcast.incoming_response(id, Err(&error));
218
219 pending_response.respond.send(Err(error)).ok();
220 } else {
221 broadcast.incoming_response(id, Ok(None));
222
223 let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
224 pending_response.respond.send(result).ok();
225 }
226 } else {
227 log::error!("received response for unknown request id: {id:?}");
228 }
229 } else if let Some(method) = message.method {
230 match Local::decode_notification(method, message.params) {
232 Ok(notification) => {
233 broadcast.incoming_notification(method, ¬ification);
234 incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
235 }
236 Err(err) => {
237 log::error!("failed to decode {:?}: {err}", message.params);
238 }
239 }
240 } else {
241 log::error!("received message with neither id nor method");
242 }
243 }
244 Err(error) => {
245 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
246 }
247 }
248 incoming_line.clear();
249 }
250 }
251 }
252 Ok(())
253 }
254
255 fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
256 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
257 mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
258 handler: Handler,
259 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
260 ) {
261 let spawn = Rc::new(spawn);
262 let handler = Rc::new(handler);
263 spawn({
264 let spawn = spawn.clone();
265 async move {
266 while let Some(message) = incoming_rx.next().await {
267 match message {
268 IncomingMessage::Request { id, request } => {
269 let outgoing_tx = outgoing_tx.clone();
270 let handler = handler.clone();
271 spawn(
272 async move {
273 let result = handler.handle_request(request).await;
274 outgoing_tx
275 .unbounded_send(OutgoingMessage::Response(Response::new(
276 id, result,
277 )))
278 .ok();
279 }
280 .boxed_local(),
281 );
282 }
283 IncomingMessage::Notification { notification } => {
284 let handler = handler.clone();
285 spawn(
286 async move {
287 if let Err(err) =
288 handler.handle_notification(notification).await
289 {
290 log::error!("failed to handle notification: {err:?}");
291 }
292 }
293 .boxed_local(),
294 );
295 }
296 }
297 }
298 }
299 .boxed_local()
300 });
301 }
302}
303
304#[derive(Debug, Deserialize)]
305pub struct RawIncomingMessage<'a> {
306 id: Option<RequestId>,
307 method: Option<&'a str>,
308 params: Option<&'a RawValue>,
309 result: Option<&'a RawValue>,
310 error: Option<Error>,
311}
312
313#[derive(Debug)]
314pub enum IncomingMessage<Local: Side> {
315 Request {
316 id: RequestId,
317 request: Local::InRequest,
318 },
319 Notification {
320 notification: Local::InNotification,
321 },
322}
323
324pub trait MessageHandler<Local: Side> {
325 fn handle_request(
326 &self,
327 request: Local::InRequest,
328 ) -> impl Future<Output = Result<Local::OutResponse>>;
329
330 fn handle_notification(
331 &self,
332 notification: Local::InNotification,
333 ) -> impl Future<Output = Result<()>>;
334}