use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use arrow::array::{Array, AsArray, BooleanBuilder, RecordBatch};
use arrow::datatypes::{
DataType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef,
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use datafusion::common::GetExt;
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::datasource::provider::DefaultTableFactory;
use datafusion::datasource::MemTable;
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::*;
use sqlparser::ast::{
AssignmentTarget, BinaryOperator, Expr, FromTable, SetExpr, Statement, TableFactor,
TableObject, UnaryOperator, Value,
};
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
use tokio::sync::RwLock;
use tracing::debug;
use vortex::session::VortexSession;
use vortex::VortexSessionDefault;
use vortex_datafusion::VortexFormat;
use vortex_datafusion::VortexFormatFactory;
use crate::error::DfOlapError;
use crate::storage::{StorageMode, VortexLocation};
#[cfg(feature = "cloud-storage")]
use url::Url;
struct TableData {
schema: SchemaRef,
batches: Vec<RecordBatch>,
}
struct VortexTableMeta {
schema: SchemaRef,
table_url: String,
}
fn build_vortex_session_context(location: &VortexLocation) -> Result<SessionContext, DfOlapError> {
let factory = Arc::new(VortexFormatFactory::new());
let mut state_builder = SessionStateBuilder::new()
.with_default_features()
.with_table_factory(
factory.get_ext().to_uppercase(),
Arc::new(DefaultTableFactory::new()),
);
if let Some(file_formats) = state_builder.file_formats() {
file_formats.push(factory.clone() as _);
}
let ctx = SessionContext::new_with_state(state_builder.build()).enable_url_table();
#[cfg(feature = "cloud-storage")]
if let VortexLocation::S3 { url } = location {
let bucket = parse_s3_bucket(url)?;
let store: Arc<dyn object_store::ObjectStore> = Arc::new(
object_store::aws::AmazonS3Builder::from_env()
.with_bucket_name(&bucket)
.build()
.map_err(DfOlapError::ObjectStore)?,
);
let base_url = Url::parse(&format!("s3://{bucket}")).map_err(DfOlapError::UrlParse)?;
ctx.runtime_env().register_object_store(&base_url, store);
tracing::info!(bucket, "registered S3 object store for Vortex");
}
let _ = location; Ok(ctx)
}
#[cfg(feature = "cloud-storage")]
fn parse_s3_bucket(url: &str) -> Result<String, DfOlapError> {
let parsed = Url::parse(url).map_err(DfOlapError::UrlParse)?;
if parsed.scheme() != "s3" {
return Err(DfOlapError::StorageConfig(format!(
"expected s3:// URL, got '{url}'"
)));
}
parsed
.host_str()
.map(|h| h.to_string())
.ok_or_else(|| DfOlapError::StorageConfig(format!("missing bucket name in URL '{url}'")))
}
fn table_listing_url(location: &VortexLocation, table_name: &str) -> String {
match location {
VortexLocation::Local { base_path } => {
let dir = base_path.join(table_name);
format!("file://{}/", dir.to_string_lossy())
}
#[cfg(feature = "cloud-storage")]
VortexLocation::S3 { url } => {
let base = url.trim_end_matches('/');
format!("{base}/{table_name}/")
}
}
}
async fn register_vortex_listing_table(
ctx: &SessionContext,
table_name: &str,
schema: &SchemaRef,
listing_url: &str,
) -> Result<(), DfOlapError> {
let vortex_format = Arc::new(VortexFormat::new(
<VortexSession as VortexSessionDefault>::default(),
));
let listing_options = ListingOptions::new(vortex_format as _)
.with_file_extension("vortex")
.with_session_config_options(ctx.state().config());
let table_url = ListingTableUrl::parse(listing_url)?;
let config = ListingTableConfig::new(table_url)
.with_listing_options(listing_options)
.with_schema(schema.clone());
let listing_table = ListingTable::try_new(config)?;
let _ = ctx.deregister_table(table_name);
ctx.register_table(table_name, Arc::new(listing_table))?;
Ok(())
}
pub struct DataFusionEngine {
ctx: RwLock<SessionContext>,
tables: RwLock<HashMap<String, TableData>>,
vortex_tables: RwLock<HashMap<String, VortexTableMeta>>,
vortex_location: Option<VortexLocation>,
storage_mode: StorageMode,
tmp_counter: AtomicU64,
}
impl DataFusionEngine {
pub fn with_storage(mode: StorageMode) -> Result<Self, DfOlapError> {
let vortex_location = match mode.classify() {
Ok(Some(loc)) => {
#[cfg(not(feature = "cloud-storage"))]
{
let VortexLocation::Local { ref base_path } = loc;
std::fs::create_dir_all(base_path)?;
}
#[cfg(feature = "cloud-storage")]
if let VortexLocation::Local { ref base_path } = loc {
std::fs::create_dir_all(base_path)?;
}
Some(loc)
}
Ok(None) => None,
Err(e) => return Err(DfOlapError::StorageConfig(e)),
};
let ctx = if let Some(ref loc) = vortex_location {
build_vortex_session_context(loc)?
} else {
SessionContext::new()
};
Ok(Self {
ctx: RwLock::new(ctx),
tables: RwLock::new(HashMap::new()),
vortex_tables: RwLock::new(HashMap::new()),
vortex_location,
storage_mode: mode,
tmp_counter: AtomicU64::new(0),
})
}
pub fn new() -> Self {
Self::with_storage(StorageMode::InMemory).expect("in-memory mode cannot fail")
}
pub fn storage_mode(&self) -> &StorageMode {
&self.storage_mode
}
fn location(&self) -> Option<&VortexLocation> {
self.vortex_location.as_ref()
}
async fn refresh_table_mem(&self, name: &str) -> Result<(), DfOlapError> {
let tables = self.tables.read().await;
let table_data = tables
.get(name)
.ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
let partitions = if table_data.batches.is_empty() {
vec![vec![]]
} else {
vec![table_data.batches.clone()]
};
let mem_table = MemTable::try_new(table_data.schema.clone(), partitions)?;
let ctx = self.ctx.write().await;
let _ = ctx.deregister_table(name);
ctx.register_table(name, Arc::new(mem_table))?;
Ok(())
}
async fn refresh_table_vortex(&self, name: &str) -> Result<(), DfOlapError> {
let vortex_tables = self.vortex_tables.read().await;
let meta = vortex_tables
.get(name)
.ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
let schema = meta.schema.clone();
let listing_url = meta.table_url.clone();
drop(vortex_tables);
let ctx = self.ctx.read().await;
register_vortex_listing_table(&ctx, name, &schema, &listing_url).await
}
async fn vortex_table_schema(&self, table_name: &str) -> Result<SchemaRef, DfOlapError> {
let vortex_tables = self.vortex_tables.read().await;
vortex_tables
.get(table_name)
.map(|m| m.schema.clone())
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))
}
async fn read_all_batches_vortex(
&self,
table_name: &str,
) -> Result<(SchemaRef, Vec<RecordBatch>), DfOlapError> {
let schema = self.vortex_table_schema(table_name).await?;
let ctx = self.ctx.read().await;
let df = ctx.sql(&format!("SELECT * FROM \"{table_name}\"")).await?;
let batches = df.collect().await?;
Ok((schema, batches))
}
async fn clear_table_storage(&self, table_name: &str) -> Result<(), DfOlapError> {
let loc = self
.location()
.ok_or_else(|| DfOlapError::Other("expected Vortex location".into()))?;
match loc {
VortexLocation::Local { base_path } => {
let dir = base_path.join(table_name);
if !dir.exists() {
return Ok(());
}
tokio::task::spawn_blocking(move || {
let entries: Vec<_> = std::fs::read_dir(&dir)?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|x| x == "vortex"))
.collect();
for path in entries {
std::fs::remove_file(path)?;
}
Ok::<_, DfOlapError>(())
})
.await
.map_err(DfOlapError::from_join)?
}
#[cfg(feature = "cloud-storage")]
VortexLocation::S3 { url } => self.clear_s3_table_prefix(table_name, url).await,
}
}
#[cfg(feature = "cloud-storage")]
async fn clear_s3_table_prefix(&self, table_name: &str, url: &str) -> Result<(), DfOlapError> {
use futures::StreamExt;
#[allow(unused_imports)]
use object_store::{ObjectStore, ObjectStoreExt};
let bucket = parse_s3_bucket(url)?;
let table_prefix = {
let parsed = Url::parse(url).map_err(DfOlapError::UrlParse)?;
let trimmed = parsed.path().trim_start_matches('/').trim_end_matches('/');
if trimmed.is_empty() {
format!("{table_name}/")
} else {
format!("{trimmed}/{table_name}/")
}
};
let osu_str = format!("s3://{bucket}/");
let osu =
datafusion::execution::object_store::ObjectStoreUrl::parse(&osu_str).map_err(|e| {
DfOlapError::Other(format!("invalid object-store URL '{osu_str}': {e}"))
})?;
let store = self
.ctx
.read()
.await
.runtime_env()
.object_store(osu)
.map_err(|e| {
DfOlapError::Other(format!(
"object store for s3://{bucket} not registered: {e}"
))
})?;
let prefix = object_store::path::Path::from(table_prefix.as_str());
let mut list = store.list(Some(&prefix));
let mut to_delete: Vec<object_store::path::Path> = Vec::new();
while let Some(meta) = list.next().await {
let meta = meta.map_err(DfOlapError::ObjectStore)?;
if meta
.location
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("vortex"))
{
to_delete.push(meta.location);
}
}
for path in to_delete {
store
.delete(&path)
.await
.map_err(DfOlapError::ObjectStore)?;
}
Ok(())
}
async fn insert_arrow_into_vortex(
&self,
table_name: &str,
schema: &SchemaRef,
batches: &[RecordBatch],
) -> Result<u64, DfOlapError> {
if batches.is_empty() {
return Ok(0);
}
let total_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
let tmp_name = format!(
"__tmp_load_{}",
self.tmp_counter.fetch_add(1, Ordering::Relaxed)
);
let mem_table = MemTable::try_new(schema.clone(), vec![batches.to_vec()])?;
{
let ctx = self.ctx.read().await;
let _ = ctx.deregister_table(&tmp_name);
ctx.register_table(&tmp_name, Arc::new(mem_table))?;
ctx.sql(&format!(
"INSERT INTO \"{table_name}\" SELECT * FROM \"{tmp_name}\""
))
.await?
.collect()
.await?;
let _ = ctx.deregister_table(&tmp_name);
}
Ok(total_rows)
}
async fn rewrite_vortex_table(
&self,
table_name: &str,
schema: &SchemaRef,
batches: &[RecordBatch],
) -> Result<(), DfOlapError> {
self.clear_table_storage(table_name).await?;
self.refresh_table_vortex(table_name).await?;
if !batches.is_empty() {
self.insert_arrow_into_vortex(table_name, schema, batches)
.await?;
self.refresh_table_vortex(table_name).await?;
}
Ok(())
}
async fn execute_sql(&self, sql: &str) -> Result<Vec<RecordBatch>, DfOlapError> {
let ctx = self.ctx.read().await;
let df = ctx.sql(sql).await?;
let batches = df.collect().await?;
Ok(batches)
}
async fn execute_insert_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, col_names, batches) = parse_insert_values(sql)?;
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let table_schema = table_data.schema.clone();
let (aligned_batches, total_rows) =
align_batches_to_schema(&table_schema, &col_names, &batches)?;
table_data.batches.extend(aligned_batches);
drop(tables);
self.refresh_table_mem(&table_name).await?;
Ok(total_rows)
}
async fn execute_update_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, assignments, where_clause) = parse_update(sql)?;
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let schema = table_data.schema.clone();
let mut updated_count = 0u64;
let all_rows = flatten_batches(&table_data.batches, &schema)?;
if let Some(all_rows) = all_rows {
let (updated_batch, count) =
apply_update(&all_rows, &schema, &assignments, &where_clause)?;
updated_count = count;
table_data.batches = vec![updated_batch];
}
drop(tables);
self.refresh_table_mem(&table_name).await?;
Ok(updated_count)
}
async fn execute_delete_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, where_clause) = parse_delete(sql)?;
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let schema = table_data.schema.clone();
let all_rows = flatten_batches(&table_data.batches, &schema)?;
if let Some(all_rows) = all_rows {
let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
table_data.batches = if filtered_batch.num_rows() > 0 {
vec![filtered_batch]
} else {
vec![]
};
drop(tables);
self.refresh_table_mem(&table_name).await?;
Ok(deleted_count)
} else {
Ok(0)
}
}
async fn execute_insert_vortex(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, col_names, batches) = parse_insert_values(sql)?;
let schema = self.vortex_table_schema(&table_name).await?;
let (aligned_batches, total_rows) = align_batches_to_schema(&schema, &col_names, &batches)?;
self.insert_arrow_into_vortex(&table_name, &schema, &aligned_batches)
.await?;
self.refresh_table_vortex(&table_name).await?;
Ok(total_rows)
}
async fn execute_update_vortex(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, assignments, where_clause) = parse_update(sql)?;
let (schema, existing_batches) = self.read_all_batches_vortex(&table_name).await?;
let all_rows = flatten_batches(&existing_batches, &schema)?;
if let Some(all_rows) = all_rows {
let (updated_batch, count) =
apply_update(&all_rows, &schema, &assignments, &where_clause)?;
let new_batches = if updated_batch.num_rows() > 0 {
vec![updated_batch]
} else {
vec![]
};
self.rewrite_vortex_table(&table_name, &schema, &new_batches)
.await?;
Ok(count)
} else {
Ok(0)
}
}
async fn execute_delete_vortex(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, where_clause) = parse_delete(sql)?;
let (schema, existing_batches) = self.read_all_batches_vortex(&table_name).await?;
let all_rows = flatten_batches(&existing_batches, &schema)?;
if let Some(all_rows) = all_rows {
let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
let new_batches = if filtered_batch.num_rows() > 0 {
vec![filtered_batch]
} else {
vec![]
};
self.rewrite_vortex_table(&table_name, &schema, &new_batches)
.await?;
Ok(deleted_count)
} else {
Ok(0)
}
}
async fn execute_insert(&self, sql: &str) -> Result<u64, DfOlapError> {
match &self.storage_mode {
StorageMode::InMemory => self.execute_insert_mem(sql).await,
StorageMode::Vortex { .. } => self.execute_insert_vortex(sql).await,
}
}
async fn execute_update(&self, sql: &str) -> Result<u64, DfOlapError> {
match &self.storage_mode {
StorageMode::InMemory => self.execute_update_mem(sql).await,
StorageMode::Vortex { .. } => self.execute_update_vortex(sql).await,
}
}
async fn execute_delete(&self, sql: &str) -> Result<u64, DfOlapError> {
match &self.storage_mode {
StorageMode::InMemory => self.execute_delete_mem(sql).await,
StorageMode::Vortex { .. } => self.execute_delete_vortex(sql).await,
}
}
}
impl Default for DataFusionEngine {
fn default() -> Self {
Self::new()
}
}
impl rhei_core::OlapEngine for DataFusionEngine {
type Error = DfOlapError;
async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
debug!(sql, "DataFusion query");
self.execute_sql(sql).await
}
async fn query_stream(
&self,
sql: &str,
) -> Result<rhei_core::RecordBatchBoxStream, Self::Error> {
debug!(sql, "DataFusion query_stream");
let ctx = self.ctx.read().await;
let df = ctx.sql(sql).await?;
let stream = df.execute_stream().await?;
let mapped = Box::pin(StreamAdapter(stream));
Ok(mapped)
}
async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
debug!(sql, "DataFusion execute");
let trimmed = sql.trim();
let upper = trimmed.to_ascii_uppercase();
if upper.starts_with("INSERT") {
self.execute_insert(trimmed).await
} else if upper.starts_with("UPDATE") {
self.execute_update(trimmed).await
} else if upper.starts_with("DELETE") {
self.execute_delete(trimmed).await
} else if upper.starts_with("BEGIN")
|| upper.starts_with("COMMIT")
|| upper.starts_with("ROLLBACK")
{
Ok(0)
} else {
let ctx = self.ctx.read().await;
let df = ctx.sql(trimmed).await?;
let _ = df.collect().await?;
Ok(0)
}
}
async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
if batches.is_empty() {
return Ok(0);
}
debug!(table, batch_count = batches.len(), "DataFusion load_arrow");
rhei_core::validate_identifier(table).map_err(|e| DfOlapError::Other(e.to_string()))?;
let total_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(table)
.ok_or_else(|| DfOlapError::TableNotFound(table.to_string()))?;
for batch in batches {
table_data.batches.push(batch.clone());
}
drop(tables);
self.refresh_table_mem(table).await?;
}
StorageMode::Vortex { .. } => {
let schema = self.vortex_table_schema(table).await?;
self.insert_arrow_into_vortex(table, &schema, batches)
.await?;
self.refresh_table_vortex(table).await?;
}
}
Ok(total_rows)
}
async fn create_table(
&self,
table_name: &str,
schema: &SchemaRef,
_primary_key: &[String],
) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
for field in schema.fields() {
rhei_core::validate_identifier(field.name())
.map_err(|e| DfOlapError::Other(e.to_string()))?;
}
debug!(
table = table_name,
storage = ?self.storage_mode,
"DataFusion create_table"
);
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
if tables.contains_key(table_name) {
return Ok(());
}
tables.insert(
table_name.to_string(),
TableData {
schema: schema.clone(),
batches: vec![],
},
);
drop(tables);
self.refresh_table_mem(table_name).await?;
}
StorageMode::Vortex { .. } => {
let loc = self
.location()
.expect("Vortex mode must have a resolved location");
{
let vortex_tables = self.vortex_tables.read().await;
if vortex_tables.contains_key(table_name) {
return Ok(());
}
}
let listing_url = table_listing_url(loc, table_name);
#[cfg(not(feature = "cloud-storage"))]
{
let VortexLocation::Local { ref base_path } = *loc;
let dir = base_path.join(table_name);
tokio::fs::create_dir_all(&dir).await?;
}
#[cfg(feature = "cloud-storage")]
if let VortexLocation::Local { ref base_path } = *loc {
let dir = base_path.join(table_name);
tokio::fs::create_dir_all(&dir).await?;
}
let mut vortex_tables = self.vortex_tables.write().await;
vortex_tables.insert(
table_name.to_string(),
VortexTableMeta {
schema: schema.clone(),
table_url: listing_url.clone(),
},
);
drop(vortex_tables);
let ctx = self.ctx.read().await;
register_vortex_listing_table(&ctx, table_name, schema, &listing_url).await?;
}
}
Ok(())
}
async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
match &self.storage_mode {
StorageMode::InMemory => {
let tables = self.tables.read().await;
Ok(tables.contains_key(table_name))
}
StorageMode::Vortex { .. } => {
let vortex_tables = self.vortex_tables.read().await;
Ok(vortex_tables.contains_key(table_name))
}
}
}
async fn add_column(
&self,
table_name: &str,
column_name: &str,
data_type: &DataType,
) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
rhei_core::validate_identifier(column_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
debug!(
table = table_name,
column = column_name,
"DataFusion add_column"
);
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let new_schema = append_field(&table_data.schema, column_name, data_type);
let new_batches =
extend_batches_with_null_column(&table_data.batches, &new_schema, data_type)?;
table_data.schema = new_schema;
table_data.batches = new_batches;
drop(tables);
self.refresh_table_mem(table_name).await?;
}
StorageMode::Vortex { .. } => {
let (old_schema, existing_batches) =
self.read_all_batches_vortex(table_name).await?;
let new_schema = append_field(&old_schema, column_name, data_type);
let new_batches =
extend_batches_with_null_column(&existing_batches, &new_schema, data_type)?;
{
let mut vortex_tables = self.vortex_tables.write().await;
if let Some(meta) = vortex_tables.get_mut(table_name) {
meta.schema = new_schema.clone();
}
}
self.clear_table_storage(table_name).await?;
self.refresh_table_vortex(table_name).await?;
if !new_batches.is_empty() {
self.insert_arrow_into_vortex(table_name, &new_schema, &new_batches)
.await?;
self.refresh_table_vortex(table_name).await?;
}
}
}
Ok(())
}
async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
rhei_core::validate_identifier(column_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
debug!(
table = table_name,
column = column_name,
"DataFusion drop_column"
);
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let col_idx = find_column_index(&table_data.schema, column_name, table_name)?;
let new_schema = remove_field(&table_data.schema, col_idx);
let new_batches =
remove_column_from_batches(&table_data.batches, &new_schema, col_idx)?;
table_data.schema = new_schema;
table_data.batches = new_batches;
drop(tables);
self.refresh_table_mem(table_name).await?;
}
StorageMode::Vortex { .. } => {
let (old_schema, existing_batches) =
self.read_all_batches_vortex(table_name).await?;
let col_idx = find_column_index(&old_schema, column_name, table_name)?;
let new_schema = remove_field(&old_schema, col_idx);
let new_batches =
remove_column_from_batches(&existing_batches, &new_schema, col_idx)?;
{
let mut vortex_tables = self.vortex_tables.write().await;
if let Some(meta) = vortex_tables.get_mut(table_name) {
meta.schema = new_schema.clone();
}
}
self.clear_table_storage(table_name).await?;
self.refresh_table_vortex(table_name).await?;
if !new_batches.is_empty() {
self.insert_arrow_into_vortex(table_name, &new_schema, &new_batches)
.await?;
self.refresh_table_vortex(table_name).await?;
}
}
}
Ok(())
}
}
fn append_field(schema: &SchemaRef, column_name: &str, data_type: &DataType) -> SchemaRef {
let mut fields: Vec<arrow::datatypes::Field> =
schema.fields().iter().map(|f| f.as_ref().clone()).collect();
fields.push(arrow::datatypes::Field::new(
column_name,
data_type.clone(),
true,
));
Arc::new(arrow::datatypes::Schema::new(fields))
}
fn remove_field(schema: &SchemaRef, col_idx: usize) -> SchemaRef {
let fields: Vec<arrow::datatypes::Field> = schema
.fields()
.iter()
.enumerate()
.filter(|(i, _)| *i != col_idx)
.map(|(_, f)| f.as_ref().clone())
.collect();
Arc::new(arrow::datatypes::Schema::new(fields))
}
fn find_column_index(
schema: &SchemaRef,
column_name: &str,
table_name: &str,
) -> Result<usize, DfOlapError> {
schema
.fields()
.iter()
.position(|f| f.name() == column_name)
.ok_or_else(|| {
DfOlapError::Other(format!(
"column '{}' not found in table '{}'",
column_name, table_name
))
})
}
fn extend_batches_with_null_column(
batches: &[RecordBatch],
new_schema: &SchemaRef,
data_type: &DataType,
) -> Result<Vec<RecordBatch>, DfOlapError> {
let mut new_batches = Vec::with_capacity(batches.len());
for batch in batches {
let null_array = arrow::array::new_null_array(data_type, batch.num_rows());
let mut columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
.map(|i| batch.column(i).clone())
.collect();
columns.push(null_array);
new_batches.push(RecordBatch::try_new(new_schema.clone(), columns)?);
}
Ok(new_batches)
}
fn remove_column_from_batches(
batches: &[RecordBatch],
new_schema: &SchemaRef,
col_idx: usize,
) -> Result<Vec<RecordBatch>, DfOlapError> {
let mut new_batches = Vec::with_capacity(batches.len());
for batch in batches {
let columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
.filter(|i| *i != col_idx)
.map(|i| batch.column(i).clone())
.collect();
new_batches.push(RecordBatch::try_new(new_schema.clone(), columns)?);
}
Ok(new_batches)
}
fn align_batches_to_schema(
table_schema: &SchemaRef,
col_names: &[String],
batches: &[RecordBatch],
) -> Result<(Vec<RecordBatch>, u64), DfOlapError> {
let mut aligned_batches = Vec::with_capacity(batches.len());
let mut total_rows = 0u64;
for batch in batches {
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(table_schema.fields().len());
for field in table_schema.fields() {
let idx = col_names
.iter()
.position(|c| c == field.name())
.ok_or_else(|| {
DfOlapError::SchemaMismatch(format!(
"column '{}' not in INSERT column list",
field.name()
))
})?;
let col = batch.column(idx);
let col = if col.data_type() != field.data_type() {
arrow::compute::cast(col, field.data_type())?
} else {
col.clone()
};
columns.push(col);
}
let aligned = RecordBatch::try_new(table_schema.clone(), columns)?;
total_rows += aligned.num_rows() as u64;
aligned_batches.push(aligned);
}
Ok((aligned_batches, total_rows))
}
struct StreamAdapter(datafusion::physical_plan::SendableRecordBatchStream);
impl futures_core::Stream for StreamAdapter {
type Item = Result<RecordBatch, Box<dyn std::error::Error + Send + Sync>>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.0).poll_next(cx).map(|opt| {
opt.map(|r| r.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>))
})
}
}
#[derive(Clone)]
pub struct SharedDataFusionEngine(pub Arc<DataFusionEngine>);
impl SharedDataFusionEngine {
pub fn new(engine: DataFusionEngine) -> Self {
Self(Arc::new(engine))
}
}
impl std::ops::Deref for SharedDataFusionEngine {
type Target = DataFusionEngine;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl rhei_core::OlapEngine for SharedDataFusionEngine {
type Error = DfOlapError;
async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
self.0.query(sql).await
}
async fn query_stream(
&self,
sql: &str,
) -> Result<rhei_core::RecordBatchBoxStream, Self::Error> {
self.0.query_stream(sql).await
}
async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
self.0.execute(sql).await
}
async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
self.0.load_arrow(table, batches).await
}
async fn create_table(
&self,
table_name: &str,
schema: &SchemaRef,
primary_key: &[String],
) -> Result<(), Self::Error> {
self.0.create_table(table_name, schema, primary_key).await
}
async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
self.0.table_exists(table_name).await
}
async fn add_column(
&self,
table_name: &str,
column_name: &str,
data_type: &DataType,
) -> Result<(), Self::Error> {
self.0.add_column(table_name, column_name, data_type).await
}
async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
self.0.drop_column(table_name, column_name).await
}
}
fn expr_to_sql_literal(expr: &Expr) -> Result<String, DfOlapError> {
match expr {
Expr::Value(v) => match &v.value {
Value::Number(n, _) => Ok(n.clone()),
Value::SingleQuotedString(s) => Ok(format!("'{}'", s.replace('\'', "''"))),
Value::Boolean(b) => Ok(if *b { "TRUE".into() } else { "FALSE".into() }),
Value::Null => Ok("NULL".into()),
other => Err(DfOlapError::Other(format!(
"unsupported value literal: {other:?}"
))),
},
Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: inner,
} => {
if let Expr::Value(v) = inner.as_ref() {
if let Value::Number(n, _) = &v.value {
return Ok(format!("-{n}"));
}
}
Err(DfOlapError::Other(format!(
"unsupported unary expression: {expr}"
)))
}
other => Err(DfOlapError::Other(format!(
"unsupported expression in VALUES: {other}"
))),
}
}
fn ident_from_expr(expr: &Expr) -> Result<String, DfOlapError> {
match expr {
Expr::Identifier(ident) => Ok(ident.value.clone()),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|i| i.value.clone())
.ok_or_else(|| DfOlapError::Other("empty compound identifier".into())),
other => Err(DfOlapError::Other(format!(
"expected column name, got: {other}"
))),
}
}
fn extract_where_conditions(expr: &Expr) -> Result<Vec<(String, String)>, DfOlapError> {
match expr {
Expr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
let mut conditions = extract_where_conditions(left)?;
conditions.extend(extract_where_conditions(right)?);
Ok(conditions)
}
Expr::BinaryOp {
left,
op: BinaryOperator::Eq,
right,
} => {
let col = ident_from_expr(left)?;
let val = expr_to_sql_literal(right)?;
Ok(vec![(col, val)])
}
Expr::IsNull(inner) => {
let col = ident_from_expr(inner)?;
Ok(vec![(col, "NULL".into())])
}
Expr::IsNotNull(inner) => {
let col = ident_from_expr(inner)?;
Ok(vec![(col, "__IS_NOT_NULL__".into())])
}
Expr::Nested(inner) => extract_where_conditions(inner),
other => Err(DfOlapError::Other(format!(
"unsupported WHERE expression: {other}"
))),
}
}
fn parse_insert_values(sql: &str) -> Result<(String, Vec<String>, Vec<RecordBatch>), DfOlapError> {
let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
.map_err(|e| DfOlapError::Other(format!("failed to parse INSERT: {e}")))?;
let stmt = stmts
.pop()
.ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
let insert = match stmt {
Statement::Insert(ins) => ins,
other => {
return Err(DfOlapError::Other(format!(
"expected INSERT statement, got: {other:?}"
)));
}
};
let table_name = match &insert.table {
TableObject::TableName(obj_name) => obj_name
.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
.ok_or_else(|| DfOlapError::Other("empty table name in INSERT".into()))?,
TableObject::TableFunction(_) => {
return Err(DfOlapError::Other(
"INSERT INTO TABLE FUNCTION not supported".into(),
));
}
};
rhei_core::validate_identifier(&table_name).map_err(|e| DfOlapError::Other(e.to_string()))?;
let col_name_strings: Vec<String> = insert.columns.iter().map(|id| id.value.clone()).collect();
let source = match insert.source {
Some(q) => q,
None => return Ok((table_name, col_name_strings, vec![])),
};
let values = match *source.body {
SetExpr::Values(v) => v,
other => {
return Err(DfOlapError::Other(format!(
"INSERT source is not a VALUES clause: {other:?}"
)));
}
};
if values.rows.is_empty() {
return Ok((table_name, col_name_strings, vec![]));
}
let rows: Vec<Vec<String>> = values
.rows
.iter()
.map(|row| {
row.iter()
.map(expr_to_sql_literal)
.collect::<Result<_, _>>()
})
.collect::<Result<_, _>>()?;
let col_name_refs: Vec<&str> = col_name_strings.iter().map(|s| s.as_str()).collect();
let num_cols = col_name_refs.len();
if num_cols == 0 {
return Err(DfOlapError::Other(format!(
"INSERT INTO {table_name} requires an explicit column list; `VALUES (...)` without columns is not supported"
)));
}
let batch = build_record_batch_from_values(&col_name_refs, &rows, num_cols)?;
Ok((table_name, col_name_strings, vec![batch]))
}
fn build_record_batch_from_values(
col_names: &[&str],
rows: &[Vec<String>],
num_cols: usize,
) -> Result<RecordBatch, DfOlapError> {
use arrow::array::*;
use arrow::datatypes::{Field, Schema};
let mut types = vec![DataType::Utf8; num_cols];
for col_idx in 0..num_cols {
for row in rows {
if col_idx < row.len() {
let val = &row[col_idx];
let upper = val.to_ascii_uppercase();
if upper == "NULL" {
continue;
}
if upper == "TRUE" || upper == "FALSE" {
types[col_idx] = DataType::Boolean;
break;
}
if val.starts_with('\'') {
types[col_idx] = DataType::Utf8;
break;
}
if val.contains('.') {
if val.parse::<f64>().is_ok() {
types[col_idx] = DataType::Float64;
break;
}
} else if val.parse::<i64>().is_ok() {
types[col_idx] = DataType::Int64;
break;
}
break;
}
}
}
let fields: Vec<Field> = col_names
.iter()
.zip(types.iter())
.map(|(name, dt)| Field::new(*name, dt.clone(), true))
.collect();
let schema = Arc::new(Schema::new(fields));
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(num_cols);
for col_idx in 0..num_cols {
let col_values: Vec<&str> = rows
.iter()
.map(|row| {
if col_idx < row.len() {
row[col_idx].as_str()
} else {
"NULL"
}
})
.collect();
columns.push(build_array(&types[col_idx], &col_values)?);
}
let batch = RecordBatch::try_new(schema, columns)?;
Ok(batch)
}
fn build_array(dt: &DataType, values: &[&str]) -> Result<Arc<dyn Array>, DfOlapError> {
use arrow::array::*;
match dt {
DataType::Int64 => {
let mut builder = Int64Builder::new();
for v in values {
if v.eq_ignore_ascii_case("NULL") {
builder.append_null();
} else {
builder.append_value(
v.parse::<i64>()
.map_err(|e| DfOlapError::Other(format!("parse i64: {e}")))?,
);
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Float64 => {
let mut builder = Float64Builder::new();
for v in values {
if v.eq_ignore_ascii_case("NULL") {
builder.append_null();
} else {
builder.append_value(
v.parse::<f64>()
.map_err(|e| DfOlapError::Other(format!("parse f64: {e}")))?,
);
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Boolean => {
let mut builder = BooleanBuilder::new();
for v in values {
let upper = v.to_ascii_uppercase();
if upper == "NULL" {
builder.append_null();
} else {
builder.append_value(upper == "TRUE");
}
}
Ok(Arc::new(builder.finish()))
}
_ => {
let mut builder = StringBuilder::new();
for v in values {
if v.eq_ignore_ascii_case("NULL") {
builder.append_null();
} else {
let stripped = if v.starts_with('\'') && v.ends_with('\'') && v.len() >= 2 {
&v[1..v.len() - 1]
} else {
v
};
builder.append_value(stripped.replace("''", "'"));
}
}
Ok(Arc::new(builder.finish()))
}
}
}
type ColVal = (String, String);
fn parse_update(sql: &str) -> Result<(String, Vec<ColVal>, Vec<ColVal>), DfOlapError> {
let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
.map_err(|e| DfOlapError::Other(format!("failed to parse UPDATE: {e}")))?;
let stmt = stmts
.pop()
.ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
let update = match stmt {
Statement::Update(upd) => upd,
other => {
return Err(DfOlapError::Other(format!(
"expected UPDATE statement, got: {other:?}"
)));
}
};
let table_name = match &update.table.relation {
TableFactor::Table { name, .. } => name
.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
.ok_or_else(|| DfOlapError::Other("empty table name in UPDATE".into()))?,
other => {
return Err(DfOlapError::Other(format!(
"unexpected table factor in UPDATE: {other:?}"
)));
}
};
let assignments: Vec<ColVal> = update
.assignments
.iter()
.map(|a| {
let col = match &a.target {
AssignmentTarget::ColumnName(obj) => obj
.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
.ok_or_else(|| DfOlapError::Other("empty column name in SET".into()))?,
AssignmentTarget::Tuple(_) => {
return Err(DfOlapError::Other(
"tuple assignments in SET not supported".into(),
));
}
};
let val = expr_to_sql_literal(&a.value)?;
Ok((col, val))
})
.collect::<Result<_, DfOlapError>>()?;
let where_clause = match &update.selection {
Some(expr) => extract_where_conditions(expr)?,
None => vec![],
};
Ok((table_name, assignments, where_clause))
}
fn parse_delete(sql: &str) -> Result<(String, Vec<(String, String)>), DfOlapError> {
let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
.map_err(|e| DfOlapError::Other(format!("failed to parse DELETE: {e}")))?;
let stmt = stmts
.pop()
.ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
let delete = match stmt {
Statement::Delete(del) => del,
other => {
return Err(DfOlapError::Other(format!(
"expected DELETE statement, got: {other:?}"
)));
}
};
let tables = match &delete.from {
FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
};
let table_name = tables
.first()
.and_then(|twj| {
if let TableFactor::Table { name, .. } = &twj.relation {
name.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
} else {
None
}
})
.ok_or_else(|| DfOlapError::Other("missing table name in DELETE".into()))?;
let where_clause = match &delete.selection {
Some(expr) => extract_where_conditions(expr)?,
None => vec![],
};
Ok((table_name, where_clause))
}
fn flatten_batches(
batches: &[RecordBatch],
schema: &SchemaRef,
) -> Result<Option<RecordBatch>, DfOlapError> {
if batches.is_empty() {
return Ok(None);
}
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
if total_rows == 0 {
return Ok(None);
}
let batch = arrow::compute::concat_batches(schema, batches)?;
Ok(Some(batch))
}
fn apply_update(
batch: &RecordBatch,
schema: &SchemaRef,
assignments: &[(String, String)],
where_conditions: &[(String, String)],
) -> Result<(RecordBatch, u64), DfOlapError> {
let matching = find_matching_rows(batch, schema, where_conditions)?;
let updated_count = matching.iter().filter(|&&m| m).count() as u64;
let mut new_columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for (col_idx, field) in schema.fields().iter().enumerate() {
let assignment = assignments.iter().find(|(col, _)| col == field.name());
if let Some((_, new_val)) = assignment {
let original = batch.column(col_idx);
new_columns.push(apply_value_to_matching(
original,
&matching,
new_val,
field.data_type(),
)?);
} else {
new_columns.push(batch.column(col_idx).clone());
}
}
let new_batch = RecordBatch::try_new(schema.clone(), new_columns)?;
Ok((new_batch, updated_count))
}
fn apply_delete(
batch: &RecordBatch,
schema: &SchemaRef,
where_conditions: &[(String, String)],
) -> Result<(RecordBatch, u64), DfOlapError> {
let matching = find_matching_rows(batch, schema, where_conditions)?;
let deleted_count = matching.iter().filter(|&&m| m).count() as u64;
let mut builder = BooleanBuilder::new();
for &m in &matching {
builder.append_value(!m);
}
let filter_array = builder.finish();
let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
.map(|i| arrow::compute::filter(batch.column(i), &filter_array).map_err(DfOlapError::Arrow))
.collect::<Result<_, _>>()?;
let new_batch = RecordBatch::try_new(schema.clone(), new_columns)?;
Ok((new_batch, deleted_count))
}
fn find_matching_rows(
batch: &RecordBatch,
schema: &SchemaRef,
conditions: &[(String, String)],
) -> Result<Vec<bool>, DfOlapError> {
let num_rows = batch.num_rows();
let mut matching = vec![true; num_rows];
for (col_name, expected_val) in conditions {
let col_idx = schema
.fields()
.iter()
.position(|f| f.name() == col_name)
.ok_or_else(|| DfOlapError::Other(format!("column not found: {col_name}")))?;
let col = batch.column(col_idx);
for (row_idx, m) in matching.iter_mut().enumerate() {
if !*m {
continue;
}
*m = value_matches(col, row_idx, expected_val);
}
}
Ok(matching)
}
fn value_matches(array: &dyn Array, row_idx: usize, expected: &str) -> bool {
if expected == "__IS_NOT_NULL__" {
return !array.is_null(row_idx);
}
if array.is_null(row_idx) {
return expected.eq_ignore_ascii_case("NULL");
}
match array.data_type() {
DataType::Int8 => {
expected.parse::<i8>().ok() == Some(array.as_primitive::<Int8Type>().value(row_idx))
}
DataType::Int16 => {
expected.parse::<i16>().ok() == Some(array.as_primitive::<Int16Type>().value(row_idx))
}
DataType::Int32 => {
expected.parse::<i32>().ok() == Some(array.as_primitive::<Int32Type>().value(row_idx))
}
DataType::Int64 => {
expected.parse::<i64>().ok() == Some(array.as_primitive::<Int64Type>().value(row_idx))
}
DataType::UInt8 => {
expected.parse::<u8>().ok() == Some(array.as_primitive::<UInt8Type>().value(row_idx))
}
DataType::UInt16 => {
expected.parse::<u16>().ok() == Some(array.as_primitive::<UInt16Type>().value(row_idx))
}
DataType::UInt32 => {
expected.parse::<u32>().ok() == Some(array.as_primitive::<UInt32Type>().value(row_idx))
}
DataType::UInt64 => {
expected.parse::<u64>().ok() == Some(array.as_primitive::<UInt64Type>().value(row_idx))
}
DataType::Float32 => {
expected.parse::<f32>().ok() == Some(array.as_primitive::<Float32Type>().value(row_idx))
}
DataType::Float64 => {
expected.parse::<f64>().ok() == Some(array.as_primitive::<Float64Type>().value(row_idx))
}
DataType::Utf8 => {
let arr = array.as_string::<i32>();
let stripped =
if expected.starts_with('\'') && expected.ends_with('\'') && expected.len() >= 2 {
&expected[1..expected.len() - 1]
} else {
expected
};
arr.value(row_idx) == stripped
}
DataType::Boolean => {
let arr = array.as_boolean();
match expected.to_ascii_uppercase().as_str() {
"TRUE" => arr.value(row_idx),
"FALSE" => !arr.value(row_idx),
_ => false,
}
}
_ => false,
}
}
fn apply_value_to_matching(
original: &dyn Array,
matching: &[bool],
new_val: &str,
dt: &DataType,
) -> Result<Arc<dyn Array>, DfOlapError> {
use arrow::array::*;
match dt {
DataType::Int64 => {
let orig = original.as_primitive::<Int64Type>();
let parsed: i64 = new_val
.parse()
.map_err(|e| DfOlapError::Other(format!("parse i64: {e}")))?;
let mut builder = Int64Builder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(parsed);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Float64 => {
let orig = original.as_primitive::<Float64Type>();
let parsed: f64 = new_val
.parse()
.map_err(|e| DfOlapError::Other(format!("parse f64: {e}")))?;
let mut builder = Float64Builder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(parsed);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Utf8 => {
let orig = original.as_string::<i32>();
let stripped =
if new_val.starts_with('\'') && new_val.ends_with('\'') && new_val.len() >= 2 {
&new_val[1..new_val.len() - 1]
} else {
new_val
};
let unescaped = stripped.replace("''", "'");
let mut builder = StringBuilder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(&unescaped);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Boolean => {
let orig = original.as_boolean();
let parsed = new_val.eq_ignore_ascii_case("TRUE");
let mut builder = BooleanBuilder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(parsed);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
_ => {
let orig = original.as_string::<i32>();
let mut builder = StringBuilder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(new_val);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{Field, Schema};
use rhei_core::OlapEngine;
fn users_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
]))
}
fn make_in_memory(_: &std::path::Path) -> DataFusionEngine {
DataFusionEngine::new()
}
fn make_vortex(tmp: &std::path::Path) -> DataFusionEngine {
DataFusionEngine::with_storage(StorageMode::Vortex {
url: tmp.join("vortex_olap").to_string_lossy().to_string(),
})
.unwrap()
}
macro_rules! storage_mode_tests {
($mod_name:ident, $make_engine:ident) => {
mod $mod_name {
use super::*;
#[tokio::test]
async fn create_and_query_empty() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
assert!(engine.table_exists("users").await.unwrap());
assert!(!engine.table_exists("nonexistent").await.unwrap());
}
#[tokio::test]
async fn insert_and_query() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.await
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.await
.unwrap();
let batches = engine
.query("SELECT * FROM users ORDER BY id")
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 2);
}
#[tokio::test]
async fn update() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.await
.unwrap();
let updated = engine
.execute("UPDATE users SET age = 31 WHERE id = 1")
.await
.unwrap();
assert_eq!(updated, 1);
let batches = engine
.query("SELECT age FROM users WHERE id = 1")
.await
.unwrap();
let age = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(age, 31);
}
#[tokio::test]
async fn delete() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)")
.await
.unwrap();
let deleted = engine
.execute("DELETE FROM users WHERE id = 1")
.await
.unwrap();
assert_eq!(deleted, 1);
let batches = engine
.query("SELECT COUNT(*) as cnt FROM users")
.await
.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 1);
}
#[tokio::test]
async fn load_arrow_bulk() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3])),
Arc::new(arrow::array::StringArray::from(vec![
"Alice", "Bob", "Charlie",
])),
Arc::new(arrow::array::Int64Array::from(vec![30, 25, 35])),
],
)
.unwrap();
let loaded = engine.load_arrow("users", &[batch]).await.unwrap();
assert_eq!(loaded, 3);
let batches = engine
.query("SELECT COUNT(*) as cnt FROM users")
.await
.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 3);
}
#[tokio::test]
async fn aggregate() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute(
"INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25), (3, 'Charlie', 35)",
)
.await
.unwrap();
let batches = engine
.query("SELECT AVG(age) as avg_age FROM users")
.await
.unwrap();
let avg = batches[0].column(0).as_primitive::<Float64Type>().value(0);
assert!((avg - 30.0).abs() < 0.01);
}
}
};
}
storage_mode_tests!(in_memory, make_in_memory);
storage_mode_tests!(vortex_local, make_vortex);
#[tokio::test]
async fn vortex_local_persist_restart() {
let tmp = tempfile::tempdir().unwrap();
let base = tmp.path().join("restart_test");
let schema = users_schema();
{
let engine = DataFusionEngine::with_storage(StorageMode::Vortex {
url: base.to_string_lossy().to_string(),
})
.unwrap();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute(
"INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)",
)
.await
.unwrap();
}
{
let engine2 = DataFusionEngine::with_storage(StorageMode::Vortex {
url: base.to_string_lossy().to_string(),
})
.unwrap();
engine2.create_table("users", &schema, &[]).await.unwrap();
let batches = engine2
.query("SELECT COUNT(*) as cnt FROM users")
.await
.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 2, "data should survive engine restart");
}
}
#[tokio::test]
async fn insert_string_with_comma() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice, B', 30)")
.await
.unwrap();
let batches = engine
.query("SELECT name FROM users WHERE id = 1")
.await
.unwrap();
let name_arr = batches[0].column(0).as_string::<i32>();
assert_eq!(name_arr.value(0), "Alice, B");
}
#[tokio::test]
async fn insert_null_value() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, NULL, 30)")
.await
.unwrap();
let batches = engine
.query("SELECT name FROM users WHERE id = 1")
.await
.unwrap();
assert!(batches[0].column(0).is_null(0));
}
#[tokio::test]
async fn update_where_and() {
let engine = DataFusionEngine::new();
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("id", DataType::Int64, false),
arrow::datatypes::Field::new("name", DataType::Utf8, true),
arrow::datatypes::Field::new("status", DataType::Utf8, true),
]));
engine.create_table("t", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO t (id, name, status) VALUES (1, 'x', 'active')")
.await
.unwrap();
engine
.execute("INSERT INTO t (id, name, status) VALUES (2, 'y', 'inactive')")
.await
.unwrap();
let updated = engine
.execute("UPDATE t SET name = 'updated' WHERE id = 1 AND status = 'active'")
.await
.unwrap();
assert_eq!(updated, 1);
let batches = engine
.query("SELECT name FROM t WHERE id = 1")
.await
.unwrap();
assert_eq!(batches[0].column(0).as_string::<i32>().value(0), "updated");
let batches2 = engine
.query("SELECT name FROM t WHERE id = 2")
.await
.unwrap();
assert_eq!(batches2[0].column(0).as_string::<i32>().value(0), "y");
}
#[tokio::test]
async fn delete_quoted_identifier() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.await
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.await
.unwrap();
let deleted = engine
.execute(r#"DELETE FROM "users" WHERE id = 1"#)
.await
.unwrap();
assert_eq!(deleted, 1);
let batches = engine.query("SELECT COUNT(*) FROM users").await.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 1);
}
#[tokio::test]
async fn insert_escaped_single_quote() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'O''Brien', 42)")
.await
.unwrap();
let batches = engine
.query("SELECT name FROM users WHERE id = 1")
.await
.unwrap();
assert_eq!(batches[0].column(0).as_string::<i32>().value(0), "O'Brien");
}
#[test]
fn parse_insert_multi_row() {
let (table, cols, batches) =
parse_insert_values("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')")
.unwrap();
assert_eq!(table, "users");
assert_eq!(cols, vec!["id", "name"]);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 2);
}
#[test]
fn parse_update_basic() {
let (table, assignments, where_clause) =
parse_update("UPDATE users SET name = 'Alice' WHERE id = 1").unwrap();
assert_eq!(table, "users");
assert_eq!(
assignments,
vec![("name".to_string(), "'Alice'".to_string())]
);
assert_eq!(where_clause, vec![("id".to_string(), "1".to_string())]);
}
#[test]
fn parse_delete_no_where() {
let (table, conditions) = parse_delete("DELETE FROM logs").unwrap();
assert_eq!(table, "logs");
assert!(conditions.is_empty());
}
#[test]
fn vortex_url_local_path_classified() {
let mode = StorageMode::Vortex {
url: "/tmp/rhei".to_string(),
};
assert!(!mode.is_cloud());
assert!(mode.local_base_path().is_some());
}
#[cfg(feature = "cloud-storage")]
#[test]
fn vortex_url_s3_classified() {
let mode = StorageMode::Vortex {
url: "s3://my-bucket/prefix".to_string(),
};
assert!(mode.is_cloud());
assert_eq!(mode.cloud_base_url(), Some("s3://my-bucket/prefix"));
}
#[cfg(feature = "cloud-storage")]
#[tokio::test]
async fn vortex_s3_round_trip() {
if std::env::var("RHEI_TEST_S3").as_deref() != Ok("1") {
return; }
use std::time::{SystemTime, UNIX_EPOCH};
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.subsec_nanos();
let prefix = format!("test_{:08x}", ts);
let base_url = format!("s3://pixai-rec-sys/dev/rhei/{prefix}");
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("id", DataType::Int64, false),
arrow::datatypes::Field::new("val", DataType::Utf8, true),
]));
let engine = DataFusionEngine::with_storage(StorageMode::Vortex {
url: base_url.clone(),
})
.expect("S3 engine construction should succeed with AWS credentials");
engine.create_table("s3test", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO s3test (id, val) VALUES (1, 'hello'), (2, 'world')")
.await
.unwrap();
let batches = engine
.query("SELECT COUNT(*) as cnt FROM s3test")
.await
.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 2, "S3 round-trip INSERT+SELECT should return 2 rows");
let updated = engine
.execute("UPDATE s3test SET val = 'updated' WHERE id = 1")
.await
.unwrap();
assert_eq!(updated, 1);
let batches2 = engine
.query("SELECT val FROM s3test WHERE id = 1")
.await
.unwrap();
assert_eq!(batches2[0].column(0).as_string::<i32>().value(0), "updated");
tracing::warn!(prefix, "S3 test data not cleaned up — remove manually");
}
}