1use crate::borrow::Cow;
2use crate::client::AsyncClient;
3use crate::EventChannel;
4use crate::{Error, Frame, FrameKind, OpConfirm, QoS};
5
6use std::collections::BTreeMap;
7use std::fmt;
8use std::sync::atomic;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::oneshot;
12use tokio::sync::Mutex;
13use tokio::task::JoinHandle;
14
15use log::{error, trace, warn};
16
17use async_trait::async_trait;
18
19pub const RPC_NOTIFICATION: u8 = 0x00;
20pub const RPC_REQUEST: u8 = 0x01;
21pub const RPC_REPLY: u8 = 0x11;
22pub const RPC_ERROR: u8 = 0x12;
23
24pub const RPC_ERROR_CODE_PARSE: i16 = -32700;
25pub const RPC_ERROR_CODE_INVALID_REQUEST: i16 = -32600;
26pub const RPC_ERROR_CODE_METHOD_NOT_FOUND: i16 = -32601;
27pub const RPC_ERROR_CODE_INVALID_METHOD_PARAMS: i16 = -32602;
28pub const RPC_ERROR_CODE_INTERNAL: i16 = -32603;
29
30#[derive(Default, Clone, Debug)]
42pub struct Options {
43 blocking_notifications: bool,
44 blocking_frames: bool,
45}
46
47impl Options {
48 #[inline]
49 pub fn new() -> Self {
50 Self::default()
51 }
52 #[inline]
53 pub fn blocking_notifications(mut self) -> Self {
54 self.blocking_notifications = true;
55 self
56 }
57 #[inline]
58 pub fn blocking_frames(mut self) -> Self {
59 self.blocking_frames = true;
60 self
61 }
62}
63
64#[allow(clippy::module_name_repetitions)]
65#[derive(Debug, Eq, PartialEq, Copy, Clone)]
66#[repr(u8)]
67pub enum RpcEventKind {
68 Notification = RPC_NOTIFICATION,
69 Request = RPC_REQUEST,
70 Reply = RPC_REPLY,
71 ErrorReply = RPC_ERROR,
72}
73
74#[allow(clippy::module_name_repetitions)]
75#[inline]
76pub fn rpc_err_str(v: impl fmt::Display) -> Option<Vec<u8>> {
77 Some(v.to_string().as_bytes().to_vec())
78}
79
80impl fmt::Display for RpcEventKind {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 write!(
83 f,
84 "{}",
85 match self {
86 RpcEventKind::Notification => "notifcation",
87 RpcEventKind::Request => "request",
88 RpcEventKind::Reply => "reply",
89 RpcEventKind::ErrorReply => "error reply",
90 }
91 )
92 }
93}
94
95#[allow(clippy::module_name_repetitions)]
96#[derive(Debug)]
97pub struct RpcEvent {
98 kind: RpcEventKind,
99 frame: Frame,
100 payload_pos: usize,
101 use_header: bool,
102}
103
104impl RpcEvent {
105 #[inline]
106 pub fn kind(&self) -> RpcEventKind {
107 self.kind
108 }
109 #[inline]
110 pub fn frame(&self) -> &Frame {
111 &self.frame
112 }
113 #[inline]
114 pub fn sender(&self) -> &str {
115 self.frame.sender()
116 }
117 #[inline]
118 pub fn primary_sender(&self) -> &str {
119 self.frame.primary_sender()
120 }
121 #[inline]
122 pub fn payload(&self) -> &[u8] {
123 &self.frame().payload()[self.payload_pos..]
124 }
125 #[inline]
129 pub fn id(&self) -> u32 {
130 u32::from_le_bytes(
131 if self.use_header {
132 &self.frame.header().unwrap()[1..5]
133 } else {
134 &self.frame.payload()[1..5]
135 }
136 .try_into()
137 .unwrap(),
138 )
139 }
140 #[inline]
141 pub fn is_response_required(&self) -> bool {
142 self.id() != 0
143 }
144 #[inline]
148 pub fn method(&self) -> &[u8] {
149 if self.use_header {
150 let header = self.frame.header.as_ref().unwrap();
151 &header[5..header.len() - 1]
152 } else {
153 &self.frame().payload()[5..self.payload_pos - 1]
154 }
155 }
156 #[inline]
157 pub fn parse_method(&self) -> Result<&str, Error> {
158 std::str::from_utf8(self.method()).map_err(Into::into)
159 }
160 #[inline]
164 pub fn code(&self) -> i16 {
165 if self.kind == RpcEventKind::ErrorReply {
166 i16::from_le_bytes(
167 if self.use_header {
168 &self.frame.header().unwrap()[5..7]
169 } else {
170 &self.frame.payload()[5..7]
171 }
172 .try_into()
173 .unwrap(),
174 )
175 } else {
176 0
177 }
178 }
179}
180
181impl TryFrom<Frame> for RpcEvent {
182 type Error = Error;
183 fn try_from(frame: Frame) -> Result<Self, Self::Error> {
184 let (body, use_header) = frame
185 .header()
186 .map_or_else(|| (frame.payload(), false), |h| (h, true));
187 if body.is_empty() {
188 Err(Error::data("Empty RPC frame"))
189 } else {
190 macro_rules! check_len {
191 ($len: expr) => {
192 if body.len() < $len {
193 return Err(Error::data("Invalid RPC frame"));
194 }
195 };
196 }
197 match body[0] {
198 RPC_NOTIFICATION => Ok(RpcEvent {
199 kind: RpcEventKind::Notification,
200 frame,
201 payload_pos: if use_header { 0 } else { 1 },
202 use_header: false,
203 }),
204 RPC_REQUEST => {
205 check_len!(6);
206 if use_header {
207 Ok(RpcEvent {
208 kind: RpcEventKind::Request,
209 frame,
210 payload_pos: 0,
211 use_header: true,
212 })
213 } else {
214 let mut sp = body[5..].splitn(2, |c| *c == 0);
215 let method = sp.next().ok_or_else(|| Error::data("No RPC method"))?;
216 let payload_pos = 6 + method.len();
217 sp.next()
218 .ok_or_else(|| Error::data("No RPC params block"))?;
219 Ok(RpcEvent {
220 kind: RpcEventKind::Request,
221 frame,
222 payload_pos,
223 use_header: false,
224 })
225 }
226 }
227 RPC_REPLY => {
228 check_len!(5);
229 Ok(RpcEvent {
230 kind: RpcEventKind::Reply,
231 frame,
232 payload_pos: if use_header { 0 } else { 5 },
233 use_header,
234 })
235 }
236 RPC_ERROR => {
237 check_len!(7);
238 Ok(RpcEvent {
239 kind: RpcEventKind::ErrorReply,
240 frame,
241 payload_pos: if use_header { 0 } else { 7 },
242 use_header,
243 })
244 }
245 v => Err(Error::data(format!("Unsupported RPC frame code {}", v))),
246 }
247 }
248 }
249}
250
251#[allow(clippy::module_name_repetitions)]
252#[async_trait]
253pub trait RpcHandlers {
254 async fn handle_call(&self, event: RpcEvent) -> RpcResult;
255 async fn handle_notification(&self, event: RpcEvent);
256 async fn handle_frame(&self, frame: Frame);
257}
258
259pub struct DummyHandlers {}
260
261#[async_trait]
262impl RpcHandlers for DummyHandlers {
263 async fn handle_call(&self, _event: RpcEvent) -> RpcResult {
264 Err(RpcError::new(
265 RPC_ERROR_CODE_METHOD_NOT_FOUND,
266 Some("RPC handler is not implemented".as_bytes().to_vec()),
267 ))
268 }
269 async fn handle_notification(&self, _event: RpcEvent) {}
270 async fn handle_frame(&self, _frame: Frame) {}
271}
272
273type CallMap = Arc<parking_lot::Mutex<BTreeMap<u32, oneshot::Sender<RpcEvent>>>>;
274
275#[async_trait]
276pub trait Rpc {
277 fn client(&self) -> Arc<Mutex<(dyn AsyncClient + 'static)>>;
283 async fn notify(
284 &self,
285 target: &str,
286 data: Cow<'async_trait>,
287 qos: QoS,
288 ) -> Result<OpConfirm, Error>;
289 async fn call0(
291 &self,
292 target: &str,
293 method: &str,
294 params: Cow<'async_trait>,
295 qos: QoS,
296 ) -> Result<OpConfirm, Error>;
297 async fn call(
299 &self,
300 target: &str,
301 method: &str,
302 params: Cow<'async_trait>,
303 qos: QoS,
304 ) -> Result<RpcEvent, RpcError>;
305 fn is_connected(&self) -> bool;
306}
307
308#[allow(clippy::module_name_repetitions)]
309pub struct RpcClient {
310 call_id: parking_lot::Mutex<u32>,
311 timeout: Option<Duration>,
312 client: Arc<Mutex<dyn AsyncClient>>,
313 processor_fut: Arc<parking_lot::Mutex<JoinHandle<()>>>,
314 pinger_fut: Option<JoinHandle<()>>,
315 calls: CallMap,
316 connected: Option<Arc<atomic::AtomicBool>>,
317}
318
319#[allow(clippy::too_many_lines)]
320async fn processor<C, H>(
321 rx: EventChannel,
322 processor_client: Arc<Mutex<C>>,
323 calls: CallMap,
324 handlers: Arc<H>,
325 opts: Options,
326) where
327 C: AsyncClient + 'static,
328 H: RpcHandlers + Send + Sync + 'static,
329{
330 while let Ok(frame) = rx.recv().await {
331 if frame.kind() == FrameKind::Message {
332 match RpcEvent::try_from(frame) {
333 Ok(event) => match event.kind() {
334 RpcEventKind::Notification => {
335 trace!("RPC notification from {}", event.frame().sender());
336 if opts.blocking_notifications {
337 handlers.handle_notification(event).await;
338 } else {
339 let h = handlers.clone();
340 tokio::spawn(async move {
341 h.handle_notification(event).await;
342 });
343 }
344 }
345 RpcEventKind::Request => {
346 let id = event.id();
347 trace!(
348 "RPC request from {}, id: {}, method: {:?}",
349 event.frame().sender(),
350 id,
351 event.method()
352 );
353 let ev = if id > 0 {
354 Some((event.frame().sender().to_owned(), processor_client.clone()))
355 } else {
356 None
357 };
358 let h = handlers.clone();
359 tokio::spawn(async move {
360 let qos = if event.frame().is_realtime() {
361 QoS::RealtimeProcessed
362 } else {
363 QoS::Processed
364 };
365 let res = h.handle_call(event).await;
366 if let Some((target, cl)) = ev {
367 macro_rules! send_reply {
368 ($payload: expr, $result: expr) => {{
369 let mut client = cl.lock().await;
370 if let Some(result) = $result {
371 client
372 .zc_send(&target, $payload, result.into(), qos)
373 .await
374 } else {
375 client
376 .zc_send(&target, $payload, (&[][..]).into(), qos)
377 .await
378 }
379 }};
380 }
381 match res {
382 Ok(v) => {
383 trace!("Sending RPC reply id {} to {}", id, target);
384 let mut payload = Vec::with_capacity(5);
385 payload.push(RPC_REPLY);
386 payload.extend_from_slice(&id.to_le_bytes());
387 let _r = send_reply!(payload.into(), v);
388 }
389 Err(e) => {
390 trace!(
391 "Sending RPC error {} reply id {} to {}",
392 e.code,
393 id,
394 target,
395 );
396 let mut payload = Vec::with_capacity(7);
397 payload.push(RPC_ERROR);
398 payload.extend_from_slice(&id.to_le_bytes());
399 payload.extend_from_slice(&e.code.to_le_bytes());
400 let _r = send_reply!(payload.into(), e.data);
401 }
402 }
403 }
404 });
405 }
406 RpcEventKind::Reply | RpcEventKind::ErrorReply => {
407 let id = event.id();
408 trace!(
409 "RPC {} from {}, id: {}",
410 event.kind(),
411 event.frame().sender(),
412 id
413 );
414 if let Some(tx) = { calls.lock().remove(&id) } {
415 let _r = tx.send(event);
416 } else {
417 warn!("orphaned RPC response: {}", id);
418 }
419 }
420 },
421 Err(e) => {
422 error!("{}", e);
423 }
424 }
425 } else if opts.blocking_frames {
426 handlers.handle_frame(frame).await;
427 } else {
428 let h = handlers.clone();
429 tokio::spawn(async move {
430 h.handle_frame(frame).await;
431 });
432 }
433 }
434}
435
436#[inline]
437fn prepare_call_payload(method: &str, id_bytes: &[u8]) -> Vec<u8> {
438 let m = method.as_bytes();
439 let mut payload = Vec::with_capacity(m.len() + 6);
440 payload.push(RPC_REQUEST);
441 payload.extend(id_bytes);
442 payload.extend(m);
443 payload.push(0x00);
444 payload
445}
446
447impl RpcClient {
448 pub fn new<H>(client: impl AsyncClient + 'static, handlers: H) -> Self
450 where
451 H: RpcHandlers + Send + Sync + 'static,
452 {
453 Self::init(client, handlers, Options::default())
454 }
455
456 pub fn new0(client: impl AsyncClient + 'static) -> Self {
458 Self::init(client, DummyHandlers {}, Options::default())
459 }
460
461 pub fn create<H>(client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
463 where
464 H: RpcHandlers + Send + Sync + 'static,
465 {
466 Self::init(client, handlers, opts)
467 }
468
469 pub fn create0(client: impl AsyncClient + 'static, opts: Options) -> Self {
471 Self::init(client, DummyHandlers {}, opts)
472 }
473
474 fn init<H>(mut client: impl AsyncClient + 'static, handlers: H, opts: Options) -> Self
475 where
476 H: RpcHandlers + Send + Sync + 'static,
477 {
478 let timeout = client.get_timeout();
479 let rx = { client.take_event_channel().unwrap() };
480 let connected = client.get_connected_beacon();
481 let client = Arc::new(Mutex::new(client));
482 let calls: CallMap = <_>::default();
483 let processor_fut = Arc::new(parking_lot::Mutex::new(tokio::spawn(processor(
484 rx,
485 client.clone(),
486 calls.clone(),
487 Arc::new(handlers),
488 opts,
489 ))));
490 let pinger_client = client.clone();
491 let pfut = processor_fut.clone();
492 let pinger_fut = timeout.map(|t| {
493 tokio::spawn(async move {
494 loop {
495 if let Err(e) = pinger_client.lock().await.ping().await {
496 error!("{}", e);
497 pfut.lock().abort();
498 break;
499 }
500 tokio::time::sleep(t).await;
501 }
502 })
503 });
504 Self {
505 call_id: parking_lot::Mutex::new(0),
506 timeout,
507 client,
508 processor_fut,
509 pinger_fut,
510 calls,
511 connected,
512 }
513 }
514}
515
516#[async_trait]
517impl Rpc for RpcClient {
518 #[inline]
519 fn client(&self) -> Arc<Mutex<(dyn AsyncClient + 'static)>> {
520 self.client.clone()
521 }
522 #[inline]
523 async fn notify(
524 &self,
525 target: &str,
526 data: Cow<'async_trait>,
527 qos: QoS,
528 ) -> Result<OpConfirm, Error> {
529 self.client
530 .lock()
531 .await
532 .zc_send(target, (&[RPC_NOTIFICATION][..]).into(), data, qos)
533 .await
534 }
535 async fn call0(
536 &self,
537 target: &str,
538 method: &str,
539 params: Cow<'async_trait>,
540 qos: QoS,
541 ) -> Result<OpConfirm, Error> {
542 let payload = prepare_call_payload(method, &[0, 0, 0, 0]);
543 self.client
544 .lock()
545 .await
546 .zc_send(target, payload.into(), params, qos)
547 .await
548 }
549 async fn call(
553 &self,
554 target: &str,
555 method: &str,
556 params: Cow<'async_trait>,
557 qos: QoS,
558 ) -> Result<RpcEvent, RpcError> {
559 let call_id = {
560 let mut ci = self.call_id.lock();
561 let mut call_id = *ci;
562 if call_id == u32::MAX {
563 call_id = 1;
564 } else {
565 call_id += 1;
566 }
567 *ci = call_id;
568 call_id
569 };
570 let payload = prepare_call_payload(method, &call_id.to_le_bytes());
571 let (tx, rx) = oneshot::channel();
572 self.calls.lock().insert(call_id, tx);
573 macro_rules! unwrap_or_cancel {
574 ($result: expr) => {
575 match $result {
576 Ok(v) => v,
577 Err(e) => {
578 self.calls.lock().remove(&call_id);
579 return Err(Into::<Error>::into(e).into());
580 }
581 }
582 };
583 }
584 let opc = {
585 let mut client = self.client.lock().await;
586 let fut = client.zc_send(target, payload.into(), params, qos);
587 if let Some(timeout) = self.timeout {
588 unwrap_or_cancel!(unwrap_or_cancel!(tokio::time::timeout(timeout, fut).await))
589 } else {
590 unwrap_or_cancel!(fut.await)
591 }
592 };
593 if let Some(c) = opc {
594 unwrap_or_cancel!(unwrap_or_cancel!(c.await));
595 }
596 let result = rx.await.map_err(Into::<Error>::into)?;
597 if let Ok(e) = RpcError::try_from(&result) {
598 Err(e)
599 } else {
600 Ok(result)
601 }
602 }
603 fn is_connected(&self) -> bool {
604 self.connected
605 .as_ref()
606 .map_or(true, |b| b.load(atomic::Ordering::SeqCst))
607 }
608}
609
610impl Drop for RpcClient {
611 fn drop(&mut self) {
612 self.pinger_fut.as_ref().map(JoinHandle::abort);
613 self.processor_fut.lock().abort();
614 }
615}
616
617#[allow(clippy::module_name_repetitions)]
618#[derive(Debug)]
619pub struct RpcError {
620 code: i16,
621 data: Option<Vec<u8>>,
622}
623
624impl TryFrom<&RpcEvent> for RpcError {
625 type Error = Error;
626 #[inline]
627 fn try_from(event: &RpcEvent) -> Result<Self, Self::Error> {
628 if event.kind() == RpcEventKind::ErrorReply {
629 Ok(RpcError::new(event.code(), Some(event.payload().to_vec())))
630 } else {
631 Err(Error::data("not a RPC error"))
632 }
633 }
634}
635
636impl RpcError {
637 #[inline]
638 pub fn new(code: i16, data: Option<Vec<u8>>) -> Self {
639 Self { code, data }
640 }
641 #[inline]
642 pub fn code(&self) -> i16 {
643 self.code
644 }
645 #[inline]
646 pub fn data(&self) -> Option<&[u8]> {
647 self.data.as_deref()
648 }
649 #[inline]
650 pub fn method(err: Option<Vec<u8>>) -> Self {
651 Self {
652 code: RPC_ERROR_CODE_METHOD_NOT_FOUND,
653 data: err,
654 }
655 }
656 #[inline]
657 pub fn params(err: Option<Vec<u8>>) -> Self {
658 Self {
659 code: RPC_ERROR_CODE_INVALID_METHOD_PARAMS,
660 data: err,
661 }
662 }
663 #[inline]
664 pub fn parse(err: Option<Vec<u8>>) -> Self {
665 Self {
666 code: RPC_ERROR_CODE_PARSE,
667 data: err,
668 }
669 }
670 #[inline]
671 pub fn invalid(err: Option<Vec<u8>>) -> Self {
672 Self {
673 code: RPC_ERROR_CODE_INVALID_REQUEST,
674 data: err,
675 }
676 }
677 #[inline]
678 pub fn internal(err: Option<Vec<u8>>) -> Self {
679 Self {
680 code: RPC_ERROR_CODE_INTERNAL,
681 data: err,
682 }
683 }
684 #[inline]
686 pub fn convert_data(v: impl fmt::Display) -> Vec<u8> {
687 v.to_string().as_bytes().to_vec()
688 }
689}
690
691impl From<Error> for RpcError {
692 #[inline]
693 fn from(e: Error) -> RpcError {
694 RpcError {
695 code: -32000 - e.kind() as i16,
696 data: None,
697 }
698 }
699}
700
701impl From<rmp_serde::encode::Error> for RpcError {
702 #[inline]
703 fn from(e: rmp_serde::encode::Error) -> RpcError {
704 RpcError {
705 code: RPC_ERROR_CODE_INTERNAL,
706 data: Some(e.to_string().as_bytes().to_vec()),
707 }
708 }
709}
710
711impl From<std::io::Error> for RpcError {
712 #[inline]
713 fn from(e: std::io::Error) -> RpcError {
714 RpcError {
715 code: RPC_ERROR_CODE_INTERNAL,
716 data: Some(e.to_string().as_bytes().to_vec()),
717 }
718 }
719}
720
721impl From<rmp_serde::decode::Error> for RpcError {
722 #[inline]
723 fn from(e: rmp_serde::decode::Error) -> RpcError {
724 RpcError {
725 code: RPC_ERROR_CODE_PARSE,
726 data: Some(e.to_string().as_bytes().to_vec()),
727 }
728 }
729}
730
731impl fmt::Display for RpcError {
732 #[inline]
733 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
734 write!(f, "rpc error code: {}", self.code)
735 }
736}
737
738#[allow(clippy::module_name_repetitions)]
739pub type RpcResult = Result<Option<Vec<u8>>, RpcError>;