1use arrow::array::{
18 Array, ArrayRef, BooleanArray, BooleanBuilder, Float32Array, Float64Array, Int8Array,
19 Int16Array, Int32Array, Int64Array, Int64Builder, LargeStringArray, RecordBatch, StringArray,
20 StructArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, new_null_array,
21};
22use arrow::compute::{
23 SortColumn, SortOptions, cast, concat_batches, filter_record_batch, lexsort_to_indices, take,
24};
25use arrow::datatypes::{DataType, Field, Float64Type, Int64Type, Schema};
26use llkv_aggregate::{AggregateAccumulator, AggregateKind, AggregateSpec, AggregateState};
27use llkv_column_map::gather::gather_indices_from_batches;
28use llkv_column_map::store::Projection as StoreProjection;
29use llkv_column_map::types::LogicalFieldId;
30use llkv_expr::SubqueryId;
31use llkv_expr::expr::{
32 AggregateCall, BinaryOp, CompareOp, Expr as LlkvExpr, Filter, Operator, ScalarExpr,
33};
34use llkv_expr::literal::Literal;
35use llkv_expr::typed_predicate::{
36 build_bool_predicate, build_fixed_width_predicate, build_var_width_predicate,
37};
38use llkv_join::cross_join_pair;
39use llkv_plan::{
40 AggregateExpr, AggregateFunction, CanonicalRow, CompoundOperator, CompoundQuantifier,
41 CompoundSelectComponent, CompoundSelectPlan, OrderByPlan, OrderSortType, OrderTarget,
42 PlanValue, SelectPlan, SelectProjection,
43};
44use llkv_result::Error;
45use llkv_storage::pager::Pager;
46use llkv_table::table::{
47 RowIdFilter, ScanOrderDirection, ScanOrderSpec, ScanOrderTransform, ScanProjection,
48 ScanStreamOptions,
49};
50use llkv_table::types::FieldId;
51use llkv_table::{NumericArray, NumericArrayMap, NumericKernels, ROW_ID_FIELD_ID};
52use rayon::prelude::*;
53use rustc_hash::{FxHashMap, FxHashSet};
54use simd_r_drive_entry_handle::EntryHandle;
55use std::convert::TryFrom;
56use std::fmt;
57use std::sync::Arc;
58use std::sync::atomic::Ordering;
59
60#[cfg(test)]
61use std::cell::RefCell;
62
63pub mod insert;
68pub mod translation;
69pub mod types;
70pub mod utils;
71
72pub type ExecutorResult<T> = Result<T, Error>;
78
79use crate::translation::schema::infer_computed_data_type;
80pub use insert::{
81 build_array_for_column, normalize_insert_value_for_column, resolve_insert_columns,
82};
83pub use translation::{
84 build_projected_columns, build_wildcard_projections, full_table_scan_filter,
85 resolve_field_id_from_schema, schema_for_projections, translate_predicate,
86 translate_predicate_with, translate_scalar, translate_scalar_with,
87};
88pub use types::{
89 ExecutorColumn, ExecutorMultiColumnUnique, ExecutorRowBatch, ExecutorSchema, ExecutorTable,
90 ExecutorTableProvider,
91};
92pub use utils::current_time_micros;
93
94#[derive(Clone, Debug, PartialEq, Eq, Hash)]
95enum GroupKeyValue {
96 Null,
97 Int(i64),
98 Bool(bool),
99 String(String),
100}
101
102#[derive(Clone, Debug, PartialEq)]
105enum AggregateValue {
106 Null,
107 Int64(i64),
108 Float64(f64),
109 String(String),
110}
111
112impl AggregateValue {
113 fn as_i64(&self) -> Option<i64> {
115 match self {
116 AggregateValue::Null => None,
117 AggregateValue::Int64(v) => Some(*v),
118 AggregateValue::Float64(v) => Some(*v as i64),
119 AggregateValue::String(s) => s.parse().ok(),
120 }
121 }
122
123 #[allow(dead_code)]
125 fn as_f64(&self) -> Option<f64> {
126 match self {
127 AggregateValue::Null => None,
128 AggregateValue::Int64(v) => Some(*v as f64),
129 AggregateValue::Float64(v) => Some(*v),
130 AggregateValue::String(s) => s.parse().ok(),
131 }
132 }
133}
134
135struct GroupState {
136 batch: RecordBatch,
137 row_idx: usize,
138}
139
140struct GroupAggregateState {
142 representative_batch_idx: usize,
143 representative_row: usize,
144 row_locations: Vec<(usize, usize)>,
145}
146
147struct OutputColumn {
148 field: Field,
149 source: OutputSource,
150}
151
152enum OutputSource {
153 TableColumn { index: usize },
154 Computed { projection_index: usize },
155}
156
157#[cfg(test)]
162thread_local! {
163 static QUERY_LABEL_STACK: RefCell<Vec<String>> = const { RefCell::new(Vec::new()) };
164}
165
166pub struct QueryLogGuard {
168 _private: (),
169}
170
171#[cfg(test)]
174pub fn push_query_label(label: impl Into<String>) -> QueryLogGuard {
175 QUERY_LABEL_STACK.with(|stack| stack.borrow_mut().push(label.into()));
176 QueryLogGuard { _private: () }
177}
178
179#[cfg(not(test))]
184#[inline]
185pub fn push_query_label(_label: impl Into<String>) -> QueryLogGuard {
186 QueryLogGuard { _private: () }
187}
188
189#[cfg(test)]
190impl Drop for QueryLogGuard {
191 fn drop(&mut self) {
192 QUERY_LABEL_STACK.with(|stack| {
193 let _ = stack.borrow_mut().pop();
194 });
195 }
196}
197
198#[cfg(not(test))]
199impl Drop for QueryLogGuard {
200 #[inline]
201 fn drop(&mut self) {
202 }
204}
205
206#[cfg(test)]
208pub fn current_query_label() -> Option<String> {
209 QUERY_LABEL_STACK.with(|stack| stack.borrow().last().cloned())
210}
211
212#[cfg(not(test))]
216#[inline]
217pub fn current_query_label() -> Option<String> {
218 None
219}
220
221fn try_extract_simple_column<F: AsRef<str>>(expr: &ScalarExpr<F>) -> Option<&str> {
236 match expr {
237 ScalarExpr::Column(name) => Some(name.as_ref()),
238 ScalarExpr::Binary { left, op, right } => {
240 match op {
242 BinaryOp::Add => {
243 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Integer(0))) {
245 return try_extract_simple_column(right);
246 }
247 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Integer(0))) {
248 return try_extract_simple_column(left);
249 }
250 }
251 BinaryOp::Multiply => {
254 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Integer(1))) {
256 return try_extract_simple_column(right);
257 }
258 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Integer(1))) {
259 return try_extract_simple_column(left);
260 }
261 }
262 _ => {}
263 }
264 None
265 }
266 _ => None,
267 }
268}
269
270fn plan_values_to_arrow_array(values: &[PlanValue]) -> ExecutorResult<ArrayRef> {
275 use arrow::array::{Float64Array, Int64Array, StringArray};
276
277 let mut value_type = None;
279 for v in values {
280 if !matches!(v, PlanValue::Null) {
281 value_type = Some(v);
282 break;
283 }
284 }
285
286 match value_type {
287 Some(PlanValue::Integer(_)) => {
288 let int_values: Vec<Option<i64>> = values
289 .iter()
290 .map(|v| match v {
291 PlanValue::Integer(i) => Some(*i),
292 PlanValue::Null => None,
293 _ => Some(0), })
295 .collect();
296 Ok(Arc::new(Int64Array::from(int_values)) as ArrayRef)
297 }
298 Some(PlanValue::Float(_)) => {
299 let float_values: Vec<Option<f64>> = values
300 .iter()
301 .map(|v| match v {
302 PlanValue::Float(f) => Some(*f),
303 PlanValue::Integer(i) => Some(*i as f64),
304 PlanValue::Null => None,
305 _ => Some(0.0), })
307 .collect();
308 Ok(Arc::new(Float64Array::from(float_values)) as ArrayRef)
309 }
310 Some(PlanValue::String(_)) => {
311 let string_values: Vec<Option<&str>> = values
312 .iter()
313 .map(|v| match v {
314 PlanValue::String(s) => Some(s.as_str()),
315 PlanValue::Null => None,
316 _ => Some(""), })
318 .collect();
319 Ok(Arc::new(StringArray::from(string_values)) as ArrayRef)
320 }
321 _ => {
322 let null_values: Vec<Option<i64>> = vec![None; values.len()];
324 Ok(Arc::new(Int64Array::from(null_values)) as ArrayRef)
325 }
326 }
327}
328
329fn resolve_column_name_to_index(
339 col_name: &str,
340 column_lookup_map: &FxHashMap<String, usize>,
341) -> Option<usize> {
342 let col_lower = col_name.to_ascii_lowercase();
343
344 if let Some(&idx) = column_lookup_map.get(&col_lower) {
346 return Some(idx);
347 }
348
349 let unqualified = col_name
352 .rsplit('.')
353 .next()
354 .unwrap_or(col_name)
355 .to_ascii_lowercase();
356 column_lookup_map
357 .iter()
358 .find(|(k, _)| k.ends_with(&format!(".{}", unqualified)) || k == &&unqualified)
359 .map(|(_, &idx)| idx)
360}
361
362fn get_or_insert_column_projection<P>(
364 projections: &mut Vec<ScanProjection>,
365 cache: &mut FxHashMap<FieldId, usize>,
366 table: &ExecutorTable<P>,
367 column: &ExecutorColumn,
368) -> usize
369where
370 P: Pager<Blob = EntryHandle> + Send + Sync,
371{
372 if let Some(existing) = cache.get(&column.field_id) {
373 return *existing;
374 }
375
376 let projection_index = projections.len();
377 let alias = if column.name.is_empty() {
378 format!("col{}", column.field_id)
379 } else {
380 column.name.clone()
381 };
382 projections.push(ScanProjection::from(StoreProjection::with_alias(
383 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
384 alias,
385 )));
386 cache.insert(column.field_id, projection_index);
387 projection_index
388}
389
390fn ensure_computed_projection<P>(
392 expr: &ScalarExpr<String>,
393 table: &ExecutorTable<P>,
394 projections: &mut Vec<ScanProjection>,
395 cache: &mut FxHashMap<String, (usize, DataType)>,
396 alias_counter: &mut usize,
397) -> ExecutorResult<(usize, DataType)>
398where
399 P: Pager<Blob = EntryHandle> + Send + Sync,
400{
401 let key = format!("{:?}", expr);
402 if let Some((idx, dtype)) = cache.get(&key) {
403 return Ok((*idx, dtype.clone()));
404 }
405
406 let translated = translate_scalar(expr, table.schema.as_ref(), |name| {
407 Error::InvalidArgumentError(format!("unknown column '{}' in aggregate expression", name))
408 })?;
409 let data_type = infer_computed_data_type(table.schema.as_ref(), &translated)?;
410 if data_type == DataType::Null {
411 tracing::debug!(
412 "ensure_computed_projection inferred Null type for expr: {:?}",
413 expr
414 );
415 }
416 let alias = format!("__agg_expr_{}", *alias_counter);
417 *alias_counter += 1;
418 let projection_index = projections.len();
419 projections.push(ScanProjection::computed(translated, alias));
420 cache.insert(key, (projection_index, data_type.clone()));
421 Ok((projection_index, data_type))
422}
423
424pub struct QueryExecutor<P>
426where
427 P: Pager<Blob = EntryHandle> + Send + Sync,
428{
429 provider: Arc<dyn ExecutorTableProvider<P>>,
430}
431
432impl<P> QueryExecutor<P>
433where
434 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
435{
436 pub fn new(provider: Arc<dyn ExecutorTableProvider<P>>) -> Self {
437 Self { provider }
438 }
439
440 pub fn execute_select(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
441 self.execute_select_with_filter(plan, None)
442 }
443
444 pub fn execute_select_with_filter(
445 &self,
446 plan: SelectPlan,
447 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
448 ) -> ExecutorResult<SelectExecution<P>> {
449 if plan.compound.is_some() {
450 return self.execute_compound_select(plan, row_filter);
451 }
452
453 if plan.tables.is_empty() {
455 return self.execute_select_without_table(plan);
456 }
457
458 if !plan.group_by.is_empty() {
459 if plan.tables.len() > 1 {
460 return self.execute_cross_product(plan);
461 }
462 let table_ref = &plan.tables[0];
463 let table = self.provider.get_table(&table_ref.qualified_name())?;
464 let display_name = table_ref.qualified_name();
465 return self.execute_group_by_single_table(table, display_name, plan, row_filter);
466 }
467
468 if plan.tables.len() > 1 {
470 return self.execute_cross_product(plan);
471 }
472
473 let table_ref = &plan.tables[0];
475 let table = self.provider.get_table(&table_ref.qualified_name())?;
476 let display_name = table_ref.qualified_name();
477
478 if !plan.aggregates.is_empty() {
479 self.execute_aggregates(table, display_name, plan, row_filter)
480 } else if self.has_computed_aggregates(&plan) {
481 self.execute_computed_aggregates(table, display_name, plan, row_filter)
483 } else {
484 self.execute_projection(table, display_name, plan, row_filter)
485 }
486 }
487
488 fn execute_compound_select(
508 &self,
509 plan: SelectPlan,
510 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
511 ) -> ExecutorResult<SelectExecution<P>> {
512 let order_by = plan.order_by.clone();
513 let compound = plan.compound.expect("compound plan should be present");
514
515 let CompoundSelectPlan {
516 initial,
517 operations,
518 } = compound;
519
520 let initial_exec = self.execute_select_with_filter(*initial, row_filter.clone())?;
521 let schema = initial_exec.schema();
522 let mut rows = initial_exec.into_rows()?;
523 let mut distinct_cache: Option<FxHashSet<Vec<u8>>> = None;
524
525 for component in operations {
526 let exec = self.execute_select_with_filter(component.plan, row_filter.clone())?;
527 let other_schema = exec.schema();
528 ensure_schema_compatibility(schema.as_ref(), other_schema.as_ref())?;
529 let other_rows = exec.into_rows()?;
530
531 match (component.operator, component.quantifier) {
532 (CompoundOperator::Union, CompoundQuantifier::All) => {
533 rows.extend(other_rows);
534 distinct_cache = None;
535 }
536 (CompoundOperator::Union, CompoundQuantifier::Distinct) => {
537 ensure_distinct_rows(&mut rows, &mut distinct_cache);
538 let cache = distinct_cache
539 .as_mut()
540 .expect("distinct cache should be initialized");
541 for row in other_rows {
542 let key = encode_row(&row);
543 if cache.insert(key) {
544 rows.push(row);
545 }
546 }
547 }
548 (CompoundOperator::Except, CompoundQuantifier::Distinct) => {
549 ensure_distinct_rows(&mut rows, &mut distinct_cache);
550 let cache = distinct_cache
551 .as_mut()
552 .expect("distinct cache should be initialized");
553 if rows.is_empty() {
554 continue;
555 }
556 let mut remove_keys = FxHashSet::default();
557 for row in other_rows {
558 remove_keys.insert(encode_row(&row));
559 }
560 if remove_keys.is_empty() {
561 continue;
562 }
563 rows.retain(|row| {
564 let key = encode_row(row);
565 if remove_keys.contains(&key) {
566 cache.remove(&key);
567 false
568 } else {
569 true
570 }
571 });
572 }
573 (CompoundOperator::Except, CompoundQuantifier::All) => {
574 return Err(Error::InvalidArgumentError(
575 "EXCEPT ALL is not supported yet".into(),
576 ));
577 }
578 (CompoundOperator::Intersect, CompoundQuantifier::Distinct) => {
579 ensure_distinct_rows(&mut rows, &mut distinct_cache);
580 let mut right_keys = FxHashSet::default();
581 for row in other_rows {
582 right_keys.insert(encode_row(&row));
583 }
584 if right_keys.is_empty() {
585 rows.clear();
586 distinct_cache = Some(FxHashSet::default());
587 continue;
588 }
589 let mut new_rows = Vec::new();
590 let mut new_cache = FxHashSet::default();
591 for row in rows.drain(..) {
592 let key = encode_row(&row);
593 if right_keys.contains(&key) && new_cache.insert(key) {
594 new_rows.push(row);
595 }
596 }
597 rows = new_rows;
598 distinct_cache = Some(new_cache);
599 }
600 (CompoundOperator::Intersect, CompoundQuantifier::All) => {
601 return Err(Error::InvalidArgumentError(
602 "INTERSECT ALL is not supported yet".into(),
603 ));
604 }
605 }
606 }
607
608 let mut batch = rows_to_record_batch(schema.clone(), &rows)?;
609 if !order_by.is_empty() && batch.num_rows() > 0 {
610 batch = sort_record_batch_with_order(&schema, &batch, &order_by)?;
611 }
612
613 Ok(SelectExecution::new_single_batch(
614 String::new(),
615 schema,
616 batch,
617 ))
618 }
619
620 fn has_computed_aggregates(&self, plan: &SelectPlan) -> bool {
622 plan.projections.iter().any(|proj| {
623 if let SelectProjection::Computed { expr, .. } = proj {
624 Self::expr_contains_aggregate(expr)
625 } else {
626 false
627 }
628 })
629 }
630
631 fn predicate_contains_aggregate(expr: &llkv_expr::expr::Expr<String>) -> bool {
633 match expr {
634 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
635 exprs.iter().any(Self::predicate_contains_aggregate)
636 }
637 llkv_expr::expr::Expr::Not(inner) => Self::predicate_contains_aggregate(inner),
638 llkv_expr::expr::Expr::Compare { left, right, .. } => {
639 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
640 }
641 llkv_expr::expr::Expr::InList { expr, list, .. } => {
642 Self::expr_contains_aggregate(expr)
643 || list.iter().any(|e| Self::expr_contains_aggregate(e))
644 }
645 llkv_expr::expr::Expr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
646 llkv_expr::expr::Expr::Literal(_) => false,
647 llkv_expr::expr::Expr::Pred(_) => false,
648 llkv_expr::expr::Expr::Exists(_) => false,
649 }
650 }
651
652 fn expr_contains_aggregate(expr: &ScalarExpr<String>) -> bool {
654 match expr {
655 ScalarExpr::Aggregate(_) => true,
656 ScalarExpr::Binary { left, right, .. } => {
657 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
658 }
659 ScalarExpr::Compare { left, right, .. } => {
660 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
661 }
662 ScalarExpr::GetField { base, .. } => Self::expr_contains_aggregate(base),
663 ScalarExpr::Cast { expr, .. } => Self::expr_contains_aggregate(expr),
664 ScalarExpr::Not(expr) => Self::expr_contains_aggregate(expr),
665 ScalarExpr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
666 ScalarExpr::Case {
667 operand,
668 branches,
669 else_expr,
670 } => {
671 operand
672 .as_deref()
673 .map(Self::expr_contains_aggregate)
674 .unwrap_or(false)
675 || branches.iter().any(|(when_expr, then_expr)| {
676 Self::expr_contains_aggregate(when_expr)
677 || Self::expr_contains_aggregate(then_expr)
678 })
679 || else_expr
680 .as_deref()
681 .map(Self::expr_contains_aggregate)
682 .unwrap_or(false)
683 }
684 ScalarExpr::Coalesce(items) => items.iter().any(Self::expr_contains_aggregate),
685 ScalarExpr::Column(_) | ScalarExpr::Literal(_) | ScalarExpr::Random => false,
686 ScalarExpr::ScalarSubquery(_) => false,
687 }
688 }
689
690 fn evaluate_exists_subquery(
691 &self,
692 context: &mut CrossProductExpressionContext,
693 subquery: &llkv_plan::FilterSubquery,
694 batch: &RecordBatch,
695 row_idx: usize,
696 ) -> ExecutorResult<bool> {
697 let bindings =
698 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
699 let bound_plan = bind_select_plan(&subquery.plan, &bindings)?;
700 let execution = self.execute_select(bound_plan)?;
701 let mut found = false;
702 execution.stream(|inner_batch| {
703 if inner_batch.num_rows() > 0 {
704 found = true;
705 }
706 Ok(())
707 })?;
708 Ok(found)
709 }
710
711 fn evaluate_scalar_subquery_literal(
712 &self,
713 context: &mut CrossProductExpressionContext,
714 subquery: &llkv_plan::ScalarSubquery,
715 batch: &RecordBatch,
716 row_idx: usize,
717 ) -> ExecutorResult<Literal> {
718 let bindings =
719 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
720 let bound_plan = bind_select_plan(&subquery.plan, &bindings)?;
721 let execution = self.execute_select(bound_plan)?;
722 let mut rows_seen: usize = 0;
723 let mut result: Option<Literal> = None;
724 execution.stream(|inner_batch| {
725 if inner_batch.num_columns() != 1 {
726 return Err(Error::InvalidArgumentError(
727 "scalar subquery must return exactly one column".into(),
728 ));
729 }
730 let column = inner_batch.column(0).clone();
731 for idx in 0..inner_batch.num_rows() {
732 if rows_seen >= 1 {
733 return Err(Error::InvalidArgumentError(
734 "scalar subquery produced more than one row".into(),
735 ));
736 }
737 rows_seen = rows_seen.saturating_add(1);
738 result = Some(array_value_to_literal(&column, idx)?);
739 }
740 Ok(())
741 })?;
742
743 if rows_seen == 0 {
744 Ok(Literal::Null)
745 } else {
746 result
747 .ok_or_else(|| Error::Internal("scalar subquery evaluation missing result".into()))
748 }
749 }
750
751 fn evaluate_scalar_subquery_numeric(
752 &self,
753 context: &mut CrossProductExpressionContext,
754 subquery: &llkv_plan::ScalarSubquery,
755 batch: &RecordBatch,
756 ) -> ExecutorResult<NumericArray> {
757 let mut values: Vec<Option<f64>> = Vec::with_capacity(batch.num_rows());
758 let mut all_integer = true;
759
760 for row_idx in 0..batch.num_rows() {
761 let literal =
762 self.evaluate_scalar_subquery_literal(context, subquery, batch, row_idx)?;
763 match literal {
764 Literal::Null => values.push(None),
765 Literal::Integer(value) => {
766 let cast = i64::try_from(value).map_err(|_| {
767 Error::InvalidArgumentError(
768 "scalar subquery integer result exceeds supported range".into(),
769 )
770 })?;
771 values.push(Some(cast as f64));
772 }
773 Literal::Float(value) => {
774 all_integer = false;
775 values.push(Some(value));
776 }
777 Literal::Boolean(flag) => {
778 let numeric = if flag { 1.0 } else { 0.0 };
779 values.push(Some(numeric));
780 }
781 Literal::String(_) | Literal::Struct(_) => {
782 return Err(Error::InvalidArgumentError(
783 "scalar subquery produced non-numeric result in numeric context".into(),
784 ));
785 }
786 }
787 }
788
789 if all_integer {
790 let iter = values.into_iter().map(|opt| opt.map(|v| v as i64));
791 let array = Int64Array::from_iter(iter);
792 NumericArray::try_from_arrow(&(Arc::new(array) as ArrayRef))
793 } else {
794 let array = Float64Array::from_iter(values);
795 NumericArray::try_from_arrow(&(Arc::new(array) as ArrayRef))
796 }
797 }
798
799 fn evaluate_projection_expression(
800 &self,
801 context: &mut CrossProductExpressionContext,
802 expr: &ScalarExpr<String>,
803 batch: &RecordBatch,
804 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
805 ) -> ExecutorResult<ArrayRef> {
806 let translated = translate_scalar(expr, context.schema(), |name| {
807 Error::InvalidArgumentError(format!(
808 "column '{}' not found in cross product result",
809 name
810 ))
811 })?;
812
813 let mut subquery_ids: FxHashSet<SubqueryId> = FxHashSet::default();
814 collect_scalar_subquery_ids(&translated, &mut subquery_ids);
815
816 let mut mapping: FxHashMap<SubqueryId, FieldId> = FxHashMap::default();
817 for subquery_id in subquery_ids {
818 let info = scalar_lookup
819 .get(&subquery_id)
820 .ok_or_else(|| Error::Internal("missing scalar subquery metadata".into()))?;
821 let field_id = context.allocate_synthetic_field_id()?;
822 let numeric = self.evaluate_scalar_subquery_numeric(context, info, batch)?;
823 context.numeric_cache.insert(field_id, numeric);
824 mapping.insert(subquery_id, field_id);
825 }
826
827 let rewritten = rewrite_scalar_expr_for_subqueries(&translated, &mapping);
828 context.evaluate_numeric(&rewritten, batch)
829 }
830
831 fn execute_select_without_table(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
833 use arrow::array::ArrayRef;
834 use arrow::datatypes::Field;
835
836 let mut fields = Vec::new();
838 let mut arrays: Vec<ArrayRef> = Vec::new();
839
840 for proj in &plan.projections {
841 match proj {
842 SelectProjection::Computed { expr, alias } => {
843 let literal =
844 evaluate_constant_scalar_with_aggregates(expr).ok_or_else(|| {
845 Error::InvalidArgumentError(
846 "SELECT without FROM only supports constant expressions".into(),
847 )
848 })?;
849 let (dtype, array) = Self::literal_to_array(&literal)?;
850
851 fields.push(Field::new(alias.clone(), dtype, true));
852 arrays.push(array);
853 }
854 _ => {
855 return Err(Error::InvalidArgumentError(
856 "SELECT without FROM only supports computed projections".into(),
857 ));
858 }
859 }
860 }
861
862 let schema = Arc::new(Schema::new(fields));
863 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)
864 .map_err(|e| Error::Internal(format!("failed to create record batch: {}", e)))?;
865
866 if plan.distinct {
867 let mut state = DistinctState::default();
868 batch = match distinct_filter_batch(batch, &mut state)? {
869 Some(filtered) => filtered,
870 None => RecordBatch::new_empty(Arc::clone(&schema)),
871 };
872 }
873
874 let schema = batch.schema();
875
876 Ok(SelectExecution::new_single_batch(
877 String::new(), schema,
879 batch,
880 ))
881 }
882
883 fn literal_to_array(lit: &llkv_expr::literal::Literal) -> ExecutorResult<(DataType, ArrayRef)> {
885 use arrow::array::{
886 ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, StructArray,
887 new_null_array,
888 };
889 use arrow::datatypes::{DataType, Field};
890 use llkv_expr::literal::Literal;
891
892 match lit {
893 Literal::Integer(v) => {
894 let val = i64::try_from(*v).unwrap_or(0);
895 Ok((
896 DataType::Int64,
897 Arc::new(Int64Array::from(vec![val])) as ArrayRef,
898 ))
899 }
900 Literal::Float(v) => Ok((
901 DataType::Float64,
902 Arc::new(Float64Array::from(vec![*v])) as ArrayRef,
903 )),
904 Literal::Boolean(v) => Ok((
905 DataType::Boolean,
906 Arc::new(BooleanArray::from(vec![*v])) as ArrayRef,
907 )),
908 Literal::String(v) => Ok((
909 DataType::Utf8,
910 Arc::new(StringArray::from(vec![v.clone()])) as ArrayRef,
911 )),
912 Literal::Null => Ok((DataType::Null, new_null_array(&DataType::Null, 1))),
913 Literal::Struct(struct_fields) => {
914 let mut inner_fields = Vec::new();
916 let mut inner_arrays = Vec::new();
917
918 for (field_name, field_lit) in struct_fields {
919 let (field_dtype, field_array) = Self::literal_to_array(field_lit)?;
920 inner_fields.push(Field::new(field_name.clone(), field_dtype, true));
921 inner_arrays.push(field_array);
922 }
923
924 let struct_array =
925 StructArray::try_new(inner_fields.clone().into(), inner_arrays, None).map_err(
926 |e| Error::Internal(format!("failed to create struct array: {}", e)),
927 )?;
928
929 Ok((
930 DataType::Struct(inner_fields.into()),
931 Arc::new(struct_array) as ArrayRef,
932 ))
933 }
934 }
935 }
936
937 fn execute_cross_product(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
939 use arrow::compute::concat_batches;
940
941 if plan.tables.len() < 2 {
942 return Err(Error::InvalidArgumentError(
943 "cross product requires at least 2 tables".into(),
944 ));
945 }
946
947 let mut tables_with_handles = Vec::with_capacity(plan.tables.len());
948 for table_ref in &plan.tables {
949 let qualified_name = table_ref.qualified_name();
950 let table = self.provider.get_table(&qualified_name)?;
951 tables_with_handles.push((table_ref.clone(), table));
952 }
953
954 let display_name = tables_with_handles
955 .iter()
956 .map(|(table_ref, _)| table_ref.qualified_name())
957 .collect::<Vec<_>>()
958 .join(",");
959
960 let mut remaining_filter = plan.filter.clone();
961
962 let join_data = if plan.scalar_subqueries.is_empty() && remaining_filter.as_ref().is_some()
964 {
965 self.try_execute_hash_join(&plan, &tables_with_handles)?
966 } else {
967 None
968 };
969
970 let current = if let Some((joined, handled_all_predicates)) = join_data {
971 if handled_all_predicates {
973 remaining_filter = None;
974 }
975 joined
976 } else {
977 let has_joins = !plan.joins.is_empty();
979
980 if has_joins && tables_with_handles.len() == 2 {
981 use llkv_join::{JoinOptions, TableJoinExt};
983
984 let (left_ref, left_table) = &tables_with_handles[0];
985 let (right_ref, right_table) = &tables_with_handles[1];
986
987 let join_metadata = plan.joins.first().ok_or_else(|| {
988 Error::InvalidArgumentError("expected join metadata for two-table join".into())
989 })?;
990
991 let join_type = match join_metadata.join_type {
992 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
993 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
994 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
995 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
996 };
997
998 tracing::debug!(
999 "Using llkv-join for {join_type:?} join between {} and {}",
1000 left_ref.qualified_name(),
1001 right_ref.qualified_name()
1002 );
1003
1004 let left_col_count = left_table.schema.columns.len();
1005 let right_col_count = right_table.schema.columns.len();
1006
1007 let mut combined_fields = Vec::with_capacity(left_col_count + right_col_count);
1008 for col in &left_table.schema.columns {
1009 combined_fields.push(Field::new(
1010 col.name.clone(),
1011 col.data_type.clone(),
1012 col.nullable,
1013 ));
1014 }
1015 for col in &right_table.schema.columns {
1016 combined_fields.push(Field::new(
1017 col.name.clone(),
1018 col.data_type.clone(),
1019 col.nullable,
1020 ));
1021 }
1022 let combined_schema = Arc::new(Schema::new(combined_fields));
1023 let column_counts = vec![left_col_count, right_col_count];
1024 let table_indices = vec![0, 1];
1025
1026 let mut join_keys = Vec::new();
1027 let mut condition_is_trivial = false;
1028 let mut condition_is_impossible = false;
1029
1030 if let Some(condition) = join_metadata.on_condition.as_ref() {
1031 let plan = build_join_keys_from_condition(
1032 condition,
1033 left_ref,
1034 left_table.as_ref(),
1035 right_ref,
1036 right_table.as_ref(),
1037 )?;
1038 join_keys = plan.keys;
1039 condition_is_trivial = plan.always_true;
1040 condition_is_impossible = plan.always_false;
1041 }
1042
1043 if condition_is_impossible {
1044 let batches = build_no_match_join_batches(
1045 join_type,
1046 left_ref,
1047 left_table.as_ref(),
1048 right_ref,
1049 right_table.as_ref(),
1050 Arc::clone(&combined_schema),
1051 )?;
1052
1053 TableCrossProductData {
1054 schema: combined_schema,
1055 batches,
1056 column_counts,
1057 table_indices,
1058 }
1059 } else {
1060 if !condition_is_trivial
1061 && join_metadata.on_condition.is_some()
1062 && join_keys.is_empty()
1063 {
1064 return Err(Error::InvalidArgumentError(
1065 "JOIN ON clause must include at least one equality predicate".into(),
1066 ));
1067 }
1068
1069 let mut result_batches = Vec::new();
1070 left_table.table.join_stream(
1071 &right_table.table,
1072 &join_keys,
1073 &JoinOptions {
1074 join_type,
1075 ..Default::default()
1076 },
1077 |batch| {
1078 result_batches.push(batch);
1079 },
1080 )?;
1081
1082 TableCrossProductData {
1083 schema: combined_schema,
1084 batches: result_batches,
1085 column_counts,
1086 table_indices,
1087 }
1088 }
1089 } else {
1090 let constraint_map = if let Some(filter_wrapper) = remaining_filter.as_ref() {
1092 extract_literal_pushdown_filters(
1093 &filter_wrapper.predicate,
1094 &tables_with_handles,
1095 )
1096 } else {
1097 vec![Vec::new(); tables_with_handles.len()]
1098 };
1099
1100 let mut staged: Vec<TableCrossProductData> =
1101 Vec::with_capacity(tables_with_handles.len());
1102 let join_lookup: FxHashMap<usize, &llkv_plan::JoinMetadata> = plan
1103 .joins
1104 .iter()
1105 .map(|join| (join.left_table_index, join))
1106 .collect();
1107
1108 let mut idx = 0usize;
1109 while idx < tables_with_handles.len() {
1110 if let Some(join_metadata) = join_lookup.get(&idx) {
1111 if idx + 1 >= tables_with_handles.len() {
1112 return Err(Error::Internal(
1113 "join metadata references table beyond FROM list".into(),
1114 ));
1115 }
1116
1117 let overlaps_next_join = join_lookup.contains_key(&(idx + 1));
1122 if let Some(condition) = join_metadata.on_condition.as_ref() {
1123 let (left_ref, left_table) = &tables_with_handles[idx];
1124 let (right_ref, right_table) = &tables_with_handles[idx + 1];
1125 let join_plan = build_join_keys_from_condition(
1126 condition,
1127 left_ref,
1128 left_table.as_ref(),
1129 right_ref,
1130 right_table.as_ref(),
1131 )?;
1132 if join_plan.always_false && !overlaps_next_join {
1133 let join_type = match join_metadata.join_type {
1134 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1135 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1136 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1137 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1138 };
1139
1140 let left_col_count = left_table.schema.columns.len();
1141 let right_col_count = right_table.schema.columns.len();
1142
1143 let mut combined_fields =
1144 Vec::with_capacity(left_col_count + right_col_count);
1145 for col in &left_table.schema.columns {
1146 combined_fields.push(Field::new(
1147 col.name.clone(),
1148 col.data_type.clone(),
1149 col.nullable,
1150 ));
1151 }
1152 for col in &right_table.schema.columns {
1153 combined_fields.push(Field::new(
1154 col.name.clone(),
1155 col.data_type.clone(),
1156 col.nullable,
1157 ));
1158 }
1159
1160 let combined_schema = Arc::new(Schema::new(combined_fields));
1161 let batches = build_no_match_join_batches(
1162 join_type,
1163 left_ref,
1164 left_table.as_ref(),
1165 right_ref,
1166 right_table.as_ref(),
1167 Arc::clone(&combined_schema),
1168 )?;
1169
1170 staged.push(TableCrossProductData {
1171 schema: combined_schema,
1172 batches,
1173 column_counts: vec![left_col_count, right_col_count],
1174 table_indices: vec![idx, idx + 1],
1175 });
1176 idx += 2;
1177 continue;
1178 }
1179 }
1180 }
1181
1182 let (table_ref, table) = &tables_with_handles[idx];
1183 let constraints = constraint_map.get(idx).map(|v| v.as_slice()).unwrap_or(&[]);
1184 staged.push(collect_table_data(
1185 idx,
1186 table_ref,
1187 table.as_ref(),
1188 constraints,
1189 )?);
1190 idx += 1;
1191 }
1192
1193 cross_join_all(staged)?
1194 }
1195 };
1196
1197 let TableCrossProductData {
1198 schema: combined_schema,
1199 batches: mut combined_batches,
1200 column_counts,
1201 table_indices,
1202 } = current;
1203
1204 let column_lookup_map = build_cross_product_column_lookup(
1205 combined_schema.as_ref(),
1206 &plan.tables,
1207 &column_counts,
1208 &table_indices,
1209 );
1210
1211 if let Some(filter_wrapper) = remaining_filter.as_ref() {
1212 let mut filter_context = CrossProductExpressionContext::new(
1213 combined_schema.as_ref(),
1214 column_lookup_map.clone(),
1215 )?;
1216 let translated_filter = translate_predicate(
1217 filter_wrapper.predicate.clone(),
1218 filter_context.schema(),
1219 |name| {
1220 Error::InvalidArgumentError(format!(
1221 "column '{}' not found in cross product result",
1222 name
1223 ))
1224 },
1225 )?;
1226
1227 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
1228 filter_wrapper
1229 .subqueries
1230 .iter()
1231 .map(|subquery| (subquery.id, subquery))
1232 .collect();
1233
1234 let mut filtered_batches = Vec::with_capacity(combined_batches.len());
1235 for batch in combined_batches.into_iter() {
1236 filter_context.reset();
1237 let mask = filter_context.evaluate_predicate_mask(
1238 &translated_filter,
1239 &batch,
1240 |ctx, subquery_expr, row_idx, current_batch| {
1241 let subquery = subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
1242 Error::Internal("missing correlated subquery metadata".into())
1243 })?;
1244 let exists =
1245 self.evaluate_exists_subquery(ctx, subquery, current_batch, row_idx)?;
1246 let value = if subquery_expr.negated {
1247 !exists
1248 } else {
1249 exists
1250 };
1251 Ok(Some(value))
1252 },
1253 )?;
1254 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
1255 Error::InvalidArgumentError(format!(
1256 "failed to apply cross product filter: {err}"
1257 ))
1258 })?;
1259 if filtered.num_rows() > 0 {
1260 filtered_batches.push(filtered);
1261 }
1262 }
1263 combined_batches = filtered_batches;
1264 }
1265
1266 if !plan.group_by.is_empty() {
1268 return self.execute_group_by_from_batches(
1269 display_name,
1270 plan,
1271 combined_schema,
1272 combined_batches,
1273 column_lookup_map,
1274 );
1275 }
1276
1277 if !plan.aggregates.is_empty() {
1278 return self.execute_cross_product_aggregates(
1279 Arc::clone(&combined_schema),
1280 combined_batches,
1281 &column_lookup_map,
1282 &plan,
1283 &display_name,
1284 );
1285 }
1286
1287 if self.has_computed_aggregates(&plan) {
1288 return self.execute_cross_product_computed_aggregates(
1289 Arc::clone(&combined_schema),
1290 combined_batches,
1291 &column_lookup_map,
1292 &plan,
1293 &display_name,
1294 );
1295 }
1296
1297 let mut combined_batch = if combined_batches.is_empty() {
1298 RecordBatch::new_empty(Arc::clone(&combined_schema))
1299 } else if combined_batches.len() == 1 {
1300 combined_batches.pop().unwrap()
1301 } else {
1302 concat_batches(&combined_schema, &combined_batches).map_err(|e| {
1303 Error::Internal(format!(
1304 "failed to concatenate cross product batches: {}",
1305 e
1306 ))
1307 })?
1308 };
1309
1310 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
1311 .scalar_subqueries
1312 .iter()
1313 .map(|subquery| (subquery.id, subquery))
1314 .collect();
1315
1316 if !plan.projections.is_empty() {
1318 let mut selected_fields = Vec::new();
1319 let mut selected_columns = Vec::new();
1320 let mut expr_context: Option<CrossProductExpressionContext> = None;
1321
1322 for proj in &plan.projections {
1323 match proj {
1324 SelectProjection::AllColumns => {
1325 selected_fields = combined_schema.fields().iter().cloned().collect();
1327 selected_columns = combined_batch.columns().to_vec();
1328 break;
1329 }
1330 SelectProjection::AllColumnsExcept { exclude } => {
1331 let exclude_lower: Vec<String> =
1333 exclude.iter().map(|e| e.to_ascii_lowercase()).collect();
1334
1335 for (idx, field) in combined_schema.fields().iter().enumerate() {
1336 let field_name_lower = field.name().to_ascii_lowercase();
1337 if !exclude_lower.contains(&field_name_lower) {
1338 selected_fields.push(field.clone());
1339 selected_columns.push(combined_batch.column(idx).clone());
1340 }
1341 }
1342 break;
1343 }
1344 SelectProjection::Column { name, alias } => {
1345 let col_name = name.to_ascii_lowercase();
1347 if let Some(&idx) = column_lookup_map.get(&col_name) {
1348 let field = combined_schema.field(idx);
1349 let output_name = alias.as_ref().unwrap_or(name).clone();
1350 selected_fields.push(Arc::new(arrow::datatypes::Field::new(
1351 output_name,
1352 field.data_type().clone(),
1353 field.is_nullable(),
1354 )));
1355 selected_columns.push(combined_batch.column(idx).clone());
1356 } else {
1357 return Err(Error::InvalidArgumentError(format!(
1358 "column '{}' not found in cross product result",
1359 name
1360 )));
1361 }
1362 }
1363 SelectProjection::Computed { expr, alias } => {
1364 if expr_context.is_none() {
1365 expr_context = Some(CrossProductExpressionContext::new(
1366 combined_schema.as_ref(),
1367 column_lookup_map.clone(),
1368 )?);
1369 }
1370 let context = expr_context
1371 .as_mut()
1372 .expect("projection context must be initialized");
1373 context.reset();
1374 let evaluated = self.evaluate_projection_expression(
1375 context,
1376 expr,
1377 &combined_batch,
1378 &scalar_lookup,
1379 )?;
1380 let field = Arc::new(arrow::datatypes::Field::new(
1381 alias.clone(),
1382 evaluated.data_type().clone(),
1383 true,
1384 ));
1385 selected_fields.push(field);
1386 selected_columns.push(evaluated);
1387 }
1388 }
1389 }
1390
1391 let projected_schema = Arc::new(Schema::new(selected_fields));
1392 combined_batch = RecordBatch::try_new(projected_schema, selected_columns)
1393 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))?;
1394 }
1395
1396 if plan.distinct {
1397 let mut state = DistinctState::default();
1398 let source_schema = combined_batch.schema();
1399 combined_batch = match distinct_filter_batch(combined_batch, &mut state)? {
1400 Some(filtered) => filtered,
1401 None => RecordBatch::new_empty(source_schema),
1402 };
1403 }
1404
1405 let schema = combined_batch.schema();
1406
1407 Ok(SelectExecution::new_single_batch(
1408 display_name,
1409 schema,
1410 combined_batch,
1411 ))
1412 }
1413}
1414
1415struct JoinKeyBuild {
1416 keys: Vec<llkv_join::JoinKey>,
1417 always_true: bool,
1418 always_false: bool,
1419}
1420
1421#[derive(Debug)]
1422enum JoinConditionAnalysis {
1423 AlwaysTrue,
1424 AlwaysFalse,
1425 EquiPairs(Vec<(String, String)>),
1426}
1427
1428fn build_join_keys_from_condition<P>(
1429 condition: &LlkvExpr<'static, String>,
1430 left_ref: &llkv_plan::TableRef,
1431 left_table: &ExecutorTable<P>,
1432 right_ref: &llkv_plan::TableRef,
1433 right_table: &ExecutorTable<P>,
1434) -> ExecutorResult<JoinKeyBuild>
1435where
1436 P: Pager<Blob = EntryHandle> + Send + Sync,
1437{
1438 match analyze_join_condition(condition)? {
1439 JoinConditionAnalysis::AlwaysTrue => Ok(JoinKeyBuild {
1440 keys: Vec::new(),
1441 always_true: true,
1442 always_false: false,
1443 }),
1444 JoinConditionAnalysis::AlwaysFalse => Ok(JoinKeyBuild {
1445 keys: Vec::new(),
1446 always_true: false,
1447 always_false: true,
1448 }),
1449 JoinConditionAnalysis::EquiPairs(pairs) => {
1450 let left_lookup = build_join_column_lookup(left_ref, left_table);
1451 let right_lookup = build_join_column_lookup(right_ref, right_table);
1452
1453 let mut keys = Vec::with_capacity(pairs.len());
1454 for (lhs, rhs) in pairs {
1455 let (lhs_side, lhs_field) = resolve_join_column(&lhs, &left_lookup, &right_lookup)?;
1456 let (rhs_side, rhs_field) = resolve_join_column(&rhs, &left_lookup, &right_lookup)?;
1457
1458 match (lhs_side, rhs_side) {
1459 (JoinColumnSide::Left, JoinColumnSide::Right) => {
1460 keys.push(llkv_join::JoinKey::new(lhs_field, rhs_field));
1461 }
1462 (JoinColumnSide::Right, JoinColumnSide::Left) => {
1463 keys.push(llkv_join::JoinKey::new(rhs_field, lhs_field));
1464 }
1465 (JoinColumnSide::Left, JoinColumnSide::Left) => {
1466 return Err(Error::InvalidArgumentError(format!(
1467 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1468 left_ref.display_name(),
1469 lhs,
1470 rhs
1471 )));
1472 }
1473 (JoinColumnSide::Right, JoinColumnSide::Right) => {
1474 return Err(Error::InvalidArgumentError(format!(
1475 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1476 right_ref.display_name(),
1477 lhs,
1478 rhs
1479 )));
1480 }
1481 }
1482 }
1483
1484 Ok(JoinKeyBuild {
1485 keys,
1486 always_true: false,
1487 always_false: false,
1488 })
1489 }
1490 }
1491}
1492
1493fn analyze_join_condition(
1494 expr: &LlkvExpr<'static, String>,
1495) -> ExecutorResult<JoinConditionAnalysis> {
1496 match evaluate_constant_join_expr(expr) {
1497 ConstantJoinEvaluation::Known(true) => {
1498 return Ok(JoinConditionAnalysis::AlwaysTrue);
1499 }
1500 ConstantJoinEvaluation::Known(false) | ConstantJoinEvaluation::Unknown => {
1501 return Ok(JoinConditionAnalysis::AlwaysFalse);
1502 }
1503 ConstantJoinEvaluation::NotConstant => {}
1504 }
1505 match expr {
1506 LlkvExpr::Literal(value) => {
1507 if *value {
1508 Ok(JoinConditionAnalysis::AlwaysTrue)
1509 } else {
1510 Ok(JoinConditionAnalysis::AlwaysFalse)
1511 }
1512 }
1513 LlkvExpr::And(children) => {
1514 let mut collected: Vec<(String, String)> = Vec::new();
1515 for child in children {
1516 match analyze_join_condition(child)? {
1517 JoinConditionAnalysis::AlwaysTrue => {}
1518 JoinConditionAnalysis::AlwaysFalse => {
1519 return Ok(JoinConditionAnalysis::AlwaysFalse);
1520 }
1521 JoinConditionAnalysis::EquiPairs(mut pairs) => {
1522 collected.append(&mut pairs);
1523 }
1524 }
1525 }
1526
1527 if collected.is_empty() {
1528 Ok(JoinConditionAnalysis::AlwaysTrue)
1529 } else {
1530 Ok(JoinConditionAnalysis::EquiPairs(collected))
1531 }
1532 }
1533 LlkvExpr::Compare { left, op, right } => {
1534 if *op != CompareOp::Eq {
1535 return Err(Error::InvalidArgumentError(
1536 "JOIN ON clause only supports '=' comparisons in optimized path".into(),
1537 ));
1538 }
1539 let left_name = try_extract_simple_column(left).ok_or_else(|| {
1540 Error::InvalidArgumentError(
1541 "JOIN ON clause requires plain column references".into(),
1542 )
1543 })?;
1544 let right_name = try_extract_simple_column(right).ok_or_else(|| {
1545 Error::InvalidArgumentError(
1546 "JOIN ON clause requires plain column references".into(),
1547 )
1548 })?;
1549 Ok(JoinConditionAnalysis::EquiPairs(vec![(
1550 left_name.to_string(),
1551 right_name.to_string(),
1552 )]))
1553 }
1554 _ => Err(Error::InvalidArgumentError(
1555 "JOIN ON expressions must be conjunctions of column equality predicates".into(),
1556 )),
1557 }
1558}
1559
1560fn compare_literals_with_mode(
1561 op: CompareOp,
1562 left: &Literal,
1563 right: &Literal,
1564 null_behavior: NullComparisonBehavior,
1565) -> Option<bool> {
1566 use std::cmp::Ordering;
1567
1568 fn ordering_result(ord: Ordering, op: CompareOp) -> bool {
1569 match op {
1570 CompareOp::Eq => ord == Ordering::Equal,
1571 CompareOp::NotEq => ord != Ordering::Equal,
1572 CompareOp::Lt => ord == Ordering::Less,
1573 CompareOp::LtEq => ord != Ordering::Greater,
1574 CompareOp::Gt => ord == Ordering::Greater,
1575 CompareOp::GtEq => ord != Ordering::Less,
1576 }
1577 }
1578
1579 fn compare_f64(lhs: f64, rhs: f64, op: CompareOp) -> bool {
1580 match op {
1581 CompareOp::Eq => lhs == rhs,
1582 CompareOp::NotEq => lhs != rhs,
1583 CompareOp::Lt => lhs < rhs,
1584 CompareOp::LtEq => lhs <= rhs,
1585 CompareOp::Gt => lhs > rhs,
1586 CompareOp::GtEq => lhs >= rhs,
1587 }
1588 }
1589
1590 match (left, right) {
1591 (Literal::Null, _) | (_, Literal::Null) => match null_behavior {
1592 NullComparisonBehavior::ThreeValuedLogic => None,
1593 },
1594 (Literal::Integer(lhs), Literal::Integer(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1595 (Literal::Float(lhs), Literal::Float(rhs)) => Some(compare_f64(*lhs, *rhs, op)),
1596 (Literal::Integer(lhs), Literal::Float(rhs)) => Some(compare_f64(*lhs as f64, *rhs, op)),
1597 (Literal::Float(lhs), Literal::Integer(rhs)) => Some(compare_f64(*lhs, *rhs as f64, op)),
1598 (Literal::Boolean(lhs), Literal::Boolean(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1599 (Literal::String(lhs), Literal::String(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1600 (Literal::Struct(_), _) | (_, Literal::Struct(_)) => None,
1601 _ => None,
1602 }
1603}
1604
1605fn build_no_match_join_batches<P>(
1606 join_type: llkv_join::JoinType,
1607 left_ref: &llkv_plan::TableRef,
1608 left_table: &ExecutorTable<P>,
1609 right_ref: &llkv_plan::TableRef,
1610 right_table: &ExecutorTable<P>,
1611 combined_schema: Arc<Schema>,
1612) -> ExecutorResult<Vec<RecordBatch>>
1613where
1614 P: Pager<Blob = EntryHandle> + Send + Sync,
1615{
1616 match join_type {
1617 llkv_join::JoinType::Inner => Ok(Vec::new()),
1618 llkv_join::JoinType::Left => {
1619 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
1620 let mut results = Vec::new();
1621
1622 for left_batch in left_batches {
1623 let row_count = left_batch.num_rows();
1624 if row_count == 0 {
1625 continue;
1626 }
1627
1628 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1629 columns.extend(left_batch.columns().iter().cloned());
1630 for column in &right_table.schema.columns {
1631 columns.push(new_null_array(&column.data_type, row_count));
1632 }
1633
1634 let batch =
1635 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1636 Error::Internal(format!("failed to build LEFT JOIN fallback batch: {err}"))
1637 })?;
1638 results.push(batch);
1639 }
1640
1641 Ok(results)
1642 }
1643 llkv_join::JoinType::Right => {
1644 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
1645 let mut results = Vec::new();
1646
1647 for right_batch in right_batches {
1648 let row_count = right_batch.num_rows();
1649 if row_count == 0 {
1650 continue;
1651 }
1652
1653 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1654 for column in &left_table.schema.columns {
1655 columns.push(new_null_array(&column.data_type, row_count));
1656 }
1657 columns.extend(right_batch.columns().iter().cloned());
1658
1659 let batch =
1660 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1661 Error::Internal(format!("failed to build RIGHT JOIN fallback batch: {err}"))
1662 })?;
1663 results.push(batch);
1664 }
1665
1666 Ok(results)
1667 }
1668 llkv_join::JoinType::Full => {
1669 let mut results = Vec::new();
1670
1671 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
1672 for left_batch in left_batches {
1673 let row_count = left_batch.num_rows();
1674 if row_count == 0 {
1675 continue;
1676 }
1677
1678 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1679 columns.extend(left_batch.columns().iter().cloned());
1680 for column in &right_table.schema.columns {
1681 columns.push(new_null_array(&column.data_type, row_count));
1682 }
1683
1684 let batch =
1685 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1686 Error::Internal(format!(
1687 "failed to build FULL JOIN left fallback batch: {err}"
1688 ))
1689 })?;
1690 results.push(batch);
1691 }
1692
1693 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
1694 for right_batch in right_batches {
1695 let row_count = right_batch.num_rows();
1696 if row_count == 0 {
1697 continue;
1698 }
1699
1700 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1701 for column in &left_table.schema.columns {
1702 columns.push(new_null_array(&column.data_type, row_count));
1703 }
1704 columns.extend(right_batch.columns().iter().cloned());
1705
1706 let batch =
1707 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1708 Error::Internal(format!(
1709 "failed to build FULL JOIN right fallback batch: {err}"
1710 ))
1711 })?;
1712 results.push(batch);
1713 }
1714
1715 Ok(results)
1716 }
1717 other => Err(Error::InvalidArgumentError(format!(
1718 "{other:?} join type is not supported when join predicate is unsatisfiable",
1719 ))),
1720 }
1721}
1722
1723fn scan_all_columns_for_join<P>(
1724 table_ref: &llkv_plan::TableRef,
1725 table: &ExecutorTable<P>,
1726) -> ExecutorResult<Vec<RecordBatch>>
1727where
1728 P: Pager<Blob = EntryHandle> + Send + Sync,
1729{
1730 if table.schema.columns.is_empty() {
1731 return Err(Error::InvalidArgumentError(format!(
1732 "table '{}' has no columns; joins require at least one column",
1733 table_ref.qualified_name()
1734 )));
1735 }
1736
1737 let mut projections = Vec::with_capacity(table.schema.columns.len());
1738 for column in &table.schema.columns {
1739 projections.push(ScanProjection::from(StoreProjection::with_alias(
1740 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
1741 column.name.clone(),
1742 )));
1743 }
1744
1745 let filter_field = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
1746 let filter_expr = full_table_scan_filter(filter_field);
1747
1748 let mut batches = Vec::new();
1749 table.table.scan_stream(
1750 projections,
1751 &filter_expr,
1752 ScanStreamOptions {
1753 include_nulls: true,
1754 ..ScanStreamOptions::default()
1755 },
1756 |batch| {
1757 batches.push(batch);
1758 },
1759 )?;
1760
1761 Ok(batches)
1762}
1763
1764fn build_join_column_lookup<P>(
1765 table_ref: &llkv_plan::TableRef,
1766 table: &ExecutorTable<P>,
1767) -> FxHashMap<String, FieldId>
1768where
1769 P: Pager<Blob = EntryHandle> + Send + Sync,
1770{
1771 let mut lookup = FxHashMap::default();
1772 let table_lower = table_ref.table.to_ascii_lowercase();
1773 let qualified_lower = table_ref.qualified_name().to_ascii_lowercase();
1774 let display_lower = table_ref.display_name().to_ascii_lowercase();
1775 let alias_lower = table_ref.alias.as_ref().map(|s| s.to_ascii_lowercase());
1776 let schema_lower = if table_ref.schema.is_empty() {
1777 None
1778 } else {
1779 Some(table_ref.schema.to_ascii_lowercase())
1780 };
1781
1782 for column in &table.schema.columns {
1783 let base = column.name.to_ascii_lowercase();
1784 let short = base.rsplit('.').next().unwrap_or(base.as_str()).to_string();
1785
1786 lookup.entry(short.clone()).or_insert(column.field_id);
1787 lookup.entry(base.clone()).or_insert(column.field_id);
1788
1789 lookup
1790 .entry(format!("{table_lower}.{short}"))
1791 .or_insert(column.field_id);
1792
1793 if display_lower != table_lower {
1794 lookup
1795 .entry(format!("{display_lower}.{short}"))
1796 .or_insert(column.field_id);
1797 }
1798
1799 if qualified_lower != table_lower {
1800 lookup
1801 .entry(format!("{qualified_lower}.{short}"))
1802 .or_insert(column.field_id);
1803 }
1804
1805 if let Some(schema) = &schema_lower {
1806 lookup
1807 .entry(format!("{schema}.{table_lower}.{short}"))
1808 .or_insert(column.field_id);
1809 if display_lower != table_lower {
1810 lookup
1811 .entry(format!("{schema}.{display_lower}.{short}"))
1812 .or_insert(column.field_id);
1813 }
1814 }
1815
1816 if let Some(alias) = &alias_lower {
1817 lookup
1818 .entry(format!("{alias}.{short}"))
1819 .or_insert(column.field_id);
1820 }
1821 }
1822
1823 lookup
1824}
1825
1826#[derive(Clone, Copy)]
1827enum JoinColumnSide {
1828 Left,
1829 Right,
1830}
1831
1832fn resolve_join_column(
1833 column: &str,
1834 left_lookup: &FxHashMap<String, FieldId>,
1835 right_lookup: &FxHashMap<String, FieldId>,
1836) -> ExecutorResult<(JoinColumnSide, FieldId)> {
1837 let key = column.to_ascii_lowercase();
1838 match (left_lookup.get(&key), right_lookup.get(&key)) {
1839 (Some(&field_id), None) => Ok((JoinColumnSide::Left, field_id)),
1840 (None, Some(&field_id)) => Ok((JoinColumnSide::Right, field_id)),
1841 (Some(_), Some(_)) => Err(Error::InvalidArgumentError(format!(
1842 "join column '{column}' is ambiguous; qualify it with a table name or alias",
1843 ))),
1844 (None, None) => Err(Error::InvalidArgumentError(format!(
1845 "join column '{column}' was not found in either table",
1846 ))),
1847 }
1848}
1849
1850#[cfg(test)]
1851mod join_condition_tests {
1852 use super::*;
1853 use llkv_expr::expr::{CompareOp, ScalarExpr};
1854 use llkv_expr::literal::Literal;
1855
1856 #[test]
1857 fn analyze_detects_simple_equality() {
1858 let expr = LlkvExpr::Compare {
1859 left: ScalarExpr::Column("t1.col".into()),
1860 op: CompareOp::Eq,
1861 right: ScalarExpr::Column("t2.col".into()),
1862 };
1863
1864 match analyze_join_condition(&expr).expect("analysis succeeds") {
1865 JoinConditionAnalysis::EquiPairs(pairs) => {
1866 assert_eq!(pairs, vec![("t1.col".to_string(), "t2.col".to_string())]);
1867 }
1868 other => panic!("unexpected analysis result: {other:?}"),
1869 }
1870 }
1871
1872 #[test]
1873 fn analyze_handles_literal_true() {
1874 let expr = LlkvExpr::Literal(true);
1875 assert!(matches!(
1876 analyze_join_condition(&expr).expect("analysis succeeds"),
1877 JoinConditionAnalysis::AlwaysTrue
1878 ));
1879 }
1880
1881 #[test]
1882 fn analyze_rejects_non_equality() {
1883 let expr = LlkvExpr::Compare {
1884 left: ScalarExpr::Column("t1.col".into()),
1885 op: CompareOp::Gt,
1886 right: ScalarExpr::Column("t2.col".into()),
1887 };
1888 assert!(analyze_join_condition(&expr).is_err());
1889 }
1890
1891 #[test]
1892 fn analyze_handles_constant_is_not_null() {
1893 let expr = LlkvExpr::IsNull {
1894 expr: ScalarExpr::Literal(Literal::Null),
1895 negated: true,
1896 };
1897
1898 assert!(matches!(
1899 analyze_join_condition(&expr).expect("analysis succeeds"),
1900 JoinConditionAnalysis::AlwaysFalse
1901 ));
1902 }
1903
1904 #[test]
1905 fn analyze_handles_not_applied_to_is_not_null() {
1906 let expr = LlkvExpr::Not(Box::new(LlkvExpr::IsNull {
1907 expr: ScalarExpr::Literal(Literal::Integer(86)),
1908 negated: true,
1909 }));
1910
1911 assert!(matches!(
1912 analyze_join_condition(&expr).expect("analysis succeeds"),
1913 JoinConditionAnalysis::AlwaysFalse
1914 ));
1915 }
1916
1917 #[test]
1918 fn analyze_literal_is_null_is_always_false() {
1919 let expr = LlkvExpr::IsNull {
1920 expr: ScalarExpr::Literal(Literal::Integer(1)),
1921 negated: false,
1922 };
1923
1924 assert!(matches!(
1925 analyze_join_condition(&expr).expect("analysis succeeds"),
1926 JoinConditionAnalysis::AlwaysFalse
1927 ));
1928 }
1929
1930 #[test]
1931 fn analyze_not_null_comparison_is_always_false() {
1932 let expr = LlkvExpr::Not(Box::new(LlkvExpr::Compare {
1933 left: ScalarExpr::Literal(Literal::Null),
1934 op: CompareOp::Lt,
1935 right: ScalarExpr::Column("t2.col".into()),
1936 }));
1937
1938 assert!(matches!(
1939 analyze_join_condition(&expr).expect("analysis succeeds"),
1940 JoinConditionAnalysis::AlwaysFalse
1941 ));
1942 }
1943}
1944
1945impl<P> QueryExecutor<P>
1946where
1947 P: Pager<Blob = EntryHandle> + Send + Sync,
1948{
1949 fn execute_cross_product_aggregates(
1950 &self,
1951 combined_schema: Arc<Schema>,
1952 batches: Vec<RecordBatch>,
1953 column_lookup_map: &FxHashMap<String, usize>,
1954 plan: &SelectPlan,
1955 display_name: &str,
1956 ) -> ExecutorResult<SelectExecution<P>> {
1957 if !plan.scalar_subqueries.is_empty() {
1958 return Err(Error::InvalidArgumentError(
1959 "scalar subqueries not supported in aggregate joins".into(),
1960 ));
1961 }
1962
1963 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
1964 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(plan.aggregates.len());
1965
1966 for aggregate in &plan.aggregates {
1967 match aggregate {
1968 AggregateExpr::CountStar { alias, distinct } => {
1969 specs.push(AggregateSpec {
1970 alias: alias.clone(),
1971 kind: AggregateKind::Count {
1972 field_id: None,
1973 distinct: *distinct,
1974 },
1975 });
1976 spec_to_projection.push(None);
1977 }
1978 AggregateExpr::Column {
1979 column,
1980 alias,
1981 function,
1982 distinct,
1983 } => {
1984 let key = column.to_ascii_lowercase();
1985 let column_index = *column_lookup_map.get(&key).ok_or_else(|| {
1986 Error::InvalidArgumentError(format!(
1987 "unknown column '{column}' in aggregate"
1988 ))
1989 })?;
1990 let field = combined_schema.field(column_index);
1991 let kind = match function {
1992 AggregateFunction::Count => AggregateKind::Count {
1993 field_id: Some(column_index as u32),
1994 distinct: *distinct,
1995 },
1996 AggregateFunction::SumInt64 => {
1997 let input_type = Self::validate_aggregate_type(
1998 Some(field.data_type().clone()),
1999 "SUM",
2000 &[DataType::Int64, DataType::Float64],
2001 )?;
2002 AggregateKind::Sum {
2003 field_id: column_index as u32,
2004 data_type: input_type,
2005 distinct: *distinct,
2006 }
2007 }
2008 AggregateFunction::TotalInt64 => {
2009 let input_type = Self::validate_aggregate_type(
2010 Some(field.data_type().clone()),
2011 "TOTAL",
2012 &[DataType::Int64, DataType::Float64],
2013 )?;
2014 AggregateKind::Total {
2015 field_id: column_index as u32,
2016 data_type: input_type,
2017 distinct: *distinct,
2018 }
2019 }
2020 AggregateFunction::MinInt64 => {
2021 let input_type = Self::validate_aggregate_type(
2022 Some(field.data_type().clone()),
2023 "MIN",
2024 &[DataType::Int64, DataType::Float64],
2025 )?;
2026 AggregateKind::Min {
2027 field_id: column_index as u32,
2028 data_type: input_type,
2029 }
2030 }
2031 AggregateFunction::MaxInt64 => {
2032 let input_type = Self::validate_aggregate_type(
2033 Some(field.data_type().clone()),
2034 "MAX",
2035 &[DataType::Int64, DataType::Float64],
2036 )?;
2037 AggregateKind::Max {
2038 field_id: column_index as u32,
2039 data_type: input_type,
2040 }
2041 }
2042 AggregateFunction::CountNulls => AggregateKind::CountNulls {
2043 field_id: column_index as u32,
2044 },
2045 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
2046 field_id: column_index as u32,
2047 distinct: *distinct,
2048 separator: ",".to_string(),
2049 },
2050 };
2051
2052 specs.push(AggregateSpec {
2053 alias: alias.clone(),
2054 kind,
2055 });
2056 spec_to_projection.push(Some(column_index));
2057 }
2058 }
2059 }
2060
2061 if specs.is_empty() {
2062 return Err(Error::InvalidArgumentError(
2063 "aggregate query requires at least one aggregate expression".into(),
2064 ));
2065 }
2066
2067 let mut states = Vec::with_capacity(specs.len());
2068 for (idx, spec) in specs.iter().enumerate() {
2069 states.push(AggregateState {
2070 alias: spec.alias.clone(),
2071 accumulator: AggregateAccumulator::new_with_projection_index(
2072 spec,
2073 spec_to_projection[idx],
2074 None,
2075 )?,
2076 override_value: None,
2077 });
2078 }
2079
2080 for batch in &batches {
2081 for state in &mut states {
2082 state.update(batch)?;
2083 }
2084 }
2085
2086 let mut fields = Vec::with_capacity(states.len());
2087 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
2088 for state in states {
2089 let (field, array) = state.finalize()?;
2090 fields.push(Arc::new(field));
2091 arrays.push(array);
2092 }
2093
2094 let schema = Arc::new(Schema::new(fields));
2095 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
2096
2097 if plan.distinct {
2098 let mut distinct_state = DistinctState::default();
2099 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
2100 Some(filtered) => filtered,
2101 None => RecordBatch::new_empty(Arc::clone(&schema)),
2102 };
2103 }
2104
2105 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
2106 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
2107 }
2108
2109 Ok(SelectExecution::new_single_batch(
2110 display_name.to_string(),
2111 schema,
2112 batch,
2113 ))
2114 }
2115
2116 fn execute_cross_product_computed_aggregates(
2117 &self,
2118 combined_schema: Arc<Schema>,
2119 batches: Vec<RecordBatch>,
2120 column_lookup_map: &FxHashMap<String, usize>,
2121 plan: &SelectPlan,
2122 display_name: &str,
2123 ) -> ExecutorResult<SelectExecution<P>> {
2124 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
2125 for projection in &plan.projections {
2126 match projection {
2127 SelectProjection::Computed { expr, .. } => {
2128 Self::collect_aggregates(expr, &mut aggregate_specs);
2129 }
2130 SelectProjection::AllColumns
2131 | SelectProjection::AllColumnsExcept { .. }
2132 | SelectProjection::Column { .. } => {
2133 return Err(Error::InvalidArgumentError(
2134 "non-computed projections not supported with aggregate expressions".into(),
2135 ));
2136 }
2137 }
2138 }
2139
2140 if aggregate_specs.is_empty() {
2141 return Err(Error::InvalidArgumentError(
2142 "computed aggregate query requires at least one aggregate expression".into(),
2143 ));
2144 }
2145
2146 let aggregate_values = self.compute_cross_product_aggregate_values(
2147 &combined_schema,
2148 &batches,
2149 column_lookup_map,
2150 &aggregate_specs,
2151 )?;
2152
2153 let mut fields = Vec::with_capacity(plan.projections.len());
2154 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
2155
2156 for projection in &plan.projections {
2157 if let SelectProjection::Computed { expr, alias } = projection {
2158 if let ScalarExpr::Aggregate(agg) = expr {
2160 let key = format!("{:?}", agg);
2161 if let Some(agg_value) = aggregate_values.get(&key) {
2162 match agg_value {
2163 AggregateValue::Null => {
2164 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
2165 arrays.push(Arc::new(Int64Array::from(vec![None::<i64>])) as ArrayRef);
2166 }
2167 AggregateValue::Int64(v) => {
2168 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
2169 arrays.push(Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef);
2170 }
2171 AggregateValue::Float64(v) => {
2172 fields.push(Arc::new(Field::new(alias, DataType::Float64, true)));
2173 arrays
2174 .push(Arc::new(Float64Array::from(vec![Some(*v)])) as ArrayRef);
2175 }
2176 AggregateValue::String(s) => {
2177 fields.push(Arc::new(Field::new(alias, DataType::Utf8, true)));
2178 arrays
2179 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
2180 as ArrayRef);
2181 }
2182 }
2183 continue;
2184 }
2185 }
2186
2187 let value = Self::evaluate_expr_with_aggregates(expr, &aggregate_values)?;
2189 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
2190 arrays.push(Arc::new(Int64Array::from(vec![value])) as ArrayRef);
2191 }
2192 }
2193
2194 let schema = Arc::new(Schema::new(fields));
2195 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
2196
2197 if plan.distinct {
2198 let mut distinct_state = DistinctState::default();
2199 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
2200 Some(filtered) => filtered,
2201 None => RecordBatch::new_empty(Arc::clone(&schema)),
2202 };
2203 }
2204
2205 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
2206 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
2207 }
2208
2209 Ok(SelectExecution::new_single_batch(
2210 display_name.to_string(),
2211 schema,
2212 batch,
2213 ))
2214 }
2215
2216 fn compute_cross_product_aggregate_values(
2217 &self,
2218 combined_schema: &Arc<Schema>,
2219 batches: &[RecordBatch],
2220 column_lookup_map: &FxHashMap<String, usize>,
2221 aggregate_specs: &[(String, AggregateCall<String>)],
2222 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
2223 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
2224 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
2225
2226 let mut columns_per_batch: Option<Vec<Vec<ArrayRef>>> = None;
2227 let mut augmented_fields: Option<Vec<Field>> = None;
2228 let mut owned_batches: Option<Vec<RecordBatch>> = None;
2229 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
2230 FxHashMap::default();
2231 let mut computed_alias_counter: usize = 0;
2232 let mut expr_context = CrossProductExpressionContext::new(
2233 combined_schema.as_ref(),
2234 column_lookup_map.clone(),
2235 )?;
2236
2237 let mut ensure_computed_column =
2238 |expr: &ScalarExpr<String>| -> ExecutorResult<(usize, DataType)> {
2239 let key = format!("{:?}", expr);
2240 if let Some((idx, dtype)) = computed_projection_cache.get(&key) {
2241 return Ok((*idx, dtype.clone()));
2242 }
2243
2244 if columns_per_batch.is_none() {
2245 let initial_columns: Vec<Vec<ArrayRef>> = batches
2246 .iter()
2247 .map(|batch| batch.columns().to_vec())
2248 .collect();
2249 columns_per_batch = Some(initial_columns);
2250 }
2251 if augmented_fields.is_none() {
2252 augmented_fields = Some(
2253 combined_schema
2254 .fields()
2255 .iter()
2256 .map(|field| field.as_ref().clone())
2257 .collect(),
2258 );
2259 }
2260
2261 let translated = translate_scalar(expr, expr_context.schema(), |name| {
2262 Error::InvalidArgumentError(format!(
2263 "unknown column '{}' in aggregate expression",
2264 name
2265 ))
2266 })?;
2267 let data_type = infer_computed_data_type(expr_context.schema(), &translated)?;
2268
2269 if let Some(columns) = columns_per_batch.as_mut() {
2270 for (batch_idx, batch) in batches.iter().enumerate() {
2271 expr_context.reset();
2272 let array = expr_context.materialize_scalar_array(&translated, batch)?;
2273 if let Some(batch_columns) = columns.get_mut(batch_idx) {
2274 batch_columns.push(array);
2275 }
2276 }
2277 }
2278
2279 let column_index = augmented_fields
2280 .as_ref()
2281 .map(|fields| fields.len())
2282 .unwrap_or_else(|| combined_schema.fields().len());
2283
2284 let alias = format!("__agg_expr_cp_{}", computed_alias_counter);
2285 computed_alias_counter += 1;
2286 augmented_fields
2287 .as_mut()
2288 .expect("augmented fields initialized")
2289 .push(Field::new(&alias, data_type.clone(), true));
2290
2291 computed_projection_cache.insert(key, (column_index, data_type.clone()));
2292 Ok((column_index, data_type))
2293 };
2294
2295 for (key, agg) in aggregate_specs {
2296 match agg {
2297 AggregateCall::CountStar => {
2298 specs.push(AggregateSpec {
2299 alias: key.clone(),
2300 kind: AggregateKind::Count {
2301 field_id: None,
2302 distinct: false,
2303 },
2304 });
2305 spec_to_projection.push(None);
2306 }
2307 AggregateCall::Count { expr, .. }
2308 | AggregateCall::Sum { expr, .. }
2309 | AggregateCall::Total { expr, .. }
2310 | AggregateCall::Avg { expr, .. }
2311 | AggregateCall::Min(expr)
2312 | AggregateCall::Max(expr)
2313 | AggregateCall::CountNulls(expr)
2314 | AggregateCall::GroupConcat { expr, .. } => {
2315 let (column_index, data_type_opt) = if let Some(column) =
2316 try_extract_simple_column(expr)
2317 {
2318 let key_lower = column.to_ascii_lowercase();
2319 let column_index = *column_lookup_map.get(&key_lower).ok_or_else(|| {
2320 Error::InvalidArgumentError(format!(
2321 "unknown column '{column}' in aggregate"
2322 ))
2323 })?;
2324 let field = combined_schema.field(column_index);
2325 (column_index, Some(field.data_type().clone()))
2326 } else {
2327 let (index, dtype) = ensure_computed_column(expr)?;
2328 (index, Some(dtype))
2329 };
2330
2331 let kind = match agg {
2332 AggregateCall::Count { distinct, .. } => {
2333 let field_id = u32::try_from(column_index).map_err(|_| {
2334 Error::InvalidArgumentError(
2335 "aggregate projection index exceeds supported range".into(),
2336 )
2337 })?;
2338 AggregateKind::Count {
2339 field_id: Some(field_id),
2340 distinct: *distinct,
2341 }
2342 }
2343 AggregateCall::Sum { distinct, .. } => {
2344 let input_type = Self::validate_aggregate_type(
2345 data_type_opt.clone(),
2346 "SUM",
2347 &[DataType::Int64, DataType::Float64],
2348 )?;
2349 let field_id = u32::try_from(column_index).map_err(|_| {
2350 Error::InvalidArgumentError(
2351 "aggregate projection index exceeds supported range".into(),
2352 )
2353 })?;
2354 AggregateKind::Sum {
2355 field_id,
2356 data_type: input_type,
2357 distinct: *distinct,
2358 }
2359 }
2360 AggregateCall::Total { distinct, .. } => {
2361 let input_type = Self::validate_aggregate_type(
2362 data_type_opt.clone(),
2363 "TOTAL",
2364 &[DataType::Int64, DataType::Float64],
2365 )?;
2366 let field_id = u32::try_from(column_index).map_err(|_| {
2367 Error::InvalidArgumentError(
2368 "aggregate projection index exceeds supported range".into(),
2369 )
2370 })?;
2371 AggregateKind::Total {
2372 field_id,
2373 data_type: input_type,
2374 distinct: *distinct,
2375 }
2376 }
2377 AggregateCall::Avg { distinct, .. } => {
2378 let input_type = Self::validate_aggregate_type(
2379 data_type_opt.clone(),
2380 "AVG",
2381 &[DataType::Int64, DataType::Float64],
2382 )?;
2383 let field_id = u32::try_from(column_index).map_err(|_| {
2384 Error::InvalidArgumentError(
2385 "aggregate projection index exceeds supported range".into(),
2386 )
2387 })?;
2388 AggregateKind::Avg {
2389 field_id,
2390 data_type: input_type,
2391 distinct: *distinct,
2392 }
2393 }
2394 AggregateCall::Min(_) => {
2395 let input_type = Self::validate_aggregate_type(
2396 data_type_opt.clone(),
2397 "MIN",
2398 &[DataType::Int64, DataType::Float64],
2399 )?;
2400 let field_id = u32::try_from(column_index).map_err(|_| {
2401 Error::InvalidArgumentError(
2402 "aggregate projection index exceeds supported range".into(),
2403 )
2404 })?;
2405 AggregateKind::Min {
2406 field_id,
2407 data_type: input_type,
2408 }
2409 }
2410 AggregateCall::Max(_) => {
2411 let input_type = Self::validate_aggregate_type(
2412 data_type_opt.clone(),
2413 "MAX",
2414 &[DataType::Int64, DataType::Float64],
2415 )?;
2416 let field_id = u32::try_from(column_index).map_err(|_| {
2417 Error::InvalidArgumentError(
2418 "aggregate projection index exceeds supported range".into(),
2419 )
2420 })?;
2421 AggregateKind::Max {
2422 field_id,
2423 data_type: input_type,
2424 }
2425 }
2426 AggregateCall::CountNulls(_) => {
2427 let field_id = u32::try_from(column_index).map_err(|_| {
2428 Error::InvalidArgumentError(
2429 "aggregate projection index exceeds supported range".into(),
2430 )
2431 })?;
2432 AggregateKind::CountNulls { field_id }
2433 }
2434 AggregateCall::GroupConcat {
2435 distinct,
2436 separator,
2437 ..
2438 } => {
2439 let field_id = u32::try_from(column_index).map_err(|_| {
2440 Error::InvalidArgumentError(
2441 "aggregate projection index exceeds supported range".into(),
2442 )
2443 })?;
2444 AggregateKind::GroupConcat {
2445 field_id,
2446 distinct: *distinct,
2447 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
2448 }
2449 }
2450 _ => unreachable!(),
2451 };
2452
2453 specs.push(AggregateSpec {
2454 alias: key.clone(),
2455 kind,
2456 });
2457 spec_to_projection.push(Some(column_index));
2458 }
2459 }
2460 }
2461
2462 if let Some(columns) = columns_per_batch {
2463 let fields = augmented_fields.unwrap_or_else(|| {
2464 combined_schema
2465 .fields()
2466 .iter()
2467 .map(|field| field.as_ref().clone())
2468 .collect()
2469 });
2470 let augmented_schema = Arc::new(Schema::new(fields));
2471 let mut new_batches = Vec::with_capacity(columns.len());
2472 for batch_columns in columns {
2473 let batch = RecordBatch::try_new(Arc::clone(&augmented_schema), batch_columns)
2474 .map_err(|err| {
2475 Error::InvalidArgumentError(format!(
2476 "failed to materialize aggregate projections: {err}"
2477 ))
2478 })?;
2479 new_batches.push(batch);
2480 }
2481 owned_batches = Some(new_batches);
2482 }
2483
2484 let mut states = Vec::with_capacity(specs.len());
2485 for (idx, spec) in specs.iter().enumerate() {
2486 states.push(AggregateState {
2487 alias: spec.alias.clone(),
2488 accumulator: AggregateAccumulator::new_with_projection_index(
2489 spec,
2490 spec_to_projection[idx],
2491 None,
2492 )?,
2493 override_value: None,
2494 });
2495 }
2496
2497 let batch_iter: &[RecordBatch] = if let Some(ref extended) = owned_batches {
2498 extended.as_slice()
2499 } else {
2500 batches
2501 };
2502
2503 for batch in batch_iter {
2504 for state in &mut states {
2505 state.update(batch)?;
2506 }
2507 }
2508
2509 let mut results = FxHashMap::default();
2510 for state in states {
2511 let (field, array) = state.finalize()?;
2512
2513 if let Some(int_array) = array.as_any().downcast_ref::<Int64Array>() {
2515 if int_array.len() != 1 {
2516 return Err(Error::Internal(format!(
2517 "Expected single value from aggregate, got {}",
2518 int_array.len()
2519 )));
2520 }
2521 let value = if int_array.is_null(0) {
2522 AggregateValue::Null
2523 } else {
2524 AggregateValue::Int64(int_array.value(0))
2525 };
2526 results.insert(field.name().to_string(), value);
2527 }
2528 else if let Some(float_array) = array.as_any().downcast_ref::<Float64Array>() {
2530 if float_array.len() != 1 {
2531 return Err(Error::Internal(format!(
2532 "Expected single value from aggregate, got {}",
2533 float_array.len()
2534 )));
2535 }
2536 let value = if float_array.is_null(0) {
2537 AggregateValue::Null
2538 } else {
2539 AggregateValue::Float64(float_array.value(0))
2540 };
2541 results.insert(field.name().to_string(), value);
2542 }
2543 else if let Some(string_array) = array.as_any().downcast_ref::<StringArray>() {
2545 if string_array.len() != 1 {
2546 return Err(Error::Internal(format!(
2547 "Expected single value from aggregate, got {}",
2548 string_array.len()
2549 )));
2550 }
2551 let value = if string_array.is_null(0) {
2552 AggregateValue::Null
2553 } else {
2554 AggregateValue::String(string_array.value(0).to_string())
2555 };
2556 results.insert(field.name().to_string(), value);
2557 } else {
2558 return Err(Error::Internal(format!(
2559 "Unexpected array type from aggregate: {:?}",
2560 array.data_type()
2561 )));
2562 }
2563 }
2564
2565 Ok(results)
2566 }
2567
2568 fn try_execute_hash_join(
2585 &self,
2586 plan: &SelectPlan,
2587 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
2588 ) -> ExecutorResult<Option<(TableCrossProductData, bool)>> {
2589 let query_label_opt = current_query_label();
2590 let query_label = query_label_opt.as_deref().unwrap_or("<unknown query>");
2591
2592 let filter_wrapper = match &plan.filter {
2594 Some(filter) if filter.subqueries.is_empty() => filter,
2595 _ => {
2596 tracing::debug!(
2597 "join_opt[{query_label}]: skipping optimization – filter missing or uses subqueries"
2598 );
2599 return Ok(None);
2600 }
2601 };
2602
2603 if tables_with_handles.len() < 2 {
2604 tracing::debug!(
2605 "join_opt[{query_label}]: skipping optimization – requires at least 2 tables"
2606 );
2607 return Ok(None);
2608 }
2609
2610 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
2612 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
2613 let mut column_map = FxHashMap::default();
2614 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
2615 let column_name = column.name.to_ascii_lowercase();
2616 column_map.entry(column_name).or_insert(column_idx);
2617 }
2618 table_infos.push(TableInfo {
2619 index,
2620 table_ref,
2621 column_map,
2622 });
2623 }
2624
2625 let constraint_plan = match extract_join_constraints(
2627 &filter_wrapper.predicate,
2628 &table_infos,
2629 ) {
2630 Some(plan) => plan,
2631 None => {
2632 tracing::debug!(
2633 "join_opt[{query_label}]: skipping optimization – predicate parsing failed (contains OR or other unsupported top-level structure)"
2634 );
2635 return Ok(None);
2636 }
2637 };
2638
2639 tracing::debug!(
2640 "join_opt[{query_label}]: constraint extraction succeeded - equalities={}, literals={}, handled={}/{} predicates",
2641 constraint_plan.equalities.len(),
2642 constraint_plan.literals.len(),
2643 constraint_plan.handled_conjuncts,
2644 constraint_plan.total_conjuncts
2645 );
2646 tracing::debug!(
2647 "join_opt[{query_label}]: attempting hash join with tables={:?} filter={:?}",
2648 plan.tables
2649 .iter()
2650 .map(|t| t.qualified_name())
2651 .collect::<Vec<_>>(),
2652 filter_wrapper.predicate,
2653 );
2654
2655 if constraint_plan.unsatisfiable {
2657 tracing::debug!(
2658 "join_opt[{query_label}]: predicate unsatisfiable – returning empty result"
2659 );
2660 let mut combined_fields = Vec::new();
2661 let mut column_counts = Vec::new();
2662 for (_table_ref, executor_table) in tables_with_handles {
2663 for column in &executor_table.schema.columns {
2664 combined_fields.push(Field::new(
2665 column.name.clone(),
2666 column.data_type.clone(),
2667 column.nullable,
2668 ));
2669 }
2670 column_counts.push(executor_table.schema.columns.len());
2671 }
2672 let combined_schema = Arc::new(Schema::new(combined_fields));
2673 let empty_batch = RecordBatch::new_empty(Arc::clone(&combined_schema));
2674 return Ok(Some((
2675 TableCrossProductData {
2676 schema: combined_schema,
2677 batches: vec![empty_batch],
2678 column_counts,
2679 table_indices: (0..tables_with_handles.len()).collect(),
2680 },
2681 true, )));
2683 }
2684
2685 if constraint_plan.equalities.is_empty() {
2687 tracing::debug!(
2688 "join_opt[{query_label}]: skipping optimization – no join equalities found"
2689 );
2690 return Ok(None);
2691 }
2692
2693 if !constraint_plan.literals.is_empty() {
2698 tracing::debug!(
2699 "join_opt[{query_label}]: found {} literal constraints - proceeding with hash join but may need fallback",
2700 constraint_plan.literals.len()
2701 );
2702 }
2703
2704 tracing::debug!(
2705 "join_opt[{query_label}]: hash join optimization applicable with {} equality constraints",
2706 constraint_plan.equalities.len()
2707 );
2708
2709 let mut literal_map: Vec<Vec<ColumnConstraint>> =
2710 vec![Vec::new(); tables_with_handles.len()];
2711 for constraint in &constraint_plan.literals {
2712 let table_idx = match constraint {
2713 ColumnConstraint::Equality(lit) => lit.column.table,
2714 ColumnConstraint::InList(in_list) => in_list.column.table,
2715 };
2716 if table_idx >= literal_map.len() {
2717 tracing::debug!(
2718 "join_opt[{query_label}]: constraint references unknown table index {}; falling back",
2719 table_idx
2720 );
2721 return Ok(None);
2722 }
2723 tracing::debug!(
2724 "join_opt[{query_label}]: mapping constraint to table_idx={} (table={})",
2725 table_idx,
2726 tables_with_handles[table_idx].0.qualified_name()
2727 );
2728 literal_map[table_idx].push(constraint.clone());
2729 }
2730
2731 let mut per_table: Vec<Option<TableCrossProductData>> =
2732 Vec::with_capacity(tables_with_handles.len());
2733 for (idx, (table_ref, table)) in tables_with_handles.iter().enumerate() {
2734 let data =
2735 collect_table_data(idx, table_ref, table.as_ref(), literal_map[idx].as_slice())?;
2736 per_table.push(Some(data));
2737 }
2738
2739 let has_left_join = plan
2741 .joins
2742 .iter()
2743 .any(|j| j.join_type == llkv_plan::JoinPlan::Left);
2744
2745 let mut current: Option<TableCrossProductData> = None;
2746
2747 if has_left_join {
2748 tracing::debug!(
2750 "join_opt[{query_label}]: delegating to llkv-join for LEFT JOIN support"
2751 );
2752 return Ok(None);
2754 } else {
2755 let mut remaining: Vec<usize> = (0..tables_with_handles.len()).collect();
2757 let mut used_tables: FxHashSet<usize> = FxHashSet::default();
2758
2759 while !remaining.is_empty() {
2760 let next_index = if used_tables.is_empty() {
2761 remaining[0]
2762 } else {
2763 match remaining.iter().copied().find(|idx| {
2764 table_has_join_with_used(*idx, &used_tables, &constraint_plan.equalities)
2765 }) {
2766 Some(idx) => idx,
2767 None => {
2768 tracing::debug!(
2769 "join_opt[{query_label}]: no remaining equality links – using cartesian expansion for table index {idx}",
2770 idx = remaining[0]
2771 );
2772 remaining[0]
2773 }
2774 }
2775 };
2776
2777 let position = remaining
2778 .iter()
2779 .position(|&idx| idx == next_index)
2780 .expect("next index present");
2781
2782 let next_data = per_table[next_index]
2783 .take()
2784 .ok_or_else(|| Error::Internal("hash join consumed table data twice".into()))?;
2785
2786 if let Some(current_data) = current.take() {
2787 let join_keys = gather_join_keys(
2788 ¤t_data,
2789 &next_data,
2790 &used_tables,
2791 next_index,
2792 &constraint_plan.equalities,
2793 )?;
2794
2795 let joined = if join_keys.is_empty() {
2796 tracing::debug!(
2797 "join_opt[{query_label}]: joining '{}' via cartesian expansion (no equality keys)",
2798 tables_with_handles[next_index].0.qualified_name()
2799 );
2800 cross_join_table_batches(current_data, next_data)?
2801 } else {
2802 hash_join_table_batches(
2803 current_data,
2804 next_data,
2805 &join_keys,
2806 llkv_join::JoinType::Inner,
2807 )?
2808 };
2809 current = Some(joined);
2810 } else {
2811 current = Some(next_data);
2812 }
2813
2814 used_tables.insert(next_index);
2815 remaining.remove(position);
2816 }
2817 }
2818
2819 if let Some(result) = current {
2820 let handled_all = constraint_plan.handled_conjuncts == constraint_plan.total_conjuncts;
2821 tracing::debug!(
2822 "join_opt[{query_label}]: hash join succeeded across {} tables (handled {}/{} predicates)",
2823 tables_with_handles.len(),
2824 constraint_plan.handled_conjuncts,
2825 constraint_plan.total_conjuncts
2826 );
2827 return Ok(Some((result, handled_all)));
2828 }
2829
2830 Ok(None)
2831 }
2832
2833 fn execute_projection(
2834 &self,
2835 table: Arc<ExecutorTable<P>>,
2836 display_name: String,
2837 plan: SelectPlan,
2838 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
2839 ) -> ExecutorResult<SelectExecution<P>> {
2840 if plan.having.is_some() {
2841 return Err(Error::InvalidArgumentError(
2842 "HAVING requires GROUP BY".into(),
2843 ));
2844 }
2845 if plan
2846 .filter
2847 .as_ref()
2848 .is_some_and(|filter| !filter.subqueries.is_empty())
2849 || !plan.scalar_subqueries.is_empty()
2850 {
2851 return self.execute_projection_with_subqueries(table, display_name, plan, row_filter);
2852 }
2853
2854 let table_ref = table.as_ref();
2855 let constant_filter = plan
2856 .filter
2857 .as_ref()
2858 .and_then(|filter| evaluate_constant_predicate(&filter.predicate));
2859 let projections = if plan.projections.is_empty() {
2860 build_wildcard_projections(table_ref)
2861 } else {
2862 build_projected_columns(table_ref, &plan.projections)?
2863 };
2864 let schema = schema_for_projections(table_ref, &projections)?;
2865
2866 if let Some(result) = constant_filter {
2867 match result {
2868 Some(true) => {
2869 }
2871 Some(false) | None => {
2872 let batch = RecordBatch::new_empty(Arc::clone(&schema));
2873 return Ok(SelectExecution::new_single_batch(
2874 display_name,
2875 schema,
2876 batch,
2877 ));
2878 }
2879 }
2880 }
2881
2882 let (mut filter_expr, mut full_table_scan) = match &plan.filter {
2883 Some(filter_wrapper) => (
2884 crate::translation::expression::translate_predicate(
2885 filter_wrapper.predicate.clone(),
2886 table_ref.schema.as_ref(),
2887 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
2888 )?,
2889 false,
2890 ),
2891 None => {
2892 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
2893 Error::InvalidArgumentError(
2894 "table has no columns; cannot perform wildcard scan".into(),
2895 )
2896 })?;
2897 (
2898 crate::translation::expression::full_table_scan_filter(field_id),
2899 true,
2900 )
2901 }
2902 };
2903
2904 if matches!(constant_filter, Some(Some(true))) {
2905 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
2906 Error::InvalidArgumentError(
2907 "table has no columns; cannot perform wildcard scan".into(),
2908 )
2909 })?;
2910 filter_expr = crate::translation::expression::full_table_scan_filter(field_id);
2911 full_table_scan = true;
2912 }
2913
2914 let expanded_order = expand_order_targets(&plan.order_by, &projections)?;
2915
2916 let mut physical_order: Option<ScanOrderSpec> = None;
2917
2918 if let Some(first) = expanded_order.first() {
2919 match &first.target {
2920 OrderTarget::Column(name) => {
2921 if table_ref.schema.resolve(name).is_some() {
2922 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
2923 }
2924 }
2925 OrderTarget::Index(position) => match projections.get(*position) {
2926 Some(ScanProjection::Column(_)) => {
2927 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
2928 }
2929 Some(ScanProjection::Computed { .. }) => {}
2930 None => {
2931 return Err(Error::InvalidArgumentError(format!(
2932 "ORDER BY position {} is out of range",
2933 position + 1
2934 )));
2935 }
2936 },
2937 OrderTarget::All => {}
2938 }
2939 }
2940
2941 let options = if let Some(order_spec) = physical_order {
2942 if row_filter.is_some() {
2943 tracing::debug!("Applying MVCC row filter with ORDER BY");
2944 }
2945 ScanStreamOptions {
2946 include_nulls: true,
2947 order: Some(order_spec),
2948 row_id_filter: row_filter.clone(),
2949 }
2950 } else {
2951 if row_filter.is_some() {
2952 tracing::debug!("Applying MVCC row filter");
2953 }
2954 ScanStreamOptions {
2955 include_nulls: true,
2956 order: None,
2957 row_id_filter: row_filter.clone(),
2958 }
2959 };
2960
2961 Ok(SelectExecution::new_projection(
2962 display_name,
2963 schema,
2964 table,
2965 projections,
2966 filter_expr,
2967 options,
2968 full_table_scan,
2969 expanded_order,
2970 plan.distinct,
2971 ))
2972 }
2973
2974 fn execute_projection_with_subqueries(
2975 &self,
2976 table: Arc<ExecutorTable<P>>,
2977 display_name: String,
2978 plan: SelectPlan,
2979 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
2980 ) -> ExecutorResult<SelectExecution<P>> {
2981 if plan.having.is_some() {
2982 return Err(Error::InvalidArgumentError(
2983 "HAVING requires GROUP BY".into(),
2984 ));
2985 }
2986 let table_ref = table.as_ref();
2987
2988 let (output_scan_projections, effective_projections): (
2989 Vec<ScanProjection>,
2990 Vec<SelectProjection>,
2991 ) = if plan.projections.is_empty() {
2992 (
2993 build_wildcard_projections(table_ref),
2994 vec![SelectProjection::AllColumns],
2995 )
2996 } else {
2997 (
2998 build_projected_columns(table_ref, &plan.projections)?,
2999 plan.projections.clone(),
3000 )
3001 };
3002
3003 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
3004 .scalar_subqueries
3005 .iter()
3006 .map(|subquery| (subquery.id, subquery))
3007 .collect();
3008
3009 let base_projections = build_wildcard_projections(table_ref);
3010
3011 let filter_wrapper_opt = plan.filter.as_ref();
3012
3013 let mut translated_filter: Option<llkv_expr::expr::Expr<'static, FieldId>> = None;
3014 let pushdown_filter = if let Some(filter_wrapper) = filter_wrapper_opt {
3015 let translated = crate::translation::expression::translate_predicate(
3016 filter_wrapper.predicate.clone(),
3017 table_ref.schema.as_ref(),
3018 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
3019 )?;
3020 if !filter_wrapper.subqueries.is_empty() {
3021 translated_filter = Some(translated.clone());
3022 strip_exists(&translated)
3023 } else {
3024 translated
3025 }
3026 } else {
3027 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
3028 Error::InvalidArgumentError(
3029 "table has no columns; cannot perform scalar subquery projection".into(),
3030 )
3031 })?;
3032 crate::translation::expression::full_table_scan_filter(field_id)
3033 };
3034
3035 let mut base_fields: Vec<Field> = Vec::with_capacity(table_ref.schema.columns.len());
3036 for column in &table_ref.schema.columns {
3037 base_fields.push(Field::new(
3038 column.name.clone(),
3039 column.data_type.clone(),
3040 column.nullable,
3041 ));
3042 }
3043 let base_schema = Arc::new(Schema::new(base_fields));
3044 let base_column_counts = vec![base_schema.fields().len()];
3045 let base_table_indices = vec![0usize];
3046 let base_lookup = build_cross_product_column_lookup(
3047 base_schema.as_ref(),
3048 &plan.tables,
3049 &base_column_counts,
3050 &base_table_indices,
3051 );
3052
3053 let mut filter_context = if translated_filter.is_some() {
3054 Some(CrossProductExpressionContext::new(
3055 base_schema.as_ref(),
3056 base_lookup.clone(),
3057 )?)
3058 } else {
3059 None
3060 };
3061
3062 let options = ScanStreamOptions {
3063 include_nulls: true,
3064 order: None,
3065 row_id_filter: row_filter.clone(),
3066 };
3067
3068 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
3069 filter_wrapper_opt
3070 .map(|wrapper| {
3071 wrapper
3072 .subqueries
3073 .iter()
3074 .map(|subquery| (subquery.id, subquery))
3075 .collect()
3076 })
3077 .unwrap_or_default();
3078
3079 let mut projected_batches: Vec<RecordBatch> = Vec::new();
3080 let mut scan_error: Option<Error> = None;
3081
3082 table.table.scan_stream(
3083 base_projections.clone(),
3084 &pushdown_filter,
3085 options,
3086 |batch| {
3087 if scan_error.is_some() {
3088 return;
3089 }
3090 let effective_batch = if let Some(context) = filter_context.as_mut() {
3091 context.reset();
3092 let translated = translated_filter
3093 .as_ref()
3094 .expect("filter context requires translated filter");
3095 let mask = match context.evaluate_predicate_mask(
3096 translated,
3097 &batch,
3098 |ctx, subquery_expr, row_idx, current_batch| {
3099 let subquery =
3100 subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
3101 Error::Internal("missing correlated subquery metadata".into())
3102 })?;
3103 let exists = self.evaluate_exists_subquery(
3104 ctx,
3105 subquery,
3106 current_batch,
3107 row_idx,
3108 )?;
3109 let value = if subquery_expr.negated {
3110 !exists
3111 } else {
3112 exists
3113 };
3114 Ok(Some(value))
3115 },
3116 ) {
3117 Ok(mask) => mask,
3118 Err(err) => {
3119 scan_error = Some(err);
3120 return;
3121 }
3122 };
3123 match filter_record_batch(&batch, &mask) {
3124 Ok(filtered) => {
3125 if filtered.num_rows() == 0 {
3126 return;
3127 }
3128 filtered
3129 }
3130 Err(err) => {
3131 scan_error = Some(Error::InvalidArgumentError(format!(
3132 "failed to apply EXISTS filter: {err}"
3133 )));
3134 return;
3135 }
3136 }
3137 } else {
3138 batch.clone()
3139 };
3140
3141 if effective_batch.num_rows() == 0 {
3142 return;
3143 }
3144
3145 let projected = match self.project_record_batch(
3146 &effective_batch,
3147 &effective_projections,
3148 &base_lookup,
3149 &scalar_lookup,
3150 ) {
3151 Ok(batch) => batch,
3152 Err(err) => {
3153 scan_error = Some(Error::InvalidArgumentError(format!(
3154 "failed to evaluate projections: {err}"
3155 )));
3156 return;
3157 }
3158 };
3159 projected_batches.push(projected);
3160 },
3161 )?;
3162
3163 if let Some(err) = scan_error {
3164 return Err(err);
3165 }
3166
3167 let mut result_batch = if projected_batches.is_empty() {
3168 let empty_batch = RecordBatch::new_empty(Arc::clone(&base_schema));
3169 self.project_record_batch(
3170 &empty_batch,
3171 &effective_projections,
3172 &base_lookup,
3173 &scalar_lookup,
3174 )?
3175 } else if projected_batches.len() == 1 {
3176 projected_batches.pop().unwrap()
3177 } else {
3178 let schema = projected_batches[0].schema();
3179 concat_batches(&schema, &projected_batches).map_err(|err| {
3180 Error::Internal(format!("failed to combine filtered batches: {err}"))
3181 })?
3182 };
3183
3184 if plan.distinct && result_batch.num_rows() > 0 {
3185 let mut state = DistinctState::default();
3186 let schema = result_batch.schema();
3187 result_batch = match distinct_filter_batch(result_batch, &mut state)? {
3188 Some(filtered) => filtered,
3189 None => RecordBatch::new_empty(schema),
3190 };
3191 }
3192
3193 if !plan.order_by.is_empty() && result_batch.num_rows() > 0 {
3194 let expanded_order = expand_order_targets(&plan.order_by, &output_scan_projections)?;
3195 if !expanded_order.is_empty() {
3196 result_batch = sort_record_batch_with_order(
3197 &result_batch.schema(),
3198 &result_batch,
3199 &expanded_order,
3200 )?;
3201 }
3202 }
3203
3204 let schema = result_batch.schema();
3205
3206 Ok(SelectExecution::new_single_batch(
3207 display_name,
3208 schema,
3209 result_batch,
3210 ))
3211 }
3212
3213 fn execute_group_by_single_table(
3214 &self,
3215 table: Arc<ExecutorTable<P>>,
3216 display_name: String,
3217 plan: SelectPlan,
3218 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
3219 ) -> ExecutorResult<SelectExecution<P>> {
3220 if plan
3221 .filter
3222 .as_ref()
3223 .is_some_and(|filter| !filter.subqueries.is_empty())
3224 || !plan.scalar_subqueries.is_empty()
3225 {
3226 return Err(Error::InvalidArgumentError(
3227 "GROUP BY with subqueries is not supported yet".into(),
3228 ));
3229 }
3230
3231 tracing::debug!(
3233 "[GROUP BY] Original plan: projections={}, aggregates={}, has_filter={}, has_having={}",
3234 plan.projections.len(),
3235 plan.aggregates.len(),
3236 plan.filter.is_some(),
3237 plan.having.is_some()
3238 );
3239
3240 let mut base_plan = plan.clone();
3244 base_plan.projections.clear();
3245 base_plan.aggregates.clear();
3246 base_plan.scalar_subqueries.clear();
3247 base_plan.order_by.clear();
3248 base_plan.distinct = false;
3249 base_plan.group_by.clear();
3250 base_plan.value_table_mode = None;
3251 base_plan.having = None;
3252
3253 tracing::debug!(
3254 "[GROUP BY] Base plan: projections={}, aggregates={}, has_filter={}, has_having={}",
3255 base_plan.projections.len(),
3256 base_plan.aggregates.len(),
3257 base_plan.filter.is_some(),
3258 base_plan.having.is_some()
3259 );
3260
3261 let table_ref = table.as_ref();
3264 let projections = build_wildcard_projections(table_ref);
3265 let base_schema = schema_for_projections(table_ref, &projections)?;
3266
3267 tracing::debug!(
3269 "[GROUP BY] Building base filter: has_filter={}",
3270 base_plan.filter.is_some()
3271 );
3272 let (filter_expr, full_table_scan) = match &base_plan.filter {
3273 Some(filter_wrapper) => {
3274 tracing::debug!(
3275 "[GROUP BY] Translating filter predicate: {:?}",
3276 filter_wrapper.predicate
3277 );
3278 let expr = crate::translation::expression::translate_predicate(
3279 filter_wrapper.predicate.clone(),
3280 table_ref.schema.as_ref(),
3281 |name| {
3282 Error::InvalidArgumentError(format!(
3283 "Binder Error: does not have a column named '{}'",
3284 name
3285 ))
3286 },
3287 )?;
3288 tracing::debug!("[GROUP BY] Translated filter expr: {:?}", expr);
3289 (expr, false)
3290 }
3291 None => {
3292 let first_col =
3294 table_ref.schema.columns.first().ok_or_else(|| {
3295 Error::InvalidArgumentError("Table has no columns".into())
3296 })?;
3297 (full_table_scan_filter(first_col.field_id), true)
3298 }
3299 };
3300
3301 let options = ScanStreamOptions {
3302 include_nulls: true,
3303 order: None,
3304 row_id_filter: row_filter.clone(),
3305 };
3306
3307 let execution = SelectExecution::new_projection(
3308 display_name.clone(),
3309 Arc::clone(&base_schema),
3310 Arc::clone(&table),
3311 projections,
3312 filter_expr,
3313 options,
3314 full_table_scan,
3315 vec![],
3316 false,
3317 );
3318
3319 let batches = execution.collect()?;
3320
3321 let column_lookup_map = build_column_lookup_map(base_schema.as_ref());
3322
3323 self.execute_group_by_from_batches(
3324 display_name,
3325 plan,
3326 base_schema,
3327 batches,
3328 column_lookup_map,
3329 )
3330 }
3331
3332 fn execute_group_by_from_batches(
3333 &self,
3334 display_name: String,
3335 plan: SelectPlan,
3336 base_schema: Arc<Schema>,
3337 batches: Vec<RecordBatch>,
3338 column_lookup_map: FxHashMap<String, usize>,
3339 ) -> ExecutorResult<SelectExecution<P>> {
3340 if plan
3341 .filter
3342 .as_ref()
3343 .is_some_and(|filter| !filter.subqueries.is_empty())
3344 || !plan.scalar_subqueries.is_empty()
3345 {
3346 return Err(Error::InvalidArgumentError(
3347 "GROUP BY with subqueries is not supported yet".into(),
3348 ));
3349 }
3350
3351 let having_has_aggregates = plan
3354 .having
3355 .as_ref()
3356 .map(|h| Self::predicate_contains_aggregate(h))
3357 .unwrap_or(false);
3358
3359 tracing::debug!(
3360 "[GROUP BY PATH] aggregates={}, has_computed={}, having_has_agg={}",
3361 plan.aggregates.len(),
3362 self.has_computed_aggregates(&plan),
3363 having_has_aggregates
3364 );
3365
3366 if !plan.aggregates.is_empty()
3367 || self.has_computed_aggregates(&plan)
3368 || having_has_aggregates
3369 {
3370 tracing::debug!("[GROUP BY PATH] Taking aggregates path");
3371 return self.execute_group_by_with_aggregates(
3372 display_name,
3373 plan,
3374 base_schema,
3375 batches,
3376 column_lookup_map,
3377 );
3378 }
3379
3380 let mut key_indices = Vec::with_capacity(plan.group_by.len());
3381 for column in &plan.group_by {
3382 let key = column.to_ascii_lowercase();
3383 let index = column_lookup_map.get(&key).ok_or_else(|| {
3384 Error::InvalidArgumentError(format!(
3385 "column '{}' not found in GROUP BY input",
3386 column
3387 ))
3388 })?;
3389 key_indices.push(*index);
3390 }
3391
3392 let sample_batch = batches
3393 .first()
3394 .cloned()
3395 .unwrap_or_else(|| RecordBatch::new_empty(Arc::clone(&base_schema)));
3396
3397 let output_columns = self.build_group_by_output_columns(
3398 &plan,
3399 base_schema.as_ref(),
3400 &column_lookup_map,
3401 &sample_batch,
3402 )?;
3403
3404 let constant_having = plan.having.as_ref().and_then(evaluate_constant_predicate);
3405
3406 if let Some(result) = constant_having
3407 && !result.unwrap_or(false)
3408 {
3409 let fields: Vec<Field> = output_columns
3410 .iter()
3411 .map(|output| output.field.clone())
3412 .collect();
3413 let schema = Arc::new(Schema::new(fields));
3414 let batch = RecordBatch::new_empty(Arc::clone(&schema));
3415 return Ok(SelectExecution::new_single_batch(
3416 display_name,
3417 schema,
3418 batch,
3419 ));
3420 }
3421
3422 let translated_having = if plan.having.is_some() && constant_having.is_none() {
3423 let having = plan.having.clone().expect("checked above");
3424 if Self::predicate_contains_aggregate(&having) {
3427 None
3428 } else {
3429 let temp_context = CrossProductExpressionContext::new(
3430 base_schema.as_ref(),
3431 column_lookup_map.clone(),
3432 )?;
3433 Some(translate_predicate(
3434 having,
3435 temp_context.schema(),
3436 |name| {
3437 Error::InvalidArgumentError(format!(
3438 "column '{}' not found in GROUP BY result",
3439 name
3440 ))
3441 },
3442 )?)
3443 }
3444 } else {
3445 None
3446 };
3447
3448 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
3449 let mut groups: Vec<GroupState> = Vec::new();
3450
3451 for batch in &batches {
3452 for row_idx in 0..batch.num_rows() {
3453 let key = build_group_key(batch, row_idx, &key_indices)?;
3454 if group_index.contains_key(&key) {
3455 continue;
3456 }
3457 group_index.insert(key, groups.len());
3458 groups.push(GroupState {
3459 batch: batch.clone(),
3460 row_idx,
3461 });
3462 }
3463 }
3464
3465 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(groups.len());
3466
3467 for group in &groups {
3468 if let Some(predicate) = translated_having.as_ref() {
3469 let mut context = CrossProductExpressionContext::new(
3470 group.batch.schema().as_ref(),
3471 column_lookup_map.clone(),
3472 )?;
3473 context.reset();
3474 let mut eval = |_ctx: &mut CrossProductExpressionContext,
3475 _subquery_expr: &llkv_expr::SubqueryExpr,
3476 _row_idx: usize,
3477 _current_batch: &RecordBatch|
3478 -> ExecutorResult<Option<bool>> {
3479 Err(Error::InvalidArgumentError(
3480 "HAVING subqueries are not supported yet".into(),
3481 ))
3482 };
3483 let truths =
3484 context.evaluate_predicate_truths(predicate, &group.batch, &mut eval)?;
3485 let passes = truths
3486 .get(group.row_idx)
3487 .copied()
3488 .flatten()
3489 .unwrap_or(false);
3490 if !passes {
3491 continue;
3492 }
3493 }
3494
3495 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
3496 for output in &output_columns {
3497 match output.source {
3498 OutputSource::TableColumn { index } => {
3499 let value = llkv_plan::plan_value_from_array(
3500 group.batch.column(index),
3501 group.row_idx,
3502 )?;
3503 row.push(value);
3504 }
3505 OutputSource::Computed { projection_index } => {
3506 let expr = match &plan.projections[projection_index] {
3507 SelectProjection::Computed { expr, .. } => expr,
3508 _ => unreachable!("projection index mismatch for computed column"),
3509 };
3510 let mut context = CrossProductExpressionContext::new(
3511 group.batch.schema().as_ref(),
3512 column_lookup_map.clone(),
3513 )?;
3514 context.reset();
3515 let evaluated = self.evaluate_projection_expression(
3516 &mut context,
3517 expr,
3518 &group.batch,
3519 &FxHashMap::default(),
3520 )?;
3521 let value = llkv_plan::plan_value_from_array(&evaluated, group.row_idx)?;
3522 row.push(value);
3523 }
3524 }
3525 }
3526 rows.push(row);
3527 }
3528
3529 let fields: Vec<Field> = output_columns
3530 .into_iter()
3531 .map(|output| output.field)
3532 .collect();
3533 let schema = Arc::new(Schema::new(fields));
3534
3535 let mut batch = rows_to_record_batch(Arc::clone(&schema), &rows)?;
3536
3537 if plan.distinct && batch.num_rows() > 0 {
3538 let mut state = DistinctState::default();
3539 batch = match distinct_filter_batch(batch, &mut state)? {
3540 Some(filtered) => filtered,
3541 None => RecordBatch::new_empty(Arc::clone(&schema)),
3542 };
3543 }
3544
3545 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
3546 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
3547 }
3548
3549 Ok(SelectExecution::new_single_batch(
3550 display_name,
3551 schema,
3552 batch,
3553 ))
3554 }
3555
3556 fn build_group_by_output_columns(
3557 &self,
3558 plan: &SelectPlan,
3559 base_schema: &Schema,
3560 column_lookup_map: &FxHashMap<String, usize>,
3561 _sample_batch: &RecordBatch,
3562 ) -> ExecutorResult<Vec<OutputColumn>> {
3563 let projections = if plan.projections.is_empty() {
3564 vec![SelectProjection::AllColumns]
3565 } else {
3566 plan.projections.clone()
3567 };
3568
3569 let mut columns: Vec<OutputColumn> = Vec::new();
3570
3571 for (proj_idx, projection) in projections.iter().enumerate() {
3572 match projection {
3573 SelectProjection::AllColumns => {
3574 for (index, field) in base_schema.fields().iter().enumerate() {
3575 columns.push(OutputColumn {
3576 field: (**field).clone(),
3577 source: OutputSource::TableColumn { index },
3578 });
3579 }
3580 }
3581 SelectProjection::AllColumnsExcept { exclude } => {
3582 let exclude_lower: FxHashSet<String> = exclude
3583 .iter()
3584 .map(|name| name.to_ascii_lowercase())
3585 .collect();
3586 for (index, field) in base_schema.fields().iter().enumerate() {
3587 if !exclude_lower.contains(&field.name().to_ascii_lowercase()) {
3588 columns.push(OutputColumn {
3589 field: (**field).clone(),
3590 source: OutputSource::TableColumn { index },
3591 });
3592 }
3593 }
3594 }
3595 SelectProjection::Column { name, alias } => {
3596 let lookup_key = name.to_ascii_lowercase();
3597 let index = column_lookup_map.get(&lookup_key).ok_or_else(|| {
3598 Error::InvalidArgumentError(format!(
3599 "column '{}' not found in GROUP BY result",
3600 name
3601 ))
3602 })?;
3603 let field = base_schema.field(*index);
3604 let field = Field::new(
3605 alias.as_ref().unwrap_or(name).clone(),
3606 field.data_type().clone(),
3607 field.is_nullable(),
3608 );
3609 columns.push(OutputColumn {
3610 field,
3611 source: OutputSource::TableColumn { index: *index },
3612 });
3613 }
3614 SelectProjection::Computed { expr: _, alias } => {
3615 let field = Field::new(alias.clone(), DataType::Float64, true);
3619 columns.push(OutputColumn {
3620 field,
3621 source: OutputSource::Computed {
3622 projection_index: proj_idx,
3623 },
3624 });
3625 }
3626 }
3627 }
3628
3629 if columns.is_empty() {
3630 for (index, field) in base_schema.fields().iter().enumerate() {
3631 columns.push(OutputColumn {
3632 field: (**field).clone(),
3633 source: OutputSource::TableColumn { index },
3634 });
3635 }
3636 }
3637
3638 Ok(columns)
3639 }
3640
3641 fn project_record_batch(
3642 &self,
3643 batch: &RecordBatch,
3644 projections: &[SelectProjection],
3645 lookup: &FxHashMap<String, usize>,
3646 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
3647 ) -> ExecutorResult<RecordBatch> {
3648 if projections.is_empty() {
3649 return Ok(batch.clone());
3650 }
3651
3652 let schema = batch.schema();
3653 let mut selected_fields: Vec<Arc<Field>> = Vec::new();
3654 let mut selected_columns: Vec<ArrayRef> = Vec::new();
3655 let mut expr_context: Option<CrossProductExpressionContext> = None;
3656
3657 for proj in projections {
3658 match proj {
3659 SelectProjection::AllColumns => {
3660 selected_fields = schema.fields().iter().cloned().collect();
3661 selected_columns = batch.columns().to_vec();
3662 break;
3663 }
3664 SelectProjection::AllColumnsExcept { exclude } => {
3665 let exclude_lower: FxHashSet<String> = exclude
3666 .iter()
3667 .map(|name| name.to_ascii_lowercase())
3668 .collect();
3669 for (idx, field) in schema.fields().iter().enumerate() {
3670 let column_name = field.name().to_ascii_lowercase();
3671 if !exclude_lower.contains(&column_name) {
3672 selected_fields.push(Arc::clone(field));
3673 selected_columns.push(batch.column(idx).clone());
3674 }
3675 }
3676 break;
3677 }
3678 SelectProjection::Column { name, alias } => {
3679 let normalized = name.to_ascii_lowercase();
3680 let column_index = lookup.get(&normalized).ok_or_else(|| {
3681 Error::InvalidArgumentError(format!(
3682 "column '{}' not found in projection",
3683 name
3684 ))
3685 })?;
3686 let field = schema.field(*column_index);
3687 let output_field = Arc::new(Field::new(
3688 alias.as_ref().unwrap_or_else(|| field.name()),
3689 field.data_type().clone(),
3690 field.is_nullable(),
3691 ));
3692 selected_fields.push(output_field);
3693 selected_columns.push(batch.column(*column_index).clone());
3694 }
3695 SelectProjection::Computed { expr, alias } => {
3696 if expr_context.is_none() {
3697 expr_context = Some(CrossProductExpressionContext::new(
3698 schema.as_ref(),
3699 lookup.clone(),
3700 )?);
3701 }
3702 let context = expr_context
3703 .as_mut()
3704 .expect("projection context must be initialized");
3705 context.reset();
3706 let evaluated =
3707 self.evaluate_projection_expression(context, expr, batch, scalar_lookup)?;
3708 let field = Arc::new(Field::new(
3709 alias.clone(),
3710 evaluated.data_type().clone(),
3711 true,
3712 ));
3713 selected_fields.push(field);
3714 selected_columns.push(evaluated);
3715 }
3716 }
3717 }
3718
3719 let projected_schema = Arc::new(Schema::new(selected_fields));
3720 RecordBatch::try_new(projected_schema, selected_columns)
3721 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))
3722 }
3723
3724 fn execute_group_by_with_aggregates(
3726 &self,
3727 display_name: String,
3728 plan: SelectPlan,
3729 base_schema: Arc<Schema>,
3730 batches: Vec<RecordBatch>,
3731 column_lookup_map: FxHashMap<String, usize>,
3732 ) -> ExecutorResult<SelectExecution<P>> {
3733 use llkv_expr::expr::AggregateCall;
3734
3735 let mut key_indices = Vec::with_capacity(plan.group_by.len());
3737 for column in &plan.group_by {
3738 let key = column.to_ascii_lowercase();
3739 let index = column_lookup_map.get(&key).ok_or_else(|| {
3740 Error::InvalidArgumentError(format!(
3741 "column '{}' not found in GROUP BY input",
3742 column
3743 ))
3744 })?;
3745 key_indices.push(*index);
3746 }
3747
3748 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
3750 for proj in &plan.projections {
3751 if let SelectProjection::Computed { expr, .. } = proj {
3752 Self::collect_aggregates(expr, &mut aggregate_specs);
3753 }
3754 }
3755
3756 if let Some(having_expr) = &plan.having {
3758 Self::collect_aggregates_from_predicate(having_expr, &mut aggregate_specs);
3759 }
3760
3761 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
3763 let mut group_states: Vec<GroupAggregateState> = Vec::new();
3764
3765 for (batch_idx, batch) in batches.iter().enumerate() {
3767 for row_idx in 0..batch.num_rows() {
3768 let key = build_group_key(batch, row_idx, &key_indices)?;
3769
3770 if let Some(&group_idx) = group_index.get(&key) {
3771 group_states[group_idx]
3773 .row_locations
3774 .push((batch_idx, row_idx));
3775 } else {
3776 let group_idx = group_states.len();
3778 group_index.insert(key, group_idx);
3779 group_states.push(GroupAggregateState {
3780 representative_batch_idx: batch_idx,
3781 representative_row: row_idx,
3782 row_locations: vec![(batch_idx, row_idx)],
3783 });
3784 }
3785 }
3786 }
3787
3788 let mut group_aggregate_values: Vec<FxHashMap<String, PlanValue>> =
3790 Vec::with_capacity(group_states.len());
3791
3792 for group_state in &group_states {
3793 tracing::debug!(
3794 "[GROUP BY] aggregate group rows={:?}",
3795 group_state.row_locations
3796 );
3797 let group_batch = {
3799 let representative_batch = &batches[group_state.representative_batch_idx];
3800 let schema = representative_batch.schema();
3801
3802 let mut per_batch_indices: Vec<(usize, Vec<u64>)> = Vec::new();
3804 for &(batch_idx, row_idx) in &group_state.row_locations {
3805 if let Some((_, indices)) = per_batch_indices
3806 .iter_mut()
3807 .find(|(idx, _)| *idx == batch_idx)
3808 {
3809 indices.push(row_idx as u64);
3810 } else {
3811 per_batch_indices.push((batch_idx, vec![row_idx as u64]));
3812 }
3813 }
3814
3815 let mut row_index_arrays: Vec<(usize, ArrayRef)> =
3816 Vec::with_capacity(per_batch_indices.len());
3817 for (batch_idx, indices) in per_batch_indices {
3818 let index_array: ArrayRef = Arc::new(arrow::array::UInt64Array::from(indices));
3819 row_index_arrays.push((batch_idx, index_array));
3820 }
3821
3822 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
3823
3824 for col_idx in 0..schema.fields().len() {
3825 let column_array = if row_index_arrays.len() == 1 {
3826 let (batch_idx, indices) = &row_index_arrays[0];
3827 let source_array = batches[*batch_idx].column(col_idx);
3828 arrow::compute::take(source_array.as_ref(), indices.as_ref(), None)?
3829 } else {
3830 let mut partial_arrays: Vec<ArrayRef> =
3831 Vec::with_capacity(row_index_arrays.len());
3832 for (batch_idx, indices) in &row_index_arrays {
3833 let source_array = batches[*batch_idx].column(col_idx);
3834 let taken = arrow::compute::take(
3835 source_array.as_ref(),
3836 indices.as_ref(),
3837 None,
3838 )?;
3839 partial_arrays.push(taken);
3840 }
3841 let slices: Vec<&dyn arrow::array::Array> =
3842 partial_arrays.iter().map(|arr| arr.as_ref()).collect();
3843 arrow::compute::concat(&slices)?
3844 };
3845 arrays.push(column_array);
3846 }
3847
3848 let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
3849 tracing::debug!("[GROUP BY] group batch rows={}", batch.num_rows());
3850 batch
3851 };
3852
3853 let mut aggregate_values: FxHashMap<String, PlanValue> = FxHashMap::default();
3855
3856 let mut working_batch = group_batch.clone();
3858 let mut next_temp_col_idx = working_batch.num_columns();
3859
3860 for (key, agg_call) in &aggregate_specs {
3861 let (projection_idx, value_type) = match agg_call {
3863 AggregateCall::CountStar => (None, None),
3864 AggregateCall::Count { expr, .. }
3865 | AggregateCall::Sum { expr, .. }
3866 | AggregateCall::Total { expr, .. }
3867 | AggregateCall::Avg { expr, .. }
3868 | AggregateCall::Min(expr)
3869 | AggregateCall::Max(expr)
3870 | AggregateCall::CountNulls(expr)
3871 | AggregateCall::GroupConcat { expr, .. } => {
3872 if let Some(col_name) = try_extract_simple_column(expr) {
3873 let idx = resolve_column_name_to_index(col_name, &column_lookup_map)
3874 .ok_or_else(|| {
3875 Error::InvalidArgumentError(format!(
3876 "column '{}' not found for aggregate",
3877 col_name
3878 ))
3879 })?;
3880 let field_type = working_batch.schema().field(idx).data_type().clone();
3881 (Some(idx), Some(field_type))
3882 } else {
3883 let mut computed_values = Vec::with_capacity(working_batch.num_rows());
3885 for row_idx in 0..working_batch.num_rows() {
3886 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
3887 expr,
3888 &FxHashMap::default(),
3889 Some(&working_batch),
3890 Some(&column_lookup_map),
3891 row_idx,
3892 )?;
3893 computed_values.push(value);
3894 }
3895
3896 let computed_array = plan_values_to_arrow_array(&computed_values)?;
3897 let computed_type = computed_array.data_type().clone();
3898
3899 let mut new_columns: Vec<ArrayRef> = working_batch.columns().to_vec();
3900 new_columns.push(computed_array);
3901
3902 let temp_field = Arc::new(Field::new(
3903 format!("__temp_agg_expr_{}", next_temp_col_idx),
3904 computed_type.clone(),
3905 true,
3906 ));
3907 let mut new_fields: Vec<Arc<Field>> =
3908 working_batch.schema().fields().iter().cloned().collect();
3909 new_fields.push(temp_field);
3910 let new_schema = Arc::new(Schema::new(new_fields));
3911
3912 working_batch = RecordBatch::try_new(new_schema, new_columns)?;
3913
3914 let col_idx = next_temp_col_idx;
3915 next_temp_col_idx += 1;
3916 (Some(col_idx), Some(computed_type))
3917 }
3918 }
3919 };
3920
3921 let spec = Self::build_aggregate_spec_for_cross_product(
3923 agg_call,
3924 key.clone(),
3925 value_type.clone(),
3926 )?;
3927
3928 let mut state = llkv_aggregate::AggregateState {
3929 alias: key.clone(),
3930 accumulator: llkv_aggregate::AggregateAccumulator::new_with_projection_index(
3931 &spec,
3932 projection_idx,
3933 None,
3934 )?,
3935 override_value: None,
3936 };
3937
3938 state.update(&working_batch)?;
3940
3941 let (_field, array) = state.finalize()?;
3943 let value = llkv_plan::plan_value_from_array(&array, 0)?;
3944 tracing::debug!(
3945 "[GROUP BY] aggregate result key={:?} value={:?}",
3946 key,
3947 value
3948 );
3949 aggregate_values.insert(key.clone(), value);
3950 }
3951
3952 group_aggregate_values.push(aggregate_values);
3953 }
3954
3955 let output_columns = self.build_group_by_output_columns(
3957 &plan,
3958 base_schema.as_ref(),
3959 &column_lookup_map,
3960 batches
3961 .first()
3962 .unwrap_or(&RecordBatch::new_empty(Arc::clone(&base_schema))),
3963 )?;
3964
3965 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(group_states.len());
3966
3967 for (group_idx, group_state) in group_states.iter().enumerate() {
3968 let aggregate_values = &group_aggregate_values[group_idx];
3969 let representative_batch = &batches[group_state.representative_batch_idx];
3970
3971 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
3972 for output in &output_columns {
3973 match output.source {
3974 OutputSource::TableColumn { index } => {
3975 let value = llkv_plan::plan_value_from_array(
3977 representative_batch.column(index),
3978 group_state.representative_row,
3979 )?;
3980 row.push(value);
3981 }
3982 OutputSource::Computed { projection_index } => {
3983 let expr = match &plan.projections[projection_index] {
3984 SelectProjection::Computed { expr, .. } => expr,
3985 _ => unreachable!("projection index mismatch for computed column"),
3986 };
3987 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
3989 expr,
3990 aggregate_values,
3991 Some(representative_batch),
3992 Some(&column_lookup_map),
3993 group_state.representative_row,
3994 )?;
3995 row.push(value);
3996 }
3997 }
3998 }
3999 rows.push(row);
4000 }
4001
4002 let filtered_rows = if let Some(having) = &plan.having {
4004 let mut filtered = Vec::new();
4005 for (row_idx, row) in rows.iter().enumerate() {
4006 let aggregate_values = &group_aggregate_values[row_idx];
4007 let group_state = &group_states[row_idx];
4008 let representative_batch = &batches[group_state.representative_batch_idx];
4009 let passes = Self::evaluate_having_expr(
4011 having,
4012 aggregate_values,
4013 representative_batch,
4014 &column_lookup_map,
4015 group_state.representative_row,
4016 )?;
4017 if matches!(passes, Some(true)) {
4019 filtered.push(row.clone());
4020 }
4021 }
4022 filtered
4023 } else {
4024 rows
4025 };
4026
4027 let fields: Vec<Field> = output_columns
4028 .into_iter()
4029 .map(|output| output.field)
4030 .collect();
4031 let schema = Arc::new(Schema::new(fields));
4032
4033 let mut batch = rows_to_record_batch(Arc::clone(&schema), &filtered_rows)?;
4034
4035 if plan.distinct && batch.num_rows() > 0 {
4036 let mut state = DistinctState::default();
4037 batch = match distinct_filter_batch(batch, &mut state)? {
4038 Some(filtered) => filtered,
4039 None => RecordBatch::new_empty(Arc::clone(&schema)),
4040 };
4041 }
4042
4043 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
4044 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
4045 }
4046
4047 Ok(SelectExecution::new_single_batch(
4048 display_name,
4049 schema,
4050 batch,
4051 ))
4052 }
4053
4054 fn execute_aggregates(
4055 &self,
4056 table: Arc<ExecutorTable<P>>,
4057 display_name: String,
4058 plan: SelectPlan,
4059 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4060 ) -> ExecutorResult<SelectExecution<P>> {
4061 let table_ref = table.as_ref();
4062 let distinct = plan.distinct;
4063 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
4064 for aggregate in plan.aggregates {
4065 match aggregate {
4066 AggregateExpr::CountStar { alias, distinct } => {
4067 specs.push(AggregateSpec {
4068 alias,
4069 kind: AggregateKind::Count {
4070 field_id: None,
4071 distinct,
4072 },
4073 });
4074 }
4075 AggregateExpr::Column {
4076 column,
4077 alias,
4078 function,
4079 distinct,
4080 } => {
4081 let col = table_ref.schema.resolve(&column).ok_or_else(|| {
4082 Error::InvalidArgumentError(format!(
4083 "unknown column '{}' in aggregate",
4084 column
4085 ))
4086 })?;
4087
4088 let kind = match function {
4089 AggregateFunction::Count => AggregateKind::Count {
4090 field_id: Some(col.field_id),
4091 distinct,
4092 },
4093 AggregateFunction::SumInt64 => {
4094 let input_type = Self::validate_aggregate_type(
4095 Some(col.data_type.clone()),
4096 "SUM",
4097 &[DataType::Int64, DataType::Float64],
4098 )?;
4099 AggregateKind::Sum {
4100 field_id: col.field_id,
4101 data_type: input_type,
4102 distinct,
4103 }
4104 }
4105 AggregateFunction::TotalInt64 => {
4106 let input_type = Self::validate_aggregate_type(
4107 Some(col.data_type.clone()),
4108 "TOTAL",
4109 &[DataType::Int64, DataType::Float64],
4110 )?;
4111 AggregateKind::Total {
4112 field_id: col.field_id,
4113 data_type: input_type,
4114 distinct,
4115 }
4116 }
4117 AggregateFunction::MinInt64 => {
4118 let input_type = Self::validate_aggregate_type(
4119 Some(col.data_type.clone()),
4120 "MIN",
4121 &[DataType::Int64, DataType::Float64],
4122 )?;
4123 AggregateKind::Min {
4124 field_id: col.field_id,
4125 data_type: input_type,
4126 }
4127 }
4128 AggregateFunction::MaxInt64 => {
4129 let input_type = Self::validate_aggregate_type(
4130 Some(col.data_type.clone()),
4131 "MAX",
4132 &[DataType::Int64, DataType::Float64],
4133 )?;
4134 AggregateKind::Max {
4135 field_id: col.field_id,
4136 data_type: input_type,
4137 }
4138 }
4139 AggregateFunction::CountNulls => {
4140 if distinct {
4141 return Err(Error::InvalidArgumentError(
4142 "DISTINCT is not supported for COUNT_NULLS".into(),
4143 ));
4144 }
4145 AggregateKind::CountNulls {
4146 field_id: col.field_id,
4147 }
4148 }
4149 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
4150 field_id: col.field_id,
4151 distinct,
4152 separator: ",".to_string(),
4153 },
4154 };
4155 specs.push(AggregateSpec { alias, kind });
4156 }
4157 }
4158 }
4159
4160 if specs.is_empty() {
4161 return Err(Error::InvalidArgumentError(
4162 "aggregate query requires at least one aggregate expression".into(),
4163 ));
4164 }
4165
4166 let had_filter = plan.filter.is_some();
4167 let filter_expr = match &plan.filter {
4168 Some(filter_wrapper) => {
4169 if !filter_wrapper.subqueries.is_empty() {
4170 return Err(Error::InvalidArgumentError(
4171 "EXISTS subqueries not yet implemented in aggregate queries".into(),
4172 ));
4173 }
4174 crate::translation::expression::translate_predicate(
4175 filter_wrapper.predicate.clone(),
4176 table.schema.as_ref(),
4177 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4178 )?
4179 }
4180 None => {
4181 let field_id = table.schema.first_field_id().ok_or_else(|| {
4182 Error::InvalidArgumentError(
4183 "table has no columns; cannot perform aggregate scan".into(),
4184 )
4185 })?;
4186 crate::translation::expression::full_table_scan_filter(field_id)
4187 }
4188 };
4189
4190 let mut projections = Vec::new();
4192 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(specs.len());
4193
4194 for spec in &specs {
4195 if let Some(field_id) = spec.kind.field_id() {
4196 let proj_idx = projections.len();
4197 spec_to_projection.push(Some(proj_idx));
4198 projections.push(ScanProjection::from(StoreProjection::with_alias(
4199 LogicalFieldId::for_user(table.table.table_id(), field_id),
4200 table
4201 .schema
4202 .column_by_field_id(field_id)
4203 .map(|c| c.name.clone())
4204 .unwrap_or_else(|| format!("col{field_id}")),
4205 )));
4206 } else {
4207 spec_to_projection.push(None);
4208 }
4209 }
4210
4211 if projections.is_empty() {
4212 let field_id = table.schema.first_field_id().ok_or_else(|| {
4213 Error::InvalidArgumentError(
4214 "table has no columns; cannot perform aggregate scan".into(),
4215 )
4216 })?;
4217 projections.push(ScanProjection::from(StoreProjection::with_alias(
4218 LogicalFieldId::for_user(table.table.table_id(), field_id),
4219 table
4220 .schema
4221 .column_by_field_id(field_id)
4222 .map(|c| c.name.clone())
4223 .unwrap_or_else(|| format!("col{field_id}")),
4224 )));
4225 }
4226
4227 let options = ScanStreamOptions {
4228 include_nulls: true,
4229 order: None,
4230 row_id_filter: row_filter.clone(),
4231 };
4232
4233 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
4234 let mut count_star_override: Option<i64> = None;
4238 if !had_filter && row_filter.is_none() {
4239 let total_rows = table.total_rows.load(Ordering::SeqCst);
4241 tracing::debug!(
4242 "[AGGREGATE] Using COUNT(*) shortcut: total_rows={}",
4243 total_rows
4244 );
4245 if total_rows > i64::MAX as u64 {
4246 return Err(Error::InvalidArgumentError(
4247 "COUNT(*) result exceeds supported range".into(),
4248 ));
4249 }
4250 count_star_override = Some(total_rows as i64);
4251 } else {
4252 tracing::debug!(
4253 "[AGGREGATE] NOT using COUNT(*) shortcut: had_filter={}, has_row_filter={}",
4254 had_filter,
4255 row_filter.is_some()
4256 );
4257 }
4258
4259 for (idx, spec) in specs.iter().enumerate() {
4260 states.push(AggregateState {
4261 alias: spec.alias.clone(),
4262 accumulator: AggregateAccumulator::new_with_projection_index(
4263 spec,
4264 spec_to_projection[idx],
4265 count_star_override,
4266 )?,
4267 override_value: match &spec.kind {
4268 AggregateKind::Count { field_id: None, .. } => {
4269 tracing::debug!(
4270 "[AGGREGATE] CountStar override_value={:?}",
4271 count_star_override
4272 );
4273 count_star_override
4274 }
4275 _ => None,
4276 },
4277 });
4278 }
4279
4280 let mut error: Option<Error> = None;
4281 match table.table.scan_stream(
4282 projections,
4283 &filter_expr,
4284 ScanStreamOptions {
4285 row_id_filter: row_filter.clone(),
4286 ..options
4287 },
4288 |batch| {
4289 if error.is_some() {
4290 return;
4291 }
4292 for state in &mut states {
4293 if let Err(err) = state.update(&batch) {
4294 error = Some(err);
4295 return;
4296 }
4297 }
4298 },
4299 ) {
4300 Ok(()) => {}
4301 Err(llkv_result::Error::NotFound) => {
4302 }
4305 Err(err) => return Err(err),
4306 }
4307 if let Some(err) = error {
4308 return Err(err);
4309 }
4310
4311 let mut fields = Vec::with_capacity(states.len());
4312 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
4313 for state in states {
4314 let (field, array) = state.finalize()?;
4315 fields.push(field);
4316 arrays.push(array);
4317 }
4318
4319 let schema = Arc::new(Schema::new(fields));
4320 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
4321
4322 if distinct {
4323 let mut state = DistinctState::default();
4324 batch = match distinct_filter_batch(batch, &mut state)? {
4325 Some(filtered) => filtered,
4326 None => RecordBatch::new_empty(Arc::clone(&schema)),
4327 };
4328 }
4329
4330 let schema = batch.schema();
4331
4332 Ok(SelectExecution::new_single_batch(
4333 display_name,
4334 schema,
4335 batch,
4336 ))
4337 }
4338
4339 fn execute_computed_aggregates(
4342 &self,
4343 table: Arc<ExecutorTable<P>>,
4344 display_name: String,
4345 plan: SelectPlan,
4346 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4347 ) -> ExecutorResult<SelectExecution<P>> {
4348 use arrow::array::Int64Array;
4349 use llkv_expr::expr::AggregateCall;
4350
4351 let table_ref = table.as_ref();
4352 let distinct = plan.distinct;
4353
4354 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
4356 for proj in &plan.projections {
4357 if let SelectProjection::Computed { expr, .. } = proj {
4358 Self::collect_aggregates(expr, &mut aggregate_specs);
4359 }
4360 }
4361
4362 let filter_predicate = plan
4364 .filter
4365 .as_ref()
4366 .map(|wrapper| {
4367 if !wrapper.subqueries.is_empty() {
4368 return Err(Error::InvalidArgumentError(
4369 "EXISTS subqueries not yet implemented with aggregates".into(),
4370 ));
4371 }
4372 Ok(wrapper.predicate.clone())
4373 })
4374 .transpose()?;
4375
4376 let computed_aggregates = self.compute_aggregate_values(
4377 table.clone(),
4378 &filter_predicate,
4379 &aggregate_specs,
4380 row_filter.clone(),
4381 )?;
4382
4383 let mut fields = Vec::with_capacity(plan.projections.len());
4385 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
4386
4387 for proj in &plan.projections {
4388 match proj {
4389 SelectProjection::AllColumns | SelectProjection::AllColumnsExcept { .. } => {
4390 return Err(Error::InvalidArgumentError(
4391 "Wildcard projections not supported with computed aggregates".into(),
4392 ));
4393 }
4394 SelectProjection::Column { name, alias } => {
4395 let col = table_ref.schema.resolve(name).ok_or_else(|| {
4396 Error::InvalidArgumentError(format!("unknown column '{}'", name))
4397 })?;
4398 let field_name = alias.as_ref().unwrap_or(name);
4399 fields.push(arrow::datatypes::Field::new(
4400 field_name,
4401 col.data_type.clone(),
4402 col.nullable,
4403 ));
4404 return Err(Error::InvalidArgumentError(
4407 "Regular columns not supported in aggregate queries without GROUP BY"
4408 .into(),
4409 ));
4410 }
4411 SelectProjection::Computed { expr, alias } => {
4412 if let ScalarExpr::Aggregate(agg) = expr {
4414 let key = format!("{:?}", agg);
4415 if let Some(agg_value) = computed_aggregates.get(&key) {
4416 match agg_value {
4417 AggregateValue::Null => {
4418 fields.push(arrow::datatypes::Field::new(
4419 alias,
4420 DataType::Int64,
4421 true,
4422 ));
4423 arrays
4424 .push(Arc::new(Int64Array::from(vec![None::<i64>]))
4425 as ArrayRef);
4426 }
4427 AggregateValue::Int64(v) => {
4428 fields.push(arrow::datatypes::Field::new(
4429 alias,
4430 DataType::Int64,
4431 true,
4432 ));
4433 arrays.push(
4434 Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef
4435 );
4436 }
4437 AggregateValue::Float64(v) => {
4438 fields.push(arrow::datatypes::Field::new(
4439 alias,
4440 DataType::Float64,
4441 true,
4442 ));
4443 arrays
4444 .push(Arc::new(Float64Array::from(vec![Some(*v)]))
4445 as ArrayRef);
4446 }
4447 AggregateValue::String(s) => {
4448 fields.push(arrow::datatypes::Field::new(
4449 alias,
4450 DataType::Utf8,
4451 true,
4452 ));
4453 arrays
4454 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
4455 as ArrayRef);
4456 }
4457 }
4458 continue;
4459 }
4460 }
4461
4462 let value = Self::evaluate_expr_with_aggregates(expr, &computed_aggregates)?;
4464
4465 fields.push(arrow::datatypes::Field::new(alias, DataType::Int64, true));
4466
4467 let array = Arc::new(Int64Array::from(vec![value])) as ArrayRef;
4468 arrays.push(array);
4469 }
4470 }
4471 }
4472
4473 let schema = Arc::new(Schema::new(fields));
4474 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
4475
4476 if distinct {
4477 let mut state = DistinctState::default();
4478 batch = match distinct_filter_batch(batch, &mut state)? {
4479 Some(filtered) => filtered,
4480 None => RecordBatch::new_empty(Arc::clone(&schema)),
4481 };
4482 }
4483
4484 let schema = batch.schema();
4485
4486 Ok(SelectExecution::new_single_batch(
4487 display_name,
4488 schema,
4489 batch,
4490 ))
4491 }
4492
4493 fn build_aggregate_spec_for_cross_product(
4496 agg_call: &llkv_expr::expr::AggregateCall<String>,
4497 alias: String,
4498 data_type: Option<DataType>,
4499 ) -> ExecutorResult<llkv_aggregate::AggregateSpec> {
4500 use llkv_expr::expr::AggregateCall;
4501
4502 let kind = match agg_call {
4503 AggregateCall::CountStar => llkv_aggregate::AggregateKind::Count {
4504 field_id: None,
4505 distinct: false,
4506 },
4507 AggregateCall::Count { distinct, .. } => llkv_aggregate::AggregateKind::Count {
4508 field_id: Some(0),
4509 distinct: *distinct,
4510 },
4511 AggregateCall::Sum { distinct, .. } => llkv_aggregate::AggregateKind::Sum {
4512 field_id: 0,
4513 data_type: Self::validate_aggregate_type(
4514 data_type.clone(),
4515 "SUM",
4516 &[DataType::Int64, DataType::Float64],
4517 )?,
4518 distinct: *distinct,
4519 },
4520 AggregateCall::Total { distinct, .. } => llkv_aggregate::AggregateKind::Total {
4521 field_id: 0,
4522 data_type: Self::validate_aggregate_type(
4523 data_type.clone(),
4524 "TOTAL",
4525 &[DataType::Int64, DataType::Float64],
4526 )?,
4527 distinct: *distinct,
4528 },
4529 AggregateCall::Avg { distinct, .. } => llkv_aggregate::AggregateKind::Avg {
4530 field_id: 0,
4531 data_type: Self::validate_aggregate_type(
4532 data_type.clone(),
4533 "AVG",
4534 &[DataType::Int64, DataType::Float64],
4535 )?,
4536 distinct: *distinct,
4537 },
4538 AggregateCall::Min(_) => llkv_aggregate::AggregateKind::Min {
4539 field_id: 0,
4540 data_type: Self::validate_aggregate_type(
4541 data_type.clone(),
4542 "MIN",
4543 &[DataType::Int64, DataType::Float64],
4544 )?,
4545 },
4546 AggregateCall::Max(_) => llkv_aggregate::AggregateKind::Max {
4547 field_id: 0,
4548 data_type: Self::validate_aggregate_type(
4549 data_type.clone(),
4550 "MAX",
4551 &[DataType::Int64, DataType::Float64],
4552 )?,
4553 },
4554 AggregateCall::CountNulls(_) => {
4555 llkv_aggregate::AggregateKind::CountNulls { field_id: 0 }
4556 }
4557 AggregateCall::GroupConcat {
4558 distinct,
4559 separator,
4560 ..
4561 } => llkv_aggregate::AggregateKind::GroupConcat {
4562 field_id: 0,
4563 distinct: *distinct,
4564 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
4565 },
4566 };
4567
4568 Ok(llkv_aggregate::AggregateSpec { alias, kind })
4569 }
4570
4571 fn validate_aggregate_type(
4586 data_type: Option<DataType>,
4587 func_name: &str,
4588 allowed: &[DataType],
4589 ) -> ExecutorResult<DataType> {
4590 let dt = data_type.ok_or_else(|| {
4591 Error::Internal(format!(
4592 "missing input type metadata for {func_name} aggregate"
4593 ))
4594 })?;
4595
4596 if matches!(func_name, "SUM" | "AVG" | "TOTAL" | "MIN" | "MAX") {
4599 match dt {
4600 DataType::Int64 | DataType::Float64 => Ok(dt),
4602
4603 DataType::Utf8 | DataType::Boolean | DataType::Date32 => Ok(DataType::Float64),
4606
4607 _ => Err(Error::InvalidArgumentError(format!(
4608 "{func_name} aggregate not supported for column type {:?}",
4609 dt
4610 ))),
4611 }
4612 } else {
4613 if allowed.iter().any(|candidate| candidate == &dt) {
4615 Ok(dt)
4616 } else {
4617 Err(Error::InvalidArgumentError(format!(
4618 "{func_name} aggregate not supported for column type {:?}",
4619 dt
4620 )))
4621 }
4622 }
4623 }
4624
4625 fn collect_aggregates(
4627 expr: &ScalarExpr<String>,
4628 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
4629 ) {
4630 match expr {
4631 ScalarExpr::Aggregate(agg) => {
4632 let key = format!("{:?}", agg);
4634 if !aggregates.iter().any(|(k, _)| k == &key) {
4635 aggregates.push((key, agg.clone()));
4636 }
4637 }
4638 ScalarExpr::Binary { left, right, .. } => {
4639 Self::collect_aggregates(left, aggregates);
4640 Self::collect_aggregates(right, aggregates);
4641 }
4642 ScalarExpr::Compare { left, right, .. } => {
4643 Self::collect_aggregates(left, aggregates);
4644 Self::collect_aggregates(right, aggregates);
4645 }
4646 ScalarExpr::GetField { base, .. } => {
4647 Self::collect_aggregates(base, aggregates);
4648 }
4649 ScalarExpr::Cast { expr, .. } => {
4650 Self::collect_aggregates(expr, aggregates);
4651 }
4652 ScalarExpr::Not(expr) => {
4653 Self::collect_aggregates(expr, aggregates);
4654 }
4655 ScalarExpr::IsNull { expr, .. } => {
4656 Self::collect_aggregates(expr, aggregates);
4657 }
4658 ScalarExpr::Case {
4659 operand,
4660 branches,
4661 else_expr,
4662 } => {
4663 if let Some(inner) = operand.as_deref() {
4664 Self::collect_aggregates(inner, aggregates);
4665 }
4666 for (when_expr, then_expr) in branches {
4667 Self::collect_aggregates(when_expr, aggregates);
4668 Self::collect_aggregates(then_expr, aggregates);
4669 }
4670 if let Some(inner) = else_expr.as_deref() {
4671 Self::collect_aggregates(inner, aggregates);
4672 }
4673 }
4674 ScalarExpr::Coalesce(items) => {
4675 for item in items {
4676 Self::collect_aggregates(item, aggregates);
4677 }
4678 }
4679 ScalarExpr::Column(_) | ScalarExpr::Literal(_) | ScalarExpr::Random => {}
4680 ScalarExpr::ScalarSubquery(_) => {}
4681 }
4682 }
4683
4684 fn collect_aggregates_from_predicate(
4686 expr: &llkv_expr::expr::Expr<String>,
4687 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
4688 ) {
4689 match expr {
4690 llkv_expr::expr::Expr::Compare { left, right, .. } => {
4691 Self::collect_aggregates(left, aggregates);
4692 Self::collect_aggregates(right, aggregates);
4693 }
4694 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
4695 for e in exprs {
4696 Self::collect_aggregates_from_predicate(e, aggregates);
4697 }
4698 }
4699 llkv_expr::expr::Expr::Not(inner) => {
4700 Self::collect_aggregates_from_predicate(inner, aggregates);
4701 }
4702 llkv_expr::expr::Expr::InList {
4703 expr: test_expr,
4704 list,
4705 ..
4706 } => {
4707 Self::collect_aggregates(test_expr, aggregates);
4708 for item in list {
4709 Self::collect_aggregates(item, aggregates);
4710 }
4711 }
4712 llkv_expr::expr::Expr::IsNull { expr, .. } => {
4713 Self::collect_aggregates(expr, aggregates);
4714 }
4715 llkv_expr::expr::Expr::Literal(_) => {}
4716 llkv_expr::expr::Expr::Pred(_) => {}
4717 llkv_expr::expr::Expr::Exists(_) => {}
4718 }
4719 }
4720
4721 fn compute_aggregate_values(
4723 &self,
4724 table: Arc<ExecutorTable<P>>,
4725 filter: &Option<llkv_expr::expr::Expr<'static, String>>,
4726 aggregate_specs: &[(String, llkv_expr::expr::AggregateCall<String>)],
4727 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4728 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
4729 use llkv_expr::expr::AggregateCall;
4730
4731 let table_ref = table.as_ref();
4732 let mut results =
4733 FxHashMap::with_capacity_and_hasher(aggregate_specs.len(), Default::default());
4734
4735 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
4736 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
4737 let mut projections: Vec<ScanProjection> = Vec::new();
4738 let mut column_projection_cache: FxHashMap<FieldId, usize> = FxHashMap::default();
4739 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
4740 FxHashMap::default();
4741 let mut computed_alias_counter: usize = 0;
4742
4743 for (key, agg) in aggregate_specs {
4744 match agg {
4745 AggregateCall::CountStar => {
4746 specs.push(AggregateSpec {
4747 alias: key.clone(),
4748 kind: AggregateKind::Count {
4749 field_id: None,
4750 distinct: false,
4751 },
4752 });
4753 spec_to_projection.push(None);
4754 }
4755 AggregateCall::Count { expr, distinct } => {
4756 if let Some(col_name) = try_extract_simple_column(expr) {
4757 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4758 Error::InvalidArgumentError(format!(
4759 "unknown column '{}' in aggregate",
4760 col_name
4761 ))
4762 })?;
4763 let projection_index = get_or_insert_column_projection(
4764 &mut projections,
4765 &mut column_projection_cache,
4766 table_ref,
4767 col,
4768 );
4769 specs.push(AggregateSpec {
4770 alias: key.clone(),
4771 kind: AggregateKind::Count {
4772 field_id: Some(col.field_id),
4773 distinct: *distinct,
4774 },
4775 });
4776 spec_to_projection.push(Some(projection_index));
4777 } else {
4778 let (projection_index, _dtype) = ensure_computed_projection(
4779 expr,
4780 table_ref,
4781 &mut projections,
4782 &mut computed_projection_cache,
4783 &mut computed_alias_counter,
4784 )?;
4785 let field_id = u32::try_from(projection_index).map_err(|_| {
4786 Error::InvalidArgumentError(
4787 "aggregate projection index exceeds supported range".into(),
4788 )
4789 })?;
4790 specs.push(AggregateSpec {
4791 alias: key.clone(),
4792 kind: AggregateKind::Count {
4793 field_id: Some(field_id),
4794 distinct: *distinct,
4795 },
4796 });
4797 spec_to_projection.push(Some(projection_index));
4798 }
4799 }
4800 AggregateCall::Sum { expr, distinct } => {
4801 let (projection_index, data_type, field_id) =
4802 if let Some(col_name) = try_extract_simple_column(expr) {
4803 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4804 Error::InvalidArgumentError(format!(
4805 "unknown column '{}' in aggregate",
4806 col_name
4807 ))
4808 })?;
4809 let projection_index = get_or_insert_column_projection(
4810 &mut projections,
4811 &mut column_projection_cache,
4812 table_ref,
4813 col,
4814 );
4815 let data_type = col.data_type.clone();
4816 (projection_index, data_type, col.field_id)
4817 } else {
4818 let (projection_index, inferred_type) = ensure_computed_projection(
4819 expr,
4820 table_ref,
4821 &mut projections,
4822 &mut computed_projection_cache,
4823 &mut computed_alias_counter,
4824 )?;
4825 let field_id = u32::try_from(projection_index).map_err(|_| {
4826 Error::InvalidArgumentError(
4827 "aggregate projection index exceeds supported range".into(),
4828 )
4829 })?;
4830 (projection_index, inferred_type, field_id)
4831 };
4832 let normalized_type = Self::validate_aggregate_type(
4833 Some(data_type.clone()),
4834 "SUM",
4835 &[DataType::Int64, DataType::Float64],
4836 )?;
4837 specs.push(AggregateSpec {
4838 alias: key.clone(),
4839 kind: AggregateKind::Sum {
4840 field_id,
4841 data_type: normalized_type,
4842 distinct: *distinct,
4843 },
4844 });
4845 spec_to_projection.push(Some(projection_index));
4846 }
4847 AggregateCall::Total { expr, distinct } => {
4848 let (projection_index, data_type, field_id) =
4849 if let Some(col_name) = try_extract_simple_column(expr) {
4850 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4851 Error::InvalidArgumentError(format!(
4852 "unknown column '{}' in aggregate",
4853 col_name
4854 ))
4855 })?;
4856 let projection_index = get_or_insert_column_projection(
4857 &mut projections,
4858 &mut column_projection_cache,
4859 table_ref,
4860 col,
4861 );
4862 let data_type = col.data_type.clone();
4863 (projection_index, data_type, col.field_id)
4864 } else {
4865 let (projection_index, inferred_type) = ensure_computed_projection(
4866 expr,
4867 table_ref,
4868 &mut projections,
4869 &mut computed_projection_cache,
4870 &mut computed_alias_counter,
4871 )?;
4872 let field_id = u32::try_from(projection_index).map_err(|_| {
4873 Error::InvalidArgumentError(
4874 "aggregate projection index exceeds supported range".into(),
4875 )
4876 })?;
4877 (projection_index, inferred_type, field_id)
4878 };
4879 let normalized_type = Self::validate_aggregate_type(
4880 Some(data_type.clone()),
4881 "TOTAL",
4882 &[DataType::Int64, DataType::Float64],
4883 )?;
4884 specs.push(AggregateSpec {
4885 alias: key.clone(),
4886 kind: AggregateKind::Total {
4887 field_id,
4888 data_type: normalized_type,
4889 distinct: *distinct,
4890 },
4891 });
4892 spec_to_projection.push(Some(projection_index));
4893 }
4894 AggregateCall::Avg { expr, distinct } => {
4895 let (projection_index, data_type, field_id) =
4896 if let Some(col_name) = try_extract_simple_column(expr) {
4897 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4898 Error::InvalidArgumentError(format!(
4899 "unknown column '{}' in aggregate",
4900 col_name
4901 ))
4902 })?;
4903 let projection_index = get_or_insert_column_projection(
4904 &mut projections,
4905 &mut column_projection_cache,
4906 table_ref,
4907 col,
4908 );
4909 let data_type = col.data_type.clone();
4910 (projection_index, data_type, col.field_id)
4911 } else {
4912 let (projection_index, inferred_type) = ensure_computed_projection(
4913 expr,
4914 table_ref,
4915 &mut projections,
4916 &mut computed_projection_cache,
4917 &mut computed_alias_counter,
4918 )?;
4919 tracing::debug!(
4920 "AVG aggregate expr={:?} inferred_type={:?}",
4921 expr,
4922 inferred_type
4923 );
4924 let field_id = u32::try_from(projection_index).map_err(|_| {
4925 Error::InvalidArgumentError(
4926 "aggregate projection index exceeds supported range".into(),
4927 )
4928 })?;
4929 (projection_index, inferred_type, field_id)
4930 };
4931 let normalized_type = Self::validate_aggregate_type(
4932 Some(data_type.clone()),
4933 "AVG",
4934 &[DataType::Int64, DataType::Float64],
4935 )?;
4936 specs.push(AggregateSpec {
4937 alias: key.clone(),
4938 kind: AggregateKind::Avg {
4939 field_id,
4940 data_type: normalized_type,
4941 distinct: *distinct,
4942 },
4943 });
4944 spec_to_projection.push(Some(projection_index));
4945 }
4946 AggregateCall::Min(expr) => {
4947 let (projection_index, data_type, field_id) =
4948 if let Some(col_name) = try_extract_simple_column(expr) {
4949 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4950 Error::InvalidArgumentError(format!(
4951 "unknown column '{}' in aggregate",
4952 col_name
4953 ))
4954 })?;
4955 let projection_index = get_or_insert_column_projection(
4956 &mut projections,
4957 &mut column_projection_cache,
4958 table_ref,
4959 col,
4960 );
4961 let data_type = col.data_type.clone();
4962 (projection_index, data_type, col.field_id)
4963 } else {
4964 let (projection_index, inferred_type) = ensure_computed_projection(
4965 expr,
4966 table_ref,
4967 &mut projections,
4968 &mut computed_projection_cache,
4969 &mut computed_alias_counter,
4970 )?;
4971 let field_id = u32::try_from(projection_index).map_err(|_| {
4972 Error::InvalidArgumentError(
4973 "aggregate projection index exceeds supported range".into(),
4974 )
4975 })?;
4976 (projection_index, inferred_type, field_id)
4977 };
4978 let normalized_type = Self::validate_aggregate_type(
4979 Some(data_type.clone()),
4980 "MIN",
4981 &[DataType::Int64, DataType::Float64],
4982 )?;
4983 specs.push(AggregateSpec {
4984 alias: key.clone(),
4985 kind: AggregateKind::Min {
4986 field_id,
4987 data_type: normalized_type,
4988 },
4989 });
4990 spec_to_projection.push(Some(projection_index));
4991 }
4992 AggregateCall::Max(expr) => {
4993 let (projection_index, data_type, field_id) =
4994 if let Some(col_name) = try_extract_simple_column(expr) {
4995 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4996 Error::InvalidArgumentError(format!(
4997 "unknown column '{}' in aggregate",
4998 col_name
4999 ))
5000 })?;
5001 let projection_index = get_or_insert_column_projection(
5002 &mut projections,
5003 &mut column_projection_cache,
5004 table_ref,
5005 col,
5006 );
5007 let data_type = col.data_type.clone();
5008 (projection_index, data_type, col.field_id)
5009 } else {
5010 let (projection_index, inferred_type) = ensure_computed_projection(
5011 expr,
5012 table_ref,
5013 &mut projections,
5014 &mut computed_projection_cache,
5015 &mut computed_alias_counter,
5016 )?;
5017 let field_id = u32::try_from(projection_index).map_err(|_| {
5018 Error::InvalidArgumentError(
5019 "aggregate projection index exceeds supported range".into(),
5020 )
5021 })?;
5022 (projection_index, inferred_type, field_id)
5023 };
5024 let normalized_type = Self::validate_aggregate_type(
5025 Some(data_type.clone()),
5026 "MAX",
5027 &[DataType::Int64, DataType::Float64],
5028 )?;
5029 specs.push(AggregateSpec {
5030 alias: key.clone(),
5031 kind: AggregateKind::Max {
5032 field_id,
5033 data_type: normalized_type,
5034 },
5035 });
5036 spec_to_projection.push(Some(projection_index));
5037 }
5038 AggregateCall::CountNulls(expr) => {
5039 if let Some(col_name) = try_extract_simple_column(expr) {
5040 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
5041 Error::InvalidArgumentError(format!(
5042 "unknown column '{}' in aggregate",
5043 col_name
5044 ))
5045 })?;
5046 let projection_index = get_or_insert_column_projection(
5047 &mut projections,
5048 &mut column_projection_cache,
5049 table_ref,
5050 col,
5051 );
5052 specs.push(AggregateSpec {
5053 alias: key.clone(),
5054 kind: AggregateKind::CountNulls {
5055 field_id: col.field_id,
5056 },
5057 });
5058 spec_to_projection.push(Some(projection_index));
5059 } else {
5060 let (projection_index, _dtype) = ensure_computed_projection(
5061 expr,
5062 table_ref,
5063 &mut projections,
5064 &mut computed_projection_cache,
5065 &mut computed_alias_counter,
5066 )?;
5067 let field_id = u32::try_from(projection_index).map_err(|_| {
5068 Error::InvalidArgumentError(
5069 "aggregate projection index exceeds supported range".into(),
5070 )
5071 })?;
5072 specs.push(AggregateSpec {
5073 alias: key.clone(),
5074 kind: AggregateKind::CountNulls { field_id },
5075 });
5076 spec_to_projection.push(Some(projection_index));
5077 }
5078 }
5079 AggregateCall::GroupConcat {
5080 expr,
5081 distinct,
5082 separator,
5083 } => {
5084 if let Some(col_name) = try_extract_simple_column(expr) {
5085 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
5086 Error::InvalidArgumentError(format!(
5087 "unknown column '{}' in aggregate",
5088 col_name
5089 ))
5090 })?;
5091 let projection_index = get_or_insert_column_projection(
5092 &mut projections,
5093 &mut column_projection_cache,
5094 table_ref,
5095 col,
5096 );
5097 specs.push(AggregateSpec {
5098 alias: key.clone(),
5099 kind: AggregateKind::GroupConcat {
5100 field_id: col.field_id,
5101 distinct: *distinct,
5102 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
5103 },
5104 });
5105 spec_to_projection.push(Some(projection_index));
5106 } else {
5107 let (projection_index, _dtype) = ensure_computed_projection(
5108 expr,
5109 table_ref,
5110 &mut projections,
5111 &mut computed_projection_cache,
5112 &mut computed_alias_counter,
5113 )?;
5114 let field_id = u32::try_from(projection_index).map_err(|_| {
5115 Error::InvalidArgumentError(
5116 "aggregate projection index exceeds supported range".into(),
5117 )
5118 })?;
5119 specs.push(AggregateSpec {
5120 alias: key.clone(),
5121 kind: AggregateKind::GroupConcat {
5122 field_id,
5123 distinct: *distinct,
5124 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
5125 },
5126 });
5127 spec_to_projection.push(Some(projection_index));
5128 }
5129 }
5130 }
5131 }
5132
5133 let filter_expr = match filter {
5134 Some(expr) => crate::translation::expression::translate_predicate(
5135 expr.clone(),
5136 table_ref.schema.as_ref(),
5137 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
5138 )?,
5139 None => {
5140 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
5141 Error::InvalidArgumentError(
5142 "table has no columns; cannot perform aggregate scan".into(),
5143 )
5144 })?;
5145 crate::translation::expression::full_table_scan_filter(field_id)
5146 }
5147 };
5148
5149 if projections.is_empty() {
5150 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
5151 Error::InvalidArgumentError(
5152 "table has no columns; cannot perform aggregate scan".into(),
5153 )
5154 })?;
5155 projections.push(ScanProjection::from(StoreProjection::with_alias(
5156 LogicalFieldId::for_user(table.table.table_id(), field_id),
5157 table
5158 .schema
5159 .column_by_field_id(field_id)
5160 .map(|c| c.name.clone())
5161 .unwrap_or_else(|| format!("col{field_id}")),
5162 )));
5163 }
5164
5165 let base_options = ScanStreamOptions {
5166 include_nulls: true,
5167 order: None,
5168 row_id_filter: None,
5169 };
5170
5171 let count_star_override: Option<i64> = None;
5172
5173 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
5174 for (idx, spec) in specs.iter().enumerate() {
5175 states.push(AggregateState {
5176 alias: spec.alias.clone(),
5177 accumulator: AggregateAccumulator::new_with_projection_index(
5178 spec,
5179 spec_to_projection[idx],
5180 count_star_override,
5181 )?,
5182 override_value: match &spec.kind {
5183 AggregateKind::Count { field_id: None, .. } => count_star_override,
5184 _ => None,
5185 },
5186 });
5187 }
5188
5189 let mut error: Option<Error> = None;
5190 match table.table.scan_stream(
5191 projections,
5192 &filter_expr,
5193 ScanStreamOptions {
5194 row_id_filter: row_filter.clone(),
5195 ..base_options
5196 },
5197 |batch| {
5198 if error.is_some() {
5199 return;
5200 }
5201 for state in &mut states {
5202 if let Err(err) = state.update(&batch) {
5203 error = Some(err);
5204 return;
5205 }
5206 }
5207 },
5208 ) {
5209 Ok(()) => {}
5210 Err(llkv_result::Error::NotFound) => {}
5211 Err(err) => return Err(err),
5212 }
5213 if let Some(err) = error {
5214 return Err(err);
5215 }
5216
5217 for state in states {
5218 let alias = state.alias.clone();
5219 let (_field, array) = state.finalize()?;
5220
5221 if let Some(int64_array) = array.as_any().downcast_ref::<arrow::array::Int64Array>() {
5222 if int64_array.len() != 1 {
5223 return Err(Error::Internal(format!(
5224 "Expected single value from aggregate, got {}",
5225 int64_array.len()
5226 )));
5227 }
5228 let value = if int64_array.is_null(0) {
5229 AggregateValue::Null
5230 } else {
5231 AggregateValue::Int64(int64_array.value(0))
5232 };
5233 results.insert(alias, value);
5234 } else if let Some(float64_array) =
5235 array.as_any().downcast_ref::<arrow::array::Float64Array>()
5236 {
5237 if float64_array.len() != 1 {
5238 return Err(Error::Internal(format!(
5239 "Expected single value from aggregate, got {}",
5240 float64_array.len()
5241 )));
5242 }
5243 let value = if float64_array.is_null(0) {
5244 AggregateValue::Null
5245 } else {
5246 AggregateValue::Float64(float64_array.value(0))
5247 };
5248 results.insert(alias, value);
5249 } else if let Some(string_array) =
5250 array.as_any().downcast_ref::<arrow::array::StringArray>()
5251 {
5252 if string_array.len() != 1 {
5253 return Err(Error::Internal(format!(
5254 "Expected single value from aggregate, got {}",
5255 string_array.len()
5256 )));
5257 }
5258 let value = if string_array.is_null(0) {
5259 AggregateValue::Null
5260 } else {
5261 AggregateValue::String(string_array.value(0).to_string())
5262 };
5263 results.insert(alias, value);
5264 } else {
5265 return Err(Error::Internal(format!(
5266 "Unexpected array type from aggregate: {:?}",
5267 array.data_type()
5268 )));
5269 }
5270 }
5271
5272 Ok(results)
5273 }
5274
5275 fn evaluate_having_expr(
5276 expr: &llkv_expr::expr::Expr<String>,
5277 aggregates: &FxHashMap<String, PlanValue>,
5278 row_batch: &RecordBatch,
5279 column_lookup: &FxHashMap<String, usize>,
5280 row_idx: usize,
5281 ) -> ExecutorResult<Option<bool>> {
5282 fn compare_plan_values_for_pred(
5283 left: &PlanValue,
5284 right: &PlanValue,
5285 ) -> Option<std::cmp::Ordering> {
5286 match (left, right) {
5287 (PlanValue::Integer(l), PlanValue::Integer(r)) => Some(l.cmp(r)),
5288 (PlanValue::Float(l), PlanValue::Float(r)) => l.partial_cmp(r),
5289 (PlanValue::Integer(l), PlanValue::Float(r)) => (*l as f64).partial_cmp(r),
5290 (PlanValue::Float(l), PlanValue::Integer(r)) => l.partial_cmp(&(*r as f64)),
5291 (PlanValue::String(l), PlanValue::String(r)) => Some(l.cmp(r)),
5292 _ => None,
5293 }
5294 }
5295
5296 fn evaluate_ordering_predicate<F>(
5297 value: &PlanValue,
5298 literal: &Literal,
5299 predicate: F,
5300 ) -> ExecutorResult<Option<bool>>
5301 where
5302 F: Fn(std::cmp::Ordering) -> bool,
5303 {
5304 if matches!(value, PlanValue::Null) {
5305 return Ok(None);
5306 }
5307 let expected = llkv_plan::plan_value_from_literal(literal)?;
5308 if matches!(expected, PlanValue::Null) {
5309 return Ok(None);
5310 }
5311
5312 match compare_plan_values_for_pred(value, &expected) {
5313 Some(ordering) => Ok(Some(predicate(ordering))),
5314 None => Err(Error::InvalidArgumentError(
5315 "unsupported HAVING comparison between column value and literal".into(),
5316 )),
5317 }
5318 }
5319
5320 match expr {
5321 llkv_expr::expr::Expr::Compare { left, op, right } => {
5322 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5323 left,
5324 aggregates,
5325 Some(row_batch),
5326 Some(column_lookup),
5327 row_idx,
5328 )?;
5329 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5330 right,
5331 aggregates,
5332 Some(row_batch),
5333 Some(column_lookup),
5334 row_idx,
5335 )?;
5336
5337 let (left_val, right_val) = match (&left_val, &right_val) {
5339 (PlanValue::Integer(i), PlanValue::Float(_)) => {
5340 (PlanValue::Float(*i as f64), right_val)
5341 }
5342 (PlanValue::Float(_), PlanValue::Integer(i)) => {
5343 (left_val, PlanValue::Float(*i as f64))
5344 }
5345 _ => (left_val, right_val),
5346 };
5347
5348 match (left_val, right_val) {
5349 (PlanValue::Null, _) | (_, PlanValue::Null) => Ok(None),
5351 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
5352 use llkv_expr::expr::CompareOp;
5353 Ok(Some(match op {
5354 CompareOp::Eq => l == r,
5355 CompareOp::NotEq => l != r,
5356 CompareOp::Lt => l < r,
5357 CompareOp::LtEq => l <= r,
5358 CompareOp::Gt => l > r,
5359 CompareOp::GtEq => l >= r,
5360 }))
5361 }
5362 (PlanValue::Float(l), PlanValue::Float(r)) => {
5363 use llkv_expr::expr::CompareOp;
5364 Ok(Some(match op {
5365 CompareOp::Eq => l == r,
5366 CompareOp::NotEq => l != r,
5367 CompareOp::Lt => l < r,
5368 CompareOp::LtEq => l <= r,
5369 CompareOp::Gt => l > r,
5370 CompareOp::GtEq => l >= r,
5371 }))
5372 }
5373 _ => Ok(Some(false)),
5374 }
5375 }
5376 llkv_expr::expr::Expr::Not(inner) => {
5377 match Self::evaluate_having_expr(
5379 inner,
5380 aggregates,
5381 row_batch,
5382 column_lookup,
5383 row_idx,
5384 )? {
5385 Some(b) => Ok(Some(!b)),
5386 None => Ok(None), }
5388 }
5389 llkv_expr::expr::Expr::InList {
5390 expr: test_expr,
5391 list,
5392 negated,
5393 } => {
5394 let test_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5395 test_expr,
5396 aggregates,
5397 Some(row_batch),
5398 Some(column_lookup),
5399 row_idx,
5400 )?;
5401
5402 if matches!(test_val, PlanValue::Null) {
5405 return Ok(None);
5406 }
5407
5408 let mut found = false;
5409 let mut has_null = false;
5410
5411 for list_item in list {
5412 let list_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5413 list_item,
5414 aggregates,
5415 Some(row_batch),
5416 Some(column_lookup),
5417 row_idx,
5418 )?;
5419
5420 if matches!(list_val, PlanValue::Null) {
5422 has_null = true;
5423 continue;
5424 }
5425
5426 let matches = match (&test_val, &list_val) {
5428 (PlanValue::Integer(a), PlanValue::Integer(b)) => a == b,
5429 (PlanValue::Float(a), PlanValue::Float(b)) => a == b,
5430 (PlanValue::Integer(a), PlanValue::Float(b)) => (*a as f64) == *b,
5431 (PlanValue::Float(a), PlanValue::Integer(b)) => *a == (*b as f64),
5432 (PlanValue::String(a), PlanValue::String(b)) => a == b,
5433 _ => false,
5434 };
5435
5436 if matches {
5437 found = true;
5438 break;
5439 }
5440 }
5441
5442 if *negated {
5446 Ok(if found {
5448 Some(false)
5449 } else if has_null {
5450 None } else {
5452 Some(true)
5453 })
5454 } else {
5455 Ok(if found {
5457 Some(true)
5458 } else if has_null {
5459 None } else {
5461 Some(false)
5462 })
5463 }
5464 }
5465 llkv_expr::expr::Expr::IsNull { expr, negated } => {
5466 let val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5468 expr,
5469 aggregates,
5470 Some(row_batch),
5471 Some(column_lookup),
5472 row_idx,
5473 )?;
5474
5475 let is_null = matches!(val, PlanValue::Null);
5479 Ok(Some(if *negated { !is_null } else { is_null }))
5480 }
5481 llkv_expr::expr::Expr::Literal(val) => Ok(Some(*val)),
5482 llkv_expr::expr::Expr::And(exprs) => {
5483 let mut has_null = false;
5485 for e in exprs {
5486 match Self::evaluate_having_expr(
5487 e,
5488 aggregates,
5489 row_batch,
5490 column_lookup,
5491 row_idx,
5492 )? {
5493 Some(false) => return Ok(Some(false)), None => has_null = true,
5495 Some(true) => {} }
5497 }
5498 Ok(if has_null { None } else { Some(true) })
5499 }
5500 llkv_expr::expr::Expr::Or(exprs) => {
5501 let mut has_null = false;
5503 for e in exprs {
5504 match Self::evaluate_having_expr(
5505 e,
5506 aggregates,
5507 row_batch,
5508 column_lookup,
5509 row_idx,
5510 )? {
5511 Some(true) => return Ok(Some(true)), None => has_null = true,
5513 Some(false) => {} }
5515 }
5516 Ok(if has_null { None } else { Some(false) })
5517 }
5518 llkv_expr::expr::Expr::Pred(filter) => {
5519 use llkv_expr::expr::Operator;
5522
5523 let col_name = &filter.field_id;
5524 let col_idx = column_lookup
5525 .get(&col_name.to_ascii_lowercase())
5526 .ok_or_else(|| {
5527 Error::InvalidArgumentError(format!(
5528 "column '{}' not found in HAVING context",
5529 col_name
5530 ))
5531 })?;
5532
5533 let value = llkv_plan::plan_value_from_array(row_batch.column(*col_idx), row_idx)?;
5534
5535 match &filter.op {
5536 Operator::IsNull => Ok(Some(matches!(value, PlanValue::Null))),
5537 Operator::IsNotNull => Ok(Some(!matches!(value, PlanValue::Null))),
5538 Operator::Equals(expected) => {
5539 if matches!(value, PlanValue::Null) {
5541 return Ok(None);
5542 }
5543 let expected_value = llkv_plan::plan_value_from_literal(expected)?;
5545 if matches!(expected_value, PlanValue::Null) {
5546 return Ok(None);
5547 }
5548 Ok(Some(value == expected_value))
5549 }
5550 Operator::GreaterThan(expected) => {
5551 evaluate_ordering_predicate(&value, expected, |ordering| {
5552 ordering == std::cmp::Ordering::Greater
5553 })
5554 }
5555 Operator::GreaterThanOrEquals(expected) => {
5556 evaluate_ordering_predicate(&value, expected, |ordering| {
5557 ordering == std::cmp::Ordering::Greater
5558 || ordering == std::cmp::Ordering::Equal
5559 })
5560 }
5561 Operator::LessThan(expected) => {
5562 evaluate_ordering_predicate(&value, expected, |ordering| {
5563 ordering == std::cmp::Ordering::Less
5564 })
5565 }
5566 Operator::LessThanOrEquals(expected) => {
5567 evaluate_ordering_predicate(&value, expected, |ordering| {
5568 ordering == std::cmp::Ordering::Less
5569 || ordering == std::cmp::Ordering::Equal
5570 })
5571 }
5572 _ => {
5573 Err(Error::InvalidArgumentError(format!(
5576 "Operator {:?} not supported for column predicates in HAVING clause",
5577 filter.op
5578 )))
5579 }
5580 }
5581 }
5582 llkv_expr::expr::Expr::Exists(_) => Err(Error::InvalidArgumentError(
5583 "EXISTS subqueries not supported in HAVING clause".into(),
5584 )),
5585 }
5586 }
5587
5588 fn evaluate_expr_with_plan_value_aggregates_and_row(
5589 expr: &ScalarExpr<String>,
5590 aggregates: &FxHashMap<String, PlanValue>,
5591 row_batch: Option<&RecordBatch>,
5592 column_lookup: Option<&FxHashMap<String, usize>>,
5593 row_idx: usize,
5594 ) -> ExecutorResult<PlanValue> {
5595 use llkv_expr::expr::BinaryOp;
5596 use llkv_expr::literal::Literal;
5597
5598 match expr {
5599 ScalarExpr::Literal(Literal::Integer(v)) => Ok(PlanValue::Integer(*v as i64)),
5600 ScalarExpr::Literal(Literal::Float(v)) => Ok(PlanValue::Float(*v)),
5601 ScalarExpr::Literal(Literal::Boolean(v)) => {
5602 Ok(PlanValue::Integer(if *v { 1 } else { 0 }))
5603 }
5604 ScalarExpr::Literal(Literal::String(s)) => Ok(PlanValue::String(s.clone())),
5605 ScalarExpr::Literal(Literal::Null) => Ok(PlanValue::Null),
5606 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
5607 "Struct literals not supported in aggregate expressions".into(),
5608 )),
5609 ScalarExpr::Column(col_name) => {
5610 if let (Some(batch), Some(lookup)) = (row_batch, column_lookup) {
5612 let col_idx = lookup.get(&col_name.to_ascii_lowercase()).ok_or_else(|| {
5613 Error::InvalidArgumentError(format!("column '{}' not found", col_name))
5614 })?;
5615 llkv_plan::plan_value_from_array(batch.column(*col_idx), row_idx)
5616 } else {
5617 Err(Error::InvalidArgumentError(
5618 "Column references not supported in aggregate-only expressions".into(),
5619 ))
5620 }
5621 }
5622 ScalarExpr::Compare { left, op, right } => {
5623 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5625 left,
5626 aggregates,
5627 row_batch,
5628 column_lookup,
5629 row_idx,
5630 )?;
5631 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5632 right,
5633 aggregates,
5634 row_batch,
5635 column_lookup,
5636 row_idx,
5637 )?;
5638
5639 if matches!(left_val, PlanValue::Null) || matches!(right_val, PlanValue::Null) {
5641 return Ok(PlanValue::Null);
5642 }
5643
5644 let (left_val, right_val) = match (&left_val, &right_val) {
5646 (PlanValue::Integer(i), PlanValue::Float(_)) => {
5647 (PlanValue::Float(*i as f64), right_val)
5648 }
5649 (PlanValue::Float(_), PlanValue::Integer(i)) => {
5650 (left_val, PlanValue::Float(*i as f64))
5651 }
5652 _ => (left_val, right_val),
5653 };
5654
5655 let result = match (&left_val, &right_val) {
5657 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
5658 use llkv_expr::expr::CompareOp;
5659 match op {
5660 CompareOp::Eq => l == r,
5661 CompareOp::NotEq => l != r,
5662 CompareOp::Lt => l < r,
5663 CompareOp::LtEq => l <= r,
5664 CompareOp::Gt => l > r,
5665 CompareOp::GtEq => l >= r,
5666 }
5667 }
5668 (PlanValue::Float(l), PlanValue::Float(r)) => {
5669 use llkv_expr::expr::CompareOp;
5670 match op {
5671 CompareOp::Eq => l == r,
5672 CompareOp::NotEq => l != r,
5673 CompareOp::Lt => l < r,
5674 CompareOp::LtEq => l <= r,
5675 CompareOp::Gt => l > r,
5676 CompareOp::GtEq => l >= r,
5677 }
5678 }
5679 (PlanValue::String(l), PlanValue::String(r)) => {
5680 use llkv_expr::expr::CompareOp;
5681 match op {
5682 CompareOp::Eq => l == r,
5683 CompareOp::NotEq => l != r,
5684 CompareOp::Lt => l < r,
5685 CompareOp::LtEq => l <= r,
5686 CompareOp::Gt => l > r,
5687 CompareOp::GtEq => l >= r,
5688 }
5689 }
5690 _ => false,
5691 };
5692
5693 Ok(PlanValue::Integer(if result { 1 } else { 0 }))
5695 }
5696 ScalarExpr::Not(inner) => {
5697 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5698 inner,
5699 aggregates,
5700 row_batch,
5701 column_lookup,
5702 row_idx,
5703 )?;
5704 match value {
5705 PlanValue::Integer(v) => Ok(PlanValue::Integer(if v != 0 { 0 } else { 1 })),
5706 PlanValue::Float(v) => Ok(PlanValue::Integer(if v != 0.0 { 0 } else { 1 })),
5707 PlanValue::Null => Ok(PlanValue::Null),
5708 other => Err(Error::InvalidArgumentError(format!(
5709 "logical NOT does not support value {other:?}"
5710 ))),
5711 }
5712 }
5713 ScalarExpr::IsNull { expr, negated } => {
5714 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5715 expr,
5716 aggregates,
5717 row_batch,
5718 column_lookup,
5719 row_idx,
5720 )?;
5721 let is_null = matches!(value, PlanValue::Null);
5722 let condition = if is_null { !negated } else { *negated };
5723 Ok(PlanValue::Integer(if condition { 1 } else { 0 }))
5724 }
5725 ScalarExpr::Aggregate(agg) => {
5726 let key = format!("{:?}", agg);
5727 aggregates
5728 .get(&key)
5729 .cloned()
5730 .ok_or_else(|| Error::Internal(format!("Aggregate value not found: {}", key)))
5731 }
5732 ScalarExpr::Binary { left, op, right } => {
5733 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5734 left,
5735 aggregates,
5736 row_batch,
5737 column_lookup,
5738 row_idx,
5739 )?;
5740 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5741 right,
5742 aggregates,
5743 row_batch,
5744 column_lookup,
5745 row_idx,
5746 )?;
5747
5748 match op {
5749 BinaryOp::Add
5750 | BinaryOp::Subtract
5751 | BinaryOp::Multiply
5752 | BinaryOp::Divide
5753 | BinaryOp::Modulo => {
5754 if matches!(&left_val, PlanValue::Null)
5755 || matches!(&right_val, PlanValue::Null)
5756 {
5757 return Ok(PlanValue::Null);
5758 }
5759
5760 if matches!(op, BinaryOp::Divide)
5761 && let (PlanValue::Integer(lhs), PlanValue::Integer(rhs)) =
5762 (&left_val, &right_val)
5763 {
5764 if *rhs == 0 {
5765 return Ok(PlanValue::Null);
5766 }
5767
5768 if *lhs == i64::MIN && *rhs == -1 {
5769 return Ok(PlanValue::Float((*lhs as f64) / (*rhs as f64)));
5770 }
5771
5772 return Ok(PlanValue::Integer(lhs / rhs));
5773 }
5774
5775 let left_is_float = matches!(&left_val, PlanValue::Float(_));
5776 let right_is_float = matches!(&right_val, PlanValue::Float(_));
5777
5778 let left_num = match left_val {
5779 PlanValue::Integer(i) => i as f64,
5780 PlanValue::Float(f) => f,
5781 other => {
5782 return Err(Error::InvalidArgumentError(format!(
5783 "Non-numeric value {:?} in binary operation",
5784 other
5785 )));
5786 }
5787 };
5788 let right_num = match right_val {
5789 PlanValue::Integer(i) => i as f64,
5790 PlanValue::Float(f) => f,
5791 other => {
5792 return Err(Error::InvalidArgumentError(format!(
5793 "Non-numeric value {:?} in binary operation",
5794 other
5795 )));
5796 }
5797 };
5798
5799 let result = match op {
5800 BinaryOp::Add => left_num + right_num,
5801 BinaryOp::Subtract => left_num - right_num,
5802 BinaryOp::Multiply => left_num * right_num,
5803 BinaryOp::Divide => {
5804 if right_num == 0.0 {
5805 return Ok(PlanValue::Null);
5806 }
5807 left_num / right_num
5808 }
5809 BinaryOp::Modulo => {
5810 if right_num == 0.0 {
5811 return Ok(PlanValue::Null);
5812 }
5813 left_num % right_num
5814 }
5815 BinaryOp::And
5816 | BinaryOp::Or
5817 | BinaryOp::BitwiseShiftLeft
5818 | BinaryOp::BitwiseShiftRight => unreachable!(),
5819 };
5820
5821 if matches!(op, BinaryOp::Divide) {
5822 return Ok(PlanValue::Float(result));
5823 }
5824
5825 if left_is_float || right_is_float {
5826 Ok(PlanValue::Float(result))
5827 } else {
5828 Ok(PlanValue::Integer(result as i64))
5829 }
5830 }
5831 BinaryOp::And => Ok(evaluate_plan_value_logical_and(left_val, right_val)),
5832 BinaryOp::Or => Ok(evaluate_plan_value_logical_or(left_val, right_val)),
5833 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
5834 if matches!(&left_val, PlanValue::Null)
5835 || matches!(&right_val, PlanValue::Null)
5836 {
5837 return Ok(PlanValue::Null);
5838 }
5839
5840 let lhs = match left_val {
5842 PlanValue::Integer(i) => i,
5843 PlanValue::Float(f) => f as i64,
5844 other => {
5845 return Err(Error::InvalidArgumentError(format!(
5846 "Non-numeric value {:?} in bitwise shift operation",
5847 other
5848 )));
5849 }
5850 };
5851 let rhs = match right_val {
5852 PlanValue::Integer(i) => i,
5853 PlanValue::Float(f) => f as i64,
5854 other => {
5855 return Err(Error::InvalidArgumentError(format!(
5856 "Non-numeric value {:?} in bitwise shift operation",
5857 other
5858 )));
5859 }
5860 };
5861
5862 let result = match op {
5864 BinaryOp::BitwiseShiftLeft => lhs.wrapping_shl(rhs as u32),
5865 BinaryOp::BitwiseShiftRight => lhs.wrapping_shr(rhs as u32),
5866 _ => unreachable!(),
5867 };
5868
5869 Ok(PlanValue::Integer(result))
5870 }
5871 }
5872 }
5873 ScalarExpr::Cast { expr, data_type } => {
5874 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5876 expr,
5877 aggregates,
5878 row_batch,
5879 column_lookup,
5880 row_idx,
5881 )?;
5882
5883 if matches!(value, PlanValue::Null) {
5885 return Ok(PlanValue::Null);
5886 }
5887
5888 match data_type {
5890 DataType::Int64 | DataType::Int32 | DataType::Int16 | DataType::Int8 => {
5891 match value {
5892 PlanValue::Integer(i) => Ok(PlanValue::Integer(i)),
5893 PlanValue::Float(f) => Ok(PlanValue::Integer(f as i64)),
5894 PlanValue::String(s) => {
5895 s.parse::<i64>().map(PlanValue::Integer).map_err(|_| {
5896 Error::InvalidArgumentError(format!(
5897 "Cannot cast '{}' to integer",
5898 s
5899 ))
5900 })
5901 }
5902 _ => Err(Error::InvalidArgumentError(format!(
5903 "Cannot cast {:?} to integer",
5904 value
5905 ))),
5906 }
5907 }
5908 DataType::Float64 | DataType::Float32 => match value {
5909 PlanValue::Integer(i) => Ok(PlanValue::Float(i as f64)),
5910 PlanValue::Float(f) => Ok(PlanValue::Float(f)),
5911 PlanValue::String(s) => {
5912 s.parse::<f64>().map(PlanValue::Float).map_err(|_| {
5913 Error::InvalidArgumentError(format!("Cannot cast '{}' to float", s))
5914 })
5915 }
5916 _ => Err(Error::InvalidArgumentError(format!(
5917 "Cannot cast {:?} to float",
5918 value
5919 ))),
5920 },
5921 DataType::Utf8 | DataType::LargeUtf8 => match value {
5922 PlanValue::String(s) => Ok(PlanValue::String(s)),
5923 PlanValue::Integer(i) => Ok(PlanValue::String(i.to_string())),
5924 PlanValue::Float(f) => Ok(PlanValue::String(f.to_string())),
5925 _ => Err(Error::InvalidArgumentError(format!(
5926 "Cannot cast {:?} to string",
5927 value
5928 ))),
5929 },
5930 _ => Err(Error::InvalidArgumentError(format!(
5931 "CAST to {:?} not supported in aggregate expressions",
5932 data_type
5933 ))),
5934 }
5935 }
5936 ScalarExpr::Case {
5937 operand,
5938 branches,
5939 else_expr,
5940 } => {
5941 let operand_value = if let Some(op) = operand {
5943 Some(Self::evaluate_expr_with_plan_value_aggregates_and_row(
5944 op,
5945 aggregates,
5946 row_batch,
5947 column_lookup,
5948 row_idx,
5949 )?)
5950 } else {
5951 None
5952 };
5953
5954 for (when_expr, then_expr) in branches {
5956 let matches = if let Some(ref op_val) = operand_value {
5957 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5959 when_expr,
5960 aggregates,
5961 row_batch,
5962 column_lookup,
5963 row_idx,
5964 )?;
5965 Self::simple_case_branch_matches(op_val, &when_val)
5966 } else {
5967 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5969 when_expr,
5970 aggregates,
5971 row_batch,
5972 column_lookup,
5973 row_idx,
5974 )?;
5975 match when_val {
5977 PlanValue::Integer(i) => i != 0,
5978 PlanValue::Float(f) => f != 0.0,
5979 PlanValue::Null => false,
5980 _ => false,
5981 }
5982 };
5983
5984 if matches {
5985 return Self::evaluate_expr_with_plan_value_aggregates_and_row(
5986 then_expr,
5987 aggregates,
5988 row_batch,
5989 column_lookup,
5990 row_idx,
5991 );
5992 }
5993 }
5994
5995 if let Some(else_e) = else_expr {
5997 Self::evaluate_expr_with_plan_value_aggregates_and_row(
5998 else_e,
5999 aggregates,
6000 row_batch,
6001 column_lookup,
6002 row_idx,
6003 )
6004 } else {
6005 Ok(PlanValue::Null)
6006 }
6007 }
6008 ScalarExpr::Coalesce(exprs) => {
6009 for expr in exprs {
6011 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6012 expr,
6013 aggregates,
6014 row_batch,
6015 column_lookup,
6016 row_idx,
6017 )?;
6018 if !matches!(value, PlanValue::Null) {
6019 return Ok(value);
6020 }
6021 }
6022 Ok(PlanValue::Null)
6023 }
6024 ScalarExpr::Random => Ok(PlanValue::Float(rand::random::<f64>())),
6025 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
6026 "GetField not supported in aggregate expressions".into(),
6027 )),
6028 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
6029 "Scalar subqueries not supported in aggregate expressions".into(),
6030 )),
6031 }
6032 }
6033
6034 fn simple_case_branch_matches(operand: &PlanValue, candidate: &PlanValue) -> bool {
6035 if matches!(operand, PlanValue::Null) || matches!(candidate, PlanValue::Null) {
6036 return false;
6037 }
6038
6039 match (operand, candidate) {
6040 (PlanValue::Integer(left), PlanValue::Integer(right)) => left == right,
6041 (PlanValue::Integer(left), PlanValue::Float(right)) => (*left as f64) == *right,
6042 (PlanValue::Float(left), PlanValue::Integer(right)) => *left == (*right as f64),
6043 (PlanValue::Float(left), PlanValue::Float(right)) => left == right,
6044 (PlanValue::String(left), PlanValue::String(right)) => left == right,
6045 (PlanValue::Struct(left), PlanValue::Struct(right)) => left == right,
6046 _ => operand == candidate,
6047 }
6048 }
6049
6050 fn evaluate_expr_with_aggregates(
6051 expr: &ScalarExpr<String>,
6052 aggregates: &FxHashMap<String, AggregateValue>,
6053 ) -> ExecutorResult<Option<i64>> {
6054 use llkv_expr::expr::BinaryOp;
6055 use llkv_expr::literal::Literal;
6056
6057 match expr {
6058 ScalarExpr::Literal(Literal::Integer(v)) => Ok(Some(*v as i64)),
6059 ScalarExpr::Literal(Literal::Float(v)) => Ok(Some(*v as i64)),
6060 ScalarExpr::Literal(Literal::Boolean(v)) => Ok(Some(if *v { 1 } else { 0 })),
6061 ScalarExpr::Literal(Literal::String(_)) => Err(Error::InvalidArgumentError(
6062 "String literals not supported in aggregate expressions".into(),
6063 )),
6064 ScalarExpr::Literal(Literal::Null) => Ok(None),
6065 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
6066 "Struct literals not supported in aggregate expressions".into(),
6067 )),
6068 ScalarExpr::Column(_) => Err(Error::InvalidArgumentError(
6069 "Column references not supported in aggregate-only expressions".into(),
6070 )),
6071 ScalarExpr::Compare { .. } => Err(Error::InvalidArgumentError(
6072 "Comparisons not supported in aggregate-only expressions".into(),
6073 )),
6074 ScalarExpr::Aggregate(agg) => {
6075 let key = format!("{:?}", agg);
6076 let value = aggregates.get(&key).ok_or_else(|| {
6077 Error::Internal(format!("Aggregate value not found for key: {}", key))
6078 })?;
6079 Ok(value.as_i64())
6080 }
6081 ScalarExpr::Not(inner) => {
6082 let value = Self::evaluate_expr_with_aggregates(inner, aggregates)?;
6083 Ok(value.map(|v| if v != 0 { 0 } else { 1 }))
6084 }
6085 ScalarExpr::IsNull { expr, negated } => {
6086 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
6087 let is_null = value.is_none();
6088 Ok(Some(if is_null != *negated { 1 } else { 0 }))
6089 }
6090 ScalarExpr::Binary { left, op, right } => {
6091 let left_val = Self::evaluate_expr_with_aggregates(left, aggregates)?;
6092 let right_val = Self::evaluate_expr_with_aggregates(right, aggregates)?;
6093
6094 match op {
6095 BinaryOp::Add
6096 | BinaryOp::Subtract
6097 | BinaryOp::Multiply
6098 | BinaryOp::Divide
6099 | BinaryOp::Modulo => match (left_val, right_val) {
6100 (Some(lhs), Some(rhs)) => {
6101 let result = match op {
6102 BinaryOp::Add => lhs.checked_add(rhs),
6103 BinaryOp::Subtract => lhs.checked_sub(rhs),
6104 BinaryOp::Multiply => lhs.checked_mul(rhs),
6105 BinaryOp::Divide => {
6106 if rhs == 0 {
6107 return Ok(None);
6108 }
6109 lhs.checked_div(rhs)
6110 }
6111 BinaryOp::Modulo => {
6112 if rhs == 0 {
6113 return Ok(None);
6114 }
6115 lhs.checked_rem(rhs)
6116 }
6117 BinaryOp::And
6118 | BinaryOp::Or
6119 | BinaryOp::BitwiseShiftLeft
6120 | BinaryOp::BitwiseShiftRight => unreachable!(),
6121 };
6122
6123 result.map(Some).ok_or_else(|| {
6124 Error::InvalidArgumentError(
6125 "Arithmetic overflow in expression".into(),
6126 )
6127 })
6128 }
6129 _ => Ok(None),
6130 },
6131 BinaryOp::And => Ok(evaluate_option_logical_and(left_val, right_val)),
6132 BinaryOp::Or => Ok(evaluate_option_logical_or(left_val, right_val)),
6133 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
6134 match (left_val, right_val) {
6135 (Some(lhs), Some(rhs)) => {
6136 let result = match op {
6137 BinaryOp::BitwiseShiftLeft => {
6138 Some(lhs.wrapping_shl(rhs as u32))
6139 }
6140 BinaryOp::BitwiseShiftRight => {
6141 Some(lhs.wrapping_shr(rhs as u32))
6142 }
6143 _ => unreachable!(),
6144 };
6145 Ok(result)
6146 }
6147 _ => Ok(None),
6148 }
6149 }
6150 }
6151 }
6152 ScalarExpr::Cast { expr, data_type } => {
6153 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
6154 match value {
6155 Some(v) => Self::cast_aggregate_value(v, data_type).map(Some),
6156 None => Ok(None),
6157 }
6158 }
6159 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
6160 "GetField not supported in aggregate-only expressions".into(),
6161 )),
6162 ScalarExpr::Case { .. } => Err(Error::InvalidArgumentError(
6163 "CASE not supported in aggregate-only expressions".into(),
6164 )),
6165 ScalarExpr::Coalesce(_) => Err(Error::InvalidArgumentError(
6166 "COALESCE not supported in aggregate-only expressions".into(),
6167 )),
6168 ScalarExpr::Random => Ok(Some((rand::random::<f64>() * (i64::MAX as f64)) as i64)),
6169 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
6170 "Scalar subqueries not supported in aggregate-only expressions".into(),
6171 )),
6172 }
6173 }
6174
6175 fn cast_aggregate_value(value: i64, data_type: &DataType) -> ExecutorResult<i64> {
6176 fn ensure_range(value: i64, min: i64, max: i64, ty: &DataType) -> ExecutorResult<i64> {
6177 if value < min || value > max {
6178 return Err(Error::InvalidArgumentError(format!(
6179 "value {} out of range for CAST target {:?}",
6180 value, ty
6181 )));
6182 }
6183 Ok(value)
6184 }
6185
6186 match data_type {
6187 DataType::Int8 => ensure_range(value, i8::MIN as i64, i8::MAX as i64, data_type),
6188 DataType::Int16 => ensure_range(value, i16::MIN as i64, i16::MAX as i64, data_type),
6189 DataType::Int32 => ensure_range(value, i32::MIN as i64, i32::MAX as i64, data_type),
6190 DataType::Int64 => Ok(value),
6191 DataType::UInt8 => ensure_range(value, 0, u8::MAX as i64, data_type),
6192 DataType::UInt16 => ensure_range(value, 0, u16::MAX as i64, data_type),
6193 DataType::UInt32 => ensure_range(value, 0, u32::MAX as i64, data_type),
6194 DataType::UInt64 => {
6195 if value < 0 {
6196 return Err(Error::InvalidArgumentError(format!(
6197 "value {} out of range for CAST target {:?}",
6198 value, data_type
6199 )));
6200 }
6201 Ok(value)
6202 }
6203 DataType::Float32 | DataType::Float64 => Ok(value),
6204 DataType::Boolean => Ok(if value == 0 { 0 } else { 1 }),
6205 DataType::Null => Err(Error::InvalidArgumentError(
6206 "CAST to NULL is not supported in aggregate-only expressions".into(),
6207 )),
6208 _ => Err(Error::InvalidArgumentError(format!(
6209 "CAST to {:?} is not supported in aggregate-only expressions",
6210 data_type
6211 ))),
6212 }
6213 }
6214}
6215
6216struct CrossProductExpressionContext {
6217 schema: Arc<ExecutorSchema>,
6218 field_id_to_index: FxHashMap<FieldId, usize>,
6219 numeric_cache: FxHashMap<FieldId, NumericArray>,
6220 column_cache: FxHashMap<FieldId, ColumnAccessor>,
6221 next_field_id: FieldId,
6222}
6223
6224#[derive(Clone)]
6225enum ColumnAccessor {
6226 Int64(Arc<Int64Array>),
6227 Float64(Arc<Float64Array>),
6228 Boolean(Arc<BooleanArray>),
6229 Utf8(Arc<StringArray>),
6230 Null(usize),
6231}
6232
6233impl ColumnAccessor {
6234 fn from_array(array: &ArrayRef) -> ExecutorResult<Self> {
6235 match array.data_type() {
6236 DataType::Int64 => {
6237 let typed = array
6238 .as_any()
6239 .downcast_ref::<Int64Array>()
6240 .ok_or_else(|| Error::Internal("expected Int64 array".into()))?
6241 .clone();
6242 Ok(Self::Int64(Arc::new(typed)))
6243 }
6244 DataType::Float64 => {
6245 let typed = array
6246 .as_any()
6247 .downcast_ref::<Float64Array>()
6248 .ok_or_else(|| Error::Internal("expected Float64 array".into()))?
6249 .clone();
6250 Ok(Self::Float64(Arc::new(typed)))
6251 }
6252 DataType::Boolean => {
6253 let typed = array
6254 .as_any()
6255 .downcast_ref::<BooleanArray>()
6256 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
6257 .clone();
6258 Ok(Self::Boolean(Arc::new(typed)))
6259 }
6260 DataType::Utf8 => {
6261 let typed = array
6262 .as_any()
6263 .downcast_ref::<StringArray>()
6264 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
6265 .clone();
6266 Ok(Self::Utf8(Arc::new(typed)))
6267 }
6268 DataType::Null => Ok(Self::Null(array.len())),
6269 other => Err(Error::InvalidArgumentError(format!(
6270 "unsupported column type {:?} in cross product filter",
6271 other
6272 ))),
6273 }
6274 }
6275
6276 fn len(&self) -> usize {
6277 match self {
6278 ColumnAccessor::Int64(array) => array.len(),
6279 ColumnAccessor::Float64(array) => array.len(),
6280 ColumnAccessor::Boolean(array) => array.len(),
6281 ColumnAccessor::Utf8(array) => array.len(),
6282 ColumnAccessor::Null(len) => *len,
6283 }
6284 }
6285
6286 fn is_null(&self, idx: usize) -> bool {
6287 match self {
6288 ColumnAccessor::Int64(array) => array.is_null(idx),
6289 ColumnAccessor::Float64(array) => array.is_null(idx),
6290 ColumnAccessor::Boolean(array) => array.is_null(idx),
6291 ColumnAccessor::Utf8(array) => array.is_null(idx),
6292 ColumnAccessor::Null(_) => true,
6293 }
6294 }
6295
6296 fn literal_at(&self, idx: usize) -> ExecutorResult<Literal> {
6297 if self.is_null(idx) {
6298 return Ok(Literal::Null);
6299 }
6300 match self {
6301 ColumnAccessor::Int64(array) => Ok(Literal::Integer(array.value(idx) as i128)),
6302 ColumnAccessor::Float64(array) => Ok(Literal::Float(array.value(idx))),
6303 ColumnAccessor::Boolean(array) => Ok(Literal::Boolean(array.value(idx))),
6304 ColumnAccessor::Utf8(array) => Ok(Literal::String(array.value(idx).to_string())),
6305 ColumnAccessor::Null(_) => Ok(Literal::Null),
6306 }
6307 }
6308
6309 fn as_array_ref(&self) -> ArrayRef {
6310 match self {
6311 ColumnAccessor::Int64(array) => Arc::clone(array) as ArrayRef,
6312 ColumnAccessor::Float64(array) => Arc::clone(array) as ArrayRef,
6313 ColumnAccessor::Boolean(array) => Arc::clone(array) as ArrayRef,
6314 ColumnAccessor::Utf8(array) => Arc::clone(array) as ArrayRef,
6315 ColumnAccessor::Null(len) => new_null_array(&DataType::Null, *len),
6316 }
6317 }
6318}
6319
6320#[derive(Clone)]
6321enum ValueArray {
6322 Numeric(NumericArray),
6323 Boolean(Arc<BooleanArray>),
6324 Utf8(Arc<StringArray>),
6325 Null(usize),
6326}
6327
6328impl ValueArray {
6329 fn from_array(array: ArrayRef) -> ExecutorResult<Self> {
6330 match array.data_type() {
6331 DataType::Boolean => {
6332 let typed = array
6333 .as_any()
6334 .downcast_ref::<BooleanArray>()
6335 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
6336 .clone();
6337 Ok(Self::Boolean(Arc::new(typed)))
6338 }
6339 DataType::Utf8 => {
6340 let typed = array
6341 .as_any()
6342 .downcast_ref::<StringArray>()
6343 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
6344 .clone();
6345 Ok(Self::Utf8(Arc::new(typed)))
6346 }
6347 DataType::Null => Ok(Self::Null(array.len())),
6348 DataType::Int8
6349 | DataType::Int16
6350 | DataType::Int32
6351 | DataType::Int64
6352 | DataType::UInt8
6353 | DataType::UInt16
6354 | DataType::UInt32
6355 | DataType::UInt64
6356 | DataType::Float32
6357 | DataType::Float64 => {
6358 let numeric = NumericArray::try_from_arrow(&array)?;
6359 Ok(Self::Numeric(numeric))
6360 }
6361 other => Err(Error::InvalidArgumentError(format!(
6362 "unsupported data type {:?} in cross product expression",
6363 other
6364 ))),
6365 }
6366 }
6367
6368 fn len(&self) -> usize {
6369 match self {
6370 ValueArray::Numeric(array) => array.len(),
6371 ValueArray::Boolean(array) => array.len(),
6372 ValueArray::Utf8(array) => array.len(),
6373 ValueArray::Null(len) => *len,
6374 }
6375 }
6376}
6377
6378fn truth_and(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
6379 match (lhs, rhs) {
6380 (Some(false), _) | (_, Some(false)) => Some(false),
6381 (Some(true), Some(true)) => Some(true),
6382 (Some(true), None) | (None, Some(true)) | (None, None) => None,
6383 }
6384}
6385
6386fn truth_or(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
6387 match (lhs, rhs) {
6388 (Some(true), _) | (_, Some(true)) => Some(true),
6389 (Some(false), Some(false)) => Some(false),
6390 (Some(false), None) | (None, Some(false)) | (None, None) => None,
6391 }
6392}
6393
6394fn truth_not(value: Option<bool>) -> Option<bool> {
6395 match value {
6396 Some(true) => Some(false),
6397 Some(false) => Some(true),
6398 None => None,
6399 }
6400}
6401
6402fn compare_bool(op: CompareOp, lhs: bool, rhs: bool) -> bool {
6403 let l = lhs as u8;
6404 let r = rhs as u8;
6405 match op {
6406 CompareOp::Eq => lhs == rhs,
6407 CompareOp::NotEq => lhs != rhs,
6408 CompareOp::Lt => l < r,
6409 CompareOp::LtEq => l <= r,
6410 CompareOp::Gt => l > r,
6411 CompareOp::GtEq => l >= r,
6412 }
6413}
6414
6415fn compare_str(op: CompareOp, lhs: &str, rhs: &str) -> bool {
6416 match op {
6417 CompareOp::Eq => lhs == rhs,
6418 CompareOp::NotEq => lhs != rhs,
6419 CompareOp::Lt => lhs < rhs,
6420 CompareOp::LtEq => lhs <= rhs,
6421 CompareOp::Gt => lhs > rhs,
6422 CompareOp::GtEq => lhs >= rhs,
6423 }
6424}
6425
6426fn finalize_in_list_result(has_match: bool, saw_null: bool, negated: bool) -> Option<bool> {
6427 if has_match {
6428 Some(!negated)
6429 } else if saw_null {
6430 None
6431 } else if negated {
6432 Some(true)
6433 } else {
6434 Some(false)
6435 }
6436}
6437
6438fn literal_to_constant_array(literal: &Literal, len: usize) -> ExecutorResult<ArrayRef> {
6439 match literal {
6440 Literal::Integer(v) => {
6441 let value = i64::try_from(*v).unwrap_or(0);
6442 let values = vec![value; len];
6443 Ok(Arc::new(Int64Array::from(values)) as ArrayRef)
6444 }
6445 Literal::Float(v) => {
6446 let values = vec![*v; len];
6447 Ok(Arc::new(Float64Array::from(values)) as ArrayRef)
6448 }
6449 Literal::Boolean(v) => {
6450 let values = vec![Some(*v); len];
6451 Ok(Arc::new(BooleanArray::from(values)) as ArrayRef)
6452 }
6453 Literal::String(v) => {
6454 let values: Vec<Option<String>> = (0..len).map(|_| Some(v.clone())).collect();
6455 Ok(Arc::new(StringArray::from(values)) as ArrayRef)
6456 }
6457 Literal::Null => Ok(new_null_array(&DataType::Null, len)),
6458 Literal::Struct(_) => Err(Error::InvalidArgumentError(
6459 "struct literals are not supported in cross product filters".into(),
6460 )),
6461 }
6462}
6463
6464impl CrossProductExpressionContext {
6465 fn new(schema: &Schema, lookup: FxHashMap<String, usize>) -> ExecutorResult<Self> {
6466 let mut columns = Vec::with_capacity(schema.fields().len());
6467 let mut field_id_to_index = FxHashMap::default();
6468 let mut next_field_id: FieldId = 1;
6469
6470 for (idx, field) in schema.fields().iter().enumerate() {
6471 if next_field_id == u32::MAX {
6472 return Err(Error::Internal(
6473 "cross product projection exhausted FieldId space".into(),
6474 ));
6475 }
6476
6477 let executor_column = ExecutorColumn {
6478 name: field.name().clone(),
6479 data_type: field.data_type().clone(),
6480 nullable: field.is_nullable(),
6481 primary_key: false,
6482 unique: false,
6483 field_id: next_field_id,
6484 check_expr: None,
6485 };
6486 let field_id = next_field_id;
6487 next_field_id = next_field_id.saturating_add(1);
6488
6489 columns.push(executor_column);
6490 field_id_to_index.insert(field_id, idx);
6491 }
6492
6493 Ok(Self {
6494 schema: Arc::new(ExecutorSchema { columns, lookup }),
6495 field_id_to_index,
6496 numeric_cache: FxHashMap::default(),
6497 column_cache: FxHashMap::default(),
6498 next_field_id,
6499 })
6500 }
6501
6502 fn schema(&self) -> &ExecutorSchema {
6503 self.schema.as_ref()
6504 }
6505
6506 fn field_id_for_column(&self, name: &str) -> Option<FieldId> {
6507 self.schema.resolve(name).map(|column| column.field_id)
6508 }
6509
6510 fn reset(&mut self) {
6511 self.numeric_cache.clear();
6512 self.column_cache.clear();
6513 }
6514
6515 fn allocate_synthetic_field_id(&mut self) -> ExecutorResult<FieldId> {
6516 if self.next_field_id == FieldId::MAX {
6517 return Err(Error::Internal(
6518 "cross product projection exhausted FieldId space".into(),
6519 ));
6520 }
6521 let field_id = self.next_field_id;
6522 self.next_field_id = self.next_field_id.saturating_add(1);
6523 Ok(field_id)
6524 }
6525
6526 #[cfg(test)]
6527 fn evaluate(
6528 &mut self,
6529 expr: &ScalarExpr<String>,
6530 batch: &RecordBatch,
6531 ) -> ExecutorResult<ArrayRef> {
6532 let translated = translate_scalar(expr, self.schema.as_ref(), |name| {
6533 Error::InvalidArgumentError(format!(
6534 "column '{}' not found in cross product result",
6535 name
6536 ))
6537 })?;
6538
6539 self.evaluate_numeric(&translated, batch)
6540 }
6541
6542 fn evaluate_predicate_mask(
6543 &mut self,
6544 expr: &LlkvExpr<'static, FieldId>,
6545 batch: &RecordBatch,
6546 mut exists_eval: impl FnMut(
6547 &mut Self,
6548 &llkv_expr::SubqueryExpr,
6549 usize,
6550 &RecordBatch,
6551 ) -> ExecutorResult<Option<bool>>,
6552 ) -> ExecutorResult<BooleanArray> {
6553 let truths = self.evaluate_predicate_truths(expr, batch, &mut exists_eval)?;
6554 let mut builder = BooleanBuilder::with_capacity(truths.len());
6555 for value in truths {
6556 builder.append_value(value.unwrap_or(false));
6557 }
6558 Ok(builder.finish())
6559 }
6560
6561 fn evaluate_predicate_truths(
6562 &mut self,
6563 expr: &LlkvExpr<'static, FieldId>,
6564 batch: &RecordBatch,
6565 exists_eval: &mut impl FnMut(
6566 &mut Self,
6567 &llkv_expr::SubqueryExpr,
6568 usize,
6569 &RecordBatch,
6570 ) -> ExecutorResult<Option<bool>>,
6571 ) -> ExecutorResult<Vec<Option<bool>>> {
6572 match expr {
6573 LlkvExpr::Literal(value) => Ok(vec![Some(*value); batch.num_rows()]),
6574 LlkvExpr::And(children) => {
6575 if children.is_empty() {
6576 return Ok(vec![Some(true); batch.num_rows()]);
6577 }
6578 let mut result =
6579 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
6580 for child in &children[1..] {
6581 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
6582 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
6583 *lhs = truth_and(*lhs, rhs);
6584 }
6585 }
6586 Ok(result)
6587 }
6588 LlkvExpr::Or(children) => {
6589 if children.is_empty() {
6590 return Ok(vec![Some(false); batch.num_rows()]);
6591 }
6592 let mut result =
6593 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
6594 for child in &children[1..] {
6595 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
6596 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
6597 *lhs = truth_or(*lhs, rhs);
6598 }
6599 }
6600 Ok(result)
6601 }
6602 LlkvExpr::Not(inner) => {
6603 let mut values = self.evaluate_predicate_truths(inner, batch, exists_eval)?;
6604 for value in &mut values {
6605 *value = truth_not(*value);
6606 }
6607 Ok(values)
6608 }
6609 LlkvExpr::Pred(filter) => self.evaluate_filter_truths(filter, batch),
6610 LlkvExpr::Compare { left, op, right } => {
6611 self.evaluate_compare_truths(left, *op, right, batch)
6612 }
6613 LlkvExpr::InList {
6614 expr: target,
6615 list,
6616 negated,
6617 } => self.evaluate_in_list_truths(target, list, *negated, batch),
6618 LlkvExpr::IsNull { expr, negated } => {
6619 self.evaluate_is_null_truths(expr, *negated, batch)
6620 }
6621 LlkvExpr::Exists(subquery_expr) => {
6622 let mut values = Vec::with_capacity(batch.num_rows());
6623 for row_idx in 0..batch.num_rows() {
6624 let value = exists_eval(self, subquery_expr, row_idx, batch)?;
6625 values.push(value);
6626 }
6627 Ok(values)
6628 }
6629 }
6630 }
6631
6632 fn evaluate_filter_truths(
6633 &mut self,
6634 filter: &Filter<FieldId>,
6635 batch: &RecordBatch,
6636 ) -> ExecutorResult<Vec<Option<bool>>> {
6637 let accessor = self.column_accessor(filter.field_id, batch)?;
6638 let len = accessor.len();
6639
6640 match &filter.op {
6641 Operator::IsNull => {
6642 let mut out = Vec::with_capacity(len);
6643 for idx in 0..len {
6644 out.push(Some(accessor.is_null(idx)));
6645 }
6646 Ok(out)
6647 }
6648 Operator::IsNotNull => {
6649 let mut out = Vec::with_capacity(len);
6650 for idx in 0..len {
6651 out.push(Some(!accessor.is_null(idx)));
6652 }
6653 Ok(out)
6654 }
6655 _ => match accessor {
6656 ColumnAccessor::Int64(array) => {
6657 let predicate = build_fixed_width_predicate::<Int64Type>(&filter.op)
6658 .map_err(Error::predicate_build)?;
6659 let mut out = Vec::with_capacity(len);
6660 for idx in 0..len {
6661 if array.is_null(idx) {
6662 out.push(None);
6663 } else {
6664 let value = array.value(idx);
6665 out.push(Some(predicate.matches(&value)));
6666 }
6667 }
6668 Ok(out)
6669 }
6670 ColumnAccessor::Float64(array) => {
6671 let predicate = build_fixed_width_predicate::<Float64Type>(&filter.op)
6672 .map_err(Error::predicate_build)?;
6673 let mut out = Vec::with_capacity(len);
6674 for idx in 0..len {
6675 if array.is_null(idx) {
6676 out.push(None);
6677 } else {
6678 let value = array.value(idx);
6679 out.push(Some(predicate.matches(&value)));
6680 }
6681 }
6682 Ok(out)
6683 }
6684 ColumnAccessor::Boolean(array) => {
6685 let predicate =
6686 build_bool_predicate(&filter.op).map_err(Error::predicate_build)?;
6687 let mut out = Vec::with_capacity(len);
6688 for idx in 0..len {
6689 if array.is_null(idx) {
6690 out.push(None);
6691 } else {
6692 let value = array.value(idx);
6693 out.push(Some(predicate.matches(&value)));
6694 }
6695 }
6696 Ok(out)
6697 }
6698 ColumnAccessor::Utf8(array) => {
6699 let predicate =
6700 build_var_width_predicate(&filter.op).map_err(Error::predicate_build)?;
6701 let mut out = Vec::with_capacity(len);
6702 for idx in 0..len {
6703 if array.is_null(idx) {
6704 out.push(None);
6705 } else {
6706 let value = array.value(idx);
6707 out.push(Some(predicate.matches(value)));
6708 }
6709 }
6710 Ok(out)
6711 }
6712 ColumnAccessor::Null(len) => Ok(vec![None; len]),
6713 },
6714 }
6715 }
6716
6717 fn evaluate_compare_truths(
6718 &mut self,
6719 left: &ScalarExpr<FieldId>,
6720 op: CompareOp,
6721 right: &ScalarExpr<FieldId>,
6722 batch: &RecordBatch,
6723 ) -> ExecutorResult<Vec<Option<bool>>> {
6724 let left_values = self.materialize_value_array(left, batch)?;
6725 let right_values = self.materialize_value_array(right, batch)?;
6726
6727 if left_values.len() != right_values.len() {
6728 return Err(Error::Internal(
6729 "mismatched compare operand lengths in cross product filter".into(),
6730 ));
6731 }
6732
6733 let len = left_values.len();
6734 match (&left_values, &right_values) {
6735 (ValueArray::Null(_), _) | (_, ValueArray::Null(_)) => Ok(vec![None; len]),
6736 (ValueArray::Numeric(lhs), ValueArray::Numeric(rhs)) => {
6737 let mut out = Vec::with_capacity(len);
6738 for idx in 0..len {
6739 match (lhs.value(idx), rhs.value(idx)) {
6740 (Some(lv), Some(rv)) => out.push(Some(NumericKernels::compare(op, lv, rv))),
6741 _ => out.push(None),
6742 }
6743 }
6744 Ok(out)
6745 }
6746 (ValueArray::Boolean(lhs), ValueArray::Boolean(rhs)) => {
6747 let lhs = lhs.as_ref();
6748 let rhs = rhs.as_ref();
6749 let mut out = Vec::with_capacity(len);
6750 for idx in 0..len {
6751 if lhs.is_null(idx) || rhs.is_null(idx) {
6752 out.push(None);
6753 } else {
6754 out.push(Some(compare_bool(op, lhs.value(idx), rhs.value(idx))));
6755 }
6756 }
6757 Ok(out)
6758 }
6759 (ValueArray::Utf8(lhs), ValueArray::Utf8(rhs)) => {
6760 let lhs = lhs.as_ref();
6761 let rhs = rhs.as_ref();
6762 let mut out = Vec::with_capacity(len);
6763 for idx in 0..len {
6764 if lhs.is_null(idx) || rhs.is_null(idx) {
6765 out.push(None);
6766 } else {
6767 out.push(Some(compare_str(op, lhs.value(idx), rhs.value(idx))));
6768 }
6769 }
6770 Ok(out)
6771 }
6772 _ => Err(Error::InvalidArgumentError(
6773 "unsupported comparison between mismatched types in cross product filter".into(),
6774 )),
6775 }
6776 }
6777
6778 fn evaluate_is_null_truths(
6779 &mut self,
6780 expr: &ScalarExpr<FieldId>,
6781 negated: bool,
6782 batch: &RecordBatch,
6783 ) -> ExecutorResult<Vec<Option<bool>>> {
6784 let values = self.materialize_value_array(expr, batch)?;
6785 let len = values.len();
6786
6787 match &values {
6788 ValueArray::Null(len) => {
6789 let result = if negated {
6791 Some(false) } else {
6793 Some(true) };
6795 Ok(vec![result; *len])
6796 }
6797 ValueArray::Numeric(arr) => {
6798 let mut out = Vec::with_capacity(len);
6799 for idx in 0..len {
6800 let is_null = arr.value(idx).is_none();
6801 let result = if negated {
6802 !is_null } else {
6804 is_null };
6806 out.push(Some(result));
6807 }
6808 Ok(out)
6809 }
6810 ValueArray::Boolean(arr) => {
6811 let mut out = Vec::with_capacity(len);
6812 for idx in 0..len {
6813 let is_null = arr.is_null(idx);
6814 let result = if negated { !is_null } else { is_null };
6815 out.push(Some(result));
6816 }
6817 Ok(out)
6818 }
6819 ValueArray::Utf8(arr) => {
6820 let mut out = Vec::with_capacity(len);
6821 for idx in 0..len {
6822 let is_null = arr.is_null(idx);
6823 let result = if negated { !is_null } else { is_null };
6824 out.push(Some(result));
6825 }
6826 Ok(out)
6827 }
6828 }
6829 }
6830
6831 fn evaluate_in_list_truths(
6832 &mut self,
6833 target: &ScalarExpr<FieldId>,
6834 list: &[ScalarExpr<FieldId>],
6835 negated: bool,
6836 batch: &RecordBatch,
6837 ) -> ExecutorResult<Vec<Option<bool>>> {
6838 let target_values = self.materialize_value_array(target, batch)?;
6839 let list_values = list
6840 .iter()
6841 .map(|expr| self.materialize_value_array(expr, batch))
6842 .collect::<ExecutorResult<Vec<_>>>()?;
6843
6844 let len = target_values.len();
6845 for values in &list_values {
6846 if values.len() != len {
6847 return Err(Error::Internal(
6848 "mismatched IN list operand lengths in cross product filter".into(),
6849 ));
6850 }
6851 }
6852
6853 match &target_values {
6854 ValueArray::Numeric(target_numeric) => {
6855 let mut out = Vec::with_capacity(len);
6856 for idx in 0..len {
6857 let target_value = match target_numeric.value(idx) {
6858 Some(value) => value,
6859 None => {
6860 out.push(None);
6861 continue;
6862 }
6863 };
6864 let mut has_match = false;
6865 let mut saw_null = false;
6866 for candidate in &list_values {
6867 match candidate {
6868 ValueArray::Numeric(array) => match array.value(idx) {
6869 Some(value) => {
6870 if NumericKernels::compare(CompareOp::Eq, target_value, value) {
6871 has_match = true;
6872 break;
6873 }
6874 }
6875 None => saw_null = true,
6876 },
6877 ValueArray::Null(_) => saw_null = true,
6878 _ => {
6879 return Err(Error::InvalidArgumentError(
6880 "type mismatch in IN list evaluation".into(),
6881 ));
6882 }
6883 }
6884 }
6885 out.push(finalize_in_list_result(has_match, saw_null, negated));
6886 }
6887 Ok(out)
6888 }
6889 ValueArray::Boolean(target_bool) => {
6890 let mut out = Vec::with_capacity(len);
6891 for idx in 0..len {
6892 if target_bool.is_null(idx) {
6893 out.push(None);
6894 continue;
6895 }
6896 let target_value = target_bool.value(idx);
6897 let mut has_match = false;
6898 let mut saw_null = false;
6899 for candidate in &list_values {
6900 match candidate {
6901 ValueArray::Boolean(array) => {
6902 if array.is_null(idx) {
6903 saw_null = true;
6904 } else if array.value(idx) == target_value {
6905 has_match = true;
6906 break;
6907 }
6908 }
6909 ValueArray::Null(_) => saw_null = true,
6910 _ => {
6911 return Err(Error::InvalidArgumentError(
6912 "type mismatch in IN list evaluation".into(),
6913 ));
6914 }
6915 }
6916 }
6917 out.push(finalize_in_list_result(has_match, saw_null, negated));
6918 }
6919 Ok(out)
6920 }
6921 ValueArray::Utf8(target_utf8) => {
6922 let mut out = Vec::with_capacity(len);
6923 for idx in 0..len {
6924 if target_utf8.is_null(idx) {
6925 out.push(None);
6926 continue;
6927 }
6928 let target_value = target_utf8.value(idx);
6929 let mut has_match = false;
6930 let mut saw_null = false;
6931 for candidate in &list_values {
6932 match candidate {
6933 ValueArray::Utf8(array) => {
6934 if array.is_null(idx) {
6935 saw_null = true;
6936 } else if array.value(idx) == target_value {
6937 has_match = true;
6938 break;
6939 }
6940 }
6941 ValueArray::Null(_) => saw_null = true,
6942 _ => {
6943 return Err(Error::InvalidArgumentError(
6944 "type mismatch in IN list evaluation".into(),
6945 ));
6946 }
6947 }
6948 }
6949 out.push(finalize_in_list_result(has_match, saw_null, negated));
6950 }
6951 Ok(out)
6952 }
6953 ValueArray::Null(len) => Ok(vec![None; *len]),
6954 }
6955 }
6956
6957 fn evaluate_numeric(
6958 &mut self,
6959 expr: &ScalarExpr<FieldId>,
6960 batch: &RecordBatch,
6961 ) -> ExecutorResult<ArrayRef> {
6962 let mut required = FxHashSet::default();
6963 collect_field_ids(expr, &mut required);
6964
6965 let mut arrays = NumericArrayMap::default();
6966 for field_id in required {
6967 let numeric = self.numeric_array(field_id, batch)?;
6968 arrays.insert(field_id, numeric);
6969 }
6970
6971 NumericKernels::evaluate_batch(expr, batch.num_rows(), &arrays)
6972 }
6973
6974 fn numeric_array(
6975 &mut self,
6976 field_id: FieldId,
6977 batch: &RecordBatch,
6978 ) -> ExecutorResult<NumericArray> {
6979 if let Some(existing) = self.numeric_cache.get(&field_id) {
6980 return Ok(existing.clone());
6981 }
6982
6983 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
6984 Error::Internal("field mapping missing during cross product evaluation".into())
6985 })?;
6986
6987 let array_ref = batch.column(column_index).clone();
6988 let numeric = NumericArray::try_from_arrow(&array_ref)?;
6989 self.numeric_cache.insert(field_id, numeric.clone());
6990 Ok(numeric)
6991 }
6992
6993 fn column_accessor(
6994 &mut self,
6995 field_id: FieldId,
6996 batch: &RecordBatch,
6997 ) -> ExecutorResult<ColumnAccessor> {
6998 if let Some(existing) = self.column_cache.get(&field_id) {
6999 return Ok(existing.clone());
7000 }
7001
7002 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
7003 Error::Internal("field mapping missing during cross product evaluation".into())
7004 })?;
7005
7006 let accessor = ColumnAccessor::from_array(batch.column(column_index))?;
7007 self.column_cache.insert(field_id, accessor.clone());
7008 Ok(accessor)
7009 }
7010
7011 fn materialize_scalar_array(
7012 &mut self,
7013 expr: &ScalarExpr<FieldId>,
7014 batch: &RecordBatch,
7015 ) -> ExecutorResult<ArrayRef> {
7016 match expr {
7017 ScalarExpr::Column(field_id) => {
7018 let accessor = self.column_accessor(*field_id, batch)?;
7019 Ok(accessor.as_array_ref())
7020 }
7021 ScalarExpr::Literal(literal) => literal_to_constant_array(literal, batch.num_rows()),
7022 ScalarExpr::Binary { .. } => self.evaluate_numeric(expr, batch),
7023 ScalarExpr::Compare { .. } => self.evaluate_numeric(expr, batch),
7024 ScalarExpr::Not(_) => self.evaluate_numeric(expr, batch),
7025 ScalarExpr::IsNull { .. } => self.evaluate_numeric(expr, batch),
7026 ScalarExpr::Aggregate(_) => Err(Error::InvalidArgumentError(
7027 "aggregate expressions are not supported in cross product filters".into(),
7028 )),
7029 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
7030 "struct field access is not supported in cross product filters".into(),
7031 )),
7032 ScalarExpr::Cast { expr, data_type } => {
7033 let source = self.materialize_scalar_array(expr.as_ref(), batch)?;
7034 let casted = cast(source.as_ref(), data_type).map_err(|err| {
7035 Error::InvalidArgumentError(format!("failed to cast expression: {err}"))
7036 })?;
7037 Ok(casted)
7038 }
7039 ScalarExpr::Case { .. } => self.evaluate_numeric(expr, batch),
7040 ScalarExpr::Coalesce(_) => self.evaluate_numeric(expr, batch),
7041 ScalarExpr::Random => self.evaluate_numeric(expr, batch),
7042 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
7043 "scalar subqueries are not supported in cross product filters".into(),
7044 )),
7045 }
7046 }
7047
7048 fn materialize_value_array(
7049 &mut self,
7050 expr: &ScalarExpr<FieldId>,
7051 batch: &RecordBatch,
7052 ) -> ExecutorResult<ValueArray> {
7053 let array = self.materialize_scalar_array(expr, batch)?;
7054 ValueArray::from_array(array)
7055 }
7056}
7057
7058fn collect_field_ids(expr: &ScalarExpr<FieldId>, out: &mut FxHashSet<FieldId>) {
7060 match expr {
7061 ScalarExpr::Column(fid) => {
7062 out.insert(*fid);
7063 }
7064 ScalarExpr::Binary { left, right, .. } => {
7065 collect_field_ids(left, out);
7066 collect_field_ids(right, out);
7067 }
7068 ScalarExpr::Compare { left, right, .. } => {
7069 collect_field_ids(left, out);
7070 collect_field_ids(right, out);
7071 }
7072 ScalarExpr::Aggregate(call) => match call {
7073 AggregateCall::CountStar => {}
7074 AggregateCall::Count { expr, .. }
7075 | AggregateCall::Sum { expr, .. }
7076 | AggregateCall::Total { expr, .. }
7077 | AggregateCall::Avg { expr, .. }
7078 | AggregateCall::Min(expr)
7079 | AggregateCall::Max(expr)
7080 | AggregateCall::CountNulls(expr)
7081 | AggregateCall::GroupConcat { expr, .. } => {
7082 collect_field_ids(expr, out);
7083 }
7084 },
7085 ScalarExpr::GetField { base, .. } => collect_field_ids(base, out),
7086 ScalarExpr::Cast { expr, .. } => collect_field_ids(expr, out),
7087 ScalarExpr::Not(expr) => collect_field_ids(expr, out),
7088 ScalarExpr::IsNull { expr, .. } => collect_field_ids(expr, out),
7089 ScalarExpr::Case {
7090 operand,
7091 branches,
7092 else_expr,
7093 } => {
7094 if let Some(inner) = operand.as_deref() {
7095 collect_field_ids(inner, out);
7096 }
7097 for (when_expr, then_expr) in branches {
7098 collect_field_ids(when_expr, out);
7099 collect_field_ids(then_expr, out);
7100 }
7101 if let Some(inner) = else_expr.as_deref() {
7102 collect_field_ids(inner, out);
7103 }
7104 }
7105 ScalarExpr::Coalesce(items) => {
7106 for item in items {
7107 collect_field_ids(item, out);
7108 }
7109 }
7110 ScalarExpr::Literal(_) | ScalarExpr::Random => {}
7111 ScalarExpr::ScalarSubquery(_) => {}
7112 }
7113}
7114
7115fn strip_exists(expr: &LlkvExpr<'static, FieldId>) -> LlkvExpr<'static, FieldId> {
7116 match expr {
7117 LlkvExpr::And(children) => LlkvExpr::And(children.iter().map(strip_exists).collect()),
7118 LlkvExpr::Or(children) => LlkvExpr::Or(children.iter().map(strip_exists).collect()),
7119 LlkvExpr::Not(inner) => LlkvExpr::Not(Box::new(strip_exists(inner))),
7120 LlkvExpr::Pred(filter) => LlkvExpr::Pred(filter.clone()),
7121 LlkvExpr::Compare { left, op, right } => LlkvExpr::Compare {
7122 left: left.clone(),
7123 op: *op,
7124 right: right.clone(),
7125 },
7126 LlkvExpr::InList {
7127 expr,
7128 list,
7129 negated,
7130 } => LlkvExpr::InList {
7131 expr: expr.clone(),
7132 list: list.clone(),
7133 negated: *negated,
7134 },
7135 LlkvExpr::IsNull { expr, negated } => LlkvExpr::IsNull {
7136 expr: expr.clone(),
7137 negated: *negated,
7138 },
7139 LlkvExpr::Literal(value) => LlkvExpr::Literal(*value),
7140 LlkvExpr::Exists(_) => LlkvExpr::Literal(true),
7141 }
7142}
7143
7144fn bind_select_plan(
7145 plan: &SelectPlan,
7146 bindings: &FxHashMap<String, Literal>,
7147) -> ExecutorResult<SelectPlan> {
7148 if bindings.is_empty() {
7149 return Ok(plan.clone());
7150 }
7151
7152 let projections = plan
7153 .projections
7154 .iter()
7155 .map(|projection| bind_projection(projection, bindings))
7156 .collect::<ExecutorResult<Vec<_>>>()?;
7157
7158 let filter = match &plan.filter {
7159 Some(wrapper) => Some(bind_select_filter(wrapper, bindings)?),
7160 None => None,
7161 };
7162
7163 let aggregates = plan
7164 .aggregates
7165 .iter()
7166 .map(|aggregate| bind_aggregate_expr(aggregate, bindings))
7167 .collect::<ExecutorResult<Vec<_>>>()?;
7168
7169 let scalar_subqueries = plan
7170 .scalar_subqueries
7171 .iter()
7172 .map(|subquery| bind_scalar_subquery(subquery, bindings))
7173 .collect::<ExecutorResult<Vec<_>>>()?;
7174
7175 if let Some(compound) = &plan.compound {
7176 let bound_compound = bind_compound_select(compound, bindings)?;
7177 return Ok(SelectPlan {
7178 tables: Vec::new(),
7179 joins: Vec::new(),
7180 projections: Vec::new(),
7181 filter: None,
7182 having: None,
7183 aggregates: Vec::new(),
7184 order_by: plan.order_by.clone(),
7185 distinct: false,
7186 scalar_subqueries: Vec::new(),
7187 compound: Some(bound_compound),
7188 group_by: Vec::new(),
7189 value_table_mode: None,
7190 });
7191 }
7192
7193 Ok(SelectPlan {
7194 tables: plan.tables.clone(),
7195 joins: plan.joins.clone(),
7196 projections,
7197 filter,
7198 having: plan.having.clone(),
7199 aggregates,
7200 order_by: Vec::new(),
7201 distinct: plan.distinct,
7202 scalar_subqueries,
7203 compound: None,
7204 group_by: plan.group_by.clone(),
7205 value_table_mode: plan.value_table_mode.clone(),
7206 })
7207}
7208
7209fn bind_compound_select(
7210 compound: &CompoundSelectPlan,
7211 bindings: &FxHashMap<String, Literal>,
7212) -> ExecutorResult<CompoundSelectPlan> {
7213 let initial = bind_select_plan(&compound.initial, bindings)?;
7214 let mut operations = Vec::with_capacity(compound.operations.len());
7215 for component in &compound.operations {
7216 let bound_plan = bind_select_plan(&component.plan, bindings)?;
7217 operations.push(CompoundSelectComponent {
7218 operator: component.operator.clone(),
7219 quantifier: component.quantifier.clone(),
7220 plan: bound_plan,
7221 });
7222 }
7223 Ok(CompoundSelectPlan {
7224 initial: Box::new(initial),
7225 operations,
7226 })
7227}
7228
7229fn ensure_schema_compatibility(base: &Schema, other: &Schema) -> ExecutorResult<()> {
7230 if base.fields().len() != other.fields().len() {
7231 return Err(Error::InvalidArgumentError(
7232 "compound SELECT requires matching column counts".into(),
7233 ));
7234 }
7235 for (left, right) in base.fields().iter().zip(other.fields().iter()) {
7236 if left.data_type() != right.data_type() {
7237 return Err(Error::InvalidArgumentError(format!(
7238 "compound SELECT column type mismatch: {} vs {}",
7239 left.data_type(),
7240 right.data_type()
7241 )));
7242 }
7243 }
7244 Ok(())
7245}
7246
7247fn ensure_distinct_rows(rows: &mut Vec<Vec<PlanValue>>, cache: &mut Option<FxHashSet<Vec<u8>>>) {
7248 if cache.is_some() {
7249 return;
7250 }
7251 let mut set = FxHashSet::default();
7252 let mut deduped: Vec<Vec<PlanValue>> = Vec::with_capacity(rows.len());
7253 for row in rows.drain(..) {
7254 let key = encode_row(&row);
7255 if set.insert(key) {
7256 deduped.push(row);
7257 }
7258 }
7259 *rows = deduped;
7260 *cache = Some(set);
7261}
7262
7263fn encode_row(row: &[PlanValue]) -> Vec<u8> {
7264 let mut buf = Vec::new();
7265 for value in row {
7266 encode_plan_value(&mut buf, value);
7267 buf.push(0x1F);
7268 }
7269 buf
7270}
7271
7272fn encode_plan_value(buf: &mut Vec<u8>, value: &PlanValue) {
7273 match value {
7274 PlanValue::Null => buf.push(0),
7275 PlanValue::Integer(v) => {
7276 buf.push(1);
7277 buf.extend_from_slice(&v.to_be_bytes());
7278 }
7279 PlanValue::Float(v) => {
7280 buf.push(2);
7281 buf.extend_from_slice(&v.to_bits().to_be_bytes());
7282 }
7283 PlanValue::String(s) => {
7284 buf.push(3);
7285 let bytes = s.as_bytes();
7286 let len = u32::try_from(bytes.len()).unwrap_or(u32::MAX);
7287 buf.extend_from_slice(&len.to_be_bytes());
7288 buf.extend_from_slice(bytes);
7289 }
7290 PlanValue::Struct(map) => {
7291 buf.push(4);
7292 let mut entries: Vec<_> = map.iter().collect();
7293 entries.sort_by(|a, b| a.0.cmp(b.0));
7294 let len = u32::try_from(entries.len()).unwrap_or(u32::MAX);
7295 buf.extend_from_slice(&len.to_be_bytes());
7296 for (key, val) in entries {
7297 let key_bytes = key.as_bytes();
7298 let key_len = u32::try_from(key_bytes.len()).unwrap_or(u32::MAX);
7299 buf.extend_from_slice(&key_len.to_be_bytes());
7300 buf.extend_from_slice(key_bytes);
7301 encode_plan_value(buf, val);
7302 }
7303 }
7304 }
7305}
7306
7307fn rows_to_record_batch(
7308 schema: Arc<Schema>,
7309 rows: &[Vec<PlanValue>],
7310) -> ExecutorResult<RecordBatch> {
7311 let column_count = schema.fields().len();
7312 let mut columns: Vec<Vec<PlanValue>> = vec![Vec::with_capacity(rows.len()); column_count];
7313 for row in rows {
7314 if row.len() != column_count {
7315 return Err(Error::InvalidArgumentError(
7316 "compound SELECT produced mismatched column counts".into(),
7317 ));
7318 }
7319 for (idx, value) in row.iter().enumerate() {
7320 columns[idx].push(value.clone());
7321 }
7322 }
7323
7324 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(column_count);
7325 for (idx, field) in schema.fields().iter().enumerate() {
7326 let array = build_array_for_column(field.data_type(), &columns[idx])?;
7327 arrays.push(array);
7328 }
7329
7330 RecordBatch::try_new(schema, arrays).map_err(|err| {
7331 Error::InvalidArgumentError(format!("failed to materialize compound SELECT: {err}"))
7332 })
7333}
7334
7335fn build_column_lookup_map(schema: &Schema) -> FxHashMap<String, usize> {
7336 let mut lookup = FxHashMap::default();
7337 for (idx, field) in schema.fields().iter().enumerate() {
7338 lookup.insert(field.name().to_ascii_lowercase(), idx);
7339 }
7340 lookup
7341}
7342
7343fn build_group_key(
7344 batch: &RecordBatch,
7345 row_idx: usize,
7346 key_indices: &[usize],
7347) -> ExecutorResult<Vec<GroupKeyValue>> {
7348 let mut values = Vec::with_capacity(key_indices.len());
7349 for &index in key_indices {
7350 values.push(group_key_value(batch.column(index), row_idx)?);
7351 }
7352 Ok(values)
7353}
7354
7355fn group_key_value(array: &ArrayRef, row_idx: usize) -> ExecutorResult<GroupKeyValue> {
7356 if !array.is_valid(row_idx) {
7357 return Ok(GroupKeyValue::Null);
7358 }
7359
7360 match array.data_type() {
7361 DataType::Int8 => {
7362 let values = array
7363 .as_any()
7364 .downcast_ref::<Int8Array>()
7365 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
7366 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7367 }
7368 DataType::Int16 => {
7369 let values = array
7370 .as_any()
7371 .downcast_ref::<Int16Array>()
7372 .ok_or_else(|| Error::Internal("failed to downcast to Int16Array".into()))?;
7373 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7374 }
7375 DataType::Int32 => {
7376 let values = array
7377 .as_any()
7378 .downcast_ref::<Int32Array>()
7379 .ok_or_else(|| Error::Internal("failed to downcast to Int32Array".into()))?;
7380 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7381 }
7382 DataType::Int64 => {
7383 let values = array
7384 .as_any()
7385 .downcast_ref::<Int64Array>()
7386 .ok_or_else(|| Error::Internal("failed to downcast to Int64Array".into()))?;
7387 Ok(GroupKeyValue::Int(values.value(row_idx)))
7388 }
7389 DataType::UInt8 => {
7390 let values = array
7391 .as_any()
7392 .downcast_ref::<UInt8Array>()
7393 .ok_or_else(|| Error::Internal("failed to downcast to UInt8Array".into()))?;
7394 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7395 }
7396 DataType::UInt16 => {
7397 let values = array
7398 .as_any()
7399 .downcast_ref::<UInt16Array>()
7400 .ok_or_else(|| Error::Internal("failed to downcast to UInt16Array".into()))?;
7401 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7402 }
7403 DataType::UInt32 => {
7404 let values = array
7405 .as_any()
7406 .downcast_ref::<UInt32Array>()
7407 .ok_or_else(|| Error::Internal("failed to downcast to UInt32Array".into()))?;
7408 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7409 }
7410 DataType::UInt64 => {
7411 let values = array
7412 .as_any()
7413 .downcast_ref::<UInt64Array>()
7414 .ok_or_else(|| Error::Internal("failed to downcast to UInt64Array".into()))?;
7415 let value = values.value(row_idx);
7416 if value > i64::MAX as u64 {
7417 return Err(Error::InvalidArgumentError(
7418 "GROUP BY value exceeds supported integer range".into(),
7419 ));
7420 }
7421 Ok(GroupKeyValue::Int(value as i64))
7422 }
7423 DataType::Boolean => {
7424 let values = array
7425 .as_any()
7426 .downcast_ref::<BooleanArray>()
7427 .ok_or_else(|| Error::Internal("failed to downcast to BooleanArray".into()))?;
7428 Ok(GroupKeyValue::Bool(values.value(row_idx)))
7429 }
7430 DataType::Utf8 => {
7431 let values = array
7432 .as_any()
7433 .downcast_ref::<StringArray>()
7434 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
7435 Ok(GroupKeyValue::String(values.value(row_idx).to_string()))
7436 }
7437 other => Err(Error::InvalidArgumentError(format!(
7438 "GROUP BY does not support column type {:?}",
7439 other
7440 ))),
7441 }
7442}
7443
7444fn evaluate_constant_predicate(expr: &LlkvExpr<'static, String>) -> Option<Option<bool>> {
7445 match expr {
7446 LlkvExpr::Literal(value) => Some(Some(*value)),
7447 LlkvExpr::Not(inner) => {
7448 let inner_val = evaluate_constant_predicate(inner)?;
7449 Some(truth_not(inner_val))
7450 }
7451 LlkvExpr::And(children) => {
7452 let mut acc = Some(true);
7453 for child in children {
7454 let child_val = evaluate_constant_predicate(child)?;
7455 acc = truth_and(acc, child_val);
7456 }
7457 Some(acc)
7458 }
7459 LlkvExpr::Or(children) => {
7460 let mut acc = Some(false);
7461 for child in children {
7462 let child_val = evaluate_constant_predicate(child)?;
7463 acc = truth_or(acc, child_val);
7464 }
7465 Some(acc)
7466 }
7467 LlkvExpr::Compare { left, op, right } => {
7468 let left_literal = evaluate_constant_scalar(left)?;
7469 let right_literal = evaluate_constant_scalar(right)?;
7470 Some(compare_literals(*op, &left_literal, &right_literal))
7471 }
7472 LlkvExpr::IsNull { expr, negated } => {
7473 let literal = evaluate_constant_scalar(expr)?;
7474 let is_null = matches!(literal, Literal::Null);
7475 Some(Some(if *negated { !is_null } else { is_null }))
7476 }
7477 LlkvExpr::InList {
7478 expr,
7479 list,
7480 negated,
7481 } => {
7482 let needle = evaluate_constant_scalar(expr)?;
7483 let mut saw_unknown = false;
7484
7485 for candidate in list {
7486 let value = evaluate_constant_scalar(candidate)?;
7487 match compare_literals(CompareOp::Eq, &needle, &value) {
7488 Some(true) => {
7489 return Some(Some(!*negated));
7490 }
7491 Some(false) => {}
7492 None => saw_unknown = true,
7493 }
7494 }
7495
7496 if saw_unknown {
7497 Some(None)
7498 } else {
7499 Some(Some(*negated))
7500 }
7501 }
7502 _ => None,
7503 }
7504}
7505
7506enum ConstantJoinEvaluation {
7507 Known(bool),
7508 Unknown,
7509 NotConstant,
7510}
7511
7512fn evaluate_constant_join_expr(expr: &LlkvExpr<'static, String>) -> ConstantJoinEvaluation {
7513 match expr {
7514 LlkvExpr::Literal(value) => ConstantJoinEvaluation::Known(*value),
7515 LlkvExpr::And(children) => {
7516 let mut saw_unknown = false;
7517 for child in children {
7518 match evaluate_constant_join_expr(child) {
7519 ConstantJoinEvaluation::Known(false) => {
7520 return ConstantJoinEvaluation::Known(false);
7521 }
7522 ConstantJoinEvaluation::Known(true) => {}
7523 ConstantJoinEvaluation::Unknown => saw_unknown = true,
7524 ConstantJoinEvaluation::NotConstant => {
7525 return ConstantJoinEvaluation::NotConstant;
7526 }
7527 }
7528 }
7529 if saw_unknown {
7530 ConstantJoinEvaluation::Unknown
7531 } else {
7532 ConstantJoinEvaluation::Known(true)
7533 }
7534 }
7535 LlkvExpr::Or(children) => {
7536 let mut saw_unknown = false;
7537 for child in children {
7538 match evaluate_constant_join_expr(child) {
7539 ConstantJoinEvaluation::Known(true) => {
7540 return ConstantJoinEvaluation::Known(true);
7541 }
7542 ConstantJoinEvaluation::Known(false) => {}
7543 ConstantJoinEvaluation::Unknown => saw_unknown = true,
7544 ConstantJoinEvaluation::NotConstant => {
7545 return ConstantJoinEvaluation::NotConstant;
7546 }
7547 }
7548 }
7549 if saw_unknown {
7550 ConstantJoinEvaluation::Unknown
7551 } else {
7552 ConstantJoinEvaluation::Known(false)
7553 }
7554 }
7555 LlkvExpr::Not(inner) => match evaluate_constant_join_expr(inner) {
7556 ConstantJoinEvaluation::Known(value) => ConstantJoinEvaluation::Known(!value),
7557 ConstantJoinEvaluation::Unknown => ConstantJoinEvaluation::Unknown,
7558 ConstantJoinEvaluation::NotConstant => ConstantJoinEvaluation::NotConstant,
7559 },
7560 LlkvExpr::Compare { left, op, right } => {
7561 let left_lit = evaluate_constant_scalar(left);
7562 let right_lit = evaluate_constant_scalar(right);
7563
7564 if matches!(left_lit, Some(Literal::Null)) || matches!(right_lit, Some(Literal::Null)) {
7565 return ConstantJoinEvaluation::Unknown;
7567 }
7568
7569 let (Some(left_lit), Some(right_lit)) = (left_lit, right_lit) else {
7570 return ConstantJoinEvaluation::NotConstant;
7571 };
7572
7573 match compare_literals(*op, &left_lit, &right_lit) {
7574 Some(result) => ConstantJoinEvaluation::Known(result),
7575 None => ConstantJoinEvaluation::Unknown,
7576 }
7577 }
7578 LlkvExpr::IsNull { expr, negated } => match evaluate_constant_scalar(expr) {
7579 Some(literal) => {
7580 let is_null = matches!(literal, Literal::Null);
7581 let value = if *negated { !is_null } else { is_null };
7582 ConstantJoinEvaluation::Known(value)
7583 }
7584 None => ConstantJoinEvaluation::NotConstant,
7585 },
7586 LlkvExpr::InList {
7587 expr,
7588 list,
7589 negated,
7590 } => {
7591 let needle = match evaluate_constant_scalar(expr) {
7592 Some(literal) => literal,
7593 None => return ConstantJoinEvaluation::NotConstant,
7594 };
7595
7596 if matches!(needle, Literal::Null) {
7597 return ConstantJoinEvaluation::Unknown;
7598 }
7599
7600 let mut saw_unknown = false;
7601 for candidate in list {
7602 let value = match evaluate_constant_scalar(candidate) {
7603 Some(literal) => literal,
7604 None => return ConstantJoinEvaluation::NotConstant,
7605 };
7606
7607 match compare_literals(CompareOp::Eq, &needle, &value) {
7608 Some(true) => {
7609 let result = !*negated;
7610 return ConstantJoinEvaluation::Known(result);
7611 }
7612 Some(false) => {}
7613 None => saw_unknown = true,
7614 }
7615 }
7616
7617 if saw_unknown {
7618 ConstantJoinEvaluation::Unknown
7619 } else {
7620 let result = *negated;
7621 ConstantJoinEvaluation::Known(result)
7622 }
7623 }
7624 _ => ConstantJoinEvaluation::NotConstant,
7625 }
7626}
7627
7628enum NullComparisonBehavior {
7629 ThreeValuedLogic,
7630}
7631
7632fn evaluate_constant_scalar(expr: &ScalarExpr<String>) -> Option<Literal> {
7633 evaluate_constant_scalar_internal(expr, false)
7634}
7635
7636fn evaluate_constant_scalar_with_aggregates(expr: &ScalarExpr<String>) -> Option<Literal> {
7637 evaluate_constant_scalar_internal(expr, true)
7638}
7639
7640fn evaluate_constant_scalar_internal(
7641 expr: &ScalarExpr<String>,
7642 allow_aggregates: bool,
7643) -> Option<Literal> {
7644 match expr {
7645 ScalarExpr::Literal(lit) => Some(lit.clone()),
7646 ScalarExpr::Binary { left, op, right } => {
7647 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
7648 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
7649 evaluate_binary_literal(*op, &left_value, &right_value)
7650 }
7651 ScalarExpr::Cast { expr, data_type } => {
7652 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7653 cast_literal_to_type(&value, data_type)
7654 }
7655 ScalarExpr::Not(inner) => {
7656 let value = evaluate_constant_scalar_internal(inner, allow_aggregates)?;
7657 match literal_truthiness(&value) {
7658 Some(true) => Some(Literal::Integer(0)),
7659 Some(false) => Some(Literal::Integer(1)),
7660 None => Some(Literal::Null),
7661 }
7662 }
7663 ScalarExpr::IsNull { expr, negated } => {
7664 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7665 let is_null = matches!(value, Literal::Null);
7666 Some(Literal::Boolean(if *negated { !is_null } else { is_null }))
7667 }
7668 ScalarExpr::Coalesce(items) => {
7669 let mut saw_null = false;
7670 for item in items {
7671 match evaluate_constant_scalar_internal(item, allow_aggregates) {
7672 Some(Literal::Null) => saw_null = true,
7673 Some(value) => return Some(value),
7674 None => return None,
7675 }
7676 }
7677 if saw_null { Some(Literal::Null) } else { None }
7678 }
7679 ScalarExpr::Compare { left, op, right } => {
7680 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
7681 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
7682 match compare_literals(*op, &left_value, &right_value) {
7683 Some(flag) => Some(Literal::Boolean(flag)),
7684 None => Some(Literal::Null),
7685 }
7686 }
7687 ScalarExpr::Case {
7688 operand,
7689 branches,
7690 else_expr,
7691 } => {
7692 if let Some(operand_expr) = operand {
7693 let operand_value =
7694 evaluate_constant_scalar_internal(operand_expr, allow_aggregates)?;
7695 for (when_expr, then_expr) in branches {
7696 let when_value =
7697 evaluate_constant_scalar_internal(when_expr, allow_aggregates)?;
7698 if let Some(true) = compare_literals(CompareOp::Eq, &operand_value, &when_value)
7699 {
7700 return evaluate_constant_scalar_internal(then_expr, allow_aggregates);
7701 }
7702 }
7703 } else {
7704 for (condition_expr, result_expr) in branches {
7705 let condition_value =
7706 evaluate_constant_scalar_internal(condition_expr, allow_aggregates)?;
7707 match literal_truthiness(&condition_value) {
7708 Some(true) => {
7709 return evaluate_constant_scalar_internal(
7710 result_expr,
7711 allow_aggregates,
7712 );
7713 }
7714 Some(false) => {}
7715 None => {}
7716 }
7717 }
7718 }
7719
7720 if let Some(else_branch) = else_expr {
7721 evaluate_constant_scalar_internal(else_branch, allow_aggregates)
7722 } else {
7723 Some(Literal::Null)
7724 }
7725 }
7726 ScalarExpr::Column(_) => None,
7727 ScalarExpr::Aggregate(call) => {
7728 if allow_aggregates {
7729 evaluate_constant_aggregate(call, allow_aggregates)
7730 } else {
7731 None
7732 }
7733 }
7734 ScalarExpr::GetField { .. } => None,
7735 ScalarExpr::Random => None, ScalarExpr::ScalarSubquery(_) => None,
7737 }
7738}
7739
7740fn evaluate_constant_aggregate(
7741 call: &AggregateCall<String>,
7742 allow_aggregates: bool,
7743) -> Option<Literal> {
7744 match call {
7745 AggregateCall::CountStar => Some(Literal::Integer(1)),
7746 AggregateCall::Count { expr, .. } => {
7747 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7748 if matches!(value, Literal::Null) {
7749 Some(Literal::Integer(0))
7750 } else {
7751 Some(Literal::Integer(1))
7752 }
7753 }
7754 AggregateCall::Sum { expr, .. } => {
7755 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7756 match value {
7757 Literal::Null => Some(Literal::Null),
7758 Literal::Integer(value) => Some(Literal::Integer(value)),
7759 Literal::Float(value) => Some(Literal::Float(value)),
7760 Literal::Boolean(flag) => Some(Literal::Integer(if flag { 1 } else { 0 })),
7761 _ => None,
7762 }
7763 }
7764 AggregateCall::Total { expr, .. } => {
7765 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7766 match value {
7767 Literal::Null => Some(Literal::Integer(0)),
7768 Literal::Integer(value) => Some(Literal::Integer(value)),
7769 Literal::Float(value) => Some(Literal::Float(value)),
7770 Literal::Boolean(flag) => Some(Literal::Integer(if flag { 1 } else { 0 })),
7771 _ => None,
7772 }
7773 }
7774 AggregateCall::Avg { expr, .. } => {
7775 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7776 match value {
7777 Literal::Null => Some(Literal::Null),
7778 other => {
7779 let numeric = literal_to_f64(&other)?;
7780 Some(Literal::Float(numeric))
7781 }
7782 }
7783 }
7784 AggregateCall::Min(expr) => {
7785 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7786 match value {
7787 Literal::Null => Some(Literal::Null),
7788 other => Some(other),
7789 }
7790 }
7791 AggregateCall::Max(expr) => {
7792 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7793 match value {
7794 Literal::Null => Some(Literal::Null),
7795 other => Some(other),
7796 }
7797 }
7798 AggregateCall::CountNulls(expr) => {
7799 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7800 let count = if matches!(value, Literal::Null) { 1 } else { 0 };
7801 Some(Literal::Integer(count))
7802 }
7803 AggregateCall::GroupConcat {
7804 expr, separator: _, ..
7805 } => {
7806 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7807 match value {
7808 Literal::Null => Some(Literal::Null),
7809 Literal::String(s) => Some(Literal::String(s)),
7810 Literal::Integer(i) => Some(Literal::String(i.to_string())),
7811 Literal::Float(f) => Some(Literal::String(f.to_string())),
7812 Literal::Boolean(b) => Some(Literal::String(if b { "1" } else { "0" }.to_string())),
7813 _ => None,
7814 }
7815 }
7816 }
7817}
7818
7819fn evaluate_binary_literal(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
7820 match op {
7821 BinaryOp::And => evaluate_literal_logical_and(left, right),
7822 BinaryOp::Or => evaluate_literal_logical_or(left, right),
7823 BinaryOp::Add
7824 | BinaryOp::Subtract
7825 | BinaryOp::Multiply
7826 | BinaryOp::Divide
7827 | BinaryOp::Modulo => {
7828 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
7829 return Some(Literal::Null);
7830 }
7831
7832 match op {
7833 BinaryOp::Add => add_literals(left, right),
7834 BinaryOp::Subtract => subtract_literals(left, right),
7835 BinaryOp::Multiply => multiply_literals(left, right),
7836 BinaryOp::Divide => divide_literals(left, right),
7837 BinaryOp::Modulo => modulo_literals(left, right),
7838 BinaryOp::And
7839 | BinaryOp::Or
7840 | BinaryOp::BitwiseShiftLeft
7841 | BinaryOp::BitwiseShiftRight => unreachable!(),
7842 }
7843 }
7844 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
7845 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
7846 return Some(Literal::Null);
7847 }
7848
7849 let lhs = literal_to_i128(left)?;
7851 let rhs = literal_to_i128(right)?;
7852
7853 let result = match op {
7855 BinaryOp::BitwiseShiftLeft => (lhs as i64).wrapping_shl(rhs as u32) as i128,
7856 BinaryOp::BitwiseShiftRight => (lhs as i64).wrapping_shr(rhs as u32) as i128,
7857 _ => unreachable!(),
7858 };
7859
7860 Some(Literal::Integer(result))
7861 }
7862 }
7863}
7864
7865fn evaluate_literal_logical_and(left: &Literal, right: &Literal) -> Option<Literal> {
7866 let left_truth = literal_truthiness(left);
7867 if matches!(left_truth, Some(false)) {
7868 return Some(Literal::Integer(0));
7869 }
7870
7871 let right_truth = literal_truthiness(right);
7872 if matches!(right_truth, Some(false)) {
7873 return Some(Literal::Integer(0));
7874 }
7875
7876 match (left_truth, right_truth) {
7877 (Some(true), Some(true)) => Some(Literal::Integer(1)),
7878 (Some(true), None) | (None, Some(true)) | (None, None) => Some(Literal::Null),
7879 _ => Some(Literal::Null),
7880 }
7881}
7882
7883fn evaluate_literal_logical_or(left: &Literal, right: &Literal) -> Option<Literal> {
7884 let left_truth = literal_truthiness(left);
7885 if matches!(left_truth, Some(true)) {
7886 return Some(Literal::Integer(1));
7887 }
7888
7889 let right_truth = literal_truthiness(right);
7890 if matches!(right_truth, Some(true)) {
7891 return Some(Literal::Integer(1));
7892 }
7893
7894 match (left_truth, right_truth) {
7895 (Some(false), Some(false)) => Some(Literal::Integer(0)),
7896 (Some(false), None) | (None, Some(false)) | (None, None) => Some(Literal::Null),
7897 _ => Some(Literal::Null),
7898 }
7899}
7900
7901fn add_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7902 match (left, right) {
7903 (Literal::Integer(lhs), Literal::Integer(rhs)) => {
7904 Some(Literal::Integer(lhs.saturating_add(*rhs)))
7905 }
7906 _ => {
7907 let lhs = literal_to_f64(left)?;
7908 let rhs = literal_to_f64(right)?;
7909 Some(Literal::Float(lhs + rhs))
7910 }
7911 }
7912}
7913
7914fn subtract_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7915 match (left, right) {
7916 (Literal::Integer(lhs), Literal::Integer(rhs)) => {
7917 Some(Literal::Integer(lhs.saturating_sub(*rhs)))
7918 }
7919 _ => {
7920 let lhs = literal_to_f64(left)?;
7921 let rhs = literal_to_f64(right)?;
7922 Some(Literal::Float(lhs - rhs))
7923 }
7924 }
7925}
7926
7927fn multiply_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7928 match (left, right) {
7929 (Literal::Integer(lhs), Literal::Integer(rhs)) => {
7930 Some(Literal::Integer(lhs.saturating_mul(*rhs)))
7931 }
7932 _ => {
7933 let lhs = literal_to_f64(left)?;
7934 let rhs = literal_to_f64(right)?;
7935 Some(Literal::Float(lhs * rhs))
7936 }
7937 }
7938}
7939
7940fn divide_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7941 fn literal_to_i128_from_integer_like(literal: &Literal) -> Option<i128> {
7942 match literal {
7943 Literal::Integer(value) => Some(*value),
7944 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
7945 _ => None,
7946 }
7947 }
7948
7949 if let (Some(lhs), Some(rhs)) = (
7950 literal_to_i128_from_integer_like(left),
7951 literal_to_i128_from_integer_like(right),
7952 ) {
7953 if rhs == 0 {
7954 return Some(Literal::Null);
7955 }
7956
7957 if lhs == i128::MIN && rhs == -1 {
7958 return Some(Literal::Float((lhs as f64) / (rhs as f64)));
7959 }
7960
7961 return Some(Literal::Integer(lhs / rhs));
7962 }
7963
7964 let lhs = literal_to_f64(left)?;
7965 let rhs = literal_to_f64(right)?;
7966 if rhs == 0.0 {
7967 return Some(Literal::Null);
7968 }
7969 Some(Literal::Float(lhs / rhs))
7970}
7971
7972fn modulo_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7973 let lhs = literal_to_i128(left)?;
7974 let rhs = literal_to_i128(right)?;
7975 if rhs == 0 {
7976 return Some(Literal::Null);
7977 }
7978 Some(Literal::Integer(lhs % rhs))
7979}
7980
7981fn literal_to_f64(literal: &Literal) -> Option<f64> {
7982 match literal {
7983 Literal::Integer(value) => Some(*value as f64),
7984 Literal::Float(value) => Some(*value),
7985 Literal::Boolean(value) => Some(if *value { 1.0 } else { 0.0 }),
7986 _ => None,
7987 }
7988}
7989
7990fn literal_to_i128(literal: &Literal) -> Option<i128> {
7991 match literal {
7992 Literal::Integer(value) => Some(*value),
7993 Literal::Float(value) => Some(*value as i128),
7994 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
7995 _ => None,
7996 }
7997}
7998
7999fn literal_truthiness(literal: &Literal) -> Option<bool> {
8000 match literal {
8001 Literal::Boolean(value) => Some(*value),
8002 Literal::Integer(value) => Some(*value != 0),
8003 Literal::Float(value) => Some(*value != 0.0),
8004 Literal::Null => None,
8005 _ => None,
8006 }
8007}
8008
8009fn plan_value_truthiness(value: &PlanValue) -> Option<bool> {
8010 match value {
8011 PlanValue::Integer(v) => Some(*v != 0),
8012 PlanValue::Float(v) => Some(*v != 0.0),
8013 PlanValue::Null => None,
8014 _ => None,
8015 }
8016}
8017
8018fn option_i64_truthiness(value: Option<i64>) -> Option<bool> {
8019 value.map(|v| v != 0)
8020}
8021
8022fn evaluate_plan_value_logical_and(left: PlanValue, right: PlanValue) -> PlanValue {
8023 let left_truth = plan_value_truthiness(&left);
8024 if matches!(left_truth, Some(false)) {
8025 return PlanValue::Integer(0);
8026 }
8027
8028 let right_truth = plan_value_truthiness(&right);
8029 if matches!(right_truth, Some(false)) {
8030 return PlanValue::Integer(0);
8031 }
8032
8033 match (left_truth, right_truth) {
8034 (Some(true), Some(true)) => PlanValue::Integer(1),
8035 (Some(true), None) | (None, Some(true)) | (None, None) => PlanValue::Null,
8036 _ => PlanValue::Null,
8037 }
8038}
8039
8040fn evaluate_plan_value_logical_or(left: PlanValue, right: PlanValue) -> PlanValue {
8041 let left_truth = plan_value_truthiness(&left);
8042 if matches!(left_truth, Some(true)) {
8043 return PlanValue::Integer(1);
8044 }
8045
8046 let right_truth = plan_value_truthiness(&right);
8047 if matches!(right_truth, Some(true)) {
8048 return PlanValue::Integer(1);
8049 }
8050
8051 match (left_truth, right_truth) {
8052 (Some(false), Some(false)) => PlanValue::Integer(0),
8053 (Some(false), None) | (None, Some(false)) | (None, None) => PlanValue::Null,
8054 _ => PlanValue::Null,
8055 }
8056}
8057
8058fn evaluate_option_logical_and(left: Option<i64>, right: Option<i64>) -> Option<i64> {
8059 let left_truth = option_i64_truthiness(left);
8060 if matches!(left_truth, Some(false)) {
8061 return Some(0);
8062 }
8063
8064 let right_truth = option_i64_truthiness(right);
8065 if matches!(right_truth, Some(false)) {
8066 return Some(0);
8067 }
8068
8069 match (left_truth, right_truth) {
8070 (Some(true), Some(true)) => Some(1),
8071 (Some(true), None) | (None, Some(true)) | (None, None) => None,
8072 _ => None,
8073 }
8074}
8075
8076fn evaluate_option_logical_or(left: Option<i64>, right: Option<i64>) -> Option<i64> {
8077 let left_truth = option_i64_truthiness(left);
8078 if matches!(left_truth, Some(true)) {
8079 return Some(1);
8080 }
8081
8082 let right_truth = option_i64_truthiness(right);
8083 if matches!(right_truth, Some(true)) {
8084 return Some(1);
8085 }
8086
8087 match (left_truth, right_truth) {
8088 (Some(false), Some(false)) => Some(0),
8089 (Some(false), None) | (None, Some(false)) | (None, None) => None,
8090 _ => None,
8091 }
8092}
8093
8094fn cast_literal_to_type(literal: &Literal, data_type: &DataType) -> Option<Literal> {
8095 if matches!(literal, Literal::Null) {
8096 return Some(Literal::Null);
8097 }
8098
8099 match data_type {
8100 DataType::Boolean => literal_truthiness(literal).map(Literal::Boolean),
8101 DataType::Float16 | DataType::Float32 | DataType::Float64 => {
8102 let value = literal_to_f64(literal)?;
8103 Some(Literal::Float(value))
8104 }
8105 DataType::Int8
8106 | DataType::Int16
8107 | DataType::Int32
8108 | DataType::Int64
8109 | DataType::UInt8
8110 | DataType::UInt16
8111 | DataType::UInt32
8112 | DataType::UInt64 => {
8113 let value = literal_to_i128(literal)?;
8114 Some(Literal::Integer(value))
8115 }
8116 DataType::Utf8 | DataType::LargeUtf8 => Some(Literal::String(match literal {
8117 Literal::String(text) => text.clone(),
8118 Literal::Integer(value) => value.to_string(),
8119 Literal::Float(value) => value.to_string(),
8120 Literal::Boolean(value) => {
8121 if *value {
8122 "1".to_string()
8123 } else {
8124 "0".to_string()
8125 }
8126 }
8127 Literal::Struct(_) | Literal::Null => return None,
8128 })),
8129 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
8130 literal_to_i128(literal).map(Literal::Integer)
8131 }
8132 _ => None,
8133 }
8134}
8135
8136fn compare_literals(op: CompareOp, left: &Literal, right: &Literal) -> Option<bool> {
8137 compare_literals_with_mode(op, left, right, NullComparisonBehavior::ThreeValuedLogic)
8138}
8139
8140fn bind_select_filter(
8141 filter: &llkv_plan::SelectFilter,
8142 bindings: &FxHashMap<String, Literal>,
8143) -> ExecutorResult<llkv_plan::SelectFilter> {
8144 let predicate = bind_predicate_expr(&filter.predicate, bindings)?;
8145 let subqueries = filter
8146 .subqueries
8147 .iter()
8148 .map(|subquery| bind_filter_subquery(subquery, bindings))
8149 .collect::<ExecutorResult<Vec<_>>>()?;
8150
8151 Ok(llkv_plan::SelectFilter {
8152 predicate,
8153 subqueries,
8154 })
8155}
8156
8157fn bind_filter_subquery(
8158 subquery: &llkv_plan::FilterSubquery,
8159 bindings: &FxHashMap<String, Literal>,
8160) -> ExecutorResult<llkv_plan::FilterSubquery> {
8161 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
8162 Ok(llkv_plan::FilterSubquery {
8163 id: subquery.id,
8164 plan: Box::new(bound_plan),
8165 correlated_columns: subquery.correlated_columns.clone(),
8166 })
8167}
8168
8169fn bind_scalar_subquery(
8170 subquery: &llkv_plan::ScalarSubquery,
8171 bindings: &FxHashMap<String, Literal>,
8172) -> ExecutorResult<llkv_plan::ScalarSubquery> {
8173 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
8174 Ok(llkv_plan::ScalarSubquery {
8175 id: subquery.id,
8176 plan: Box::new(bound_plan),
8177 correlated_columns: subquery.correlated_columns.clone(),
8178 })
8179}
8180
8181fn bind_projection(
8182 projection: &SelectProjection,
8183 bindings: &FxHashMap<String, Literal>,
8184) -> ExecutorResult<SelectProjection> {
8185 match projection {
8186 SelectProjection::AllColumns => Ok(projection.clone()),
8187 SelectProjection::AllColumnsExcept { exclude } => Ok(SelectProjection::AllColumnsExcept {
8188 exclude: exclude.clone(),
8189 }),
8190 SelectProjection::Column { name, alias } => {
8191 if let Some(literal) = bindings.get(name) {
8192 let expr = ScalarExpr::Literal(literal.clone());
8193 Ok(SelectProjection::Computed {
8194 expr,
8195 alias: alias.clone().unwrap_or_else(|| name.clone()),
8196 })
8197 } else {
8198 Ok(projection.clone())
8199 }
8200 }
8201 SelectProjection::Computed { expr, alias } => Ok(SelectProjection::Computed {
8202 expr: bind_scalar_expr(expr, bindings)?,
8203 alias: alias.clone(),
8204 }),
8205 }
8206}
8207
8208fn bind_aggregate_expr(
8209 aggregate: &AggregateExpr,
8210 bindings: &FxHashMap<String, Literal>,
8211) -> ExecutorResult<AggregateExpr> {
8212 match aggregate {
8213 AggregateExpr::CountStar { .. } => Ok(aggregate.clone()),
8214 AggregateExpr::Column {
8215 column,
8216 alias,
8217 function,
8218 distinct,
8219 } => {
8220 if bindings.contains_key(column) {
8221 return Err(Error::InvalidArgumentError(
8222 "correlated columns are not supported inside aggregate expressions".into(),
8223 ));
8224 }
8225 Ok(AggregateExpr::Column {
8226 column: column.clone(),
8227 alias: alias.clone(),
8228 function: function.clone(),
8229 distinct: *distinct,
8230 })
8231 }
8232 }
8233}
8234
8235fn bind_scalar_expr(
8236 expr: &ScalarExpr<String>,
8237 bindings: &FxHashMap<String, Literal>,
8238) -> ExecutorResult<ScalarExpr<String>> {
8239 match expr {
8240 ScalarExpr::Column(name) => {
8241 if let Some(literal) = bindings.get(name) {
8242 Ok(ScalarExpr::Literal(literal.clone()))
8243 } else {
8244 Ok(ScalarExpr::Column(name.clone()))
8245 }
8246 }
8247 ScalarExpr::Literal(literal) => Ok(ScalarExpr::Literal(literal.clone())),
8248 ScalarExpr::Binary { left, op, right } => Ok(ScalarExpr::Binary {
8249 left: Box::new(bind_scalar_expr(left, bindings)?),
8250 op: *op,
8251 right: Box::new(bind_scalar_expr(right, bindings)?),
8252 }),
8253 ScalarExpr::Compare { left, op, right } => Ok(ScalarExpr::Compare {
8254 left: Box::new(bind_scalar_expr(left, bindings)?),
8255 op: *op,
8256 right: Box::new(bind_scalar_expr(right, bindings)?),
8257 }),
8258 ScalarExpr::Aggregate(call) => Ok(ScalarExpr::Aggregate(call.clone())),
8259 ScalarExpr::GetField { base, field_name } => {
8260 let bound_base = bind_scalar_expr(base, bindings)?;
8261 match bound_base {
8262 ScalarExpr::Literal(literal) => {
8263 let value = extract_struct_field(&literal, field_name).unwrap_or(Literal::Null);
8264 Ok(ScalarExpr::Literal(value))
8265 }
8266 other => Ok(ScalarExpr::GetField {
8267 base: Box::new(other),
8268 field_name: field_name.clone(),
8269 }),
8270 }
8271 }
8272 ScalarExpr::Cast { expr, data_type } => Ok(ScalarExpr::Cast {
8273 expr: Box::new(bind_scalar_expr(expr, bindings)?),
8274 data_type: data_type.clone(),
8275 }),
8276 ScalarExpr::Case {
8277 operand,
8278 branches,
8279 else_expr,
8280 } => {
8281 let bound_operand = match operand {
8282 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
8283 None => None,
8284 };
8285 let mut bound_branches = Vec::with_capacity(branches.len());
8286 for (when_expr, then_expr) in branches {
8287 bound_branches.push((
8288 bind_scalar_expr(when_expr, bindings)?,
8289 bind_scalar_expr(then_expr, bindings)?,
8290 ));
8291 }
8292 let bound_else = match else_expr {
8293 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
8294 None => None,
8295 };
8296 Ok(ScalarExpr::Case {
8297 operand: bound_operand,
8298 branches: bound_branches,
8299 else_expr: bound_else,
8300 })
8301 }
8302 ScalarExpr::Coalesce(items) => {
8303 let mut bound_items = Vec::with_capacity(items.len());
8304 for item in items {
8305 bound_items.push(bind_scalar_expr(item, bindings)?);
8306 }
8307 Ok(ScalarExpr::Coalesce(bound_items))
8308 }
8309 ScalarExpr::Not(inner) => Ok(ScalarExpr::Not(Box::new(bind_scalar_expr(
8310 inner, bindings,
8311 )?))),
8312 ScalarExpr::IsNull { expr, negated } => Ok(ScalarExpr::IsNull {
8313 expr: Box::new(bind_scalar_expr(expr, bindings)?),
8314 negated: *negated,
8315 }),
8316 ScalarExpr::Random => Ok(ScalarExpr::Random),
8317 ScalarExpr::ScalarSubquery(sub) => Ok(ScalarExpr::ScalarSubquery(sub.clone())),
8318 }
8319}
8320
8321fn bind_predicate_expr(
8322 expr: &LlkvExpr<'static, String>,
8323 bindings: &FxHashMap<String, Literal>,
8324) -> ExecutorResult<LlkvExpr<'static, String>> {
8325 match expr {
8326 LlkvExpr::And(children) => {
8327 let mut bound = Vec::with_capacity(children.len());
8328 for child in children {
8329 bound.push(bind_predicate_expr(child, bindings)?);
8330 }
8331 Ok(LlkvExpr::And(bound))
8332 }
8333 LlkvExpr::Or(children) => {
8334 let mut bound = Vec::with_capacity(children.len());
8335 for child in children {
8336 bound.push(bind_predicate_expr(child, bindings)?);
8337 }
8338 Ok(LlkvExpr::Or(bound))
8339 }
8340 LlkvExpr::Not(inner) => Ok(LlkvExpr::Not(Box::new(bind_predicate_expr(
8341 inner, bindings,
8342 )?))),
8343 LlkvExpr::Pred(filter) => bind_filter_predicate(filter, bindings),
8344 LlkvExpr::Compare { left, op, right } => Ok(LlkvExpr::Compare {
8345 left: bind_scalar_expr(left, bindings)?,
8346 op: *op,
8347 right: bind_scalar_expr(right, bindings)?,
8348 }),
8349 LlkvExpr::InList {
8350 expr,
8351 list,
8352 negated,
8353 } => {
8354 let target = bind_scalar_expr(expr, bindings)?;
8355 let mut bound_list = Vec::with_capacity(list.len());
8356 for item in list {
8357 bound_list.push(bind_scalar_expr(item, bindings)?);
8358 }
8359 Ok(LlkvExpr::InList {
8360 expr: target,
8361 list: bound_list,
8362 negated: *negated,
8363 })
8364 }
8365 LlkvExpr::IsNull { expr, negated } => Ok(LlkvExpr::IsNull {
8366 expr: bind_scalar_expr(expr, bindings)?,
8367 negated: *negated,
8368 }),
8369 LlkvExpr::Literal(value) => Ok(LlkvExpr::Literal(*value)),
8370 LlkvExpr::Exists(subquery) => Ok(LlkvExpr::Exists(subquery.clone())),
8371 }
8372}
8373
8374fn bind_filter_predicate(
8375 filter: &Filter<'static, String>,
8376 bindings: &FxHashMap<String, Literal>,
8377) -> ExecutorResult<LlkvExpr<'static, String>> {
8378 if let Some(literal) = bindings.get(&filter.field_id) {
8379 let result = evaluate_filter_against_literal(literal, &filter.op)?;
8380 return Ok(LlkvExpr::Literal(result));
8381 }
8382 Ok(LlkvExpr::Pred(filter.clone()))
8383}
8384
8385fn evaluate_filter_against_literal(value: &Literal, op: &Operator) -> ExecutorResult<bool> {
8386 use std::ops::Bound;
8387
8388 match op {
8389 Operator::IsNull => Ok(matches!(value, Literal::Null)),
8390 Operator::IsNotNull => Ok(!matches!(value, Literal::Null)),
8391 Operator::Equals(rhs) => Ok(literal_equals(value, rhs).unwrap_or(false)),
8392 Operator::GreaterThan(rhs) => Ok(literal_compare(value, rhs)
8393 .map(|cmp| cmp == std::cmp::Ordering::Greater)
8394 .unwrap_or(false)),
8395 Operator::GreaterThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
8396 .map(|cmp| matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal))
8397 .unwrap_or(false)),
8398 Operator::LessThan(rhs) => Ok(literal_compare(value, rhs)
8399 .map(|cmp| cmp == std::cmp::Ordering::Less)
8400 .unwrap_or(false)),
8401 Operator::LessThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
8402 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal))
8403 .unwrap_or(false)),
8404 Operator::In(values) => Ok(values
8405 .iter()
8406 .any(|candidate| literal_equals(value, candidate).unwrap_or(false))),
8407 Operator::Range { lower, upper } => {
8408 let lower_ok = match lower {
8409 Bound::Unbounded => Some(true),
8410 Bound::Included(bound) => literal_compare(value, bound).map(|cmp| {
8411 matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
8412 }),
8413 Bound::Excluded(bound) => {
8414 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Greater)
8415 }
8416 }
8417 .unwrap_or(false);
8418
8419 let upper_ok = match upper {
8420 Bound::Unbounded => Some(true),
8421 Bound::Included(bound) => literal_compare(value, bound)
8422 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal)),
8423 Bound::Excluded(bound) => {
8424 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Less)
8425 }
8426 }
8427 .unwrap_or(false);
8428
8429 Ok(lower_ok && upper_ok)
8430 }
8431 Operator::StartsWith {
8432 pattern,
8433 case_sensitive,
8434 } => {
8435 let target = if *case_sensitive {
8436 pattern.to_string()
8437 } else {
8438 pattern.to_ascii_lowercase()
8439 };
8440 Ok(literal_string(value, *case_sensitive)
8441 .map(|source| source.starts_with(&target))
8442 .unwrap_or(false))
8443 }
8444 Operator::EndsWith {
8445 pattern,
8446 case_sensitive,
8447 } => {
8448 let target = if *case_sensitive {
8449 pattern.to_string()
8450 } else {
8451 pattern.to_ascii_lowercase()
8452 };
8453 Ok(literal_string(value, *case_sensitive)
8454 .map(|source| source.ends_with(&target))
8455 .unwrap_or(false))
8456 }
8457 Operator::Contains {
8458 pattern,
8459 case_sensitive,
8460 } => {
8461 let target = if *case_sensitive {
8462 pattern.to_string()
8463 } else {
8464 pattern.to_ascii_lowercase()
8465 };
8466 Ok(literal_string(value, *case_sensitive)
8467 .map(|source| source.contains(&target))
8468 .unwrap_or(false))
8469 }
8470 }
8471}
8472
8473fn literal_compare(lhs: &Literal, rhs: &Literal) -> Option<std::cmp::Ordering> {
8474 match (lhs, rhs) {
8475 (Literal::Integer(a), Literal::Integer(b)) => Some(a.cmp(b)),
8476 (Literal::Float(a), Literal::Float(b)) => a.partial_cmp(b),
8477 (Literal::Integer(a), Literal::Float(b)) => (*a as f64).partial_cmp(b),
8478 (Literal::Float(a), Literal::Integer(b)) => a.partial_cmp(&(*b as f64)),
8479 (Literal::String(a), Literal::String(b)) => Some(a.cmp(b)),
8480 _ => None,
8481 }
8482}
8483
8484fn literal_equals(lhs: &Literal, rhs: &Literal) -> Option<bool> {
8485 match (lhs, rhs) {
8486 (Literal::Boolean(a), Literal::Boolean(b)) => Some(a == b),
8487 (Literal::String(a), Literal::String(b)) => Some(a == b),
8488 (Literal::Integer(_), Literal::Integer(_))
8489 | (Literal::Integer(_), Literal::Float(_))
8490 | (Literal::Float(_), Literal::Integer(_))
8491 | (Literal::Float(_), Literal::Float(_)) => {
8492 literal_compare(lhs, rhs).map(|cmp| cmp == std::cmp::Ordering::Equal)
8493 }
8494 _ => None,
8495 }
8496}
8497
8498fn literal_string(literal: &Literal, case_sensitive: bool) -> Option<String> {
8499 match literal {
8500 Literal::String(value) => {
8501 if case_sensitive {
8502 Some(value.clone())
8503 } else {
8504 Some(value.to_ascii_lowercase())
8505 }
8506 }
8507 _ => None,
8508 }
8509}
8510
8511fn extract_struct_field(literal: &Literal, field_name: &str) -> Option<Literal> {
8512 if let Literal::Struct(fields) = literal {
8513 for (name, value) in fields {
8514 if name.eq_ignore_ascii_case(field_name) {
8515 return Some((**value).clone());
8516 }
8517 }
8518 }
8519 None
8520}
8521
8522fn array_value_to_literal(array: &ArrayRef, idx: usize) -> ExecutorResult<Literal> {
8523 if array.is_null(idx) {
8524 return Ok(Literal::Null);
8525 }
8526
8527 match array.data_type() {
8528 DataType::Boolean => {
8529 let array = array
8530 .as_any()
8531 .downcast_ref::<BooleanArray>()
8532 .ok_or_else(|| Error::Internal("failed to downcast boolean array".into()))?;
8533 Ok(Literal::Boolean(array.value(idx)))
8534 }
8535 DataType::Int8 => {
8536 let array = array
8537 .as_any()
8538 .downcast_ref::<Int8Array>()
8539 .ok_or_else(|| Error::Internal("failed to downcast int8 array".into()))?;
8540 Ok(Literal::Integer(array.value(idx) as i128))
8541 }
8542 DataType::Int16 => {
8543 let array = array
8544 .as_any()
8545 .downcast_ref::<Int16Array>()
8546 .ok_or_else(|| Error::Internal("failed to downcast int16 array".into()))?;
8547 Ok(Literal::Integer(array.value(idx) as i128))
8548 }
8549 DataType::Int32 => {
8550 let array = array
8551 .as_any()
8552 .downcast_ref::<Int32Array>()
8553 .ok_or_else(|| Error::Internal("failed to downcast int32 array".into()))?;
8554 Ok(Literal::Integer(array.value(idx) as i128))
8555 }
8556 DataType::Int64 => {
8557 let array = array
8558 .as_any()
8559 .downcast_ref::<Int64Array>()
8560 .ok_or_else(|| Error::Internal("failed to downcast int64 array".into()))?;
8561 Ok(Literal::Integer(array.value(idx) as i128))
8562 }
8563 DataType::UInt8 => {
8564 let array = array
8565 .as_any()
8566 .downcast_ref::<UInt8Array>()
8567 .ok_or_else(|| Error::Internal("failed to downcast uint8 array".into()))?;
8568 Ok(Literal::Integer(array.value(idx) as i128))
8569 }
8570 DataType::UInt16 => {
8571 let array = array
8572 .as_any()
8573 .downcast_ref::<UInt16Array>()
8574 .ok_or_else(|| Error::Internal("failed to downcast uint16 array".into()))?;
8575 Ok(Literal::Integer(array.value(idx) as i128))
8576 }
8577 DataType::UInt32 => {
8578 let array = array
8579 .as_any()
8580 .downcast_ref::<UInt32Array>()
8581 .ok_or_else(|| Error::Internal("failed to downcast uint32 array".into()))?;
8582 Ok(Literal::Integer(array.value(idx) as i128))
8583 }
8584 DataType::UInt64 => {
8585 let array = array
8586 .as_any()
8587 .downcast_ref::<UInt64Array>()
8588 .ok_or_else(|| Error::Internal("failed to downcast uint64 array".into()))?;
8589 Ok(Literal::Integer(array.value(idx) as i128))
8590 }
8591 DataType::Float32 => {
8592 let array = array
8593 .as_any()
8594 .downcast_ref::<Float32Array>()
8595 .ok_or_else(|| Error::Internal("failed to downcast float32 array".into()))?;
8596 Ok(Literal::Float(array.value(idx) as f64))
8597 }
8598 DataType::Float64 => {
8599 let array = array
8600 .as_any()
8601 .downcast_ref::<Float64Array>()
8602 .ok_or_else(|| Error::Internal("failed to downcast float64 array".into()))?;
8603 Ok(Literal::Float(array.value(idx)))
8604 }
8605 DataType::Utf8 => {
8606 let array = array
8607 .as_any()
8608 .downcast_ref::<StringArray>()
8609 .ok_or_else(|| Error::Internal("failed to downcast utf8 array".into()))?;
8610 Ok(Literal::String(array.value(idx).to_string()))
8611 }
8612 DataType::LargeUtf8 => {
8613 let array = array
8614 .as_any()
8615 .downcast_ref::<LargeStringArray>()
8616 .ok_or_else(|| Error::Internal("failed to downcast large utf8 array".into()))?;
8617 Ok(Literal::String(array.value(idx).to_string()))
8618 }
8619 DataType::Struct(fields) => {
8620 let struct_array = array
8621 .as_any()
8622 .downcast_ref::<StructArray>()
8623 .ok_or_else(|| Error::Internal("failed to downcast struct array".into()))?;
8624 let mut members = Vec::with_capacity(fields.len());
8625 for (field_idx, field) in fields.iter().enumerate() {
8626 let child = struct_array.column(field_idx);
8627 let literal = array_value_to_literal(child, idx)?;
8628 members.push((field.name().clone(), Box::new(literal)));
8629 }
8630 Ok(Literal::Struct(members))
8631 }
8632 other => Err(Error::InvalidArgumentError(format!(
8633 "unsupported scalar subquery result type: {other:?}"
8634 ))),
8635 }
8636}
8637
8638fn collect_scalar_subquery_ids(expr: &ScalarExpr<FieldId>, ids: &mut FxHashSet<SubqueryId>) {
8639 match expr {
8640 ScalarExpr::ScalarSubquery(subquery) => {
8641 ids.insert(subquery.id);
8642 }
8643 ScalarExpr::Binary { left, right, .. } => {
8644 collect_scalar_subquery_ids(left, ids);
8645 collect_scalar_subquery_ids(right, ids);
8646 }
8647 ScalarExpr::Compare { left, right, .. } => {
8648 collect_scalar_subquery_ids(left, ids);
8649 collect_scalar_subquery_ids(right, ids);
8650 }
8651 ScalarExpr::GetField { base, .. } => {
8652 collect_scalar_subquery_ids(base, ids);
8653 }
8654 ScalarExpr::Cast { expr, .. } => {
8655 collect_scalar_subquery_ids(expr, ids);
8656 }
8657 ScalarExpr::Not(expr) => {
8658 collect_scalar_subquery_ids(expr, ids);
8659 }
8660 ScalarExpr::IsNull { expr, .. } => {
8661 collect_scalar_subquery_ids(expr, ids);
8662 }
8663 ScalarExpr::Case {
8664 operand,
8665 branches,
8666 else_expr,
8667 } => {
8668 if let Some(op) = operand {
8669 collect_scalar_subquery_ids(op, ids);
8670 }
8671 for (when_expr, then_expr) in branches {
8672 collect_scalar_subquery_ids(when_expr, ids);
8673 collect_scalar_subquery_ids(then_expr, ids);
8674 }
8675 if let Some(else_expr) = else_expr {
8676 collect_scalar_subquery_ids(else_expr, ids);
8677 }
8678 }
8679 ScalarExpr::Coalesce(items) => {
8680 for item in items {
8681 collect_scalar_subquery_ids(item, ids);
8682 }
8683 }
8684 ScalarExpr::Aggregate(_)
8685 | ScalarExpr::Column(_)
8686 | ScalarExpr::Literal(_)
8687 | ScalarExpr::Random => {}
8688 }
8689}
8690
8691fn rewrite_scalar_expr_for_subqueries(
8692 expr: &ScalarExpr<FieldId>,
8693 mapping: &FxHashMap<SubqueryId, FieldId>,
8694) -> ScalarExpr<FieldId> {
8695 match expr {
8696 ScalarExpr::ScalarSubquery(subquery) => mapping
8697 .get(&subquery.id)
8698 .map(|field_id| ScalarExpr::Column(*field_id))
8699 .unwrap_or_else(|| ScalarExpr::ScalarSubquery(subquery.clone())),
8700 ScalarExpr::Binary { left, op, right } => ScalarExpr::Binary {
8701 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
8702 op: *op,
8703 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
8704 },
8705 ScalarExpr::Compare { left, op, right } => ScalarExpr::Compare {
8706 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
8707 op: *op,
8708 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
8709 },
8710 ScalarExpr::GetField { base, field_name } => ScalarExpr::GetField {
8711 base: Box::new(rewrite_scalar_expr_for_subqueries(base, mapping)),
8712 field_name: field_name.clone(),
8713 },
8714 ScalarExpr::Cast { expr, data_type } => ScalarExpr::Cast {
8715 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
8716 data_type: data_type.clone(),
8717 },
8718 ScalarExpr::Not(expr) => {
8719 ScalarExpr::Not(Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)))
8720 }
8721 ScalarExpr::IsNull { expr, negated } => ScalarExpr::IsNull {
8722 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
8723 negated: *negated,
8724 },
8725 ScalarExpr::Case {
8726 operand,
8727 branches,
8728 else_expr,
8729 } => ScalarExpr::Case {
8730 operand: operand
8731 .as_ref()
8732 .map(|op| Box::new(rewrite_scalar_expr_for_subqueries(op, mapping))),
8733 branches: branches
8734 .iter()
8735 .map(|(when_expr, then_expr)| {
8736 (
8737 rewrite_scalar_expr_for_subqueries(when_expr, mapping),
8738 rewrite_scalar_expr_for_subqueries(then_expr, mapping),
8739 )
8740 })
8741 .collect(),
8742 else_expr: else_expr
8743 .as_ref()
8744 .map(|expr| Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping))),
8745 },
8746 ScalarExpr::Coalesce(items) => ScalarExpr::Coalesce(
8747 items
8748 .iter()
8749 .map(|item| rewrite_scalar_expr_for_subqueries(item, mapping))
8750 .collect(),
8751 ),
8752 ScalarExpr::Aggregate(_)
8753 | ScalarExpr::Column(_)
8754 | ScalarExpr::Literal(_)
8755 | ScalarExpr::Random => expr.clone(),
8756 }
8757}
8758
8759fn collect_correlated_bindings(
8760 context: &mut CrossProductExpressionContext,
8761 batch: &RecordBatch,
8762 row_idx: usize,
8763 columns: &[llkv_plan::CorrelatedColumn],
8764) -> ExecutorResult<FxHashMap<String, Literal>> {
8765 let mut out = FxHashMap::default();
8766
8767 for correlated in columns {
8768 if !correlated.field_path.is_empty() {
8769 return Err(Error::InvalidArgumentError(
8770 "correlated field path resolution is not yet supported".into(),
8771 ));
8772 }
8773
8774 let field_id = context
8775 .field_id_for_column(&correlated.column)
8776 .ok_or_else(|| {
8777 Error::InvalidArgumentError(format!(
8778 "correlated column '{}' not found in outer query output",
8779 correlated.column
8780 ))
8781 })?;
8782
8783 let accessor = context.column_accessor(field_id, batch)?;
8784 let literal = accessor.literal_at(row_idx)?;
8785 out.insert(correlated.placeholder.clone(), literal);
8786 }
8787
8788 Ok(out)
8789}
8790
8791#[derive(Clone)]
8793pub struct SelectExecution<P>
8794where
8795 P: Pager<Blob = EntryHandle> + Send + Sync,
8796{
8797 table_name: String,
8798 schema: Arc<Schema>,
8799 stream: SelectStream<P>,
8800}
8801
8802#[derive(Clone)]
8803enum SelectStream<P>
8804where
8805 P: Pager<Blob = EntryHandle> + Send + Sync,
8806{
8807 Projection {
8808 table: Arc<ExecutorTable<P>>,
8809 projections: Vec<ScanProjection>,
8810 filter_expr: LlkvExpr<'static, FieldId>,
8811 options: ScanStreamOptions<P>,
8812 full_table_scan: bool,
8813 order_by: Vec<OrderByPlan>,
8814 distinct: bool,
8815 },
8816 Aggregation {
8817 batch: RecordBatch,
8818 },
8819}
8820
8821impl<P> SelectExecution<P>
8822where
8823 P: Pager<Blob = EntryHandle> + Send + Sync,
8824{
8825 #[allow(clippy::too_many_arguments)]
8826 fn new_projection(
8827 table_name: String,
8828 schema: Arc<Schema>,
8829 table: Arc<ExecutorTable<P>>,
8830 projections: Vec<ScanProjection>,
8831 filter_expr: LlkvExpr<'static, FieldId>,
8832 options: ScanStreamOptions<P>,
8833 full_table_scan: bool,
8834 order_by: Vec<OrderByPlan>,
8835 distinct: bool,
8836 ) -> Self {
8837 Self {
8838 table_name,
8839 schema,
8840 stream: SelectStream::Projection {
8841 table,
8842 projections,
8843 filter_expr,
8844 options,
8845 full_table_scan,
8846 order_by,
8847 distinct,
8848 },
8849 }
8850 }
8851
8852 pub fn new_single_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
8853 Self {
8854 table_name,
8855 schema,
8856 stream: SelectStream::Aggregation { batch },
8857 }
8858 }
8859
8860 pub fn from_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
8861 Self::new_single_batch(table_name, schema, batch)
8862 }
8863
8864 pub fn table_name(&self) -> &str {
8865 &self.table_name
8866 }
8867
8868 pub fn schema(&self) -> Arc<Schema> {
8869 Arc::clone(&self.schema)
8870 }
8871
8872 pub fn stream(
8873 self,
8874 mut on_batch: impl FnMut(RecordBatch) -> ExecutorResult<()>,
8875 ) -> ExecutorResult<()> {
8876 let schema = Arc::clone(&self.schema);
8877 match self.stream {
8878 SelectStream::Projection {
8879 table,
8880 projections,
8881 filter_expr,
8882 options,
8883 full_table_scan,
8884 order_by,
8885 distinct,
8886 } => {
8887 let total_rows = table.total_rows.load(Ordering::SeqCst);
8889 if total_rows == 0 {
8890 return Ok(());
8892 }
8893
8894 let mut error: Option<Error> = None;
8895 let mut produced = false;
8896 let mut produced_rows: u64 = 0;
8897 let capture_nulls_first = matches!(options.order, Some(spec) if spec.nulls_first);
8898 let needs_post_sort =
8899 !order_by.is_empty() && (order_by.len() > 1 || options.order.is_none());
8900 let collect_batches = needs_post_sort || capture_nulls_first;
8901 let include_nulls = options.include_nulls;
8902 let has_row_id_filter = options.row_id_filter.is_some();
8903 let mut distinct_state = if distinct {
8904 Some(DistinctState::default())
8905 } else {
8906 None
8907 };
8908 let scan_options = options;
8909 let mut buffered_batches: Vec<RecordBatch> = Vec::new();
8910 table
8911 .table
8912 .scan_stream(projections, &filter_expr, scan_options, |batch| {
8913 if error.is_some() {
8914 return;
8915 }
8916 let mut batch = batch;
8917 if let Some(state) = distinct_state.as_mut() {
8918 match distinct_filter_batch(batch, state) {
8919 Ok(Some(filtered)) => {
8920 batch = filtered;
8921 }
8922 Ok(None) => {
8923 return;
8924 }
8925 Err(err) => {
8926 error = Some(err);
8927 return;
8928 }
8929 }
8930 }
8931 produced = true;
8932 produced_rows = produced_rows.saturating_add(batch.num_rows() as u64);
8933 if collect_batches {
8934 buffered_batches.push(batch);
8935 } else if let Err(err) = on_batch(batch) {
8936 error = Some(err);
8937 }
8938 })?;
8939 if let Some(err) = error {
8940 return Err(err);
8941 }
8942 if !produced {
8943 if !distinct && full_table_scan && total_rows > 0 {
8946 for batch in synthesize_null_scan(Arc::clone(&schema), total_rows)? {
8947 on_batch(batch)?;
8948 }
8949 }
8950 return Ok(());
8951 }
8952 let mut null_batches: Vec<RecordBatch> = Vec::new();
8953 if !distinct
8959 && include_nulls
8960 && full_table_scan
8961 && produced_rows < total_rows
8962 && !has_row_id_filter
8963 {
8964 let missing = total_rows - produced_rows;
8965 if missing > 0 {
8966 null_batches = synthesize_null_scan(Arc::clone(&schema), missing)?;
8967 }
8968 }
8969
8970 if collect_batches {
8971 if needs_post_sort {
8972 if !null_batches.is_empty() {
8973 buffered_batches.extend(null_batches);
8974 }
8975 if !buffered_batches.is_empty() {
8976 let combined =
8977 concat_batches(&schema, &buffered_batches).map_err(|err| {
8978 Error::InvalidArgumentError(format!(
8979 "failed to concatenate result batches for ORDER BY: {}",
8980 err
8981 ))
8982 })?;
8983 let sorted_batch =
8984 sort_record_batch_with_order(&schema, &combined, &order_by)?;
8985 on_batch(sorted_batch)?;
8986 }
8987 } else if capture_nulls_first {
8988 for batch in null_batches {
8989 on_batch(batch)?;
8990 }
8991 for batch in buffered_batches {
8992 on_batch(batch)?;
8993 }
8994 }
8995 } else if !null_batches.is_empty() {
8996 for batch in null_batches {
8997 on_batch(batch)?;
8998 }
8999 }
9000 Ok(())
9001 }
9002 SelectStream::Aggregation { batch } => on_batch(batch),
9003 }
9004 }
9005
9006 pub fn collect(self) -> ExecutorResult<Vec<RecordBatch>> {
9007 let mut batches = Vec::new();
9008 self.stream(|batch| {
9009 batches.push(batch);
9010 Ok(())
9011 })?;
9012 Ok(batches)
9013 }
9014
9015 pub fn collect_rows(self) -> ExecutorResult<ExecutorRowBatch> {
9016 let schema = self.schema();
9017 let mut rows: Vec<Vec<PlanValue>> = Vec::new();
9018 self.stream(|batch| {
9019 for row_idx in 0..batch.num_rows() {
9020 let mut row: Vec<PlanValue> = Vec::with_capacity(batch.num_columns());
9021 for col_idx in 0..batch.num_columns() {
9022 let value = llkv_plan::plan_value_from_array(batch.column(col_idx), row_idx)?;
9023 row.push(value);
9024 }
9025 rows.push(row);
9026 }
9027 Ok(())
9028 })?;
9029 let columns = schema
9030 .fields()
9031 .iter()
9032 .map(|field| field.name().to_string())
9033 .collect();
9034 Ok(ExecutorRowBatch { columns, rows })
9035 }
9036
9037 pub fn into_rows(self) -> ExecutorResult<Vec<Vec<PlanValue>>> {
9038 Ok(self.collect_rows()?.rows)
9039 }
9040}
9041
9042impl<P> fmt::Debug for SelectExecution<P>
9043where
9044 P: Pager<Blob = EntryHandle> + Send + Sync,
9045{
9046 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
9047 f.debug_struct("SelectExecution")
9048 .field("table_name", &self.table_name)
9049 .field("schema", &self.schema)
9050 .finish()
9051 }
9052}
9053
9054fn expand_order_targets(
9059 order_items: &[OrderByPlan],
9060 projections: &[ScanProjection],
9061) -> ExecutorResult<Vec<OrderByPlan>> {
9062 let mut expanded = Vec::new();
9063
9064 for item in order_items {
9065 match &item.target {
9066 OrderTarget::All => {
9067 if projections.is_empty() {
9068 return Err(Error::InvalidArgumentError(
9069 "ORDER BY ALL requires at least one projection".into(),
9070 ));
9071 }
9072
9073 for (idx, projection) in projections.iter().enumerate() {
9074 if matches!(projection, ScanProjection::Computed { .. }) {
9075 return Err(Error::InvalidArgumentError(
9076 "ORDER BY ALL cannot reference computed projections".into(),
9077 ));
9078 }
9079
9080 let mut clone = item.clone();
9081 clone.target = OrderTarget::Index(idx);
9082 expanded.push(clone);
9083 }
9084 }
9085 _ => expanded.push(item.clone()),
9086 }
9087 }
9088
9089 Ok(expanded)
9090}
9091
9092fn resolve_scan_order<P>(
9093 table: &ExecutorTable<P>,
9094 projections: &[ScanProjection],
9095 order_plan: &OrderByPlan,
9096) -> ExecutorResult<ScanOrderSpec>
9097where
9098 P: Pager<Blob = EntryHandle> + Send + Sync,
9099{
9100 let (column, field_id) = match &order_plan.target {
9101 OrderTarget::Column(name) => {
9102 let column = table.schema.resolve(name).ok_or_else(|| {
9103 Error::InvalidArgumentError(format!("unknown column '{}' in ORDER BY", name))
9104 })?;
9105 (column, column.field_id)
9106 }
9107 OrderTarget::Index(position) => {
9108 let projection = projections.get(*position).ok_or_else(|| {
9109 Error::InvalidArgumentError(format!(
9110 "ORDER BY position {} is out of range",
9111 position + 1
9112 ))
9113 })?;
9114 match projection {
9115 ScanProjection::Column(store_projection) => {
9116 let field_id = store_projection.logical_field_id.field_id();
9117 let column = table.schema.column_by_field_id(field_id).ok_or_else(|| {
9118 Error::InvalidArgumentError(format!(
9119 "unknown column with field id {field_id} in ORDER BY"
9120 ))
9121 })?;
9122 (column, field_id)
9123 }
9124 ScanProjection::Computed { .. } => {
9125 return Err(Error::InvalidArgumentError(
9126 "ORDER BY position referring to computed projection is not supported"
9127 .into(),
9128 ));
9129 }
9130 }
9131 }
9132 OrderTarget::All => {
9133 return Err(Error::InvalidArgumentError(
9134 "ORDER BY ALL should be expanded before execution".into(),
9135 ));
9136 }
9137 };
9138
9139 let transform = match order_plan.sort_type {
9140 OrderSortType::Native => match column.data_type {
9141 DataType::Int64 => ScanOrderTransform::IdentityInteger,
9142 DataType::Utf8 => ScanOrderTransform::IdentityUtf8,
9143 ref other => {
9144 return Err(Error::InvalidArgumentError(format!(
9145 "ORDER BY on column type {:?} is not supported",
9146 other
9147 )));
9148 }
9149 },
9150 OrderSortType::CastTextToInteger => {
9151 if column.data_type != DataType::Utf8 {
9152 return Err(Error::InvalidArgumentError(
9153 "ORDER BY CAST expects a text column".into(),
9154 ));
9155 }
9156 ScanOrderTransform::CastUtf8ToInteger
9157 }
9158 };
9159
9160 let direction = if order_plan.ascending {
9161 ScanOrderDirection::Ascending
9162 } else {
9163 ScanOrderDirection::Descending
9164 };
9165
9166 Ok(ScanOrderSpec {
9167 field_id,
9168 direction,
9169 nulls_first: order_plan.nulls_first,
9170 transform,
9171 })
9172}
9173
9174fn synthesize_null_scan(schema: Arc<Schema>, total_rows: u64) -> ExecutorResult<Vec<RecordBatch>> {
9175 let row_count = usize::try_from(total_rows).map_err(|_| {
9176 Error::InvalidArgumentError("table row count exceeds supported in-memory batch size".into())
9177 })?;
9178
9179 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
9180 for field in schema.fields() {
9181 match field.data_type() {
9182 DataType::Int64 => {
9183 let mut builder = Int64Builder::with_capacity(row_count);
9184 for _ in 0..row_count {
9185 builder.append_null();
9186 }
9187 arrays.push(Arc::new(builder.finish()));
9188 }
9189 DataType::Float64 => {
9190 let mut builder = arrow::array::Float64Builder::with_capacity(row_count);
9191 for _ in 0..row_count {
9192 builder.append_null();
9193 }
9194 arrays.push(Arc::new(builder.finish()));
9195 }
9196 DataType::Utf8 => {
9197 let mut builder = arrow::array::StringBuilder::with_capacity(row_count, 0);
9198 for _ in 0..row_count {
9199 builder.append_null();
9200 }
9201 arrays.push(Arc::new(builder.finish()));
9202 }
9203 DataType::Date32 => {
9204 let mut builder = arrow::array::Date32Builder::with_capacity(row_count);
9205 for _ in 0..row_count {
9206 builder.append_null();
9207 }
9208 arrays.push(Arc::new(builder.finish()));
9209 }
9210 other => {
9211 return Err(Error::InvalidArgumentError(format!(
9212 "unsupported data type in null synthesis: {other:?}"
9213 )));
9214 }
9215 }
9216 }
9217
9218 let batch = RecordBatch::try_new(schema, arrays)?;
9219 Ok(vec![batch])
9220}
9221
9222struct TableCrossProductData {
9223 schema: Arc<Schema>,
9224 batches: Vec<RecordBatch>,
9225 column_counts: Vec<usize>,
9226 table_indices: Vec<usize>,
9227}
9228
9229fn collect_table_data<P>(
9230 table_index: usize,
9231 table_ref: &llkv_plan::TableRef,
9232 table: &ExecutorTable<P>,
9233 constraints: &[ColumnConstraint],
9234) -> ExecutorResult<TableCrossProductData>
9235where
9236 P: Pager<Blob = EntryHandle> + Send + Sync,
9237{
9238 if table.schema.columns.is_empty() {
9239 return Err(Error::InvalidArgumentError(format!(
9240 "table '{}' has no columns; cross products require at least one column",
9241 table_ref.qualified_name()
9242 )));
9243 }
9244
9245 let mut projections = Vec::with_capacity(table.schema.columns.len());
9246 let mut fields = Vec::with_capacity(table.schema.columns.len());
9247
9248 for column in &table.schema.columns {
9249 let table_component = table_ref
9250 .alias
9251 .as_deref()
9252 .unwrap_or(table_ref.table.as_str());
9253 let qualified_name = format!("{}.{}.{}", table_ref.schema, table_component, column.name);
9254 projections.push(ScanProjection::from(StoreProjection::with_alias(
9255 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
9256 qualified_name.clone(),
9257 )));
9258 fields.push(Field::new(
9259 qualified_name,
9260 column.data_type.clone(),
9261 column.nullable,
9262 ));
9263 }
9264
9265 let schema = Arc::new(Schema::new(fields));
9266
9267 let filter_field_id = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
9268 let filter_expr = crate::translation::expression::full_table_scan_filter(filter_field_id);
9269
9270 let mut raw_batches = Vec::new();
9271 table.table.scan_stream(
9272 projections,
9273 &filter_expr,
9274 ScanStreamOptions {
9275 include_nulls: true,
9276 ..ScanStreamOptions::default()
9277 },
9278 |batch| {
9279 raw_batches.push(batch);
9280 },
9281 )?;
9282
9283 let mut normalized_batches = Vec::with_capacity(raw_batches.len());
9284 for batch in raw_batches {
9285 let normalized = RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec())
9286 .map_err(|err| {
9287 Error::Internal(format!(
9288 "failed to align scan batch for table '{}': {}",
9289 table_ref.qualified_name(),
9290 err
9291 ))
9292 })?;
9293 normalized_batches.push(normalized);
9294 }
9295
9296 if !constraints.is_empty() {
9297 normalized_batches = apply_column_constraints_to_batches(normalized_batches, constraints)?;
9298 }
9299
9300 Ok(TableCrossProductData {
9301 schema,
9302 batches: normalized_batches,
9303 column_counts: vec![table.schema.columns.len()],
9304 table_indices: vec![table_index],
9305 })
9306}
9307
9308fn apply_column_constraints_to_batches(
9309 batches: Vec<RecordBatch>,
9310 constraints: &[ColumnConstraint],
9311) -> ExecutorResult<Vec<RecordBatch>> {
9312 if batches.is_empty() {
9313 return Ok(batches);
9314 }
9315
9316 let mut filtered = batches;
9317 for constraint in constraints {
9318 match constraint {
9319 ColumnConstraint::Equality(lit) => {
9320 filtered = filter_batches_by_literal(filtered, lit.column.column, &lit.value)?;
9321 }
9322 ColumnConstraint::InList(in_list) => {
9323 filtered =
9324 filter_batches_by_in_list(filtered, in_list.column.column, &in_list.values)?;
9325 }
9326 }
9327 if filtered.is_empty() {
9328 break;
9329 }
9330 }
9331
9332 Ok(filtered)
9333}
9334
9335fn filter_batches_by_literal(
9336 batches: Vec<RecordBatch>,
9337 column_idx: usize,
9338 literal: &PlanValue,
9339) -> ExecutorResult<Vec<RecordBatch>> {
9340 let mut result = Vec::with_capacity(batches.len());
9341
9342 for batch in batches {
9343 if column_idx >= batch.num_columns() {
9344 return Err(Error::Internal(
9345 "literal constraint referenced invalid column index".into(),
9346 ));
9347 }
9348
9349 if batch.num_rows() == 0 {
9350 result.push(batch);
9351 continue;
9352 }
9353
9354 let column = batch.column(column_idx);
9355 let mut keep_rows: Vec<u32> = Vec::with_capacity(batch.num_rows());
9356
9357 for row_idx in 0..batch.num_rows() {
9358 if array_value_equals_plan_value(column.as_ref(), row_idx, literal)? {
9359 keep_rows.push(row_idx as u32);
9360 }
9361 }
9362
9363 if keep_rows.len() == batch.num_rows() {
9364 result.push(batch);
9365 continue;
9366 }
9367
9368 if keep_rows.is_empty() {
9369 continue;
9371 }
9372
9373 let indices = UInt32Array::from(keep_rows);
9374 let mut filtered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
9375 for col_idx in 0..batch.num_columns() {
9376 let filtered = take(batch.column(col_idx).as_ref(), &indices, None)
9377 .map_err(|err| Error::Internal(format!("failed to apply literal filter: {err}")))?;
9378 filtered_columns.push(filtered);
9379 }
9380
9381 let filtered_batch =
9382 RecordBatch::try_new(batch.schema(), filtered_columns).map_err(|err| {
9383 Error::Internal(format!(
9384 "failed to rebuild batch after literal filter: {err}"
9385 ))
9386 })?;
9387 result.push(filtered_batch);
9388 }
9389
9390 Ok(result)
9391}
9392
9393fn filter_batches_by_in_list(
9394 batches: Vec<RecordBatch>,
9395 column_idx: usize,
9396 values: &[PlanValue],
9397) -> ExecutorResult<Vec<RecordBatch>> {
9398 use arrow::array::*;
9399 use arrow::compute::or;
9400
9401 if values.is_empty() {
9402 return Ok(Vec::new());
9404 }
9405
9406 let mut result = Vec::with_capacity(batches.len());
9407
9408 for batch in batches {
9409 if column_idx >= batch.num_columns() {
9410 return Err(Error::Internal(
9411 "IN list constraint referenced invalid column index".into(),
9412 ));
9413 }
9414
9415 if batch.num_rows() == 0 {
9416 result.push(batch);
9417 continue;
9418 }
9419
9420 let column = batch.column(column_idx);
9421
9422 let mut mask = BooleanArray::from(vec![false; batch.num_rows()]);
9425
9426 for value in values {
9427 let comparison_mask = build_comparison_mask(column.as_ref(), value)?;
9428 mask = or(&mask, &comparison_mask)
9429 .map_err(|err| Error::Internal(format!("failed to OR comparison masks: {err}")))?;
9430 }
9431
9432 let true_count = mask.true_count();
9434 if true_count == batch.num_rows() {
9435 result.push(batch);
9436 continue;
9437 }
9438
9439 if true_count == 0 {
9440 continue;
9442 }
9443
9444 let filtered_batch = arrow::compute::filter_record_batch(&batch, &mask)
9446 .map_err(|err| Error::Internal(format!("failed to apply IN list filter: {err}")))?;
9447
9448 result.push(filtered_batch);
9449 }
9450
9451 Ok(result)
9452}
9453
9454fn build_comparison_mask(column: &dyn Array, value: &PlanValue) -> ExecutorResult<BooleanArray> {
9456 use arrow::array::*;
9457 use arrow::datatypes::DataType;
9458
9459 match value {
9460 PlanValue::Null => {
9461 let mut builder = BooleanBuilder::with_capacity(column.len());
9463 for i in 0..column.len() {
9464 builder.append_value(column.is_null(i));
9465 }
9466 Ok(builder.finish())
9467 }
9468 PlanValue::Integer(val) => {
9469 let mut builder = BooleanBuilder::with_capacity(column.len());
9470 match column.data_type() {
9471 DataType::Int8 => {
9472 let arr = column
9473 .as_any()
9474 .downcast_ref::<Int8Array>()
9475 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
9476 let target = *val as i8;
9477 for i in 0..arr.len() {
9478 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9479 }
9480 }
9481 DataType::Int16 => {
9482 let arr = column
9483 .as_any()
9484 .downcast_ref::<Int16Array>()
9485 .ok_or_else(|| {
9486 Error::Internal("failed to downcast to Int16Array".into())
9487 })?;
9488 let target = *val as i16;
9489 for i in 0..arr.len() {
9490 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9491 }
9492 }
9493 DataType::Int32 => {
9494 let arr = column
9495 .as_any()
9496 .downcast_ref::<Int32Array>()
9497 .ok_or_else(|| {
9498 Error::Internal("failed to downcast to Int32Array".into())
9499 })?;
9500 let target = *val as i32;
9501 for i in 0..arr.len() {
9502 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9503 }
9504 }
9505 DataType::Int64 => {
9506 let arr = column
9507 .as_any()
9508 .downcast_ref::<Int64Array>()
9509 .ok_or_else(|| {
9510 Error::Internal("failed to downcast to Int64Array".into())
9511 })?;
9512 for i in 0..arr.len() {
9513 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
9514 }
9515 }
9516 DataType::UInt8 => {
9517 let arr = column
9518 .as_any()
9519 .downcast_ref::<UInt8Array>()
9520 .ok_or_else(|| {
9521 Error::Internal("failed to downcast to UInt8Array".into())
9522 })?;
9523 let target = *val as u8;
9524 for i in 0..arr.len() {
9525 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9526 }
9527 }
9528 DataType::UInt16 => {
9529 let arr = column
9530 .as_any()
9531 .downcast_ref::<UInt16Array>()
9532 .ok_or_else(|| {
9533 Error::Internal("failed to downcast to UInt16Array".into())
9534 })?;
9535 let target = *val as u16;
9536 for i in 0..arr.len() {
9537 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9538 }
9539 }
9540 DataType::UInt32 => {
9541 let arr = column
9542 .as_any()
9543 .downcast_ref::<UInt32Array>()
9544 .ok_or_else(|| {
9545 Error::Internal("failed to downcast to UInt32Array".into())
9546 })?;
9547 let target = *val as u32;
9548 for i in 0..arr.len() {
9549 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9550 }
9551 }
9552 DataType::UInt64 => {
9553 let arr = column
9554 .as_any()
9555 .downcast_ref::<UInt64Array>()
9556 .ok_or_else(|| {
9557 Error::Internal("failed to downcast to UInt64Array".into())
9558 })?;
9559 let target = *val as u64;
9560 for i in 0..arr.len() {
9561 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9562 }
9563 }
9564 _ => {
9565 return Err(Error::Internal(format!(
9566 "unsupported integer type for IN list: {:?}",
9567 column.data_type()
9568 )));
9569 }
9570 }
9571 Ok(builder.finish())
9572 }
9573 PlanValue::Float(val) => {
9574 let mut builder = BooleanBuilder::with_capacity(column.len());
9575 match column.data_type() {
9576 DataType::Float32 => {
9577 let arr = column
9578 .as_any()
9579 .downcast_ref::<Float32Array>()
9580 .ok_or_else(|| {
9581 Error::Internal("failed to downcast to Float32Array".into())
9582 })?;
9583 let target = *val as f32;
9584 for i in 0..arr.len() {
9585 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9586 }
9587 }
9588 DataType::Float64 => {
9589 let arr = column
9590 .as_any()
9591 .downcast_ref::<Float64Array>()
9592 .ok_or_else(|| {
9593 Error::Internal("failed to downcast to Float64Array".into())
9594 })?;
9595 for i in 0..arr.len() {
9596 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
9597 }
9598 }
9599 _ => {
9600 return Err(Error::Internal(format!(
9601 "unsupported float type for IN list: {:?}",
9602 column.data_type()
9603 )));
9604 }
9605 }
9606 Ok(builder.finish())
9607 }
9608 PlanValue::String(val) => {
9609 let mut builder = BooleanBuilder::with_capacity(column.len());
9610 let arr = column
9611 .as_any()
9612 .downcast_ref::<StringArray>()
9613 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
9614 for i in 0..arr.len() {
9615 builder.append_value(!arr.is_null(i) && arr.value(i) == val.as_str());
9616 }
9617 Ok(builder.finish())
9618 }
9619 PlanValue::Struct(_) => Err(Error::Internal(
9620 "struct comparison in IN list not supported".into(),
9621 )),
9622 }
9623}
9624
9625fn array_value_equals_plan_value(
9626 array: &dyn Array,
9627 row_idx: usize,
9628 literal: &PlanValue,
9629) -> ExecutorResult<bool> {
9630 use arrow::array::*;
9631 use arrow::datatypes::DataType;
9632
9633 match literal {
9634 PlanValue::Null => Ok(array.is_null(row_idx)),
9635 PlanValue::Integer(expected) => match array.data_type() {
9636 DataType::Int8 => Ok(!array.is_null(row_idx)
9637 && array
9638 .as_any()
9639 .downcast_ref::<Int8Array>()
9640 .expect("int8 array")
9641 .value(row_idx) as i64
9642 == *expected),
9643 DataType::Int16 => Ok(!array.is_null(row_idx)
9644 && array
9645 .as_any()
9646 .downcast_ref::<Int16Array>()
9647 .expect("int16 array")
9648 .value(row_idx) as i64
9649 == *expected),
9650 DataType::Int32 => Ok(!array.is_null(row_idx)
9651 && array
9652 .as_any()
9653 .downcast_ref::<Int32Array>()
9654 .expect("int32 array")
9655 .value(row_idx) as i64
9656 == *expected),
9657 DataType::Int64 => Ok(!array.is_null(row_idx)
9658 && array
9659 .as_any()
9660 .downcast_ref::<Int64Array>()
9661 .expect("int64 array")
9662 .value(row_idx)
9663 == *expected),
9664 DataType::UInt8 if *expected >= 0 => Ok(!array.is_null(row_idx)
9665 && array
9666 .as_any()
9667 .downcast_ref::<UInt8Array>()
9668 .expect("uint8 array")
9669 .value(row_idx) as i64
9670 == *expected),
9671 DataType::UInt16 if *expected >= 0 => Ok(!array.is_null(row_idx)
9672 && array
9673 .as_any()
9674 .downcast_ref::<UInt16Array>()
9675 .expect("uint16 array")
9676 .value(row_idx) as i64
9677 == *expected),
9678 DataType::UInt32 if *expected >= 0 => Ok(!array.is_null(row_idx)
9679 && array
9680 .as_any()
9681 .downcast_ref::<UInt32Array>()
9682 .expect("uint32 array")
9683 .value(row_idx) as i64
9684 == *expected),
9685 DataType::UInt64 if *expected >= 0 => Ok(!array.is_null(row_idx)
9686 && array
9687 .as_any()
9688 .downcast_ref::<UInt64Array>()
9689 .expect("uint64 array")
9690 .value(row_idx)
9691 == *expected as u64),
9692 DataType::Boolean => {
9693 if array.is_null(row_idx) {
9694 Ok(false)
9695 } else if *expected == 0 || *expected == 1 {
9696 let value = array
9697 .as_any()
9698 .downcast_ref::<BooleanArray>()
9699 .expect("bool array")
9700 .value(row_idx);
9701 Ok(value == (*expected == 1))
9702 } else {
9703 Ok(false)
9704 }
9705 }
9706 _ => Err(Error::InvalidArgumentError(format!(
9707 "literal integer comparison not supported for {:?}",
9708 array.data_type()
9709 ))),
9710 },
9711 PlanValue::Float(expected) => match array.data_type() {
9712 DataType::Float32 => Ok(!array.is_null(row_idx)
9713 && (array
9714 .as_any()
9715 .downcast_ref::<Float32Array>()
9716 .expect("float32 array")
9717 .value(row_idx) as f64
9718 - *expected)
9719 .abs()
9720 .eq(&0.0)),
9721 DataType::Float64 => Ok(!array.is_null(row_idx)
9722 && (array
9723 .as_any()
9724 .downcast_ref::<Float64Array>()
9725 .expect("float64 array")
9726 .value(row_idx)
9727 - *expected)
9728 .abs()
9729 .eq(&0.0)),
9730 _ => Err(Error::InvalidArgumentError(format!(
9731 "literal float comparison not supported for {:?}",
9732 array.data_type()
9733 ))),
9734 },
9735 PlanValue::String(expected) => match array.data_type() {
9736 DataType::Utf8 => Ok(!array.is_null(row_idx)
9737 && array
9738 .as_any()
9739 .downcast_ref::<StringArray>()
9740 .expect("string array")
9741 .value(row_idx)
9742 == expected),
9743 DataType::LargeUtf8 => Ok(!array.is_null(row_idx)
9744 && array
9745 .as_any()
9746 .downcast_ref::<LargeStringArray>()
9747 .expect("large string array")
9748 .value(row_idx)
9749 == expected),
9750 _ => Err(Error::InvalidArgumentError(format!(
9751 "literal string comparison not supported for {:?}",
9752 array.data_type()
9753 ))),
9754 },
9755 PlanValue::Struct(_) => Err(Error::InvalidArgumentError(
9756 "struct literals are not supported in join filters".into(),
9757 )),
9758 }
9759}
9760
9761fn hash_join_table_batches(
9762 left: TableCrossProductData,
9763 right: TableCrossProductData,
9764 join_keys: &[(usize, usize)],
9765 join_type: llkv_join::JoinType,
9766) -> ExecutorResult<TableCrossProductData> {
9767 let TableCrossProductData {
9768 schema: left_schema,
9769 batches: left_batches,
9770 column_counts: left_counts,
9771 table_indices: left_tables,
9772 } = left;
9773
9774 let TableCrossProductData {
9775 schema: right_schema,
9776 batches: right_batches,
9777 column_counts: right_counts,
9778 table_indices: right_tables,
9779 } = right;
9780
9781 let combined_fields: Vec<Field> = left_schema
9782 .fields()
9783 .iter()
9784 .chain(right_schema.fields().iter())
9785 .map(|field| field.as_ref().clone())
9786 .collect();
9787
9788 let combined_schema = Arc::new(Schema::new(combined_fields));
9789
9790 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
9791 column_counts.extend(left_counts.iter());
9792 column_counts.extend(right_counts.iter());
9793
9794 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
9795 table_indices.extend(left_tables.iter().copied());
9796 table_indices.extend(right_tables.iter().copied());
9797
9798 if left_batches.is_empty() {
9800 return Ok(TableCrossProductData {
9801 schema: combined_schema,
9802 batches: Vec::new(),
9803 column_counts,
9804 table_indices,
9805 });
9806 }
9807
9808 if right_batches.is_empty() {
9809 if join_type == llkv_join::JoinType::Left {
9811 let total_left_rows: usize = left_batches.iter().map(|b| b.num_rows()).sum();
9812 let mut left_arrays = Vec::new();
9813 for field in left_schema.fields() {
9814 let column_idx = left_schema.index_of(field.name()).map_err(|e| {
9815 Error::Internal(format!("failed to find field {}: {}", field.name(), e))
9816 })?;
9817 let arrays: Vec<ArrayRef> = left_batches
9818 .iter()
9819 .map(|batch| batch.column(column_idx).clone())
9820 .collect();
9821 let concatenated =
9822 arrow::compute::concat(&arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
9823 .map_err(|e| {
9824 Error::Internal(format!("failed to concat left arrays: {}", e))
9825 })?;
9826 left_arrays.push(concatenated);
9827 }
9828
9829 for field in right_schema.fields() {
9831 let null_array = arrow::array::new_null_array(field.data_type(), total_left_rows);
9832 left_arrays.push(null_array);
9833 }
9834
9835 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), left_arrays)
9836 .map_err(|err| {
9837 Error::Internal(format!(
9838 "failed to create LEFT JOIN batch with NULL right: {err}"
9839 ))
9840 })?;
9841
9842 return Ok(TableCrossProductData {
9843 schema: combined_schema,
9844 batches: vec![joined_batch],
9845 column_counts,
9846 table_indices,
9847 });
9848 } else {
9849 return Ok(TableCrossProductData {
9851 schema: combined_schema,
9852 batches: Vec::new(),
9853 column_counts,
9854 table_indices,
9855 });
9856 }
9857 }
9858
9859 match join_type {
9860 llkv_join::JoinType::Inner => {
9861 let (left_matches, right_matches) =
9862 build_join_match_indices(&left_batches, &right_batches, join_keys)?;
9863
9864 if left_matches.is_empty() {
9865 return Ok(TableCrossProductData {
9866 schema: combined_schema,
9867 batches: Vec::new(),
9868 column_counts,
9869 table_indices,
9870 });
9871 }
9872
9873 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
9874 let right_arrays = gather_indices_from_batches(&right_batches, &right_matches)?;
9875
9876 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
9877 combined_columns.extend(left_arrays);
9878 combined_columns.extend(right_arrays);
9879
9880 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
9881 .map_err(|err| {
9882 Error::Internal(format!("failed to materialize INNER JOIN batch: {err}"))
9883 })?;
9884
9885 Ok(TableCrossProductData {
9886 schema: combined_schema,
9887 batches: vec![joined_batch],
9888 column_counts,
9889 table_indices,
9890 })
9891 }
9892 llkv_join::JoinType::Left => {
9893 let (left_matches, right_optional_matches) =
9894 build_left_join_match_indices(&left_batches, &right_batches, join_keys)?;
9895
9896 if left_matches.is_empty() {
9897 return Ok(TableCrossProductData {
9899 schema: combined_schema,
9900 batches: Vec::new(),
9901 column_counts,
9902 table_indices,
9903 });
9904 }
9905
9906 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
9907 let right_arrays = llkv_column_map::gather::gather_optional_indices_from_batches(
9909 &right_batches,
9910 &right_optional_matches,
9911 )?;
9912
9913 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
9914 combined_columns.extend(left_arrays);
9915 combined_columns.extend(right_arrays);
9916
9917 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
9918 .map_err(|err| {
9919 Error::Internal(format!("failed to materialize LEFT JOIN batch: {err}"))
9920 })?;
9921
9922 Ok(TableCrossProductData {
9923 schema: combined_schema,
9924 batches: vec![joined_batch],
9925 column_counts,
9926 table_indices,
9927 })
9928 }
9929 _ => Err(Error::Internal(format!(
9931 "join type {:?} not supported in hash_join_table_batches; use llkv-join",
9932 join_type
9933 ))),
9934 }
9935}
9936
9937type JoinMatchIndices = Vec<(usize, usize)>;
9939type JoinHashTable = FxHashMap<Vec<u8>, Vec<(usize, usize)>>;
9941type JoinMatchPairs = (JoinMatchIndices, JoinMatchIndices);
9943type OptionalJoinMatches = Vec<Option<(usize, usize)>>;
9945type LeftJoinMatchPairs = (JoinMatchIndices, OptionalJoinMatches);
9947
9948fn build_join_match_indices(
9978 left_batches: &[RecordBatch],
9979 right_batches: &[RecordBatch],
9980 join_keys: &[(usize, usize)],
9981) -> ExecutorResult<JoinMatchPairs> {
9982 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
9983
9984 let hash_table: JoinHashTable = llkv_column_map::parallel::with_thread_pool(|| {
9987 let local_tables: Vec<JoinHashTable> = right_batches
9988 .par_iter()
9989 .enumerate()
9990 .map(|(batch_idx, batch)| {
9991 let mut local_table: JoinHashTable = FxHashMap::default();
9992 let mut key_buffer: Vec<u8> = Vec::new();
9993
9994 for row_idx in 0..batch.num_rows() {
9995 key_buffer.clear();
9996 match build_join_key(batch, &right_key_indices, row_idx, &mut key_buffer) {
9997 Ok(true) => {
9998 local_table
9999 .entry(key_buffer.clone())
10000 .or_default()
10001 .push((batch_idx, row_idx));
10002 }
10003 Ok(false) => continue,
10004 Err(_) => continue, }
10006 }
10007
10008 local_table
10009 })
10010 .collect();
10011
10012 let mut merged_table: JoinHashTable = FxHashMap::default();
10014 for local_table in local_tables {
10015 for (key, mut positions) in local_table {
10016 merged_table.entry(key).or_default().append(&mut positions);
10017 }
10018 }
10019
10020 merged_table
10021 });
10022
10023 if hash_table.is_empty() {
10024 return Ok((Vec::new(), Vec::new()));
10025 }
10026
10027 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
10028
10029 let matches: Vec<JoinMatchPairs> = llkv_column_map::parallel::with_thread_pool(|| {
10032 left_batches
10033 .par_iter()
10034 .enumerate()
10035 .map(|(batch_idx, batch)| {
10036 let mut local_left_matches: JoinMatchIndices = Vec::new();
10037 let mut local_right_matches: JoinMatchIndices = Vec::new();
10038 let mut key_buffer: Vec<u8> = Vec::new();
10039
10040 for row_idx in 0..batch.num_rows() {
10041 key_buffer.clear();
10042 match build_join_key(batch, &left_key_indices, row_idx, &mut key_buffer) {
10043 Ok(true) => {
10044 if let Some(entries) = hash_table.get(&key_buffer) {
10045 for &(r_batch, r_row) in entries {
10046 local_left_matches.push((batch_idx, row_idx));
10047 local_right_matches.push((r_batch, r_row));
10048 }
10049 }
10050 }
10051 Ok(false) => continue,
10052 Err(_) => continue, }
10054 }
10055
10056 (local_left_matches, local_right_matches)
10057 })
10058 .collect()
10059 });
10060
10061 let mut left_matches: JoinMatchIndices = Vec::new();
10063 let mut right_matches: JoinMatchIndices = Vec::new();
10064 for (mut left, mut right) in matches {
10065 left_matches.append(&mut left);
10066 right_matches.append(&mut right);
10067 }
10068
10069 Ok((left_matches, right_matches))
10070}
10071
10072fn build_left_join_match_indices(
10083 left_batches: &[RecordBatch],
10084 right_batches: &[RecordBatch],
10085 join_keys: &[(usize, usize)],
10086) -> ExecutorResult<LeftJoinMatchPairs> {
10087 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
10088
10089 let hash_table: JoinHashTable = llkv_column_map::parallel::with_thread_pool(|| {
10091 let local_tables: Vec<JoinHashTable> = right_batches
10092 .par_iter()
10093 .enumerate()
10094 .map(|(batch_idx, batch)| {
10095 let mut local_table: JoinHashTable = FxHashMap::default();
10096 let mut key_buffer: Vec<u8> = Vec::new();
10097
10098 for row_idx in 0..batch.num_rows() {
10099 key_buffer.clear();
10100 match build_join_key(batch, &right_key_indices, row_idx, &mut key_buffer) {
10101 Ok(true) => {
10102 local_table
10103 .entry(key_buffer.clone())
10104 .or_default()
10105 .push((batch_idx, row_idx));
10106 }
10107 Ok(false) => continue,
10108 Err(_) => continue,
10109 }
10110 }
10111
10112 local_table
10113 })
10114 .collect();
10115
10116 let mut merged_table: JoinHashTable = FxHashMap::default();
10117 for local_table in local_tables {
10118 for (key, mut positions) in local_table {
10119 merged_table.entry(key).or_default().append(&mut positions);
10120 }
10121 }
10122
10123 merged_table
10124 });
10125
10126 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
10127
10128 let matches: Vec<LeftJoinMatchPairs> = llkv_column_map::parallel::with_thread_pool(|| {
10130 left_batches
10131 .par_iter()
10132 .enumerate()
10133 .map(|(batch_idx, batch)| {
10134 let mut local_left_matches: JoinMatchIndices = Vec::new();
10135 let mut local_right_optional: Vec<Option<(usize, usize)>> = Vec::new();
10136 let mut key_buffer: Vec<u8> = Vec::new();
10137
10138 for row_idx in 0..batch.num_rows() {
10139 key_buffer.clear();
10140 match build_join_key(batch, &left_key_indices, row_idx, &mut key_buffer) {
10141 Ok(true) => {
10142 if let Some(entries) = hash_table.get(&key_buffer) {
10143 for &(r_batch, r_row) in entries {
10145 local_left_matches.push((batch_idx, row_idx));
10146 local_right_optional.push(Some((r_batch, r_row)));
10147 }
10148 } else {
10149 local_left_matches.push((batch_idx, row_idx));
10151 local_right_optional.push(None);
10152 }
10153 }
10154 Ok(false) => {
10155 local_left_matches.push((batch_idx, row_idx));
10157 local_right_optional.push(None);
10158 }
10159 Err(_) => {
10160 local_left_matches.push((batch_idx, row_idx));
10162 local_right_optional.push(None);
10163 }
10164 }
10165 }
10166
10167 (local_left_matches, local_right_optional)
10168 })
10169 .collect()
10170 });
10171
10172 let mut left_matches: JoinMatchIndices = Vec::new();
10174 let mut right_optional: Vec<Option<(usize, usize)>> = Vec::new();
10175 for (mut left, mut right) in matches {
10176 left_matches.append(&mut left);
10177 right_optional.append(&mut right);
10178 }
10179
10180 Ok((left_matches, right_optional))
10181}
10182
10183fn build_join_key(
10184 batch: &RecordBatch,
10185 column_indices: &[usize],
10186 row_idx: usize,
10187 buffer: &mut Vec<u8>,
10188) -> ExecutorResult<bool> {
10189 buffer.clear();
10190
10191 for &col_idx in column_indices {
10192 let array = batch.column(col_idx);
10193 if array.is_null(row_idx) {
10194 return Ok(false);
10195 }
10196 append_array_value_to_key(array.as_ref(), row_idx, buffer)?;
10197 }
10198
10199 Ok(true)
10200}
10201
10202fn append_array_value_to_key(
10203 array: &dyn Array,
10204 row_idx: usize,
10205 buffer: &mut Vec<u8>,
10206) -> ExecutorResult<()> {
10207 use arrow::array::*;
10208 use arrow::datatypes::DataType;
10209
10210 match array.data_type() {
10211 DataType::Int8 => buffer.extend_from_slice(
10212 &array
10213 .as_any()
10214 .downcast_ref::<Int8Array>()
10215 .expect("int8 array")
10216 .value(row_idx)
10217 .to_le_bytes(),
10218 ),
10219 DataType::Int16 => buffer.extend_from_slice(
10220 &array
10221 .as_any()
10222 .downcast_ref::<Int16Array>()
10223 .expect("int16 array")
10224 .value(row_idx)
10225 .to_le_bytes(),
10226 ),
10227 DataType::Int32 => buffer.extend_from_slice(
10228 &array
10229 .as_any()
10230 .downcast_ref::<Int32Array>()
10231 .expect("int32 array")
10232 .value(row_idx)
10233 .to_le_bytes(),
10234 ),
10235 DataType::Int64 => buffer.extend_from_slice(
10236 &array
10237 .as_any()
10238 .downcast_ref::<Int64Array>()
10239 .expect("int64 array")
10240 .value(row_idx)
10241 .to_le_bytes(),
10242 ),
10243 DataType::UInt8 => buffer.extend_from_slice(
10244 &array
10245 .as_any()
10246 .downcast_ref::<UInt8Array>()
10247 .expect("uint8 array")
10248 .value(row_idx)
10249 .to_le_bytes(),
10250 ),
10251 DataType::UInt16 => buffer.extend_from_slice(
10252 &array
10253 .as_any()
10254 .downcast_ref::<UInt16Array>()
10255 .expect("uint16 array")
10256 .value(row_idx)
10257 .to_le_bytes(),
10258 ),
10259 DataType::UInt32 => buffer.extend_from_slice(
10260 &array
10261 .as_any()
10262 .downcast_ref::<UInt32Array>()
10263 .expect("uint32 array")
10264 .value(row_idx)
10265 .to_le_bytes(),
10266 ),
10267 DataType::UInt64 => buffer.extend_from_slice(
10268 &array
10269 .as_any()
10270 .downcast_ref::<UInt64Array>()
10271 .expect("uint64 array")
10272 .value(row_idx)
10273 .to_le_bytes(),
10274 ),
10275 DataType::Float32 => buffer.extend_from_slice(
10276 &array
10277 .as_any()
10278 .downcast_ref::<Float32Array>()
10279 .expect("float32 array")
10280 .value(row_idx)
10281 .to_le_bytes(),
10282 ),
10283 DataType::Float64 => buffer.extend_from_slice(
10284 &array
10285 .as_any()
10286 .downcast_ref::<Float64Array>()
10287 .expect("float64 array")
10288 .value(row_idx)
10289 .to_le_bytes(),
10290 ),
10291 DataType::Boolean => buffer.push(
10292 array
10293 .as_any()
10294 .downcast_ref::<BooleanArray>()
10295 .expect("bool array")
10296 .value(row_idx) as u8,
10297 ),
10298 DataType::Utf8 => {
10299 let value = array
10300 .as_any()
10301 .downcast_ref::<StringArray>()
10302 .expect("utf8 array")
10303 .value(row_idx);
10304 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
10305 buffer.extend_from_slice(value.as_bytes());
10306 }
10307 DataType::LargeUtf8 => {
10308 let value = array
10309 .as_any()
10310 .downcast_ref::<LargeStringArray>()
10311 .expect("large utf8 array")
10312 .value(row_idx);
10313 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
10314 buffer.extend_from_slice(value.as_bytes());
10315 }
10316 DataType::Binary => {
10317 let value = array
10318 .as_any()
10319 .downcast_ref::<BinaryArray>()
10320 .expect("binary array")
10321 .value(row_idx);
10322 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
10323 buffer.extend_from_slice(value);
10324 }
10325 other => {
10326 return Err(Error::InvalidArgumentError(format!(
10327 "hash join does not support join key type {:?}",
10328 other
10329 )));
10330 }
10331 }
10332
10333 Ok(())
10334}
10335
10336fn table_has_join_with_used(
10337 candidate: usize,
10338 used_tables: &FxHashSet<usize>,
10339 equalities: &[ColumnEquality],
10340) -> bool {
10341 equalities.iter().any(|equality| {
10342 (equality.left.table == candidate && used_tables.contains(&equality.right.table))
10343 || (equality.right.table == candidate && used_tables.contains(&equality.left.table))
10344 })
10345}
10346
10347fn gather_join_keys(
10348 left: &TableCrossProductData,
10349 right: &TableCrossProductData,
10350 used_tables: &FxHashSet<usize>,
10351 right_table_index: usize,
10352 equalities: &[ColumnEquality],
10353) -> ExecutorResult<Vec<(usize, usize)>> {
10354 let mut keys = Vec::new();
10355
10356 for equality in equalities {
10357 if equality.left.table == right_table_index && used_tables.contains(&equality.right.table) {
10358 let left_idx = resolve_column_index(left, &equality.right).ok_or_else(|| {
10359 Error::Internal("failed to resolve column offset for hash join".into())
10360 })?;
10361 let right_idx = resolve_column_index(right, &equality.left).ok_or_else(|| {
10362 Error::Internal("failed to resolve column offset for hash join".into())
10363 })?;
10364 keys.push((left_idx, right_idx));
10365 } else if equality.right.table == right_table_index
10366 && used_tables.contains(&equality.left.table)
10367 {
10368 let left_idx = resolve_column_index(left, &equality.left).ok_or_else(|| {
10369 Error::Internal("failed to resolve column offset for hash join".into())
10370 })?;
10371 let right_idx = resolve_column_index(right, &equality.right).ok_or_else(|| {
10372 Error::Internal("failed to resolve column offset for hash join".into())
10373 })?;
10374 keys.push((left_idx, right_idx));
10375 }
10376 }
10377
10378 Ok(keys)
10379}
10380
10381fn resolve_column_index(data: &TableCrossProductData, column: &ColumnRef) -> Option<usize> {
10382 let mut offset = 0;
10383 for (table_idx, count) in data.table_indices.iter().zip(data.column_counts.iter()) {
10384 if *table_idx == column.table {
10385 if column.column < *count {
10386 return Some(offset + column.column);
10387 } else {
10388 return None;
10389 }
10390 }
10391 offset += count;
10392 }
10393 None
10394}
10395
10396fn build_cross_product_column_lookup(
10397 schema: &Schema,
10398 tables: &[llkv_plan::TableRef],
10399 column_counts: &[usize],
10400 table_indices: &[usize],
10401) -> FxHashMap<String, usize> {
10402 debug_assert_eq!(tables.len(), column_counts.len());
10403 debug_assert_eq!(column_counts.len(), table_indices.len());
10404
10405 let mut column_occurrences: FxHashMap<String, usize> = FxHashMap::default();
10406 let mut table_column_counts: FxHashMap<String, usize> = FxHashMap::default();
10407 for field in schema.fields() {
10408 let column_name = extract_column_name(field.name());
10409 *column_occurrences.entry(column_name).or_insert(0) += 1;
10410 if let Some(pair) = table_column_suffix(field.name()) {
10411 *table_column_counts.entry(pair).or_insert(0) += 1;
10412 }
10413 }
10414
10415 let mut base_table_totals: FxHashMap<String, usize> = FxHashMap::default();
10416 let mut base_table_unaliased: FxHashMap<String, usize> = FxHashMap::default();
10417 for table_ref in tables {
10418 let key = base_table_key(table_ref);
10419 *base_table_totals.entry(key.clone()).or_insert(0) += 1;
10420 if table_ref.alias.is_none() {
10421 *base_table_unaliased.entry(key).or_insert(0) += 1;
10422 }
10423 }
10424
10425 let mut lookup = FxHashMap::default();
10426
10427 if table_indices.is_empty() || column_counts.is_empty() {
10428 for (idx, field) in schema.fields().iter().enumerate() {
10429 let field_name_lower = field.name().to_ascii_lowercase();
10430 lookup.entry(field_name_lower).or_insert(idx);
10431
10432 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
10433 lookup.entry(trimmed_lower).or_insert(idx);
10434
10435 if let Some(pair) = table_column_suffix(field.name())
10436 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
10437 {
10438 lookup.entry(pair).or_insert(idx);
10439 }
10440
10441 let column_name = extract_column_name(field.name());
10442 if column_occurrences.get(&column_name).copied().unwrap_or(0) == 1 {
10443 lookup.entry(column_name).or_insert(idx);
10444 }
10445 }
10446 return lookup;
10447 }
10448
10449 let mut offset = 0usize;
10450 for (&table_idx, &count) in table_indices.iter().zip(column_counts.iter()) {
10451 if table_idx >= tables.len() {
10452 continue;
10453 }
10454 let table_ref = &tables[table_idx];
10455 let alias_lower = table_ref
10456 .alias
10457 .as_ref()
10458 .map(|alias| alias.to_ascii_lowercase());
10459 let table_lower = table_ref.table.to_ascii_lowercase();
10460 let schema_lower = table_ref.schema.to_ascii_lowercase();
10461 let base_key = base_table_key(table_ref);
10462 let total_refs = base_table_totals.get(&base_key).copied().unwrap_or(0);
10463 let unaliased_refs = base_table_unaliased.get(&base_key).copied().unwrap_or(0);
10464
10465 let allow_base_mapping = if table_ref.alias.is_none() {
10466 unaliased_refs == 1
10467 } else {
10468 unaliased_refs == 0 && total_refs == 1
10469 };
10470
10471 let mut table_keys: Vec<String> = Vec::new();
10472
10473 if let Some(alias) = &alias_lower {
10474 table_keys.push(alias.clone());
10475 if !schema_lower.is_empty() {
10476 table_keys.push(format!("{}.{}", schema_lower, alias));
10477 }
10478 }
10479
10480 if allow_base_mapping {
10481 table_keys.push(table_lower.clone());
10482 if !schema_lower.is_empty() {
10483 table_keys.push(format!("{}.{}", schema_lower, table_lower));
10484 }
10485 }
10486
10487 for local_idx in 0..count {
10488 let field_index = offset + local_idx;
10489 let field = schema.field(field_index);
10490 let field_name_lower = field.name().to_ascii_lowercase();
10491 lookup.entry(field_name_lower).or_insert(field_index);
10492
10493 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
10494 lookup.entry(trimmed_lower).or_insert(field_index);
10495
10496 let column_name = extract_column_name(field.name());
10497 for table_key in &table_keys {
10498 lookup
10499 .entry(format!("{}.{}", table_key, column_name))
10500 .or_insert(field_index);
10501 }
10502
10503 lookup.entry(column_name.clone()).or_insert(field_index);
10507
10508 if table_keys.is_empty()
10509 && let Some(pair) = table_column_suffix(field.name())
10510 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
10511 {
10512 lookup.entry(pair).or_insert(field_index);
10513 }
10514 }
10515
10516 offset = offset.saturating_add(count);
10517 }
10518
10519 lookup
10520}
10521
10522fn base_table_key(table_ref: &llkv_plan::TableRef) -> String {
10523 let schema_lower = table_ref.schema.to_ascii_lowercase();
10524 let table_lower = table_ref.table.to_ascii_lowercase();
10525 if schema_lower.is_empty() {
10526 table_lower
10527 } else {
10528 format!("{}.{}", schema_lower, table_lower)
10529 }
10530}
10531
10532fn extract_column_name(name: &str) -> String {
10533 name.trim_start_matches('.')
10534 .rsplit('.')
10535 .next()
10536 .unwrap_or(name)
10537 .to_ascii_lowercase()
10538}
10539
10540fn table_column_suffix(name: &str) -> Option<String> {
10541 let trimmed = name.trim_start_matches('.');
10542 let mut parts: Vec<&str> = trimmed.split('.').collect();
10543 if parts.len() < 2 {
10544 return None;
10545 }
10546 let column = parts.pop()?.to_ascii_lowercase();
10547 let table = parts.pop()?.to_ascii_lowercase();
10548 Some(format!("{}.{}", table, column))
10549}
10550
10551fn cross_join_table_batches(
10576 left: TableCrossProductData,
10577 right: TableCrossProductData,
10578) -> ExecutorResult<TableCrossProductData> {
10579 let TableCrossProductData {
10580 schema: left_schema,
10581 batches: left_batches,
10582 column_counts: mut left_counts,
10583 table_indices: mut left_tables,
10584 } = left;
10585 let TableCrossProductData {
10586 schema: right_schema,
10587 batches: right_batches,
10588 column_counts: right_counts,
10589 table_indices: right_tables,
10590 } = right;
10591
10592 let combined_fields: Vec<Field> = left_schema
10593 .fields()
10594 .iter()
10595 .chain(right_schema.fields().iter())
10596 .map(|field| field.as_ref().clone())
10597 .collect();
10598
10599 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
10600 column_counts.append(&mut left_counts);
10601 column_counts.extend(right_counts);
10602
10603 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
10604 table_indices.append(&mut left_tables);
10605 table_indices.extend(right_tables);
10606
10607 let combined_schema = Arc::new(Schema::new(combined_fields));
10608
10609 let left_has_rows = left_batches.iter().any(|batch| batch.num_rows() > 0);
10610 let right_has_rows = right_batches.iter().any(|batch| batch.num_rows() > 0);
10611
10612 if !left_has_rows || !right_has_rows {
10613 return Ok(TableCrossProductData {
10614 schema: combined_schema,
10615 batches: Vec::new(),
10616 column_counts,
10617 table_indices,
10618 });
10619 }
10620
10621 let output_batches: Vec<RecordBatch> = llkv_column_map::parallel::with_thread_pool(|| {
10624 left_batches
10625 .par_iter()
10626 .filter(|left_batch| left_batch.num_rows() > 0)
10627 .flat_map(|left_batch| {
10628 right_batches
10629 .par_iter()
10630 .filter(|right_batch| right_batch.num_rows() > 0)
10631 .filter_map(|right_batch| {
10632 cross_join_pair(left_batch, right_batch, &combined_schema).ok()
10633 })
10634 .collect::<Vec<_>>()
10635 })
10636 .collect()
10637 });
10638
10639 Ok(TableCrossProductData {
10640 schema: combined_schema,
10641 batches: output_batches,
10642 column_counts,
10643 table_indices,
10644 })
10645}
10646
10647fn cross_join_all(staged: Vec<TableCrossProductData>) -> ExecutorResult<TableCrossProductData> {
10648 let mut iter = staged.into_iter();
10649 let mut current = iter
10650 .next()
10651 .ok_or_else(|| Error::Internal("cross product preparation yielded no tables".into()))?;
10652 for next in iter {
10653 current = cross_join_table_batches(current, next)?;
10654 }
10655 Ok(current)
10656}
10657
10658struct TableInfo<'a> {
10659 index: usize,
10660 table_ref: &'a llkv_plan::TableRef,
10661 column_map: FxHashMap<String, usize>,
10662}
10663
10664#[derive(Clone, Copy)]
10665struct ColumnRef {
10666 table: usize,
10667 column: usize,
10668}
10669
10670#[derive(Clone, Copy)]
10671struct ColumnEquality {
10672 left: ColumnRef,
10673 right: ColumnRef,
10674}
10675
10676#[derive(Clone)]
10677struct ColumnLiteral {
10678 column: ColumnRef,
10679 value: PlanValue,
10680}
10681
10682#[derive(Clone)]
10683struct ColumnInList {
10684 column: ColumnRef,
10685 values: Vec<PlanValue>,
10686}
10687
10688#[derive(Clone)]
10689enum ColumnConstraint {
10690 Equality(ColumnLiteral),
10691 InList(ColumnInList),
10692}
10693
10694struct JoinConstraintPlan {
10696 equalities: Vec<ColumnEquality>,
10697 literals: Vec<ColumnConstraint>,
10698 unsatisfiable: bool,
10699 total_conjuncts: usize,
10701 handled_conjuncts: usize,
10703}
10704
10705fn extract_literal_pushdown_filters<P>(
10724 expr: &LlkvExpr<'static, String>,
10725 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
10726) -> Vec<Vec<ColumnConstraint>>
10727where
10728 P: Pager<Blob = EntryHandle> + Send + Sync,
10729{
10730 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
10731 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
10732 let mut column_map = FxHashMap::default();
10733 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
10734 let column_name = column.name.to_ascii_lowercase();
10735 column_map.entry(column_name).or_insert(column_idx);
10736 }
10737 table_infos.push(TableInfo {
10738 index,
10739 table_ref,
10740 column_map,
10741 });
10742 }
10743
10744 let mut constraints: Vec<Vec<ColumnConstraint>> = vec![Vec::new(); tables_with_handles.len()];
10745
10746 let mut conjuncts = Vec::new();
10748 collect_conjuncts_lenient(expr, &mut conjuncts);
10749
10750 for conjunct in conjuncts {
10751 if let LlkvExpr::Compare {
10753 left,
10754 op: CompareOp::Eq,
10755 right,
10756 } = conjunct
10757 {
10758 match (
10759 resolve_column_reference(left, &table_infos),
10760 resolve_column_reference(right, &table_infos),
10761 ) {
10762 (Some(column), None) => {
10763 if let Some(literal) = extract_literal(right)
10764 && let Some(value) = literal_to_plan_value_for_join(literal)
10765 && column.table < constraints.len()
10766 {
10767 constraints[column.table]
10768 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
10769 }
10770 }
10771 (None, Some(column)) => {
10772 if let Some(literal) = extract_literal(left)
10773 && let Some(value) = literal_to_plan_value_for_join(literal)
10774 && column.table < constraints.len()
10775 {
10776 constraints[column.table]
10777 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
10778 }
10779 }
10780 _ => {}
10781 }
10782 }
10783 else if let LlkvExpr::Pred(filter) = conjunct {
10786 if let Operator::Equals(ref literal_val) = filter.op {
10787 let field_name = filter.field_id.trim().to_ascii_lowercase();
10789
10790 for info in &table_infos {
10792 if let Some(&col_idx) = info.column_map.get(&field_name) {
10793 if let Some(value) = plan_value_from_operator_literal(literal_val) {
10794 let column_ref = ColumnRef {
10795 table: info.index,
10796 column: col_idx,
10797 };
10798 if info.index < constraints.len() {
10799 constraints[info.index].push(ColumnConstraint::Equality(
10800 ColumnLiteral {
10801 column: column_ref,
10802 value,
10803 },
10804 ));
10805 }
10806 }
10807 break; }
10809 }
10810 }
10811 }
10812 else if let LlkvExpr::InList {
10814 expr: col_expr,
10815 list,
10816 negated: false,
10817 } = conjunct
10818 {
10819 if let Some(column) = resolve_column_reference(col_expr, &table_infos) {
10820 let mut values = Vec::new();
10821 for item in list {
10822 if let Some(literal) = extract_literal(item)
10823 && let Some(value) = literal_to_plan_value_for_join(literal)
10824 {
10825 values.push(value);
10826 }
10827 }
10828 if !values.is_empty() && column.table < constraints.len() {
10829 constraints[column.table]
10830 .push(ColumnConstraint::InList(ColumnInList { column, values }));
10831 }
10832 }
10833 }
10834 else if let LlkvExpr::Or(or_children) = conjunct
10836 && let Some((column, values)) = try_extract_or_as_in_list(or_children, &table_infos)
10837 && !values.is_empty()
10838 && column.table < constraints.len()
10839 {
10840 constraints[column.table]
10841 .push(ColumnConstraint::InList(ColumnInList { column, values }));
10842 }
10843 }
10844
10845 constraints
10846}
10847
10848fn collect_conjuncts_lenient<'a>(
10853 expr: &'a LlkvExpr<'static, String>,
10854 out: &mut Vec<&'a LlkvExpr<'static, String>>,
10855) {
10856 match expr {
10857 LlkvExpr::And(children) => {
10858 for child in children {
10859 collect_conjuncts_lenient(child, out);
10860 }
10861 }
10862 other => {
10863 out.push(other);
10865 }
10866 }
10867}
10868
10869fn try_extract_or_as_in_list(
10873 or_children: &[LlkvExpr<'static, String>],
10874 table_infos: &[TableInfo<'_>],
10875) -> Option<(ColumnRef, Vec<PlanValue>)> {
10876 if or_children.is_empty() {
10877 return None;
10878 }
10879
10880 let mut common_column: Option<ColumnRef> = None;
10881 let mut values = Vec::new();
10882
10883 for child in or_children {
10884 if let LlkvExpr::Compare {
10886 left,
10887 op: CompareOp::Eq,
10888 right,
10889 } = child
10890 {
10891 if let (Some(column), None) = (
10893 resolve_column_reference(left, table_infos),
10894 resolve_column_reference(right, table_infos),
10895 ) && let Some(literal) = extract_literal(right)
10896 && let Some(value) = literal_to_plan_value_for_join(literal)
10897 {
10898 match common_column {
10900 None => common_column = Some(column),
10901 Some(ref prev)
10902 if prev.table == column.table && prev.column == column.column =>
10903 {
10904 }
10906 _ => {
10907 return None;
10909 }
10910 }
10911 values.push(value);
10912 continue;
10913 }
10914
10915 if let (None, Some(column)) = (
10917 resolve_column_reference(left, table_infos),
10918 resolve_column_reference(right, table_infos),
10919 ) && let Some(literal) = extract_literal(left)
10920 && let Some(value) = literal_to_plan_value_for_join(literal)
10921 {
10922 match common_column {
10923 None => common_column = Some(column),
10924 Some(ref prev)
10925 if prev.table == column.table && prev.column == column.column => {}
10926 _ => return None,
10927 }
10928 values.push(value);
10929 continue;
10930 }
10931 }
10932 else if let LlkvExpr::Pred(filter) = child
10934 && let Operator::Equals(ref literal) = filter.op
10935 && let Some(column) =
10936 resolve_column_reference(&ScalarExpr::Column(filter.field_id.clone()), table_infos)
10937 && let Some(value) = literal_to_plan_value_for_join(literal)
10938 {
10939 match common_column {
10940 None => common_column = Some(column),
10941 Some(ref prev) if prev.table == column.table && prev.column == column.column => {}
10942 _ => return None,
10943 }
10944 values.push(value);
10945 continue;
10946 }
10947
10948 return None;
10950 }
10951
10952 common_column.map(|col| (col, values))
10953}
10954
10955fn extract_join_constraints(
10982 expr: &LlkvExpr<'static, String>,
10983 table_infos: &[TableInfo<'_>],
10984) -> Option<JoinConstraintPlan> {
10985 let mut conjuncts = Vec::new();
10986 collect_conjuncts_lenient(expr, &mut conjuncts);
10988
10989 let total_conjuncts = conjuncts.len();
10990 let mut equalities = Vec::new();
10991 let mut literals = Vec::new();
10992 let mut unsatisfiable = false;
10993 let mut handled_conjuncts = 0;
10994
10995 for conjunct in conjuncts {
10996 match conjunct {
10997 LlkvExpr::Literal(true) => {
10998 handled_conjuncts += 1;
10999 }
11000 LlkvExpr::Literal(false) => {
11001 unsatisfiable = true;
11002 handled_conjuncts += 1;
11003 break;
11004 }
11005 LlkvExpr::Compare {
11006 left,
11007 op: CompareOp::Eq,
11008 right,
11009 } => {
11010 match (
11011 resolve_column_reference(left, table_infos),
11012 resolve_column_reference(right, table_infos),
11013 ) {
11014 (Some(left_col), Some(right_col)) => {
11015 equalities.push(ColumnEquality {
11016 left: left_col,
11017 right: right_col,
11018 });
11019 handled_conjuncts += 1;
11020 continue;
11021 }
11022 (Some(column), None) => {
11023 if let Some(literal) = extract_literal(right)
11024 && let Some(value) = literal_to_plan_value_for_join(literal)
11025 {
11026 literals
11027 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
11028 handled_conjuncts += 1;
11029 continue;
11030 }
11031 }
11032 (None, Some(column)) => {
11033 if let Some(literal) = extract_literal(left)
11034 && let Some(value) = literal_to_plan_value_for_join(literal)
11035 {
11036 literals
11037 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
11038 handled_conjuncts += 1;
11039 continue;
11040 }
11041 }
11042 _ => {}
11043 }
11044 }
11046 LlkvExpr::InList {
11048 expr: col_expr,
11049 list,
11050 negated: false,
11051 } => {
11052 if let Some(column) = resolve_column_reference(col_expr, table_infos) {
11053 let mut in_list_values = Vec::new();
11055 for item in list {
11056 if let Some(literal) = extract_literal(item)
11057 && let Some(value) = literal_to_plan_value_for_join(literal)
11058 {
11059 in_list_values.push(value);
11060 }
11061 }
11062 if !in_list_values.is_empty() {
11063 literals.push(ColumnConstraint::InList(ColumnInList {
11064 column,
11065 values: in_list_values,
11066 }));
11067 handled_conjuncts += 1;
11068 continue;
11069 }
11070 }
11071 }
11073 LlkvExpr::Or(or_children) => {
11075 if let Some((column, values)) = try_extract_or_as_in_list(or_children, table_infos)
11076 {
11077 literals.push(ColumnConstraint::InList(ColumnInList { column, values }));
11079 handled_conjuncts += 1;
11080 continue;
11081 }
11082 }
11084 LlkvExpr::Pred(filter) => {
11086 if let Operator::Equals(ref literal) = filter.op
11088 && let Some(column) = resolve_column_reference(
11089 &ScalarExpr::Column(filter.field_id.clone()),
11090 table_infos,
11091 )
11092 && let Some(value) = literal_to_plan_value_for_join(literal)
11093 {
11094 literals.push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
11095 handled_conjuncts += 1;
11096 continue;
11097 }
11098 }
11100 _ => {
11101 }
11103 }
11104 }
11105
11106 Some(JoinConstraintPlan {
11107 equalities,
11108 literals,
11109 unsatisfiable,
11110 total_conjuncts,
11111 handled_conjuncts,
11112 })
11113}
11114
11115fn resolve_column_reference(
11116 expr: &ScalarExpr<String>,
11117 table_infos: &[TableInfo<'_>],
11118) -> Option<ColumnRef> {
11119 let name = match expr {
11120 ScalarExpr::Column(name) => name.trim(),
11121 _ => return None,
11122 };
11123
11124 let mut parts: Vec<&str> = name
11125 .trim_start_matches('.')
11126 .split('.')
11127 .filter(|segment| !segment.is_empty())
11128 .collect();
11129
11130 if parts.is_empty() {
11131 return None;
11132 }
11133
11134 let column_part = parts.pop()?.to_ascii_lowercase();
11135 if parts.is_empty() {
11136 for info in table_infos {
11140 if let Some(&col_idx) = info.column_map.get(&column_part) {
11141 return Some(ColumnRef {
11142 table: info.index,
11143 column: col_idx,
11144 });
11145 }
11146 }
11147 return None;
11148 }
11149
11150 let table_ident = parts.join(".").to_ascii_lowercase();
11151 for info in table_infos {
11152 if matches_table_ident(info.table_ref, &table_ident) {
11153 if let Some(&col_idx) = info.column_map.get(&column_part) {
11154 return Some(ColumnRef {
11155 table: info.index,
11156 column: col_idx,
11157 });
11158 } else {
11159 return None;
11160 }
11161 }
11162 }
11163 None
11164}
11165
11166fn matches_table_ident(table_ref: &llkv_plan::TableRef, ident: &str) -> bool {
11167 if ident.is_empty() {
11168 return false;
11169 }
11170 if let Some(alias) = &table_ref.alias
11171 && alias.to_ascii_lowercase() == ident
11172 {
11173 return true;
11174 }
11175 if table_ref.table.to_ascii_lowercase() == ident {
11176 return true;
11177 }
11178 if !table_ref.schema.is_empty() {
11179 let full = format!(
11180 "{}.{}",
11181 table_ref.schema.to_ascii_lowercase(),
11182 table_ref.table.to_ascii_lowercase()
11183 );
11184 if full == ident {
11185 return true;
11186 }
11187 }
11188 false
11189}
11190
11191fn extract_literal(expr: &ScalarExpr<String>) -> Option<&Literal> {
11192 match expr {
11193 ScalarExpr::Literal(lit) => Some(lit),
11194 _ => None,
11195 }
11196}
11197
11198fn plan_value_from_operator_literal(op_value: &llkv_expr::literal::Literal) -> Option<PlanValue> {
11199 match op_value {
11200 llkv_expr::literal::Literal::Integer(v) => i64::try_from(*v).ok().map(PlanValue::Integer),
11201 llkv_expr::literal::Literal::Float(v) => Some(PlanValue::Float(*v)),
11202 llkv_expr::literal::Literal::Boolean(v) => Some(PlanValue::Integer(if *v { 1 } else { 0 })),
11203 llkv_expr::literal::Literal::String(v) => Some(PlanValue::String(v.clone())),
11204 _ => None,
11205 }
11206}
11207
11208fn literal_to_plan_value_for_join(literal: &Literal) -> Option<PlanValue> {
11209 match literal {
11210 Literal::Integer(v) => i64::try_from(*v).ok().map(PlanValue::Integer),
11211 Literal::Float(v) => Some(PlanValue::Float(*v)),
11212 Literal::Boolean(v) => Some(PlanValue::Integer(if *v { 1 } else { 0 })),
11213 Literal::String(v) => Some(PlanValue::String(v.clone())),
11214 _ => None,
11215 }
11216}
11217
11218#[derive(Default)]
11219struct DistinctState {
11220 seen: FxHashSet<CanonicalRow>,
11221}
11222
11223impl DistinctState {
11224 fn insert(&mut self, row: CanonicalRow) -> bool {
11225 self.seen.insert(row)
11226 }
11227}
11228
11229fn distinct_filter_batch(
11230 batch: RecordBatch,
11231 state: &mut DistinctState,
11232) -> ExecutorResult<Option<RecordBatch>> {
11233 if batch.num_rows() == 0 {
11234 return Ok(None);
11235 }
11236
11237 let mut keep_flags = Vec::with_capacity(batch.num_rows());
11238 let mut keep_count = 0usize;
11239
11240 for row_idx in 0..batch.num_rows() {
11241 let row = CanonicalRow::from_batch(&batch, row_idx)?;
11242 if state.insert(row) {
11243 keep_flags.push(true);
11244 keep_count += 1;
11245 } else {
11246 keep_flags.push(false);
11247 }
11248 }
11249
11250 if keep_count == 0 {
11251 return Ok(None);
11252 }
11253
11254 if keep_count == batch.num_rows() {
11255 return Ok(Some(batch));
11256 }
11257
11258 let mut builder = BooleanBuilder::with_capacity(batch.num_rows());
11259 for flag in keep_flags {
11260 builder.append_value(flag);
11261 }
11262 let mask = Arc::new(builder.finish());
11263
11264 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
11265 Error::InvalidArgumentError(format!("failed to apply DISTINCT filter: {err}"))
11266 })?;
11267
11268 Ok(Some(filtered))
11269}
11270
11271fn sort_record_batch_with_order(
11272 schema: &Arc<Schema>,
11273 batch: &RecordBatch,
11274 order_by: &[OrderByPlan],
11275) -> ExecutorResult<RecordBatch> {
11276 if order_by.is_empty() {
11277 return Ok(batch.clone());
11278 }
11279
11280 let mut sort_columns: Vec<SortColumn> = Vec::with_capacity(order_by.len());
11281
11282 for order in order_by {
11283 let column_index = match &order.target {
11284 OrderTarget::Column(name) => schema.index_of(name).map_err(|_| {
11285 Error::InvalidArgumentError(format!(
11286 "ORDER BY references unknown column '{}'",
11287 name
11288 ))
11289 })?,
11290 OrderTarget::Index(idx) => {
11291 if *idx >= batch.num_columns() {
11292 return Err(Error::InvalidArgumentError(format!(
11293 "ORDER BY position {} is out of bounds for {} columns",
11294 idx + 1,
11295 batch.num_columns()
11296 )));
11297 }
11298 *idx
11299 }
11300 OrderTarget::All => {
11301 return Err(Error::InvalidArgumentError(
11302 "ORDER BY ALL should be expanded before sorting".into(),
11303 ));
11304 }
11305 };
11306
11307 let source_array = batch.column(column_index);
11308
11309 let values: ArrayRef = match order.sort_type {
11310 OrderSortType::Native => Arc::clone(source_array),
11311 OrderSortType::CastTextToInteger => {
11312 let strings = source_array
11313 .as_any()
11314 .downcast_ref::<StringArray>()
11315 .ok_or_else(|| {
11316 Error::InvalidArgumentError(
11317 "ORDER BY CAST expects the underlying column to be TEXT".into(),
11318 )
11319 })?;
11320 let mut builder = Int64Builder::with_capacity(strings.len());
11321 for i in 0..strings.len() {
11322 if strings.is_null(i) {
11323 builder.append_null();
11324 } else {
11325 match strings.value(i).parse::<i64>() {
11326 Ok(value) => builder.append_value(value),
11327 Err(_) => builder.append_null(),
11328 }
11329 }
11330 }
11331 Arc::new(builder.finish()) as ArrayRef
11332 }
11333 };
11334
11335 let sort_options = SortOptions {
11336 descending: !order.ascending,
11337 nulls_first: order.nulls_first,
11338 };
11339
11340 sort_columns.push(SortColumn {
11341 values,
11342 options: Some(sort_options),
11343 });
11344 }
11345
11346 let indices = lexsort_to_indices(&sort_columns, None).map_err(|err| {
11347 Error::InvalidArgumentError(format!("failed to compute ORDER BY indices: {err}"))
11348 })?;
11349
11350 let perm = indices
11351 .as_any()
11352 .downcast_ref::<UInt32Array>()
11353 .ok_or_else(|| Error::Internal("ORDER BY sorting produced unexpected index type".into()))?;
11354
11355 let mut reordered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
11356 for col_idx in 0..batch.num_columns() {
11357 let reordered = take(batch.column(col_idx), perm, None).map_err(|err| {
11358 Error::InvalidArgumentError(format!(
11359 "failed to apply ORDER BY permutation to column {col_idx}: {err}"
11360 ))
11361 })?;
11362 reordered_columns.push(reordered);
11363 }
11364
11365 RecordBatch::try_new(Arc::clone(schema), reordered_columns)
11366 .map_err(|err| Error::Internal(format!("failed to build reordered ORDER BY batch: {err}")))
11367}
11368
11369#[cfg(test)]
11370mod tests {
11371 use super::*;
11372 use arrow::array::{Array, ArrayRef, Int64Array};
11373 use arrow::datatypes::{DataType, Field, Schema};
11374 use llkv_expr::expr::BinaryOp;
11375 use llkv_expr::literal::Literal;
11376 use llkv_storage::pager::MemPager;
11377 use std::sync::Arc;
11378
11379 #[test]
11380 fn cross_product_context_evaluates_expressions() {
11381 let schema = Arc::new(Schema::new(vec![
11382 Field::new("main.tab2.a", DataType::Int64, false),
11383 Field::new("main.tab2.b", DataType::Int64, false),
11384 ]));
11385
11386 let batch = RecordBatch::try_new(
11387 Arc::clone(&schema),
11388 vec![
11389 Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef,
11390 Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef,
11391 ],
11392 )
11393 .expect("valid batch");
11394
11395 let lookup = build_cross_product_column_lookup(schema.as_ref(), &[], &[], &[]);
11396 let mut ctx = CrossProductExpressionContext::new(schema.as_ref(), lookup)
11397 .expect("context builds from schema");
11398
11399 let literal_expr: ScalarExpr<String> = ScalarExpr::literal(67);
11400 let literal = ctx
11401 .evaluate(&literal_expr, &batch)
11402 .expect("literal evaluation succeeds");
11403 let literal_array = literal
11404 .as_any()
11405 .downcast_ref::<Int64Array>()
11406 .expect("int64 literal result");
11407 assert_eq!(literal_array.len(), 3);
11408 assert!(literal_array.iter().all(|value| value == Some(67)));
11409
11410 let add_expr = ScalarExpr::binary(
11411 ScalarExpr::column("tab2.a".to_string()),
11412 BinaryOp::Add,
11413 ScalarExpr::literal(5),
11414 );
11415 let added = ctx
11416 .evaluate(&add_expr, &batch)
11417 .expect("column addition succeeds");
11418 let added_array = added
11419 .as_any()
11420 .downcast_ref::<Int64Array>()
11421 .expect("int64 addition result");
11422 assert_eq!(added_array.values(), &[6, 7, 8]);
11423 }
11424
11425 #[test]
11426 fn aggregate_expr_allows_numeric_casts() {
11427 let expr = ScalarExpr::Cast {
11428 expr: Box::new(ScalarExpr::literal(31)),
11429 data_type: DataType::Int32,
11430 };
11431 let aggregates = FxHashMap::default();
11432
11433 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11434 .expect("cast should succeed for in-range integral values");
11435
11436 assert_eq!(value, Some(31));
11437 }
11438
11439 #[test]
11440 fn aggregate_expr_cast_rejects_out_of_range_values() {
11441 let expr = ScalarExpr::Cast {
11442 expr: Box::new(ScalarExpr::literal(-1)),
11443 data_type: DataType::UInt8,
11444 };
11445 let aggregates = FxHashMap::default();
11446
11447 let result = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates);
11448
11449 assert!(matches!(result, Err(Error::InvalidArgumentError(_))));
11450 }
11451
11452 #[test]
11453 fn aggregate_expr_null_literal_remains_null() {
11454 let expr = ScalarExpr::binary(
11455 ScalarExpr::literal(0),
11456 BinaryOp::Subtract,
11457 ScalarExpr::cast(ScalarExpr::literal(Literal::Null), DataType::Int64),
11458 );
11459 let aggregates = FxHashMap::default();
11460
11461 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11462 .expect("expression should evaluate");
11463
11464 assert_eq!(value, None);
11465 }
11466
11467 #[test]
11468 fn aggregate_expr_divide_by_zero_returns_null() {
11469 let expr = ScalarExpr::binary(
11470 ScalarExpr::literal(10),
11471 BinaryOp::Divide,
11472 ScalarExpr::literal(0),
11473 );
11474 let aggregates = FxHashMap::default();
11475
11476 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11477 .expect("division should evaluate");
11478
11479 assert_eq!(value, None);
11480 }
11481
11482 #[test]
11483 fn aggregate_expr_modulo_by_zero_returns_null() {
11484 let expr = ScalarExpr::binary(
11485 ScalarExpr::literal(10),
11486 BinaryOp::Modulo,
11487 ScalarExpr::literal(0),
11488 );
11489 let aggregates = FxHashMap::default();
11490
11491 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11492 .expect("modulo should evaluate");
11493
11494 assert_eq!(value, None);
11495 }
11496
11497 #[test]
11498 fn constant_and_with_null_yields_null() {
11499 let expr = ScalarExpr::binary(
11500 ScalarExpr::literal(Literal::Null),
11501 BinaryOp::And,
11502 ScalarExpr::literal(1),
11503 );
11504
11505 let value = evaluate_constant_scalar_with_aggregates(&expr)
11506 .expect("expression should fold as constant");
11507
11508 assert!(matches!(value, Literal::Null));
11509 }
11510
11511 #[test]
11512 fn cross_product_handles_more_than_two_tables() {
11513 let schema_a = Arc::new(Schema::new(vec![Field::new(
11514 "main.t1.a",
11515 DataType::Int64,
11516 false,
11517 )]));
11518 let schema_b = Arc::new(Schema::new(vec![Field::new(
11519 "main.t2.b",
11520 DataType::Int64,
11521 false,
11522 )]));
11523 let schema_c = Arc::new(Schema::new(vec![Field::new(
11524 "main.t3.c",
11525 DataType::Int64,
11526 false,
11527 )]));
11528
11529 let batch_a = RecordBatch::try_new(
11530 Arc::clone(&schema_a),
11531 vec![Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef],
11532 )
11533 .expect("valid batch");
11534 let batch_b = RecordBatch::try_new(
11535 Arc::clone(&schema_b),
11536 vec![Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef],
11537 )
11538 .expect("valid batch");
11539 let batch_c = RecordBatch::try_new(
11540 Arc::clone(&schema_c),
11541 vec![Arc::new(Int64Array::from(vec![100])) as ArrayRef],
11542 )
11543 .expect("valid batch");
11544
11545 let data_a = TableCrossProductData {
11546 schema: schema_a,
11547 batches: vec![batch_a],
11548 column_counts: vec![1],
11549 table_indices: vec![0],
11550 };
11551 let data_b = TableCrossProductData {
11552 schema: schema_b,
11553 batches: vec![batch_b],
11554 column_counts: vec![1],
11555 table_indices: vec![1],
11556 };
11557 let data_c = TableCrossProductData {
11558 schema: schema_c,
11559 batches: vec![batch_c],
11560 column_counts: vec![1],
11561 table_indices: vec![2],
11562 };
11563
11564 let ab = cross_join_table_batches(data_a, data_b).expect("two-table product");
11565 assert_eq!(ab.schema.fields().len(), 2);
11566 assert_eq!(ab.batches.len(), 1);
11567 assert_eq!(ab.batches[0].num_rows(), 6);
11568
11569 let abc = cross_join_table_batches(ab, data_c).expect("three-table product");
11570 assert_eq!(abc.schema.fields().len(), 3);
11571 assert_eq!(abc.batches.len(), 1);
11572
11573 let final_batch = &abc.batches[0];
11574 assert_eq!(final_batch.num_rows(), 6);
11575
11576 let col_a = final_batch
11577 .column(0)
11578 .as_any()
11579 .downcast_ref::<Int64Array>()
11580 .expect("left column values");
11581 assert_eq!(col_a.values(), &[1, 1, 1, 2, 2, 2]);
11582
11583 let col_b = final_batch
11584 .column(1)
11585 .as_any()
11586 .downcast_ref::<Int64Array>()
11587 .expect("middle column values");
11588 assert_eq!(col_b.values(), &[10, 20, 30, 10, 20, 30]);
11589
11590 let col_c = final_batch
11591 .column(2)
11592 .as_any()
11593 .downcast_ref::<Int64Array>()
11594 .expect("right column values");
11595 assert_eq!(col_c.values(), &[100, 100, 100, 100, 100, 100]);
11596 }
11597}