commonware_stream/utils/
codec.rs

1use crate::Error;
2use bytes::{BufMut as _, Bytes, BytesMut};
3use commonware_runtime::{Sink, Stream};
4
5/// Sends data to the sink with a 4-byte length prefix.
6/// Returns an error if the message is too large or the stream is closed.
7pub async fn send_frame<S: Sink>(
8    sink: &mut S,
9    buf: &[u8],
10    max_message_size: usize,
11) -> Result<(), Error> {
12    // Validate frame size
13    let n = buf.len();
14    if n == 0 {
15        return Err(Error::SendZeroSize);
16    }
17    if n > max_message_size {
18        return Err(Error::SendTooLarge(n));
19    }
20
21    // Prefix `buf` with its length and send it
22    let mut prefixed_buf = BytesMut::with_capacity(4 + buf.len());
23    let len: u32 = n.try_into().map_err(|_| Error::SendTooLarge(n))?;
24    prefixed_buf.put_u32(len);
25    prefixed_buf.extend_from_slice(buf);
26    sink.send(prefixed_buf).await.map_err(Error::SendFailed)
27}
28
29/// Receives data from the stream with a 4-byte length prefix.
30/// Returns an error if the message is too large or the stream is closed.
31pub async fn recv_frame<T: Stream>(
32    stream: &mut T,
33    max_message_size: usize,
34) -> Result<Bytes, Error> {
35    // Read the first 4 bytes to get the length of the message
36    let len_buf = stream.recv(vec![0; 4]).await.map_err(Error::RecvFailed)?;
37
38    // Validate frame size
39    let len = u32::from_be_bytes(len_buf.as_ref()[..4].try_into().unwrap()) as usize;
40    if len > max_message_size {
41        return Err(Error::RecvTooLarge(len));
42    }
43    if len == 0 {
44        return Err(Error::StreamClosed);
45    }
46
47    // Read the rest of the message
48    let read = stream.recv(vec![0; len]).await.map_err(Error::RecvFailed)?;
49    Ok(read.into())
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use commonware_runtime::{deterministic, mocks, Runner};
56    use rand::Rng;
57
58    const MAX_MESSAGE_SIZE: usize = 1024;
59
60    #[test]
61    fn test_send_recv_at_max_message_size() {
62        let (mut sink, mut stream) = mocks::Channel::init();
63
64        let executor = deterministic::Runner::default();
65        executor.start(|mut context| async move {
66            let mut buf = [0u8; MAX_MESSAGE_SIZE];
67            context.fill(&mut buf);
68
69            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
70            assert!(result.is_ok());
71
72            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
73            assert_eq!(data.len(), buf.len());
74            assert_eq!(data, Bytes::from(buf.to_vec()));
75        });
76    }
77
78    #[test]
79    fn test_send_recv_multiple() {
80        let (mut sink, mut stream) = mocks::Channel::init();
81
82        let executor = deterministic::Runner::default();
83        executor.start(|mut context| async move {
84            let mut buf1 = [0u8; MAX_MESSAGE_SIZE];
85            let mut buf2 = [0u8; MAX_MESSAGE_SIZE / 2];
86            context.fill(&mut buf1);
87            context.fill(&mut buf2);
88
89            // Send two messages of different sizes
90            let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await;
91            assert!(result.is_ok());
92            let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await;
93            assert!(result.is_ok());
94
95            // Read both messages in order
96            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
97            assert_eq!(data.len(), buf1.len());
98            assert_eq!(data, Bytes::from(buf1.to_vec()));
99            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
100            assert_eq!(data.len(), buf2.len());
101            assert_eq!(data, Bytes::from(buf2.to_vec()));
102        });
103    }
104
105    #[test]
106    fn test_send_frame() {
107        let (mut sink, mut stream) = mocks::Channel::init();
108
109        let executor = deterministic::Runner::default();
110        executor.start(|mut context| async move {
111            let mut buf = [0u8; MAX_MESSAGE_SIZE];
112            context.fill(&mut buf);
113
114            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
115            assert!(result.is_ok());
116
117            // Do the reading manually without using recv_frame
118            let read = stream.recv(vec![0; 4]).await.unwrap();
119            assert_eq!(read.as_ref(), (buf.len() as u32).to_be_bytes());
120            let read = stream.recv(vec![0; MAX_MESSAGE_SIZE]).await.unwrap();
121            assert_eq!(read.as_ref(), buf);
122        });
123    }
124
125    #[test]
126    fn test_send_frame_too_large() {
127        const MAX_MESSAGE_SIZE: usize = 1024;
128        let (mut sink, _) = mocks::Channel::init();
129
130        let executor = deterministic::Runner::default();
131        executor.start(|mut context| async move {
132            let mut buf = [0u8; MAX_MESSAGE_SIZE];
133            context.fill(&mut buf);
134
135            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await;
136            assert!(matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
137        });
138    }
139
140    #[test]
141    fn test_send_zero_size() {
142        let (mut sink, _) = mocks::Channel::init();
143
144        let executor = deterministic::Runner::default();
145        executor.start(|_| async move {
146            let buf = [];
147            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
148            assert!(matches!(&result, Err(Error::SendZeroSize)));
149        });
150    }
151
152    #[test]
153    fn test_read_frame() {
154        let (mut sink, mut stream) = mocks::Channel::init();
155
156        let executor = deterministic::Runner::default();
157        executor.start(|mut context| async move {
158            // Do the writing manually without using send_frame
159            let mut msg = [0u8; MAX_MESSAGE_SIZE];
160            context.fill(&mut msg);
161
162            let mut buf = BytesMut::with_capacity(4 + msg.len());
163            buf.put_u32(MAX_MESSAGE_SIZE as u32);
164            buf.extend_from_slice(&msg);
165            sink.send(buf).await.unwrap();
166
167            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
168            assert_eq!(data.len(), MAX_MESSAGE_SIZE);
169            assert_eq!(data, msg.as_ref());
170        });
171    }
172
173    #[test]
174    fn test_read_frame_too_large() {
175        let (mut sink, mut stream) = mocks::Channel::init();
176
177        let executor = deterministic::Runner::default();
178        executor.start(|_| async move {
179            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
180            let mut buf = BytesMut::with_capacity(4);
181            buf.put_u32(MAX_MESSAGE_SIZE as u32);
182            sink.send(buf).await.unwrap();
183
184            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
185            assert!(matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
186        });
187    }
188
189    #[test]
190    fn test_read_zero_size() {
191        let (mut sink, mut stream) = mocks::Channel::init();
192
193        let executor = deterministic::Runner::default();
194        executor.start(|_| async move {
195            // Manually insert a frame that gives zero as the size
196            let mut buf = BytesMut::with_capacity(4);
197            buf.put_u32(0);
198            sink.send(buf).await.unwrap();
199
200            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
201            assert!(matches!(&result, Err(Error::StreamClosed)));
202        });
203    }
204
205    #[test]
206    fn test_recv_frame_short_length_prefix() {
207        let (mut sink, mut stream) = mocks::Channel::init();
208
209        let executor = deterministic::Runner::default();
210        executor.start(|_| async move {
211            // Manually insert a frame with a short length prefix
212            let mut buf = BytesMut::with_capacity(3);
213            buf.put_u8(0x00);
214            buf.put_u8(0x00);
215            buf.put_u8(0x00);
216
217            sink.send(buf).await.unwrap();
218            drop(sink); // Close the sink to simulate a closed stream
219
220            // Expect an error rather than a panic
221            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
222            assert!(matches!(&result, Err(Error::RecvFailed(_))));
223        });
224    }
225}