forge_core/auth/
claims.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use uuid::Uuid;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Claims {
8 pub sub: String,
10 pub iat: i64,
12 pub exp: i64,
14 #[serde(default)]
16 pub roles: Vec<String>,
17 #[serde(flatten)]
19 pub custom: HashMap<String, serde_json::Value>,
20}
21
22impl Claims {
23 pub fn user_id(&self) -> Option<Uuid> {
25 Uuid::parse_str(&self.sub).ok()
26 }
27
28 pub fn is_expired(&self) -> bool {
30 let now = chrono::Utc::now().timestamp();
31 self.exp < now
32 }
33
34 pub fn has_role(&self, role: &str) -> bool {
36 self.roles.iter().any(|r| r == role)
37 }
38
39 const RESERVED_CLAIMS: &'static [&'static str] =
41 &["iss", "aud", "nbf", "jti", "sub", "iat", "exp", "roles"];
42
43 pub fn get_claim(&self, key: &str) -> Option<&serde_json::Value> {
48 if Self::RESERVED_CLAIMS.contains(&key) {
49 return None;
50 }
51 self.custom.get(key)
52 }
53
54 pub fn sanitized_custom(&self) -> HashMap<String, serde_json::Value> {
59 self.custom
60 .iter()
61 .filter(|(k, _)| !Self::RESERVED_CLAIMS.contains(&k.as_str()))
62 .map(|(k, v)| (k.clone(), v.clone()))
63 .collect()
64 }
65
66 pub fn tenant_id(&self) -> Option<Uuid> {
68 self.custom
69 .get("tenant_id")
70 .and_then(|v| v.as_str())
71 .and_then(|s| Uuid::parse_str(s).ok())
72 }
73
74 pub fn builder() -> ClaimsBuilder {
76 ClaimsBuilder::new()
77 }
78}
79
80#[derive(Debug, Default)]
82pub struct ClaimsBuilder {
83 sub: Option<String>,
84 roles: Vec<String>,
85 custom: HashMap<String, serde_json::Value>,
86 duration_secs: i64,
87}
88
89impl ClaimsBuilder {
90 pub fn new() -> Self {
92 Self {
93 sub: None,
94 roles: Vec::new(),
95 custom: HashMap::new(),
96 duration_secs: 3600, }
98 }
99
100 pub fn subject(mut self, sub: impl Into<String>) -> Self {
102 self.sub = Some(sub.into());
103 self
104 }
105
106 pub fn user_id(mut self, id: Uuid) -> Self {
108 self.sub = Some(id.to_string());
109 self
110 }
111
112 pub fn role(mut self, role: impl Into<String>) -> Self {
114 self.roles.push(role.into());
115 self
116 }
117
118 pub fn roles(mut self, roles: Vec<String>) -> Self {
120 self.roles = roles;
121 self
122 }
123
124 pub fn claim(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
126 self.custom.insert(key.into(), value);
127 self
128 }
129
130 pub fn tenant_id(mut self, id: Uuid) -> Self {
132 self.custom
133 .insert("tenant_id".to_string(), serde_json::json!(id.to_string()));
134 self
135 }
136
137 pub fn duration_secs(mut self, secs: i64) -> Self {
139 self.duration_secs = secs;
140 self
141 }
142
143 pub fn build(self) -> Result<Claims, String> {
145 let sub = self.sub.ok_or("Subject is required")?;
146 let now = chrono::Utc::now().timestamp();
147
148 Ok(Claims {
149 sub,
150 iat: now,
151 exp: now + self.duration_secs,
152 roles: self.roles,
153 custom: self.custom,
154 })
155 }
156}
157
158#[cfg(test)]
159#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_claims_builder() {
165 let user_id = Uuid::new_v4();
166 let claims = Claims::builder()
167 .user_id(user_id)
168 .role("admin")
169 .role("user")
170 .claim("org_id", serde_json::json!("org-123"))
171 .duration_secs(7200)
172 .build()
173 .unwrap();
174
175 assert_eq!(claims.user_id(), Some(user_id));
176 assert!(claims.has_role("admin"));
177 assert!(claims.has_role("user"));
178 assert!(!claims.has_role("superadmin"));
179 assert_eq!(
180 claims.get_claim("org_id"),
181 Some(&serde_json::json!("org-123"))
182 );
183 assert!(!claims.is_expired());
184 }
185
186 #[test]
187 fn test_claims_expiration() {
188 let claims = Claims {
189 sub: "user-1".to_string(),
190 iat: 0,
191 exp: 1, roles: vec![],
193 custom: HashMap::new(),
194 };
195
196 assert!(claims.is_expired());
197 }
198
199 #[test]
200 fn test_claims_serialization() {
201 let claims = Claims::builder()
202 .subject("user-1")
203 .role("admin")
204 .build()
205 .unwrap();
206
207 let json = serde_json::to_string(&claims).unwrap();
208 let deserialized: Claims = serde_json::from_str(&json).unwrap();
209
210 assert_eq!(deserialized.sub, claims.sub);
211 assert_eq!(deserialized.roles, claims.roles);
212 }
213}