1use std::{
2 any::Any,
3 collections::HashMap,
4 rc::Rc,
5 sync::{
6 Arc,
7 atomic::{AtomicI32, Ordering},
8 },
9};
10
11use anyhow::Result;
12use futures::{
13 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
14 StreamExt as _,
15 channel::{
16 mpsc::{self, UnboundedReceiver, UnboundedSender},
17 oneshot,
18 },
19 future::LocalBoxFuture,
20 io::BufReader,
21 select_biased,
22};
23use parking_lot::Mutex;
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use serde_json::value::RawValue;
26
27use crate::Error;
28
29pub struct RpcConnection<Local: Side, Remote: Side> {
30 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
31 pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
32 next_id: AtomicI32,
33}
34
35struct PendingResponse {
36 deserialize: fn(&serde_json::value::RawValue) -> Result<Box<dyn Any + Send>, Error>,
37 respond: oneshot::Sender<Result<Box<dyn Any + Send>, Error>>,
38}
39
40impl<Local, Remote> RpcConnection<Local, Remote>
41where
42 Local: Side + 'static,
43 Remote: Side + 'static,
44{
45 pub fn new<Handler>(
46 handler: Handler,
47 outgoing_bytes: impl Unpin + AsyncWrite,
48 incoming_bytes: impl Unpin + AsyncRead,
49 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
50 ) -> (Self, impl futures::Future<Output = Result<()>>)
51 where
52 Handler: MessageHandler<Local> + 'static,
53 {
54 let (incoming_tx, incoming_rx) = mpsc::unbounded();
55 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
56
57 let pending_responses = Arc::new(Mutex::new(HashMap::default()));
58
59 let io_task = {
60 let pending_responses = pending_responses.clone();
61 async move {
62 let result = Self::handle_io(
63 incoming_tx,
64 outgoing_rx,
65 outgoing_bytes,
66 incoming_bytes,
67 pending_responses.clone(),
68 )
69 .await;
70 pending_responses.lock().clear();
71 result
72 }
73 };
74
75 Self::handle_incoming(outgoing_tx.clone(), incoming_rx, handler, spawn);
76
77 let this = Self {
78 outgoing_tx,
79 pending_responses,
80 next_id: AtomicI32::new(0),
81 };
82
83 (this, io_task)
84 }
85
86 pub fn notify(
87 &self,
88 method: &'static str,
89 params: Option<Remote::InNotification>,
90 ) -> Result<(), Error> {
91 self.outgoing_tx
92 .unbounded_send(OutgoingMessage::Notification { method, params })
93 .map_err(|_| Error::internal_error().with_data("failed to send notification"))
94 }
95
96 pub fn request<Out: DeserializeOwned + Send + 'static>(
97 &self,
98 method: &'static str,
99 params: Option<Remote::InRequest>,
100 ) -> impl Future<Output = Result<Out, Error>> {
101 let (tx, rx) = oneshot::channel();
102 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
103 self.pending_responses.lock().insert(
104 id,
105 PendingResponse {
106 deserialize: |value| {
107 serde_json::from_str::<Out>(value.get())
108 .map(|out| Box::new(out) as _)
109 .map_err(|_| {
110 Error::internal_error().with_data("failed to deserialize response")
111 })
112 },
113 respond: tx,
114 },
115 );
116
117 if self
118 .outgoing_tx
119 .unbounded_send(OutgoingMessage::Request { id, method, params })
120 .is_err()
121 {
122 self.pending_responses.lock().remove(&id);
123 }
124 async move {
125 let result = rx
126 .await
127 .map_err(|_| Error::internal_error().with_data("server shut down unexpectedly"))??
128 .downcast::<Out>()
129 .map_err(|_| Error::internal_error().with_data("failed to deserialize response"))?;
130
131 Ok(*result)
132 }
133 }
134
135 async fn handle_io(
136 incoming_tx: UnboundedSender<IncomingMessage<Local>>,
137 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Local, Remote>>,
138 mut outgoing_bytes: impl Unpin + AsyncWrite,
139 incoming_bytes: impl Unpin + AsyncRead,
140 pending_responses: Arc<Mutex<HashMap<i32, PendingResponse>>>,
141 ) -> Result<()> {
142 let mut input_reader = BufReader::new(incoming_bytes);
143 let mut outgoing_line = Vec::new();
144 let mut incoming_line = String::new();
145 loop {
146 select_biased! {
147 message = outgoing_rx.next() => {
148 if let Some(message) = message {
149 outgoing_line.clear();
150 serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
151 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
152 outgoing_line.push(b'\n');
153 outgoing_bytes.write_all(&outgoing_line).await.ok();
154 } else {
155 break;
156 }
157 }
158 bytes_read = input_reader.read_line(&mut incoming_line).fuse() => {
159 if bytes_read.map_err(Error::into_internal_error)? == 0 {
160 break
161 }
162 log::trace!("recv: {}", &incoming_line);
163
164 match serde_json::from_str::<RawIncomingMessage>(&incoming_line) {
165 Ok(message) => {
166 if let Some(id) = message.id {
167 if let Some(method) = message.method {
168 match Local::decode_request(method, message.params) {
170 Ok(request) => {
171 incoming_tx.unbounded_send(IncomingMessage::Request { id, request }).ok();
172 }
173 Err(err) => {
174 outgoing_line.clear();
175 let error_response = OutgoingMessage::<Local, Remote>::Response {
176 id,
177 result: ResponseResult::Error(err),
178 };
179
180 serde_json::to_writer(&mut outgoing_line, &error_response)?;
181 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
182 outgoing_line.push(b'\n');
183 outgoing_bytes.write_all(&outgoing_line).await.ok();
184 }
185 }
186 } else if let Some(pending_response) = pending_responses.lock().remove(&id) {
187 if let Some(result) = message.result {
189 let result = (pending_response.deserialize)(result);
190 pending_response.respond.send(result).ok();
191 } else if let Some(error) = message.error {
192 pending_response.respond.send(Err(error)).ok();
193 } else {
194 let result = (pending_response.deserialize)(&RawValue::from_string("null".into()).unwrap());
195 pending_response.respond.send(result).ok();
196 }
197 } else {
198 log::error!("received response for unknown request id: {id}");
199 }
200 } else if let Some(method) = message.method {
201 match Local::decode_notification(method, message.params) {
203 Ok(notification) => {
204 incoming_tx.unbounded_send(IncomingMessage::Notification { notification }).ok();
205 }
206 Err(err) => {
207 log::error!("failed to decode notification: {err}");
208 }
209 }
210 } else {
211 log::error!("received message with neither id nor method");
212 }
213 }
214 Err(error) => {
215 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
216 }
217 }
218 incoming_line.clear();
219 }
220 }
221 }
222 Ok(())
223 }
224
225 fn handle_incoming<Handler: MessageHandler<Local> + 'static>(
226 outgoing_tx: UnboundedSender<OutgoingMessage<Local, Remote>>,
227 mut incoming_rx: UnboundedReceiver<IncomingMessage<Local>>,
228 handler: Handler,
229 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
230 ) {
231 let spawn = Rc::new(spawn);
232 let handler = Rc::new(handler);
233 spawn({
234 let spawn = spawn.clone();
235 async move {
236 while let Some(message) = incoming_rx.next().await {
237 match message {
238 IncomingMessage::Request { id, request } => {
239 let outgoing_tx = outgoing_tx.clone();
240 let handler = handler.clone();
241 spawn(
242 async move {
243 let result = handler.handle_request(request).await.into();
244 outgoing_tx
245 .unbounded_send(OutgoingMessage::Response { id, result })
246 .ok();
247 }
248 .boxed_local(),
249 )
250 }
251 IncomingMessage::Notification { notification } => {
252 let handler = handler.clone();
253 spawn(
254 async move {
255 if let Err(err) =
256 handler.handle_notification(notification).await
257 {
258 log::error!("failed to handle notification: {err:?}");
259 }
260 }
261 .boxed_local(),
262 )
263 }
264 }
265 }
266 }
267 .boxed_local()
268 })
269 }
270}
271
272#[derive(Deserialize)]
273struct RawIncomingMessage<'a> {
274 id: Option<i32>,
275 method: Option<&'a str>,
276 params: Option<&'a RawValue>,
277 result: Option<&'a RawValue>,
278 error: Option<Error>,
279}
280
281enum IncomingMessage<Local: Side> {
282 Request { id: i32, request: Local::InRequest },
283 Notification { notification: Local::InNotification },
284}
285
286#[derive(Serialize, Deserialize)]
287#[serde(untagged)]
288pub enum OutgoingMessage<Local: Side, Remote: Side> {
289 Request {
290 id: i32,
291 method: &'static str,
292 #[serde(skip_serializing_if = "Option::is_none")]
293 params: Option<Remote::InRequest>,
294 },
295 Response {
296 id: i32,
297 #[serde(flatten)]
298 result: ResponseResult<Local::OutResponse>,
299 },
300 Notification {
301 method: &'static str,
302 #[serde(skip_serializing_if = "Option::is_none")]
303 params: Option<Remote::InNotification>,
304 },
305}
306
307#[derive(Debug, Serialize, Deserialize)]
308#[serde(rename_all = "snake_case")]
309pub enum ResponseResult<Res> {
310 Result(Res),
311 Error(Error),
312}
313
314impl<T> From<Result<T, Error>> for ResponseResult<T> {
315 fn from(result: Result<T, Error>) -> Self {
316 match result {
317 Ok(value) => ResponseResult::Result(value),
318 Err(error) => ResponseResult::Error(error),
319 }
320 }
321}
322
323pub trait Side {
324 type InRequest: Serialize + DeserializeOwned + 'static;
325 type OutResponse: Serialize + DeserializeOwned + 'static;
326 type InNotification: Serialize + DeserializeOwned + 'static;
327
328 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest, Error>;
329
330 fn decode_notification(
331 method: &str,
332 params: Option<&RawValue>,
333 ) -> Result<Self::InNotification, Error>;
334}
335
336pub trait MessageHandler<Local: Side> {
337 fn handle_request(
338 &self,
339 request: Local::InRequest,
340 ) -> impl Future<Output = Result<Local::OutResponse, Error>>;
341
342 fn handle_notification(
343 &self,
344 notification: Local::InNotification,
345 ) -> impl Future<Output = Result<(), Error>>;
346}
347
348#[macro_export]
356macro_rules! dispatch_request {
357 ($base:expr, $id:expr, $params:expr, $request_type:ty, $method:expr, $response_wrapper:expr) => {{
358 let Some(params) = $params else {
359 return Err($crate::Error::invalid_params());
360 };
361
362 match serde_json::from_str::<$request_type>(params.get()) {
363 Ok(arguments) => {
364 let fut = $method(&$base.delegate, arguments);
365 let outgoing_tx = $base.outgoing_tx.clone();
366 ($base.spawn)(::futures::FutureExt::boxed_local(async move {
367 outgoing_tx
368 .unbounded_send($crate::rpc::OutgoingMessage::Response {
369 id: $id,
370 result: fut.await.map($response_wrapper).into(),
371 })
372 .ok();
373 }));
374
375 Ok(())
376 }
377 Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
378 }
379 }};
380}
381
382#[macro_export]
383macro_rules! dispatch_notification {
384 ($method:expr, $params:expr, $params_type:ty, $handler:expr) => {{
385 let Some(params) = $params else {
386 return Err($crate::Error::invalid_params());
387 };
388
389 match serde_json::from_str::<$params_type>(params.get()) {
390 Ok(arguments) => $handler(arguments),
391 Err(err) => Err($crate::Error::invalid_params().with_data(err.to_string())),
392 }
393 }};
394}