tokn-router 0.2.0-rc.3

Routing, relay, and proxy orchestration across providers for tokn gateway
Documentation
use anyhow::{Context, Result};
use parking_lot::Mutex;
use rcgen::{BasicConstraints, CertificateParams, CertifiedIssuer, IsCa, Issuer, KeyPair};
use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use time::{Duration as TimeDuration, OffsetDateTime};

const CA_CERT_FILE: &str = "ca.crt";
const CA_KEY_FILE: &str = "ca.key";
const CA_BUNDLE_FILE: &str = "ca-bundle.crt";

pub fn load_or_generate_ca(dir: &Path, force_regenerate: bool) -> Result<ProxyCa> {
  std::fs::create_dir_all(dir).with_context(|| format!("create ca dir {}", dir.display()))?;
  let cert_path = dir.join(CA_CERT_FILE);
  let key_path = dir.join(CA_KEY_FILE);

  if force_regenerate || !cert_path.exists() || !key_path.exists() {
    return generate_ca(dir);
  }

  let cert_pem = std::fs::read_to_string(&cert_path).with_context(|| format!("read {}", cert_path.display()))?;
  let key_pem = std::fs::read_to_string(&key_path).with_context(|| format!("read {}", key_path.display()))?;
  let signing_key = KeyPair::from_pem(&key_pem).context("parse CA private key")?;
  let issuer = Issuer::new(ca_params(), signing_key);
  Ok(ProxyCa {
    dir: dir.to_path_buf(),
    cert_pem,
    issuer: Arc::new(issuer),
    cert_cache: Arc::new(Mutex::new(HashMap::new())),
  })
}

fn generate_ca(dir: &Path) -> Result<ProxyCa> {
  let params = ca_params();
  let key = KeyPair::generate().context("generate CA key")?;
  let issuer = CertifiedIssuer::self_signed(params, key).context("generate CA certificate")?;

  let cert_pem = issuer.pem();
  let key_pem = issuer.key().serialize_pem();
  write_ca_file(&dir.join(CA_CERT_FILE), cert_pem.as_bytes(), 0o644)?;
  write_ca_file(&dir.join(CA_KEY_FILE), key_pem.as_bytes(), 0o600)?;

  Ok(ProxyCa {
    dir: dir.to_path_buf(),
    cert_pem,
    issuer: Arc::new(Issuer::new(ca_params(), KeyPair::from_pem(&key_pem)?)),
    cert_cache: Arc::new(Mutex::new(HashMap::new())),
  })
}

fn ca_params() -> CertificateParams {
  let mut params = CertificateParams::default();
  params
    .distinguished_name
    .push(rcgen::DnType::CommonName, "tokn-router local proxy");
  params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
  params.not_before = OffsetDateTime::now_utc() - TimeDuration::days(1);
  params.not_after = OffsetDateTime::now_utc() + TimeDuration::days(3650);
  params.key_usages = vec![
    rcgen::KeyUsagePurpose::KeyCertSign,
    rcgen::KeyUsagePurpose::DigitalSignature,
    rcgen::KeyUsagePurpose::CrlSign,
  ];
  params
}

fn write_ca_file(path: &Path, bytes: &[u8], mode: u32) -> Result<()> {
  std::fs::write(path, bytes).with_context(|| format!("write {}", path.display()))?;
  #[cfg(unix)]
  {
    use std::os::unix::fs::PermissionsExt;
    std::fs::set_permissions(path, std::fs::Permissions::from_mode(mode))
      .with_context(|| format!("chmod {}", path.display()))?;
  }
  #[cfg(not(unix))]
  let _ = mode;
  Ok(())
}

#[derive(Clone)]
pub struct ProxyCa {
  dir: PathBuf,
  cert_pem: String,
  issuer: Arc<Issuer<'static, KeyPair>>,
  cert_cache: Arc<Mutex<HashMap<String, Arc<CertifiedKey>>>>,
}

impl ProxyCa {
  pub fn cert_path(&self) -> PathBuf {
    self.dir.join(CA_CERT_FILE)
  }

