1use std::time::{Duration, SystemTime};
2
3use base64::{decode_config, URL_SAFE_NO_PAD};
4use regex::Regex;
5use reqwest;
6use reqwest::Response;
7use ring::signature::{RsaPublicKeyComponents, RSA_PKCS1_2048_8192_SHA256};
8use serde::{
9 de::DeserializeOwned,
10 {Deserialize, Serialize},
11};
12use serde_json::Value;
13
14use crate::error::*;
15use crate::jwt::*;
16
17type HeaderBody = String;
18pub type Signature = String;
19
20#[derive(Debug, Serialize, Deserialize)]
21pub struct JwtKey {
22 #[serde(default)] pub e: String,
24 pub kty: String,
25 pub alg: String,
26 #[serde(default)] pub n: String,
28 pub kid: String,
29}
30
31impl JwtKey {
32 pub fn new(kid: &str, n: &str, e: &str) -> JwtKey {
33 JwtKey {
34 e: e.to_owned(),
35 kty: "JTW".to_string(),
36 alg: "RS256".to_string(),
37 n: n.to_owned(),
38 kid: kid.to_owned(),
39 }
40 }
41}
42
43impl Clone for JwtKey {
44 fn clone(&self) -> Self {
45 JwtKey {
46 e: self.e.clone(),
47 kty: self.kty.clone(),
48 alg: self.alg.clone(),
49 n: self.n.clone(),
50 kid: self.kid.clone(),
51 }
52 }
53}
54
55pub struct KeyStore {
56 key_url: String,
57 keys: Vec<JwtKey>,
58 refresh_interval: f64,
59 load_time: Option<SystemTime>,
60 expire_time: Option<SystemTime>,
61 refresh_time: Option<SystemTime>,
62}
63
64impl KeyStore {
65 pub fn new() -> KeyStore {
66 let key_store = KeyStore {
67 key_url: "".to_owned(),
68 keys: vec![],
69 refresh_interval: 0.5,
70 load_time: None,
71 expire_time: None,
72 refresh_time: None,
73 };
74
75 key_store
76 }
77
78 pub async fn new_from(jkws_url: &str) -> Result<KeyStore, Error> {
79 let mut key_store = KeyStore::new();
80
81 key_store.key_url = jkws_url.to_string();
82
83 key_store.load_keys().await?;
84
85 Ok(key_store)
86 }
87
88 pub fn clear_keys(&mut self) {
89 self.keys.clear();
90 }
91
92 pub fn key_set_url(&self) -> &str {
93 &self.key_url
94 }
95
96 pub async fn load_keys_from(&mut self, url: &str) -> Result<(), Error> {
97 self.key_url = url.to_owned();
98
99 self.load_keys().await?;
100
101 Ok(())
102 }
103
104 pub async fn load_keys(&mut self) -> Result<(), Error> {
105 #[derive(Deserialize)]
106 pub struct JwtKeys {
107 pub keys: Vec<JwtKey>,
108 }
109
110 let mut response = reqwest::get(&self.key_url).await.map_err(|_| err_con("Could not download JWKS"))?;
111
112 let load_time = SystemTime::now();
113 self.load_time = Some(load_time);
114
115 let result = KeyStore::cache_max_age(&mut response);
116
117 if let Ok(value) = result {
118 let expire = load_time + Duration::new(value, 0);
119 self.expire_time = Some(expire);
120 let refresh_time = (value as f64 * self.refresh_interval) as u64;
121 let refresh = load_time + Duration::new(refresh_time, 0);
122 self.refresh_time = Some(refresh);
123 }
124
125 let jwks = response.json::<JwtKeys>().await.map_err(|_| err_int("Failed to parse keys"))?;
126
127 jwks.keys.iter().for_each(|k| self.add_key(k));
128
129 Ok(())
130 }
131
132 fn cache_max_age(response: &mut Response) -> Result<u64, ()> {
133 let header = response.headers().get("cache-control").ok_or(())?;
134
135 let header_text = header.to_str().map_err(|_| ())?;
136
137 let re = Regex::new("max-age\\s*=\\s*(\\d+)").map_err(|_| ())?;
138
139 let captures = re.captures(header_text).ok_or(())?;
140
141 let capture = captures.get(1).ok_or(())?;
142
143 let text = capture.as_str();
144
145 let value = text.parse::<u64>().map_err(|_| ())?;
146
147 Ok(value)
148 }
149
150 pub fn key_by_id(&self, kid: &str) -> Option<&JwtKey> {
152 self.keys.iter().find(|k| k.kid == kid)
153 }
154
155 pub fn keys_len(&self) -> usize {
157 self.keys.len()
158 }
159
160 pub fn add_key(&mut self, key: &JwtKey) {
162 self.keys.push(key.clone());
163 }
164
165 fn decode_segments(&self, token: &str) -> Result<(Header, Payload, Signature, HeaderBody), Error> {
166 let raw_segments: Vec<&str> = token.split(".").collect();
167 if raw_segments.len() != 3 {
168 return Err(err_inv("JWT does not have 3 segments"));
169 }
170
171 let header_segment = raw_segments[0];
172 let payload_segment = raw_segments[1];
173 let signature_segment = raw_segments[2].to_string();
174
175 let header = Header::new(decode_segment::<Value>(header_segment).or(Err(err_hea("Failed to decode header")))?);
176 let payload = Payload::new(decode_segment::<Value>(payload_segment).or(Err(err_pay("Failed to decode payload")))?);
177
178 let body = format!("{}.{}", header_segment, payload_segment);
179
180 Ok((header, payload, signature_segment, body))
181 }
182
183 pub fn decode(&self, token: &str) -> Result<Jwt, Error> {
184 let (header, payload, signature, _) = self.decode_segments(token)?;
185
186 Ok(Jwt::new(header, payload, signature))
187 }
188
189 pub fn verify_time(&self, token: &str, time: SystemTime) -> Result<Jwt, Error> {
190 let (header, payload, signature, body) = self.decode_segments(token)?;
191
192 if header.alg() != Some("RS256") {
193 return Err(err_inv("Unsupported algorithm"));
194 }
195
196 let kid = header.kid().ok_or(err_key("No key id"))?;
197
198 let key = self.key_by_id(kid).ok_or(err_key("JWT key does not exists"))?;
199
200 let e = decode_config(&key.e, URL_SAFE_NO_PAD).or(Err(err_cer("Failed to decode exponent")))?;
201 let n = decode_config(&key.n, URL_SAFE_NO_PAD).or(Err(err_cer("Failed to decode modulus")))?;
202
203 verify_signature(&e, &n, &body, &signature)?;
204
205 let jwt = Jwt::new(header, payload, signature);
206
207 if jwt.expired_time(time).unwrap_or(false) {
208 return Err(err_exp("Token expired"));
209 }
210 if jwt.early_time(time).unwrap_or(false) {
211 return Err(err_nbf("Too early to use token (nbf)"));
212 }
213
214 Ok(jwt)
215 }
216
217 pub fn verify(&self, token: &str) -> Result<Jwt, Error> {
227 self.verify_time(token, SystemTime::now())
228 }
229
230 pub fn last_load_time(&self) -> Option<SystemTime> {
232 self.load_time
233 }
234
235 pub fn keys_expired(&self) -> Option<bool> {
239 match self.expire_time {
240 Some(expire) => Some(expire <= SystemTime::now()),
241 None => None,
242 }
243 }
244
245 pub fn set_refresh_interval(&mut self, interval: f64) {
251 self.refresh_interval = interval;
252 }
253
254 pub fn refresh_interval(&self) -> f64 {
256 self.refresh_interval
257 }
258
259 pub fn load_time(&self) -> Option<SystemTime> {
262 self.load_time
263 }
264
265 pub fn expire_time(&self) -> Option<SystemTime> {
267 self.expire_time
268 }
269
270 pub fn refresh_time(&self) -> Option<SystemTime> {
272 self.refresh_time
273 }
274
275 pub fn should_refresh_time(&self, current_time: SystemTime) -> Option<bool> {
280 if let Some(refresh_time) = self.refresh_time {
281 return Some(refresh_time <= current_time);
282 }
283
284 None
285 }
286
287 pub fn should_refresh(&self) -> Option<bool> {
292 self.should_refresh_time(SystemTime::now())
293 }
294}
295
296fn verify_signature(e: &Vec<u8>, n: &Vec<u8>, message: &str, signature: &str) -> Result<(), Error> {
297 let pkc = RsaPublicKeyComponents { e, n };
298
299 let message_bytes = &message.as_bytes().to_vec();
300 let signature_bytes = decode_config(&signature, URL_SAFE_NO_PAD).or(Err(err_sig("Could not base64 decode signature")))?;
301
302 let result = pkc.verify(&RSA_PKCS1_2048_8192_SHA256, &message_bytes, &signature_bytes);
303
304 result.or(Err(err_cer("Signature does not match certificate")))
305}
306
307fn decode_segment<T: DeserializeOwned>(segment: &str) -> Result<T, Error> {
308 let raw = decode_config(segment, base64::URL_SAFE_NO_PAD).or(Err(err_inv("Failed to decode segment")))?;
309 let slice = String::from_utf8_lossy(&raw);
310 let decoded: T = serde_json::from_str(&slice).or(Err(err_inv("Failed to decode segment")))?;
311
312 Ok(decoded)
313}