Skip to main content

modkit_auth/
validation.rs

1use crate::{claims::Claims, claims_error::ClaimsError};
2use time::OffsetDateTime;
3use uuid::Uuid;
4
5/// Configuration for common validation
6#[derive(Debug, Clone)]
7pub struct ValidationConfig {
8    /// Allowed issuers (if empty, any issuer is accepted)
9    pub allowed_issuers: Vec<String>,
10
11    /// Allowed audiences (if empty, any audience is accepted)
12    pub allowed_audiences: Vec<String>,
13
14    /// Leeway in seconds for time-based validations (exp, nbf)
15    pub leeway_seconds: i64,
16
17    /// Require subject to be a valid UUID
18    pub require_uuid_subject: bool,
19
20    /// Require tenants to be valid UUIDs
21    pub require_uuid_tenants: bool,
22}
23
24impl Default for ValidationConfig {
25    fn default() -> Self {
26        Self {
27            allowed_issuers: vec![],
28            allowed_audiences: vec![],
29            leeway_seconds: 60,
30            require_uuid_subject: true,
31            require_uuid_tenants: true,
32        }
33    }
34}
35
36/// Perform common validation checks on claims.
37///
38/// # Errors
39/// Returns `ClaimsError` if any validation check fails (issuer, audience, expiration, etc.).
40pub fn validate_claims(claims: &Claims, config: &ValidationConfig) -> Result<(), ClaimsError> {
41    // 1. Validate issuer
42    if !config.allowed_issuers.is_empty() && !config.allowed_issuers.contains(&claims.issuer) {
43        return Err(ClaimsError::InvalidIssuer {
44            expected: config.allowed_issuers.clone(),
45            actual: claims.issuer.clone(),
46        });
47    }
48
49    // 2. Validate audience (at least one must match)
50    if !config.allowed_audiences.is_empty() {
51        let has_valid_audience = claims
52            .audiences
53            .iter()
54            .any(|aud| config.allowed_audiences.contains(aud));
55
56        if !has_valid_audience {
57            return Err(ClaimsError::InvalidAudience {
58                expected: config.allowed_audiences.clone(),
59                actual: claims.audiences.clone(),
60            });
61        }
62    }
63
64    // 3. Validate expiration with leeway
65    if let Some(exp) = claims.expires_at {
66        let now = OffsetDateTime::now_utc();
67        let leeway = time::Duration::seconds(config.leeway_seconds);
68
69        if now > exp + leeway {
70            return Err(ClaimsError::Expired);
71        }
72    }
73
74    // 4. Validate not-before with leeway
75    if let Some(nbf) = claims.not_before {
76        let now = OffsetDateTime::now_utc();
77        let leeway = time::Duration::seconds(config.leeway_seconds);
78
79        if now < nbf - leeway {
80            return Err(ClaimsError::NotYetValid);
81        }
82    }
83
84    // 5. Validate subject is UUID (already validated during normalization, but double-check)
85    if config.require_uuid_subject && claims.subject.is_nil() {
86        // subject is already a Uuid type, so this is guaranteed
87        // Just a safety check for future-proofing
88        return Err(ClaimsError::InvalidClaimFormat {
89            field: "subject".to_owned(),
90            reason: "subject cannot be nil UUID".to_owned(),
91        });
92    }
93
94    // 6. Validate tenant_id is UUID (already validated during normalization)
95    if config.require_uuid_tenants && claims.tenant_id.is_nil() {
96        return Err(ClaimsError::InvalidClaimFormat {
97            field: "tenant_id".to_owned(),
98            reason: "tenant ID cannot be nil UUID".to_owned(),
99        });
100    }
101
102    Ok(())
103}
104
105/// Helper to parse a UUID from a JSON value.
106///
107/// # Errors
108/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a valid UUID string.
109pub fn parse_uuid_from_value(
110    value: &serde_json::Value,
111    field_name: &str,
112) -> Result<Uuid, ClaimsError> {
113    value
114        .as_str()
115        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
116            field: field_name.to_owned(),
117            reason: "must be a string".to_owned(),
118        })
119        .and_then(|s| {
120            Uuid::parse_str(s).map_err(|_| ClaimsError::InvalidClaimFormat {
121                field: field_name.to_owned(),
122                reason: "must be a valid UUID".to_owned(),
123            })
124        })
125}
126
127/// Helper to parse an array of UUIDs from a JSON value.
128///
129/// # Errors
130/// Returns `ClaimsError::InvalidClaimFormat` if the value is not an array of valid UUID strings.
131pub fn parse_uuid_array_from_value(
132    value: &serde_json::Value,
133    field_name: &str,
134) -> Result<Vec<Uuid>, ClaimsError> {
135    value
136        .as_array()
137        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
138            field: field_name.to_owned(),
139            reason: "must be an array".to_owned(),
140        })?
141        .iter()
142        .map(|v| parse_uuid_from_value(v, field_name))
143        .collect()
144}
145
146/// Helper to parse timestamp (seconds since epoch) into `OffsetDateTime`.
147///
148/// # Errors
149/// Returns `ClaimsError::InvalidClaimFormat` if the value is not a valid unix timestamp.
150pub fn parse_timestamp(
151    value: &serde_json::Value,
152    field_name: &str,
153) -> Result<OffsetDateTime, ClaimsError> {
154    let ts = value
155        .as_i64()
156        .ok_or_else(|| ClaimsError::InvalidClaimFormat {
157            field: field_name.to_owned(),
158            reason: "must be a number (unix timestamp)".to_owned(),
159        })?;
160
161    OffsetDateTime::from_unix_timestamp(ts).map_err(|_| ClaimsError::InvalidClaimFormat {
162        field: field_name.to_owned(),
163        reason: "invalid unix timestamp".to_owned(),
164    })
165}
166
167/// Helper to extract string from JSON value.
168///
169/// # Errors
170/// Returns `ClaimsError::MissingClaim` if the value is not a string.
171pub fn extract_string(value: &serde_json::Value, field_name: &str) -> Result<String, ClaimsError> {
172    value
173        .as_str()
174        .map(ToString::to_string)
175        .ok_or_else(|| ClaimsError::MissingClaim(field_name.to_owned()))
176}
177
178/// Helper to extract string array from JSON value (handles both string and array)
179#[must_use]
180pub fn extract_audiences(value: &serde_json::Value) -> Vec<String> {
181    match value {
182        serde_json::Value::String(s) => vec![s.clone()],
183        serde_json::Value::Array(arr) => arr
184            .iter()
185            .filter_map(|v| v.as_str().map(ToString::to_string))
186            .collect(),
187        _ => vec![],
188    }
189}
190
191#[cfg(test)]
192#[cfg_attr(coverage_nightly, coverage(off))]
193mod tests {
194    use super::*;
195    use serde_json::json;
196
197    fn create_test_claims() -> Claims {
198        Claims {
199            issuer: "https://test.example.com".to_owned(),
200            subject: Uuid::new_v4(),
201            audiences: vec!["api".to_owned()],
202            expires_at: Some(OffsetDateTime::now_utc() + time::Duration::hours(1)),
203            not_before: None,
204            issued_at: None,
205            jwt_id: None,
206            tenant_id: Uuid::new_v4(),
207            permissions: vec![],
208            extras: serde_json::Map::new(),
209        }
210    }
211
212    #[test]
213    fn test_valid_claims_pass() {
214        let claims = create_test_claims();
215        let config = ValidationConfig {
216            allowed_issuers: vec!["https://test.example.com".to_owned()],
217            allowed_audiences: vec!["api".to_owned()],
218            ..Default::default()
219        };
220
221        assert!(validate_claims(&claims, &config).is_ok());
222    }
223
224    #[test]
225    fn test_invalid_issuer_fails() {
226        let claims = create_test_claims();
227        let config = ValidationConfig {
228            allowed_issuers: vec!["https://other.example.com".to_owned()],
229            allowed_audiences: vec![],
230            ..Default::default()
231        };
232
233        let result = validate_claims(&claims, &config);
234        assert!(matches!(result, Err(ClaimsError::InvalidIssuer { .. })));
235    }
236
237    #[test]
238    fn test_invalid_audience_fails() {
239        let claims = create_test_claims();
240        let config = ValidationConfig {
241            allowed_issuers: vec![],
242            allowed_audiences: vec!["other-api".to_owned()],
243            ..Default::default()
244        };
245
246        let result = validate_claims(&claims, &config);
247        assert!(matches!(result, Err(ClaimsError::InvalidAudience { .. })));
248    }
249
250    #[test]
251    fn test_expired_token_fails() {
252        let mut claims = create_test_claims();
253        claims.expires_at = Some(OffsetDateTime::now_utc() - time::Duration::hours(1));
254
255        let config = ValidationConfig::default();
256        let result = validate_claims(&claims, &config);
257        assert!(matches!(result, Err(ClaimsError::Expired)));
258    }
259
260    #[test]
261    fn test_not_yet_valid_fails() {
262        let mut claims = create_test_claims();
263        claims.not_before = Some(OffsetDateTime::now_utc() + time::Duration::hours(1));
264
265        let config = ValidationConfig::default();
266        let result = validate_claims(&claims, &config);
267        assert!(matches!(result, Err(ClaimsError::NotYetValid)));
268    }
269
270    #[test]
271    fn test_leeway_allows_expired() {
272        let mut claims = create_test_claims();
273        claims.expires_at = Some(OffsetDateTime::now_utc() - time::Duration::seconds(30));
274
275        let config = ValidationConfig {
276            leeway_seconds: 60,
277            ..Default::default()
278        };
279
280        assert!(validate_claims(&claims, &config).is_ok());
281    }
282
283    #[test]
284    fn test_parse_uuid_from_value() {
285        let uuid = Uuid::new_v4();
286        let value = json!(uuid.to_string());
287
288        let result = parse_uuid_from_value(&value, "test");
289        assert_eq!(result.unwrap(), uuid);
290    }
291
292    #[test]
293    fn test_parse_uuid_from_value_invalid() {
294        let value = json!("not-a-uuid");
295        let result = parse_uuid_from_value(&value, "test");
296        assert!(matches!(
297            result,
298            Err(ClaimsError::InvalidClaimFormat { .. })
299        ));
300    }
301
302    #[test]
303    fn test_extract_audiences_string() {
304        let value = json!("api");
305        let audiences = extract_audiences(&value);
306        assert_eq!(audiences, vec!["api"]);
307    }
308
309    #[test]
310    fn test_extract_audiences_array() {
311        let value = json!(["api", "ui"]);
312        let audiences = extract_audiences(&value);
313        assert_eq!(audiences, vec!["api", "ui"]);
314    }
315}