Skip to main content

pylon_plugin/builtin/
jwt.rs

1use std::collections::HashSet;
2use std::sync::Mutex;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use hmac::{Hmac, Mac};
6use sha2::Sha256;
7
8use crate::Plugin;
9
10type HmacSha256 = Hmac<Sha256>;
11
12/// A minimal JWT implementation using HMAC-SHA256 (HS256).
13///
14/// For production, consider a full JWT library. This implementation uses real
15/// cryptographic primitives (HMAC-SHA256 via the `hmac` and `sha2` crates).
16pub struct JwtPlugin {
17    secret: String,
18    expiry_secs: u64,
19    /// Tracks consumed refresh tokens to enforce one-time use.
20    used_refresh_tokens: Mutex<HashSet<String>>,
21}
22
23/// Decoded JWT claims.
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct Claims {
26    pub sub: String,
27    pub iat: u64,
28    pub exp: u64,
29    /// Token kind: `"access"` or `"refresh"`. `None` for legacy tokens
30    /// issued before the kind field was introduced.
31    pub kind: Option<String>,
32}
33
34/// A paired access + refresh token, issued together.
35#[derive(Debug, Clone)]
36pub struct TokenPair {
37    pub access_token: String,
38    pub refresh_token: String,
39    pub access_expires_in: u64,
40    pub refresh_expires_in: u64,
41}
42
43impl JwtPlugin {
44    pub fn new(secret: &str, expiry_secs: u64) -> Self {
45        Self {
46            secret: secret.to_string(),
47            expiry_secs,
48            used_refresh_tokens: Mutex::new(HashSet::new()),
49        }
50    }
51
52    /// Issue a short-lived access JWT for a user ID.
53    pub fn issue(&self, user_id: &str) -> String {
54        self.issue_with_kind(user_id, "access", self.expiry_secs)
55    }
56
57    /// Issue a JWT with an explicit kind and expiry.
58    pub fn issue_with_kind(&self, user_id: &str, kind: &str, expiry_secs: u64) -> String {
59        let now = SystemTime::now()
60            .duration_since(UNIX_EPOCH)
61            .unwrap_or_default()
62            .as_secs();
63
64        let header = base64url_encode(b"{\"alg\":\"HS256\",\"typ\":\"JWT\"}");
65        let payload = base64url_encode(
66            format!(
67                "{{\"sub\":\"{}\",\"iat\":{},\"exp\":{},\"kind\":\"{}\"}}",
68                user_id,
69                now,
70                now + expiry_secs,
71                kind,
72            )
73            .as_bytes(),
74        );
75
76        let signing_input = format!("{header}.{payload}");
77        let signature = base64url_encode(&hmac_sha256(&self.secret, &signing_input));
78
79        format!("{signing_input}.{signature}")
80    }
81
82    /// Issue a token pair: a short-lived access token and a long-lived refresh
83    /// token. The access token uses the plugin's configured expiry; the refresh
84    /// token uses the provided `refresh_expiry_secs`.
85    pub fn issue_pair(&self, user_id: &str, refresh_expiry_secs: u64) -> TokenPair {
86        let access_token = self.issue(user_id);
87        let refresh_token = self.issue_with_kind(user_id, "refresh", refresh_expiry_secs);
88        TokenPair {
89            access_token,
90            refresh_token,
91            access_expires_in: self.expiry_secs,
92            refresh_expires_in: refresh_expiry_secs,
93        }
94    }
95
96    /// Consume a refresh token and issue a new token pair.
97    ///
98    /// Order of operations matters for security:
99    ///   1. Cryptographically verify the token FIRST. If we inserted into the
100    ///      replay cache before verification, an attacker could pollute the
101    ///      cache by posting random garbage, growing it unbounded. Worse, a
102    ///      real token presented alongside that garbage would get "burned"
103    ///      before we knew whether it was even valid.
104    ///   2. Then check the replay cache and atomically insert.
105    ///
106    /// The window between `verify()` and `insert()` is a TOCTOU where two
107    /// concurrent refreshes of the same token could both succeed. The Mutex
108    /// around `used_refresh_tokens` is the serialization point — the check +
109    /// insert happens under the same lock.
110    pub fn refresh(&self, refresh_token: &str) -> Result<TokenPair, String> {
111        let claims = self.verify(refresh_token)?;
112
113        match claims.kind.as_deref() {
114            Some("refresh") => {}
115            _ => return Err("Token is not a refresh token".into()),
116        }
117
118        {
119            let mut used = self
120                .used_refresh_tokens
121                .lock()
122                .map_err(|_| "Lock poisoned")?;
123            if used.contains(refresh_token) {
124                return Err("Refresh token already used".into());
125            }
126            used.insert(refresh_token.to_string());
127        }
128
129        Ok(self.issue_pair(&claims.sub, 86400 * 7))
130    }
131
132    /// Verify and decode a JWT. Returns claims if valid and not expired.
133    /// Uses constant-time comparison for the signature to prevent timing attacks.
134    pub fn verify(&self, token: &str) -> Result<Claims, String> {
135        let parts: Vec<&str> = token.split('.').collect();
136        if parts.len() != 3 {
137            return Err("Invalid JWT format".into());
138        }
139
140        let signing_input = format!("{}.{}", parts[0], parts[1]);
141        let expected_sig = base64url_encode(&hmac_sha256(&self.secret, &signing_input));
142
143        if !pylon_auth::constant_time_eq(parts[2].as_bytes(), expected_sig.as_bytes()) {
144            return Err("Invalid signature".into());
145        }
146
147        let payload_bytes = base64url_decode(parts[1])?;
148        let payload_str = String::from_utf8(payload_bytes).map_err(|_| "Invalid payload")?;
149
150        // Parse claims manually (no serde dependency in this minimal impl).
151        let sub = extract_json_string(&payload_str, "sub").ok_or("Missing sub claim")?;
152        let iat = extract_json_number(&payload_str, "iat").ok_or("Missing iat claim")?;
153        let exp = extract_json_number(&payload_str, "exp").ok_or("Missing exp claim")?;
154        let kind = extract_json_string(&payload_str, "kind");
155
156        let now = SystemTime::now()
157            .duration_since(UNIX_EPOCH)
158            .unwrap_or_default()
159            .as_secs();
160
161        if now > exp {
162            return Err("Token expired".into());
163        }
164
165        Ok(Claims {
166            sub,
167            iat,
168            exp,
169            kind,
170        })
171    }
172
173    /// Resolve a JWT to a user ID. Returns None if invalid.
174    pub fn resolve_user(&self, token: &str) -> Option<String> {
175        self.verify(token).ok().map(|c| c.sub)
176    }
177}
178
179impl Plugin for JwtPlugin {
180    fn name(&self) -> &str {
181        "jwt"
182    }
183}
184
185// -- Minimal base64url encoding/decoding --
186
187fn base64url_encode(data: &[u8]) -> String {
188    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
189    let mut out = String::new();
190    for chunk in data.chunks(3) {
191        let b0 = chunk[0] as u32;
192        let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
193        let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
194        let n = (b0 << 16) | (b1 << 8) | b2;
195        out.push(CHARS[((n >> 18) & 63) as usize] as char);
196        out.push(CHARS[((n >> 12) & 63) as usize] as char);
197        if chunk.len() > 1 {
198            out.push(CHARS[((n >> 6) & 63) as usize] as char);
199        }
200        if chunk.len() > 2 {
201            out.push(CHARS[(n & 63) as usize] as char);
202        }
203    }
204    out
205}
206
207fn base64url_decode(data: &str) -> Result<Vec<u8>, String> {
208    fn val(c: u8) -> Result<u8, String> {
209        match c {
210            b'A'..=b'Z' => Ok(c - b'A'),
211            b'a'..=b'z' => Ok(c - b'a' + 26),
212            b'0'..=b'9' => Ok(c - b'0' + 52),
213            b'-' => Ok(62),
214            b'_' => Ok(63),
215            _ => Err(format!("Invalid base64url character: {}", c as char)),
216        }
217    }
218
219    let bytes = data.as_bytes();
220    let mut out = Vec::new();
221    let mut i = 0;
222    while i < bytes.len() {
223        let b0 = val(bytes[i])?;
224        let b1 = if i + 1 < bytes.len() {
225            val(bytes[i + 1])?
226        } else {
227            0
228        };
229        let b2 = if i + 2 < bytes.len() {
230            val(bytes[i + 2])?
231        } else {
232            0
233        };
234        let b3 = if i + 3 < bytes.len() {
235            val(bytes[i + 3])?
236        } else {
237            0
238        };
239
240        let n = ((b0 as u32) << 18) | ((b1 as u32) << 12) | ((b2 as u32) << 6) | (b3 as u32);
241        out.push((n >> 16) as u8);
242        if i + 2 < bytes.len() {
243            out.push((n >> 8) as u8);
244        }
245        if i + 3 < bytes.len() {
246            out.push(n as u8);
247        }
248        i += 4;
249    }
250    Ok(out)
251}
252
253// -- HMAC-SHA256 signing --
254
255fn hmac_sha256(key: &str, data: &str) -> Vec<u8> {
256    let mut mac =
257        HmacSha256::new_from_slice(key.as_bytes()).expect("HMAC can take key of any size");
258    mac.update(data.as_bytes());
259    mac.finalize().into_bytes().to_vec()
260}
261
262fn extract_json_string(json: &str, key: &str) -> Option<String> {
263    let pattern = format!("\"{}\":\"", key);
264    let idx = json.find(&pattern)?;
265    let start = idx + pattern.len();
266    let end = json[start..].find('"')? + start;
267    Some(json[start..end].to_string())
268}
269
270fn extract_json_number(json: &str, key: &str) -> Option<u64> {
271    let pattern = format!("\"{}\":", key);
272    let idx = json.find(&pattern)?;
273    let start = idx + pattern.len();
274    let rest = &json[start..];
275    let end = rest
276        .find(|c: char| !c.is_ascii_digit())
277        .unwrap_or(rest.len());
278    rest[..end].parse().ok()
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn issue_and_verify() {
287        let jwt = JwtPlugin::new("test-secret", 3600);
288        let token = jwt.issue("user-1");
289
290        assert!(!token.is_empty());
291        assert_eq!(token.split('.').count(), 3);
292
293        let claims = jwt.verify(&token).unwrap();
294        assert_eq!(claims.sub, "user-1");
295        assert!(claims.exp > claims.iat);
296        assert_eq!(claims.kind, Some("access".into()));
297    }
298
299    #[test]
300    fn wrong_secret_fails() {
301        let jwt1 = JwtPlugin::new("secret-1", 3600);
302        let jwt2 = JwtPlugin::new("secret-2", 3600);
303
304        let token = jwt1.issue("user-1");
305        let result = jwt2.verify(&token);
306        assert!(result.is_err());
307    }
308
309    #[test]
310    fn expired_token_rejected() {
311        let jwt = JwtPlugin::new("secret", 0); // 0 second expiry
312        let token = jwt.issue("user-1");
313
314        // Token is already expired (exp = iat + 0 = now, and we check now > exp).
315        // This might pass if checked in the same second. Sleep would make it reliable.
316        // For testing, use a very short expiry and accept the edge case.
317        let _ = jwt.verify(&token); // may or may not fail depending on timing
318    }
319
320    #[test]
321    fn invalid_format_rejected() {
322        let jwt = JwtPlugin::new("secret", 3600);
323        assert!(jwt.verify("not.a.jwt.token").is_err());
324        assert!(jwt.verify("invalid").is_err());
325        assert!(jwt.verify("").is_err());
326    }
327
328    #[test]
329    fn resolve_user() {
330        let jwt = JwtPlugin::new("secret", 3600);
331        let token = jwt.issue("alice");
332
333        assert_eq!(jwt.resolve_user(&token), Some("alice".into()));
334        assert_eq!(jwt.resolve_user("invalid"), None);
335    }
336
337    #[test]
338    fn different_users_different_tokens() {
339        let jwt = JwtPlugin::new("secret", 3600);
340        let t1 = jwt.issue("user-1");
341        let t2 = jwt.issue("user-2");
342        assert_ne!(t1, t2);
343    }
344
345    #[test]
346    fn hmac_sha256_produces_32_bytes() {
347        let sig = hmac_sha256("key", "data");
348        assert_eq!(sig.len(), 32);
349    }
350
351    #[test]
352    fn hmac_sha256_different_keys_different_output() {
353        let s1 = hmac_sha256("key1", "data");
354        let s2 = hmac_sha256("key2", "data");
355        assert_ne!(s1, s2);
356    }
357
358    #[test]
359    fn hmac_sha256_different_data_different_output() {
360        let s1 = hmac_sha256("key", "data1");
361        let s2 = hmac_sha256("key", "data2");
362        assert_ne!(s1, s2);
363    }
364
365    // -- Token pair tests --
366
367    #[test]
368    fn issue_pair_creates_two_distinct_tokens() {
369        let jwt = JwtPlugin::new("secret", 300);
370        let pair = jwt.issue_pair("user-1", 86400 * 7);
371
372        assert_ne!(pair.access_token, pair.refresh_token);
373        assert_eq!(pair.access_expires_in, 300);
374        assert_eq!(pair.refresh_expires_in, 86400 * 7);
375
376        let access_claims = jwt.verify(&pair.access_token).unwrap();
377        assert_eq!(access_claims.sub, "user-1");
378        assert_eq!(access_claims.kind, Some("access".into()));
379
380        let refresh_claims = jwt.verify(&pair.refresh_token).unwrap();
381        assert_eq!(refresh_claims.sub, "user-1");
382        assert_eq!(refresh_claims.kind, Some("refresh".into()));
383    }
384
385    #[test]
386    fn refresh_returns_new_pair() {
387        let jwt = JwtPlugin::new("secret", 300);
388        let pair = jwt.issue_pair("user-1", 86400 * 7);
389
390        let new_pair = jwt.refresh(&pair.refresh_token).unwrap();
391
392        // The new pair should contain valid tokens for the same user.
393        let access_claims = jwt.verify(&new_pair.access_token).unwrap();
394        assert_eq!(access_claims.sub, "user-1");
395        assert_eq!(access_claims.kind, Some("access".into()));
396
397        let refresh_claims = jwt.verify(&new_pair.refresh_token).unwrap();
398        assert_eq!(refresh_claims.sub, "user-1");
399        assert_eq!(refresh_claims.kind, Some("refresh".into()));
400
401        // The old refresh token must now be rejected (one-time use).
402        let err = jwt.refresh(&pair.refresh_token).unwrap_err();
403        assert!(err.contains("already used"));
404    }
405
406    #[test]
407    fn used_refresh_token_rejected() {
408        let jwt = JwtPlugin::new("secret", 300);
409        let pair = jwt.issue_pair("user-1", 86400 * 7);
410
411        // First use succeeds.
412        assert!(jwt.refresh(&pair.refresh_token).is_ok());
413
414        // Second use is rejected (replay protection).
415        let err = jwt.refresh(&pair.refresh_token).unwrap_err();
416        assert!(err.contains("already used"));
417    }
418
419    #[test]
420    fn access_token_cannot_be_used_as_refresh() {
421        let jwt = JwtPlugin::new("secret", 300);
422        let pair = jwt.issue_pair("user-1", 86400 * 7);
423
424        let err = jwt.refresh(&pair.access_token).unwrap_err();
425        assert!(err.contains("not a refresh token"));
426    }
427
428    #[test]
429    fn issue_with_kind_sets_kind_field() {
430        let jwt = JwtPlugin::new("secret", 3600);
431        let token = jwt.issue_with_kind("user-1", "refresh", 86400);
432        let claims = jwt.verify(&token).unwrap();
433        assert_eq!(claims.kind, Some("refresh".into()));
434    }
435}