datafusion_table_providers/duckdb/
write.rs

1use std::time::{SystemTime, UNIX_EPOCH};
2use std::{any::Any, fmt, sync::Arc};
3
4use crate::duckdb::DuckDB;
5use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
6use crate::util::{
7    constraints,
8    on_conflict::OnConflict,
9    retriable_error::{check_and_mark_retriable_error, to_retriable_data_write_error},
10};
11use arrow::array::RecordBatchReader;
12use arrow::ffi_stream::FFI_ArrowArrayStream;
13use arrow::{array::RecordBatch, datatypes::SchemaRef};
14use arrow_schema::ArrowError;
15use async_trait::async_trait;
16use datafusion::catalog::Session;
17use datafusion::common::{Constraints, SchemaExt};
18use datafusion::datasource::sink::{DataSink, DataSinkExec};
19use datafusion::logical_expr::dml::InsertOp;
20use datafusion::{
21    datasource::{TableProvider, TableType},
22    error::DataFusionError,
23    execution::{SendableRecordBatchStream, TaskContext},
24    logical_expr::Expr,
25    physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
26};
27use duckdb::Transaction;
28use futures::StreamExt;
29use snafu::prelude::*;
30use tokio::sync::mpsc::{self, Receiver, Sender};
31use tokio::task::JoinHandle;
32
33use super::creator::{TableDefinition, TableManager, ViewCreator};
34use super::{to_datafusion_error, RelationName};
35
36// checking schemas are equivalent is disabled because it incorrectly marks single-level list fields are different when the name of the field is different
37// e.g. List(Field { name: 'a', data_type: Int32 }) != List(Field { name: 'b', data_type: Int32 })
38// but, in this case, they are actually equivalent because the field name does not matter for the schema.
39// related: https://github.com/apache/arrow-rs/issues/6733#issuecomment-2482582556
40const SCHEMA_EQUIVALENCE_ENABLED: bool = false;
41
42#[derive(Default)]
43pub struct DuckDBTableWriterBuilder {
44    read_provider: Option<Arc<dyn TableProvider>>,
45    pool: Option<Arc<DuckDbConnectionPool>>,
46    on_conflict: Option<OnConflict>,
47    table_definition: Option<TableDefinition>,
48}
49
50impl DuckDBTableWriterBuilder {
51    #[must_use]
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    #[must_use]
57    pub fn with_read_provider(mut self, read_provider: Arc<dyn TableProvider>) -> Self {
58        self.read_provider = Some(read_provider);
59        self
60    }
61
62    #[must_use]
63    pub fn with_pool(mut self, pool: Arc<DuckDbConnectionPool>) -> Self {
64        self.pool = Some(pool);
65        self
66    }
67
68    #[must_use]
69    pub fn set_on_conflict(mut self, on_conflict: Option<OnConflict>) -> Self {
70        self.on_conflict = on_conflict;
71        self
72    }
73
74    #[must_use]
75    pub fn with_table_definition(mut self, table_definition: TableDefinition) -> Self {
76        self.table_definition = Some(table_definition);
77        self
78    }
79
80    /// Builds a `DuckDBTableWriter` from the provided configuration.
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if any of the required fields are missing:
85    /// - `read_provider`
86    /// - `pool`
87    /// - `table_definition`
88    pub fn build(self) -> super::Result<DuckDBTableWriter> {
89        let Some(read_provider) = self.read_provider else {
90            return Err(super::Error::MissingReadProvider);
91        };
92
93        let Some(pool) = self.pool else {
94            return Err(super::Error::MissingPool);
95        };
96
97        let Some(table_definition) = self.table_definition else {
98            return Err(super::Error::MissingTableDefinition);
99        };
100
101        Ok(DuckDBTableWriter {
102            read_provider,
103            on_conflict: self.on_conflict,
104            table_definition: Arc::new(table_definition),
105            pool,
106        })
107    }
108}
109
110#[derive(Clone)]
111pub struct DuckDBTableWriter {
112    pub read_provider: Arc<dyn TableProvider>,
113    pool: Arc<DuckDbConnectionPool>,
114    table_definition: Arc<TableDefinition>,
115    on_conflict: Option<OnConflict>,
116}
117
118impl std::fmt::Debug for DuckDBTableWriter {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "DuckDBTableWriter")
121    }
122}
123
124impl DuckDBTableWriter {
125    #[must_use]
126    pub fn pool(&self) -> Arc<DuckDbConnectionPool> {
127        Arc::clone(&self.pool)
128    }
129
130    #[must_use]
131    pub fn table_definition(&self) -> Arc<TableDefinition> {
132        Arc::clone(&self.table_definition)
133    }
134}
135
136#[async_trait]
137impl TableProvider for DuckDBTableWriter {
138    fn as_any(&self) -> &dyn Any {
139        self
140    }
141
142    fn schema(&self) -> SchemaRef {
143        self.read_provider.schema()
144    }
145
146    fn table_type(&self) -> TableType {
147        TableType::Base
148    }
149
150    fn constraints(&self) -> Option<&Constraints> {
151        self.table_definition.constraints()
152    }
153
154    async fn scan(
155        &self,
156        state: &dyn Session,
157        projection: Option<&Vec<usize>>,
158        filters: &[Expr],
159        limit: Option<usize>,
160    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
161        self.read_provider
162            .scan(state, projection, filters, limit)
163            .await
164    }
165
166    async fn insert_into(
167        &self,
168        _state: &dyn Session,
169        input: Arc<dyn ExecutionPlan>,
170        op: InsertOp,
171    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
172        Ok(Arc::new(DataSinkExec::new(
173            input,
174            Arc::new(DuckDBDataSink::new(
175                Arc::clone(&self.pool),
176                Arc::clone(&self.table_definition),
177                op,
178                self.on_conflict.clone(),
179                self.schema(),
180            )),
181            None,
182        )) as _)
183    }
184}
185
186#[derive(Clone)]
187pub(crate) struct DuckDBDataSink {
188    pool: Arc<DuckDbConnectionPool>,
189    table_definition: Arc<TableDefinition>,
190    overwrite: InsertOp,
191    on_conflict: Option<OnConflict>,
192    schema: SchemaRef,
193}
194
195#[async_trait]
196impl DataSink for DuckDBDataSink {
197    fn as_any(&self) -> &dyn Any {
198        self
199    }
200
201    fn metrics(&self) -> Option<MetricsSet> {
202        None
203    }
204
205    fn schema(&self) -> &SchemaRef {
206        &self.schema
207    }
208
209    async fn write_all(
210        &self,
211        mut data: SendableRecordBatchStream,
212        _context: &Arc<TaskContext>,
213    ) -> datafusion::common::Result<u64> {
214        let pool = Arc::clone(&self.pool);
215        let table_definition = Arc::clone(&self.table_definition);
216        let overwrite = self.overwrite;
217        let on_conflict = self.on_conflict.clone();
218
219        // Limit channel size to a maximum of 100 RecordBatches queued for cases when DuckDB is slower than the writer stream,
220        // so that we don't significantly increase memory usage. After the maximum RecordBatches are queued, the writer stream will wait
221        // until DuckDB is able to process more data.
222        let (batch_tx, batch_rx): (Sender<RecordBatch>, Receiver<RecordBatch>) = mpsc::channel(100);
223
224        // Since the main task/stream can be dropped or fail, we use a oneshot channel to signal that all data is received and we should commit the transaction
225        let (notify_commit_transaction, on_commit_transaction) = tokio::sync::oneshot::channel();
226
227        let schema = data.schema();
228
229        let duckdb_write_handle: JoinHandle<datafusion::common::Result<u64>> =
230            tokio::task::spawn_blocking(move || {
231                let num_rows = match overwrite {
232                    InsertOp::Overwrite => insert_overwrite(
233                        pool,
234                        &table_definition,
235                        batch_rx,
236                        on_conflict.as_ref(),
237                        on_commit_transaction,
238                        schema,
239                    )?,
240                    InsertOp::Append | InsertOp::Replace => insert_append(
241                        pool,
242                        &table_definition,
243                        batch_rx,
244                        on_conflict.as_ref(),
245                        on_commit_transaction,
246                        schema,
247                    )?,
248                };
249
250                Ok(num_rows)
251            });
252
253        while let Some(batch) = data.next().await {
254            let batch = batch.map_err(check_and_mark_retriable_error)?;
255
256            if let Some(constraints) = self.table_definition.constraints() {
257                constraints::validate_batch_with_constraints(&[batch.clone()], constraints)
258                    .await
259                    .context(super::ConstraintViolationSnafu)
260                    .map_err(to_datafusion_error)?;
261            }
262
263            if let Err(send_error) = batch_tx.send(batch).await {
264                match duckdb_write_handle.await {
265                    Err(join_error) => {
266                        return Err(DataFusionError::Execution(format!(
267                            "Error writing to DuckDB: {join_error}"
268                        )));
269                    }
270                    Ok(Err(datafusion_error)) => {
271                        return Err(datafusion_error);
272                    }
273                    _ => {
274                        return Err(DataFusionError::Execution(format!(
275                            "Unable to send RecordBatch to DuckDB writer: {send_error}"
276                        )))
277                    }
278                };
279            }
280        }
281
282        if notify_commit_transaction.send(()).is_err() {
283            return Err(DataFusionError::Execution(
284                "Unable to send message to commit transaction to DuckDB writer.".to_string(),
285            ));
286        };
287
288        // Drop the sender to signal the receiver that no more data is coming
289        drop(batch_tx);
290
291        match duckdb_write_handle.await {
292            Ok(result) => result,
293            Err(e) => Err(DataFusionError::Execution(format!(
294                "Error writing to DuckDB: {e}"
295            ))),
296        }
297    }
298}
299
300impl DuckDBDataSink {
301    pub(crate) fn new(
302        pool: Arc<DuckDbConnectionPool>,
303        table_definition: Arc<TableDefinition>,
304        overwrite: InsertOp,
305        on_conflict: Option<OnConflict>,
306        schema: SchemaRef,
307    ) -> Self {
308        Self {
309            pool,
310            table_definition,
311            overwrite,
312            on_conflict,
313            schema,
314        }
315    }
316}
317
318impl std::fmt::Debug for DuckDBDataSink {
319    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
320        write!(f, "DuckDBDataSink")
321    }
322}
323
324impl DisplayAs for DuckDBDataSink {
325    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
326        write!(f, "DuckDBDataSink")
327    }
328}
329
330fn insert_append(
331    pool: Arc<DuckDbConnectionPool>,
332    table_definition: &Arc<TableDefinition>,
333    batch_rx: Receiver<RecordBatch>,
334    on_conflict: Option<&OnConflict>,
335    mut on_commit_transaction: tokio::sync::oneshot::Receiver<()>,
336    schema: SchemaRef,
337) -> datafusion::common::Result<u64> {
338    let mut db_conn = pool
339        .connect_sync()
340        .context(super::DbConnectionPoolSnafu)
341        .map_err(to_retriable_data_write_error)?;
342
343    let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn).map_err(to_retriable_data_write_error)?;
344
345    let tx = duckdb_conn
346        .conn
347        .transaction()
348        .context(super::UnableToBeginTransactionSnafu)
349        .map_err(to_retriable_data_write_error)?;
350
351    let append_table = TableManager::new(Arc::clone(table_definition))
352        .with_internal(false)
353        .map_err(to_retriable_data_write_error)?;
354
355    let should_have_indexes = !append_table.indexes_vec().is_empty();
356    let has_indexes = !append_table
357        .current_indexes(&tx)
358        .map_err(to_retriable_data_write_error)?
359        .is_empty();
360    let is_empty_table = append_table
361        .get_row_count(&tx)
362        .map_err(to_retriable_data_write_error)?
363        == 0;
364    let should_apply_indexes = should_have_indexes && !has_indexes && is_empty_table;
365
366    let append_table_schema = append_table
367        .current_schema(&tx)
368        .map_err(to_retriable_data_write_error)?;
369
370    if SCHEMA_EQUIVALENCE_ENABLED && !schema.equivalent_names_and_types(&append_table_schema) {
371        return Err(DataFusionError::Execution(
372            "Schema of the append table does not match the schema of the new append data."
373                .to_string(),
374        ));
375    }
376
377    tracing::debug!(
378        "Append load for {table_name}",
379        table_name = append_table.table_name()
380    );
381    let num_rows = write_to_table(&append_table, &tx, schema, batch_rx, on_conflict)
382        .map_err(to_retriable_data_write_error)?;
383
384    on_commit_transaction
385        .try_recv()
386        .map_err(to_retriable_data_write_error)?;
387
388    tx.commit()
389        .context(super::UnableToCommitTransactionSnafu)
390        .map_err(to_retriable_data_write_error)?;
391
392    let tx = duckdb_conn
393        .conn
394        .transaction()
395        .context(super::UnableToBeginTransactionSnafu)
396        .map_err(to_datafusion_error)?;
397
398    // apply indexes if new table
399    if should_apply_indexes {
400        tracing::debug!(
401            "Load for table {table_name} complete, applying constraints and indexes.",
402            table_name = append_table.table_name()
403        );
404
405        append_table
406            .create_indexes(&tx)
407            .map_err(to_retriable_data_write_error)?;
408    }
409
410    let primary_keys_match = append_table
411        .verify_primary_keys_match(&append_table, &tx)
412        .map_err(to_retriable_data_write_error)?;
413    let indexes_match = append_table
414        .verify_indexes_match(&append_table, &tx)
415        .map_err(to_retriable_data_write_error)?;
416
417    if !primary_keys_match {
418        return Err(DataFusionError::Execution(
419            "Primary keys do not match between the new table and the existing table.\nEnsure primary key configuration is the same as the existing table, or manually migrate the table.".to_string(),
420        ));
421    }
422
423    if !indexes_match {
424        return Err(DataFusionError::Execution(
425            "Indexes do not match between the new table and the existing table.\nEnsure index configuration is the same as the existing table, or manually migrate the table.".to_string(),
426        ));
427    }
428
429    tx.commit()
430        .context(super::UnableToCommitTransactionSnafu)
431        .map_err(to_retriable_data_write_error)?;
432
433    Ok(num_rows)
434}
435
436#[allow(clippy::too_many_lines)]
437fn insert_overwrite(
438    pool: Arc<DuckDbConnectionPool>,
439    table_definition: &Arc<TableDefinition>,
440    batch_rx: Receiver<RecordBatch>,
441    on_conflict: Option<&OnConflict>,
442    mut on_commit_transaction: tokio::sync::oneshot::Receiver<()>,
443    schema: SchemaRef,
444) -> datafusion::common::Result<u64> {
445    let cloned_pool = Arc::clone(&pool);
446    let mut db_conn = pool
447        .connect_sync()
448        .context(super::DbConnectionPoolSnafu)
449        .map_err(to_retriable_data_write_error)?;
450
451    let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn).map_err(to_retriable_data_write_error)?;
452
453    let tx = duckdb_conn
454        .conn
455        .transaction()
456        .context(super::UnableToBeginTransactionSnafu)
457        .map_err(to_retriable_data_write_error)?;
458
459    let new_table = TableManager::new(Arc::clone(table_definition))
460        .with_internal(true)
461        .map_err(to_retriable_data_write_error)?;
462
463    new_table
464        .create_table(cloned_pool, &tx)
465        .map_err(to_retriable_data_write_error)?;
466
467    let existing_tables = new_table
468        .list_other_internal_tables(&tx)
469        .map_err(to_retriable_data_write_error)?;
470    let base_table = new_table
471        .base_table(&tx)
472        .map_err(to_retriable_data_write_error)?;
473    let last_table = match (existing_tables.last(), base_table.as_ref()) {
474        (Some(internal_table), Some(base_table)) => {
475            return Err(DataFusionError::Execution(
476                format!("Failed to insert data for DuckDB - both an internal table and definition base table were found.\nManual table migration is required - delete the table '{internal_table}' or '{base_table}' and try again.",
477                internal_table = internal_table.0.table_name(),
478                base_table = base_table.table_name())));
479        }
480        (Some((table, _)), None) | (None, Some(table)) => Some(table),
481        (None, None) => None,
482    };
483
484    if let Some(last_table) = last_table {
485        let should_have_indexes = !last_table.indexes_vec().is_empty();
486        let has_indexes = !last_table
487            .current_indexes(&tx)
488            .map_err(to_retriable_data_write_error)?
489            .is_empty();
490        let is_empty_table = last_table
491            .get_row_count(&tx)
492            .map_err(to_retriable_data_write_error)?
493            == 0;
494        let should_apply_indexes = should_have_indexes && !has_indexes && is_empty_table;
495
496        let last_table_schema = last_table
497            .current_schema(&tx)
498            .map_err(to_retriable_data_write_error)?;
499        let new_table_schema = new_table
500            .current_schema(&tx)
501            .map_err(to_retriable_data_write_error)?;
502
503        if SCHEMA_EQUIVALENCE_ENABLED
504            && !new_table_schema.equivalent_names_and_types(&last_table_schema)
505        {
506            return Err(DataFusionError::Execution(
507                "Schema does not match between the new table and the existing table.".to_string(),
508            ));
509        }
510
511        if !should_apply_indexes {
512            // compare indexes and primary keys
513            let primary_keys_match = new_table
514                .verify_primary_keys_match(last_table, &tx)
515                .map_err(to_retriable_data_write_error)?;
516            let indexes_match = new_table
517                .verify_indexes_match(last_table, &tx)
518                .map_err(to_retriable_data_write_error)?;
519
520            if !primary_keys_match {
521                return Err(DataFusionError::Execution(
522                    "Primary keys do not match between the new table and the existing table.\nEnsure primary key configuration is the same as the existing table, or manually migrate the table."
523                        .to_string(),
524                ));
525            }
526
527            if !indexes_match {
528                return Err(DataFusionError::Execution(
529                    "Indexes do not match between the new table and the existing table.\nEnsure index configuration is the same as the existing table, or manually migrate the table.".to_string(),
530                ));
531            }
532        }
533    }
534
535    tracing::debug!("Initial load for {}", new_table.table_name());
536    let num_rows = write_to_table(&new_table, &tx, schema, batch_rx, on_conflict)
537        .map_err(to_retriable_data_write_error)?;
538
539    on_commit_transaction
540        .try_recv()
541        .map_err(to_retriable_data_write_error)?;
542
543    if let Some(base_table) = base_table {
544        base_table
545            .delete_table(&tx)
546            .map_err(to_retriable_data_write_error)?;
547    }
548
549    new_table
550        .create_view(&tx)
551        .map_err(to_retriable_data_write_error)?;
552
553    tx.commit()
554        .context(super::UnableToCommitTransactionSnafu)
555        .map_err(to_retriable_data_write_error)?;
556
557    tracing::debug!(
558        "Load for table {table_name} complete, applying constraints and indexes.",
559        table_name = new_table.table_name()
560    );
561
562    let tx = duckdb_conn
563        .conn
564        .transaction()
565        .context(super::UnableToBeginTransactionSnafu)
566        .map_err(to_datafusion_error)?;
567
568    for (table, _) in existing_tables {
569        table
570            .delete_table(&tx)
571            .map_err(to_retriable_data_write_error)?;
572    }
573
574    // Apply constraints and indexes.
575    new_table
576        .create_indexes(&tx)
577        .map_err(to_retriable_data_write_error)?;
578
579    tx.commit()
580        .context(super::UnableToCommitTransactionSnafu)
581        .map_err(to_retriable_data_write_error)?;
582
583    Ok(num_rows)
584}
585
586#[allow(clippy::doc_markdown)]
587/// Writes a stream of ``RecordBatch``es to a DuckDB table.
588fn write_to_table(
589    table: &TableManager,
590    tx: &Transaction<'_>,
591    schema: SchemaRef,
592    data_batches: Receiver<RecordBatch>,
593    on_conflict: Option<&OnConflict>,
594) -> datafusion::common::Result<u64> {
595    let stream = FFI_ArrowArrayStream::new(Box::new(RecordBatchReaderFromStream::new(
596        data_batches,
597        schema,
598    )));
599
600    let current_ts = SystemTime::now()
601        .duration_since(UNIX_EPOCH)
602        .context(super::UnableToGetSystemTimeSnafu)
603        .map_err(to_datafusion_error)?
604        .as_millis();
605
606    let view_name = format!("__scan_{}_{current_ts}", table.table_name());
607    tx.register_arrow_scan_view(&view_name, &stream)
608        .context(super::UnableToRegisterArrowScanViewSnafu)
609        .map_err(to_datafusion_error)?;
610
611    let view = ViewCreator::from_name(RelationName::new(view_name));
612    let rows = view
613        .insert_into(table, tx, on_conflict)
614        .map_err(to_datafusion_error)?;
615    view.drop(tx).map_err(to_datafusion_error)?;
616
617    Ok(rows as u64)
618}
619
620struct RecordBatchReaderFromStream {
621    stream: Receiver<RecordBatch>,
622    schema: SchemaRef,
623}
624
625impl RecordBatchReaderFromStream {
626    fn new(stream: Receiver<RecordBatch>, schema: SchemaRef) -> Self {
627        Self { stream, schema }
628    }
629}
630
631impl Iterator for RecordBatchReaderFromStream {
632    type Item = Result<RecordBatch, ArrowError>;
633
634    fn next(&mut self) -> Option<Self::Item> {
635        self.stream.blocking_recv().map(Ok)
636    }
637}
638
639impl RecordBatchReader for RecordBatchReaderFromStream {
640    fn schema(&self) -> SchemaRef {
641        Arc::clone(&self.schema)
642    }
643}
644
645#[cfg(test)]
646mod test {
647    use arrow::array::{Int64Array, StringArray};
648    use datafusion::physical_plan::memory::MemoryStream;
649
650    use super::*;
651    use crate::{
652        duckdb::creator::tests::{get_basic_table_definition, get_mem_duckdb, init_tracing},
653        util::{column_reference::ColumnReference, indexes::IndexType},
654    };
655
656    #[tokio::test]
657    async fn test_write_to_table_overwrite_without_previous_table() {
658        // Test scenario: Write to a table with overwrite mode without a previous table
659        // Expected behavior: Data sink creates a new internal table, writes data to it, and creates a view with the table definition name
660
661        let _guard = init_tracing(None);
662        let pool = get_mem_duckdb();
663
664        let table_definition = get_basic_table_definition();
665
666        let duckdb_sink = DuckDBDataSink::new(
667            Arc::clone(&pool),
668            Arc::clone(&table_definition),
669            InsertOp::Overwrite,
670            None,
671            table_definition.schema(),
672        );
673        let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
674
675        // id, name
676        // 1, "a"
677        // 2, "b"
678        let batches = vec![RecordBatch::try_new(
679            Arc::clone(&table_definition.schema()),
680            vec![
681                Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
682                Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
683            ],
684        )
685        .expect("should create a record batch")];
686
687        let stream = Box::pin(
688            MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
689        );
690
691        data_sink
692            .write_all(stream, &Arc::new(TaskContext::default()))
693            .await
694            .expect("to write all");
695
696        let mut conn = pool.connect_sync().expect("to connect");
697        let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
698        let tx = duckdb.conn.transaction().expect("to begin transaction");
699        let mut internal_tables = table_definition
700            .list_internal_tables(&tx)
701            .expect("to list internal tables");
702        assert_eq!(internal_tables.len(), 1);
703
704        let table_name = internal_tables.pop().expect("should have a table").0;
705
706        let rows = tx
707            .query_row(&format!("SELECT COUNT(1) FROM {table_name}"), [], |row| {
708                row.get::<_, i64>(0)
709            })
710            .expect("to get count");
711        assert_eq!(rows, 2);
712
713        // expect a view to be created with the table definition name
714        let view_rows = tx
715            .query_row(
716                &format!(
717                    "SELECT COUNT(1) FROM {view_name}",
718                    view_name = table_definition.name()
719                ),
720                [],
721                |row| row.get::<_, i64>(0),
722            )
723            .expect("to get count");
724
725        assert_eq!(view_rows, 2);
726
727        tx.rollback().expect("to rollback");
728    }
729
730    #[tokio::test]
731    async fn test_write_to_table_overwrite_with_previous_base_table() {
732        // Test scenario: Write to a table with overwrite mode with a previous base table
733        // Expected behavior: Data sink creates a new internal table, writes data to it.
734        // Before creating the view, the base table needs to get dropped as we need to create a view with the same name.
735
736        let _guard = init_tracing(None);
737        let pool = get_mem_duckdb();
738
739        let table_definition = get_basic_table_definition();
740
741        let cloned_pool = Arc::clone(&pool);
742        let mut conn = cloned_pool.connect_sync().expect("to connect");
743        let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
744        let tx = duckdb.conn.transaction().expect("to begin transaction");
745
746        // make an existing table to overwrite
747        let overwrite_table = TableManager::new(Arc::clone(&table_definition))
748            .with_internal(false)
749            .expect("to create table");
750
751        overwrite_table
752            .create_table(Arc::clone(&pool), &tx)
753            .expect("to create table");
754
755        tx.execute(
756            &format!(
757                "INSERT INTO {table_name} VALUES (3, 'c')",
758                table_name = overwrite_table.table_name()
759            ),
760            [],
761        )
762        .expect("to insert");
763
764        tx.commit().expect("to commit");
765
766        let duckdb_sink = DuckDBDataSink::new(
767            Arc::clone(&pool),
768            Arc::clone(&table_definition),
769            InsertOp::Overwrite,
770            None,
771            table_definition.schema(),
772        );
773        let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
774
775        // id, name
776        // 1, "a"
777        // 2, "b"
778        let batches = vec![RecordBatch::try_new(
779            Arc::clone(&table_definition.schema()),
780            vec![
781                Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
782                Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
783            ],
784        )
785        .expect("should create a record batch")];
786
787        let stream = Box::pin(
788            MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
789        );
790
791        data_sink
792            .write_all(stream, &Arc::new(TaskContext::default()))
793            .await
794            .expect("to write all");
795
796        let mut conn = pool.connect_sync().expect("to connect");
797        let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
798        let tx = duckdb.conn.transaction().expect("to begin transaction");
799        let mut internal_tables = table_definition
800            .list_internal_tables(&tx)
801            .expect("to list internal tables");
802        assert_eq!(internal_tables.len(), 1);
803
804        let table_name = internal_tables.pop().expect("should have a table").0;
805
806        let rows = tx
807            .query_row(&format!("SELECT COUNT(1) FROM {table_name}"), [], |row| {
808                row.get::<_, i64>(0)
809            })
810            .expect("to get count");
811        assert_eq!(rows, 2);
812
813        let table_creator =
814            TableManager::from_table_name(Arc::clone(&table_definition), table_name);
815        let base_table = table_creator.base_table(&tx).expect("to get base table");
816
817        assert!(base_table.is_none()); // base table should get deleted
818
819        // expect a view to be created with the table definition name
820        let view_rows = tx
821            .query_row(
822                &format!(
823                    "SELECT COUNT(1) FROM {view_name}",
824                    view_name = table_definition.name()
825                ),
826                [],
827                |row| row.get::<_, i64>(0),
828            )
829            .expect("to get count");
830
831        assert_eq!(view_rows, 2);
832
833        tx.rollback().expect("to rollback");
834    }
835
836    #[tokio::test]
837    async fn test_write_to_table_overwrite_with_previous_internal_table() {
838        // Test scenario: Write to a table with overwrite mode with a previous base table
839        // Expected behavior: Data sink creates a new internal table, writes data to it.
840        // Before creating the view, the base table needs to get dropped as we need to create a view with the same name.
841
842        let _guard = init_tracing(None);
843        let pool = get_mem_duckdb();
844
845        let table_definition = get_basic_table_definition();
846
847        let cloned_pool = Arc::clone(&pool);
848        let mut conn = cloned_pool.connect_sync().expect("to connect");
849        let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
850        let tx = duckdb.conn.transaction().expect("to begin transaction");
851
852        // make an existing table to overwrite
853        let overwrite_table = TableManager::new(Arc::clone(&table_definition))
854            .with_internal(true)
855            .expect("to create table");
856
857        overwrite_table
858            .create_table(Arc::clone(&pool), &tx)
859            .expect("to create table");
860
861        tx.execute(
862            &format!(
863                "INSERT INTO {table_name} VALUES (3, 'c')",
864                table_name = overwrite_table.table_name()
865            ),
866            [],
867        )
868        .expect("to insert");
869
870        tx.commit().expect("to commit");
871
872        let duckdb_sink = DuckDBDataSink::new(
873            Arc::clone(&pool),
874            Arc::clone(&table_definition),
875            InsertOp::Overwrite,
876            None,
877            table_definition.schema(),
878        );
879        let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
880
881        // id, name
882        // 1, "a"
883        // 2, "b"
884        let batches = vec![RecordBatch::try_new(
885            Arc::clone(&table_definition.schema()),
886            vec![
887                Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
888                Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
889            ],
890        )
891        .expect("should create a record batch")];
892
893        let stream = Box::pin(
894            MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
895        );
896
897        data_sink
898            .write_all(stream, &Arc::new(TaskContext::default()))
899            .await
900            .expect("to write all");
901
902        let mut conn = pool.connect_sync().expect("to connect");
903        let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
904        let tx = duckdb.conn.transaction().expect("to begin transaction");
905        let mut internal_tables = table_definition
906            .list_internal_tables(&tx)
907            .expect("to list internal tables");
908        assert_eq!(internal_tables.len(), 1);
909
910        let table_name = internal_tables.pop().expect("should have a table").0;
911
912        let rows = tx
913            .query_row(&format!("SELECT COUNT(1) FROM {table_name}"), [], |row| {
914                row.get::<_, i64>(0)
915            })
916            .expect("to get count");
917        assert_eq!(rows, 2);
918
919        // expect a view to be created with the table definition name
920        let view_rows = tx
921            .query_row(
922                &format!(
923                    "SELECT COUNT(1) FROM {view_name}",
924                    view_name = table_definition.name()
925                ),
926                [],
927                |row| row.get::<_, i64>(0),
928            )
929            .expect("to get count");
930
931        assert_eq!(view_rows, 2);
932
933        tx.rollback().expect("to rollback");
934    }
935
936    #[tokio::test]
937    async fn test_write_to_table_append_with_previous_table() {
938        // Test scenario: Write to a table with append mode with a previous table
939        // Expected behavior: Data sink appends data to the existing table. No new internal table should be created.
940        // The existing table is re-used.
941
942        let _guard = init_tracing(None);
943        let pool = get_mem_duckdb();
944
945        let cloned_pool = Arc::clone(&pool);
946        let mut conn = cloned_pool.connect_sync().expect("to connect");
947        let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
948        let tx = duckdb.conn.transaction().expect("to begin transaction");
949
950        let table_definition = get_basic_table_definition();
951
952        // make an existing table to append from
953        let append_table = TableManager::new(Arc::clone(&table_definition))
954            .with_internal(false)
955            .expect("to create table");
956
957        append_table
958            .create_table(Arc::clone(&pool), &tx)
959            .expect("to create table");
960
961        tx.execute(
962            &format!(
963                "INSERT INTO {table_name} VALUES (3, 'c')",
964                table_name = append_table.table_name()
965            ),
966            [],
967        )
968        .expect("to insert");
969
970        tx.commit().expect("to commit");
971
972        let duckdb_sink = DuckDBDataSink::new(
973            Arc::clone(&pool),
974            Arc::clone(&table_definition),
975            InsertOp::Append,
976            None,
977            table_definition.schema(),
978        );
979        let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
980
981        // id, name
982        // 1, "a"
983        // 2, "b"
984        let batches = vec![RecordBatch::try_new(
985            Arc::clone(&table_definition.schema()),
986            vec![
987                Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
988                Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
989            ],
990        )
991        .expect("should create a record batch")];
992
993        let stream = Box::pin(
994            MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
995        );
996
997        data_sink
998            .write_all(stream, &Arc::new(TaskContext::default()))
999            .await
1000            .expect("to write all");
1001
1002        let tx = duckdb.conn.transaction().expect("to begin transaction");
1003
1004        let internal_tables = table_definition
1005            .list_internal_tables(&tx)
1006            .expect("to list internal tables");
1007        assert_eq!(internal_tables.len(), 0);
1008
1009        let base_table = append_table
1010            .base_table(&tx)
1011            .expect("to get base table")
1012            .expect("should have a base table");
1013
1014        let rows = tx
1015            .query_row(
1016                &format!(
1017                    "SELECT COUNT(1) FROM {table_name}",
1018                    table_name = base_table.table_name()
1019                ),
1020                [],
1021                |row| row.get::<_, i64>(0),
1022            )
1023            .expect("to get count");
1024        assert_eq!(rows, 3);
1025
1026        tx.rollback().expect("to rollback");
1027    }
1028
1029    #[tokio::test]
1030    async fn test_write_to_table_append_with_previous_table_needs_indexes() {
1031        // Test scenario: Write to a table with append mode with a previous table
1032        // Expected behavior: Data sink appends data to the existing table. No new internal table should be created.
1033        // The existing table is re-used.
1034
1035        let _guard = init_tracing(None);
1036        let pool = get_mem_duckdb();
1037
1038        let cloned_pool = Arc::clone(&pool);
1039        let mut conn = cloned_pool.connect_sync().expect("to connect");
1040        let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
1041        let tx = duckdb.conn.transaction().expect("to begin transaction");
1042
1043        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
1044            arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
1045            arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
1046        ]));
1047
1048        let table_definition = Arc::new(
1049            TableDefinition::new(RelationName::new("test_table"), Arc::clone(&schema))
1050                .with_indexes(
1051                    vec![(
1052                        ColumnReference::try_from("id").expect("valid column ref"),
1053                        IndexType::Enabled,
1054                    )]
1055                    .into_iter()
1056                    .collect(),
1057                ),
1058        );
1059
1060        // make an existing table to append from
1061        let append_table = TableManager::new(Arc::clone(&table_definition))
1062            .with_internal(false)
1063            .expect("to create table");
1064
1065        append_table
1066            .create_table(Arc::clone(&pool), &tx)
1067            .expect("to create table");
1068
1069        // don't apply indexes, and leave the table empty to simulate a new table from TableProviderFactory::create()
1070
1071        tx.commit().expect("to commit");
1072
1073        let duckdb_sink = DuckDBDataSink::new(
1074            Arc::clone(&pool),
1075            Arc::clone(&table_definition),
1076            InsertOp::Append,
1077            None,
1078            table_definition.schema(),
1079        );
1080        let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
1081
1082        // id, name
1083        // 1, "a"
1084        // 2, "b"
1085        let batches = vec![RecordBatch::try_new(
1086            Arc::clone(&table_definition.schema()),
1087            vec![
1088                Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
1089                Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
1090            ],
1091        )
1092        .expect("should create a record batch")];
1093
1094        let stream = Box::pin(
1095            MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
1096        );
1097
1098        data_sink
1099            .write_all(stream, &Arc::new(TaskContext::default()))
1100            .await
1101            .expect("to write all");
1102
1103        let tx = duckdb.conn.transaction().expect("to begin transaction");
1104
1105        let internal_tables = table_definition
1106            .list_internal_tables(&tx)
1107            .expect("to list internal tables");
1108        assert_eq!(internal_tables.len(), 0);
1109
1110        let base_table = append_table
1111            .base_table(&tx)
1112            .expect("to get base table")
1113            .expect("should have a base table");
1114
1115        let rows = tx
1116            .query_row(
1117                &format!(
1118                    "SELECT COUNT(1) FROM {table_name}",
1119                    table_name = base_table.table_name()
1120                ),
1121                [],
1122                |row| row.get::<_, i64>(0),
1123            )
1124            .expect("to get count");
1125        assert_eq!(rows, 2);
1126
1127        // at this point, indexes should be applied
1128        let indexes = append_table.current_indexes(&tx).expect("to get indexes");
1129        assert_eq!(indexes.len(), 1);
1130
1131        tx.rollback().expect("to rollback");
1132    }
1133}