datafusion_table_providers/sql/db_connection_pool/
sqlitepool.rs

1use std::{sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use snafu::{prelude::*, ResultExt};
5use tokio_rusqlite::{Connection, ToSql};
6
7use super::{DbConnectionPool, Result};
8use crate::sql::db_connection_pool::{
9    dbconnection::{sqliteconn::SqliteConnection, AsyncDbConnection, DbConnection},
10    JoinPushDown, Mode,
11};
12
13#[derive(Debug, Snafu)]
14pub enum Error {
15    #[snafu(display("ConnectionPoolError: {source}"))]
16    ConnectionPoolError { source: tokio_rusqlite::Error },
17
18    #[snafu(display("No path provided for SQLite connection"))]
19    NoPathError {},
20
21    #[snafu(display("Database to attach does not exist: {path}"))]
22    DatabaseDoesNotExist { path: String },
23}
24
25pub struct SqliteConnectionPoolFactory {
26    path: Arc<str>,
27    mode: Mode,
28    attach_databases: Option<Vec<Arc<str>>>,
29    busy_timeout: Duration,
30}
31
32impl SqliteConnectionPoolFactory {
33    pub fn new(path: &str, mode: Mode, busy_timeout: Duration) -> Self {
34        SqliteConnectionPoolFactory {
35            path: path.into(),
36            mode,
37            attach_databases: None,
38            busy_timeout,
39        }
40    }
41
42    #[must_use]
43    pub fn with_databases(mut self, attach_databases: Option<Vec<Arc<str>>>) -> Self {
44        self.attach_databases = attach_databases;
45        self
46    }
47
48    pub async fn build(&self) -> Result<SqliteConnectionPool> {
49        let join_push_down = match (self.mode, &self.attach_databases) {
50            (Mode::File, Some(attach_databases)) => {
51                if attach_databases.is_empty() {
52                    JoinPushDown::AllowedFor(self.path.to_string())
53                } else {
54                    let mut attach_databases = attach_databases.clone();
55
56                    for database in &attach_databases {
57                        // check if the database file exists
58                        if std::fs::metadata(database.as_ref()).is_err() {
59                            return Err(Error::DatabaseDoesNotExist {
60                                path: database.to_string(),
61                            }
62                            .into());
63                        }
64                    }
65
66                    if !attach_databases.contains(&self.path) {
67                        attach_databases.push(Arc::clone(&self.path));
68                    }
69
70                    attach_databases.sort();
71
72                    JoinPushDown::AllowedFor(attach_databases.join(";")) // push down is allowed cross-database when they're attached together
73                }
74            }
75            (Mode::File, None) => JoinPushDown::AllowedFor(self.path.to_string()),
76            (Mode::Memory, _) => JoinPushDown::AllowedFor("memory".to_string()),
77        };
78
79        let attach_databases = if let Some(attach_databases) = &self.attach_databases {
80            attach_databases.clone()
81        } else {
82            vec![]
83        };
84
85        let pool = SqliteConnectionPool::new(
86            &self.path,
87            self.mode,
88            join_push_down,
89            attach_databases,
90            self.busy_timeout,
91        )
92        .await?;
93
94        pool.setup().await?;
95
96        Ok(pool)
97    }
98}
99
100#[derive(Debug)]
101pub struct SqliteConnectionPool {
102    conn: Connection,
103    join_push_down: JoinPushDown,
104    mode: Mode,
105    path: Arc<str>,
106    attach_databases: Vec<Arc<str>>,
107    busy_timeout: Duration,
108}
109
110impl SqliteConnectionPool {
111    /// Creates a new instance of `SqliteConnectionPool`.
112    ///
113    /// NOTE: The `SqliteConnectionPool` currently does no connection pooling, it simply creates a new connection
114    /// and clones it on each call to `connect()`.
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if there is a problem creating the connection pool.
119    #[allow(clippy::needless_pass_by_value)]
120    pub async fn new(
121        path: &str,
122        mode: Mode,
123        join_push_down: JoinPushDown,
124        attach_databases: Vec<Arc<str>>,
125        busy_timeout: Duration,
126    ) -> Result<Self> {
127        let conn = match mode {
128            Mode::Memory => Connection::open_in_memory()
129                .await
130                .context(ConnectionPoolSnafu)?,
131
132            Mode::File => Connection::open(path.to_string())
133                .await
134                .context(ConnectionPoolSnafu)?,
135        };
136
137        Ok(SqliteConnectionPool {
138            conn,
139            join_push_down,
140            mode,
141            attach_databases,
142            path: path.into(),
143            busy_timeout,
144        })
145    }
146
147    /// Initializes an SQLite database on-disk without creating a connection pool.
148    /// No-op if the database is in-memory.
149    pub async fn init(path: &str, mode: Mode) -> Result<()> {
150        if mode == Mode::File {
151            Connection::open(path.to_string())
152                .await
153                .context(ConnectionPoolSnafu)?;
154        }
155
156        Ok(())
157    }
158
159    pub async fn setup(&self) -> Result<()> {
160        let conn = self.conn.clone();
161        let busy_timeout = self.busy_timeout;
162
163        // these configuration options are only applicable for file-mode databases
164        if self.mode == Mode::File {
165            // change transaction mode to Write-Ahead log instead of default atomic rollback journal: https://www.sqlite.org/wal.html
166            // NOTE: This is a no-op if the database is in-memory, as only MEMORY or OFF are supported: https://www.sqlite.org/pragma.html#pragma_journal_mode
167            conn.call(move |conn| {
168                conn.pragma_update(None, "journal_mode", "WAL")?;
169                conn.pragma_update(None, "synchronous", "NORMAL")?;
170                conn.pragma_update(None, "cache_size", "-20000")?;
171                conn.pragma_update(None, "foreign_keys", "true")?;
172                conn.pragma_update(None, "temp_store", "memory")?;
173                // conn.set_transaction_behavior(TransactionBehavior::Immediate); introduced in rustqlite 0.32.1, but tokio-rusqlite is still on 0.31.0
174
175                // Set user configurable connection timeout
176                conn.busy_timeout(busy_timeout)?;
177
178                Ok(())
179            })
180            .await
181            .context(ConnectionPoolSnafu)?;
182
183            // database attachments are only supported for file-mode databases
184            #[cfg(feature = "sqlite-federation")]
185            {
186                let attach_databases = self
187                    .attach_databases
188                    .iter()
189                    .enumerate()
190                    .map(|(i, db)| format!("ATTACH DATABASE '{db}' AS attachment_{i}"));
191
192                for attachment in attach_databases {
193                    if attachment == *self.path {
194                        continue;
195                    }
196
197                    conn.call(move |conn| {
198                        conn.execute(&attachment, [])?;
199                        Ok(())
200                    })
201                    .await
202                    .context(ConnectionPoolSnafu)?;
203                }
204
205                Ok::<(), super::Error>(())
206            }?;
207        }
208
209        Ok(())
210    }
211
212    #[must_use]
213    pub fn connect_sync(&self) -> Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>> {
214        Box::new(SqliteConnection::new(self.conn.clone()))
215    }
216
217    /// Will attempt to clone the connection pool. This will always succeed for in-memory mode.
218    /// For file-mode, it will attempt to create a new connection pool with the same configuration.
219    ///
220    /// Due to the way the connection pool is implemented, it doesn't allow multiple concurrent reads/writes
221    /// using the same connection pool instance.
222    pub async fn try_clone(&self) -> Result<Self> {
223        match self.mode {
224            Mode::Memory => Ok(SqliteConnectionPool {
225                conn: self.conn.clone(),
226                join_push_down: self.join_push_down.clone(),
227                mode: self.mode,
228                path: Arc::clone(&self.path),
229                attach_databases: self.attach_databases.clone(),
230                busy_timeout: self.busy_timeout,
231            }),
232            Mode::File => {
233                let attach_databases = if self.attach_databases.is_empty() {
234                    None
235                } else {
236                    Some(self.attach_databases.clone())
237                };
238
239                SqliteConnectionPoolFactory::new(&self.path, self.mode, self.busy_timeout)
240                    .with_databases(attach_databases)
241                    .build()
242                    .await
243            }
244        }
245    }
246}
247
248#[async_trait]
249impl DbConnectionPool<Connection, &'static (dyn ToSql + Sync)> for SqliteConnectionPool {
250    async fn connect(
251        &self,
252    ) -> Result<Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>>> {
253        let conn = self.conn.clone();
254
255        Ok(Box::new(SqliteConnection::new(conn)))
256    }
257
258    fn join_push_down(&self) -> JoinPushDown {
259        self.join_push_down.clone()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::sql::db_connection_pool::Mode;
267    use rand::Rng;
268    use rstest::rstest;
269    use std::time::Duration;
270
271    fn random_db_name() -> String {
272        let mut rng = rand::rng();
273        let mut name = String::new();
274
275        for _ in 0..10 {
276            name.push(rng.random_range(b'a'..=b'z') as char);
277        }
278
279        format!("./{name}.sqlite")
280    }
281
282    #[rstest]
283    #[tokio::test]
284    async fn test_sqlite_connection_pool_factory() {
285        let db_name = random_db_name();
286        let factory =
287            SqliteConnectionPoolFactory::new(&db_name, Mode::File, Duration::from_secs(5));
288        let pool = factory.build().await.unwrap();
289
290        assert!(pool.join_push_down == JoinPushDown::AllowedFor(db_name.clone()));
291        assert!(pool.mode == Mode::File);
292        assert_eq!(pool.path, db_name.clone().into());
293
294        drop(pool);
295
296        // cleanup
297        std::fs::remove_file(&db_name).unwrap();
298    }
299
300    #[tokio::test]
301    async fn test_sqlite_connection_pool_factory_with_attachments() {
302        let mut db_names = [random_db_name(), random_db_name(), random_db_name()];
303        db_names.sort();
304
305        let factory =
306            SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000))
307                .with_databases(Some(vec![
308                    db_names[1].clone().into(),
309                    db_names[2].clone().into(),
310                ]));
311
312        SqliteConnectionPool::init(&db_names[1], Mode::File)
313            .await
314            .unwrap();
315        SqliteConnectionPool::init(&db_names[2], Mode::File)
316            .await
317            .unwrap();
318
319        let pool = factory.build().await.unwrap();
320
321        let push_down = db_names.join(";");
322
323        assert!(pool.join_push_down == JoinPushDown::AllowedFor(push_down));
324        assert!(pool.mode == Mode::File);
325        assert_eq!(pool.path, db_names[0].clone().into());
326
327        drop(pool);
328
329        // cleanup
330        for db in &db_names {
331            std::fs::remove_file(db).unwrap();
332        }
333    }
334
335    #[tokio::test]
336    async fn test_sqlite_connection_pool_factory_with_empty_attachments() {
337        let db_name = random_db_name();
338        let factory =
339            SqliteConnectionPoolFactory::new(&db_name, Mode::File, Duration::from_millis(5000))
340                .with_databases(Some(vec![]));
341
342        let pool = factory.build().await.unwrap();
343
344        assert!(pool.join_push_down == JoinPushDown::AllowedFor(db_name.clone()));
345        assert!(pool.mode == Mode::File);
346        assert_eq!(pool.path, db_name.clone().into());
347
348        drop(pool);
349
350        // cleanup
351        std::fs::remove_file(&db_name).unwrap();
352    }
353
354    #[tokio::test]
355    async fn test_sqlite_connection_pool_factory_memory_with_attachments() {
356        let factory = SqliteConnectionPoolFactory::new(
357            "./test.sqlite",
358            Mode::Memory,
359            Duration::from_millis(5000),
360        )
361        .with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()]));
362        let pool = factory.build().await.unwrap();
363
364        assert!(pool.join_push_down == JoinPushDown::AllowedFor("memory".to_string()));
365        assert!(pool.mode == Mode::Memory);
366        assert_eq!(pool.path, "./test.sqlite".into());
367
368        drop(pool);
369
370        // in memory mode, attachments are not created and nothing happens
371        assert!(std::fs::metadata("./test.sqlite").is_err());
372        assert!(std::fs::metadata("./test1.sqlite").is_err());
373        assert!(std::fs::metadata("./test2.sqlite").is_err());
374    }
375
376    #[tokio::test]
377    async fn test_sqlite_connection_pool_factory_errors_with_missing_attachments() {
378        let mut db_names = [random_db_name(), random_db_name(), random_db_name()];
379        db_names.sort();
380
381        let factory =
382            SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000))
383                .with_databases(Some(vec![
384                    db_names[1].clone().into(),
385                    db_names[2].clone().into(),
386                ]));
387        let pool = factory.build().await;
388
389        assert!(pool.is_err());
390
391        let err = pool.err().unwrap();
392        assert!(err.to_string().contains(&format!(
393            "Database to attach does not exist: {}",
394            db_names[1]
395        )));
396    }
397}