Skip to main content

hyperstack_auth/
claims.rs

1use serde::{Deserialize, Serialize};
2
3/// Key classification for metering and policy
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum KeyClass {
7    /// Secret API key - long-lived, high trust
8    Secret,
9    /// Publishable key - safe for browsers, constrained
10    Publishable,
11}
12
13/// Resource limits for a session
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct Limits {
16    /// Maximum concurrent connections for this subject
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub max_connections: Option<u32>,
19    /// Maximum subscriptions per connection
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub max_subscriptions: Option<u32>,
22    /// Maximum snapshot rows per request
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub max_snapshot_rows: Option<u32>,
25    /// Maximum messages per minute
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub max_messages_per_minute: Option<u32>,
28    /// Maximum egress bytes per minute
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub max_bytes_per_minute: Option<u64>,
31}
32
33/// Session token claims
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct SessionClaims {
36    /// Issuer - who issued this token
37    pub iss: String,
38    /// Subject - who this token is for
39    pub sub: String,
40    /// Audience - intended recipient (e.g., deployment ID)
41    pub aud: String,
42    /// Issued at (Unix timestamp)
43    pub iat: u64,
44    /// Not valid before (Unix timestamp)
45    pub nbf: u64,
46    /// Expiration time (Unix timestamp)
47    pub exp: u64,
48    /// JWT ID - unique identifier for this token
49    pub jti: String,
50    /// Scope - permissions granted
51    pub scope: String,
52    /// Metering key - for usage attribution
53    pub metering_key: String,
54    /// Deployment ID (optional)
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub deployment_id: Option<String>,
57    /// Origin binding (optional, defense-in-depth)
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub origin: Option<String>,
60    /// Client IP binding (optional, for high-security scenarios)
61    #[serde(skip_serializing_if = "Option::is_none", rename = "client_ip")]
62    pub client_ip: Option<String>,
63    /// Resource limits
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub limits: Option<Limits>,
66    /// Plan identifier (optional)
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub plan: Option<String>,
69    /// Key class (secret vs publishable)
70    #[serde(rename = "key_class")]
71    pub key_class: KeyClass,
72}
73
74impl SessionClaims {
75    /// Create a new session claims builder
76    pub fn builder(
77        iss: impl Into<String>,
78        sub: impl Into<String>,
79        aud: impl Into<String>,
80    ) -> SessionClaimsBuilder {
81        SessionClaimsBuilder::new(iss, sub, aud)
82    }
83
84    /// Check if the token is expired
85    pub fn is_expired(&self, now: u64) -> bool {
86        self.exp <= now
87    }
88
89    /// Check if the token is valid (not before issued)
90    pub fn is_valid(&self, now: u64) -> bool {
91        self.nbf <= now && self.iat <= now
92    }
93}
94
95/// Builder for SessionClaims
96pub struct SessionClaimsBuilder {
97    iss: String,
98    sub: String,
99    aud: String,
100    iat: u64,
101    nbf: u64,
102    exp: u64,
103    jti: String,
104    scope: String,
105    metering_key: String,
106    deployment_id: Option<String>,
107    origin: Option<String>,
108    client_ip: Option<String>,
109    limits: Option<Limits>,
110    plan: Option<String>,
111    key_class: KeyClass,
112}
113
114impl SessionClaimsBuilder {
115    fn new(iss: impl Into<String>, sub: impl Into<String>, aud: impl Into<String>) -> Self {
116        use std::time::{SystemTime, UNIX_EPOCH};
117        let now = SystemTime::now()
118            .duration_since(UNIX_EPOCH)
119            .expect("time should not be before epoch")
120            .as_secs();
121
122        Self {
123            iss: iss.into(),
124            sub: sub.into(),
125            aud: aud.into(),
126            iat: now,
127            nbf: now,
128            exp: now + crate::DEFAULT_SESSION_TTL_SECONDS,
129            jti: uuid::Uuid::new_v4().to_string(),
130            scope: "read".to_string(),
131            metering_key: String::new(),
132            deployment_id: None,
133            origin: None,
134            client_ip: None,
135            limits: None,
136            plan: None,
137            key_class: KeyClass::Publishable,
138        }
139    }
140
141    pub fn with_ttl(mut self, ttl_seconds: u64) -> Self {
142        self.exp = self.iat + ttl_seconds;
143        self
144    }
145
146    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
147        self.scope = scope.into();
148        self
149    }
150
151    pub fn with_metering_key(mut self, key: impl Into<String>) -> Self {
152        self.metering_key = key.into();
153        self
154    }
155
156    pub fn with_deployment_id(mut self, id: impl Into<String>) -> Self {
157        self.deployment_id = Some(id.into());
158        self
159    }
160
161    pub fn with_origin(mut self, origin: impl Into<String>) -> Self {
162        self.origin = Some(origin.into());
163        self
164    }
165
166    pub fn with_client_ip(mut self, client_ip: impl Into<String>) -> Self {
167        self.client_ip = Some(client_ip.into());
168        self
169    }
170
171    pub fn with_limits(mut self, limits: Limits) -> Self {
172        self.limits = Some(limits);
173        self
174    }
175
176    pub fn with_plan(mut self, plan: impl Into<String>) -> Self {
177        self.plan = Some(plan.into());
178        self
179    }
180
181    pub fn with_key_class(mut self, key_class: KeyClass) -> Self {
182        self.key_class = key_class;
183        self
184    }
185
186    pub fn with_jti(mut self, jti: impl Into<String>) -> Self {
187        self.jti = jti.into();
188        self
189    }
190
191    pub fn build(self) -> SessionClaims {
192        SessionClaims {
193            iss: self.iss,
194            sub: self.sub,
195            aud: self.aud,
196            iat: self.iat,
197            nbf: self.nbf,
198            exp: self.exp,
199            jti: self.jti,
200            scope: self.scope,
201            metering_key: self.metering_key,
202            deployment_id: self.deployment_id,
203            origin: self.origin,
204            client_ip: self.client_ip,
205            limits: self.limits,
206            plan: self.plan,
207            key_class: self.key_class,
208        }
209    }
210}
211
212/// Auth context extracted from a verified token
213#[derive(Debug, Clone)]
214pub struct AuthContext {
215    /// Subject identifier
216    pub subject: String,
217    /// Issuer
218    pub issuer: String,
219    /// Key class (secret vs publishable)
220    pub key_class: KeyClass,
221    /// Metering key for usage attribution
222    pub metering_key: String,
223    /// Deployment ID binding
224    pub deployment_id: Option<String>,
225    /// Token expiration time
226    pub expires_at: u64,
227    /// Granted scope
228    pub scope: String,
229    /// Resource limits
230    pub limits: Limits,
231    /// Plan or access tier associated with the session
232    pub plan: Option<String>,
233    /// Origin binding
234    pub origin: Option<String>,
235    /// Client IP binding
236    pub client_ip: Option<String>,
237    /// JWT ID
238    pub jti: String,
239}
240
241impl AuthContext {
242    /// Create AuthContext from verified claims
243    pub fn from_claims(claims: SessionClaims) -> Self {
244        Self {
245            subject: claims.sub,
246            issuer: claims.iss,
247            key_class: claims.key_class,
248            metering_key: claims.metering_key,
249            deployment_id: claims.deployment_id,
250            expires_at: claims.exp,
251            scope: claims.scope,
252            limits: claims.limits.unwrap_or_default(),
253            plan: claims.plan,
254            origin: claims.origin,
255            client_ip: claims.client_ip,
256            jti: claims.jti,
257        }
258    }
259}