commonware_p2p/authenticated/
channels.rs

1use super::{actors::Messenger, Error};
2use crate::{Channel, Message, Recipients};
3use bytes::Bytes;
4use commonware_utils::Array;
5use futures::{channel::mpsc, StreamExt};
6use governor::Quota;
7use std::collections::BTreeMap;
8use zstd::bulk::{compress, decompress};
9
10/// Sender is the mechanism used to send arbitrary bytes to
11/// a set of recipients over a pre-defined channel.
12#[derive(Clone, Debug)]
13pub struct Sender<P: Array> {
14    channel: Channel,
15    max_size: usize,
16    compression: Option<i32>,
17    messenger: Messenger<P>,
18}
19
20impl<P: Array> Sender<P> {
21    pub(super) fn new(
22        channel: Channel,
23        max_size: usize,
24        compression: Option<i32>,
25        messenger: Messenger<P>,
26    ) -> Self {
27        Self {
28            channel,
29            max_size,
30            compression,
31            messenger,
32        }
33    }
34}
35
36impl<P: Array> crate::Sender for Sender<P> {
37    type Error = Error;
38    type PublicKey = P;
39
40    /// Sends a message to a set of recipients.
41    ///
42    /// # Offline Recipients
43    ///
44    /// If a recipient is offline at the time a message is sent, the message will be dropped.
45    /// It is up to the application to handle retries (if necessary).
46    ///
47    /// # Parameters
48    ///
49    /// * `recipients` - The set of recipients to send the message to.
50    /// * `message` - The message to send.
51    /// * `priority` - Whether the message should be sent with priority (across
52    ///   all channels).
53    ///
54    /// # Returns
55    ///
56    /// If the message can be compressed (if enabled) and the message is `< max_size`, The set of recipients
57    /// that the message was sent to. Note, a successful send does not mean that the recipient will
58    /// receive the message (connection may no longer be active and we may not know that yet).
59    async fn send(
60        &mut self,
61        recipients: Recipients<Self::PublicKey>,
62        mut message: Bytes,
63        priority: bool,
64    ) -> Result<Vec<Self::PublicKey>, Error> {
65        // If compression is enabled, compress the message before sending.
66        if let Some(level) = self.compression {
67            let compressed = compress(&message, level).map_err(|_| Error::CompressionFailed)?;
68            message = compressed.into();
69        }
70
71        // Ensure message isn't too large
72        let message_len = message.len();
73        if message_len > self.max_size {
74            return Err(Error::MessageTooLarge(message_len));
75        }
76
77        // Wait for messenger to let us know who we sent to
78        Ok(self
79            .messenger
80            .content(recipients, self.channel, message, priority)
81            .await)
82    }
83}
84
85/// Channel to asynchronously receive messages from a channel.
86#[derive(Debug)]
87pub struct Receiver<P: Array> {
88    max_size: usize,
89    compression: bool,
90    receiver: mpsc::Receiver<Message<P>>,
91}
92
93impl<P: Array> Receiver<P> {
94    pub(super) fn new(
95        max_size: usize,
96        compression: bool,
97        receiver: mpsc::Receiver<Message<P>>,
98    ) -> Self {
99        Self {
100            max_size,
101            compression,
102            receiver,
103        }
104    }
105}
106
107impl<P: Array> crate::Receiver for Receiver<P> {
108    type Error = Error;
109    type PublicKey = P;
110
111    /// Receives a message from the channel.
112    ///
113    /// This method will block until a message is received or the underlying
114    /// network shuts down.
115    async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Error> {
116        let (sender, mut message) = self.receiver.next().await.ok_or(Error::NetworkClosed)?;
117
118        // If compression is enabled, decompress the message before returning.
119        if self.compression {
120            let buf =
121                decompress(&message, self.max_size).map_err(|_| Error::DecompressionFailed)?;
122            message = buf.into();
123        }
124
125        // We don't check that the message is too large here because we already enforce
126        // that on the network layer.
127        Ok((sender, message))
128    }
129}
130
131#[derive(Clone)]
132pub struct Channels<P: Array> {
133    messenger: Messenger<P>,
134    max_size: usize,
135    receivers: BTreeMap<Channel, (Quota, mpsc::Sender<Message<P>>)>,
136}
137
138impl<P: Array> Channels<P> {
139    pub fn new(messenger: Messenger<P>, max_size: usize) -> Self {
140        Self {
141            messenger,
142            max_size,
143            receivers: BTreeMap::new(),
144        }
145    }
146
147    pub fn register(
148        &mut self,
149        channel: Channel,
150        rate: governor::Quota,
151        backlog: usize,
152        compression: Option<i32>,
153    ) -> (Sender<P>, Receiver<P>) {
154        let (sender, receiver) = mpsc::channel(backlog);
155        if self.receivers.insert(channel, (rate, sender)).is_some() {
156            panic!("duplicate channel registration: {}", channel);
157        }
158        (
159            Sender::new(channel, self.max_size, compression, self.messenger.clone()),
160            Receiver::new(self.max_size, compression.is_some(), receiver),
161        )
162    }
163
164    pub fn collect(self) -> BTreeMap<u32, (Quota, mpsc::Sender<Message<P>>)> {
165        self.receivers
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_compression() {
175        let message = b"hello world";
176        let compressed = compress(message, 3).unwrap();
177        let buf = decompress(&compressed, message.len()).unwrap();
178        assert_eq!(message, buf.as_slice());
179    }
180}