#[cfg(feature = "oidc")]
mod oidc_impl {
use jsonwebtoken::{
decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData, Validation,
};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use thiserror::Error;
use tracing::{debug, info};
#[derive(Error, Debug)]
pub enum OidcError {
#[error("Provider not found: {0}")]
ProviderNotFound(String),
#[error("Invalid token: {0}")]
InvalidToken(String),
#[error("Token expired")]
TokenExpired,
#[error("Invalid issuer: expected {expected}, got {actual}")]
InvalidIssuer { expected: String, actual: String },
#[error("Invalid audience: {0}")]
InvalidAudience(String),
#[error("Missing claim: {0}")]
MissingClaim(String),
#[error("JWKS fetch error: {0}")]
JwksFetch(String),
#[error("Key not found for kid: {0}")]
KeyNotFound(String),
#[error("Configuration error: {0}")]
Configuration(String),
#[error("HTTP error: {0}")]
Http(String),
#[error("JSON error: {0}")]
Json(String),
}
pub type OidcResult<T> = Result<T, OidcError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcProviderConfig {
pub name: String,
pub issuer: String,
pub audience: String,
#[serde(default = "default_username_claim")]
pub username_claim: String,
pub groups_claim: Option<String>,
pub roles_claim: Option<String>,
#[serde(default)]
pub required_claims: Vec<String>,
#[serde(default = "default_clock_skew")]
pub clock_skew_seconds: u64,
#[serde(default = "default_jwks_cache_ttl")]
pub jwks_cache_ttl_seconds: u64,
#[serde(default = "default_algorithms")]
pub algorithms: Vec<String>,
pub jwks_uri: Option<String>,
#[serde(default = "default_enabled")]
pub enabled: bool,
}
fn default_username_claim() -> String {
"sub".to_string()
}
fn default_clock_skew() -> u64 {
60
}
fn default_jwks_cache_ttl() -> u64 {
3600
}
fn default_algorithms() -> Vec<String> {
vec!["RS256".to_string()]
}
fn default_enabled() -> bool {
true
}
impl OidcProviderConfig {
pub fn new(
name: impl Into<String>,
issuer: impl Into<String>,
audience: impl Into<String>,
) -> Self {
Self {
name: name.into(),
issuer: issuer.into(),
audience: audience.into(),
username_claim: default_username_claim(),
groups_claim: None,
roles_claim: None,
required_claims: vec![],
clock_skew_seconds: default_clock_skew(),
jwks_cache_ttl_seconds: default_jwks_cache_ttl(),
algorithms: default_algorithms(),
jwks_uri: None,
enabled: true,
}
}
pub fn with_groups_claim(mut self, claim: impl Into<String>) -> Self {
self.groups_claim = Some(claim.into());
self
}
pub fn with_roles_claim(mut self, claim: impl Into<String>) -> Self {
self.roles_claim = Some(claim.into());
self
}
pub fn with_username_claim(mut self, claim: impl Into<String>) -> Self {
self.username_claim = claim.into();
self
}
pub fn jwks_url(&self) -> String {
self.jwks_uri.clone().unwrap_or_else(|| {
format!(
"{}/.well-known/jwks.json",
self.issuer.trim_end_matches('/')
)
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StandardClaims {
pub sub: String,
pub iss: String,
#[serde(default)]
pub aud: ClaimValue,
pub exp: u64,
#[serde(default)]
pub iat: Option<u64>,
#[serde(default)]
pub nbf: Option<u64>,
#[serde(default)]
pub jti: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ClaimValue {
Single(String),
Multiple(Vec<String>),
}
impl Default for ClaimValue {
fn default() -> Self {
ClaimValue::Multiple(vec![])
}
}
impl ClaimValue {
pub fn contains(&self, value: &str) -> bool {
match self {
ClaimValue::Single(s) => s == value,
ClaimValue::Multiple(v) => v.iter().any(|s| s == value),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcIdentity {
pub username: String,
pub subject: String,
pub issuer: String,
pub provider: String,
pub groups: Vec<String>,
pub roles: Vec<String>,
pub email: Option<String>,
pub name: Option<String>,
pub expires_at: u64,
pub claims: HashMap<String, serde_json::Value>,
}
impl OidcIdentity {
pub fn has_group(&self, group: &str) -> bool {
self.groups.iter().any(|g| g == group)
}
pub fn has_role(&self, role: &str) -> bool {
self.roles.iter().any(|r| r == role)
}
pub fn has_any_group(&self, groups: &[&str]) -> bool {
groups.iter().any(|g| self.has_group(g))
}
pub fn has_all_groups(&self, groups: &[&str]) -> bool {
groups.iter().all(|g| self.has_group(g))
}
}
#[derive(Debug)]
struct CachedJwks {
jwks: JwkSet,
fetched_at: Instant,
ttl: Duration,
}
impl CachedJwks {
fn is_expired(&self) -> bool {
self.fetched_at.elapsed() > self.ttl
}
}
pub struct OidcAuthenticator {
providers: HashMap<String, OidcProviderConfig>,
jwks_cache: RwLock<HashMap<String, CachedJwks>>,
http_client: reqwest::Client,
}
impl OidcAuthenticator {
pub fn new() -> OidcResult<Self> {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| {
OidcError::Configuration(format!("Failed to create OIDC HTTP client: {}", e))
})?;
Ok(Self {
providers: HashMap::new(),
jwks_cache: RwLock::new(HashMap::new()),
http_client,
})
}
pub fn add_provider(&mut self, config: OidcProviderConfig) {
if config.enabled {
info!("Added OIDC provider: {} ({})", config.name, config.issuer);
self.providers.insert(config.name.clone(), config);
}
}
pub fn provider_names(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
pub async fn validate_token(&self, token: &str) -> OidcResult<OidcIdentity> {
let header = decode_header(token)
.map_err(|e| OidcError::InvalidToken(format!("Invalid header: {}", e)))?;
let insecure_claims = self.peek_claims(token)?;
let provider = self
.providers
.values()
.find(|p| p.issuer == insecure_claims.iss)
.ok_or_else(|| OidcError::ProviderNotFound(insecure_claims.iss.clone()))?;
self.validate_with_provider(token, provider, &header.kid)
.await
}
pub async fn validate_with_provider_name(
&self,
token: &str,
provider_name: &str,
) -> OidcResult<OidcIdentity> {
let provider = self
.providers
.get(provider_name)
.ok_or_else(|| OidcError::ProviderNotFound(provider_name.to_string()))?;
let header = decode_header(token)
.map_err(|e| OidcError::InvalidToken(format!("Invalid header: {}", e)))?;
self.validate_with_provider(token, provider, &header.kid)
.await
}
fn peek_claims(&self, token: &str) -> OidcResult<StandardClaims> {
let mut validation = Validation::default();
validation.insecure_disable_signature_validation();
validation.validate_exp = false;
validation.validate_aud = false;
let data: TokenData<StandardClaims> =
decode(token, &DecodingKey::from_secret(&[]), &validation)
.map_err(|e| OidcError::InvalidToken(format!("Cannot parse token: {}", e)))?;
Ok(data.claims)
}
async fn validate_with_provider(
&self,
token: &str,
provider: &OidcProviderConfig,
kid: &Option<String>,
) -> OidcResult<OidcIdentity> {
let jwks = self.get_jwks(provider).await?;
let key = if let Some(kid) = kid {
jwks.find(kid)
.ok_or_else(|| OidcError::KeyNotFound(kid.clone()))?
} else {
jwks.keys
.first()
.ok_or_else(|| OidcError::KeyNotFound("no keys in JWKS".to_string()))?
};
let decoding_key = DecodingKey::from_jwk(key)
.map_err(|e| OidcError::InvalidToken(format!("Invalid JWK: {}", e)))?;
let algorithms: Vec<Algorithm> = provider
.algorithms
.iter()
.filter_map(|a| match a.as_str() {
"RS256" => Some(Algorithm::RS256),
"RS384" => Some(Algorithm::RS384),
"RS512" => Some(Algorithm::RS512),
"ES256" => Some(Algorithm::ES256),
"ES384" => Some(Algorithm::ES384),
"PS256" => Some(Algorithm::PS256),
"PS384" => Some(Algorithm::PS384),
"PS512" => Some(Algorithm::PS512),
_ => None,
})
.collect();
let primary_alg = algorithms.first().copied().unwrap_or(Algorithm::RS256);
let mut validation = Validation::new(primary_alg);
if algorithms.len() > 1 {
validation.algorithms = algorithms;
}
validation.set_issuer(&[&provider.issuer]);
validation.set_audience(&[&provider.audience]);
validation.leeway = provider.clock_skew_seconds;
let data: TokenData<StandardClaims> = decode(token, &decoding_key, &validation)
.map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::TokenExpired,
jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
expected: provider.issuer.clone(),
actual: "unknown".to_string(),
},
jsonwebtoken::errors::ErrorKind::InvalidAudience => {
OidcError::InvalidAudience(provider.audience.clone())
}
_ => OidcError::InvalidToken(format!("{}", e)),
})?;
let claims = data.claims;
let username = self.extract_claim(&claims, &provider.username_claim)?;
let groups = provider
.groups_claim
.as_ref()
.and_then(|claim| self.extract_string_array(&claims, claim))
.unwrap_or_default();
let roles = provider
.roles_claim
.as_ref()
.and_then(|claim| self.extract_string_array(&claims, claim))
.unwrap_or_default();
let email = claims
.extra
.get("email")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let name = claims
.extra
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
for claim in &provider.required_claims {
if !claims.extra.contains_key(claim) {
return Err(OidcError::MissingClaim(claim.clone()));
}
}
debug!(
"Validated OIDC token for user '{}' from provider '{}'",
username, provider.name
);
Ok(OidcIdentity {
username,
subject: claims.sub,
issuer: claims.iss,
provider: provider.name.clone(),
groups,
roles,
email,
name,
expires_at: claims.exp,
claims: claims.extra,
})
}
async fn get_jwks(&self, provider: &OidcProviderConfig) -> OidcResult<JwkSet> {
{
let cache = self.jwks_cache.read();
if let Some(cached) = cache.get(&provider.name) {
if !cached.is_expired() {
return Ok(cached.jwks.clone());
}
}
}
let jwks_url = provider.jwks_url();
debug!("Fetching JWKS from: {}", jwks_url);
let response = self
.http_client
.get(&jwks_url)
.send()
.await
.map_err(|e| OidcError::JwksFetch(format!("HTTP error: {}", e)))?;
if !response.status().is_success() {
return Err(OidcError::JwksFetch(format!(
"HTTP {} from {}",
response.status(),
jwks_url
)));
}
let jwks: JwkSet = response
.json()
.await
.map_err(|e| OidcError::Json(format!("Invalid JWKS: {}", e)))?;
{
let mut cache = self.jwks_cache.write();
cache.insert(
provider.name.clone(),
CachedJwks {
jwks: jwks.clone(),
fetched_at: Instant::now(),
ttl: Duration::from_secs(provider.jwks_cache_ttl_seconds),
},
);
}
info!(
"Fetched and cached JWKS for provider '{}' ({} keys)",
provider.name,
jwks.keys.len()
);
Ok(jwks)
}
fn extract_claim(&self, claims: &StandardClaims, path: &str) -> OidcResult<String> {
match path {
"sub" => return Ok(claims.sub.clone()),
"iss" => return Ok(claims.iss.clone()),
_ => {}
}
let parts: Vec<&str> = path.split('.').collect();
let mut current: &serde_json::Value = claims
.extra
.get(parts[0])
.ok_or_else(|| OidcError::MissingClaim(path.to_string()))?;
for part in &parts[1..] {
current = current
.get(part)
.ok_or_else(|| OidcError::MissingClaim(path.to_string()))?;
}
current
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| OidcError::MissingClaim(format!("{} is not a string", path)))
}
fn extract_string_array(&self, claims: &StandardClaims, path: &str) -> Option<Vec<String>> {
let parts: Vec<&str> = path.split('.').collect();
let mut current: Option<&serde_json::Value> = claims.extra.get(parts[0]);
for part in &parts[1..] {
current = current.and_then(|v| v.get(part));
}
current.and_then(|v| {
v.as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
})
}
pub fn clear_cache(&self) {
let mut cache = self.jwks_cache.write();
cache.clear();
info!("Cleared OIDC JWKS cache");
}
pub fn clear_provider_cache(&self, provider_name: &str) {
let mut cache = self.jwks_cache.write();
cache.remove(provider_name);
info!("Cleared OIDC JWKS cache for provider: {}", provider_name);
}
}
impl Default for OidcAuthenticator {
fn default() -> Self {
Self::new().unwrap_or_else(|e| {
tracing::error!(
"OIDC authenticator default creation failed: {}. Using no-op.",
e
);
Self {
providers: HashMap::new(),
jwks_cache: RwLock::new(HashMap::new()),
http_client: reqwest::Client::new(),
}
})
}
}
pub const OIDC_SASL_MECHANISM: &str = "OAUTHBEARER";
pub fn parse_oauthbearer_message(data: &[u8]) -> OidcResult<String> {
let message = std::str::from_utf8(data)
.map_err(|_| OidcError::InvalidToken("Invalid UTF-8 in SASL message".to_string()))?;
let parts: Vec<&str> = message.split('\x01').collect();
for part in parts {
if let Some(auth) = part.strip_prefix("auth=Bearer ") {
return Ok(auth.trim().to_string());
}
}
Err(OidcError::InvalidToken(
"No bearer token found in SASL message".to_string(),
))
}
pub fn build_oauthbearer_success() -> Vec<u8> {
vec![]
}
pub fn build_oauthbearer_error(status: &str, scope: Option<&str>) -> Vec<u8> {
let mut response = format!("{{\"status\":\"{}\"}}", status);
if let Some(s) = scope {
response = format!("{{\"status\":\"{}\",\"scope\":\"{}\"}}", status, s);
}
response.into_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_config() {
let config = OidcProviderConfig::new("test", "https://auth.example.com", "test-client")
.with_groups_claim("groups")
.with_roles_claim("realm_access.roles");
assert_eq!(config.name, "test");
assert_eq!(config.issuer, "https://auth.example.com");
assert_eq!(config.audience, "test-client");
assert_eq!(config.groups_claim, Some("groups".to_string()));
assert_eq!(config.roles_claim, Some("realm_access.roles".to_string()));
assert_eq!(
config.jwks_url(),
"https://auth.example.com/.well-known/jwks.json"
);
}
#[test]
fn test_claim_value() {
let single = ClaimValue::Single("test".to_string());
assert!(single.contains("test"));
assert!(!single.contains("other"));
let multiple = ClaimValue::Multiple(vec!["a".to_string(), "b".to_string()]);
assert!(multiple.contains("a"));
assert!(multiple.contains("b"));
assert!(!multiple.contains("c"));
}
#[test]
fn test_identity_methods() {
let identity = OidcIdentity {
username: "alice".to_string(),
subject: "alice-uuid".to_string(),
issuer: "https://auth.example.com".to_string(),
provider: "test".to_string(),
groups: vec!["admins".to_string(), "developers".to_string()],
roles: vec!["admin".to_string()],
email: Some("alice@example.com".to_string()),
name: Some("Alice".to_string()),
expires_at: 9999999999,
claims: HashMap::new(),
};
assert!(identity.has_group("admins"));
assert!(identity.has_group("developers"));
assert!(!identity.has_group("users"));
assert!(identity.has_role("admin"));
assert!(!identity.has_role("user"));
assert!(identity.has_any_group(&["admins", "users"]));
assert!(identity.has_all_groups(&["admins", "developers"]));
assert!(!identity.has_all_groups(&["admins", "users"]));
}
#[test]
fn test_parse_oauthbearer() {
let message = b"n,,\x01auth=Bearer eyJhbGciOiJSUzI1NiJ9.test\x01\x01";
let token = parse_oauthbearer_message(message).unwrap();
assert_eq!(token, "eyJhbGciOiJSUzI1NiJ9.test");
}
#[test]
fn test_parse_oauthbearer_invalid() {
let message = b"n,,\x01invalid\x01\x01";
let result = parse_oauthbearer_message(message);
assert!(result.is_err());
}
}
}
#[cfg(feature = "oidc")]
pub use oidc_impl::*;
#[cfg(not(feature = "oidc"))]
mod no_oidc {
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum OidcError {
#[error("OIDC not enabled. Build with 'oidc' feature.")]
NotEnabled,
}
pub type OidcResult<T> = Result<T, OidcError>;
#[derive(Debug, Clone)]
pub struct OidcProviderConfig {
pub name: String,
pub issuer: String,
pub audience: String,
}
#[derive(Debug, Clone)]
pub struct OidcIdentity {
pub username: String,
pub subject: String,
pub issuer: String,
pub provider: String,
pub groups: Vec<String>,
pub roles: Vec<String>,
pub email: Option<String>,
pub name: Option<String>,
pub expires_at: u64,
pub claims: HashMap<String, serde_json::Value>,
}
pub struct OidcAuthenticator;
impl OidcAuthenticator {
pub fn new() -> Self {
Self
}
pub async fn validate_token(&self, _token: &str) -> OidcResult<OidcIdentity> {
Err(OidcError::NotEnabled)
}
}
impl Default for OidcAuthenticator {
fn default() -> Self {
Self::new()
}
}
}
#[cfg(not(feature = "oidc"))]
pub use no_oidc::*;