1use super::{
2 prepare_call_payload, RpcError, RpcEvent, RpcEventKind, RpcResult, RPC_ERROR,
3 RPC_ERROR_CODE_METHOD_NOT_FOUND, RPC_NOTIFICATION, RPC_REPLY,
4};
5use crate::borrow::Cow;
6use crate::client::AsyncClient;
7use crate::EventChannel;
8use crate::{Error, Frame, FrameKind, OpConfirm, QoS};
9use async_trait::async_trait;
10use log::{error, trace, warn};
11#[cfg(not(feature = "rt"))]
12use parking_lot::Mutex as SyncMutex;
13#[cfg(feature = "rt")]
14use parking_lot_rt::Mutex as SyncMutex;
15use std::collections::BTreeMap;
16use std::sync::atomic;
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::oneshot;
20use tokio::sync::Mutex;
21use tokio::task::JoinHandle;
22use tokio_task_pool::{Pool, Task};
23
24#[derive(Default, Clone, Debug)]
36pub struct Options {
37 blocking_notifications: bool,
38 blocking_frames: bool,
39 task_pool: Option<Arc<Pool>>,
40}
41
42impl Options {
43 #[inline]
44 pub fn new() -> Self {
45 Self::default()
46 }
47 #[inline]
48 pub fn blocking_notifications(mut self) -> Self {
49 self.blocking_notifications = true;
50 self
51 }
52 #[inline]
53 pub fn blocking_frames(mut self) -> Self {
54 self.blocking_frames = true;
55 self
56 }
57 #[inline]
58 pub fn with_task_pool(mut self, pool: Pool) -> Self {
60 self.task_pool = Some(Arc::new(pool));
61 self
62 }
63}
64
65#[allow(clippy::module_name_repetitions)]
66#[async_trait]
67pub trait RpcHandlers {
68 #[allow(unused_variables)]
69 async fn handle_call(&self, event: RpcEvent) -> RpcResult {
70 Err(RpcError::method(None))
71 }
72 #[allow(unused_variables)]
73 async fn handle_notification(&self, event: RpcEvent) {}
74 #[allow(unused_variables)]
75 async fn handle_frame(&self, frame: Frame) {}
76}
77
78pub struct DummyHandlers {}
79
80#[async_trait]
81impl RpcHandlers for DummyHandlers {
82 async fn handle_call(&self, _event: RpcEvent) -> RpcResult {
83 Err(RpcError::new(
84 RPC_ERROR_CODE_METHOD_NOT_FOUND,
85 Some("RPC handler is not implemented".as_bytes().to_vec()),
86 ))
87 }
88}
89
90type CallMap = Arc<SyncMutex<BTreeMap<u32, oneshot::Sender<RpcEvent>>>>;
91
92#[async_trait]
93pub trait Rpc {
94 fn client(&self) -> Arc<Mutex<dyn AsyncClient + 'static>>;
100 async fn notify(
101 &self,
102 target: &str,
103 data: Cow<'async_trait>,
104 qos: QoS,
105 ) -> Result<OpConfirm, Error>;
106 async fn call0(
108 &self,
109 target: &str,
110 method: &str,
111 params: Cow<'async_trait>,
112 qos: QoS,
113 ) -> Result<OpConfirm, Error>;
114 async fn call(
116 &self,
117 target: &str,
118 method: &str,
119 params: Cow<'async_trait>,
120 qos: QoS,
121 ) -> Result<RpcEvent, RpcError>;
122 fn is_connected(&self) -> bool;
123}
124
125#[allow(clippy::module_name_repetitions)]
126pub struct RpcClient {
127 call_id: SyncMutex<u32>,
128 timeout: Option<Duration>,
129 client: Arc<Mutex<dyn AsyncClient>>,
130 processor_fut: Arc<SyncMutex<JoinHandle<()>>>,
131 pinger_fut: Option<JoinHandle<()>>,
132 calls: CallMap,
133 connected: Option<Arc<atomic::AtomicBool>>,
134}
135
136#[allow(clippy::too_many_lines)]
137async fn processor<C, H>(
138 rx: EventChannel,
139 processor_client: Arc<Mutex<C>>,
140 calls: CallMap,
141 handlers: Arc<H>,
142 opts: Options,
143) where
144 C: AsyncClient + 'static,
145 H: RpcHandlers + Send + Sync + 'static,
146{
147 macro_rules! spawn {
148 ($task_id: expr, $fut: expr) => {
149 if let Some(ref pool) = opts.task_pool {
150 let task = Task::new($fut).with_id($task_id);
151 if let Err(e) = pool.spawn_task(task).await {
152 error!("Unable to spawn RPC task: {}", e);
153 }
154 } else {
155 tokio::spawn($fut);
156 }
157 };
158 }
159 while let Ok(frame) = rx.recv().await {
160 if frame.kind() == FrameKind::Message {
161 match RpcEvent::try_from(frame) {
162 Ok(event) => match event.kind() {
163 RpcEventKind::Notification => {
164 trace!("RPC notification from {}", event.frame().sender());
165 if opts.blocking_notifications {
166 handlers.handle_notification(event).await;
167 } else {
168 let h = handlers.clone();
169 spawn!("rpc.notification", async move {
170 h.handle_notification(event).await;
171 });
172 }
173 }
174 RpcEventKind::Request => {
175 let id = event.id();
176 trace!(
177 "RPC request from {}, id: {}, method: {:?}",
178 event.frame().sender(),
179 id,
180 event.method()
181 );
182 let ev = if id > 0 {
183 Some((event.frame().sender().to_owned(), processor_client.clone()))
184 } else {
185 None
186 };
187 let h = handlers.clone();
188 spawn!("rpc.request", async move {
189 let qos = if event.frame().is_realtime() {
190 QoS::RealtimeProcessed
191 } else {
192 QoS::Processed
193 };
194 let res = h.handle_call(event).await;
195 if let Some((target, cl)) = ev {
196 macro_rules! send_reply {
197 ($payload: expr, $result: expr) => {{
198 let mut client = cl.lock().await;
199 if let Some(result) = $result {
200 client
201 .zc_send(&target, $payload, result.into(), qos)
202 .await
203 } else {
204 client
205 .zc_send(&target, $payload, (&[][..]).into(), qos)
206 .await
207 }
208 }};
209 }
210 match res {
211 Ok(v) => {
212 trace!("Sending RPC reply id {} to {}", id, target);
213 let mut payload = Vec::with_capacity(5);
214 payload.push(RPC_REPLY);
215 payload.extend_from_slice(&id.to_le_bytes());
216 let _r = send_reply!(payload.into(), v);
217 }
218 Err(e) => {
219 trace!(
220 "Sending RPC error {} reply id {} to {}",
221 e.code,
222 id,
223 target,
224 );
225 let mut payload = Vec::with_capacity(7);
226 payload.push(RPC_ERROR);
227 payload.extend_from_slice(&id.to_le_bytes());
228 payload.extend_from_slice(&e.code.to_le_bytes());
229 let _r = send_reply!(payload.into(), e.data);
230 }
231 }
232 }
233 });
234 }
235 RpcEventKind::Reply | RpcEventKind::ErrorReply => {
236 let id = event.id();
237 trace!(
238 "RPC {} from {}, id: {}",
239 event.kind(),
240 event.frame().sender(),
241 id
242 );
243 if let Some(tx) = { calls.lock().remove(&id) } {
244 let _r = tx.send(event);
245 } else {
246 warn!("orphaned RPC response: {}", id);
247 }
248 }
249 },
250 Err(e) => {
251 error!("{}", e);
252 }
253 }
254 } else if opts.blocking_frames {
255 handlers.handle_frame(frame).await;
256 } else {
257 let h = handlers.clone();
258 spawn!("rpc.frame", async move {
259 h.handle_frame(frame).await;
260 });
261 }
262 }
263}
264
265impl RpcClient {
266 pub fn new<H>(client: impl AsyncClient + 'static, handlers: H) -> Self
268 where
269 H: RpcHandlers + Send + Sync + 'static,
270 {
271 Self::init(client, handlers, Options::default())
272 }
273
274 pub fn new0(client: impl AsyncClient + 'static) -> Self {
276 Self::init(client, DummyHandlers {}, Options::default())
277 }
278
279 pub fn create<H>(client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
281 where
282 H: RpcHandlers + Send + Sync + 'static,
283 {
284 Self::init(client, handlers, opts)
285 }
286
287 pub fn create0(client: impl AsyncClient + 'static, opts: Options) -> Self {
289 Self::init(client, DummyHandlers {}, opts)
290 }
291
292 fn init<H>(mut client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
293 where
294 H: RpcHandlers + Send + Sync + 'static,
295 {
296 let timeout = client.get_timeout();
297 let rx = { client.take_event_channel().unwrap() };
298 let connected = client.get_connected_beacon();
299 let client = Arc::new(Mutex::new(client));
300 let calls: CallMap = <_>::default();
301 let processor_fut = Arc::new(SyncMutex::new(tokio::spawn(processor(
302 rx,
303 client.clone(),
304 calls.clone(),
305 Arc::new(handlers),
306 opts,
307 ))));
308 let pinger_client = client.clone();
309 let pfut = processor_fut.clone();
310 let pinger_fut = timeout.map(|t| {
311 tokio::spawn(async move {
312 loop {
313 if let Err(e) = pinger_client.lock().await.ping().await {
314 error!("{}", e);
315 pfut.lock().abort();
316 break;
317 }
318 tokio::time::sleep(t).await;
319 }
320 })
321 });
322 Self {
323 call_id: SyncMutex::new(0),
324 timeout,
325 client,
326 processor_fut,
327 pinger_fut,
328 calls,
329 connected,
330 }
331 }
332}
333
334#[async_trait]
335impl Rpc for RpcClient {
336 #[inline]
337 fn client(&self) -> Arc<Mutex<dyn AsyncClient + 'static>> {
338 self.client.clone()
339 }
340 #[inline]
341 async fn notify(
342 &self,
343 target: &str,
344 data: Cow<'async_trait>,
345 qos: QoS,
346 ) -> Result<OpConfirm, Error> {
347 self.client
348 .lock()
349 .await
350 .zc_send(target, (&[RPC_NOTIFICATION][..]).into(), data, qos)
351 .await
352 }
353 async fn call0(
354 &self,
355 target: &str,
356 method: &str,
357 params: Cow<'async_trait>,
358 qos: QoS,
359 ) -> Result<OpConfirm, Error> {
360 let payload = prepare_call_payload(method, &[0, 0, 0, 0]);
361 self.client
362 .lock()
363 .await
364 .zc_send(target, payload.into(), params, qos)
365 .await
366 }
367 async fn call(
371 &self,
372 target: &str,
373 method: &str,
374 params: Cow<'async_trait>,
375 qos: QoS,
376 ) -> Result<RpcEvent, RpcError> {
377 let call_id = {
378 let mut ci = self.call_id.lock();
379 let mut call_id = *ci;
380 if call_id == u32::MAX {
381 call_id = 1;
382 } else {
383 call_id += 1;
384 }
385 *ci = call_id;
386 call_id
387 };
388 let payload = prepare_call_payload(method, &call_id.to_le_bytes());
389 let (tx, rx) = oneshot::channel();
390 self.calls.lock().insert(call_id, tx);
391 macro_rules! unwrap_or_cancel {
392 ($result: expr) => {
393 match $result {
394 Ok(v) => v,
395 Err(e) => {
396 self.calls.lock().remove(&call_id);
397 return Err(Into::<Error>::into(e).into());
398 }
399 }
400 };
401 }
402 let opc = {
403 let mut client = self.client.lock().await;
404 let fut = client.zc_send(target, payload.into(), params, qos);
405 if let Some(timeout) = self.timeout {
406 unwrap_or_cancel!(unwrap_or_cancel!(tokio::time::timeout(timeout, fut).await))
407 } else {
408 unwrap_or_cancel!(fut.await)
409 }
410 };
411 if let Some(c) = opc {
412 unwrap_or_cancel!(unwrap_or_cancel!(c.await));
413 }
414 let result = rx.await.map_err(Into::<Error>::into)?;
415 if let Ok(e) = RpcError::try_from(&result) {
416 Err(e)
417 } else {
418 Ok(result)
419 }
420 }
421 fn is_connected(&self) -> bool {
422 self.connected
423 .as_ref()
424 .is_none_or(|b| b.load(atomic::Ordering::Relaxed))
425 }
426}
427
428impl Drop for RpcClient {
429 fn drop(&mut self) {
430 self.pinger_fut.as_ref().map(JoinHandle::abort);
431 self.processor_fut.lock().abort();
432 }
433}