commonware_stream/utils/
codec.rs

1use crate::Error;
2use bytes::{Bytes, BytesMut};
3use commonware_codec::{
4    varint::{Decoder, UInt},
5    EncodeSize as _, Write as _,
6};
7use commonware_runtime::{Sink, Stream};
8use commonware_utils::StableBuf;
9
10/// Sends data to the sink with a varint length prefix.
11/// Returns an error if the message is too large or the stream is closed.
12pub async fn send_frame<S: Sink>(
13    sink: &mut S,
14    buf: &[u8],
15    max_message_size: u32,
16) -> Result<(), Error> {
17    // Validate frame size
18    let n = buf.len();
19    if n > max_message_size as usize {
20        return Err(Error::SendTooLarge(n));
21    }
22
23    // Prefix `buf` with its varint-encoded length and send it
24    let len = UInt(n as u32);
25    let mut prefixed_buf = BytesMut::with_capacity(len.encode_size() + buf.len());
26    len.write(&mut prefixed_buf);
27    prefixed_buf.extend_from_slice(buf);
28    sink.send(prefixed_buf).await.map_err(Error::SendFailed)
29}
30
31/// Receives data from the stream with a varint length prefix.
32/// Returns an error if the message is too large, the varint is invalid, or the
33/// stream is closed.
34pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<Bytes, Error> {
35    // Read and decode the varint length prefix byte-by-byte
36    let mut decoder = Decoder::<u32>::new();
37    let mut buf = StableBuf::from(vec![0u8; 1]);
38    let len = loop {
39        buf = stream.recv(buf).await.map_err(Error::RecvFailed)?;
40        match decoder.feed(buf[0]) {
41            Ok(Some(len)) => break len as usize,
42            Ok(None) => continue,
43            Err(_) => return Err(Error::InvalidVarint),
44        }
45    };
46
47    // Validate frame size
48    if len > max_message_size as usize {
49        return Err(Error::RecvTooLarge(len));
50    }
51
52    // Read the rest of the message
53    let read = stream.recv(vec![0; len]).await.map_err(Error::RecvFailed)?;
54    Ok(read.into())
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use bytes::BufMut;
61    use commonware_runtime::{deterministic, mocks, Runner};
62    use rand::Rng;
63
64    const MAX_MESSAGE_SIZE: u32 = 1024;
65
66    #[test]
67    fn test_send_recv_at_max_message_size() {
68        let (mut sink, mut stream) = mocks::Channel::init();
69
70        let executor = deterministic::Runner::default();
71        executor.start(|mut context| async move {
72            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
73            context.fill(&mut buf);
74
75            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
76            assert!(result.is_ok());
77
78            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
79            assert_eq!(data.len(), buf.len());
80            assert_eq!(data, Bytes::from(buf.to_vec()));
81        });
82    }
83
84    #[test]
85    fn test_send_recv_multiple() {
86        let (mut sink, mut stream) = mocks::Channel::init();
87
88        let executor = deterministic::Runner::default();
89        executor.start(|mut context| async move {
90            let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
91            let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
92            context.fill(&mut buf1);
93            context.fill(&mut buf2);
94
95            // Send two messages of different sizes
96            let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await;
97            assert!(result.is_ok());
98            let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await;
99            assert!(result.is_ok());
100
101            // Read both messages in order
102            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
103            assert_eq!(data.len(), buf1.len());
104            assert_eq!(data, Bytes::from(buf1.to_vec()));
105            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
106            assert_eq!(data.len(), buf2.len());
107            assert_eq!(data, Bytes::from(buf2.to_vec()));
108        });
109    }
110
111    #[test]
112    fn test_send_frame() {
113        let (mut sink, mut stream) = mocks::Channel::init();
114
115        let executor = deterministic::Runner::default();
116        executor.start(|mut context| async move {
117            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
118            context.fill(&mut buf);
119
120            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
121            assert!(result.is_ok());
122
123            // Do the reading manually without using recv_frame
124            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes)
125            let read = stream.recv(vec![0; 2]).await.unwrap();
126            assert_eq!(read.as_ref(), &[0x80, 0x08]); // 1024 as varint
127            let read = stream
128                .recv(vec![0; MAX_MESSAGE_SIZE as usize])
129                .await
130                .unwrap();
131            assert_eq!(read.as_ref(), buf);
132        });
133    }
134
135    #[test]
136    fn test_send_frame_too_large() {
137        let (mut sink, _) = mocks::Channel::init();
138
139        let executor = deterministic::Runner::default();
140        executor.start(|mut context| async move {
141            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
142            context.fill(&mut buf);
143
144            let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await;
145            assert!(
146                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
147            );
148        });
149    }
150
151    #[test]
152    fn test_read_frame() {
153        let (mut sink, mut stream) = mocks::Channel::init();
154
155        let executor = deterministic::Runner::default();
156        executor.start(|mut context| async move {
157            // Do the writing manually without using send_frame
158            let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
159            context.fill(&mut msg);
160
161            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
162            let mut buf = BytesMut::with_capacity(2 + msg.len());
163            buf.put_u8(0x80);
164            buf.put_u8(0x08);
165            buf.extend_from_slice(&msg);
166            sink.send(buf).await.unwrap();
167
168            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
169            assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
170            assert_eq!(data, msg.as_ref());
171        });
172    }
173
174    #[test]
175    fn test_read_frame_too_large() {
176        let (mut sink, mut stream) = mocks::Channel::init();
177
178        let executor = deterministic::Runner::default();
179        executor.start(|_| async move {
180            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
181            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
182            let mut buf = BytesMut::with_capacity(2);
183            buf.put_u8(0x80);
184            buf.put_u8(0x08);
185            sink.send(buf).await.unwrap();
186
187            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
188            assert!(
189                matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
190            );
191        });
192    }
193
194    #[test]
195    fn test_recv_frame_incomplete_varint() {
196        let (mut sink, mut stream) = mocks::Channel::init();
197
198        let executor = deterministic::Runner::default();
199        executor.start(|_| async move {
200            // Send incomplete varint (continuation bit set but no following byte)
201            let mut buf = BytesMut::with_capacity(1);
202            buf.put_u8(0x80); // Continuation bit set, expects more bytes
203
204            sink.send(buf).await.unwrap();
205            drop(sink); // Close the sink to simulate a closed stream
206
207            // Expect an error because varint is incomplete
208            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
209            assert!(matches!(&result, Err(Error::RecvFailed(_))));
210        });
211    }
212
213    #[test]
214    fn test_recv_frame_invalid_varint_overflow() {
215        let (mut sink, mut stream) = mocks::Channel::init();
216
217        let executor = deterministic::Runner::default();
218        executor.start(|_| async move {
219            // Send a varint that overflows u32 (more than 5 bytes with continuation bits)
220            let mut buf = BytesMut::with_capacity(6);
221            buf.put_u8(0xFF); // 7 bits + continue
222            buf.put_u8(0xFF); // 7 bits + continue
223            buf.put_u8(0xFF); // 7 bits + continue
224            buf.put_u8(0xFF); // 7 bits + continue
225            buf.put_u8(0xFF); // 5th byte with overflow bits set + continue
226            buf.put_u8(0x01); // 6th byte
227
228            sink.send(buf).await.unwrap();
229
230            // Expect an error because varint overflows u32
231            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
232            assert!(matches!(&result, Err(Error::InvalidVarint)));
233        });
234    }
235}