commonware_stream/utils/
codec.rs1use crate::Error;
2use bytes::{Bytes, BytesMut};
3use commonware_codec::{
4 varint::{Decoder, UInt},
5 EncodeSize as _, Write as _,
6};
7use commonware_runtime::{Sink, Stream};
8use commonware_utils::StableBuf;
9
10pub async fn send_frame<S: Sink>(
13 sink: &mut S,
14 buf: &[u8],
15 max_message_size: u32,
16) -> Result<(), Error> {
17 let n = buf.len();
19 if n > max_message_size as usize {
20 return Err(Error::SendTooLarge(n));
21 }
22
23 let len = UInt(n as u32);
25 let mut prefixed_buf = BytesMut::with_capacity(len.encode_size() + buf.len());
26 len.write(&mut prefixed_buf);
27 prefixed_buf.extend_from_slice(buf);
28 sink.send(prefixed_buf).await.map_err(Error::SendFailed)
29}
30
31pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<Bytes, Error> {
35 let mut decoder = Decoder::<u32>::new();
37 let mut buf = StableBuf::from(vec![0u8; 1]);
38 let len = loop {
39 buf = stream.recv(buf).await.map_err(Error::RecvFailed)?;
40 match decoder.feed(buf[0]) {
41 Ok(Some(len)) => break len as usize,
42 Ok(None) => continue,
43 Err(_) => return Err(Error::InvalidVarint),
44 }
45 };
46
47 if len > max_message_size as usize {
49 return Err(Error::RecvTooLarge(len));
50 }
51
52 let read = stream.recv(vec![0; len]).await.map_err(Error::RecvFailed)?;
54 Ok(read.into())
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use bytes::BufMut;
61 use commonware_runtime::{deterministic, mocks, Runner};
62 use rand::Rng;
63
64 const MAX_MESSAGE_SIZE: u32 = 1024;
65
66 #[test]
67 fn test_send_recv_at_max_message_size() {
68 let (mut sink, mut stream) = mocks::Channel::init();
69
70 let executor = deterministic::Runner::default();
71 executor.start(|mut context| async move {
72 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
73 context.fill(&mut buf);
74
75 let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
76 assert!(result.is_ok());
77
78 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
79 assert_eq!(data.len(), buf.len());
80 assert_eq!(data, Bytes::from(buf.to_vec()));
81 });
82 }
83
84 #[test]
85 fn test_send_recv_multiple() {
86 let (mut sink, mut stream) = mocks::Channel::init();
87
88 let executor = deterministic::Runner::default();
89 executor.start(|mut context| async move {
90 let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
91 let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
92 context.fill(&mut buf1);
93 context.fill(&mut buf2);
94
95 let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await;
97 assert!(result.is_ok());
98 let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await;
99 assert!(result.is_ok());
100
101 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
103 assert_eq!(data.len(), buf1.len());
104 assert_eq!(data, Bytes::from(buf1.to_vec()));
105 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
106 assert_eq!(data.len(), buf2.len());
107 assert_eq!(data, Bytes::from(buf2.to_vec()));
108 });
109 }
110
111 #[test]
112 fn test_send_frame() {
113 let (mut sink, mut stream) = mocks::Channel::init();
114
115 let executor = deterministic::Runner::default();
116 executor.start(|mut context| async move {
117 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
118 context.fill(&mut buf);
119
120 let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
121 assert!(result.is_ok());
122
123 let read = stream.recv(vec![0; 2]).await.unwrap();
126 assert_eq!(read.as_ref(), &[0x80, 0x08]); let read = stream
128 .recv(vec![0; MAX_MESSAGE_SIZE as usize])
129 .await
130 .unwrap();
131 assert_eq!(read.as_ref(), buf);
132 });
133 }
134
135 #[test]
136 fn test_send_frame_too_large() {
137 let (mut sink, _) = mocks::Channel::init();
138
139 let executor = deterministic::Runner::default();
140 executor.start(|mut context| async move {
141 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
142 context.fill(&mut buf);
143
144 let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await;
145 assert!(
146 matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
147 );
148 });
149 }
150
151 #[test]
152 fn test_read_frame() {
153 let (mut sink, mut stream) = mocks::Channel::init();
154
155 let executor = deterministic::Runner::default();
156 executor.start(|mut context| async move {
157 let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
159 context.fill(&mut msg);
160
161 let mut buf = BytesMut::with_capacity(2 + msg.len());
163 buf.put_u8(0x80);
164 buf.put_u8(0x08);
165 buf.extend_from_slice(&msg);
166 sink.send(buf).await.unwrap();
167
168 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
169 assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
170 assert_eq!(data, msg.as_ref());
171 });
172 }
173
174 #[test]
175 fn test_read_frame_too_large() {
176 let (mut sink, mut stream) = mocks::Channel::init();
177
178 let executor = deterministic::Runner::default();
179 executor.start(|_| async move {
180 let mut buf = BytesMut::with_capacity(2);
183 buf.put_u8(0x80);
184 buf.put_u8(0x08);
185 sink.send(buf).await.unwrap();
186
187 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
188 assert!(
189 matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
190 );
191 });
192 }
193
194 #[test]
195 fn test_recv_frame_incomplete_varint() {
196 let (mut sink, mut stream) = mocks::Channel::init();
197
198 let executor = deterministic::Runner::default();
199 executor.start(|_| async move {
200 let mut buf = BytesMut::with_capacity(1);
202 buf.put_u8(0x80); sink.send(buf).await.unwrap();
205 drop(sink); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
209 assert!(matches!(&result, Err(Error::RecvFailed(_))));
210 });
211 }
212
213 #[test]
214 fn test_recv_frame_invalid_varint_overflow() {
215 let (mut sink, mut stream) = mocks::Channel::init();
216
217 let executor = deterministic::Runner::default();
218 executor.start(|_| async move {
219 let mut buf = BytesMut::with_capacity(6);
221 buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0x01); sink.send(buf).await.unwrap();
229
230 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
232 assert!(matches!(&result, Err(Error::InvalidVarint)));
233 });
234 }
235}