Skip to main content

nexus_net/tls/
stream.rs

1//! TLS stream wrapper — implements `Read + Write` over the sans-IO codec.
2//!
3//! Wraps a transport stream `S` and a [`TlsCodec`] into a single type
4//! that transparently encrypts/decrypts. Sync only — async TLS lives
5//! in `nexus-async-web::maybe_tls` (which drives the same sans-IO
6//! codec at the poll level).
7
8use std::io::{self, Read, Write};
9
10use super::codec::TlsCodec;
11
12/// A stream that transparently encrypts and decrypts via [`TlsCodec`].
13///
14/// Implements `Read` and `Write` by routing through the TLS codec.
15/// The inner stream `S` carries raw ciphertext; callers see plaintext.
16///
17/// Construct via [`connect`](Self::connect) — the handshake is driven
18/// to completion before the value is returned.
19pub struct TlsStream<S> {
20    stream: S,
21    codec: TlsCodec,
22}
23
24impl<S> TlsStream<S> {
25    /// Access the underlying transport stream.
26    pub fn stream(&self) -> &S {
27        &self.stream
28    }
29
30    /// Mutable access to the underlying transport stream.
31    pub fn stream_mut(&mut self) -> &mut S {
32        &mut self.stream
33    }
34
35    /// Access the TLS codec.
36    pub fn codec(&self) -> &TlsCodec {
37        &self.codec
38    }
39
40    /// Mutable access to the TLS codec.
41    pub fn codec_mut(&mut self) -> &mut TlsCodec {
42        &mut self.codec
43    }
44
45    /// Decompose into the inner stream and codec.
46    pub fn into_parts(self) -> (S, TlsCodec) {
47        (self.stream, self.codec)
48    }
49
50    /// Set rustls's outbound plaintext queue limit. Convenience
51    /// pass-through to [`TlsCodec::set_buffer_limit`].
52    ///
53    /// Default is rustls's `DEFAULT_BUFFER_LIMIT = 64 KiB`. Bulk-
54    /// transfer workloads (large snapshots, file uploads over TLS)
55    /// may benefit from raising it. `None` for unlimited (caller
56    /// is responsible for not encrypting more than memory allows).
57    pub fn set_buffer_limit(&mut self, limit: Option<usize>) {
58        self.codec.set_buffer_limit(limit);
59    }
60}
61
62impl<S: Read + Write> TlsStream<S> {
63    /// Wrap a transport stream and drive the TLS handshake to
64    /// completion. Returns a stream ready for plaintext I/O.
65    pub fn connect(stream: S, codec: TlsCodec) -> Result<Self, super::TlsError> {
66        let mut s = Self { stream, codec };
67        s.handshake()?;
68        Ok(s)
69    }
70
71    /// Drive the TLS handshake to completion (blocking).
72    fn handshake(&mut self) -> Result<(), super::TlsError> {
73        while self.codec.is_handshaking() {
74            while self.codec.wants_write() {
75                self.codec.write_tls_to(&mut self.stream)?;
76            }
77            if self.codec.wants_read() {
78                // read_tls_from drives one per-call read against the
79                // Read trait and processes the resulting records.
80                // Ok(0) means the peer closed mid-handshake — surface
81                // explicitly so we don't loop forever with
82                // is_handshaking() still true.
83                let n = self.codec.read_tls_from(&mut self.stream)?;
84                if n == 0 {
85                    return Err(super::TlsError::Io(io::Error::new(
86                        io::ErrorKind::UnexpectedEof,
87                        "connection closed during TLS handshake",
88                    )));
89                }
90            }
91        }
92        // Flush any remaining handshake data (client Finished, etc).
93        while self.codec.wants_write() {
94            self.codec.write_tls_to(&mut self.stream)?;
95        }
96        Ok(())
97    }
98}
99
100impl<S: Read + Write> Read for TlsStream<S> {
101    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
102        // Try reading plaintext that's already buffered.
103        let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
104        if n > 0 {
105            return Ok(n);
106        }
107
108        // Need more ciphertext from the transport.
109        // TLS may consume records without producing plaintext (session
110        // tickets, key updates). Loop until we get plaintext or EOF.
111        loop {
112            let tls_n = self
113                .codec
114                .read_tls_from(&mut self.stream)
115                .map_err(tls_to_io)?;
116            if tls_n == 0 {
117                return Ok(0); // EOF
118            }
119            let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
120            if n > 0 {
121                return Ok(n);
122            }
123        }
124    }
125}
126
127impl<S: Read + Write> Write for TlsStream<S> {
128    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
129        // Sync `Write` is all-or-nothing by trait contract, so loop
130        // until rustls accepts everything. Typical sync writes are
131        // well under rustls's plaintext queue cap (64 KiB by default),
132        // so this loop is one iteration in practice.
133        let mut written = 0;
134        while written < buf.len() {
135            let n = self.codec.encrypt(&buf[written..]).map_err(tls_to_io)?;
136            if n == 0 {
137                // rustls's plaintext queue is full. Drain it to the
138                // socket — that produces ciphertext from the queued
139                // plaintext, freeing space for more `encrypt` calls.
140                // Without this, retrying would just hit the same wall.
141                while self.codec.wants_write() {
142                    self.codec.write_tls_to(&mut self.stream)?;
143                }
144                // Queue should now be drained; retry encrypt. If still
145                // zero, the queue limit is genuinely smaller than the
146                // remaining input — surface explicitly.
147                let n2 = self.codec.encrypt(&buf[written..]).map_err(tls_to_io)?;
148                if n2 == 0 {
149                    return Err(io::Error::new(
150                        io::ErrorKind::WriteZero,
151                        "rustls plaintext queue limit is smaller than \
152                         the remaining input — the buffer-limit may \
153                         have been set too low (raise via \
154                         TlsCodec::set_buffer_limit or \
155                         TlsBufferCapacities::rustls_plaintext_limit), \
156                         or chunk the write into smaller pieces",
157                    ));
158                }
159                written += n2;
160            } else {
161                written += n;
162            }
163        }
164        while self.codec.wants_write() {
165            self.codec.write_tls_to(&mut self.stream)?;
166        }
167        Ok(buf.len())
168    }
169
170    fn flush(&mut self) -> io::Result<()> {
171        while self.codec.wants_write() {
172            self.codec.write_tls_to(&mut self.stream)?;
173        }
174        self.stream.flush()
175    }
176}
177
178fn tls_to_io(e: super::TlsError) -> io::Error {
179    match e {
180        super::TlsError::Io(io) => io,
181        other => io::Error::other(other),
182    }
183}