use std::sync::Arc;
use http::StatusCode;
use rustauth_core::api::{parse_request_body, ApiRequest, ApiResponse};
use rustauth_core::context::AuthContext;
use rustauth_core::error::RustAuthError;
use rustauth_core::options::PasswordResetPayload;
use rustauth_core::user::CreateCredentialAccountInput;
use serde::Deserialize;
use super::helpers::{resolve_otp, send_email, validated_email, verify_otp};
use super::otp;
use super::response;
use super::types::{EmailOtpOptions, EmailOtpType};
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct EmailBody {
email: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ResetPasswordBody {
email: String,
otp: String,
password: String,
}
pub(super) async fn request_password_reset(
context: AuthContext,
request: ApiRequest,
options: Arc<EmailOtpOptions>,
) -> Result<ApiResponse, RustAuthError> {
let body: EmailBody = parse_request_body(&request)?;
let email = match validated_email(&body.email)? {
Ok(email) => email,
Err(response) => return Ok(response),
};
let identifier = otp::identifier(EmailOtpType::ForgetPassword, &email);
let otp = resolve_otp(
&context,
&options,
&context.secret_config,
&email,
EmailOtpType::ForgetPassword,
&identifier,
)
.await?;
if context.users()?.find_user_by_email(&email).await?.is_none() {
context
.verifications()?
.delete_verification(&identifier)
.await?;
return response::success();
}
if let Some(response) = send_email(
&context,
&options,
&email,
otp,
EmailOtpType::ForgetPassword,
Some(&request),
)? {
return Ok(response);
}
response::success()
}
pub(super) async fn reset_password(
context: AuthContext,
request: ApiRequest,
options: Arc<EmailOtpOptions>,
) -> Result<ApiResponse, RustAuthError> {
let body: ResetPasswordBody = parse_request_body(&request)?;
let email = match validated_email(&body.email)? {
Ok(email) => email,
Err(response) => return Ok(response),
};
if body.password.len() < context.password.config.min_password_length {
return response::error(
StatusCode::BAD_REQUEST,
"PASSWORD_TOO_SHORT",
"Password too short",
);
}
if body.password.len() > context.password.config.max_password_length {
return response::error(
StatusCode::BAD_REQUEST,
"PASSWORD_TOO_LONG",
"Password too long",
);
}
if let Some(response) = verify_otp(
&context,
&options,
&context.secret_config,
&otp::identifier(EmailOtpType::ForgetPassword, &email),
&body.otp,
true,
)
.await?
{
return Ok(response);
}
let users = context.users()?;
let Some(user) = users.find_user_by_email(&email).await? else {
return response::error(StatusCode::BAD_REQUEST, "USER_NOT_FOUND", "User not found");
};
let password_hash = (context.password.hash)(&body.password)?;
if users.find_credential_account(&user.id).await?.is_some() {
users
.update_credential_password(&user.id, &password_hash)
.await?;
} else {
users
.create_credential_account(CreateCredentialAccountInput::new(&user.id, password_hash))
.await?;
}
let user = if !user.email_verified {
users
.update_user_email_verified(&user.id, true)
.await?
.unwrap_or(user)
} else {
user
};
if let Some(callback) = &context.options.password.on_password_reset {
callback.on_password_reset(PasswordResetPayload { user: user.clone() }, Some(&request))?;
}
if context.options.password.revoke_sessions_on_password_reset {
context.sessions()?.delete_user_sessions(&user.id).await?;
}
response::success()
}