canary/channel/encrypted/
bidirectional.rs

1use std::sync::Arc;
2
3use derive_more::From;
4use serde::{de::DeserializeOwned, Serialize};
5use snow::StatelessTransportState;
6
7use crate::{
8    async_snow::RefDividedSnow,
9    channel::raw::{
10        joint::unformatted::RefUnformattedRawChannel,
11        unified::unformatted::UnformattedRawUnifiedChannel,
12    },
13    serialization::formats::{Format, ReadFormat, SendFormat},
14    Result,
15};
16
17use super::{
18    bipartite::{BipartiteChannel, UnformattedBipartiteChannel},
19    receive_channel::{ReceiveChannel, UnformattedReceiveChannel},
20    send_channel::{SendChannel, UnformattedSendChannel},
21    snowwith::WithCipher,
22    unified::{UnformattedUnifiedChannel, UnifiedChannel},
23};
24
25#[derive(From)]
26/// Reference unformatted bidirectional channel, may be encrypted
27pub enum RefUnformattedBidirectionalChannel<'a> {
28    /// Unencrypted channel
29    Raw(RefUnformattedRawChannel<'a>),
30    /// Encrypted channel
31    Encrypted(
32        RefUnformattedRawChannel<'a>,
33        &'a StatelessTransportState,
34        &'a mut u32,
35    ),
36}
37
38#[derive(From)]
39/// Unformatted bidirectional channel which may be unified or bipartite
40pub enum UnformattedBidirectionalChannel {
41    /// Channel has not been split
42    Unified(UnformattedUnifiedChannel),
43    /// Channel has been split
44    Bipartite(UnformattedBipartiteChannel),
45}
46
47#[derive(From)]
48/// Reference channel with formats, similar to `&Channel`
49pub struct RefChannel<'a, R = Format, W = Format> {
50    /// Inner channel
51    channel: RefUnformattedBidirectionalChannel<'a>,
52    /// Inner receive format
53    receive_format: R,
54    /// Inner send format
55    send_format: W,
56}
57
58#[derive(From)]
59/// Channel with formats
60pub enum Channel<R = Format, W = Format> {
61    /// Channel has not been split
62    Unified(UnifiedChannel<R, W>),
63    /// Channel has been split
64    Bipartite(BipartiteChannel<R, W>),
65}
66
67impl<'a, R, W> RefChannel<'a, R, W> {
68    /// Send an object through the channel
69    /// ```no_run
70    /// chan.send("Hello world!").await?;
71    /// ```
72    pub async fn send<T: Serialize>(&mut self, obj: T) -> Result<usize>
73    where
74        W: SendFormat,
75    {
76        self.channel.send(obj, &mut self.send_format).await
77    }
78    /// Receive an object sent through the channel
79    /// ```no_run
80    /// let string: String = chan.receive().await?;
81    /// ```
82    pub async fn receive<T: DeserializeOwned>(&mut self) -> Result<T>
83    where
84        R: ReadFormat,
85    {
86        self.channel.receive(&mut self.receive_format).await
87    }
88}
89
90impl<R, W> Channel<R, W> {
91    pub(crate) fn from_raw(
92        raw: impl Into<UnformattedRawUnifiedChannel>,
93        receive_format: R,
94        send_format: W,
95    ) -> Self {
96        Self::Unified(UnifiedChannel {
97            channel: UnformattedUnifiedChannel::Raw(raw.into()),
98            receive_format,
99            send_format,
100        })
101    }
102
103    /// Try to encrypt channel using the provided transport.
104    /// Will return an error if channel is already encrypted.
105    /// To turn `Arc<StatelessTransportState>` into the inner transport state
106    /// use `Arc::try_unwrap(transport)`.
107    pub fn encrypt(
108        &mut self,
109        transport: StatelessTransportState,
110    ) -> Result<(), Arc<StatelessTransportState>> {
111        match self {
112            Channel::Unified(unified) => unified.encrypt(transport).map_err(Arc::new),
113            Channel::Bipartite(bipartite) => bipartite.encrypt(Arc::new(transport)),
114        }
115    }
116
117    /// Send an object through the channel
118    /// ```no_run
119    /// chan.send("Hello world!").await?;
120    /// ```
121    pub async fn send<T: Serialize>(&mut self, obj: T) -> Result<usize>
122    where
123        W: SendFormat,
124    {
125        match self {
126            Channel::Unified(chan) => chan.send(obj).await,
127            Channel::Bipartite(chan) => chan.send(obj).await,
128        }
129    }
130    /// Receive an object sent through the channel
131    /// ```no_run
132    /// let string: String = chan.receive().await?;
133    /// ```
134    pub async fn receive<T: DeserializeOwned>(&mut self) -> Result<T>
135    where
136        R: ReadFormat,
137    {
138        match self {
139            Channel::Unified(chan) => chan.receive().await,
140            Channel::Bipartite(chan) => chan.receive().await,
141        }
142    }
143    #[must_use]
144    /// Split channel into its send and receive components
145    pub fn split(self) -> (SendChannel<W>, ReceiveChannel<R>) {
146        match self {
147            Channel::Unified(chan) => chan.split(),
148            Channel::Bipartite(chan) => chan.split(),
149        }
150    }
151    /// Join send and receive channels into a channel
152    pub fn join(send: SendChannel<W>, receive: ReceiveChannel<R>) -> Self {
153        Self::Bipartite(BipartiteChannel {
154            receive_channel: receive,
155            send_channel: send,
156        })
157    }
158}
159
160impl<'a> RefUnformattedBidirectionalChannel<'a> {
161    /// Send an object through the channel serialized with format
162    /// ```no_run
163    /// chan.send("Hello world!", &mut Format::Bincode).await?;
164    /// ```
165    pub async fn send<T: Serialize, F: SendFormat>(
166        &mut self,
167        obj: T,
168        format: &mut F,
169    ) -> Result<usize> {
170        match self {
171            Self::Raw(chan) => chan.send(obj, format).await,
172            Self::Encrypted(chan, snow, nonce) => {
173                let ref mut snow = RefDividedSnow {
174                    transport: snow,
175                    nonce,
176                };
177                let mut with = WithCipher { snow, format };
178                chan.send(obj, &mut with).await
179            }
180        }
181    }
182    /// Receive an object sent through the channel with format
183    /// ```no_run
184    /// let string: String = chan.receive(&mut Format::Bincode).await?;
185    /// ```
186    pub async fn receive<T: DeserializeOwned, F: ReadFormat>(
187        &mut self,
188        format: &mut F,
189    ) -> Result<T> {
190        match self {
191            Self::Raw(chan) => chan.receive(format).await,
192            Self::Encrypted(chan, snow, nonce) => {
193                let ref mut snow = RefDividedSnow {
194                    transport: snow,
195                    nonce,
196                };
197                let mut with = WithCipher { snow, format };
198                chan.receive(&mut with).await
199            }
200        }
201    }
202
203    /// Returns `true` if the ref unformatted bidirectional channel is [`Encrypted`].
204    ///
205    /// [`Encrypted`]: RefUnformattedBidirectionalChannel::Encrypted
206    #[must_use]
207    pub fn is_encrypted(&self) -> bool {
208        matches!(self, Self::Encrypted(..))
209    }
210}
211
212impl UnformattedBidirectionalChannel {
213    /// Send an object through the channel serialized with format
214    /// ```no_run
215    /// chan.send("Hello world!", &mut Format::Bincode).await?;
216    /// ```
217    pub async fn send<T: Serialize, F: SendFormat>(
218        &mut self,
219        obj: T,
220        format: &mut F,
221    ) -> Result<usize> {
222        match self {
223            Self::Unified(chan) => chan.send(obj, format).await,
224            Self::Bipartite(chan) => chan.send(obj, format).await,
225        }
226    }
227    /// Receive an object sent through the channel with format
228    /// ```no_run
229    /// let string: String = chan.receive(&mut Format::Bincode).await?;
230    /// ```
231    pub async fn receive<T: DeserializeOwned, F: ReadFormat>(
232        &mut self,
233        format: &mut F,
234    ) -> Result<T> {
235        match self {
236            UnformattedBidirectionalChannel::Unified(chan) => chan.receive(format).await,
237            UnformattedBidirectionalChannel::Bipartite(chan) => chan.receive(format).await,
238        }
239    }
240    #[must_use]
241    /// Split channel into its send and receive components
242    pub fn split(self) -> (UnformattedSendChannel, UnformattedReceiveChannel) {
243        match self {
244            UnformattedBidirectionalChannel::Unified(chan) => chan.split(),
245            UnformattedBidirectionalChannel::Bipartite(chan) => chan.split(),
246        }
247    }
248}