use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::OnceLock;
use crate::Result;
use crate::config::AuthConfig;
static AUTO_JWT_SECRET: OnceLock<String> = OnceLock::new();
fn get_or_generate_jwt_secret() -> &'static str {
AUTO_JWT_SECRET.get_or_init(|| {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
let secret: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
tracing::warn!(
"No jwt_secret configured — generated a random secret. \
JWTs signed by external services will fail validation. \
Set [auth] jwt_secret in what.toml or WHAT_AUTH_JWT_SECRET env var."
);
secret
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
#[serde(default)]
pub exp: Option<u64>,
#[serde(default)]
pub iat: Option<u64>,
#[serde(default)]
pub sub: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, Value>,
}
impl JwtClaims {
pub fn to_context(&self, claim_names: &[String]) -> HashMap<String, Value> {
let mut context = HashMap::new();
if let Some(sub) = &self.sub {
context.insert("sub".to_string(), json!(sub));
}
if let Some(exp) = self.exp {
context.insert("exp".to_string(), json!(exp));
}
for name in claim_names {
if let Some(value) = self.custom.get(name) {
context.insert(name.clone(), value.clone());
}
}
context
}
pub fn is_expired(&self) -> bool {
if let Some(exp) = self.exp {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
exp < now
} else {
false }
}
}
#[derive(Clone)]
pub struct AuthHandler {
config: AuthConfig,
}
impl AuthHandler {
pub fn new(config: AuthConfig) -> Self {
Self { config }
}
pub fn from_config_with_env(mut config: AuthConfig) -> Self {
if let Ok(val) = std::env::var("WHAT_AUTH_ENABLED") {
config.enabled = val.parse().unwrap_or(config.enabled);
}
if let Ok(val) = std::env::var("WHAT_AUTH_LOGIN_ENDPOINT") {
config.login_endpoint = Some(val);
}
if let Ok(val) = std::env::var("WHAT_AUTH_LOGOUT_ENDPOINT") {
config.logout_endpoint = Some(val);
}
if let Ok(val) = std::env::var("WHAT_AUTH_JWT_SECRET") {
config.jwt_secret = Some(val);
}
if let Ok(val) = std::env::var("WHAT_AUTH_JWT_COOKIE_NAME") {
config.jwt_cookie_name = val;
}
if let Ok(val) = std::env::var("WHAT_AUTH_LOGIN_PATH") {
config.login_path = val;
}
if let Ok(val) = std::env::var("WHAT_AUTH_AFTER_LOGIN") {
config.after_login = val;
}
Self { config }
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn is_protected(&self, path: &str) -> bool {
if !self.config.enabled {
return false;
}
for pattern in &self.config.protected_paths {
if pattern_matches(pattern, path) {
return true;
}
}
false
}
pub fn login_path(&self) -> &str {
&self.config.login_path
}
pub fn after_login_path(&self) -> &str {
&self.config.after_login
}
pub fn login_endpoint(&self) -> Option<&str> {
self.config.login_endpoint.as_deref()
}
pub fn logout_endpoint(&self) -> Option<&str> {
self.config.logout_endpoint.as_deref()
}
pub fn jwt_cookie_name(&self) -> &str {
&self.config.jwt_cookie_name
}
pub fn jwt_claims(&self) -> &[String] {
&self.config.jwt_claims
}
pub fn parse_jwt_cookie(&self, cookie_header: Option<&str>) -> Option<String> {
cookie_header.and_then(|header| {
header
.split(';')
.map(|s| s.trim())
.find(|s| s.starts_with(&format!("{}=", self.config.jwt_cookie_name)))
.map(|s| s[self.config.jwt_cookie_name.len() + 1..].to_string())
})
}
pub fn decode_jwt(&self, token: &str) -> Result<JwtClaims> {
let secret = match self.config.jwt_secret {
Some(ref s) => s.as_str(),
None => get_or_generate_jwt_secret(),
};
let key = DecodingKey::from_secret(secret.as_bytes());
let validation = Validation::new(Algorithm::HS256);
let token_data = decode::<JwtClaims>(token, &key, &validation)?;
Ok(token_data.claims)
}
pub fn build_jwt_cookie(&self, token: &str, max_age: i64, secure: bool) -> String {
let mut cookie = format!(
"{}={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
self.config.jwt_cookie_name, token, max_age
);
if secure {
cookie.push_str("; Secure");
}
cookie
}
pub fn build_clear_cookie(&self) -> String {
format!(
"{}=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0",
self.config.jwt_cookie_name
)
}
}
fn pattern_matches(pattern: &str, path: &str) -> bool {
if pattern.ends_with("/**") {
let prefix = &pattern[..pattern.len() - 3];
path.starts_with(prefix)
} else if pattern.ends_with("/*") {
let prefix = &pattern[..pattern.len() - 1];
path.starts_with(prefix) && !path[prefix.len()..].contains('/')
} else {
pattern == path
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserContext {
pub authenticated: bool,
#[serde(flatten)]
pub claims: HashMap<String, Value>,
}
impl UserContext {
pub fn unauthenticated() -> Self {
Self {
authenticated: false,
claims: HashMap::new(),
}
}
pub fn from_claims(claims: HashMap<String, Value>) -> Self {
Self {
authenticated: true,
claims,
}
}
pub fn to_context(&self) -> Value {
let mut map = serde_json::Map::new();
map.insert("authenticated".to_string(), json!(self.authenticated));
for (key, value) in &self.claims {
map.insert(key.clone(), value.clone());
}
Value::Object(map)
}
pub fn roles(&self) -> Vec<String> {
self.claims
.get("roles")
.or_else(|| self.claims.get("role"))
.map(|v| match v {
Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect(),
Value::String(s) => s.split(',').map(|r| r.trim().to_string()).collect(),
_ => Vec::new(),
})
.unwrap_or_default()
}
pub fn sub(&self) -> Option<String> {
self.claims.get("sub").and_then(|v| v.as_str().map(String::from))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_matches() {
assert!(pattern_matches("/admin", "/admin"));
assert!(!pattern_matches("/admin", "/admin/users"));
assert!(pattern_matches("/admin/*", "/admin/users"));
assert!(pattern_matches("/admin/*", "/admin/settings"));
assert!(!pattern_matches("/admin/*", "/admin/users/123"));
assert!(!pattern_matches("/admin/*", "/admin"));
assert!(pattern_matches("/api/**", "/api/v1"));
assert!(pattern_matches("/api/**", "/api/v1/users"));
assert!(pattern_matches("/api/**", "/api/v1/users/123"));
}
#[test]
fn test_jwt_claims_to_context() {
let claims = JwtClaims {
exp: Some(1234567890),
iat: Some(1234567800),
sub: Some("user123".to_string()),
custom: [
("email".to_string(), json!("user@example.com")),
("full_name".to_string(), json!("John Doe")),
("role".to_string(), json!("admin")),
]
.into_iter()
.collect(),
};
let context = claims.to_context(&["email".to_string(), "full_name".to_string()]);
assert_eq!(context.get("email"), Some(&json!("user@example.com")));
assert_eq!(context.get("full_name"), Some(&json!("John Doe")));
assert_eq!(context.get("sub"), Some(&json!("user123")));
assert!(!context.contains_key("role")); }
#[test]
fn test_user_context() {
let unauthenticated = UserContext::unauthenticated();
assert!(!unauthenticated.authenticated);
let authenticated = UserContext::from_claims(
[("email".to_string(), json!("user@example.com"))]
.into_iter()
.collect(),
);
assert!(authenticated.authenticated);
let context = authenticated.to_context();
assert_eq!(context.get("authenticated"), Some(&json!(true)));
assert_eq!(context.get("email"), Some(&json!("user@example.com")));
}
#[test]
fn test_auth_handler_parse_jwt_cookie() {
let config = AuthConfig {
enabled: true,
jwt_cookie_name: "w_token".to_string(),
..Default::default()
};
let handler = AuthHandler::new(config);
let cookie_header = Some("w_token=abc123; other_cookie=xyz");
let result = handler.parse_jwt_cookie(cookie_header);
assert_eq!(result, Some("abc123".to_string()));
let cookie_header = Some("other=value; w_token=def456");
let result = handler.parse_jwt_cookie(cookie_header);
assert_eq!(result, Some("def456".to_string()));
let cookie_header = Some("other_cookie=xyz");
let result = handler.parse_jwt_cookie(cookie_header);
assert!(result.is_none());
let result = handler.parse_jwt_cookie(None);
assert!(result.is_none());
}
#[test]
fn test_auth_handler_is_protected() {
let config = AuthConfig {
enabled: true,
protected_paths: vec![
"/admin".to_string(),
"/admin/*".to_string(),
"/api/**".to_string(),
],
..Default::default()
};
let handler = AuthHandler::new(config);
assert!(handler.is_protected("/admin"));
assert!(handler.is_protected("/admin/users"));
assert!(handler.is_protected("/admin/settings"));
assert!(!handler.is_protected("/admin/users/123"));
assert!(handler.is_protected("/api/v1"));
assert!(handler.is_protected("/api/v1/users"));
assert!(handler.is_protected("/api/v1/users/123"));
assert!(!handler.is_protected("/"));
assert!(!handler.is_protected("/public"));
assert!(!handler.is_protected("/login"));
}
#[test]
fn test_auth_handler_disabled() {
let config = AuthConfig {
enabled: false,
protected_paths: vec!["/admin/**".to_string()],
..Default::default()
};
let handler = AuthHandler::new(config);
assert!(!handler.is_protected("/admin"));
assert!(!handler.is_protected("/admin/users"));
assert!(!handler.is_enabled());
}
#[test]
fn test_build_jwt_cookie() {
let config = AuthConfig {
enabled: true,
jwt_cookie_name: "w_token".to_string(),
..Default::default()
};
let handler = AuthHandler::new(config);
let cookie = handler.build_jwt_cookie("test_token_123", 3600, false);
assert!(cookie.contains("w_token=test_token_123"));
assert!(cookie.contains("HttpOnly"));
assert!(cookie.contains("SameSite=Strict"));
assert!(cookie.contains("Path=/"));
assert!(cookie.contains("Max-Age=3600"));
assert!(!cookie.contains("Secure"));
let cookie = handler.build_jwt_cookie("test_token_123", 3600, true);
assert!(cookie.contains("Secure"));
}
#[test]
fn test_build_clear_cookie() {
let config = AuthConfig {
enabled: true,
jwt_cookie_name: "w_token".to_string(),
..Default::default()
};
let handler = AuthHandler::new(config);
let cookie = handler.build_clear_cookie();
assert!(cookie.contains("w_token="));
assert!(cookie.contains("Max-Age=0"));
assert!(cookie.contains("HttpOnly"));
assert!(cookie.contains("SameSite=Strict"));
assert!(cookie.contains("Path=/"));
}
#[test]
fn test_jwt_claims_is_expired() {
let future_exp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let claims = JwtClaims {
exp: Some(future_exp),
iat: None,
sub: None,
custom: HashMap::new(),
};
assert!(!claims.is_expired());
let past_exp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
- 3600;
let expired_claims = JwtClaims {
exp: Some(past_exp),
iat: None,
sub: None,
custom: HashMap::new(),
};
assert!(expired_claims.is_expired());
let no_exp_claims = JwtClaims {
exp: None,
iat: None,
sub: None,
custom: HashMap::new(),
};
assert!(!no_exp_claims.is_expired());
}
#[test]
fn test_decode_jwt_with_configured_secret() {
use jsonwebtoken::{EncodingKey, Header, encode};
let secret = "test_secret_123";
let config = AuthConfig {
enabled: true,
jwt_secret: Some(secret.to_string()),
..Default::default()
};
let handler = AuthHandler::new(config);
let exp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let claims = JwtClaims {
exp: Some(exp),
iat: None,
sub: Some("user1".to_string()),
custom: [("email".to_string(), json!("a@b.com"))]
.into_iter()
.collect(),
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let decoded = handler.decode_jwt(&token).unwrap();
assert_eq!(decoded.sub, Some("user1".to_string()));
assert_eq!(decoded.custom.get("email"), Some(&json!("a@b.com")));
}
#[test]
fn test_decode_jwt_rejects_wrong_secret() {
use jsonwebtoken::{EncodingKey, Header, encode};
let config = AuthConfig {
enabled: true,
jwt_secret: Some("correct_secret".to_string()),
..Default::default()
};
let handler = AuthHandler::new(config);
let exp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let claims = JwtClaims {
exp: Some(exp),
iat: None,
sub: None,
custom: HashMap::new(),
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"wrong_secret"),
)
.unwrap();
let result = handler.decode_jwt(&token);
assert!(
result.is_err(),
"Should reject JWT signed with wrong secret"
);
}
#[test]
fn test_decode_jwt_no_secret_uses_auto_generated() {
use jsonwebtoken::{EncodingKey, Header, encode};
let config = AuthConfig {
enabled: true,
jwt_secret: None, ..Default::default()
};
let handler = AuthHandler::new(config);
let exp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let claims = JwtClaims {
exp: Some(exp),
iat: None,
sub: None,
custom: HashMap::new(),
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"attacker_secret"),
)
.unwrap();
let result = handler.decode_jwt(&token);
assert!(
result.is_err(),
"Should reject JWT when no secret is configured (auto-generated secret won't match)"
);
}
#[test]
fn test_auth_handler_getters() {
let config = AuthConfig {
enabled: true,
login_path: "/login".to_string(),
after_login: "/dashboard".to_string(),
login_endpoint: Some("https://api.example.com/login".to_string()),
logout_endpoint: Some("https://api.example.com/logout".to_string()),
jwt_cookie_name: "auth_token".to_string(),
jwt_claims: vec!["email".to_string(), "name".to_string()],
..Default::default()
};
let handler = AuthHandler::new(config);
assert!(handler.is_enabled());
assert_eq!(handler.login_path(), "/login");
assert_eq!(handler.after_login_path(), "/dashboard");
assert_eq!(
handler.login_endpoint(),
Some("https://api.example.com/login")
);
assert_eq!(
handler.logout_endpoint(),
Some("https://api.example.com/logout")
);
assert_eq!(handler.jwt_cookie_name(), "auth_token");
assert_eq!(
handler.jwt_claims(),
&["email".to_string(), "name".to_string()]
);
}
}