sqlite_rwc/
pool.rs

1use crate::drivers::{Driver, DriverMutConnectionDeref};
2use parking_lot::{Condvar, Mutex, MutexGuard};
3#[cfg(feature = "watcher")]
4use sqlite_watcher::connection::State;
5#[cfg(feature = "watcher")]
6use sqlite_watcher::watcher::Watcher;
7use std::mem::ManuallyDrop;
8use std::ops::{Deref, DerefMut};
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12
13pub struct ConnectionPoolConfig {
14    pub max_read_connection_count: usize,
15    pub file_path: PathBuf,
16    pub connection_acquire_timeout: Option<Duration>,
17    #[cfg(feature = "watcher")]
18    pub watcher: Arc<Watcher>,
19}
20
21pub struct ConnectionPool<T: Driver, A: ConnectionAdapter<T>> {
22    read_connections: Mutex<Vec<A>>,
23    reader_condvar: Condvar,
24    write_connection: Mutex<WatchedConnection<T>>,
25    config: ConnectionPoolConfig,
26}
27
28#[derive(Debug, thiserror::Error)]
29pub enum ConnectionPoolError<E> {
30    #[error(transparent)]
31    Driver(#[from] E),
32    #[error("Failed to acquire connection in time")]
33    ConnectionAcquireTimeout,
34    #[error("Failed to setup connection watcher")]
35    WatcherSetup,
36    #[error(transparent)]
37    Other(Box<dyn std::error::Error + Send + Sync>),
38}
39
40impl<T: Driver, A: ConnectionAdapter<T>> ConnectionPool<T, A> {
41    /// Create a new connection pool with the given `config`.
42    ///
43    /// The write connection is created first, followed by a read connection for every reader.
44    ///
45    /// # Errors
46    ///
47    /// Returns error if the connections could not be initialized.
48    pub fn new(
49        config: ConnectionPoolConfig,
50    ) -> Result<Arc<Self>, ConnectionPoolError<T::ConnectionError>> {
51        let watched_connection = T::new_write_connection(&config.file_path)
52            .inspect_err(|e| tracing::error!("Failed to create write connection: {e:?}"))?;
53        #[cfg(feature = "watcher")]
54        let watched_connection = WatchedConnection::new(watched_connection).map_err(|e| {
55            tracing::error!("Failed to setup connection watcher: {e:?}");
56            ConnectionPoolError::WatcherSetup
57        })?;
58        #[cfg(not(feature = "watcher"))]
59        let watched_connection = WatchedConnection::new(watched_connection);
60
61        let mut read_connections = Vec::with_capacity(config.max_read_connection_count);
62        for _ in 0..config.max_read_connection_count {
63            read_connections.push(A::from_driver_connection(
64                T::new_read_connection(&config.file_path)
65                    .inspect_err(|e| tracing::error!("Failed to create read connection: {e:?}"))?,
66            ));
67        }
68        Ok(Arc::new(Self {
69            write_connection: Mutex::new(watched_connection),
70            read_connections: Mutex::new(read_connections),
71            reader_condvar: Condvar::new(),
72            config,
73        }))
74    }
75
76    /// Retrieve a connection from the pool.
77    ///
78    /// If all the connections are currently in use, we will wait until one is returned to the
79    /// pool. If `ConnectionPoolConfig.connection_acquire_timeout` has no value, this method will
80    /// block indefinitely.
81    ///
82    /// # Errors
83    ///
84    /// Return error if we could not retrieve a connection from the pool before the timeout
85    /// triggered.
86    pub fn connection(
87        self: &Arc<Self>,
88    ) -> Result<PooledConnection<T, A>, ConnectionPoolError<T::Error>> {
89        let mut rd_connections = self.read_connections.lock();
90        loop {
91            if let Some(rd_connection) = rd_connections.pop() {
92                return Ok(PooledConnection::new(self.clone(), rd_connection));
93            } else if let Some(duration) = self.config.connection_acquire_timeout {
94                if self
95                    .reader_condvar
96                    .wait_for(&mut rd_connections, duration)
97                    .timed_out()
98                {
99                    return Err(ConnectionPoolError::ConnectionAcquireTimeout);
100                }
101            } else {
102                self.reader_condvar.wait(&mut rd_connections);
103            }
104        }
105    }
106
107    pub(crate) fn transaction_closure<F, R, E>(&self, closure: F) -> Result<R, E>
108    where
109        F: FnOnce(&mut Transaction<'_, T>) -> Result<R, E>,
110        E: From<T::Error>,
111    {
112        let mut tx = self.transaction()?;
113        let result = closure(&mut tx);
114        if result.is_ok() {
115            tx.commit()?;
116        } else {
117            tx.rollback()?;
118        }
119        result
120    }
121
122    pub(crate) fn transaction(&self) -> Result<Transaction<'_, T>, T::Error> {
123        let writer = self.write_connection.lock();
124        Transaction::new(
125            writer,
126            #[cfg(feature = "watcher")]
127            &self.config.watcher,
128        )
129    }
130
131    fn return_to_pool(&self, conn: A) {
132        let mut read_connections = self.read_connections.lock();
133        read_connections.push(conn);
134        drop(read_connections);
135        self.reader_condvar.notify_one();
136    }
137
138    #[cfg(feature = "watcher")]
139    pub fn watcher(&self) -> &Arc<Watcher> {
140        &self.config.watcher
141    }
142}
143
144pub trait ConnectionAdapter<T: Driver> {
145    fn from_driver_connection(connection: T::Connection) -> Self;
146}
147
148pub struct PooledConnection<T: Driver, A: ConnectionAdapter<T>> {
149    pub(crate) pool: Arc<ConnectionPool<T, A>>,
150    conn: Option<A>,
151}
152
153impl<T: Driver, A: ConnectionAdapter<T>> Drop for PooledConnection<T, A> {
154    fn drop(&mut self) {
155        let conn = self.conn.take().expect("Connection should be set");
156        self.pool.return_to_pool(conn);
157    }
158}
159
160impl<T: Driver, A: ConnectionAdapter<T>> PooledConnection<T, A> {
161    fn new(pool: Arc<ConnectionPool<T, A>>, connection: A) -> PooledConnection<T, A> {
162        Self {
163            pool,
164            conn: Some(connection),
165        }
166    }
167
168    pub(crate) fn connection(&self) -> &A {
169        self.conn.as_ref().expect("Connection should be set")
170    }
171
172    pub(crate) fn connection_mut(&mut self) -> &mut A {
173        self.conn.as_mut().expect("Connection should be set")
174    }
175}
176
177struct WatchedConnection<T>
178where
179    T: Driver,
180{
181    connection: T::Connection,
182    #[cfg(feature = "watcher")]
183    state: State,
184}
185
186#[cfg(feature = "watcher")]
187impl<T> WatchedConnection<T>
188where
189    T: Driver,
190{
191    fn new(mut connection: T::Connection) -> Result<Self, <T as Driver>::Error> {
192        use sqlite_watcher::statement::Statement;
193        State::set_pragmas().execute_mut(&mut connection)?;
194        State::start_tracking().execute_mut(&mut connection)?;
195        Ok(Self {
196            connection,
197            state: State::new(),
198        })
199    }
200    fn sync_changes(&mut self, watcher: &Watcher) -> Result<(), T::Error> {
201        use sqlite_watcher::statement::Statement;
202        if let Some(stmt) = self.state.sync_tables(watcher) {
203            stmt.execute_mut(&mut self.connection)?;
204        }
205        Ok(())
206    }
207
208    fn publish_changes(&mut self, watcher: &Watcher) {
209        use sqlite_watcher::statement::Statement;
210        if let Err(e) = self
211            .state
212            .publish_changes(watcher)
213            .execute_mut(&mut self.connection)
214        {
215            tracing::error!("Failed to publish updates to watcher: {e:?}");
216        }
217    }
218}
219
220#[cfg(not(feature = "watcher"))]
221impl<T> WatchedConnection<T>
222where
223    T: Driver,
224{
225    fn new(connection: T::Connection) -> Self {
226        Self { connection }
227    }
228}
229
230/// Even though some implementations have their own transaction type (e.g.: rusqlite), they
231/// are consumed on commit/rollback. We want to run some extra code after commit and rollback.
232pub struct Transaction<'c, T: Driver> {
233    conn: ManuallyDrop<MutexGuard<'c, WatchedConnection<T>>>,
234    #[cfg(feature = "watcher")]
235    watcher: &'c Watcher,
236}
237
238impl<'c, T: Driver> Transaction<'c, T> {
239    fn new(
240        mut conn: MutexGuard<'c, WatchedConnection<T>>,
241        #[cfg(feature = "watcher")] watcher: &'c Watcher,
242    ) -> Result<Self, <T as Driver>::Error> {
243        #[cfg(feature = "watcher")]
244        conn.sync_changes(watcher)?;
245        T::begin_transaction(&mut conn.connection, "BEGIN IMMEDIATE")?;
246        Ok(Self {
247            conn: ManuallyDrop::new(conn),
248            #[cfg(feature = "watcher")]
249            watcher,
250        })
251    }
252
253    /// Commit the transaction
254    ///
255    /// # Errors
256    ///
257    /// Returns error if the commit failed.
258    #[allow(clippy::missing_panics_doc)]
259    pub fn commit(mut self) -> Result<(), <T as Driver>::Error> {
260        T::commit_transaction(&mut self.conn.connection)?;
261        #[cfg(feature = "watcher")]
262        self.conn.publish_changes(self.watcher);
263        unsafe {
264            ManuallyDrop::drop(&mut self.conn);
265        }
266        std::mem::forget(self);
267
268        Ok(())
269    }
270
271    /// Rollback the transaction
272    ///
273    /// # Errors
274    ///
275    /// Returns errors if the operation failed.
276    #[allow(clippy::missing_panics_doc)]
277    pub fn rollback(mut self) -> Result<(), <T as Driver>::Error> {
278        T::rollback_transaction(&mut self.conn.connection)?;
279        unsafe {
280            ManuallyDrop::drop(&mut self.conn);
281        }
282        std::mem::forget(self);
283        Ok(())
284    }
285}
286
287impl<T: Driver> Drop for Transaction<'_, T> {
288    fn drop(&mut self) {
289        if let Err(e) = T::rollback_transaction(&mut self.conn.connection) {
290            tracing::error!("Failed to rollback transaction: {e:?}");
291        }
292        unsafe {
293            ManuallyDrop::drop(&mut self.conn);
294        }
295    }
296}
297
298impl<T: Driver> Deref for Transaction<'_, T> {
299    type Target = T::Connection;
300
301    fn deref(&self) -> &Self::Target {
302        &self.conn.connection
303    }
304}
305
306impl<T: DriverMutConnectionDeref> DerefMut for Transaction<'_, T> {
307    fn deref_mut(&mut self) -> &mut Self::Target {
308        &mut self.conn.connection
309    }
310}
311
312/// Performs an explicit read transaction using `BEGIN` and `END` sql statements.
313pub struct ReadTransaction<'c, T: Driver> {
314    conn: &'c mut T::Connection,
315}
316
317impl<'c, T: Driver> ReadTransaction<'c, T> {
318    pub(crate) fn new(conn: &'c mut T::Connection) -> Result<Self, <T as Driver>::Error> {
319        T::begin_transaction(conn, "BEGIN")?;
320        Ok(Self { conn })
321    }
322
323    /// Commit the transaction
324    ///
325    /// # Errors
326    ///
327    /// Returns error if the commit failed.
328    #[allow(clippy::missing_panics_doc)]
329    pub fn commit(self) -> Result<(), <T as Driver>::Error> {
330        T::commit_transaction(self.conn)?;
331        std::mem::forget(self);
332
333        Ok(())
334    }
335
336    /// Rollback the transaction
337    ///
338    /// # Errors
339    ///
340    /// Returns errors if the operation failed.
341    #[allow(clippy::missing_panics_doc)]
342    pub fn rollback(self) -> Result<(), <T as Driver>::Error> {
343        T::rollback_transaction(self.conn)?;
344        std::mem::forget(self);
345        Ok(())
346    }
347
348    pub(crate) fn scoped<F, R, E>(conn: &mut T::Connection, closure: F) -> Result<R, E>
349    where
350        F: FnOnce(&mut ReadTransaction<'_, T>) -> Result<R, E>,
351        E: From<T::Error>,
352    {
353        let mut tx = ReadTransaction::new(conn)?;
354        let r = closure(&mut tx);
355        tx.commit()?;
356        r
357    }
358}
359
360impl<T: Driver> Drop for ReadTransaction<'_, T> {
361    fn drop(&mut self) {
362        if let Err(e) = T::rollback_transaction(self.conn) {
363            tracing::error!("Failed to rollback transaction: {e:?}");
364        }
365    }
366}
367impl<T: Driver> Deref for ReadTransaction<'_, T> {
368    type Target = T::Connection;
369
370    fn deref(&self) -> &Self::Target {
371        self.conn
372    }
373}
374
375impl<T: DriverMutConnectionDeref> DerefMut for ReadTransaction<'_, T> {
376    fn deref_mut(&mut self) -> &mut Self::Target {
377        self.conn
378    }
379}