1use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::Arc;
19
20use arrow::array::{Array, AsArray, BooleanBuilder, RecordBatch};
21use arrow::datatypes::{
22 DataType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef,
23 UInt16Type, UInt32Type, UInt64Type, UInt8Type,
24};
25use datafusion::datasource::file_format::arrow::ArrowFormat;
26use datafusion::datasource::file_format::parquet::ParquetFormat;
27use datafusion::datasource::file_format::FileFormat;
28use datafusion::datasource::listing::{
29 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
30};
31use datafusion::datasource::MemTable;
32use datafusion::prelude::*;
33use sqlparser::ast::{
34 AssignmentTarget, BinaryOperator, Expr, FromTable, SetExpr, Statement, TableFactor,
35 TableObject, UnaryOperator, Value,
36};
37use sqlparser::dialect::SQLiteDialect;
38use sqlparser::parser::Parser;
39use tokio::sync::RwLock;
40use tracing::debug;
41
42use crate::error::DfOlapError;
43use crate::storage::StorageMode;
44
45#[cfg(feature = "cloud-storage")]
47use object_store::ObjectStore;
48#[cfg(feature = "cloud-storage")]
49use url::Url;
50
51struct TableData {
53 schema: SchemaRef,
54 batches: Vec<RecordBatch>,
56}
57
58struct FileTableMeta {
60 schema: SchemaRef,
61 dir: PathBuf,
63}
64
65#[cfg(feature = "cloud-storage")]
67struct CloudTableMeta {
68 schema: SchemaRef,
69 table_url: String,
71}
72
73pub struct DataFusionEngine {
81 ctx: RwLock<SessionContext>,
82 tables: RwLock<HashMap<String, TableData>>,
84 file_tables: RwLock<HashMap<String, FileTableMeta>>,
86 #[cfg(feature = "cloud-storage")]
88 cloud_tables: RwLock<HashMap<String, CloudTableMeta>>,
89 #[cfg(feature = "cloud-storage")]
95 cloud_store: Option<Arc<dyn ObjectStore>>,
96 storage_mode: StorageMode,
98 file_counter: AtomicU64,
100}
101
102impl DataFusionEngine {
103 pub fn with_storage(mode: StorageMode) -> Result<Self, DfOlapError> {
114 let start_counter = if let Some(base_path) = mode.base_path() {
115 std::fs::create_dir_all(base_path)?;
116 Self::scan_max_file_seq(base_path, mode.file_extension())
117 } else {
118 0
119 };
120
121 #[cfg(feature = "cloud-storage")]
122 let (ctx, cloud_store) = Self::build_session_context(&mode)?;
123 #[cfg(not(feature = "cloud-storage"))]
124 let ctx = SessionContext::new();
125
126 Ok(Self {
127 ctx: RwLock::new(ctx),
128 tables: RwLock::new(HashMap::new()),
129 file_tables: RwLock::new(HashMap::new()),
130 #[cfg(feature = "cloud-storage")]
131 cloud_tables: RwLock::new(HashMap::new()),
132 #[cfg(feature = "cloud-storage")]
133 cloud_store,
134 storage_mode: mode,
135 file_counter: AtomicU64::new(start_counter),
136 })
137 }
138
139 #[cfg(feature = "cloud-storage")]
144 fn build_session_context(
145 mode: &StorageMode,
146 ) -> Result<(SessionContext, Option<Arc<dyn ObjectStore>>), DfOlapError> {
147 let ctx = SessionContext::new();
148 let mut cloud_store: Option<Arc<dyn ObjectStore>> = None;
149
150 match mode {
151 StorageMode::S3Parquet { url } => {
152 let bucket = Self::parse_bucket(url, "s3")?;
153 let store: Arc<dyn ObjectStore> = Arc::new(
154 object_store::aws::AmazonS3Builder::from_env()
155 .with_bucket_name(&bucket)
156 .build()?,
157 );
158 let base_url =
161 Url::parse(&format!("s3://{bucket}")).map_err(DfOlapError::UrlParse)?;
162 ctx.runtime_env()
163 .register_object_store(&base_url, store.clone());
164 cloud_store = Some(store);
165 tracing::info!(bucket, "registered S3 object store");
166 }
167 StorageMode::GcsParquet { url } => {
168 let bucket = Self::parse_bucket(url, "gs")?;
169 let store: Arc<dyn ObjectStore> = Arc::new(
170 object_store::gcp::GoogleCloudStorageBuilder::from_env()
171 .with_bucket_name(&bucket)
172 .build()?,
173 );
174 let base_url =
175 Url::parse(&format!("gs://{bucket}")).map_err(DfOlapError::UrlParse)?;
176 ctx.runtime_env()
177 .register_object_store(&base_url, store.clone());
178 cloud_store = Some(store);
179 tracing::info!(bucket, "registered GCS object store");
180 }
181 _ => {}
182 }
183
184 Ok((ctx, cloud_store))
185 }
186
187 #[cfg(feature = "cloud-storage")]
191 fn parse_bucket(url: &str, expected_scheme: &str) -> Result<String, DfOlapError> {
192 let parsed = Url::parse(url).map_err(DfOlapError::UrlParse)?;
193 if parsed.scheme() != expected_scheme {
194 return Err(DfOlapError::Other(format!(
195 "expected {expected_scheme}:// URL, got '{url}'"
196 )));
197 }
198 parsed
199 .host_str()
200 .map(|h| h.to_string())
201 .ok_or_else(|| DfOlapError::Other(format!("missing bucket name in URL '{url}'")))
202 }
203
204 #[cfg(feature = "cloud-storage")]
208 fn cloud_table_url(base_url: &str, table_name: &str) -> String {
209 let base = base_url.trim_end_matches('/');
210 format!("{base}/{table_name}/")
211 }
212
213 pub fn new() -> Self {
215 Self::with_storage(StorageMode::InMemory).expect("in-memory mode cannot fail")
217 }
218
219 pub fn storage_mode(&self) -> &StorageMode {
221 &self.storage_mode
222 }
223
224 fn scan_max_file_seq(base_path: &std::path::Path, ext: &str) -> u64 {
227 let mut max_seq: u64 = 0;
228 let Ok(entries) = std::fs::read_dir(base_path) else {
229 return 0;
230 };
231 for entry in entries.flatten() {
233 if !entry.path().is_dir() {
234 continue;
235 }
236 let Ok(files) = std::fs::read_dir(entry.path()) else {
237 continue;
238 };
239 for file in files.flatten() {
240 let path = file.path();
241 if path.extension().and_then(|x| x.to_str()) != Some(ext) {
242 continue;
243 }
244 if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
246 if let Some(seq_str) = stem.rsplit('_').next() {
247 if let Ok(seq) = seq_str.parse::<u64>() {
248 max_seq = max_seq.max(seq + 1);
249 }
250 }
251 }
252 }
253 }
254 max_seq
255 }
256
257 fn next_file_name(&self, table_name: &str) -> String {
259 let seq = self.file_counter.fetch_add(1, Ordering::Relaxed);
260 let ext = self.storage_mode.file_extension();
261 format!("{table_name}_{seq:06}.{ext}")
262 }
263
264 fn table_dir(&self, table_name: &str) -> Option<PathBuf> {
266 self.storage_mode
267 .base_path()
268 .map(|base| base.join(table_name))
269 }
270
271 async fn refresh_table_mem(&self, name: &str) -> Result<(), DfOlapError> {
277 let tables = self.tables.read().await;
278 let table_data = tables
279 .get(name)
280 .ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
281
282 let partitions = if table_data.batches.is_empty() {
285 vec![vec![]]
286 } else {
287 vec![table_data.batches.clone()]
288 };
289 let mem_table = MemTable::try_new(table_data.schema.clone(), partitions)?;
290
291 let ctx = self.ctx.write().await;
292 let _ = ctx.deregister_table(name);
293 ctx.register_table(name, Arc::new(mem_table))?;
294 Ok(())
295 }
296
297 async fn refresh_table_file(&self, name: &str) -> Result<(), DfOlapError> {
303 let file_tables = self.file_tables.read().await;
304 let meta = file_tables
305 .get(name)
306 .ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
307
308 let table_path = meta.dir.to_string_lossy().to_string();
309 let format: Arc<dyn FileFormat> = match &self.storage_mode {
310 StorageMode::ArrowIpc { .. } => Arc::new(ArrowFormat),
311 StorageMode::Parquet { .. } => Arc::new(ParquetFormat::default()),
312 StorageMode::InMemory => unreachable!(),
313 #[cfg(feature = "cloud-storage")]
314 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => unreachable!(),
315 };
316
317 let ext = self.storage_mode.file_extension();
318 let listing_options = ListingOptions::new(format).with_file_extension(ext);
319
320 let config =
321 ListingTableConfig::new_with_multi_paths(vec![ListingTableUrl::parse(&table_path)?])
322 .with_listing_options(listing_options)
323 .with_schema(meta.schema.clone());
324
325 let listing_table = ListingTable::try_new(config)?;
326
327 let ctx = self.ctx.write().await;
328 let _ = ctx.deregister_table(name);
329 ctx.register_table(name, Arc::new(listing_table))?;
330 Ok(())
331 }
332
333 #[cfg(feature = "cloud-storage")]
339 async fn refresh_table_cloud(&self, name: &str) -> Result<(), DfOlapError> {
340 let cloud_tables = self.cloud_tables.read().await;
341 let meta = cloud_tables
342 .get(name)
343 .ok_or_else(|| DfOlapError::TableNotFound(name.to_string()))?;
344
345 let table_url = meta.table_url.clone();
346 let schema = meta.schema.clone();
347 drop(cloud_tables);
348
349 let format: Arc<dyn FileFormat> = Arc::new(ParquetFormat::default());
350 let listing_options = ListingOptions::new(format).with_file_extension("parquet");
351
352 let config =
353 ListingTableConfig::new_with_multi_paths(vec![ListingTableUrl::parse(&table_url)?])
354 .with_listing_options(listing_options)
355 .with_schema(schema);
356
357 let listing_table = ListingTable::try_new(config)?;
358
359 let ctx = self.ctx.write().await;
360 let _ = ctx.deregister_table(name);
361 ctx.register_table(name, Arc::new(listing_table))?;
362 Ok(())
363 }
364
365 async fn write_batches_to_file(
371 &self,
372 table_name: &str,
373 schema: &SchemaRef,
374 batches: &[RecordBatch],
375 ) -> Result<PathBuf, DfOlapError> {
376 let dir = self
377 .table_dir(table_name)
378 .ok_or_else(|| DfOlapError::Other("no table dir for in-memory mode".into()))?;
379
380 let file_name = self.next_file_name(table_name);
381 let file_path = dir.join(&file_name);
382
383 let schema = schema.clone();
384 let batches: Vec<RecordBatch> = batches.to_vec();
385 let path = file_path.clone();
386
387 match &self.storage_mode {
388 StorageMode::ArrowIpc { .. } => {
389 tokio::task::spawn_blocking(move || {
390 let file = std::fs::File::create(&path)?;
391 let mut writer =
392 arrow::ipc::writer::FileWriter::try_new(file, schema.as_ref())?;
393 for batch in &batches {
394 writer.write(batch)?;
395 }
396 writer.finish()?;
397 Ok::<_, DfOlapError>(())
398 })
399 .await
400 .map_err(DfOlapError::from_join)??;
401 }
402 StorageMode::Parquet { .. } => {
403 tokio::task::spawn_blocking(move || {
404 let file = std::fs::File::create(&path)?;
405 let props = parquet::file::properties::WriterProperties::builder()
406 .set_writer_version(parquet::file::properties::WriterVersion::PARQUET_2_0)
407 .build();
408 let mut writer =
409 parquet::arrow::ArrowWriter::try_new(file, schema, Some(props))?;
410 for batch in &batches {
411 writer.write(batch)?;
412 }
413 writer.close()?;
414 Ok::<_, DfOlapError>(())
415 })
416 .await
417 .map_err(DfOlapError::from_join)??;
418 }
419 StorageMode::InMemory => unreachable!(),
420 #[cfg(feature = "cloud-storage")]
421 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => unreachable!(),
422 }
423
424 Ok(file_path)
425 }
426
427 fn list_data_files(dir: &std::path::Path, ext: &str) -> Result<Vec<PathBuf>, DfOlapError> {
429 let mut files: Vec<PathBuf> = std::fs::read_dir(dir)?
430 .filter_map(|e| e.ok())
431 .map(|e| e.path())
432 .filter(|p| p.extension().is_some_and(|x| x.to_str() == Some(ext)))
433 .collect();
434 files.sort();
435 Ok(files)
436 }
437
438 async fn read_all_batches(
440 &self,
441 table_name: &str,
442 ) -> Result<(SchemaRef, Vec<RecordBatch>), DfOlapError> {
443 let file_tables = self.file_tables.read().await;
444 let meta = file_tables
445 .get(table_name)
446 .ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
447
448 let schema = meta.schema.clone();
449 let dir = meta.dir.clone();
450 let ext = self.storage_mode.file_extension().to_string();
451 let is_arrow_ipc = matches!(self.storage_mode, StorageMode::ArrowIpc { .. });
452
453 drop(file_tables);
454
455 tokio::task::spawn_blocking(move || {
456 let mut all_batches = Vec::new();
457 let files = Self::list_data_files(&dir, &ext)?;
458
459 for path in files {
460 if is_arrow_ipc {
461 let file = std::fs::File::open(&path)?;
462 let reader = arrow::ipc::reader::FileReader::try_new(file, None)?;
463 for batch in reader {
464 all_batches.push(batch?);
465 }
466 } else {
467 let file = std::fs::File::open(&path)?;
468 let reader = parquet::arrow::arrow_reader::ParquetRecordBatchReader::try_new(
469 file, 8192,
470 )?;
471 for batch in reader {
472 all_batches.push(batch?);
473 }
474 }
475 }
476
477 Ok::<_, DfOlapError>((schema, all_batches))
478 })
479 .await
480 .map_err(DfOlapError::from_join)?
481 }
482
483 async fn clear_table_dir(&self, table_name: &str) -> Result<(), DfOlapError> {
485 let dir = match self.table_dir(table_name) {
486 Some(d) => d,
487 None => return Ok(()),
488 };
489 let ext = self.storage_mode.file_extension().to_string();
490
491 tokio::task::spawn_blocking(move || {
492 let files = Self::list_data_files(&dir, &ext)?;
493 for path in files {
494 std::fs::remove_file(path)?;
495 }
496 Ok::<_, DfOlapError>(())
497 })
498 .await
499 .map_err(DfOlapError::from_join)?
500 }
501
502 #[cfg(feature = "cloud-storage")]
508 fn cloud_store(&self) -> Result<Arc<dyn ObjectStore>, DfOlapError> {
509 self.cloud_store
510 .clone()
511 .ok_or_else(|| DfOlapError::Other("cloud store not initialised".into()))
512 }
513
514 #[cfg(feature = "cloud-storage")]
520 async fn write_batches_to_cloud(
521 &self,
522 table_name: &str,
523 schema: &SchemaRef,
524 batches: &[RecordBatch],
525 ) -> Result<(), DfOlapError> {
526 let cloud_tables = self.cloud_tables.read().await;
527 let meta = cloud_tables
528 .get(table_name)
529 .ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
530 let table_url = meta.table_url.clone();
531 drop(cloud_tables);
532
533 let store = self.cloud_store()?;
534
535 let schema = schema.clone();
537 let batches: Vec<RecordBatch> = batches.to_vec();
538 let parquet_bytes = tokio::task::spawn_blocking(move || {
539 let mut buf = Vec::new();
540 let props = parquet::file::properties::WriterProperties::builder()
541 .set_writer_version(parquet::file::properties::WriterVersion::PARQUET_2_0)
542 .build();
543 let mut writer = parquet::arrow::ArrowWriter::try_new(&mut buf, schema, Some(props))?;
544 for batch in &batches {
545 writer.write(batch)?;
546 }
547 writer.close()?;
548 Ok::<_, DfOlapError>(buf)
549 })
550 .await
551 .map_err(DfOlapError::from_join)??;
552
553 let seq = self.file_counter.fetch_add(1, Ordering::Relaxed);
555 let object_key = format!("{table_name}_{seq:06}.parquet");
556 let table_path_prefix = Url::parse(&table_url)
557 .map_err(DfOlapError::UrlParse)?
558 .path()
559 .trim_start_matches('/')
560 .trim_end_matches('/')
561 .to_string();
562 let object_path =
563 object_store::path::Path::from(format!("{table_path_prefix}/{object_key}").as_str());
564
565 use object_store::ObjectStoreExt as _;
566 store.put(&object_path, parquet_bytes.into()).await?;
567
568 Ok(())
569 }
570
571 #[cfg(feature = "cloud-storage")]
577 async fn cloud_seq_for_prefix(
578 store: &Arc<dyn ObjectStore>,
579 table_url: &str,
580 table_name: &str,
581 ) -> Result<u64, DfOlapError> {
582 let prefix_str = Url::parse(table_url)
583 .map_err(DfOlapError::UrlParse)?
584 .path()
585 .trim_start_matches('/')
586 .trim_end_matches('/')
587 .to_string();
588 let prefix = object_store::path::Path::from(prefix_str.as_str());
589
590 use futures::StreamExt as _;
591 use object_store::ObjectStore as _;
592 let mut list_stream = store.list(Some(&prefix));
593
594 let file_prefix = format!("{table_name}_");
595 let mut max_seq: Option<u64> = None;
596 while let Some(item) = list_stream.next().await {
597 let m = item?;
598 let path_str = m.location.to_string();
599 let file_name = path_str.rsplit('/').next().unwrap_or("");
601 if !file_name.starts_with(&file_prefix) || !file_name.ends_with(".parquet") {
602 continue;
603 }
604 let inner = &file_name[file_prefix.len()..file_name.len() - ".parquet".len()];
606 if let Ok(seq) = inner.parse::<u64>() {
607 max_seq = Some(max_seq.map_or(seq, |m| m.max(seq)));
608 }
609 }
610
611 Ok(max_seq.map_or(0, |m| m + 1))
612 }
613
614 #[cfg(feature = "cloud-storage")]
618 async fn list_cloud_objects(
619 &self,
620 table_name: &str,
621 ) -> Result<(Arc<dyn ObjectStore>, String, Vec<object_store::path::Path>), DfOlapError> {
622 let cloud_tables = self.cloud_tables.read().await;
623 let meta = cloud_tables
624 .get(table_name)
625 .ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
626 let table_url = meta.table_url.clone();
627 drop(cloud_tables);
628
629 let store = self.cloud_store()?;
630
631 let prefix_str = Url::parse(&table_url)
633 .map_err(DfOlapError::UrlParse)?
634 .path()
635 .trim_start_matches('/')
636 .trim_end_matches('/')
637 .to_string();
638 let prefix = object_store::path::Path::from(prefix_str.as_str());
639
640 use futures::StreamExt as _;
641 use object_store::ObjectStore as _;
642 let mut list_stream = store.list(Some(&prefix));
643 let mut paths = Vec::new();
644 while let Some(item) = list_stream.next().await {
645 let m = item?;
646 if m.location.to_string().ends_with(".parquet") {
647 paths.push(m.location);
648 }
649 }
650 paths.sort_by_key(|a| a.to_string());
651
652 Ok((store, table_url, paths))
653 }
654
655 #[cfg(feature = "cloud-storage")]
657 async fn read_all_batches_cloud(
658 &self,
659 table_name: &str,
660 ) -> Result<(SchemaRef, Vec<RecordBatch>), DfOlapError> {
661 let cloud_tables = self.cloud_tables.read().await;
662 let meta = cloud_tables
663 .get(table_name)
664 .ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
665 let schema = meta.schema.clone();
666 drop(cloud_tables);
667
668 let (store, _, object_paths) = self.list_cloud_objects(table_name).await?;
669
670 let mut all_batches = Vec::new();
671 for path in object_paths {
672 use object_store::ObjectStoreExt as _;
673 let get_result = store.get(&path).await?;
674 let bytes = get_result.bytes().await?;
675 let mut batch_vec = tokio::task::spawn_blocking(move || {
676 let reader =
678 parquet::arrow::arrow_reader::ParquetRecordBatchReader::try_new(bytes, 8192)?;
679 let mut batches = Vec::new();
680 for b in reader {
681 batches.push(b?);
682 }
683 Ok::<_, DfOlapError>(batches)
684 })
685 .await
686 .map_err(DfOlapError::from_join)??;
687 all_batches.append(&mut batch_vec);
688 }
689
690 Ok((schema, all_batches))
691 }
692
693 #[cfg(feature = "cloud-storage")]
695 async fn clear_cloud_table(&self, table_name: &str) -> Result<(), DfOlapError> {
696 let (store, _, paths) = self.list_cloud_objects(table_name).await?;
697 use object_store::ObjectStoreExt as _;
698 for path in paths {
699 store.delete(&path).await?;
700 }
701 Ok(())
702 }
703
704 async fn execute_sql(&self, sql: &str) -> Result<Vec<RecordBatch>, DfOlapError> {
706 let ctx = self.ctx.read().await;
707 let df = ctx.sql(sql).await?;
708 let batches = df.collect().await?;
709 Ok(batches)
710 }
711
712 fn align_batches_to_schema(
714 table_schema: &SchemaRef,
715 col_names: &[String],
716 batches: &[RecordBatch],
717 ) -> Result<(Vec<RecordBatch>, u64), DfOlapError> {
718 let mut aligned_batches = Vec::with_capacity(batches.len());
719 let mut total_rows = 0u64;
720 for batch in batches {
721 let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(table_schema.fields().len());
722 for field in table_schema.fields() {
723 let idx = col_names
724 .iter()
725 .position(|c| c == field.name())
726 .ok_or_else(|| {
727 DfOlapError::SchemaMismatch(format!(
728 "column '{}' not in INSERT column list",
729 field.name()
730 ))
731 })?;
732 let col = batch.column(idx);
733 let col = if col.data_type() != field.data_type() {
734 arrow::compute::cast(col, field.data_type())?
735 } else {
736 col.clone()
737 };
738 columns.push(col);
739 }
740 let aligned = RecordBatch::try_new(table_schema.clone(), columns)?;
741 total_rows += aligned.num_rows() as u64;
742 aligned_batches.push(aligned);
743 }
744 Ok((aligned_batches, total_rows))
745 }
746
747 async fn execute_insert_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
752 let (table_name, col_names, batches) = parse_insert_values(sql)?;
753
754 let mut tables = self.tables.write().await;
755 let table_data = tables
756 .get_mut(&table_name)
757 .ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
758
759 let table_schema = table_data.schema.clone();
760 let (aligned_batches, total_rows) =
761 Self::align_batches_to_schema(&table_schema, &col_names, &batches)?;
762 table_data.batches.extend(aligned_batches);
763 drop(tables);
764
765 self.refresh_table_mem(&table_name).await?;
766 Ok(total_rows)
767 }
768
769 async fn execute_update_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
770 let (table_name, assignments, where_clause) = parse_update(sql)?;
771
772 let mut tables = self.tables.write().await;
773 let table_data = tables
774 .get_mut(&table_name)
775 .ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
776
777 let schema = table_data.schema.clone();
778 let mut updated_count = 0u64;
779
780 let all_rows = flatten_batches(&table_data.batches, &schema)?;
781 if let Some(all_rows) = all_rows {
782 let (updated_batch, count) =
783 apply_update(&all_rows, &schema, &assignments, &where_clause)?;
784 updated_count = count;
785 table_data.batches = vec![updated_batch];
786 }
787
788 drop(tables);
789 self.refresh_table_mem(&table_name).await?;
790 Ok(updated_count)
791 }
792
793 async fn execute_delete_mem(&self, sql: &str) -> Result<u64, DfOlapError> {
794 let (table_name, where_clause) = parse_delete(sql)?;
795
796 let mut tables = self.tables.write().await;
797 let table_data = tables
798 .get_mut(&table_name)
799 .ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
800
801 let schema = table_data.schema.clone();
802 let all_rows = flatten_batches(&table_data.batches, &schema)?;
803
804 if let Some(all_rows) = all_rows {
805 let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
806 table_data.batches = if filtered_batch.num_rows() > 0 {
807 vec![filtered_batch]
808 } else {
809 vec![]
810 };
811 drop(tables);
812 self.refresh_table_mem(&table_name).await?;
813 Ok(deleted_count)
814 } else {
815 Ok(0)
816 }
817 }
818
819 async fn execute_insert_file(&self, sql: &str) -> Result<u64, DfOlapError> {
824 let (table_name, col_names, batches) = parse_insert_values(sql)?;
825
826 let file_tables = self.file_tables.read().await;
827 let meta = file_tables
828 .get(&table_name)
829 .ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
830 let table_schema = meta.schema.clone();
831 drop(file_tables);
832
833 let (aligned_batches, total_rows) =
834 Self::align_batches_to_schema(&table_schema, &col_names, &batches)?;
835 self.write_batches_to_file(&table_name, &table_schema, &aligned_batches)
836 .await?;
837 self.refresh_table_file(&table_name).await?;
838 Ok(total_rows)
839 }
840
841 async fn execute_update_file(&self, sql: &str) -> Result<u64, DfOlapError> {
842 let (table_name, assignments, where_clause) = parse_update(sql)?;
843
844 let (schema, existing_batches) = self.read_all_batches(&table_name).await?;
845 let all_rows = flatten_batches(&existing_batches, &schema)?;
846
847 if let Some(all_rows) = all_rows {
848 let (updated_batch, count) =
849 apply_update(&all_rows, &schema, &assignments, &where_clause)?;
850 self.clear_table_dir(&table_name).await?;
852 if updated_batch.num_rows() > 0 {
853 self.write_batches_to_file(&table_name, &schema, &[updated_batch])
854 .await?;
855 }
856 self.refresh_table_file(&table_name).await?;
857 Ok(count)
858 } else {
859 Ok(0)
860 }
861 }
862
863 async fn execute_delete_file(&self, sql: &str) -> Result<u64, DfOlapError> {
864 let (table_name, where_clause) = parse_delete(sql)?;
865
866 let (schema, existing_batches) = self.read_all_batches(&table_name).await?;
867 let all_rows = flatten_batches(&existing_batches, &schema)?;
868
869 if let Some(all_rows) = all_rows {
870 let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
871 self.clear_table_dir(&table_name).await?;
873 if filtered_batch.num_rows() > 0 {
874 self.write_batches_to_file(&table_name, &schema, &[filtered_batch])
875 .await?;
876 }
877 self.refresh_table_file(&table_name).await?;
878 Ok(deleted_count)
879 } else {
880 Ok(0)
881 }
882 }
883
884 #[cfg(feature = "cloud-storage")]
894 async fn execute_insert_cloud(&self, sql: &str) -> Result<u64, DfOlapError> {
895 let (table_name, col_names, batches) = parse_insert_values(sql)?;
896
897 let cloud_tables = self.cloud_tables.read().await;
898 let meta = cloud_tables
899 .get(&table_name)
900 .ok_or_else(|| DfOlapError::TableNotFound(table_name.clone()))?;
901 let table_schema = meta.schema.clone();
902 drop(cloud_tables);
903
904 let (aligned_batches, total_rows) =
905 Self::align_batches_to_schema(&table_schema, &col_names, &batches)?;
906 self.write_batches_to_cloud(&table_name, &table_schema, &aligned_batches)
907 .await?;
908 self.refresh_table_cloud(&table_name).await?;
909 Ok(total_rows)
910 }
911
912 #[cfg(feature = "cloud-storage")]
919 async fn execute_update_cloud(&self, sql: &str) -> Result<u64, DfOlapError> {
920 let (table_name, assignments, where_clause) = parse_update(sql)?;
921
922 let (schema, existing_batches) = self.read_all_batches_cloud(&table_name).await?;
923 let all_rows = flatten_batches(&existing_batches, &schema)?;
924
925 if let Some(all_rows) = all_rows {
926 let (updated_batch, count) =
927 apply_update(&all_rows, &schema, &assignments, &where_clause)?;
928 self.clear_cloud_table(&table_name).await?;
929 if updated_batch.num_rows() > 0 {
930 self.write_batches_to_cloud(&table_name, &schema, &[updated_batch])
931 .await?;
932 }
933 self.refresh_table_cloud(&table_name).await?;
934 Ok(count)
935 } else {
936 Ok(0)
937 }
938 }
939
940 #[cfg(feature = "cloud-storage")]
944 async fn execute_delete_cloud(&self, sql: &str) -> Result<u64, DfOlapError> {
945 let (table_name, where_clause) = parse_delete(sql)?;
946
947 let (schema, existing_batches) = self.read_all_batches_cloud(&table_name).await?;
948 let all_rows = flatten_batches(&existing_batches, &schema)?;
949
950 if let Some(all_rows) = all_rows {
951 let (filtered_batch, deleted_count) = apply_delete(&all_rows, &schema, &where_clause)?;
952 self.clear_cloud_table(&table_name).await?;
953 if filtered_batch.num_rows() > 0 {
954 self.write_batches_to_cloud(&table_name, &schema, &[filtered_batch])
955 .await?;
956 }
957 self.refresh_table_cloud(&table_name).await?;
958 Ok(deleted_count)
959 } else {
960 Ok(0)
961 }
962 }
963
964 async fn execute_insert(&self, sql: &str) -> Result<u64, DfOlapError> {
969 match &self.storage_mode {
970 StorageMode::InMemory => self.execute_insert_mem(sql).await,
971 #[cfg(feature = "cloud-storage")]
972 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
973 self.execute_insert_cloud(sql).await
974 }
975 _ => self.execute_insert_file(sql).await,
976 }
977 }
978
979 async fn execute_update(&self, sql: &str) -> Result<u64, DfOlapError> {
980 match &self.storage_mode {
981 StorageMode::InMemory => self.execute_update_mem(sql).await,
982 #[cfg(feature = "cloud-storage")]
983 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
984 self.execute_update_cloud(sql).await
985 }
986 _ => self.execute_update_file(sql).await,
987 }
988 }
989
990 async fn execute_delete(&self, sql: &str) -> Result<u64, DfOlapError> {
991 match &self.storage_mode {
992 StorageMode::InMemory => self.execute_delete_mem(sql).await,
993 #[cfg(feature = "cloud-storage")]
994 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
995 self.execute_delete_cloud(sql).await
996 }
997 _ => self.execute_delete_file(sql).await,
998 }
999 }
1000}
1001
1002impl Default for DataFusionEngine {
1003 fn default() -> Self {
1004 Self::new()
1005 }
1006}
1007
1008impl rhei_core::OlapEngine for DataFusionEngine {
1009 type Error = DfOlapError;
1010
1011 async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
1012 debug!(sql, "DataFusion query");
1013 self.execute_sql(sql).await
1014 }
1015
1016 async fn query_stream(
1017 &self,
1018 sql: &str,
1019 ) -> Result<rhei_core::RecordBatchBoxStream, Self::Error> {
1020 debug!(sql, "DataFusion query_stream");
1021 let ctx = self.ctx.read().await;
1022 let df = ctx.sql(sql).await?;
1023 let stream = df.execute_stream().await?;
1024 let mapped = Box::pin(StreamAdapter(stream));
1025 Ok(mapped)
1026 }
1027
1028 async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
1029 debug!(sql, "DataFusion execute");
1030 let trimmed = sql.trim();
1031 let upper = trimmed.to_ascii_uppercase();
1032
1033 if upper.starts_with("INSERT") {
1034 self.execute_insert(trimmed).await
1035 } else if upper.starts_with("UPDATE") {
1036 self.execute_update(trimmed).await
1037 } else if upper.starts_with("DELETE") {
1038 self.execute_delete(trimmed).await
1039 } else if upper.starts_with("BEGIN")
1040 || upper.starts_with("COMMIT")
1041 || upper.starts_with("ROLLBACK")
1042 {
1043 Ok(0)
1045 } else {
1046 let ctx = self.ctx.read().await;
1048 let df = ctx.sql(trimmed).await?;
1049 let _ = df.collect().await?;
1050 Ok(0)
1051 }
1052 }
1053
1054 async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
1055 if batches.is_empty() {
1056 return Ok(0);
1057 }
1058
1059 debug!(table, batch_count = batches.len(), "DataFusion load_arrow");
1060 rhei_core::validate_identifier(table).map_err(|e| DfOlapError::Other(e.to_string()))?;
1061
1062 let total_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
1063
1064 match &self.storage_mode {
1065 StorageMode::InMemory => {
1066 let mut tables = self.tables.write().await;
1067 let table_data = tables
1068 .get_mut(table)
1069 .ok_or_else(|| DfOlapError::TableNotFound(table.to_string()))?;
1070
1071 for batch in batches {
1072 table_data.batches.push(batch.clone());
1073 }
1074 drop(tables);
1075 self.refresh_table_mem(table).await?;
1076 }
1077 #[cfg(feature = "cloud-storage")]
1078 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
1079 let cloud_tables = self.cloud_tables.read().await;
1080 let meta = cloud_tables
1081 .get(table)
1082 .ok_or_else(|| DfOlapError::TableNotFound(table.to_string()))?;
1083 let schema = meta.schema.clone();
1084 drop(cloud_tables);
1085
1086 self.write_batches_to_cloud(table, &schema, batches).await?;
1087 self.refresh_table_cloud(table).await?;
1088 }
1089 _ => {
1090 let file_tables = self.file_tables.read().await;
1091 let meta = file_tables
1092 .get(table)
1093 .ok_or_else(|| DfOlapError::TableNotFound(table.to_string()))?;
1094 let schema = meta.schema.clone();
1095 drop(file_tables);
1096
1097 self.write_batches_to_file(table, &schema, batches).await?;
1098 self.refresh_table_file(table).await?;
1099 }
1100 }
1101
1102 Ok(total_rows)
1103 }
1104
1105 async fn create_table(
1106 &self,
1107 table_name: &str,
1108 schema: &SchemaRef,
1109 _primary_key: &[String],
1110 ) -> Result<(), Self::Error> {
1111 rhei_core::validate_identifier(table_name)
1112 .map_err(|e| DfOlapError::Other(e.to_string()))?;
1113 for field in schema.fields() {
1114 rhei_core::validate_identifier(field.name())
1115 .map_err(|e| DfOlapError::Other(e.to_string()))?;
1116 }
1117 debug!(
1121 table = table_name,
1122 storage = ?self.storage_mode,
1123 "DataFusion create_table"
1124 );
1125
1126 match &self.storage_mode {
1127 StorageMode::InMemory => {
1128 let mut tables = self.tables.write().await;
1129 if tables.contains_key(table_name) {
1130 return Ok(());
1131 }
1132 tables.insert(
1133 table_name.to_string(),
1134 TableData {
1135 schema: schema.clone(),
1136 batches: vec![],
1137 },
1138 );
1139 drop(tables);
1140 self.refresh_table_mem(table_name).await?;
1141 }
1142 #[cfg(feature = "cloud-storage")]
1143 StorageMode::S3Parquet { url } | StorageMode::GcsParquet { url } => {
1144 let mut cloud_tables = self.cloud_tables.write().await;
1145 if cloud_tables.contains_key(table_name) {
1146 return Ok(());
1147 }
1148 let table_url = Self::cloud_table_url(url, table_name);
1149 cloud_tables.insert(
1150 table_name.to_string(),
1151 CloudTableMeta {
1152 schema: schema.clone(),
1153 table_url: table_url.clone(),
1154 },
1155 );
1156 drop(cloud_tables);
1157
1158 if let Ok(store) = self.cloud_store() {
1163 match Self::cloud_seq_for_prefix(&store, &table_url, table_name).await {
1164 Ok(next_seq) => {
1165 self.file_counter.fetch_max(next_seq, Ordering::Relaxed);
1169 if next_seq > 0 {
1170 tracing::debug!(
1171 table = table_name,
1172 next_seq,
1173 "cloud restart: advanced file_counter to avoid overwrites"
1174 );
1175 }
1176 }
1177 Err(e) => {
1178 tracing::warn!(
1183 table = table_name,
1184 error = %e,
1185 "cloud_seq_for_prefix failed; file_counter not advanced"
1186 );
1187 }
1188 }
1189 }
1190
1191 self.refresh_table_cloud(table_name).await?;
1194 }
1195 _ => {
1196 let mut file_tables = self.file_tables.write().await;
1197 if file_tables.contains_key(table_name) {
1198 return Ok(());
1199 }
1200 let dir = self.table_dir(table_name).expect("file mode has base_path");
1201 tokio::fs::create_dir_all(&dir).await?;
1202 file_tables.insert(
1203 table_name.to_string(),
1204 FileTableMeta {
1205 schema: schema.clone(),
1206 dir,
1207 },
1208 );
1209 drop(file_tables);
1210 self.refresh_table_file(table_name).await?;
1211 }
1212 }
1213
1214 Ok(())
1215 }
1216
1217 async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
1218 match &self.storage_mode {
1219 StorageMode::InMemory => {
1220 let tables = self.tables.read().await;
1221 Ok(tables.contains_key(table_name))
1222 }
1223 #[cfg(feature = "cloud-storage")]
1224 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
1225 let cloud_tables = self.cloud_tables.read().await;
1226 Ok(cloud_tables.contains_key(table_name))
1227 }
1228 _ => {
1229 let file_tables = self.file_tables.read().await;
1230 Ok(file_tables.contains_key(table_name))
1231 }
1232 }
1233 }
1234
1235 async fn add_column(
1236 &self,
1237 table_name: &str,
1238 column_name: &str,
1239 data_type: &DataType,
1240 ) -> Result<(), Self::Error> {
1241 rhei_core::validate_identifier(table_name)
1242 .map_err(|e| DfOlapError::Other(e.to_string()))?;
1243 rhei_core::validate_identifier(column_name)
1244 .map_err(|e| DfOlapError::Other(e.to_string()))?;
1245
1246 debug!(
1247 table = table_name,
1248 column = column_name,
1249 "DataFusion add_column"
1250 );
1251
1252 match &self.storage_mode {
1253 StorageMode::InMemory => {
1254 let mut tables = self.tables.write().await;
1255 let table_data = tables
1256 .get_mut(table_name)
1257 .ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
1258
1259 let new_schema = append_field(&table_data.schema, column_name, data_type);
1260 let new_batches =
1261 extend_batches_with_null_column(&table_data.batches, &new_schema, data_type)?;
1262 table_data.schema = new_schema;
1263 table_data.batches = new_batches;
1264 drop(tables);
1265 self.refresh_table_mem(table_name).await?;
1266 }
1267 #[cfg(feature = "cloud-storage")]
1268 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
1269 let (old_schema, existing_batches) =
1270 self.read_all_batches_cloud(table_name).await?;
1271 let new_schema = append_field(&old_schema, column_name, data_type);
1272 let new_batches =
1273 extend_batches_with_null_column(&existing_batches, &new_schema, data_type)?;
1274
1275 self.clear_cloud_table(table_name).await?;
1276 if !new_batches.is_empty() {
1277 self.write_batches_to_cloud(table_name, &new_schema, &new_batches)
1278 .await?;
1279 }
1280
1281 let mut cloud_tables = self.cloud_tables.write().await;
1282 if let Some(meta) = cloud_tables.get_mut(table_name) {
1283 meta.schema = new_schema;
1284 }
1285 drop(cloud_tables);
1286 self.refresh_table_cloud(table_name).await?;
1287 }
1288 _ => {
1289 let (old_schema, existing_batches) = self.read_all_batches(table_name).await?;
1291 let new_schema = append_field(&old_schema, column_name, data_type);
1292 let new_batches =
1293 extend_batches_with_null_column(&existing_batches, &new_schema, data_type)?;
1294
1295 self.clear_table_dir(table_name).await?;
1296 if !new_batches.is_empty() {
1297 self.write_batches_to_file(table_name, &new_schema, &new_batches)
1298 .await?;
1299 }
1300
1301 let mut file_tables = self.file_tables.write().await;
1302 if let Some(meta) = file_tables.get_mut(table_name) {
1303 meta.schema = new_schema;
1304 }
1305 drop(file_tables);
1306 self.refresh_table_file(table_name).await?;
1307 }
1308 }
1309
1310 Ok(())
1311 }
1312
1313 async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
1314 rhei_core::validate_identifier(table_name)
1315 .map_err(|e| DfOlapError::Other(e.to_string()))?;
1316 rhei_core::validate_identifier(column_name)
1317 .map_err(|e| DfOlapError::Other(e.to_string()))?;
1318
1319 debug!(
1320 table = table_name,
1321 column = column_name,
1322 "DataFusion drop_column"
1323 );
1324
1325 match &self.storage_mode {
1326 StorageMode::InMemory => {
1327 let mut tables = self.tables.write().await;
1328 let table_data = tables
1329 .get_mut(table_name)
1330 .ok_or_else(|| DfOlapError::TableNotFound(table_name.to_string()))?;
1331
1332 let col_idx = find_column_index(&table_data.schema, column_name, table_name)?;
1333 let new_schema = remove_field(&table_data.schema, col_idx);
1334 let new_batches =
1335 remove_column_from_batches(&table_data.batches, &new_schema, col_idx)?;
1336 table_data.schema = new_schema;
1337 table_data.batches = new_batches;
1338 drop(tables);
1339 self.refresh_table_mem(table_name).await?;
1340 }
1341 #[cfg(feature = "cloud-storage")]
1342 StorageMode::S3Parquet { .. } | StorageMode::GcsParquet { .. } => {
1343 let (old_schema, existing_batches) =
1344 self.read_all_batches_cloud(table_name).await?;
1345 let col_idx = find_column_index(&old_schema, column_name, table_name)?;
1346 let new_schema = remove_field(&old_schema, col_idx);
1347 let new_batches =
1348 remove_column_from_batches(&existing_batches, &new_schema, col_idx)?;
1349
1350 self.clear_cloud_table(table_name).await?;
1351 if !new_batches.is_empty() {
1352 self.write_batches_to_cloud(table_name, &new_schema, &new_batches)
1353 .await?;
1354 }
1355
1356 let mut cloud_tables = self.cloud_tables.write().await;
1357 if let Some(meta) = cloud_tables.get_mut(table_name) {
1358 meta.schema = new_schema;
1359 }
1360 drop(cloud_tables);
1361 self.refresh_table_cloud(table_name).await?;
1362 }
1363 _ => {
1364 let (old_schema, existing_batches) = self.read_all_batches(table_name).await?;
1365 let col_idx = find_column_index(&old_schema, column_name, table_name)?;
1366 let new_schema = remove_field(&old_schema, col_idx);
1367 let new_batches =
1368 remove_column_from_batches(&existing_batches, &new_schema, col_idx)?;
1369
1370 self.clear_table_dir(table_name).await?;
1371 if !new_batches.is_empty() {
1372 self.write_batches_to_file(table_name, &new_schema, &new_batches)
1373 .await?;
1374 }
1375
1376 let mut file_tables = self.file_tables.write().await;
1377 if let Some(meta) = file_tables.get_mut(table_name) {
1378 meta.schema = new_schema;
1379 }
1380 drop(file_tables);
1381 self.refresh_table_file(table_name).await?;
1382 }
1383 }
1384
1385 Ok(())
1386 }
1387}
1388
1389fn append_field(schema: &SchemaRef, column_name: &str, data_type: &DataType) -> SchemaRef {
1394 let mut fields: Vec<arrow::datatypes::Field> =
1395 schema.fields().iter().map(|f| f.as_ref().clone()).collect();
1396 fields.push(arrow::datatypes::Field::new(
1397 column_name,
1398 data_type.clone(),
1399 true,
1400 ));
1401 Arc::new(arrow::datatypes::Schema::new(fields))
1402}
1403
1404fn remove_field(schema: &SchemaRef, col_idx: usize) -> SchemaRef {
1405 let fields: Vec<arrow::datatypes::Field> = schema
1406 .fields()
1407 .iter()
1408 .enumerate()
1409 .filter(|(i, _)| *i != col_idx)
1410 .map(|(_, f)| f.as_ref().clone())
1411 .collect();
1412 Arc::new(arrow::datatypes::Schema::new(fields))
1413}
1414
1415fn find_column_index(
1416 schema: &SchemaRef,
1417 column_name: &str,
1418 table_name: &str,
1419) -> Result<usize, DfOlapError> {
1420 schema
1421 .fields()
1422 .iter()
1423 .position(|f| f.name() == column_name)
1424 .ok_or_else(|| {
1425 DfOlapError::Other(format!(
1426 "column '{}' not found in table '{}'",
1427 column_name, table_name
1428 ))
1429 })
1430}
1431
1432fn extend_batches_with_null_column(
1433 batches: &[RecordBatch],
1434 new_schema: &SchemaRef,
1435 data_type: &DataType,
1436) -> Result<Vec<RecordBatch>, DfOlapError> {
1437 let mut new_batches = Vec::with_capacity(batches.len());
1438 for batch in batches {
1439 let null_array = arrow::array::new_null_array(data_type, batch.num_rows());
1440 let mut columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
1441 .map(|i| batch.column(i).clone())
1442 .collect();
1443 columns.push(null_array);
1444 new_batches.push(RecordBatch::try_new(new_schema.clone(), columns)?);
1445 }
1446 Ok(new_batches)
1447}
1448
1449fn remove_column_from_batches(
1450 batches: &[RecordBatch],
1451 new_schema: &SchemaRef,
1452 col_idx: usize,
1453) -> Result<Vec<RecordBatch>, DfOlapError> {
1454 let mut new_batches = Vec::with_capacity(batches.len());
1455 for batch in batches {
1456 let columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
1457 .filter(|i| *i != col_idx)
1458 .map(|i| batch.column(i).clone())
1459 .collect();
1460 new_batches.push(RecordBatch::try_new(new_schema.clone(), columns)?);
1461 }
1462 Ok(new_batches)
1463}
1464
1465struct StreamAdapter(datafusion::physical_plan::SendableRecordBatchStream);
1467
1468impl futures_core::Stream for StreamAdapter {
1469 type Item = Result<RecordBatch, Box<dyn std::error::Error + Send + Sync>>;
1470
1471 fn poll_next(
1472 mut self: std::pin::Pin<&mut Self>,
1473 cx: &mut std::task::Context<'_>,
1474 ) -> std::task::Poll<Option<Self::Item>> {
1475 std::pin::Pin::new(&mut self.0).poll_next(cx).map(|opt| {
1476 opt.map(|r| r.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>))
1477 })
1478 }
1479}
1480
1481#[derive(Clone)]
1492pub struct SharedDataFusionEngine(pub Arc<DataFusionEngine>);
1493
1494impl SharedDataFusionEngine {
1495 pub fn new(engine: DataFusionEngine) -> Self {
1497 Self(Arc::new(engine))
1498 }
1499}
1500
1501impl std::ops::Deref for SharedDataFusionEngine {
1502 type Target = DataFusionEngine;
1503 fn deref(&self) -> &Self::Target {
1504 &self.0
1505 }
1506}
1507
1508impl rhei_core::OlapEngine for SharedDataFusionEngine {
1509 type Error = DfOlapError;
1510
1511 async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
1512 self.0.query(sql).await
1513 }
1514
1515 async fn query_stream(
1516 &self,
1517 sql: &str,
1518 ) -> Result<rhei_core::RecordBatchBoxStream, Self::Error> {
1519 self.0.query_stream(sql).await
1520 }
1521
1522 async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
1523 self.0.execute(sql).await
1524 }
1525
1526 async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
1527 self.0.load_arrow(table, batches).await
1528 }
1529
1530 async fn create_table(
1531 &self,
1532 table_name: &str,
1533 schema: &SchemaRef,
1534 primary_key: &[String],
1535 ) -> Result<(), Self::Error> {
1536 self.0.create_table(table_name, schema, primary_key).await
1537 }
1538
1539 async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
1540 self.0.table_exists(table_name).await
1541 }
1542
1543 async fn add_column(
1544 &self,
1545 table_name: &str,
1546 column_name: &str,
1547 data_type: &DataType,
1548 ) -> Result<(), Self::Error> {
1549 self.0.add_column(table_name, column_name, data_type).await
1550 }
1551
1552 async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
1553 self.0.drop_column(table_name, column_name).await
1554 }
1555}
1556
1557fn expr_to_sql_literal(expr: &Expr) -> Result<String, DfOlapError> {
1568 match expr {
1569 Expr::Value(v) => match &v.value {
1570 Value::Number(n, _) => Ok(n.clone()),
1571 Value::SingleQuotedString(s) => Ok(format!("'{}'", s.replace('\'', "''"))),
1572 Value::Boolean(b) => Ok(if *b { "TRUE".into() } else { "FALSE".into() }),
1573 Value::Null => Ok("NULL".into()),
1574 other => Err(DfOlapError::Other(format!(
1575 "unsupported value literal: {other:?}"
1576 ))),
1577 },
1578 Expr::UnaryOp {
1579 op: UnaryOperator::Minus,
1580 expr: inner,
1581 } => {
1582 if let Expr::Value(v) = inner.as_ref() {
1584 if let Value::Number(n, _) = &v.value {
1585 return Ok(format!("-{n}"));
1586 }
1587 }
1588 Err(DfOlapError::Other(format!(
1589 "unsupported unary expression: {expr}"
1590 )))
1591 }
1592 other => Err(DfOlapError::Other(format!(
1593 "unsupported expression in VALUES: {other}"
1594 ))),
1595 }
1596}
1597
1598fn ident_from_expr(expr: &Expr) -> Result<String, DfOlapError> {
1600 match expr {
1601 Expr::Identifier(ident) => Ok(ident.value.clone()),
1602 Expr::CompoundIdentifier(parts) => parts
1603 .last()
1604 .map(|i| i.value.clone())
1605 .ok_or_else(|| DfOlapError::Other("empty compound identifier".into())),
1606 other => Err(DfOlapError::Other(format!(
1607 "expected column name, got: {other}"
1608 ))),
1609 }
1610}
1611
1612fn extract_where_conditions(expr: &Expr) -> Result<Vec<(String, String)>, DfOlapError> {
1622 match expr {
1623 Expr::BinaryOp {
1624 left,
1625 op: BinaryOperator::And,
1626 right,
1627 } => {
1628 let mut conditions = extract_where_conditions(left)?;
1629 conditions.extend(extract_where_conditions(right)?);
1630 Ok(conditions)
1631 }
1632 Expr::BinaryOp {
1633 left,
1634 op: BinaryOperator::Eq,
1635 right,
1636 } => {
1637 let col = ident_from_expr(left)?;
1638 let val = expr_to_sql_literal(right)?;
1639 Ok(vec![(col, val)])
1640 }
1641 Expr::IsNull(inner) => {
1643 let col = ident_from_expr(inner)?;
1644 Ok(vec![(col, "NULL".into())])
1645 }
1646 Expr::IsNotNull(inner) => {
1649 let col = ident_from_expr(inner)?;
1650 Ok(vec![(col, "__IS_NOT_NULL__".into())])
1651 }
1652 Expr::Nested(inner) => extract_where_conditions(inner),
1654 other => Err(DfOlapError::Other(format!(
1655 "unsupported WHERE expression: {other}"
1656 ))),
1657 }
1658}
1659
1660fn parse_insert_values(sql: &str) -> Result<(String, Vec<String>, Vec<RecordBatch>), DfOlapError> {
1664 let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
1665 .map_err(|e| DfOlapError::Other(format!("failed to parse INSERT: {e}")))?;
1666
1667 let stmt = stmts
1668 .pop()
1669 .ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
1670
1671 let insert = match stmt {
1672 Statement::Insert(ins) => ins,
1673 other => {
1674 return Err(DfOlapError::Other(format!(
1675 "expected INSERT statement, got: {other:?}"
1676 )));
1677 }
1678 };
1679
1680 let table_name = match &insert.table {
1682 TableObject::TableName(obj_name) => obj_name
1683 .0
1684 .last()
1685 .and_then(|p| p.as_ident())
1686 .map(|id| id.value.clone())
1687 .ok_or_else(|| DfOlapError::Other("empty table name in INSERT".into()))?,
1688 TableObject::TableFunction(_) => {
1689 return Err(DfOlapError::Other(
1690 "INSERT INTO TABLE FUNCTION not supported".into(),
1691 ));
1692 }
1693 };
1694
1695 rhei_core::validate_identifier(&table_name).map_err(|e| DfOlapError::Other(e.to_string()))?;
1696
1697 let col_name_strings: Vec<String> = insert.columns.iter().map(|id| id.value.clone()).collect();
1699
1700 let source = match insert.source {
1702 Some(q) => q,
1703 None => return Ok((table_name, col_name_strings, vec![])),
1704 };
1705
1706 let values = match *source.body {
1707 SetExpr::Values(v) => v,
1708 other => {
1709 return Err(DfOlapError::Other(format!(
1710 "INSERT source is not a VALUES clause: {other:?}"
1711 )));
1712 }
1713 };
1714
1715 if values.rows.is_empty() {
1716 return Ok((table_name, col_name_strings, vec![]));
1717 }
1718
1719 let rows: Vec<Vec<String>> = values
1721 .rows
1722 .iter()
1723 .map(|row| {
1724 row.iter()
1725 .map(expr_to_sql_literal)
1726 .collect::<Result<_, _>>()
1727 })
1728 .collect::<Result<_, _>>()?;
1729
1730 let col_name_refs: Vec<&str> = col_name_strings.iter().map(|s| s.as_str()).collect();
1731 let num_cols = col_name_refs.len();
1732
1733 if num_cols == 0 {
1734 return Err(DfOlapError::Other(format!(
1739 "INSERT INTO {table_name} requires an explicit column list; `VALUES (...)` without columns is not supported"
1740 )));
1741 }
1742
1743 let batch = build_record_batch_from_values(&col_name_refs, &rows, num_cols)?;
1744 Ok((table_name, col_name_strings, vec![batch]))
1745}
1746
1747fn build_record_batch_from_values(
1752 col_names: &[&str],
1753 rows: &[Vec<String>],
1754 num_cols: usize,
1755) -> Result<RecordBatch, DfOlapError> {
1756 use arrow::array::*;
1757 use arrow::datatypes::{Field, Schema};
1758
1759 let mut types = vec![DataType::Utf8; num_cols]; for col_idx in 0..num_cols {
1762 for row in rows {
1763 if col_idx < row.len() {
1764 let val = &row[col_idx];
1765 let upper = val.to_ascii_uppercase();
1766 if upper == "NULL" {
1767 continue;
1768 }
1769 if upper == "TRUE" || upper == "FALSE" {
1770 types[col_idx] = DataType::Boolean;
1771 break;
1772 }
1773 if val.starts_with('\'') {
1774 types[col_idx] = DataType::Utf8;
1775 break;
1776 }
1777 if val.contains('.') {
1778 if val.parse::<f64>().is_ok() {
1779 types[col_idx] = DataType::Float64;
1780 break;
1781 }
1782 } else if val.parse::<i64>().is_ok() {
1783 types[col_idx] = DataType::Int64;
1784 break;
1785 }
1786 break;
1788 }
1789 }
1790 }
1791
1792 let fields: Vec<Field> = col_names
1793 .iter()
1794 .zip(types.iter())
1795 .map(|(name, dt)| Field::new(*name, dt.clone(), true))
1796 .collect();
1797 let schema = Arc::new(Schema::new(fields));
1798
1799 let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(num_cols);
1801 for col_idx in 0..num_cols {
1802 let col_values: Vec<&str> = rows
1803 .iter()
1804 .map(|row| {
1805 if col_idx < row.len() {
1806 row[col_idx].as_str()
1807 } else {
1808 "NULL"
1809 }
1810 })
1811 .collect();
1812
1813 columns.push(build_array(&types[col_idx], &col_values)?);
1814 }
1815
1816 let batch = RecordBatch::try_new(schema, columns)?;
1817 Ok(batch)
1818}
1819
1820fn build_array(dt: &DataType, values: &[&str]) -> Result<Arc<dyn Array>, DfOlapError> {
1822 use arrow::array::*;
1823
1824 match dt {
1825 DataType::Int64 => {
1826 let mut builder = Int64Builder::new();
1827 for v in values {
1828 if v.eq_ignore_ascii_case("NULL") {
1829 builder.append_null();
1830 } else {
1831 builder.append_value(
1832 v.parse::<i64>()
1833 .map_err(|e| DfOlapError::Other(format!("parse i64: {e}")))?,
1834 );
1835 }
1836 }
1837 Ok(Arc::new(builder.finish()))
1838 }
1839 DataType::Float64 => {
1840 let mut builder = Float64Builder::new();
1841 for v in values {
1842 if v.eq_ignore_ascii_case("NULL") {
1843 builder.append_null();
1844 } else {
1845 builder.append_value(
1846 v.parse::<f64>()
1847 .map_err(|e| DfOlapError::Other(format!("parse f64: {e}")))?,
1848 );
1849 }
1850 }
1851 Ok(Arc::new(builder.finish()))
1852 }
1853 DataType::Boolean => {
1854 let mut builder = BooleanBuilder::new();
1855 for v in values {
1856 let upper = v.to_ascii_uppercase();
1857 if upper == "NULL" {
1858 builder.append_null();
1859 } else {
1860 builder.append_value(upper == "TRUE");
1861 }
1862 }
1863 Ok(Arc::new(builder.finish()))
1864 }
1865 _ => {
1866 let mut builder = StringBuilder::new();
1867 for v in values {
1868 if v.eq_ignore_ascii_case("NULL") {
1869 builder.append_null();
1870 } else {
1871 let stripped = if v.starts_with('\'') && v.ends_with('\'') && v.len() >= 2 {
1873 &v[1..v.len() - 1]
1874 } else {
1875 v
1876 };
1877 builder.append_value(stripped.replace("''", "'"));
1879 }
1880 }
1881 Ok(Arc::new(builder.finish()))
1882 }
1883 }
1884}
1885
1886type ColVal = (String, String);
1888
1889fn parse_update(sql: &str) -> Result<(String, Vec<ColVal>, Vec<ColVal>), DfOlapError> {
1891 let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
1892 .map_err(|e| DfOlapError::Other(format!("failed to parse UPDATE: {e}")))?;
1893
1894 let stmt = stmts
1895 .pop()
1896 .ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
1897
1898 let update = match stmt {
1899 Statement::Update(upd) => upd,
1900 other => {
1901 return Err(DfOlapError::Other(format!(
1902 "expected UPDATE statement, got: {other:?}"
1903 )));
1904 }
1905 };
1906
1907 let table_name = match &update.table.relation {
1909 TableFactor::Table { name, .. } => name
1910 .0
1911 .last()
1912 .and_then(|p| p.as_ident())
1913 .map(|id| id.value.clone())
1914 .ok_or_else(|| DfOlapError::Other("empty table name in UPDATE".into()))?,
1915 other => {
1916 return Err(DfOlapError::Other(format!(
1917 "unexpected table factor in UPDATE: {other:?}"
1918 )));
1919 }
1920 };
1921
1922 let assignments: Vec<ColVal> = update
1924 .assignments
1925 .iter()
1926 .map(|a| {
1927 let col = match &a.target {
1928 AssignmentTarget::ColumnName(obj) => obj
1929 .0
1930 .last()
1931 .and_then(|p| p.as_ident())
1932 .map(|id| id.value.clone())
1933 .ok_or_else(|| DfOlapError::Other("empty column name in SET".into()))?,
1934 AssignmentTarget::Tuple(_) => {
1935 return Err(DfOlapError::Other(
1936 "tuple assignments in SET not supported".into(),
1937 ));
1938 }
1939 };
1940 let val = expr_to_sql_literal(&a.value)?;
1941 Ok((col, val))
1942 })
1943 .collect::<Result<_, DfOlapError>>()?;
1944
1945 let where_clause = match &update.selection {
1947 Some(expr) => extract_where_conditions(expr)?,
1948 None => vec![],
1949 };
1950
1951 Ok((table_name, assignments, where_clause))
1952}
1953
1954fn parse_delete(sql: &str) -> Result<(String, Vec<(String, String)>), DfOlapError> {
1956 let mut stmts = Parser::parse_sql(&SQLiteDialect {}, sql)
1957 .map_err(|e| DfOlapError::Other(format!("failed to parse DELETE: {e}")))?;
1958
1959 let stmt = stmts
1960 .pop()
1961 .ok_or_else(|| DfOlapError::Other("empty SQL statement".into()))?;
1962
1963 let delete = match stmt {
1964 Statement::Delete(del) => del,
1965 other => {
1966 return Err(DfOlapError::Other(format!(
1967 "expected DELETE statement, got: {other:?}"
1968 )));
1969 }
1970 };
1971
1972 let tables = match &delete.from {
1974 FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
1975 };
1976
1977 let table_name = tables
1978 .first()
1979 .and_then(|twj| {
1980 if let TableFactor::Table { name, .. } = &twj.relation {
1981 name.0
1982 .last()
1983 .and_then(|p| p.as_ident())
1984 .map(|id| id.value.clone())
1985 } else {
1986 None
1987 }
1988 })
1989 .ok_or_else(|| DfOlapError::Other("missing table name in DELETE".into()))?;
1990
1991 let where_clause = match &delete.selection {
1993 Some(expr) => extract_where_conditions(expr)?,
1994 None => vec![],
1995 };
1996
1997 Ok((table_name, where_clause))
1998}
1999
2000fn flatten_batches(
2002 batches: &[RecordBatch],
2003 schema: &SchemaRef,
2004) -> Result<Option<RecordBatch>, DfOlapError> {
2005 if batches.is_empty() {
2006 return Ok(None);
2007 }
2008
2009 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
2010 if total_rows == 0 {
2011 return Ok(None);
2012 }
2013
2014 let batch = arrow::compute::concat_batches(schema, batches)?;
2015 Ok(Some(batch))
2016}
2017
2018fn apply_update(
2020 batch: &RecordBatch,
2021 schema: &SchemaRef,
2022 assignments: &[(String, String)],
2023 where_conditions: &[(String, String)],
2024) -> Result<(RecordBatch, u64), DfOlapError> {
2025 let matching = find_matching_rows(batch, schema, where_conditions)?;
2026 let updated_count = matching.iter().filter(|&&m| m).count() as u64;
2027
2028 let mut new_columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
2030 for (col_idx, field) in schema.fields().iter().enumerate() {
2031 let assignment = assignments.iter().find(|(col, _)| col == field.name());
2033
2034 if let Some((_, new_val)) = assignment {
2035 let original = batch.column(col_idx);
2037 new_columns.push(apply_value_to_matching(
2038 original,
2039 &matching,
2040 new_val,
2041 field.data_type(),
2042 )?);
2043 } else {
2044 new_columns.push(batch.column(col_idx).clone());
2045 }
2046 }
2047
2048 let new_batch = RecordBatch::try_new(schema.clone(), new_columns)?;
2049 Ok((new_batch, updated_count))
2050}
2051
2052fn apply_delete(
2054 batch: &RecordBatch,
2055 schema: &SchemaRef,
2056 where_conditions: &[(String, String)],
2057) -> Result<(RecordBatch, u64), DfOlapError> {
2058 let matching = find_matching_rows(batch, schema, where_conditions)?;
2059 let deleted_count = matching.iter().filter(|&&m| m).count() as u64;
2060
2061 let mut builder = BooleanBuilder::new();
2063 for &m in &matching {
2064 builder.append_value(!m);
2065 }
2066 let filter_array = builder.finish();
2067
2068 let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
2069 .map(|i| arrow::compute::filter(batch.column(i), &filter_array).map_err(DfOlapError::Arrow))
2070 .collect::<Result<_, _>>()?;
2071
2072 let new_batch = RecordBatch::try_new(schema.clone(), new_columns)?;
2073 Ok((new_batch, deleted_count))
2074}
2075
2076fn find_matching_rows(
2078 batch: &RecordBatch,
2079 schema: &SchemaRef,
2080 conditions: &[(String, String)],
2081) -> Result<Vec<bool>, DfOlapError> {
2082 let num_rows = batch.num_rows();
2083 let mut matching = vec![true; num_rows];
2084
2085 for (col_name, expected_val) in conditions {
2086 let col_idx = schema
2087 .fields()
2088 .iter()
2089 .position(|f| f.name() == col_name)
2090 .ok_or_else(|| DfOlapError::Other(format!("column not found: {col_name}")))?;
2091
2092 let col = batch.column(col_idx);
2093 for (row_idx, m) in matching.iter_mut().enumerate() {
2094 if !*m {
2095 continue;
2096 }
2097 *m = value_matches(col, row_idx, expected_val);
2098 }
2099 }
2100
2101 Ok(matching)
2102}
2103
2104fn value_matches(array: &dyn Array, row_idx: usize, expected: &str) -> bool {
2110 if expected == "__IS_NOT_NULL__" {
2112 return !array.is_null(row_idx);
2113 }
2114 if array.is_null(row_idx) {
2115 return expected.eq_ignore_ascii_case("NULL");
2116 }
2117
2118 match array.data_type() {
2119 DataType::Int8 => {
2120 expected.parse::<i8>().ok() == Some(array.as_primitive::<Int8Type>().value(row_idx))
2121 }
2122 DataType::Int16 => {
2123 expected.parse::<i16>().ok() == Some(array.as_primitive::<Int16Type>().value(row_idx))
2124 }
2125 DataType::Int32 => {
2126 expected.parse::<i32>().ok() == Some(array.as_primitive::<Int32Type>().value(row_idx))
2127 }
2128 DataType::Int64 => {
2129 expected.parse::<i64>().ok() == Some(array.as_primitive::<Int64Type>().value(row_idx))
2130 }
2131 DataType::UInt8 => {
2132 expected.parse::<u8>().ok() == Some(array.as_primitive::<UInt8Type>().value(row_idx))
2133 }
2134 DataType::UInt16 => {
2135 expected.parse::<u16>().ok() == Some(array.as_primitive::<UInt16Type>().value(row_idx))
2136 }
2137 DataType::UInt32 => {
2138 expected.parse::<u32>().ok() == Some(array.as_primitive::<UInt32Type>().value(row_idx))
2139 }
2140 DataType::UInt64 => {
2141 expected.parse::<u64>().ok() == Some(array.as_primitive::<UInt64Type>().value(row_idx))
2142 }
2143 DataType::Float32 => {
2144 expected.parse::<f32>().ok() == Some(array.as_primitive::<Float32Type>().value(row_idx))
2145 }
2146 DataType::Float64 => {
2147 expected.parse::<f64>().ok() == Some(array.as_primitive::<Float64Type>().value(row_idx))
2148 }
2149 DataType::Utf8 => {
2150 let arr = array.as_string::<i32>();
2151 let stripped =
2152 if expected.starts_with('\'') && expected.ends_with('\'') && expected.len() >= 2 {
2153 &expected[1..expected.len() - 1]
2154 } else {
2155 expected
2156 };
2157 arr.value(row_idx) == stripped
2158 }
2159 DataType::Boolean => {
2160 let arr = array.as_boolean();
2161 match expected.to_ascii_uppercase().as_str() {
2162 "TRUE" => arr.value(row_idx),
2163 "FALSE" => !arr.value(row_idx),
2164 _ => false,
2165 }
2166 }
2167 _ => false,
2168 }
2169}
2170
2171fn apply_value_to_matching(
2173 original: &dyn Array,
2174 matching: &[bool],
2175 new_val: &str,
2176 dt: &DataType,
2177) -> Result<Arc<dyn Array>, DfOlapError> {
2178 use arrow::array::*;
2179
2180 match dt {
2181 DataType::Int64 => {
2182 let orig = original.as_primitive::<Int64Type>();
2183 let parsed: i64 = new_val
2184 .parse()
2185 .map_err(|e| DfOlapError::Other(format!("parse i64: {e}")))?;
2186 let mut builder = Int64Builder::new();
2187 for (i, &m) in matching.iter().enumerate() {
2188 if m {
2189 builder.append_value(parsed);
2190 } else if orig.is_null(i) {
2191 builder.append_null();
2192 } else {
2193 builder.append_value(orig.value(i));
2194 }
2195 }
2196 Ok(Arc::new(builder.finish()))
2197 }
2198 DataType::Float64 => {
2199 let orig = original.as_primitive::<Float64Type>();
2200 let parsed: f64 = new_val
2201 .parse()
2202 .map_err(|e| DfOlapError::Other(format!("parse f64: {e}")))?;
2203 let mut builder = Float64Builder::new();
2204 for (i, &m) in matching.iter().enumerate() {
2205 if m {
2206 builder.append_value(parsed);
2207 } else if orig.is_null(i) {
2208 builder.append_null();
2209 } else {
2210 builder.append_value(orig.value(i));
2211 }
2212 }
2213 Ok(Arc::new(builder.finish()))
2214 }
2215 DataType::Utf8 => {
2216 let orig = original.as_string::<i32>();
2217 let stripped =
2218 if new_val.starts_with('\'') && new_val.ends_with('\'') && new_val.len() >= 2 {
2219 &new_val[1..new_val.len() - 1]
2220 } else {
2221 new_val
2222 };
2223 let unescaped = stripped.replace("''", "'");
2224 let mut builder = StringBuilder::new();
2225 for (i, &m) in matching.iter().enumerate() {
2226 if m {
2227 builder.append_value(&unescaped);
2228 } else if orig.is_null(i) {
2229 builder.append_null();
2230 } else {
2231 builder.append_value(orig.value(i));
2232 }
2233 }
2234 Ok(Arc::new(builder.finish()))
2235 }
2236 DataType::Boolean => {
2237 let orig = original.as_boolean();
2238 let parsed = new_val.eq_ignore_ascii_case("TRUE");
2239 let mut builder = BooleanBuilder::new();
2240 for (i, &m) in matching.iter().enumerate() {
2241 if m {
2242 builder.append_value(parsed);
2243 } else if orig.is_null(i) {
2244 builder.append_null();
2245 } else {
2246 builder.append_value(orig.value(i));
2247 }
2248 }
2249 Ok(Arc::new(builder.finish()))
2250 }
2251 _ => {
2252 let orig = original.as_string::<i32>();
2254 let mut builder = StringBuilder::new();
2255 for (i, &m) in matching.iter().enumerate() {
2256 if m {
2257 builder.append_value(new_val);
2258 } else if orig.is_null(i) {
2259 builder.append_null();
2260 } else {
2261 builder.append_value(orig.value(i));
2262 }
2263 }
2264 Ok(Arc::new(builder.finish()))
2265 }
2266 }
2267}
2268
2269#[cfg(test)]
2270mod tests {
2271 use super::*;
2272 use arrow::datatypes::{Field, Schema};
2273 use rhei_core::OlapEngine;
2274
2275 fn users_schema() -> SchemaRef {
2276 Arc::new(Schema::new(vec![
2277 Field::new("id", DataType::Int64, false),
2278 Field::new("name", DataType::Utf8, true),
2279 Field::new("age", DataType::Int64, true),
2280 ]))
2281 }
2282
2283 fn make_in_memory(_: &std::path::Path) -> DataFusionEngine {
2284 DataFusionEngine::new()
2285 }
2286
2287 fn make_arrow_ipc(tmp: &std::path::Path) -> DataFusionEngine {
2288 DataFusionEngine::with_storage(StorageMode::ArrowIpc {
2289 path: tmp.join("arrow_olap"),
2290 })
2291 .unwrap()
2292 }
2293
2294 fn make_parquet(tmp: &std::path::Path) -> DataFusionEngine {
2295 DataFusionEngine::with_storage(StorageMode::Parquet {
2296 path: tmp.join("parquet_olap"),
2297 })
2298 .unwrap()
2299 }
2300
2301 macro_rules! storage_mode_tests {
2303 ($mod_name:ident, $make_engine:ident) => {
2304 mod $mod_name {
2305 use super::*;
2306
2307 #[tokio::test]
2308 async fn create_and_query_empty() {
2309 let _tmp = tempfile::tempdir().unwrap();
2310 let engine = $make_engine(_tmp.path());
2311 let schema = users_schema();
2312 engine.create_table("users", &schema, &[]).await.unwrap();
2313
2314 assert!(engine.table_exists("users").await.unwrap());
2315 assert!(!engine.table_exists("nonexistent").await.unwrap());
2316 }
2317
2318 #[tokio::test]
2319 async fn insert_and_query() {
2320 let _tmp = tempfile::tempdir().unwrap();
2321 let engine = $make_engine(_tmp.path());
2322 let schema = users_schema();
2323 engine.create_table("users", &schema, &[]).await.unwrap();
2324
2325 engine
2326 .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
2327 .await
2328 .unwrap();
2329 engine
2330 .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
2331 .await
2332 .unwrap();
2333
2334 let batches = engine
2335 .query("SELECT * FROM users ORDER BY id")
2336 .await
2337 .unwrap();
2338 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
2339 assert_eq!(total_rows, 2);
2340 }
2341
2342 #[tokio::test]
2343 async fn update() {
2344 let _tmp = tempfile::tempdir().unwrap();
2345 let engine = $make_engine(_tmp.path());
2346 let schema = users_schema();
2347 engine.create_table("users", &schema, &[]).await.unwrap();
2348
2349 engine
2350 .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
2351 .await
2352 .unwrap();
2353
2354 let rows = engine
2355 .execute("UPDATE users SET age = 31 WHERE id = 1")
2356 .await
2357 .unwrap();
2358 assert_eq!(rows, 1);
2359
2360 let batches = engine
2361 .query("SELECT age FROM users WHERE id = 1")
2362 .await
2363 .unwrap();
2364 let age = batches[0].column(0).as_primitive::<Int64Type>().value(0);
2365 assert_eq!(age, 31);
2366 }
2367
2368 #[tokio::test]
2369 async fn delete() {
2370 let _tmp = tempfile::tempdir().unwrap();
2371 let engine = $make_engine(_tmp.path());
2372 let schema = users_schema();
2373 engine.create_table("users", &schema, &[]).await.unwrap();
2374
2375 engine
2376 .execute(
2377 "INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)",
2378 )
2379 .await
2380 .unwrap();
2381
2382 let rows = engine
2383 .execute("DELETE FROM users WHERE id = 1")
2384 .await
2385 .unwrap();
2386 assert_eq!(rows, 1);
2387
2388 let batches = engine.query("SELECT COUNT(*) FROM users").await.unwrap();
2389 let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
2390 assert_eq!(count, 1);
2391 }
2392
2393 #[tokio::test]
2394 async fn load_arrow() {
2395 let _tmp = tempfile::tempdir().unwrap();
2396 let engine = $make_engine(_tmp.path());
2397 let schema = users_schema();
2398 engine.create_table("users", &schema, &[]).await.unwrap();
2399
2400 let batch = RecordBatch::try_new(
2401 schema.clone(),
2402 vec![
2403 Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3])),
2404 Arc::new(arrow::array::StringArray::from(vec![
2405 "Alice", "Bob", "Charlie",
2406 ])),
2407 Arc::new(arrow::array::Int64Array::from(vec![30, 25, 35])),
2408 ],
2409 )
2410 .unwrap();
2411
2412 let loaded = engine.load_arrow("users", &[batch]).await.unwrap();
2413 assert_eq!(loaded, 3);
2414
2415 let batches = engine
2416 .query("SELECT COUNT(*) as cnt FROM users")
2417 .await
2418 .unwrap();
2419 let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
2420 assert_eq!(count, 3);
2421 }
2422
2423 #[tokio::test]
2424 async fn aggregate() {
2425 let _tmp = tempfile::tempdir().unwrap();
2426 let engine = $make_engine(_tmp.path());
2427 let schema = users_schema();
2428 engine.create_table("users", &schema, &[]).await.unwrap();
2429
2430 engine
2431 .execute(
2432 "INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25), (3, 'Charlie', 35)",
2433 )
2434 .await
2435 .unwrap();
2436
2437 let batches = engine
2438 .query("SELECT AVG(age) as avg_age FROM users")
2439 .await
2440 .unwrap();
2441 let avg = batches[0].column(0).as_primitive::<Float64Type>().value(0);
2442 assert!((avg - 30.0).abs() < 0.01);
2443 }
2444 }
2445 };
2446 }
2447
2448 storage_mode_tests!(in_memory, make_in_memory);
2449 storage_mode_tests!(arrow_ipc, make_arrow_ipc);
2450 storage_mode_tests!(parquet, make_parquet);
2451
2452 #[tokio::test]
2458 async fn insert_string_with_comma() {
2459 let engine = DataFusionEngine::new();
2460 let schema = users_schema();
2461 engine.create_table("users", &schema, &[]).await.unwrap();
2462
2463 engine
2466 .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice, B', 30)")
2467 .await
2468 .unwrap();
2469
2470 let batches = engine
2471 .query("SELECT name FROM users WHERE id = 1")
2472 .await
2473 .unwrap();
2474 let name_arr = batches[0].column(0).as_string::<i32>();
2475 assert_eq!(name_arr.value(0), "Alice, B");
2476 }
2477
2478 #[tokio::test]
2480 async fn insert_null_value() {
2481 let engine = DataFusionEngine::new();
2482 let schema = users_schema();
2483 engine.create_table("users", &schema, &[]).await.unwrap();
2484
2485 engine
2486 .execute("INSERT INTO users (id, name, age) VALUES (1, NULL, 30)")
2487 .await
2488 .unwrap();
2489
2490 let batches = engine
2491 .query("SELECT name FROM users WHERE id = 1")
2492 .await
2493 .unwrap();
2494 assert!(batches[0].column(0).is_null(0));
2495 }
2496
2497 #[tokio::test]
2499 async fn update_where_and() {
2500 let engine = DataFusionEngine::new();
2501 let schema = Arc::new(arrow::datatypes::Schema::new(vec![
2502 arrow::datatypes::Field::new("id", DataType::Int64, false),
2503 arrow::datatypes::Field::new("name", DataType::Utf8, true),
2504 arrow::datatypes::Field::new("status", DataType::Utf8, true),
2505 ]));
2506 engine.create_table("t", &schema, &[]).await.unwrap();
2507
2508 engine
2509 .execute("INSERT INTO t (id, name, status) VALUES (1, 'x', 'active')")
2510 .await
2511 .unwrap();
2512 engine
2513 .execute("INSERT INTO t (id, name, status) VALUES (2, 'y', 'inactive')")
2514 .await
2515 .unwrap();
2516
2517 let updated = engine
2519 .execute("UPDATE t SET name = 'updated' WHERE id = 1 AND status = 'active'")
2520 .await
2521 .unwrap();
2522 assert_eq!(updated, 1);
2523
2524 let batches = engine
2525 .query("SELECT name FROM t WHERE id = 1")
2526 .await
2527 .unwrap();
2528 assert_eq!(batches[0].column(0).as_string::<i32>().value(0), "updated");
2529
2530 let batches2 = engine
2531 .query("SELECT name FROM t WHERE id = 2")
2532 .await
2533 .unwrap();
2534 assert_eq!(batches2[0].column(0).as_string::<i32>().value(0), "y");
2535 }
2536
2537 #[tokio::test]
2539 async fn delete_quoted_identifier() {
2540 let engine = DataFusionEngine::new();
2541 let schema = users_schema();
2542 engine.create_table("users", &schema, &[]).await.unwrap();
2544
2545 engine
2546 .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
2547 .await
2548 .unwrap();
2549 engine
2550 .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
2551 .await
2552 .unwrap();
2553
2554 let deleted = engine
2556 .execute(r#"DELETE FROM "users" WHERE id = 1"#)
2557 .await
2558 .unwrap();
2559 assert_eq!(deleted, 1);
2560
2561 let batches = engine.query("SELECT COUNT(*) FROM users").await.unwrap();
2562 let count = batches[0].column(0).as_primitive::<Int64Type>().value(0);
2563 assert_eq!(count, 1);
2564 }
2565
2566 #[tokio::test]
2568 async fn insert_escaped_single_quote() {
2569 let engine = DataFusionEngine::new();
2570 let schema = users_schema();
2571 engine.create_table("users", &schema, &[]).await.unwrap();
2572
2573 engine
2575 .execute("INSERT INTO users (id, name, age) VALUES (1, 'O''Brien', 42)")
2576 .await
2577 .unwrap();
2578
2579 let batches = engine
2580 .query("SELECT name FROM users WHERE id = 1")
2581 .await
2582 .unwrap();
2583 assert_eq!(batches[0].column(0).as_string::<i32>().value(0), "O'Brien");
2584 }
2585
2586 #[test]
2588 fn parse_insert_multi_row() {
2589 let (table, cols, batches) =
2590 parse_insert_values("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')")
2591 .unwrap();
2592 assert_eq!(table, "users");
2593 assert_eq!(cols, vec!["id", "name"]);
2594 assert_eq!(batches.len(), 1);
2595 assert_eq!(batches[0].num_rows(), 2);
2596 }
2597
2598 #[test]
2600 fn parse_update_basic() {
2601 let (table, assignments, where_clause) =
2602 parse_update("UPDATE users SET name = 'Alice' WHERE id = 1").unwrap();
2603 assert_eq!(table, "users");
2604 assert_eq!(
2605 assignments,
2606 vec![("name".to_string(), "'Alice'".to_string())]
2607 );
2608 assert_eq!(where_clause, vec![("id".to_string(), "1".to_string())]);
2609 }
2610
2611 #[test]
2613 fn parse_delete_no_where() {
2614 let (table, conditions) = parse_delete("DELETE FROM logs").unwrap();
2615 assert_eq!(table, "logs");
2616 assert!(conditions.is_empty());
2617 }
2618
2619 #[cfg(feature = "cloud-storage")]
2625 #[test]
2626 fn storage_mode_s3_parquet_attributes() {
2627 let mode = StorageMode::S3Parquet {
2628 url: "s3://my-bucket/rhei-data".to_string(),
2629 };
2630 assert!(mode.is_cloud());
2631 assert_eq!(mode.file_extension(), "parquet");
2632 assert!(mode.base_path().is_none());
2633 assert_eq!(mode.cloud_base_url(), Some("s3://my-bucket/rhei-data"));
2634 }
2635
2636 #[cfg(feature = "cloud-storage")]
2638 #[test]
2639 fn storage_mode_gcs_parquet_attributes() {
2640 let mode = StorageMode::GcsParquet {
2641 url: "gs://gcs-bucket/prefix".to_string(),
2642 };
2643 assert!(mode.is_cloud());
2644 assert_eq!(mode.file_extension(), "parquet");
2645 assert!(mode.base_path().is_none());
2646 assert_eq!(mode.cloud_base_url(), Some("gs://gcs-bucket/prefix"));
2647 }
2648
2649 #[cfg(feature = "cloud-storage")]
2651 #[test]
2652 fn parse_bucket_s3() {
2653 let bucket = DataFusionEngine::parse_bucket("s3://my-bucket/some/prefix", "s3").unwrap();
2654 assert_eq!(bucket, "my-bucket");
2655 }
2656
2657 #[cfg(feature = "cloud-storage")]
2659 #[test]
2660 fn parse_bucket_gcs() {
2661 let bucket = DataFusionEngine::parse_bucket("gs://gcs-bucket/data", "gs").unwrap();
2662 assert_eq!(bucket, "gcs-bucket");
2663 }
2664
2665 #[cfg(feature = "cloud-storage")]
2667 #[test]
2668 fn parse_bucket_wrong_scheme_returns_error() {
2669 let result = DataFusionEngine::parse_bucket("gs://bucket/data", "s3");
2670 assert!(result.is_err());
2671 }
2672
2673 #[cfg(feature = "cloud-storage")]
2675 #[test]
2676 fn cloud_table_url_construction() {
2677 assert_eq!(
2678 DataFusionEngine::cloud_table_url("s3://bucket/prefix", "events"),
2679 "s3://bucket/prefix/events/"
2680 );
2681 assert_eq!(
2683 DataFusionEngine::cloud_table_url("s3://bucket/prefix/", "logs"),
2684 "s3://bucket/prefix/logs/"
2685 );
2686 }
2687
2688 #[cfg(feature = "cloud-storage")]
2697 #[test]
2698 fn s3_parquet_engine_construction_does_not_panic() {
2699 let result = DataFusionEngine::with_storage(StorageMode::S3Parquet {
2702 url: "s3://test-bucket/test-prefix".to_string(),
2703 });
2704 let _ = result;
2707 }
2708
2709 #[cfg(feature = "cloud-storage")]
2711 #[test]
2712 fn gcs_parquet_engine_construction_does_not_panic() {
2713 let result = DataFusionEngine::with_storage(StorageMode::GcsParquet {
2714 url: "gs://test-bucket/test-prefix".to_string(),
2715 });
2716 let _ = result;
2717 }
2718
2719 #[cfg(feature = "cloud-storage")]
2728 #[tokio::test]
2729 async fn cloud_seq_for_prefix_empty_prefix_returns_zero() {
2730 let store: Arc<dyn ObjectStore> = Arc::new(object_store::memory::InMemory::new());
2731 let table_url = "s3://bucket/data/events/";
2733 let next = DataFusionEngine::cloud_seq_for_prefix(&store, table_url, "events")
2734 .await
2735 .unwrap();
2736 assert_eq!(next, 0, "empty prefix should yield counter = 0");
2737 }
2738
2739 #[cfg(feature = "cloud-storage")]
2742 #[tokio::test]
2743 async fn cloud_seq_for_prefix_advances_past_existing_files() {
2744 use object_store::ObjectStoreExt as _;
2745
2746 let store: Arc<dyn ObjectStore> = Arc::new(object_store::memory::InMemory::new());
2747
2748 for seq in 0u64..3 {
2753 let path = object_store::path::Path::from(
2754 format!("data/events/events_{seq:06}.parquet").as_str(),
2755 );
2756 store
2757 .put(&path, bytes::Bytes::from_static(b"dummy").into())
2758 .await
2759 .unwrap();
2760 }
2761
2762 let table_url = "s3://bucket/data/events/";
2763 let next = DataFusionEngine::cloud_seq_for_prefix(&store, table_url, "events")
2764 .await
2765 .unwrap();
2766 assert_eq!(next, 3, "counter should start at max_existing + 1");
2768 }
2769
2770 #[cfg(feature = "cloud-storage")]
2776 #[tokio::test]
2777 async fn cloud_engine_restart_does_not_overwrite_existing_files() {
2778 use object_store::ObjectStoreExt as _;
2779 use std::sync::atomic::Ordering;
2780
2781 let store: Arc<dyn ObjectStore> = Arc::new(object_store::memory::InMemory::new());
2782
2783 for seq in 0u64..2 {
2789 let path = object_store::path::Path::from(
2790 format!("prefix/users/users_{seq:06}.parquet").as_str(),
2791 );
2792 store
2793 .put(&path, bytes::Bytes::from_static(b"parquet-data").into())
2794 .await
2795 .unwrap();
2796 }
2797
2798 let table_url = "s3://bucket/prefix/users/";
2800 let next = DataFusionEngine::cloud_seq_for_prefix(&store, table_url, "users")
2801 .await
2802 .unwrap();
2803 assert_eq!(
2804 next, 2,
2805 "restarted engine should begin writing at index 2, not 0"
2806 );
2807
2808 let counter = AtomicU64::new(0);
2810 counter.fetch_max(next, Ordering::Relaxed);
2811 let seq = counter.fetch_add(1, Ordering::Relaxed);
2812 assert_eq!(seq, 2, "first write after restart should use index 2");
2813
2814 for orig_seq in 0u64..2 {
2816 let path = object_store::path::Path::from(
2817 format!("prefix/users/users_{orig_seq:06}.parquet").as_str(),
2818 );
2819 let result = store.get(&path).await;
2820 assert!(
2821 result.is_ok(),
2822 "original file users_{orig_seq:06}.parquet must not be overwritten"
2823 );
2824 }
2825 }
2826}