busrt/rpc/
async_client.rs

1use super::{
2    prepare_call_payload, RpcError, RpcEvent, RpcEventKind, RpcResult, RPC_ERROR,
3    RPC_ERROR_CODE_METHOD_NOT_FOUND, RPC_NOTIFICATION, RPC_REPLY,
4};
5use crate::borrow::Cow;
6use crate::client::AsyncClient;
7use crate::EventChannel;
8use crate::{Error, Frame, FrameKind, OpConfirm, QoS};
9use async_trait::async_trait;
10use log::{error, trace, warn};
11#[cfg(not(feature = "rt"))]
12use parking_lot::Mutex as SyncMutex;
13#[cfg(feature = "rt")]
14use parking_lot_rt::Mutex as SyncMutex;
15use std::collections::BTreeMap;
16use std::sync::atomic;
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::oneshot;
20use tokio::sync::Mutex;
21use tokio::task::JoinHandle;
22use tokio_task_pool::{Pool, Task};
23
24/// By default, RPC frame and notification handlers are launched in background which allows
25/// non-blocking event processing, however events can be processed in random order
26///
27/// RPC options allow to launch handlers in blocking mode. In this case handlers must process
28/// events as fast as possible (e.g. send them to processing channels) and avoid using any RPC
29/// client functions from inside.
30///
31/// WARNING: when handling frames in blocking mode, it is forbidden to use the current RPC client
32/// directly or with any kind of bounded channels, otherwise the RPC client may get stuck!
33///
34/// See https://busrt.readthedocs.io/en/latest/rpc_blocking.html
35#[derive(Default, Clone, Debug)]
36pub struct Options {
37    blocking_notifications: bool,
38    blocking_frames: bool,
39    task_pool: Option<Arc<Pool>>,
40}
41
42impl Options {
43    #[inline]
44    pub fn new() -> Self {
45        Self::default()
46    }
47    #[inline]
48    pub fn blocking_notifications(mut self) -> Self {
49        self.blocking_notifications = true;
50        self
51    }
52    #[inline]
53    pub fn blocking_frames(mut self) -> Self {
54        self.blocking_frames = true;
55        self
56    }
57    #[inline]
58    /// See <https://crates.io/crates/tokio-task-pool>
59    pub fn with_task_pool(mut self, pool: Pool) -> Self {
60        self.task_pool = Some(Arc::new(pool));
61        self
62    }
63}
64
65#[allow(clippy::module_name_repetitions)]
66#[async_trait]
67pub trait RpcHandlers {
68    #[allow(unused_variables)]
69    async fn handle_call(&self, event: RpcEvent) -> RpcResult {
70        Err(RpcError::method(None))
71    }
72    #[allow(unused_variables)]
73    async fn handle_notification(&self, event: RpcEvent) {}
74    #[allow(unused_variables)]
75    async fn handle_frame(&self, frame: Frame) {}
76}
77
78pub struct DummyHandlers {}
79
80#[async_trait]
81impl RpcHandlers for DummyHandlers {
82    async fn handle_call(&self, _event: RpcEvent) -> RpcResult {
83        Err(RpcError::new(
84            RPC_ERROR_CODE_METHOD_NOT_FOUND,
85            Some("RPC handler is not implemented".as_bytes().to_vec()),
86        ))
87    }
88}
89
90type CallMap = Arc<SyncMutex<BTreeMap<u32, oneshot::Sender<RpcEvent>>>>;
91
92#[async_trait]
93pub trait Rpc {
94    /// When created, busrt client is wrapped with Arc<Mutex<_>> to let it be sent into
95    /// the incoming frames handler future
96    ///
97    /// This mehtod allows to get the containered-client back, to call its methods directly (manage
98    /// pub/sub and send broadcast messages)
99    fn client(&self) -> Arc<Mutex<dyn AsyncClient + 'static>>;
100    async fn notify(
101        &self,
102        target: &str,
103        data: Cow<'async_trait>,
104        qos: QoS,
105    ) -> Result<OpConfirm, Error>;
106    /// Call the method, no response is required
107    async fn call0(
108        &self,
109        target: &str,
110        method: &str,
111        params: Cow<'async_trait>,
112        qos: QoS,
113    ) -> Result<OpConfirm, Error>;
114    /// Call the method and get the response
115    async fn call(
116        &self,
117        target: &str,
118        method: &str,
119        params: Cow<'async_trait>,
120        qos: QoS,
121    ) -> Result<RpcEvent, RpcError>;
122    fn is_connected(&self) -> bool;
123}
124
125#[allow(clippy::module_name_repetitions)]
126pub struct RpcClient {
127    call_id: SyncMutex<u32>,
128    timeout: Option<Duration>,
129    client: Arc<Mutex<dyn AsyncClient>>,
130    processor_fut: Arc<SyncMutex<JoinHandle<()>>>,
131    pinger_fut: Option<JoinHandle<()>>,
132    calls: CallMap,
133    connected: Option<Arc<atomic::AtomicBool>>,
134}
135
136#[allow(clippy::too_many_lines)]
137async fn processor<C, H>(
138    rx: EventChannel,
139    processor_client: Arc<Mutex<C>>,
140    calls: CallMap,
141    handlers: Arc<H>,
142    opts: Options,
143) where
144    C: AsyncClient + 'static,
145    H: RpcHandlers + Send + Sync + 'static,
146{
147    macro_rules! spawn {
148        ($task_id: expr, $fut: expr) => {
149            if let Some(ref pool) = opts.task_pool {
150                let task = Task::new($fut).with_id($task_id);
151                if let Err(e) = pool.spawn_task(task).await {
152                    error!("Unable to spawn RPC task: {}", e);
153                }
154            } else {
155                tokio::spawn($fut);
156            }
157        };
158    }
159    while let Ok(frame) = rx.recv().await {
160        if frame.kind() == FrameKind::Message {
161            match RpcEvent::try_from(frame) {
162                Ok(event) => match event.kind() {
163                    RpcEventKind::Notification => {
164                        trace!("RPC notification from {}", event.frame().sender());
165                        if opts.blocking_notifications {
166                            handlers.handle_notification(event).await;
167                        } else {
168                            let h = handlers.clone();
169                            spawn!("rpc.notification", async move {
170                                h.handle_notification(event).await;
171                            });
172                        }
173                    }
174                    RpcEventKind::Request => {
175                        let id = event.id();
176                        trace!(
177                            "RPC request from {}, id: {}, method: {:?}",
178                            event.frame().sender(),
179                            id,
180                            event.method()
181                        );
182                        let ev = if id > 0 {
183                            Some((event.frame().sender().to_owned(), processor_client.clone()))
184                        } else {
185                            None
186                        };
187                        let h = handlers.clone();
188                        spawn!("rpc.request", async move {
189                            let qos = if event.frame().is_realtime() {
190                                QoS::RealtimeProcessed
191                            } else {
192                                QoS::Processed
193                            };
194                            let res = h.handle_call(event).await;
195                            if let Some((target, cl)) = ev {
196                                macro_rules! send_reply {
197                                    ($payload: expr, $result: expr) => {{
198                                        let mut client = cl.lock().await;
199                                        if let Some(result) = $result {
200                                            client
201                                                .zc_send(&target, $payload, result.into(), qos)
202                                                .await
203                                        } else {
204                                            client
205                                                .zc_send(&target, $payload, (&[][..]).into(), qos)
206                                                .await
207                                        }
208                                    }};
209                                }
210                                match res {
211                                    Ok(v) => {
212                                        trace!("Sending RPC reply id {} to {}", id, target);
213                                        let mut payload = Vec::with_capacity(5);
214                                        payload.push(RPC_REPLY);
215                                        payload.extend_from_slice(&id.to_le_bytes());
216                                        let _r = send_reply!(payload.into(), v);
217                                    }
218                                    Err(e) => {
219                                        trace!(
220                                            "Sending RPC error {} reply id {} to {}",
221                                            e.code,
222                                            id,
223                                            target,
224                                        );
225                                        let mut payload = Vec::with_capacity(7);
226                                        payload.push(RPC_ERROR);
227                                        payload.extend_from_slice(&id.to_le_bytes());
228                                        payload.extend_from_slice(&e.code.to_le_bytes());
229                                        let _r = send_reply!(payload.into(), e.data);
230                                    }
231                                }
232                            }
233                        });
234                    }
235                    RpcEventKind::Reply | RpcEventKind::ErrorReply => {
236                        let id = event.id();
237                        trace!(
238                            "RPC {} from {}, id: {}",
239                            event.kind(),
240                            event.frame().sender(),
241                            id
242                        );
243                        if let Some(tx) = { calls.lock().remove(&id) } {
244                            let _r = tx.send(event);
245                        } else {
246                            warn!("orphaned RPC response: {}", id);
247                        }
248                    }
249                },
250                Err(e) => {
251                    error!("{}", e);
252                }
253            }
254        } else if opts.blocking_frames {
255            handlers.handle_frame(frame).await;
256        } else {
257            let h = handlers.clone();
258            spawn!("rpc.frame", async move {
259                h.handle_frame(frame).await;
260            });
261        }
262    }
263}
264
265impl RpcClient {
266    /// creates RPC client with the specified handlers and the default options
267    pub fn new<H>(client: impl AsyncClient + 'static, handlers: H) -> Self
268    where
269        H: RpcHandlers + Send + Sync + 'static,
270    {
271        Self::init(client, handlers, Options::default())
272    }
273
274    /// creates RPC client with dummy handlers and the default options
275    pub fn new0(client: impl AsyncClient + 'static) -> Self {
276        Self::init(client, DummyHandlers {}, Options::default())
277    }
278
279    /// creates RPC client
280    pub fn create<H>(client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
281    where
282        H: RpcHandlers + Send + Sync + 'static,
283    {
284        Self::init(client, handlers, opts)
285    }
286
287    /// creates RPC client with dummy handlers
288    pub fn create0(client: impl AsyncClient + 'static, opts: Options) -> Self {
289        Self::init(client, DummyHandlers {}, opts)
290    }
291
292    fn init<H>(mut client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
293    where
294        H: RpcHandlers + Send + Sync + 'static,
295    {
296        let timeout = client.get_timeout();
297        let rx = { client.take_event_channel().unwrap() };
298        let connected = client.get_connected_beacon();
299        let client = Arc::new(Mutex::new(client));
300        let calls: CallMap = <_>::default();
301        let processor_fut = Arc::new(SyncMutex::new(tokio::spawn(processor(
302            rx,
303            client.clone(),
304            calls.clone(),
305            Arc::new(handlers),
306            opts,
307        ))));
308        let pinger_client = client.clone();
309        let pfut = processor_fut.clone();
310        let pinger_fut = timeout.map(|t| {
311            tokio::spawn(async move {
312                loop {
313                    if let Err(e) = pinger_client.lock().await.ping().await {
314                        error!("{}", e);
315                        pfut.lock().abort();
316                        break;
317                    }
318                    tokio::time::sleep(t).await;
319                }
320            })
321        });
322        Self {
323            call_id: SyncMutex::new(0),
324            timeout,
325            client,
326            processor_fut,
327            pinger_fut,
328            calls,
329            connected,
330        }
331    }
332}
333
334#[async_trait]
335impl Rpc for RpcClient {
336    #[inline]
337    fn client(&self) -> Arc<Mutex<dyn AsyncClient + 'static>> {
338        self.client.clone()
339    }
340    #[inline]
341    async fn notify(
342        &self,
343        target: &str,
344        data: Cow<'async_trait>,
345        qos: QoS,
346    ) -> Result<OpConfirm, Error> {
347        self.client
348            .lock()
349            .await
350            .zc_send(target, (&[RPC_NOTIFICATION][..]).into(), data, qos)
351            .await
352    }
353    async fn call0(
354        &self,
355        target: &str,
356        method: &str,
357        params: Cow<'async_trait>,
358        qos: QoS,
359    ) -> Result<OpConfirm, Error> {
360        let payload = prepare_call_payload(method, &[0, 0, 0, 0]);
361        self.client
362            .lock()
363            .await
364            .zc_send(target, payload.into(), params, qos)
365            .await
366    }
367    /// # Panics
368    ///
369    /// Will panic on poisoned mutex
370    async fn call(
371        &self,
372        target: &str,
373        method: &str,
374        params: Cow<'async_trait>,
375        qos: QoS,
376    ) -> Result<RpcEvent, RpcError> {
377        let call_id = {
378            let mut ci = self.call_id.lock();
379            let mut call_id = *ci;
380            if call_id == u32::MAX {
381                call_id = 1;
382            } else {
383                call_id += 1;
384            }
385            *ci = call_id;
386            call_id
387        };
388        let payload = prepare_call_payload(method, &call_id.to_le_bytes());
389        let (tx, rx) = oneshot::channel();
390        self.calls.lock().insert(call_id, tx);
391        macro_rules! unwrap_or_cancel {
392            ($result: expr) => {
393                match $result {
394                    Ok(v) => v,
395                    Err(e) => {
396                        self.calls.lock().remove(&call_id);
397                        return Err(Into::<Error>::into(e).into());
398                    }
399                }
400            };
401        }
402        let opc = {
403            let mut client = self.client.lock().await;
404            let fut = client.zc_send(target, payload.into(), params, qos);
405            if let Some(timeout) = self.timeout {
406                unwrap_or_cancel!(unwrap_or_cancel!(tokio::time::timeout(timeout, fut).await))
407            } else {
408                unwrap_or_cancel!(fut.await)
409            }
410        };
411        if let Some(c) = opc {
412            unwrap_or_cancel!(unwrap_or_cancel!(c.await));
413        }
414        let result = rx.await.map_err(Into::<Error>::into)?;
415        if let Ok(e) = RpcError::try_from(&result) {
416            Err(e)
417        } else {
418            Ok(result)
419        }
420    }
421    fn is_connected(&self) -> bool {
422        self.connected
423            .as_ref()
424            .is_none_or(|b| b.load(atomic::Ordering::Relaxed))
425    }
426}
427
428impl Drop for RpcClient {
429    fn drop(&mut self) {
430        self.pinger_fut.as_ref().map(JoinHandle::abort);
431        self.processor_fut.lock().abort();
432    }
433}