1use std::{
2 any::Any,
3 collections::HashMap,
4 rc::Rc,
5 sync::{
6 Arc,
7 atomic::{AtomicI32, Ordering},
8 },
9};
10
11use anyhow::Result;
12use futures::{
13 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
14 StreamExt as _,
15 channel::{
16 mpsc::{self, UnboundedReceiver, UnboundedSender},
17 oneshot,
18 },
19 future::LocalBoxFuture,
20 io::BufReader,
21 select_biased,
22};
23use parking_lot::Mutex;
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use serde_json::value::RawValue;
26
27use crate::Error;
28
29pub struct RpcConnection<Local: Side, Remote: Side> {
30 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
31 pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
32 next_id: AtomicI32,
33}
34
35struct PendingResponse {
36 deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>, Error>,
37 respond: oneshot::Sender<Result<Box<dyn Any + Send>, Error>>,
38}
39
40impl<Local, Remote> RpcConnection<Local, Remote>
41where
42 Local: Side + 'static,
43 Remote: Side + 'static,
44{
45 pub fn new<Handler>(
46 handler: Handler,
47 outgoing_bytes: impl Unpin + AsyncWrite,
48 incoming_bytes: impl Unpin + AsyncRead,
49 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
50 ) -> (Self, impl futures::Future<Output = Result<()>>)
51 where
52 Handler: MessageHandler<Local> + 'static,
53 {
54 let (incoming_tx, incoming_rx) = mpsc::unbounded();
55 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
56
57 let pending_responses = Arc::new(Mutex::new(HashMap::default()));
58
59 let io_task = Self::handle_io(
60 incoming_tx,
61 outgoing_rx,
62 outgoing_bytes,
63 incoming_bytes,
64 pending_responses.clone(),
65 );
66
67 Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
68
69 let this = Self {
70 outgoing_tx,
71 pending_responses,
72 next_id: AtomicI32::new(0),
73 };
74
75 (this, io_task)
76 }
77
78 pub fn notify(
79 &self,
80 method: &'static str,
81 params: Option<Remote::InNotification>,
82 ) -> Result<(), Error> {
83 self.outgoing_tx
84 .unbounded_send(OutgoingMessage::Notification { method, params })
85 .map_err(|_| Error::internal_error().with_data("failed to send notification"))
86 }
87
88 pub fn request<Out: DeserializeOwned + Send + 'static>(
89 &self,
90 method: &'static str,
91 params: Option<Remote::InRequest>,
92 ) -> impl Future<Output = Result<Out, Error>> {
93 let (tx, rx) = oneshot::channel();
94 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
95 self.pending_responses.lock().insert(
96 id,
97 PendingResponse {
98 deserialize: |value| {
99 serde_json::from_str::<Out>(value.get())
100 .map(|out| Box::new(out) as _)
101 .map_err(|_| {
102 Error::internal_error().with_data("failed to deserialize response")
103 })
104 },
105 respond: tx,
106 },
107 );
108
109 if self
110 .outgoing_tx
111 .unbounded_send(OutgoingMessage::Request { id, method, params })
112 .is_err()
113 {
114 self.pending_responses.lock().remove(&id);
115 }
116 async move {
117 let result = rx
118 .await
119 .map_err(|e| Error::internal_error().with_data(e.to_string()))??
120 .downcast::<Out>()
121 .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
122
123 Ok(*result)
124 }
125 }
126
127 async fn handle_io(
128 incoming_tx: UnboundedSender<IncomingMessage<Local>>,
129 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
130 mut outgoing_bytes: impl Unpin + AsyncWrite,
131 incoming_bytes: impl Unpin + AsyncRead,
132 pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
133 ) -> Result<()> {
134 let mut input_reader = BufReader::new(incoming_bytes);
135 let mut outgoing_line = Vec::new();
136 let mut incoming_line = String::new();
137 loop {
138 select_biased! {
139 message = outgoing_rx.next() => {
140 if let Some(message) = message {
141 outgoing_line.clear();
142 serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
143 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
144 outgoing_line.push(b'\n');
145 outgoing_bytes.write_all(&outgoing_line).await.ok();
146 } else {
147 break;
148 }
149 }
150 bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
151 if bytes_read.map_err(Error::into_internal_error)? == 0 {
152 break
153 }
154 log::trace!("recv: {}", &incoming_line);
155
156 match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
157 Ok(message) => {
158 if let Some(id) = message.id {
159 if let Some(method) = message.method {
160 match Local::decode_request(method, message.params) {
162 Ok(request) => {
163 incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
164 }
165 Err(err) => {
166 outgoing_line.clear();
167 let error_response = OutgoingMessage::<Local, Remote>::Response {
168 id,
169 result: ResponseResult::Error(err),
170 };
171
172 serde_json::to_writer(&mut outgoing_line, &error_response)?;
173 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
174 outgoing_line.push(b'\n');
175 outgoing_bytes.write_all(&outgoing_line).await.ok();
176 }
177 }
178 } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
179 if let Some(result) = message.result {
181 let result = (pending_response.deserialize)(result);
182 pending_response.respond.send(result).ok();
183 } else if let Some(error) = message.error {
184 pending_response.respond.send(Err(error)).ok();
185 } else {
186 let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
187 pending_response.respond.send(result).ok();
188 }
189 } else {
190 log::error!("received response for unknown request id: {id}");
191 }
192 } else if let Some(method) = message.method {
193 match Local::decode_notification(method, message.params) {
195 Ok(notification) => {
196 incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
197 }
198 Err(err) => {
199 log::error!("failed to decode notification: {err}");
200 }
201 }
202 } else {
203 log::error!("received message with neither id nor method");
204 }
205 }
206 Err(error) => {
207 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
208 }
209 }
210 incoming_line.clear();
211 }
212 }
213 }
214 Ok(())
215 }
216
217 fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
218 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
219 mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
220 handler: Handler,
221 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
222 ) {
223 let spawn = Rc::new(spawn);
224 let handler = Rc::new(handler);
225 spawn({
226 let spawn = spawn.clone();
227 async move {
228 while let Some(message) = incoming_rx.next().await {
229 match message {
230 IncomingMessage::Request { id, request } => {
231 let outgoing_tx = outgoing_tx.clone();
232 let handler = handler.clone();
233 spawn(
234 async move {
235 let result = handler.handle_request(request).await.into();
236 outgoing_tx
237 .unbounded_send(OutgoingMessage::Response { id, result })
238 .ok();
239 }
240 .boxed_local(),
241 )
242 }
243 IncomingMessage::Notification { notification } => {
244 let handler = handler.clone();
245 spawn(
246 async move {
247 if let Err(err) =
248 handler.handle_notification(notification).await
249 {
250 log::error!("failed to handle notification: {err:?}");
251 }
252 }
253 .boxed_local(),
254 )
255 }
256 }
257 }
258 }
259 .boxed_local()
260 })
261 }
262}
263
264#[derive(Deserialize)]
265struct RawIncomingMessage<'a> {
266 id: Option<i32>,
267 method: Option<&'a str>,
268 params: Option<&'a RawValue>,
269 result: Option<&'a RawValue>,
270 error: Option<Error>,
271}
272
273enum IncomingMessage<Local: Side> {
274 Request { id: i32, request: Local::InRequest },
275 Notification { notification: Local::InNotification },
276}
277
278#[derive(Serialize, Deserialize)]
279#[serde(untagged)]
280pub enum OutgoingMessage<Local: Side, Remote: Side> {
281 Request {
282 id: i32,
283 method: &'static str,
284 #[serde(skip_serializing_if = "Option::is_none")]
285 params: Option<Remote::InRequest>,
286 },
287 Response {
288 id: i32,
289 #[serde(flatten)]
290 result: ResponseResult<Local::OutResponse>,
291 },
292 Notification {
293 method: &'static str,
294 #[serde(skip_serializing_if = "Option::is_none")]
295 params: Option<Remote::InNotification>,
296 },
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300#[serde(rename_all = "lowercase")]
301pub enum ResponseResult<Res> {
302 Result(Res),
303 Error(Error),
304}
305
306impl<T> From<Result<T, Error>> for ResponseResult<T> {
307 fn from(result: Result<T, Error>) -> Self {
308 match result {
309 Ok(value) => ResponseResult::Result(value),
310 Err(error) => ResponseResult::Error(error),
311 }
312 }
313}
314
315pub trait Side {
316 type InRequest: Serialize + DeserializeOwned + 'static;
317 type OutResponse: Serialize + DeserializeOwned + 'static;
318 type InNotification: Serialize + DeserializeOwned + 'static;
319
320 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest, Error>;
321
322 fn decode_notification(
323 method: &str,
324 params: Option<&RawValue>,
325 ) -> Result<Self::InNotification, Error>;
326}
327
328pub trait MessageHandler<Local: Side> {
329 fn handle_request(
330 &self,
331 request: Local::InRequest,
332 ) -> impl Future<Output = Result<Local::OutResponse, Error>>;
333
334 fn handle_notification(
335 &self,
336 notification: Local::InNotification,
337 ) -> impl Future<Output = Result<(), Error>>;
338}
339
340#[macro_export]
348macro_rules! dispatch_request {
349 ($base:expr, $id:expr, $params:expr, $request_type:ty, $method:expr, $response_wrapper:expr) => {{
350 let Some(params) = $params else {
351 return Err($crate::Error::invalid_params());
352 };
353
354 match serde_json::from_str::<$request_type>(params.get()) {
355 Ok(arguments) => {
356 let fut = $method(&$base.delegate, arguments);
357 let outgoing_tx = $base.outgoing_tx.clone();
358 ($base.spawn)(::futures::FutureExt::boxed_local(async move {
359 outgoing_tx
360 .unbounded_send($crate::rpc::OutgoingMessage::Response {
361 id: $id,
362 result: fut.await.map($response_wrapper).into(),
363 })
364 .ok();
365 }));
366
367 Ok(())
368 }
369 Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
370 }
371 }};
372}
373
374#[macro_export]
375macro_rules! dispatch_notification {
376 ($method:expr, $params:expr, $params_type:ty, $handler:expr) => {{
377 let Some(params) = $params else {
378 return Err($crate::Error::invalid_params());
379 };
380
381 match serde_json::from_str::<$params_type>(params.get()) {
382 Ok(arguments) => $handler(arguments),
383 Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
384 }
385 }};
386}