axum_oidc_layer/
jwks.rs

1//! JWKS (JSON Web Key Set) fetching and JWT validation operations.
2
3use std::sync::Arc;
4
5use jsonwebtoken::{jwk::JwkSet, Algorithm, DecodingKey, Validation};
6use serde::de::DeserializeOwned;
7
8use crate::{
9    cache::{JwkCacheKey, JwksCache},
10    config::AuthenticationConfigProvider,
11    error::OidcError,
12};
13
14/// Validates a JWT token using a specific JWK key.
15///
16/// This function performs the actual cryptographic validation of the JWT
17/// using the provided JWK and algorithm. It returns the decoded claims
18/// if validation succeeds.
19pub fn validate_token_with_jwk<T: DeserializeOwned>(
20    token: &str,
21    jwk: &jsonwebtoken::jwk::Jwk,
22    alg: Algorithm,
23) -> Result<T, String> {
24    let mut validation = Validation::new(alg);
25    validation.validate_aud = false;
26
27    let decoding_key = DecodingKey::from_jwk(jwk)
28        .map_err(|e| format!("Failed to create DecodingKey from JWK: {e}"))?;
29
30    let token_data = jsonwebtoken::decode::<T>(token, &decoding_key, &validation)
31        .map_err(|e| format!("JWT validation failed: {e}"))?;
32
33    Ok(token_data.claims)
34}
35
36/// Fetches `JWKS` from the URI and caches individual JWK keys.
37///
38/// This function retrieves the complete JWKS from the provider, caches each
39/// individual key separately for efficient future lookups, and returns the
40/// specific JWK needed for the current token validation.
41///
42/// # Arguments
43/// * `cache` - The cache implementation to store JWK keys
44/// * `jwks_uri` - The URI to fetch the JWKS from
45/// * `config` - Configuration provider for cache TTL settings
46/// * `requested_kid` - The specific key ID needed for validation
47///
48/// # Returns
49/// The JWK corresponding to the requested key ID, or an error if not found
50pub async fn fetch_and_cache_jwks(
51    cache: &Arc<dyn JwksCache>,
52    jwks_uri: &str,
53    config: &(impl AuthenticationConfigProvider + Send + Sync),
54    requested_kid: &str,
55) -> Result<jsonwebtoken::jwk::Jwk, OidcError> {
56    // Fetch the complete JWKS from the provider
57    let response = reqwest::get(jwks_uri)
58        .await
59        .map_err(|e| OidcError::JwksError(format!("Failed to fetch JWKS: {e}")))?;
60
61    let jwks: JwkSet = response
62        .json()
63        .await
64        .map_err(|e| OidcError::JwksError(format!("Failed to parse JWKS: {e}")))?;
65
66    // Cache each individual JWK key for future use
67    let jwks_cache_ttl = config.get_jwks_cache_ttl();
68    for jwk in &jwks.keys {
69        if let Some(kid) = &jwk.common.key_id {
70            let jwk_cache_key = JwkCacheKey::from_jwks_uri_and_kid(jwks_uri, kid);
71            if let Ok(serialized) = serde_json::to_string(jwk) {
72                cache.set(jwk_cache_key.as_str(), serialized, jwks_cache_ttl);
73            }
74        }
75    }
76
77    // Return the specific JWK requested
78    jwks.find(requested_kid)
79        .cloned()
80        .ok_or_else(|| OidcError::JwksError(format!("No JWK found for kid: {requested_kid}")))
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use jsonwebtoken::{Algorithm, EncodingKey, Header};
87    use serde::{Deserialize, Serialize};
88    use std::collections::BTreeMap;
89
90    #[derive(Debug, Serialize, Deserialize, PartialEq)]
91    struct TestClaims {
92        sub: String,
93        exp: usize,
94    }
95
96    #[test]
97    fn test_validate_token_with_invalid_jwk() {
98        // Create a simple test case that should fail with invalid JWK
99        let token = "invalid.token.here";
100        let mut jwk = jsonwebtoken::jwk::Jwk::default();
101        jwk.common.key_type = Some(jsonwebtoken::jwk::KeyType::RSA);
102        
103        let result = validate_token_with_jwk::<TestClaims>(token, &jwk, Algorithm::RS256);
104        assert!(result.is_err());
105    }
106}