use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Json,
};
use rand::Rng;
use serde::Deserialize;
use std::sync::Arc;
use std::time::Duration;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::callback::AuthCallback;
use crate::errors::AppError;
use crate::models::MessageResponse;
use crate::repositories::{
default_expiry, generate_verification_token, hash_verification_token, AuditEventType, TokenType,
};
use crate::services::EmailService;
use crate::AppState;
async fn add_timing_normalization_delay() {
let delay_ms = rand::thread_rng().gen_range(50..=150);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ForgotPasswordRequest {
pub email: String,
}
#[derive(Debug, Deserialize, Zeroize, ZeroizeOnDrop)]
#[serde(rename_all = "camelCase")]
pub struct ResetPasswordRequest {
pub token: String,
pub new_password: String,
}
pub async fn forgot_password<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<ForgotPasswordRequest>,
) -> Result<(StatusCode, Json<MessageResponse>), AppError> {
if !state.config.email.enabled {
return Err(AppError::NotFound("Email auth disabled".into()));
}
let response = (
StatusCode::OK,
Json(MessageResponse {
message: "If an account exists, a reset email has been sent".to_string(),
}),
);
let throttle_key = format!("password_reset:{}", req.email.trim().to_lowercase());
let throttle_status = state
.login_attempt_repo
.record_failed_attempt_atomic(None, &throttle_key, None, &state.login_attempt_config)
.await?;
if throttle_status.is_locked {
if let Some(remaining) = throttle_status.lockout_remaining_secs {
return Err(AppError::TooManyRequests(format!(
"Too many password reset requests. Try again in {} seconds",
remaining
)));
}
return Err(AppError::RateLimited);
}
let user = match state.user_repo.find_by_email(&req.email).await? {
Some(u) => u,
None => {
add_timing_normalization_delay().await;
return Ok(response); }
};
if user.password_hash.is_none() {
add_timing_normalization_delay().await;
return Ok(response); }
state
.verification_repo
.delete_for_user(user.id, TokenType::PasswordReset)
.await?;
let token = generate_verification_token();
let token_hash = hash_verification_token(&token);
state
.verification_repo
.create(
user.id,
&token_hash,
TokenType::PasswordReset,
default_expiry(TokenType::PasswordReset),
)
.await
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to create token: {}", e)))?;
state
.comms_service
.queue_password_reset_email(&req.email, user.name.as_deref(), &token, Some(user.id))
.await?;
if let Err(e) = state
.audit_service
.log_password_event(
AuditEventType::PasswordResetRequested,
user.id,
Some(&headers),
)
.await
{
tracing::warn!(error = %e, user_id = %user.id, "Failed to log password reset requested audit event");
}
Ok(response)
}
pub async fn reset_password<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<ResetPasswordRequest>,
) -> Result<(StatusCode, Json<MessageResponse>), AppError> {
if !state.config.email.enabled {
return Err(AppError::NotFound("Email auth disabled".into()));
}
state.password_service.validate(&req.new_password)?;
let token_hash = hash_verification_token(&req.token);
let token = state
.verification_repo
.consume_if_valid(&token_hash)
.await
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to consume token: {}", e)))?
.ok_or_else(|| AppError::Validation("Invalid or expired token".to_string()))?;
if token.token_type != TokenType::PasswordReset {
return Err(AppError::Validation("Invalid token type".to_string()));
}
let password_hash = state
.password_service
.hash(req.new_password.clone())
.await?;
state
.user_repo
.update_password(token.user_id, &password_hash)
.await?;
state
.session_repo
.revoke_all_for_user_with_reason(token.user_id, "password_reset")
.await?;
if let Err(e) = state
.audit_service
.log_password_event(
AuditEventType::PasswordResetCompleted,
token.user_id,
Some(&headers),
)
.await
{
tracing::warn!(error = %e, user_id = %token.user_id, "Failed to log password reset completed audit event");
}
Ok((
StatusCode::OK,
Json(MessageResponse {
message: "Password reset successfully".to_string(),
}),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forgot_password_request_deserialize() {
let json = r#"{"email": "test@example.com"}"#;
let req: ForgotPasswordRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.email, "test@example.com");
}
#[test]
fn test_reset_password_request_deserialize() {
let json = r#"{"token": "abc123", "newPassword": "NewPassword1!"}"#;
let req: ResetPasswordRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.token, "abc123");
assert_eq!(req.new_password, "NewPassword1!");
}
}