Skip to main content

commonware_stream/utils/
codec.rs

1use crate::encrypted::Error;
2use commonware_codec::{
3    varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
4    Encode, EncodeSize, Write,
5};
6use commonware_runtime::{Buf, IoBuf, IoBufMut, IoBufs, Sink, Stream};
7
8/// Validates the frame size and assembles the frame via the caller's closure.
9///
10/// The `assemble` closure receives the varint prefix and must combine it with
11/// the payload. This allows callers to choose between:
12/// - Chunked: prepend the prefix as a separate buffer
13/// - Contiguous: write the prefix directly into a pre-allocated buffer
14///
15/// Returns an error if the message is too large.
16pub(crate) fn build_frame<T>(
17    payload_len: usize,
18    max_message_size: u32,
19    assemble: impl FnOnce(UInt<u32>) -> Result<T, Error>,
20) -> Result<T, Error> {
21    if payload_len > max_message_size as usize {
22        return Err(Error::SendTooLarge(payload_len));
23    }
24    let prefix = UInt(payload_len as u32);
25    assemble(prefix)
26}
27
28/// Returns the total size of a length-prefixed frame.
29pub(crate) fn framed_len(payload_len: usize, max_message_size: u32) -> Result<usize, Error> {
30    build_frame(payload_len, max_message_size, |prefix| {
31        Ok(prefix.encode_size() + payload_len)
32    })
33}
34
35/// Appends one length-prefixed frame to a contiguous output buffer.
36///
37/// The callback receives the offset of the frame payload, which is useful when
38/// callers need to operate on the payload bytes after copying them.
39pub(crate) fn append_frame(
40    frame: &mut IoBufMut,
41    payload_len: usize,
42    max_message_size: u32,
43    append_payload: impl FnOnce(&mut IoBufMut, usize) -> Result<(), Error>,
44) -> Result<usize, Error> {
45    build_frame(payload_len, max_message_size, |prefix| {
46        let start = frame.len();
47        prefix.write(frame);
48        let payload_offset = frame.len();
49        append_payload(frame, payload_offset)?;
50        assert_eq!(frame.len() - payload_offset, payload_len);
51        Ok(frame.len() - start)
52    })
53}
54
55/// Sends data to the sink with a varint length prefix.
56///
57/// The varint length prefix is prepended to the buffer(s), which results in a
58/// chunked `IoBufs`.
59///
60/// Returns an error if the message is too large or the sink is closed.
61pub async fn send_frame<S: Sink>(
62    sink: &mut S,
63    bufs: impl Into<IoBufs> + Send,
64    max_message_size: u32,
65) -> Result<(), Error> {
66    let mut bufs = bufs.into();
67
68    let frame = build_frame(bufs.len(), max_message_size, |prefix| {
69        bufs.prepend(IoBuf::from(prefix.encode()));
70        Ok(bufs)
71    })?;
72    sink.send(frame).await.map_err(Error::SendFailed)
73}
74
75/// Receives data from the stream with a varint length prefix.
76/// Returns an error if the message is too large, the varint is invalid, or the
77/// stream is closed.
78pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
79    let (len, skip) = recv_length(stream).await?;
80    if len > max_message_size as usize {
81        return Err(Error::RecvTooLarge(len));
82    }
83
84    stream
85        .recv(skip + len)
86        .await
87        .map(|mut bufs| {
88            bufs.advance(skip);
89            bufs
90        })
91        .map_err(Error::RecvFailed)
92}
93
94/// Receives and decodes the varint length prefix from the stream.
95/// Returns (payload_len, bytes_to_skip) where bytes_to_skip is:
96/// - varint_len if decoded from peek buffer (bytes not yet consumed)
97/// - 0 if decoded via recv (bytes already consumed)
98async fn recv_length<T: Stream>(stream: &mut T) -> Result<(usize, usize), Error> {
99    let mut decoder = Decoder::<u32>::new();
100
101    // Fast path: decode from peek buffer without blocking
102    let peeked = {
103        let peeked = stream.peek(MAX_U32_VARINT_SIZE);
104        for (i, byte) in peeked.iter().enumerate() {
105            match decoder.feed(*byte) {
106                Ok(Some(len)) => return Ok((len as usize, i + 1)),
107                Ok(None) => continue,
108                Err(_) => return Err(Error::InvalidVarint),
109            }
110        }
111        peeked.len()
112    };
113
114    // Slow path: fetch bytes one at a time (skipping already-decoded peek bytes)
115    let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?;
116    buf.advance(peeked);
117
118    loop {
119        match decoder.feed(buf.get_u8()) {
120            Ok(Some(len)) => return Ok((len as usize, 0)),
121            Ok(None) => {}
122            Err(_) => return Err(Error::InvalidVarint),
123        }
124        buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use commonware_runtime::{
132        deterministic, mocks, BufMut, IoBufMut, Runner, Spawner, Supervisor as _,
133    };
134    use rand::Rng;
135
136    const MAX_MESSAGE_SIZE: u32 = 1024;
137
138    #[test]
139    fn test_send_recv_at_max_message_size() {
140        let (mut sink, mut stream) = mocks::Channel::init();
141
142        let executor = deterministic::Runner::default();
143        executor.start(|mut context| async move {
144            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
145            context.fill(&mut buf);
146
147            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
148            assert!(result.is_ok());
149
150            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
151            assert_eq!(data.len(), buf.len());
152            assert_eq!(data.coalesce(), buf);
153        });
154    }
155
156    #[test]
157    fn test_send_recv_multiple() {
158        let (mut sink, mut stream) = mocks::Channel::init();
159
160        let executor = deterministic::Runner::default();
161        executor.start(|mut context| async move {
162            let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
163            let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
164            context.fill(&mut buf1);
165            context.fill(&mut buf2);
166
167            // Send two messages of different sizes
168            let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
169            assert!(result.is_ok());
170            let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
171            assert!(result.is_ok());
172
173            // Read both messages in order
174            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
175            assert_eq!(data.len(), buf1.len());
176            assert_eq!(data.coalesce(), buf1);
177            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
178            assert_eq!(data.len(), buf2.len());
179            assert_eq!(data.coalesce(), buf2);
180        });
181    }
182
183    #[test]
184    fn test_send_frame() {
185        let (mut sink, mut stream) = mocks::Channel::init();
186
187        let executor = deterministic::Runner::default();
188        executor.start(|mut context| async move {
189            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
190            context.fill(&mut buf);
191
192            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
193            assert!(result.is_ok());
194
195            // Do the reading manually without using recv_frame
196            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes)
197            let read = stream.recv(2).await.unwrap();
198            assert_eq!(read.coalesce(), &[0x80, 0x08]); // 1024 as varint
199            let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap();
200            assert_eq!(read.coalesce(), buf);
201        });
202    }
203
204    #[test]
205    fn test_build_frame_closure_error() {
206        let result: Result<IoBufs, _> = build_frame(10, MAX_MESSAGE_SIZE, |_prefix| {
207            Err(Error::HandshakeError(
208                commonware_cryptography::handshake::Error::EncryptionFailed,
209            ))
210        });
211        assert!(matches!(&result, Err(Error::HandshakeError(_))));
212    }
213
214    #[test]
215    fn test_build_frame_too_large() {
216        let result: Result<IoBufs, _> = build_frame(
217            MAX_MESSAGE_SIZE as usize + 1,
218            MAX_MESSAGE_SIZE,
219            |_prefix| unreachable!(),
220        );
221        assert!(
222            matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1)
223        );
224    }
225
226    #[test]
227    fn test_send_frame_too_large() {
228        let (mut sink, _) = mocks::Channel::init();
229
230        let executor = deterministic::Runner::default();
231        executor.start(|mut context| async move {
232            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
233            context.fill(&mut buf);
234
235            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
236            assert!(
237                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
238            );
239        });
240    }
241
242    #[test]
243    fn test_read_frame() {
244        let (mut sink, mut stream) = mocks::Channel::init();
245
246        let executor = deterministic::Runner::default();
247        executor.start(|mut context| async move {
248            // Do the writing manually without using send_frame
249            let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
250            context.fill(&mut msg);
251
252            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
253            let mut buf = IoBufMut::with_capacity(2 + msg.len());
254            buf.put_u8(0x80);
255            buf.put_u8(0x08);
256            buf.put_slice(&msg);
257            sink.send(buf.freeze()).await.unwrap();
258
259            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
260            assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
261            assert_eq!(data.coalesce(), msg);
262        });
263    }
264
265    #[test]
266    fn test_read_frame_too_large() {
267        let (mut sink, mut stream) = mocks::Channel::init();
268
269        let executor = deterministic::Runner::default();
270        executor.start(|_| async move {
271            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
272            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
273            let mut buf = IoBufMut::with_capacity(2);
274            buf.put_u8(0x80);
275            buf.put_u8(0x08);
276            sink.send(buf.freeze()).await.unwrap();
277
278            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
279            assert!(
280                matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
281            );
282        });
283    }
284
285    #[test]
286    fn test_recv_frame_incomplete_varint() {
287        let (mut sink, mut stream) = mocks::Channel::init();
288
289        let executor = deterministic::Runner::default();
290        executor.start(|_| async move {
291            // Send incomplete varint (continuation bit set but no following byte)
292            let mut buf = IoBufMut::with_capacity(1);
293            buf.put_u8(0x80); // Continuation bit set, expects more bytes
294
295            sink.send(buf.freeze()).await.unwrap();
296            drop(sink); // Close the sink to simulate a closed stream
297
298            // Expect an error because varint is incomplete
299            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
300            assert!(matches!(&result, Err(Error::RecvFailed(_))));
301        });
302    }
303
304    #[test]
305    fn test_recv_frame_invalid_varint_overflow() {
306        let (mut sink, mut stream) = mocks::Channel::init();
307
308        let executor = deterministic::Runner::default();
309        executor.start(|_| async move {
310            // Send a varint that overflows u32 (more than 5 bytes with continuation bits)
311            let mut buf = IoBufMut::with_capacity(6);
312            buf.put_u8(0xFF); // 7 bits + continue
313            buf.put_u8(0xFF); // 7 bits + continue
314            buf.put_u8(0xFF); // 7 bits + continue
315            buf.put_u8(0xFF); // 7 bits + continue
316            buf.put_u8(0xFF); // 5th byte with overflow bits set + continue
317            buf.put_u8(0x01); // 6th byte
318
319            sink.send(buf.freeze()).await.unwrap();
320
321            // Expect an error because varint overflows u32
322            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
323            assert!(matches!(&result, Err(Error::InvalidVarint)));
324        });
325    }
326
327    #[test]
328    fn test_recv_frame_peek_paths() {
329        let executor = deterministic::Runner::default();
330        executor.start(|mut context| async move {
331            // 300 encodes as [0xAC, 0x02] (2-byte varint)
332            let mut payload = vec![0u8; 300];
333            context.fill(&mut payload[..]);
334
335            // Fast path: peek returns complete varint
336            let (mut sink, mut stream) = mocks::Channel::init();
337            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
338                .await
339                .unwrap();
340            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
341            assert_eq!(data.coalesce(), &payload[..]);
342
343            // Slow path: peek returns empty (buffer_size=0 means send always
344            // blocks, so send and recv must run concurrently).
345            let (mut sink, mut stream) = mocks::Channel::init_with_buffer_size(0);
346            let payload2 = payload.clone();
347            let send_handle = context.child("sender_empty_peek").spawn(|_| async move {
348                send_frame(&mut sink, payload2, MAX_MESSAGE_SIZE)
349                    .await
350                    .unwrap();
351            });
352            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
353            assert_eq!(data.coalesce(), &payload[..]);
354            send_handle.await.unwrap();
355
356            // Slow path: peek returns partial varint
357            let (mut sink, mut stream) = mocks::Channel::init_with_buffer_size(1);
358            let payload2 = payload.clone();
359            let send_handle = context.child("sender_partial_peek").spawn(|_| async move {
360                send_frame(&mut sink, payload2, MAX_MESSAGE_SIZE)
361                    .await
362                    .unwrap();
363            });
364            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
365            assert_eq!(data.coalesce(), &payload[..]);
366            send_handle.await.unwrap();
367        });
368    }
369}