sqlite-rwc 0.4.0

Reader Writer Concurrency Setup for Sqlite3
Documentation
use crate::drivers::{Driver, DriverMutConnectionDeref};
use parking_lot::{Condvar, Mutex, MutexGuard};
#[cfg(feature = "watcher")]
use sqlite_watcher::connection::State;
#[cfg(feature = "watcher")]
use sqlite_watcher::watcher::Watcher;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

pub struct ConnectionPoolConfig {
    pub max_read_connection_count: usize,
    pub file_path: PathBuf,
    pub connection_acquire_timeout: Option<Duration>,
    #[cfg(feature = "watcher")]
    pub watcher: Arc<Watcher>,
}

pub struct ConnectionPool<T: Driver, A: ConnectionAdapter<T>> {
    read_connections: Mutex<Vec<A>>,
    reader_condvar: Condvar,
    write_connection: Mutex<WatchedConnection<T>>,
    config: ConnectionPoolConfig,
}

#[derive(Debug, thiserror::Error)]
pub enum ConnectionPoolError<E> {
    #[error(transparent)]
    Driver(#[from] E),
    #[error("Failed to acquire connection in time")]
    ConnectionAcquireTimeout,
    #[error("Failed to setup connection watcher")]
    WatcherSetup,
    #[error(transparent)]
    Other(Box<dyn std::error::Error + Send + Sync>),
}

impl<T: Driver, A: ConnectionAdapter<T>> ConnectionPool<T, A> {
    /// Create a new connection pool with the given `config`.
    ///
    /// The write connection is created first, followed by a read connection for every reader.
    ///
    /// # Errors
    ///
    /// Returns error if the connections could not be initialized.
    pub fn new(
        config: ConnectionPoolConfig,
    ) -> Result<Arc<Self>, ConnectionPoolError<T::ConnectionError>> {
        let watched_connection = T::new_write_connection(&config.file_path)
            .inspect_err(|e| tracing::error!("Failed to create write connection: {e:?}"))?;
        #[cfg(feature = "watcher")]
        let watched_connection = WatchedConnection::new(watched_connection).map_err(|e| {
            tracing::error!("Failed to setup connection watcher: {e:?}");
            ConnectionPoolError::WatcherSetup
        })?;
        #[cfg(not(feature = "watcher"))]
        let watched_connection = WatchedConnection::new(watched_connection);

        let mut read_connections = Vec::with_capacity(config.max_read_connection_count);
        for _ in 0..config.max_read_connection_count {
            read_connections.push(A::from_driver_connection(
                T::new_read_connection(&config.file_path)
                    .inspect_err(|e| tracing::error!("Failed to create read connection: {e:?}"))?,
            ));
        }
        Ok(Arc::new(Self {
            write_connection: Mutex::new(watched_connection),
            read_connections: Mutex::new(read_connections),
            reader_condvar: Condvar::new(),
            config,
        }))
    }

    /// Retrieve a connection from the pool.
    ///
    /// If all the connections are currently in use, we will wait until one is returned to the
    /// pool. If `ConnectionPoolConfig.connection_acquire_timeout` has no value, this method will
    /// block indefinitely.
    ///
    /// # Errors
    ///
    /// Return error if we could not retrieve a connection from the pool before the timeout
    /// triggered.
    pub fn connection(
        self: &Arc<Self>,
    ) -> Result<PooledConnection<T, A>, ConnectionPoolError<T::Error>> {
        let mut rd_connections = self.read_connections.lock();
        loop {
            if let Some(rd_connection) = rd_connections.pop() {
                return Ok(PooledConnection::new(self.clone(), rd_connection));
            } else if let Some(duration) = self.config.connection_acquire_timeout {
                if self
                    .reader_condvar
                    .wait_for(&mut rd_connections, duration)
                    .timed_out()
                {
                    return Err(ConnectionPoolError::ConnectionAcquireTimeout);
                }
            } else {
                self.reader_condvar.wait(&mut rd_connections);
            }
        }
    }

