Skip to main content

mas_jose/
claims.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::HashMap, convert::Infallible, marker::PhantomData, ops::Deref};
16
17use base64ct::{Base64UrlUnpadded, Encoding};
18use mas_iana::jose::JsonWebSignatureAlg;
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20use sha2::{Digest, Sha256, Sha384, Sha512};
21use thiserror::Error;
22
23#[derive(Debug, Error)]
24pub enum ClaimError {
25    #[error("missing claim {0:?}")]
26    MissingClaim(&'static str),
27
28    #[error("invalid claim {0:?}")]
29    InvalidClaim(&'static str),
30
31    #[error("could not validate claim {claim:?}")]
32    ValidationError {
33        claim: &'static str,
34        #[source]
35        source: Box<dyn std::error::Error + Send + Sync + 'static>,
36    },
37}
38
39pub trait Validator<T> {
40    /// The associated error type returned by this validator.
41    type Error;
42
43    /// Validate a claim value
44    ///
45    /// # Errors
46    ///
47    /// Returns an error if the value is invalid.
48    fn validate(&self, value: &T) -> Result<(), Self::Error>;
49}
50
51impl<T> Validator<T> for () {
52    type Error = Infallible;
53
54    fn validate(&self, _value: &T) -> Result<(), Self::Error> {
55        Ok(())
56    }
57}
58
59pub struct Claim<T, V = ()> {
60    claim: &'static str,
61    t: PhantomData<T>,
62    v: PhantomData<V>,
63}
64
65impl<T, V> Claim<T, V>
66where
67    V: Validator<T>,
68{
69    #[must_use]
70    pub const fn new(claim: &'static str) -> Self {
71        Self {
72            claim,
73            t: PhantomData,
74            v: PhantomData,
75        }
76    }
77
78    /// Insert a claim into the given claims map.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if the value failed to serialize.
83    pub fn insert<I>(
84        &self,
85        claims: &mut HashMap<String, serde_json::Value>,
86        value: I,
87    ) -> Result<(), ClaimError>
88    where
89        I: Into<T>,
90        T: Serialize,
91    {
92        let value = value.into();
93        let value: serde_json::Value =
94            serde_json::to_value(&value).map_err(|_| ClaimError::InvalidClaim(self.claim))?;
95        claims.insert(self.claim.to_owned(), value);
96
97        Ok(())
98    }
99
100    /// Extract a claim from the given claims map.
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if the value failed to deserialize, if its value is
105    /// invalid or if the claim is missing.
106    pub fn extract_required(
107        &self,
108        claims: &mut HashMap<String, serde_json::Value>,
109    ) -> Result<T, ClaimError>
110    where
111        T: DeserializeOwned,
112        V: Default,
113        V::Error: std::error::Error + Send + Sync + 'static,
114    {
115        let validator = V::default();
116        self.extract_required_with_options(claims, validator)
117    }
118
119    /// Extract a claim from the given claims map, with the given options.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if the value failed to deserialize, if its value is
124    /// invalid or if the claim is missing.
125    pub fn extract_required_with_options<I>(
126        &self,
127        claims: &mut HashMap<String, serde_json::Value>,
128        validator: I,
129    ) -> Result<T, ClaimError>
130    where
131        T: DeserializeOwned,
132        I: Into<V>,
133        V::Error: std::error::Error + Send + Sync + 'static,
134    {
135        let validator: V = validator.into();
136        let claim = claims
137            .remove(self.claim)
138            .ok_or(ClaimError::MissingClaim(self.claim))?;
139
140        let res =
141            serde_json::from_value(claim).map_err(|_| ClaimError::InvalidClaim(self.claim))?;
142        validator
143            .validate(&res)
144            .map_err(|source| ClaimError::ValidationError {
145                claim: self.claim,
146                source: Box::new(source),
147            })?;
148        Ok(res)
149    }
150
151    /// Extract a claim from the given claims map, if it exists.
152    ///
153    /// # Errors
154    ///
155    /// Returns an error if the value failed to deserialize or if its value is
156    /// invalid.
157    pub fn extract_optional(
158        &self,
159        claims: &mut HashMap<String, serde_json::Value>,
160    ) -> Result<Option<T>, ClaimError>
161    where
162        T: DeserializeOwned,
163        V: Default,
164        V::Error: std::error::Error + Send + Sync + 'static,
165    {
166        let validator = V::default();
167        self.extract_optional_with_options(claims, validator)
168    }
169
170    /// Extract a claim from the given claims map, if it exists, with the given
171    /// options.
172    ///
173    /// # Errors
174    ///
175    /// Returns an error if the value failed to deserialize or if its value is
176    /// invalid.
177    pub fn extract_optional_with_options<I>(
178        &self,
179        claims: &mut HashMap<String, serde_json::Value>,
180        validator: I,
181    ) -> Result<Option<T>, ClaimError>
182    where
183        T: DeserializeOwned,
184        I: Into<V>,
185        V::Error: std::error::Error + Send + Sync + 'static,
186    {
187        match self.extract_required_with_options(claims, validator) {
188            Ok(v) => Ok(Some(v)),
189            Err(ClaimError::MissingClaim(_)) => Ok(None),
190            Err(e) => Err(e),
191        }
192    }
193}
194
195#[derive(Debug, Clone)]
196pub struct TimeOptions {
197    when: chrono::DateTime<chrono::Utc>,
198    leeway: chrono::Duration,
199}
200
201impl TimeOptions {
202    #[must_use]
203    pub fn new(when: chrono::DateTime<chrono::Utc>) -> Self {
204        Self {
205            when,
206            leeway: chrono::Duration::microseconds(5 * 60 * 1000 * 1000),
207        }
208    }
209
210    #[must_use]
211    pub fn leeway(mut self, leeway: chrono::Duration) -> Self {
212        self.leeway = leeway;
213        self
214    }
215}
216
217#[derive(Debug, Clone, Copy, Error)]
218#[error("Current time is too far away")]
219pub struct TimeTooFarError;
220
221#[derive(Debug, Clone)]
222pub struct TimeNotAfter(TimeOptions);
223
224impl Validator<Timestamp> for TimeNotAfter {
225    type Error = TimeTooFarError;
226    fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
227        if self.0.when <= value.0 + self.0.leeway {
228            Ok(())
229        } else {
230            Err(TimeTooFarError)
231        }
232    }
233}
234
235impl From<TimeOptions> for TimeNotAfter {
236    fn from(opt: TimeOptions) -> Self {
237        Self(opt)
238    }
239}
240
241impl From<&TimeOptions> for TimeNotAfter {
242    fn from(opt: &TimeOptions) -> Self {
243        opt.clone().into()
244    }
245}
246
247#[derive(Debug, Clone)]
248pub struct TimeNotBefore(TimeOptions);
249
250impl Validator<Timestamp> for TimeNotBefore {
251    type Error = TimeTooFarError;
252
253    fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
254        if self.0.when >= value.0 - self.0.leeway {
255            Ok(())
256        } else {
257            Err(TimeTooFarError)
258        }
259    }
260}
261
262impl From<TimeOptions> for TimeNotBefore {
263    fn from(opt: TimeOptions) -> Self {
264        Self(opt)
265    }
266}
267
268impl From<&TimeOptions> for TimeNotBefore {
269    fn from(opt: &TimeOptions) -> Self {
270        opt.clone().into()
271    }
272}
273
274/// Hash the given token with the given algorithm for an ID Token claim.
275///
276/// According to the [OpenID Connect Core 1.0 specification].
277///
278/// # Errors
279///
280/// Returns an error if the algorithm is not supported.
281///
282/// [OpenID Connect Core 1.0 specification]: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
283pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> Result<String, TokenHashError> {
284    let bits = match alg {
285        JsonWebSignatureAlg::Hs256
286        | JsonWebSignatureAlg::Rs256
287        | JsonWebSignatureAlg::Es256
288        | JsonWebSignatureAlg::Ps256
289        | JsonWebSignatureAlg::Es256K => {
290            let mut hasher = Sha256::new();
291            hasher.update(token);
292            let hash: [u8; 32] = hasher.finalize().into();
293            // Left-most half
294            hash[..16].to_owned()
295        }
296        JsonWebSignatureAlg::Hs384
297        | JsonWebSignatureAlg::Rs384
298        | JsonWebSignatureAlg::Es384
299        | JsonWebSignatureAlg::Ps384 => {
300            let mut hasher = Sha384::new();
301            hasher.update(token);
302            let hash: [u8; 48] = hasher.finalize().into();
303            // Left-most half
304            hash[..24].to_owned()
305        }
306        JsonWebSignatureAlg::Hs512
307        | JsonWebSignatureAlg::Rs512
308        | JsonWebSignatureAlg::Es512
309        | JsonWebSignatureAlg::Ps512 => {
310            let mut hasher = Sha512::new();
311            hasher.update(token);
312            let hash: [u8; 64] = hasher.finalize().into();
313            // Left-most half
314            hash[..32].to_owned()
315        }
316        _ => return Err(TokenHashError::UnsupportedAlgorithm),
317    };
318
319    Ok(Base64UrlUnpadded::encode_string(&bits))
320}
321
322#[derive(Debug, Clone, Copy, Error)]
323pub enum TokenHashError {
324    #[error("Hashes don't match")]
325    HashMismatch,
326
327    #[error("Unsupported algorithm for hashing")]
328    UnsupportedAlgorithm,
329}
330
331#[derive(Debug, Clone)]
332pub struct TokenHash<'a> {
333    alg: &'a JsonWebSignatureAlg,
334    token: &'a str,
335}
336
337impl<'a> TokenHash<'a> {
338    /// Creates a new `TokenHash` validator for the given algorithm and token.
339    #[must_use]
340    pub fn new(alg: &'a JsonWebSignatureAlg, token: &'a str) -> Self {
341        Self { alg, token }
342    }
343}
344
345impl<'a> Validator<String> for TokenHash<'a> {
346    type Error = TokenHashError;
347    fn validate(&self, value: &String) -> Result<(), Self::Error> {
348        if hash_token(self.alg, self.token)? == *value {
349            Ok(())
350        } else {
351            Err(TokenHashError::HashMismatch)
352        }
353    }
354}
355
356#[derive(Debug, Clone, Copy, Error)]
357#[error("Values don't match")]
358pub struct EqualityError;
359
360#[derive(Debug, Clone)]
361pub struct Equality<'a, T: ?Sized> {
362    value: &'a T,
363}
364
365impl<'a, T: ?Sized> Equality<'a, T> {
366    /// Creates a new `Equality` validator for the given value.
367    #[must_use]
368    pub fn new(value: &'a T) -> Self {
369        Self { value }
370    }
371}
372
373impl<'a, T1, T2> Validator<T1> for Equality<'a, T2>
374where
375    T2: PartialEq<T1> + ?Sized,
376{
377    type Error = EqualityError;
378    fn validate(&self, value: &T1) -> Result<(), Self::Error> {
379        if *self.value == *value {
380            Ok(())
381        } else {
382            Err(EqualityError)
383        }
384    }
385}
386
387impl<'a, T: ?Sized> From<&'a T> for Equality<'a, T> {
388    fn from(value: &'a T) -> Self {
389        Self::new(value)
390    }
391}
392
393#[derive(Debug, Clone)]
394pub struct Contains<'a, T> {
395    value: &'a T,
396}
397
398impl<'a, T> Contains<'a, T> {
399    /// Creates a new `Contains` validator for the given value.
400    #[must_use]
401    pub fn new(value: &'a T) -> Self {
402        Self { value }
403    }
404}
405
406#[derive(Debug, Clone, Copy, Error)]
407#[error("OneOrMany doesn't contain value")]
408pub struct ContainsError;
409
410impl<'a, T> Validator<OneOrMany<T>> for Contains<'a, T>
411where
412    T: PartialEq,
413{
414    type Error = ContainsError;
415    fn validate(&self, value: &OneOrMany<T>) -> Result<(), Self::Error> {
416        if value.contains(self.value) {
417            Ok(())
418        } else {
419            Err(ContainsError)
420        }
421    }
422}
423
424impl<'a, T> From<&'a T> for Contains<'a, T> {
425    fn from(value: &'a T) -> Self {
426        Self::new(value)
427    }
428}
429
430#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
431#[serde(transparent)]
432pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime<chrono::Utc>);
433
434impl Deref for Timestamp {
435    type Target = chrono::DateTime<chrono::Utc>;
436
437    fn deref(&self) -> &Self::Target {
438        &self.0
439    }
440}
441
442impl From<chrono::DateTime<chrono::Utc>> for Timestamp {
443    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
444        Timestamp(value)
445    }
446}
447
448#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
449#[serde(
450    transparent,
451    bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>")
452)]
453pub struct OneOrMany<T>(
454    // serde_as seems to not work properly with #[serde(transparent)]
455    // We have use plain old #[serde(with = ...)] with serde_with's utilities, which is a bit
456    // verbose but works
457    #[serde(
458        with = "serde_with::As::<serde_with::OneOrMany<serde_with::Same, serde_with::formats::PreferOne>>"
459    )]
460    Vec<T>,
461);
462
463impl<T> Deref for OneOrMany<T> {
464    type Target = Vec<T>;
465
466    fn deref(&self) -> &Self::Target {
467        &self.0
468    }
469}
470
471impl<T> From<Vec<T>> for OneOrMany<T> {
472    fn from(value: Vec<T>) -> Self {
473        Self(value)
474    }
475}
476
477impl<T> From<T> for OneOrMany<T> {
478    fn from(value: T) -> Self {
479        Self(vec![value])
480    }
481}
482
483/// Claims defined in RFC7519 sec. 4.1
484/// <https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1>
485mod rfc7519 {
486    use super::{Claim, Contains, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp};
487
488    pub const ISS: Claim<String, Equality<str>> = Claim::new("iss");
489    pub const SUB: Claim<String> = Claim::new("sub");
490    pub const AUD: Claim<OneOrMany<String>, Contains<String>> = Claim::new("aud");
491    pub const NBF: Claim<Timestamp, TimeNotBefore> = Claim::new("nbf");
492    pub const EXP: Claim<Timestamp, TimeNotAfter> = Claim::new("exp");
493    pub const IAT: Claim<Timestamp, TimeNotBefore> = Claim::new("iat");
494    pub const JTI: Claim<String> = Claim::new("jti");
495}
496
497/// Claims defined in OIDC.Core sec. 2 and sec. 5.1
498/// <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>
499/// <https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims>
500mod oidc_core {
501    use url::Url;
502
503    use super::{Claim, Equality, Timestamp, TokenHash};
504
505    pub const AUTH_TIME: Claim<Timestamp> = Claim::new("auth_time");
506    pub const NONCE: Claim<String, Equality<str>> = Claim::new("nonce");
507    pub const AT_HASH: Claim<String, TokenHash> = Claim::new("at_hash");
508    pub const C_HASH: Claim<String, TokenHash> = Claim::new("c_hash");
509
510    pub const NAME: Claim<String> = Claim::new("name");
511    pub const GIVEN_NAME: Claim<String> = Claim::new("given_name");
512    pub const FAMILY_NAME: Claim<String> = Claim::new("family_name");
513    pub const MIDDLE_NAME: Claim<String> = Claim::new("middle_name");
514    pub const NICKNAME: Claim<String> = Claim::new("nickname");
515    pub const PREFERRED_USERNAME: Claim<String> = Claim::new("preferred_username");
516    pub const PROFILE: Claim<Url> = Claim::new("profile");
517    pub const PICTURE: Claim<Url> = Claim::new("picture");
518    pub const WEBSITE: Claim<Url> = Claim::new("website");
519    // TODO: email type?
520    pub const EMAIL: Claim<String> = Claim::new("email");
521    pub const EMAIL_VERIFIED: Claim<bool> = Claim::new("email_verified");
522    pub const GENDER: Claim<String> = Claim::new("gender");
523    // TODO: date type
524    pub const BIRTHDATE: Claim<String> = Claim::new("birthdate");
525    // TODO: timezone type
526    pub const ZONEINFO: Claim<String> = Claim::new("zoneinfo");
527    // TODO: locale type
528    pub const LOCALE: Claim<String> = Claim::new("locale");
529    // TODO: phone number type
530    pub const PHONE_NUMBER: Claim<String> = Claim::new("phone_number");
531    pub const PHONE_NUMBER_VERIFIED: Claim<bool> = Claim::new("phone_number_verified");
532    // TODO: pub const ADDRESS: Claim<Timestamp> = Claim::new("address");
533    pub const UPDATED_AT: Claim<Timestamp> = Claim::new("updated_at");
534}
535
536pub use self::{oidc_core::*, rfc7519::*};
537
538#[cfg(test)]
539mod tests {
540    use chrono::TimeZone;
541
542    use super::*;
543
544    #[test]
545    fn timestamp_serde() {
546        let datetime = Timestamp(
547            chrono::Utc
548                .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
549                .unwrap(),
550        );
551        let timestamp = serde_json::Value::Number(1_516_239_022.into());
552
553        assert_eq!(datetime, serde_json::from_value(timestamp.clone()).unwrap());
554        assert_eq!(timestamp, serde_json::to_value(&datetime).unwrap());
555    }
556
557    #[test]
558    fn one_or_many_serde() {
559        let one = OneOrMany(vec!["one".to_owned()]);
560        let many = OneOrMany(vec!["one".to_owned(), "two".to_owned()]);
561
562        assert_eq!(
563            one,
564            serde_json::from_value(serde_json::json!("one")).unwrap()
565        );
566        assert_eq!(
567            one,
568            serde_json::from_value(serde_json::json!(["one"])).unwrap()
569        );
570        assert_eq!(
571            many,
572            serde_json::from_value(serde_json::json!(["one", "two"])).unwrap()
573        );
574        assert_eq!(
575            serde_json::to_value(&one).unwrap(),
576            serde_json::json!("one")
577        );
578        assert_eq!(
579            serde_json::to_value(&many).unwrap(),
580            serde_json::json!(["one", "two"])
581        );
582    }
583
584    #[test]
585    fn extract_claims() {
586        let now = chrono::Utc
587            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
588            .unwrap();
589        let expiration = now + chrono::Duration::microseconds(5 * 60 * 1000 * 1000);
590        let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
591
592        let claims = serde_json::json!({
593            "iss": "https://foo.com",
594            "sub": "johndoe",
595            "aud": ["abcd-efgh"],
596            "iat": 1_516_239_022,
597            "nbf": 1_516_239_022,
598            "exp": 1_516_239_322,
599            "jti": "1122-3344-5566-7788",
600        });
601        let mut claims = serde_json::from_value(claims).unwrap();
602
603        let iss = ISS
604            .extract_required_with_options(&mut claims, "https://foo.com")
605            .unwrap();
606        let sub = SUB.extract_optional(&mut claims).unwrap();
607        let aud = AUD
608            .extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned())
609            .unwrap();
610        let nbf = NBF
611            .extract_optional_with_options(&mut claims, &time_options)
612            .unwrap();
613        let exp = EXP
614            .extract_optional_with_options(&mut claims, &time_options)
615            .unwrap();
616        let iat = IAT
617            .extract_optional_with_options(&mut claims, &time_options)
618            .unwrap();
619        let jti = JTI.extract_optional(&mut claims).unwrap();
620
621        assert_eq!(iss, "https://foo.com".to_owned());
622        assert_eq!(sub, Some("johndoe".to_owned()));
623        assert_eq!(aud.as_deref(), Some(&vec!["abcd-efgh".to_owned()]));
624        assert_eq!(iat.as_deref(), Some(&now));
625        assert_eq!(nbf.as_deref(), Some(&now));
626        assert_eq!(exp.as_deref(), Some(&expiration));
627        assert_eq!(jti, Some("1122-3344-5566-7788".to_owned()));
628
629        assert!(claims.is_empty());
630    }
631
632    #[test]
633    fn time_validation() {
634        let now = chrono::Utc
635            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
636            .unwrap();
637
638        let claims = serde_json::json!({
639            "iat": 1_516_239_022,
640            "nbf": 1_516_239_022,
641            "exp": 1_516_239_322,
642        });
643        let claims: HashMap<String, serde_json::Value> = serde_json::from_value(claims).unwrap();
644
645        // Everything should be fine at this point, the claims iat & nbf == now
646        {
647            let mut claims = claims.clone();
648
649            // so no leeway should be fine as well here
650            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
651            assert!(IAT
652                .extract_required_with_options(&mut claims, &time_options)
653                .is_ok());
654            assert!(NBF
655                .extract_required_with_options(&mut claims, &time_options)
656                .is_ok());
657            assert!(EXP
658                .extract_required_with_options(&mut claims, &time_options)
659                .is_ok());
660        }
661
662        // Let's go back in time a bit
663        let now = now - chrono::Duration::microseconds(60 * 1000 * 1000);
664
665        {
666            // There is now a time variance between the two parties...
667            let mut claims = claims.clone();
668
669            // but no time variance is allowed. "iat" and "nbf" validation will fail
670            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
671            assert!(matches!(
672                IAT.extract_required_with_options(&mut claims, &time_options),
673                Err(ClaimError::ValidationError { claim: "iat", .. }),
674            ));
675            assert!(matches!(
676                NBF.extract_required_with_options(&mut claims, &time_options),
677                Err(ClaimError::ValidationError { claim: "nbf", .. }),
678            ));
679            assert!(EXP
680                .extract_required_with_options(&mut claims, &time_options)
681                .is_ok());
682        }
683
684        {
685            // This time, there is a two minute leeway, they all should be fine
686            let mut claims = claims.clone();
687
688            // but no time variance is allowed. "iat" and "nbf" validation will fail
689            let time_options =
690                TimeOptions::new(now).leeway(chrono::Duration::microseconds(2 * 60 * 1000 * 1000));
691            assert!(IAT
692                .extract_required_with_options(&mut claims, &time_options)
693                .is_ok());
694            assert!(NBF
695                .extract_required_with_options(&mut claims, &time_options)
696                .is_ok());
697            assert!(EXP
698                .extract_required_with_options(&mut claims, &time_options)
699                .is_ok());
700        }
701
702        // Let's wait some time so it expires
703        let now = now + chrono::Duration::microseconds((1 + 6) * 60 * 1000 * 1000);
704
705        {
706            // At this point, the claims expired one minute ago
707            let mut claims = claims.clone();
708
709            // but no time variance is allowed. "exp" validation will fail
710            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
711            assert!(IAT
712                .extract_required_with_options(&mut claims, &time_options)
713                .is_ok());
714            assert!(NBF
715                .extract_required_with_options(&mut claims, &time_options)
716                .is_ok());
717            assert!(matches!(
718                EXP.extract_required_with_options(&mut claims, &time_options),
719                Err(ClaimError::ValidationError { claim: "exp", .. }),
720            ));
721        }
722
723        {
724            let mut claims = claims;
725
726            // Same, but with a 2 minutes leeway should be fine then
727            let time_options =
728                TimeOptions::new(now).leeway(chrono::Duration::try_minutes(2).unwrap());
729            assert!(IAT
730                .extract_required_with_options(&mut claims, &time_options)
731                .is_ok());
732            assert!(NBF
733                .extract_required_with_options(&mut claims, &time_options)
734                .is_ok());
735            assert!(EXP
736                .extract_required_with_options(&mut claims, &time_options)
737                .is_ok());
738        }
739    }
740
741    #[test]
742    fn invalid_claims() {
743        let now = chrono::Utc
744            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
745            .unwrap();
746        let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
747
748        let claims = serde_json::json!({
749            "iss": 123,
750            "sub": 456,
751            "aud": 789,
752            "iat": "123",
753            "nbf": "456",
754            "exp": "789",
755            "jti": 123,
756        });
757        let mut claims = serde_json::from_value(claims).unwrap();
758
759        assert!(matches!(
760            ISS.extract_required_with_options(&mut claims, "https://foo.com"),
761            Err(ClaimError::InvalidClaim("iss"))
762        ));
763        assert!(matches!(
764            SUB.extract_required(&mut claims),
765            Err(ClaimError::InvalidClaim("sub"))
766        ));
767        assert!(matches!(
768            AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
769            Err(ClaimError::InvalidClaim("aud"))
770        ));
771        assert!(matches!(
772            NBF.extract_required_with_options(&mut claims, &time_options),
773            Err(ClaimError::InvalidClaim("nbf"))
774        ));
775        assert!(matches!(
776            EXP.extract_required_with_options(&mut claims, &time_options),
777            Err(ClaimError::InvalidClaim("exp"))
778        ));
779        assert!(matches!(
780            IAT.extract_required_with_options(&mut claims, &time_options),
781            Err(ClaimError::InvalidClaim("iat"))
782        ));
783        assert!(matches!(
784            JTI.extract_required(&mut claims),
785            Err(ClaimError::InvalidClaim("jti"))
786        ));
787    }
788
789    #[test]
790    fn missing_claims() {
791        // Empty claim set
792        let mut claims = HashMap::new();
793
794        assert!(matches!(
795            ISS.extract_required_with_options(&mut claims, "https://foo.com"),
796            Err(ClaimError::MissingClaim("iss"))
797        ));
798        assert!(matches!(
799            SUB.extract_required(&mut claims),
800            Err(ClaimError::MissingClaim("sub"))
801        ));
802        assert!(matches!(
803            AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
804            Err(ClaimError::MissingClaim("aud"))
805        ));
806
807        assert!(matches!(
808            ISS.extract_optional_with_options(&mut claims, "https://foo.com"),
809            Ok(None)
810        ));
811        assert!(matches!(SUB.extract_optional(&mut claims), Ok(None)));
812        assert!(matches!(
813            AUD.extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned()),
814            Ok(None)
815        ));
816    }
817
818    #[test]
819    fn string_eq_validation() {
820        let claims = serde_json::json!({
821            "iss": "https://foo.com",
822        });
823        let mut claims: HashMap<String, serde_json::Value> =
824            serde_json::from_value(claims).unwrap();
825
826        ISS.extract_required_with_options(&mut claims.clone(), "https://foo.com")
827            .unwrap();
828
829        assert!(matches!(
830            ISS.extract_required_with_options(&mut claims, "https://bar.com"),
831            Err(ClaimError::ValidationError { claim: "iss", .. }),
832        ));
833    }
834
835    #[test]
836    fn contains_validation() {
837        let claims = serde_json::json!({
838            "aud": "abcd-efgh",
839        });
840        let mut claims: HashMap<String, serde_json::Value> =
841            serde_json::from_value(claims).unwrap();
842
843        AUD.extract_required_with_options(&mut claims.clone(), &"abcd-efgh".to_owned())
844            .unwrap();
845
846        assert!(matches!(
847            AUD.extract_required_with_options(&mut claims, &"wxyz".to_owned()),
848            Err(ClaimError::ValidationError { claim: "aud", .. }),
849        ));
850    }
851}