commonware_p2p/authenticated/
channels.rs1use super::{actors::Messenger, Error};
2use crate::{Channel, Message, Recipients};
3use bytes::Bytes;
4use commonware_utils::Array;
5use futures::{channel::mpsc, StreamExt};
6use governor::Quota;
7use std::collections::BTreeMap;
8use zstd::bulk::{compress, decompress};
9
10#[derive(Clone, Debug)]
13pub struct Sender<P: Array> {
14 channel: Channel,
15 max_size: usize,
16 compression: Option<i32>,
17 messenger: Messenger<P>,
18}
19
20impl<P: Array> Sender<P> {
21 pub(super) fn new(
22 channel: Channel,
23 max_size: usize,
24 compression: Option<i32>,
25 messenger: Messenger<P>,
26 ) -> Self {
27 Self {
28 channel,
29 max_size,
30 compression,
31 messenger,
32 }
33 }
34}
35
36impl<P: Array> crate::Sender for Sender<P> {
37 type Error = Error;
38 type PublicKey = P;
39
40 async fn send(
60 &mut self,
61 recipients: Recipients<Self::PublicKey>,
62 mut message: Bytes,
63 priority: bool,
64 ) -> Result<Vec<Self::PublicKey>, Error> {
65 if let Some(level) = self.compression {
67 let compressed = compress(&message, level).map_err(|_| Error::CompressionFailed)?;
68 message = compressed.into();
69 }
70
71 let message_len = message.len();
73 if message_len > self.max_size {
74 return Err(Error::MessageTooLarge(message_len));
75 }
76
77 Ok(self
79 .messenger
80 .content(recipients, self.channel, message, priority)
81 .await)
82 }
83}
84
85#[derive(Debug)]
87pub struct Receiver<P: Array> {
88 max_size: usize,
89 compression: bool,
90 receiver: mpsc::Receiver<Message<P>>,
91}
92
93impl<P: Array> Receiver<P> {
94 pub(super) fn new(
95 max_size: usize,
96 compression: bool,
97 receiver: mpsc::Receiver<Message<P>>,
98 ) -> Self {
99 Self {
100 max_size,
101 compression,
102 receiver,
103 }
104 }
105}
106
107impl<P: Array> crate::Receiver for Receiver<P> {
108 type Error = Error;
109 type PublicKey = P;
110
111 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Error> {
116 let (sender, mut message) = self.receiver.next().await.ok_or(Error::NetworkClosed)?;
117
118 if self.compression {
120 let buf =
121 decompress(&message, self.max_size).map_err(|_| Error::DecompressionFailed)?;
122 message = buf.into();
123 }
124
125 Ok((sender, message))
128 }
129}
130
131#[derive(Clone)]
132pub struct Channels<P: Array> {
133 messenger: Messenger<P>,
134 max_size: usize,
135 receivers: BTreeMap<Channel, (Quota, mpsc::Sender<Message<P>>)>,
136}
137
138impl<P: Array> Channels<P> {
139 pub fn new(messenger: Messenger<P>, max_size: usize) -> Self {
140 Self {
141 messenger,
142 max_size,
143 receivers: BTreeMap::new(),
144 }
145 }
146
147 pub fn register(
148 &mut self,
149 channel: Channel,
150 rate: governor::Quota,
151 backlog: usize,
152 compression: Option<i32>,
153 ) -> (Sender<P>, Receiver<P>) {
154 let (sender, receiver) = mpsc::channel(backlog);
155 if self.receivers.insert(channel, (rate, sender)).is_some() {
156 panic!("duplicate channel registration: {}", channel);
157 }
158 (
159 Sender::new(channel, self.max_size, compression, self.messenger.clone()),
160 Receiver::new(self.max_size, compression.is_some(), receiver),
161 )
162 }
163
164 pub fn collect(self) -> BTreeMap<u32, (Quota, mpsc::Sender<Message<P>>)> {
165 self.receivers
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_compression() {
175 let message = b"hello world";
176 let compressed = compress(message, 3).unwrap();
177 let buf = decompress(&compressed, message.len()).unwrap();
178 assert_eq!(message, buf.as_slice());
179 }
180}