1use crate::utils::codec::{append_frame, framed_len, recv_frame, send_frame};
59use commonware_codec::{DecodeExt, Encode as _, Error as CodecError, FixedSize};
60use commonware_cryptography::{
61 handshake::{
62 self, dial_end, dial_start, listen_end, listen_start, Ack, Context,
63 Error as HandshakeError, RecvCipher, SendCipher, Syn, SynAck,
64 },
65 transcript::Transcript,
66 Signer,
67};
68use commonware_formatting::hex;
69use commonware_macros::select;
70use commonware_runtime::{
71 BufMut, BufferPool, BufferPooler, Clock, Error as RuntimeError, IoBuf, IoBufMut, IoBufs, Sink,
72 Stream,
73};
74use commonware_utils::SystemTimeExt;
75use rand_core::CryptoRngCore;
76use std::{future::Future, ops::Range, time::Duration};
77use thiserror::Error;
78
79const TAG_SIZE: u32 = {
80 assert!(handshake::TAG_SIZE <= u32::MAX as usize);
81 handshake::TAG_SIZE as u32
82};
83
84#[derive(Error, Debug)]
86pub enum Error {
87 #[error("handshake error: {0}")]
88 HandshakeError(HandshakeError),
89 #[error("unable to decode: {0}")]
90 UnableToDecode(CodecError),
91 #[error("peer rejected: {}", hex(_0))]
92 PeerRejected(Vec<u8>),
93 #[error("recv failed")]
94 RecvFailed(RuntimeError),
95 #[error("recv too large: {0} bytes")]
96 RecvTooLarge(usize),
97 #[error("invalid varint length prefix")]
98 InvalidVarint,
99 #[error("send failed")]
100 SendFailed(RuntimeError),
101 #[error("send zero size")]
102 SendZeroSize,
103 #[error("send too large: {0} bytes")]
104 SendTooLarge(usize),
105 #[error("connection closed")]
106 StreamClosed,
107 #[error("handshake timed out")]
108 HandshakeTimeout,
109}
110
111impl From<CodecError> for Error {
112 fn from(value: CodecError) -> Self {
113 Self::UnableToDecode(value)
114 }
115}
116
117impl From<HandshakeError> for Error {
118 fn from(value: HandshakeError) -> Self {
119 Self::HandshakeError(value)
120 }
121}
122
123#[derive(Clone)]
130pub struct Config<S> {
131 pub signing_key: S,
135
136 pub namespace: Vec<u8>,
139
140 pub max_message_size: u32,
145
146 pub synchrony_bound: Duration,
148
149 pub max_handshake_age: Duration,
151
152 pub handshake_timeout: Duration,
154}
155
156impl<S> Config<S> {
157 pub fn time_information(&self, ctx: &impl Clock) -> (u64, Range<u64>) {
159 fn duration_to_u64(d: Duration) -> u64 {
160 u64::try_from(d.as_millis()).expect("duration ms should fit in an u64")
161 }
162 let current_time_ms = duration_to_u64(ctx.current().epoch());
163 let ok_timestamps = (current_time_ms
164 .saturating_sub(duration_to_u64(self.max_handshake_age)))
165 ..(current_time_ms.saturating_add(duration_to_u64(self.synchrony_bound)));
166 (current_time_ms, ok_timestamps)
167 }
168}
169
170async fn recv_handshake_frame<M, T>(stream: &mut T) -> Result<M, Error>
173where
174 M: DecodeExt<()> + FixedSize,
175 T: Stream,
176{
177 let frame = recv_frame(
178 stream,
179 u32::try_from(M::SIZE).expect("handshake frame should fit in u32"),
180 )
181 .await?;
182 Ok(M::decode(frame)?)
183}
184
185pub async fn dial<R: BufferPooler + CryptoRngCore + Clock, S: Signer, I: Stream, O: Sink>(
188 ctx: R,
189 config: Config<S>,
190 peer: S::PublicKey,
191 mut stream: I,
192 mut sink: O,
193) -> Result<(Sender<O>, Receiver<I>), Error> {
194 let pool = ctx.network_buffer_pool().clone();
195 let timeout = ctx.sleep(config.handshake_timeout);
196 let inner_routine = async move {
197 send_frame(
198 &mut sink,
199 config.signing_key.public_key().encode(),
200 config.max_message_size,
201 )
202 .await?;
203
204 let (current_time, ok_timestamps) = config.time_information(&ctx);
205 let (state, syn) = dial_start(
206 ctx,
207 Context::new(
208 &Transcript::new(&config.namespace),
209 current_time,
210 ok_timestamps,
211 config.signing_key,
212 peer,
213 ),
214 );
215 send_frame(&mut sink, syn.encode(), config.max_message_size).await?;
216
217 let syn_ack = recv_handshake_frame::<SynAck<S::Signature>, _>(&mut stream).await?;
218
219 let (ack, send, recv) = dial_end(state, syn_ack)?;
220 send_frame(&mut sink, ack.encode(), config.max_message_size).await?;
221
222 Ok((
223 Sender {
224 cipher: send,
225 sink,
226 max_message_size: config.max_message_size,
227 pool: pool.clone(),
228 },
229 Receiver {
230 cipher: recv,
231 stream,
232 max_message_size: config.max_message_size,
233 pool,
234 },
235 ))
236 };
237
238 select! {
239 x = inner_routine => x,
240 _ = timeout => Err(Error::HandshakeTimeout),
241 }
242}
243
244pub async fn listen<
247 R: BufferPooler + CryptoRngCore + Clock,
248 S: Signer,
249 I: Stream,
250 O: Sink,
251 Fut: Future<Output = bool>,
252 F: FnOnce(S::PublicKey) -> Fut,
253>(
254 ctx: R,
255 bouncer: F,
256 config: Config<S>,
257 mut stream: I,
258 mut sink: O,
259) -> Result<(S::PublicKey, Sender<O>, Receiver<I>), Error> {
260 let pool = ctx.network_buffer_pool().clone();
261 let timeout = ctx.sleep(config.handshake_timeout);
262 let inner_routine = async move {
263 let peer = recv_handshake_frame::<S::PublicKey, _>(&mut stream).await?;
264 if !bouncer(peer.clone()).await {
265 return Err(Error::PeerRejected(peer.encode().to_vec()));
266 }
267
268 let msg1 = recv_handshake_frame::<Syn<S::Signature>, _>(&mut stream).await?;
269
270 let (current_time, ok_timestamps) = config.time_information(&ctx);
271 let (state, syn_ack) = listen_start(
272 ctx,
273 Context::new(
274 &Transcript::new(&config.namespace),
275 current_time,
276 ok_timestamps,
277 config.signing_key,
278 peer.clone(),
279 ),
280 msg1,
281 )?;
282 send_frame(&mut sink, syn_ack.encode(), config.max_message_size).await?;
283
284 let ack = recv_handshake_frame::<Ack, _>(&mut stream).await?;
285
286 let (send, recv) = listen_end(state, ack)?;
287
288 Ok((
289 peer,
290 Sender {
291 cipher: send,
292 sink,
293 max_message_size: config.max_message_size,
294 pool: pool.clone(),
295 },
296 Receiver {
297 cipher: recv,
298 stream,
299 max_message_size: config.max_message_size,
300 pool,
301 },
302 ))
303 };
304
305 select! {
306 x = inner_routine => x,
307 _ = timeout => Err(Error::HandshakeTimeout),
308 }
309}
310
311pub struct Sender<O> {
313 cipher: SendCipher,
314 sink: O,
315 max_message_size: u32,
316 pool: BufferPool,
317}
318
319struct ChunkPlan {
321 messages: Vec<IoBufs>,
322 total_len: usize,
323}
324
325impl<O: Sink> Sender<O> {
326 fn encrypted_frame_len(&self, plaintext_len: usize) -> Result<usize, Error> {
330 framed_len(
331 plaintext_len + TAG_SIZE as usize,
332 self.max_message_size.saturating_add(TAG_SIZE),
333 )
334 }
335
336 fn append_encrypted_frame(
342 &mut self,
343 chunk: &mut IoBufMut,
344 mut bufs: IoBufs,
345 ) -> Result<(), Error> {
346 append_frame(
347 chunk,
348 bufs.len() + TAG_SIZE as usize,
349 self.max_message_size.saturating_add(TAG_SIZE),
350 |chunk, plaintext_offset| {
351 chunk.put(&mut bufs);
353
354 let tag = self
356 .cipher
357 .send_in_place(&mut chunk.as_mut()[plaintext_offset..])?;
358 chunk.put_slice(&tag);
359 Ok(())
360 },
361 )?;
362 Ok(())
363 }
364
365 fn build_chunk<I>(&mut self, messages: I, total_len: usize) -> Result<IoBuf, Error>
370 where
371 I: IntoIterator<Item = IoBufs>,
372 {
373 let mut chunk = self.pool.alloc(total_len);
374 for msg in messages {
375 self.append_encrypted_frame(&mut chunk, msg)?;
376 }
377 assert_eq!(chunk.len(), total_len);
378 Ok(chunk.freeze())
379 }
380
381 fn plan_chunks<B, I>(&self, bufs: I) -> Result<Vec<ChunkPlan>, Error>
386 where
387 B: Into<IoBufs>,
388 I: IntoIterator<Item = B>,
389 {
390 let bufs = bufs.into_iter();
391 let (lower, _) = bufs.size_hint();
392 let mut chunks = Vec::with_capacity(lower.max(1));
393 let mut batch = Vec::new();
394 let mut batch_total = 0usize;
395 let max_batch_size = self.pool.config().max_size.get();
396
397 for buf in bufs {
398 let msg = buf.into();
399 let frame_len = self.encrypted_frame_len(msg.len())?;
400
401 if frame_len > max_batch_size {
404 if !batch.is_empty() {
405 chunks.push(ChunkPlan {
406 messages: std::mem::take(&mut batch),
407 total_len: batch_total,
408 });
409 batch_total = 0;
410 }
411 chunks.push(ChunkPlan {
412 messages: vec![msg],
413 total_len: frame_len,
414 });
415 continue;
416 }
417
418 if batch_total.saturating_add(frame_len) > max_batch_size {
421 chunks.push(ChunkPlan {
422 messages: std::mem::take(&mut batch),
423 total_len: batch_total,
424 });
425 batch_total = 0;
426 }
427
428 batch_total += frame_len;
429 batch.push(msg);
430 }
431
432 if !batch.is_empty() {
433 chunks.push(ChunkPlan {
434 messages: batch,
435 total_len: batch_total,
436 });
437 }
438
439 Ok(chunks)
440 }
441
442 pub async fn send(&mut self, bufs: impl Into<IoBufs>) -> Result<(), Error> {
447 let bufs = bufs.into();
448 let frame_len = self.encrypted_frame_len(bufs.len())?;
449 let chunk = self.build_chunk(std::iter::once(bufs), frame_len)?;
450 self.sink.send(chunk).await.map_err(Error::SendFailed)
451 }
452
453 pub async fn send_many<B, I>(&mut self, bufs: I) -> Result<(), Error>
461 where
462 B: Into<IoBufs>,
463 I: IntoIterator<Item = B>,
464 {
465 let plans = self.plan_chunks(bufs)?;
466 if plans.is_empty() {
467 return Ok(());
468 }
469
470 let mut chunks = Vec::with_capacity(plans.len());
471 for plan in plans {
472 chunks.push(self.build_chunk(plan.messages, plan.total_len)?);
473 }
474
475 self.sink
476 .send(IoBufs::from(chunks))
477 .await
478 .map_err(Error::SendFailed)
479 }
480}
481
482pub struct Receiver<I> {
484 cipher: RecvCipher,
485 stream: I,
486 max_message_size: u32,
487 pool: BufferPool,
488}
489
490impl<I: Stream> Receiver<I> {
491 pub async fn recv(&mut self) -> Result<IoBufs, Error> {
496 let mut encrypted = recv_frame(
497 &mut self.stream,
498 self.max_message_size.saturating_add(TAG_SIZE),
499 )
500 .await?;
501 let ciphertext_len = encrypted.len();
502
503 let mut decryption_buf = self.pool.alloc(ciphertext_len);
505
506 decryption_buf.put(&mut encrypted);
508
509 let plaintext_len = self.cipher.recv_in_place(decryption_buf.as_mut())?;
511
512 decryption_buf.truncate(plaintext_len);
514
515 Ok(decryption_buf.freeze().into())
516 }
517}
518
519#[cfg(test)]
520mod test {
521 use super::*;
522 use commonware_codec::varint::UInt;
523 use commonware_cryptography::{ed25519::PrivateKey, Signer};
524 use commonware_runtime::{
525 deterministic, mocks, BufferPoolConfig, Error as RuntimeError, IoBuf, IoBufs, Runner as _,
526 Spawner as _, Supervisor as _,
527 };
528 use commonware_utils::{sync::Mutex, NZUsize};
529 use std::{
530 sync::{
531 atomic::{AtomicUsize, Ordering},
532 Arc,
533 },
534 time::Duration,
535 };
536
537 const NAMESPACE: &[u8] = b"fuzz_transport";
538 const MAX_MESSAGE_SIZE: u32 = 64 * 1024; fn transport_config(signing_key: PrivateKey) -> Config<PrivateKey> {
541 Config {
542 signing_key,
543 namespace: NAMESPACE.to_vec(),
544 max_message_size: MAX_MESSAGE_SIZE,
545 synchrony_bound: Duration::from_secs(1),
546 max_handshake_age: Duration::from_secs(1),
547 handshake_timeout: Duration::from_secs(1),
548 }
549 }
550
551 fn oversized_handshake_prefix(message: &impl commonware_codec::Encode) -> IoBuf {
552 let size = u32::try_from(message.encode().len()).expect("message length should fit in u32");
553 IoBuf::from(UInt(size + 1).encode())
554 }
555
556 struct CountingSink<S> {
557 inner: S,
558 sends: Arc<AtomicUsize>,
559 chunk_counts: Arc<Mutex<Vec<usize>>>,
560 }
561
562 impl<S> CountingSink<S> {
563 fn new(inner: S, sends: Arc<AtomicUsize>, chunk_counts: Arc<Mutex<Vec<usize>>>) -> Self {
564 Self {
565 inner,
566 sends,
567 chunk_counts,
568 }
569 }
570 }
571
572 impl<S: commonware_runtime::Sink> commonware_runtime::Sink for CountingSink<S> {
573 async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), RuntimeError> {
574 let bufs = bufs.into();
575 self.sends.fetch_add(1, Ordering::Relaxed);
576 self.chunk_counts.lock().push(bufs.chunk_count());
577 self.inner.send(bufs).await
578 }
579 }
580
581 #[test]
582 fn test_can_setup_and_send_messages() -> Result<(), Error> {
583 let executor = deterministic::Runner::default();
584 executor.start(|context| async move {
585 let dialer_crypto = PrivateKey::from_seed(42);
586 let listener_crypto = PrivateKey::from_seed(24);
587
588 let (dialer_sink, listener_stream) = mocks::Channel::init();
589 let (listener_sink, dialer_stream) = mocks::Channel::init();
590
591 let dialer_config = transport_config(dialer_crypto.clone());
592 let listener_config = transport_config(listener_crypto.clone());
593
594 let listener_handle = context.child("listener").spawn(move |context| async move {
595 listen(
596 context,
597 |_| async { true },
598 listener_config,
599 listener_stream,
600 listener_sink,
601 )
602 .await
603 });
604
605 let (mut dialer_sender, mut dialer_receiver) = dial(
606 context,
607 dialer_config,
608 listener_crypto.public_key(),
609 dialer_stream,
610 dialer_sink,
611 )
612 .await?;
613
614 let (listener_peer, mut listener_sender, mut listener_receiver) =
615 listener_handle.await.unwrap()?;
616 assert_eq!(listener_peer, dialer_crypto.public_key());
617 let messages: Vec<&'static [u8]> = vec![b"A", b"B", b"C"];
618 for msg in &messages {
619 dialer_sender.send(&msg[..]).await?;
620 let syn_ack = listener_receiver.recv().await?;
621 assert_eq!(syn_ack.coalesce(), *msg);
622 listener_sender.send(&msg[..]).await?;
623 let ack = dialer_receiver.recv().await?;
624 assert_eq!(ack.coalesce(), *msg);
625 }
626 Ok(())
627 })
628 }
629
630 #[test]
631 fn test_send_many_uses_single_runtime_send() -> Result<(), Error> {
632 let executor = deterministic::Runner::default();
633 executor.start(|context| async move {
634 let dialer_crypto = PrivateKey::from_seed(42);
635 let listener_crypto = PrivateKey::from_seed(24);
636
637 let (dialer_sink, listener_stream) = mocks::Channel::init();
638 let (listener_sink, dialer_stream) = mocks::Channel::init();
639 let sends = Arc::new(AtomicUsize::new(0));
640 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
641
642 let dialer_config = transport_config(dialer_crypto.clone());
643 let listener_config = transport_config(listener_crypto.clone());
644
645 let listener_handle = context.child("listener").spawn(move |context| async move {
646 listen(
647 context,
648 |_| async { true },
649 listener_config,
650 listener_stream,
651 listener_sink,
652 )
653 .await
654 });
655
656 let (mut dialer_sender, _dialer_receiver) = dial(
657 context,
658 dialer_config,
659 listener_crypto.public_key(),
660 dialer_stream,
661 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
662 )
663 .await?;
664
665 let (_listener_peer, _listener_sender, mut listener_receiver) =
666 listener_handle.await.unwrap()?;
667 sends.store(0, Ordering::Relaxed);
668 chunk_counts.lock().clear();
669
670 dialer_sender
673 .send_many(vec![
674 IoBufs::from(IoBuf::from(b"alpha")),
675 IoBufs::from(IoBuf::from(b"beta")),
676 IoBufs::from(IoBuf::from(b"gamma")),
677 ])
678 .await?;
679
680 assert_eq!(sends.load(Ordering::Relaxed), 1);
681 assert_eq!(*chunk_counts.lock(), vec![1]);
682 assert_eq!(
683 listener_receiver.recv().await?.coalesce(),
684 IoBuf::from(b"alpha")
685 );
686 assert_eq!(
687 listener_receiver.recv().await?.coalesce(),
688 IoBuf::from(b"beta")
689 );
690 assert_eq!(
691 listener_receiver.recv().await?.coalesce(),
692 IoBuf::from(b"gamma")
693 );
694 Ok(())
695 })
696 }
697
698 #[test]
699 fn test_send_many_flushes_at_network_pool_item_max() -> Result<(), Error> {
700 let executor = deterministic::Runner::new(
701 deterministic::Config::new().with_network_buffer_pool_config(
702 BufferPoolConfig::for_network()
703 .with_pool_min_size(256)
704 .with_min_size(NZUsize!(256))
705 .with_max_size(NZUsize!(256)),
706 ),
707 );
708 executor.start(|context| async move {
709 let dialer_crypto = PrivateKey::from_seed(42);
710 let listener_crypto = PrivateKey::from_seed(24);
711
712 let (dialer_sink, listener_stream) = mocks::Channel::init();
713 let (listener_sink, dialer_stream) = mocks::Channel::init();
714 let sends = Arc::new(AtomicUsize::new(0));
715 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
716
717 let dialer_config = transport_config(dialer_crypto.clone());
718 let listener_config = transport_config(listener_crypto.clone());
719
720 let listener_handle = context.child("listener").spawn(move |context| async move {
721 listen(
722 context,
723 |_| async { true },
724 listener_config,
725 listener_stream,
726 listener_sink,
727 )
728 .await
729 });
730
731 let (mut dialer_sender, _dialer_receiver) = dial(
732 context,
733 dialer_config,
734 listener_crypto.public_key(),
735 dialer_stream,
736 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
737 )
738 .await?;
739
740 let (_listener_peer, _listener_sender, mut listener_receiver) =
741 listener_handle.await.unwrap()?;
742 sends.store(0, Ordering::Relaxed);
743 chunk_counts.lock().clear();
744
745 let payload = vec![7u8; 100];
749 dialer_sender
750 .send_many(vec![
751 IoBufs::from(IoBuf::from(payload.clone())),
752 IoBufs::from(IoBuf::from(payload.clone())),
753 IoBufs::from(IoBuf::from(payload.clone())),
754 ])
755 .await?;
756
757 assert_eq!(sends.load(Ordering::Relaxed), 1);
758 assert_eq!(*chunk_counts.lock(), vec![2]);
759 for _ in 0..3 {
760 assert_eq!(
761 listener_receiver.recv().await?.coalesce(),
762 payload.as_slice()
763 );
764 }
765 Ok(())
766 })
767 }
768
769 #[test]
770 fn test_send_many_sends_oversized_single_message_alone() -> Result<(), Error> {
771 let executor = deterministic::Runner::new(
772 deterministic::Config::new().with_network_buffer_pool_config(
773 BufferPoolConfig::for_network()
774 .with_pool_min_size(128)
775 .with_min_size(NZUsize!(128))
776 .with_max_size(NZUsize!(128)),
777 ),
778 );
779 executor.start(|context| async move {
780 let dialer_crypto = PrivateKey::from_seed(42);
781 let listener_crypto = PrivateKey::from_seed(24);
782
783 let (dialer_sink, listener_stream) = mocks::Channel::init();
784 let (listener_sink, dialer_stream) = mocks::Channel::init();
785 let sends = Arc::new(AtomicUsize::new(0));
786 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
787
788 let dialer_config = transport_config(dialer_crypto.clone());
789 let listener_config = transport_config(listener_crypto.clone());
790
791 let listener_handle = context.child("listener").spawn(move |context| async move {
792 listen(
793 context,
794 |_| async { true },
795 listener_config,
796 listener_stream,
797 listener_sink,
798 )
799 .await
800 });
801
802 let (mut dialer_sender, _dialer_receiver) = dial(
803 context,
804 dialer_config,
805 listener_crypto.public_key(),
806 dialer_stream,
807 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
808 )
809 .await?;
810
811 let (_listener_peer, _listener_sender, mut listener_receiver) =
812 listener_handle.await.unwrap()?;
813 sends.store(0, Ordering::Relaxed);
814 chunk_counts.lock().clear();
815
816 let large = vec![3u8; 200];
819 let small = vec![9u8; 16];
820 dialer_sender
821 .send_many(vec![
822 IoBufs::from(IoBuf::from(large.clone())),
823 IoBufs::from(IoBuf::from(small.clone())),
824 ])
825 .await?;
826
827 assert_eq!(sends.load(Ordering::Relaxed), 1);
828 assert_eq!(*chunk_counts.lock(), vec![2]);
829 assert_eq!(listener_receiver.recv().await?.coalesce(), large.as_slice());
830 assert_eq!(listener_receiver.recv().await?.coalesce(), small.as_slice());
831 Ok(())
832 })
833 }
834
835 #[test]
836 fn test_send_many_too_large_preserves_sender_state() -> Result<(), Error> {
837 let executor = deterministic::Runner::default();
838 executor.start(|context| async move {
839 let dialer_crypto = PrivateKey::from_seed(42);
840 let listener_crypto = PrivateKey::from_seed(24);
841
842 let (dialer_sink, listener_stream) = mocks::Channel::init();
843 let (listener_sink, dialer_stream) = mocks::Channel::init();
844 let sends = Arc::new(AtomicUsize::new(0));
845 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
846
847 let dialer_config = transport_config(dialer_crypto.clone());
848 let listener_config = transport_config(listener_crypto.clone());
849
850 let listener_handle = context.child("listener").spawn(move |context| async move {
851 listen(
852 context,
853 |_| async { true },
854 listener_config,
855 listener_stream,
856 listener_sink,
857 )
858 .await
859 });
860
861 let (mut dialer_sender, _dialer_receiver) = dial(
862 context,
863 dialer_config,
864 listener_crypto.public_key(),
865 dialer_stream,
866 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
867 )
868 .await?;
869
870 let (_listener_peer, _listener_sender, mut listener_receiver) =
871 listener_handle.await.unwrap()?;
872 sends.store(0, Ordering::Relaxed);
873 chunk_counts.lock().clear();
874
875 let valid = vec![7u8; 32];
876 let oversized = vec![9u8; MAX_MESSAGE_SIZE as usize + 1];
877 assert!(matches!(
878 dialer_sender
879 .send_many(vec![
880 IoBufs::from(IoBuf::from(valid)),
881 IoBufs::from(IoBuf::from(oversized)),
882 ])
883 .await,
884 Err(Error::SendTooLarge(_))
885 ));
886
887 assert_eq!(sends.load(Ordering::Relaxed), 0);
888 assert!(chunk_counts.lock().is_empty());
889
890 let recovered = b"recovered";
891 dialer_sender.send(&recovered[..]).await?;
892 assert_eq!(sends.load(Ordering::Relaxed), 1);
893 assert_eq!(listener_receiver.recv().await?.coalesce(), recovered);
894 Ok(())
895 })
896 }
897
898 #[test]
899 fn test_listen_rejects_oversized_fixed_size_peer_key_frame() {
900 let executor = deterministic::Runner::default();
901 executor.start(|context| async move {
902 let dialer_crypto = PrivateKey::from_seed(42);
903 let listener_crypto = PrivateKey::from_seed(24);
904 let peer = dialer_crypto.public_key();
905
906 let (mut dialer_sink, listener_stream) = mocks::Channel::init();
907 let (listener_sink, _dialer_stream) = mocks::Channel::init();
908
909 let mut listener_config = transport_config(listener_crypto);
912 listener_config.max_message_size = 1024 * 1024;
913
914 dialer_sink
918 .send(oversized_handshake_prefix(&peer))
919 .await
920 .unwrap();
921
922 let result = listen(
923 context,
924 |_| async { true },
925 listener_config,
926 listener_stream,
927 listener_sink,
928 )
929 .await;
930
931 assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == peer.encode().len() + 1));
935 });
936 }
937
938 #[test]
939 fn test_dial_rejects_oversized_fixed_size_syn_ack_frame() {
940 let executor = deterministic::Runner::default();
941 executor.start(|context| async move {
942 let dialer_crypto = PrivateKey::from_seed(42);
943 let listener_crypto = PrivateKey::from_seed(24);
944
945 let (dialer_sink, _listener_stream) = mocks::Channel::init();
946 let (mut listener_sink, dialer_stream) = mocks::Channel::init();
947
948 let mut dialer_config = transport_config(dialer_crypto);
951 dialer_config.max_message_size = 1024 * 1024;
952
953 let (current_time, ok_timestamps) = dialer_config.time_information(&context);
956 let listener_public_key = listener_crypto.public_key();
957 let dialer_public_key = dialer_config.signing_key.public_key();
958 let (_, syn) = dial_start(
959 context.child("dialer"),
960 Context::new(
961 &Transcript::new(&dialer_config.namespace),
962 current_time,
963 ok_timestamps.clone(),
964 dialer_config.signing_key.clone(),
965 listener_public_key.clone(),
966 ),
967 );
968 let (_, syn_ack) = listen_start(
969 context.child("listener"),
970 Context::new(
971 &Transcript::new(&dialer_config.namespace),
972 current_time,
973 ok_timestamps,
974 listener_crypto,
975 dialer_public_key,
976 ),
977 syn,
978 )
979 .expect("mock handshake should produce a valid syn_ack");
980
981 listener_sink
984 .send(oversized_handshake_prefix(&syn_ack))
985 .await
986 .unwrap();
987
988 let result = dial(
989 context,
990 dialer_config,
991 listener_public_key,
992 dialer_stream,
993 dialer_sink,
994 )
995 .await;
996
997 assert!(matches!(
1000 result,
1001 Err(Error::RecvTooLarge(n))
1002 if n == syn_ack.encode().len() + 1
1003 ));
1004 });
1005 }
1006}