elbus/
rpc.rs

1use crate::borrow::Cow;
2use crate::client::AsyncClient;
3use crate::EventChannel;
4use crate::{Error, Frame, FrameKind, OpConfirm, QoS};
5
6use std::collections::BTreeMap;
7use std::fmt;
8use std::sync::atomic;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::oneshot;
12use tokio::sync::Mutex;
13use tokio::task::JoinHandle;
14
15use log::{error, trace, warn};
16
17use async_trait::async_trait;
18
19pub const RPC_NOTIFICATION: u8 = 0x00;
20pub const RPC_REQUEST: u8 = 0x01;
21pub const RPC_REPLY: u8 = 0x11;
22pub const RPC_ERROR: u8 = 0x12;
23
24pub const RPC_ERROR_CODE_PARSE: i16 = -32700;
25pub const RPC_ERROR_CODE_INVALID_REQUEST: i16 = -32600;
26pub const RPC_ERROR_CODE_METHOD_NOT_FOUND: i16 = -32601;
27pub const RPC_ERROR_CODE_INVALID_METHOD_PARAMS: i16 = -32602;
28pub const RPC_ERROR_CODE_INTERNAL: i16 = -32603;
29
30/// By default, RPC frame and notification handlers are launched in background, which allows
31/// non-blocking event processing, however events can be processed in random order
32///
33/// RPC options allow to launch handlers in blocking mode. In this case handlers must process
34/// events as fast as possible (e.g. send them to processing channels) and avoid using any RPC
35/// client functions from inside.
36///
37/// WARNING: when handling frames in blocking mode, it is forbidden to use the current RPC client
38/// directly or with any kind of bounded channels, otherwise the RPC client may get stuck!
39///
40/// See https://elbus.readthedocs.io/en/latest/rpc_blocking.html
41#[derive(Default, Clone, Debug)]
42pub struct Options {
43    blocking_notifications: bool,
44    blocking_frames: bool,
45}
46
47impl Options {
48    #[inline]
49    pub fn new() -> Self {
50        Self::default()
51    }
52    #[inline]
53    pub fn blocking_notifications(mut self) -> Self {
54        self.blocking_notifications = true;
55        self
56    }
57    #[inline]
58    pub fn blocking_frames(mut self) -> Self {
59        self.blocking_frames = true;
60        self
61    }
62}
63
64#[allow(clippy::module_name_repetitions)]
65#[derive(Debug, Eq, PartialEq, Copy, Clone)]
66#[repr(u8)]
67pub enum RpcEventKind {
68    Notification = RPC_NOTIFICATION,
69    Request = RPC_REQUEST,
70    Reply = RPC_REPLY,
71    ErrorReply = RPC_ERROR,
72}
73
74#[allow(clippy::module_name_repetitions)]
75#[inline]
76pub fn rpc_err_str(v: impl fmt::Display) -> Option<Vec<u8>> {
77    Some(v.to_string().as_bytes().to_vec())
78}
79
80impl fmt::Display for RpcEventKind {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        write!(
83            f,
84            "{}",
85            match self {
86                RpcEventKind::Notification => "notifcation",
87                RpcEventKind::Request => "request",
88                RpcEventKind::Reply => "reply",
89                RpcEventKind::ErrorReply => "error reply",
90            }
91        )
92    }
93}
94
95#[allow(clippy::module_name_repetitions)]
96#[derive(Debug)]
97pub struct RpcEvent {
98    kind: RpcEventKind,
99    frame: Frame,
100    payload_pos: usize,
101    use_header: bool,
102}
103
104impl RpcEvent {
105    #[inline]
106    pub fn kind(&self) -> RpcEventKind {
107        self.kind
108    }
109    #[inline]
110    pub fn frame(&self) -> &Frame {
111        &self.frame
112    }
113    #[inline]
114    pub fn sender(&self) -> &str {
115        self.frame.sender()
116    }
117    #[inline]
118    pub fn primary_sender(&self) -> &str {
119        self.frame.primary_sender()
120    }
121    #[inline]
122    pub fn payload(&self) -> &[u8] {
123        &self.frame().payload()[self.payload_pos..]
124    }
125    /// # Panics
126    ///
127    /// Should not panic
128    #[inline]
129    pub fn id(&self) -> u32 {
130        u32::from_le_bytes(
131            if self.use_header {
132                &self.frame.header().unwrap()[1..5]
133            } else {
134                &self.frame.payload()[1..5]
135            }
136            .try_into()
137            .unwrap(),
138        )
139    }
140    #[inline]
141    pub fn is_response_required(&self) -> bool {
142        self.id() != 0
143    }
144    /// # Panics
145    ///
146    /// Should not panic
147    #[inline]
148    pub fn method(&self) -> &[u8] {
149        if self.use_header {
150            let header = self.frame.header.as_ref().unwrap();
151            &header[5..header.len() - 1]
152        } else {
153            &self.frame().payload()[5..self.payload_pos - 1]
154        }
155    }
156    #[inline]
157    pub fn parse_method(&self) -> Result<&str, Error> {
158        std::str::from_utf8(self.method()).map_err(Into::into)
159    }
160    /// # Panics
161    ///
162    /// Should not panic
163    #[inline]
164    pub fn code(&self) -> i16 {
165        if self.kind == RpcEventKind::ErrorReply {
166            i16::from_le_bytes(
167                if self.use_header {
168                    &self.frame.header().unwrap()[5..7]
169                } else {
170                    &self.frame.payload()[5..7]
171                }
172                .try_into()
173                .unwrap(),
174            )
175        } else {
176            0
177        }
178    }
179}
180
181impl TryFrom<Frame> for RpcEvent {
182    type Error = Error;
183    fn try_from(frame: Frame) -> Result<Self, Self::Error> {
184        let (body, use_header) = frame
185            .header()
186            .map_or_else(|| (frame.payload(), false), |h| (h, true));
187        if body.is_empty() {
188            Err(Error::data("Empty RPC frame"))
189        } else {
190            macro_rules! check_len {
191                ($len: expr) => {
192                    if body.len() < $len {
193                        return Err(Error::data("Invalid RPC frame"));
194                    }
195                };
196            }
197            match body[0] {
198                RPC_NOTIFICATION => Ok(RpcEvent {
199                    kind: RpcEventKind::Notification,
200                    frame,
201                    payload_pos: if use_header { 0 } else { 1 },
202                    use_header: false,
203                }),
204                RPC_REQUEST => {
205                    check_len!(6);
206                    if use_header {
207                        Ok(RpcEvent {
208                            kind: RpcEventKind::Request,
209                            frame,
210                            payload_pos: 0,
211                            use_header: true,
212                        })
213                    } else {
214                        let mut sp = body[5..].splitn(2, |c| *c == 0);
215                        let method = sp.next().ok_or_else(|| Error::data("No RPC method"))?;
216                        let payload_pos = 6 + method.len();
217                        sp.next()
218                            .ok_or_else(|| Error::data("No RPC params block"))?;
219                        Ok(RpcEvent {
220                            kind: RpcEventKind::Request,
221                            frame,
222                            payload_pos,
223                            use_header: false,
224                        })
225                    }
226                }
227                RPC_REPLY => {
228                    check_len!(5);
229                    Ok(RpcEvent {
230                        kind: RpcEventKind::Reply,
231                        frame,
232                        payload_pos: if use_header { 0 } else { 5 },
233                        use_header,
234                    })
235                }
236                RPC_ERROR => {
237                    check_len!(7);
238                    Ok(RpcEvent {
239                        kind: RpcEventKind::ErrorReply,
240                        frame,
241                        payload_pos: if use_header { 0 } else { 7 },
242                        use_header,
243                    })
244                }
245                v => Err(Error::data(format!("Unsupported RPC frame code {}", v))),
246            }
247        }
248    }
249}
250
251#[allow(clippy::module_name_repetitions)]
252#[async_trait]
253pub trait RpcHandlers {
254    async fn handle_call(&self, event: RpcEvent) -> RpcResult;
255    async fn handle_notification(&self, event: RpcEvent);
256    async fn handle_frame(&self, frame: Frame);
257}
258
259pub struct DummyHandlers {}
260
261#[async_trait]
262impl RpcHandlers for DummyHandlers {
263    async fn handle_call(&self, _event: RpcEvent) -> RpcResult {
264        Err(RpcError::new(
265            RPC_ERROR_CODE_METHOD_NOT_FOUND,
266            Some("RPC handler is not implemented".as_bytes().to_vec()),
267        ))
268    }
269    async fn handle_notification(&self, _event: RpcEvent) {}
270    async fn handle_frame(&self, _frame: Frame) {}
271}
272
273type CallMap = Arc<parking_lot::Mutex<BTreeMap<u32, oneshot::Sender<RpcEvent>>>>;
274
275#[async_trait]
276pub trait Rpc {
277    /// When created, elbus client is wrapped with Arc<Mutex<_>> to let it be sent into
278    /// the incoming frames handler future
279    ///
280    /// This mehtod allows to get the containered-client back, to call its methods directly (manage
281    /// pub/sub and send broadcast messages)
282    fn client(&self) -> Arc<Mutex<(dyn AsyncClient + 'static)>>;
283    async fn notify(
284        &self,
285        target: &str,
286        data: Cow<'async_trait>,
287        qos: QoS,
288    ) -> Result<OpConfirm, Error>;
289    /// Call the method, no response is required
290    async fn call0(
291        &self,
292        target: &str,
293        method: &str,
294        params: Cow<'async_trait>,
295        qos: QoS,
296    ) -> Result<OpConfirm, Error>;
297    /// Call the method and get the response
298    async fn call(
299        &self,
300        target: &str,
301        method: &str,
302        params: Cow<'async_trait>,
303        qos: QoS,
304    ) -> Result<RpcEvent, RpcError>;
305    fn is_connected(&self) -> bool;
306}
307
308#[allow(clippy::module_name_repetitions)]
309pub struct RpcClient {
310    call_id: parking_lot::Mutex<u32>,
311    timeout: Option<Duration>,
312    client: Arc<Mutex<dyn AsyncClient>>,
313    processor_fut: Arc<parking_lot::Mutex<JoinHandle<()>>>,
314    pinger_fut: Option<JoinHandle<()>>,
315    calls: CallMap,
316    connected: Option<Arc<atomic::AtomicBool>>,
317}
318
319#[allow(clippy::too_many_lines)]
320async fn processor<C, H>(
321    rx: EventChannel,
322    processor_client: Arc<Mutex<C>>,
323    calls: CallMap,
324    handlers: Arc<H>,
325    opts: Options,
326) where
327    C: AsyncClient + 'static,
328    H: RpcHandlers + Send + Sync + 'static,
329{
330    while let Ok(frame) = rx.recv().await {
331        if frame.kind() == FrameKind::Message {
332            match RpcEvent::try_from(frame) {
333                Ok(event) => match event.kind() {
334                    RpcEventKind::Notification => {
335                        trace!("RPC notification from {}", event.frame().sender());
336                        if opts.blocking_notifications {
337                            handlers.handle_notification(event).await;
338                        } else {
339                            let h = handlers.clone();
340                            tokio::spawn(async move {
341                                h.handle_notification(event).await;
342                            });
343                        }
344                    }
345                    RpcEventKind::Request => {
346                        let id = event.id();
347                        trace!(
348                            "RPC request from {}, id: {}, method: {:?}",
349                            event.frame().sender(),
350                            id,
351                            event.method()
352                        );
353                        let ev = if id > 0 {
354                            Some((event.frame().sender().to_owned(), processor_client.clone()))
355                        } else {
356                            None
357                        };
358                        let h = handlers.clone();
359                        tokio::spawn(async move {
360                            let qos = if event.frame().is_realtime() {
361                                QoS::RealtimeProcessed
362                            } else {
363                                QoS::Processed
364                            };
365                            let res = h.handle_call(event).await;
366                            if let Some((target, cl)) = ev {
367                                macro_rules! send_reply {
368                                    ($payload: expr, $result: expr) => {{
369                                        let mut client = cl.lock().await;
370                                        if let Some(result) = $result {
371                                            client
372                                                .zc_send(&target, $payload, result.into(), qos)
373                                                .await
374                                        } else {
375                                            client
376                                                .zc_send(&target, $payload, (&[][..]).into(), qos)
377                                                .await
378                                        }
379                                    }};
380                                }
381                                match res {
382                                    Ok(v) => {
383                                        trace!("Sending RPC reply id {} to {}", id, target);
384                                        let mut payload = Vec::with_capacity(5);
385                                        payload.push(RPC_REPLY);
386                                        payload.extend_from_slice(&id.to_le_bytes());
387                                        let _r = send_reply!(payload.into(), v);
388                                    }
389                                    Err(e) => {
390                                        trace!(
391                                            "Sending RPC error {} reply id {} to {}",
392                                            e.code,
393                                            id,
394                                            target,
395                                        );
396                                        let mut payload = Vec::with_capacity(7);
397                                        payload.push(RPC_ERROR);
398                                        payload.extend_from_slice(&id.to_le_bytes());
399                                        payload.extend_from_slice(&e.code.to_le_bytes());
400                                        let _r = send_reply!(payload.into(), e.data);
401                                    }
402                                }
403                            }
404                        });
405                    }
406                    RpcEventKind::Reply | RpcEventKind::ErrorReply => {
407                        let id = event.id();
408                        trace!(
409                            "RPC {} from {}, id: {}",
410                            event.kind(),
411                            event.frame().sender(),
412                            id
413                        );
414                        if let Some(tx) = { calls.lock().remove(&id) } {
415                            let _r = tx.send(event);
416                        } else {
417                            warn!("orphaned RPC response: {}", id);
418                        }
419                    }
420                },
421                Err(e) => {
422                    error!("{}", e);
423                }
424            }
425        } else if opts.blocking_frames {
426            handlers.handle_frame(frame).await;
427        } else {
428            let h = handlers.clone();
429            tokio::spawn(async move {
430                h.handle_frame(frame).await;
431            });
432        }
433    }
434}
435
436#[inline]
437fn prepare_call_payload(method: &str, id_bytes: &[u8]) -> Vec<u8> {
438    let m = method.as_bytes();
439    let mut payload = Vec::with_capacity(m.len() + 6);
440    payload.push(RPC_REQUEST);
441    payload.extend(id_bytes);
442    payload.extend(m);
443    payload.push(0x00);
444    payload
445}
446
447impl RpcClient {
448    /// creates RPC client with the specified handlers and the default options
449    pub fn new<H>(client: impl AsyncClient + 'static, handlers: H) -> Self
450    where
451        H: RpcHandlers + Send + Sync + 'static,
452    {
453        Self::init(client, handlers, Options::default())
454    }
455
456    /// creates RPC client with dummy handlers and the default options
457    pub fn new0(client: impl AsyncClient + 'static) -> Self {
458        Self::init(client, DummyHandlers {}, Options::default())
459    }
460
461    /// creates RPC client
462    pub fn create<H>(client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
463    where
464        H: RpcHandlers + Send + Sync + 'static,
465    {
466        Self::init(client, handlers, opts)
467    }
468
469    /// creates RPC client with dummy handlers
470    pub fn create0(client: impl AsyncClient + 'static, opts: Options) -> Self {
471        Self::init(client, DummyHandlers {}, opts)
472    }
473
474    fn init<H>(mut client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
475    where
476        H: RpcHandlers + Send + Sync + 'static,
477    {
478        let timeout = client.get_timeout();
479        let rx = { client.take_event_channel().unwrap() };
480        let connected = client.get_connected_beacon();
481        let client = Arc::new(Mutex::new(client));
482        let calls: CallMap = <_>::default();
483        let processor_fut = Arc::new(parking_lot::Mutex::new(tokio::spawn(processor(
484            rx,
485            client.clone(),
486            calls.clone(),
487            Arc::new(handlers),
488            opts,
489        ))));
490        let pinger_client = client.clone();
491        let pfut = processor_fut.clone();
492        let pinger_fut = timeout.map(|t| {
493            tokio::spawn(async move {
494                loop {
495                    if let Err(e) = pinger_client.lock().await.ping().await {
496                        error!("{}", e);
497                        pfut.lock().abort();
498                        break;
499                    }
500                    tokio::time::sleep(t).await;
501                }
502            })
503        });
504        Self {
505            call_id: parking_lot::Mutex::new(0),
506            timeout,
507            client,
508            processor_fut,
509            pinger_fut,
510            calls,
511            connected,
512        }
513    }
514}
515
516#[async_trait]
517impl Rpc for RpcClient {
518    #[inline]
519    fn client(&self) -> Arc<Mutex<(dyn AsyncClient + 'static)>> {
520        self.client.clone()
521    }
522    #[inline]
523    async fn notify(
524        &self,
525        target: &str,
526        data: Cow<'async_trait>,
527        qos: QoS,
528    ) -> Result<OpConfirm, Error> {
529        self.client
530            .lock()
531            .await
532            .zc_send(target, (&[RPC_NOTIFICATION][..]).into(), data, qos)
533            .await
534    }
535    async fn call0(
536        &self,
537        target: &str,
538        method: &str,
539        params: Cow<'async_trait>,
540        qos: QoS,
541    ) -> Result<OpConfirm, Error> {
542        let payload = prepare_call_payload(method, &[0, 0, 0, 0]);
543        self.client
544            .lock()
545            .await
546            .zc_send(target, payload.into(), params, qos)
547            .await
548    }
549    /// # Panics
550    ///
551    /// Will panic on poisoned mutex
552    async fn call(
553        &self,
554        target: &str,
555        method: &str,
556        params: Cow<'async_trait>,
557        qos: QoS,
558    ) -> Result<RpcEvent, RpcError> {
559        let call_id = {
560            let mut ci = self.call_id.lock();
561            let mut call_id = *ci;
562            if call_id == u32::MAX {
563                call_id = 1;
564            } else {
565                call_id += 1;
566            }
567            *ci = call_id;
568            call_id
569        };
570        let payload = prepare_call_payload(method, &call_id.to_le_bytes());
571        let (tx, rx) = oneshot::channel();
572        self.calls.lock().insert(call_id, tx);
573        macro_rules! unwrap_or_cancel {
574            ($result: expr) => {
575                match $result {
576                    Ok(v) => v,
577                    Err(e) => {
578                        self.calls.lock().remove(&call_id);
579                        return Err(Into::<Error>::into(e).into());
580                    }
581                }
582            };
583        }
584        let opc = {
585            let mut client = self.client.lock().await;
586            let fut = client.zc_send(target, payload.into(), params, qos);
587            if let Some(timeout) = self.timeout {
588                unwrap_or_cancel!(unwrap_or_cancel!(tokio::time::timeout(timeout, fut).await))
589            } else {
590                unwrap_or_cancel!(fut.await)
591            }
592        };
593        if let Some(c) = opc {
594            unwrap_or_cancel!(unwrap_or_cancel!(c.await));
595        }
596        let result = rx.await.map_err(Into::<Error>::into)?;
597        if let Ok(e) = RpcError::try_from(&result) {
598            Err(e)
599        } else {
600            Ok(result)
601        }
602    }
603    fn is_connected(&self) -> bool {
604        self.connected
605            .as_ref()
606            .map_or(true, |b| b.load(atomic::Ordering::SeqCst))
607    }
608}
609
610impl Drop for RpcClient {
611    fn drop(&mut self) {
612        self.pinger_fut.as_ref().map(JoinHandle::abort);
613        self.processor_fut.lock().abort();
614    }
615}
616
617#[allow(clippy::module_name_repetitions)]
618#[derive(Debug)]
619pub struct RpcError {
620    code: i16,
621    data: Option<Vec<u8>>,
622}
623
624impl TryFrom<&RpcEvent> for RpcError {
625    type Error = Error;
626    #[inline]
627    fn try_from(event: &RpcEvent) -> Result<Self, Self::Error> {
628        if event.kind() == RpcEventKind::ErrorReply {
629            Ok(RpcError::new(event.code(), Some(event.payload().to_vec())))
630        } else {
631            Err(Error::data("not a RPC error"))
632        }
633    }
634}
635
636impl RpcError {
637    #[inline]
638    pub fn new(code: i16, data: Option<Vec<u8>>) -> Self {
639        Self { code, data }
640    }
641    #[inline]
642    pub fn code(&self) -> i16 {
643        self.code
644    }
645    #[inline]
646    pub fn data(&self) -> Option<&[u8]> {
647        self.data.as_deref()
648    }
649    #[inline]
650    pub fn method(err: Option<Vec<u8>>) -> Self {
651        Self {
652            code: RPC_ERROR_CODE_METHOD_NOT_FOUND,
653            data: err,
654        }
655    }
656    #[inline]
657    pub fn params(err: Option<Vec<u8>>) -> Self {
658        Self {
659            code: RPC_ERROR_CODE_INVALID_METHOD_PARAMS,
660            data: err,
661        }
662    }
663    #[inline]
664    pub fn parse(err: Option<Vec<u8>>) -> Self {
665        Self {
666            code: RPC_ERROR_CODE_PARSE,
667            data: err,
668        }
669    }
670    #[inline]
671    pub fn invalid(err: Option<Vec<u8>>) -> Self {
672        Self {
673            code: RPC_ERROR_CODE_INVALID_REQUEST,
674            data: err,
675        }
676    }
677    #[inline]
678    pub fn internal(err: Option<Vec<u8>>) -> Self {
679        Self {
680            code: RPC_ERROR_CODE_INTERNAL,
681            data: err,
682        }
683    }
684    /// Converts displayable to Vec<u8>
685    #[inline]
686    pub fn convert_data(v: impl fmt::Display) -> Vec<u8> {
687        v.to_string().as_bytes().to_vec()
688    }
689}
690
691impl From<Error> for RpcError {
692    #[inline]
693    fn from(e: Error) -> RpcError {
694        RpcError {
695            code: -32000 - e.kind() as i16,
696            data: None,
697        }
698    }
699}
700
701impl From<rmp_serde::encode::Error> for RpcError {
702    #[inline]
703    fn from(e: rmp_serde::encode::Error) -> RpcError {
704        RpcError {
705            code: RPC_ERROR_CODE_INTERNAL,
706            data: Some(e.to_string().as_bytes().to_vec()),
707        }
708    }
709}
710
711impl From<std::io::Error> for RpcError {
712    #[inline]
713    fn from(e: std::io::Error) -> RpcError {
714        RpcError {
715            code: RPC_ERROR_CODE_INTERNAL,
716            data: Some(e.to_string().as_bytes().to_vec()),
717        }
718    }
719}
720
721impl From<rmp_serde::decode::Error> for RpcError {
722    #[inline]
723    fn from(e: rmp_serde::decode::Error) -> RpcError {
724        RpcError {
725            code: RPC_ERROR_CODE_PARSE,
726            data: Some(e.to_string().as_bytes().to_vec()),
727        }
728    }
729}
730
731impl fmt::Display for RpcError {
732    #[inline]
733    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
734        write!(f, "rpc error code: {}", self.code)
735    }
736}
737
738#[allow(clippy::module_name_repetitions)]
739pub type RpcResult = Result<Option<Vec<u8>>, RpcError>;