  pub fn bundle_path(&self) -> PathBuf {
    self.dir.join(CA_BUNDLE_FILE)
  }

  pub fn key_path(&self) -> PathBuf {
    self.dir.join(CA_KEY_FILE)
  }

  pub fn fingerprint_sha256(&self) -> String {
    let digest = Sha256::digest(self.cert_pem.as_bytes());
    hexify(&digest)
  }

  pub fn ensure_bundle(&self) -> Result<PathBuf> {
    let bundle_path = self.bundle_path();
    let mut bundle = match detect_system_ca_bundle() {
      Some(path) => std::fs::read_to_string(&path).with_context(|| format!("read {}", path.display()))?,
      None => String::new(),
    };
    if !bundle.is_empty() && !bundle.ends_with('\n') {
      bundle.push('\n');
    }
    if !bundle.contains(&self.cert_pem) {
      bundle.push_str(&self.cert_pem);
      if !bundle.ends_with('\n') {
        bundle.push('\n');
      }
    }
    write_ca_file(&bundle_path, bundle.as_bytes(), 0o644)?;
    Ok(bundle_path)
  }

  pub(super) fn certified_key_for(&self, host: &str) -> Result<Arc<CertifiedKey>> {
    if let Some(existing) = self.cert_cache.lock().get(host).cloned() {
      return Ok(existing);
    }

    let mut params = CertificateParams::new(vec![host.to_string()]).context("build leaf certificate params")?;
    params.distinguished_name.push(rcgen::DnType::CommonName, host);
    params.not_before = OffsetDateTime::now_utc() - TimeDuration::days(1);
    params.not_after = OffsetDateTime::now_utc() + TimeDuration::days(7);
    params.is_ca = IsCa::NoCa;
    params.key_usages = vec![
      rcgen::KeyUsagePurpose::DigitalSignature,
      rcgen::KeyUsagePurpose::KeyEncipherment,
    ];
    params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth];

    let leaf_key = KeyPair::generate().context("generate leaf key")?;
    let cert = params
      .signed_by(&leaf_key, self.issuer.as_ref())
      .context("sign leaf certificate")?;
    let private_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(leaf_key.serialize_der()));
    let certified = Arc::new(
      CertifiedKey::from_der(
        vec![cert.der().clone()],
        private_key,
        &rustls::crypto::ring::default_provider(),
      )
      .context("build rustls certified key")?,
    );
    self.cert_cache.lock().insert(host.to_string(), certified.clone());
    Ok(certified)
  }
}

fn detect_system_ca_bundle() -> Option<PathBuf> {
  let env_path = std::env::var_os("SSL_CERT_FILE").map(PathBuf::from);
  let mut candidates = env_path.into_iter().chain(
    [
      "/etc/ssl/certs/ca-certificates.crt",
      "/etc/pki/tls/certs/ca-bundle.crt",
      "/etc/ssl/ca-bundle.pem",
      "/etc/pki/tls/cacert.pem",
      "/etc/ssl/cert.pem",
    ]
    .into_iter()
    .map(PathBuf::from),
  );
  candidates.find(|path| path.is_file())
}

impl fmt::Debug for ProxyCa {
  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    f.debug_struct("ProxyCa")
      .field("dir", &self.dir)
      .field("cert_path", &self.cert_path())
      .field("key_path", &self.key_path())
      .field("key_pem", &"***")
      .finish()
  }
}

pub(super) fn hexify(bytes: &[u8]) -> String {
  let mut out = String::with_capacity(bytes.len() * 2);
  for b in bytes {
    use std::fmt::Write as _;
    let _ = write!(out, "{b:02x}");
  }
  out
}

#[derive(Debug)]
pub(super) struct DynamicResolver {
  pub(super) ca: Arc<ProxyCa>,
  pub(super) fallback_host: String,
}

impl ResolvesServerCert for DynamicResolver {
  fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
    let host = client_hello.server_name().unwrap_or(&self.fallback_host);
    self.ca.certified_key_for(host).ok()
  }
}