use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub iat: i64,
pub exp: i64,
#[serde(default)]
pub roles: Vec<String>,
#[serde(flatten)]
pub custom: HashMap<String, serde_json::Value>,
}
impl Claims {
pub fn user_id(&self) -> Option<Uuid> {
Uuid::parse_str(&self.sub).ok()
}
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp();
self.exp < now
}
pub fn has_role(&self, role: &str) -> bool {
self.roles.iter().any(|r| r == role)
}
const RESERVED_CLAIMS: &'static [&'static str] =
&["iss", "aud", "nbf", "jti", "sub", "iat", "exp", "roles"];
pub fn get_claim(&self, key: &str) -> Option<&serde_json::Value> {
if Self::RESERVED_CLAIMS.contains(&key) {
return None;
}
self.custom.get(key)
}
pub fn sanitized_custom(&self) -> HashMap<String, serde_json::Value> {
self.custom
.iter()
.filter(|(k, _)| !Self::RESERVED_CLAIMS.contains(&k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub fn tenant_id(&self) -> Option<Uuid> {
self.custom
.get("tenant_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok())
}
pub fn builder() -> ClaimsBuilder {
ClaimsBuilder::new()
}
}
#[derive(Debug, Default)]
pub struct ClaimsBuilder {
sub: Option<String>,
roles: Vec<String>,
custom: HashMap<String, serde_json::Value>,
duration_secs: i64,
}
impl ClaimsBuilder {
pub fn new() -> Self {
Self {
sub: None,
roles: Vec::new(),
custom: HashMap::new(),
duration_secs: 3600, }
}
pub fn subject(mut self, sub: impl Into<String>) -> Self {
self.sub = Some(sub.into());
self
}
pub fn user_id(mut self, id: Uuid) -> Self {
self.sub = Some(id.to_string());
self
}
pub fn role(mut self, role: impl Into<String>) -> Self {
self.roles.push(role.into());
self
}
pub fn roles(mut self, roles: Vec<String>) -> Self {
self.roles = roles;
self
}
pub fn claim(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.custom.insert(key.into(), value);
self
}
pub fn tenant_id(mut self, id: Uuid) -> Self {
self.custom
.insert("tenant_id".to_string(), serde_json::json!(id.to_string()));
self
}
pub fn duration_secs(mut self, secs: i64) -> Self {
self.duration_secs = secs;
self
}
pub fn build(self) -> Result<Claims, String> {
let sub = self.sub.ok_or("Subject is required")?;
let now = chrono::Utc::now().timestamp();
Ok(Claims {
sub,
iat: now,
exp: now + self.duration_secs,
roles: self.roles,
custom: self.custom,
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn test_claims_builder() {
let user_id = Uuid::new_v4();
let claims = Claims::builder()
.user_id(user_id)
.role("admin")
.role("user")
.claim("org_id", serde_json::json!("org-123"))
.duration_secs(7200)
.build()
.unwrap();
assert_eq!(claims.user_id(), Some(user_id));
assert!(claims.has_role("admin"));
assert!(claims.has_role("user"));
assert!(!claims.has_role("superadmin"));
assert_eq!(
claims.get_claim("org_id"),
Some(&serde_json::json!("org-123"))
);
assert!(!claims.is_expired());
}
#[test]
fn test_claims_expiration() {
let claims = Claims {
sub: "user-1".to_string(),
iat: 0,
exp: 1, roles: vec![],
custom: HashMap::new(),
};
assert!(claims.is_expired());
}
#[test]
fn test_claims_serialization() {
let claims = Claims::builder()
.subject("user-1")
.role("admin")
.build()
.unwrap();
let json = serde_json::to_string(&claims).unwrap();
let deserialized: Claims = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.sub, claims.sub);
assert_eq!(deserialized.roles, claims.roles);
}
}