1use std::collections::HashSet;
2use std::convert::TryInto;
3
4use coarsetime::{Clock, Duration, UnixTimeStamp};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6
7use crate::common::VerificationOptions;
8use crate::ensure;
9use crate::error::*;
10use crate::serde_additions;
11
12pub const DEFAULT_TIME_TOLERANCE_SECS: u64 = 900;
13
14#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
16pub struct NoCustomClaims {}
17
18#[derive(Debug, Clone, Eq, PartialEq)]
21pub enum Audiences {
22 AsSet(HashSet<String>),
23 AsString(String),
24}
25
26impl Audiences {
27 pub fn is_set(&self) -> bool {
29 matches!(self, Audiences::AsSet(_))
30 }
31
32 pub fn is_string(&self) -> bool {
34 matches!(self, Audiences::AsString(_))
35 }
36
37 pub fn contains(&self, allowed_audiences: &HashSet<String>) -> bool {
40 match self {
41 Audiences::AsString(audience) => allowed_audiences.contains(audience),
42 Audiences::AsSet(audiences) => {
43 audiences.intersection(allowed_audiences).next().is_some()
44 }
45 }
46 }
47
48 pub fn into_set(self) -> HashSet<String> {
50 match self {
51 Audiences::AsSet(audiences_set) => audiences_set,
52 Audiences::AsString(audiences) => {
53 let mut audiences_set = HashSet::new();
54 if !audiences.is_empty() {
55 audiences_set.insert(audiences);
56 }
57 audiences_set
58 }
59 }
60 }
61
62 pub fn into_string(self) -> Result<String, JWTError> {
66 match self {
67 Audiences::AsString(audiences_str) => Ok(audiences_str),
68 Audiences::AsSet(audiences) => {
69 if audiences.len() > 1 {
70 return Err(JWTError::TooManyAudiences);
71 }
72 Ok(audiences
73 .iter()
74 .next()
75 .map(|x| x.to_string())
76 .unwrap_or_default())
77 }
78 }
79 }
80}
81
82impl TryInto<String> for Audiences {
83 type Error = JWTError;
84
85 fn try_into(self) -> Result<String, JWTError> {
86 self.into_string()
87 }
88}
89
90impl From<Audiences> for HashSet<String> {
91 fn from(audiences: Audiences) -> HashSet<String> {
92 audiences.into_set()
93 }
94}
95
96impl<T: ToString> From<T> for Audiences {
97 fn from(audience: T) -> Self {
98 Audiences::AsString(audience.to_string())
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct JWTClaims<CustomClaims> {
109 #[serde(
111 rename = "iat",
112 default,
113 skip_serializing_if = "Option::is_none",
114 with = "self::serde_additions::unix_timestamp"
115 )]
116 pub issued_at: Option<UnixTimeStamp>,
117
118 #[serde(
120 rename = "exp",
121 default,
122 skip_serializing_if = "Option::is_none",
123 with = "self::serde_additions::unix_timestamp"
124 )]
125 pub expires_at: Option<UnixTimeStamp>,
126
127 #[serde(
129 rename = "nbf",
130 default,
131 skip_serializing_if = "Option::is_none",
132 with = "self::serde_additions::unix_timestamp"
133 )]
134 pub invalid_before: Option<UnixTimeStamp>,
135
136 #[serde(rename = "iss", default, skip_serializing_if = "Option::is_none")]
138 pub issuer: Option<String>,
139
140 #[serde(rename = "sub", default, skip_serializing_if = "Option::is_none")]
142 pub subject: Option<String>,
143
144 #[serde(
146 rename = "aud",
147 default,
148 skip_serializing_if = "Option::is_none",
149 with = "self::serde_additions::audiences"
150 )]
151 pub audiences: Option<Audiences>,
152
153 #[serde(rename = "jti", default, skip_serializing_if = "Option::is_none")]
162 pub jwt_id: Option<String>,
163
164 #[serde(rename = "nonce", default, skip_serializing_if = "Option::is_none")]
166 pub nonce: Option<String>,
167
168 #[serde(flatten)]
170 pub custom: CustomClaims,
171}
172
173impl<CustomClaims> JWTClaims<CustomClaims> {
174 pub(crate) fn validate(&self, options: &VerificationOptions) -> Result<(), JWTError> {
175 let now = Clock::now_since_epoch();
176 let time_tolerance = options.time_tolerance.unwrap_or_default();
177
178 if let Some(reject_before) = options.reject_before {
179 if now > reject_before {
180 return Err(JWTError::OldTokenReused);
181 }
182 }
183 if let Some(time_issued) = self.issued_at {
184 ensure!(time_issued <= now + time_tolerance, JWTError::ClockDrift);
185 if let Some(max_validity) = options.max_validity {
186 ensure!(
187 now <= time_issued || now - time_issued <= max_validity,
188 JWTError::TokenIsTooOld
189 );
190 }
191 }
192 if !options.accept_future {
193 if let Some(invalid_before) = self.invalid_before {
194 ensure!(
195 now + time_tolerance >= invalid_before,
196 JWTError::TokenNotValidYet
197 );
198 }
199 }
200 if let Some(expires_at) = self.expires_at {
201 ensure!(
202 now - time_tolerance <= expires_at,
203 JWTError::TokenHasExpired
204 );
205 }
206 if let Some(allowed_issuers) = &options.allowed_issuers {
207 if let Some(issuer) = &self.issuer {
208 ensure!(
209 allowed_issuers.contains(issuer),
210 JWTError::RequiredIssuerMismatch
211 );
212 } else {
213 return Err(JWTError::RequiredIssuerMissing);
214 }
215 }
216 if let Some(required_subject) = &options.required_subject {
217 if let Some(subject) = &self.subject {
218 ensure!(
219 subject == required_subject,
220 JWTError::RequiredSubjectMismatch
221 );
222 } else {
223 return Err(JWTError::RequiredSubjectMissing);
224 }
225 }
226 if let Some(required_nonce) = &options.required_nonce {
227 if let Some(nonce) = &self.nonce {
228 ensure!(nonce == required_nonce, JWTError::RequiredNonceMismatch);
229 } else {
230 return Err(JWTError::RequiredNonceMissing);
231 }
232 }
233 if let Some(allowed_audiences) = &options.allowed_audiences {
234 if let Some(audiences) = &self.audiences {
235 ensure!(
236 audiences.contains(allowed_audiences),
237 JWTError::RequiredAudienceMismatch
238 );
239 } else {
240 return Err(JWTError::RequiredAudienceMissing);
241 }
242 }
243 Ok(())
244 }
245
246 pub fn invalid_before(mut self, unix_timestamp: UnixTimeStamp) -> Self {
248 self.invalid_before = Some(unix_timestamp);
249 self
250 }
251
252 pub fn with_issuer(mut self, issuer: impl ToString) -> Self {
254 self.issuer = Some(issuer.to_string());
255 self
256 }
257
258 pub fn with_subject(mut self, subject: impl ToString) -> Self {
260 self.subject = Some(subject.to_string());
261 self
262 }
263
264 pub fn with_audiences(mut self, audiences: HashSet<impl ToString>) -> Self {
267 self.audiences = Some(Audiences::AsSet(
268 audiences.iter().map(|x| x.to_string()).collect(),
269 ));
270 self
271 }
272
273 pub fn with_audience(mut self, audience: impl ToString) -> Self {
275 self.audiences = Some(Audiences::AsString(audience.to_string()));
276 self
277 }
278
279 pub fn with_jwt_id(mut self, jwt_id: impl ToString) -> Self {
281 self.jwt_id = Some(jwt_id.to_string());
282 self
283 }
284
285 pub fn with_nonce(mut self, nonce: impl ToString) -> Self {
287 self.nonce = Some(nonce.to_string());
288 self
289 }
290}
291
292pub struct Claims;
293
294impl Claims {
295 pub fn create(valid_for: Duration) -> JWTClaims<NoCustomClaims> {
298 let now = Clock::now_since_epoch();
299 JWTClaims {
300 issued_at: Some(now),
301 expires_at: Some(now + valid_for),
302 invalid_before: Some(now),
303 audiences: None,
304 issuer: None,
305 jwt_id: None,
306 subject: None,
307 nonce: None,
308 custom: NoCustomClaims {},
309 }
310 }
311
312 pub fn with_custom_claims<CustomClaims: Serialize + DeserializeOwned>(
314 custom_claims: CustomClaims,
315 valid_for: Duration,
316 ) -> JWTClaims<CustomClaims> {
317 let now = Clock::now_since_epoch();
318 JWTClaims {
319 issued_at: Some(now),
320 expires_at: Some(now + valid_for),
321 invalid_before: Some(now),
322 audiences: None,
323 issuer: None,
324 jwt_id: None,
325 subject: None,
326 nonce: None,
327 custom: custom_claims,
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn should_set_standard_claims() {
338 let exp = Duration::from_mins(10);
339 let mut audiences = HashSet::new();
340 audiences.insert("audience1".to_string());
341 audiences.insert("audience2".to_string());
342 let claims = Claims::create(exp)
343 .with_audiences(audiences.clone())
344 .with_issuer("issuer")
345 .with_jwt_id("jwt_id")
346 .with_nonce("nonce")
347 .with_subject("subject");
348
349 assert_eq!(claims.audiences, Some(Audiences::AsSet(audiences)));
350 assert_eq!(claims.issuer, Some("issuer".to_owned()));
351 assert_eq!(claims.jwt_id, Some("jwt_id".to_owned()));
352 assert_eq!(claims.nonce, Some("nonce".to_owned()));
353 assert_eq!(claims.subject, Some("subject".to_owned()));
354 }
355
356 #[test]
357 fn parse_floating_point_unix_time() {
358 let claims: JWTClaims<()> = serde_json::from_str(r#"{"exp":1617757825.8}"#).unwrap();
359 assert_eq!(
360 claims.expires_at,
361 Some(UnixTimeStamp::from_secs(1617757825))
362 );
363 }
364}