Skip to main content

pylon_auth/
jwt.rs

1//! Stateless JWT sessions — alternative to opaque session tokens.
2//!
3//! By default Pylon mints opaque random `pylon_…` tokens that must
4//! be looked up in the session store on every request. For deploys
5//! that can't tolerate that round-trip (edge runtimes, CDN-backed
6//! routes, multi-region read replicas), Pylon can mint **JWT-shaped**
7//! sessions instead — verified by the local secret with no DB hit.
8//!
9//! Trade-offs:
10//!   - **Pro**: stateless verification (no DB read on every request)
11//!   - **Pro**: clients can decode their own claims (without verifying)
12//!     for UI personalization without a `/me` round-trip
13//!   - **Con**: revocation requires either a denylist or a short TTL —
14//!     a leaked JWT stays valid until its `exp`
15//!   - **Con**: secret rotation needs both old + new keys to coexist
16//!     for at least one session lifetime
17//!
18//! Pylon uses HS256 (HMAC-SHA256) — symmetric, no key distribution.
19//! Apps that need RS256 / asymmetric verification across services
20//! should use the OIDC discovery / JWKS path on Wave 5.
21//!
22//! Spec: <https://www.rfc-editor.org/rfc/rfc7519> + RFC 7515 (JWS).
23
24use crate::apple_jwt::base64_url;
25use hmac::{Hmac, Mac};
26use sha2::Sha256;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29type HmacSha256 = Hmac<Sha256>;
30
31/// Standard claims pylon mints. Apps that want extra claims can
32/// extend via the Custom variant.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct JwtClaims {
35    /// Subject — the user_id.
36    pub sub: String,
37    /// Issued at (Unix seconds).
38    pub iat: u64,
39    /// Expiry (Unix seconds). Pylon defaults to 30d for parity with
40    /// opaque sessions; apps can override.
41    pub exp: u64,
42    /// Issuer — `PYLON_JWT_ISSUER` if set, else `pylon`.
43    pub iss: String,
44    /// Optional tenant id (Pylon-specific extension claim
45    /// `https://pylonsync.com/tenant`).
46    pub tenant_id: Option<String>,
47    /// Optional roles array.
48    pub roles: Vec<String>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum JwtError {
53    /// Token doesn't have three `.`-separated segments.
54    Malformed,
55    /// Header / claims base64 decode failed.
56    BadEncoding,
57    /// Header alg isn't `HS256` (we only mint that).
58    UnsupportedAlg,
59    /// Signature didn't match the secret.
60    BadSignature,
61    /// `exp` is in the past.
62    Expired,
63    /// `iss` doesn't match expected issuer.
64    WrongIssuer,
65}
66
67impl std::fmt::Display for JwtError {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.write_str(match self {
70            Self::Malformed => "JWT malformed",
71            Self::BadEncoding => "JWT base64/JSON decode failed",
72            Self::UnsupportedAlg => "JWT alg not supported (expected HS256)",
73            Self::BadSignature => "JWT signature mismatch",
74            Self::Expired => "JWT expired",
75            Self::WrongIssuer => "JWT issuer mismatch",
76        })
77    }
78}
79
80/// Mint a JWT-shaped session token. The output is the
81/// `header.claims.sig` triplet, ready to be returned in
82/// `Authorization: Bearer …` form. Client doesn't need to know the
83/// difference from an opaque session token.
84///
85/// Panics in debug if `claims.exp <= claims.iat` — programmer error
86/// (the token would be instantly expired). Release builds let it
87/// through; the verifier would then reject as `Expired`.
88pub fn mint(secret: &[u8], claims: &JwtClaims) -> String {
89    debug_assert!(
90        claims.exp > claims.iat,
91        "JWT exp ({}) must be > iat ({})",
92        claims.exp,
93        claims.iat
94    );
95    let header = serde_json::json!({"alg": "HS256", "typ": "JWT"});
96    let mut claims_obj = serde_json::Map::new();
97    claims_obj.insert("sub".into(), claims.sub.clone().into());
98    claims_obj.insert("iat".into(), claims.iat.into());
99    claims_obj.insert("exp".into(), claims.exp.into());
100    claims_obj.insert("iss".into(), claims.iss.clone().into());
101    if let Some(t) = &claims.tenant_id {
102        claims_obj.insert("https://pylonsync.com/tenant".into(), t.clone().into());
103    }
104    if !claims.roles.is_empty() {
105        claims_obj.insert(
106            "https://pylonsync.com/roles".into(),
107            serde_json::Value::Array(claims.roles.iter().cloned().map(Into::into).collect()),
108        );
109    }
110    let header_b64 = base64_url(serde_json::to_vec(&header).unwrap());
111    let claims_b64 = base64_url(serde_json::to_vec(&claims_obj).unwrap());
112    let signing_input = format!("{header_b64}.{claims_b64}");
113    let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
114    mac.update(signing_input.as_bytes());
115    let sig = mac.finalize().into_bytes();
116    let sig_b64 = base64_url(sig);
117    format!("{signing_input}.{sig_b64}")
118}
119
120/// Verify + decode a JWT. Checks signature, alg, expiry, and issuer
121/// (when supplied). Returns the parsed claims or a structured error.
122pub fn verify(
123    token: &str,
124    secret: &[u8],
125    expected_issuer: Option<&str>,
126) -> Result<JwtClaims, JwtError> {
127    let mut parts = token.split('.');
128    let header_b64 = parts.next().ok_or(JwtError::Malformed)?;
129    let claims_b64 = parts.next().ok_or(JwtError::Malformed)?;
130    let sig_b64 = parts.next().ok_or(JwtError::Malformed)?;
131    if parts.next().is_some() {
132        return Err(JwtError::Malformed);
133    }
134
135    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
136    let header_bytes = URL_SAFE_NO_PAD
137        .decode(header_b64)
138        .map_err(|_| JwtError::BadEncoding)?;
139    let header: serde_json::Value =
140        serde_json::from_slice(&header_bytes).map_err(|_| JwtError::BadEncoding)?;
141    if header.get("alg").and_then(|v| v.as_str()) != Some("HS256") {
142        return Err(JwtError::UnsupportedAlg);
143    }
144
145    let signing_input = format!("{header_b64}.{claims_b64}");
146    let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
147    mac.update(signing_input.as_bytes());
148    let expected_sig = mac.finalize().into_bytes();
149    let provided_sig = URL_SAFE_NO_PAD
150        .decode(sig_b64)
151        .map_err(|_| JwtError::BadEncoding)?;
152    if !crate::constant_time_eq(&expected_sig, &provided_sig) {
153        return Err(JwtError::BadSignature);
154    }
155
156    let claims_bytes = URL_SAFE_NO_PAD
157        .decode(claims_b64)
158        .map_err(|_| JwtError::BadEncoding)?;
159    let claims: serde_json::Value =
160        serde_json::from_slice(&claims_bytes).map_err(|_| JwtError::BadEncoding)?;
161
162    let now = SystemTime::now()
163        .duration_since(UNIX_EPOCH)
164        .map(|d| d.as_secs())
165        .unwrap_or(0);
166    let exp = claims.get("exp").and_then(|v| v.as_u64()).unwrap_or(0);
167    if exp <= now {
168        return Err(JwtError::Expired);
169    }
170    let iss = claims
171        .get("iss")
172        .and_then(|v| v.as_str())
173        .unwrap_or_default()
174        .to_string();
175    if let Some(want) = expected_issuer {
176        if iss != want {
177            return Err(JwtError::WrongIssuer);
178        }
179    }
180
181    let sub = claims
182        .get("sub")
183        .and_then(|v| v.as_str())
184        .ok_or(JwtError::BadEncoding)?
185        .to_string();
186    let iat = claims.get("iat").and_then(|v| v.as_u64()).unwrap_or(0);
187    let tenant_id = claims
188        .get("https://pylonsync.com/tenant")
189        .and_then(|v| v.as_str())
190        .map(String::from);
191    let roles = claims
192        .get("https://pylonsync.com/roles")
193        .and_then(|v| v.as_array())
194        .map(|arr| {
195            arr.iter()
196                .filter_map(|v| v.as_str().map(String::from))
197                .collect()
198        })
199        .unwrap_or_default();
200
201    Ok(JwtClaims {
202        sub,
203        iat,
204        exp,
205        iss,
206        tenant_id,
207        roles,
208    })
209}
210
211/// Convenience: detect whether a bearer token looks like a JWT
212/// (three `.`-separated base64url segments) so the dispatcher can
213/// route between session store and JWT verifier without trying both.
214pub fn looks_like_jwt(token: &str) -> bool {
215    let mut parts = token.split('.');
216    let a = parts.next();
217    let b = parts.next();
218    let c = parts.next();
219    let extra = parts.next();
220    matches!((a, b, c, extra), (Some(a), Some(b), Some(c), None) if !a.is_empty() && !b.is_empty() && !c.is_empty())
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    fn fixture_claims(exp_secs_from_now: i64) -> JwtClaims {
228        let now = SystemTime::now()
229            .duration_since(UNIX_EPOCH)
230            .unwrap()
231            .as_secs();
232        JwtClaims {
233            sub: "user-1".into(),
234            iat: now,
235            exp: (now as i64 + exp_secs_from_now) as u64,
236            iss: "pylon-test".into(),
237            tenant_id: None,
238            roles: vec![],
239        }
240    }
241
242    #[test]
243    fn round_trip_minimal_claims() {
244        let secret = b"super-secret-pylon-key";
245        let claims = fixture_claims(3600);
246        let token = mint(secret, &claims);
247        let decoded = verify(&token, secret, Some("pylon-test")).unwrap();
248        assert_eq!(decoded.sub, "user-1");
249        assert_eq!(decoded.iss, "pylon-test");
250    }
251
252    #[test]
253    fn round_trip_with_tenant_and_roles() {
254        let secret = b"k";
255        let mut claims = fixture_claims(3600);
256        claims.tenant_id = Some("acme".into());
257        claims.roles = vec!["admin".into(), "billing".into()];
258        let token = mint(secret, &claims);
259        let decoded = verify(&token, secret, None).unwrap();
260        assert_eq!(decoded.tenant_id.as_deref(), Some("acme"));
261        assert_eq!(decoded.roles, vec!["admin", "billing"]);
262    }
263
264    #[test]
265    fn expired_token_rejected() {
266        let secret = b"k";
267        // Use a clock-drift scenario: token minted with future iat but
268        // also future-then-now-then-past exp. We mint a token far in
269        // the future, then verify after the OS clock has moved past
270        // exp. Easier: hand-craft the encoded JWT directly to bypass
271        // mint's debug_assert.
272        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
273        let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"HS256","typ":"JWT"}"#);
274        let claims =
275            URL_SAFE_NO_PAD.encode(br#"{"sub":"user-1","iat":1,"exp":2,"iss":"pylon-test"}"#);
276        let signing_input = format!("{header}.{claims}");
277        use hmac::{Hmac, Mac};
278        use sha2::Sha256;
279        let mut mac = Hmac::<Sha256>::new_from_slice(secret).unwrap();
280        mac.update(signing_input.as_bytes());
281        let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
282        let token = format!("{signing_input}.{sig}");
283        assert_eq!(verify(&token, secret, None), Err(JwtError::Expired));
284    }
285
286    #[test]
287    #[should_panic(expected = "JWT exp")]
288    #[cfg(debug_assertions)]
289    fn mint_panics_on_exp_le_iat_in_debug() {
290        let secret = b"k";
291        let mut claims = fixture_claims(0);
292        claims.exp = claims.iat;
293        let _ = mint(secret, &claims);
294    }
295
296    #[test]
297    fn wrong_secret_rejected() {
298        let secret = b"k";
299        let claims = fixture_claims(3600);
300        let token = mint(secret, &claims);
301        assert_eq!(
302            verify(&token, b"different-secret", None),
303            Err(JwtError::BadSignature)
304        );
305    }
306
307    #[test]
308    fn wrong_issuer_rejected() {
309        let secret = b"k";
310        let claims = fixture_claims(3600);
311        let token = mint(secret, &claims);
312        assert_eq!(
313            verify(&token, secret, Some("different-issuer")),
314            Err(JwtError::WrongIssuer)
315        );
316    }
317
318    #[test]
319    fn alg_none_rejected() {
320        // Critical security check — RFC 7519 famously had the "alg:none"
321        // bypass class. Hand-craft a token with `alg: none` and assert
322        // verify rejects it.
323        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
324        let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none","typ":"JWT"}"#);
325        let claims = URL_SAFE_NO_PAD.encode(br#"{"sub":"attacker","exp":99999999999}"#);
326        let token = format!("{header}.{claims}.");
327        let result = verify(&token, b"any-secret", None);
328        assert_eq!(result, Err(JwtError::UnsupportedAlg));
329    }
330
331    #[test]
332    fn malformed_token_rejected() {
333        assert_eq!(
334            verify("not.a.jwt.too-many-parts", b"k", None),
335            Err(JwtError::Malformed)
336        );
337        assert_eq!(
338            verify("only-one-part", b"k", None),
339            Err(JwtError::Malformed)
340        );
341        assert_eq!(verify("", b"k", None), Err(JwtError::Malformed));
342    }
343
344    #[test]
345    fn looks_like_jwt_classifies() {
346        assert!(looks_like_jwt("aaa.bbb.ccc"));
347        assert!(!looks_like_jwt("pylon_abcdef"));
348        assert!(!looks_like_jwt("aaa.bbb"));
349        assert!(!looks_like_jwt(""));
350        assert!(!looks_like_jwt("aaa..ccc"));
351        // NOTE: `pk.key_abc.secret` has three nonempty segments and
352        // would superficially look like a JWT — that's why the
353        // dispatcher in server.rs MUST check the `pk.` prefix
354        // BEFORE looks_like_jwt. Documented for whoever changes that
355        // dispatcher next.
356        assert!(looks_like_jwt("pk.key_abc.secret"));
357    }
358
359    /// Codex Wave-5 P0-3 regression. The auth-token dispatcher in
360    /// server.rs uses `t.starts_with("pk.")` BEFORE `looks_like_jwt`
361    /// because `pk.…` tokens have three dot-separated nonempty
362    /// segments and would otherwise fall through to JWT verify.
363    /// This test pins the contract at the predicate level: any
364    /// `pk.…` token must classify as both an api-key AND a JWT
365    /// shape, so the dispatcher MUST disambiguate via the prefix.
366    #[test]
367    fn pk_token_overlaps_jwt_shape_dispatcher_must_check_prefix_first() {
368        let pk_like =
369            "pk.key_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
370        assert!(pk_like.starts_with("pk."));
371        assert!(looks_like_jwt(pk_like));
372    }
373}