datafusion_table_providers/
duckdb.rs

1use crate::sql::sql_provider_datafusion;
2use crate::util::{
3    self,
4    column_reference::{self, ColumnReference},
5    constraints,
6    indexes::IndexType,
7    on_conflict::{self, OnConflict},
8};
9use crate::{
10    sql::db_connection_pool::{
11        self,
12        dbconnection::{
13            duckdbconn::{
14                flatten_table_function_name, is_table_function, DuckDBParameter, DuckDbConnection,
15            },
16            get_schema, DbConnection,
17        },
18        duckdbpool::{DuckDbConnectionPool, DuckDbConnectionPoolBuilder},
19        DbConnectionPool, DbInstanceKey, Mode,
20    },
21    UnsupportedTypeAction,
22};
23use arrow::datatypes::SchemaRef;
24use async_trait::async_trait;
25use creator::TableManager;
26use datafusion::sql::unparser::dialect::{Dialect, DuckDBDialect};
27use datafusion::{
28    catalog::{Session, TableProviderFactory},
29    common::Constraints,
30    datasource::TableProvider,
31    error::{DataFusionError, Result as DataFusionResult},
32    logical_expr::CreateExternalTable,
33    sql::TableReference,
34};
35use duckdb::{AccessMode, DuckdbConnectionManager};
36use itertools::Itertools;
37use snafu::prelude::*;
38use std::{collections::HashMap, sync::Arc};
39use tokio::sync::Mutex;
40use write::DuckDBTableWriterBuilder;
41
42pub use self::settings::{
43    DuckDBSetting, DuckDBSettingScope, DuckDBSettingsRegistry, MemoryLimitSetting,
44    PreserveInsertionOrderSetting, TempDirectorySetting,
45};
46use self::sql_table::DuckDBTable;
47
48#[cfg(feature = "duckdb-federation")]
49mod federation;
50
51mod creator;
52mod settings;
53mod sql_table;
54pub mod write;
55pub use creator::{RelationName, TableDefinition};
56
57#[derive(Debug, Snafu)]
58pub enum Error {
59    #[snafu(display("DbConnectionError: {source}"))]
60    DbConnectionError {
61        source: db_connection_pool::dbconnection::GenericError,
62    },
63
64    #[snafu(display("DbConnectionPoolError: {source}"))]
65    DbConnectionPoolError { source: db_connection_pool::Error },
66
67    #[snafu(display("DuckDBDataFusionError: {source}"))]
68    DuckDBDataFusion {
69        source: sql_provider_datafusion::Error,
70    },
71
72    #[snafu(display("Unable to downcast DbConnection to DuckDbConnection"))]
73    UnableToDowncastDbConnection {},
74
75    #[snafu(display("Unable to drop duckdb table: {source}"))]
76    UnableToDropDuckDBTable { source: duckdb::Error },
77
78    #[snafu(display("Unable to create duckdb table: {source}"))]
79    UnableToCreateDuckDBTable { source: duckdb::Error },
80
81    #[snafu(display("Unable to create index on duckdb table: {source}"))]
82    UnableToCreateIndexOnDuckDBTable { source: duckdb::Error },
83
84    #[snafu(display("Unable to retrieve existing primary keys from DuckDB table: {source}"))]
85    UnableToGetPrimaryKeysOnDuckDBTable { source: duckdb::Error },
86
87    #[snafu(display("Unable to drop index on duckdb table: {source}"))]
88    UnableToDropIndexOnDuckDBTable { source: duckdb::Error },
89
90    #[snafu(display("Unable to rename duckdb table: {source}"))]
91    UnableToRenameDuckDBTable { source: duckdb::Error },
92
93    #[snafu(display("Unable to insert into duckdb table: {source}"))]
94    UnableToInsertToDuckDBTable { source: duckdb::Error },
95
96    #[snafu(display("Unable to get appender to duckdb table: {source}"))]
97    UnableToGetAppenderToDuckDBTable { source: duckdb::Error },
98
99    #[snafu(display("Unable to delete data from the duckdb table: {source}"))]
100    UnableToDeleteDuckdbData { source: duckdb::Error },
101
102    #[snafu(display("Unable to query data from the duckdb table: {source}"))]
103    UnableToQueryData { source: duckdb::Error },
104
105    #[snafu(display("Unable to commit transaction: {source}"))]
106    UnableToCommitTransaction { source: duckdb::Error },
107
108    #[snafu(display("Unable to begin duckdb transaction: {source}"))]
109    UnableToBeginTransaction { source: duckdb::Error },
110
111    #[snafu(display("Unable to rollback transaction: {source}"))]
112    UnableToRollbackTransaction { source: duckdb::Error },
113
114    #[snafu(display("Unable to delete all data from the DuckDB table: {source}"))]
115    UnableToDeleteAllTableData { source: duckdb::Error },
116
117    #[snafu(display("Unable to insert data into the DuckDB table: {source}"))]
118    UnableToInsertIntoTableAsync { source: duckdb::Error },
119
120    #[snafu(display("The table '{table_name}' doesn't exist in the DuckDB server"))]
121    TableDoesntExist { table_name: String },
122
123    #[snafu(display("Constraint Violation: {source}"))]
124    ConstraintViolation { source: constraints::Error },
125
126    #[snafu(display("Error parsing column reference: {source}"))]
127    UnableToParseColumnReference { source: column_reference::Error },
128
129    #[snafu(display("Error parsing on_conflict: {source}"))]
130    UnableToParseOnConflict { source: on_conflict::Error },
131
132    #[snafu(display(
133        "Failed to create '{table_name}': creating a table with a schema is not supported"
134    ))]
135    TableWithSchemaCreationNotSupported { table_name: String },
136
137    #[snafu(display("Failed to parse memory_limit value '{value}': {source}\nProvide a valid value, e.g. '2GB', '512MiB' (expected: KB, MB, GB, TB for 1000^i units or KiB, MiB, GiB, TiB for 1024^i units)"))]
138    UnableToParseMemoryLimit {
139        value: String,
140        source: byte_unit::ParseError,
141    },
142
143    #[snafu(display("Unable to add primary key to table: {source}"))]
144    UnableToAddPrimaryKey { source: duckdb::Error },
145
146    #[snafu(display("Failed to get system time since epoch: {source}"))]
147    UnableToGetSystemTime { source: std::time::SystemTimeError },
148
149    #[snafu(display("Failed to parse the system time: {source}"))]
150    UnableToParseSystemTime { source: std::num::ParseIntError },
151
152    #[snafu(display("A read provider is required to create a DuckDBTableWriter"))]
153    MissingReadProvider,
154
155    #[snafu(display("A pool is required to create a DuckDBTableWriter"))]
156    MissingPool,
157
158    #[snafu(display("A table definition is required to create a DuckDBTableWriter"))]
159    MissingTableDefinition,
160
161    #[snafu(display("Failed to register Arrow scan view for DuckDB ingestion: {source}"))]
162    UnableToRegisterArrowScanView { source: duckdb::Error },
163
164    #[snafu(display("Failed to register Arrow scan view to build table creation statement: {source}"))]
165    UnableToRegisterArrowScanViewForTableCreation { source: duckdb::Error },
166
167    #[snafu(display("Failed to drop Arrow scan view for DuckDB ingestion: {source}"))]
168    UnableToDropArrowScanView { source: duckdb::Error },
169}
170
171type Result<T, E = Error> = std::result::Result<T, E>;
172
173const DUCKDB_DB_PATH_PARAM: &str = "open";
174const DUCKDB_DB_BASE_FOLDER_PARAM: &str = "data_directory";
175const DUCKDB_ATTACH_DATABASES_PARAM: &str = "attach_databases";
176
177pub struct DuckDBTableProviderFactory {
178    access_mode: AccessMode,
179    instances: Arc<Mutex<HashMap<DbInstanceKey, DuckDbConnectionPool>>>,
180    unsupported_type_action: UnsupportedTypeAction,
181    dialect: Arc<dyn Dialect>,
182    settings_registry: DuckDBSettingsRegistry,
183}
184
185// Dialect trait does not implement Debug so we implement Debug manually
186impl std::fmt::Debug for DuckDBTableProviderFactory {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        f.debug_struct("DuckDBTableProviderFactory")
189            .field("access_mode", &self.access_mode)
190            .field("instances", &self.instances)
191            .field("unsupported_type_action", &self.unsupported_type_action)
192            .finish()
193    }
194}
195
196impl DuckDBTableProviderFactory {
197    #[must_use]
198    pub fn new(access_mode: AccessMode) -> Self {
199        Self {
200            access_mode,
201            instances: Arc::new(Mutex::new(HashMap::new())),
202            unsupported_type_action: UnsupportedTypeAction::Error,
203            dialect: Arc::new(DuckDBDialect::new()),
204            settings_registry: DuckDBSettingsRegistry::new(),
205        }
206    }
207
208    #[must_use]
209    pub fn with_unsupported_type_action(
210        mut self,
211        unsupported_type_action: UnsupportedTypeAction,
212    ) -> Self {
213        self.unsupported_type_action = unsupported_type_action;
214        self
215    }
216
217    #[must_use]
218    pub fn with_dialect(mut self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
219        self.dialect = dialect;
220        self
221    }
222
223    #[must_use]
224    pub fn with_settings_registry(mut self, settings_registry: DuckDBSettingsRegistry) -> Self {
225        self.settings_registry = settings_registry;
226        self
227    }
228
229    #[must_use]
230    pub fn settings_registry(&self) -> &DuckDBSettingsRegistry {
231        &self.settings_registry
232    }
233
234    #[must_use]
235    pub fn settings_registry_mut(&mut self) -> &mut DuckDBSettingsRegistry {
236        &mut self.settings_registry
237    }
238
239    #[must_use]
240    pub fn attach_databases(&self, options: &HashMap<String, String>) -> Vec<Arc<str>> {
241        options
242            .get(DUCKDB_ATTACH_DATABASES_PARAM)
243            .map(|attach_databases| {
244                attach_databases
245                    .split(';')
246                    .map(Arc::from)
247                    .collect::<Vec<Arc<str>>>()
248            })
249            .unwrap_or_default()
250    }
251
252    /// Get the path to the DuckDB file database.
253    ///
254    /// ## Errors
255    ///
256    /// - If the path includes absolute sequences to escape the current directory, like `./`, `../`, or `/`.
257    pub fn duckdb_file_path(
258        &self,
259        name: &str,
260        options: &mut HashMap<String, String>,
261    ) -> Result<String, Error> {
262        let options = util::remove_prefix_from_hashmap_keys(options.clone(), "duckdb_");
263
264        let db_base_folder = options
265            .get(DUCKDB_DB_BASE_FOLDER_PARAM)
266            .cloned()
267            .unwrap_or(".".to_string()); // default to the current directory
268        let default_filepath = &format!("{db_base_folder}/{name}.db");
269
270        let filepath = options
271            .get(DUCKDB_DB_PATH_PARAM)
272            .unwrap_or(default_filepath);
273
274        Ok(filepath.to_string())
275    }
276
277    pub async fn get_or_init_memory_instance(&self) -> Result<DuckDbConnectionPool> {
278        let pool_builder = DuckDbConnectionPoolBuilder::memory();
279        self.get_or_init_instance_with_builder(pool_builder).await
280    }
281
282    pub async fn get_or_init_file_instance(
283        &self,
284        db_path: impl Into<Arc<str>>,
285    ) -> Result<DuckDbConnectionPool> {
286        let db_path: Arc<str> = db_path.into();
287        let pool_builder = DuckDbConnectionPoolBuilder::file(&db_path);
288
289        self.get_or_init_instance_with_builder(pool_builder).await
290    }
291
292    pub async fn get_or_init_instance_with_builder(
293        &self,
294        pool_builder: DuckDbConnectionPoolBuilder,
295    ) -> Result<DuckDbConnectionPool> {
296        let mode = pool_builder.get_mode();
297        let key = match mode {
298            Mode::File => {
299                let path = pool_builder.get_path();
300                DbInstanceKey::file(path.into())
301            }
302            Mode::Memory => DbInstanceKey::memory(),
303        };
304
305        let access_mode = match &self.access_mode {
306            AccessMode::ReadOnly => AccessMode::ReadOnly,
307            AccessMode::ReadWrite => AccessMode::ReadWrite,
308            AccessMode::Automatic => AccessMode::Automatic,
309        };
310        let pool_builder = pool_builder.with_access_mode(access_mode);
311
312        let mut instances = self.instances.lock().await;
313
314        if let Some(instance) = instances.get(&key) {
315            return Ok(instance.clone());
316        }
317
318        let pool = pool_builder
319            .build()
320            .context(DbConnectionPoolSnafu)?
321            .with_unsupported_type_action(self.unsupported_type_action);
322
323        instances.insert(key, pool.clone());
324
325        Ok(pool)
326    }
327}
328
329type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
330    + Send
331    + Sync;
332
333#[async_trait]
334impl TableProviderFactory for DuckDBTableProviderFactory {
335    #[allow(clippy::too_many_lines)]
336    async fn create(
337        &self,
338        _state: &dyn Session,
339        cmd: &CreateExternalTable,
340    ) -> DataFusionResult<Arc<dyn TableProvider>> {
341        if cmd.name.schema().is_some() {
342            TableWithSchemaCreationNotSupportedSnafu {
343                table_name: cmd.name.to_string(),
344            }
345            .fail()
346            .map_err(to_datafusion_error)?;
347        }
348
349        let name = cmd.name.to_string();
350        let mut options = cmd.options.clone();
351        let mode = remove_option(&mut options, "mode").unwrap_or_default();
352        let mode: Mode = mode.as_str().into();
353
354        let indexes_option_str = remove_option(&mut options, "indexes");
355        let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
356            Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
357            None => HashMap::new(),
358        };
359
360        let unparsed_indexes = unparsed_indexes
361            .into_iter()
362            .map(|(key, value)| {
363                let columns = ColumnReference::try_from(key.as_str())
364                    .context(UnableToParseColumnReferenceSnafu)
365                    .map_err(to_datafusion_error);
366                (columns, value)
367            })
368            .collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
369
370        let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
371        for (columns, index_type) in unparsed_indexes {
372            let columns = columns?;
373            indexes.push((columns, index_type));
374        }
375
376        let mut on_conflict: Option<OnConflict> = None;
377        if let Some(on_conflict_str) = remove_option(&mut options, "on_conflict") {
378            on_conflict = Some(
379                OnConflict::try_from(on_conflict_str.as_str())
380                    .context(UnableToParseOnConflictSnafu)
381                    .map_err(to_datafusion_error)?,
382            );
383        }
384
385        let pool: DuckDbConnectionPool = match &mode {
386            Mode::File => {
387                // open duckdb at given path or create a new one
388                let db_path = self
389                    .duckdb_file_path(&name, &mut options)
390                    .map_err(to_datafusion_error)?;
391
392                self.get_or_init_file_instance(db_path)
393                    .await
394                    .map_err(to_datafusion_error)?
395            }
396            Mode::Memory => self
397                .get_or_init_memory_instance()
398                .await
399                .map_err(to_datafusion_error)?,
400        };
401
402        let read_pool = match &mode {
403            Mode::File => {
404                let read_pool = pool.clone();
405
406                read_pool.set_attached_databases(&self.attach_databases(&options))
407            }
408            Mode::Memory => pool.clone(),
409        };
410
411        // Get local DuckDB SET statements to use as setup queries on the pool
412        let local_settings = self
413            .settings_registry
414            .get_setting_statements(&options, DuckDBSettingScope::Local);
415
416        let read_pool = read_pool.with_connection_setup_queries(local_settings);
417
418        let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into());
419
420        let table_definition =
421            TableDefinition::new(RelationName::new(name.clone()), Arc::clone(&schema))
422                .with_constraints(cmd.constraints.clone())
423                .with_indexes(indexes.clone());
424
425        let pool = Arc::new(pool);
426        make_initial_table(Arc::new(table_definition.clone()), &pool)?;
427
428        let table_writer_builder = DuckDBTableWriterBuilder::new()
429            .with_table_definition(table_definition)
430            .with_pool(pool)
431            .set_on_conflict(on_conflict);
432
433        let dyn_pool: Arc<DynDuckDbConnectionPool> = Arc::new(read_pool);
434
435        let db_conn = dyn_pool.connect().await?;
436        let Some(conn) = db_conn.as_sync() else {
437            return Err(DataFusionError::External(Box::new(
438                Error::DbConnectionError {
439                    source: "Failed to get sync DuckDbConnection using DbConnection".into(),
440                },
441            )));
442        };
443
444        // Apply DuckDB global settings
445        self.settings_registry
446            .apply_settings(conn, &options, DuckDBSettingScope::Global)?;
447
448        let read_provider = Arc::new(DuckDBTable::new_with_schema(
449            &dyn_pool,
450            Arc::clone(&schema),
451            TableReference::bare(name.clone()),
452            None,
453            Some(self.dialect.clone()),
454        ));
455
456        #[cfg(feature = "duckdb-federation")]
457        let read_provider: Arc<dyn TableProvider> =
458            Arc::new(read_provider.create_federated_table_provider()?);
459
460        Ok(Arc::new(
461            table_writer_builder
462                .with_read_provider(read_provider)
463                .build()
464                .map_err(to_datafusion_error)?,
465        ))
466    }
467}
468
469fn to_datafusion_error(error: Error) -> DataFusionError {
470    DataFusionError::External(Box::new(error))
471}
472
473pub struct DuckDB {
474    table_name: String,
475    pool: Arc<DuckDbConnectionPool>,
476    schema: SchemaRef,
477    constraints: Constraints,
478}
479
480impl std::fmt::Debug for DuckDB {
481    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482        f.debug_struct("DuckDB")
483            .field("table_name", &self.table_name)
484            .field("schema", &self.schema)
485            .field("constraints", &self.constraints)
486            .finish()
487    }
488}
489
490impl DuckDB {
491    #[must_use]
492    pub fn existing_table(
493        table_name: String,
494        pool: Arc<DuckDbConnectionPool>,
495        schema: SchemaRef,
496        constraints: Constraints,
497    ) -> Self {
498        Self {
499            table_name,
500            pool,
501            schema,
502            constraints,
503        }
504    }
505
506    #[must_use]
507    pub fn table_name(&self) -> &str {
508        &self.table_name
509    }
510
511    #[must_use]
512    pub fn constraints(&self) -> &Constraints {
513        &self.constraints
514    }
515
516    pub fn connect_sync(
517        &self,
518    ) -> Result<
519        Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
520    > {
521        Arc::clone(&self.pool)
522            .connect_sync()
523            .context(DbConnectionSnafu)
524    }
525
526    pub fn duckdb_conn(
527        db_connection: &mut Box<
528            dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
529        >,
530    ) -> Result<&mut DuckDbConnection> {
531        db_connection
532            .as_any_mut()
533            .downcast_mut::<DuckDbConnection>()
534            .context(UnableToDowncastDbConnectionSnafu)
535    }
536}
537
538fn remove_option(options: &mut HashMap<String, String>, key: &str) -> Option<String> {
539    options
540        .remove(key)
541        .or_else(|| options.remove(&format!("duckdb.{key}")))
542}
543
544pub struct DuckDBTableFactory {
545    pool: Arc<DuckDbConnectionPool>,
546    dialect: Arc<dyn Dialect>,
547}
548
549impl DuckDBTableFactory {
550    #[must_use]
551    pub fn new(pool: Arc<DuckDbConnectionPool>) -> Self {
552        Self {
553            pool,
554            dialect: Arc::new(DuckDBDialect::new()),
555        }
556    }
557
558    #[must_use]
559    pub fn with_dialect(mut self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
560        self.dialect = dialect;
561        self
562    }
563
564    pub async fn table_provider(
565        &self,
566        table_reference: TableReference,
567    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
568        let pool = Arc::clone(&self.pool);
569        let conn = Arc::clone(&pool).connect().await?;
570        let dyn_pool: Arc<DynDuckDbConnectionPool> = pool;
571
572        let schema = get_schema(conn, &table_reference).await?;
573        let (tbl_ref, cte) = if is_table_function(&table_reference) {
574            let tbl_ref_view = create_table_function_view_name(&table_reference);
575            (
576                tbl_ref_view.clone(),
577                Some(HashMap::from_iter(vec![(
578                    tbl_ref_view.to_string(),
579                    table_reference.table().to_string(),
580                )])),
581            )
582        } else {
583            (table_reference.clone(), None)
584        };
585
586        let table_provider = Arc::new(DuckDBTable::new_with_schema(
587            &dyn_pool,
588            schema,
589            tbl_ref,
590            cte,
591            Some(self.dialect.clone()),
592        ));
593
594        #[cfg(feature = "duckdb-federation")]
595        let table_provider: Arc<dyn TableProvider> =
596            Arc::new(table_provider.create_federated_table_provider()?);
597
598        Ok(table_provider)
599    }
600
601    pub async fn read_write_table_provider(
602        &self,
603        table_reference: TableReference,
604    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
605        let read_provider = Self::table_provider(self, table_reference.clone()).await?;
606        let schema = read_provider.schema();
607
608        let table_name = RelationName::from(table_reference);
609        let table_definition = TableDefinition::new(table_name, Arc::clone(&schema));
610        let table_writer_builder = DuckDBTableWriterBuilder::new()
611            .with_read_provider(read_provider)
612            .with_pool(Arc::clone(&self.pool))
613            .with_table_definition(table_definition);
614
615        Ok(Arc::new(table_writer_builder.build()?))
616    }
617}
618
619/// For a [`TableReference`] that is a table function, create a name for a view on the original [`TableReference`]
620///
621/// ### Example
622///
623/// ```rust,ignore
624/// use datafusion_table_providers::duckdb::create_table_function_view_name;
625/// use datafusion::common::TableReference;
626///
627/// let table_reference = TableReference::from("catalog.schema.read_parquet('cleaned_sales_data.parquet')");
628/// let view_name = create_table_function_view_name(&table_reference);
629/// assert_eq!(view_name.to_string(), "catalog.schema.read_parquet_cleaned_sales_dataparquet__view");
630/// ```
631fn create_table_function_view_name(table_reference: &TableReference) -> TableReference {
632    let tbl_ref_view = [
633        table_reference.catalog(),
634        table_reference.schema(),
635        Some(&flatten_table_function_name(table_reference)),
636    ]
637    .iter()
638    .flatten()
639    .join(".");
640    TableReference::from(&tbl_ref_view)
641}
642
643pub(crate) fn make_initial_table(
644    table_definition: Arc<TableDefinition>,
645    pool: &Arc<DuckDbConnectionPool>,
646) -> DataFusionResult<()> {
647    let cloned_pool = Arc::clone(pool);
648    let mut db_conn = Arc::clone(&cloned_pool)
649        .connect_sync()
650        .context(DbConnectionPoolSnafu)
651        .map_err(to_datafusion_error)?;
652
653    let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn).map_err(to_datafusion_error)?;
654
655    let tx = duckdb_conn
656        .conn
657        .transaction()
658        .context(UnableToBeginTransactionSnafu)
659        .map_err(to_datafusion_error)?;
660
661    let has_table = table_definition
662        .has_table(&tx)
663        .map_err(to_datafusion_error)?;
664    let internal_tables = table_definition
665        .list_internal_tables(&tx)
666        .map_err(to_datafusion_error)?;
667
668    if has_table || !internal_tables.is_empty() {
669        return Ok(());
670    }
671
672    let table_manager = TableManager::new(table_definition);
673
674    table_manager
675        .create_table(cloned_pool, &tx)
676        .map_err(to_datafusion_error)?;
677
678    tx.commit()
679        .context(UnableToCommitTransactionSnafu)
680        .map_err(to_datafusion_error)?;
681
682    Ok(())
683}
684
685#[cfg(test)]
686pub(crate) mod tests {
687    use crate::duckdb::write::DuckDBTableWriter;
688
689    use super::*;
690    use arrow::datatypes::{DataType, Field, Schema};
691    use datafusion::common::{Constraints, ToDFSchema};
692    use datafusion::logical_expr::CreateExternalTable;
693    use datafusion::prelude::SessionContext;
694    use datafusion::sql::TableReference;
695    use std::collections::HashMap;
696    use std::sync::Arc;
697
698    #[tokio::test]
699    async fn test_create_with_memory_limit() {
700        let table_name = TableReference::bare("test_table");
701        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
702
703        let mut options = HashMap::new();
704        options.insert("mode".to_string(), "memory".to_string());
705        options.insert("memory_limit".to_string(), "123MiB".to_string());
706
707        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
708        let ctx = SessionContext::new();
709        let cmd = CreateExternalTable {
710            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
711            name: table_name,
712            location: "".to_string(),
713            file_type: "".to_string(),
714            table_partition_cols: vec![],
715            if_not_exists: false,
716            definition: None,
717            order_exprs: vec![],
718            unbounded: false,
719            options,
720            constraints: Constraints::empty(),
721            column_defaults: HashMap::new(),
722            temporary: false,
723        };
724
725        let table_provider = factory
726            .create(&ctx.state(), &cmd)
727            .await
728            .expect("table provider created");
729
730        let writer = table_provider
731            .as_any()
732            .downcast_ref::<DuckDBTableWriter>()
733            .expect("cast to DuckDBTableWriter");
734
735        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
736        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
737
738        let mut stmt = conn
739            .conn
740            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'memory_limit'")
741            .expect("to prepare statement");
742
743        let memory_limit = stmt
744            .query_row([], |row| row.get::<usize, String>(0))
745            .expect("to query memory limit");
746
747        println!("Memory limit: {memory_limit}");
748
749        assert_eq!(
750            memory_limit, "123.0 MiB",
751            "Memory limit must be set to 123.0 MiB"
752        );
753    }
754
755    #[tokio::test]
756    async fn test_create_with_temp_directory() {
757        let table_name = TableReference::bare("test_table_temp_dir");
758        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
759
760        let test_temp_directory = "/tmp/duckdb_test_temp";
761        let mut options = HashMap::new();
762        options.insert("mode".to_string(), "memory".to_string());
763        options.insert(
764            "temp_directory".to_string(),
765            test_temp_directory.to_string(),
766        );
767
768        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
769        let ctx = SessionContext::new();
770        let cmd = CreateExternalTable {
771            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
772            name: table_name,
773            location: "".to_string(),
774            file_type: "".to_string(),
775            table_partition_cols: vec![],
776            if_not_exists: false,
777            definition: None,
778            order_exprs: vec![],
779            unbounded: false,
780            options,
781            constraints: Constraints::empty(),
782            column_defaults: HashMap::new(),
783            temporary: false,
784        };
785
786        let table_provider = factory
787            .create(&ctx.state(), &cmd)
788            .await
789            .expect("table provider created");
790
791        let writer = table_provider
792            .as_any()
793            .downcast_ref::<DuckDBTableWriter>()
794            .expect("cast to DuckDBTableWriter");
795
796        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
797        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
798
799        let mut stmt = conn
800            .conn
801            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'temp_directory'")
802            .expect("to prepare statement");
803
804        let temp_directory = stmt
805            .query_row([], |row| row.get::<usize, String>(0))
806            .expect("to query temp directory");
807
808        println!("Temp directory: {temp_directory}");
809
810        assert_eq!(
811            temp_directory, test_temp_directory,
812            "Temp directory must be set to {test_temp_directory}"
813        );
814    }
815
816    #[tokio::test]
817    async fn test_create_with_preserve_insertion_order_true() {
818        let table_name = TableReference::bare("test_table_preserve_order_true");
819        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
820
821        let mut options = HashMap::new();
822        options.insert("mode".to_string(), "memory".to_string());
823        options.insert("preserve_insertion_order".to_string(), "true".to_string());
824
825        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
826        let ctx = SessionContext::new();
827        let cmd = CreateExternalTable {
828            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
829            name: table_name,
830            location: "".to_string(),
831            file_type: "".to_string(),
832            table_partition_cols: vec![],
833            if_not_exists: false,
834            definition: None,
835            order_exprs: vec![],
836            unbounded: false,
837            options,
838            constraints: Constraints::empty(),
839            column_defaults: HashMap::new(),
840            temporary: false,
841        };
842
843        let table_provider = factory
844            .create(&ctx.state(), &cmd)
845            .await
846            .expect("table provider created");
847
848        let writer = table_provider
849            .as_any()
850            .downcast_ref::<DuckDBTableWriter>()
851            .expect("cast to DuckDBTableWriter");
852
853        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
854        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
855
856        let mut stmt = conn
857            .conn
858            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'preserve_insertion_order'")
859            .expect("to prepare statement");
860
861        let preserve_order = stmt
862            .query_row([], |row| row.get::<usize, String>(0))
863            .expect("to query preserve_insertion_order");
864
865        assert_eq!(
866            preserve_order, "true",
867            "preserve_insertion_order must be set to true"
868        );
869    }
870
871    #[tokio::test]
872    async fn test_create_with_preserve_insertion_order_false() {
873        let table_name = TableReference::bare("test_table_preserve_order_false");
874        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
875
876        let mut options = HashMap::new();
877        options.insert("mode".to_string(), "memory".to_string());
878        options.insert("preserve_insertion_order".to_string(), "false".to_string());
879
880        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
881        let ctx = SessionContext::new();
882        let cmd = CreateExternalTable {
883            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
884            name: table_name,
885            location: "".to_string(),
886            file_type: "".to_string(),
887            table_partition_cols: vec![],
888            if_not_exists: false,
889            definition: None,
890            order_exprs: vec![],
891            unbounded: false,
892            options,
893            constraints: Constraints::empty(),
894            column_defaults: HashMap::new(),
895            temporary: false,
896        };
897
898        let table_provider = factory
899            .create(&ctx.state(), &cmd)
900            .await
901            .expect("table provider created");
902
903        let writer = table_provider
904            .as_any()
905            .downcast_ref::<DuckDBTableWriter>()
906            .expect("cast to DuckDBTableWriter");
907
908        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
909        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
910
911        let mut stmt = conn
912            .conn
913            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'preserve_insertion_order'")
914            .expect("to prepare statement");
915
916        let preserve_order = stmt
917            .query_row([], |row| row.get::<usize, String>(0))
918            .expect("to query preserve_insertion_order");
919
920        assert_eq!(
921            preserve_order, "false",
922            "preserve_insertion_order must be set to false"
923        );
924    }
925
926    #[tokio::test]
927    async fn test_create_with_invalid_preserve_insertion_order() {
928        let table_name = TableReference::bare("test_table_preserve_order_invalid");
929        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
930
931        let mut options = HashMap::new();
932        options.insert("mode".to_string(), "memory".to_string());
933        options.insert(
934            "preserve_insertion_order".to_string(),
935            "invalid".to_string(),
936        );
937
938        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
939        let ctx = SessionContext::new();
940        let cmd = CreateExternalTable {
941            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
942            name: table_name,
943            location: "".to_string(),
944            file_type: "".to_string(),
945            table_partition_cols: vec![],
946            if_not_exists: false,
947            definition: None,
948            order_exprs: vec![],
949            unbounded: false,
950            options,
951            constraints: Constraints::empty(),
952            column_defaults: HashMap::new(),
953            temporary: false,
954        };
955
956        let result = factory.create(&ctx.state(), &cmd).await;
957        assert!(
958            result.is_err(),
959            "Should fail with invalid preserve_insertion_order value"
960        );
961        if let Err(e) = result {
962            assert_eq!(e.to_string(), "External error: Query execution failed.\nInvalid Input Error: Failed to cast value: Could not convert string 'invalid' to BOOL\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/");
963        }
964    }
965}