use axum::{extract::State, Json};
use chrono::{Duration, Utc};
use serde::{Deserialize, Serialize};
use crate::{
auth::{
create_token_pair, hash_password, verify_password, verify_refresh_token,
REFRESH_TOKEN_EXPIRY_DAYS,
},
error::{ApiError, ApiResult},
middleware::AuthUser,
models::organization::Plan,
AppState,
};
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub username: String,
pub email: String,
pub password: String,
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub email: String,
pub password: String,
pub two_factor_code: Option<String>, }
#[derive(Debug, Serialize)]
pub struct AuthResponse {
pub token: String,
pub user_id: String,
pub username: String,
}
#[derive(Debug, Serialize)]
pub struct AuthResponseV2 {
pub access_token: String,
pub refresh_token: String,
pub access_token_expires_at: i64,
pub refresh_token_expires_at: i64,
pub user_id: String,
pub username: String,
}
pub async fn register(
State(state): State<AppState>,
Json(request): Json<RegisterRequest>,
) -> ApiResult<Json<AuthResponseV2>> {
if request.username.len() < 3 {
return Err(ApiError::InvalidRequest("Username must be at least 3 characters".to_string()));
}
if request.password.len() < 8 {
return Err(ApiError::InvalidRequest("Password must be at least 8 characters".to_string()));
}
if state.store.find_user_by_email(&request.email).await?.is_some() {
return Err(ApiError::InvalidRequest("Email already registered".to_string()));
}
if state.store.find_user_by_username(&request.username).await?.is_some() {
return Err(ApiError::InvalidRequest("Username already taken".to_string()));
}
let password_hash = hash_password(&request.password).map_err(ApiError::Internal)?;
let user = state
.store
.create_user(&request.username, &request.email, &password_hash)
.await?;
let org_slug = format!("{}-personal", request.username.to_lowercase().replace(' ', "-"));
if let Err(e) = state
.store
.create_organization(&format!("{}'s Org", request.username), &org_slug, user.id, Plan::Free)
.await
{
tracing::warn!("Failed to create personal org for user {}: {}", user.id, e);
}
let (token_pair, jti) = create_token_pair(&user.id.to_string(), &state.config.jwt_secret)
.map_err(ApiError::Internal)?;
let expires_at = Utc::now()
.checked_add_signed(Duration::days(REFRESH_TOKEN_EXPIRY_DAYS))
.ok_or_else(|| ApiError::Internal(anyhow::anyhow!("Failed to calculate token expiry")))?;
state.db.store_refresh_token_jti(&jti, user.id, expires_at).await.map_err(|e| {
tracing::warn!("Failed to store refresh token JTI: {}", e);
ApiError::Internal(e)
})?;
Ok(Json(AuthResponseV2 {
access_token: token_pair.access_token,
refresh_token: token_pair.refresh_token,
access_token_expires_at: token_pair.access_token_expires_at,
refresh_token_expires_at: token_pair.refresh_token_expires_at,
user_id: user.id.to_string(),
username: user.username,
}))
}
pub async fn login(
State(state): State<AppState>,
Json(request): Json<LoginRequest>,
) -> ApiResult<Json<AuthResponseV2>> {
let user = state
.store
.find_user_by_email(&request.email)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Invalid email or password".to_string()))?;
let valid =
verify_password(&request.password, &user.password_hash).map_err(ApiError::Internal)?;
if !valid {
return Err(ApiError::InvalidRequest("Invalid email or password".to_string()));
}
if user.two_factor_enabled {
let code = request
.two_factor_code
.ok_or_else(|| ApiError::InvalidRequest("2FA code is required".to_string()))?;
let secret = user.two_factor_secret.ok_or_else(|| {
ApiError::Internal(anyhow::anyhow!("2FA enabled but no secret found"))
})?;
use crate::two_factor::verify_totp_code;
let totp_valid = verify_totp_code(&secret, &code, Some(1))
.map_err(|e| ApiError::Internal(anyhow::anyhow!("TOTP verification error: {}", e)))?;
if !totp_valid {
let mut backup_valid = false;
if let Some(backup_codes) = &user.two_factor_backup_codes {
use crate::two_factor::verify_backup_code;
for (index, hashed_code) in backup_codes.iter().enumerate() {
if verify_backup_code(&code, hashed_code).map_err(|e| {
ApiError::Internal(anyhow::anyhow!("Backup code verification error: {}", e))
})? {
state.store.remove_user_backup_code(user.id, index).await?;
backup_valid = true;
break;
}
}
}
if !backup_valid {
return Err(ApiError::InvalidRequest("Invalid 2FA code".to_string()));
}
}
state.store.update_user_2fa_verified(user.id).await?;
}
let (token_pair, jti) = create_token_pair(&user.id.to_string(), &state.config.jwt_secret)
.map_err(ApiError::Internal)?;
let expires_at = Utc::now()
.checked_add_signed(Duration::days(REFRESH_TOKEN_EXPIRY_DAYS))
.ok_or_else(|| ApiError::Internal(anyhow::anyhow!("Failed to calculate token expiry")))?;
state.db.store_refresh_token_jti(&jti, user.id, expires_at).await.map_err(|e| {
tracing::warn!("Failed to store refresh token JTI: {}", e);
ApiError::Internal(e)
})?;
Ok(Json(AuthResponseV2 {
access_token: token_pair.access_token,
refresh_token: token_pair.refresh_token,
access_token_expires_at: token_pair.access_token_expires_at,
refresh_token_expires_at: token_pair.refresh_token_expires_at,
user_id: user.id.to_string(),
username: user.username,
}))
}
#[derive(Debug, Deserialize)]
pub struct RefreshTokenRequest {
pub refresh_token: String,
}
#[derive(Debug, Serialize)]
pub struct RefreshTokenResponse {
pub access_token: String,
pub refresh_token: String,
pub access_token_expires_at: i64,
pub refresh_token_expires_at: i64,
}
pub async fn refresh_token(
State(state): State<AppState>,
Json(request): Json<RefreshTokenRequest>,
) -> ApiResult<Json<RefreshTokenResponse>> {
let (claims, old_jti) = verify_refresh_token(&request.refresh_token, &state.config.jwt_secret)
.map_err(|e| {
tracing::debug!("Refresh token validation failed: {}", e);
ApiError::InvalidRequest("Invalid or expired refresh token".to_string())
})?;
let is_revoked = state.db.is_token_revoked(&old_jti).await.map_err(|e| {
tracing::warn!("Failed to check token revocation status: {}", e);
ApiError::Internal(e)
})?;
if is_revoked {
tracing::warn!("Attempt to use revoked refresh token: jti={}", old_jti);
return Err(ApiError::InvalidRequest("Refresh token has been revoked".to_string()));
}
let user_id = uuid::Uuid::parse_str(&claims.sub)
.map_err(|_| ApiError::InvalidRequest("Invalid user ID".to_string()))?;
let user = state
.store
.find_user_by_id(user_id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("User not found".to_string()))?;
state.db.revoke_token(&old_jti, "refresh").await.map_err(|e| {
tracing::warn!("Failed to revoke old refresh token: {}", e);
ApiError::Internal(e)
})?;
let (token_pair, new_jti) = create_token_pair(&user.id.to_string(), &state.config.jwt_secret)
.map_err(ApiError::Internal)?;
let expires_at = Utc::now()
.checked_add_signed(Duration::days(REFRESH_TOKEN_EXPIRY_DAYS))
.ok_or_else(|| ApiError::Internal(anyhow::anyhow!("Failed to calculate token expiry")))?;
state
.db
.store_refresh_token_jti(&new_jti, user.id, expires_at)
.await
.map_err(|e| {
tracing::warn!("Failed to store new refresh token JTI: {}", e);
ApiError::Internal(e)
})?;
Ok(Json(RefreshTokenResponse {
access_token: token_pair.access_token,
refresh_token: token_pair.refresh_token,
access_token_expires_at: token_pair.access_token_expires_at,
refresh_token_expires_at: token_pair.refresh_token_expires_at,
}))
}
use crate::email::EmailService;
#[derive(Debug, Deserialize)]
pub struct PasswordResetRequest {
pub email: String,
}
#[derive(Debug, Serialize)]
pub struct PasswordResetRequestResponse {
pub success: bool,
pub message: String,
}
pub async fn request_password_reset(
State(state): State<AppState>,
Json(request): Json<PasswordResetRequest>,
) -> ApiResult<Json<PasswordResetRequestResponse>> {
let user = match state.store.find_user_by_email(&request.email).await? {
Some(user) => user,
None => {
return Ok(Json(PasswordResetRequestResponse {
success: true,
message:
"If an account with that email exists, a password reset link has been sent."
.to_string(),
}));
}
};
let reset_token = state.store.create_verification_token(user.id).await?;
state.store.set_verification_token_expiry_hours(reset_token.id, 1).await?;
let email_service = match EmailService::from_env() {
Ok(service) => service,
Err(e) => {
tracing::warn!("Failed to create email service: {}", e);
return Ok(Json(PasswordResetRequestResponse {
success: true,
message:
"If an account with that email exists, a password reset link has been sent."
.to_string(),
}));
}
};
let reset_email = EmailService::generate_password_reset_email(
&user.username,
&user.email,
&reset_token.token,
);
tokio::spawn(async move {
if let Err(e) = email_service.send(reset_email).await {
tracing::warn!("Failed to send password reset email: {}", e);
}
});
tracing::info!("Password reset requested: user_id={}, email={}", user.id, user.email);
Ok(Json(PasswordResetRequestResponse {
success: true,
message: "If an account with that email exists, a password reset link has been sent."
.to_string(),
}))
}
#[derive(Debug, Deserialize)]
pub struct PasswordResetConfirmRequest {
pub token: String,
pub new_password: String,
}
#[derive(Debug, Serialize)]
pub struct PasswordResetConfirmResponse {
pub success: bool,
pub message: String,
}
pub async fn confirm_password_reset(
State(state): State<AppState>,
Json(request): Json<PasswordResetConfirmRequest>,
) -> ApiResult<Json<PasswordResetConfirmResponse>> {
if request.new_password.len() < 8 {
return Err(ApiError::InvalidRequest("Password must be at least 8 characters".to_string()));
}
let reset_token = state
.store
.find_verification_token_by_token(&request.token)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Invalid or expired reset token".to_string()))?;
if !reset_token.is_valid() {
return Err(ApiError::InvalidRequest(
"Reset token has expired or already been used".to_string(),
));
}
let user = state
.store
.find_user_by_id(reset_token.user_id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("User not found".to_string()))?;
let password_hash = hash_password(&request.new_password).map_err(ApiError::Internal)?;
state.store.update_user_password_hash(user.id, &password_hash).await?;
let revoked_count =
state.db.revoke_all_user_tokens(user.id, "password_reset").await.map_err(|e| {
tracing::warn!("Failed to revoke user tokens on password reset: {}", e);
ApiError::Internal(e)
})?;
tracing::info!(
"Revoked {} refresh tokens for user {} on password reset",
revoked_count,
user.id
);
state.store.mark_verification_token_used(reset_token.id).await?;
tracing::info!("Password reset completed: user_id={}, email={}", user.id, user.email);
Ok(Json(PasswordResetConfirmResponse {
success: true,
message: "Password has been reset successfully. You can now log in with your new password."
.to_string(),
}))
}
#[derive(Debug, Serialize)]
pub struct VerifyTokenResponse {
pub valid: bool,
pub user_id: String,
pub username: String,
pub email: String,
}
pub async fn verify_token(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
) -> ApiResult<Json<VerifyTokenResponse>> {
let user = state
.store
.find_user_by_id(user_id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("User not found".to_string()))?;
Ok(Json(VerifyTokenResponse {
valid: true,
user_id: user.id.to_string(),
username: user.username,
email: user.email,
}))
}
#[derive(Debug, Serialize)]
pub struct MeResponse {
pub user_id: String,
pub username: String,
pub email: String,
pub is_verified: bool,
pub is_admin: bool,
pub two_factor_enabled: bool,
pub email_notifications: bool,
pub security_alerts: bool,
pub preferences: serde_json::Value,
pub created_at: chrono::DateTime<chrono::Utc>,
}
pub async fn me(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
) -> ApiResult<Json<MeResponse>> {
let user = state
.store
.find_user_by_id(user_id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("User not found".to_string()))?;
Ok(Json(MeResponse {
user_id: user.id.to_string(),
username: user.username,
email: user.email,
is_verified: user.is_verified,
is_admin: user.is_admin,
two_factor_enabled: user.two_factor_enabled,
email_notifications: user.email_notifications,
security_alerts: user.security_alerts,
preferences: user.preferences,
created_at: user.created_at,
}))
}
#[derive(Debug, Deserialize)]
pub struct ChangePasswordRequest {
pub current_password: String,
pub new_password: String,
}
#[derive(Debug, Serialize)]
pub struct ChangePasswordResponse {
pub success: bool,
pub message: String,
}
pub async fn change_password(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
Json(request): Json<ChangePasswordRequest>,
) -> ApiResult<Json<ChangePasswordResponse>> {
if request.new_password.len() < 8 {
return Err(ApiError::InvalidRequest("Password must be at least 8 characters".to_string()));
}
if request.new_password == request.current_password {
return Err(ApiError::InvalidRequest(
"New password must differ from the current password".to_string(),
));
}
let user = state
.store
.find_user_by_id(user_id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("User not found".to_string()))?;
if !verify_password(&request.current_password, &user.password_hash)
.map_err(ApiError::Internal)?
{
return Err(ApiError::InvalidRequest("Current password is incorrect".to_string()));
}
let password_hash = hash_password(&request.new_password).map_err(ApiError::Internal)?;
state.store.update_user_password_hash(user.id, &password_hash).await?;
let revoked_count = state
.db
.revoke_all_user_tokens(user.id, "password_changed")
.await
.map_err(|e| {
tracing::warn!("Failed to revoke user tokens on password change: {}", e);
ApiError::Internal(e)
})?;
tracing::info!(
"Password changed: user_id={}, revoked {} refresh tokens",
user.id,
revoked_count
);
if user.security_alerts {
if let Ok(email_service) = EmailService::from_env() {
let msg = EmailService::generate_security_alert_email(
&user.username,
&user.email,
"Your password was changed",
"If you did not perform this change, reset your password immediately and contact support.",
);
if let Err(e) = email_service.send(msg).await {
tracing::warn!("Failed to send password-change security alert: {}", e);
}
}
}
Ok(Json(ChangePasswordResponse {
success: true,
message: "Password changed successfully. Other sessions have been signed out.".to_string(),
}))
}