sqlite-rwc 0.4.0

Reader Writer Concurrency Setup for Sqlite3
Documentation
use crate::drivers::Driver;
use rusqlite::OpenFlags;
use std::path::Path;

#[derive(Debug)]
pub struct RusqliteDriver;

pub type RusqliteConnectionPool = crate::SyncConnectionPool<RusqliteDriver>;
pub type RusqlitePooledConnection = crate::SyncPooledConnection<RusqliteDriver>;
pub type RusqliteConnectionPoolError =
    crate::ConnectionPoolError<<RusqliteDriver as Driver>::Error>;
pub type RusqliteTransaction<'t> = crate::Transaction<'t, RusqliteDriver>;
pub type RusqliteReadTransaction<'t> = crate::ReadTransaction<'t, RusqliteDriver>;

#[cfg(feature = "async")]
pub type RusqliteAsyncConnectionPool = crate::AsyncConnectionPool<RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncPooledConnection = crate::AsyncPooledConnection<RusqliteDriver>;

#[cfg(feature = "async")]
pub type RusqliteAsyncTransaction<'t> = crate::AsyncTransaction<'t, RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncReadTransaction<'t> = crate::AsyncReadTransaction<'t, RusqliteDriver>;
#[cfg(feature = "async")]
pub type RusqliteAsyncConnectionError =
    crate::AsyncConnectionError<<RusqliteDriver as Driver>::Error>;

impl RusqliteDriver {
    fn new_connection(
        path: &Path,
        open_flags: OpenFlags,
    ) -> rusqlite::Result<rusqlite::Connection> {
        let conn = rusqlite::Connection::open_with_flags(path, open_flags)?;
        conn.pragma_update(None, "journal_mode", "wal")?;
        conn.pragma_update(None, "foreign_keys", "ON")?;
        conn.pragma_update(None, "synchronous", "FULL")?;

        Ok(conn)
    }
}
impl Driver for RusqliteDriver {
    type Connection = rusqlite::Connection;
    type Error = rusqlite::Error;

    type ConnectionError = rusqlite::Error;

    type ConnectionRef<'c> = &'c rusqlite::Connection;

    fn new_read_connection(path: &Path) -> Result<Self::Connection, Self::ConnectionError> {
        let open_flags = OpenFlags::SQLITE_OPEN_READ_ONLY
            | OpenFlags::SQLITE_OPEN_NO_MUTEX
            | OpenFlags::SQLITE_OPEN_URI;
        Self::new_connection(path, open_flags)
    }

    fn new_write_connection(path: &Path) -> Result<Self::Connection, Self::ConnectionError> {
        let open_flags = OpenFlags::SQLITE_OPEN_READ_WRITE
            | OpenFlags::SQLITE_OPEN_CREATE
            | OpenFlags::SQLITE_OPEN_NO_MUTEX
            | OpenFlags::SQLITE_OPEN_URI;
        Self::new_connection(path, open_flags)
    }

    fn begin_transaction(connection: &mut Self::Connection, sql: &str) -> Result<(), Self::Error> {
        connection.execute_batch(sql)
    }

    fn commit_transaction(connection: &mut Self::Connection) -> Result<(), Self::Error> {
        connection.execute_batch("COMMIT")
    }

