nexus_net/tls/codec.rs
1use std::io::{self, Read, Write};
2
3use rustls::ClientConnection;
4use rustls::pki_types::ServerName;
5
6use super::{TlsConfig, TlsError};
7
8/// Sans-IO TLS codec. Decrypts inbound bytes, encrypts outbound bytes.
9///
10/// Wraps a rustls `ClientConnection` with an API shaped for nexus-net.
11/// The codec is a pure state machine: callers drive IO and buffering;
12/// the codec only transforms bytes.
13///
14/// # API at a glance
15///
16/// - **Inbound:** [`read_tls`](Self::read_tls) feeds buffered ciphertext
17/// one packet step at a time; [`read_tls_from`](Self::read_tls_from)
18/// drives a sync [`Read`] source directly;
19/// [`read_and_process_tls`](Self::read_and_process_tls) loops over
20/// bounded input.
21/// - **Drain plaintext:** [`read_plaintext`](Self::read_plaintext) into
22/// a slice; [`drain_plaintext_into`](Self::drain_plaintext_into) feeds
23/// any [`ParserSink`](crate::ParserSink) (e.g. `FrameReader`) with
24/// one fewer copy.
25/// - **Outbound:** [`encrypt`](Self::encrypt) returns bytes accepted
26/// (chunked); [`write_tls_to`](Self::write_tls_to) drains ciphertext
27/// to a writer.
28/// - **Shutdown:** [`send_close_notify`](Self::send_close_notify)
29/// queues the alert; flush via `write_tls_to` before transport close.
30pub struct TlsCodec {
31 inner: ClientConnection,
32}
33
34impl TlsCodec {
35 /// Create a new TLS codec for the given hostname.
36 ///
37 /// The hostname is used for SNI (Server Name Indication) and
38 /// certificate verification.
39 pub fn new(config: &TlsConfig, hostname: &str) -> Result<Self, TlsError> {
40 let server_name = ServerName::try_from(hostname.to_owned())
41 .map_err(|_| TlsError::InvalidHostname(hostname.to_owned()))?;
42
43 let conn = ClientConnection::new(config.inner.clone(), server_name)?;
44
45 Ok(Self { inner: conn })
46 }
47
48 // =========================================================================
49 // Inbound (socket → TLS → FrameReader)
50 // =========================================================================
51
52 /// Advance the codec by a single TLS packet step: one read + one
53 /// `process_new_packets` pair.
54 ///
55 /// Returns the number of ciphertext bytes consumed from `src`. The
56 /// caller drains any plaintext between calls (via
57 /// [`read_plaintext`](Self::read_plaintext) or
58 /// [`drain_plaintext_into`](Self::drain_plaintext_into)) — feeding
59 /// more ciphertext while plaintext is queued can overflow rustls's
60 /// internal plaintext buffer. This is the canonical primitive for
61 /// streaming app-data adapters (poll socket → step codec → drain
62 /// plaintext → repeat).
63 ///
64 /// For bounded input that fits in rustls's plaintext queue
65 /// (handshake bytes, in-memory tests), use the drain-loop helper
66 /// [`read_and_process_tls`](Self::read_and_process_tls).
67 ///
68 /// # Returns
69 ///
70 /// `Ok(0)` if `src` is empty, or if rustls's deframer cannot
71 /// progress on the input alone (matches `Read::read` idiom — the
72 /// caller's loop is responsible for detecting stuck state).
73 /// Otherwise `Ok(n)` where `n > 0` is bytes consumed (always
74 /// `<= src.len()`; rustls's deframer caps each call at its
75 /// internal `READ_SIZE`).
76 ///
77 /// # Errors
78 ///
79 /// Any rustls error from the read or process step (alerts,
80 /// decryption failures, plaintext-buffer overflow, protocol
81 /// violations).
82 #[inline]
83 pub fn read_tls(&mut self, src: &[u8]) -> Result<usize, TlsError> {
84 if src.is_empty() {
85 return Ok(0);
86 }
87 let mut cursor = io::Cursor::new(src);
88 let consumed = self.inner.read_tls(&mut cursor)?;
89 if consumed > 0 {
90 self.inner.process_new_packets()?;
91 }
92 Ok(consumed)
93 }
94
95 /// Feed buffered TLS bytes through rustls in a loop until the
96 /// entire slice is consumed.
97 ///
98 /// **Use only for bounded input** that fits in rustls's plaintext
99 /// queue — in-memory tests, custom adapters that pre-buffer a
100 /// known-bounded byte sequence. Do **not** use for streaming app
101 /// data: large ciphertext slices fed without intervening plaintext
102 /// drains overflow rustls's internal plaintext buffer
103 /// (`received plaintext buffer full`). For streaming adapters,
104 /// drive [`read_tls`](Self::read_tls) step-by-step yourself.
105 ///
106 /// **No production callers in this crate** — kept as a
107 /// user-facing safety helper. The async `TlsInner::connect`
108 /// (nexus-async-web) and sync `TlsStream::connect` drive their
109 /// own loops over [`read_tls`](Self::read_tls) /
110 /// [`read_tls_from`](Self::read_tls_from). External adapter
111 /// authors who pre-buffer ciphertext can reach for this helper
112 /// to avoid reimplementing the consume-loop.
113 ///
114 /// # Why this exists
115 ///
116 /// `rustls::Connection::read_tls` is not guaranteed to consume the
117 /// full provided slice on a single call. The naive pattern
118 /// `codec.read_tls(&buf)?` silently drops the unconsumed tail
119 /// (issue #200 — a TLS handshake against a server that splits its
120 /// response into multiple records inside one TCP segment fails
121 /// because the unconsumed bytes vanish). This helper encodes the
122 /// correct loop so naive callers don't reintroduce the bug.
123 ///
124 /// # Returns
125 ///
126 /// `Ok(src.len())` when the entire slice has been consumed.
127 ///
128 /// # Errors
129 ///
130 /// - `TlsError::Io(InvalidData)` if rustls's deframer can't make
131 /// progress (returned 0 bytes consumed) — malformed input.
132 /// - Any rustls error from the underlying read/process steps.
133 pub fn read_and_process_tls(&mut self, src: &[u8]) -> Result<usize, TlsError> {
134 let mut consumed = 0;
135 while consumed < src.len() {
136 let n = self.read_tls(&src[consumed..])?;
137 if n == 0 {
138 return Err(TlsError::Io(io::Error::new(
139 io::ErrorKind::InvalidData,
140 "TLS codec stopped before consuming buffered input \
141 (rustls deframer cannot make progress)",
142 )));
143 }
144 consumed += n;
145 }
146 Ok(consumed)
147 }
148
149 /// Drive a sync [`Read`] source: read up to rustls's internal
150 /// `READ_SIZE` from `src`, then process the records.
151 ///
152 /// Equivalent to one `read_tls` step but pulls bytes from a
153 /// `Read` source instead of a buffer. Returns the bytes read from
154 /// `src`, or 0 on EOF / no bytes available. The caller's loop
155 /// handles the rest.
156 pub fn read_tls_from<R: Read>(&mut self, src: &mut R) -> Result<usize, TlsError> {
157 let n = self.inner.read_tls(src)?;
158 if n > 0 {
159 self.inner.process_new_packets()?;
160 }
161 Ok(n)
162 }
163
164 /// Drain decrypted plaintext into a [`ParserSink`](crate::ParserSink).
165 ///
166 /// Direct-feed path: uses `BufRead::fill_buf` to borrow rustls's
167 /// internal plaintext queue and copy directly into `sink.spare()`,
168 /// skipping the intermediate `&mut [u8]` that the
169 /// [`read_plaintext`](Self::read_plaintext) shape requires. Returns
170 /// the number of plaintext bytes delivered.
171 ///
172 /// Implements the zero-copy seam between rustls and parsers
173 /// (`FrameReader` for WebSocket framing, `ResponseReader` for
174 /// HTTP). Used by adapters' `WireStream::poll_fill_into` to fold
175 /// plaintext draining into the same call that drives ciphertext
176 /// reads.
177 pub fn drain_plaintext_into<P: crate::ParserSink>(
178 &mut self,
179 sink: &mut P,
180 ) -> Result<usize, TlsError> {
181 let mut rd = self.inner.reader();
182 let chunk = match std::io::BufRead::fill_buf(&mut rd) {
183 Ok(chunk) => chunk,
184 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(0),
185 Err(e) => return Err(TlsError::Io(e)),
186 };
187 if chunk.is_empty() {
188 return Ok(0);
189 }
190 let spare = sink.spare();
191 let n = chunk.len().min(spare.len());
192 if n == 0 {
193 // Sink has no room; caller must drain the parser before
194 // we can deliver more plaintext.
195 return Ok(0);
196 }
197 spare[..n].copy_from_slice(&chunk[..n]);
198 sink.filled(n);
199 std::io::BufRead::consume(&mut rd, n);
200 Ok(n)
201 }
202
203 /// Read decrypted plaintext into a buffer (sans-IO path).
204 ///
205 /// For users who want to feed bytes into FrameReader manually
206 /// or use a different parser.
207 #[inline]
208 pub fn read_plaintext(&mut self, dst: &mut [u8]) -> Result<usize, TlsError> {
209 match self.inner.reader().read(dst) {
210 Ok(n) => Ok(n),
211 Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
212 Err(e) => Err(TlsError::Io(e)),
213 }
214 }
215
216 // =========================================================================
217 // Outbound (FrameWriter → TLS → socket)
218 // =========================================================================
219
220 /// Encrypt up to `plaintext.len()` bytes, returning the number of
221 /// bytes actually accepted by rustls's outbound plaintext queue.
222 ///
223 /// Chunked semantics — the caller's `write_all` (or equivalent)
224 /// handles re-driving on partial acceptance. This is the
225 /// `AsyncWrite::poll_write` contract: surface backpressure as a
226 /// partial count, not a hard error.
227 ///
228 /// # Returns
229 ///
230 /// `Ok(0)` if rustls's queue is full and cannot accept any bytes
231 /// (caller should drain ciphertext to the socket and retry).
232 /// Otherwise `Ok(n)` where `n > 0` is plaintext bytes queued for
233 /// encryption. `n` may be less than `plaintext.len()`.
234 ///
235 /// # Errors
236 ///
237 /// Any rustls writer error other than `WriteZero` (which is
238 /// translated to `Ok(0)` so callers treat queue-full as
239 /// backpressure rather than a hard failure).
240 #[inline]
241 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<usize, TlsError> {
242 match self.inner.writer().write(plaintext) {
243 Ok(n) => Ok(n),
244 Err(e) if e.kind() == io::ErrorKind::WriteZero => Ok(0),
245 Err(e) => Err(TlsError::Io(e)),
246 }
247 }
248
249 /// Set rustls's outbound plaintext queue limit. `None` for
250 /// unlimited (rustls accepts as much plaintext as memory allows;
251 /// pair with a caller-side bound).
252 ///
253 /// Default is rustls's `DEFAULT_BUFFER_LIMIT = 64 KiB`. Trading
254 /// workloads with small messages typically don't need to change
255 /// this. Bulk-transfer workloads (large snapshots, file uploads
256 /// over TLS) may benefit from raising it to reduce drain/refill
257 /// cycles in [`encrypt`](Self::encrypt).
258 pub fn set_buffer_limit(&mut self, limit: Option<usize>) {
259 self.inner.set_buffer_limit(limit);
260 }
261
262 /// Queue a TLS `close_notify` alert.
263 ///
264 /// Subsequent calls to [`wants_write`](Self::wants_write) will
265 /// return true until the alert ciphertext has been written via
266 /// [`write_tls_to`](Self::write_tls_to).
267 ///
268 /// Idempotent: rustls tracks whether close_notify has been sent
269 /// and no-ops on duplicate calls.
270 ///
271 /// Use in `AsyncWrite::poll_shutdown` (or equivalent) before
272 /// closing the underlying transport. Without close_notify, the
273 /// peer sees TCP FIN as a potential truncation and may error its
274 /// read loop mid-stream.
275 #[inline]
276 pub fn send_close_notify(&mut self) {
277 self.inner.send_close_notify();
278 }
279
280 /// Flush encrypted bytes to a socket.
281 ///
282 /// Returns the number of bytes written. Call in a loop or when
283 /// [`wants_write`](Self::wants_write) returns true.
284 pub fn write_tls_to<W: Write>(&mut self, dst: &mut W) -> io::Result<usize> {
285 self.inner.write_tls(dst)
286 }
287
288 // =========================================================================
289 // State
290 // =========================================================================
291
292 /// Whether the TLS handshake is still in progress.
293 #[inline]
294 pub fn is_handshaking(&self) -> bool {
295 self.inner.is_handshaking()
296 }
297
298 /// Whether the codec has buffered TLS data to read.
299 #[inline]
300 pub fn wants_read(&self) -> bool {
301 self.inner.wants_read()
302 }
303
304 /// Whether the codec has encrypted data to write.
305 #[inline]
306 pub fn wants_write(&self) -> bool {
307 self.inner.wants_write()
308 }
309}
310
311impl std::fmt::Debug for TlsCodec {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 f.debug_struct("TlsCodec")
314 .field("handshaking", &self.inner.is_handshaking())
315 .finish()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use std::io::Cursor;
322 use std::sync::Arc;
323
324 use super::*;
325
326 // -------------------------------------------------------------------------
327 // In-memory handshake scaffolding (lifted from examples/perf_tls.rs).
328 // -------------------------------------------------------------------------
329
330 fn generate_self_signed() -> (Vec<rustls::pki_types::CertificateDer<'static>>, Vec<u8>) {
331 let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
332 .expect("cert generation");
333 (
334 vec![rustls::pki_types::CertificateDer::from(
335 cert.cert.der().to_vec(),
336 )],
337 cert.key_pair.serialize_der(),
338 )
339 }
340
341 /// Generate an N-cert ECDSA-P256 chain whose serialized DER pushes
342 /// the TLS 1.3 server's first handshake burst past rustls's
343 /// `READ_SIZE = 4096` per-call deframer cap. ECDSA keygen is
344 /// microseconds (vs RSA-4096's ~1.5s per key) so this stays cheap
345 /// even at chain depth 10.
346 ///
347 /// Why a deep chain instead of one big RSA cert: chain depth scales
348 /// the Certificate message linearly without paying for slow RSA
349 /// keygen. 10 P-256 certs ≈ 5KB of cert bytes, comfortably over
350 /// 4096. Each link is signed by its parent — a real CA-style chain.
351 ///
352 /// Returns `(chain_in_send_order, leaf_key_der)`. The chain is
353 /// `[leaf, intermediate_n, ..., intermediate_1, root]` — the order
354 /// rustls sends in the Certificate message.
355 fn generate_oversize_ecdsa_chain() -> (Vec<rustls::pki_types::CertificateDer<'static>>, Vec<u8>)
356 {
357 use rcgen::{BasicConstraints, CertificateParams, IsCa, KeyPair};
358
359 const CHAIN_DEPTH: usize = 10;
360
361 // Generate the root + intermediates + leaf. Each non-leaf is a
362 // CA-flagged cert that signs the next link.
363 let mut keys: Vec<KeyPair> = Vec::with_capacity(CHAIN_DEPTH);
364 let mut certs: Vec<rcgen::Certificate> = Vec::with_capacity(CHAIN_DEPTH);
365
366 // Root.
367 let root_key = KeyPair::generate().expect("root key");
368 let mut root_params = CertificateParams::new(Vec::<String>::new()).expect("root params");
369 root_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
370 let root_cert = root_params.self_signed(&root_key).expect("root self-sign");
371 keys.push(root_key);
372 certs.push(root_cert);
373
374 // Intermediates (CHAIN_DEPTH - 2 of them, all CA-flagged).
375 for _ in 0..(CHAIN_DEPTH - 2) {
376 let key = KeyPair::generate().expect("int key");
377 let mut params = CertificateParams::new(Vec::<String>::new()).expect("int params");
378 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
379 let parent_cert = certs.last().expect("parent");
380 let parent_key = keys.last().expect("parent key");
381 let cert = params
382 .signed_by(&key, parent_cert, parent_key)
383 .expect("int signed");
384 keys.push(key);
385 certs.push(cert);
386 }
387
388 // Leaf (signed by the deepest intermediate, SAN=localhost).
389 let leaf_key = KeyPair::generate().expect("leaf key");
390 let leaf_params =
391 CertificateParams::new(vec!["localhost".to_string()]).expect("leaf params");
392 let parent_cert = certs.last().expect("parent");
393 let parent_key = keys.last().expect("parent key");
394 let leaf_cert = leaf_params
395 .signed_by(&leaf_key, parent_cert, parent_key)
396 .expect("leaf signed");
397
398 // Server sends [leaf, intermediates_descending, root] in the
399 // Certificate message. We built `certs` as [root, int_1, ...,
400 // int_n], so reverse + prepend leaf.
401 let mut chain: Vec<rustls::pki_types::CertificateDer<'static>> =
402 Vec::with_capacity(CHAIN_DEPTH);
403 chain.push(rustls::pki_types::CertificateDer::from(
404 leaf_cert.der().to_vec(),
405 ));
406 for cert in certs.iter().rev() {
407 chain.push(rustls::pki_types::CertificateDer::from(cert.der().to_vec()));
408 }
409
410 (chain, leaf_key.serialize_der())
411 }
412
413 /// In-memory pipe for handshake bytes.
414 struct MemPipe {
415 buf: Vec<u8>,
416 }
417
418 impl MemPipe {
419 fn new() -> Self {
420 Self { buf: Vec::new() }
421 }
422
423 fn write_to(&mut self, data: &[u8]) {
424 self.buf.extend_from_slice(data);
425 }
426
427 fn read_from(&mut self, dst: &mut [u8]) -> usize {
428 let n = dst.len().min(self.buf.len());
429 dst[..n].copy_from_slice(&self.buf[..n]);
430 self.buf.drain(..n);
431 n
432 }
433
434 fn len(&self) -> usize {
435 self.buf.len()
436 }
437 }
438
439 /// Build the server side and capture its first multi-record handshake
440 /// burst (ServerHello + EncryptedExtensions + Certificate + CertVerify +
441 /// Finished under TLS 1.3 — several records pushed back-to-back). The
442 /// returned `server_out` is the slice we feed to the client `TlsCodec`
443 /// to exercise the partial-consumption surface.
444 fn setup_and_capture_server_burst(
445 cert_chain: Vec<rustls::pki_types::CertificateDer<'static>>,
446 key_der: Vec<u8>,
447 ) -> (TlsCodec, rustls::ServerConnection, Vec<u8>) {
448 let key = rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap();
449 let server_config = Arc::new(
450 rustls::ServerConfig::builder()
451 .with_no_client_auth()
452 .with_single_cert(cert_chain, key)
453 .unwrap(),
454 );
455 let mut server = rustls::ServerConnection::new(server_config).unwrap();
456
457 let client_config = TlsConfig::builder().danger_no_verify().build().unwrap();
458 let mut client = TlsCodec::new(&client_config, "localhost").unwrap();
459
460 let mut c2s = MemPipe::new();
461 let mut s2c = MemPipe::new();
462
463 // Client writes ClientHello.
464 // Loop `while wants_write()` (mirroring the server side below)
465 // for defense-in-depth — if a future rustls or cert config splits
466 // the ClientHello across multiple write batches, a single
467 // write_tls_to call would leave bytes pending in the codec.
468 while client.wants_write() {
469 let mut cursor = Cursor::new(Vec::new());
470 client.write_tls_to(&mut cursor).unwrap();
471 c2s.write_to(cursor.get_ref());
472 }
473
474 // Server consumes ClientHello.
475 let mut tmp = vec![0u8; 16384];
476 let n = c2s.read_from(&mut tmp);
477 server
478 .read_tls(&mut Cursor::new(&tmp[..n]))
479 .expect("server reads ClientHello");
480 server.process_new_packets().unwrap();
481
482 // Server writes its multi-record burst.
483 while server.wants_write() {
484 let mut cursor = Cursor::new(Vec::new());
485 server.write_tls(&mut cursor).unwrap();
486 s2c.write_to(cursor.get_ref());
487 }
488
489 let mut server_out = vec![0u8; s2c.len()];
490 let n = s2c.read_from(&mut server_out);
491 assert!(n > 0, "server should have produced handshake bytes");
492 server_out.truncate(n);
493
494 (client, server, server_out)
495 }
496
497 // -------------------------------------------------------------------------
498 // Tests
499 // -------------------------------------------------------------------------
500
501 /// Regression test for issue #200.
502 ///
503 /// Pre-fix: `read_tls(&buf)` may consume only part of `buf`. Calling
504 /// code in nexus-async-web + nexus-net's tls/stream.rs ignored the
505 /// returned consumed count, dropping the unconsumed tail and stalling
506 /// the TLS handshake. Post-fix: `read_and_process_tls` loops until the
507 /// entire slice is consumed.
508 #[test]
509 fn read_and_process_tls_consumes_full_slice() {
510 let (chain, key) = generate_self_signed();
511 let (mut client, _server, server_out) = setup_and_capture_server_burst(chain, key);
512
513 let consumed = client
514 .read_and_process_tls(&server_out)
515 .expect("helper must consume the full slice");
516
517 assert_eq!(
518 consumed,
519 server_out.len(),
520 "helper must consume every byte (issue #200)"
521 );
522 assert!(
523 client.wants_write(),
524 "client should have produced its handshake response"
525 );
526 }
527
528 /// Stricter exercise: feed the captured server bytes one byte per
529 /// `read_and_process_tls` call. Catches a class of bugs where the
530 /// helper itself drops bytes between calls or skips the
531 /// `process_new_packets` step in some iterations.
532 #[test]
533 fn read_and_process_tls_byte_at_a_time() {
534 let (chain, key) = generate_self_signed();
535 let (mut client, _server, server_out) = setup_and_capture_server_burst(chain, key);
536
537 for byte in &server_out {
538 client
539 .read_and_process_tls(std::slice::from_ref(byte))
540 .expect("byte-at-a-time must succeed");
541 }
542
543 assert!(
544 client.wants_write(),
545 "client should have produced its handshake response \
546 after byte-at-a-time consumption"
547 );
548 }
549
550 /// **The actual end-to-end regression test for issue #200.**
551 ///
552 /// The other tests in this module either don't exercise the helper's
553 /// multi-iteration loop (`read_and_process_tls_consumes_full_slice`
554 /// uses a small burst that consumes in one inner iteration;
555 /// `read_and_process_tls_byte_at_a_time` invokes the helper many times
556 /// with 1-byte slices but each invocation has a 1-iteration loop),
557 /// or test only rustls's contract without exercising our helper
558 /// (`bare_read_tls_partially_consumes_large_slice`).
559 ///
560 /// This test uses a 10-cert ECDSA-P256 chain to push the server's
561 /// first handshake burst past rustls's `READ_SIZE = 4096` per-call
562 /// cap. Chain depth (not key size) provides the bytes — keeps
563 /// keygen fast. The helper is fed the whole burst in ONE call; its
564 /// internal loop must iterate multiple times to consume everything.
565 /// This is exactly the shape birch hit against polymarket.
566 #[test]
567 fn read_and_process_tls_handles_oversize_burst() {
568 let (chain, key) = generate_oversize_ecdsa_chain();
569 let (mut client, _server, server_out) = setup_and_capture_server_burst(chain, key);
570
571 // Confirm the test is actually exercising the partial-consumption
572 // path. If this assertion fails, future contributors investigating
573 // know the burst-size assumption broke (e.g., rustls raised
574 // READ_SIZE, or the cert chain shrank). Bump the chain size or
575 // the key size in `generate_oversize_ecdsa_chain` to restore.
576 assert!(
577 server_out.len() > 4096,
578 "burst must exceed READ_SIZE to exercise multi-iteration loop, \
579 got {} bytes — bump cert chain in generate_oversize_ecdsa_chain",
580 server_out.len()
581 );
582
583 let consumed = client
584 .read_and_process_tls(&server_out)
585 .expect("helper must consume the full slice across multiple iterations");
586
587 assert_eq!(
588 consumed,
589 server_out.len(),
590 "helper must consume every byte across the multi-iteration loop \
591 (issue #200 — the actual partial-consumption surface)"
592 );
593 assert!(
594 client.wants_write(),
595 "client should have produced its handshake response after \
596 consuming the oversize burst"
597 );
598 }
599
600 /// Drive an in-memory TLS 1.3 handshake to completion.
601 /// Returns the connected client codec + server connection ready for
602 /// app-data exchange. Used by `read_tls_step` tests that need a
603 /// post-handshake codec.
604 fn connected_pair() -> (TlsCodec, rustls::ServerConnection) {
605 let (cert_chain, key_der) = generate_self_signed();
606 let key = rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap();
607 let server_config = Arc::new(
608 rustls::ServerConfig::builder()
609 .with_no_client_auth()
610 .with_single_cert(cert_chain, key)
611 .unwrap(),
612 );
613 let mut server = rustls::ServerConnection::new(server_config).unwrap();
614
615 let client_config = TlsConfig::builder().danger_no_verify().build().unwrap();
616 let mut client = TlsCodec::new(&client_config, "localhost").unwrap();
617
618 let mut c2s = Vec::new();
619 let mut s2c = Vec::new();
620
621 for _ in 0..64 {
622 while client.wants_write() {
623 client.write_tls_to(&mut c2s).unwrap();
624 }
625 if !c2s.is_empty() {
626 server.read_tls(&mut Cursor::new(&c2s)).unwrap();
627 server.process_new_packets().unwrap();
628 c2s.clear();
629 }
630 while server.wants_write() {
631 server.write_tls(&mut s2c).unwrap();
632 }
633 if !s2c.is_empty() {
634 client.read_and_process_tls(&s2c).unwrap();
635 s2c.clear();
636 }
637 if !client.is_handshaking() && !server.is_handshaking() {
638 return (client, server);
639 }
640 }
641 panic!("TLS handshake did not complete");
642 }
643
644 /// Encrypt `payload` from the server side and capture the resulting
645 /// ciphertext.
646 fn encrypt_server_payload(server: &mut rustls::ServerConnection, payload: &[u8]) -> Vec<u8> {
647 use std::io::Write as _;
648 server.writer().write_all(payload).unwrap();
649 let mut ciphertext = Vec::new();
650 while server.wants_write() {
651 server.write_tls(&mut ciphertext).unwrap();
652 }
653 ciphertext
654 }
655
656 /// Empty input is a cheap no-op, not an error.
657 #[test]
658 fn read_tls_empty_input_returns_zero() {
659 let client_config = TlsConfig::builder().danger_no_verify().build().unwrap();
660 let mut client = TlsCodec::new(&client_config, "localhost").unwrap();
661
662 let n = client.read_tls(&[]).expect("empty input must not error");
663 assert_eq!(n, 0);
664 }
665
666 /// Happy path: feed a small ciphertext prefix, get a non-zero
667 /// consumed count back, drain the resulting plaintext.
668 #[test]
669 fn read_tls_normal_step() {
670 let (mut client, mut server) = connected_pair();
671 let payload = b"hello, world";
672 let ciphertext = encrypt_server_payload(&mut server, payload);
673 assert!(!ciphertext.is_empty());
674
675 let consumed = client
676 .read_tls(&ciphertext)
677 .expect("step must succeed on fresh ciphertext");
678 assert!(consumed > 0, "must consume at least one byte");
679 assert!(consumed <= ciphertext.len());
680
681 let mut dst = vec![0u8; payload.len()];
682 let n = client.read_plaintext(&mut dst).unwrap();
683 assert_eq!(n, payload.len());
684 assert_eq!(&dst[..n], payload);
685 }
686
687 /// Documentation pin: when a caller feeds ciphertext via repeated
688 /// [`read_tls`] calls without ever draining plaintext between
689 /// steps, rustls's internal plaintext buffer eventually overflows
690 /// and surfaces as `received plaintext buffer full`. This is the
691 /// constraint that motivates the streaming pattern (alternate
692 /// step + drain) and the existence of [`read_and_process_tls`]
693 /// for the bounded-input case where overflow can't happen.
694 #[test]
695 fn read_tls_rejects_when_caller_does_not_drain() {
696 let (mut client, mut server) = connected_pair();
697 let payload = vec![b'x'; 64 * 1024];
698 let ciphertext = encrypt_server_payload(&mut server, &payload);
699
700 // Drive read_tls in a loop without ever calling read_plaintext.
701 // Plaintext queues until rustls's cap is hit, then the next
702 // process_new_packets surfaces the overflow.
703 let mut consumed = 0;
704 let error = loop {
705 match client.read_tls(&ciphertext[consumed..]) {
706 Ok(0) => panic!("unexpected stuck state; consumed={consumed}"),
707 Ok(n) => consumed += n,
708 Err(e) => break e,
709 }
710 assert!(
711 consumed < ciphertext.len(),
712 "expected error before consuming entire slice"
713 );
714 };
715 assert!(
716 error.to_string().contains("received plaintext buffer full"),
717 "unexpected error: {error}"
718 );
719 }
720
721 /// `encrypt` returns the partial accepted count when rustls's
722 /// outbound plaintext queue can't hold the full input. Lower the
723 /// queue limit explicitly so the test doesn't depend on rustls's
724 /// internal default.
725 #[test]
726 fn encrypt_returns_partial_when_queue_fills() {
727 let (mut client, _server) = connected_pair();
728 client.set_buffer_limit(Some(4096));
729
730 // First 4 KiB fits.
731 let n1 = client.encrypt(&[b'a'; 4096]).unwrap();
732 assert_eq!(n1, 4096);
733
734 // Next chunk: queue is full. encrypt accepts 0.
735 let n2 = client.encrypt(&[b'b'; 4096]).unwrap();
736 assert_eq!(n2, 0, "queue full → encrypt must report 0 accepted");
737 }
738
739 /// `set_buffer_limit(None)` lifts the cap entirely — `encrypt`
740 /// accepts everything in one shot.
741 #[test]
742 fn set_buffer_limit_none_unlimits_queue() {
743 let (mut client, _server) = connected_pair();
744 client.set_buffer_limit(None);
745
746 // Heap-allocated to avoid a 256 KiB stack frame in this test.
747 let payload = vec![b'x'; 256 * 1024];
748 let n = client.encrypt(&payload).unwrap();
749 assert_eq!(
750 n,
751 256 * 1024,
752 "unlimited queue must accept the entire payload"
753 );
754 }
755
756 /// `drain_plaintext_into` direct-feeds a [`ParserSink`] without
757 /// the intermediate slice copy. Pins the zero-copy path that
758 /// `WireStream::poll_fill_into` uses on TLS adapters.
759 #[test]
760 fn drain_plaintext_into_zero_copy_path() {
761 struct CaptureSink {
762 buf: Vec<u8>,
763 committed: usize,
764 }
765 impl crate::ParserSink for CaptureSink {
766 fn spare(&mut self) -> &mut [u8] {
767 &mut self.buf[self.committed..]
768 }
769 fn filled(&mut self, n: usize) {
770 self.committed += n;
771 }
772 }
773
774 let (mut client, mut server) = connected_pair();
775 let payload = b"hello-frames";
776 let ciphertext = encrypt_server_payload(&mut server, payload);
777
778 // Step the codec until plaintext is queued.
779 let mut consumed = 0;
780 while consumed < ciphertext.len() {
781 consumed += client.read_tls(&ciphertext[consumed..]).unwrap();
782 }
783
784 let mut sink = CaptureSink {
785 buf: vec![0u8; 64],
786 committed: 0,
787 };
788 let n = client
789 .drain_plaintext_into(&mut sink)
790 .expect("drain_plaintext_into must succeed");
791 assert_eq!(n, payload.len(), "must feed all queued plaintext");
792 assert_eq!(&sink.buf[..n], payload);
793
794 // Idempotent on empty queue.
795 let n = client.drain_plaintext_into(&mut sink).unwrap();
796 assert_eq!(n, 0, "no more plaintext → Ok(0)");
797 }
798
799 /// `read_tls_from` drives a sync [`Read`] source: reads up to one
800 /// `READ_SIZE` from the source, processes packets, returns bytes
801 /// pulled. Verifies the read+process pair fold (caller no longer
802 /// has to call `process_new_packets` after).
803 #[test]
804 fn read_tls_from_drives_sync_read_source() {
805 let (mut client, mut server) = connected_pair();
806 let payload = b"hello-from-source";
807 let ciphertext = encrypt_server_payload(&mut server, payload);
808
809 let mut cursor = Cursor::new(ciphertext);
810 let mut total = 0;
811 while total < cursor.get_ref().len() {
812 let n = client.read_tls_from(&mut cursor).unwrap();
813 if n == 0 {
814 break;
815 }
816 total += n;
817 }
818 assert!(total > 0);
819
820 let mut dst = vec![0u8; payload.len()];
821 let n = client.read_plaintext(&mut dst).unwrap();
822 assert_eq!(n, payload.len());
823 assert_eq!(&dst[..n], payload);
824 }
825}