1use std::collections::HashSet;
2#[cfg(not(target_arch = "wasm32"))]
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use serde_json::map::Map;
6use serde_json::Value;
7
8use crate::Algorithm;
9use crate::Error;
10use crate::Header;
11
12#[cfg(target_arch = "wasm32")]
13use wasm_bindgen::prelude::*;
14
15#[cfg(target_arch = "wasm32")]
16#[wasm_bindgen]
17extern "C" {
18 #[wasm_bindgen(js_namespace = console)]
19 fn log(value: &str);
20}
21
22#[cfg(target_arch = "wasm32")]
23#[wasm_bindgen]
24extern "C" {
25 #[wasm_bindgen(js_namespace = Date)]
26 fn now() -> f64;
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct ValidationOptions {
31 pub leeway: u64,
33 pub validate_exp: bool,
35 pub validate_nbf: bool,
37 pub audiences: Option<HashSet<String>>,
39 pub issuer: Option<String>,
41 pub subject: Option<String>,
43 pub algorithms: HashSet<Algorithm>,
45 pub required_claims: Option<HashSet<String>>,
47}
48
49impl ValidationOptions {
50 pub fn new(alg: Algorithm) -> Self {
52 Self {
53 algorithms: HashSet::from([alg]),
54 ..Self::default()
55 }
56 }
57
58 pub fn without_expiry(self) -> Self {
60 Self {
61 validate_exp: false,
62 ..Self::default()
63 }
64 }
65
66 pub fn with_audiences<T: ToString>(self, audiences: &[T]) -> Self {
68 Self {
69 audiences: Some(audiences.iter().map(ToString::to_string).collect()),
70 ..self
71 }
72 }
73
74 pub fn with_audience<T: ToString>(self, audience: T) -> Self {
76 Self {
77 audiences: Some(HashSet::from([audience.to_string()])),
78 ..self
79 }
80 }
81
82 pub fn with_issuer<T: ToString>(self, issuer: T) -> Self {
84 Self {
85 issuer: Some(issuer.to_string()),
86 ..self
87 }
88 }
89
90 pub fn with_subject<T: ToString>(self, subject: T) -> Self {
92 Self {
93 subject: Some(subject.to_string()),
94 ..self
95 }
96 }
97
98 pub fn with_leeway(self, leeway: u64) -> Self {
100 Self { leeway, ..self }
101 }
102
103 pub fn with_algorithm(mut self, alg: Algorithm) -> Self {
105 self.algorithms.insert(alg);
106 self
107 }
108
109 pub fn with_required_claim<T: ToString>(mut self, claim: T) -> Self {
111 if let Some(ref mut required_claims) = self.required_claims {
112 required_claims.insert(claim.to_string());
113 } else {
114 self.required_claims = Some(HashSet::from([claim.to_string()]));
115 }
116 self
117 }
118}
119
120impl Default for ValidationOptions {
121 fn default() -> Self {
122 Self {
123 leeway: 0,
124 validate_exp: true,
125 validate_nbf: false,
126 audiences: None,
127 issuer: None,
128 subject: None,
129 algorithms: HashSet::new(),
130 required_claims: None,
131 }
132 }
133}
134
135pub(crate) fn validate_header(
137 header: &Header,
138 validation_options: &ValidationOptions,
139) -> Result<(), Error> {
140 if !validation_options.algorithms.is_empty()
141 && !validation_options.algorithms.contains(&header.alg)
142 {
143 return Err(Error::InvalidAlgorithm);
144 }
145 Ok(())
146}
147
148pub(crate) fn validate(
150 claims: &Map<String, Value>,
151 options: &ValidationOptions,
152) -> Result<(), Error> {
153 let now = current_timestamp();
154
155 let validate_time_claim = |claim_value: Option<&Value>,
156 validate: bool,
157 validation_predicate: &dyn Fn(u64) -> bool,
158 validation_error: Error,
159 missing_claim_error: Error|
160 -> Result<(), Error> {
161 if validate {
162 if let Some(value) = claim_value.and_then(|v| v.as_u64()) {
163 if !validation_predicate(value) {
164 return Err(validation_error);
165 }
166 } else {
167 return Err(missing_claim_error);
168 }
169 }
170 Ok(())
171 };
172
173 validate_time_claim(
174 claims.get("exp"),
175 options.validate_exp,
176 &|timestamp| now <= timestamp + options.leeway,
177 Error::ExpiredSignature,
178 Error::InvalidClaim("Missing exp claim".to_string()),
179 )?;
180
181 validate_time_claim(
182 claims.get("nbf"),
183 options.validate_nbf,
184 &|timestamp| now >= timestamp - options.leeway,
185 Error::ImmatureSignature,
186 Error::InvalidClaim("Missing nbf claim".to_string()),
187 )?;
188
189 let validate_str_claim = |claim_value: Option<&Value>,
190 expected_value: &Option<String>,
191 validation_error: Error|
192 -> Result<(), Error> {
193 if let Some(expected) = expected_value {
194 if let Some(actual) = claim_value.and_then(|v| v.as_str()) {
195 if actual != expected {
196 return Err(validation_error);
197 }
198 } else {
199 return Err(validation_error);
200 }
201 }
202 Ok(())
203 };
204
205 validate_str_claim(claims.get("iss"), &options.issuer, Error::InvalidIssuer)?;
206 validate_str_claim(claims.get("sub"), &options.subject, Error::InvalidSubject)?;
207
208 let validate_audiences = |aud_claim: Option<&Value>,
209 expected_audiences: &Option<HashSet<String>>|
210 -> Result<(), Error> {
211 if let Some(expected) = expected_audiences {
212 match aud_claim {
213 Some(Value::String(aud)) => {
214 if !expected.contains(aud) {
215 return Err(Error::InvalidAudience);
216 }
217 }
218 Some(Value::Array(aud_array)) => {
219 let provided: HashSet<String> = aud_array
220 .iter()
221 .filter_map(|val| val.as_str().map(String::from))
222 .collect();
223 if provided.is_disjoint(expected) {
224 return Err(Error::InvalidAudience);
225 }
226 }
227 _ => return Err(Error::InvalidAudience),
228 }
229 }
230 Ok(())
231 };
232
233 validate_audiences(claims.get("aud"), &options.audiences)?;
234
235 if let Some(ref required_claims) = options.required_claims {
236 for claim in required_claims {
237 if !claims.contains_key(claim) {
238 return Err(Error::InvalidClaim(format!(
239 "Missing required claim: {}",
240 claim
241 )));
242 }
243 }
244 }
245
246 Ok(())
247}
248
249#[cfg(not(target_arch = "wasm32"))]
251fn current_timestamp() -> u64 {
252 SystemTime::now()
253 .duration_since(UNIX_EPOCH)
254 .expect("SystemTime before UNIX EPOCH")
255 .as_secs()
256}
257
258#[cfg(target_arch = "wasm32")]
259pub fn current_timestamp() -> u64 {
260 (now() / 1000.0) as u64
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use serde_json::{json, to_value};
267
268 #[test]
269 fn test_expiration_validation() {
270 let mut claims = Map::new();
271 claims.insert(
272 "exp".to_string(),
273 to_value(current_timestamp() + 3600).unwrap(),
274 );
275
276 let result = validate(&claims, &ValidationOptions::default());
277 if result.is_err() {
278 println!("{:?}", result);
279 }
280 assert!(result.is_ok());
281 }
282
283 #[test]
284 fn test_expiration_validation_fail() {
285 let mut claims = Map::new();
286 claims.insert(
287 "exp".to_string(),
288 to_value(current_timestamp() - 60).unwrap(),
289 );
290
291 let result = validate(&claims, &ValidationOptions::default());
292 assert!(matches!(result, Err(Error::ExpiredSignature)));
293 }
294
295 #[test]
296 fn test_not_before_validation() {
297 let mut claims = Map::new();
298 claims.insert(
299 "exp".to_string(),
300 to_value(current_timestamp() + 3600).unwrap(),
301 );
302 claims.insert("nbf".to_string(), to_value(current_timestamp()).unwrap());
303
304 let options = ValidationOptions::default();
305 let result = validate(&claims, &options);
306 assert!(result.is_ok());
307 }
308
309 #[test]
310 fn test_issuer_validation() {
311 let mut claims = Map::new();
312 claims.insert(
313 "exp".to_string(),
314 to_value(current_timestamp() + 3600).unwrap(),
315 );
316 claims.insert("iss".to_string(), json!("valid_issuer"));
317
318 let options = ValidationOptions::default().with_issuer("valid_issuer");
319 let result = validate(&claims, &options);
320 assert!(result.is_ok());
321 }
322
323 #[test]
324 fn test_issuer_validation_fail() {
325 let mut claims = Map::new();
326 claims.insert(
327 "exp".to_string(),
328 to_value(current_timestamp() + 3600).unwrap(),
329 );
330 claims.insert("iss".to_string(), json!("invalid_issuer"));
331
332 let options = ValidationOptions::default().with_issuer("valid_issuer");
333 let result = validate(&claims, &options);
334 assert!(matches!(result, Err(Error::InvalidIssuer)));
335 }
336
337 #[test]
338 fn test_subject_validation() {
339 let mut claims = Map::new();
340 claims.insert(
341 "exp".to_string(),
342 to_value(current_timestamp() + 3600).unwrap(),
343 );
344 claims.insert("sub".to_string(), json!("valid_subject"));
345
346 let options = ValidationOptions::default().with_subject("valid_subject");
347 let result = validate(&claims, &options);
348 assert!(result.is_ok());
349 }
350
351 #[test]
352 fn test_subject_validation_fail() {
353 let mut claims = Map::new();
354 claims.insert(
355 "exp".to_string(),
356 to_value(current_timestamp() + 3600).unwrap(),
357 );
358 claims.insert("sub".to_string(), json!("invalid_subject"));
359
360 let options = ValidationOptions::default().with_subject("valid_subject");
361 let result = validate(&claims, &options);
362 assert!(matches!(result, Err(Error::InvalidSubject)));
363 }
364
365 #[test]
366 fn test_audience_validation() {
367 let mut claims = Map::new();
368 claims.insert(
369 "exp".to_string(),
370 to_value(current_timestamp() + 3600).unwrap(),
371 );
372 claims.insert("aud".to_string(), json!("valid_audience"));
373
374 let options = ValidationOptions::default().with_audience("valid_audience");
375 let result = validate(&claims, &options);
376 assert!(result.is_ok());
377 }
378
379 #[test]
380 fn test_audience_validation_fail() {
381 let mut claims = Map::new();
382 claims.insert(
383 "exp".to_string(),
384 to_value(current_timestamp() + 3600).unwrap(),
385 );
386 claims.insert("aud".to_string(), json!("invalid_audience"));
387
388 let options = ValidationOptions::default().with_audience("valid_audience");
389 let result = validate(&claims, &options);
390 assert!(matches!(result, Err(Error::InvalidAudience)));
391 }
392
393 #[test]
394 fn test_audience_validation_array() {
395 let mut claims = Map::new();
396 claims.insert(
397 "exp".to_string(),
398 to_value(current_timestamp() + 3600).unwrap(),
399 );
400 claims.insert(
401 "aud".to_string(),
402 json!(["valid_audience", "another_audience"]),
403 );
404
405 let options = ValidationOptions::default().with_audience("valid_audience");
406 let result = validate(&claims, &options);
407 assert!(result.is_ok());
408 }
409
410 #[test]
411 fn test_audience_validation_array_fail() {
412 let mut claims = Map::new();
413 claims.insert(
414 "exp".to_string(),
415 to_value(current_timestamp() + 3600).unwrap(),
416 );
417 claims.insert(
418 "aud".to_string(),
419 json!(["invalid_audience", "another_audience"]),
420 );
421
422 let options = ValidationOptions::default().with_audience("valid_audience");
423 let result = validate(&claims, &options);
424 assert!(matches!(result, Err(Error::InvalidAudience)));
425 }
426
427 #[test]
428 fn test_algorithm_validation() {
429 let header = Header {
430 alg: Algorithm::HS256,
431 ..Header::default()
432 };
433 let mut claims = Map::new();
434 claims.insert(
435 "exp".to_string(),
436 to_value(current_timestamp() + 3600).unwrap(),
437 );
438
439 let options = ValidationOptions::default().with_algorithm(Algorithm::HS256);
440 let result = validate_header(&header, &options);
441 assert!(result.is_ok());
442 }
443
444 #[test]
445 fn test_algorithm_validation_fail_in_header() {
446 let header = Header {
447 alg: Algorithm::HS256,
448 ..Header::default()
449 };
450 let mut claims = Map::new();
451 claims.insert(
452 "exp".to_string(),
453 to_value(current_timestamp() + 3600).unwrap(),
454 );
455
456 let options = ValidationOptions::default().with_algorithm(Algorithm::HS384);
457 let result = validate_header(&header, &options);
458 assert!(matches!(result, Err(Error::InvalidAlgorithm)));
459 }
460
461 #[test]
462 fn test_required_claims() {
463 let mut claims = Map::new();
464 claims.insert(
465 "exp".to_string(),
466 to_value(current_timestamp() + 3600).unwrap(),
467 );
468 claims.insert("sub".to_string(), json!("required_subject"));
469
470 let options = ValidationOptions::default().with_required_claim("sub");
471 let result = validate(&claims, &options);
472 assert!(result.is_ok());
473 }
474
475 #[test]
476 fn test_required_claims_fail() {
477 let mut claims = Map::new();
478 claims.insert(
479 "exp".to_string(),
480 to_value(current_timestamp() + 3600).unwrap(),
481 );
482
483 let options = ValidationOptions::default().with_required_claim("sub");
484 let result = validate(&claims, &options);
485 assert!(matches!(result, Err(Error::InvalidClaim(_))));
486 }
487
488 #[test]
489 fn test_required_claims_multiple() {
490 let mut claims = Map::new();
491 claims.insert(
492 "exp".to_string(),
493 to_value(current_timestamp() + 3600).unwrap(),
494 );
495 claims.insert("sub".to_string(), json!("required_subject"));
496 claims.insert("aud".to_string(), json!("required_audience"));
497
498 let options = ValidationOptions::default()
499 .with_required_claim("sub")
500 .with_required_claim("aud");
501 let result = validate(&claims, &options);
502 assert!(result.is_ok());
503 }
504}