    fn rollback_transaction(connection: &mut Self::Connection) -> Result<(), Self::Error> {
        connection.execute_batch("ROLLBACK")
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::async_adapter::{AsyncConnectionPool, AsyncRuntime};
    use crate::pool::{ConnectionPool, ConnectionPoolConfig};
    use crate::{AsyncConnectionError, ReadTransaction, SyncConnectionPool, Transaction};
    use sqlite_watcher::watcher::Watcher;
    use std::collections::BTreeSet;
    use std::sync::Arc;
    use tempdir::TempDir;

    fn create_table(tx: &mut Transaction<'_, RusqliteDriver>) -> rusqlite::Result<()> {
        tx.execute("CREATE TABLE foo (id INTEGER PRIMARY KEY)", [])?;
        tx.execute("INSERT INTO foo (id) VALUES (1)", [])?;
        Ok(())
    }

    fn read_value(conn: &rusqlite::Connection) -> rusqlite::Result<u32> {
        conn.query_row("SELECT id FROM foo", [], |r| r.get(0))
    }
    fn read_value_tx(tx: &mut Transaction<'_, RusqliteDriver>) -> rusqlite::Result<u32> {
        read_value(tx)
    }

    fn read_value_read_tx(tx: &mut ReadTransaction<'_, RusqliteDriver>) -> rusqlite::Result<u32> {
        read_value(tx)
    }

    #[test]
    fn read_scope_query() {
        let (pool, _dir) = new_pool();
        let mut conn = pool.connection().unwrap();
        conn.transaction_closure(create_table).unwrap();

        let id: u32 = conn.read_transaction_closure(read_value_read_tx).unwrap();
        assert_eq!(id, 1);
    }

    #[tokio::test]
    async fn read_scope_query_async() {
        let (pool, _dir) = new_pool_async().await;
        let mut conn = pool.connection_async::<TokioRuntime>().await.unwrap();
        conn.transaction_closure(create_table).await.unwrap();

        let id: u32 = conn
            .read_transaction_closure(read_value_read_tx)
            .await
            .unwrap();
        assert_eq!(id, 1);
    }

    #[tokio::test]
    async fn transaction_async() {
        let (pool, _dir) = new_pool_async().await;
        let mut conn = pool.connection_async::<TokioRuntime>().await.unwrap();
        let id = conn
            .transaction_closure_async(async |tx| {
                tx.run(create_table).await?;
                tx.run(read_value_tx).await
            })
            .await
            .unwrap();

        assert_eq!(id, 1);

        let id: u32 = conn
            .read_transaction_closure_async(async |tx| tx.run(read_value_read_tx).await)
            .await
            .unwrap();
        assert_eq!(id, 1);
    }

    #[tokio::test]
    async fn transaction_async_reports_actual_error_on_failure() {
        let (pool, _dir) = new_pool_async().await;
        let mut conn = pool.connection_async::<TokioRuntime>().await.unwrap();
        let err = conn
            .transaction_closure_async(async |tx| tx.run(read_value_tx).await)
            .await
            .unwrap_err();

        assert!(matches!(
            err,
            AsyncConnectionError::Connection(rusqlite::Error::SqliteFailure(_, _))
        ));
    }

    #[cfg(feature = "watcher")]
    #[test]
    fn watcher() {
        struct Observer {
            sender: flume::Sender<()>,
        }

        impl Observer {
            fn new() -> (Self, flume::Receiver<()>) {
                let (tx, rx) = flume::bounded(1);
                (Self { sender: tx }, rx)
            }
        }
        impl sqlite_watcher::watcher::TableObserver for Observer {
            fn tables(&self) -> Vec<String> {
                vec!["foo".to_owned()]
            }

            fn on_tables_changed(&self, _: &BTreeSet<String>) {
                let _ = self.sender.send(());
            }
        }

        let (pool, _dir) = new_pool();
        let (observer, rx) = Observer::new();
        let mut conn = pool.connection().unwrap();
        conn.transaction_closure(create_table).unwrap();
        pool.watcher().add_observer(Box::new(observer)).unwrap();

        conn.transaction_closure(|tx| tx.execute("INSERT INTO foo (id) VALUES (30)", []))
            .unwrap();

        rx.recv().unwrap();
    }

    #[test]
    #[should_panic(
        expected = "SqliteFailure(Error { code: ReadOnly, extended_code: 8 }, Some(\"attempt to write a readonly database\"))"
    )]
    fn panic_on_write_in_read_scope() {
        let (pool, _dir) = new_pool();
        let mut conn = pool.connection().unwrap();

        conn.transaction_closure(|tx| {
            tx.execute("CREATE TABLE foo (id INTEGER PRIMARY KEY)", [])
                .unwrap();
            Ok::<_, rusqlite::Error>(())
        })
        .unwrap();

        let _: u32 = conn
            .read_transaction_closure(|conn| {
                conn.execute("INSERT INTO foo (id) VALUES (1)", [])?;
                conn.query_row("SELECT id FROM foo", [], |r| r.get(0))
            })
            .unwrap();
    }

    fn new_pool() -> (Arc<SyncConnectionPool<RusqliteDriver>>, TempDir) {
        let dir = tempdir::TempDir::new("rusqlite-test").unwrap();
        let pool = ConnectionPool::new(ConnectionPoolConfig {
            max_read_connection_count: 4,
            file_path: dir.path().join("sqlite.db"),
            connection_acquire_timeout: None,
            #[cfg(feature = "watcher")]
            watcher: Watcher::new().unwrap(),
        })
        .unwrap();
        (pool, dir)
    }

    struct TokioRuntime;
    impl AsyncRuntime for TokioRuntime {
        type JoinError = tokio::task::JoinError;
        type JoinHandle<T: Send + 'static> = tokio::task::JoinHandle<T>;

        fn spawn_blocking<F, T>(closure: F) -> Self::JoinHandle<T>
        where
            F: FnOnce() -> T + Send + 'static,
            T: Send + 'static,
        {
            tokio::task::spawn_blocking(closure)
        }
    }
    async fn new_pool_async() -> (Arc<AsyncConnectionPool<RusqliteDriver>>, TempDir) {
        let dir = tempdir::TempDir::new("rusqlite-test").unwrap();
        let pool = AsyncConnectionPool::<RusqliteDriver>::new_async::<TokioRuntime>(
            ConnectionPoolConfig {
                max_read_connection_count: 4,
                file_path: dir.path().join("sqlite.db"),
                connection_acquire_timeout: None,
                #[cfg(feature = "watcher")]
                watcher: Watcher::new().unwrap(),
            },
        )
        .await
        .unwrap();
        (pool, dir)
    }
}