keygate_jwt/
claims.rs

1use std::collections::HashSet;
2use std::convert::TryInto;
3
4use coarsetime::{Clock, Duration, UnixTimeStamp};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6
7use crate::common::VerificationOptions;
8use crate::ensure;
9use crate::error::*;
10use crate::serde_additions;
11
12pub const DEFAULT_TIME_TOLERANCE_SECS: u64 = 900;
13
14/// Type representing the fact that no application-defined claims is necessary.
15#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
16pub struct NoCustomClaims {}
17
18/// Depending on applications, the `audiences` property may be either a set or a
19/// string. We support both.
20#[derive(Debug, Clone, Eq, PartialEq)]
21pub enum Audiences {
22    AsSet(HashSet<String>),
23    AsString(String),
24}
25
26impl Audiences {
27    /// Return `true` if the audiences are represented as a set.
28    pub fn is_set(&self) -> bool {
29        matches!(self, Audiences::AsSet(_))
30    }
31
32    /// Return `true` if the audiences are represented as a string.
33    pub fn is_string(&self) -> bool {
34        matches!(self, Audiences::AsString(_))
35    }
36
37    /// Return `true` if the audiences include any of the `allowed_audiences`
38    /// entries
39    pub fn contains(&self, allowed_audiences: &HashSet<String>) -> bool {
40        match self {
41            Audiences::AsString(audience) => allowed_audiences.contains(audience),
42            Audiences::AsSet(audiences) => {
43                audiences.intersection(allowed_audiences).next().is_some()
44            }
45        }
46    }
47
48    /// Get the audiences as a set
49    pub fn into_set(self) -> HashSet<String> {
50        match self {
51            Audiences::AsSet(audiences_set) => audiences_set,
52            Audiences::AsString(audiences) => {
53                let mut audiences_set = HashSet::new();
54                if !audiences.is_empty() {
55                    audiences_set.insert(audiences);
56                }
57                audiences_set
58            }
59        }
60    }
61
62    /// Get the audiences as a string.
63    /// If it was originally serialized as a set, it can be only converted to a
64    /// string if it contains at most one element.
65    pub fn into_string(self) -> Result<String, JWTError> {
66        match self {
67            Audiences::AsString(audiences_str) => Ok(audiences_str),
68            Audiences::AsSet(audiences) => {
69                if audiences.len() > 1 {
70                    return Err(JWTError::TooManyAudiences);
71                }
72                Ok(audiences
73                    .iter()
74                    .next()
75                    .map(|x| x.to_string())
76                    .unwrap_or_default())
77            }
78        }
79    }
80}
81
82impl TryInto<String> for Audiences {
83    type Error = JWTError;
84
85    fn try_into(self) -> Result<String, JWTError> {
86        self.into_string()
87    }
88}
89
90impl From<Audiences> for HashSet<String> {
91    fn from(audiences: Audiences) -> HashSet<String> {
92        audiences.into_set()
93    }
94}
95
96impl<T: ToString> From<T> for Audiences {
97    fn from(audience: T) -> Self {
98        Audiences::AsString(audience.to_string())
99    }
100}
101
102/// A set of JWT claims.
103///
104/// The `CustomClaims` parameter can be set to `NoCustomClaims` if only standard
105/// claims are used, or to a user-defined type that must be `serde`-serializable
106/// if custom claims are required.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct JWTClaims<CustomClaims> {
109    /// Time the claims were created at
110    #[serde(
111        rename = "iat",
112        default,
113        skip_serializing_if = "Option::is_none",
114        with = "self::serde_additions::unix_timestamp"
115    )]
116    pub issued_at: Option<UnixTimeStamp>,
117
118    /// Time the claims expire at
119    #[serde(
120        rename = "exp",
121        default,
122        skip_serializing_if = "Option::is_none",
123        with = "self::serde_additions::unix_timestamp"
124    )]
125    pub expires_at: Option<UnixTimeStamp>,
126
127    /// Time the claims will be invalid until
128    #[serde(
129        rename = "nbf",
130        default,
131        skip_serializing_if = "Option::is_none",
132        with = "self::serde_additions::unix_timestamp"
133    )]
134    pub invalid_before: Option<UnixTimeStamp>,
135
136    /// Issuer - This can be set to anything application-specific
137    #[serde(rename = "iss", default, skip_serializing_if = "Option::is_none")]
138    pub issuer: Option<String>,
139
140    /// Subject - This can be set to anything application-specific
141    #[serde(rename = "sub", default, skip_serializing_if = "Option::is_none")]
142    pub subject: Option<String>,
143
144    /// Audience
145    #[serde(
146        rename = "aud",
147        default,
148        skip_serializing_if = "Option::is_none",
149        with = "self::serde_additions::audiences"
150    )]
151    pub audiences: Option<Audiences>,
152
153    /// JWT identifier
154    ///
155    /// That property was originally designed to avoid replay attacks, but
156    /// keeping all previously sent JWT token IDs is unrealistic.
157    ///
158    /// Replay attacks are better addressed by keeping only the timestamp of the
159    /// last valid token for a user, and rejecting anything older in future
160    /// tokens.
161    #[serde(rename = "jti", default, skip_serializing_if = "Option::is_none")]
162    pub jwt_id: Option<String>,
163
164    /// Nonce
165    #[serde(rename = "nonce", default, skip_serializing_if = "Option::is_none")]
166    pub nonce: Option<String>,
167
168    /// Custom (application-defined) claims
169    #[serde(flatten)]
170    pub custom: CustomClaims,
171}
172
173impl<CustomClaims> JWTClaims<CustomClaims> {
174    pub(crate) fn validate(&self, options: &VerificationOptions) -> Result<(), JWTError> {
175        let now = Clock::now_since_epoch();
176        let time_tolerance = options.time_tolerance.unwrap_or_default();
177
178        if let Some(reject_before) = options.reject_before {
179            if now > reject_before {
180                return Err(JWTError::OldTokenReused);
181            }
182        }
183        if let Some(time_issued) = self.issued_at {
184            ensure!(time_issued <= now + time_tolerance, JWTError::ClockDrift);
185            if let Some(max_validity) = options.max_validity {
186                ensure!(
187                    now <= time_issued || now - time_issued <= max_validity,
188                    JWTError::TokenIsTooOld
189                );
190            }
191        }
192        if !options.accept_future {
193            if let Some(invalid_before) = self.invalid_before {
194                ensure!(
195                    now + time_tolerance >= invalid_before,
196                    JWTError::TokenNotValidYet
197                );
198            }
199        }
200        if let Some(expires_at) = self.expires_at {
201            ensure!(
202                now - time_tolerance <= expires_at,
203                JWTError::TokenHasExpired
204            );
205        }
206        if let Some(allowed_issuers) = &options.allowed_issuers {
207            if let Some(issuer) = &self.issuer {
208                ensure!(
209                    allowed_issuers.contains(issuer),
210                    JWTError::RequiredIssuerMismatch
211                );
212            } else {
213                return Err(JWTError::RequiredIssuerMissing);
214            }
215        }
216        if let Some(required_subject) = &options.required_subject {
217            if let Some(subject) = &self.subject {
218                ensure!(
219                    subject == required_subject,
220                    JWTError::RequiredSubjectMismatch
221                );
222            } else {
223                return Err(JWTError::RequiredSubjectMissing);
224            }
225        }
226        if let Some(required_nonce) = &options.required_nonce {
227            if let Some(nonce) = &self.nonce {
228                ensure!(nonce == required_nonce, JWTError::RequiredNonceMismatch);
229            } else {
230                return Err(JWTError::RequiredNonceMissing);
231            }
232        }
233        if let Some(allowed_audiences) = &options.allowed_audiences {
234            if let Some(audiences) = &self.audiences {
235                ensure!(
236                    audiences.contains(allowed_audiences),
237                    JWTError::RequiredAudienceMismatch
238                );
239            } else {
240                return Err(JWTError::RequiredAudienceMissing);
241            }
242        }
243        Ok(())
244    }
245
246    /// Set the token as not being valid until `unix_timestamp`
247    pub fn invalid_before(mut self, unix_timestamp: UnixTimeStamp) -> Self {
248        self.invalid_before = Some(unix_timestamp);
249        self
250    }
251
252    /// Set the issuer
253    pub fn with_issuer(mut self, issuer: impl ToString) -> Self {
254        self.issuer = Some(issuer.to_string());
255        self
256    }
257
258    /// Set the subject
259    pub fn with_subject(mut self, subject: impl ToString) -> Self {
260        self.subject = Some(subject.to_string());
261        self
262    }
263
264    /// Register one or more audiences (optional recipient identifiers), as a
265    /// set
266    pub fn with_audiences(mut self, audiences: HashSet<impl ToString>) -> Self {
267        self.audiences = Some(Audiences::AsSet(
268            audiences.iter().map(|x| x.to_string()).collect(),
269        ));
270        self
271    }
272
273    /// Set a unique audience (an optional recipient identifier), as a string
274    pub fn with_audience(mut self, audience: impl ToString) -> Self {
275        self.audiences = Some(Audiences::AsString(audience.to_string()));
276        self
277    }
278
279    /// Set the JWT identifier
280    pub fn with_jwt_id(mut self, jwt_id: impl ToString) -> Self {
281        self.jwt_id = Some(jwt_id.to_string());
282        self
283    }
284
285    /// Set the nonce
286    pub fn with_nonce(mut self, nonce: impl ToString) -> Self {
287        self.nonce = Some(nonce.to_string());
288        self
289    }
290}
291
292pub struct Claims;
293
294impl Claims {
295    /// Create a new set of claims, without custom data, expiring in
296    /// `valid_for`.
297    pub fn create(valid_for: Duration) -> JWTClaims<NoCustomClaims> {
298        let now = Clock::now_since_epoch();
299        JWTClaims {
300            issued_at: Some(now),
301            expires_at: Some(now + valid_for),
302            invalid_before: Some(now),
303            audiences: None,
304            issuer: None,
305            jwt_id: None,
306            subject: None,
307            nonce: None,
308            custom: NoCustomClaims {},
309        }
310    }
311
312    /// Create a new set of claims, with custom data, expiring in `valid_for`.
313    pub fn with_custom_claims<CustomClaims: Serialize + DeserializeOwned>(
314        custom_claims: CustomClaims,
315        valid_for: Duration,
316    ) -> JWTClaims<CustomClaims> {
317        let now = Clock::now_since_epoch();
318        JWTClaims {
319            issued_at: Some(now),
320            expires_at: Some(now + valid_for),
321            invalid_before: Some(now),
322            audiences: None,
323            issuer: None,
324            jwt_id: None,
325            subject: None,
326            nonce: None,
327            custom: custom_claims,
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn should_set_standard_claims() {
338        let exp = Duration::from_mins(10);
339        let mut audiences = HashSet::new();
340        audiences.insert("audience1".to_string());
341        audiences.insert("audience2".to_string());
342        let claims = Claims::create(exp)
343            .with_audiences(audiences.clone())
344            .with_issuer("issuer")
345            .with_jwt_id("jwt_id")
346            .with_nonce("nonce")
347            .with_subject("subject");
348
349        assert_eq!(claims.audiences, Some(Audiences::AsSet(audiences)));
350        assert_eq!(claims.issuer, Some("issuer".to_owned()));
351        assert_eq!(claims.jwt_id, Some("jwt_id".to_owned()));
352        assert_eq!(claims.nonce, Some("nonce".to_owned()));
353        assert_eq!(claims.subject, Some("subject".to_owned()));
354    }
355
356    #[test]
357    fn parse_floating_point_unix_time() {
358        let claims: JWTClaims<()> = serde_json::from_str(r#"{"exp":1617757825.8}"#).unwrap();
359        assert_eq!(
360            claims.expires_at,
361            Some(UnixTimeStamp::from_secs(1617757825))
362        );
363    }
364}