use std::path::PathBuf;
use super::{Role, User};
#[derive(Debug, Clone)]
pub struct CertAuthConfig {
pub enabled: bool,
pub trust_bundle: PathBuf,
pub identity_mode: CertIdentityMode,
pub role_oid: Option<String>,
pub default_role: Role,
pub map_to_existing_users: bool,
}
impl Default for CertAuthConfig {
fn default() -> Self {
Self {
enabled: false,
trust_bundle: PathBuf::from("./certs/client-ca.pem"),
identity_mode: CertIdentityMode::CommonName,
role_oid: None,
default_role: Role::Read,
map_to_existing_users: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CertIdentityMode {
CommonName,
SanRfc822Name,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CertIdentity {
pub username: String,
pub role: Role,
pub subject_dn: String,
pub serial_hex: String,
pub not_after_unix_secs: i64,
}
#[derive(Debug, Clone)]
pub enum CertAuthError {
MissingIdentity(CertIdentityMode),
MissingRoleExtension(String),
UnknownUser(String),
Expired { not_after_unix_secs: i64 },
TrustBundle(String),
Parse(String),
}
impl std::fmt::Display for CertAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CertAuthError::MissingIdentity(mode) => {
write!(f, "client cert missing {:?} identity field", mode)
}
CertAuthError::MissingRoleExtension(oid) => {
write!(f, "client cert missing role extension {oid}")
}
CertAuthError::UnknownUser(u) => write!(f, "cert user '{u}' not in auth store"),
CertAuthError::Expired {
not_after_unix_secs,
} => write!(f, "client cert expired at unix {not_after_unix_secs}"),
CertAuthError::TrustBundle(m) => write!(f, "trust bundle error: {m}"),
CertAuthError::Parse(m) => write!(f, "cert parse error: {m}"),
}
}
}
impl std::error::Error for CertAuthError {}
#[derive(Debug, Clone)]
pub struct ParsedClientCert {
pub subject_dn: String,
pub common_name: Option<String>,
pub san_rfc822: Vec<String>,
pub serial_hex: String,
pub not_after_unix_secs: i64,
pub extensions: std::collections::HashMap<String, Vec<u8>>,
}
pub struct CertAuthenticator {
config: CertAuthConfig,
}
impl CertAuthenticator {
pub fn new(config: CertAuthConfig) -> Self {
Self { config }
}
pub fn validate<F>(
&self,
cert: &ParsedClientCert,
now_unix_secs: i64,
lookup_user: F,
) -> Result<CertIdentity, CertAuthError>
where
F: Fn(&str) -> Option<User>,
{
if !self.config.enabled {
return Err(CertAuthError::Parse(
"cert auth disabled on this listener".into(),
));
}
if cert.not_after_unix_secs < now_unix_secs {
return Err(CertAuthError::Expired {
not_after_unix_secs: cert.not_after_unix_secs,
});
}
let username = match self.config.identity_mode {
CertIdentityMode::CommonName => cert
.common_name
.clone()
.ok_or(CertAuthError::MissingIdentity(CertIdentityMode::CommonName))?,
CertIdentityMode::SanRfc822Name => {
cert.san_rfc822
.first()
.cloned()
.ok_or(CertAuthError::MissingIdentity(
CertIdentityMode::SanRfc822Name,
))?
}
};
let role = if self.config.map_to_existing_users {
match lookup_user(&username) {
Some(user) => user.role,
None => self.derive_role_from_cert(cert)?,
}
} else {
self.derive_role_from_cert(cert)?
};
Ok(CertIdentity {
username,
role,
subject_dn: cert.subject_dn.clone(),
serial_hex: cert.serial_hex.clone(),
not_after_unix_secs: cert.not_after_unix_secs,
})
}
fn derive_role_from_cert(&self, cert: &ParsedClientCert) -> Result<Role, CertAuthError> {
let Some(oid) = &self.config.role_oid else {
return Ok(self.config.default_role);
};
let bytes = cert
.extensions
.get(oid)
.ok_or_else(|| CertAuthError::MissingRoleExtension(oid.clone()))?;
let name = std::str::from_utf8(bytes)
.map_err(|e| CertAuthError::Parse(format!("role extension not valid UTF-8: {e}")))?;
Role::from_str(name.trim())
.ok_or_else(|| CertAuthError::Parse(format!("unknown role '{name}'")))
}
pub fn config(&self) -> &CertAuthConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn base_cert() -> ParsedClientCert {
ParsedClientCert {
subject_dn: "CN=alice,O=reddb,C=BR".to_string(),
common_name: Some("alice".to_string()),
san_rfc822: vec!["alice@example.com".to_string()],
serial_hex: "ABCDEF".to_string(),
not_after_unix_secs: 2_000_000_000,
extensions: HashMap::new(),
}
}
fn cfg(mode: CertIdentityMode) -> CertAuthConfig {
CertAuthConfig {
enabled: true,
identity_mode: mode,
..CertAuthConfig::default()
}
}
#[test]
fn common_name_maps_to_username() {
let auth = CertAuthenticator::new(cfg(CertIdentityMode::CommonName));
let id = auth
.validate(&base_cert(), 1_000_000_000, |_| None)
.unwrap();
assert_eq!(id.username, "alice");
assert_eq!(id.role, Role::Read);
}
#[test]
fn san_rfc822_maps_to_email() {
let auth = CertAuthenticator::new(cfg(CertIdentityMode::SanRfc822Name));
let id = auth
.validate(&base_cert(), 1_000_000_000, |_| None)
.unwrap();
assert_eq!(id.username, "alice@example.com");
}
#[test]
fn missing_cn_field_rejected() {
let mut cert = base_cert();
cert.common_name = None;
let auth = CertAuthenticator::new(cfg(CertIdentityMode::CommonName));
let err = auth.validate(&cert, 1_000_000_000, |_| None).unwrap_err();
assert!(matches!(err, CertAuthError::MissingIdentity(_)));
}
#[test]
fn expired_cert_rejected() {
let mut cert = base_cert();
cert.not_after_unix_secs = 500;
let auth = CertAuthenticator::new(cfg(CertIdentityMode::CommonName));
let err = auth.validate(&cert, 1_000, |_| None).unwrap_err();
assert!(matches!(err, CertAuthError::Expired { .. }));
}
#[test]
fn role_extension_overrides_default_role() {
let mut cert = base_cert();
cert.extensions
.insert("1.3.6.1.4.1.99999.1".to_string(), b"admin".to_vec());
let mut config = cfg(CertIdentityMode::CommonName);
config.role_oid = Some("1.3.6.1.4.1.99999.1".to_string());
config.map_to_existing_users = false;
let auth = CertAuthenticator::new(config);
let id = auth.validate(&cert, 1_000_000_000, |_| None).unwrap();
assert_eq!(id.role, Role::Admin);
}
#[test]
fn missing_role_extension_errors_when_configured() {
let mut config = cfg(CertIdentityMode::CommonName);
config.role_oid = Some("1.2.3".to_string());
config.map_to_existing_users = false;
let auth = CertAuthenticator::new(config);
let err = auth
.validate(&base_cert(), 1_000_000_000, |_| None)
.unwrap_err();
assert!(matches!(err, CertAuthError::MissingRoleExtension(_)));
}
}