http2_proto/transport/
tls.rs1use super::{Transport, TransportState};
4use bytes::BytesMut;
5use rustls::pki_types::ServerName;
6use std::io::{self, Read, Write};
7use std::sync::Arc;
8
9pub struct TlsConfig {
11 config: Arc<rustls::ClientConfig>,
13}
14
15impl TlsConfig {
16 pub fn new() -> io::Result<Self> {
18 let root_store =
19 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
20
21 let config = rustls::ClientConfig::builder()
22 .with_root_certificates(root_store)
23 .with_no_client_auth();
24
25 Ok(Self {
26 config: Arc::new(config),
27 })
28 }
29
30 pub fn with_alpn(protocols: Vec<Vec<u8>>) -> io::Result<Self> {
32 let root_store =
33 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
34
35 let mut config = rustls::ClientConfig::builder()
36 .with_root_certificates(root_store)
37 .with_no_client_auth();
38
39 config.alpn_protocols = protocols;
40
41 Ok(Self {
42 config: Arc::new(config),
43 })
44 }
45
46 pub fn http2() -> io::Result<Self> {
48 Self::with_alpn(vec![b"h2".to_vec()])
49 }
50}
51
52impl Default for TlsConfig {
53 fn default() -> Self {
54 Self::new().expect("failed to create default TLS config")
55 }
56}
57
58pub struct TlsTransport {
63 conn: rustls::ClientConnection,
65 state: TransportState,
67 incoming: BytesMut,
69 outgoing: BytesMut,
71 outgoing_pos: usize,
73 plaintext: BytesMut,
75}
76
77impl TlsTransport {
78 pub fn new(config: &TlsConfig, server_name: &str) -> io::Result<Self> {
80 let name = ServerName::try_from(server_name.to_string())
81 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
82
83 let conn =
84 rustls::ClientConnection::new(config.config.clone(), name).map_err(io::Error::other)?;
85
86 let mut transport = Self {
87 conn,
88 state: TransportState::Handshaking,
89 incoming: BytesMut::with_capacity(16384),
90 outgoing: BytesMut::with_capacity(16384),
91 outgoing_pos: 0,
92 plaintext: BytesMut::with_capacity(16384),
93 };
94
95 transport.flush_tls_to_outgoing()?;
97
98 Ok(transport)
99 }
100
101 fn process_tls(&mut self) -> io::Result<()> {
103 if !self.incoming.is_empty() {
105 let mut cursor = io::Cursor::new(&self.incoming[..]);
106 match self.conn.read_tls(&mut cursor) {
107 Ok(n) => {
108 let _ = self.incoming.split_to(n);
109 }
110 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
111 Err(e) => return Err(e),
112 }
113 }
114
115 match self.conn.process_new_packets() {
117 Ok(state) => {
118 if state.plaintext_bytes_to_read() > 0 {
120 let mut buf = vec![0u8; state.plaintext_bytes_to_read()];
121 let n = self.conn.reader().read(&mut buf)?;
122 self.plaintext.extend_from_slice(&buf[..n]);
123 }
124
125 if self.state == TransportState::Handshaking && !self.conn.is_handshaking() {
127 self.state = TransportState::Ready;
128 }
129 }
130 Err(e) => {
131 self.state = TransportState::Error;
132 return Err(io::Error::other(e));
133 }
134 }
135
136 self.flush_tls_to_outgoing()?;
138
139 Ok(())
140 }
141
142 fn flush_tls_to_outgoing(&mut self) -> io::Result<()> {
144 if self.conn.wants_write() {
145 let mut buf = Vec::with_capacity(4096);
146 loop {
147 match self.conn.write_tls(&mut buf) {
148 Ok(0) => break,
149 Ok(_) => {}
150 Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
151 Err(e) => return Err(e),
152 }
153 }
154 self.outgoing.extend_from_slice(&buf);
155 }
156 Ok(())
157 }
158
159 pub fn alpn_protocol(&self) -> Option<&[u8]> {
161 self.conn.alpn_protocol()
162 }
163}
164
165impl Transport for TlsTransport {
166 fn state(&self) -> TransportState {
167 self.state
168 }
169
170 fn send(&mut self, data: &[u8]) -> io::Result<usize> {
171 if self.state != TransportState::Ready {
172 return Err(io::Error::new(
173 io::ErrorKind::NotConnected,
174 "TLS handshake not complete",
175 ));
176 }
177
178 let n = self.conn.writer().write(data)?;
180
181 self.flush_tls_to_outgoing()?;
183
184 Ok(n)
185 }
186
187 fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
188 if self.plaintext.is_empty() {
189 return Err(io::Error::from(io::ErrorKind::WouldBlock));
190 }
191
192 let n = std::cmp::min(buf.len(), self.plaintext.len());
193 buf[..n].copy_from_slice(&self.plaintext[..n]);
194
195 let _ = self.plaintext.split_to(n);
197
198 Ok(n)
199 }
200
201 fn on_recv(&mut self, data: &[u8]) -> io::Result<()> {
202 self.incoming.extend_from_slice(data);
203 self.process_tls()
204 }
205
206 fn pending_send(&self) -> &[u8] {
207 &self.outgoing[self.outgoing_pos..]
208 }
209
210 fn advance_send(&mut self, n: usize) {
211 self.outgoing_pos += n;
212
213 if self.outgoing_pos >= self.outgoing.len() {
215 self.outgoing.clear();
216 self.outgoing_pos = 0;
217 }
218 }
219
220 fn shutdown(&mut self) -> io::Result<()> {
221 self.conn.send_close_notify();
222 self.flush_tls_to_outgoing()?;
223 self.state = TransportState::Closed;
224 Ok(())
225 }
226}