1use crate::utils::codec::{append_frame, framed_len, recv_frame, send_frame};
59use commonware_codec::{DecodeExt, Encode as _, Error as CodecError, FixedSize};
60use commonware_cryptography::{
61 handshake::{
62 self, dial_end, dial_start, listen_end, listen_start, Ack, Context,
63 Error as HandshakeError, RecvCipher, SendCipher, Syn, SynAck,
64 },
65 transcript::Transcript,
66 Signer,
67};
68use commonware_macros::select;
69use commonware_runtime::{
70 BufMut, BufferPool, BufferPooler, Clock, Error as RuntimeError, IoBuf, IoBufMut, IoBufs, Sink,
71 Stream,
72};
73use commonware_utils::{hex, SystemTimeExt};
74use rand_core::CryptoRngCore;
75use std::{future::Future, ops::Range, time::Duration};
76use thiserror::Error;
77
78const TAG_SIZE: u32 = {
79 assert!(handshake::TAG_SIZE <= u32::MAX as usize);
80 handshake::TAG_SIZE as u32
81};
82
83#[derive(Error, Debug)]
85pub enum Error {
86 #[error("handshake error: {0}")]
87 HandshakeError(HandshakeError),
88 #[error("unable to decode: {0}")]
89 UnableToDecode(CodecError),
90 #[error("peer rejected: {}", hex(_0))]
91 PeerRejected(Vec<u8>),
92 #[error("recv failed")]
93 RecvFailed(RuntimeError),
94 #[error("recv too large: {0} bytes")]
95 RecvTooLarge(usize),
96 #[error("invalid varint length prefix")]
97 InvalidVarint,
98 #[error("send failed")]
99 SendFailed(RuntimeError),
100 #[error("send zero size")]
101 SendZeroSize,
102 #[error("send too large: {0} bytes")]
103 SendTooLarge(usize),
104 #[error("connection closed")]
105 StreamClosed,
106 #[error("handshake timed out")]
107 HandshakeTimeout,
108}
109
110impl From<CodecError> for Error {
111 fn from(value: CodecError) -> Self {
112 Self::UnableToDecode(value)
113 }
114}
115
116impl From<HandshakeError> for Error {
117 fn from(value: HandshakeError) -> Self {
118 Self::HandshakeError(value)
119 }
120}
121
122#[derive(Clone)]
129pub struct Config<S> {
130 pub signing_key: S,
134
135 pub namespace: Vec<u8>,
138
139 pub max_message_size: u32,
144
145 pub synchrony_bound: Duration,
147
148 pub max_handshake_age: Duration,
150
151 pub handshake_timeout: Duration,
153}
154
155impl<S> Config<S> {
156 pub fn time_information(&self, ctx: &impl Clock) -> (u64, Range<u64>) {
158 fn duration_to_u64(d: Duration) -> u64 {
159 u64::try_from(d.as_millis()).expect("duration ms should fit in an u64")
160 }
161 let current_time_ms = duration_to_u64(ctx.current().epoch());
162 let ok_timestamps = (current_time_ms
163 .saturating_sub(duration_to_u64(self.max_handshake_age)))
164 ..(current_time_ms.saturating_add(duration_to_u64(self.synchrony_bound)));
165 (current_time_ms, ok_timestamps)
166 }
167}
168
169async fn recv_handshake_frame<M, T>(stream: &mut T) -> Result<M, Error>
172where
173 M: DecodeExt<()> + FixedSize,
174 T: Stream,
175{
176 let frame = recv_frame(
177 stream,
178 u32::try_from(M::SIZE).expect("handshake frame should fit in u32"),
179 )
180 .await?;
181 Ok(M::decode(frame)?)
182}
183
184pub async fn dial<R: BufferPooler + CryptoRngCore + Clock, S: Signer, I: Stream, O: Sink>(
187 mut ctx: R,
188 config: Config<S>,
189 peer: S::PublicKey,
190 mut stream: I,
191 mut sink: O,
192) -> Result<(Sender<O>, Receiver<I>), Error> {
193 let pool = ctx.network_buffer_pool().clone();
194 let timeout = ctx.sleep(config.handshake_timeout);
195 let inner_routine = async move {
196 send_frame(
197 &mut sink,
198 config.signing_key.public_key().encode(),
199 config.max_message_size,
200 )
201 .await?;
202
203 let (current_time, ok_timestamps) = config.time_information(&ctx);
204 let (state, syn) = dial_start(
205 &mut ctx,
206 Context::new(
207 &Transcript::new(&config.namespace),
208 current_time,
209 ok_timestamps,
210 config.signing_key,
211 peer,
212 ),
213 );
214 send_frame(&mut sink, syn.encode(), config.max_message_size).await?;
215
216 let syn_ack = recv_handshake_frame::<SynAck<S::Signature>, _>(&mut stream).await?;
217
218 let (ack, send, recv) = dial_end(state, syn_ack)?;
219 send_frame(&mut sink, ack.encode(), config.max_message_size).await?;
220
221 Ok((
222 Sender {
223 cipher: send,
224 sink,
225 max_message_size: config.max_message_size,
226 pool: pool.clone(),
227 },
228 Receiver {
229 cipher: recv,
230 stream,
231 max_message_size: config.max_message_size,
232 pool,
233 },
234 ))
235 };
236
237 select! {
238 x = inner_routine => x,
239 _ = timeout => Err(Error::HandshakeTimeout),
240 }
241}
242
243pub async fn listen<
246 R: BufferPooler + CryptoRngCore + Clock,
247 S: Signer,
248 I: Stream,
249 O: Sink,
250 Fut: Future<Output = bool>,
251 F: FnOnce(S::PublicKey) -> Fut,
252>(
253 mut ctx: R,
254 bouncer: F,
255 config: Config<S>,
256 mut stream: I,
257 mut sink: O,
258) -> Result<(S::PublicKey, Sender<O>, Receiver<I>), Error> {
259 let pool = ctx.network_buffer_pool().clone();
260 let timeout = ctx.sleep(config.handshake_timeout);
261 let inner_routine = async move {
262 let peer = recv_handshake_frame::<S::PublicKey, _>(&mut stream).await?;
263 if !bouncer(peer.clone()).await {
264 return Err(Error::PeerRejected(peer.encode().to_vec()));
265 }
266
267 let msg1 = recv_handshake_frame::<Syn<S::Signature>, _>(&mut stream).await?;
268
269 let (current_time, ok_timestamps) = config.time_information(&ctx);
270 let (state, syn_ack) = listen_start(
271 &mut ctx,
272 Context::new(
273 &Transcript::new(&config.namespace),
274 current_time,
275 ok_timestamps,
276 config.signing_key,
277 peer.clone(),
278 ),
279 msg1,
280 )?;
281 send_frame(&mut sink, syn_ack.encode(), config.max_message_size).await?;
282
283 let ack = recv_handshake_frame::<Ack, _>(&mut stream).await?;
284
285 let (send, recv) = listen_end(state, ack)?;
286
287 Ok((
288 peer,
289 Sender {
290 cipher: send,
291 sink,
292 max_message_size: config.max_message_size,
293 pool: pool.clone(),
294 },
295 Receiver {
296 cipher: recv,
297 stream,
298 max_message_size: config.max_message_size,
299 pool,
300 },
301 ))
302 };
303
304 select! {
305 x = inner_routine => x,
306 _ = timeout => Err(Error::HandshakeTimeout),
307 }
308}
309
310pub struct Sender<O> {
312 cipher: SendCipher,
313 sink: O,
314 max_message_size: u32,
315 pool: BufferPool,
316}
317
318struct ChunkPlan {
320 messages: Vec<IoBufs>,
321 total_len: usize,
322}
323
324impl<O: Sink> Sender<O> {
325 fn encrypted_frame_len(&self, plaintext_len: usize) -> Result<usize, Error> {
329 framed_len(
330 plaintext_len + TAG_SIZE as usize,
331 self.max_message_size.saturating_add(TAG_SIZE),
332 )
333 }
334
335 fn append_encrypted_frame(
341 &mut self,
342 chunk: &mut IoBufMut,
343 mut bufs: IoBufs,
344 ) -> Result<(), Error> {
345 append_frame(
346 chunk,
347 bufs.len() + TAG_SIZE as usize,
348 self.max_message_size.saturating_add(TAG_SIZE),
349 |chunk, plaintext_offset| {
350 chunk.put(&mut bufs);
352
353 let tag = self
355 .cipher
356 .send_in_place(&mut chunk.as_mut()[plaintext_offset..])?;
357 chunk.put_slice(&tag);
358 Ok(())
359 },
360 )?;
361 Ok(())
362 }
363
364 fn build_chunk<I>(&mut self, messages: I, total_len: usize) -> Result<IoBuf, Error>
369 where
370 I: IntoIterator<Item = IoBufs>,
371 {
372 let mut chunk = self.pool.alloc(total_len);
373 for msg in messages {
374 self.append_encrypted_frame(&mut chunk, msg)?;
375 }
376 assert_eq!(chunk.len(), total_len);
377 Ok(chunk.freeze())
378 }
379
380 fn plan_chunks<B, I>(&self, bufs: I) -> Result<Vec<ChunkPlan>, Error>
385 where
386 B: Into<IoBufs>,
387 I: IntoIterator<Item = B>,
388 {
389 let bufs = bufs.into_iter();
390 let (lower, _) = bufs.size_hint();
391 let mut chunks = Vec::with_capacity(lower.max(1));
392 let mut batch = Vec::new();
393 let mut batch_total = 0usize;
394 let max_batch_size = self.pool.config().max_size.get();
395
396 for buf in bufs {
397 let msg = buf.into();
398 let frame_len = self.encrypted_frame_len(msg.len())?;
399
400 if frame_len > max_batch_size {
403 if !batch.is_empty() {
404 chunks.push(ChunkPlan {
405 messages: std::mem::take(&mut batch),
406 total_len: batch_total,
407 });
408 batch_total = 0;
409 }
410 chunks.push(ChunkPlan {
411 messages: vec![msg],
412 total_len: frame_len,
413 });
414 continue;
415 }
416
417 if batch_total.saturating_add(frame_len) > max_batch_size {
420 chunks.push(ChunkPlan {
421 messages: std::mem::take(&mut batch),
422 total_len: batch_total,
423 });
424 batch_total = 0;
425 }
426
427 batch_total += frame_len;
428 batch.push(msg);
429 }
430
431 if !batch.is_empty() {
432 chunks.push(ChunkPlan {
433 messages: batch,
434 total_len: batch_total,
435 });
436 }
437
438 Ok(chunks)
439 }
440
441 pub async fn send(&mut self, bufs: impl Into<IoBufs>) -> Result<(), Error> {
446 let bufs = bufs.into();
447 let frame_len = self.encrypted_frame_len(bufs.len())?;
448 let chunk = self.build_chunk(std::iter::once(bufs), frame_len)?;
449 self.sink.send(chunk).await.map_err(Error::SendFailed)
450 }
451
452 pub async fn send_many<B, I>(&mut self, bufs: I) -> Result<(), Error>
460 where
461 B: Into<IoBufs>,
462 I: IntoIterator<Item = B>,
463 {
464 let plans = self.plan_chunks(bufs)?;
465 if plans.is_empty() {
466 return Ok(());
467 }
468
469 let mut chunks = Vec::with_capacity(plans.len());
470 for plan in plans {
471 chunks.push(self.build_chunk(plan.messages, plan.total_len)?);
472 }
473
474 self.sink
475 .send(IoBufs::from(chunks))
476 .await
477 .map_err(Error::SendFailed)
478 }
479}
480
481pub struct Receiver<I> {
483 cipher: RecvCipher,
484 stream: I,
485 max_message_size: u32,
486 pool: BufferPool,
487}
488
489impl<I: Stream> Receiver<I> {
490 pub async fn recv(&mut self) -> Result<IoBufs, Error> {
495 let mut encrypted = recv_frame(
496 &mut self.stream,
497 self.max_message_size.saturating_add(TAG_SIZE),
498 )
499 .await?;
500 let ciphertext_len = encrypted.len();
501
502 let mut decryption_buf = self.pool.alloc(ciphertext_len);
504
505 decryption_buf.put(&mut encrypted);
507
508 let plaintext_len = self.cipher.recv_in_place(decryption_buf.as_mut())?;
510
511 decryption_buf.truncate(plaintext_len);
513
514 Ok(decryption_buf.freeze().into())
515 }
516}
517
518#[cfg(test)]
519mod test {
520 use super::*;
521 use commonware_codec::varint::UInt;
522 use commonware_cryptography::{ed25519::PrivateKey, Signer};
523 use commonware_runtime::{
524 deterministic, mocks, BufferPoolConfig, Error as RuntimeError, IoBuf, IoBufs, Runner as _,
525 Spawner as _,
526 };
527 use commonware_utils::{sync::Mutex, NZUsize};
528 use std::{
529 sync::{
530 atomic::{AtomicUsize, Ordering},
531 Arc,
532 },
533 time::Duration,
534 };
535
536 const NAMESPACE: &[u8] = b"fuzz_transport";
537 const MAX_MESSAGE_SIZE: u32 = 64 * 1024; fn transport_config(signing_key: PrivateKey) -> Config<PrivateKey> {
540 Config {
541 signing_key,
542 namespace: NAMESPACE.to_vec(),
543 max_message_size: MAX_MESSAGE_SIZE,
544 synchrony_bound: Duration::from_secs(1),
545 max_handshake_age: Duration::from_secs(1),
546 handshake_timeout: Duration::from_secs(1),
547 }
548 }
549
550 fn oversized_handshake_prefix(message: &impl commonware_codec::Encode) -> IoBuf {
551 let size = u32::try_from(message.encode().len()).expect("message length should fit in u32");
552 IoBuf::from(UInt(size + 1).encode())
553 }
554
555 struct CountingSink<S> {
556 inner: S,
557 sends: Arc<AtomicUsize>,
558 chunk_counts: Arc<Mutex<Vec<usize>>>,
559 }
560
561 impl<S> CountingSink<S> {
562 fn new(inner: S, sends: Arc<AtomicUsize>, chunk_counts: Arc<Mutex<Vec<usize>>>) -> Self {
563 Self {
564 inner,
565 sends,
566 chunk_counts,
567 }
568 }
569 }
570
571 impl<S: commonware_runtime::Sink> commonware_runtime::Sink for CountingSink<S> {
572 async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), RuntimeError> {
573 let bufs = bufs.into();
574 self.sends.fetch_add(1, Ordering::Relaxed);
575 self.chunk_counts.lock().push(bufs.chunk_count());
576 self.inner.send(bufs).await
577 }
578 }
579
580 #[test]
581 fn test_can_setup_and_send_messages() -> Result<(), Error> {
582 let executor = deterministic::Runner::default();
583 executor.start(|context| async move {
584 let dialer_crypto = PrivateKey::from_seed(42);
585 let listener_crypto = PrivateKey::from_seed(24);
586
587 let (dialer_sink, listener_stream) = mocks::Channel::init();
588 let (listener_sink, dialer_stream) = mocks::Channel::init();
589
590 let dialer_config = transport_config(dialer_crypto.clone());
591 let listener_config = transport_config(listener_crypto.clone());
592
593 let listener_handle = context.clone().spawn(move |context| async move {
594 listen(
595 context,
596 |_| async { true },
597 listener_config,
598 listener_stream,
599 listener_sink,
600 )
601 .await
602 });
603
604 let (mut dialer_sender, mut dialer_receiver) = dial(
605 context,
606 dialer_config,
607 listener_crypto.public_key(),
608 dialer_stream,
609 dialer_sink,
610 )
611 .await?;
612
613 let (listener_peer, mut listener_sender, mut listener_receiver) =
614 listener_handle.await.unwrap()?;
615 assert_eq!(listener_peer, dialer_crypto.public_key());
616 let messages: Vec<&'static [u8]> = vec![b"A", b"B", b"C"];
617 for msg in &messages {
618 dialer_sender.send(&msg[..]).await?;
619 let syn_ack = listener_receiver.recv().await?;
620 assert_eq!(syn_ack.coalesce(), *msg);
621 listener_sender.send(&msg[..]).await?;
622 let ack = dialer_receiver.recv().await?;
623 assert_eq!(ack.coalesce(), *msg);
624 }
625 Ok(())
626 })
627 }
628
629 #[test]
630 fn test_send_many_uses_single_runtime_send() -> Result<(), Error> {
631 let executor = deterministic::Runner::default();
632 executor.start(|context| async move {
633 let dialer_crypto = PrivateKey::from_seed(42);
634 let listener_crypto = PrivateKey::from_seed(24);
635
636 let (dialer_sink, listener_stream) = mocks::Channel::init();
637 let (listener_sink, dialer_stream) = mocks::Channel::init();
638 let sends = Arc::new(AtomicUsize::new(0));
639 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
640
641 let dialer_config = transport_config(dialer_crypto.clone());
642 let listener_config = transport_config(listener_crypto.clone());
643
644 let listener_handle = context.clone().spawn(move |context| async move {
645 listen(
646 context,
647 |_| async { true },
648 listener_config,
649 listener_stream,
650 listener_sink,
651 )
652 .await
653 });
654
655 let (mut dialer_sender, _dialer_receiver) = dial(
656 context,
657 dialer_config,
658 listener_crypto.public_key(),
659 dialer_stream,
660 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
661 )
662 .await?;
663
664 let (_listener_peer, _listener_sender, mut listener_receiver) =
665 listener_handle.await.unwrap()?;
666 sends.store(0, Ordering::Relaxed);
667 chunk_counts.lock().clear();
668
669 dialer_sender
672 .send_many(vec![
673 IoBufs::from(IoBuf::from(b"alpha")),
674 IoBufs::from(IoBuf::from(b"beta")),
675 IoBufs::from(IoBuf::from(b"gamma")),
676 ])
677 .await?;
678
679 assert_eq!(sends.load(Ordering::Relaxed), 1);
680 assert_eq!(*chunk_counts.lock(), vec![1]);
681 assert_eq!(
682 listener_receiver.recv().await?.coalesce(),
683 IoBuf::from(b"alpha")
684 );
685 assert_eq!(
686 listener_receiver.recv().await?.coalesce(),
687 IoBuf::from(b"beta")
688 );
689 assert_eq!(
690 listener_receiver.recv().await?.coalesce(),
691 IoBuf::from(b"gamma")
692 );
693 Ok(())
694 })
695 }
696
697 #[test]
698 fn test_send_many_flushes_at_network_pool_item_max() -> Result<(), Error> {
699 let executor = deterministic::Runner::new(
700 deterministic::Config::new().with_network_buffer_pool_config(
701 BufferPoolConfig::for_network()
702 .with_pool_min_size(256)
703 .with_min_size(NZUsize!(256))
704 .with_max_size(NZUsize!(256)),
705 ),
706 );
707 executor.start(|context| async move {
708 let dialer_crypto = PrivateKey::from_seed(42);
709 let listener_crypto = PrivateKey::from_seed(24);
710
711 let (dialer_sink, listener_stream) = mocks::Channel::init();
712 let (listener_sink, dialer_stream) = mocks::Channel::init();
713 let sends = Arc::new(AtomicUsize::new(0));
714 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
715
716 let dialer_config = transport_config(dialer_crypto.clone());
717 let listener_config = transport_config(listener_crypto.clone());
718
719 let listener_handle = context.clone().spawn(move |context| async move {
720 listen(
721 context,
722 |_| async { true },
723 listener_config,
724 listener_stream,
725 listener_sink,
726 )
727 .await
728 });
729
730 let (mut dialer_sender, _dialer_receiver) = dial(
731 context,
732 dialer_config,
733 listener_crypto.public_key(),
734 dialer_stream,
735 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
736 )
737 .await?;
738
739 let (_listener_peer, _listener_sender, mut listener_receiver) =
740 listener_handle.await.unwrap()?;
741 sends.store(0, Ordering::Relaxed);
742 chunk_counts.lock().clear();
743
744 let payload = vec![7u8; 100];
748 dialer_sender
749 .send_many(vec![
750 IoBufs::from(IoBuf::from(payload.clone())),
751 IoBufs::from(IoBuf::from(payload.clone())),
752 IoBufs::from(IoBuf::from(payload.clone())),
753 ])
754 .await?;
755
756 assert_eq!(sends.load(Ordering::Relaxed), 1);
757 assert_eq!(*chunk_counts.lock(), vec![2]);
758 for _ in 0..3 {
759 assert_eq!(
760 listener_receiver.recv().await?.coalesce(),
761 payload.as_slice()
762 );
763 }
764 Ok(())
765 })
766 }
767
768 #[test]
769 fn test_send_many_sends_oversized_single_message_alone() -> Result<(), Error> {
770 let executor = deterministic::Runner::new(
771 deterministic::Config::new().with_network_buffer_pool_config(
772 BufferPoolConfig::for_network()
773 .with_pool_min_size(128)
774 .with_min_size(NZUsize!(128))
775 .with_max_size(NZUsize!(128)),
776 ),
777 );
778 executor.start(|context| async move {
779 let dialer_crypto = PrivateKey::from_seed(42);
780 let listener_crypto = PrivateKey::from_seed(24);
781
782 let (dialer_sink, listener_stream) = mocks::Channel::init();
783 let (listener_sink, dialer_stream) = mocks::Channel::init();
784 let sends = Arc::new(AtomicUsize::new(0));
785 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
786
787 let dialer_config = transport_config(dialer_crypto.clone());
788 let listener_config = transport_config(listener_crypto.clone());
789
790 let listener_handle = context.clone().spawn(move |context| async move {
791 listen(
792 context,
793 |_| async { true },
794 listener_config,
795 listener_stream,
796 listener_sink,
797 )
798 .await
799 });
800
801 let (mut dialer_sender, _dialer_receiver) = dial(
802 context,
803 dialer_config,
804 listener_crypto.public_key(),
805 dialer_stream,
806 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
807 )
808 .await?;
809
810 let (_listener_peer, _listener_sender, mut listener_receiver) =
811 listener_handle.await.unwrap()?;
812 sends.store(0, Ordering::Relaxed);
813 chunk_counts.lock().clear();
814
815 let large = vec![3u8; 200];
818 let small = vec![9u8; 16];
819 dialer_sender
820 .send_many(vec![
821 IoBufs::from(IoBuf::from(large.clone())),
822 IoBufs::from(IoBuf::from(small.clone())),
823 ])
824 .await?;
825
826 assert_eq!(sends.load(Ordering::Relaxed), 1);
827 assert_eq!(*chunk_counts.lock(), vec![2]);
828 assert_eq!(listener_receiver.recv().await?.coalesce(), large.as_slice());
829 assert_eq!(listener_receiver.recv().await?.coalesce(), small.as_slice());
830 Ok(())
831 })
832 }
833
834 #[test]
835 fn test_send_many_too_large_preserves_sender_state() -> Result<(), Error> {
836 let executor = deterministic::Runner::default();
837 executor.start(|context| async move {
838 let dialer_crypto = PrivateKey::from_seed(42);
839 let listener_crypto = PrivateKey::from_seed(24);
840
841 let (dialer_sink, listener_stream) = mocks::Channel::init();
842 let (listener_sink, dialer_stream) = mocks::Channel::init();
843 let sends = Arc::new(AtomicUsize::new(0));
844 let chunk_counts = Arc::new(Mutex::new(Vec::new()));
845
846 let dialer_config = transport_config(dialer_crypto.clone());
847 let listener_config = transport_config(listener_crypto.clone());
848
849 let listener_handle = context.clone().spawn(move |context| async move {
850 listen(
851 context,
852 |_| async { true },
853 listener_config,
854 listener_stream,
855 listener_sink,
856 )
857 .await
858 });
859
860 let (mut dialer_sender, _dialer_receiver) = dial(
861 context,
862 dialer_config,
863 listener_crypto.public_key(),
864 dialer_stream,
865 CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
866 )
867 .await?;
868
869 let (_listener_peer, _listener_sender, mut listener_receiver) =
870 listener_handle.await.unwrap()?;
871 sends.store(0, Ordering::Relaxed);
872 chunk_counts.lock().clear();
873
874 let valid = vec![7u8; 32];
875 let oversized = vec![9u8; MAX_MESSAGE_SIZE as usize + 1];
876 assert!(matches!(
877 dialer_sender
878 .send_many(vec![
879 IoBufs::from(IoBuf::from(valid)),
880 IoBufs::from(IoBuf::from(oversized)),
881 ])
882 .await,
883 Err(Error::SendTooLarge(_))
884 ));
885
886 assert_eq!(sends.load(Ordering::Relaxed), 0);
887 assert!(chunk_counts.lock().is_empty());
888
889 let recovered = b"recovered";
890 dialer_sender.send(&recovered[..]).await?;
891 assert_eq!(sends.load(Ordering::Relaxed), 1);
892 assert_eq!(listener_receiver.recv().await?.coalesce(), recovered);
893 Ok(())
894 })
895 }
896
897 #[test]
898 fn test_listen_rejects_oversized_fixed_size_peer_key_frame() {
899 let executor = deterministic::Runner::default();
900 executor.start(|context| async move {
901 let dialer_crypto = PrivateKey::from_seed(42);
902 let listener_crypto = PrivateKey::from_seed(24);
903 let peer = dialer_crypto.public_key();
904
905 let (mut dialer_sink, listener_stream) = mocks::Channel::init();
906 let (listener_sink, _dialer_stream) = mocks::Channel::init();
907
908 let mut listener_config = transport_config(listener_crypto);
911 listener_config.max_message_size = 1024 * 1024;
912
913 dialer_sink
917 .send(oversized_handshake_prefix(&peer))
918 .await
919 .unwrap();
920
921 let result = listen(
922 context,
923 |_| async { true },
924 listener_config,
925 listener_stream,
926 listener_sink,
927 )
928 .await;
929
930 assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == peer.encode().len() + 1));
934 });
935 }
936
937 #[test]
938 fn test_dial_rejects_oversized_fixed_size_syn_ack_frame() {
939 let executor = deterministic::Runner::default();
940 executor.start(|context| async move {
941 let dialer_crypto = PrivateKey::from_seed(42);
942 let listener_crypto = PrivateKey::from_seed(24);
943
944 let (dialer_sink, _listener_stream) = mocks::Channel::init();
945 let (mut listener_sink, dialer_stream) = mocks::Channel::init();
946
947 let mut dialer_config = transport_config(dialer_crypto);
950 dialer_config.max_message_size = 1024 * 1024;
951
952 let (current_time, ok_timestamps) = dialer_config.time_information(&context);
955 let mut listener_rng = context.clone();
956 let (_, syn) = dial_start(
957 context.clone(),
958 Context::new(
959 &Transcript::new(&dialer_config.namespace),
960 current_time,
961 ok_timestamps.clone(),
962 dialer_config.signing_key.clone(),
963 listener_crypto.public_key(),
964 ),
965 );
966 let (_, syn_ack) = listen_start(
967 &mut listener_rng,
968 Context::new(
969 &Transcript::new(&dialer_config.namespace),
970 current_time,
971 ok_timestamps,
972 listener_crypto.clone(),
973 dialer_config.signing_key.public_key(),
974 ),
975 syn,
976 )
977 .expect("mock handshake should produce a valid syn_ack");
978
979 listener_sink
982 .send(oversized_handshake_prefix(&syn_ack))
983 .await
984 .unwrap();
985
986 let result = dial(
987 context,
988 dialer_config,
989 listener_crypto.public_key(),
990 dialer_stream,
991 dialer_sink,
992 )
993 .await;
994
995 assert!(matches!(
998 result,
999 Err(Error::RecvTooLarge(n))
1000 if n == syn_ack.encode().len() + 1
1001 ));
1002 });
1003 }
1004}