Skip to main content

http2_proto/transport/
tls.rs

1//! TLS transport using rustls.
2
3use super::{Transport, TransportState};
4use bytes::BytesMut;
5use rustls::pki_types::ServerName;
6use std::io::{self, Read, Write};
7use std::sync::Arc;
8
9/// TLS configuration for client connections.
10pub struct TlsConfig {
11    /// rustls client configuration.
12    config: Arc<rustls::ClientConfig>,
13}
14
15impl TlsConfig {
16    /// Create a new TLS configuration with default root certificates.
17    pub fn new() -> io::Result<Self> {
18        let root_store =
19            rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
20
21        let config = rustls::ClientConfig::builder()
22            .with_root_certificates(root_store)
23            .with_no_client_auth();
24
25        Ok(Self {
26            config: Arc::new(config),
27        })
28    }
29
30    /// Create a TLS configuration with ALPN protocols.
31    pub fn with_alpn(protocols: Vec<Vec<u8>>) -> io::Result<Self> {
32        let root_store =
33            rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
34
35        let mut config = rustls::ClientConfig::builder()
36            .with_root_certificates(root_store)
37            .with_no_client_auth();
38
39        config.alpn_protocols = protocols;
40
41        Ok(Self {
42            config: Arc::new(config),
43        })
44    }
45
46    /// Create a TLS configuration for HTTP/2.
47    pub fn http2() -> io::Result<Self> {
48        Self::with_alpn(vec![b"h2".to_vec()])
49    }
50}
51
52impl Default for TlsConfig {
53    fn default() -> Self {
54        Self::new().expect("failed to create default TLS config")
55    }
56}
57
58/// TLS transport using rustls.
59///
60/// This wraps a rustls `ClientConnection` and provides a completion-based
61/// interface.
62pub struct TlsTransport {
63    /// The TLS connection state.
64    conn: rustls::ClientConnection,
65    /// Current transport state.
66    state: TransportState,
67    /// Buffer for incoming encrypted data from the socket.
68    incoming: BytesMut,
69    /// Buffer for outgoing encrypted data to the socket.
70    outgoing: BytesMut,
71    /// How much of outgoing has been sent.
72    outgoing_pos: usize,
73    /// Buffer for decrypted application data.
74    plaintext: BytesMut,
75}
76
77impl TlsTransport {
78    /// Create a new TLS transport for the given server name.
79    pub fn new(config: &TlsConfig, server_name: &str) -> io::Result<Self> {
80        let name = ServerName::try_from(server_name.to_string())
81            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
82
83        let conn =
84            rustls::ClientConnection::new(config.config.clone(), name).map_err(io::Error::other)?;
85
86        let mut transport = Self {
87            conn,
88            state: TransportState::Handshaking,
89            incoming: BytesMut::with_capacity(16384),
90            outgoing: BytesMut::with_capacity(16384),
91            outgoing_pos: 0,
92            plaintext: BytesMut::with_capacity(16384),
93        };
94
95        // Generate initial handshake data
96        transport.flush_tls_to_outgoing()?;
97
98        Ok(transport)
99    }
100
101    /// Process TLS state and move data between buffers.
102    fn process_tls(&mut self) -> io::Result<()> {
103        // Feed incoming encrypted data to TLS
104        if !self.incoming.is_empty() {
105            let mut cursor = io::Cursor::new(&self.incoming[..]);
106            match self.conn.read_tls(&mut cursor) {
107                Ok(n) => {
108                    let _ = self.incoming.split_to(n);
109                }
110                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
111                Err(e) => return Err(e),
112            }
113        }
114
115        // Process TLS messages
116        match self.conn.process_new_packets() {
117            Ok(state) => {
118                // Read any decrypted data
119                if state.plaintext_bytes_to_read() > 0 {
120                    let mut buf = vec![0u8; state.plaintext_bytes_to_read()];
121                    let n = self.conn.reader().read(&mut buf)?;
122                    self.plaintext.extend_from_slice(&buf[..n]);
123                }
124
125                // Check if handshake completed
126                if self.state == TransportState::Handshaking && !self.conn.is_handshaking() {
127                    self.state = TransportState::Ready;
128                }
129            }
130            Err(e) => {
131                self.state = TransportState::Error;
132                return Err(io::Error::other(e));
133            }
134        }
135
136        // Flush any pending TLS output
137        self.flush_tls_to_outgoing()?;
138
139        Ok(())
140    }
141
142    /// Move pending TLS output to the outgoing buffer.
143    fn flush_tls_to_outgoing(&mut self) -> io::Result<()> {
144        if self.conn.wants_write() {
145            let mut buf = Vec::with_capacity(4096);
146            loop {
147                match self.conn.write_tls(&mut buf) {
148                    Ok(0) => break,
149                    Ok(_) => {}
150                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
151                    Err(e) => return Err(e),
152                }
153            }
154            self.outgoing.extend_from_slice(&buf);
155        }
156        Ok(())
157    }
158
159    /// Get the negotiated ALPN protocol, if any.
160    pub fn alpn_protocol(&self) -> Option<&[u8]> {
161        self.conn.alpn_protocol()
162    }
163}
164
165impl Transport for TlsTransport {
166    fn state(&self) -> TransportState {
167        self.state
168    }
169
170    fn send(&mut self, data: &[u8]) -> io::Result<usize> {
171        if self.state != TransportState::Ready {
172            return Err(io::Error::new(
173                io::ErrorKind::NotConnected,
174                "TLS handshake not complete",
175            ));
176        }
177
178        // Write plaintext to TLS
179        let n = self.conn.writer().write(data)?;
180
181        // Encrypt and move to outgoing buffer
182        self.flush_tls_to_outgoing()?;
183
184        Ok(n)
185    }
186
187    fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
188        if self.plaintext.is_empty() {
189            return Err(io::Error::from(io::ErrorKind::WouldBlock));
190        }
191
192        let n = std::cmp::min(buf.len(), self.plaintext.len());
193        buf[..n].copy_from_slice(&self.plaintext[..n]);
194
195        // Remove read bytes
196        let _ = self.plaintext.split_to(n);
197
198        Ok(n)
199    }
200
201    fn on_recv(&mut self, data: &[u8]) -> io::Result<()> {
202        self.incoming.extend_from_slice(data);
203        self.process_tls()
204    }
205
206    fn pending_send(&self) -> &[u8] {
207        &self.outgoing[self.outgoing_pos..]
208    }
209
210    fn advance_send(&mut self, n: usize) {
211        self.outgoing_pos += n;
212
213        // If all data sent, clear the buffer
214        if self.outgoing_pos >= self.outgoing.len() {
215            self.outgoing.clear();
216            self.outgoing_pos = 0;
217        }
218    }
219
220    fn shutdown(&mut self) -> io::Result<()> {
221        self.conn.send_close_notify();
222        self.flush_tls_to_outgoing()?;
223        self.state = TransportState::Closed;
224        Ok(())
225    }
226}