Skip to main content

commonware_stream/
encrypted.rs

1//! Encrypted stream implementation using ChaCha20-Poly1305.
2//!
3//! # Design
4//!
5//! ## Handshake
6//!
7//! c.f. [commonware_cryptography::handshake]. One difference here is that the listener does not
8//! know the dialer's public key in advance. Instead, the dialer tells the listener its public key
9//! in the first message. The listener has an opportunity to reject the connection if it does not
10//! wish to connect ([listen] takes in an arbitrary function to implement this).
11//!
12//! ## Encryption
13//!
14//! All traffic is encrypted using ChaCha20-Poly1305. A shared secret is established using an
15//! ephemeral X25519 Diffie-Hellman key exchange. This secret, combined with the handshake
16//! transcript, is used to derive keys for both the handshake's key confirmation messages and
17//! the post-handshake data traffic. Binding the derived keys to the handshake transcript prevents
18//! man-in-the-middle and transcript substitution attacks.
19//!
20//! Each directional cipher uses a 12-byte nonce derived from a counter that is incremented for each
21//! message sent. This counter has sufficient cardinality for over 2.5 trillion years of continuous
22//! communication at a rate of 1 billion messages per second - sufficient for all practical use cases.
23//! This ensures that well-behaving peers can remain connected indefinitely as long as they both
24//! remain online (maximizing p2p network stability). In the unlikely case of counter overflow, the
25//! connection will be terminated and a new connection should be established. This method prevents
26//! nonce reuse (which would compromise message confidentiality) while saving bandwidth (as there is
27//! no need to transmit nonces explicitly).
28//!
29//! # Security
30//!
31//! ## Requirements
32//!
33//! - **Pre-Shared Namespace**: Peers must agree on a unique, application-specific namespace
34//!   out-of-band to prevent cross-application replay attacks.
35//! - **Time Synchronization**: Peer clocks must be synchronized to within the `synchrony_bound`
36//!   to correctly validate timestamps.
37//!
38//! ## Provided
39//!
40//! - **Mutual Authentication**: Both parties prove ownership of their static private keys through
41//!   signatures.
42//! - **Forward Secrecy**: Ephemeral encryption keys ensure that any compromise of long-term static keys
43//!   doesn't expose the contents of previous sessions.
44//! - **Session Uniqueness**: A listener's [commonware_cryptography::handshake::SynAck] is bound to the dialer's [commonware_cryptography::handshake::Syn] message and
45//!   [commonware_cryptography::handshake::Ack]s are bound to the complete handshake transcript, preventing replay attacks and ensuring
46//!   message integrity.
47//! - **Handshake Timeout**: A configurable deadline is enforced for handshake completion to protect
48//!   against malicious peers that create connections but abandon handshakes.
49//!
50//! ## Not Provided
51//!
52//! - **Anonymity**: Peer identities are not hidden during handshakes from network observers (both active
53//!   and passive).
54//! - **Padding**: Messages are encrypted as-is, allowing an attacker to perform traffic analysis.
55//! - **Future Secrecy**: If a peer's static private key is compromised, future sessions will be exposed.
56//! - **0-RTT**: The protocol does not support 0-RTT handshakes (resumed sessions).
57
58use crate::utils::codec::{append_frame, framed_len, recv_frame, send_frame};
59use commonware_codec::{DecodeExt, Encode as _, Error as CodecError, FixedSize};
60use commonware_cryptography::{
61    handshake::{
62        self, dial_end, dial_start, listen_end, listen_start, Ack, Context,
63        Error as HandshakeError, RecvCipher, SendCipher, Syn, SynAck,
64    },
65    transcript::Transcript,
66    Signer,
67};
68use commonware_macros::select;
69use commonware_runtime::{
70    BufMut, BufferPool, BufferPooler, Clock, Error as RuntimeError, IoBuf, IoBufMut, IoBufs, Sink,
71    Stream,
72};
73use commonware_utils::{hex, SystemTimeExt};
74use rand_core::CryptoRngCore;
75use std::{future::Future, ops::Range, time::Duration};
76use thiserror::Error;
77
78const TAG_SIZE: u32 = {
79    assert!(handshake::TAG_SIZE <= u32::MAX as usize);
80    handshake::TAG_SIZE as u32
81};
82
83/// Errors that can occur when interacting with a stream.
84#[derive(Error, Debug)]
85pub enum Error {
86    #[error("handshake error: {0}")]
87    HandshakeError(HandshakeError),
88    #[error("unable to decode: {0}")]
89    UnableToDecode(CodecError),
90    #[error("peer rejected: {}", hex(_0))]
91    PeerRejected(Vec<u8>),
92    #[error("recv failed")]
93    RecvFailed(RuntimeError),
94    #[error("recv too large: {0} bytes")]
95    RecvTooLarge(usize),
96    #[error("invalid varint length prefix")]
97    InvalidVarint,
98    #[error("send failed")]
99    SendFailed(RuntimeError),
100    #[error("send zero size")]
101    SendZeroSize,
102    #[error("send too large: {0} bytes")]
103    SendTooLarge(usize),
104    #[error("connection closed")]
105    StreamClosed,
106    #[error("handshake timed out")]
107    HandshakeTimeout,
108}
109
110impl From<CodecError> for Error {
111    fn from(value: CodecError) -> Self {
112        Self::UnableToDecode(value)
113    }
114}
115
116impl From<HandshakeError> for Error {
117    fn from(value: HandshakeError) -> Self {
118        Self::HandshakeError(value)
119    }
120}
121
122/// Configuration for a connection.
123///
124/// # Warning
125///
126/// Synchronize this configuration across all peers.
127/// Mismatched configurations may cause dropped connections or parsing errors.
128#[derive(Clone)]
129pub struct Config<S> {
130    /// The private key used for signing messages.
131    ///
132    /// This proves our own identity to other peers.
133    pub signing_key: S,
134
135    /// Unique prefix for all signed messages. Should be application-specific.
136    /// Prevents replay attacks across different applications using the same keys.
137    pub namespace: Vec<u8>,
138
139    /// Maximum message size (in bytes). Prevents memory exhaustion DoS attacks.
140    ///
141    /// Fixed-size handshake frames use their protocol-defined sizes instead of
142    /// inheriting this limit.
143    pub max_message_size: u32,
144
145    /// Maximum time drift allowed for future timestamps. Handles clock skew.
146    pub synchrony_bound: Duration,
147
148    /// Maximum age of handshake messages before rejection.
149    pub max_handshake_age: Duration,
150
151    /// The allotted time for the handshake to complete.
152    pub handshake_timeout: Duration,
153}
154
155impl<S> Config<S> {
156    /// Computes current time and acceptable timestamp range.
157    pub fn time_information(&self, ctx: &impl Clock) -> (u64, Range<u64>) {
158        fn duration_to_u64(d: Duration) -> u64 {
159            u64::try_from(d.as_millis()).expect("duration ms should fit in an u64")
160        }
161        let current_time_ms = duration_to_u64(ctx.current().epoch());
162        let ok_timestamps = (current_time_ms
163            .saturating_sub(duration_to_u64(self.max_handshake_age)))
164            ..(current_time_ms.saturating_add(duration_to_u64(self.synchrony_bound)));
165        (current_time_ms, ok_timestamps)
166    }
167}
168
169// Handshake frames are fixed-size protocol messages, so we cap receives to
170// their exact encoded length instead of the application message limit.
171async fn recv_handshake_frame<M, T>(stream: &mut T) -> Result<M, Error>
172where
173    M: DecodeExt<()> + FixedSize,
174    T: Stream,
175{
176    let frame = recv_frame(
177        stream,
178        u32::try_from(M::SIZE).expect("handshake frame should fit in u32"),
179    )
180    .await?;
181    Ok(M::decode(frame)?)
182}
183
184/// Establishes an authenticated connection to a peer as the dialer.
185/// Returns sender and receiver for encrypted communication.
186pub async fn dial<R: BufferPooler + CryptoRngCore + Clock, S: Signer, I: Stream, O: Sink>(
187    mut ctx: R,
188    config: Config<S>,
189    peer: S::PublicKey,
190    mut stream: I,
191    mut sink: O,
192) -> Result<(Sender<O>, Receiver<I>), Error> {
193    let pool = ctx.network_buffer_pool().clone();
194    let timeout = ctx.sleep(config.handshake_timeout);
195    let inner_routine = async move {
196        send_frame(
197            &mut sink,
198            config.signing_key.public_key().encode(),
199            config.max_message_size,
200        )
201        .await?;
202
203        let (current_time, ok_timestamps) = config.time_information(&ctx);
204        let (state, syn) = dial_start(
205            &mut ctx,
206            Context::new(
207                &Transcript::new(&config.namespace),
208                current_time,
209                ok_timestamps,
210                config.signing_key,
211                peer,
212            ),
213        );
214        send_frame(&mut sink, syn.encode(), config.max_message_size).await?;
215
216        let syn_ack = recv_handshake_frame::<SynAck<S::Signature>, _>(&mut stream).await?;
217
218        let (ack, send, recv) = dial_end(state, syn_ack)?;
219        send_frame(&mut sink, ack.encode(), config.max_message_size).await?;
220
221        Ok((
222            Sender {
223                cipher: send,
224                sink,
225                max_message_size: config.max_message_size,
226                pool: pool.clone(),
227            },
228            Receiver {
229                cipher: recv,
230                stream,
231                max_message_size: config.max_message_size,
232                pool,
233            },
234        ))
235    };
236
237    select! {
238        x = inner_routine => x,
239        _ = timeout => Err(Error::HandshakeTimeout),
240    }
241}
242
243/// Accepts an authenticated connection from a peer as the listener.
244/// Returns the peer's identity, sender, and receiver for encrypted communication.
245pub async fn listen<
246    R: BufferPooler + CryptoRngCore + Clock,
247    S: Signer,
248    I: Stream,
249    O: Sink,
250    Fut: Future<Output = bool>,
251    F: FnOnce(S::PublicKey) -> Fut,
252>(
253    mut ctx: R,
254    bouncer: F,
255    config: Config<S>,
256    mut stream: I,
257    mut sink: O,
258) -> Result<(S::PublicKey, Sender<O>, Receiver<I>), Error> {
259    let pool = ctx.network_buffer_pool().clone();
260    let timeout = ctx.sleep(config.handshake_timeout);
261    let inner_routine = async move {
262        let peer = recv_handshake_frame::<S::PublicKey, _>(&mut stream).await?;
263        if !bouncer(peer.clone()).await {
264            return Err(Error::PeerRejected(peer.encode().to_vec()));
265        }
266
267        let msg1 = recv_handshake_frame::<Syn<S::Signature>, _>(&mut stream).await?;
268
269        let (current_time, ok_timestamps) = config.time_information(&ctx);
270        let (state, syn_ack) = listen_start(
271            &mut ctx,
272            Context::new(
273                &Transcript::new(&config.namespace),
274                current_time,
275                ok_timestamps,
276                config.signing_key,
277                peer.clone(),
278            ),
279            msg1,
280        )?;
281        send_frame(&mut sink, syn_ack.encode(), config.max_message_size).await?;
282
283        let ack = recv_handshake_frame::<Ack, _>(&mut stream).await?;
284
285        let (send, recv) = listen_end(state, ack)?;
286
287        Ok((
288            peer,
289            Sender {
290                cipher: send,
291                sink,
292                max_message_size: config.max_message_size,
293                pool: pool.clone(),
294            },
295            Receiver {
296                cipher: recv,
297                stream,
298                max_message_size: config.max_message_size,
299                pool,
300            },
301        ))
302    };
303
304    select! {
305        x = inner_routine => x,
306        _ = timeout => Err(Error::HandshakeTimeout),
307    }
308}
309
310/// Sends encrypted messages to a peer.
311pub struct Sender<O> {
312    cipher: SendCipher,
313    sink: O,
314    max_message_size: u32,
315    pool: BufferPool,
316}
317
318/// Describes one contiguous sink chunk made up of one or more encrypted frames.
319struct ChunkPlan {
320    messages: Vec<IoBufs>,
321    total_len: usize,
322}
323
324impl<O: Sink> Sender<O> {
325    /// Returns the total encoded size of one encrypted frame.
326    ///
327    /// The returned size includes the length prefix, ciphertext, and AEAD tag.
328    fn encrypted_frame_len(&self, plaintext_len: usize) -> Result<usize, Error> {
329        framed_len(
330            plaintext_len + TAG_SIZE as usize,
331            self.max_message_size.saturating_add(TAG_SIZE),
332        )
333    }
334
335    /// Appends one encrypted frame directly into caller-provided storage.
336    ///
337    /// This lets chunk builders append multiple independently framed
338    /// ciphertexts into a single contiguous allocation without staging each
339    /// frame in its own buffer first.
340    fn append_encrypted_frame(
341        &mut self,
342        chunk: &mut IoBufMut,
343        mut bufs: IoBufs,
344    ) -> Result<(), Error> {
345        append_frame(
346            chunk,
347            bufs.len() + TAG_SIZE as usize,
348            self.max_message_size.saturating_add(TAG_SIZE),
349            |chunk, plaintext_offset| {
350                // Copy the plaintext directly into the frame.
351                chunk.put(&mut bufs);
352
353                // Encrypt in-place and append the tag to the frame.
354                let tag = self
355                    .cipher
356                    .send_in_place(&mut chunk.as_mut()[plaintext_offset..])?;
357                chunk.put_slice(&tag);
358                Ok(())
359            },
360        )?;
361        Ok(())
362    }
363
364    /// Builds one contiguous chunk containing one or more encrypted frames.
365    ///
366    /// Callers compute `total_len` up front so this helper can allocate once,
367    /// append each framed ciphertext in order, and freeze the result.
368    fn build_chunk<I>(&mut self, messages: I, total_len: usize) -> Result<IoBuf, Error>
369    where
370        I: IntoIterator<Item = IoBufs>,
371    {
372        let mut chunk = self.pool.alloc(total_len);
373        for msg in messages {
374            self.append_encrypted_frame(&mut chunk, msg)?;
375        }
376        assert_eq!(chunk.len(), total_len);
377        Ok(chunk.freeze())
378    }
379
380    /// Plans `send_many` chunk boundaries without consuming cipher state.
381    ///
382    /// This validation pass ensures any oversize error is reported before
383    /// encryption advances nonces, so the sender remains usable after failure.
384    fn plan_chunks<B, I>(&self, bufs: I) -> Result<Vec<ChunkPlan>, Error>
385    where
386        B: Into<IoBufs>,
387        I: IntoIterator<Item = B>,
388    {
389        let bufs = bufs.into_iter();
390        let (lower, _) = bufs.size_hint();
391        let mut chunks = Vec::with_capacity(lower.max(1));
392        let mut batch = Vec::new();
393        let mut batch_total = 0usize;
394        let max_batch_size = self.pool.config().max_size.get();
395
396        for buf in bufs {
397            let msg = buf.into();
398            let frame_len = self.encrypted_frame_len(msg.len())?;
399
400            // If one framed message is larger than the pooled batch cap, keep
401            // current chunks intact and send that message as its own chunk.
402            if frame_len > max_batch_size {
403                if !batch.is_empty() {
404                    chunks.push(ChunkPlan {
405                        messages: std::mem::take(&mut batch),
406                        total_len: batch_total,
407                    });
408                    batch_total = 0;
409                }
410                chunks.push(ChunkPlan {
411                    messages: vec![msg],
412                    total_len: frame_len,
413                });
414                continue;
415            }
416
417            // Close the current chunk before it would exceed one network
418            // buffer-pool item.
419            if batch_total.saturating_add(frame_len) > max_batch_size {
420                chunks.push(ChunkPlan {
421                    messages: std::mem::take(&mut batch),
422                    total_len: batch_total,
423                });
424                batch_total = 0;
425            }
426
427            batch_total += frame_len;
428            batch.push(msg);
429        }
430
431        if !batch.is_empty() {
432            chunks.push(ChunkPlan {
433                messages: batch,
434                total_len: batch_total,
435            });
436        }
437
438        Ok(chunks)
439    }
440
441    /// Encrypts and sends a message to the peer.
442    ///
443    /// Allocates a buffer from the pool, copies plaintext, encrypts in-place,
444    /// and sends the ciphertext.
445    pub async fn send(&mut self, bufs: impl Into<IoBufs>) -> Result<(), Error> {
446        let bufs = bufs.into();
447        let frame_len = self.encrypted_frame_len(bufs.len())?;
448        let chunk = self.build_chunk(std::iter::once(bufs), frame_len)?;
449        self.sink.send(chunk).await.map_err(Error::SendFailed)
450    }
451
452    /// Encrypts and sends multiple messages in a single sink call.
453    ///
454    /// Each message is framed independently so receivers still observe the
455    /// original message boundaries. Aggregate writes are broken into contiguous
456    /// chunks capped to one network buffer-pool item, then submitted together as
457    /// a chunked `IoBufs`. An individual message larger than that cap is still
458    /// sent as its own chunk.
459    pub async fn send_many<B, I>(&mut self, bufs: I) -> Result<(), Error>
460    where
461        B: Into<IoBufs>,
462        I: IntoIterator<Item = B>,
463    {
464        let plans = self.plan_chunks(bufs)?;
465        if plans.is_empty() {
466            return Ok(());
467        }
468
469        let mut chunks = Vec::with_capacity(plans.len());
470        for plan in plans {
471            chunks.push(self.build_chunk(plan.messages, plan.total_len)?);
472        }
473
474        self.sink
475            .send(IoBufs::from(chunks))
476            .await
477            .map_err(Error::SendFailed)
478    }
479}
480
481/// Receives encrypted messages from a peer.
482pub struct Receiver<I> {
483    cipher: RecvCipher,
484    stream: I,
485    max_message_size: u32,
486    pool: BufferPool,
487}
488
489impl<I: Stream> Receiver<I> {
490    /// Receives and decrypts a message from the peer.
491    ///
492    /// Receives ciphertext, allocates a buffer from the pool, copies ciphertext,
493    /// and decrypts in-place.
494    pub async fn recv(&mut self) -> Result<IoBufs, Error> {
495        let mut encrypted = recv_frame(
496            &mut self.stream,
497            self.max_message_size.saturating_add(TAG_SIZE),
498        )
499        .await?;
500        let ciphertext_len = encrypted.len();
501
502        // Allocate buffer from pool for decryption.
503        let mut decryption_buf = self.pool.alloc(ciphertext_len);
504
505        // Copy ciphertext into buffer.
506        decryption_buf.put(&mut encrypted);
507
508        // Decrypt in-place, get plaintext length back.
509        let plaintext_len = self.cipher.recv_in_place(decryption_buf.as_mut())?;
510
511        // Truncate to remove tag bytes, keeping only plaintext.
512        decryption_buf.truncate(plaintext_len);
513
514        Ok(decryption_buf.freeze().into())
515    }
516}
517
518#[cfg(test)]
519mod test {
520    use super::*;
521    use commonware_codec::varint::UInt;
522    use commonware_cryptography::{ed25519::PrivateKey, Signer};
523    use commonware_runtime::{
524        deterministic, mocks, BufferPoolConfig, Error as RuntimeError, IoBuf, IoBufs, Runner as _,
525        Spawner as _,
526    };
527    use commonware_utils::{sync::Mutex, NZUsize};
528    use std::{
529        sync::{
530            atomic::{AtomicUsize, Ordering},
531            Arc,
532        },
533        time::Duration,
534    };
535
536    const NAMESPACE: &[u8] = b"fuzz_transport";
537    const MAX_MESSAGE_SIZE: u32 = 64 * 1024; // 64KB buffer
538
539    fn transport_config(signing_key: PrivateKey) -> Config<PrivateKey> {
540        Config {
541            signing_key,
542            namespace: NAMESPACE.to_vec(),
543            max_message_size: MAX_MESSAGE_SIZE,
544            synchrony_bound: Duration::from_secs(1),
545            max_handshake_age: Duration::from_secs(1),
546            handshake_timeout: Duration::from_secs(1),
547        }
548    }
549
550    fn oversized_handshake_prefix(message: &impl commonware_codec::Encode) -> IoBuf {
551        let size = u32::try_from(message.encode().len()).expect("message length should fit in u32");
552        IoBuf::from(UInt(size + 1).encode())
553    }
554
555    struct CountingSink<S> {
556        inner: S,
557        sends: Arc<AtomicUsize>,
558        chunk_counts: Arc<Mutex<Vec<usize>>>,
559    }
560
561    impl<S> CountingSink<S> {
562        fn new(inner: S, sends: Arc<AtomicUsize>, chunk_counts: Arc<Mutex<Vec<usize>>>) -> Self {
563            Self {
564                inner,
565                sends,
566                chunk_counts,
567            }
568        }
569    }
570
571    impl<S: commonware_runtime::Sink> commonware_runtime::Sink for CountingSink<S> {
572        async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), RuntimeError> {
573            let bufs = bufs.into();
574            self.sends.fetch_add(1, Ordering::Relaxed);
575            self.chunk_counts.lock().push(bufs.chunk_count());
576            self.inner.send(bufs).await
577        }
578    }
579
580    #[test]
581    fn test_can_setup_and_send_messages() -> Result<(), Error> {
582        let executor = deterministic::Runner::default();
583        executor.start(|context| async move {
584            let dialer_crypto = PrivateKey::from_seed(42);
585            let listener_crypto = PrivateKey::from_seed(24);
586
587            let (dialer_sink, listener_stream) = mocks::Channel::init();
588            let (listener_sink, dialer_stream) = mocks::Channel::init();
589
590            let dialer_config = transport_config(dialer_crypto.clone());
591            let listener_config = transport_config(listener_crypto.clone());
592
593            let listener_handle = context.clone().spawn(move |context| async move {
594                listen(
595                    context,
596                    |_| async { true },
597                    listener_config,
598                    listener_stream,
599                    listener_sink,
600                )
601                .await
602            });
603
604            let (mut dialer_sender, mut dialer_receiver) = dial(
605                context,
606                dialer_config,
607                listener_crypto.public_key(),
608                dialer_stream,
609                dialer_sink,
610            )
611            .await?;
612
613            let (listener_peer, mut listener_sender, mut listener_receiver) =
614                listener_handle.await.unwrap()?;
615            assert_eq!(listener_peer, dialer_crypto.public_key());
616            let messages: Vec<&'static [u8]> = vec![b"A", b"B", b"C"];
617            for msg in &messages {
618                dialer_sender.send(&msg[..]).await?;
619                let syn_ack = listener_receiver.recv().await?;
620                assert_eq!(syn_ack.coalesce(), *msg);
621                listener_sender.send(&msg[..]).await?;
622                let ack = dialer_receiver.recv().await?;
623                assert_eq!(ack.coalesce(), *msg);
624            }
625            Ok(())
626        })
627    }
628
629    #[test]
630    fn test_send_many_uses_single_runtime_send() -> Result<(), Error> {
631        let executor = deterministic::Runner::default();
632        executor.start(|context| async move {
633            let dialer_crypto = PrivateKey::from_seed(42);
634            let listener_crypto = PrivateKey::from_seed(24);
635
636            let (dialer_sink, listener_stream) = mocks::Channel::init();
637            let (listener_sink, dialer_stream) = mocks::Channel::init();
638            let sends = Arc::new(AtomicUsize::new(0));
639            let chunk_counts = Arc::new(Mutex::new(Vec::new()));
640
641            let dialer_config = transport_config(dialer_crypto.clone());
642            let listener_config = transport_config(listener_crypto.clone());
643
644            let listener_handle = context.clone().spawn(move |context| async move {
645                listen(
646                    context,
647                    |_| async { true },
648                    listener_config,
649                    listener_stream,
650                    listener_sink,
651                )
652                .await
653            });
654
655            let (mut dialer_sender, _dialer_receiver) = dial(
656                context,
657                dialer_config,
658                listener_crypto.public_key(),
659                dialer_stream,
660                CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
661            )
662            .await?;
663
664            let (_listener_peer, _listener_sender, mut listener_receiver) =
665                listener_handle.await.unwrap()?;
666            sends.store(0, Ordering::Relaxed);
667            chunk_counts.lock().clear();
668
669            // Three small messages should fit in one pooled chunk, so `send_many`
670            // still reaches the runtime as a single single-chunk send call.
671            dialer_sender
672                .send_many(vec![
673                    IoBufs::from(IoBuf::from(b"alpha")),
674                    IoBufs::from(IoBuf::from(b"beta")),
675                    IoBufs::from(IoBuf::from(b"gamma")),
676                ])
677                .await?;
678
679            assert_eq!(sends.load(Ordering::Relaxed), 1);
680            assert_eq!(*chunk_counts.lock(), vec![1]);
681            assert_eq!(
682                listener_receiver.recv().await?.coalesce(),
683                IoBuf::from(b"alpha")
684            );
685            assert_eq!(
686                listener_receiver.recv().await?.coalesce(),
687                IoBuf::from(b"beta")
688            );
689            assert_eq!(
690                listener_receiver.recv().await?.coalesce(),
691                IoBuf::from(b"gamma")
692            );
693            Ok(())
694        })
695    }
696
697    #[test]
698    fn test_send_many_flushes_at_network_pool_item_max() -> Result<(), Error> {
699        let executor = deterministic::Runner::new(
700            deterministic::Config::new().with_network_buffer_pool_config(
701                BufferPoolConfig::for_network()
702                    .with_pool_min_size(256)
703                    .with_min_size(NZUsize!(256))
704                    .with_max_size(NZUsize!(256)),
705            ),
706        );
707        executor.start(|context| async move {
708            let dialer_crypto = PrivateKey::from_seed(42);
709            let listener_crypto = PrivateKey::from_seed(24);
710
711            let (dialer_sink, listener_stream) = mocks::Channel::init();
712            let (listener_sink, dialer_stream) = mocks::Channel::init();
713            let sends = Arc::new(AtomicUsize::new(0));
714            let chunk_counts = Arc::new(Mutex::new(Vec::new()));
715
716            let dialer_config = transport_config(dialer_crypto.clone());
717            let listener_config = transport_config(listener_crypto.clone());
718
719            let listener_handle = context.clone().spawn(move |context| async move {
720                listen(
721                    context,
722                    |_| async { true },
723                    listener_config,
724                    listener_stream,
725                    listener_sink,
726                )
727                .await
728            });
729
730            let (mut dialer_sender, _dialer_receiver) = dial(
731                context,
732                dialer_config,
733                listener_crypto.public_key(),
734                dialer_stream,
735                CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
736            )
737            .await?;
738
739            let (_listener_peer, _listener_sender, mut listener_receiver) =
740                listener_handle.await.unwrap()?;
741            sends.store(0, Ordering::Relaxed);
742            chunk_counts.lock().clear();
743
744            // The first two framed messages fit together under the 256-byte cap,
745            // but the third must spill into a second chunk. We still hand the
746            // runtime one chunked `IoBufs`, so there is only one sink call.
747            let payload = vec![7u8; 100];
748            dialer_sender
749                .send_many(vec![
750                    IoBufs::from(IoBuf::from(payload.clone())),
751                    IoBufs::from(IoBuf::from(payload.clone())),
752                    IoBufs::from(IoBuf::from(payload.clone())),
753                ])
754                .await?;
755
756            assert_eq!(sends.load(Ordering::Relaxed), 1);
757            assert_eq!(*chunk_counts.lock(), vec![2]);
758            for _ in 0..3 {
759                assert_eq!(
760                    listener_receiver.recv().await?.coalesce(),
761                    payload.as_slice()
762                );
763            }
764            Ok(())
765        })
766    }
767
768    #[test]
769    fn test_send_many_sends_oversized_single_message_alone() -> Result<(), Error> {
770        let executor = deterministic::Runner::new(
771            deterministic::Config::new().with_network_buffer_pool_config(
772                BufferPoolConfig::for_network()
773                    .with_pool_min_size(128)
774                    .with_min_size(NZUsize!(128))
775                    .with_max_size(NZUsize!(128)),
776            ),
777        );
778        executor.start(|context| async move {
779            let dialer_crypto = PrivateKey::from_seed(42);
780            let listener_crypto = PrivateKey::from_seed(24);
781
782            let (dialer_sink, listener_stream) = mocks::Channel::init();
783            let (listener_sink, dialer_stream) = mocks::Channel::init();
784            let sends = Arc::new(AtomicUsize::new(0));
785            let chunk_counts = Arc::new(Mutex::new(Vec::new()));
786
787            let dialer_config = transport_config(dialer_crypto.clone());
788            let listener_config = transport_config(listener_crypto.clone());
789
790            let listener_handle = context.clone().spawn(move |context| async move {
791                listen(
792                    context,
793                    |_| async { true },
794                    listener_config,
795                    listener_stream,
796                    listener_sink,
797                )
798                .await
799            });
800
801            let (mut dialer_sender, _dialer_receiver) = dial(
802                context,
803                dialer_config,
804                listener_crypto.public_key(),
805                dialer_stream,
806                CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
807            )
808            .await?;
809
810            let (_listener_peer, _listener_sender, mut listener_receiver) =
811                listener_handle.await.unwrap()?;
812            sends.store(0, Ordering::Relaxed);
813            chunk_counts.lock().clear();
814
815            // A single framed message larger than the cap still goes out, but it
816            // must occupy its own chunk instead of being rejected or merged.
817            let large = vec![3u8; 200];
818            let small = vec![9u8; 16];
819            dialer_sender
820                .send_many(vec![
821                    IoBufs::from(IoBuf::from(large.clone())),
822                    IoBufs::from(IoBuf::from(small.clone())),
823                ])
824                .await?;
825
826            assert_eq!(sends.load(Ordering::Relaxed), 1);
827            assert_eq!(*chunk_counts.lock(), vec![2]);
828            assert_eq!(listener_receiver.recv().await?.coalesce(), large.as_slice());
829            assert_eq!(listener_receiver.recv().await?.coalesce(), small.as_slice());
830            Ok(())
831        })
832    }
833
834    #[test]
835    fn test_send_many_too_large_preserves_sender_state() -> Result<(), Error> {
836        let executor = deterministic::Runner::default();
837        executor.start(|context| async move {
838            let dialer_crypto = PrivateKey::from_seed(42);
839            let listener_crypto = PrivateKey::from_seed(24);
840
841            let (dialer_sink, listener_stream) = mocks::Channel::init();
842            let (listener_sink, dialer_stream) = mocks::Channel::init();
843            let sends = Arc::new(AtomicUsize::new(0));
844            let chunk_counts = Arc::new(Mutex::new(Vec::new()));
845
846            let dialer_config = transport_config(dialer_crypto.clone());
847            let listener_config = transport_config(listener_crypto.clone());
848
849            let listener_handle = context.clone().spawn(move |context| async move {
850                listen(
851                    context,
852                    |_| async { true },
853                    listener_config,
854                    listener_stream,
855                    listener_sink,
856                )
857                .await
858            });
859
860            let (mut dialer_sender, _dialer_receiver) = dial(
861                context,
862                dialer_config,
863                listener_crypto.public_key(),
864                dialer_stream,
865                CountingSink::new(dialer_sink, sends.clone(), chunk_counts.clone()),
866            )
867            .await?;
868
869            let (_listener_peer, _listener_sender, mut listener_receiver) =
870                listener_handle.await.unwrap()?;
871            sends.store(0, Ordering::Relaxed);
872            chunk_counts.lock().clear();
873
874            let valid = vec![7u8; 32];
875            let oversized = vec![9u8; MAX_MESSAGE_SIZE as usize + 1];
876            assert!(matches!(
877                dialer_sender
878                    .send_many(vec![
879                        IoBufs::from(IoBuf::from(valid)),
880                        IoBufs::from(IoBuf::from(oversized)),
881                    ])
882                    .await,
883                Err(Error::SendTooLarge(_))
884            ));
885
886            assert_eq!(sends.load(Ordering::Relaxed), 0);
887            assert!(chunk_counts.lock().is_empty());
888
889            let recovered = b"recovered";
890            dialer_sender.send(&recovered[..]).await?;
891            assert_eq!(sends.load(Ordering::Relaxed), 1);
892            assert_eq!(listener_receiver.recv().await?.coalesce(), recovered);
893            Ok(())
894        })
895    }
896
897    #[test]
898    fn test_listen_rejects_oversized_fixed_size_peer_key_frame() {
899        let executor = deterministic::Runner::default();
900        executor.start(|context| async move {
901            let dialer_crypto = PrivateKey::from_seed(42);
902            let listener_crypto = PrivateKey::from_seed(24);
903            let peer = dialer_crypto.public_key();
904
905            let (mut dialer_sink, listener_stream) = mocks::Channel::init();
906            let (listener_sink, _dialer_stream) = mocks::Channel::init();
907
908            // Even with a large application limit, the listener should bound the
909            // unauthenticated peer-key frame to the fixed public-key size.
910            let mut listener_config = transport_config(listener_crypto);
911            listener_config.max_message_size = 1024 * 1024;
912
913            // Advertise a frame that is one byte larger than the encoded public
914            // key and send no payload. The old behavior accepted this because it
915            // only compared against `max_message_size`.
916            dialer_sink
917                .send(oversized_handshake_prefix(&peer))
918                .await
919                .unwrap();
920
921            let result = listen(
922                context,
923                |_| async { true },
924                listener_config,
925                listener_stream,
926                listener_sink,
927            )
928            .await;
929
930            // The listener should reject immediately on the fixed-size bound
931            // instead of waiting for more bytes or allocating for the larger
932            // application limit.
933            assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == peer.encode().len() + 1));
934        });
935    }
936
937    #[test]
938    fn test_dial_rejects_oversized_fixed_size_syn_ack_frame() {
939        let executor = deterministic::Runner::default();
940        executor.start(|context| async move {
941            let dialer_crypto = PrivateKey::from_seed(42);
942            let listener_crypto = PrivateKey::from_seed(24);
943
944            let (dialer_sink, _listener_stream) = mocks::Channel::init();
945            let (mut listener_sink, dialer_stream) = mocks::Channel::init();
946
947            // Use a large application limit to make sure this path is guarded by
948            // the fixed SynAck size rather than by post-handshake settings.
949            let mut dialer_config = transport_config(dialer_crypto);
950            dialer_config.max_message_size = 1024 * 1024;
951
952            // Build a valid SynAck only to derive its true encoded size for the
953            // oversized prefix we inject below.
954            let (current_time, ok_timestamps) = dialer_config.time_information(&context);
955            let mut listener_rng = context.clone();
956            let (_, syn) = dial_start(
957                context.clone(),
958                Context::new(
959                    &Transcript::new(&dialer_config.namespace),
960                    current_time,
961                    ok_timestamps.clone(),
962                    dialer_config.signing_key.clone(),
963                    listener_crypto.public_key(),
964                ),
965            );
966            let (_, syn_ack) = listen_start(
967                &mut listener_rng,
968                Context::new(
969                    &Transcript::new(&dialer_config.namespace),
970                    current_time,
971                    ok_timestamps,
972                    listener_crypto.clone(),
973                    dialer_config.signing_key.public_key(),
974                ),
975                syn,
976            )
977            .expect("mock handshake should produce a valid syn_ack");
978
979            // Send only a length prefix that claims a frame one byte larger than
980            // the fixed SynAck encoding.
981            listener_sink
982                .send(oversized_handshake_prefix(&syn_ack))
983                .await
984                .unwrap();
985
986            let result = dial(
987                context,
988                dialer_config,
989                listener_crypto.public_key(),
990                dialer_stream,
991                dialer_sink,
992            )
993            .await;
994
995            // The dialer should reject on the fixed handshake bound before any
996            // larger application-sized receive path is considered.
997            assert!(matches!(
998                result,
999                Err(Error::RecvTooLarge(n))
1000                    if n == syn_ack.encode().len() + 1
1001            ));
1002        });
1003    }
1004}