    pub(crate) fn transaction_closure<F, R, E>(&self, closure: F) -> Result<R, E>
    where
        F: FnOnce(&mut Transaction<'_, T>) -> Result<R, E>,
        E: From<T::Error>,
    {
        let mut tx = self.transaction()?;
        let result = closure(&mut tx);
        if result.is_ok() {
            tx.commit()?;
        } else {
            tx.rollback()?;
        }
        result
    }

    pub(crate) fn transaction(&self) -> Result<Transaction<'_, T>, T::Error> {
        let writer = self.write_connection.lock();
        Transaction::new(
            writer,
            #[cfg(feature = "watcher")]
            &self.config.watcher,
        )
    }

    fn return_to_pool(&self, conn: A) {
        let mut read_connections = self.read_connections.lock();
        read_connections.push(conn);
        drop(read_connections);
        self.reader_condvar.notify_one();
    }

    #[cfg(feature = "watcher")]
    pub fn watcher(&self) -> &Arc<Watcher> {
        &self.config.watcher
    }
}

pub trait ConnectionAdapter<T: Driver> {
    fn from_driver_connection(connection: T::Connection) -> Self;
}

pub struct PooledConnection<T: Driver, A: ConnectionAdapter<T>> {
    pub(crate) pool: Arc<ConnectionPool<T, A>>,
    conn: Option<A>,
}

impl<T: Driver, A: ConnectionAdapter<T>> Drop for PooledConnection<T, A> {
    fn drop(&mut self) {
        let conn = self.conn.take().expect("Connection should be set");
        self.pool.return_to_pool(conn);
    }
}

impl<T: Driver, A: ConnectionAdapter<T>> PooledConnection<T, A> {
    fn new(pool: Arc<ConnectionPool<T, A>>, connection: A) -> PooledConnection<T, A> {
        Self {
            pool,
            conn: Some(connection),
        }
    }

    pub(crate) fn connection(&self) -> &A {
        self.conn.as_ref().expect("Connection should be set")
    }

    pub(crate) fn connection_mut(&mut self) -> &mut A {
        self.conn.as_mut().expect("Connection should be set")
    }
}

struct WatchedConnection<T>
where
    T: Driver,
{
    connection: T::Connection,
    #[cfg(feature = "watcher")]
    state: State,
}

#[cfg(feature = "watcher")]
impl<T> WatchedConnection<T>
where
    T: Driver,
{
    fn new(mut connection: T::Connection) -> Result<Self, <T as Driver>::Error> {
        use sqlite_watcher::statement::Statement;
        State::set_pragmas().execute_mut(&mut connection)?;
        State::start_tracking().execute_mut(&mut connection)?;
        Ok(Self {
            connection,
            state: State::new(),
        })
    }
    fn sync_changes(&mut self, watcher: &Watcher) -> Result<(), T::Error> {
        use sqlite_watcher::statement::Statement;
        if let Some(stmt) = self.state.sync_tables(watcher) {
            stmt.execute_mut(&mut self.connection)?;
        }
        Ok(())
    }

    fn publish_changes(&mut self, watcher: &Watcher) {
        use sqlite_watcher::statement::Statement;
        if let Err(e) = self
            .state
            .publish_changes(watcher)
            .execute_mut(&mut self.connection)
        {
            tracing::error!("Failed to publish updates to watcher: {e:?}");
        }
    }
}

#[cfg(not(feature = "watcher"))]
impl<T> WatchedConnection<T>
where
    T: Driver,
{
    fn new(connection: T::Connection) -> Self {
        Self { connection }
    }
}

/// Even though some implementations have their own transaction type (e.g.: rusqlite), they
/// are consumed on commit/rollback. We want to run some extra code after commit and rollback.
pub struct Transaction<'c, T: Driver> {
    conn: ManuallyDrop<MutexGuard<'c, WatchedConnection<T>>>,
    #[cfg(feature = "watcher")]
    watcher: &'c Watcher,
}

