1use crate::encrypted::Error;
2use commonware_codec::{
3 varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
4 Encode, EncodeSize, Write,
5};
6use commonware_runtime::{Buf, IoBuf, IoBufMut, IoBufs, Sink, Stream};
7
8pub(crate) fn build_frame<T>(
17 payload_len: usize,
18 max_message_size: u32,
19 assemble: impl FnOnce(UInt<u32>) -> Result<T, Error>,
20) -> Result<T, Error> {
21 if payload_len > max_message_size as usize {
22 return Err(Error::SendTooLarge(payload_len));
23 }
24 let prefix = UInt(payload_len as u32);
25 assemble(prefix)
26}
27
28pub(crate) fn framed_len(payload_len: usize, max_message_size: u32) -> Result<usize, Error> {
30 build_frame(payload_len, max_message_size, |prefix| {
31 Ok(prefix.encode_size() + payload_len)
32 })
33}
34
35pub(crate) fn append_frame(
40 frame: &mut IoBufMut,
41 payload_len: usize,
42 max_message_size: u32,
43 append_payload: impl FnOnce(&mut IoBufMut, usize) -> Result<(), Error>,
44) -> Result<usize, Error> {
45 build_frame(payload_len, max_message_size, |prefix| {
46 let start = frame.len();
47 prefix.write(frame);
48 let payload_offset = frame.len();
49 append_payload(frame, payload_offset)?;
50 assert_eq!(frame.len() - payload_offset, payload_len);
51 Ok(frame.len() - start)
52 })
53}
54
55pub async fn send_frame<S: Sink>(
62 sink: &mut S,
63 bufs: impl Into<IoBufs> + Send,
64 max_message_size: u32,
65) -> Result<(), Error> {
66 let mut bufs = bufs.into();
67
68 let frame = build_frame(bufs.len(), max_message_size, |prefix| {
69 bufs.prepend(IoBuf::from(prefix.encode()));
70 Ok(bufs)
71 })?;
72 sink.send(frame).await.map_err(Error::SendFailed)
73}
74
75pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
79 let (len, skip) = recv_length(stream).await?;
80 if len > max_message_size as usize {
81 return Err(Error::RecvTooLarge(len));
82 }
83
84 stream
85 .recv(skip + len)
86 .await
87 .map(|mut bufs| {
88 bufs.advance(skip);
89 bufs
90 })
91 .map_err(Error::RecvFailed)
92}
93
94async fn recv_length<T: Stream>(stream: &mut T) -> Result<(usize, usize), Error> {
99 let mut decoder = Decoder::<u32>::new();
100
101 let peeked = {
103 let peeked = stream.peek(MAX_U32_VARINT_SIZE);
104 for (i, byte) in peeked.iter().enumerate() {
105 match decoder.feed(*byte) {
106 Ok(Some(len)) => return Ok((len as usize, i + 1)),
107 Ok(None) => continue,
108 Err(_) => return Err(Error::InvalidVarint),
109 }
110 }
111 peeked.len()
112 };
113
114 let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?;
116 buf.advance(peeked);
117
118 loop {
119 match decoder.feed(buf.get_u8()) {
120 Ok(Some(len)) => return Ok((len as usize, 0)),
121 Ok(None) => {}
122 Err(_) => return Err(Error::InvalidVarint),
123 }
124 buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use commonware_runtime::{
132 deterministic, mocks, BufMut, IoBufMut, Runner, Spawner, Supervisor as _,
133 };
134 use rand::Rng;
135
136 const MAX_MESSAGE_SIZE: u32 = 1024;
137
138 #[test]
139 fn test_send_recv_at_max_message_size() {
140 let (mut sink, mut stream) = mocks::Channel::init();
141
142 let executor = deterministic::Runner::default();
143 executor.start(|mut context| async move {
144 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
145 context.fill(&mut buf);
146
147 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
148 assert!(result.is_ok());
149
150 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
151 assert_eq!(data.len(), buf.len());
152 assert_eq!(data.coalesce(), buf);
153 });
154 }
155
156 #[test]
157 fn test_send_recv_multiple() {
158 let (mut sink, mut stream) = mocks::Channel::init();
159
160 let executor = deterministic::Runner::default();
161 executor.start(|mut context| async move {
162 let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
163 let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
164 context.fill(&mut buf1);
165 context.fill(&mut buf2);
166
167 let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
169 assert!(result.is_ok());
170 let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
171 assert!(result.is_ok());
172
173 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
175 assert_eq!(data.len(), buf1.len());
176 assert_eq!(data.coalesce(), buf1);
177 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
178 assert_eq!(data.len(), buf2.len());
179 assert_eq!(data.coalesce(), buf2);
180 });
181 }
182
183 #[test]
184 fn test_send_frame() {
185 let (mut sink, mut stream) = mocks::Channel::init();
186
187 let executor = deterministic::Runner::default();
188 executor.start(|mut context| async move {
189 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
190 context.fill(&mut buf);
191
192 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
193 assert!(result.is_ok());
194
195 let read = stream.recv(2).await.unwrap();
198 assert_eq!(read.coalesce(), &[0x80, 0x08]); let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap();
200 assert_eq!(read.coalesce(), buf);
201 });
202 }
203
204 #[test]
205 fn test_build_frame_closure_error() {
206 let result: Result<IoBufs, _> = build_frame(10, MAX_MESSAGE_SIZE, |_prefix| {
207 Err(Error::HandshakeError(
208 commonware_cryptography::handshake::Error::EncryptionFailed,
209 ))
210 });
211 assert!(matches!(&result, Err(Error::HandshakeError(_))));
212 }
213
214 #[test]
215 fn test_build_frame_too_large() {
216 let result: Result<IoBufs, _> = build_frame(
217 MAX_MESSAGE_SIZE as usize + 1,
218 MAX_MESSAGE_SIZE,
219 |_prefix| unreachable!(),
220 );
221 assert!(
222 matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1)
223 );
224 }
225
226 #[test]
227 fn test_send_frame_too_large() {
228 let (mut sink, _) = mocks::Channel::init();
229
230 let executor = deterministic::Runner::default();
231 executor.start(|mut context| async move {
232 let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
233 context.fill(&mut buf);
234
235 let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
236 assert!(
237 matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
238 );
239 });
240 }
241
242 #[test]
243 fn test_read_frame() {
244 let (mut sink, mut stream) = mocks::Channel::init();
245
246 let executor = deterministic::Runner::default();
247 executor.start(|mut context| async move {
248 let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
250 context.fill(&mut msg);
251
252 let mut buf = IoBufMut::with_capacity(2 + msg.len());
254 buf.put_u8(0x80);
255 buf.put_u8(0x08);
256 buf.put_slice(&msg);
257 sink.send(buf.freeze()).await.unwrap();
258
259 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
260 assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
261 assert_eq!(data.coalesce(), msg);
262 });
263 }
264
265 #[test]
266 fn test_read_frame_too_large() {
267 let (mut sink, mut stream) = mocks::Channel::init();
268
269 let executor = deterministic::Runner::default();
270 executor.start(|_| async move {
271 let mut buf = IoBufMut::with_capacity(2);
274 buf.put_u8(0x80);
275 buf.put_u8(0x08);
276 sink.send(buf.freeze()).await.unwrap();
277
278 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
279 assert!(
280 matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
281 );
282 });
283 }
284
285 #[test]
286 fn test_recv_frame_incomplete_varint() {
287 let (mut sink, mut stream) = mocks::Channel::init();
288
289 let executor = deterministic::Runner::default();
290 executor.start(|_| async move {
291 let mut buf = IoBufMut::with_capacity(1);
293 buf.put_u8(0x80); sink.send(buf.freeze()).await.unwrap();
296 drop(sink); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
300 assert!(matches!(&result, Err(Error::RecvFailed(_))));
301 });
302 }
303
304 #[test]
305 fn test_recv_frame_invalid_varint_overflow() {
306 let (mut sink, mut stream) = mocks::Channel::init();
307
308 let executor = deterministic::Runner::default();
309 executor.start(|_| async move {
310 let mut buf = IoBufMut::with_capacity(6);
312 buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0x01); sink.send(buf.freeze()).await.unwrap();
320
321 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
323 assert!(matches!(&result, Err(Error::InvalidVarint)));
324 });
325 }
326
327 #[test]
328 fn test_recv_frame_peek_paths() {
329 let executor = deterministic::Runner::default();
330 executor.start(|mut context| async move {
331 let mut payload = vec![0u8; 300];
333 context.fill(&mut payload[..]);
334
335 let (mut sink, mut stream) = mocks::Channel::init();
337 send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
338 .await
339 .unwrap();
340 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
341 assert_eq!(data.coalesce(), &payload[..]);
342
343 let (mut sink, mut stream) = mocks::Channel::init_with_buffer_size(0);
346 let payload2 = payload.clone();
347 let send_handle = context.child("sender_empty_peek").spawn(|_| async move {
348 send_frame(&mut sink, payload2, MAX_MESSAGE_SIZE)
349 .await
350 .unwrap();
351 });
352 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
353 assert_eq!(data.coalesce(), &payload[..]);
354 send_handle.await.unwrap();
355
356 let (mut sink, mut stream) = mocks::Channel::init_with_buffer_size(1);
358 let payload2 = payload.clone();
359 let send_handle = context.child("sender_partial_peek").spawn(|_| async move {
360 send_frame(&mut sink, payload2, MAX_MESSAGE_SIZE)
361 .await
362 .unwrap();
363 });
364 let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
365 assert_eq!(data.coalesce(), &payload[..]);
366 send_handle.await.unwrap();
367 });
368 }
369}