Skip to main content

arcly_http/auth/
jwt.rs

1//! JWT authentication service — sign, decode, and validate JSON Web Tokens.
2//!
3//! ## Usage
4//!
5//! Provide a `JwtService` instance in an `ArclyPlugin::on_init`:
6//!
7//! ```ignore
8//! ctx.provide(JwtService::new(JwtConfig {
9//!     secret: "change-in-prod".to_string(),
10//!     access_ttl_secs:  900,
11//!     refresh_ttl_secs: 604_800,
12//!     ..Default::default()
13//! }));
14//! ```
15//!
16//! Once provided, the HTTP and WebSocket boundaries automatically decode the
17//! `Authorization: Bearer <token>` header and populate `RequestContext::claims()`
18//! on every request — no per-handler boilerplate needed. Protect routes with
19//! `JWT_AUTH.check(&ctx)?` or `RoleGuard::require("admin").check(&ctx)?`.
20
21use std::sync::Arc;
22use std::time::{SystemTime, UNIX_EPOCH};
23
24use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
25use serde::{Deserialize, Serialize};
26
27use crate::web::context::Claims;
28
29// ─── Configuration ────────────────────────────────────────────────────────────
30
31/// Configuration for `JwtService`. Build once at startup and provide via DI.
32/// Token signing failed — malformed key material (typically a bad rotation
33/// payload). Map to a 500 at the route; the process must keep serving.
34#[derive(Debug)]
35pub struct JwtSignError(pub jsonwebtoken::errors::Error);
36
37impl std::fmt::Display for JwtSignError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "JWT signing failed: {}", self.0)
40    }
41}
42impl std::error::Error for JwtSignError {}
43
44pub struct JwtConfig {
45    /// HMAC secret (HS256 / HS384 / HS512) or PEM-encoded key for RS/ES algorithms.
46    pub secret: String,
47    /// Signing algorithm. Defaults to `HS256`.
48    pub algorithm: Algorithm,
49    /// Lifetime of access tokens in seconds. Defaults to 900 (15 min).
50    pub access_ttl_secs: u64,
51    /// Lifetime of refresh tokens in seconds. Defaults to 604 800 (7 days).
52    pub refresh_ttl_secs: u64,
53}
54
55impl Default for JwtConfig {
56    fn default() -> Self {
57        Self {
58            secret: "change-me-in-production".to_string(),
59            algorithm: Algorithm::HS256,
60            access_ttl_secs: 900,
61            refresh_ttl_secs: 604_800,
62        }
63    }
64}
65
66// ─── Internal claims struct ───────────────────────────────────────────────────
67
68/// Private claims struct used for `encode` / `decode`.
69/// Decoded into a `serde_json::Map` before being stored on `RequestContext`.
70#[derive(Debug, Serialize, Deserialize)]
71struct JwtClaims {
72    sub: String,
73    /// Omitted from refresh tokens (always empty there — no point encoding them).
74    #[serde(skip_serializing_if = "String::is_empty", default)]
75    role: String,
76    #[serde(skip_serializing_if = "String::is_empty", default)]
77    email: String,
78    /// "access" or "refresh"
79    #[serde(rename = "type")]
80    kind: String,
81    /// JWT ID — unique token identifier, used for refresh token rotation.
82    jti: String,
83    iat: u64,
84    exp: u64,
85    /// Fine-grained permissions (e.g. `["users:*", "orders:read"]`).
86    /// Omitted from refresh tokens and when no permissions are set.
87    #[serde(skip_serializing_if = "Vec::is_empty", default)]
88    perms: Vec<String>,
89    /// Home tenant of the principal. `TenantGuard` cross-checks this against
90    /// the request's resolved tenant, which (a) blocks forged tenant headers
91    /// and (b) makes dropping the header useless: a token bound to tenant A
92    /// can never act as the fallback tenant.
93    #[serde(skip_serializing_if = "String::is_empty", default)]
94    tenant: String,
95}
96
97// ─── JwtService ───────────────────────────────────────────────────────────────
98
99/// Live signing/verification keys. Swapped atomically as one bundle on
100/// rotation so readers never observe a half-rotated state. `verify` keeps the
101/// previous key so tokens signed before rotation stay valid through their TTL.
102struct JwtKeyMaterial {
103    encoding: EncodingKey,
104    verify: Vec<DecodingKey>, // [current, previous?]
105    version: u64,
106}
107
108impl JwtKeyMaterial {
109    fn from_secret(secret: &[u8], version: u64, previous: Option<DecodingKey>) -> Self {
110        let mut verify = vec![DecodingKey::from_secret(secret)];
111        verify.extend(previous);
112        Self {
113            encoding: EncodingKey::from_secret(secret),
114            verify,
115            version,
116        }
117    }
118}
119
120/// Signs and validates JWTs. Provide this into the DI container so the framework
121/// boundaries (`boundary.rs`, `ws.rs`) can auto-populate `RequestContext::claims()`
122/// on every incoming request.
123///
124/// ## Secret rotation
125///
126/// Keys live behind [`Rotating`](crate::auth::secrets::Rotating) (an
127/// `ArcSwap`): the request path pays one atomic pointer load, while
128/// [`rotate_secret`](Self::rotate_secret) — typically driven by a
129/// `SecretSource` watcher — swaps in a new bundle with no restart. The
130/// previous key is retained for verification, so live tokens (≤ TTL old)
131/// keep validating through the grace window.
132///
133/// Token *signing* returns `Result<_, JwtSignError>` — a malformed key from
134/// a bad rotation payload must surface as a 500 on the affected request,
135/// never as a process panic.
136pub struct JwtService {
137    keys: crate::auth::secrets::Rotating<JwtKeyMaterial>,
138    header: Header,
139    validation: Validation,
140    config: JwtConfig,
141}
142
143impl JwtService {
144    pub fn new(config: JwtConfig) -> Self {
145        let keys = crate::auth::secrets::Rotating::new(JwtKeyMaterial::from_secret(
146            config.secret.as_bytes(),
147            1,
148            None,
149        ));
150        let header = Header::new(config.algorithm);
151        let mut validation = Validation::new(config.algorithm);
152        validation.validate_exp = true;
153        Self {
154            keys,
155            header,
156            validation,
157            config,
158        }
159    }
160
161    /// Hot-swap the signing secret — no restart, no token mass-invalidation.
162    ///
163    /// New tokens sign with the new key immediately; tokens signed with the
164    /// previous key keep verifying until natural expiry. Versions are
165    /// monotonic: a stale (≤ current) version is ignored, making concurrent
166    /// watchers and duplicate delivery harmless.
167    pub fn rotate_secret(&self, new_secret: &[u8], version: u64) {
168        let current = self.keys.load();
169        if version <= current.version {
170            tracing::warn!(
171                current = current.version,
172                offered = version,
173                "ignoring stale JWT secret rotation",
174            );
175            return;
176        }
177        let previous = current.verify.first().cloned();
178        self.keys
179            .store(JwtKeyMaterial::from_secret(new_secret, version, previous));
180        tracing::info!(version, "JwtService signing key rotated");
181    }
182
183    fn now() -> u64 {
184        SystemTime::now()
185            .duration_since(UNIX_EPOCH)
186            .map(|d| d.as_secs())
187            .unwrap_or(0)
188    }
189
190    /// Issue a signed **access token**.
191    ///
192    /// Claims: `sub` = user ID, `role`, `email`, `type = "access"`, `jti`, `iat`, `exp`.
193    ///
194    /// Signing can only fail on malformed key material (e.g. a bad rotation
195    /// payload) — propagate the error instead of panicking mid-traffic.
196    pub fn issue_access(&self, sub: &str, role: &str, email: &str) -> Result<String, JwtSignError> {
197        self.issue_access_with_perms(sub, role, email, &[])
198    }
199
200    /// Like [`Self::issue_access`] but embeds a `perms` claim (array of permission strings).
201    ///
202    /// Use this when the app maintains a permission map so that `PermissionGuard`
203    /// can do a zero-latency lookup without hitting the store on each request.
204    pub fn issue_access_with_perms(
205        &self,
206        sub: &str,
207        role: &str,
208        email: &str,
209        perms: &[String],
210    ) -> Result<String, JwtSignError> {
211        self.issue_access_bound(sub, role, email, perms, None)
212    }
213
214    /// Like [`Self::issue_access_with_perms`] but additionally **binds the token to
215    /// a tenant** via the `tenant` claim. `TenantGuard` then enforces that
216    /// requests carrying this token resolve to the same tenant — omitting or
217    /// forging the tenant header yields `403`, so a suspended tenant's users
218    /// cannot ride the fallback pool by dropping the header.
219    pub fn issue_access_bound(
220        &self,
221        sub: &str,
222        role: &str,
223        email: &str,
224        perms: &[String],
225        tenant: Option<&str>,
226    ) -> Result<String, JwtSignError> {
227        let now = Self::now();
228        let claims = JwtClaims {
229            sub: sub.to_owned(),
230            role: role.to_owned(),
231            email: email.to_owned(),
232            kind: "access".to_owned(),
233            jti: new_jti(),
234            iat: now,
235            exp: now + self.config.access_ttl_secs,
236            perms: perms.to_vec(),
237            tenant: tenant.unwrap_or("").to_owned(),
238        };
239        encode(&self.header, &claims, &self.keys.load().encoding).map_err(JwtSignError)
240    }
241
242    /// Issue a signed **refresh token** with a unique `jti`.
243    ///
244    /// Claims: `sub`, `type = "refresh"`, `jti`, `iat`, `exp`.
245    /// The `jti` is returned alongside the token so the caller can persist it.
246    pub fn issue_refresh(&self, sub: &str) -> Result<(String, String), JwtSignError> {
247        let now = Self::now();
248        let jti = new_jti();
249        let claims = JwtClaims {
250            sub: sub.to_owned(),
251            role: String::new(),
252            email: String::new(),
253            kind: "refresh".to_owned(),
254            jti: jti.clone(),
255            iat: now,
256            exp: now + self.config.refresh_ttl_secs,
257            perms: Vec::new(),
258            tenant: String::new(),
259        };
260        let token =
261            encode(&self.header, &claims, &self.keys.load().encoding).map_err(JwtSignError)?;
262        Ok((token, jti))
263    }
264
265    /// Validate signature + expiry and return the decoded claims as a JSON map.
266    ///
267    /// Returns `None` for any invalid token (expired, bad signature, malformed).
268    /// Does NOT enforce token type — use [`Self::decode_access`] at request boundaries.
269    pub fn decode(&self, token: &str) -> Option<Arc<Claims>> {
270        // Try current key first, then the retained previous key (rotation
271        // grace window). Bundle is one atomic load — keys can't mix versions.
272        let keys = self.keys.load();
273        let data = keys
274            .verify
275            .iter()
276            .find_map(|k| decode::<serde_json::Value>(token, k, &self.validation).ok())?;
277        let obj = data.claims.as_object()?.clone();
278        Some(Arc::new(obj))
279    }
280
281    /// Like [`decode`] but additionally requires `"type" == "access"`.
282    ///
283    /// Use this at request boundaries so refresh tokens cannot be passed as
284    /// access tokens to authenticate protected routes.
285    pub fn decode_access(&self, token: &str) -> Option<Arc<Claims>> {
286        let claims = self.decode(token)?;
287        if claims.get("type").and_then(|v| v.as_str()) != Some("access") {
288            return None;
289        }
290        Some(claims)
291    }
292
293    /// Validate a **refresh token** specifically.
294    ///
295    /// Returns `(subject, jti)` on success, `None` otherwise.
296    /// Callers must verify that the `jti` exists in their token store before
297    /// issuing a new pair.
298    pub fn validate_refresh(&self, token: &str) -> Option<(String, String)> {
299        let claims = self.decode(token)?;
300        if claims.get("type")?.as_str()? != "refresh" {
301            return None;
302        }
303        let sub = claims.get("sub")?.as_str()?.to_owned();
304        let jti = claims.get("jti")?.as_str()?.to_owned();
305        Some((sub, jti))
306    }
307
308    /// Lifetime of access tokens in seconds (used in `TokenResponse.expires_in`).
309    pub fn access_ttl_secs(&self) -> u64 {
310        self.config.access_ttl_secs
311    }
312
313    /// Lifetime of refresh tokens (used by token store for TTL).
314    pub fn refresh_ttl_secs(&self) -> u64 {
315        self.config.refresh_ttl_secs
316    }
317}
318
319/// Extract and decode an **access** Bearer token from request headers.
320///
321/// Shared by all three request boundaries (HTTP macro routes, plugin routes,
322/// WebSocket handshake) so a security fix here applies everywhere at once.
323///
324/// Returns `None` when:
325/// - No `JwtService` is registered in the container.
326/// - The `Authorization` header is absent or not valid UTF-8.
327/// - The token is missing, expired, or has an invalid signature.
328/// - The token `"type"` claim is not `"access"` (i.e. refresh tokens are rejected).
329pub fn decode_bearer_token(
330    headers: &axum::http::HeaderMap,
331    container: &crate::core::engine::FrozenDiContainer,
332) -> Option<Arc<Claims>> {
333    let raw = headers.get("authorization")?.to_str().ok()?;
334    let token = raw.strip_prefix("Bearer ").unwrap_or(raw).trim();
335    if token.is_empty() {
336        return None;
337    }
338    container.try_get::<JwtService>()?.decode_access(token)
339}
340
341/// Generate a collision-resistant JWT ID without external deps.
342///
343/// Combines a monotonic process counter with current time and thread ID so
344/// two tokens issued concurrently on different threads in the same nanosecond
345/// still get distinct JTIs.
346fn new_jti() -> String {
347    use std::collections::hash_map::DefaultHasher;
348    use std::hash::{Hash, Hasher};
349    use std::sync::atomic::{AtomicU64, Ordering};
350
351    static COUNTER: AtomicU64 = AtomicU64::new(0);
352    let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
353
354    let mut h1 = DefaultHasher::new();
355    SystemTime::now().hash(&mut h1);
356    seq.hash(&mut h1);
357
358    let mut h2 = DefaultHasher::new();
359    std::thread::current().id().hash(&mut h2);
360    seq.wrapping_add(1).hash(&mut h2);
361
362    format!("{:016x}{:016x}", h1.finish(), h2.finish())
363}