use async_trait::async_trait;
use chrono::{Duration, Utc};
use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
use uuid::Uuid;
use crate::{
error::{AuthError, AuthResult},
models::{AuthUser, Session},
traits::AuthProvider,
AuthConfig,
};
pub struct DatabaseAuthProvider {
db: DatabaseConnection,
users_table: String,
sessions_table: String,
config: AuthConfig,
}
impl DatabaseAuthProvider {
pub fn new(
db: DatabaseConnection,
users_table: impl Into<String>,
sessions_table: impl Into<String>,
config: AuthConfig,
) -> Self {
Self {
db,
users_table: users_table.into(),
sessions_table: sessions_table.into(),
config,
}
}
}
#[async_trait]
impl AuthProvider for DatabaseAuthProvider {
async fn authenticate(&self, email: &str, password: &str) -> AuthResult<AuthUser> {
let sql = format!(
"SELECT id, email, password_hash, role FROM {} WHERE email = $1 AND is_active = true",
self.users_table
);
let stmt = Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
vec![email.into()],
);
let result = self
.db
.query_one(stmt)
.await
.map_err(|e| AuthError::Database(e.to_string()))?;
let user = result.ok_or(AuthError::InvalidCredentials)?;
let id: i64 = user
.try_get("", "id")
.map_err(|e| AuthError::Database(e.to_string()))?;
let stored_hash: String = user
.try_get("", "password_hash")
.map_err(|e| AuthError::Database(e.to_string()))?;
let role: String = user
.try_get("", "role")
.map_err(|e| AuthError::Database(e.to_string()))?;
let user_email: String = user
.try_get("", "email")
.map_err(|e| AuthError::Database(e.to_string()))?;
#[cfg(feature = "bcrypt")]
{
let is_valid = bcrypt::verify(password, &stored_hash)?;
if !is_valid {
return Err(AuthError::InvalidCredentials);
}
}
#[cfg(not(feature = "bcrypt"))]
{
if password != stored_hash {
return Err(AuthError::InvalidCredentials);
}
}
Ok(AuthUser::new(id, user_email, role))
}
async fn create_session(&self, user_id: i64) -> AuthResult<Session> {
let token = Uuid::new_v4().to_string();
let expires_at = Utc::now() + Duration::hours(self.config.session_expiration_hours);
let sql = format!(
"INSERT INTO {} (token, user_id, expires_at, created_at) VALUES ($1, $2, $3, $4)",
self.sessions_table
);
let stmt = Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
vec![
token.clone().into(),
user_id.into(),
expires_at.naive_utc().into(),
Utc::now().naive_utc().into(),
],
);
self.db
.execute(stmt)
.await
.map_err(|e| AuthError::Database(e.to_string()))?;
Ok(Session::new(token, user_id, expires_at.naive_utc()))
}
async fn validate_session(&self, token: &str) -> AuthResult<Option<AuthUser>> {
let sql = format!(
"SELECT s.user_id, s.expires_at, u.email, u.role
FROM {} s
JOIN {} u ON s.user_id = u.id
WHERE s.token = $1 AND s.expires_at > NOW()",
self.sessions_table, self.users_table
);
let stmt = Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
vec![token.into()],
);
let result = self
.db
.query_one(stmt)
.await
.map_err(|e| AuthError::Database(e.to_string()))?;
let Some(row) = result else {
return Ok(None);
};
let user_id: i64 = row
.try_get("", "user_id")
.map_err(|e| AuthError::Database(e.to_string()))?;
let email: String = row
.try_get("", "email")
.map_err(|e| AuthError::Database(e.to_string()))?;
let role: String = row
.try_get("", "role")
.map_err(|e| AuthError::Database(e.to_string()))?;
Ok(Some(AuthUser::new(user_id, email, role)))
}
async fn destroy_session(&self, token: &str) -> AuthResult<()> {
let sql = format!("DELETE FROM {} WHERE token = $1", self.sessions_table);
let stmt = Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
vec![token.into()],
);
self.db
.execute(stmt)
.await
.map_err(|e| AuthError::Database(e.to_string()))?;
Ok(())
}
async fn get_user(&self, user_id: i64) -> AuthResult<Option<AuthUser>> {
let sql = format!(
"SELECT id, email, role FROM {} WHERE id = $1 AND is_active = true",
self.users_table
);
let stmt = Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
vec![user_id.into()],
);
let result = self
.db
.query_one(stmt)
.await
.map_err(|e| AuthError::Database(e.to_string()))?;
let Some(row) = result else {
return Ok(None);
};
let id: i64 = row
.try_get("", "id")
.map_err(|e| AuthError::Database(e.to_string()))?;
let email: String = row
.try_get("", "email")
.map_err(|e| AuthError::Database(e.to_string()))?;
let role: String = row
.try_get("", "role")
.map_err(|e| AuthError::Database(e.to_string()))?;
Ok(Some(AuthUser::new(id, email, role)))
}
}