commonware_stream/utils/
codec.rs

1use crate::Error;
2use bytes::{BufMut as _, Bytes, BytesMut};
3use commonware_runtime::{Sink, Stream};
4
5/// Sends data to the sink with a 4-byte length prefix.
6/// Returns an error if the message is too large or the stream is closed.
7pub async fn send_frame<S: Sink>(
8    sink: &mut S,
9    buf: &[u8],
10    max_message_size: usize,
11) -> Result<(), Error> {
12    // Validate frame size
13    let n = buf.len();
14    if n == 0 {
15        return Err(Error::SendZeroSize);
16    }
17    if n > max_message_size {
18        return Err(Error::SendTooLarge(n));
19    }
20
21    // Prefix `buf` with its length and send it
22    let mut prefixed_buf = BytesMut::with_capacity(4 + buf.len());
23    let len: u32 = n.try_into().map_err(|_| Error::SendTooLarge(n))?;
24    prefixed_buf.put_u32(len);
25    prefixed_buf.extend_from_slice(buf);
26    sink.send(prefixed_buf.freeze())
27        .await
28        .map_err(Error::SendFailed)
29}
30
31/// Receives data from the stream with a 4-byte length prefix.
32/// Returns an error if the message is too large or the stream is closed.
33pub async fn recv_frame<T: Stream>(
34    stream: &mut T,
35    max_message_size: usize,
36) -> Result<Bytes, Error> {
37    // Read the first 4 bytes to get the length of the message
38    let len_buf = stream.recv(vec![0; 4]).await.map_err(Error::RecvFailed)?;
39
40    // Validate frame size
41    let len = u32::from_be_bytes(len_buf[..4].try_into().unwrap()) as usize;
42    if len > max_message_size {
43        return Err(Error::RecvTooLarge(len));
44    }
45    if len == 0 {
46        return Err(Error::StreamClosed);
47    }
48
49    // Read the rest of the message
50    let read = stream.recv(vec![0; len]).await.map_err(Error::RecvFailed)?;
51    Ok(read.into())
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use commonware_runtime::{deterministic, mocks, Runner};
58    use rand::Rng;
59
60    const MAX_MESSAGE_SIZE: usize = 1024;
61
62    #[test]
63    fn test_send_recv_at_max_message_size() {
64        let (mut sink, mut stream) = mocks::Channel::init();
65
66        let executor = deterministic::Runner::default();
67        executor.start(|mut context| async move {
68            let mut buf = [0u8; MAX_MESSAGE_SIZE];
69            context.fill(&mut buf);
70
71            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
72            assert!(result.is_ok());
73
74            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
75            assert_eq!(data.len(), buf.len());
76            assert_eq!(data, Bytes::from(buf.to_vec()));
77        });
78    }
79
80    #[test]
81    fn test_send_recv_multiple() {
82        let (mut sink, mut stream) = mocks::Channel::init();
83
84        let executor = deterministic::Runner::default();
85        executor.start(|mut context| async move {
86            let mut buf1 = [0u8; MAX_MESSAGE_SIZE];
87            let mut buf2 = [0u8; MAX_MESSAGE_SIZE / 2];
88            context.fill(&mut buf1);
89            context.fill(&mut buf2);
90
91            // Send two messages of different sizes
92            let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await;
93            assert!(result.is_ok());
94            let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await;
95            assert!(result.is_ok());
96
97            // Read both messages in order
98            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
99            assert_eq!(data.len(), buf1.len());
100            assert_eq!(data, Bytes::from(buf1.to_vec()));
101            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
102            assert_eq!(data.len(), buf2.len());
103            assert_eq!(data, Bytes::from(buf2.to_vec()));
104        });
105    }
106
107    #[test]
108    fn test_send_frame() {
109        let (mut sink, mut stream) = mocks::Channel::init();
110
111        let executor = deterministic::Runner::default();
112        executor.start(|mut context| async move {
113            let mut buf = [0u8; MAX_MESSAGE_SIZE];
114            context.fill(&mut buf);
115
116            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
117            assert!(result.is_ok());
118
119            // Do the reading manually without using recv_frame
120            let read = stream.recv(vec![0; 4]).await.unwrap();
121            assert_eq!(read, (buf.len() as u32).to_be_bytes());
122            let read = stream.recv(vec![0; MAX_MESSAGE_SIZE]).await.unwrap();
123            assert_eq!(read, buf);
124        });
125    }
126
127    #[test]
128    fn test_send_frame_too_large() {
129        const MAX_MESSAGE_SIZE: usize = 1024;
130        let (mut sink, _) = mocks::Channel::init();
131
132        let executor = deterministic::Runner::default();
133        executor.start(|mut context| async move {
134            let mut buf = [0u8; MAX_MESSAGE_SIZE];
135            context.fill(&mut buf);
136
137            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await;
138            assert!(matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
139        });
140    }
141
142    #[test]
143    fn test_send_zero_size() {
144        let (mut sink, _) = mocks::Channel::init();
145
146        let executor = deterministic::Runner::default();
147        executor.start(|_| async move {
148            let buf = [];
149            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
150            assert!(matches!(&result, Err(Error::SendZeroSize)));
151        });
152    }
153
154    #[test]
155    fn test_read_frame() {
156        let (mut sink, mut stream) = mocks::Channel::init();
157
158        let executor = deterministic::Runner::default();
159        executor.start(|mut context| async move {
160            // Do the writing manually without using send_frame
161            let mut msg = [0u8; MAX_MESSAGE_SIZE];
162            context.fill(&mut msg);
163
164            let mut buf = BytesMut::with_capacity(4 + msg.len());
165            buf.put_u32(MAX_MESSAGE_SIZE as u32);
166            buf.extend_from_slice(&msg);
167            sink.send(buf.freeze()).await.unwrap();
168
169            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
170            assert_eq!(data.len(), MAX_MESSAGE_SIZE);
171            assert_eq!(data, msg.as_ref());
172        });
173    }
174
175    #[test]
176    fn test_read_frame_too_large() {
177        let (mut sink, mut stream) = mocks::Channel::init();
178
179        let executor = deterministic::Runner::default();
180        executor.start(|_| async move {
181            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
182            let mut buf = BytesMut::with_capacity(4);
183            buf.put_u32(MAX_MESSAGE_SIZE as u32);
184            sink.send(buf.freeze()).await.unwrap();
185
186            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
187            assert!(matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
188        });
189    }
190
191    #[test]
192    fn test_read_zero_size() {
193        let (mut sink, mut stream) = mocks::Channel::init();
194
195        let executor = deterministic::Runner::default();
196        executor.start(|_| async move {
197            // Manually insert a frame that gives zero as the size
198            let mut buf = BytesMut::with_capacity(4);
199            buf.put_u32(0);
200            sink.send(buf.freeze()).await.unwrap();
201
202            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
203            assert!(matches!(&result, Err(Error::StreamClosed)));
204        });
205    }
206}