use axum::{extract::State, Json};
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::state::AppState;
use dashmap::DashSet;
use once_cell::sync::Lazy;
static JWT_SECRET: Lazy<Vec<u8>> = Lazy::new(|| {
std::env::var("AINGLE_JWT_SECRET")
.expect(
"AINGLE_JWT_SECRET environment variable must be set. \
Generate one with: openssl rand -base64 64",
)
.into_bytes()
});
static REVOKED_TOKENS: Lazy<DashSet<String>> = Lazy::new(DashSet::new);
const TOKEN_EXPIRATION_HOURS: i64 = 24;
const REFRESH_TOKEN_EXPIRATION_DAYS: i64 = 7;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
pub exp: i64,
pub iat: i64,
pub roles: Vec<String>,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub namespace: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
}
impl Claims {
pub fn new_access(user_id: &str, roles: Vec<String>) -> Self {
let now = Utc::now();
Self {
sub: user_id.to_string(),
username: None,
exp: (now + Duration::hours(TOKEN_EXPIRATION_HOURS)).timestamp(),
iat: now.timestamp(),
roles,
token_type: "access".to_string(),
namespace: None,
jti: None,
}
}
pub fn new_access_with_username(user_id: &str, username: &str, roles: Vec<String>) -> Self {
let now = Utc::now();
Self {
sub: user_id.to_string(),
username: Some(username.to_string()),
exp: (now + Duration::hours(TOKEN_EXPIRATION_HOURS)).timestamp(),
iat: now.timestamp(),
roles,
token_type: "access".to_string(),
namespace: None,
jti: None,
}
}
pub fn new_access_with_namespace(
user_id: &str,
username: &str,
roles: Vec<String>,
namespace: String,
) -> Self {
let now = Utc::now();
Self {
sub: user_id.to_string(),
username: Some(username.to_string()),
exp: (now + Duration::hours(TOKEN_EXPIRATION_HOURS)).timestamp(),
iat: now.timestamp(),
roles,
token_type: "access".to_string(),
namespace: Some(namespace),
jti: None,
}
}
pub fn new_refresh(user_id: &str) -> Self {
let now = Utc::now();
Self {
sub: user_id.to_string(),
username: None,
exp: (now + Duration::days(REFRESH_TOKEN_EXPIRATION_DAYS)).timestamp(),
iat: now.timestamp(),
roles: vec![],
token_type: "refresh".to_string(),
namespace: None,
jti: Some(uuid::Uuid::new_v4().to_string()),
}
}
pub fn is_expired(&self) -> bool {
Utc::now().timestamp() > self.exp
}
pub fn has_role(&self, role: &str) -> bool {
self.roles.iter().any(|r| r == role)
}
}
#[derive(Debug, Deserialize)]
pub struct CreateTokenRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct TokenResponse {
pub access_token: String,
pub refresh_token: String,
pub token_type: String,
pub expires_in: i64,
}
pub async fn create_token(
State(state): State<AppState>,
Json(req): Json<CreateTokenRequest>,
) -> Result<Json<TokenResponse>> {
let user = state
.user_store
.validate_credentials(&req.username, &req.password)
.map_err(|_| Error::AuthError("Invalid credentials".into()))?;
let access_claims =
Claims::new_access_with_username(&user.id, &user.username, user.roles.clone());
let refresh_claims = Claims::new_refresh(&user.id);
let access_token = encode(
&Header::default(),
&access_claims,
&EncodingKey::from_secret(&JWT_SECRET),
)
.map_err(|e| Error::Internal(format!("Failed to create access token: {}", e)))?;
let refresh_token = encode(
&Header::default(),
&refresh_claims,
&EncodingKey::from_secret(&JWT_SECRET),
)
.map_err(|e| Error::Internal(format!("Failed to create refresh token: {}", e)))?;
Ok(Json(TokenResponse {
access_token,
refresh_token,
token_type: "Bearer".to_string(),
expires_in: TOKEN_EXPIRATION_HOURS * 3600,
}))
}
#[derive(Debug, Deserialize)]
pub struct RefreshTokenRequest {
pub refresh_token: String,
}
pub async fn refresh_token(
State(_state): State<AppState>,
Json(req): Json<RefreshTokenRequest>,
) -> Result<Json<TokenResponse>> {
let claims = decode::<Claims>(
&req.refresh_token,
&DecodingKey::from_secret(&JWT_SECRET),
&Validation::new(Algorithm::HS256),
)
.map_err(|e| Error::AuthError(format!("Invalid refresh token: {}", e)))?;
if claims.claims.token_type != "refresh" {
return Err(Error::AuthError("Invalid token type".to_string()));
}
if claims.claims.is_expired() {
return Err(Error::AuthError("Refresh token expired".to_string()));
}
if let Some(ref jti) = claims.claims.jti {
if !REVOKED_TOKENS.insert(jti.clone()) {
return Err(Error::AuthError("Refresh token already used".to_string()));
}
} else {
return Err(Error::AuthError("Refresh token missing JTI".to_string()));
}
let roles = vec!["user".to_string()];
let access_claims = Claims::new_access(&claims.claims.sub, roles);
let refresh_claims = Claims::new_refresh(&claims.claims.sub);
let access_token = encode(
&Header::default(),
&access_claims,
&EncodingKey::from_secret(&JWT_SECRET),
)
.map_err(|e| Error::Internal(format!("Failed to create access token: {}", e)))?;
let refresh_token = encode(
&Header::default(),
&refresh_claims,
&EncodingKey::from_secret(&JWT_SECRET),
)
.map_err(|e| Error::Internal(format!("Failed to create refresh token: {}", e)))?;
Ok(Json(TokenResponse {
access_token,
refresh_token,
token_type: "Bearer".to_string(),
expires_in: TOKEN_EXPIRATION_HOURS * 3600,
}))
}
#[derive(Debug, Deserialize)]
pub struct VerifyTokenRequest {
pub token: String,
}
#[derive(Debug, Serialize)]
pub struct VerifyTokenResponse {
pub valid: bool,
pub user_id: Option<String>,
pub roles: Option<Vec<String>>,
pub expires_at: Option<i64>,
}
pub async fn verify_token_endpoint(
Json(req): Json<VerifyTokenRequest>,
) -> Json<VerifyTokenResponse> {
match verify_token(&req.token) {
Ok(claims) => Json(VerifyTokenResponse {
valid: true,
user_id: Some(claims.sub),
roles: Some(claims.roles),
expires_at: Some(claims.exp),
}),
Err(_) => Json(VerifyTokenResponse {
valid: false,
user_id: None,
roles: None,
expires_at: None,
}),
}
}
pub fn verify_token(token: &str) -> Result<Claims> {
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(&JWT_SECRET),
&Validation::new(Algorithm::HS256),
)
.map_err(|e| Error::AuthError(format!("Invalid token: {}", e)))?;
if token_data.claims.is_expired() {
return Err(Error::AuthError("Token expired".to_string()));
}
Ok(token_data.claims)
}
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct RegisterResponse {
pub user_id: String,
pub username: String,
pub message: String,
}
pub async fn register(
State(state): State<AppState>,
Json(req): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>> {
let user = state
.user_store
.create_user(&req.username, &req.password, vec!["user".into()])
.map_err(Error::InvalidInput)?;
Ok(Json(RegisterResponse {
user_id: user.id,
username: user.username,
message: "User registered successfully".to_string(),
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_claims_creation() {
let claims = Claims::new_access("user123", vec!["admin".to_string()]);
assert_eq!(claims.sub, "user123");
assert_eq!(claims.token_type, "access");
assert!(claims.has_role("admin"));
assert!(!claims.is_expired());
}
#[test]
fn test_token_roundtrip() {
std::env::set_var("AINGLE_JWT_SECRET", "test-secret-only-do-not-use-in-production-64bytes-pad");
let claims = Claims::new_access("user123", vec!["user".to_string()]);
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(&JWT_SECRET),
)
.unwrap();
let verified = verify_token(&token).unwrap();
assert_eq!(verified.sub, "user123");
}
}