1use crate::encrypted::Error;
2use commonware_codec::{
3 varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
4 Encode,
5};
6use commonware_runtime::{Buf, IoBuf, IoBufs, Sink, Stream};
7
8pub(crate) async fn send_frame_with<S: Sink>(
17 sink: &mut S,
18 payload_len: usize,
19 max_message_size: u32,
20 assemble: impl FnOnce(UInt<u32>) -> Result<IoBufs, Error>,
21) -> Result<(), Error> {
22 if payload_len > max_message_size as usize {
24 return Err(Error::SendTooLarge(payload_len));
25 }
26 let prefix = UInt(payload_len as u32);
27
28 let frame = assemble(prefix)?;
29 sink.send(frame).await.map_err(Error::SendFailed)
30}
31
32pub async fn send_frame<S: Sink>(
39 sink: &mut S,
40 bufs: impl Into<IoBufs> + Send,
41 max_message_size: u32,
42) -> Result<(), Error> {
43 let mut bufs = bufs.into();
44
45 send_frame_with(sink, bufs.len(), max_message_size, |prefix| {
46 bufs.prepend(IoBuf::from(prefix.encode()));
48 Ok(bufs)
49 })
50 .await
51}
52
53pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
57 let (len, skip) = recv_length(stream).await?;
58 if len > max_message_size as usize {
59 return Err(Error::RecvTooLarge(len));
60 }
61
62 stream
63 .recv(skip + len)
64 .await
65 .map(|mut bufs| {
66 bufs.advance(skip);
67 bufs
68 })
69 .map_err(Error::RecvFailed)
70}
71
72async fn recv_length<T: Stream>(stream: &mut T) -> Result<(usize, usize), Error> {
77 let mut decoder = Decoder::<u32>::new();
78
79 let peeked = {
81 let peeked = stream.peek(MAX_U32_VARINT_SIZE);
82 for (i, byte) in peeked.iter().enumerate() {
83 match decoder.feed(*byte) {
84 Ok(Some(len)) => return Ok((len as usize, i + 1)),
85 Ok(None) => continue,
86 Err(_) => return Err(Error::InvalidVarint),
87 }
88 }
89 peeked.len()
90 };
91
92 let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?;
94 buf.advance(peeked);
95
96 loop {
97 match decoder.feed(buf.get_u8()) {
98 Ok(Some(len)) => return Ok((len as usize, 0)),
99 Ok(None) => {}
100 Err(_) => return Err(Error::InvalidVarint),
101 }
102 buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner};
110 use rand::Rng;
111
112 const MAX_MESSAGE_SIZE: u32 = 1024;
113
114 #[test]
115 fn test_send_recv_at_max_message_size() {
116 let (mut sink, mut stream) = mocks::Channel::init();
117
118 let executor = deterministic::Runner::default();
119 executor.start(|mut context| async move {
120 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
121 context.fill(&mut buf);
122
123 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
124 assert!(result.is_ok());
125
126 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
127 assert_eq!(data.len(), buf.len());
128 assert_eq!(data.coalesce(), buf);
129 });
130 }
131
132 #[test]
133 fn test_send_recv_multiple() {
134 let (mut sink, mut stream) = mocks::Channel::init();
135
136 let executor = deterministic::Runner::default();
137 executor.start(|mut context| async move {
138 let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
139 let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
140 context.fill(&mut buf1);
141 context.fill(&mut buf2);
142
143 let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
145 assert!(result.is_ok());
146 let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
147 assert!(result.is_ok());
148
149 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
151 assert_eq!(data.len(), buf1.len());
152 assert_eq!(data.coalesce(), buf1);
153 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
154 assert_eq!(data.len(), buf2.len());
155 assert_eq!(data.coalesce(), buf2);
156 });
157 }
158
159 #[test]
160 fn test_send_frame() {
161 let (mut sink, mut stream) = mocks::Channel::init();
162
163 let executor = deterministic::Runner::default();
164 executor.start(|mut context| async move {
165 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
166 context.fill(&mut buf);
167
168 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
169 assert!(result.is_ok());
170
171 let read = stream.recv(2).await.unwrap();
174 assert_eq!(read.coalesce(), &[0x80, 0x08]); let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap();
176 assert_eq!(read.coalesce(), buf);
177 });
178 }
179
180 #[test]
181 fn test_send_frame_with_closure_error() {
182 let (mut sink, _) = mocks::Channel::init();
183
184 let executor = deterministic::Runner::default();
185 executor.start(|_| async move {
186 let result = send_frame_with(&mut sink, 10, MAX_MESSAGE_SIZE, |_prefix| {
187 Err(Error::HandshakeError(
188 commonware_cryptography::handshake::Error::EncryptionFailed,
189 ))
190 })
191 .await;
192 assert!(matches!(&result, Err(Error::HandshakeError(_))));
193 });
194 }
195
196 #[test]
197 fn test_send_frame_with_too_large() {
198 let (mut sink, _) = mocks::Channel::init();
199
200 let executor = deterministic::Runner::default();
201 executor.start(|_| async move {
202 let result = send_frame_with(
203 &mut sink,
204 MAX_MESSAGE_SIZE as usize + 1,
205 MAX_MESSAGE_SIZE,
206 |_prefix| unreachable!(),
207 )
208 .await;
209 assert!(
210 matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1)
211 );
212 });
213 }
214
215 #[test]
216 fn test_send_frame_too_large() {
217 let (mut sink, _) = mocks::Channel::init();
218
219 let executor = deterministic::Runner::default();
220 executor.start(|mut context| async move {
221 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
222 context.fill(&mut buf);
223
224 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
225 assert!(
226 matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
227 );
228 });
229 }
230
231 #[test]
232 fn test_read_frame() {
233 let (mut sink, mut stream) = mocks::Channel::init();
234
235 let executor = deterministic::Runner::default();
236 executor.start(|mut context| async move {
237 let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
239 context.fill(&mut msg);
240
241 let mut buf = IoBufMut::with_capacity(2 + msg.len());
243 buf.put_u8(0x80);
244 buf.put_u8(0x08);
245 buf.put_slice(&msg);
246 sink.send(buf.freeze()).await.unwrap();
247
248 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
249 assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
250 assert_eq!(data.coalesce(), msg);
251 });
252 }
253
254 #[test]
255 fn test_read_frame_too_large() {
256 let (mut sink, mut stream) = mocks::Channel::init();
257
258 let executor = deterministic::Runner::default();
259 executor.start(|_| async move {
260 let mut buf = IoBufMut::with_capacity(2);
263 buf.put_u8(0x80);
264 buf.put_u8(0x08);
265 sink.send(buf.freeze()).await.unwrap();
266
267 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
268 assert!(
269 matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
270 );
271 });
272 }
273
274 #[test]
275 fn test_recv_frame_incomplete_varint() {
276 let (mut sink, mut stream) = mocks::Channel::init();
277
278 let executor = deterministic::Runner::default();
279 executor.start(|_| async move {
280 let mut buf = IoBufMut::with_capacity(1);
282 buf.put_u8(0x80); sink.send(buf.freeze()).await.unwrap();
285 drop(sink); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
289 assert!(matches!(&result, Err(Error::RecvFailed(_))));
290 });
291 }
292
293 #[test]
294 fn test_recv_frame_invalid_varint_overflow() {
295 let (mut sink, mut stream) = mocks::Channel::init();
296
297 let executor = deterministic::Runner::default();
298 executor.start(|_| async move {
299 let mut buf = IoBufMut::with_capacity(6);
301 buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0x01); sink.send(buf.freeze()).await.unwrap();
309
310 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
312 assert!(matches!(&result, Err(Error::InvalidVarint)));
313 });
314 }
315
316 #[test]
317 fn test_recv_frame_peek_paths() {
318 let executor = deterministic::Runner::default();
319 executor.start(|mut context| async move {
320 let mut payload = vec![0u8; 300];
322 context.fill(&mut payload[..]);
323
324 let (mut sink, mut stream) = mocks::Channel::init();
326 send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
327 .await
328 .unwrap();
329 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
330 assert_eq!(data.coalesce(), &payload[..]);
331
332 let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0);
334 send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
335 .await
336 .unwrap();
337 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
338 assert_eq!(data.coalesce(), &payload[..]);
339
340 let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1);
342 send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
343 .await
344 .unwrap();
345 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
346 assert_eq!(data.coalesce(), &payload[..]);
347 });
348 }
349}