use std::{fmt, future::Future, str::FromStr};
use sqlx::SqlitePool;
use super::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, strum::IntoStaticStr)]
pub enum Role {
#[default]
#[strum(serialize = "admin")]
Admin,
}
impl Role {
pub fn as_str(&self) -> &'static str {
self.into()
}
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for Role {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"admin" => Ok(Self::Admin),
other => Err(Error::Decode(format!("unknown role value: {other:?}"))),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AdminUser {
pub id: i64,
pub username: String,
pub password_hash: String,
pub role: Role,
pub created_at: i64,
pub updated_at: i64,
}
struct AdminUserRow {
id: i64,
username: String,
password_hash: String,
role: String,
created_at: i64,
updated_at: i64,
}
impl TryFrom<AdminUserRow> for AdminUser {
type Error = Error;
fn try_from(row: AdminUserRow) -> Result<Self> {
Ok(AdminUser {
id: row.id,
username: row.username,
password_hash: row.password_hash,
role: row.role.parse()?,
created_at: row.created_at,
updated_at: row.updated_at,
})
}
}
pub trait AdminUserRepository {
fn count(&self) -> impl Future<Output = Result<i64>>;
fn find_by_username(&self, username: &str) -> impl Future<Output = Result<Option<AdminUser>>>;
fn create(
&self,
username: &str,
password_hash: &str,
) -> impl Future<Output = Result<AdminUser>>;
fn create_initial(
&self,
username: &str,
password_hash: &str,
) -> impl Future<Output = Result<Option<AdminUser>>>;
}
pub struct SqliteAdminUserRepo {
pool: SqlitePool,
}
impl SqliteAdminUserRepo {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
impl AdminUserRepository for SqliteAdminUserRepo {
async fn count(&self) -> Result<i64> {
let count = sqlx::query_scalar!(r#"SELECT COUNT(*) AS "count!" FROM admin_users"#)
.fetch_one(&self.pool)
.await?;
Ok(count)
}
async fn find_by_username(&self, username: &str) -> Result<Option<AdminUser>> {
let row = sqlx::query_as!(
AdminUserRow,
r#"SELECT
id AS "id!",
username,
password_hash,
role,
created_at AS "created_at!",
updated_at AS "updated_at!"
FROM admin_users
WHERE username = ?"#,
username,
)
.fetch_optional(&self.pool)
.await?;
row.map(AdminUser::try_from).transpose()
}
async fn create(&self, username: &str, password_hash: &str) -> Result<AdminUser> {
let row = sqlx::query!(
r#"INSERT INTO admin_users (username, password_hash)
VALUES (?, ?)
RETURNING
id AS "id!",
created_at AS "created_at!",
updated_at AS "updated_at!""#,
username,
password_hash,
)
.fetch_one(&self.pool)
.await?;
Ok(AdminUser {
id: row.id,
username: username.to_owned(),
password_hash: password_hash.to_owned(),
role: Role::Admin,
created_at: row.created_at,
updated_at: row.updated_at,
})
}
async fn create_initial(
&self,
username: &str,
password_hash: &str,
) -> Result<Option<AdminUser>> {
let row = sqlx::query_as!(
AdminUserRow,
r#"INSERT INTO admin_users (username, password_hash, role)
SELECT ?, ?, 'admin'
WHERE NOT EXISTS (SELECT 1 FROM admin_users)
RETURNING
id AS "id!",
username,
password_hash,
role,
created_at AS "created_at!",
updated_at AS "updated_at!""#,
username,
password_hash,
)
.fetch_optional(&self.pool)
.await?;
row.map(AdminUser::try_from).transpose()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn open_repo() -> (TempDir, SqliteAdminUserRepo) {
let (dir, db) = crate::test_support::temp_db().await;
(dir, SqliteAdminUserRepo::new(db.pool().clone()))
}
#[test]
fn role_round_trips_through_text() {
assert_eq!(Role::Admin.as_str(), "admin");
assert_eq!("admin".parse::<Role>().expect("parse"), Role::Admin);
assert!("root".parse::<Role>().is_err());
}
#[tokio::test]
async fn fresh_db_has_no_admin_users() {
let (_dir, repo) = open_repo().await;
assert_eq!(repo.count().await.expect("count"), 0);
}
#[tokio::test]
async fn create_then_find_round_trips() {
let (_dir, repo) = open_repo().await;
let created = repo
.create("admin", "$argon2id$dummy")
.await
.expect("create");
assert!(created.id > 0);
assert_eq!(created.username, "admin");
assert_eq!(created.role, Role::Admin);
assert!(created.created_at > 0);
assert_eq!(repo.count().await.expect("count"), 1);
let found = repo
.find_by_username("admin")
.await
.expect("find")
.expect("present");
assert_eq!(found, created);
}
#[tokio::test]
async fn find_unknown_returns_none() {
let (_dir, repo) = open_repo().await;
assert!(
repo.find_by_username("nobody")
.await
.expect("find")
.is_none()
);
}
#[tokio::test]
async fn duplicate_username_errors() {
let (_dir, repo) = open_repo().await;
repo.create("admin", "$h1").await.expect("first");
let err = repo.create("admin", "$h2").await;
assert!(
matches!(err, Err(Error::Sqlx(_))),
"duplicate username must surface as Sqlx error, got {err:?}"
);
}
#[tokio::test]
async fn create_initial_only_inserts_when_table_is_empty() {
let (_dir, repo) = open_repo().await;
let first = repo
.create_initial("admin", "$h1")
.await
.expect("create initial")
.expect("first setup should insert");
let second = repo
.create_initial("other", "$h2")
.await
.expect("second create initial");
assert_eq!(first.username, "admin");
assert!(second.is_none(), "second setup must not insert");
assert_eq!(repo.count().await.expect("count"), 1);
assert!(
repo.find_by_username("other")
.await
.expect("find")
.is_none()
);
}
}