1use arrow::array::{
17 Array, ArrayRef, BooleanArray, BooleanBuilder, Date32Array, Decimal128Array, Decimal128Builder,
18 Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, Int64Builder,
19 IntervalMonthDayNanoArray, RecordBatch, StringArray, UInt8Array, UInt16Array, UInt32Array,
20 UInt64Array, new_null_array,
21};
22use arrow::compute::{
23 SortColumn, SortOptions, cast, concat_batches, filter_record_batch, lexsort_to_indices, not,
24 or_kleene, take,
25};
26use arrow::datatypes::{DataType, Field, Float64Type, Int32Type, Int64Type, IntervalUnit, Schema};
27use arrow::row::{RowConverter, SortField};
28use arrow_buffer::IntervalMonthDayNano;
29use llkv_aggregate::{AggregateAccumulator, AggregateKind, AggregateSpec, AggregateState};
30use llkv_column_map::gather::gather_indices_from_batches;
31use llkv_column_map::store::Projection as StoreProjection;
32use llkv_expr::SubqueryId;
33use llkv_expr::expr::{
34 AggregateCall, BinaryOp, CompareOp, Expr as LlkvExpr, Filter, Operator, ScalarExpr,
35};
36use llkv_expr::literal::{Literal, LiteralExt};
37use llkv_expr::typed_predicate::{
38 build_bool_predicate, build_fixed_width_predicate, build_var_width_predicate,
39};
40use llkv_join::cross_join_pair;
41use llkv_plan::{
42 AggregateExpr, AggregateFunction, CanonicalRow, CompoundOperator, CompoundQuantifier,
43 CompoundSelectComponent, CompoundSelectPlan, OrderByPlan, OrderSortType, OrderTarget,
44 PlanValue, SelectPlan, SelectProjection, plan_value_from_literal,
45};
46use llkv_result::Error;
47use llkv_storage::pager::Pager;
48use llkv_table::table::{
49 RowIdFilter, ScanOrderDirection, ScanOrderSpec, ScanOrderTransform, ScanProjection,
50 ScanStreamOptions,
51};
52use llkv_table::types::FieldId;
53use llkv_table::{NumericArrayMap, NumericKernels, ROW_ID_FIELD_ID};
54use llkv_threading::with_thread_pool;
55use llkv_types::LogicalFieldId;
56use llkv_types::decimal::DecimalValue;
57use rayon::prelude::*;
58use rustc_hash::{FxHashMap, FxHashSet};
59use simd_r_drive_entry_handle::EntryHandle;
60use std::convert::TryFrom;
61use std::fmt;
62use std::sync::Arc;
63use std::sync::atomic::Ordering;
64
65#[cfg(test)]
66use std::cell::RefCell;
67
68pub mod insert;
73pub mod scan;
74pub mod translation;
75pub mod types;
76
77pub type ExecutorResult<T> = Result<T, Error>;
83
84use crate::translation::schema::infer_computed_data_type;
85pub use insert::{
86 build_array_for_column, normalize_insert_value_for_column, resolve_insert_columns,
87};
88use llkv_compute::date::{format_date32_literal, parse_date32_literal};
89use llkv_compute::scalar::decimal::{
90 align_decimal_to_scale, decimal_from_f64, decimal_from_i64, decimal_truthy,
91};
92use llkv_compute::scalar::interval::{
93 compare_interval_values, interval_value_from_arrow, interval_value_to_arrow,
94};
95pub use llkv_compute::time::current_time_micros;
96pub use translation::{
97 build_projected_columns, build_wildcard_projections, full_table_scan_filter,
98 resolve_field_id_from_schema, schema_for_projections, translate_predicate,
99 translate_predicate_with, translate_scalar, translate_scalar_with,
100};
101pub use types::{
102 ExecutorColumn, ExecutorMultiColumnUnique, ExecutorRowBatch, ExecutorSchema, ExecutorTable,
103 ExecutorTableProvider, StorageTable, TableStorageAdapter,
104};
105
106#[derive(Clone, Debug, PartialEq, Eq, Hash)]
107enum GroupKeyValue {
108 Null,
109 Int(i64),
111 Bool(bool),
112 String(String),
113}
114
115#[derive(Clone, Debug, PartialEq)]
119enum AggregateValue {
120 Null,
121 Int64(i64),
122 Float64(f64),
123 Decimal128 { value: i128, scale: i8 },
124 String(String),
125}
126
127impl AggregateValue {
128 fn as_i64(&self) -> Option<i64> {
130 match self {
131 AggregateValue::Null => None,
132 AggregateValue::Int64(v) => Some(*v),
133 AggregateValue::Float64(v) => Some(*v as i64),
134 AggregateValue::Decimal128 { value, scale } => {
135 let divisor = 10_i128.pow(*scale as u32);
137 Some((value / divisor) as i64)
138 }
139 AggregateValue::String(s) => s.parse().ok(),
140 }
141 }
142
143 #[allow(dead_code)]
145 fn as_f64(&self) -> Option<f64> {
146 match self {
147 AggregateValue::Null => None,
148 AggregateValue::Int64(v) => Some(*v as f64),
149 AggregateValue::Float64(v) => Some(*v),
150 AggregateValue::Decimal128 { value, scale } => {
151 let divisor = 10_f64.powi(*scale as i32);
153 Some(*value as f64 / divisor)
154 }
155 AggregateValue::String(s) => s.parse().ok(),
156 }
157 }
158}
159
160fn decimal_exact_i64(decimal: DecimalValue) -> Option<i64> {
161 llkv_compute::scalar::decimal::rescale(decimal, 0)
162 .ok()
163 .and_then(|integral| i64::try_from(integral.raw_value()).ok())
164}
165
166struct GroupState {
167 batch: RecordBatch,
168 row_idx: usize,
169}
170
171struct GroupAggregateState {
173 representative_batch_idx: usize,
174 representative_row: usize,
175 row_locations: Vec<(usize, usize)>,
176}
177
178struct OutputColumn {
179 field: Field,
180 source: OutputSource,
181}
182
183enum OutputSource {
184 TableColumn { index: usize },
185 Computed { projection_index: usize },
186}
187
188#[cfg(test)]
193thread_local! {
194 static QUERY_LABEL_STACK: RefCell<Vec<String>> = const { RefCell::new(Vec::new()) };
195}
196
197pub struct QueryLogGuard {
199 _private: (),
200}
201
202#[cfg(test)]
205pub fn push_query_label(label: impl Into<String>) -> QueryLogGuard {
206 QUERY_LABEL_STACK.with(|stack| stack.borrow_mut().push(label.into()));
207 QueryLogGuard { _private: () }
208}
209
210#[cfg(not(test))]
215#[inline]
216pub fn push_query_label(_label: impl Into<String>) -> QueryLogGuard {
217 QueryLogGuard { _private: () }
218}
219
220#[cfg(test)]
221impl Drop for QueryLogGuard {
222 fn drop(&mut self) {
223 QUERY_LABEL_STACK.with(|stack| {
224 let _ = stack.borrow_mut().pop();
225 });
226 }
227}
228
229#[cfg(not(test))]
230impl Drop for QueryLogGuard {
231 #[inline]
232 fn drop(&mut self) {
233 }
235}
236
237#[cfg(test)]
239pub fn current_query_label() -> Option<String> {
240 QUERY_LABEL_STACK.with(|stack| stack.borrow().last().cloned())
241}
242
243#[cfg(not(test))]
247#[inline]
248pub fn current_query_label() -> Option<String> {
249 None
250}
251
252fn try_extract_simple_column<F: AsRef<str>>(expr: &ScalarExpr<F>) -> Option<&str> {
267 match expr {
268 ScalarExpr::Column(name) => Some(name.as_ref()),
269 ScalarExpr::Binary { left, op, right } => {
271 match op {
273 BinaryOp::Add => {
274 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Int128(0))) {
276 return try_extract_simple_column(right);
277 }
278 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Int128(0))) {
279 return try_extract_simple_column(left);
280 }
281 }
282 BinaryOp::Multiply => {
285 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Int128(1))) {
287 return try_extract_simple_column(right);
288 }
289 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Int128(1))) {
290 return try_extract_simple_column(left);
291 }
292 }
293 _ => {}
294 }
295 None
296 }
297 _ => None,
298 }
299}
300
301fn plan_values_to_arrow_array(values: &[PlanValue]) -> ExecutorResult<ArrayRef> {
306 use arrow::array::{
307 Date32Array, Decimal128Array, Float64Array, Int64Array, IntervalMonthDayNanoArray,
308 StringArray,
309 };
310
311 let mut value_type = None;
313 for v in values {
314 if !matches!(v, PlanValue::Null) {
315 value_type = Some(v);
316 break;
317 }
318 }
319
320 match value_type {
321 Some(PlanValue::Decimal(d)) => {
322 let precision = d.precision();
323 let scale = d.scale();
324 let mut builder = Decimal128Array::builder(values.len())
325 .with_precision_and_scale(precision, scale)
326 .map_err(|e| {
327 Error::InvalidArgumentError(format!(
328 "invalid Decimal128 precision/scale: {}",
329 e
330 ))
331 })?;
332 for v in values {
333 match v {
334 PlanValue::Decimal(d) => builder.append_value(d.raw_value()),
335 PlanValue::Null => builder.append_null(),
336 other => {
337 return Err(Error::InvalidArgumentError(format!(
338 "expected DECIMAL plan value, found {other:?}"
339 )));
340 }
341 }
342 }
343 Ok(Arc::new(builder.finish()) as ArrayRef)
344 }
345 Some(PlanValue::Integer(_)) => {
346 let int_values: Vec<Option<i64>> = values
347 .iter()
348 .map(|v| match v {
349 PlanValue::Integer(i) => Ok(Some(*i)),
350 PlanValue::Null => Ok(None),
351 other => Err(Error::InvalidArgumentError(format!(
352 "expected INTEGER plan value, found {other:?}"
353 ))),
354 })
355 .collect::<Result<_, _>>()?;
356 Ok(Arc::new(Int64Array::from(int_values)) as ArrayRef)
357 }
358 Some(PlanValue::Float(_)) => {
359 let float_values: Vec<Option<f64>> = values
360 .iter()
361 .map(|v| match v {
362 PlanValue::Float(f) => Ok(Some(*f)),
363 PlanValue::Null => Ok(None),
364 PlanValue::Integer(i) => Ok(Some(*i as f64)),
365 other => Err(Error::InvalidArgumentError(format!(
366 "expected FLOAT plan value, found {other:?}"
367 ))),
368 })
369 .collect::<Result<_, _>>()?;
370 Ok(Arc::new(Float64Array::from(float_values)) as ArrayRef)
371 }
372 Some(PlanValue::String(_)) => {
373 let string_values: Vec<Option<&str>> = values
374 .iter()
375 .map(|v| match v {
376 PlanValue::String(s) => Ok(Some(s.as_str())),
377 PlanValue::Null => Ok(None),
378 other => Err(Error::InvalidArgumentError(format!(
379 "expected STRING plan value, found {other:?}"
380 ))),
381 })
382 .collect::<Result<_, _>>()?;
383 Ok(Arc::new(StringArray::from(string_values)) as ArrayRef)
384 }
385 Some(PlanValue::Date32(_)) => {
386 let date_values: Vec<Option<i32>> = values
387 .iter()
388 .map(|v| match v {
389 PlanValue::Date32(d) => Ok(Some(*d)),
390 PlanValue::Null => Ok(None),
391 other => Err(Error::InvalidArgumentError(format!(
392 "expected DATE plan value, found {other:?}"
393 ))),
394 })
395 .collect::<Result<_, _>>()?;
396 Ok(Arc::new(Date32Array::from(date_values)) as ArrayRef)
397 }
398 Some(PlanValue::Interval(_)) => {
399 let interval_values: Vec<Option<IntervalMonthDayNano>> = values
400 .iter()
401 .map(|v| match v {
402 PlanValue::Interval(interval) => Ok(Some(interval_value_to_arrow(*interval))),
403 PlanValue::Null => Ok(None),
404 other => Err(Error::InvalidArgumentError(format!(
405 "expected INTERVAL plan value, found {other:?}"
406 ))),
407 })
408 .collect::<Result<_, _>>()?;
409 Ok(Arc::new(IntervalMonthDayNanoArray::from(interval_values)) as ArrayRef)
410 }
411 _ => Ok(new_null_array(&DataType::Int64, values.len())),
412 }
413}
414
415fn resolve_column_name_to_index(
425 col_name: &str,
426 column_lookup_map: &FxHashMap<String, usize>,
427) -> Option<usize> {
428 let col_lower = col_name.to_ascii_lowercase();
429
430 if let Some(&idx) = column_lookup_map.get(&col_lower) {
432 return Some(idx);
433 }
434
435 let unqualified = col_name
438 .rsplit('.')
439 .next()
440 .unwrap_or(col_name)
441 .to_ascii_lowercase();
442 column_lookup_map
443 .iter()
444 .find(|(k, _)| k.ends_with(&format!(".{}", unqualified)) || k == &&unqualified)
445 .map(|(_, &idx)| idx)
446}
447
448fn get_or_insert_column_projection<P>(
450 projections: &mut Vec<ScanProjection>,
451 cache: &mut FxHashMap<FieldId, usize>,
452 table: &ExecutorTable<P>,
453 column: &ExecutorColumn,
454) -> usize
455where
456 P: Pager<Blob = EntryHandle> + Send + Sync,
457{
458 if let Some(existing) = cache.get(&column.field_id) {
459 return *existing;
460 }
461
462 let projection_index = projections.len();
463 let alias = if column.name.is_empty() {
464 format!("col{}", column.field_id)
465 } else {
466 column.name.clone()
467 };
468 projections.push(ScanProjection::from(StoreProjection::with_alias(
469 LogicalFieldId::for_user(table.table_id(), column.field_id),
470 alias,
471 )));
472 cache.insert(column.field_id, projection_index);
473 projection_index
474}
475
476fn ensure_computed_projection<P>(
478 expr: &ScalarExpr<String>,
479 table: &ExecutorTable<P>,
480 projections: &mut Vec<ScanProjection>,
481 cache: &mut FxHashMap<String, (usize, DataType)>,
482 alias_counter: &mut usize,
483) -> ExecutorResult<(usize, DataType)>
484where
485 P: Pager<Blob = EntryHandle> + Send + Sync,
486{
487 let key = format!("{:?}", expr);
488 if let Some((idx, dtype)) = cache.get(&key) {
489 return Ok((*idx, dtype.clone()));
490 }
491
492 let translated = translate_scalar(expr, table.schema.as_ref(), |name| {
493 Error::InvalidArgumentError(format!("unknown column '{}' in aggregate expression", name))
494 })?;
495 let data_type = infer_computed_data_type(table.schema.as_ref(), &translated)?;
496 if data_type == DataType::Null {
497 tracing::debug!(
498 "ensure_computed_projection inferred Null type for expr: {:?}",
499 expr
500 );
501 }
502 let alias = format!("__agg_expr_{}", *alias_counter);
503 *alias_counter += 1;
504 let projection_index = projections.len();
505 projections.push(ScanProjection::computed(translated, alias));
506 cache.insert(key, (projection_index, data_type.clone()));
507 Ok((projection_index, data_type))
508}
509
510pub struct QueryExecutor<P>
512where
513 P: Pager<Blob = EntryHandle> + Send + Sync,
514{
515 provider: Arc<dyn ExecutorTableProvider<P>>,
516}
517
518impl<P> QueryExecutor<P>
519where
520 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
521{
522 pub fn new(provider: Arc<dyn ExecutorTableProvider<P>>) -> Self {
523 Self { provider }
524 }
525
526 pub fn execute_select(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
527 self.execute_select_with_filter(plan, None)
528 }
529
530 pub fn execute_select_with_filter(
531 &self,
532 plan: SelectPlan,
533 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
534 ) -> ExecutorResult<SelectExecution<P>> {
535 let limit = plan.limit;
536 let offset = plan.offset;
537
538 let execution = if plan.compound.is_some() {
539 self.execute_compound_select(plan, row_filter)?
540 } else if plan.tables.is_empty() {
541 self.execute_select_without_table(plan)?
542 } else if !plan.group_by.is_empty() {
543 if plan.tables.len() > 1 {
544 self.execute_cross_product(plan)?
545 } else {
546 let table_ref = &plan.tables[0];
547 let table = self.provider.get_table(&table_ref.qualified_name())?;
548 let display_name = table_ref.qualified_name();
549 self.execute_group_by_single_table(table, display_name, plan, row_filter)?
550 }
551 } else if plan.tables.len() > 1 {
552 self.execute_cross_product(plan)?
553 } else {
554 let table_ref = &plan.tables[0];
556 let table = self.provider.get_table(&table_ref.qualified_name())?;
557 let display_name = table_ref.qualified_name();
558
559 if !plan.aggregates.is_empty() {
560 self.execute_aggregates(table, display_name, plan, row_filter)?
561 } else if self.has_computed_aggregates(&plan) {
562 self.execute_computed_aggregates(table, display_name, plan, row_filter)?
564 } else {
565 self.execute_projection(table, display_name, plan, row_filter)?
566 }
567 };
568
569 Ok(execution.with_limit(limit).with_offset(offset))
570 }
571
572 fn execute_compound_select(
592 &self,
593 plan: SelectPlan,
594 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
595 ) -> ExecutorResult<SelectExecution<P>> {
596 let order_by = plan.order_by.clone();
597 let compound = plan.compound.expect("compound plan should be present");
598
599 let CompoundSelectPlan {
600 initial,
601 operations,
602 } = compound;
603
604 let initial_exec = self.execute_select_with_filter(*initial, row_filter.clone())?;
605 let schema = initial_exec.schema();
606 let mut rows = initial_exec.into_rows()?;
607 let mut distinct_cache: Option<FxHashSet<Vec<u8>>> = None;
608
609 for component in operations {
610 let exec = self.execute_select_with_filter(component.plan, row_filter.clone())?;
611 let other_schema = exec.schema();
612 ensure_schema_compatibility(schema.as_ref(), other_schema.as_ref())?;
613 let other_rows = exec.into_rows()?;
614
615 match (component.operator, component.quantifier) {
616 (CompoundOperator::Union, CompoundQuantifier::All) => {
617 rows.extend(other_rows);
618 distinct_cache = None;
619 }
620 (CompoundOperator::Union, CompoundQuantifier::Distinct) => {
621 ensure_distinct_rows(&mut rows, &mut distinct_cache);
622 let cache = distinct_cache
623 .as_mut()
624 .expect("distinct cache should be initialized");
625 for row in other_rows {
626 let key = encode_row(&row);
627 if cache.insert(key) {
628 rows.push(row);
629 }
630 }
631 }
632 (CompoundOperator::Except, CompoundQuantifier::Distinct) => {
633 ensure_distinct_rows(&mut rows, &mut distinct_cache);
634 let cache = distinct_cache
635 .as_mut()
636 .expect("distinct cache should be initialized");
637 if rows.is_empty() {
638 continue;
639 }
640 let mut remove_keys = FxHashSet::default();
641 for row in other_rows {
642 remove_keys.insert(encode_row(&row));
643 }
644 if remove_keys.is_empty() {
645 continue;
646 }
647 rows.retain(|row| {
648 let key = encode_row(row);
649 if remove_keys.contains(&key) {
650 cache.remove(&key);
651 false
652 } else {
653 true
654 }
655 });
656 }
657 (CompoundOperator::Except, CompoundQuantifier::All) => {
658 return Err(Error::InvalidArgumentError(
659 "EXCEPT ALL is not supported yet".into(),
660 ));
661 }
662 (CompoundOperator::Intersect, CompoundQuantifier::Distinct) => {
663 ensure_distinct_rows(&mut rows, &mut distinct_cache);
664 let mut right_keys = FxHashSet::default();
665 for row in other_rows {
666 right_keys.insert(encode_row(&row));
667 }
668 if right_keys.is_empty() {
669 rows.clear();
670 distinct_cache = Some(FxHashSet::default());
671 continue;
672 }
673 let mut new_rows = Vec::new();
674 let mut new_cache = FxHashSet::default();
675 for row in rows.drain(..) {
676 let key = encode_row(&row);
677 if right_keys.contains(&key) && new_cache.insert(key) {
678 new_rows.push(row);
679 }
680 }
681 rows = new_rows;
682 distinct_cache = Some(new_cache);
683 }
684 (CompoundOperator::Intersect, CompoundQuantifier::All) => {
685 return Err(Error::InvalidArgumentError(
686 "INTERSECT ALL is not supported yet".into(),
687 ));
688 }
689 }
690 }
691
692 let mut batch = rows_to_record_batch(schema.clone(), &rows)?;
693 if !order_by.is_empty() && batch.num_rows() > 0 {
694 batch = sort_record_batch_with_order(&schema, &batch, &order_by)?;
695 }
696
697 Ok(SelectExecution::new_single_batch(
698 String::new(),
699 schema,
700 batch,
701 ))
702 }
703
704 fn has_computed_aggregates(&self, plan: &SelectPlan) -> bool {
706 plan.projections.iter().any(|proj| {
707 if let SelectProjection::Computed { expr, .. } = proj {
708 Self::expr_contains_aggregate(expr)
709 } else {
710 false
711 }
712 })
713 }
714
715 fn predicate_contains_aggregate(expr: &llkv_expr::expr::Expr<String>) -> bool {
717 match expr {
718 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
719 exprs.iter().any(Self::predicate_contains_aggregate)
720 }
721 llkv_expr::expr::Expr::Not(inner) => Self::predicate_contains_aggregate(inner),
722 llkv_expr::expr::Expr::Compare { left, right, .. } => {
723 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
724 }
725 llkv_expr::expr::Expr::InList { expr, list, .. } => {
726 Self::expr_contains_aggregate(expr)
727 || list.iter().any(|e| Self::expr_contains_aggregate(e))
728 }
729 llkv_expr::expr::Expr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
730 llkv_expr::expr::Expr::Literal(_) => false,
731 llkv_expr::expr::Expr::Pred(_) => false,
732 llkv_expr::expr::Expr::Exists(_) => false,
733 }
734 }
735
736 fn expr_contains_aggregate(expr: &ScalarExpr<String>) -> bool {
738 match expr {
739 ScalarExpr::Aggregate(_) => true,
740 ScalarExpr::Binary { left, right, .. } => {
741 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
742 }
743 ScalarExpr::Compare { left, right, .. } => {
744 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
745 }
746 ScalarExpr::GetField { base, .. } => Self::expr_contains_aggregate(base),
747 ScalarExpr::Cast { expr, .. } => Self::expr_contains_aggregate(expr),
748 ScalarExpr::Not(expr) => Self::expr_contains_aggregate(expr),
749 ScalarExpr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
750 ScalarExpr::Case {
751 operand,
752 branches,
753 else_expr,
754 } => {
755 operand
756 .as_deref()
757 .map(Self::expr_contains_aggregate)
758 .unwrap_or(false)
759 || branches.iter().any(|(when_expr, then_expr)| {
760 Self::expr_contains_aggregate(when_expr)
761 || Self::expr_contains_aggregate(then_expr)
762 })
763 || else_expr
764 .as_deref()
765 .map(Self::expr_contains_aggregate)
766 .unwrap_or(false)
767 }
768 ScalarExpr::Coalesce(items) => items.iter().any(Self::expr_contains_aggregate),
769 ScalarExpr::Column(_) | ScalarExpr::Literal(_) | ScalarExpr::Random => false,
770 ScalarExpr::ScalarSubquery(_) => false,
771 }
772 }
773
774 fn evaluate_exists_subquery(
775 &self,
776 context: &mut CrossProductExpressionContext,
777 subquery: &llkv_plan::FilterSubquery,
778 batch: &RecordBatch,
779 row_idx: usize,
780 ) -> ExecutorResult<bool> {
781 let bindings =
782 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
783 let bound_plan = bind_select_plan(&subquery.plan, &bindings)?;
784 let execution = self.execute_select(bound_plan)?;
785 let mut found = false;
786 execution.stream(|inner_batch| {
787 if inner_batch.num_rows() > 0 {
788 found = true;
789 }
790 Ok(())
791 })?;
792 Ok(found)
793 }
794
795 fn evaluate_scalar_subquery_literal(
796 &self,
797 context: &mut CrossProductExpressionContext,
798 subquery: &llkv_plan::ScalarSubquery,
799 batch: &RecordBatch,
800 row_idx: usize,
801 ) -> ExecutorResult<Literal> {
802 let bindings =
803 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
804 self.evaluate_scalar_subquery_with_bindings(subquery, &bindings)
805 }
806
807 fn evaluate_scalar_subquery_with_bindings(
808 &self,
809 subquery: &llkv_plan::ScalarSubquery,
810 bindings: &FxHashMap<String, Literal>,
811 ) -> ExecutorResult<Literal> {
812 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
813 let execution = self.execute_select(bound_plan)?;
814 let mut rows_seen: usize = 0;
815 let mut result: Option<Literal> = None;
816 execution.stream(|inner_batch| {
817 if inner_batch.num_columns() != 1 {
818 return Err(Error::InvalidArgumentError(
819 "scalar subquery must return exactly one column".into(),
820 ));
821 }
822 let column = inner_batch.column(0).clone();
823 for idx in 0..inner_batch.num_rows() {
824 if rows_seen >= 1 {
825 return Err(Error::InvalidArgumentError(
826 "scalar subquery produced more than one row".into(),
827 ));
828 }
829 rows_seen = rows_seen.saturating_add(1);
830 result = Some(Literal::from_array_ref(&column, idx)?);
831 }
832 Ok(())
833 })?;
834
835 if rows_seen == 0 {
836 Ok(Literal::Null)
837 } else {
838 result
839 .ok_or_else(|| Error::Internal("scalar subquery evaluation missing result".into()))
840 }
841 }
842
843 fn evaluate_scalar_subquery_numeric(
844 &self,
845 context: &mut CrossProductExpressionContext,
846 subquery: &llkv_plan::ScalarSubquery,
847 batch: &RecordBatch,
848 ) -> ExecutorResult<ArrayRef> {
849 let row_count = batch.num_rows();
850 let mut row_job_indices: Vec<usize> = Vec::with_capacity(row_count);
851 let mut key_lookup: FxHashMap<Vec<u8>, usize> = FxHashMap::default();
852 let mut job_literals: Vec<Option<Literal>> = Vec::new();
853 let mut pending_bindings: Vec<FxHashMap<String, Literal>> = Vec::new();
854 let mut pending_keys: Vec<Vec<u8>> = Vec::new();
855 let mut pending_slots: Vec<usize> = Vec::new();
856
857 for row_idx in 0..row_count {
858 let bindings =
859 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
860
861 let mut plan_values: Vec<PlanValue> =
863 Vec::with_capacity(subquery.correlated_columns.len());
864 for column in &subquery.correlated_columns {
865 let literal = bindings
866 .get(&column.placeholder)
867 .cloned()
868 .unwrap_or(Literal::Null);
869 let plan_value = plan_value_from_literal(&literal)?;
870 plan_values.push(plan_value);
871 }
872 let key = encode_row(&plan_values);
873 let cache_key = (subquery.id, key.clone());
874
875 let job_idx = if let Some(&existing) = key_lookup.get(&key) {
876 existing
877 } else if let Some(cached) = context.scalar_subquery_cache.get(&cache_key) {
878 let idx = job_literals.len();
879 key_lookup.insert(key, idx);
880 job_literals.push(Some(cached.clone()));
881 idx
882 } else {
883 let idx = job_literals.len();
884 key_lookup.insert(key, idx);
885 job_literals.push(None);
886 pending_bindings.push(bindings);
887 pending_keys.push(cache_key.1);
888 pending_slots.push(idx);
889 idx
890 };
891 row_job_indices.push(job_idx);
892 }
893
894 if !pending_bindings.is_empty() {
896 let job_results: Vec<ExecutorResult<Literal>> = with_thread_pool(|| {
897 pending_bindings
898 .par_iter()
899 .map(|bindings| self.evaluate_scalar_subquery_with_bindings(subquery, bindings))
900 .collect()
901 });
902
903 for ((slot_idx, cache_key), result) in pending_slots
904 .into_iter()
905 .zip(pending_keys.into_iter())
906 .zip(job_results.into_iter())
907 {
908 let literal = result?;
909 job_literals[slot_idx] = Some(literal.clone());
910 context
911 .scalar_subquery_cache
912 .insert((subquery.id, cache_key), literal);
913 }
914 }
915
916 let mut values: Vec<Option<f64>> = Vec::with_capacity(row_count);
917 let mut all_integer = true;
918
919 for row_idx in 0..row_count {
920 let literal = job_literals[row_job_indices[row_idx]]
921 .as_ref()
922 .ok_or_else(|| Error::Internal("scalar subquery result missing".into()))?;
923 match literal {
924 Literal::Null => values.push(None),
925 Literal::Int128(value) => {
926 let cast = i64::try_from(*value).map_err(|_| {
927 Error::InvalidArgumentError(
928 "scalar subquery integer result exceeds supported range".into(),
929 )
930 })?;
931 values.push(Some(cast as f64));
932 }
933 Literal::Float64(value) => {
934 all_integer = false;
935 values.push(Some(*value));
936 }
937 Literal::Boolean(flag) => {
938 let numeric = if *flag { 1.0 } else { 0.0 };
939 values.push(Some(numeric));
940 }
941 Literal::Decimal128(decimal) => {
942 if let Some(value) = decimal_exact_i64(*decimal) {
943 values.push(Some(value as f64));
944 } else {
945 all_integer = false;
946 values.push(Some(decimal.to_f64()));
947 }
948 }
949 Literal::String(_)
950 | Literal::Struct(_)
951 | Literal::Date32(_)
952 | Literal::Interval(_) => {
953 return Err(Error::InvalidArgumentError(
954 "scalar subquery produced non-numeric result in numeric context".into(),
955 ));
956 }
957 }
958 }
959
960 if all_integer {
961 let iter = values.into_iter().map(|opt| opt.map(|v| v as i64));
962 let array = Int64Array::from_iter(iter);
963 Ok(Arc::new(array) as ArrayRef)
964 } else {
965 let array = Float64Array::from_iter(values);
966 Ok(Arc::new(array) as ArrayRef)
967 }
968 }
969
970 fn evaluate_scalar_subquery_array(
971 &self,
972 context: &mut CrossProductExpressionContext,
973 subquery: &llkv_plan::ScalarSubquery,
974 batch: &RecordBatch,
975 ) -> ExecutorResult<ArrayRef> {
976 let row_count = batch.num_rows();
977 let mut row_job_indices: Vec<usize> = Vec::with_capacity(row_count);
978 let mut key_lookup: FxHashMap<Vec<u8>, usize> = FxHashMap::default();
979 let mut job_literals: Vec<Option<Literal>> = Vec::new();
980 let mut pending_bindings: Vec<FxHashMap<String, Literal>> = Vec::new();
981 let mut pending_keys: Vec<Vec<u8>> = Vec::new();
982 let mut pending_slots: Vec<usize> = Vec::new();
983
984 for row_idx in 0..row_count {
985 let bindings =
986 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
987
988 let mut plan_values: Vec<PlanValue> =
990 Vec::with_capacity(subquery.correlated_columns.len());
991 for column in &subquery.correlated_columns {
992 let literal = bindings
993 .get(&column.placeholder)
994 .cloned()
995 .unwrap_or(Literal::Null);
996 let plan_value = plan_value_from_literal(&literal)?;
997 plan_values.push(plan_value);
998 }
999 let key = encode_row(&plan_values);
1000 let cache_key = (subquery.id, key.clone());
1001
1002 let job_idx = if let Some(&existing) = key_lookup.get(&key) {
1003 existing
1004 } else if let Some(cached) = context.scalar_subquery_cache.get(&cache_key) {
1005 let idx = job_literals.len();
1006 key_lookup.insert(key, idx);
1007 job_literals.push(Some(cached.clone()));
1008 idx
1009 } else {
1010 let idx = job_literals.len();
1011 key_lookup.insert(key, idx);
1012 job_literals.push(None);
1013 pending_bindings.push(bindings);
1014 pending_keys.push(cache_key.1);
1015 pending_slots.push(idx);
1016 idx
1017 };
1018 row_job_indices.push(job_idx);
1019 }
1020
1021 if !pending_bindings.is_empty() {
1023 let job_results: Vec<ExecutorResult<Literal>> = with_thread_pool(|| {
1024 pending_bindings
1025 .par_iter()
1026 .map(|bindings| self.evaluate_scalar_subquery_with_bindings(subquery, bindings))
1027 .collect()
1028 });
1029
1030 for ((slot_idx, cache_key), result) in pending_slots
1031 .into_iter()
1032 .zip(pending_keys.into_iter())
1033 .zip(job_results.into_iter())
1034 {
1035 let literal = result?;
1036 job_literals[slot_idx] = Some(literal.clone());
1037 context
1038 .scalar_subquery_cache
1039 .insert((subquery.id, cache_key), literal);
1040 }
1041 }
1042
1043 let mut values: Vec<Literal> = Vec::with_capacity(row_count);
1044 for row_idx in 0..row_count {
1045 let literal = job_literals[row_job_indices[row_idx]]
1046 .as_ref()
1047 .ok_or_else(|| Error::Internal("scalar subquery result missing".into()))?;
1048 values.push(literal.clone());
1049 }
1050
1051 literals_to_array(&values)
1052 }
1053
1054 fn evaluate_projection_expression(
1055 &self,
1056 context: &mut CrossProductExpressionContext,
1057 expr: &ScalarExpr<String>,
1058 batch: &RecordBatch,
1059 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
1060 ) -> ExecutorResult<ArrayRef> {
1061 let translated = translate_scalar(expr, context.schema(), |name| {
1062 Error::InvalidArgumentError(format!(
1063 "column '{}' not found in cross product result",
1064 name
1065 ))
1066 })?;
1067
1068 let mut subquery_ids: FxHashSet<SubqueryId> = FxHashSet::default();
1069 collect_scalar_subquery_ids(&translated, &mut subquery_ids);
1070
1071 let mut mapping: FxHashMap<SubqueryId, FieldId> = FxHashMap::default();
1072 for subquery_id in &subquery_ids {
1073 let info = scalar_lookup
1074 .get(subquery_id)
1075 .ok_or_else(|| Error::Internal("missing scalar subquery metadata".into()))?;
1076 let field_id = context.allocate_synthetic_field_id()?;
1077 let numeric = self.evaluate_scalar_subquery_numeric(context, info, batch)?;
1078 context.numeric_cache.insert(field_id, numeric);
1079 mapping.insert(*subquery_id, field_id);
1080 }
1081
1082 let rewritten = rewrite_scalar_expr_for_subqueries(&translated, &mapping);
1083 context.evaluate_numeric(&rewritten, batch)
1084 }
1085
1086 fn execute_select_without_table(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
1088 use arrow::array::ArrayRef;
1089 use arrow::datatypes::Field;
1090
1091 let mut fields = Vec::new();
1093 let mut arrays: Vec<ArrayRef> = Vec::new();
1094
1095 for proj in &plan.projections {
1096 match proj {
1097 SelectProjection::Computed { expr, alias } => {
1098 let literal =
1099 evaluate_constant_scalar_with_aggregates(expr).ok_or_else(|| {
1100 Error::InvalidArgumentError(
1101 "SELECT without FROM only supports constant expressions".into(),
1102 )
1103 })?;
1104 let (dtype, array) = Self::literal_to_array(&literal)?;
1105
1106 fields.push(Field::new(alias.clone(), dtype, true));
1107 arrays.push(array);
1108 }
1109 _ => {
1110 return Err(Error::InvalidArgumentError(
1111 "SELECT without FROM only supports computed projections".into(),
1112 ));
1113 }
1114 }
1115 }
1116
1117 let schema = Arc::new(Schema::new(fields));
1118 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)
1119 .map_err(|e| Error::Internal(format!("failed to create record batch: {}", e)))?;
1120
1121 if plan.distinct {
1122 let mut state = DistinctState::default();
1123 batch = match distinct_filter_batch(batch, &mut state)? {
1124 Some(filtered) => filtered,
1125 None => RecordBatch::new_empty(Arc::clone(&schema)),
1126 };
1127 }
1128
1129 let schema = batch.schema();
1130
1131 Ok(SelectExecution::new_single_batch(
1132 String::new(), schema,
1134 batch,
1135 ))
1136 }
1137
1138 fn literal_to_array(lit: &llkv_expr::literal::Literal) -> ExecutorResult<(DataType, ArrayRef)> {
1140 use arrow::array::{
1141 ArrayRef, BooleanArray, Date32Array, Decimal128Array, Float64Array, Int64Array,
1142 IntervalMonthDayNanoArray, StringArray, StructArray, new_null_array,
1143 };
1144 use arrow::datatypes::{DataType, Field, IntervalUnit};
1145 use llkv_compute::scalar::interval::interval_value_to_arrow;
1146 use llkv_expr::literal::Literal;
1147
1148 match lit {
1149 Literal::Int128(v) => {
1150 let val = i64::try_from(*v).unwrap_or(0);
1151 Ok((
1152 DataType::Int64,
1153 Arc::new(Int64Array::from(vec![val])) as ArrayRef,
1154 ))
1155 }
1156 Literal::Float64(v) => Ok((
1157 DataType::Float64,
1158 Arc::new(Float64Array::from(vec![*v])) as ArrayRef,
1159 )),
1160 Literal::Boolean(v) => Ok((
1161 DataType::Boolean,
1162 Arc::new(BooleanArray::from(vec![*v])) as ArrayRef,
1163 )),
1164 Literal::Decimal128(value) => {
1165 let iter = std::iter::once(value.raw_value());
1166 let precision = std::cmp::max(value.precision(), value.scale() as u8);
1167 let array = Decimal128Array::from_iter_values(iter)
1168 .with_precision_and_scale(precision, value.scale())
1169 .map_err(|err| {
1170 Error::InvalidArgumentError(format!(
1171 "failed to build Decimal128 literal array: {err}"
1172 ))
1173 })?;
1174 Ok((
1175 DataType::Decimal128(precision, value.scale()),
1176 Arc::new(array) as ArrayRef,
1177 ))
1178 }
1179 Literal::String(v) => Ok((
1180 DataType::Utf8,
1181 Arc::new(StringArray::from(vec![v.clone()])) as ArrayRef,
1182 )),
1183 Literal::Date32(v) => Ok((
1184 DataType::Date32,
1185 Arc::new(Date32Array::from(vec![*v])) as ArrayRef,
1186 )),
1187 Literal::Null => Ok((DataType::Null, new_null_array(&DataType::Null, 1))),
1188 Literal::Interval(interval) => Ok((
1189 DataType::Interval(IntervalUnit::MonthDayNano),
1190 Arc::new(IntervalMonthDayNanoArray::from(vec![
1191 interval_value_to_arrow(*interval),
1192 ])) as ArrayRef,
1193 )),
1194 Literal::Struct(struct_fields) => {
1195 let mut inner_fields = Vec::new();
1197 let mut inner_arrays = Vec::new();
1198
1199 for (field_name, field_lit) in struct_fields {
1200 let (field_dtype, field_array) = Self::literal_to_array(field_lit)?;
1201 inner_fields.push(Field::new(field_name.clone(), field_dtype, true));
1202 inner_arrays.push(field_array);
1203 }
1204
1205 let struct_array =
1206 StructArray::try_new(inner_fields.clone().into(), inner_arrays, None).map_err(
1207 |e| Error::Internal(format!("failed to create struct array: {}", e)),
1208 )?;
1209
1210 Ok((
1211 DataType::Struct(inner_fields.into()),
1212 Arc::new(struct_array) as ArrayRef,
1213 ))
1214 }
1215 }
1216 }
1217
1218 fn execute_cross_product(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
1220 use arrow::compute::concat_batches;
1221
1222 if plan.tables.len() < 2 {
1223 return Err(Error::InvalidArgumentError(
1224 "cross product requires at least 2 tables".into(),
1225 ));
1226 }
1227
1228 let mut tables_with_handles = Vec::with_capacity(plan.tables.len());
1229 for table_ref in &plan.tables {
1230 let qualified_name = table_ref.qualified_name();
1231 let table = self.provider.get_table(&qualified_name)?;
1232 tables_with_handles.push((table_ref.clone(), table));
1233 }
1234
1235 let display_name = tables_with_handles
1236 .iter()
1237 .map(|(table_ref, _)| table_ref.qualified_name())
1238 .collect::<Vec<_>>()
1239 .join(",");
1240
1241 let mut remaining_filter = plan.filter.clone();
1242
1243 let join_data = if remaining_filter.as_ref().is_some() {
1246 self.try_execute_hash_join(&plan, &tables_with_handles)?
1247 } else {
1248 None
1249 };
1250
1251 let current = if let Some((joined, handled_all_predicates)) = join_data {
1252 if handled_all_predicates {
1254 remaining_filter = None;
1255 }
1256 joined
1257 } else {
1258 let has_joins = !plan.joins.is_empty();
1260
1261 if has_joins && tables_with_handles.len() == 2 {
1262 use llkv_join::JoinOptions;
1264
1265 let (left_ref, left_table) = &tables_with_handles[0];
1266 let (right_ref, right_table) = &tables_with_handles[1];
1267
1268 let join_metadata = plan.joins.first().ok_or_else(|| {
1269 Error::InvalidArgumentError("expected join metadata for two-table join".into())
1270 })?;
1271
1272 let join_type = match join_metadata.join_type {
1273 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1274 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1275 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1276 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1277 };
1278
1279 tracing::debug!(
1280 "Using llkv-join for {join_type:?} join between {} and {}",
1281 left_ref.qualified_name(),
1282 right_ref.qualified_name()
1283 );
1284
1285 let left_col_count = left_table.schema.columns.len();
1286 let right_col_count = right_table.schema.columns.len();
1287
1288 let mut combined_fields = Vec::with_capacity(left_col_count + right_col_count);
1289 for col in &left_table.schema.columns {
1290 combined_fields.push(Field::new(
1291 col.name.clone(),
1292 col.data_type.clone(),
1293 col.nullable,
1294 ));
1295 }
1296 for col in &right_table.schema.columns {
1297 combined_fields.push(Field::new(
1298 col.name.clone(),
1299 col.data_type.clone(),
1300 col.nullable,
1301 ));
1302 }
1303 let combined_schema = Arc::new(Schema::new(combined_fields));
1304 let column_counts = vec![left_col_count, right_col_count];
1305 let table_indices = vec![0, 1];
1306
1307 let mut join_keys = Vec::new();
1308 let mut condition_is_trivial = false;
1309 let mut condition_is_impossible = false;
1310
1311 if let Some(condition) = join_metadata.on_condition.as_ref() {
1312 let plan = build_join_keys_from_condition(
1313 condition,
1314 left_ref,
1315 left_table.as_ref(),
1316 right_ref,
1317 right_table.as_ref(),
1318 )?;
1319 join_keys = plan.keys;
1320 condition_is_trivial = plan.always_true;
1321 condition_is_impossible = plan.always_false;
1322 }
1323
1324 if condition_is_impossible {
1325 let batches = build_no_match_join_batches(
1326 join_type,
1327 left_ref,
1328 left_table.as_ref(),
1329 right_ref,
1330 right_table.as_ref(),
1331 Arc::clone(&combined_schema),
1332 )?;
1333
1334 TableCrossProductData {
1335 schema: combined_schema,
1336 batches,
1337 column_counts,
1338 table_indices,
1339 }
1340 } else {
1341 if !condition_is_trivial
1342 && join_metadata.on_condition.is_some()
1343 && join_keys.is_empty()
1344 {
1345 return Err(Error::InvalidArgumentError(
1346 "JOIN ON clause must include at least one equality predicate".into(),
1347 ));
1348 }
1349
1350 let mut result_batches = Vec::new();
1351 let mut on_batch = |batch: RecordBatch| {
1352 result_batches.push(batch);
1353 };
1354 left_table.storage().join_stream(
1355 right_table.storage().as_ref(),
1356 &join_keys,
1357 &JoinOptions {
1358 join_type,
1359 ..Default::default()
1360 },
1361 &mut on_batch,
1362 )?;
1363
1364 TableCrossProductData {
1365 schema: combined_schema,
1366 batches: result_batches,
1367 column_counts,
1368 table_indices,
1369 }
1370 }
1371 } else if has_joins && tables_with_handles.len() > 2 {
1372 let join_lookup: FxHashMap<usize, &llkv_plan::JoinMetadata> = plan
1375 .joins
1376 .iter()
1377 .map(|join| (join.left_table_index, join))
1378 .collect();
1379
1380 let constraint_map = if let Some(filter_wrapper) = remaining_filter.as_ref() {
1382 extract_literal_pushdown_filters(
1383 &filter_wrapper.predicate,
1384 &tables_with_handles,
1385 )
1386 } else {
1387 vec![Vec::new(); tables_with_handles.len()]
1388 };
1389
1390 let (first_ref, first_table) = &tables_with_handles[0];
1392 let first_constraints = constraint_map.first().map(|v| v.as_slice()).unwrap_or(&[]);
1393 let mut accumulated =
1394 collect_table_data(0, first_ref, first_table.as_ref(), first_constraints)?;
1395
1396 for (idx, (right_ref, right_table)) in
1398 tables_with_handles.iter().enumerate().skip(1)
1399 {
1400 let right_constraints =
1401 constraint_map.get(idx).map(|v| v.as_slice()).unwrap_or(&[]);
1402
1403 let join_metadata = join_lookup.get(&(idx - 1)).ok_or_else(|| {
1404 Error::InvalidArgumentError(format!(
1405 "No join condition found between table {} and {}. Multi-table queries require explicit JOIN syntax.",
1406 idx - 1, idx
1407 ))
1408 })?;
1409
1410 let join_type = match join_metadata.join_type {
1411 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1412 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1413 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1414 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1415 };
1416
1417 let right_data = collect_table_data(
1419 idx,
1420 right_ref,
1421 right_table.as_ref(),
1422 right_constraints,
1423 )?;
1424
1425 let condition_expr = join_metadata
1427 .on_condition
1428 .clone()
1429 .unwrap_or(LlkvExpr::Literal(true));
1430
1431 let join_batches = execute_hash_join_batches(
1434 &accumulated.schema,
1435 &accumulated.batches,
1436 &right_data.schema,
1437 &right_data.batches,
1438 &condition_expr,
1439 join_type,
1440 )?;
1441
1442 let combined_fields: Vec<Field> = accumulated
1444 .schema
1445 .fields()
1446 .iter()
1447 .chain(right_data.schema.fields().iter())
1448 .map(|f| {
1449 Field::new(f.name().clone(), f.data_type().clone(), f.is_nullable())
1450 })
1451 .collect();
1452 let combined_schema = Arc::new(Schema::new(combined_fields));
1453
1454 accumulated = TableCrossProductData {
1455 schema: combined_schema,
1456 batches: join_batches,
1457 column_counts: {
1458 let mut counts = accumulated.column_counts;
1459 counts.push(right_data.schema.fields().len());
1460 counts
1461 },
1462 table_indices: {
1463 let mut indices = accumulated.table_indices;
1464 indices.push(idx);
1465 indices
1466 },
1467 };
1468 }
1469
1470 accumulated
1471 } else {
1472 let constraint_map = if let Some(filter_wrapper) = remaining_filter.as_ref() {
1474 extract_literal_pushdown_filters(
1475 &filter_wrapper.predicate,
1476 &tables_with_handles,
1477 )
1478 } else {
1479 vec![Vec::new(); tables_with_handles.len()]
1480 };
1481
1482 let mut staged: Vec<TableCrossProductData> =
1483 Vec::with_capacity(tables_with_handles.len());
1484 let join_lookup: FxHashMap<usize, &llkv_plan::JoinMetadata> = plan
1485 .joins
1486 .iter()
1487 .map(|join| (join.left_table_index, join))
1488 .collect();
1489
1490 let mut idx = 0usize;
1491 while idx < tables_with_handles.len() {
1492 if let Some(join_metadata) = join_lookup.get(&idx) {
1493 if idx + 1 >= tables_with_handles.len() {
1494 return Err(Error::Internal(
1495 "join metadata references table beyond FROM list".into(),
1496 ));
1497 }
1498
1499 let overlaps_next_join = join_lookup.contains_key(&(idx + 1));
1504 if let Some(condition) = join_metadata.on_condition.as_ref() {
1505 let (left_ref, left_table) = &tables_with_handles[idx];
1506 let (right_ref, right_table) = &tables_with_handles[idx + 1];
1507 let join_plan = build_join_keys_from_condition(
1508 condition,
1509 left_ref,
1510 left_table.as_ref(),
1511 right_ref,
1512 right_table.as_ref(),
1513 )?;
1514 if join_plan.always_false && !overlaps_next_join {
1515 let join_type = match join_metadata.join_type {
1516 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1517 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1518 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1519 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1520 };
1521
1522 let left_col_count = left_table.schema.columns.len();
1523 let right_col_count = right_table.schema.columns.len();
1524
1525 let mut combined_fields =
1526 Vec::with_capacity(left_col_count + right_col_count);
1527 for col in &left_table.schema.columns {
1528 combined_fields.push(Field::new(
1529 col.name.clone(),
1530 col.data_type.clone(),
1531 col.nullable,
1532 ));
1533 }
1534 for col in &right_table.schema.columns {
1535 combined_fields.push(Field::new(
1536 col.name.clone(),
1537 col.data_type.clone(),
1538 col.nullable,
1539 ));
1540 }
1541
1542 let combined_schema = Arc::new(Schema::new(combined_fields));
1543 let batches = build_no_match_join_batches(
1544 join_type,
1545 left_ref,
1546 left_table.as_ref(),
1547 right_ref,
1548 right_table.as_ref(),
1549 Arc::clone(&combined_schema),
1550 )?;
1551
1552 staged.push(TableCrossProductData {
1553 schema: combined_schema,
1554 batches,
1555 column_counts: vec![left_col_count, right_col_count],
1556 table_indices: vec![idx, idx + 1],
1557 });
1558 idx += 2;
1559 continue;
1560 }
1561 }
1562 }
1563
1564 let (table_ref, table) = &tables_with_handles[idx];
1565 let constraints = constraint_map.get(idx).map(|v| v.as_slice()).unwrap_or(&[]);
1566 staged.push(collect_table_data(
1567 idx,
1568 table_ref,
1569 table.as_ref(),
1570 constraints,
1571 )?);
1572 idx += 1;
1573 }
1574
1575 cross_join_all(staged)?
1576 }
1577 };
1578
1579 let TableCrossProductData {
1580 schema: combined_schema,
1581 batches: mut combined_batches,
1582 column_counts,
1583 table_indices,
1584 } = current;
1585
1586 let column_lookup_map = build_cross_product_column_lookup(
1587 combined_schema.as_ref(),
1588 &plan.tables,
1589 &column_counts,
1590 &table_indices,
1591 );
1592
1593 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
1594 .scalar_subqueries
1595 .iter()
1596 .map(|subquery| (subquery.id, subquery))
1597 .collect();
1598
1599 if let Some(filter_wrapper) = remaining_filter.as_ref() {
1600 let mut filter_context = CrossProductExpressionContext::new(
1601 combined_schema.as_ref(),
1602 column_lookup_map.clone(),
1603 )?;
1604 let translated_filter = translate_predicate(
1605 filter_wrapper.predicate.clone(),
1606 filter_context.schema(),
1607 |name| {
1608 Error::InvalidArgumentError(format!(
1609 "column '{}' not found in cross product result",
1610 name
1611 ))
1612 },
1613 )?;
1614
1615 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
1616 filter_wrapper
1617 .subqueries
1618 .iter()
1619 .map(|subquery| (subquery.id, subquery))
1620 .collect();
1621 let mut predicate_scalar_ids = FxHashSet::default();
1622 collect_predicate_scalar_subquery_ids(&translated_filter, &mut predicate_scalar_ids);
1623
1624 let mut filtered_batches = Vec::with_capacity(combined_batches.len());
1625 for batch in combined_batches.into_iter() {
1626 filter_context.reset();
1627 for subquery_id in &predicate_scalar_ids {
1628 let info = scalar_lookup.get(subquery_id).ok_or_else(|| {
1629 Error::Internal("missing scalar subquery metadata".into())
1630 })?;
1631 let array =
1632 self.evaluate_scalar_subquery_array(&mut filter_context, info, &batch)?;
1633 let accessor = ColumnAccessor::from_array(&array)?;
1634 filter_context.register_scalar_subquery_column(*subquery_id, accessor);
1635 }
1636 let mask = filter_context.evaluate_predicate_mask(
1637 &translated_filter,
1638 &batch,
1639 |ctx, subquery_expr, row_idx, current_batch| {
1640 let subquery = subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
1641 Error::Internal("missing correlated subquery metadata".into())
1642 })?;
1643 let exists =
1644 self.evaluate_exists_subquery(ctx, subquery, current_batch, row_idx)?;
1645 let value = if subquery_expr.negated {
1646 !exists
1647 } else {
1648 exists
1649 };
1650 Ok(Some(value))
1651 },
1652 )?;
1653 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
1654 Error::InvalidArgumentError(format!(
1655 "failed to apply cross product filter: {err}"
1656 ))
1657 })?;
1658 if filtered.num_rows() > 0 {
1659 filtered_batches.push(filtered);
1660 }
1661 }
1662 combined_batches = filtered_batches;
1663 }
1664
1665 if !plan.group_by.is_empty() {
1667 return self.execute_group_by_from_batches(
1668 display_name,
1669 plan,
1670 combined_schema,
1671 combined_batches,
1672 column_lookup_map,
1673 );
1674 }
1675
1676 if !plan.aggregates.is_empty() {
1677 return self.execute_cross_product_aggregates(
1678 Arc::clone(&combined_schema),
1679 combined_batches,
1680 &column_lookup_map,
1681 &plan,
1682 &display_name,
1683 );
1684 }
1685
1686 if self.has_computed_aggregates(&plan) {
1687 return self.execute_cross_product_computed_aggregates(
1688 Arc::clone(&combined_schema),
1689 combined_batches,
1690 &column_lookup_map,
1691 &plan,
1692 &display_name,
1693 );
1694 }
1695
1696 let mut combined_batch = if combined_batches.is_empty() {
1697 RecordBatch::new_empty(Arc::clone(&combined_schema))
1698 } else if combined_batches.len() == 1 {
1699 combined_batches.pop().unwrap()
1700 } else {
1701 concat_batches(&combined_schema, &combined_batches).map_err(|e| {
1702 Error::Internal(format!(
1703 "failed to concatenate cross product batches: {}",
1704 e
1705 ))
1706 })?
1707 };
1708
1709 if !plan.order_by.is_empty() {
1711 let mut resolved_order_by = Vec::with_capacity(plan.order_by.len());
1712 for order in &plan.order_by {
1713 let resolved_target = match &order.target {
1714 OrderTarget::Column(name) => {
1715 let col_name = name.to_ascii_lowercase();
1716 if let Some(&idx) = column_lookup_map.get(&col_name) {
1717 OrderTarget::Index(idx)
1718 } else {
1719 if let Ok(idx) = combined_schema.index_of(name) {
1721 OrderTarget::Index(idx)
1722 } else {
1723 return Err(Error::InvalidArgumentError(format!(
1724 "ORDER BY references unknown column '{}'",
1725 name
1726 )));
1727 }
1728 }
1729 }
1730 other => other.clone(),
1731 };
1732 resolved_order_by.push(llkv_plan::OrderByPlan {
1733 target: resolved_target,
1734 sort_type: order.sort_type.clone(),
1735 ascending: order.ascending,
1736 nulls_first: order.nulls_first,
1737 });
1738 }
1739
1740 combined_batch = sort_record_batch_with_order(
1741 &combined_schema,
1742 &combined_batch,
1743 &resolved_order_by,
1744 )?;
1745 }
1746
1747 if !plan.projections.is_empty() {
1749 let mut selected_fields = Vec::new();
1750 let mut selected_columns = Vec::new();
1751 let mut expr_context: Option<CrossProductExpressionContext> = None;
1752
1753 for proj in &plan.projections {
1754 match proj {
1755 SelectProjection::AllColumns => {
1756 selected_fields = combined_schema.fields().iter().cloned().collect();
1758 selected_columns = combined_batch.columns().to_vec();
1759 break;
1760 }
1761 SelectProjection::AllColumnsExcept { exclude } => {
1762 let exclude_lower: Vec<String> =
1764 exclude.iter().map(|e| e.to_ascii_lowercase()).collect();
1765
1766 let mut excluded_indices = FxHashSet::default();
1767 for excluded_name in &exclude_lower {
1768 if let Some(&idx) = column_lookup_map.get(excluded_name) {
1769 excluded_indices.insert(idx);
1770 }
1771 }
1772
1773 for (idx, field) in combined_schema.fields().iter().enumerate() {
1774 let field_name_lower = field.name().to_ascii_lowercase();
1775 if !exclude_lower.contains(&field_name_lower)
1776 && !excluded_indices.contains(&idx)
1777 {
1778 selected_fields.push(field.clone());
1779 selected_columns.push(combined_batch.column(idx).clone());
1780 }
1781 }
1782 break;
1783 }
1784 SelectProjection::Column { name, alias } => {
1785 let col_name = name.to_ascii_lowercase();
1787 if let Some(&idx) = column_lookup_map.get(&col_name) {
1788 let field = combined_schema.field(idx);
1789 let output_name = alias.as_ref().unwrap_or(name).clone();
1790 selected_fields.push(Arc::new(arrow::datatypes::Field::new(
1791 output_name,
1792 field.data_type().clone(),
1793 field.is_nullable(),
1794 )));
1795 selected_columns.push(combined_batch.column(idx).clone());
1796 } else {
1797 return Err(Error::InvalidArgumentError(format!(
1798 "column '{}' not found in cross product result",
1799 name
1800 )));
1801 }
1802 }
1803 SelectProjection::Computed { expr, alias } => {
1804 if expr_context.is_none() {
1805 expr_context = Some(CrossProductExpressionContext::new(
1806 combined_schema.as_ref(),
1807 column_lookup_map.clone(),
1808 )?);
1809 }
1810 let context = expr_context
1811 .as_mut()
1812 .expect("projection context must be initialized");
1813 context.reset();
1814 let evaluated = self.evaluate_projection_expression(
1815 context,
1816 expr,
1817 &combined_batch,
1818 &scalar_lookup,
1819 )?;
1820 let field = Arc::new(arrow::datatypes::Field::new(
1821 alias.clone(),
1822 evaluated.data_type().clone(),
1823 true,
1824 ));
1825 selected_fields.push(field);
1826 selected_columns.push(evaluated);
1827 }
1828 }
1829 }
1830
1831 let projected_schema = Arc::new(Schema::new(selected_fields));
1832 combined_batch = RecordBatch::try_new(projected_schema, selected_columns)
1833 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))?;
1834 }
1835
1836 if plan.distinct {
1837 let mut state = DistinctState::default();
1838 let source_schema = combined_batch.schema();
1839 combined_batch = match distinct_filter_batch(combined_batch, &mut state)? {
1840 Some(filtered) => filtered,
1841 None => RecordBatch::new_empty(source_schema),
1842 };
1843 }
1844
1845 let schema = combined_batch.schema();
1846
1847 Ok(SelectExecution::new_single_batch(
1848 display_name,
1849 schema,
1850 combined_batch,
1851 ))
1852 }
1853}
1854
1855struct JoinKeyBuild {
1856 keys: Vec<llkv_join::JoinKey>,
1857 always_true: bool,
1858 always_false: bool,
1859}
1860
1861#[allow(dead_code)]
1863type JoinKeyBuildEqualities = JoinKeyBuild;
1864
1865impl JoinKeyBuild {
1866 #[allow(dead_code)]
1867 fn equalities(&self) -> &[llkv_join::JoinKey] {
1868 &self.keys
1869 }
1870}
1871
1872#[derive(Debug)]
1873enum JoinConditionAnalysis {
1874 AlwaysTrue,
1875 AlwaysFalse,
1876 EquiPairs(Vec<(String, String)>),
1877}
1878
1879fn build_join_keys_from_condition<P>(
1880 condition: &LlkvExpr<'static, String>,
1881 left_ref: &llkv_plan::TableRef,
1882 left_table: &ExecutorTable<P>,
1883 right_ref: &llkv_plan::TableRef,
1884 right_table: &ExecutorTable<P>,
1885) -> ExecutorResult<JoinKeyBuild>
1886where
1887 P: Pager<Blob = EntryHandle> + Send + Sync,
1888{
1889 match analyze_join_condition(condition)? {
1890 JoinConditionAnalysis::AlwaysTrue => Ok(JoinKeyBuild {
1891 keys: Vec::new(),
1892 always_true: true,
1893 always_false: false,
1894 }),
1895 JoinConditionAnalysis::AlwaysFalse => Ok(JoinKeyBuild {
1896 keys: Vec::new(),
1897 always_true: false,
1898 always_false: true,
1899 }),
1900 JoinConditionAnalysis::EquiPairs(pairs) => {
1901 let left_lookup = build_join_column_lookup(left_ref, left_table);
1902 let right_lookup = build_join_column_lookup(right_ref, right_table);
1903
1904 let mut keys = Vec::with_capacity(pairs.len());
1905 for (lhs, rhs) in pairs {
1906 let (lhs_side, lhs_field) = resolve_join_column(&lhs, &left_lookup, &right_lookup)?;
1907 let (rhs_side, rhs_field) = resolve_join_column(&rhs, &left_lookup, &right_lookup)?;
1908
1909 match (lhs_side, rhs_side) {
1910 (JoinColumnSide::Left, JoinColumnSide::Right) => {
1911 keys.push(llkv_join::JoinKey::new(lhs_field, rhs_field));
1912 }
1913 (JoinColumnSide::Right, JoinColumnSide::Left) => {
1914 keys.push(llkv_join::JoinKey::new(rhs_field, lhs_field));
1915 }
1916 (JoinColumnSide::Left, JoinColumnSide::Left) => {
1917 return Err(Error::InvalidArgumentError(format!(
1918 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1919 left_ref.display_name(),
1920 lhs,
1921 rhs
1922 )));
1923 }
1924 (JoinColumnSide::Right, JoinColumnSide::Right) => {
1925 return Err(Error::InvalidArgumentError(format!(
1926 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1927 right_ref.display_name(),
1928 lhs,
1929 rhs
1930 )));
1931 }
1932 }
1933 }
1934
1935 Ok(JoinKeyBuild {
1936 keys,
1937 always_true: false,
1938 always_false: false,
1939 })
1940 }
1941 }
1942}
1943
1944fn analyze_join_condition(
1945 expr: &LlkvExpr<'static, String>,
1946) -> ExecutorResult<JoinConditionAnalysis> {
1947 match evaluate_constant_join_expr(expr) {
1948 ConstantJoinEvaluation::Known(true) => {
1949 return Ok(JoinConditionAnalysis::AlwaysTrue);
1950 }
1951 ConstantJoinEvaluation::Known(false) | ConstantJoinEvaluation::Unknown => {
1952 return Ok(JoinConditionAnalysis::AlwaysFalse);
1953 }
1954 ConstantJoinEvaluation::NotConstant => {}
1955 }
1956 match expr {
1957 LlkvExpr::Literal(value) => {
1958 if *value {
1959 Ok(JoinConditionAnalysis::AlwaysTrue)
1960 } else {
1961 Ok(JoinConditionAnalysis::AlwaysFalse)
1962 }
1963 }
1964 LlkvExpr::And(children) => {
1965 let mut collected: Vec<(String, String)> = Vec::new();
1966 for child in children {
1967 match analyze_join_condition(child)? {
1968 JoinConditionAnalysis::AlwaysTrue => {}
1969 JoinConditionAnalysis::AlwaysFalse => {
1970 return Ok(JoinConditionAnalysis::AlwaysFalse);
1971 }
1972 JoinConditionAnalysis::EquiPairs(mut pairs) => {
1973 collected.append(&mut pairs);
1974 }
1975 }
1976 }
1977
1978 if collected.is_empty() {
1979 Ok(JoinConditionAnalysis::AlwaysTrue)
1980 } else {
1981 Ok(JoinConditionAnalysis::EquiPairs(collected))
1982 }
1983 }
1984 LlkvExpr::Compare { left, op, right } => {
1985 if *op != CompareOp::Eq {
1986 return Err(Error::InvalidArgumentError(
1987 "JOIN ON clause only supports '=' comparisons in optimized path".into(),
1988 ));
1989 }
1990 let left_name = try_extract_simple_column(left).ok_or_else(|| {
1991 Error::InvalidArgumentError(
1992 "JOIN ON clause requires plain column references".into(),
1993 )
1994 })?;
1995 let right_name = try_extract_simple_column(right).ok_or_else(|| {
1996 Error::InvalidArgumentError(
1997 "JOIN ON clause requires plain column references".into(),
1998 )
1999 })?;
2000 Ok(JoinConditionAnalysis::EquiPairs(vec![(
2001 left_name.to_string(),
2002 right_name.to_string(),
2003 )]))
2004 }
2005 _ => Err(Error::InvalidArgumentError(
2006 "JOIN ON expressions must be conjunctions of column equality predicates".into(),
2007 )),
2008 }
2009}
2010
2011fn compare_literals_with_mode(
2012 op: CompareOp,
2013 left: &Literal,
2014 right: &Literal,
2015 null_behavior: NullComparisonBehavior,
2016) -> Option<bool> {
2017 use std::cmp::Ordering;
2018
2019 fn ordering_result(ord: Ordering, op: CompareOp) -> bool {
2020 match op {
2021 CompareOp::Eq => ord == Ordering::Equal,
2022 CompareOp::NotEq => ord != Ordering::Equal,
2023 CompareOp::Lt => ord == Ordering::Less,
2024 CompareOp::LtEq => ord != Ordering::Greater,
2025 CompareOp::Gt => ord == Ordering::Greater,
2026 CompareOp::GtEq => ord != Ordering::Less,
2027 }
2028 }
2029
2030 fn compare_f64(lhs: f64, rhs: f64, op: CompareOp) -> bool {
2031 match op {
2032 CompareOp::Eq => lhs == rhs,
2033 CompareOp::NotEq => lhs != rhs,
2034 CompareOp::Lt => lhs < rhs,
2035 CompareOp::LtEq => lhs <= rhs,
2036 CompareOp::Gt => lhs > rhs,
2037 CompareOp::GtEq => lhs >= rhs,
2038 }
2039 }
2040
2041 match (left, right) {
2042 (Literal::Null, _) | (_, Literal::Null) => match null_behavior {
2043 NullComparisonBehavior::ThreeValuedLogic => None,
2044 },
2045 (Literal::Int128(lhs), Literal::Int128(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
2046 (Literal::Float64(lhs), Literal::Float64(rhs)) => Some(compare_f64(*lhs, *rhs, op)),
2047 (Literal::Int128(lhs), Literal::Float64(rhs)) => Some(compare_f64(*lhs as f64, *rhs, op)),
2048 (Literal::Float64(lhs), Literal::Int128(rhs)) => Some(compare_f64(*lhs, *rhs as f64, op)),
2049 (Literal::Boolean(lhs), Literal::Boolean(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
2050 (Literal::String(lhs), Literal::String(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
2051 (Literal::Decimal128(lhs), Literal::Decimal128(rhs)) => {
2052 llkv_compute::scalar::decimal::compare(*lhs, *rhs)
2053 .ok()
2054 .map(|ord| ordering_result(ord, op))
2055 }
2056 (Literal::Decimal128(lhs), Literal::Int128(rhs)) => {
2057 DecimalValue::new(*rhs, 0).ok().and_then(|rhs_dec| {
2058 llkv_compute::scalar::decimal::compare(*lhs, rhs_dec)
2059 .ok()
2060 .map(|ord| ordering_result(ord, op))
2061 })
2062 }
2063 (Literal::Int128(lhs), Literal::Decimal128(rhs)) => {
2064 DecimalValue::new(*lhs, 0).ok().and_then(|lhs_dec| {
2065 llkv_compute::scalar::decimal::compare(lhs_dec, *rhs)
2066 .ok()
2067 .map(|ord| ordering_result(ord, op))
2068 })
2069 }
2070 (Literal::Decimal128(lhs), Literal::Float64(rhs)) => {
2071 Some(compare_f64(lhs.to_f64(), *rhs, op))
2072 }
2073 (Literal::Float64(lhs), Literal::Decimal128(rhs)) => {
2074 Some(compare_f64(*lhs, rhs.to_f64(), op))
2075 }
2076 (Literal::Struct(_), _) | (_, Literal::Struct(_)) => None,
2077 _ => None,
2078 }
2079}
2080
2081fn build_no_match_join_batches<P>(
2082 join_type: llkv_join::JoinType,
2083 left_ref: &llkv_plan::TableRef,
2084 left_table: &ExecutorTable<P>,
2085 right_ref: &llkv_plan::TableRef,
2086 right_table: &ExecutorTable<P>,
2087 combined_schema: Arc<Schema>,
2088) -> ExecutorResult<Vec<RecordBatch>>
2089where
2090 P: Pager<Blob = EntryHandle> + Send + Sync,
2091{
2092 match join_type {
2093 llkv_join::JoinType::Inner => Ok(Vec::new()),
2094 llkv_join::JoinType::Left => {
2095 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
2096 let mut results = Vec::new();
2097
2098 for left_batch in left_batches {
2099 let row_count = left_batch.num_rows();
2100 if row_count == 0 {
2101 continue;
2102 }
2103
2104 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2105 columns.extend(left_batch.columns().iter().cloned());
2106 for column in &right_table.schema.columns {
2107 columns.push(new_null_array(&column.data_type, row_count));
2108 }
2109
2110 let batch =
2111 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2112 Error::Internal(format!("failed to build LEFT JOIN fallback batch: {err}"))
2113 })?;
2114 results.push(batch);
2115 }
2116
2117 Ok(results)
2118 }
2119 llkv_join::JoinType::Right => {
2120 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
2121 let mut results = Vec::new();
2122
2123 for right_batch in right_batches {
2124 let row_count = right_batch.num_rows();
2125 if row_count == 0 {
2126 continue;
2127 }
2128
2129 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2130 for column in &left_table.schema.columns {
2131 columns.push(new_null_array(&column.data_type, row_count));
2132 }
2133 columns.extend(right_batch.columns().iter().cloned());
2134
2135 let batch =
2136 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2137 Error::Internal(format!("failed to build RIGHT JOIN fallback batch: {err}"))
2138 })?;
2139 results.push(batch);
2140 }
2141
2142 Ok(results)
2143 }
2144 llkv_join::JoinType::Full => {
2145 let mut results = Vec::new();
2146
2147 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
2148 for left_batch in left_batches {
2149 let row_count = left_batch.num_rows();
2150 if row_count == 0 {
2151 continue;
2152 }
2153
2154 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2155 columns.extend(left_batch.columns().iter().cloned());
2156 for column in &right_table.schema.columns {
2157 columns.push(new_null_array(&column.data_type, row_count));
2158 }
2159
2160 let batch =
2161 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2162 Error::Internal(format!(
2163 "failed to build FULL JOIN left fallback batch: {err}"
2164 ))
2165 })?;
2166 results.push(batch);
2167 }
2168
2169 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
2170 for right_batch in right_batches {
2171 let row_count = right_batch.num_rows();
2172 if row_count == 0 {
2173 continue;
2174 }
2175
2176 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2177 for column in &left_table.schema.columns {
2178 columns.push(new_null_array(&column.data_type, row_count));
2179 }
2180 columns.extend(right_batch.columns().iter().cloned());
2181
2182 let batch =
2183 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2184 Error::Internal(format!(
2185 "failed to build FULL JOIN right fallback batch: {err}"
2186 ))
2187 })?;
2188 results.push(batch);
2189 }
2190
2191 Ok(results)
2192 }
2193 other => Err(Error::InvalidArgumentError(format!(
2194 "{other:?} join type is not supported when join predicate is unsatisfiable",
2195 ))),
2196 }
2197}
2198
2199fn scan_all_columns_for_join<P>(
2200 table_ref: &llkv_plan::TableRef,
2201 table: &ExecutorTable<P>,
2202) -> ExecutorResult<Vec<RecordBatch>>
2203where
2204 P: Pager<Blob = EntryHandle> + Send + Sync,
2205{
2206 if table.schema.columns.is_empty() {
2207 return Err(Error::InvalidArgumentError(format!(
2208 "table '{}' has no columns; joins require at least one column",
2209 table_ref.qualified_name()
2210 )));
2211 }
2212
2213 let mut projections = Vec::with_capacity(table.schema.columns.len());
2214 for column in &table.schema.columns {
2215 projections.push(ScanProjection::from(StoreProjection::with_alias(
2216 LogicalFieldId::for_user(table.table_id(), column.field_id),
2217 column.name.clone(),
2218 )));
2219 }
2220
2221 let filter_field = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
2222 let filter_expr = full_table_scan_filter(filter_field);
2223
2224 let mut batches = Vec::new();
2225 let mut on_batch = |batch| {
2226 batches.push(batch);
2227 };
2228 table.storage().scan_stream(
2229 &projections,
2230 &filter_expr,
2231 ScanStreamOptions {
2232 include_nulls: true,
2233 include_row_ids: true,
2234 ..ScanStreamOptions::default()
2235 },
2236 &mut on_batch,
2237 )?;
2238
2239 Ok(batches)
2240}
2241
2242fn build_join_column_lookup<P>(
2243 table_ref: &llkv_plan::TableRef,
2244 table: &ExecutorTable<P>,
2245) -> FxHashMap<String, FieldId>
2246where
2247 P: Pager<Blob = EntryHandle> + Send + Sync,
2248{
2249 let mut lookup = FxHashMap::default();
2250 let table_lower = table_ref.table.to_ascii_lowercase();
2251 let qualified_lower = table_ref.qualified_name().to_ascii_lowercase();
2252 let display_lower = table_ref.display_name().to_ascii_lowercase();
2253 let alias_lower = table_ref.alias.as_ref().map(|s| s.to_ascii_lowercase());
2254 let schema_lower = if table_ref.schema.is_empty() {
2255 None
2256 } else {
2257 Some(table_ref.schema.to_ascii_lowercase())
2258 };
2259
2260 for column in &table.schema.columns {
2261 let base = column.name.to_ascii_lowercase();
2262 let short = base.rsplit('.').next().unwrap_or(base.as_str()).to_string();
2263
2264 lookup.entry(short.clone()).or_insert(column.field_id);
2265 lookup.entry(base.clone()).or_insert(column.field_id);
2266
2267 lookup
2268 .entry(format!("{table_lower}.{short}"))
2269 .or_insert(column.field_id);
2270
2271 if display_lower != table_lower {
2272 lookup
2273 .entry(format!("{display_lower}.{short}"))
2274 .or_insert(column.field_id);
2275 }
2276
2277 if qualified_lower != table_lower {
2278 lookup
2279 .entry(format!("{qualified_lower}.{short}"))
2280 .or_insert(column.field_id);
2281 }
2282
2283 if let Some(schema) = &schema_lower {
2284 lookup
2285 .entry(format!("{schema}.{table_lower}.{short}"))
2286 .or_insert(column.field_id);
2287 if display_lower != table_lower {
2288 lookup
2289 .entry(format!("{schema}.{display_lower}.{short}"))
2290 .or_insert(column.field_id);
2291 }
2292 }
2293
2294 if let Some(alias) = &alias_lower {
2295 lookup
2296 .entry(format!("{alias}.{short}"))
2297 .or_insert(column.field_id);
2298 }
2299 }
2300
2301 lookup
2302}
2303
2304#[derive(Clone, Copy)]
2305enum JoinColumnSide {
2306 Left,
2307 Right,
2308}
2309
2310fn resolve_join_column(
2311 column: &str,
2312 left_lookup: &FxHashMap<String, FieldId>,
2313 right_lookup: &FxHashMap<String, FieldId>,
2314) -> ExecutorResult<(JoinColumnSide, FieldId)> {
2315 let key = column.to_ascii_lowercase();
2316 match (left_lookup.get(&key), right_lookup.get(&key)) {
2317 (Some(&field_id), None) => Ok((JoinColumnSide::Left, field_id)),
2318 (None, Some(&field_id)) => Ok((JoinColumnSide::Right, field_id)),
2319 (Some(_), Some(_)) => Err(Error::InvalidArgumentError(format!(
2320 "join column '{column}' is ambiguous; qualify it with a table name or alias",
2321 ))),
2322 (None, None) => Err(Error::InvalidArgumentError(format!(
2323 "join column '{column}' was not found in either table",
2324 ))),
2325 }
2326}
2327
2328fn execute_hash_join_batches(
2339 left_schema: &Arc<Schema>,
2340 left_batches: &[RecordBatch],
2341 right_schema: &Arc<Schema>,
2342 right_batches: &[RecordBatch],
2343 condition: &LlkvExpr<'static, String>,
2344 join_type: llkv_join::JoinType,
2345) -> ExecutorResult<Vec<RecordBatch>> {
2346 let equalities = match analyze_join_condition(condition)? {
2348 JoinConditionAnalysis::AlwaysTrue => {
2349 let results: Vec<RecordBatch> = left_batches
2351 .par_iter()
2352 .flat_map(|left| {
2353 right_batches
2354 .par_iter()
2355 .map(move |right| execute_cross_join_batches(left, right))
2356 })
2357 .collect::<ExecutorResult<Vec<RecordBatch>>>()?;
2358 return Ok(results);
2359 }
2360 JoinConditionAnalysis::AlwaysFalse => {
2361 let combined_fields: Vec<Field> = left_schema
2363 .fields()
2364 .iter()
2365 .chain(right_schema.fields().iter())
2366 .map(|f| Field::new(f.name().clone(), f.data_type().clone(), f.is_nullable()))
2367 .collect();
2368 let combined_schema = Arc::new(Schema::new(combined_fields));
2369
2370 let mut results = Vec::new();
2371 match join_type {
2372 llkv_join::JoinType::Inner
2373 | llkv_join::JoinType::Semi
2374 | llkv_join::JoinType::Anti => {
2375 results.push(RecordBatch::new_empty(combined_schema));
2376 }
2377 llkv_join::JoinType::Left => {
2378 for left_batch in left_batches {
2379 let row_count = left_batch.num_rows();
2380 if row_count == 0 {
2381 continue;
2382 }
2383 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2384 columns.extend(left_batch.columns().iter().cloned());
2385 for field in right_schema.fields() {
2386 columns.push(new_null_array(field.data_type(), row_count));
2387 }
2388 results.push(
2389 RecordBatch::try_new(Arc::clone(&combined_schema), columns)
2390 .map_err(|err| {
2391 Error::Internal(format!(
2392 "failed to materialize LEFT JOIN null-extension batch: {err}"
2393 ))
2394 })?,
2395 );
2396 }
2397 if results.is_empty() {
2398 results.push(RecordBatch::new_empty(combined_schema));
2399 }
2400 }
2401 llkv_join::JoinType::Right => {
2402 for right_batch in right_batches {
2403 let row_count = right_batch.num_rows();
2404 if row_count == 0 {
2405 continue;
2406 }
2407 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2408 for field in left_schema.fields() {
2409 columns.push(new_null_array(field.data_type(), row_count));
2410 }
2411 columns.extend(right_batch.columns().iter().cloned());
2412 results.push(
2413 RecordBatch::try_new(Arc::clone(&combined_schema), columns)
2414 .map_err(|err| {
2415 Error::Internal(format!(
2416 "failed to materialize RIGHT JOIN null-extension batch: {err}"
2417 ))
2418 })?,
2419 );
2420 }
2421 if results.is_empty() {
2422 results.push(RecordBatch::new_empty(combined_schema));
2423 }
2424 }
2425 llkv_join::JoinType::Full => {
2426 for left_batch in left_batches {
2427 let row_count = left_batch.num_rows();
2428 if row_count == 0 {
2429 continue;
2430 }
2431 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2432 columns.extend(left_batch.columns().iter().cloned());
2433 for field in right_schema.fields() {
2434 columns.push(new_null_array(field.data_type(), row_count));
2435 }
2436 results.push(
2437 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(
2438 |err| {
2439 Error::Internal(format!(
2440 "failed to materialize FULL JOIN left batch: {err}"
2441 ))
2442 },
2443 )?,
2444 );
2445 }
2446
2447 for right_batch in right_batches {
2448 let row_count = right_batch.num_rows();
2449 if row_count == 0 {
2450 continue;
2451 }
2452 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2453 for field in left_schema.fields() {
2454 columns.push(new_null_array(field.data_type(), row_count));
2455 }
2456 columns.extend(right_batch.columns().iter().cloned());
2457 results.push(
2458 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(
2459 |err| {
2460 Error::Internal(format!(
2461 "failed to materialize FULL JOIN right batch: {err}"
2462 ))
2463 },
2464 )?,
2465 );
2466 }
2467
2468 if results.is_empty() {
2469 results.push(RecordBatch::new_empty(combined_schema));
2470 }
2471 }
2472 }
2473
2474 return Ok(results);
2475 }
2476 JoinConditionAnalysis::EquiPairs(pairs) => pairs,
2477 };
2478
2479 let mut left_lookup: FxHashMap<String, usize> = FxHashMap::default();
2481 for (idx, field) in left_schema.fields().iter().enumerate() {
2482 left_lookup.insert(field.name().to_ascii_lowercase(), idx);
2483 }
2484
2485 let mut right_lookup: FxHashMap<String, usize> = FxHashMap::default();
2486 for (idx, field) in right_schema.fields().iter().enumerate() {
2487 right_lookup.insert(field.name().to_ascii_lowercase(), idx);
2488 }
2489
2490 let mut left_key_indices = Vec::new();
2492 let mut right_key_indices = Vec::new();
2493
2494 for (lhs_col, rhs_col) in equalities {
2495 let lhs_lower = lhs_col.to_ascii_lowercase();
2496 let rhs_lower = rhs_col.to_ascii_lowercase();
2497
2498 let (left_idx, right_idx) =
2499 match (left_lookup.get(&lhs_lower), right_lookup.get(&rhs_lower)) {
2500 (Some(&l), Some(&r)) => (l, r),
2501 (Some(_), None) => {
2502 if left_lookup.contains_key(&rhs_lower) {
2503 return Err(Error::InvalidArgumentError(format!(
2504 "Both join columns '{}' and '{}' are from left table",
2505 lhs_col, rhs_col
2506 )));
2507 }
2508 return Err(Error::InvalidArgumentError(format!(
2509 "Join column '{}' not found in right table",
2510 rhs_col
2511 )));
2512 }
2513 (None, Some(_)) => {
2514 if right_lookup.contains_key(&lhs_lower) {
2515 return Err(Error::InvalidArgumentError(format!(
2516 "Both join columns '{}' and '{}' are from right table",
2517 lhs_col, rhs_col
2518 )));
2519 }
2520 return Err(Error::InvalidArgumentError(format!(
2521 "Join column '{}' not found in left table",
2522 lhs_col
2523 )));
2524 }
2525 (None, None) => {
2526 match (left_lookup.get(&rhs_lower), right_lookup.get(&lhs_lower)) {
2528 (Some(&l), Some(&r)) => (l, r),
2529 _ => {
2530 return Err(Error::InvalidArgumentError(format!(
2531 "Join columns '{}' and '{}' not found in either table",
2532 lhs_col, rhs_col
2533 )));
2534 }
2535 }
2536 }
2537 };
2538
2539 left_key_indices.push(left_idx);
2540 right_key_indices.push(right_idx);
2541 }
2542
2543 let mut hash_table: FxHashMap<Vec<i64>, Vec<(usize, usize)>> = FxHashMap::default();
2546
2547 for (batch_idx, right_batch) in right_batches.iter().enumerate() {
2548 let num_rows = right_batch.num_rows();
2549 if num_rows == 0 {
2550 continue;
2551 }
2552
2553 let key_columns: Vec<&ArrayRef> = right_key_indices
2555 .iter()
2556 .map(|&idx| right_batch.column(idx))
2557 .collect();
2558
2559 for row_idx in 0..num_rows {
2561 let mut key_values = Vec::with_capacity(key_columns.len());
2563 let mut has_null = false;
2564
2565 for col in &key_columns {
2566 if col.is_null(row_idx) {
2567 has_null = true;
2568 break;
2569 }
2570 let value = extract_key_value_as_i64(col, row_idx)?;
2572 key_values.push(value);
2573 }
2574
2575 if has_null {
2577 continue;
2578 }
2579
2580 hash_table
2581 .entry(key_values)
2582 .or_default()
2583 .push((batch_idx, row_idx));
2584 }
2585 }
2586
2587 let combined_fields: Vec<Field> = left_schema
2589 .fields()
2590 .iter()
2591 .chain(right_schema.fields().iter())
2592 .map(|f| Field::new(f.name().clone(), f.data_type().clone(), true)) .collect();
2594 let combined_schema = Arc::new(Schema::new(combined_fields));
2595
2596 let mut result_batches: Vec<RecordBatch> = left_batches
2597 .par_iter()
2598 .map(|left_batch| -> ExecutorResult<Option<RecordBatch>> {
2599 let num_rows = left_batch.num_rows();
2600 if num_rows == 0 {
2601 return Ok(None);
2602 }
2603
2604 let left_key_columns: Vec<&ArrayRef> = left_key_indices
2606 .iter()
2607 .map(|&idx| left_batch.column(idx))
2608 .collect();
2609
2610 let mut left_matched = vec![false; num_rows];
2612
2613 let mut left_indices = Vec::new();
2615 let mut right_refs = Vec::new();
2616
2617 for (left_row_idx, matched) in left_matched.iter_mut().enumerate() {
2618 let mut key_values = Vec::with_capacity(left_key_columns.len());
2620 let mut has_null = false;
2621
2622 for col in &left_key_columns {
2623 if col.is_null(left_row_idx) {
2624 has_null = true;
2625 break;
2626 }
2627 let value = extract_key_value_as_i64(col, left_row_idx)?;
2628 key_values.push(value);
2629 }
2630
2631 if has_null {
2632 continue;
2634 }
2635
2636 if let Some(right_rows) = hash_table.get(&key_values) {
2638 *matched = true;
2639 for &(right_batch_idx, right_row_idx) in right_rows {
2640 left_indices.push(left_row_idx as u32);
2641 right_refs.push((right_batch_idx, right_row_idx));
2642 }
2643 }
2644 }
2645
2646 if !left_indices.is_empty() || join_type == llkv_join::JoinType::Left {
2648 let output_batch = build_join_output_batch(
2649 left_batch,
2650 right_batches,
2651 &left_indices,
2652 &right_refs,
2653 &left_matched,
2654 &combined_schema,
2655 join_type,
2656 )?;
2657
2658 if output_batch.num_rows() > 0 {
2659 return Ok(Some(output_batch));
2660 }
2661 }
2662 Ok(None)
2663 })
2664 .collect::<ExecutorResult<Vec<Option<RecordBatch>>>>()?
2665 .into_iter()
2666 .flatten()
2667 .collect();
2668
2669 if result_batches.is_empty() {
2670 result_batches.push(RecordBatch::new_empty(combined_schema));
2671 }
2672
2673 Ok(result_batches)
2674}
2675
2676fn extract_key_value_as_i64(col: &ArrayRef, row_idx: usize) -> ExecutorResult<i64> {
2678 use arrow::array::*;
2679 use arrow::datatypes::DataType;
2680
2681 match col.data_type() {
2682 DataType::Int8 => Ok(col
2683 .as_any()
2684 .downcast_ref::<Int8Array>()
2685 .unwrap()
2686 .value(row_idx) as i64),
2687 DataType::Int16 => Ok(col
2688 .as_any()
2689 .downcast_ref::<Int16Array>()
2690 .unwrap()
2691 .value(row_idx) as i64),
2692 DataType::Int32 => Ok(col
2693 .as_any()
2694 .downcast_ref::<Int32Array>()
2695 .unwrap()
2696 .value(row_idx) as i64),
2697 DataType::Int64 => Ok(col
2698 .as_any()
2699 .downcast_ref::<Int64Array>()
2700 .unwrap()
2701 .value(row_idx)),
2702 DataType::UInt8 => Ok(col
2703 .as_any()
2704 .downcast_ref::<UInt8Array>()
2705 .unwrap()
2706 .value(row_idx) as i64),
2707 DataType::UInt16 => Ok(col
2708 .as_any()
2709 .downcast_ref::<UInt16Array>()
2710 .unwrap()
2711 .value(row_idx) as i64),
2712 DataType::UInt32 => Ok(col
2713 .as_any()
2714 .downcast_ref::<UInt32Array>()
2715 .unwrap()
2716 .value(row_idx) as i64),
2717 DataType::UInt64 => {
2718 let val = col
2719 .as_any()
2720 .downcast_ref::<UInt64Array>()
2721 .unwrap()
2722 .value(row_idx);
2723 Ok(val as i64) }
2725 DataType::Utf8 => {
2726 let s = col
2728 .as_any()
2729 .downcast_ref::<StringArray>()
2730 .unwrap()
2731 .value(row_idx);
2732 use std::collections::hash_map::DefaultHasher;
2733 use std::hash::{Hash, Hasher};
2734 let mut hasher = DefaultHasher::new();
2735 s.hash(&mut hasher);
2736 Ok(hasher.finish() as i64)
2737 }
2738 _ => Err(Error::InvalidArgumentError(format!(
2739 "Unsupported join key type: {:?}",
2740 col.data_type()
2741 ))),
2742 }
2743}
2744
2745fn build_join_output_batch(
2747 left_batch: &RecordBatch,
2748 right_batches: &[RecordBatch],
2749 left_indices: &[u32],
2750 right_refs: &[(usize, usize)],
2751 left_matched: &[bool],
2752 combined_schema: &Arc<Schema>,
2753 join_type: llkv_join::JoinType,
2754) -> ExecutorResult<RecordBatch> {
2755 use arrow::array::UInt32Array;
2756 use arrow::compute::take;
2757
2758 match join_type {
2759 llkv_join::JoinType::Inner => {
2760 let left_indices_array = UInt32Array::from(left_indices.to_vec());
2762
2763 let mut output_columns = Vec::new();
2764
2765 for col in left_batch.columns() {
2767 let taken = take(col.as_ref(), &left_indices_array, None)
2768 .map_err(|e| Error::Internal(format!("Failed to take left column: {}", e)))?;
2769 output_columns.push(taken);
2770 }
2771
2772 for right_col_idx in 0..right_batches[0].num_columns() {
2774 let mut values = Vec::with_capacity(right_refs.len());
2775 for &(batch_idx, row_idx) in right_refs {
2776 let col = right_batches[batch_idx].column(right_col_idx);
2777 values.push((col.clone(), row_idx));
2778 }
2779
2780 let right_col = gather_from_multiple_batches(
2782 &values,
2783 right_batches[0].column(right_col_idx).data_type(),
2784 )?;
2785 output_columns.push(right_col);
2786 }
2787
2788 RecordBatch::try_new(Arc::clone(combined_schema), output_columns)
2789 .map_err(|e| Error::Internal(format!("Failed to create output batch: {}", e)))
2790 }
2791 llkv_join::JoinType::Left => {
2792 let mut output_columns = Vec::new();
2794
2795 for col in left_batch.columns() {
2797 output_columns.push(col.clone());
2798 }
2799
2800 for right_col_idx in 0..right_batches[0].num_columns() {
2802 let right_col = build_left_join_column(
2803 left_matched,
2804 right_batches,
2805 right_col_idx,
2806 left_indices,
2807 right_refs,
2808 )?;
2809 output_columns.push(right_col);
2810 }
2811
2812 RecordBatch::try_new(Arc::clone(combined_schema), output_columns)
2813 .map_err(|e| Error::Internal(format!("Failed to create left join batch: {}", e)))
2814 }
2815 _ => Err(Error::InvalidArgumentError(format!(
2816 "{:?} join not yet implemented in batch join",
2817 join_type
2818 ))),
2819 }
2820}
2821
2822fn gather_from_multiple_batches(
2827 values: &[(ArrayRef, usize)],
2828 _data_type: &DataType,
2829) -> ExecutorResult<ArrayRef> {
2830 use arrow::array::*;
2831 use arrow::compute::take;
2832
2833 if values.is_empty() {
2834 return Ok(new_null_array(&DataType::Null, 0));
2835 }
2836
2837 if values.len() > 1 {
2839 let first_array_ptr = Arc::as_ptr(&values[0].0);
2840 let all_same_array = values
2841 .iter()
2842 .all(|(arr, _)| std::ptr::addr_eq(Arc::as_ptr(arr), first_array_ptr));
2843
2844 if all_same_array {
2845 let indices: Vec<u32> = values.iter().map(|(_, idx)| *idx as u32).collect();
2848 let indices_array = UInt32Array::from(indices);
2849 return take(values[0].0.as_ref(), &indices_array, None)
2850 .map_err(|e| Error::Internal(format!("Arrow take failed: {}", e)));
2851 }
2852 }
2853
2854 use arrow::compute::concat;
2857
2858 let mut unique_arrays: Vec<(Arc<dyn Array>, Vec<usize>)> = Vec::new();
2860 let mut array_map: FxHashMap<*const dyn Array, usize> = FxHashMap::default();
2861
2862 for (arr, row_idx) in values {
2863 let ptr = Arc::as_ptr(arr);
2864 if let Some(&idx) = array_map.get(&ptr) {
2865 unique_arrays[idx].1.push(*row_idx);
2866 } else {
2867 let idx = unique_arrays.len();
2868 array_map.insert(ptr, idx);
2869 unique_arrays.push((Arc::clone(arr), vec![*row_idx]));
2870 }
2871 }
2872
2873 if unique_arrays.len() == 1 {
2875 let (arr, indices) = &unique_arrays[0];
2876 let indices_u32: Vec<u32> = indices.iter().map(|&i| i as u32).collect();
2877 let indices_array = UInt32Array::from(indices_u32);
2878 return take(arr.as_ref(), &indices_array, None)
2879 .map_err(|e| Error::Internal(format!("Arrow take failed: {}", e)));
2880 }
2881
2882 let arrays_to_concat: Vec<&dyn Array> =
2884 unique_arrays.iter().map(|(arr, _)| arr.as_ref()).collect();
2885
2886 let concatenated = concat(&arrays_to_concat)
2887 .map_err(|e| Error::Internal(format!("Arrow concat failed: {}", e)))?;
2888
2889 let mut offset = 0;
2891 let mut adjusted_indices = Vec::with_capacity(values.len());
2892 for (arr, _) in &unique_arrays {
2893 let arr_len = arr.len();
2894 for (check_arr, row_idx) in values {
2895 if Arc::ptr_eq(arr, check_arr) {
2896 adjusted_indices.push((offset + row_idx) as u32);
2897 }
2898 }
2899 offset += arr_len;
2900 }
2901
2902 let indices_array = UInt32Array::from(adjusted_indices);
2903 take(&concatenated, &indices_array, None)
2904 .map_err(|e| Error::Internal(format!("Arrow take on concatenated failed: {}", e)))
2905}
2906
2907fn build_left_join_column(
2909 left_matched: &[bool],
2910 right_batches: &[RecordBatch],
2911 right_col_idx: usize,
2912 _left_indices: &[u32],
2913 _right_refs: &[(usize, usize)],
2914) -> ExecutorResult<ArrayRef> {
2915 let data_type = right_batches[0].column(right_col_idx).data_type();
2918 Ok(new_null_array(data_type, left_matched.len()))
2919}
2920
2921fn execute_cross_join_batches(
2923 left: &RecordBatch,
2924 right: &RecordBatch,
2925) -> ExecutorResult<RecordBatch> {
2926 let combined_fields: Vec<Field> = left
2927 .schema()
2928 .fields()
2929 .iter()
2930 .chain(right.schema().fields().iter())
2931 .map(|f| Field::new(f.name().clone(), f.data_type().clone(), f.is_nullable()))
2932 .collect();
2933 let combined_schema = Arc::new(Schema::new(combined_fields));
2934
2935 cross_join_pair(left, right, &combined_schema)
2936}
2937
2938#[allow(dead_code)]
2940fn build_temp_table_from_batches<P>(
2941 _schema: &Arc<Schema>,
2942 _batches: &[RecordBatch],
2943) -> ExecutorResult<llkv_table::Table<P>>
2944where
2945 P: Pager<Blob = EntryHandle> + Send + Sync,
2946{
2947 Err(Error::Internal(
2949 "build_temp_table_from_batches should not be called".into(),
2950 ))
2951}
2952
2953#[allow(dead_code)]
2955fn build_join_keys_from_condition_indexed(
2956 _condition: &LlkvExpr<'static, String>,
2957 _left_data: &TableCrossProductData,
2958 _right_data: &TableCrossProductData,
2959 _left_idx: usize,
2960 _right_idx: usize,
2961) -> ExecutorResult<JoinKeyBuild> {
2962 Err(Error::Internal(
2964 "build_join_keys_from_condition_indexed should not be called".into(),
2965 ))
2966}
2967
2968#[cfg(test)]
2969mod join_condition_tests {
2970 use super::*;
2971 use llkv_expr::expr::{CompareOp, ScalarExpr};
2972 use llkv_expr::literal::Literal;
2973
2974 #[test]
2975 fn analyze_detects_simple_equality() {
2976 let expr = LlkvExpr::Compare {
2977 left: ScalarExpr::Column("t1.col".into()),
2978 op: CompareOp::Eq,
2979 right: ScalarExpr::Column("t2.col".into()),
2980 };
2981
2982 match analyze_join_condition(&expr).expect("analysis succeeds") {
2983 JoinConditionAnalysis::EquiPairs(pairs) => {
2984 assert_eq!(pairs, vec![("t1.col".to_string(), "t2.col".to_string())]);
2985 }
2986 other => panic!("unexpected analysis result: {other:?}"),
2987 }
2988 }
2989
2990 #[test]
2991 fn analyze_handles_literal_true() {
2992 let expr = LlkvExpr::Literal(true);
2993 assert!(matches!(
2994 analyze_join_condition(&expr).expect("analysis succeeds"),
2995 JoinConditionAnalysis::AlwaysTrue
2996 ));
2997 }
2998
2999 #[test]
3000 fn analyze_rejects_non_equality() {
3001 let expr = LlkvExpr::Compare {
3002 left: ScalarExpr::Column("t1.col".into()),
3003 op: CompareOp::Gt,
3004 right: ScalarExpr::Column("t2.col".into()),
3005 };
3006 assert!(analyze_join_condition(&expr).is_err());
3007 }
3008
3009 #[test]
3010 fn analyze_handles_constant_is_not_null() {
3011 let expr = LlkvExpr::IsNull {
3012 expr: ScalarExpr::Literal(Literal::Null),
3013 negated: true,
3014 };
3015
3016 assert!(matches!(
3017 analyze_join_condition(&expr).expect("analysis succeeds"),
3018 JoinConditionAnalysis::AlwaysFalse
3019 ));
3020 }
3021
3022 #[test]
3023 fn analyze_handles_not_applied_to_is_not_null() {
3024 let expr = LlkvExpr::Not(Box::new(LlkvExpr::IsNull {
3025 expr: ScalarExpr::Literal(Literal::Int128(86)),
3026 negated: true,
3027 }));
3028
3029 assert!(matches!(
3030 analyze_join_condition(&expr).expect("analysis succeeds"),
3031 JoinConditionAnalysis::AlwaysFalse
3032 ));
3033 }
3034
3035 #[test]
3036 fn analyze_literal_is_null_is_always_false() {
3037 let expr = LlkvExpr::IsNull {
3038 expr: ScalarExpr::Literal(Literal::Int128(1)),
3039 negated: false,
3040 };
3041
3042 assert!(matches!(
3043 analyze_join_condition(&expr).expect("analysis succeeds"),
3044 JoinConditionAnalysis::AlwaysFalse
3045 ));
3046 }
3047
3048 #[test]
3049 fn analyze_not_null_comparison_is_always_false() {
3050 let expr = LlkvExpr::Not(Box::new(LlkvExpr::Compare {
3051 left: ScalarExpr::Literal(Literal::Null),
3052 op: CompareOp::Lt,
3053 right: ScalarExpr::Column("t2.col".into()),
3054 }));
3055
3056 assert!(matches!(
3057 analyze_join_condition(&expr).expect("analysis succeeds"),
3058 JoinConditionAnalysis::AlwaysFalse
3059 ));
3060 }
3061}
3062
3063#[cfg(test)]
3064mod cross_join_batch_tests {
3065 use super::*;
3066 use arrow::array::Int32Array;
3067
3068 #[test]
3069 fn execute_cross_join_batches_emits_full_cartesian_product() {
3070 let left_schema = Arc::new(Schema::new(vec![Field::new("l", DataType::Int32, false)]));
3071 let right_schema = Arc::new(Schema::new(vec![Field::new("r", DataType::Int32, false)]));
3072
3073 let left_batch = RecordBatch::try_new(
3074 Arc::clone(&left_schema),
3075 vec![Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef],
3076 )
3077 .expect("left batch");
3078 let right_batch = RecordBatch::try_new(
3079 Arc::clone(&right_schema),
3080 vec![Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef],
3081 )
3082 .expect("right batch");
3083
3084 let result = execute_cross_join_batches(&left_batch, &right_batch).expect("cross join");
3085
3086 assert_eq!(result.num_rows(), 6);
3087 assert_eq!(result.num_columns(), 2);
3088
3089 let left_values: Vec<i32> = {
3090 let array = result
3091 .column(0)
3092 .as_any()
3093 .downcast_ref::<Int32Array>()
3094 .unwrap();
3095 (0..array.len()).map(|idx| array.value(idx)).collect()
3096 };
3097 let right_values: Vec<i32> = {
3098 let array = result
3099 .column(1)
3100 .as_any()
3101 .downcast_ref::<Int32Array>()
3102 .unwrap();
3103 (0..array.len()).map(|idx| array.value(idx)).collect()
3104 };
3105
3106 assert_eq!(left_values, vec![1, 1, 1, 2, 2, 2]);
3107 assert_eq!(right_values, vec![10, 20, 30, 10, 20, 30]);
3108 }
3109}
3110
3111impl<P> QueryExecutor<P>
3112where
3113 P: Pager<Blob = EntryHandle> + Send + Sync,
3114{
3115 fn execute_cross_product_aggregates(
3116 &self,
3117 combined_schema: Arc<Schema>,
3118 batches: Vec<RecordBatch>,
3119 column_lookup_map: &FxHashMap<String, usize>,
3120 plan: &SelectPlan,
3121 display_name: &str,
3122 ) -> ExecutorResult<SelectExecution<P>> {
3123 if !plan.scalar_subqueries.is_empty() {
3124 return Err(Error::InvalidArgumentError(
3125 "scalar subqueries not supported in aggregate joins".into(),
3126 ));
3127 }
3128
3129 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
3130 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(plan.aggregates.len());
3131
3132 for aggregate in &plan.aggregates {
3133 match aggregate {
3134 AggregateExpr::CountStar { alias, distinct } => {
3135 specs.push(AggregateSpec {
3136 alias: alias.clone(),
3137 kind: AggregateKind::Count {
3138 field_id: None,
3139 distinct: *distinct,
3140 },
3141 });
3142 spec_to_projection.push(None);
3143 }
3144 AggregateExpr::Column {
3145 column,
3146 alias,
3147 function,
3148 distinct,
3149 } => {
3150 let key = column.to_ascii_lowercase();
3151 let column_index = *column_lookup_map.get(&key).ok_or_else(|| {
3152 Error::InvalidArgumentError(format!(
3153 "unknown column '{column}' in aggregate"
3154 ))
3155 })?;
3156 let field = combined_schema.field(column_index);
3157 let kind = match function {
3158 AggregateFunction::Count => AggregateKind::Count {
3159 field_id: Some(column_index as u32),
3160 distinct: *distinct,
3161 },
3162 AggregateFunction::SumInt64 => {
3163 let input_type = Self::validate_aggregate_type(
3164 Some(field.data_type().clone()),
3165 "SUM",
3166 &[DataType::Int64, DataType::Float64],
3167 )?;
3168 AggregateKind::Sum {
3169 field_id: column_index as u32,
3170 data_type: input_type,
3171 distinct: *distinct,
3172 }
3173 }
3174 AggregateFunction::TotalInt64 => {
3175 let input_type = Self::validate_aggregate_type(
3176 Some(field.data_type().clone()),
3177 "TOTAL",
3178 &[DataType::Int64, DataType::Float64],
3179 )?;
3180 AggregateKind::Total {
3181 field_id: column_index as u32,
3182 data_type: input_type,
3183 distinct: *distinct,
3184 }
3185 }
3186 AggregateFunction::MinInt64 => {
3187 let input_type = Self::validate_aggregate_type(
3188 Some(field.data_type().clone()),
3189 "MIN",
3190 &[DataType::Int64, DataType::Float64],
3191 )?;
3192 AggregateKind::Min {
3193 field_id: column_index as u32,
3194 data_type: input_type,
3195 }
3196 }
3197 AggregateFunction::MaxInt64 => {
3198 let input_type = Self::validate_aggregate_type(
3199 Some(field.data_type().clone()),
3200 "MAX",
3201 &[DataType::Int64, DataType::Float64],
3202 )?;
3203 AggregateKind::Max {
3204 field_id: column_index as u32,
3205 data_type: input_type,
3206 }
3207 }
3208 AggregateFunction::CountNulls => AggregateKind::CountNulls {
3209 field_id: column_index as u32,
3210 },
3211 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
3212 field_id: column_index as u32,
3213 distinct: *distinct,
3214 separator: ",".to_string(),
3215 },
3216 };
3217
3218 specs.push(AggregateSpec {
3219 alias: alias.clone(),
3220 kind,
3221 });
3222 spec_to_projection.push(Some(column_index));
3223 }
3224 }
3225 }
3226
3227 if specs.is_empty() {
3228 return Err(Error::InvalidArgumentError(
3229 "aggregate query requires at least one aggregate expression".into(),
3230 ));
3231 }
3232
3233 let mut states = Vec::with_capacity(specs.len());
3234 for (idx, spec) in specs.iter().enumerate() {
3235 states.push(AggregateState {
3236 alias: spec.alias.clone(),
3237 accumulator: AggregateAccumulator::new_with_projection_index(
3238 spec,
3239 spec_to_projection[idx],
3240 None,
3241 )?,
3242 override_value: None,
3243 });
3244 }
3245
3246 for batch in &batches {
3247 for state in &mut states {
3248 state.update(batch)?;
3249 }
3250 }
3251
3252 let mut fields = Vec::with_capacity(states.len());
3253 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
3254 for state in states {
3255 let (field, array) = state.finalize()?;
3256 fields.push(Arc::new(field));
3257 arrays.push(array);
3258 }
3259
3260 let schema = Arc::new(Schema::new(fields));
3261 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
3262
3263 if plan.distinct {
3264 let mut distinct_state = DistinctState::default();
3265 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
3266 Some(filtered) => filtered,
3267 None => RecordBatch::new_empty(Arc::clone(&schema)),
3268 };
3269 }
3270
3271 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
3272 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
3273 }
3274
3275 Ok(SelectExecution::new_single_batch(
3276 display_name.to_string(),
3277 schema,
3278 batch,
3279 ))
3280 }
3281
3282 fn execute_cross_product_computed_aggregates(
3283 &self,
3284 combined_schema: Arc<Schema>,
3285 batches: Vec<RecordBatch>,
3286 column_lookup_map: &FxHashMap<String, usize>,
3287 plan: &SelectPlan,
3288 display_name: &str,
3289 ) -> ExecutorResult<SelectExecution<P>> {
3290 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
3291 for projection in &plan.projections {
3292 match projection {
3293 SelectProjection::Computed { expr, .. } => {
3294 Self::collect_aggregates(expr, &mut aggregate_specs);
3295 }
3296 SelectProjection::AllColumns
3297 | SelectProjection::AllColumnsExcept { .. }
3298 | SelectProjection::Column { .. } => {
3299 return Err(Error::InvalidArgumentError(
3300 "non-computed projections not supported with aggregate expressions".into(),
3301 ));
3302 }
3303 }
3304 }
3305
3306 if aggregate_specs.is_empty() {
3307 return Err(Error::InvalidArgumentError(
3308 "computed aggregate query requires at least one aggregate expression".into(),
3309 ));
3310 }
3311
3312 let aggregate_values = self.compute_cross_product_aggregate_values(
3313 &combined_schema,
3314 &batches,
3315 column_lookup_map,
3316 &aggregate_specs,
3317 )?;
3318
3319 let mut fields = Vec::with_capacity(plan.projections.len());
3320 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
3321
3322 for projection in &plan.projections {
3323 if let SelectProjection::Computed { expr, alias } = projection {
3324 if let ScalarExpr::Aggregate(agg) = expr {
3326 let key = format!("{:?}", agg);
3327 if let Some(agg_value) = aggregate_values.get(&key) {
3328 match agg_value {
3329 AggregateValue::Null => {
3330 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
3331 arrays.push(Arc::new(Int64Array::from(vec![None::<i64>])) as ArrayRef);
3332 }
3333 AggregateValue::Int64(v) => {
3334 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
3335 arrays.push(Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef);
3336 }
3337 AggregateValue::Float64(v) => {
3338 fields.push(Arc::new(Field::new(alias, DataType::Float64, true)));
3339 arrays
3340 .push(Arc::new(Float64Array::from(vec![Some(*v)])) as ArrayRef);
3341 }
3342 AggregateValue::Decimal128 { value, scale } => {
3343 let precision = if *value == 0 {
3345 1
3346 } else {
3347 (*value).abs().to_string().len() as u8
3348 };
3349 fields.push(Arc::new(Field::new(
3350 alias,
3351 DataType::Decimal128(precision, *scale),
3352 true,
3353 )));
3354 let array = Decimal128Array::from(vec![Some(*value)])
3355 .with_precision_and_scale(precision, *scale)
3356 .map_err(|e| {
3357 Error::Internal(format!("invalid Decimal128: {}", e))
3358 })?;
3359 arrays.push(Arc::new(array) as ArrayRef);
3360 }
3361 AggregateValue::String(s) => {
3362 fields.push(Arc::new(Field::new(alias, DataType::Utf8, true)));
3363 arrays
3364 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
3365 as ArrayRef);
3366 }
3367 }
3368 continue;
3369 }
3370 }
3371
3372 let value = Self::evaluate_expr_with_aggregates(expr, &aggregate_values)?;
3374 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
3375 arrays.push(Arc::new(Int64Array::from(vec![value])) as ArrayRef);
3376 }
3377 }
3378
3379 let schema = Arc::new(Schema::new(fields));
3380 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
3381
3382 if plan.distinct {
3383 let mut distinct_state = DistinctState::default();
3384 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
3385 Some(filtered) => filtered,
3386 None => RecordBatch::new_empty(Arc::clone(&schema)),
3387 };
3388 }
3389
3390 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
3391 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
3392 }
3393
3394 Ok(SelectExecution::new_single_batch(
3395 display_name.to_string(),
3396 schema,
3397 batch,
3398 ))
3399 }
3400
3401 fn compute_cross_product_aggregate_values(
3402 &self,
3403 combined_schema: &Arc<Schema>,
3404 batches: &[RecordBatch],
3405 column_lookup_map: &FxHashMap<String, usize>,
3406 aggregate_specs: &[(String, AggregateCall<String>)],
3407 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
3408 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
3409 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
3410
3411 let mut columns_per_batch: Option<Vec<Vec<ArrayRef>>> = None;
3412 let mut augmented_fields: Option<Vec<Field>> = None;
3413 let mut owned_batches: Option<Vec<RecordBatch>> = None;
3414 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
3415 FxHashMap::default();
3416 let mut computed_alias_counter: usize = 0;
3417 let mut expr_context = CrossProductExpressionContext::new(
3418 combined_schema.as_ref(),
3419 column_lookup_map.clone(),
3420 )?;
3421
3422 let mut ensure_computed_column =
3423 |expr: &ScalarExpr<String>| -> ExecutorResult<(usize, DataType)> {
3424 let key = format!("{:?}", expr);
3425 if let Some((idx, dtype)) = computed_projection_cache.get(&key) {
3426 return Ok((*idx, dtype.clone()));
3427 }
3428
3429 if columns_per_batch.is_none() {
3430 let initial_columns: Vec<Vec<ArrayRef>> = batches
3431 .iter()
3432 .map(|batch| batch.columns().to_vec())
3433 .collect();
3434 columns_per_batch = Some(initial_columns);
3435 }
3436 if augmented_fields.is_none() {
3437 augmented_fields = Some(
3438 combined_schema
3439 .fields()
3440 .iter()
3441 .map(|field| field.as_ref().clone())
3442 .collect(),
3443 );
3444 }
3445
3446 let translated = translate_scalar(expr, expr_context.schema(), |name| {
3447 Error::InvalidArgumentError(format!(
3448 "unknown column '{}' in aggregate expression",
3449 name
3450 ))
3451 })?;
3452 let data_type = infer_computed_data_type(expr_context.schema(), &translated)?;
3453
3454 if let Some(columns) = columns_per_batch.as_mut() {
3455 for (batch_idx, batch) in batches.iter().enumerate() {
3456 expr_context.reset();
3457 let array = expr_context.materialize_scalar_array(&translated, batch)?;
3458 if let Some(batch_columns) = columns.get_mut(batch_idx) {
3459 batch_columns.push(array);
3460 }
3461 }
3462 }
3463
3464 let column_index = augmented_fields
3465 .as_ref()
3466 .map(|fields| fields.len())
3467 .unwrap_or_else(|| combined_schema.fields().len());
3468
3469 let alias = format!("__agg_expr_cp_{}", computed_alias_counter);
3470 computed_alias_counter += 1;
3471 augmented_fields
3472 .as_mut()
3473 .expect("augmented fields initialized")
3474 .push(Field::new(&alias, data_type.clone(), true));
3475
3476 computed_projection_cache.insert(key, (column_index, data_type.clone()));
3477 Ok((column_index, data_type))
3478 };
3479
3480 for (key, agg) in aggregate_specs {
3481 match agg {
3482 AggregateCall::CountStar => {
3483 specs.push(AggregateSpec {
3484 alias: key.clone(),
3485 kind: AggregateKind::Count {
3486 field_id: None,
3487 distinct: false,
3488 },
3489 });
3490 spec_to_projection.push(None);
3491 }
3492 AggregateCall::Count { expr, .. }
3493 | AggregateCall::Sum { expr, .. }
3494 | AggregateCall::Total { expr, .. }
3495 | AggregateCall::Avg { expr, .. }
3496 | AggregateCall::Min(expr)
3497 | AggregateCall::Max(expr)
3498 | AggregateCall::CountNulls(expr)
3499 | AggregateCall::GroupConcat { expr, .. } => {
3500 let (column_index, data_type_opt) = if let Some(column) =
3501 try_extract_simple_column(expr)
3502 {
3503 let key_lower = column.to_ascii_lowercase();
3504 let column_index = *column_lookup_map.get(&key_lower).ok_or_else(|| {
3505 Error::InvalidArgumentError(format!(
3506 "unknown column '{column}' in aggregate"
3507 ))
3508 })?;
3509 let field = combined_schema.field(column_index);
3510 (column_index, Some(field.data_type().clone()))
3511 } else {
3512 let (index, dtype) = ensure_computed_column(expr)?;
3513 (index, Some(dtype))
3514 };
3515
3516 let kind = match agg {
3517 AggregateCall::Count { distinct, .. } => {
3518 let field_id = u32::try_from(column_index).map_err(|_| {
3519 Error::InvalidArgumentError(
3520 "aggregate projection index exceeds supported range".into(),
3521 )
3522 })?;
3523 AggregateKind::Count {
3524 field_id: Some(field_id),
3525 distinct: *distinct,
3526 }
3527 }
3528 AggregateCall::Sum { distinct, .. } => {
3529 let input_type = Self::validate_aggregate_type(
3530 data_type_opt.clone(),
3531 "SUM",
3532 &[DataType::Int64, DataType::Float64],
3533 )?;
3534 let field_id = u32::try_from(column_index).map_err(|_| {
3535 Error::InvalidArgumentError(
3536 "aggregate projection index exceeds supported range".into(),
3537 )
3538 })?;
3539 AggregateKind::Sum {
3540 field_id,
3541 data_type: input_type,
3542 distinct: *distinct,
3543 }
3544 }
3545 AggregateCall::Total { distinct, .. } => {
3546 let input_type = Self::validate_aggregate_type(
3547 data_type_opt.clone(),
3548 "TOTAL",
3549 &[DataType::Int64, DataType::Float64],
3550 )?;
3551 let field_id = u32::try_from(column_index).map_err(|_| {
3552 Error::InvalidArgumentError(
3553 "aggregate projection index exceeds supported range".into(),
3554 )
3555 })?;
3556 AggregateKind::Total {
3557 field_id,
3558 data_type: input_type,
3559 distinct: *distinct,
3560 }
3561 }
3562 AggregateCall::Avg { distinct, .. } => {
3563 let input_type = Self::validate_aggregate_type(
3564 data_type_opt.clone(),
3565 "AVG",
3566 &[DataType::Int64, DataType::Float64],
3567 )?;
3568 let field_id = u32::try_from(column_index).map_err(|_| {
3569 Error::InvalidArgumentError(
3570 "aggregate projection index exceeds supported range".into(),
3571 )
3572 })?;
3573 AggregateKind::Avg {
3574 field_id,
3575 data_type: input_type,
3576 distinct: *distinct,
3577 }
3578 }
3579 AggregateCall::Min(_) => {
3580 let input_type = Self::validate_aggregate_type(
3581 data_type_opt.clone(),
3582 "MIN",
3583 &[DataType::Int64, DataType::Float64],
3584 )?;
3585 let field_id = u32::try_from(column_index).map_err(|_| {
3586 Error::InvalidArgumentError(
3587 "aggregate projection index exceeds supported range".into(),
3588 )
3589 })?;
3590 AggregateKind::Min {
3591 field_id,
3592 data_type: input_type,
3593 }
3594 }
3595 AggregateCall::Max(_) => {
3596 let input_type = Self::validate_aggregate_type(
3597 data_type_opt.clone(),
3598 "MAX",
3599 &[DataType::Int64, DataType::Float64],
3600 )?;
3601 let field_id = u32::try_from(column_index).map_err(|_| {
3602 Error::InvalidArgumentError(
3603 "aggregate projection index exceeds supported range".into(),
3604 )
3605 })?;
3606 AggregateKind::Max {
3607 field_id,
3608 data_type: input_type,
3609 }
3610 }
3611 AggregateCall::CountNulls(_) => {
3612 let field_id = u32::try_from(column_index).map_err(|_| {
3613 Error::InvalidArgumentError(
3614 "aggregate projection index exceeds supported range".into(),
3615 )
3616 })?;
3617 AggregateKind::CountNulls { field_id }
3618 }
3619 AggregateCall::GroupConcat {
3620 distinct,
3621 separator,
3622 ..
3623 } => {
3624 let field_id = u32::try_from(column_index).map_err(|_| {
3625 Error::InvalidArgumentError(
3626 "aggregate projection index exceeds supported range".into(),
3627 )
3628 })?;
3629 AggregateKind::GroupConcat {
3630 field_id,
3631 distinct: *distinct,
3632 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
3633 }
3634 }
3635 _ => unreachable!(),
3636 };
3637
3638 specs.push(AggregateSpec {
3639 alias: key.clone(),
3640 kind,
3641 });
3642 spec_to_projection.push(Some(column_index));
3643 }
3644 }
3645 }
3646
3647 if let Some(columns) = columns_per_batch {
3648 let fields = augmented_fields.unwrap_or_else(|| {
3649 combined_schema
3650 .fields()
3651 .iter()
3652 .map(|field| field.as_ref().clone())
3653 .collect()
3654 });
3655 let augmented_schema = Arc::new(Schema::new(fields));
3656 let mut new_batches = Vec::with_capacity(columns.len());
3657 for batch_columns in columns {
3658 let batch = RecordBatch::try_new(Arc::clone(&augmented_schema), batch_columns)
3659 .map_err(|err| {
3660 Error::InvalidArgumentError(format!(
3661 "failed to materialize aggregate projections: {err}"
3662 ))
3663 })?;
3664 new_batches.push(batch);
3665 }
3666 owned_batches = Some(new_batches);
3667 }
3668
3669 let mut states = Vec::with_capacity(specs.len());
3670 for (idx, spec) in specs.iter().enumerate() {
3671 states.push(AggregateState {
3672 alias: spec.alias.clone(),
3673 accumulator: AggregateAccumulator::new_with_projection_index(
3674 spec,
3675 spec_to_projection[idx],
3676 None,
3677 )?,
3678 override_value: None,
3679 });
3680 }
3681
3682 let batch_iter: &[RecordBatch] = if let Some(ref extended) = owned_batches {
3683 extended.as_slice()
3684 } else {
3685 batches
3686 };
3687
3688 for batch in batch_iter {
3689 for state in &mut states {
3690 state.update(batch)?;
3691 }
3692 }
3693
3694 let mut results = FxHashMap::default();
3695 for state in states {
3696 let (field, array) = state.finalize()?;
3697
3698 if let Some(int_array) = array.as_any().downcast_ref::<Int64Array>() {
3700 if int_array.len() != 1 {
3701 return Err(Error::Internal(format!(
3702 "Expected single value from aggregate, got {}",
3703 int_array.len()
3704 )));
3705 }
3706 let value = if int_array.is_null(0) {
3707 AggregateValue::Null
3708 } else {
3709 AggregateValue::Int64(int_array.value(0))
3710 };
3711 results.insert(field.name().to_string(), value);
3712 }
3713 else if let Some(float_array) = array.as_any().downcast_ref::<Float64Array>() {
3715 if float_array.len() != 1 {
3716 return Err(Error::Internal(format!(
3717 "Expected single value from aggregate, got {}",
3718 float_array.len()
3719 )));
3720 }
3721 let value = if float_array.is_null(0) {
3722 AggregateValue::Null
3723 } else {
3724 AggregateValue::Float64(float_array.value(0))
3725 };
3726 results.insert(field.name().to_string(), value);
3727 }
3728 else if let Some(string_array) = array.as_any().downcast_ref::<StringArray>() {
3730 if string_array.len() != 1 {
3731 return Err(Error::Internal(format!(
3732 "Expected single value from aggregate, got {}",
3733 string_array.len()
3734 )));
3735 }
3736 let value = if string_array.is_null(0) {
3737 AggregateValue::Null
3738 } else {
3739 AggregateValue::String(string_array.value(0).to_string())
3740 };
3741 results.insert(field.name().to_string(), value);
3742 }
3743 else if let Some(decimal_array) = array.as_any().downcast_ref::<Decimal128Array>() {
3745 if decimal_array.len() != 1 {
3746 return Err(Error::Internal(format!(
3747 "Expected single value from aggregate, got {}",
3748 decimal_array.len()
3749 )));
3750 }
3751 let value = if decimal_array.is_null(0) {
3752 AggregateValue::Null
3753 } else {
3754 AggregateValue::Decimal128 {
3755 value: decimal_array.value(0),
3756 scale: decimal_array.scale(),
3757 }
3758 };
3759 results.insert(field.name().to_string(), value);
3760 } else {
3761 return Err(Error::Internal(format!(
3762 "Unexpected array type from aggregate: {:?}",
3763 array.data_type()
3764 )));
3765 }
3766 }
3767
3768 Ok(results)
3769 }
3770
3771 fn try_execute_hash_join(
3788 &self,
3789 plan: &SelectPlan,
3790 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
3791 ) -> ExecutorResult<Option<(TableCrossProductData, bool)>> {
3792 let query_label_opt = current_query_label();
3793 let query_label = query_label_opt.as_deref().unwrap_or("<unknown query>");
3794
3795 let filter_wrapper = match &plan.filter {
3797 Some(filter) => filter,
3798 None => {
3799 tracing::debug!(
3800 "join_opt[{query_label}]: skipping optimization – no filter present"
3801 );
3802 return Ok(None);
3804 }
3805 };
3806
3807 let all_inner_joins = plan
3815 .joins
3816 .iter()
3817 .all(|j| j.join_type == llkv_plan::JoinPlan::Inner);
3818
3819 if !plan.joins.is_empty() && !all_inner_joins {
3820 tracing::debug!(
3821 "join_opt[{query_label}]: skipping optimization – explicit non-INNER JOINs present"
3822 );
3823 return Ok(None);
3824 }
3825
3826 if tables_with_handles.len() < 2 {
3827 tracing::debug!(
3828 "join_opt[{query_label}]: skipping optimization – requires at least 2 tables"
3829 );
3830 return Ok(None);
3831 }
3832
3833 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
3835 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
3836 let mut column_map = FxHashMap::default();
3837 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
3838 let column_name = column.name.to_ascii_lowercase();
3839 column_map.entry(column_name).or_insert(column_idx);
3840 }
3841 table_infos.push(TableInfo {
3842 index,
3843 table_ref,
3844 column_map,
3845 });
3846 }
3847
3848 let constraint_plan = match extract_join_constraints(
3850 &filter_wrapper.predicate,
3851 &table_infos,
3852 ) {
3853 Some(plan) => plan,
3854 None => {
3855 tracing::debug!(
3856 "join_opt[{query_label}]: skipping optimization – predicate parsing failed (contains OR or other unsupported top-level structure)"
3857 );
3858 return Ok(None);
3860 }
3861 };
3862
3863 tracing::debug!(
3864 "join_opt[{query_label}]: constraint extraction succeeded - equalities={}, literals={}, handled={}/{} predicates",
3865 constraint_plan.equalities.len(),
3866 constraint_plan.literals.len(),
3867 constraint_plan.handled_conjuncts,
3868 constraint_plan.total_conjuncts
3869 );
3870 tracing::debug!(
3871 "join_opt[{query_label}]: attempting hash join with tables={:?} filter={:?}",
3872 plan.tables
3873 .iter()
3874 .map(|t| t.qualified_name())
3875 .collect::<Vec<_>>(),
3876 filter_wrapper.predicate,
3877 );
3878
3879 if constraint_plan.unsatisfiable {
3881 tracing::debug!(
3882 "join_opt[{query_label}]: predicate unsatisfiable – returning empty result"
3883 );
3884 let mut combined_fields = Vec::new();
3886 let mut column_counts = Vec::new();
3887 for (_table_ref, executor_table) in tables_with_handles {
3888 for column in &executor_table.schema.columns {
3889 combined_fields.push(Field::new(
3890 column.name.clone(),
3891 column.data_type.clone(),
3892 column.nullable,
3893 ));
3894 }
3895 column_counts.push(executor_table.schema.columns.len());
3896 }
3897 let combined_schema = Arc::new(Schema::new(combined_fields));
3898 let empty_batch = RecordBatch::new_empty(Arc::clone(&combined_schema));
3899 return Ok(Some((
3900 TableCrossProductData {
3901 schema: combined_schema,
3902 batches: vec![empty_batch],
3903 column_counts,
3904 table_indices: (0..tables_with_handles.len()).collect(),
3905 },
3906 true, )));
3908 }
3909
3910 if constraint_plan.equalities.is_empty() {
3912 tracing::debug!(
3913 "join_opt[{query_label}]: skipping optimization – no join equalities found"
3914 );
3915 return Ok(None);
3917 }
3918
3919 if !constraint_plan.literals.is_empty() {
3924 tracing::debug!(
3925 "join_opt[{query_label}]: found {} literal constraints - proceeding with hash join but may need fallback",
3926 constraint_plan.literals.len()
3927 );
3928 }
3929
3930 tracing::debug!(
3931 "join_opt[{query_label}]: hash join optimization applicable with {} equality constraints",
3932 constraint_plan.equalities.len()
3933 );
3934
3935 let mut literal_map: Vec<Vec<ColumnConstraint>> =
3936 vec![Vec::new(); tables_with_handles.len()];
3937 for constraint in &constraint_plan.literals {
3938 let table_idx = match constraint {
3939 ColumnConstraint::Equality(lit) => lit.column.table,
3940 ColumnConstraint::InList(in_list) => in_list.column.table,
3941 };
3942 if table_idx >= literal_map.len() {
3943 tracing::debug!(
3944 "join_opt[{query_label}]: constraint references unknown table index {}; falling back",
3945 table_idx
3946 );
3947 return Ok(None);
3949 }
3950 tracing::debug!(
3951 "join_opt[{query_label}]: mapping constraint to table_idx={} (table={})",
3952 table_idx,
3953 tables_with_handles[table_idx].0.qualified_name()
3954 );
3955 literal_map[table_idx].push(constraint.clone());
3956 }
3957
3958 let mut per_table: Vec<Option<TableCrossProductData>> =
3959 Vec::with_capacity(tables_with_handles.len());
3960 for (idx, (table_ref, table)) in tables_with_handles.iter().enumerate() {
3961 let data =
3962 collect_table_data(idx, table_ref, table.as_ref(), literal_map[idx].as_slice())?;
3963 per_table.push(Some(data));
3964 }
3965
3966 let has_left_join = plan
3968 .joins
3969 .iter()
3970 .any(|j| j.join_type == llkv_plan::JoinPlan::Left);
3971
3972 let mut current: Option<TableCrossProductData> = None;
3973
3974 if has_left_join {
3975 tracing::debug!(
3977 "join_opt[{query_label}]: delegating to llkv-join for LEFT JOIN support"
3978 );
3979 return Ok(None);
3982 } else {
3983 let mut remaining: Vec<usize> = (0..tables_with_handles.len()).collect();
3985 let mut used_tables: FxHashSet<usize> = FxHashSet::default();
3986
3987 while !remaining.is_empty() {
3988 let next_index = if used_tables.is_empty() {
3989 remaining[0]
3990 } else {
3991 match remaining.iter().copied().find(|idx| {
3992 table_has_join_with_used(*idx, &used_tables, &constraint_plan.equalities)
3993 }) {
3994 Some(idx) => idx,
3995 None => {
3996 tracing::debug!(
3997 "join_opt[{query_label}]: no remaining equality links – using cartesian expansion for table index {idx}",
3998 idx = remaining[0]
3999 );
4000 remaining[0]
4001 }
4002 }
4003 };
4004
4005 let position = remaining
4006 .iter()
4007 .position(|&idx| idx == next_index)
4008 .expect("next index present");
4009
4010 let next_data = per_table[next_index]
4011 .take()
4012 .ok_or_else(|| Error::Internal("hash join consumed table data twice".into()))?;
4013
4014 if let Some(current_data) = current.take() {
4015 let join_keys = gather_join_keys(
4016 ¤t_data,
4017 &next_data,
4018 &used_tables,
4019 next_index,
4020 &constraint_plan.equalities,
4021 )?;
4022
4023 let joined = if join_keys.is_empty() {
4024 tracing::debug!(
4025 "join_opt[{query_label}]: joining '{}' via cartesian expansion (no equality keys)",
4026 tables_with_handles[next_index].0.qualified_name()
4027 );
4028 cross_join_table_batches(current_data, next_data)?
4029 } else {
4030 hash_join_table_batches(
4031 current_data,
4032 next_data,
4033 &join_keys,
4034 llkv_join::JoinType::Inner,
4035 )?
4036 };
4037 current = Some(joined);
4038 } else {
4039 current = Some(next_data);
4040 }
4041
4042 used_tables.insert(next_index);
4043 remaining.remove(position);
4044 }
4045 }
4046
4047 if let Some(result) = current {
4048 let handled_all = constraint_plan.handled_conjuncts == constraint_plan.total_conjuncts;
4049 tracing::debug!(
4050 "join_opt[{query_label}]: hash join succeeded across {} tables (handled {}/{} predicates)",
4051 tables_with_handles.len(),
4052 constraint_plan.handled_conjuncts,
4053 constraint_plan.total_conjuncts
4054 );
4055 return Ok(Some((result, handled_all)));
4056 }
4057
4058 Ok(None)
4059 }
4060
4061 fn execute_projection(
4062 &self,
4063 table: Arc<ExecutorTable<P>>,
4064 display_name: String,
4065 plan: SelectPlan,
4066 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4067 ) -> ExecutorResult<SelectExecution<P>> {
4068 if plan.having.is_some() {
4069 return Err(Error::InvalidArgumentError(
4070 "HAVING requires GROUP BY".into(),
4071 ));
4072 }
4073
4074 let has_filter_subqueries = plan
4075 .filter
4076 .as_ref()
4077 .is_some_and(|filter| !filter.subqueries.is_empty());
4078 let has_scalar_subqueries = !plan.scalar_subqueries.is_empty();
4079
4080 if has_filter_subqueries || has_scalar_subqueries {
4081 return self.execute_projection_with_subqueries(table, display_name, plan, row_filter);
4082 }
4083
4084 let table_ref = table.as_ref();
4085 let constant_filter = plan
4086 .filter
4087 .as_ref()
4088 .and_then(|filter| evaluate_constant_predicate(&filter.predicate));
4089 let projections = if plan.projections.is_empty() {
4090 build_wildcard_projections(table_ref)
4091 } else {
4092 build_projected_columns(table_ref, &plan.projections)?
4093 };
4094 let schema = schema_for_projections(table_ref, &projections)?;
4095
4096 if let Some(result) = constant_filter {
4097 match result {
4098 Some(true) => {
4099 }
4101 Some(false) | None => {
4102 let batch = RecordBatch::new_empty(Arc::clone(&schema));
4103 return Ok(SelectExecution::new_single_batch(
4104 display_name,
4105 schema,
4106 batch,
4107 ));
4108 }
4109 }
4110 }
4111
4112 let (mut filter_expr, mut full_table_scan) = match &plan.filter {
4113 Some(filter_wrapper) => (
4114 crate::translation::expression::translate_predicate(
4115 filter_wrapper.predicate.clone(),
4116 table_ref.schema.as_ref(),
4117 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4118 )?,
4119 false,
4120 ),
4121 None => {
4122 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4123 Error::InvalidArgumentError(
4124 "table has no columns; cannot perform wildcard scan".into(),
4125 )
4126 })?;
4127 (
4128 crate::translation::expression::full_table_scan_filter(field_id),
4129 true,
4130 )
4131 }
4132 };
4133
4134 if matches!(constant_filter, Some(Some(true))) {
4135 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4136 Error::InvalidArgumentError(
4137 "table has no columns; cannot perform wildcard scan".into(),
4138 )
4139 })?;
4140 filter_expr = crate::translation::expression::full_table_scan_filter(field_id);
4141 full_table_scan = true;
4142 }
4143
4144 let expanded_order = expand_order_targets(&plan.order_by, &projections)?;
4145
4146 let mut physical_order: Option<ScanOrderSpec> = None;
4147
4148 if let Some(first) = expanded_order.first() {
4149 match &first.target {
4150 OrderTarget::Column(name) => {
4151 if table_ref.schema.resolve(name).is_some() {
4152 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
4153 }
4154 }
4155 OrderTarget::Index(position) => match projections.get(*position) {
4156 Some(ScanProjection::Column(_)) => {
4157 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
4158 }
4159 Some(ScanProjection::Computed { .. }) => {}
4160 None => {
4161 return Err(Error::InvalidArgumentError(format!(
4162 "ORDER BY position {} is out of range",
4163 position + 1
4164 )));
4165 }
4166 },
4167 OrderTarget::All => {}
4168 }
4169 }
4170
4171 let options = if let Some(order_spec) = physical_order {
4172 if row_filter.is_some() {
4173 tracing::debug!("Applying MVCC row filter with ORDER BY");
4174 }
4175 ScanStreamOptions {
4176 include_nulls: true,
4177 order: Some(order_spec),
4178 row_id_filter: row_filter.clone(),
4179 include_row_ids: true,
4180 }
4181 } else {
4182 if row_filter.is_some() {
4183 tracing::debug!("Applying MVCC row filter");
4184 }
4185 ScanStreamOptions {
4186 include_nulls: true,
4187 order: None,
4188 row_id_filter: row_filter.clone(),
4189 include_row_ids: true,
4190 }
4191 };
4192
4193 Ok(SelectExecution::new_projection(
4194 display_name,
4195 schema,
4196 table,
4197 projections,
4198 filter_expr,
4199 options,
4200 full_table_scan,
4201 expanded_order,
4202 plan.distinct,
4203 ))
4204 }
4205
4206 fn execute_projection_with_subqueries(
4207 &self,
4208 table: Arc<ExecutorTable<P>>,
4209 display_name: String,
4210 plan: SelectPlan,
4211 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4212 ) -> ExecutorResult<SelectExecution<P>> {
4213 if plan.having.is_some() {
4214 return Err(Error::InvalidArgumentError(
4215 "HAVING requires GROUP BY".into(),
4216 ));
4217 }
4218 let table_ref = table.as_ref();
4219
4220 let (output_scan_projections, effective_projections): (
4221 Vec<ScanProjection>,
4222 Vec<SelectProjection>,
4223 ) = if plan.projections.is_empty() {
4224 (
4225 build_wildcard_projections(table_ref),
4226 vec![SelectProjection::AllColumns],
4227 )
4228 } else {
4229 (
4230 build_projected_columns(table_ref, &plan.projections)?,
4231 plan.projections.clone(),
4232 )
4233 };
4234
4235 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
4236 .scalar_subqueries
4237 .iter()
4238 .map(|subquery| (subquery.id, subquery))
4239 .collect();
4240
4241 let base_projections = build_wildcard_projections(table_ref);
4242
4243 let filter_wrapper_opt = plan.filter.as_ref();
4244
4245 let mut filter_has_scalar_subqueries = false;
4247 if let Some(filter_wrapper) = filter_wrapper_opt {
4248 let translated = crate::translation::expression::translate_predicate(
4249 filter_wrapper.predicate.clone(),
4250 table_ref.schema.as_ref(),
4251 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4252 )?;
4253 let mut scalar_filter_ids = FxHashSet::default();
4254 collect_predicate_scalar_subquery_ids(&translated, &mut scalar_filter_ids);
4255 filter_has_scalar_subqueries = !scalar_filter_ids.is_empty();
4256 }
4257
4258 let mut translated_filter: Option<llkv_expr::expr::Expr<'static, FieldId>> = None;
4259 let pushdown_filter = if let Some(filter_wrapper) = filter_wrapper_opt {
4260 let translated = crate::translation::expression::translate_predicate(
4261 filter_wrapper.predicate.clone(),
4262 table_ref.schema.as_ref(),
4263 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4264 )?;
4265 if !filter_wrapper.subqueries.is_empty() || filter_has_scalar_subqueries {
4266 translated_filter = Some(translated.clone());
4267 if filter_has_scalar_subqueries {
4268 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4271 Error::InvalidArgumentError(
4272 "table has no columns; cannot perform scalar subquery projection"
4273 .into(),
4274 )
4275 })?;
4276 crate::translation::expression::full_table_scan_filter(field_id)
4277 } else {
4278 strip_exists(&translated)
4280 }
4281 } else {
4282 translated
4283 }
4284 } else {
4285 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4286 Error::InvalidArgumentError(
4287 "table has no columns; cannot perform scalar subquery projection".into(),
4288 )
4289 })?;
4290 crate::translation::expression::full_table_scan_filter(field_id)
4291 };
4292
4293 let mut base_fields: Vec<Field> = Vec::with_capacity(table_ref.schema.columns.len());
4294 for column in &table_ref.schema.columns {
4295 base_fields.push(Field::new(
4296 column.name.clone(),
4297 column.data_type.clone(),
4298 column.nullable,
4299 ));
4300 }
4301 let base_schema = Arc::new(Schema::new(base_fields));
4302 let base_column_counts = vec![base_schema.fields().len()];
4303 let base_table_indices = vec![0usize];
4304 let base_lookup = build_cross_product_column_lookup(
4305 base_schema.as_ref(),
4306 &plan.tables,
4307 &base_column_counts,
4308 &base_table_indices,
4309 );
4310
4311 let mut filter_context = if translated_filter.is_some() {
4312 Some(CrossProductExpressionContext::new(
4313 base_schema.as_ref(),
4314 base_lookup.clone(),
4315 )?)
4316 } else {
4317 None
4318 };
4319
4320 let mut filter_scalar_subquery_ids = FxHashSet::default();
4322 if let Some(translated) = translated_filter.as_ref() {
4323 collect_predicate_scalar_subquery_ids(translated, &mut filter_scalar_subquery_ids);
4324 }
4325
4326 let filter_scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> =
4328 if !filter_scalar_subquery_ids.is_empty() {
4329 plan.scalar_subqueries
4330 .iter()
4331 .filter(|subquery| filter_scalar_subquery_ids.contains(&subquery.id))
4332 .map(|subquery| (subquery.id, subquery))
4333 .collect()
4334 } else {
4335 FxHashMap::default()
4336 };
4337
4338 let options = ScanStreamOptions {
4339 include_nulls: true,
4340 order: None,
4341 row_id_filter: row_filter.clone(),
4342 include_row_ids: true,
4343 };
4344
4345 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
4346 filter_wrapper_opt
4347 .map(|wrapper| {
4348 wrapper
4349 .subqueries
4350 .iter()
4351 .map(|subquery| (subquery.id, subquery))
4352 .collect()
4353 })
4354 .unwrap_or_default();
4355
4356 let mut projected_batches: Vec<RecordBatch> = Vec::new();
4357 let mut scan_error: Option<Error> = None;
4358
4359 table.storage().scan_stream(
4360 &base_projections,
4361 &pushdown_filter,
4362 options,
4363 &mut |batch| {
4364 if scan_error.is_some() {
4365 return;
4366 }
4367 let effective_batch = if let Some(context) = filter_context.as_mut() {
4368 context.reset();
4369
4370 for (subquery_id, subquery) in filter_scalar_lookup.iter() {
4372 let result_array = match self
4373 .evaluate_scalar_subquery_numeric(context, subquery, &batch)
4374 {
4375 Ok(array) => array,
4376 Err(err) => {
4377 scan_error = Some(err);
4378 return;
4379 }
4380 };
4381 let accessor = match ColumnAccessor::from_numeric_array(&result_array) {
4382 Ok(acc) => acc,
4383 Err(err) => {
4384 scan_error = Some(err);
4385 return;
4386 }
4387 };
4388 context
4389 .scalar_subquery_columns
4390 .insert(*subquery_id, accessor);
4391 }
4392 let translated = translated_filter
4393 .as_ref()
4394 .expect("filter context requires translated filter");
4395 let mask = match context.evaluate_predicate_mask(
4396 translated,
4397 &batch,
4398 |ctx, subquery_expr, row_idx, current_batch| {
4399 let subquery =
4400 subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
4401 Error::Internal("missing correlated subquery metadata".into())
4402 })?;
4403 let exists = self.evaluate_exists_subquery(
4404 ctx,
4405 subquery,
4406 current_batch,
4407 row_idx,
4408 )?;
4409 let value = if subquery_expr.negated {
4410 !exists
4411 } else {
4412 exists
4413 };
4414 Ok(Some(value))
4415 },
4416 ) {
4417 Ok(mask) => mask,
4418 Err(err) => {
4419 scan_error = Some(err);
4420 return;
4421 }
4422 };
4423 match filter_record_batch(&batch, &mask) {
4424 Ok(filtered) => {
4425 if filtered.num_rows() == 0 {
4426 return;
4427 }
4428 filtered
4429 }
4430 Err(err) => {
4431 scan_error = Some(Error::InvalidArgumentError(format!(
4432 "failed to apply EXISTS filter: {err}"
4433 )));
4434 return;
4435 }
4436 }
4437 } else {
4438 batch.clone()
4439 };
4440
4441 if effective_batch.num_rows() == 0 {
4442 return;
4443 }
4444
4445 let projected = match self.project_record_batch(
4446 &effective_batch,
4447 &effective_projections,
4448 &base_lookup,
4449 &scalar_lookup,
4450 ) {
4451 Ok(batch) => batch,
4452 Err(err) => {
4453 scan_error = Some(Error::InvalidArgumentError(format!(
4454 "failed to evaluate projections: {err}"
4455 )));
4456 return;
4457 }
4458 };
4459 projected_batches.push(projected);
4460 },
4461 )?;
4462
4463 if let Some(err) = scan_error {
4464 return Err(err);
4465 }
4466
4467 let mut result_batch = if projected_batches.is_empty() {
4468 let empty_batch = RecordBatch::new_empty(Arc::clone(&base_schema));
4469 self.project_record_batch(
4470 &empty_batch,
4471 &effective_projections,
4472 &base_lookup,
4473 &scalar_lookup,
4474 )?
4475 } else if projected_batches.len() == 1 {
4476 projected_batches.pop().unwrap()
4477 } else {
4478 let schema = projected_batches[0].schema();
4479 concat_batches(&schema, &projected_batches).map_err(|err| {
4480 Error::Internal(format!("failed to combine filtered batches: {err}"))
4481 })?
4482 };
4483
4484 if plan.distinct && result_batch.num_rows() > 0 {
4485 let mut state = DistinctState::default();
4486 let schema = result_batch.schema();
4487 result_batch = match distinct_filter_batch(result_batch, &mut state)? {
4488 Some(filtered) => filtered,
4489 None => RecordBatch::new_empty(schema),
4490 };
4491 }
4492
4493 if !plan.order_by.is_empty() && result_batch.num_rows() > 0 {
4494 let expanded_order = expand_order_targets(&plan.order_by, &output_scan_projections)?;
4495 if !expanded_order.is_empty() {
4496 result_batch = sort_record_batch_with_order(
4497 &result_batch.schema(),
4498 &result_batch,
4499 &expanded_order,
4500 )?;
4501 }
4502 }
4503
4504 let schema = result_batch.schema();
4505
4506 Ok(SelectExecution::new_single_batch(
4507 display_name,
4508 schema,
4509 result_batch,
4510 ))
4511 }
4512
4513 fn execute_group_by_single_table(
4514 &self,
4515 table: Arc<ExecutorTable<P>>,
4516 display_name: String,
4517 plan: SelectPlan,
4518 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4519 ) -> ExecutorResult<SelectExecution<P>> {
4520 if plan
4521 .filter
4522 .as_ref()
4523 .is_some_and(|filter| !filter.subqueries.is_empty())
4524 || !plan.scalar_subqueries.is_empty()
4525 {
4526 return Err(Error::InvalidArgumentError(
4527 "GROUP BY with subqueries is not supported yet".into(),
4528 ));
4529 }
4530
4531 tracing::debug!(
4533 "[GROUP BY] Original plan: projections={}, aggregates={}, has_filter={}, has_having={}",
4534 plan.projections.len(),
4535 plan.aggregates.len(),
4536 plan.filter.is_some(),
4537 plan.having.is_some()
4538 );
4539
4540 let mut base_plan = plan.clone();
4544 base_plan.projections.clear();
4545 base_plan.aggregates.clear();
4546 base_plan.scalar_subqueries.clear();
4547 base_plan.order_by.clear();
4548 base_plan.distinct = false;
4549 base_plan.group_by.clear();
4550 base_plan.value_table_mode = None;
4551 base_plan.having = None;
4552
4553 tracing::debug!(
4554 "[GROUP BY] Base plan: projections={}, aggregates={}, has_filter={}, has_having={}",
4555 base_plan.projections.len(),
4556 base_plan.aggregates.len(),
4557 base_plan.filter.is_some(),
4558 base_plan.having.is_some()
4559 );
4560
4561 let table_ref = table.as_ref();
4564 let projections = build_wildcard_projections(table_ref);
4565 let base_schema = schema_for_projections(table_ref, &projections)?;
4566
4567 tracing::debug!(
4569 "[GROUP BY] Building base filter: has_filter={}",
4570 base_plan.filter.is_some()
4571 );
4572 let (filter_expr, full_table_scan) = match &base_plan.filter {
4573 Some(filter_wrapper) => {
4574 tracing::debug!(
4575 "[GROUP BY] Translating filter predicate: {:?}",
4576 filter_wrapper.predicate
4577 );
4578 let expr = crate::translation::expression::translate_predicate(
4579 filter_wrapper.predicate.clone(),
4580 table_ref.schema.as_ref(),
4581 |name| {
4582 Error::InvalidArgumentError(format!(
4583 "Binder Error: does not have a column named '{}'",
4584 name
4585 ))
4586 },
4587 )?;
4588 tracing::debug!("[GROUP BY] Translated filter expr: {:?}", expr);
4589 (expr, false)
4590 }
4591 None => {
4592 let first_col =
4594 table_ref.schema.columns.first().ok_or_else(|| {
4595 Error::InvalidArgumentError("Table has no columns".into())
4596 })?;
4597 (full_table_scan_filter(first_col.field_id), true)
4598 }
4599 };
4600
4601 let options = ScanStreamOptions {
4602 include_nulls: true,
4603 order: None,
4604 row_id_filter: row_filter.clone(),
4605 include_row_ids: true,
4606 };
4607
4608 let execution = SelectExecution::new_projection(
4609 display_name.clone(),
4610 Arc::clone(&base_schema),
4611 Arc::clone(&table),
4612 projections,
4613 filter_expr,
4614 options,
4615 full_table_scan,
4616 vec![],
4617 false,
4618 );
4619
4620 let batches = execution.collect()?;
4621
4622 tracing::debug!(
4623 "[GROUP BY] Collected {} batches from base scan, total_rows={}",
4624 batches.len(),
4625 batches.iter().map(|b| b.num_rows()).sum::<usize>()
4626 );
4627
4628 let column_lookup_map = build_column_lookup_map(base_schema.as_ref());
4629
4630 self.execute_group_by_from_batches(
4631 display_name,
4632 plan,
4633 base_schema,
4634 batches,
4635 column_lookup_map,
4636 )
4637 }
4638
4639 fn execute_group_by_from_batches(
4640 &self,
4641 display_name: String,
4642 plan: SelectPlan,
4643 base_schema: Arc<Schema>,
4644 batches: Vec<RecordBatch>,
4645 column_lookup_map: FxHashMap<String, usize>,
4646 ) -> ExecutorResult<SelectExecution<P>> {
4647 if plan
4648 .filter
4649 .as_ref()
4650 .is_some_and(|filter| !filter.subqueries.is_empty())
4651 || !plan.scalar_subqueries.is_empty()
4652 {
4653 return Err(Error::InvalidArgumentError(
4654 "GROUP BY with subqueries is not supported yet".into(),
4655 ));
4656 }
4657
4658 let having_has_aggregates = plan
4661 .having
4662 .as_ref()
4663 .map(|h| Self::predicate_contains_aggregate(h))
4664 .unwrap_or(false);
4665
4666 tracing::debug!(
4667 "[GROUP BY PATH] aggregates={}, has_computed={}, having_has_agg={}",
4668 plan.aggregates.len(),
4669 self.has_computed_aggregates(&plan),
4670 having_has_aggregates
4671 );
4672
4673 if !plan.aggregates.is_empty()
4674 || self.has_computed_aggregates(&plan)
4675 || having_has_aggregates
4676 {
4677 tracing::debug!("[GROUP BY PATH] Taking aggregates path");
4678 return self.execute_group_by_with_aggregates(
4679 display_name,
4680 plan,
4681 base_schema,
4682 batches,
4683 column_lookup_map,
4684 );
4685 }
4686
4687 let mut key_indices = Vec::with_capacity(plan.group_by.len());
4688 for column in &plan.group_by {
4689 let key = column.to_ascii_lowercase();
4690 let index = column_lookup_map.get(&key).ok_or_else(|| {
4691 Error::InvalidArgumentError(format!(
4692 "column '{}' not found in GROUP BY input",
4693 column
4694 ))
4695 })?;
4696 key_indices.push(*index);
4697 }
4698
4699 let sample_batch = batches
4700 .first()
4701 .cloned()
4702 .unwrap_or_else(|| RecordBatch::new_empty(Arc::clone(&base_schema)));
4703
4704 let output_columns = self.build_group_by_output_columns(
4705 &plan,
4706 base_schema.as_ref(),
4707 &column_lookup_map,
4708 &sample_batch,
4709 )?;
4710
4711 let constant_having = plan.having.as_ref().and_then(evaluate_constant_predicate);
4712
4713 if let Some(result) = constant_having
4714 && !result.unwrap_or(false)
4715 {
4716 let fields: Vec<Field> = output_columns
4717 .iter()
4718 .map(|output| output.field.clone())
4719 .collect();
4720 let schema = Arc::new(Schema::new(fields));
4721 let batch = RecordBatch::new_empty(Arc::clone(&schema));
4722 return Ok(SelectExecution::new_single_batch(
4723 display_name,
4724 schema,
4725 batch,
4726 ));
4727 }
4728
4729 let translated_having = if plan.having.is_some() && constant_having.is_none() {
4730 let having = plan.having.clone().expect("checked above");
4731 if Self::predicate_contains_aggregate(&having) {
4734 None
4735 } else {
4736 let temp_context = CrossProductExpressionContext::new(
4737 base_schema.as_ref(),
4738 column_lookup_map.clone(),
4739 )?;
4740 Some(translate_predicate(
4741 having,
4742 temp_context.schema(),
4743 |name| {
4744 Error::InvalidArgumentError(format!(
4745 "column '{}' not found in GROUP BY result",
4746 name
4747 ))
4748 },
4749 )?)
4750 }
4751 } else {
4752 None
4753 };
4754
4755 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
4756 let mut groups: Vec<GroupState> = Vec::new();
4757
4758 for batch in &batches {
4759 for row_idx in 0..batch.num_rows() {
4760 let key = build_group_key(batch, row_idx, &key_indices)?;
4761 if group_index.contains_key(&key) {
4762 continue;
4763 }
4764 group_index.insert(key, groups.len());
4765 groups.push(GroupState {
4766 batch: batch.clone(),
4767 row_idx,
4768 });
4769 }
4770 }
4771
4772 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(groups.len());
4773
4774 for group in &groups {
4775 if let Some(predicate) = translated_having.as_ref() {
4776 let mut context = CrossProductExpressionContext::new(
4777 group.batch.schema().as_ref(),
4778 column_lookup_map.clone(),
4779 )?;
4780 context.reset();
4781 let mut eval = |_ctx: &mut CrossProductExpressionContext,
4782 _subquery_expr: &llkv_expr::SubqueryExpr,
4783 _row_idx: usize,
4784 _current_batch: &RecordBatch|
4785 -> ExecutorResult<Option<bool>> {
4786 Err(Error::InvalidArgumentError(
4787 "HAVING subqueries are not supported yet".into(),
4788 ))
4789 };
4790 let truths =
4791 context.evaluate_predicate_truths(predicate, &group.batch, &mut eval)?;
4792 let passes = truths
4793 .get(group.row_idx)
4794 .copied()
4795 .flatten()
4796 .unwrap_or(false);
4797 if !passes {
4798 continue;
4799 }
4800 }
4801
4802 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
4803 for output in &output_columns {
4804 match output.source {
4805 OutputSource::TableColumn { index } => {
4806 let value = llkv_plan::plan_value_from_array(
4807 group.batch.column(index),
4808 group.row_idx,
4809 )?;
4810 row.push(value);
4811 }
4812 OutputSource::Computed { projection_index } => {
4813 let expr = match &plan.projections[projection_index] {
4814 SelectProjection::Computed { expr, .. } => expr,
4815 _ => unreachable!("projection index mismatch for computed column"),
4816 };
4817 let mut context = CrossProductExpressionContext::new(
4818 group.batch.schema().as_ref(),
4819 column_lookup_map.clone(),
4820 )?;
4821 context.reset();
4822 let evaluated = self.evaluate_projection_expression(
4823 &mut context,
4824 expr,
4825 &group.batch,
4826 &FxHashMap::default(),
4827 )?;
4828 let value = llkv_plan::plan_value_from_array(&evaluated, group.row_idx)?;
4829 row.push(value);
4830 }
4831 }
4832 }
4833 rows.push(row);
4834 }
4835
4836 let fields: Vec<Field> = output_columns
4837 .into_iter()
4838 .map(|output| output.field)
4839 .collect();
4840 let schema = Arc::new(Schema::new(fields));
4841
4842 let mut batch = rows_to_record_batch(Arc::clone(&schema), &rows)?;
4843
4844 if plan.distinct && batch.num_rows() > 0 {
4845 let mut state = DistinctState::default();
4846 batch = match distinct_filter_batch(batch, &mut state)? {
4847 Some(filtered) => filtered,
4848 None => RecordBatch::new_empty(Arc::clone(&schema)),
4849 };
4850 }
4851
4852 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
4853 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
4854 }
4855
4856 Ok(SelectExecution::new_single_batch(
4857 display_name,
4858 schema,
4859 batch,
4860 ))
4861 }
4862
4863 fn infer_computed_expression_type(
4865 expr: &ScalarExpr<String>,
4866 base_schema: &Schema,
4867 column_lookup_map: &FxHashMap<String, usize>,
4868 sample_batch: &RecordBatch,
4869 ) -> Option<DataType> {
4870 use llkv_expr::expr::AggregateCall;
4871
4872 if let ScalarExpr::Aggregate(agg_call) = expr {
4874 return match agg_call {
4875 AggregateCall::CountStar
4876 | AggregateCall::Count { .. }
4877 | AggregateCall::CountNulls(_) => Some(DataType::Int64),
4878 AggregateCall::Sum { expr: agg_expr, .. }
4879 | AggregateCall::Total { expr: agg_expr, .. }
4880 | AggregateCall::Avg { expr: agg_expr, .. }
4881 | AggregateCall::Min(agg_expr)
4882 | AggregateCall::Max(agg_expr) => {
4883 if let Some(dtype) =
4885 infer_type_recursive(agg_expr, base_schema, column_lookup_map)
4886 {
4887 return Some(dtype);
4888 }
4889
4890 if let Some(col_name) = try_extract_simple_column(agg_expr) {
4892 let idx = resolve_column_name_to_index(col_name, column_lookup_map)?;
4893 Some(base_schema.field(idx).data_type().clone())
4894 } else {
4895 if sample_batch.num_rows() > 0 {
4898 let mut computed_values = Vec::new();
4899 if let Ok(value) =
4900 Self::evaluate_expr_with_plan_value_aggregates_and_row(
4901 agg_expr,
4902 &FxHashMap::default(),
4903 Some(sample_batch),
4904 Some(column_lookup_map),
4905 0,
4906 )
4907 {
4908 computed_values.push(value);
4909 if let Ok(array) = plan_values_to_arrow_array(&computed_values) {
4910 match array.data_type() {
4911 DataType::Decimal128(_, scale) => {
4913 return Some(DataType::Decimal128(38, *scale));
4914 }
4915 DataType::Null => {
4917 return Some(DataType::Float64);
4918 }
4919 other => {
4920 return Some(other.clone());
4921 }
4922 }
4923 }
4924 }
4925 }
4926 Some(DataType::Float64)
4928 }
4929 }
4930 AggregateCall::GroupConcat { .. } => Some(DataType::Utf8),
4931 };
4932 }
4933
4934 None
4937 }
4938
4939 fn build_group_by_output_columns(
4940 &self,
4941 plan: &SelectPlan,
4942 base_schema: &Schema,
4943 column_lookup_map: &FxHashMap<String, usize>,
4944 _sample_batch: &RecordBatch,
4945 ) -> ExecutorResult<Vec<OutputColumn>> {
4946 let projections = if plan.projections.is_empty() {
4947 vec![SelectProjection::AllColumns]
4948 } else {
4949 plan.projections.clone()
4950 };
4951
4952 let mut columns: Vec<OutputColumn> = Vec::new();
4953
4954 for (proj_idx, projection) in projections.iter().enumerate() {
4955 match projection {
4956 SelectProjection::AllColumns => {
4957 for (index, field) in base_schema.fields().iter().enumerate() {
4958 columns.push(OutputColumn {
4959 field: (**field).clone(),
4960 source: OutputSource::TableColumn { index },
4961 });
4962 }
4963 }
4964 SelectProjection::AllColumnsExcept { exclude } => {
4965 let exclude_lower: FxHashSet<String> = exclude
4966 .iter()
4967 .map(|name| name.to_ascii_lowercase())
4968 .collect();
4969
4970 let mut excluded_indices = FxHashSet::default();
4971 for excluded_name in &exclude_lower {
4972 if let Some(&idx) = column_lookup_map.get(excluded_name) {
4973 excluded_indices.insert(idx);
4974 }
4975 }
4976
4977 for (index, field) in base_schema.fields().iter().enumerate() {
4978 if !exclude_lower.contains(&field.name().to_ascii_lowercase())
4979 && !excluded_indices.contains(&index)
4980 {
4981 columns.push(OutputColumn {
4982 field: (**field).clone(),
4983 source: OutputSource::TableColumn { index },
4984 });
4985 }
4986 }
4987 }
4988 SelectProjection::Column { name, alias } => {
4989 let lookup_key = name.to_ascii_lowercase();
4990 let index = column_lookup_map.get(&lookup_key).ok_or_else(|| {
4991 Error::InvalidArgumentError(format!(
4992 "column '{}' not found in GROUP BY result",
4993 name
4994 ))
4995 })?;
4996 let field = base_schema.field(*index);
4997 let field = Field::new(
4998 alias.as_ref().unwrap_or(name).clone(),
4999 field.data_type().clone(),
5000 field.is_nullable(),
5001 );
5002 columns.push(OutputColumn {
5003 field,
5004 source: OutputSource::TableColumn { index: *index },
5005 });
5006 }
5007 SelectProjection::Computed { expr, alias } => {
5008 let inferred_type = Self::infer_computed_expression_type(
5012 expr,
5013 base_schema,
5014 column_lookup_map,
5015 _sample_batch,
5016 )
5017 .unwrap_or(DataType::Float64);
5018 let field = Field::new(alias.clone(), inferred_type, true);
5019 columns.push(OutputColumn {
5020 field,
5021 source: OutputSource::Computed {
5022 projection_index: proj_idx,
5023 },
5024 });
5025 }
5026 }
5027 }
5028
5029 if columns.is_empty() {
5030 for (index, field) in base_schema.fields().iter().enumerate() {
5031 columns.push(OutputColumn {
5032 field: (**field).clone(),
5033 source: OutputSource::TableColumn { index },
5034 });
5035 }
5036 }
5037
5038 Ok(columns)
5039 }
5040
5041 fn project_record_batch(
5042 &self,
5043 batch: &RecordBatch,
5044 projections: &[SelectProjection],
5045 lookup: &FxHashMap<String, usize>,
5046 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
5047 ) -> ExecutorResult<RecordBatch> {
5048 if projections.is_empty() {
5049 return Ok(batch.clone());
5050 }
5051
5052 let schema = batch.schema();
5053 let mut selected_fields: Vec<Arc<Field>> = Vec::new();
5054 let mut selected_columns: Vec<ArrayRef> = Vec::new();
5055 let mut expr_context: Option<CrossProductExpressionContext> = None;
5056
5057 for proj in projections {
5058 match proj {
5059 SelectProjection::AllColumns => {
5060 selected_fields = schema.fields().iter().cloned().collect();
5061 selected_columns = batch.columns().to_vec();
5062 break;
5063 }
5064 SelectProjection::AllColumnsExcept { exclude } => {
5065 let exclude_lower: FxHashSet<String> = exclude
5066 .iter()
5067 .map(|name| name.to_ascii_lowercase())
5068 .collect();
5069
5070 let mut excluded_indices = FxHashSet::default();
5071 for excluded_name in &exclude_lower {
5072 if let Some(&idx) = lookup.get(excluded_name) {
5073 excluded_indices.insert(idx);
5074 }
5075 }
5076
5077 for (idx, field) in schema.fields().iter().enumerate() {
5078 let column_name = field.name().to_ascii_lowercase();
5079 if !exclude_lower.contains(&column_name) && !excluded_indices.contains(&idx)
5080 {
5081 selected_fields.push(Arc::clone(field));
5082 selected_columns.push(batch.column(idx).clone());
5083 }
5084 }
5085 break;
5086 }
5087 SelectProjection::Column { name, alias } => {
5088 let normalized = name.to_ascii_lowercase();
5089 let column_index = lookup.get(&normalized).ok_or_else(|| {
5090 Error::InvalidArgumentError(format!(
5091 "column '{}' not found in projection",
5092 name
5093 ))
5094 })?;
5095 let field = schema.field(*column_index);
5096 let output_field = Arc::new(Field::new(
5097 alias.as_ref().unwrap_or_else(|| field.name()),
5098 field.data_type().clone(),
5099 field.is_nullable(),
5100 ));
5101 selected_fields.push(output_field);
5102 selected_columns.push(batch.column(*column_index).clone());
5103 }
5104 SelectProjection::Computed { expr, alias } => {
5105 if expr_context.is_none() {
5106 expr_context = Some(CrossProductExpressionContext::new(
5107 schema.as_ref(),
5108 lookup.clone(),
5109 )?);
5110 }
5111 let context = expr_context
5112 .as_mut()
5113 .expect("projection context must be initialized");
5114 context.reset();
5115 let evaluated =
5116 self.evaluate_projection_expression(context, expr, batch, scalar_lookup)?;
5117 let field = Arc::new(Field::new(
5118 alias.clone(),
5119 evaluated.data_type().clone(),
5120 true,
5121 ));
5122 selected_fields.push(field);
5123 selected_columns.push(evaluated);
5124 }
5125 }
5126 }
5127
5128 let projected_schema = Arc::new(Schema::new(selected_fields));
5129 RecordBatch::try_new(projected_schema, selected_columns)
5130 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))
5131 }
5132
5133 fn execute_group_by_with_aggregates(
5135 &self,
5136 display_name: String,
5137 plan: SelectPlan,
5138 base_schema: Arc<Schema>,
5139 batches: Vec<RecordBatch>,
5140 column_lookup_map: FxHashMap<String, usize>,
5141 ) -> ExecutorResult<SelectExecution<P>> {
5142 use llkv_expr::expr::AggregateCall;
5143
5144 let mut key_indices = Vec::with_capacity(plan.group_by.len());
5146 for column in &plan.group_by {
5147 let key = column.to_ascii_lowercase();
5148 let index = column_lookup_map.get(&key).ok_or_else(|| {
5149 Error::InvalidArgumentError(format!(
5150 "column '{}' not found in GROUP BY input",
5151 column
5152 ))
5153 })?;
5154 key_indices.push(*index);
5155 }
5156
5157 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
5159 for proj in &plan.projections {
5160 if let SelectProjection::Computed { expr, .. } = proj {
5161 Self::collect_aggregates(expr, &mut aggregate_specs);
5162 }
5163 }
5164
5165 if let Some(having_expr) = &plan.having {
5167 Self::collect_aggregates_from_predicate(having_expr, &mut aggregate_specs);
5168 }
5169
5170 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
5172 let mut group_states: Vec<GroupAggregateState> = Vec::new();
5173
5174 for (batch_idx, batch) in batches.iter().enumerate() {
5176 for row_idx in 0..batch.num_rows() {
5177 let key = build_group_key(batch, row_idx, &key_indices)?;
5178
5179 if let Some(&group_idx) = group_index.get(&key) {
5180 group_states[group_idx]
5182 .row_locations
5183 .push((batch_idx, row_idx));
5184 } else {
5185 let group_idx = group_states.len();
5187 group_index.insert(key, group_idx);
5188 group_states.push(GroupAggregateState {
5189 representative_batch_idx: batch_idx,
5190 representative_row: row_idx,
5191 row_locations: vec![(batch_idx, row_idx)],
5192 });
5193 }
5194 }
5195 }
5196
5197 let mut group_aggregate_values: Vec<FxHashMap<String, PlanValue>> =
5199 Vec::with_capacity(group_states.len());
5200
5201 for group_state in &group_states {
5202 tracing::debug!(
5203 "[GROUP BY] aggregate group rows={:?}",
5204 group_state.row_locations
5205 );
5206 let group_batch = {
5208 let representative_batch = &batches[group_state.representative_batch_idx];
5209 let schema = representative_batch.schema();
5210
5211 let mut per_batch_indices: Vec<(usize, Vec<u64>)> = Vec::new();
5213 for &(batch_idx, row_idx) in &group_state.row_locations {
5214 if let Some((_, indices)) = per_batch_indices
5215 .iter_mut()
5216 .find(|(idx, _)| *idx == batch_idx)
5217 {
5218 indices.push(row_idx as u64);
5219 } else {
5220 per_batch_indices.push((batch_idx, vec![row_idx as u64]));
5221 }
5222 }
5223
5224 let mut row_index_arrays: Vec<(usize, ArrayRef)> =
5225 Vec::with_capacity(per_batch_indices.len());
5226 for (batch_idx, indices) in per_batch_indices {
5227 let index_array: ArrayRef = Arc::new(arrow::array::UInt64Array::from(indices));
5228 row_index_arrays.push((batch_idx, index_array));
5229 }
5230
5231 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
5232
5233 for col_idx in 0..schema.fields().len() {
5234 let column_array = if row_index_arrays.len() == 1 {
5235 let (batch_idx, indices) = &row_index_arrays[0];
5236 let source_array = batches[*batch_idx].column(col_idx);
5237 arrow::compute::take(source_array.as_ref(), indices.as_ref(), None)?
5238 } else {
5239 let mut partial_arrays: Vec<ArrayRef> =
5240 Vec::with_capacity(row_index_arrays.len());
5241 for (batch_idx, indices) in &row_index_arrays {
5242 let source_array = batches[*batch_idx].column(col_idx);
5243 let taken = arrow::compute::take(
5244 source_array.as_ref(),
5245 indices.as_ref(),
5246 None,
5247 )?;
5248 partial_arrays.push(taken);
5249 }
5250 let slices: Vec<&dyn arrow::array::Array> =
5251 partial_arrays.iter().map(|arr| arr.as_ref()).collect();
5252 arrow::compute::concat(&slices)?
5253 };
5254 arrays.push(column_array);
5255 }
5256
5257 let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
5258 tracing::debug!("[GROUP BY] group batch rows={}", batch.num_rows());
5259 batch
5260 };
5261
5262 let mut aggregate_values: FxHashMap<String, PlanValue> = FxHashMap::default();
5264
5265 let mut working_batch = group_batch.clone();
5267 let mut next_temp_col_idx = working_batch.num_columns();
5268
5269 for (key, agg_call) in &aggregate_specs {
5270 let (projection_idx, value_type) = match agg_call {
5272 AggregateCall::CountStar => (None, None),
5273 AggregateCall::Count { expr, .. }
5274 | AggregateCall::Sum { expr, .. }
5275 | AggregateCall::Total { expr, .. }
5276 | AggregateCall::Avg { expr, .. }
5277 | AggregateCall::Min(expr)
5278 | AggregateCall::Max(expr)
5279 | AggregateCall::CountNulls(expr)
5280 | AggregateCall::GroupConcat { expr, .. } => {
5281 if let Some(col_name) = try_extract_simple_column(expr) {
5282 let idx = resolve_column_name_to_index(col_name, &column_lookup_map)
5283 .ok_or_else(|| {
5284 Error::InvalidArgumentError(format!(
5285 "column '{}' not found for aggregate",
5286 col_name
5287 ))
5288 })?;
5289 let field_type = working_batch.schema().field(idx).data_type().clone();
5290 (Some(idx), Some(field_type))
5291 } else {
5292 let mut computed_values = Vec::with_capacity(working_batch.num_rows());
5294 for row_idx in 0..working_batch.num_rows() {
5295 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5296 expr,
5297 &FxHashMap::default(),
5298 Some(&working_batch),
5299 Some(&column_lookup_map),
5300 row_idx,
5301 )?;
5302 computed_values.push(value);
5303 }
5304
5305 let computed_array = plan_values_to_arrow_array(&computed_values)?;
5306 let computed_type = computed_array.data_type().clone();
5307
5308 let mut new_columns: Vec<ArrayRef> = working_batch.columns().to_vec();
5309 new_columns.push(computed_array);
5310
5311 let temp_field = Arc::new(Field::new(
5312 format!("__temp_agg_expr_{}", next_temp_col_idx),
5313 computed_type.clone(),
5314 true,
5315 ));
5316 let mut new_fields: Vec<Arc<Field>> =
5317 working_batch.schema().fields().iter().cloned().collect();
5318 new_fields.push(temp_field);
5319 let new_schema = Arc::new(Schema::new(new_fields));
5320
5321 working_batch = RecordBatch::try_new(new_schema, new_columns)?;
5322
5323 let col_idx = next_temp_col_idx;
5324 next_temp_col_idx += 1;
5325 (Some(col_idx), Some(computed_type))
5326 }
5327 }
5328 };
5329
5330 let spec = Self::build_aggregate_spec_for_cross_product(
5332 agg_call,
5333 key.clone(),
5334 value_type.clone(),
5335 )?;
5336
5337 let mut state = llkv_aggregate::AggregateState {
5338 alias: key.clone(),
5339 accumulator: llkv_aggregate::AggregateAccumulator::new_with_projection_index(
5340 &spec,
5341 projection_idx,
5342 None,
5343 )?,
5344 override_value: None,
5345 };
5346
5347 state.update(&working_batch)?;
5349
5350 let (_field, array) = state.finalize()?;
5352 let value = llkv_plan::plan_value_from_array(&array, 0)?;
5353 tracing::debug!(
5354 "[GROUP BY] aggregate result key={:?} value={:?}",
5355 key,
5356 value
5357 );
5358 aggregate_values.insert(key.clone(), value);
5359 }
5360
5361 group_aggregate_values.push(aggregate_values);
5362 }
5363
5364 let output_columns = self.build_group_by_output_columns(
5366 &plan,
5367 base_schema.as_ref(),
5368 &column_lookup_map,
5369 batches
5370 .first()
5371 .unwrap_or(&RecordBatch::new_empty(Arc::clone(&base_schema))),
5372 )?;
5373
5374 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(group_states.len());
5375
5376 for (group_idx, group_state) in group_states.iter().enumerate() {
5377 let aggregate_values = &group_aggregate_values[group_idx];
5378 let representative_batch = &batches[group_state.representative_batch_idx];
5379
5380 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
5381 for output in &output_columns {
5382 match output.source {
5383 OutputSource::TableColumn { index } => {
5384 let value = llkv_plan::plan_value_from_array(
5386 representative_batch.column(index),
5387 group_state.representative_row,
5388 )?;
5389 row.push(value);
5390 }
5391 OutputSource::Computed { projection_index } => {
5392 let expr = match &plan.projections[projection_index] {
5393 SelectProjection::Computed { expr, .. } => expr,
5394 _ => unreachable!("projection index mismatch for computed column"),
5395 };
5396 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5398 expr,
5399 aggregate_values,
5400 Some(representative_batch),
5401 Some(&column_lookup_map),
5402 group_state.representative_row,
5403 )?;
5404 row.push(value);
5405 }
5406 }
5407 }
5408 rows.push(row);
5409 }
5410
5411 let filtered_rows = if let Some(having) = &plan.having {
5413 let mut filtered = Vec::new();
5414 for (row_idx, row) in rows.iter().enumerate() {
5415 let aggregate_values = &group_aggregate_values[row_idx];
5416 let group_state = &group_states[row_idx];
5417 let representative_batch = &batches[group_state.representative_batch_idx];
5418 let passes = Self::evaluate_having_expr(
5420 having,
5421 aggregate_values,
5422 representative_batch,
5423 &column_lookup_map,
5424 group_state.representative_row,
5425 )?;
5426 if matches!(passes, Some(true)) {
5428 filtered.push(row.clone());
5429 }
5430 }
5431 filtered
5432 } else {
5433 rows
5434 };
5435
5436 let fields: Vec<Field> = output_columns
5437 .into_iter()
5438 .map(|output| output.field)
5439 .collect();
5440 let schema = Arc::new(Schema::new(fields));
5441
5442 let mut batch = rows_to_record_batch(Arc::clone(&schema), &filtered_rows)?;
5443
5444 if plan.distinct && batch.num_rows() > 0 {
5445 let mut state = DistinctState::default();
5446 batch = match distinct_filter_batch(batch, &mut state)? {
5447 Some(filtered) => filtered,
5448 None => RecordBatch::new_empty(Arc::clone(&schema)),
5449 };
5450 }
5451
5452 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
5453 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
5454 }
5455
5456 Ok(SelectExecution::new_single_batch(
5457 display_name,
5458 schema,
5459 batch,
5460 ))
5461 }
5462
5463 fn execute_aggregates(
5464 &self,
5465 table: Arc<ExecutorTable<P>>,
5466 display_name: String,
5467 plan: SelectPlan,
5468 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
5469 ) -> ExecutorResult<SelectExecution<P>> {
5470 let table_ref = table.as_ref();
5471 let distinct = plan.distinct;
5472 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
5473 for aggregate in plan.aggregates {
5474 match aggregate {
5475 AggregateExpr::CountStar { alias, distinct } => {
5476 specs.push(AggregateSpec {
5477 alias,
5478 kind: AggregateKind::Count {
5479 field_id: None,
5480 distinct,
5481 },
5482 });
5483 }
5484 AggregateExpr::Column {
5485 column,
5486 alias,
5487 function,
5488 distinct,
5489 } => {
5490 let col = table_ref.schema.resolve(&column).ok_or_else(|| {
5491 Error::InvalidArgumentError(format!(
5492 "unknown column '{}' in aggregate",
5493 column
5494 ))
5495 })?;
5496
5497 let kind = match function {
5498 AggregateFunction::Count => AggregateKind::Count {
5499 field_id: Some(col.field_id),
5500 distinct,
5501 },
5502 AggregateFunction::SumInt64 => {
5503 let input_type = Self::validate_aggregate_type(
5504 Some(col.data_type.clone()),
5505 "SUM",
5506 &[DataType::Int64, DataType::Float64],
5507 )?;
5508 AggregateKind::Sum {
5509 field_id: col.field_id,
5510 data_type: input_type,
5511 distinct,
5512 }
5513 }
5514 AggregateFunction::TotalInt64 => {
5515 let input_type = Self::validate_aggregate_type(
5516 Some(col.data_type.clone()),
5517 "TOTAL",
5518 &[DataType::Int64, DataType::Float64],
5519 )?;
5520 AggregateKind::Total {
5521 field_id: col.field_id,
5522 data_type: input_type,
5523 distinct,
5524 }
5525 }
5526 AggregateFunction::MinInt64 => {
5527 let input_type = Self::validate_aggregate_type(
5528 Some(col.data_type.clone()),
5529 "MIN",
5530 &[DataType::Int64, DataType::Float64],
5531 )?;
5532 AggregateKind::Min {
5533 field_id: col.field_id,
5534 data_type: input_type,
5535 }
5536 }
5537 AggregateFunction::MaxInt64 => {
5538 let input_type = Self::validate_aggregate_type(
5539 Some(col.data_type.clone()),
5540 "MAX",
5541 &[DataType::Int64, DataType::Float64],
5542 )?;
5543 AggregateKind::Max {
5544 field_id: col.field_id,
5545 data_type: input_type,
5546 }
5547 }
5548 AggregateFunction::CountNulls => {
5549 if distinct {
5550 return Err(Error::InvalidArgumentError(
5551 "DISTINCT is not supported for COUNT_NULLS".into(),
5552 ));
5553 }
5554 AggregateKind::CountNulls {
5555 field_id: col.field_id,
5556 }
5557 }
5558 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
5559 field_id: col.field_id,
5560 distinct,
5561 separator: ",".to_string(),
5562 },
5563 };
5564 specs.push(AggregateSpec { alias, kind });
5565 }
5566 }
5567 }
5568
5569 if specs.is_empty() {
5570 return Err(Error::InvalidArgumentError(
5571 "aggregate query requires at least one aggregate expression".into(),
5572 ));
5573 }
5574
5575 let had_filter = plan.filter.is_some();
5576 let filter_expr = match &plan.filter {
5577 Some(filter_wrapper) => {
5578 if !filter_wrapper.subqueries.is_empty() {
5579 return Err(Error::InvalidArgumentError(
5580 "EXISTS subqueries not yet implemented in aggregate queries".into(),
5581 ));
5582 }
5583 let mut translated = crate::translation::expression::translate_predicate(
5584 filter_wrapper.predicate.clone(),
5585 table.schema.as_ref(),
5586 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
5587 )?;
5588
5589 let mut filter_scalar_ids = FxHashSet::default();
5591 collect_predicate_scalar_subquery_ids(&translated, &mut filter_scalar_ids);
5592
5593 if !filter_scalar_ids.is_empty() {
5594 let filter_scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> =
5596 plan.scalar_subqueries
5597 .iter()
5598 .filter(|subquery| filter_scalar_ids.contains(&subquery.id))
5599 .map(|subquery| (subquery.id, subquery))
5600 .collect();
5601
5602 let base_schema = Arc::new(Schema::new(Vec::<Field>::new()));
5604 let base_lookup = FxHashMap::default();
5605 let mut context =
5606 CrossProductExpressionContext::new(base_schema.as_ref(), base_lookup)?;
5607 let empty_batch =
5608 RecordBatch::new_empty(Arc::new(Schema::new(Vec::<Field>::new())));
5609
5610 let mut scalar_literals: FxHashMap<SubqueryId, Literal> = FxHashMap::default();
5612 for (subquery_id, subquery) in filter_scalar_lookup.iter() {
5613 let literal = self.evaluate_scalar_subquery_literal(
5614 &mut context,
5615 subquery,
5616 &empty_batch,
5617 0,
5618 )?;
5619 scalar_literals.insert(*subquery_id, literal);
5620 }
5621
5622 translated = rewrite_predicate_scalar_subqueries(translated, &scalar_literals)?;
5624 }
5625
5626 translated
5627 }
5628 None => {
5629 let field_id = table.schema.first_field_id().ok_or_else(|| {
5630 Error::InvalidArgumentError(
5631 "table has no columns; cannot perform aggregate scan".into(),
5632 )
5633 })?;
5634 crate::translation::expression::full_table_scan_filter(field_id)
5635 }
5636 };
5637
5638 let mut projections = Vec::new();
5640 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(specs.len());
5641
5642 for spec in &specs {
5643 if let Some(field_id) = spec.kind.field_id() {
5644 let proj_idx = projections.len();
5645 spec_to_projection.push(Some(proj_idx));
5646 projections.push(ScanProjection::from(StoreProjection::with_alias(
5647 LogicalFieldId::for_user(table.table_id(), field_id),
5648 table
5649 .schema
5650 .column_by_field_id(field_id)
5651 .map(|c| c.name.clone())
5652 .unwrap_or_else(|| format!("col{field_id}")),
5653 )));
5654 } else {
5655 spec_to_projection.push(None);
5656 }
5657 }
5658
5659 if projections.is_empty() {
5660 let field_id = table.schema.first_field_id().ok_or_else(|| {
5661 Error::InvalidArgumentError(
5662 "table has no columns; cannot perform aggregate scan".into(),
5663 )
5664 })?;
5665 projections.push(ScanProjection::from(StoreProjection::with_alias(
5666 LogicalFieldId::for_user(table.table_id(), field_id),
5667 table
5668 .schema
5669 .column_by_field_id(field_id)
5670 .map(|c| c.name.clone())
5671 .unwrap_or_else(|| format!("col{field_id}")),
5672 )));
5673 }
5674
5675 let options = ScanStreamOptions {
5676 include_nulls: true,
5677 order: None,
5678 row_id_filter: row_filter.clone(),
5679 include_row_ids: true,
5680 };
5681
5682 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
5683 let mut count_star_override: Option<i64> = None;
5687 if !had_filter && row_filter.is_none() {
5688 let total_rows = table.total_rows.load(Ordering::SeqCst);
5690 tracing::debug!(
5691 "[AGGREGATE] Using COUNT(*) shortcut: total_rows={}",
5692 total_rows
5693 );
5694 if total_rows > i64::MAX as u64 {
5695 return Err(Error::InvalidArgumentError(
5696 "COUNT(*) result exceeds supported range".into(),
5697 ));
5698 }
5699 count_star_override = Some(total_rows as i64);
5700 } else {
5701 tracing::debug!(
5702 "[AGGREGATE] NOT using COUNT(*) shortcut: had_filter={}, has_row_filter={}",
5703 had_filter,
5704 row_filter.is_some()
5705 );
5706 }
5707
5708 for (idx, spec) in specs.iter().enumerate() {
5709 states.push(AggregateState {
5710 alias: spec.alias.clone(),
5711 accumulator: AggregateAccumulator::new_with_projection_index(
5712 spec,
5713 spec_to_projection[idx],
5714 count_star_override,
5715 )?,
5716 override_value: match &spec.kind {
5717 AggregateKind::Count { field_id: None, .. } => {
5718 tracing::debug!(
5719 "[AGGREGATE] CountStar override_value={:?}",
5720 count_star_override
5721 );
5722 count_star_override
5723 }
5724 _ => None,
5725 },
5726 });
5727 }
5728
5729 let mut error: Option<Error> = None;
5730 match table.storage().scan_stream(
5731 &projections,
5732 &filter_expr,
5733 ScanStreamOptions {
5734 row_id_filter: row_filter.clone(),
5735 ..options
5736 },
5737 &mut |batch| {
5738 if error.is_some() {
5739 return;
5740 }
5741 for state in &mut states {
5742 if let Err(err) = state.update(&batch) {
5743 error = Some(err);
5744 return;
5745 }
5746 }
5747 },
5748 ) {
5749 Ok(()) => {}
5750 Err(llkv_result::Error::NotFound) => {
5751 }
5754 Err(err) => return Err(err),
5755 }
5756 if let Some(err) = error {
5757 return Err(err);
5758 }
5759
5760 let mut fields = Vec::with_capacity(states.len());
5761 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
5762 for state in states {
5763 let (field, array) = state.finalize()?;
5764 fields.push(field);
5765 arrays.push(array);
5766 }
5767
5768 let schema = Arc::new(Schema::new(fields));
5769 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
5770
5771 if distinct {
5772 let mut state = DistinctState::default();
5773 batch = match distinct_filter_batch(batch, &mut state)? {
5774 Some(filtered) => filtered,
5775 None => RecordBatch::new_empty(Arc::clone(&schema)),
5776 };
5777 }
5778
5779 let schema = batch.schema();
5780
5781 Ok(SelectExecution::new_single_batch(
5782 display_name,
5783 schema,
5784 batch,
5785 ))
5786 }
5787
5788 fn execute_computed_aggregates(
5791 &self,
5792 table: Arc<ExecutorTable<P>>,
5793 display_name: String,
5794 plan: SelectPlan,
5795 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
5796 ) -> ExecutorResult<SelectExecution<P>> {
5797 use arrow::array::Int64Array;
5798 use llkv_expr::expr::AggregateCall;
5799
5800 let table_ref = table.as_ref();
5801 let distinct = plan.distinct;
5802
5803 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
5805 for proj in &plan.projections {
5806 if let SelectProjection::Computed { expr, .. } = proj {
5807 Self::collect_aggregates(expr, &mut aggregate_specs);
5808 }
5809 }
5810
5811 let filter_predicate = plan
5813 .filter
5814 .as_ref()
5815 .map(|wrapper| {
5816 if !wrapper.subqueries.is_empty() {
5817 return Err(Error::InvalidArgumentError(
5818 "EXISTS subqueries not yet implemented with aggregates".into(),
5819 ));
5820 }
5821 Ok(wrapper.predicate.clone())
5822 })
5823 .transpose()?;
5824
5825 let computed_aggregates = self.compute_aggregate_values(
5826 table.clone(),
5827 &filter_predicate,
5828 &aggregate_specs,
5829 row_filter.clone(),
5830 )?;
5831
5832 let mut fields = Vec::with_capacity(plan.projections.len());
5834 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
5835
5836 for proj in &plan.projections {
5837 match proj {
5838 SelectProjection::AllColumns | SelectProjection::AllColumnsExcept { .. } => {
5839 return Err(Error::InvalidArgumentError(
5840 "Wildcard projections not supported with computed aggregates".into(),
5841 ));
5842 }
5843 SelectProjection::Column { name, alias } => {
5844 let col = table_ref.schema.resolve(name).ok_or_else(|| {
5845 Error::InvalidArgumentError(format!("unknown column '{}'", name))
5846 })?;
5847 let field_name = alias.as_ref().unwrap_or(name);
5848 fields.push(arrow::datatypes::Field::new(
5849 field_name,
5850 col.data_type.clone(),
5851 col.nullable,
5852 ));
5853 return Err(Error::InvalidArgumentError(
5856 "Regular columns not supported in aggregate queries without GROUP BY"
5857 .into(),
5858 ));
5859 }
5860 SelectProjection::Computed { expr, alias } => {
5861 if let ScalarExpr::Aggregate(agg) = expr {
5863 let key = format!("{:?}", agg);
5864 if let Some(agg_value) = computed_aggregates.get(&key) {
5865 match agg_value {
5866 AggregateValue::Null => {
5867 fields.push(arrow::datatypes::Field::new(
5868 alias,
5869 DataType::Int64,
5870 true,
5871 ));
5872 arrays
5873 .push(Arc::new(Int64Array::from(vec![None::<i64>]))
5874 as ArrayRef);
5875 }
5876 AggregateValue::Int64(v) => {
5877 fields.push(arrow::datatypes::Field::new(
5878 alias,
5879 DataType::Int64,
5880 true,
5881 ));
5882 arrays.push(
5883 Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef
5884 );
5885 }
5886 AggregateValue::Float64(v) => {
5887 fields.push(arrow::datatypes::Field::new(
5888 alias,
5889 DataType::Float64,
5890 true,
5891 ));
5892 arrays
5893 .push(Arc::new(Float64Array::from(vec![Some(*v)]))
5894 as ArrayRef);
5895 }
5896 AggregateValue::Decimal128 { value, scale } => {
5897 let precision = if *value == 0 {
5899 1
5900 } else {
5901 (*value).abs().to_string().len() as u8
5902 };
5903 fields.push(arrow::datatypes::Field::new(
5904 alias,
5905 DataType::Decimal128(precision, *scale),
5906 true,
5907 ));
5908 let array = Decimal128Array::from(vec![Some(*value)])
5909 .with_precision_and_scale(precision, *scale)
5910 .map_err(|e| {
5911 Error::Internal(format!("invalid Decimal128: {}", e))
5912 })?;
5913 arrays.push(Arc::new(array) as ArrayRef);
5914 }
5915 AggregateValue::String(s) => {
5916 fields.push(arrow::datatypes::Field::new(
5917 alias,
5918 DataType::Utf8,
5919 true,
5920 ));
5921 arrays
5922 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
5923 as ArrayRef);
5924 }
5925 }
5926 continue;
5927 }
5928 }
5929
5930 let value = Self::evaluate_expr_with_aggregates(expr, &computed_aggregates)?;
5932
5933 fields.push(arrow::datatypes::Field::new(alias, DataType::Int64, true));
5934
5935 let array = Arc::new(Int64Array::from(vec![value])) as ArrayRef;
5936 arrays.push(array);
5937 }
5938 }
5939 }
5940
5941 let schema = Arc::new(Schema::new(fields));
5942 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
5943
5944 if distinct {
5945 let mut state = DistinctState::default();
5946 batch = match distinct_filter_batch(batch, &mut state)? {
5947 Some(filtered) => filtered,
5948 None => RecordBatch::new_empty(Arc::clone(&schema)),
5949 };
5950 }
5951
5952 let schema = batch.schema();
5953
5954 Ok(SelectExecution::new_single_batch(
5955 display_name,
5956 schema,
5957 batch,
5958 ))
5959 }
5960
5961 fn build_aggregate_spec_for_cross_product(
5964 agg_call: &llkv_expr::expr::AggregateCall<String>,
5965 alias: String,
5966 data_type: Option<DataType>,
5967 ) -> ExecutorResult<llkv_aggregate::AggregateSpec> {
5968 use llkv_expr::expr::AggregateCall;
5969
5970 let kind = match agg_call {
5971 AggregateCall::CountStar => llkv_aggregate::AggregateKind::Count {
5972 field_id: None,
5973 distinct: false,
5974 },
5975 AggregateCall::Count { distinct, .. } => llkv_aggregate::AggregateKind::Count {
5976 field_id: Some(0),
5977 distinct: *distinct,
5978 },
5979 AggregateCall::Sum { distinct, .. } => llkv_aggregate::AggregateKind::Sum {
5980 field_id: 0,
5981 data_type: Self::validate_aggregate_type(
5982 data_type.clone(),
5983 "SUM",
5984 &[DataType::Int64, DataType::Float64],
5985 )?,
5986 distinct: *distinct,
5987 },
5988 AggregateCall::Total { distinct, .. } => llkv_aggregate::AggregateKind::Total {
5989 field_id: 0,
5990 data_type: Self::validate_aggregate_type(
5991 data_type.clone(),
5992 "TOTAL",
5993 &[DataType::Int64, DataType::Float64],
5994 )?,
5995 distinct: *distinct,
5996 },
5997 AggregateCall::Avg { distinct, .. } => llkv_aggregate::AggregateKind::Avg {
5998 field_id: 0,
5999 data_type: Self::validate_aggregate_type(
6000 data_type.clone(),
6001 "AVG",
6002 &[DataType::Int64, DataType::Float64],
6003 )?,
6004 distinct: *distinct,
6005 },
6006 AggregateCall::Min(_) => llkv_aggregate::AggregateKind::Min {
6007 field_id: 0,
6008 data_type: Self::validate_aggregate_type(
6009 data_type.clone(),
6010 "MIN",
6011 &[DataType::Int64, DataType::Float64],
6012 )?,
6013 },
6014 AggregateCall::Max(_) => llkv_aggregate::AggregateKind::Max {
6015 field_id: 0,
6016 data_type: Self::validate_aggregate_type(
6017 data_type.clone(),
6018 "MAX",
6019 &[DataType::Int64, DataType::Float64],
6020 )?,
6021 },
6022 AggregateCall::CountNulls(_) => {
6023 llkv_aggregate::AggregateKind::CountNulls { field_id: 0 }
6024 }
6025 AggregateCall::GroupConcat {
6026 distinct,
6027 separator,
6028 ..
6029 } => llkv_aggregate::AggregateKind::GroupConcat {
6030 field_id: 0,
6031 distinct: *distinct,
6032 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
6033 },
6034 };
6035
6036 Ok(llkv_aggregate::AggregateSpec { alias, kind })
6037 }
6038
6039 fn validate_aggregate_type(
6051 data_type: Option<DataType>,
6052 func_name: &str,
6053 allowed: &[DataType],
6054 ) -> ExecutorResult<DataType> {
6055 let dt = data_type.ok_or_else(|| {
6056 Error::Internal(format!(
6057 "missing input type metadata for {func_name} aggregate"
6058 ))
6059 })?;
6060
6061 if matches!(func_name, "SUM" | "AVG" | "TOTAL" | "MIN" | "MAX") {
6064 match dt {
6065 DataType::Int64 | DataType::Float64 | DataType::Decimal128(_, _) => Ok(dt),
6067
6068 DataType::Utf8 | DataType::Boolean | DataType::Date32 => Ok(DataType::Float64),
6071
6072 DataType::Null => Ok(DataType::Float64),
6075
6076 _ => Err(Error::InvalidArgumentError(format!(
6077 "{func_name} aggregate not supported for column type {:?}",
6078 dt
6079 ))),
6080 }
6081 } else {
6082 if allowed.iter().any(|candidate| candidate == &dt) {
6084 Ok(dt)
6085 } else {
6086 Err(Error::InvalidArgumentError(format!(
6087 "{func_name} aggregate not supported for column type {:?}",
6088 dt
6089 )))
6090 }
6091 }
6092 }
6093
6094 fn collect_aggregates(
6096 expr: &ScalarExpr<String>,
6097 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
6098 ) {
6099 match expr {
6100 ScalarExpr::Aggregate(agg) => {
6101 let key = format!("{:?}", agg);
6103 if !aggregates.iter().any(|(k, _)| k == &key) {
6104 aggregates.push((key, agg.clone()));
6105 }
6106 }
6107 ScalarExpr::Binary { left, right, .. } => {
6108 Self::collect_aggregates(left, aggregates);
6109 Self::collect_aggregates(right, aggregates);
6110 }
6111 ScalarExpr::Compare { left, right, .. } => {
6112 Self::collect_aggregates(left, aggregates);
6113 Self::collect_aggregates(right, aggregates);
6114 }
6115 ScalarExpr::GetField { base, .. } => {
6116 Self::collect_aggregates(base, aggregates);
6117 }
6118 ScalarExpr::Cast { expr, .. } => {
6119 Self::collect_aggregates(expr, aggregates);
6120 }
6121 ScalarExpr::Not(expr) => {
6122 Self::collect_aggregates(expr, aggregates);
6123 }
6124 ScalarExpr::IsNull { expr, .. } => {
6125 Self::collect_aggregates(expr, aggregates);
6126 }
6127 ScalarExpr::Case {
6128 operand,
6129 branches,
6130 else_expr,
6131 } => {
6132 if let Some(inner) = operand.as_deref() {
6133 Self::collect_aggregates(inner, aggregates);
6134 }
6135 for (when_expr, then_expr) in branches {
6136 Self::collect_aggregates(when_expr, aggregates);
6137 Self::collect_aggregates(then_expr, aggregates);
6138 }
6139 if let Some(inner) = else_expr.as_deref() {
6140 Self::collect_aggregates(inner, aggregates);
6141 }
6142 }
6143 ScalarExpr::Coalesce(items) => {
6144 for item in items {
6145 Self::collect_aggregates(item, aggregates);
6146 }
6147 }
6148 ScalarExpr::Column(_) | ScalarExpr::Literal(_) | ScalarExpr::Random => {}
6149 ScalarExpr::ScalarSubquery(_) => {}
6150 }
6151 }
6152
6153 fn collect_aggregates_from_predicate(
6155 expr: &llkv_expr::expr::Expr<String>,
6156 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
6157 ) {
6158 match expr {
6159 llkv_expr::expr::Expr::Compare { left, right, .. } => {
6160 Self::collect_aggregates(left, aggregates);
6161 Self::collect_aggregates(right, aggregates);
6162 }
6163 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
6164 for e in exprs {
6165 Self::collect_aggregates_from_predicate(e, aggregates);
6166 }
6167 }
6168 llkv_expr::expr::Expr::Not(inner) => {
6169 Self::collect_aggregates_from_predicate(inner, aggregates);
6170 }
6171 llkv_expr::expr::Expr::InList {
6172 expr: test_expr,
6173 list,
6174 ..
6175 } => {
6176 Self::collect_aggregates(test_expr, aggregates);
6177 for item in list {
6178 Self::collect_aggregates(item, aggregates);
6179 }
6180 }
6181 llkv_expr::expr::Expr::IsNull { expr, .. } => {
6182 Self::collect_aggregates(expr, aggregates);
6183 }
6184 llkv_expr::expr::Expr::Literal(_) => {}
6185 llkv_expr::expr::Expr::Pred(_) => {}
6186 llkv_expr::expr::Expr::Exists(_) => {}
6187 }
6188 }
6189
6190 fn compute_aggregate_values(
6192 &self,
6193 table: Arc<ExecutorTable<P>>,
6194 filter: &Option<llkv_expr::expr::Expr<'static, String>>,
6195 aggregate_specs: &[(String, llkv_expr::expr::AggregateCall<String>)],
6196 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
6197 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
6198 use llkv_expr::expr::AggregateCall;
6199
6200 let table_ref = table.as_ref();
6201 let mut results =
6202 FxHashMap::with_capacity_and_hasher(aggregate_specs.len(), Default::default());
6203
6204 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
6205 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
6206 let mut projections: Vec<ScanProjection> = Vec::new();
6207 let mut column_projection_cache: FxHashMap<FieldId, usize> = FxHashMap::default();
6208 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
6209 FxHashMap::default();
6210 let mut computed_alias_counter: usize = 0;
6211
6212 for (key, agg) in aggregate_specs {
6213 match agg {
6214 AggregateCall::CountStar => {
6215 specs.push(AggregateSpec {
6216 alias: key.clone(),
6217 kind: AggregateKind::Count {
6218 field_id: None,
6219 distinct: false,
6220 },
6221 });
6222 spec_to_projection.push(None);
6223 }
6224 AggregateCall::Count { expr, distinct } => {
6225 if let Some(col_name) = try_extract_simple_column(expr) {
6226 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6227 Error::InvalidArgumentError(format!(
6228 "unknown column '{}' in aggregate",
6229 col_name
6230 ))
6231 })?;
6232 let projection_index = get_or_insert_column_projection(
6233 &mut projections,
6234 &mut column_projection_cache,
6235 table_ref,
6236 col,
6237 );
6238 specs.push(AggregateSpec {
6239 alias: key.clone(),
6240 kind: AggregateKind::Count {
6241 field_id: Some(col.field_id),
6242 distinct: *distinct,
6243 },
6244 });
6245 spec_to_projection.push(Some(projection_index));
6246 } else {
6247 let (projection_index, _dtype) = ensure_computed_projection(
6248 expr,
6249 table_ref,
6250 &mut projections,
6251 &mut computed_projection_cache,
6252 &mut computed_alias_counter,
6253 )?;
6254 let field_id = u32::try_from(projection_index).map_err(|_| {
6255 Error::InvalidArgumentError(
6256 "aggregate projection index exceeds supported range".into(),
6257 )
6258 })?;
6259 specs.push(AggregateSpec {
6260 alias: key.clone(),
6261 kind: AggregateKind::Count {
6262 field_id: Some(field_id),
6263 distinct: *distinct,
6264 },
6265 });
6266 spec_to_projection.push(Some(projection_index));
6267 }
6268 }
6269 AggregateCall::Sum { expr, distinct } => {
6270 let (projection_index, data_type, field_id) =
6271 if let Some(col_name) = try_extract_simple_column(expr) {
6272 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6273 Error::InvalidArgumentError(format!(
6274 "unknown column '{}' in aggregate",
6275 col_name
6276 ))
6277 })?;
6278 let projection_index = get_or_insert_column_projection(
6279 &mut projections,
6280 &mut column_projection_cache,
6281 table_ref,
6282 col,
6283 );
6284 let data_type = col.data_type.clone();
6285 (projection_index, data_type, col.field_id)
6286 } else {
6287 let (projection_index, inferred_type) = ensure_computed_projection(
6288 expr,
6289 table_ref,
6290 &mut projections,
6291 &mut computed_projection_cache,
6292 &mut computed_alias_counter,
6293 )?;
6294 let field_id = u32::try_from(projection_index).map_err(|_| {
6295 Error::InvalidArgumentError(
6296 "aggregate projection index exceeds supported range".into(),
6297 )
6298 })?;
6299 (projection_index, inferred_type, field_id)
6300 };
6301 let normalized_type = Self::validate_aggregate_type(
6302 Some(data_type.clone()),
6303 "SUM",
6304 &[DataType::Int64, DataType::Float64],
6305 )?;
6306 specs.push(AggregateSpec {
6307 alias: key.clone(),
6308 kind: AggregateKind::Sum {
6309 field_id,
6310 data_type: normalized_type,
6311 distinct: *distinct,
6312 },
6313 });
6314 spec_to_projection.push(Some(projection_index));
6315 }
6316 AggregateCall::Total { expr, distinct } => {
6317 let (projection_index, data_type, field_id) =
6318 if let Some(col_name) = try_extract_simple_column(expr) {
6319 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6320 Error::InvalidArgumentError(format!(
6321 "unknown column '{}' in aggregate",
6322 col_name
6323 ))
6324 })?;
6325 let projection_index = get_or_insert_column_projection(
6326 &mut projections,
6327 &mut column_projection_cache,
6328 table_ref,
6329 col,
6330 );
6331 let data_type = col.data_type.clone();
6332 (projection_index, data_type, col.field_id)
6333 } else {
6334 let (projection_index, inferred_type) = ensure_computed_projection(
6335 expr,
6336 table_ref,
6337 &mut projections,
6338 &mut computed_projection_cache,
6339 &mut computed_alias_counter,
6340 )?;
6341 let field_id = u32::try_from(projection_index).map_err(|_| {
6342 Error::InvalidArgumentError(
6343 "aggregate projection index exceeds supported range".into(),
6344 )
6345 })?;
6346 (projection_index, inferred_type, field_id)
6347 };
6348 let normalized_type = Self::validate_aggregate_type(
6349 Some(data_type.clone()),
6350 "TOTAL",
6351 &[DataType::Int64, DataType::Float64],
6352 )?;
6353 specs.push(AggregateSpec {
6354 alias: key.clone(),
6355 kind: AggregateKind::Total {
6356 field_id,
6357 data_type: normalized_type,
6358 distinct: *distinct,
6359 },
6360 });
6361 spec_to_projection.push(Some(projection_index));
6362 }
6363 AggregateCall::Avg { expr, distinct } => {
6364 let (projection_index, data_type, field_id) =
6365 if let Some(col_name) = try_extract_simple_column(expr) {
6366 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6367 Error::InvalidArgumentError(format!(
6368 "unknown column '{}' in aggregate",
6369 col_name
6370 ))
6371 })?;
6372 let projection_index = get_or_insert_column_projection(
6373 &mut projections,
6374 &mut column_projection_cache,
6375 table_ref,
6376 col,
6377 );
6378 let data_type = col.data_type.clone();
6379 (projection_index, data_type, col.field_id)
6380 } else {
6381 let (projection_index, inferred_type) = ensure_computed_projection(
6382 expr,
6383 table_ref,
6384 &mut projections,
6385 &mut computed_projection_cache,
6386 &mut computed_alias_counter,
6387 )?;
6388 tracing::debug!(
6389 "AVG aggregate expr={:?} inferred_type={:?}",
6390 expr,
6391 inferred_type
6392 );
6393 let field_id = u32::try_from(projection_index).map_err(|_| {
6394 Error::InvalidArgumentError(
6395 "aggregate projection index exceeds supported range".into(),
6396 )
6397 })?;
6398 (projection_index, inferred_type, field_id)
6399 };
6400 let normalized_type = Self::validate_aggregate_type(
6401 Some(data_type.clone()),
6402 "AVG",
6403 &[DataType::Int64, DataType::Float64],
6404 )?;
6405 specs.push(AggregateSpec {
6406 alias: key.clone(),
6407 kind: AggregateKind::Avg {
6408 field_id,
6409 data_type: normalized_type,
6410 distinct: *distinct,
6411 },
6412 });
6413 spec_to_projection.push(Some(projection_index));
6414 }
6415 AggregateCall::Min(expr) => {
6416 let (projection_index, data_type, field_id) =
6417 if let Some(col_name) = try_extract_simple_column(expr) {
6418 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6419 Error::InvalidArgumentError(format!(
6420 "unknown column '{}' in aggregate",
6421 col_name
6422 ))
6423 })?;
6424 let projection_index = get_or_insert_column_projection(
6425 &mut projections,
6426 &mut column_projection_cache,
6427 table_ref,
6428 col,
6429 );
6430 let data_type = col.data_type.clone();
6431 (projection_index, data_type, col.field_id)
6432 } else {
6433 let (projection_index, inferred_type) = ensure_computed_projection(
6434 expr,
6435 table_ref,
6436 &mut projections,
6437 &mut computed_projection_cache,
6438 &mut computed_alias_counter,
6439 )?;
6440 let field_id = u32::try_from(projection_index).map_err(|_| {
6441 Error::InvalidArgumentError(
6442 "aggregate projection index exceeds supported range".into(),
6443 )
6444 })?;
6445 (projection_index, inferred_type, field_id)
6446 };
6447 let normalized_type = Self::validate_aggregate_type(
6448 Some(data_type.clone()),
6449 "MIN",
6450 &[DataType::Int64, DataType::Float64],
6451 )?;
6452 specs.push(AggregateSpec {
6453 alias: key.clone(),
6454 kind: AggregateKind::Min {
6455 field_id,
6456 data_type: normalized_type,
6457 },
6458 });
6459 spec_to_projection.push(Some(projection_index));
6460 }
6461 AggregateCall::Max(expr) => {
6462 let (projection_index, data_type, field_id) =
6463 if let Some(col_name) = try_extract_simple_column(expr) {
6464 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6465 Error::InvalidArgumentError(format!(
6466 "unknown column '{}' in aggregate",
6467 col_name
6468 ))
6469 })?;
6470 let projection_index = get_or_insert_column_projection(
6471 &mut projections,
6472 &mut column_projection_cache,
6473 table_ref,
6474 col,
6475 );
6476 let data_type = col.data_type.clone();
6477 (projection_index, data_type, col.field_id)
6478 } else {
6479 let (projection_index, inferred_type) = ensure_computed_projection(
6480 expr,
6481 table_ref,
6482 &mut projections,
6483 &mut computed_projection_cache,
6484 &mut computed_alias_counter,
6485 )?;
6486 let field_id = u32::try_from(projection_index).map_err(|_| {
6487 Error::InvalidArgumentError(
6488 "aggregate projection index exceeds supported range".into(),
6489 )
6490 })?;
6491 (projection_index, inferred_type, field_id)
6492 };
6493 let normalized_type = Self::validate_aggregate_type(
6494 Some(data_type.clone()),
6495 "MAX",
6496 &[DataType::Int64, DataType::Float64],
6497 )?;
6498 specs.push(AggregateSpec {
6499 alias: key.clone(),
6500 kind: AggregateKind::Max {
6501 field_id,
6502 data_type: normalized_type,
6503 },
6504 });
6505 spec_to_projection.push(Some(projection_index));
6506 }
6507 AggregateCall::CountNulls(expr) => {
6508 if let Some(col_name) = try_extract_simple_column(expr) {
6509 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6510 Error::InvalidArgumentError(format!(
6511 "unknown column '{}' in aggregate",
6512 col_name
6513 ))
6514 })?;
6515 let projection_index = get_or_insert_column_projection(
6516 &mut projections,
6517 &mut column_projection_cache,
6518 table_ref,
6519 col,
6520 );
6521 specs.push(AggregateSpec {
6522 alias: key.clone(),
6523 kind: AggregateKind::CountNulls {
6524 field_id: col.field_id,
6525 },
6526 });
6527 spec_to_projection.push(Some(projection_index));
6528 } else {
6529 let (projection_index, _dtype) = ensure_computed_projection(
6530 expr,
6531 table_ref,
6532 &mut projections,
6533 &mut computed_projection_cache,
6534 &mut computed_alias_counter,
6535 )?;
6536 let field_id = u32::try_from(projection_index).map_err(|_| {
6537 Error::InvalidArgumentError(
6538 "aggregate projection index exceeds supported range".into(),
6539 )
6540 })?;
6541 specs.push(AggregateSpec {
6542 alias: key.clone(),
6543 kind: AggregateKind::CountNulls { field_id },
6544 });
6545 spec_to_projection.push(Some(projection_index));
6546 }
6547 }
6548 AggregateCall::GroupConcat {
6549 expr,
6550 distinct,
6551 separator,
6552 } => {
6553 if let Some(col_name) = try_extract_simple_column(expr) {
6554 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6555 Error::InvalidArgumentError(format!(
6556 "unknown column '{}' in aggregate",
6557 col_name
6558 ))
6559 })?;
6560 let projection_index = get_or_insert_column_projection(
6561 &mut projections,
6562 &mut column_projection_cache,
6563 table_ref,
6564 col,
6565 );
6566 specs.push(AggregateSpec {
6567 alias: key.clone(),
6568 kind: AggregateKind::GroupConcat {
6569 field_id: col.field_id,
6570 distinct: *distinct,
6571 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
6572 },
6573 });
6574 spec_to_projection.push(Some(projection_index));
6575 } else {
6576 let (projection_index, _dtype) = ensure_computed_projection(
6577 expr,
6578 table_ref,
6579 &mut projections,
6580 &mut computed_projection_cache,
6581 &mut computed_alias_counter,
6582 )?;
6583 let field_id = u32::try_from(projection_index).map_err(|_| {
6584 Error::InvalidArgumentError(
6585 "aggregate projection index exceeds supported range".into(),
6586 )
6587 })?;
6588 specs.push(AggregateSpec {
6589 alias: key.clone(),
6590 kind: AggregateKind::GroupConcat {
6591 field_id,
6592 distinct: *distinct,
6593 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
6594 },
6595 });
6596 spec_to_projection.push(Some(projection_index));
6597 }
6598 }
6599 }
6600 }
6601
6602 let filter_expr = match filter {
6603 Some(expr) => crate::translation::expression::translate_predicate(
6604 expr.clone(),
6605 table_ref.schema.as_ref(),
6606 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
6607 )?,
6608 None => {
6609 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
6610 Error::InvalidArgumentError(
6611 "table has no columns; cannot perform aggregate scan".into(),
6612 )
6613 })?;
6614 crate::translation::expression::full_table_scan_filter(field_id)
6615 }
6616 };
6617
6618 if projections.is_empty() {
6619 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
6620 Error::InvalidArgumentError(
6621 "table has no columns; cannot perform aggregate scan".into(),
6622 )
6623 })?;
6624 projections.push(ScanProjection::from(StoreProjection::with_alias(
6625 LogicalFieldId::for_user(table.table_id(), field_id),
6626 table
6627 .schema
6628 .column_by_field_id(field_id)
6629 .map(|c| c.name.clone())
6630 .unwrap_or_else(|| format!("col{field_id}")),
6631 )));
6632 }
6633
6634 let base_options = ScanStreamOptions {
6635 include_nulls: true,
6636 order: None,
6637 row_id_filter: None,
6638 include_row_ids: true,
6639 };
6640
6641 let count_star_override: Option<i64> = None;
6642
6643 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
6644 for (idx, spec) in specs.iter().enumerate() {
6645 states.push(AggregateState {
6646 alias: spec.alias.clone(),
6647 accumulator: AggregateAccumulator::new_with_projection_index(
6648 spec,
6649 spec_to_projection[idx],
6650 count_star_override,
6651 )?,
6652 override_value: match &spec.kind {
6653 AggregateKind::Count { field_id: None, .. } => count_star_override,
6654 _ => None,
6655 },
6656 });
6657 }
6658
6659 let mut error: Option<Error> = None;
6660 match table.storage().scan_stream(
6661 &projections,
6662 &filter_expr,
6663 ScanStreamOptions {
6664 row_id_filter: row_filter.clone(),
6665 ..base_options
6666 },
6667 &mut |batch| {
6668 if error.is_some() {
6669 return;
6670 }
6671 for state in &mut states {
6672 if let Err(err) = state.update(&batch) {
6673 error = Some(err);
6674 return;
6675 }
6676 }
6677 },
6678 ) {
6679 Ok(()) => {}
6680 Err(llkv_result::Error::NotFound) => {}
6681 Err(err) => return Err(err),
6682 }
6683 if let Some(err) = error {
6684 return Err(err);
6685 }
6686
6687 for state in states {
6688 let alias = state.alias.clone();
6689 let (_field, array) = state.finalize()?;
6690
6691 if let Some(int64_array) = array.as_any().downcast_ref::<arrow::array::Int64Array>() {
6692 if int64_array.len() != 1 {
6693 return Err(Error::Internal(format!(
6694 "Expected single value from aggregate, got {}",
6695 int64_array.len()
6696 )));
6697 }
6698 let value = if int64_array.is_null(0) {
6699 AggregateValue::Null
6700 } else {
6701 AggregateValue::Int64(int64_array.value(0))
6702 };
6703 results.insert(alias, value);
6704 } else if let Some(float64_array) =
6705 array.as_any().downcast_ref::<arrow::array::Float64Array>()
6706 {
6707 if float64_array.len() != 1 {
6708 return Err(Error::Internal(format!(
6709 "Expected single value from aggregate, got {}",
6710 float64_array.len()
6711 )));
6712 }
6713 let value = if float64_array.is_null(0) {
6714 AggregateValue::Null
6715 } else {
6716 AggregateValue::Float64(float64_array.value(0))
6717 };
6718 results.insert(alias, value);
6719 } else if let Some(string_array) =
6720 array.as_any().downcast_ref::<arrow::array::StringArray>()
6721 {
6722 if string_array.len() != 1 {
6723 return Err(Error::Internal(format!(
6724 "Expected single value from aggregate, got {}",
6725 string_array.len()
6726 )));
6727 }
6728 let value = if string_array.is_null(0) {
6729 AggregateValue::Null
6730 } else {
6731 AggregateValue::String(string_array.value(0).to_string())
6732 };
6733 results.insert(alias, value);
6734 } else if let Some(decimal_array) = array
6735 .as_any()
6736 .downcast_ref::<arrow::array::Decimal128Array>()
6737 {
6738 if decimal_array.len() != 1 {
6739 return Err(Error::Internal(format!(
6740 "Expected single value from aggregate, got {}",
6741 decimal_array.len()
6742 )));
6743 }
6744 let value = if decimal_array.is_null(0) {
6745 AggregateValue::Null
6746 } else {
6747 AggregateValue::Decimal128 {
6748 value: decimal_array.value(0),
6749 scale: decimal_array.scale(),
6750 }
6751 };
6752 results.insert(alias, value);
6753 } else {
6754 return Err(Error::Internal(format!(
6755 "Unexpected array type from aggregate: {:?}",
6756 array.data_type()
6757 )));
6758 }
6759 }
6760
6761 Ok(results)
6762 }
6763
6764 fn evaluate_having_expr(
6765 expr: &llkv_expr::expr::Expr<String>,
6766 aggregates: &FxHashMap<String, PlanValue>,
6767 row_batch: &RecordBatch,
6768 column_lookup: &FxHashMap<String, usize>,
6769 row_idx: usize,
6770 ) -> ExecutorResult<Option<bool>> {
6771 fn compare_plan_values_for_pred(
6772 left: &PlanValue,
6773 right: &PlanValue,
6774 ) -> Option<std::cmp::Ordering> {
6775 match (left, right) {
6776 (PlanValue::Integer(l), PlanValue::Integer(r)) => Some(l.cmp(r)),
6777 (PlanValue::Float(l), PlanValue::Float(r)) => l.partial_cmp(r),
6778 (PlanValue::Integer(l), PlanValue::Float(r)) => (*l as f64).partial_cmp(r),
6779 (PlanValue::Float(l), PlanValue::Integer(r)) => l.partial_cmp(&(*r as f64)),
6780 (PlanValue::String(l), PlanValue::String(r)) => Some(l.cmp(r)),
6781 (PlanValue::Interval(l), PlanValue::Interval(r)) => {
6782 Some(compare_interval_values(*l, *r))
6783 }
6784 _ => None,
6785 }
6786 }
6787
6788 fn evaluate_ordering_predicate<F>(
6789 value: &PlanValue,
6790 literal: &Literal,
6791 predicate: F,
6792 ) -> ExecutorResult<Option<bool>>
6793 where
6794 F: Fn(std::cmp::Ordering) -> bool,
6795 {
6796 if matches!(value, PlanValue::Null) {
6797 return Ok(None);
6798 }
6799 let expected = llkv_plan::plan_value_from_literal(literal)?;
6800 if matches!(expected, PlanValue::Null) {
6801 return Ok(None);
6802 }
6803
6804 match compare_plan_values_for_pred(value, &expected) {
6805 Some(ordering) => Ok(Some(predicate(ordering))),
6806 None => Err(Error::InvalidArgumentError(
6807 "unsupported HAVING comparison between column value and literal".into(),
6808 )),
6809 }
6810 }
6811
6812 match expr {
6813 llkv_expr::expr::Expr::Compare { left, op, right } => {
6814 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6815 left,
6816 aggregates,
6817 Some(row_batch),
6818 Some(column_lookup),
6819 row_idx,
6820 )?;
6821 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6822 right,
6823 aggregates,
6824 Some(row_batch),
6825 Some(column_lookup),
6826 row_idx,
6827 )?;
6828
6829 let (left_val, right_val) = match (&left_val, &right_val) {
6831 (PlanValue::Integer(i), PlanValue::Float(_)) => {
6832 (PlanValue::Float(*i as f64), right_val)
6833 }
6834 (PlanValue::Float(_), PlanValue::Integer(i)) => {
6835 (left_val, PlanValue::Float(*i as f64))
6836 }
6837 _ => (left_val, right_val),
6838 };
6839
6840 match (left_val, right_val) {
6841 (PlanValue::Null, _) | (_, PlanValue::Null) => Ok(None),
6843 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
6844 use llkv_expr::expr::CompareOp;
6845 Ok(Some(match op {
6846 CompareOp::Eq => l == r,
6847 CompareOp::NotEq => l != r,
6848 CompareOp::Lt => l < r,
6849 CompareOp::LtEq => l <= r,
6850 CompareOp::Gt => l > r,
6851 CompareOp::GtEq => l >= r,
6852 }))
6853 }
6854 (PlanValue::Float(l), PlanValue::Float(r)) => {
6855 use llkv_expr::expr::CompareOp;
6856 Ok(Some(match op {
6857 CompareOp::Eq => l == r,
6858 CompareOp::NotEq => l != r,
6859 CompareOp::Lt => l < r,
6860 CompareOp::LtEq => l <= r,
6861 CompareOp::Gt => l > r,
6862 CompareOp::GtEq => l >= r,
6863 }))
6864 }
6865 (PlanValue::Interval(l), PlanValue::Interval(r)) => {
6866 use llkv_expr::expr::CompareOp;
6867 let ordering = compare_interval_values(l, r);
6868 Ok(Some(match op {
6869 CompareOp::Eq => ordering == std::cmp::Ordering::Equal,
6870 CompareOp::NotEq => ordering != std::cmp::Ordering::Equal,
6871 CompareOp::Lt => ordering == std::cmp::Ordering::Less,
6872 CompareOp::LtEq => {
6873 matches!(
6874 ordering,
6875 std::cmp::Ordering::Less | std::cmp::Ordering::Equal
6876 )
6877 }
6878 CompareOp::Gt => ordering == std::cmp::Ordering::Greater,
6879 CompareOp::GtEq => {
6880 matches!(
6881 ordering,
6882 std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
6883 )
6884 }
6885 }))
6886 }
6887 _ => Ok(Some(false)),
6888 }
6889 }
6890 llkv_expr::expr::Expr::Not(inner) => {
6891 match Self::evaluate_having_expr(
6893 inner,
6894 aggregates,
6895 row_batch,
6896 column_lookup,
6897 row_idx,
6898 )? {
6899 Some(b) => Ok(Some(!b)),
6900 None => Ok(None), }
6902 }
6903 llkv_expr::expr::Expr::InList {
6904 expr: test_expr,
6905 list,
6906 negated,
6907 } => {
6908 let test_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6909 test_expr,
6910 aggregates,
6911 Some(row_batch),
6912 Some(column_lookup),
6913 row_idx,
6914 )?;
6915
6916 if matches!(test_val, PlanValue::Null) {
6919 return Ok(None);
6920 }
6921
6922 let mut found = false;
6923 let mut has_null = false;
6924
6925 for list_item in list {
6926 let list_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6927 list_item,
6928 aggregates,
6929 Some(row_batch),
6930 Some(column_lookup),
6931 row_idx,
6932 )?;
6933
6934 if matches!(list_val, PlanValue::Null) {
6936 has_null = true;
6937 continue;
6938 }
6939
6940 let matches = match (&test_val, &list_val) {
6942 (PlanValue::Integer(a), PlanValue::Integer(b)) => a == b,
6943 (PlanValue::Float(a), PlanValue::Float(b)) => a == b,
6944 (PlanValue::Integer(a), PlanValue::Float(b)) => (*a as f64) == *b,
6945 (PlanValue::Float(a), PlanValue::Integer(b)) => *a == (*b as f64),
6946 (PlanValue::String(a), PlanValue::String(b)) => a == b,
6947 (PlanValue::Interval(a), PlanValue::Interval(b)) => {
6948 compare_interval_values(*a, *b) == std::cmp::Ordering::Equal
6949 }
6950 _ => false,
6951 };
6952
6953 if matches {
6954 found = true;
6955 break;
6956 }
6957 }
6958
6959 if *negated {
6963 Ok(if found {
6965 Some(false)
6966 } else if has_null {
6967 None } else {
6969 Some(true)
6970 })
6971 } else {
6972 Ok(if found {
6974 Some(true)
6975 } else if has_null {
6976 None } else {
6978 Some(false)
6979 })
6980 }
6981 }
6982 llkv_expr::expr::Expr::IsNull { expr, negated } => {
6983 let val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6985 expr,
6986 aggregates,
6987 Some(row_batch),
6988 Some(column_lookup),
6989 row_idx,
6990 )?;
6991
6992 let is_null = matches!(val, PlanValue::Null);
6996 Ok(Some(if *negated { !is_null } else { is_null }))
6997 }
6998 llkv_expr::expr::Expr::Literal(val) => Ok(Some(*val)),
6999 llkv_expr::expr::Expr::And(exprs) => {
7000 let mut has_null = false;
7002 for e in exprs {
7003 match Self::evaluate_having_expr(
7004 e,
7005 aggregates,
7006 row_batch,
7007 column_lookup,
7008 row_idx,
7009 )? {
7010 Some(false) => return Ok(Some(false)), None => has_null = true,
7012 Some(true) => {} }
7014 }
7015 Ok(if has_null { None } else { Some(true) })
7016 }
7017 llkv_expr::expr::Expr::Or(exprs) => {
7018 let mut has_null = false;
7020 for e in exprs {
7021 match Self::evaluate_having_expr(
7022 e,
7023 aggregates,
7024 row_batch,
7025 column_lookup,
7026 row_idx,
7027 )? {
7028 Some(true) => return Ok(Some(true)), None => has_null = true,
7030 Some(false) => {} }
7032 }
7033 Ok(if has_null { None } else { Some(false) })
7034 }
7035 llkv_expr::expr::Expr::Pred(filter) => {
7036 use llkv_expr::expr::Operator;
7039
7040 let col_name = &filter.field_id;
7041 let col_idx = column_lookup
7042 .get(&col_name.to_ascii_lowercase())
7043 .ok_or_else(|| {
7044 Error::InvalidArgumentError(format!(
7045 "column '{}' not found in HAVING context",
7046 col_name
7047 ))
7048 })?;
7049
7050 let value = llkv_plan::plan_value_from_array(row_batch.column(*col_idx), row_idx)?;
7051
7052 match &filter.op {
7053 Operator::IsNull => Ok(Some(matches!(value, PlanValue::Null))),
7054 Operator::IsNotNull => Ok(Some(!matches!(value, PlanValue::Null))),
7055 Operator::Equals(expected) => {
7056 if matches!(value, PlanValue::Null) {
7058 return Ok(None);
7059 }
7060 let expected_value = llkv_plan::plan_value_from_literal(expected)?;
7062 if matches!(expected_value, PlanValue::Null) {
7063 return Ok(None);
7064 }
7065 Ok(Some(value == expected_value))
7066 }
7067 Operator::GreaterThan(expected) => {
7068 evaluate_ordering_predicate(&value, expected, |ordering| {
7069 ordering == std::cmp::Ordering::Greater
7070 })
7071 }
7072 Operator::GreaterThanOrEquals(expected) => {
7073 evaluate_ordering_predicate(&value, expected, |ordering| {
7074 ordering == std::cmp::Ordering::Greater
7075 || ordering == std::cmp::Ordering::Equal
7076 })
7077 }
7078 Operator::LessThan(expected) => {
7079 evaluate_ordering_predicate(&value, expected, |ordering| {
7080 ordering == std::cmp::Ordering::Less
7081 })
7082 }
7083 Operator::LessThanOrEquals(expected) => {
7084 evaluate_ordering_predicate(&value, expected, |ordering| {
7085 ordering == std::cmp::Ordering::Less
7086 || ordering == std::cmp::Ordering::Equal
7087 })
7088 }
7089 _ => {
7090 Err(Error::InvalidArgumentError(format!(
7093 "Operator {:?} not supported for column predicates in HAVING clause",
7094 filter.op
7095 )))
7096 }
7097 }
7098 }
7099 llkv_expr::expr::Expr::Exists(_) => Err(Error::InvalidArgumentError(
7100 "EXISTS subqueries not supported in HAVING clause".into(),
7101 )),
7102 }
7103 }
7104
7105 fn evaluate_expr_with_plan_value_aggregates_and_row(
7106 expr: &ScalarExpr<String>,
7107 aggregates: &FxHashMap<String, PlanValue>,
7108 row_batch: Option<&RecordBatch>,
7109 column_lookup: Option<&FxHashMap<String, usize>>,
7110 row_idx: usize,
7111 ) -> ExecutorResult<PlanValue> {
7112 use llkv_expr::expr::BinaryOp;
7113 use llkv_expr::literal::Literal;
7114
7115 match expr {
7116 ScalarExpr::Literal(Literal::Int128(v)) => Ok(PlanValue::Integer(*v as i64)),
7117 ScalarExpr::Literal(Literal::Float64(v)) => Ok(PlanValue::Float(*v)),
7118 ScalarExpr::Literal(Literal::Decimal128(value)) => Ok(PlanValue::Decimal(*value)),
7119 ScalarExpr::Literal(Literal::Boolean(v)) => {
7120 Ok(PlanValue::Integer(if *v { 1 } else { 0 }))
7121 }
7122 ScalarExpr::Literal(Literal::String(s)) => Ok(PlanValue::String(s.clone())),
7123 ScalarExpr::Literal(Literal::Date32(days)) => Ok(PlanValue::Date32(*days)),
7124 ScalarExpr::Literal(Literal::Interval(interval)) => Ok(PlanValue::Interval(*interval)),
7125 ScalarExpr::Literal(Literal::Null) => Ok(PlanValue::Null),
7126 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
7127 "Struct literals not supported in aggregate expressions".into(),
7128 )),
7129 ScalarExpr::Column(col_name) => {
7130 if let (Some(batch), Some(lookup)) = (row_batch, column_lookup) {
7132 let col_idx = lookup.get(&col_name.to_ascii_lowercase()).ok_or_else(|| {
7133 Error::InvalidArgumentError(format!("column '{}' not found", col_name))
7134 })?;
7135 llkv_plan::plan_value_from_array(batch.column(*col_idx), row_idx)
7136 } else {
7137 Err(Error::InvalidArgumentError(
7138 "Column references not supported in aggregate-only expressions".into(),
7139 ))
7140 }
7141 }
7142 ScalarExpr::Compare { left, op, right } => {
7143 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7145 left,
7146 aggregates,
7147 row_batch,
7148 column_lookup,
7149 row_idx,
7150 )?;
7151 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7152 right,
7153 aggregates,
7154 row_batch,
7155 column_lookup,
7156 row_idx,
7157 )?;
7158
7159 if matches!(left_val, PlanValue::Null) || matches!(right_val, PlanValue::Null) {
7161 return Ok(PlanValue::Null);
7162 }
7163
7164 let (left_val, right_val) = match (&left_val, &right_val) {
7166 (PlanValue::Integer(i), PlanValue::Float(_)) => {
7167 (PlanValue::Float(*i as f64), right_val)
7168 }
7169 (PlanValue::Float(_), PlanValue::Integer(i)) => {
7170 (left_val, PlanValue::Float(*i as f64))
7171 }
7172 _ => (left_val, right_val),
7173 };
7174
7175 let result = match (&left_val, &right_val) {
7177 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
7178 use llkv_expr::expr::CompareOp;
7179 match op {
7180 CompareOp::Eq => l == r,
7181 CompareOp::NotEq => l != r,
7182 CompareOp::Lt => l < r,
7183 CompareOp::LtEq => l <= r,
7184 CompareOp::Gt => l > r,
7185 CompareOp::GtEq => l >= r,
7186 }
7187 }
7188 (PlanValue::Float(l), PlanValue::Float(r)) => {
7189 use llkv_expr::expr::CompareOp;
7190 match op {
7191 CompareOp::Eq => l == r,
7192 CompareOp::NotEq => l != r,
7193 CompareOp::Lt => l < r,
7194 CompareOp::LtEq => l <= r,
7195 CompareOp::Gt => l > r,
7196 CompareOp::GtEq => l >= r,
7197 }
7198 }
7199 (PlanValue::String(l), PlanValue::String(r)) => {
7200 use llkv_expr::expr::CompareOp;
7201 match op {
7202 CompareOp::Eq => l == r,
7203 CompareOp::NotEq => l != r,
7204 CompareOp::Lt => l < r,
7205 CompareOp::LtEq => l <= r,
7206 CompareOp::Gt => l > r,
7207 CompareOp::GtEq => l >= r,
7208 }
7209 }
7210 (PlanValue::Interval(l), PlanValue::Interval(r)) => {
7211 use llkv_expr::expr::CompareOp;
7212 let ordering = compare_interval_values(*l, *r);
7213 match op {
7214 CompareOp::Eq => ordering == std::cmp::Ordering::Equal,
7215 CompareOp::NotEq => ordering != std::cmp::Ordering::Equal,
7216 CompareOp::Lt => ordering == std::cmp::Ordering::Less,
7217 CompareOp::LtEq => {
7218 matches!(
7219 ordering,
7220 std::cmp::Ordering::Less | std::cmp::Ordering::Equal
7221 )
7222 }
7223 CompareOp::Gt => ordering == std::cmp::Ordering::Greater,
7224 CompareOp::GtEq => {
7225 matches!(
7226 ordering,
7227 std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
7228 )
7229 }
7230 }
7231 }
7232 _ => false,
7233 };
7234
7235 Ok(PlanValue::Integer(if result { 1 } else { 0 }))
7237 }
7238 ScalarExpr::Not(inner) => {
7239 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7240 inner,
7241 aggregates,
7242 row_batch,
7243 column_lookup,
7244 row_idx,
7245 )?;
7246 match value {
7247 PlanValue::Integer(v) => Ok(PlanValue::Integer(if v != 0 { 0 } else { 1 })),
7248 PlanValue::Float(v) => Ok(PlanValue::Integer(if v != 0.0 { 0 } else { 1 })),
7249 PlanValue::Null => Ok(PlanValue::Null),
7250 other => Err(Error::InvalidArgumentError(format!(
7251 "logical NOT does not support value {other:?}"
7252 ))),
7253 }
7254 }
7255 ScalarExpr::IsNull { expr, negated } => {
7256 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7257 expr,
7258 aggregates,
7259 row_batch,
7260 column_lookup,
7261 row_idx,
7262 )?;
7263 let is_null = matches!(value, PlanValue::Null);
7264 let condition = if is_null { !negated } else { *negated };
7265 Ok(PlanValue::Integer(if condition { 1 } else { 0 }))
7266 }
7267 ScalarExpr::Aggregate(agg) => {
7268 let key = format!("{:?}", agg);
7269 aggregates
7270 .get(&key)
7271 .cloned()
7272 .ok_or_else(|| Error::Internal(format!("Aggregate value not found: {}", key)))
7273 }
7274 ScalarExpr::Binary { left, op, right } => {
7275 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7276 left,
7277 aggregates,
7278 row_batch,
7279 column_lookup,
7280 row_idx,
7281 )?;
7282 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7283 right,
7284 aggregates,
7285 row_batch,
7286 column_lookup,
7287 row_idx,
7288 )?;
7289
7290 match op {
7291 BinaryOp::Add
7292 | BinaryOp::Subtract
7293 | BinaryOp::Multiply
7294 | BinaryOp::Divide
7295 | BinaryOp::Modulo => {
7296 if matches!(&left_val, PlanValue::Null)
7297 || matches!(&right_val, PlanValue::Null)
7298 {
7299 return Ok(PlanValue::Null);
7300 }
7301
7302 if matches!(left_val, PlanValue::Interval(_))
7303 || matches!(right_val, PlanValue::Interval(_))
7304 {
7305 return Err(Error::InvalidArgumentError(
7306 "interval arithmetic not supported in aggregate expressions".into(),
7307 ));
7308 }
7309
7310 if matches!(op, BinaryOp::Divide)
7312 && let (PlanValue::Integer(lhs), PlanValue::Integer(rhs)) =
7313 (&left_val, &right_val)
7314 {
7315 if *rhs == 0 {
7316 return Ok(PlanValue::Null);
7317 }
7318
7319 if *lhs == i64::MIN && *rhs == -1 {
7320 return Ok(PlanValue::Float((*lhs as f64) / (*rhs as f64)));
7321 }
7322
7323 return Ok(PlanValue::Integer(lhs / rhs));
7324 }
7325
7326 let has_decimal = matches!(&left_val, PlanValue::Decimal(_))
7328 || matches!(&right_val, PlanValue::Decimal(_));
7329
7330 if has_decimal {
7331 use llkv_types::decimal::DecimalValue;
7332
7333 let left_dec = match &left_val {
7335 PlanValue::Integer(i) => DecimalValue::from_i64(*i),
7336 PlanValue::Float(_f) => {
7337 return Err(Error::InvalidArgumentError(
7339 "Cannot perform exact decimal arithmetic with Float operands"
7340 .into(),
7341 ));
7342 }
7343 PlanValue::Decimal(d) => *d,
7344 other => {
7345 return Err(Error::InvalidArgumentError(format!(
7346 "Non-numeric value {:?} in binary operation",
7347 other
7348 )));
7349 }
7350 };
7351
7352 let right_dec = match &right_val {
7353 PlanValue::Integer(i) => DecimalValue::from_i64(*i),
7354 PlanValue::Float(_f) => {
7355 return Err(Error::InvalidArgumentError(
7356 "Cannot perform exact decimal arithmetic with Float operands"
7357 .into(),
7358 ));
7359 }
7360 PlanValue::Decimal(d) => *d,
7361 other => {
7362 return Err(Error::InvalidArgumentError(format!(
7363 "Non-numeric value {:?} in binary operation",
7364 other
7365 )));
7366 }
7367 };
7368
7369 let result_dec = match op {
7371 BinaryOp::Add => {
7372 llkv_compute::scalar::decimal::add(left_dec, right_dec)
7373 .map_err(|e| {
7374 Error::InvalidArgumentError(format!(
7375 "Decimal addition overflow: {}",
7376 e
7377 ))
7378 })?
7379 }
7380 BinaryOp::Subtract => {
7381 llkv_compute::scalar::decimal::sub(left_dec, right_dec)
7382 .map_err(|e| {
7383 Error::InvalidArgumentError(format!(
7384 "Decimal subtraction overflow: {}",
7385 e
7386 ))
7387 })?
7388 }
7389 BinaryOp::Multiply => {
7390 llkv_compute::scalar::decimal::mul(left_dec, right_dec)
7391 .map_err(|e| {
7392 Error::InvalidArgumentError(format!(
7393 "Decimal multiplication overflow: {}",
7394 e
7395 ))
7396 })?
7397 }
7398 BinaryOp::Divide => {
7399 if right_dec.raw_value() == 0 {
7401 return Ok(PlanValue::Null);
7402 }
7403 let target_scale = left_dec.scale();
7405 llkv_compute::scalar::decimal::div(
7406 left_dec,
7407 right_dec,
7408 target_scale,
7409 )
7410 .map_err(|e| {
7411 Error::InvalidArgumentError(format!(
7412 "Decimal division error: {}",
7413 e
7414 ))
7415 })?
7416 }
7417 BinaryOp::Modulo => {
7418 return Err(Error::InvalidArgumentError(
7419 "Modulo not supported for Decimal types".into(),
7420 ));
7421 }
7422 BinaryOp::And
7423 | BinaryOp::Or
7424 | BinaryOp::BitwiseShiftLeft
7425 | BinaryOp::BitwiseShiftRight => unreachable!(),
7426 };
7427
7428 return Ok(PlanValue::Decimal(result_dec));
7429 }
7430
7431 let left_is_float = matches!(&left_val, PlanValue::Float(_));
7433 let right_is_float = matches!(&right_val, PlanValue::Float(_));
7434
7435 let left_num = match left_val {
7436 PlanValue::Integer(i) => i as f64,
7437 PlanValue::Float(f) => f,
7438 other => {
7439 return Err(Error::InvalidArgumentError(format!(
7440 "Non-numeric value {:?} in binary operation",
7441 other
7442 )));
7443 }
7444 };
7445 let right_num = match right_val {
7446 PlanValue::Integer(i) => i as f64,
7447 PlanValue::Float(f) => f,
7448 other => {
7449 return Err(Error::InvalidArgumentError(format!(
7450 "Non-numeric value {:?} in binary operation",
7451 other
7452 )));
7453 }
7454 };
7455
7456 let result = match op {
7457 BinaryOp::Add => left_num + right_num,
7458 BinaryOp::Subtract => left_num - right_num,
7459 BinaryOp::Multiply => left_num * right_num,
7460 BinaryOp::Divide => {
7461 if right_num == 0.0 {
7462 return Ok(PlanValue::Null);
7463 }
7464 left_num / right_num
7465 }
7466 BinaryOp::Modulo => {
7467 if right_num == 0.0 {
7468 return Ok(PlanValue::Null);
7469 }
7470 left_num % right_num
7471 }
7472 BinaryOp::And
7473 | BinaryOp::Or
7474 | BinaryOp::BitwiseShiftLeft
7475 | BinaryOp::BitwiseShiftRight => unreachable!(),
7476 };
7477
7478 if matches!(op, BinaryOp::Divide) {
7479 return Ok(PlanValue::Float(result));
7480 }
7481
7482 if left_is_float || right_is_float {
7483 Ok(PlanValue::Float(result))
7484 } else {
7485 Ok(PlanValue::Integer(result as i64))
7486 }
7487 }
7488 BinaryOp::And => Ok(evaluate_plan_value_logical_and(left_val, right_val)),
7489 BinaryOp::Or => Ok(evaluate_plan_value_logical_or(left_val, right_val)),
7490 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
7491 if matches!(&left_val, PlanValue::Null)
7492 || matches!(&right_val, PlanValue::Null)
7493 {
7494 return Ok(PlanValue::Null);
7495 }
7496
7497 let lhs = match left_val {
7499 PlanValue::Integer(i) => i,
7500 PlanValue::Float(f) => f as i64,
7501 other => {
7502 return Err(Error::InvalidArgumentError(format!(
7503 "Non-numeric value {:?} in bitwise shift operation",
7504 other
7505 )));
7506 }
7507 };
7508 let rhs = match right_val {
7509 PlanValue::Integer(i) => i,
7510 PlanValue::Float(f) => f as i64,
7511 other => {
7512 return Err(Error::InvalidArgumentError(format!(
7513 "Non-numeric value {:?} in bitwise shift operation",
7514 other
7515 )));
7516 }
7517 };
7518
7519 let result = match op {
7521 BinaryOp::BitwiseShiftLeft => lhs.wrapping_shl(rhs as u32),
7522 BinaryOp::BitwiseShiftRight => lhs.wrapping_shr(rhs as u32),
7523 _ => unreachable!(),
7524 };
7525
7526 Ok(PlanValue::Integer(result))
7527 }
7528 }
7529 }
7530 ScalarExpr::Cast { expr, data_type } => {
7531 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7533 expr,
7534 aggregates,
7535 row_batch,
7536 column_lookup,
7537 row_idx,
7538 )?;
7539
7540 if matches!(value, PlanValue::Null) {
7542 return Ok(PlanValue::Null);
7543 }
7544
7545 match data_type {
7547 DataType::Int64 | DataType::Int32 | DataType::Int16 | DataType::Int8 => {
7548 match value {
7549 PlanValue::Integer(i) => Ok(PlanValue::Integer(i)),
7550 PlanValue::Float(f) => Ok(PlanValue::Integer(f as i64)),
7551 PlanValue::String(s) => {
7552 s.parse::<i64>().map(PlanValue::Integer).map_err(|_| {
7553 Error::InvalidArgumentError(format!(
7554 "Cannot cast '{}' to integer",
7555 s
7556 ))
7557 })
7558 }
7559 _ => Err(Error::InvalidArgumentError(format!(
7560 "Cannot cast {:?} to integer",
7561 value
7562 ))),
7563 }
7564 }
7565 DataType::Float64 | DataType::Float32 => match value {
7566 PlanValue::Integer(i) => Ok(PlanValue::Float(i as f64)),
7567 PlanValue::Float(f) => Ok(PlanValue::Float(f)),
7568 PlanValue::String(s) => {
7569 s.parse::<f64>().map(PlanValue::Float).map_err(|_| {
7570 Error::InvalidArgumentError(format!("Cannot cast '{}' to float", s))
7571 })
7572 }
7573 _ => Err(Error::InvalidArgumentError(format!(
7574 "Cannot cast {:?} to float",
7575 value
7576 ))),
7577 },
7578 DataType::Utf8 | DataType::LargeUtf8 => match value {
7579 PlanValue::String(s) => Ok(PlanValue::String(s)),
7580 PlanValue::Integer(i) => Ok(PlanValue::String(i.to_string())),
7581 PlanValue::Float(f) => Ok(PlanValue::String(f.to_string())),
7582 PlanValue::Interval(_) => Err(Error::InvalidArgumentError(
7583 "Cannot cast interval to string in aggregate expressions".into(),
7584 )),
7585 _ => Err(Error::InvalidArgumentError(format!(
7586 "Cannot cast {:?} to string",
7587 value
7588 ))),
7589 },
7590 DataType::Interval(IntervalUnit::MonthDayNano) => match value {
7591 PlanValue::Interval(interval) => Ok(PlanValue::Interval(interval)),
7592 _ => Err(Error::InvalidArgumentError(format!(
7593 "Cannot cast {:?} to interval",
7594 value
7595 ))),
7596 },
7597 DataType::Date32 => match value {
7598 PlanValue::Date32(days) => Ok(PlanValue::Date32(days)),
7599 PlanValue::String(text) => {
7600 let days = parse_date32_literal(&text)?;
7601 Ok(PlanValue::Date32(days))
7602 }
7603 _ => Err(Error::InvalidArgumentError(format!(
7604 "Cannot cast {:?} to date",
7605 value
7606 ))),
7607 },
7608 _ => Err(Error::InvalidArgumentError(format!(
7609 "CAST to {:?} not supported in aggregate expressions",
7610 data_type
7611 ))),
7612 }
7613 }
7614 ScalarExpr::Case {
7615 operand,
7616 branches,
7617 else_expr,
7618 } => {
7619 let operand_value = if let Some(op) = operand {
7621 Some(Self::evaluate_expr_with_plan_value_aggregates_and_row(
7622 op,
7623 aggregates,
7624 row_batch,
7625 column_lookup,
7626 row_idx,
7627 )?)
7628 } else {
7629 None
7630 };
7631
7632 for (when_expr, then_expr) in branches {
7634 let matches = if let Some(ref op_val) = operand_value {
7635 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7637 when_expr,
7638 aggregates,
7639 row_batch,
7640 column_lookup,
7641 row_idx,
7642 )?;
7643 Self::simple_case_branch_matches(op_val, &when_val)
7644 } else {
7645 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7647 when_expr,
7648 aggregates,
7649 row_batch,
7650 column_lookup,
7651 row_idx,
7652 )?;
7653 match when_val {
7655 PlanValue::Integer(i) => i != 0,
7656 PlanValue::Float(f) => f != 0.0,
7657 PlanValue::Null => false,
7658 _ => false,
7659 }
7660 };
7661
7662 if matches {
7663 return Self::evaluate_expr_with_plan_value_aggregates_and_row(
7664 then_expr,
7665 aggregates,
7666 row_batch,
7667 column_lookup,
7668 row_idx,
7669 );
7670 }
7671 }
7672
7673 if let Some(else_e) = else_expr {
7675 Self::evaluate_expr_with_plan_value_aggregates_and_row(
7676 else_e,
7677 aggregates,
7678 row_batch,
7679 column_lookup,
7680 row_idx,
7681 )
7682 } else {
7683 Ok(PlanValue::Null)
7684 }
7685 }
7686 ScalarExpr::Coalesce(exprs) => {
7687 for expr in exprs {
7689 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7690 expr,
7691 aggregates,
7692 row_batch,
7693 column_lookup,
7694 row_idx,
7695 )?;
7696 if !matches!(value, PlanValue::Null) {
7697 return Ok(value);
7698 }
7699 }
7700 Ok(PlanValue::Null)
7701 }
7702 ScalarExpr::Random => Ok(PlanValue::Float(rand::random::<f64>())),
7703 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
7704 "GetField not supported in aggregate expressions".into(),
7705 )),
7706 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
7707 "Scalar subqueries not supported in aggregate expressions".into(),
7708 )),
7709 }
7710 }
7711
7712 fn simple_case_branch_matches(operand: &PlanValue, candidate: &PlanValue) -> bool {
7713 if matches!(operand, PlanValue::Null) || matches!(candidate, PlanValue::Null) {
7714 return false;
7715 }
7716
7717 match (operand, candidate) {
7718 (PlanValue::Integer(left), PlanValue::Integer(right)) => left == right,
7719 (PlanValue::Integer(left), PlanValue::Float(right)) => (*left as f64) == *right,
7720 (PlanValue::Float(left), PlanValue::Integer(right)) => *left == (*right as f64),
7721 (PlanValue::Float(left), PlanValue::Float(right)) => left == right,
7722 (PlanValue::String(left), PlanValue::String(right)) => left == right,
7723 (PlanValue::Struct(left), PlanValue::Struct(right)) => left == right,
7724 (PlanValue::Interval(left), PlanValue::Interval(right)) => {
7725 compare_interval_values(*left, *right) == std::cmp::Ordering::Equal
7726 }
7727 _ => operand == candidate,
7728 }
7729 }
7730
7731 fn evaluate_expr_with_aggregates(
7732 expr: &ScalarExpr<String>,
7733 aggregates: &FxHashMap<String, AggregateValue>,
7734 ) -> ExecutorResult<Option<i64>> {
7735 use llkv_expr::expr::BinaryOp;
7736 use llkv_expr::literal::Literal;
7737
7738 match expr {
7739 ScalarExpr::Literal(Literal::Int128(v)) => Ok(Some(*v as i64)),
7740 ScalarExpr::Literal(Literal::Float64(v)) => Ok(Some(*v as i64)),
7741 ScalarExpr::Literal(Literal::Decimal128(value)) => {
7742 if let Some(int) = decimal_exact_i64(*value) {
7743 Ok(Some(int))
7744 } else {
7745 Ok(Some(value.to_f64() as i64))
7746 }
7747 }
7748 ScalarExpr::Literal(Literal::Boolean(v)) => Ok(Some(if *v { 1 } else { 0 })),
7749 ScalarExpr::Literal(Literal::String(_)) => Err(Error::InvalidArgumentError(
7750 "String literals not supported in aggregate expressions".into(),
7751 )),
7752 ScalarExpr::Literal(Literal::Date32(days)) => Ok(Some(*days as i64)),
7753 ScalarExpr::Literal(Literal::Null) => Ok(None),
7754 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
7755 "Struct literals not supported in aggregate expressions".into(),
7756 )),
7757 ScalarExpr::Literal(Literal::Interval(_)) => Err(Error::InvalidArgumentError(
7758 "Interval literals not supported in aggregate-only expressions".into(),
7759 )),
7760 ScalarExpr::Column(_) => Err(Error::InvalidArgumentError(
7761 "Column references not supported in aggregate-only expressions".into(),
7762 )),
7763 ScalarExpr::Compare { .. } => Err(Error::InvalidArgumentError(
7764 "Comparisons not supported in aggregate-only expressions".into(),
7765 )),
7766 ScalarExpr::Aggregate(agg) => {
7767 let key = format!("{:?}", agg);
7768 let value = aggregates.get(&key).ok_or_else(|| {
7769 Error::Internal(format!("Aggregate value not found for key: {}", key))
7770 })?;
7771 Ok(value.as_i64())
7772 }
7773 ScalarExpr::Not(inner) => {
7774 let value = Self::evaluate_expr_with_aggregates(inner, aggregates)?;
7775 Ok(value.map(|v| if v != 0 { 0 } else { 1 }))
7776 }
7777 ScalarExpr::IsNull { expr, negated } => {
7778 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
7779 let is_null = value.is_none();
7780 Ok(Some(if is_null != *negated { 1 } else { 0 }))
7781 }
7782 ScalarExpr::Binary { left, op, right } => {
7783 let left_val = Self::evaluate_expr_with_aggregates(left, aggregates)?;
7784 let right_val = Self::evaluate_expr_with_aggregates(right, aggregates)?;
7785
7786 match op {
7787 BinaryOp::Add
7788 | BinaryOp::Subtract
7789 | BinaryOp::Multiply
7790 | BinaryOp::Divide
7791 | BinaryOp::Modulo => match (left_val, right_val) {
7792 (Some(lhs), Some(rhs)) => {
7793 let result = match op {
7794 BinaryOp::Add => lhs.checked_add(rhs),
7795 BinaryOp::Subtract => lhs.checked_sub(rhs),
7796 BinaryOp::Multiply => lhs.checked_mul(rhs),
7797 BinaryOp::Divide => {
7798 if rhs == 0 {
7799 return Ok(None);
7800 }
7801 lhs.checked_div(rhs)
7802 }
7803 BinaryOp::Modulo => {
7804 if rhs == 0 {
7805 return Ok(None);
7806 }
7807 lhs.checked_rem(rhs)
7808 }
7809 BinaryOp::And
7810 | BinaryOp::Or
7811 | BinaryOp::BitwiseShiftLeft
7812 | BinaryOp::BitwiseShiftRight => unreachable!(),
7813 };
7814
7815 result.map(Some).ok_or_else(|| {
7816 Error::InvalidArgumentError(
7817 "Arithmetic overflow in expression".into(),
7818 )
7819 })
7820 }
7821 _ => Ok(None),
7822 },
7823 BinaryOp::And => Ok(evaluate_option_logical_and(left_val, right_val)),
7824 BinaryOp::Or => Ok(evaluate_option_logical_or(left_val, right_val)),
7825 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
7826 match (left_val, right_val) {
7827 (Some(lhs), Some(rhs)) => {
7828 let result = match op {
7829 BinaryOp::BitwiseShiftLeft => {
7830 Some(lhs.wrapping_shl(rhs as u32))
7831 }
7832 BinaryOp::BitwiseShiftRight => {
7833 Some(lhs.wrapping_shr(rhs as u32))
7834 }
7835 _ => unreachable!(),
7836 };
7837 Ok(result)
7838 }
7839 _ => Ok(None),
7840 }
7841 }
7842 }
7843 }
7844 ScalarExpr::Cast { expr, data_type } => {
7845 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
7846 match value {
7847 Some(v) => Self::cast_aggregate_value(v, data_type).map(Some),
7848 None => Ok(None),
7849 }
7850 }
7851 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
7852 "GetField not supported in aggregate-only expressions".into(),
7853 )),
7854 ScalarExpr::Case { .. } => Err(Error::InvalidArgumentError(
7855 "CASE not supported in aggregate-only expressions".into(),
7856 )),
7857 ScalarExpr::Coalesce(_) => Err(Error::InvalidArgumentError(
7858 "COALESCE not supported in aggregate-only expressions".into(),
7859 )),
7860 ScalarExpr::Random => Ok(Some((rand::random::<f64>() * (i64::MAX as f64)) as i64)),
7861 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
7862 "Scalar subqueries not supported in aggregate-only expressions".into(),
7863 )),
7864 }
7865 }
7866
7867 fn cast_aggregate_value(value: i64, data_type: &DataType) -> ExecutorResult<i64> {
7868 fn ensure_range(value: i64, min: i64, max: i64, ty: &DataType) -> ExecutorResult<i64> {
7869 if value < min || value > max {
7870 return Err(Error::InvalidArgumentError(format!(
7871 "value {} out of range for CAST target {:?}",
7872 value, ty
7873 )));
7874 }
7875 Ok(value)
7876 }
7877
7878 match data_type {
7879 DataType::Int8 => ensure_range(value, i8::MIN as i64, i8::MAX as i64, data_type),
7880 DataType::Int16 => ensure_range(value, i16::MIN as i64, i16::MAX as i64, data_type),
7881 DataType::Int32 => ensure_range(value, i32::MIN as i64, i32::MAX as i64, data_type),
7882 DataType::Int64 => Ok(value),
7883 DataType::UInt8 => ensure_range(value, 0, u8::MAX as i64, data_type),
7884 DataType::UInt16 => ensure_range(value, 0, u16::MAX as i64, data_type),
7885 DataType::UInt32 => ensure_range(value, 0, u32::MAX as i64, data_type),
7886 DataType::UInt64 => {
7887 if value < 0 {
7888 return Err(Error::InvalidArgumentError(format!(
7889 "value {} out of range for CAST target {:?}",
7890 value, data_type
7891 )));
7892 }
7893 Ok(value)
7894 }
7895 DataType::Float32 | DataType::Float64 => Ok(value),
7896 DataType::Boolean => Ok(if value == 0 { 0 } else { 1 }),
7897 DataType::Null => Err(Error::InvalidArgumentError(
7898 "CAST to NULL is not supported in aggregate-only expressions".into(),
7899 )),
7900 _ => Err(Error::InvalidArgumentError(format!(
7901 "CAST to {:?} is not supported in aggregate-only expressions",
7902 data_type
7903 ))),
7904 }
7905 }
7906}
7907
7908struct CrossProductExpressionContext {
7909 schema: Arc<ExecutorSchema>,
7910 field_id_to_index: FxHashMap<FieldId, usize>,
7911 numeric_cache: FxHashMap<FieldId, ArrayRef>,
7912 column_cache: FxHashMap<FieldId, ColumnAccessor>,
7913 scalar_subquery_columns: FxHashMap<SubqueryId, ColumnAccessor>,
7914 scalar_subquery_cache: FxHashMap<(SubqueryId, Vec<u8>), Literal>,
7915 next_field_id: FieldId,
7916}
7917
7918#[derive(Clone)]
7919enum ColumnAccessor {
7920 Int64(Arc<Int64Array>),
7921 Float64(Arc<Float64Array>),
7922 Boolean(Arc<BooleanArray>),
7923 Utf8(Arc<StringArray>),
7924 Date32(Arc<Date32Array>),
7925 Interval(Arc<IntervalMonthDayNanoArray>),
7926 Decimal128 {
7927 array: Arc<Decimal128Array>,
7928 scale: i8,
7929 },
7930 Null(usize),
7931}
7932
7933impl ColumnAccessor {
7934 fn from_array(array: &ArrayRef) -> ExecutorResult<Self> {
7935 match array.data_type() {
7936 DataType::Int64 => {
7937 let typed = array
7938 .as_any()
7939 .downcast_ref::<Int64Array>()
7940 .ok_or_else(|| Error::Internal("expected Int64 array".into()))?
7941 .clone();
7942 Ok(Self::Int64(Arc::new(typed)))
7943 }
7944 DataType::Float64 => {
7945 let typed = array
7946 .as_any()
7947 .downcast_ref::<Float64Array>()
7948 .ok_or_else(|| Error::Internal("expected Float64 array".into()))?
7949 .clone();
7950 Ok(Self::Float64(Arc::new(typed)))
7951 }
7952 DataType::Boolean => {
7953 let typed = array
7954 .as_any()
7955 .downcast_ref::<BooleanArray>()
7956 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
7957 .clone();
7958 Ok(Self::Boolean(Arc::new(typed)))
7959 }
7960 DataType::Utf8 => {
7961 let typed = array
7962 .as_any()
7963 .downcast_ref::<StringArray>()
7964 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
7965 .clone();
7966 Ok(Self::Utf8(Arc::new(typed)))
7967 }
7968 DataType::Date32 => {
7969 let typed = array
7970 .as_any()
7971 .downcast_ref::<Date32Array>()
7972 .ok_or_else(|| Error::Internal("expected Date32 array".into()))?
7973 .clone();
7974 Ok(Self::Date32(Arc::new(typed)))
7975 }
7976 DataType::Interval(IntervalUnit::MonthDayNano) => {
7977 let typed = array
7978 .as_any()
7979 .downcast_ref::<IntervalMonthDayNanoArray>()
7980 .ok_or_else(|| Error::Internal("expected IntervalMonthDayNano array".into()))?
7981 .clone();
7982 Ok(Self::Interval(Arc::new(typed)))
7983 }
7984 DataType::Decimal128(_, scale) => {
7985 let typed = array
7986 .as_any()
7987 .downcast_ref::<Decimal128Array>()
7988 .ok_or_else(|| Error::Internal("expected Decimal128 array".into()))?
7989 .clone();
7990 Ok(Self::Decimal128 {
7991 array: Arc::new(typed),
7992 scale: *scale,
7993 })
7994 }
7995 DataType::Null => Ok(Self::Null(array.len())),
7996 other => Err(Error::InvalidArgumentError(format!(
7997 "unsupported column type {:?} in cross product filter",
7998 other
7999 ))),
8000 }
8001 }
8002
8003 fn from_numeric_array(numeric: &ArrayRef) -> ExecutorResult<Self> {
8004 let casted = cast(numeric, &DataType::Float64)?;
8005 let float_array = casted
8006 .as_any()
8007 .downcast_ref::<Float64Array>()
8008 .expect("cast to Float64 failed")
8009 .clone();
8010 Ok(Self::Float64(Arc::new(float_array)))
8011 }
8012
8013 fn len(&self) -> usize {
8014 match self {
8015 ColumnAccessor::Int64(array) => array.len(),
8016 ColumnAccessor::Float64(array) => array.len(),
8017 ColumnAccessor::Boolean(array) => array.len(),
8018 ColumnAccessor::Utf8(array) => array.len(),
8019 ColumnAccessor::Date32(array) => array.len(),
8020 ColumnAccessor::Interval(array) => array.len(),
8021 ColumnAccessor::Decimal128 { array, .. } => array.len(),
8022 ColumnAccessor::Null(len) => *len,
8023 }
8024 }
8025
8026 fn is_null(&self, idx: usize) -> bool {
8027 match self {
8028 ColumnAccessor::Int64(array) => array.is_null(idx),
8029 ColumnAccessor::Float64(array) => array.is_null(idx),
8030 ColumnAccessor::Boolean(array) => array.is_null(idx),
8031 ColumnAccessor::Utf8(array) => array.is_null(idx),
8032 ColumnAccessor::Date32(array) => array.is_null(idx),
8033 ColumnAccessor::Interval(array) => array.is_null(idx),
8034 ColumnAccessor::Decimal128 { array, .. } => array.is_null(idx),
8035 ColumnAccessor::Null(_) => true,
8036 }
8037 }
8038
8039 fn literal_at(&self, idx: usize) -> ExecutorResult<Literal> {
8040 if self.is_null(idx) {
8041 return Ok(Literal::Null);
8042 }
8043 match self {
8044 ColumnAccessor::Int64(array) => Ok(Literal::Int128(array.value(idx) as i128)),
8045 ColumnAccessor::Float64(array) => Ok(Literal::Float64(array.value(idx))),
8046 ColumnAccessor::Boolean(array) => Ok(Literal::Boolean(array.value(idx))),
8047 ColumnAccessor::Utf8(array) => Ok(Literal::String(array.value(idx).to_string())),
8048 ColumnAccessor::Date32(array) => Ok(Literal::Date32(array.value(idx))),
8049 ColumnAccessor::Interval(array) => Ok(Literal::Interval(interval_value_from_arrow(
8050 array.value(idx),
8051 ))),
8052 ColumnAccessor::Decimal128 { array, .. } => Ok(Literal::Int128(array.value(idx))),
8053 ColumnAccessor::Null(_) => Ok(Literal::Null),
8054 }
8055 }
8056
8057 fn as_array_ref(&self) -> ArrayRef {
8058 match self {
8059 ColumnAccessor::Int64(array) => Arc::clone(array) as ArrayRef,
8060 ColumnAccessor::Float64(array) => Arc::clone(array) as ArrayRef,
8061 ColumnAccessor::Boolean(array) => Arc::clone(array) as ArrayRef,
8062 ColumnAccessor::Utf8(array) => Arc::clone(array) as ArrayRef,
8063 ColumnAccessor::Date32(array) => Arc::clone(array) as ArrayRef,
8064 ColumnAccessor::Interval(array) => Arc::clone(array) as ArrayRef,
8065 ColumnAccessor::Decimal128 { array, .. } => Arc::clone(array) as ArrayRef,
8066 ColumnAccessor::Null(len) => new_null_array(&DataType::Null, *len),
8067 }
8068 }
8069}
8070
8071#[derive(Clone)]
8072enum ValueArray {
8073 Numeric(ArrayRef),
8074 Boolean(Arc<BooleanArray>),
8075 Utf8(Arc<StringArray>),
8076 Interval(Arc<IntervalMonthDayNanoArray>),
8077 Null(usize),
8078}
8079
8080impl ValueArray {
8081 fn from_array(array: ArrayRef) -> ExecutorResult<Self> {
8082 match array.data_type() {
8083 DataType::Boolean => {
8084 let typed = array
8085 .as_any()
8086 .downcast_ref::<BooleanArray>()
8087 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
8088 .clone();
8089 Ok(Self::Boolean(Arc::new(typed)))
8090 }
8091 DataType::Utf8 => {
8092 let typed = array
8093 .as_any()
8094 .downcast_ref::<StringArray>()
8095 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
8096 .clone();
8097 Ok(Self::Utf8(Arc::new(typed)))
8098 }
8099 DataType::Interval(IntervalUnit::MonthDayNano) => {
8100 let typed = array
8101 .as_any()
8102 .downcast_ref::<IntervalMonthDayNanoArray>()
8103 .ok_or_else(|| Error::Internal("expected IntervalMonthDayNano array".into()))?
8104 .clone();
8105 Ok(Self::Interval(Arc::new(typed)))
8106 }
8107 DataType::Null => Ok(Self::Null(array.len())),
8108 DataType::Int8
8109 | DataType::Int16
8110 | DataType::Int32
8111 | DataType::Int64
8112 | DataType::UInt8
8113 | DataType::UInt16
8114 | DataType::UInt32
8115 | DataType::UInt64
8116 | DataType::Date32
8117 | DataType::Float32
8118 | DataType::Float64
8119 | DataType::Decimal128(_, _) => Ok(Self::Numeric(array)),
8120 other => Err(Error::InvalidArgumentError(format!(
8121 "unsupported data type {:?} in cross product expression",
8122 other
8123 ))),
8124 }
8125 }
8126
8127 fn len(&self) -> usize {
8128 match self {
8129 ValueArray::Numeric(array) => array.len(),
8130 ValueArray::Boolean(array) => array.len(),
8131 ValueArray::Utf8(array) => array.len(),
8132 ValueArray::Interval(array) => array.len(),
8133 ValueArray::Null(len) => *len,
8134 }
8135 }
8136
8137 fn as_array_ref(&self) -> ArrayRef {
8138 match self {
8139 ValueArray::Numeric(arr) => arr.clone(),
8140 ValueArray::Boolean(arr) => arr.clone() as ArrayRef,
8141 ValueArray::Utf8(arr) => arr.clone() as ArrayRef,
8142 ValueArray::Interval(arr) => arr.clone() as ArrayRef,
8143 ValueArray::Null(len) => new_null_array(&DataType::Null, *len),
8144 }
8145 }
8146}
8147
8148fn truth_and(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
8149 match (lhs, rhs) {
8150 (Some(false), _) | (_, Some(false)) => Some(false),
8151 (Some(true), Some(true)) => Some(true),
8152 (Some(true), None) | (None, Some(true)) | (None, None) => None,
8153 }
8154}
8155
8156fn truth_or(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
8157 match (lhs, rhs) {
8158 (Some(true), _) | (_, Some(true)) => Some(true),
8159 (Some(false), Some(false)) => Some(false),
8160 (Some(false), None) | (None, Some(false)) | (None, None) => None,
8161 }
8162}
8163
8164fn truth_not(value: Option<bool>) -> Option<bool> {
8165 match value {
8166 Some(true) => Some(false),
8167 Some(false) => Some(true),
8168 None => None,
8169 }
8170}
8171
8172fn literal_to_constant_array(literal: &Literal, len: usize) -> ExecutorResult<ArrayRef> {
8173 match literal {
8174 Literal::Int128(v) => {
8175 let value = i64::try_from(*v).unwrap_or(0);
8176 let values = vec![value; len];
8177 Ok(Arc::new(Int64Array::from(values)) as ArrayRef)
8178 }
8179 Literal::Float64(v) => {
8180 let values = vec![*v; len];
8181 Ok(Arc::new(Float64Array::from(values)) as ArrayRef)
8182 }
8183 Literal::Boolean(v) => {
8184 let values = vec![Some(*v); len];
8185 Ok(Arc::new(BooleanArray::from(values)) as ArrayRef)
8186 }
8187 Literal::String(v) => {
8188 let values: Vec<Option<String>> = (0..len).map(|_| Some(v.clone())).collect();
8189 Ok(Arc::new(StringArray::from(values)) as ArrayRef)
8190 }
8191 Literal::Date32(days) => {
8192 let values = vec![*days; len];
8193 Ok(Arc::new(Date32Array::from(values)) as ArrayRef)
8194 }
8195 Literal::Decimal128(value) => {
8196 let iter = std::iter::repeat_n(value.raw_value(), len);
8197 let array = Decimal128Array::from_iter_values(iter)
8198 .with_precision_and_scale(value.precision(), value.scale())
8199 .map_err(|err| {
8200 Error::InvalidArgumentError(format!(
8201 "failed to synthesize decimal literal array: {err}"
8202 ))
8203 })?;
8204 Ok(Arc::new(array) as ArrayRef)
8205 }
8206 Literal::Interval(interval) => {
8207 let value = interval_value_to_arrow(*interval);
8208 let values = vec![value; len];
8209 Ok(Arc::new(IntervalMonthDayNanoArray::from(values)) as ArrayRef)
8210 }
8211 Literal::Null => Ok(new_null_array(&DataType::Null, len)),
8212 Literal::Struct(_) => Err(Error::InvalidArgumentError(
8213 "struct literals are not supported in cross product filters".into(),
8214 )),
8215 }
8216}
8217
8218fn literals_to_array(values: &[Literal]) -> ExecutorResult<ArrayRef> {
8219 #[derive(Copy, Clone, Eq, PartialEq)]
8220 enum LiteralArrayKind {
8221 Null,
8222 Integer,
8223 Float,
8224 Boolean,
8225 String,
8226 Date32,
8227 Interval,
8228 Decimal,
8229 }
8230
8231 if values.is_empty() {
8232 return Ok(new_null_array(&DataType::Null, 0));
8233 }
8234
8235 let mut has_integer = false;
8236 let mut has_float = false;
8237 let mut has_decimal = false;
8238 let mut has_boolean = false;
8239 let mut has_string = false;
8240 let mut has_date = false;
8241 let mut has_interval = false;
8242
8243 for literal in values {
8244 match literal {
8245 Literal::Null => {}
8246 Literal::Int128(_) => {
8247 has_integer = true;
8248 }
8249 Literal::Float64(_) => {
8250 has_float = true;
8251 }
8252 Literal::Decimal128(_) => {
8253 has_decimal = true;
8254 }
8255 Literal::Boolean(_) => {
8256 has_boolean = true;
8257 }
8258 Literal::String(_) => {
8259 has_string = true;
8260 }
8261 Literal::Date32(_) => {
8262 has_date = true;
8263 }
8264 Literal::Interval(_) => {
8265 has_interval = true;
8266 }
8267 Literal::Struct(_) => {
8268 return Err(Error::InvalidArgumentError(
8269 "struct scalar subquery results are not supported".into(),
8270 ));
8271 }
8272 }
8273 }
8274
8275 let mixed_numeric = has_integer as u8 + has_float as u8 + has_decimal as u8;
8276 if has_string && (has_boolean || has_date || has_interval || mixed_numeric > 0)
8277 || has_boolean && (has_date || has_interval || mixed_numeric > 0)
8278 || has_date && (has_interval || mixed_numeric > 0)
8279 || has_interval && (mixed_numeric > 0)
8280 {
8281 return Err(Error::InvalidArgumentError(
8282 "mixed scalar subquery result types are not supported".into(),
8283 ));
8284 }
8285
8286 let target_kind = if has_string {
8287 LiteralArrayKind::String
8288 } else if has_interval {
8289 LiteralArrayKind::Interval
8290 } else if has_date {
8291 LiteralArrayKind::Date32
8292 } else if has_boolean {
8293 LiteralArrayKind::Boolean
8294 } else if has_float {
8295 LiteralArrayKind::Float
8296 } else if has_decimal {
8297 LiteralArrayKind::Decimal
8298 } else if has_integer {
8299 LiteralArrayKind::Integer
8300 } else {
8301 LiteralArrayKind::Null
8302 };
8303
8304 match target_kind {
8305 LiteralArrayKind::Null => Ok(new_null_array(&DataType::Null, values.len())),
8306 LiteralArrayKind::Integer => {
8307 let mut coerced: Vec<Option<i64>> = Vec::with_capacity(values.len());
8308 for literal in values {
8309 match literal {
8310 Literal::Null => coerced.push(None),
8311 Literal::Int128(value) => {
8312 let v = i64::try_from(*value).map_err(|_| {
8313 Error::InvalidArgumentError(
8314 "scalar subquery integer result exceeds supported range".into(),
8315 )
8316 })?;
8317 coerced.push(Some(v));
8318 }
8319 _ => unreachable!("non-integer value encountered in integer array"),
8320 }
8321 }
8322 let array = Int64Array::from_iter(coerced);
8323 Ok(Arc::new(array) as ArrayRef)
8324 }
8325 LiteralArrayKind::Float => {
8326 let mut coerced: Vec<Option<f64>> = Vec::with_capacity(values.len());
8327 for literal in values {
8328 match literal {
8329 Literal::Null => coerced.push(None),
8330 Literal::Int128(_) | Literal::Float64(_) | Literal::Decimal128(_) => {
8331 let value = literal_to_f64(literal).ok_or_else(|| {
8332 Error::InvalidArgumentError(
8333 "failed to coerce scalar subquery value to FLOAT".into(),
8334 )
8335 })?;
8336 coerced.push(Some(value));
8337 }
8338 _ => unreachable!("non-numeric value encountered in float array"),
8339 }
8340 }
8341 let array = Float64Array::from_iter(coerced);
8342 Ok(Arc::new(array) as ArrayRef)
8343 }
8344 LiteralArrayKind::Boolean => {
8345 let iter = values.iter().map(|literal| match literal {
8346 Literal::Null => None,
8347 Literal::Boolean(flag) => Some(*flag),
8348 _ => unreachable!("non-boolean value encountered in boolean array"),
8349 });
8350 let array = BooleanArray::from_iter(iter);
8351 Ok(Arc::new(array) as ArrayRef)
8352 }
8353 LiteralArrayKind::String => {
8354 let iter = values.iter().map(|literal| match literal {
8355 Literal::Null => None,
8356 Literal::String(value) => Some(value.clone()),
8357 _ => unreachable!("non-string value encountered in string array"),
8358 });
8359 let array = StringArray::from_iter(iter);
8360 Ok(Arc::new(array) as ArrayRef)
8361 }
8362 LiteralArrayKind::Date32 => {
8363 let iter = values.iter().map(|literal| match literal {
8364 Literal::Null => None,
8365 Literal::Date32(days) => Some(*days),
8366 _ => unreachable!("non-date value encountered in date array"),
8367 });
8368 let array = Date32Array::from_iter(iter);
8369 Ok(Arc::new(array) as ArrayRef)
8370 }
8371 LiteralArrayKind::Interval => {
8372 let iter = values.iter().map(|literal| match literal {
8373 Literal::Null => None,
8374 Literal::Interval(interval) => Some(interval_value_to_arrow(*interval)),
8375 _ => unreachable!("non-interval value encountered in interval array"),
8376 });
8377 let array = IntervalMonthDayNanoArray::from_iter(iter);
8378 Ok(Arc::new(array) as ArrayRef)
8379 }
8380 LiteralArrayKind::Decimal => {
8381 let mut target_scale: Option<i8> = None;
8382 for literal in values {
8383 if let Literal::Decimal128(value) = literal {
8384 target_scale = Some(match target_scale {
8385 Some(scale) => scale.max(value.scale()),
8386 None => value.scale(),
8387 });
8388 }
8389 }
8390 let target_scale = target_scale.expect("decimal literal expected");
8391
8392 let mut max_precision: u8 = 1;
8393 let mut aligned: Vec<Option<DecimalValue>> = Vec::with_capacity(values.len());
8394 for literal in values {
8395 match literal {
8396 Literal::Null => aligned.push(None),
8397 Literal::Decimal128(value) => {
8398 let adjusted = if value.scale() != target_scale {
8399 llkv_compute::scalar::decimal::rescale(*value, target_scale).map_err(
8400 |err| {
8401 Error::InvalidArgumentError(format!(
8402 "failed to align decimal scale: {err}"
8403 ))
8404 },
8405 )?
8406 } else {
8407 *value
8408 };
8409 max_precision = max_precision.max(adjusted.precision());
8410 aligned.push(Some(adjusted));
8411 }
8412 Literal::Int128(value) => {
8413 let decimal = DecimalValue::new(*value, 0).map_err(|err| {
8414 Error::InvalidArgumentError(format!(
8415 "failed to build decimal from integer: {err}"
8416 ))
8417 })?;
8418 let decimal = llkv_compute::scalar::decimal::rescale(decimal, target_scale)
8419 .map_err(|err| {
8420 Error::InvalidArgumentError(format!(
8421 "failed to align integer decimal scale: {err}"
8422 ))
8423 })?;
8424 max_precision = max_precision.max(decimal.precision());
8425 aligned.push(Some(decimal));
8426 }
8427 _ => unreachable!("unexpected literal in decimal array"),
8428 }
8429 }
8430
8431 let mut builder = Decimal128Builder::new()
8432 .with_precision_and_scale(max_precision, target_scale)
8433 .map_err(|err| {
8434 Error::InvalidArgumentError(format!(
8435 "invalid Decimal128 precision/scale: {err}"
8436 ))
8437 })?;
8438 for value in aligned {
8439 match value {
8440 Some(decimal) => builder.append_value(decimal.raw_value()),
8441 None => builder.append_null(),
8442 }
8443 }
8444 let array = builder.finish();
8445 Ok(Arc::new(array) as ArrayRef)
8446 }
8447 }
8448}
8449
8450impl CrossProductExpressionContext {
8451 fn new(schema: &Schema, lookup: FxHashMap<String, usize>) -> ExecutorResult<Self> {
8452 let mut columns = Vec::with_capacity(schema.fields().len());
8453 let mut field_id_to_index = FxHashMap::default();
8454 let mut next_field_id: FieldId = 1;
8455
8456 for (idx, field) in schema.fields().iter().enumerate() {
8457 if next_field_id == u32::MAX {
8458 return Err(Error::Internal(
8459 "cross product projection exhausted FieldId space".into(),
8460 ));
8461 }
8462
8463 let executor_column = ExecutorColumn {
8464 name: field.name().clone(),
8465 data_type: field.data_type().clone(),
8466 nullable: field.is_nullable(),
8467 primary_key: false,
8468 unique: false,
8469 field_id: next_field_id,
8470 check_expr: None,
8471 };
8472 let field_id = next_field_id;
8473 next_field_id = next_field_id.saturating_add(1);
8474
8475 columns.push(executor_column);
8476 field_id_to_index.insert(field_id, idx);
8477 }
8478
8479 Ok(Self {
8480 schema: Arc::new(ExecutorSchema { columns, lookup }),
8481 field_id_to_index,
8482 numeric_cache: FxHashMap::default(),
8483 column_cache: FxHashMap::default(),
8484 scalar_subquery_columns: FxHashMap::default(),
8485 scalar_subquery_cache: FxHashMap::default(),
8486 next_field_id,
8487 })
8488 }
8489
8490 fn schema(&self) -> &ExecutorSchema {
8491 self.schema.as_ref()
8492 }
8493
8494 fn field_id_for_column(&self, name: &str) -> Option<FieldId> {
8495 self.schema.resolve(name).map(|column| column.field_id)
8496 }
8497
8498 fn reset(&mut self) {
8499 self.numeric_cache.clear();
8500 self.column_cache.clear();
8501 self.scalar_subquery_columns.clear();
8502 }
8503
8504 fn allocate_synthetic_field_id(&mut self) -> ExecutorResult<FieldId> {
8505 if self.next_field_id == FieldId::MAX {
8506 return Err(Error::Internal(
8507 "cross product projection exhausted FieldId space".into(),
8508 ));
8509 }
8510 let field_id = self.next_field_id;
8511 self.next_field_id = self.next_field_id.saturating_add(1);
8512 Ok(field_id)
8513 }
8514
8515 fn register_scalar_subquery_column(
8516 &mut self,
8517 subquery_id: SubqueryId,
8518 accessor: ColumnAccessor,
8519 ) {
8520 self.scalar_subquery_columns.insert(subquery_id, accessor);
8521 }
8522
8523 #[cfg(test)]
8524 fn evaluate(
8525 &mut self,
8526 expr: &ScalarExpr<String>,
8527 batch: &RecordBatch,
8528 ) -> ExecutorResult<ArrayRef> {
8529 let translated = translate_scalar(expr, self.schema.as_ref(), |name| {
8530 Error::InvalidArgumentError(format!(
8531 "column '{}' not found in cross product result",
8532 name
8533 ))
8534 })?;
8535
8536 self.evaluate_numeric(&translated, batch)
8537 }
8538
8539 fn evaluate_predicate_mask(
8540 &mut self,
8541 expr: &LlkvExpr<'static, FieldId>,
8542 batch: &RecordBatch,
8543 mut exists_eval: impl FnMut(
8544 &mut Self,
8545 &llkv_expr::SubqueryExpr,
8546 usize,
8547 &RecordBatch,
8548 ) -> ExecutorResult<Option<bool>>,
8549 ) -> ExecutorResult<BooleanArray> {
8550 let truths = self.evaluate_predicate_truths(expr, batch, &mut exists_eval)?;
8551 let mut builder = BooleanBuilder::with_capacity(truths.len());
8552 for value in truths {
8553 builder.append_value(value.unwrap_or(false));
8554 }
8555 Ok(builder.finish())
8556 }
8557
8558 fn evaluate_predicate_truths(
8559 &mut self,
8560 expr: &LlkvExpr<'static, FieldId>,
8561 batch: &RecordBatch,
8562 exists_eval: &mut impl FnMut(
8563 &mut Self,
8564 &llkv_expr::SubqueryExpr,
8565 usize,
8566 &RecordBatch,
8567 ) -> ExecutorResult<Option<bool>>,
8568 ) -> ExecutorResult<Vec<Option<bool>>> {
8569 match expr {
8570 LlkvExpr::Literal(value) => Ok(vec![Some(*value); batch.num_rows()]),
8571 LlkvExpr::And(children) => {
8572 if children.is_empty() {
8573 return Ok(vec![Some(true); batch.num_rows()]);
8574 }
8575 let mut result =
8576 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
8577 for child in &children[1..] {
8578 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
8579 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
8580 *lhs = truth_and(*lhs, rhs);
8581 }
8582 }
8583 Ok(result)
8584 }
8585 LlkvExpr::Or(children) => {
8586 if children.is_empty() {
8587 return Ok(vec![Some(false); batch.num_rows()]);
8588 }
8589 let mut result =
8590 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
8591 for child in &children[1..] {
8592 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
8593 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
8594 *lhs = truth_or(*lhs, rhs);
8595 }
8596 }
8597 Ok(result)
8598 }
8599 LlkvExpr::Not(inner) => {
8600 let mut values = self.evaluate_predicate_truths(inner, batch, exists_eval)?;
8601 for value in &mut values {
8602 *value = truth_not(*value);
8603 }
8604 Ok(values)
8605 }
8606 LlkvExpr::Pred(filter) => self.evaluate_filter_truths(filter, batch),
8607 LlkvExpr::Compare { left, op, right } => {
8608 self.evaluate_compare_truths(left, *op, right, batch)
8609 }
8610 LlkvExpr::InList {
8611 expr: target,
8612 list,
8613 negated,
8614 } => self.evaluate_in_list_truths(target, list, *negated, batch),
8615 LlkvExpr::IsNull { expr, negated } => {
8616 self.evaluate_is_null_truths(expr, *negated, batch)
8617 }
8618 LlkvExpr::Exists(subquery_expr) => {
8619 let mut values = Vec::with_capacity(batch.num_rows());
8620 for row_idx in 0..batch.num_rows() {
8621 let value = exists_eval(self, subquery_expr, row_idx, batch)?;
8622 values.push(value);
8623 }
8624 Ok(values)
8625 }
8626 }
8627 }
8628
8629 fn evaluate_filter_truths(
8630 &mut self,
8631 filter: &Filter<FieldId>,
8632 batch: &RecordBatch,
8633 ) -> ExecutorResult<Vec<Option<bool>>> {
8634 let accessor = self.column_accessor(filter.field_id, batch)?;
8635 let len = accessor.len();
8636
8637 match &filter.op {
8638 Operator::IsNull => {
8639 let mut out = Vec::with_capacity(len);
8640 for idx in 0..len {
8641 out.push(Some(accessor.is_null(idx)));
8642 }
8643 Ok(out)
8644 }
8645 Operator::IsNotNull => {
8646 let mut out = Vec::with_capacity(len);
8647 for idx in 0..len {
8648 out.push(Some(!accessor.is_null(idx)));
8649 }
8650 Ok(out)
8651 }
8652 _ => match accessor {
8653 ColumnAccessor::Int64(array) => {
8654 let predicate = build_fixed_width_predicate::<Int64Type>(&filter.op)
8655 .map_err(Error::predicate_build)?;
8656 let mut out = Vec::with_capacity(len);
8657 for idx in 0..len {
8658 if array.is_null(idx) {
8659 out.push(None);
8660 } else {
8661 let value = array.value(idx);
8662 out.push(Some(predicate.matches(&value)));
8663 }
8664 }
8665 Ok(out)
8666 }
8667 ColumnAccessor::Float64(array) => {
8668 let predicate = build_fixed_width_predicate::<Float64Type>(&filter.op)
8669 .map_err(Error::predicate_build)?;
8670 let mut out = Vec::with_capacity(len);
8671 for idx in 0..len {
8672 if array.is_null(idx) {
8673 out.push(None);
8674 } else {
8675 let value = array.value(idx);
8676 out.push(Some(predicate.matches(&value)));
8677 }
8678 }
8679 Ok(out)
8680 }
8681 ColumnAccessor::Boolean(array) => {
8682 let predicate =
8683 build_bool_predicate(&filter.op).map_err(Error::predicate_build)?;
8684 let mut out = Vec::with_capacity(len);
8685 for idx in 0..len {
8686 if array.is_null(idx) {
8687 out.push(None);
8688 } else {
8689 let value = array.value(idx);
8690 out.push(Some(predicate.matches(&value)));
8691 }
8692 }
8693 Ok(out)
8694 }
8695 ColumnAccessor::Utf8(array) => {
8696 let predicate =
8697 build_var_width_predicate(&filter.op).map_err(Error::predicate_build)?;
8698 let mut out = Vec::with_capacity(len);
8699 for idx in 0..len {
8700 if array.is_null(idx) {
8701 out.push(None);
8702 } else {
8703 let value = array.value(idx);
8704 out.push(Some(predicate.matches(value)));
8705 }
8706 }
8707 Ok(out)
8708 }
8709 ColumnAccessor::Date32(array) => {
8710 let predicate = build_fixed_width_predicate::<Int32Type>(&filter.op)
8711 .map_err(Error::predicate_build)?;
8712 let mut out = Vec::with_capacity(len);
8713 for idx in 0..len {
8714 if array.is_null(idx) {
8715 out.push(None);
8716 } else {
8717 let value = array.value(idx);
8718 out.push(Some(predicate.matches(&value)));
8719 }
8720 }
8721 Ok(out)
8722 }
8723 ColumnAccessor::Interval(array) => {
8724 let array = array.as_ref();
8725 let mut out = Vec::with_capacity(len);
8726 for idx in 0..len {
8727 if array.is_null(idx) {
8728 out.push(None);
8729 continue;
8730 }
8731 let literal =
8732 Literal::Interval(interval_value_from_arrow(array.value(idx)));
8733 let matches = evaluate_filter_against_literal(&literal, &filter.op)?;
8734 out.push(Some(matches));
8735 }
8736 Ok(out)
8737 }
8738 ColumnAccessor::Decimal128 { array, scale } => {
8739 let scale_factor = 10_f64.powi(scale as i32);
8742 let mut out = Vec::with_capacity(len);
8743 for idx in 0..len {
8744 if array.is_null(idx) {
8745 out.push(None);
8746 continue;
8747 }
8748 let raw_value = array.value(idx);
8749 let decimal_value = raw_value as f64 / scale_factor;
8750 let literal = Literal::Float64(decimal_value);
8751 let matches = evaluate_filter_against_literal(&literal, &filter.op)?;
8752 out.push(Some(matches));
8753 }
8754 Ok(out)
8755 }
8756 ColumnAccessor::Null(len) => Ok(vec![None; len]),
8757 },
8758 }
8759 }
8760
8761 fn evaluate_compare_truths(
8762 &mut self,
8763 left: &ScalarExpr<FieldId>,
8764 op: CompareOp,
8765 right: &ScalarExpr<FieldId>,
8766 batch: &RecordBatch,
8767 ) -> ExecutorResult<Vec<Option<bool>>> {
8768 let left_values = self.materialize_value_array(left, batch)?;
8769 let right_values = self.materialize_value_array(right, batch)?;
8770
8771 if left_values.len() != right_values.len() {
8772 return Err(Error::Internal(
8773 "mismatched compare operand lengths in cross product filter".into(),
8774 ));
8775 }
8776
8777 let len = left_values.len();
8778
8779 if matches!(left_values, ValueArray::Null(_)) || matches!(right_values, ValueArray::Null(_))
8780 {
8781 return Ok(vec![None; len]);
8782 }
8783
8784 let lhs_arr = left_values.as_array_ref();
8785 let rhs_arr = right_values.as_array_ref();
8786
8787 let result_array = llkv_compute::kernels::compute_compare(&lhs_arr, op, &rhs_arr)?;
8788 let bool_array = result_array
8789 .as_any()
8790 .downcast_ref::<BooleanArray>()
8791 .expect("compute_compare must return BooleanArray");
8792
8793 let out: Vec<Option<bool>> = bool_array.iter().collect();
8794 Ok(out)
8795 }
8796
8797 fn evaluate_is_null_truths(
8798 &mut self,
8799 expr: &ScalarExpr<FieldId>,
8800 negated: bool,
8801 batch: &RecordBatch,
8802 ) -> ExecutorResult<Vec<Option<bool>>> {
8803 let values = self.materialize_value_array(expr, batch)?;
8804 let len = values.len();
8805
8806 if let ValueArray::Null(len) = values {
8807 let result = if negated { Some(false) } else { Some(true) };
8808 return Ok(vec![result; len]);
8809 }
8810
8811 let arr = values.as_array_ref();
8812 let mut out = Vec::with_capacity(len);
8813 for idx in 0..len {
8814 let is_null = arr.is_null(idx);
8815 let result = if negated { !is_null } else { is_null };
8816 out.push(Some(result));
8817 }
8818 Ok(out)
8819 }
8820
8821 fn evaluate_in_list_truths(
8822 &mut self,
8823 target: &ScalarExpr<FieldId>,
8824 list: &[ScalarExpr<FieldId>],
8825 negated: bool,
8826 batch: &RecordBatch,
8827 ) -> ExecutorResult<Vec<Option<bool>>> {
8828 let target_values = self.materialize_value_array(target, batch)?;
8829 let list_values = list
8830 .iter()
8831 .map(|expr| self.materialize_value_array(expr, batch))
8832 .collect::<ExecutorResult<Vec<_>>>()?;
8833
8834 let len = target_values.len();
8835 for values in &list_values {
8836 if values.len() != len {
8837 return Err(Error::Internal(
8838 "mismatched IN list operand lengths in cross product filter".into(),
8839 ));
8840 }
8841 }
8842
8843 if matches!(target_values, ValueArray::Null(_)) {
8844 return Ok(vec![None; len]);
8845 }
8846
8847 let target_arr = target_values.as_array_ref();
8848 let mut combined_result: Option<BooleanArray> = None;
8849
8850 for candidate in &list_values {
8851 if matches!(candidate, ValueArray::Null(_)) {
8852 let nulls = new_null_array(&DataType::Boolean, len);
8853 let bool_nulls = nulls
8854 .as_any()
8855 .downcast_ref::<BooleanArray>()
8856 .unwrap()
8857 .clone();
8858
8859 match combined_result {
8860 None => combined_result = Some(bool_nulls),
8861 Some(prev) => {
8862 combined_result = Some(or_kleene(&prev, &bool_nulls)?);
8863 }
8864 }
8865 continue;
8866 }
8867
8868 let candidate_arr = candidate.as_array_ref();
8869
8870 let cmp =
8871 llkv_compute::kernels::compute_compare(&target_arr, CompareOp::Eq, &candidate_arr)?;
8872 let bool_cmp = cmp
8873 .as_any()
8874 .downcast_ref::<BooleanArray>()
8875 .expect("compute_compare returns BooleanArray")
8876 .clone();
8877
8878 match combined_result {
8879 None => combined_result = Some(bool_cmp),
8880 Some(prev) => {
8881 combined_result = Some(or_kleene(&prev, &bool_cmp)?);
8882 }
8883 }
8884 }
8885
8886 let final_bool = combined_result.unwrap_or_else(|| {
8887 let mut builder = BooleanBuilder::new();
8888 for _ in 0..len {
8889 builder.append_value(false);
8890 }
8891 builder.finish()
8892 });
8893
8894 let final_bool = if negated {
8895 not(&final_bool)?
8896 } else {
8897 final_bool
8898 };
8899
8900 let out: Vec<Option<bool>> = final_bool.iter().collect();
8901 Ok(out)
8902 }
8903
8904 fn evaluate_numeric(
8905 &mut self,
8906 expr: &ScalarExpr<FieldId>,
8907 batch: &RecordBatch,
8908 ) -> ExecutorResult<ArrayRef> {
8909 let mut required = FxHashSet::default();
8910 collect_field_ids(expr, &mut required);
8911
8912 let mut arrays = NumericArrayMap::default();
8913 for field_id in required {
8914 let numeric = self.numeric_array(field_id, batch)?;
8915 arrays.insert(field_id, numeric);
8916 }
8917
8918 NumericKernels::evaluate_batch(expr, batch.num_rows(), &arrays)
8919 }
8920
8921 fn numeric_array(
8922 &mut self,
8923 field_id: FieldId,
8924 batch: &RecordBatch,
8925 ) -> ExecutorResult<ArrayRef> {
8926 if let Some(existing) = self.numeric_cache.get(&field_id) {
8927 return Ok(existing.clone());
8928 }
8929
8930 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
8931 Error::Internal("field mapping missing during cross product evaluation".into())
8932 })?;
8933
8934 let array_ref = batch.column(column_index).clone();
8935 self.numeric_cache.insert(field_id, array_ref.clone());
8936 Ok(array_ref)
8937 }
8938
8939 fn column_accessor(
8940 &mut self,
8941 field_id: FieldId,
8942 batch: &RecordBatch,
8943 ) -> ExecutorResult<ColumnAccessor> {
8944 if let Some(existing) = self.column_cache.get(&field_id) {
8945 return Ok(existing.clone());
8946 }
8947
8948 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
8949 Error::Internal("field mapping missing during cross product evaluation".into())
8950 })?;
8951
8952 let accessor = ColumnAccessor::from_array(batch.column(column_index))?;
8953 self.column_cache.insert(field_id, accessor.clone());
8954 Ok(accessor)
8955 }
8956
8957 fn materialize_scalar_array(
8958 &mut self,
8959 expr: &ScalarExpr<FieldId>,
8960 batch: &RecordBatch,
8961 ) -> ExecutorResult<ArrayRef> {
8962 match expr {
8963 ScalarExpr::Column(field_id) => {
8964 let accessor = self.column_accessor(*field_id, batch)?;
8965 Ok(accessor.as_array_ref())
8966 }
8967 ScalarExpr::Literal(literal) => literal_to_constant_array(literal, batch.num_rows()),
8968 ScalarExpr::Binary { .. } => self.evaluate_numeric(expr, batch),
8969 ScalarExpr::Compare { .. } => self.evaluate_numeric(expr, batch),
8970 ScalarExpr::Not(_) => self.evaluate_numeric(expr, batch),
8971 ScalarExpr::IsNull { .. } => self.evaluate_numeric(expr, batch),
8972 ScalarExpr::Aggregate(_) => Err(Error::InvalidArgumentError(
8973 "aggregate expressions are not supported in cross product filters".into(),
8974 )),
8975 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
8976 "struct field access is not supported in cross product filters".into(),
8977 )),
8978 ScalarExpr::Cast { expr, data_type } => {
8979 let source = self.materialize_scalar_array(expr.as_ref(), batch)?;
8980 let casted = cast(source.as_ref(), data_type).map_err(|err| {
8981 Error::InvalidArgumentError(format!("failed to cast expression: {err}"))
8982 })?;
8983 Ok(casted)
8984 }
8985 ScalarExpr::Case { .. } => self.evaluate_numeric(expr, batch),
8986 ScalarExpr::Coalesce(_) => self.evaluate_numeric(expr, batch),
8987 ScalarExpr::Random => self.evaluate_numeric(expr, batch),
8988 ScalarExpr::ScalarSubquery(subquery) => {
8989 let accessor = self
8990 .scalar_subquery_columns
8991 .get(&subquery.id)
8992 .ok_or_else(|| {
8993 Error::InvalidArgumentError(
8994 "scalar subqueries are not supported in cross product filters".into(),
8995 )
8996 })?
8997 .clone();
8998 Ok(accessor.as_array_ref())
8999 }
9000 }
9001 }
9002
9003 fn materialize_value_array(
9004 &mut self,
9005 expr: &ScalarExpr<FieldId>,
9006 batch: &RecordBatch,
9007 ) -> ExecutorResult<ValueArray> {
9008 let array = self.materialize_scalar_array(expr, batch)?;
9009 ValueArray::from_array(array)
9010 }
9011}
9012
9013fn collect_field_ids(expr: &ScalarExpr<FieldId>, out: &mut FxHashSet<FieldId>) {
9015 match expr {
9016 ScalarExpr::Column(fid) => {
9017 out.insert(*fid);
9018 }
9019 ScalarExpr::Binary { left, right, .. } => {
9020 collect_field_ids(left, out);
9021 collect_field_ids(right, out);
9022 }
9023 ScalarExpr::Compare { left, right, .. } => {
9024 collect_field_ids(left, out);
9025 collect_field_ids(right, out);
9026 }
9027 ScalarExpr::Aggregate(call) => match call {
9028 AggregateCall::CountStar => {}
9029 AggregateCall::Count { expr, .. }
9030 | AggregateCall::Sum { expr, .. }
9031 | AggregateCall::Total { expr, .. }
9032 | AggregateCall::Avg { expr, .. }
9033 | AggregateCall::Min(expr)
9034 | AggregateCall::Max(expr)
9035 | AggregateCall::CountNulls(expr)
9036 | AggregateCall::GroupConcat { expr, .. } => {
9037 collect_field_ids(expr, out);
9038 }
9039 },
9040 ScalarExpr::GetField { base, .. } => collect_field_ids(base, out),
9041 ScalarExpr::Cast { expr, .. } => collect_field_ids(expr, out),
9042 ScalarExpr::Not(expr) => collect_field_ids(expr, out),
9043 ScalarExpr::IsNull { expr, .. } => collect_field_ids(expr, out),
9044 ScalarExpr::Case {
9045 operand,
9046 branches,
9047 else_expr,
9048 } => {
9049 if let Some(inner) = operand.as_deref() {
9050 collect_field_ids(inner, out);
9051 }
9052 for (when_expr, then_expr) in branches {
9053 collect_field_ids(when_expr, out);
9054 collect_field_ids(then_expr, out);
9055 }
9056 if let Some(inner) = else_expr.as_deref() {
9057 collect_field_ids(inner, out);
9058 }
9059 }
9060 ScalarExpr::Coalesce(items) => {
9061 for item in items {
9062 collect_field_ids(item, out);
9063 }
9064 }
9065 ScalarExpr::Literal(_) | ScalarExpr::Random => {}
9066 ScalarExpr::ScalarSubquery(_) => {}
9067 }
9068}
9069
9070fn strip_exists(expr: &LlkvExpr<'static, FieldId>) -> LlkvExpr<'static, FieldId> {
9071 match expr {
9072 LlkvExpr::And(children) => LlkvExpr::And(children.iter().map(strip_exists).collect()),
9073 LlkvExpr::Or(children) => LlkvExpr::Or(children.iter().map(strip_exists).collect()),
9074 LlkvExpr::Not(inner) => LlkvExpr::Not(Box::new(strip_exists(inner))),
9075 LlkvExpr::Pred(filter) => LlkvExpr::Pred(filter.clone()),
9076 LlkvExpr::Compare { left, op, right } => LlkvExpr::Compare {
9077 left: left.clone(),
9078 op: *op,
9079 right: right.clone(),
9080 },
9081 LlkvExpr::InList {
9082 expr,
9083 list,
9084 negated,
9085 } => LlkvExpr::InList {
9086 expr: expr.clone(),
9087 list: list.clone(),
9088 negated: *negated,
9089 },
9090 LlkvExpr::IsNull { expr, negated } => LlkvExpr::IsNull {
9091 expr: expr.clone(),
9092 negated: *negated,
9093 },
9094 LlkvExpr::Literal(value) => LlkvExpr::Literal(*value),
9095 LlkvExpr::Exists(_) => LlkvExpr::Literal(true),
9096 }
9097}
9098
9099fn rewrite_predicate_scalar_subqueries(
9100 expr: LlkvExpr<'static, FieldId>,
9101 literals: &FxHashMap<SubqueryId, Literal>,
9102) -> ExecutorResult<LlkvExpr<'static, FieldId>> {
9103 match expr {
9104 LlkvExpr::And(children) => {
9105 let rewritten: ExecutorResult<Vec<_>> = children
9106 .into_iter()
9107 .map(|child| rewrite_predicate_scalar_subqueries(child, literals))
9108 .collect();
9109 Ok(LlkvExpr::And(rewritten?))
9110 }
9111 LlkvExpr::Or(children) => {
9112 let rewritten: ExecutorResult<Vec<_>> = children
9113 .into_iter()
9114 .map(|child| rewrite_predicate_scalar_subqueries(child, literals))
9115 .collect();
9116 Ok(LlkvExpr::Or(rewritten?))
9117 }
9118 LlkvExpr::Not(inner) => Ok(LlkvExpr::Not(Box::new(
9119 rewrite_predicate_scalar_subqueries(*inner, literals)?,
9120 ))),
9121 LlkvExpr::Pred(filter) => Ok(LlkvExpr::Pred(filter)),
9122 LlkvExpr::Compare { left, op, right } => Ok(LlkvExpr::Compare {
9123 left: rewrite_scalar_expr_subqueries(left, literals)?,
9124 op,
9125 right: rewrite_scalar_expr_subqueries(right, literals)?,
9126 }),
9127 LlkvExpr::InList {
9128 expr,
9129 list,
9130 negated,
9131 } => Ok(LlkvExpr::InList {
9132 expr: rewrite_scalar_expr_subqueries(expr, literals)?,
9133 list: list
9134 .into_iter()
9135 .map(|item| rewrite_scalar_expr_subqueries(item, literals))
9136 .collect::<ExecutorResult<_>>()?,
9137 negated,
9138 }),
9139 LlkvExpr::IsNull { expr, negated } => Ok(LlkvExpr::IsNull {
9140 expr: rewrite_scalar_expr_subqueries(expr, literals)?,
9141 negated,
9142 }),
9143 LlkvExpr::Literal(value) => Ok(LlkvExpr::Literal(value)),
9144 LlkvExpr::Exists(subquery) => Ok(LlkvExpr::Exists(subquery)),
9145 }
9146}
9147
9148fn rewrite_scalar_expr_subqueries(
9149 expr: ScalarExpr<FieldId>,
9150 literals: &FxHashMap<SubqueryId, Literal>,
9151) -> ExecutorResult<ScalarExpr<FieldId>> {
9152 match expr {
9153 ScalarExpr::ScalarSubquery(subquery) => {
9154 let literal = literals.get(&subquery.id).ok_or_else(|| {
9155 Error::Internal(format!(
9156 "missing literal for scalar subquery {:?}",
9157 subquery.id
9158 ))
9159 })?;
9160 Ok(ScalarExpr::Literal(literal.clone()))
9161 }
9162 ScalarExpr::Column(fid) => Ok(ScalarExpr::Column(fid)),
9163 ScalarExpr::Literal(lit) => Ok(ScalarExpr::Literal(lit)),
9164 ScalarExpr::Binary { left, op, right } => Ok(ScalarExpr::Binary {
9165 left: Box::new(rewrite_scalar_expr_subqueries(*left, literals)?),
9166 op,
9167 right: Box::new(rewrite_scalar_expr_subqueries(*right, literals)?),
9168 }),
9169 ScalarExpr::Compare { left, op, right } => Ok(ScalarExpr::Compare {
9170 left: Box::new(rewrite_scalar_expr_subqueries(*left, literals)?),
9171 op,
9172 right: Box::new(rewrite_scalar_expr_subqueries(*right, literals)?),
9173 }),
9174 ScalarExpr::Not(inner) => Ok(ScalarExpr::Not(Box::new(rewrite_scalar_expr_subqueries(
9175 *inner, literals,
9176 )?))),
9177 ScalarExpr::IsNull { expr, negated } => Ok(ScalarExpr::IsNull {
9178 expr: Box::new(rewrite_scalar_expr_subqueries(*expr, literals)?),
9179 negated,
9180 }),
9181 ScalarExpr::Aggregate(agg) => Ok(ScalarExpr::Aggregate(agg)),
9182 ScalarExpr::GetField { base, field_name } => Ok(ScalarExpr::GetField {
9183 base: Box::new(rewrite_scalar_expr_subqueries(*base, literals)?),
9184 field_name,
9185 }),
9186 ScalarExpr::Cast { expr, data_type } => Ok(ScalarExpr::Cast {
9187 expr: Box::new(rewrite_scalar_expr_subqueries(*expr, literals)?),
9188 data_type,
9189 }),
9190 ScalarExpr::Case {
9191 operand,
9192 branches,
9193 else_expr,
9194 } => Ok(ScalarExpr::Case {
9195 operand: operand
9196 .map(|e| rewrite_scalar_expr_subqueries(*e, literals))
9197 .transpose()?
9198 .map(Box::new),
9199 branches: branches
9200 .into_iter()
9201 .map(|(when, then)| {
9202 Ok((
9203 rewrite_scalar_expr_subqueries(when, literals)?,
9204 rewrite_scalar_expr_subqueries(then, literals)?,
9205 ))
9206 })
9207 .collect::<ExecutorResult<_>>()?,
9208 else_expr: else_expr
9209 .map(|e| rewrite_scalar_expr_subqueries(*e, literals))
9210 .transpose()?
9211 .map(Box::new),
9212 }),
9213 ScalarExpr::Coalesce(items) => Ok(ScalarExpr::Coalesce(
9214 items
9215 .into_iter()
9216 .map(|item| rewrite_scalar_expr_subqueries(item, literals))
9217 .collect::<ExecutorResult<_>>()?,
9218 )),
9219 ScalarExpr::Random => Ok(ScalarExpr::Random),
9220 }
9221}
9222
9223fn bind_select_plan(
9224 plan: &SelectPlan,
9225 bindings: &FxHashMap<String, Literal>,
9226) -> ExecutorResult<SelectPlan> {
9227 if bindings.is_empty() {
9228 return Ok(plan.clone());
9229 }
9230
9231 let projections = plan
9232 .projections
9233 .iter()
9234 .map(|projection| bind_projection(projection, bindings))
9235 .collect::<ExecutorResult<Vec<_>>>()?;
9236
9237 let filter = match &plan.filter {
9238 Some(wrapper) => Some(bind_select_filter(wrapper, bindings)?),
9239 None => None,
9240 };
9241
9242 let aggregates = plan
9243 .aggregates
9244 .iter()
9245 .map(|aggregate| bind_aggregate_expr(aggregate, bindings))
9246 .collect::<ExecutorResult<Vec<_>>>()?;
9247
9248 let scalar_subqueries = plan
9249 .scalar_subqueries
9250 .iter()
9251 .map(|subquery| bind_scalar_subquery(subquery, bindings))
9252 .collect::<ExecutorResult<Vec<_>>>()?;
9253
9254 if let Some(compound) = &plan.compound {
9255 let bound_compound = bind_compound_select(compound, bindings)?;
9256 return Ok(SelectPlan {
9257 tables: Vec::new(),
9258 joins: Vec::new(),
9259 projections: Vec::new(),
9260 filter: None,
9261 having: None,
9262 aggregates: Vec::new(),
9263 order_by: plan.order_by.clone(),
9264 distinct: false,
9265 scalar_subqueries: Vec::new(),
9266 compound: Some(bound_compound),
9267 group_by: Vec::new(),
9268 value_table_mode: None,
9269 limit: plan.limit,
9270 offset: plan.offset,
9271 });
9272 }
9273
9274 Ok(SelectPlan {
9275 tables: plan.tables.clone(),
9276 joins: plan.joins.clone(),
9277 projections,
9278 filter,
9279 having: plan.having.clone(),
9280 aggregates,
9281 order_by: Vec::new(),
9282 distinct: plan.distinct,
9283 scalar_subqueries,
9284 compound: None,
9285 group_by: plan.group_by.clone(),
9286 value_table_mode: plan.value_table_mode.clone(),
9287 limit: plan.limit,
9288 offset: plan.offset,
9289 })
9290}
9291
9292fn bind_compound_select(
9293 compound: &CompoundSelectPlan,
9294 bindings: &FxHashMap<String, Literal>,
9295) -> ExecutorResult<CompoundSelectPlan> {
9296 let initial = bind_select_plan(&compound.initial, bindings)?;
9297 let mut operations = Vec::with_capacity(compound.operations.len());
9298 for component in &compound.operations {
9299 let bound_plan = bind_select_plan(&component.plan, bindings)?;
9300 operations.push(CompoundSelectComponent {
9301 operator: component.operator.clone(),
9302 quantifier: component.quantifier.clone(),
9303 plan: bound_plan,
9304 });
9305 }
9306 Ok(CompoundSelectPlan {
9307 initial: Box::new(initial),
9308 operations,
9309 })
9310}
9311
9312fn ensure_schema_compatibility(base: &Schema, other: &Schema) -> ExecutorResult<()> {
9313 if base.fields().len() != other.fields().len() {
9314 return Err(Error::InvalidArgumentError(
9315 "compound SELECT requires matching column counts".into(),
9316 ));
9317 }
9318 for (left, right) in base.fields().iter().zip(other.fields().iter()) {
9319 if left.data_type() != right.data_type() {
9320 return Err(Error::InvalidArgumentError(format!(
9321 "compound SELECT column type mismatch: {} vs {}",
9322 left.data_type(),
9323 right.data_type()
9324 )));
9325 }
9326 }
9327 Ok(())
9328}
9329
9330fn ensure_distinct_rows(rows: &mut Vec<Vec<PlanValue>>, cache: &mut Option<FxHashSet<Vec<u8>>>) {
9331 if cache.is_some() {
9332 return;
9333 }
9334 let mut set = FxHashSet::default();
9335 let mut deduped: Vec<Vec<PlanValue>> = Vec::with_capacity(rows.len());
9336 for row in rows.drain(..) {
9337 let key = encode_row(&row);
9338 if set.insert(key) {
9339 deduped.push(row);
9340 }
9341 }
9342 *rows = deduped;
9343 *cache = Some(set);
9344}
9345
9346fn encode_row(row: &[PlanValue]) -> Vec<u8> {
9347 let mut buf = Vec::new();
9348 for value in row {
9349 encode_plan_value(&mut buf, value);
9350 buf.push(0x1F);
9351 }
9352 buf
9353}
9354
9355fn encode_plan_value(buf: &mut Vec<u8>, value: &PlanValue) {
9356 match value {
9357 PlanValue::Null => buf.push(0),
9358 PlanValue::Integer(v) => {
9359 buf.push(1);
9360 buf.extend_from_slice(&v.to_be_bytes());
9361 }
9362 PlanValue::Float(v) => {
9363 buf.push(2);
9364 buf.extend_from_slice(&v.to_bits().to_be_bytes());
9365 }
9366 PlanValue::Decimal(decimal) => {
9367 buf.push(7);
9368 buf.extend_from_slice(&decimal.raw_value().to_be_bytes());
9369 buf.push(decimal.scale().to_be_bytes()[0]);
9370 }
9371 PlanValue::String(s) => {
9372 buf.push(3);
9373 let bytes = s.as_bytes();
9374 let len = u32::try_from(bytes.len()).unwrap_or(u32::MAX);
9375 buf.extend_from_slice(&len.to_be_bytes());
9376 buf.extend_from_slice(bytes);
9377 }
9378 PlanValue::Date32(days) => {
9379 buf.push(5);
9380 buf.extend_from_slice(&days.to_be_bytes());
9381 }
9382 PlanValue::Struct(map) => {
9383 buf.push(4);
9384 let mut entries: Vec<_> = map.iter().collect();
9385 entries.sort_by(|a, b| a.0.cmp(b.0));
9386 let len = u32::try_from(entries.len()).unwrap_or(u32::MAX);
9387 buf.extend_from_slice(&len.to_be_bytes());
9388 for (key, val) in entries {
9389 let key_bytes = key.as_bytes();
9390 let key_len = u32::try_from(key_bytes.len()).unwrap_or(u32::MAX);
9391 buf.extend_from_slice(&key_len.to_be_bytes());
9392 buf.extend_from_slice(key_bytes);
9393 encode_plan_value(buf, val);
9394 }
9395 }
9396 PlanValue::Interval(interval) => {
9397 buf.push(6);
9398 buf.extend_from_slice(&interval.months.to_be_bytes());
9399 buf.extend_from_slice(&interval.days.to_be_bytes());
9400 buf.extend_from_slice(&interval.nanos.to_be_bytes());
9401 }
9402 }
9403}
9404
9405fn rows_to_record_batch(
9406 schema: Arc<Schema>,
9407 rows: &[Vec<PlanValue>],
9408) -> ExecutorResult<RecordBatch> {
9409 let column_count = schema.fields().len();
9410 let mut columns: Vec<Vec<PlanValue>> = vec![Vec::with_capacity(rows.len()); column_count];
9411 for row in rows {
9412 if row.len() != column_count {
9413 return Err(Error::InvalidArgumentError(
9414 "compound SELECT produced mismatched column counts".into(),
9415 ));
9416 }
9417 for (idx, value) in row.iter().enumerate() {
9418 columns[idx].push(value.clone());
9419 }
9420 }
9421
9422 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(column_count);
9423 for (idx, field) in schema.fields().iter().enumerate() {
9424 let array = build_array_for_column(field.data_type(), &columns[idx])?;
9425 arrays.push(array);
9426 }
9427
9428 RecordBatch::try_new(schema, arrays).map_err(|err| {
9429 Error::InvalidArgumentError(format!("failed to materialize compound SELECT: {err}"))
9430 })
9431}
9432
9433fn build_column_lookup_map(schema: &Schema) -> FxHashMap<String, usize> {
9434 let mut lookup = FxHashMap::default();
9435 for (idx, field) in schema.fields().iter().enumerate() {
9436 lookup.insert(field.name().to_ascii_lowercase(), idx);
9437 }
9438 lookup
9439}
9440
9441fn build_group_key(
9442 batch: &RecordBatch,
9443 row_idx: usize,
9444 key_indices: &[usize],
9445) -> ExecutorResult<Vec<GroupKeyValue>> {
9446 let mut values = Vec::with_capacity(key_indices.len());
9447 for &index in key_indices {
9448 values.push(group_key_value(batch.column(index), row_idx)?);
9449 }
9450 Ok(values)
9451}
9452
9453fn group_key_value(array: &ArrayRef, row_idx: usize) -> ExecutorResult<GroupKeyValue> {
9454 if !array.is_valid(row_idx) {
9455 return Ok(GroupKeyValue::Null);
9456 }
9457
9458 match array.data_type() {
9459 DataType::Int8 => {
9460 let values = array
9461 .as_any()
9462 .downcast_ref::<Int8Array>()
9463 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
9464 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9465 }
9466 DataType::Int16 => {
9467 let values = array
9468 .as_any()
9469 .downcast_ref::<Int16Array>()
9470 .ok_or_else(|| Error::Internal("failed to downcast to Int16Array".into()))?;
9471 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9472 }
9473 DataType::Int32 => {
9474 let values = array
9475 .as_any()
9476 .downcast_ref::<Int32Array>()
9477 .ok_or_else(|| Error::Internal("failed to downcast to Int32Array".into()))?;
9478 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9479 }
9480 DataType::Int64 => {
9481 let values = array
9482 .as_any()
9483 .downcast_ref::<Int64Array>()
9484 .ok_or_else(|| Error::Internal("failed to downcast to Int64Array".into()))?;
9485 Ok(GroupKeyValue::Int(values.value(row_idx)))
9486 }
9487 DataType::UInt8 => {
9488 let values = array
9489 .as_any()
9490 .downcast_ref::<UInt8Array>()
9491 .ok_or_else(|| Error::Internal("failed to downcast to UInt8Array".into()))?;
9492 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9493 }
9494 DataType::UInt16 => {
9495 let values = array
9496 .as_any()
9497 .downcast_ref::<UInt16Array>()
9498 .ok_or_else(|| Error::Internal("failed to downcast to UInt16Array".into()))?;
9499 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9500 }
9501 DataType::UInt32 => {
9502 let values = array
9503 .as_any()
9504 .downcast_ref::<UInt32Array>()
9505 .ok_or_else(|| Error::Internal("failed to downcast to UInt32Array".into()))?;
9506 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9507 }
9508 DataType::UInt64 => {
9509 let values = array
9510 .as_any()
9511 .downcast_ref::<UInt64Array>()
9512 .ok_or_else(|| Error::Internal("failed to downcast to UInt64Array".into()))?;
9513 let value = values.value(row_idx);
9514 if value > i64::MAX as u64 {
9515 return Err(Error::InvalidArgumentError(
9516 "GROUP BY value exceeds supported integer range".into(),
9517 ));
9518 }
9519 Ok(GroupKeyValue::Int(value as i64))
9520 }
9521 DataType::Date32 => {
9522 let values = array
9523 .as_any()
9524 .downcast_ref::<Date32Array>()
9525 .ok_or_else(|| Error::Internal("failed to downcast to Date32Array".into()))?;
9526 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9527 }
9528 DataType::Boolean => {
9529 let values = array
9530 .as_any()
9531 .downcast_ref::<BooleanArray>()
9532 .ok_or_else(|| Error::Internal("failed to downcast to BooleanArray".into()))?;
9533 Ok(GroupKeyValue::Bool(values.value(row_idx)))
9534 }
9535 DataType::Utf8 => {
9536 let values = array
9537 .as_any()
9538 .downcast_ref::<StringArray>()
9539 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
9540 Ok(GroupKeyValue::String(values.value(row_idx).to_string()))
9541 }
9542 other => Err(Error::InvalidArgumentError(format!(
9543 "GROUP BY does not support column type {:?}",
9544 other
9545 ))),
9546 }
9547}
9548
9549fn evaluate_constant_predicate(expr: &LlkvExpr<'static, String>) -> Option<Option<bool>> {
9550 match expr {
9551 LlkvExpr::Literal(value) => Some(Some(*value)),
9552 LlkvExpr::Not(inner) => {
9553 let inner_val = evaluate_constant_predicate(inner)?;
9554 Some(truth_not(inner_val))
9555 }
9556 LlkvExpr::And(children) => {
9557 let mut acc = Some(true);
9558 for child in children {
9559 let child_val = evaluate_constant_predicate(child)?;
9560 acc = truth_and(acc, child_val);
9561 }
9562 Some(acc)
9563 }
9564 LlkvExpr::Or(children) => {
9565 let mut acc = Some(false);
9566 for child in children {
9567 let child_val = evaluate_constant_predicate(child)?;
9568 acc = truth_or(acc, child_val);
9569 }
9570 Some(acc)
9571 }
9572 LlkvExpr::Compare { left, op, right } => {
9573 let left_literal = evaluate_constant_scalar(left)?;
9574 let right_literal = evaluate_constant_scalar(right)?;
9575 Some(compare_literals(*op, &left_literal, &right_literal))
9576 }
9577 LlkvExpr::IsNull { expr, negated } => {
9578 let literal = evaluate_constant_scalar(expr)?;
9579 let is_null = matches!(literal, Literal::Null);
9580 Some(Some(if *negated { !is_null } else { is_null }))
9581 }
9582 LlkvExpr::InList {
9583 expr,
9584 list,
9585 negated,
9586 } => {
9587 let needle = evaluate_constant_scalar(expr)?;
9588 let mut saw_unknown = false;
9589
9590 for candidate in list {
9591 let value = evaluate_constant_scalar(candidate)?;
9592 match compare_literals(CompareOp::Eq, &needle, &value) {
9593 Some(true) => {
9594 return Some(Some(!*negated));
9595 }
9596 Some(false) => {}
9597 None => saw_unknown = true,
9598 }
9599 }
9600
9601 if saw_unknown {
9602 Some(None)
9603 } else {
9604 Some(Some(*negated))
9605 }
9606 }
9607 _ => None,
9608 }
9609}
9610
9611enum ConstantJoinEvaluation {
9612 Known(bool),
9613 Unknown,
9614 NotConstant,
9615}
9616
9617fn evaluate_constant_join_expr(expr: &LlkvExpr<'static, String>) -> ConstantJoinEvaluation {
9618 match expr {
9619 LlkvExpr::Literal(value) => ConstantJoinEvaluation::Known(*value),
9620 LlkvExpr::And(children) => {
9621 let mut saw_unknown = false;
9622 for child in children {
9623 match evaluate_constant_join_expr(child) {
9624 ConstantJoinEvaluation::Known(false) => {
9625 return ConstantJoinEvaluation::Known(false);
9626 }
9627 ConstantJoinEvaluation::Known(true) => {}
9628 ConstantJoinEvaluation::Unknown => saw_unknown = true,
9629 ConstantJoinEvaluation::NotConstant => {
9630 return ConstantJoinEvaluation::NotConstant;
9631 }
9632 }
9633 }
9634 if saw_unknown {
9635 ConstantJoinEvaluation::Unknown
9636 } else {
9637 ConstantJoinEvaluation::Known(true)
9638 }
9639 }
9640 LlkvExpr::Or(children) => {
9641 let mut saw_unknown = false;
9642 for child in children {
9643 match evaluate_constant_join_expr(child) {
9644 ConstantJoinEvaluation::Known(true) => {
9645 return ConstantJoinEvaluation::Known(true);
9646 }
9647 ConstantJoinEvaluation::Known(false) => {}
9648 ConstantJoinEvaluation::Unknown => saw_unknown = true,
9649 ConstantJoinEvaluation::NotConstant => {
9650 return ConstantJoinEvaluation::NotConstant;
9651 }
9652 }
9653 }
9654 if saw_unknown {
9655 ConstantJoinEvaluation::Unknown
9656 } else {
9657 ConstantJoinEvaluation::Known(false)
9658 }
9659 }
9660 LlkvExpr::Not(inner) => match evaluate_constant_join_expr(inner) {
9661 ConstantJoinEvaluation::Known(value) => ConstantJoinEvaluation::Known(!value),
9662 ConstantJoinEvaluation::Unknown => ConstantJoinEvaluation::Unknown,
9663 ConstantJoinEvaluation::NotConstant => ConstantJoinEvaluation::NotConstant,
9664 },
9665 LlkvExpr::Compare { left, op, right } => {
9666 let left_lit = evaluate_constant_scalar(left);
9667 let right_lit = evaluate_constant_scalar(right);
9668
9669 if matches!(left_lit, Some(Literal::Null)) || matches!(right_lit, Some(Literal::Null)) {
9670 return ConstantJoinEvaluation::Unknown;
9672 }
9673
9674 let (Some(left_lit), Some(right_lit)) = (left_lit, right_lit) else {
9675 return ConstantJoinEvaluation::NotConstant;
9676 };
9677
9678 match compare_literals(*op, &left_lit, &right_lit) {
9679 Some(result) => ConstantJoinEvaluation::Known(result),
9680 None => ConstantJoinEvaluation::Unknown,
9681 }
9682 }
9683 LlkvExpr::IsNull { expr, negated } => match evaluate_constant_scalar(expr) {
9684 Some(literal) => {
9685 let is_null = matches!(literal, Literal::Null);
9686 let value = if *negated { !is_null } else { is_null };
9687 ConstantJoinEvaluation::Known(value)
9688 }
9689 None => ConstantJoinEvaluation::NotConstant,
9690 },
9691 LlkvExpr::InList {
9692 expr,
9693 list,
9694 negated,
9695 } => {
9696 let needle = match evaluate_constant_scalar(expr) {
9697 Some(literal) => literal,
9698 None => return ConstantJoinEvaluation::NotConstant,
9699 };
9700
9701 if matches!(needle, Literal::Null) {
9702 return ConstantJoinEvaluation::Unknown;
9703 }
9704
9705 let mut saw_unknown = false;
9706 for candidate in list {
9707 let value = match evaluate_constant_scalar(candidate) {
9708 Some(literal) => literal,
9709 None => return ConstantJoinEvaluation::NotConstant,
9710 };
9711
9712 match compare_literals(CompareOp::Eq, &needle, &value) {
9713 Some(true) => {
9714 let result = !*negated;
9715 return ConstantJoinEvaluation::Known(result);
9716 }
9717 Some(false) => {}
9718 None => saw_unknown = true,
9719 }
9720 }
9721
9722 if saw_unknown {
9723 ConstantJoinEvaluation::Unknown
9724 } else {
9725 let result = *negated;
9726 ConstantJoinEvaluation::Known(result)
9727 }
9728 }
9729 _ => ConstantJoinEvaluation::NotConstant,
9730 }
9731}
9732
9733enum NullComparisonBehavior {
9734 ThreeValuedLogic,
9735}
9736
9737fn evaluate_constant_scalar(expr: &ScalarExpr<String>) -> Option<Literal> {
9738 evaluate_constant_scalar_internal(expr, false)
9739}
9740
9741fn evaluate_constant_scalar_with_aggregates(expr: &ScalarExpr<String>) -> Option<Literal> {
9742 evaluate_constant_scalar_internal(expr, true)
9743}
9744
9745fn evaluate_constant_scalar_internal(
9746 expr: &ScalarExpr<String>,
9747 allow_aggregates: bool,
9748) -> Option<Literal> {
9749 match expr {
9750 ScalarExpr::Literal(lit) => Some(lit.clone()),
9751 ScalarExpr::Binary { left, op, right } => {
9752 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
9753 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
9754 evaluate_binary_literal(*op, &left_value, &right_value)
9755 }
9756 ScalarExpr::Cast { expr, data_type } => {
9757 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9758 cast_literal_to_type(&value, data_type)
9759 }
9760 ScalarExpr::Not(inner) => {
9761 let value = evaluate_constant_scalar_internal(inner, allow_aggregates)?;
9762 match literal_truthiness(&value) {
9763 Some(true) => Some(Literal::Int128(0)),
9764 Some(false) => Some(Literal::Int128(1)),
9765 None => Some(Literal::Null),
9766 }
9767 }
9768 ScalarExpr::IsNull { expr, negated } => {
9769 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9770 let is_null = matches!(value, Literal::Null);
9771 Some(Literal::Boolean(if *negated { !is_null } else { is_null }))
9772 }
9773 ScalarExpr::Coalesce(items) => {
9774 let mut saw_null = false;
9775 for item in items {
9776 match evaluate_constant_scalar_internal(item, allow_aggregates) {
9777 Some(Literal::Null) => saw_null = true,
9778 Some(value) => return Some(value),
9779 None => return None,
9780 }
9781 }
9782 if saw_null { Some(Literal::Null) } else { None }
9783 }
9784 ScalarExpr::Compare { left, op, right } => {
9785 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
9786 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
9787 match compare_literals(*op, &left_value, &right_value) {
9788 Some(flag) => Some(Literal::Boolean(flag)),
9789 None => Some(Literal::Null),
9790 }
9791 }
9792 ScalarExpr::Case {
9793 operand,
9794 branches,
9795 else_expr,
9796 } => {
9797 if let Some(operand_expr) = operand {
9798 let operand_value =
9799 evaluate_constant_scalar_internal(operand_expr, allow_aggregates)?;
9800 for (when_expr, then_expr) in branches {
9801 let when_value =
9802 evaluate_constant_scalar_internal(when_expr, allow_aggregates)?;
9803 if let Some(true) = compare_literals(CompareOp::Eq, &operand_value, &when_value)
9804 {
9805 return evaluate_constant_scalar_internal(then_expr, allow_aggregates);
9806 }
9807 }
9808 } else {
9809 for (condition_expr, result_expr) in branches {
9810 let condition_value =
9811 evaluate_constant_scalar_internal(condition_expr, allow_aggregates)?;
9812 match literal_truthiness(&condition_value) {
9813 Some(true) => {
9814 return evaluate_constant_scalar_internal(
9815 result_expr,
9816 allow_aggregates,
9817 );
9818 }
9819 Some(false) => {}
9820 None => {}
9821 }
9822 }
9823 }
9824
9825 if let Some(else_branch) = else_expr {
9826 evaluate_constant_scalar_internal(else_branch, allow_aggregates)
9827 } else {
9828 Some(Literal::Null)
9829 }
9830 }
9831 ScalarExpr::Column(_) => None,
9832 ScalarExpr::Aggregate(call) => {
9833 if allow_aggregates {
9834 evaluate_constant_aggregate(call, allow_aggregates)
9835 } else {
9836 None
9837 }
9838 }
9839 ScalarExpr::GetField { .. } => None,
9840 ScalarExpr::Random => None, ScalarExpr::ScalarSubquery(_) => None,
9842 }
9843}
9844
9845fn evaluate_constant_aggregate(
9846 call: &AggregateCall<String>,
9847 allow_aggregates: bool,
9848) -> Option<Literal> {
9849 match call {
9850 AggregateCall::CountStar => Some(Literal::Int128(1)),
9851 AggregateCall::Count { expr, .. } => {
9852 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9853 if matches!(value, Literal::Null) {
9854 Some(Literal::Int128(0))
9855 } else {
9856 Some(Literal::Int128(1))
9857 }
9858 }
9859 AggregateCall::Sum { expr, .. } => {
9860 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9861 match value {
9862 Literal::Null => Some(Literal::Null),
9863 Literal::Int128(value) => Some(Literal::Int128(value)),
9864 Literal::Float64(value) => Some(Literal::Float64(value)),
9865 Literal::Boolean(flag) => Some(Literal::Int128(if flag { 1 } else { 0 })),
9866 _ => None,
9867 }
9868 }
9869 AggregateCall::Total { expr, .. } => {
9870 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9871 match value {
9872 Literal::Null => Some(Literal::Int128(0)),
9873 Literal::Int128(value) => Some(Literal::Int128(value)),
9874 Literal::Float64(value) => Some(Literal::Float64(value)),
9875 Literal::Boolean(flag) => Some(Literal::Int128(if flag { 1 } else { 0 })),
9876 _ => None,
9877 }
9878 }
9879 AggregateCall::Avg { expr, .. } => {
9880 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9881 match value {
9882 Literal::Null => Some(Literal::Null),
9883 other => {
9884 let numeric = literal_to_f64(&other)?;
9885 Some(Literal::Float64(numeric))
9886 }
9887 }
9888 }
9889 AggregateCall::Min(expr) => {
9890 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9891 match value {
9892 Literal::Null => Some(Literal::Null),
9893 other => Some(other),
9894 }
9895 }
9896 AggregateCall::Max(expr) => {
9897 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9898 match value {
9899 Literal::Null => Some(Literal::Null),
9900 other => Some(other),
9901 }
9902 }
9903 AggregateCall::CountNulls(expr) => {
9904 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9905 let count = if matches!(value, Literal::Null) { 1 } else { 0 };
9906 Some(Literal::Int128(count))
9907 }
9908 AggregateCall::GroupConcat {
9909 expr, separator: _, ..
9910 } => {
9911 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9912 match value {
9913 Literal::Null => Some(Literal::Null),
9914 Literal::String(s) => Some(Literal::String(s)),
9915 Literal::Int128(i) => Some(Literal::String(i.to_string())),
9916 Literal::Float64(f) => Some(Literal::String(f.to_string())),
9917 Literal::Boolean(b) => Some(Literal::String(if b { "1" } else { "0" }.to_string())),
9918 _ => None,
9919 }
9920 }
9921 }
9922}
9923
9924fn evaluate_binary_literal(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
9925 match op {
9926 BinaryOp::And => evaluate_literal_logical_and(left, right),
9927 BinaryOp::Or => evaluate_literal_logical_or(left, right),
9928 BinaryOp::Add
9929 | BinaryOp::Subtract
9930 | BinaryOp::Multiply
9931 | BinaryOp::Divide
9932 | BinaryOp::Modulo => {
9933 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
9934 return Some(Literal::Null);
9935 }
9936
9937 match op {
9938 BinaryOp::Add => add_literals(left, right),
9939 BinaryOp::Subtract => subtract_literals(left, right),
9940 BinaryOp::Multiply => multiply_literals(left, right),
9941 BinaryOp::Divide => divide_literals(left, right),
9942 BinaryOp::Modulo => modulo_literals(left, right),
9943 BinaryOp::And
9944 | BinaryOp::Or
9945 | BinaryOp::BitwiseShiftLeft
9946 | BinaryOp::BitwiseShiftRight => unreachable!(),
9947 }
9948 }
9949 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
9950 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
9951 return Some(Literal::Null);
9952 }
9953
9954 let lhs = literal_to_i128(left)?;
9956 let rhs = literal_to_i128(right)?;
9957
9958 let result = match op {
9960 BinaryOp::BitwiseShiftLeft => (lhs as i64).wrapping_shl(rhs as u32) as i128,
9961 BinaryOp::BitwiseShiftRight => (lhs as i64).wrapping_shr(rhs as u32) as i128,
9962 _ => unreachable!(),
9963 };
9964
9965 Some(Literal::Int128(result))
9966 }
9967 }
9968}
9969
9970fn evaluate_literal_logical_and(left: &Literal, right: &Literal) -> Option<Literal> {
9971 let left_truth = literal_truthiness(left);
9972 if matches!(left_truth, Some(false)) {
9973 return Some(Literal::Int128(0));
9974 }
9975
9976 let right_truth = literal_truthiness(right);
9977 if matches!(right_truth, Some(false)) {
9978 return Some(Literal::Int128(0));
9979 }
9980
9981 match (left_truth, right_truth) {
9982 (Some(true), Some(true)) => Some(Literal::Int128(1)),
9983 (Some(true), None) | (None, Some(true)) | (None, None) => Some(Literal::Null),
9984 _ => Some(Literal::Null),
9985 }
9986}
9987
9988fn evaluate_literal_logical_or(left: &Literal, right: &Literal) -> Option<Literal> {
9989 let left_truth = literal_truthiness(left);
9990 if matches!(left_truth, Some(true)) {
9991 return Some(Literal::Int128(1));
9992 }
9993
9994 let right_truth = literal_truthiness(right);
9995 if matches!(right_truth, Some(true)) {
9996 return Some(Literal::Int128(1));
9997 }
9998
9999 match (left_truth, right_truth) {
10000 (Some(false), Some(false)) => Some(Literal::Int128(0)),
10001 (Some(false), None) | (None, Some(false)) | (None, None) => Some(Literal::Null),
10002 _ => Some(Literal::Null),
10003 }
10004}
10005
10006fn add_literals(left: &Literal, right: &Literal) -> Option<Literal> {
10007 match (left, right) {
10008 (Literal::Int128(lhs), Literal::Int128(rhs)) => {
10009 Some(Literal::Int128(lhs.saturating_add(*rhs)))
10010 }
10011 _ => {
10012 let lhs = literal_to_f64(left)?;
10013 let rhs = literal_to_f64(right)?;
10014 Some(Literal::Float64(lhs + rhs))
10015 }
10016 }
10017}
10018
10019fn subtract_literals(left: &Literal, right: &Literal) -> Option<Literal> {
10020 match (left, right) {
10021 (Literal::Int128(lhs), Literal::Int128(rhs)) => {
10022 Some(Literal::Int128(lhs.saturating_sub(*rhs)))
10023 }
10024 _ => {
10025 let lhs = literal_to_f64(left)?;
10026 let rhs = literal_to_f64(right)?;
10027 Some(Literal::Float64(lhs - rhs))
10028 }
10029 }
10030}
10031
10032fn multiply_literals(left: &Literal, right: &Literal) -> Option<Literal> {
10033 match (left, right) {
10034 (Literal::Int128(lhs), Literal::Int128(rhs)) => {
10035 Some(Literal::Int128(lhs.saturating_mul(*rhs)))
10036 }
10037 _ => {
10038 let lhs = literal_to_f64(left)?;
10039 let rhs = literal_to_f64(right)?;
10040 Some(Literal::Float64(lhs * rhs))
10041 }
10042 }
10043}
10044
10045fn divide_literals(left: &Literal, right: &Literal) -> Option<Literal> {
10046 fn literal_to_i128_from_integer_like(literal: &Literal) -> Option<i128> {
10047 match literal {
10048 Literal::Int128(value) => Some(*value),
10049 Literal::Decimal128(value) => llkv_compute::scalar::decimal::rescale(*value, 0)
10050 .ok()
10051 .map(|integral| integral.raw_value()),
10052 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
10053 Literal::Date32(value) => Some(*value as i128),
10054 _ => None,
10055 }
10056 }
10057
10058 if let (Some(lhs), Some(rhs)) = (
10059 literal_to_i128_from_integer_like(left),
10060 literal_to_i128_from_integer_like(right),
10061 ) {
10062 if rhs == 0 {
10063 return Some(Literal::Null);
10064 }
10065
10066 if lhs == i128::MIN && rhs == -1 {
10067 return Some(Literal::Float64((lhs as f64) / (rhs as f64)));
10068 }
10069
10070 return Some(Literal::Int128(lhs / rhs));
10071 }
10072
10073 let lhs = literal_to_f64(left)?;
10074 let rhs = literal_to_f64(right)?;
10075 if rhs == 0.0 {
10076 return Some(Literal::Null);
10077 }
10078 Some(Literal::Float64(lhs / rhs))
10079}
10080
10081fn modulo_literals(left: &Literal, right: &Literal) -> Option<Literal> {
10082 let lhs = literal_to_i128(left)?;
10083 let rhs = literal_to_i128(right)?;
10084 if rhs == 0 {
10085 return Some(Literal::Null);
10086 }
10087 Some(Literal::Int128(lhs % rhs))
10088}
10089
10090fn literal_to_f64(literal: &Literal) -> Option<f64> {
10091 match literal {
10092 Literal::Int128(value) => Some(*value as f64),
10093 Literal::Float64(value) => Some(*value),
10094 Literal::Decimal128(value) => Some(value.to_f64()),
10095 Literal::Boolean(value) => Some(if *value { 1.0 } else { 0.0 }),
10096 Literal::Date32(value) => Some(*value as f64),
10097 _ => None,
10098 }
10099}
10100
10101fn literal_to_i128(literal: &Literal) -> Option<i128> {
10102 match literal {
10103 Literal::Int128(value) => Some(*value),
10104 Literal::Float64(value) => Some(*value as i128),
10105 Literal::Decimal128(value) => llkv_compute::scalar::decimal::rescale(*value, 0)
10106 .ok()
10107 .map(|integral| integral.raw_value()),
10108 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
10109 Literal::Date32(value) => Some(*value as i128),
10110 _ => None,
10111 }
10112}
10113
10114fn literal_truthiness(literal: &Literal) -> Option<bool> {
10115 match literal {
10116 Literal::Boolean(value) => Some(*value),
10117 Literal::Int128(value) => Some(*value != 0),
10118 Literal::Float64(value) => Some(*value != 0.0),
10119 Literal::Decimal128(value) => Some(decimal_truthy(*value)),
10120 Literal::Date32(value) => Some(*value != 0),
10121 Literal::Null => None,
10122 _ => None,
10123 }
10124}
10125
10126fn plan_value_truthiness(value: &PlanValue) -> Option<bool> {
10127 match value {
10128 PlanValue::Integer(v) => Some(*v != 0),
10129 PlanValue::Float(v) => Some(*v != 0.0),
10130 PlanValue::Decimal(v) => Some(decimal_truthy(*v)),
10131 PlanValue::Date32(v) => Some(*v != 0),
10132 PlanValue::Null => None,
10133 _ => None,
10134 }
10135}
10136
10137fn option_i64_truthiness(value: Option<i64>) -> Option<bool> {
10138 value.map(|v| v != 0)
10139}
10140
10141fn evaluate_plan_value_logical_and(left: PlanValue, right: PlanValue) -> PlanValue {
10142 let left_truth = plan_value_truthiness(&left);
10143 if matches!(left_truth, Some(false)) {
10144 return PlanValue::Integer(0);
10145 }
10146
10147 let right_truth = plan_value_truthiness(&right);
10148 if matches!(right_truth, Some(false)) {
10149 return PlanValue::Integer(0);
10150 }
10151
10152 match (left_truth, right_truth) {
10153 (Some(true), Some(true)) => PlanValue::Integer(1),
10154 (Some(true), None) | (None, Some(true)) | (None, None) => PlanValue::Null,
10155 _ => PlanValue::Null,
10156 }
10157}
10158
10159fn evaluate_plan_value_logical_or(left: PlanValue, right: PlanValue) -> PlanValue {
10160 let left_truth = plan_value_truthiness(&left);
10161 if matches!(left_truth, Some(true)) {
10162 return PlanValue::Integer(1);
10163 }
10164
10165 let right_truth = plan_value_truthiness(&right);
10166 if matches!(right_truth, Some(true)) {
10167 return PlanValue::Integer(1);
10168 }
10169
10170 match (left_truth, right_truth) {
10171 (Some(false), Some(false)) => PlanValue::Integer(0),
10172 (Some(false), None) | (None, Some(false)) | (None, None) => PlanValue::Null,
10173 _ => PlanValue::Null,
10174 }
10175}
10176
10177fn evaluate_option_logical_and(left: Option<i64>, right: Option<i64>) -> Option<i64> {
10178 let left_truth = option_i64_truthiness(left);
10179 if matches!(left_truth, Some(false)) {
10180 return Some(0);
10181 }
10182
10183 let right_truth = option_i64_truthiness(right);
10184 if matches!(right_truth, Some(false)) {
10185 return Some(0);
10186 }
10187
10188 match (left_truth, right_truth) {
10189 (Some(true), Some(true)) => Some(1),
10190 (Some(true), None) | (None, Some(true)) | (None, None) => None,
10191 _ => None,
10192 }
10193}
10194
10195fn evaluate_option_logical_or(left: Option<i64>, right: Option<i64>) -> Option<i64> {
10196 let left_truth = option_i64_truthiness(left);
10197 if matches!(left_truth, Some(true)) {
10198 return Some(1);
10199 }
10200
10201 let right_truth = option_i64_truthiness(right);
10202 if matches!(right_truth, Some(true)) {
10203 return Some(1);
10204 }
10205
10206 match (left_truth, right_truth) {
10207 (Some(false), Some(false)) => Some(0),
10208 (Some(false), None) | (None, Some(false)) | (None, None) => None,
10209 _ => None,
10210 }
10211}
10212
10213fn cast_literal_to_type(literal: &Literal, data_type: &DataType) -> Option<Literal> {
10214 if matches!(literal, Literal::Null) {
10215 return Some(Literal::Null);
10216 }
10217
10218 match data_type {
10219 DataType::Boolean => literal_truthiness(literal).map(Literal::Boolean),
10220 DataType::Float16 | DataType::Float32 | DataType::Float64 => {
10221 let value = literal_to_f64(literal)?;
10222 Some(Literal::Float64(value))
10223 }
10224 DataType::Int8
10225 | DataType::Int16
10226 | DataType::Int32
10227 | DataType::Int64
10228 | DataType::UInt8
10229 | DataType::UInt16
10230 | DataType::UInt32
10231 | DataType::UInt64 => {
10232 let value = literal_to_i128(literal)?;
10233 Some(Literal::Int128(value))
10234 }
10235 DataType::Utf8 | DataType::LargeUtf8 => Some(Literal::String(match literal {
10236 Literal::String(text) => text.clone(),
10237 Literal::Int128(value) => value.to_string(),
10238 Literal::Float64(value) => value.to_string(),
10239 Literal::Decimal128(value) => value.to_string(),
10240 Literal::Boolean(value) => {
10241 if *value {
10242 "1".to_string()
10243 } else {
10244 "0".to_string()
10245 }
10246 }
10247 Literal::Date32(days) => format_date32_literal(*days).ok()?,
10248 Literal::Struct(_) | Literal::Null | Literal::Interval(_) => return None,
10249 })),
10250 DataType::Decimal128(precision, scale) => {
10251 literal_to_decimal_literal(literal, *precision, *scale)
10252 }
10253 DataType::Decimal256(precision, scale) => {
10254 literal_to_decimal_literal(literal, *precision, *scale)
10255 }
10256 DataType::Interval(IntervalUnit::MonthDayNano) => match literal {
10257 Literal::Interval(interval) => Some(Literal::Interval(*interval)),
10258 Literal::Null => Some(Literal::Null),
10259 _ => None,
10260 },
10261 DataType::Date32 => match literal {
10262 Literal::Null => Some(Literal::Null),
10263 Literal::Date32(days) => Some(Literal::Date32(*days)),
10264 Literal::String(text) => parse_date32_literal(text).ok().map(Literal::Date32),
10265 _ => None,
10266 },
10267 _ => None,
10268 }
10269}
10270
10271fn literal_to_decimal_literal(literal: &Literal, precision: u8, scale: i8) -> Option<Literal> {
10272 match literal {
10273 Literal::Decimal128(value) => align_decimal_to_scale(*value, precision, scale)
10274 .ok()
10275 .map(Literal::Decimal128),
10276 Literal::Int128(value) => {
10277 let int = i64::try_from(*value).ok()?;
10278 decimal_from_i64(int, precision, scale)
10279 .ok()
10280 .map(Literal::Decimal128)
10281 }
10282 Literal::Float64(value) => decimal_from_f64(*value, precision, scale)
10283 .ok()
10284 .map(Literal::Decimal128),
10285 Literal::Boolean(value) => {
10286 let int = if *value { 1 } else { 0 };
10287 decimal_from_i64(int, precision, scale)
10288 .ok()
10289 .map(Literal::Decimal128)
10290 }
10291 Literal::Null => Some(Literal::Null),
10292 _ => None,
10293 }
10294}
10295
10296fn compare_literals(op: CompareOp, left: &Literal, right: &Literal) -> Option<bool> {
10297 compare_literals_with_mode(op, left, right, NullComparisonBehavior::ThreeValuedLogic)
10298}
10299
10300fn bind_select_filter(
10301 filter: &llkv_plan::SelectFilter,
10302 bindings: &FxHashMap<String, Literal>,
10303) -> ExecutorResult<llkv_plan::SelectFilter> {
10304 let predicate = bind_predicate_expr(&filter.predicate, bindings)?;
10305 let subqueries = filter
10306 .subqueries
10307 .iter()
10308 .map(|subquery| bind_filter_subquery(subquery, bindings))
10309 .collect::<ExecutorResult<Vec<_>>>()?;
10310
10311 Ok(llkv_plan::SelectFilter {
10312 predicate,
10313 subqueries,
10314 })
10315}
10316
10317fn bind_filter_subquery(
10318 subquery: &llkv_plan::FilterSubquery,
10319 bindings: &FxHashMap<String, Literal>,
10320) -> ExecutorResult<llkv_plan::FilterSubquery> {
10321 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
10322 Ok(llkv_plan::FilterSubquery {
10323 id: subquery.id,
10324 plan: Box::new(bound_plan),
10325 correlated_columns: subquery.correlated_columns.clone(),
10326 })
10327}
10328
10329fn bind_scalar_subquery(
10330 subquery: &llkv_plan::ScalarSubquery,
10331 bindings: &FxHashMap<String, Literal>,
10332) -> ExecutorResult<llkv_plan::ScalarSubquery> {
10333 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
10334 Ok(llkv_plan::ScalarSubquery {
10335 id: subquery.id,
10336 plan: Box::new(bound_plan),
10337 correlated_columns: subquery.correlated_columns.clone(),
10338 })
10339}
10340
10341fn bind_projection(
10342 projection: &SelectProjection,
10343 bindings: &FxHashMap<String, Literal>,
10344) -> ExecutorResult<SelectProjection> {
10345 match projection {
10346 SelectProjection::AllColumns => Ok(projection.clone()),
10347 SelectProjection::AllColumnsExcept { exclude } => Ok(SelectProjection::AllColumnsExcept {
10348 exclude: exclude.clone(),
10349 }),
10350 SelectProjection::Column { name, alias } => {
10351 if let Some(literal) = bindings.get(name) {
10352 let expr = ScalarExpr::Literal(literal.clone());
10353 Ok(SelectProjection::Computed {
10354 expr,
10355 alias: alias.clone().unwrap_or_else(|| name.clone()),
10356 })
10357 } else {
10358 Ok(projection.clone())
10359 }
10360 }
10361 SelectProjection::Computed { expr, alias } => Ok(SelectProjection::Computed {
10362 expr: bind_scalar_expr(expr, bindings)?,
10363 alias: alias.clone(),
10364 }),
10365 }
10366}
10367
10368fn bind_aggregate_expr(
10369 aggregate: &AggregateExpr,
10370 bindings: &FxHashMap<String, Literal>,
10371) -> ExecutorResult<AggregateExpr> {
10372 match aggregate {
10373 AggregateExpr::CountStar { .. } => Ok(aggregate.clone()),
10374 AggregateExpr::Column {
10375 column,
10376 alias,
10377 function,
10378 distinct,
10379 } => {
10380 if bindings.contains_key(column) {
10381 return Err(Error::InvalidArgumentError(
10382 "correlated columns are not supported inside aggregate expressions".into(),
10383 ));
10384 }
10385 Ok(AggregateExpr::Column {
10386 column: column.clone(),
10387 alias: alias.clone(),
10388 function: function.clone(),
10389 distinct: *distinct,
10390 })
10391 }
10392 }
10393}
10394
10395fn bind_scalar_expr(
10396 expr: &ScalarExpr<String>,
10397 bindings: &FxHashMap<String, Literal>,
10398) -> ExecutorResult<ScalarExpr<String>> {
10399 match expr {
10400 ScalarExpr::Column(name) => {
10401 if let Some(literal) = bindings.get(name) {
10402 Ok(ScalarExpr::Literal(literal.clone()))
10403 } else {
10404 Ok(ScalarExpr::Column(name.clone()))
10405 }
10406 }
10407 ScalarExpr::Literal(literal) => Ok(ScalarExpr::Literal(literal.clone())),
10408 ScalarExpr::Binary { left, op, right } => Ok(ScalarExpr::Binary {
10409 left: Box::new(bind_scalar_expr(left, bindings)?),
10410 op: *op,
10411 right: Box::new(bind_scalar_expr(right, bindings)?),
10412 }),
10413 ScalarExpr::Compare { left, op, right } => Ok(ScalarExpr::Compare {
10414 left: Box::new(bind_scalar_expr(left, bindings)?),
10415 op: *op,
10416 right: Box::new(bind_scalar_expr(right, bindings)?),
10417 }),
10418 ScalarExpr::Aggregate(call) => Ok(ScalarExpr::Aggregate(call.clone())),
10419 ScalarExpr::GetField { base, field_name } => {
10420 let bound_base = bind_scalar_expr(base, bindings)?;
10421 match bound_base {
10422 ScalarExpr::Literal(literal) => {
10423 let value = extract_struct_field(&literal, field_name).unwrap_or(Literal::Null);
10424 Ok(ScalarExpr::Literal(value))
10425 }
10426 other => Ok(ScalarExpr::GetField {
10427 base: Box::new(other),
10428 field_name: field_name.clone(),
10429 }),
10430 }
10431 }
10432 ScalarExpr::Cast { expr, data_type } => Ok(ScalarExpr::Cast {
10433 expr: Box::new(bind_scalar_expr(expr, bindings)?),
10434 data_type: data_type.clone(),
10435 }),
10436 ScalarExpr::Case {
10437 operand,
10438 branches,
10439 else_expr,
10440 } => {
10441 let bound_operand = match operand {
10442 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
10443 None => None,
10444 };
10445 let mut bound_branches = Vec::with_capacity(branches.len());
10446 for (when_expr, then_expr) in branches {
10447 bound_branches.push((
10448 bind_scalar_expr(when_expr, bindings)?,
10449 bind_scalar_expr(then_expr, bindings)?,
10450 ));
10451 }
10452 let bound_else = match else_expr {
10453 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
10454 None => None,
10455 };
10456 Ok(ScalarExpr::Case {
10457 operand: bound_operand,
10458 branches: bound_branches,
10459 else_expr: bound_else,
10460 })
10461 }
10462 ScalarExpr::Coalesce(items) => {
10463 let mut bound_items = Vec::with_capacity(items.len());
10464 for item in items {
10465 bound_items.push(bind_scalar_expr(item, bindings)?);
10466 }
10467 Ok(ScalarExpr::Coalesce(bound_items))
10468 }
10469 ScalarExpr::Not(inner) => Ok(ScalarExpr::Not(Box::new(bind_scalar_expr(
10470 inner, bindings,
10471 )?))),
10472 ScalarExpr::IsNull { expr, negated } => Ok(ScalarExpr::IsNull {
10473 expr: Box::new(bind_scalar_expr(expr, bindings)?),
10474 negated: *negated,
10475 }),
10476 ScalarExpr::Random => Ok(ScalarExpr::Random),
10477 ScalarExpr::ScalarSubquery(sub) => Ok(ScalarExpr::ScalarSubquery(sub.clone())),
10478 }
10479}
10480
10481fn bind_predicate_expr(
10482 expr: &LlkvExpr<'static, String>,
10483 bindings: &FxHashMap<String, Literal>,
10484) -> ExecutorResult<LlkvExpr<'static, String>> {
10485 match expr {
10486 LlkvExpr::And(children) => {
10487 let mut bound = Vec::with_capacity(children.len());
10488 for child in children {
10489 bound.push(bind_predicate_expr(child, bindings)?);
10490 }
10491 Ok(LlkvExpr::And(bound))
10492 }
10493 LlkvExpr::Or(children) => {
10494 let mut bound = Vec::with_capacity(children.len());
10495 for child in children {
10496 bound.push(bind_predicate_expr(child, bindings)?);
10497 }
10498 Ok(LlkvExpr::Or(bound))
10499 }
10500 LlkvExpr::Not(inner) => Ok(LlkvExpr::Not(Box::new(bind_predicate_expr(
10501 inner, bindings,
10502 )?))),
10503 LlkvExpr::Pred(filter) => bind_filter_predicate(filter, bindings),
10504 LlkvExpr::Compare { left, op, right } => Ok(LlkvExpr::Compare {
10505 left: bind_scalar_expr(left, bindings)?,
10506 op: *op,
10507 right: bind_scalar_expr(right, bindings)?,
10508 }),
10509 LlkvExpr::InList {
10510 expr,
10511 list,
10512 negated,
10513 } => {
10514 let target = bind_scalar_expr(expr, bindings)?;
10515 let mut bound_list = Vec::with_capacity(list.len());
10516 for item in list {
10517 bound_list.push(bind_scalar_expr(item, bindings)?);
10518 }
10519 Ok(LlkvExpr::InList {
10520 expr: target,
10521 list: bound_list,
10522 negated: *negated,
10523 })
10524 }
10525 LlkvExpr::IsNull { expr, negated } => Ok(LlkvExpr::IsNull {
10526 expr: bind_scalar_expr(expr, bindings)?,
10527 negated: *negated,
10528 }),
10529 LlkvExpr::Literal(value) => Ok(LlkvExpr::Literal(*value)),
10530 LlkvExpr::Exists(subquery) => Ok(LlkvExpr::Exists(subquery.clone())),
10531 }
10532}
10533
10534fn bind_filter_predicate(
10535 filter: &Filter<'static, String>,
10536 bindings: &FxHashMap<String, Literal>,
10537) -> ExecutorResult<LlkvExpr<'static, String>> {
10538 if let Some(literal) = bindings.get(&filter.field_id) {
10539 let result = evaluate_filter_against_literal(literal, &filter.op)?;
10540 return Ok(LlkvExpr::Literal(result));
10541 }
10542 Ok(LlkvExpr::Pred(filter.clone()))
10543}
10544
10545fn evaluate_filter_against_literal(value: &Literal, op: &Operator) -> ExecutorResult<bool> {
10546 use std::ops::Bound;
10547
10548 match op {
10549 Operator::IsNull => Ok(matches!(value, Literal::Null)),
10550 Operator::IsNotNull => Ok(!matches!(value, Literal::Null)),
10551 Operator::Equals(rhs) => Ok(literal_equals(value, rhs).unwrap_or(false)),
10552 Operator::GreaterThan(rhs) => Ok(literal_compare(value, rhs)
10553 .map(|cmp| cmp == std::cmp::Ordering::Greater)
10554 .unwrap_or(false)),
10555 Operator::GreaterThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
10556 .map(|cmp| matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal))
10557 .unwrap_or(false)),
10558 Operator::LessThan(rhs) => Ok(literal_compare(value, rhs)
10559 .map(|cmp| cmp == std::cmp::Ordering::Less)
10560 .unwrap_or(false)),
10561 Operator::LessThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
10562 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal))
10563 .unwrap_or(false)),
10564 Operator::In(values) => Ok(values
10565 .iter()
10566 .any(|candidate| literal_equals(value, candidate).unwrap_or(false))),
10567 Operator::Range { lower, upper } => {
10568 let lower_ok = match lower {
10569 Bound::Unbounded => Some(true),
10570 Bound::Included(bound) => literal_compare(value, bound).map(|cmp| {
10571 matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
10572 }),
10573 Bound::Excluded(bound) => {
10574 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Greater)
10575 }
10576 }
10577 .unwrap_or(false);
10578
10579 let upper_ok = match upper {
10580 Bound::Unbounded => Some(true),
10581 Bound::Included(bound) => literal_compare(value, bound)
10582 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal)),
10583 Bound::Excluded(bound) => {
10584 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Less)
10585 }
10586 }
10587 .unwrap_or(false);
10588
10589 Ok(lower_ok && upper_ok)
10590 }
10591 Operator::StartsWith {
10592 pattern,
10593 case_sensitive,
10594 } => {
10595 let target = if *case_sensitive {
10596 pattern.to_string()
10597 } else {
10598 pattern.to_ascii_lowercase()
10599 };
10600 Ok(literal_string(value, *case_sensitive)
10601 .map(|source| source.starts_with(&target))
10602 .unwrap_or(false))
10603 }
10604 Operator::EndsWith {
10605 pattern,
10606 case_sensitive,
10607 } => {
10608 let target = if *case_sensitive {
10609 pattern.to_string()
10610 } else {
10611 pattern.to_ascii_lowercase()
10612 };
10613 Ok(literal_string(value, *case_sensitive)
10614 .map(|source| source.ends_with(&target))
10615 .unwrap_or(false))
10616 }
10617 Operator::Contains {
10618 pattern,
10619 case_sensitive,
10620 } => {
10621 let target = if *case_sensitive {
10622 pattern.to_string()
10623 } else {
10624 pattern.to_ascii_lowercase()
10625 };
10626 Ok(literal_string(value, *case_sensitive)
10627 .map(|source| source.contains(&target))
10628 .unwrap_or(false))
10629 }
10630 }
10631}
10632
10633fn literal_compare(lhs: &Literal, rhs: &Literal) -> Option<std::cmp::Ordering> {
10634 match (lhs, rhs) {
10635 (Literal::Int128(a), Literal::Int128(b)) => Some(a.cmp(b)),
10636 (Literal::Float64(a), Literal::Float64(b)) => a.partial_cmp(b),
10637 (Literal::Int128(a), Literal::Float64(b)) => (*a as f64).partial_cmp(b),
10638 (Literal::Float64(a), Literal::Int128(b)) => a.partial_cmp(&(*b as f64)),
10639 (Literal::Date32(a), Literal::Date32(b)) => Some(a.cmp(b)),
10640 (Literal::Date32(a), Literal::Int128(b)) => Some((*a as i128).cmp(b)),
10641 (Literal::Int128(a), Literal::Date32(b)) => Some(a.cmp(&(*b as i128))),
10642 (Literal::Date32(a), Literal::Float64(b)) => (*a as f64).partial_cmp(b),
10643 (Literal::Float64(a), Literal::Date32(b)) => a.partial_cmp(&(*b as f64)),
10644 (Literal::String(a), Literal::String(b)) => Some(a.cmp(b)),
10645 (Literal::Interval(a), Literal::Interval(b)) => Some(compare_interval_values(*a, *b)),
10646 _ => None,
10647 }
10648}
10649
10650fn literal_equals(lhs: &Literal, rhs: &Literal) -> Option<bool> {
10651 match (lhs, rhs) {
10652 (Literal::Boolean(a), Literal::Boolean(b)) => Some(a == b),
10653 (Literal::String(a), Literal::String(b)) => Some(a == b),
10654 (Literal::Int128(_), Literal::Int128(_))
10655 | (Literal::Int128(_), Literal::Float64(_))
10656 | (Literal::Float64(_), Literal::Int128(_))
10657 | (Literal::Float64(_), Literal::Float64(_))
10658 | (Literal::Date32(_), Literal::Date32(_))
10659 | (Literal::Date32(_), Literal::Int128(_))
10660 | (Literal::Int128(_), Literal::Date32(_))
10661 | (Literal::Date32(_), Literal::Float64(_))
10662 | (Literal::Float64(_), Literal::Date32(_))
10663 | (Literal::Interval(_), Literal::Interval(_)) => {
10664 literal_compare(lhs, rhs).map(|cmp| cmp == std::cmp::Ordering::Equal)
10665 }
10666 _ => None,
10667 }
10668}
10669
10670fn literal_string(literal: &Literal, case_sensitive: bool) -> Option<String> {
10671 match literal {
10672 Literal::String(value) => {
10673 if case_sensitive {
10674 Some(value.clone())
10675 } else {
10676 Some(value.to_ascii_lowercase())
10677 }
10678 }
10679 Literal::Date32(value) => {
10680 let formatted = format_date32_literal(*value).ok()?;
10681 if case_sensitive {
10682 Some(formatted)
10683 } else {
10684 Some(formatted.to_ascii_lowercase())
10685 }
10686 }
10687 _ => None,
10688 }
10689}
10690
10691fn extract_struct_field(literal: &Literal, field_name: &str) -> Option<Literal> {
10692 if let Literal::Struct(fields) = literal {
10693 for (name, value) in fields {
10694 if name.eq_ignore_ascii_case(field_name) {
10695 return Some((**value).clone());
10696 }
10697 }
10698 }
10699 None
10700}
10701
10702fn collect_scalar_subquery_ids(expr: &ScalarExpr<FieldId>, ids: &mut FxHashSet<SubqueryId>) {
10703 match expr {
10704 ScalarExpr::ScalarSubquery(subquery) => {
10705 ids.insert(subquery.id);
10706 }
10707 ScalarExpr::Binary { left, right, .. } => {
10708 collect_scalar_subquery_ids(left, ids);
10709 collect_scalar_subquery_ids(right, ids);
10710 }
10711 ScalarExpr::Compare { left, right, .. } => {
10712 collect_scalar_subquery_ids(left, ids);
10713 collect_scalar_subquery_ids(right, ids);
10714 }
10715 ScalarExpr::GetField { base, .. } => {
10716 collect_scalar_subquery_ids(base, ids);
10717 }
10718 ScalarExpr::Cast { expr, .. } => {
10719 collect_scalar_subquery_ids(expr, ids);
10720 }
10721 ScalarExpr::Not(expr) => {
10722 collect_scalar_subquery_ids(expr, ids);
10723 }
10724 ScalarExpr::IsNull { expr, .. } => {
10725 collect_scalar_subquery_ids(expr, ids);
10726 }
10727 ScalarExpr::Case {
10728 operand,
10729 branches,
10730 else_expr,
10731 } => {
10732 if let Some(op) = operand {
10733 collect_scalar_subquery_ids(op, ids);
10734 }
10735 for (when_expr, then_expr) in branches {
10736 collect_scalar_subquery_ids(when_expr, ids);
10737 collect_scalar_subquery_ids(then_expr, ids);
10738 }
10739 if let Some(else_expr) = else_expr {
10740 collect_scalar_subquery_ids(else_expr, ids);
10741 }
10742 }
10743 ScalarExpr::Coalesce(items) => {
10744 for item in items {
10745 collect_scalar_subquery_ids(item, ids);
10746 }
10747 }
10748 ScalarExpr::Aggregate(_)
10749 | ScalarExpr::Column(_)
10750 | ScalarExpr::Literal(_)
10751 | ScalarExpr::Random => {}
10752 }
10753}
10754
10755fn collect_predicate_scalar_subquery_ids(
10756 expr: &LlkvExpr<'static, FieldId>,
10757 ids: &mut FxHashSet<SubqueryId>,
10758) {
10759 match expr {
10760 LlkvExpr::And(children) | LlkvExpr::Or(children) => {
10761 for child in children {
10762 collect_predicate_scalar_subquery_ids(child, ids);
10763 }
10764 }
10765 LlkvExpr::Not(inner) => collect_predicate_scalar_subquery_ids(inner, ids),
10766 LlkvExpr::Compare { left, right, .. } => {
10767 collect_scalar_subquery_ids(left, ids);
10768 collect_scalar_subquery_ids(right, ids);
10769 }
10770 LlkvExpr::InList { expr, list, .. } => {
10771 collect_scalar_subquery_ids(expr, ids);
10772 for item in list {
10773 collect_scalar_subquery_ids(item, ids);
10774 }
10775 }
10776 LlkvExpr::IsNull { expr, .. } => {
10777 collect_scalar_subquery_ids(expr, ids);
10778 }
10779 LlkvExpr::Exists(_) | LlkvExpr::Pred(_) | LlkvExpr::Literal(_) => {
10780 }
10782 }
10783}
10784
10785fn rewrite_scalar_expr_for_subqueries(
10786 expr: &ScalarExpr<FieldId>,
10787 mapping: &FxHashMap<SubqueryId, FieldId>,
10788) -> ScalarExpr<FieldId> {
10789 match expr {
10790 ScalarExpr::ScalarSubquery(subquery) => mapping
10791 .get(&subquery.id)
10792 .map(|field_id| ScalarExpr::Column(*field_id))
10793 .unwrap_or_else(|| ScalarExpr::ScalarSubquery(subquery.clone())),
10794 ScalarExpr::Binary { left, op, right } => ScalarExpr::Binary {
10795 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
10796 op: *op,
10797 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
10798 },
10799 ScalarExpr::Compare { left, op, right } => ScalarExpr::Compare {
10800 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
10801 op: *op,
10802 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
10803 },
10804 ScalarExpr::GetField { base, field_name } => ScalarExpr::GetField {
10805 base: Box::new(rewrite_scalar_expr_for_subqueries(base, mapping)),
10806 field_name: field_name.clone(),
10807 },
10808 ScalarExpr::Cast { expr, data_type } => ScalarExpr::Cast {
10809 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
10810 data_type: data_type.clone(),
10811 },
10812 ScalarExpr::Not(expr) => {
10813 ScalarExpr::Not(Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)))
10814 }
10815 ScalarExpr::IsNull { expr, negated } => ScalarExpr::IsNull {
10816 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
10817 negated: *negated,
10818 },
10819 ScalarExpr::Case {
10820 operand,
10821 branches,
10822 else_expr,
10823 } => ScalarExpr::Case {
10824 operand: operand
10825 .as_ref()
10826 .map(|op| Box::new(rewrite_scalar_expr_for_subqueries(op, mapping))),
10827 branches: branches
10828 .iter()
10829 .map(|(when_expr, then_expr)| {
10830 (
10831 rewrite_scalar_expr_for_subqueries(when_expr, mapping),
10832 rewrite_scalar_expr_for_subqueries(then_expr, mapping),
10833 )
10834 })
10835 .collect(),
10836 else_expr: else_expr
10837 .as_ref()
10838 .map(|expr| Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping))),
10839 },
10840 ScalarExpr::Coalesce(items) => ScalarExpr::Coalesce(
10841 items
10842 .iter()
10843 .map(|item| rewrite_scalar_expr_for_subqueries(item, mapping))
10844 .collect(),
10845 ),
10846 ScalarExpr::Aggregate(_)
10847 | ScalarExpr::Column(_)
10848 | ScalarExpr::Literal(_)
10849 | ScalarExpr::Random => expr.clone(),
10850 }
10851}
10852
10853fn collect_correlated_bindings(
10854 context: &mut CrossProductExpressionContext,
10855 batch: &RecordBatch,
10856 row_idx: usize,
10857 columns: &[llkv_plan::CorrelatedColumn],
10858) -> ExecutorResult<FxHashMap<String, Literal>> {
10859 let mut out = FxHashMap::default();
10860
10861 for correlated in columns {
10862 if !correlated.field_path.is_empty() {
10863 return Err(Error::InvalidArgumentError(
10864 "correlated field path resolution is not yet supported".into(),
10865 ));
10866 }
10867
10868 let field_id = context
10869 .field_id_for_column(&correlated.column)
10870 .ok_or_else(|| {
10871 Error::InvalidArgumentError(format!(
10872 "correlated column '{}' not found in outer query output",
10873 correlated.column
10874 ))
10875 })?;
10876
10877 let accessor = context.column_accessor(field_id, batch)?;
10878 let literal = accessor.literal_at(row_idx)?;
10879 out.insert(correlated.placeholder.clone(), literal);
10880 }
10881
10882 Ok(out)
10883}
10884
10885#[derive(Clone)]
10887pub struct SelectExecution<P>
10888where
10889 P: Pager<Blob = EntryHandle> + Send + Sync,
10890{
10891 table_name: String,
10892 schema: Arc<Schema>,
10893 stream: SelectStream<P>,
10894 limit: Option<usize>,
10895 offset: Option<usize>,
10896}
10897
10898#[derive(Clone)]
10899enum SelectStream<P>
10900where
10901 P: Pager<Blob = EntryHandle> + Send + Sync,
10902{
10903 Projection {
10904 table: Arc<ExecutorTable<P>>,
10905 projections: Vec<ScanProjection>,
10906 filter_expr: LlkvExpr<'static, FieldId>,
10907 options: ScanStreamOptions<P>,
10908 full_table_scan: bool,
10909 order_by: Vec<OrderByPlan>,
10910 distinct: bool,
10911 },
10912 Aggregation {
10913 batch: RecordBatch,
10914 },
10915}
10916
10917impl<P> SelectExecution<P>
10918where
10919 P: Pager<Blob = EntryHandle> + Send + Sync,
10920{
10921 #[allow(clippy::too_many_arguments)]
10922 fn new_projection(
10923 table_name: String,
10924 schema: Arc<Schema>,
10925 table: Arc<ExecutorTable<P>>,
10926 projections: Vec<ScanProjection>,
10927 filter_expr: LlkvExpr<'static, FieldId>,
10928 options: ScanStreamOptions<P>,
10929 full_table_scan: bool,
10930 order_by: Vec<OrderByPlan>,
10931 distinct: bool,
10932 ) -> Self {
10933 Self {
10934 table_name,
10935 schema,
10936 stream: SelectStream::Projection {
10937 table,
10938 projections,
10939 filter_expr,
10940 options,
10941 full_table_scan,
10942 order_by,
10943 distinct,
10944 },
10945 limit: None,
10946 offset: None,
10947 }
10948 }
10949
10950 pub fn new_single_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
10951 Self {
10952 table_name,
10953 schema,
10954 stream: SelectStream::Aggregation { batch },
10955 limit: None,
10956 offset: None,
10957 }
10958 }
10959
10960 pub fn from_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
10961 Self::new_single_batch(table_name, schema, batch)
10962 }
10963
10964 pub fn table_name(&self) -> &str {
10965 &self.table_name
10966 }
10967
10968 pub fn schema(&self) -> Arc<Schema> {
10969 Arc::clone(&self.schema)
10970 }
10971
10972 pub fn with_limit(mut self, limit: Option<usize>) -> Self {
10973 self.limit = limit;
10974 self
10975 }
10976
10977 pub fn with_offset(mut self, offset: Option<usize>) -> Self {
10978 self.offset = offset;
10979 self
10980 }
10981
10982 pub fn stream(
10983 self,
10984 mut on_batch: impl FnMut(RecordBatch) -> ExecutorResult<()>,
10985 ) -> ExecutorResult<()> {
10986 let limit = self.limit;
10987 let mut offset = self.offset.unwrap_or(0);
10988 let mut rows_emitted = 0;
10989
10990 let mut on_batch = |batch: RecordBatch| -> ExecutorResult<()> {
10991 let rows = batch.num_rows();
10992 let mut batch_to_emit = batch;
10993
10994 if offset > 0 {
10996 if rows == 0 {
10997 } else if rows <= offset {
10999 offset -= rows;
11000 return Ok(());
11001 } else {
11002 batch_to_emit = batch_to_emit.slice(offset, rows - offset);
11003 offset = 0;
11004 }
11005 }
11006
11007 if let Some(limit_val) = limit {
11009 if rows_emitted >= limit_val {
11010 return Ok(());
11011 }
11012 let remaining = limit_val - rows_emitted;
11013 if batch_to_emit.num_rows() > remaining {
11014 batch_to_emit = batch_to_emit.slice(0, remaining);
11015 }
11016 rows_emitted += batch_to_emit.num_rows();
11017 }
11018
11019 on_batch(batch_to_emit)
11020 };
11021
11022 let schema = Arc::clone(&self.schema);
11023 match self.stream {
11024 SelectStream::Projection {
11025 table,
11026 projections,
11027 filter_expr,
11028 options,
11029 full_table_scan,
11030 order_by,
11031 distinct,
11032 } => {
11033 let total_rows = table.total_rows.load(Ordering::SeqCst);
11035 if total_rows == 0 {
11036 return Ok(());
11038 }
11039
11040 let mut error: Option<Error> = None;
11041 let mut produced = false;
11042 let mut produced_rows: u64 = 0;
11043 let capture_nulls_first = matches!(options.order, Some(spec) if spec.nulls_first);
11044 let needs_post_sort =
11045 !order_by.is_empty() && (order_by.len() > 1 || options.order.is_none());
11046 let collect_batches = needs_post_sort || capture_nulls_first;
11047 let include_nulls = options.include_nulls;
11048 let has_row_id_filter = options.row_id_filter.is_some();
11049 let mut distinct_state = if distinct {
11050 Some(DistinctState::default())
11051 } else {
11052 None
11053 };
11054 let scan_options = options;
11055 let mut buffered_batches: Vec<RecordBatch> = Vec::new();
11056 table
11057 .table
11058 .scan_stream(projections, &filter_expr, scan_options, |batch| {
11059 if error.is_some() {
11060 return;
11061 }
11062 let mut batch = batch;
11063 if let Some(state) = distinct_state.as_mut() {
11064 match distinct_filter_batch(batch, state) {
11065 Ok(Some(filtered)) => {
11066 batch = filtered;
11067 }
11068 Ok(None) => {
11069 return;
11070 }
11071 Err(err) => {
11072 error = Some(err);
11073 return;
11074 }
11075 }
11076 }
11077 produced = true;
11078 produced_rows = produced_rows.saturating_add(batch.num_rows() as u64);
11079 if collect_batches {
11080 buffered_batches.push(batch);
11081 } else if let Err(err) = on_batch(batch) {
11082 error = Some(err);
11083 }
11084 })?;
11085 if let Some(err) = error {
11086 return Err(err);
11087 }
11088 if !produced {
11089 if !distinct && full_table_scan && total_rows > 0 {
11092 for batch in synthesize_null_scan(Arc::clone(&schema), total_rows)? {
11093 on_batch(batch)?;
11094 }
11095 }
11096 return Ok(());
11097 }
11098 let mut null_batches: Vec<RecordBatch> = Vec::new();
11099 if !distinct
11105 && include_nulls
11106 && full_table_scan
11107 && produced_rows < total_rows
11108 && !has_row_id_filter
11109 {
11110 let missing = total_rows - produced_rows;
11111 if missing > 0 {
11112 null_batches = synthesize_null_scan(Arc::clone(&schema), missing)?;
11113 }
11114 }
11115
11116 if collect_batches {
11117 if needs_post_sort {
11118 if !null_batches.is_empty() {
11119 buffered_batches.extend(null_batches);
11120 }
11121 if !buffered_batches.is_empty() {
11122 let combined =
11123 concat_batches(&schema, &buffered_batches).map_err(|err| {
11124 Error::InvalidArgumentError(format!(
11125 "failed to concatenate result batches for ORDER BY: {}",
11126 err
11127 ))
11128 })?;
11129 let sorted_batch =
11130 sort_record_batch_with_order(&schema, &combined, &order_by)?;
11131 on_batch(sorted_batch)?;
11132 }
11133 } else if capture_nulls_first {
11134 for batch in null_batches {
11135 on_batch(batch)?;
11136 }
11137 for batch in buffered_batches {
11138 on_batch(batch)?;
11139 }
11140 }
11141 } else if !null_batches.is_empty() {
11142 for batch in null_batches {
11143 on_batch(batch)?;
11144 }
11145 }
11146 Ok(())
11147 }
11148 SelectStream::Aggregation { batch } => on_batch(batch),
11149 }
11150 }
11151
11152 pub fn collect(self) -> ExecutorResult<Vec<RecordBatch>> {
11153 let mut batches = Vec::new();
11154 self.stream(|batch| {
11155 batches.push(batch);
11156 Ok(())
11157 })?;
11158 Ok(batches)
11159 }
11160
11161 pub fn collect_rows(self) -> ExecutorResult<ExecutorRowBatch> {
11162 let schema = self.schema();
11163 let mut rows: Vec<Vec<PlanValue>> = Vec::new();
11164 self.stream(|batch| {
11165 for row_idx in 0..batch.num_rows() {
11166 let mut row: Vec<PlanValue> = Vec::with_capacity(batch.num_columns());
11167 for col_idx in 0..batch.num_columns() {
11168 let value = llkv_plan::plan_value_from_array(batch.column(col_idx), row_idx)?;
11169 row.push(value);
11170 }
11171 rows.push(row);
11172 }
11173 Ok(())
11174 })?;
11175 let columns = schema
11176 .fields()
11177 .iter()
11178 .map(|field| field.name().to_string())
11179 .collect();
11180 Ok(ExecutorRowBatch { columns, rows })
11181 }
11182
11183 pub fn into_rows(self) -> ExecutorResult<Vec<Vec<PlanValue>>> {
11184 Ok(self.collect_rows()?.rows)
11185 }
11186}
11187
11188impl<P> fmt::Debug for SelectExecution<P>
11189where
11190 P: Pager<Blob = EntryHandle> + Send + Sync,
11191{
11192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11193 f.debug_struct("SelectExecution")
11194 .field("table_name", &self.table_name)
11195 .field("schema", &self.schema)
11196 .finish()
11197 }
11198}
11199
11200fn infer_type_recursive(
11206 expr: &ScalarExpr<String>,
11207 base_schema: &Schema,
11208 column_lookup_map: &FxHashMap<String, usize>,
11209) -> Option<DataType> {
11210 use arrow::datatypes::IntervalUnit;
11211 use llkv_expr::literal::Literal;
11212
11213 match expr {
11214 ScalarExpr::Column(name) => resolve_column_name_to_index(name, column_lookup_map)
11215 .map(|idx| base_schema.field(idx).data_type().clone()),
11216 ScalarExpr::Literal(lit) => match lit {
11217 Literal::Decimal128(v) => Some(DataType::Decimal128(v.precision(), v.scale())),
11218 Literal::Float64(_) => Some(DataType::Float64),
11219 Literal::Int128(_) => Some(DataType::Int64),
11220 Literal::Boolean(_) => Some(DataType::Boolean),
11221 Literal::String(_) => Some(DataType::Utf8),
11222 Literal::Date32(_) => Some(DataType::Date32),
11223 Literal::Null => Some(DataType::Null),
11224 Literal::Interval(_) => Some(DataType::Interval(IntervalUnit::MonthDayNano)),
11225 _ => None,
11226 },
11227 ScalarExpr::Binary { left, op: _, right } => {
11228 let l = infer_type_recursive(left, base_schema, column_lookup_map)?;
11229 let r = infer_type_recursive(right, base_schema, column_lookup_map)?;
11230
11231 if matches!(l, DataType::Float64) || matches!(r, DataType::Float64) {
11232 return Some(DataType::Float64);
11233 }
11234
11235 match (l, r) {
11236 (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) => {
11237 Some(DataType::Decimal128(38, s1.max(s2)))
11239 }
11240 (DataType::Decimal128(p, s), _) => Some(DataType::Decimal128(p, s)),
11241 (_, DataType::Decimal128(p, s)) => Some(DataType::Decimal128(p, s)),
11242 (l, _) => Some(l),
11243 }
11244 }
11245 ScalarExpr::Cast { data_type, .. } => Some(data_type.clone()),
11246 _ => None,
11248 }
11249}
11250
11251fn expand_order_targets(
11252 order_items: &[OrderByPlan],
11253 projections: &[ScanProjection],
11254) -> ExecutorResult<Vec<OrderByPlan>> {
11255 let mut expanded = Vec::new();
11256
11257 for item in order_items {
11258 match &item.target {
11259 OrderTarget::All => {
11260 if projections.is_empty() {
11261 return Err(Error::InvalidArgumentError(
11262 "ORDER BY ALL requires at least one projection".into(),
11263 ));
11264 }
11265
11266 for (idx, projection) in projections.iter().enumerate() {
11267 if matches!(projection, ScanProjection::Computed { .. }) {
11268 return Err(Error::InvalidArgumentError(
11269 "ORDER BY ALL cannot reference computed projections".into(),
11270 ));
11271 }
11272
11273 let mut clone = item.clone();
11274 clone.target = OrderTarget::Index(idx);
11275 expanded.push(clone);
11276 }
11277 }
11278 _ => expanded.push(item.clone()),
11279 }
11280 }
11281
11282 Ok(expanded)
11283}
11284
11285fn resolve_scan_order<P>(
11286 table: &ExecutorTable<P>,
11287 projections: &[ScanProjection],
11288 order_plan: &OrderByPlan,
11289) -> ExecutorResult<ScanOrderSpec>
11290where
11291 P: Pager<Blob = EntryHandle> + Send + Sync,
11292{
11293 let (column, field_id) = match &order_plan.target {
11294 OrderTarget::Column(name) => {
11295 let column = table.schema.resolve(name).ok_or_else(|| {
11296 Error::InvalidArgumentError(format!("unknown column '{}' in ORDER BY", name))
11297 })?;
11298 (column, column.field_id)
11299 }
11300 OrderTarget::Index(position) => {
11301 let projection = projections.get(*position).ok_or_else(|| {
11302 Error::InvalidArgumentError(format!(
11303 "ORDER BY position {} is out of range",
11304 position + 1
11305 ))
11306 })?;
11307 match projection {
11308 ScanProjection::Column(store_projection) => {
11309 let field_id = store_projection.logical_field_id.field_id();
11310 let column = table.schema.column_by_field_id(field_id).ok_or_else(|| {
11311 Error::InvalidArgumentError(format!(
11312 "unknown column with field id {field_id} in ORDER BY"
11313 ))
11314 })?;
11315 (column, field_id)
11316 }
11317 ScanProjection::Computed { .. } => {
11318 return Err(Error::InvalidArgumentError(
11319 "ORDER BY position referring to computed projection is not supported"
11320 .into(),
11321 ));
11322 }
11323 }
11324 }
11325 OrderTarget::All => {
11326 return Err(Error::InvalidArgumentError(
11327 "ORDER BY ALL should be expanded before execution".into(),
11328 ));
11329 }
11330 };
11331
11332 let transform = match order_plan.sort_type {
11333 OrderSortType::Native => match column.data_type {
11334 DataType::Int64 => ScanOrderTransform::IdentityInt64,
11335 DataType::Int32 => ScanOrderTransform::IdentityInt32,
11336 DataType::Utf8 => ScanOrderTransform::IdentityUtf8,
11337 ref other => {
11338 return Err(Error::InvalidArgumentError(format!(
11339 "ORDER BY on column type {:?} is not supported",
11340 other
11341 )));
11342 }
11343 },
11344 OrderSortType::CastTextToInteger => {
11345 if column.data_type != DataType::Utf8 {
11346 return Err(Error::InvalidArgumentError(
11347 "ORDER BY CAST expects a text column".into(),
11348 ));
11349 }
11350 ScanOrderTransform::CastUtf8ToInteger
11351 }
11352 };
11353
11354 let direction = if order_plan.ascending {
11355 ScanOrderDirection::Ascending
11356 } else {
11357 ScanOrderDirection::Descending
11358 };
11359
11360 Ok(ScanOrderSpec {
11361 field_id,
11362 direction,
11363 nulls_first: order_plan.nulls_first,
11364 transform,
11365 })
11366}
11367
11368fn synthesize_null_scan(schema: Arc<Schema>, total_rows: u64) -> ExecutorResult<Vec<RecordBatch>> {
11369 let row_count = usize::try_from(total_rows).map_err(|_| {
11370 Error::InvalidArgumentError("table row count exceeds supported in-memory batch size".into())
11371 })?;
11372
11373 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
11374 for field in schema.fields() {
11375 match field.data_type() {
11376 DataType::Int64 => {
11377 let mut builder = Int64Builder::with_capacity(row_count);
11378 for _ in 0..row_count {
11379 builder.append_null();
11380 }
11381 arrays.push(Arc::new(builder.finish()));
11382 }
11383 DataType::Float64 => {
11384 let mut builder = arrow::array::Float64Builder::with_capacity(row_count);
11385 for _ in 0..row_count {
11386 builder.append_null();
11387 }
11388 arrays.push(Arc::new(builder.finish()));
11389 }
11390 DataType::Utf8 => {
11391 let mut builder = arrow::array::StringBuilder::with_capacity(row_count, 0);
11392 for _ in 0..row_count {
11393 builder.append_null();
11394 }
11395 arrays.push(Arc::new(builder.finish()));
11396 }
11397 DataType::Date32 => {
11398 let mut builder = arrow::array::Date32Builder::with_capacity(row_count);
11399 for _ in 0..row_count {
11400 builder.append_null();
11401 }
11402 arrays.push(Arc::new(builder.finish()));
11403 }
11404 other => {
11405 return Err(Error::InvalidArgumentError(format!(
11406 "unsupported data type in null synthesis: {other:?}"
11407 )));
11408 }
11409 }
11410 }
11411
11412 let batch = RecordBatch::try_new(schema, arrays)?;
11413 Ok(vec![batch])
11414}
11415
11416struct TableCrossProductData {
11417 schema: Arc<Schema>,
11418 batches: Vec<RecordBatch>,
11419 column_counts: Vec<usize>,
11420 table_indices: Vec<usize>,
11421}
11422
11423fn plan_value_to_literal(value: &PlanValue) -> ExecutorResult<Literal> {
11424 match value {
11425 PlanValue::String(s) => Ok(Literal::String(s.clone())),
11426 PlanValue::Integer(i) => Ok(Literal::Int128(*i as i128)),
11427 PlanValue::Float(f) => Ok(Literal::Float64(*f)),
11428 PlanValue::Null => Ok(Literal::Null),
11429 PlanValue::Date32(d) => Ok(Literal::Date32(*d)),
11430 PlanValue::Decimal(d) => Ok(Literal::Decimal128(*d)),
11431 _ => Err(Error::Internal(format!(
11432 "unsupported plan value for literal conversion: {:?}",
11433 value
11434 ))),
11435 }
11436}
11437
11438fn collect_table_data<P>(
11439 table_index: usize,
11440 table_ref: &llkv_plan::TableRef,
11441 table: &ExecutorTable<P>,
11442 constraints: &[ColumnConstraint],
11443) -> ExecutorResult<TableCrossProductData>
11444where
11445 P: Pager<Blob = EntryHandle> + Send + Sync,
11446{
11447 if table.schema.columns.is_empty() {
11448 return Err(Error::InvalidArgumentError(format!(
11449 "table '{}' has no columns; cross products require at least one column",
11450 table_ref.qualified_name()
11451 )));
11452 }
11453
11454 let mut projections = Vec::with_capacity(table.schema.columns.len());
11455 let mut fields = Vec::with_capacity(table.schema.columns.len());
11456
11457 for column in &table.schema.columns {
11458 let table_component = table_ref
11459 .alias
11460 .as_deref()
11461 .unwrap_or(table_ref.table.as_str());
11462 let qualified_name = format!("{}.{}.{}", table_ref.schema, table_component, column.name);
11463 projections.push(ScanProjection::from(StoreProjection::with_alias(
11464 LogicalFieldId::for_user(table.table_id(), column.field_id),
11465 qualified_name.clone(),
11466 )));
11467 fields.push(Field::new(
11468 qualified_name,
11469 column.data_type.clone(),
11470 column.nullable,
11471 ));
11472 }
11473
11474 let schema = Arc::new(Schema::new(fields));
11475
11476 let filter_field_id = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
11477
11478 let mut filter_exprs = Vec::new();
11480 for constraint in constraints {
11481 match constraint {
11482 ColumnConstraint::Equality(lit) => {
11483 let col_idx = lit.column.column;
11484 if col_idx < table.schema.columns.len() {
11485 let field_id = table.schema.columns[col_idx].field_id;
11486 if let Ok(literal) = plan_value_to_literal(&lit.value) {
11487 filter_exprs.push(LlkvExpr::Compare {
11488 left: ScalarExpr::Column(field_id),
11489 op: CompareOp::Eq,
11490 right: ScalarExpr::Literal(literal),
11491 });
11492 }
11493 }
11494 }
11495 ColumnConstraint::InList(in_list) => {
11496 let col_idx = in_list.column.column;
11497 if col_idx < table.schema.columns.len() {
11498 let field_id = table.schema.columns[col_idx].field_id;
11499 let literals: Vec<Literal> = in_list
11500 .values
11501 .iter()
11502 .filter_map(|v| plan_value_to_literal(v).ok())
11503 .collect();
11504
11505 if !literals.is_empty() {
11506 filter_exprs.push(LlkvExpr::InList {
11507 expr: ScalarExpr::Column(field_id),
11508 list: literals.into_iter().map(ScalarExpr::Literal).collect(),
11509 negated: false,
11510 });
11511 }
11512 }
11513 }
11514 }
11515 }
11516
11517 let filter_expr = if filter_exprs.is_empty() {
11518 crate::translation::expression::full_table_scan_filter(filter_field_id)
11519 } else if filter_exprs.len() == 1 {
11520 filter_exprs.pop().unwrap()
11521 } else {
11522 LlkvExpr::And(filter_exprs)
11523 };
11524
11525 let mut raw_batches = Vec::new();
11526 table.storage().scan_stream(
11527 &projections,
11528 &filter_expr,
11529 ScanStreamOptions {
11530 include_nulls: true,
11531 ..ScanStreamOptions::default()
11532 },
11533 &mut |batch| {
11534 raw_batches.push(batch);
11535 },
11536 )?;
11537
11538 let normalized_batches = raw_batches
11539 .par_iter()
11540 .map(|batch| {
11541 RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()).map_err(|err| {
11542 Error::Internal(format!(
11543 "failed to align scan batch for table '{}': {}",
11544 table_ref.qualified_name(),
11545 err
11546 ))
11547 })
11548 })
11549 .collect::<ExecutorResult<Vec<RecordBatch>>>()?;
11550
11551 let mut normalized_batches = normalized_batches;
11552
11553 if !constraints.is_empty() {
11554 normalized_batches = apply_column_constraints_to_batches(normalized_batches, constraints)?;
11555 }
11556
11557 Ok(TableCrossProductData {
11558 schema,
11559 batches: normalized_batches,
11560 column_counts: vec![table.schema.columns.len()],
11561 table_indices: vec![table_index],
11562 })
11563}
11564
11565fn apply_column_constraints_to_batches(
11566 batches: Vec<RecordBatch>,
11567 constraints: &[ColumnConstraint],
11568) -> ExecutorResult<Vec<RecordBatch>> {
11569 if batches.is_empty() {
11570 return Ok(batches);
11571 }
11572
11573 let mut filtered = batches;
11574 for constraint in constraints {
11575 match constraint {
11576 ColumnConstraint::Equality(lit) => {
11577 filtered = filter_batches_by_literal(filtered, lit.column.column, &lit.value)?;
11578 }
11579 ColumnConstraint::InList(in_list) => {
11580 filtered =
11581 filter_batches_by_in_list(filtered, in_list.column.column, &in_list.values)?;
11582 }
11583 }
11584 if filtered.is_empty() {
11585 break;
11586 }
11587 }
11588
11589 Ok(filtered)
11590}
11591
11592fn filter_batches_by_literal(
11593 batches: Vec<RecordBatch>,
11594 column_idx: usize,
11595 literal: &PlanValue,
11596) -> ExecutorResult<Vec<RecordBatch>> {
11597 let result: Vec<RecordBatch> = batches
11598 .par_iter()
11599 .map(|batch| -> ExecutorResult<Option<RecordBatch>> {
11600 if column_idx >= batch.num_columns() {
11601 return Err(Error::Internal(
11602 "literal constraint referenced invalid column index".into(),
11603 ));
11604 }
11605
11606 if batch.num_rows() == 0 {
11607 return Ok(Some(batch.clone()));
11608 }
11609
11610 let column = batch.column(column_idx);
11611 let mut keep_rows: Vec<u32> = Vec::with_capacity(batch.num_rows());
11612
11613 for row_idx in 0..batch.num_rows() {
11614 if array_value_equals_plan_value(column.as_ref(), row_idx, literal)? {
11615 keep_rows.push(row_idx as u32);
11616 }
11617 }
11618
11619 if keep_rows.len() == batch.num_rows() {
11620 return Ok(Some(batch.clone()));
11621 }
11622
11623 if keep_rows.is_empty() {
11624 return Ok(None);
11625 }
11626
11627 let indices = UInt32Array::from(keep_rows);
11628 let mut filtered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
11629 for col_idx in 0..batch.num_columns() {
11630 let filtered =
11631 take(batch.column(col_idx).as_ref(), &indices, None).map_err(|err| {
11632 Error::Internal(format!("failed to apply literal filter: {err}"))
11633 })?;
11634 filtered_columns.push(filtered);
11635 }
11636
11637 let filtered_batch =
11638 RecordBatch::try_new(batch.schema(), filtered_columns).map_err(|err| {
11639 Error::Internal(format!(
11640 "failed to rebuild batch after literal filter: {err}"
11641 ))
11642 })?;
11643 Ok(Some(filtered_batch))
11644 })
11645 .collect::<ExecutorResult<Vec<Option<RecordBatch>>>>()?
11646 .into_iter()
11647 .flatten()
11648 .collect();
11649
11650 Ok(result)
11651}
11652
11653fn filter_batches_by_in_list(
11654 batches: Vec<RecordBatch>,
11655 column_idx: usize,
11656 values: &[PlanValue],
11657) -> ExecutorResult<Vec<RecordBatch>> {
11658 use arrow::array::*;
11659 use arrow::compute::or;
11660
11661 if values.is_empty() {
11662 return Ok(Vec::new());
11664 }
11665
11666 let result: Vec<RecordBatch> = batches
11667 .par_iter()
11668 .map(|batch| -> ExecutorResult<Option<RecordBatch>> {
11669 if column_idx >= batch.num_columns() {
11670 return Err(Error::Internal(
11671 "IN list constraint referenced invalid column index".into(),
11672 ));
11673 }
11674
11675 if batch.num_rows() == 0 {
11676 return Ok(Some(batch.clone()));
11677 }
11678
11679 let column = batch.column(column_idx);
11680
11681 let mut mask = BooleanArray::from(vec![false; batch.num_rows()]);
11684
11685 for value in values {
11686 let comparison_mask = build_comparison_mask(column.as_ref(), value)?;
11687 mask = or(&mask, &comparison_mask).map_err(|err| {
11688 Error::Internal(format!("failed to OR comparison masks: {err}"))
11689 })?;
11690 }
11691
11692 let true_count = mask.true_count();
11694 if true_count == batch.num_rows() {
11695 return Ok(Some(batch.clone()));
11696 }
11697
11698 if true_count == 0 {
11699 return Ok(None);
11701 }
11702
11703 let filtered_batch = arrow::compute::filter_record_batch(batch, &mask)
11705 .map_err(|err| Error::Internal(format!("failed to apply IN list filter: {err}")))?;
11706
11707 Ok(Some(filtered_batch))
11708 })
11709 .collect::<ExecutorResult<Vec<Option<RecordBatch>>>>()?
11710 .into_iter()
11711 .flatten()
11712 .collect();
11713
11714 Ok(result)
11715}
11716
11717fn build_comparison_mask(column: &dyn Array, value: &PlanValue) -> ExecutorResult<BooleanArray> {
11719 use arrow::array::*;
11720 use arrow::datatypes::DataType;
11721
11722 match value {
11723 PlanValue::Null => {
11724 let mut builder = BooleanBuilder::with_capacity(column.len());
11726 for i in 0..column.len() {
11727 builder.append_value(column.is_null(i));
11728 }
11729 Ok(builder.finish())
11730 }
11731 PlanValue::Integer(val) => {
11732 let mut builder = BooleanBuilder::with_capacity(column.len());
11733 match column.data_type() {
11734 DataType::Int8 => {
11735 let arr = column
11736 .as_any()
11737 .downcast_ref::<Int8Array>()
11738 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
11739 let target = *val as i8;
11740 for i in 0..arr.len() {
11741 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11742 }
11743 }
11744 DataType::Int16 => {
11745 let arr = column
11746 .as_any()
11747 .downcast_ref::<Int16Array>()
11748 .ok_or_else(|| {
11749 Error::Internal("failed to downcast to Int16Array".into())
11750 })?;
11751 let target = *val as i16;
11752 for i in 0..arr.len() {
11753 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11754 }
11755 }
11756 DataType::Int32 => {
11757 let arr = column
11758 .as_any()
11759 .downcast_ref::<Int32Array>()
11760 .ok_or_else(|| {
11761 Error::Internal("failed to downcast to Int32Array".into())
11762 })?;
11763 let target = *val as i32;
11764 for i in 0..arr.len() {
11765 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11766 }
11767 }
11768 DataType::Int64 => {
11769 let arr = column
11770 .as_any()
11771 .downcast_ref::<Int64Array>()
11772 .ok_or_else(|| {
11773 Error::Internal("failed to downcast to Int64Array".into())
11774 })?;
11775 for i in 0..arr.len() {
11776 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
11777 }
11778 }
11779 DataType::UInt8 => {
11780 let arr = column
11781 .as_any()
11782 .downcast_ref::<UInt8Array>()
11783 .ok_or_else(|| {
11784 Error::Internal("failed to downcast to UInt8Array".into())
11785 })?;
11786 let target = *val as u8;
11787 for i in 0..arr.len() {
11788 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11789 }
11790 }
11791 DataType::UInt16 => {
11792 let arr = column
11793 .as_any()
11794 .downcast_ref::<UInt16Array>()
11795 .ok_or_else(|| {
11796 Error::Internal("failed to downcast to UInt16Array".into())
11797 })?;
11798 let target = *val as u16;
11799 for i in 0..arr.len() {
11800 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11801 }
11802 }
11803 DataType::UInt32 => {
11804 let arr = column
11805 .as_any()
11806 .downcast_ref::<UInt32Array>()
11807 .ok_or_else(|| {
11808 Error::Internal("failed to downcast to UInt32Array".into())
11809 })?;
11810 let target = *val as u32;
11811 for i in 0..arr.len() {
11812 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11813 }
11814 }
11815 DataType::UInt64 => {
11816 let arr = column
11817 .as_any()
11818 .downcast_ref::<UInt64Array>()
11819 .ok_or_else(|| {
11820 Error::Internal("failed to downcast to UInt64Array".into())
11821 })?;
11822 let target = *val as u64;
11823 for i in 0..arr.len() {
11824 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11825 }
11826 }
11827 _ => {
11828 return Err(Error::Internal(format!(
11829 "unsupported integer type for IN list: {:?}",
11830 column.data_type()
11831 )));
11832 }
11833 }
11834 Ok(builder.finish())
11835 }
11836 PlanValue::Float(val) => {
11837 let mut builder = BooleanBuilder::with_capacity(column.len());
11838 match column.data_type() {
11839 DataType::Float32 => {
11840 let arr = column
11841 .as_any()
11842 .downcast_ref::<Float32Array>()
11843 .ok_or_else(|| {
11844 Error::Internal("failed to downcast to Float32Array".into())
11845 })?;
11846 let target = *val as f32;
11847 for i in 0..arr.len() {
11848 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11849 }
11850 }
11851 DataType::Float64 => {
11852 let arr = column
11853 .as_any()
11854 .downcast_ref::<Float64Array>()
11855 .ok_or_else(|| {
11856 Error::Internal("failed to downcast to Float64Array".into())
11857 })?;
11858 for i in 0..arr.len() {
11859 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
11860 }
11861 }
11862 _ => {
11863 return Err(Error::Internal(format!(
11864 "unsupported float type for IN list: {:?}",
11865 column.data_type()
11866 )));
11867 }
11868 }
11869 Ok(builder.finish())
11870 }
11871 PlanValue::Decimal(expected) => match column.data_type() {
11872 DataType::Decimal128(precision, scale) => {
11873 let arr = column
11874 .as_any()
11875 .downcast_ref::<Decimal128Array>()
11876 .ok_or_else(|| {
11877 Error::Internal("failed to downcast to Decimal128Array".into())
11878 })?;
11879 let expected_aligned = align_decimal_to_scale(*expected, *precision, *scale)
11880 .map_err(|err| {
11881 Error::InvalidArgumentError(format!(
11882 "decimal literal {expected} incompatible with DECIMAL({}, {}): {err}",
11883 precision, scale
11884 ))
11885 })?;
11886 let mut builder = BooleanBuilder::with_capacity(arr.len());
11887 for i in 0..arr.len() {
11888 if arr.is_null(i) {
11889 builder.append_value(false);
11890 } else {
11891 let actual = DecimalValue::new(arr.value(i), *scale).map_err(|err| {
11892 Error::InvalidArgumentError(format!(
11893 "invalid decimal value stored in column: {err}"
11894 ))
11895 })?;
11896 builder.append_value(actual.raw_value() == expected_aligned.raw_value());
11897 }
11898 }
11899 Ok(builder.finish())
11900 }
11901 DataType::Int8
11902 | DataType::Int16
11903 | DataType::Int32
11904 | DataType::Int64
11905 | DataType::UInt8
11906 | DataType::UInt16
11907 | DataType::UInt32
11908 | DataType::UInt64
11909 | DataType::Boolean => {
11910 if let Some(int_value) = decimal_exact_i64(*expected) {
11911 return build_comparison_mask(column, &PlanValue::Integer(int_value));
11912 }
11913 Ok(BooleanArray::from(vec![false; column.len()]))
11914 }
11915 DataType::Float32 | DataType::Float64 => {
11916 build_comparison_mask(column, &PlanValue::Float(expected.to_f64()))
11917 }
11918 _ => Err(Error::Internal(format!(
11919 "unsupported decimal type for IN list: {:?}",
11920 column.data_type()
11921 ))),
11922 },
11923 PlanValue::String(val) => {
11924 let mut builder = BooleanBuilder::with_capacity(column.len());
11925 let arr = column
11926 .as_any()
11927 .downcast_ref::<StringArray>()
11928 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
11929 for i in 0..arr.len() {
11930 builder.append_value(!arr.is_null(i) && arr.value(i) == val.as_str());
11931 }
11932 Ok(builder.finish())
11933 }
11934 PlanValue::Date32(days) => {
11935 let mut builder = BooleanBuilder::with_capacity(column.len());
11936 match column.data_type() {
11937 DataType::Date32 => {
11938 let arr = column
11939 .as_any()
11940 .downcast_ref::<Date32Array>()
11941 .ok_or_else(|| {
11942 Error::Internal("failed to downcast to Date32Array".into())
11943 })?;
11944 for i in 0..arr.len() {
11945 builder.append_value(!arr.is_null(i) && arr.value(i) == *days);
11946 }
11947 }
11948 _ => {
11949 return Err(Error::Internal(format!(
11950 "unsupported DATE type for IN list: {:?}",
11951 column.data_type()
11952 )));
11953 }
11954 }
11955 Ok(builder.finish())
11956 }
11957 PlanValue::Interval(interval) => {
11958 let mut builder = BooleanBuilder::with_capacity(column.len());
11959 match column.data_type() {
11960 DataType::Interval(IntervalUnit::MonthDayNano) => {
11961 let arr = column
11962 .as_any()
11963 .downcast_ref::<IntervalMonthDayNanoArray>()
11964 .ok_or_else(|| {
11965 Error::Internal(
11966 "failed to downcast to IntervalMonthDayNanoArray".into(),
11967 )
11968 })?;
11969 let expected = *interval;
11970 for i in 0..arr.len() {
11971 if arr.is_null(i) {
11972 builder.append_value(false);
11973 } else {
11974 let candidate = interval_value_from_arrow(arr.value(i));
11975 let matches = compare_interval_values(expected, candidate)
11976 == std::cmp::Ordering::Equal;
11977 builder.append_value(matches);
11978 }
11979 }
11980 }
11981 _ => {
11982 return Err(Error::Internal(format!(
11983 "unsupported INTERVAL type for IN list: {:?}",
11984 column.data_type()
11985 )));
11986 }
11987 }
11988 Ok(builder.finish())
11989 }
11990 PlanValue::Struct(_) => Err(Error::Internal(
11991 "struct comparison in IN list not supported".into(),
11992 )),
11993 }
11994}
11995
11996fn array_value_equals_plan_value(
11997 array: &dyn Array,
11998 row_idx: usize,
11999 literal: &PlanValue,
12000) -> ExecutorResult<bool> {
12001 use arrow::array::*;
12002 use arrow::datatypes::DataType;
12003
12004 match literal {
12005 PlanValue::Null => Ok(array.is_null(row_idx)),
12006 PlanValue::Decimal(expected) => match array.data_type() {
12007 DataType::Decimal128(precision, scale) => {
12008 if array.is_null(row_idx) {
12009 return Ok(false);
12010 }
12011 let arr = array
12012 .as_any()
12013 .downcast_ref::<Decimal128Array>()
12014 .ok_or_else(|| {
12015 Error::Internal("failed to downcast to Decimal128Array".into())
12016 })?;
12017 let actual = DecimalValue::new(arr.value(row_idx), *scale).map_err(|err| {
12018 Error::InvalidArgumentError(format!(
12019 "invalid decimal value retrieved from column: {err}"
12020 ))
12021 })?;
12022 let expected_aligned = align_decimal_to_scale(*expected, *precision, *scale)
12023 .map_err(|err| {
12024 Error::InvalidArgumentError(format!(
12025 "failed to align decimal literal for comparison: {err}"
12026 ))
12027 })?;
12028 Ok(actual.raw_value() == expected_aligned.raw_value())
12029 }
12030 DataType::Int8
12031 | DataType::Int16
12032 | DataType::Int32
12033 | DataType::Int64
12034 | DataType::UInt8
12035 | DataType::UInt16
12036 | DataType::UInt32
12037 | DataType::UInt64 => {
12038 if array.is_null(row_idx) {
12039 return Ok(false);
12040 }
12041 if let Some(int_value) = decimal_exact_i64(*expected) {
12042 array_value_equals_plan_value(array, row_idx, &PlanValue::Integer(int_value))
12043 } else {
12044 Ok(false)
12045 }
12046 }
12047 DataType::Float32 | DataType::Float64 => {
12048 if array.is_null(row_idx) {
12049 return Ok(false);
12050 }
12051 array_value_equals_plan_value(array, row_idx, &PlanValue::Float(expected.to_f64()))
12052 }
12053 DataType::Boolean => {
12054 if array.is_null(row_idx) {
12055 return Ok(false);
12056 }
12057 if let Some(int_value) = decimal_exact_i64(*expected) {
12058 array_value_equals_plan_value(array, row_idx, &PlanValue::Integer(int_value))
12059 } else {
12060 Ok(false)
12061 }
12062 }
12063 _ => Err(Error::InvalidArgumentError(format!(
12064 "decimal literal comparison not supported for {:?}",
12065 array.data_type()
12066 ))),
12067 },
12068 PlanValue::Integer(expected) => match array.data_type() {
12069 DataType::Int8 => Ok(!array.is_null(row_idx)
12070 && array
12071 .as_any()
12072 .downcast_ref::<Int8Array>()
12073 .expect("int8 array")
12074 .value(row_idx) as i64
12075 == *expected),
12076 DataType::Int16 => Ok(!array.is_null(row_idx)
12077 && array
12078 .as_any()
12079 .downcast_ref::<Int16Array>()
12080 .expect("int16 array")
12081 .value(row_idx) as i64
12082 == *expected),
12083 DataType::Int32 => Ok(!array.is_null(row_idx)
12084 && array
12085 .as_any()
12086 .downcast_ref::<Int32Array>()
12087 .expect("int32 array")
12088 .value(row_idx) as i64
12089 == *expected),
12090 DataType::Int64 => Ok(!array.is_null(row_idx)
12091 && array
12092 .as_any()
12093 .downcast_ref::<Int64Array>()
12094 .expect("int64 array")
12095 .value(row_idx)
12096 == *expected),
12097 DataType::UInt8 if *expected >= 0 => Ok(!array.is_null(row_idx)
12098 && array
12099 .as_any()
12100 .downcast_ref::<UInt8Array>()
12101 .expect("uint8 array")
12102 .value(row_idx) as i64
12103 == *expected),
12104 DataType::UInt16 if *expected >= 0 => Ok(!array.is_null(row_idx)
12105 && array
12106 .as_any()
12107 .downcast_ref::<UInt16Array>()
12108 .expect("uint16 array")
12109 .value(row_idx) as i64
12110 == *expected),
12111 DataType::UInt32 if *expected >= 0 => Ok(!array.is_null(row_idx)
12112 && array
12113 .as_any()
12114 .downcast_ref::<UInt32Array>()
12115 .expect("uint32 array")
12116 .value(row_idx) as i64
12117 == *expected),
12118 DataType::UInt64 if *expected >= 0 => Ok(!array.is_null(row_idx)
12119 && array
12120 .as_any()
12121 .downcast_ref::<UInt64Array>()
12122 .expect("uint64 array")
12123 .value(row_idx)
12124 == *expected as u64),
12125 DataType::Boolean => {
12126 if array.is_null(row_idx) {
12127 Ok(false)
12128 } else if *expected == 0 || *expected == 1 {
12129 let value = array
12130 .as_any()
12131 .downcast_ref::<BooleanArray>()
12132 .expect("bool array")
12133 .value(row_idx);
12134 Ok(value == (*expected == 1))
12135 } else {
12136 Ok(false)
12137 }
12138 }
12139 _ => Err(Error::InvalidArgumentError(format!(
12140 "literal integer comparison not supported for {:?}",
12141 array.data_type()
12142 ))),
12143 },
12144 PlanValue::Float(expected) => match array.data_type() {
12145 DataType::Float32 => Ok(!array.is_null(row_idx)
12146 && (array
12147 .as_any()
12148 .downcast_ref::<Float32Array>()
12149 .expect("float32 array")
12150 .value(row_idx) as f64
12151 - *expected)
12152 .abs()
12153 .eq(&0.0)),
12154 DataType::Float64 => Ok(!array.is_null(row_idx)
12155 && (array
12156 .as_any()
12157 .downcast_ref::<Float64Array>()
12158 .expect("float64 array")
12159 .value(row_idx)
12160 - *expected)
12161 .abs()
12162 .eq(&0.0)),
12163 _ => Err(Error::InvalidArgumentError(format!(
12164 "literal float comparison not supported for {:?}",
12165 array.data_type()
12166 ))),
12167 },
12168 PlanValue::String(expected) => match array.data_type() {
12169 DataType::Utf8 => Ok(!array.is_null(row_idx)
12170 && array
12171 .as_any()
12172 .downcast_ref::<StringArray>()
12173 .expect("string array")
12174 .value(row_idx)
12175 == expected),
12176 DataType::LargeUtf8 => Ok(!array.is_null(row_idx)
12177 && array
12178 .as_any()
12179 .downcast_ref::<LargeStringArray>()
12180 .expect("large string array")
12181 .value(row_idx)
12182 == expected),
12183 _ => Err(Error::InvalidArgumentError(format!(
12184 "literal string comparison not supported for {:?}",
12185 array.data_type()
12186 ))),
12187 },
12188 PlanValue::Date32(expected) => match array.data_type() {
12189 DataType::Date32 => Ok(!array.is_null(row_idx)
12190 && array
12191 .as_any()
12192 .downcast_ref::<Date32Array>()
12193 .expect("date32 array")
12194 .value(row_idx)
12195 == *expected),
12196 _ => Err(Error::InvalidArgumentError(format!(
12197 "literal date comparison not supported for {:?}",
12198 array.data_type()
12199 ))),
12200 },
12201 PlanValue::Interval(expected) => {
12202 match array.data_type() {
12203 DataType::Interval(IntervalUnit::MonthDayNano) => {
12204 if array.is_null(row_idx) {
12205 Ok(false)
12206 } else {
12207 let value = array
12208 .as_any()
12209 .downcast_ref::<IntervalMonthDayNanoArray>()
12210 .expect("interval array")
12211 .value(row_idx);
12212 let arrow_value = interval_value_from_arrow(value);
12213 Ok(compare_interval_values(*expected, arrow_value)
12214 == std::cmp::Ordering::Equal)
12215 }
12216 }
12217 _ => Err(Error::InvalidArgumentError(format!(
12218 "literal interval comparison not supported for {:?}",
12219 array.data_type()
12220 ))),
12221 }
12222 }
12223 PlanValue::Struct(_) => Err(Error::InvalidArgumentError(
12224 "struct literals are not supported in join filters".into(),
12225 )),
12226 }
12227}
12228
12229fn hash_join_table_batches(
12230 left: TableCrossProductData,
12231 right: TableCrossProductData,
12232 join_keys: &[(usize, usize)],
12233 join_type: llkv_join::JoinType,
12234) -> ExecutorResult<TableCrossProductData> {
12235 let TableCrossProductData {
12236 schema: left_schema,
12237 batches: left_batches,
12238 column_counts: left_counts,
12239 table_indices: left_tables,
12240 } = left;
12241
12242 let TableCrossProductData {
12243 schema: right_schema,
12244 batches: right_batches,
12245 column_counts: right_counts,
12246 table_indices: right_tables,
12247 } = right;
12248
12249 let combined_fields: Vec<Field> = left_schema
12250 .fields()
12251 .iter()
12252 .chain(right_schema.fields().iter())
12253 .map(|field| field.as_ref().clone())
12254 .collect();
12255
12256 let combined_schema = Arc::new(Schema::new(combined_fields));
12257
12258 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
12259 column_counts.extend(left_counts.iter());
12260 column_counts.extend(right_counts.iter());
12261
12262 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
12263 table_indices.extend(left_tables.iter().copied());
12264 table_indices.extend(right_tables.iter().copied());
12265
12266 if left_batches.is_empty() {
12268 return Ok(TableCrossProductData {
12269 schema: combined_schema,
12270 batches: Vec::new(),
12271 column_counts,
12272 table_indices,
12273 });
12274 }
12275
12276 if right_batches.is_empty() {
12277 if join_type == llkv_join::JoinType::Left {
12279 let total_left_rows: usize = left_batches.iter().map(|b| b.num_rows()).sum();
12280 let mut left_arrays = Vec::new();
12281 for field in left_schema.fields() {
12282 let column_idx = left_schema.index_of(field.name()).map_err(|e| {
12283 Error::Internal(format!("failed to find field {}: {}", field.name(), e))
12284 })?;
12285 let arrays: Vec<ArrayRef> = left_batches
12286 .iter()
12287 .map(|batch| batch.column(column_idx).clone())
12288 .collect();
12289 let concatenated =
12290 arrow::compute::concat(&arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
12291 .map_err(|e| {
12292 Error::Internal(format!("failed to concat left arrays: {}", e))
12293 })?;
12294 left_arrays.push(concatenated);
12295 }
12296
12297 for field in right_schema.fields() {
12299 let null_array = arrow::array::new_null_array(field.data_type(), total_left_rows);
12300 left_arrays.push(null_array);
12301 }
12302
12303 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), left_arrays)
12304 .map_err(|err| {
12305 Error::Internal(format!(
12306 "failed to create LEFT JOIN batch with NULL right: {err}"
12307 ))
12308 })?;
12309
12310 return Ok(TableCrossProductData {
12311 schema: combined_schema,
12312 batches: vec![joined_batch],
12313 column_counts,
12314 table_indices,
12315 });
12316 } else {
12317 return Ok(TableCrossProductData {
12319 schema: combined_schema,
12320 batches: Vec::new(),
12321 column_counts,
12322 table_indices,
12323 });
12324 }
12325 }
12326
12327 match join_type {
12328 llkv_join::JoinType::Inner => {
12329 let (left_matches, right_matches) =
12330 build_join_match_indices(&left_batches, &right_batches, join_keys)?;
12331
12332 if left_matches.is_empty() {
12333 return Ok(TableCrossProductData {
12334 schema: combined_schema,
12335 batches: Vec::new(),
12336 column_counts,
12337 table_indices,
12338 });
12339 }
12340
12341 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
12342 let right_arrays = gather_indices_from_batches(&right_batches, &right_matches)?;
12343
12344 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
12345 combined_columns.extend(left_arrays);
12346 combined_columns.extend(right_arrays);
12347
12348 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
12349 .map_err(|err| {
12350 Error::Internal(format!("failed to materialize INNER JOIN batch: {err}"))
12351 })?;
12352
12353 Ok(TableCrossProductData {
12354 schema: combined_schema,
12355 batches: vec![joined_batch],
12356 column_counts,
12357 table_indices,
12358 })
12359 }
12360 llkv_join::JoinType::Left => {
12361 let (left_matches, right_optional_matches) =
12362 build_left_join_match_indices(&left_batches, &right_batches, join_keys)?;
12363
12364 if left_matches.is_empty() {
12365 return Ok(TableCrossProductData {
12367 schema: combined_schema,
12368 batches: Vec::new(),
12369 column_counts,
12370 table_indices,
12371 });
12372 }
12373
12374 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
12375 let right_arrays = llkv_column_map::gather::gather_optional_indices_from_batches(
12377 &right_batches,
12378 &right_optional_matches,
12379 )?;
12380
12381 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
12382 combined_columns.extend(left_arrays);
12383 combined_columns.extend(right_arrays);
12384
12385 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
12386 .map_err(|err| {
12387 Error::Internal(format!("failed to materialize LEFT JOIN batch: {err}"))
12388 })?;
12389
12390 Ok(TableCrossProductData {
12391 schema: combined_schema,
12392 batches: vec![joined_batch],
12393 column_counts,
12394 table_indices,
12395 })
12396 }
12397 _ => Err(Error::Internal(format!(
12399 "join type {:?} not supported in hash_join_table_batches; use llkv-join",
12400 join_type
12401 ))),
12402 }
12403}
12404
12405type JoinMatchIndices = Vec<(usize, usize)>;
12407type JoinHashTable = FxHashMap<Vec<u8>, Vec<(usize, usize)>>;
12409type JoinMatchPairs = (JoinMatchIndices, JoinMatchIndices);
12411type OptionalJoinMatches = Vec<Option<(usize, usize)>>;
12413type LeftJoinMatchPairs = (JoinMatchIndices, OptionalJoinMatches);
12415
12416fn normalize_join_column(array: &ArrayRef) -> ExecutorResult<ArrayRef> {
12417 match array.data_type() {
12418 DataType::Boolean
12419 | DataType::Int8
12420 | DataType::Int16
12421 | DataType::Int32
12422 | DataType::UInt8
12423 | DataType::UInt16
12424 | DataType::UInt32
12425 | DataType::UInt64 => cast(array, &DataType::Int64)
12426 .map_err(|e| Error::Internal(format!("failed to cast integer/boolean to Int64: {e}"))),
12427 DataType::Float32 => cast(array, &DataType::Float64)
12428 .map_err(|e| Error::Internal(format!("failed to cast Float32 to Float64: {e}"))),
12429 DataType::Utf8 | DataType::LargeUtf8 => cast(array, &DataType::LargeUtf8)
12430 .map_err(|e| Error::Internal(format!("failed to cast Utf8 to LargeUtf8: {e}"))),
12431 DataType::Dictionary(_, value_type) => {
12432 let unpacked = cast(array, value_type)
12433 .map_err(|e| Error::Internal(format!("failed to unpack dictionary: {e}")))?;
12434 normalize_join_column(&unpacked)
12435 }
12436 _ => Ok(array.clone()),
12437 }
12438}
12439
12440fn build_join_match_indices(
12470 left_batches: &[RecordBatch],
12471 right_batches: &[RecordBatch],
12472 join_keys: &[(usize, usize)],
12473) -> ExecutorResult<JoinMatchPairs> {
12474 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
12475
12476 let hash_table: JoinHashTable = with_thread_pool(|| {
12479 let local_tables: Vec<ExecutorResult<JoinHashTable>> = right_batches
12480 .par_iter()
12481 .enumerate()
12482 .map(|(batch_idx, batch)| {
12483 let mut local_table: JoinHashTable = FxHashMap::default();
12484
12485 let columns: Vec<ArrayRef> = right_key_indices
12486 .iter()
12487 .map(|&idx| normalize_join_column(batch.column(idx)))
12488 .collect::<ExecutorResult<Vec<_>>>()?;
12489
12490 let sort_fields: Vec<SortField> = columns
12491 .iter()
12492 .map(|c| SortField::new(c.data_type().clone()))
12493 .collect();
12494
12495 let converter = RowConverter::new(sort_fields)
12496 .map_err(|e| Error::Internal(format!("failed to create RowConverter: {e}")))?;
12497 let rows = converter.convert_columns(&columns).map_err(|e| {
12498 Error::Internal(format!("failed to convert columns to rows: {e}"))
12499 })?;
12500
12501 for (row_idx, row) in rows.iter().enumerate() {
12502 if columns.iter().any(|c| c.is_null(row_idx)) {
12504 continue;
12505 }
12506
12507 local_table
12508 .entry(row.as_ref().to_vec())
12509 .or_default()
12510 .push((batch_idx, row_idx));
12511 }
12512
12513 Ok(local_table)
12514 })
12515 .collect();
12516
12517 let mut merged_table: JoinHashTable = FxHashMap::default();
12519 for local_table_res in local_tables {
12520 if let Ok(local_table) = local_table_res {
12521 for (key, mut positions) in local_table {
12522 merged_table.entry(key).or_default().append(&mut positions);
12523 }
12524 } else {
12525 tracing::error!("failed to build hash table for batch");
12526 }
12527 }
12528
12529 merged_table
12530 });
12531
12532 if hash_table.is_empty() {
12533 return Ok((Vec::new(), Vec::new()));
12534 }
12535
12536 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
12537
12538 let matches: Vec<ExecutorResult<JoinMatchPairs>> = with_thread_pool(|| {
12541 left_batches
12542 .par_iter()
12543 .enumerate()
12544 .map(|(batch_idx, batch)| {
12545 let mut local_left_matches: JoinMatchIndices = Vec::new();
12546 let mut local_right_matches: JoinMatchIndices = Vec::new();
12547
12548 let columns: Vec<ArrayRef> = left_key_indices
12549 .iter()
12550 .map(|&idx| normalize_join_column(batch.column(idx)))
12551 .collect::<ExecutorResult<Vec<_>>>()?;
12552
12553 let sort_fields: Vec<SortField> = columns
12554 .iter()
12555 .map(|c| SortField::new(c.data_type().clone()))
12556 .collect();
12557
12558 let converter = RowConverter::new(sort_fields)
12559 .map_err(|e| Error::Internal(format!("failed to create RowConverter: {e}")))?;
12560 let rows = converter.convert_columns(&columns).map_err(|e| {
12561 Error::Internal(format!("failed to convert columns to rows: {e}"))
12562 })?;
12563
12564 for (row_idx, row) in rows.iter().enumerate() {
12565 if columns.iter().any(|c| c.is_null(row_idx)) {
12566 continue;
12567 }
12568
12569 if let Some(positions) = hash_table.get(row.as_ref()) {
12570 for &(r_batch_idx, r_row_idx) in positions {
12571 local_left_matches.push((batch_idx, row_idx));
12572 local_right_matches.push((r_batch_idx, r_row_idx));
12573 }
12574 }
12575 }
12576
12577 Ok((local_left_matches, local_right_matches))
12578 })
12579 .collect()
12580 });
12581
12582 let mut left_matches: JoinMatchIndices = Vec::new();
12584 let mut right_matches: JoinMatchIndices = Vec::new();
12585 for match_res in matches {
12586 let (mut left, mut right) = match_res?;
12587 left_matches.append(&mut left);
12588 right_matches.append(&mut right);
12589 }
12590
12591 Ok((left_matches, right_matches))
12592}
12593
12594fn build_left_join_match_indices(
12605 left_batches: &[RecordBatch],
12606 right_batches: &[RecordBatch],
12607 join_keys: &[(usize, usize)],
12608) -> ExecutorResult<LeftJoinMatchPairs> {
12609 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
12610
12611 let hash_table: JoinHashTable = with_thread_pool(|| {
12613 let local_tables: Vec<JoinHashTable> = right_batches
12614 .par_iter()
12615 .enumerate()
12616 .map(|(batch_idx, batch)| {
12617 let mut local_table: JoinHashTable = FxHashMap::default();
12618 let mut key_buffer: Vec<u8> = Vec::new();
12619
12620 for row_idx in 0..batch.num_rows() {
12621 key_buffer.clear();
12622 match build_join_key(batch, &right_key_indices, row_idx, &mut key_buffer) {
12623 Ok(true) => {
12624 local_table
12625 .entry(key_buffer.clone())
12626 .or_default()
12627 .push((batch_idx, row_idx));
12628 }
12629 Ok(false) => continue,
12630 Err(_) => continue,
12631 }
12632 }
12633
12634 local_table
12635 })
12636 .collect();
12637
12638 let mut merged_table: JoinHashTable = FxHashMap::default();
12639 for local_table in local_tables {
12640 for (key, mut positions) in local_table {
12641 merged_table.entry(key).or_default().append(&mut positions);
12642 }
12643 }
12644
12645 merged_table
12646 });
12647
12648 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
12649
12650 let matches: Vec<LeftJoinMatchPairs> = with_thread_pool(|| {
12652 left_batches
12653 .par_iter()
12654 .enumerate()
12655 .map(|(batch_idx, batch)| {
12656 let mut local_left_matches: JoinMatchIndices = Vec::new();
12657 let mut local_right_optional: Vec<Option<(usize, usize)>> = Vec::new();
12658 let mut key_buffer: Vec<u8> = Vec::new();
12659
12660 for row_idx in 0..batch.num_rows() {
12661 key_buffer.clear();
12662 match build_join_key(batch, &left_key_indices, row_idx, &mut key_buffer) {
12663 Ok(true) => {
12664 if let Some(entries) = hash_table.get(&key_buffer) {
12665 for &(r_batch, r_row) in entries {
12667 local_left_matches.push((batch_idx, row_idx));
12668 local_right_optional.push(Some((r_batch, r_row)));
12669 }
12670 } else {
12671 local_left_matches.push((batch_idx, row_idx));
12673 local_right_optional.push(None);
12674 }
12675 }
12676 Ok(false) => {
12677 local_left_matches.push((batch_idx, row_idx));
12679 local_right_optional.push(None);
12680 }
12681 Err(_) => {
12682 local_left_matches.push((batch_idx, row_idx));
12684 local_right_optional.push(None);
12685 }
12686 }
12687 }
12688
12689 (local_left_matches, local_right_optional)
12690 })
12691 .collect()
12692 });
12693
12694 let mut left_matches: JoinMatchIndices = Vec::new();
12696 let mut right_optional: Vec<Option<(usize, usize)>> = Vec::new();
12697 for (mut left, mut right) in matches {
12698 left_matches.append(&mut left);
12699 right_optional.append(&mut right);
12700 }
12701
12702 Ok((left_matches, right_optional))
12703}
12704
12705fn build_join_key(
12706 batch: &RecordBatch,
12707 column_indices: &[usize],
12708 row_idx: usize,
12709 buffer: &mut Vec<u8>,
12710) -> ExecutorResult<bool> {
12711 buffer.clear();
12712
12713 for &col_idx in column_indices {
12714 let array = batch.column(col_idx);
12715 if array.is_null(row_idx) {
12716 return Ok(false);
12717 }
12718 append_array_value_to_key(array.as_ref(), row_idx, buffer)?;
12719 }
12720
12721 Ok(true)
12722}
12723
12724fn append_array_value_to_key(
12725 array: &dyn Array,
12726 row_idx: usize,
12727 buffer: &mut Vec<u8>,
12728) -> ExecutorResult<()> {
12729 use arrow::array::*;
12730 use arrow::datatypes::DataType;
12731
12732 match array.data_type() {
12733 DataType::Int8 => buffer.extend_from_slice(
12734 &array
12735 .as_any()
12736 .downcast_ref::<Int8Array>()
12737 .expect("int8 array")
12738 .value(row_idx)
12739 .to_le_bytes(),
12740 ),
12741 DataType::Int16 => buffer.extend_from_slice(
12742 &array
12743 .as_any()
12744 .downcast_ref::<Int16Array>()
12745 .expect("int16 array")
12746 .value(row_idx)
12747 .to_le_bytes(),
12748 ),
12749 DataType::Int32 => buffer.extend_from_slice(
12750 &array
12751 .as_any()
12752 .downcast_ref::<Int32Array>()
12753 .expect("int32 array")
12754 .value(row_idx)
12755 .to_le_bytes(),
12756 ),
12757 DataType::Int64 => buffer.extend_from_slice(
12758 &array
12759 .as_any()
12760 .downcast_ref::<Int64Array>()
12761 .expect("int64 array")
12762 .value(row_idx)
12763 .to_le_bytes(),
12764 ),
12765 DataType::UInt8 => buffer.extend_from_slice(
12766 &array
12767 .as_any()
12768 .downcast_ref::<UInt8Array>()
12769 .expect("uint8 array")
12770 .value(row_idx)
12771 .to_le_bytes(),
12772 ),
12773 DataType::UInt16 => buffer.extend_from_slice(
12774 &array
12775 .as_any()
12776 .downcast_ref::<UInt16Array>()
12777 .expect("uint16 array")
12778 .value(row_idx)
12779 .to_le_bytes(),
12780 ),
12781 DataType::UInt32 => buffer.extend_from_slice(
12782 &array
12783 .as_any()
12784 .downcast_ref::<UInt32Array>()
12785 .expect("uint32 array")
12786 .value(row_idx)
12787 .to_le_bytes(),
12788 ),
12789 DataType::UInt64 => buffer.extend_from_slice(
12790 &array
12791 .as_any()
12792 .downcast_ref::<UInt64Array>()
12793 .expect("uint64 array")
12794 .value(row_idx)
12795 .to_le_bytes(),
12796 ),
12797 DataType::Float32 => buffer.extend_from_slice(
12798 &array
12799 .as_any()
12800 .downcast_ref::<Float32Array>()
12801 .expect("float32 array")
12802 .value(row_idx)
12803 .to_le_bytes(),
12804 ),
12805 DataType::Float64 => buffer.extend_from_slice(
12806 &array
12807 .as_any()
12808 .downcast_ref::<Float64Array>()
12809 .expect("float64 array")
12810 .value(row_idx)
12811 .to_le_bytes(),
12812 ),
12813 DataType::Boolean => buffer.push(
12814 array
12815 .as_any()
12816 .downcast_ref::<BooleanArray>()
12817 .expect("bool array")
12818 .value(row_idx) as u8,
12819 ),
12820 DataType::Utf8 => {
12821 let value = array
12822 .as_any()
12823 .downcast_ref::<StringArray>()
12824 .expect("utf8 array")
12825 .value(row_idx);
12826 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
12827 buffer.extend_from_slice(value.as_bytes());
12828 }
12829 DataType::LargeUtf8 => {
12830 let value = array
12831 .as_any()
12832 .downcast_ref::<LargeStringArray>()
12833 .expect("large utf8 array")
12834 .value(row_idx);
12835 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
12836 buffer.extend_from_slice(value.as_bytes());
12837 }
12838 DataType::Binary => {
12839 let value = array
12840 .as_any()
12841 .downcast_ref::<BinaryArray>()
12842 .expect("binary array")
12843 .value(row_idx);
12844 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
12845 buffer.extend_from_slice(value);
12846 }
12847 other => {
12848 return Err(Error::InvalidArgumentError(format!(
12849 "hash join does not support join key type {:?}",
12850 other
12851 )));
12852 }
12853 }
12854
12855 Ok(())
12856}
12857
12858fn table_has_join_with_used(
12859 candidate: usize,
12860 used_tables: &FxHashSet<usize>,
12861 equalities: &[ColumnEquality],
12862) -> bool {
12863 equalities.iter().any(|equality| {
12864 (equality.left.table == candidate && used_tables.contains(&equality.right.table))
12865 || (equality.right.table == candidate && used_tables.contains(&equality.left.table))
12866 })
12867}
12868
12869fn gather_join_keys(
12870 left: &TableCrossProductData,
12871 right: &TableCrossProductData,
12872 used_tables: &FxHashSet<usize>,
12873 right_table_index: usize,
12874 equalities: &[ColumnEquality],
12875) -> ExecutorResult<Vec<(usize, usize)>> {
12876 let mut keys = Vec::new();
12877
12878 for equality in equalities {
12879 if equality.left.table == right_table_index && used_tables.contains(&equality.right.table) {
12880 let left_idx = resolve_column_index(left, &equality.right).ok_or_else(|| {
12881 Error::Internal("failed to resolve column offset for hash join".into())
12882 })?;
12883 let right_idx = resolve_column_index(right, &equality.left).ok_or_else(|| {
12884 Error::Internal("failed to resolve column offset for hash join".into())
12885 })?;
12886 keys.push((left_idx, right_idx));
12887 } else if equality.right.table == right_table_index
12888 && used_tables.contains(&equality.left.table)
12889 {
12890 let left_idx = resolve_column_index(left, &equality.left).ok_or_else(|| {
12891 Error::Internal("failed to resolve column offset for hash join".into())
12892 })?;
12893 let right_idx = resolve_column_index(right, &equality.right).ok_or_else(|| {
12894 Error::Internal("failed to resolve column offset for hash join".into())
12895 })?;
12896 keys.push((left_idx, right_idx));
12897 }
12898 }
12899
12900 Ok(keys)
12901}
12902
12903fn resolve_column_index(data: &TableCrossProductData, column: &ColumnRef) -> Option<usize> {
12904 let mut offset = 0;
12905 for (table_idx, count) in data.table_indices.iter().zip(data.column_counts.iter()) {
12906 if *table_idx == column.table {
12907 if column.column < *count {
12908 return Some(offset + column.column);
12909 } else {
12910 return None;
12911 }
12912 }
12913 offset += count;
12914 }
12915 None
12916}
12917
12918fn build_cross_product_column_lookup(
12919 schema: &Schema,
12920 tables: &[llkv_plan::TableRef],
12921 column_counts: &[usize],
12922 table_indices: &[usize],
12923) -> FxHashMap<String, usize> {
12924 debug_assert_eq!(tables.len(), column_counts.len());
12925 debug_assert_eq!(column_counts.len(), table_indices.len());
12926
12927 let mut column_occurrences: FxHashMap<String, usize> = FxHashMap::default();
12928 let mut table_column_counts: FxHashMap<String, usize> = FxHashMap::default();
12929 for field in schema.fields() {
12930 let column_name = extract_column_name(field.name());
12931 *column_occurrences.entry(column_name).or_insert(0) += 1;
12932 if let Some(pair) = table_column_suffix(field.name()) {
12933 *table_column_counts.entry(pair).or_insert(0) += 1;
12934 }
12935 }
12936
12937 let mut base_table_totals: FxHashMap<String, usize> = FxHashMap::default();
12938 let mut base_table_unaliased: FxHashMap<String, usize> = FxHashMap::default();
12939 for table_ref in tables {
12940 let key = base_table_key(table_ref);
12941 *base_table_totals.entry(key.clone()).or_insert(0) += 1;
12942 if table_ref.alias.is_none() {
12943 *base_table_unaliased.entry(key).or_insert(0) += 1;
12944 }
12945 }
12946
12947 let mut lookup = FxHashMap::default();
12948
12949 if table_indices.is_empty() || column_counts.is_empty() {
12950 for (idx, field) in schema.fields().iter().enumerate() {
12951 let field_name_lower = field.name().to_ascii_lowercase();
12952 lookup.entry(field_name_lower).or_insert(idx);
12953
12954 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
12955 lookup.entry(trimmed_lower).or_insert(idx);
12956
12957 if let Some(pair) = table_column_suffix(field.name())
12958 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
12959 {
12960 lookup.entry(pair).or_insert(idx);
12961 }
12962
12963 let column_name = extract_column_name(field.name());
12964 if column_occurrences.get(&column_name).copied().unwrap_or(0) == 1 {
12965 lookup.entry(column_name).or_insert(idx);
12966 }
12967 }
12968 return lookup;
12969 }
12970
12971 let mut offset = 0usize;
12972 for (&table_idx, &count) in table_indices.iter().zip(column_counts.iter()) {
12973 if table_idx >= tables.len() {
12974 continue;
12975 }
12976 let table_ref = &tables[table_idx];
12977 let alias_lower = table_ref
12978 .alias
12979 .as_ref()
12980 .map(|alias| alias.to_ascii_lowercase());
12981 let table_lower = table_ref.table.to_ascii_lowercase();
12982 let schema_lower = table_ref.schema.to_ascii_lowercase();
12983 let base_key = base_table_key(table_ref);
12984 let total_refs = base_table_totals.get(&base_key).copied().unwrap_or(0);
12985 let unaliased_refs = base_table_unaliased.get(&base_key).copied().unwrap_or(0);
12986
12987 let allow_base_mapping = if table_ref.alias.is_none() {
12988 unaliased_refs == 1
12989 } else {
12990 unaliased_refs == 0 && total_refs == 1
12991 };
12992
12993 let mut table_keys: Vec<String> = Vec::new();
12994
12995 if let Some(alias) = &alias_lower {
12996 table_keys.push(alias.clone());
12997 if !schema_lower.is_empty() {
12998 table_keys.push(format!("{}.{}", schema_lower, alias));
12999 }
13000 }
13001
13002 if allow_base_mapping {
13003 table_keys.push(table_lower.clone());
13004 if !schema_lower.is_empty() {
13005 table_keys.push(format!("{}.{}", schema_lower, table_lower));
13006 }
13007 }
13008
13009 for local_idx in 0..count {
13010 let field_index = offset + local_idx;
13011 let field = schema.field(field_index);
13012 let field_name_lower = field.name().to_ascii_lowercase();
13013 lookup.entry(field_name_lower).or_insert(field_index);
13014
13015 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
13016 lookup.entry(trimmed_lower).or_insert(field_index);
13017
13018 let column_name = extract_column_name(field.name());
13019 for table_key in &table_keys {
13020 lookup
13021 .entry(format!("{}.{}", table_key, column_name))
13022 .or_insert(field_index);
13023 }
13024
13025 lookup.entry(column_name.clone()).or_insert(field_index);
13029
13030 if table_keys.is_empty()
13031 && let Some(pair) = table_column_suffix(field.name())
13032 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
13033 {
13034 lookup.entry(pair).or_insert(field_index);
13035 }
13036 }
13037
13038 offset = offset.saturating_add(count);
13039 }
13040
13041 lookup
13042}
13043
13044fn base_table_key(table_ref: &llkv_plan::TableRef) -> String {
13045 let schema_lower = table_ref.schema.to_ascii_lowercase();
13046 let table_lower = table_ref.table.to_ascii_lowercase();
13047 if schema_lower.is_empty() {
13048 table_lower
13049 } else {
13050 format!("{}.{}", schema_lower, table_lower)
13051 }
13052}
13053
13054fn extract_column_name(name: &str) -> String {
13055 name.trim_start_matches('.')
13056 .rsplit('.')
13057 .next()
13058 .unwrap_or(name)
13059 .to_ascii_lowercase()
13060}
13061
13062fn table_column_suffix(name: &str) -> Option<String> {
13063 let trimmed = name.trim_start_matches('.');
13064 let mut parts: Vec<&str> = trimmed.split('.').collect();
13065 if parts.len() < 2 {
13066 return None;
13067 }
13068 let column = parts.pop()?.to_ascii_lowercase();
13069 let table = parts.pop()?.to_ascii_lowercase();
13070 Some(format!("{}.{}", table, column))
13071}
13072
13073fn cross_join_table_batches(
13098 left: TableCrossProductData,
13099 right: TableCrossProductData,
13100) -> ExecutorResult<TableCrossProductData> {
13101 let TableCrossProductData {
13102 schema: left_schema,
13103 batches: left_batches,
13104 column_counts: mut left_counts,
13105 table_indices: mut left_tables,
13106 } = left;
13107 let TableCrossProductData {
13108 schema: right_schema,
13109 batches: right_batches,
13110 column_counts: right_counts,
13111 table_indices: right_tables,
13112 } = right;
13113
13114 let combined_fields: Vec<Field> = left_schema
13115 .fields()
13116 .iter()
13117 .chain(right_schema.fields().iter())
13118 .map(|field| field.as_ref().clone())
13119 .collect();
13120
13121 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
13122 column_counts.append(&mut left_counts);
13123 column_counts.extend(right_counts);
13124
13125 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
13126 table_indices.append(&mut left_tables);
13127 table_indices.extend(right_tables);
13128
13129 let combined_schema = Arc::new(Schema::new(combined_fields));
13130
13131 let left_has_rows = left_batches.iter().any(|batch| batch.num_rows() > 0);
13132 let right_has_rows = right_batches.iter().any(|batch| batch.num_rows() > 0);
13133
13134 if !left_has_rows || !right_has_rows {
13135 return Ok(TableCrossProductData {
13136 schema: combined_schema,
13137 batches: Vec::new(),
13138 column_counts,
13139 table_indices,
13140 });
13141 }
13142
13143 let output_batches: Vec<RecordBatch> = with_thread_pool(|| {
13146 left_batches
13147 .par_iter()
13148 .filter(|left_batch| left_batch.num_rows() > 0)
13149 .flat_map(|left_batch| {
13150 right_batches
13151 .par_iter()
13152 .filter(|right_batch| right_batch.num_rows() > 0)
13153 .filter_map(|right_batch| {
13154 cross_join_pair(left_batch, right_batch, &combined_schema).ok()
13155 })
13156 .collect::<Vec<_>>()
13157 })
13158 .collect()
13159 });
13160
13161 Ok(TableCrossProductData {
13162 schema: combined_schema,
13163 batches: output_batches,
13164 column_counts,
13165 table_indices,
13166 })
13167}
13168
13169fn cross_join_all(staged: Vec<TableCrossProductData>) -> ExecutorResult<TableCrossProductData> {
13170 let mut iter = staged.into_iter();
13171 let mut current = iter
13172 .next()
13173 .ok_or_else(|| Error::Internal("cross product preparation yielded no tables".into()))?;
13174 for next in iter {
13175 current = cross_join_table_batches(current, next)?;
13176 }
13177 Ok(current)
13178}
13179
13180struct TableInfo<'a> {
13181 index: usize,
13182 table_ref: &'a llkv_plan::TableRef,
13183 column_map: FxHashMap<String, usize>,
13184}
13185
13186#[derive(Clone, Copy)]
13187struct ColumnRef {
13188 table: usize,
13189 column: usize,
13190}
13191
13192#[derive(Clone, Copy)]
13193struct ColumnEquality {
13194 left: ColumnRef,
13195 right: ColumnRef,
13196}
13197
13198#[derive(Clone)]
13199struct ColumnLiteral {
13200 column: ColumnRef,
13201 value: PlanValue,
13202}
13203
13204#[derive(Clone)]
13205struct ColumnInList {
13206 column: ColumnRef,
13207 values: Vec<PlanValue>,
13208}
13209
13210#[derive(Clone)]
13211enum ColumnConstraint {
13212 Equality(ColumnLiteral),
13213 InList(ColumnInList),
13214}
13215
13216struct JoinConstraintPlan {
13218 equalities: Vec<ColumnEquality>,
13219 literals: Vec<ColumnConstraint>,
13220 unsatisfiable: bool,
13221 total_conjuncts: usize,
13223 handled_conjuncts: usize,
13225}
13226
13227fn extract_literal_pushdown_filters<P>(
13246 expr: &LlkvExpr<'static, String>,
13247 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
13248) -> Vec<Vec<ColumnConstraint>>
13249where
13250 P: Pager<Blob = EntryHandle> + Send + Sync,
13251{
13252 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
13253 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
13254 let mut column_map = FxHashMap::default();
13255 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
13256 let column_name = column.name.to_ascii_lowercase();
13257 column_map.entry(column_name).or_insert(column_idx);
13258 }
13259 table_infos.push(TableInfo {
13260 index,
13261 table_ref,
13262 column_map,
13263 });
13264 }
13265
13266 let mut constraints: Vec<Vec<ColumnConstraint>> = vec![Vec::new(); tables_with_handles.len()];
13267
13268 let mut conjuncts = Vec::new();
13270 collect_conjuncts_lenient(expr, &mut conjuncts);
13271
13272 for conjunct in conjuncts {
13273 if let LlkvExpr::Compare {
13275 left,
13276 op: CompareOp::Eq,
13277 right,
13278 } = conjunct
13279 {
13280 match (
13281 resolve_column_reference(left, &table_infos),
13282 resolve_column_reference(right, &table_infos),
13283 ) {
13284 (Some(column), None) => {
13285 if let Some(literal) = extract_literal(right)
13286 && let Some(value) = PlanValue::from_literal_for_join(literal)
13287 && column.table < constraints.len()
13288 {
13289 constraints[column.table]
13290 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13291 }
13292 }
13293 (None, Some(column)) => {
13294 if let Some(literal) = extract_literal(left)
13295 && let Some(value) = PlanValue::from_literal_for_join(literal)
13296 && column.table < constraints.len()
13297 {
13298 constraints[column.table]
13299 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13300 }
13301 }
13302 _ => {}
13303 }
13304 }
13305 else if let LlkvExpr::Pred(filter) = conjunct {
13308 if let Operator::Equals(ref literal_val) = filter.op {
13309 let field_name = filter.field_id.trim().to_ascii_lowercase();
13311
13312 for info in &table_infos {
13314 if let Some(&col_idx) = info.column_map.get(&field_name) {
13315 if let Some(value) = PlanValue::from_operator_literal(literal_val) {
13316 let column_ref = ColumnRef {
13317 table: info.index,
13318 column: col_idx,
13319 };
13320 if info.index < constraints.len() {
13321 constraints[info.index].push(ColumnConstraint::Equality(
13322 ColumnLiteral {
13323 column: column_ref,
13324 value,
13325 },
13326 ));
13327 }
13328 }
13329 break; }
13331 }
13332 }
13333 }
13334 else if let LlkvExpr::InList {
13336 expr: col_expr,
13337 list,
13338 negated: false,
13339 } = conjunct
13340 {
13341 if let Some(column) = resolve_column_reference(col_expr, &table_infos) {
13342 let mut values = Vec::new();
13343 for item in list {
13344 if let Some(literal) = extract_literal(item)
13345 && let Some(value) = PlanValue::from_literal_for_join(literal)
13346 {
13347 values.push(value);
13348 }
13349 }
13350 if !values.is_empty() && column.table < constraints.len() {
13351 constraints[column.table]
13352 .push(ColumnConstraint::InList(ColumnInList { column, values }));
13353 }
13354 }
13355 }
13356 else if let LlkvExpr::Or(or_children) = conjunct
13358 && let Some((column, values)) = try_extract_or_as_in_list(or_children, &table_infos)
13359 && !values.is_empty()
13360 && column.table < constraints.len()
13361 {
13362 constraints[column.table]
13363 .push(ColumnConstraint::InList(ColumnInList { column, values }));
13364 }
13365 }
13366
13367 constraints
13368}
13369
13370fn collect_conjuncts_lenient<'a>(
13375 expr: &'a LlkvExpr<'static, String>,
13376 out: &mut Vec<&'a LlkvExpr<'static, String>>,
13377) {
13378 match expr {
13379 LlkvExpr::And(children) => {
13380 for child in children {
13381 collect_conjuncts_lenient(child, out);
13382 }
13383 }
13384 other => {
13385 out.push(other);
13387 }
13388 }
13389}
13390
13391fn try_extract_or_as_in_list(
13395 or_children: &[LlkvExpr<'static, String>],
13396 table_infos: &[TableInfo<'_>],
13397) -> Option<(ColumnRef, Vec<PlanValue>)> {
13398 if or_children.is_empty() {
13399 return None;
13400 }
13401
13402 let mut common_column: Option<ColumnRef> = None;
13403 let mut values = Vec::new();
13404
13405 for child in or_children {
13406 if let LlkvExpr::Compare {
13408 left,
13409 op: CompareOp::Eq,
13410 right,
13411 } = child
13412 {
13413 if let (Some(column), None) = (
13415 resolve_column_reference(left, table_infos),
13416 resolve_column_reference(right, table_infos),
13417 ) && let Some(literal) = extract_literal(right)
13418 && let Some(value) = PlanValue::from_literal_for_join(literal)
13419 {
13420 match common_column {
13422 None => common_column = Some(column),
13423 Some(ref prev)
13424 if prev.table == column.table && prev.column == column.column =>
13425 {
13426 }
13428 _ => {
13429 return None;
13431 }
13432 }
13433 values.push(value);
13434 continue;
13435 }
13436
13437 if let (None, Some(column)) = (
13439 resolve_column_reference(left, table_infos),
13440 resolve_column_reference(right, table_infos),
13441 ) && let Some(literal) = extract_literal(left)
13442 && let Some(value) = PlanValue::from_literal_for_join(literal)
13443 {
13444 match common_column {
13445 None => common_column = Some(column),
13446 Some(ref prev)
13447 if prev.table == column.table && prev.column == column.column => {}
13448 _ => return None,
13449 }
13450 values.push(value);
13451 continue;
13452 }
13453 }
13454 else if let LlkvExpr::Pred(filter) = child
13456 && let Operator::Equals(ref literal) = filter.op
13457 && let Some(column) =
13458 resolve_column_reference(&ScalarExpr::Column(filter.field_id.clone()), table_infos)
13459 && let Some(value) = PlanValue::from_literal_for_join(literal)
13460 {
13461 match common_column {
13462 None => common_column = Some(column),
13463 Some(ref prev) if prev.table == column.table && prev.column == column.column => {}
13464 _ => return None,
13465 }
13466 values.push(value);
13467 continue;
13468 }
13469
13470 return None;
13472 }
13473
13474 common_column.map(|col| (col, values))
13475}
13476
13477fn extract_join_constraints(
13504 expr: &LlkvExpr<'static, String>,
13505 table_infos: &[TableInfo<'_>],
13506) -> Option<JoinConstraintPlan> {
13507 let mut conjuncts = Vec::new();
13508 collect_conjuncts_lenient(expr, &mut conjuncts);
13510
13511 let total_conjuncts = conjuncts.len();
13512 let mut equalities = Vec::new();
13513 let mut literals = Vec::new();
13514 let mut unsatisfiable = false;
13515 let mut handled_conjuncts = 0;
13516
13517 for conjunct in conjuncts {
13518 match conjunct {
13519 LlkvExpr::Literal(true) => {
13520 handled_conjuncts += 1;
13521 }
13522 LlkvExpr::Literal(false) => {
13523 unsatisfiable = true;
13524 handled_conjuncts += 1;
13525 break;
13526 }
13527 LlkvExpr::Compare {
13528 left,
13529 op: CompareOp::Eq,
13530 right,
13531 } => {
13532 match (
13533 resolve_column_reference(left, table_infos),
13534 resolve_column_reference(right, table_infos),
13535 ) {
13536 (Some(left_col), Some(right_col)) => {
13537 equalities.push(ColumnEquality {
13538 left: left_col,
13539 right: right_col,
13540 });
13541 handled_conjuncts += 1;
13542 continue;
13543 }
13544 (Some(column), None) => {
13545 if let Some(literal) = extract_literal(right)
13546 && let Some(value) = PlanValue::from_literal_for_join(literal)
13547 {
13548 literals
13549 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13550 handled_conjuncts += 1;
13551 continue;
13552 }
13553 }
13554 (None, Some(column)) => {
13555 if let Some(literal) = extract_literal(left)
13556 && let Some(value) = PlanValue::from_literal_for_join(literal)
13557 {
13558 literals
13559 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13560 handled_conjuncts += 1;
13561 continue;
13562 }
13563 }
13564 _ => {}
13565 }
13566 }
13568 LlkvExpr::InList {
13570 expr: col_expr,
13571 list,
13572 negated: false,
13573 } => {
13574 if let Some(column) = resolve_column_reference(col_expr, table_infos) {
13575 let mut in_list_values = Vec::new();
13577 for item in list {
13578 if let Some(literal) = extract_literal(item)
13579 && let Some(value) = PlanValue::from_literal_for_join(literal)
13580 {
13581 in_list_values.push(value);
13582 }
13583 }
13584 if !in_list_values.is_empty() {
13585 literals.push(ColumnConstraint::InList(ColumnInList {
13586 column,
13587 values: in_list_values,
13588 }));
13589 handled_conjuncts += 1;
13590 continue;
13591 }
13592 }
13593 }
13595 LlkvExpr::Or(or_children) => {
13597 if let Some((column, values)) = try_extract_or_as_in_list(or_children, table_infos)
13598 {
13599 literals.push(ColumnConstraint::InList(ColumnInList { column, values }));
13601 handled_conjuncts += 1;
13602 continue;
13603 }
13604 }
13606 LlkvExpr::Pred(filter) => {
13608 if let Operator::Equals(ref literal) = filter.op
13610 && let Some(column) = resolve_column_reference(
13611 &ScalarExpr::Column(filter.field_id.clone()),
13612 table_infos,
13613 )
13614 && let Some(value) = PlanValue::from_literal_for_join(literal)
13615 {
13616 literals.push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13617 handled_conjuncts += 1;
13618 continue;
13619 }
13620 }
13622 _ => {
13623 }
13625 }
13626 }
13627
13628 Some(JoinConstraintPlan {
13629 equalities,
13630 literals,
13631 unsatisfiable,
13632 total_conjuncts,
13633 handled_conjuncts,
13634 })
13635}
13636
13637fn resolve_column_reference(
13638 expr: &ScalarExpr<String>,
13639 table_infos: &[TableInfo<'_>],
13640) -> Option<ColumnRef> {
13641 let name = match expr {
13642 ScalarExpr::Column(name) => name.trim(),
13643 _ => return None,
13644 };
13645
13646 let mut parts: Vec<&str> = name
13647 .trim_start_matches('.')
13648 .split('.')
13649 .filter(|segment| !segment.is_empty())
13650 .collect();
13651
13652 if parts.is_empty() {
13653 return None;
13654 }
13655
13656 let column_part = parts.pop()?.to_ascii_lowercase();
13657 if parts.is_empty() {
13658 for info in table_infos {
13662 if let Some(&col_idx) = info.column_map.get(&column_part) {
13663 return Some(ColumnRef {
13664 table: info.index,
13665 column: col_idx,
13666 });
13667 }
13668 }
13669 return None;
13670 }
13671
13672 let table_ident = parts.join(".").to_ascii_lowercase();
13673 for info in table_infos {
13674 if matches_table_ident(info.table_ref, &table_ident) {
13675 if let Some(&col_idx) = info.column_map.get(&column_part) {
13676 return Some(ColumnRef {
13677 table: info.index,
13678 column: col_idx,
13679 });
13680 } else {
13681 return None;
13682 }
13683 }
13684 }
13685 None
13686}
13687
13688fn matches_table_ident(table_ref: &llkv_plan::TableRef, ident: &str) -> bool {
13689 if ident.is_empty() {
13690 return false;
13691 }
13692 if let Some(alias) = &table_ref.alias
13693 && alias.to_ascii_lowercase() == ident
13694 {
13695 return true;
13696 }
13697 if table_ref.table.to_ascii_lowercase() == ident {
13698 return true;
13699 }
13700 if !table_ref.schema.is_empty() {
13701 let full = format!(
13702 "{}.{}",
13703 table_ref.schema.to_ascii_lowercase(),
13704 table_ref.table.to_ascii_lowercase()
13705 );
13706 if full == ident {
13707 return true;
13708 }
13709 }
13710 false
13711}
13712
13713fn extract_literal(expr: &ScalarExpr<String>) -> Option<&Literal> {
13714 match expr {
13715 ScalarExpr::Literal(lit) => Some(lit),
13716 _ => None,
13717 }
13718}
13719
13720#[derive(Default)]
13721struct DistinctState {
13722 seen: FxHashSet<CanonicalRow>,
13723}
13724
13725impl DistinctState {
13726 fn insert(&mut self, row: CanonicalRow) -> bool {
13727 self.seen.insert(row)
13728 }
13729}
13730
13731fn distinct_filter_batch(
13732 batch: RecordBatch,
13733 state: &mut DistinctState,
13734) -> ExecutorResult<Option<RecordBatch>> {
13735 if batch.num_rows() == 0 {
13736 return Ok(None);
13737 }
13738
13739 let mut keep_flags = Vec::with_capacity(batch.num_rows());
13740 let mut keep_count = 0usize;
13741
13742 for row_idx in 0..batch.num_rows() {
13743 let row = CanonicalRow::from_batch(&batch, row_idx)?;
13744 if state.insert(row) {
13745 keep_flags.push(true);
13746 keep_count += 1;
13747 } else {
13748 keep_flags.push(false);
13749 }
13750 }
13751
13752 if keep_count == 0 {
13753 return Ok(None);
13754 }
13755
13756 if keep_count == batch.num_rows() {
13757 return Ok(Some(batch));
13758 }
13759
13760 let mut builder = BooleanBuilder::with_capacity(batch.num_rows());
13761 for flag in keep_flags {
13762 builder.append_value(flag);
13763 }
13764 let mask = Arc::new(builder.finish());
13765
13766 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
13767 Error::InvalidArgumentError(format!("failed to apply DISTINCT filter: {err}"))
13768 })?;
13769
13770 Ok(Some(filtered))
13771}
13772
13773fn sort_record_batch_with_order(
13774 schema: &Arc<Schema>,
13775 batch: &RecordBatch,
13776 order_by: &[OrderByPlan],
13777) -> ExecutorResult<RecordBatch> {
13778 if order_by.is_empty() {
13779 return Ok(batch.clone());
13780 }
13781
13782 let mut sort_columns: Vec<SortColumn> = Vec::with_capacity(order_by.len());
13783
13784 for order in order_by {
13785 let column_index = match &order.target {
13786 OrderTarget::Column(name) => schema.index_of(name).map_err(|_| {
13787 Error::InvalidArgumentError(format!(
13788 "ORDER BY references unknown column '{}'",
13789 name
13790 ))
13791 })?,
13792 OrderTarget::Index(idx) => {
13793 if *idx >= batch.num_columns() {
13794 return Err(Error::InvalidArgumentError(format!(
13795 "ORDER BY position {} is out of bounds for {} columns",
13796 idx + 1,
13797 batch.num_columns()
13798 )));
13799 }
13800 *idx
13801 }
13802 OrderTarget::All => {
13803 return Err(Error::InvalidArgumentError(
13804 "ORDER BY ALL should be expanded before sorting".into(),
13805 ));
13806 }
13807 };
13808
13809 let source_array = batch.column(column_index);
13810
13811 let values: ArrayRef = match order.sort_type {
13812 OrderSortType::Native => Arc::clone(source_array),
13813 OrderSortType::CastTextToInteger => {
13814 let strings = source_array
13815 .as_any()
13816 .downcast_ref::<StringArray>()
13817 .ok_or_else(|| {
13818 Error::InvalidArgumentError(
13819 "ORDER BY CAST expects the underlying column to be TEXT".into(),
13820 )
13821 })?;
13822 let mut builder = Int64Builder::with_capacity(strings.len());
13823 for i in 0..strings.len() {
13824 if strings.is_null(i) {
13825 builder.append_null();
13826 } else {
13827 match strings.value(i).parse::<i64>() {
13828 Ok(value) => builder.append_value(value),
13829 Err(_) => builder.append_null(),
13830 }
13831 }
13832 }
13833 Arc::new(builder.finish()) as ArrayRef
13834 }
13835 };
13836
13837 let sort_options = SortOptions {
13838 descending: !order.ascending,
13839 nulls_first: order.nulls_first,
13840 };
13841
13842 sort_columns.push(SortColumn {
13843 values,
13844 options: Some(sort_options),
13845 });
13846 }
13847
13848 let indices = lexsort_to_indices(&sort_columns, None).map_err(|err| {
13849 Error::InvalidArgumentError(format!("failed to compute ORDER BY indices: {err}"))
13850 })?;
13851
13852 let perm = indices
13853 .as_any()
13854 .downcast_ref::<UInt32Array>()
13855 .ok_or_else(|| Error::Internal("ORDER BY sorting produced unexpected index type".into()))?;
13856
13857 let mut reordered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
13858 for col_idx in 0..batch.num_columns() {
13859 let reordered = take(batch.column(col_idx), perm, None).map_err(|err| {
13860 Error::InvalidArgumentError(format!(
13861 "failed to apply ORDER BY permutation to column {col_idx}: {err}"
13862 ))
13863 })?;
13864 reordered_columns.push(reordered);
13865 }
13866
13867 RecordBatch::try_new(Arc::clone(schema), reordered_columns)
13868 .map_err(|err| Error::Internal(format!("failed to build reordered ORDER BY batch: {err}")))
13869}
13870
13871#[cfg(test)]
13872mod tests {
13873 use super::*;
13874 use arrow::array::{Array, ArrayRef, Date32Array, Int64Array};
13875 use arrow::datatypes::{DataType, Field, Schema};
13876 use llkv_expr::expr::{BinaryOp, CompareOp};
13877 use llkv_expr::literal::Literal;
13878 use llkv_storage::pager::MemPager;
13879 use std::sync::Arc;
13880
13881 #[test]
13882 fn cross_product_context_evaluates_expressions() {
13883 let schema = Arc::new(Schema::new(vec![
13884 Field::new("main.tab2.a", DataType::Int64, false),
13885 Field::new("main.tab2.b", DataType::Int64, false),
13886 ]));
13887
13888 let batch = RecordBatch::try_new(
13889 Arc::clone(&schema),
13890 vec![
13891 Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef,
13892 Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef,
13893 ],
13894 )
13895 .expect("valid batch");
13896
13897 let lookup = build_cross_product_column_lookup(schema.as_ref(), &[], &[], &[]);
13898 let mut ctx = CrossProductExpressionContext::new(schema.as_ref(), lookup)
13899 .expect("context builds from schema");
13900
13901 let literal_expr: ScalarExpr<String> = ScalarExpr::literal(67);
13902 let literal = ctx
13903 .evaluate(&literal_expr, &batch)
13904 .expect("literal evaluation succeeds");
13905 let literal_array = literal
13906 .as_any()
13907 .downcast_ref::<Int64Array>()
13908 .expect("int64 literal result");
13909 assert_eq!(literal_array.len(), 3);
13910 assert!(literal_array.iter().all(|value| value == Some(67)));
13911
13912 let add_expr = ScalarExpr::binary(
13913 ScalarExpr::column("tab2.a".to_string()),
13914 BinaryOp::Add,
13915 ScalarExpr::literal(5),
13916 );
13917 let added = ctx
13918 .evaluate(&add_expr, &batch)
13919 .expect("column addition succeeds");
13920 let added_array = added
13921 .as_any()
13922 .downcast_ref::<Int64Array>()
13923 .expect("int64 addition result");
13924 assert_eq!(added_array.values(), &[6, 7, 8]);
13925 }
13926
13927 #[test]
13928 fn cross_product_filter_handles_date32_columns() {
13929 let schema = Arc::new(Schema::new(vec![Field::new(
13930 "orders.o_orderdate",
13931 DataType::Date32,
13932 false,
13933 )]));
13934
13935 let batch = RecordBatch::try_new(
13936 Arc::clone(&schema),
13937 vec![Arc::new(Date32Array::from(vec![0, 1, 3])) as ArrayRef],
13938 )
13939 .expect("valid batch");
13940
13941 let lookup = build_cross_product_column_lookup(schema.as_ref(), &[], &[], &[]);
13942 let mut ctx = CrossProductExpressionContext::new(schema.as_ref(), lookup)
13943 .expect("context builds from schema");
13944
13945 let field_id = ctx
13946 .schema()
13947 .columns
13948 .first()
13949 .expect("schema exposes date column")
13950 .field_id;
13951
13952 let predicate = LlkvExpr::Compare {
13953 left: ScalarExpr::Column(field_id),
13954 op: CompareOp::GtEq,
13955 right: ScalarExpr::Literal(Literal::Date32(1)),
13956 };
13957
13958 let truths = ctx
13959 .evaluate_predicate_truths(&predicate, &batch, &mut |_, _, _, _| Ok(None))
13960 .expect("date comparison evaluates");
13961
13962 assert_eq!(truths, vec![Some(false), Some(true), Some(true)]);
13963 }
13964
13965 #[test]
13966 fn group_by_handles_date32_columns() {
13967 let array: ArrayRef = Arc::new(Date32Array::from(vec![Some(3), None, Some(-7)]));
13968
13969 let first = group_key_value(&array, 0).expect("extract first group key");
13970 assert_eq!(first, GroupKeyValue::Int(3));
13971
13972 let second = group_key_value(&array, 1).expect("extract second group key");
13973 assert_eq!(second, GroupKeyValue::Null);
13974
13975 let third = group_key_value(&array, 2).expect("extract third group key");
13976 assert_eq!(third, GroupKeyValue::Int(-7));
13977 }
13978
13979 #[test]
13980 fn aggregate_expr_allows_numeric_casts() {
13981 let expr = ScalarExpr::Cast {
13982 expr: Box::new(ScalarExpr::literal(31)),
13983 data_type: DataType::Int32,
13984 };
13985 let aggregates = FxHashMap::default();
13986
13987 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
13988 .expect("cast should succeed for in-range integral values");
13989
13990 assert_eq!(value, Some(31));
13991 }
13992
13993 #[test]
13994 fn aggregate_expr_cast_rejects_out_of_range_values() {
13995 let expr = ScalarExpr::Cast {
13996 expr: Box::new(ScalarExpr::literal(-1)),
13997 data_type: DataType::UInt8,
13998 };
13999 let aggregates = FxHashMap::default();
14000
14001 let result = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates);
14002
14003 assert!(matches!(result, Err(Error::InvalidArgumentError(_))));
14004 }
14005
14006 #[test]
14007 fn aggregate_expr_null_literal_remains_null() {
14008 let expr = ScalarExpr::binary(
14009 ScalarExpr::literal(0),
14010 BinaryOp::Subtract,
14011 ScalarExpr::cast(ScalarExpr::literal(Literal::Null), DataType::Int64),
14012 );
14013 let aggregates = FxHashMap::default();
14014
14015 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
14016 .expect("expression should evaluate");
14017
14018 assert_eq!(value, None);
14019 }
14020
14021 #[test]
14022 fn aggregate_expr_divide_by_zero_returns_null() {
14023 let expr = ScalarExpr::binary(
14024 ScalarExpr::literal(10),
14025 BinaryOp::Divide,
14026 ScalarExpr::literal(0),
14027 );
14028 let aggregates = FxHashMap::default();
14029
14030 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
14031 .expect("division should evaluate");
14032
14033 assert_eq!(value, None);
14034 }
14035
14036 #[test]
14037 fn aggregate_expr_modulo_by_zero_returns_null() {
14038 let expr = ScalarExpr::binary(
14039 ScalarExpr::literal(10),
14040 BinaryOp::Modulo,
14041 ScalarExpr::literal(0),
14042 );
14043 let aggregates = FxHashMap::default();
14044
14045 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
14046 .expect("modulo should evaluate");
14047
14048 assert_eq!(value, None);
14049 }
14050
14051 #[test]
14052 fn constant_and_with_null_yields_null() {
14053 let expr = ScalarExpr::binary(
14054 ScalarExpr::literal(Literal::Null),
14055 BinaryOp::And,
14056 ScalarExpr::literal(1),
14057 );
14058
14059 let value = evaluate_constant_scalar_with_aggregates(&expr)
14060 .expect("expression should fold as constant");
14061
14062 assert!(matches!(value, Literal::Null));
14063 }
14064
14065 #[test]
14066 fn cross_product_handles_more_than_two_tables() {
14067 let schema_a = Arc::new(Schema::new(vec![Field::new(
14068 "main.t1.a",
14069 DataType::Int64,
14070 false,
14071 )]));
14072 let schema_b = Arc::new(Schema::new(vec![Field::new(
14073 "main.t2.b",
14074 DataType::Int64,
14075 false,
14076 )]));
14077 let schema_c = Arc::new(Schema::new(vec![Field::new(
14078 "main.t3.c",
14079 DataType::Int64,
14080 false,
14081 )]));
14082
14083 let batch_a = RecordBatch::try_new(
14084 Arc::clone(&schema_a),
14085 vec![Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef],
14086 )
14087 .expect("valid batch");
14088 let batch_b = RecordBatch::try_new(
14089 Arc::clone(&schema_b),
14090 vec![Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef],
14091 )
14092 .expect("valid batch");
14093 let batch_c = RecordBatch::try_new(
14094 Arc::clone(&schema_c),
14095 vec![Arc::new(Int64Array::from(vec![100])) as ArrayRef],
14096 )
14097 .expect("valid batch");
14098
14099 let data_a = TableCrossProductData {
14100 schema: schema_a,
14101 batches: vec![batch_a],
14102 column_counts: vec![1],
14103 table_indices: vec![0],
14104 };
14105 let data_b = TableCrossProductData {
14106 schema: schema_b,
14107 batches: vec![batch_b],
14108 column_counts: vec![1],
14109 table_indices: vec![1],
14110 };
14111 let data_c = TableCrossProductData {
14112 schema: schema_c,
14113 batches: vec![batch_c],
14114 column_counts: vec![1],
14115 table_indices: vec![2],
14116 };
14117
14118 let ab = cross_join_table_batches(data_a, data_b).expect("two-table product");
14119 assert_eq!(ab.schema.fields().len(), 2);
14120 assert_eq!(ab.batches.len(), 1);
14121 assert_eq!(ab.batches[0].num_rows(), 6);
14122
14123 let abc = cross_join_table_batches(ab, data_c).expect("three-table product");
14124 assert_eq!(abc.schema.fields().len(), 3);
14125 assert_eq!(abc.batches.len(), 1);
14126
14127 let final_batch = &abc.batches[0];
14128 assert_eq!(final_batch.num_rows(), 6);
14129
14130 let col_a = final_batch
14131 .column(0)
14132 .as_any()
14133 .downcast_ref::<Int64Array>()
14134 .expect("left column values");
14135 assert_eq!(col_a.values(), &[1, 1, 1, 2, 2, 2]);
14136
14137 let col_b = final_batch
14138 .column(1)
14139 .as_any()
14140 .downcast_ref::<Int64Array>()
14141 .expect("middle column values");
14142 assert_eq!(col_b.values(), &[10, 20, 30, 10, 20, 30]);
14143
14144 let col_c = final_batch
14145 .column(2)
14146 .as_any()
14147 .downcast_ref::<Int64Array>()
14148 .expect("right column values");
14149 assert_eq!(col_c.values(), &[100, 100, 100, 100, 100, 100]);
14150 }
14151}