postrust_auth/
claims.rs

1//! JWT claims handling.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Parsed JWT claims for use in requests.
7#[derive(Clone, Debug, Default, Serialize, Deserialize)]
8pub struct Claims {
9    /// All claims as key-value pairs
10    pub values: HashMap<String, serde_json::Value>,
11}
12
13impl Claims {
14    /// Create empty claims.
15    pub fn new() -> Self {
16        Self::default()
17    }
18
19    /// Get a claim value as a string.
20    pub fn get_str(&self, key: &str) -> Option<&str> {
21        self.values.get(key).and_then(|v| v.as_str())
22    }
23
24    /// Get a claim value as an integer.
25    pub fn get_i64(&self, key: &str) -> Option<i64> {
26        self.values.get(key).and_then(|v| v.as_i64())
27    }
28
29    /// Get a claim value as a boolean.
30    pub fn get_bool(&self, key: &str) -> Option<bool> {
31        self.values.get(key).and_then(|v| v.as_bool())
32    }
33
34    /// Get a claim value.
35    pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
36        self.values.get(key)
37    }
38
39    /// Set a claim value.
40    pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
41        self.values.insert(key.into(), value);
42    }
43
44    /// Convert to JSON string for GUC.
45    pub fn to_json(&self) -> String {
46        serde_json::to_string(&self.values).unwrap_or_else(|_| "{}".to_string())
47    }
48
49    /// Get claims for a specific prefix (e.g., "request.jwt.claims.").
50    pub fn prefixed_entries(&self, prefix: &str) -> Vec<(String, String)> {
51        self.values
52            .iter()
53            .map(|(k, v)| {
54                let key = format!("{}{}", prefix, k);
55                let value = match v {
56                    serde_json::Value::String(s) => s.clone(),
57                    other => other.to_string(),
58                };
59                (key, value)
60            })
61            .collect()
62    }
63}
64
65impl From<HashMap<String, serde_json::Value>> for Claims {
66    fn from(values: HashMap<String, serde_json::Value>) -> Self {
67        Self { values }
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn test_claims_get_str() {
77        let mut claims = Claims::new();
78        claims.set("role", serde_json::Value::String("admin".into()));
79
80        assert_eq!(claims.get_str("role"), Some("admin"));
81        assert_eq!(claims.get_str("missing"), None);
82    }
83
84    #[test]
85    fn test_claims_get_i64() {
86        let mut claims = Claims::new();
87        claims.set("user_id", serde_json::Value::Number(42.into()));
88
89        assert_eq!(claims.get_i64("user_id"), Some(42));
90    }
91
92    #[test]
93    fn test_claims_to_json() {
94        let mut claims = Claims::new();
95        claims.set("role", serde_json::Value::String("user".into()));
96        claims.set("id", serde_json::Value::Number(123.into()));
97
98        let json = claims.to_json();
99        assert!(json.contains("role"));
100        assert!(json.contains("user"));
101    }
102
103    #[test]
104    fn test_claims_prefixed_entries() {
105        let mut claims = Claims::new();
106        claims.set("role", serde_json::Value::String("admin".into()));
107        claims.set("email", serde_json::Value::String("test@example.com".into()));
108
109        let entries = claims.prefixed_entries("request.jwt.claims.");
110        assert_eq!(entries.len(), 2);
111        assert!(entries.iter().any(|(k, _)| k == "request.jwt.claims.role"));
112    }
113}