Skip to main content

nexus_net/tls/
codec.rs

1use std::io::{self, Read, Write};
2
3use rustls::ClientConnection;
4use rustls::pki_types::ServerName;
5
6use super::{TlsConfig, TlsError};
7
8/// Sans-IO TLS codec. Decrypts inbound bytes, encrypts outbound bytes.
9///
10/// Wraps a rustls `ClientConnection` with an API shaped for nexus-net.
11/// The codec is a pure state machine: callers drive IO and buffering;
12/// the codec only transforms bytes.
13///
14/// # API at a glance
15///
16/// - **Inbound:** [`read_tls`](Self::read_tls) feeds buffered ciphertext
17///   one packet step at a time; [`read_tls_from`](Self::read_tls_from)
18///   drives a sync [`Read`] source directly;
19///   [`read_and_process_tls`](Self::read_and_process_tls) loops over
20///   bounded input.
21/// - **Drain plaintext:** [`read_plaintext`](Self::read_plaintext) into
22///   a slice; [`drain_plaintext_into`](Self::drain_plaintext_into) feeds
23///   any [`ParserSink`](crate::ParserSink) (e.g. `FrameReader`) with
24///   one fewer copy.
25/// - **Outbound:** [`encrypt`](Self::encrypt) returns bytes accepted
26///   (chunked); [`write_tls_to`](Self::write_tls_to) drains ciphertext
27///   to a writer.
28/// - **Shutdown:** [`send_close_notify`](Self::send_close_notify)
29///   queues the alert; flush via `write_tls_to` before transport close.
30pub struct TlsCodec {
31    inner: ClientConnection,
32}
33
34impl TlsCodec {
35    /// Create a new TLS codec for the given hostname.
36    ///
37    /// The hostname is used for SNI (Server Name Indication) and
38    /// certificate verification.
39    pub fn new(config: &TlsConfig, hostname: &str) -> Result<Self, TlsError> {
40        let server_name = ServerName::try_from(hostname.to_owned())
41            .map_err(|_| TlsError::InvalidHostname(hostname.to_owned()))?;
42
43        let conn = ClientConnection::new(config.inner.clone(), server_name)?;
44
45        Ok(Self { inner: conn })
46    }
47
48    // =========================================================================
49    // Inbound (socket → TLS → FrameReader)
50    // =========================================================================
51
52    /// Advance the codec by a single TLS packet step: one read + one
53    /// `process_new_packets` pair.
54    ///
55    /// Returns the number of ciphertext bytes consumed from `src`. The
56    /// caller drains any plaintext between calls (via
57    /// [`read_plaintext`](Self::read_plaintext) or
58    /// [`drain_plaintext_into`](Self::drain_plaintext_into)) — feeding
59    /// more ciphertext while plaintext is queued can overflow rustls's
60    /// internal plaintext buffer. This is the canonical primitive for
61    /// streaming app-data adapters (poll socket → step codec → drain
62    /// plaintext → repeat).
63    ///
64    /// For bounded input that fits in rustls's plaintext queue
65    /// (handshake bytes, in-memory tests), use the drain-loop helper
66    /// [`read_and_process_tls`](Self::read_and_process_tls).
67    ///
68    /// # Returns
69    ///
70    /// `Ok(0)` if `src` is empty, or if rustls's deframer cannot
71    /// progress on the input alone (matches `Read::read` idiom — the
72    /// caller's loop is responsible for detecting stuck state).
73    /// Otherwise `Ok(n)` where `n > 0` is bytes consumed (always
74    /// `<= src.len()`; rustls's deframer caps each call at its
75    /// internal `READ_SIZE`).
76    ///
77    /// # Errors
78    ///
79    /// Any rustls error from the read or process step (alerts,
80    /// decryption failures, plaintext-buffer overflow, protocol
81    /// violations).
82    #[inline]
83    pub fn read_tls(&mut self, src: &[u8]) -> Result<usize, TlsError> {
84        if src.is_empty() {
85            return Ok(0);
86        }
87        let mut cursor = io::Cursor::new(src);
88        let consumed = self.inner.read_tls(&mut cursor)?;
89        if consumed > 0 {
90            self.inner.process_new_packets()?;
91        }
92        Ok(consumed)
93    }
94
95    /// Feed buffered TLS bytes through rustls in a loop until the
96    /// entire slice is consumed.
97    ///
98    /// **Use only for bounded input** that fits in rustls's plaintext
99    /// queue — in-memory tests, custom adapters that pre-buffer a
100    /// known-bounded byte sequence. Do **not** use for streaming app
101    /// data: large ciphertext slices fed without intervening plaintext
102    /// drains overflow rustls's internal plaintext buffer
103    /// (`received plaintext buffer full`). For streaming adapters,
104    /// drive [`read_tls`](Self::read_tls) step-by-step yourself.
105    ///
106    /// **No production callers in this crate** — kept as a
107    /// user-facing safety helper. The async `TlsInner::connect`
108    /// (nexus-async-web) and sync `TlsStream::connect` drive their
109    /// own loops over [`read_tls`](Self::read_tls) /
110    /// [`read_tls_from`](Self::read_tls_from). External adapter
111    /// authors who pre-buffer ciphertext can reach for this helper
112    /// to avoid reimplementing the consume-loop.
113    ///
114    /// # Why this exists
115    ///
116    /// `rustls::Connection::read_tls` is not guaranteed to consume the
117    /// full provided slice on a single call. The naive pattern
118    /// `codec.read_tls(&buf)?` silently drops the unconsumed tail
119    /// (issue #200 — a TLS handshake against a server that splits its
120    /// response into multiple records inside one TCP segment fails
121    /// because the unconsumed bytes vanish). This helper encodes the
122    /// correct loop so naive callers don't reintroduce the bug.
123    ///
124    /// # Returns
125    ///
126    /// `Ok(src.len())` when the entire slice has been consumed.
127    ///
128    /// # Errors
129    ///
130    /// - `TlsError::Io(InvalidData)` if rustls's deframer can't make
131    ///   progress (returned 0 bytes consumed) — malformed input.
132    /// - Any rustls error from the underlying read/process steps.
133    pub fn read_and_process_tls(&mut self, src: &[u8]) -> Result<usize, TlsError> {
134        let mut consumed = 0;
135        while consumed < src.len() {
136            let n = self.read_tls(&src[consumed..])?;
137            if n == 0 {
138                return Err(TlsError::Io(io::Error::new(
139                    io::ErrorKind::InvalidData,
140                    "TLS codec stopped before consuming buffered input \
141                     (rustls deframer cannot make progress)",
142                )));
143            }
144            consumed += n;
145        }
146        Ok(consumed)
147    }
148
149    /// Drive a sync [`Read`] source: read up to rustls's internal
150    /// `READ_SIZE` from `src`, then process the records.
151    ///
152    /// Equivalent to one `read_tls` step but pulls bytes from a
153    /// `Read` source instead of a buffer. Returns the bytes read from
154    /// `src`, or 0 on EOF / no bytes available. The caller's loop
155    /// handles the rest.
156    pub fn read_tls_from<R: Read>(&mut self, src: &mut R) -> Result<usize, TlsError> {
157        let n = self.inner.read_tls(src)?;
158        if n > 0 {
159            self.inner.process_new_packets()?;
160        }
161        Ok(n)
162    }
163
164    /// Drain decrypted plaintext into a [`ParserSink`](crate::ParserSink).
165    ///
166    /// Direct-feed path: uses `BufRead::fill_buf` to borrow rustls's
167    /// internal plaintext queue and copy directly into `sink.spare()`,
168    /// skipping the intermediate `&mut [u8]` that the
169    /// [`read_plaintext`](Self::read_plaintext) shape requires. Returns
170    /// the number of plaintext bytes delivered.
171    ///
172    /// Implements the zero-copy seam between rustls and parsers
173    /// (`FrameReader` for WebSocket framing, `ResponseReader` for
174    /// HTTP). Used by adapters' `WireStream::poll_fill_into` to fold
175    /// plaintext draining into the same call that drives ciphertext
176    /// reads.
177    pub fn drain_plaintext_into<P: crate::ParserSink>(
178        &mut self,
179        sink: &mut P,
180    ) -> Result<usize, TlsError> {
181        let mut rd = self.inner.reader();
182        let chunk = match std::io::BufRead::fill_buf(&mut rd) {
183            Ok(chunk) => chunk,
184            Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(0),
185            Err(e) => return Err(TlsError::Io(e)),
186        };
187        if chunk.is_empty() {
188            return Ok(0);
189        }
190        let spare = sink.spare();
191        let n = chunk.len().min(spare.len());
192        if n == 0 {
193            // Sink has no room; caller must drain the parser before
194            // we can deliver more plaintext.
195            return Ok(0);
196        }
197        spare[..n].copy_from_slice(&chunk[..n]);
198        sink.filled(n);
199        std::io::BufRead::consume(&mut rd, n);
200        Ok(n)
201    }
202
203    /// Read decrypted plaintext into a buffer (sans-IO path).
204    ///
205    /// For users who want to feed bytes into FrameReader manually
206    /// or use a different parser.
207    #[inline]
208    pub fn read_plaintext(&mut self, dst: &mut [u8]) -> Result<usize, TlsError> {
209        match self.inner.reader().read(dst) {
210            Ok(n) => Ok(n),
211            Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
212            Err(e) => Err(TlsError::Io(e)),
213        }
214    }
215
216    // =========================================================================
217    // Outbound (FrameWriter → TLS → socket)
218    // =========================================================================
219
220    /// Encrypt up to `plaintext.len()` bytes, returning the number of
221    /// bytes actually accepted by rustls's outbound plaintext queue.
222    ///
223    /// Chunked semantics — the caller's `write_all` (or equivalent)
224    /// handles re-driving on partial acceptance. This is the
225    /// `AsyncWrite::poll_write` contract: surface backpressure as a
226    /// partial count, not a hard error.
227    ///
228    /// # Returns
229    ///
230    /// `Ok(0)` if rustls's queue is full and cannot accept any bytes
231    /// (caller should drain ciphertext to the socket and retry).
232    /// Otherwise `Ok(n)` where `n > 0` is plaintext bytes queued for
233    /// encryption. `n` may be less than `plaintext.len()`.
234    ///
235    /// # Errors
236    ///
237    /// Any rustls writer error other than `WriteZero` (which is
238    /// translated to `Ok(0)` so callers treat queue-full as
239    /// backpressure rather than a hard failure).
240    #[inline]
241    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<usize, TlsError> {
242        match self.inner.writer().write(plaintext) {
243            Ok(n) => Ok(n),
244            Err(e) if e.kind() == io::ErrorKind::WriteZero => Ok(0),
245            Err(e) => Err(TlsError::Io(e)),
246        }
247    }
248
249    /// Set rustls's outbound plaintext queue limit. `None` for
250    /// unlimited (rustls accepts as much plaintext as memory allows;
251    /// pair with a caller-side bound).
252    ///
253    /// Default is rustls's `DEFAULT_BUFFER_LIMIT = 64 KiB`. Trading
254    /// workloads with small messages typically don't need to change
255    /// this. Bulk-transfer workloads (large snapshots, file uploads
256    /// over TLS) may benefit from raising it to reduce drain/refill
257    /// cycles in [`encrypt`](Self::encrypt).
258    pub fn set_buffer_limit(&mut self, limit: Option<usize>) {
259        self.inner.set_buffer_limit(limit);
260    }
261
262    /// Queue a TLS `close_notify` alert.
263    ///
264    /// Subsequent calls to [`wants_write`](Self::wants_write) will
265    /// return true until the alert ciphertext has been written via
266    /// [`write_tls_to`](Self::write_tls_to).
267    ///
268    /// Idempotent: rustls tracks whether close_notify has been sent
269    /// and no-ops on duplicate calls.
270    ///
271    /// Use in `AsyncWrite::poll_shutdown` (or equivalent) before
272    /// closing the underlying transport. Without close_notify, the
273    /// peer sees TCP FIN as a potential truncation and may error its
274    /// read loop mid-stream.
275    #[inline]
276    pub fn send_close_notify(&mut self) {
277        self.inner.send_close_notify();
278    }
279
280    /// Flush encrypted bytes to a socket.
281    ///
282    /// Returns the number of bytes written. Call in a loop or when
283    /// [`wants_write`](Self::wants_write) returns true.
284    pub fn write_tls_to<W: Write>(&mut self, dst: &mut W) -> io::Result<usize> {
285        self.inner.write_tls(dst)
286    }
287
288    // =========================================================================
289    // State
290    // =========================================================================
291
292    /// Whether the TLS handshake is still in progress.
293    #[inline]
294    pub fn is_handshaking(&self) -> bool {
295        self.inner.is_handshaking()
296    }
297
298    /// Whether the codec has buffered TLS data to read.
299    #[inline]
300    pub fn wants_read(&self) -> bool {
301        self.inner.wants_read()
302    }
303
304    /// Whether the codec has encrypted data to write.
305    #[inline]
306    pub fn wants_write(&self) -> bool {
307        self.inner.wants_write()
308    }
309}
310
311impl std::fmt::Debug for TlsCodec {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.debug_struct("TlsCodec")
314            .field("handshaking", &self.inner.is_handshaking())
315            .finish()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use std::io::Cursor;
322    use std::sync::Arc;
323
324    use super::*;
325
326    // -------------------------------------------------------------------------
327    // In-memory handshake scaffolding (lifted from examples/perf_tls.rs).
328    // -------------------------------------------------------------------------
329
330    fn generate_self_signed() -> (Vec<rustls::pki_types::CertificateDer<'static>>, Vec<u8>) {
331        let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
332            .expect("cert generation");
333        (
334            vec![rustls::pki_types::CertificateDer::from(
335                cert.cert.der().to_vec(),
336            )],
337            cert.key_pair.serialize_der(),
338        )
339    }
340
341    /// Generate an N-cert ECDSA-P256 chain whose serialized DER pushes
342    /// the TLS 1.3 server's first handshake burst past rustls's
343    /// `READ_SIZE = 4096` per-call deframer cap. ECDSA keygen is
344    /// microseconds (vs RSA-4096's ~1.5s per key) so this stays cheap
345    /// even at chain depth 10.
346    ///
347    /// Why a deep chain instead of one big RSA cert: chain depth scales
348    /// the Certificate message linearly without paying for slow RSA
349    /// keygen. 10 P-256 certs ≈ 5KB of cert bytes, comfortably over
350    /// 4096. Each link is signed by its parent — a real CA-style chain.
351    ///
352    /// Returns `(chain_in_send_order, leaf_key_der)`. The chain is
353    /// `[leaf, intermediate_n, ..., intermediate_1, root]` — the order
354    /// rustls sends in the Certificate message.
355    fn generate_oversize_ecdsa_chain() -> (Vec<rustls::pki_types::CertificateDer<'static>>, Vec<u8>)
356    {
357        use rcgen::{BasicConstraints, CertificateParams, IsCa, KeyPair};
358
359        const CHAIN_DEPTH: usize = 10;
360
361        // Generate the root + intermediates + leaf. Each non-leaf is a
362        // CA-flagged cert that signs the next link.
363        let mut keys: Vec<KeyPair> = Vec::with_capacity(CHAIN_DEPTH);
364        let mut certs: Vec<rcgen::Certificate> = Vec::with_capacity(CHAIN_DEPTH);
365
366        // Root.
367        let root_key = KeyPair::generate().expect("root key");
368        let mut root_params = CertificateParams::new(Vec::<String>::new()).expect("root params");
369        root_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
370        let root_cert = root_params.self_signed(&root_key).expect("root self-sign");
371        keys.push(root_key);
372        certs.push(root_cert);
373
374        // Intermediates (CHAIN_DEPTH - 2 of them, all CA-flagged).
375        for _ in 0..(CHAIN_DEPTH - 2) {
376            let key = KeyPair::generate().expect("int key");
377            let mut params = CertificateParams::new(Vec::<String>::new()).expect("int params");
378            params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
379            let parent_cert = certs.last().expect("parent");
380            let parent_key = keys.last().expect("parent key");
381            let cert = params
382                .signed_by(&key, parent_cert, parent_key)
383                .expect("int signed");
384            keys.push(key);
385            certs.push(cert);
386        }
387
388        // Leaf (signed by the deepest intermediate, SAN=localhost).
389        let leaf_key = KeyPair::generate().expect("leaf key");
390        let leaf_params =
391            CertificateParams::new(vec!["localhost".to_string()]).expect("leaf params");
392        let parent_cert = certs.last().expect("parent");
393        let parent_key = keys.last().expect("parent key");
394        let leaf_cert = leaf_params
395            .signed_by(&leaf_key, parent_cert, parent_key)
396            .expect("leaf signed");
397
398        // Server sends [leaf, intermediates_descending, root] in the
399        // Certificate message. We built `certs` as [root, int_1, ...,
400        // int_n], so reverse + prepend leaf.
401        let mut chain: Vec<rustls::pki_types::CertificateDer<'static>> =
402            Vec::with_capacity(CHAIN_DEPTH);
403        chain.push(rustls::pki_types::CertificateDer::from(
404            leaf_cert.der().to_vec(),
405        ));
406        for cert in certs.iter().rev() {
407            chain.push(rustls::pki_types::CertificateDer::from(cert.der().to_vec()));
408        }
409
410        (chain, leaf_key.serialize_der())
411    }
412
413    /// In-memory pipe for handshake bytes.
414    struct MemPipe {
415        buf: Vec<u8>,
416    }
417
418    impl MemPipe {
419        fn new() -> Self {
420            Self { buf: Vec::new() }
421        }
422
423        fn write_to(&mut self, data: &[u8]) {
424            self.buf.extend_from_slice(data);
425        }
426
427        fn read_from(&mut self, dst: &mut [u8]) -> usize {
428            let n = dst.len().min(self.buf.len());
429            dst[..n].copy_from_slice(&self.buf[..n]);
430            self.buf.drain(..n);
431            n
432        }
433
434        fn len(&self) -> usize {
435            self.buf.len()
436        }
437    }
438
439    /// Build the server side and capture its first multi-record handshake
440    /// burst (ServerHello + EncryptedExtensions + Certificate + CertVerify +
441    /// Finished under TLS 1.3 — several records pushed back-to-back). The
442    /// returned `server_out` is the slice we feed to the client `TlsCodec`
443    /// to exercise the partial-consumption surface.
444    fn setup_and_capture_server_burst(
445        cert_chain: Vec<rustls::pki_types::CertificateDer<'static>>,
446        key_der: Vec<u8>,
447    ) -> (TlsCodec, rustls::ServerConnection, Vec<u8>) {
448        let key = rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap();
449        let server_config = Arc::new(
450            rustls::ServerConfig::builder()
451                .with_no_client_auth()
452                .with_single_cert(cert_chain, key)
453                .unwrap(),
454        );
455        let mut server = rustls::ServerConnection::new(server_config).unwrap();
456
457        let client_config = TlsConfig::builder().danger_no_verify().build().unwrap();
458        let mut client = TlsCodec::new(&client_config, "localhost").unwrap();
459
460        let mut c2s = MemPipe::new();
461        let mut s2c = MemPipe::new();
462
463        // Client writes ClientHello.
464        // Loop `while wants_write()` (mirroring the server side below)
465        // for defense-in-depth — if a future rustls or cert config splits
466        // the ClientHello across multiple write batches, a single
467        // write_tls_to call would leave bytes pending in the codec.
468        while client.wants_write() {
469            let mut cursor = Cursor::new(Vec::new());
470            client.write_tls_to(&mut cursor).unwrap();
471            c2s.write_to(cursor.get_ref());
472        }
473
474        // Server consumes ClientHello.
475        let mut tmp = vec![0u8; 16384];
476        let n = c2s.read_from(&mut tmp);
477        server
478            .read_tls(&mut Cursor::new(&tmp[..n]))
479            .expect("server reads ClientHello");
480        server.process_new_packets().unwrap();
481
482        // Server writes its multi-record burst.
483        while server.wants_write() {
484            let mut cursor = Cursor::new(Vec::new());
485            server.write_tls(&mut cursor).unwrap();
486            s2c.write_to(cursor.get_ref());
487        }
488
489        let mut server_out = vec![0u8; s2c.len()];
490        let n = s2c.read_from(&mut server_out);
491        assert!(n > 0, "server should have produced handshake bytes");
492        server_out.truncate(n);
493
494        (client, server, server_out)
495    }
496
497    // -------------------------------------------------------------------------
498    // Tests
499    // -------------------------------------------------------------------------
500
501    /// Regression test for issue #200.
502    ///
503    /// Pre-fix: `read_tls(&buf)` may consume only part of `buf`. Calling
504    /// code in nexus-async-web + nexus-net's tls/stream.rs ignored the
505    /// returned consumed count, dropping the unconsumed tail and stalling
506    /// the TLS handshake. Post-fix: `read_and_process_tls` loops until the
507    /// entire slice is consumed.
508    #[test]
509    fn read_and_process_tls_consumes_full_slice() {
510        let (chain, key) = generate_self_signed();
511        let (mut client, _server, server_out) = setup_and_capture_server_burst(chain, key);
512
513        let consumed = client
514            .read_and_process_tls(&server_out)
515            .expect("helper must consume the full slice");
516
517        assert_eq!(
518            consumed,
519            server_out.len(),
520            "helper must consume every byte (issue #200)"
521        );
522        assert!(
523            client.wants_write(),
524            "client should have produced its handshake response"
525        );
526    }
527
528    /// Stricter exercise: feed the captured server bytes one byte per
529    /// `read_and_process_tls` call. Catches a class of bugs where the
530    /// helper itself drops bytes between calls or skips the
531    /// `process_new_packets` step in some iterations.
532    #[test]
533    fn read_and_process_tls_byte_at_a_time() {
534        let (chain, key) = generate_self_signed();
535        let (mut client, _server, server_out) = setup_and_capture_server_burst(chain, key);
536
537        for byte in &server_out {
538            client
539                .read_and_process_tls(std::slice::from_ref(byte))
540                .expect("byte-at-a-time must succeed");
541        }
542
543        assert!(
544            client.wants_write(),
545            "client should have produced its handshake response \
546             after byte-at-a-time consumption"
547        );
548    }
549
550    /// **The actual end-to-end regression test for issue #200.**
551    ///
552    /// The other tests in this module either don't exercise the helper's
553    /// multi-iteration loop (`read_and_process_tls_consumes_full_slice`
554    /// uses a small burst that consumes in one inner iteration;
555    /// `read_and_process_tls_byte_at_a_time` invokes the helper many times
556    /// with 1-byte slices but each invocation has a 1-iteration loop),
557    /// or test only rustls's contract without exercising our helper
558    /// (`bare_read_tls_partially_consumes_large_slice`).
559    ///
560    /// This test uses a 10-cert ECDSA-P256 chain to push the server's
561    /// first handshake burst past rustls's `READ_SIZE = 4096` per-call
562    /// cap. Chain depth (not key size) provides the bytes — keeps
563    /// keygen fast. The helper is fed the whole burst in ONE call; its
564    /// internal loop must iterate multiple times to consume everything.
565    /// This is exactly the shape birch hit against polymarket.
566    #[test]
567    fn read_and_process_tls_handles_oversize_burst() {
568        let (chain, key) = generate_oversize_ecdsa_chain();
569        let (mut client, _server, server_out) = setup_and_capture_server_burst(chain, key);
570
571        // Confirm the test is actually exercising the partial-consumption
572        // path. If this assertion fails, future contributors investigating
573        // know the burst-size assumption broke (e.g., rustls raised
574        // READ_SIZE, or the cert chain shrank). Bump the chain size or
575        // the key size in `generate_oversize_ecdsa_chain` to restore.
576        assert!(
577            server_out.len() > 4096,
578            "burst must exceed READ_SIZE to exercise multi-iteration loop, \
579             got {} bytes — bump cert chain in generate_oversize_ecdsa_chain",
580            server_out.len()
581        );
582
583        let consumed = client
584            .read_and_process_tls(&server_out)
585            .expect("helper must consume the full slice across multiple iterations");
586
587        assert_eq!(
588            consumed,
589            server_out.len(),
590            "helper must consume every byte across the multi-iteration loop \
591             (issue #200 — the actual partial-consumption surface)"
592        );
593        assert!(
594            client.wants_write(),
595            "client should have produced its handshake response after \
596             consuming the oversize burst"
597        );
598    }
599
600    /// Drive an in-memory TLS 1.3 handshake to completion.
601    /// Returns the connected client codec + server connection ready for
602    /// app-data exchange. Used by `read_tls_step` tests that need a
603    /// post-handshake codec.
604    fn connected_pair() -> (TlsCodec, rustls::ServerConnection) {
605        let (cert_chain, key_der) = generate_self_signed();
606        let key = rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap();
607        let server_config = Arc::new(
608            rustls::ServerConfig::builder()
609                .with_no_client_auth()
610                .with_single_cert(cert_chain, key)
611                .unwrap(),
612        );
613        let mut server = rustls::ServerConnection::new(server_config).unwrap();
614
615        let client_config = TlsConfig::builder().danger_no_verify().build().unwrap();
616        let mut client = TlsCodec::new(&client_config, "localhost").unwrap();
617
618        let mut c2s = Vec::new();
619        let mut s2c = Vec::new();
620
621        for _ in 0..64 {
622            while client.wants_write() {
623                client.write_tls_to(&mut c2s).unwrap();
624            }
625            if !c2s.is_empty() {
626                server.read_tls(&mut Cursor::new(&c2s)).unwrap();
627                server.process_new_packets().unwrap();
628                c2s.clear();
629            }
630            while server.wants_write() {
631                server.write_tls(&mut s2c).unwrap();
632            }
633            if !s2c.is_empty() {
634                client.read_and_process_tls(&s2c).unwrap();
635                s2c.clear();
636            }
637            if !client.is_handshaking() && !server.is_handshaking() {
638                return (client, server);
639            }
640        }
641        panic!("TLS handshake did not complete");
642    }
643
644    /// Encrypt `payload` from the server side and capture the resulting
645    /// ciphertext.
646    fn encrypt_server_payload(server: &mut rustls::ServerConnection, payload: &[u8]) -> Vec<u8> {
647        use std::io::Write as _;
648        server.writer().write_all(payload).unwrap();
649        let mut ciphertext = Vec::new();
650        while server.wants_write() {
651            server.write_tls(&mut ciphertext).unwrap();
652        }
653        ciphertext
654    }
655
656    /// Empty input is a cheap no-op, not an error.
657    #[test]
658    fn read_tls_empty_input_returns_zero() {
659        let client_config = TlsConfig::builder().danger_no_verify().build().unwrap();
660        let mut client = TlsCodec::new(&client_config, "localhost").unwrap();
661
662        let n = client.read_tls(&[]).expect("empty input must not error");
663        assert_eq!(n, 0);
664    }
665
666    /// Happy path: feed a small ciphertext prefix, get a non-zero
667    /// consumed count back, drain the resulting plaintext.
668    #[test]
669    fn read_tls_normal_step() {
670        let (mut client, mut server) = connected_pair();
671        let payload = b"hello, world";
672        let ciphertext = encrypt_server_payload(&mut server, payload);
673        assert!(!ciphertext.is_empty());
674
675        let consumed = client
676            .read_tls(&ciphertext)
677            .expect("step must succeed on fresh ciphertext");
678        assert!(consumed > 0, "must consume at least one byte");
679        assert!(consumed <= ciphertext.len());
680
681        let mut dst = vec![0u8; payload.len()];
682        let n = client.read_plaintext(&mut dst).unwrap();
683        assert_eq!(n, payload.len());
684        assert_eq!(&dst[..n], payload);
685    }
686
687    /// Documentation pin: when a caller feeds ciphertext via repeated
688    /// [`read_tls`] calls without ever draining plaintext between
689    /// steps, rustls's internal plaintext buffer eventually overflows
690    /// and surfaces as `received plaintext buffer full`. This is the
691    /// constraint that motivates the streaming pattern (alternate
692    /// step + drain) and the existence of [`read_and_process_tls`]
693    /// for the bounded-input case where overflow can't happen.
694    #[test]
695    fn read_tls_rejects_when_caller_does_not_drain() {
696        let (mut client, mut server) = connected_pair();
697        let payload = vec![b'x'; 64 * 1024];
698        let ciphertext = encrypt_server_payload(&mut server, &payload);
699
700        // Drive read_tls in a loop without ever calling read_plaintext.
701        // Plaintext queues until rustls's cap is hit, then the next
702        // process_new_packets surfaces the overflow.
703        let mut consumed = 0;
704        let error = loop {
705            match client.read_tls(&ciphertext[consumed..]) {
706                Ok(0) => panic!("unexpected stuck state; consumed={consumed}"),
707                Ok(n) => consumed += n,
708                Err(e) => break e,
709            }
710            assert!(
711                consumed < ciphertext.len(),
712                "expected error before consuming entire slice"
713            );
714        };
715        assert!(
716            error.to_string().contains("received plaintext buffer full"),
717            "unexpected error: {error}"
718        );
719    }
720
721    /// `encrypt` returns the partial accepted count when rustls's
722    /// outbound plaintext queue can't hold the full input. Lower the
723    /// queue limit explicitly so the test doesn't depend on rustls's
724    /// internal default.
725    #[test]
726    fn encrypt_returns_partial_when_queue_fills() {
727        let (mut client, _server) = connected_pair();
728        client.set_buffer_limit(Some(4096));
729
730        // First 4 KiB fits.
731        let n1 = client.encrypt(&[b'a'; 4096]).unwrap();
732        assert_eq!(n1, 4096);
733
734        // Next chunk: queue is full. encrypt accepts 0.
735        let n2 = client.encrypt(&[b'b'; 4096]).unwrap();
736        assert_eq!(n2, 0, "queue full → encrypt must report 0 accepted");
737    }
738
739    /// `set_buffer_limit(None)` lifts the cap entirely — `encrypt`
740    /// accepts everything in one shot.
741    #[test]
742    fn set_buffer_limit_none_unlimits_queue() {
743        let (mut client, _server) = connected_pair();
744        client.set_buffer_limit(None);
745
746        // Heap-allocated to avoid a 256 KiB stack frame in this test.
747        let payload = vec![b'x'; 256 * 1024];
748        let n = client.encrypt(&payload).unwrap();
749        assert_eq!(
750            n,
751            256 * 1024,
752            "unlimited queue must accept the entire payload"
753        );
754    }
755
756    /// `drain_plaintext_into` direct-feeds a [`ParserSink`] without
757    /// the intermediate slice copy. Pins the zero-copy path that
758    /// `WireStream::poll_fill_into` uses on TLS adapters.
759    #[test]
760    fn drain_plaintext_into_zero_copy_path() {
761        struct CaptureSink {
762            buf: Vec<u8>,
763            committed: usize,
764        }
765        impl crate::ParserSink for CaptureSink {
766            fn spare(&mut self) -> &mut [u8] {
767                &mut self.buf[self.committed..]
768            }
769            fn filled(&mut self, n: usize) {
770                self.committed += n;
771            }
772        }
773
774        let (mut client, mut server) = connected_pair();
775        let payload = b"hello-frames";
776        let ciphertext = encrypt_server_payload(&mut server, payload);
777
778        // Step the codec until plaintext is queued.
779        let mut consumed = 0;
780        while consumed < ciphertext.len() {
781            consumed += client.read_tls(&ciphertext[consumed..]).unwrap();
782        }
783
784        let mut sink = CaptureSink {
785            buf: vec![0u8; 64],
786            committed: 0,
787        };
788        let n = client
789            .drain_plaintext_into(&mut sink)
790            .expect("drain_plaintext_into must succeed");
791        assert_eq!(n, payload.len(), "must feed all queued plaintext");
792        assert_eq!(&sink.buf[..n], payload);
793
794        // Idempotent on empty queue.
795        let n = client.drain_plaintext_into(&mut sink).unwrap();
796        assert_eq!(n, 0, "no more plaintext → Ok(0)");
797    }
798
799    /// `read_tls_from` drives a sync [`Read`] source: reads up to one
800    /// `READ_SIZE` from the source, processes packets, returns bytes
801    /// pulled. Verifies the read+process pair fold (caller no longer
802    /// has to call `process_new_packets` after).
803    #[test]
804    fn read_tls_from_drives_sync_read_source() {
805        let (mut client, mut server) = connected_pair();
806        let payload = b"hello-from-source";
807        let ciphertext = encrypt_server_payload(&mut server, payload);
808
809        let mut cursor = Cursor::new(ciphertext);
810        let mut total = 0;
811        while total < cursor.get_ref().len() {
812            let n = client.read_tls_from(&mut cursor).unwrap();
813            if n == 0 {
814                break;
815            }
816            total += n;
817        }
818        assert!(total > 0);
819
820        let mut dst = vec![0u8; payload.len()];
821        let n = client.read_plaintext(&mut dst).unwrap();
822        assert_eq!(n, payload.len());
823        assert_eq!(&dst[..n], payload);
824    }
825}