use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
use crate::common::error::JwtError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveryDocument {
pub issuer: String,
pub jwks_uri: String,
#[serde(default)]
pub authorization_endpoint: Option<String>,
#[serde(default)]
pub token_endpoint: Option<String>,
#[serde(default)]
pub userinfo_endpoint: Option<String>,
#[serde(default)]
pub end_session_endpoint: Option<String>,
#[serde(default)]
pub response_types_supported: Option<Vec<String>>,
#[serde(default)]
pub subject_types_supported: Option<Vec<String>>,
#[serde(default)]
pub id_token_signing_alg_values_supported: Option<Vec<String>>,
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}
impl DiscoveryDocument {
pub fn new(issuer: &str, jwks_uri: &str) -> Self {
Self {
issuer: issuer.to_string(),
jwks_uri: jwks_uri.to_string(),
authorization_endpoint: None,
token_endpoint: None,
userinfo_endpoint: None,
end_session_endpoint: None,
response_types_supported: None,
subject_types_supported: None,
id_token_signing_alg_values_supported: None,
additional_fields: HashMap::new(),
}
}
pub fn with_endpoints(
issuer: &str,
jwks_uri: &str,
authorization_endpoint: Option<&str>,
token_endpoint: Option<&str>,
userinfo_endpoint: Option<&str>,
) -> Self {
Self {
issuer: issuer.to_string(),
jwks_uri: jwks_uri.to_string(),
authorization_endpoint: authorization_endpoint.map(|s| s.to_string()),
token_endpoint: token_endpoint.map(|s| s.to_string()),
userinfo_endpoint: userinfo_endpoint.map(|s| s.to_string()),
end_session_endpoint: None,
response_types_supported: None,
subject_types_supported: None,
id_token_signing_alg_values_supported: None,
additional_fields: HashMap::new(),
}
}
pub fn validate(&self, expected_issuer: &str) -> Result<(), JwtError> {
if self.issuer != expected_issuer {
return Err(JwtError::ConfigurationError {
parameter: Some("issuer".to_string()),
error: format!(
"Issuer mismatch: expected {}, got {}",
expected_issuer, self.issuer
),
});
}
if self.jwks_uri.is_empty() {
return Err(JwtError::ConfigurationError {
parameter: Some("jwks_uri".to_string()),
error: "JWKS URI is empty".to_string(),
});
}
if !self.jwks_uri.starts_with("http://") && !self.jwks_uri.starts_with("https://") {
return Err(JwtError::ConfigurationError {
parameter: Some("jwks_uri".to_string()),
error: "JWKS URI must start with http:// or https://".to_string(),
});
}
Ok(())
}
}
#[derive(Debug)]
struct CachedDiscoveryDocument {
document: DiscoveryDocument,
inserted_at: Instant,
}
#[derive(Clone, Debug)]
pub struct OidcDiscovery {
#[allow(dead_code)]
client: Client,
cache: Arc<RwLock<HashMap<String, CachedDiscoveryDocument>>>,
cache_duration: Duration,
}
impl OidcDiscovery {
pub fn new(cache_duration: Duration) -> Self {
Self {
client: Client::new(),
cache: Arc::new(RwLock::new(HashMap::new())),
cache_duration,
}
}
pub fn with_client(client: Client, cache_duration: Duration) -> Self {
Self {
client,
cache: Arc::new(RwLock::new(HashMap::new())),
cache_duration,
}
}
pub async fn discover(&self, issuer: &str) -> Result<DiscoveryDocument, JwtError> {
{
let cache = self.cache.read().unwrap();
if let Some(cached) = cache.get(issuer) {
if cached.inserted_at.elapsed() < self.cache_duration {
debug!("Using cached discovery document for issuer: {}", issuer);
return Ok(cached.document.clone());
}
debug!(
"Cached discovery document for issuer {} has expired",
issuer
);
}
}
debug!("Fetching discovery document for issuer: {}", issuer);
let document = self.fetch_discovery_document(issuer).await?;
{
let mut cache = self.cache.write().unwrap();
cache.insert(
issuer.to_string(),
CachedDiscoveryDocument {
document: document.clone(),
inserted_at: Instant::now(),
},
);
debug!("Cached discovery document for issuer: {}", issuer);
}
Ok(document)
}
pub async fn discover_with_fallback(
&self,
issuer: &str,
jwks_url: Option<&str>,
) -> Result<DiscoveryDocument, JwtError> {
match self.discover(issuer).await {
Ok(document) => {
debug!(
"Successfully discovered OIDC configuration for issuer: {}",
issuer
);
Ok(document)
}
Err(err) => {
if let Some(jwks_url) = jwks_url {
warn!(
"Discovery failed for issuer {}, using fallback JWKS URL: {}",
issuer, jwks_url
);
let document = DiscoveryDocument::new(issuer, jwks_url);
{
let mut cache = self.cache.write().unwrap();
cache.insert(
issuer.to_string(),
CachedDiscoveryDocument {
document: document.clone(),
inserted_at: Instant::now(),
},
);
debug!("Cached manual discovery document for issuer: {}", issuer);
}
Ok(document)
} else {
Err(err)
}
}
}
}
async fn fetch_discovery_document(&self, issuer: &str) -> Result<DiscoveryDocument, JwtError> {
let well_known_url = format!(
"{}/.well-known/openid-configuration",
issuer.trim_end_matches('/')
);
debug!("Fetching discovery document from URL: {}", well_known_url);
let response = self.client.get(&well_known_url).send().await.map_err(|e| {
JwtError::JwksFetchError {
url: Some(well_known_url.clone()),
error: format!("Failed to fetch discovery document: {}", e),
}
})?;
if !response.status().is_success() {
return Err(JwtError::JwksFetchError {
url: Some(well_known_url),
error: format!("HTTP {}", response.status()),
});
}
let document: DiscoveryDocument =
response
.json()
.await
.map_err(|e| JwtError::JwksFetchError {
url: Some(well_known_url.clone()),
error: format!("Failed to parse discovery document: {}", e),
})?;
document
.validate(issuer)
.map_err(|e| JwtError::JwksFetchError {
url: Some(well_known_url),
error: format!("Invalid discovery document: {}", e),
})?;
debug!(
"Successfully fetched discovery document for issuer: {}",
issuer
);
info!("JWKS URI for issuer {}: {}", issuer, document.jwks_uri);
Ok(document)
}
pub fn clear_cache(&self, issuer: &str) {
let mut cache = self.cache.write().unwrap();
if cache.remove(issuer).is_some() {
debug!("Cleared cache for issuer: {}", issuer);
}
}
pub fn clear_all_cache(&self) {
let mut cache = self.cache.write().unwrap();
let count = cache.len();
cache.clear();
debug!("Cleared cache for {} issuers", count);
}
pub fn get_cached_issuers(&self) -> Vec<String> {
let cache = self.cache.read().unwrap();
cache.keys().cloned().collect()
}
pub fn is_cached(&self, issuer: &str) -> bool {
let cache = self.cache.read().unwrap();
cache.contains_key(issuer)
}
pub fn get_cache_duration(&self) -> Duration {
self.cache_duration
}
pub fn set_cache_duration(&mut self, duration: Duration) {
self.cache_duration = duration;
}
pub fn add_to_cache(&self, issuer: &str, document: DiscoveryDocument) {
let mut cache = self.cache.write().unwrap();
cache.insert(
issuer.to_string(),
CachedDiscoveryDocument {
document,
inserted_at: Instant::now(),
},
);
debug!(
"Manually added discovery document to cache for issuer: {}",
issuer
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use mockito::Server;
#[tokio::test]
async fn test_fetch_discovery_document_success() {
let mut server = Server::new_async().await;
let mock_server = server
.mock("GET", "/.well-known/openid-configuration")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(format!(
r#"{{
"issuer": "{}",
"jwks_uri": "{}/.well-known/jwks.json",
"authorization_endpoint": "{}/oauth2/authorize",
"token_endpoint": "{}/oauth2/token"
}}"#,
server.url(),
server.url(),
server.url(),
server.url()
))
.create();
let discovery = OidcDiscovery::new(Duration::from_secs(3600));
let document = discovery
.fetch_discovery_document(&server.url())
.await
.unwrap();
assert_eq!(document.issuer, server.url());
assert_eq!(
document.jwks_uri,
format!("{}/.well-known/jwks.json", server.url())
);
assert_eq!(
document.authorization_endpoint,
Some(format!("{}/oauth2/authorize", server.url()))
);
assert_eq!(
document.token_endpoint,
Some(format!("{}/oauth2/token", server.url()))
);
mock_server.assert();
}
#[tokio::test]
async fn test_fetch_discovery_document_http_error() {
let mut server = Server::new_async().await;
let mock_server = server
.mock("GET", "/.well-known/openid-configuration")
.with_status(404)
.create();
let discovery = OidcDiscovery::new(Duration::from_secs(3600));
let result = discovery.fetch_discovery_document(&server.url()).await;
assert!(result.is_err());
match result.unwrap_err() {
JwtError::JwksFetchError { url, .. } => {
assert_eq!(
url,
Some(format!("{}/.well-known/openid-configuration", server.url()))
);
}
_ => panic!("Expected JwksFetchError"),
}
mock_server.assert();
}
#[tokio::test]
async fn test_fetch_discovery_document_invalid_json() {
let mut server = Server::new_async().await;
let mock_server = server
.mock("GET", "/.well-known/openid-configuration")
.with_status(200)
.with_header("content-type", "application/json")
.with_body("invalid json")
.create();
let discovery = OidcDiscovery::new(Duration::from_secs(3600));
let result = discovery.fetch_discovery_document(&server.url()).await;
assert!(result.is_err());
match result.unwrap_err() {
JwtError::JwksFetchError { url, .. } => {
assert_eq!(
url,
Some(format!("{}/.well-known/openid-configuration", server.url()))
);
}
_ => panic!("Expected JwksFetchError"),
}
mock_server.assert();
}
#[tokio::test]
async fn test_fetch_discovery_document_issuer_mismatch() {
let mut server = Server::new_async().await;
let mock_server = server
.mock("GET", "/.well-known/openid-configuration")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(format!(
r#"{{
"issuer": "https://wrong-issuer.com",
"jwks_uri": "{}/.well-known/jwks.json"
}}"#,
server.url()
))
.create();
let discovery = OidcDiscovery::new(Duration::from_secs(3600));
let result = discovery.fetch_discovery_document(&server.url()).await;
assert!(result.is_err());
match result.unwrap_err() {
JwtError::JwksFetchError { url, error } => {
assert_eq!(
url,
Some(format!("{}/.well-known/openid-configuration", server.url()))
);
assert!(error.contains("Issuer mismatch"));
}
_ => panic!("Expected JwksFetchError"),
}
mock_server.assert();
}
#[tokio::test]
async fn test_discover_caching() {
let mut server = Server::new_async().await;
let mock_server = server
.mock("GET", "/.well-known/openid-configuration")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(format!(
r#"{{
"issuer": "{}",
"jwks_uri": "{}/.well-known/jwks.json"
}}"#,
server.url(),
server.url()
))
.expect(1) .create();
let discovery = OidcDiscovery::new(Duration::from_secs(3600));
let document1 = discovery.discover(&server.url()).await.unwrap();
let document2 = discovery.discover(&server.url()).await.unwrap();
assert_eq!(document1.issuer, document2.issuer);
assert_eq!(document1.jwks_uri, document2.jwks_uri);
mock_server.assert();
}
#[tokio::test]
async fn test_discover_cache_expiration() {
let mut server = Server::new_async().await;
let mock_server = server
.mock("GET", "/.well-known/openid-configuration")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(format!(
r#"{{
"issuer": "{}",
"jwks_uri": "{}/.well-known/jwks.json"
}}"#,
server.url(),
server.url()
))
.expect(2) .create();
let discovery = OidcDiscovery::new(Duration::from_millis(1));
let _ = discovery.discover(&server.url()).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let _ = discovery.discover(&server.url()).await.unwrap();
mock_server.assert();
}
}