axum_oidc_layer/
validation.rs

1//! High-level token validation and OIDC configuration management.
2
3use std::sync::Arc;
4
5use jsonwebtoken::decode_header;
6use serde::de::DeserializeOwned;
7
8use crate::{
9    cache::{ConfigCacheKey, JwkCacheKey, JwksCache, TokenCacheKey},
10    config::{AuthenticationConfigProvider, OidcConfiguration},
11    error::OidcError,
12    jwks::{fetch_and_cache_jwks, validate_token_with_jwk},
13    token::extract_token_ttl,
14};
15
16/// Fetches or retrieves OIDC configuration from cache.
17///
18/// This function first checks the cache for existing OIDC configuration.
19/// If not found or expired, it fetches the configuration from the provider's
20/// well-known endpoint and caches it for future use.
21///
22/// # Arguments
23/// * `cache` - The cache implementation to use
24/// * `config_url` - The URL to fetch the OIDC configuration from
25/// * `config` - Configuration provider for cache TTL settings
26///
27/// # Returns
28/// The OIDC configuration containing the JWKS URI and other provider information
29pub async fn get_oidc_config(
30    cache: &Arc<dyn JwksCache>,
31    config_url: &str,
32    config: &(impl AuthenticationConfigProvider + Send + Sync),
33) -> Result<OidcConfiguration, OidcError> {
34    let config_cache_key = ConfigCacheKey::from_url(config_url);
35    
36    // Try cache first
37    if let Some(cached_config) = cache.get(config_cache_key.as_str()) {
38        return serde_json::from_str(&cached_config)
39            .map_err(|e| OidcError::CacheError(format!("Failed to deserialize cached config: {e}")));
40    }
41
42    // Cache miss - fetch from remote
43    let response = reqwest::get(config_url)
44        .await
45        .map_err(|e| OidcError::ConfigurationError(format!("Failed to fetch OIDC config: {e}")))?;
46
47    let config_doc: OidcConfiguration = response
48        .json()
49        .await
50        .map_err(|e| OidcError::ConfigurationError(format!("Failed to parse OIDC config: {e}")))?;
51
52    // Cache the result
53    if let Ok(serialized) = serde_json::to_string(&config_doc) {
54        cache.set(config_cache_key.as_str(), serialized, config.get_config_cache_ttl());
55    }
56
57    Ok(config_doc)
58}
59
60/// Validates a JWT token and caches the validation result.
61///
62/// This is the main validation orchestration function that:
63/// 1. Checks if the token's claims are already cached
64/// 2. If not cached, extracts the key ID from the token header
65/// 3. Retrieves the appropriate JWK (from cache or by fetching JWKS)
66/// 4. Validates the token and caches the result
67///
68/// # Arguments
69/// * `cache` - The cache implementation to use
70/// * `token` - The JWT token to validate
71/// * `jwks_uri` - The URI where JWKs can be fetched
72/// * `config` - Configuration provider for cache TTL settings
73///
74/// # Returns
75/// The validated and deserialized token claims
76pub async fn validate_and_cache_token<T>(
77    cache: &Arc<dyn JwksCache>,
78    token: &str,
79    jwks_uri: &str,
80    config: &(impl AuthenticationConfigProvider + Send + Sync),
81) -> Result<T, OidcError>
82where
83    T: DeserializeOwned + serde::Serialize + Clone + Send + Sync + 'static,
84{
85    let token_cache_key = TokenCacheKey::from_token(token);
86    
87    // Check token cache first
88    if let Some(cached_claims) = cache.get(token_cache_key.as_str()) {
89        return serde_json::from_str(&cached_claims)
90            .map_err(|e| OidcError::CacheError(format!("Failed to deserialize cached token: {e}")));
91    }
92
93    // Extract kid from token header
94    let header = decode_header(token)
95        .map_err(|e| OidcError::InvalidToken(format!("Failed to decode JWT header: {e}")))?;
96
97    let kid = header.kid.ok_or(OidcError::MissingKid)?;
98
99    // Check if we have this specific JWK key cached
100    let jwk_cache_key = JwkCacheKey::from_jwks_uri_and_kid(jwks_uri, &kid);
101    let jwk = match cache.get(jwk_cache_key.as_str()) {
102        Some(cached_jwk) => serde_json::from_str(&cached_jwk)
103            .map_err(|e| OidcError::CacheError(format!("Failed to deserialize cached JWK: {e}")))?,
104        None => {
105            // Cache miss - fetch full JWKS and cache individual keys
106            fetch_and_cache_jwks(cache, jwks_uri, config, &kid).await?
107        }
108    };
109
110    // Validate token with JWK
111    let validated_claims = validate_token_with_jwk::<T>(token, &jwk, header.alg)
112        .map_err(OidcError::ValidationError)?;
113
114    // Cache the validated token
115    if let Ok(serialized) = serde_json::to_string(&validated_claims) {
116        let token_ttl = extract_token_ttl(token);
117        cache.set(token_cache_key.as_str(), serialized, token_ttl);
118    }
119
120    Ok(validated_claims)
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::cache::InMemoryCache;
127    use crate::config::{AuthenticationConfigProvider, DEFAULT_CONFIG_TTL, DEFAULT_JWKS_TTL};
128    use std::time::Duration;
129
130    struct TestConfig;
131    
132    impl AuthenticationConfigProvider for TestConfig {
133        fn get_provider_url(&self) -> String {
134            "https://example.com".to_string()
135        }
136        
137        fn get_openid_configuration_url(&self) -> Option<String> {
138            None
139        }
140        
141        fn get_jwks_cache_ttl(&self) -> Duration {
142            DEFAULT_JWKS_TTL
143        }
144        
145        fn get_config_cache_ttl(&self) -> Duration {
146            DEFAULT_CONFIG_TTL
147        }
148    }
149
150    #[test]
151    fn test_token_cache_key_consistency() {
152        let token = "test.token.here";
153        let key1 = TokenCacheKey::from_token(token);
154        let key2 = TokenCacheKey::from_token(token);
155        assert_eq!(key1, key2);
156    }
157}