jwk_box/
lib.rs

1//! # A basic JWK client.
2//!
3//! Fetches public keys from a jwks_uri to validate JWT. Keys are refreshed automatically.
4
5use std::collections::{HashMap, HashSet};
6
7use jwt_simple::{
8    algorithms::RSAPublicKeyLike,
9    prelude::{
10        Token,
11        Serialize,
12        VerificationOptions,
13        RS256PublicKey,
14        JWTClaims,
15    },
16};
17use chrono::{Duration, DateTime, Utc};
18use serde::Deserialize;
19use serde_with::{
20    serde_as,
21    base64::{Base64, UrlSafe},
22    formats::Unpadded,
23};
24
25
26mod error;
27pub use error::JwkClientErr;
28
29/// # Defaults
30///
31/// - If public keys are older than `auto_refresh_interval`, the keys are refreshed before token validation. Defaults to an hour.
32/// - Reactively refreshes public keys and retries token validation on validation failure, limited to once per `retry_rate_limit`. Defaults to 5 minutes.
33#[derive(Debug, Clone)]
34pub struct JwkClient {
35    jwks_uri: String,
36    issuer: String,
37    audience: String,
38    public_keys: HashMap<String, PublicKey>, // `kid` -> PublicKey
39    // how often JWK will be fetched proactively before token validation, i.e. how
40    // long before JWK will be considered stale
41    auto_refresh_interval: Duration,
42    // limit how often JWK will be fetched reactively after failed token validation
43    retry_rate_limit: Duration,
44    // last time JWK were fetched proactively before token validation
45    last_refresh: Option<DateTime<Utc>>,
46    // last time JWK were fetched reactively after failed token validation
47    last_retry: Option<DateTime<Utc>>,
48}
49
50#[derive(Debug, Clone)]
51struct PublicKey {
52    key: RS256PublicKey,
53    not_before: Option<DateTime<Utc>>,
54}
55
56impl PublicKey {
57    /// Check if key is valid (not_before is either None or in the past)
58    /// Returns true if the key is currently valid
59    fn is_valid(&self) -> bool {
60        self.not_before.is_none_or(|nbf| nbf <= Utc::now())
61    }
62
63    /// Returns the key if it's currently valid, None otherwise
64    fn valid_key(&self) -> Option<&RS256PublicKey> {
65        self.is_valid().then_some(&self.key)
66    }
67}
68
69impl JwkClient {
70    pub fn new(
71        jwks_uri: impl Into<String>,
72        issuer: impl Into<String>,
73        audience: impl Into<String>,
74    ) -> Self {
75        Self {
76            jwks_uri: jwks_uri.into(),
77            issuer: issuer.into(),
78            audience: audience.into(),
79            public_keys: HashMap::new(),
80            auto_refresh_interval: Duration::hours(1),
81            retry_rate_limit: Duration::minutes(5),
82            last_refresh: None,
83            last_retry: None,
84        }
85    }
86
87    pub fn set_auto_refresh_interval(&mut self, duration: Duration) {
88        self.auto_refresh_interval = duration;
89    }
90
91    pub fn set_retry_rate_limit(&mut self, duration: Duration) {
92        self.retry_rate_limit = duration;
93    }
94
95    fn keys_are_stale(&self) -> bool {
96        self.last_refresh
97            .map(|t| Utc::now() - t > self.auto_refresh_interval)
98            .unwrap_or(true)
99    }
100
101    fn can_retry_on_failure(&self) -> bool {
102        self.last_retry
103            .map(|t| Utc::now() - t > self.retry_rate_limit)
104            .unwrap_or(true)
105    }
106
107    async fn refresh_public_keys(&mut self) -> Result<(), JwkClientErr> {
108        let public_keys: Result<_, _> = reqwest::get(&self.jwks_uri)
109            .await?
110            .json::<JwkRawArray>()
111            .await?
112            .keys
113            .into_iter()
114            .map(|jwk| {
115                let key = RS256PublicKey::from_components(&jwk.modulus, &jwk.exponent)?;
116                Ok::<(std::string::String, PublicKey), JwkClientErr>((jwk.key_id, PublicKey {
117                    key,
118                    not_before: jwk.not_before,
119                }))
120            })
121            .collect();
122
123        self.public_keys = public_keys?;
124        self.last_refresh = Some(Utc::now());
125
126        Ok(())
127    }
128
129    fn get_valid_key(&self, key_id: &str) -> Option<&RS256PublicKey> {
130        self.public_keys
131            .get(key_id)?
132            .valid_key()
133    }
134
135    pub async fn validate_token<T>(&mut self, token: &str) -> Result<JWTClaims<T>, JwkClientErr>
136    where
137        for<'de> T: Serialize + Deserialize<'de>,
138    {
139        if self.keys_are_stale() {
140            self.refresh_public_keys().await?;
141        }
142
143        match self.validate_token_impl(token).await {
144            // Retry if we haven't retried recently
145            Err(_) if self.can_retry_on_failure() => {
146                self.refresh_public_keys().await?;
147                self.last_retry = Some(Utc::now());
148                self.validate_token_impl(token).await
149            },
150            // Otherwise, return the first result
151            result => result,
152        }
153    }
154
155    async fn validate_token_impl<T>(
156        &mut self,
157        token: &str,
158    ) -> Result<JWTClaims<T>, JwkClientErr>
159    where
160        for<'de> T: Serialize + Deserialize<'de>,
161    {
162        let verification_options = VerificationOptions {
163            allowed_issuers: Some(HashSet::from([self.issuer.clone()])),
164            allowed_audiences: Some(HashSet::from([self.audience.clone()])),
165            ..Default::default()
166        };
167
168        let metadata = Token::decode_metadata(token)?;
169
170        let key_id = metadata
171            .key_id()
172            .ok_or(JwkClientErr::Other("token is missing public key id `kid`".to_string()))?;
173
174        let key = self.get_valid_key(key_id)
175            .ok_or(JwkClientErr::Other("token's public key id `kid` not found".to_string()))?;
176
177        key.verify_token::<T>(token, Some(verification_options))
178            .map_err(JwkClientErr::from)
179    }
180
181}
182
183
184#[derive(Debug, Deserialize)]
185struct JwkRawArray {
186    keys: Vec<JwkRaw>,
187}
188
189#[serde_as]
190#[derive(Debug, Deserialize, Clone)]
191struct JwkRaw {
192    #[serde(rename = "kid")]
193    key_id: String,
194
195    // #[serde(rename = "use")]
196    // key_use: String, // e.g. "sig"
197
198    // #[serde(rename = "kty")]
199    // key_type: String, // e.g. "RSA"
200
201    #[serde(rename = "nbf", with = "chrono::serde::ts_seconds_option")]
202    not_before: Option<DateTime<Utc>>,
203
204    #[serde_as(as = "Base64<UrlSafe, Unpadded>")]
205    #[serde(rename = "e")]
206    exponent: Vec<u8>,
207
208    #[serde_as(as = "Base64<UrlSafe, Unpadded>")]
209    #[serde(rename = "n")]
210    modulus: Vec<u8>,
211}