1use std::collections::HashSet;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use serde_json::map::Map;
5use serde_json::{from_value, Value};
6
7use crate::algorithms::Algorithm;
8use crate::errors::{new_error, ErrorKind, Result};
9
10#[derive(Debug, Clone, PartialEq)]
29pub struct Validation {
30 pub leeway: u64,
35 pub validate_exp: bool,
41 pub validate_nbf: bool,
47 pub aud: Option<HashSet<String>>,
52 pub iss: Option<String>,
57 pub sub: Option<String>,
62 pub algorithms: Vec<Algorithm>,
67}
68
69impl Validation {
70 pub fn new(alg: Algorithm) -> Validation {
72 Validation { algorithms: vec![alg], ..Default::default() }
73 }
74
75 pub fn set_audience<T: ToString>(&mut self, items: &[T]) {
77 self.aud = Some(items.iter().map(|x| x.to_string()).collect())
78 }
79}
80
81impl Default for Validation {
82 fn default() -> Validation {
83 Validation {
84 leeway: 0,
85
86 validate_exp: true,
87 validate_nbf: false,
88
89 iss: None,
90 sub: None,
91 aud: None,
92
93 algorithms: Vec::new(),
94 }
95 }
96}
97
98fn get_current_timestamp() -> u64 {
99 let start = SystemTime::now();
100 start.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs()
101}
102
103pub fn validate(claims: &Map<String, Value>, options: &Validation) -> Result<()> {
104 let now = get_current_timestamp();
105
106 if options.validate_exp {
107 if let Some(exp) = claims.get("exp") {
108 if from_value::<u64>(exp.clone())? < now - options.leeway {
109 return Err(new_error(ErrorKind::ExpiredSignature));
110 }
111 } else {
112 return Err(new_error(ErrorKind::ExpiredSignature));
113 }
114 }
115
116 if options.validate_nbf {
117 if let Some(nbf) = claims.get("nbf") {
118 if from_value::<u64>(nbf.clone())? > now + options.leeway {
119 return Err(new_error(ErrorKind::ImmatureSignature));
120 }
121 } else {
122 return Err(new_error(ErrorKind::ImmatureSignature));
123 }
124 }
125
126 if let Some(ref correct_iss) = options.iss {
127 if let Some(iss) = claims.get("iss") {
128 if from_value::<String>(iss.clone())? != *correct_iss {
129 return Err(new_error(ErrorKind::InvalidIssuer));
130 }
131 } else {
132 return Err(new_error(ErrorKind::InvalidIssuer));
133 }
134 }
135
136 if let Some(ref correct_sub) = options.sub {
137 if let Some(sub) = claims.get("sub") {
138 if from_value::<String>(sub.clone())? != *correct_sub {
139 return Err(new_error(ErrorKind::InvalidSubject));
140 }
141 } else {
142 return Err(new_error(ErrorKind::InvalidSubject));
143 }
144 }
145
146 if let Some(ref correct_aud) = options.aud {
147 if let Some(aud) = claims.get("aud") {
148 match aud {
149 Value::String(aud_found) => {
150 if !correct_aud.contains(aud_found) {
151 return Err(new_error(ErrorKind::InvalidAudience));
152 }
153 }
154 Value::Array(_) => {
155 let provided_aud: HashSet<String> = from_value(aud.clone())?;
156 if provided_aud.intersection(correct_aud).count() == 0 {
157 return Err(new_error(ErrorKind::InvalidAudience));
158 }
159 }
160 _ => return Err(new_error(ErrorKind::InvalidAudience)),
161 };
162 } else {
163 return Err(new_error(ErrorKind::InvalidAudience));
164 }
165 }
166
167 Ok(())
168}
169
170#[cfg(test)]
171mod tests {
172 use serde_json::map::Map;
173 use serde_json::to_value;
174
175 use super::{get_current_timestamp, validate, Validation};
176
177 use crate::errors::ErrorKind;
178
179 #[test]
180 fn exp_in_future_ok() {
181 let mut claims = Map::new();
182 claims.insert("exp".to_string(), to_value(get_current_timestamp() + 10000).unwrap());
183 let res = validate(&claims, &Validation::default());
184 assert!(res.is_ok());
185 }
186
187 #[test]
188 fn exp_in_past_fails() {
189 let mut claims = Map::new();
190 claims.insert("exp".to_string(), to_value(get_current_timestamp() - 100000).unwrap());
191 let res = validate(&claims, &Validation::default());
192 assert!(res.is_err());
193
194 match res.unwrap_err().kind() {
195 ErrorKind::ExpiredSignature => (),
196 _ => unreachable!(),
197 };
198 }
199
200 #[test]
201 fn exp_in_past_but_in_leeway_ok() {
202 let mut claims = Map::new();
203 claims.insert("exp".to_string(), to_value(get_current_timestamp() - 500).unwrap());
204 let validation = Validation { leeway: 1000 * 60, ..Default::default() };
205 let res = validate(&claims, &validation);
206 assert!(res.is_ok());
207 }
208
209 #[test]
211 fn validation_called_even_if_field_is_empty() {
212 let claims = Map::new();
213 let res = validate(&claims, &Validation::default());
214 assert!(res.is_err());
215 match res.unwrap_err().kind() {
216 ErrorKind::ExpiredSignature => (),
217 _ => unreachable!(),
218 };
219 }
220
221 #[test]
222 fn nbf_in_past_ok() {
223 let mut claims = Map::new();
224 claims.insert("nbf".to_string(), to_value(get_current_timestamp() - 10000).unwrap());
225 let validation =
226 Validation { validate_exp: false, validate_nbf: true, ..Validation::default() };
227 let res = validate(&claims, &validation);
228 assert!(res.is_ok());
229 }
230
231 #[test]
232 fn nbf_in_future_fails() {
233 let mut claims = Map::new();
234 claims.insert("nbf".to_string(), to_value(get_current_timestamp() + 100000).unwrap());
235 let validation =
236 Validation { validate_exp: false, validate_nbf: true, ..Validation::default() };
237 let res = validate(&claims, &validation);
238 assert!(res.is_err());
239
240 match res.unwrap_err().kind() {
241 ErrorKind::ImmatureSignature => (),
242 _ => unreachable!(),
243 };
244 }
245
246 #[test]
247 fn nbf_in_future_but_in_leeway_ok() {
248 let mut claims = Map::new();
249 claims.insert("nbf".to_string(), to_value(get_current_timestamp() + 500).unwrap());
250 let validation = Validation {
251 leeway: 1000 * 60,
252 validate_nbf: true,
253 validate_exp: false,
254 ..Default::default()
255 };
256 let res = validate(&claims, &validation);
257 assert!(res.is_ok());
258 }
259
260 #[test]
261 fn iss_ok() {
262 let mut claims = Map::new();
263 claims.insert("iss".to_string(), to_value("Keats").unwrap());
264 let validation = Validation {
265 validate_exp: false,
266 iss: Some("Keats".to_string()),
267 ..Default::default()
268 };
269 let res = validate(&claims, &validation);
270 assert!(res.is_ok());
271 }
272
273 #[test]
274 fn iss_not_matching_fails() {
275 let mut claims = Map::new();
276 claims.insert("iss".to_string(), to_value("Hacked").unwrap());
277 let validation = Validation {
278 validate_exp: false,
279 iss: Some("Keats".to_string()),
280 ..Default::default()
281 };
282 let res = validate(&claims, &validation);
283 assert!(res.is_err());
284
285 match res.unwrap_err().kind() {
286 ErrorKind::InvalidIssuer => (),
287 _ => unreachable!(),
288 };
289 }
290
291 #[test]
292 fn iss_missing_fails() {
293 let claims = Map::new();
294 let validation = Validation {
295 validate_exp: false,
296 iss: Some("Keats".to_string()),
297 ..Default::default()
298 };
299 let res = validate(&claims, &validation);
300 assert!(res.is_err());
301
302 match res.unwrap_err().kind() {
303 ErrorKind::InvalidIssuer => (),
304 _ => unreachable!(),
305 };
306 }
307
308 #[test]
309 fn sub_ok() {
310 let mut claims = Map::new();
311 claims.insert("sub".to_string(), to_value("Keats").unwrap());
312 let validation = Validation {
313 validate_exp: false,
314 sub: Some("Keats".to_string()),
315 ..Default::default()
316 };
317 let res = validate(&claims, &validation);
318 assert!(res.is_ok());
319 }
320
321 #[test]
322 fn sub_not_matching_fails() {
323 let mut claims = Map::new();
324 claims.insert("sub".to_string(), to_value("Hacked").unwrap());
325 let validation = Validation {
326 validate_exp: false,
327 sub: Some("Keats".to_string()),
328 ..Default::default()
329 };
330 let res = validate(&claims, &validation);
331 assert!(res.is_err());
332
333 match res.unwrap_err().kind() {
334 ErrorKind::InvalidSubject => (),
335 _ => unreachable!(),
336 };
337 }
338
339 #[test]
340 fn sub_missing_fails() {
341 let claims = Map::new();
342 let validation = Validation {
343 validate_exp: false,
344 sub: Some("Keats".to_string()),
345 ..Default::default()
346 };
347 let res = validate(&claims, &validation);
348 assert!(res.is_err());
349
350 match res.unwrap_err().kind() {
351 ErrorKind::InvalidSubject => (),
352 _ => unreachable!(),
353 };
354 }
355
356 #[test]
357 fn aud_string_ok() {
358 let mut claims = Map::new();
359 claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
360 let mut validation = Validation { validate_exp: false, ..Validation::default() };
361 validation.set_audience(&["Everyone"]);
362 let res = validate(&claims, &validation);
363 assert!(res.is_ok());
364 }
365
366 #[test]
367 fn aud_array_of_string_ok() {
368 let mut claims = Map::new();
369 claims.insert("aud".to_string(), to_value(["UserA", "UserB"]).unwrap());
370 let mut validation = Validation { validate_exp: false, ..Validation::default() };
371 validation.set_audience(&["UserA", "UserB"]);
372 let res = validate(&claims, &validation);
373 assert!(res.is_ok());
374 }
375
376 #[test]
377 fn aud_type_mismatch_fails() {
378 let mut claims = Map::new();
379 claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
380 let mut validation = Validation { validate_exp: false, ..Validation::default() };
381 validation.set_audience(&["UserA", "UserB"]);
382 let res = validate(&claims, &validation);
383 assert!(res.is_err());
384
385 match res.unwrap_err().kind() {
386 ErrorKind::InvalidAudience => (),
387 _ => unreachable!(),
388 };
389 }
390
391 #[test]
392 fn aud_correct_type_not_matching_fails() {
393 let mut claims = Map::new();
394 claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
395 let mut validation = Validation { validate_exp: false, ..Validation::default() };
396 validation.set_audience(&["None"]);
397 let res = validate(&claims, &validation);
398 assert!(res.is_err());
399
400 match res.unwrap_err().kind() {
401 ErrorKind::InvalidAudience => (),
402 _ => unreachable!(),
403 };
404 }
405
406 #[test]
407 fn aud_missing_fails() {
408 let claims = Map::new();
409 let mut validation = Validation { validate_exp: false, ..Validation::default() };
410 validation.set_audience(&["None"]);
411 let res = validate(&claims, &validation);
412 assert!(res.is_err());
413
414 match res.unwrap_err().kind() {
415 ErrorKind::InvalidAudience => (),
416 _ => unreachable!(),
417 };
418 }
419
420 #[test]
422 fn does_validation_in_right_order() {
423 let mut claims = Map::new();
424 claims.insert("exp".to_string(), to_value(get_current_timestamp() + 10000).unwrap());
425 let v = Validation {
426 leeway: 5,
427 validate_exp: true,
428 iss: Some("iss no check".to_string()),
429 sub: Some("sub no check".to_string()),
430 ..Validation::default()
431 };
432 let res = validate(&claims, &v);
433 assert!(res.is_err());
435 match res.unwrap_err().kind() {
436 ErrorKind::InvalidIssuer => (),
437 t => panic!("{:?}", t),
438 };
439 }
440
441 #[test]
443 fn aud_use_validation_struct() {
444 let mut claims = Map::new();
445 claims.insert(
446 "aud".to_string(),
447 to_value("my-googleclientid1234.apps.googleusercontent.com").unwrap(),
448 );
449
450 let aud = "my-googleclientid1234.apps.googleusercontent.com".to_string();
451 let mut aud_hashset = std::collections::HashSet::new();
452 aud_hashset.insert(aud);
453
454 let validation =
455 Validation { aud: Some(aud_hashset), validate_exp: false, ..Validation::default() };
456 let res = validate(&claims, &validation);
457 println!("{:?}", res);
458 assert!(res.is_ok());
459 }
460}