use std::{future::Future, path::Path, time::Duration};
use sea_orm::{
ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr, ExecResult,
QueryResult, Statement, TransactionTrait,
};
use crate::{pool, retry, retry::IsSqliteBusy};
#[derive(Debug, Clone)]
pub struct DbReadConnection(DatabaseConnection);
#[derive(Debug, Clone)]
pub struct DbWriteConnection(DatabaseConnection);
impl DbReadConnection {
pub fn new(inner: DatabaseConnection) -> Self {
Self(inner)
}
pub async fn open(
db_path: &Path,
max_connections: u32,
connect_timeout: Duration,
busy_timeout: Duration,
) -> Result<Self, sqlx::Error> {
let conn = pool::build_pool(
db_path,
max_connections,
connect_timeout,
busy_timeout,
false,
)
.await?;
Ok(Self(conn))
}
pub fn inner(&self) -> &DatabaseConnection {
&self.0
}
}
impl DbWriteConnection {
pub fn new(inner: DatabaseConnection) -> Self {
Self(inner)
}
pub async fn open(
db_path: &Path,
connect_timeout: Duration,
busy_timeout: Duration,
) -> Result<Self, sqlx::Error> {
let conn = pool::build_pool(db_path, 1, connect_timeout, busy_timeout, true).await?;
Ok(Self(conn))
}
pub fn inner(&self) -> &DatabaseConnection {
&self.0
}
pub async fn transaction<F, Fut, T, E>(&self, f: F) -> Result<T, E>
where
F: Fn(DatabaseTransaction) -> Fut,
Fut: Future<Output = Result<(DatabaseTransaction, T), E>> + Send,
T: Send,
E: From<DbErr> + IsSqliteBusy,
{
retry::retry_on_busy(|| async {
let txn = self.0.begin().await?;
let (txn, value) = f(txn).await?;
txn.commit().await?;
Ok(value)
})
.await
}
}
#[async_trait::async_trait]
impl ConnectionTrait for DbReadConnection {
fn get_database_backend(&self) -> DbBackend {
self.0.get_database_backend()
}
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
self.0.execute(stmt).await
}
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
self.0.execute_unprepared(sql).await
}
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
self.0.query_one(stmt).await
}
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
self.0.query_all(stmt).await
}
fn support_returning(&self) -> bool {
self.0.support_returning()
}
fn is_mock_connection(&self) -> bool {
self.0.is_mock_connection()
}
}
#[async_trait::async_trait]
impl ConnectionTrait for DbWriteConnection {
fn get_database_backend(&self) -> DbBackend {
self.0.get_database_backend()
}
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
retry::retry_on_busy(|| async { self.0.execute(stmt.clone()).await }).await
}
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
retry::retry_on_busy(|| async { self.0.execute_unprepared(sql).await }).await
}
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
retry::retry_on_busy(|| async { self.0.query_one(stmt.clone()).await }).await
}
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
retry::retry_on_busy(|| async { self.0.query_all(stmt.clone()).await }).await
}
fn support_returning(&self) -> bool {
self.0.support_returning()
}
fn is_mock_connection(&self) -> bool {
self.0.is_mock_connection()
}
}
#[cfg(test)]
mod tests {
use super::*;
const TIMEOUT: Duration = Duration::from_secs(5);
#[tokio::test]
async fn read_open_does_not_create_db() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("catalog.db");
let result = DbReadConnection::open(&db_path, 1, TIMEOUT, TIMEOUT).await;
assert!(result.is_err(), "read open should fail on a missing db");
assert!(
!db_path.exists(),
"read open must not create the catalog db file"
);
}
#[tokio::test]
async fn write_open_creates_db() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("catalog.db");
let conn = DbWriteConnection::open(&db_path, TIMEOUT, TIMEOUT).await;
assert!(conn.is_ok(), "write open should succeed");
assert!(
db_path.exists(),
"write open should create the catalog db file"
);
}
}