#![allow(clippy::unreachable)]
use async_trait::async_trait;
use sqlx::Transaction;
use std::str::FromStr;
use sqlx::any::{self, AnyPoolOptions};
use sqlx::AnyPool;
use sqlx::ConnectOptions;
use tracing::instrument;
pub mod init;
pub mod models;
#[async_trait]
pub trait Db {
async fn connect(url: &str) -> anyhow::Result<DatabaseConnection>;
}
#[async_trait]
pub trait Tx {
async fn begin(pool: AnyPool) -> anyhow::Result<DatabaseTransaction>;
async fn commit(self) -> anyhow::Result<()>;
async fn rollback(self) -> anyhow::Result<()>;
}
#[derive(Debug, Clone)]
pub enum DatabaseKind {
Sqlite,
Postgres,
}
#[derive(Debug, Clone)]
pub struct DatabaseConnection {
pub pool: AnyPool,
pub kind: DatabaseKind,
}
pub struct DatabaseTransaction {
pub tx: Transaction<'static, sqlx::Any>,
}
#[async_trait]
impl Db for DatabaseConnection {
#[instrument(level = "trace")]
async fn connect(db_url: &str) -> anyhow::Result<Self> {
any::install_default_drivers();
let options = any::AnyConnectOptions::from_str(db_url)?.disable_statement_logging();
let pool = AnyPoolOptions::new()
.max_connections(50)
.connect_with(options)
.await?;
let connection = match db_url {
url if url.starts_with("sqlite:///") => Self {
pool,
kind: DatabaseKind::Sqlite,
},
url if url.starts_with("postgres://") => Self {
pool,
kind: DatabaseKind::Postgres,
},
_ => anyhow::bail!("Unsupported database URL: {}", db_url),
};
Ok(connection)
}
}
#[async_trait]
impl Tx for DatabaseTransaction {
async fn begin(pool: AnyPool) -> anyhow::Result<Self> {
let tx = pool.begin().await?;
Ok(Self { tx })
}
async fn commit(self) -> anyhow::Result<()> {
self.tx.commit().await?;
Ok(())
}
async fn rollback(self) -> anyhow::Result<()> {
self.tx.rollback().await?;
Ok(())
}
}