1use std::collections::{HashMap, HashSet};
6
7use jwt_simple::{
8 algorithms::RSAPublicKeyLike,
9 prelude::{
10 Token,
11 Serialize,
12 VerificationOptions,
13 RS256PublicKey,
14 JWTClaims,
15 },
16};
17use chrono::{Duration, DateTime, Utc};
18use serde::Deserialize;
19use serde_with::{
20 serde_as,
21 base64::{Base64, UrlSafe},
22 formats::Unpadded,
23};
24
25
26mod error;
27pub use error::JwkClientErr;
28
29#[derive(Debug, Clone)]
34pub struct JwkClient {
35 jwks_uri: String,
36 issuer: String,
37 audience: String,
38 public_keys: HashMap<String, PublicKey>, auto_refresh_interval: Duration,
42 retry_rate_limit: Duration,
44 last_refresh: Option<DateTime<Utc>>,
46 last_retry: Option<DateTime<Utc>>,
48}
49
50#[derive(Debug, Clone)]
51struct PublicKey {
52 key: RS256PublicKey,
53 not_before: Option<DateTime<Utc>>,
54}
55
56impl PublicKey {
57 fn is_valid(&self) -> bool {
60 self.not_before.is_none_or(|nbf| nbf <= Utc::now())
61 }
62
63 fn valid_key(&self) -> Option<&RS256PublicKey> {
65 self.is_valid().then_some(&self.key)
66 }
67}
68
69impl JwkClient {
70 pub fn new(
71 jwks_uri: impl Into<String>,
72 issuer: impl Into<String>,
73 audience: impl Into<String>,
74 ) -> Self {
75 Self {
76 jwks_uri: jwks_uri.into(),
77 issuer: issuer.into(),
78 audience: audience.into(),
79 public_keys: HashMap::new(),
80 auto_refresh_interval: Duration::hours(1),
81 retry_rate_limit: Duration::minutes(5),
82 last_refresh: None,
83 last_retry: None,
84 }
85 }
86
87 pub fn set_auto_refresh_interval(&mut self, duration: Duration) {
88 self.auto_refresh_interval = duration;
89 }
90
91 pub fn set_retry_rate_limit(&mut self, duration: Duration) {
92 self.retry_rate_limit = duration;
93 }
94
95 fn keys_are_stale(&self) -> bool {
96 self.last_refresh
97 .map(|t| Utc::now() - t > self.auto_refresh_interval)
98 .unwrap_or(true)
99 }
100
101 fn can_retry_on_failure(&self) -> bool {
102 self.last_retry
103 .map(|t| Utc::now() - t > self.retry_rate_limit)
104 .unwrap_or(true)
105 }
106
107 async fn refresh_public_keys(&mut self) -> Result<(), JwkClientErr> {
108 let public_keys: Result<_, _> = reqwest::get(&self.jwks_uri)
109 .await?
110 .json::<JwkRawArray>()
111 .await?
112 .keys
113 .into_iter()
114 .map(|jwk| {
115 let key = RS256PublicKey::from_components(&jwk.modulus, &jwk.exponent)?;
116 Ok::<(std::string::String, PublicKey), JwkClientErr>((jwk.key_id, PublicKey {
117 key,
118 not_before: jwk.not_before,
119 }))
120 })
121 .collect();
122
123 self.public_keys = public_keys?;
124 self.last_refresh = Some(Utc::now());
125
126 Ok(())
127 }
128
129 fn get_valid_key(&self, key_id: &str) -> Option<&RS256PublicKey> {
130 self.public_keys
131 .get(key_id)?
132 .valid_key()
133 }
134
135 pub async fn validate_token<T>(&mut self, token: &str) -> Result<JWTClaims<T>, JwkClientErr>
136 where
137 for<'de> T: Serialize + Deserialize<'de>,
138 {
139 if self.keys_are_stale() {
140 self.refresh_public_keys().await?;
141 }
142
143 match self.validate_token_impl(token).await {
144 Err(_) if self.can_retry_on_failure() => {
146 self.refresh_public_keys().await?;
147 self.last_retry = Some(Utc::now());
148 self.validate_token_impl(token).await
149 },
150 result => result,
152 }
153 }
154
155 async fn validate_token_impl<T>(
156 &mut self,
157 token: &str,
158 ) -> Result<JWTClaims<T>, JwkClientErr>
159 where
160 for<'de> T: Serialize + Deserialize<'de>,
161 {
162 let verification_options = VerificationOptions {
163 allowed_issuers: Some(HashSet::from([self.issuer.clone()])),
164 allowed_audiences: Some(HashSet::from([self.audience.clone()])),
165 ..Default::default()
166 };
167
168 let metadata = Token::decode_metadata(token)?;
169
170 let key_id = metadata
171 .key_id()
172 .ok_or(JwkClientErr::Other("token is missing public key id `kid`".to_string()))?;
173
174 let key = self.get_valid_key(key_id)
175 .ok_or(JwkClientErr::Other("token's public key id `kid` not found".to_string()))?;
176
177 key.verify_token::<T>(token, Some(verification_options))
178 .map_err(JwkClientErr::from)
179 }
180
181}
182
183
184#[derive(Debug, Deserialize)]
185struct JwkRawArray {
186 keys: Vec<JwkRaw>,
187}
188
189#[serde_as]
190#[derive(Debug, Deserialize, Clone)]
191struct JwkRaw {
192 #[serde(rename = "kid")]
193 key_id: String,
194
195 #[serde(rename = "nbf", with = "chrono::serde::ts_seconds_option")]
202 not_before: Option<DateTime<Utc>>,
203
204 #[serde_as(as = "Base64<UrlSafe, Unpadded>")]
205 #[serde(rename = "e")]
206 exponent: Vec<u8>,
207
208 #[serde_as(as = "Base64<UrlSafe, Unpadded>")]
209 #[serde(rename = "n")]
210 modulus: Vec<u8>,
211}