httproxide 0.2.0

Rusted HTTP router reverse-proxy
Documentation
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;

use anyhow::{anyhow, bail, Context};
use rustls::server::ResolvesServerCertUsingSni;
use rustls::{Certificate, PrivateKey, ServerConfig};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;

pub type Cert = (Vec<Certificate>, PrivateKey);

pub fn read_cert(
    private_key_path: impl AsRef<Path>,
    certs_path: impl AsRef<Path>,
) -> anyhow::Result<Cert> {
    let mut private_key_file = std::io::BufReader::new(std::fs::File::open(private_key_path)?);

    let private_key =
        std::iter::from_fn(|| rustls_pemfile::read_one(&mut private_key_file).transpose())
            .filter_map(|x| match x {
                Err(e) => Some(Err(e)),
                Ok(rustls_pemfile::Item::RSAKey(key)) => Some(Ok(key)),
                Ok(rustls_pemfile::Item::PKCS8Key(key)) => Some(Ok(key)),
                Ok(rustls_pemfile::Item::ECKey(key)) => Some(Ok(key)),
                Ok(_) => None,
            })
            .next()
            .ok_or(anyhow!("no private key"))??;

    let private_key = PrivateKey(private_key);

    let mut certs_file = std::io::BufReader::new(std::fs::File::open(certs_path)?);
    let certs = rustls_pemfile::certs(&mut certs_file)?
        .into_iter()
        .map(|x| Certificate(x))
        .collect::<Vec<_>>();
    Ok((certs, private_key))
}

pub fn get_selfsigned(names: Vec<String>) -> anyhow::Result<Cert> {
    let cert = rcgen::generate_simple_self_signed(names)?;

    let chain = rustls_pemfile::certs(&mut cert.serialize_pem()?.as_bytes())?
        .into_iter()
        .map(|v| Certificate(v))
        .collect::<Vec<_>>();
    if chain.is_empty() {
        bail!("sanity");
    }

    let key = PrivateKey(
        rustls_pemfile::pkcs8_private_keys(&mut cert.serialize_private_key_pem().as_bytes())?
            .get(0)
            .ok_or(anyhow!("key"))?
            .to_vec(),
    );

    Ok((chain, key))
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct TlsHostDiskConfig {
    key_pem: String,
    fullchain_pem: String,
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(tag = "source")]
pub enum TlsHostConfig {
    Disk(TlsHostDiskConfig),
}

#[derive(Deserialize, Clone, PartialEq, Eq, Hash, Debug, Serialize)]
#[serde(tag = "")]
pub struct TlsConfig {
    #[serde(with = "tuple_vec_map")]
    hosts: Vec<(String, TlsHostConfig)>,
}

pub type AlpnProtocols = Vec<Vec<u8>>;

lazy_static::lazy_static! {
    static ref CACHE: Mutex<HashMap<TlsConfig, ServerConfig>> = {
        Default::default()
    };
}

pub async fn clear_cache() {
    *(*CACHE).lock().await = HashMap::new();
}

async fn new_rustls_config(cfg: TlsConfig) -> anyhow::Result<ServerConfig> {
    let mut store = ResolvesServerCertUsingSni::new();
    for (hostname, host_cfg) in cfg.hosts {
        let maybe_res = match host_cfg.clone() {
            TlsHostConfig::Disk(disk_cfg) => {
                tokio::task::spawn_blocking(|| read_cert(disk_cfg.key_pem, disk_cfg.fullchain_pem))
                    .await?
                    .ok()
            }
        };

        let (chain, key) = match maybe_res {
            Some(v) => v,
            None => {
                log::info!("couldn't read cert");
                log::info!("using self-signed cert");
                get_selfsigned(vec![hostname.clone()])?
            }
        };

        let signing_key = rustls::sign::any_supported_type(&key)?;
        store.add(
            hostname.as_str(),
            rustls::sign::CertifiedKey::new(chain.clone(), signing_key),
        ).context(format!("could not add cert {:?} for domain {}", host_cfg, hostname))?;
    }

    Ok(ServerConfig::builder()
        .with_safe_defaults()
        .with_no_client_auth()
        .with_cert_resolver(Arc::new(store)))
}

pub async fn get_rustls_config(
    alpn: AlpnProtocols,
    cfg: TlsConfig,
) -> anyhow::Result<ServerConfig> {
    let mut cache = CACHE.lock().await;
    let mut config_without_alpn = if let Some(val) = cache.get(&cfg) {
        val.clone()
    } else {
        let new_val = new_rustls_config(cfg.clone()).await?;
        cache.insert(cfg, new_val.clone());
        new_val
    };

    config_without_alpn.alpn_protocols = alpn;
    Ok(config_without_alpn)
}