1use std::{
2 any::Any,
3 collections::HashMap,
4 rc::Rc,
5 sync::{
6 Arc,
7 atomic::{AtomicI64, Ordering},
8 },
9};
10
11use agent_client_protocol_schema::{
12 Error, JsonRpcMessage, OutgoingMessage, RequestId, ResponseResult, Result, Side,
13};
14use futures::{
15 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
16 StreamExt as _,
17 channel::{
18 mpsc::{self, UnboundedReceiver, UnboundedSender},
19 oneshot,
20 },
21 future::LocalBoxFuture,
22 io::BufReader,
23 select_biased,
24};
25use parking_lot::Mutex;
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::value::RawValue;
28
29use super::stream_broadcast::{StreamBroadcast, StreamReceiver, StreamSender};
30
31pub(crate) struct RpcConnection<Local: Side, Remote: Side> {
32 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
33 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
34 next_id: AtomicI64,
35 broadcast: StreamBroadcast,
36}
37
38struct PendingResponse {
39 deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>>,
40 respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
41}
42
43impl<Local, Remote> RpcConnection<Local, Remote>
44where
45 Local: Side + 'static,
46 Remote: Side + 'static,
47{
48 pub(crate) fn new<Handler>(
49 handler: Handler,
50 outgoing_bytes: impl Unpin + AsyncWrite,
51 incoming_bytes: impl Unpin + AsyncRead,
52 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
53 ) -> (Self, impl futures::Future<Output = Result<()>>)
54 where
55 Handler: MessageHandler<Local> + 'static,
56 {
57 let (incoming_tx, incoming_rx) = mpsc::unbounded();
58 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
59
60 let pending_responses = Arc::new(Mutex::new(HashMap::default()));
61 let (broadcast_tx, broadcast) = StreamBroadcast::new();
62
63 let io_task = {
64 let pending_responses = pending_responses.clone();
65 async move {
66 let result = Self::handle_io(
67 incoming_tx,
68 outgoing_rx,
69 outgoing_bytes,
70 incoming_bytes,
71 pending_responses.clone(),
72 broadcast_tx,
73 )
74 .await;
75 pending_responses.lock().clear();
76 result
77 }
78 };
79
80 Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
81
82 let this = Self {
83 outgoing_tx,
84 pending_responses,
85 next_id: AtomicI64::new(0),
86 broadcast,
87 };
88
89 (this, io_task)
90 }
91
92 pub(crate) fn subscribe(&self) -> StreamReceiver {
93 self.broadcast.receiver()
94 }
95
96 pub(crate) fn notify(
97 &self,
98 method: impl Into<Arc<str>>,
99 params: Option<Remote::InNotification>,
100 ) -> Result<()> {
101 self.outgoing_tx
102 .unbounded_send(OutgoingMessage::Notification {
103 method: method.into(),
104 params,
105 })
106 .map_err(|_| Error::internal_error().with_data("failed to send notification"))
107 }
108
109 pub(crate) fn request<Out: DeserializeOwned + Send + 'static>(
110 &self,
111 method: impl Into<Arc<str>>,
112 params: Option<Remote::InRequest>,
113 ) -> impl Future<Output = Result<Out>> {
114 let (tx, rx) = oneshot::channel();
115 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
116 let id = RequestId::Number(id);
117 self.pending_responses.lock().insert(
118 id.clone(),
119 PendingResponse {
120 deserialize: |value| {
121 serde_json::from_str::<Out>(value.get())
122 .map(|out| Box::new(out) as _)
123 .map_err(|_| {
124 Error::internal_error().with_data("failed to deserialize response")
125 })
126 },
127 respond: tx,
128 },
129 );
130
131 if self
132 .outgoing_tx
133 .unbounded_send(OutgoingMessage::Request {
134 id: id.clone(),
135 method: method.into(),
136 params,
137 })
138 .is_err()
139 {
140 self.pending_responses.lock().remove(&id);
141 }
142 async move {
143 let result = rx
144 .await
145 .map_err(|_| Error::internal_error().with_data("server shut down unexpectedly"))??
146 .downcast::<Out>()
147 .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
148
149 Ok(*result)
150 }
151 }
152
153 async fn handle_io(
154 incoming_tx: UnboundedSender<IncomingMessage<Local>>,
155 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
156 mut outgoing_bytes: impl Unpin + AsyncWrite,
157 incoming_bytes: impl Unpin + AsyncRead,
158 pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
159 broadcast: StreamSender,
160 ) -> Result<()> {
161 let mut input_reader = BufReader::new(incoming_bytes);
163 let mut outgoing_line = Vec::new();
164 let mut incoming_line = String::new();
165 loop {
166 select_biased! {
167 message = outgoing_rx.next() => {
168 if let Some(message) = message {
169 outgoing_line.clear();
170 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
171 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
172 outgoing_line.push(b'\n');
173 outgoing_bytes.write_all(&outgoing_line).await.ok();
174 broadcast.outgoing(&message);
175 } else {
176 break;
177 }
178 }
179 bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
180 if bytes_read.map_err(Error::into_internal_error)? == 0 {
181 break
182 }
183 log::trace!("recv: {}", &incoming_line);
184
185 match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
186 Ok(message) => {
187 if let Some(id) = message.id {
188 if let Some(method) = message.method {
189 match Local::decode_request(method, message.params) {
191 Ok(request) => {
192 broadcast.incoming_request(id.clone(), method, &request);
193 incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
194 }
195 Err(err) => {
196 outgoing_line.clear();
197 let error_response = OutgoingMessage::<Local, Remote>::Response {
198 id,
199 result: ResponseResult::Error(err),
200 };
201
202 serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
203 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
204 outgoing_line.push(b'\n');
205 outgoing_bytes.write_all(&outgoing_line).await.ok();
206 broadcast.outgoing(&error_response);
207 }
208 }
209 } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
210 if let Some(result_value) = message.result {
212 broadcast.incoming_response(id, Ok(Some(result_value)));
213
214 let result = (pending_response.deserialize)(result_value);
215 pending_response.respond.send(result).ok();
216 } else if let Some(error) = message.error {
217 broadcast.incoming_response(id, Err(&error));
218
219 pending_response.respond.send(Err(error)).ok();
220 } else {
221 broadcast.incoming_response(id, Ok(None));
222
223 let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
224 pending_response.respond.send(result).ok();
225 }
226 } else {
227 log::error!("received response for unknown request id: {id:?}");
228 }
229 } else if let Some(method) = message.method {
230 match Local::decode_notification(method, message.params) {
232 Ok(notification) => {
233 broadcast.incoming_notification(method, ¬ification);
234 incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
235 }
236 Err(err) => {
237 log::error!("failed to decode {:?}: {err}", message.params);
238 }
239 }
240 } else {
241 log::error!("received message with neither id nor method");
242 }
243 }
244 Err(error) => {
245 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
246 }
247 }
248 incoming_line.clear();
249 }
250 }
251 }
252 Ok(())
253 }
254
255 fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
256 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
257 mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
258 handler: Handler,
259 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
260 ) {
261 let spawn = Rc::new(spawn);
262 let handler = Rc::new(handler);
263 spawn({
264 let spawn = spawn.clone();
265 async move {
266 while let Some(message) = incoming_rx.next().await {
267 match message {
268 IncomingMessage::Request { id, request } => {
269 let outgoing_tx = outgoing_tx.clone();
270 let handler = handler.clone();
271 spawn(
272 async move {
273 let result = handler.handle_request(request).await.into();
274 outgoing_tx
275 .unbounded_send(OutgoingMessage::Response { id, result })
276 .ok();
277 }
278 .boxed_local(),
279 );
280 }
281 IncomingMessage::Notification { notification } => {
282 let handler = handler.clone();
283 spawn(
284 async move {
285 if let Err(err) =
286 handler.handle_notification(notification).await
287 {
288 log::error!("failed to handle notification: {err:?}");
289 }
290 }
291 .boxed_local(),
292 );
293 }
294 }
295 }
296 }
297 .boxed_local()
298 });
299 }
300}
301
302#[derive(Deserialize)]
303pub struct RawIncomingMessage<'a> {
304 id: Option<RequestId>,
305 method: Option<&'a str>,
306 params: Option<&'a RawValue>,
307 result: Option<&'a RawValue>,
308 error: Option<Error>,
309}
310
311pub enum IncomingMessage<Local: Side> {
312 Request {
313 id: RequestId,
314 request: Local::InRequest,
315 },
316 Notification {
317 notification: Local::InNotification,
318 },
319}
320
321pub trait MessageHandler<Local: Side> {
322 fn handle_request(
323 &self,
324 request: Local::InRequest,
325 ) -> impl Future<Output = Result<Local::OutResponse>>;
326
327 fn handle_notification(
328 &self,
329 notification: Local::InNotification,
330 ) -> impl Future<Output = Result<()>>;
331}