1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum KeyClass {
7 Secret,
9 Publishable,
11}
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct Limits {
16 #[serde(skip_serializing_if = "Option::is_none")]
18 pub max_connections: Option<u32>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub max_subscriptions: Option<u32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub max_snapshot_rows: Option<u32>,
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub max_messages_per_minute: Option<u32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub max_bytes_per_minute: Option<u64>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct SessionClaims {
36 pub iss: String,
38 pub sub: String,
40 pub aud: String,
42 pub iat: u64,
44 pub nbf: u64,
46 pub exp: u64,
48 pub jti: String,
50 pub scope: String,
52 pub metering_key: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub deployment_id: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub origin: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none", rename = "client_ip")]
62 pub client_ip: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub limits: Option<Limits>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub plan: Option<String>,
69 #[serde(rename = "key_class")]
71 pub key_class: KeyClass,
72}
73
74impl SessionClaims {
75 pub fn builder(
77 iss: impl Into<String>,
78 sub: impl Into<String>,
79 aud: impl Into<String>,
80 ) -> SessionClaimsBuilder {
81 SessionClaimsBuilder::new(iss, sub, aud)
82 }
83
84 pub fn is_expired(&self, now: u64) -> bool {
86 self.exp <= now
87 }
88
89 pub fn is_valid(&self, now: u64) -> bool {
91 self.nbf <= now && self.iat <= now
92 }
93}
94
95pub struct SessionClaimsBuilder {
97 iss: String,
98 sub: String,
99 aud: String,
100 iat: u64,
101 nbf: u64,
102 exp: u64,
103 jti: String,
104 scope: String,
105 metering_key: String,
106 deployment_id: Option<String>,
107 origin: Option<String>,
108 client_ip: Option<String>,
109 limits: Option<Limits>,
110 plan: Option<String>,
111 key_class: KeyClass,
112}
113
114impl SessionClaimsBuilder {
115 fn new(iss: impl Into<String>, sub: impl Into<String>, aud: impl Into<String>) -> Self {
116 use std::time::{SystemTime, UNIX_EPOCH};
117 let now = SystemTime::now()
118 .duration_since(UNIX_EPOCH)
119 .expect("time should not be before epoch")
120 .as_secs();
121
122 Self {
123 iss: iss.into(),
124 sub: sub.into(),
125 aud: aud.into(),
126 iat: now,
127 nbf: now,
128 exp: now + crate::DEFAULT_SESSION_TTL_SECONDS,
129 jti: uuid::Uuid::new_v4().to_string(),
130 scope: "read".to_string(),
131 metering_key: String::new(),
132 deployment_id: None,
133 origin: None,
134 client_ip: None,
135 limits: None,
136 plan: None,
137 key_class: KeyClass::Publishable,
138 }
139 }
140
141 pub fn with_ttl(mut self, ttl_seconds: u64) -> Self {
142 self.exp = self.iat + ttl_seconds;
143 self
144 }
145
146 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
147 self.scope = scope.into();
148 self
149 }
150
151 pub fn with_metering_key(mut self, key: impl Into<String>) -> Self {
152 self.metering_key = key.into();
153 self
154 }
155
156 pub fn with_deployment_id(mut self, id: impl Into<String>) -> Self {
157 self.deployment_id = Some(id.into());
158 self
159 }
160
161 pub fn with_origin(mut self, origin: impl Into<String>) -> Self {
162 self.origin = Some(origin.into());
163 self
164 }
165
166 pub fn with_client_ip(mut self, client_ip: impl Into<String>) -> Self {
167 self.client_ip = Some(client_ip.into());
168 self
169 }
170
171 pub fn with_limits(mut self, limits: Limits) -> Self {
172 self.limits = Some(limits);
173 self
174 }
175
176 pub fn with_plan(mut self, plan: impl Into<String>) -> Self {
177 self.plan = Some(plan.into());
178 self
179 }
180
181 pub fn with_key_class(mut self, key_class: KeyClass) -> Self {
182 self.key_class = key_class;
183 self
184 }
185
186 pub fn with_jti(mut self, jti: impl Into<String>) -> Self {
187 self.jti = jti.into();
188 self
189 }
190
191 pub fn build(self) -> SessionClaims {
192 SessionClaims {
193 iss: self.iss,
194 sub: self.sub,
195 aud: self.aud,
196 iat: self.iat,
197 nbf: self.nbf,
198 exp: self.exp,
199 jti: self.jti,
200 scope: self.scope,
201 metering_key: self.metering_key,
202 deployment_id: self.deployment_id,
203 origin: self.origin,
204 client_ip: self.client_ip,
205 limits: self.limits,
206 plan: self.plan,
207 key_class: self.key_class,
208 }
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct AuthContext {
215 pub subject: String,
217 pub issuer: String,
219 pub key_class: KeyClass,
221 pub metering_key: String,
223 pub deployment_id: Option<String>,
225 pub expires_at: u64,
227 pub scope: String,
229 pub limits: Limits,
231 pub plan: Option<String>,
233 pub origin: Option<String>,
235 pub client_ip: Option<String>,
237 pub jti: String,
239}
240
241impl AuthContext {
242 pub fn from_claims(claims: SessionClaims) -> Self {
244 Self {
245 subject: claims.sub,
246 issuer: claims.iss,
247 key_class: claims.key_class,
248 metering_key: claims.metering_key,
249 deployment_id: claims.deployment_id,
250 expires_at: claims.exp,
251 scope: claims.scope,
252 limits: claims.limits.unwrap_or_default(),
253 plan: claims.plan,
254 origin: claims.origin,
255 client_ip: claims.client_ip,
256 jti: claims.jti,
257 }
258 }
259}