use std::io;
use figment::value::magic::{Either, RelativePathBuf};
use rustls_pki_types::{pem::PemObject, CertificateDer};
use serde::{Deserialize, Serialize};
use crate::tls::{Error, Result};
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)]
pub struct MtlsConfig {
pub(crate) ca_certs: Either<RelativePathBuf, Vec<u8>>,
#[serde(default)]
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
pub mandatory: bool,
}
impl MtlsConfig {
pub fn from_path<C: AsRef<std::path::Path>>(ca_certs: C) -> Self {
MtlsConfig {
ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()),
mandatory: Default::default(),
}
}
pub fn from_bytes(ca_certs: &[u8]) -> Self {
MtlsConfig {
ca_certs: Either::Right(ca_certs.to_vec()),
mandatory: Default::default(),
}
}
pub fn mandatory(mut self, mandatory: bool) -> Self {
self.mandatory = mandatory;
self
}
pub fn ca_certs(&self) -> either::Either<std::path::PathBuf, &[u8]> {
match &self.ca_certs {
Either::Left(path) => either::Either::Left(path.relative()),
Either::Right(bytes) => either::Either::Right(bytes),
}
}
#[inline(always)]
pub fn ca_certs_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
crate::tls::config::to_reader(&self.ca_certs)
}
pub(crate) fn load_ca_certs(&self) -> Result<rustls::RootCertStore> {
let mut roots = rustls::RootCertStore::empty();
for cert in CertificateDer::pem_reader_iter(&mut self.ca_certs_reader()?) {
roots
.add(cert.map_err(std::io::Error::other)?)
.map_err(Error::CertAuth)?;
}
Ok(roots)
}
}
#[cfg(test)]
mod tests {
use figment::{
providers::{Format, Toml},
Figment,
};
use std::path::Path;
use crate::mtls::MtlsConfig;
#[test]
fn test_mtls_config() {
figment::Jail::expect_with(|jail| {
jail.create_file(
"MTLS.toml",
r#"
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
"#,
)?;
let figment = || Figment::from(Toml::file("MTLS.toml"));
figment().extract::<MtlsConfig>().expect_err("no ca");
jail.create_file(
"MTLS.toml",
r#"
ca_certs = "/ssl/ca.pem"
"#,
)?;
let mtls: MtlsConfig = figment().extract()?;
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
assert!(!mtls.mandatory);
jail.create_file(
"MTLS.toml",
r#"
ca_certs = "/ssl/ca.pem"
mandatory = true
"#,
)?;
let mtls: MtlsConfig = figment().extract()?;
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
assert!(mtls.mandatory);
jail.create_file(
"MTLS.toml",
r#"
ca_certs = "relative/ca.pem"
"#,
)?;
let mtls: MtlsConfig = figment().extract()?;
assert_eq!(
mtls.ca_certs().unwrap_left(),
jail.directory().join("relative/ca.pem")
);
Ok(())
});
}
}