use time::{Duration, OffsetDateTime};
use crate::api::plugin_pipeline::run_password_validators;
use crate::api::{request_base_url, ApiRequest};
use crate::context::AuthContext;
use crate::crypto::random::generate_random_string;
use crate::db::{Session, User};
use crate::error::RustAuthError;
use crate::options::{PasswordResetEmail, PasswordResetPayload};
use crate::outbound::dispatch_outbound;
use crate::plugin::PluginPasswordValidationRejection;
use crate::session::CreateSessionInput;
use crate::user::CreateCredentialAccountInput;
use crate::verification::CreateVerificationInput;
const DONT_REMEMBER_SESSION_EXPIRES_IN: i64 = 60 * 60 * 24;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(in crate::api) struct ChangePasswordInput {
pub(in crate::api) current_password: String,
pub(in crate::api) new_password: String,
pub(in crate::api) revoke_other_sessions: bool,
pub(in crate::api) dont_remember: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(in crate::api) struct SetPasswordInput {
pub(in crate::api) new_password: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(in crate::api) struct VerifyPasswordInput {
pub(in crate::api) password: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(in crate::api) struct RequestPasswordResetInput {
pub(in crate::api) email: String,
pub(in crate::api) redirect_to: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(in crate::api) struct ResetPasswordInput {
pub(in crate::api) token: String,
pub(in crate::api) new_password: String,
}
#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
pub(in crate::api) enum PasswordServiceError {
#[error("credential account not found")]
CredentialAccountNotFound,
#[error("invalid password")]
InvalidPassword,
#[error("invalid token")]
InvalidToken,
#[error("password already set")]
PasswordAlreadySet,
#[error("password is too long")]
PasswordTooLong,
#[error("password is too short")]
PasswordTooShort,
#[error("password validation rejected the request")]
PasswordValidation(PluginPasswordValidationRejection),
}
pub(in crate::api) async fn change_password(
context: &AuthContext,
user: &User,
input: ChangePasswordInput,
) -> Result<Option<Session>, PasswordServiceErrorOrRustAuth> {
validate_password_length(context, &input.new_password)?;
let users = context.users()?;
let Some(account) = users.find_credential_account(&user.id).await? else {
return Err(PasswordServiceError::CredentialAccountNotFound.into());
};
let Some(password_hash) = account.password.as_deref() else {
return Err(PasswordServiceError::CredentialAccountNotFound.into());
};
if !(context.password.verify)(password_hash, &input.current_password)? {
return Err(PasswordServiceError::InvalidPassword.into());
}
run_password_validators(context, "/change-password", &input.new_password)
.await
.map_err(PasswordServiceError::PasswordValidation)?;
let new_hash = (context.password.hash)(&input.new_password)?;
users
.update_credential_password(&user.id, &new_hash)
.await?;
if !input.revoke_other_sessions {
return Ok(None);
}
let sessions = context.sessions()?;
sessions.delete_user_sessions(&user.id).await?;
let expires_in = if input.dont_remember {
DONT_REMEMBER_SESSION_EXPIRES_IN
} else {
context.session_config.expires_in.whole_seconds()
};
Ok(Some(
sessions
.create_session(CreateSessionInput::new(
&user.id,
OffsetDateTime::now_utc() + Duration::seconds(expires_in),
))
.await?,
))
}
pub(in crate::api) async fn set_password(
context: &AuthContext,
user: &User,
input: SetPasswordInput,
) -> Result<(), PasswordServiceErrorOrRustAuth> {
validate_password_length(context, &input.new_password)?;
let users = context.users()?;
if users.find_credential_account(&user.id).await?.is_some() {
return Err(PasswordServiceError::PasswordAlreadySet.into());
}
let hash = (context.password.hash)(&input.new_password)?;
users
.create_credential_account(CreateCredentialAccountInput::new(&user.id, hash))
.await?;
Ok(())
}
pub(in crate::api) async fn verify_password(
context: &AuthContext,
user: &User,
input: VerifyPasswordInput,
) -> Result<(), PasswordServiceErrorOrRustAuth> {
let Some(account) = context.users()?.find_credential_account(&user.id).await? else {
return Err(PasswordServiceError::InvalidPassword.into());
};
let Some(password_hash) = account.password.as_deref() else {
return Err(PasswordServiceError::InvalidPassword.into());
};
if !(context.password.verify)(password_hash, &input.password)? {
return Err(PasswordServiceError::InvalidPassword.into());
}
Ok(())
}
pub(in crate::api) async fn request_password_reset(
context: &AuthContext,
request: Option<&ApiRequest>,
input: RequestPasswordResetInput,
) -> Result<(), RustAuthError> {
let Some(user) = context.users()?.find_user_by_email(&input.email).await? else {
return Ok(());
};
let token = generate_random_string(24);
let expires_in = context
.options
.password
.reset_password_token_expires_in
.unwrap_or(time::Duration::hours(1));
context
.verifications()?
.create_verification(CreateVerificationInput::new(
format!("reset-password:{token}"),
user.id.clone(),
OffsetDateTime::now_utc() + expires_in,
))
.await?;
if let Some(sender) = &context.options.password.send_reset_password {
let url = password_reset_url(context, request, &token, input.redirect_to.as_deref());
let payload = PasswordResetEmail { user, url, token };
let send = sender.send_reset_password(payload, request);
dispatch_outbound(context, send);
}
Ok(())
}
pub(in crate::api) async fn reset_password(
context: &AuthContext,
request: Option<&ApiRequest>,
input: ResetPasswordInput,
) -> Result<(), PasswordServiceErrorOrRustAuth> {
validate_password_length(context, &input.new_password)?;
let identifier = format!("reset-password:{}", input.token);
let verifications = context.verifications()?;
let Some(verification) = verifications.find_verification(&identifier).await? else {
return Err(PasswordServiceError::InvalidToken.into());
};
if verification.expires_at <= OffsetDateTime::now_utc() {
return Err(PasswordServiceError::InvalidToken.into());
}
run_password_validators(context, "/reset-password", &input.new_password)
.await
.map_err(PasswordServiceError::PasswordValidation)?;
let user_id = verification.value;
let users = context.users()?;
let Some(user) = users.find_user_by_id(&user_id).await? else {
verifications.delete_verification(&identifier).await?;
return Err(PasswordServiceError::InvalidToken.into());
};
let new_hash = (context.password.hash)(&input.new_password)?;
if users
.update_credential_password(&user_id, &new_hash)
.await?
.is_none()
{
users
.create_credential_account(CreateCredentialAccountInput::new(&user_id, new_hash))
.await?;
}
verifications.delete_verification(&identifier).await?;
if let Some(callback) = &context.options.password.on_password_reset {
callback.on_password_reset(PasswordResetPayload { user: user.clone() }, request)?;
}
if context.options.password.revoke_sessions_on_password_reset {
context.sessions()?.delete_user_sessions(&user.id).await?;
}
Ok(())
}
pub(in crate::api) async fn reset_password_callback_token_is_valid(
context: &AuthContext,
token: &str,
) -> Result<bool, RustAuthError> {
let identifier = format!("reset-password:{token}");
let verification = context
.verifications()?
.find_verification(&identifier)
.await?;
Ok(matches!(
verification,
Some(verification) if verification.expires_at > OffsetDateTime::now_utc()
))
}
#[derive(Debug, thiserror::Error)]
pub(in crate::api) enum PasswordServiceErrorOrRustAuth {
#[error(transparent)]
Service(#[from] PasswordServiceError),
#[error(transparent)]
RustAuth(#[from] RustAuthError),
}
fn validate_password_length(
context: &AuthContext,
password: &str,
) -> Result<(), PasswordServiceError> {
if password.len() < context.password.config.min_password_length {
return Err(PasswordServiceError::PasswordTooShort);
}
if password.len() > context.password.config.max_password_length {
return Err(PasswordServiceError::PasswordTooLong);
}
Ok(())
}
fn password_reset_url(
context: &AuthContext,
request: Option<&ApiRequest>,
token: &str,
redirect_to: Option<&str>,
) -> String {
let callback_url = redirect_to.unwrap_or("/");
format!(
"{}/reset-password/{token}?callbackURL={}",
request_base_url(context, request),
percent_encode(callback_url)
)
}
fn percent_encode(value: &str) -> String {
url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
}