1use std::{
2 collections::{HashMap, VecDeque},
3 io,
4 sync::{
5 atomic::{AtomicBool, AtomicUsize, Ordering},
6 Arc,
7 },
8};
9
10use futures::{lock::Mutex, Future, Sink, SinkExt, Stream, StreamExt};
11
12use futures_map::KeyWaitMap;
13
14use crate::{Request, Response, Version};
15
16pub trait JsonRpcClientSender<E>: Sink<Vec<u8>, Error = E> + Unpin
17where
18 E: ToString,
19{
20 fn send_request<S, P, R, D>(
21 &mut self,
22 request: Request<S, P>,
23 ) -> impl Future<Output = io::Result<()>>
24 where
25 S: AsRef<str> + serde::Serialize,
26 P: serde::Serialize,
27 {
28 async move {
29 let data = serde_json::to_vec(&request)?;
30
31 self.send(data)
32 .await
33 .map_err(|err| io::Error::new(io::ErrorKind::BrokenPipe, err.to_string()))?;
34
35 Ok(())
36 }
37 }
38}
39
40impl<T, E> JsonRpcClientSender<E> for T
41where
42 T: Sink<Vec<u8>, Error = E> + Unpin,
43 E: ToString,
44{
45}
46
47pub trait JsonRpcClientReceiver: Stream<Item = Vec<u8>> + Unpin {
48 fn next_response<R, D>(&mut self) -> impl Future<Output = io::Result<Response<String, R, D>>>
49 where
50 for<'a> R: serde::Deserialize<'a>,
51 for<'a> D: serde::Deserialize<'a>,
52 {
53 async move {
54 let buf = self.next().await.ok_or(io::Error::new(
55 io::ErrorKind::BrokenPipe,
56 "JSONRPC client receive stream broken",
57 ))?;
58
59 Ok(serde_json::from_slice(&buf)?)
60 }
61 }
62}
63
64impl<T> JsonRpcClientReceiver for T where T: Stream<Item = Vec<u8>> + Unpin {}
65
66type InnerResponse = Response<String, serde_json::Value, serde_json::Value>;
67
68#[derive(Default)]
69struct RawJsonRpcClient {
70 max_send_queue_size: usize,
71 send_queue: VecDeque<(usize, Vec<u8>)>,
72 received_resps: HashMap<usize, InnerResponse>,
73}
74
75impl RawJsonRpcClient {
76 fn new(max_send_queue_size: usize) -> Self {
77 Self {
78 max_send_queue_size,
79 ..Default::default()
80 }
81 }
82
83 fn cache_send(&mut self, id: usize, data: Vec<u8>) -> Option<(usize, Vec<u8>)> {
84 if self.send_queue.len() == self.max_send_queue_size {
85 return Some((id, data));
86 }
87
88 self.send_queue.push_back((id, data));
89
90 None
91 }
92
93 fn send_one(&mut self) -> Option<(usize, Vec<u8>)> {
94 self.send_queue.pop_front()
95 }
96}
97
98#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
99enum JsonRpcClientEvent {
100 Send,
101 Forward,
102 Response(usize),
103}
104
105struct RawJsonRpcClientState {
106 is_closed: AtomicBool,
107 next_id: AtomicUsize,
108 raw: Mutex<RawJsonRpcClient>,
109 wait_map: KeyWaitMap<JsonRpcClientEvent, ()>,
110}
111
112#[derive(Clone)]
114pub struct JsonRpcClientState(Arc<RawJsonRpcClientState>);
115
116impl JsonRpcClientState {
117 pub fn new(max_send_queue_size: usize) -> Self {
119 Self(Arc::new(RawJsonRpcClientState {
120 is_closed: Default::default(),
121 next_id: Default::default(),
122 raw: Mutex::new(RawJsonRpcClient::new(max_send_queue_size)),
123 wait_map: KeyWaitMap::new(),
124 }))
125 }
126
127 pub async fn call<M, P, R>(&self, method: M, params: P) -> std::io::Result<R>
129 where
130 M: AsRef<str>,
131 P: serde::Serialize,
132 for<'a> R: serde::Deserialize<'a>,
133 {
134 let id = self.0.next_id.fetch_add(1, Ordering::Relaxed);
135
136 let request = Request {
137 id: Some(id),
138 jsonrpc: Version::default(),
139 method: method.as_ref(),
140 params,
141 };
142
143 let packet = serde_json::to_vec(&request)?;
144
145 let mut send_data = Some((id, packet));
146
147 while let Some((id, data)) = send_data {
148 if self.is_closed() {
149 return Err(std::io::Error::new(
150 std::io::ErrorKind::BrokenPipe,
151 "JsonRpcClient is closed",
152 ));
153 }
154
155 let mut raw = self.0.raw.lock().await;
156
157 send_data = raw.cache_send(id, data);
158
159 if send_data.is_some() {
160 self.0.wait_map.wait(&JsonRpcClientEvent::Send, raw).await;
161 } else {
162 self.0.wait_map.insert(JsonRpcClientEvent::Forward, ());
163 }
164 }
165
166 if let Some(_) = self
167 .0
168 .wait_map
169 .wait(&JsonRpcClientEvent::Response(id), ())
170 .await
171 {
172 if self.is_closed() {
173 return Err(std::io::Error::new(
174 std::io::ErrorKind::BrokenPipe,
175 "JsonRpcClient is closed",
176 ));
177 }
178
179 let mut raw = self.0.raw.lock().await;
180
181 let resp = raw
182 .received_resps
183 .remove(&id)
184 .expect("consistency guarantee");
185
186 if let Some(err) = resp.error {
187 return Err(io::Error::new(io::ErrorKind::Other, err));
188 }
189
190 Ok(serde_json::from_value(serde_json::to_value(resp.result)?)?)
191 } else {
192 Err(io::Error::new(io::ErrorKind::Other, "jsonrpc canceled."))
193 }
194 }
195
196 pub async fn send(&self) -> std::io::Result<(usize, Vec<u8>)> {
198 loop {
199 let mut raw = self.0.raw.lock().await;
200
201 if self.is_closed() {
202 return Err(std::io::Error::new(
203 std::io::ErrorKind::BrokenPipe,
204 "JsonRpcClient is closed",
205 ));
206 }
207
208 if let Some(packet) = raw.send_one() {
209 return Ok(packet);
210 }
211
212 self.0.wait_map.insert(JsonRpcClientEvent::Send, ());
213
214 self.0
215 .wait_map
216 .wait(&JsonRpcClientEvent::Forward, raw)
217 .await;
218 }
219 }
220
221 pub async fn recv<V: AsRef<[u8]>>(&self, packet: V) -> std::io::Result<()> {
223 if self.is_closed() {
224 return Err(std::io::Error::new(
225 std::io::ErrorKind::BrokenPipe,
226 "JsonRpcClient is closed",
227 ));
228 }
229
230 let resp: Response<String, serde_json::Value, serde_json::Value> =
231 serde_json::from_slice(packet.as_ref())?;
232
233 let mut raw = self.0.raw.lock().await;
234
235 let id = resp.id;
236
237 raw.received_resps.insert(resp.id, resp);
238
239 self.0.wait_map.insert(JsonRpcClientEvent::Response(id), ());
240
241 Ok(())
242 }
243
244 pub fn close(&self) {
246 self.0.is_closed.store(true, Ordering::SeqCst);
247 }
248
249 pub fn is_closed(&self) -> bool {
251 self.0.is_closed.load(Ordering::SeqCst)
252 }
253}
254
255pub struct JsonRpcClient(JsonRpcClientState);
257
258impl Drop for JsonRpcClient {
259 fn drop(&mut self) {
260 self.0.close();
261 }
262}
263
264impl Default for JsonRpcClient {
265 fn default() -> Self {
266 Self::new(128)
267 }
268}
269
270impl JsonRpcClient {
271 pub fn new(max_send_queue_size: usize) -> Self {
273 Self(JsonRpcClientState::new(max_send_queue_size))
274 }
275 pub async fn call<M, P, R>(&self, method: M, params: P) -> std::io::Result<R>
277 where
278 M: AsRef<str>,
279 P: serde::Serialize,
280 for<'a> R: serde::Deserialize<'a>,
281 {
282 self.0.call(method, params).await
283 }
284
285 pub fn to_state(&self) -> JsonRpcClientState {
287 self.0.clone()
288 }
289}
290
291#[cfg(test)]
292mod tests {
293
294 use std::task::Poll;
295
296 use futures::poll;
297 use serde_json::json;
298
299 use crate::{Error, ErrorCode};
300
301 use super::*;
302
303 #[futures_test::test]
304 async fn test_client_drop() {
305 let client = JsonRpcClient::default();
306
307 let state = client.to_state();
308
309 drop(client);
310
311 assert!(state.is_closed());
312 }
313
314 #[futures_test::test]
315 async fn test_empty_return() {
316 let client = JsonRpcClient::default();
317
318 let client = client.to_state();
319
320 let call_client = client.clone();
321
322 let mut call = Box::pin(call_client.call("echo", ("hello", 1)));
323
324 let poll_result: Poll<io::Result<()>> = poll!(&mut call);
325
326 assert!(poll_result.is_pending());
327
328 let (_, buf) = client.send().await.unwrap();
329
330 let json = json!({"id":0,"jsonrpc":"2.0","method":"echo","params":["hello",1]}).to_string();
331
332 assert_eq!(json.as_bytes(), buf);
333
334 client
335 .recv(
336 json!({
337 "id":0,"jsonrpc":"2.0"
338 })
339 .to_string(),
340 )
341 .await
342 .unwrap();
343
344 let poll_result: Poll<io::Result<()>> = poll!(&mut call);
345
346 assert!(matches!(poll_result, Poll::Ready(Ok(()))));
347
348 let call_client = client.clone();
349
350 let mut call = Box::pin(call_client.call("echo", ("hello", 1)));
351
352 let poll_result: Poll<io::Result<i32>> = poll!(&mut call);
353
354 assert!(poll_result.is_pending());
355
356 client
357 .recv(
358 json!({
359 "id":1,"jsonrpc":"2.0","result":1
360 })
361 .to_string(),
362 )
363 .await
364 .unwrap();
365
366 let poll_result = poll!(&mut call);
367
368 assert!(matches!(poll_result, Poll::Ready(Ok(1))));
369
370 let call_client = client.clone();
371
372 let mut call = Box::pin(call_client.call("echo", ("hello", 1)));
373
374 let poll_result: Poll<io::Result<i32>> = poll!(&mut call);
375
376 assert!(poll_result.is_pending());
377
378 client
379 .recv(
380 json!({
381 "id":2,"jsonrpc":"2.0","error": Error {
382 code: ErrorCode::InternalError,
383 message: "",
384 data: None::<()>
385 }
386 })
387 .to_string(),
388 )
389 .await
390 .unwrap();
391
392 let poll_result = poll!(&mut call);
393
394 assert!(matches!(poll_result, Poll::Ready(Err(_))));
395 }
396}