use crate::sql::arrow_sql_gen::statement::{
table_reference_to_sea_table_ref, CreateTableBuilder, IndexBuilder, InsertBuilder,
};
use crate::sql::db_connection_pool::dbconnection::{self, get_schema, AsyncDbConnection};
use crate::sql::db_connection_pool::sqlitepool::SqliteConnectionPoolFactory;
use crate::sql::db_connection_pool::DbInstanceKey;
use crate::sql::db_connection_pool::{
self,
dbconnection::{sqliteconn::SqliteConnection, DbConnection},
sqlitepool::SqliteConnectionPool,
DbConnectionPool, Mode,
};
use crate::sql::sql_provider_datafusion;
use crate::util::schema::SchemaValidator;
use crate::UnsupportedTypeAction;
use arrow::array::{Int64Array, StringArray};
use arrow::{array::RecordBatch, datatypes::SchemaRef};
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::{
catalog::TableProviderFactory,
common::Constraints,
datasource::TableProvider,
error::{DataFusionError, Result as DataFusionResult},
logical_expr::CreateExternalTable,
sql::TableReference,
};
use futures::TryStreamExt;
use rusqlite::{ToSql, Transaction};
use snafu::prelude::*;
use sql_table::SQLiteTable;
use std::collections::HashSet;
use std::time::Duration;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
use tokio_rusqlite::Connection;
use crate::util::{
self,
column_reference::{self, ColumnReference},
constraints::{self, get_primary_keys_from_constraints},
indexes::IndexType,
on_conflict::{self, OnConflict},
};
use self::write::SqliteTableWriter;
#[cfg(feature = "sqlite-federation")]
pub mod federation;
#[cfg(feature = "sqlite-federation")]
pub mod sqlite_interval;
pub mod sql_table;
pub mod write;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("DbConnectionError: {source}"))]
DbConnectionError {
source: db_connection_pool::dbconnection::GenericError,
},
#[snafu(display("DbConnectionPoolError: {source}"))]
DbConnectionPoolError { source: db_connection_pool::Error },
#[snafu(display("Unable to downcast DbConnection to SqliteConnection"))]
UnableToDowncastDbConnection {},
#[snafu(display("Unable to construct SQLTable instance: {source}"))]
UnableToConstuctSqlTableProvider {
source: sql_provider_datafusion::Error,
},
#[snafu(display("Unable to create table in Sqlite: {source}"))]
UnableToCreateTable {
source: tokio_rusqlite::Error<rusqlite::Error>,
},
#[snafu(display("Unable to insert data into the Sqlite table: {source}"))]
UnableToInsertIntoTable { source: rusqlite::Error },
#[snafu(display("Unable to insert data into the Sqlite table: {source}"))]
UnableToInsertIntoTableAsync {
source: tokio_rusqlite::Error<rusqlite::Error>,
},
#[snafu(display("Unable to insert data into the Sqlite table. The disk is full."))]
DiskFull {},
#[snafu(display("Unable to deleta all table data in Sqlite: {source}"))]
UnableToDeleteAllTableData { source: rusqlite::Error },
#[snafu(display("There is a dangling reference to the Sqlite struct in TableProviderFactory.create. This is a bug."))]
DanglingReferenceToSqlite,
#[snafu(display("Constraint Violation: {source}"))]
ConstraintViolation { source: constraints::Error },
#[snafu(display("Error parsing column reference: {source}"))]
UnableToParseColumnReference { source: column_reference::Error },
#[snafu(display("Error parsing on_conflict: {source}"))]
UnableToParseOnConflict { source: on_conflict::Error },
#[snafu(display("Unable to infer schema: {source}"))]
UnableToInferSchema { source: dbconnection::Error },
#[snafu(display("Invalid SQLite busy_timeout value"))]
InvalidBusyTimeoutValue { value: String },
#[snafu(display(
"Unable to parse SQLite busy_timeout parameter, ensure it is a valid duration"
))]
UnableToParseBusyTimeoutParameter { source: fundu::ParseError },
#[snafu(display(
"Failed to create '{table_name}': creating a table with a schema is not supported"
))]
TableWithSchemaCreationNotSupported { table_name: String },
}
type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug)]
pub struct SqliteTableProviderFactory {
instances: Arc<Mutex<HashMap<DbInstanceKey, SqliteConnectionPool>>>,
batch_insert_use_prepared_statements: bool,
}
const SQLITE_DB_PATH_PARAM: &str = "file";
const SQLITE_DB_BASE_FOLDER_PARAM: &str = "data_directory";
const SQLITE_ATTACH_DATABASES_PARAM: &str = "attach_databases";
const SQLITE_BUSY_TIMEOUT_PARAM: &str = "busy_timeout";
impl SqliteTableProviderFactory {
#[must_use]
pub fn new() -> Self {
Self {
instances: Arc::new(Mutex::new(HashMap::new())),
batch_insert_use_prepared_statements: false,
}
}
#[must_use]
pub fn with_batch_insert_use_prepared_statements(mut self, use_prepared: bool) -> Self {
self.batch_insert_use_prepared_statements = use_prepared;
self
}
#[must_use]
pub fn attach_databases(&self, options: &HashMap<String, String>) -> Option<Vec<Arc<str>>> {
options.get(SQLITE_ATTACH_DATABASES_PARAM).map(|databases| {
databases
.split(';')
.map(Arc::from)
.collect::<Vec<Arc<str>>>()
})
}
pub fn sqlite_file_path(
&self,
name: &str,
options: &HashMap<String, String>,
) -> Result<String, Error> {
let options = util::remove_prefix_from_hashmap_keys(options.clone(), "sqlite_");
let db_base_folder = options
.get(SQLITE_DB_BASE_FOLDER_PARAM)
.cloned()
.unwrap_or(".".to_string()); let default_filepath = &format!("{db_base_folder}/{name}_sqlite.db");
let filepath = options
.get(SQLITE_DB_PATH_PARAM)
.unwrap_or(default_filepath);
Ok(filepath.to_string())
}
pub fn sqlite_busy_timeout(&self, options: &HashMap<String, String>) -> Result<Duration> {
let busy_timeout = options.get(SQLITE_BUSY_TIMEOUT_PARAM).cloned();
match busy_timeout {
Some(busy_timeout) => {
let duration = fundu::parse_duration(&busy_timeout)
.context(UnableToParseBusyTimeoutParameterSnafu)?;
Ok(duration)
}
None => Ok(Duration::from_millis(5000)),
}
}
pub async fn get_or_init_instance(
&self,
db_path: impl Into<Arc<str>>,
mode: Mode,
busy_timeout: Duration,
) -> Result<SqliteConnectionPool> {
let db_path = db_path.into();
let key = match mode {
Mode::Memory => DbInstanceKey::memory(),
Mode::File => DbInstanceKey::file(Arc::clone(&db_path)),
};
let mut instances = self.instances.lock().await;
if let Some(instance) = instances.get(&key) {
return instance.try_clone().await.context(DbConnectionPoolSnafu);
}
let pool = SqliteConnectionPoolFactory::new(&db_path, mode, busy_timeout)
.build()
.await
.context(DbConnectionPoolSnafu)?;
instances.insert(key, pool.try_clone().await.context(DbConnectionPoolSnafu)?);
Ok(pool)
}
}
impl Default for SqliteTableProviderFactory {
fn default() -> Self {
Self::new()
}
}
pub type DynSqliteConnectionPool =
dyn DbConnectionPool<Connection, &'static (dyn ToSql + Sync)> + Send + Sync;
#[async_trait]
impl TableProviderFactory for SqliteTableProviderFactory {
#[allow(clippy::too_many_lines)]
async fn create(
&self,
_state: &dyn Session,
cmd: &CreateExternalTable,
) -> DataFusionResult<Arc<dyn TableProvider>> {
if cmd.name.schema().is_some() {
TableWithSchemaCreationNotSupportedSnafu {
table_name: cmd.name.to_string(),
}
.fail()
.map_err(to_datafusion_error)?;
}
let name = cmd.name.clone();
let mut options = cmd.options.clone();
let mode = options.remove("mode").unwrap_or_default();
let mode: Mode = mode.as_str().into();
let indexes_option_str = options.remove("indexes");
let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
None => HashMap::new(),
};
let unparsed_indexes = unparsed_indexes
.into_iter()
.map(|(key, value)| {
let columns = ColumnReference::try_from(key.as_str())
.context(UnableToParseColumnReferenceSnafu)
.map_err(to_datafusion_error);
(columns, value)
})
.collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
for (columns, index_type) in unparsed_indexes {
let columns = columns?;
indexes.push((columns, index_type));
}
let mut on_conflict: Option<OnConflict> = None;
if let Some(on_conflict_str) = options.remove("on_conflict") {
on_conflict = Some(
OnConflict::try_from(on_conflict_str.as_str())
.context(UnableToParseOnConflictSnafu)
.map_err(to_datafusion_error)?,
);
}
let busy_timeout = self
.sqlite_busy_timeout(&cmd.options)
.map_err(to_datafusion_error)?;
let db_path: Arc<str> = self
.sqlite_file_path(name.table(), &cmd.options)
.map_err(to_datafusion_error)?
.into();
let pool: Arc<SqliteConnectionPool> = Arc::new(
self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout)
.await
.map_err(to_datafusion_error)?,
);
let read_pool = if mode == Mode::Memory {
Arc::clone(&pool)
} else {
Arc::new(
self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout)
.await
.map_err(to_datafusion_error)?,
)
};
let schema: SchemaRef = Arc::new(cmd.schema.as_ref().as_arrow().clone());
let schema: SchemaRef =
SqliteConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
.map_err(|e| DataFusionError::External(e.into()))?;
let sqlite = Arc::new(
Sqlite::new(
name.clone(),
Arc::clone(&schema),
Arc::clone(&pool),
cmd.constraints.clone(),
)
.with_batch_insert_use_prepared_statements(self.batch_insert_use_prepared_statements),
);
let mut db_conn = sqlite.connect().await.map_err(to_datafusion_error)?;
let sqlite_conn = Sqlite::sqlite_conn(&mut db_conn).map_err(to_datafusion_error)?;
let primary_keys = get_primary_keys_from_constraints(&cmd.constraints, &schema);
let table_exists = sqlite.table_exists(sqlite_conn).await;
if !table_exists {
let sqlite_in_conn = Arc::clone(&sqlite);
sqlite_conn
.conn
.call(move |conn| {
let transaction = conn.transaction()?;
sqlite_in_conn.create_table(&transaction, primary_keys)?;
for index in indexes {
sqlite_in_conn.create_index(
&transaction,
index.0.iter().collect(),
index.1 == IndexType::Unique,
)?;
}
transaction.commit()?;
Ok(())
})
.await
.context(UnableToCreateTableSnafu)
.map_err(to_datafusion_error)?;
} else {
let mut table_definition_matches = true;
table_definition_matches &= sqlite.verify_indexes_match(sqlite_conn, &indexes).await?;
table_definition_matches &= sqlite
.verify_primary_keys_match(sqlite_conn, &primary_keys)
.await?;
if !table_definition_matches {
tracing::warn!(
"The local table definition at '{db_path}' for '{name}' does not match the expected configuration. To fix this, drop the existing local copy. A new table with the correct schema will be automatically created upon first access.",
name = name
);
}
}
let dyn_pool: Arc<DynSqliteConnectionPool> = read_pool;
let read_provider = Arc::new(SQLiteTable::new_with_schema(
&dyn_pool,
Arc::clone(&schema),
name,
));
let sqlite = Arc::into_inner(sqlite)
.context(DanglingReferenceToSqliteSnafu)
.map_err(to_datafusion_error)?;
#[cfg(feature = "sqlite-federation")]
let read_provider: Arc<dyn TableProvider> =
Arc::new(read_provider.create_federated_table_provider()?);
Ok(SqliteTableWriter::create(
read_provider,
sqlite,
on_conflict,
))
}
}
pub struct SqliteTableFactory {
pool: Arc<SqliteConnectionPool>,
batch_insert_use_prepared_statements: bool,
}
impl SqliteTableFactory {
#[must_use]
pub fn new(pool: Arc<SqliteConnectionPool>) -> Self {
Self {
pool,
batch_insert_use_prepared_statements: false,
}
}
#[must_use]
pub fn with_batch_insert_use_prepared_statements(mut self, use_prepared: bool) -> Self {
self.batch_insert_use_prepared_statements = use_prepared;
self
}
pub async fn table_provider(
&self,
table_reference: TableReference,
) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
let pool = Arc::clone(&self.pool);
let conn = pool.connect().await.context(DbConnectionSnafu)?;
let schema = get_schema(conn, &table_reference)
.await
.context(UnableToInferSchemaSnafu)?;
let dyn_pool: Arc<DynSqliteConnectionPool> = pool;
let read_provider = Arc::new(SQLiteTable::new_with_schema(
&dyn_pool,
Arc::clone(&schema),
table_reference,
));
Ok(read_provider)
}
}
fn to_datafusion_error(error: Error) -> DataFusionError {
DataFusionError::External(Box::new(error))
}
#[derive(Clone)]
pub struct Sqlite {
table: TableReference,
schema: SchemaRef,
pool: Arc<SqliteConnectionPool>,
constraints: Constraints,
batch_insert_use_prepared_statements: bool,
}
impl std::fmt::Debug for Sqlite {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sqlite")
.field("table_name", &self.table)
.field("schema", &self.schema)
.field("constraints", &self.constraints)
.finish()
}
}
impl Sqlite {
#[must_use]
pub fn new(
table: TableReference,
schema: SchemaRef,
pool: Arc<SqliteConnectionPool>,
constraints: Constraints,
) -> Self {
Self {
table,
schema,
pool,
constraints,
batch_insert_use_prepared_statements: false,
}
}
#[must_use]
pub fn with_batch_insert_use_prepared_statements(mut self, use_prepared: bool) -> Self {
self.batch_insert_use_prepared_statements = use_prepared;
self
}
#[must_use]
pub fn table_name(&self) -> &str {
self.table.table()
}
#[must_use]
pub fn constraints(&self) -> &Constraints {
&self.constraints
}
pub async fn connect(
&self,
) -> Result<Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>>> {
self.pool.connect().await.context(DbConnectionSnafu)
}
pub fn sqlite_conn<'a>(
db_connection: &'a mut Box<dyn DbConnection<Connection, &'static (dyn ToSql + Sync)>>,
) -> Result<&'a mut SqliteConnection> {
db_connection
.as_any_mut()
.downcast_mut::<SqliteConnection>()
.ok_or_else(|| UnableToDowncastDbConnectionSnafu {}.build())
}
async fn table_exists(&self, sqlite_conn: &mut SqliteConnection) -> bool {
let sql = format!(
"SELECT EXISTS (
SELECT 1
FROM sqlite_master
WHERE type='table'
AND name = '{name}'
)",
name = self.table
);
tracing::trace!("{sql}");
sqlite_conn
.conn
.call(move |conn| {
let mut stmt = conn.prepare(&sql)?;
let exists = stmt.query_row([], |row| row.get(0))?;
Ok::<bool, rusqlite::Error>(exists)
})
.await
.unwrap_or(false)
}
#[allow(dead_code)]
#[deprecated(note = "Use insert_batch_prepared instead for better performance")]
fn insert_batch(
&self,
transaction: &Transaction<'_>,
batch: RecordBatch,
on_conflict: Option<&OnConflict>,
) -> rusqlite::Result<()> {
let batches = vec![batch];
let insert_table_builder = InsertBuilder::new(&self.table, &batches);
let sea_query_on_conflict =
on_conflict.map(|oc| oc.build_sea_query_on_conflict(&self.schema));
let sql = insert_table_builder
.build_sqlite(sea_query_on_conflict)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(e.into()))?;
transaction.execute(&sql, [])?;
Ok(())
}
#[allow(clippy::too_many_lines)]
fn insert_batch_prepared(
&self,
transaction: &Transaction<'_>,
batch: RecordBatch,
on_conflict: Option<&OnConflict>,
) -> rusqlite::Result<()> {
use arrow::array::*;
use arrow::datatypes::DataType;
if batch.num_rows() == 0 {
return Ok(());
}
let schema = batch.schema();
let column_names: Vec<String> = schema
.fields()
.iter()
.map(|f| format!("\"{}\"", f.name()))
.collect();
let placeholders: Vec<String> = (0..schema.fields().len())
.map(|_| "?")
.map(String::from)
.collect();
let mut sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
self.table.to_quoted_string(),
column_names.join(", "),
placeholders.join(", ")
);
if let Some(oc) = on_conflict {
use sea_query::{Alias, Query, SqliteQueryBuilder};
let sea_query_on_conflict = oc.build_sea_query_on_conflict(&self.schema);
let table_ref = table_reference_to_sea_table_ref(&self.table);
let mut dummy_insert = Query::insert();
dummy_insert.into_table(table_ref);
dummy_insert.columns(vec![Alias::new("dummy")]);
dummy_insert.on_conflict(sea_query_on_conflict);
let full_sql = dummy_insert.to_string(SqliteQueryBuilder);
if let Some(idx) = full_sql.find("ON CONFLICT") {
sql.push(' ');
sql.push_str(&full_sql[idx..]);
}
}
let mut stmt = transaction.prepare_cached(&sql)?;
for row_idx in 0..batch.num_rows() {
let mut params: Vec<Box<dyn ToSql>> = Vec::with_capacity(batch.num_columns());
for col_idx in 0..batch.num_columns() {
let column = batch.column(col_idx);
let data_type = column.data_type();
match data_type {
DataType::Int8 => {
let array = column.as_any().downcast_ref::<Int8Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::Int16 => {
let array = column.as_any().downcast_ref::<Int16Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::Int32 => {
let array = column.as_any().downcast_ref::<Int32Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::Int64 => {
let array = column.as_any().downcast_ref::<Int64Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::UInt8 => {
let array = column.as_any().downcast_ref::<UInt8Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::UInt16 => {
let array = column.as_any().downcast_ref::<UInt16Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::UInt32 => {
let array = column.as_any().downcast_ref::<UInt32Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx) as i64));
}
}
DataType::UInt64 => {
let array = column.as_any().downcast_ref::<UInt64Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx) as i64));
}
}
DataType::Float32 => {
let array = column.as_any().downcast_ref::<Float32Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::Float64 => {
let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::Utf8 => {
let array = column.as_any().downcast_ref::<StringArray>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_string()));
}
}
DataType::LargeUtf8 => {
let array = column.as_any().downcast_ref::<LargeStringArray>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_string()));
}
}
DataType::Boolean => {
let array = column.as_any().downcast_ref::<BooleanArray>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx)));
}
}
DataType::Binary => {
let array = column.as_any().downcast_ref::<BinaryArray>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_vec()));
}
}
DataType::LargeBinary => {
let array = column.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_vec()));
}
}
DataType::Date32 => {
let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
let days = array.value(row_idx);
let timestamp = i64::from(days) * 86_400;
params.push(Box::new(timestamp));
}
}
DataType::Date64 => {
let array = column.as_any().downcast_ref::<Date64Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
let millis = array.value(row_idx);
let timestamp = millis / 1000;
params.push(Box::new(timestamp));
}
}
DataType::Timestamp(unit, _) => {
if column.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
match unit {
arrow::datatypes::TimeUnit::Second => {
let array = column
.as_any()
.downcast_ref::<TimestampSecondArray>()
.unwrap();
params.push(Box::new(array.value(row_idx)));
}
arrow::datatypes::TimeUnit::Millisecond => {
let array = column
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap();
let value = array.value(row_idx) / 1000; params.push(Box::new(value));
}
arrow::datatypes::TimeUnit::Microsecond => {
let array = column
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.unwrap();
let value = array.value(row_idx) / 1_000_000; params.push(Box::new(value));
}
arrow::datatypes::TimeUnit::Nanosecond => {
let array = column
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
let value = array.value(row_idx) / 1_000_000_000; params.push(Box::new(value));
}
}
}
}
DataType::Time32(unit) => {
if column.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
let value = match unit {
arrow::datatypes::TimeUnit::Second => {
let array = column
.as_any()
.downcast_ref::<Time32SecondArray>()
.unwrap();
array.value(row_idx)
}
arrow::datatypes::TimeUnit::Millisecond => {
let array = column
.as_any()
.downcast_ref::<Time32MillisecondArray>()
.unwrap();
array.value(row_idx)
}
_ => 0,
};
params.push(Box::new(value));
}
}
DataType::Time64(unit) => {
let value = if column.is_null(row_idx) {
None
} else {
match unit {
arrow::datatypes::TimeUnit::Microsecond => {
let array = column
.as_any()
.downcast_ref::<Time64MicrosecondArray>()
.unwrap();
Some(array.value(row_idx))
}
arrow::datatypes::TimeUnit::Nanosecond => {
let array = column
.as_any()
.downcast_ref::<Time64NanosecondArray>()
.unwrap();
Some(array.value(row_idx))
}
_ => None,
}
};
if let Some(v) = value {
params.push(Box::new(v));
} else {
params.push(Box::new(rusqlite::types::Null));
}
}
DataType::Duration(unit) => {
if column.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
match unit {
arrow::datatypes::TimeUnit::Second => {
let array = column
.as_any()
.downcast_ref::<DurationSecondArray>()
.unwrap();
params.push(Box::new(array.value(row_idx)));
}
arrow::datatypes::TimeUnit::Millisecond => {
let array = column
.as_any()
.downcast_ref::<DurationMillisecondArray>()
.unwrap();
params.push(Box::new(array.value(row_idx)));
}
arrow::datatypes::TimeUnit::Microsecond => {
let array = column
.as_any()
.downcast_ref::<DurationMicrosecondArray>()
.unwrap();
params.push(Box::new(array.value(row_idx)));
}
arrow::datatypes::TimeUnit::Nanosecond => {
let array = column
.as_any()
.downcast_ref::<DurationNanosecondArray>()
.unwrap();
params.push(Box::new(array.value(row_idx)));
}
}
}
}
DataType::Interval(_) => {
use arrow::util::display::{ArrayFormatter, FormatOptions};
let formatter =
ArrayFormatter::try_new(column.as_ref(), &FormatOptions::default())
.map_err(|e| {
rusqlite::Error::ToSqlConversionFailure(Box::new(e))
})?;
let value_str = formatter.value(row_idx).to_string();
params.push(Box::new(value_str));
}
DataType::BinaryView => {
let array = column.as_any().downcast_ref::<BinaryViewArray>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_vec()));
}
}
DataType::Utf8View => {
let array = column.as_any().downcast_ref::<StringViewArray>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_string()));
}
}
DataType::FixedSizeBinary(_) => {
let array = column
.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_vec()));
}
}
DataType::Float16 => {
let array = column.as_any().downcast_ref::<Float16Array>().unwrap();
if array.is_null(row_idx) {
params.push(Box::new(rusqlite::types::Null));
} else {
params.push(Box::new(array.value(row_idx).to_f32()));
}
}
DataType::Null => {
params.push(Box::new(rusqlite::types::Null));
}
DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::List(_)
| DataType::LargeList(_)
| DataType::ListView(_)
| DataType::LargeListView(_)
| DataType::FixedSizeList(_, _)
| DataType::Struct(_)
| DataType::Map(_, _)
| DataType::Union(_, _)
| DataType::Dictionary(_, _)
| DataType::RunEndEncoded(_, _) => {
use arrow::util::display::{ArrayFormatter, FormatOptions};
let formatter =
ArrayFormatter::try_new(column.as_ref(), &FormatOptions::default())
.map_err(|e| {
rusqlite::Error::ToSqlConversionFailure(Box::new(e))
})?;
let value_str = formatter.value(row_idx).to_string();
params.push(Box::new(value_str));
}
}
}
let params_refs: Vec<&dyn ToSql> = params.iter().map(|p| p.as_ref()).collect();
stmt.execute(params_refs.as_slice())?;
}
Ok(())
}
fn delete_all_table_data(&self, transaction: &Transaction<'_>) -> rusqlite::Result<()> {
transaction.execute(
format!("DELETE FROM {}", self.table.to_quoted_string()).as_str(),
[],
)?;
Ok(())
}
fn create_table(
&self,
transaction: &Transaction<'_>,
primary_keys: Vec<String>,
) -> rusqlite::Result<()> {
let create_table_statement =
CreateTableBuilder::new(Arc::clone(&self.schema), self.table.table())
.primary_keys(primary_keys);
let sql = create_table_statement.build_sqlite();
transaction.execute(&sql, [])?;
Ok(())
}
fn create_index(
&self,
transaction: &Transaction<'_>,
columns: Vec<&str>,
unique: bool,
) -> rusqlite::Result<()> {
let mut index_builder = IndexBuilder::new(self.table.table(), columns);
if unique {
index_builder = index_builder.unique();
}
let sql = index_builder.build_sqlite();
transaction.execute(&sql, [])?;
Ok(())
}
async fn get_indexes(
&self,
sqlite_conn: &mut SqliteConnection,
) -> DataFusionResult<HashSet<String>> {
let query_result = sqlite_conn
.query_arrow(
format!("PRAGMA index_list({name})", name = self.table).as_str(),
&[],
None,
)
.await?;
let mut indexes = HashSet::new();
query_result
.try_collect::<Vec<RecordBatch>>()
.await
.into_iter()
.flatten()
.for_each(|batch| {
if let Some(name_array) = batch
.column_by_name("name")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
{
for index_name in name_array.iter().flatten() {
if !index_name.starts_with("sqlite_autoindex_") {
indexes.insert(index_name.to_string());
}
}
}
});
Ok(indexes)
}
async fn get_primary_keys(
&self,
sqlite_conn: &mut SqliteConnection,
) -> DataFusionResult<HashSet<String>> {
let query_result = sqlite_conn
.query_arrow(
format!("PRAGMA table_info({name})", name = self.table).as_str(),
&[],
None,
)
.await?;
let mut primary_keys = HashSet::new();
query_result
.try_collect::<Vec<RecordBatch>>()
.await
.into_iter()
.flatten()
.for_each(|batch| {
if let (Some(name_array), Some(pk_array)) = (
batch
.column_by_name("name")
.and_then(|col| col.as_any().downcast_ref::<StringArray>()),
batch
.column_by_name("pk")
.and_then(|col| col.as_any().downcast_ref::<Int64Array>()),
) {
for (name, pk) in name_array.iter().flatten().zip(pk_array.iter().flatten()) {
if pk > 0 {
primary_keys.insert(name.to_string());
}
}
}
});
Ok(primary_keys)
}
async fn verify_indexes_match(
&self,
sqlite_conn: &mut SqliteConnection,
indexes: &[(ColumnReference, IndexType)],
) -> DataFusionResult<bool> {
let expected_indexes_str_map: HashSet<String> = indexes
.iter()
.map(|(col, _)| {
IndexBuilder::new(self.table.table(), col.iter().collect()).index_name()
})
.collect();
let actual_indexes_str_map = self.get_indexes(sqlite_conn).await?;
let missing_in_actual = expected_indexes_str_map
.difference(&actual_indexes_str_map)
.collect::<Vec<_>>();
let extra_in_actual = actual_indexes_str_map
.difference(&expected_indexes_str_map)
.collect::<Vec<_>>();
if !missing_in_actual.is_empty() {
tracing::warn!(
"Missing indexes detected for the table '{name}': {:?}.",
missing_in_actual,
name = self.table
);
}
if !extra_in_actual.is_empty() {
tracing::warn!(
"The table '{name}' contains unexpected indexes not presented in the configuration: {:?}.",
extra_in_actual,
name = self.table
);
}
Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
}
async fn verify_primary_keys_match(
&self,
sqlite_conn: &mut SqliteConnection,
primary_keys: &[String],
) -> DataFusionResult<bool> {
let expected_pk_keys_str_map: HashSet<String> = primary_keys.iter().cloned().collect();
let actual_pk_keys_str_map = self.get_primary_keys(sqlite_conn).await?;
let missing_in_actual = expected_pk_keys_str_map
.difference(&actual_pk_keys_str_map)
.collect::<Vec<_>>();
let extra_in_actual = actual_pk_keys_str_map
.difference(&expected_pk_keys_str_map)
.collect::<Vec<_>>();
if !missing_in_actual.is_empty() {
tracing::warn!(
"Missing primary keys detected for the table '{name}': {:?}.",
missing_in_actual,
name = self.table
);
}
if !extra_in_actual.is_empty() {
tracing::warn!(
"The table '{name}' contains unexpected primary keys not presented in the configuration: {:?}.",
extra_in_actual,
name = self.table
);
}
Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty())
}
}
#[cfg(test)]
pub(crate) mod tests {
use arrow::datatypes::{DataType, Schema};
use datafusion::{
common::{Constraint, ToDFSchema},
prelude::SessionContext,
};
use super::*;
#[tokio::test]
async fn test_sqlite_table_creation_with_indexes() {
let schema = Arc::new(Schema::new(vec![
arrow::datatypes::Field::new("first_name", DataType::Utf8, false),
arrow::datatypes::Field::new("last_name", DataType::Utf8, false),
arrow::datatypes::Field::new("id", DataType::Int64, false),
]));
let options: HashMap<String, String> = [(
"indexes".to_string(),
"id:enabled;(first_name, last_name):unique".to_string(),
)]
.iter()
.cloned()
.collect();
let expected_indexes: HashSet<String> = [
"i_test_table_id".to_string(),
"i_test_table_first_name_last_name".to_string(),
]
.iter()
.cloned()
.collect();
let df_schema = ToDFSchema::to_dfschema_ref(Arc::clone(&schema)).expect("df schema");
let primary_keys_constraints = {
let schema = Arc::clone(&schema);
let indices: Vec<usize> = ["id"]
.iter()
.filter_map(|&col_name| schema.column_with_name(col_name).map(|(index, _)| index))
.collect();
Constraints::new_unverified(vec![Constraint::PrimaryKey(indices)])
};
let external_table = CreateExternalTable {
schema: df_schema,
name: TableReference::bare("test_table"),
location: String::new(),
file_type: String::new(),
table_partition_cols: vec![],
if_not_exists: true,
definition: None,
order_exprs: vec![],
unbounded: false,
options,
constraints: primary_keys_constraints,
column_defaults: HashMap::default(),
temporary: false,
or_replace: false,
};
let ctx = SessionContext::new();
let table = SqliteTableProviderFactory::default()
.create(&ctx.state(), &external_table)
.await
.expect("table should be created");
let sqlite = table
.as_any()
.downcast_ref::<SqliteTableWriter>()
.expect("downcast to SqliteTableWriter")
.sqlite();
let mut db_conn = sqlite.connect().await.expect("should connect to db");
let sqlite_conn =
Sqlite::sqlite_conn(&mut db_conn).expect("should create sqlite connection");
let retrieved_indexes = sqlite
.get_indexes(sqlite_conn)
.await
.expect("should get indexes");
assert_eq!(retrieved_indexes, expected_indexes);
let retrieved_primary_keys = sqlite
.get_primary_keys(sqlite_conn)
.await
.expect("should get primary keys");
assert_eq!(
retrieved_primary_keys,
vec!["id".to_string()]
.into_iter()
.collect::<HashSet<String>>()
);
}
}