pub mod fts_exec;
pub mod fts_table_function;
pub mod knn_exec;
pub mod knn_table_function;
pub mod vec_to_binary;
pub use fts_table_function::register_sqlite_fts_udtf;
pub use knn_table_function::{SqliteEntry, register_sqlite_knn_udtf};
pub use vec_to_binary::register_vec_to_binary_udf;
use anyhow::{Context, Result};
use arrow::array::{
ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array, Float64Array,
Int64Array, ListArray, RecordBatch, RecordBatchOptions, StringArray, UInt64Array,
};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::common::{Constraints, ScalarValue};
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, dml::InsertOp};
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
execution_plan::{Boundedness, EmissionType},
};
use datafusion::prelude::SessionContext;
use datafusion::sql::TableReference;
use datafusion::sql::unparser::Unparser;
use datafusion::sql::unparser::dialect::SqliteDialect;
use futures::{StreamExt, stream};
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio_rusqlite::Connection;
use crate::sources::DataSourceType;
use crate::sources::hierarchy::{
HierarchyLevel, SourceLabel, build_catalog, parse_allowed_schemas, retry_with_timeout,
};
use crate::sources::providers::{DatasetEntry, DatasetRegistry};
const DEFAULT_READ_POOL_SIZE: usize = 4;
pub async fn create_sqlite_table_provider(
db_path: &str,
table_name: &str,
) -> Result<Arc<dyn TableProvider>> {
let table_reference = TableReference::bare(table_name);
let provider = SqliteTableProvider::new(
db_path,
table_reference,
5000,
DEFAULT_READ_POOL_SIZE,
false,
&[],
)
.await?;
Ok(Arc::new(provider))
}
pub async fn register_sqlite_tables(
session_ctx: &mut SessionContext,
name: &str,
db_path: &str,
options: Option<&HashMap<String, String>>,
read_write: bool,
registry: Option<&DatasetRegistry>,
hierarchy_level: HierarchyLevel,
) -> Result<()> {
let mode_str = if read_write {
"read-write"
} else {
"read-only"
};
match hierarchy_level {
HierarchyLevel::Catalog => {
register_sqlite_catalog(
session_ctx,
name,
db_path,
options,
read_write,
mode_str,
registry,
)
.await
}
HierarchyLevel::Table => {
register_single_sqlite_table(
session_ctx,
name,
db_path,
options,
read_write,
mode_str,
registry,
)
.await
}
}
}
async fn register_single_sqlite_table(
session_ctx: &mut SessionContext,
name: &str,
db_path: &str,
options: Option<&HashMap<String, String>>,
read_write: bool,
mode_str: &str,
registry: Option<&DatasetRegistry>,
) -> Result<()> {
tracing::info!(
"Registering SQLite table: {} with path: {} ({})",
name,
db_path,
mode_str
);
tracing::debug!("Options: {:?}", options);
let table_name = options
.and_then(|opts| opts.get("table"))
.ok_or_else(|| anyhow::anyhow!("SQLite data source '{}' requires 'table' option", name))?;
let busy_timeout_ms: u64 = options
.and_then(|opts| opts.get("busy_timeout_ms"))
.and_then(|v| v.parse().ok())
.unwrap_or(5000);
let read_pool_size: usize = options
.and_then(|opts| opts.get("read_pool_size"))
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_READ_POOL_SIZE);
let mut extensions: Vec<String> = options
.and_then(|opts| opts.get("extensions"))
.map(|v| v.split(',').map(|s| s.trim().to_string()).collect())
.unwrap_or_default();
if let Some(env_key) = options.and_then(|opts| opts.get("extensions_env")) {
for key in env_key.split(',') {
let key = key.trim();
if let Ok(val) = std::env::var(key) {
extensions.push(val);
} else {
tracing::warn!("SQLite extension env var '{}' not set, skipping", key);
}
}
}
tracing::debug!(
"Connecting to SQLite table: {} in database '{}' as '{}'",
table_name,
db_path,
name
);
let table_reference = TableReference::bare(table_name.as_str());
let provider = SqliteTableProvider::new(
db_path,
table_reference.clone(),
busy_timeout_ms,
read_pool_size,
read_write,
&extensions,
)
.await
.with_context(|| {
format!(
"Failed to create SQLite table provider for '{}'",
table_name
)
})?;
if let Some(registry) = registry {
let columns: Vec<(String, DataType)> = provider
.schema
.fields()
.iter()
.map(|f| (f.name().clone(), f.data_type().clone()))
.collect();
let entry = SqliteEntry {
conn: Arc::clone(&provider.read_pool[0]),
table_name: table_name.clone(),
columns,
};
let mut reg = registry
.write()
.map_err(|e| anyhow::anyhow!("sqlite registry lock error: {}", e))?;
reg.insert(name.to_string(), DatasetEntry::Sqlite(entry));
tracing::debug!("Registered SQLite table '{}' in dataset registry", name);
}
session_ctx
.register_table(name, Arc::new(provider))
.map_err(|e| {
tracing::error!("Failed to register table with DataFusion: {:?}", e);
e
})
.with_context(|| format!("Failed to register SQLite table '{}' with DataFusion", name))?;
tracing::info!(
"Successfully registered SQLite table '{}' as '{}' ({})",
table_reference,
name,
mode_str
);
Ok(())
}
async fn register_sqlite_catalog(
session_ctx: &mut SessionContext,
catalog_name: &str,
db_path: &str,
options: Option<&HashMap<String, String>>,
read_write: bool,
mode_str: &str,
registry: Option<&DatasetRegistry>,
) -> Result<()> {
tracing::info!(
"Registering SQLite catalog: {} ({})",
catalog_name,
mode_str
);
let busy_timeout_ms: u64 = options
.and_then(|opts| opts.get("busy_timeout_ms"))
.and_then(|v| v.parse().ok())
.unwrap_or(5000);
let read_pool_size: usize = options
.and_then(|opts| opts.get("read_pool_size"))
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_READ_POOL_SIZE);
let mut extensions: Vec<String> = options
.and_then(|opts| opts.get("extensions"))
.map(|v| v.split(',').map(|s| s.trim().to_string()).collect())
.unwrap_or_default();
if let Some(env_key) = options.and_then(|opts| opts.get("extensions_env")) {
for key in env_key.split(',') {
let key = key.trim();
if let Ok(val) = std::env::var(key) {
extensions.push(val);
} else {
tracing::warn!("SQLite extension env var '{}' not set, skipping", key);
}
}
}
let label = SourceLabel::new(
DataSourceType::Sqlite,
HierarchyLevel::Catalog,
catalog_name,
);
let (read_pool, write_conn) = open_sqlite_pool(
db_path,
busy_timeout_ms,
read_pool_size,
read_write,
&extensions,
)
.await
.with_context(|| {
format!(
"Failed to open shared SQLite pool for catalog '{}'",
catalog_name
)
})?;
let intro_conn = Arc::clone(&read_pool[0]);
let mut schema_tables = retry_with_timeout(label, "sqlite_master introspection", || async {
list_sqlite_tables_in_catalog(&intro_conn).await
})
.await?;
let allowed_schemas = parse_allowed_schemas(options);
if let Some(ref allowed) = allowed_schemas {
if !allowed.iter().any(|s| s == "main") {
tracing::warn!(
"SQLite catalog '{}' has allowed_schemas={:?} which excludes 'main'; \
all SQLite tables live in schema 'main' so no tables will be registered",
catalog_name,
allowed
);
}
schema_tables.retain(|(schema, _)| allowed.iter().any(|s| s == schema));
}
if schema_tables.is_empty() {
tracing::warn!(
"No tables found in SQLite catalog for source '{}'",
catalog_name
);
}
let table_count = schema_tables.len();
let shared_read_pool = Arc::new(read_pool);
let shared_write_conn = write_conn;
let catalog_name_owned = catalog_name.to_string();
build_catalog(
session_ctx,
catalog_name,
schema_tables,
|schema, table_name| {
let read_pool_c = Arc::clone(&shared_read_pool);
let write_conn_c = shared_write_conn.clone();
let registry_c = registry.map(Arc::clone);
let catalog_c = catalog_name_owned.clone();
async move {
let provider = SqliteTableProvider::from_shared_pool(
(*read_pool_c).clone(),
write_conn_c,
TableReference::bare(table_name.as_str()),
read_write,
)
.await
.with_context(|| {
format!(
"Failed to create SQLite table provider for '{}.{}'",
schema, table_name
)
})?;
if let Some(registry) = registry_c {
let columns: Vec<(String, DataType)> = provider
.schema
.fields()
.iter()
.map(|f| (f.name().clone(), f.data_type().clone()))
.collect();
let entry = SqliteEntry {
conn: Arc::clone(&provider.read_pool[0]),
table_name: table_name.clone(),
columns,
};
let key = format!("{}.{}.{}", catalog_c, schema, table_name);
let mut reg = registry
.write()
.map_err(|e| anyhow::anyhow!("sqlite registry lock error: {}", e))?;
reg.insert(key, DatasetEntry::Sqlite(entry));
}
Ok(Arc::new(provider) as Arc<dyn TableProvider>)
}
},
)
.await
.with_context(|| format!("Failed to build SQLite catalog '{}'", catalog_name))?;
tracing::info!(
"Registered SQLite catalog '{}' with {} table(s) ({})",
catalog_name,
table_count,
mode_str
);
Ok(())
}
async fn list_sqlite_tables_in_catalog(conn: &Connection) -> Result<Vec<(String, String)>> {
let names: Vec<String> = conn
.call(
move |conn| -> std::result::Result<Vec<String>, tokio_rusqlite::rusqlite::Error> {
let mut stmt = conn.prepare(
"SELECT name FROM sqlite_master \
WHERE type IN ('table', 'view') \
AND name NOT LIKE 'sqlite_%' \
AND sql IS NOT NULL \
ORDER BY name",
)?;
stmt.query_map([], |row| row.get::<_, String>(0))?
.collect::<std::result::Result<Vec<_>, _>>()
},
)
.await
.map_err(|e| anyhow::anyhow!("Failed to list tables in SQLite catalog: {}", e))?;
Ok(names.into_iter().map(|t| ("main".to_string(), t)).collect())
}
async fn open_sqlite_pool(
db_path: &str,
busy_timeout_ms: u64,
read_pool_size: usize,
read_write: bool,
extensions: &[String],
) -> Result<(Vec<Arc<Connection>>, Option<Arc<Connection>>)> {
let pool_size = read_pool_size.max(1);
let mut read_pool = Vec::with_capacity(pool_size);
for _ in 0..pool_size {
let conn = Connection::open(db_path)
.await
.with_context(|| format!("Failed to open SQLite read connection: {}", db_path))?;
init_connection(&conn, busy_timeout_ms, extensions).await?;
read_pool.push(Arc::new(conn));
}
let write_conn = if read_write {
let conn = Connection::open(db_path)
.await
.with_context(|| format!("Failed to open SQLite write connection: {}", db_path))?;
init_connection(&conn, busy_timeout_ms, extensions).await?;
Some(Arc::new(conn))
} else {
None
};
Ok((read_pool, write_conn))
}
struct SqliteTableProvider {
read_pool: Vec<Arc<Connection>>,
read_pool_idx: AtomicUsize,
write_conn: Option<Arc<Connection>>,
table_reference: TableReference,
schema: SchemaRef,
read_write: bool,
}
impl fmt::Debug for SqliteTableProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SqliteTableProvider")
.field("table_reference", &self.table_reference)
.field("read_pool_size", &self.read_pool.len())
.field("read_write", &self.read_write)
.finish()
}
}
impl SqliteTableProvider {
async fn new(
db_path: &str,
table_reference: TableReference,
busy_timeout_ms: u64,
read_pool_size: usize,
read_write: bool,
extensions: &[String],
) -> Result<Self> {
let (read_pool, write_conn) = open_sqlite_pool(
db_path,
busy_timeout_ms,
read_pool_size,
read_write,
extensions,
)
.await?;
Self::from_shared_pool(read_pool, write_conn, table_reference, read_write).await
}
async fn from_shared_pool(
read_pool: Vec<Arc<Connection>>,
write_conn: Option<Arc<Connection>>,
table_reference: TableReference,
read_write: bool,
) -> Result<Self> {
if read_pool.is_empty() {
anyhow::bail!("SqliteTableProvider::from_shared_pool requires a non-empty read pool");
}
let schema = read_schema_from_pragma(&read_pool[0], table_reference.table()).await?;
if schema.fields().is_empty() {
tracing::warn!(
"PRAGMA table_info returned empty schema for '{}' — table may not exist",
table_reference.table()
);
}
Ok(Self {
read_pool,
read_pool_idx: AtomicUsize::new(0),
write_conn,
table_reference,
schema,
read_write,
})
}
fn next_read_conn(&self) -> Arc<Connection> {
let idx = self.read_pool_idx.fetch_add(1, Ordering::Relaxed) % self.read_pool.len();
Arc::clone(&self.read_pool[idx])
}
fn write_conn(&self) -> DataFusionResult<Arc<Connection>> {
self.write_conn.as_ref().cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"Table '{}' is registered as read-only",
self.table_reference
))
})
}
}
fn validate_extension_path(path: &str) -> Result<()> {
use std::path::{Component, Path};
let p = Path::new(path);
if !p.is_absolute() {
anyhow::bail!(
"SQLite extension path must be absolute, got '{}'. \
Use an absolute path to the extension shared library.",
path
);
}
if p.components().any(|c| matches!(c, Component::ParentDir)) {
anyhow::bail!(
"SQLite extension path must not contain '..' components, got '{}'",
path
);
}
Ok(())
}
async fn init_connection(
conn: &Connection,
busy_timeout_ms: u64,
extensions: &[String],
) -> Result<()> {
let timeout = busy_timeout_ms;
for ext_path in extensions {
validate_extension_path(ext_path)?;
}
let exts = extensions.to_vec();
conn.call(
move |conn| -> std::result::Result<(), tokio_rusqlite::rusqlite::Error> {
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "busy_timeout", timeout)?;
if !exts.is_empty() {
unsafe { conn.load_extension_enable()? };
for ext_path in &exts {
unsafe { conn.load_extension(ext_path, None::<&str>)? };
}
conn.load_extension_disable()?;
}
Ok(())
},
)
.await
.map_err(|e| anyhow::anyhow!("Failed to initialize SQLite connection: {}", e))
}
pub(crate) async fn read_schema_from_pragma(
conn: &Connection,
table_name: &str,
) -> Result<SchemaRef> {
let tbl = table_name.to_string();
let fields: Vec<Field> = conn
.call(
move |conn| -> std::result::Result<Vec<Field>, tokio_rusqlite::rusqlite::Error> {
let is_fts = is_fts_table(conn, &tbl);
let mut stmt = conn.prepare(&format!("PRAGMA table_info(\"{}\")", tbl))?;
let rows = stmt.query_map([], |row| {
let col_name: String = row.get(1)?;
let col_type: String = row.get(2)?;
let not_null: bool = row.get(3)?;
let is_pk: bool = row.get::<_, i32>(5)? != 0;
Ok((col_name, col_type, not_null, is_pk))
})?;
let mut fields = Vec::new();
for row in rows {
let (col_name, col_type, not_null, is_pk) = row?;
let data_type = if is_fts && col_type.is_empty() && !is_pk {
DataType::Utf8
} else {
sqlite_type_to_arrow(&col_type, is_pk)
};
fields.push(Field::new(col_name, data_type, !not_null));
}
Ok(fields)
},
)
.await
.map_err(|e| anyhow::anyhow!("PRAGMA table_info failed: {}", e))?;
Ok(Arc::new(Schema::new(fields)))
}
fn is_fts_table(conn: &tokio_rusqlite::rusqlite::Connection, table_name: &str) -> bool {
conn.query_row(
"SELECT sql FROM sqlite_master WHERE type = 'table' AND name = ?1",
[table_name],
|row| row.get::<_, Option<String>>(0),
)
.ok()
.flatten()
.map(|sql| {
let upper = sql.to_uppercase();
upper.contains("FTS5") || upper.contains("FTS4") || upper.contains("FTS3")
})
.unwrap_or(false)
}
#[async_trait]
impl TableProvider for SqliteTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn table_type(&self) -> TableType {
TableType::Base
}
fn constraints(&self) -> Option<&Constraints> {
None
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()])
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let conn = self.next_read_conn();
Ok(Arc::new(SqliteScanExec::new(
conn,
self.table_reference.clone(),
Arc::clone(&self.schema),
projection.cloned(),
filters.to_vec(),
limit,
)))
}
async fn insert_into(
&self,
_state: &dyn Session,
input: Arc<dyn ExecutionPlan>,
op: InsertOp,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let conn = self.write_conn()?;
Ok(Arc::new(SqliteInsertExec::new(
conn,
self.table_reference.clone(),
Arc::clone(&self.schema),
input,
op,
)))
}
async fn delete_from(
&self,
_state: &dyn Session,
filters: Vec<Expr>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let conn = self.write_conn()?;
let table = quote_sqlite_table(&self.table_reference);
let where_clause = build_sqlite_where_clause(&filters)?;
let sql = format!("DELETE FROM {table}{where_clause}");
Ok(Arc::new(SqliteDmlExec::new(conn, sql)))
}
async fn update(
&self,
_state: &dyn Session,
assignments: Vec<(String, Expr)>,
filters: Vec<Expr>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if assignments.is_empty() {
return Err(DataFusionError::Plan(
"UPDATE requires at least one assignment".to_string(),
));
}
let unparser = Unparser::new(&SqliteDialect {});
let set_clause = assignments
.iter()
.map(|(col, expr)| {
let val = unparser
.expr_to_sql(expr)
.map_err(|e| {
DataFusionError::Plan(format!(
"Failed to unparse assignment expression for column '{col}': {e}"
))
})?
.to_string();
Ok(format!("{} = {val}", quote_sqlite_ident(col)))
})
.collect::<DataFusionResult<Vec<_>>>()?
.join(", ");
let conn = self.write_conn()?;
let table = quote_sqlite_table(&self.table_reference);
let where_clause = build_sqlite_where_clause(&filters)?;
let sql = format!("UPDATE {table} SET {set_clause}{where_clause}");
Ok(Arc::new(SqliteDmlExec::new(conn, sql)))
}
}
struct SqliteScanExec {
conn: Arc<Connection>,
table_reference: TableReference,
table_schema: SchemaRef,
output_schema: SchemaRef,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
limit: Option<usize>,
properties: PlanProperties,
}
impl SqliteScanExec {
fn new(
conn: Arc<Connection>,
table_reference: TableReference,
table_schema: SchemaRef,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
limit: Option<usize>,
) -> Self {
let output_schema = if let Some(ref proj) = projection {
Arc::new(table_schema.project(proj).expect("valid projection"))
} else {
Arc::clone(&table_schema)
};
let properties = PlanProperties::new(
EquivalenceProperties::new(Arc::clone(&output_schema)),
Partitioning::UnknownPartitioning(1),
EmissionType::Final,
Boundedness::Bounded,
);
Self {
conn,
table_reference,
table_schema,
output_schema,
projection,
filters,
limit,
properties,
}
}
fn build_sql(&self) -> DataFusionResult<String> {
let columns: Vec<String> = if let Some(ref proj) = self.projection {
proj.iter()
.map(|&i| quote_sqlite_ident(self.table_schema.field(i).name()))
.collect()
} else {
self.table_schema
.fields()
.iter()
.map(|f| quote_sqlite_ident(f.name()))
.collect()
};
let table = quote_sqlite_table(&self.table_reference);
let where_clause = build_sqlite_where_clause(&self.filters)?;
let limit_clause = self
.limit
.map(|n| format!(" LIMIT {n}"))
.unwrap_or_default();
let projection_clause = if columns.is_empty() {
"1".to_string()
} else {
columns.join(", ")
};
Ok(format!(
"SELECT {projection_clause} FROM {table}{where_clause}{limit_clause}"
))
}
}
impl fmt::Debug for SqliteScanExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SqliteScanExec(table={})", self.table_reference)
}
}
impl DisplayAs for SqliteScanExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SqliteScanExec: table={}", self.table_reference)
}
}
impl ExecutionPlan for SqliteScanExec {
fn name(&self) -> &str {
"SqliteScanExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.output_schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let conn = Arc::clone(&self.conn);
let sql = self.build_sql()?;
let output_schema = Arc::clone(&self.output_schema);
let num_cols = output_schema.fields().len();
let field_types: Vec<DataType> = output_schema
.fields()
.iter()
.map(|f| f.data_type().clone())
.collect();
let future = async move {
let (batch, row_count): (Vec<Vec<tokio_rusqlite::rusqlite::types::Value>>, usize) =
conn.call(
move |conn| -> std::result::Result<_, tokio_rusqlite::rusqlite::Error> {
let mut stmt = conn.prepare(&sql)?;
let mut col_values: Vec<Vec<tokio_rusqlite::rusqlite::types::Value>> =
(0..num_cols).map(|_| Vec::new()).collect();
let mut row_count: usize = 0;
let mut rows = stmt.query([])?;
while let Some(row) = rows.next()? {
for col_idx in 0..num_cols {
let val: tokio_rusqlite::rusqlite::types::Value =
row.get(col_idx)?;
col_values[col_idx].push(val);
}
row_count += 1;
}
Ok((col_values, row_count))
},
)
.await
.map_err(|e| DataFusionError::Execution(format!("SQLite scan error: {e}")))?;
let arrays: Vec<ArrayRef> = batch
.into_iter()
.zip(field_types.iter())
.map(|(values, data_type)| sqlite_values_to_arrow(&values, data_type))
.collect();
if num_cols == 0 {
let options = RecordBatchOptions::new().with_row_count(Some(row_count));
RecordBatch::try_new_with_options(output_schema, arrays, &options)
.map_err(DataFusionError::from)
} else {
RecordBatch::try_new(output_schema, arrays).map_err(DataFusionError::from)
}
};
Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.output_schema),
stream::once(future),
)))
}
}
pub(crate) fn sqlite_values_to_arrow(
values: &[tokio_rusqlite::rusqlite::types::Value],
data_type: &DataType,
) -> ArrayRef {
use tokio_rusqlite::rusqlite::types::Value;
match data_type {
DataType::Int64 => {
let arr: Int64Array = values
.iter()
.map(|v| match v {
Value::Integer(i) => Some(*i),
Value::Null => None,
_ => None,
})
.collect();
Arc::new(arr)
}
DataType::Float64 => {
let arr: Float64Array = values
.iter()
.map(|v| match v {
Value::Real(f) => Some(*f),
Value::Integer(i) => Some(*i as f64),
Value::Null => None,
_ => None,
})
.collect();
Arc::new(arr)
}
DataType::Boolean => {
let arr: BooleanArray = values
.iter()
.map(|v| match v {
Value::Integer(i) => Some(*i != 0),
Value::Null => None,
_ => None,
})
.collect();
Arc::new(arr)
}
DataType::Binary => {
let arr: BinaryArray = values
.iter()
.map(|v| match v {
Value::Blob(b) => Some(b.as_slice()),
Value::Null => None,
_ => None,
})
.collect();
Arc::new(arr)
}
_ => {
let strings: Vec<Option<String>> = values
.iter()
.map(|v| match v {
Value::Text(s) => Some(s.clone()),
Value::Integer(i) => Some(i.to_string()),
Value::Real(f) => Some(f.to_string()),
Value::Null => None,
_ => None,
})
.collect();
let arr: StringArray = strings.iter().map(|v| v.as_deref()).collect();
Arc::new(arr)
}
}
}
struct SqliteInsertExec {
conn: Arc<Connection>,
table_reference: TableReference,
table_schema: SchemaRef,
input: Arc<dyn ExecutionPlan>,
op: InsertOp,
output_schema: SchemaRef,
properties: PlanProperties,
}
impl SqliteInsertExec {
fn new(
conn: Arc<Connection>,
table_reference: TableReference,
table_schema: SchemaRef,
input: Arc<dyn ExecutionPlan>,
op: InsertOp,
) -> Self {
let output_schema = Arc::new(Schema::new(vec![Field::new(
"count",
DataType::UInt64,
false,
)]));
let properties = PlanProperties::new(
EquivalenceProperties::new(Arc::clone(&output_schema)),
Partitioning::UnknownPartitioning(1),
EmissionType::Final,
Boundedness::Bounded,
);
Self {
conn,
table_reference,
table_schema,
input,
op,
output_schema,
properties,
}
}
}
impl fmt::Debug for SqliteInsertExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SqliteInsertExec(table={})", self.table_reference)
}
}
impl DisplayAs for SqliteInsertExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SqliteInsertExec")
}
}
impl ExecutionPlan for SqliteInsertExec {
fn name(&self) -> &str {
"SqliteInsertExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.output_schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn required_input_distribution(&self) -> Vec<Distribution> {
vec![Distribution::SinglePartition]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(DataFusionError::Internal(
"SqliteInsertExec expects exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new(
Arc::clone(&self.conn),
self.table_reference.clone(),
Arc::clone(&self.table_schema),
children.into_iter().next().expect("len == 1 checked above"),
self.op,
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let conn = Arc::clone(&self.conn);
let table_ref = self.table_reference.clone();
let op = self.op;
let output_schema = Arc::clone(&self.output_schema);
let mut input_stream = self.input.execute(partition, context)?;
let input_schema = self.input.schema();
let future = async move {
let mut batches: Vec<RecordBatch> = Vec::new();
while let Some(batch_result) = input_stream.next().await {
let batch = batch_result?;
if batch.num_rows() > 0 {
batches.push(batch);
}
}
let total_rows: u64 = conn
.call(
move |conn| -> std::result::Result<u64, tokio_rusqlite::rusqlite::Error> {
let tx = conn.transaction()?;
if matches!(op, InsertOp::Overwrite) {
let table = quote_sqlite_table(&table_ref);
tx.execute(&format!("DELETE FROM {table}"), [])?;
}
let col_names: Vec<String> = input_schema
.fields()
.iter()
.map(|f| quote_sqlite_ident(f.name()))
.collect();
let table = quote_sqlite_table(&table_ref);
let placeholders: Vec<&str> = vec!["?"; col_names.len()];
let insert_sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
table,
col_names.join(", "),
placeholders.join(", ")
);
let mut total: u64 = 0;
for batch in &batches {
let num_rows = batch.num_rows();
let num_cols = batch.num_columns();
for row_idx in 0..num_rows {
let params: Vec<tokio_rusqlite::rusqlite::types::Value> = (0
..num_cols)
.map(|col_idx| arrow_value_to_sqlite(batch, row_idx, col_idx))
.collect();
let param_refs: Vec<&dyn tokio_rusqlite::rusqlite::types::ToSql> =
params
.iter()
.map(|v| v as &dyn tokio_rusqlite::rusqlite::types::ToSql)
.collect();
tx.execute(&insert_sql, param_refs.as_slice())?;
total += 1;
}
}
tx.commit()?;
Ok(total)
},
)
.await
.map_err(|e| DataFusionError::Execution(format!("SQLite INSERT error: {e}")))?;
let count_array = Arc::new(UInt64Array::from(vec![total_rows]));
RecordBatch::try_new(output_schema, vec![count_array]).map_err(DataFusionError::from)
};
Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.output_schema),
stream::once(future),
)))
}
}
fn arrow_value_to_sqlite(
batch: &RecordBatch,
row: usize,
col: usize,
) -> tokio_rusqlite::rusqlite::types::Value {
use arrow::array::{Array, AsArray};
use tokio_rusqlite::rusqlite::types::Value;
let array = batch.column(col);
if array.is_null(row) {
return Value::Null;
}
match array.data_type() {
DataType::Int8 => Value::Integer(
array
.as_primitive::<arrow::datatypes::Int8Type>()
.value(row) as i64,
),
DataType::Int16 => Value::Integer(
array
.as_primitive::<arrow::datatypes::Int16Type>()
.value(row) as i64,
),
DataType::Int32 => Value::Integer(
array
.as_primitive::<arrow::datatypes::Int32Type>()
.value(row) as i64,
),
DataType::Int64 => Value::Integer(
array
.as_primitive::<arrow::datatypes::Int64Type>()
.value(row),
),
DataType::UInt8 => Value::Integer(
array
.as_primitive::<arrow::datatypes::UInt8Type>()
.value(row) as i64,
),
DataType::UInt16 => Value::Integer(
array
.as_primitive::<arrow::datatypes::UInt16Type>()
.value(row) as i64,
),
DataType::UInt32 => Value::Integer(
array
.as_primitive::<arrow::datatypes::UInt32Type>()
.value(row) as i64,
),
DataType::UInt64 => Value::Integer(
array
.as_primitive::<arrow::datatypes::UInt64Type>()
.value(row) as i64,
),
DataType::Float16 => Value::Real(
array
.as_primitive::<arrow::datatypes::Float16Type>()
.value(row)
.to_f64(),
),
DataType::Float32 => Value::Real(
array
.as_primitive::<arrow::datatypes::Float32Type>()
.value(row) as f64,
),
DataType::Float64 => Value::Real(
array
.as_primitive::<arrow::datatypes::Float64Type>()
.value(row),
),
DataType::Boolean => Value::Integer(if array.as_boolean().value(row) { 1 } else { 0 }),
DataType::Utf8 => Value::Text(array.as_string::<i32>().value(row).to_string()),
DataType::LargeUtf8 => Value::Text(array.as_string::<i64>().value(row).to_string()),
DataType::Binary => Value::Blob(array.as_binary::<i32>().value(row).to_vec()),
DataType::LargeBinary => Value::Blob(array.as_binary::<i64>().value(row).to_vec()),
DataType::List(field) if *field.data_type() == DataType::Float32 => {
let list = array
.as_any()
.downcast_ref::<ListArray>()
.expect("DataType::List guarantees ListArray");
let values = list.value(row);
let f32_arr = values
.as_any()
.downcast_ref::<Float32Array>()
.expect("List<Float32> guarantees Float32Array values");
let blob: Vec<u8> = f32_arr
.values()
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
Value::Blob(blob)
}
DataType::FixedSizeList(field, _) if *field.data_type() == DataType::Float32 => {
let list = array
.as_any()
.downcast_ref::<FixedSizeListArray>()
.expect("DataType::FixedSizeList guarantees FixedSizeListArray");
let values = list.value(row);
let f32_arr = values
.as_any()
.downcast_ref::<Float32Array>()
.expect("FixedSizeList<Float32> guarantees Float32Array values");
let blob: Vec<u8> = f32_arr
.values()
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
Value::Blob(blob)
}
_ => Value::Text(format!("{:?}", array.as_ref())),
}
}
struct SqliteDmlExec {
conn: Arc<Connection>,
sql: String,
schema: SchemaRef,
properties: PlanProperties,
}
impl SqliteDmlExec {
fn new(conn: Arc<Connection>, sql: String) -> Self {
let schema = Arc::new(Schema::new(vec![Field::new(
"count",
DataType::UInt64,
false,
)]));
let properties = PlanProperties::new(
EquivalenceProperties::new(Arc::clone(&schema)),
Partitioning::UnknownPartitioning(1),
EmissionType::Final,
Boundedness::Bounded,
);
Self {
conn,
sql,
schema,
properties,
}
}
}
impl fmt::Debug for SqliteDmlExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SqliteDmlExec(sql={})", self.sql)
}
}
impl DisplayAs for SqliteDmlExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SqliteDmlExec")
}
}
impl ExecutionPlan for SqliteDmlExec {
fn name(&self) -> &str {
"SqliteDmlExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let conn = Arc::clone(&self.conn);
let sql = self.sql.clone();
let schema = Arc::clone(&self.schema);
let future = async move {
let rows_affected: u64 = conn
.call(
move |conn| -> Result<u64, tokio_rusqlite::rusqlite::Error> {
let affected = conn.execute(&sql, [])?;
Ok(affected as u64)
},
)
.await
.map_err(|e| {
DataFusionError::Execution(format!("SQLite DML execute error: {e}"))
})?;
let count_array = Arc::new(UInt64Array::from(vec![rows_affected]));
RecordBatch::try_new(Arc::clone(&schema), vec![count_array])
.map_err(DataFusionError::from)
};
Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.schema),
stream::once(future),
)))
}
}
pub(crate) fn sqlite_type_to_arrow(sqlite_type: &str, is_pk: bool) -> DataType {
let upper = sqlite_type.to_uppercase();
if upper.is_empty() {
if is_pk {
DataType::Int64
} else {
DataType::Binary
}
} else if upper.contains("INT") {
DataType::Int64
} else if upper.contains("REAL") || upper.contains("FLOAT") || upper.contains("DOUBLE") {
DataType::Float64
} else if upper.contains("BLOB") {
DataType::Binary
} else if upper.contains("BOOL") {
DataType::Boolean
} else {
DataType::Utf8
}
}
fn build_sqlite_where_clause(filters: &[Expr]) -> DataFusionResult<String> {
if filters.is_empty() {
return Ok(String::new());
}
let unparser = Unparser::new(&SqliteDialect {});
let parts = filters
.iter()
.map(|e| {
unparser.expr_to_sql(e).map(|s| s.to_string()).map_err(|e| {
DataFusionError::Plan(format!("Failed to unparse filter expression: {e}"))
})
})
.collect::<DataFusionResult<Vec<_>>>()?;
Ok(format!(" WHERE {}", parts.join(" AND ")))
}
pub(crate) fn quote_sqlite_ident(s: &str) -> String {
format!("\"{}\"", s.replace('"', "\"\""))
}
pub(crate) fn expr_to_sqlite_sql(expr: &Expr) -> Option<String> {
let unparser = Unparser::new(&SqliteDialect {});
unparser.expr_to_sql(expr).ok().map(|ast| ast.to_string())
}
pub(crate) fn extract_string(expr: &Expr, name: &str) -> DataFusionResult<String> {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(s)), _)
| Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()),
Expr::Literal(ScalarValue::Null, _) => Ok(String::new()),
_ => Err(DataFusionError::Plan(format!(
"sqlite: '{}' must be a string literal",
name
))),
}
}
fn quote_sqlite_table(tbl: &TableReference) -> String {
quote_sqlite_ident(tbl.table())
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Array as _, Int64Array};
#[test]
fn test_empty_type_maps_to_binary_for_non_pk() {
assert_eq!(sqlite_type_to_arrow("", false), DataType::Binary);
}
#[test]
fn test_empty_type_maps_to_int64_for_pk() {
assert_eq!(sqlite_type_to_arrow("", true), DataType::Int64);
}
#[test]
fn test_quote_sqlite_ident() {
assert_eq!(quote_sqlite_ident("users"), "\"users\"");
assert_eq!(quote_sqlite_ident("my\"table"), "\"my\"\"table\"");
}
#[test]
fn test_quote_sqlite_table() {
let reference = TableReference::bare("users");
assert_eq!(quote_sqlite_table(&reference), "\"users\"");
}
#[test]
fn test_missing_table_option() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let mut session_ctx = SessionContext::new();
let result = register_sqlite_tables(
&mut session_ctx,
"test_table",
"/tmp/test.db",
None,
false,
None,
HierarchyLevel::Table,
)
.await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("requires 'table' option"));
});
}
async fn create_test_db() -> tempfile::TempPath {
let tmp = tempfile::NamedTempFile::new().expect("create temp file");
let path = tmp.into_temp_path();
let db_path = path.to_str().unwrap().to_string();
let conn = Connection::open(&db_path).await.expect("open temp sqlite");
conn.call(|conn| -> Result<(), tokio_rusqlite::rusqlite::Error> {
conn.execute_batch(
"CREATE TABLE test_items (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
value INTEGER NOT NULL
);
INSERT INTO test_items (id, name, value) VALUES (1, 'alice', 10);
INSERT INTO test_items (id, name, value) VALUES (2, 'bob', 20);
INSERT INTO test_items (id, name, value) VALUES (3, 'carol', 30);",
)?;
Ok(())
})
.await
.expect("seed table");
conn.close().await.expect("close seed connection");
path
}
async fn register_test_table(ctx: &mut SessionContext, db_path: &str) {
let mut options = HashMap::new();
options.insert("table".to_string(), "test_items".to_string());
register_sqlite_tables(
ctx,
"test_items",
db_path,
Some(&options),
true,
None,
HierarchyLevel::Table,
)
.await
.expect("register sqlite table");
}
async fn query_all(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
let df = ctx.sql(sql).await.expect("parse sql");
df.collect().await.expect("collect results")
}
fn total_rows(batches: &[RecordBatch]) -> usize {
batches.iter().map(|b| b.num_rows()).sum()
}
#[tokio::test]
#[ignore]
async fn test_scan_all_rows() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
let batches = query_all(&ctx, "SELECT id, name, value FROM test_items ORDER BY id").await;
assert_eq!(total_rows(&batches), 3);
let ids = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(ids.value(0), 1);
assert_eq!(ids.value(1), 2);
assert_eq!(ids.value(2), 3);
}
#[tokio::test]
#[ignore]
async fn test_scan_with_projection() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
let batches = query_all(&ctx, "SELECT name FROM test_items ORDER BY id").await;
assert_eq!(total_rows(&batches), 3);
assert_eq!(batches[0].num_columns(), 1);
let names = batches[0]
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "alice");
}
#[tokio::test]
#[ignore]
async fn test_scan_with_filter() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
let batches = query_all(&ctx, "SELECT id, name FROM test_items WHERE id = 2").await;
assert_eq!(total_rows(&batches), 1);
let names = batches[0]
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "bob");
}
#[tokio::test]
#[ignore]
async fn test_scan_with_limit() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
let batches = query_all(&ctx, "SELECT id FROM test_items LIMIT 2").await;
assert_eq!(total_rows(&batches), 2);
}
#[tokio::test]
#[ignore]
async fn test_count_star_pushdown() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
let batches = query_all(&ctx, "SELECT count(*) FROM test_items").await;
assert_eq!(total_rows(&batches), 1);
let counts = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 3);
}
#[tokio::test]
#[ignore]
async fn test_count_star_with_filter() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
let batches = query_all(&ctx, "SELECT count(*) FROM test_items WHERE id > 1").await;
assert_eq!(total_rows(&batches), 1);
let counts = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 2);
}
#[tokio::test]
#[ignore]
async fn test_count_star_empty_table() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("DELETE FROM test_items")
.await
.expect("parse delete")
.collect()
.await
.expect("execute delete");
let batches = query_all(&ctx, "SELECT count(*) FROM test_items").await;
assert_eq!(total_rows(&batches), 1);
let counts = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 0);
}
#[tokio::test]
#[ignore]
async fn test_insert_into() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("INSERT INTO test_items (id, name, value) VALUES (4, 'dave', 40)")
.await
.expect("parse insert")
.collect()
.await
.expect("execute insert");
let batches = query_all(&ctx, "SELECT id, name, value FROM test_items ORDER BY id").await;
assert_eq!(total_rows(&batches), 4);
let last_batch = &batches[batches.len() - 1];
let ids = last_batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert!(ids.values().iter().any(|&v| v == 4), "id 4 should exist");
}
#[tokio::test]
#[ignore]
async fn test_insert_multi_row_values() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql(
"INSERT INTO test_items (id, name, value) VALUES \
(10, 'eve', 100), (11, 'frank', 110), (12, 'gina', 120)",
)
.await
.expect("parse multi-row insert")
.collect()
.await
.expect("execute multi-row insert");
let batches = query_all(
&ctx,
"SELECT id, name FROM test_items WHERE id >= 10 ORDER BY id",
)
.await;
assert_eq!(total_rows(&batches), 3);
}
#[tokio::test]
#[ignore]
async fn test_delete_with_filter() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("DELETE FROM test_items WHERE id > 1")
.await
.expect("parse delete")
.collect()
.await
.expect("execute delete");
let batches = query_all(&ctx, "SELECT id, name FROM test_items").await;
assert_eq!(total_rows(&batches), 1);
let ids = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(ids.value(0), 1);
}
#[tokio::test]
#[ignore]
async fn test_delete_all_rows() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("DELETE FROM test_items")
.await
.expect("parse delete all")
.collect()
.await
.expect("execute delete all");
let batches = query_all(&ctx, "SELECT id FROM test_items").await;
assert_eq!(total_rows(&batches), 0);
}
#[tokio::test]
#[ignore]
async fn test_delete_no_matching_rows() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("DELETE FROM test_items WHERE id = 999")
.await
.expect("parse delete")
.collect()
.await
.expect("execute delete");
let batches = query_all(&ctx, "SELECT id FROM test_items").await;
assert_eq!(total_rows(&batches), 3);
}
#[tokio::test]
#[ignore]
async fn test_update_single_column_with_filter() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("UPDATE test_items SET value = 200 WHERE id = 2")
.await
.expect("parse update")
.collect()
.await
.expect("execute update");
let batches = query_all(&ctx, "SELECT value FROM test_items WHERE id = 2").await;
assert_eq!(total_rows(&batches), 1);
let values = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(values.value(0), 200);
}
#[tokio::test]
#[ignore]
async fn test_update_multiple_columns() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("UPDATE test_items SET name = 'charlie', value = 300 WHERE id = 3")
.await
.expect("parse update")
.collect()
.await
.expect("execute update");
let batches = query_all(&ctx, "SELECT name, value FROM test_items WHERE id = 3").await;
assert_eq!(total_rows(&batches), 1);
let names = batches[0]
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let values = batches[0]
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(names.value(0), "charlie");
assert_eq!(values.value(0), 300);
}
#[tokio::test]
#[ignore]
async fn test_update_all_rows() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("UPDATE test_items SET value = 0")
.await
.expect("parse update all")
.collect()
.await
.expect("execute update all");
let batches = query_all(&ctx, "SELECT value FROM test_items").await;
assert_eq!(total_rows(&batches), 3);
for batch in &batches {
let values = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
for i in 0..values.len() {
assert_eq!(values.value(i), 0);
}
}
}
#[tokio::test]
#[ignore]
async fn test_update_no_matching_rows() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("UPDATE test_items SET value = 999 WHERE id = 999")
.await
.expect("parse update")
.collect()
.await
.expect("execute update");
let batches = query_all(&ctx, "SELECT value FROM test_items ORDER BY id").await;
assert_eq!(total_rows(&batches), 3);
let values = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(values.value(0), 10);
assert_eq!(values.value(1), 20);
assert_eq!(values.value(2), 30);
}
#[tokio::test]
#[ignore]
async fn test_insert_update_delete_round_trip() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("INSERT INTO test_items (id, name, value) VALUES (4, 'dave', 40)")
.await
.unwrap()
.collect()
.await
.unwrap();
let batches = query_all(&ctx, "SELECT id FROM test_items").await;
assert_eq!(total_rows(&batches), 4);
ctx.sql("UPDATE test_items SET value = 44 WHERE id = 4")
.await
.unwrap()
.collect()
.await
.unwrap();
let batches = query_all(&ctx, "SELECT value FROM test_items WHERE id = 4").await;
let values = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(values.value(0), 44);
ctx.sql("DELETE FROM test_items WHERE id <= 2")
.await
.unwrap()
.collect()
.await
.unwrap();
let batches = query_all(&ctx, "SELECT id FROM test_items ORDER BY id").await;
assert_eq!(total_rows(&batches), 2);
let ids = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(ids.value(0), 3); assert_eq!(ids.value(1), 4); }
async fn create_multi_table_db() -> tempfile::TempPath {
let tmp = tempfile::NamedTempFile::new().expect("create temp file");
let path = tmp.into_temp_path();
let db_path = path.to_str().unwrap().to_string();
let conn = Connection::open(&db_path).await.expect("open temp sqlite");
conn.call(|conn| -> Result<(), tokio_rusqlite::rusqlite::Error> {
conn.execute_batch(
"CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE NOT NULL
);
CREATE TABLE orders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
product TEXT NOT NULL,
amount REAL NOT NULL
);
CREATE TABLE user_order_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL UNIQUE,
user_name TEXT NOT NULL,
user_email TEXT NOT NULL,
total_orders INTEGER NOT NULL,
total_spent REAL NOT NULL,
last_order_date TEXT
);
INSERT INTO users (name, email) VALUES
('Alice Smith', 'alice@example.com'),
('Bob Johnson', 'bob@example.com');
INSERT INTO orders (user_id, product, amount) VALUES
(1, 'Laptop', 999.99),
(1, 'Mouse', 29.99),
(2, 'Keyboard', 79.99);
",
)?;
Ok(())
})
.await
.expect("seed multi-table db");
conn.close().await.expect("close seed connection");
path
}
async fn register_multi_tables(ctx: &mut SessionContext, db_path: &str) {
for (reg_name, table_name) in [
("users", "users"),
("orders", "orders"),
("user_order_stats", "user_order_stats"),
] {
let mut options = HashMap::new();
options.insert("table".to_string(), table_name.to_string());
register_sqlite_tables(
ctx,
reg_name,
db_path,
Some(&options),
true,
None,
HierarchyLevel::Table,
)
.await
.unwrap_or_else(|e| panic!("register {} failed: {}", reg_name, e));
}
}
#[tokio::test]
#[ignore]
async fn test_empty_table_schema_from_pragma() {
let db_path = create_multi_table_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_multi_tables(&mut ctx, db).await;
let catalog = ctx.catalog("datafusion").unwrap();
let schema = catalog.schema("public").unwrap();
let table = schema.table("user_order_stats").await.unwrap().unwrap();
let table_schema = table.schema();
let fields: Vec<&str> = table_schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
assert_eq!(
fields,
vec![
"id",
"user_id",
"user_name",
"user_email",
"total_orders",
"total_spent",
"last_order_date"
]
);
}
#[tokio::test]
#[ignore]
async fn test_insert_select_from_same_db() {
let db_path = create_multi_table_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_multi_tables(&mut ctx, db).await;
ctx.sql(
"INSERT INTO user_order_stats (user_id, user_name, user_email, total_orders, total_spent, last_order_date)
SELECT
u.id,
u.name,
u.email,
CAST(COUNT(o.id) AS BIGINT),
CAST(SUM(o.amount) AS DOUBLE),
CAST('N/A' AS TEXT)
FROM users u
INNER JOIN orders o ON u.id = o.user_id
WHERE u.name = 'Alice Smith'
GROUP BY u.id, u.name, u.email",
)
.await
.expect("parse insert-select")
.collect()
.await
.expect("execute insert-select");
let batches = query_all(
&ctx,
"SELECT user_id, user_name, total_orders, total_spent FROM user_order_stats",
)
.await;
assert_eq!(total_rows(&batches), 1);
let user_ids = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(user_ids.value(0), 1); }
#[tokio::test]
#[ignore]
async fn test_insert_select_multiple_users() {
let db_path = create_multi_table_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_multi_tables(&mut ctx, db).await;
ctx.sql(
"INSERT INTO user_order_stats (user_id, user_name, user_email, total_orders, total_spent, last_order_date)
SELECT
u.id,
u.name,
u.email,
CAST(COUNT(o.id) AS BIGINT),
CAST(SUM(o.amount) AS DOUBLE),
CAST('N/A' AS TEXT)
FROM users u
INNER JOIN orders o ON u.id = o.user_id
GROUP BY u.id, u.name, u.email",
)
.await
.expect("parse insert-select all")
.collect()
.await
.expect("execute insert-select all");
let batches = query_all(
&ctx,
"SELECT user_id, user_name, total_orders FROM user_order_stats ORDER BY user_id",
)
.await;
assert_eq!(total_rows(&batches), 2);
}
#[tokio::test]
#[ignore]
async fn test_schema_visibility_across_tables() {
let db_path = create_multi_table_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_multi_tables(&mut ctx, db).await;
let batches = query_all(&ctx, "SELECT id, name, email FROM users ORDER BY id").await;
assert_eq!(total_rows(&batches), 2);
let batches = query_all(
&ctx,
"SELECT id, user_id, product, amount FROM orders ORDER BY id",
)
.await;
assert_eq!(total_rows(&batches), 3);
let batches = query_all(
&ctx,
"SELECT u.name, o.product, o.amount
FROM users u
INNER JOIN orders o ON u.id = o.user_id
ORDER BY o.id",
)
.await;
assert_eq!(total_rows(&batches), 3);
}
#[tokio::test]
#[ignore]
async fn test_read_own_write_insert_then_select() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("INSERT INTO test_items (id, name, value) VALUES (4, 'dave', 40)")
.await
.unwrap()
.collect()
.await
.unwrap();
let batches = query_all(&ctx, "SELECT id, name FROM test_items WHERE id = 4").await;
assert_eq!(
total_rows(&batches),
1,
"inserted row must be visible immediately"
);
let names = batches[0]
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "dave");
}
#[tokio::test]
#[ignore]
async fn test_read_own_write_delete_then_select() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("DELETE FROM test_items WHERE id = 1")
.await
.unwrap()
.collect()
.await
.unwrap();
let batches = query_all(&ctx, "SELECT id FROM test_items WHERE id = 1").await;
assert_eq!(total_rows(&batches), 0, "deleted row must not be visible");
let batches = query_all(&ctx, "SELECT id FROM test_items").await;
assert_eq!(total_rows(&batches), 2);
}
#[tokio::test]
#[ignore]
async fn test_read_own_write_update_then_select() {
let db_path = create_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_test_table(&mut ctx, db).await;
ctx.sql("UPDATE test_items SET value = 999 WHERE id = 2")
.await
.unwrap()
.collect()
.await
.unwrap();
let batches = query_all(&ctx, "SELECT value FROM test_items WHERE id = 2").await;
assert_eq!(total_rows(&batches), 1);
let values = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(
values.value(0),
999,
"updated value must be visible immediately"
);
}
async fn create_empty_type_db() -> tempfile::TempPath {
let tmp = tempfile::NamedTempFile::new().expect("create temp file");
let path = tmp.into_temp_path();
let db_path = path.to_str().unwrap().to_string();
let conn = Connection::open(&db_path).await.expect("open temp sqlite");
conn.call(|conn| -> Result<(), tokio_rusqlite::rusqlite::Error> {
conn.execute_batch(
"CREATE TABLE blob_items (
id INTEGER PRIMARY KEY,
data \"\" -- empty type, like vec0 virtual tables
);",
)?;
let vecs: Vec<(i64, Vec<f32>)> =
vec![(1, vec![1.0, 0.0, 0.0]), (2, vec![0.0, 1.0, 0.0])];
let mut stmt = conn.prepare("INSERT INTO blob_items (id, data) VALUES (?1, ?2)")?;
for (id, vec) in &vecs {
let blob: Vec<u8> = vec.iter().flat_map(|f| f.to_le_bytes()).collect();
stmt.execute(tokio_rusqlite::rusqlite::params![id, blob])?;
}
Ok(())
})
.await
.expect("seed blob_items table");
conn.close().await.expect("close seed connection");
path
}
#[tokio::test]
#[ignore]
async fn test_empty_type_column_read_as_binary() {
let db_path = create_empty_type_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
let mut options = HashMap::new();
options.insert("table".to_string(), "blob_items".to_string());
register_sqlite_tables(
&mut ctx,
"blob_items",
db,
Some(&options),
false,
None,
HierarchyLevel::Table,
)
.await
.expect("register blob_items");
let batches = query_all(&ctx, "SELECT id, data FROM blob_items ORDER BY id").await;
assert_eq!(total_rows(&batches), 2);
let schema = batches[0].schema();
let data_field = schema.field_with_name("data").unwrap();
assert_eq!(
*data_field.data_type(),
DataType::Binary,
"empty-type column should map to Binary"
);
let data_col = batches[0]
.column(1)
.as_any()
.downcast_ref::<BinaryArray>()
.expect("data column should be BinaryArray");
let blob = data_col.value(0);
assert_eq!(blob.len(), 12, "3 × f32 = 12 bytes");
let floats: Vec<f32> = blob
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
assert_eq!(floats, vec![1.0f32, 0.0, 0.0]);
}
#[tokio::test]
#[ignore]
async fn test_empty_type_column_subquery_extracts_blob() {
let db_path = create_empty_type_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
let mut options = HashMap::new();
options.insert("table".to_string(), "blob_items".to_string());
register_sqlite_tables(
&mut ctx,
"blob_items",
db,
Some(&options),
false,
None,
HierarchyLevel::Table,
)
.await
.expect("register blob_items");
let batches = query_all(&ctx, "SELECT data FROM blob_items WHERE id = 1").await;
assert_eq!(total_rows(&batches), 1);
let data_col = batches[0]
.column(0)
.as_any()
.downcast_ref::<BinaryArray>()
.expect("subquery should return BinaryArray");
assert!(!data_col.is_null(0), "BLOB should not be null");
assert_eq!(data_col.value(0).len(), 12);
}
async fn create_catalog_test_db() -> tempfile::TempPath {
let tmp = tempfile::NamedTempFile::new().expect("create temp file");
let path = tmp.into_temp_path();
let db_path = path.to_str().unwrap().to_string();
let conn = Connection::open(&db_path).await.expect("open temp sqlite");
conn.call(|conn| -> Result<(), tokio_rusqlite::rusqlite::Error> {
conn.execute_batch(
"CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL
);
INSERT INTO users (name) VALUES ('alice'), ('bob'), ('carol');
CREATE TABLE orders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
amount INTEGER NOT NULL
);
INSERT INTO orders (user_id, amount) VALUES (1, 100), (1, 50), (2, 200);
CREATE VIEW user_totals AS
SELECT u.name AS name, SUM(o.amount) AS total
FROM users u
JOIN orders o ON u.id = o.user_id
GROUP BY u.id, u.name;",
)?;
Ok(())
})
.await
.expect("seed catalog db");
conn.close().await.expect("close seed connection");
path
}
#[tokio::test]
#[ignore]
async fn test_list_sqlite_tables_filters_internal_and_shadow_tables() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let conn = Connection::open(db).await.expect("open");
init_connection(&conn, 5000, &[]).await.expect("init");
let tables = list_sqlite_tables_in_catalog(&conn)
.await
.expect("list tables");
let names: Vec<&str> = tables.iter().map(|(_, t)| t.as_str()).collect();
assert!(names.contains(&"users"), "expected users in {:?}", names);
assert!(names.contains(&"orders"), "expected orders in {:?}", names);
assert!(
names.contains(&"user_totals"),
"expected user_totals view in {:?}",
names
);
assert!(
!names.iter().any(|n| n.starts_with("sqlite_")),
"sqlite_* internal tables must be filtered: {:?}",
names
);
for (schema, _) in &tables {
assert_eq!(schema, "main");
}
}
#[tokio::test]
#[ignore]
async fn test_catalog_registers_all_user_tables_and_views() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_sqlite_tables(
&mut ctx,
"demo",
db,
None,
false,
None,
HierarchyLevel::Catalog,
)
.await
.expect("register catalog");
let users = query_all(&ctx, "SELECT id FROM demo.main.users ORDER BY id").await;
assert_eq!(total_rows(&users), 3);
let orders = query_all(&ctx, "SELECT id FROM demo.main.orders ORDER BY id").await;
assert_eq!(total_rows(&orders), 3);
let totals = query_all(
&ctx,
"SELECT name, total FROM demo.main.user_totals ORDER BY name",
)
.await;
assert_eq!(total_rows(&totals), 2);
}
#[tokio::test]
#[ignore]
async fn test_catalog_cross_table_join() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_sqlite_tables(
&mut ctx,
"demo",
db,
None,
false,
None,
HierarchyLevel::Catalog,
)
.await
.expect("register catalog");
let batches = query_all(
&ctx,
"SELECT u.name, SUM(o.amount) AS total \
FROM demo.main.users u \
JOIN demo.main.orders o ON u.id = o.user_id \
GROUP BY u.name \
ORDER BY u.name",
)
.await;
assert_eq!(total_rows(&batches), 2);
}
#[tokio::test]
#[ignore]
async fn test_catalog_excludes_sqlite_sequence() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_sqlite_tables(
&mut ctx,
"demo",
db,
None,
false,
None,
HierarchyLevel::Catalog,
)
.await
.expect("register catalog");
let result = ctx
.sql("SELECT * FROM demo.main.sqlite_sequence")
.await
.and_then(|df| futures::executor::block_on(df.collect()));
assert!(
result.is_err(),
"sqlite_sequence should not be registered in the catalog"
);
}
#[tokio::test]
#[ignore]
async fn test_catalog_allowed_schemas_main_includes_all() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
let mut options = HashMap::new();
options.insert("allowed_schemas".to_string(), "main".to_string());
register_sqlite_tables(
&mut ctx,
"demo",
db,
Some(&options),
false,
None,
HierarchyLevel::Catalog,
)
.await
.expect("register catalog");
let users = query_all(&ctx, "SELECT id FROM demo.main.users ORDER BY id").await;
assert_eq!(total_rows(&users), 3);
}
#[tokio::test]
#[ignore]
async fn test_catalog_allowed_schemas_non_main_is_empty() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
let mut options = HashMap::new();
options.insert("allowed_schemas".to_string(), "public".to_string());
register_sqlite_tables(
&mut ctx,
"demo",
db,
Some(&options),
false,
None,
HierarchyLevel::Catalog,
)
.await
.expect("register catalog");
let result = ctx.sql("SELECT * FROM demo.main.users").await;
assert!(
result.is_err(),
"allowed_schemas=public should register no tables under SQLite's 'main' schema"
);
}
#[tokio::test]
#[ignore]
async fn test_catalog_read_write_mode_allows_inserts() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let mut ctx = SessionContext::new();
register_sqlite_tables(
&mut ctx,
"demo",
db,
None,
true,
None,
HierarchyLevel::Catalog,
)
.await
.expect("register catalog rw");
ctx.sql("INSERT INTO demo.main.users (name) VALUES ('dave')")
.await
.expect("parse insert")
.collect()
.await
.expect("execute insert");
let batches = query_all(&ctx, "SELECT name FROM demo.main.users WHERE name = 'dave'").await;
assert_eq!(total_rows(&batches), 1);
}
#[tokio::test]
#[ignore]
async fn test_catalog_empty_database_is_ok() {
let tmp = tempfile::NamedTempFile::new().expect("create temp file");
let path = tmp.into_temp_path();
let db = path.to_str().unwrap();
let conn = Connection::open(db).await.expect("open empty db");
conn.close().await.expect("close empty db");
let mut ctx = SessionContext::new();
register_sqlite_tables(
&mut ctx,
"demo",
db,
None,
false,
None,
HierarchyLevel::Catalog,
)
.await
.expect("register empty catalog");
}
#[tokio::test]
#[ignore]
async fn test_catalog_shares_single_read_pool() {
let db_path = create_catalog_test_db().await;
let db = db_path.to_str().unwrap();
let (read_pool, write_conn) = open_sqlite_pool(db, 5000, 2, false, &[])
.await
.expect("open shared pool");
assert_eq!(read_pool.len(), 2);
assert!(write_conn.is_none());
let users_provider = SqliteTableProvider::from_shared_pool(
read_pool.clone(),
None,
TableReference::bare("users"),
false,
)
.await
.expect("users provider");
let orders_provider = SqliteTableProvider::from_shared_pool(
read_pool.clone(),
None,
TableReference::bare("orders"),
false,
)
.await
.expect("orders provider");
for (u, o) in users_provider
.read_pool
.iter()
.zip(orders_provider.read_pool.iter())
{
assert!(
Arc::ptr_eq(u, o),
"catalog providers must share the same connections"
);
}
}
}