Skip to main content

commonware_stream/utils/
codec.rs

1use crate::encrypted::Error;
2use commonware_codec::{
3    varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
4    Encode,
5};
6use commonware_runtime::{Buf, IoBuf, IoBufs, Sink, Stream};
7
8/// Sends a frame with a varint length prefix, delegating frame assembly to the caller.
9///
10/// The `assemble` closure receives the varint prefix and must combine it with
11/// the payload. This allows callers to choose between:
12/// - Chunked: prepend the prefix as a separate buffer
13/// - Contiguous: write the prefix directly into a pre-allocated buffer
14///
15/// Returns an error if the message is too large or the sink is closed.
16pub(crate) async fn send_frame_with<S: Sink>(
17    sink: &mut S,
18    payload_len: usize,
19    max_message_size: u32,
20    assemble: impl FnOnce(UInt<u32>) -> Result<IoBufs, Error>,
21) -> Result<(), Error> {
22    // Validate frame size
23    if payload_len > max_message_size as usize {
24        return Err(Error::SendTooLarge(payload_len));
25    }
26    let prefix = UInt(payload_len as u32);
27
28    let frame = assemble(prefix)?;
29    sink.send(frame).await.map_err(Error::SendFailed)
30}
31
32/// Sends data to the sink with a varint length prefix.
33///
34/// The varint length prefix is prepended to the buffer(s), which results in a
35/// chunked `IoBufs`.
36///
37/// Returns an error if the message is too large or the sink is closed.
38pub async fn send_frame<S: Sink>(
39    sink: &mut S,
40    bufs: impl Into<IoBufs> + Send,
41    max_message_size: u32,
42) -> Result<(), Error> {
43    let mut bufs = bufs.into();
44
45    send_frame_with(sink, bufs.len(), max_message_size, |prefix| {
46        // Prepend varint-encoded length
47        bufs.prepend(IoBuf::from(prefix.encode()));
48        Ok(bufs)
49    })
50    .await
51}
52
53/// Receives data from the stream with a varint length prefix.
54/// Returns an error if the message is too large, the varint is invalid, or the
55/// stream is closed.
56pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
57    let (len, skip) = recv_length(stream).await?;
58    if len > max_message_size as usize {
59        return Err(Error::RecvTooLarge(len));
60    }
61
62    stream
63        .recv(skip + len)
64        .await
65        .map(|mut bufs| {
66            bufs.advance(skip);
67            bufs
68        })
69        .map_err(Error::RecvFailed)
70}
71
72/// Receives and decodes the varint length prefix from the stream.
73/// Returns (payload_len, bytes_to_skip) where bytes_to_skip is:
74/// - varint_len if decoded from peek buffer (bytes not yet consumed)
75/// - 0 if decoded via recv (bytes already consumed)
76async fn recv_length<T: Stream>(stream: &mut T) -> Result<(usize, usize), Error> {
77    let mut decoder = Decoder::<u32>::new();
78
79    // Fast path: decode from peek buffer without blocking
80    let peeked = {
81        let peeked = stream.peek(MAX_U32_VARINT_SIZE);
82        for (i, byte) in peeked.iter().enumerate() {
83            match decoder.feed(*byte) {
84                Ok(Some(len)) => return Ok((len as usize, i + 1)),
85                Ok(None) => continue,
86                Err(_) => return Err(Error::InvalidVarint),
87            }
88        }
89        peeked.len()
90    };
91
92    // Slow path: fetch bytes one at a time (skipping already-decoded peek bytes)
93    let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?;
94    buf.advance(peeked);
95
96    loop {
97        match decoder.feed(buf.get_u8()) {
98            Ok(Some(len)) => return Ok((len as usize, 0)),
99            Ok(None) => {}
100            Err(_) => return Err(Error::InvalidVarint),
101        }
102        buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner};
110    use rand::Rng;
111
112    const MAX_MESSAGE_SIZE: u32 = 1024;
113
114    #[test]
115    fn test_send_recv_at_max_message_size() {
116        let (mut sink, mut stream) = mocks::Channel::init();
117
118        let executor = deterministic::Runner::default();
119        executor.start(|mut context| async move {
120            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
121            context.fill(&mut buf);
122
123            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
124            assert!(result.is_ok());
125
126            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
127            assert_eq!(data.len(), buf.len());
128            assert_eq!(data.coalesce(), buf);
129        });
130    }
131
132    #[test]
133    fn test_send_recv_multiple() {
134        let (mut sink, mut stream) = mocks::Channel::init();
135
136        let executor = deterministic::Runner::default();
137        executor.start(|mut context| async move {
138            let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
139            let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
140            context.fill(&mut buf1);
141            context.fill(&mut buf2);
142
143            // Send two messages of different sizes
144            let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
145            assert!(result.is_ok());
146            let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
147            assert!(result.is_ok());
148
149            // Read both messages in order
150            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
151            assert_eq!(data.len(), buf1.len());
152            assert_eq!(data.coalesce(), buf1);
153            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
154            assert_eq!(data.len(), buf2.len());
155            assert_eq!(data.coalesce(), buf2);
156        });
157    }
158
159    #[test]
160    fn test_send_frame() {
161        let (mut sink, mut stream) = mocks::Channel::init();
162
163        let executor = deterministic::Runner::default();
164        executor.start(|mut context| async move {
165            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
166            context.fill(&mut buf);
167
168            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
169            assert!(result.is_ok());
170
171            // Do the reading manually without using recv_frame
172            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes)
173            let read = stream.recv(2).await.unwrap();
174            assert_eq!(read.coalesce(), &[0x80, 0x08]); // 1024 as varint
175            let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap();
176            assert_eq!(read.coalesce(), buf);
177        });
178    }
179
180    #[test]
181    fn test_send_frame_with_closure_error() {
182        let (mut sink, _) = mocks::Channel::init();
183
184        let executor = deterministic::Runner::default();
185        executor.start(|_| async move {
186            let result = send_frame_with(&mut sink, 10, MAX_MESSAGE_SIZE, |_prefix| {
187                Err(Error::HandshakeError(
188                    commonware_cryptography::handshake::Error::EncryptionFailed,
189                ))
190            })
191            .await;
192            assert!(matches!(&result, Err(Error::HandshakeError(_))));
193        });
194    }
195
196    #[test]
197    fn test_send_frame_with_too_large() {
198        let (mut sink, _) = mocks::Channel::init();
199
200        let executor = deterministic::Runner::default();
201        executor.start(|_| async move {
202            let result = send_frame_with(
203                &mut sink,
204                MAX_MESSAGE_SIZE as usize + 1,
205                MAX_MESSAGE_SIZE,
206                |_prefix| unreachable!(),
207            )
208            .await;
209            assert!(
210                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1)
211            );
212        });
213    }
214
215    #[test]
216    fn test_send_frame_too_large() {
217        let (mut sink, _) = mocks::Channel::init();
218
219        let executor = deterministic::Runner::default();
220        executor.start(|mut context| async move {
221            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
222            context.fill(&mut buf);
223
224            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
225            assert!(
226                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
227            );
228        });
229    }
230
231    #[test]
232    fn test_read_frame() {
233        let (mut sink, mut stream) = mocks::Channel::init();
234
235        let executor = deterministic::Runner::default();
236        executor.start(|mut context| async move {
237            // Do the writing manually without using send_frame
238            let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
239            context.fill(&mut msg);
240
241            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
242            let mut buf = IoBufMut::with_capacity(2 + msg.len());
243            buf.put_u8(0x80);
244            buf.put_u8(0x08);
245            buf.put_slice(&msg);
246            sink.send(buf.freeze()).await.unwrap();
247
248            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
249            assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
250            assert_eq!(data.coalesce(), msg);
251        });
252    }
253
254    #[test]
255    fn test_read_frame_too_large() {
256        let (mut sink, mut stream) = mocks::Channel::init();
257
258        let executor = deterministic::Runner::default();
259        executor.start(|_| async move {
260            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
261            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
262            let mut buf = IoBufMut::with_capacity(2);
263            buf.put_u8(0x80);
264            buf.put_u8(0x08);
265            sink.send(buf.freeze()).await.unwrap();
266
267            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
268            assert!(
269                matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
270            );
271        });
272    }
273
274    #[test]
275    fn test_recv_frame_incomplete_varint() {
276        let (mut sink, mut stream) = mocks::Channel::init();
277
278        let executor = deterministic::Runner::default();
279        executor.start(|_| async move {
280            // Send incomplete varint (continuation bit set but no following byte)
281            let mut buf = IoBufMut::with_capacity(1);
282            buf.put_u8(0x80); // Continuation bit set, expects more bytes
283
284            sink.send(buf.freeze()).await.unwrap();
285            drop(sink); // Close the sink to simulate a closed stream
286
287            // Expect an error because varint is incomplete
288            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
289            assert!(matches!(&result, Err(Error::RecvFailed(_))));
290        });
291    }
292
293    #[test]
294    fn test_recv_frame_invalid_varint_overflow() {
295        let (mut sink, mut stream) = mocks::Channel::init();
296
297        let executor = deterministic::Runner::default();
298        executor.start(|_| async move {
299            // Send a varint that overflows u32 (more than 5 bytes with continuation bits)
300            let mut buf = IoBufMut::with_capacity(6);
301            buf.put_u8(0xFF); // 7 bits + continue
302            buf.put_u8(0xFF); // 7 bits + continue
303            buf.put_u8(0xFF); // 7 bits + continue
304            buf.put_u8(0xFF); // 7 bits + continue
305            buf.put_u8(0xFF); // 5th byte with overflow bits set + continue
306            buf.put_u8(0x01); // 6th byte
307
308            sink.send(buf.freeze()).await.unwrap();
309
310            // Expect an error because varint overflows u32
311            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
312            assert!(matches!(&result, Err(Error::InvalidVarint)));
313        });
314    }
315
316    #[test]
317    fn test_recv_frame_peek_paths() {
318        let executor = deterministic::Runner::default();
319        executor.start(|mut context| async move {
320            // 300 encodes as [0xAC, 0x02] (2-byte varint)
321            let mut payload = vec![0u8; 300];
322            context.fill(&mut payload[..]);
323
324            // Fast path: peek returns complete varint
325            let (mut sink, mut stream) = mocks::Channel::init();
326            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
327                .await
328                .unwrap();
329            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
330            assert_eq!(data.coalesce(), &payload[..]);
331
332            // Slow path: peek returns empty
333            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0);
334            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
335                .await
336                .unwrap();
337            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
338            assert_eq!(data.coalesce(), &payload[..]);
339
340            // Slow path: peek returns partial varint
341            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1);
342            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
343                .await
344                .unwrap();
345            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
346            assert_eq!(data.coalesce(), &payload[..]);
347        });
348    }
349}