Skip to main content

forge_core/auth/
claims.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use uuid::Uuid;
4
5/// JWT claims structure.
6///
7/// Fields are intentionally crate-private. Construct via [`Claims::builder`]
8/// and read via the accessor methods. The `custom` map is gated by
9/// [`Claims::get_claim`] / [`Claims::sanitized_custom`] so reserved JWT
10/// claim names (`iss`, `aud`, `nbf`, `jti`, …) can never be retrieved as
11/// custom data, even when serde's `#[serde(flatten)]` lets them in.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[non_exhaustive]
14pub struct Claims {
15    pub(crate) sub: String,
16    pub(crate) iat: i64,
17    pub(crate) exp: i64,
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub(crate) aud: Option<String>,
20    #[serde(default)]
21    pub(crate) roles: Vec<String>,
22    /// Reserved JWT claims are filtered out on read; use [`Claims::get_claim`] / [`Claims::sanitized_custom`].
23    #[serde(flatten)]
24    pub(crate) custom: HashMap<String, serde_json::Value>,
25}
26
27impl Claims {
28    /// Get the subject (raw `sub` claim).
29    pub fn sub(&self) -> &str {
30        &self.sub
31    }
32
33    /// Get the issued-at Unix timestamp.
34    pub fn iat(&self) -> i64 {
35        self.iat
36    }
37
38    /// Get the expiration Unix timestamp.
39    pub fn exp(&self) -> i64 {
40        self.exp
41    }
42
43    /// Get the audience (`aud` claim), if set.
44    pub fn audience(&self) -> Option<&str> {
45        self.aud.as_deref()
46    }
47
48    /// Get the user roles.
49    pub fn roles(&self) -> &[String] {
50        &self.roles
51    }
52
53    /// Consume the claims and return the owned roles vector.
54    pub fn into_roles(self) -> Vec<String> {
55        self.roles
56    }
57
58    /// Consume the claims and return the owned subject string.
59    pub fn into_sub(self) -> String {
60        self.sub
61    }
62
63    /// Get the user ID as UUID.
64    pub fn user_id(&self) -> Option<Uuid> {
65        Uuid::parse_str(&self.sub).ok()
66    }
67
68    /// Check if the token is expired.
69    pub fn is_expired(&self) -> bool {
70        let now = chrono::Utc::now().timestamp();
71        self.exp < now
72    }
73
74    /// Check if the user has a role.
75    pub fn has_role(&self, role: &str) -> bool {
76        self.roles.iter().any(|r| r == role)
77    }
78
79    /// Reserved JWT claim names that should not be treated as custom claims.
80    const RESERVED_CLAIMS: &'static [&'static str] =
81        &["iss", "aud", "nbf", "jti", "sub", "iat", "exp", "roles"];
82
83    /// Get a custom claim value.
84    ///
85    /// Returns `None` for reserved JWT claims (iss, aud, nbf, jti, etc.)
86    /// to prevent claim injection via `#[serde(flatten)]`.
87    pub fn get_claim(&self, key: &str) -> Option<&serde_json::Value> {
88        if Self::RESERVED_CLAIMS.contains(&key) {
89            return None;
90        }
91        self.custom.get(key)
92    }
93
94    /// Get custom claims with reserved JWT claims filtered out.
95    ///
96    /// Prevents claim injection where standard JWT claims like `iss`, `aud`,
97    /// or `jti` end up in the custom claims map via `#[serde(flatten)]`.
98    pub fn sanitized_custom(&self) -> HashMap<String, serde_json::Value> {
99        self.custom
100            .iter()
101            .filter(|(k, _)| !Self::RESERVED_CLAIMS.contains(&k.as_str()))
102            .map(|(k, v)| (k.clone(), v.clone()))
103            .collect()
104    }
105
106    /// Get the tenant ID if present in claims.
107    pub fn tenant_id(&self) -> Option<Uuid> {
108        self.custom
109            .get("tenant_id")
110            .and_then(|v| v.as_str())
111            .and_then(|s| Uuid::parse_str(s).ok())
112    }
113
114    /// Create a builder for constructing claims.
115    pub fn builder() -> ClaimsBuilder {
116        ClaimsBuilder::new()
117    }
118}
119
120/// Builder for JWT claims.
121#[derive(Debug, Default)]
122pub struct ClaimsBuilder {
123    sub: Option<String>,
124    aud: Option<String>,
125    roles: Vec<String>,
126    custom: HashMap<String, serde_json::Value>,
127    duration_secs: i64,
128}
129
130impl ClaimsBuilder {
131    /// Create a new builder.
132    pub fn new() -> Self {
133        Self {
134            sub: None,
135            aud: None,
136            roles: Vec::new(),
137            custom: HashMap::new(),
138            duration_secs: 3600,
139        }
140    }
141
142    /// Set the subject (user ID).
143    pub fn subject(mut self, sub: impl Into<String>) -> Self {
144        self.sub = Some(sub.into());
145        self
146    }
147
148    /// Set the user ID from UUID.
149    pub fn user_id(mut self, id: Uuid) -> Self {
150        self.sub = Some(id.to_string());
151        self
152    }
153
154    /// Add a role.
155    pub fn role(mut self, role: impl Into<String>) -> Self {
156        self.roles.push(role.into());
157        self
158    }
159
160    /// Set multiple roles.
161    pub fn roles(mut self, roles: Vec<String>) -> Self {
162        self.roles = roles;
163        self
164    }
165
166    /// Add a custom claim.
167    ///
168    /// Rejects reserved JWT claim names to prevent duplicate-keyed tokens where
169    /// structural fields (`sub`, `exp`, …) and a flattened custom key both serialize
170    /// under the same JSON key — some validators read one, `ctx.claim()` reads the other.
171    ///
172    /// Use the typed setters instead:
173    /// - `sub` / `iat` / `exp` → `.subject()` / `.user_id()` / `.duration_secs()`
174    /// - `roles` → `.role()` / `.roles()`
175    /// - `aud` → `.audience()`
176    /// - `nbf`, `jti`, `iss` are not supported by this builder
177    pub fn claim(
178        mut self,
179        key: impl Into<String>,
180        value: serde_json::Value,
181    ) -> crate::Result<Self> {
182        let key = key.into();
183        if Claims::RESERVED_CLAIMS.contains(&key.as_str()) {
184            return Err(crate::ForgeError::InvalidArgument(format!(
185                "'{key}' is a reserved JWT claim name; use the typed setter instead"
186            )));
187        }
188        self.custom.insert(key, value);
189        Ok(self)
190    }
191
192    /// Set the token audience (`aud` claim).
193    pub fn audience(mut self, aud: impl Into<String>) -> Self {
194        self.aud = Some(aud.into());
195        self
196    }
197
198    /// Set the tenant ID.
199    pub fn tenant_id(mut self, id: Uuid) -> Self {
200        self.custom
201            .insert("tenant_id".to_string(), serde_json::json!(id.to_string()));
202        self
203    }
204
205    /// Set token duration in seconds.
206    pub fn duration_secs(mut self, secs: i64) -> Self {
207        self.duration_secs = secs;
208        self
209    }
210
211    /// Build the claims.
212    pub fn build(self) -> Result<Claims, String> {
213        let sub = self.sub.ok_or("Subject is required")?;
214        let now = chrono::Utc::now().timestamp();
215
216        Ok(Claims {
217            sub,
218            iat: now,
219            exp: now + self.duration_secs,
220            aud: self.aud,
221            roles: self.roles,
222            custom: self.custom,
223        })
224    }
225}
226
227#[cfg(test)]
228#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_claims_builder() {
234        let user_id = Uuid::new_v4();
235        let claims = Claims::builder()
236            .user_id(user_id)
237            .role("admin")
238            .role("user")
239            .claim("org_id", serde_json::json!("org-123"))
240            .unwrap()
241            .duration_secs(7200)
242            .build()
243            .unwrap();
244
245        assert_eq!(claims.user_id(), Some(user_id));
246        assert!(claims.has_role("admin"));
247        assert!(claims.has_role("user"));
248        assert!(!claims.has_role("superadmin"));
249        assert_eq!(
250            claims.get_claim("org_id"),
251            Some(&serde_json::json!("org-123"))
252        );
253        assert!(!claims.is_expired());
254    }
255
256    #[test]
257    fn claim_rejects_reserved_names() {
258        for reserved in Claims::RESERVED_CLAIMS {
259            let result = Claims::builder()
260                .subject("user-1")
261                .claim(*reserved, serde_json::json!("value"));
262            assert!(
263                result.is_err(),
264                "Expected '{reserved}' to be rejected but it was accepted"
265            );
266        }
267    }
268
269    #[test]
270    fn claim_accepts_custom_names() {
271        let result = Claims::builder()
272            .subject("user-1")
273            .claim("org_id", serde_json::json!("org-123"));
274        assert!(result.is_ok());
275    }
276
277    #[test]
278    fn test_claims_expiration() {
279        let claims = Claims {
280            sub: "user-1".to_string(),
281            iat: 0,
282            exp: 1,
283            aud: None,
284            roles: vec![],
285            custom: HashMap::new(),
286        };
287
288        assert!(claims.is_expired());
289    }
290
291    #[test]
292    fn test_claims_serialization() {
293        let claims = Claims::builder()
294            .subject("user-1")
295            .role("admin")
296            .build()
297            .unwrap();
298
299        let json = serde_json::to_string(&claims).unwrap();
300        let deserialized: Claims = serde_json::from_str(&json).unwrap();
301
302        assert_eq!(deserialized.sub, claims.sub);
303        assert_eq!(deserialized.roles, claims.roles);
304    }
305
306    #[test]
307    fn build_errors_when_subject_missing() {
308        let result = Claims::builder().role("user").build();
309        assert!(result.is_err());
310        assert!(result.unwrap_err().contains("Subject is required"));
311    }
312
313    #[test]
314    fn duration_secs_sets_exp_offset_from_iat() {
315        let claims = Claims::builder()
316            .subject("u")
317            .duration_secs(120)
318            .build()
319            .unwrap();
320        assert_eq!(claims.exp() - claims.iat(), 120);
321    }
322
323    #[test]
324    fn default_duration_secs_is_one_hour() {
325        let claims = Claims::builder().subject("u").build().unwrap();
326        assert_eq!(claims.exp() - claims.iat(), 3600);
327    }
328
329    #[test]
330    fn is_expired_false_for_future_exp() {
331        let now = chrono::Utc::now().timestamp();
332        let claims = Claims {
333            sub: "u".into(),
334            iat: now,
335            exp: now + 3600,
336            aud: None,
337            roles: vec![],
338            custom: HashMap::new(),
339        };
340        assert!(!claims.is_expired());
341    }
342
343    #[test]
344    fn user_id_returns_none_for_non_uuid_subject() {
345        let claims = Claims::builder().subject("not-a-uuid").build().unwrap();
346        assert!(claims.user_id().is_none());
347        assert_eq!(claims.sub(), "not-a-uuid");
348    }
349
350    #[test]
351    fn user_id_set_via_builder_round_trips_through_sub() {
352        let id = Uuid::new_v4();
353        let claims = Claims::builder().user_id(id).build().unwrap();
354        assert_eq!(claims.user_id(), Some(id));
355        assert_eq!(claims.sub(), id.to_string());
356    }
357
358    #[test]
359    fn into_methods_consume_owned_values() {
360        let claims = Claims::builder()
361            .subject("user-x")
362            .role("a")
363            .role("b")
364            .build()
365            .unwrap();
366        let roles = claims.clone().into_roles();
367        assert_eq!(roles, vec!["a".to_string(), "b".to_string()]);
368        let sub = claims.into_sub();
369        assert_eq!(sub, "user-x");
370    }
371
372    #[test]
373    fn roles_setter_replaces_prior_calls() {
374        let claims = Claims::builder()
375            .subject("u")
376            .role("first")
377            .roles(vec!["one".into(), "two".into()])
378            .build()
379            .unwrap();
380        assert_eq!(claims.roles(), &["one".to_string(), "two".to_string()]);
381    }
382
383    #[test]
384    fn get_claim_returns_none_for_reserved_names_even_if_present() {
385        // Reserved names can leak in via deserialization (#[serde(flatten)] +
386        // duplicate keys). Construct directly to bypass the builder guard.
387        let mut custom = HashMap::new();
388        custom.insert("iss".to_string(), serde_json::json!("evil"));
389        custom.insert("jti".to_string(), serde_json::json!("evil"));
390        custom.insert("safe".to_string(), serde_json::json!("ok"));
391        let claims = Claims {
392            sub: "u".into(),
393            iat: 0,
394            exp: i64::MAX,
395            aud: None,
396            roles: vec![],
397            custom,
398        };
399        assert!(claims.get_claim("iss").is_none());
400        assert!(claims.get_claim("jti").is_none());
401        assert_eq!(claims.get_claim("safe"), Some(&serde_json::json!("ok")));
402    }
403
404    #[test]
405    fn get_claim_returns_none_for_missing_custom_key() {
406        let claims = Claims::builder().subject("u").build().unwrap();
407        assert!(claims.get_claim("nope").is_none());
408    }
409
410    #[test]
411    fn sanitized_custom_filters_reserved_names() {
412        let mut custom = HashMap::new();
413        for reserved in Claims::RESERVED_CLAIMS {
414            custom.insert((*reserved).to_string(), serde_json::json!("smuggled"));
415        }
416        custom.insert("org_id".into(), serde_json::json!("o1"));
417        let claims = Claims {
418            sub: "u".into(),
419            iat: 0,
420            exp: i64::MAX,
421            aud: None,
422            roles: vec![],
423            custom,
424        };
425        let safe = claims.sanitized_custom();
426        assert_eq!(safe.len(), 1);
427        assert_eq!(safe.get("org_id"), Some(&serde_json::json!("o1")));
428        for reserved in Claims::RESERVED_CLAIMS {
429            assert!(
430                !safe.contains_key(*reserved),
431                "{reserved} should be filtered out"
432            );
433        }
434    }
435
436    #[test]
437    fn tenant_id_round_trips_via_builder() {
438        let tenant = Uuid::new_v4();
439        let claims = Claims::builder()
440            .subject("u")
441            .tenant_id(tenant)
442            .build()
443            .unwrap();
444        assert_eq!(claims.tenant_id(), Some(tenant));
445    }
446
447    #[test]
448    fn tenant_id_returns_none_when_value_is_not_string_or_uuid() {
449        // Not a string: numeric.
450        let mut custom = HashMap::new();
451        custom.insert("tenant_id".to_string(), serde_json::json!(42));
452        let claims = Claims {
453            sub: "u".into(),
454            iat: 0,
455            exp: i64::MAX,
456            aud: None,
457            roles: vec![],
458            custom,
459        };
460        assert!(claims.tenant_id().is_none());
461
462        // String but not UUID.
463        let mut custom = HashMap::new();
464        custom.insert("tenant_id".to_string(), serde_json::json!("garbage"));
465        let claims = Claims {
466            sub: "u".into(),
467            iat: 0,
468            exp: i64::MAX,
469            aud: None,
470            roles: vec![],
471            custom,
472        };
473        assert!(claims.tenant_id().is_none());
474    }
475
476    #[test]
477    fn audience_round_trips_through_typed_field() {
478        let claims = Claims::builder()
479            .subject("u")
480            .audience("my-service")
481            .build()
482            .unwrap();
483        assert_eq!(claims.audience(), Some("my-service"));
484        // Serializes into the JWT as "aud"
485        let json = serde_json::to_value(&claims).unwrap();
486        assert_eq!(json.get("aud"), Some(&serde_json::json!("my-service")));
487        // Does not leak into the custom map
488        assert!(!claims.custom.contains_key("aud"));
489    }
490
491    #[test]
492    fn audience_deserializes_from_jwt() {
493        let claims = Claims::builder()
494            .subject("u")
495            .audience("svc-1")
496            .build()
497            .unwrap();
498        let json = serde_json::to_string(&claims).unwrap();
499        let restored: Claims = serde_json::from_str(&json).unwrap();
500        assert_eq!(restored.audience(), Some("svc-1"));
501    }
502
503    #[test]
504    fn reserved_claims_set_matches_documented_list() {
505        // Lock the reserved list so future additions are intentional code review.
506        let expected: std::collections::HashSet<&str> =
507            ["iss", "aud", "nbf", "jti", "sub", "iat", "exp", "roles"]
508                .into_iter()
509                .collect();
510        let actual: std::collections::HashSet<&str> =
511            Claims::RESERVED_CLAIMS.iter().copied().collect();
512        assert_eq!(actual, expected);
513    }
514}