1use std::sync::Arc;
4
5use jsonwebtoken::{jwk::JwkSet, Algorithm, DecodingKey, Validation};
6use serde::de::DeserializeOwned;
7
8use crate::{
9 cache::{JwkCacheKey, JwksCache},
10 config::AuthenticationConfigProvider,
11 error::OidcError,
12};
13
14pub fn validate_token_with_jwk<T: DeserializeOwned>(
20 token: &str,
21 jwk: &jsonwebtoken::jwk::Jwk,
22 alg: Algorithm,
23) -> Result<T, String> {
24 let mut validation = Validation::new(alg);
25 validation.validate_aud = false;
26
27 let decoding_key = DecodingKey::from_jwk(jwk)
28 .map_err(|e| format!("Failed to create DecodingKey from JWK: {e}"))?;
29
30 let token_data = jsonwebtoken::decode::<T>(token, &decoding_key, &validation)
31 .map_err(|e| format!("JWT validation failed: {e}"))?;
32
33 Ok(token_data.claims)
34}
35
36pub async fn fetch_and_cache_jwks(
51 cache: &Arc<dyn JwksCache>,
52 jwks_uri: &str,
53 config: &(impl AuthenticationConfigProvider + Send + Sync),
54 requested_kid: &str,
55) -> Result<jsonwebtoken::jwk::Jwk, OidcError> {
56 let response = reqwest::get(jwks_uri)
58 .await
59 .map_err(|e| OidcError::JwksError(format!("Failed to fetch JWKS: {e}")))?;
60
61 let jwks: JwkSet = response
62 .json()
63 .await
64 .map_err(|e| OidcError::JwksError(format!("Failed to parse JWKS: {e}")))?;
65
66 let jwks_cache_ttl = config.get_jwks_cache_ttl();
68 for jwk in &jwks.keys {
69 if let Some(kid) = &jwk.common.key_id {
70 let jwk_cache_key = JwkCacheKey::from_jwks_uri_and_kid(jwks_uri, kid);
71 if let Ok(serialized) = serde_json::to_string(jwk) {
72 cache.set(jwk_cache_key.as_str(), serialized, jwks_cache_ttl);
73 }
74 }
75 }
76
77 jwks.find(requested_kid)
79 .cloned()
80 .ok_or_else(|| OidcError::JwksError(format!("No JWK found for kid: {requested_kid}")))
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use jsonwebtoken::{Algorithm, EncodingKey, Header};
87 use serde::{Deserialize, Serialize};
88 use std::collections::BTreeMap;
89
90 #[derive(Debug, Serialize, Deserialize, PartialEq)]
91 struct TestClaims {
92 sub: String,
93 exp: usize,
94 }
95
96 #[test]
97 fn test_validate_token_with_invalid_jwk() {
98 let token = "invalid.token.here";
100 let mut jwk = jsonwebtoken::jwk::Jwk::default();
101 jwk.common.key_type = Some(jsonwebtoken::jwk::KeyType::RSA);
102
103 let result = validate_token_with_jwk::<TestClaims>(token, &jwk, Algorithm::RS256);
104 assert!(result.is_err());
105 }
106}