Skip to main content

modkit_auth/
validation.rs

1use crate::claims_error::ClaimsError;
2use crate::standard_claims::StandardClaim;
3use time::OffsetDateTime;
4use uuid::Uuid;
5
6/// Configuration for common validation
7#[derive(Debug, Clone)]
8pub struct ValidationConfig {
9    /// Allowed issuers (if empty, any issuer is accepted)
10    pub allowed_issuers: Vec<String>,
11
12    /// Allowed audiences (if empty, any audience is accepted)
13    pub allowed_audiences: Vec<String>,
14
15    /// Leeway in seconds for time-based validations (exp, nbf)
16    pub leeway_seconds: i64,
17}
18
19impl Default for ValidationConfig {
20    fn default() -> Self {
21        Self {
22            allowed_issuers: vec![],
23            allowed_audiences: vec![],
24            leeway_seconds: 60,
25        }
26    }
27}
28
29/// Validate standard JWT claims in raw JSON against the given configuration.
30///
31/// Checks performed:
32/// 1. **Issuer** (`iss`) — must match one of `config.allowed_issuers` (skipped if empty)
33/// 2. **Audience** (`aud`) — at least one must match `config.allowed_audiences` (skipped if empty)
34/// 3. **Expiration** (`exp`) — must not be in the past (with leeway)
35/// 4. **Not Before** (`nbf`) — must not be in the future (with leeway)
36///
37/// # Errors
38/// Returns `ClaimsError` if any validation check fails.
39pub fn validate_claims(
40    raw: &serde_json::Value,
41    config: &ValidationConfig,
42) -> Result<(), ClaimsError> {
43    // 0. Reject non-object payloads early
44    if !raw.is_object() {
45        return Err(ClaimsError::InvalidClaimFormat {
46            field: "claims".to_owned(),
47            reason: "must be a JSON object".to_owned(),
48        });
49    }
50
51    // 1. Validate issuer
52    if !config.allowed_issuers.is_empty() {
53        if let Some(iss_value) = raw.get(StandardClaim::ISS) {
54            let iss = iss_value
55                .as_str()
56                .ok_or_else(|| ClaimsError::InvalidClaimFormat {
57                    field: StandardClaim::ISS.to_owned(),
58                    reason: "must be a string".to_owned(),
59                })?;
60            if !config.allowed_issuers.iter().any(|a| a == iss) {
61                return Err(ClaimsError::InvalidIssuer {
62                    expected: config.allowed_issuers.clone(),
63                    actual: iss.to_owned(),
64                });
65            }
66        } else {
67            return Err(ClaimsError::MissingClaim(StandardClaim::ISS.to_owned()));
68        }
69    }
70
71    // 2. Validate audience (at least one must match)
72    if !config.allowed_audiences.is_empty() {
73        if let Some(aud_value) = raw.get(StandardClaim::AUD) {
74            let audiences = extract_audiences(aud_value)?;
75            let has_match = audiences
76                .iter()
77                .any(|a| config.allowed_audiences.contains(a));
78            if !has_match {
79                return Err(ClaimsError::InvalidAudience {
80                    expected: config.allowed_audiences.clone(),
81                    actual: audiences,
82                });
83            }
84        } else {
85            return Err(ClaimsError::MissingClaim(StandardClaim::AUD.to_owned()));
86        }
87    }
88
89    let now = OffsetDateTime::now_utc();
90    let leeway = time::Duration::seconds(config.leeway_seconds);
91
92    // 3. Validate expiration with leeway
93    if let Some(exp_value) = raw.get(StandardClaim::EXP) {
94        let exp = parse_timestamp(exp_value, StandardClaim::EXP)?;
95        let exp_with_leeway =
96            exp.checked_add(leeway)
97                .ok_or_else(|| ClaimsError::InvalidClaimFormat {
98                    field: StandardClaim::EXP.to_owned(),
99                    reason: "timestamp with leeway is out of range".to_owned(),
100                })?;
101        if now > exp_with_leeway {
102            return Err(ClaimsError::Expired);
103        }
104    }
105
106    // 4. Validate not-before with leeway
107    if let Some(nbf_value) = raw.get(StandardClaim::NBF) {
108        let nbf = parse_timestamp(nbf_value, StandardClaim::NBF)?;
109        let nbf_with_leeway =
110            nbf.checked_sub(leeway)
111                .ok_or_else(|| ClaimsError::InvalidClaimFormat {
112                    field: StandardClaim::NBF.to_owned(),
113                    reason: "timestamp with leeway is out of range".to_owned(),
114                })?;
115        if now < nbf_with_leeway {
116            return Err(ClaimsError::NotYetValid);
117        }
118    }
119
120    Ok(())
121}
122
123/// Helper to parse a UUID from a JSON value.
124///
125/// # Errors
126/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a valid UUID string.
127pub fn parse_uuid_from_value(
128    value: &serde_json::Value,
129    field_name: &str,
130) -> Result<Uuid, ClaimsError> {
131    value
132        .as_str()
133        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
134            field: field_name.to_owned(),
135            reason: "must be a string".to_owned(),
136        })
137        .and_then(|s| {
138            Uuid::parse_str(s).map_err(|_| ClaimsError::InvalidClaimFormat {
139                field: field_name.to_owned(),
140                reason: "must be a valid UUID".to_owned(),
141            })
142        })
143}
144
145/// Helper to parse an array of UUIDs from a JSON value.
146///
147/// # Errors
148/// Returns `ClaimsError::InvalidClaimFormat` if the value is not an array of valid UUID strings.
149pub fn parse_uuid_array_from_value(
150    value: &serde_json::Value,
151    field_name: &str,
152) -> Result<Vec<Uuid>, ClaimsError> {
153    value
154        .as_array()
155        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
156            field: field_name.to_owned(),
157            reason: "must be an array".to_owned(),
158        })?
159        .iter()
160        .map(|v| parse_uuid_from_value(v, field_name))
161        .collect()
162}
163
164/// Helper to parse timestamp (seconds since epoch) into `OffsetDateTime`.
165///
166/// # Errors
167/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a valid unix timestamp.
168pub fn parse_timestamp(
169    value: &serde_json::Value,
170    field_name: &str,
171) -> Result<OffsetDateTime, ClaimsError> {
172    let ts = value
173        .as_i64()
174        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
175            field: field_name.to_owned(),
176            reason: "must be a number (unix timestamp)".to_owned(),
177        })?;
178
179    OffsetDateTime::from_unix_timestamp(ts).map_err(|_| ClaimsError::InvalidClaimFormat {
180        field: field_name.to_owned(),
181        reason: "invalid unix timestamp".to_owned(),
182    })
183}
184
185/// Helper to extract string from JSON value.
186///
187/// # Errors
188/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a string.
189pub fn extract_string(value: &serde_json::Value, field_name: &str) -> Result<String, ClaimsError> {
190    value
191        .as_str()
192        .map(ToString::to_string)
193        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
194            field: field_name.to_owned(),
195            reason: "must be a string".to_owned(),
196        })
197}
198
199/// Extract audiences from a JSON value.
200///
201/// Accepts a single string or an array of strings. Rejects non-string entries
202/// in arrays and non-string/non-array values.
203///
204/// # Errors
205/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a string,
206/// not an array of strings, or contains non-string entries.
207pub fn extract_audiences(value: &serde_json::Value) -> Result<Vec<String>, ClaimsError> {
208    match value {
209        serde_json::Value::String(s) => Ok(vec![s.clone()]),
210        serde_json::Value::Array(arr) => {
211            let mut out = Vec::with_capacity(arr.len());
212            for v in arr {
213                let s = v.as_str().ok_or_else(|| ClaimsError::InvalidClaimFormat {
214                    field: StandardClaim::AUD.to_owned(),
215                    reason: "must be a string or array of strings".to_owned(),
216                })?;
217                out.push(s.to_owned());
218            }
219            Ok(out)
220        }
221        _ => Err(ClaimsError::InvalidClaimFormat {
222            field: StandardClaim::AUD.to_owned(),
223            reason: "must be a string or array of strings".to_owned(),
224        }),
225    }
226}
227
228#[cfg(test)]
229#[cfg_attr(coverage_nightly, coverage(off))]
230mod tests {
231    use super::*;
232    use serde_json::json;
233
234    /// Unix timestamp for 9999-12-31T23:59:59Z — max representable date in `time` crate default range.
235    const MAX_UNIX_TIMESTAMP: i64 = 253_402_300_799;
236    /// Unix timestamp for -9999-01-01T00:00:00Z — min representable date in `time` crate default range.
237    const MIN_UNIX_TIMESTAMP: i64 = -377_705_116_800;
238
239    #[test]
240    fn test_valid_claims_pass() {
241        let now = time::OffsetDateTime::now_utc();
242        let claims = json!({
243            "iss": "https://test.example.com",
244            "aud": "api",
245            "exp": (now + time::Duration::hours(1)).unix_timestamp(),
246        });
247        let config = ValidationConfig {
248            allowed_issuers: vec!["https://test.example.com".to_owned()],
249            allowed_audiences: vec!["api".to_owned()],
250            ..Default::default()
251        };
252        assert!(validate_claims(&claims, &config).is_ok());
253    }
254
255    #[test]
256    fn test_invalid_issuer_fails() {
257        let claims = json!({ "iss": "https://wrong.example.com" });
258        let config = ValidationConfig {
259            allowed_issuers: vec!["https://expected.example.com".to_owned()],
260            ..Default::default()
261        };
262        let err = validate_claims(&claims, &config).unwrap_err();
263        match err {
264            ClaimsError::InvalidIssuer { expected, actual } => {
265                assert_eq!(expected, vec!["https://expected.example.com"]);
266                assert_eq!(actual, "https://wrong.example.com");
267            }
268            other => panic!("expected InvalidIssuer, got {other:?}"),
269        }
270    }
271
272    #[test]
273    fn test_missing_issuer_fails_when_required() {
274        let claims = json!({ "sub": "user-1" });
275        let config = ValidationConfig {
276            allowed_issuers: vec!["https://expected.example.com".to_owned()],
277            ..Default::default()
278        };
279        let err = validate_claims(&claims, &config).unwrap_err();
280        match err {
281            ClaimsError::MissingClaim(claim) => assert_eq!(claim, StandardClaim::ISS),
282            other => panic!("expected MissingClaim(iss), got {other:?}"),
283        }
284    }
285
286    #[test]
287    fn test_invalid_audience_fails() {
288        let claims = json!({ "aud": "wrong-api" });
289        let config = ValidationConfig {
290            allowed_audiences: vec!["expected-api".to_owned()],
291            ..Default::default()
292        };
293        let err = validate_claims(&claims, &config).unwrap_err();
294        match err {
295            ClaimsError::InvalidAudience { expected, actual } => {
296                assert_eq!(expected, vec!["expected-api"]);
297                assert_eq!(actual, vec!["wrong-api"]);
298            }
299            other => panic!("expected InvalidAudience, got {other:?}"),
300        }
301    }
302
303    #[test]
304    fn test_missing_audience_fails_when_required() {
305        let claims = json!({ "sub": "user-1" });
306        let config = ValidationConfig {
307            allowed_audiences: vec!["api".to_owned()],
308            ..Default::default()
309        };
310        let err = validate_claims(&claims, &config).unwrap_err();
311        match err {
312            ClaimsError::MissingClaim(claim) => assert_eq!(claim, StandardClaim::AUD),
313            other => panic!("expected MissingClaim(aud), got {other:?}"),
314        }
315    }
316
317    #[test]
318    fn test_expired_token_fails() {
319        let now = time::OffsetDateTime::now_utc();
320        let claims = json!({
321            "exp": (now - time::Duration::hours(1)).unix_timestamp(),
322        });
323        let config = ValidationConfig::default();
324        assert!(matches!(
325            validate_claims(&claims, &config),
326            Err(ClaimsError::Expired)
327        ));
328    }
329
330    #[test]
331    fn test_not_yet_valid_fails() {
332        let now = time::OffsetDateTime::now_utc();
333        let claims = json!({
334            "nbf": (now + time::Duration::hours(1)).unix_timestamp(),
335        });
336        let config = ValidationConfig::default();
337        assert!(matches!(
338            validate_claims(&claims, &config),
339            Err(ClaimsError::NotYetValid)
340        ));
341    }
342
343    #[test]
344    fn test_leeway_allows_slightly_expired() {
345        let now = time::OffsetDateTime::now_utc();
346        let claims = json!({
347            "exp": (now - time::Duration::seconds(30)).unix_timestamp(),
348        });
349        let config = ValidationConfig {
350            leeway_seconds: 60,
351            ..Default::default()
352        };
353        assert!(validate_claims(&claims, &config).is_ok());
354    }
355
356    #[test]
357    fn test_empty_config_accepts_anything() {
358        let claims = json!({ "sub": "anyone", "iss": "any-issuer" });
359        let config = ValidationConfig::default();
360        assert!(validate_claims(&claims, &config).is_ok());
361    }
362
363    #[test]
364    fn test_audience_array_match() {
365        let claims = json!({ "aud": ["api", "frontend"] });
366        let config = ValidationConfig {
367            allowed_audiences: vec!["api".to_owned()],
368            ..Default::default()
369        };
370        assert!(validate_claims(&claims, &config).is_ok());
371    }
372
373    #[test]
374    fn test_parse_uuid_from_value() {
375        let uuid = Uuid::new_v4();
376        let value = json!(uuid.to_string());
377
378        let result = parse_uuid_from_value(&value, "test");
379        assert_eq!(result.unwrap(), uuid);
380    }
381
382    #[test]
383    fn test_parse_uuid_from_value_invalid() {
384        let value = json!("not-a-uuid");
385        let err = parse_uuid_from_value(&value, "test").unwrap_err();
386        match err {
387            ClaimsError::InvalidClaimFormat { field, reason } => {
388                assert_eq!(field, "test");
389                assert_eq!(reason, "must be a valid UUID");
390            }
391            other => panic!("expected InvalidClaimFormat, got {other:?}"),
392        }
393    }
394
395    #[test]
396    fn test_malformed_audience_array_rejected() {
397        let claims = json!({ "aud": ["api", 123] });
398        let config = ValidationConfig {
399            allowed_audiences: vec!["api".to_owned()],
400            ..Default::default()
401        };
402        let err = validate_claims(&claims, &config).unwrap_err();
403        match err {
404            ClaimsError::InvalidClaimFormat { field, reason } => {
405                assert_eq!(field, StandardClaim::AUD);
406                assert_eq!(reason, "must be a string or array of strings");
407            }
408            other => panic!("expected InvalidClaimFormat for aud, got {other:?}"),
409        }
410    }
411
412    #[test]
413    fn test_malformed_audience_type_rejected() {
414        let claims = json!({ "aud": 42 });
415        let config = ValidationConfig {
416            allowed_audiences: vec!["api".to_owned()],
417            ..Default::default()
418        };
419        let err = validate_claims(&claims, &config).unwrap_err();
420        match err {
421            ClaimsError::InvalidClaimFormat { field, reason } => {
422                assert_eq!(field, StandardClaim::AUD);
423                assert_eq!(reason, "must be a string or array of strings");
424            }
425            other => panic!("expected InvalidClaimFormat for aud, got {other:?}"),
426        }
427    }
428
429    #[test]
430    fn test_extract_audiences_string() {
431        let value = json!("api");
432        let audiences = extract_audiences(&value).unwrap();
433        assert_eq!(audiences, vec!["api"]);
434    }
435
436    #[test]
437    fn test_extract_audiences_array() {
438        let value = json!(["api", "ui"]);
439        let audiences = extract_audiences(&value).unwrap();
440        assert_eq!(audiences, vec!["api", "ui"]);
441    }
442
443    #[test]
444    fn test_exp_overflow_returns_error() {
445        let claims = json!({ "exp": MAX_UNIX_TIMESTAMP });
446        let config = ValidationConfig {
447            leeway_seconds: 60,
448            ..Default::default()
449        };
450        let err = validate_claims(&claims, &config).unwrap_err();
451        match err {
452            ClaimsError::InvalidClaimFormat { field, reason } => {
453                assert_eq!(field, StandardClaim::EXP);
454                assert_eq!(reason, "timestamp with leeway is out of range");
455            }
456            other => panic!("expected InvalidClaimFormat for exp overflow, got {other:?}"),
457        }
458    }
459
460    #[test]
461    fn test_nbf_overflow_returns_error() {
462        let claims = json!({ "nbf": MIN_UNIX_TIMESTAMP });
463        let config = ValidationConfig {
464            leeway_seconds: 60,
465            ..Default::default()
466        };
467        let err = validate_claims(&claims, &config).unwrap_err();
468        match err {
469            ClaimsError::InvalidClaimFormat { field, reason } => {
470                assert_eq!(field, StandardClaim::NBF);
471                assert_eq!(reason, "timestamp with leeway is out of range");
472            }
473            other => panic!("expected InvalidClaimFormat for nbf overflow, got {other:?}"),
474        }
475    }
476
477    #[test]
478    fn test_non_object_payload_rejected() {
479        let config = ValidationConfig::default();
480        for value in [
481            json!("string"),
482            json!(42),
483            json!(true),
484            json!(null),
485            json!([1, 2, 3]),
486        ] {
487            let err = validate_claims(&value, &config).unwrap_err();
488            match err {
489                ClaimsError::InvalidClaimFormat { field, reason } => {
490                    assert_eq!(field, "claims");
491                    assert_eq!(reason, "must be a JSON object");
492                }
493                other => panic!("expected InvalidClaimFormat for non-object, got {other:?}"),
494            }
495        }
496    }
497
498    #[test]
499    fn test_extract_string_valid() {
500        let value = json!("hello");
501        assert_eq!(extract_string(&value, "field").unwrap(), "hello");
502    }
503
504    #[test]
505    fn test_extract_string_non_string_returns_invalid_claim_format() {
506        for value in [json!(42), json!(true), json!({"a": 1}), json!([1, 2])] {
507            let err = extract_string(&value, "my_field").unwrap_err();
508            match err {
509                ClaimsError::InvalidClaimFormat { field, reason } => {
510                    assert_eq!(field, "my_field");
511                    assert_eq!(reason, "must be a string");
512                }
513                other => panic!("expected InvalidClaimFormat, got {other:?}"),
514            }
515        }
516    }
517}