jwt_rustcrypto/
validation.rs

1use std::collections::HashSet;
2#[cfg(not(target_arch = "wasm32"))]
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use serde_json::map::Map;
6use serde_json::Value;
7
8use crate::Algorithm;
9use crate::Error;
10use crate::Header;
11
12#[cfg(target_arch = "wasm32")]
13use wasm_bindgen::prelude::*;
14
15#[cfg(target_arch = "wasm32")]
16#[wasm_bindgen]
17extern "C" {
18    #[wasm_bindgen(js_namespace = console)]
19    fn log(value: &str);
20}
21
22#[cfg(target_arch = "wasm32")]
23#[wasm_bindgen]
24extern "C" {
25    #[wasm_bindgen(js_namespace = Date)]
26    fn now() -> f64;
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct ValidationOptions {
31    /// General leeway (in seconds) applied to all time-related claims like `exp`, `nbf`, and `iat`.
32    pub leeway: u64,
33    /// Validate the expiration time (`exp` claim).
34    pub validate_exp: bool,
35    /// Validate the not-before time (`nbf` claim).
36    pub validate_nbf: bool,
37    /// Set of acceptable audience members.
38    pub audiences: Option<HashSet<String>>,
39    /// Expected issuer.
40    pub issuer: Option<String>,
41    /// Expected subject.
42    pub subject: Option<String>,
43    /// Allowed signing algorithms for the JWT.
44    pub algorithms: HashSet<Algorithm>,
45    /// Required claims.
46    pub required_claims: Option<HashSet<String>>,
47}
48
49impl ValidationOptions {
50    /// Create a new set of `ValidationOptions` with a specific algorithm.
51    pub fn new(alg: Algorithm) -> Self {
52        Self {
53            algorithms: HashSet::from([alg]),
54            ..Self::default()
55        }
56    }
57
58    /// Disable expiration (`exp`) validation.
59    pub fn without_expiry(self) -> Self {
60        Self {
61            validate_exp: false,
62            ..Self::default()
63        }
64    }
65
66    /// Set acceptable audience members as a HashSet of strings.
67    pub fn with_audiences<T: ToString>(self, audiences: &[T]) -> Self {
68        Self {
69            audiences: Some(audiences.iter().map(ToString::to_string).collect()),
70            ..self
71        }
72    }
73
74    /// Set a single audience member as the only acceptable value.
75    pub fn with_audience<T: ToString>(self, audience: T) -> Self {
76        Self {
77            audiences: Some(HashSet::from([audience.to_string()])),
78            ..self
79        }
80    }
81
82    /// Set the issuer claim to validate.
83    pub fn with_issuer<T: ToString>(self, issuer: T) -> Self {
84        Self {
85            issuer: Some(issuer.to_string()),
86            ..self
87        }
88    }
89
90    /// Set the subject claim to validate.
91    pub fn with_subject<T: ToString>(self, subject: T) -> Self {
92        Self {
93            subject: Some(subject.to_string()),
94            ..self
95        }
96    }
97
98    /// Set leeway for time-related claims (`exp`, `nbf`, `iat`).
99    pub fn with_leeway(self, leeway: u64) -> Self {
100        Self { leeway, ..self }
101    }
102
103    /// Add an allowed signing algorithm.
104    pub fn with_algorithm(mut self, alg: Algorithm) -> Self {
105        self.algorithms.insert(alg);
106        self
107    }
108
109    /// Add a required claim.
110    pub fn with_required_claim<T: ToString>(mut self, claim: T) -> Self {
111        if let Some(ref mut required_claims) = self.required_claims {
112            required_claims.insert(claim.to_string());
113        } else {
114            self.required_claims = Some(HashSet::from([claim.to_string()]));
115        }
116        self
117    }
118}
119
120impl Default for ValidationOptions {
121    fn default() -> Self {
122        Self {
123            leeway: 0,
124            validate_exp: true,
125            validate_nbf: false,
126            audiences: None,
127            issuer: None,
128            subject: None,
129            algorithms: HashSet::new(),
130            required_claims: None,
131        }
132    }
133}
134
135/// Validates the header of a JWT using the given `ValidationOptions`.
136pub(crate) fn validate_header(
137    header: &Header,
138    validation_options: &ValidationOptions,
139) -> Result<(), Error> {
140    if !validation_options.algorithms.is_empty()
141        && !validation_options.algorithms.contains(&header.alg)
142    {
143        return Err(Error::InvalidAlgorithm);
144    }
145    Ok(())
146}
147
148/// Validates the claims within a JWT using the given `ValidationOptions`.
149pub(crate) fn validate(
150    claims: &Map<String, Value>,
151    options: &ValidationOptions,
152) -> Result<(), Error> {
153    let now = current_timestamp();
154
155    let validate_time_claim = |claim_value: Option<&Value>,
156                               validate: bool,
157                               validation_predicate: &dyn Fn(u64) -> bool,
158                               validation_error: Error,
159                               missing_claim_error: Error|
160     -> Result<(), Error> {
161        if validate {
162            if let Some(value) = claim_value.and_then(|v| v.as_u64()) {
163                if !validation_predicate(value) {
164                    return Err(validation_error);
165                }
166            } else {
167                return Err(missing_claim_error);
168            }
169        }
170        Ok(())
171    };
172
173    validate_time_claim(
174        claims.get("exp"),
175        options.validate_exp,
176        &|timestamp| now <= timestamp + options.leeway,
177        Error::ExpiredSignature,
178        Error::InvalidClaim("Missing exp claim".to_string()),
179    )?;
180
181    validate_time_claim(
182        claims.get("nbf"),
183        options.validate_nbf,
184        &|timestamp| now >= timestamp - options.leeway,
185        Error::ImmatureSignature,
186        Error::InvalidClaim("Missing nbf claim".to_string()),
187    )?;
188
189    let validate_str_claim = |claim_value: Option<&Value>,
190                              expected_value: &Option<String>,
191                              validation_error: Error|
192     -> Result<(), Error> {
193        if let Some(expected) = expected_value {
194            if let Some(actual) = claim_value.and_then(|v| v.as_str()) {
195                if actual != expected {
196                    return Err(validation_error);
197                }
198            } else {
199                return Err(validation_error);
200            }
201        }
202        Ok(())
203    };
204
205    validate_str_claim(claims.get("iss"), &options.issuer, Error::InvalidIssuer)?;
206    validate_str_claim(claims.get("sub"), &options.subject, Error::InvalidSubject)?;
207
208    let validate_audiences = |aud_claim: Option<&Value>,
209                              expected_audiences: &Option<HashSet<String>>|
210     -> Result<(), Error> {
211        if let Some(expected) = expected_audiences {
212            match aud_claim {
213                Some(Value::String(aud)) => {
214                    if !expected.contains(aud) {
215                        return Err(Error::InvalidAudience);
216                    }
217                }
218                Some(Value::Array(aud_array)) => {
219                    let provided: HashSet<String> = aud_array
220                        .iter()
221                        .filter_map(|val| val.as_str().map(String::from))
222                        .collect();
223                    if provided.is_disjoint(expected) {
224                        return Err(Error::InvalidAudience);
225                    }
226                }
227                _ => return Err(Error::InvalidAudience),
228            }
229        }
230        Ok(())
231    };
232
233    validate_audiences(claims.get("aud"), &options.audiences)?;
234
235    if let Some(ref required_claims) = options.required_claims {
236        for claim in required_claims {
237            if !claims.contains_key(claim) {
238                return Err(Error::InvalidClaim(format!(
239                    "Missing required claim: {}",
240                    claim
241                )));
242            }
243        }
244    }
245
246    Ok(())
247}
248
249/// Gets the current timestamp in seconds since the UNIX epoch.
250#[cfg(not(target_arch = "wasm32"))]
251fn current_timestamp() -> u64 {
252    SystemTime::now()
253        .duration_since(UNIX_EPOCH)
254        .expect("SystemTime before UNIX EPOCH")
255        .as_secs()
256}
257
258#[cfg(target_arch = "wasm32")]
259pub fn current_timestamp() -> u64 {
260    (now() / 1000.0) as u64
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use serde_json::{json, to_value};
267
268    #[test]
269    fn test_expiration_validation() {
270        let mut claims = Map::new();
271        claims.insert(
272            "exp".to_string(),
273            to_value(current_timestamp() + 3600).unwrap(),
274        );
275
276        let result = validate(&claims, &ValidationOptions::default());
277        if result.is_err() {
278            println!("{:?}", result);
279        }
280        assert!(result.is_ok());
281    }
282
283    #[test]
284    fn test_expiration_validation_fail() {
285        let mut claims = Map::new();
286        claims.insert(
287            "exp".to_string(),
288            to_value(current_timestamp() - 60).unwrap(),
289        );
290
291        let result = validate(&claims, &ValidationOptions::default());
292        assert!(matches!(result, Err(Error::ExpiredSignature)));
293    }
294
295    #[test]
296    fn test_not_before_validation() {
297        let mut claims = Map::new();
298        claims.insert(
299            "exp".to_string(),
300            to_value(current_timestamp() + 3600).unwrap(),
301        );
302        claims.insert("nbf".to_string(), to_value(current_timestamp()).unwrap());
303
304        let options = ValidationOptions::default();
305        let result = validate(&claims, &options);
306        assert!(result.is_ok());
307    }
308
309    #[test]
310    fn test_issuer_validation() {
311        let mut claims = Map::new();
312        claims.insert(
313            "exp".to_string(),
314            to_value(current_timestamp() + 3600).unwrap(),
315        );
316        claims.insert("iss".to_string(), json!("valid_issuer"));
317
318        let options = ValidationOptions::default().with_issuer("valid_issuer");
319        let result = validate(&claims, &options);
320        assert!(result.is_ok());
321    }
322
323    #[test]
324    fn test_issuer_validation_fail() {
325        let mut claims = Map::new();
326        claims.insert(
327            "exp".to_string(),
328            to_value(current_timestamp() + 3600).unwrap(),
329        );
330        claims.insert("iss".to_string(), json!("invalid_issuer"));
331
332        let options = ValidationOptions::default().with_issuer("valid_issuer");
333        let result = validate(&claims, &options);
334        assert!(matches!(result, Err(Error::InvalidIssuer)));
335    }
336
337    #[test]
338    fn test_subject_validation() {
339        let mut claims = Map::new();
340        claims.insert(
341            "exp".to_string(),
342            to_value(current_timestamp() + 3600).unwrap(),
343        );
344        claims.insert("sub".to_string(), json!("valid_subject"));
345
346        let options = ValidationOptions::default().with_subject("valid_subject");
347        let result = validate(&claims, &options);
348        assert!(result.is_ok());
349    }
350
351    #[test]
352    fn test_subject_validation_fail() {
353        let mut claims = Map::new();
354        claims.insert(
355            "exp".to_string(),
356            to_value(current_timestamp() + 3600).unwrap(),
357        );
358        claims.insert("sub".to_string(), json!("invalid_subject"));
359
360        let options = ValidationOptions::default().with_subject("valid_subject");
361        let result = validate(&claims, &options);
362        assert!(matches!(result, Err(Error::InvalidSubject)));
363    }
364
365    #[test]
366    fn test_audience_validation() {
367        let mut claims = Map::new();
368        claims.insert(
369            "exp".to_string(),
370            to_value(current_timestamp() + 3600).unwrap(),
371        );
372        claims.insert("aud".to_string(), json!("valid_audience"));
373
374        let options = ValidationOptions::default().with_audience("valid_audience");
375        let result = validate(&claims, &options);
376        assert!(result.is_ok());
377    }
378
379    #[test]
380    fn test_audience_validation_fail() {
381        let mut claims = Map::new();
382        claims.insert(
383            "exp".to_string(),
384            to_value(current_timestamp() + 3600).unwrap(),
385        );
386        claims.insert("aud".to_string(), json!("invalid_audience"));
387
388        let options = ValidationOptions::default().with_audience("valid_audience");
389        let result = validate(&claims, &options);
390        assert!(matches!(result, Err(Error::InvalidAudience)));
391    }
392
393    #[test]
394    fn test_audience_validation_array() {
395        let mut claims = Map::new();
396        claims.insert(
397            "exp".to_string(),
398            to_value(current_timestamp() + 3600).unwrap(),
399        );
400        claims.insert(
401            "aud".to_string(),
402            json!(["valid_audience", "another_audience"]),
403        );
404
405        let options = ValidationOptions::default().with_audience("valid_audience");
406        let result = validate(&claims, &options);
407        assert!(result.is_ok());
408    }
409
410    #[test]
411    fn test_audience_validation_array_fail() {
412        let mut claims = Map::new();
413        claims.insert(
414            "exp".to_string(),
415            to_value(current_timestamp() + 3600).unwrap(),
416        );
417        claims.insert(
418            "aud".to_string(),
419            json!(["invalid_audience", "another_audience"]),
420        );
421
422        let options = ValidationOptions::default().with_audience("valid_audience");
423        let result = validate(&claims, &options);
424        assert!(matches!(result, Err(Error::InvalidAudience)));
425    }
426
427    #[test]
428    fn test_algorithm_validation() {
429        let header = Header {
430            alg: Algorithm::HS256,
431            ..Header::default()
432        };
433        let mut claims = Map::new();
434        claims.insert(
435            "exp".to_string(),
436            to_value(current_timestamp() + 3600).unwrap(),
437        );
438
439        let options = ValidationOptions::default().with_algorithm(Algorithm::HS256);
440        let result = validate_header(&header, &options);
441        assert!(result.is_ok());
442    }
443
444    #[test]
445    fn test_algorithm_validation_fail_in_header() {
446        let header = Header {
447            alg: Algorithm::HS256,
448            ..Header::default()
449        };
450        let mut claims = Map::new();
451        claims.insert(
452            "exp".to_string(),
453            to_value(current_timestamp() + 3600).unwrap(),
454        );
455
456        let options = ValidationOptions::default().with_algorithm(Algorithm::HS384);
457        let result = validate_header(&header, &options);
458        assert!(matches!(result, Err(Error::InvalidAlgorithm)));
459    }
460
461    #[test]
462    fn test_required_claims() {
463        let mut claims = Map::new();
464        claims.insert(
465            "exp".to_string(),
466            to_value(current_timestamp() + 3600).unwrap(),
467        );
468        claims.insert("sub".to_string(), json!("required_subject"));
469
470        let options = ValidationOptions::default().with_required_claim("sub");
471        let result = validate(&claims, &options);
472        assert!(result.is_ok());
473    }
474
475    #[test]
476    fn test_required_claims_fail() {
477        let mut claims = Map::new();
478        claims.insert(
479            "exp".to_string(),
480            to_value(current_timestamp() + 3600).unwrap(),
481        );
482
483        let options = ValidationOptions::default().with_required_claim("sub");
484        let result = validate(&claims, &options);
485        assert!(matches!(result, Err(Error::InvalidClaim(_))));
486    }
487
488    #[test]
489    fn test_required_claims_multiple() {
490        let mut claims = Map::new();
491        claims.insert(
492            "exp".to_string(),
493            to_value(current_timestamp() + 3600).unwrap(),
494        );
495        claims.insert("sub".to_string(), json!("required_subject"));
496        claims.insert("aud".to_string(), json!("required_audience"));
497
498        let options = ValidationOptions::default()
499            .with_required_claim("sub")
500            .with_required_claim("aud");
501        let result = validate(&claims, &options);
502        assert!(result.is_ok());
503    }
504}