use std::sync::Arc;
use argon2::password_hash::SaltString;
use argon2::{Argon2, PasswordHasher};
use async_trait::async_trait;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
#[cfg(test)]
use mockall::automock;
use pbkdf2::password_hash::rand_core::OsRng;
use tracing::{debug, info};
use super::authentication::DbUserAuthenticationRepository;
use super::authorization::{self, ACTION};
use crate::config::{Configuration, PasswordConstraints};
use crate::databases::database::{Database, Error};
use crate::errors::ServiceError;
use crate::mailer;
use crate::mailer::VerifyClaims;
use crate::models::user::{UserCompact, UserId, UserProfile, Username};
use crate::services::authentication::verify_password;
use crate::utils::validation::validate_email_address;
use crate::web::api::server::v1::contexts::user::forms::{ChangePasswordForm, RegistrationForm};
fn no_email() -> String {
String::new()
}
pub struct RegistrationService {
configuration: Arc<Configuration>,
mailer: Arc<mailer::Service>,
user_repository: Arc<Box<dyn Repository>>,
user_profile_repository: Arc<DbUserProfileRepository>,
}
impl RegistrationService {
#[must_use]
pub fn new(
configuration: Arc<Configuration>,
mailer: Arc<mailer::Service>,
user_repository: Arc<Box<dyn Repository>>,
user_profile_repository: Arc<DbUserProfileRepository>,
) -> Self {
Self {
configuration,
mailer,
user_repository,
user_profile_repository,
}
}
pub async fn register_user(&self, registration_form: &RegistrationForm, api_base_url: &str) -> Result<UserId, ServiceError> {
info!("registering user: {}", registration_form.username);
let settings = self.configuration.settings.read().await;
match &settings.registration {
Some(registration) => {
let Ok(username) = registration_form.username.parse::<Username>() else {
return Err(ServiceError::UsernameInvalid);
};
let opt_email = match ®istration.email {
Some(email) => {
if email.required && registration_form.email.is_none() {
return Err(ServiceError::EmailMissing);
}
match ®istration_form.email {
Some(email) => {
if email.trim() == String::new() {
None
} else {
Some(email.clone())
}
}
None => None,
}
}
None => None,
};
if let Some(email) = &opt_email {
if !validate_email_address(email) {
return Err(ServiceError::EmailInvalid);
}
}
let password_constraints = PasswordConstraints {
min_password_length: settings.auth.password_constraints.min_password_length,
max_password_length: settings.auth.password_constraints.max_password_length,
};
validate_password_constraints(
®istration_form.password,
®istration_form.confirm_password,
&password_constraints,
)?;
let password_hash = hash_password(®istration_form.password)?;
let user_id = self
.user_repository
.add(
&username.to_string(),
&opt_email.clone().unwrap_or(no_email()),
&password_hash,
)
.await?;
if user_id == 1 {
drop(self.user_repository.grant_admin_role(&user_id).await);
}
if let Some(email) = ®istration.email {
if email.verification_required {
if let Some(email) = opt_email {
let mail_res = self
.mailer
.send_verification_mail(&email, ®istration_form.username, user_id, api_base_url)
.await;
if mail_res.is_err() {
drop(self.user_repository.delete(&user_id).await);
return Err(ServiceError::FailedToSendVerificationEmail);
}
}
}
}
Ok(user_id)
}
None => Err(ServiceError::ClosedForRegistration),
}
}
pub async fn verify_email(&self, token: &str) -> Result<bool, ServiceError> {
let settings = self.configuration.settings.read().await;
let token_data = match decode::<VerifyClaims>(
token,
&DecodingKey::from_secret(settings.auth.user_claim_token_pepper.as_bytes()),
&Validation::new(Algorithm::HS256),
) {
Ok(token_data) => {
if !token_data.claims.iss.eq("email-verification") {
return Ok(false);
}
token_data.claims
}
Err(_) => return Ok(false),
};
drop(settings);
let user_id = token_data.sub;
if self.user_profile_repository.verify_email(&user_id).await.is_err() {
return Err(ServiceError::DatabaseError);
};
Ok(true)
}
}
pub struct ProfileService {
configuration: Arc<Configuration>,
user_authentication_repository: Arc<DbUserAuthenticationRepository>,
authorization_service: Arc<authorization::Service>,
}
impl ProfileService {
#[must_use]
pub fn new(
configuration: Arc<Configuration>,
user_repository: Arc<DbUserAuthenticationRepository>,
authorization_service: Arc<authorization::Service>,
) -> Self {
Self {
configuration,
user_authentication_repository: user_repository,
authorization_service,
}
}
pub async fn change_password(
&self,
maybe_user_id: Option<UserId>,
change_password_form: &ChangePasswordForm,
) -> Result<(), ServiceError> {
let Some(user_id) = maybe_user_id else {
return Err(ServiceError::UnauthorizedActionForGuests);
};
self.authorization_service
.authorize(ACTION::ChangePassword, maybe_user_id)
.await?;
info!("changing user password for user ID: {}", user_id);
let settings = self.configuration.settings.read().await;
let user_authentication = self
.user_authentication_repository
.get_user_authentication_from_id(&user_id)
.await?;
verify_password(change_password_form.current_password.as_bytes(), &user_authentication)?;
let password_constraints = PasswordConstraints {
min_password_length: settings.auth.password_constraints.min_password_length,
max_password_length: settings.auth.password_constraints.max_password_length,
};
validate_password_constraints(
&change_password_form.password,
&change_password_form.confirm_password,
&password_constraints,
)?;
let password_hash = hash_password(&change_password_form.password)?;
self.user_authentication_repository
.change_password(user_id, &password_hash)
.await?;
Ok(())
}
}
pub struct BanService {
user_profile_repository: Arc<DbUserProfileRepository>,
banned_user_list: Arc<DbBannedUserList>,
authorization_service: Arc<authorization::Service>,
}
impl BanService {
#[must_use]
pub fn new(
user_profile_repository: Arc<DbUserProfileRepository>,
banned_user_list: Arc<DbBannedUserList>,
authorization_service: Arc<authorization::Service>,
) -> Self {
Self {
user_profile_repository,
banned_user_list,
authorization_service,
}
}
pub async fn ban_user(&self, username_to_be_banned: &str, maybe_user_id: Option<UserId>) -> Result<(), ServiceError> {
let Some(user_id) = maybe_user_id else {
return Err(ServiceError::UnauthorizedActionForGuests);
};
self.authorization_service.authorize(ACTION::BanUser, maybe_user_id).await?;
debug!("user with ID {} banning username: {username_to_be_banned}", user_id);
let user_profile = self
.user_profile_repository
.get_user_profile_from_username(username_to_be_banned)
.await?;
self.banned_user_list.add(&user_profile.user_id).await?;
Ok(())
}
}
#[cfg_attr(test, automock)]
#[async_trait]
pub trait Repository: Sync + Send {
async fn get_compact(&self, user_id: &UserId) -> Result<UserCompact, ServiceError>;
async fn grant_admin_role(&self, user_id: &UserId) -> Result<(), Error>;
async fn delete(&self, user_id: &UserId) -> Result<(), Error>;
async fn add(&self, username: &str, email: &str, password_hash: &str) -> Result<UserId, Error>;
}
pub struct DbUserRepository {
database: Arc<Box<dyn Database>>,
}
impl DbUserRepository {
#[must_use]
pub fn new(database: Arc<Box<dyn Database>>) -> Self {
Self { database }
}
}
#[async_trait]
impl Repository for DbUserRepository {
async fn get_compact(&self, user_id: &UserId) -> Result<UserCompact, ServiceError> {
self.database
.get_user_compact_from_id(*user_id)
.await
.map_err(|_| ServiceError::UserNotFound)
}
async fn grant_admin_role(&self, user_id: &UserId) -> Result<(), Error> {
self.database.grant_admin_role(*user_id).await
}
async fn delete(&self, user_id: &UserId) -> Result<(), Error> {
self.database.delete_user(*user_id).await
}
async fn add(&self, username: &str, email: &str, password_hash: &str) -> Result<UserId, Error> {
self.database.insert_user_and_get_id(username, email, password_hash).await
}
}
pub struct DbUserProfileRepository {
database: Arc<Box<dyn Database>>,
}
impl DbUserProfileRepository {
#[must_use]
pub fn new(database: Arc<Box<dyn Database>>) -> Self {
Self { database }
}
pub async fn verify_email(&self, user_id: &UserId) -> Result<(), Error> {
self.database.verify_email(*user_id).await
}
pub async fn get_user_profile_from_username(&self, username: &str) -> Result<UserProfile, Error> {
self.database.get_user_profile_from_username(username).await
}
}
pub struct DbBannedUserList {
database: Arc<Box<dyn Database>>,
}
impl DbBannedUserList {
#[must_use]
pub fn new(database: Arc<Box<dyn Database>>) -> Self {
Self { database }
}
pub async fn add(&self, user_id: &UserId) -> Result<(), Error> {
let reason = "no reason".to_string();
let date_expiry = chrono::NaiveDateTime::parse_from_str("9999-01-01 00:00:00", "%Y-%m-%d %H:%M:%S")
.expect("Could not parse date from 9999-01-01 00:00:00.");
self.database.ban_user(*user_id, &reason, date_expiry).await
}
}
fn validate_password_constraints(
password: &str,
confirm_password: &str,
password_rules: &PasswordConstraints,
) -> Result<(), ServiceError> {
if password != confirm_password {
return Err(ServiceError::PasswordsDontMatch);
}
let password_length = password.len();
if password_length < password_rules.min_password_length {
return Err(ServiceError::PasswordTooShort);
}
if password_length > password_rules.max_password_length {
return Err(ServiceError::PasswordTooLong);
}
Ok(())
}
fn hash_password(password: &str) -> Result<String, ServiceError> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string();
Ok(password_hash)
}