datafusion_table_providers/duckdb/
creator.rs

1use crate::sql::arrow_sql_gen::statement::IndexBuilder;
2use crate::sql::db_connection_pool::dbconnection::duckdbconn::DuckDbConnection;
3use crate::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool;
4use crate::util::on_conflict::OnConflict;
5use arrow::{
6    array::{RecordBatch, RecordBatchIterator, RecordBatchReader},
7    datatypes::SchemaRef,
8    ffi_stream::FFI_ArrowArrayStream,
9};
10use datafusion::common::utils::quote_identifier;
11use datafusion::common::Constraints;
12use datafusion::sql::TableReference;
13use duckdb::Transaction;
14use itertools::Itertools;
15use snafu::prelude::*;
16use std::collections::HashSet;
17use std::fmt::Display;
18use std::sync::Arc;
19
20use super::DuckDB;
21use crate::util::{
22    column_reference::ColumnReference, constraints::get_primary_keys_from_constraints,
23    indexes::IndexType,
24};
25
26/// A newtype for a relation name, to better control the inputs for the `TableDefinition`, `TableCreator`, and `ViewCreator`.
27#[derive(Debug, Clone, PartialEq)]
28pub struct RelationName(String);
29
30impl Display for RelationName {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "{}", self.0)
33    }
34}
35
36impl RelationName {
37    #[must_use]
38    pub fn new(name: impl Into<String>) -> Self {
39        Self(name.into())
40    }
41}
42
43impl From<TableReference> for RelationName {
44    fn from(table_ref: TableReference) -> Self {
45        RelationName(table_ref.to_string())
46    }
47}
48
49/// A table definition, which includes the table name, schema, constraints, and indexes.
50/// This is used to store the definition of a table for a dataset, and can be re-used to create one or more tables (like internal data tables).
51#[derive(Debug, Clone, PartialEq)]
52pub struct TableDefinition {
53    name: RelationName,
54    schema: SchemaRef,
55    constraints: Option<Constraints>,
56    indexes: Vec<(ColumnReference, IndexType)>,
57}
58
59impl TableDefinition {
60    #[must_use]
61    pub(crate) fn new(name: RelationName, schema: SchemaRef) -> Self {
62        Self {
63            name,
64            schema,
65            constraints: None,
66            indexes: Vec::new(),
67        }
68    }
69
70    #[must_use]
71    pub(crate) fn with_constraints(mut self, constraints: Constraints) -> Self {
72        self.constraints = Some(constraints);
73        self
74    }
75
76    #[must_use]
77    pub(crate) fn with_indexes(mut self, indexes: Vec<(ColumnReference, IndexType)>) -> Self {
78        self.indexes = indexes;
79        self
80    }
81
82    #[must_use]
83    pub fn name(&self) -> &RelationName {
84        &self.name
85    }
86
87    #[cfg(test)]
88    pub(crate) fn schema(&self) -> SchemaRef {
89        Arc::clone(&self.schema)
90    }
91
92    /// For an internal table, generate a unique name based on the table definition name and the current system time.
93    pub(crate) fn generate_internal_name(&self) -> super::Result<RelationName> {
94        let unix_ms = std::time::SystemTime::now()
95            .duration_since(std::time::UNIX_EPOCH)
96            .context(super::UnableToGetSystemTimeSnafu)?
97            .as_millis();
98        Ok(RelationName(format!(
99            "__data_{table_name}_{unix_ms}",
100            table_name = self.name,
101        )))
102    }
103
104    pub(crate) fn constraints(&self) -> Option<&Constraints> {
105        self.constraints.as_ref()
106    }
107
108    /// Returns true if this table definition has a base table matching the exact `RelationName` of the definition
109    ///
110    /// # Errors
111    ///
112    /// If the transaction fails to query for whether the table exists.
113    pub fn has_table(&self, tx: &Transaction<'_>) -> super::Result<bool> {
114        let mut stmt = tx
115            .prepare("SELECT 1 FROM duckdb_tables() WHERE table_name = ?")
116            .context(super::UnableToQueryDataSnafu)?;
117        let mut rows = stmt
118            .query([self.name.to_string()])
119            .context(super::UnableToQueryDataSnafu)?;
120
121        Ok(rows
122            .next()
123            .context(super::UnableToQueryDataSnafu)?
124            .is_some())
125    }
126
127    /// List all internal tables related to this table definition.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the internal tables cannot be listed.
132    pub fn list_internal_tables(
133        &self,
134        tx: &Transaction<'_>,
135    ) -> super::Result<Vec<(RelationName, u64)>> {
136        // list all related internal tables, based on the table definition name
137        let sql = format!(
138            "select table_name from duckdb_tables() where table_name LIKE '__data_{table_name}%'",
139            table_name = self.name
140        );
141        let mut stmt = tx.prepare(&sql).context(super::UnableToQueryDataSnafu)?;
142        let mut rows = stmt.query([]).context(super::UnableToQueryDataSnafu)?;
143
144        let mut table_names = Vec::new();
145        while let Some(row) = rows.next().context(super::UnableToQueryDataSnafu)? {
146            let table_name = row
147                .get::<usize, String>(0)
148                .context(super::UnableToQueryDataSnafu)?;
149            // __data_{table_name}% could be a subset of another table name, so we need to check if the table name starts with the table definition name
150            let inner_name = table_name.replace("__data_", "");
151            let mut parts = inner_name.split('_');
152            let Some(timestamp) = parts.next_back() else {
153                continue; // skip invalid table names
154            };
155
156            let inner_name = parts.join("_");
157            if inner_name != self.name.to_string() {
158                continue;
159            }
160
161            let timestamp = timestamp
162                .parse::<u64>()
163                .context(super::UnableToParseSystemTimeSnafu)?;
164
165            table_names.push((table_name, timestamp));
166        }
167
168        table_names.sort_by(|a, b| a.1.cmp(&b.1));
169
170        Ok(table_names
171            .into_iter()
172            .map(|(name, time_created)| (RelationName(name), time_created))
173            .collect())
174    }
175}
176
177/// A table creator, which is used to create, delete, and manage tables based on a `TableDefinition`.
178#[derive(Debug, Clone)]
179pub(crate) struct TableManager {
180    table_definition: Arc<TableDefinition>,
181    internal_name: Option<RelationName>,
182}
183
184impl TableManager {
185    pub(crate) fn new(table_definition: Arc<TableDefinition>) -> Self {
186        Self {
187            table_definition,
188            internal_name: None,
189        }
190    }
191
192    /// Set the internal flag for the table creator.
193    pub(crate) fn with_internal(mut self, is_internal: bool) -> super::Result<Self> {
194        if is_internal {
195            self.internal_name = Some(self.table_definition.generate_internal_name()?);
196        } else {
197            self.internal_name = None;
198        }
199
200        Ok(self)
201    }
202
203    pub(crate) fn definition_name(&self) -> &RelationName {
204        &self.table_definition.name
205    }
206
207    /// Returns the canonical name for this table, which is the internal name if the table is internal, or the table name if it is not.
208    pub(crate) fn table_name(&self) -> &RelationName {
209        self.internal_name
210            .as_ref()
211            .unwrap_or_else(|| &self.table_definition.name)
212    }
213
214    /// Searches if a table by the name specified in the table definition exists in the database.
215    /// Returns None if the table does not exist, or an instance of a `TableCreator` for the base table if it does.
216    #[tracing::instrument(level = "debug", skip_all)]
217    pub(crate) fn base_table(&self, tx: &Transaction<'_>) -> super::Result<Option<Self>> {
218        let mut stmt = tx
219            .prepare("SELECT 1 FROM duckdb_tables() WHERE table_name = ?")
220            .context(super::UnableToQueryDataSnafu)?;
221        let mut rows = stmt
222            .query([self.definition_name().to_string()])
223            .context(super::UnableToQueryDataSnafu)?;
224
225        if rows
226            .next()
227            .context(super::UnableToQueryDataSnafu)?
228            .is_some()
229        {
230            let base_table = self.clone();
231            Ok(Some(base_table.with_internal(false)?))
232        } else {
233            Ok(None)
234        }
235    }
236
237    pub(crate) fn indexes_vec(&self) -> Vec<(Vec<&str>, IndexType)> {
238        self.table_definition
239            .indexes
240            .iter()
241            .map(|(key, ty)| (key.iter().collect(), *ty))
242            .collect()
243    }
244
245    /// Creates the table for this `TableManager`. Does not create indexes - use `TableManager::create_indexes` to apply indexes.
246    #[tracing::instrument(level = "debug", skip_all)]
247    pub(crate) fn create_table(
248        &self,
249        pool: Arc<DuckDbConnectionPool>,
250        tx: &Transaction<'_>,
251    ) -> super::Result<()> {
252        let mut db_conn = pool.connect_sync().context(super::DbConnectionPoolSnafu)?;
253        let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn)?;
254
255        // create the table with the supplied table name, or a generated internal name
256        let mut create_stmt = self.get_table_create_statement(duckdb_conn)?;
257        tracing::debug!("{create_stmt}");
258
259        let primary_keys = if let Some(constraints) = &self.table_definition.constraints {
260            get_primary_keys_from_constraints(constraints, &self.table_definition.schema)
261        } else {
262            Vec::new()
263        };
264
265        if !primary_keys.is_empty() && !create_stmt.contains("PRIMARY KEY") {
266            let primary_key_clause = format!(", PRIMARY KEY ({}));", primary_keys.join(", "));
267            create_stmt = create_stmt.replace(");", &primary_key_clause);
268        }
269
270        tx.execute(&create_stmt, [])
271            .context(super::UnableToCreateDuckDBTableSnafu)?;
272
273        Ok(())
274    }
275
276    /// Drops indexes from the table, then drops the table itself.
277    #[tracing::instrument(level = "debug", skip_all)]
278    pub(crate) fn delete_table(&self, tx: &Transaction<'_>) -> super::Result<()> {
279        // drop indexes first
280        self.drop_indexes(tx)?;
281        self.drop_table(tx)?;
282
283        Ok(())
284    }
285
286    #[tracing::instrument(level = "debug", skip_all)]
287    fn drop_table(&self, tx: &Transaction<'_>) -> super::Result<()> {
288        // drop this table
289        tx.execute(
290            &format!(r#"DROP TABLE IF EXISTS "{}""#, self.table_name()),
291            [],
292        )
293        .context(super::UnableToDropDuckDBTableSnafu)?;
294
295        Ok(())
296    }
297
298    /// Inserts data from this table into the target table.
299    #[tracing::instrument(level = "debug", skip_all)]
300    pub(crate) fn insert_into(
301        &self,
302        table: &TableManager,
303        tx: &Transaction<'_>,
304        on_conflict: Option<&OnConflict>,
305    ) -> super::Result<u64> {
306        // insert from this table, into the target table
307        let mut insert_sql = format!(
308            r#"INSERT INTO "{}" SELECT * FROM "{}""#,
309            table.table_name(),
310            self.table_name()
311        );
312
313        if let Some(on_conflict) = on_conflict {
314            let on_conflict_sql =
315                on_conflict.build_on_conflict_statement(&self.table_definition.schema);
316            insert_sql.push_str(&format!(" {on_conflict_sql}"));
317        }
318        tracing::debug!("{insert_sql}");
319
320        let rows = tx
321            .execute(&insert_sql, [])
322            .context(super::UnableToInsertToDuckDBTableSnafu)?;
323
324        Ok(rows as u64)
325    }
326
327    fn get_index_name(table_name: &RelationName, index: &(Vec<&str>, IndexType)) -> String {
328        let index_builder = IndexBuilder::new(&table_name.to_string(), index.0.clone());
329        index_builder.index_name()
330    }
331
332    #[tracing::instrument(level = "debug", skip_all)]
333    fn create_index(
334        &self,
335        tx: &Transaction<'_>,
336        index: (Vec<&str>, IndexType),
337    ) -> super::Result<()> {
338        let table_name = self.table_name();
339
340        let unique = index.1 == IndexType::Unique;
341        let columns = index.0;
342        let mut index_builder = IndexBuilder::new(&table_name.to_string(), columns);
343        if unique {
344            index_builder = index_builder.unique();
345        }
346        let sql = index_builder.build_postgres();
347        tracing::debug!("Creating index: {sql}");
348
349        tx.execute(&sql, [])
350            .context(super::UnableToCreateIndexOnDuckDBTableSnafu)?;
351
352        Ok(())
353    }
354
355    #[tracing::instrument(level = "debug", skip_all)]
356    pub(crate) fn create_indexes(&self, tx: &Transaction<'_>) -> super::Result<()> {
357        // create indexes on this table
358        for index in self.indexes_vec() {
359            self.create_index(tx, index)?;
360        }
361        Ok(())
362    }
363
364    #[tracing::instrument(level = "debug", skip_all)]
365    fn drop_index(&self, tx: &Transaction<'_>, index: (Vec<&str>, IndexType)) -> super::Result<()> {
366        let table_name = self.table_name();
367        let index_name = TableManager::get_index_name(table_name, &index);
368
369        let sql = format!(r#"DROP INDEX IF EXISTS "{index_name}""#);
370        tracing::debug!("{sql}");
371
372        tx.execute(&sql, [])
373            .context(super::UnableToDropIndexOnDuckDBTableSnafu)?;
374
375        Ok(())
376    }
377
378    pub(crate) fn drop_indexes(&self, tx: &Transaction<'_>) -> super::Result<()> {
379        // drop indexes on this table
380        for index in self.indexes_vec() {
381            self.drop_index(tx, index)?;
382        }
383
384        Ok(())
385    }
386
387    /// DuckDB CREATE TABLE statements aren't supported by sea-query - so we create a temporary table
388    /// from an Arrow schema and ask DuckDB for the CREATE TABLE statement.
389    #[tracing::instrument(level = "debug", skip_all)]
390    fn get_table_create_statement(
391        &self,
392        duckdb_conn: &mut DuckDbConnection,
393    ) -> super::Result<String> {
394        let tx = duckdb_conn
395            .conn
396            .transaction()
397            .context(super::UnableToBeginTransactionSnafu)?;
398        let table_name = self.table_name();
399        let record_batch_reader =
400            create_empty_record_batch_reader(Arc::clone(&self.table_definition.schema));
401        let stream = FFI_ArrowArrayStream::new(Box::new(record_batch_reader));
402
403        let current_ts = std::time::SystemTime::now()
404            .duration_since(std::time::UNIX_EPOCH)
405            .context(super::UnableToGetSystemTimeSnafu)?
406            .as_millis();
407
408        let view_name = format!("__scan_{}_{current_ts}", table_name);
409        tx.register_arrow_scan_view(&view_name, &stream)
410            .context(super::UnableToRegisterArrowScanViewForTableCreationSnafu)?;
411
412        let sql =
413            format!(r#"CREATE TABLE IF NOT EXISTS "{table_name}" AS SELECT * FROM "{view_name}""#,);
414        tracing::debug!("{sql}");
415
416        tx.execute(&sql, [])
417            .context(super::UnableToCreateDuckDBTableSnafu)?;
418
419        let create_stmt = tx
420            .query_row(
421                &format!("select sql from duckdb_tables() where table_name = '{table_name}'",),
422                [],
423                |r| r.get::<usize, String>(0),
424            )
425            .context(super::UnableToQueryDataSnafu)?;
426
427        // DuckDB doesn't add IF NOT EXISTS to CREATE TABLE statements, so we add it here.
428        let create_stmt = create_stmt.replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS");
429
430        tx.rollback()
431            .context(super::UnableToRollbackTransactionSnafu)?;
432
433        Ok(create_stmt)
434    }
435
436    /// List all internal tables related to this table manager's table definition.
437    /// Excludes itself from the list of tables, if created.
438    #[tracing::instrument(level = "debug", skip_all)]
439    pub(crate) fn list_other_internal_tables(
440        &self,
441        tx: &Transaction<'_>,
442    ) -> super::Result<Vec<(Self, u64)>> {
443        let tables = self.table_definition.list_internal_tables(tx)?;
444
445        Ok(tables
446            .into_iter()
447            .filter_map(|(name, time_created)| {
448                if let Some(internal_name) = &self.internal_name {
449                    if name == *internal_name {
450                        return None;
451                    }
452                }
453
454                let internal_table = TableManager {
455                    table_definition: Arc::clone(&self.table_definition),
456                    internal_name: Some(name),
457                };
458                Some((internal_table, time_created))
459            })
460            .collect())
461    }
462
463    /// If this table is an internal table, creates a view with the table definition name targeting this table.
464    #[tracing::instrument(level = "debug", skip_all)]
465    pub(crate) fn create_view(&self, tx: &Transaction<'_>) -> super::Result<()> {
466        if self.internal_name.is_none() {
467            return Ok(());
468        }
469
470        tx.execute(
471            &format!(
472                "CREATE OR REPLACE VIEW {base_table} AS SELECT * FROM {internal_table}",
473                base_table = quote_identifier(&self.definition_name().to_string()),
474                internal_table = quote_identifier(&self.table_name().to_string())
475            ),
476            [],
477        )
478        .context(super::UnableToCreateDuckDBTableSnafu)?;
479
480        Ok(())
481    }
482
483    /// Returns the current primary keys in database for this table.
484    #[tracing::instrument(level = "debug", skip_all)]
485    pub(crate) fn current_primary_keys(
486        &self,
487        tx: &Transaction<'_>,
488    ) -> super::Result<HashSet<String>> {
489        // DuckDB provides convenient queryable 'pragma_table_info' table function
490        // Complex table name with schema as part of the name must be quoted as
491        // '"<name>"', otherwise it will be parsed to schema and table name
492        let sql = format!(
493            "SELECT name FROM pragma_table_info('{table_name}') WHERE pk = true",
494            table_name = quote_identifier(&self.table_name().to_string())
495        );
496        tracing::debug!("{sql}");
497
498        let mut stmt = tx
499            .prepare(&sql)
500            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
501
502        let primary_keys_iter = stmt
503            .query_map([], |row| row.get::<usize, String>(0))
504            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
505
506        let mut primary_keys = HashSet::new();
507        for pk in primary_keys_iter {
508            primary_keys.insert(pk.context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?);
509        }
510
511        Ok(primary_keys)
512    }
513
514    /// Returns the current indexes in database for this table.
515    #[tracing::instrument(level = "debug", skip_all)]
516    pub(crate) fn current_indexes(&self, tx: &Transaction<'_>) -> super::Result<HashSet<String>> {
517        let sql = format!(
518            "SELECT index_name FROM duckdb_indexes WHERE table_name = '{table_name}'",
519            table_name = &self.table_name().to_string()
520        );
521
522        tracing::debug!("{sql}");
523
524        let mut stmt = tx
525            .prepare(&sql)
526            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
527
528        let indexes_iter = stmt
529            .query_map([], |row| row.get::<usize, String>(0))
530            .context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?;
531
532        let mut indexes = HashSet::new();
533        for index in indexes_iter {
534            indexes.insert(index.context(super::UnableToGetPrimaryKeysOnDuckDBTableSnafu)?);
535        }
536
537        Ok(indexes)
538    }
539
540    #[cfg(test)]
541    pub(crate) fn from_table_name(
542        table_definition: Arc<TableDefinition>,
543        table_name: RelationName,
544    ) -> Self {
545        Self {
546            table_definition,
547            internal_name: Some(table_name),
548        }
549    }
550
551    /// Verifies that the primary keys match between this table creator and another table creator.
552    pub(crate) fn verify_primary_keys_match(
553        &self,
554        other_table: &TableManager,
555        tx: &Transaction<'_>,
556    ) -> super::Result<bool> {
557        let expected_pk_keys_str_map =
558            if let Some(constraints) = self.table_definition.constraints.as_ref() {
559                get_primary_keys_from_constraints(constraints, &self.table_definition.schema)
560                    .into_iter()
561                    .collect()
562            } else {
563                HashSet::new()
564            };
565
566        let actual_pk_keys_str_map = other_table.current_primary_keys(tx)?;
567
568        tracing::debug!(
569            "Expected primary keys: {:?}\nActual primary keys: {:?}",
570            expected_pk_keys_str_map,
571            actual_pk_keys_str_map
572        );
573
574        let missing_in_actual = expected_pk_keys_str_map
575            .difference(&actual_pk_keys_str_map)
576            .collect::<Vec<_>>();
577        let extra_in_actual = actual_pk_keys_str_map
578            .difference(&expected_pk_keys_str_map)
579            .collect::<Vec<_>>();
580
581        if !missing_in_actual.is_empty() {
582            tracing::warn!(
583                "Missing primary key(s) detected for the table '{name}': {:?}.",
584                missing_in_actual.iter().join(", "),
585                name = self.table_name()
586            );
587        }
588
589        if !extra_in_actual.is_empty() {
590            tracing::warn!(
591                "The table '{name}' has unexpected primary key(s) not defined in the configuration: {:?}.",
592                extra_in_actual.iter().join(", "),
593                name = self.table_name()
594            );
595        }
596
597        Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
598    }
599
600    /// Verifies that the indexes match between this table creator and another table creator.
601    pub(crate) fn verify_indexes_match(
602        &self,
603        other_table: &TableManager,
604        tx: &Transaction<'_>,
605    ) -> super::Result<bool> {
606        let expected_indexes_str_map: HashSet<String> = self
607            .indexes_vec()
608            .iter()
609            .map(|index| TableManager::get_index_name(self.table_name(), index))
610            .collect();
611
612        let actual_indexes_str_map = other_table.current_indexes(tx)?;
613
614        // replace table names for each index with nothing, as table names could be internal and have unique timestamps
615        let expected_indexes_str_map = expected_indexes_str_map
616            .iter()
617            .map(|index| index.replace(&self.table_name().to_string(), ""))
618            .collect::<HashSet<_>>();
619
620        let actual_indexes_str_map = actual_indexes_str_map
621            .iter()
622            .map(|index| index.replace(&other_table.table_name().to_string(), ""))
623            .collect::<HashSet<_>>();
624
625        tracing::debug!(
626            "Expected indexes: {:?}\nActual indexes: {:?}",
627            expected_indexes_str_map,
628            actual_indexes_str_map
629        );
630
631        let missing_in_actual = expected_indexes_str_map
632            .difference(&actual_indexes_str_map)
633            .collect::<Vec<_>>();
634        let extra_in_actual = actual_indexes_str_map
635            .difference(&expected_indexes_str_map)
636            .collect::<Vec<_>>();
637
638        if !missing_in_actual.is_empty() {
639            tracing::warn!(
640                "Missing index(es) detected for the table '{name}': {:?}.",
641                missing_in_actual.iter().join(", "),
642                name = self.table_name()
643            );
644        }
645        if !extra_in_actual.is_empty() {
646            tracing::warn!(
647                "Unexpected index(es) detected in table '{name}': {}.\nThese indexes are not defined in the configuration.",
648                extra_in_actual.iter().join(", "),
649                name = self.table_name()
650            );
651        }
652
653        Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
654    }
655
656    /// Returns the current schema in database for this table.
657    pub(crate) fn current_schema(&self, tx: &Transaction<'_>) -> super::Result<SchemaRef> {
658        let sql = format!(
659            "SELECT * FROM {table_name} LIMIT 0",
660            table_name = quote_identifier(&self.table_name().to_string())
661        );
662        let mut stmt = tx.prepare(&sql).context(super::UnableToQueryDataSnafu)?;
663        let result: duckdb::Arrow<'_> = stmt
664            .query_arrow([])
665            .context(super::UnableToQueryDataSnafu)?;
666        Ok(result.get_schema())
667    }
668
669    pub(crate) fn get_row_count(&self, tx: &Transaction<'_>) -> super::Result<u64> {
670        let sql = format!(
671            "SELECT COUNT(1) FROM {table_name}",
672            table_name = quote_identifier(&self.table_name().to_string())
673        );
674        let count = tx
675            .query_row(&sql, [], |r| r.get::<usize, u64>(0))
676            .context(super::UnableToQueryDataSnafu)?;
677
678        Ok(count)
679    }
680}
681
682fn create_empty_record_batch_reader(schema: SchemaRef) -> impl RecordBatchReader {
683    let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
684    let batches = vec![empty_batch];
685    RecordBatchIterator::new(batches.into_iter().map(Ok), schema)
686}
687
688#[derive(Debug, Clone)]
689pub(crate) struct ViewCreator {
690    name: RelationName,
691}
692
693impl ViewCreator {
694    #[must_use]
695    pub(crate) fn from_name(name: RelationName) -> Self {
696        Self { name }
697    }
698
699    pub(crate) fn insert_into(
700        &self,
701        table: &TableManager,
702        tx: &Transaction<'_>,
703        on_conflict: Option<&OnConflict>,
704    ) -> super::Result<u64> {
705        // insert from this view, into the target table
706        let mut insert_sql = format!(
707            r#"INSERT INTO "{table_name}" SELECT * FROM "{view_name}""#,
708            view_name = self.name,
709            table_name = table.table_name()
710        );
711
712        if let Some(on_conflict) = on_conflict {
713            let on_conflict_sql =
714                on_conflict.build_on_conflict_statement(&table.table_definition.schema);
715            insert_sql.push_str(&format!(" {on_conflict_sql}"));
716        }
717        tracing::debug!("{insert_sql}");
718
719        let rows = tx
720            .execute(&insert_sql, [])
721            .context(super::UnableToInsertToDuckDBTableSnafu)?;
722
723        Ok(rows as u64)
724    }
725
726    pub(crate) fn drop(&self, tx: &Transaction<'_>) -> super::Result<()> {
727        // drop this view
728        tx.execute(
729            &format!(
730                r#"DROP VIEW IF EXISTS "{view_name}""#,
731                view_name = self.name
732            ),
733            [],
734        )
735        .context(super::UnableToDropDuckDBTableSnafu)?;
736
737        Ok(())
738    }
739}
740
741#[cfg(test)]
742pub(crate) mod tests {
743    use crate::{
744        duckdb::make_initial_table,
745        sql::db_connection_pool::{
746            dbconnection::duckdbconn::DuckDbConnection, duckdbpool::DuckDbConnectionPool,
747        },
748    };
749    use datafusion::{arrow::array::RecordBatch, datasource::sink::DataSink};
750    use datafusion::{
751        common::SchemaExt,
752        execution::{SendableRecordBatchStream, TaskContext},
753        logical_expr::dml::InsertOp,
754        parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
755        physical_plan::memory::MemoryStream,
756    };
757    use tracing::subscriber::DefaultGuard;
758    use tracing_subscriber::EnvFilter;
759
760    use crate::{
761        duckdb::write::DuckDBDataSink,
762        util::constraints::tests::{get_pk_constraints, get_unique_constraints},
763    };
764
765    use super::*;
766
767    pub(crate) fn get_mem_duckdb() -> Arc<DuckDbConnectionPool> {
768        Arc::new(
769            DuckDbConnectionPool::new_memory().expect("to get a memory duckdb connection pool"),
770        )
771    }
772
773    async fn get_logs_batches() -> Vec<RecordBatch> {
774        let parquet_bytes = reqwest::get("https://public-data.spiceai.org/eth.recent_logs.parquet")
775            .await
776            .expect("to get parquet file")
777            .bytes()
778            .await
779            .expect("to get parquet bytes");
780
781        let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(parquet_bytes)
782            .expect("to get parquet reader builder")
783            .build()
784            .expect("to build parquet reader");
785
786        parquet_reader
787            .collect::<Result<Vec<_>, datafusion::arrow::error::ArrowError>>()
788            .expect("to get records")
789    }
790
791    fn get_stream_from_batches(batches: Vec<RecordBatch>) -> SendableRecordBatchStream {
792        let schema = batches[0].schema();
793        Box::pin(MemoryStream::try_new(batches, schema, None).expect("to get stream"))
794    }
795
796    pub(crate) fn init_tracing(default_level: Option<&str>) -> DefaultGuard {
797        let filter = match default_level {
798            Some(level) => EnvFilter::new(level),
799            _ => EnvFilter::new("INFO,datafusion_table_providers=TRACE"),
800        };
801
802        let subscriber = tracing_subscriber::FmtSubscriber::builder()
803            .with_env_filter(filter)
804            .with_ansi(true)
805            .finish();
806        tracing::subscriber::set_default(subscriber)
807    }
808
809    pub(crate) fn get_basic_table_definition() -> Arc<TableDefinition> {
810        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
811            arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
812            arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
813        ]));
814
815        Arc::new(TableDefinition::new(
816            RelationName::new("test_table"),
817            schema,
818        ))
819    }
820
821    #[tokio::test]
822    async fn test_table_creator() {
823        let _guard = init_tracing(None);
824        let batches = get_logs_batches().await;
825
826        let schema = batches[0].schema();
827
828        let table_definition = Arc::new(
829            TableDefinition::new(RelationName::new("eth.logs"), Arc::clone(&schema))
830                .with_constraints(get_unique_constraints(
831                    &["log_index", "transaction_hash"],
832                    Arc::clone(&schema),
833                ))
834                .with_indexes(vec![
835                    (
836                        ColumnReference::try_from("block_number").expect("valid column ref"),
837                        IndexType::Enabled,
838                    ),
839                    (
840                        ColumnReference::try_from("(log_index, transaction_hash)")
841                            .expect("valid column ref"),
842                        IndexType::Unique,
843                    ),
844                ]),
845        );
846
847        for overwrite in &[InsertOp::Append, InsertOp::Overwrite] {
848            let pool = get_mem_duckdb();
849
850            make_initial_table(Arc::clone(&table_definition), &pool)
851                .expect("to make initial table");
852
853            let duckdb_sink = DuckDBDataSink::new(
854                Arc::clone(&pool),
855                Arc::clone(&table_definition),
856                *overwrite,
857                None,
858                table_definition.schema(),
859            );
860            let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
861            let rows_written = data_sink
862                .write_all(
863                    get_stream_from_batches(batches.clone()),
864                    &Arc::new(TaskContext::default()),
865                )
866                .await
867                .expect("to write all");
868
869            let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
870            let conn = pool_conn
871                .as_any_mut()
872                .downcast_mut::<DuckDbConnection>()
873                .expect("to downcast to duckdb connection");
874            let num_rows = conn
875                .get_underlying_conn_mut()
876                .query_row(r#"SELECT COUNT(1) FROM "eth.logs""#, [], |r| {
877                    r.get::<usize, u64>(0)
878                })
879                .expect("to get count");
880
881            assert_eq!(num_rows, rows_written);
882
883            let tx = conn
884                .get_underlying_conn_mut()
885                .transaction()
886                .expect("should begin transaction");
887            let table_creator = if matches!(overwrite, InsertOp::Overwrite) {
888                let internal_tables: Vec<(RelationName, u64)> = table_definition
889                    .list_internal_tables(&tx)
890                    .expect("should list internal tables");
891                assert_eq!(internal_tables.len(), 1);
892
893                let internal_table = internal_tables.first().expect("to get internal table");
894                let internal_table = internal_table.0.clone();
895
896                TableManager::from_table_name(Arc::clone(&table_definition), internal_table.clone())
897            } else {
898                let table_creator = TableManager::new(Arc::clone(&table_definition))
899                    .with_internal(false)
900                    .expect("to create table creator");
901
902                let base_table = table_creator.base_table(&tx).expect("to get base table");
903                assert!(base_table.is_some());
904                table_creator
905            };
906
907            let primary_keys = table_creator
908                .current_primary_keys(&tx)
909                .expect("should get primary keys");
910
911            assert_eq!(primary_keys, HashSet::<String>::new());
912
913            let created_indexes_str_map = table_creator
914                .current_indexes(&tx)
915                .expect("should get indexes");
916
917            assert_eq!(
918                created_indexes_str_map,
919                vec![
920                    format!(
921                        "i_{table_name}_block_number",
922                        table_name = table_creator.table_name()
923                    ),
924                    format!(
925                        "i_{table_name}_log_index_transaction_hash",
926                        table_name = table_creator.table_name()
927                    )
928                ]
929                .into_iter()
930                .collect::<HashSet<_>>(),
931                "Indexes must match"
932            );
933
934            tx.rollback().expect("should rollback transaction");
935        }
936    }
937
938    #[allow(clippy::too_many_lines)]
939    #[tokio::test]
940    async fn test_table_creator_primary_key() {
941        let _guard = init_tracing(None);
942        let batches = get_logs_batches().await;
943
944        let schema = batches[0].schema();
945        let table_definition = Arc::new(
946            TableDefinition::new(RelationName::new("eth.logs"), Arc::clone(&schema))
947                .with_constraints(get_pk_constraints(
948                    &["log_index", "transaction_hash"],
949                    Arc::clone(&schema),
950                ))
951                .with_indexes(
952                    vec![(
953                        ColumnReference::try_from("block_number").expect("valid column ref"),
954                        IndexType::Enabled,
955                    )]
956                    .into_iter()
957                    .collect(),
958                ),
959        );
960
961        for overwrite in &[InsertOp::Append, InsertOp::Overwrite] {
962            let pool = get_mem_duckdb();
963
964            make_initial_table(Arc::clone(&table_definition), &pool)
965                .expect("to make initial table");
966
967            let duckdb_sink = DuckDBDataSink::new(
968                Arc::clone(&pool),
969                Arc::clone(&table_definition),
970                *overwrite,
971                None,
972                table_definition.schema(),
973            );
974            let data_sink: Arc<dyn DataSink> = Arc::new(duckdb_sink);
975            let rows_written = data_sink
976                .write_all(
977                    get_stream_from_batches(batches.clone()),
978                    &Arc::new(TaskContext::default()),
979                )
980                .await
981                .expect("to write all");
982
983            let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
984            let conn = pool_conn
985                .as_any_mut()
986                .downcast_mut::<DuckDbConnection>()
987                .expect("to downcast to duckdb connection");
988            let num_rows = conn
989                .get_underlying_conn_mut()
990                .query_row(r#"SELECT COUNT(1) FROM "eth.logs""#, [], |r| {
991                    r.get::<usize, u64>(0)
992                })
993                .expect("to get count");
994
995            assert_eq!(num_rows, rows_written);
996
997            let tx = conn
998                .get_underlying_conn_mut()
999                .transaction()
1000                .expect("should begin transaction");
1001
1002            let table_creator = if matches!(overwrite, InsertOp::Overwrite) {
1003                let internal_tables: Vec<(RelationName, u64)> = table_definition
1004                    .list_internal_tables(&tx)
1005                    .expect("should list internal tables");
1006                assert_eq!(internal_tables.len(), 1);
1007
1008                let internal_table = internal_tables.first().expect("to get internal table");
1009                let internal_table = internal_table.0.clone();
1010
1011                TableManager::from_table_name(Arc::clone(&table_definition), internal_table.clone())
1012            } else {
1013                let table_creator = TableManager::new(Arc::clone(&table_definition))
1014                    .with_internal(false)
1015                    .expect("to create table creator");
1016
1017                let base_table = table_creator.base_table(&tx).expect("to get base table");
1018                assert!(base_table.is_some());
1019                table_creator
1020            };
1021
1022            let create_stmt = tx
1023                .query_row(
1024                    "select sql from duckdb_tables() where table_name = ?",
1025                    [table_creator.table_name().to_string()],
1026                    |r| r.get::<usize, String>(0),
1027                )
1028                .expect("to get create table statement");
1029
1030            assert_eq!(
1031                create_stmt,
1032                format!(
1033                    r#"CREATE TABLE "{table_name}"(log_index BIGINT, transaction_hash VARCHAR, transaction_index BIGINT, address VARCHAR, "data" VARCHAR, topics VARCHAR[], block_timestamp BIGINT, block_hash VARCHAR, block_number BIGINT, PRIMARY KEY(log_index, transaction_hash));"#,
1034                    table_name = table_creator.table_name(),
1035                )
1036            );
1037
1038            let primary_keys = table_creator
1039                .current_primary_keys(&tx)
1040                .expect("should get primary keys");
1041
1042            assert_eq!(
1043                primary_keys,
1044                vec!["log_index".to_string(), "transaction_hash".to_string()]
1045                    .into_iter()
1046                    .collect::<HashSet<_>>()
1047            );
1048
1049            let created_indexes_str_map = table_creator
1050                .current_indexes(&tx)
1051                .expect("should get indexes");
1052
1053            assert_eq!(
1054                created_indexes_str_map,
1055                vec![format!(
1056                    "i_{table_name}_block_number",
1057                    table_name = table_creator.table_name()
1058                )]
1059                .into_iter()
1060                .collect::<HashSet<_>>(),
1061                "Indexes must match"
1062            );
1063
1064            tx.rollback().expect("should rollback transaction");
1065        }
1066    }
1067
1068    #[tokio::test]
1069    async fn test_list_related_tables_from_definition() {
1070        let _guard = init_tracing(None);
1071        let pool = get_mem_duckdb();
1072
1073        let table_definition = get_basic_table_definition();
1074
1075        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1076        let conn = pool_conn
1077            .as_any_mut()
1078            .downcast_mut::<DuckDbConnection>()
1079            .expect("to downcast to duckdb connection");
1080        let tx = conn
1081            .get_underlying_conn_mut()
1082            .transaction()
1083            .expect("should begin transaction");
1084
1085        // make 3 internal tables
1086        for _ in 0..3 {
1087            TableManager::new(Arc::clone(&table_definition))
1088                .with_internal(true)
1089                .expect("to create table creator")
1090                .create_table(Arc::clone(&pool), &tx)
1091                .expect("to create table");
1092        }
1093
1094        // using the table definition, list the names of the internal tables
1095        let table_name = table_definition.name.clone();
1096        let internal_tables = table_definition
1097            .list_internal_tables(&tx)
1098            .expect("should list internal tables");
1099
1100        assert_eq!(internal_tables.len(), 3);
1101
1102        // validate the first table is the oldest, and the last table is the newest
1103        let first_table = internal_tables.first().expect("to get first table");
1104        let last_table = internal_tables.last().expect("to get last table");
1105        assert!(first_table.1 < last_table.1);
1106
1107        // validate none of the internal tables are the same, they are not equal to the base table
1108        let mut seen_tables = vec![];
1109        for (internal_table, _) in internal_tables {
1110            let internal_name = internal_table.clone();
1111            assert_ne!(&internal_name, &table_name);
1112            assert!(!seen_tables.contains(&internal_name));
1113            seen_tables.push(internal_name);
1114        }
1115
1116        tx.rollback().expect("should rollback transaction");
1117    }
1118
1119    #[tokio::test]
1120    async fn test_list_related_tables_from_creator() {
1121        let _guard = init_tracing(None);
1122        let pool = get_mem_duckdb();
1123
1124        let table_definition = get_basic_table_definition();
1125
1126        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1127        let conn = pool_conn
1128            .as_any_mut()
1129            .downcast_mut::<DuckDbConnection>()
1130            .expect("to downcast to duckdb connection");
1131        let tx = conn
1132            .get_underlying_conn_mut()
1133            .transaction()
1134            .expect("should begin transaction");
1135
1136        // make 3 internal tables
1137        for _ in 0..3 {
1138            TableManager::new(Arc::clone(&table_definition))
1139                .with_internal(true)
1140                .expect("to create table creator")
1141                .create_table(Arc::clone(&pool), &tx)
1142                .expect("to create table");
1143        }
1144
1145        // instantiate a new table creator, make it, and list the internal tables
1146        let table_creator = TableManager::new(Arc::clone(&table_definition))
1147            .with_internal(true)
1148            .expect("to create table creator");
1149
1150        table_creator
1151            .create_table(Arc::clone(&pool), &tx)
1152            .expect("to create table");
1153
1154        let internal_tables = table_creator
1155            .list_other_internal_tables(&tx)
1156            .expect("should list internal tables");
1157
1158        assert_eq!(internal_tables.len(), 3);
1159
1160        // validate none of the internal tables are the same, they are not equal to the base table, and they are not equal to the internal table that listed them
1161        let mut seen_tables = vec![];
1162        for (internal_table, _) in &internal_tables {
1163            let table_name = internal_table.table_name().clone();
1164            assert_ne!(&table_name, table_creator.definition_name());
1165            assert_ne!(Some(&table_name), table_creator.internal_name.as_ref());
1166            assert!(!seen_tables.contains(&table_name));
1167            seen_tables.push(table_name);
1168        }
1169
1170        // drop the internal tables except the last one
1171        for (internal_table, _) in internal_tables {
1172            internal_table.delete_table(&tx).expect("to delete table");
1173        }
1174
1175        // list the internal tables again
1176        let internal_tables = table_creator
1177            .list_other_internal_tables(&tx)
1178            .expect("should list internal tables");
1179
1180        assert_eq!(internal_tables.len(), 0);
1181
1182        tx.rollback().expect("should rollback transaction");
1183    }
1184
1185    #[tokio::test]
1186    async fn test_create_view() {
1187        let _guard = init_tracing(None);
1188        let pool = get_mem_duckdb();
1189
1190        let table_definition = get_basic_table_definition();
1191
1192        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1193        let conn = pool_conn
1194            .as_any_mut()
1195            .downcast_mut::<DuckDbConnection>()
1196            .expect("to downcast to duckdb connection");
1197        let tx = conn
1198            .get_underlying_conn_mut()
1199            .transaction()
1200            .expect("should begin transaction");
1201
1202        // make a table
1203        let table_creator = TableManager::new(Arc::clone(&table_definition))
1204            .with_internal(true)
1205            .expect("to create table creator");
1206
1207        table_creator
1208            .create_table(Arc::clone(&pool), &tx)
1209            .expect("to create table");
1210
1211        // create a view from the internal table
1212        table_creator.create_view(&tx).expect("to create view");
1213
1214        // check if the view exists
1215        let view_exists = tx
1216            .query_row(
1217                "from duckdb_views() select 1 where view_name = ? and not internal",
1218                [table_creator.definition_name().to_string()],
1219                |r| r.get::<usize, i32>(0),
1220            )
1221            .expect("to get view");
1222
1223        assert_eq!(view_exists, 1);
1224
1225        tx.rollback().expect("should rollback transaction");
1226    }
1227
1228    #[tokio::test]
1229    async fn test_insert_into_tables() {
1230        let _guard = init_tracing(None);
1231        let pool = get_mem_duckdb();
1232
1233        let table_definition = get_basic_table_definition();
1234
1235        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1236        let conn = pool_conn
1237            .as_any_mut()
1238            .downcast_mut::<DuckDbConnection>()
1239            .expect("to downcast to duckdb connection");
1240        let tx = conn
1241            .get_underlying_conn_mut()
1242            .transaction()
1243            .expect("should begin transaction");
1244
1245        // make a base table
1246        let base_table = TableManager::new(Arc::clone(&table_definition))
1247            .with_internal(false)
1248            .expect("to create table creator");
1249
1250        base_table
1251            .create_table(Arc::clone(&pool), &tx)
1252            .expect("to create table");
1253
1254        // make an internal table
1255        let internal_table = TableManager::new(Arc::clone(&table_definition))
1256            .with_internal(true)
1257            .expect("to create table creator");
1258
1259        internal_table
1260            .create_table(Arc::clone(&pool), &tx)
1261            .expect("to create table");
1262
1263        // insert some rows directly into the base table
1264        let insert_stmt = format!(
1265            r#"INSERT INTO "{base_table}" VALUES (1, 'test'), (2, 'test2')"#,
1266            base_table = base_table.table_name()
1267        );
1268
1269        tx.execute(&insert_stmt, [])
1270            .expect("to insert into base table");
1271
1272        // insert from the base table into the internal table
1273        base_table
1274            .insert_into(&internal_table, &tx, None)
1275            .expect("to insert into internal table");
1276
1277        // check if the rows were inserted
1278        let rows = tx
1279            .query_row(
1280                &format!(
1281                    r#"SELECT COUNT(1) FROM "{internal_table}""#,
1282                    internal_table = internal_table.table_name()
1283                ),
1284                [],
1285                |r| r.get::<usize, u64>(0),
1286            )
1287            .expect("to get count");
1288
1289        assert_eq!(rows, 2);
1290
1291        tx.rollback().expect("should rollback transaction");
1292    }
1293
1294    #[tokio::test]
1295    async fn test_lists_base_table_from_definition() {
1296        let _guard = init_tracing(None);
1297        let pool = get_mem_duckdb();
1298
1299        let table_definition = get_basic_table_definition();
1300
1301        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1302        let conn = pool_conn
1303            .as_any_mut()
1304            .downcast_mut::<DuckDbConnection>()
1305            .expect("to downcast to duckdb connection");
1306        let tx = conn
1307            .get_underlying_conn_mut()
1308            .transaction()
1309            .expect("should begin transaction");
1310
1311        // make a base table
1312        let table_creator = TableManager::new(Arc::clone(&table_definition))
1313            .with_internal(false)
1314            .expect("to create table creator");
1315
1316        table_creator
1317            .create_table(Arc::clone(&pool), &tx)
1318            .expect("to create table");
1319
1320        // list the base table from another base table
1321        let internal_table = TableManager::new(Arc::clone(&table_definition))
1322            .with_internal(false)
1323            .expect("to create table creator");
1324
1325        let base_table = internal_table.base_table(&tx).expect("to get base table");
1326
1327        assert!(base_table.is_some());
1328        assert_eq!(
1329            base_table.expect("to be some").table_definition,
1330            table_creator.table_definition
1331        );
1332
1333        // list the base table from an internal table
1334        let internal_table = TableManager::new(Arc::clone(&table_definition))
1335            .with_internal(true)
1336            .expect("to create table creator");
1337
1338        let base_table = internal_table.base_table(&tx).expect("to get base table");
1339
1340        assert!(base_table.is_some());
1341        assert_eq!(
1342            base_table.expect("to be some").table_definition,
1343            table_creator.table_definition
1344        );
1345
1346        tx.rollback().expect("should rollback transaction");
1347    }
1348
1349    #[tokio::test]
1350    async fn test_primary_keys_match() {
1351        let _guard = init_tracing(None);
1352        let pool = get_mem_duckdb();
1353
1354        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
1355            arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
1356            arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
1357        ]));
1358
1359        let table_definition = Arc::new(
1360            TableDefinition::new(RelationName::new("test_table"), Arc::clone(&schema))
1361                .with_constraints(get_pk_constraints(&["id"], Arc::clone(&schema))),
1362        );
1363
1364        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1365        let conn = pool_conn
1366            .as_any_mut()
1367            .downcast_mut::<DuckDbConnection>()
1368            .expect("to downcast to duckdb connection");
1369        let tx = conn
1370            .get_underlying_conn_mut()
1371            .transaction()
1372            .expect("should begin transaction");
1373
1374        // make 2 internal tables which should have the same indexes
1375        let table_creator = TableManager::new(Arc::clone(&table_definition))
1376            .with_internal(true)
1377            .expect("to create table creator");
1378
1379        table_creator
1380            .create_table(Arc::clone(&pool), &tx)
1381            .expect("to create table");
1382
1383        let table_creator2 = TableManager::new(Arc::clone(&table_definition))
1384            .with_internal(true)
1385            .expect("to create table creator");
1386
1387        table_creator2
1388            .create_table(Arc::clone(&pool), &tx)
1389            .expect("to create table");
1390
1391        let primary_keys_match = table_creator
1392            .verify_primary_keys_match(&table_creator2, &tx)
1393            .expect("to verify primary keys match");
1394
1395        assert!(primary_keys_match);
1396
1397        // make another table that does not match
1398        let table_definition = get_basic_table_definition();
1399
1400        let table_creator3 = TableManager::new(Arc::clone(&table_definition))
1401            .with_internal(true)
1402            .expect("to create table creator");
1403
1404        table_creator3
1405            .create_table(Arc::clone(&pool), &tx)
1406            .expect("to create table");
1407
1408        let primary_keys_match = table_creator
1409            .verify_primary_keys_match(&table_creator3, &tx)
1410            .expect("to verify primary keys match");
1411
1412        assert!(!primary_keys_match);
1413
1414        // validate that 2 empty tables return true
1415        let table_creator4 = TableManager::new(Arc::clone(&table_definition))
1416            .with_internal(true)
1417            .expect("to create table creator");
1418
1419        table_creator4
1420            .create_table(Arc::clone(&pool), &tx)
1421            .expect("to create table");
1422
1423        let primary_keys_match = table_creator3
1424            .verify_primary_keys_match(&table_creator4, &tx)
1425            .expect("to verify primary keys match");
1426
1427        assert!(primary_keys_match);
1428
1429        tx.rollback().expect("should rollback transaction");
1430    }
1431
1432    #[tokio::test]
1433    async fn test_indexes_match() {
1434        let _guard = init_tracing(None);
1435        let pool = get_mem_duckdb();
1436
1437        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
1438            arrow::datatypes::Field::new("id", arrow::datatypes::DataType::Int64, false),
1439            arrow::datatypes::Field::new("name", arrow::datatypes::DataType::Utf8, false),
1440        ]));
1441
1442        let table_definition = Arc::new(
1443            TableDefinition::new(RelationName::new("test_table"), Arc::clone(&schema))
1444                .with_indexes(
1445                    vec![(
1446                        ColumnReference::try_from("id").expect("valid column ref"),
1447                        IndexType::Enabled,
1448                    )]
1449                    .into_iter()
1450                    .collect(),
1451                ),
1452        );
1453
1454        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1455        let conn = pool_conn
1456            .as_any_mut()
1457            .downcast_mut::<DuckDbConnection>()
1458            .expect("to downcast to duckdb connection");
1459        let tx = conn
1460            .get_underlying_conn_mut()
1461            .transaction()
1462            .expect("should begin transaction");
1463
1464        // make 2 internal tables which should have the same indexes
1465        let table_creator = TableManager::new(Arc::clone(&table_definition))
1466            .with_internal(true)
1467            .expect("to create table creator");
1468
1469        table_creator
1470            .create_table(Arc::clone(&pool), &tx)
1471            .expect("to create table");
1472
1473        table_creator
1474            .create_indexes(&tx)
1475            .expect("to create indexes");
1476
1477        let table_creator2 = TableManager::new(Arc::clone(&table_definition))
1478            .with_internal(true)
1479            .expect("to create table creator");
1480
1481        table_creator2
1482            .create_table(Arc::clone(&pool), &tx)
1483            .expect("to create table");
1484
1485        table_creator2
1486            .create_indexes(&tx)
1487            .expect("to create indexes");
1488
1489        let indexes_match = table_creator
1490            .verify_indexes_match(&table_creator2, &tx)
1491            .expect("to verify indexes match");
1492
1493        assert!(indexes_match);
1494
1495        // make another table that does not match
1496        let table_definition = get_basic_table_definition();
1497
1498        let table_creator3 = TableManager::new(Arc::clone(&table_definition))
1499            .with_internal(true)
1500            .expect("to create table creator");
1501
1502        table_creator3
1503            .create_table(Arc::clone(&pool), &tx)
1504            .expect("to create table");
1505
1506        table_creator3
1507            .create_indexes(&tx)
1508            .expect("to create indexes");
1509
1510        let indexes_match = table_creator
1511            .verify_indexes_match(&table_creator3, &tx)
1512            .expect("to verify indexes match");
1513
1514        assert!(!indexes_match);
1515
1516        // validate that 2 empty tables return true
1517        let table_creator4 = TableManager::new(Arc::clone(&table_definition))
1518            .with_internal(true)
1519            .expect("to create table creator");
1520
1521        table_creator4
1522            .create_table(Arc::clone(&pool), &tx)
1523            .expect("to create table");
1524
1525        table_creator4
1526            .create_indexes(&tx)
1527            .expect("to create indexes");
1528
1529        let indexes_match = table_creator3
1530            .verify_indexes_match(&table_creator4, &tx)
1531            .expect("to verify indexes match");
1532
1533        assert!(indexes_match);
1534
1535        tx.rollback().expect("should rollback transaction");
1536    }
1537
1538    #[tokio::test]
1539    async fn test_current_schema() {
1540        let _guard = init_tracing(None);
1541        let pool = get_mem_duckdb();
1542
1543        let table_definition = get_basic_table_definition();
1544
1545        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1546        let conn = pool_conn
1547            .as_any_mut()
1548            .downcast_mut::<DuckDbConnection>()
1549            .expect("to downcast to duckdb connection");
1550        let tx = conn
1551            .get_underlying_conn_mut()
1552            .transaction()
1553            .expect("should begin transaction");
1554
1555        let table_creator = TableManager::new(Arc::clone(&table_definition))
1556            .with_internal(true)
1557            .expect("to create table creator");
1558
1559        table_creator
1560            .create_table(Arc::clone(&pool), &tx)
1561            .expect("to create table");
1562
1563        let schema = table_creator
1564            .current_schema(&tx)
1565            .expect("to get current schema");
1566
1567        assert!(schema.equivalent_names_and_types(&table_definition.schema));
1568
1569        // schemas between different tables are equivalent
1570        let table_creator2 = TableManager::new(Arc::clone(&table_definition))
1571            .with_internal(true)
1572            .expect("to create table creator");
1573
1574        table_creator2
1575            .create_table(Arc::clone(&pool), &tx)
1576            .expect("to create table");
1577
1578        let schema2 = table_creator2
1579            .current_schema(&tx)
1580            .expect("to get current schema");
1581
1582        assert!(schema.equivalent_names_and_types(&schema2));
1583
1584        tx.rollback().expect("should rollback transaction");
1585    }
1586
1587    #[tokio::test]
1588    async fn test_internal_tables_exclude_subsets_of_other_tables() {
1589        let _guard = init_tracing(None);
1590        let pool = get_mem_duckdb();
1591
1592        let table_definition = get_basic_table_definition();
1593        let other_definition = Arc::new(TableDefinition::new(
1594            RelationName::new("test_table_second"),
1595            Arc::clone(&table_definition.schema),
1596        ));
1597
1598        let mut pool_conn = Arc::clone(&pool).connect_sync().expect("to get connection");
1599        let conn = pool_conn
1600            .as_any_mut()
1601            .downcast_mut::<DuckDbConnection>()
1602            .expect("to downcast to duckdb connection");
1603
1604        let tx = conn
1605            .get_underlying_conn_mut()
1606            .transaction()
1607            .expect("should begin transaction");
1608
1609        // make an internal table for each definition
1610        let table_creator = TableManager::new(Arc::clone(&table_definition))
1611            .with_internal(true)
1612            .expect("to create table creator");
1613
1614        table_creator
1615            .create_table(Arc::clone(&pool), &tx)
1616            .expect("to create table");
1617
1618        let other_table_creator = TableManager::new(Arc::clone(&other_definition))
1619            .with_internal(true)
1620            .expect("to create table creator");
1621
1622        other_table_creator
1623            .create_table(Arc::clone(&pool), &tx)
1624            .expect("to create table");
1625
1626        // each table should not list the other as an internal table
1627        let first_tables = table_definition
1628            .list_internal_tables(&tx)
1629            .expect("should list internal tables");
1630        let second_tables = other_definition
1631            .list_internal_tables(&tx)
1632            .expect("should list internal tables");
1633
1634        assert_eq!(first_tables.len(), 1);
1635        assert_eq!(second_tables.len(), 1);
1636
1637        assert_ne!(
1638            first_tables.first().expect("should have a table").0,
1639            second_tables.first().expect("should have a table").0
1640        );
1641    }
1642}