axol 0.1.5

Axol Web Framework
Documentation
use std::{
    net::SocketAddr,
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
    time::Duration,
};

use anyhow::Result;
use futures::Stream;
use hyper::server::{
    accept::Accept,
    conn::{AddrIncoming, AddrStream},
};
use log::{error, warn};
use rustls::{server::Acceptor, ServerConfig};
use tokio::sync::{mpsc, watch};
use tokio_rustls::{server::TlsStream, LazyConfigAcceptor};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};

pub struct TlsIncoming {
    incoming: StreamWrapper,
    tls_config: watch::Receiver<Option<Arc<ServerConfig>>>,
}

pin_project_lite::pin_project! {
    pub struct AcceptWrapper<S: Stream<Item = Result<TlsStream<AddrStream>, std::io::Error>>> {
        #[pin]
        stream: S,
    }
}

impl<S: Stream<Item = Result<TlsStream<AddrStream>, std::io::Error>>> Accept for AcceptWrapper<S> {
    type Conn = TlsStream<AddrStream>;

    type Error = std::io::Error;

    fn poll_accept(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        self.project().stream.poll_next(cx)
    }
}

struct StreamWrapper(AddrIncoming);

impl Stream for StreamWrapper {
    type Item = Result<AddrStream, std::io::Error>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        Pin::new(&mut self.0).poll_accept(cx)
    }
}

impl TlsIncoming {
    pub fn new(
        listen: SocketAddr,
        nodelay: bool,
        keepalive: Option<Duration>,
        tls_config: watch::Receiver<Option<Arc<ServerConfig>>>,
    ) -> Result<Self, hyper::Error> {
        let mut incoming = AddrIncoming::bind(&listen)?;
        incoming.set_nodelay(nodelay);
        incoming.set_keepalive(keepalive);

        Ok(Self {
            incoming: StreamWrapper(incoming),
            tls_config,
        })
    }

    pub fn new_static(
        listen: SocketAddr,
        nodelay: bool,
        keepalive: Option<Duration>,
        tls_config: ServerConfig,
    ) -> Result<Self, hyper::Error> {
        Self::new(
            listen,
            nodelay,
            keepalive,
            watch::channel(Some(Arc::new(tls_config))).1,
        )
    }

    pub fn start(
        mut self,
    ) -> AcceptWrapper<impl Stream<Item = Result<TlsStream<AddrStream>, std::io::Error>>> {
        let (sender, receiver) = mpsc::channel::<Result<TlsStream<AddrStream>, std::io::Error>>(10);
        tokio::spawn(async move {
            loop {
                let client = match self.incoming.next().await {
                    Some(Ok(x)) => x,
                    Some(Err(e)) => {
                        error!("error during accepting TCP client: {e}");
                        continue;
                    }
                    None => break,
                };
                let Some(server_config) = self.tls_config.borrow().clone() else {
                    warn!("inbound TLS connection dropped (no certificates loaded, but were configured)");
                    continue
                };

                let lazy = LazyConfigAcceptor::new(Acceptor::default(), client);
                let sender = sender.clone();
                tokio::spawn(async move {
                    let accepted = match lazy.await {
                        Ok(x) => x,
                        Err(e) => {
                            error!("error during TLS init: {e}");
                            return;
                        }
                    };
                    let tls_stream = accepted.into_stream(server_config).await;
                    if sender.send(tls_stream).await.is_err() {
                        error!("TLS acceptor hung");
                    }
                });
            }
        });
        AcceptWrapper {
            stream: ReceiverStream::new(receiver),
        }
    }
}