use anyhow::Result;
use sqlx::{postgres::PgPoolOptions, PgPool};
#[derive(Clone, Debug)]
pub struct Database {
pool: PgPool,
}
impl Database {
pub async fn connect(database_url: &str) -> Result<Self> {
let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(20);
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.connect(database_url)
.await?;
Ok(Self { pool })
}
pub async fn migrate(&self) -> Result<()> {
const MIGRATION_LOCK_ID: i64 = 8675309;
tracing::info!("Acquiring advisory lock for database migrations...");
sqlx::query("SELECT pg_advisory_lock($1)")
.bind(MIGRATION_LOCK_ID)
.execute(&self.pool)
.await?;
tracing::info!("Advisory lock acquired, running migrations...");
let result =
sqlx::migrate!("./migrations")
.run(&self.pool)
.await
.map_err(|e| -> anyhow::Error {
if e.to_string().contains("previously applied but is missing") {
tracing::error!(
"sqlx refused to migrate: the DB's `_sqlx_migrations` table has an \
applied row whose matching file is missing from the repo. This \
blocks ALL subsequent migrations from running. To fix, either \
restore the missing file or remove the orphaned tracking row \
manually (psql: `DELETE FROM _sqlx_migrations WHERE version = …`). \
Full error: {:?}",
e
);
}
e.into()
});
if let Err(unlock_err) = sqlx::query("SELECT pg_advisory_unlock($1)")
.bind(MIGRATION_LOCK_ID)
.execute(&self.pool)
.await
{
tracing::error!("Failed to release migration advisory lock: {}", unlock_err);
} else {
tracing::info!("Migration advisory lock released");
}
result
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub async fn get_total_plugins(&self) -> Result<i64> {
let count: (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM plugins").fetch_one(&self.pool).await?;
Ok(count.0)
}
pub async fn get_total_downloads(&self) -> Result<i64> {
let total: (Option<i64>,) =
sqlx::query_as("SELECT COALESCE(SUM(downloads_total)::BIGINT, 0) FROM plugins")
.fetch_one(&self.pool)
.await?;
Ok(total.0.unwrap_or(0))
}
pub async fn get_total_users(&self) -> Result<i64> {
let count: (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(&self.pool).await?;
Ok(count.0)
}
pub async fn store_refresh_token_jti(
&self,
jti: &str,
user_id: uuid::Uuid,
expires_at: chrono::DateTime<chrono::Utc>,
) -> Result<()> {
sqlx::query(
r#"
INSERT INTO token_revocations (jti, user_id, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (jti) DO NOTHING
"#,
)
.bind(jti)
.bind(user_id)
.bind(expires_at)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn is_token_revoked(&self, jti: &str) -> Result<bool> {
let result: Option<(Option<chrono::DateTime<chrono::Utc>>,)> = sqlx::query_as(
r#"
SELECT revoked_at FROM token_revocations WHERE jti = $1
"#,
)
.bind(jti)
.fetch_optional(&self.pool)
.await?;
match result {
Some((Some(_),)) => Ok(true),
Some((None,)) => Ok(false),
None => Ok(true),
}
}
pub async fn revoke_token(&self, jti: &str, reason: &str) -> Result<()> {
sqlx::query(
r#"
UPDATE token_revocations
SET revoked_at = NOW(), revocation_reason = $2
WHERE jti = $1 AND revoked_at IS NULL
"#,
)
.bind(jti)
.bind(reason)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn revoke_all_user_tokens(&self, user_id: uuid::Uuid, reason: &str) -> Result<u64> {
let result = sqlx::query(
r#"
UPDATE token_revocations
SET revoked_at = NOW(), revocation_reason = $2
WHERE user_id = $1 AND revoked_at IS NULL
"#,
)
.bind(user_id)
.bind(reason)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
pub async fn cleanup_expired_tokens(&self) -> Result<u64> {
let result = sqlx::query(
r#"
DELETE FROM token_revocations
WHERE expires_at < NOW() - INTERVAL '1 day'
"#,
)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_clone() {
fn requires_clone<T: Clone>() {}
requires_clone::<Database>();
}
#[tokio::test]
async fn test_database_connect() {
let database_url = "postgresql://test:test@localhost/test_db";
let result = Database::connect(database_url).await;
assert!(
result.is_err(),
"expected connection to fail without a running database, but got: {result:?}"
);
}
#[test]
fn test_database_pool_type() {
fn check_pool_method(_db: &Database) -> &PgPool {
_db.pool()
}
let _: fn(&Database) -> &PgPool = check_pool_method;
}
#[test]
fn test_total_plugins_query_structure() {
let query = "SELECT COUNT(*) FROM plugins";
assert!(query.contains("SELECT"));
assert!(query.contains("COUNT(*)"));
assert!(query.contains("FROM plugins"));
}
#[test]
fn test_total_downloads_query_structure() {
let query = "SELECT COALESCE(SUM(downloads_total)::BIGINT, 0) FROM plugins";
assert!(query.contains("SELECT"));
assert!(query.contains("COALESCE"));
assert!(query.contains("SUM(downloads_total)"));
assert!(query.contains("FROM plugins"));
assert!(query.contains("::BIGINT"));
}
#[test]
fn test_total_users_query_structure() {
let query = "SELECT COUNT(*) FROM users";
assert!(query.contains("SELECT"));
assert!(query.contains("COUNT(*)"));
assert!(query.contains("FROM users"));
}
#[test]
fn test_migration_error_handling() {
let error_msg = "previously applied but is missing";
assert!(error_msg.contains("previously applied"));
assert!(error_msg.contains("missing"));
}
#[tokio::test]
#[ignore] async fn test_database_migration() {
let database_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
if let Ok(db) = Database::connect(&database_url).await {
let result = db.migrate().await;
assert!(result.is_ok() || result.is_err());
}
}
#[tokio::test]
#[ignore] async fn test_get_total_plugins() {
let database_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
if let Ok(db) = Database::connect(&database_url).await {
let _ = db.migrate().await;
let result = db.get_total_plugins().await;
if let Ok(count) = result {
assert!(count >= 0);
}
}
}
#[tokio::test]
#[ignore] async fn test_get_total_downloads() {
let database_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
if let Ok(db) = Database::connect(&database_url).await {
let _ = db.migrate().await;
let result = db.get_total_downloads().await;
if let Ok(count) = result {
assert!(count >= 0);
}
}
}
#[tokio::test]
#[ignore] async fn test_get_total_users() {
let database_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
if let Ok(db) = Database::connect(&database_url).await {
let _ = db.migrate().await;
let result = db.get_total_users().await;
if let Ok(count) = result {
assert!(count >= 0);
}
}
}
#[tokio::test]
#[ignore] async fn test_pool_reuse() {
let database_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
if let Ok(db) = Database::connect(&database_url).await {
let pool1 = db.pool();
let pool2 = db.pool();
assert!(std::ptr::eq(pool1, pool2));
}
}
#[test]
fn test_database_connection_string_validation() {
let valid_urls = vec![
"postgresql://user:pass@localhost/db",
"postgresql://user:pass@localhost:5432/db",
"postgresql://localhost/db",
"postgres://user:pass@host:5432/database?sslmode=require",
];
for url in valid_urls {
assert!(url.starts_with("postgres"));
assert!(url.contains("://"));
}
}
#[test]
fn test_max_connections_config() {
let max_connections = 20;
assert!(max_connections > 0);
assert!(max_connections <= 100); }
#[tokio::test]
#[ignore] async fn test_migration_idempotency() {
let database_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
if let Ok(db) = Database::connect(&database_url).await {
let result1 = db.migrate().await;
let result2 = db.migrate().await;
assert!(result1.is_ok() || result1.is_err());
assert!(result2.is_ok() || result2.is_err());
}
}
#[test]
fn test_query_return_types() {
fn check_total_plugins_type(_: i64) {}
fn check_total_downloads_type(_: i64) {}
fn check_total_users_type(_: i64) {}
let _: fn(i64) = check_total_plugins_type;
let _: fn(i64) = check_total_downloads_type;
let _: fn(i64) = check_total_users_type;
}
#[test]
fn test_database_error_types() {
use anyhow::Result;
fn returns_result() -> Result<()> {
Ok(())
}
let result = returns_result();
assert!(result.is_ok());
}
}