libdoh 0.9.7

DoH and Oblivious DoH library for the rust-doh app
Documentation
use std::fs::File;
use std::io::{self, BufReader, Cursor, Read};
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;

use futures::{future::FutureExt, join, select};
use hyper::server::conn::Http;
use tokio::{
    net::TcpListener,
    sync::mpsc::{self, Receiver},
};
use tokio_rustls::{
    rustls::{Certificate, PrivateKey, ServerConfig},
    TlsAcceptor,
};

use crate::constants::CERTS_WATCH_DELAY_SECS;
use crate::errors::*;
use crate::{DoH, LocalExecutor};

pub fn create_tls_acceptor<P, P2>(certs_path: P, certs_keys_path: P2) -> io::Result<TlsAcceptor>
where
    P: AsRef<Path>,
    P2: AsRef<Path>,
{
    let certs: Vec<_> = {
        let certs_path_str = certs_path.as_ref().display().to_string();
        let mut reader = BufReader::new(File::open(certs_path).map_err(|e| {
            io::Error::new(
                e.kind(),
                format!(
                    "Unable to load the certificates [{certs_path_str}]: {e}"
                ),
            )
        })?);
        rustls_pemfile::certs(&mut reader).map_err(|_| {
            io::Error::new(
                io::ErrorKind::InvalidInput,
                "Unable to parse the certificates",
            )
        })?
    }
    .drain(..)
    .map(Certificate)
    .collect();
    let certs_keys: Vec<_> = {
        let certs_keys_path_str = certs_keys_path.as_ref().display().to_string();
        let encoded_keys = {
            let mut encoded_keys = vec![];
            File::open(certs_keys_path)
                .map_err(|e| {
                    io::Error::new(
                        e.kind(),
                        format!(
                            "Unable to load the certificate keys [{certs_keys_path_str}]: {e}"
                        ),
                    )
                })?
                .read_to_end(&mut encoded_keys)?;
            encoded_keys
        };
        let mut reader = Cursor::new(encoded_keys);
        let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
            io::Error::new(
                io::ErrorKind::InvalidInput,
                "Unable to parse the certificates private keys (PKCS8)",
            )
        })?;
        reader.set_position(0);
        let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| {
            io::Error::new(
                io::ErrorKind::InvalidInput,
                "Unable to parse the certificates private keys (RSA)",
            )
        })?;
        let mut keys = pkcs8_keys;
        keys.append(&mut rsa_keys);
        if keys.is_empty() {
            return Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                "No private keys found - Make sure that they are in PKCS#8/PEM format",
            ));
        }
        keys.drain(..).map(PrivateKey).collect()
    };

    let mut server_config = certs_keys
        .into_iter()
        .find_map(|certs_key| {
            let server_config_builder = ServerConfig::builder()
                .with_safe_defaults()
                .with_no_client_auth();
            if let Ok(found_config) =
                server_config_builder.with_single_cert(certs.clone(), certs_key)
            {
                Some(found_config)
            } else {
                None
            }
        })
        .ok_or_else(|| {
            io::Error::new(
                io::ErrorKind::InvalidInput,
                "Unable to find a valid certificate and key",
            )
        })?;
    server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
    Ok(TlsAcceptor::from(Arc::new(server_config)))
}

impl DoH {
    async fn start_https_service(
        self,
        mut tls_acceptor_receiver: Receiver<TlsAcceptor>,
        listener: TcpListener,
        server: Http<LocalExecutor>,
    ) -> Result<(), DoHError> {
        let mut tls_acceptor: Option<TlsAcceptor> = None;
        let listener_service = async {
            loop {
                select! {
                    tcp_cnx = listener.accept().fuse() => {
                        if tls_acceptor.is_none() || tcp_cnx.is_err() {
                            continue;
                        }
                        let (raw_stream, _client_addr) = tcp_cnx.unwrap();
                        if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
                            self.clone().client_serve(stream, server.clone()).await
                        }
                    }
                    new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => {
                        if new_tls_acceptor.is_none() {
                            break;
                        }
                        tls_acceptor = new_tls_acceptor;
                    }
                    complete => break
                }
            }
            Ok(()) as Result<(), DoHError>
        };
        listener_service.await?;
        Ok(())
    }

    pub async fn start_with_tls(
        self,
        listener: TcpListener,
        server: Http<LocalExecutor>,
    ) -> Result<(), DoHError> {
        let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone();
        let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone();
        let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1);
        let https_service = self.start_https_service(tls_acceptor_receiver, listener, server);
        let cert_service = async {
            loop {
                match create_tls_acceptor(&certs_path, &certs_keys_path) {
                    Ok(tls_acceptor) => {
                        if tls_acceptor_sender.send(tls_acceptor).await.is_err() {
                            break;
                        }
                    }
                    Err(e) => eprintln!("TLS certificates error: {e}"),
                }
                tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await;
            }
            Ok::<_, DoHError>(())
        };
        join!(https_service, cert_service).0
    }
}