use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Role {
Reader,
Writer,
Admin,
}
impl Role {
pub fn can_write(self) -> bool {
matches!(self, Role::Writer | Role::Admin)
}
pub fn can_admin(self) -> bool {
matches!(self, Role::Admin)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub methods: Vec<AuthMethodConfig>,
#[serde(default)]
pub roles: HashMap<String, Role>,
#[serde(default)]
pub default_role: Option<Role>,
}
impl AuthConfig {
pub fn rbac_enabled(&self) -> bool {
!self.roles.is_empty() || self.default_role.is_some()
}
pub fn role_for(&self, subject: &str) -> Role {
self.roles
.get(subject)
.copied()
.or(self.default_role)
.unwrap_or(Role::Reader)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuthMethodConfig {
Oidc {
issuer_url: String,
},
Mtls {
ca_cert: String,
},
}
impl Default for AuthConfig {
fn default() -> Self {
Self { methods: vec![], roles: HashMap::new(), default_role: None }
}
}
#[derive(Debug, Clone)]
pub struct AuthIdentity {
pub subject: String,
pub method: String,
}
pub async fn validate_request(
config: &AuthConfig,
bearer_token: Option<&str>,
client_cert_cn: Option<&str>,
) -> Result<Option<AuthIdentity>, AuthError> {
if config.methods.is_empty() {
return Ok(None); }
for method in &config.methods {
match method {
AuthMethodConfig::Oidc { issuer_url } => {
if let Some(token) = bearer_token {
match validate_oidc_token(issuer_url, token).await {
Ok(subject) => return Ok(Some(AuthIdentity {
subject,
method: "oidc".into(),
})),
Err(_) => continue,
}
}
}
AuthMethodConfig::Mtls { .. } => {
if let Some(cn) = client_cert_cn {
return Ok(Some(AuthIdentity {
subject: cn.to_string(),
method: "mtls".into(),
}));
}
}
}
}
Err(AuthError::Unauthorized)
}
#[derive(Debug)]
pub enum AuthError {
Unauthorized,
}
const OIDC_CACHE_TTL: Duration = Duration::from_secs(120);
fn oidc_cache_ttl() -> Duration {
std::env::var("HOLGER_OIDC_CACHE_TTL_SECS")
.ok()
.and_then(|v| v.trim().parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or(OIDC_CACHE_TTL)
}
fn oidc_client() -> &'static reqwest::Client {
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
CLIENT.get_or_init(|| {
reqwest::Client::builder()
.pool_idle_timeout(Duration::from_secs(90))
.timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::Client::new())
})
}
type OidcCacheKey = (String, [u8; 32]);
fn oidc_cache() -> &'static Mutex<HashMap<OidcCacheKey, (String, Instant)>> {
static CACHE: OnceLock<Mutex<HashMap<OidcCacheKey, (String, Instant)>>> = OnceLock::new();
CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
fn token_digest(token: &str) -> [u8; 32] {
*blake3::hash(token.as_bytes()).as_bytes()
}
async fn validate_oidc_token(issuer_url: &str, token: &str) -> Result<String, AuthError> {
let ttl = oidc_cache_ttl();
let key: OidcCacheKey = (issuer_url.to_string(), token_digest(token));
if let Ok(cache) = oidc_cache().lock() {
if let Some((subject, inserted)) = cache.get(&key) {
if inserted.elapsed() < ttl {
return Ok(subject.clone());
}
}
}
let userinfo_url = format!("{}/userinfo", issuer_url.trim_end_matches('/'));
let resp = oidc_client()
.get(&userinfo_url)
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.map_err(|_| AuthError::Unauthorized)?;
if !resp.status().is_success() {
return Err(AuthError::Unauthorized);
}
let body: serde_json::Value = resp.json().await.map_err(|_| AuthError::Unauthorized)?;
let subject = body
.get("sub")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or(AuthError::Unauthorized)?;
if let Ok(mut cache) = oidc_cache().lock() {
cache.retain(|_, (_, inserted)| inserted.elapsed() < ttl);
cache.insert(key, (subject.clone(), Instant::now()));
}
Ok(subject)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
async fn spawn_userinfo(hits: Arc<AtomicUsize>) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((mut sock, _)) = listener.accept().await {
hits.fetch_add(1, Ordering::SeqCst);
let mut buf = [0u8; 1024];
let _ = sock.read(&mut buf).await; let body = br#"{"sub":"alice"}"#;
let head = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
let _ = sock.write_all(head.as_bytes()).await;
let _ = sock.write_all(body).await;
let _ = sock.flush().await;
}
});
format!("http://{addr}")
}
#[tokio::test]
async fn oidc_validation_caches_and_skips_repeat_userinfo_call() {
let hits = Arc::new(AtomicUsize::new(0));
let token = "tok-w1-cache-test-7f3a9c";
let issuer = spawn_userinfo(hits.clone()).await;
let s1 = validate_oidc_token(&issuer, token).await.expect("first validate ok");
let s2 = validate_oidc_token(&issuer, token).await.expect("second validate ok");
assert_eq!(s1, "alice");
assert_eq!(s2, "alice");
assert_eq!(
hits.load(Ordering::SeqCst),
1,
"second validation must come from cache — /userinfo should be hit once, not twice"
);
}
#[tokio::test]
async fn oidc_distinct_token_is_not_a_cache_hit() {
let hits = Arc::new(AtomicUsize::new(0));
let issuer = spawn_userinfo(hits.clone()).await;
let _ = validate_oidc_token(&issuer, "tok-w1-distinct-A-11").await.expect("A ok");
let _ = validate_oidc_token(&issuer, "tok-w1-distinct-B-22").await.expect("B ok");
assert_eq!(
hits.load(Ordering::SeqCst),
2,
"two distinct tokens must each validate against /userinfo"
);
}
#[test]
fn role_privilege_ordering_and_capabilities() {
assert!(Role::Reader < Role::Writer && Role::Writer < Role::Admin);
assert!(!Role::Reader.can_write());
assert!(Role::Writer.can_write());
assert!(Role::Admin.can_write());
assert!(!Role::Writer.can_admin());
assert!(Role::Admin.can_admin());
}
#[test]
fn rbac_is_off_until_policy_is_configured() {
let open = AuthConfig::default();
assert!(!open.rbac_enabled());
let mut roles = HashMap::new();
roles.insert("alice".to_string(), Role::Writer);
let cfg = AuthConfig { methods: vec![], roles, default_role: None };
assert!(cfg.rbac_enabled());
let cfg2 = AuthConfig { methods: vec![], roles: HashMap::new(), default_role: Some(Role::Reader) };
assert!(cfg2.rbac_enabled());
}
#[test]
fn role_resolution_explicit_then_default_then_least_privilege() {
let mut roles = HashMap::new();
roles.insert("alice".to_string(), Role::Admin);
roles.insert("bot".to_string(), Role::Writer);
let cfg = AuthConfig { methods: vec![], roles, default_role: Some(Role::Reader) };
assert_eq!(cfg.role_for("alice"), Role::Admin);
assert_eq!(cfg.role_for("bot"), Role::Writer);
assert_eq!(cfg.role_for("stranger"), Role::Reader);
let cfg_nodef = AuthConfig {
methods: vec![],
roles: HashMap::from([("bot".to_string(), Role::Writer)]),
default_role: None,
};
assert_eq!(cfg_nodef.role_for("stranger"), Role::Reader);
assert!(!cfg_nodef.role_for("stranger").can_write(), "unmapped identity cannot write");
}
}