sqlite_rwc/
pool.rs

1use crate::drivers::Driver;
2use parking_lot::{Condvar, Mutex};
3#[cfg(feature = "watcher")]
4use sqlite_watcher::connection::State;
5#[cfg(feature = "watcher")]
6use sqlite_watcher::watcher::Watcher;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10
11pub struct ConnectionPoolConfig {
12    pub max_read_connection_count: usize,
13    pub file_path: PathBuf,
14    pub connection_acquire_timeout: Option<Duration>,
15    #[cfg(feature = "watcher")]
16    pub watcher: Arc<Watcher>,
17}
18
19pub struct ConnectionPool<T: Driver, A: ConnectionAdapter<T>> {
20    read_connections: Mutex<Vec<A>>,
21    reader_condvar: Condvar,
22    write_connection: Mutex<WatchedConnection<T>>,
23    config: ConnectionPoolConfig,
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum ConnectionPoolError<E> {
28    #[error(transparent)]
29    Driver(#[from] E),
30    #[error("Failed to acquire connection in time")]
31    ConnectionAcquireTimeout,
32    #[error("Failed to setup connection watcher")]
33    WatcherSetup,
34}
35
36impl<T: Driver, A: ConnectionAdapter<T>> ConnectionPool<T, A> {
37    /// Create a new connection pool with the given `config`.
38    ///
39    /// The write connection is created first, followed by a read connection for every reader.
40    ///
41    /// # Errors
42    ///
43    /// Returns error if the connections could not be initialized.
44    pub fn new(
45        config: ConnectionPoolConfig,
46    ) -> Result<Arc<Self>, ConnectionPoolError<T::ConnectionError>> {
47        let watched_connection = T::new_write_connection(&config.file_path)
48            .inspect_err(|e| tracing::error!("Failed to create write connection: {e:?}"))?;
49        #[cfg(feature = "watcher")]
50        let watched_connection = WatchedConnection::new(watched_connection).map_err(|e| {
51            tracing::error!("Failed to setup connection watcher: {e:?}");
52            ConnectionPoolError::WatcherSetup
53        })?;
54        #[cfg(not(feature = "watcher"))]
55        let watched_connection = WatchedConnection::new(watched_connection);
56
57        let mut read_connections = Vec::with_capacity(config.max_read_connection_count);
58        for _ in 0..config.max_read_connection_count {
59            read_connections.push(A::from_driver_connection(
60                T::new_read_connection(&config.file_path)
61                    .inspect_err(|e| tracing::error!("Failed to create read connection: {e:?}"))?,
62            ));
63        }
64        Ok(Arc::new(Self {
65            write_connection: Mutex::new(watched_connection),
66            read_connections: Mutex::new(read_connections),
67            reader_condvar: Condvar::new(),
68            config,
69        }))
70    }
71
72    /// Retrieve a connection from the pool.
73    ///
74    /// If all the connections are currently in use, we will wait until one is returned to the
75    /// pool. If `ConnectionPoolConfig.connection_acquire_timeout` has no value, this method will
76    /// block indefinitely.
77    ///
78    /// # Errors
79    ///
80    /// Return error if we could not retrieve a connection from the pool before the timeout
81    /// triggered.
82    pub fn connection(
83        self: &Arc<Self>,
84    ) -> Result<PooledConnection<T, A>, ConnectionPoolError<T::Error>> {
85        let mut rd_connections = self.read_connections.lock();
86        loop {
87            if let Some(rd_connection) = rd_connections.pop() {
88                return Ok(PooledConnection::new(self.clone(), rd_connection));
89            } else if let Some(duration) = self.config.connection_acquire_timeout {
90                if self
91                    .reader_condvar
92                    .wait_for(&mut rd_connections, duration)
93                    .timed_out()
94                {
95                    return Err(ConnectionPoolError::ConnectionAcquireTimeout);
96                }
97            } else {
98                self.reader_condvar.wait(&mut rd_connections);
99            }
100        }
101    }
102
103    pub(crate) fn transaction<F, R, E>(&self, closure: F) -> Result<R, E>
104    where
105        F: FnOnce(&mut T::Transaction<'_>) -> Result<R, E>,
106        E: From<T::Error>,
107    {
108        let mut writer_connection = self.write_connection.lock();
109        #[cfg(feature = "watcher")]
110        {
111            writer_connection.transaction(closure, &self.config.watcher)
112        }
113        #[cfg(not(feature = "watcher"))]
114        {
115            writer_connection.transaction(closure)
116        }
117    }
118
119    fn return_to_pool(&self, conn: A) {
120        let mut read_connections = self.read_connections.lock();
121        read_connections.push(conn);
122        drop(read_connections);
123        self.reader_condvar.notify_one();
124    }
125
126    #[cfg(feature = "watcher")]
127    pub fn watcher(&self) -> &Arc<Watcher> {
128        &self.config.watcher
129    }
130}
131
132pub trait ConnectionAdapter<T: Driver> {
133    fn from_driver_connection(connection: T::Connection) -> Self;
134}
135
136pub struct PooledConnection<T: Driver, A: ConnectionAdapter<T>> {
137    pub(crate) pool: Arc<ConnectionPool<T, A>>,
138    conn: Option<A>,
139}
140
141impl<T: Driver, A: ConnectionAdapter<T>> Drop for PooledConnection<T, A> {
142    fn drop(&mut self) {
143        let conn = self.conn.take().expect("Connection should be set");
144        self.pool.return_to_pool(conn);
145    }
146}
147
148impl<T: Driver, A: ConnectionAdapter<T>> PooledConnection<T, A> {
149    fn new(pool: Arc<ConnectionPool<T, A>>, connection: A) -> PooledConnection<T, A> {
150        Self {
151            pool,
152            conn: Some(connection),
153        }
154    }
155
156    pub(crate) fn connection(&self) -> &A {
157        self.conn.as_ref().expect("Connection should be set")
158    }
159
160    pub(crate) fn connection_mut(&mut self) -> &mut A {
161        self.conn.as_mut().expect("Connection should be set")
162    }
163}
164
165struct WatchedConnection<T>
166where
167    T: Driver,
168{
169    connection: T::Connection,
170    #[cfg(feature = "watcher")]
171    state: State,
172}
173
174#[cfg(feature = "watcher")]
175impl<T> WatchedConnection<T>
176where
177    T: Driver,
178{
179    fn new(mut connection: T::Connection) -> Result<Self, <T as Driver>::Error> {
180        use sqlite_watcher::statement::Statement;
181        State::start_tracking().execute_mut(&mut connection)?;
182        Ok(Self {
183            connection,
184            state: State::new(),
185        })
186    }
187
188    fn transaction<F, R, E>(&mut self, closure: F, watcher: &Watcher) -> Result<R, E>
189    where
190        F: FnOnce(&mut T::Transaction<'_>) -> Result<R, E>,
191        E: From<T::Error>,
192    {
193        self.before_write(watcher)?;
194        let result = T::write(&mut self.connection, closure);
195        if let Err(e) = self.after_write(watcher) {
196            tracing::error!("Failed to publish updates to watcher: {e:?}");
197        }
198        result
199    }
200
201    fn before_write(&mut self, watcher: &Watcher) -> Result<(), <T as Driver>::Error> {
202        use sqlite_watcher::statement::Statement;
203        if let Some(stmt) = self.state.sync_tables(watcher) {
204            stmt.execute_mut(&mut self.connection)?;
205        }
206        Ok(())
207    }
208
209    fn after_write(&mut self, watcher: &Watcher) -> Result<(), <T as Driver>::Error> {
210        use sqlite_watcher::statement::Statement;
211        self.state
212            .publish_changes(watcher)
213            .execute_mut(&mut self.connection)?;
214        Ok(())
215    }
216}
217
218#[cfg(not(feature = "watcher"))]
219impl<T> WatchedConnection<T>
220where
221    T: Driver,
222{
223    fn new(connection: T::Connection) -> Self {
224        Self { connection }
225    }
226
227    fn transaction<F, R, E>(&mut self, closure: F) -> Result<R, E>
228    where
229        F: FnOnce(&mut T::Transaction<'_>) -> Result<R, E>,
230        E: From<T::Error>,
231    {
232        T::write(&mut self.connection, closure)
233    }
234}