use async_trait::async_trait;
use sqlx::Transaction;
use std::str::FromStr as _;
use std::time;
use std::thread::available_parallelism;
use sqlx::any::{self, AnyPoolOptions};
use sqlx::AnyPool;
use sqlx::ConnectOptions as _;
use tracing::{instrument, log};
pub mod init;
pub mod models;
#[async_trait]
pub trait Db {
async fn connect(url: &str) -> anyhow::Result<DatabaseConnection>;
}
#[async_trait]
pub trait Tx {
async fn begin(pool: AnyPool) -> anyhow::Result<DatabaseTransaction>;
async fn commit(self) -> anyhow::Result<()>;
async fn rollback(self) -> anyhow::Result<()>;
}
#[derive(Debug, Clone)]
pub enum DatabaseKind {
Sqlite,
}
#[derive(Debug, Clone)]
pub struct DatabaseConnection {
pub pool: AnyPool,
pub kind: DatabaseKind,
}
pub struct DatabaseTransaction {
pub tx: Transaction<'static, sqlx::Any>,
}
#[async_trait]
impl Db for DatabaseConnection {
#[instrument(level = "trace")]
async fn connect(db_url: &str) -> anyhow::Result<Self> {
any::install_default_drivers();
let num_cpus = match available_parallelism() {
Ok(cpus) => u32::try_from(cpus.get())?,
Err(_) => 4,
};
let options = any::AnyConnectOptions::from_str(db_url)?
.log_slow_statements(log::LevelFilter::Warn, time::Duration::from_secs(1));
let pool = AnyPoolOptions::new()
.max_connections(2 * num_cpus)
.connect_with(options)
.await?;
let connection = match db_url {
url if url.starts_with("sqlite:///") => Self {
pool,
kind: DatabaseKind::Sqlite,
},
_ => anyhow::bail!("Unsupported database URL: {}", db_url),
};
sqlx::query("PRAGMA journal_mode=WAL;")
.execute(&connection.pool)
.await?;
Ok(connection)
}
}
#[async_trait]
impl Tx for DatabaseTransaction {
async fn begin(pool: AnyPool) -> anyhow::Result<Self> {
let tx = pool.begin().await?;
Ok(Self { tx })
}
async fn commit(self) -> anyhow::Result<()> {
self.tx.commit().await?;
Ok(())
}
async fn rollback(self) -> anyhow::Result<()> {
self.tx.rollback().await?;
Ok(())
}
}