Skip to main content

modkit_auth/
validation.rs

1use crate::claims_error::ClaimsError;
2use crate::standard_claims::StandardClaim;
3use time::OffsetDateTime;
4use uuid::Uuid;
5
6/// Configuration for common validation
7#[derive(Debug, Clone)]
8pub struct ValidationConfig {
9    /// Allowed issuers (if empty, any issuer is accepted)
10    pub allowed_issuers: Vec<String>,
11
12    /// Allowed audiences (if empty, any audience is accepted)
13    pub allowed_audiences: Vec<String>,
14
15    /// Leeway in seconds for time-based validations (exp, nbf)
16    pub leeway_seconds: i64,
17
18    /// Whether the `exp` claim is required (default: `true`).
19    /// Set to `false` to allow tokens without an expiration claim.
20    pub require_exp: bool,
21}
22
23impl Default for ValidationConfig {
24    fn default() -> Self {
25        Self {
26            allowed_issuers: vec![],
27            allowed_audiences: vec![],
28            leeway_seconds: 60,
29            require_exp: true,
30        }
31    }
32}
33
34/// Validate standard JWT claims in raw JSON against the given configuration.
35///
36/// Checks performed:
37/// 1. **Issuer** (`iss`) — must match one of `config.allowed_issuers` (skipped if empty)
38/// 2. **Audience** (`aud`) — at least one must match `config.allowed_audiences` (skipped if empty)
39/// 3. **Expiration** (`exp`) — required by default; must not be in the past (with leeway).
40///    Set `require_exp = false` to accept tokens without an `exp` claim.
41/// 4. **Not Before** (`nbf`) — must not be in the future (with leeway)
42///
43/// # Errors
44/// Returns `ClaimsError` if any validation check fails.
45pub fn validate_claims(
46    raw: &serde_json::Value,
47    config: &ValidationConfig,
48) -> Result<(), ClaimsError> {
49    // 0. Reject non-object payloads early
50    if !raw.is_object() {
51        return Err(ClaimsError::InvalidClaimFormat {
52            field: "claims".to_owned(),
53            reason: "must be a JSON object".to_owned(),
54        });
55    }
56
57    // 1. Validate issuer
58    if !config.allowed_issuers.is_empty() {
59        if let Some(iss_value) = raw.get(StandardClaim::ISS) {
60            let iss = iss_value
61                .as_str()
62                .ok_or_else(|| ClaimsError::InvalidClaimFormat {
63                    field: StandardClaim::ISS.to_owned(),
64                    reason: "must be a string".to_owned(),
65                })?;
66            if !config.allowed_issuers.iter().any(|a| a == iss) {
67                return Err(ClaimsError::InvalidIssuer {
68                    expected: config.allowed_issuers.clone(),
69                    actual: iss.to_owned(),
70                });
71            }
72        } else {
73            return Err(ClaimsError::MissingClaim(StandardClaim::ISS.to_owned()));
74        }
75    }
76
77    // 2. Validate audience (at least one must match)
78    if !config.allowed_audiences.is_empty() {
79        if let Some(aud_value) = raw.get(StandardClaim::AUD) {
80            let audiences = extract_audiences(aud_value)?;
81            let has_match = audiences
82                .iter()
83                .any(|a| config.allowed_audiences.contains(a));
84            if !has_match {
85                return Err(ClaimsError::InvalidAudience {
86                    expected: config.allowed_audiences.clone(),
87                    actual: audiences,
88                });
89            }
90        } else {
91            return Err(ClaimsError::MissingClaim(StandardClaim::AUD.to_owned()));
92        }
93    }
94
95    let now = OffsetDateTime::now_utc();
96    let leeway = time::Duration::seconds(config.leeway_seconds);
97
98    // 3. Validate expiration with leeway
99    if let Some(exp_value) = raw.get(StandardClaim::EXP) {
100        let exp = parse_timestamp(exp_value, StandardClaim::EXP)?;
101        let exp_with_leeway =
102            exp.checked_add(leeway)
103                .ok_or_else(|| ClaimsError::InvalidClaimFormat {
104                    field: StandardClaim::EXP.to_owned(),
105                    reason: "timestamp with leeway is out of range".to_owned(),
106                })?;
107        if now > exp_with_leeway {
108            return Err(ClaimsError::Expired);
109        }
110    } else if config.require_exp {
111        return Err(ClaimsError::MissingClaim(StandardClaim::EXP.to_owned()));
112    }
113
114    // 4. Validate not-before with leeway
115    if let Some(nbf_value) = raw.get(StandardClaim::NBF) {
116        let nbf = parse_timestamp(nbf_value, StandardClaim::NBF)?;
117        let nbf_with_leeway =
118            nbf.checked_sub(leeway)
119                .ok_or_else(|| ClaimsError::InvalidClaimFormat {
120                    field: StandardClaim::NBF.to_owned(),
121                    reason: "timestamp with leeway is out of range".to_owned(),
122                })?;
123        if now < nbf_with_leeway {
124            return Err(ClaimsError::NotYetValid);
125        }
126    }
127
128    Ok(())
129}
130
131/// Helper to parse a UUID from a JSON value.
132///
133/// # Errors
134/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a valid UUID string.
135pub fn parse_uuid_from_value(
136    value: &serde_json::Value,
137    field_name: &str,
138) -> Result<Uuid, ClaimsError> {
139    value
140        .as_str()
141        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
142            field: field_name.to_owned(),
143            reason: "must be a string".to_owned(),
144        })
145        .and_then(|s| {
146            Uuid::parse_str(s).map_err(|_| ClaimsError::InvalidClaimFormat {
147                field: field_name.to_owned(),
148                reason: "must be a valid UUID".to_owned(),
149            })
150        })
151}
152
153/// Helper to parse an array of UUIDs from a JSON value.
154///
155/// # Errors
156/// Returns `ClaimsError::InvalidClaimFormat` if the value is not an array of valid UUID strings.
157pub fn parse_uuid_array_from_value(
158    value: &serde_json::Value,
159    field_name: &str,
160) -> Result<Vec<Uuid>, ClaimsError> {
161    value
162        .as_array()
163        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
164            field: field_name.to_owned(),
165            reason: "must be an array".to_owned(),
166        })?
167        .iter()
168        .map(|v| parse_uuid_from_value(v, field_name))
169        .collect()
170}
171
172/// Helper to parse timestamp (seconds since epoch) into `OffsetDateTime`.
173///
174/// # Errors
175/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a valid unix timestamp.
176pub fn parse_timestamp(
177    value: &serde_json::Value,
178    field_name: &str,
179) -> Result<OffsetDateTime, ClaimsError> {
180    let ts = value
181        .as_i64()
182        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
183            field: field_name.to_owned(),
184            reason: "must be a number (unix timestamp)".to_owned(),
185        })?;
186
187    OffsetDateTime::from_unix_timestamp(ts).map_err(|_| ClaimsError::InvalidClaimFormat {
188        field: field_name.to_owned(),
189        reason: "invalid unix timestamp".to_owned(),
190    })
191}
192
193/// Helper to extract string from JSON value.
194///
195/// # Errors
196/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a string.
197pub fn extract_string(value: &serde_json::Value, field_name: &str) -> Result<String, ClaimsError> {
198    value
199        .as_str()
200        .map(ToString::to_string)
201        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
202            field: field_name.to_owned(),
203            reason: "must be a string".to_owned(),
204        })
205}
206
207/// Extract audiences from a JSON value.
208///
209/// Accepts a single string or an array of strings. Rejects non-string entries
210/// in arrays and non-string/non-array values.
211///
212/// # Errors
213/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a string,
214/// not an array of strings, or contains non-string entries.
215pub fn extract_audiences(value: &serde_json::Value) -> Result<Vec<String>, ClaimsError> {
216    match value {
217        serde_json::Value::String(s) => Ok(vec![s.clone()]),
218        serde_json::Value::Array(arr) => {
219            let mut out = Vec::with_capacity(arr.len());
220            for v in arr {
221                let s = v.as_str().ok_or_else(|| ClaimsError::InvalidClaimFormat {
222                    field: StandardClaim::AUD.to_owned(),
223                    reason: "must be a string or array of strings".to_owned(),
224                })?;
225                out.push(s.to_owned());
226            }
227            Ok(out)
228        }
229        _ => Err(ClaimsError::InvalidClaimFormat {
230            field: StandardClaim::AUD.to_owned(),
231            reason: "must be a string or array of strings".to_owned(),
232        }),
233    }
234}
235
236#[cfg(test)]
237#[cfg_attr(coverage_nightly, coverage(off))]
238mod tests {
239    use super::*;
240    use serde_json::json;
241
242    /// Unix timestamp for 9999-12-31T23:59:59Z — max representable date in `time` crate default range.
243    const MAX_UNIX_TIMESTAMP: i64 = 253_402_300_799;
244    /// Unix timestamp for -9999-01-01T00:00:00Z — min representable date in `time` crate default range.
245    const MIN_UNIX_TIMESTAMP: i64 = -377_705_116_800;
246
247    #[test]
248    fn test_valid_claims_pass() {
249        let now = time::OffsetDateTime::now_utc();
250        let claims = json!({
251            "iss": "https://test.example.com",
252            "aud": "api",
253            "exp": (now + time::Duration::hours(1)).unix_timestamp(),
254        });
255        let config = ValidationConfig {
256            allowed_issuers: vec!["https://test.example.com".to_owned()],
257            allowed_audiences: vec!["api".to_owned()],
258            ..Default::default()
259        };
260        assert!(validate_claims(&claims, &config).is_ok());
261    }
262
263    #[test]
264    fn test_invalid_issuer_fails() {
265        let claims = json!({ "iss": "https://wrong.example.com" });
266        let config = ValidationConfig {
267            allowed_issuers: vec!["https://expected.example.com".to_owned()],
268            ..Default::default()
269        };
270        let err = validate_claims(&claims, &config).unwrap_err();
271        match err {
272            ClaimsError::InvalidIssuer { expected, actual } => {
273                assert_eq!(expected, vec!["https://expected.example.com"]);
274                assert_eq!(actual, "https://wrong.example.com");
275            }
276            other => panic!("expected InvalidIssuer, got {other:?}"),
277        }
278    }
279
280    #[test]
281    fn test_missing_issuer_fails_when_required() {
282        let claims = json!({ "sub": "user-1" });
283        let config = ValidationConfig {
284            allowed_issuers: vec!["https://expected.example.com".to_owned()],
285            ..Default::default()
286        };
287        let err = validate_claims(&claims, &config).unwrap_err();
288        match err {
289            ClaimsError::MissingClaim(claim) => assert_eq!(claim, StandardClaim::ISS),
290            other => panic!("expected MissingClaim(iss), got {other:?}"),
291        }
292    }
293
294    #[test]
295    fn test_invalid_audience_fails() {
296        let claims = json!({ "aud": "wrong-api" });
297        let config = ValidationConfig {
298            allowed_audiences: vec!["expected-api".to_owned()],
299            ..Default::default()
300        };
301        let err = validate_claims(&claims, &config).unwrap_err();
302        match err {
303            ClaimsError::InvalidAudience { expected, actual } => {
304                assert_eq!(expected, vec!["expected-api"]);
305                assert_eq!(actual, vec!["wrong-api"]);
306            }
307            other => panic!("expected InvalidAudience, got {other:?}"),
308        }
309    }
310
311    #[test]
312    fn test_missing_audience_fails_when_required() {
313        let claims = json!({ "sub": "user-1" });
314        let config = ValidationConfig {
315            allowed_audiences: vec!["api".to_owned()],
316            ..Default::default()
317        };
318        let err = validate_claims(&claims, &config).unwrap_err();
319        match err {
320            ClaimsError::MissingClaim(claim) => assert_eq!(claim, StandardClaim::AUD),
321            other => panic!("expected MissingClaim(aud), got {other:?}"),
322        }
323    }
324
325    #[test]
326    fn test_expired_token_fails() {
327        let now = time::OffsetDateTime::now_utc();
328        let claims = json!({
329            "exp": (now - time::Duration::hours(1)).unix_timestamp(),
330        });
331        let config = ValidationConfig::default();
332        assert!(matches!(
333            validate_claims(&claims, &config),
334            Err(ClaimsError::Expired)
335        ));
336    }
337
338    #[test]
339    fn test_not_yet_valid_fails() {
340        let now = time::OffsetDateTime::now_utc();
341        let claims = json!({
342            "exp": (now + time::Duration::hours(2)).unix_timestamp(),
343            "nbf": (now + time::Duration::hours(1)).unix_timestamp(),
344        });
345        let config = ValidationConfig::default();
346        assert!(matches!(
347            validate_claims(&claims, &config),
348            Err(ClaimsError::NotYetValid)
349        ));
350    }
351
352    #[test]
353    fn test_leeway_allows_slightly_expired() {
354        let now = time::OffsetDateTime::now_utc();
355        let claims = json!({
356            "exp": (now - time::Duration::seconds(30)).unix_timestamp(),
357        });
358        let config = ValidationConfig {
359            leeway_seconds: 60,
360            ..Default::default()
361        };
362        assert!(validate_claims(&claims, &config).is_ok());
363    }
364
365    #[test]
366    fn test_default_config_accepts_valid_claims_with_exp() {
367        let now = time::OffsetDateTime::now_utc();
368        let claims = json!({
369            "sub": "anyone",
370            "iss": "any-issuer",
371            "exp": (now + time::Duration::hours(1)).unix_timestamp(),
372        });
373        let config = ValidationConfig::default();
374        assert!(validate_claims(&claims, &config).is_ok());
375    }
376
377    #[test]
378    fn test_missing_exp_fails() {
379        let claims = json!({ "sub": "user-1" });
380        let config = ValidationConfig::default();
381        let err = validate_claims(&claims, &config).unwrap_err();
382        match err {
383            ClaimsError::MissingClaim(claim) => assert_eq!(claim, StandardClaim::EXP),
384            other => panic!("expected MissingClaim(exp), got {other:?}"),
385        }
386    }
387
388    #[test]
389    fn test_missing_exp_allowed_when_not_required() {
390        let claims = json!({ "sub": "service-token", "iss": "internal" });
391        let config = ValidationConfig {
392            require_exp: false,
393            ..Default::default()
394        };
395        assert!(validate_claims(&claims, &config).is_ok());
396    }
397
398    #[test]
399    fn test_audience_array_match() {
400        let now = time::OffsetDateTime::now_utc();
401        let claims = json!({
402            "aud": ["api", "frontend"],
403            "exp": (now + time::Duration::hours(1)).unix_timestamp(),
404        });
405        let config = ValidationConfig {
406            allowed_audiences: vec!["api".to_owned()],
407            ..Default::default()
408        };
409        assert!(validate_claims(&claims, &config).is_ok());
410    }
411
412    #[test]
413    fn test_parse_uuid_from_value() {
414        let uuid = Uuid::new_v4();
415        let value = json!(uuid.to_string());
416
417        let result = parse_uuid_from_value(&value, "test");
418        assert_eq!(result.unwrap(), uuid);
419    }
420
421    #[test]
422    fn test_parse_uuid_from_value_invalid() {
423        let value = json!("not-a-uuid");
424        let err = parse_uuid_from_value(&value, "test").unwrap_err();
425        match err {
426            ClaimsError::InvalidClaimFormat { field, reason } => {
427                assert_eq!(field, "test");
428                assert_eq!(reason, "must be a valid UUID");
429            }
430            other => panic!("expected InvalidClaimFormat, got {other:?}"),
431        }
432    }
433
434    #[test]
435    fn test_malformed_audience_array_rejected() {
436        let claims = json!({ "aud": ["api", 123] });
437        let config = ValidationConfig {
438            allowed_audiences: vec!["api".to_owned()],
439            ..Default::default()
440        };
441        let err = validate_claims(&claims, &config).unwrap_err();
442        match err {
443            ClaimsError::InvalidClaimFormat { field, reason } => {
444                assert_eq!(field, StandardClaim::AUD);
445                assert_eq!(reason, "must be a string or array of strings");
446            }
447            other => panic!("expected InvalidClaimFormat for aud, got {other:?}"),
448        }
449    }
450
451    #[test]
452    fn test_malformed_audience_type_rejected() {
453        let claims = json!({ "aud": 42 });
454        let config = ValidationConfig {
455            allowed_audiences: vec!["api".to_owned()],
456            ..Default::default()
457        };
458        let err = validate_claims(&claims, &config).unwrap_err();
459        match err {
460            ClaimsError::InvalidClaimFormat { field, reason } => {
461                assert_eq!(field, StandardClaim::AUD);
462                assert_eq!(reason, "must be a string or array of strings");
463            }
464            other => panic!("expected InvalidClaimFormat for aud, got {other:?}"),
465        }
466    }
467
468    #[test]
469    fn test_extract_audiences_string() {
470        let value = json!("api");
471        let audiences = extract_audiences(&value).unwrap();
472        assert_eq!(audiences, vec!["api"]);
473    }
474
475    #[test]
476    fn test_extract_audiences_array() {
477        let value = json!(["api", "ui"]);
478        let audiences = extract_audiences(&value).unwrap();
479        assert_eq!(audiences, vec!["api", "ui"]);
480    }
481
482    #[test]
483    fn test_exp_overflow_returns_error() {
484        let claims = json!({ "exp": MAX_UNIX_TIMESTAMP });
485        let config = ValidationConfig {
486            leeway_seconds: 60,
487            ..Default::default()
488        };
489        let err = validate_claims(&claims, &config).unwrap_err();
490        match err {
491            ClaimsError::InvalidClaimFormat { field, reason } => {
492                assert_eq!(field, StandardClaim::EXP);
493                assert_eq!(reason, "timestamp with leeway is out of range");
494            }
495            other => panic!("expected InvalidClaimFormat for exp overflow, got {other:?}"),
496        }
497    }
498
499    #[test]
500    fn test_nbf_overflow_returns_error() {
501        let now = time::OffsetDateTime::now_utc();
502        let claims = json!({
503            "exp": (now + time::Duration::hours(1)).unix_timestamp(),
504            "nbf": MIN_UNIX_TIMESTAMP,
505        });
506        let config = ValidationConfig {
507            leeway_seconds: 60,
508            ..Default::default()
509        };
510        let err = validate_claims(&claims, &config).unwrap_err();
511        match err {
512            ClaimsError::InvalidClaimFormat { field, reason } => {
513                assert_eq!(field, StandardClaim::NBF);
514                assert_eq!(reason, "timestamp with leeway is out of range");
515            }
516            other => panic!("expected InvalidClaimFormat for nbf overflow, got {other:?}"),
517        }
518    }
519
520    #[test]
521    fn test_non_object_payload_rejected() {
522        let config = ValidationConfig::default();
523        for value in [
524            json!("string"),
525            json!(42),
526            json!(true),
527            json!(null),
528            json!([1, 2, 3]),
529        ] {
530            let err = validate_claims(&value, &config).unwrap_err();
531            match err {
532                ClaimsError::InvalidClaimFormat { field, reason } => {
533                    assert_eq!(field, "claims");
534                    assert_eq!(reason, "must be a JSON object");
535                }
536                other => panic!("expected InvalidClaimFormat for non-object, got {other:?}"),
537            }
538        }
539    }
540
541    #[test]
542    fn test_extract_string_valid() {
543        let value = json!("hello");
544        assert_eq!(extract_string(&value, "field").unwrap(), "hello");
545    }
546
547    #[test]
548    fn test_extract_string_non_string_returns_invalid_claim_format() {
549        for value in [json!(42), json!(true), json!({"a": 1}), json!([1, 2])] {
550            let err = extract_string(&value, "my_field").unwrap_err();
551            match err {
552                ClaimsError::InvalidClaimFormat { field, reason } => {
553                    assert_eq!(field, "my_field");
554                    assert_eq!(reason, "must be a string");
555                }
556                other => panic!("expected InvalidClaimFormat, got {other:?}"),
557            }
558        }
559    }
560}