datafusion_table_providers/
duckdb.rs

1use crate::sql::sql_provider_datafusion;
2use crate::util::{
3    self,
4    column_reference::{self, ColumnReference},
5    constraints,
6    indexes::IndexType,
7    on_conflict::{self, OnConflict},
8};
9use crate::{
10    sql::db_connection_pool::{
11        self,
12        dbconnection::{
13            duckdbconn::{
14                flatten_table_function_name, is_table_function, DuckDBParameter, DuckDbConnection,
15            },
16            get_schema, DbConnection,
17        },
18        duckdbpool::{DuckDbConnectionPool, DuckDbConnectionPoolBuilder},
19        DbConnectionPool, DbInstanceKey, Mode,
20    },
21    UnsupportedTypeAction,
22};
23use arrow::datatypes::SchemaRef;
24use async_trait::async_trait;
25use creator::TableManager;
26use datafusion::sql::unparser::dialect::{Dialect, DuckDBDialect};
27use datafusion::{
28    catalog::{Session, TableProviderFactory},
29    common::Constraints,
30    datasource::TableProvider,
31    error::{DataFusionError, Result as DataFusionResult},
32    logical_expr::CreateExternalTable,
33    sql::TableReference,
34};
35use duckdb::{AccessMode, DuckdbConnectionManager};
36use itertools::Itertools;
37use snafu::prelude::*;
38use std::{collections::HashMap, sync::Arc};
39use tokio::sync::Mutex;
40use write::DuckDBTableWriterBuilder;
41
42pub use self::settings::{
43    DuckDBSetting, DuckDBSettingScope, DuckDBSettingsRegistry, MemoryLimitSetting,
44    PreserveInsertionOrderSetting, TempDirectorySetting,
45};
46use self::sql_table::DuckDBTable;
47
48#[cfg(feature = "duckdb-federation")]
49mod federation;
50
51mod creator;
52mod settings;
53mod sql_table;
54pub mod write;
55pub use creator::{RelationName, TableDefinition};
56
57#[derive(Debug, Snafu)]
58pub enum Error {
59    #[snafu(display("DbConnectionError: {source}"))]
60    DbConnectionError {
61        source: db_connection_pool::dbconnection::GenericError,
62    },
63
64    #[snafu(display("DbConnectionPoolError: {source}"))]
65    DbConnectionPoolError { source: db_connection_pool::Error },
66
67    #[snafu(display("DuckDBDataFusionError: {source}"))]
68    DuckDBDataFusion {
69        source: sql_provider_datafusion::Error,
70    },
71
72    #[snafu(display("Unable to downcast DbConnection to DuckDbConnection"))]
73    UnableToDowncastDbConnection {},
74
75    #[snafu(display("Unable to drop duckdb table: {source}"))]
76    UnableToDropDuckDBTable { source: duckdb::Error },
77
78    #[snafu(display("Unable to create duckdb table: {source}"))]
79    UnableToCreateDuckDBTable { source: duckdb::Error },
80
81    #[snafu(display("Unable to create index on duckdb table: {source}"))]
82    UnableToCreateIndexOnDuckDBTable { source: duckdb::Error },
83
84    #[snafu(display("Unable to retrieve existing primary keys from DuckDB table: {source}"))]
85    UnableToGetPrimaryKeysOnDuckDBTable { source: duckdb::Error },
86
87    #[snafu(display("Unable to drop index on duckdb table: {source}"))]
88    UnableToDropIndexOnDuckDBTable { source: duckdb::Error },
89
90    #[snafu(display("Unable to rename duckdb table: {source}"))]
91    UnableToRenameDuckDBTable { source: duckdb::Error },
92
93    #[snafu(display("Unable to insert into duckdb table: {source}"))]
94    UnableToInsertToDuckDBTable { source: duckdb::Error },
95
96    #[snafu(display("Unable to get appender to duckdb table: {source}"))]
97    UnableToGetAppenderToDuckDBTable { source: duckdb::Error },
98
99    #[snafu(display("Unable to delete data from the duckdb table: {source}"))]
100    UnableToDeleteDuckdbData { source: duckdb::Error },
101
102    #[snafu(display("Unable to query data from the duckdb table: {source}"))]
103    UnableToQueryData { source: duckdb::Error },
104
105    #[snafu(display("Unable to commit transaction: {source}"))]
106    UnableToCommitTransaction { source: duckdb::Error },
107
108    #[snafu(display("Unable to begin duckdb transaction: {source}"))]
109    UnableToBeginTransaction { source: duckdb::Error },
110
111    #[snafu(display("Unable to rollback transaction: {source}"))]
112    UnableToRollbackTransaction { source: duckdb::Error },
113
114    #[snafu(display("Unable to delete all data from the DuckDB table: {source}"))]
115    UnableToDeleteAllTableData { source: duckdb::Error },
116
117    #[snafu(display("Unable to insert data into the DuckDB table: {source}"))]
118    UnableToInsertIntoTableAsync { source: duckdb::Error },
119
120    #[snafu(display("The table '{table_name}' doesn't exist in the DuckDB server"))]
121    TableDoesntExist { table_name: String },
122
123    #[snafu(display("Constraint Violation: {source}"))]
124    ConstraintViolation { source: constraints::Error },
125
126    #[snafu(display("Error parsing column reference: {source}"))]
127    UnableToParseColumnReference { source: column_reference::Error },
128
129    #[snafu(display("Error parsing on_conflict: {source}"))]
130    UnableToParseOnConflict { source: on_conflict::Error },
131
132    #[snafu(display(
133        "Failed to create '{table_name}': creating a table with a schema is not supported"
134    ))]
135    TableWithSchemaCreationNotSupported { table_name: String },
136
137    #[snafu(display("Failed to parse memory_limit value '{value}': {source}\nProvide a valid value, e.g. '2GB', '512MiB' (expected: KB, MB, GB, TB for 1000^i units or KiB, MiB, GiB, TiB for 1024^i units)"))]
138    UnableToParseMemoryLimit {
139        value: String,
140        source: byte_unit::ParseError,
141    },
142
143    #[snafu(display("Unable to add primary key to table: {source}"))]
144    UnableToAddPrimaryKey { source: duckdb::Error },
145
146    #[snafu(display("Failed to get system time since epoch: {source}"))]
147    UnableToGetSystemTime { source: std::time::SystemTimeError },
148
149    #[snafu(display("Failed to parse the system time: {source}"))]
150    UnableToParseSystemTime { source: std::num::ParseIntError },
151
152    #[snafu(display("A read provider is required to create a DuckDBTableWriter"))]
153    MissingReadProvider,
154
155    #[snafu(display("A pool is required to create a DuckDBTableWriter"))]
156    MissingPool,
157
158    #[snafu(display("A table definition is required to create a DuckDBTableWriter"))]
159    MissingTableDefinition,
160
161    #[snafu(display("Failed to register Arrow scan view for DuckDB ingestion: {source}"))]
162    UnableToRegisterArrowScanView { source: duckdb::Error },
163
164    #[snafu(display(
165        "Failed to register Arrow scan view to build table creation statement: {source}"
166    ))]
167    UnableToRegisterArrowScanViewForTableCreation { source: duckdb::Error },
168
169    #[snafu(display("Failed to drop Arrow scan view for DuckDB ingestion: {source}"))]
170    UnableToDropArrowScanView { source: duckdb::Error },
171}
172
173type Result<T, E = Error> = std::result::Result<T, E>;
174
175const DUCKDB_DB_PATH_PARAM: &str = "open";
176const DUCKDB_DB_BASE_FOLDER_PARAM: &str = "data_directory";
177const DUCKDB_ATTACH_DATABASES_PARAM: &str = "attach_databases";
178
179pub struct DuckDBTableProviderFactory {
180    access_mode: AccessMode,
181    instances: Arc<Mutex<HashMap<DbInstanceKey, DuckDbConnectionPool>>>,
182    unsupported_type_action: UnsupportedTypeAction,
183    dialect: Arc<dyn Dialect>,
184    settings_registry: DuckDBSettingsRegistry,
185}
186
187// Dialect trait does not implement Debug so we implement Debug manually
188impl std::fmt::Debug for DuckDBTableProviderFactory {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("DuckDBTableProviderFactory")
191            .field("access_mode", &self.access_mode)
192            .field("instances", &self.instances)
193            .field("unsupported_type_action", &self.unsupported_type_action)
194            .finish()
195    }
196}
197
198impl DuckDBTableProviderFactory {
199    #[must_use]
200    pub fn new(access_mode: AccessMode) -> Self {
201        Self {
202            access_mode,
203            instances: Arc::new(Mutex::new(HashMap::new())),
204            unsupported_type_action: UnsupportedTypeAction::Error,
205            dialect: Arc::new(DuckDBDialect::new()),
206            settings_registry: DuckDBSettingsRegistry::new(),
207        }
208    }
209
210    #[must_use]
211    pub fn with_unsupported_type_action(
212        mut self,
213        unsupported_type_action: UnsupportedTypeAction,
214    ) -> Self {
215        self.unsupported_type_action = unsupported_type_action;
216        self
217    }
218
219    #[must_use]
220    pub fn with_dialect(mut self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
221        self.dialect = dialect;
222        self
223    }
224
225    #[must_use]
226    pub fn with_settings_registry(mut self, settings_registry: DuckDBSettingsRegistry) -> Self {
227        self.settings_registry = settings_registry;
228        self
229    }
230
231    #[must_use]
232    pub fn settings_registry(&self) -> &DuckDBSettingsRegistry {
233        &self.settings_registry
234    }
235
236    #[must_use]
237    pub fn settings_registry_mut(&mut self) -> &mut DuckDBSettingsRegistry {
238        &mut self.settings_registry
239    }
240
241    #[must_use]
242    pub fn attach_databases(&self, options: &HashMap<String, String>) -> Vec<Arc<str>> {
243        options
244            .get(DUCKDB_ATTACH_DATABASES_PARAM)
245            .map(|attach_databases| {
246                attach_databases
247                    .split(';')
248                    .map(Arc::from)
249                    .collect::<Vec<Arc<str>>>()
250            })
251            .unwrap_or_default()
252    }
253
254    /// Get the path to the DuckDB file database.
255    ///
256    /// ## Errors
257    ///
258    /// - If the path includes absolute sequences to escape the current directory, like `./`, `../`, or `/`.
259    pub fn duckdb_file_path(
260        &self,
261        name: &str,
262        options: &mut HashMap<String, String>,
263    ) -> Result<String, Error> {
264        let options = util::remove_prefix_from_hashmap_keys(options.clone(), "duckdb_");
265
266        let db_base_folder = options
267            .get(DUCKDB_DB_BASE_FOLDER_PARAM)
268            .cloned()
269            .unwrap_or(".".to_string()); // default to the current directory
270        let default_filepath = &format!("{db_base_folder}/{name}.db");
271
272        let filepath = options
273            .get(DUCKDB_DB_PATH_PARAM)
274            .unwrap_or(default_filepath);
275
276        Ok(filepath.to_string())
277    }
278
279    pub async fn get_or_init_memory_instance(&self) -> Result<DuckDbConnectionPool> {
280        let pool_builder = DuckDbConnectionPoolBuilder::memory();
281        self.get_or_init_instance_with_builder(pool_builder).await
282    }
283
284    pub async fn get_or_init_file_instance(
285        &self,
286        db_path: impl Into<Arc<str>>,
287    ) -> Result<DuckDbConnectionPool> {
288        let db_path: Arc<str> = db_path.into();
289        let pool_builder = DuckDbConnectionPoolBuilder::file(&db_path);
290
291        self.get_or_init_instance_with_builder(pool_builder).await
292    }
293
294    pub async fn get_or_init_instance_with_builder(
295        &self,
296        pool_builder: DuckDbConnectionPoolBuilder,
297    ) -> Result<DuckDbConnectionPool> {
298        let mode = pool_builder.get_mode();
299        let key = match mode {
300            Mode::File => {
301                let path = pool_builder.get_path();
302                DbInstanceKey::file(path.into())
303            }
304            Mode::Memory => DbInstanceKey::memory(),
305        };
306
307        let access_mode = match &self.access_mode {
308            AccessMode::ReadOnly => AccessMode::ReadOnly,
309            AccessMode::ReadWrite => AccessMode::ReadWrite,
310            AccessMode::Automatic => AccessMode::Automatic,
311        };
312        let pool_builder = pool_builder.with_access_mode(access_mode);
313
314        let mut instances = self.instances.lock().await;
315
316        if let Some(instance) = instances.get(&key) {
317            return Ok(instance.clone());
318        }
319
320        let pool = pool_builder
321            .build()
322            .context(DbConnectionPoolSnafu)?
323            .with_unsupported_type_action(self.unsupported_type_action);
324
325        instances.insert(key, pool.clone());
326
327        Ok(pool)
328    }
329}
330
331type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
332    + Send
333    + Sync;
334
335#[async_trait]
336impl TableProviderFactory for DuckDBTableProviderFactory {
337    #[allow(clippy::too_many_lines)]
338    async fn create(
339        &self,
340        _state: &dyn Session,
341        cmd: &CreateExternalTable,
342    ) -> DataFusionResult<Arc<dyn TableProvider>> {
343        if cmd.name.schema().is_some() {
344            TableWithSchemaCreationNotSupportedSnafu {
345                table_name: cmd.name.to_string(),
346            }
347            .fail()
348            .map_err(to_datafusion_error)?;
349        }
350
351        let name = cmd.name.to_string();
352        let mut options = cmd.options.clone();
353        let mode = remove_option(&mut options, "mode").unwrap_or_default();
354        let mode: Mode = mode.as_str().into();
355
356        let indexes_option_str = remove_option(&mut options, "indexes");
357        let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
358            Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
359            None => HashMap::new(),
360        };
361
362        let unparsed_indexes = unparsed_indexes
363            .into_iter()
364            .map(|(key, value)| {
365                let columns = ColumnReference::try_from(key.as_str())
366                    .context(UnableToParseColumnReferenceSnafu)
367                    .map_err(to_datafusion_error);
368                (columns, value)
369            })
370            .collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
371
372        let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
373        for (columns, index_type) in unparsed_indexes {
374            let columns = columns?;
375            indexes.push((columns, index_type));
376        }
377
378        let mut on_conflict: Option<OnConflict> = None;
379        if let Some(on_conflict_str) = remove_option(&mut options, "on_conflict") {
380            on_conflict = Some(
381                OnConflict::try_from(on_conflict_str.as_str())
382                    .context(UnableToParseOnConflictSnafu)
383                    .map_err(to_datafusion_error)?,
384            );
385        }
386
387        let pool: DuckDbConnectionPool = match &mode {
388            Mode::File => {
389                // open duckdb at given path or create a new one
390                let db_path = self
391                    .duckdb_file_path(&name, &mut options)
392                    .map_err(to_datafusion_error)?;
393
394                self.get_or_init_file_instance(db_path)
395                    .await
396                    .map_err(to_datafusion_error)?
397            }
398            Mode::Memory => self
399                .get_or_init_memory_instance()
400                .await
401                .map_err(to_datafusion_error)?,
402        };
403
404        let read_pool = match &mode {
405            Mode::File => {
406                let read_pool = pool.clone();
407
408                read_pool.set_attached_databases(&self.attach_databases(&options))
409            }
410            Mode::Memory => pool.clone(),
411        };
412
413        // Get local DuckDB SET statements to use as setup queries on the pool
414        let local_settings = self
415            .settings_registry
416            .get_setting_statements(&options, DuckDBSettingScope::Local);
417
418        let read_pool = read_pool.with_connection_setup_queries(local_settings);
419
420        let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into());
421
422        let table_definition =
423            TableDefinition::new(RelationName::new(name.clone()), Arc::clone(&schema))
424                .with_constraints(cmd.constraints.clone())
425                .with_indexes(indexes.clone());
426
427        let pool = Arc::new(pool);
428        make_initial_table(Arc::new(table_definition.clone()), &pool)?;
429
430        let table_writer_builder = DuckDBTableWriterBuilder::new()
431            .with_table_definition(table_definition)
432            .with_pool(pool)
433            .set_on_conflict(on_conflict);
434
435        let dyn_pool: Arc<DynDuckDbConnectionPool> = Arc::new(read_pool);
436
437        let db_conn = dyn_pool.connect().await?;
438        let Some(conn) = db_conn.as_sync() else {
439            return Err(DataFusionError::External(Box::new(
440                Error::DbConnectionError {
441                    source: "Failed to get sync DuckDbConnection using DbConnection".into(),
442                },
443            )));
444        };
445
446        // Apply DuckDB global settings
447        self.settings_registry
448            .apply_settings(conn, &options, DuckDBSettingScope::Global)?;
449
450        let read_provider = Arc::new(DuckDBTable::new_with_schema(
451            &dyn_pool,
452            Arc::clone(&schema),
453            TableReference::bare(name.clone()),
454            None,
455            Some(self.dialect.clone()),
456        ));
457
458        #[cfg(feature = "duckdb-federation")]
459        let read_provider: Arc<dyn TableProvider> =
460            Arc::new(read_provider.create_federated_table_provider()?);
461
462        Ok(Arc::new(
463            table_writer_builder
464                .with_read_provider(read_provider)
465                .build()
466                .map_err(to_datafusion_error)?,
467        ))
468    }
469}
470
471fn to_datafusion_error(error: Error) -> DataFusionError {
472    DataFusionError::External(Box::new(error))
473}
474
475pub struct DuckDB {
476    table_name: String,
477    pool: Arc<DuckDbConnectionPool>,
478    schema: SchemaRef,
479    constraints: Constraints,
480}
481
482impl std::fmt::Debug for DuckDB {
483    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484        f.debug_struct("DuckDB")
485            .field("table_name", &self.table_name)
486            .field("schema", &self.schema)
487            .field("constraints", &self.constraints)
488            .finish()
489    }
490}
491
492impl DuckDB {
493    #[must_use]
494    pub fn existing_table(
495        table_name: String,
496        pool: Arc<DuckDbConnectionPool>,
497        schema: SchemaRef,
498        constraints: Constraints,
499    ) -> Self {
500        Self {
501            table_name,
502            pool,
503            schema,
504            constraints,
505        }
506    }
507
508    #[must_use]
509    pub fn table_name(&self) -> &str {
510        &self.table_name
511    }
512
513    #[must_use]
514    pub fn constraints(&self) -> &Constraints {
515        &self.constraints
516    }
517
518    pub fn connect_sync(
519        &self,
520    ) -> Result<
521        Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
522    > {
523        Arc::clone(&self.pool)
524            .connect_sync()
525            .context(DbConnectionSnafu)
526    }
527
528    pub fn duckdb_conn(
529        db_connection: &mut Box<
530            dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
531        >,
532    ) -> Result<&mut DuckDbConnection> {
533        db_connection
534            .as_any_mut()
535            .downcast_mut::<DuckDbConnection>()
536            .context(UnableToDowncastDbConnectionSnafu)
537    }
538}
539
540fn remove_option(options: &mut HashMap<String, String>, key: &str) -> Option<String> {
541    options
542        .remove(key)
543        .or_else(|| options.remove(&format!("duckdb.{key}")))
544}
545
546pub struct DuckDBTableFactory {
547    pool: Arc<DuckDbConnectionPool>,
548    dialect: Arc<dyn Dialect>,
549}
550
551impl DuckDBTableFactory {
552    #[must_use]
553    pub fn new(pool: Arc<DuckDbConnectionPool>) -> Self {
554        Self {
555            pool,
556            dialect: Arc::new(DuckDBDialect::new()),
557        }
558    }
559
560    #[must_use]
561    pub fn with_dialect(mut self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
562        self.dialect = dialect;
563        self
564    }
565
566    pub async fn table_provider(
567        &self,
568        table_reference: TableReference,
569    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
570        let pool = Arc::clone(&self.pool);
571        let conn = Arc::clone(&pool).connect().await?;
572        let dyn_pool: Arc<DynDuckDbConnectionPool> = pool;
573
574        let schema = get_schema(conn, &table_reference).await?;
575        let (tbl_ref, cte) = if is_table_function(&table_reference) {
576            let tbl_ref_view = create_table_function_view_name(&table_reference);
577            (
578                tbl_ref_view.clone(),
579                Some(HashMap::from_iter(vec![(
580                    tbl_ref_view.to_string(),
581                    table_reference.table().to_string(),
582                )])),
583            )
584        } else {
585            (table_reference.clone(), None)
586        };
587
588        let table_provider = Arc::new(DuckDBTable::new_with_schema(
589            &dyn_pool,
590            schema,
591            tbl_ref,
592            cte,
593            Some(self.dialect.clone()),
594        ));
595
596        #[cfg(feature = "duckdb-federation")]
597        let table_provider: Arc<dyn TableProvider> =
598            Arc::new(table_provider.create_federated_table_provider()?);
599
600        Ok(table_provider)
601    }
602
603    pub async fn read_write_table_provider(
604        &self,
605        table_reference: TableReference,
606    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
607        let read_provider = Self::table_provider(self, table_reference.clone()).await?;
608        let schema = read_provider.schema();
609
610        let table_name = RelationName::from(table_reference);
611        let table_definition = TableDefinition::new(table_name, Arc::clone(&schema));
612        let table_writer_builder = DuckDBTableWriterBuilder::new()
613            .with_read_provider(read_provider)
614            .with_pool(Arc::clone(&self.pool))
615            .with_table_definition(table_definition);
616
617        Ok(Arc::new(table_writer_builder.build()?))
618    }
619}
620
621/// For a [`TableReference`] that is a table function, create a name for a view on the original [`TableReference`]
622///
623/// ### Example
624///
625/// ```rust,ignore
626/// use datafusion_table_providers::duckdb::create_table_function_view_name;
627/// use datafusion::common::TableReference;
628///
629/// let table_reference = TableReference::from("catalog.schema.read_parquet('cleaned_sales_data.parquet')");
630/// let view_name = create_table_function_view_name(&table_reference);
631/// assert_eq!(view_name.to_string(), "catalog.schema.read_parquet_cleaned_sales_dataparquet__view");
632/// ```
633fn create_table_function_view_name(table_reference: &TableReference) -> TableReference {
634    let tbl_ref_view = [
635        table_reference.catalog(),
636        table_reference.schema(),
637        Some(&flatten_table_function_name(table_reference)),
638    ]
639    .iter()
640    .flatten()
641    .join(".");
642    TableReference::from(&tbl_ref_view)
643}
644
645pub(crate) fn make_initial_table(
646    table_definition: Arc<TableDefinition>,
647    pool: &Arc<DuckDbConnectionPool>,
648) -> DataFusionResult<()> {
649    let cloned_pool = Arc::clone(pool);
650    let mut db_conn = Arc::clone(&cloned_pool)
651        .connect_sync()
652        .context(DbConnectionPoolSnafu)
653        .map_err(to_datafusion_error)?;
654
655    let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn).map_err(to_datafusion_error)?;
656
657    let tx = duckdb_conn
658        .conn
659        .transaction()
660        .context(UnableToBeginTransactionSnafu)
661        .map_err(to_datafusion_error)?;
662
663    let has_table = table_definition
664        .has_table(&tx)
665        .map_err(to_datafusion_error)?;
666    let internal_tables = table_definition
667        .list_internal_tables(&tx)
668        .map_err(to_datafusion_error)?;
669
670    if has_table || !internal_tables.is_empty() {
671        return Ok(());
672    }
673
674    let table_manager = TableManager::new(table_definition);
675
676    table_manager
677        .create_table(cloned_pool, &tx)
678        .map_err(to_datafusion_error)?;
679
680    tx.commit()
681        .context(UnableToCommitTransactionSnafu)
682        .map_err(to_datafusion_error)?;
683
684    Ok(())
685}
686
687#[cfg(test)]
688pub(crate) mod tests {
689    use crate::duckdb::write::DuckDBTableWriter;
690
691    use super::*;
692    use arrow::datatypes::{DataType, Field, Schema};
693    use datafusion::common::{Constraints, ToDFSchema};
694    use datafusion::logical_expr::CreateExternalTable;
695    use datafusion::prelude::SessionContext;
696    use datafusion::sql::TableReference;
697    use std::collections::HashMap;
698    use std::sync::Arc;
699
700    #[tokio::test]
701    async fn test_create_with_memory_limit() {
702        let table_name = TableReference::bare("test_table");
703        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
704
705        let mut options = HashMap::new();
706        options.insert("mode".to_string(), "memory".to_string());
707        options.insert("memory_limit".to_string(), "123MiB".to_string());
708
709        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
710        let ctx = SessionContext::new();
711        let cmd = CreateExternalTable {
712            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
713            name: table_name,
714            location: "".to_string(),
715            file_type: "".to_string(),
716            table_partition_cols: vec![],
717            if_not_exists: false,
718            definition: None,
719            order_exprs: vec![],
720            unbounded: false,
721            options,
722            constraints: Constraints::default(),
723            column_defaults: HashMap::new(),
724            temporary: false,
725        };
726
727        let table_provider = factory
728            .create(&ctx.state(), &cmd)
729            .await
730            .expect("table provider created");
731
732        let writer = table_provider
733            .as_any()
734            .downcast_ref::<DuckDBTableWriter>()
735            .expect("cast to DuckDBTableWriter");
736
737        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
738        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
739
740        let mut stmt = conn
741            .conn
742            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'memory_limit'")
743            .expect("to prepare statement");
744
745        let memory_limit = stmt
746            .query_row([], |row| row.get::<usize, String>(0))
747            .expect("to query memory limit");
748
749        println!("Memory limit: {memory_limit}");
750
751        assert_eq!(
752            memory_limit, "123.0 MiB",
753            "Memory limit must be set to 123.0 MiB"
754        );
755    }
756
757    #[tokio::test]
758    async fn test_create_with_temp_directory() {
759        let table_name = TableReference::bare("test_table_temp_dir");
760        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
761
762        let test_temp_directory = "/tmp/duckdb_test_temp";
763        let mut options = HashMap::new();
764        options.insert("mode".to_string(), "memory".to_string());
765        options.insert(
766            "temp_directory".to_string(),
767            test_temp_directory.to_string(),
768        );
769
770        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
771        let ctx = SessionContext::new();
772        let cmd = CreateExternalTable {
773            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
774            name: table_name,
775            location: "".to_string(),
776            file_type: "".to_string(),
777            table_partition_cols: vec![],
778            if_not_exists: false,
779            definition: None,
780            order_exprs: vec![],
781            unbounded: false,
782            options,
783            constraints: Constraints::default(),
784            column_defaults: HashMap::new(),
785            temporary: false,
786        };
787
788        let table_provider = factory
789            .create(&ctx.state(), &cmd)
790            .await
791            .expect("table provider created");
792
793        let writer = table_provider
794            .as_any()
795            .downcast_ref::<DuckDBTableWriter>()
796            .expect("cast to DuckDBTableWriter");
797
798        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
799        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
800
801        let mut stmt = conn
802            .conn
803            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'temp_directory'")
804            .expect("to prepare statement");
805
806        let temp_directory = stmt
807            .query_row([], |row| row.get::<usize, String>(0))
808            .expect("to query temp directory");
809
810        println!("Temp directory: {temp_directory}");
811
812        assert_eq!(
813            temp_directory, test_temp_directory,
814            "Temp directory must be set to {test_temp_directory}"
815        );
816    }
817
818    #[tokio::test]
819    async fn test_create_with_preserve_insertion_order_true() {
820        let table_name = TableReference::bare("test_table_preserve_order_true");
821        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
822
823        let mut options = HashMap::new();
824        options.insert("mode".to_string(), "memory".to_string());
825        options.insert("preserve_insertion_order".to_string(), "true".to_string());
826
827        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
828        let ctx = SessionContext::new();
829        let cmd = CreateExternalTable {
830            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
831            name: table_name,
832            location: "".to_string(),
833            file_type: "".to_string(),
834            table_partition_cols: vec![],
835            if_not_exists: false,
836            definition: None,
837            order_exprs: vec![],
838            unbounded: false,
839            options,
840            constraints: Constraints::default(),
841            column_defaults: HashMap::new(),
842            temporary: false,
843        };
844
845        let table_provider = factory
846            .create(&ctx.state(), &cmd)
847            .await
848            .expect("table provider created");
849
850        let writer = table_provider
851            .as_any()
852            .downcast_ref::<DuckDBTableWriter>()
853            .expect("cast to DuckDBTableWriter");
854
855        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
856        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
857
858        let mut stmt = conn
859            .conn
860            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'preserve_insertion_order'")
861            .expect("to prepare statement");
862
863        let preserve_order = stmt
864            .query_row([], |row| row.get::<usize, String>(0))
865            .expect("to query preserve_insertion_order");
866
867        assert_eq!(
868            preserve_order, "true",
869            "preserve_insertion_order must be set to true"
870        );
871    }
872
873    #[tokio::test]
874    async fn test_create_with_preserve_insertion_order_false() {
875        let table_name = TableReference::bare("test_table_preserve_order_false");
876        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
877
878        let mut options = HashMap::new();
879        options.insert("mode".to_string(), "memory".to_string());
880        options.insert("preserve_insertion_order".to_string(), "false".to_string());
881
882        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
883        let ctx = SessionContext::new();
884        let cmd = CreateExternalTable {
885            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
886            name: table_name,
887            location: "".to_string(),
888            file_type: "".to_string(),
889            table_partition_cols: vec![],
890            if_not_exists: false,
891            definition: None,
892            order_exprs: vec![],
893            unbounded: false,
894            options,
895            constraints: Constraints::default(),
896            column_defaults: HashMap::new(),
897            temporary: false,
898        };
899
900        let table_provider = factory
901            .create(&ctx.state(), &cmd)
902            .await
903            .expect("table provider created");
904
905        let writer = table_provider
906            .as_any()
907            .downcast_ref::<DuckDBTableWriter>()
908            .expect("cast to DuckDBTableWriter");
909
910        let mut conn_box = writer.pool().connect_sync().expect("to get connection");
911        let conn = DuckDB::duckdb_conn(&mut conn_box).expect("to get DuckDB connection");
912
913        let mut stmt = conn
914            .conn
915            .prepare("SELECT value FROM duckdb_settings() WHERE name = 'preserve_insertion_order'")
916            .expect("to prepare statement");
917
918        let preserve_order = stmt
919            .query_row([], |row| row.get::<usize, String>(0))
920            .expect("to query preserve_insertion_order");
921
922        assert_eq!(
923            preserve_order, "false",
924            "preserve_insertion_order must be set to false"
925        );
926    }
927
928    #[tokio::test]
929    async fn test_create_with_invalid_preserve_insertion_order() {
930        let table_name = TableReference::bare("test_table_preserve_order_invalid");
931        let schema = Schema::new(vec![Field::new("dummy", DataType::Int32, false)]);
932
933        let mut options = HashMap::new();
934        options.insert("mode".to_string(), "memory".to_string());
935        options.insert(
936            "preserve_insertion_order".to_string(),
937            "invalid".to_string(),
938        );
939
940        let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite);
941        let ctx = SessionContext::new();
942        let cmd = CreateExternalTable {
943            schema: Arc::new(schema.to_dfschema().expect("to df schema")),
944            name: table_name,
945            location: "".to_string(),
946            file_type: "".to_string(),
947            table_partition_cols: vec![],
948            if_not_exists: false,
949            definition: None,
950            order_exprs: vec![],
951            unbounded: false,
952            options,
953            constraints: Constraints::default(),
954            column_defaults: HashMap::new(),
955            temporary: false,
956        };
957
958        let result = factory.create(&ctx.state(), &cmd).await;
959        assert!(
960            result.is_err(),
961            "Should fail with invalid preserve_insertion_order value"
962        );
963        if let Err(e) = result {
964            assert_eq!(e.to_string(), "External error: Query execution failed.\nInvalid Input Error: Failed to cast value: Could not convert string 'invalid' to BOOL\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/");
965        }
966    }
967}