async_web_server/
tcp.rs

1use crate::h1::HttpIncoming;
2use crate::tls::TlsIncoming;
3use crate::{AcmeIncoming, TcpOrTlsIncoming};
4use async_io::{Async, ReadableOwned};
5use futures::prelude::*;
6use futures::stream::FusedStream;
7use futures::FutureExt;
8use rustls_acme::caches::DirCache;
9use rustls_acme::futures_rustls::pki_types::{CertificateDer, PrivateKeyDer};
10use rustls_acme::futures_rustls::rustls::server::ClientHello;
11use rustls_acme::futures_rustls::rustls::ServerConfig;
12use rustls_acme::AcmeConfig;
13use std::fmt::Debug;
14use std::io;
15use std::net::SocketAddr;
16use std::path::Path;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21pub type TcpStream = async_net::TcpStream;
22
23pub struct TcpIncoming {
24    listener: Arc<Async<std::net::TcpListener>>,
25    readable: Pin<Box<ReadableOwned<std::net::TcpListener>>>,
26}
27
28impl TcpIncoming {
29    pub fn bind(addr: impl Into<SocketAddr>) -> io::Result<Self> {
30        let listener = Arc::new(Async::<std::net::TcpListener>::bind(addr)?);
31        let readable = Box::pin(listener.clone().readable_owned());
32        Ok(Self { listener, readable })
33    }
34    pub fn tls_with_config<F: FnMut(&ClientHello) -> Arc<ServerConfig>>(
35        self,
36        f: F,
37    ) -> TlsIncoming<F> {
38        TlsIncoming::new(self, f)
39    }
40    pub fn tls(
41        self,
42        cert_chain: Vec<CertificateDer<'static>>,
43        key_der: PrivateKeyDer<'static>,
44    ) -> Result<
45        TlsIncoming<impl FnMut(&ClientHello) -> Arc<ServerConfig>>,
46        rustls_acme::futures_rustls::rustls::Error,
47    > {
48        let config = ServerConfig::builder()
49            .with_no_client_auth()
50            .with_single_cert(cert_chain, key_der)?;
51        let config = Arc::new(config);
52        Ok(TlsIncoming::new(self, move |_| config.clone()))
53    }
54    pub fn tls_acme<EC: Debug, EA: Debug>(
55        self,
56        config: AcmeConfig<EC, EA>,
57    ) -> AcmeIncoming<EC, EA> {
58        AcmeIncoming::new(self, config)
59    }
60    // TODO: add rate limit warning for production
61    pub fn tls_lets_encrypt(
62        self,
63        domains: impl IntoIterator<Item = impl AsRef<str>>,
64        contact: impl IntoIterator<Item = impl AsRef<str>>,
65        cache_dir: impl AsRef<Path> + Send + Sync + 'static,
66        production: bool,
67    ) -> AcmeIncoming<io::Error, io::Error> {
68        let config = AcmeConfig::new(domains)
69            .contact(contact)
70            .cache(DirCache::new(cache_dir))
71            .directory_lets_encrypt(production);
72        self.tls_acme(config)
73    }
74    pub fn or_tls(self) -> TcpOrTlsIncoming {
75        let mut tcp_or_tls = TcpOrTlsIncoming::new();
76        tcp_or_tls.push(self);
77        tcp_or_tls
78    }
79    pub fn http(self) -> HttpIncoming<TcpStream, Self> {
80        HttpIncoming::new(self)
81    }
82    pub fn local_addr(&self) -> io::Result<SocketAddr> {
83        self.listener.get_ref().local_addr()
84    }
85}
86
87impl Stream for TcpIncoming {
88    type Item = TcpStream;
89
90    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91        let listener = self.listener.clone();
92        loop {
93            match Box::pin(listener.accept()).poll_unpin(cx) {
94                Poll::Ready(result) => match result {
95                    Ok((stream, _)) => return Poll::Ready(Some(stream.into())),
96                    Err(err) => log::debug!("tcp accept error: {:?}", err),
97                },
98                Poll::Pending => match self.readable.as_mut().poll(cx) {
99                    Poll::Pending => return Poll::Pending,
100                    Poll::Ready(_) => self.readable = Box::pin(listener.clone().readable_owned()),
101                },
102            }
103        }
104    }
105}
106
107impl FusedStream for TcpIncoming {
108    fn is_terminated(&self) -> bool {
109        false
110    }
111}