use crate::drivers::Driver;
use rusqlite::OpenFlags;
use std::path::Path;
#[derive(Debug)]
pub struct RusqliteDriver;
pub type RusqliteConnectionPool = crate::SyncConnectionPool<RusqliteDriver>;
pub type RusqlitePooledConnection = crate::SyncPooledConnection<RusqliteDriver>;
pub type RusqliteConnectionPoolError =
crate::ConnectionPoolError<<RusqliteDriver as Driver>::Error>;
pub type RusqliteTransaction<'t> = crate::Transaction<'t, RusqliteDriver>;
pub type RusqliteReadTransaction<'t> = crate::ReadTransaction<'t, RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncConnectionPool = crate::AsyncConnectionPool<RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncPooledConnection = crate::AsyncPooledConnection<RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncTransaction<'t> = crate::AsyncTransaction<'t, RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncReadTransaction<'t> = crate::AsyncReadTransaction<'t, RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncConnectionError =
crate::AsyncConnectionError<<RusqliteDriver as Driver>::Error>;
impl RusqliteDriver {
fn new_connection(
path: &Path,
open_flags: OpenFlags,
) -> rusqlite::Result<rusqlite::Connection> {
let conn = rusqlite::Connection::open_with_flags(path, open_flags)?;
conn.pragma_update(None, "journal_mode", "wal")?;
conn.pragma_update(None, "foreign_keys", "ON")?;
conn.pragma_update(None, "synchronous", "FULL")?;
Ok(conn)
}
}
impl Driver for RusqliteDriver {
type Connection = rusqlite::Connection;
type Error = rusqlite::Error;
type ConnectionError = rusqlite::Error;
type ConnectionRef<'c> = &'c rusqlite::Connection;
fn new_read_connection(path: &Path) -> Result<Self::Connection, Self::ConnectionError> {
let open_flags = OpenFlags::SQLITE_OPEN_READ_ONLY
| OpenFlags::SQLITE_OPEN_NO_MUTEX
| OpenFlags::SQLITE_OPEN_URI;
Self::new_connection(path, open_flags)
}
fn new_write_connection(path: &Path) -> Result<Self::Connection, Self::ConnectionError> {
let open_flags = OpenFlags::SQLITE_OPEN_READ_WRITE
| OpenFlags::SQLITE_OPEN_CREATE
| OpenFlags::SQLITE_OPEN_NO_MUTEX
| OpenFlags::SQLITE_OPEN_URI;
Self::new_connection(path, open_flags)
}
fn begin_transaction(connection: &mut Self::Connection, sql: &str) -> Result<(), Self::Error> {
connection.execute_batch(sql)
}
fn commit_transaction(connection: &mut Self::Connection) -> Result<(), Self::Error> {
connection.execute_batch("COMMIT")
}
fn rollback_transaction(connection: &mut Self::Connection) -> Result<(), Self::Error> {
connection.execute_batch("ROLLBACK")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::async_adapter::{AsyncConnectionPool, AsyncRuntime};
use crate::pool::{ConnectionPool, ConnectionPoolConfig};
use crate::{AsyncConnectionError, ReadTransaction, SyncConnectionPool, Transaction};
use sqlite_watcher::watcher::Watcher;
use std::collections::BTreeSet;
use std::sync::Arc;
use tempdir::TempDir;
fn create_table(tx: &mut Transaction<'_, RusqliteDriver>) -> rusqlite::Result<()> {
tx.execute("CREATE TABLE foo (id INTEGER PRIMARY KEY)", [])?;
tx.execute("INSERT INTO foo (id) VALUES (1)", [])?;
Ok(())
}
fn read_value(conn: &rusqlite::Connection) -> rusqlite::Result<u32> {
conn.query_row("SELECT id FROM foo", [], |r| r.get(0))
}
fn read_value_tx(tx: &mut Transaction<'_, RusqliteDriver>) -> rusqlite::Result<u32> {
read_value(tx)
}
fn read_value_read_tx(tx: &mut ReadTransaction<'_, RusqliteDriver>) -> rusqlite::Result<u32> {
read_value(tx)
}
#[test]
fn read_scope_query() {
let (pool, _dir) = new_pool();
let mut conn = pool.connection().unwrap();
conn.transaction_closure(create_table).unwrap();
let id: u32 = conn.read_transaction_closure(read_value_read_tx).unwrap();
assert_eq!(id, 1);
}
#[tokio::test]
async fn read_scope_query_async() {
let (pool, _dir) = new_pool_async().await;
let mut conn = pool.connection_async::<TokioRuntime>().await.unwrap();
conn.transaction_closure(create_table).await.unwrap();
let id: u32 = conn
.read_transaction_closure(read_value_read_tx)
.await
.unwrap();
assert_eq!(id, 1);
}
#[tokio::test]
async fn transaction_async() {
let (pool, _dir) = new_pool_async().await;
let mut conn = pool.connection_async::<TokioRuntime>().await.unwrap();
let id = conn
.transaction_closure_async(async |tx| {
tx.run(create_table).await?;
tx.run(read_value_tx).await
})
.await
.unwrap();
assert_eq!(id, 1);
let id: u32 = conn
.read_transaction_closure_async(async |tx| tx.run(read_value_read_tx).await)
.await
.unwrap();
assert_eq!(id, 1);
}
#[tokio::test]
async fn transaction_async_reports_actual_error_on_failure() {
let (pool, _dir) = new_pool_async().await;
let mut conn = pool.connection_async::<TokioRuntime>().await.unwrap();
let err = conn
.transaction_closure_async(async |tx| tx.run(read_value_tx).await)
.await
.unwrap_err();
assert!(matches!(
err,
AsyncConnectionError::Connection(rusqlite::Error::SqliteFailure(_, _))
));
}
#[cfg(feature = "watcher")]
#[test]
fn watcher() {
struct Observer {
sender: flume::Sender<()>,
}
impl Observer {
fn new() -> (Self, flume::Receiver<()>) {
let (tx, rx) = flume::bounded(1);
(Self { sender: tx }, rx)
}
}
impl sqlite_watcher::watcher::TableObserver for Observer {
fn tables(&self) -> Vec<String> {
vec!["foo".to_owned()]
}
fn on_tables_changed(&self, _: &BTreeSet<String>) {
let _ = self.sender.send(());
}
}
let (pool, _dir) = new_pool();
let (observer, rx) = Observer::new();
let mut conn = pool.connection().unwrap();
conn.transaction_closure(create_table).unwrap();
pool.watcher().add_observer(Box::new(observer)).unwrap();
conn.transaction_closure(|tx| tx.execute("INSERT INTO foo (id) VALUES (30)", []))
.unwrap();
rx.recv().unwrap();
}
#[test]
#[should_panic(
expected = "SqliteFailure(Error { code: ReadOnly, extended_code: 8 }, Some(\"attempt to write a readonly database\"))"
)]
fn panic_on_write_in_read_scope() {
let (pool, _dir) = new_pool();
let mut conn = pool.connection().unwrap();
conn.transaction_closure(|tx| {
tx.execute("CREATE TABLE foo (id INTEGER PRIMARY KEY)", [])
.unwrap();
Ok::<_, rusqlite::Error>(())
})
.unwrap();
let _: u32 = conn
.read_transaction_closure(|conn| {
conn.execute("INSERT INTO foo (id) VALUES (1)", [])?;
conn.query_row("SELECT id FROM foo", [], |r| r.get(0))
})
.unwrap();
}
fn new_pool() -> (Arc<SyncConnectionPool<RusqliteDriver>>, TempDir) {
let dir = tempdir::TempDir::new("rusqlite-test").unwrap();
let pool = ConnectionPool::new(ConnectionPoolConfig {
max_read_connection_count: 4,
file_path: dir.path().join("sqlite.db"),
connection_acquire_timeout: None,
#[cfg(feature = "watcher")]
watcher: Watcher::new().unwrap(),
})
.unwrap();
(pool, dir)
}
struct TokioRuntime;
impl AsyncRuntime for TokioRuntime {
type JoinError = tokio::task::JoinError;
type JoinHandle<T: Send + 'static> = tokio::task::JoinHandle<T>;
fn spawn_blocking<F, T>(closure: F) -> Self::JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
tokio::task::spawn_blocking(closure)
}
}
async fn new_pool_async() -> (Arc<AsyncConnectionPool<RusqliteDriver>>, TempDir) {
let dir = tempdir::TempDir::new("rusqlite-test").unwrap();
let pool = AsyncConnectionPool::<RusqliteDriver>::new_async::<TokioRuntime>(
ConnectionPoolConfig {
max_read_connection_count: 4,
file_path: dir.path().join("sqlite.db"),
connection_acquire_timeout: None,
#[cfg(feature = "watcher")]
watcher: Watcher::new().unwrap(),
},
)
.await
.unwrap();
(pool, dir)
}
}