use std::any::Any;
use std::borrow::Cow;
use std::fmt::{Display, Formatter};
use async_trait::async_trait;
use cot::db::Auto;
use cot_macros::AdminModel;
use hmac::{Hmac, Mac};
use sha2::Sha512;
use thiserror::Error;
use crate::App;
use crate::admin::{AdminModelManager, DefaultAdminModelManager};
use crate::auth::{
AuthBackend, AuthError, PasswordHash, PasswordVerificationResult, Result, SessionAuthHash,
User, UserId,
};
use crate::common_types::Password;
use crate::config::SecretKey;
use crate::db::migrations::SyncDynMigration;
use crate::db::{Database, DatabaseBackend, LimitedString, Model, model, query};
use crate::form::Form;
pub mod migrations;
pub(crate) const MAX_USERNAME_LENGTH: u32 = 255;
#[derive(Debug, Clone, Form, AdminModel)]
#[model]
pub struct DatabaseUser {
#[model(primary_key)]
id: Auto<i64>,
#[model(unique)]
username: LimitedString<MAX_USERNAME_LENGTH>,
password: PasswordHash,
}
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum CreateUserError {
#[error("username is too long (max {MAX_USERNAME_LENGTH} characters, got {0})")]
UsernameTooLong(usize),
}
impl DatabaseUser {
#[must_use]
fn new(
id: Auto<i64>,
username: LimitedString<MAX_USERNAME_LENGTH>,
password: &Password,
) -> Self {
Self {
id,
username,
password: PasswordHash::from_password(password),
}
}
pub async fn create_user<DB: DatabaseBackend, T: Into<String>, U: Into<Password>>(
db: &DB,
username: T,
password: U,
) -> Result<Self> {
let username = username.into();
let username_length = username.len();
let username = LimitedString::<MAX_USERNAME_LENGTH>::new(username).map_err(|_| {
AuthError::backend_error(CreateUserError::UsernameTooLong(username_length))
})?;
let mut user = Self::new(Auto::auto(), username, &password.into());
user.insert(db).await.map_err(AuthError::backend_error)?;
Ok(user)
}
pub async fn get_by_id<DB: DatabaseBackend>(db: &DB, id: i64) -> Result<Option<Self>> {
let db_user = query!(DatabaseUser, $id == id)
.get(db)
.await
.map_err(AuthError::backend_error)?;
Ok(db_user)
}
pub async fn get_by_username<DB: DatabaseBackend>(
db: &DB,
username: &str,
) -> Result<Option<Self>> {
let username = LimitedString::<MAX_USERNAME_LENGTH>::new(username).map_err(|_| {
AuthError::backend_error(CreateUserError::UsernameTooLong(username.len()))
})?;
let db_user = query!(DatabaseUser, $username == username)
.get(db)
.await
.map_err(AuthError::backend_error)?;
Ok(db_user)
}
pub async fn authenticate<DB: DatabaseBackend>(
db: &DB,
credentials: &DatabaseUserCredentials,
) -> Result<Option<Self>> {
let username = credentials.username();
let username_limited = LimitedString::<MAX_USERNAME_LENGTH>::new(username.to_string())
.map_err(|_| {
AuthError::backend_error(CreateUserError::UsernameTooLong(username.len()))
})?;
let user = query!(DatabaseUser, $username == username_limited)
.get(db)
.await
.map_err(AuthError::backend_error)?;
if let Some(mut user) = user {
let password_hash = &user.password;
match password_hash.verify(credentials.password()) {
PasswordVerificationResult::Ok => Ok(Some(user)),
PasswordVerificationResult::OkObsolete(new_hash) => {
user.password = new_hash;
user.save(db).await.map_err(AuthError::backend_error)?;
Ok(Some(user))
}
PasswordVerificationResult::Invalid => Ok(None),
}
} else {
let dummy_hash = PasswordHash::from_password(credentials.password());
if let PasswordVerificationResult::Invalid = dummy_hash.verify(credentials.password()) {
unreachable!(
"Password hash verification should never fail for a newly generated hash"
);
}
Ok(None)
}
}
#[must_use]
pub fn id(&self) -> i64 {
match self.id {
Auto::Fixed(id) => id,
Auto::Auto => unreachable!("DatabaseUser constructed with an unknown ID"),
}
}
#[must_use]
pub fn username(&self) -> &str {
&self.username
}
}
type SessionAuthHmac = Hmac<Sha512>;
impl User for DatabaseUser {
fn id(&self) -> Option<UserId> {
Some(UserId::Int(self.id()))
}
fn username(&self) -> Option<Cow<'_, str>> {
Some(Cow::from(self.username.as_str()))
}
fn is_active(&self) -> bool {
true
}
fn is_authenticated(&self) -> bool {
true
}
fn session_auth_hash(&self, secret_key: &SecretKey) -> Option<SessionAuthHash> {
let mut mac = SessionAuthHmac::new_from_slice(secret_key.as_bytes())
.expect("HMAC can take key of any size");
mac.update(self.password.as_str().as_bytes());
let hmac_data = mac.finalize().into_bytes();
Some(SessionAuthHash::new(&hmac_data))
}
}
impl Display for DatabaseUser {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.username)
}
}
#[derive(Debug, Clone)]
pub struct DatabaseUserCredentials {
username: String,
password: Password,
}
impl DatabaseUserCredentials {
#[must_use]
pub fn new(username: String, password: Password) -> Self {
Self { username, password }
}
#[must_use]
pub fn username(&self) -> &str {
&self.username
}
#[must_use]
pub fn password(&self) -> &Password {
&self.password
}
}
#[derive(Debug, Clone)]
pub struct DatabaseUserBackend {
database: Database,
}
impl DatabaseUserBackend {
#[must_use]
pub fn new(database: Database) -> Self {
Self { database }
}
}
#[async_trait]
impl AuthBackend for DatabaseUserBackend {
async fn authenticate(
&self,
credentials: &(dyn Any + Send + Sync),
) -> Result<Option<Box<dyn User + Send + Sync>>> {
if let Some(credentials) = credentials.downcast_ref::<DatabaseUserCredentials>() {
#[expect(trivial_casts)] Ok(DatabaseUser::authenticate(&self.database, credentials)
.await
.map(|user| user.map(|user| Box::new(user) as Box<dyn User + Send + Sync>))?)
} else {
Err(AuthError::CredentialsTypeNotSupported)
}
}
async fn get_by_id(&self, id: UserId) -> Result<Option<Box<dyn User + Send + Sync>>> {
let UserId::Int(id) = id else {
return Err(AuthError::UserIdTypeNotSupported);
};
#[expect(trivial_casts)] Ok(DatabaseUser::get_by_id(&self.database, id)
.await?
.map(|user| Box::new(user) as Box<dyn User + Send + Sync>))
}
}
#[derive(Debug, Copy, Clone)]
pub struct DatabaseUserApp;
impl Default for DatabaseUserApp {
fn default() -> Self {
Self::new()
}
}
impl DatabaseUserApp {
#[must_use]
pub fn new() -> Self {
Self {}
}
}
impl App for DatabaseUserApp {
fn name(&self) -> &'static str {
"cot_db_user"
}
fn admin_model_managers(&self) -> Vec<Box<dyn AdminModelManager>> {
vec![Box::new(DefaultAdminModelManager::<DatabaseUser>::new())]
}
fn migrations(&self) -> Vec<Box<SyncDynMigration>> {
cot::db::migrations::wrap_migrations(migrations::MIGRATIONS)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SecretKey;
use crate::db::MockDatabaseBackend;
#[test]
#[cfg_attr(miri, ignore)]
fn session_auth_hash() {
let user = DatabaseUser::new(
Auto::fixed(1),
LimitedString::new("testuser").unwrap(),
&Password::new("password123"),
);
let secret_key = SecretKey::new(b"supersecretkey");
let hash = user.session_auth_hash(&secret_key);
assert!(hash.is_some());
}
#[test]
#[cfg_attr(miri, ignore)]
fn database_user_traits() {
let user = DatabaseUser::new(
Auto::fixed(1),
LimitedString::new("testuser").unwrap(),
&Password::new("password123"),
);
let user_ref: &dyn User = &user;
assert_eq!(user_ref.id(), Some(UserId::Int(1)));
assert_eq!(user_ref.username(), Some(Cow::from("testuser")));
assert!(user_ref.is_active());
assert!(user_ref.is_authenticated());
assert!(
user_ref
.session_auth_hash(&SecretKey::new(b"supersecretkey"))
.is_some()
);
}
#[cot::test]
#[cfg_attr(miri, ignore)]
async fn create_user() {
let mut mock_db = MockDatabaseBackend::new();
mock_db
.expect_insert::<DatabaseUser>()
.returning(|_| Ok(()));
let username = "testuser".to_string();
let password = Password::new("password123");
let user = DatabaseUser::create_user(&mock_db, username.clone(), &password)
.await
.unwrap();
assert_eq!(user.username(), username);
}
#[cot::test]
#[cfg_attr(miri, ignore)]
async fn get_by_id() {
let mut mock_db = MockDatabaseBackend::new();
let user = DatabaseUser::new(
Auto::fixed(1),
LimitedString::new("testuser").unwrap(),
&Password::new("password123"),
);
mock_db
.expect_get::<DatabaseUser>()
.returning(move |_| Ok(Some(user.clone())));
let result = DatabaseUser::get_by_id(&mock_db, 1).await.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().username(), "testuser");
}
#[cot::test]
#[cfg_attr(miri, ignore)]
async fn authenticate() {
let mut mock_db = MockDatabaseBackend::new();
let user = DatabaseUser::new(
Auto::fixed(1),
LimitedString::new("testuser").unwrap(),
&Password::new("password123"),
);
mock_db
.expect_get::<DatabaseUser>()
.returning(move |_| Ok(Some(user.clone())));
let credentials =
DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123"));
let result = DatabaseUser::authenticate(&mock_db, &credentials)
.await
.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().username(), "testuser");
}
#[cot::test]
#[cfg_attr(miri, ignore)]
async fn authenticate_non_existing() {
let mut mock_db = MockDatabaseBackend::new();
mock_db
.expect_get::<DatabaseUser>()
.returning(move |_| Ok(None));
let credentials =
DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123"));
let result = DatabaseUser::authenticate(&mock_db, &credentials)
.await
.unwrap();
assert!(result.is_none());
}
#[cot::test]
#[cfg_attr(miri, ignore)]
async fn authenticate_invalid_password() {
let mut mock_db = MockDatabaseBackend::new();
let user = DatabaseUser::new(
Auto::fixed(1),
LimitedString::new("testuser").unwrap(),
&Password::new("password123"),
);
mock_db
.expect_get::<DatabaseUser>()
.returning(move |_| Ok(Some(user.clone())));
let credentials =
DatabaseUserCredentials::new("testuser".to_string(), Password::new("invalid"));
let result = DatabaseUser::authenticate(&mock_db, &credentials)
.await
.unwrap();
assert!(result.is_none());
}
}