use crate::drivers::{Driver, DriverMutConnectionDeref};
use parking_lot::{Condvar, Mutex, MutexGuard};
#[cfg(feature = "watcher")]
use sqlite_watcher::connection::State;
#[cfg(feature = "watcher")]
use sqlite_watcher::watcher::Watcher;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
pub struct ConnectionPoolConfig {
pub max_read_connection_count: usize,
pub file_path: PathBuf,
pub connection_acquire_timeout: Option<Duration>,
#[cfg(feature = "watcher")]
pub watcher: Arc<Watcher>,
}
pub struct ConnectionPool<T: Driver, A: ConnectionAdapter<T>> {
read_connections: Mutex<Vec<A>>,
reader_condvar: Condvar,
write_connection: Mutex<WatchedConnection<T>>,
config: ConnectionPoolConfig,
}
#[derive(Debug, thiserror::Error)]
pub enum ConnectionPoolError<E> {
#[error(transparent)]
Driver(#[from] E),
#[error("Failed to acquire connection in time")]
ConnectionAcquireTimeout,
#[error("Failed to setup connection watcher")]
WatcherSetup,
#[error(transparent)]
Other(Box<dyn std::error::Error + Send + Sync>),
}
impl<T: Driver, A: ConnectionAdapter<T>> ConnectionPool<T, A> {
pub fn new(
config: ConnectionPoolConfig,
) -> Result<Arc<Self>, ConnectionPoolError<T::ConnectionError>> {
let watched_connection = T::new_write_connection(&config.file_path)
.inspect_err(|e| tracing::error!("Failed to create write connection: {e:?}"))?;
#[cfg(feature = "watcher")]
let watched_connection = WatchedConnection::new(watched_connection).map_err(|e| {
tracing::error!("Failed to setup connection watcher: {e:?}");
ConnectionPoolError::WatcherSetup
})?;
#[cfg(not(feature = "watcher"))]
let watched_connection = WatchedConnection::new(watched_connection);
let mut read_connections = Vec::with_capacity(config.max_read_connection_count);
for _ in 0..config.max_read_connection_count {
read_connections.push(A::from_driver_connection(
T::new_read_connection(&config.file_path)
.inspect_err(|e| tracing::error!("Failed to create read connection: {e:?}"))?,
));
}
Ok(Arc::new(Self {
write_connection: Mutex::new(watched_connection),
read_connections: Mutex::new(read_connections),
reader_condvar: Condvar::new(),
config,
}))
}
pub fn connection(
self: &Arc<Self>,
) -> Result<PooledConnection<T, A>, ConnectionPoolError<T::Error>> {
let mut rd_connections = self.read_connections.lock();
loop {
if let Some(rd_connection) = rd_connections.pop() {
return Ok(PooledConnection::new(self.clone(), rd_connection));
} else if let Some(duration) = self.config.connection_acquire_timeout {
if self
.reader_condvar
.wait_for(&mut rd_connections, duration)
.timed_out()
{
return Err(ConnectionPoolError::ConnectionAcquireTimeout);
}
} else {
self.reader_condvar.wait(&mut rd_connections);
}
}
}
pub(crate) fn transaction_closure<F, R, E>(&self, closure: F) -> Result<R, E>
where
F: FnOnce(&mut Transaction<'_, T>) -> Result<R, E>,
E: From<T::Error>,
{
let mut tx = self.transaction()?;
let result = closure(&mut tx);
if result.is_ok() {
tx.commit()?;
} else {
tx.rollback()?;
}
result
}
pub(crate) fn transaction(&self) -> Result<Transaction<'_, T>, T::Error> {
let writer = self.write_connection.lock();
Transaction::new(
writer,
#[cfg(feature = "watcher")]
&self.config.watcher,
)
}
fn return_to_pool(&self, conn: A) {
let mut read_connections = self.read_connections.lock();
read_connections.push(conn);
drop(read_connections);
self.reader_condvar.notify_one();
}
#[cfg(feature = "watcher")]
pub fn watcher(&self) -> &Arc<Watcher> {
&self.config.watcher
}
}
pub trait ConnectionAdapter<T: Driver> {
fn from_driver_connection(connection: T::Connection) -> Self;
}
pub struct PooledConnection<T: Driver, A: ConnectionAdapter<T>> {
pub(crate) pool: Arc<ConnectionPool<T, A>>,
conn: Option<A>,
}
impl<T: Driver, A: ConnectionAdapter<T>> Drop for PooledConnection<T, A> {
fn drop(&mut self) {
let conn = self.conn.take().expect("Connection should be set");
self.pool.return_to_pool(conn);
}
}
impl<T: Driver, A: ConnectionAdapter<T>> PooledConnection<T, A> {
fn new(pool: Arc<ConnectionPool<T, A>>, connection: A) -> PooledConnection<T, A> {
Self {
pool,
conn: Some(connection),
}
}
pub(crate) fn connection(&self) -> &A {
self.conn.as_ref().expect("Connection should be set")
}
pub(crate) fn connection_mut(&mut self) -> &mut A {
self.conn.as_mut().expect("Connection should be set")
}
}
struct WatchedConnection<T>
where
T: Driver,
{
connection: T::Connection,
#[cfg(feature = "watcher")]
state: State,
}
#[cfg(feature = "watcher")]
impl<T> WatchedConnection<T>
where
T: Driver,
{
fn new(mut connection: T::Connection) -> Result<Self, <T as Driver>::Error> {
use sqlite_watcher::statement::Statement;
State::set_pragmas().execute_mut(&mut connection)?;
State::start_tracking().execute_mut(&mut connection)?;
Ok(Self {
connection,
state: State::new(),
})
}
fn sync_changes(&mut self, watcher: &Watcher) -> Result<(), T::Error> {
use sqlite_watcher::statement::Statement;
if let Some(stmt) = self.state.sync_tables(watcher) {
stmt.execute_mut(&mut self.connection)?;
}
Ok(())
}
fn publish_changes(&mut self, watcher: &Watcher) {
use sqlite_watcher::statement::Statement;
if let Err(e) = self
.state
.publish_changes(watcher)
.execute_mut(&mut self.connection)
{
tracing::error!("Failed to publish updates to watcher: {e:?}");
}
}
}
#[cfg(not(feature = "watcher"))]
impl<T> WatchedConnection<T>
where
T: Driver,
{
fn new(connection: T::Connection) -> Self {
Self { connection }
}
}
pub struct Transaction<'c, T: Driver> {
conn: ManuallyDrop<MutexGuard<'c, WatchedConnection<T>>>,
#[cfg(feature = "watcher")]
watcher: &'c Watcher,
}
impl<'c, T: Driver> Transaction<'c, T> {
fn new(
mut conn: MutexGuard<'c, WatchedConnection<T>>,
#[cfg(feature = "watcher")] watcher: &'c Watcher,
) -> Result<Self, <T as Driver>::Error> {
#[cfg(feature = "watcher")]
conn.sync_changes(watcher)?;
T::begin_transaction(&mut conn.connection, "BEGIN IMMEDIATE")?;
Ok(Self {
conn: ManuallyDrop::new(conn),
#[cfg(feature = "watcher")]
watcher,
})
}
#[allow(clippy::missing_panics_doc)]
pub fn commit(mut self) -> Result<(), <T as Driver>::Error> {
T::commit_transaction(&mut self.conn.connection)?;
#[cfg(feature = "watcher")]
self.conn.publish_changes(self.watcher);
unsafe {
ManuallyDrop::drop(&mut self.conn);
}
std::mem::forget(self);
Ok(())
}
#[allow(clippy::missing_panics_doc)]
pub fn rollback(mut self) -> Result<(), <T as Driver>::Error> {
T::rollback_transaction(&mut self.conn.connection)?;
unsafe {
ManuallyDrop::drop(&mut self.conn);
}
std::mem::forget(self);
Ok(())
}
}
impl<T: Driver> Drop for Transaction<'_, T> {
fn drop(&mut self) {
if let Err(e) = T::rollback_transaction(&mut self.conn.connection) {
tracing::error!("Failed to rollback transaction: {e:?}");
}
unsafe {
ManuallyDrop::drop(&mut self.conn);
}
}
}
impl<T: Driver> Deref for Transaction<'_, T> {
type Target = T::Connection;
fn deref(&self) -> &Self::Target {
&self.conn.connection
}
}
impl<T: DriverMutConnectionDeref> DerefMut for Transaction<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.conn.connection
}
}
pub struct ReadTransaction<'c, T: Driver> {
conn: &'c mut T::Connection,
}
impl<'c, T: Driver> ReadTransaction<'c, T> {
pub(crate) fn new(conn: &'c mut T::Connection) -> Result<Self, <T as Driver>::Error> {
T::begin_transaction(conn, "BEGIN")?;
Ok(Self { conn })
}
#[allow(clippy::missing_panics_doc)]
pub fn commit(self) -> Result<(), <T as Driver>::Error> {
T::commit_transaction(self.conn)?;
std::mem::forget(self);
Ok(())
}
#[allow(clippy::missing_panics_doc)]
pub fn rollback(self) -> Result<(), <T as Driver>::Error> {
T::rollback_transaction(self.conn)?;
std::mem::forget(self);
Ok(())
}
pub(crate) fn scoped<F, R, E>(conn: &mut T::Connection, closure: F) -> Result<R, E>
where
F: FnOnce(&mut ReadTransaction<'_, T>) -> Result<R, E>,
E: From<T::Error>,
{
let mut tx = ReadTransaction::new(conn)?;
let r = closure(&mut tx);
tx.commit()?;
r
}
}
impl<T: Driver> Drop for ReadTransaction<'_, T> {
fn drop(&mut self) {
if let Err(e) = T::rollback_transaction(self.conn) {
tracing::error!("Failed to rollback transaction: {e:?}");
}
}
}
impl<T: Driver> Deref for ReadTransaction<'_, T> {
type Target = T::Connection;
fn deref(&self) -> &Self::Target {
self.conn
}
}
impl<T: DriverMutConnectionDeref> DerefMut for ReadTransaction<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
}
}