1use std::time::{SystemTime, UNIX_EPOCH};
2use std::{any::Any, fmt, sync::Arc};
3
4use crate::duckdb::DuckDB;
5use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
6use crate::util::{
7 constraints,
8 on_conflict::OnConflict,
9 retriable_error::{check_and_mark_retriable_error, to_retriable_data_write_error},
10};
11use arrow::array::RecordBatchReader;
12use arrow::ffi_stream::FFI_ArrowArrayStream;
13use arrow::{array::RecordBatch, datatypes::SchemaRef};
14use arrow_schema::ArrowError;
15use async_trait::async_trait;
16use datafusion::catalog::Session;
17use datafusion::common::{Constraints, SchemaExt};
18use datafusion::datasource::sink::{DataSink, DataSinkExec};
19use datafusion::logical_expr::dml::InsertOp;
20use datafusion::{
21 datasource::{TableProvider, TableType},
22 error::DataFusionError,
23 execution::{SendableRecordBatchStream, TaskContext},
24 logical_expr::Expr,
25 physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
26};
27use duckdb::Transaction;
28use futures::StreamExt;
29use snafu::prelude::*;
30use tokio::sync::mpsc::{self, Receiver, Sender};
31use tokio::task::JoinHandle;
32
33use super::creator::{TableDefinition, TableManager, ViewCreator};
34use super::{to_datafusion_error, RelationName};
35
36const SCHEMA_EQUIVALENCE_ENABLED: bool = false;
41
42#[derive(Default)]
43pub struct DuckDBTableWriterBuilder {
44 read_provider: Option<Arc<dyn TableProvider>>,
45 pool: Option<Arc<DuckDbConnectionPool>>,
46 on_conflict: Option<OnConflict>,
47 table_definition: Option<TableDefinition>,
48}
49
50impl DuckDBTableWriterBuilder {
51 #[must_use]
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 #[must_use]
57 pub fn with_read_provider(mut self, read_provider: Arc<dyn TableProvider>) -> Self {
58 self.read_provider = Some(read_provider);
59 self
60 }
61
62 #[must_use]
63 pub fn with_pool(mut self, pool: Arc<DuckDbConnectionPool>) -> Self {
64 self.pool = Some(pool);
65 self
66 }
67
68 #[must_use]
69 pub fn set_on_conflict(mut self, on_conflict: Option<OnConflict>) -> Self {
70 self.on_conflict = on_conflict;
71 self
72 }
73
74 #[must_use]
75 pub fn with_table_definition(mut self, table_definition: TableDefinition) -> Self {
76 self.table_definition = Some(table_definition);
77 self
78 }
79
80 pub fn build(self) -> super::Result<DuckDBTableWriter> {
89 let Some(read_provider) = self.read_provider else {
90 return Err(super::Error::MissingReadProvider);
91 };
92
93 let Some(pool) = self.pool else {
94 return Err(super::Error::MissingPool);
95 };
96
97 let Some(table_definition) = self.table_definition else {
98 return Err(super::Error::MissingTableDefinition);
99 };
100
101 Ok(DuckDBTableWriter {
102 read_provider,
103 on_conflict: self.on_conflict,
104 table_definition: Arc::new(table_definition),
105 pool,
106 })
107 }
108}
109
110#[derive(Clone)]
111pub struct DuckDBTableWriter {
112 pub read_provider: Arc<dyn TableProvider>,
113 pool: Arc<DuckDbConnectionPool>,
114 table_definition: Arc<TableDefinition>,
115 on_conflict: Option<OnConflict>,
116}
117
118impl std::fmt::Debug for DuckDBTableWriter {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "DuckDBTableWriter")
121 }
122}
123
124impl DuckDBTableWriter {
125 #[must_use]
126 pub fn pool(&self) -> Arc<DuckDbConnectionPool> {
127 Arc::clone(&self.pool)
128 }
129
130 #[must_use]
131 pub fn table_definition(&self) -> Arc<TableDefinition> {
132 Arc::clone(&self.table_definition)
133 }
134}
135
136#[async_trait]
137impl TableProvider for DuckDBTableWriter {
138 fn as_any(&self) -> &dyn Any {
139 self
140 }
141
142 fn schema(&self) -> SchemaRef {
143 self.read_provider.schema()
144 }
145
146 fn table_type(&self) -> TableType {
147 TableType::Base
148 }
149
150 fn constraints(&self) -> Option<&Constraints> {
151 self.table_definition.constraints()
152 }
153
154 async fn scan(
155 &self,
156 state: &dyn Session,
157 projection: Option<&Vec<usize>>,
158 filters: &[Expr],
159 limit: Option<usize>,
160 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
161 self.read_provider
162 .scan(state, projection, filters, limit)
163 .await
164 }
165
166 async fn insert_into(
167 &self,
168 _state: &dyn Session,
169 input: Arc<dyn ExecutionPlan>,
170 op: InsertOp,
171 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
172 Ok(Arc::new(DataSinkExec::new(
173 input,
174 Arc::new(DuckDBDataSink::new(
175 Arc::clone(&self.pool),
176 Arc::clone(&self.table_definition),
177 op,
178 self.on_conflict.clone(),
179 self.schema(),
180 )),
181 None,
182 )) as _)
183 }
184}
185
186#[derive(Clone)]
187pub(crate) struct DuckDBDataSink {
188 pool: Arc<DuckDbConnectionPool>,
189 table_definition: Arc<TableDefinition>,
190 overwrite: InsertOp,
191 on_conflict: Option<OnConflict>,
192 schema: SchemaRef,
193}
194
195#[async_trait]
196impl DataSink for DuckDBDataSink {
197 fn as_any(&self) -> &dyn Any {
198 self
199 }
200
201 fn metrics(&self) -> Option<MetricsSet> {
202 None
203 }
204
205 fn schema(&self) -> &SchemaRef {
206 &self.schema
207 }
208
209 async fn write_all(
210 &self,
211 mut data: SendableRecordBatchStream,
212 _context: &Arc<TaskContext>,
213 ) -> datafusion::common::Result<u64> {
214 let pool = Arc::clone(&self.pool);
215 let table_definition = Arc::clone(&self.table_definition);
216 let overwrite = self.overwrite;
217 let on_conflict = self.on_conflict.clone();
218
219 let (batch_tx, batch_rx): (Sender<RecordBatch>, Receiver<RecordBatch>) = mpsc::channel(100);
223
224 let (notify_commit_transaction, on_commit_transaction) = tokio::sync::oneshot::channel();
226
227 let schema = data.schema();
228
229 let duckdb_write_handle: JoinHandle<datafusion::common::Result<u64>> =
230 tokio::task::spawn_blocking(move || {
231 let num_rows = match overwrite {
232 InsertOp::Overwrite => insert_overwrite(
233 pool,
234 &table_definition,
235 batch_rx,
236 on_conflict.as_ref(),
237 on_commit_transaction,
238 schema,
239 )?,
240 InsertOp::Append | InsertOp::Replace => insert_append(
241 pool,
242 &table_definition,
243 batch_rx,
244 on_conflict.as_ref(),
245 on_commit_transaction,
246 schema,
247 )?,
248 };
249
250 Ok(num_rows)
251 });
252
253 while let Some(batch) = data.next().await {
254 let batch = batch.map_err(check_and_mark_retriable_error)?;
255
256 if let Some(constraints) = self.table_definition.constraints() {
257 constraints::validate_batch_with_constraints(&[batch.clone()], constraints)
258 .await
259 .context(super::ConstraintViolationSnafu)
260 .map_err(to_datafusion_error)?;
261 }
262
263 if let Err(send_error) = batch_tx.send(batch).await {
264 match duckdb_write_handle.await {
265 Err(join_error) => {
266 return Err(DataFusionError::Execution(format!(
267 "Error writing to DuckDB: {join_error}"
268 )));
269 }
270 Ok(Err(datafusion_error)) => {
271 return Err(datafusion_error);
272 }
273 _ => {
274 return Err(DataFusionError::Execution(format!(
275 "Unable to send RecordBatch to DuckDB writer: {send_error}"
276 )))
277 }
278 };
279 }
280 }
281
282 if notify_commit_transaction.send(()).is_err() {
283 return Err(DataFusionError::Execution(
284 "Unable to send message to commit transaction to DuckDB writer.".to_string(),
285 ));
286 };
287
288 drop(batch_tx);
290
291 match duckdb_write_handle.await {
292 Ok(result) => result,
293 Err(e) => Err(DataFusionError::Execution(format!(
294 "Error writing to DuckDB: {e}"
295 ))),
296 }
297 }
298}
299
300impl DuckDBDataSink {
301 pub(crate) fn new(
302 pool: Arc<DuckDbConnectionPool>,
303 table_definition: Arc<TableDefinition>,
304 overwrite: InsertOp,
305 on_conflict: Option<OnConflict>,
306 schema: SchemaRef,
307 ) -> Self {
308 Self {
309 pool,
310 table_definition,
311 overwrite,
312 on_conflict,
313 schema,
314 }
315 }
316}
317
318impl std::fmt::Debug for DuckDBDataSink {
319 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
320 write!(f, "DuckDBDataSink")
321 }
322}
323
324impl DisplayAs for DuckDBDataSink {
325 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
326 write!(f, "DuckDBDataSink")
327 }
328}
329
330fn insert_append(
331 pool: Arc<DuckDbConnectionPool>,
332 table_definition: &Arc<TableDefinition>,
333 batch_rx: Receiver<RecordBatch>,
334 on_conflict: Option<&OnConflict>,
335 mut on_commit_transaction: tokio::sync::oneshot::Receiver<()>,
336 schema: SchemaRef,
337) -> datafusion::common::Result<u64> {
338 let mut db_conn = pool
339 .connect_sync()
340 .context(super::DbConnectionPoolSnafu)
341 .map_err(to_retriable_data_write_error)?;
342
343 let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn).map_err(to_retriable_data_write_error)?;
344
345 let tx = duckdb_conn
346 .conn
347 .transaction()
348 .context(super::UnableToBeginTransactionSnafu)
349 .map_err(to_retriable_data_write_error)?;
350
351 let append_table = TableManager::new(Arc::clone(table_definition))
352 .with_internal(false)
353 .map_err(to_retriable_data_write_error)?;
354
355 let should_have_indexes = !append_table.indexes_vec().is_empty();
356 let has_indexes = !append_table
357 .current_indexes(&tx)
358 .map_err(to_retriable_data_write_error)?
359 .is_empty();
360 let is_empty_table = append_table
361 .get_row_count(&tx)
362 .map_err(to_retriable_data_write_error)?
363 == 0;
364 let should_apply_indexes = should_have_indexes && !has_indexes && is_empty_table;
365
366 let append_table_schema = append_table
367 .current_schema(&tx)
368 .map_err(to_retriable_data_write_error)?;
369
370 if SCHEMA_EQUIVALENCE_ENABLED && !schema.equivalent_names_and_types(&append_table_schema) {
371 return Err(DataFusionError::Execution(
372 "Schema of the append table does not match the schema of the new append data."
373 .to_string(),
374 ));
375 }
376
377 tracing::debug!(
378 "Append load for {table_name}",
379 table_name = append_table.table_name()
380 );
381 let num_rows = write_to_table(&append_table, &tx, schema, batch_rx, on_conflict)
382 .map_err(to_retriable_data_write_error)?;
383
384 on_commit_transaction
385 .try_recv()
386 .map_err(to_retriable_data_write_error)?;
387
388 tx.commit()
389 .context(super::UnableToCommitTransactionSnafu)
390 .map_err(to_retriable_data_write_error)?;
391
392 let tx = duckdb_conn
393 .conn
394 .transaction()
395 .context(super::UnableToBeginTransactionSnafu)
396 .map_err(to_datafusion_error)?;
397
398 if should_apply_indexes {
400 tracing::debug!(
401 "Load for table {table_name} complete, applying constraints and indexes.",
402 table_name = append_table.table_name()
403 );
404
405 append_table
406 .create_indexes(&tx)
407 .map_err(to_retriable_data_write_error)?;
408 }
409
410 let primary_keys_match = append_table
411 .verify_primary_keys_match(&append_table, &tx)
412 .map_err(to_retriable_data_write_error)?;
413 let indexes_match = append_table
414 .verify_indexes_match(&append_table, &tx)
415 .map_err(to_retriable_data_write_error)?;
416
417 if !primary_keys_match {
418 return Err(DataFusionError::Execution(
419 "Primary keys do not match between the new table and the existing table.\nEnsure primary key configuration is the same as the existing table, or manually migrate the table.".to_string(),
420 ));
421 }
422
423 if !indexes_match {
424 return Err(DataFusionError::Execution(
425 "Indexes do not match between the new table and the existing table.\nEnsure index configuration is the same as the existing table, or manually migrate the table.".to_string(),
426 ));
427 }
428
429 tx.commit()
430 .context(super::UnableToCommitTransactionSnafu)
431 .map_err(to_retriable_data_write_error)?;
432
433 Ok(num_rows)
434}
435
436#[allow(clippy::too_many_lines)]
437fn insert_overwrite(
438 pool: Arc<DuckDbConnectionPool>,
439 table_definition: &Arc<TableDefinition>,
440 batch_rx: Receiver<RecordBatch>,
441 on_conflict: Option<&OnConflict>,
442 mut on_commit_transaction: tokio::sync::oneshot::Receiver<()>,
443 schema: SchemaRef,
444) -> datafusion::common::Result<u64> {
445 let cloned_pool = Arc::clone(&pool);
446 let mut db_conn = pool
447 .connect_sync()
448 .context(super::DbConnectionPoolSnafu)
449 .map_err(to_retriable_data_write_error)?;
450
451 let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn).map_err(to_retriable_data_write_error)?;
452
453 let tx = duckdb_conn
454 .conn
455 .transaction()
456 .context(super::UnableToBeginTransactionSnafu)
457 .map_err(to_retriable_data_write_error)?;
458
459 let new_table = TableManager::new(Arc::clone(table_definition))
460 .with_internal(true)
461 .map_err(to_retriable_data_write_error)?;
462
463 new_table
464 .create_table(cloned_pool, &tx)
465 .map_err(to_retriable_data_write_error)?;
466
467 let existing_tables = new_table
468 .list_other_internal_tables(&tx)
469 .map_err(to_retriable_data_write_error)?;
470 let base_table = new_table
471 .base_table(&tx)
472 .map_err(to_retriable_data_write_error)?;
473 let last_table = match (existing_tables.last(), base_table.as_ref()) {
474 (Some(internal_table), Some(base_table)) => {
475 return Err(DataFusionError::Execution(
476 format!("Failed to insert data for DuckDB - both an internal table and definition base table were found.\nManual table migration is required - delete the table '{internal_table}' or '{base_table}' and try again.",
477 internal_table = internal_table.0.table_name(),
478 base_table = base_table.table_name())));
479 }
480 (Some((table, _)), None) | (None, Some(table)) => Some(table),
481 (None, None) => None,
482 };
483
484 if let Some(last_table) = last_table {
485 let should_have_indexes = !last_table.indexes_vec().is_empty();
486 let has_indexes = !last_table
487 .current_indexes(&tx)
488 .map_err(to_retriable_data_write_error)?
489 .is_empty();
490 let is_empty_table = last_table
491 .get_row_count(&tx)
492 .map_err(to_retriable_data_write_error)?
493 == 0;
494 let should_apply_indexes = should_have_indexes && !has_indexes && is_empty_table;
495
496 let last_table_schema = last_table
497 .current_schema(&tx)
498 .map_err(to_retriable_data_write_error)?;
499 let new_table_schema = new_table
500 .current_schema(&tx)
501 .map_err(to_retriable_data_write_error)?;
502
503 if SCHEMA_EQUIVALENCE_ENABLED
504 && !new_table_schema.equivalent_names_and_types(&last_table_schema)
505 {
506 return Err(DataFusionError::Execution(
507 "Schema does not match between the new table and the existing table.".to_string(),
508 ));
509 }
510
511 if !should_apply_indexes {
512 let primary_keys_match = new_table
514 .verify_primary_keys_match(last_table, &tx)
515 .map_err(to_retriable_data_write_error)?;
516 let indexes_match = new_table
517 .verify_indexes_match(last_table, &tx)
518 .map_err(to_retriable_data_write_error)?;
519
520 if !primary_keys_match {
521 return Err(DataFusionError::Execution(
522 "Primary keys do not match between the new table and the existing table.\nEnsure primary key configuration is the same as the existing table, or manually migrate the table."
523 .to_string(),
524 ));
525 }
526
527 if !indexes_match {
528 return Err(DataFusionError::Execution(
529 "Indexes do not match between the new table and the existing table.\nEnsure index configuration is the same as the existing table, or manually migrate the table.".to_string(),
530 ));
531 }
532 }
533 }
534
535 tracing::debug!("Initial load for {}", new_table.table_name());
536 let num_rows = write_to_table(&new_table, &tx, schema, batch_rx, on_conflict)
537 .map_err(to_retriable_data_write_error)?;
538
539 on_commit_transaction
540 .try_recv()
541 .map_err(to_retriable_data_write_error)?;
542
543 if let Some(base_table) = base_table {
544 base_table
545 .delete_table(&tx)
546 .map_err(to_retriable_data_write_error)?;
547 }
548
549 new_table
550 .create_view(&tx)
551 .map_err(to_retriable_data_write_error)?;
552
553 tx.commit()
554 .context(super::UnableToCommitTransactionSnafu)
555 .map_err(to_retriable_data_write_error)?;
556
557 tracing::debug!(
558 "Load for table {table_name} complete, applying constraints and indexes.",
559 table_name = new_table.table_name()
560 );
561
562 let tx = duckdb_conn
563 .conn
564 .transaction()
565 .context(super::UnableToBeginTransactionSnafu)
566 .map_err(to_datafusion_error)?;
567
568 for (table, _) in existing_tables {
569 table
570 .delete_table(&tx)
571 .map_err(to_retriable_data_write_error)?;
572 }
573
574 new_table
576 .create_indexes(&tx)
577 .map_err(to_retriable_data_write_error)?;
578
579 tx.commit()
580 .context(super::UnableToCommitTransactionSnafu)
581 .map_err(to_retriable_data_write_error)?;
582
583 Ok(num_rows)
584}
585
586#[allow(clippy::doc_markdown)]
587fn write_to_table(
589 table: &TableManager,
590 tx: &Transaction<'_>,
591 schema: SchemaRef,
592 data_batches: Receiver<RecordBatch>,
593 on_conflict: Option<&OnConflict>,
594) -> datafusion::common::Result<u64> {
595 let stream = FFI_ArrowArrayStream::new(Box::new(RecordBatchReaderFromStream::new(
596 data_batches,
597 schema,
598 )));
599
600 let current_ts = SystemTime::now()
601 .duration_since(UNIX_EPOCH)
602 .context(super::UnableToGetSystemTimeSnafu)
603 .map_err(to_datafusion_error)?
604 .as_millis();
605
606 let view_name = format!("__scan_{}_{current_ts}", table.table_name());
607 tx.register_arrow_scan_view(&view_name, &stream)
608 .context(super::UnableToRegisterArrowScanViewSnafu)
609 .map_err(to_datafusion_error)?;
610
611 let view = ViewCreator::from_name(RelationName::new(view_name));
612 let rows = view
613 .insert_into(table, tx, on_conflict)
614 .map_err(to_datafusion_error)?;
615 view.drop(tx).map_err(to_datafusion_error)?;
616
617 Ok(rows as u64)
618}
619
620struct RecordBatchReaderFromStream {
621 stream: Receiver<RecordBatch>,
622 schema: SchemaRef,
623}
624
625impl RecordBatchReaderFromStream {
626 fn new(stream: Receiver<RecordBatch>, schema: SchemaRef) -> Self {
627 Self { stream, schema }
628 }
629}
630
631impl Iterator for RecordBatchReaderFromStream {
632 type Item = Result<RecordBatch, ArrowError>;
633
634 fn next(&mut self) -> Option<Self::Item> {
635 self.stream.blocking_recv().map(Ok)
636 }
637}
638
639impl RecordBatchReader for RecordBatchReaderFromStream {
640 fn schema(&self) -> SchemaRef {
641 Arc::clone(&self.schema)
642 }
643}
644
645#[cfg(test)]
646mod test {
647 use arrow::array::{Int64Array, StringArray};
648 use datafusion::physical_plan::memory::MemoryStream;
649
650 use super::*;
651 use crate::{
652 duckdb::creator::tests::{get_basic_table_definition, get_mem_duckdb, init_tracing},
653 util::{column_reference::ColumnReference, indexes::IndexType},
654 };
655
656 #[tokio::test]
657 async fn test_write_to_table_overwrite_without_previous_table() {
658 let _guard = init_tracing(None);
662 let pool = get_mem_duckdb();
663
664 let table_definition = get_basic_table_definition();
665
666 let duckdb_sink = DuckDBDataSink::new(
667 Arc::clone(&pool),
668 Arc::clone(&table_definition),
669 InsertOp::Overwrite,
670 None,
671 table_definition.schema(),
672 );
673 let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
674
675 let batches = vec![RecordBatch::try_new(
679 Arc::clone(&table_definition.schema()),
680 vec![
681 Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
682 Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
683 ],
684 )
685 .expect("should create a record batch")];
686
687 let stream = Box::pin(
688 MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
689 );
690
691 data_sink
692 .write_all(stream, &Arc::new(TaskContext::default()))
693 .await
694 .expect("to write all");
695
696 let mut conn = pool.connect_sync().expect("to connect");
697 let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
698 let tx = duckdb.conn.transaction().expect("to begin transaction");
699 let mut internal_tables = table_definition
700 .list_internal_tables(&tx)
701 .expect("to list internal tables");
702 assert_eq!(internal_tables.len(), 1);
703
704 let table_name = internal_tables.pop().expect("should have a table").0;
705
706 let rows = tx
707 .query_row(&format!("SELECT COUNT(1) FROM {table_name}"), [], |row| {
708 row.get::<_, i64>(0)
709 })
710 .expect("to get count");
711 assert_eq!(rows, 2);
712
713 let view_rows = tx
715 .query_row(
716 &format!(
717 "SELECT COUNT(1) FROM {view_name}",
718 view_name = table_definition.name()
719 ),
720 [],
721 |row| row.get::<_, i64>(0),
722 )
723 .expect("to get count");
724
725 assert_eq!(view_rows, 2);
726
727 tx.rollback().expect("to rollback");
728 }
729
730 #[tokio::test]
731 async fn test_write_to_table_overwrite_with_previous_base_table() {
732 let _guard = init_tracing(None);
737 let pool = get_mem_duckdb();
738
739 let table_definition = get_basic_table_definition();
740
741 let cloned_pool = Arc::clone(&pool);
742 let mut conn = cloned_pool.connect_sync().expect("to connect");
743 let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
744 let tx = duckdb.conn.transaction().expect("to begin transaction");
745
746 let overwrite_table = TableManager::new(Arc::clone(&table_definition))
748 .with_internal(false)
749 .expect("to create table");
750
751 overwrite_table
752 .create_table(Arc::clone(&pool), &tx)
753 .expect("to create table");
754
755 tx.execute(
756 &format!(
757 "INSERT INTO {table_name} VALUES (3, 'c')",
758 table_name = overwrite_table.table_name()
759 ),
760 [],
761 )
762 .expect("to insert");
763
764 tx.commit().expect("to commit");
765
766 let duckdb_sink = DuckDBDataSink::new(
767 Arc::clone(&pool),
768 Arc::clone(&table_definition),
769 InsertOp::Overwrite,
770 None,
771 table_definition.schema(),
772 );
773 let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
774
775 let batches = vec![RecordBatch::try_new(
779 Arc::clone(&table_definition.schema()),
780 vec![
781 Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
782 Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
783 ],
784 )
785 .expect("should create a record batch")];
786
787 let stream = Box::pin(
788 MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
789 );
790
791 data_sink
792 .write_all(stream, &Arc::new(TaskContext::default()))
793 .await
794 .expect("to write all");
795
796 let mut conn = pool.connect_sync().expect("to connect");
797 let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
798 let tx = duckdb.conn.transaction().expect("to begin transaction");
799 let mut internal_tables = table_definition
800 .list_internal_tables(&tx)
801 .expect("to list internal tables");
802 assert_eq!(internal_tables.len(), 1);
803
804 let table_name = internal_tables.pop().expect("should have a table").0;
805
806 let rows = tx
807 .query_row(&format!("SELECT COUNT(1) FROM {table_name}"), [], |row| {
808 row.get::<_, i64>(0)
809 })
810 .expect("to get count");
811 assert_eq!(rows, 2);
812
813 let table_creator =
814 TableManager::from_table_name(Arc::clone(&table_definition), table_name);
815 let base_table = table_creator.base_table(&tx).expect("to get base table");
816
817 assert!(base_table.is_none()); let view_rows = tx
821 .query_row(
822 &format!(
823 "SELECT COUNT(1) FROM {view_name}",
824 view_name = table_definition.name()
825 ),
826 [],
827 |row| row.get::<_, i64>(0),
828 )
829 .expect("to get count");
830
831 assert_eq!(view_rows, 2);
832
833 tx.rollback().expect("to rollback");
834 }
835
836 #[tokio::test]
837 async fn test_write_to_table_overwrite_with_previous_internal_table() {
838 let _guard = init_tracing(None);
843 let pool = get_mem_duckdb();
844
845 let table_definition = get_basic_table_definition();
846
847 let cloned_pool = Arc::clone(&pool);
848 let mut conn = cloned_pool.connect_sync().expect("to connect");
849 let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
850 let tx = duckdb.conn.transaction().expect("to begin transaction");
851
852 let overwrite_table = TableManager::new(Arc::clone(&table_definition))
854 .with_internal(true)
855 .expect("to create table");
856
857 overwrite_table
858 .create_table(Arc::clone(&pool), &tx)
859 .expect("to create table");
860
861 tx.execute(
862 &format!(
863 "INSERT INTO {table_name} VALUES (3, 'c')",
864 table_name = overwrite_table.table_name()
865 ),
866 [],
867 )
868 .expect("to insert");
869
870 tx.commit().expect("to commit");
871
872 let duckdb_sink = DuckDBDataSink::new(
873 Arc::clone(&pool),
874 Arc::clone(&table_definition),
875 InsertOp::Overwrite,
876 None,
877 table_definition.schema(),
878 );
879 let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
880
881 let batches = vec![RecordBatch::try_new(
885 Arc::clone(&table_definition.schema()),
886 vec![
887 Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
888 Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
889 ],
890 )
891 .expect("should create a record batch")];
892
893 let stream = Box::pin(
894 MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
895 );
896
897 data_sink
898 .write_all(stream, &Arc::new(TaskContext::default()))
899 .await
900 .expect("to write all");
901
902 let mut conn = pool.connect_sync().expect("to connect");
903 let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
904 let tx = duckdb.conn.transaction().expect("to begin transaction");
905 let mut internal_tables = table_definition
906 .list_internal_tables(&tx)
907 .expect("to list internal tables");
908 assert_eq!(internal_tables.len(), 1);
909
910 let table_name = internal_tables.pop().expect("should have a table").0;
911
912 let rows = tx
913 .query_row(&format!("SELECT COUNT(1) FROM {table_name}"), [], |row| {
914 row.get::<_, i64>(0)
915 })
916 .expect("to get count");
917 assert_eq!(rows, 2);
918
919 let view_rows = tx
921 .query_row(
922 &format!(
923 "SELECT COUNT(1) FROM {view_name}",
924 view_name = table_definition.name()
925 ),
926 [],
927 |row| row.get::<_, i64>(0),
928 )
929 .expect("to get count");
930
931 assert_eq!(view_rows, 2);
932
933 tx.rollback().expect("to rollback");
934 }
935
936 #[tokio::test]
937 async fn test_write_to_table_append_with_previous_table() {
938 let _guard = init_tracing(None);
943 let pool = get_mem_duckdb();
944
945 let cloned_pool = Arc::clone(&pool);
946 let mut conn = cloned_pool.connect_sync().expect("to connect");
947 let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
948 let tx = duckdb.conn.transaction().expect("to begin transaction");
949
950 let table_definition = get_basic_table_definition();
951
952 let append_table = TableManager::new(Arc::clone(&table_definition))
954 .with_internal(false)
955 .expect("to create table");
956
957 append_table
958 .create_table(Arc::clone(&pool), &tx)
959 .expect("to create table");
960
961 tx.execute(
962 &format!(
963 "INSERT INTO {table_name} VALUES (3, 'c')",
964 table_name = append_table.table_name()
965 ),
966 [],
967 )
968 .expect("to insert");
969
970 tx.commit().expect("to commit");
971
972 let duckdb_sink = DuckDBDataSink::new(
973 Arc::clone(&pool),
974 Arc::clone(&table_definition),
975 InsertOp::Append,
976 None,
977 table_definition.schema(),
978 );
979 let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
980
981 let batches = vec![RecordBatch::try_new(
985 Arc::clone(&table_definition.schema()),
986 vec![
987 Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
988 Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
989 ],
990 )
991 .expect("should create a record batch")];
992
993 let stream = Box::pin(
994 MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
995 );
996
997 data_sink
998 .write_all(stream, &Arc::new(TaskContext::default()))
999 .await
1000 .expect("to write all");
1001
1002 let tx = duckdb.conn.transaction().expect("to begin transaction");
1003
1004 let internal_tables = table_definition
1005 .list_internal_tables(&tx)
1006 .expect("to list internal tables");
1007 assert_eq!(internal_tables.len(), 0);
1008
1009 let base_table = append_table
1010 .base_table(&tx)
1011 .expect("to get base table")
1012 .expect("should have a base table");
1013
1014 let rows = tx
1015 .query_row(
1016 &format!(
1017 "SELECT COUNT(1) FROM {table_name}",
1018 table_name = base_table.table_name()
1019 ),
1020 [],
1021 |row| row.get::<_, i64>(0),
1022 )
1023 .expect("to get count");
1024 assert_eq!(rows, 3);
1025
1026 tx.rollback().expect("to rollback");
1027 }
1028
1029 #[tokio::test]
1030 async fn test_write_to_table_append_with_previous_table_needs_indexes() {
1031 let _guard = init_tracing(None);
1036 let pool = get_mem_duckdb();
1037
1038 let cloned_pool = Arc::clone(&pool);
1039 let mut conn = cloned_pool.connect_sync().expect("to connect");
1040 let duckdb = DuckDB::duckdb_conn(&mut conn).expect("to get duckdb conn");
1041 let tx = duckdb.conn.transaction().expect("to begin transaction");
1042
1043 let schema = Arc::new(arrow::datatypes::Schema::new(vec![
1044 arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
1045 arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
1046 ]));
1047
1048 let table_definition = Arc::new(
1049 TableDefinition::new(RelationName::new("test_table"), Arc::clone(&schema))
1050 .with_indexes(
1051 vec![(
1052 ColumnReference::try_from("id").expect("valid column ref"),
1053 IndexType::Enabled,
1054 )]
1055 .into_iter()
1056 .collect(),
1057 ),
1058 );
1059
1060 let append_table = TableManager::new(Arc::clone(&table_definition))
1062 .with_internal(false)
1063 .expect("to create table");
1064
1065 append_table
1066 .create_table(Arc::clone(&pool), &tx)
1067 .expect("to create table");
1068
1069 tx.commit().expect("to commit");
1072
1073 let duckdb_sink = DuckDBDataSink::new(
1074 Arc::clone(&pool),
1075 Arc::clone(&table_definition),
1076 InsertOp::Append,
1077 None,
1078 table_definition.schema(),
1079 );
1080 let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
1081
1082 let batches = vec![RecordBatch::try_new(
1086 Arc::clone(&table_definition.schema()),
1087 vec![
1088 Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
1089 Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
1090 ],
1091 )
1092 .expect("should create a record batch")];
1093
1094 let stream = Box::pin(
1095 MemoryStream::try_new(batches, table_definition.schema(), None).expect("to get stream"),
1096 );
1097
1098 data_sink
1099 .write_all(stream, &Arc::new(TaskContext::default()))
1100 .await
1101 .expect("to write all");
1102
1103 let tx = duckdb.conn.transaction().expect("to begin transaction");
1104
1105 let internal_tables = table_definition
1106 .list_internal_tables(&tx)
1107 .expect("to list internal tables");
1108 assert_eq!(internal_tables.len(), 0);
1109
1110 let base_table = append_table
1111 .base_table(&tx)
1112 .expect("to get base table")
1113 .expect("should have a base table");
1114
1115 let rows = tx
1116 .query_row(
1117 &format!(
1118 "SELECT COUNT(1) FROM {table_name}",
1119 table_name = base_table.table_name()
1120 ),
1121 [],
1122 |row| row.get::<_, i64>(0),
1123 )
1124 .expect("to get count");
1125 assert_eq!(rows, 2);
1126
1127 let indexes = append_table.current_indexes(&tx).expect("to get indexes");
1129 assert_eq!(indexes.len(), 1);
1130
1131 tx.rollback().expect("to rollback");
1132 }
1133}