nexus_net/tls/stream.rs
1//! TLS stream wrapper — implements `Read + Write` over the sans-IO codec.
2//!
3//! Wraps a transport stream `S` and a [`TlsCodec`] into a single type
4//! that transparently encrypts/decrypts. Sync only — async TLS lives
5//! in `nexus-async-web::maybe_tls` (which drives the same sans-IO
6//! codec at the poll level).
7
8use std::io::{self, Read, Write};
9
10use super::codec::TlsCodec;
11
12/// A stream that transparently encrypts and decrypts via [`TlsCodec`].
13///
14/// Implements `Read` and `Write` by routing through the TLS codec.
15/// The inner stream `S` carries raw ciphertext; callers see plaintext.
16///
17/// Construct via [`connect`](Self::connect) — the handshake is driven
18/// to completion before the value is returned.
19pub struct TlsStream<S> {
20 stream: S,
21 codec: TlsCodec,
22}
23
24impl<S> TlsStream<S> {
25 /// Access the underlying transport stream.
26 pub fn stream(&self) -> &S {
27 &self.stream
28 }
29
30 /// Mutable access to the underlying transport stream.
31 pub fn stream_mut(&mut self) -> &mut S {
32 &mut self.stream
33 }
34
35 /// Access the TLS codec.
36 pub fn codec(&self) -> &TlsCodec {
37 &self.codec
38 }
39
40 /// Mutable access to the TLS codec.
41 pub fn codec_mut(&mut self) -> &mut TlsCodec {
42 &mut self.codec
43 }
44
45 /// Decompose into the inner stream and codec.
46 pub fn into_parts(self) -> (S, TlsCodec) {
47 (self.stream, self.codec)
48 }
49
50 /// Set rustls's outbound plaintext queue limit. Convenience
51 /// pass-through to [`TlsCodec::set_buffer_limit`].
52 ///
53 /// Default is rustls's `DEFAULT_BUFFER_LIMIT = 64 KiB`. Bulk-
54 /// transfer workloads (large snapshots, file uploads over TLS)
55 /// may benefit from raising it. `None` for unlimited (caller
56 /// is responsible for not encrypting more than memory allows).
57 pub fn set_buffer_limit(&mut self, limit: Option<usize>) {
58 self.codec.set_buffer_limit(limit);
59 }
60}
61
62impl<S: Read + Write> TlsStream<S> {
63 /// Wrap a transport stream and drive the TLS handshake to
64 /// completion. Returns a stream ready for plaintext I/O.
65 pub fn connect(stream: S, codec: TlsCodec) -> Result<Self, super::TlsError> {
66 let mut s = Self { stream, codec };
67 s.handshake()?;
68 Ok(s)
69 }
70
71 /// Drive the TLS handshake to completion (blocking).
72 fn handshake(&mut self) -> Result<(), super::TlsError> {
73 while self.codec.is_handshaking() {
74 while self.codec.wants_write() {
75 self.codec.write_tls_to(&mut self.stream)?;
76 }
77 if self.codec.wants_read() {
78 // read_tls_from drives one per-call read against the
79 // Read trait and processes the resulting records.
80 // Ok(0) means the peer closed mid-handshake — surface
81 // explicitly so we don't loop forever with
82 // is_handshaking() still true.
83 let n = self.codec.read_tls_from(&mut self.stream)?;
84 if n == 0 {
85 return Err(super::TlsError::Io(io::Error::new(
86 io::ErrorKind::UnexpectedEof,
87 "connection closed during TLS handshake",
88 )));
89 }
90 }
91 }
92 // Flush any remaining handshake data (client Finished, etc).
93 while self.codec.wants_write() {
94 self.codec.write_tls_to(&mut self.stream)?;
95 }
96 Ok(())
97 }
98}
99
100impl<S: Read + Write> Read for TlsStream<S> {
101 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
102 // Try reading plaintext that's already buffered.
103 let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
104 if n > 0 {
105 return Ok(n);
106 }
107
108 // Need more ciphertext from the transport.
109 // TLS may consume records without producing plaintext (session
110 // tickets, key updates). Loop until we get plaintext or EOF.
111 loop {
112 let tls_n = self
113 .codec
114 .read_tls_from(&mut self.stream)
115 .map_err(tls_to_io)?;
116 if tls_n == 0 {
117 return Ok(0); // EOF
118 }
119 let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
120 if n > 0 {
121 return Ok(n);
122 }
123 }
124 }
125}
126
127impl<S: Read + Write> Write for TlsStream<S> {
128 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
129 // Sync `Write` is all-or-nothing by trait contract, so loop
130 // until rustls accepts everything. Typical sync writes are
131 // well under rustls's plaintext queue cap (64 KiB by default),
132 // so this loop is one iteration in practice.
133 let mut written = 0;
134 while written < buf.len() {
135 let n = self.codec.encrypt(&buf[written..]).map_err(tls_to_io)?;
136 if n == 0 {
137 // rustls's plaintext queue is full. Drain it to the
138 // socket — that produces ciphertext from the queued
139 // plaintext, freeing space for more `encrypt` calls.
140 // Without this, retrying would just hit the same wall.
141 while self.codec.wants_write() {
142 self.codec.write_tls_to(&mut self.stream)?;
143 }
144 // Queue should now be drained; retry encrypt. If still
145 // zero, the queue limit is genuinely smaller than the
146 // remaining input — surface explicitly.
147 let n2 = self.codec.encrypt(&buf[written..]).map_err(tls_to_io)?;
148 if n2 == 0 {
149 return Err(io::Error::new(
150 io::ErrorKind::WriteZero,
151 "rustls plaintext queue limit is smaller than \
152 the remaining input — the buffer-limit may \
153 have been set too low (raise via \
154 TlsCodec::set_buffer_limit or \
155 TlsBufferCapacities::rustls_plaintext_limit), \
156 or chunk the write into smaller pieces",
157 ));
158 }
159 written += n2;
160 } else {
161 written += n;
162 }
163 }
164 while self.codec.wants_write() {
165 self.codec.write_tls_to(&mut self.stream)?;
166 }
167 Ok(buf.len())
168 }
169
170 fn flush(&mut self) -> io::Result<()> {
171 while self.codec.wants_write() {
172 self.codec.write_tls_to(&mut self.stream)?;
173 }
174 self.stream.flush()
175 }
176}
177
178fn tls_to_io(e: super::TlsError) -> io::Error {
179 match e {
180 super::TlsError::Io(io) => io,
181 other => io::Error::other(other),
182 }
183}