Skip to main content

forge_core/auth/
claims.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use uuid::Uuid;
4
5/// JWT claims structure.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Claims {
8    /// Subject (user ID).
9    pub sub: String,
10    /// Issued at (Unix timestamp).
11    pub iat: i64,
12    /// Expiration time (Unix timestamp).
13    pub exp: i64,
14    /// User roles.
15    #[serde(default)]
16    pub roles: Vec<String>,
17    /// Custom claims.
18    #[serde(flatten)]
19    pub custom: HashMap<String, serde_json::Value>,
20}
21
22impl Claims {
23    /// Get the user ID as UUID.
24    pub fn user_id(&self) -> Option<Uuid> {
25        Uuid::parse_str(&self.sub).ok()
26    }
27
28    /// Check if the token is expired.
29    pub fn is_expired(&self) -> bool {
30        let now = chrono::Utc::now().timestamp();
31        self.exp < now
32    }
33
34    /// Check if the user has a role.
35    pub fn has_role(&self, role: &str) -> bool {
36        self.roles.iter().any(|r| r == role)
37    }
38
39    /// Reserved JWT claim names that should not be treated as custom claims.
40    const RESERVED_CLAIMS: &'static [&'static str] =
41        &["iss", "aud", "nbf", "jti", "sub", "iat", "exp", "roles"];
42
43    /// Get a custom claim value.
44    ///
45    /// Returns `None` for reserved JWT claims (iss, aud, nbf, jti, etc.)
46    /// to prevent claim injection via `#[serde(flatten)]`.
47    pub fn get_claim(&self, key: &str) -> Option<&serde_json::Value> {
48        if Self::RESERVED_CLAIMS.contains(&key) {
49            return None;
50        }
51        self.custom.get(key)
52    }
53
54    /// Get custom claims with reserved JWT claims filtered out.
55    ///
56    /// Prevents claim injection where standard JWT claims like `iss`, `aud`,
57    /// or `jti` end up in the custom claims map via `#[serde(flatten)]`.
58    pub fn sanitized_custom(&self) -> HashMap<String, serde_json::Value> {
59        self.custom
60            .iter()
61            .filter(|(k, _)| !Self::RESERVED_CLAIMS.contains(&k.as_str()))
62            .map(|(k, v)| (k.clone(), v.clone()))
63            .collect()
64    }
65
66    /// Get the tenant ID if present in claims.
67    pub fn tenant_id(&self) -> Option<Uuid> {
68        self.custom
69            .get("tenant_id")
70            .and_then(|v| v.as_str())
71            .and_then(|s| Uuid::parse_str(s).ok())
72    }
73
74    /// Create a builder for constructing claims.
75    pub fn builder() -> ClaimsBuilder {
76        ClaimsBuilder::new()
77    }
78}
79
80/// Builder for JWT claims.
81#[derive(Debug, Default)]
82pub struct ClaimsBuilder {
83    sub: Option<String>,
84    roles: Vec<String>,
85    custom: HashMap<String, serde_json::Value>,
86    duration_secs: i64,
87}
88
89impl ClaimsBuilder {
90    /// Create a new builder.
91    pub fn new() -> Self {
92        Self {
93            sub: None,
94            roles: Vec::new(),
95            custom: HashMap::new(),
96            duration_secs: 3600, // 1 hour default
97        }
98    }
99
100    /// Set the subject (user ID).
101    pub fn subject(mut self, sub: impl Into<String>) -> Self {
102        self.sub = Some(sub.into());
103        self
104    }
105
106    /// Set the user ID from UUID.
107    pub fn user_id(mut self, id: Uuid) -> Self {
108        self.sub = Some(id.to_string());
109        self
110    }
111
112    /// Add a role.
113    pub fn role(mut self, role: impl Into<String>) -> Self {
114        self.roles.push(role.into());
115        self
116    }
117
118    /// Set multiple roles.
119    pub fn roles(mut self, roles: Vec<String>) -> Self {
120        self.roles = roles;
121        self
122    }
123
124    /// Add a custom claim.
125    pub fn claim(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
126        self.custom.insert(key.into(), value);
127        self
128    }
129
130    /// Set the tenant ID.
131    pub fn tenant_id(mut self, id: Uuid) -> Self {
132        self.custom
133            .insert("tenant_id".to_string(), serde_json::json!(id.to_string()));
134        self
135    }
136
137    /// Set token duration in seconds.
138    pub fn duration_secs(mut self, secs: i64) -> Self {
139        self.duration_secs = secs;
140        self
141    }
142
143    /// Build the claims.
144    pub fn build(self) -> Result<Claims, String> {
145        let sub = self.sub.ok_or("Subject is required")?;
146        let now = chrono::Utc::now().timestamp();
147
148        Ok(Claims {
149            sub,
150            iat: now,
151            exp: now + self.duration_secs,
152            roles: self.roles,
153            custom: self.custom,
154        })
155    }
156}
157
158#[cfg(test)]
159#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_claims_builder() {
165        let user_id = Uuid::new_v4();
166        let claims = Claims::builder()
167            .user_id(user_id)
168            .role("admin")
169            .role("user")
170            .claim("org_id", serde_json::json!("org-123"))
171            .duration_secs(7200)
172            .build()
173            .unwrap();
174
175        assert_eq!(claims.user_id(), Some(user_id));
176        assert!(claims.has_role("admin"));
177        assert!(claims.has_role("user"));
178        assert!(!claims.has_role("superadmin"));
179        assert_eq!(
180            claims.get_claim("org_id"),
181            Some(&serde_json::json!("org-123"))
182        );
183        assert!(!claims.is_expired());
184    }
185
186    #[test]
187    fn test_claims_expiration() {
188        let claims = Claims {
189            sub: "user-1".to_string(),
190            iat: 0,
191            exp: 1, // Expired timestamp
192            roles: vec![],
193            custom: HashMap::new(),
194        };
195
196        assert!(claims.is_expired());
197    }
198
199    #[test]
200    fn test_claims_serialization() {
201        let claims = Claims::builder()
202            .subject("user-1")
203            .role("admin")
204            .build()
205            .unwrap();
206
207        let json = serde_json::to_string(&claims).unwrap();
208        let deserialized: Claims = serde_json::from_str(&json).unwrap();
209
210        assert_eq!(deserialized.sub, claims.sub);
211        assert_eq!(deserialized.roles, claims.roles);
212    }
213}