use axum::{
Json,
extract::FromRequestParts,
http::{StatusCode, header::AUTHORIZATION, request::Parts},
response::{IntoResponse, Response},
};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use std::env;
fn get_jwt_secret() -> String {
env::var("JWT_SECRET").unwrap_or_else(|_| "your-secret-key-change-in-production".to_string())
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub exp: usize,
pub iat: usize,
pub role: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiKey {
pub key: String,
pub name: String,
pub role: String,
pub created_at: u64,
}
#[derive(Debug, Clone)]
pub struct AuthUser {
pub id: String,
pub role: String,
}
#[derive(Debug, Serialize)]
pub struct AuthError {
pub error: String,
pub message: String,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
(StatusCode::UNAUTHORIZED, Json(self)).into_response()
}
}
impl<S> FromRequestParts<S> for AuthUser
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(auth_header) = parts.headers.get(AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with("Bearer ") {
let token = &auth_str[7..];
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(get_jwt_secret().as_ref()),
&Validation::default(),
)
.map_err(|_| AuthError {
error: "invalid_token".to_string(),
message: "Invalid or expired JWT token".to_string(),
})?;
return Ok(AuthUser {
id: token_data.claims.sub,
role: token_data.claims.role,
});
}
}
}
if let Some(api_key) = parts.headers.get("X-API-Key") {
let key = api_key.to_str().map_err(|_| AuthError {
error: "invalid_api_key".to_string(),
message: "Invalid API key format".to_string(),
})?;
if validate_api_key(key) {
return Ok(AuthUser {
id: format!("api_key_{}", key.chars().take(8).collect::<String>()),
role: "api_user".to_string(),
});
}
}
Err(AuthError {
error: "no_auth".to_string(),
message: "No valid authentication provided".to_string(),
})
}
}
pub fn generate_jwt(user_id: &str, role: &str) -> Result<String, jsonwebtoken::errors::Error> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as usize;
let claims = Claims {
sub: user_id.to_owned(),
role: role.to_owned(),
iat: now,
exp: now + 86400, };
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(get_jwt_secret().as_ref()),
)
}
fn validate_api_key(key: &str) -> bool {
let api_tokens = env::var("API_TOKENS")
.unwrap_or_else(|_| "sk_live_123456789,sk_test_987654321".to_string());
let valid_keys: Vec<&str> = api_tokens
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
valid_keys.contains(&key)
}
pub fn require_role(required_role: &str) -> impl Fn(&AuthUser) -> Result<(), AuthError> + '_ {
move |user: &AuthUser| {
if user.role == required_role || user.role == "admin" {
Ok(())
} else {
Err(AuthError {
error: "insufficient_permissions".to_string(),
message: format!("This endpoint requires {} role", required_role),
})
}
}
}
#[derive(Serialize, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Serialize)]
pub struct LoginResponse {
pub token: String,
pub user_id: String,
pub role: String,
}
pub async fn login(Json(payload): Json<LoginRequest>) -> Result<Json<LoginResponse>, AuthError> {
let (user_id, role) = if payload.username == "admin" && payload.password == "admin123" {
("admin_001".to_string(), "admin".to_string())
} else if payload.username == "user" && payload.password == "user123" {
("user_001".to_string(), "user".to_string())
} else {
return Err(AuthError {
error: "invalid_credentials".to_string(),
message: "Invalid username or password".to_string(),
});
};
let token = generate_jwt(&user_id, &role).map_err(|_| AuthError {
error: "token_generation_failed".to_string(),
message: "Failed to generate authentication token".to_string(),
})?;
Ok(Json(LoginResponse {
token,
user_id,
role,
}))
}
pub async fn generate_api_key(user: AuthUser) -> Result<Json<ApiKey>, AuthError> {
require_role("admin")(&user)?;
let key = format!(
"sk_live_{}",
uuid::Uuid::new_v4().to_string().replace("-", "")
);
let api_key = ApiKey {
key: key.clone(),
name: format!("Key for {}", user.id),
role: "api_user".to_string(),
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
};
Ok(Json(api_key))
}