jsonrpc_reactor/
reactor.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time;
4
5use tokio::sync::{self, mpsc, oneshot};
6
7pub use serde_json::Value;
8
9use crate::{Id, Notification, Params, Request, Response, RpcError};
10
11#[derive(Debug)]
12struct PendingRequest {
13 sender: oneshot::Sender<Result<Value, RpcError>>,
14 moment: time::Instant,
15 timeout: Option<time::Duration>,
16}
17
18#[derive(Debug)]
19pub struct Reactor {
20 capacity: usize,
21 request_id: i64,
22 requests: mpsc::Sender<Request>,
23 notifications: mpsc::Sender<Notification>,
24 pending: Arc<sync::RwLock<HashMap<Id, PendingRequest>>>,
25}
26
27impl Reactor {
28 pub fn spawn(
29 capacity: usize,
30 requests: mpsc::Sender<Request>,
31 notifications: mpsc::Sender<Notification>,
32 ) -> (Self, mpsc::Sender<Response>) {
33 let request_id = 0;
34 let (responses_tx, mut responses) = mpsc::channel(capacity);
35
36 let pending = HashMap::with_capacity(capacity);
37 let pending = sync::RwLock::new(pending);
38 let pending = Arc::new(pending);
39 let pending_thr = Arc::clone(&pending);
40
41 tokio::spawn(async move {
42 while let Some(Response { id, result }) = responses.recv().await {
43 let mut pending = pending_thr.write().await;
44
45 if let Some(PendingRequest { sender, .. }) = pending.remove(&id) {
46 sender.send(result).ok();
47 }
48 }
49 });
50
51 let slf = Self {
52 capacity,
53 request_id,
54 requests,
55 notifications,
56 pending,
57 };
58
59 (slf, responses_tx)
60 }
61
62 pub async fn notify<M, P>(
63 &mut self,
64 method: M,
65 params: P,
66 timeout: Option<time::Duration>,
67 ) -> bool
68 where
69 M: AsRef<str>,
70 P: Into<Params>,
71 {
72 let method = method.as_ref().to_string();
73 let params = params.into();
74 let notification = Notification { method, params };
75
76 match timeout {
77 Some(t) => self
78 .notifications
79 .send_timeout(notification, t)
80 .await
81 .is_ok(),
82
83 None => self.notifications.send(notification).await.is_ok(),
84 }
85 }
86
87 pub async fn request<M, P>(
88 &mut self,
89 method: M,
90 params: P,
91 timeout: Option<time::Duration>,
92 ) -> Option<oneshot::Receiver<Result<Value, RpcError>>>
93 where
94 M: AsRef<str>,
95 P: Into<Params>,
96 {
97 let id = self.request_id;
98
99 self.request_id = id.wrapping_add(1);
100
101 self.request_with_id(Id::Number(id), method, params, timeout)
102 .await
103 }
104
105 pub async fn request_with_id<M, P>(
106 &mut self,
107 id: Id,
108 method: M,
109 params: P,
110 timeout: Option<time::Duration>,
111 ) -> Option<oneshot::Receiver<Result<Value, RpcError>>>
112 where
113 M: AsRef<str>,
114 P: Into<Params>,
115 {
116 let method = method.as_ref().to_string();
117 let params = params.into();
118 let request = Request {
119 id: id.clone(),
120 method,
121 params,
122 };
123
124 let sent = match &timeout {
125 Some(t) => self.requests.send_timeout(request, *t).await.is_ok(),
126 None => self.requests.send(request).await.is_ok(),
127 };
128
129 if !sent {
130 return None;
131 }
132
133 let (sender, receiver) = oneshot::channel();
134 let pending = PendingRequest {
135 sender,
136 moment: time::Instant::now(),
137 timeout,
138 };
139
140 let mut queue = self.pending.write().await;
141
142 queue.insert(id, pending);
143
144 if self.capacity < queue.len() {
146 let now = time::Instant::now();
147
148 let expired = queue
149 .iter()
150 .filter_map(|(id, pending)| {
151 pending.timeout.and_then(|t| {
152 let diff = now.duration_since(pending.moment);
153
154 (t < diff).then_some(id)
155 })
156 })
157 .cloned()
158 .collect::<Vec<_>>();
159
160 for id in expired {
161 if let Some(pending) = queue.remove(&id) {
162 let response = Err(RpcError {
163 code: -1,
164 message: String::from("response timeout"),
165 data: Value::Null,
166 });
167
168 pending.sender.send(response).ok();
169 }
170 }
171 }
172
173 Some(receiver)
174 }
175}