commonware_stream/utils/
codec.rs

1use crate::Error;
2use bytes::Bytes;
3use commonware_runtime::{Sink, Stream};
4
5/// Sends data to the sink with a 4-byte length prefix.
6/// Returns an error if the message is too large or the stream is closed.
7pub async fn send_frame<S: Sink>(
8    sink: &mut S,
9    buf: &[u8],
10    max_message_size: usize,
11) -> Result<(), Error> {
12    // Validate frame size
13    let n = buf.len();
14    if n == 0 {
15        return Err(Error::SendZeroSize);
16    }
17    if n > max_message_size {
18        return Err(Error::SendTooLarge(n));
19    }
20    let len: u32 = n.try_into().map_err(|_| Error::SendTooLarge(n))?;
21
22    // Send the length of the message
23    let f: [u8; 4] = len.to_be_bytes();
24    sink.send(&f).await.map_err(|_| Error::SendFailed)?;
25
26    // Send the rest of the message
27    sink.send(buf).await.map_err(|_| Error::SendFailed)?;
28
29    Ok(())
30}
31
32/// Receives data from the stream with a 4-byte length prefix.
33/// Returns an error if the message is too large or the stream is closed.
34pub async fn recv_frame<T: Stream>(
35    stream: &mut T,
36    max_message_size: usize,
37) -> Result<Bytes, Error> {
38    // Read the first 4 bytes to get the length of the message
39    let mut buf = [0u8; 4];
40    stream.recv(&mut buf).await.map_err(|_| Error::RecvFailed)?;
41
42    // Validate frame size
43    let len = u32::from_be_bytes(buf) as usize;
44    if len > max_message_size {
45        return Err(Error::RecvTooLarge(len));
46    }
47    if len == 0 {
48        return Err(Error::StreamClosed);
49    }
50
51    // Read the rest of the message
52    let mut buf = vec![0u8; len];
53    stream.recv(&mut buf).await.map_err(|_| Error::RecvFailed)?;
54
55    Ok(Bytes::from(buf))
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use commonware_runtime::{deterministic::Executor, mocks, Runner};
62    use rand::Rng;
63
64    const MAX_MESSAGE_SIZE: usize = 1024;
65
66    #[test]
67    fn test_send_recv_at_max_message_size() {
68        let (mut sink, mut stream) = mocks::Channel::init();
69
70        let (executor, mut context, _) = Executor::default();
71        executor.start(async move {
72            let mut buf = [0u8; MAX_MESSAGE_SIZE];
73            context.fill(&mut buf);
74
75            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
76            assert!(result.is_ok());
77
78            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
79            assert_eq!(data.len(), buf.len());
80            assert_eq!(data, Bytes::from(buf.to_vec()));
81        });
82    }
83
84    #[test]
85    fn test_send_recv_multiple() {
86        let (mut sink, mut stream) = mocks::Channel::init();
87
88        let (executor, mut context, _) = Executor::default();
89        executor.start(async move {
90            let mut buf1 = [0u8; MAX_MESSAGE_SIZE];
91            let mut buf2 = [0u8; MAX_MESSAGE_SIZE / 2];
92            context.fill(&mut buf1);
93            context.fill(&mut buf2);
94
95            // Send two messages of different sizes
96            let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await;
97            assert!(result.is_ok());
98            let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await;
99            assert!(result.is_ok());
100
101            // Read both messages in order
102            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
103            assert_eq!(data.len(), buf1.len());
104            assert_eq!(data, Bytes::from(buf1.to_vec()));
105            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
106            assert_eq!(data.len(), buf2.len());
107            assert_eq!(data, Bytes::from(buf2.to_vec()));
108        });
109    }
110
111    #[test]
112    fn test_send_frame() {
113        let (mut sink, mut stream) = mocks::Channel::init();
114
115        let (executor, mut context, _) = Executor::default();
116        executor.start(async move {
117            let mut buf = [0u8; MAX_MESSAGE_SIZE];
118            context.fill(&mut buf);
119
120            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
121            assert!(result.is_ok());
122
123            // Do the reading manually without using recv_frame
124            let mut b = [0u8; 4];
125            stream.recv(&mut b).await.unwrap();
126            assert_eq!(b, (buf.len() as u32).to_be_bytes());
127            let mut b = [0u8; MAX_MESSAGE_SIZE];
128            stream.recv(&mut b).await.unwrap();
129            assert_eq!(b, buf);
130        });
131    }
132
133    #[test]
134    fn test_send_frame_too_large() {
135        const MAX_MESSAGE_SIZE: usize = 1024;
136        let (mut sink, _) = mocks::Channel::init();
137
138        let (executor, mut context, _) = Executor::default();
139        executor.start(async move {
140            let mut buf = [0u8; MAX_MESSAGE_SIZE];
141            context.fill(&mut buf);
142
143            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await;
144            assert!(matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
145        });
146    }
147
148    #[test]
149    fn test_send_zero_size() {
150        let (mut sink, _) = mocks::Channel::init();
151
152        let (executor, _, _) = Executor::default();
153        executor.start(async move {
154            let buf = [];
155            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
156            assert!(matches!(&result, Err(Error::SendZeroSize)));
157        });
158    }
159
160    #[test]
161    fn test_read_frame() {
162        let (mut sink, mut stream) = mocks::Channel::init();
163
164        let (executor, mut context, _) = Executor::default();
165        executor.start(async move {
166            // Do the writing manually without using send_frame
167            let mut buf = [0u8; MAX_MESSAGE_SIZE];
168            context.fill(&mut buf);
169            sink.send(&(MAX_MESSAGE_SIZE as u32).to_be_bytes())
170                .await
171                .unwrap();
172            sink.send(&buf).await.unwrap();
173
174            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
175            assert_eq!(data.len(), buf.len());
176            assert_eq!(data, Bytes::from(buf.to_vec()));
177        });
178    }
179
180    #[test]
181    fn test_read_frame_too_large() {
182        let (mut sink, mut stream) = mocks::Channel::init();
183
184        let (executor, _, _) = Executor::default();
185        executor.start(async move {
186            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
187            sink.send(&(MAX_MESSAGE_SIZE as u32).to_be_bytes())
188                .await
189                .unwrap();
190
191            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
192            assert!(matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
193        });
194    }
195
196    #[test]
197    fn test_read_zero_size() {
198        let (mut sink, mut stream) = mocks::Channel::init();
199
200        let (executor, _, _) = Executor::default();
201        executor.start(async move {
202            // Manually insert a frame that gives zero as the size
203            sink.send(&(0u32).to_be_bytes()).await.unwrap();
204
205            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
206            assert!(matches!(&result, Err(Error::StreamClosed)));
207        });
208    }
209}