impl<'c, T: Driver> Transaction<'c, T> {
    fn new(
        mut conn: MutexGuard<'c, WatchedConnection<T>>,
        #[cfg(feature = "watcher")] watcher: &'c Watcher,
    ) -> Result<Self, <T as Driver>::Error> {
        #[cfg(feature = "watcher")]
        conn.sync_changes(watcher)?;
        T::begin_transaction(&mut conn.connection, "BEGIN IMMEDIATE")?;
        Ok(Self {
            conn: ManuallyDrop::new(conn),
            #[cfg(feature = "watcher")]
            watcher,
        })
    }

    /// Commit the transaction
    ///
    /// # Errors
    ///
    /// Returns error if the commit failed.
    #[allow(clippy::missing_panics_doc)]
    pub fn commit(mut self) -> Result<(), <T as Driver>::Error> {
        T::commit_transaction(&mut self.conn.connection)?;
        #[cfg(feature = "watcher")]
        self.conn.publish_changes(self.watcher);
        unsafe {
            ManuallyDrop::drop(&mut self.conn);
        }
        std::mem::forget(self);

        Ok(())
    }

    /// Rollback the transaction
    ///
    /// # Errors
    ///
    /// Returns errors if the operation failed.
    #[allow(clippy::missing_panics_doc)]
    pub fn rollback(mut self) -> Result<(), <T as Driver>::Error> {
        T::rollback_transaction(&mut self.conn.connection)?;
        unsafe {
            ManuallyDrop::drop(&mut self.conn);
        }
        std::mem::forget(self);
        Ok(())
    }
}

impl<T: Driver> Drop for Transaction<'_, T> {
    fn drop(&mut self) {
        if let Err(e) = T::rollback_transaction(&mut self.conn.connection) {
            tracing::error!("Failed to rollback transaction: {e:?}");
        }
        unsafe {
            ManuallyDrop::drop(&mut self.conn);
        }
    }
}

impl<T: Driver> Deref for Transaction<'_, T> {
    type Target = T::Connection;

    fn deref(&self) -> &Self::Target {
        &self.conn.connection
    }
}

impl<T: DriverMutConnectionDeref> DerefMut for Transaction<'_, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.conn.connection
    }
}

/// Performs an explicit read transaction using `BEGIN` and `END` sql statements.
pub struct ReadTransaction<'c, T: Driver> {
    conn: &'c mut T::Connection,
}

impl<'c, T: Driver> ReadTransaction<'c, T> {
    pub(crate) fn new(conn: &'c mut T::Connection) -> Result<Self, <T as Driver>::Error> {
        T::begin_transaction(conn, "BEGIN")?;
        Ok(Self { conn })
    }

    /// Commit the transaction
    ///
    /// # Errors
    ///
    /// Returns error if the commit failed.
    #[allow(clippy::missing_panics_doc)]
    pub fn commit(self) -> Result<(), <T as Driver>::Error> {
        T::commit_transaction(self.conn)?;
        std::mem::forget(self);

        Ok(())
    }

    /// Rollback the transaction
    ///
    /// # Errors
    ///
    /// Returns errors if the operation failed.
    #[allow(clippy::missing_panics_doc)]
    pub fn rollback(self) -> Result<(), <T as Driver>::Error> {
        T::rollback_transaction(self.conn)?;
        std::mem::forget(self);
        Ok(())
    }

    pub(crate) fn scoped<F, R, E>(conn: &mut T::Connection, closure: F) -> Result<R, E>
    where
        F: FnOnce(&mut ReadTransaction<'_, T>) -> Result<R, E>,
        E: From<T::Error>,
    {
        let mut tx = ReadTransaction::new(conn)?;
        let r = closure(&mut tx);
        tx.commit()?;
        r
    }
}

impl<T: Driver> Drop for ReadTransaction<'_, T> {
    fn drop(&mut self) {
        if let Err(e) = T::rollback_transaction(self.conn) {
            tracing::error!("Failed to rollback transaction: {e:?}");
        }
    }
}
impl<T: Driver> Deref for ReadTransaction<'_, T> {
    type Target = T::Connection;

    fn deref(&self) -> &Self::Target {
        self.conn
    }
}

impl<T: DriverMutConnectionDeref> DerefMut for ReadTransaction<'_, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.conn
    }
}