Skip to main content

commonware_stream/utils/
codec.rs

1use crate::encrypted::Error;
2use commonware_codec::{
3    varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
4    Encode,
5};
6use commonware_runtime::{Buf, IoBuf, IoBufs, Sink, Stream};
7
8/// Sends data to the sink with a varint length prefix.
9/// Returns an error if the message is too large or the stream is closed.
10pub async fn send_frame<S: Sink>(
11    sink: &mut S,
12    buf: impl Into<IoBufs> + Send,
13    max_message_size: u32,
14) -> Result<(), Error> {
15    let mut bufs = buf.into();
16
17    // Validate frame size
18    let n = bufs.remaining();
19    if n > max_message_size as usize {
20        return Err(Error::SendTooLarge(n));
21    }
22
23    // Prepend varint-encoded length
24    let len = UInt(n as u32);
25    bufs.prepend(IoBuf::from(len.encode()));
26    sink.send(bufs).await.map_err(Error::SendFailed)
27}
28
29/// Receives data from the stream with a varint length prefix.
30/// Returns an error if the message is too large, the varint is invalid, or the
31/// stream is closed.
32pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
33    let (len, skip) = recv_length(stream).await?;
34    if len > max_message_size {
35        return Err(Error::RecvTooLarge(len as usize));
36    }
37
38    stream
39        .recv(skip as u64 + len as u64)
40        .await
41        .map(|mut bufs| {
42            bufs.advance(skip as usize);
43            bufs
44        })
45        .map_err(Error::RecvFailed)
46}
47
48/// Receives and decodes the varint length prefix from the stream.
49/// Returns (payload_len, bytes_to_skip) where bytes_to_skip is:
50/// - varint_len if decoded from peek buffer (bytes not yet consumed)
51/// - 0 if decoded via recv (bytes already consumed)
52async fn recv_length<T: Stream>(stream: &mut T) -> Result<(u32, u32), Error> {
53    let mut decoder = Decoder::<u32>::new();
54
55    // Fast path: decode from peek buffer without blocking
56    let peeked = {
57        let peeked = stream.peek(MAX_U32_VARINT_SIZE as u64);
58        for (i, byte) in peeked.iter().enumerate() {
59            match decoder.feed(*byte) {
60                Ok(Some(len)) => return Ok((len, i as u32 + 1)),
61                Ok(None) => continue,
62                Err(_) => return Err(Error::InvalidVarint),
63            }
64        }
65        peeked.len()
66    };
67
68    // Slow path: fetch bytes one at a time (skipping already-decoded peek bytes)
69    let mut buf = stream
70        .recv(peeked as u64 + 1)
71        .await
72        .map_err(Error::RecvFailed)?;
73    buf.advance(peeked);
74
75    loop {
76        match decoder.feed(buf.get_u8()) {
77            Ok(Some(len)) => return Ok((len, 0)),
78            Ok(None) => {}
79            Err(_) => return Err(Error::InvalidVarint),
80        }
81        buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner};
89    use rand::Rng;
90
91    const MAX_MESSAGE_SIZE: u32 = 1024;
92
93    #[test]
94    fn test_send_recv_at_max_message_size() {
95        let (mut sink, mut stream) = mocks::Channel::init();
96
97        let executor = deterministic::Runner::default();
98        executor.start(|mut context| async move {
99            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
100            context.fill(&mut buf);
101
102            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
103            assert!(result.is_ok());
104
105            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
106            assert_eq!(data.len(), buf.len());
107            assert_eq!(data.coalesce(), buf);
108        });
109    }
110
111    #[test]
112    fn test_send_recv_multiple() {
113        let (mut sink, mut stream) = mocks::Channel::init();
114
115        let executor = deterministic::Runner::default();
116        executor.start(|mut context| async move {
117            let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
118            let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
119            context.fill(&mut buf1);
120            context.fill(&mut buf2);
121
122            // Send two messages of different sizes
123            let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
124            assert!(result.is_ok());
125            let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
126            assert!(result.is_ok());
127
128            // Read both messages in order
129            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
130            assert_eq!(data.len(), buf1.len());
131            assert_eq!(data.coalesce(), buf1);
132            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
133            assert_eq!(data.len(), buf2.len());
134            assert_eq!(data.coalesce(), buf2);
135        });
136    }
137
138    #[test]
139    fn test_send_frame() {
140        let (mut sink, mut stream) = mocks::Channel::init();
141
142        let executor = deterministic::Runner::default();
143        executor.start(|mut context| async move {
144            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
145            context.fill(&mut buf);
146
147            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
148            assert!(result.is_ok());
149
150            // Do the reading manually without using recv_frame
151            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes)
152            let read = stream.recv(2).await.unwrap();
153            assert_eq!(read.coalesce(), &[0x80, 0x08]); // 1024 as varint
154            let read = stream.recv(MAX_MESSAGE_SIZE as u64).await.unwrap();
155            assert_eq!(read.coalesce(), buf);
156        });
157    }
158
159    #[test]
160    fn test_send_frame_too_large() {
161        let (mut sink, _) = mocks::Channel::init();
162
163        let executor = deterministic::Runner::default();
164        executor.start(|mut context| async move {
165            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
166            context.fill(&mut buf);
167
168            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
169            assert!(
170                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
171            );
172        });
173    }
174
175    #[test]
176    fn test_read_frame() {
177        let (mut sink, mut stream) = mocks::Channel::init();
178
179        let executor = deterministic::Runner::default();
180        executor.start(|mut context| async move {
181            // Do the writing manually without using send_frame
182            let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
183            context.fill(&mut msg);
184
185            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
186            let mut buf = IoBufMut::with_capacity(2 + msg.len());
187            buf.put_u8(0x80);
188            buf.put_u8(0x08);
189            buf.put_slice(&msg);
190            sink.send(buf.freeze()).await.unwrap();
191
192            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
193            assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
194            assert_eq!(data.coalesce(), msg);
195        });
196    }
197
198    #[test]
199    fn test_read_frame_too_large() {
200        let (mut sink, mut stream) = mocks::Channel::init();
201
202        let executor = deterministic::Runner::default();
203        executor.start(|_| async move {
204            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
205            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
206            let mut buf = IoBufMut::with_capacity(2);
207            buf.put_u8(0x80);
208            buf.put_u8(0x08);
209            sink.send(buf.freeze()).await.unwrap();
210
211            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
212            assert!(
213                matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
214            );
215        });
216    }
217
218    #[test]
219    fn test_recv_frame_incomplete_varint() {
220        let (mut sink, mut stream) = mocks::Channel::init();
221
222        let executor = deterministic::Runner::default();
223        executor.start(|_| async move {
224            // Send incomplete varint (continuation bit set but no following byte)
225            let mut buf = IoBufMut::with_capacity(1);
226            buf.put_u8(0x80); // Continuation bit set, expects more bytes
227
228            sink.send(buf.freeze()).await.unwrap();
229            drop(sink); // Close the sink to simulate a closed stream
230
231            // Expect an error because varint is incomplete
232            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
233            assert!(matches!(&result, Err(Error::RecvFailed(_))));
234        });
235    }
236
237    #[test]
238    fn test_recv_frame_invalid_varint_overflow() {
239        let (mut sink, mut stream) = mocks::Channel::init();
240
241        let executor = deterministic::Runner::default();
242        executor.start(|_| async move {
243            // Send a varint that overflows u32 (more than 5 bytes with continuation bits)
244            let mut buf = IoBufMut::with_capacity(6);
245            buf.put_u8(0xFF); // 7 bits + continue
246            buf.put_u8(0xFF); // 7 bits + continue
247            buf.put_u8(0xFF); // 7 bits + continue
248            buf.put_u8(0xFF); // 7 bits + continue
249            buf.put_u8(0xFF); // 5th byte with overflow bits set + continue
250            buf.put_u8(0x01); // 6th byte
251
252            sink.send(buf.freeze()).await.unwrap();
253
254            // Expect an error because varint overflows u32
255            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
256            assert!(matches!(&result, Err(Error::InvalidVarint)));
257        });
258    }
259
260    #[test]
261    fn test_recv_frame_peek_paths() {
262        let executor = deterministic::Runner::default();
263        executor.start(|mut context| async move {
264            // 300 encodes as [0xAC, 0x02] (2-byte varint)
265            let mut payload = vec![0u8; 300];
266            context.fill(&mut payload[..]);
267
268            // Fast path: peek returns complete varint
269            let (mut sink, mut stream) = mocks::Channel::init();
270            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
271                .await
272                .unwrap();
273            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
274            assert_eq!(data.coalesce(), &payload[..]);
275
276            // Slow path: peek returns empty
277            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0);
278            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
279                .await
280                .unwrap();
281            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
282            assert_eq!(data.coalesce(), &payload[..]);
283
284            // Slow path: peek returns partial varint
285            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1);
286            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
287                .await
288                .unwrap();
289            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
290            assert_eq!(data.coalesce(), &payload[..]);
291        });
292    }
293}