commonware_stream/utils/
codec.rs

1use crate::Error;
2use bytes::{BufMut as _, Bytes, BytesMut};
3use commonware_runtime::{Sink, Stream};
4
5/// Sends data to the sink with a 4-byte length prefix.
6/// Returns an error if the message is too large or the stream is closed.
7pub async fn send_frame<S: Sink>(
8    sink: &mut S,
9    buf: &[u8],
10    max_message_size: usize,
11) -> Result<(), Error> {
12    // Validate frame size
13    let n = buf.len();
14    if n > max_message_size {
15        return Err(Error::SendTooLarge(n));
16    }
17
18    // Prefix `buf` with its length and send it
19    let mut prefixed_buf = BytesMut::with_capacity(4 + buf.len());
20    let len: u32 = n.try_into().map_err(|_| Error::SendTooLarge(n))?;
21    prefixed_buf.put_u32(len);
22    prefixed_buf.extend_from_slice(buf);
23    sink.send(prefixed_buf).await.map_err(Error::SendFailed)
24}
25
26/// Receives data from the stream with a 4-byte length prefix.
27/// Returns an error if the message is too large or the stream is closed.
28pub async fn recv_frame<T: Stream>(
29    stream: &mut T,
30    max_message_size: usize,
31) -> Result<Bytes, Error> {
32    // Read the first 4 bytes to get the length of the message
33    let len_buf = stream.recv(vec![0; 4]).await.map_err(Error::RecvFailed)?;
34
35    // Validate frame size
36    let len = u32::from_be_bytes(len_buf.as_ref()[..4].try_into().unwrap()) as usize;
37    if len > max_message_size {
38        return Err(Error::RecvTooLarge(len));
39    }
40
41    // Read the rest of the message
42    let read = stream.recv(vec![0; len]).await.map_err(Error::RecvFailed)?;
43    Ok(read.into())
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use commonware_runtime::{deterministic, mocks, Runner};
50    use rand::Rng;
51
52    const MAX_MESSAGE_SIZE: usize = 1024;
53
54    #[test]
55    fn test_send_recv_at_max_message_size() {
56        let (mut sink, mut stream) = mocks::Channel::init();
57
58        let executor = deterministic::Runner::default();
59        executor.start(|mut context| async move {
60            let mut buf = [0u8; MAX_MESSAGE_SIZE];
61            context.fill(&mut buf);
62
63            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
64            assert!(result.is_ok());
65
66            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
67            assert_eq!(data.len(), buf.len());
68            assert_eq!(data, Bytes::from(buf.to_vec()));
69        });
70    }
71
72    #[test]
73    fn test_send_recv_multiple() {
74        let (mut sink, mut stream) = mocks::Channel::init();
75
76        let executor = deterministic::Runner::default();
77        executor.start(|mut context| async move {
78            let mut buf1 = [0u8; MAX_MESSAGE_SIZE];
79            let mut buf2 = [0u8; MAX_MESSAGE_SIZE / 2];
80            context.fill(&mut buf1);
81            context.fill(&mut buf2);
82
83            // Send two messages of different sizes
84            let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await;
85            assert!(result.is_ok());
86            let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await;
87            assert!(result.is_ok());
88
89            // Read both messages in order
90            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
91            assert_eq!(data.len(), buf1.len());
92            assert_eq!(data, Bytes::from(buf1.to_vec()));
93            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
94            assert_eq!(data.len(), buf2.len());
95            assert_eq!(data, Bytes::from(buf2.to_vec()));
96        });
97    }
98
99    #[test]
100    fn test_send_frame() {
101        let (mut sink, mut stream) = mocks::Channel::init();
102
103        let executor = deterministic::Runner::default();
104        executor.start(|mut context| async move {
105            let mut buf = [0u8; MAX_MESSAGE_SIZE];
106            context.fill(&mut buf);
107
108            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
109            assert!(result.is_ok());
110
111            // Do the reading manually without using recv_frame
112            let read = stream.recv(vec![0; 4]).await.unwrap();
113            assert_eq!(read.as_ref(), (buf.len() as u32).to_be_bytes());
114            let read = stream.recv(vec![0; MAX_MESSAGE_SIZE]).await.unwrap();
115            assert_eq!(read.as_ref(), buf);
116        });
117    }
118
119    #[test]
120    fn test_send_frame_too_large() {
121        const MAX_MESSAGE_SIZE: usize = 1024;
122        let (mut sink, _) = mocks::Channel::init();
123
124        let executor = deterministic::Runner::default();
125        executor.start(|mut context| async move {
126            let mut buf = [0u8; MAX_MESSAGE_SIZE];
127            context.fill(&mut buf);
128
129            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await;
130            assert!(matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
131        });
132    }
133
134    #[test]
135    fn test_read_frame() {
136        let (mut sink, mut stream) = mocks::Channel::init();
137
138        let executor = deterministic::Runner::default();
139        executor.start(|mut context| async move {
140            // Do the writing manually without using send_frame
141            let mut msg = [0u8; MAX_MESSAGE_SIZE];
142            context.fill(&mut msg);
143
144            let mut buf = BytesMut::with_capacity(4 + msg.len());
145            buf.put_u32(MAX_MESSAGE_SIZE as u32);
146            buf.extend_from_slice(&msg);
147            sink.send(buf).await.unwrap();
148
149            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
150            assert_eq!(data.len(), MAX_MESSAGE_SIZE);
151            assert_eq!(data, msg.as_ref());
152        });
153    }
154
155    #[test]
156    fn test_read_frame_too_large() {
157        let (mut sink, mut stream) = mocks::Channel::init();
158
159        let executor = deterministic::Runner::default();
160        executor.start(|_| async move {
161            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
162            let mut buf = BytesMut::with_capacity(4);
163            buf.put_u32(MAX_MESSAGE_SIZE as u32);
164            sink.send(buf).await.unwrap();
165
166            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
167            assert!(matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
168        });
169    }
170
171    #[test]
172    fn test_recv_frame_short_length_prefix() {
173        let (mut sink, mut stream) = mocks::Channel::init();
174
175        let executor = deterministic::Runner::default();
176        executor.start(|_| async move {
177            // Manually insert a frame with a short length prefix
178            let mut buf = BytesMut::with_capacity(3);
179            buf.put_u8(0x00);
180            buf.put_u8(0x00);
181            buf.put_u8(0x00);
182
183            sink.send(buf).await.unwrap();
184            drop(sink); // Close the sink to simulate a closed stream
185
186            // Expect an error rather than a panic
187            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
188            assert!(matches!(&result, Err(Error::RecvFailed(_))));
189        });
190    }
191}