arcly_http/auth/jwt.rs
1//! JWT authentication service — sign, decode, and validate JSON Web Tokens.
2//!
3//! ## Usage
4//!
5//! Provide a `JwtService` instance in an `ArclyPlugin::on_init`:
6//!
7//! ```ignore
8//! ctx.provide(JwtService::new(JwtConfig {
9//! secret: "change-in-prod".to_string(),
10//! access_ttl_secs: 900,
11//! refresh_ttl_secs: 604_800,
12//! ..Default::default()
13//! }));
14//! ```
15//!
16//! Once provided, the HTTP and WebSocket boundaries automatically decode the
17//! `Authorization: Bearer <token>` header and populate `RequestContext::claims()`
18//! on every request — no per-handler boilerplate needed. Protect routes with
19//! `JWT_AUTH.check(&ctx)?` or `RoleGuard::require("admin").check(&ctx)?`.
20
21use std::sync::Arc;
22use std::time::{SystemTime, UNIX_EPOCH};
23
24use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
25use serde::{Deserialize, Serialize};
26
27use crate::web::context::Claims;
28
29// ─── Configuration ────────────────────────────────────────────────────────────
30
31/// Configuration for `JwtService`. Build once at startup and provide via DI.
32pub struct JwtConfig {
33 /// HMAC secret (HS256 / HS384 / HS512) or PEM-encoded key for RS/ES algorithms.
34 pub secret: String,
35 /// Signing algorithm. Defaults to `HS256`.
36 pub algorithm: Algorithm,
37 /// Lifetime of access tokens in seconds. Defaults to 900 (15 min).
38 pub access_ttl_secs: u64,
39 /// Lifetime of refresh tokens in seconds. Defaults to 604 800 (7 days).
40 pub refresh_ttl_secs: u64,
41}
42
43impl Default for JwtConfig {
44 fn default() -> Self {
45 Self {
46 secret: "change-me-in-production".to_string(),
47 algorithm: Algorithm::HS256,
48 access_ttl_secs: 900,
49 refresh_ttl_secs: 604_800,
50 }
51 }
52}
53
54// ─── Internal claims struct ───────────────────────────────────────────────────
55
56/// Private claims struct used for `encode` / `decode`.
57/// Decoded into a `serde_json::Map` before being stored on `RequestContext`.
58#[derive(Debug, Serialize, Deserialize)]
59struct JwtClaims {
60 sub: String,
61 /// Omitted from refresh tokens (always empty there — no point encoding them).
62 #[serde(skip_serializing_if = "String::is_empty", default)]
63 role: String,
64 #[serde(skip_serializing_if = "String::is_empty", default)]
65 email: String,
66 /// "access" or "refresh"
67 #[serde(rename = "type")]
68 kind: String,
69 /// JWT ID — unique token identifier, used for refresh token rotation.
70 jti: String,
71 iat: u64,
72 exp: u64,
73 /// Fine-grained permissions (e.g. `["users:*", "orders:read"]`).
74 /// Omitted from refresh tokens and when no permissions are set.
75 #[serde(skip_serializing_if = "Vec::is_empty", default)]
76 perms: Vec<String>,
77 /// Home tenant of the principal. `TenantGuard` cross-checks this against
78 /// the request's resolved tenant, which (a) blocks forged tenant headers
79 /// and (b) makes dropping the header useless: a token bound to tenant A
80 /// can never act as the fallback tenant.
81 #[serde(skip_serializing_if = "String::is_empty", default)]
82 tenant: String,
83}
84
85// ─── JwtService ───────────────────────────────────────────────────────────────
86
87/// Live signing/verification keys. Swapped atomically as one bundle on
88/// rotation so readers never observe a half-rotated state. `verify` keeps the
89/// previous key so tokens signed before rotation stay valid through their TTL.
90struct JwtKeyMaterial {
91 encoding: EncodingKey,
92 verify: Vec<DecodingKey>, // [current, previous?]
93 version: u64,
94}
95
96impl JwtKeyMaterial {
97 fn from_secret(secret: &[u8], version: u64, previous: Option<DecodingKey>) -> Self {
98 let mut verify = vec![DecodingKey::from_secret(secret)];
99 verify.extend(previous);
100 Self {
101 encoding: EncodingKey::from_secret(secret),
102 verify,
103 version,
104 }
105 }
106}
107
108/// Signs and validates JWTs. Provide this into the DI container so the framework
109/// boundaries (`boundary.rs`, `ws.rs`) can auto-populate `RequestContext::claims()`
110/// on every incoming request.
111///
112/// ## Secret rotation
113///
114/// Keys live behind [`Rotating`](crate::auth::secrets::Rotating) (an
115/// `ArcSwap`): the request path pays one atomic pointer load, while
116/// [`rotate_secret`](Self::rotate_secret) — typically driven by a
117/// `SecretSource` watcher — swaps in a new bundle with no restart. The
118/// previous key is retained for verification, so live tokens (≤ TTL old)
119/// keep validating through the grace window.
120pub struct JwtService {
121 keys: crate::auth::secrets::Rotating<JwtKeyMaterial>,
122 header: Header,
123 validation: Validation,
124 config: JwtConfig,
125}
126
127impl JwtService {
128 pub fn new(config: JwtConfig) -> Self {
129 let keys = crate::auth::secrets::Rotating::new(JwtKeyMaterial::from_secret(
130 config.secret.as_bytes(),
131 1,
132 None,
133 ));
134 let header = Header::new(config.algorithm);
135 let mut validation = Validation::new(config.algorithm);
136 validation.validate_exp = true;
137 Self {
138 keys,
139 header,
140 validation,
141 config,
142 }
143 }
144
145 /// Hot-swap the signing secret — no restart, no token mass-invalidation.
146 ///
147 /// New tokens sign with the new key immediately; tokens signed with the
148 /// previous key keep verifying until natural expiry. Versions are
149 /// monotonic: a stale (≤ current) version is ignored, making concurrent
150 /// watchers and duplicate delivery harmless.
151 pub fn rotate_secret(&self, new_secret: &[u8], version: u64) {
152 let current = self.keys.load();
153 if version <= current.version {
154 tracing::warn!(
155 current = current.version,
156 offered = version,
157 "ignoring stale JWT secret rotation",
158 );
159 return;
160 }
161 let previous = current.verify.first().cloned();
162 self.keys
163 .store(JwtKeyMaterial::from_secret(new_secret, version, previous));
164 tracing::info!(version, "JwtService signing key rotated");
165 }
166
167 fn now() -> u64 {
168 SystemTime::now()
169 .duration_since(UNIX_EPOCH)
170 .map(|d| d.as_secs())
171 .unwrap_or(0)
172 }
173
174 /// Issue a signed **access token**.
175 ///
176 /// Claims: `sub` = user ID, `role`, `email`, `type = "access"`, `jti`, `iat`, `exp`.
177 pub fn issue_access(&self, sub: &str, role: &str, email: &str) -> String {
178 self.issue_access_with_perms(sub, role, email, &[])
179 }
180
181 /// Like [`Self::issue_access`] but embeds a `perms` claim (array of permission strings).
182 ///
183 /// Use this when the app maintains a permission map so that `PermissionGuard`
184 /// can do a zero-latency lookup without hitting the store on each request.
185 pub fn issue_access_with_perms(
186 &self,
187 sub: &str,
188 role: &str,
189 email: &str,
190 perms: &[String],
191 ) -> String {
192 self.issue_access_bound(sub, role, email, perms, None)
193 }
194
195 /// Like [`Self::issue_access_with_perms`] but additionally **binds the token to
196 /// a tenant** via the `tenant` claim. `TenantGuard` then enforces that
197 /// requests carrying this token resolve to the same tenant — omitting or
198 /// forging the tenant header yields `403`, so a suspended tenant's users
199 /// cannot ride the fallback pool by dropping the header.
200 pub fn issue_access_bound(
201 &self,
202 sub: &str,
203 role: &str,
204 email: &str,
205 perms: &[String],
206 tenant: Option<&str>,
207 ) -> String {
208 let now = Self::now();
209 let claims = JwtClaims {
210 sub: sub.to_owned(),
211 role: role.to_owned(),
212 email: email.to_owned(),
213 kind: "access".to_owned(),
214 jti: new_jti(),
215 iat: now,
216 exp: now + self.config.access_ttl_secs,
217 perms: perms.to_vec(),
218 tenant: tenant.unwrap_or("").to_owned(),
219 };
220 encode(&self.header, &claims, &self.keys.load().encoding)
221 .expect("JWT encode failed — check signing key")
222 }
223
224 /// Issue a signed **refresh token** with a unique `jti`.
225 ///
226 /// Claims: `sub`, `type = "refresh"`, `jti`, `iat`, `exp`.
227 /// The `jti` is returned alongside the token so the caller can persist it.
228 pub fn issue_refresh(&self, sub: &str) -> (String, String) {
229 let now = Self::now();
230 let jti = new_jti();
231 let claims = JwtClaims {
232 sub: sub.to_owned(),
233 role: String::new(),
234 email: String::new(),
235 kind: "refresh".to_owned(),
236 jti: jti.clone(),
237 iat: now,
238 exp: now + self.config.refresh_ttl_secs,
239 perms: Vec::new(),
240 tenant: String::new(),
241 };
242 let token = encode(&self.header, &claims, &self.keys.load().encoding)
243 .expect("JWT encode failed — check signing key");
244 (token, jti)
245 }
246
247 /// Validate signature + expiry and return the decoded claims as a JSON map.
248 ///
249 /// Returns `None` for any invalid token (expired, bad signature, malformed).
250 /// Does NOT enforce token type — use [`Self::decode_access`] at request boundaries.
251 pub fn decode(&self, token: &str) -> Option<Arc<Claims>> {
252 // Try current key first, then the retained previous key (rotation
253 // grace window). Bundle is one atomic load — keys can't mix versions.
254 let keys = self.keys.load();
255 let data = keys
256 .verify
257 .iter()
258 .find_map(|k| decode::<serde_json::Value>(token, k, &self.validation).ok())?;
259 let obj = data.claims.as_object()?.clone();
260 Some(Arc::new(obj))
261 }
262
263 /// Like [`decode`] but additionally requires `"type" == "access"`.
264 ///
265 /// Use this at request boundaries so refresh tokens cannot be passed as
266 /// access tokens to authenticate protected routes.
267 pub fn decode_access(&self, token: &str) -> Option<Arc<Claims>> {
268 let claims = self.decode(token)?;
269 if claims.get("type").and_then(|v| v.as_str()) != Some("access") {
270 return None;
271 }
272 Some(claims)
273 }
274
275 /// Validate a **refresh token** specifically.
276 ///
277 /// Returns `(subject, jti)` on success, `None` otherwise.
278 /// Callers must verify that the `jti` exists in their token store before
279 /// issuing a new pair.
280 pub fn validate_refresh(&self, token: &str) -> Option<(String, String)> {
281 let claims = self.decode(token)?;
282 if claims.get("type")?.as_str()? != "refresh" {
283 return None;
284 }
285 let sub = claims.get("sub")?.as_str()?.to_owned();
286 let jti = claims.get("jti")?.as_str()?.to_owned();
287 Some((sub, jti))
288 }
289
290 /// Lifetime of access tokens in seconds (used in `TokenResponse.expires_in`).
291 pub fn access_ttl_secs(&self) -> u64 {
292 self.config.access_ttl_secs
293 }
294
295 /// Lifetime of refresh tokens (used by token store for TTL).
296 pub fn refresh_ttl_secs(&self) -> u64 {
297 self.config.refresh_ttl_secs
298 }
299}
300
301/// Extract and decode an **access** Bearer token from request headers.
302///
303/// Shared by all three request boundaries (HTTP macro routes, plugin routes,
304/// WebSocket handshake) so a security fix here applies everywhere at once.
305///
306/// Returns `None` when:
307/// - No `JwtService` is registered in the container.
308/// - The `Authorization` header is absent or not valid UTF-8.
309/// - The token is missing, expired, or has an invalid signature.
310/// - The token `"type"` claim is not `"access"` (i.e. refresh tokens are rejected).
311pub fn decode_bearer_token(
312 headers: &axum::http::HeaderMap,
313 container: &crate::core::engine::FrozenDiContainer,
314) -> Option<Arc<Claims>> {
315 let raw = headers.get("authorization")?.to_str().ok()?;
316 let token = raw.strip_prefix("Bearer ").unwrap_or(raw).trim();
317 if token.is_empty() {
318 return None;
319 }
320 container.try_get::<JwtService>()?.decode_access(token)
321}
322
323/// Generate a collision-resistant JWT ID without external deps.
324///
325/// Combines a monotonic process counter with current time and thread ID so
326/// two tokens issued concurrently on different threads in the same nanosecond
327/// still get distinct JTIs.
328fn new_jti() -> String {
329 use std::collections::hash_map::DefaultHasher;
330 use std::hash::{Hash, Hasher};
331 use std::sync::atomic::{AtomicU64, Ordering};
332
333 static COUNTER: AtomicU64 = AtomicU64::new(0);
334 let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
335
336 let mut h1 = DefaultHasher::new();
337 SystemTime::now().hash(&mut h1);
338 seq.hash(&mut h1);
339
340 let mut h2 = DefaultHasher::new();
341 std::thread::current().id().hash(&mut h2);
342 seq.wrapping_add(1).hash(&mut h2);
343
344 format!("{:016x}{:016x}", h1.finish(), h2.finish())
345}