use std::collections::HashMap;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct AuthSubject {
pub name: String,
pub groups: Vec<String>,
pub claims: HashMap<String, String>,
}
impl AuthSubject {
#[must_use]
pub fn anonymous() -> Self {
Self {
name: "anonymous".to_string(),
groups: Vec::new(),
claims: HashMap::new(),
}
}
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
groups: Vec::new(),
claims: HashMap::new(),
}
}
#[must_use]
pub fn with_group(mut self, g: impl Into<String>) -> Self {
self.groups.push(g.into());
self
}
#[must_use]
pub fn with_claim(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
self.claims.insert(k.into(), v.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthError {
MissingCredentials,
MalformedCredentials(String),
Rejected(String),
Misconfigured(String),
}
impl core::fmt::Display for AuthError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::MissingCredentials => f.write_str("missing credentials"),
Self::MalformedCredentials(m) => write!(f, "malformed credentials: {m}"),
Self::Rejected(m) => write!(f, "credentials rejected: {m}"),
Self::Misconfigured(m) => write!(f, "auth misconfigured: {m}"),
}
}
}
impl std::error::Error for AuthError {}
#[derive(Debug, Clone)]
pub enum AuthMode {
None,
Bearer {
tokens: HashMap<String, AuthSubject>,
},
Jwt {
pkcs1_pubkey_der: Vec<u8>,
expected_issuer: Option<String>,
},
Mtls,
SaslPlain {
users: HashMap<String, String>,
},
}
#[derive(Debug, Clone, Default)]
pub struct AuthInput<'a> {
pub authorization_header: Option<&'a str>,
pub sasl_plain_blob: Option<&'a [u8]>,
pub mtls_subject: Option<AuthSubject>,
}
impl AuthMode {
pub fn validate(&self, input: &AuthInput<'_>) -> Result<AuthSubject, AuthError> {
match self {
Self::None => Ok(AuthSubject::anonymous()),
Self::Bearer { tokens } => {
let hdr = input
.authorization_header
.ok_or(AuthError::MissingCredentials)?;
let token = strip_bearer(hdr)?;
tokens
.get(token)
.cloned()
.ok_or_else(|| AuthError::Rejected("unknown bearer token".to_string()))
}
Self::Jwt {
pkcs1_pubkey_der,
expected_issuer,
} => {
let hdr = input
.authorization_header
.ok_or(AuthError::MissingCredentials)?;
let token = strip_bearer(hdr)?;
validate_jwt_rs256(token, pkcs1_pubkey_der, expected_issuer.as_deref())
}
Self::Mtls => input
.mtls_subject
.clone()
.ok_or_else(|| AuthError::Rejected("mTLS expected client cert".to_string())),
Self::SaslPlain { users } => {
let blob = input.sasl_plain_blob.ok_or(AuthError::MissingCredentials)?;
let (user, pass) = parse_sasl_plain(blob)?;
let stored = users
.get(user)
.ok_or_else(|| AuthError::Rejected("unknown user".to_string()))?;
if stored == pass {
Ok(AuthSubject::new(user))
} else {
Err(AuthError::Rejected("password mismatch".to_string()))
}
}
}
}
}
fn strip_bearer(hdr: &str) -> Result<&str, AuthError> {
let trimmed = hdr.trim();
let prefix = "Bearer ";
if trimmed.len() < prefix.len()
|| !trimmed
.get(..prefix.len())
.is_some_and(|p| p.eq_ignore_ascii_case(prefix))
{
return Err(AuthError::MalformedCredentials(
"expected `Bearer …`".to_string(),
));
}
Ok(trimmed[prefix.len()..].trim())
}
fn parse_sasl_plain(blob: &[u8]) -> Result<(&str, &str), AuthError> {
let mut parts = blob.splitn(3, |b| *b == 0);
let _authzid = parts
.next()
.ok_or(AuthError::MalformedCredentials("sasl-plain empty".into()))?;
let authcid = parts
.next()
.ok_or(AuthError::MalformedCredentials("sasl-plain no user".into()))?;
let passwd = parts
.next()
.ok_or(AuthError::MalformedCredentials("sasl-plain no pass".into()))?;
let user = core::str::from_utf8(authcid)
.map_err(|_| AuthError::MalformedCredentials("sasl-plain user utf8".into()))?;
let pass = core::str::from_utf8(passwd)
.map_err(|_| AuthError::MalformedCredentials("sasl-plain pass utf8".into()))?;
if user.is_empty() {
return Err(AuthError::MalformedCredentials(
"sasl-plain empty user".into(),
));
}
Ok((user, pass))
}
fn validate_jwt_rs256(
token: &str,
pkcs1_pubkey_der: &[u8],
expected_issuer: Option<&str>,
) -> Result<AuthSubject, AuthError> {
use base64::Engine as _;
let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
let mut parts = token.split('.');
let h_b64 = parts
.next()
.ok_or_else(|| AuthError::MalformedCredentials("jwt: no header".into()))?;
let p_b64 = parts
.next()
.ok_or_else(|| AuthError::MalformedCredentials("jwt: no payload".into()))?;
let s_b64 = parts
.next()
.ok_or_else(|| AuthError::MalformedCredentials("jwt: no sig".into()))?;
if parts.next().is_some() {
return Err(AuthError::MalformedCredentials(
"jwt: too many segments".into(),
));
}
let header_bytes = engine
.decode(h_b64)
.map_err(|e| AuthError::MalformedCredentials(format!("jwt header b64: {e}")))?;
let payload_bytes = engine
.decode(p_b64)
.map_err(|e| AuthError::MalformedCredentials(format!("jwt payload b64: {e}")))?;
let sig_bytes = engine
.decode(s_b64)
.map_err(|e| AuthError::MalformedCredentials(format!("jwt sig b64: {e}")))?;
let header_str = core::str::from_utf8(&header_bytes)
.map_err(|_| AuthError::MalformedCredentials("jwt header utf8".into()))?;
if !json_field_eq(header_str, "alg", "RS256") {
return Err(AuthError::Rejected("jwt: alg must be RS256".into()));
}
let signed = {
let mut v = Vec::with_capacity(h_b64.len() + 1 + p_b64.len());
v.extend_from_slice(h_b64.as_bytes());
v.push(b'.');
v.extend_from_slice(p_b64.as_bytes());
v
};
let pubkey = ring::signature::UnparsedPublicKey::new(
&ring::signature::RSA_PKCS1_2048_8192_SHA256,
pkcs1_pubkey_der,
);
pubkey
.verify(&signed, &sig_bytes)
.map_err(|_| AuthError::Rejected("jwt: signature invalid".into()))?;
let payload_str = core::str::from_utf8(&payload_bytes)
.map_err(|_| AuthError::MalformedCredentials("jwt payload utf8".into()))?;
let sub = json_field(payload_str, "sub")
.ok_or_else(|| AuthError::Rejected("jwt: no sub claim".into()))?;
if let Some(expected) = expected_issuer {
let iss = json_field(payload_str, "iss")
.ok_or_else(|| AuthError::Rejected("jwt: no iss claim".into()))?;
if iss != expected {
return Err(AuthError::Rejected(format!("jwt: iss != {expected}")));
}
}
let mut subj = AuthSubject::new(sub);
if let Some(groups_raw) = json_array(payload_str, "groups") {
for g in groups_raw {
subj.groups.push(g);
}
}
Ok(subj)
}
fn json_field(src: &str, key: &str) -> Option<String> {
let pat = format!("\"{key}\"");
let pos = src.find(&pat)?;
let after = &src[pos + pat.len()..];
let colon = after.find(':')?;
let rest = after[colon + 1..].trim_start();
if let Some(stripped) = rest.strip_prefix('"') {
let end = stripped.find('"')?;
Some(stripped[..end].to_string())
} else {
let end = rest
.find(|c: char| c == ',' || c == '}' || c.is_whitespace())
.unwrap_or(rest.len());
Some(rest[..end].to_string())
}
}
fn json_field_eq(src: &str, key: &str, expected: &str) -> bool {
json_field(src, key).is_some_and(|v| v == expected)
}
fn json_array(src: &str, key: &str) -> Option<Vec<String>> {
let pat = format!("\"{key}\"");
let pos = src.find(&pat)?;
let after = &src[pos + pat.len()..];
let colon = after.find(':')?;
let rest = after[colon + 1..].trim_start();
let stripped = rest.strip_prefix('[')?;
let end = stripped.find(']')?;
let inside = &stripped[..end];
let mut out = Vec::new();
for piece in inside.split(',') {
let p = piece.trim().trim_matches('"');
if !p.is_empty() {
out.push(p.to_string());
}
}
Some(out)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn none_mode_yields_anonymous() {
let m = AuthMode::None;
let s = m.validate(&AuthInput::default()).unwrap();
assert_eq!(s.name, "anonymous");
}
#[test]
fn bearer_valid_token_accepted() {
let mut tokens = HashMap::new();
tokens.insert("secret123".to_string(), AuthSubject::new("alice"));
let m = AuthMode::Bearer { tokens };
let s = m
.validate(&AuthInput {
authorization_header: Some("Bearer secret123"),
..Default::default()
})
.unwrap();
assert_eq!(s.name, "alice");
}
#[test]
fn bearer_invalid_token_rejected() {
let m = AuthMode::Bearer {
tokens: HashMap::new(),
};
let err = m
.validate(&AuthInput {
authorization_header: Some("Bearer wrong"),
..Default::default()
})
.unwrap_err();
assert!(matches!(err, AuthError::Rejected(_)));
}
#[test]
fn bearer_missing_header_returns_missing() {
let m = AuthMode::Bearer {
tokens: HashMap::new(),
};
let err = m.validate(&AuthInput::default()).unwrap_err();
assert!(matches!(err, AuthError::MissingCredentials));
}
#[test]
fn bearer_malformed_header_returns_malformed() {
let m = AuthMode::Bearer {
tokens: HashMap::new(),
};
let err = m
.validate(&AuthInput {
authorization_header: Some("Basic xx"),
..Default::default()
})
.unwrap_err();
assert!(matches!(err, AuthError::MalformedCredentials(_)));
}
#[test]
fn mtls_with_subject_accepted() {
let m = AuthMode::Mtls;
let s = m
.validate(&AuthInput {
mtls_subject: Some(AuthSubject::new("CN=alice")),
..Default::default()
})
.unwrap();
assert_eq!(s.name, "CN=alice");
}
#[test]
fn mtls_without_subject_rejected() {
let m = AuthMode::Mtls;
let err = m.validate(&AuthInput::default()).unwrap_err();
assert!(matches!(err, AuthError::Rejected(_)));
}
#[test]
fn sasl_plain_valid_pair_accepted() {
let mut users = HashMap::new();
users.insert("alice".to_string(), "wonderland".to_string());
let m = AuthMode::SaslPlain { users };
let blob = b"\0alice\0wonderland";
let s = m
.validate(&AuthInput {
sasl_plain_blob: Some(blob),
..Default::default()
})
.unwrap();
assert_eq!(s.name, "alice");
}
#[test]
fn sasl_plain_wrong_password_rejected() {
let mut users = HashMap::new();
users.insert("alice".to_string(), "wonderland".to_string());
let m = AuthMode::SaslPlain { users };
let blob = b"\0alice\0wrong";
let err = m
.validate(&AuthInput {
sasl_plain_blob: Some(blob),
..Default::default()
})
.unwrap_err();
assert!(matches!(err, AuthError::Rejected(_)));
}
#[test]
fn json_field_extracts_string() {
let s = r#"{"alg":"RS256","typ":"JWT"}"#;
assert_eq!(json_field(s, "alg").as_deref(), Some("RS256"));
}
#[test]
fn json_array_extracts_groups() {
let s = r#"{"sub":"a","groups":["eng","ops"]}"#;
let g = json_array(s, "groups").unwrap();
assert_eq!(g, vec!["eng".to_string(), "ops".to_string()]);
}
#[test]
fn jwt_invalid_signature_rejected() {
let m = AuthMode::Jwt {
pkcs1_pubkey_der: vec![0u8; 32],
expected_issuer: None,
};
let token = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhIn0.AAAA";
let err = m
.validate(&AuthInput {
authorization_header: Some(&format!("Bearer {token}")),
..Default::default()
})
.unwrap_err();
assert!(matches!(err, AuthError::Rejected(_)));
}
#[test]
fn jwt_wrong_alg_rejected() {
let m = AuthMode::Jwt {
pkcs1_pubkey_der: vec![0u8; 32],
expected_issuer: None,
};
let token = "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJhIn0.AAAA";
let err = m
.validate(&AuthInput {
authorization_header: Some(&format!("Bearer {token}")),
..Default::default()
})
.unwrap_err();
assert!(matches!(err, AuthError::Rejected(_)));
}
}