1use std::collections::HashSet;
2use std::sync::Mutex;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use hmac::{Hmac, Mac};
6use sha2::Sha256;
7
8use crate::Plugin;
9
10type HmacSha256 = Hmac<Sha256>;
11
12pub struct JwtPlugin {
17 secret: String,
18 expiry_secs: u64,
19 used_refresh_tokens: Mutex<HashSet<String>>,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct Claims {
26 pub sub: String,
27 pub iat: u64,
28 pub exp: u64,
29 pub kind: Option<String>,
32}
33
34#[derive(Debug, Clone)]
36pub struct TokenPair {
37 pub access_token: String,
38 pub refresh_token: String,
39 pub access_expires_in: u64,
40 pub refresh_expires_in: u64,
41}
42
43impl JwtPlugin {
44 pub fn new(secret: &str, expiry_secs: u64) -> Self {
45 Self {
46 secret: secret.to_string(),
47 expiry_secs,
48 used_refresh_tokens: Mutex::new(HashSet::new()),
49 }
50 }
51
52 pub fn issue(&self, user_id: &str) -> String {
54 self.issue_with_kind(user_id, "access", self.expiry_secs)
55 }
56
57 pub fn issue_with_kind(&self, user_id: &str, kind: &str, expiry_secs: u64) -> String {
59 let now = SystemTime::now()
60 .duration_since(UNIX_EPOCH)
61 .unwrap_or_default()
62 .as_secs();
63
64 let header = base64url_encode(b"{\"alg\":\"HS256\",\"typ\":\"JWT\"}");
65 let payload = base64url_encode(
66 format!(
67 "{{\"sub\":\"{}\",\"iat\":{},\"exp\":{},\"kind\":\"{}\"}}",
68 user_id,
69 now,
70 now + expiry_secs,
71 kind,
72 )
73 .as_bytes(),
74 );
75
76 let signing_input = format!("{header}.{payload}");
77 let signature = base64url_encode(&hmac_sha256(&self.secret, &signing_input));
78
79 format!("{signing_input}.{signature}")
80 }
81
82 pub fn issue_pair(&self, user_id: &str, refresh_expiry_secs: u64) -> TokenPair {
86 let access_token = self.issue(user_id);
87 let refresh_token = self.issue_with_kind(user_id, "refresh", refresh_expiry_secs);
88 TokenPair {
89 access_token,
90 refresh_token,
91 access_expires_in: self.expiry_secs,
92 refresh_expires_in: refresh_expiry_secs,
93 }
94 }
95
96 pub fn refresh(&self, refresh_token: &str) -> Result<TokenPair, String> {
111 let claims = self.verify(refresh_token)?;
112
113 match claims.kind.as_deref() {
114 Some("refresh") => {}
115 _ => return Err("Token is not a refresh token".into()),
116 }
117
118 {
119 let mut used = self
120 .used_refresh_tokens
121 .lock()
122 .map_err(|_| "Lock poisoned")?;
123 if used.contains(refresh_token) {
124 return Err("Refresh token already used".into());
125 }
126 used.insert(refresh_token.to_string());
127 }
128
129 Ok(self.issue_pair(&claims.sub, 86400 * 7))
130 }
131
132 pub fn verify(&self, token: &str) -> Result<Claims, String> {
135 let parts: Vec<&str> = token.split('.').collect();
136 if parts.len() != 3 {
137 return Err("Invalid JWT format".into());
138 }
139
140 let signing_input = format!("{}.{}", parts[0], parts[1]);
141 let expected_sig = base64url_encode(&hmac_sha256(&self.secret, &signing_input));
142
143 if !pylon_auth::constant_time_eq(parts[2].as_bytes(), expected_sig.as_bytes()) {
144 return Err("Invalid signature".into());
145 }
146
147 let payload_bytes = base64url_decode(parts[1])?;
148 let payload_str = String::from_utf8(payload_bytes).map_err(|_| "Invalid payload")?;
149
150 let sub = extract_json_string(&payload_str, "sub").ok_or("Missing sub claim")?;
152 let iat = extract_json_number(&payload_str, "iat").ok_or("Missing iat claim")?;
153 let exp = extract_json_number(&payload_str, "exp").ok_or("Missing exp claim")?;
154 let kind = extract_json_string(&payload_str, "kind");
155
156 let now = SystemTime::now()
157 .duration_since(UNIX_EPOCH)
158 .unwrap_or_default()
159 .as_secs();
160
161 if now > exp {
162 return Err("Token expired".into());
163 }
164
165 Ok(Claims {
166 sub,
167 iat,
168 exp,
169 kind,
170 })
171 }
172
173 pub fn resolve_user(&self, token: &str) -> Option<String> {
175 self.verify(token).ok().map(|c| c.sub)
176 }
177}
178
179impl Plugin for JwtPlugin {
180 fn name(&self) -> &str {
181 "jwt"
182 }
183}
184
185fn base64url_encode(data: &[u8]) -> String {
188 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
189 let mut out = String::new();
190 for chunk in data.chunks(3) {
191 let b0 = chunk[0] as u32;
192 let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
193 let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
194 let n = (b0 << 16) | (b1 << 8) | b2;
195 out.push(CHARS[((n >> 18) & 63) as usize] as char);
196 out.push(CHARS[((n >> 12) & 63) as usize] as char);
197 if chunk.len() > 1 {
198 out.push(CHARS[((n >> 6) & 63) as usize] as char);
199 }
200 if chunk.len() > 2 {
201 out.push(CHARS[(n & 63) as usize] as char);
202 }
203 }
204 out
205}
206
207fn base64url_decode(data: &str) -> Result<Vec<u8>, String> {
208 fn val(c: u8) -> Result<u8, String> {
209 match c {
210 b'A'..=b'Z' => Ok(c - b'A'),
211 b'a'..=b'z' => Ok(c - b'a' + 26),
212 b'0'..=b'9' => Ok(c - b'0' + 52),
213 b'-' => Ok(62),
214 b'_' => Ok(63),
215 _ => Err(format!("Invalid base64url character: {}", c as char)),
216 }
217 }
218
219 let bytes = data.as_bytes();
220 let mut out = Vec::new();
221 let mut i = 0;
222 while i < bytes.len() {
223 let b0 = val(bytes[i])?;
224 let b1 = if i + 1 < bytes.len() {
225 val(bytes[i + 1])?
226 } else {
227 0
228 };
229 let b2 = if i + 2 < bytes.len() {
230 val(bytes[i + 2])?
231 } else {
232 0
233 };
234 let b3 = if i + 3 < bytes.len() {
235 val(bytes[i + 3])?
236 } else {
237 0
238 };
239
240 let n = ((b0 as u32) << 18) | ((b1 as u32) << 12) | ((b2 as u32) << 6) | (b3 as u32);
241 out.push((n >> 16) as u8);
242 if i + 2 < bytes.len() {
243 out.push((n >> 8) as u8);
244 }
245 if i + 3 < bytes.len() {
246 out.push(n as u8);
247 }
248 i += 4;
249 }
250 Ok(out)
251}
252
253fn hmac_sha256(key: &str, data: &str) -> Vec<u8> {
256 let mut mac =
257 HmacSha256::new_from_slice(key.as_bytes()).expect("HMAC can take key of any size");
258 mac.update(data.as_bytes());
259 mac.finalize().into_bytes().to_vec()
260}
261
262fn extract_json_string(json: &str, key: &str) -> Option<String> {
263 let pattern = format!("\"{}\":\"", key);
264 let idx = json.find(&pattern)?;
265 let start = idx + pattern.len();
266 let end = json[start..].find('"')? + start;
267 Some(json[start..end].to_string())
268}
269
270fn extract_json_number(json: &str, key: &str) -> Option<u64> {
271 let pattern = format!("\"{}\":", key);
272 let idx = json.find(&pattern)?;
273 let start = idx + pattern.len();
274 let rest = &json[start..];
275 let end = rest
276 .find(|c: char| !c.is_ascii_digit())
277 .unwrap_or(rest.len());
278 rest[..end].parse().ok()
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn issue_and_verify() {
287 let jwt = JwtPlugin::new("test-secret", 3600);
288 let token = jwt.issue("user-1");
289
290 assert!(!token.is_empty());
291 assert_eq!(token.split('.').count(), 3);
292
293 let claims = jwt.verify(&token).unwrap();
294 assert_eq!(claims.sub, "user-1");
295 assert!(claims.exp > claims.iat);
296 assert_eq!(claims.kind, Some("access".into()));
297 }
298
299 #[test]
300 fn wrong_secret_fails() {
301 let jwt1 = JwtPlugin::new("secret-1", 3600);
302 let jwt2 = JwtPlugin::new("secret-2", 3600);
303
304 let token = jwt1.issue("user-1");
305 let result = jwt2.verify(&token);
306 assert!(result.is_err());
307 }
308
309 #[test]
310 fn expired_token_rejected() {
311 let jwt = JwtPlugin::new("secret", 0); let token = jwt.issue("user-1");
313
314 let _ = jwt.verify(&token); }
319
320 #[test]
321 fn invalid_format_rejected() {
322 let jwt = JwtPlugin::new("secret", 3600);
323 assert!(jwt.verify("not.a.jwt.token").is_err());
324 assert!(jwt.verify("invalid").is_err());
325 assert!(jwt.verify("").is_err());
326 }
327
328 #[test]
329 fn resolve_user() {
330 let jwt = JwtPlugin::new("secret", 3600);
331 let token = jwt.issue("alice");
332
333 assert_eq!(jwt.resolve_user(&token), Some("alice".into()));
334 assert_eq!(jwt.resolve_user("invalid"), None);
335 }
336
337 #[test]
338 fn different_users_different_tokens() {
339 let jwt = JwtPlugin::new("secret", 3600);
340 let t1 = jwt.issue("user-1");
341 let t2 = jwt.issue("user-2");
342 assert_ne!(t1, t2);
343 }
344
345 #[test]
346 fn hmac_sha256_produces_32_bytes() {
347 let sig = hmac_sha256("key", "data");
348 assert_eq!(sig.len(), 32);
349 }
350
351 #[test]
352 fn hmac_sha256_different_keys_different_output() {
353 let s1 = hmac_sha256("key1", "data");
354 let s2 = hmac_sha256("key2", "data");
355 assert_ne!(s1, s2);
356 }
357
358 #[test]
359 fn hmac_sha256_different_data_different_output() {
360 let s1 = hmac_sha256("key", "data1");
361 let s2 = hmac_sha256("key", "data2");
362 assert_ne!(s1, s2);
363 }
364
365 #[test]
368 fn issue_pair_creates_two_distinct_tokens() {
369 let jwt = JwtPlugin::new("secret", 300);
370 let pair = jwt.issue_pair("user-1", 86400 * 7);
371
372 assert_ne!(pair.access_token, pair.refresh_token);
373 assert_eq!(pair.access_expires_in, 300);
374 assert_eq!(pair.refresh_expires_in, 86400 * 7);
375
376 let access_claims = jwt.verify(&pair.access_token).unwrap();
377 assert_eq!(access_claims.sub, "user-1");
378 assert_eq!(access_claims.kind, Some("access".into()));
379
380 let refresh_claims = jwt.verify(&pair.refresh_token).unwrap();
381 assert_eq!(refresh_claims.sub, "user-1");
382 assert_eq!(refresh_claims.kind, Some("refresh".into()));
383 }
384
385 #[test]
386 fn refresh_returns_new_pair() {
387 let jwt = JwtPlugin::new("secret", 300);
388 let pair = jwt.issue_pair("user-1", 86400 * 7);
389
390 let new_pair = jwt.refresh(&pair.refresh_token).unwrap();
391
392 let access_claims = jwt.verify(&new_pair.access_token).unwrap();
394 assert_eq!(access_claims.sub, "user-1");
395 assert_eq!(access_claims.kind, Some("access".into()));
396
397 let refresh_claims = jwt.verify(&new_pair.refresh_token).unwrap();
398 assert_eq!(refresh_claims.sub, "user-1");
399 assert_eq!(refresh_claims.kind, Some("refresh".into()));
400
401 let err = jwt.refresh(&pair.refresh_token).unwrap_err();
403 assert!(err.contains("already used"));
404 }
405
406 #[test]
407 fn used_refresh_token_rejected() {
408 let jwt = JwtPlugin::new("secret", 300);
409 let pair = jwt.issue_pair("user-1", 86400 * 7);
410
411 assert!(jwt.refresh(&pair.refresh_token).is_ok());
413
414 let err = jwt.refresh(&pair.refresh_token).unwrap_err();
416 assert!(err.contains("already used"));
417 }
418
419 #[test]
420 fn access_token_cannot_be_used_as_refresh() {
421 let jwt = JwtPlugin::new("secret", 300);
422 let pair = jwt.issue_pair("user-1", 86400 * 7);
423
424 let err = jwt.refresh(&pair.access_token).unwrap_err();
425 assert!(err.contains("not a refresh token"));
426 }
427
428 #[test]
429 fn issue_with_kind_sets_kind_field() {
430 let jwt = JwtPlugin::new("secret", 3600);
431 let token = jwt.issue_with_kind("user-1", "refresh", 86400);
432 let claims = jwt.verify(&token).unwrap();
433 assert_eq!(claims.kind, Some("refresh".into()));
434 }
435}