#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DnMatchError {
CertDnMismatch { expected_dn: String },
NameMismatch { expected_name: String },
}
impl core::fmt::Display for DnMatchError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::CertDnMismatch { expected_dn } => write!(
f,
"the distinguished name (DN) on the server certificate does not match \
the expected value \"{expected_dn}\""
),
Self::NameMismatch { expected_name } => write!(
f,
"the server name \"{expected_name}\" does not match the names in the \
server certificate"
),
}
}
}
impl std::error::Error for DnMatchError {}
#[must_use]
pub fn parse_dn(dn: &str) -> Vec<(String, String)> {
let mut out: Vec<(String, String)> = Vec::new();
let bytes: Vec<char> = dn.chars().collect();
let mut i = 0usize;
let n = bytes.len();
while i < n {
if bytes[i] == ',' {
i += 1;
if i < n && bytes[i] == ' ' {
i += 1;
}
}
while i < n && bytes[i] == ' ' {
i += 1;
}
if i >= n {
break;
}
let name_start = i;
while i < n && bytes[i].is_ascii_uppercase() {
i += 1;
}
if i == name_start || i >= n || bytes[i] != '=' {
while i < n && bytes[i] != ',' {
i += 1;
}
continue;
}
let name: String = bytes[name_start..i].iter().collect();
i += 1;
let value = if i < n && bytes[i] == '"' {
i += 1; let mut val = String::new();
while i < n {
if bytes[i] == '"' {
if i + 1 < n && bytes[i + 1] == '"' {
val.push('"');
i += 2;
} else {
i += 1; break;
}
} else {
val.push(bytes[i]);
i += 1;
}
}
val
} else {
let val_start = i;
while i < n && bytes[i] != ',' {
i += 1;
}
let mut val: String = bytes[val_start..i].iter().collect();
if val.ends_with(' ') {
val = val.trim_end_matches(' ').to_string();
}
val
};
out.push((name, value));
}
out.sort();
out
}
pub fn check_cert_dn(expected_dn: &str, server_subject_dn: &str) -> Result<(), DnMatchError> {
let expected = parse_dn(expected_dn);
let server = parse_dn(server_subject_dn);
if expected == server {
Ok(())
} else {
Err(DnMatchError::CertDnMismatch {
expected_dn: expected_dn.to_string(),
})
}
}
#[must_use]
pub fn name_matches(name_to_check: &str, cert_name: &str) -> bool {
let cert_name = cert_name.to_lowercase();
let name_to_check = name_to_check.to_lowercase();
if name_to_check == cert_name {
return true;
}
let check_pos = name_to_check.find('.');
let cert_pos = cert_name.find('.');
let (Some(check_pos), Some(cert_pos)) = (check_pos, cert_pos) else {
return false;
};
if check_pos == 0 || cert_pos == 0 {
return false;
}
if name_to_check[check_pos..] != cert_name[cert_pos..] {
return false;
}
let cert_label = &cert_name[..cert_pos];
let check_label = &name_to_check[..check_pos];
if cert_label == "*" {
return true;
} else if let Some(suffix) = cert_label.strip_prefix('*') {
return check_label.ends_with(suffix);
} else if let Some(prefix) = cert_label.strip_suffix('*') {
return check_label.starts_with(prefix);
}
match cert_name.find('*') {
None => false,
Some(_) => {
let wildcard_pos = cert_name.find('*').unwrap_or(0);
let pre = &cert_name[..wildcard_pos];
let post_start = wildcard_pos + 1;
let post = if post_start <= cert_label.len() {
&cert_label[post_start..]
} else {
""
};
check_label.starts_with(pre) && check_label.ends_with(post)
}
}
}
pub fn check_server_name(
expected_name: &str,
san_dns_names: &[String],
common_names: &[String],
) -> Result<(), DnMatchError> {
for name in san_dns_names {
if name_matches(expected_name, name) {
return Ok(());
}
}
for name in common_names {
if name_matches(expected_name, name) {
return Ok(());
}
}
Err(DnMatchError::NameMismatch {
expected_name: expected_name.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_dn_simple() {
let parsed = parse_dn("CN=db.example.com,O=Example,C=US");
assert_eq!(
parsed,
vec![
("C".to_string(), "US".to_string()),
("CN".to_string(), "db.example.com".to_string()),
("O".to_string(), "Example".to_string()),
]
);
}
#[test]
fn parse_dn_order_independent_equality() {
let a = parse_dn("CN=x,O=y");
let b = parse_dn("O=y,CN=x");
assert_eq!(a, b);
}
#[test]
fn parse_dn_comma_space_separator() {
let a = parse_dn("CN=x, O=y, C=Z");
assert_eq!(
a,
vec![
("C".to_string(), "Z".to_string()),
("CN".to_string(), "x".to_string()),
("O".to_string(), "y".to_string()),
]
);
}
#[test]
fn parse_dn_quoted_value() {
let a = parse_dn(r#"CN="Acme, Inc.",C=US"#);
assert!(a.contains(&("CN".to_string(), "Acme, Inc.".to_string())));
assert!(a.contains(&("C".to_string(), "US".to_string())));
}
#[test]
fn check_cert_dn_accept_exact() {
assert!(check_cert_dn("CN=x,O=y", "O=y,CN=x").is_ok());
}
#[test]
fn check_cert_dn_reject_diff() {
let err = check_cert_dn("CN=x,O=y", "CN=z,O=y").unwrap_err();
assert!(matches!(err, DnMatchError::CertDnMismatch { .. }));
}
#[test]
fn check_cert_dn_reject_extra_attr() {
let err = check_cert_dn("CN=x", "CN=x,O=y").unwrap_err();
assert!(matches!(err, DnMatchError::CertDnMismatch { .. }));
}
#[test]
fn name_matches_full_case_insensitive() {
assert!(name_matches("DB.example.com", "db.example.COM"));
}
#[test]
fn name_matches_leading_wildcard() {
assert!(name_matches("host.example.com", "*.example.com"));
assert!(!name_matches("host.sub.example.com", "*.example.com"));
}
#[test]
fn name_matches_prefix_wildcard_label() {
assert!(name_matches("webserver.example.com", "web*.example.com"));
assert!(!name_matches("appserver.example.com", "web*.example.com"));
}
#[test]
fn name_matches_suffix_wildcard_label() {
assert!(name_matches("serverweb.example.com", "*web.example.com"));
}
#[test]
fn name_matches_rejects_single_label() {
assert!(!name_matches("localhost", "*"));
}
#[test]
fn check_server_name_san_first() {
assert!(check_server_name("db.example.com", &["db.example.com".to_string()], &[]).is_ok());
}
#[test]
fn check_server_name_falls_back_to_cn() {
assert!(check_server_name("db.example.com", &[], &["db.example.com".to_string()]).is_ok());
}
#[test]
fn check_server_name_rejects_unknown() {
let err = check_server_name("evil.example.com", &["db.example.com".to_string()], &[])
.unwrap_err();
assert!(matches!(err, DnMatchError::NameMismatch { .. }));
}
}