Skip to main content

auths_infra_http/
oidc_validator.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::Jwk};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use std::time::Duration;
8use zeroize::Zeroize;
9
10use crate::default_http_client;
11use auths_oidc_port::{JwksClient, JwtValidator, OidcError, OidcValidationConfig};
12
13/// OIDC claims structure with standard JWT fields.
14///
15/// # Usage
16///
17/// ```ignore
18/// use auths_infra_http::HttpJwtValidator;
19/// use chrono::Utc;
20///
21/// let validator = HttpJwtValidator::new(jwks_client);
22/// let claims = validator.validate(token, &config, Utc::now()).await?;
23/// ```
24#[derive(Debug, Serialize, Deserialize, Clone)]
25pub struct OidcTokenClaims {
26    /// Subject (user/service identity)
27    pub sub: String,
28    /// Issuer
29    pub iss: String,
30    /// Audience
31    pub aud: String,
32    /// Expiration time
33    pub exp: i64,
34    /// Issued at time
35    #[serde(default)]
36    pub iat: i64,
37    /// Not before time
38    #[serde(default)]
39    pub nbf: Option<i64>,
40    /// JWT ID (jti) for replay detection
41    #[serde(default)]
42    pub jti: Option<String>,
43    /// Additional claims (passed through as extra fields)
44    #[serde(flatten)]
45    pub extra: serde_json::Map<String, serde_json::Value>,
46}
47
48/// HTTP-based implementation of JwtValidator using jsonwebtoken crate.
49///
50/// Validates JWT tokens by:
51/// 1. Extracting JWT header to get kid and alg
52/// 2. Fetching JWKS from issuer via injected JwksClient
53/// 3. Matching kid to find the appropriate key
54/// 4. Building Validation struct with explicit algorithm and claims validation
55/// 5. Calling jsonwebtoken::decode() with full validation
56/// 6. Returning claims as JSON value
57pub struct HttpJwtValidator {
58    jwks_client: Arc<dyn JwksClient>,
59}
60
61impl HttpJwtValidator {
62    /// Create a new HttpJwtValidator with the given JWKS client.
63    ///
64    /// # Args
65    ///
66    /// * `jwks_client`: JWKS client for fetching and caching public keys
67    pub fn new(jwks_client: Arc<dyn JwksClient>) -> Self {
68        Self { jwks_client }
69    }
70}
71
72#[async_trait]
73impl JwtValidator for HttpJwtValidator {
74    async fn validate(
75        &self,
76        token: &str,
77        config: &OidcValidationConfig,
78        now: DateTime<Utc>,
79    ) -> Result<serde_json::Value, OidcError> {
80        let mut token_mut = token.to_string();
81
82        let header = decode_header(&token_mut).map_err(|e| {
83            let error_msg = format!("{}", e);
84            // Check if the error is due to an unknown algorithm variant
85            if error_msg.contains("unknown variant") && error_msg.contains("expected one of") {
86                OidcError::AlgorithmMismatch {
87                    expected: "RS256, RS384, RS512, ES256, ES384, PS256, PS384, PS512, or EdDSA"
88                        .to_string(),
89                    got: "unsupported algorithm".to_string(),
90                }
91            } else {
92                OidcError::JwtDecode(format!("failed to decode JWT header: {}", e))
93            }
94        })?;
95
96        let kid = header
97            .kid
98            .ok_or_else(|| OidcError::JwtDecode("JWT header missing 'kid' field".to_string()))?;
99
100        let alg_str = format!("{:?}", header.alg);
101        if alg_str.to_uppercase() == "NONE" {
102            return Err(OidcError::AlgorithmMismatch {
103                expected: "RS256 or ES256".to_string(),
104                got: "none".to_string(),
105            });
106        }
107
108        if !config
109            .allowed_algorithms
110            .iter()
111            .any(|allowed| allowed.to_uppercase() == alg_str.to_uppercase())
112        {
113            return Err(OidcError::AlgorithmMismatch {
114                expected: config.allowed_algorithms.join(", "),
115                got: alg_str.clone(),
116            });
117        }
118
119        let jwks = self.jwks_client.fetch_jwks(&config.issuer).await?;
120
121        let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
122            OidcError::JwksResolutionFailed("JWKS response missing 'keys' array".to_string())
123        })?;
124
125        let key_obj = keys
126            .iter()
127            .find(|key| {
128                key.get("kid")
129                    .and_then(|k| k.as_str())
130                    .map(|k| k == kid)
131                    .unwrap_or(false)
132            })
133            .ok_or_else(|| OidcError::UnknownKeyId(kid.clone()))?;
134
135        let jwk: Jwk = serde_json::from_value(key_obj.clone()).map_err(|e| {
136            OidcError::JwksResolutionFailed(format!(
137                "failed to parse JWKS key for kid {}: {}",
138                kid, e
139            ))
140        })?;
141
142        let decoding_key = DecodingKey::from_jwk(&jwk).map_err(|e| {
143            OidcError::JwksResolutionFailed(format!(
144                "failed to create decoding key for kid {}: {}",
145                kid, e
146            ))
147        })?;
148
149        let now_secs = now.timestamp();
150        let leeway = config.max_clock_skew_secs as u64;
151
152        let algorithm = match alg_str.to_uppercase().as_str() {
153            "RS256" => Algorithm::RS256,
154            "ES256" => Algorithm::ES256,
155            _ => {
156                return Err(OidcError::AlgorithmMismatch {
157                    expected: "RS256 or ES256".to_string(),
158                    got: alg_str,
159                });
160            }
161        };
162
163        let mut validation = Validation::new(algorithm);
164
165        validation.set_issuer(&[&config.issuer]);
166        validation.set_audience(&[&config.audience]);
167        validation.leeway = leeway;
168        validation.validate_exp = true;
169        validation.set_required_spec_claims(&["exp", "iss", "aud", "sub"]);
170
171        let token_data = decode::<OidcTokenClaims>(&token_mut, &decoding_key, &validation)
172            .map_err(|e| {
173                let error_msg = format!("{}", e);
174                if error_msg.contains("ExpiredSignature") || error_msg.contains("InvalidIssuedAt") {
175                    OidcError::ClockSkewExceeded {
176                        token_exp: 0,
177                        current_time: now_secs,
178                        leeway: leeway as i64,
179                    }
180                } else if error_msg.contains("InvalidSignature") {
181                    OidcError::SignatureVerificationFailed
182                } else if error_msg.contains("InvalidIssuer") {
183                    OidcError::ClaimsValidationFailed {
184                        claim: "iss".to_string(),
185                        reason: "issuer mismatch".to_string(),
186                    }
187                } else if error_msg.contains("InvalidAudience") {
188                    OidcError::ClaimsValidationFailed {
189                        claim: "aud".to_string(),
190                        reason: "audience mismatch".to_string(),
191                    }
192                } else {
193                    OidcError::JwtDecode(format!("JWT validation failed: {}", e))
194                }
195            })?;
196
197        token_mut.zeroize();
198
199        let mut json = serde_json::json!(token_data.claims);
200        if let Some(obj) = json.as_object_mut() {
201            for (k, v) in token_data.claims.extra.iter() {
202                obj.insert(k.clone(), v.clone());
203            }
204        }
205
206        Ok(json)
207    }
208}
209
210/// HTTP-based implementation of JwksClient with built-in caching.
211///
212/// Caches JWKS responses with configurable TTL to avoid repeated network calls.
213/// Implements refresh-ahead pattern to reduce cache misses.
214pub struct HttpJwksClient {
215    cache: Arc<RwLock<JwksCache>>,
216}
217
218struct JwksCache {
219    data: Option<serde_json::Value>,
220    expires_at: Option<DateTime<Utc>>,
221    ttl: Duration,
222}
223
224impl HttpJwksClient {
225    /// Create a new HttpJwksClient with the given cache TTL.
226    ///
227    /// # Args
228    ///
229    /// * `ttl`: Cache time-to-live duration
230    pub fn new(ttl: Duration) -> Self {
231        Self {
232            cache: Arc::new(RwLock::new(JwksCache {
233                data: None,
234                expires_at: None,
235                ttl,
236            })),
237        }
238    }
239
240    /// Create a new HttpJwksClient with default TTL of 1 hour.
241    pub fn with_default_ttl() -> Self {
242        Self::new(Duration::from_secs(3600))
243    }
244}
245
246#[async_trait]
247impl JwksClient for HttpJwksClient {
248    async fn fetch_jwks(&self, issuer_url: &str) -> Result<serde_json::Value, OidcError> {
249        #[allow(clippy::disallowed_methods)] // Cache refresh: needs current time
250        let now = Utc::now();
251        {
252            let cache = self.cache.read();
253            if let Some(data) = &cache.data
254                && let Some(expires_at) = cache.expires_at
255                && now < expires_at
256            {
257                return Ok(data.clone());
258            }
259        }
260
261        let jwks_url = format!(
262            "{}{}",
263            issuer_url.trim_end_matches('/'),
264            "/.well-known/jwks.json"
265        );
266
267        let client = default_http_client();
268        let response = client.get(&jwks_url).send().await.map_err(|e| {
269            OidcError::JwksResolutionFailed(format!(
270                "failed to fetch JWKS from {}: {}",
271                jwks_url, e
272            ))
273        })?;
274
275        let jwks: serde_json::Value = response.json().await.map_err(|e| {
276            OidcError::JwksResolutionFailed(format!(
277                "failed to parse JWKS response from {}: {}",
278                jwks_url, e
279            ))
280        })?;
281
282        let mut cache = self.cache.write();
283        cache.data = Some(jwks.clone());
284        // INVARIANT: cache.ttl is always a valid Duration (max 1 hour)
285        #[allow(clippy::expect_used)]
286        let duration_offset = chrono::Duration::from_std(cache.ttl).expect("cache TTL overflow");
287        cache.expires_at = Some(now + duration_offset);
288
289        Ok(jwks)
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use auths_oidc_port::OidcValidationConfig;
297
298    #[tokio::test]
299    async fn test_http_jwt_validator_missing_kid() {
300        let mock_client = MockJwksClient::new();
301        let validator = HttpJwtValidator::new(Arc::new(mock_client));
302
303        let invalid_token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
304        let config = OidcValidationConfig::builder()
305            .issuer("https://example.com")
306            .audience("test")
307            .build()
308            .unwrap();
309
310        #[allow(clippy::disallowed_methods)] // Test boundary
311        let result = validator.validate(invalid_token, &config, Utc::now()).await;
312        assert!(result.is_err());
313    }
314
315    #[tokio::test]
316    async fn test_algorithm_none_rejected() {
317        let mock_client = MockJwksClient::new();
318        let validator = HttpJwtValidator::new(Arc::new(mock_client));
319
320        let token_none = "eyJhbGciOiJub25lIiwia2lkIjoiYWJjIn0.eyJzdWIiOiIxMjM0NTY3ODkwIn0.";
321        let config = OidcValidationConfig::builder()
322            .issuer("https://example.com")
323            .audience("test")
324            .build()
325            .unwrap();
326
327        #[allow(clippy::disallowed_methods)] // Test boundary
328        let result = validator.validate(token_none, &config, Utc::now()).await;
329        assert!(matches!(result, Err(OidcError::AlgorithmMismatch { .. })));
330    }
331
332    struct MockJwksClient;
333
334    impl MockJwksClient {
335        fn new() -> Self {
336            Self
337        }
338    }
339
340    #[async_trait]
341    impl JwksClient for MockJwksClient {
342        async fn fetch_jwks(&self, _issuer_url: &str) -> Result<serde_json::Value, OidcError> {
343            Ok(serde_json::json!({
344                "keys": [
345                    {
346                        "kty": "RSA",
347                        "kid": "test-key-1",
348                        "use": "sig",
349                        "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
350                        "e": "AQAB"
351                    }
352                ]
353            }))
354        }
355    }
356}