jwks_client/
keyset.rs

1use std::time::{Duration, SystemTime};
2
3use base64::{decode_config, URL_SAFE_NO_PAD};
4use regex::Regex;
5use reqwest;
6use reqwest::Response;
7use ring::signature::{RsaPublicKeyComponents, RSA_PKCS1_2048_8192_SHA256};
8use serde::{
9    de::DeserializeOwned,
10    {Deserialize, Serialize},
11};
12use serde_json::Value;
13
14use crate::error::*;
15use crate::jwt::*;
16
17type HeaderBody = String;
18pub type Signature = String;
19
20#[derive(Debug, Serialize, Deserialize)]
21pub struct JwtKey {
22    #[serde(default)] // https://github.com/jfbilodeau/jwks-client/issues/1
23    pub e: String,
24    pub kty: String,
25    pub alg: String,
26    #[serde(default)] // https://github.com/jfbilodeau/jwks-client/issues/1
27    pub n: String,
28    pub kid: String,
29}
30
31impl JwtKey {
32    pub fn new(kid: &str, n: &str, e: &str) -> JwtKey {
33        JwtKey {
34            e: e.to_owned(),
35            kty: "JTW".to_string(),
36            alg: "RS256".to_string(),
37            n: n.to_owned(),
38            kid: kid.to_owned(),
39        }
40    }
41}
42
43impl Clone for JwtKey {
44    fn clone(&self) -> Self {
45        JwtKey {
46            e: self.e.clone(),
47            kty: self.kty.clone(),
48            alg: self.alg.clone(),
49            n: self.n.clone(),
50            kid: self.kid.clone(),
51        }
52    }
53}
54
55pub struct KeyStore {
56    key_url: String,
57    keys: Vec<JwtKey>,
58    refresh_interval: f64,
59    load_time: Option<SystemTime>,
60    expire_time: Option<SystemTime>,
61    refresh_time: Option<SystemTime>,
62}
63
64impl KeyStore {
65    pub fn new() -> KeyStore {
66        let key_store = KeyStore {
67            key_url: "".to_owned(),
68            keys: vec![],
69            refresh_interval: 0.5,
70            load_time: None,
71            expire_time: None,
72            refresh_time: None,
73        };
74
75        key_store
76    }
77
78    pub async fn new_from(jkws_url: &str) -> Result<KeyStore, Error> {
79        let mut key_store = KeyStore::new();
80
81        key_store.key_url = jkws_url.to_string();
82
83        key_store.load_keys().await?;
84
85        Ok(key_store)
86    }
87
88    pub fn clear_keys(&mut self) {
89        self.keys.clear();
90    }
91
92    pub fn key_set_url(&self) -> &str {
93        &self.key_url
94    }
95
96    pub async fn load_keys_from(&mut self, url: &str) -> Result<(), Error> {
97        self.key_url = url.to_owned();
98
99        self.load_keys().await?;
100
101        Ok(())
102    }
103
104    pub async fn load_keys(&mut self) -> Result<(), Error> {
105        #[derive(Deserialize)]
106        pub struct JwtKeys {
107            pub keys: Vec<JwtKey>,
108        }
109
110        let mut response = reqwest::get(&self.key_url).await.map_err(|_| err_con("Could not download JWKS"))?;
111
112        let load_time = SystemTime::now();
113        self.load_time = Some(load_time);
114
115        let result = KeyStore::cache_max_age(&mut response);
116
117        if let Ok(value) = result {
118            let expire = load_time + Duration::new(value, 0);
119            self.expire_time = Some(expire);
120            let refresh_time = (value as f64 * self.refresh_interval) as u64;
121            let refresh = load_time + Duration::new(refresh_time, 0);
122            self.refresh_time = Some(refresh);
123        }
124
125        let jwks = response.json::<JwtKeys>().await.map_err(|_| err_int("Failed to parse keys"))?;
126
127        jwks.keys.iter().for_each(|k| self.add_key(k));
128
129        Ok(())
130    }
131
132    fn cache_max_age(response: &mut Response) -> Result<u64, ()> {
133        let header = response.headers().get("cache-control").ok_or(())?;
134
135        let header_text = header.to_str().map_err(|_| ())?;
136
137        let re = Regex::new("max-age\\s*=\\s*(\\d+)").map_err(|_| ())?;
138
139        let captures = re.captures(header_text).ok_or(())?;
140
141        let capture = captures.get(1).ok_or(())?;
142
143        let text = capture.as_str();
144
145        let value = text.parse::<u64>().map_err(|_| ())?;
146
147        Ok(value)
148    }
149
150    /// Fetch a key by key id (KID)
151    pub fn key_by_id(&self, kid: &str) -> Option<&JwtKey> {
152        self.keys.iter().find(|k| k.kid == kid)
153    }
154
155    /// Number of keys in keystore
156    pub fn keys_len(&self) -> usize {
157        self.keys.len()
158    }
159
160    /// Manually add a key to the keystore
161    pub fn add_key(&mut self, key: &JwtKey) {
162        self.keys.push(key.clone());
163    }
164
165    fn decode_segments(&self, token: &str) -> Result<(Header, Payload, Signature, HeaderBody), Error> {
166        let raw_segments: Vec<&str> = token.split(".").collect();
167        if raw_segments.len() != 3 {
168            return Err(err_inv("JWT does not have 3 segments"));
169        }
170
171        let header_segment = raw_segments[0];
172        let payload_segment = raw_segments[1];
173        let signature_segment = raw_segments[2].to_string();
174
175        let header = Header::new(decode_segment::<Value>(header_segment).or(Err(err_hea("Failed to decode header")))?);
176        let payload = Payload::new(decode_segment::<Value>(payload_segment).or(Err(err_pay("Failed to decode payload")))?);
177
178        let body = format!("{}.{}", header_segment, payload_segment);
179
180        Ok((header, payload, signature_segment, body))
181    }
182
183    pub fn decode(&self, token: &str) -> Result<Jwt, Error> {
184        let (header, payload, signature, _) = self.decode_segments(token)?;
185
186        Ok(Jwt::new(header, payload, signature))
187    }
188
189    pub fn verify_time(&self, token: &str, time: SystemTime) -> Result<Jwt, Error> {
190        let (header, payload, signature, body) = self.decode_segments(token)?;
191
192        if header.alg() != Some("RS256") {
193            return Err(err_inv("Unsupported algorithm"));
194        }
195
196        let kid = header.kid().ok_or(err_key("No key id"))?;
197
198        let key = self.key_by_id(kid).ok_or(err_key("JWT key does not exists"))?;
199
200        let e = decode_config(&key.e, URL_SAFE_NO_PAD).or(Err(err_cer("Failed to decode exponent")))?;
201        let n = decode_config(&key.n, URL_SAFE_NO_PAD).or(Err(err_cer("Failed to decode modulus")))?;
202
203        verify_signature(&e, &n, &body, &signature)?;
204
205        let jwt = Jwt::new(header, payload, signature);
206
207        if jwt.expired_time(time).unwrap_or(false) {
208            return Err(err_exp("Token expired"));
209        }
210        if jwt.early_time(time).unwrap_or(false) {
211            return Err(err_nbf("Too early to use token (nbf)"));
212        }
213
214        Ok(jwt)
215    }
216
217    /// Verify a JWT token.
218    /// If the token is valid, it is returned.
219    ///
220    /// A token is considered valid if:
221    /// * Is well formed
222    /// * Has a `kid` field that matches a public signature `kid
223    /// * Signature matches public key
224    /// * It is not expired
225    /// * The `nbf` is not set to before now
226    pub fn verify(&self, token: &str) -> Result<Jwt, Error> {
227        self.verify_time(token, SystemTime::now())
228    }
229
230    /// Time at which the keys were last refreshed
231    pub fn last_load_time(&self) -> Option<SystemTime> {
232        self.load_time
233    }
234
235    /// True if the keys are expired and should be refreshed
236    ///
237    /// None if keys do not have an expiration time
238    pub fn keys_expired(&self) -> Option<bool> {
239        match self.expire_time {
240            Some(expire) => Some(expire <= SystemTime::now()),
241            None => None,
242        }
243    }
244
245    /// Specifies the interval (as a fraction) when the key store should refresh it's key.
246    ///
247    /// The default is 0.5, meaning that keys should be refreshed when we are halfway through the expiration time (similar to DHCP).
248    ///
249    /// This method does _not_ update the refresh time. Call `load_keys` to force an update on the refresh time property.
250    pub fn set_refresh_interval(&mut self, interval: f64) {
251        self.refresh_interval = interval;
252    }
253
254    /// Get the current fraction time to check for token refresh time.
255    pub fn refresh_interval(&self) -> f64 {
256        self.refresh_interval
257    }
258
259    /// The time at which the keys were loaded
260    /// None if the keys were never loaded via `load_keys` or `load_keys_from`.
261    pub fn load_time(&self) -> Option<SystemTime> {
262        self.load_time
263    }
264
265    /// Get the time at which the keys are considered expired
266    pub fn expire_time(&self) -> Option<SystemTime> {
267        self.expire_time
268    }
269
270    /// time at which keys should be refreshed.
271    pub fn refresh_time(&self) -> Option<SystemTime> {
272        self.refresh_time
273    }
274
275    /// Returns `Option<true>` if keys should be refreshed based on the given `current_time`.
276    ///
277    /// None is returned if the key store does not have a refresh time available. For example, the
278    /// `load_keys` function was not called or the HTTP server did not provide a  
279    pub fn should_refresh_time(&self, current_time: SystemTime) -> Option<bool> {
280        if let Some(refresh_time) = self.refresh_time {
281            return Some(refresh_time <= current_time);
282        }
283
284        None
285    }
286
287    /// Returns `Option<true>` if keys should be refreshed based on the system time.
288    ///
289    /// None is returned if the key store does not have a refresh time available. For example, the
290    /// `load_keys` function was not called or the HTTP server did not provide a  
291    pub fn should_refresh(&self) -> Option<bool> {
292        self.should_refresh_time(SystemTime::now())
293    }
294}
295
296fn verify_signature(e: &Vec<u8>, n: &Vec<u8>, message: &str, signature: &str) -> Result<(), Error> {
297    let pkc = RsaPublicKeyComponents { e, n };
298
299    let message_bytes = &message.as_bytes().to_vec();
300    let signature_bytes = decode_config(&signature, URL_SAFE_NO_PAD).or(Err(err_sig("Could not base64 decode signature")))?;
301
302    let result = pkc.verify(&RSA_PKCS1_2048_8192_SHA256, &message_bytes, &signature_bytes);
303
304    result.or(Err(err_cer("Signature does not match certificate")))
305}
306
307fn decode_segment<T: DeserializeOwned>(segment: &str) -> Result<T, Error> {
308    let raw = decode_config(segment, base64::URL_SAFE_NO_PAD).or(Err(err_inv("Failed to decode segment")))?;
309    let slice = String::from_utf8_lossy(&raw);
310    let decoded: T = serde_json::from_str(&slice).or(Err(err_inv("Failed to decode segment")))?;
311
312    Ok(decoded)
313}