commonware_stream/utils/
codec.rs1use crate::Error;
2use bytes::{BufMut as _, Bytes, BytesMut};
3use commonware_runtime::{Sink, Stream};
4
5pub async fn send_frame<S: Sink>(
8 sink: &mut S,
9 buf: &[u8],
10 max_message_size: usize,
11) -> Result<(), Error> {
12 let n = buf.len();
14 if n == 0 {
15 return Err(Error::SendZeroSize);
16 }
17 if n > max_message_size {
18 return Err(Error::SendTooLarge(n));
19 }
20
21 let mut prefixed_buf = BytesMut::with_capacity(4 + buf.len());
23 let len: u32 = n.try_into().map_err(|_| Error::SendTooLarge(n))?;
24 prefixed_buf.put_u32(len);
25 prefixed_buf.extend_from_slice(buf);
26 sink.send(prefixed_buf).await.map_err(Error::SendFailed)
27}
28
29pub async fn recv_frame<T: Stream>(
32 stream: &mut T,
33 max_message_size: usize,
34) -> Result<Bytes, Error> {
35 let len_buf = stream.recv(vec![0; 4]).await.map_err(Error::RecvFailed)?;
37
38 let len = u32::from_be_bytes(len_buf.as_ref()[..4].try_into().unwrap()) as usize;
40 if len > max_message_size {
41 return Err(Error::RecvTooLarge(len));
42 }
43 if len == 0 {
44 return Err(Error::StreamClosed);
45 }
46
47 let read = stream.recv(vec![0; len]).await.map_err(Error::RecvFailed)?;
49 Ok(read.into())
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55 use commonware_runtime::{deterministic, mocks, Runner};
56 use rand::Rng;
57
58 const MAX_MESSAGE_SIZE: usize = 1024;
59
60 #[test]
61 fn test_send_recv_at_max_message_size() {
62 let (mut sink, mut stream) = mocks::Channel::init();
63
64 let executor = deterministic::Runner::default();
65 executor.start(|mut context| async move {
66 let mut buf = [0u8; MAX_MESSAGE_SIZE];
67 context.fill(&mut buf);
68
69 let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
70 assert!(result.is_ok());
71
72 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
73 assert_eq!(data.len(), buf.len());
74 assert_eq!(data, Bytes::from(buf.to_vec()));
75 });
76 }
77
78 #[test]
79 fn test_send_recv_multiple() {
80 let (mut sink, mut stream) = mocks::Channel::init();
81
82 let executor = deterministic::Runner::default();
83 executor.start(|mut context| async move {
84 let mut buf1 = [0u8; MAX_MESSAGE_SIZE];
85 let mut buf2 = [0u8; MAX_MESSAGE_SIZE / 2];
86 context.fill(&mut buf1);
87 context.fill(&mut buf2);
88
89 let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await;
91 assert!(result.is_ok());
92 let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await;
93 assert!(result.is_ok());
94
95 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
97 assert_eq!(data.len(), buf1.len());
98 assert_eq!(data, Bytes::from(buf1.to_vec()));
99 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
100 assert_eq!(data.len(), buf2.len());
101 assert_eq!(data, Bytes::from(buf2.to_vec()));
102 });
103 }
104
105 #[test]
106 fn test_send_frame() {
107 let (mut sink, mut stream) = mocks::Channel::init();
108
109 let executor = deterministic::Runner::default();
110 executor.start(|mut context| async move {
111 let mut buf = [0u8; MAX_MESSAGE_SIZE];
112 context.fill(&mut buf);
113
114 let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
115 assert!(result.is_ok());
116
117 let read = stream.recv(vec![0; 4]).await.unwrap();
119 assert_eq!(read.as_ref(), (buf.len() as u32).to_be_bytes());
120 let read = stream.recv(vec![0; MAX_MESSAGE_SIZE]).await.unwrap();
121 assert_eq!(read.as_ref(), buf);
122 });
123 }
124
125 #[test]
126 fn test_send_frame_too_large() {
127 const MAX_MESSAGE_SIZE: usize = 1024;
128 let (mut sink, _) = mocks::Channel::init();
129
130 let executor = deterministic::Runner::default();
131 executor.start(|mut context| async move {
132 let mut buf = [0u8; MAX_MESSAGE_SIZE];
133 context.fill(&mut buf);
134
135 let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await;
136 assert!(matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
137 });
138 }
139
140 #[test]
141 fn test_send_zero_size() {
142 let (mut sink, _) = mocks::Channel::init();
143
144 let executor = deterministic::Runner::default();
145 executor.start(|_| async move {
146 let buf = [];
147 let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await;
148 assert!(matches!(&result, Err(Error::SendZeroSize)));
149 });
150 }
151
152 #[test]
153 fn test_read_frame() {
154 let (mut sink, mut stream) = mocks::Channel::init();
155
156 let executor = deterministic::Runner::default();
157 executor.start(|mut context| async move {
158 let mut msg = [0u8; MAX_MESSAGE_SIZE];
160 context.fill(&mut msg);
161
162 let mut buf = BytesMut::with_capacity(4 + msg.len());
163 buf.put_u32(MAX_MESSAGE_SIZE as u32);
164 buf.extend_from_slice(&msg);
165 sink.send(buf).await.unwrap();
166
167 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
168 assert_eq!(data.len(), MAX_MESSAGE_SIZE);
169 assert_eq!(data, msg.as_ref());
170 });
171 }
172
173 #[test]
174 fn test_read_frame_too_large() {
175 let (mut sink, mut stream) = mocks::Channel::init();
176
177 let executor = deterministic::Runner::default();
178 executor.start(|_| async move {
179 let mut buf = BytesMut::with_capacity(4);
181 buf.put_u32(MAX_MESSAGE_SIZE as u32);
182 sink.send(buf).await.unwrap();
183
184 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
185 assert!(matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE));
186 });
187 }
188
189 #[test]
190 fn test_read_zero_size() {
191 let (mut sink, mut stream) = mocks::Channel::init();
192
193 let executor = deterministic::Runner::default();
194 executor.start(|_| async move {
195 let mut buf = BytesMut::with_capacity(4);
197 buf.put_u32(0);
198 sink.send(buf).await.unwrap();
199
200 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
201 assert!(matches!(&result, Err(Error::StreamClosed)));
202 });
203 }
204
205 #[test]
206 fn test_recv_frame_short_length_prefix() {
207 let (mut sink, mut stream) = mocks::Channel::init();
208
209 let executor = deterministic::Runner::default();
210 executor.start(|_| async move {
211 let mut buf = BytesMut::with_capacity(3);
213 buf.put_u8(0x00);
214 buf.put_u8(0x00);
215 buf.put_u8(0x00);
216
217 sink.send(buf).await.unwrap();
218 drop(sink); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
222 assert!(matches!(&result, Err(Error::RecvFailed(_))));
223 });
224 }
225}