use std::collections::HashMap;
use super::{Role, User};
#[derive(Debug, Clone)]
pub struct OAuthConfig {
pub enabled: bool,
pub issuer: String,
pub audience: String,
pub jwks_url: String,
pub identity_mode: OAuthIdentityMode,
pub role_claim: Option<String>,
pub tenant_claim: Option<String>,
pub default_role: Role,
pub map_to_existing_users: bool,
pub accept_bearer: bool,
}
impl Default for OAuthConfig {
fn default() -> Self {
Self {
enabled: false,
issuer: String::new(),
audience: String::new(),
jwks_url: String::new(),
identity_mode: OAuthIdentityMode::SubClaim,
role_claim: None,
tenant_claim: None,
default_role: Role::Read,
map_to_existing_users: true,
accept_bearer: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OAuthIdentityMode {
SubClaim,
ClaimField(String),
}
#[derive(Debug, Clone)]
pub struct DecodedJwt {
pub header: JwtHeader,
pub claims: JwtClaims,
pub signature: Vec<u8>,
pub signing_input: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct JwtHeader {
pub alg: String,
pub kid: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct JwtClaims {
pub iss: Option<String>,
pub sub: Option<String>,
pub aud: Vec<String>,
pub exp: Option<i64>,
pub nbf: Option<i64>,
pub iat: Option<i64>,
pub extra: HashMap<String, String>,
}
impl JwtClaims {
pub fn claim(&self, key: &str) -> Option<&str> {
match key {
"iss" => self.iss.as_deref(),
"sub" => self.sub.as_deref(),
_ => self.extra.get(key).map(|s| s.as_str()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthIdentity {
pub username: String,
pub tenant: Option<String>,
pub role: Role,
pub issuer: String,
pub subject: Option<String>,
pub expires_at_unix_secs: Option<i64>,
}
#[derive(Debug, Clone)]
pub enum OAuthError {
Disabled,
MissingToken,
Malformed(String),
WrongIssuer {
expected: String,
actual: String,
},
WrongAudience {
expected: String,
actual: Vec<String>,
},
Expired {
exp: i64,
},
NotYetValid {
nbf: i64,
},
BadSignature(String),
MissingIdentityClaim(OAuthIdentityMode),
MissingOrInvalidRole(String),
UnknownUser(String),
JwksFetch(String),
}
impl std::fmt::Display for OAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OAuthError::Disabled => write!(f, "OAuth disabled on this listener"),
OAuthError::MissingToken => write!(f, "no Bearer token"),
OAuthError::Malformed(m) => write!(f, "malformed JWT: {m}"),
OAuthError::WrongIssuer { expected, actual } => {
write!(f, "issuer mismatch: expected {expected}, got {actual}")
}
OAuthError::WrongAudience { expected, actual } => {
write!(
f,
"audience mismatch: expected {expected}, got {:?}",
actual
)
}
OAuthError::Expired { exp } => write!(f, "token expired at unix {exp}"),
OAuthError::NotYetValid { nbf } => {
write!(f, "token not valid before unix {nbf}")
}
OAuthError::BadSignature(m) => write!(f, "signature verification failed: {m}"),
OAuthError::MissingIdentityClaim(mode) => {
write!(f, "identity claim missing for mode {:?}", mode)
}
OAuthError::MissingOrInvalidRole(c) => {
write!(f, "role claim '{c}' missing or not a valid Role")
}
OAuthError::UnknownUser(u) => write!(f, "OAuth user '{u}' not in auth store"),
OAuthError::JwksFetch(m) => write!(f, "JWKS fetch failed: {m}"),
}
}
}
impl std::error::Error for OAuthError {}
#[derive(Debug, Clone)]
pub struct Jwk {
pub kid: String,
pub alg: String,
pub key_bytes: Vec<u8>,
}
pub type JwtVerifier = Box<dyn Fn(&Jwk, &[u8], &[u8]) -> Result<(), String> + Send + Sync>;
pub struct OAuthValidator {
config: OAuthConfig,
jwks: parking_lot::RwLock<Vec<Jwk>>,
verifier: JwtVerifier,
}
impl OAuthValidator {
pub fn with_verifier(config: OAuthConfig, verifier: JwtVerifier) -> Self {
Self {
config,
jwks: parking_lot::RwLock::new(Vec::new()),
verifier,
}
}
pub fn set_jwks(&self, keys: Vec<Jwk>) {
*self.jwks.write() = keys;
}
pub fn config(&self) -> &OAuthConfig {
&self.config
}
pub fn extract_bearer(&self, header_value: &str) -> Option<String> {
if !self.config.accept_bearer {
return None;
}
let trimmed = header_value.trim();
let prefix = "Bearer ";
if trimmed.len() > prefix.len() && trimmed[..prefix.len()].eq_ignore_ascii_case(prefix) {
Some(trimmed[prefix.len()..].trim().to_string())
} else {
None
}
}
pub fn validate<F>(
&self,
token: &DecodedJwt,
now_unix_secs: i64,
lookup_user: F,
) -> Result<OAuthIdentity, OAuthError>
where
F: Fn(&str) -> Option<User>,
{
if !self.config.enabled {
return Err(OAuthError::Disabled);
}
let jwk = {
let jwks = self.jwks.read();
let kid = token.header.kid.as_deref();
jwks.iter()
.find(|j| kid.map(|k| k == j.kid).unwrap_or(false) && j.alg == token.header.alg)
.cloned()
};
let Some(jwk) = jwk else {
return Err(OAuthError::BadSignature(format!(
"no JWK for kid {:?} alg {}",
token.header.kid, token.header.alg
)));
};
(self.verifier)(&jwk, &token.signing_input, &token.signature)
.map_err(OAuthError::BadSignature)?;
match &token.claims.iss {
Some(iss) if iss == &self.config.issuer => {}
Some(iss) => {
return Err(OAuthError::WrongIssuer {
expected: self.config.issuer.clone(),
actual: iss.clone(),
});
}
None => {
return Err(OAuthError::Malformed("missing iss".into()));
}
}
if !token.claims.aud.iter().any(|a| a == &self.config.audience) {
return Err(OAuthError::WrongAudience {
expected: self.config.audience.clone(),
actual: token.claims.aud.clone(),
});
}
if let Some(exp) = token.claims.exp {
if exp <= now_unix_secs {
return Err(OAuthError::Expired { exp });
}
}
if let Some(nbf) = token.claims.nbf {
if nbf > now_unix_secs {
return Err(OAuthError::NotYetValid { nbf });
}
}
let username = match &self.config.identity_mode {
OAuthIdentityMode::SubClaim => {
token
.claims
.sub
.clone()
.ok_or(OAuthError::MissingIdentityClaim(
OAuthIdentityMode::SubClaim,
))?
}
OAuthIdentityMode::ClaimField(name) => token
.claims
.claim(name)
.map(|s| s.to_string())
.ok_or_else(|| {
OAuthError::MissingIdentityClaim(OAuthIdentityMode::ClaimField(name.clone()))
})?,
};
let role = if self.config.map_to_existing_users {
match lookup_user(&username) {
Some(user) => user.role,
None => self.derive_role_from_claims(&token.claims)?,
}
} else {
self.derive_role_from_claims(&token.claims)?
};
let tenant = self
.config
.tenant_claim
.as_deref()
.and_then(|name| token.claims.claim(name).map(|s| s.to_string()))
.filter(|s| !s.is_empty());
Ok(OAuthIdentity {
username,
tenant,
role,
issuer: self.config.issuer.clone(),
subject: token.claims.sub.clone(),
expires_at_unix_secs: token.claims.exp,
})
}
fn derive_role_from_claims(&self, claims: &JwtClaims) -> Result<Role, OAuthError> {
let Some(name) = &self.config.role_claim else {
return Ok(self.config.default_role);
};
let Some(raw) = claims.claim(name) else {
return Ok(self.config.default_role);
};
Role::from_str(raw.trim()).ok_or_else(|| OAuthError::MissingOrInvalidRole(name.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn noop_verifier() -> JwtVerifier {
Box::new(|_jwk, _input, _sig| Ok(()))
}
fn base_config() -> OAuthConfig {
OAuthConfig {
enabled: true,
issuer: "https://id.example.com".to_string(),
audience: "reddb".to_string(),
jwks_url: String::new(),
identity_mode: OAuthIdentityMode::SubClaim,
role_claim: None,
tenant_claim: None,
default_role: Role::Read,
map_to_existing_users: false,
accept_bearer: true,
}
}
fn base_token(now: i64) -> DecodedJwt {
DecodedJwt {
header: JwtHeader {
alg: "RS256".to_string(),
kid: Some("k1".to_string()),
},
claims: JwtClaims {
iss: Some("https://id.example.com".to_string()),
sub: Some("alice".to_string()),
aud: vec!["reddb".to_string()],
exp: Some(now + 3600),
nbf: Some(now - 60),
iat: Some(now),
extra: HashMap::new(),
},
signature: vec![0u8; 8],
signing_input: b"header.payload".to_vec(),
}
}
fn seeded_validator() -> OAuthValidator {
let v = OAuthValidator::with_verifier(base_config(), noop_verifier());
v.set_jwks(vec![Jwk {
kid: "k1".to_string(),
alg: "RS256".to_string(),
key_bytes: Vec::new(),
}]);
v
}
#[test]
fn extract_bearer_case_insensitive() {
let v = seeded_validator();
assert_eq!(
v.extract_bearer("Bearer abc.def.ghi").as_deref(),
Some("abc.def.ghi")
);
assert_eq!(v.extract_bearer("bearer xyz").as_deref(), Some("xyz"));
assert!(v.extract_bearer("Basic QQ==").is_none());
}
#[test]
fn valid_token_yields_sub_identity() {
let v = seeded_validator();
let token = base_token(1_700_000_000);
let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
assert_eq!(id.username, "alice");
assert_eq!(id.role, Role::Read);
}
#[test]
fn issuer_mismatch_rejected() {
let v = seeded_validator();
let mut token = base_token(1_700_000_000);
token.claims.iss = Some("https://evil.example.com".to_string());
assert!(matches!(
v.validate(&token, 1_700_000_000, |_| None),
Err(OAuthError::WrongIssuer { .. })
));
}
#[test]
fn audience_mismatch_rejected() {
let v = seeded_validator();
let mut token = base_token(1_700_000_000);
token.claims.aud = vec!["other".to_string()];
assert!(matches!(
v.validate(&token, 1_700_000_000, |_| None),
Err(OAuthError::WrongAudience { .. })
));
}
#[test]
fn expired_token_rejected() {
let v = seeded_validator();
let mut token = base_token(1_700_000_000);
token.claims.exp = Some(1_600_000_000);
assert!(matches!(
v.validate(&token, 1_700_000_000, |_| None),
Err(OAuthError::Expired { .. })
));
}
#[test]
fn not_yet_valid_rejected() {
let v = seeded_validator();
let mut token = base_token(1_700_000_000);
token.claims.nbf = Some(1_800_000_000);
assert!(matches!(
v.validate(&token, 1_700_000_000, |_| None),
Err(OAuthError::NotYetValid { .. })
));
}
#[test]
fn missing_jwk_fails_signature() {
let v = OAuthValidator::with_verifier(base_config(), noop_verifier());
let token = base_token(1_700_000_000);
assert!(matches!(
v.validate(&token, 1_700_000_000, |_| None),
Err(OAuthError::BadSignature(_))
));
}
#[test]
fn role_claim_parses_from_extra() {
let mut config = base_config();
config.role_claim = Some("role".to_string());
let v = OAuthValidator::with_verifier(config, noop_verifier());
v.set_jwks(vec![Jwk {
kid: "k1".to_string(),
alg: "RS256".to_string(),
key_bytes: Vec::new(),
}]);
let mut token = base_token(1_700_000_000);
token
.claims
.extra
.insert("role".to_string(), "admin".to_string());
let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
assert_eq!(id.role, Role::Admin);
}
#[test]
fn claim_field_identity_mode() {
let mut config = base_config();
config.identity_mode = OAuthIdentityMode::ClaimField("preferred_username".into());
let v = OAuthValidator::with_verifier(config, noop_verifier());
v.set_jwks(vec![Jwk {
kid: "k1".to_string(),
alg: "RS256".to_string(),
key_bytes: Vec::new(),
}]);
let mut token = base_token(1_700_000_000);
token
.claims
.extra
.insert("preferred_username".into(), "alice.smith".into());
let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
assert_eq!(id.username, "alice.smith");
}
#[test]
fn tenant_claim_extracted_when_configured() {
let mut config = base_config();
config.tenant_claim = Some("tenant".into());
let v = OAuthValidator::with_verifier(config, noop_verifier());
v.set_jwks(vec![Jwk {
kid: "k1".to_string(),
alg: "RS256".to_string(),
key_bytes: Vec::new(),
}]);
let mut token = base_token(1_700_000_000);
token.claims.extra.insert("tenant".into(), "acme".into());
let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
assert_eq!(id.tenant.as_deref(), Some("acme"));
}
#[test]
fn tenant_absent_when_claim_unconfigured() {
let v = seeded_validator();
let mut token = base_token(1_700_000_000);
token.claims.extra.insert("tenant".into(), "acme".into());
let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
assert!(id.tenant.is_none());
}
#[test]
fn tenant_claim_custom_name() {
let mut config = base_config();
config.tenant_claim = Some("org_id".into());
let v = OAuthValidator::with_verifier(config, noop_verifier());
v.set_jwks(vec![Jwk {
kid: "k1".to_string(),
alg: "RS256".to_string(),
key_bytes: Vec::new(),
}]);
let mut token = base_token(1_700_000_000);
token.claims.extra.insert("org_id".into(), "globex".into());
let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
assert_eq!(id.tenant.as_deref(), Some("globex"));
}
}