1use crate::drivers::{Driver, DriverMutConnectionDeref};
2use parking_lot::{Condvar, Mutex, MutexGuard};
3#[cfg(feature = "watcher")]
4use sqlite_watcher::connection::State;
5#[cfg(feature = "watcher")]
6use sqlite_watcher::watcher::Watcher;
7use std::mem::ManuallyDrop;
8use std::ops::{Deref, DerefMut};
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12
13pub struct ConnectionPoolConfig {
14 pub max_read_connection_count: usize,
15 pub file_path: PathBuf,
16 pub connection_acquire_timeout: Option<Duration>,
17 #[cfg(feature = "watcher")]
18 pub watcher: Arc<Watcher>,
19}
20
21pub struct ConnectionPool<T: Driver, A: ConnectionAdapter<T>> {
22 read_connections: Mutex<Vec<A>>,
23 reader_condvar: Condvar,
24 write_connection: Mutex<WatchedConnection<T>>,
25 config: ConnectionPoolConfig,
26}
27
28#[derive(Debug, thiserror::Error)]
29pub enum ConnectionPoolError<E> {
30 #[error(transparent)]
31 Driver(#[from] E),
32 #[error("Failed to acquire connection in time")]
33 ConnectionAcquireTimeout,
34 #[error("Failed to setup connection watcher")]
35 WatcherSetup,
36 #[error(transparent)]
37 Other(Box<dyn std::error::Error + Send + Sync>),
38}
39
40impl<T: Driver, A: ConnectionAdapter<T>> ConnectionPool<T, A> {
41 pub fn new(
49 config: ConnectionPoolConfig,
50 ) -> Result<Arc<Self>, ConnectionPoolError<T::ConnectionError>> {
51 let watched_connection = T::new_write_connection(&config.file_path)
52 .inspect_err(|e| tracing::error!("Failed to create write connection: {e:?}"))?;
53 #[cfg(feature = "watcher")]
54 let watched_connection = WatchedConnection::new(watched_connection).map_err(|e| {
55 tracing::error!("Failed to setup connection watcher: {e:?}");
56 ConnectionPoolError::WatcherSetup
57 })?;
58 #[cfg(not(feature = "watcher"))]
59 let watched_connection = WatchedConnection::new(watched_connection);
60
61 let mut read_connections = Vec::with_capacity(config.max_read_connection_count);
62 for _ in 0..config.max_read_connection_count {
63 read_connections.push(A::from_driver_connection(
64 T::new_read_connection(&config.file_path)
65 .inspect_err(|e| tracing::error!("Failed to create read connection: {e:?}"))?,
66 ));
67 }
68 Ok(Arc::new(Self {
69 write_connection: Mutex::new(watched_connection),
70 read_connections: Mutex::new(read_connections),
71 reader_condvar: Condvar::new(),
72 config,
73 }))
74 }
75
76 pub fn connection(
87 self: &Arc<Self>,
88 ) -> Result<PooledConnection<T, A>, ConnectionPoolError<T::Error>> {
89 let mut rd_connections = self.read_connections.lock();
90 loop {
91 if let Some(rd_connection) = rd_connections.pop() {
92 return Ok(PooledConnection::new(self.clone(), rd_connection));
93 } else if let Some(duration) = self.config.connection_acquire_timeout {
94 if self
95 .reader_condvar
96 .wait_for(&mut rd_connections, duration)
97 .timed_out()
98 {
99 return Err(ConnectionPoolError::ConnectionAcquireTimeout);
100 }
101 } else {
102 self.reader_condvar.wait(&mut rd_connections);
103 }
104 }
105 }
106
107 pub(crate) fn transaction_closure<F, R, E>(&self, closure: F) -> Result<R, E>
108 where
109 F: FnOnce(&mut Transaction<'_, T>) -> Result<R, E>,
110 E: From<T::Error>,
111 {
112 let mut tx = self.transaction()?;
113 let result = closure(&mut tx);
114 if result.is_ok() {
115 tx.commit()?;
116 } else {
117 tx.rollback()?;
118 }
119 result
120 }
121
122 pub(crate) fn transaction(&self) -> Result<Transaction<'_, T>, T::Error> {
123 let writer = self.write_connection.lock();
124 Transaction::new(
125 writer,
126 #[cfg(feature = "watcher")]
127 &self.config.watcher,
128 )
129 }
130
131 fn return_to_pool(&self, conn: A) {
132 let mut read_connections = self.read_connections.lock();
133 read_connections.push(conn);
134 drop(read_connections);
135 self.reader_condvar.notify_one();
136 }
137
138 #[cfg(feature = "watcher")]
139 pub fn watcher(&self) -> &Arc<Watcher> {
140 &self.config.watcher
141 }
142}
143
144pub trait ConnectionAdapter<T: Driver> {
145 fn from_driver_connection(connection: T::Connection) -> Self;
146}
147
148pub struct PooledConnection<T: Driver, A: ConnectionAdapter<T>> {
149 pub(crate) pool: Arc<ConnectionPool<T, A>>,
150 conn: Option<A>,
151}
152
153impl<T: Driver, A: ConnectionAdapter<T>> Drop for PooledConnection<T, A> {
154 fn drop(&mut self) {
155 let conn = self.conn.take().expect("Connection should be set");
156 self.pool.return_to_pool(conn);
157 }
158}
159
160impl<T: Driver, A: ConnectionAdapter<T>> PooledConnection<T, A> {
161 fn new(pool: Arc<ConnectionPool<T, A>>, connection: A) -> PooledConnection<T, A> {
162 Self {
163 pool,
164 conn: Some(connection),
165 }
166 }
167
168 pub(crate) fn connection(&self) -> &A {
169 self.conn.as_ref().expect("Connection should be set")
170 }
171
172 pub(crate) fn connection_mut(&mut self) -> &mut A {
173 self.conn.as_mut().expect("Connection should be set")
174 }
175}
176
177struct WatchedConnection<T>
178where
179 T: Driver,
180{
181 connection: T::Connection,
182 #[cfg(feature = "watcher")]
183 state: State,
184}
185
186#[cfg(feature = "watcher")]
187impl<T> WatchedConnection<T>
188where
189 T: Driver,
190{
191 fn new(mut connection: T::Connection) -> Result<Self, <T as Driver>::Error> {
192 use sqlite_watcher::statement::Statement;
193 State::set_pragmas().execute_mut(&mut connection)?;
194 State::start_tracking().execute_mut(&mut connection)?;
195 Ok(Self {
196 connection,
197 state: State::new(),
198 })
199 }
200 fn sync_changes(&mut self, watcher: &Watcher) -> Result<(), T::Error> {
201 use sqlite_watcher::statement::Statement;
202 if let Some(stmt) = self.state.sync_tables(watcher) {
203 stmt.execute_mut(&mut self.connection)?;
204 }
205 Ok(())
206 }
207
208 fn publish_changes(&mut self, watcher: &Watcher) {
209 use sqlite_watcher::statement::Statement;
210 if let Err(e) = self
211 .state
212 .publish_changes(watcher)
213 .execute_mut(&mut self.connection)
214 {
215 tracing::error!("Failed to publish updates to watcher: {e:?}");
216 }
217 }
218}
219
220#[cfg(not(feature = "watcher"))]
221impl<T> WatchedConnection<T>
222where
223 T: Driver,
224{
225 fn new(connection: T::Connection) -> Self {
226 Self { connection }
227 }
228}
229
230pub struct Transaction<'c, T: Driver> {
233 conn: ManuallyDrop<MutexGuard<'c, WatchedConnection<T>>>,
234 #[cfg(feature = "watcher")]
235 watcher: &'c Watcher,
236}
237
238impl<'c, T: Driver> Transaction<'c, T> {
239 fn new(
240 mut conn: MutexGuard<'c, WatchedConnection<T>>,
241 #[cfg(feature = "watcher")] watcher: &'c Watcher,
242 ) -> Result<Self, <T as Driver>::Error> {
243 #[cfg(feature = "watcher")]
244 conn.sync_changes(watcher)?;
245 T::begin_transaction(&mut conn.connection, "BEGIN IMMEDIATE")?;
246 Ok(Self {
247 conn: ManuallyDrop::new(conn),
248 #[cfg(feature = "watcher")]
249 watcher,
250 })
251 }
252
253 #[allow(clippy::missing_panics_doc)]
259 pub fn commit(mut self) -> Result<(), <T as Driver>::Error> {
260 T::commit_transaction(&mut self.conn.connection)?;
261 #[cfg(feature = "watcher")]
262 self.conn.publish_changes(self.watcher);
263 unsafe {
264 ManuallyDrop::drop(&mut self.conn);
265 }
266 std::mem::forget(self);
267
268 Ok(())
269 }
270
271 #[allow(clippy::missing_panics_doc)]
277 pub fn rollback(mut self) -> Result<(), <T as Driver>::Error> {
278 T::rollback_transaction(&mut self.conn.connection)?;
279 unsafe {
280 ManuallyDrop::drop(&mut self.conn);
281 }
282 std::mem::forget(self);
283 Ok(())
284 }
285}
286
287impl<T: Driver> Drop for Transaction<'_, T> {
288 fn drop(&mut self) {
289 if let Err(e) = T::rollback_transaction(&mut self.conn.connection) {
290 tracing::error!("Failed to rollback transaction: {e:?}");
291 }
292 unsafe {
293 ManuallyDrop::drop(&mut self.conn);
294 }
295 }
296}
297
298impl<T: Driver> Deref for Transaction<'_, T> {
299 type Target = T::Connection;
300
301 fn deref(&self) -> &Self::Target {
302 &self.conn.connection
303 }
304}
305
306impl<T: DriverMutConnectionDeref> DerefMut for Transaction<'_, T> {
307 fn deref_mut(&mut self) -> &mut Self::Target {
308 &mut self.conn.connection
309 }
310}
311
312pub struct ReadTransaction<'c, T: Driver> {
314 conn: &'c mut T::Connection,
315}
316
317impl<'c, T: Driver> ReadTransaction<'c, T> {
318 pub(crate) fn new(conn: &'c mut T::Connection) -> Result<Self, <T as Driver>::Error> {
319 T::begin_transaction(conn, "BEGIN")?;
320 Ok(Self { conn })
321 }
322
323 #[allow(clippy::missing_panics_doc)]
329 pub fn commit(self) -> Result<(), <T as Driver>::Error> {
330 T::commit_transaction(self.conn)?;
331 std::mem::forget(self);
332
333 Ok(())
334 }
335
336 #[allow(clippy::missing_panics_doc)]
342 pub fn rollback(self) -> Result<(), <T as Driver>::Error> {
343 T::rollback_transaction(self.conn)?;
344 std::mem::forget(self);
345 Ok(())
346 }
347
348 pub(crate) fn scoped<F, R, E>(conn: &mut T::Connection, closure: F) -> Result<R, E>
349 where
350 F: FnOnce(&mut ReadTransaction<'_, T>) -> Result<R, E>,
351 E: From<T::Error>,
352 {
353 let mut tx = ReadTransaction::new(conn)?;
354 let r = closure(&mut tx);
355 tx.commit()?;
356 r
357 }
358}
359
360impl<T: Driver> Drop for ReadTransaction<'_, T> {
361 fn drop(&mut self) {
362 if let Err(e) = T::rollback_transaction(self.conn) {
363 tracing::error!("Failed to rollback transaction: {e:?}");
364 }
365 }
366}
367impl<T: Driver> Deref for ReadTransaction<'_, T> {
368 type Target = T::Connection;
369
370 fn deref(&self) -> &Self::Target {
371 self.conn
372 }
373}
374
375impl<T: DriverMutConnectionDeref> DerefMut for ReadTransaction<'_, T> {
376 fn deref_mut(&mut self) -> &mut Self::Target {
377 self.conn
378 }
379}