Skip to main content

pylon_auth/
jwt.rs

1//! Stateless JWT sessions — alternative to opaque session tokens.
2//!
3//! By default Pylon mints opaque random `pylon_…` tokens that must
4//! be looked up in the session store on every request. For deploys
5//! that can't tolerate that round-trip (edge runtimes, CDN-backed
6//! routes, multi-region read replicas), Pylon can mint **JWT-shaped**
7//! sessions instead — verified by the local secret with no DB hit.
8//!
9//! Trade-offs:
10//!   - **Pro**: stateless verification (no DB read on every request)
11//!   - **Pro**: clients can decode their own claims (without verifying)
12//!     for UI personalization without a `/me` round-trip
13//!   - **Con**: revocation requires either a denylist or a short TTL —
14//!     a leaked JWT stays valid until its `exp`
15//!   - **Con**: secret rotation needs both old + new keys to coexist
16//!     for at least one session lifetime
17//!
18//! Pylon uses HS256 (HMAC-SHA256) — symmetric, no key distribution.
19//! Apps that need RS256 / asymmetric verification across services
20//! should use the OIDC discovery / JWKS path on Wave 5.
21//!
22//! Spec: <https://www.rfc-editor.org/rfc/rfc7519> + RFC 7515 (JWS).
23
24use crate::apple_jwt::base64_url;
25use hmac::{Hmac, Mac};
26use sha2::Sha256;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29type HmacSha256 = Hmac<Sha256>;
30
31/// Standard claims pylon mints. Apps that want extra claims can
32/// extend via the Custom variant.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct JwtClaims {
35    /// Subject — the user_id.
36    pub sub: String,
37    /// Issued at (Unix seconds).
38    pub iat: u64,
39    /// Expiry (Unix seconds). Pylon defaults to 30d for parity with
40    /// opaque sessions; apps can override.
41    pub exp: u64,
42    /// Issuer — `PYLON_JWT_ISSUER` if set, else `pylon`.
43    pub iss: String,
44    /// Optional tenant id (Pylon-specific extension claim
45    /// `https://pylonsync.com/tenant`).
46    pub tenant_id: Option<String>,
47    /// Optional roles array.
48    pub roles: Vec<String>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum JwtError {
53    /// Token doesn't have three `.`-separated segments.
54    Malformed,
55    /// Header / claims base64 decode failed.
56    BadEncoding,
57    /// Header alg isn't `HS256` (we only mint that).
58    UnsupportedAlg,
59    /// Signature didn't match the secret.
60    BadSignature,
61    /// `exp` is in the past.
62    Expired,
63    /// `iss` doesn't match expected issuer.
64    WrongIssuer,
65}
66
67impl std::fmt::Display for JwtError {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.write_str(match self {
70            Self::Malformed => "JWT malformed",
71            Self::BadEncoding => "JWT base64/JSON decode failed",
72            Self::UnsupportedAlg => "JWT alg not supported (expected HS256)",
73            Self::BadSignature => "JWT signature mismatch",
74            Self::Expired => "JWT expired",
75            Self::WrongIssuer => "JWT issuer mismatch",
76        })
77    }
78}
79
80/// Mint a JWT-shaped session token. The output is the
81/// `header.claims.sig` triplet, ready to be returned in
82/// `Authorization: Bearer …` form. Client doesn't need to know the
83/// difference from an opaque session token.
84///
85/// Panics in debug if `claims.exp <= claims.iat` — programmer error
86/// (the token would be instantly expired). Release builds let it
87/// through; the verifier would then reject as `Expired`.
88pub fn mint(secret: &[u8], claims: &JwtClaims) -> String {
89    debug_assert!(
90        claims.exp > claims.iat,
91        "JWT exp ({}) must be > iat ({})",
92        claims.exp,
93        claims.iat
94    );
95    let header = serde_json::json!({"alg": "HS256", "typ": "JWT"});
96    let mut claims_obj = serde_json::Map::new();
97    claims_obj.insert("sub".into(), claims.sub.clone().into());
98    claims_obj.insert("iat".into(), claims.iat.into());
99    claims_obj.insert("exp".into(), claims.exp.into());
100    claims_obj.insert("iss".into(), claims.iss.clone().into());
101    if let Some(t) = &claims.tenant_id {
102        claims_obj.insert("https://pylonsync.com/tenant".into(), t.clone().into());
103    }
104    if !claims.roles.is_empty() {
105        claims_obj.insert(
106            "https://pylonsync.com/roles".into(),
107            serde_json::Value::Array(
108                claims.roles.iter().cloned().map(Into::into).collect(),
109            ),
110        );
111    }
112    let header_b64 = base64_url(serde_json::to_vec(&header).unwrap());
113    let claims_b64 = base64_url(serde_json::to_vec(&claims_obj).unwrap());
114    let signing_input = format!("{header_b64}.{claims_b64}");
115    let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
116    mac.update(signing_input.as_bytes());
117    let sig = mac.finalize().into_bytes();
118    let sig_b64 = base64_url(sig);
119    format!("{signing_input}.{sig_b64}")
120}
121
122/// Verify + decode a JWT. Checks signature, alg, expiry, and issuer
123/// (when supplied). Returns the parsed claims or a structured error.
124pub fn verify(token: &str, secret: &[u8], expected_issuer: Option<&str>) -> Result<JwtClaims, JwtError> {
125    let mut parts = token.split('.');
126    let header_b64 = parts.next().ok_or(JwtError::Malformed)?;
127    let claims_b64 = parts.next().ok_or(JwtError::Malformed)?;
128    let sig_b64 = parts.next().ok_or(JwtError::Malformed)?;
129    if parts.next().is_some() {
130        return Err(JwtError::Malformed);
131    }
132
133    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
134    let header_bytes = URL_SAFE_NO_PAD
135        .decode(header_b64)
136        .map_err(|_| JwtError::BadEncoding)?;
137    let header: serde_json::Value =
138        serde_json::from_slice(&header_bytes).map_err(|_| JwtError::BadEncoding)?;
139    if header.get("alg").and_then(|v| v.as_str()) != Some("HS256") {
140        return Err(JwtError::UnsupportedAlg);
141    }
142
143    let signing_input = format!("{header_b64}.{claims_b64}");
144    let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
145    mac.update(signing_input.as_bytes());
146    let expected_sig = mac.finalize().into_bytes();
147    let provided_sig = URL_SAFE_NO_PAD
148        .decode(sig_b64)
149        .map_err(|_| JwtError::BadEncoding)?;
150    if !crate::constant_time_eq(&expected_sig, &provided_sig) {
151        return Err(JwtError::BadSignature);
152    }
153
154    let claims_bytes = URL_SAFE_NO_PAD
155        .decode(claims_b64)
156        .map_err(|_| JwtError::BadEncoding)?;
157    let claims: serde_json::Value =
158        serde_json::from_slice(&claims_bytes).map_err(|_| JwtError::BadEncoding)?;
159
160    let now = SystemTime::now()
161        .duration_since(UNIX_EPOCH)
162        .map(|d| d.as_secs())
163        .unwrap_or(0);
164    let exp = claims.get("exp").and_then(|v| v.as_u64()).unwrap_or(0);
165    if exp <= now {
166        return Err(JwtError::Expired);
167    }
168    let iss = claims
169        .get("iss")
170        .and_then(|v| v.as_str())
171        .unwrap_or_default()
172        .to_string();
173    if let Some(want) = expected_issuer {
174        if iss != want {
175            return Err(JwtError::WrongIssuer);
176        }
177    }
178
179    let sub = claims
180        .get("sub")
181        .and_then(|v| v.as_str())
182        .ok_or(JwtError::BadEncoding)?
183        .to_string();
184    let iat = claims.get("iat").and_then(|v| v.as_u64()).unwrap_or(0);
185    let tenant_id = claims
186        .get("https://pylonsync.com/tenant")
187        .and_then(|v| v.as_str())
188        .map(String::from);
189    let roles = claims
190        .get("https://pylonsync.com/roles")
191        .and_then(|v| v.as_array())
192        .map(|arr| {
193            arr.iter()
194                .filter_map(|v| v.as_str().map(String::from))
195                .collect()
196        })
197        .unwrap_or_default();
198
199    Ok(JwtClaims {
200        sub,
201        iat,
202        exp,
203        iss,
204        tenant_id,
205        roles,
206    })
207}
208
209/// Convenience: detect whether a bearer token looks like a JWT
210/// (three `.`-separated base64url segments) so the dispatcher can
211/// route between session store and JWT verifier without trying both.
212pub fn looks_like_jwt(token: &str) -> bool {
213    let mut parts = token.split('.');
214    let a = parts.next();
215    let b = parts.next();
216    let c = parts.next();
217    let extra = parts.next();
218    matches!((a, b, c, extra), (Some(a), Some(b), Some(c), None) if !a.is_empty() && !b.is_empty() && !c.is_empty())
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    fn fixture_claims(exp_secs_from_now: i64) -> JwtClaims {
226        let now = SystemTime::now()
227            .duration_since(UNIX_EPOCH)
228            .unwrap()
229            .as_secs();
230        JwtClaims {
231            sub: "user-1".into(),
232            iat: now,
233            exp: (now as i64 + exp_secs_from_now) as u64,
234            iss: "pylon-test".into(),
235            tenant_id: None,
236            roles: vec![],
237        }
238    }
239
240    #[test]
241    fn round_trip_minimal_claims() {
242        let secret = b"super-secret-pylon-key";
243        let claims = fixture_claims(3600);
244        let token = mint(secret, &claims);
245        let decoded = verify(&token, secret, Some("pylon-test")).unwrap();
246        assert_eq!(decoded.sub, "user-1");
247        assert_eq!(decoded.iss, "pylon-test");
248    }
249
250    #[test]
251    fn round_trip_with_tenant_and_roles() {
252        let secret = b"k";
253        let mut claims = fixture_claims(3600);
254        claims.tenant_id = Some("acme".into());
255        claims.roles = vec!["admin".into(), "billing".into()];
256        let token = mint(secret, &claims);
257        let decoded = verify(&token, secret, None).unwrap();
258        assert_eq!(decoded.tenant_id.as_deref(), Some("acme"));
259        assert_eq!(decoded.roles, vec!["admin", "billing"]);
260    }
261
262    #[test]
263    fn expired_token_rejected() {
264        let secret = b"k";
265        // Use a clock-drift scenario: token minted with future iat but
266        // also future-then-now-then-past exp. We mint a token far in
267        // the future, then verify after the OS clock has moved past
268        // exp. Easier: hand-craft the encoded JWT directly to bypass
269        // mint's debug_assert.
270        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
271        let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"HS256","typ":"JWT"}"#);
272        let claims = URL_SAFE_NO_PAD
273            .encode(br#"{"sub":"user-1","iat":1,"exp":2,"iss":"pylon-test"}"#);
274        let signing_input = format!("{header}.{claims}");
275        use hmac::{Hmac, Mac};
276        use sha2::Sha256;
277        let mut mac = Hmac::<Sha256>::new_from_slice(secret).unwrap();
278        mac.update(signing_input.as_bytes());
279        let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
280        let token = format!("{signing_input}.{sig}");
281        assert_eq!(verify(&token, secret, None), Err(JwtError::Expired));
282    }
283
284    #[test]
285    #[should_panic(expected = "JWT exp")]
286    #[cfg(debug_assertions)]
287    fn mint_panics_on_exp_le_iat_in_debug() {
288        let secret = b"k";
289        let mut claims = fixture_claims(0);
290        claims.exp = claims.iat;
291        let _ = mint(secret, &claims);
292    }
293
294    #[test]
295    fn wrong_secret_rejected() {
296        let secret = b"k";
297        let claims = fixture_claims(3600);
298        let token = mint(secret, &claims);
299        assert_eq!(
300            verify(&token, b"different-secret", None),
301            Err(JwtError::BadSignature)
302        );
303    }
304
305    #[test]
306    fn wrong_issuer_rejected() {
307        let secret = b"k";
308        let claims = fixture_claims(3600);
309        let token = mint(secret, &claims);
310        assert_eq!(
311            verify(&token, secret, Some("different-issuer")),
312            Err(JwtError::WrongIssuer)
313        );
314    }
315
316    #[test]
317    fn alg_none_rejected() {
318        // Critical security check — RFC 7519 famously had the "alg:none"
319        // bypass class. Hand-craft a token with `alg: none` and assert
320        // verify rejects it.
321        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
322        let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none","typ":"JWT"}"#);
323        let claims = URL_SAFE_NO_PAD.encode(br#"{"sub":"attacker","exp":99999999999}"#);
324        let token = format!("{header}.{claims}.");
325        let result = verify(&token, b"any-secret", None);
326        assert_eq!(result, Err(JwtError::UnsupportedAlg));
327    }
328
329    #[test]
330    fn malformed_token_rejected() {
331        assert_eq!(verify("not.a.jwt.too-many-parts", b"k", None), Err(JwtError::Malformed));
332        assert_eq!(verify("only-one-part", b"k", None), Err(JwtError::Malformed));
333        assert_eq!(verify("", b"k", None), Err(JwtError::Malformed));
334    }
335
336    #[test]
337    fn looks_like_jwt_classifies() {
338        assert!(looks_like_jwt("aaa.bbb.ccc"));
339        assert!(!looks_like_jwt("pylon_abcdef"));
340        assert!(!looks_like_jwt("aaa.bbb"));
341        assert!(!looks_like_jwt(""));
342        assert!(!looks_like_jwt("aaa..ccc"));
343        // NOTE: `pk.key_abc.secret` has three nonempty segments and
344        // would superficially look like a JWT — that's why the
345        // dispatcher in server.rs MUST check the `pk.` prefix
346        // BEFORE looks_like_jwt. Documented for whoever changes that
347        // dispatcher next.
348        assert!(looks_like_jwt("pk.key_abc.secret"));
349    }
350
351    /// Codex Wave-5 P0-3 regression. The auth-token dispatcher in
352    /// server.rs uses `t.starts_with("pk.")` BEFORE `looks_like_jwt`
353    /// because `pk.…` tokens have three dot-separated nonempty
354    /// segments and would otherwise fall through to JWT verify.
355    /// This test pins the contract at the predicate level: any
356    /// `pk.…` token must classify as both an api-key AND a JWT
357    /// shape, so the dispatcher MUST disambiguate via the prefix.
358    #[test]
359    fn pk_token_overlaps_jwt_shape_dispatcher_must_check_prefix_first() {
360        let pk_like = "pk.key_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
361        assert!(pk_like.starts_with("pk."));
362        assert!(looks_like_jwt(pk_like));
363    }
364}