sql_middleware/sqlite/
config.rs1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::thread;
5
6use bb8::{ManageConnection, Pool, PooledConnection};
7use crossbeam_channel::{Sender, unbounded};
8
9use crate::middleware::{MiddlewarePoolOptions, SqlMiddlewareDbError};
10
11pub type SqlitePooledConnection = PooledConnection<'static, SqliteManager>;
13
14pub type SharedSqliteConnection = Arc<SqliteWorker>;
16
17#[doc(hidden)]
19#[cfg(feature = "sqlite")]
20pub async fn rollback_for_tests(pool: &Pool<SqliteManager>) -> Result<(), SqlMiddlewareDbError> {
21 let conn = pool.get_owned().await.map_err(|e| {
22 SqlMiddlewareDbError::ConnectionError(format!("sqlite cleanup checkout error: {e}"))
23 })?;
24 let handle = Arc::clone(&*conn);
25 crate::sqlite::connection::run_blocking(handle, |c| {
26 c.execute_batch("ROLLBACK;")
27 .map_err(SqlMiddlewareDbError::SqliteError)
28 })
29 .await
30}
31
32enum SqliteWorkerMessage {
33 Execute(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
34 Shutdown,
35}
36
37#[derive(Debug)]
38pub struct SqliteWorker {
39 sender: Sender<SqliteWorkerMessage>,
40 broken: Arc<AtomicBool>,
41 force_rollback_busy_for_tests: AtomicBool,
42}
43
44impl SqliteWorker {
45 pub(crate) fn start(conn: rusqlite::Connection) -> Arc<Self> {
46 let (sender, receiver) = unbounded::<SqliteWorkerMessage>();
47 let broken = Arc::new(AtomicBool::new(false));
48 let broken_flag = Arc::clone(&broken);
49 let mut conn = Some(conn);
50 let _ = thread::Builder::new()
52 .name("sql-middleware-sqlite-worker".into())
53 .spawn(move || {
54 let mut conn = conn
55 .take()
56 .expect("sqlite worker missing connection at start");
57 for msg in &receiver {
58 match msg {
59 SqliteWorkerMessage::Execute(job) => {
60 let result =
63 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
64 job(&mut conn);
65 }));
66 if result.is_err() {
67 broken_flag.store(true, Ordering::Relaxed);
68 break;
69 }
70 }
71 SqliteWorkerMessage::Shutdown => break,
72 }
73 }
74 broken_flag.store(true, Ordering::Relaxed);
75 });
76
77 Arc::new(Self {
78 sender,
79 broken,
80 force_rollback_busy_for_tests: AtomicBool::new(false),
81 })
82 }
83
84 pub(crate) fn execute<F>(&self, func: F) -> Result<(), SqlMiddlewareDbError>
85 where
86 F: FnOnce(&mut rusqlite::Connection) + Send + 'static,
87 {
88 self.sender
89 .send(SqliteWorkerMessage::Execute(Box::new(func)))
90 .map_err(|_| {
91 SqlMiddlewareDbError::ExecutionError(
92 "sqlite worker channel unexpectedly closed".into(),
93 )
94 })
95 }
96
97 pub(crate) fn execute_blocking<F, R>(&self, func: F) -> Result<R, SqlMiddlewareDbError>
98 where
99 F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
100 R: Send + 'static,
101 {
102 let (resp_tx, resp_rx) = crossbeam_channel::bounded(1);
103 self.sender
104 .send(SqliteWorkerMessage::Execute(Box::new(move |conn| {
105 let _ = resp_tx.send(func(conn));
106 })))
107 .map_err(|_| {
108 SqlMiddlewareDbError::ExecutionError(
109 "sqlite worker channel unexpectedly closed".into(),
110 )
111 })?;
112 resp_rx.recv().map_err(|_| {
113 SqlMiddlewareDbError::ExecutionError(
114 "sqlite worker response channel unexpectedly closed".into(),
115 )
116 })?
117 }
118
119 #[must_use]
120 pub(crate) fn is_broken(&self) -> bool {
121 self.broken.load(Ordering::Relaxed)
122 }
123
124 #[cfg(test)]
125 #[must_use]
126 pub fn is_broken_for_tests(&self) -> bool {
127 self.is_broken()
128 }
129
130 pub(crate) fn mark_broken(&self) {
131 self.broken.store(true, Ordering::Relaxed);
132 }
133
134 #[doc(hidden)]
135 pub fn set_force_rollback_busy_for_tests(&self, force: bool) {
136 self.force_rollback_busy_for_tests
137 .store(force, Ordering::Relaxed);
138 }
139
140 pub(crate) fn force_rollback_busy_for_tests(&self) -> bool {
141 self.force_rollback_busy_for_tests.load(Ordering::Relaxed)
142 }
143}
144
145impl Drop for SqliteWorker {
146 fn drop(&mut self) {
147 let _ = self.sender.send(SqliteWorkerMessage::Shutdown);
148 }
149}
150
151pub struct SqliteManager {
153 db_path: PathBuf,
154 pool_options: MiddlewarePoolOptions,
155 statement_cache_capacity: Option<usize>,
156}
157
158impl SqliteManager {
159 #[must_use]
160 pub fn new(db_path: String) -> Self {
161 Self {
162 db_path: db_path.into(),
163 pool_options: MiddlewarePoolOptions::default(),
164 statement_cache_capacity: None,
165 }
166 }
167
168 #[must_use]
169 pub fn from_path(db_path: impl Into<PathBuf>) -> Self {
170 Self {
171 db_path: db_path.into(),
172 pool_options: MiddlewarePoolOptions::default(),
173 statement_cache_capacity: None,
174 }
175 }
176
177 #[must_use]
178 pub fn with_pool_options(mut self, pool_options: MiddlewarePoolOptions) -> Self {
179 self.pool_options = pool_options;
180 self
181 }
182
183 #[must_use]
184 pub fn with_statement_cache_capacity(mut self, capacity: Option<usize>) -> Self {
185 self.statement_cache_capacity = capacity;
186 self
187 }
188
189 pub async fn build_pool(self) -> Result<Pool<SqliteManager>, SqlMiddlewareDbError> {
194 self.pool_options
195 .apply_to(Pool::builder())
196 .build(self)
197 .await
198 .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("sqlite pool error: {e}")))
199 }
200}
201
202impl ManageConnection for SqliteManager {
203 type Connection = SharedSqliteConnection;
204 type Error = SqlMiddlewareDbError;
205
206 fn connect(
207 &self,
208 ) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
209 let path = self.db_path.clone();
210 let statement_cache_capacity = self.statement_cache_capacity;
211 async move {
212 let conn =
213 rusqlite::Connection::open(path).map_err(SqlMiddlewareDbError::SqliteError)?;
214 if let Some(capacity) = statement_cache_capacity {
215 conn.set_prepared_statement_cache_capacity(capacity);
216 }
217 Ok(SqliteWorker::start(conn))
218 }
219 }
220
221 fn is_valid(
222 &self,
223 conn: &mut Self::Connection,
224 ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
225 let conn = Arc::clone(conn);
226 async move {
227 crate::sqlite::connection::run_blocking(conn, |guard| {
228 guard
229 .query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
230 .map_err(SqlMiddlewareDbError::SqliteError)
231 })
232 .await
233 }
234 }
235
236 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
237 conn.is_broken()
238 }
239}