Skip to main content

commonware_stream/utils/
codec.rs

1use crate::encrypted::Error;
2use commonware_codec::{
3    varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
4    Encode, EncodeSize, Write,
5};
6use commonware_runtime::{Buf, IoBuf, IoBufMut, IoBufs, Sink, Stream};
7
8/// Validates the frame size and assembles the frame via the caller's closure.
9///
10/// The `assemble` closure receives the varint prefix and must combine it with
11/// the payload. This allows callers to choose between:
12/// - Chunked: prepend the prefix as a separate buffer
13/// - Contiguous: write the prefix directly into a pre-allocated buffer
14///
15/// Returns an error if the message is too large.
16pub(crate) fn build_frame<T>(
17    payload_len: usize,
18    max_message_size: u32,
19    assemble: impl FnOnce(UInt<u32>) -> Result<T, Error>,
20) -> Result<T, Error> {
21    if payload_len > max_message_size as usize {
22        return Err(Error::SendTooLarge(payload_len));
23    }
24    let prefix = UInt(payload_len as u32);
25    assemble(prefix)
26}
27
28/// Returns the total size of a length-prefixed frame.
29pub(crate) fn framed_len(payload_len: usize, max_message_size: u32) -> Result<usize, Error> {
30    build_frame(payload_len, max_message_size, |prefix| {
31        Ok(prefix.encode_size() + payload_len)
32    })
33}
34
35/// Appends one length-prefixed frame to a contiguous output buffer.
36///
37/// The callback receives the offset of the frame payload, which is useful when
38/// callers need to operate on the payload bytes after copying them.
39pub(crate) fn append_frame(
40    frame: &mut IoBufMut,
41    payload_len: usize,
42    max_message_size: u32,
43    append_payload: impl FnOnce(&mut IoBufMut, usize) -> Result<(), Error>,
44) -> Result<usize, Error> {
45    build_frame(payload_len, max_message_size, |prefix| {
46        let start = frame.len();
47        prefix.write(frame);
48        let payload_offset = frame.len();
49        append_payload(frame, payload_offset)?;
50        assert_eq!(frame.len() - payload_offset, payload_len);
51        Ok(frame.len() - start)
52    })
53}
54
55/// Sends data to the sink with a varint length prefix.
56///
57/// The varint length prefix is prepended to the buffer(s), which results in a
58/// chunked `IoBufs`.
59///
60/// Returns an error if the message is too large or the sink is closed.
61pub async fn send_frame<S: Sink>(
62    sink: &mut S,
63    bufs: impl Into<IoBufs> + Send,
64    max_message_size: u32,
65) -> Result<(), Error> {
66    let mut bufs = bufs.into();
67
68    let frame = build_frame(bufs.len(), max_message_size, |prefix| {
69        bufs.prepend(IoBuf::from(prefix.encode()));
70        Ok(bufs)
71    })?;
72    sink.send(frame).await.map_err(Error::SendFailed)
73}
74
75/// Receives data from the stream with a varint length prefix.
76/// Returns an error if the message is too large, the varint is invalid, or the
77/// stream is closed.
78pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
79    let (len, skip) = recv_length(stream).await?;
80    if len > max_message_size as usize {
81        return Err(Error::RecvTooLarge(len));
82    }
83
84    stream
85        .recv(skip + len)
86        .await
87        .map(|mut bufs| {
88            bufs.advance(skip);
89            bufs
90        })
91        .map_err(Error::RecvFailed)
92}
93
94/// Receives and decodes the varint length prefix from the stream.
95/// Returns (payload_len, bytes_to_skip) where bytes_to_skip is:
96/// - varint_len if decoded from peek buffer (bytes not yet consumed)
97/// - 0 if decoded via recv (bytes already consumed)
98async fn recv_length<T: Stream>(stream: &mut T) -> Result<(usize, usize), Error> {
99    let mut decoder = Decoder::<u32>::new();
100
101    // Fast path: decode from peek buffer without blocking
102    let peeked = {
103        let peeked = stream.peek(MAX_U32_VARINT_SIZE);
104        for (i, byte) in peeked.iter().enumerate() {
105            match decoder.feed(*byte) {
106                Ok(Some(len)) => return Ok((len as usize, i + 1)),
107                Ok(None) => continue,
108                Err(_) => return Err(Error::InvalidVarint),
109            }
110        }
111        peeked.len()
112    };
113
114    // Slow path: fetch bytes one at a time (skipping already-decoded peek bytes)
115    let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?;
116    buf.advance(peeked);
117
118    loop {
119        match decoder.feed(buf.get_u8()) {
120            Ok(Some(len)) => return Ok((len as usize, 0)),
121            Ok(None) => {}
122            Err(_) => return Err(Error::InvalidVarint),
123        }
124        buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner};
132    use rand::Rng;
133
134    const MAX_MESSAGE_SIZE: u32 = 1024;
135
136    #[test]
137    fn test_send_recv_at_max_message_size() {
138        let (mut sink, mut stream) = mocks::Channel::init();
139
140        let executor = deterministic::Runner::default();
141        executor.start(|mut context| async move {
142            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
143            context.fill(&mut buf);
144
145            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
146            assert!(result.is_ok());
147
148            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
149            assert_eq!(data.len(), buf.len());
150            assert_eq!(data.coalesce(), buf);
151        });
152    }
153
154    #[test]
155    fn test_send_recv_multiple() {
156        let (mut sink, mut stream) = mocks::Channel::init();
157
158        let executor = deterministic::Runner::default();
159        executor.start(|mut context| async move {
160            let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
161            let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
162            context.fill(&mut buf1);
163            context.fill(&mut buf2);
164
165            // Send two messages of different sizes
166            let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
167            assert!(result.is_ok());
168            let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
169            assert!(result.is_ok());
170
171            // Read both messages in order
172            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
173            assert_eq!(data.len(), buf1.len());
174            assert_eq!(data.coalesce(), buf1);
175            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
176            assert_eq!(data.len(), buf2.len());
177            assert_eq!(data.coalesce(), buf2);
178        });
179    }
180
181    #[test]
182    fn test_send_frame() {
183        let (mut sink, mut stream) = mocks::Channel::init();
184
185        let executor = deterministic::Runner::default();
186        executor.start(|mut context| async move {
187            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
188            context.fill(&mut buf);
189
190            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
191            assert!(result.is_ok());
192
193            // Do the reading manually without using recv_frame
194            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes)
195            let read = stream.recv(2).await.unwrap();
196            assert_eq!(read.coalesce(), &[0x80, 0x08]); // 1024 as varint
197            let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap();
198            assert_eq!(read.coalesce(), buf);
199        });
200    }
201
202    #[test]
203    fn test_build_frame_closure_error() {
204        let result: Result<IoBufs, _> = build_frame(10, MAX_MESSAGE_SIZE, |_prefix| {
205            Err(Error::HandshakeError(
206                commonware_cryptography::handshake::Error::EncryptionFailed,
207            ))
208        });
209        assert!(matches!(&result, Err(Error::HandshakeError(_))));
210    }
211
212    #[test]
213    fn test_build_frame_too_large() {
214        let result: Result<IoBufs, _> = build_frame(
215            MAX_MESSAGE_SIZE as usize + 1,
216            MAX_MESSAGE_SIZE,
217            |_prefix| unreachable!(),
218        );
219        assert!(
220            matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1)
221        );
222    }
223
224    #[test]
225    fn test_send_frame_too_large() {
226        let (mut sink, _) = mocks::Channel::init();
227
228        let executor = deterministic::Runner::default();
229        executor.start(|mut context| async move {
230            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
231            context.fill(&mut buf);
232
233            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
234            assert!(
235                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
236            );
237        });
238    }
239
240    #[test]
241    fn test_read_frame() {
242        let (mut sink, mut stream) = mocks::Channel::init();
243
244        let executor = deterministic::Runner::default();
245        executor.start(|mut context| async move {
246            // Do the writing manually without using send_frame
247            let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
248            context.fill(&mut msg);
249
250            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
251            let mut buf = IoBufMut::with_capacity(2 + msg.len());
252            buf.put_u8(0x80);
253            buf.put_u8(0x08);
254            buf.put_slice(&msg);
255            sink.send(buf.freeze()).await.unwrap();
256
257            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
258            assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
259            assert_eq!(data.coalesce(), msg);
260        });
261    }
262
263    #[test]
264    fn test_read_frame_too_large() {
265        let (mut sink, mut stream) = mocks::Channel::init();
266
267        let executor = deterministic::Runner::default();
268        executor.start(|_| async move {
269            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
270            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
271            let mut buf = IoBufMut::with_capacity(2);
272            buf.put_u8(0x80);
273            buf.put_u8(0x08);
274            sink.send(buf.freeze()).await.unwrap();
275
276            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
277            assert!(
278                matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
279            );
280        });
281    }
282
283    #[test]
284    fn test_recv_frame_incomplete_varint() {
285        let (mut sink, mut stream) = mocks::Channel::init();
286
287        let executor = deterministic::Runner::default();
288        executor.start(|_| async move {
289            // Send incomplete varint (continuation bit set but no following byte)
290            let mut buf = IoBufMut::with_capacity(1);
291            buf.put_u8(0x80); // Continuation bit set, expects more bytes
292
293            sink.send(buf.freeze()).await.unwrap();
294            drop(sink); // Close the sink to simulate a closed stream
295
296            // Expect an error because varint is incomplete
297            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
298            assert!(matches!(&result, Err(Error::RecvFailed(_))));
299        });
300    }
301
302    #[test]
303    fn test_recv_frame_invalid_varint_overflow() {
304        let (mut sink, mut stream) = mocks::Channel::init();
305
306        let executor = deterministic::Runner::default();
307        executor.start(|_| async move {
308            // Send a varint that overflows u32 (more than 5 bytes with continuation bits)
309            let mut buf = IoBufMut::with_capacity(6);
310            buf.put_u8(0xFF); // 7 bits + continue
311            buf.put_u8(0xFF); // 7 bits + continue
312            buf.put_u8(0xFF); // 7 bits + continue
313            buf.put_u8(0xFF); // 7 bits + continue
314            buf.put_u8(0xFF); // 5th byte with overflow bits set + continue
315            buf.put_u8(0x01); // 6th byte
316
317            sink.send(buf.freeze()).await.unwrap();
318
319            // Expect an error because varint overflows u32
320            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
321            assert!(matches!(&result, Err(Error::InvalidVarint)));
322        });
323    }
324
325    #[test]
326    fn test_recv_frame_peek_paths() {
327        let executor = deterministic::Runner::default();
328        executor.start(|mut context| async move {
329            // 300 encodes as [0xAC, 0x02] (2-byte varint)
330            let mut payload = vec![0u8; 300];
331            context.fill(&mut payload[..]);
332
333            // Fast path: peek returns complete varint
334            let (mut sink, mut stream) = mocks::Channel::init();
335            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
336                .await
337                .unwrap();
338            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
339            assert_eq!(data.coalesce(), &payload[..]);
340
341            // Slow path: peek returns empty
342            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0);
343            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
344                .await
345                .unwrap();
346            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
347            assert_eq!(data.coalesce(), &payload[..]);
348
349            // Slow path: peek returns partial varint
350            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1);
351            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
352                .await
353                .unwrap();
354            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
355            assert_eq!(data.coalesce(), &payload[..]);
356        });
357    }
358}