1use std::borrow::Cow;
2use std::collections::HashSet;
3use std::fmt;
4use std::marker::PhantomData;
5
6use serde::de::{self, Visitor};
7use serde::{Deserialize, Deserializer};
8
9use crate::algorithms::Algorithm;
10use crate::errors::{new_error, ErrorKind, Result};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct Validation {
32 pub required_spec_claims: HashSet<String>,
39 pub leeway: u64,
44 pub reject_tokens_expiring_in_less_than: u64,
51 pub validate_exp: bool,
57 pub validate_nbf: bool,
66 pub validate_aud: bool,
75 pub aud: Option<HashSet<String>>,
84 pub iss: Option<HashSet<String>>,
93 pub sub: Option<String>,
101 pub algorithms: Vec<Algorithm>,
106
107 pub(crate) validate_signature: bool,
109}
110
111impl Validation {
112 pub fn new(alg: Algorithm) -> Validation {
114 let mut required_claims = HashSet::with_capacity(1);
115 required_claims.insert("exp".to_owned());
116
117 Validation {
118 required_spec_claims: required_claims,
119 algorithms: vec![alg],
120 leeway: 60,
121 reject_tokens_expiring_in_less_than: 0,
122
123 validate_exp: true,
124 validate_nbf: false,
125 validate_aud: true,
126
127 iss: None,
128 sub: None,
129 aud: None,
130
131 validate_signature: true,
132 }
133 }
134
135 pub fn set_audience<T: ToString>(&mut self, items: &[T]) {
138 self.aud = Some(items.iter().map(|x| x.to_string()).collect())
139 }
140
141 pub fn set_issuer<T: ToString>(&mut self, items: &[T]) {
144 self.iss = Some(items.iter().map(|x| x.to_string()).collect())
145 }
146
147 pub fn set_required_spec_claims<T: ToString>(&mut self, items: &[T]) {
153 self.required_spec_claims = items.iter().map(|x| x.to_string()).collect();
154 }
155
156 pub fn insecure_disable_signature_validation(&mut self) {
160 self.validate_signature = false;
161 }
162}
163
164impl Default for Validation {
165 fn default() -> Self {
166 Self::new(Algorithm::HS256)
167 }
168}
169
170#[cfg(not(all(target_arch = "wasm32", not(any(target_os = "emscripten", target_os = "wasi")))))]
172#[must_use]
173pub fn get_current_timestamp() -> u64 {
174 let start = std::time::SystemTime::now();
175 start.duration_since(std::time::UNIX_EPOCH).expect("Time went backwards").as_secs()
176}
177
178#[cfg(all(target_arch = "wasm32", not(any(target_os = "emscripten", target_os = "wasi"))))]
180#[must_use]
181pub fn get_current_timestamp() -> u64 {
182 js_sys::Date::new_0().get_time() as u64 / 1000
183}
184
185#[derive(Deserialize)]
186pub(crate) struct ClaimsForValidation<'a> {
187 #[serde(deserialize_with = "numeric_type", default)]
188 exp: TryParse<u64>,
189 #[serde(deserialize_with = "numeric_type", default)]
190 nbf: TryParse<u64>,
191 #[serde(borrow)]
192 sub: TryParse<Cow<'a, str>>,
193 #[serde(borrow)]
194 iss: TryParse<Issuer<'a>>,
195 #[serde(borrow)]
196 aud: TryParse<Audience<'a>>,
197}
198#[derive(Debug)]
199enum TryParse<T> {
200 Parsed(T),
201 FailedToParse,
202 NotPresent,
203}
204impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse<T> {
205 fn deserialize<D: serde::Deserializer<'de>>(
206 deserializer: D,
207 ) -> std::result::Result<Self, D::Error> {
208 Ok(match Option::<T>::deserialize(deserializer) {
209 Ok(Some(value)) => TryParse::Parsed(value),
210 Ok(None) => TryParse::NotPresent,
211 Err(_) => TryParse::FailedToParse,
212 })
213 }
214}
215impl<T> Default for TryParse<T> {
216 fn default() -> Self {
217 Self::NotPresent
218 }
219}
220
221#[derive(Deserialize)]
222#[serde(untagged)]
223enum Audience<'a> {
224 Single(#[serde(borrow)] Cow<'a, str>),
225 Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
226}
227
228#[derive(Deserialize)]
229#[serde(untagged)]
230enum Issuer<'a> {
231 Single(#[serde(borrow)] Cow<'a, str>),
232 Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
233}
234
235#[derive(Deserialize, PartialEq, Eq, Hash)]
239struct BorrowedCowIfPossible<'a>(#[serde(borrow)] Cow<'a, str>);
240impl std::borrow::Borrow<str> for BorrowedCowIfPossible<'_> {
241 fn borrow(&self) -> &str {
242 &self.0
243 }
244}
245
246fn is_subset(reference: &HashSet<String>, given: &HashSet<BorrowedCowIfPossible<'_>>) -> bool {
247 if reference.len() < given.len() {
249 reference.iter().any(|a| given.contains(&**a))
250 } else {
251 given.iter().any(|a| reference.contains(&*a.0))
252 }
253}
254
255pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> {
256 for required_claim in &options.required_spec_claims {
257 let present = match required_claim.as_str() {
258 "exp" => matches!(claims.exp, TryParse::Parsed(_)),
259 "sub" => matches!(claims.sub, TryParse::Parsed(_)),
260 "iss" => matches!(claims.iss, TryParse::Parsed(_)),
261 "aud" => matches!(claims.aud, TryParse::Parsed(_)),
262 "nbf" => matches!(claims.nbf, TryParse::Parsed(_)),
263 _ => continue,
264 };
265
266 if !present {
267 return Err(new_error(ErrorKind::MissingRequiredClaim(required_claim.clone())));
268 }
269 }
270
271 if options.validate_exp || options.validate_nbf {
272 let now = get_current_timestamp();
273
274 if matches!(claims.exp, TryParse::Parsed(exp) if options.validate_exp
275 && exp - options.reject_tokens_expiring_in_less_than < now - options.leeway )
276 {
277 return Err(new_error(ErrorKind::ExpiredSignature));
278 }
279
280 if matches!(claims.nbf, TryParse::Parsed(nbf) if options.validate_nbf && nbf > now + options.leeway)
281 {
282 return Err(new_error(ErrorKind::ImmatureSignature));
283 }
284 }
285
286 if let (TryParse::Parsed(sub), Some(correct_sub)) = (claims.sub, options.sub.as_deref()) {
287 if sub != correct_sub {
288 return Err(new_error(ErrorKind::InvalidSubject));
289 }
290 }
291
292 match (claims.iss, options.iss.as_ref()) {
293 (TryParse::Parsed(Issuer::Single(iss)), Some(correct_iss)) => {
294 if !correct_iss.contains(&*iss) {
295 return Err(new_error(ErrorKind::InvalidIssuer));
296 }
297 }
298 (TryParse::Parsed(Issuer::Multiple(iss)), Some(correct_iss)) => {
299 if !is_subset(correct_iss, &iss) {
300 return Err(new_error(ErrorKind::InvalidIssuer));
301 }
302 }
303 _ => {}
304 }
305
306 if !options.validate_aud {
307 return Ok(());
308 }
309 match (claims.aud, options.aud.as_ref()) {
310 (TryParse::Parsed(_), None) => {
316 return Err(new_error(ErrorKind::InvalidAudience));
317 }
318 (TryParse::Parsed(Audience::Single(aud)), Some(correct_aud)) => {
319 if !correct_aud.contains(&*aud) {
320 return Err(new_error(ErrorKind::InvalidAudience));
321 }
322 }
323 (TryParse::Parsed(Audience::Multiple(aud)), Some(correct_aud)) => {
324 if !is_subset(correct_aud, &aud) {
325 return Err(new_error(ErrorKind::InvalidAudience));
326 }
327 }
328 _ => {}
329 }
330
331 Ok(())
332}
333
334fn numeric_type<'de, D>(deserializer: D) -> std::result::Result<TryParse<u64>, D::Error>
335where
336 D: Deserializer<'de>,
337{
338 struct NumericType(PhantomData<fn() -> TryParse<u64>>);
339
340 impl<'de> Visitor<'de> for NumericType {
341 type Value = TryParse<u64>;
342
343 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
344 formatter.write_str("A NumericType that can be reasonably coerced into a u64")
345 }
346
347 fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
348 where
349 E: de::Error,
350 {
351 if value.is_finite() && value >= 0.0 && value < (u64::MAX as f64) {
352 Ok(TryParse::Parsed(value.round() as u64))
353 } else {
354 Err(serde::de::Error::custom("NumericType must be representable as a u64"))
355 }
356 }
357
358 fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
359 where
360 E: de::Error,
361 {
362 Ok(TryParse::Parsed(value))
363 }
364 }
365
366 match deserializer.deserialize_any(NumericType(PhantomData)) {
367 Ok(ok) => Ok(ok),
368 Err(_) => Ok(TryParse::FailedToParse),
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use serde_json::json;
375
376 use super::{get_current_timestamp, validate, ClaimsForValidation, Validation};
377
378 use crate::errors::ErrorKind;
379 use crate::Algorithm;
380 use std::collections::HashSet;
381
382 fn deserialize_claims(claims: &serde_json::Value) -> ClaimsForValidation {
383 serde::Deserialize::deserialize(claims).unwrap()
384 }
385
386 #[test]
387 fn exp_in_future_ok() {
388 let claims = json!({ "exp": get_current_timestamp() + 10000 });
389 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
390 assert!(res.is_ok());
391 }
392
393 #[test]
394 fn exp_in_future_but_in_rejection_period_fails() {
395 let claims = json!({ "exp": get_current_timestamp() + 500 });
396 let mut validation = Validation::new(Algorithm::HS256);
397 validation.leeway = 0;
398 validation.reject_tokens_expiring_in_less_than = 501;
399 let res = validate(deserialize_claims(&claims), &validation);
400 assert!(res.is_err());
401 }
402
403 #[test]
404 fn exp_float_in_future_ok() {
405 let claims = json!({ "exp": (get_current_timestamp() as f64) + 10000.123 });
406 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
407 assert!(res.is_ok());
408 }
409
410 #[test]
411 fn exp_float_in_future_but_in_rejection_period_fails() {
412 let claims = json!({ "exp": (get_current_timestamp() as f64) + 500.123 });
413 let mut validation = Validation::new(Algorithm::HS256);
414 validation.leeway = 0;
415 validation.reject_tokens_expiring_in_less_than = 501;
416 let res = validate(deserialize_claims(&claims), &validation);
417 assert!(res.is_err());
418 }
419
420 #[test]
421 fn exp_in_past_fails() {
422 let claims = json!({ "exp": get_current_timestamp() - 100000 });
423 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
424 assert!(res.is_err());
425
426 match res.unwrap_err().kind() {
427 ErrorKind::ExpiredSignature => (),
428 _ => unreachable!(),
429 };
430 }
431
432 #[test]
433 fn exp_float_in_past_fails() {
434 let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 });
435 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
436 assert!(res.is_err());
437
438 match res.unwrap_err().kind() {
439 ErrorKind::ExpiredSignature => (),
440 _ => unreachable!(),
441 };
442 }
443
444 #[test]
445 fn exp_in_past_but_in_leeway_ok() {
446 let claims = json!({ "exp": get_current_timestamp() - 500 });
447 let mut validation = Validation::new(Algorithm::HS256);
448 validation.leeway = 1000 * 60;
449 let res = validate(deserialize_claims(&claims), &validation);
450 assert!(res.is_ok());
451 }
452
453 #[test]
455 fn validate_required_fields_are_present() {
456 for spec_claim in ["exp", "nbf", "aud", "iss", "sub"] {
457 let claims = json!({});
458 let mut validation = Validation::new(Algorithm::HS256);
459 validation.set_required_spec_claims(&[spec_claim]);
460 let res = validate(deserialize_claims(&claims), &validation).unwrap_err();
461 assert_eq!(res.kind(), &ErrorKind::MissingRequiredClaim(spec_claim.to_owned()));
462 }
463 }
464
465 #[test]
466 fn exp_validated_but_not_required_ok() {
467 let claims = json!({});
468 let mut validation = Validation::new(Algorithm::HS256);
469 validation.required_spec_claims = HashSet::new();
470 validation.validate_exp = true;
471 let res = validate(deserialize_claims(&claims), &validation);
472 assert!(res.is_ok());
473 }
474
475 #[test]
476 fn exp_validated_but_not_required_fails() {
477 let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 });
478 let mut validation = Validation::new(Algorithm::HS256);
479 validation.required_spec_claims = HashSet::new();
480 validation.validate_exp = true;
481 let res = validate(deserialize_claims(&claims), &validation);
482 assert!(res.is_err());
483 }
484
485 #[test]
486 fn exp_required_but_not_validated_ok() {
487 let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 });
488 let mut validation = Validation::new(Algorithm::HS256);
489 validation.set_required_spec_claims(&["exp"]);
490 validation.validate_exp = false;
491 let res = validate(deserialize_claims(&claims), &validation);
492 assert!(res.is_ok());
493 }
494
495 #[test]
496 fn exp_required_but_not_validated_fails() {
497 let claims = json!({});
498 let mut validation = Validation::new(Algorithm::HS256);
499 validation.set_required_spec_claims(&["exp"]);
500 validation.validate_exp = false;
501 let res = validate(deserialize_claims(&claims), &validation);
502 assert!(res.is_err());
503 }
504
505 #[test]
506 fn nbf_in_past_ok() {
507 let claims = json!({ "nbf": get_current_timestamp() - 10000 });
508 let mut validation = Validation::new(Algorithm::HS256);
509 validation.required_spec_claims = HashSet::new();
510 validation.validate_exp = false;
511 validation.validate_nbf = true;
512 let res = validate(deserialize_claims(&claims), &validation);
513 assert!(res.is_ok());
514 }
515
516 #[test]
517 fn nbf_float_in_past_ok() {
518 let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 });
519 let mut validation = Validation::new(Algorithm::HS256);
520 validation.required_spec_claims = HashSet::new();
521 validation.validate_exp = false;
522 validation.validate_nbf = true;
523 let res = validate(deserialize_claims(&claims), &validation);
524 assert!(res.is_ok());
525 }
526
527 #[test]
528 fn nbf_in_future_fails() {
529 let claims = json!({ "nbf": get_current_timestamp() + 100000 });
530 let mut validation = Validation::new(Algorithm::HS256);
531 validation.required_spec_claims = HashSet::new();
532 validation.validate_exp = false;
533 validation.validate_nbf = true;
534 let res = validate(deserialize_claims(&claims), &validation);
535 assert!(res.is_err());
536
537 match res.unwrap_err().kind() {
538 ErrorKind::ImmatureSignature => (),
539 _ => unreachable!(),
540 };
541 }
542
543 #[test]
544 fn nbf_in_future_but_in_leeway_ok() {
545 let claims = json!({ "nbf": get_current_timestamp() + 500 });
546 let mut validation = Validation::new(Algorithm::HS256);
547 validation.required_spec_claims = HashSet::new();
548 validation.validate_exp = false;
549 validation.validate_nbf = true;
550 validation.leeway = 1000 * 60;
551 let res = validate(deserialize_claims(&claims), &validation);
552 assert!(res.is_ok());
553 }
554
555 #[test]
556 fn iss_string_ok() {
557 let claims = json!({"iss": ["Keats"]});
558 let mut validation = Validation::new(Algorithm::HS256);
559 validation.required_spec_claims = HashSet::new();
560 validation.validate_exp = false;
561 validation.set_issuer(&["Keats"]);
562 let res = validate(deserialize_claims(&claims), &validation);
563 assert!(res.is_ok());
564 }
565
566 #[test]
567 fn iss_array_of_string_ok() {
568 let claims = json!({"iss": ["UserA", "UserB"]});
569 let mut validation = Validation::new(Algorithm::HS256);
570 validation.required_spec_claims = HashSet::new();
571 validation.validate_exp = false;
572 validation.set_issuer(&["UserA", "UserB"]);
573 let res = validate(deserialize_claims(&claims), &validation);
574 assert!(res.is_ok());
575 }
576
577 #[test]
578 fn iss_not_matching_fails() {
579 let claims = json!({"iss": "Hacked"});
580
581 let mut validation = Validation::new(Algorithm::HS256);
582 validation.required_spec_claims = HashSet::new();
583 validation.validate_exp = false;
584 validation.set_issuer(&["Keats"]);
585 let res = validate(deserialize_claims(&claims), &validation);
586 assert!(res.is_err());
587
588 match res.unwrap_err().kind() {
589 ErrorKind::InvalidIssuer => (),
590 _ => unreachable!(),
591 };
592 }
593
594 #[test]
595 fn iss_missing_fails() {
596 let claims = json!({});
597
598 let mut validation = Validation::new(Algorithm::HS256);
599 validation.set_required_spec_claims(&["iss"]);
600 validation.validate_exp = false;
601 validation.set_issuer(&["Keats"]);
602 let res = validate(deserialize_claims(&claims), &validation);
603
604 match res.unwrap_err().kind() {
605 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "iss"),
606 _ => unreachable!(),
607 };
608 }
609
610 #[test]
611 fn sub_ok() {
612 let claims = json!({"sub": "Keats"});
613 let mut validation = Validation::new(Algorithm::HS256);
614 validation.required_spec_claims = HashSet::new();
615 validation.validate_exp = false;
616 validation.sub = Some("Keats".to_owned());
617 let res = validate(deserialize_claims(&claims), &validation);
618 assert!(res.is_ok());
619 }
620
621 #[test]
622 fn sub_not_matching_fails() {
623 let claims = json!({"sub": "Hacked"});
624 let mut validation = Validation::new(Algorithm::HS256);
625 validation.required_spec_claims = HashSet::new();
626 validation.validate_exp = false;
627 validation.sub = Some("Keats".to_owned());
628 let res = validate(deserialize_claims(&claims), &validation);
629 assert!(res.is_err());
630
631 match res.unwrap_err().kind() {
632 ErrorKind::InvalidSubject => (),
633 _ => unreachable!(),
634 };
635 }
636
637 #[test]
638 fn sub_missing_fails() {
639 let claims = json!({});
640 let mut validation = Validation::new(Algorithm::HS256);
641 validation.validate_exp = false;
642 validation.set_required_spec_claims(&["sub"]);
643 validation.sub = Some("Keats".to_owned());
644 let res = validate(deserialize_claims(&claims), &validation);
645 assert!(res.is_err());
646
647 match res.unwrap_err().kind() {
648 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "sub"),
649 _ => unreachable!(),
650 };
651 }
652
653 #[test]
654 fn aud_string_ok() {
655 let claims = json!({"aud": "Everyone"});
656 let mut validation = Validation::new(Algorithm::HS256);
657 validation.validate_exp = false;
658 validation.required_spec_claims = HashSet::new();
659 validation.set_audience(&["Everyone"]);
660 let res = validate(deserialize_claims(&claims), &validation);
661 assert!(res.is_ok());
662 }
663
664 #[test]
665 fn aud_array_of_string_ok() {
666 let claims = json!({"aud": ["UserA", "UserB"]});
667 let mut validation = Validation::new(Algorithm::HS256);
668 validation.validate_exp = false;
669 validation.required_spec_claims = HashSet::new();
670 validation.set_audience(&["UserA", "UserB"]);
671 let res = validate(deserialize_claims(&claims), &validation);
672 assert!(res.is_ok());
673 }
674
675 #[test]
676 fn aud_type_mismatch_fails() {
677 let claims = json!({"aud": ["Everyone"]});
678 let mut validation = Validation::new(Algorithm::HS256);
679 validation.validate_exp = false;
680 validation.required_spec_claims = HashSet::new();
681 validation.set_audience(&["UserA", "UserB"]);
682 let res = validate(deserialize_claims(&claims), &validation);
683 assert!(res.is_err());
684
685 match res.unwrap_err().kind() {
686 ErrorKind::InvalidAudience => (),
687 _ => unreachable!(),
688 };
689 }
690
691 #[test]
692 fn aud_correct_type_not_matching_fails() {
693 let claims = json!({"aud": ["Everyone"]});
694 let mut validation = Validation::new(Algorithm::HS256);
695 validation.validate_exp = false;
696 validation.required_spec_claims = HashSet::new();
697 validation.set_audience(&["None"]);
698 let res = validate(deserialize_claims(&claims), &validation);
699 assert!(res.is_err());
700
701 match res.unwrap_err().kind() {
702 ErrorKind::InvalidAudience => (),
703 _ => unreachable!(),
704 };
705 }
706
707 #[test]
708 fn aud_none_fails() {
709 let claims = json!({"aud": ["Everyone"]});
710 let mut validation = Validation::new(Algorithm::HS256);
711 validation.validate_exp = false;
712 validation.required_spec_claims = HashSet::new();
713 validation.aud = None;
714 let res = validate(deserialize_claims(&claims), &validation);
715 assert!(res.is_err());
716
717 match res.unwrap_err().kind() {
718 ErrorKind::InvalidAudience => (),
719 _ => unreachable!(),
720 };
721 }
722
723 #[test]
724 fn aud_validation_skipped() {
725 let claims = json!({"aud": ["Everyone"]});
726 let mut validation = Validation::new(Algorithm::HS256);
727 validation.validate_exp = false;
728 validation.validate_aud = false;
729 validation.required_spec_claims = HashSet::new();
730 validation.aud = None;
731 let res = validate(deserialize_claims(&claims), &validation);
732 assert!(res.is_ok());
733 }
734
735 #[test]
736 fn aud_missing_fails() {
737 let claims = json!({});
738 let mut validation = Validation::new(Algorithm::HS256);
739 validation.validate_exp = false;
740 validation.set_required_spec_claims(&["aud"]);
741 validation.set_audience(&["None"]);
742 let res = validate(deserialize_claims(&claims), &validation);
743 assert!(res.is_err());
744
745 match res.unwrap_err().kind() {
746 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "aud"),
747 _ => unreachable!(),
748 };
749 }
750
751 #[test]
753 fn does_validation_in_right_order() {
754 let claims = json!({ "exp": get_current_timestamp() + 10000 });
755
756 let mut validation = Validation::new(Algorithm::HS256);
757 validation.set_required_spec_claims(&["exp", "iss"]);
758 validation.leeway = 5;
759 validation.set_issuer(&["iss no check"]);
760 validation.set_audience(&["iss no check"]);
761
762 let res = validate(deserialize_claims(&claims), &validation);
763 assert!(res.is_err());
765 match res.unwrap_err().kind() {
766 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "iss"),
767 t => panic!("{:?}", t),
768 };
769 }
770
771 #[test]
773 fn aud_use_validation_struct() {
774 let claims = json!({"aud": "my-googleclientid1234.apps.googleusercontent.com"});
775
776 let aud = "my-googleclientid1234.apps.googleusercontent.com".to_string();
777 let mut aud_hashset = std::collections::HashSet::new();
778 aud_hashset.insert(aud);
779 let mut validation = Validation::new(Algorithm::HS256);
780 validation.validate_exp = false;
781 validation.required_spec_claims = HashSet::new();
782 validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]);
783
784 let res = validate(deserialize_claims(&claims), &validation);
785 assert!(res.is_ok());
786 }
787}