async_web_server/
tls.rs

1use crate::tcp::TcpIncoming;
2use crate::{HttpIncoming, TcpOrTlsIncoming, TcpStream};
3use futures::prelude::*;
4use futures::stream::{FusedStream, FuturesUnordered};
5use futures::StreamExt;
6use rustls_acme::futures_rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
7use rustls_acme::futures_rustls::rustls::server::{Acceptor, ClientHello};
8use rustls_acme::futures_rustls::rustls::ServerConfig;
9use rustls_acme::futures_rustls::{Accept, LazyConfigAcceptor};
10use rustls_pemfile::Item;
11use std::io;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16pub type TlsStream = rustls_acme::futures_rustls::server::TlsStream<TcpStream>;
17
18pub struct TlsIncoming<F: FnMut(&ClientHello) -> Arc<ServerConfig>> {
19    tcp_incoming: Option<TcpIncoming>,
20    f: F,
21    start_accepts: FuturesUnordered<LazyConfigAcceptor<TcpStream>>,
22    accepts: FuturesUnordered<Accept<TcpStream>>,
23}
24
25impl<F: FnMut(&ClientHello) -> Arc<ServerConfig>> TlsIncoming<F> {
26    pub fn new(tcp_incoming: TcpIncoming, f: F) -> Self {
27        let start_accepts = FuturesUnordered::new();
28        let accepts = FuturesUnordered::new();
29        TlsIncoming {
30            tcp_incoming: Some(tcp_incoming),
31            f,
32            start_accepts,
33            accepts,
34        }
35    }
36    pub fn http(self) -> HttpIncoming<TlsStream, Self> {
37        HttpIncoming::new(self)
38    }
39}
40
41impl<F: FnMut(&ClientHello) -> Arc<ServerConfig> + 'static> TlsIncoming<F> {
42    pub fn or_tcp(self) -> TcpOrTlsIncoming {
43        let mut tcp_or_tls = TcpOrTlsIncoming::new();
44        tcp_or_tls.push(self);
45        tcp_or_tls
46    }
47}
48
49impl<F: FnMut(&ClientHello) -> Arc<ServerConfig>> Unpin for TlsIncoming<F> {}
50
51impl<F: FnMut(&ClientHello) -> Arc<ServerConfig>> Stream for TlsIncoming<F> {
52    type Item = TlsStream;
53
54    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
55        loop {
56            match self.accepts.poll_next_unpin(cx) {
57                Poll::Ready(Some(Ok(tls_stream))) => return Poll::Ready(Some(tls_stream)),
58                Poll::Ready(Some(Err(err))) => log::debug!("tls accept error: {:?}", err),
59                Poll::Ready(None) | Poll::Pending => match self.start_accepts.poll_next_unpin(cx) {
60                    Poll::Ready(Some(Ok(start_handshake))) => {
61                        let config = (self.f)(&start_handshake.client_hello());
62                        let accept_fut = start_handshake.into_stream(config);
63                        self.accepts.push(accept_fut);
64                    }
65                    Poll::Ready(Some(Err(err))) => log::debug!("tls accept error: {:?}", err),
66                    Poll::Ready(None) | Poll::Pending => match &mut self.tcp_incoming {
67                        None => match self.is_terminated() {
68                            true => return Poll::Ready(None),
69                            false => return Poll::Pending,
70                        },
71                        Some(tcp_incoming) => match tcp_incoming.poll_next_unpin(cx) {
72                            Poll::Ready(Some(tcp_stream)) => {
73                                let acceptor = Acceptor::default();
74                                let acceptor_fut = LazyConfigAcceptor::new(acceptor, tcp_stream);
75                                self.start_accepts.push(acceptor_fut);
76                            }
77                            Poll::Ready(None) => drop(self.tcp_incoming.take()),
78                            Poll::Pending => return Poll::Pending,
79                        },
80                    },
81                },
82            }
83        }
84    }
85}
86
87impl<F: FnMut(&ClientHello) -> Arc<ServerConfig>> FusedStream for TlsIncoming<F> {
88    fn is_terminated(&self) -> bool {
89        self.tcp_incoming.is_none()
90            && self.accepts.is_terminated()
91            && self.start_accepts.is_terminated()
92    }
93}
94
95pub fn parse_pem(
96    pem: impl AsRef<[u8]>,
97) -> io::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
98    let mut buf = pem.as_ref();
99    let pem = rustls_pemfile::read_all(&mut buf)?;
100
101    let (mut cert_chain, mut private_key) = (Vec::new(), None);
102    for item in pem.into_iter() {
103        match item {
104            Item::X509Certificate(b) => cert_chain.push(CertificateDer::from(b)),
105            Item::RSAKey(v) | Item::PKCS8Key(v) | Item::ECKey(v) => {
106                if private_key.is_none() {
107                    private_key = Some(PrivatePkcs8KeyDer::from(v).into());
108                }
109            }
110            _ => {}
111        }
112    }
113
114    let private_key = match private_key {
115        None => {
116            return Err(io::Error::new(
117                io::ErrorKind::InvalidData,
118                "missing private key",
119            ))
120        }
121        Some(private_key) => private_key,
122    };
123    if cert_chain.len() == 0 {
124        return Err(io::Error::new(
125            io::ErrorKind::InvalidData,
126            "missing certificates",
127        ));
128    }
129    Ok((cert_chain, private_key))
130}