datafusion_table_providers/sql/db_connection_pool/
duckdbpool.rs

1use async_trait::async_trait;
2use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager};
3use snafu::{prelude::*, ResultExt};
4use std::sync::Arc;
5
6use super::{
7    dbconnection::duckdbconn::{DuckDBAttachments, DuckDBParameter},
8    DbConnectionPool, Mode, Result,
9};
10use crate::{
11    sql::db_connection_pool::{
12        dbconnection::{duckdbconn::DuckDbConnection, DbConnection, SyncDbConnection},
13        JoinPushDown,
14    },
15    UnsupportedTypeAction,
16};
17
18#[derive(Debug, Snafu)]
19pub enum Error {
20    #[snafu(display("DuckDB connection failed.\n{source}\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/"))]
21    DuckDBConnectionError { source: duckdb::Error },
22
23    #[snafu(display(
24        "DuckDB connection failed.\n{source}\nAdjust the DuckDB connection pool parameters for sufficient capacity."
25    ))]
26    ConnectionPoolError { source: r2d2::Error },
27
28    #[snafu(display(
29        "Invalid DuckDB file path: {path}. Ensure it contains a valid database name."
30    ))]
31    UnableToExtractDatabaseNameFromPath { path: Arc<str> },
32}
33
34pub struct DuckDbConnectionPoolBuilder {
35    path: String,
36    max_size: Option<u32>,
37    access_mode: AccessMode,
38    min_idle: Option<u32>,
39    mode: Mode,
40    connection_setup_queries: Vec<Arc<str>>,
41}
42
43impl DuckDbConnectionPoolBuilder {
44    pub fn memory() -> Self {
45        Self {
46            path: String::default(),
47            max_size: None,
48            access_mode: AccessMode::ReadWrite,
49            min_idle: None,
50            mode: Mode::Memory,
51            connection_setup_queries: Vec::new(),
52        }
53    }
54
55    pub fn file(path: &str) -> Self {
56        Self {
57            path: path.to_string(),
58            max_size: None,
59            access_mode: AccessMode::ReadWrite,
60            min_idle: None,
61            mode: Mode::File,
62            connection_setup_queries: Vec::new(),
63        }
64    }
65
66    pub fn get_path(&self) -> String {
67        self.path.clone()
68    }
69
70    pub fn get_mode(&self) -> Mode {
71        self.mode
72    }
73
74    pub fn with_max_size(mut self, size: Option<u32>) -> Self {
75        self.max_size = size;
76        self
77    }
78
79    pub fn with_access_mode(mut self, access_mode: AccessMode) -> Self {
80        self.access_mode = access_mode;
81        self
82    }
83
84    pub fn with_min_idle(mut self, min_idle: Option<u32>) -> Self {
85        self.min_idle = min_idle;
86        self
87    }
88
89    pub fn with_connection_setup_query(mut self, query: impl Into<Arc<str>>) -> Self {
90        self.connection_setup_queries.push(query.into());
91        self
92    }
93
94    fn build_memory_pool(self) -> Result<DuckDbConnectionPool> {
95        let config = get_config(&AccessMode::ReadWrite)?;
96        let manager =
97            DuckdbConnectionManager::memory_with_flags(config).context(DuckDBConnectionSnafu)?;
98
99        let mut pool_builder = r2d2::Pool::builder();
100
101        if let Some(size) = self.max_size {
102            pool_builder = pool_builder.max_size(size)
103        }
104        if self.min_idle.is_some() {
105            pool_builder = pool_builder.min_idle(self.min_idle)
106        }
107
108        let pool = Arc::new(pool_builder.build(manager).context(ConnectionPoolSnafu)?);
109
110        let conn = pool.get().context(ConnectionPoolSnafu)?;
111        conn.register_table_function::<ArrowVTab>("arrow")
112            .context(DuckDBConnectionSnafu)?;
113
114        test_connection(&conn)?;
115
116        Ok(DuckDbConnectionPool {
117            path: ":memory:".into(),
118            pool,
119            join_push_down: JoinPushDown::AllowedFor(":memory:".to_string()),
120            attached_databases: Vec::new(),
121            mode: Mode::Memory,
122            unsupported_type_action: UnsupportedTypeAction::Error,
123            connection_setup_queries: self.connection_setup_queries,
124        })
125    }
126
127    fn build_file_pool(self) -> Result<DuckDbConnectionPool> {
128        let config = get_config(&self.access_mode)?;
129        let manager = DuckdbConnectionManager::file_with_flags(&self.path, config)
130            .context(DuckDBConnectionSnafu)?;
131
132        let mut pool_builder = r2d2::Pool::builder();
133
134        if let Some(size) = self.max_size {
135            pool_builder = pool_builder.max_size(size)
136        }
137        if self.min_idle.is_some() {
138            pool_builder = pool_builder.min_idle(self.min_idle)
139        }
140
141        let pool = Arc::new(pool_builder.build(manager).context(ConnectionPoolSnafu)?);
142
143        let conn = pool.get().context(ConnectionPoolSnafu)?;
144        conn.register_table_function::<ArrowVTab>("arrow")
145            .context(DuckDBConnectionSnafu)?;
146
147        test_connection(&conn)?;
148
149        Ok(DuckDbConnectionPool {
150            path: self.path.as_str().into(),
151            pool,
152            // Allow join-push down for any other instances that connect to the same underlying file.
153            join_push_down: JoinPushDown::AllowedFor(self.path),
154            attached_databases: Vec::new(),
155            mode: Mode::File,
156            unsupported_type_action: UnsupportedTypeAction::Error,
157            connection_setup_queries: self.connection_setup_queries,
158        })
159    }
160
161    pub fn build(self) -> Result<DuckDbConnectionPool> {
162        match self.mode {
163            Mode::Memory => self.build_memory_pool(),
164            Mode::File => self.build_file_pool(),
165        }
166    }
167}
168
169#[derive(Clone)]
170pub struct DuckDbConnectionPool {
171    path: Arc<str>,
172    pool: Arc<r2d2::Pool<DuckdbConnectionManager>>,
173    join_push_down: JoinPushDown,
174    attached_databases: Vec<Arc<str>>,
175    mode: Mode,
176    unsupported_type_action: UnsupportedTypeAction,
177    connection_setup_queries: Vec<Arc<str>>,
178}
179
180impl std::fmt::Debug for DuckDbConnectionPool {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("DuckDbConnectionPool")
183            .field("path", &self.path)
184            .field("join_push_down", &self.join_push_down)
185            .field("attached_databases", &self.attached_databases)
186            .field("mode", &self.mode)
187            .field("unsupported_type_action", &self.unsupported_type_action)
188            .finish()
189    }
190}
191
192impl DuckDbConnectionPool {
193    /// Get the dataset path. Returns `:memory:` if the in memory database is used.
194    pub fn db_path(&self) -> &str {
195        self.path.as_ref()
196    }
197
198    /// Create a new `DuckDbConnectionPool` from memory.
199    ///
200    /// # Arguments
201    ///
202    /// * `access_mode` - The access mode for the connection pool
203    ///
204    /// # Returns
205    ///
206    /// * A new `DuckDbConnectionPool`
207    ///
208    /// # Errors
209    ///
210    /// * `DuckDBConnectionSnafu` - If there is an error creating the connection pool
211    /// * `ConnectionPoolSnafu` - If there is an error creating the connection pool
212    pub fn new_memory() -> Result<Self> {
213        DuckDbConnectionPoolBuilder::memory().build()
214    }
215
216    /// Create a new `DuckDbConnectionPool` from a file.
217    ///
218    /// # Arguments
219    ///
220    /// * `path` - The path to the file
221    /// * `access_mode` - The access mode for the connection pool
222    ///
223    /// # Returns
224    ///
225    /// * A new `DuckDbConnectionPool`
226    ///
227    /// # Errors
228    ///
229    /// * `DuckDBConnectionSnafu` - If there is an error creating the connection pool
230    /// * `ConnectionPoolSnafu` - If there is an error creating the connection pool
231    pub fn new_file(path: &str, access_mode: &AccessMode) -> Result<Self> {
232        let access_mode = match access_mode {
233            AccessMode::Automatic => AccessMode::Automatic,
234            AccessMode::ReadOnly => AccessMode::ReadOnly,
235            AccessMode::ReadWrite => AccessMode::ReadWrite,
236        };
237        DuckDbConnectionPoolBuilder::file(path)
238            .with_access_mode(access_mode)
239            .build()
240    }
241
242    #[must_use]
243    pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self {
244        self.unsupported_type_action = action;
245        self
246    }
247
248    #[must_use]
249    pub fn set_attached_databases(mut self, databases: &[Arc<str>]) -> Self {
250        self.attached_databases = databases.to_vec();
251
252        if !databases.is_empty() {
253            let mut paths = self.attached_databases.clone();
254            paths.push(Arc::clone(&self.path));
255            paths.sort();
256            let push_down_context = paths.join(";");
257            self.join_push_down = JoinPushDown::AllowedFor(push_down_context);
258        }
259
260        self
261    }
262
263    #[must_use]
264    pub fn with_connection_setup_queries(mut self, queries: Vec<Arc<str>>) -> Self {
265        self.connection_setup_queries = queries;
266        self
267    }
268
269    /// Create a new `DuckDbConnectionPool` from a database URL.
270    ///
271    /// # Errors
272    ///
273    /// * `DuckDBConnectionSnafu` - If there is an error creating the connection pool
274    pub fn connect_sync(
275        self: Arc<Self>,
276    ) -> Result<
277        Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
278    > {
279        let pool = Arc::clone(&self.pool);
280        let conn: r2d2::PooledConnection<DuckdbConnectionManager> =
281            pool.get().context(ConnectionPoolSnafu)?;
282
283        let attachments = self.get_attachments()?;
284
285        for query in self.connection_setup_queries.iter() {
286            tracing::debug!("DuckDB connection setup: {}", query);
287            conn.execute(query, []).context(DuckDBConnectionSnafu)?;
288        }
289
290        Ok(Box::new(
291            DuckDbConnection::new(conn)
292                .with_attachments(attachments)
293                .with_connection_setup_queries(self.connection_setup_queries.clone())
294                .with_unsupported_type_action(self.unsupported_type_action),
295        ))
296    }
297
298    #[must_use]
299    pub fn mode(&self) -> Mode {
300        self.mode
301    }
302
303    pub fn get_attachments(&self) -> Result<Option<Arc<DuckDBAttachments>>> {
304        if self.attached_databases.is_empty() {
305            Ok(None)
306        } else {
307            #[cfg(not(feature = "duckdb-federation"))]
308            return Ok(None);
309
310            #[cfg(feature = "duckdb-federation")]
311            Ok(Some(Arc::new(DuckDBAttachments::new(
312                &extract_db_name(Arc::clone(&self.path))?,
313                &self.attached_databases,
314            ))))
315        }
316    }
317}
318
319#[async_trait]
320impl DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
321    for DuckDbConnectionPool
322{
323    async fn connect(
324        &self,
325    ) -> Result<
326        Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
327    > {
328        let pool = Arc::clone(&self.pool);
329        let conn: r2d2::PooledConnection<DuckdbConnectionManager> =
330            pool.get().context(ConnectionPoolSnafu)?;
331
332        let attachments = self.get_attachments()?;
333
334        for query in self.connection_setup_queries.iter() {
335            tracing::debug!("DuckDB connection setup: {}", query);
336            conn.execute(query, []).context(DuckDBConnectionSnafu)?;
337        }
338
339        Ok(Box::new(
340            DuckDbConnection::new(conn)
341                .with_attachments(attachments)
342                .with_connection_setup_queries(self.connection_setup_queries.clone())
343                .with_unsupported_type_action(self.unsupported_type_action),
344        ))
345    }
346
347    fn join_push_down(&self) -> JoinPushDown {
348        self.join_push_down.clone()
349    }
350}
351
352fn test_connection(conn: &r2d2::PooledConnection<DuckdbConnectionManager>) -> Result<()> {
353    conn.execute("SELECT 1", [])
354        .context(DuckDBConnectionSnafu)?;
355    Ok(())
356}
357
358fn get_config(access_mode: &AccessMode) -> Result<duckdb::Config> {
359    let config = duckdb::Config::default()
360        .access_mode(match access_mode {
361            AccessMode::ReadOnly => duckdb::AccessMode::ReadOnly,
362            AccessMode::ReadWrite => duckdb::AccessMode::ReadWrite,
363            AccessMode::Automatic => duckdb::AccessMode::Automatic,
364        })
365        .context(DuckDBConnectionSnafu)?;
366
367    Ok(config)
368}
369
370// Helper function to extract the duckdb database name from the duckdb file path
371fn extract_db_name(file_path: Arc<str>) -> Result<String> {
372    let path = std::path::Path::new(file_path.as_ref());
373
374    let db_name = match path.file_stem().and_then(|name| name.to_str()) {
375        Some(name) => name,
376        None => {
377            return Err(Box::new(Error::UnableToExtractDatabaseNameFromPath {
378                path: file_path,
379            }))
380        }
381    };
382
383    Ok(db_name.to_string())
384}
385
386#[cfg(test)]
387mod test {
388    use rand::Rng;
389
390    use super::*;
391    use crate::sql::db_connection_pool::DbConnectionPool;
392
393    fn random_db_name() -> String {
394        let mut rng = rand::rng();
395        let mut name = String::new();
396
397        for _ in 0..10 {
398            name.push(rng.random_range(b'a'..=b'z') as char);
399        }
400
401        format!("./{name}.duckdb")
402    }
403
404    #[tokio::test]
405    async fn test_duckdb_connection_pool() {
406        let pool =
407            DuckDbConnectionPool::new_memory().expect("DuckDB connection pool to be created");
408        let conn = pool
409            .connect()
410            .await
411            .expect("DuckDB connection should be established");
412        let conn = conn
413            .as_sync()
414            .expect("DuckDB connection should be synchronous");
415
416        conn.execute("CREATE TABLE test (a INTEGER, b VARCHAR)", &[])
417            .expect("Table should be created");
418        conn.execute("INSERT INTO test VALUES (1, 'a')", &[])
419            .expect("Data should be inserted");
420
421        conn.query_arrow("SELECT * FROM test", &[], None)
422            .expect("Query should be successful");
423    }
424
425    #[tokio::test]
426    #[cfg(feature = "duckdb-federation")]
427    async fn test_duckdb_connection_pool_with_attached_databases() {
428        let db_base_name = random_db_name();
429        let db_attached_name = random_db_name();
430        let pool = DuckDbConnectionPool::new_file(&db_base_name, &AccessMode::ReadWrite)
431            .expect("DuckDB connection pool to be created")
432            .set_attached_databases(&[Arc::from(db_attached_name.as_str())]);
433
434        let pool_attached =
435            DuckDbConnectionPool::new_file(&db_attached_name, &AccessMode::ReadWrite)
436                .expect("DuckDB connection pool to be created")
437                .set_attached_databases(&[Arc::from(db_base_name.as_str())]);
438
439        let conn = pool
440            .pool
441            .get()
442            .expect("DuckDB connection should be established");
443
444        conn.execute("CREATE TABLE test_one (a INTEGER, b VARCHAR)", [])
445            .expect("Table should be created");
446        conn.execute("INSERT INTO test_one VALUES (1, 'a')", [])
447            .expect("Data should be inserted");
448
449        let conn_attached = pool_attached
450            .pool
451            .get()
452            .expect("DuckDB connection should be established");
453
454        conn_attached
455            .execute("CREATE TABLE test_two (a INTEGER, b VARCHAR)", [])
456            .expect("Table should be created");
457        conn_attached
458            .execute("INSERT INTO test_two VALUES (1, 'a')", [])
459            .expect("Data should be inserted");
460
461        let conn = pool
462            .connect()
463            .await
464            .expect("DuckDB connection should be established");
465        let conn = conn
466            .as_sync()
467            .expect("DuckDB connection should be synchronous");
468
469        let conn_attached = pool_attached
470            .connect()
471            .await
472            .expect("DuckDB connection should be established");
473        let conn_attached = conn_attached
474            .as_sync()
475            .expect("DuckDB connection should be synchronous");
476
477        // sleep to let writes clear
478        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
479
480        conn.query_arrow("SELECT * FROM test_one", &[], None)
481            .expect("Query should be successful");
482
483        conn_attached
484            .query_arrow("SELECT * FROM test_two", &[], None)
485            .expect("Query should be successful");
486
487        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
488
489        conn_attached
490            .query_arrow("SELECT * FROM test_one", &[], None)
491            .expect("Query should be successful");
492
493        conn.query_arrow("SELECT * FROM test_two", &[], None)
494            .expect("Query should be successful");
495
496        std::fs::remove_file(&db_base_name).expect("File should be removed");
497        std::fs::remove_file(&db_attached_name).expect("File should be removed");
498    }
499}