use crate::utils::error::gateway_error::{GatewayError, Result};
use sea_orm::*;
use tracing::debug;
use super::super::entities::{self, password_reset_token, user};
use super::types::SeaOrmDatabase;
impl SeaOrmDatabase {
pub async fn store_password_reset_token(
&self,
user_id: uuid::Uuid,
token: &str,
expires_at: chrono::DateTime<chrono::Utc>,
) -> Result<()> {
debug!("Storing password reset token for user: {}", user_id);
let txn = self.db.begin().await.map_err(GatewayError::from)?;
entities::PasswordResetToken::delete_many()
.filter(password_reset_token::Column::UserId.eq(user_id))
.exec(&txn)
.await
.map_err(GatewayError::from)?;
let active_model = password_reset_token::ActiveModel {
id: NotSet,
user_id: Set(user_id),
token: Set(token.to_string()),
expires_at: Set(expires_at.into()),
created_at: Set(chrono::Utc::now().into()),
used_at: Set(None),
};
entities::PasswordResetToken::insert(active_model)
.exec(&txn)
.await
.map_err(GatewayError::from)?;
txn.commit().await.map_err(GatewayError::from)?;
Ok(())
}
pub async fn verify_password_reset_token(&self, token: &str) -> Result<Option<uuid::Uuid>> {
debug!("Verifying password reset token");
let txn = self.db.begin().await.map_err(GatewayError::from)?;
let token_model = entities::PasswordResetToken::find()
.filter(password_reset_token::Column::Token.eq(token))
.filter(password_reset_token::Column::UsedAt.is_null())
.filter(password_reset_token::Column::ExpiresAt.gt(chrono::Utc::now()))
.one(&txn)
.await
.map_err(GatewayError::from)?;
if let Some(token_model) = token_model {
let user_id = token_model.user_id;
let mut active_model: password_reset_token::ActiveModel = token_model.into();
active_model.used_at = Set(Some(chrono::Utc::now().into()));
active_model
.update(&txn)
.await
.map_err(GatewayError::from)?;
txn.commit().await.map_err(GatewayError::from)?;
Ok(Some(user_id))
} else {
txn.rollback().await.map_err(GatewayError::from)?;
Ok(None)
}
}
pub async fn invalidate_password_reset_token(&self, token: &str) -> Result<()> {
debug!("Invalidating password reset token");
let token_model = entities::PasswordResetToken::find()
.filter(password_reset_token::Column::Token.eq(token))
.one(&self.db)
.await
.map_err(GatewayError::from)?;
if let Some(token_model) = token_model {
let mut active_model: password_reset_token::ActiveModel = token_model.into();
active_model.used_at = Set(Some(chrono::Utc::now().into()));
active_model
.update(&self.db)
.await
.map_err(GatewayError::from)?;
}
Ok(())
}
pub async fn cleanup_expired_tokens(&self) -> Result<u64> {
debug!("Cleaning up expired password reset tokens");
let result = entities::PasswordResetToken::delete_many()
.filter(password_reset_token::Column::ExpiresAt.lt(chrono::Utc::now()))
.exec(&self.db)
.await
.map_err(GatewayError::from)?;
Ok(result.rows_affected)
}
pub async fn reset_password_with_token(
&self,
token: &str,
password_hash: &str,
) -> Result<bool> {
debug!("Atomically consuming password reset token and updating password");
let txn = self.db.begin().await.map_err(GatewayError::from)?;
let token_model = entities::PasswordResetToken::find()
.filter(password_reset_token::Column::Token.eq(token))
.filter(password_reset_token::Column::UsedAt.is_null())
.filter(password_reset_token::Column::ExpiresAt.gt(chrono::Utc::now()))
.one(&txn)
.await
.map_err(GatewayError::from)?;
let token_model = match token_model {
Some(m) => m,
None => {
txn.rollback().await.map_err(GatewayError::from)?;
return Ok(false);
}
};
let user_id = token_model.user_id;
let mut token_active: password_reset_token::ActiveModel = token_model.into();
token_active.used_at = Set(Some(chrono::Utc::now().into()));
token_active
.update(&txn)
.await
.map_err(GatewayError::from)?;
let user_model = entities::User::find_by_id(user_id)
.one(&txn)
.await
.map_err(GatewayError::from)?
.ok_or_else(|| GatewayError::NotFound("User not found".to_string()))?;
let mut user_active: user::ActiveModel = user_model.into();
user_active.password_hash = Set(password_hash.to_string());
user_active.updated_at = Set(chrono::Utc::now().into());
user_active.update(&txn).await.map_err(GatewayError::from)?;
txn.commit().await.map_err(GatewayError::from)?;
Ok(true)
}
}