datafusion_table_providers/duckdb/
creator.rs

1use crate::sql::arrow_sql_gen::statement::IndexBuilder;
2use crate::sql::db_connection_pool::dbconnection::duckdbconn::DuckDbConnection;
3use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
4use crate::util::on_conflict::OnConflict;
5use arrow::{
6    array::{RecordBatch, RecordBatchIterator, RecordBatchReader},
7    datatypes::SchemaRef,
8    ffi_stream::FFI_ArrowArrayStream,
9};
10use datafusion::common::utils::quote_identifier;
11use datafusion::common::Constraints;
12use datafusion::sql::TableReference;
13use duckdb::Transaction;
14use itertools::Itertools;
15use snafu::prelude::*;
16use std::collections::HashSet;
17use std::fmt::Display;
18use std::sync::Arc;
19
20use super::DuckDB;
21use crate::util::{
22    column_reference::ColumnReference, constraints::get_primary_keys_from_constraints,
23    indexes::IndexType,
24};
25
26/// A newtype for a relation name, to better control the inputs for the `TableDefinition`, `TableCreator`, and `ViewCreator`.
27#[derive(Debug, Clone, PartialEq)]
28pub struct RelationName(String);
29
30impl Display for RelationName {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "{}", self.0)
33    }
34}
35
36impl RelationName {
37    #[must_use]
38    pub fn new(name: impl Into<String>) -> Self {
39        Self(name.into())
40    }
41}
42
43impl From<TableReference> for RelationName {
44    fn from(table_ref: TableReference) -> Self {
45        RelationName(table_ref.to_string())
46    }
47}
48
49/// A table definition, which includes the table name, schema, constraints, and indexes.
50/// This is used to store the definition of a table for a dataset, and can be re-used to create one or more tables (like internal data tables).
51#[derive(Debug, Clone, PartialEq)]
52pub struct TableDefinition {
53    name: RelationName,
54    schema: SchemaRef,
55    constraints: Option<Constraints>,
56    indexes: Vec<(ColumnReference, IndexType)>,
57}
58
59impl TableDefinition {
60    #[must_use]
61    pub(crate) fn new(name: RelationName, schema: SchemaRef) -> Self {
62        Self {
63            name,
64            schema,
65            constraints: None,
66            indexes: Vec::new(),
67        }
68    }
69
70    #[must_use]
71    pub(crate) fn with_constraints(mut self, constraints: Constraints) -> Self {
72        self.constraints = Some(constraints);
73        self
74    }
75
76    #[must_use]
77    pub(crate) fn with_indexes(mut self, indexes: Vec<(ColumnReference, IndexType)>) -> Self {
78        self.indexes = indexes;
79        self
80    }
81
82    #[must_use]
83    pub fn name(&self) -> &RelationName {
84        &self.name
85    }
86
87    #[cfg(test)]
88    pub(crate) fn schema(&self) -> SchemaRef {
89        Arc::clone(&self.schema)
90    }
91
92    /// For an internal table, generate a unique name based on the table definition name and the current system time.
93    pub(crate) fn generate_internal_name(&self) -> super::Result<RelationName> {
94        let unix_ms = std::time::SystemTime::now()
95            .duration_since(std::time::UNIX_EPOCH)
96            .context(super::UnableToGetSystemTimeSnafu)?
97            .as_millis();
98        Ok(RelationName(format!(
99            "__data_{table_name}_{unix_ms}",
100            table_name = self.name,
101        )))
102    }
103
104    pub(crate) fn constraints(&self) -> Option<&Constraints> {
105        self.constraints.as_ref()
106    }
107
108    /// Returns true if this table definition has a base table matching the exact `RelationName` of the definition
109    ///
110    /// # Errors
111    ///
112    /// If the transaction fails to query for whether the table exists.
113    pub fn has_table(&self, tx: &Transaction<'_>) -> super::Result<bool> {
114        let mut stmt = tx
115            .prepare("SELECT 1 FROM duckdb_tables() WHERE table_name = ?")
116            .context(super::UnableToQueryDataSnafu)?;
117        let mut rows = stmt
118            .query([self.name.to_string()])
119            .context(super::UnableToQueryDataSnafu)?;
120
121        Ok(rows
122            .next()
123            .context(super::UnableToQueryDataSnafu)?
124            .is_some())
125    }
126
127    /// List all internal tables related to this table definition.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the internal tables cannot be listed.
132    pub fn list_internal_tables(
133        &self,
134        tx: &Transaction<'_>,
135    ) -> super::Result<Vec<(RelationName, u64)>> {
136        // list all related internal tables, based on the table definition name
137        let sql = format!(
138            "select table_name from duckdb_tables() where table_name LIKE '__data_{table_name}%'",
139            table_name = self.name
140        );
141        let mut stmt = tx.prepare(&sql).context(super::UnableToQueryDataSnafu)?;
142        let mut rows = stmt.query([]).context(super::UnableToQueryDataSnafu)?;
143
144        let mut table_names = Vec::new();
145        while let Some(row) = rows.next().context(super::UnableToQueryDataSnafu)? {
146            let table_name = row
147                .get::<usize, String>(0)
148                .context(super::UnableToQueryDataSnafu)?;
149            // __data_{table_name}% could be a subset of another table name, so we need to check if the table name starts with the table definition name
150            let inner_name = table_name.replace("__data_", "");
151            let mut parts = inner_name.split('_');
152            let Some(timestamp) = parts.next_back() else {
153                continue; // skip invalid table names
154            };
155
156            let inner_name = parts.join("_");
157            if inner_name != self.name.to_string() {
158                continue;
159            }
160
161            let timestamp = timestamp
162                .parse::<u64>()
163                .context(super::UnableToParseSystemTimeSnafu)?;
164
165            table_names.push((table_name, timestamp));
166        }
167
168        table_names.sort_by(|a, b| a.1.cmp(&b.1));
169
170        Ok(table_names
171            .into_iter()
172            .map(|(name, time_created)| (RelationName(name), time_created))
173            .collect())
174    }
175}
176
177/// A table creator, which is used to create, delete, and manage tables based on a `TableDefinition`.
178#[derive(Debug, Clone)]
179pub(crate) struct TableManager {
180    table_definition: Arc<TableDefinition>,
181    internal_name: Option<RelationName>,
182}
183
184impl TableManager {
185    pub(crate) fn new(table_definition: Arc<TableDefinition>) -> Self {
186        Self {
187            table_definition,
188            internal_name: None,
189        }
190    }
191
192    /// Set the internal flag for the table creator.
193    pub(crate) fn with_internal(mut self, is_internal: bool) -> super::Result<Self> {
194        if is_internal {
195            self.internal_name = Some(self.table_definition.generate_internal_name()?);
196        } else {
197            self.internal_name = None;
198        }
199
200        Ok(self)
201    }
202
203    pub(crate) fn definition_name(&self) -> &RelationName {
204        &self.table_definition.name
205    }
206
207    /// Returns the canonical name for this table, which is the internal name if the table is internal, or the table name if it is not.
208    pub(crate) fn table_name(&self) -> &RelationName {
209        self.internal_name
210            .as_ref()
211            .unwrap_or_else(|| &self.table_definition.name)
212    }
213
214    /// Searches if a table by the name specified in the table definition exists in the database.
215    /// Returns None if the table does not exist, or an instance of a `TableCreator` for the base table if it does.
216    #[tracing::instrument(level = "debug", skip_all)]
217    pub(crate) fn base_table(&self, tx: &Transaction<'_>) -> super::Result<Option<Self>> {
218        let mut stmt = tx
219            .prepare("SELECT 1 FROM duckdb_tables() WHERE table_name = ?")
220            .context(super::UnableToQueryDataSnafu)?;
221        let mut rows = stmt
222            .query([self.definition_name().to_string()])
223            .context(super::UnableToQueryDataSnafu)?;
224
225        if rows
226            .next()
227            .context(super::UnableToQueryDataSnafu)?
228            .is_some()
229        {
230            let base_table = self.clone();
231            Ok(Some(base_table.with_internal(false)?))
232        } else {
233            Ok(None)
234        }
235    }
236
237    pub(crate) fn indexes_vec(&self) -> Vec<(Vec<&str>, IndexType)> {
238        self.table_definition
239            .indexes
240            .iter()
241            .map(|(key, ty)| (key.iter().collect(), *ty))
242            .collect()
243    }
244
245    /// Creates the table for this `TableManager`. Does not create indexes - use `TableManager::create_indexes` to apply indexes.
246    #[tracing::instrument(level = "debug", skip_all)]
247    pub(crate) fn create_table(
248        &self,
249        pool: Arc<DuckDbConnectionPool>,
250        tx: &Transaction<'_>,
251    ) -> super::Result<()> {
252        let mut db_conn = pool.connect_sync().context(super::DbConnectionPoolSnafu)?;
253        let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn)?;
254
255        // create the table with the supplied table name, or a generated internal name
256        let mut create_stmt = self.get_table_create_statement(duckdb_conn)?;
257        tracing::debug!("{create_stmt}");
258
259        let primary_keys = if let Some(constraints) = &self.table_definition.constraints {
260            get_primary_keys_from_constraints(constraints, &self.table_definition.schema)
261        } else {
262            Vec::new()
263        };
264
265        if !primary_keys.is_empty() && !create_stmt.contains("PRIMARY KEY") {
266            let primary_key_clause = format!(", PRIMARY KEY ({}));", primary_keys.join(", "));
267            create_stmt = create_stmt.replace(");", &primary_key_clause);
268        }
269
270        tx.execute(&create_stmt, [])
271            .context(super::UnableToCreateDuckDBTableSnafu)?;
272
273        Ok(())
274    }
275
276    /// Drops indexes from the table, then drops the table itself.
277    #[tracing::instrument(level = "debug", skip_all)]
278    pub(crate) fn delete_table(&self, tx: &Transaction<'_>) -> super::Result<()> {
279        // drop indexes first
280        self.drop_indexes(tx)?;
281        self.drop_table(tx)?;
282
283        Ok(())
284    }
285
286    #[tracing::instrument(level = "debug", skip_all)]
287    fn drop_table(&self, tx: &Transaction<'_>) -> super::Result<()> {
288        // drop this table
289        tx.execute(
290            &format!(r#"DROP TABLE IF EXISTS "{}""#, self.table_name()),
291            [],
292        )
293        .context(super::UnableToDropDuckDBTableSnafu)?;
294
295        Ok(())
296    }
297
298    /// Inserts data from this table into the target table.
299    #[tracing::instrument(level = "debug", skip_all)]
300    pub(crate) fn insert_into(
301        &self,
302        table: &TableManager,
303        tx: &Transaction<'_>,
304        on_conflict: Option<&OnConflict>,
305    ) -> super::Result<u64> {
306        // insert from this table, into the target table
307        let mut insert_sql = format!(
308            r#"INSERT INTO "{}" SELECT * FROM "{}""#,
309            table.table_name(),
310            self.table_name()
311        );
312
313        if let Some(on_conflict) = on_conflict {
314            let on_conflict_sql =
315                on_conflict.build_on_conflict_statement(&self.table_definition.schema);
316            insert_sql.push_str(&format!(" {on_conflict_sql}"));
317        }
318        tracing::debug!("{insert_sql}");
319
320        let rows = tx
321            .execute(&insert_sql, [])
322            .context(super::UnableToInsertToDuckDBTableSnafu)?;
323
324        Ok(rows as u64)
325    }
326
327    fn get_index_name(table_name: &RelationName, index: &(Vec<&str>, IndexType)) -> String {
328        let index_builder = IndexBuilder::new(&table_name.to_string(), index.0.clone());
329        index_builder.index_name()
330    }
331
332    #[tracing::instrument(level = "debug", skip_all)]
333    fn create_index(
334        &self,
335        tx: &Transaction<'_>,
336        index: (Vec<&str>, IndexType),
337    ) -> super::Result<()> {
338        let table_name = self.table_name();
339
340        let unique = index.1 == IndexType::Unique;
341        let columns = index.0;
342        let mut index_builder = IndexBuilder::new(&table_name.to_string(), columns);
343        if unique {
344            index_builder = index_builder.unique();
345        }
346        let sql = index_builder.build_postgres();
347        tracing::debug!("Creating index: {sql}");
348
349        tx.execute(&sql, [])
350            .context(super::UnableToCreateIndexOnDuckDBTableSnafu)?;
351
352        Ok(())
353    }
354
355    #[tracing::instrument(level = "debug", skip_all)]
356    pub(crate) fn create_indexes(&self, tx: &Transaction<'_>) -> super::Result<()> {
357        // create indexes on this table
358        for index in self.indexes_vec() {
359            self.create_index(tx, index)?;
360        }
361        Ok(())
362    }
363
364    #[tracing::instrument(level = "debug", skip_all)]
365    fn drop_index(&self, tx: &Transaction<'_>, index: (Vec<&str>, IndexType)) -> super::Result<()> {
366        let table_name = self.table_name();
367        let index_name = TableManager::get_index_name(table_name, &index);
368
369        let sql = format!(r#"DROP INDEX IF EXISTS "{index_name}""#);
370        tracing::debug!("{sql}");
371
372        tx.execute(&sql, [])
373            .context(super::UnableToDropIndexOnDuckDBTableSnafu)?;
374
375        Ok(())
376    }
377
378    pub(crate) fn drop_indexes(&self, tx: &Transaction<'_>) -> super::Result<()> {
379        // drop indexes on this table
380        for index in self.indexes_vec() {
381            self.drop_index(tx, index)?;
382        }
383
384        Ok(())
385    }
386
387    /// DuckDB CREATE TABLE statements aren't supported by sea-query - so we create a temporary table
388    /// from an Arrow schema and ask DuckDB for the CREATE TABLE statement.
389    #[tracing::instrument(level = "debug", skip_all)]
390    fn get_table_create_statement(
391        &self,
392        duckdb_conn: &mut DuckDbConnection,
393    ) -> super::Result<String> {
394        let tx = duckdb_conn
395            .conn
396            .transaction()
397            .context(super::UnableToBeginTransactionSnafu)?;
398        let table_name = self.table_name();
399        let record_batch_reader =
400            create_empty_record_batch_reader(Arc::clone(&self.table_definition.schema));
401        let stream = FFI_ArrowArrayStream::new(Box::new(record_batch_reader));
402
403        let current_ts = std::time::SystemTime::now()
404            .duration_since(std::time::UNIX_EPOCH)
405            .context(super::UnableToGetSystemTimeSnafu)?
406            .as_millis();
407
408        let view_name = format!("__scan_{}_{current_ts}", table_name);
409        tx.register_arrow_scan_view(&view_name, &stream)
410            .context(super::UnableToRegisterArrowScanViewForTableCreationSnafu)?;
411
412        let sql = format!(
413            r#"CREATE TABLE IF NOT EXISTS "{table_name}" AS SELECT * FROM "{view_name}""#,
414        );
415        tracing::debug!("{sql}");
416
417        tx.execute(&sql, [])
418            .context(super::UnableToCreateDuckDBTableSnafu)?;
419
420        let create_stmt = tx
421            .query_row(
422                &format!("select sql from duckdb_tables() where table_name = '{table_name}'",),
423                [],
424                |r| r.get::<usize, String>(0),
425            )
426            .context(super::UnableToQueryDataSnafu)?;
427
428        // DuckDB doesn't add IF NOT EXISTS to CREATE TABLE statements, so we add it here.
429        let create_stmt = create_stmt.replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS");
430
431        tx.rollback()
432            .context(super::UnableToRollbackTransactionSnafu)?;
433
434        Ok(create_stmt)
435    }
436
437    /// List all internal tables related to this table manager's table definition.
438    /// Excludes itself from the list of tables, if created.
439    #[tracing::instrument(level = "debug", skip_all)]
440    pub(crate) fn list_other_internal_tables(
441        &self,
442        tx: &Transaction<'_>,
443    ) -> super::Result<Vec<(Self, u64)>> {
444        let tables = self.table_definition.list_internal_tables(tx)?;
445
446        Ok(tables
447            .into_iter()
448            .filter_map(|(name, time_created)| {
449                if let Some(internal_name) = &self.internal_name {
450                    if name == *internal_name {
451                        return None;
452                    }
453                }
454
455                let internal_table = TableManager {
456                    table_definition: Arc::clone(&self.table_definition),
457                    internal_name: Some(name),
458                };
459                Some((internal_table, time_created))
460            })
461            .collect())
462    }
463
464    /// If this table is an internal table, creates a view with the table definition name targeting this table.
465    #[tracing::instrument(level = "debug", skip_all)]
466    pub(crate) fn create_view(&self, tx: &Transaction<'_>) -> super::Result<()> {
467        if self.internal_name.is_none() {
468            return Ok(());
469        }
470
471        tx.execute(
472            &format!(
473                "CREATE OR REPLACE VIEW {base_table} AS SELECT * FROM {internal_table}",
474                base_table = quote_identifier(&self.definition_name().to_string()),
475                internal_table = quote_identifier(&self.table_name().to_string())
476            ),
477            [],
478        )
479        .context(super::UnableToCreateDuckDBTableSnafu)?;
480
481        Ok(())
482    }
483
484    /// Returns the current primary keys in database for this table.
485    #[tracing::instrument(level = "debug", skip_all)]
486    pub(crate) fn current_primary_keys(
487        &self,
488        tx: &Transaction<'_>,
489    ) -> super::Result<HashSet<String>> {
490        // DuckDB provides convenient queryable 'pragma_table_info' table function
491        // Complex table name with schema as part of the name must be quoted as
492        // '"<name>"', otherwise it will be parsed to schema and table name
493        let sql = format!(
494            "SELECT name FROM pragma_table_info('{table_name}') WHERE pk = true",
495            table_name = quote_identifier(&self.table_name().to_string())
496        );
497        tracing::debug!("{sql}");
498
499        let mut stmt = tx
500            .prepare(&sql)
501            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
502
503        let primary_keys_iter = stmt
504            .query_map([], |row| row.get::<usize, String>(0))
505            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
506
507        let mut primary_keys = HashSet::new();
508        for pk in primary_keys_iter {
509            primary_keys.insert(pk.context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?);
510        }
511
512        Ok(primary_keys)
513    }
514
515    /// Returns the current indexes in database for this table.
516    #[tracing::instrument(level = "debug", skip_all)]
517    pub(crate) fn current_indexes(&self, tx: &Transaction<'_>) -> super::Result<HashSet<String>> {
518        let sql = format!(
519            "SELECT index_name FROM duckdb_indexes WHERE table_name = '{table_name}'",
520            table_name = &self.table_name().to_string()
521        );
522
523        tracing::debug!("{sql}");
524
525        let mut stmt = tx
526            .prepare(&sql)
527            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
528
529        let indexes_iter = stmt
530            .query_map([], |row| row.get::<usize, String>(0))
531            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
532
533        let mut indexes = HashSet::new();
534        for index in indexes_iter {
535            indexes.insert(index.context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?);
536        }
537
538        Ok(indexes)
539    }
540
541    #[cfg(test)]
542    pub(crate) fn from_table_name(
543        table_definition: Arc<TableDefinition>,
544        table_name: RelationName,
545    ) -> Self {
546        Self {
547            table_definition,
548            internal_name: Some(table_name),
549        }
550    }
551
552    /// Verifies that the primary keys match between this table creator and another table creator.
553    pub(crate) fn verify_primary_keys_match(
554        &self,
555        other_table: &TableManager,
556        tx: &Transaction<'_>,
557    ) -> super::Result<bool> {
558        let expected_pk_keys_str_map =
559            if let Some(constraints) = self.table_definition.constraints.as_ref() {
560                get_primary_keys_from_constraints(constraints, &self.table_definition.schema)
561                    .into_iter()
562                    .collect()
563            } else {
564                HashSet::new()
565            };
566
567        let actual_pk_keys_str_map = other_table.current_primary_keys(tx)?;
568
569        tracing::debug!(
570            "Expected primary keys: {:?}\nActual primary keys: {:?}",
571            expected_pk_keys_str_map,
572            actual_pk_keys_str_map
573        );
574
575        let missing_in_actual = expected_pk_keys_str_map
576            .difference(&actual_pk_keys_str_map)
577            .collect::<Vec<_>>();
578        let extra_in_actual = actual_pk_keys_str_map
579            .difference(&expected_pk_keys_str_map)
580            .collect::<Vec<_>>();
581
582        if !missing_in_actual.is_empty() {
583            tracing::warn!(
584                "Missing primary key(s) detected for the table '{name}': {:?}.",
585                missing_in_actual.iter().join(", "),
586                name = self.table_name()
587            );
588        }
589
590        if !extra_in_actual.is_empty() {
591            tracing::warn!(
592                "The table '{name}' has unexpected primary key(s) not defined in the configuration: {:?}.",
593                extra_in_actual.iter().join(", "),
594                name = self.table_name()
595            );
596        }
597
598        Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
599    }
600
601    /// Verifies that the indexes match between this table creator and another table creator.
602    pub(crate) fn verify_indexes_match(
603        &self,
604        other_table: &TableManager,
605        tx: &Transaction<'_>,
606    ) -> super::Result<bool> {
607        let expected_indexes_str_map: HashSet<String> = self
608            .indexes_vec()
609            .iter()
610            .map(|index| TableManager::get_index_name(self.table_name(), index))
611            .collect();
612
613        let actual_indexes_str_map = other_table.current_indexes(tx)?;
614
615        // replace table names for each index with nothing, as table names could be internal and have unique timestamps
616        let expected_indexes_str_map = expected_indexes_str_map
617            .iter()
618            .map(|index| index.replace(&self.table_name().to_string(), ""))
619            .collect::<HashSet<_>>();
620
621        let actual_indexes_str_map = actual_indexes_str_map
622            .iter()
623            .map(|index| index.replace(&other_table.table_name().to_string(), ""))
624            .collect::<HashSet<_>>();
625
626        tracing::debug!(
627            "Expected indexes: {:?}\nActual indexes: {:?}",
628            expected_indexes_str_map,
629            actual_indexes_str_map
630        );
631
632        let missing_in_actual = expected_indexes_str_map
633            .difference(&actual_indexes_str_map)
634            .collect::<Vec<_>>();
635        let extra_in_actual = actual_indexes_str_map
636            .difference(&expected_indexes_str_map)
637            .collect::<Vec<_>>();
638
639        if !missing_in_actual.is_empty() {
640            tracing::warn!(
641                "Missing index(es) detected for the table '{name}': {:?}.",
642                missing_in_actual.iter().join(", "),
643                name = self.table_name()
644            );
645        }
646        if !extra_in_actual.is_empty() {
647            tracing::warn!(
648                "Unexpected index(es) detected in table '{name}': {}.\nThese indexes are not defined in the configuration.",
649                extra_in_actual.iter().join(", "),
650                name = self.table_name()
651            );
652        }
653
654        Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
655    }
656
657    /// Returns the current schema in database for this table.
658    pub(crate) fn current_schema(&self, tx: &Transaction<'_>) -> super::Result<SchemaRef> {
659        let sql = format!(
660            "SELECT * FROM {table_name} LIMIT 0",
661            table_name = quote_identifier(&self.table_name().to_string())
662        );
663        let mut stmt = tx.prepare(&sql).context(super::UnableToQueryDataSnafu)?;
664        let result: duckdb::Arrow<'_> = stmt
665            .query_arrow([])
666            .context(super::UnableToQueryDataSnafu)?;
667        Ok(result.get_schema())
668    }
669
670    pub(crate) fn get_row_count(&self, tx: &Transaction<'_>) -> super::Result<u64> {
671        let sql = format!(
672            "SELECT COUNT(1) FROM {table_name}",
673            table_name = quote_identifier(&self.table_name().to_string())
674        );
675        let count = tx
676            .query_row(&sql, [], |r| r.get::<usize, u64>(0))
677            .context(super::UnableToQueryDataSnafu)?;
678
679        Ok(count)
680    }
681}
682
683fn create_empty_record_batch_reader(schema: SchemaRef) -> impl RecordBatchReader {
684    let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
685    let batches = vec![empty_batch];
686    RecordBatchIterator::new(batches.into_iter().map(Ok), schema)
687}
688
689#[derive(Debug, Clone)]
690pub(crate) struct ViewCreator {
691    name: RelationName,
692}
693
694impl ViewCreator {
695    #[must_use]
696    pub(crate) fn from_name(name: RelationName) -> Self {
697        Self { name }
698    }
699
700    pub(crate) fn insert_into(
701        &self,
702        table: &TableManager,
703        tx: &Transaction<'_>,
704        on_conflict: Option<&OnConflict>,
705    ) -> super::Result<u64> {
706        // insert from this view, into the target table
707        let mut insert_sql = format!(
708            r#"INSERT INTO "{table_name}" SELECT * FROM "{view_name}""#,
709            view_name = self.name,
710            table_name = table.table_name()
711        );
712
713        if let Some(on_conflict) = on_conflict {
714            let on_conflict_sql =
715                on_conflict.build_on_conflict_statement(&table.table_definition.schema);
716            insert_sql.push_str(&format!(" {on_conflict_sql}"));
717        }
718        tracing::debug!("{insert_sql}");
719
720        let rows = tx
721            .execute(&insert_sql, [])
722            .context(super::UnableToInsertToDuckDBTableSnafu)?;
723
724        Ok(rows as u64)
725    }
726
727    pub(crate) fn drop(&self, tx: &Transaction<'_>) -> super::Result<()> {
728        // drop this view
729        tx.execute(
730            &format!(
731                r#"DROP VIEW IF EXISTS "{view_name}""#,
732                view_name = self.name
733            ),
734            [],
735        )
736        .context(super::UnableToDropDuckDBTableSnafu)?;
737
738        Ok(())
739    }
740}
741
742#[cfg(test)]
743pub(crate) mod tests {
744    use crate::{
745        duckdb::make_initial_table,
746        sql::db_connection_pool::{
747            dbconnection::duckdbconn::DuckDbConnection, duckdbpool::DuckDbConnectionPool,
748        },
749    };
750    use datafusion::{arrow::array::RecordBatch, datasource::sink::DataSink};
751    use datafusion::{
752        common::SchemaExt,
753        execution::{SendableRecordBatchStream, TaskContext},
754        logical_expr::dml::InsertOp,
755        parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
756        physical_plan::memory::MemoryStream,
757    };
758    use tracing::subscriber::DefaultGuard;
759    use tracing_subscriber::EnvFilter;
760
761    use crate::{
762        duckdb::write::DuckDBDataSink,
763        util::constraints::tests::{get_pk_constraints, get_unique_constraints},
764    };
765
766    use super::*;
767
768    pub(crate) fn get_mem_duckdb() -> Arc<DuckDbConnectionPool> {
769        Arc::new(
770            DuckDbConnectionPool::new_memory().expect("to get a memory duckdb connection pool"),
771        )
772    }
773
774    async fn get_logs_batches() -> Vec<RecordBatch> {
775        let parquet_bytes = reqwest::get("https://public-data.spiceai.org/eth.recent_logs.parquet")
776            .await
777            .expect("to get parquet file")
778            .bytes()
779            .await
780            .expect("to get parquet bytes");
781
782        let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(parquet_bytes)
783            .expect("to get parquet reader builder")
784            .build()
785            .expect("to build parquet reader");
786
787        parquet_reader
788            .collect::<Result<Vec<_>, datafusion::arrow::error::ArrowError>>()
789            .expect("to get records")
790    }
791
792    fn get_stream_from_batches(batches: Vec<RecordBatch>) -> SendableRecordBatchStream {
793        let schema = batches[0].schema();
794        Box::pin(MemoryStream::try_new(batches, schema, None).expect("to get stream"))
795    }
796
797    pub(crate) fn init_tracing(default_level: Option<&str>) -> DefaultGuard {
798        let filter = match default_level {
799            Some(level) => EnvFilter::new(level),
800            _ => EnvFilter::new("INFO,datafusion_table_providers=TRACE"),
801        };
802
803        let subscriber = tracing_subscriber::FmtSubscriber::builder()
804            .with_env_filter(filter)
805            .with_ansi(true)
806            .finish();
807        tracing::subscriber::set_default(subscriber)
808    }
809
810    pub(crate) fn get_basic_table_definition() -> Arc<TableDefinition> {
811        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
812            arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
813            arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
814        ]));
815
816        Arc::new(TableDefinition::new(
817            RelationName::new("test_table"),
818            schema,
819        ))
820    }
821
822    #[tokio::test]
823    async fn test_table_creator() {
824        let _guard = init_tracing(None);
825        let batches = get_logs_batches().await;
826
827        let schema = batches[0].schema();
828
829        let table_definition = Arc::new(
830            TableDefinition::new(RelationName::new("eth.logs"), Arc::clone(&schema))
831                .with_constraints(get_unique_constraints(
832                    &["log_index", "transaction_hash"],
833                    Arc::clone(&schema),
834                ))
835                .with_indexes(vec![
836                    (
837                        ColumnReference::try_from("block_number").expect("valid column ref"),
838                        IndexType::Enabled,
839                    ),
840                    (
841                        ColumnReference::try_from("(log_index, transaction_hash)")
842                            .expect("valid column ref"),
843                        IndexType::Unique,
844                    ),
845                ]),
846        );
847
848        for overwrite in &[InsertOp::Append, InsertOp::Overwrite] {
849            let pool = get_mem_duckdb();
850
851            make_initial_table(Arc::clone(&table_definition), &pool)
852                .expect("to make initial table");
853
854            let duckdb_sink = DuckDBDataSink::new(
855                Arc::clone(&pool),
856                Arc::clone(&table_definition),
857                *overwrite,
858                None,
859                table_definition.schema(),
860            );
861            let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
862            let rows_written = data_sink
863                .write_all(
864                    get_stream_from_batches(batches.clone()),
865                    &Arc::new(TaskContext::default()),
866                )
867                .await
868                .expect("to write all");
869
870            let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
871            let conn = pool_conn
872                .as_any_mut()
873                .downcast_mut::<DuckDbConnection>()
874                .expect("to downcast to duckdb connection");
875            let num_rows = conn
876                .get_underlying_conn_mut()
877                .query_row(r#"SELECT COUNT(1) FROM "eth.logs""#, [], |r| {
878                    r.get::<usize, u64>(0)
879                })
880                .expect("to get count");
881
882            assert_eq!(num_rows, rows_written);
883
884            let tx = conn
885                .get_underlying_conn_mut()
886                .transaction()
887                .expect("should begin transaction");
888            let table_creator = if matches!(overwrite, InsertOp::Overwrite) {
889                let internal_tables: Vec<(RelationName, u64)> = table_definition
890                    .list_internal_tables(&tx)
891                    .expect("should list internal tables");
892                assert_eq!(internal_tables.len(), 1);
893
894                let internal_table = internal_tables.first().expect("to get internal table");
895                let internal_table = internal_table.0.clone();
896
897                TableManager::from_table_name(Arc::clone(&table_definition), internal_table.clone())
898            } else {
899                let table_creator = TableManager::new(Arc::clone(&table_definition))
900                    .with_internal(false)
901                    .expect("to create table creator");
902
903                let base_table = table_creator.base_table(&tx).expect("to get base table");
904                assert!(base_table.is_some());
905                table_creator
906            };
907
908            let primary_keys = table_creator
909                .current_primary_keys(&tx)
910                .expect("should get primary keys");
911
912            assert_eq!(primary_keys, HashSet::<String>::new());
913
914            let created_indexes_str_map = table_creator
915                .current_indexes(&tx)
916                .expect("should get indexes");
917
918            assert_eq!(
919                created_indexes_str_map,
920                vec![
921                    format!(
922                        "i_{table_name}_block_number",
923                        table_name = table_creator.table_name()
924                    ),
925                    format!(
926                        "i_{table_name}_log_index_transaction_hash",
927                        table_name = table_creator.table_name()
928                    )
929                ]
930                .into_iter()
931                .collect::<HashSet<_>>(),
932                "Indexes must match"
933            );
934
935            tx.rollback().expect("should rollback transaction");
936        }
937    }
938
939    #[allow(clippy::too_many_lines)]
940    #[tokio::test]
941    async fn test_table_creator_primary_key() {
942        let _guard = init_tracing(None);
943        let batches = get_logs_batches().await;
944
945        let schema = batches[0].schema();
946        let table_definition = Arc::new(
947            TableDefinition::new(RelationName::new("eth.logs"), Arc::clone(&schema))
948                .with_constraints(get_pk_constraints(
949                    &["log_index", "transaction_hash"],
950                    Arc::clone(&schema),
951                ))
952                .with_indexes(
953                    vec![(
954                        ColumnReference::try_from("block_number").expect("valid column ref"),
955                        IndexType::Enabled,
956                    )]
957                    .into_iter()
958                    .collect(),
959                ),
960        );
961
962        for overwrite in &[InsertOp::Append, InsertOp::Overwrite] {
963            let pool = get_mem_duckdb();
964
965            make_initial_table(Arc::clone(&table_definition), &pool)
966                .expect("to make initial table");
967
968            let duckdb_sink = DuckDBDataSink::new(
969                Arc::clone(&pool),
970                Arc::clone(&table_definition),
971                *overwrite,
972                None,
973                table_definition.schema(),
974            );
975            let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
976            let rows_written = data_sink
977                .write_all(
978                    get_stream_from_batches(batches.clone()),
979                    &Arc::new(TaskContext::default()),
980                )
981                .await
982                .expect("to write all");
983
984            let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
985            let conn = pool_conn
986                .as_any_mut()
987                .downcast_mut::<DuckDbConnection>()
988                .expect("to downcast to duckdb connection");
989            let num_rows = conn
990                .get_underlying_conn_mut()
991                .query_row(r#"SELECT COUNT(1) FROM "eth.logs""#, [], |r| {
992                    r.get::<usize, u64>(0)
993                })
994                .expect("to get count");
995
996            assert_eq!(num_rows, rows_written);
997
998            let tx = conn
999                .get_underlying_conn_mut()
1000                .transaction()
1001                .expect("should begin transaction");
1002
1003            let table_creator = if matches!(overwrite, InsertOp::Overwrite) {
1004                let internal_tables: Vec<(RelationName, u64)> = table_definition
1005                    .list_internal_tables(&tx)
1006                    .expect("should list internal tables");
1007                assert_eq!(internal_tables.len(), 1);
1008
1009                let internal_table = internal_tables.first().expect("to get internal table");
1010                let internal_table = internal_table.0.clone();
1011
1012                TableManager::from_table_name(Arc::clone(&table_definition), internal_table.clone())
1013            } else {
1014                let table_creator = TableManager::new(Arc::clone(&table_definition))
1015                    .with_internal(false)
1016                    .expect("to create table creator");
1017
1018                let base_table = table_creator.base_table(&tx).expect("to get base table");
1019                assert!(base_table.is_some());
1020                table_creator
1021            };
1022
1023            let create_stmt = tx
1024                .query_row(
1025                    "select sql from duckdb_tables() where table_name = ?",
1026                    [table_creator.table_name().to_string()],
1027                    |r| r.get::<usize, String>(0),
1028                )
1029                .expect("to get create table statement");
1030
1031            assert_eq!(
1032                create_stmt,
1033                format!(
1034                    r#"CREATE TABLE "{table_name}"(log_index BIGINT, transaction_hash VARCHAR, transaction_index BIGINT, address VARCHAR, "data" VARCHAR, topics VARCHAR[], block_timestamp BIGINT, block_hash VARCHAR, block_number BIGINT, PRIMARY KEY(log_index, transaction_hash));"#,
1035                    table_name = table_creator.table_name(),
1036                )
1037            );
1038
1039            let primary_keys = table_creator
1040                .current_primary_keys(&tx)
1041                .expect("should get primary keys");
1042
1043            assert_eq!(
1044                primary_keys,
1045                vec!["log_index".to_string(), "transaction_hash".to_string()]
1046                    .into_iter()
1047                    .collect::<HashSet<_>>()
1048            );
1049
1050            let created_indexes_str_map = table_creator
1051                .current_indexes(&tx)
1052                .expect("should get indexes");
1053
1054            assert_eq!(
1055                created_indexes_str_map,
1056                vec![format!(
1057                    "i_{table_name}_block_number",
1058                    table_name = table_creator.table_name()
1059                )]
1060                .into_iter()
1061                .collect::<HashSet<_>>(),
1062                "Indexes must match"
1063            );
1064
1065            tx.rollback().expect("should rollback transaction");
1066        }
1067    }
1068
1069    #[tokio::test]
1070    async fn test_list_related_tables_from_definition() {
1071        let _guard = init_tracing(None);
1072        let pool = get_mem_duckdb();
1073
1074        let table_definition = get_basic_table_definition();
1075
1076        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1077        let conn = pool_conn
1078            .as_any_mut()
1079            .downcast_mut::<DuckDbConnection>()
1080            .expect("to downcast to duckdb connection");
1081        let tx = conn
1082            .get_underlying_conn_mut()
1083            .transaction()
1084            .expect("should begin transaction");
1085
1086        // make 3 internal tables
1087        for _ in 0..3 {
1088            TableManager::new(Arc::clone(&table_definition))
1089                .with_internal(true)
1090                .expect("to create table creator")
1091                .create_table(Arc::clone(&pool), &tx)
1092                .expect("to create table");
1093        }
1094
1095        // using the table definition, list the names of the internal tables
1096        let table_name = table_definition.name.clone();
1097        let internal_tables = table_definition
1098            .list_internal_tables(&tx)
1099            .expect("should list internal tables");
1100
1101        assert_eq!(internal_tables.len(), 3);
1102
1103        // validate the first table is the oldest, and the last table is the newest
1104        let first_table = internal_tables.first().expect("to get first table");
1105        let last_table = internal_tables.last().expect("to get last table");
1106        assert!(first_table.1 < last_table.1);
1107
1108        // validate none of the internal tables are the same, they are not equal to the base table
1109        let mut seen_tables = vec![];
1110        for (internal_table, _) in internal_tables {
1111            let internal_name = internal_table.clone();
1112            assert_ne!(&internal_name, &table_name);
1113            assert!(!seen_tables.contains(&internal_name));
1114            seen_tables.push(internal_name);
1115        }
1116
1117        tx.rollback().expect("should rollback transaction");
1118    }
1119
1120    #[tokio::test]
1121    async fn test_list_related_tables_from_creator() {
1122        let _guard = init_tracing(None);
1123        let pool = get_mem_duckdb();
1124
1125        let table_definition = get_basic_table_definition();
1126
1127        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1128        let conn = pool_conn
1129            .as_any_mut()
1130            .downcast_mut::<DuckDbConnection>()
1131            .expect("to downcast to duckdb connection");
1132        let tx = conn
1133            .get_underlying_conn_mut()
1134            .transaction()
1135            .expect("should begin transaction");
1136
1137        // make 3 internal tables
1138        for _ in 0..3 {
1139            TableManager::new(Arc::clone(&table_definition))
1140                .with_internal(true)
1141                .expect("to create table creator")
1142                .create_table(Arc::clone(&pool), &tx)
1143                .expect("to create table");
1144        }
1145
1146        // instantiate a new table creator, make it, and list the internal tables
1147        let table_creator = TableManager::new(Arc::clone(&table_definition))
1148            .with_internal(true)
1149            .expect("to create table creator");
1150
1151        table_creator
1152            .create_table(Arc::clone(&pool), &tx)
1153            .expect("to create table");
1154
1155        let internal_tables = table_creator
1156            .list_other_internal_tables(&tx)
1157            .expect("should list internal tables");
1158
1159        assert_eq!(internal_tables.len(), 3);
1160
1161        // validate none of the internal tables are the same, they are not equal to the base table, and they are not equal to the internal table that listed them
1162        let mut seen_tables = vec![];
1163        for (internal_table, _) in &internal_tables {
1164            let table_name = internal_table.table_name().clone();
1165            assert_ne!(&table_name, table_creator.definition_name());
1166            assert_ne!(Some(&table_name), table_creator.internal_name.as_ref());
1167            assert!(!seen_tables.contains(&table_name));
1168            seen_tables.push(table_name);
1169        }
1170
1171        // drop the internal tables except the last one
1172        for (internal_table, _) in internal_tables {
1173            internal_table.delete_table(&tx).expect("to delete table");
1174        }
1175
1176        // list the internal tables again
1177        let internal_tables = table_creator
1178            .list_other_internal_tables(&tx)
1179            .expect("should list internal tables");
1180
1181        assert_eq!(internal_tables.len(), 0);
1182
1183        tx.rollback().expect("should rollback transaction");
1184    }
1185
1186    #[tokio::test]
1187    async fn test_create_view() {
1188        let _guard = init_tracing(None);
1189        let pool = get_mem_duckdb();
1190
1191        let table_definition = get_basic_table_definition();
1192
1193        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1194        let conn = pool_conn
1195            .as_any_mut()
1196            .downcast_mut::<DuckDbConnection>()
1197            .expect("to downcast to duckdb connection");
1198        let tx = conn
1199            .get_underlying_conn_mut()
1200            .transaction()
1201            .expect("should begin transaction");
1202
1203        // make a table
1204        let table_creator = TableManager::new(Arc::clone(&table_definition))
1205            .with_internal(true)
1206            .expect("to create table creator");
1207
1208        table_creator
1209            .create_table(Arc::clone(&pool), &tx)
1210            .expect("to create table");
1211
1212        // create a view from the internal table
1213        table_creator.create_view(&tx).expect("to create view");
1214
1215        // check if the view exists
1216        let view_exists = tx
1217            .query_row(
1218                "from duckdb_views() select 1 where view_name = ? and not internal",
1219                [table_creator.definition_name().to_string()],
1220                |r| r.get::<usize, i32>(0),
1221            )
1222            .expect("to get view");
1223
1224        assert_eq!(view_exists, 1);
1225
1226        tx.rollback().expect("should rollback transaction");
1227    }
1228
1229    #[tokio::test]
1230    async fn test_insert_into_tables() {
1231        let _guard = init_tracing(None);
1232        let pool = get_mem_duckdb();
1233
1234        let table_definition = get_basic_table_definition();
1235
1236        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1237        let conn = pool_conn
1238            .as_any_mut()
1239            .downcast_mut::<DuckDbConnection>()
1240            .expect("to downcast to duckdb connection");
1241        let tx = conn
1242            .get_underlying_conn_mut()
1243            .transaction()
1244            .expect("should begin transaction");
1245
1246        // make a base table
1247        let base_table = TableManager::new(Arc::clone(&table_definition))
1248            .with_internal(false)
1249            .expect("to create table creator");
1250
1251        base_table
1252            .create_table(Arc::clone(&pool), &tx)
1253            .expect("to create table");
1254
1255        // make an internal table
1256        let internal_table = TableManager::new(Arc::clone(&table_definition))
1257            .with_internal(true)
1258            .expect("to create table creator");
1259
1260        internal_table
1261            .create_table(Arc::clone(&pool), &tx)
1262            .expect("to create table");
1263
1264        // insert some rows directly into the base table
1265        let insert_stmt = format!(
1266            r#"INSERT INTO "{base_table}" VALUES (1, 'test'), (2, 'test2')"#,
1267            base_table = base_table.table_name()
1268        );
1269
1270        tx.execute(&insert_stmt, [])
1271            .expect("to insert into base table");
1272
1273        // insert from the base table into the internal table
1274        base_table
1275            .insert_into(&internal_table, &tx, None)
1276            .expect("to insert into internal table");
1277
1278        // check if the rows were inserted
1279        let rows = tx
1280            .query_row(
1281                &format!(
1282                    r#"SELECT COUNT(1) FROM "{internal_table}""#,
1283                    internal_table = internal_table.table_name()
1284                ),
1285                [],
1286                |r| r.get::<usize, u64>(0),
1287            )
1288            .expect("to get count");
1289
1290        assert_eq!(rows, 2);
1291
1292        tx.rollback().expect("should rollback transaction");
1293    }
1294
1295    #[tokio::test]
1296    async fn test_lists_base_table_from_definition() {
1297        let _guard = init_tracing(None);
1298        let pool = get_mem_duckdb();
1299
1300        let table_definition = get_basic_table_definition();
1301
1302        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1303        let conn = pool_conn
1304            .as_any_mut()
1305            .downcast_mut::<DuckDbConnection>()
1306            .expect("to downcast to duckdb connection");
1307        let tx = conn
1308            .get_underlying_conn_mut()
1309            .transaction()
1310            .expect("should begin transaction");
1311
1312        // make a base table
1313        let table_creator = TableManager::new(Arc::clone(&table_definition))
1314            .with_internal(false)
1315            .expect("to create table creator");
1316
1317        table_creator
1318            .create_table(Arc::clone(&pool), &tx)
1319            .expect("to create table");
1320
1321        // list the base table from another base table
1322        let internal_table = TableManager::new(Arc::clone(&table_definition))
1323            .with_internal(false)
1324            .expect("to create table creator");
1325
1326        let base_table = internal_table.base_table(&tx).expect("to get base table");
1327
1328        assert!(base_table.is_some());
1329        assert_eq!(
1330            base_table.expect("to be some").table_definition,
1331            table_creator.table_definition
1332        );
1333
1334        // list the base table from an internal table
1335        let internal_table = TableManager::new(Arc::clone(&table_definition))
1336            .with_internal(true)
1337            .expect("to create table creator");
1338
1339        let base_table = internal_table.base_table(&tx).expect("to get base table");
1340
1341        assert!(base_table.is_some());
1342        assert_eq!(
1343            base_table.expect("to be some").table_definition,
1344            table_creator.table_definition
1345        );
1346
1347        tx.rollback().expect("should rollback transaction");
1348    }
1349
1350    #[tokio::test]
1351    async fn test_primary_keys_match() {
1352        let _guard = init_tracing(None);
1353        let pool = get_mem_duckdb();
1354
1355        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
1356            arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
1357            arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
1358        ]));
1359
1360        let table_definition = Arc::new(
1361            TableDefinition::new(RelationName::new("test_table"), Arc::clone(&schema))
1362                .with_constraints(get_pk_constraints(&["id"], Arc::clone(&schema))),
1363        );
1364
1365        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1366        let conn = pool_conn
1367            .as_any_mut()
1368            .downcast_mut::<DuckDbConnection>()
1369            .expect("to downcast to duckdb connection");
1370        let tx = conn
1371            .get_underlying_conn_mut()
1372            .transaction()
1373            .expect("should begin transaction");
1374
1375        // make 2 internal tables which should have the same indexes
1376        let table_creator = TableManager::new(Arc::clone(&table_definition))
1377            .with_internal(true)
1378            .expect("to create table creator");
1379
1380        table_creator
1381            .create_table(Arc::clone(&pool), &tx)
1382            .expect("to create table");
1383
1384        let table_creator2 = TableManager::new(Arc::clone(&table_definition))
1385            .with_internal(true)
1386            .expect("to create table creator");
1387
1388        table_creator2
1389            .create_table(Arc::clone(&pool), &tx)
1390            .expect("to create table");
1391
1392        let primary_keys_match = table_creator
1393            .verify_primary_keys_match(&table_creator2, &tx)
1394            .expect("to verify primary keys match");
1395
1396        assert!(primary_keys_match);
1397
1398        // make another table that does not match
1399        let table_definition = get_basic_table_definition();
1400
1401        let table_creator3 = TableManager::new(Arc::clone(&table_definition))
1402            .with_internal(true)
1403            .expect("to create table creator");
1404
1405        table_creator3
1406            .create_table(Arc::clone(&pool), &tx)
1407            .expect("to create table");
1408
1409        let primary_keys_match = table_creator
1410            .verify_primary_keys_match(&table_creator3, &tx)
1411            .expect("to verify primary keys match");
1412
1413        assert!(!primary_keys_match);
1414
1415        // validate that 2 empty tables return true
1416        let table_creator4 = TableManager::new(Arc::clone(&table_definition))
1417            .with_internal(true)
1418            .expect("to create table creator");
1419
1420        table_creator4
1421            .create_table(Arc::clone(&pool), &tx)
1422            .expect("to create table");
1423
1424        let primary_keys_match = table_creator3
1425            .verify_primary_keys_match(&table_creator4, &tx)
1426            .expect("to verify primary keys match");
1427
1428        assert!(primary_keys_match);
1429
1430        tx.rollback().expect("should rollback transaction");
1431    }
1432
1433    #[tokio::test]
1434    async fn test_indexes_match() {
1435        let _guard = init_tracing(None);
1436        let pool = get_mem_duckdb();
1437
1438        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
1439            arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
1440            arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
1441        ]));
1442
1443        let table_definition = Arc::new(
1444            TableDefinition::new(RelationName::new("test_table"), Arc::clone(&schema))
1445                .with_indexes(
1446                    vec![(
1447                        ColumnReference::try_from("id").expect("valid column ref"),
1448                        IndexType::Enabled,
1449                    )]
1450                    .into_iter()
1451                    .collect(),
1452                ),
1453        );
1454
1455        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1456        let conn = pool_conn
1457            .as_any_mut()
1458            .downcast_mut::<DuckDbConnection>()
1459            .expect("to downcast to duckdb connection");
1460        let tx = conn
1461            .get_underlying_conn_mut()
1462            .transaction()
1463            .expect("should begin transaction");
1464
1465        // make 2 internal tables which should have the same indexes
1466        let table_creator = TableManager::new(Arc::clone(&table_definition))
1467            .with_internal(true)
1468            .expect("to create table creator");
1469
1470        table_creator
1471            .create_table(Arc::clone(&pool), &tx)
1472            .expect("to create table");
1473
1474        table_creator
1475            .create_indexes(&tx)
1476            .expect("to create indexes");
1477
1478        let table_creator2 = TableManager::new(Arc::clone(&table_definition))
1479            .with_internal(true)
1480            .expect("to create table creator");
1481
1482        table_creator2
1483            .create_table(Arc::clone(&pool), &tx)
1484            .expect("to create table");
1485
1486        table_creator2
1487            .create_indexes(&tx)
1488            .expect("to create indexes");
1489
1490        let indexes_match = table_creator
1491            .verify_indexes_match(&table_creator2, &tx)
1492            .expect("to verify indexes match");
1493
1494        assert!(indexes_match);
1495
1496        // make another table that does not match
1497        let table_definition = get_basic_table_definition();
1498
1499        let table_creator3 = TableManager::new(Arc::clone(&table_definition))
1500            .with_internal(true)
1501            .expect("to create table creator");
1502
1503        table_creator3
1504            .create_table(Arc::clone(&pool), &tx)
1505            .expect("to create table");
1506
1507        table_creator3
1508            .create_indexes(&tx)
1509            .expect("to create indexes");
1510
1511        let indexes_match = table_creator
1512            .verify_indexes_match(&table_creator3, &tx)
1513            .expect("to verify indexes match");
1514
1515        assert!(!indexes_match);
1516
1517        // validate that 2 empty tables return true
1518        let table_creator4 = TableManager::new(Arc::clone(&table_definition))
1519            .with_internal(true)
1520            .expect("to create table creator");
1521
1522        table_creator4
1523            .create_table(Arc::clone(&pool), &tx)
1524            .expect("to create table");
1525
1526        table_creator4
1527            .create_indexes(&tx)
1528            .expect("to create indexes");
1529
1530        let indexes_match = table_creator3
1531            .verify_indexes_match(&table_creator4, &tx)
1532            .expect("to verify indexes match");
1533
1534        assert!(indexes_match);
1535
1536        tx.rollback().expect("should rollback transaction");
1537    }
1538
1539    #[tokio::test]
1540    async fn test_current_schema() {
1541        let _guard = init_tracing(None);
1542        let pool = get_mem_duckdb();
1543
1544        let table_definition = get_basic_table_definition();
1545
1546        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1547        let conn = pool_conn
1548            .as_any_mut()
1549            .downcast_mut::<DuckDbConnection>()
1550            .expect("to downcast to duckdb connection");
1551        let tx = conn
1552            .get_underlying_conn_mut()
1553            .transaction()
1554            .expect("should begin transaction");
1555
1556        let table_creator = TableManager::new(Arc::clone(&table_definition))
1557            .with_internal(true)
1558            .expect("to create table creator");
1559
1560        table_creator
1561            .create_table(Arc::clone(&pool), &tx)
1562            .expect("to create table");
1563
1564        let schema = table_creator
1565            .current_schema(&tx)
1566            .expect("to get current schema");
1567
1568        assert!(schema.equivalent_names_and_types(&table_definition.schema));
1569
1570        // schemas between different tables are equivalent
1571        let table_creator2 = TableManager::new(Arc::clone(&table_definition))
1572            .with_internal(true)
1573            .expect("to create table creator");
1574
1575        table_creator2
1576            .create_table(Arc::clone(&pool), &tx)
1577            .expect("to create table");
1578
1579        let schema2 = table_creator2
1580            .current_schema(&tx)
1581            .expect("to get current schema");
1582
1583        assert!(schema.equivalent_names_and_types(&schema2));
1584
1585        tx.rollback().expect("should rollback transaction");
1586    }
1587
1588    #[tokio::test]
1589    async fn test_internal_tables_exclude_subsets_of_other_tables() {
1590        let _guard = init_tracing(None);
1591        let pool = get_mem_duckdb();
1592
1593        let table_definition = get_basic_table_definition();
1594        let other_definition = Arc::new(TableDefinition::new(
1595            RelationName::new("test_table_second"),
1596            Arc::clone(&table_definition.schema),
1597        ));
1598
1599        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1600        let conn = pool_conn
1601            .as_any_mut()
1602            .downcast_mut::<DuckDbConnection>()
1603            .expect("to downcast to duckdb connection");
1604
1605        let tx = conn
1606            .get_underlying_conn_mut()
1607            .transaction()
1608            .expect("should begin transaction");
1609
1610        // make an internal table for each definition
1611        let table_creator = TableManager::new(Arc::clone(&table_definition))
1612            .with_internal(true)
1613            .expect("to create table creator");
1614
1615        table_creator
1616            .create_table(Arc::clone(&pool), &tx)
1617            .expect("to create table");
1618
1619        let other_table_creator = TableManager::new(Arc::clone(&other_definition))
1620            .with_internal(true)
1621            .expect("to create table creator");
1622
1623        other_table_creator
1624            .create_table(Arc::clone(&pool), &tx)
1625            .expect("to create table");
1626
1627        // each table should not list the other as an internal table
1628        let first_tables = table_definition
1629            .list_internal_tables(&tx)
1630            .expect("should list internal tables");
1631        let second_tables = other_definition
1632            .list_internal_tables(&tx)
1633            .expect("should list internal tables");
1634
1635        assert_eq!(first_tables.len(), 1);
1636        assert_eq!(second_tables.len(), 1);
1637
1638        assert_ne!(
1639            first_tables.first().expect("should have a table").0,
1640            second_tables.first().expect("should have a table").0
1641        );
1642    }
1643}