commonware_stream/utils/
codec.rs1use crate::Error;
2use bytes::{Buf, Bytes, BytesMut};
3use commonware_codec::{
4 varint::{Decoder, UInt},
5 Encode,
6};
7use commonware_runtime::{Sink, Stream};
8
9pub async fn send_frame<S: Sink>(
12 sink: &mut S,
13 buf: impl Buf + Send,
14 max_message_size: u32,
15) -> Result<(), Error> {
16 let n = buf.remaining();
18 if n > max_message_size as usize {
19 return Err(Error::SendTooLarge(n));
20 }
21
22 let len = UInt(n as u32);
24 let data = len.encode().chain(buf);
25 sink.send(data).await.map_err(Error::SendFailed)
26}
27
28pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<Bytes, Error> {
32 let mut decoder = Decoder::<u32>::new();
34 let mut buf = [0u8; 1];
35 let len = loop {
36 stream.recv(&mut buf[..]).await.map_err(Error::RecvFailed)?;
37 match decoder.feed(buf[0]) {
38 Ok(Some(len)) => break len as usize,
39 Ok(None) => continue,
40 Err(_) => return Err(Error::InvalidVarint),
41 }
42 };
43 if len > max_message_size as usize {
44 return Err(Error::RecvTooLarge(len));
45 }
46
47 let mut read = BytesMut::zeroed(len);
49 stream
50 .recv(&mut read[..])
51 .await
52 .map_err(Error::RecvFailed)?;
53 Ok(read.freeze())
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use bytes::{BufMut, BytesMut};
60 use commonware_runtime::{deterministic, mocks, Runner};
61 use rand::Rng;
62
63 const MAX_MESSAGE_SIZE: u32 = 1024;
64
65 #[test]
66 fn test_send_recv_at_max_message_size() {
67 let (mut sink, mut stream) = mocks::Channel::init();
68
69 let executor = deterministic::Runner::default();
70 executor.start(|mut context| async move {
71 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
72 context.fill(&mut buf);
73
74 let result = send_frame(&mut sink, buf.as_slice(), MAX_MESSAGE_SIZE).await;
75 assert!(result.is_ok());
76
77 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
78 assert_eq!(data.len(), buf.len());
79 assert_eq!(data, Bytes::from(buf.to_vec()));
80 });
81 }
82
83 #[test]
84 fn test_send_recv_multiple() {
85 let (mut sink, mut stream) = mocks::Channel::init();
86
87 let executor = deterministic::Runner::default();
88 executor.start(|mut context| async move {
89 let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
90 let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
91 context.fill(&mut buf1);
92 context.fill(&mut buf2);
93
94 let result = send_frame(&mut sink, buf1.as_slice(), MAX_MESSAGE_SIZE).await;
96 assert!(result.is_ok());
97 let result = send_frame(&mut sink, buf2.as_slice(), MAX_MESSAGE_SIZE).await;
98 assert!(result.is_ok());
99
100 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
102 assert_eq!(data.len(), buf1.len());
103 assert_eq!(data, Bytes::from(buf1.to_vec()));
104 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
105 assert_eq!(data.len(), buf2.len());
106 assert_eq!(data, Bytes::from(buf2.to_vec()));
107 });
108 }
109
110 #[test]
111 fn test_send_frame() {
112 let (mut sink, mut stream) = mocks::Channel::init();
113
114 let executor = deterministic::Runner::default();
115 executor.start(|mut context| async move {
116 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
117 context.fill(&mut buf);
118
119 let result = send_frame(&mut sink, buf.as_slice(), MAX_MESSAGE_SIZE).await;
120 assert!(result.is_ok());
121
122 let mut read = [0u8; 2];
125 stream.recv(&mut read[..]).await.unwrap();
126 assert_eq!(read.as_ref(), &[0x80, 0x08]); let mut read = vec![0u8; MAX_MESSAGE_SIZE as usize];
128 stream.recv(&mut read[..]).await.unwrap();
129 assert_eq!(read, buf);
130 });
131 }
132
133 #[test]
134 fn test_send_frame_too_large() {
135 let (mut sink, _) = mocks::Channel::init();
136
137 let executor = deterministic::Runner::default();
138 executor.start(|mut context| async move {
139 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
140 context.fill(&mut buf);
141
142 let result = send_frame(&mut sink, buf.as_slice(), MAX_MESSAGE_SIZE - 1).await;
143 assert!(
144 matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
145 );
146 });
147 }
148
149 #[test]
150 fn test_read_frame() {
151 let (mut sink, mut stream) = mocks::Channel::init();
152
153 let executor = deterministic::Runner::default();
154 executor.start(|mut context| async move {
155 let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
157 context.fill(&mut msg);
158
159 let mut buf = BytesMut::with_capacity(2 + msg.len());
161 buf.put_u8(0x80);
162 buf.put_u8(0x08);
163 buf.extend_from_slice(&msg);
164 sink.send(buf).await.unwrap();
165
166 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
167 assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
168 assert_eq!(data, msg.as_ref());
169 });
170 }
171
172 #[test]
173 fn test_read_frame_too_large() {
174 let (mut sink, mut stream) = mocks::Channel::init();
175
176 let executor = deterministic::Runner::default();
177 executor.start(|_| async move {
178 let mut buf = BytesMut::with_capacity(2);
181 buf.put_u8(0x80);
182 buf.put_u8(0x08);
183 sink.send(buf).await.unwrap();
184
185 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
186 assert!(
187 matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
188 );
189 });
190 }
191
192 #[test]
193 fn test_recv_frame_incomplete_varint() {
194 let (mut sink, mut stream) = mocks::Channel::init();
195
196 let executor = deterministic::Runner::default();
197 executor.start(|_| async move {
198 let mut buf = BytesMut::with_capacity(1);
200 buf.put_u8(0x80); sink.send(buf).await.unwrap();
203 drop(sink); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
207 assert!(matches!(&result, Err(Error::RecvFailed(_))));
208 });
209 }
210
211 #[test]
212 fn test_recv_frame_invalid_varint_overflow() {
213 let (mut sink, mut stream) = mocks::Channel::init();
214
215 let executor = deterministic::Runner::default();
216 executor.start(|_| async move {
217 let mut buf = BytesMut::with_capacity(6);
219 buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0x01); sink.send(buf).await.unwrap();
227
228 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
230 assert!(matches!(&result, Err(Error::InvalidVarint)));
231 });
232 }
233}