use std::time::{Duration, SystemTime, UNIX_EPOCH};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::{json, Map, Value};
use subtle::ConstantTimeEq;
#[derive(Debug, thiserror::Error)]
pub enum JwtError {
#[error("malformed token: expected three base64url segments")]
Malformed,
#[error("unsupported algorithm: {0} (only HS256)")]
UnsupportedAlg(String),
#[error("signature mismatch")]
BadSignature,
#[error("token expired (exp={0})")]
Expired(u64),
#[error("token not yet valid (nbf={0})")]
NotYetValid(u64),
#[error("decode error: {0}")]
Decode(String),
}
#[derive(Debug, Clone, Default)]
pub struct Claims {
inner: Map<String, Value>,
}
impl Claims {
#[must_use]
pub fn new(subject: impl Into<String>) -> Self {
let mut c = Self::default();
c.inner.insert("sub".into(), Value::String(subject.into()));
c.inner.insert("iat".into(), Value::from(now_secs()));
c
}
#[must_use]
pub fn empty() -> Self {
Self::default()
}
pub fn set<T: Serialize>(&mut self, name: impl Into<String>, value: T) {
if let Ok(v) = serde_json::to_value(value) {
self.inner.insert(name.into(), v);
}
}
pub fn get<T: DeserializeOwned>(&self, name: &str) -> Option<T> {
self.inner
.get(name)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
#[must_use]
pub fn subject(&self) -> Option<&str> {
self.inner.get("sub").and_then(Value::as_str)
}
#[must_use]
pub fn issuer(self, iss: impl Into<String>) -> Self {
let mut c = self;
c.inner.insert("iss".into(), Value::String(iss.into()));
c
}
#[must_use]
pub fn audience(self, aud: impl Into<String>) -> Self {
let mut c = self;
c.inner.insert("aud".into(), Value::String(aud.into()));
c
}
#[must_use]
pub fn ttl(self, ttl: Duration) -> Self {
let mut c = self;
let now = now_secs();
c.inner.insert("iat".into(), Value::from(now));
c.inner.insert("exp".into(), Value::from(now + ttl.as_secs()));
c
}
#[must_use]
pub fn expires_at(self, unix_secs: u64) -> Self {
let mut c = self;
c.inner.insert("exp".into(), Value::from(unix_secs));
c
}
#[must_use]
pub fn not_before(self, unix_secs: u64) -> Self {
let mut c = self;
c.inner.insert("nbf".into(), Value::from(unix_secs));
c
}
#[must_use]
pub fn jti(self, jti: impl Into<String>) -> Self {
let mut c = self;
c.inner.insert("jti".into(), Value::String(jti.into()));
c
}
fn to_json(&self) -> Vec<u8> {
serde_json::to_vec(&self.inner).unwrap_or_else(|_| b"{}".to_vec())
}
fn from_json(json: &[u8]) -> Result<Self, JwtError> {
let inner: Map<String, Value> = serde_json::from_slice(json)
.map_err(|e| JwtError::Decode(format!("claims: {e}")))?;
Ok(Self { inner })
}
}
pub fn encode(claims: &Claims, secret: &[u8]) -> Result<String, JwtError> {
if secret.is_empty() {
return Err(JwtError::Decode("HMAC secret must not be empty".into()));
}
let header = json!({"alg": "HS256", "typ": "JWT"});
let header_b = URL_SAFE_NO_PAD
.encode(serde_json::to_vec(&header).expect("header serialize"));
let payload_b = URL_SAFE_NO_PAD.encode(claims.to_json());
let signing_input = format!("{header_b}.{payload_b}");
let sig = hmac_sha256(secret, signing_input.as_bytes());
let sig_b = URL_SAFE_NO_PAD.encode(sig);
Ok(format!("{signing_input}.{sig_b}"))
}
pub fn decode(token: &str, secret: &[u8]) -> Result<Claims, JwtError> {
decode_at(token, secret, now_secs())
}
pub fn decode_at(token: &str, secret: &[u8], now: u64) -> Result<Claims, JwtError> {
let mut it = token.split('.');
let header_b = it.next().ok_or(JwtError::Malformed)?;
let payload_b = it.next().ok_or(JwtError::Malformed)?;
let sig_b = it.next().ok_or(JwtError::Malformed)?;
if it.next().is_some() {
return Err(JwtError::Malformed);
}
let signing_input = format!("{header_b}.{payload_b}");
let expected = hmac_sha256(secret, signing_input.as_bytes());
let provided = URL_SAFE_NO_PAD
.decode(sig_b.as_bytes())
.map_err(|_| JwtError::BadSignature)?;
if expected.ct_eq(&provided).unwrap_u8() == 0 {
return Err(JwtError::BadSignature);
}
let header_bytes = URL_SAFE_NO_PAD
.decode(header_b.as_bytes())
.map_err(|e| JwtError::Decode(format!("header b64: {e}")))?;
let header: Value = serde_json::from_slice(&header_bytes)
.map_err(|e| JwtError::Decode(format!("header json: {e}")))?;
let alg = header.get("alg").and_then(Value::as_str).unwrap_or("");
if alg != "HS256" {
return Err(JwtError::UnsupportedAlg(alg.to_owned()));
}
let payload_bytes = URL_SAFE_NO_PAD
.decode(payload_b.as_bytes())
.map_err(|e| JwtError::Decode(format!("payload b64: {e}")))?;
let claims = Claims::from_json(&payload_bytes)?;
if let Some(exp) = claims.get::<u64>("exp") {
if now > exp {
return Err(JwtError::Expired(exp));
}
}
if let Some(nbf) = claims.get::<u64>("nbf") {
if now < nbf {
return Err(JwtError::NotYetValid(nbf));
}
}
Ok(claims)
}
pub fn decode_unverified(token: &str) -> Result<Claims, JwtError> {
let mut it = token.split('.');
let _header = it.next().ok_or(JwtError::Malformed)?;
let payload_b = it.next().ok_or(JwtError::Malformed)?;
let _sig = it.next().ok_or(JwtError::Malformed)?;
let payload_bytes = URL_SAFE_NO_PAD
.decode(payload_b.as_bytes())
.map_err(|e| JwtError::Decode(format!("payload b64: {e}")))?;
Claims::from_json(&payload_bytes)
}
use crate::crypto::hmac_sha256;
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}
#[cfg(test)]
mod tests {
use super::*;
const SECRET: &[u8] = b"test-shared-secret-32-byte-string";
#[test]
fn round_trip_encode_decode() {
let mut c = Claims::new("user-42");
c.set("role", "admin");
c.set("count", 7_i64);
let token = encode(&c, SECRET).unwrap();
let v = decode(&token, SECRET).unwrap();
assert_eq!(v.subject(), Some("user-42"));
assert_eq!(v.get::<String>("role").as_deref(), Some("admin"));
assert_eq!(v.get::<i64>("count"), Some(7));
}
#[test]
fn empty_secret_rejected() {
let c = Claims::new("x");
assert!(matches!(encode(&c, b""), Err(JwtError::Decode(_))));
}
#[test]
fn token_format_is_three_segments() {
let c = Claims::new("x");
let t = encode(&c, SECRET).unwrap();
assert_eq!(t.matches('.').count(), 2);
}
#[test]
fn wrong_secret_fails_signature_check() {
let c = Claims::new("x");
let t = encode(&c, SECRET).unwrap();
let err = decode(&t, b"wrong-secret-bytes").unwrap_err();
assert!(matches!(err, JwtError::BadSignature));
}
#[test]
fn payload_tampering_fails_signature_check() {
let c = Claims::new("alice");
let t = encode(&c, SECRET).unwrap();
let parts: Vec<&str> = t.split('.').collect();
let evil_payload = URL_SAFE_NO_PAD.encode(b"{\"sub\":\"bob\"}");
let tampered = format!("{}.{}.{}", parts[0], evil_payload, parts[2]);
assert!(matches!(decode(&tampered, SECRET), Err(JwtError::BadSignature)));
}
#[test]
fn malformed_token_rejected() {
assert!(matches!(decode("only.two", SECRET), Err(JwtError::Malformed)));
assert!(matches!(decode("a.b.c.d", SECRET), Err(JwtError::Malformed)));
}
#[test]
fn expired_token_rejected_at_decode_time() {
let c = Claims::new("x").expires_at(now_secs() - 100);
let t = encode(&c, SECRET).unwrap();
let err = decode(&t, SECRET).unwrap_err();
assert!(matches!(err, JwtError::Expired(_)));
}
#[test]
fn ttl_helper_sets_iat_and_exp() {
let c = Claims::new("x").ttl(Duration::from_secs(3600));
assert!(c.get::<u64>("iat").is_some());
let exp = c.get::<u64>("exp").unwrap();
assert!(exp > now_secs(), "exp must be in the future");
}
#[test]
fn not_before_rejected_when_future() {
let c = Claims::new("x").not_before(now_secs() + 3600);
let t = encode(&c, SECRET).unwrap();
assert!(matches!(decode(&t, SECRET), Err(JwtError::NotYetValid(_))));
}
#[test]
fn decode_at_specific_time_lets_us_test_clock_window() {
let c = Claims::new("x").expires_at(1000);
let t = encode(&c, SECRET).unwrap();
let v = decode_at(&t, SECRET, 500).unwrap();
assert_eq!(v.subject(), Some("x"));
assert!(matches!(decode_at(&t, SECRET, 2000), Err(JwtError::Expired(1000))));
}
#[test]
fn alg_other_than_hs256_rejected() {
let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\",\"typ\":\"JWT\"}");
let payload = URL_SAFE_NO_PAD.encode(b"{\"sub\":\"x\"}");
let signing_input = format!("{header}.{payload}");
let sig = URL_SAFE_NO_PAD
.encode(hmac_sha256(SECRET, signing_input.as_bytes()));
let token = format!("{signing_input}.{sig}");
let err = decode(&token, SECRET).unwrap_err();
assert!(matches!(err, JwtError::UnsupportedAlg(_)));
}
#[test]
fn issuer_audience_jti_round_trip() {
let c = Claims::new("x")
.issuer("api.example.com")
.audience("client.example.com")
.jti("token-1");
let t = encode(&c, SECRET).unwrap();
let v = decode(&t, SECRET).unwrap();
assert_eq!(v.get::<String>("iss").as_deref(), Some("api.example.com"));
assert_eq!(v.get::<String>("aud").as_deref(), Some("client.example.com"));
assert_eq!(v.get::<String>("jti").as_deref(), Some("token-1"));
}
#[test]
fn decode_unverified_skips_signature_and_exp() {
let c = Claims::new("x").expires_at(now_secs() - 100);
let t = encode(&c, SECRET).unwrap();
assert!(decode(&t, SECRET).is_err());
let v = decode_unverified(&t).unwrap();
assert_eq!(v.subject(), Some("x"));
}
#[test]
fn empty_claims_round_trip_when_no_sub() {
let c = Claims::empty();
let t = encode(&c, SECRET).unwrap();
let v = decode(&t, SECRET).unwrap();
assert_eq!(v.subject(), None);
}
}