1use crate::sql::arrow_sql_gen::statement::{CreateTableBuilder, IndexBuilder, InsertBuilder};
2use crate::sql::db_connection_pool::dbconnection::{self, get_schema, AsyncDbConnection};
3use crate::sql::db_connection_pool::sqlitepool::SqliteConnectionPoolFactory;
4use crate::sql::db_connection_pool::DbInstanceKey;
5use crate::sql::db_connection_pool::{
6 self,
7 dbconnection::{sqliteconn::SqliteConnection, DbConnection},
8 sqlitepool::SqliteConnectionPool,
9 DbConnectionPool, Mode,
10};
11use crate::sql::sql_provider_datafusion;
12use crate::util::schema::SchemaValidator;
13use crate::UnsupportedTypeAction;
14use arrow::array::{Int64Array, StringArray};
15use arrow::{array::RecordBatch, datatypes::SchemaRef};
16use async_trait::async_trait;
17use datafusion::catalog::Session;
18use datafusion::{
19 catalog::TableProviderFactory,
20 common::Constraints,
21 datasource::TableProvider,
22 error::{DataFusionError, Result as DataFusionResult},
23 logical_expr::CreateExternalTable,
24 sql::TableReference,
25};
26use futures::TryStreamExt;
27use rusqlite::{ToSql, Transaction};
28use snafu::prelude::*;
29use sql_table::SQLiteTable;
30use std::collections::HashSet;
31use std::time::Duration;
32use std::{collections::HashMap, sync::Arc};
33use tokio::sync::Mutex;
34use tokio_rusqlite::Connection;
35
36use crate::util::{
37 self,
38 column_reference::{self, ColumnReference},
39 constraints::{self, get_primary_keys_from_constraints},
40 indexes::IndexType,
41 on_conflict::{self, OnConflict},
42};
43
44use self::write::SqliteTableWriter;
45
46#[cfg(feature = "sqlite-federation")]
47pub mod federation;
48
49#[cfg(feature = "sqlite-federation")]
50pub mod sqlite_interval;
51
52pub mod sql_table;
53pub mod write;
54
55#[derive(Debug, Snafu)]
56pub enum Error {
57 #[snafu(display("DbConnectionError: {source}"))]
58 DbConnectionError {
59 source: db_connection_pool::dbconnection::GenericError,
60 },
61
62 #[snafu(display("DbConnectionPoolError: {source}"))]
63 DbConnectionPoolError { source: db_connection_pool::Error },
64
65 #[snafu(display("Unable to downcast DbConnection to SqliteConnection"))]
66 UnableToDowncastDbConnection {},
67
68 #[snafu(display("Unable to construct SQLTable instance: {source}"))]
69 UnableToConstuctSqlTableProvider {
70 source: sql_provider_datafusion::Error,
71 },
72
73 #[snafu(display("Unable to create table in Sqlite: {source}"))]
74 UnableToCreateTable { source: tokio_rusqlite::Error },
75
76 #[snafu(display("Unable to insert data into the Sqlite table: {source}"))]
77 UnableToInsertIntoTable { source: rusqlite::Error },
78
79 #[snafu(display("Unable to insert data into the Sqlite table: {source}"))]
80 UnableToInsertIntoTableAsync { source: tokio_rusqlite::Error },
81
82 #[snafu(display("Unable to insert data into the Sqlite table. The disk is full."))]
83 DiskFull {},
84
85 #[snafu(display("Unable to deleta all table data in Sqlite: {source}"))]
86 UnableToDeleteAllTableData { source: rusqlite::Error },
87
88 #[snafu(display("There is a dangling reference to the Sqlite struct in TableProviderFactory.create. This is a bug."))]
89 DanglingReferenceToSqlite,
90
91 #[snafu(display("Constraint Violation: {source}"))]
92 ConstraintViolation { source: constraints::Error },
93
94 #[snafu(display("Error parsing column reference: {source}"))]
95 UnableToParseColumnReference { source: column_reference::Error },
96
97 #[snafu(display("Error parsing on_conflict: {source}"))]
98 UnableToParseOnConflict { source: on_conflict::Error },
99
100 #[snafu(display("Unable to infer schema: {source}"))]
101 UnableToInferSchema { source: dbconnection::Error },
102
103 #[snafu(display("Invalid SQLite busy_timeout value"))]
104 InvalidBusyTimeoutValue { value: String },
105
106 #[snafu(display(
107 "Unable to parse SQLite busy_timeout parameter, ensure it is a valid duration"
108 ))]
109 UnableToParseBusyTimeoutParameter { source: fundu::ParseError },
110
111 #[snafu(display(
112 "Failed to create '{table_name}': creating a table with a schema is not supported"
113 ))]
114 TableWithSchemaCreationNotSupported { table_name: String },
115}
116
117type Result<T, E = Error> = std::result::Result<T, E>;
118
119#[derive(Debug)]
120pub struct SqliteTableProviderFactory {
121 instances: Arc<Mutex<HashMap<DbInstanceKey, SqliteConnectionPool>>>,
122}
123
124const SQLITE_DB_PATH_PARAM: &str = "file";
125const SQLITE_DB_BASE_FOLDER_PARAM: &str = "data_directory";
126const SQLITE_ATTACH_DATABASES_PARAM: &str = "attach_databases";
127const SQLITE_BUSY_TIMEOUT_PARAM: &str = "busy_timeout";
128
129impl SqliteTableProviderFactory {
130 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 instances: Arc::new(Mutex::new(HashMap::new())),
134 }
135 }
136
137 #[must_use]
138 pub fn attach_databases(&self, options: &HashMap<String, String>) -> Option<Vec<Arc<str>>> {
139 options.get(SQLITE_ATTACH_DATABASES_PARAM).map(|databases| {
140 databases
141 .split(';')
142 .map(Arc::from)
143 .collect::<Vec<Arc<str>>>()
144 })
145 }
146
147 pub fn sqlite_file_path(
153 &self,
154 name: &str,
155 options: &HashMap<String, String>,
156 ) -> Result<String, Error> {
157 let options = util::remove_prefix_from_hashmap_keys(options.clone(), "sqlite_");
158
159 let db_base_folder = options
160 .get(SQLITE_DB_BASE_FOLDER_PARAM)
161 .cloned()
162 .unwrap_or(".".to_string()); let default_filepath = &format!("{db_base_folder}/{name}_sqlite.db");
164
165 let filepath = options
166 .get(SQLITE_DB_PATH_PARAM)
167 .unwrap_or(default_filepath);
168
169 Ok(filepath.to_string())
170 }
171
172 pub fn sqlite_busy_timeout(&self, options: &HashMap<String, String>) -> Result<Duration> {
173 let busy_timeout = options.get(SQLITE_BUSY_TIMEOUT_PARAM).cloned();
174 match busy_timeout {
175 Some(busy_timeout) => {
176 let duration = fundu::parse_duration(&busy_timeout)
177 .context(UnableToParseBusyTimeoutParameterSnafu)?;
178 Ok(duration)
179 }
180 None => Ok(Duration::from_millis(5000)),
181 }
182 }
183
184 pub async fn get_or_init_instance(
185 &self,
186 db_path: impl Into<Arc<str>>,
187 mode: Mode,
188 busy_timeout: Duration,
189 ) -> Result<SqliteConnectionPool> {
190 let db_path = db_path.into();
191 let key = match mode {
192 Mode::Memory => DbInstanceKey::memory(),
193 Mode::File => DbInstanceKey::file(Arc::clone(&db_path)),
194 };
195 let mut instances = self.instances.lock().await;
196
197 if let Some(instance) = instances.get(&key) {
198 return instance.try_clone().await.context(DbConnectionPoolSnafu);
199 }
200
201 let pool = SqliteConnectionPoolFactory::new(&db_path, mode, busy_timeout)
202 .build()
203 .await
204 .context(DbConnectionPoolSnafu)?;
205
206 instances.insert(key, pool.try_clone().await.context(DbConnectionPoolSnafu)?);
207
208 Ok(pool)
209 }
210}
211
212impl Default for SqliteTableProviderFactory {
213 fn default() -> Self {
214 Self::new()
215 }
216}
217
218pub type DynSqliteConnectionPool =
219 dyn DbConnectionPool<Connection, &'static (dyn ToSql + Sync)> + Send + Sync;
220
221#[async_trait]
222impl TableProviderFactory for SqliteTableProviderFactory {
223 #[allow(clippy::too_many_lines)]
224 async fn create(
225 &self,
226 _state: &dyn Session,
227 cmd: &CreateExternalTable,
228 ) -> DataFusionResult<Arc<dyn TableProvider>> {
229 if cmd.name.schema().is_some() {
230 TableWithSchemaCreationNotSupportedSnafu {
231 table_name: cmd.name.to_string(),
232 }
233 .fail()
234 .map_err(to_datafusion_error)?;
235 }
236
237 let name = cmd.name.clone();
238 let mut options = cmd.options.clone();
239 let mode = options.remove("mode").unwrap_or_default();
240 let mode: Mode = mode.as_str().into();
241
242 let indexes_option_str = options.remove("indexes");
243 let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
244 Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
245 None => HashMap::new(),
246 };
247
248 let unparsed_indexes = unparsed_indexes
249 .into_iter()
250 .map(|(key, value)| {
251 let columns = ColumnReference::try_from(key.as_str())
252 .context(UnableToParseColumnReferenceSnafu)
253 .map_err(to_datafusion_error);
254 (columns, value)
255 })
256 .collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
257
258 let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
259 for (columns, index_type) in unparsed_indexes {
260 let columns = columns?;
261 indexes.push((columns, index_type));
262 }
263
264 let mut on_conflict: Option<OnConflict> = None;
265 if let Some(on_conflict_str) = options.remove("on_conflict") {
266 on_conflict = Some(
267 OnConflict::try_from(on_conflict_str.as_str())
268 .context(UnableToParseOnConflictSnafu)
269 .map_err(to_datafusion_error)?,
270 );
271 }
272
273 let busy_timeout = self
274 .sqlite_busy_timeout(&cmd.options)
275 .map_err(to_datafusion_error)?;
276 let db_path: Arc<str> = self
277 .sqlite_file_path(name.table(), &cmd.options)
278 .map_err(to_datafusion_error)?
279 .into();
280
281 let pool: Arc<SqliteConnectionPool> = Arc::new(
282 self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout)
283 .await
284 .map_err(to_datafusion_error)?,
285 );
286
287 let read_pool = if mode == Mode::Memory {
288 Arc::clone(&pool)
289 } else {
290 Arc::new(
294 self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout)
295 .await
296 .map_err(to_datafusion_error)?,
297 )
298 };
299
300 let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into());
301 let schema: SchemaRef =
302 SqliteConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
303 .map_err(|e| DataFusionError::External(e.into()))?;
304
305 let sqlite = Arc::new(Sqlite::new(
306 name.clone(),
307 Arc::clone(&schema),
308 Arc::clone(&pool),
309 cmd.constraints.clone(),
310 ));
311
312 let mut db_conn = sqlite.connect().await.map_err(to_datafusion_error)?;
313 let sqlite_conn = Sqlite::sqlite_conn(&mut db_conn).map_err(to_datafusion_error)?;
314
315 let primary_keys = get_primary_keys_from_constraints(&cmd.constraints, &schema);
316
317 let table_exists = sqlite.table_exists(sqlite_conn).await;
318 if !table_exists {
319 let sqlite_in_conn = Arc::clone(&sqlite);
320 sqlite_conn
321 .conn
322 .call(move |conn| {
323 let transaction = conn.transaction()?;
324 sqlite_in_conn.create_table(&transaction, primary_keys)?;
325 for index in indexes {
326 sqlite_in_conn.create_index(
327 &transaction,
328 index.0.iter().collect(),
329 index.1 == IndexType::Unique,
330 )?;
331 }
332 transaction.commit()?;
333 Ok(())
334 })
335 .await
336 .context(UnableToCreateTableSnafu)
337 .map_err(to_datafusion_error)?;
338 } else {
339 let mut table_definition_matches = true;
340
341 table_definition_matches &= sqlite.verify_indexes_match(sqlite_conn, &indexes).await?;
342 table_definition_matches &= sqlite
343 .verify_primary_keys_match(sqlite_conn, &primary_keys)
344 .await?;
345
346 if !table_definition_matches {
347 tracing::warn!(
348 "The local table definition at '{db_path}' for '{name}' does not match the expected configuration. To fix this, drop the existing local copy. A new table with the correct schema will be automatically created upon first access.",
349 name = name
350 );
351 }
352 }
353
354 let dyn_pool: Arc<DynSqliteConnectionPool> = read_pool;
355
356 let read_provider = Arc::new(SQLiteTable::new_with_schema(
357 &dyn_pool,
358 Arc::clone(&schema),
359 name,
360 ));
361
362 let sqlite = Arc::into_inner(sqlite)
363 .context(DanglingReferenceToSqliteSnafu)
364 .map_err(to_datafusion_error)?;
365
366 #[cfg(feature = "sqlite-federation")]
367 let read_provider: Arc<dyn TableProvider> =
368 Arc::new(read_provider.create_federated_table_provider()?);
369
370 Ok(SqliteTableWriter::create(
371 read_provider,
372 sqlite,
373 on_conflict,
374 ))
375 }
376}
377
378pub struct SqliteTableFactory {
379 pool: Arc<SqliteConnectionPool>,
380}
381
382impl SqliteTableFactory {
383 #[must_use]
384 pub fn new(pool: Arc<SqliteConnectionPool>) -> Self {
385 Self { pool }
386 }
387
388 pub async fn table_provider(
389 &self,
390 table_reference: TableReference,
391 ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
392 let pool = Arc::clone(&self.pool);
393
394 let conn = pool.connect().await.context(DbConnectionSnafu)?;
395 let schema = get_schema(conn, &table_reference)
396 .await
397 .context(UnableToInferSchemaSnafu)?;
398
399 let dyn_pool: Arc<DynSqliteConnectionPool> = pool;
400
401 let read_provider = Arc::new(SQLiteTable::new_with_schema(
402 &dyn_pool,
403 Arc::clone(&schema),
404 table_reference,
405 ));
406
407 Ok(read_provider)
408 }
409}
410
411fn to_datafusion_error(error: Error) -> DataFusionError {
412 DataFusionError::External(Box::new(error))
413}
414
415#[derive(Clone)]
416pub struct Sqlite {
417 table: TableReference,
418 schema: SchemaRef,
419 pool: Arc<SqliteConnectionPool>,
420 constraints: Constraints,
421}
422
423impl std::fmt::Debug for Sqlite {
424 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425 f.debug_struct("Sqlite")
426 .field("table_name", &self.table)
427 .field("schema", &self.schema)
428 .field("constraints", &self.constraints)
429 .finish()
430 }
431}
432
433impl Sqlite {
434 #[must_use]
435 pub fn new(
436 table: TableReference,
437 schema: SchemaRef,
438 pool: Arc<SqliteConnectionPool>,
439 constraints: Constraints,
440 ) -> Self {
441 Self {
442 table,
443 schema,
444 pool,
445 constraints,
446 }
447 }
448
449 #[must_use]
450 pub fn table_name(&self) -> &str {
451 self.table.table()
452 }
453
454 #[must_use]
455 pub fn constraints(&self) -> &Constraints {
456 &self.constraints
457 }
458
459 pub async fn connect(
460 &self,
461 ) -> Result<Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>>> {
462 self.pool.connect().await.context(DbConnectionSnafu)
463 }
464
465 pub fn sqlite_conn<'a>(
466 db_connection: &'a mut Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>>,
467 ) -> Result<&'a mut SqliteConnection> {
468 db_connection
469 .as_any_mut()
470 .downcast_mut::<SqliteConnection>()
471 .ok_or_else(|| UnableToDowncastDbConnectionSnafu {}.build())
472 }
473
474 async fn table_exists(&self, sqlite_conn: &mut SqliteConnection) -> bool {
475 let sql = format!(
476 r#"SELECT EXISTS (
477 SELECT 1
478 FROM sqlite_master
479 WHERE type='table'
480 AND name = '{name}'
481 )"#,
482 name = self.table
483 );
484 tracing::trace!("{sql}");
485
486 sqlite_conn
487 .conn
488 .call(move |conn| {
489 let mut stmt = conn.prepare(&sql)?;
490 let exists = stmt.query_row([], |row| row.get(0))?;
491 Ok(exists)
492 })
493 .await
494 .unwrap_or(false)
495 }
496
497 fn insert_batch(
498 &self,
499 transaction: &Transaction<'_>,
500 batch: RecordBatch,
501 on_conflict: Option<&OnConflict>,
502 ) -> rusqlite::Result<()> {
503 let insert_table_builder = InsertBuilder::new(&self.table, vec![batch]);
504
505 let sea_query_on_conflict =
506 on_conflict.map(|oc| oc.build_sea_query_on_conflict(&self.schema));
507
508 let sql = insert_table_builder
509 .build_sqlite(sea_query_on_conflict)
510 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
511
512 transaction.execute(&sql, [])?;
513
514 Ok(())
515 }
516
517 fn delete_all_table_data(&self, transaction: &Transaction<'_>) -> rusqlite::Result<()> {
518 transaction.execute(
519 format!(r#"DELETE FROM {}"#, self.table.to_quoted_string()).as_str(),
520 [],
521 )?;
522
523 Ok(())
524 }
525
526 fn create_table(
527 &self,
528 transaction: &Transaction<'_>,
529 primary_keys: Vec<String>,
530 ) -> rusqlite::Result<()> {
531 let create_table_statement =
532 CreateTableBuilder::new(Arc::clone(&self.schema), self.table.table())
533 .primary_keys(primary_keys);
534 let sql = create_table_statement.build_sqlite();
535
536 transaction.execute(&sql, [])?;
537
538 Ok(())
539 }
540
541 fn create_index(
542 &self,
543 transaction: &Transaction<'_>,
544 columns: Vec<&str>,
545 unique: bool,
546 ) -> rusqlite::Result<()> {
547 let mut index_builder = IndexBuilder::new(self.table.table(), columns);
548 if unique {
549 index_builder = index_builder.unique();
550 }
551 let sql = index_builder.build_sqlite();
552
553 transaction.execute(&sql, [])?;
554
555 Ok(())
556 }
557
558 async fn get_indexes(
559 &self,
560 sqlite_conn: &mut SqliteConnection,
561 ) -> DataFusionResult<HashSet<String>> {
562 let query_result = sqlite_conn
563 .query_arrow(
564 format!("PRAGMA index_list({name})", name = self.table).as_str(),
565 &[],
566 None,
567 )
568 .await?;
569
570 let mut indexes = HashSet::new();
571
572 query_result
573 .try_collect::<Vec<RecordBatch>>()
574 .await
575 .into_iter()
576 .flatten()
577 .for_each(|batch| {
578 if let Some(name_array) = batch
579 .column_by_name("name")
580 .and_then(|col| col.as_any().downcast_ref::<StringArray>())
581 {
582 for index_name in name_array.iter().flatten() {
583 if !index_name.starts_with("sqlite_autoindex_") {
585 indexes.insert(index_name.to_string());
586 }
587 }
588 }
589 });
590
591 Ok(indexes)
592 }
593
594 async fn get_primary_keys(
595 &self,
596 sqlite_conn: &mut SqliteConnection,
597 ) -> DataFusionResult<HashSet<String>> {
598 let query_result = sqlite_conn
599 .query_arrow(
600 format!("PRAGMA table_info({name})", name = self.table).as_str(),
601 &[],
602 None,
603 )
604 .await?;
605
606 let mut primary_keys = HashSet::new();
607
608 query_result
609 .try_collect::<Vec<RecordBatch>>()
610 .await
611 .into_iter()
612 .flatten()
613 .for_each(|batch| {
614 if let (Some(name_array), Some(pk_array)) = (
615 batch
616 .column_by_name("name")
617 .and_then(|col| col.as_any().downcast_ref::<StringArray>()),
618 batch
619 .column_by_name("pk")
620 .and_then(|col| col.as_any().downcast_ref::<Int64Array>()),
621 ) {
622 for (name, pk) in name_array.iter().flatten().zip(pk_array.iter().flatten()) {
624 if pk > 0 {
625 primary_keys.insert(name.to_string());
627 }
628 }
629 }
630 });
631
632 Ok(primary_keys)
633 }
634
635 async fn verify_indexes_match(
636 &self,
637 sqlite_conn: &mut SqliteConnection,
638 indexes: &[(ColumnReference, IndexType)],
639 ) -> DataFusionResult<bool> {
640 let expected_indexes_str_map: HashSet<String> = indexes
641 .iter()
642 .map(|(col, _)| {
643 IndexBuilder::new(self.table.table(), col.iter().collect()).index_name()
644 })
645 .collect();
646
647 let actual_indexes_str_map = self.get_indexes(sqlite_conn).await?;
648
649 let missing_in_actual = expected_indexes_str_map
650 .difference(&actual_indexes_str_map)
651 .collect::<Vec<_>>();
652 let extra_in_actual = actual_indexes_str_map
653 .difference(&expected_indexes_str_map)
654 .collect::<Vec<_>>();
655
656 if !missing_in_actual.is_empty() {
657 tracing::warn!(
658 "Missing indexes detected for the table '{name}': {:?}.",
659 missing_in_actual,
660 name = self.table
661 );
662 }
663 if !extra_in_actual.is_empty() {
664 tracing::warn!(
665 "The table '{name}' contains unexpected indexes not presented in the configuration: {:?}.",
666 extra_in_actual,
667 name = self.table
668 );
669 }
670
671 Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
672 }
673
674 async fn verify_primary_keys_match(
675 &self,
676 sqlite_conn: &mut SqliteConnection,
677 primary_keys: &[String],
678 ) -> DataFusionResult<bool> {
679 let expected_pk_keys_str_map: HashSet<String> = primary_keys.iter().cloned().collect();
680
681 let actual_pk_keys_str_map = self.get_primary_keys(sqlite_conn).await?;
682
683 let missing_in_actual = expected_pk_keys_str_map
684 .difference(&actual_pk_keys_str_map)
685 .collect::<Vec<_>>();
686 let extra_in_actual = actual_pk_keys_str_map
687 .difference(&expected_pk_keys_str_map)
688 .collect::<Vec<_>>();
689
690 if !missing_in_actual.is_empty() {
691 tracing::warn!(
692 "Missing primary keys detected for the table '{name}': {:?}.",
693 missing_in_actual,
694 name = self.table
695 );
696 }
697 if !extra_in_actual.is_empty() {
698 tracing::warn!(
699 "The table '{name}' contains unexpected primary keys not presented in the configuration: {:?}.",
700 extra_in_actual,
701 name = self.table
702 );
703 }
704
705 Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
706 }
707}
708
709#[cfg(test)]
710pub(crate) mod tests {
711 use arrow::datatypes::{DataType, Schema};
712 use datafusion::{
713 common::{Constraint, ToDFSchema},
714 prelude::SessionContext,
715 };
716
717 use super::*;
718
719 #[tokio::test]
720 async fn test_sqlite_table_creation_with_indexes() {
721 let schema = Arc::new(Schema::new(vec![
722 arrow::datatypes::Field::new("first_name", DataType::Utf8, false),
723 arrow::datatypes::Field::new("last_name", DataType::Utf8, false),
724 arrow::datatypes::Field::new("id", DataType::Int64, false),
725 ]));
726
727 let options: HashMap<String, String> = [(
728 "indexes".to_string(),
729 "id:enabled;(first_name, last_name):unique".to_string(),
730 )]
731 .iter()
732 .cloned()
733 .collect();
734
735 let expected_indexes: HashSet<String> = [
736 "i_test_table_id".to_string(),
737 "i_test_table_first_name_last_name".to_string(),
738 ]
739 .iter()
740 .cloned()
741 .collect();
742
743 let df_schema = ToDFSchema::to_dfschema_ref(Arc::clone(&schema)).expect("df schema");
744
745 let primary_keys_constraints = {
746 let schema = Arc::clone(&schema);
747 let indices: Vec<usize> = ["id"]
748 .iter()
749 .filter_map(|&col_name| schema.column_with_name(col_name).map(|(index, _)| index))
750 .collect();
751
752 Constraints::new_unverified(vec![Constraint::PrimaryKey(indices)])
753 };
754
755 let external_table = CreateExternalTable {
756 schema: df_schema,
757 name: TableReference::bare("test_table"),
758 location: String::new(),
759 file_type: String::new(),
760 table_partition_cols: vec![],
761 if_not_exists: true,
762 definition: None,
763 order_exprs: vec![],
764 unbounded: false,
765 options,
766 constraints: primary_keys_constraints,
767 column_defaults: HashMap::default(),
768 temporary: false,
769 };
770 let ctx = SessionContext::new();
771 let table = SqliteTableProviderFactory::default()
772 .create(&ctx.state(), &external_table)
773 .await
774 .expect("table should be created");
775
776 let sqlite = table
777 .as_any()
778 .downcast_ref::<SqliteTableWriter>()
779 .expect("downcast to SqliteTableWriter")
780 .sqlite();
781
782 let mut db_conn = sqlite.connect().await.expect("should connect to db");
783 let sqlite_conn =
784 Sqlite::sqlite_conn(&mut db_conn).expect("should create sqlite connection");
785
786 let retrieved_indexes = sqlite
787 .get_indexes(sqlite_conn)
788 .await
789 .expect("should get indexes");
790
791 assert_eq!(retrieved_indexes, expected_indexes);
792
793 let retrieved_primary_keys = sqlite
794 .get_primary_keys(sqlite_conn)
795 .await
796 .expect("should get primary keys");
797
798 assert_eq!(
799 retrieved_primary_keys,
800 vec!["id".to_string()]
801 .into_iter()
802 .collect::<HashSet<String>>()
803 );
804 }
805}