datafusion_table_providers/sql/db_connection_pool/
sqlitepool.rs1use std::{sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use snafu::{prelude::*, ResultExt};
5use tokio_rusqlite::{Connection, ToSql};
6
7use super::{DbConnectionPool, Result};
8use crate::sql::db_connection_pool::{
9 dbconnection::{sqliteconn::SqliteConnection, AsyncDbConnection, DbConnection},
10 JoinPushDown, Mode,
11};
12
13#[derive(Debug, Snafu)]
14pub enum Error {
15 #[snafu(display("ConnectionPoolError: {source}"))]
16 ConnectionPoolError { source: tokio_rusqlite::Error },
17
18 #[snafu(display("No path provided for SQLite connection"))]
19 NoPathError {},
20
21 #[snafu(display("Database to attach does not exist: {path}"))]
22 DatabaseDoesNotExist { path: String },
23}
24
25pub struct SqliteConnectionPoolFactory {
26 path: Arc<str>,
27 mode: Mode,
28 attach_databases: Option<Vec<Arc<str>>>,
29 busy_timeout: Duration,
30}
31
32impl SqliteConnectionPoolFactory {
33 pub fn new(path: &str, mode: Mode, busy_timeout: Duration) -> Self {
34 SqliteConnectionPoolFactory {
35 path: path.into(),
36 mode,
37 attach_databases: None,
38 busy_timeout,
39 }
40 }
41
42 #[must_use]
43 pub fn with_databases(mut self, attach_databases: Option<Vec<Arc<str>>>) -> Self {
44 self.attach_databases = attach_databases;
45 self
46 }
47
48 pub async fn build(&self) -> Result<SqliteConnectionPool> {
49 let join_push_down = match (self.mode, &self.attach_databases) {
50 (Mode::File, Some(attach_databases)) => {
51 if attach_databases.is_empty() {
52 JoinPushDown::AllowedFor(self.path.to_string())
53 } else {
54 let mut attach_databases = attach_databases.clone();
55
56 for database in &attach_databases {
57 if std::fs::metadata(database.as_ref()).is_err() {
59 return Err(Error::DatabaseDoesNotExist {
60 path: database.to_string(),
61 }
62 .into());
63 }
64 }
65
66 if !attach_databases.contains(&self.path) {
67 attach_databases.push(Arc::clone(&self.path));
68 }
69
70 attach_databases.sort();
71
72 JoinPushDown::AllowedFor(attach_databases.join(";")) }
74 }
75 (Mode::File, None) => JoinPushDown::AllowedFor(self.path.to_string()),
76 (Mode::Memory, _) => JoinPushDown::AllowedFor("memory".to_string()),
77 };
78
79 let attach_databases = if let Some(attach_databases) = &self.attach_databases {
80 attach_databases.clone()
81 } else {
82 vec![]
83 };
84
85 let pool = SqliteConnectionPool::new(
86 &self.path,
87 self.mode,
88 join_push_down,
89 attach_databases,
90 self.busy_timeout,
91 )
92 .await?;
93
94 pool.setup().await?;
95
96 Ok(pool)
97 }
98}
99
100#[derive(Debug)]
101pub struct SqliteConnectionPool {
102 conn: Connection,
103 join_push_down: JoinPushDown,
104 mode: Mode,
105 path: Arc<str>,
106 attach_databases: Vec<Arc<str>>,
107 busy_timeout: Duration,
108}
109
110impl SqliteConnectionPool {
111 #[allow(clippy::needless_pass_by_value)]
120 pub async fn new(
121 path: &str,
122 mode: Mode,
123 join_push_down: JoinPushDown,
124 attach_databases: Vec<Arc<str>>,
125 busy_timeout: Duration,
126 ) -> Result<Self> {
127 let conn = match mode {
128 Mode::Memory => Connection::open_in_memory()
129 .await
130 .context(ConnectionPoolSnafu)?,
131
132 Mode::File => Connection::open(path.to_string())
133 .await
134 .context(ConnectionPoolSnafu)?,
135 };
136
137 Ok(SqliteConnectionPool {
138 conn,
139 join_push_down,
140 mode,
141 attach_databases,
142 path: path.into(),
143 busy_timeout,
144 })
145 }
146
147 pub async fn init(path: &str, mode: Mode) -> Result<()> {
150 if mode == Mode::File {
151 Connection::open(path.to_string())
152 .await
153 .context(ConnectionPoolSnafu)?;
154 }
155
156 Ok(())
157 }
158
159 pub async fn setup(&self) -> Result<()> {
160 let conn = self.conn.clone();
161 let busy_timeout = self.busy_timeout;
162
163 if self.mode == Mode::File {
165 conn.call(move |conn| {
168 conn.pragma_update(None, "journal_mode", "WAL")?;
169 conn.pragma_update(None, "synchronous", "NORMAL")?;
170 conn.pragma_update(None, "cache_size", "-20000")?;
171 conn.pragma_update(None, "foreign_keys", "true")?;
172 conn.pragma_update(None, "temp_store", "memory")?;
173 conn.busy_timeout(busy_timeout)?;
177
178 Ok(())
179 })
180 .await
181 .context(ConnectionPoolSnafu)?;
182
183 #[cfg(feature = "sqlite-federation")]
185 {
186 let attach_databases = self
187 .attach_databases
188 .iter()
189 .enumerate()
190 .map(|(i, db)| format!("ATTACH DATABASE '{db}' AS attachment_{i}"));
191
192 for attachment in attach_databases {
193 if attachment == *self.path {
194 continue;
195 }
196
197 conn.call(move |conn| {
198 conn.execute(&attachment, [])?;
199 Ok(())
200 })
201 .await
202 .context(ConnectionPoolSnafu)?;
203 }
204
205 Ok::<(), super::Error>(())
206 }?;
207 }
208
209 Ok(())
210 }
211
212 #[must_use]
213 pub fn connect_sync(&self) -> Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>> {
214 Box::new(SqliteConnection::new(self.conn.clone()))
215 }
216
217 pub async fn try_clone(&self) -> Result<Self> {
223 match self.mode {
224 Mode::Memory => Ok(SqliteConnectionPool {
225 conn: self.conn.clone(),
226 join_push_down: self.join_push_down.clone(),
227 mode: self.mode,
228 path: Arc::clone(&self.path),
229 attach_databases: self.attach_databases.clone(),
230 busy_timeout: self.busy_timeout,
231 }),
232 Mode::File => {
233 let attach_databases = if self.attach_databases.is_empty() {
234 None
235 } else {
236 Some(self.attach_databases.clone())
237 };
238
239 SqliteConnectionPoolFactory::new(&self.path, self.mode, self.busy_timeout)
240 .with_databases(attach_databases)
241 .build()
242 .await
243 }
244 }
245 }
246}
247
248#[async_trait]
249impl DbConnectionPool<Connection, &'static (dyn ToSql + Sync)> for SqliteConnectionPool {
250 async fn connect(
251 &self,
252 ) -> Result<Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>>> {
253 let conn = self.conn.clone();
254
255 Ok(Box::new(SqliteConnection::new(conn)))
256 }
257
258 fn join_push_down(&self) -> JoinPushDown {
259 self.join_push_down.clone()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::sql::db_connection_pool::Mode;
267 use rand::Rng;
268 use rstest::rstest;
269 use std::time::Duration;
270
271 fn random_db_name() -> String {
272 let mut rng = rand::rng();
273 let mut name = String::new();
274
275 for _ in 0..10 {
276 name.push(rng.random_range(b'a'..=b'z') as char);
277 }
278
279 format!("./{name}.sqlite")
280 }
281
282 #[rstest]
283 #[tokio::test]
284 async fn test_sqlite_connection_pool_factory() {
285 let db_name = random_db_name();
286 let factory =
287 SqliteConnectionPoolFactory::new(&db_name, Mode::File, Duration::from_secs(5));
288 let pool = factory.build().await.unwrap();
289
290 assert!(pool.join_push_down == JoinPushDown::AllowedFor(db_name.clone()));
291 assert!(pool.mode == Mode::File);
292 assert_eq!(pool.path, db_name.clone().into());
293
294 pool.conn.close().await.unwrap();
295
296 std::fs::remove_file(&db_name).unwrap();
298 }
299
300 #[tokio::test]
301 async fn test_sqlite_connection_pool_factory_with_attachments() {
302 let mut db_names = [random_db_name(), random_db_name(), random_db_name()];
303 db_names.sort();
304
305 let factory =
306 SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000))
307 .with_databases(Some(vec![
308 db_names[1].clone().into(),
309 db_names[2].clone().into(),
310 ]));
311
312 SqliteConnectionPool::init(&db_names[1], Mode::File)
313 .await
314 .unwrap();
315 SqliteConnectionPool::init(&db_names[2], Mode::File)
316 .await
317 .unwrap();
318
319 let pool = factory.build().await.unwrap();
320
321 let push_down = db_names.join(";");
322
323 assert!(pool.join_push_down == JoinPushDown::AllowedFor(push_down));
324 assert!(pool.mode == Mode::File);
325 assert_eq!(pool.path, db_names[0].clone().into());
326
327 pool.conn.close().await.unwrap();
328
329 for db in &db_names {
331 std::fs::remove_file(db).unwrap();
332 }
333 }
334
335 #[tokio::test]
336 async fn test_sqlite_connection_pool_factory_with_empty_attachments() {
337 let db_name = random_db_name();
338 let factory =
339 SqliteConnectionPoolFactory::new(&db_name, Mode::File, Duration::from_millis(5000))
340 .with_databases(Some(vec![]));
341
342 let pool = factory.build().await.unwrap();
343
344 assert!(pool.join_push_down == JoinPushDown::AllowedFor(db_name.clone()));
345 assert!(pool.mode == Mode::File);
346 assert_eq!(pool.path, db_name.clone().into());
347
348 pool.conn.close().await.unwrap();
349
350 std::fs::remove_file(&db_name).unwrap();
352 }
353
354 #[tokio::test]
355 async fn test_sqlite_connection_pool_factory_memory_with_attachments() {
356 let factory = SqliteConnectionPoolFactory::new(
357 "./test.sqlite",
358 Mode::Memory,
359 Duration::from_millis(5000),
360 )
361 .with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()]));
362 let pool = factory.build().await.unwrap();
363
364 assert!(pool.join_push_down == JoinPushDown::AllowedFor("memory".to_string()));
365 assert!(pool.mode == Mode::Memory);
366 assert_eq!(pool.path, "./test.sqlite".into());
367
368 pool.conn.close().await.unwrap();
369
370 assert!(std::fs::metadata("./test.sqlite").is_err());
372 assert!(std::fs::metadata("./test1.sqlite").is_err());
373 assert!(std::fs::metadata("./test2.sqlite").is_err());
374 }
375
376 #[tokio::test]
377 async fn test_sqlite_connection_pool_factory_errors_with_missing_attachments() {
378 let mut db_names = [random_db_name(), random_db_name(), random_db_name()];
379 db_names.sort();
380
381 let factory =
382 SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000))
383 .with_databases(Some(vec![
384 db_names[1].clone().into(),
385 db_names[2].clone().into(),
386 ]));
387 let pool = factory.build().await;
388
389 assert!(pool.is_err());
390
391 let err = pool.err().unwrap();
392 assert!(err.to_string().contains(&format!(
393 "Database to attach does not exist: {}",
394 db_names[1]
395 )));
396 }
397}