use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use arrow::array::{Array, AsArray, BooleanBuilder, RecordBatch};
use arrow::datatypes::{
DataType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef,
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use datafusion::datasource::file_format::arrow::ArrowFormat;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::file_format::FileFormat;
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::datasource::MemTable;
use datafusion::prelude::*;
use sqlparser::ast::{
AssignmentTarget, BinaryOperator, Expr, FromTable, SetExpr, Statement, TableFactor,
TableObject, UnaryOperator, Value,
};
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
use tokio::sync::RwLock;
use tracing::debug;
use crate::error::DfOlapError;
use crate::storage::StorageMode;
#[cfg(feature = "cloud-storage")]
use object_store::ObjectStore;
#[cfg(feature = "cloud-storage")]
use url::Url;
struct TableData {
schema: SchemaRef,
batches: Vec<RecordBatch>,
}
struct FileTableMeta {
schema: SchemaRef,
dir: PathBuf,
}
#[cfg(feature = "cloud-storage")]
struct CloudTableMeta {
schema: SchemaRef,
table_url: String,
}
pub struct DataFusionEngine {
ctx: RwLock<SessionContext>,
tables: RwLock<HashMap<String, TableData>>,
file_tables: RwLock<HashMap<String, FileTableMeta>>,
#[cfg(feature = "cloud-storage")]
cloud_tables: RwLock<HashMap<String, CloudTableMeta>>,
#[cfg(feature = "cloud-storage")]
cloud_store: Option<Arc<dyn ObjectStore>>,
storage_mode: StorageMode,
file_counter: AtomicU64,
}
impl DataFusionEngine {
pub fn with_storage(mode: StorageMode) -> Result<Self, DfOlapError> {
let start_counter = if let Some(base_path) = mode.base_path() {
std::fs::create_dir_all(base_path)?;
Self::scan_max_file_seq(base_path, mode.file_extension())
} else {
0
};
#[cfg(feature = "cloud-storage")]
let (ctx, cloud_store) = Self::build_session_context(&mode)?;
#[cfg(not(feature = "cloud-storage"))]
let ctx = SessionContext::new();
Ok(Self {
ctx: RwLock::new(ctx),
tables: RwLock::new(HashMap::new()),
file_tables: RwLock::new(HashMap::new()),
#[cfg(feature = "cloud-storage")]
cloud_tables: RwLock::new(HashMap::new()),
#[cfg(feature = "cloud-storage")]
cloud_store,
storage_mode: mode,
file_counter: AtomicU64::new(start_counter),
})
}
#[cfg(feature = "cloud-storage")]
fn build_session_context(
mode: &StorageMode,
) -> Result<(SessionContext, Option<Arc<dyn ObjectStore>>), DfOlapError> {
let ctx = SessionContext::new();
let mut cloud_store: Option<Arc<dyn ObjectStore>> = None;
match mode {
StorageMode::S3Parquet { url } => {
let bucket = Self::parse_bucket(url, "s3")?;
let store: Arc<dyn ObjectStore> = Arc::new(
object_store::aws::AmazonS3Builder::from_env()
.with_bucket_name(&bucket)
.build()?,
);
let base_url =
Url::parse(&format!("s3://{bucket}")).map_err(DfOlapError::UrlParse)?;
ctx.runtime_env()
.register_object_store(&base_url, store.clone());
cloud_store = Some(store);
tracing::info!(bucket, "registered S3 object store");
}
StorageMode::GcsParquet { url } => {
let bucket = Self::parse_bucket(url, "gs")?;
let store: Arc<dyn ObjectStore> = Arc::new(
object_store::gcp::GoogleCloudStorageBuilder::from_env()
.with_bucket_name(&bucket)
.build()?,
);
let base_url =
Url::parse(&format!("gs://{bucket}")).map_err(DfOlapError::UrlParse)?;
ctx.runtime_env()
.register_object_store(&base_url, store.clone());
cloud_store = Some(store);
tracing::info!(bucket, "registered GCS object store");
}
_ => {}
}
Ok((ctx, cloud_store))
}
#[cfg(feature = "cloud-storage")]
fn parse_bucket(url: &str, expected_scheme: &str) -> Result<String, DfOlapError> {
let parsed = Url::parse(url).map_err(DfOlapError::UrlParse)?;
if parsed.scheme() != expected_scheme {
return Err(DfOlapError::Other(format!(
"expected {expected_scheme}:// URL, got '{url}'"
)));
}
parsed
.host_str()
.map(|h| h.to_string())
.ok_or_else(|| DfOlapError::Other(format!("missing bucket name in URL '{url}'")))
}
#[cfg(feature = "cloud-storage")]
fn cloud_table_url(base_url: &str, table_name: &str) -> String {
let base = base_url.trim_end_matches('/');
format!("{base}/{table_name}/")
}
pub fn new() -> Self {
Self::with_storage(StorageMode::InMemory).expect("in-memory mode cannot fail")
}
pub fn storage_mode(&self) -> &StorageMode {
&self.storage_mode
}
fn scan_max_file_seq(base_path: &std::path::Path, ext: &str) -> u64 {
let mut max_seq: u64 = 0;
let Ok(entries) = std::fs::read_dir(base_path) else {
return 0;
};
for entry in entries.flatten() {
if !entry.path().is_dir() {
continue;
}
let Ok(files) = std::fs::read_dir(entry.path()) else {
continue;
};
for file in files.flatten() {
let path = file.path();
if path.extension().and_then(|x| x.to_str()) != Some(ext) {
continue;
}
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
if let Some(seq_str) = stem.rsplit('_').next() {
if let Ok(seq) = seq_str.parse::<u64>() {
max_seq = max_seq.max(seq + 1);
}
}
}
}
}
max_seq
}
fn next_file_name(&self, table_name: &str) -> String {
let seq = self.file_counter.fetch_add(1, Ordering::Relaxed);
let ext = self.storage_mode.file_extension();
format!("{table_name}_{seq:06}.{ext}")
}
fn table_dir(&self, table_name: &str) -> Option<PathBuf> {
self.storage_mode
.base_path()
.map(|base| base.join(table_name))
}
async fn refresh_table_mem(&self, name: &str) -> Result<(), DfOlapError> {
let tables = self.tables.read().await;
let table_data = tables
.get(name)
.ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
let partitions = if table_data.batches.is_empty() {
vec![vec![]]
} else {
vec![table_data.batches.clone()]
};
let mem_table = MemTable::try_new(table_data.schema.clone(), partitions)?;
let ctx = self.ctx.write().await;
let _ = ctx.deregister_table(name);
ctx.register_table(name, Arc::new(mem_table))?;
Ok(())
}
async fn refresh_table_file(&self, name: &str) -> Result<(), DfOlapError> {
let file_tables = self.file_tables.read().await;
let meta = file_tables
.get(name)
.ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
let table_path = meta.dir.to_string_lossy().to_string();
let format: Arc<dyn FileFormat> = match &self.storage_mode {
StorageMode::ArrowIpc { .. } => Arc::new(ArrowFormat),
StorageMode::Parquet { .. } => Arc::new(ParquetFormat::default()),
StorageMode::InMemory => unreachable!(),
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => unreachable!(),
};
let ext = self.storage_mode.file_extension();
let listing_options = ListingOptions::new(format).with_file_extension(ext);
let config =
ListingTableConfig::new_with_multi_paths(vec![ListingTableUrl::parse(&table_path)?])
.with_listing_options(listing_options)
.with_schema(meta.schema.clone());
let listing_table = ListingTable::try_new(config)?;
let ctx = self.ctx.write().await;
let _ = ctx.deregister_table(name);
ctx.register_table(name, Arc::new(listing_table))?;
Ok(())
}
#[cfg(feature = "cloud-storage")]
async fn refresh_table_cloud(&self, name: &str) -> Result<(), DfOlapError> {
let cloud_tables = self.cloud_tables.read().await;
let meta = cloud_tables
.get(name)
.ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
let table_url = meta.table_url.clone();
let schema = meta.schema.clone();
drop(cloud_tables);
let format: Arc<dyn FileFormat> = Arc::new(ParquetFormat::default());
let listing_options = ListingOptions::new(format).with_file_extension("parquet");
let config =
ListingTableConfig::new_with_multi_paths(vec![ListingTableUrl::parse(&table_url)?])
.with_listing_options(listing_options)
.with_schema(schema);
let listing_table = ListingTable::try_new(config)?;
let ctx = self.ctx.write().await;
let _ = ctx.deregister_table(name);
ctx.register_table(name, Arc::new(listing_table))?;
Ok(())
}
async fn write_batches_to_file(
&self,
table_name: &str,
schema: &SchemaRef,
batches: &[RecordBatch],
) -> Result<PathBuf, DfOlapError> {
let dir = self
.table_dir(table_name)
.ok_or_else(|| DfOlapError::Other("no table dir for in-memory mode".into()))?;
let file_name = self.next_file_name(table_name);
let file_path = dir.join(&file_name);
let schema = schema.clone();
let batches: Vec<RecordBatch> = batches.to_vec();
let path = file_path.clone();
match &self.storage_mode {
StorageMode::ArrowIpc { .. } => {
tokio::task::spawn_blocking(move || {
let file = std::fs::File::create(&path)?;
let mut writer =
arrow::ipc::writer::FileWriter::try_new(file, schema.as_ref())?;
for batch in &batches {
writer.write(batch)?;
}
writer.finish()?;
Ok::<_, DfOlapError>(())
})
.await
.map_err(DfOlapError::from_join)??;
}
StorageMode::Parquet { .. } => {
tokio::task::spawn_blocking(move || {
let file = std::fs::File::create(&path)?;
let props = parquet::file::properties::WriterProperties::builder()
.set_writer_version(parquet::file::properties::WriterVersion::PARQUET_2_0)
.build();
let mut writer =
parquet::arrow::ArrowWriter::try_new(file, schema, Some(props))?;
for batch in &batches {
writer.write(batch)?;
}
writer.close()?;
Ok::<_, DfOlapError>(())
})
.await
.map_err(DfOlapError::from_join)??;
}
StorageMode::InMemory => unreachable!(),
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => unreachable!(),
}
Ok(file_path)
}
fn list_data_files(dir: &std::path::Path, ext: &str) -> Result<Vec<PathBuf>, DfOlapError> {
let mut files: Vec<PathBuf> = std::fs::read_dir(dir)?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|x| x.to_str() == Some(ext)))
.collect();
files.sort();
Ok(files)
}
async fn read_all_batches(
&self,
table_name: &str,
) -> Result<(SchemaRef, Vec<RecordBatch>), DfOlapError> {
let file_tables = self.file_tables.read().await;
let meta = file_tables
.get(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let schema = meta.schema.clone();
let dir = meta.dir.clone();
let ext = self.storage_mode.file_extension().to_string();
let is_arrow_ipc = matches!(self.storage_mode, StorageMode::ArrowIpc { .. });
drop(file_tables);
tokio::task::spawn_blocking(move || {
let mut all_batches = Vec::new();
let files = Self::list_data_files(&dir, &ext)?;
for path in files {
if is_arrow_ipc {
let file = std::fs::File::open(&path)?;
let reader = arrow::ipc::reader::FileReader::try_new(file, None)?;
for batch in reader {
all_batches.push(batch?);
}
} else {
let file = std::fs::File::open(&path)?;
let reader = parquet::arrow::arrow_reader::ParquetRecordBatchReader::try_new(
file, 8192,
)?;
for batch in reader {
all_batches.push(batch?);
}
}
}
Ok::<_, DfOlapError>((schema, all_batches))
})
.await
.map_err(DfOlapError::from_join)?
}
async fn clear_table_dir(&self, table_name: &str) -> Result<(), DfOlapError> {
let dir = match self.table_dir(table_name) {
Some(d) => d,
None => return Ok(()),
};
let ext = self.storage_mode.file_extension().to_string();
tokio::task::spawn_blocking(move || {
let files = Self::list_data_files(&dir, &ext)?;
for path in files {
std::fs::remove_file(path)?;
}
Ok::<_, DfOlapError>(())
})
.await
.map_err(DfOlapError::from_join)?
}
#[cfg(feature = "cloud-storage")]
fn cloud_store(&self) -> Result<Arc<dyn ObjectStore>, DfOlapError> {
self.cloud_store
.clone()
.ok_or_else(|| DfOlapError::Other("cloud store not initialised".into()))
}
#[cfg(feature = "cloud-storage")]
async fn write_batches_to_cloud(
&self,
table_name: &str,
schema: &SchemaRef,
batches: &[RecordBatch],
) -> Result<(), DfOlapError> {
let cloud_tables = self.cloud_tables.read().await;
let meta = cloud_tables
.get(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let table_url = meta.table_url.clone();
drop(cloud_tables);
let store = self.cloud_store()?;
let schema = schema.clone();
let batches: Vec<RecordBatch> = batches.to_vec();
let parquet_bytes = tokio::task::spawn_blocking(move || {
let mut buf = Vec::new();
let props = parquet::file::properties::WriterProperties::builder()
.set_writer_version(parquet::file::properties::WriterVersion::PARQUET_2_0)
.build();
let mut writer = parquet::arrow::ArrowWriter::try_new(&mut buf, schema, Some(props))?;
for batch in &batches {
writer.write(batch)?;
}
writer.close()?;
Ok::<_, DfOlapError>(buf)
})
.await
.map_err(DfOlapError::from_join)??;
let seq = self.file_counter.fetch_add(1, Ordering::Relaxed);
let object_key = format!("{table_name}_{seq:06}.parquet");
let table_path_prefix = Url::parse(&table_url)
.map_err(DfOlapError::UrlParse)?
.path()
.trim_start_matches('/')
.trim_end_matches('/')
.to_string();
let object_path =
object_store::path::Path::from(format!("{table_path_prefix}/{object_key}").as_str());
use object_store::ObjectStoreExt as _;
store.put(&object_path, parquet_bytes.into()).await?;
Ok(())
}
#[cfg(feature = "cloud-storage")]
async fn cloud_seq_for_prefix(
store: &Arc<dyn ObjectStore>,
table_url: &str,
table_name: &str,
) -> Result<u64, DfOlapError> {
let prefix_str = Url::parse(table_url)
.map_err(DfOlapError::UrlParse)?
.path()
.trim_start_matches('/')
.trim_end_matches('/')
.to_string();
let prefix = object_store::path::Path::from(prefix_str.as_str());
use futures::StreamExt as _;
use object_store::ObjectStore as _;
let mut list_stream = store.list(Some(&prefix));
let file_prefix = format!("{table_name}_");
let mut max_seq: Option<u64> = None;
while let Some(item) = list_stream.next().await {
let m = item?;
let path_str = m.location.to_string();
let file_name = path_str.rsplit('/').next().unwrap_or("");
if !file_name.starts_with(&file_prefix) || !file_name.ends_with(".parquet") {
continue;
}
let inner = &file_name[file_prefix.len()..file_name.len() - ".parquet".len()];
if let Ok(seq) = inner.parse::<u64>() {
max_seq = Some(max_seq.map_or(seq, |m| m.max(seq)));
}
}
Ok(max_seq.map_or(0, |m| m + 1))
}
#[cfg(feature = "cloud-storage")]
async fn list_cloud_objects(
&self,
table_name: &str,
) -> Result<(Arc<dyn ObjectStore>, String, Vec<object_store::path::Path>), DfOlapError> {
let cloud_tables = self.cloud_tables.read().await;
let meta = cloud_tables
.get(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let table_url = meta.table_url.clone();
drop(cloud_tables);
let store = self.cloud_store()?;
let prefix_str = Url::parse(&table_url)
.map_err(DfOlapError::UrlParse)?
.path()
.trim_start_matches('/')
.trim_end_matches('/')
.to_string();
let prefix = object_store::path::Path::from(prefix_str.as_str());
use futures::StreamExt as _;
use object_store::ObjectStore as _;
let mut list_stream = store.list(Some(&prefix));
let mut paths = Vec::new();
while let Some(item) = list_stream.next().await {
let m = item?;
if m.location.to_string().ends_with(".parquet") {
paths.push(m.location);
}
}
paths.sort_by_key(|a| a.to_string());
Ok((store, table_url, paths))
}
#[cfg(feature = "cloud-storage")]
async fn read_all_batches_cloud(
&self,
table_name: &str,
) -> Result<(SchemaRef, Vec<RecordBatch>), DfOlapError> {
let cloud_tables = self.cloud_tables.read().await;
let meta = cloud_tables
.get(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let schema = meta.schema.clone();
drop(cloud_tables);
let (store, _, object_paths) = self.list_cloud_objects(table_name).await?;
let mut all_batches = Vec::new();
for path in object_paths {
use object_store::ObjectStoreExt as _;
let get_result = store.get(&path).await?;
let bytes = get_result.bytes().await?;
let mut batch_vec = tokio::task::spawn_blocking(move || {
let reader =
parquet::arrow::arrow_reader::ParquetRecordBatchReader::try_new(bytes, 8192)?;
let mut batches = Vec::new();
for b in reader {
batches.push(b?);
}
Ok::<_, DfOlapError>(batches)
})
.await
.map_err(DfOlapError::from_join)??;
all_batches.append(&mut batch_vec);
}
Ok((schema, all_batches))
}
#[cfg(feature = "cloud-storage")]
async fn clear_cloud_table(&self, table_name: &str) -> Result<(), DfOlapError> {
let (store, _, paths) = self.list_cloud_objects(table_name).await?;
use object_store::ObjectStoreExt as _;
for path in paths {
store.delete(&path).await?;
}
Ok(())
}
async fn execute_sql(&self, sql: &str) -> Result<Vec<RecordBatch>, DfOlapError> {
let ctx = self.ctx.read().await;
let df = ctx.sql(sql).await?;
let batches = df.collect().await?;
Ok(batches)
}
fn align_batches_to_schema(
table_schema: &SchemaRef,
col_names: &[String],
batches: &[RecordBatch],
) -> Result<(Vec<RecordBatch>, u64), DfOlapError> {
let mut aligned_batches = Vec::with_capacity(batches.len());
let mut total_rows = 0u64;
for batch in batches {
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(table_schema.fields().len());
for field in table_schema.fields() {
let idx = col_names
.iter()
.position(|c| c == field.name())
.ok_or_else(|| {
DfOlapError::SchemaMismatch(format!(
"column '{}' not in INSERT column list",
field.name()
))
})?;
let col = batch.column(idx);
let col = if col.data_type() != field.data_type() {
arrow::compute::cast(col, field.data_type())?
} else {
col.clone()
};
columns.push(col);
}
let aligned = RecordBatch::try_new(table_schema.clone(), columns)?;
total_rows += aligned.num_rows() as u64;
aligned_batches.push(aligned);
}
Ok((aligned_batches, total_rows))
}
async fn execute_insert_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, col_names, batches) = parse_insert_values(sql)?;
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let table_schema = table_data.schema.clone();
let (aligned_batches, total_rows) =
Self::align_batches_to_schema(&table_schema, &col_names, &batches)?;
table_data.batches.extend(aligned_batches);
drop(tables);
self.refresh_table_mem(&table_name).await?;
Ok(total_rows)
}
async fn execute_update_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, assignments, where_clause) = parse_update(sql)?;
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let schema = table_data.schema.clone();
let mut updated_count = 0u64;
let all_rows = flatten_batches(&table_data.batches, &schema)?;
if let Some(all_rows) = all_rows {
let (updated_batch, count) =
apply_update(&all_rows, &schema, &assignments, &where_clause)?;
updated_count = count;
table_data.batches = vec![updated_batch];
}
drop(tables);
self.refresh_table_mem(&table_name).await?;
Ok(updated_count)
}
async fn execute_delete_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, where_clause) = parse_delete(sql)?;
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let schema = table_data.schema.clone();
let all_rows = flatten_batches(&table_data.batches, &schema)?;
if let Some(all_rows) = all_rows {
let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
table_data.batches = if filtered_batch.num_rows() > 0 {
vec![filtered_batch]
} else {
vec![]
};
drop(tables);
self.refresh_table_mem(&table_name).await?;
Ok(deleted_count)
} else {
Ok(0)
}
}
async fn execute_insert_file(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, col_names, batches) = parse_insert_values(sql)?;
let file_tables = self.file_tables.read().await;
let meta = file_tables
.get(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let table_schema = meta.schema.clone();
drop(file_tables);
let (aligned_batches, total_rows) =
Self::align_batches_to_schema(&table_schema, &col_names, &batches)?;
self.write_batches_to_file(&table_name, &table_schema, &aligned_batches)
.await?;
self.refresh_table_file(&table_name).await?;
Ok(total_rows)
}
async fn execute_update_file(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, assignments, where_clause) = parse_update(sql)?;
let (schema, existing_batches) = self.read_all_batches(&table_name).await?;
let all_rows = flatten_batches(&existing_batches, &schema)?;
if let Some(all_rows) = all_rows {
let (updated_batch, count) =
apply_update(&all_rows, &schema, &assignments, &where_clause)?;
self.clear_table_dir(&table_name).await?;
if updated_batch.num_rows() > 0 {
self.write_batches_to_file(&table_name, &schema, &[updated_batch])
.await?;
}
self.refresh_table_file(&table_name).await?;
Ok(count)
} else {
Ok(0)
}
}
async fn execute_delete_file(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, where_clause) = parse_delete(sql)?;
let (schema, existing_batches) = self.read_all_batches(&table_name).await?;
let all_rows = flatten_batches(&existing_batches, &schema)?;
if let Some(all_rows) = all_rows {
let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
self.clear_table_dir(&table_name).await?;
if filtered_batch.num_rows() > 0 {
self.write_batches_to_file(&table_name, &schema, &[filtered_batch])
.await?;
}
self.refresh_table_file(&table_name).await?;
Ok(deleted_count)
} else {
Ok(0)
}
}
#[cfg(feature = "cloud-storage")]
async fn execute_insert_cloud(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, col_names, batches) = parse_insert_values(sql)?;
let cloud_tables = self.cloud_tables.read().await;
let meta = cloud_tables
.get(&table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
let table_schema = meta.schema.clone();
drop(cloud_tables);
let (aligned_batches, total_rows) =
Self::align_batches_to_schema(&table_schema, &col_names, &batches)?;
self.write_batches_to_cloud(&table_name, &table_schema, &aligned_batches)
.await?;
self.refresh_table_cloud(&table_name).await?;
Ok(total_rows)
}
#[cfg(feature = "cloud-storage")]
async fn execute_update_cloud(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, assignments, where_clause) = parse_update(sql)?;
let (schema, existing_batches) = self.read_all_batches_cloud(&table_name).await?;
let all_rows = flatten_batches(&existing_batches, &schema)?;
if let Some(all_rows) = all_rows {
let (updated_batch, count) =
apply_update(&all_rows, &schema, &assignments, &where_clause)?;
self.clear_cloud_table(&table_name).await?;
if updated_batch.num_rows() > 0 {
self.write_batches_to_cloud(&table_name, &schema, &[updated_batch])
.await?;
}
self.refresh_table_cloud(&table_name).await?;
Ok(count)
} else {
Ok(0)
}
}
#[cfg(feature = "cloud-storage")]
async fn execute_delete_cloud(&self, sql: &str) -> Result<u64, DfOlapError> {
let (table_name, where_clause) = parse_delete(sql)?;
let (schema, existing_batches) = self.read_all_batches_cloud(&table_name).await?;
let all_rows = flatten_batches(&existing_batches, &schema)?;
if let Some(all_rows) = all_rows {
let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
self.clear_cloud_table(&table_name).await?;
if filtered_batch.num_rows() > 0 {
self.write_batches_to_cloud(&table_name, &schema, &[filtered_batch])
.await?;
}
self.refresh_table_cloud(&table_name).await?;
Ok(deleted_count)
} else {
Ok(0)
}
}
async fn execute_insert(&self, sql: &str) -> Result<u64, DfOlapError> {
match &self.storage_mode {
StorageMode::InMemory => self.execute_insert_mem(sql).await,
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
self.execute_insert_cloud(sql).await
}
_ => self.execute_insert_file(sql).await,
}
}
async fn execute_update(&self, sql: &str) -> Result<u64, DfOlapError> {
match &self.storage_mode {
StorageMode::InMemory => self.execute_update_mem(sql).await,
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
self.execute_update_cloud(sql).await
}
_ => self.execute_update_file(sql).await,
}
}
async fn execute_delete(&self, sql: &str) -> Result<u64, DfOlapError> {
match &self.storage_mode {
StorageMode::InMemory => self.execute_delete_mem(sql).await,
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
self.execute_delete_cloud(sql).await
}
_ => self.execute_delete_file(sql).await,
}
}
}
impl Default for DataFusionEngine {
fn default() -> Self {
Self::new()
}
}
impl rhei_core::OlapEngine for DataFusionEngine {
type Error = DfOlapError;
async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
debug!(sql, "DataFusion query");
self.execute_sql(sql).await
}
async fn query_stream(
&self,
sql: &str,
) -> Result<rhei_core::RecordBatchBoxStream, Self::Error> {
debug!(sql, "DataFusion query_stream");
let ctx = self.ctx.read().await;
let df = ctx.sql(sql).await?;
let stream = df.execute_stream().await?;
let mapped = Box::pin(StreamAdapter(stream));
Ok(mapped)
}
async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
debug!(sql, "DataFusion execute");
let trimmed = sql.trim();
let upper = trimmed.to_ascii_uppercase();
if upper.starts_with("INSERT") {
self.execute_insert(trimmed).await
} else if upper.starts_with("UPDATE") {
self.execute_update(trimmed).await
} else if upper.starts_with("DELETE") {
self.execute_delete(trimmed).await
} else if upper.starts_with("BEGIN")
|| upper.starts_with("COMMIT")
|| upper.starts_with("ROLLBACK")
{
Ok(0)
} else {
let ctx = self.ctx.read().await;
let df = ctx.sql(trimmed).await?;
let _ = df.collect().await?;
Ok(0)
}
}
async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
if batches.is_empty() {
return Ok(0);
}
debug!(table, batch_count = batches.len(), "DataFusion load_arrow");
rhei_core::validate_identifier(table).map_err(|e| DfOlapError::Other(e.to_string()))?;
let total_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(table)
.ok_or_else(|| DfOlapError::TableNotFound(table.to_string()))?;
for batch in batches {
table_data.batches.push(batch.clone());
}
drop(tables);
self.refresh_table_mem(table).await?;
}
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
let cloud_tables = self.cloud_tables.read().await;
let meta = cloud_tables
.get(table)
.ok_or_else(|| DfOlapError::TableNotFound(table.to_string()))?;
let schema = meta.schema.clone();
drop(cloud_tables);
self.write_batches_to_cloud(table, &schema, batches).await?;
self.refresh_table_cloud(table).await?;
}
_ => {
let file_tables = self.file_tables.read().await;
let meta = file_tables
.get(table)
.ok_or_else(|| DfOlapError::TableNotFound(table.to_string()))?;
let schema = meta.schema.clone();
drop(file_tables);
self.write_batches_to_file(table, &schema, batches).await?;
self.refresh_table_file(table).await?;
}
}
Ok(total_rows)
}
async fn create_table(
&self,
table_name: &str,
schema: &SchemaRef,
_primary_key: &[String],
) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
for field in schema.fields() {
rhei_core::validate_identifier(field.name())
.map_err(|e| DfOlapError::Other(e.to_string()))?;
}
debug!(
table = table_name,
storage = ?self.storage_mode,
"DataFusion create_table"
);
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
if tables.contains_key(table_name) {
return Ok(());
}
tables.insert(
table_name.to_string(),
TableData {
schema: schema.clone(),
batches: vec![],
},
);
drop(tables);
self.refresh_table_mem(table_name).await?;
}
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { url } | StorageMode::GcsParquet { url } => {
let mut cloud_tables = self.cloud_tables.write().await;
if cloud_tables.contains_key(table_name) {
return Ok(());
}
let table_url = Self::cloud_table_url(url, table_name);
cloud_tables.insert(
table_name.to_string(),
CloudTableMeta {
schema: schema.clone(),
table_url: table_url.clone(),
},
);
drop(cloud_tables);
if let Ok(store) = self.cloud_store() {
match Self::cloud_seq_for_prefix(&store, &table_url, table_name).await {
Ok(next_seq) => {
self.file_counter.fetch_max(next_seq, Ordering::Relaxed);
if next_seq > 0 {
tracing::debug!(
table = table_name,
next_seq,
"cloud restart: advanced file_counter to avoid overwrites"
);
}
}
Err(e) => {
tracing::warn!(
table = table_name,
error = %e,
"cloud_seq_for_prefix failed; file_counter not advanced"
);
}
}
}
self.refresh_table_cloud(table_name).await?;
}
_ => {
let mut file_tables = self.file_tables.write().await;
if file_tables.contains_key(table_name) {
return Ok(());
}
let dir = self.table_dir(table_name).expect("file mode has base_path");
tokio::fs::create_dir_all(&dir).await?;
file_tables.insert(
table_name.to_string(),
FileTableMeta {
schema: schema.clone(),
dir,
},
);
drop(file_tables);
self.refresh_table_file(table_name).await?;
}
}
Ok(())
}
async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
match &self.storage_mode {
StorageMode::InMemory => {
let tables = self.tables.read().await;
Ok(tables.contains_key(table_name))
}
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
let cloud_tables = self.cloud_tables.read().await;
Ok(cloud_tables.contains_key(table_name))
}
_ => {
let file_tables = self.file_tables.read().await;
Ok(file_tables.contains_key(table_name))
}
}
}
async fn add_column(
&self,
table_name: &str,
column_name: &str,
data_type: &DataType,
) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
rhei_core::validate_identifier(column_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
debug!(
table = table_name,
column = column_name,
"DataFusion add_column"
);
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let new_schema = append_field(&table_data.schema, column_name, data_type);
let new_batches =
extend_batches_with_null_column(&table_data.batches, &new_schema, data_type)?;
table_data.schema = new_schema;
table_data.batches = new_batches;
drop(tables);
self.refresh_table_mem(table_name).await?;
}
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
let (old_schema, existing_batches) =
self.read_all_batches_cloud(table_name).await?;
let new_schema = append_field(&old_schema, column_name, data_type);
let new_batches =
extend_batches_with_null_column(&existing_batches, &new_schema, data_type)?;
self.clear_cloud_table(table_name).await?;
if !new_batches.is_empty() {
self.write_batches_to_cloud(table_name, &new_schema, &new_batches)
.await?;
}
let mut cloud_tables = self.cloud_tables.write().await;
if let Some(meta) = cloud_tables.get_mut(table_name) {
meta.schema = new_schema;
}
drop(cloud_tables);
self.refresh_table_cloud(table_name).await?;
}
_ => {
let (old_schema, existing_batches) = self.read_all_batches(table_name).await?;
let new_schema = append_field(&old_schema, column_name, data_type);
let new_batches =
extend_batches_with_null_column(&existing_batches, &new_schema, data_type)?;
self.clear_table_dir(table_name).await?;
if !new_batches.is_empty() {
self.write_batches_to_file(table_name, &new_schema, &new_batches)
.await?;
}
let mut file_tables = self.file_tables.write().await;
if let Some(meta) = file_tables.get_mut(table_name) {
meta.schema = new_schema;
}
drop(file_tables);
self.refresh_table_file(table_name).await?;
}
}
Ok(())
}
async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
rhei_core::validate_identifier(column_name)
.map_err(|e| DfOlapError::Other(e.to_string()))?;
debug!(
table = table_name,
column = column_name,
"DataFusion drop_column"
);
match &self.storage_mode {
StorageMode::InMemory => {
let mut tables = self.tables.write().await;
let table_data = tables
.get_mut(table_name)
.ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
let col_idx = find_column_index(&table_data.schema, column_name, table_name)?;
let new_schema = remove_field(&table_data.schema, col_idx);
let new_batches =
remove_column_from_batches(&table_data.batches, &new_schema, col_idx)?;
table_data.schema = new_schema;
table_data.batches = new_batches;
drop(tables);
self.refresh_table_mem(table_name).await?;
}
#[cfg(feature = "cloud-storage")]
StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
let (old_schema, existing_batches) =
self.read_all_batches_cloud(table_name).await?;
let col_idx = find_column_index(&old_schema, column_name, table_name)?;
let new_schema = remove_field(&old_schema, col_idx);
let new_batches =
remove_column_from_batches(&existing_batches, &new_schema, col_idx)?;
self.clear_cloud_table(table_name).await?;
if !new_batches.is_empty() {
self.write_batches_to_cloud(table_name, &new_schema, &new_batches)
.await?;
}
let mut cloud_tables = self.cloud_tables.write().await;
if let Some(meta) = cloud_tables.get_mut(table_name) {
meta.schema = new_schema;
}
drop(cloud_tables);
self.refresh_table_cloud(table_name).await?;
}
_ => {
let (old_schema, existing_batches) = self.read_all_batches(table_name).await?;
let col_idx = find_column_index(&old_schema, column_name, table_name)?;
let new_schema = remove_field(&old_schema, col_idx);
let new_batches =
remove_column_from_batches(&existing_batches, &new_schema, col_idx)?;
self.clear_table_dir(table_name).await?;
if !new_batches.is_empty() {
self.write_batches_to_file(table_name, &new_schema, &new_batches)
.await?;
}
let mut file_tables = self.file_tables.write().await;
if let Some(meta) = file_tables.get_mut(table_name) {
meta.schema = new_schema;
}
drop(file_tables);
self.refresh_table_file(table_name).await?;
}
}
Ok(())
}
}
fn append_field(schema: &SchemaRef, column_name: &str, data_type: &DataType) -> SchemaRef {
let mut fields: Vec<arrow::datatypes::Field> =
schema.fields().iter().map(|f| f.as_ref().clone()).collect();
fields.push(arrow::datatypes::Field::new(
column_name,
data_type.clone(),
true,
));
Arc::new(arrow::datatypes::Schema::new(fields))
}
fn remove_field(schema: &SchemaRef, col_idx: usize) -> SchemaRef {
let fields: Vec<arrow::datatypes::Field> = schema
.fields()
.iter()
.enumerate()
.filter(|(i, _)| *i != col_idx)
.map(|(_, f)| f.as_ref().clone())
.collect();
Arc::new(arrow::datatypes::Schema::new(fields))
}
fn find_column_index(
schema: &SchemaRef,
column_name: &str,
table_name: &str,
) -> Result<usize, DfOlapError> {
schema
.fields()
.iter()
.position(|f| f.name() == column_name)
.ok_or_else(|| {
DfOlapError::Other(format!(
"column '{}' not found in table '{}'",
column_name, table_name
))
})
}
fn extend_batches_with_null_column(
batches: &[RecordBatch],
new_schema: &SchemaRef,
data_type: &DataType,
) -> Result<Vec<RecordBatch>, DfOlapError> {
let mut new_batches = Vec::with_capacity(batches.len());
for batch in batches {
let null_array = arrow::array::new_null_array(data_type, batch.num_rows());
let mut columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
.map(|i| batch.column(i).clone())
.collect();
columns.push(null_array);
new_batches.push(RecordBatch::try_new(new_schema.clone(), columns)?);
}
Ok(new_batches)
}
fn remove_column_from_batches(
batches: &[RecordBatch],
new_schema: &SchemaRef,
col_idx: usize,
) -> Result<Vec<RecordBatch>, DfOlapError> {
let mut new_batches = Vec::with_capacity(batches.len());
for batch in batches {
let columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
.filter(|i| *i != col_idx)
.map(|i| batch.column(i).clone())
.collect();
new_batches.push(RecordBatch::try_new(new_schema.clone(), columns)?);
}
Ok(new_batches)
}
struct StreamAdapter(datafusion::physical_plan::SendableRecordBatchStream);
impl futures_core::Stream for StreamAdapter {
type Item = Result<RecordBatch, Box<dyn std::error::Error + Send + Sync>>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.0).poll_next(cx).map(|opt| {
opt.map(|r| r.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>))
})
}
}
#[derive(Clone)]
pub struct SharedDataFusionEngine(pub Arc<DataFusionEngine>);
impl SharedDataFusionEngine {
pub fn new(engine: DataFusionEngine) -> Self {
Self(Arc::new(engine))
}
}
impl std::ops::Deref for SharedDataFusionEngine {
type Target = DataFusionEngine;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl rhei_core::OlapEngine for SharedDataFusionEngine {
type Error = DfOlapError;
async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
self.0.query(sql).await
}
async fn query_stream(
&self,
sql: &str,
) -> Result<rhei_core::RecordBatchBoxStream, Self::Error> {
self.0.query_stream(sql).await
}
async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
self.0.execute(sql).await
}
async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
self.0.load_arrow(table, batches).await
}
async fn create_table(
&self,
table_name: &str,
schema: &SchemaRef,
primary_key: &[String],
) -> Result<(), Self::Error> {
self.0.create_table(table_name, schema, primary_key).await
}
async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
self.0.table_exists(table_name).await
}
async fn add_column(
&self,
table_name: &str,
column_name: &str,
data_type: &DataType,
) -> Result<(), Self::Error> {
self.0.add_column(table_name, column_name, data_type).await
}
async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
self.0.drop_column(table_name, column_name).await
}
}
fn expr_to_sql_literal(expr: &Expr) -> Result<String, DfOlapError> {
match expr {
Expr::Value(v) => match &v.value {
Value::Number(n, _) => Ok(n.clone()),
Value::SingleQuotedString(s) => Ok(format!("'{}'", s.replace('\'', "''"))),
Value::Boolean(b) => Ok(if *b { "TRUE".into() } else { "FALSE".into() }),
Value::Null => Ok("NULL".into()),
other => Err(DfOlapError::Other(format!(
"unsupported value literal: {other:?}"
))),
},
Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: inner,
} => {
if let Expr::Value(v) = inner.as_ref() {
if let Value::Number(n, _) = &v.value {
return Ok(format!("-{n}"));
}
}
Err(DfOlapError::Other(format!(
"unsupported unary expression: {expr}"
)))
}
other => Err(DfOlapError::Other(format!(
"unsupported expression in VALUES: {other}"
))),
}
}
fn ident_from_expr(expr: &Expr) -> Result<String, DfOlapError> {
match expr {
Expr::Identifier(ident) => Ok(ident.value.clone()),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|i| i.value.clone())
.ok_or_else(|| DfOlapError::Other("empty compound identifier".into())),
other => Err(DfOlapError::Other(format!(
"expected column name, got: {other}"
))),
}
}
fn extract_where_conditions(expr: &Expr) -> Result<Vec<(String, String)>, DfOlapError> {
match expr {
Expr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
let mut conditions = extract_where_conditions(left)?;
conditions.extend(extract_where_conditions(right)?);
Ok(conditions)
}
Expr::BinaryOp {
left,
op: BinaryOperator::Eq,
right,
} => {
let col = ident_from_expr(left)?;
let val = expr_to_sql_literal(right)?;
Ok(vec![(col, val)])
}
Expr::IsNull(inner) => {
let col = ident_from_expr(inner)?;
Ok(vec![(col, "NULL".into())])
}
Expr::IsNotNull(inner) => {
let col = ident_from_expr(inner)?;
Ok(vec![(col, "__IS_NOT_NULL__".into())])
}
Expr::Nested(inner) => extract_where_conditions(inner),
other => Err(DfOlapError::Other(format!(
"unsupported WHERE expression: {other}"
))),
}
}
fn parse_insert_values(sql: &str) -> Result<(String, Vec<String>, Vec<RecordBatch>), DfOlapError> {
let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
.map_err(|e| DfOlapError::Other(format!("failed to parse INSERT: {e}")))?;
let stmt = stmts
.pop()
.ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
let insert = match stmt {
Statement::Insert(ins) => ins,
other => {
return Err(DfOlapError::Other(format!(
"expected INSERT statement, got: {other:?}"
)));
}
};
let table_name = match &insert.table {
TableObject::TableName(obj_name) => obj_name
.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
.ok_or_else(|| DfOlapError::Other("empty table name in INSERT".into()))?,
TableObject::TableFunction(_) => {
return Err(DfOlapError::Other(
"INSERT INTO TABLE FUNCTION not supported".into(),
));
}
};
rhei_core::validate_identifier(&table_name).map_err(|e| DfOlapError::Other(e.to_string()))?;
let col_name_strings: Vec<String> = insert.columns.iter().map(|id| id.value.clone()).collect();
let source = match insert.source {
Some(q) => q,
None => return Ok((table_name, col_name_strings, vec![])),
};
let values = match *source.body {
SetExpr::Values(v) => v,
other => {
return Err(DfOlapError::Other(format!(
"INSERT source is not a VALUES clause: {other:?}"
)));
}
};
if values.rows.is_empty() {
return Ok((table_name, col_name_strings, vec![]));
}
let rows: Vec<Vec<String>> = values
.rows
.iter()
.map(|row| {
row.iter()
.map(expr_to_sql_literal)
.collect::<Result<_, _>>()
})
.collect::<Result<_, _>>()?;
let col_name_refs: Vec<&str> = col_name_strings.iter().map(|s| s.as_str()).collect();
let num_cols = col_name_refs.len();
if num_cols == 0 {
return Err(DfOlapError::Other(format!(
"INSERT INTO {table_name} requires an explicit column list; `VALUES (...)` without columns is not supported"
)));
}
let batch = build_record_batch_from_values(&col_name_refs, &rows, num_cols)?;
Ok((table_name, col_name_strings, vec![batch]))
}
fn build_record_batch_from_values(
col_names: &[&str],
rows: &[Vec<String>],
num_cols: usize,
) -> Result<RecordBatch, DfOlapError> {
use arrow::array::*;
use arrow::datatypes::{Field, Schema};
let mut types = vec![DataType::Utf8; num_cols]; for col_idx in 0..num_cols {
for row in rows {
if col_idx < row.len() {
let val = &row[col_idx];
let upper = val.to_ascii_uppercase();
if upper == "NULL" {
continue;
}
if upper == "TRUE" || upper == "FALSE" {
types[col_idx] = DataType::Boolean;
break;
}
if val.starts_with('\'') {
types[col_idx] = DataType::Utf8;
break;
}
if val.contains('.') {
if val.parse::<f64>().is_ok() {
types[col_idx] = DataType::Float64;
break;
}
} else if val.parse::<i64>().is_ok() {
types[col_idx] = DataType::Int64;
break;
}
break;
}
}
}
let fields: Vec<Field> = col_names
.iter()
.zip(types.iter())
.map(|(name, dt)| Field::new(*name, dt.clone(), true))
.collect();
let schema = Arc::new(Schema::new(fields));
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(num_cols);
for col_idx in 0..num_cols {
let col_values: Vec<&str> = rows
.iter()
.map(|row| {
if col_idx < row.len() {
row[col_idx].as_str()
} else {
"NULL"
}
})
.collect();
columns.push(build_array(&types[col_idx], &col_values)?);
}
let batch = RecordBatch::try_new(schema, columns)?;
Ok(batch)
}
fn build_array(dt: &DataType, values: &[&str]) -> Result<Arc<dyn Array>, DfOlapError> {
use arrow::array::*;
match dt {
DataType::Int64 => {
let mut builder = Int64Builder::new();
for v in values {
if v.eq_ignore_ascii_case("NULL") {
builder.append_null();
} else {
builder.append_value(
v.parse::<i64>()
.map_err(|e| DfOlapError::Other(format!("parse i64: {e}")))?,
);
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Float64 => {
let mut builder = Float64Builder::new();
for v in values {
if v.eq_ignore_ascii_case("NULL") {
builder.append_null();
} else {
builder.append_value(
v.parse::<f64>()
.map_err(|e| DfOlapError::Other(format!("parse f64: {e}")))?,
);
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Boolean => {
let mut builder = BooleanBuilder::new();
for v in values {
let upper = v.to_ascii_uppercase();
if upper == "NULL" {
builder.append_null();
} else {
builder.append_value(upper == "TRUE");
}
}
Ok(Arc::new(builder.finish()))
}
_ => {
let mut builder = StringBuilder::new();
for v in values {
if v.eq_ignore_ascii_case("NULL") {
builder.append_null();
} else {
let stripped = if v.starts_with('\'') && v.ends_with('\'') && v.len() >= 2 {
&v[1..v.len() - 1]
} else {
v
};
builder.append_value(stripped.replace("''", "'"));
}
}
Ok(Arc::new(builder.finish()))
}
}
}
type ColVal = (String, String);
fn parse_update(sql: &str) -> Result<(String, Vec<ColVal>, Vec<ColVal>), DfOlapError> {
let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
.map_err(|e| DfOlapError::Other(format!("failed to parse UPDATE: {e}")))?;
let stmt = stmts
.pop()
.ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
let update = match stmt {
Statement::Update(upd) => upd,
other => {
return Err(DfOlapError::Other(format!(
"expected UPDATE statement, got: {other:?}"
)));
}
};
let table_name = match &update.table.relation {
TableFactor::Table { name, .. } => name
.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
.ok_or_else(|| DfOlapError::Other("empty table name in UPDATE".into()))?,
other => {
return Err(DfOlapError::Other(format!(
"unexpected table factor in UPDATE: {other:?}"
)));
}
};
let assignments: Vec<ColVal> = update
.assignments
.iter()
.map(|a| {
let col = match &a.target {
AssignmentTarget::ColumnName(obj) => obj
.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
.ok_or_else(|| DfOlapError::Other("empty column name in SET".into()))?,
AssignmentTarget::Tuple(_) => {
return Err(DfOlapError::Other(
"tuple assignments in SET not supported".into(),
));
}
};
let val = expr_to_sql_literal(&a.value)?;
Ok((col, val))
})
.collect::<Result<_, DfOlapError>>()?;
let where_clause = match &update.selection {
Some(expr) => extract_where_conditions(expr)?,
None => vec![],
};
Ok((table_name, assignments, where_clause))
}
fn parse_delete(sql: &str) -> Result<(String, Vec<(String, String)>), DfOlapError> {
let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
.map_err(|e| DfOlapError::Other(format!("failed to parse DELETE: {e}")))?;
let stmt = stmts
.pop()
.ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
let delete = match stmt {
Statement::Delete(del) => del,
other => {
return Err(DfOlapError::Other(format!(
"expected DELETE statement, got: {other:?}"
)));
}
};
let tables = match &delete.from {
FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
};
let table_name = tables
.first()
.and_then(|twj| {
if let TableFactor::Table { name, .. } = &twj.relation {
name.0
.last()
.and_then(|p| p.as_ident())
.map(|id| id.value.clone())
} else {
None
}
})
.ok_or_else(|| DfOlapError::Other("missing table name in DELETE".into()))?;
let where_clause = match &delete.selection {
Some(expr) => extract_where_conditions(expr)?,
None => vec![],
};
Ok((table_name, where_clause))
}
fn flatten_batches(
batches: &[RecordBatch],
schema: &SchemaRef,
) -> Result<Option<RecordBatch>, DfOlapError> {
if batches.is_empty() {
return Ok(None);
}
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
if total_rows == 0 {
return Ok(None);
}
let batch = arrow::compute::concat_batches(schema, batches)?;
Ok(Some(batch))
}
fn apply_update(
batch: &RecordBatch,
schema: &SchemaRef,
assignments: &[(String, String)],
where_conditions: &[(String, String)],
) -> Result<(RecordBatch, u64), DfOlapError> {
let matching = find_matching_rows(batch, schema, where_conditions)?;
let updated_count = matching.iter().filter(|&&m| m).count() as u64;
let mut new_columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for (col_idx, field) in schema.fields().iter().enumerate() {
let assignment = assignments.iter().find(|(col, _)| col == field.name());
if let Some((_, new_val)) = assignment {
let original = batch.column(col_idx);
new_columns.push(apply_value_to_matching(
original,
&matching,
new_val,
field.data_type(),
)?);
} else {
new_columns.push(batch.column(col_idx).clone());
}
}
let new_batch = RecordBatch::try_new(schema.clone(), new_columns)?;
Ok((new_batch, updated_count))
}
fn apply_delete(
batch: &RecordBatch,
schema: &SchemaRef,
where_conditions: &[(String, String)],
) -> Result<(RecordBatch, u64), DfOlapError> {
let matching = find_matching_rows(batch, schema, where_conditions)?;
let deleted_count = matching.iter().filter(|&&m| m).count() as u64;
let mut builder = BooleanBuilder::new();
for &m in &matching {
builder.append_value(!m);
}
let filter_array = builder.finish();
let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
.map(|i| arrow::compute::filter(batch.column(i), &filter_array).map_err(DfOlapError::Arrow))
.collect::<Result<_, _>>()?;
let new_batch = RecordBatch::try_new(schema.clone(), new_columns)?;
Ok((new_batch, deleted_count))
}
fn find_matching_rows(
batch: &RecordBatch,
schema: &SchemaRef,
conditions: &[(String, String)],
) -> Result<Vec<bool>, DfOlapError> {
let num_rows = batch.num_rows();
let mut matching = vec![true; num_rows];
for (col_name, expected_val) in conditions {
let col_idx = schema
.fields()
.iter()
.position(|f| f.name() == col_name)
.ok_or_else(|| DfOlapError::Other(format!("column not found: {col_name}")))?;
let col = batch.column(col_idx);
for (row_idx, m) in matching.iter_mut().enumerate() {
if !*m {
continue;
}
*m = value_matches(col, row_idx, expected_val);
}
}
Ok(matching)
}
fn value_matches(array: &dyn Array, row_idx: usize, expected: &str) -> bool {
if expected == "__IS_NOT_NULL__" {
return !array.is_null(row_idx);
}
if array.is_null(row_idx) {
return expected.eq_ignore_ascii_case("NULL");
}
match array.data_type() {
DataType::Int8 => {
expected.parse::<i8>().ok() == Some(array.as_primitive::<Int8Type>().value(row_idx))
}
DataType::Int16 => {
expected.parse::<i16>().ok() == Some(array.as_primitive::<Int16Type>().value(row_idx))
}
DataType::Int32 => {
expected.parse::<i32>().ok() == Some(array.as_primitive::<Int32Type>().value(row_idx))
}
DataType::Int64 => {
expected.parse::<i64>().ok() == Some(array.as_primitive::<Int64Type>().value(row_idx))
}
DataType::UInt8 => {
expected.parse::<u8>().ok() == Some(array.as_primitive::<UInt8Type>().value(row_idx))
}
DataType::UInt16 => {
expected.parse::<u16>().ok() == Some(array.as_primitive::<UInt16Type>().value(row_idx))
}
DataType::UInt32 => {
expected.parse::<u32>().ok() == Some(array.as_primitive::<UInt32Type>().value(row_idx))
}
DataType::UInt64 => {
expected.parse::<u64>().ok() == Some(array.as_primitive::<UInt64Type>().value(row_idx))
}
DataType::Float32 => {
expected.parse::<f32>().ok() == Some(array.as_primitive::<Float32Type>().value(row_idx))
}
DataType::Float64 => {
expected.parse::<f64>().ok() == Some(array.as_primitive::<Float64Type>().value(row_idx))
}
DataType::Utf8 => {
let arr = array.as_string::<i32>();
let stripped =
if expected.starts_with('\'') && expected.ends_with('\'') && expected.len() >= 2 {
&expected[1..expected.len() - 1]
} else {
expected
};
arr.value(row_idx) == stripped
}
DataType::Boolean => {
let arr = array.as_boolean();
match expected.to_ascii_uppercase().as_str() {
"TRUE" => arr.value(row_idx),
"FALSE" => !arr.value(row_idx),
_ => false,
}
}
_ => false,
}
}
fn apply_value_to_matching(
original: &dyn Array,
matching: &[bool],
new_val: &str,
dt: &DataType,
) -> Result<Arc<dyn Array>, DfOlapError> {
use arrow::array::*;
match dt {
DataType::Int64 => {
let orig = original.as_primitive::<Int64Type>();
let parsed: i64 = new_val
.parse()
.map_err(|e| DfOlapError::Other(format!("parse i64: {e}")))?;
let mut builder = Int64Builder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(parsed);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Float64 => {
let orig = original.as_primitive::<Float64Type>();
let parsed: f64 = new_val
.parse()
.map_err(|e| DfOlapError::Other(format!("parse f64: {e}")))?;
let mut builder = Float64Builder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(parsed);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Utf8 => {
let orig = original.as_string::<i32>();
let stripped =
if new_val.starts_with('\'') && new_val.ends_with('\'') && new_val.len() >= 2 {
&new_val[1..new_val.len() - 1]
} else {
new_val
};
let unescaped = stripped.replace("''", "'");
let mut builder = StringBuilder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(&unescaped);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Boolean => {
let orig = original.as_boolean();
let parsed = new_val.eq_ignore_ascii_case("TRUE");
let mut builder = BooleanBuilder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(parsed);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
_ => {
let orig = original.as_string::<i32>();
let mut builder = StringBuilder::new();
for (i, &m) in matching.iter().enumerate() {
if m {
builder.append_value(new_val);
} else if orig.is_null(i) {
builder.append_null();
} else {
builder.append_value(orig.value(i));
}
}
Ok(Arc::new(builder.finish()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{Field, Schema};
use rhei_core::OlapEngine;
fn users_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
]))
}
fn make_in_memory(_: &std::path::Path) -> DataFusionEngine {
DataFusionEngine::new()
}
fn make_arrow_ipc(tmp: &std::path::Path) -> DataFusionEngine {
DataFusionEngine::with_storage(StorageMode::ArrowIpc {
path: tmp.join("arrow_olap"),
})
.unwrap()
}
fn make_parquet(tmp: &std::path::Path) -> DataFusionEngine {
DataFusionEngine::with_storage(StorageMode::Parquet {
path: tmp.join("parquet_olap"),
})
.unwrap()
}
macro_rules! storage_mode_tests {
($mod_name:ident, $make_engine:ident) => {
mod $mod_name {
use super::*;
#[tokio::test]
async fn create_and_query_empty() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
assert!(engine.table_exists("users").await.unwrap());
assert!(!engine.table_exists("nonexistent").await.unwrap());
}
#[tokio::test]
async fn insert_and_query() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.await
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.await
.unwrap();
let batches = engine
.query("SELECT * FROM users ORDER BY id")
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 2);
}
#[tokio::test]
async fn update() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.await
.unwrap();
let rows = engine
.execute("UPDATE users SET age = 31 WHERE id = 1")
.await
.unwrap();
assert_eq!(rows, 1);
let batches = engine
.query("SELECT age FROM users WHERE id = 1")
.await
.unwrap();
let age = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(age, 31);
}
#[tokio::test]
async fn delete() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute(
"INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)",
)
.await
.unwrap();
let rows = engine
.execute("DELETE FROM users WHERE id = 1")
.await
.unwrap();
assert_eq!(rows, 1);
let batches = engine.query("SELECT COUNT(*) FROM users").await.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 1);
}
#[tokio::test]
async fn load_arrow() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3])),
Arc::new(arrow::array::StringArray::from(vec![
"Alice", "Bob", "Charlie",
])),
Arc::new(arrow::array::Int64Array::from(vec![30, 25, 35])),
],
)
.unwrap();
let loaded = engine.load_arrow("users", &[batch]).await.unwrap();
assert_eq!(loaded, 3);
let batches = engine
.query("SELECT COUNT(*) as cnt FROM users")
.await
.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 3);
}
#[tokio::test]
async fn aggregate() {
let _tmp = tempfile::tempdir().unwrap();
let engine = $make_engine(_tmp.path());
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute(
"INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25), (3, 'Charlie', 35)",
)
.await
.unwrap();
let batches = engine
.query("SELECT AVG(age) as avg_age FROM users")
.await
.unwrap();
let avg = batches[0].column(0).as_primitive::<Float64Type>().value(0);
assert!((avg - 30.0).abs() < 0.01);
}
}
};
}
storage_mode_tests!(in_memory, make_in_memory);
storage_mode_tests!(arrow_ipc, make_arrow_ipc);
storage_mode_tests!(parquet, make_parquet);
#[tokio::test]
async fn insert_string_with_comma() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice, B', 30)")
.await
.unwrap();
let batches = engine
.query("SELECT name FROM users WHERE id = 1")
.await
.unwrap();
let name_arr = batches[0].column(0).as_string::<i32>();
assert_eq!(name_arr.value(0), "Alice, B");
}
#[tokio::test]
async fn insert_null_value() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, NULL, 30)")
.await
.unwrap();
let batches = engine
.query("SELECT name FROM users WHERE id = 1")
.await
.unwrap();
assert!(batches[0].column(0).is_null(0));
}
#[tokio::test]
async fn update_where_and() {
let engine = DataFusionEngine::new();
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("id", DataType::Int64, false),
arrow::datatypes::Field::new("name", DataType::Utf8, true),
arrow::datatypes::Field::new("status", DataType::Utf8, true),
]));
engine.create_table("t", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO t (id, name, status) VALUES (1, 'x', 'active')")
.await
.unwrap();
engine
.execute("INSERT INTO t (id, name, status) VALUES (2, 'y', 'inactive')")
.await
.unwrap();
let updated = engine
.execute("UPDATE t SET name = 'updated' WHERE id = 1 AND status = 'active'")
.await
.unwrap();
assert_eq!(updated, 1);
let batches = engine
.query("SELECT name FROM t WHERE id = 1")
.await
.unwrap();
assert_eq!(batches[0].column(0).as_string::<i32>().value(0), "updated");
let batches2 = engine
.query("SELECT name FROM t WHERE id = 2")
.await
.unwrap();
assert_eq!(batches2[0].column(0).as_string::<i32>().value(0), "y");
}
#[tokio::test]
async fn delete_quoted_identifier() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
.await
.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
.await
.unwrap();
let deleted = engine
.execute(r#"DELETE FROM "users" WHERE id = 1"#)
.await
.unwrap();
assert_eq!(deleted, 1);
let batches = engine.query("SELECT COUNT(*) FROM users").await.unwrap();
let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
assert_eq!(count, 1);
}
#[tokio::test]
async fn insert_escaped_single_quote() {
let engine = DataFusionEngine::new();
let schema = users_schema();
engine.create_table("users", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO users (id, name, age) VALUES (1, 'O''Brien', 42)")
.await
.unwrap();
let batches = engine
.query("SELECT name FROM users WHERE id = 1")
.await
.unwrap();
assert_eq!(batches[0].column(0).as_string::<i32>().value(0), "O'Brien");
}
#[test]
fn parse_insert_multi_row() {
let (table, cols, batches) =
parse_insert_values("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')")
.unwrap();
assert_eq!(table, "users");
assert_eq!(cols, vec!["id", "name"]);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 2);
}
#[test]
fn parse_update_basic() {
let (table, assignments, where_clause) =
parse_update("UPDATE users SET name = 'Alice' WHERE id = 1").unwrap();
assert_eq!(table, "users");
assert_eq!(
assignments,
vec![("name".to_string(), "'Alice'".to_string())]
);
assert_eq!(where_clause, vec![("id".to_string(), "1".to_string())]);
}
#[test]
fn parse_delete_no_where() {
let (table, conditions) = parse_delete("DELETE FROM logs").unwrap();
assert_eq!(table, "logs");
assert!(conditions.is_empty());
}
#[cfg(feature = "cloud-storage")]
#[test]
fn storage_mode_s3_parquet_attributes() {
let mode = StorageMode::S3Parquet {
url: "s3://my-bucket/rhei-data".to_string(),
};
assert!(mode.is_cloud());
assert_eq!(mode.file_extension(), "parquet");
assert!(mode.base_path().is_none());
assert_eq!(mode.cloud_base_url(), Some("s3://my-bucket/rhei-data"));
}
#[cfg(feature = "cloud-storage")]
#[test]
fn storage_mode_gcs_parquet_attributes() {
let mode = StorageMode::GcsParquet {
url: "gs://gcs-bucket/prefix".to_string(),
};
assert!(mode.is_cloud());
assert_eq!(mode.file_extension(), "parquet");
assert!(mode.base_path().is_none());
assert_eq!(mode.cloud_base_url(), Some("gs://gcs-bucket/prefix"));
}
#[cfg(feature = "cloud-storage")]
#[test]
fn parse_bucket_s3() {
let bucket = DataFusionEngine::parse_bucket("s3://my-bucket/some/prefix", "s3").unwrap();
assert_eq!(bucket, "my-bucket");
}
#[cfg(feature = "cloud-storage")]
#[test]
fn parse_bucket_gcs() {
let bucket = DataFusionEngine::parse_bucket("gs://gcs-bucket/data", "gs").unwrap();
assert_eq!(bucket, "gcs-bucket");
}
#[cfg(feature = "cloud-storage")]
#[test]
fn parse_bucket_wrong_scheme_returns_error() {
let result = DataFusionEngine::parse_bucket("gs://bucket/data", "s3");
assert!(result.is_err());
}
#[cfg(feature = "cloud-storage")]
#[test]
fn cloud_table_url_construction() {
assert_eq!(
DataFusionEngine::cloud_table_url("s3://bucket/prefix", "events"),
"s3://bucket/prefix/events/"
);
assert_eq!(
DataFusionEngine::cloud_table_url("s3://bucket/prefix/", "logs"),
"s3://bucket/prefix/logs/"
);
}
#[cfg(feature = "cloud-storage")]
#[test]
fn s3_parquet_engine_construction_does_not_panic() {
let result = DataFusionEngine::with_storage(StorageMode::S3Parquet {
url: "s3://test-bucket/test-prefix".to_string(),
});
let _ = result;
}
#[cfg(feature = "cloud-storage")]
#[test]
fn gcs_parquet_engine_construction_does_not_panic() {
let result = DataFusionEngine::with_storage(StorageMode::GcsParquet {
url: "gs://test-bucket/test-prefix".to_string(),
});
let _ = result;
}
#[cfg(feature = "cloud-storage")]
#[tokio::test]
async fn cloud_seq_for_prefix_empty_prefix_returns_zero() {
let store: Arc<dyn ObjectStore> = Arc::new(object_store::memory::InMemory::new());
let table_url = "s3://bucket/data/events/";
let next = DataFusionEngine::cloud_seq_for_prefix(&store, table_url, "events")
.await
.unwrap();
assert_eq!(next, 0, "empty prefix should yield counter = 0");
}
#[cfg(feature = "cloud-storage")]
#[tokio::test]
async fn cloud_seq_for_prefix_advances_past_existing_files() {
use object_store::ObjectStoreExt as _;
let store: Arc<dyn ObjectStore> = Arc::new(object_store::memory::InMemory::new());
for seq in 0u64..3 {
let path = object_store::path::Path::from(
format!("data/events/events_{seq:06}.parquet").as_str(),
);
store
.put(&path, bytes::Bytes::from_static(b"dummy").into())
.await
.unwrap();
}
let table_url = "s3://bucket/data/events/";
let next = DataFusionEngine::cloud_seq_for_prefix(&store, table_url, "events")
.await
.unwrap();
assert_eq!(next, 3, "counter should start at max_existing + 1");
}
#[cfg(feature = "cloud-storage")]
#[tokio::test]
async fn cloud_engine_restart_does_not_overwrite_existing_files() {
use object_store::ObjectStoreExt as _;
use std::sync::atomic::Ordering;
let store: Arc<dyn ObjectStore> = Arc::new(object_store::memory::InMemory::new());
for seq in 0u64..2 {
let path = object_store::path::Path::from(
format!("prefix/users/users_{seq:06}.parquet").as_str(),
);
store
.put(&path, bytes::Bytes::from_static(b"parquet-data").into())
.await
.unwrap();
}
let table_url = "s3://bucket/prefix/users/";
let next = DataFusionEngine::cloud_seq_for_prefix(&store, table_url, "users")
.await
.unwrap();
assert_eq!(
next, 2,
"restarted engine should begin writing at index 2, not 0"
);
let counter = AtomicU64::new(0);
counter.fetch_max(next, Ordering::Relaxed);
let seq = counter.fetch_add(1, Ordering::Relaxed);
assert_eq!(seq, 2, "first write after restart should use index 2");
for orig_seq in 0u64..2 {
let path = object_store::path::Path::from(
format!("prefix/users/users_{orig_seq:06}.parquet").as_str(),
);
let result = store.get(&path).await;
assert!(
result.is_ok(),
"original file users_{orig_seq:06}.parquet must not be overwritten"
);
}
}
}