1use std::{
14 mem::size_of,
15 pin::Pin,
16 task::{Context, Poll},
17};
18
19use futures::{Sink, Stream};
20use serde::{Deserialize, Serialize};
21use tarpc::transport::channel::ChannelError;
22use tokio::sync::mpsc;
23
24pub mod agent;
25pub mod control;
26pub mod error;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub enum MuxMessage<Parent, Child> {
30 Parent(Parent),
31 Child(Child),
32}
33
34#[macro_export]
35macro_rules! define_rpc_mux {
36 ( parent ; $parent_req:ty => $parent_res:ty ; $child_req:ty => $child_res:ty $(;)? ) => {
37 pub type MuxedMessageIncoming = ::snops_common::rpc::MuxMessage<
39 ::tarpc::ClientMessage<$parent_req>,
40 ::tarpc::Response<$child_res>,
41 >;
42
43 pub type MuxedMessageOutgoing = ::snops_common::rpc::MuxMessage<
45 ::tarpc::Response<$parent_res>,
46 ::tarpc::ClientMessage<$child_req>,
47 >;
48 };
49 ( child ; $parent_req:ty => $parent_res:ty ; $child_req:ty => $child_res:ty $(;)? ) => {
50 pub type MuxedMessageIncoming = ::snops_common::rpc::MuxMessage<
52 ::tarpc::Response<$parent_res>,
53 ::tarpc::ClientMessage<$child_req>,
54 >;
55
56 pub type MuxedMessageOutgoing = ::snops_common::rpc::MuxMessage<
58 ::tarpc::ClientMessage<$parent_req>,
59 ::tarpc::Response<$child_res>,
60 >;
61 };
62}
63
64pub const PING_LENGTH: usize = size_of::<u32>() + size_of::<u128>();
65pub const PING_INTERVAL_SEC: u64 = 10;
66
67pub struct RpcTransport<In, Out> {
68 tx: mpsc::UnboundedSender<Out>,
69 rx: mpsc::UnboundedReceiver<In>,
70}
71
72impl<In, Out> RpcTransport<In, Out> {
73 pub fn new() -> (
78 mpsc::UnboundedSender<In>,
79 Self,
80 mpsc::UnboundedReceiver<Out>,
81 ) {
82 let (tx1, rx1) = mpsc::unbounded_channel();
83 let (tx2, rx2) = mpsc::unbounded_channel();
84 (tx1, Self { tx: tx2, rx: rx1 }, rx2)
85 }
86}
87
88impl<In, Out> Stream for RpcTransport<In, Out> {
89 type Item = Result<In, ChannelError>;
90
91 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92 self.rx
93 .poll_recv(cx)
94 .map(|o| o.map(Ok))
95 .map_err(ChannelError::Receive)
96 }
97}
98
99const CLOSED_MESSAGE: &str = "the channel is closed";
100
101impl<In, Out> Sink<Out> for RpcTransport<In, Out> {
102 type Error = ChannelError;
103
104 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105 Poll::Ready(if self.tx.is_closed() {
106 Err(ChannelError::Ready(CLOSED_MESSAGE.into()))
107 } else {
108 Ok(())
109 })
110 }
111
112 fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
113 self.tx
114 .send(item)
115 .map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
116 }
117
118 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119 Poll::Ready(Ok(()))
120 }
121
122 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123 Poll::Ready(Ok(()))
124 }
125}