use axum::{
extract::{Extension, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{debug, info, instrument};
use crate::{
auth::AuthUser,
error::{ApiError, ApiResult},
jwt::{Claims, JwtManager, TokenPair},
responses::{ok, ApiResponse},
};
#[derive(Clone, Debug)]
pub struct AuthHandlerState {
jwt_manager: Arc<JwtManager>,
}
impl AuthHandlerState {
pub fn new(jwt_manager: JwtManager) -> Self {
Self {
jwt_manager: Arc::new(jwt_manager),
}
}
pub fn jwt_manager(&self) -> &JwtManager {
&self.jwt_manager
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
#[serde(flatten)]
pub token_pair: TokenPair,
pub user: UserInfo,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserInfo {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub roles: Vec<String>,
}
impl UserInfo {
pub fn from_claims(claims: &Claims) -> Self {
Self {
id: claims.sub.clone(),
email: claims.email.clone(),
roles: claims.roles.clone(),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct RefreshTokenRequest {
pub refresh_token: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RefreshTokenResponse {
#[serde(flatten)]
pub token_pair: TokenPair,
}
#[instrument(skip(state, request))]
pub async fn login(
State(state): State<AuthHandlerState>,
Json(request): Json<LoginRequest>,
) -> ApiResult<(StatusCode, Json<ApiResponse<LoginResponse>>)> {
info!("Login attempt for user: {}", request.username);
if request.username.is_empty() || request.password.is_empty() {
return Err(ApiError::bad_request("Username and password are required"));
}
let claims = Claims::new(
&request.username,
state.jwt_manager().config.issuer.clone(),
state.jwt_manager().config.audience.clone(),
state.jwt_manager().config.expiration_seconds,
)
.with_email(format!("{}@example.com", request.username))
.with_role("user");
let token_pair = state
.jwt_manager()
.generate_token_pair(&request.username)
.map_err(|e| ApiError::internal_server_error(format!("Failed to generate token: {}", e)))?;
let response = LoginResponse {
token_pair,
user: UserInfo::from_claims(&claims),
};
info!("User logged in successfully: {}", request.username);
Ok((StatusCode::OK, Json(ok(response))))
}
#[instrument(skip(state, request))]
pub async fn refresh_token(
State(state): State<AuthHandlerState>,
Json(request): Json<RefreshTokenRequest>,
) -> ApiResult<Json<ApiResponse<RefreshTokenResponse>>> {
debug!("Token refresh requested");
let token_pair = state
.jwt_manager()
.refresh_access_token(&request.refresh_token)
.map_err(|e| match e {
crate::jwt::TokenError::Expired => {
ApiError::unauthorized("Refresh token has expired")
}
crate::jwt::TokenError::InvalidClaims(_) => {
ApiError::bad_request("Invalid refresh token")
}
_ => ApiError::unauthorized("Invalid refresh token"),
})?;
let response = RefreshTokenResponse { token_pair };
debug!("Token refreshed successfully");
Ok(Json(ok(response)))
}
#[instrument(skip(user))]
pub async fn me(
Extension(user): Extension<AuthUser>,
) -> ApiResult<Json<ApiResponse<UserInfo>>> {
debug!("Current user info requested");
let user_info = UserInfo::from_claims(&user.claims);
Ok(Json(ok(user_info)))
}
#[instrument(skip(user))]
pub async fn logout(
Extension(user): Extension<AuthUser>,
) -> ApiResult<Json<ApiResponse<LogoutResponse>>> {
info!("User logout: {}", user.user_id());
let response = LogoutResponse {
message: "Logged out successfully".to_string(),
};
Ok(Json(ok(response)))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LogoutResponse {
pub message: String,
}
#[instrument(skip(user))]
pub async fn generate_api_key(
State(state): State<AuthHandlerState>,
Extension(user): Extension<AuthUser>,
) -> ApiResult<Json<ApiResponse<ApiKeyResponse>>> {
info!("Generating API key for user: {}", user.user_id());
if !user.has_role("admin") && !user.has_role("developer") {
return Err(ApiError::forbidden(
"Only admin or developer roles can generate API keys",
));
}
let claims = Claims::new(
user.user_id(),
state.jwt_manager().config.issuer.clone(),
state.jwt_manager().config.audience.clone(),
86400 * 30, )
.with_roles(user.claims.roles.clone())
.with_custom("api_key", serde_json::json!(true));
let api_key = state
.jwt_manager()
.generate_token_with_claims(claims)
.map_err(|e| ApiError::internal_server_error(format!("Failed to generate API key: {}", e)))?;
let response = ApiKeyResponse { api_key };
info!("API key generated for user: {}", user.user_id());
Ok(Json(ok(response)))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiKeyResponse {
pub api_key: String,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jwt::JwtConfig;
fn create_test_state() -> AuthHandlerState {
let config = JwtConfig::new("test-secret")
.with_issuer("test")
.with_audience("test");
let jwt_manager = JwtManager::new(config).unwrap();
AuthHandlerState::new(jwt_manager)
}
#[test]
fn test_user_info_from_claims() {
let claims = Claims::new("user123", "test", "test", 3600)
.with_email("user@example.com")
.with_role("admin");
let user_info = UserInfo::from_claims(&claims);
assert_eq!(user_info.id, "user123");
assert_eq!(user_info.email, Some("user@example.com".to_string()));
assert_eq!(user_info.roles, vec!["admin"]);
}
#[tokio::test]
async fn test_login_request_validation() {
let state = create_test_state();
let request = LoginRequest {
username: "".to_string(),
password: "password".to_string(),
};
let result = login(State(state), Json(request)).await;
assert!(result.is_err());
}
}