1extern crate mco_http;
2extern crate rustls;
3extern crate webpki_roots;
4
5use mco_http::net::{HttpStream, NetworkStream};
6use std::convert::{TryInto};
7
8use std::{io};
9use std::fmt::{Debug, Display, Formatter};
10use std::io::{BufReader, Cursor, Read, Write};
11use std::net::{Shutdown, SocketAddr};
12use std::sync::Arc;
13use std::time::Duration;
14
15use rustls::{ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, ServerConfig, ServerConnection, StreamOwned, WantsVerifier};
16use rustls::client::WantsClientCert;
17use mco_http::runtime::Mutex;
18
19
20pub enum Connection {
21 Client(StreamOwned<ClientConnection, HttpStream>),
22 Server(StreamOwned<ServerConnection, HttpStream>),
23}
24
25impl Read for Connection {
26 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
27 match self {
28 Connection::Client(c) => {
29 c.read(buf)
30 }
31 Connection::Server(c) => {
32 c.read(buf)
33 }
34 }
35 }
36}
37
38impl Write for Connection {
39 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
40 match self {
41 Connection::Client(c) => {
42 c.write(buf)
43 }
44 Connection::Server(c) => {
45 c.write(buf)
46 }
47 }
48 }
49
50 fn flush(&mut self) -> io::Result<()> {
51 match self {
52 Connection::Client(c) => {
53 c.flush()
54 }
55 Connection::Server(c) => {
56 c.flush()
57 }
58 }
59 }
60}
61
62
63pub struct TlsStream {
64 sess: Box<Connection>,
65 tls_error: Option<rustls::Error>,
66}
67
68impl TlsStream {
69 fn promote_tls_error(&mut self) -> io::Result<()> {
70 match self.tls_error.take() {
71 Some(err) => {
72 return Err(io::Error::new(io::ErrorKind::ConnectionAborted, err));
73 }
74 None => return Ok(()),
75 };
76 }
77}
78
79impl NetworkStream for TlsStream {
80 fn peer_addr(&mut self) -> io::Result<SocketAddr> {
81 match self.sess.as_mut() {
82 Connection::Client(c) => {
83 c.sock.peer_addr()
84 }
85 Connection::Server(c) => {
86 c.sock.peer_addr()
87 }
88 }
89 }
90
91 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
92 match self.sess.as_ref() {
93 Connection::Client(c) => {
94 c.sock.set_read_timeout(dur)
95 }
96 Connection::Server(c) => {
97 c.sock.set_read_timeout(dur)
98 }
99 }
100 }
101
102 fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
103 match self.sess.as_ref() {
104 Connection::Client(c) => {
105 c.sock.set_write_timeout(dur)
106 }
107 Connection::Server(c) => {
108 c.sock.set_write_timeout(dur)
109 }
110 }
111 }
112}
113
114impl io::Read for TlsStream {
115 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
116 loop {
117 self.promote_tls_error()?;
118 match self.sess.as_mut().read(buf) {
119 Ok(0) => continue,
120 Ok(n) => return Ok(n),
121 Err(e) => return Err(e),
122 }
123 }
124 }
125}
126
127impl io::Write for TlsStream {
128 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
129 let len = self.sess.write(buf)?;
130 self.promote_tls_error()?;
131 Ok(len)
132 }
133
134 fn flush(&mut self) -> io::Result<()> {
135 let rc = self.sess.flush();
136 self.promote_tls_error()?;
137 rc
138 }
139}
140
141#[derive(Clone)]
142pub struct WrappedStream(Arc<Mutex<TlsStream>>);
143
144impl WrappedStream {
145 fn lock(&self) -> mco_http::runtime::MutexGuard<TlsStream> {
146 self.0.lock().unwrap_or_else(|e| e.into_inner())
147 }
148}
149
150impl io::Read for WrappedStream {
151 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
152 self.lock().read(buf)
153 }
154}
155
156impl io::Write for WrappedStream {
157 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
158 self.lock().write(buf)
159 }
160
161 fn flush(&mut self) -> io::Result<()> {
162 self.lock().flush()
163 }
164}
165
166impl NetworkStream for WrappedStream {
167 fn peer_addr(&mut self) -> io::Result<SocketAddr> {
168 self.lock().peer_addr()
169 }
170
171 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
172 self.lock().set_read_timeout(dur)
173 }
174
175 fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
176 self.lock().set_write_timeout(dur)
177 }
178
179 fn close(&mut self, how: Shutdown) -> io::Result<()> {
180 self.lock().close(how)
181 }
182}
183
184pub struct TlsClient {
185 pub cfg: Arc<rustls::ClientConfig>,
186}
187
188impl TlsClient {
189 pub fn new() -> TlsClient {
190 return Self::new_ca(None).expect("crate TlsClient fail")
191 }
192
193 pub fn new_ca(mut ca:Option<&mut dyn io::BufRead>)-> Result<TlsClient,io::Error> {
194 let tls = match ca {
196 Some(ref mut rd) => {
197 let certs = rustls_pemfile::certs(rd).collect::<Result<Vec<_>, _>>()?;
199 let mut roots = RootCertStore::empty();
200 roots.add_parsable_certificates(certs);
201 ClientConfig::builder()
203 .with_root_certificates(roots)
204 .with_no_client_auth()
205 }
206 None => ClientConfig::builder()
208 .with_native_roots()?
209 .with_no_client_auth(),
210 };
211 Ok(Self{
212 cfg: Arc::new(tls),
213 })
214 }
215}
216
217
218pub trait ConfigBuilderExt {
223 fn with_native_roots(self) -> std::io::Result<ConfigBuilder<ClientConfig, WantsClientCert>>;
230
231 fn with_webpki_roots(self) -> ConfigBuilder<ClientConfig, WantsClientCert>;
235}
236
237impl ConfigBuilderExt for ConfigBuilder<ClientConfig, WantsVerifier> {
238 fn with_native_roots(self) -> std::io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
241 let mut roots = rustls::RootCertStore::empty();
242 let mut valid_count = 0;
243 let mut invalid_count = 0;
244
245 for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
246 {
247 match roots.add(cert) {
248 Ok(_) => valid_count += 1,
249 Err(err) => {
250 log::debug!("certificate parsing failed: {:?}", err);
251 invalid_count += 1
252 }
253 }
254 }
255 log::debug!(
256 "with_native_roots processed {} valid and {} invalid certs",
257 valid_count,
258 invalid_count
259 );
260 if roots.is_empty() {
261 log::debug!("no valid native root CA certificates found");
262 Err(std::io::Error::new(
263 std::io::ErrorKind::NotFound,
264 format!("no valid native root CA certificates found ({invalid_count} invalid)"),
265 ))?
266 }
267
268 Ok(self.with_root_certificates(roots))
269 }
270
271 fn with_webpki_roots(self) -> ConfigBuilder<ClientConfig, WantsClientCert> {
273 let mut roots = rustls::RootCertStore::empty();
274 roots.extend(
275 webpki_roots::TLS_SERVER_ROOTS
276 .iter()
277 .cloned(),
278 );
279 self.with_root_certificates(roots)
280 }
281}
282
283
284
285
286
287#[derive(Debug)]
288pub struct DNSError {
289 pub inner: String,
290}
291
292impl Display for DNSError {
293 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294 std::fmt::Display::fmt(&self.inner, f)
295 }
296}
297
298impl std::error::Error for DNSError {}
299
300impl mco_http::net::SslClient for TlsClient {
301 type Stream = WrappedStream;
302
303 fn wrap_client(&self, stream: HttpStream, host: &str) -> mco_http::Result<WrappedStream> {
304 let c = ClientConnection::new(
305 self.cfg.clone(),
306 host.to_string().try_into().unwrap(),
307 ).map_err(|e| mco_http::Error::Ssl(Box::new(e)))?;
308 let tls = TlsStream {
309 sess: Box::new(Connection::Client(StreamOwned::new(c, stream))),
310 tls_error: None,
311 };
312
313 Ok(WrappedStream(Arc::new(Mutex::new(tls))))
314 }
315}
316
317
318pub use SSLServer as TlsServer;
319
320#[derive(Clone)]
321pub struct SSLServer {
322 pub cfg: Arc<rustls::ServerConfig>,
323}
324
325impl SSLServer {
326
327 pub fn new(certs: Vec<Vec<u8>>, key: Vec<u8>) -> SSLServer {
329 let flattened_data: Vec<u8> = certs.into_iter().flatten().collect();
330 let mut reader = BufReader::new(Cursor::new(flattened_data));
331 let certs = rustls_pemfile::certs(&mut reader).map(|result| result.unwrap())
332 .collect();
333 let private_key=rustls_pemfile::private_key(&mut BufReader::new(Cursor::new(key.clone()))).expect("rustls_pemfile::private_key() read fail");
334 if private_key.is_none() {
335 panic!("load keys is empty")
336 }
337 let config = rustls::ServerConfig::builder()
338 .with_no_client_auth()
339 .with_single_cert(certs, private_key.unwrap()).unwrap();
340
341 SSLServer {
342 cfg: Arc::new(config),
343 }
344 }
345
346 pub fn with_tls_config(self, config: ServerConfig) -> SSLServer {
348 SSLServer {
349 cfg: Arc::new(config),
350 }
351 }
352
353}
354
355impl mco_http::net::SslServer for SSLServer {
356 type Stream = WrappedStream;
357
358 fn wrap_server(&self, stream: HttpStream) -> mco_http::Result<WrappedStream> {
359 let conn = ServerConnection::new(self.cfg.clone()).unwrap();
360 let stream = rustls::StreamOwned::new(conn, stream);
361
362 let tls = TlsStream {
363 sess: Box::new(Connection::Server(stream)),
364 tls_error: None,
365 };
366
367 Ok(WrappedStream(Arc::new(Mutex::new(tls))))
368 }
369}