jsonrpc_reactor/
reactor.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time;
4
5use tokio::sync::{self, mpsc, oneshot};
6
7pub use serde_json::Value;
8
9use crate::{Id, Notification, Params, Request, Response, RpcError};
10
11#[derive(Debug)]
12struct PendingRequest {
13    sender: oneshot::Sender<Result<Value, RpcError>>,
14    moment: time::Instant,
15    timeout: Option<time::Duration>,
16}
17
18#[derive(Debug)]
19pub struct Reactor {
20    capacity: usize,
21    request_id: i64,
22    requests: mpsc::Sender<Request>,
23    notifications: mpsc::Sender<Notification>,
24    pending: Arc<sync::RwLock<HashMap<Id, PendingRequest>>>,
25}
26
27impl Reactor {
28    pub fn spawn(
29        capacity: usize,
30        requests: mpsc::Sender<Request>,
31        notifications: mpsc::Sender<Notification>,
32    ) -> (Self, mpsc::Sender<Response>) {
33        let request_id = 0;
34        let (responses_tx, mut responses) = mpsc::channel(capacity);
35
36        let pending = HashMap::with_capacity(capacity);
37        let pending = sync::RwLock::new(pending);
38        let pending = Arc::new(pending);
39        let pending_thr = Arc::clone(&pending);
40
41        tokio::spawn(async move {
42            while let Some(Response { id, result }) = responses.recv().await {
43                let mut pending = pending_thr.write().await;
44
45                if let Some(PendingRequest { sender, .. }) = pending.remove(&id) {
46                    sender.send(result).ok();
47                }
48            }
49        });
50
51        let slf = Self {
52            capacity,
53            request_id,
54            requests,
55            notifications,
56            pending,
57        };
58
59        (slf, responses_tx)
60    }
61
62    pub async fn notify<M, P>(
63        &mut self,
64        method: M,
65        params: P,
66        timeout: Option<time::Duration>,
67    ) -> bool
68    where
69        M: AsRef<str>,
70        P: Into<Params>,
71    {
72        let method = method.as_ref().to_string();
73        let params = params.into();
74        let notification = Notification { method, params };
75
76        match timeout {
77            Some(t) => self
78                .notifications
79                .send_timeout(notification, t)
80                .await
81                .is_ok(),
82
83            None => self.notifications.send(notification).await.is_ok(),
84        }
85    }
86
87    pub async fn request<M, P>(
88        &mut self,
89        method: M,
90        params: P,
91        timeout: Option<time::Duration>,
92    ) -> Option<oneshot::Receiver<Result<Value, RpcError>>>
93    where
94        M: AsRef<str>,
95        P: Into<Params>,
96    {
97        let id = self.request_id;
98
99        self.request_id = id.wrapping_add(1);
100
101        self.request_with_id(Id::Number(id), method, params, timeout)
102            .await
103    }
104
105    pub async fn request_with_id<M, P>(
106        &mut self,
107        id: Id,
108        method: M,
109        params: P,
110        timeout: Option<time::Duration>,
111    ) -> Option<oneshot::Receiver<Result<Value, RpcError>>>
112    where
113        M: AsRef<str>,
114        P: Into<Params>,
115    {
116        let method = method.as_ref().to_string();
117        let params = params.into();
118        let request = Request {
119            id: id.clone(),
120            method,
121            params,
122        };
123
124        let sent = match &timeout {
125            Some(t) => self.requests.send_timeout(request, *t).await.is_ok(),
126            None => self.requests.send(request).await.is_ok(),
127        };
128
129        if !sent {
130            return None;
131        }
132
133        let (sender, receiver) = oneshot::channel();
134        let pending = PendingRequest {
135            sender,
136            moment: time::Instant::now(),
137            timeout,
138        };
139
140        let mut queue = self.pending.write().await;
141
142        queue.insert(id, pending);
143
144        // attempt to clean expired pending responses
145        if self.capacity < queue.len() {
146            let now = time::Instant::now();
147
148            let expired = queue
149                .iter()
150                .filter_map(|(id, pending)| {
151                    pending.timeout.and_then(|t| {
152                        let diff = now.duration_since(pending.moment);
153
154                        (t < diff).then_some(id)
155                    })
156                })
157                .cloned()
158                .collect::<Vec<_>>();
159
160            for id in expired {
161                if let Some(pending) = queue.remove(&id) {
162                    let response = Err(RpcError {
163                        code: -1,
164                        message: String::from("response timeout"),
165                        data: Value::Null,
166                    });
167
168                    pending.sender.send(response).ok();
169                }
170            }
171        }
172
173        Some(receiver)
174    }
175}