use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use crate::web::context::Claims;
#[derive(Debug)]
pub struct JwtSignError(pub jsonwebtoken::errors::Error);
impl std::fmt::Display for JwtSignError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "JWT signing failed: {}", self.0)
}
}
impl std::error::Error for JwtSignError {}
pub struct JwtConfig {
pub secret: String,
pub algorithm: Algorithm,
pub access_ttl_secs: u64,
pub refresh_ttl_secs: u64,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
secret: "change-me-in-production".to_string(),
algorithm: Algorithm::HS256,
access_ttl_secs: 900,
refresh_ttl_secs: 604_800,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct JwtClaims {
sub: String,
#[serde(skip_serializing_if = "String::is_empty", default)]
role: String,
#[serde(skip_serializing_if = "String::is_empty", default)]
email: String,
#[serde(rename = "type")]
kind: String,
jti: String,
iat: u64,
exp: u64,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
perms: Vec<String>,
#[serde(skip_serializing_if = "String::is_empty", default)]
tenant: String,
}
struct JwtKeyMaterial {
encoding: EncodingKey,
verify: Vec<DecodingKey>, version: u64,
}
impl JwtKeyMaterial {
fn from_secret(secret: &[u8], version: u64, previous: Option<DecodingKey>) -> Self {
let mut verify = vec![DecodingKey::from_secret(secret)];
verify.extend(previous);
Self {
encoding: EncodingKey::from_secret(secret),
verify,
version,
}
}
}
pub struct JwtService {
keys: crate::auth::secrets::Rotating<JwtKeyMaterial>,
header: Header,
validation: Validation,
config: JwtConfig,
}
impl JwtService {
pub fn new(config: JwtConfig) -> Self {
let keys = crate::auth::secrets::Rotating::new(JwtKeyMaterial::from_secret(
config.secret.as_bytes(),
1,
None,
));
let header = Header::new(config.algorithm);
let mut validation = Validation::new(config.algorithm);
validation.validate_exp = true;
Self {
keys,
header,
validation,
config,
}
}
pub fn rotate_secret(&self, new_secret: &[u8], version: u64) {
let current = self.keys.load();
if version <= current.version {
tracing::warn!(
current = current.version,
offered = version,
"ignoring stale JWT secret rotation",
);
return;
}
let previous = current.verify.first().cloned();
self.keys
.store(JwtKeyMaterial::from_secret(new_secret, version, previous));
tracing::info!(version, "JwtService signing key rotated");
}
fn now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub fn issue_access(&self, sub: &str, role: &str, email: &str) -> Result<String, JwtSignError> {
self.issue_access_with_perms(sub, role, email, &[])
}
pub fn issue_access_with_perms(
&self,
sub: &str,
role: &str,
email: &str,
perms: &[String],
) -> Result<String, JwtSignError> {
self.issue_access_bound(sub, role, email, perms, None)
}
pub fn issue_access_bound(
&self,
sub: &str,
role: &str,
email: &str,
perms: &[String],
tenant: Option<&str>,
) -> Result<String, JwtSignError> {
let now = Self::now();
let claims = JwtClaims {
sub: sub.to_owned(),
role: role.to_owned(),
email: email.to_owned(),
kind: "access".to_owned(),
jti: new_jti(),
iat: now,
exp: now + self.config.access_ttl_secs,
perms: perms.to_vec(),
tenant: tenant.unwrap_or("").to_owned(),
};
encode(&self.header, &claims, &self.keys.load().encoding).map_err(JwtSignError)
}
pub fn issue_refresh(&self, sub: &str) -> Result<(String, String), JwtSignError> {
let now = Self::now();
let jti = new_jti();
let claims = JwtClaims {
sub: sub.to_owned(),
role: String::new(),
email: String::new(),
kind: "refresh".to_owned(),
jti: jti.clone(),
iat: now,
exp: now + self.config.refresh_ttl_secs,
perms: Vec::new(),
tenant: String::new(),
};
let token =
encode(&self.header, &claims, &self.keys.load().encoding).map_err(JwtSignError)?;
Ok((token, jti))
}
pub fn decode(&self, token: &str) -> Option<Arc<Claims>> {
let keys = self.keys.load();
let data = keys
.verify
.iter()
.find_map(|k| decode::<serde_json::Value>(token, k, &self.validation).ok())?;
let obj = data.claims.as_object()?.clone();
Some(Arc::new(obj))
}
pub fn decode_access(&self, token: &str) -> Option<Arc<Claims>> {
let claims = self.decode(token)?;
if claims.get("type").and_then(|v| v.as_str()) != Some("access") {
return None;
}
Some(claims)
}
pub fn validate_refresh(&self, token: &str) -> Option<(String, String)> {
let claims = self.decode(token)?;
if claims.get("type")?.as_str()? != "refresh" {
return None;
}
let sub = claims.get("sub")?.as_str()?.to_owned();
let jti = claims.get("jti")?.as_str()?.to_owned();
Some((sub, jti))
}
pub fn access_ttl_secs(&self) -> u64 {
self.config.access_ttl_secs
}
pub fn refresh_ttl_secs(&self) -> u64 {
self.config.refresh_ttl_secs
}
}
pub fn decode_bearer_token(
headers: &axum::http::HeaderMap,
container: &crate::core::engine::FrozenDiContainer,
) -> Option<Arc<Claims>> {
let raw = headers.get("authorization")?.to_str().ok()?;
let token = raw.strip_prefix("Bearer ").unwrap_or(raw).trim();
if token.is_empty() {
return None;
}
container.try_get::<JwtService>()?.decode_access(token)
}
fn new_jti() -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
let mut h1 = DefaultHasher::new();
SystemTime::now().hash(&mut h1);
seq.hash(&mut h1);
let mut h2 = DefaultHasher::new();
std::thread::current().id().hash(&mut h2);
seq.wrapping_add(1).hash(&mut h2);
format!("{:016x}{:016x}", h1.finish(), h2.finish())
}