datafusion_table_providers/sql/db_connection_pool/
duckdbpool.rs1use async_trait::async_trait;
2use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager};
3use snafu::{prelude::*, ResultExt};
4use std::sync::Arc;
5
6use super::{
7 dbconnection::duckdbconn::{DuckDBAttachments, DuckDBParameter},
8 DbConnectionPool, Mode, Result,
9};
10use crate::{
11 sql::db_connection_pool::{
12 dbconnection::{duckdbconn::DuckDbConnection, DbConnection, SyncDbConnection},
13 JoinPushDown,
14 },
15 UnsupportedTypeAction,
16};
17
18#[derive(Debug, Snafu)]
19pub enum Error {
20 #[snafu(display("DuckDB connection failed.\n{source}\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/"))]
21 DuckDBConnectionError { source: duckdb::Error },
22
23 #[snafu(display(
24 "DuckDB connection failed.\n{source}\nAdjust the DuckDB connection pool parameters for sufficient capacity."
25 ))]
26 ConnectionPoolError { source: r2d2::Error },
27
28 #[snafu(display(
29 "Invalid DuckDB file path: {path}. Ensure it contains a valid database name."
30 ))]
31 UnableToExtractDatabaseNameFromPath { path: Arc<str> },
32}
33
34pub struct DuckDbConnectionPoolBuilder {
35 path: String,
36 max_size: Option<u32>,
37 access_mode: AccessMode,
38 min_idle: Option<u32>,
39 mode: Mode,
40 connection_setup_queries: Vec<Arc<str>>,
41}
42
43impl DuckDbConnectionPoolBuilder {
44 pub fn memory() -> Self {
45 Self {
46 path: String::default(),
47 max_size: None,
48 access_mode: AccessMode::ReadWrite,
49 min_idle: None,
50 mode: Mode::Memory,
51 connection_setup_queries: Vec::new(),
52 }
53 }
54
55 pub fn file(path: &str) -> Self {
56 Self {
57 path: path.to_string(),
58 max_size: None,
59 access_mode: AccessMode::ReadWrite,
60 min_idle: None,
61 mode: Mode::File,
62 connection_setup_queries: Vec::new(),
63 }
64 }
65
66 pub fn get_path(&self) -> String {
67 self.path.clone()
68 }
69
70 pub fn get_mode(&self) -> Mode {
71 self.mode
72 }
73
74 pub fn with_max_size(mut self, size: Option<u32>) -> Self {
75 self.max_size = size;
76 self
77 }
78
79 pub fn with_access_mode(mut self, access_mode: AccessMode) -> Self {
80 self.access_mode = access_mode;
81 self
82 }
83
84 pub fn with_min_idle(mut self, min_idle: Option<u32>) -> Self {
85 self.min_idle = min_idle;
86 self
87 }
88
89 pub fn with_connection_setup_query(mut self, query: impl Into<Arc<str>>) -> Self {
90 self.connection_setup_queries.push(query.into());
91 self
92 }
93
94 fn build_memory_pool(self) -> Result<DuckDbConnectionPool> {
95 let config = get_config(&AccessMode::ReadWrite)?;
96 let manager =
97 DuckdbConnectionManager::memory_with_flags(config).context(DuckDBConnectionSnafu)?;
98
99 let mut pool_builder = r2d2::Pool::builder();
100
101 if let Some(size) = self.max_size {
102 pool_builder = pool_builder.max_size(size)
103 }
104 if self.min_idle.is_some() {
105 pool_builder = pool_builder.min_idle(self.min_idle)
106 }
107
108 let pool = Arc::new(pool_builder.build(manager).context(ConnectionPoolSnafu)?);
109
110 let conn = pool.get().context(ConnectionPoolSnafu)?;
111 conn.register_table_function::<ArrowVTab>("arrow")
112 .context(DuckDBConnectionSnafu)?;
113
114 test_connection(&conn)?;
115
116 Ok(DuckDbConnectionPool {
117 path: ":memory:".into(),
118 pool,
119 join_push_down: JoinPushDown::AllowedFor(":memory:".to_string()),
120 attached_databases: Vec::new(),
121 mode: Mode::Memory,
122 unsupported_type_action: UnsupportedTypeAction::Error,
123 connection_setup_queries: self.connection_setup_queries,
124 })
125 }
126
127 fn build_file_pool(self) -> Result<DuckDbConnectionPool> {
128 let config = get_config(&self.access_mode)?;
129 let manager = DuckdbConnectionManager::file_with_flags(&self.path, config)
130 .context(DuckDBConnectionSnafu)?;
131
132 let mut pool_builder = r2d2::Pool::builder();
133
134 if let Some(size) = self.max_size {
135 pool_builder = pool_builder.max_size(size)
136 }
137 if self.min_idle.is_some() {
138 pool_builder = pool_builder.min_idle(self.min_idle)
139 }
140
141 let pool = Arc::new(pool_builder.build(manager).context(ConnectionPoolSnafu)?);
142
143 let conn = pool.get().context(ConnectionPoolSnafu)?;
144 conn.register_table_function::<ArrowVTab>("arrow")
145 .context(DuckDBConnectionSnafu)?;
146
147 test_connection(&conn)?;
148
149 Ok(DuckDbConnectionPool {
150 path: self.path.as_str().into(),
151 pool,
152 join_push_down: JoinPushDown::AllowedFor(self.path),
154 attached_databases: Vec::new(),
155 mode: Mode::File,
156 unsupported_type_action: UnsupportedTypeAction::Error,
157 connection_setup_queries: self.connection_setup_queries,
158 })
159 }
160
161 pub fn build(self) -> Result<DuckDbConnectionPool> {
162 match self.mode {
163 Mode::Memory => self.build_memory_pool(),
164 Mode::File => self.build_file_pool(),
165 }
166 }
167}
168
169#[derive(Clone)]
170pub struct DuckDbConnectionPool {
171 path: Arc<str>,
172 pool: Arc<r2d2::Pool<DuckdbConnectionManager>>,
173 join_push_down: JoinPushDown,
174 attached_databases: Vec<Arc<str>>,
175 mode: Mode,
176 unsupported_type_action: UnsupportedTypeAction,
177 connection_setup_queries: Vec<Arc<str>>,
178}
179
180impl std::fmt::Debug for DuckDbConnectionPool {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("DuckDbConnectionPool")
183 .field("path", &self.path)
184 .field("join_push_down", &self.join_push_down)
185 .field("attached_databases", &self.attached_databases)
186 .field("mode", &self.mode)
187 .field("unsupported_type_action", &self.unsupported_type_action)
188 .finish()
189 }
190}
191
192impl DuckDbConnectionPool {
193 pub fn db_path(&self) -> &str {
195 self.path.as_ref()
196 }
197
198 pub fn new_memory() -> Result<Self> {
213 DuckDbConnectionPoolBuilder::memory().build()
214 }
215
216 pub fn new_file(path: &str, access_mode: &AccessMode) -> Result<Self> {
232 let access_mode = match access_mode {
233 AccessMode::Automatic => AccessMode::Automatic,
234 AccessMode::ReadOnly => AccessMode::ReadOnly,
235 AccessMode::ReadWrite => AccessMode::ReadWrite,
236 };
237 DuckDbConnectionPoolBuilder::file(path)
238 .with_access_mode(access_mode)
239 .build()
240 }
241
242 #[must_use]
243 pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self {
244 self.unsupported_type_action = action;
245 self
246 }
247
248 #[must_use]
249 pub fn set_attached_databases(mut self, databases: &[Arc<str>]) -> Self {
250 self.attached_databases = databases.to_vec();
251
252 if !databases.is_empty() {
253 let mut paths = self.attached_databases.clone();
254 paths.push(Arc::clone(&self.path));
255 paths.sort();
256 let push_down_context = paths.join(";");
257 self.join_push_down = JoinPushDown::AllowedFor(push_down_context);
258 }
259
260 self
261 }
262
263 #[must_use]
264 pub fn with_connection_setup_queries(mut self, queries: Vec<Arc<str>>) -> Self {
265 self.connection_setup_queries = queries;
266 self
267 }
268
269 pub fn connect_sync(
275 self: Arc<Self>,
276 ) -> Result<
277 Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
278 > {
279 let pool = Arc::clone(&self.pool);
280 let conn: r2d2::PooledConnection<DuckdbConnectionManager> =
281 pool.get().context(ConnectionPoolSnafu)?;
282
283 let attachments = self.get_attachments()?;
284
285 for query in self.connection_setup_queries.iter() {
286 tracing::debug!("DuckDB connection setup: {}", query);
287 conn.execute(query, []).context(DuckDBConnectionSnafu)?;
288 }
289
290 Ok(Box::new(
291 DuckDbConnection::new(conn)
292 .with_attachments(attachments)
293 .with_connection_setup_queries(self.connection_setup_queries.clone())
294 .with_unsupported_type_action(self.unsupported_type_action),
295 ))
296 }
297
298 #[must_use]
299 pub fn mode(&self) -> Mode {
300 self.mode
301 }
302
303 pub fn get_attachments(&self) -> Result<Option<Arc<DuckDBAttachments>>> {
304 if self.attached_databases.is_empty() {
305 Ok(None)
306 } else {
307 #[cfg(not(feature = "duckdb-federation"))]
308 return Ok(None);
309
310 #[cfg(feature = "duckdb-federation")]
311 Ok(Some(Arc::new(DuckDBAttachments::new(
312 &extract_db_name(Arc::clone(&self.path))?,
313 &self.attached_databases,
314 ))))
315 }
316 }
317}
318
319#[async_trait]
320impl DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
321 for DuckDbConnectionPool
322{
323 async fn connect(
324 &self,
325 ) -> Result<
326 Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
327 > {
328 let pool = Arc::clone(&self.pool);
329 let conn: r2d2::PooledConnection<DuckdbConnectionManager> =
330 pool.get().context(ConnectionPoolSnafu)?;
331
332 let attachments = self.get_attachments()?;
333
334 for query in self.connection_setup_queries.iter() {
335 tracing::debug!("DuckDB connection setup: {}", query);
336 conn.execute(query, []).context(DuckDBConnectionSnafu)?;
337 }
338
339 Ok(Box::new(
340 DuckDbConnection::new(conn)
341 .with_attachments(attachments)
342 .with_connection_setup_queries(self.connection_setup_queries.clone())
343 .with_unsupported_type_action(self.unsupported_type_action),
344 ))
345 }
346
347 fn join_push_down(&self) -> JoinPushDown {
348 self.join_push_down.clone()
349 }
350}
351
352fn test_connection(conn: &r2d2::PooledConnection<DuckdbConnectionManager>) -> Result<()> {
353 conn.execute("SELECT 1", [])
354 .context(DuckDBConnectionSnafu)?;
355 Ok(())
356}
357
358fn get_config(access_mode: &AccessMode) -> Result<duckdb::Config> {
359 let config = duckdb::Config::default()
360 .access_mode(match access_mode {
361 AccessMode::ReadOnly => duckdb::AccessMode::ReadOnly,
362 AccessMode::ReadWrite => duckdb::AccessMode::ReadWrite,
363 AccessMode::Automatic => duckdb::AccessMode::Automatic,
364 })
365 .context(DuckDBConnectionSnafu)?;
366
367 Ok(config)
368}
369
370fn extract_db_name(file_path: Arc<str>) -> Result<String> {
372 let path = std::path::Path::new(file_path.as_ref());
373
374 let db_name = match path.file_stem().and_then(|name| name.to_str()) {
375 Some(name) => name,
376 None => {
377 return Err(Box::new(Error::UnableToExtractDatabaseNameFromPath {
378 path: file_path,
379 }))
380 }
381 };
382
383 Ok(db_name.to_string())
384}
385
386#[cfg(test)]
387mod test {
388 use rand::Rng;
389
390 use super::*;
391 use crate::sql::db_connection_pool::DbConnectionPool;
392
393 fn random_db_name() -> String {
394 let mut rng = rand::rng();
395 let mut name = String::new();
396
397 for _ in 0..10 {
398 name.push(rng.random_range(b'a'..=b'z') as char);
399 }
400
401 format!("./{name}.duckdb")
402 }
403
404 #[tokio::test]
405 async fn test_duckdb_connection_pool() {
406 let pool =
407 DuckDbConnectionPool::new_memory().expect("DuckDB connection pool to be created");
408 let conn = pool
409 .connect()
410 .await
411 .expect("DuckDB connection should be established");
412 let conn = conn
413 .as_sync()
414 .expect("DuckDB connection should be synchronous");
415
416 conn.execute("CREATE TABLE test (a INTEGER, b VARCHAR)", &[])
417 .expect("Table should be created");
418 conn.execute("INSERT INTO test VALUES (1, 'a')", &[])
419 .expect("Data should be inserted");
420
421 conn.query_arrow("SELECT * FROM test", &[], None)
422 .expect("Query should be successful");
423 }
424
425 #[tokio::test]
426 #[cfg(feature = "duckdb-federation")]
427 async fn test_duckdb_connection_pool_with_attached_databases() {
428 let db_base_name = random_db_name();
429 let db_attached_name = random_db_name();
430 let pool = DuckDbConnectionPool::new_file(&db_base_name, &AccessMode::ReadWrite)
431 .expect("DuckDB connection pool to be created")
432 .set_attached_databases(&[Arc::from(db_attached_name.as_str())]);
433
434 let pool_attached =
435 DuckDbConnectionPool::new_file(&db_attached_name, &AccessMode::ReadWrite)
436 .expect("DuckDB connection pool to be created")
437 .set_attached_databases(&[Arc::from(db_base_name.as_str())]);
438
439 let conn = pool
440 .pool
441 .get()
442 .expect("DuckDB connection should be established");
443
444 conn.execute("CREATE TABLE test_one (a INTEGER, b VARCHAR)", [])
445 .expect("Table should be created");
446 conn.execute("INSERT INTO test_one VALUES (1, 'a')", [])
447 .expect("Data should be inserted");
448
449 let conn_attached = pool_attached
450 .pool
451 .get()
452 .expect("DuckDB connection should be established");
453
454 conn_attached
455 .execute("CREATE TABLE test_two (a INTEGER, b VARCHAR)", [])
456 .expect("Table should be created");
457 conn_attached
458 .execute("INSERT INTO test_two VALUES (1, 'a')", [])
459 .expect("Data should be inserted");
460
461 let conn = pool
462 .connect()
463 .await
464 .expect("DuckDB connection should be established");
465 let conn = conn
466 .as_sync()
467 .expect("DuckDB connection should be synchronous");
468
469 let conn_attached = pool_attached
470 .connect()
471 .await
472 .expect("DuckDB connection should be established");
473 let conn_attached = conn_attached
474 .as_sync()
475 .expect("DuckDB connection should be synchronous");
476
477 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
479
480 conn.query_arrow("SELECT * FROM test_one", &[], None)
481 .expect("Query should be successful");
482
483 conn_attached
484 .query_arrow("SELECT * FROM test_two", &[], None)
485 .expect("Query should be successful");
486
487 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
488
489 conn_attached
490 .query_arrow("SELECT * FROM test_one", &[], None)
491 .expect("Query should be successful");
492
493 conn.query_arrow("SELECT * FROM test_two", &[], None)
494 .expect("Query should be successful");
495
496 std::fs::remove_file(&db_base_name).expect("File should be removed");
497 std::fs::remove_file(&db_attached_name).expect("File should be removed");
498 }
499}