use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use anyhow::{anyhow, bail, Context};
use rustls::server::ResolvesServerCertUsingSni;
use rustls::{Certificate, PrivateKey, ServerConfig};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
pub type Cert = (Vec<Certificate>, PrivateKey);
pub fn read_cert(
private_key_path: impl AsRef<Path>,
certs_path: impl AsRef<Path>,
) -> anyhow::Result<Cert> {
let mut private_key_file = std::io::BufReader::new(std::fs::File::open(private_key_path)?);
let private_key =
std::iter::from_fn(|| rustls_pemfile::read_one(&mut private_key_file).transpose())
.filter_map(|x| match x {
Err(e) => Some(Err(e)),
Ok(rustls_pemfile::Item::RSAKey(key)) => Some(Ok(key)),
Ok(rustls_pemfile::Item::PKCS8Key(key)) => Some(Ok(key)),
Ok(rustls_pemfile::Item::ECKey(key)) => Some(Ok(key)),
Ok(_) => None,
})
.next()
.ok_or(anyhow!("no private key"))??;
let private_key = PrivateKey(private_key);
let mut certs_file = std::io::BufReader::new(std::fs::File::open(certs_path)?);
let certs = rustls_pemfile::certs(&mut certs_file)?
.into_iter()
.map(|x| Certificate(x))
.collect::<Vec<_>>();
Ok((certs, private_key))
}
pub fn get_selfsigned(names: Vec<String>) -> anyhow::Result<Cert> {
let cert = rcgen::generate_simple_self_signed(names)?;
let chain = rustls_pemfile::certs(&mut cert.serialize_pem()?.as_bytes())?
.into_iter()
.map(|v| Certificate(v))
.collect::<Vec<_>>();
if chain.is_empty() {
bail!("sanity");
}
let key = PrivateKey(
rustls_pemfile::pkcs8_private_keys(&mut cert.serialize_private_key_pem().as_bytes())?
.get(0)
.ok_or(anyhow!("key"))?
.to_vec(),
);
Ok((chain, key))
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct TlsHostDiskConfig {
key_pem: String,
fullchain_pem: String,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(tag = "source")]
pub enum TlsHostConfig {
Disk(TlsHostDiskConfig),
}
#[derive(Deserialize, Clone, PartialEq, Eq, Hash, Debug, Serialize)]
#[serde(tag = "")]
pub struct TlsConfig {
#[serde(with = "tuple_vec_map")]
hosts: Vec<(String, TlsHostConfig)>,
}
pub type AlpnProtocols = Vec<Vec<u8>>;
lazy_static::lazy_static! {
static ref CACHE: Mutex<HashMap<TlsConfig, ServerConfig>> = {
Default::default()
};
}
pub async fn clear_cache() {
*(*CACHE).lock().await = HashMap::new();
}
async fn new_rustls_config(cfg: TlsConfig) -> anyhow::Result<ServerConfig> {
let mut store = ResolvesServerCertUsingSni::new();
for (hostname, host_cfg) in cfg.hosts {
let maybe_res = match host_cfg.clone() {
TlsHostConfig::Disk(disk_cfg) => {
tokio::task::spawn_blocking(|| read_cert(disk_cfg.key_pem, disk_cfg.fullchain_pem))
.await?
.ok()
}
};
let (chain, key) = match maybe_res {
Some(v) => v,
None => {
log::info!("couldn't read cert");
log::info!("using self-signed cert");
get_selfsigned(vec![hostname.clone()])?
}
};
let signing_key = rustls::sign::any_supported_type(&key)?;
store.add(
hostname.as_str(),
rustls::sign::CertifiedKey::new(chain.clone(), signing_key),
).context(format!("could not add cert {:?} for domain {}", host_cfg, hostname))?;
}
Ok(ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(store)))
}
pub async fn get_rustls_config(
alpn: AlpnProtocols,
cfg: TlsConfig,
) -> anyhow::Result<ServerConfig> {
let mut cache = CACHE.lock().await;
let mut config_without_alpn = if let Some(val) = cache.get(&cfg) {
val.clone()
} else {
let new_val = new_rustls_config(cfg.clone()).await?;
cache.insert(cfg, new_val.clone());
new_val
};
config_without_alpn.alpn_protocols = alpn;
Ok(config_without_alpn)
}