use std::collections::BTreeSet;
#[derive(Clone, Debug, Default)]
pub struct PeerAllowlist {
allowed: BTreeSet<String>,
}
impl PeerAllowlist {
pub fn new<I, S>(names: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let allowed = names
.into_iter()
.filter_map(|s| {
let s = s.as_ref().trim();
if s.is_empty() { None } else { Some(s.to_ascii_lowercase()) }
})
.collect();
Self { allowed }
}
pub fn len(&self) -> usize {
self.allowed.len()
}
pub fn is_empty(&self) -> bool {
self.allowed.is_empty()
}
pub fn contains(&self, name: &str) -> bool {
self.allowed.contains(&name.trim().to_ascii_lowercase())
}
pub fn contains_any<I, S>(&self, names: I) -> bool
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
names.into_iter().any(|n| self.contains(n.as_ref()))
}
pub fn iter(&self) -> impl Iterator<Item = &str> {
self.allowed.iter().map(String::as_str)
}
}
#[cfg(feature = "tls-rustls")]
fn der_decode_len(data: &[u8]) -> Option<(usize, usize)> {
let first = *data.first()? as usize;
if first < 0x80 {
Some((first, 1))
} else {
let n = first & 0x7F;
if n == 0 || n > 4 || data.len() < 1 + n {
return None;
}
let mut len = 0usize;
for &b in &data[1..1 + n] {
len = (len << 8) | (b as usize);
}
Some((len, 1 + n))
}
}
#[cfg(feature = "tls-rustls")]
fn der_tlv(data: &[u8]) -> Option<(u8, &[u8], &[u8])> {
if data.is_empty() {
return None;
}
let tag = data[0];
let (len, consumed) = der_decode_len(&data[1..])?;
let start = 1 + consumed;
if data.len() < start + len {
return None;
}
Some((tag, &data[start..start + len], &data[start + len..]))
}
#[cfg(feature = "tls-rustls")]
pub(crate) fn extract_cert_names(cert_der: &[u8]) -> Vec<String> {
try_extract_cert_names(cert_der).unwrap_or_default()
}
#[cfg(feature = "tls-rustls")]
pub fn extract_cert_names_for_test(cert_der: &[u8]) -> Vec<String> {
extract_cert_names(cert_der)
}
#[cfg(feature = "tls-rustls")]
fn try_extract_cert_names(cert_der: &[u8]) -> Option<Vec<String>> {
let mut names: Vec<String> = Vec::new();
let (0x30, cert_body, _) = der_tlv(cert_der)? else {
return Some(names);
};
let (0x30, tbs, _) = der_tlv(cert_body)? else {
return Some(names);
};
let mut p = tbs;
if let Some((0xA0, _, rest)) = der_tlv(p) {
p = rest;
}
let (0x02, _, rest) = der_tlv(p)? else {
return Some(names);
};
p = rest;
let (0x30, _, rest) = der_tlv(p)? else {
return Some(names);
};
p = rest;
let (0x30, _, rest) = der_tlv(p)? else {
return Some(names);
};
p = rest;
let (0x30, _, rest) = der_tlv(p)? else {
return Some(names);
};
p = rest;
let (0x30, subject, rest) = der_tlv(p)? else {
return Some(names);
};
p = rest;
let mut rdns = subject;
while let Some((0x31, rdn_val, rest2)) = der_tlv(rdns) {
rdns = rest2;
let mut atvs = rdn_val;
while let Some((0x30, atv, rest3)) = der_tlv(atvs) {
atvs = rest3;
if let Some((0x06, oid_bytes, val_rest)) = der_tlv(atv)
&& oid_bytes == [0x55, 0x04, 0x03]
{
if let Some((_vtag, vval, _)) = der_tlv(val_rest)
&& let Ok(s) = std::str::from_utf8(vval)
&& !s.is_empty()
{
names.push(s.to_ascii_lowercase());
}
}
}
}
let (0x30, _, rest) = der_tlv(p)? else {
return Some(names);
};
p = rest;
if let Some((0x81, _, rest2)) = der_tlv(p) {
p = rest2;
}
if let Some((0x82, _, rest2)) = der_tlv(p) {
p = rest2;
}
while let Some((tag, val, rest)) = der_tlv(p) {
p = rest;
if tag != 0xA3 {
continue;
}
let (0x30, exts_body, _) = der_tlv(val)? else {
break;
};
let mut ext_p = exts_body;
while let Some((0x30, ext, rest2)) = der_tlv(ext_p) {
ext_p = rest2;
let (0x06, oid_bytes, ext_rest) = der_tlv(ext)? else {
continue;
};
if oid_bytes != [0x55, 0x1D, 0x11] {
continue;
}
let san_octet_rest = if ext_rest.first() == Some(&0x01) {
der_tlv(ext_rest).map(|(_, _, r)| r).unwrap_or(ext_rest)
} else {
ext_rest
};
let (0x04, octet_val, _) = der_tlv(san_octet_rest)? else {
continue;
};
let (0x30, san_seq, _) = der_tlv(octet_val)? else {
continue;
};
let mut san_p = san_seq;
while let Some((gtag, gval, rest3)) = der_tlv(san_p) {
san_p = rest3;
if gtag == 0x82
&& let Ok(s) = std::str::from_utf8(gval)
&& !s.is_empty()
{
names.push(s.to_ascii_lowercase());
}
}
}
break; }
let _ = p;
Some(names)
}
#[cfg(feature = "tls-rustls")]
pub(crate) struct PeerAllowlistVerifier {
inner: std::sync::Arc<dyn rustls::server::danger::ClientCertVerifier>,
allowlist: PeerAllowlist,
}
#[cfg(feature = "tls-rustls")]
impl PeerAllowlistVerifier {
pub(crate) fn new(
root_store: std::sync::Arc<rustls::RootCertStore>,
allowlist: PeerAllowlist,
) -> crate::error::Result<Self> {
if allowlist.is_empty() {
return Err(crate::error::RepError::ConfigError(
"PeerAllowlistVerifier requires a non-empty allowlist; an \
empty allowlist means no peer is authorised, which is almost \
certainly a misconfiguration. Add at least one expected peer \
subject name."
.into(),
));
}
let provider =
std::sync::Arc::new(rustls::crypto::ring::default_provider());
let inner =
rustls::server::WebPkiClientVerifier::builder_with_provider(
root_store, provider,
)
.build()
.map_err(|e| {
crate::error::RepError::ConfigError(format!(
"PeerAllowlistVerifier: WebPkiClientVerifier build \
failed: {e}"
))
})?;
Ok(Self { inner, allowlist })
}
}
#[cfg(feature = "tls-rustls")]
impl std::fmt::Debug for PeerAllowlistVerifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerAllowlistVerifier")
.field("allowlist_len", &self.allowlist.len())
.finish()
}
}
#[cfg(feature = "tls-rustls")]
impl rustls::server::danger::ClientCertVerifier for PeerAllowlistVerifier {
fn offer_client_auth(&self) -> bool {
true
}
fn client_auth_mandatory(&self) -> bool {
true
}
fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
self.inner.root_hint_subjects()
}
fn verify_client_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
intermediates: &[rustls::pki_types::CertificateDer<'_>],
now: rustls::pki_types::UnixTime,
) -> std::result::Result<
rustls::server::danger::ClientCertVerified,
rustls::Error,
> {
self.inner.verify_client_cert(end_entity, intermediates, now)?;
let names = extract_cert_names(end_entity.as_ref());
if !self.allowlist.contains_any(&names) {
let peer_names = if names.is_empty() {
"<no names found in cert>".to_string()
} else {
names.join(", ")
};
log::warn!(
"mTLS: rejecting peer — cert names [{}] not in allowlist",
peer_names
);
return Err(rustls::Error::General(format!(
"peer certificate names [{peer_names}] do not match any \
entry in the configured peer_allowlist"
)));
}
log::debug!("mTLS: peer cert names {:?} admitted by allowlist", names);
Ok(rustls::server::danger::ClientCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<
rustls::client::danger::HandshakeSignatureValid,
rustls::Error,
> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<
rustls::client::danger::HandshakeSignatureValid,
rustls::Error,
> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_allowlist_admits_no_one() {
let al = PeerAllowlist::default();
assert!(al.is_empty());
assert!(!al.contains("anyone"));
assert!(!al.contains_any(["a", "b", "c"]));
}
#[test]
fn case_insensitive_match() {
let al = PeerAllowlist::new(["node-1.cluster.example"]);
assert!(al.contains("node-1.cluster.example"));
assert!(al.contains("Node-1.Cluster.Example"));
assert!(al.contains("NODE-1.CLUSTER.EXAMPLE"));
assert!(!al.contains("node-2.cluster.example"));
}
#[test]
fn whitespace_and_empties_filtered() {
let al = PeerAllowlist::new([" node-1 ", "", " ", "node-2"]);
assert_eq!(al.len(), 2);
assert!(al.contains("node-1"));
assert!(al.contains("node-2"));
}
#[test]
fn no_wildcard_match() {
let al = PeerAllowlist::new(["*.cluster.example"]);
assert!(!al.contains("node-7.cluster.example"));
assert!(al.contains("*.cluster.example"));
}
#[test]
fn duplicates_collapsed() {
let al = PeerAllowlist::new(["node-1", "NODE-1", " node-1 "]);
assert_eq!(al.len(), 1);
}
#[test]
fn contains_any_admits_first_matching() {
let al = PeerAllowlist::new(["node-2"]);
assert!(al.contains_any(["nope", "node-2", "another"]));
assert!(!al.contains_any(["nope", "another"]));
}
#[test]
fn iter_yields_sorted_lowercase_entries() {
let al = PeerAllowlist::new(["beta", "ALPHA", "Charlie"]);
let v: Vec<&str> = al.iter().collect();
assert_eq!(v, vec!["alpha", "beta", "charlie"]);
}
#[test]
fn contains_trims_input_whitespace() {
let al = PeerAllowlist::new(["node-1"]);
assert!(al.contains(" node-1 "));
assert!(al.contains("\tnode-1\n"));
}
#[test]
fn allowlist_clone_is_independent() {
let al1 = PeerAllowlist::new(["a", "b"]);
let al2 = al1.clone();
assert_eq!(al1.len(), al2.len());
assert!(al1.contains("a") && al2.contains("a"));
}
}