use std::collections::HashSet;
use std::time::Duration;
use crate::claims::ClaimValidator;
use crate::common::error::{ErrorVerbosity, JwtError};
use crate::oidc::discovery::{DiscoveryDocument, OidcDiscovery};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenUse {
Id,
Access,
}
impl TokenUse {
pub fn as_str(&self) -> &'static str {
match self {
TokenUse::Id => "id",
TokenUse::Access => "access",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"id" => Some(TokenUse::Id),
"access" => Some(TokenUse::Access),
_ => None,
}
}
}
pub struct OidcProviderConfig {
pub issuer: String,
pub jwks_url: Option<String>,
pub client_ids: Vec<String>,
pub allowed_token_uses: Vec<TokenUse>,
pub clock_skew: Duration,
pub jwk_cache_duration: Duration,
pub discovery_cache_duration: Duration,
pub required_claims: HashSet<String>,
#[allow(clippy::type_complexity)]
pub custom_validators: Vec<Box<dyn ClaimValidator + Send + Sync>>,
pub error_verbosity: ErrorVerbosity,
pub use_discovery: bool,
}
impl std::fmt::Debug for OidcProviderConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OidcProviderConfig")
.field("issuer", &self.issuer)
.field("jwks_url", &self.jwks_url)
.field("client_ids", &self.client_ids)
.field("clock_skew", &self.clock_skew)
.field("jwk_cache_duration", &self.jwk_cache_duration)
.field("discovery_cache_duration", &self.discovery_cache_duration)
.field("required_claims", &self.required_claims)
.field(
"custom_validators",
&format!("[{} validators]", self.custom_validators.len()),
)
.field("error_verbosity", &self.error_verbosity)
.field("use_discovery", &self.use_discovery)
.finish()
}
}
impl Clone for OidcProviderConfig {
fn clone(&self) -> Self {
Self {
issuer: self.issuer.clone(),
jwks_url: self.jwks_url.clone(),
client_ids: self.client_ids.clone(),
allowed_token_uses: self.allowed_token_uses.clone(),
clock_skew: self.clock_skew,
jwk_cache_duration: self.jwk_cache_duration,
discovery_cache_duration: self.discovery_cache_duration,
required_claims: self.required_claims.clone(),
custom_validators: Vec::new(), error_verbosity: self.error_verbosity,
use_discovery: self.use_discovery,
}
}
}
impl OidcProviderConfig {
pub fn new(
issuer: &str,
jwks_url: Option<&str>,
client_ids: &[String],
token_uses: Option<Vec<TokenUse>>,
) -> Result<Self, JwtError> {
if issuer.is_empty() {
return Err(JwtError::ConfigurationError {
parameter: Some("issuer".to_string()),
error: "Issuer URL cannot be empty".to_string(),
});
}
if !issuer.starts_with("http://") && !issuer.starts_with("https://") {
return Err(JwtError::ConfigurationError {
parameter: Some("issuer".to_string()),
error: "Issuer URL must start with http:// or https://".to_string(),
});
}
if let Some(url) = jwks_url {
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(JwtError::ConfigurationError {
parameter: Some("jwks_url".to_string()),
error: "JWKS URL must start with http:// or https://".to_string(),
});
}
}
let token_uses = match token_uses {
None => vec![TokenUse::Id, TokenUse::Access],
Some(tu) => tu,
};
Ok(Self {
issuer: issuer.to_string(),
jwks_url: jwks_url.map(|s| s.to_string()),
client_ids: client_ids.to_vec(),
allowed_token_uses: token_uses,
clock_skew: Duration::from_secs(60), jwk_cache_duration: Duration::from_secs(3600 * 24), discovery_cache_duration: Duration::from_secs(3600 * 24), required_claims: HashSet::from([
"sub".to_string(),
"iss".to_string(),
"aud".to_string(),
"exp".to_string(),
"iat".to_string(),
]),
custom_validators: Vec::new(),
error_verbosity: ErrorVerbosity::Standard,
use_discovery: jwks_url.is_none(), })
}
pub fn with_discovery(issuer: &str, client_ids: &[String]) -> Result<Self, JwtError> {
let mut config = Self::new(issuer, None, client_ids, None)?;
config.use_discovery = true;
Ok(config)
}
pub fn set_discovery_enabled(mut self, use_discovery: bool) -> Self {
self.use_discovery = use_discovery;
self
}
pub fn with_clock_skew(mut self, skew: Duration) -> Self {
self.clock_skew = skew;
self
}
pub fn with_cache_duration(mut self, duration: Duration) -> Self {
self.jwk_cache_duration = duration;
self
}
pub fn with_discovery_cache_duration(mut self, duration: Duration) -> Self {
self.discovery_cache_duration = duration;
self
}
pub fn with_required_claim(mut self, claim: &str) -> Self {
self.required_claims.insert(claim.to_string());
self
}
pub fn with_custom_validator(
mut self,
validator: Box<dyn ClaimValidator + Send + Sync>,
) -> Self {
self.custom_validators.push(validator);
self
}
pub fn with_error_verbosity(mut self, verbosity: ErrorVerbosity) -> Self {
self.error_verbosity = verbosity;
self
}
pub fn get_well_known_url(&self) -> String {
format!(
"{}/.well-known/openid-configuration",
self.issuer.trim_end_matches('/')
)
}
pub async fn discover_jwks_url(&self, discovery: &OidcDiscovery) -> Result<String, JwtError> {
if let Some(url) = &self.jwks_url {
if !self.use_discovery {
return Ok(url.clone());
}
}
let document = discovery
.discover_with_fallback(&self.issuer, self.jwks_url.as_deref())
.await?;
Ok(document.jwks_uri.clone())
}
pub async fn discover(&self, discovery: &OidcDiscovery) -> Result<DiscoveryDocument, JwtError> {
if self.use_discovery {
discovery
.discover_with_fallback(&self.issuer, self.jwks_url.as_deref())
.await
} else if let Some(jwks_url) = &self.jwks_url {
Ok(DiscoveryDocument::new(&self.issuer, jwks_url))
} else {
Err(JwtError::ConfigurationError {
parameter: Some("jwks_url".to_string()),
error: "JWKS URL is required when auto-discovery is disabled".to_string(),
})
}
}
pub fn create_discovery(&self) -> OidcDiscovery {
OidcDiscovery::new(self.discovery_cache_duration)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_valid_config() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
Some("https://accounts.example.com/.well-known/jwks.json"),
&["client1".to_string()],
None,
);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.issuer, "https://accounts.example.com");
assert_eq!(
config.jwks_url,
Some("https://accounts.example.com/.well-known/jwks.json".to_string())
);
assert_eq!(config.client_ids, vec!["client1".to_string()]);
assert!(!config.use_discovery); }
#[test]
fn test_new_with_discovery() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
None,
&["client1".to_string()],
None,
);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.issuer, "https://accounts.example.com");
assert_eq!(config.jwks_url, None);
assert!(config.use_discovery); }
#[test]
fn test_with_discovery_explicit() {
let config = OidcProviderConfig::with_discovery(
"https://accounts.example.com",
&["client1".to_string()],
);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.issuer, "https://accounts.example.com");
assert_eq!(config.jwks_url, None);
assert!(config.use_discovery); }
#[test]
fn test_new_empty_issuer() {
let config = OidcProviderConfig::new(
"",
Some("https://accounts.example.com/.well-known/jwks.json"),
&["client1".to_string()],
None,
);
assert!(config.is_err());
match config.unwrap_err() {
JwtError::ConfigurationError { parameter, .. } => {
assert_eq!(parameter, Some("issuer".to_string()));
}
_ => panic!("Expected ConfigurationError"),
}
}
#[test]
fn test_new_invalid_issuer_url() {
let config = OidcProviderConfig::new(
"invalid-url",
Some("https://accounts.example.com/.well-known/jwks.json"),
&["client1".to_string()],
None,
);
assert!(config.is_err());
match config.unwrap_err() {
JwtError::ConfigurationError { parameter, .. } => {
assert_eq!(parameter, Some("issuer".to_string()));
}
_ => panic!("Expected ConfigurationError"),
}
}
#[test]
fn test_new_invalid_jwks_url() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
Some("invalid-url"),
&["client1".to_string()],
None,
);
assert!(config.is_err());
match config.unwrap_err() {
JwtError::ConfigurationError { parameter, .. } => {
assert_eq!(parameter, Some("jwks_url".to_string()));
}
_ => panic!("Expected ConfigurationError"),
}
}
#[test]
fn test_with_clock_skew() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
None,
&["client1".to_string()],
None,
)
.unwrap()
.with_clock_skew(Duration::from_secs(120));
assert_eq!(config.clock_skew, Duration::from_secs(120));
}
#[test]
fn test_with_cache_duration() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
None,
&["client1".to_string()],
None,
)
.unwrap()
.with_cache_duration(Duration::from_secs(3600 * 12));
assert_eq!(config.jwk_cache_duration, Duration::from_secs(3600 * 12));
}
#[test]
fn test_with_discovery_cache_duration() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
None,
&["client1".to_string()],
None,
)
.unwrap()
.with_discovery_cache_duration(Duration::from_secs(3600 * 6));
assert_eq!(
config.discovery_cache_duration,
Duration::from_secs(3600 * 6)
);
}
#[test]
fn test_with_required_claim() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
None,
&["client1".to_string()],
None,
)
.unwrap()
.with_required_claim("nonce");
assert!(config.required_claims.contains("nonce"));
}
#[test]
fn test_with_discovery_flag() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
Some("https://accounts.example.com/.well-known/jwks.json"),
&["client1".to_string()],
None,
)
.unwrap()
.set_discovery_enabled(true);
assert!(config.use_discovery);
let config = config.set_discovery_enabled(false);
assert!(!config.use_discovery);
}
#[test]
fn test_get_well_known_url() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
None,
&["client1".to_string()],
None,
)
.unwrap();
assert_eq!(
config.get_well_known_url(),
"https://accounts.example.com/.well-known/openid-configuration"
);
let config = OidcProviderConfig::new(
"https://accounts.example.com/",
None,
&["client1".to_string()],
None,
)
.unwrap();
assert_eq!(
config.get_well_known_url(),
"https://accounts.example.com/.well-known/openid-configuration"
);
}
#[test]
fn test_create_discovery() {
let config = OidcProviderConfig::new(
"https://accounts.example.com",
None,
&["client1".to_string()],
None,
)
.unwrap()
.with_discovery_cache_duration(Duration::from_secs(7200));
let discovery = config.create_discovery();
assert_eq!(discovery.get_cache_duration(), Duration::from_secs(7200));
}
}