1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::Jwk};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use std::time::Duration;
8use zeroize::Zeroize;
9
10use crate::default_http_client;
11use auths_oidc_port::{JwksClient, JwtValidator, OidcError, OidcValidationConfig};
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
25pub struct OidcTokenClaims {
26 pub sub: String,
28 pub iss: String,
30 pub aud: String,
32 pub exp: i64,
34 #[serde(default)]
36 pub iat: i64,
37 #[serde(default)]
39 pub nbf: Option<i64>,
40 #[serde(default)]
42 pub jti: Option<String>,
43 #[serde(flatten)]
45 pub extra: serde_json::Map<String, serde_json::Value>,
46}
47
48pub struct HttpJwtValidator {
58 jwks_client: Arc<dyn JwksClient>,
59}
60
61impl HttpJwtValidator {
62 pub fn new(jwks_client: Arc<dyn JwksClient>) -> Self {
68 Self { jwks_client }
69 }
70}
71
72#[async_trait]
73impl JwtValidator for HttpJwtValidator {
74 async fn validate(
75 &self,
76 token: &str,
77 config: &OidcValidationConfig,
78 now: DateTime<Utc>,
79 ) -> Result<serde_json::Value, OidcError> {
80 let mut token_mut = token.to_string();
81
82 let header = decode_header(&token_mut).map_err(|e| {
83 let error_msg = format!("{}", e);
84 if error_msg.contains("unknown variant") && error_msg.contains("expected one of") {
86 OidcError::AlgorithmMismatch {
87 expected: "RS256, RS384, RS512, ES256, ES384, PS256, PS384, PS512, or EdDSA"
88 .to_string(),
89 got: "unsupported algorithm".to_string(),
90 }
91 } else {
92 OidcError::JwtDecode(format!("failed to decode JWT header: {}", e))
93 }
94 })?;
95
96 let kid = header
97 .kid
98 .ok_or_else(|| OidcError::JwtDecode("JWT header missing 'kid' field".to_string()))?;
99
100 let alg_str = format!("{:?}", header.alg);
101 if alg_str.to_uppercase() == "NONE" {
102 return Err(OidcError::AlgorithmMismatch {
103 expected: "RS256 or ES256".to_string(),
104 got: "none".to_string(),
105 });
106 }
107
108 if !config
109 .allowed_algorithms
110 .iter()
111 .any(|allowed| allowed.to_uppercase() == alg_str.to_uppercase())
112 {
113 return Err(OidcError::AlgorithmMismatch {
114 expected: config.allowed_algorithms.join(", "),
115 got: alg_str.clone(),
116 });
117 }
118
119 let jwks = self.jwks_client.fetch_jwks(&config.issuer).await?;
120
121 let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
122 OidcError::JwksResolutionFailed("JWKS response missing 'keys' array".to_string())
123 })?;
124
125 let key_obj = keys
126 .iter()
127 .find(|key| {
128 key.get("kid")
129 .and_then(|k| k.as_str())
130 .map(|k| k == kid)
131 .unwrap_or(false)
132 })
133 .ok_or_else(|| OidcError::UnknownKeyId(kid.clone()))?;
134
135 let jwk: Jwk = serde_json::from_value(key_obj.clone()).map_err(|e| {
136 OidcError::JwksResolutionFailed(format!(
137 "failed to parse JWKS key for kid {}: {}",
138 kid, e
139 ))
140 })?;
141
142 let decoding_key = DecodingKey::from_jwk(&jwk).map_err(|e| {
143 OidcError::JwksResolutionFailed(format!(
144 "failed to create decoding key for kid {}: {}",
145 kid, e
146 ))
147 })?;
148
149 let now_secs = now.timestamp();
150 let leeway = config.max_clock_skew_secs as u64;
151
152 let algorithm = match alg_str.to_uppercase().as_str() {
153 "RS256" => Algorithm::RS256,
154 "ES256" => Algorithm::ES256,
155 _ => {
156 return Err(OidcError::AlgorithmMismatch {
157 expected: "RS256 or ES256".to_string(),
158 got: alg_str,
159 });
160 }
161 };
162
163 let mut validation = Validation::new(algorithm);
164
165 validation.set_issuer(&[&config.issuer]);
166 validation.set_audience(&[&config.audience]);
167 validation.leeway = leeway;
168 validation.validate_exp = true;
169 validation.set_required_spec_claims(&["exp", "iss", "aud", "sub"]);
170
171 let token_data = decode::<OidcTokenClaims>(&token_mut, &decoding_key, &validation)
172 .map_err(|e| {
173 let error_msg = format!("{}", e);
174 if error_msg.contains("ExpiredSignature") || error_msg.contains("InvalidIssuedAt") {
175 OidcError::ClockSkewExceeded {
176 token_exp: 0,
177 current_time: now_secs,
178 leeway: leeway as i64,
179 }
180 } else if error_msg.contains("InvalidSignature") {
181 OidcError::SignatureVerificationFailed
182 } else if error_msg.contains("InvalidIssuer") {
183 OidcError::ClaimsValidationFailed {
184 claim: "iss".to_string(),
185 reason: "issuer mismatch".to_string(),
186 }
187 } else if error_msg.contains("InvalidAudience") {
188 OidcError::ClaimsValidationFailed {
189 claim: "aud".to_string(),
190 reason: "audience mismatch".to_string(),
191 }
192 } else {
193 OidcError::JwtDecode(format!("JWT validation failed: {}", e))
194 }
195 })?;
196
197 token_mut.zeroize();
198
199 let mut json = serde_json::json!(token_data.claims);
200 if let Some(obj) = json.as_object_mut() {
201 for (k, v) in token_data.claims.extra.iter() {
202 obj.insert(k.clone(), v.clone());
203 }
204 }
205
206 Ok(json)
207 }
208}
209
210pub struct HttpJwksClient {
215 cache: Arc<RwLock<JwksCache>>,
216}
217
218struct JwksCache {
219 data: Option<serde_json::Value>,
220 expires_at: Option<DateTime<Utc>>,
221 ttl: Duration,
222}
223
224impl HttpJwksClient {
225 pub fn new(ttl: Duration) -> Self {
231 Self {
232 cache: Arc::new(RwLock::new(JwksCache {
233 data: None,
234 expires_at: None,
235 ttl,
236 })),
237 }
238 }
239
240 pub fn with_default_ttl() -> Self {
242 Self::new(Duration::from_secs(3600))
243 }
244}
245
246#[async_trait]
247impl JwksClient for HttpJwksClient {
248 async fn fetch_jwks(&self, issuer_url: &str) -> Result<serde_json::Value, OidcError> {
249 #[allow(clippy::disallowed_methods)] let now = Utc::now();
251 {
252 let cache = self.cache.read();
253 if let Some(data) = &cache.data
254 && let Some(expires_at) = cache.expires_at
255 && now < expires_at
256 {
257 return Ok(data.clone());
258 }
259 }
260
261 let jwks_url = format!(
262 "{}{}",
263 issuer_url.trim_end_matches('/'),
264 "/.well-known/jwks.json"
265 );
266
267 let client = default_http_client();
268 let response = client.get(&jwks_url).send().await.map_err(|e| {
269 OidcError::JwksResolutionFailed(format!(
270 "failed to fetch JWKS from {}: {}",
271 jwks_url, e
272 ))
273 })?;
274
275 let jwks: serde_json::Value = response.json().await.map_err(|e| {
276 OidcError::JwksResolutionFailed(format!(
277 "failed to parse JWKS response from {}: {}",
278 jwks_url, e
279 ))
280 })?;
281
282 let mut cache = self.cache.write();
283 cache.data = Some(jwks.clone());
284 #[allow(clippy::expect_used)]
286 let duration_offset = chrono::Duration::from_std(cache.ttl).expect("cache TTL overflow");
287 cache.expires_at = Some(now + duration_offset);
288
289 Ok(jwks)
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use auths_oidc_port::OidcValidationConfig;
297
298 #[tokio::test]
299 async fn test_http_jwt_validator_missing_kid() {
300 let mock_client = MockJwksClient::new();
301 let validator = HttpJwtValidator::new(Arc::new(mock_client));
302
303 let invalid_token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
304 let config = OidcValidationConfig::builder()
305 .issuer("https://example.com")
306 .audience("test")
307 .build()
308 .unwrap();
309
310 #[allow(clippy::disallowed_methods)] let result = validator.validate(invalid_token, &config, Utc::now()).await;
312 assert!(result.is_err());
313 }
314
315 #[tokio::test]
316 async fn test_algorithm_none_rejected() {
317 let mock_client = MockJwksClient::new();
318 let validator = HttpJwtValidator::new(Arc::new(mock_client));
319
320 let token_none = "eyJhbGciOiJub25lIiwia2lkIjoiYWJjIn0.eyJzdWIiOiIxMjM0NTY3ODkwIn0.";
321 let config = OidcValidationConfig::builder()
322 .issuer("https://example.com")
323 .audience("test")
324 .build()
325 .unwrap();
326
327 #[allow(clippy::disallowed_methods)] let result = validator.validate(token_none, &config, Utc::now()).await;
329 assert!(matches!(result, Err(OidcError::AlgorithmMismatch { .. })));
330 }
331
332 struct MockJwksClient;
333
334 impl MockJwksClient {
335 fn new() -> Self {
336 Self
337 }
338 }
339
340 #[async_trait]
341 impl JwksClient for MockJwksClient {
342 async fn fetch_jwks(&self, _issuer_url: &str) -> Result<serde_json::Value, OidcError> {
343 Ok(serde_json::json!({
344 "keys": [
345 {
346 "kty": "RSA",
347 "kid": "test-key-1",
348 "use": "sig",
349 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
350 "e": "AQAB"
351 }
352 ]
353 }))
354 }
355 }
356}