commonware_stream/utils/
codec.rs1use crate::encrypted::Error;
2use commonware_codec::{
3 varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
4 Encode,
5};
6use commonware_runtime::{Buf, IoBuf, IoBufs, Sink, Stream};
7
8pub async fn send_frame<S: Sink>(
11 sink: &mut S,
12 buf: impl Into<IoBufs> + Send,
13 max_message_size: u32,
14) -> Result<(), Error> {
15 let mut bufs = buf.into();
16
17 let n = bufs.remaining();
19 if n > max_message_size as usize {
20 return Err(Error::SendTooLarge(n));
21 }
22
23 let len = UInt(n as u32);
25 bufs.prepend(IoBuf::from(len.encode()));
26 sink.send(bufs).await.map_err(Error::SendFailed)
27}
28
29pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
33 let (len, skip) = recv_length(stream).await?;
34 if len > max_message_size {
35 return Err(Error::RecvTooLarge(len as usize));
36 }
37
38 stream
39 .recv(skip as u64 + len as u64)
40 .await
41 .map(|mut bufs| {
42 bufs.advance(skip as usize);
43 bufs
44 })
45 .map_err(Error::RecvFailed)
46}
47
48async fn recv_length<T: Stream>(stream: &mut T) -> Result<(u32, u32), Error> {
53 let mut decoder = Decoder::<u32>::new();
54
55 let peeked = {
57 let peeked = stream.peek(MAX_U32_VARINT_SIZE as u64);
58 for (i, byte) in peeked.iter().enumerate() {
59 match decoder.feed(*byte) {
60 Ok(Some(len)) => return Ok((len, i as u32 + 1)),
61 Ok(None) => continue,
62 Err(_) => return Err(Error::InvalidVarint),
63 }
64 }
65 peeked.len()
66 };
67
68 let mut buf = stream
70 .recv(peeked as u64 + 1)
71 .await
72 .map_err(Error::RecvFailed)?;
73 buf.advance(peeked);
74
75 loop {
76 match decoder.feed(buf.get_u8()) {
77 Ok(Some(len)) => return Ok((len, 0)),
78 Ok(None) => {}
79 Err(_) => return Err(Error::InvalidVarint),
80 }
81 buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner};
89 use rand::Rng;
90
91 const MAX_MESSAGE_SIZE: u32 = 1024;
92
93 #[test]
94 fn test_send_recv_at_max_message_size() {
95 let (mut sink, mut stream) = mocks::Channel::init();
96
97 let executor = deterministic::Runner::default();
98 executor.start(|mut context| async move {
99 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
100 context.fill(&mut buf);
101
102 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
103 assert!(result.is_ok());
104
105 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
106 assert_eq!(data.len(), buf.len());
107 assert_eq!(data.coalesce(), buf);
108 });
109 }
110
111 #[test]
112 fn test_send_recv_multiple() {
113 let (mut sink, mut stream) = mocks::Channel::init();
114
115 let executor = deterministic::Runner::default();
116 executor.start(|mut context| async move {
117 let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
118 let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
119 context.fill(&mut buf1);
120 context.fill(&mut buf2);
121
122 let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
124 assert!(result.is_ok());
125 let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
126 assert!(result.is_ok());
127
128 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
130 assert_eq!(data.len(), buf1.len());
131 assert_eq!(data.coalesce(), buf1);
132 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
133 assert_eq!(data.len(), buf2.len());
134 assert_eq!(data.coalesce(), buf2);
135 });
136 }
137
138 #[test]
139 fn test_send_frame() {
140 let (mut sink, mut stream) = mocks::Channel::init();
141
142 let executor = deterministic::Runner::default();
143 executor.start(|mut context| async move {
144 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
145 context.fill(&mut buf);
146
147 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
148 assert!(result.is_ok());
149
150 let read = stream.recv(2).await.unwrap();
153 assert_eq!(read.coalesce(), &[0x80, 0x08]); let read = stream.recv(MAX_MESSAGE_SIZE as u64).await.unwrap();
155 assert_eq!(read.coalesce(), buf);
156 });
157 }
158
159 #[test]
160 fn test_send_frame_too_large() {
161 let (mut sink, _) = mocks::Channel::init();
162
163 let executor = deterministic::Runner::default();
164 executor.start(|mut context| async move {
165 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
166 context.fill(&mut buf);
167
168 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
169 assert!(
170 matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
171 );
172 });
173 }
174
175 #[test]
176 fn test_read_frame() {
177 let (mut sink, mut stream) = mocks::Channel::init();
178
179 let executor = deterministic::Runner::default();
180 executor.start(|mut context| async move {
181 let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
183 context.fill(&mut msg);
184
185 let mut buf = IoBufMut::with_capacity(2 + msg.len());
187 buf.put_u8(0x80);
188 buf.put_u8(0x08);
189 buf.put_slice(&msg);
190 sink.send(buf.freeze()).await.unwrap();
191
192 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
193 assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
194 assert_eq!(data.coalesce(), msg);
195 });
196 }
197
198 #[test]
199 fn test_read_frame_too_large() {
200 let (mut sink, mut stream) = mocks::Channel::init();
201
202 let executor = deterministic::Runner::default();
203 executor.start(|_| async move {
204 let mut buf = IoBufMut::with_capacity(2);
207 buf.put_u8(0x80);
208 buf.put_u8(0x08);
209 sink.send(buf.freeze()).await.unwrap();
210
211 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
212 assert!(
213 matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
214 );
215 });
216 }
217
218 #[test]
219 fn test_recv_frame_incomplete_varint() {
220 let (mut sink, mut stream) = mocks::Channel::init();
221
222 let executor = deterministic::Runner::default();
223 executor.start(|_| async move {
224 let mut buf = IoBufMut::with_capacity(1);
226 buf.put_u8(0x80); sink.send(buf.freeze()).await.unwrap();
229 drop(sink); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
233 assert!(matches!(&result, Err(Error::RecvFailed(_))));
234 });
235 }
236
237 #[test]
238 fn test_recv_frame_invalid_varint_overflow() {
239 let (mut sink, mut stream) = mocks::Channel::init();
240
241 let executor = deterministic::Runner::default();
242 executor.start(|_| async move {
243 let mut buf = IoBufMut::with_capacity(6);
245 buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0x01); sink.send(buf.freeze()).await.unwrap();
253
254 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
256 assert!(matches!(&result, Err(Error::InvalidVarint)));
257 });
258 }
259
260 #[test]
261 fn test_recv_frame_peek_paths() {
262 let executor = deterministic::Runner::default();
263 executor.start(|mut context| async move {
264 let mut payload = vec![0u8; 300];
266 context.fill(&mut payload[..]);
267
268 let (mut sink, mut stream) = mocks::Channel::init();
270 send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
271 .await
272 .unwrap();
273 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
274 assert_eq!(data.coalesce(), &payload[..]);
275
276 let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0);
278 send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
279 .await
280 .unwrap();
281 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
282 assert_eq!(data.coalesce(), &payload[..]);
283
284 let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1);
286 send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
287 .await
288 .unwrap();
289 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
290 assert_eq!(data.coalesce(), &payload[..]);
291 });
292 }
293}