use sqlx::SqlitePool;
use super::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionRecord {
pub id: String,
pub token_hash: String,
pub user_id: i64,
pub created_at: i64,
pub expires_at: i64,
}
#[derive(Debug, Clone)]
pub struct NewSession {
pub id: String,
pub token_hash: String,
pub user_id: i64,
pub expires_at: i64,
}
#[allow(async_fn_in_trait)]
pub trait SessionRepository {
async fn insert(&self, session: &NewSession) -> Result<()>;
async fn find(&self, id: &str) -> Result<Option<SessionRecord>>;
async fn touch(&self, id: &str, expires_at: i64) -> Result<()>;
async fn delete(&self, id: &str) -> Result<()>;
async fn delete_for_user(&self, user_id: i64) -> Result<()>;
async fn delete_expired(&self, now: i64) -> Result<u64>;
}
pub struct SqliteSessionRepo {
pool: SqlitePool,
}
impl SqliteSessionRepo {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
impl SessionRepository for SqliteSessionRepo {
async fn insert(&self, session: &NewSession) -> Result<()> {
sqlx::query!(
r#"INSERT INTO sessions (id, token_hash, user_id, expires_at)
VALUES (?, ?, ?, ?)"#,
session.id,
session.token_hash,
session.user_id,
session.expires_at,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn find(&self, id: &str) -> Result<Option<SessionRecord>> {
let row = sqlx::query_as!(
SessionRecord,
r#"SELECT
id AS "id!",
token_hash AS "token_hash!",
user_id AS "user_id!",
created_at AS "created_at!",
expires_at AS "expires_at!"
FROM sessions
WHERE id = ?"#,
id,
)
.fetch_optional(&self.pool)
.await?;
Ok(row)
}
async fn touch(&self, id: &str, expires_at: i64) -> Result<()> {
sqlx::query!(
"UPDATE sessions SET expires_at = ? WHERE id = ?",
expires_at,
id,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn delete(&self, id: &str) -> Result<()> {
sqlx::query!("DELETE FROM sessions WHERE id = ?", id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn delete_for_user(&self, user_id: i64) -> Result<()> {
sqlx::query!("DELETE FROM sessions WHERE user_id = ?", user_id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn delete_expired(&self, now: i64) -> Result<u64> {
let result = sqlx::query!("DELETE FROM sessions WHERE expires_at <= ?", now)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::{Db, admin_users::AdminUserRepository, admin_users::SqliteAdminUserRepo};
use tempfile::TempDir;
async fn open_repo() -> (TempDir, SqliteSessionRepo, i64) {
let dir = TempDir::new().expect("temp dir");
let db = Db::connect(dir.path().join("test.db"))
.await
.expect("connect");
let user = SqliteAdminUserRepo::new(db.pool().clone())
.create("admin", "$argon2id$dummy")
.await
.expect("create user");
(dir, SqliteSessionRepo::new(db.pool().clone()), user.id)
}
fn new_session(id: &str, user_id: i64, expires_at: i64) -> NewSession {
NewSession {
id: id.to_owned(),
token_hash: format!("hash-of-{id}"),
user_id,
expires_at,
}
}
#[tokio::test]
async fn insert_then_find_round_trips() {
let (_dir, repo, uid) = open_repo().await;
repo.insert(&new_session("sess-a", uid, 2_000_000_000))
.await
.expect("insert");
let found = repo.find("sess-a").await.expect("find").expect("present");
assert_eq!(found.id, "sess-a");
assert_eq!(found.token_hash, "hash-of-sess-a");
assert_eq!(found.user_id, uid);
assert_eq!(found.expires_at, 2_000_000_000);
assert!(found.created_at > 0);
}
#[tokio::test]
async fn find_missing_returns_none() {
let (_dir, repo, _uid) = open_repo().await;
assert!(repo.find("nope").await.expect("find").is_none());
}
#[tokio::test]
async fn touch_slides_expiry() {
let (_dir, repo, uid) = open_repo().await;
repo.insert(&new_session("s", uid, 1000))
.await
.expect("ins");
repo.touch("s", 5000).await.expect("touch");
let found = repo.find("s").await.expect("find").expect("present");
assert_eq!(found.expires_at, 5000);
}
#[tokio::test]
async fn delete_removes_session() {
let (_dir, repo, uid) = open_repo().await;
repo.insert(&new_session("s", uid, 1000))
.await
.expect("ins");
repo.delete("s").await.expect("delete");
assert!(repo.find("s").await.expect("find").is_none());
}
#[tokio::test]
async fn delete_for_user_removes_all() {
let (_dir, repo, uid) = open_repo().await;
repo.insert(&new_session("a", uid, 1000)).await.expect("a");
repo.insert(&new_session("b", uid, 1000)).await.expect("b");
repo.delete_for_user(uid).await.expect("delete_for_user");
assert!(repo.find("a").await.expect("find").is_none());
assert!(repo.find("b").await.expect("find").is_none());
}
#[tokio::test]
async fn delete_expired_removes_only_stale() {
let (_dir, repo, uid) = open_repo().await;
repo.insert(&new_session("old", uid, 100))
.await
.expect("old");
repo.insert(&new_session("new", uid, 10_000))
.await
.expect("new");
let removed = repo.delete_expired(1000).await.expect("delete_expired");
assert_eq!(removed, 1);
assert!(repo.find("old").await.expect("find").is_none());
assert!(repo.find("new").await.expect("find").is_some());
}
}