use std::collections::HashMap;
use std::sync::RwLock;
use base64::Engine;
use rand::Rng;
use subtle::ConstantTimeEq;
#[derive(Debug, Clone)]
pub struct JwtTokenPair {
pub access: String,
pub refresh: String,
}
#[derive(Debug, Clone)]
pub struct JwtClaims {
pub sub: i64,
pub exp: i64,
pub jti: String,
pub typ: String,
pub custom: serde_json::Map<String, serde_json::Value>,
}
impl JwtClaims {
#[must_use]
pub fn get_custom<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
let v = self.custom.get(key)?;
serde_json::from_value(v.clone()).ok()
}
#[must_use]
pub fn custom_value(&self, key: &str) -> Option<&serde_json::Value> {
self.custom.get(key)
}
}
pub const RESERVED_CLAIM_NAMES: &[&str] = &["sub", "exp", "jti", "typ"];
#[derive(Debug, thiserror::Error)]
pub enum JwtIssueError {
#[error("reserved claim `{0}` cannot be set in custom payload")]
ReservedClaim(String),
}
const ACCESS_TYP: &str = "access";
const REFRESH_TYP: &str = "refresh";
pub const DEFAULT_ACCESS_TTL_SECS: i64 = 900;
pub const DEFAULT_REFRESH_TTL_SECS: i64 = 7 * 24 * 3600;
pub struct JwtLifecycle {
secret: Vec<u8>,
pub access_ttl_secs: i64,
pub refresh_ttl_secs: i64,
blacklist: RwLock<HashMap<String, i64>>, }
impl JwtLifecycle {
#[must_use]
pub fn new(secret: Vec<u8>) -> Self {
Self {
secret,
access_ttl_secs: DEFAULT_ACCESS_TTL_SECS,
refresh_ttl_secs: DEFAULT_REFRESH_TTL_SECS,
blacklist: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn with_access_ttl(mut self, secs: i64) -> Self {
self.access_ttl_secs = secs;
self
}
#[must_use]
pub fn with_refresh_ttl(mut self, secs: i64) -> Self {
self.refresh_ttl_secs = secs;
self
}
pub fn issue_pair(&self, user_id: i64) -> JwtTokenPair {
self.issue_pair_with(user_id, serde_json::Map::new())
.expect("empty custom map cannot trigger ReservedClaim")
}
pub fn issue_pair_with(
&self,
user_id: i64,
custom: serde_json::Map<String, serde_json::Value>,
) -> Result<JwtTokenPair, JwtIssueError> {
check_reserved(&custom)?;
let access = self.issue_token_inner(user_id, ACCESS_TYP, self.access_ttl_secs, &custom);
let refresh = self.issue_token_inner(user_id, REFRESH_TYP, self.refresh_ttl_secs, &custom);
Ok(JwtTokenPair { access, refresh })
}
pub fn issue_access_with(
&self,
user_id: i64,
custom: serde_json::Map<String, serde_json::Value>,
) -> Result<String, JwtIssueError> {
check_reserved(&custom)?;
Ok(self.issue_token_inner(user_id, ACCESS_TYP, self.access_ttl_secs, &custom))
}
#[must_use]
pub fn verify_access(&self, token: &str) -> Option<JwtClaims> {
let claims = self.verify_token(token)?;
if claims.typ != ACCESS_TYP {
return None;
}
Some(claims)
}
#[must_use]
pub fn verify_refresh(&self, token: &str) -> Option<JwtClaims> {
let claims = self.verify_token(token)?;
if claims.typ != REFRESH_TYP {
return None;
}
Some(claims)
}
pub fn refresh(&self, refresh_token: &str) -> Option<JwtTokenPair> {
let claims = self.verify_refresh(refresh_token)?;
self.blacklist_jti(&claims.jti, claims.exp);
self.issue_pair_with(claims.sub, claims.custom).ok()
}
pub fn refresh_with(
&self,
refresh_token: &str,
new_custom: serde_json::Map<String, serde_json::Value>,
) -> Result<Option<JwtTokenPair>, JwtIssueError> {
let Some(claims) = self.verify_refresh(refresh_token) else {
return Ok(None);
};
self.blacklist_jti(&claims.jti, claims.exp);
self.issue_pair_with(claims.sub, new_custom).map(Some)
}
pub fn revoke(&self, token: &str) -> bool {
let Some(claims) = self.decode_unchecked(token) else {
return false;
};
self.blacklist_jti(&claims.jti, claims.exp);
true
}
#[must_use]
pub fn blacklist_size(&self) -> usize {
self.prune_blacklist();
self.blacklist.read().expect("blacklist poisoned").len()
}
fn issue_token_inner(
&self,
user_id: i64,
typ: &str,
ttl_secs: i64,
custom: &serde_json::Map<String, serde_json::Value>,
) -> String {
let exp = chrono::Utc::now().timestamp() + ttl_secs;
let jti = random_jti();
let mut payload = serde_json::Map::new();
for (k, v) in custom {
payload.insert(k.clone(), v.clone());
}
payload.insert("sub".into(), serde_json::Value::from(user_id));
payload.insert("exp".into(), serde_json::Value::from(exp));
payload.insert("jti".into(), serde_json::Value::String(jti));
payload.insert("typ".into(), serde_json::Value::String(typ.into()));
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(serde_json::to_vec(&payload).unwrap_or_default());
let sig = self.sign(payload_b64.as_bytes());
let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sig);
format!("{payload_b64}.{sig_b64}")
}
fn verify_token(&self, token: &str) -> Option<JwtClaims> {
let claims = self.decode_unchecked(token)?;
if chrono::Utc::now().timestamp() >= claims.exp {
return None;
}
if self.is_blacklisted(&claims.jti) {
return None;
}
Some(claims)
}
fn decode_unchecked(&self, token: &str) -> Option<JwtClaims> {
let (payload_b64, sig_b64) = token.split_once('.')?;
let expected = self.sign(payload_b64.as_bytes());
let provided = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(sig_b64)
.ok()?;
if expected.ct_eq(&provided[..]).unwrap_u8() == 0 {
return None;
}
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.ok()?;
let mut payload: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&payload_bytes).ok()?;
let sub = payload.get("sub")?.as_i64()?;
let exp = payload.get("exp")?.as_i64()?;
let jti = payload.get("jti")?.as_str()?.to_owned();
let typ = payload.get("typ")?.as_str()?.to_owned();
for reserved in RESERVED_CLAIM_NAMES {
payload.remove(*reserved);
}
Some(JwtClaims { sub, exp, jti, typ, custom: payload })
}
fn sign(&self, msg: &[u8]) -> Vec<u8> {
use hmac::{Hmac, Mac};
use sha2::Sha256;
let mut mac = <Hmac<Sha256>>::new_from_slice(&self.secret).expect("HMAC accepts any key");
mac.update(msg);
mac.finalize().into_bytes().to_vec()
}
fn blacklist_jti(&self, jti: &str, expires_at: i64) {
self.blacklist
.write()
.expect("blacklist poisoned")
.insert(jti.to_owned(), expires_at);
self.prune_blacklist();
}
fn is_blacklisted(&self, jti: &str) -> bool {
let now = chrono::Utc::now().timestamp();
let bl = self.blacklist.read().expect("blacklist poisoned");
bl.get(jti).map_or(false, |&exp| exp > now)
}
fn prune_blacklist(&self) {
let now = chrono::Utc::now().timestamp();
let mut bl = self.blacklist.write().expect("blacklist poisoned");
bl.retain(|_, &mut exp| exp > now);
}
}
fn random_jti() -> String {
let bytes: [u8; 16] = rand::thread_rng().gen();
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
fn check_reserved(
custom: &serde_json::Map<String, serde_json::Value>,
) -> Result<(), JwtIssueError> {
for reserved in RESERVED_CLAIM_NAMES {
if custom.contains_key(*reserved) {
return Err(JwtIssueError::ReservedClaim((*reserved).to_owned()));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn jwt() -> JwtLifecycle {
JwtLifecycle::new(b"test-secret".to_vec())
}
#[test]
fn issue_and_verify_access() {
let j = jwt();
let pair = j.issue_pair(42);
let claims = j.verify_access(&pair.access).expect("access verifies");
assert_eq!(claims.sub, 42);
assert_eq!(claims.typ, "access");
assert!(!claims.jti.is_empty());
}
#[test]
fn issue_and_verify_refresh() {
let j = jwt();
let pair = j.issue_pair(42);
let claims = j.verify_refresh(&pair.refresh).expect("refresh verifies");
assert_eq!(claims.sub, 42);
assert_eq!(claims.typ, "refresh");
}
#[test]
fn access_token_rejected_as_refresh() {
let j = jwt();
let pair = j.issue_pair(1);
assert!(j.verify_refresh(&pair.access).is_none());
}
#[test]
fn refresh_token_rejected_as_access() {
let j = jwt();
let pair = j.issue_pair(1);
assert!(j.verify_access(&pair.refresh).is_none());
}
#[test]
fn refresh_returns_new_pair() {
let j = jwt();
let pair = j.issue_pair(7);
let new_pair = j.refresh(&pair.refresh).expect("refresh succeeds");
assert_ne!(pair.access, new_pair.access);
assert_ne!(pair.refresh, new_pair.refresh);
let claims = j.verify_access(&new_pair.access).unwrap();
assert_eq!(claims.sub, 7);
}
#[test]
fn refresh_blacklists_old_refresh_token() {
let j = jwt();
let pair = j.issue_pair(7);
let _new = j.refresh(&pair.refresh).unwrap();
assert!(j.refresh(&pair.refresh).is_none());
assert!(j.verify_refresh(&pair.refresh).is_none());
}
#[test]
fn revoke_invalidates_access_token() {
let j = jwt();
let pair = j.issue_pair(1);
assert!(j.verify_access(&pair.access).is_some());
assert!(j.revoke(&pair.access));
assert!(j.verify_access(&pair.access).is_none());
}
#[test]
fn revoke_invalid_token_returns_false() {
let j = jwt();
assert!(!j.revoke("not-a-valid-token"));
}
#[test]
fn tampered_signature_fails_verification() {
let j = jwt();
let pair = j.issue_pair(1);
let mut bytes = pair.access.into_bytes();
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
let tampered = String::from_utf8(bytes).unwrap();
assert!(j.verify_access(&tampered).is_none());
}
#[test]
fn wrong_secret_fails_verification() {
let j1 = jwt();
let j2 = JwtLifecycle::new(b"different-secret".to_vec());
let pair = j1.issue_pair(5);
assert!(j2.verify_access(&pair.access).is_none());
}
#[test]
fn unique_jti_per_issuance() {
let j = jwt();
let pair1 = j.issue_pair(1);
let pair2 = j.issue_pair(1);
let c1 = j.verify_access(&pair1.access).unwrap();
let c2 = j.verify_access(&pair2.access).unwrap();
assert_ne!(c1.jti, c2.jti);
}
#[test]
fn custom_ttls() {
let j = JwtLifecycle::new(b"k".to_vec())
.with_access_ttl(60)
.with_refresh_ttl(3600);
assert_eq!(j.access_ttl_secs, 60);
assert_eq!(j.refresh_ttl_secs, 3600);
}
fn map(value: serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
value.as_object().unwrap().clone()
}
#[test]
fn issue_pair_with_embeds_custom_claims() {
let j = jwt();
let pair = j
.issue_pair_with(
42,
map(serde_json::json!({"roles": ["admin", "editor"], "tenant": "acme"})),
)
.unwrap();
let claims = j.verify_access(&pair.access).unwrap();
assert_eq!(claims.sub, 42);
let roles: Vec<String> = claims.get_custom("roles").unwrap();
assert_eq!(roles, vec!["admin", "editor"]);
let tenant: String = claims.get_custom("tenant").unwrap();
assert_eq!(tenant, "acme");
}
#[test]
fn issue_pair_no_custom_returns_empty_custom_map() {
let j = jwt();
let pair = j.issue_pair(7);
let claims = j.verify_access(&pair.access).unwrap();
assert!(claims.custom.is_empty());
let missing: Option<String> = claims.get_custom("anything");
assert!(missing.is_none());
}
#[test]
fn issue_pair_with_rejects_reserved_claims() {
let j = jwt();
for reserved in RESERVED_CLAIM_NAMES {
let custom = map(serde_json::json!({ *reserved: "evil" }));
let r = j.issue_pair_with(1, custom);
assert!(matches!(r, Err(JwtIssueError::ReservedClaim(_))), "should reject {reserved}");
}
}
#[test]
fn refresh_preserves_custom_claims() {
let j = jwt();
let pair = j
.issue_pair_with(7, map(serde_json::json!({"scope": "read:posts write:posts"})))
.unwrap();
let new_pair = j.refresh(&pair.refresh).unwrap();
let new_access_claims = j.verify_access(&new_pair.access).unwrap();
let new_refresh_claims = j.verify_refresh(&new_pair.refresh).unwrap();
assert_eq!(new_access_claims.sub, 7);
let scope: String = new_access_claims.get_custom("scope").unwrap();
assert_eq!(scope, "read:posts write:posts");
let scope_r: String = new_refresh_claims.get_custom("scope").unwrap();
assert_eq!(scope_r, "read:posts write:posts");
}
#[test]
fn refresh_with_substitutes_new_custom_claims() {
let j = jwt();
let pair = j
.issue_pair_with(7, map(serde_json::json!({"roles": ["admin"]})))
.unwrap();
let new_pair = j
.refresh_with(&pair.refresh, map(serde_json::json!({"roles": ["viewer"]})))
.unwrap()
.unwrap();
let claims = j.verify_access(&new_pair.access).unwrap();
let roles: Vec<String> = claims.get_custom("roles").unwrap();
assert_eq!(roles, vec!["viewer"]);
}
#[test]
fn refresh_with_invalid_token_returns_ok_none() {
let j = jwt();
let r = j.refresh_with("not-a-token", map(serde_json::json!({}))).unwrap();
assert!(r.is_none());
}
#[test]
fn refresh_with_rejects_reserved_claims() {
let j = jwt();
let pair = j.issue_pair(1);
let r = j.refresh_with(&pair.refresh, map(serde_json::json!({"sub": 999})));
assert!(matches!(r, Err(JwtIssueError::ReservedClaim(_))));
}
#[test]
fn issue_access_with_returns_single_token() {
let j = jwt();
let token = j
.issue_access_with(42, map(serde_json::json!({"key_id": "abc"})))
.unwrap();
let claims = j.verify_access(&token).unwrap();
assert_eq!(claims.sub, 42);
assert_eq!(claims.typ, "access");
let key_id: String = claims.get_custom("key_id").unwrap();
assert_eq!(key_id, "abc");
}
#[test]
fn custom_value_returns_raw_json() {
let j = jwt();
let token = j
.issue_access_with(1, map(serde_json::json!({"nested": {"x": 1}})))
.unwrap();
let claims = j.verify_access(&token).unwrap();
let raw = claims.custom_value("nested").unwrap();
assert_eq!(raw["x"], 1);
}
#[test]
fn refresh_blacklists_old_refresh_even_with_custom_claims() {
let j = jwt();
let pair = j
.issue_pair_with(7, map(serde_json::json!({"role": "admin"})))
.unwrap();
let _new = j.refresh(&pair.refresh).unwrap();
assert!(j.refresh(&pair.refresh).is_none());
}
}