use rusqlite::{Connection, Transaction, TransactionBehavior};
use tokio_rusqlite::Connection as AsyncConnection;
pub(crate) enum TxOutcome<T> {
Commit(T),
Rollback(T),
}
pub(crate) const BUSY_TIMEOUT_MS: u32 = 15_000;
fn open_pragmas() -> String {
format!(
"PRAGMA busy_timeout={BUSY_TIMEOUT_MS};\
PRAGMA synchronous=NORMAL;\
PRAGMA foreign_keys=ON;"
)
}
fn set_wal_journal_mode(c: &Connection) -> rusqlite::Result<()> {
let deadline =
std::time::Instant::now() + std::time::Duration::from_millis(BUSY_TIMEOUT_MS as u64);
let mut backoff = std::time::Duration::from_millis(1);
loop {
match c.pragma_update(None, "journal_mode", "WAL") {
Ok(()) => return Ok(()),
Err(err) if is_busy(&err) && std::time::Instant::now() < deadline => {
std::thread::sleep(backoff);
backoff = (backoff * 2).min(std::time::Duration::from_millis(50));
}
Err(err) => return Err(err),
}
}
}
fn is_busy(err: &rusqlite::Error) -> bool {
matches!(
err,
rusqlite::Error::SqliteFailure(e, _)
if e.code == rusqlite::ErrorCode::DatabaseBusy
|| e.code == rusqlite::ErrorCode::DatabaseLocked
)
}
#[derive(Clone)]
pub(crate) struct SqliteConnection {
inner: AsyncConnection,
}
impl SqliteConnection {
pub(crate) async fn open(path: &std::path::Path) -> tokio_rusqlite::Result<Self> {
let inner = AsyncConnection::open(path).await?;
let pragmas = open_pragmas();
inner
.call(move |c| {
c.busy_timeout(std::time::Duration::from_millis(BUSY_TIMEOUT_MS as u64))?;
set_wal_journal_mode(c)?;
c.execute_batch(&pragmas)?;
Ok(())
})
.await?;
Ok(Self { inner })
}
pub(crate) async fn open_in_memory() -> tokio_rusqlite::Result<Self> {
let inner = AsyncConnection::open_in_memory().await?;
let pragmas = open_pragmas();
inner
.call(move |c| {
c.busy_timeout(std::time::Duration::from_millis(BUSY_TIMEOUT_MS as u64))?;
c.execute_batch(&pragmas)?;
Ok(())
})
.await?;
Ok(Self { inner })
}
pub(crate) async fn open_readonly(path: &std::path::Path) -> tokio_rusqlite::Result<Self> {
let path = path.to_path_buf();
let inner = AsyncConnection::open_with_flags(
path,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
)
.await?;
inner
.call(move |c| {
c.busy_timeout(std::time::Duration::from_secs(1))?;
c.execute_batch("PRAGMA cache_size = -500;")?;
Ok(())
})
.await?;
Ok(Self { inner })
}
pub(crate) async fn call<T, F>(&self, f: F) -> rusqlite::Result<T>
where
T: Send + 'static,
F: FnOnce(&mut Connection) -> rusqlite::Result<T> + Send + 'static,
{
flatten(self.inner.call(move |c| Ok(f(c))).await)
}
pub(crate) async fn write<T, F>(&self, f: F) -> rusqlite::Result<T>
where
T: Send + 'static,
F: FnOnce(&Transaction<'_>) -> rusqlite::Result<T> + Send + 'static,
{
flatten(
self.inner
.call(move |c| {
let tx = c.transaction_with_behavior(TransactionBehavior::Immediate)?;
let value = f(&tx)?;
tx.commit()?;
Ok(Ok(value))
})
.await,
)
}
pub(crate) async fn write_flow<T, F>(&self, f: F) -> rusqlite::Result<T>
where
T: Send + 'static,
F: FnOnce(&Transaction<'_>) -> rusqlite::Result<TxOutcome<T>> + Send + 'static,
{
flatten(
self.inner
.call(move |c| {
let tx = c.transaction_with_behavior(TransactionBehavior::Immediate)?;
let outcome = f(&tx)?;
let value = match outcome {
TxOutcome::Commit(value) => {
tx.commit()?;
value
}
TxOutcome::Rollback(value) => {
tx.rollback()?;
value
}
};
Ok(Ok(value))
})
.await,
)
}
}
fn flatten<T>(result: tokio_rusqlite::Result<rusqlite::Result<T>>) -> rusqlite::Result<T> {
match result {
Ok(inner) => inner,
Err(tokio_rusqlite::Error::Error(err)) => Err(err),
Err(other) => Err(rusqlite::Error::ToSqlConversionFailure(Box::new(
std::io::Error::other(other.to_string()),
))),
}
}