commonware_stream/utils/
codec.rs

1use crate::Error;
2use bytes::{Buf, Bytes, BytesMut};
3use commonware_codec::{
4    varint::{Decoder, UInt},
5    Encode,
6};
7use commonware_runtime::{Sink, Stream};
8
9/// Sends data to the sink with a varint length prefix.
10/// Returns an error if the message is too large or the stream is closed.
11pub async fn send_frame<S: Sink>(
12    sink: &mut S,
13    buf: impl Buf + Send,
14    max_message_size: u32,
15) -> Result<(), Error> {
16    // Validate frame size
17    let n = buf.remaining();
18    if n > max_message_size as usize {
19        return Err(Error::SendTooLarge(n));
20    }
21
22    // Prefix `buf` with its varint-encoded length and send it
23    let len = UInt(n as u32);
24    let data = len.encode().chain(buf);
25    sink.send(data).await.map_err(Error::SendFailed)
26}
27
28/// Receives data from the stream with a varint length prefix.
29/// Returns an error if the message is too large, the varint is invalid, or the
30/// stream is closed.
31pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<Bytes, Error> {
32    // Read and decode the varint length prefix byte-by-byte
33    let mut decoder = Decoder::<u32>::new();
34    let mut buf = [0u8; 1];
35    let len = loop {
36        stream.recv(&mut buf[..]).await.map_err(Error::RecvFailed)?;
37        match decoder.feed(buf[0]) {
38            Ok(Some(len)) => break len as usize,
39            Ok(None) => continue,
40            Err(_) => return Err(Error::InvalidVarint),
41        }
42    };
43    if len > max_message_size as usize {
44        return Err(Error::RecvTooLarge(len));
45    }
46
47    // Read the rest of the message
48    let mut read = BytesMut::zeroed(len);
49    stream
50        .recv(&mut read[..])
51        .await
52        .map_err(Error::RecvFailed)?;
53    Ok(read.freeze())
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use bytes::{BufMut, BytesMut};
60    use commonware_runtime::{deterministic, mocks, Runner};
61    use rand::Rng;
62
63    const MAX_MESSAGE_SIZE: u32 = 1024;
64
65    #[test]
66    fn test_send_recv_at_max_message_size() {
67        let (mut sink, mut stream) = mocks::Channel::init();
68
69        let executor = deterministic::Runner::default();
70        executor.start(|mut context| async move {
71            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
72            context.fill(&mut buf);
73
74            let result = send_frame(&mut sink, buf.as_slice(), MAX_MESSAGE_SIZE).await;
75            assert!(result.is_ok());
76
77            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
78            assert_eq!(data.len(), buf.len());
79            assert_eq!(data, Bytes::from(buf.to_vec()));
80        });
81    }
82
83    #[test]
84    fn test_send_recv_multiple() {
85        let (mut sink, mut stream) = mocks::Channel::init();
86
87        let executor = deterministic::Runner::default();
88        executor.start(|mut context| async move {
89            let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
90            let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
91            context.fill(&mut buf1);
92            context.fill(&mut buf2);
93
94            // Send two messages of different sizes
95            let result = send_frame(&mut sink, buf1.as_slice(), MAX_MESSAGE_SIZE).await;
96            assert!(result.is_ok());
97            let result = send_frame(&mut sink, buf2.as_slice(), MAX_MESSAGE_SIZE).await;
98            assert!(result.is_ok());
99
100            // Read both messages in order
101            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
102            assert_eq!(data.len(), buf1.len());
103            assert_eq!(data, Bytes::from(buf1.to_vec()));
104            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
105            assert_eq!(data.len(), buf2.len());
106            assert_eq!(data, Bytes::from(buf2.to_vec()));
107        });
108    }
109
110    #[test]
111    fn test_send_frame() {
112        let (mut sink, mut stream) = mocks::Channel::init();
113
114        let executor = deterministic::Runner::default();
115        executor.start(|mut context| async move {
116            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
117            context.fill(&mut buf);
118
119            let result = send_frame(&mut sink, buf.as_slice(), MAX_MESSAGE_SIZE).await;
120            assert!(result.is_ok());
121
122            // Do the reading manually without using recv_frame
123            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes)
124            let mut read = [0u8; 2];
125            stream.recv(&mut read[..]).await.unwrap();
126            assert_eq!(read.as_ref(), &[0x80, 0x08]); // 1024 as varint
127            let mut read = vec![0u8; MAX_MESSAGE_SIZE as usize];
128            stream.recv(&mut read[..]).await.unwrap();
129            assert_eq!(read, buf);
130        });
131    }
132
133    #[test]
134    fn test_send_frame_too_large() {
135        let (mut sink, _) = mocks::Channel::init();
136
137        let executor = deterministic::Runner::default();
138        executor.start(|mut context| async move {
139            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
140            context.fill(&mut buf);
141
142            let result = send_frame(&mut sink, buf.as_slice(), MAX_MESSAGE_SIZE - 1).await;
143            assert!(
144                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
145            );
146        });
147    }
148
149    #[test]
150    fn test_read_frame() {
151        let (mut sink, mut stream) = mocks::Channel::init();
152
153        let executor = deterministic::Runner::default();
154        executor.start(|mut context| async move {
155            // Do the writing manually without using send_frame
156            let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
157            context.fill(&mut msg);
158
159            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
160            let mut buf = BytesMut::with_capacity(2 + msg.len());
161            buf.put_u8(0x80);
162            buf.put_u8(0x08);
163            buf.extend_from_slice(&msg);
164            sink.send(buf).await.unwrap();
165
166            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
167            assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
168            assert_eq!(data, msg.as_ref());
169        });
170    }
171
172    #[test]
173    fn test_read_frame_too_large() {
174        let (mut sink, mut stream) = mocks::Channel::init();
175
176        let executor = deterministic::Runner::default();
177        executor.start(|_| async move {
178            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
179            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
180            let mut buf = BytesMut::with_capacity(2);
181            buf.put_u8(0x80);
182            buf.put_u8(0x08);
183            sink.send(buf).await.unwrap();
184
185            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
186            assert!(
187                matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
188            );
189        });
190    }
191
192    #[test]
193    fn test_recv_frame_incomplete_varint() {
194        let (mut sink, mut stream) = mocks::Channel::init();
195
196        let executor = deterministic::Runner::default();
197        executor.start(|_| async move {
198            // Send incomplete varint (continuation bit set but no following byte)
199            let mut buf = BytesMut::with_capacity(1);
200            buf.put_u8(0x80); // Continuation bit set, expects more bytes
201
202            sink.send(buf).await.unwrap();
203            drop(sink); // Close the sink to simulate a closed stream
204
205            // Expect an error because varint is incomplete
206            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
207            assert!(matches!(&result, Err(Error::RecvFailed(_))));
208        });
209    }
210
211    #[test]
212    fn test_recv_frame_invalid_varint_overflow() {
213        let (mut sink, mut stream) = mocks::Channel::init();
214
215        let executor = deterministic::Runner::default();
216        executor.start(|_| async move {
217            // Send a varint that overflows u32 (more than 5 bytes with continuation bits)
218            let mut buf = BytesMut::with_capacity(6);
219            buf.put_u8(0xFF); // 7 bits + continue
220            buf.put_u8(0xFF); // 7 bits + continue
221            buf.put_u8(0xFF); // 7 bits + continue
222            buf.put_u8(0xFF); // 7 bits + continue
223            buf.put_u8(0xFF); // 5th byte with overflow bits set + continue
224            buf.put_u8(0x01); // 6th byte
225
226            sink.send(buf).await.unwrap();
227
228            // Expect an error because varint overflows u32
229            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
230            assert!(matches!(&result, Err(Error::InvalidVarint)));
231        });
232    }
233}