use crate::auth::password::{hash_password, verify_password, PasswordError};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, Type};
use thiserror::Error;
use validator::Validate;
#[derive(Debug, Error)]
pub enum UserError {
#[error("Invalid email address: {0}")]
InvalidEmail(String),
#[error("Password does not meet requirements: {0}")]
WeakPassword(String),
#[error("Validation error: {0}")]
ValidationFailed(String),
#[error("Password hashing failed: {0}")]
PasswordHashingFailed(#[from] PasswordError),
#[error("Database error: {0}")]
DatabaseError(#[from] sqlx::Error),
#[error("User not found")]
NotFound,
#[error("Invalid email or password")]
InvalidCredentials,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Type)]
#[serde(transparent)]
#[sqlx(transparent)]
pub struct EmailAddress(String);
impl EmailAddress {
pub fn parse(email: impl Into<String>) -> Result<Self, UserError> {
#[derive(Validate)]
struct EmailValidator {
#[validate(email)]
email: String,
}
let email = email.into();
if !email.contains('@') || !email.contains('.') {
return Err(UserError::InvalidEmail(
"Email must contain @ and domain".to_string(),
));
}
let validator = EmailValidator {
email: email.clone(),
};
validator.validate().map_err(|e| {
UserError::ValidationFailed(format!("Invalid email format: {e}"))
})?;
Ok(Self(email.to_lowercase()))
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
impl std::fmt::Display for EmailAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::str::FromStr for EmailAddress {
type Err = UserError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct User {
pub id: i64,
#[serde(serialize_with = "serialize_email")]
#[serde(deserialize_with = "deserialize_email")]
pub email: EmailAddress,
#[serde(skip_serializing)]
pub password_hash: String,
pub roles: Vec<String>,
pub permissions: Vec<String>,
pub email_verified: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
fn serialize_email<S>(email: &EmailAddress, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(email.as_str())
}
fn deserialize_email<'de, D>(deserializer: D) -> Result<EmailAddress, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
EmailAddress::parse(s).map_err(serde::de::Error::custom)
}
impl User {
pub fn verify_password(&self, password: &str) -> Result<bool, PasswordError> {
verify_password(password, &self.password_hash)
}
#[cfg(feature = "postgres")]
pub async fn create(
data: CreateUser,
pool: &sqlx::PgPool,
) -> Result<Self, UserError> {
validate_password_strength(&data.password)?;
let password_hash = hash_password(&data.password)?;
let user = sqlx::query_as::<_, Self>(
r"
INSERT INTO users (email, password_hash, roles, permissions, email_verified)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, email, password_hash, roles, permissions, email_verified, created_at, updated_at
",
)
.bind(data.email.as_str())
.bind(&password_hash)
.bind(vec!["user".to_string()]) .bind(Vec::<String>::new()) .bind(false) .fetch_one(pool)
.await?;
Ok(user)
}
#[cfg(feature = "postgres")]
pub async fn find_by_email(
email: &EmailAddress,
pool: &sqlx::PgPool,
) -> Result<Self, UserError> {
let user = sqlx::query_as::<_, Self>(
r"
SELECT id, email, password_hash, roles, permissions, email_verified, created_at, updated_at
FROM users
WHERE email = $1
",
)
.bind(email.as_str())
.fetch_optional(pool)
.await?
.ok_or(UserError::NotFound)?;
Ok(user)
}
#[cfg(feature = "postgres")]
pub async fn find_by_id(id: i64, pool: &sqlx::PgPool) -> Result<Self, UserError> {
let user = sqlx::query_as::<_, Self>(
r"
SELECT id, email, password_hash, roles, permissions, email_verified, created_at, updated_at
FROM users
WHERE id = $1
",
)
.bind(id)
.fetch_optional(pool)
.await?
.ok_or(UserError::NotFound)?;
Ok(user)
}
#[cfg(feature = "postgres")]
pub async fn authenticate(
email: &EmailAddress,
password: &str,
pool: &sqlx::PgPool,
) -> Result<Self, UserError> {
let user = Self::find_by_email(email, pool)
.await
.map_err(|_| UserError::InvalidCredentials)?;
let valid = user
.verify_password(password)
.map_err(|_| UserError::InvalidCredentials)?;
if !valid {
return Err(UserError::InvalidCredentials);
}
Ok(user)
}
}
#[derive(Debug, Clone, Validate)]
pub struct CreateUser {
pub email: EmailAddress,
#[validate(length(min = 8, message = "Password must be at least 8 characters"))]
pub password: String,
}
fn validate_password_strength(password: &str) -> Result<(), UserError> {
if password.len() < 8 {
return Err(UserError::WeakPassword(
"Password must be at least 8 characters".to_string(),
));
}
let has_uppercase = password.chars().any(char::is_uppercase);
let has_lowercase = password.chars().any(char::is_lowercase);
let has_digit = password.chars().any(|c| c.is_ascii_digit());
if !has_uppercase {
return Err(UserError::WeakPassword(
"Password must contain at least one uppercase letter".to_string(),
));
}
if !has_lowercase {
return Err(UserError::WeakPassword(
"Password must contain at least one lowercase letter".to_string(),
));
}
if !has_digit {
return Err(UserError::WeakPassword(
"Password must contain at least one digit".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_email_address_parsing() {
assert!(EmailAddress::parse("user@example.com").is_ok());
assert!(EmailAddress::parse("user.name@example.co.uk").is_ok());
assert!(EmailAddress::parse("user+tag@example.com").is_ok());
assert!(EmailAddress::parse("not-an-email").is_err());
assert!(EmailAddress::parse("@example.com").is_err());
assert!(EmailAddress::parse("user@").is_err());
assert!(EmailAddress::parse("user").is_err());
}
#[test]
fn test_email_normalization() {
let email1 = EmailAddress::parse("User@Example.COM").unwrap();
let email2 = EmailAddress::parse("user@example.com").unwrap();
assert_eq!(email1, email2);
assert_eq!(email1.as_str(), "user@example.com");
}
#[test]
fn test_password_strength_validation() {
assert!(validate_password_strength("SecurePass123").is_ok());
assert!(validate_password_strength("MyP@ssw0rd").is_ok());
assert!(validate_password_strength("Pass1").is_err());
assert!(matches!(
validate_password_strength("password123"),
Err(UserError::WeakPassword(_))
));
assert!(matches!(
validate_password_strength("PASSWORD123"),
Err(UserError::WeakPassword(_))
));
assert!(matches!(
validate_password_strength("PasswordOnly"),
Err(UserError::WeakPassword(_))
));
}
#[test]
fn test_user_password_verification() {
let password = "TestPassword123";
let hash = hash_password(password).expect("Failed to hash password");
let user = User {
id: 1,
email: EmailAddress::parse("test@example.com").unwrap(),
password_hash: hash,
roles: vec!["user".to_string()],
permissions: vec![],
email_verified: false,
created_at: Utc::now(),
updated_at: Utc::now(),
};
assert!(user.verify_password(password).expect("Verification failed"));
assert!(!user
.verify_password("wrong-password")
.expect("Verification failed"));
}
#[test]
fn test_email_serialization() {
let email = EmailAddress::parse("test@example.com").unwrap();
let json = serde_json::to_string(&email).expect("Failed to serialize");
assert_eq!(json, r#""test@example.com""#);
let deserialized: EmailAddress =
serde_json::from_str(&json).expect("Failed to deserialize");
assert_eq!(deserialized, email);
}
#[test]
fn test_user_serialization_skips_password() {
let user = User {
id: 1,
email: EmailAddress::parse("test@example.com").unwrap(),
password_hash: "hash".to_string(),
roles: vec!["user".to_string()],
permissions: vec![],
email_verified: false,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let json = serde_json::to_string(&user).expect("Failed to serialize");
assert!(!json.contains("password_hash"));
assert!(json.contains("test@example.com"));
}
}