use super::ClaimMappings;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TokenValidatorConfig {
#[serde(rename = "jwt")]
Jwt(JwtValidatorConfig),
#[serde(rename = "introspection")]
Introspection(IntrospectionValidatorConfig),
#[serde(rename = "proxy")]
Proxy(ProxyValidatorConfig),
#[serde(rename = "mock")]
Mock(MockValidatorConfig),
#[serde(rename = "none")]
#[default]
Disabled,
}
impl TokenValidatorConfig {
pub fn jwt(issuer: impl Into<String>, audience: impl Into<String>) -> Self {
Self::Jwt(JwtValidatorConfig {
issuer: issuer.into(),
audience: audience.into(),
jwks_uri: None,
algorithms: default_algorithms(),
jwks_cache_ttl: default_jwks_ttl(),
claim_mappings: ClaimMappings::default(),
leeway_seconds: 60,
})
}
pub fn mock(default_user_id: impl Into<String>) -> Self {
Self::Mock(MockValidatorConfig {
default_user_id: default_user_id.into(),
default_tenant_id: None,
default_scopes: vec!["read".to_string(), "write".to_string()],
default_client_id: Some("mock-client".to_string()),
claims: serde_json::Value::Object(serde_json::Map::new()),
always_authenticated: true,
})
}
pub fn disabled() -> Self {
Self::Disabled
}
pub fn requires_auth(&self) -> bool {
!matches!(self, Self::Disabled)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtValidatorConfig {
pub issuer: String,
pub audience: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<String>,
#[serde(default = "default_algorithms")]
pub algorithms: Vec<String>,
#[serde(default = "default_jwks_ttl")]
pub jwks_cache_ttl: u64,
#[serde(default)]
pub claim_mappings: ClaimMappings,
#[serde(default = "default_leeway")]
pub leeway_seconds: u64,
}
impl JwtValidatorConfig {
pub fn jwks_uri(&self) -> String {
self.jwks_uri.clone().unwrap_or_else(|| {
format!(
"{}/.well-known/jwks.json",
self.issuer.trim_end_matches('/')
)
})
}
pub fn cache_ttl(&self) -> Duration {
Duration::from_secs(self.jwks_cache_ttl)
}
pub fn cognito(region: &str, user_pool_id: &str, client_id: &str) -> Self {
Self {
issuer: format!(
"https://cognito-idp.{}.amazonaws.com/{}",
region, user_pool_id
),
audience: client_id.to_string(),
jwks_uri: None,
algorithms: default_algorithms(),
jwks_cache_ttl: default_jwks_ttl(),
claim_mappings: ClaimMappings::cognito(),
leeway_seconds: default_leeway(),
}
}
pub fn entra(tenant_id: &str, audience: &str) -> Self {
Self {
issuer: format!("https://login.microsoftonline.com/{}/v2.0", tenant_id),
audience: audience.to_string(),
jwks_uri: Some(format!(
"https://login.microsoftonline.com/{}/discovery/v2.0/keys",
tenant_id
)),
algorithms: default_algorithms(),
jwks_cache_ttl: default_jwks_ttl(),
claim_mappings: ClaimMappings::entra(),
leeway_seconds: default_leeway(),
}
}
pub fn google(client_id: &str) -> Self {
Self {
issuer: "https://accounts.google.com".to_string(),
audience: client_id.to_string(),
jwks_uri: Some("https://www.googleapis.com/oauth2/v3/certs".to_string()),
algorithms: default_algorithms(),
jwks_cache_ttl: default_jwks_ttl(),
claim_mappings: ClaimMappings::google(),
leeway_seconds: default_leeway(),
}
}
pub fn okta(domain: &str, audience: &str) -> Self {
Self {
issuer: format!("https://{}", domain),
audience: audience.to_string(),
jwks_uri: Some(format!("https://{}/oauth2/v1/keys", domain)),
algorithms: default_algorithms(),
jwks_cache_ttl: default_jwks_ttl(),
claim_mappings: ClaimMappings::okta(),
leeway_seconds: default_leeway(),
}
}
pub fn auth0(domain: &str, audience: &str) -> Self {
Self {
issuer: format!("https://{}/", domain),
audience: audience.to_string(),
jwks_uri: Some(format!("https://{}/.well-known/jwks.json", domain)),
algorithms: default_algorithms(),
jwks_cache_ttl: default_jwks_ttl(),
claim_mappings: ClaimMappings::auth0(),
leeway_seconds: default_leeway(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntrospectionValidatorConfig {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(default)]
pub headers: std::collections::HashMap<String, String>,
#[serde(default = "default_timeout")]
pub timeout_seconds: u64,
#[serde(default)]
pub claim_mappings: ClaimMappings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyValidatorConfig {
pub url: String,
#[serde(default)]
pub forward_headers: Vec<String>,
#[serde(default = "default_timeout")]
pub timeout_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MockValidatorConfig {
pub default_user_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_tenant_id: Option<String>,
#[serde(default)]
pub default_scopes: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_client_id: Option<String>,
#[serde(default)]
pub claims: serde_json::Value,
#[serde(default = "default_always_auth")]
pub always_authenticated: bool,
}
impl Default for MockValidatorConfig {
fn default() -> Self {
Self {
default_user_id: "mock-user".to_string(),
default_tenant_id: None,
default_scopes: vec!["read".to_string(), "write".to_string()],
default_client_id: Some("mock-client".to_string()),
claims: serde_json::Value::Object(serde_json::Map::new()),
always_authenticated: true,
}
}
}
fn default_algorithms() -> Vec<String> {
vec!["RS256".to_string()]
}
fn default_jwks_ttl() -> u64 {
3600 }
fn default_leeway() -> u64 {
60 }
fn default_timeout() -> u64 {
10 }
fn default_always_auth() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_config_creation() {
let config = TokenValidatorConfig::jwt(
"https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx",
"my-client-id",
);
match config {
TokenValidatorConfig::Jwt(jwt) => {
assert_eq!(
jwt.issuer,
"https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx"
);
assert_eq!(jwt.audience, "my-client-id");
assert!(jwt.jwks_uri.is_none());
},
_ => panic!("Expected JWT config"),
}
}
#[test]
fn test_mock_config_creation() {
let config = TokenValidatorConfig::mock("test-user");
match config {
TokenValidatorConfig::Mock(mock) => {
assert_eq!(mock.default_user_id, "test-user");
assert!(mock.always_authenticated);
},
_ => panic!("Expected Mock config"),
}
}
#[test]
fn test_jwks_uri_derivation() {
let config = JwtValidatorConfig {
issuer: "https://issuer.example.com".to_string(),
audience: "audience".to_string(),
jwks_uri: None,
algorithms: default_algorithms(),
jwks_cache_ttl: default_jwks_ttl(),
claim_mappings: ClaimMappings::default(),
leeway_seconds: default_leeway(),
};
assert_eq!(
config.jwks_uri(),
"https://issuer.example.com/.well-known/jwks.json"
);
}
#[test]
fn test_provider_specific_configs() {
let cognito = JwtValidatorConfig::cognito("us-east-1", "us-east-1_xxxxx", "client-id");
assert!(cognito.issuer.contains("cognito-idp"));
let entra = JwtValidatorConfig::entra("tenant-id", "api://my-api");
assert!(entra.issuer.contains("microsoftonline"));
let google = JwtValidatorConfig::google("client-id.apps.googleusercontent.com");
assert_eq!(google.issuer, "https://accounts.google.com");
}
#[test]
fn test_config_serialization() {
let config = TokenValidatorConfig::jwt("https://issuer.example.com", "audience");
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("\"type\":\"jwt\""));
let deserialized: TokenValidatorConfig = serde_json::from_str(&json).unwrap();
match deserialized {
TokenValidatorConfig::Jwt(jwt) => {
assert_eq!(jwt.issuer, "https://issuer.example.com");
},
_ => panic!("Expected JWT config"),
}
}
}