mco_http_rustls/
lib.rs

1extern crate mco_http;
2extern crate rustls;
3extern crate webpki_roots;
4
5use mco_http::net::{HttpStream, NetworkStream};
6use std::convert::{TryInto};
7
8use std::{io};
9use std::fmt::{Debug, Display, Formatter};
10use std::io::{BufReader, Cursor, Read, Write};
11use std::net::{Shutdown, SocketAddr};
12use std::sync::Arc;
13use std::time::Duration;
14
15use rustls::{ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, ServerConfig, ServerConnection, StreamOwned, WantsVerifier};
16use rustls::client::WantsClientCert;
17use mco_http::runtime::Mutex;
18
19
20pub enum Connection {
21    Client(StreamOwned<ClientConnection, HttpStream>),
22    Server(StreamOwned<ServerConnection, HttpStream>),
23}
24
25impl Read for Connection {
26    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
27        match self {
28            Connection::Client(c) => {
29                c.read(buf)
30            }
31            Connection::Server(c) => {
32                c.read(buf)
33            }
34        }
35    }
36}
37
38impl Write for Connection {
39    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
40        match self {
41            Connection::Client(c) => {
42                c.write(buf)
43            }
44            Connection::Server(c) => {
45                c.write(buf)
46            }
47        }
48    }
49
50    fn flush(&mut self) -> io::Result<()> {
51        match self {
52            Connection::Client(c) => {
53                c.flush()
54            }
55            Connection::Server(c) => {
56                c.flush()
57            }
58        }
59    }
60}
61
62
63pub struct TlsStream {
64    sess: Box<Connection>,
65    tls_error: Option<rustls::Error>,
66}
67
68impl TlsStream {
69    fn promote_tls_error(&mut self) -> io::Result<()> {
70        match self.tls_error.take() {
71            Some(err) => {
72                return Err(io::Error::new(io::ErrorKind::ConnectionAborted, err));
73            }
74            None => return Ok(()),
75        };
76    }
77}
78
79impl NetworkStream for TlsStream {
80    fn peer_addr(&mut self) -> io::Result<SocketAddr> {
81        match self.sess.as_mut() {
82            Connection::Client(c) => {
83                c.sock.peer_addr()
84            }
85            Connection::Server(c) => {
86                c.sock.peer_addr()
87            }
88        }
89    }
90
91    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
92        match self.sess.as_ref() {
93            Connection::Client(c) => {
94                c.sock.set_read_timeout(dur)
95            }
96            Connection::Server(c) => {
97                c.sock.set_read_timeout(dur)
98            }
99        }
100    }
101
102    fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
103        match self.sess.as_ref() {
104            Connection::Client(c) => {
105                c.sock.set_write_timeout(dur)
106            }
107            Connection::Server(c) => {
108                c.sock.set_write_timeout(dur)
109            }
110        }
111    }
112}
113
114impl io::Read for TlsStream {
115    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
116        loop {
117            self.promote_tls_error()?;
118            match self.sess.as_mut().read(buf) {
119                Ok(0) => continue,
120                Ok(n) => return Ok(n),
121                Err(e) => return Err(e),
122            }
123        }
124    }
125}
126
127impl io::Write for TlsStream {
128    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
129        let len = self.sess.write(buf)?;
130        self.promote_tls_error()?;
131        Ok(len)
132    }
133
134    fn flush(&mut self) -> io::Result<()> {
135        let rc = self.sess.flush();
136        self.promote_tls_error()?;
137        rc
138    }
139}
140
141#[derive(Clone)]
142pub struct WrappedStream(Arc<Mutex<TlsStream>>);
143
144impl WrappedStream {
145    fn lock(&self) -> mco_http::runtime::MutexGuard<TlsStream> {
146        self.0.lock().unwrap_or_else(|e| e.into_inner())
147    }
148}
149
150impl io::Read for WrappedStream {
151    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
152        self.lock().read(buf)
153    }
154}
155
156impl io::Write for WrappedStream {
157    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
158        self.lock().write(buf)
159    }
160
161    fn flush(&mut self) -> io::Result<()> {
162        self.lock().flush()
163    }
164}
165
166impl NetworkStream for WrappedStream {
167    fn peer_addr(&mut self) -> io::Result<SocketAddr> {
168        self.lock().peer_addr()
169    }
170
171    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
172        self.lock().set_read_timeout(dur)
173    }
174
175    fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
176        self.lock().set_write_timeout(dur)
177    }
178
179    fn close(&mut self, how: Shutdown) -> io::Result<()> {
180        self.lock().close(how)
181    }
182}
183
184pub struct TlsClient {
185    pub cfg: Arc<rustls::ClientConfig>,
186}
187
188impl TlsClient {
189    pub fn new() -> TlsClient {
190        return Self::new_ca(None).expect("crate TlsClient fail")
191    }
192
193    pub fn new_ca(mut ca:Option<&mut dyn io::BufRead>)-> Result<TlsClient,io::Error> {
194        // Prepare the TLS client config
195        let tls = match ca {
196            Some(ref mut rd) => {
197                // Read trust roots
198                let certs = rustls_pemfile::certs(rd).collect::<Result<Vec<_>, _>>()?;
199                let mut roots = RootCertStore::empty();
200                roots.add_parsable_certificates(certs);
201                // TLS client config using the custom CA store for lookups
202                ClientConfig::builder()
203                    .with_root_certificates(roots)
204                    .with_no_client_auth()
205            }
206            // Default TLS client config with native roots
207            None => ClientConfig::builder()
208                .with_native_roots()?
209                .with_no_client_auth(),
210        };
211        Ok(Self{
212            cfg: Arc::new(tls),
213        })
214    }
215}
216
217
218/// Methods for configuring roots
219///
220/// This adds methods (gated by crate features) for easily configuring
221/// TLS server roots a rustls ClientConfig will trust.
222pub trait ConfigBuilderExt {
223    /// This configures the platform's trusted certs, as implemented by
224    /// rustls-native-certs
225    ///
226    /// This will return an error if no valid certs were found. In that case,
227    /// it's recommended to use `with_webpki_roots`.
228    //#[cfg(feature = "rustls-native-certs")]
229    fn with_native_roots(self) -> std::io::Result<ConfigBuilder<ClientConfig, WantsClientCert>>;
230
231    /// This configures the webpki roots, which are Mozilla's set of
232    /// trusted roots as packaged by webpki-roots.
233    //#[cfg(feature = "webpki-roots")]
234    fn with_webpki_roots(self) -> ConfigBuilder<ClientConfig, WantsClientCert>;
235}
236
237impl ConfigBuilderExt for ConfigBuilder<ClientConfig, WantsVerifier> {
238    //#[cfg(feature = "rustls-native-certs")]
239    //#[cfg_attr(not(feature = "logging"), allow(unused_variables))]
240    fn with_native_roots(self) -> std::io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
241        let mut roots = rustls::RootCertStore::empty();
242        let mut valid_count = 0;
243        let mut invalid_count = 0;
244
245        for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
246        {
247            match roots.add(cert) {
248                Ok(_) => valid_count += 1,
249                Err(err) => {
250                    log::debug!("certificate parsing failed: {:?}", err);
251                    invalid_count += 1
252                }
253            }
254        }
255        log::debug!(
256            "with_native_roots processed {} valid and {} invalid certs",
257            valid_count,
258            invalid_count
259        );
260        if roots.is_empty() {
261            log::debug!("no valid native root CA certificates found");
262            Err(std::io::Error::new(
263                std::io::ErrorKind::NotFound,
264                format!("no valid native root CA certificates found ({invalid_count} invalid)"),
265            ))?
266        }
267
268        Ok(self.with_root_certificates(roots))
269    }
270
271    //#[cfg(feature = "webpki-roots")]
272    fn with_webpki_roots(self) -> ConfigBuilder<ClientConfig, WantsClientCert> {
273        let mut roots = rustls::RootCertStore::empty();
274        roots.extend(
275            webpki_roots::TLS_SERVER_ROOTS
276                .iter()
277                .cloned(),
278        );
279        self.with_root_certificates(roots)
280    }
281}
282
283
284
285
286
287#[derive(Debug)]
288pub struct DNSError {
289    pub inner: String,
290}
291
292impl Display for DNSError {
293    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294        std::fmt::Display::fmt(&self.inner, f)
295    }
296}
297
298impl std::error::Error for DNSError {}
299
300impl mco_http::net::SslClient for TlsClient {
301    type Stream = WrappedStream;
302
303    fn wrap_client(&self, stream: HttpStream, host: &str) -> mco_http::Result<WrappedStream> {
304        let c = ClientConnection::new(
305            self.cfg.clone(),
306            host.to_string().try_into().unwrap(),
307        ).map_err(|e| mco_http::Error::Ssl(Box::new(e)))?;
308        let tls = TlsStream {
309            sess: Box::new(Connection::Client(StreamOwned::new(c, stream))),
310            tls_error: None,
311        };
312
313        Ok(WrappedStream(Arc::new(Mutex::new(tls))))
314    }
315}
316
317
318pub use SSLServer as TlsServer;
319
320#[derive(Clone)]
321pub struct SSLServer {
322    pub cfg: Arc<rustls::ServerConfig>,
323}
324
325impl SSLServer {
326
327    /// new with with_single_cert
328    pub fn new(certs: Vec<Vec<u8>>, key: Vec<u8>) -> SSLServer {
329        let flattened_data: Vec<u8> = certs.into_iter().flatten().collect();
330        let mut reader = BufReader::new(Cursor::new(flattened_data));
331        let certs = rustls_pemfile::certs(&mut reader).map(|result| result.unwrap())
332            .collect();
333        let private_key=rustls_pemfile::private_key(&mut BufReader::new(Cursor::new(key.clone()))).expect("rustls_pemfile::private_key() read fail");
334        if private_key.is_none() {
335            panic!("load keys is empty")
336        }
337        let config = rustls::ServerConfig::builder()
338            .with_no_client_auth()
339            .with_single_cert(certs, private_key.unwrap()).unwrap();
340
341        SSLServer {
342            cfg: Arc::new(config),
343        }
344    }
345
346    /// new with_tls_config, Passes a rustls [`ServerConfig`] to configure the TLS connection
347    pub fn with_tls_config(self, config: ServerConfig) -> SSLServer {
348        SSLServer {
349            cfg: Arc::new(config),
350        }
351    }
352
353}
354
355impl mco_http::net::SslServer for SSLServer {
356    type Stream = WrappedStream;
357
358    fn wrap_server(&self, stream: HttpStream) -> mco_http::Result<WrappedStream> {
359        let conn = ServerConnection::new(self.cfg.clone()).unwrap();
360        let stream = rustls::StreamOwned::new(conn, stream);
361
362        let tls = TlsStream {
363            sess: Box::new(Connection::Server(stream)),
364            tls_error: None,
365        };
366
367        Ok(WrappedStream(Arc::new(Mutex::new(tls))))
368    }
369}