1use arrow::array::{
18 Array, ArrayRef, BooleanArray, BooleanBuilder, Float32Array, Float64Array, Int8Array,
19 Int16Array, Int32Array, Int64Array, Int64Builder, LargeStringArray, RecordBatch, StringArray,
20 StructArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, new_null_array,
21};
22use arrow::compute::{
23 SortColumn, SortOptions, cast, concat_batches, filter_record_batch, lexsort_to_indices, take,
24};
25use arrow::datatypes::{DataType, Field, Float64Type, Int64Type, Schema};
26use llkv_aggregate::{AggregateAccumulator, AggregateKind, AggregateSpec, AggregateState};
27use llkv_column_map::gather::gather_indices_from_batches;
28use llkv_column_map::store::Projection as StoreProjection;
29use llkv_column_map::types::LogicalFieldId;
30use llkv_expr::SubqueryId;
31use llkv_expr::expr::{
32 AggregateCall, BinaryOp, CompareOp, Expr as LlkvExpr, Filter, Operator, ScalarExpr,
33};
34use llkv_expr::literal::Literal;
35use llkv_expr::typed_predicate::{
36 build_bool_predicate, build_fixed_width_predicate, build_var_width_predicate,
37};
38use llkv_join::cross_join_pair;
39use llkv_plan::{
40 AggregateExpr, AggregateFunction, CanonicalRow, CompoundOperator, CompoundQuantifier,
41 CompoundSelectComponent, CompoundSelectPlan, OrderByPlan, OrderSortType, OrderTarget,
42 PlanValue, SelectPlan, SelectProjection,
43};
44use llkv_result::Error;
45use llkv_storage::pager::Pager;
46use llkv_table::table::{
47 RowIdFilter, ScanOrderDirection, ScanOrderSpec, ScanOrderTransform, ScanProjection,
48 ScanStreamOptions,
49};
50use llkv_table::types::FieldId;
51use llkv_table::{NumericArray, NumericArrayMap, NumericKernels, ROW_ID_FIELD_ID};
52use rayon::prelude::*;
53use rustc_hash::{FxHashMap, FxHashSet};
54use simd_r_drive_entry_handle::EntryHandle;
55use std::convert::TryFrom;
56use std::fmt;
57use std::sync::Arc;
58use std::sync::atomic::Ordering;
59
60#[cfg(test)]
61use std::cell::RefCell;
62
63pub mod insert;
68pub mod translation;
69pub mod types;
70pub mod utils;
71
72pub type ExecutorResult<T> = Result<T, Error>;
78
79use crate::translation::schema::infer_computed_data_type;
80pub use insert::{
81 build_array_for_column, normalize_insert_value_for_column, resolve_insert_columns,
82};
83pub use translation::{
84 build_projected_columns, build_wildcard_projections, full_table_scan_filter,
85 resolve_field_id_from_schema, schema_for_projections, translate_predicate,
86 translate_predicate_with, translate_scalar, translate_scalar_with,
87};
88pub use types::{
89 ExecutorColumn, ExecutorMultiColumnUnique, ExecutorRowBatch, ExecutorSchema, ExecutorTable,
90 ExecutorTableProvider,
91};
92pub use utils::current_time_micros;
93
94#[derive(Clone, Debug, PartialEq, Eq, Hash)]
95enum GroupKeyValue {
96 Null,
97 Int(i64),
98 Bool(bool),
99 String(String),
100}
101
102#[derive(Clone, Debug, PartialEq)]
105enum AggregateValue {
106 Null,
107 Int64(i64),
108 Float64(f64),
109 String(String),
110}
111
112impl AggregateValue {
113 fn as_i64(&self) -> Option<i64> {
115 match self {
116 AggregateValue::Null => None,
117 AggregateValue::Int64(v) => Some(*v),
118 AggregateValue::Float64(v) => Some(*v as i64),
119 AggregateValue::String(s) => s.parse().ok(),
120 }
121 }
122
123 #[allow(dead_code)]
125 fn as_f64(&self) -> Option<f64> {
126 match self {
127 AggregateValue::Null => None,
128 AggregateValue::Int64(v) => Some(*v as f64),
129 AggregateValue::Float64(v) => Some(*v),
130 AggregateValue::String(s) => s.parse().ok(),
131 }
132 }
133}
134
135struct GroupState {
136 batch: RecordBatch,
137 row_idx: usize,
138}
139
140struct GroupAggregateState {
142 representative_batch_idx: usize,
143 representative_row: usize,
144 row_locations: Vec<(usize, usize)>,
145}
146
147struct OutputColumn {
148 field: Field,
149 source: OutputSource,
150}
151
152enum OutputSource {
153 TableColumn { index: usize },
154 Computed { projection_index: usize },
155}
156
157#[cfg(test)]
162thread_local! {
163 static QUERY_LABEL_STACK: RefCell<Vec<String>> = const { RefCell::new(Vec::new()) };
164}
165
166pub struct QueryLogGuard {
168 _private: (),
169}
170
171#[cfg(test)]
174pub fn push_query_label(label: impl Into<String>) -> QueryLogGuard {
175 QUERY_LABEL_STACK.with(|stack| stack.borrow_mut().push(label.into()));
176 QueryLogGuard { _private: () }
177}
178
179#[cfg(not(test))]
184#[inline]
185pub fn push_query_label(_label: impl Into<String>) -> QueryLogGuard {
186 QueryLogGuard { _private: () }
187}
188
189#[cfg(test)]
190impl Drop for QueryLogGuard {
191 fn drop(&mut self) {
192 QUERY_LABEL_STACK.with(|stack| {
193 let _ = stack.borrow_mut().pop();
194 });
195 }
196}
197
198#[cfg(not(test))]
199impl Drop for QueryLogGuard {
200 #[inline]
201 fn drop(&mut self) {
202 }
204}
205
206#[cfg(test)]
208pub fn current_query_label() -> Option<String> {
209 QUERY_LABEL_STACK.with(|stack| stack.borrow().last().cloned())
210}
211
212#[cfg(not(test))]
216#[inline]
217pub fn current_query_label() -> Option<String> {
218 None
219}
220
221fn try_extract_simple_column<F: AsRef<str>>(expr: &ScalarExpr<F>) -> Option<&str> {
236 match expr {
237 ScalarExpr::Column(name) => Some(name.as_ref()),
238 ScalarExpr::Binary { left, op, right } => {
240 match op {
242 BinaryOp::Add => {
243 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Integer(0))) {
245 return try_extract_simple_column(right);
246 }
247 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Integer(0))) {
248 return try_extract_simple_column(left);
249 }
250 }
251 BinaryOp::Multiply => {
254 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Integer(1))) {
256 return try_extract_simple_column(right);
257 }
258 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Integer(1))) {
259 return try_extract_simple_column(left);
260 }
261 }
262 _ => {}
263 }
264 None
265 }
266 _ => None,
267 }
268}
269
270fn plan_values_to_arrow_array(values: &[PlanValue]) -> ExecutorResult<ArrayRef> {
275 use arrow::array::{Float64Array, Int64Array, StringArray};
276
277 let mut value_type = None;
279 for v in values {
280 if !matches!(v, PlanValue::Null) {
281 value_type = Some(v);
282 break;
283 }
284 }
285
286 match value_type {
287 Some(PlanValue::Integer(_)) => {
288 let int_values: Vec<Option<i64>> = values
289 .iter()
290 .map(|v| match v {
291 PlanValue::Integer(i) => Some(*i),
292 PlanValue::Null => None,
293 _ => Some(0), })
295 .collect();
296 Ok(Arc::new(Int64Array::from(int_values)) as ArrayRef)
297 }
298 Some(PlanValue::Float(_)) => {
299 let float_values: Vec<Option<f64>> = values
300 .iter()
301 .map(|v| match v {
302 PlanValue::Float(f) => Some(*f),
303 PlanValue::Integer(i) => Some(*i as f64),
304 PlanValue::Null => None,
305 _ => Some(0.0), })
307 .collect();
308 Ok(Arc::new(Float64Array::from(float_values)) as ArrayRef)
309 }
310 Some(PlanValue::String(_)) => {
311 let string_values: Vec<Option<&str>> = values
312 .iter()
313 .map(|v| match v {
314 PlanValue::String(s) => Some(s.as_str()),
315 PlanValue::Null => None,
316 _ => Some(""), })
318 .collect();
319 Ok(Arc::new(StringArray::from(string_values)) as ArrayRef)
320 }
321 _ => {
322 let null_values: Vec<Option<i64>> = vec![None; values.len()];
324 Ok(Arc::new(Int64Array::from(null_values)) as ArrayRef)
325 }
326 }
327}
328
329fn resolve_column_name_to_index(
339 col_name: &str,
340 column_lookup_map: &FxHashMap<String, usize>,
341) -> Option<usize> {
342 let col_lower = col_name.to_ascii_lowercase();
343
344 if let Some(&idx) = column_lookup_map.get(&col_lower) {
346 return Some(idx);
347 }
348
349 let unqualified = col_name
352 .rsplit('.')
353 .next()
354 .unwrap_or(col_name)
355 .to_ascii_lowercase();
356 column_lookup_map
357 .iter()
358 .find(|(k, _)| k.ends_with(&format!(".{}", unqualified)) || k == &&unqualified)
359 .map(|(_, &idx)| idx)
360}
361
362fn get_or_insert_column_projection<P>(
364 projections: &mut Vec<ScanProjection>,
365 cache: &mut FxHashMap<FieldId, usize>,
366 table: &ExecutorTable<P>,
367 column: &ExecutorColumn,
368) -> usize
369where
370 P: Pager<Blob = EntryHandle> + Send + Sync,
371{
372 if let Some(existing) = cache.get(&column.field_id) {
373 return *existing;
374 }
375
376 let projection_index = projections.len();
377 let alias = if column.name.is_empty() {
378 format!("col{}", column.field_id)
379 } else {
380 column.name.clone()
381 };
382 projections.push(ScanProjection::from(StoreProjection::with_alias(
383 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
384 alias,
385 )));
386 cache.insert(column.field_id, projection_index);
387 projection_index
388}
389
390fn ensure_computed_projection<P>(
392 expr: &ScalarExpr<String>,
393 table: &ExecutorTable<P>,
394 projections: &mut Vec<ScanProjection>,
395 cache: &mut FxHashMap<String, (usize, DataType)>,
396 alias_counter: &mut usize,
397) -> ExecutorResult<(usize, DataType)>
398where
399 P: Pager<Blob = EntryHandle> + Send + Sync,
400{
401 let key = format!("{:?}", expr);
402 if let Some((idx, dtype)) = cache.get(&key) {
403 return Ok((*idx, dtype.clone()));
404 }
405
406 let translated = translate_scalar(expr, table.schema.as_ref(), |name| {
407 Error::InvalidArgumentError(format!("unknown column '{}' in aggregate expression", name))
408 })?;
409 let data_type = infer_computed_data_type(table.schema.as_ref(), &translated)?;
410 if data_type == DataType::Null {
411 tracing::debug!(
412 "ensure_computed_projection inferred Null type for expr: {:?}",
413 expr
414 );
415 }
416 let alias = format!("__agg_expr_{}", *alias_counter);
417 *alias_counter += 1;
418 let projection_index = projections.len();
419 projections.push(ScanProjection::computed(translated, alias));
420 cache.insert(key, (projection_index, data_type.clone()));
421 Ok((projection_index, data_type))
422}
423
424pub struct QueryExecutor<P>
426where
427 P: Pager<Blob = EntryHandle> + Send + Sync,
428{
429 provider: Arc<dyn ExecutorTableProvider<P>>,
430}
431
432impl<P> QueryExecutor<P>
433where
434 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
435{
436 pub fn new(provider: Arc<dyn ExecutorTableProvider<P>>) -> Self {
437 Self { provider }
438 }
439
440 pub fn execute_select(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
441 self.execute_select_with_filter(plan, None)
442 }
443
444 pub fn execute_select_with_filter(
445 &self,
446 plan: SelectPlan,
447 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
448 ) -> ExecutorResult<SelectExecution<P>> {
449 if plan.compound.is_some() {
450 return self.execute_compound_select(plan, row_filter);
451 }
452
453 if plan.tables.is_empty() {
455 return self.execute_select_without_table(plan);
456 }
457
458 if !plan.group_by.is_empty() {
459 if plan.tables.len() > 1 {
460 return self.execute_cross_product(plan);
461 }
462 let table_ref = &plan.tables[0];
463 let table = self.provider.get_table(&table_ref.qualified_name())?;
464 let display_name = table_ref.qualified_name();
465 return self.execute_group_by_single_table(table, display_name, plan, row_filter);
466 }
467
468 if plan.tables.len() > 1 {
470 return self.execute_cross_product(plan);
471 }
472
473 let table_ref = &plan.tables[0];
475 let table = self.provider.get_table(&table_ref.qualified_name())?;
476 let display_name = table_ref.qualified_name();
477
478 if !plan.aggregates.is_empty() {
479 self.execute_aggregates(table, display_name, plan, row_filter)
480 } else if self.has_computed_aggregates(&plan) {
481 self.execute_computed_aggregates(table, display_name, plan, row_filter)
483 } else {
484 self.execute_projection(table, display_name, plan, row_filter)
485 }
486 }
487
488 fn execute_compound_select(
508 &self,
509 plan: SelectPlan,
510 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
511 ) -> ExecutorResult<SelectExecution<P>> {
512 let order_by = plan.order_by.clone();
513 let compound = plan.compound.expect("compound plan should be present");
514
515 let CompoundSelectPlan {
516 initial,
517 operations,
518 } = compound;
519
520 let initial_exec = self.execute_select_with_filter(*initial, row_filter.clone())?;
521 let schema = initial_exec.schema();
522 let mut rows = initial_exec.into_rows()?;
523 let mut distinct_cache: Option<FxHashSet<Vec<u8>>> = None;
524
525 for component in operations {
526 let exec = self.execute_select_with_filter(component.plan, row_filter.clone())?;
527 let other_schema = exec.schema();
528 ensure_schema_compatibility(schema.as_ref(), other_schema.as_ref())?;
529 let other_rows = exec.into_rows()?;
530
531 match (component.operator, component.quantifier) {
532 (CompoundOperator::Union, CompoundQuantifier::All) => {
533 rows.extend(other_rows);
534 distinct_cache = None;
535 }
536 (CompoundOperator::Union, CompoundQuantifier::Distinct) => {
537 ensure_distinct_rows(&mut rows, &mut distinct_cache);
538 let cache = distinct_cache
539 .as_mut()
540 .expect("distinct cache should be initialized");
541 for row in other_rows {
542 let key = encode_row(&row);
543 if cache.insert(key) {
544 rows.push(row);
545 }
546 }
547 }
548 (CompoundOperator::Except, CompoundQuantifier::Distinct) => {
549 ensure_distinct_rows(&mut rows, &mut distinct_cache);
550 let cache = distinct_cache
551 .as_mut()
552 .expect("distinct cache should be initialized");
553 if rows.is_empty() {
554 continue;
555 }
556 let mut remove_keys = FxHashSet::default();
557 for row in other_rows {
558 remove_keys.insert(encode_row(&row));
559 }
560 if remove_keys.is_empty() {
561 continue;
562 }
563 rows.retain(|row| {
564 let key = encode_row(row);
565 if remove_keys.contains(&key) {
566 cache.remove(&key);
567 false
568 } else {
569 true
570 }
571 });
572 }
573 (CompoundOperator::Except, CompoundQuantifier::All) => {
574 return Err(Error::InvalidArgumentError(
575 "EXCEPT ALL is not supported yet".into(),
576 ));
577 }
578 (CompoundOperator::Intersect, CompoundQuantifier::Distinct) => {
579 ensure_distinct_rows(&mut rows, &mut distinct_cache);
580 let mut right_keys = FxHashSet::default();
581 for row in other_rows {
582 right_keys.insert(encode_row(&row));
583 }
584 if right_keys.is_empty() {
585 rows.clear();
586 distinct_cache = Some(FxHashSet::default());
587 continue;
588 }
589 let mut new_rows = Vec::new();
590 let mut new_cache = FxHashSet::default();
591 for row in rows.drain(..) {
592 let key = encode_row(&row);
593 if right_keys.contains(&key) && new_cache.insert(key) {
594 new_rows.push(row);
595 }
596 }
597 rows = new_rows;
598 distinct_cache = Some(new_cache);
599 }
600 (CompoundOperator::Intersect, CompoundQuantifier::All) => {
601 return Err(Error::InvalidArgumentError(
602 "INTERSECT ALL is not supported yet".into(),
603 ));
604 }
605 }
606 }
607
608 let mut batch = rows_to_record_batch(schema.clone(), &rows)?;
609 if !order_by.is_empty() && batch.num_rows() > 0 {
610 batch = sort_record_batch_with_order(&schema, &batch, &order_by)?;
611 }
612
613 Ok(SelectExecution::new_single_batch(
614 String::new(),
615 schema,
616 batch,
617 ))
618 }
619
620 fn has_computed_aggregates(&self, plan: &SelectPlan) -> bool {
622 plan.projections.iter().any(|proj| {
623 if let SelectProjection::Computed { expr, .. } = proj {
624 Self::expr_contains_aggregate(expr)
625 } else {
626 false
627 }
628 })
629 }
630
631 fn predicate_contains_aggregate(expr: &llkv_expr::expr::Expr<String>) -> bool {
633 match expr {
634 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
635 exprs.iter().any(Self::predicate_contains_aggregate)
636 }
637 llkv_expr::expr::Expr::Not(inner) => Self::predicate_contains_aggregate(inner),
638 llkv_expr::expr::Expr::Compare { left, right, .. } => {
639 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
640 }
641 llkv_expr::expr::Expr::InList { expr, list, .. } => {
642 Self::expr_contains_aggregate(expr)
643 || list.iter().any(|e| Self::expr_contains_aggregate(e))
644 }
645 llkv_expr::expr::Expr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
646 llkv_expr::expr::Expr::Literal(_) => false,
647 llkv_expr::expr::Expr::Pred(_) => false,
648 llkv_expr::expr::Expr::Exists(_) => false,
649 }
650 }
651
652 fn expr_contains_aggregate(expr: &ScalarExpr<String>) -> bool {
654 match expr {
655 ScalarExpr::Aggregate(_) => true,
656 ScalarExpr::Binary { left, right, .. } => {
657 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
658 }
659 ScalarExpr::Compare { left, right, .. } => {
660 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
661 }
662 ScalarExpr::GetField { base, .. } => Self::expr_contains_aggregate(base),
663 ScalarExpr::Cast { expr, .. } => Self::expr_contains_aggregate(expr),
664 ScalarExpr::Not(expr) => Self::expr_contains_aggregate(expr),
665 ScalarExpr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
666 ScalarExpr::Case {
667 operand,
668 branches,
669 else_expr,
670 } => {
671 operand
672 .as_deref()
673 .map(Self::expr_contains_aggregate)
674 .unwrap_or(false)
675 || branches.iter().any(|(when_expr, then_expr)| {
676 Self::expr_contains_aggregate(when_expr)
677 || Self::expr_contains_aggregate(then_expr)
678 })
679 || else_expr
680 .as_deref()
681 .map(Self::expr_contains_aggregate)
682 .unwrap_or(false)
683 }
684 ScalarExpr::Coalesce(items) => items.iter().any(Self::expr_contains_aggregate),
685 ScalarExpr::Column(_) | ScalarExpr::Literal(_) => false,
686 ScalarExpr::ScalarSubquery(_) => false,
687 }
688 }
689
690 fn evaluate_exists_subquery(
691 &self,
692 context: &mut CrossProductExpressionContext,
693 subquery: &llkv_plan::FilterSubquery,
694 batch: &RecordBatch,
695 row_idx: usize,
696 ) -> ExecutorResult<bool> {
697 let bindings =
698 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
699 let bound_plan = bind_select_plan(&subquery.plan, &bindings)?;
700 let execution = self.execute_select(bound_plan)?;
701 let mut found = false;
702 execution.stream(|inner_batch| {
703 if inner_batch.num_rows() > 0 {
704 found = true;
705 }
706 Ok(())
707 })?;
708 Ok(found)
709 }
710
711 fn evaluate_scalar_subquery_literal(
712 &self,
713 context: &mut CrossProductExpressionContext,
714 subquery: &llkv_plan::ScalarSubquery,
715 batch: &RecordBatch,
716 row_idx: usize,
717 ) -> ExecutorResult<Literal> {
718 let bindings =
719 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
720 let bound_plan = bind_select_plan(&subquery.plan, &bindings)?;
721 let execution = self.execute_select(bound_plan)?;
722 let mut rows_seen: usize = 0;
723 let mut result: Option<Literal> = None;
724 execution.stream(|inner_batch| {
725 if inner_batch.num_columns() != 1 {
726 return Err(Error::InvalidArgumentError(
727 "scalar subquery must return exactly one column".into(),
728 ));
729 }
730 let column = inner_batch.column(0).clone();
731 for idx in 0..inner_batch.num_rows() {
732 if rows_seen >= 1 {
733 return Err(Error::InvalidArgumentError(
734 "scalar subquery produced more than one row".into(),
735 ));
736 }
737 rows_seen = rows_seen.saturating_add(1);
738 result = Some(array_value_to_literal(&column, idx)?);
739 }
740 Ok(())
741 })?;
742
743 if rows_seen == 0 {
744 Ok(Literal::Null)
745 } else {
746 result
747 .ok_or_else(|| Error::Internal("scalar subquery evaluation missing result".into()))
748 }
749 }
750
751 fn evaluate_scalar_subquery_numeric(
752 &self,
753 context: &mut CrossProductExpressionContext,
754 subquery: &llkv_plan::ScalarSubquery,
755 batch: &RecordBatch,
756 ) -> ExecutorResult<NumericArray> {
757 let mut values: Vec<Option<f64>> = Vec::with_capacity(batch.num_rows());
758 let mut all_integer = true;
759
760 for row_idx in 0..batch.num_rows() {
761 let literal =
762 self.evaluate_scalar_subquery_literal(context, subquery, batch, row_idx)?;
763 match literal {
764 Literal::Null => values.push(None),
765 Literal::Integer(value) => {
766 let cast = i64::try_from(value).map_err(|_| {
767 Error::InvalidArgumentError(
768 "scalar subquery integer result exceeds supported range".into(),
769 )
770 })?;
771 values.push(Some(cast as f64));
772 }
773 Literal::Float(value) => {
774 all_integer = false;
775 values.push(Some(value));
776 }
777 Literal::Boolean(flag) => {
778 let numeric = if flag { 1.0 } else { 0.0 };
779 values.push(Some(numeric));
780 }
781 Literal::String(_) | Literal::Struct(_) => {
782 return Err(Error::InvalidArgumentError(
783 "scalar subquery produced non-numeric result in numeric context".into(),
784 ));
785 }
786 }
787 }
788
789 if all_integer {
790 let iter = values.into_iter().map(|opt| opt.map(|v| v as i64));
791 let array = Int64Array::from_iter(iter);
792 NumericArray::try_from_arrow(&(Arc::new(array) as ArrayRef))
793 } else {
794 let array = Float64Array::from_iter(values);
795 NumericArray::try_from_arrow(&(Arc::new(array) as ArrayRef))
796 }
797 }
798
799 fn evaluate_projection_expression(
800 &self,
801 context: &mut CrossProductExpressionContext,
802 expr: &ScalarExpr<String>,
803 batch: &RecordBatch,
804 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
805 ) -> ExecutorResult<ArrayRef> {
806 let translated = translate_scalar(expr, context.schema(), |name| {
807 Error::InvalidArgumentError(format!(
808 "column '{}' not found in cross product result",
809 name
810 ))
811 })?;
812
813 let mut subquery_ids: FxHashSet<SubqueryId> = FxHashSet::default();
814 collect_scalar_subquery_ids(&translated, &mut subquery_ids);
815
816 let mut mapping: FxHashMap<SubqueryId, FieldId> = FxHashMap::default();
817 for subquery_id in subquery_ids {
818 let info = scalar_lookup
819 .get(&subquery_id)
820 .ok_or_else(|| Error::Internal("missing scalar subquery metadata".into()))?;
821 let field_id = context.allocate_synthetic_field_id()?;
822 let numeric = self.evaluate_scalar_subquery_numeric(context, info, batch)?;
823 context.numeric_cache.insert(field_id, numeric);
824 mapping.insert(subquery_id, field_id);
825 }
826
827 let rewritten = rewrite_scalar_expr_for_subqueries(&translated, &mapping);
828 context.evaluate_numeric(&rewritten, batch)
829 }
830
831 fn execute_select_without_table(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
833 use arrow::array::ArrayRef;
834 use arrow::datatypes::Field;
835
836 let mut fields = Vec::new();
838 let mut arrays: Vec<ArrayRef> = Vec::new();
839
840 for proj in &plan.projections {
841 match proj {
842 SelectProjection::Computed { expr, alias } => {
843 let literal =
844 evaluate_constant_scalar_with_aggregates(expr).ok_or_else(|| {
845 Error::InvalidArgumentError(
846 "SELECT without FROM only supports constant expressions".into(),
847 )
848 })?;
849 let (dtype, array) = Self::literal_to_array(&literal)?;
850
851 fields.push(Field::new(alias.clone(), dtype, true));
852 arrays.push(array);
853 }
854 _ => {
855 return Err(Error::InvalidArgumentError(
856 "SELECT without FROM only supports computed projections".into(),
857 ));
858 }
859 }
860 }
861
862 let schema = Arc::new(Schema::new(fields));
863 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)
864 .map_err(|e| Error::Internal(format!("failed to create record batch: {}", e)))?;
865
866 if plan.distinct {
867 let mut state = DistinctState::default();
868 batch = match distinct_filter_batch(batch, &mut state)? {
869 Some(filtered) => filtered,
870 None => RecordBatch::new_empty(Arc::clone(&schema)),
871 };
872 }
873
874 let schema = batch.schema();
875
876 Ok(SelectExecution::new_single_batch(
877 String::new(), schema,
879 batch,
880 ))
881 }
882
883 fn literal_to_array(lit: &llkv_expr::literal::Literal) -> ExecutorResult<(DataType, ArrayRef)> {
885 use arrow::array::{
886 ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, StructArray,
887 new_null_array,
888 };
889 use arrow::datatypes::{DataType, Field};
890 use llkv_expr::literal::Literal;
891
892 match lit {
893 Literal::Integer(v) => {
894 let val = i64::try_from(*v).unwrap_or(0);
895 Ok((
896 DataType::Int64,
897 Arc::new(Int64Array::from(vec![val])) as ArrayRef,
898 ))
899 }
900 Literal::Float(v) => Ok((
901 DataType::Float64,
902 Arc::new(Float64Array::from(vec![*v])) as ArrayRef,
903 )),
904 Literal::Boolean(v) => Ok((
905 DataType::Boolean,
906 Arc::new(BooleanArray::from(vec![*v])) as ArrayRef,
907 )),
908 Literal::String(v) => Ok((
909 DataType::Utf8,
910 Arc::new(StringArray::from(vec![v.clone()])) as ArrayRef,
911 )),
912 Literal::Null => Ok((DataType::Null, new_null_array(&DataType::Null, 1))),
913 Literal::Struct(struct_fields) => {
914 let mut inner_fields = Vec::new();
916 let mut inner_arrays = Vec::new();
917
918 for (field_name, field_lit) in struct_fields {
919 let (field_dtype, field_array) = Self::literal_to_array(field_lit)?;
920 inner_fields.push(Field::new(field_name.clone(), field_dtype, true));
921 inner_arrays.push(field_array);
922 }
923
924 let struct_array =
925 StructArray::try_new(inner_fields.clone().into(), inner_arrays, None).map_err(
926 |e| Error::Internal(format!("failed to create struct array: {}", e)),
927 )?;
928
929 Ok((
930 DataType::Struct(inner_fields.into()),
931 Arc::new(struct_array) as ArrayRef,
932 ))
933 }
934 }
935 }
936
937 fn execute_cross_product(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
939 use arrow::compute::concat_batches;
940
941 if plan.tables.len() < 2 {
942 return Err(Error::InvalidArgumentError(
943 "cross product requires at least 2 tables".into(),
944 ));
945 }
946
947 let mut tables_with_handles = Vec::with_capacity(plan.tables.len());
948 for table_ref in &plan.tables {
949 let qualified_name = table_ref.qualified_name();
950 let table = self.provider.get_table(&qualified_name)?;
951 tables_with_handles.push((table_ref.clone(), table));
952 }
953
954 let display_name = tables_with_handles
955 .iter()
956 .map(|(table_ref, _)| table_ref.qualified_name())
957 .collect::<Vec<_>>()
958 .join(",");
959
960 let mut remaining_filter = plan.filter.clone();
961
962 let join_data = if plan.scalar_subqueries.is_empty() && remaining_filter.as_ref().is_some()
964 {
965 self.try_execute_hash_join(&plan, &tables_with_handles)?
966 } else {
967 None
968 };
969
970 let current = if let Some((joined, handled_all_predicates)) = join_data {
971 if handled_all_predicates {
973 remaining_filter = None;
974 }
975 joined
976 } else {
977 let has_joins = !plan.joins.is_empty();
979
980 if has_joins && tables_with_handles.len() == 2 {
981 use llkv_join::{JoinOptions, TableJoinExt};
983
984 let (left_ref, left_table) = &tables_with_handles[0];
985 let (right_ref, right_table) = &tables_with_handles[1];
986
987 let join_metadata = plan.joins.first().ok_or_else(|| {
988 Error::InvalidArgumentError("expected join metadata for two-table join".into())
989 })?;
990
991 let join_type = match join_metadata.join_type {
992 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
993 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
994 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
995 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
996 };
997
998 tracing::debug!(
999 "Using llkv-join for {join_type:?} join between {} and {}",
1000 left_ref.qualified_name(),
1001 right_ref.qualified_name()
1002 );
1003
1004 let left_col_count = left_table.schema.columns.len();
1005 let right_col_count = right_table.schema.columns.len();
1006
1007 let mut combined_fields = Vec::with_capacity(left_col_count + right_col_count);
1008 for col in &left_table.schema.columns {
1009 combined_fields.push(Field::new(
1010 col.name.clone(),
1011 col.data_type.clone(),
1012 col.nullable,
1013 ));
1014 }
1015 for col in &right_table.schema.columns {
1016 combined_fields.push(Field::new(
1017 col.name.clone(),
1018 col.data_type.clone(),
1019 col.nullable,
1020 ));
1021 }
1022 let combined_schema = Arc::new(Schema::new(combined_fields));
1023 let column_counts = vec![left_col_count, right_col_count];
1024 let table_indices = vec![0, 1];
1025
1026 let mut join_keys = Vec::new();
1027 let mut condition_is_trivial = false;
1028 let mut condition_is_impossible = false;
1029
1030 if let Some(condition) = join_metadata.on_condition.as_ref() {
1031 let plan = build_join_keys_from_condition(
1032 condition,
1033 left_ref,
1034 left_table.as_ref(),
1035 right_ref,
1036 right_table.as_ref(),
1037 )?;
1038 join_keys = plan.keys;
1039 condition_is_trivial = plan.always_true;
1040 condition_is_impossible = plan.always_false;
1041 }
1042
1043 if condition_is_impossible {
1044 let batches = build_no_match_join_batches(
1045 join_type,
1046 left_ref,
1047 left_table.as_ref(),
1048 right_ref,
1049 right_table.as_ref(),
1050 Arc::clone(&combined_schema),
1051 )?;
1052
1053 TableCrossProductData {
1054 schema: combined_schema,
1055 batches,
1056 column_counts,
1057 table_indices,
1058 }
1059 } else {
1060 if !condition_is_trivial
1061 && join_metadata.on_condition.is_some()
1062 && join_keys.is_empty()
1063 {
1064 return Err(Error::InvalidArgumentError(
1065 "JOIN ON clause must include at least one equality predicate".into(),
1066 ));
1067 }
1068
1069 let mut result_batches = Vec::new();
1070 left_table.table.join_stream(
1071 &right_table.table,
1072 &join_keys,
1073 &JoinOptions {
1074 join_type,
1075 ..Default::default()
1076 },
1077 |batch| {
1078 result_batches.push(batch);
1079 },
1080 )?;
1081
1082 TableCrossProductData {
1083 schema: combined_schema,
1084 batches: result_batches,
1085 column_counts,
1086 table_indices,
1087 }
1088 }
1089 } else {
1090 let constraint_map = if let Some(filter_wrapper) = remaining_filter.as_ref() {
1092 extract_literal_pushdown_filters(
1093 &filter_wrapper.predicate,
1094 &tables_with_handles,
1095 )
1096 } else {
1097 vec![Vec::new(); tables_with_handles.len()]
1098 };
1099
1100 let mut staged: Vec<TableCrossProductData> =
1101 Vec::with_capacity(tables_with_handles.len());
1102 let join_lookup: FxHashMap<usize, &llkv_plan::JoinMetadata> = plan
1103 .joins
1104 .iter()
1105 .map(|join| (join.left_table_index, join))
1106 .collect();
1107
1108 let mut idx = 0usize;
1109 while idx < tables_with_handles.len() {
1110 if let Some(join_metadata) = join_lookup.get(&idx) {
1111 if idx + 1 >= tables_with_handles.len() {
1112 return Err(Error::Internal(
1113 "join metadata references table beyond FROM list".into(),
1114 ));
1115 }
1116
1117 let overlaps_next_join = join_lookup.contains_key(&(idx + 1));
1122 if let Some(condition) = join_metadata.on_condition.as_ref() {
1123 let (left_ref, left_table) = &tables_with_handles[idx];
1124 let (right_ref, right_table) = &tables_with_handles[idx + 1];
1125 let join_plan = build_join_keys_from_condition(
1126 condition,
1127 left_ref,
1128 left_table.as_ref(),
1129 right_ref,
1130 right_table.as_ref(),
1131 )?;
1132 if join_plan.always_false && !overlaps_next_join {
1133 let join_type = match join_metadata.join_type {
1134 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1135 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1136 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1137 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1138 };
1139
1140 let left_col_count = left_table.schema.columns.len();
1141 let right_col_count = right_table.schema.columns.len();
1142
1143 let mut combined_fields =
1144 Vec::with_capacity(left_col_count + right_col_count);
1145 for col in &left_table.schema.columns {
1146 combined_fields.push(Field::new(
1147 col.name.clone(),
1148 col.data_type.clone(),
1149 col.nullable,
1150 ));
1151 }
1152 for col in &right_table.schema.columns {
1153 combined_fields.push(Field::new(
1154 col.name.clone(),
1155 col.data_type.clone(),
1156 col.nullable,
1157 ));
1158 }
1159
1160 let combined_schema = Arc::new(Schema::new(combined_fields));
1161 let batches = build_no_match_join_batches(
1162 join_type,
1163 left_ref,
1164 left_table.as_ref(),
1165 right_ref,
1166 right_table.as_ref(),
1167 Arc::clone(&combined_schema),
1168 )?;
1169
1170 staged.push(TableCrossProductData {
1171 schema: combined_schema,
1172 batches,
1173 column_counts: vec![left_col_count, right_col_count],
1174 table_indices: vec![idx, idx + 1],
1175 });
1176 idx += 2;
1177 continue;
1178 }
1179 }
1180 }
1181
1182 let (table_ref, table) = &tables_with_handles[idx];
1183 let constraints = constraint_map.get(idx).map(|v| v.as_slice()).unwrap_or(&[]);
1184 staged.push(collect_table_data(
1185 idx,
1186 table_ref,
1187 table.as_ref(),
1188 constraints,
1189 )?);
1190 idx += 1;
1191 }
1192
1193 cross_join_all(staged)?
1194 }
1195 };
1196
1197 let TableCrossProductData {
1198 schema: combined_schema,
1199 batches: mut combined_batches,
1200 column_counts,
1201 table_indices,
1202 } = current;
1203
1204 let column_lookup_map = build_cross_product_column_lookup(
1205 combined_schema.as_ref(),
1206 &plan.tables,
1207 &column_counts,
1208 &table_indices,
1209 );
1210
1211 if let Some(filter_wrapper) = remaining_filter.as_ref() {
1212 let mut filter_context = CrossProductExpressionContext::new(
1213 combined_schema.as_ref(),
1214 column_lookup_map.clone(),
1215 )?;
1216 let translated_filter = translate_predicate(
1217 filter_wrapper.predicate.clone(),
1218 filter_context.schema(),
1219 |name| {
1220 Error::InvalidArgumentError(format!(
1221 "column '{}' not found in cross product result",
1222 name
1223 ))
1224 },
1225 )?;
1226
1227 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
1228 filter_wrapper
1229 .subqueries
1230 .iter()
1231 .map(|subquery| (subquery.id, subquery))
1232 .collect();
1233
1234 let mut filtered_batches = Vec::with_capacity(combined_batches.len());
1235 for batch in combined_batches.into_iter() {
1236 filter_context.reset();
1237 let mask = filter_context.evaluate_predicate_mask(
1238 &translated_filter,
1239 &batch,
1240 |ctx, subquery_expr, row_idx, current_batch| {
1241 let subquery = subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
1242 Error::Internal("missing correlated subquery metadata".into())
1243 })?;
1244 let exists =
1245 self.evaluate_exists_subquery(ctx, subquery, current_batch, row_idx)?;
1246 let value = if subquery_expr.negated {
1247 !exists
1248 } else {
1249 exists
1250 };
1251 Ok(Some(value))
1252 },
1253 )?;
1254 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
1255 Error::InvalidArgumentError(format!(
1256 "failed to apply cross product filter: {err}"
1257 ))
1258 })?;
1259 if filtered.num_rows() > 0 {
1260 filtered_batches.push(filtered);
1261 }
1262 }
1263 combined_batches = filtered_batches;
1264 }
1265
1266 if !plan.group_by.is_empty() {
1268 return self.execute_group_by_from_batches(
1269 display_name,
1270 plan,
1271 combined_schema,
1272 combined_batches,
1273 column_lookup_map,
1274 );
1275 }
1276
1277 if !plan.aggregates.is_empty() {
1278 return self.execute_cross_product_aggregates(
1279 Arc::clone(&combined_schema),
1280 combined_batches,
1281 &column_lookup_map,
1282 &plan,
1283 &display_name,
1284 );
1285 }
1286
1287 if self.has_computed_aggregates(&plan) {
1288 return self.execute_cross_product_computed_aggregates(
1289 Arc::clone(&combined_schema),
1290 combined_batches,
1291 &column_lookup_map,
1292 &plan,
1293 &display_name,
1294 );
1295 }
1296
1297 let mut combined_batch = if combined_batches.is_empty() {
1298 RecordBatch::new_empty(Arc::clone(&combined_schema))
1299 } else if combined_batches.len() == 1 {
1300 combined_batches.pop().unwrap()
1301 } else {
1302 concat_batches(&combined_schema, &combined_batches).map_err(|e| {
1303 Error::Internal(format!(
1304 "failed to concatenate cross product batches: {}",
1305 e
1306 ))
1307 })?
1308 };
1309
1310 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
1311 .scalar_subqueries
1312 .iter()
1313 .map(|subquery| (subquery.id, subquery))
1314 .collect();
1315
1316 if !plan.projections.is_empty() {
1318 let mut selected_fields = Vec::new();
1319 let mut selected_columns = Vec::new();
1320 let mut expr_context: Option<CrossProductExpressionContext> = None;
1321
1322 for proj in &plan.projections {
1323 match proj {
1324 SelectProjection::AllColumns => {
1325 selected_fields = combined_schema.fields().iter().cloned().collect();
1327 selected_columns = combined_batch.columns().to_vec();
1328 break;
1329 }
1330 SelectProjection::AllColumnsExcept { exclude } => {
1331 let exclude_lower: Vec<String> =
1333 exclude.iter().map(|e| e.to_ascii_lowercase()).collect();
1334
1335 for (idx, field) in combined_schema.fields().iter().enumerate() {
1336 let field_name_lower = field.name().to_ascii_lowercase();
1337 if !exclude_lower.contains(&field_name_lower) {
1338 selected_fields.push(field.clone());
1339 selected_columns.push(combined_batch.column(idx).clone());
1340 }
1341 }
1342 break;
1343 }
1344 SelectProjection::Column { name, alias } => {
1345 let col_name = name.to_ascii_lowercase();
1347 if let Some(&idx) = column_lookup_map.get(&col_name) {
1348 let field = combined_schema.field(idx);
1349 let output_name = alias.as_ref().unwrap_or(name).clone();
1350 selected_fields.push(Arc::new(arrow::datatypes::Field::new(
1351 output_name,
1352 field.data_type().clone(),
1353 field.is_nullable(),
1354 )));
1355 selected_columns.push(combined_batch.column(idx).clone());
1356 } else {
1357 return Err(Error::InvalidArgumentError(format!(
1358 "column '{}' not found in cross product result",
1359 name
1360 )));
1361 }
1362 }
1363 SelectProjection::Computed { expr, alias } => {
1364 if expr_context.is_none() {
1365 expr_context = Some(CrossProductExpressionContext::new(
1366 combined_schema.as_ref(),
1367 column_lookup_map.clone(),
1368 )?);
1369 }
1370 let context = expr_context
1371 .as_mut()
1372 .expect("projection context must be initialized");
1373 context.reset();
1374 let evaluated = self.evaluate_projection_expression(
1375 context,
1376 expr,
1377 &combined_batch,
1378 &scalar_lookup,
1379 )?;
1380 let field = Arc::new(arrow::datatypes::Field::new(
1381 alias.clone(),
1382 evaluated.data_type().clone(),
1383 true,
1384 ));
1385 selected_fields.push(field);
1386 selected_columns.push(evaluated);
1387 }
1388 }
1389 }
1390
1391 let projected_schema = Arc::new(Schema::new(selected_fields));
1392 combined_batch = RecordBatch::try_new(projected_schema, selected_columns)
1393 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))?;
1394 }
1395
1396 if plan.distinct {
1397 let mut state = DistinctState::default();
1398 let source_schema = combined_batch.schema();
1399 combined_batch = match distinct_filter_batch(combined_batch, &mut state)? {
1400 Some(filtered) => filtered,
1401 None => RecordBatch::new_empty(source_schema),
1402 };
1403 }
1404
1405 let schema = combined_batch.schema();
1406
1407 Ok(SelectExecution::new_single_batch(
1408 display_name,
1409 schema,
1410 combined_batch,
1411 ))
1412 }
1413}
1414
1415struct JoinKeyBuild {
1416 keys: Vec<llkv_join::JoinKey>,
1417 always_true: bool,
1418 always_false: bool,
1419}
1420
1421#[derive(Debug)]
1422enum JoinConditionAnalysis {
1423 AlwaysTrue,
1424 AlwaysFalse,
1425 EquiPairs(Vec<(String, String)>),
1426}
1427
1428fn build_join_keys_from_condition<P>(
1429 condition: &LlkvExpr<'static, String>,
1430 left_ref: &llkv_plan::TableRef,
1431 left_table: &ExecutorTable<P>,
1432 right_ref: &llkv_plan::TableRef,
1433 right_table: &ExecutorTable<P>,
1434) -> ExecutorResult<JoinKeyBuild>
1435where
1436 P: Pager<Blob = EntryHandle> + Send + Sync,
1437{
1438 match analyze_join_condition(condition)? {
1439 JoinConditionAnalysis::AlwaysTrue => Ok(JoinKeyBuild {
1440 keys: Vec::new(),
1441 always_true: true,
1442 always_false: false,
1443 }),
1444 JoinConditionAnalysis::AlwaysFalse => Ok(JoinKeyBuild {
1445 keys: Vec::new(),
1446 always_true: false,
1447 always_false: true,
1448 }),
1449 JoinConditionAnalysis::EquiPairs(pairs) => {
1450 let left_lookup = build_join_column_lookup(left_ref, left_table);
1451 let right_lookup = build_join_column_lookup(right_ref, right_table);
1452
1453 let mut keys = Vec::with_capacity(pairs.len());
1454 for (lhs, rhs) in pairs {
1455 let (lhs_side, lhs_field) = resolve_join_column(&lhs, &left_lookup, &right_lookup)?;
1456 let (rhs_side, rhs_field) = resolve_join_column(&rhs, &left_lookup, &right_lookup)?;
1457
1458 match (lhs_side, rhs_side) {
1459 (JoinColumnSide::Left, JoinColumnSide::Right) => {
1460 keys.push(llkv_join::JoinKey::new(lhs_field, rhs_field));
1461 }
1462 (JoinColumnSide::Right, JoinColumnSide::Left) => {
1463 keys.push(llkv_join::JoinKey::new(rhs_field, lhs_field));
1464 }
1465 (JoinColumnSide::Left, JoinColumnSide::Left) => {
1466 return Err(Error::InvalidArgumentError(format!(
1467 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1468 left_ref.display_name(),
1469 lhs,
1470 rhs
1471 )));
1472 }
1473 (JoinColumnSide::Right, JoinColumnSide::Right) => {
1474 return Err(Error::InvalidArgumentError(format!(
1475 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1476 right_ref.display_name(),
1477 lhs,
1478 rhs
1479 )));
1480 }
1481 }
1482 }
1483
1484 Ok(JoinKeyBuild {
1485 keys,
1486 always_true: false,
1487 always_false: false,
1488 })
1489 }
1490 }
1491}
1492
1493fn analyze_join_condition(
1494 expr: &LlkvExpr<'static, String>,
1495) -> ExecutorResult<JoinConditionAnalysis> {
1496 match evaluate_constant_join_expr(expr) {
1497 ConstantJoinEvaluation::Known(true) => {
1498 return Ok(JoinConditionAnalysis::AlwaysTrue);
1499 }
1500 ConstantJoinEvaluation::Known(false) | ConstantJoinEvaluation::Unknown => {
1501 return Ok(JoinConditionAnalysis::AlwaysFalse);
1502 }
1503 ConstantJoinEvaluation::NotConstant => {}
1504 }
1505 match expr {
1506 LlkvExpr::Literal(value) => {
1507 if *value {
1508 Ok(JoinConditionAnalysis::AlwaysTrue)
1509 } else {
1510 Ok(JoinConditionAnalysis::AlwaysFalse)
1511 }
1512 }
1513 LlkvExpr::And(children) => {
1514 let mut collected: Vec<(String, String)> = Vec::new();
1515 for child in children {
1516 match analyze_join_condition(child)? {
1517 JoinConditionAnalysis::AlwaysTrue => {}
1518 JoinConditionAnalysis::AlwaysFalse => {
1519 return Ok(JoinConditionAnalysis::AlwaysFalse);
1520 }
1521 JoinConditionAnalysis::EquiPairs(mut pairs) => {
1522 collected.append(&mut pairs);
1523 }
1524 }
1525 }
1526
1527 if collected.is_empty() {
1528 Ok(JoinConditionAnalysis::AlwaysTrue)
1529 } else {
1530 Ok(JoinConditionAnalysis::EquiPairs(collected))
1531 }
1532 }
1533 LlkvExpr::Compare { left, op, right } => {
1534 if *op != CompareOp::Eq {
1535 return Err(Error::InvalidArgumentError(
1536 "JOIN ON clause only supports '=' comparisons in optimized path".into(),
1537 ));
1538 }
1539 let left_name = try_extract_simple_column(left).ok_or_else(|| {
1540 Error::InvalidArgumentError(
1541 "JOIN ON clause requires plain column references".into(),
1542 )
1543 })?;
1544 let right_name = try_extract_simple_column(right).ok_or_else(|| {
1545 Error::InvalidArgumentError(
1546 "JOIN ON clause requires plain column references".into(),
1547 )
1548 })?;
1549 Ok(JoinConditionAnalysis::EquiPairs(vec![(
1550 left_name.to_string(),
1551 right_name.to_string(),
1552 )]))
1553 }
1554 _ => Err(Error::InvalidArgumentError(
1555 "JOIN ON expressions must be conjunctions of column equality predicates".into(),
1556 )),
1557 }
1558}
1559
1560fn compare_literals_with_mode(
1561 op: CompareOp,
1562 left: &Literal,
1563 right: &Literal,
1564 null_behavior: NullComparisonBehavior,
1565) -> Option<bool> {
1566 use std::cmp::Ordering;
1567
1568 fn ordering_result(ord: Ordering, op: CompareOp) -> bool {
1569 match op {
1570 CompareOp::Eq => ord == Ordering::Equal,
1571 CompareOp::NotEq => ord != Ordering::Equal,
1572 CompareOp::Lt => ord == Ordering::Less,
1573 CompareOp::LtEq => ord != Ordering::Greater,
1574 CompareOp::Gt => ord == Ordering::Greater,
1575 CompareOp::GtEq => ord != Ordering::Less,
1576 }
1577 }
1578
1579 fn compare_f64(lhs: f64, rhs: f64, op: CompareOp) -> bool {
1580 match op {
1581 CompareOp::Eq => lhs == rhs,
1582 CompareOp::NotEq => lhs != rhs,
1583 CompareOp::Lt => lhs < rhs,
1584 CompareOp::LtEq => lhs <= rhs,
1585 CompareOp::Gt => lhs > rhs,
1586 CompareOp::GtEq => lhs >= rhs,
1587 }
1588 }
1589
1590 match (left, right) {
1591 (Literal::Null, _) | (_, Literal::Null) => match null_behavior {
1592 NullComparisonBehavior::ThreeValuedLogic => None,
1593 },
1594 (Literal::Integer(lhs), Literal::Integer(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1595 (Literal::Float(lhs), Literal::Float(rhs)) => Some(compare_f64(*lhs, *rhs, op)),
1596 (Literal::Integer(lhs), Literal::Float(rhs)) => Some(compare_f64(*lhs as f64, *rhs, op)),
1597 (Literal::Float(lhs), Literal::Integer(rhs)) => Some(compare_f64(*lhs, *rhs as f64, op)),
1598 (Literal::Boolean(lhs), Literal::Boolean(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1599 (Literal::String(lhs), Literal::String(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1600 (Literal::Struct(_), _) | (_, Literal::Struct(_)) => None,
1601 _ => None,
1602 }
1603}
1604
1605fn build_no_match_join_batches<P>(
1606 join_type: llkv_join::JoinType,
1607 left_ref: &llkv_plan::TableRef,
1608 left_table: &ExecutorTable<P>,
1609 right_ref: &llkv_plan::TableRef,
1610 right_table: &ExecutorTable<P>,
1611 combined_schema: Arc<Schema>,
1612) -> ExecutorResult<Vec<RecordBatch>>
1613where
1614 P: Pager<Blob = EntryHandle> + Send + Sync,
1615{
1616 match join_type {
1617 llkv_join::JoinType::Inner => Ok(Vec::new()),
1618 llkv_join::JoinType::Left => {
1619 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
1620 let mut results = Vec::new();
1621
1622 for left_batch in left_batches {
1623 let row_count = left_batch.num_rows();
1624 if row_count == 0 {
1625 continue;
1626 }
1627
1628 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1629 columns.extend(left_batch.columns().iter().cloned());
1630 for column in &right_table.schema.columns {
1631 columns.push(new_null_array(&column.data_type, row_count));
1632 }
1633
1634 let batch =
1635 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1636 Error::Internal(format!("failed to build LEFT JOIN fallback batch: {err}"))
1637 })?;
1638 results.push(batch);
1639 }
1640
1641 Ok(results)
1642 }
1643 llkv_join::JoinType::Right => {
1644 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
1645 let mut results = Vec::new();
1646
1647 for right_batch in right_batches {
1648 let row_count = right_batch.num_rows();
1649 if row_count == 0 {
1650 continue;
1651 }
1652
1653 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1654 for column in &left_table.schema.columns {
1655 columns.push(new_null_array(&column.data_type, row_count));
1656 }
1657 columns.extend(right_batch.columns().iter().cloned());
1658
1659 let batch =
1660 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1661 Error::Internal(format!("failed to build RIGHT JOIN fallback batch: {err}"))
1662 })?;
1663 results.push(batch);
1664 }
1665
1666 Ok(results)
1667 }
1668 llkv_join::JoinType::Full => {
1669 let mut results = Vec::new();
1670
1671 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
1672 for left_batch in left_batches {
1673 let row_count = left_batch.num_rows();
1674 if row_count == 0 {
1675 continue;
1676 }
1677
1678 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1679 columns.extend(left_batch.columns().iter().cloned());
1680 for column in &right_table.schema.columns {
1681 columns.push(new_null_array(&column.data_type, row_count));
1682 }
1683
1684 let batch =
1685 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1686 Error::Internal(format!(
1687 "failed to build FULL JOIN left fallback batch: {err}"
1688 ))
1689 })?;
1690 results.push(batch);
1691 }
1692
1693 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
1694 for right_batch in right_batches {
1695 let row_count = right_batch.num_rows();
1696 if row_count == 0 {
1697 continue;
1698 }
1699
1700 let mut columns = Vec::with_capacity(combined_schema.fields().len());
1701 for column in &left_table.schema.columns {
1702 columns.push(new_null_array(&column.data_type, row_count));
1703 }
1704 columns.extend(right_batch.columns().iter().cloned());
1705
1706 let batch =
1707 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
1708 Error::Internal(format!(
1709 "failed to build FULL JOIN right fallback batch: {err}"
1710 ))
1711 })?;
1712 results.push(batch);
1713 }
1714
1715 Ok(results)
1716 }
1717 other => Err(Error::InvalidArgumentError(format!(
1718 "{other:?} join type is not supported when join predicate is unsatisfiable",
1719 ))),
1720 }
1721}
1722
1723fn scan_all_columns_for_join<P>(
1724 table_ref: &llkv_plan::TableRef,
1725 table: &ExecutorTable<P>,
1726) -> ExecutorResult<Vec<RecordBatch>>
1727where
1728 P: Pager<Blob = EntryHandle> + Send + Sync,
1729{
1730 if table.schema.columns.is_empty() {
1731 return Err(Error::InvalidArgumentError(format!(
1732 "table '{}' has no columns; joins require at least one column",
1733 table_ref.qualified_name()
1734 )));
1735 }
1736
1737 let mut projections = Vec::with_capacity(table.schema.columns.len());
1738 for column in &table.schema.columns {
1739 projections.push(ScanProjection::from(StoreProjection::with_alias(
1740 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
1741 column.name.clone(),
1742 )));
1743 }
1744
1745 let filter_field = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
1746 let filter_expr = full_table_scan_filter(filter_field);
1747
1748 let mut batches = Vec::new();
1749 table.table.scan_stream(
1750 projections,
1751 &filter_expr,
1752 ScanStreamOptions {
1753 include_nulls: true,
1754 ..ScanStreamOptions::default()
1755 },
1756 |batch| {
1757 batches.push(batch);
1758 },
1759 )?;
1760
1761 Ok(batches)
1762}
1763
1764fn build_join_column_lookup<P>(
1765 table_ref: &llkv_plan::TableRef,
1766 table: &ExecutorTable<P>,
1767) -> FxHashMap<String, FieldId>
1768where
1769 P: Pager<Blob = EntryHandle> + Send + Sync,
1770{
1771 let mut lookup = FxHashMap::default();
1772 let table_lower = table_ref.table.to_ascii_lowercase();
1773 let qualified_lower = table_ref.qualified_name().to_ascii_lowercase();
1774 let display_lower = table_ref.display_name().to_ascii_lowercase();
1775 let alias_lower = table_ref.alias.as_ref().map(|s| s.to_ascii_lowercase());
1776 let schema_lower = if table_ref.schema.is_empty() {
1777 None
1778 } else {
1779 Some(table_ref.schema.to_ascii_lowercase())
1780 };
1781
1782 for column in &table.schema.columns {
1783 let base = column.name.to_ascii_lowercase();
1784 let short = base.rsplit('.').next().unwrap_or(base.as_str()).to_string();
1785
1786 lookup.entry(short.clone()).or_insert(column.field_id);
1787 lookup.entry(base.clone()).or_insert(column.field_id);
1788
1789 lookup
1790 .entry(format!("{table_lower}.{short}"))
1791 .or_insert(column.field_id);
1792
1793 if display_lower != table_lower {
1794 lookup
1795 .entry(format!("{display_lower}.{short}"))
1796 .or_insert(column.field_id);
1797 }
1798
1799 if qualified_lower != table_lower {
1800 lookup
1801 .entry(format!("{qualified_lower}.{short}"))
1802 .or_insert(column.field_id);
1803 }
1804
1805 if let Some(schema) = &schema_lower {
1806 lookup
1807 .entry(format!("{schema}.{table_lower}.{short}"))
1808 .or_insert(column.field_id);
1809 if display_lower != table_lower {
1810 lookup
1811 .entry(format!("{schema}.{display_lower}.{short}"))
1812 .or_insert(column.field_id);
1813 }
1814 }
1815
1816 if let Some(alias) = &alias_lower {
1817 lookup
1818 .entry(format!("{alias}.{short}"))
1819 .or_insert(column.field_id);
1820 }
1821 }
1822
1823 lookup
1824}
1825
1826#[derive(Clone, Copy)]
1827enum JoinColumnSide {
1828 Left,
1829 Right,
1830}
1831
1832fn resolve_join_column(
1833 column: &str,
1834 left_lookup: &FxHashMap<String, FieldId>,
1835 right_lookup: &FxHashMap<String, FieldId>,
1836) -> ExecutorResult<(JoinColumnSide, FieldId)> {
1837 let key = column.to_ascii_lowercase();
1838 match (left_lookup.get(&key), right_lookup.get(&key)) {
1839 (Some(&field_id), None) => Ok((JoinColumnSide::Left, field_id)),
1840 (None, Some(&field_id)) => Ok((JoinColumnSide::Right, field_id)),
1841 (Some(_), Some(_)) => Err(Error::InvalidArgumentError(format!(
1842 "join column '{column}' is ambiguous; qualify it with a table name or alias",
1843 ))),
1844 (None, None) => Err(Error::InvalidArgumentError(format!(
1845 "join column '{column}' was not found in either table",
1846 ))),
1847 }
1848}
1849
1850#[cfg(test)]
1851mod join_condition_tests {
1852 use super::*;
1853 use llkv_expr::expr::{CompareOp, ScalarExpr};
1854 use llkv_expr::literal::Literal;
1855
1856 #[test]
1857 fn analyze_detects_simple_equality() {
1858 let expr = LlkvExpr::Compare {
1859 left: ScalarExpr::Column("t1.col".into()),
1860 op: CompareOp::Eq,
1861 right: ScalarExpr::Column("t2.col".into()),
1862 };
1863
1864 match analyze_join_condition(&expr).expect("analysis succeeds") {
1865 JoinConditionAnalysis::EquiPairs(pairs) => {
1866 assert_eq!(pairs, vec![("t1.col".to_string(), "t2.col".to_string())]);
1867 }
1868 other => panic!("unexpected analysis result: {other:?}"),
1869 }
1870 }
1871
1872 #[test]
1873 fn analyze_handles_literal_true() {
1874 let expr = LlkvExpr::Literal(true);
1875 assert!(matches!(
1876 analyze_join_condition(&expr).expect("analysis succeeds"),
1877 JoinConditionAnalysis::AlwaysTrue
1878 ));
1879 }
1880
1881 #[test]
1882 fn analyze_rejects_non_equality() {
1883 let expr = LlkvExpr::Compare {
1884 left: ScalarExpr::Column("t1.col".into()),
1885 op: CompareOp::Gt,
1886 right: ScalarExpr::Column("t2.col".into()),
1887 };
1888 assert!(analyze_join_condition(&expr).is_err());
1889 }
1890
1891 #[test]
1892 fn analyze_handles_constant_is_not_null() {
1893 let expr = LlkvExpr::IsNull {
1894 expr: ScalarExpr::Literal(Literal::Null),
1895 negated: true,
1896 };
1897
1898 assert!(matches!(
1899 analyze_join_condition(&expr).expect("analysis succeeds"),
1900 JoinConditionAnalysis::AlwaysFalse
1901 ));
1902 }
1903
1904 #[test]
1905 fn analyze_handles_not_applied_to_is_not_null() {
1906 let expr = LlkvExpr::Not(Box::new(LlkvExpr::IsNull {
1907 expr: ScalarExpr::Literal(Literal::Integer(86)),
1908 negated: true,
1909 }));
1910
1911 assert!(matches!(
1912 analyze_join_condition(&expr).expect("analysis succeeds"),
1913 JoinConditionAnalysis::AlwaysFalse
1914 ));
1915 }
1916
1917 #[test]
1918 fn analyze_literal_is_null_is_always_false() {
1919 let expr = LlkvExpr::IsNull {
1920 expr: ScalarExpr::Literal(Literal::Integer(1)),
1921 negated: false,
1922 };
1923
1924 assert!(matches!(
1925 analyze_join_condition(&expr).expect("analysis succeeds"),
1926 JoinConditionAnalysis::AlwaysFalse
1927 ));
1928 }
1929
1930 #[test]
1931 fn analyze_not_null_comparison_is_always_false() {
1932 let expr = LlkvExpr::Not(Box::new(LlkvExpr::Compare {
1933 left: ScalarExpr::Literal(Literal::Null),
1934 op: CompareOp::Lt,
1935 right: ScalarExpr::Column("t2.col".into()),
1936 }));
1937
1938 assert!(matches!(
1939 analyze_join_condition(&expr).expect("analysis succeeds"),
1940 JoinConditionAnalysis::AlwaysFalse
1941 ));
1942 }
1943}
1944
1945impl<P> QueryExecutor<P>
1946where
1947 P: Pager<Blob = EntryHandle> + Send + Sync,
1948{
1949 fn execute_cross_product_aggregates(
1950 &self,
1951 combined_schema: Arc<Schema>,
1952 batches: Vec<RecordBatch>,
1953 column_lookup_map: &FxHashMap<String, usize>,
1954 plan: &SelectPlan,
1955 display_name: &str,
1956 ) -> ExecutorResult<SelectExecution<P>> {
1957 if !plan.scalar_subqueries.is_empty() {
1958 return Err(Error::InvalidArgumentError(
1959 "scalar subqueries not supported in aggregate joins".into(),
1960 ));
1961 }
1962
1963 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
1964 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(plan.aggregates.len());
1965
1966 for aggregate in &plan.aggregates {
1967 match aggregate {
1968 AggregateExpr::CountStar { alias, distinct } => {
1969 specs.push(AggregateSpec {
1970 alias: alias.clone(),
1971 kind: AggregateKind::Count {
1972 field_id: None,
1973 distinct: *distinct,
1974 },
1975 });
1976 spec_to_projection.push(None);
1977 }
1978 AggregateExpr::Column {
1979 column,
1980 alias,
1981 function,
1982 distinct,
1983 } => {
1984 let key = column.to_ascii_lowercase();
1985 let column_index = *column_lookup_map.get(&key).ok_or_else(|| {
1986 Error::InvalidArgumentError(format!(
1987 "unknown column '{column}' in aggregate"
1988 ))
1989 })?;
1990 let field = combined_schema.field(column_index);
1991 let kind = match function {
1992 AggregateFunction::Count => AggregateKind::Count {
1993 field_id: Some(column_index as u32),
1994 distinct: *distinct,
1995 },
1996 AggregateFunction::SumInt64 => {
1997 let input_type = Self::validate_aggregate_type(
1998 Some(field.data_type().clone()),
1999 "SUM",
2000 &[DataType::Int64, DataType::Float64],
2001 )?;
2002 AggregateKind::Sum {
2003 field_id: column_index as u32,
2004 data_type: input_type,
2005 distinct: *distinct,
2006 }
2007 }
2008 AggregateFunction::TotalInt64 => {
2009 let input_type = Self::validate_aggregate_type(
2010 Some(field.data_type().clone()),
2011 "TOTAL",
2012 &[DataType::Int64, DataType::Float64],
2013 )?;
2014 AggregateKind::Total {
2015 field_id: column_index as u32,
2016 data_type: input_type,
2017 distinct: *distinct,
2018 }
2019 }
2020 AggregateFunction::MinInt64 => {
2021 let input_type = Self::validate_aggregate_type(
2022 Some(field.data_type().clone()),
2023 "MIN",
2024 &[DataType::Int64, DataType::Float64],
2025 )?;
2026 AggregateKind::Min {
2027 field_id: column_index as u32,
2028 data_type: input_type,
2029 }
2030 }
2031 AggregateFunction::MaxInt64 => {
2032 let input_type = Self::validate_aggregate_type(
2033 Some(field.data_type().clone()),
2034 "MAX",
2035 &[DataType::Int64, DataType::Float64],
2036 )?;
2037 AggregateKind::Max {
2038 field_id: column_index as u32,
2039 data_type: input_type,
2040 }
2041 }
2042 AggregateFunction::CountNulls => AggregateKind::CountNulls {
2043 field_id: column_index as u32,
2044 },
2045 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
2046 field_id: column_index as u32,
2047 distinct: *distinct,
2048 separator: ",".to_string(),
2049 },
2050 };
2051
2052 specs.push(AggregateSpec {
2053 alias: alias.clone(),
2054 kind,
2055 });
2056 spec_to_projection.push(Some(column_index));
2057 }
2058 }
2059 }
2060
2061 if specs.is_empty() {
2062 return Err(Error::InvalidArgumentError(
2063 "aggregate query requires at least one aggregate expression".into(),
2064 ));
2065 }
2066
2067 let mut states = Vec::with_capacity(specs.len());
2068 for (idx, spec) in specs.iter().enumerate() {
2069 states.push(AggregateState {
2070 alias: spec.alias.clone(),
2071 accumulator: AggregateAccumulator::new_with_projection_index(
2072 spec,
2073 spec_to_projection[idx],
2074 None,
2075 )?,
2076 override_value: None,
2077 });
2078 }
2079
2080 for batch in &batches {
2081 for state in &mut states {
2082 state.update(batch)?;
2083 }
2084 }
2085
2086 let mut fields = Vec::with_capacity(states.len());
2087 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
2088 for state in states {
2089 let (field, array) = state.finalize()?;
2090 fields.push(Arc::new(field));
2091 arrays.push(array);
2092 }
2093
2094 let schema = Arc::new(Schema::new(fields));
2095 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
2096
2097 if plan.distinct {
2098 let mut distinct_state = DistinctState::default();
2099 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
2100 Some(filtered) => filtered,
2101 None => RecordBatch::new_empty(Arc::clone(&schema)),
2102 };
2103 }
2104
2105 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
2106 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
2107 }
2108
2109 Ok(SelectExecution::new_single_batch(
2110 display_name.to_string(),
2111 schema,
2112 batch,
2113 ))
2114 }
2115
2116 fn execute_cross_product_computed_aggregates(
2117 &self,
2118 combined_schema: Arc<Schema>,
2119 batches: Vec<RecordBatch>,
2120 column_lookup_map: &FxHashMap<String, usize>,
2121 plan: &SelectPlan,
2122 display_name: &str,
2123 ) -> ExecutorResult<SelectExecution<P>> {
2124 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
2125 for projection in &plan.projections {
2126 match projection {
2127 SelectProjection::Computed { expr, .. } => {
2128 Self::collect_aggregates(expr, &mut aggregate_specs);
2129 }
2130 SelectProjection::AllColumns
2131 | SelectProjection::AllColumnsExcept { .. }
2132 | SelectProjection::Column { .. } => {
2133 return Err(Error::InvalidArgumentError(
2134 "non-computed projections not supported with aggregate expressions".into(),
2135 ));
2136 }
2137 }
2138 }
2139
2140 if aggregate_specs.is_empty() {
2141 return Err(Error::InvalidArgumentError(
2142 "computed aggregate query requires at least one aggregate expression".into(),
2143 ));
2144 }
2145
2146 let aggregate_values = self.compute_cross_product_aggregate_values(
2147 &combined_schema,
2148 &batches,
2149 column_lookup_map,
2150 &aggregate_specs,
2151 )?;
2152
2153 let mut fields = Vec::with_capacity(plan.projections.len());
2154 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
2155
2156 for projection in &plan.projections {
2157 if let SelectProjection::Computed { expr, alias } = projection {
2158 if let ScalarExpr::Aggregate(agg) = expr {
2160 let key = format!("{:?}", agg);
2161 if let Some(agg_value) = aggregate_values.get(&key) {
2162 match agg_value {
2163 AggregateValue::Null => {
2164 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
2165 arrays.push(Arc::new(Int64Array::from(vec![None::<i64>])) as ArrayRef);
2166 }
2167 AggregateValue::Int64(v) => {
2168 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
2169 arrays.push(Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef);
2170 }
2171 AggregateValue::Float64(v) => {
2172 fields.push(Arc::new(Field::new(alias, DataType::Float64, true)));
2173 arrays
2174 .push(Arc::new(Float64Array::from(vec![Some(*v)])) as ArrayRef);
2175 }
2176 AggregateValue::String(s) => {
2177 fields.push(Arc::new(Field::new(alias, DataType::Utf8, true)));
2178 arrays
2179 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
2180 as ArrayRef);
2181 }
2182 }
2183 continue;
2184 }
2185 }
2186
2187 let value = Self::evaluate_expr_with_aggregates(expr, &aggregate_values)?;
2189 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
2190 arrays.push(Arc::new(Int64Array::from(vec![value])) as ArrayRef);
2191 }
2192 }
2193
2194 let schema = Arc::new(Schema::new(fields));
2195 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
2196
2197 if plan.distinct {
2198 let mut distinct_state = DistinctState::default();
2199 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
2200 Some(filtered) => filtered,
2201 None => RecordBatch::new_empty(Arc::clone(&schema)),
2202 };
2203 }
2204
2205 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
2206 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
2207 }
2208
2209 Ok(SelectExecution::new_single_batch(
2210 display_name.to_string(),
2211 schema,
2212 batch,
2213 ))
2214 }
2215
2216 fn compute_cross_product_aggregate_values(
2217 &self,
2218 combined_schema: &Arc<Schema>,
2219 batches: &[RecordBatch],
2220 column_lookup_map: &FxHashMap<String, usize>,
2221 aggregate_specs: &[(String, AggregateCall<String>)],
2222 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
2223 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
2224 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
2225
2226 let mut columns_per_batch: Option<Vec<Vec<ArrayRef>>> = None;
2227 let mut augmented_fields: Option<Vec<Field>> = None;
2228 let mut owned_batches: Option<Vec<RecordBatch>> = None;
2229 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
2230 FxHashMap::default();
2231 let mut computed_alias_counter: usize = 0;
2232 let mut expr_context = CrossProductExpressionContext::new(
2233 combined_schema.as_ref(),
2234 column_lookup_map.clone(),
2235 )?;
2236
2237 let mut ensure_computed_column =
2238 |expr: &ScalarExpr<String>| -> ExecutorResult<(usize, DataType)> {
2239 let key = format!("{:?}", expr);
2240 if let Some((idx, dtype)) = computed_projection_cache.get(&key) {
2241 return Ok((*idx, dtype.clone()));
2242 }
2243
2244 if columns_per_batch.is_none() {
2245 let initial_columns: Vec<Vec<ArrayRef>> = batches
2246 .iter()
2247 .map(|batch| batch.columns().to_vec())
2248 .collect();
2249 columns_per_batch = Some(initial_columns);
2250 }
2251 if augmented_fields.is_none() {
2252 augmented_fields = Some(
2253 combined_schema
2254 .fields()
2255 .iter()
2256 .map(|field| field.as_ref().clone())
2257 .collect(),
2258 );
2259 }
2260
2261 let translated = translate_scalar(expr, expr_context.schema(), |name| {
2262 Error::InvalidArgumentError(format!(
2263 "unknown column '{}' in aggregate expression",
2264 name
2265 ))
2266 })?;
2267 let data_type = infer_computed_data_type(expr_context.schema(), &translated)?;
2268
2269 if let Some(columns) = columns_per_batch.as_mut() {
2270 for (batch_idx, batch) in batches.iter().enumerate() {
2271 expr_context.reset();
2272 let array = expr_context.materialize_scalar_array(&translated, batch)?;
2273 if let Some(batch_columns) = columns.get_mut(batch_idx) {
2274 batch_columns.push(array);
2275 }
2276 }
2277 }
2278
2279 let column_index = augmented_fields
2280 .as_ref()
2281 .map(|fields| fields.len())
2282 .unwrap_or_else(|| combined_schema.fields().len());
2283
2284 let alias = format!("__agg_expr_cp_{}", computed_alias_counter);
2285 computed_alias_counter += 1;
2286 augmented_fields
2287 .as_mut()
2288 .expect("augmented fields initialized")
2289 .push(Field::new(&alias, data_type.clone(), true));
2290
2291 computed_projection_cache.insert(key, (column_index, data_type.clone()));
2292 Ok((column_index, data_type))
2293 };
2294
2295 for (key, agg) in aggregate_specs {
2296 match agg {
2297 AggregateCall::CountStar => {
2298 specs.push(AggregateSpec {
2299 alias: key.clone(),
2300 kind: AggregateKind::Count {
2301 field_id: None,
2302 distinct: false,
2303 },
2304 });
2305 spec_to_projection.push(None);
2306 }
2307 AggregateCall::Count { expr, .. }
2308 | AggregateCall::Sum { expr, .. }
2309 | AggregateCall::Total { expr, .. }
2310 | AggregateCall::Avg { expr, .. }
2311 | AggregateCall::Min(expr)
2312 | AggregateCall::Max(expr)
2313 | AggregateCall::CountNulls(expr)
2314 | AggregateCall::GroupConcat { expr, .. } => {
2315 let (column_index, data_type_opt) = if let Some(column) =
2316 try_extract_simple_column(expr)
2317 {
2318 let key_lower = column.to_ascii_lowercase();
2319 let column_index = *column_lookup_map.get(&key_lower).ok_or_else(|| {
2320 Error::InvalidArgumentError(format!(
2321 "unknown column '{column}' in aggregate"
2322 ))
2323 })?;
2324 let field = combined_schema.field(column_index);
2325 (column_index, Some(field.data_type().clone()))
2326 } else {
2327 let (index, dtype) = ensure_computed_column(expr)?;
2328 (index, Some(dtype))
2329 };
2330
2331 let kind = match agg {
2332 AggregateCall::Count { distinct, .. } => {
2333 let field_id = u32::try_from(column_index).map_err(|_| {
2334 Error::InvalidArgumentError(
2335 "aggregate projection index exceeds supported range".into(),
2336 )
2337 })?;
2338 AggregateKind::Count {
2339 field_id: Some(field_id),
2340 distinct: *distinct,
2341 }
2342 }
2343 AggregateCall::Sum { distinct, .. } => {
2344 let input_type = Self::validate_aggregate_type(
2345 data_type_opt.clone(),
2346 "SUM",
2347 &[DataType::Int64, DataType::Float64],
2348 )?;
2349 let field_id = u32::try_from(column_index).map_err(|_| {
2350 Error::InvalidArgumentError(
2351 "aggregate projection index exceeds supported range".into(),
2352 )
2353 })?;
2354 AggregateKind::Sum {
2355 field_id,
2356 data_type: input_type,
2357 distinct: *distinct,
2358 }
2359 }
2360 AggregateCall::Total { distinct, .. } => {
2361 let input_type = Self::validate_aggregate_type(
2362 data_type_opt.clone(),
2363 "TOTAL",
2364 &[DataType::Int64, DataType::Float64],
2365 )?;
2366 let field_id = u32::try_from(column_index).map_err(|_| {
2367 Error::InvalidArgumentError(
2368 "aggregate projection index exceeds supported range".into(),
2369 )
2370 })?;
2371 AggregateKind::Total {
2372 field_id,
2373 data_type: input_type,
2374 distinct: *distinct,
2375 }
2376 }
2377 AggregateCall::Avg { distinct, .. } => {
2378 let input_type = Self::validate_aggregate_type(
2379 data_type_opt.clone(),
2380 "AVG",
2381 &[DataType::Int64, DataType::Float64],
2382 )?;
2383 let field_id = u32::try_from(column_index).map_err(|_| {
2384 Error::InvalidArgumentError(
2385 "aggregate projection index exceeds supported range".into(),
2386 )
2387 })?;
2388 AggregateKind::Avg {
2389 field_id,
2390 data_type: input_type,
2391 distinct: *distinct,
2392 }
2393 }
2394 AggregateCall::Min(_) => {
2395 let input_type = Self::validate_aggregate_type(
2396 data_type_opt.clone(),
2397 "MIN",
2398 &[DataType::Int64, DataType::Float64],
2399 )?;
2400 let field_id = u32::try_from(column_index).map_err(|_| {
2401 Error::InvalidArgumentError(
2402 "aggregate projection index exceeds supported range".into(),
2403 )
2404 })?;
2405 AggregateKind::Min {
2406 field_id,
2407 data_type: input_type,
2408 }
2409 }
2410 AggregateCall::Max(_) => {
2411 let input_type = Self::validate_aggregate_type(
2412 data_type_opt.clone(),
2413 "MAX",
2414 &[DataType::Int64, DataType::Float64],
2415 )?;
2416 let field_id = u32::try_from(column_index).map_err(|_| {
2417 Error::InvalidArgumentError(
2418 "aggregate projection index exceeds supported range".into(),
2419 )
2420 })?;
2421 AggregateKind::Max {
2422 field_id,
2423 data_type: input_type,
2424 }
2425 }
2426 AggregateCall::CountNulls(_) => {
2427 let field_id = u32::try_from(column_index).map_err(|_| {
2428 Error::InvalidArgumentError(
2429 "aggregate projection index exceeds supported range".into(),
2430 )
2431 })?;
2432 AggregateKind::CountNulls { field_id }
2433 }
2434 AggregateCall::GroupConcat {
2435 distinct,
2436 separator,
2437 ..
2438 } => {
2439 let field_id = u32::try_from(column_index).map_err(|_| {
2440 Error::InvalidArgumentError(
2441 "aggregate projection index exceeds supported range".into(),
2442 )
2443 })?;
2444 AggregateKind::GroupConcat {
2445 field_id,
2446 distinct: *distinct,
2447 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
2448 }
2449 }
2450 _ => unreachable!(),
2451 };
2452
2453 specs.push(AggregateSpec {
2454 alias: key.clone(),
2455 kind,
2456 });
2457 spec_to_projection.push(Some(column_index));
2458 }
2459 }
2460 }
2461
2462 if let Some(columns) = columns_per_batch {
2463 let fields = augmented_fields.unwrap_or_else(|| {
2464 combined_schema
2465 .fields()
2466 .iter()
2467 .map(|field| field.as_ref().clone())
2468 .collect()
2469 });
2470 let augmented_schema = Arc::new(Schema::new(fields));
2471 let mut new_batches = Vec::with_capacity(columns.len());
2472 for batch_columns in columns {
2473 let batch = RecordBatch::try_new(Arc::clone(&augmented_schema), batch_columns)
2474 .map_err(|err| {
2475 Error::InvalidArgumentError(format!(
2476 "failed to materialize aggregate projections: {err}"
2477 ))
2478 })?;
2479 new_batches.push(batch);
2480 }
2481 owned_batches = Some(new_batches);
2482 }
2483
2484 let mut states = Vec::with_capacity(specs.len());
2485 for (idx, spec) in specs.iter().enumerate() {
2486 states.push(AggregateState {
2487 alias: spec.alias.clone(),
2488 accumulator: AggregateAccumulator::new_with_projection_index(
2489 spec,
2490 spec_to_projection[idx],
2491 None,
2492 )?,
2493 override_value: None,
2494 });
2495 }
2496
2497 let batch_iter: &[RecordBatch] = if let Some(ref extended) = owned_batches {
2498 extended.as_slice()
2499 } else {
2500 batches
2501 };
2502
2503 for batch in batch_iter {
2504 for state in &mut states {
2505 state.update(batch)?;
2506 }
2507 }
2508
2509 let mut results = FxHashMap::default();
2510 for state in states {
2511 let (field, array) = state.finalize()?;
2512
2513 if let Some(int_array) = array.as_any().downcast_ref::<Int64Array>() {
2515 if int_array.len() != 1 {
2516 return Err(Error::Internal(format!(
2517 "Expected single value from aggregate, got {}",
2518 int_array.len()
2519 )));
2520 }
2521 let value = if int_array.is_null(0) {
2522 AggregateValue::Null
2523 } else {
2524 AggregateValue::Int64(int_array.value(0))
2525 };
2526 results.insert(field.name().to_string(), value);
2527 }
2528 else if let Some(float_array) = array.as_any().downcast_ref::<Float64Array>() {
2530 if float_array.len() != 1 {
2531 return Err(Error::Internal(format!(
2532 "Expected single value from aggregate, got {}",
2533 float_array.len()
2534 )));
2535 }
2536 let value = if float_array.is_null(0) {
2537 AggregateValue::Null
2538 } else {
2539 AggregateValue::Float64(float_array.value(0))
2540 };
2541 results.insert(field.name().to_string(), value);
2542 }
2543 else if let Some(string_array) = array.as_any().downcast_ref::<StringArray>() {
2545 if string_array.len() != 1 {
2546 return Err(Error::Internal(format!(
2547 "Expected single value from aggregate, got {}",
2548 string_array.len()
2549 )));
2550 }
2551 let value = if string_array.is_null(0) {
2552 AggregateValue::Null
2553 } else {
2554 AggregateValue::String(string_array.value(0).to_string())
2555 };
2556 results.insert(field.name().to_string(), value);
2557 } else {
2558 return Err(Error::Internal(format!(
2559 "Unexpected array type from aggregate: {:?}",
2560 array.data_type()
2561 )));
2562 }
2563 }
2564
2565 Ok(results)
2566 }
2567
2568 fn try_execute_hash_join(
2585 &self,
2586 plan: &SelectPlan,
2587 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
2588 ) -> ExecutorResult<Option<(TableCrossProductData, bool)>> {
2589 let query_label_opt = current_query_label();
2590 let query_label = query_label_opt.as_deref().unwrap_or("<unknown query>");
2591
2592 let filter_wrapper = match &plan.filter {
2594 Some(filter) if filter.subqueries.is_empty() => filter,
2595 _ => {
2596 tracing::debug!(
2597 "join_opt[{query_label}]: skipping optimization – filter missing or uses subqueries"
2598 );
2599 return Ok(None);
2600 }
2601 };
2602
2603 if tables_with_handles.len() < 2 {
2604 tracing::debug!(
2605 "join_opt[{query_label}]: skipping optimization – requires at least 2 tables"
2606 );
2607 return Ok(None);
2608 }
2609
2610 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
2612 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
2613 let mut column_map = FxHashMap::default();
2614 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
2615 let column_name = column.name.to_ascii_lowercase();
2616 column_map.entry(column_name).or_insert(column_idx);
2617 }
2618 table_infos.push(TableInfo {
2619 index,
2620 table_ref,
2621 column_map,
2622 });
2623 }
2624
2625 let constraint_plan = match extract_join_constraints(
2627 &filter_wrapper.predicate,
2628 &table_infos,
2629 ) {
2630 Some(plan) => plan,
2631 None => {
2632 tracing::debug!(
2633 "join_opt[{query_label}]: skipping optimization – predicate parsing failed (contains OR or other unsupported top-level structure)"
2634 );
2635 return Ok(None);
2636 }
2637 };
2638
2639 tracing::debug!(
2640 "join_opt[{query_label}]: constraint extraction succeeded - equalities={}, literals={}, handled={}/{} predicates",
2641 constraint_plan.equalities.len(),
2642 constraint_plan.literals.len(),
2643 constraint_plan.handled_conjuncts,
2644 constraint_plan.total_conjuncts
2645 );
2646 tracing::debug!(
2647 "join_opt[{query_label}]: attempting hash join with tables={:?} filter={:?}",
2648 plan.tables
2649 .iter()
2650 .map(|t| t.qualified_name())
2651 .collect::<Vec<_>>(),
2652 filter_wrapper.predicate,
2653 );
2654
2655 if constraint_plan.unsatisfiable {
2657 tracing::debug!(
2658 "join_opt[{query_label}]: predicate unsatisfiable – returning empty result"
2659 );
2660 let mut combined_fields = Vec::new();
2661 let mut column_counts = Vec::new();
2662 for (_table_ref, executor_table) in tables_with_handles {
2663 for column in &executor_table.schema.columns {
2664 combined_fields.push(Field::new(
2665 column.name.clone(),
2666 column.data_type.clone(),
2667 column.nullable,
2668 ));
2669 }
2670 column_counts.push(executor_table.schema.columns.len());
2671 }
2672 let combined_schema = Arc::new(Schema::new(combined_fields));
2673 let empty_batch = RecordBatch::new_empty(Arc::clone(&combined_schema));
2674 return Ok(Some((
2675 TableCrossProductData {
2676 schema: combined_schema,
2677 batches: vec![empty_batch],
2678 column_counts,
2679 table_indices: (0..tables_with_handles.len()).collect(),
2680 },
2681 true, )));
2683 }
2684
2685 if constraint_plan.equalities.is_empty() {
2687 tracing::debug!(
2688 "join_opt[{query_label}]: skipping optimization – no join equalities found"
2689 );
2690 return Ok(None);
2691 }
2692
2693 if !constraint_plan.literals.is_empty() {
2698 tracing::debug!(
2699 "join_opt[{query_label}]: found {} literal constraints - proceeding with hash join but may need fallback",
2700 constraint_plan.literals.len()
2701 );
2702 }
2703
2704 tracing::debug!(
2705 "join_opt[{query_label}]: hash join optimization applicable with {} equality constraints",
2706 constraint_plan.equalities.len()
2707 );
2708
2709 let mut literal_map: Vec<Vec<ColumnConstraint>> =
2710 vec![Vec::new(); tables_with_handles.len()];
2711 for constraint in &constraint_plan.literals {
2712 let table_idx = match constraint {
2713 ColumnConstraint::Equality(lit) => lit.column.table,
2714 ColumnConstraint::InList(in_list) => in_list.column.table,
2715 };
2716 if table_idx >= literal_map.len() {
2717 tracing::debug!(
2718 "join_opt[{query_label}]: constraint references unknown table index {}; falling back",
2719 table_idx
2720 );
2721 return Ok(None);
2722 }
2723 tracing::debug!(
2724 "join_opt[{query_label}]: mapping constraint to table_idx={} (table={})",
2725 table_idx,
2726 tables_with_handles[table_idx].0.qualified_name()
2727 );
2728 literal_map[table_idx].push(constraint.clone());
2729 }
2730
2731 let mut per_table: Vec<Option<TableCrossProductData>> =
2732 Vec::with_capacity(tables_with_handles.len());
2733 for (idx, (table_ref, table)) in tables_with_handles.iter().enumerate() {
2734 let data =
2735 collect_table_data(idx, table_ref, table.as_ref(), literal_map[idx].as_slice())?;
2736 per_table.push(Some(data));
2737 }
2738
2739 let has_left_join = plan
2741 .joins
2742 .iter()
2743 .any(|j| j.join_type == llkv_plan::JoinPlan::Left);
2744
2745 let mut current: Option<TableCrossProductData> = None;
2746
2747 if has_left_join {
2748 tracing::debug!(
2750 "join_opt[{query_label}]: delegating to llkv-join for LEFT JOIN support"
2751 );
2752 return Ok(None);
2754 } else {
2755 let mut remaining: Vec<usize> = (0..tables_with_handles.len()).collect();
2757 let mut used_tables: FxHashSet<usize> = FxHashSet::default();
2758
2759 while !remaining.is_empty() {
2760 let next_index = if used_tables.is_empty() {
2761 remaining[0]
2762 } else {
2763 match remaining.iter().copied().find(|idx| {
2764 table_has_join_with_used(*idx, &used_tables, &constraint_plan.equalities)
2765 }) {
2766 Some(idx) => idx,
2767 None => {
2768 tracing::debug!(
2769 "join_opt[{query_label}]: no remaining equality links – using cartesian expansion for table index {idx}",
2770 idx = remaining[0]
2771 );
2772 remaining[0]
2773 }
2774 }
2775 };
2776
2777 let position = remaining
2778 .iter()
2779 .position(|&idx| idx == next_index)
2780 .expect("next index present");
2781
2782 let next_data = per_table[next_index]
2783 .take()
2784 .ok_or_else(|| Error::Internal("hash join consumed table data twice".into()))?;
2785
2786 if let Some(current_data) = current.take() {
2787 let join_keys = gather_join_keys(
2788 ¤t_data,
2789 &next_data,
2790 &used_tables,
2791 next_index,
2792 &constraint_plan.equalities,
2793 )?;
2794
2795 let joined = if join_keys.is_empty() {
2796 tracing::debug!(
2797 "join_opt[{query_label}]: joining '{}' via cartesian expansion (no equality keys)",
2798 tables_with_handles[next_index].0.qualified_name()
2799 );
2800 cross_join_table_batches(current_data, next_data)?
2801 } else {
2802 hash_join_table_batches(
2803 current_data,
2804 next_data,
2805 &join_keys,
2806 llkv_join::JoinType::Inner,
2807 )?
2808 };
2809 current = Some(joined);
2810 } else {
2811 current = Some(next_data);
2812 }
2813
2814 used_tables.insert(next_index);
2815 remaining.remove(position);
2816 }
2817 }
2818
2819 if let Some(result) = current {
2820 let handled_all = constraint_plan.handled_conjuncts == constraint_plan.total_conjuncts;
2821 tracing::debug!(
2822 "join_opt[{query_label}]: hash join succeeded across {} tables (handled {}/{} predicates)",
2823 tables_with_handles.len(),
2824 constraint_plan.handled_conjuncts,
2825 constraint_plan.total_conjuncts
2826 );
2827 return Ok(Some((result, handled_all)));
2828 }
2829
2830 Ok(None)
2831 }
2832
2833 fn execute_projection(
2834 &self,
2835 table: Arc<ExecutorTable<P>>,
2836 display_name: String,
2837 plan: SelectPlan,
2838 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
2839 ) -> ExecutorResult<SelectExecution<P>> {
2840 if plan.having.is_some() {
2841 return Err(Error::InvalidArgumentError(
2842 "HAVING requires GROUP BY".into(),
2843 ));
2844 }
2845 if plan
2846 .filter
2847 .as_ref()
2848 .is_some_and(|filter| !filter.subqueries.is_empty())
2849 || !plan.scalar_subqueries.is_empty()
2850 {
2851 return self.execute_projection_with_subqueries(table, display_name, plan, row_filter);
2852 }
2853
2854 let table_ref = table.as_ref();
2855 let constant_filter = plan
2856 .filter
2857 .as_ref()
2858 .and_then(|filter| evaluate_constant_predicate(&filter.predicate));
2859 let projections = if plan.projections.is_empty() {
2860 build_wildcard_projections(table_ref)
2861 } else {
2862 build_projected_columns(table_ref, &plan.projections)?
2863 };
2864 let schema = schema_for_projections(table_ref, &projections)?;
2865
2866 if let Some(result) = constant_filter {
2867 match result {
2868 Some(true) => {
2869 }
2871 Some(false) | None => {
2872 let batch = RecordBatch::new_empty(Arc::clone(&schema));
2873 return Ok(SelectExecution::new_single_batch(
2874 display_name,
2875 schema,
2876 batch,
2877 ));
2878 }
2879 }
2880 }
2881
2882 let (mut filter_expr, mut full_table_scan) = match &plan.filter {
2883 Some(filter_wrapper) => (
2884 crate::translation::expression::translate_predicate(
2885 filter_wrapper.predicate.clone(),
2886 table_ref.schema.as_ref(),
2887 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
2888 )?,
2889 false,
2890 ),
2891 None => {
2892 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
2893 Error::InvalidArgumentError(
2894 "table has no columns; cannot perform wildcard scan".into(),
2895 )
2896 })?;
2897 (
2898 crate::translation::expression::full_table_scan_filter(field_id),
2899 true,
2900 )
2901 }
2902 };
2903
2904 if matches!(constant_filter, Some(Some(true))) {
2905 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
2906 Error::InvalidArgumentError(
2907 "table has no columns; cannot perform wildcard scan".into(),
2908 )
2909 })?;
2910 filter_expr = crate::translation::expression::full_table_scan_filter(field_id);
2911 full_table_scan = true;
2912 }
2913
2914 let expanded_order = expand_order_targets(&plan.order_by, &projections)?;
2915
2916 let mut physical_order: Option<ScanOrderSpec> = None;
2917
2918 if let Some(first) = expanded_order.first() {
2919 match &first.target {
2920 OrderTarget::Column(name) => {
2921 if table_ref.schema.resolve(name).is_some() {
2922 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
2923 }
2924 }
2925 OrderTarget::Index(position) => match projections.get(*position) {
2926 Some(ScanProjection::Column(_)) => {
2927 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
2928 }
2929 Some(ScanProjection::Computed { .. }) => {}
2930 None => {
2931 return Err(Error::InvalidArgumentError(format!(
2932 "ORDER BY position {} is out of range",
2933 position + 1
2934 )));
2935 }
2936 },
2937 OrderTarget::All => {}
2938 }
2939 }
2940
2941 let options = if let Some(order_spec) = physical_order {
2942 if row_filter.is_some() {
2943 tracing::debug!("Applying MVCC row filter with ORDER BY");
2944 }
2945 ScanStreamOptions {
2946 include_nulls: true,
2947 order: Some(order_spec),
2948 row_id_filter: row_filter.clone(),
2949 }
2950 } else {
2951 if row_filter.is_some() {
2952 tracing::debug!("Applying MVCC row filter");
2953 }
2954 ScanStreamOptions {
2955 include_nulls: true,
2956 order: None,
2957 row_id_filter: row_filter.clone(),
2958 }
2959 };
2960
2961 Ok(SelectExecution::new_projection(
2962 display_name,
2963 schema,
2964 table,
2965 projections,
2966 filter_expr,
2967 options,
2968 full_table_scan,
2969 expanded_order,
2970 plan.distinct,
2971 ))
2972 }
2973
2974 fn execute_projection_with_subqueries(
2975 &self,
2976 table: Arc<ExecutorTable<P>>,
2977 display_name: String,
2978 plan: SelectPlan,
2979 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
2980 ) -> ExecutorResult<SelectExecution<P>> {
2981 if plan.having.is_some() {
2982 return Err(Error::InvalidArgumentError(
2983 "HAVING requires GROUP BY".into(),
2984 ));
2985 }
2986 let table_ref = table.as_ref();
2987
2988 let (output_scan_projections, effective_projections): (
2989 Vec<ScanProjection>,
2990 Vec<SelectProjection>,
2991 ) = if plan.projections.is_empty() {
2992 (
2993 build_wildcard_projections(table_ref),
2994 vec![SelectProjection::AllColumns],
2995 )
2996 } else {
2997 (
2998 build_projected_columns(table_ref, &plan.projections)?,
2999 plan.projections.clone(),
3000 )
3001 };
3002
3003 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
3004 .scalar_subqueries
3005 .iter()
3006 .map(|subquery| (subquery.id, subquery))
3007 .collect();
3008
3009 let base_projections = build_wildcard_projections(table_ref);
3010
3011 let filter_wrapper_opt = plan.filter.as_ref();
3012
3013 let mut translated_filter: Option<llkv_expr::expr::Expr<'static, FieldId>> = None;
3014 let pushdown_filter = if let Some(filter_wrapper) = filter_wrapper_opt {
3015 let translated = crate::translation::expression::translate_predicate(
3016 filter_wrapper.predicate.clone(),
3017 table_ref.schema.as_ref(),
3018 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
3019 )?;
3020 if !filter_wrapper.subqueries.is_empty() {
3021 translated_filter = Some(translated.clone());
3022 strip_exists(&translated)
3023 } else {
3024 translated
3025 }
3026 } else {
3027 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
3028 Error::InvalidArgumentError(
3029 "table has no columns; cannot perform scalar subquery projection".into(),
3030 )
3031 })?;
3032 crate::translation::expression::full_table_scan_filter(field_id)
3033 };
3034
3035 let mut base_fields: Vec<Field> = Vec::with_capacity(table_ref.schema.columns.len());
3036 for column in &table_ref.schema.columns {
3037 base_fields.push(Field::new(
3038 column.name.clone(),
3039 column.data_type.clone(),
3040 column.nullable,
3041 ));
3042 }
3043 let base_schema = Arc::new(Schema::new(base_fields));
3044 let base_column_counts = vec![base_schema.fields().len()];
3045 let base_table_indices = vec![0usize];
3046 let base_lookup = build_cross_product_column_lookup(
3047 base_schema.as_ref(),
3048 &plan.tables,
3049 &base_column_counts,
3050 &base_table_indices,
3051 );
3052
3053 let mut filter_context = if translated_filter.is_some() {
3054 Some(CrossProductExpressionContext::new(
3055 base_schema.as_ref(),
3056 base_lookup.clone(),
3057 )?)
3058 } else {
3059 None
3060 };
3061
3062 let options = ScanStreamOptions {
3063 include_nulls: true,
3064 order: None,
3065 row_id_filter: row_filter.clone(),
3066 };
3067
3068 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
3069 filter_wrapper_opt
3070 .map(|wrapper| {
3071 wrapper
3072 .subqueries
3073 .iter()
3074 .map(|subquery| (subquery.id, subquery))
3075 .collect()
3076 })
3077 .unwrap_or_default();
3078
3079 let mut projected_batches: Vec<RecordBatch> = Vec::new();
3080 let mut scan_error: Option<Error> = None;
3081
3082 table.table.scan_stream(
3083 base_projections.clone(),
3084 &pushdown_filter,
3085 options,
3086 |batch| {
3087 if scan_error.is_some() {
3088 return;
3089 }
3090 let effective_batch = if let Some(context) = filter_context.as_mut() {
3091 context.reset();
3092 let translated = translated_filter
3093 .as_ref()
3094 .expect("filter context requires translated filter");
3095 let mask = match context.evaluate_predicate_mask(
3096 translated,
3097 &batch,
3098 |ctx, subquery_expr, row_idx, current_batch| {
3099 let subquery =
3100 subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
3101 Error::Internal("missing correlated subquery metadata".into())
3102 })?;
3103 let exists = self.evaluate_exists_subquery(
3104 ctx,
3105 subquery,
3106 current_batch,
3107 row_idx,
3108 )?;
3109 let value = if subquery_expr.negated {
3110 !exists
3111 } else {
3112 exists
3113 };
3114 Ok(Some(value))
3115 },
3116 ) {
3117 Ok(mask) => mask,
3118 Err(err) => {
3119 scan_error = Some(err);
3120 return;
3121 }
3122 };
3123 match filter_record_batch(&batch, &mask) {
3124 Ok(filtered) => {
3125 if filtered.num_rows() == 0 {
3126 return;
3127 }
3128 filtered
3129 }
3130 Err(err) => {
3131 scan_error = Some(Error::InvalidArgumentError(format!(
3132 "failed to apply EXISTS filter: {err}"
3133 )));
3134 return;
3135 }
3136 }
3137 } else {
3138 batch.clone()
3139 };
3140
3141 if effective_batch.num_rows() == 0 {
3142 return;
3143 }
3144
3145 let projected = match self.project_record_batch(
3146 &effective_batch,
3147 &effective_projections,
3148 &base_lookup,
3149 &scalar_lookup,
3150 ) {
3151 Ok(batch) => batch,
3152 Err(err) => {
3153 scan_error = Some(Error::InvalidArgumentError(format!(
3154 "failed to evaluate projections: {err}"
3155 )));
3156 return;
3157 }
3158 };
3159 projected_batches.push(projected);
3160 },
3161 )?;
3162
3163 if let Some(err) = scan_error {
3164 return Err(err);
3165 }
3166
3167 let mut result_batch = if projected_batches.is_empty() {
3168 let empty_batch = RecordBatch::new_empty(Arc::clone(&base_schema));
3169 self.project_record_batch(
3170 &empty_batch,
3171 &effective_projections,
3172 &base_lookup,
3173 &scalar_lookup,
3174 )?
3175 } else if projected_batches.len() == 1 {
3176 projected_batches.pop().unwrap()
3177 } else {
3178 let schema = projected_batches[0].schema();
3179 concat_batches(&schema, &projected_batches).map_err(|err| {
3180 Error::Internal(format!("failed to combine filtered batches: {err}"))
3181 })?
3182 };
3183
3184 if plan.distinct && result_batch.num_rows() > 0 {
3185 let mut state = DistinctState::default();
3186 let schema = result_batch.schema();
3187 result_batch = match distinct_filter_batch(result_batch, &mut state)? {
3188 Some(filtered) => filtered,
3189 None => RecordBatch::new_empty(schema),
3190 };
3191 }
3192
3193 if !plan.order_by.is_empty() && result_batch.num_rows() > 0 {
3194 let expanded_order = expand_order_targets(&plan.order_by, &output_scan_projections)?;
3195 if !expanded_order.is_empty() {
3196 result_batch = sort_record_batch_with_order(
3197 &result_batch.schema(),
3198 &result_batch,
3199 &expanded_order,
3200 )?;
3201 }
3202 }
3203
3204 let schema = result_batch.schema();
3205
3206 Ok(SelectExecution::new_single_batch(
3207 display_name,
3208 schema,
3209 result_batch,
3210 ))
3211 }
3212
3213 fn execute_group_by_single_table(
3214 &self,
3215 table: Arc<ExecutorTable<P>>,
3216 display_name: String,
3217 plan: SelectPlan,
3218 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
3219 ) -> ExecutorResult<SelectExecution<P>> {
3220 if plan
3221 .filter
3222 .as_ref()
3223 .is_some_and(|filter| !filter.subqueries.is_empty())
3224 || !plan.scalar_subqueries.is_empty()
3225 {
3226 return Err(Error::InvalidArgumentError(
3227 "GROUP BY with subqueries is not supported yet".into(),
3228 ));
3229 }
3230
3231 tracing::debug!(
3233 "[GROUP BY] Original plan: projections={}, aggregates={}, has_filter={}, has_having={}",
3234 plan.projections.len(),
3235 plan.aggregates.len(),
3236 plan.filter.is_some(),
3237 plan.having.is_some()
3238 );
3239
3240 let mut base_plan = plan.clone();
3244 base_plan.projections.clear();
3245 base_plan.aggregates.clear();
3246 base_plan.scalar_subqueries.clear();
3247 base_plan.order_by.clear();
3248 base_plan.distinct = false;
3249 base_plan.group_by.clear();
3250 base_plan.value_table_mode = None;
3251 base_plan.having = None;
3252
3253 tracing::debug!(
3254 "[GROUP BY] Base plan: projections={}, aggregates={}, has_filter={}, has_having={}",
3255 base_plan.projections.len(),
3256 base_plan.aggregates.len(),
3257 base_plan.filter.is_some(),
3258 base_plan.having.is_some()
3259 );
3260
3261 let table_ref = table.as_ref();
3264 let projections = build_wildcard_projections(table_ref);
3265 let base_schema = schema_for_projections(table_ref, &projections)?;
3266
3267 tracing::debug!(
3269 "[GROUP BY] Building base filter: has_filter={}",
3270 base_plan.filter.is_some()
3271 );
3272 let (filter_expr, full_table_scan) = match &base_plan.filter {
3273 Some(filter_wrapper) => {
3274 tracing::debug!(
3275 "[GROUP BY] Translating filter predicate: {:?}",
3276 filter_wrapper.predicate
3277 );
3278 let expr = crate::translation::expression::translate_predicate(
3279 filter_wrapper.predicate.clone(),
3280 table_ref.schema.as_ref(),
3281 |name| {
3282 Error::InvalidArgumentError(format!(
3283 "Binder Error: does not have a column named '{}'",
3284 name
3285 ))
3286 },
3287 )?;
3288 tracing::debug!("[GROUP BY] Translated filter expr: {:?}", expr);
3289 (expr, false)
3290 }
3291 None => {
3292 let first_col =
3294 table_ref.schema.columns.first().ok_or_else(|| {
3295 Error::InvalidArgumentError("Table has no columns".into())
3296 })?;
3297 (full_table_scan_filter(first_col.field_id), true)
3298 }
3299 };
3300
3301 let options = ScanStreamOptions {
3302 include_nulls: true,
3303 order: None,
3304 row_id_filter: row_filter.clone(),
3305 };
3306
3307 let execution = SelectExecution::new_projection(
3308 display_name.clone(),
3309 Arc::clone(&base_schema),
3310 Arc::clone(&table),
3311 projections,
3312 filter_expr,
3313 options,
3314 full_table_scan,
3315 vec![],
3316 false,
3317 );
3318
3319 let batches = execution.collect()?;
3320
3321 let column_lookup_map = build_column_lookup_map(base_schema.as_ref());
3322
3323 self.execute_group_by_from_batches(
3324 display_name,
3325 plan,
3326 base_schema,
3327 batches,
3328 column_lookup_map,
3329 )
3330 }
3331
3332 fn execute_group_by_from_batches(
3333 &self,
3334 display_name: String,
3335 plan: SelectPlan,
3336 base_schema: Arc<Schema>,
3337 batches: Vec<RecordBatch>,
3338 column_lookup_map: FxHashMap<String, usize>,
3339 ) -> ExecutorResult<SelectExecution<P>> {
3340 if plan
3341 .filter
3342 .as_ref()
3343 .is_some_and(|filter| !filter.subqueries.is_empty())
3344 || !plan.scalar_subqueries.is_empty()
3345 {
3346 return Err(Error::InvalidArgumentError(
3347 "GROUP BY with subqueries is not supported yet".into(),
3348 ));
3349 }
3350
3351 let having_has_aggregates = plan
3354 .having
3355 .as_ref()
3356 .map(|h| Self::predicate_contains_aggregate(h))
3357 .unwrap_or(false);
3358
3359 tracing::debug!(
3360 "[GROUP BY PATH] aggregates={}, has_computed={}, having_has_agg={}",
3361 plan.aggregates.len(),
3362 self.has_computed_aggregates(&plan),
3363 having_has_aggregates
3364 );
3365
3366 if !plan.aggregates.is_empty()
3367 || self.has_computed_aggregates(&plan)
3368 || having_has_aggregates
3369 {
3370 tracing::debug!("[GROUP BY PATH] Taking aggregates path");
3371 return self.execute_group_by_with_aggregates(
3372 display_name,
3373 plan,
3374 base_schema,
3375 batches,
3376 column_lookup_map,
3377 );
3378 }
3379
3380 let mut key_indices = Vec::with_capacity(plan.group_by.len());
3381 for column in &plan.group_by {
3382 let key = column.to_ascii_lowercase();
3383 let index = column_lookup_map.get(&key).ok_or_else(|| {
3384 Error::InvalidArgumentError(format!(
3385 "column '{}' not found in GROUP BY input",
3386 column
3387 ))
3388 })?;
3389 key_indices.push(*index);
3390 }
3391
3392 let sample_batch = batches
3393 .first()
3394 .cloned()
3395 .unwrap_or_else(|| RecordBatch::new_empty(Arc::clone(&base_schema)));
3396
3397 let output_columns = self.build_group_by_output_columns(
3398 &plan,
3399 base_schema.as_ref(),
3400 &column_lookup_map,
3401 &sample_batch,
3402 )?;
3403
3404 let constant_having = plan.having.as_ref().and_then(evaluate_constant_predicate);
3405
3406 if let Some(result) = constant_having
3407 && !result.unwrap_or(false)
3408 {
3409 let fields: Vec<Field> = output_columns
3410 .iter()
3411 .map(|output| output.field.clone())
3412 .collect();
3413 let schema = Arc::new(Schema::new(fields));
3414 let batch = RecordBatch::new_empty(Arc::clone(&schema));
3415 return Ok(SelectExecution::new_single_batch(
3416 display_name,
3417 schema,
3418 batch,
3419 ));
3420 }
3421
3422 let translated_having = if plan.having.is_some() && constant_having.is_none() {
3423 let having = plan.having.clone().expect("checked above");
3424 if Self::predicate_contains_aggregate(&having) {
3427 None
3428 } else {
3429 let temp_context = CrossProductExpressionContext::new(
3430 base_schema.as_ref(),
3431 column_lookup_map.clone(),
3432 )?;
3433 Some(translate_predicate(
3434 having,
3435 temp_context.schema(),
3436 |name| {
3437 Error::InvalidArgumentError(format!(
3438 "column '{}' not found in GROUP BY result",
3439 name
3440 ))
3441 },
3442 )?)
3443 }
3444 } else {
3445 None
3446 };
3447
3448 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
3449 let mut groups: Vec<GroupState> = Vec::new();
3450
3451 for batch in &batches {
3452 for row_idx in 0..batch.num_rows() {
3453 let key = build_group_key(batch, row_idx, &key_indices)?;
3454 if group_index.contains_key(&key) {
3455 continue;
3456 }
3457 group_index.insert(key, groups.len());
3458 groups.push(GroupState {
3459 batch: batch.clone(),
3460 row_idx,
3461 });
3462 }
3463 }
3464
3465 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(groups.len());
3466
3467 for group in &groups {
3468 if let Some(predicate) = translated_having.as_ref() {
3469 let mut context = CrossProductExpressionContext::new(
3470 group.batch.schema().as_ref(),
3471 column_lookup_map.clone(),
3472 )?;
3473 context.reset();
3474 let mut eval = |_ctx: &mut CrossProductExpressionContext,
3475 _subquery_expr: &llkv_expr::SubqueryExpr,
3476 _row_idx: usize,
3477 _current_batch: &RecordBatch|
3478 -> ExecutorResult<Option<bool>> {
3479 Err(Error::InvalidArgumentError(
3480 "HAVING subqueries are not supported yet".into(),
3481 ))
3482 };
3483 let truths =
3484 context.evaluate_predicate_truths(predicate, &group.batch, &mut eval)?;
3485 let passes = truths
3486 .get(group.row_idx)
3487 .copied()
3488 .flatten()
3489 .unwrap_or(false);
3490 if !passes {
3491 continue;
3492 }
3493 }
3494
3495 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
3496 for output in &output_columns {
3497 match output.source {
3498 OutputSource::TableColumn { index } => {
3499 let value = llkv_plan::plan_value_from_array(
3500 group.batch.column(index),
3501 group.row_idx,
3502 )?;
3503 row.push(value);
3504 }
3505 OutputSource::Computed { projection_index } => {
3506 let expr = match &plan.projections[projection_index] {
3507 SelectProjection::Computed { expr, .. } => expr,
3508 _ => unreachable!("projection index mismatch for computed column"),
3509 };
3510 let mut context = CrossProductExpressionContext::new(
3511 group.batch.schema().as_ref(),
3512 column_lookup_map.clone(),
3513 )?;
3514 context.reset();
3515 let evaluated = self.evaluate_projection_expression(
3516 &mut context,
3517 expr,
3518 &group.batch,
3519 &FxHashMap::default(),
3520 )?;
3521 let value = llkv_plan::plan_value_from_array(&evaluated, group.row_idx)?;
3522 row.push(value);
3523 }
3524 }
3525 }
3526 rows.push(row);
3527 }
3528
3529 let fields: Vec<Field> = output_columns
3530 .into_iter()
3531 .map(|output| output.field)
3532 .collect();
3533 let schema = Arc::new(Schema::new(fields));
3534
3535 let mut batch = rows_to_record_batch(Arc::clone(&schema), &rows)?;
3536
3537 if plan.distinct && batch.num_rows() > 0 {
3538 let mut state = DistinctState::default();
3539 batch = match distinct_filter_batch(batch, &mut state)? {
3540 Some(filtered) => filtered,
3541 None => RecordBatch::new_empty(Arc::clone(&schema)),
3542 };
3543 }
3544
3545 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
3546 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
3547 }
3548
3549 Ok(SelectExecution::new_single_batch(
3550 display_name,
3551 schema,
3552 batch,
3553 ))
3554 }
3555
3556 fn build_group_by_output_columns(
3557 &self,
3558 plan: &SelectPlan,
3559 base_schema: &Schema,
3560 column_lookup_map: &FxHashMap<String, usize>,
3561 _sample_batch: &RecordBatch,
3562 ) -> ExecutorResult<Vec<OutputColumn>> {
3563 let projections = if plan.projections.is_empty() {
3564 vec![SelectProjection::AllColumns]
3565 } else {
3566 plan.projections.clone()
3567 };
3568
3569 let mut columns: Vec<OutputColumn> = Vec::new();
3570
3571 for (proj_idx, projection) in projections.iter().enumerate() {
3572 match projection {
3573 SelectProjection::AllColumns => {
3574 for (index, field) in base_schema.fields().iter().enumerate() {
3575 columns.push(OutputColumn {
3576 field: (**field).clone(),
3577 source: OutputSource::TableColumn { index },
3578 });
3579 }
3580 }
3581 SelectProjection::AllColumnsExcept { exclude } => {
3582 let exclude_lower: FxHashSet<String> = exclude
3583 .iter()
3584 .map(|name| name.to_ascii_lowercase())
3585 .collect();
3586 for (index, field) in base_schema.fields().iter().enumerate() {
3587 if !exclude_lower.contains(&field.name().to_ascii_lowercase()) {
3588 columns.push(OutputColumn {
3589 field: (**field).clone(),
3590 source: OutputSource::TableColumn { index },
3591 });
3592 }
3593 }
3594 }
3595 SelectProjection::Column { name, alias } => {
3596 let lookup_key = name.to_ascii_lowercase();
3597 let index = column_lookup_map.get(&lookup_key).ok_or_else(|| {
3598 Error::InvalidArgumentError(format!(
3599 "column '{}' not found in GROUP BY result",
3600 name
3601 ))
3602 })?;
3603 let field = base_schema.field(*index);
3604 let field = Field::new(
3605 alias.as_ref().unwrap_or(name).clone(),
3606 field.data_type().clone(),
3607 field.is_nullable(),
3608 );
3609 columns.push(OutputColumn {
3610 field,
3611 source: OutputSource::TableColumn { index: *index },
3612 });
3613 }
3614 SelectProjection::Computed { expr: _, alias } => {
3615 let field = Field::new(alias.clone(), DataType::Float64, true);
3619 columns.push(OutputColumn {
3620 field,
3621 source: OutputSource::Computed {
3622 projection_index: proj_idx,
3623 },
3624 });
3625 }
3626 }
3627 }
3628
3629 if columns.is_empty() {
3630 for (index, field) in base_schema.fields().iter().enumerate() {
3631 columns.push(OutputColumn {
3632 field: (**field).clone(),
3633 source: OutputSource::TableColumn { index },
3634 });
3635 }
3636 }
3637
3638 Ok(columns)
3639 }
3640
3641 fn project_record_batch(
3642 &self,
3643 batch: &RecordBatch,
3644 projections: &[SelectProjection],
3645 lookup: &FxHashMap<String, usize>,
3646 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
3647 ) -> ExecutorResult<RecordBatch> {
3648 if projections.is_empty() {
3649 return Ok(batch.clone());
3650 }
3651
3652 let schema = batch.schema();
3653 let mut selected_fields: Vec<Arc<Field>> = Vec::new();
3654 let mut selected_columns: Vec<ArrayRef> = Vec::new();
3655 let mut expr_context: Option<CrossProductExpressionContext> = None;
3656
3657 for proj in projections {
3658 match proj {
3659 SelectProjection::AllColumns => {
3660 selected_fields = schema.fields().iter().cloned().collect();
3661 selected_columns = batch.columns().to_vec();
3662 break;
3663 }
3664 SelectProjection::AllColumnsExcept { exclude } => {
3665 let exclude_lower: FxHashSet<String> = exclude
3666 .iter()
3667 .map(|name| name.to_ascii_lowercase())
3668 .collect();
3669 for (idx, field) in schema.fields().iter().enumerate() {
3670 let column_name = field.name().to_ascii_lowercase();
3671 if !exclude_lower.contains(&column_name) {
3672 selected_fields.push(Arc::clone(field));
3673 selected_columns.push(batch.column(idx).clone());
3674 }
3675 }
3676 break;
3677 }
3678 SelectProjection::Column { name, alias } => {
3679 let normalized = name.to_ascii_lowercase();
3680 let column_index = lookup.get(&normalized).ok_or_else(|| {
3681 Error::InvalidArgumentError(format!(
3682 "column '{}' not found in projection",
3683 name
3684 ))
3685 })?;
3686 let field = schema.field(*column_index);
3687 let output_field = Arc::new(Field::new(
3688 alias.as_ref().unwrap_or_else(|| field.name()),
3689 field.data_type().clone(),
3690 field.is_nullable(),
3691 ));
3692 selected_fields.push(output_field);
3693 selected_columns.push(batch.column(*column_index).clone());
3694 }
3695 SelectProjection::Computed { expr, alias } => {
3696 if expr_context.is_none() {
3697 expr_context = Some(CrossProductExpressionContext::new(
3698 schema.as_ref(),
3699 lookup.clone(),
3700 )?);
3701 }
3702 let context = expr_context
3703 .as_mut()
3704 .expect("projection context must be initialized");
3705 context.reset();
3706 let evaluated =
3707 self.evaluate_projection_expression(context, expr, batch, scalar_lookup)?;
3708 let field = Arc::new(Field::new(
3709 alias.clone(),
3710 evaluated.data_type().clone(),
3711 true,
3712 ));
3713 selected_fields.push(field);
3714 selected_columns.push(evaluated);
3715 }
3716 }
3717 }
3718
3719 let projected_schema = Arc::new(Schema::new(selected_fields));
3720 RecordBatch::try_new(projected_schema, selected_columns)
3721 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))
3722 }
3723
3724 fn execute_group_by_with_aggregates(
3726 &self,
3727 display_name: String,
3728 plan: SelectPlan,
3729 base_schema: Arc<Schema>,
3730 batches: Vec<RecordBatch>,
3731 column_lookup_map: FxHashMap<String, usize>,
3732 ) -> ExecutorResult<SelectExecution<P>> {
3733 use llkv_expr::expr::AggregateCall;
3734
3735 let mut key_indices = Vec::with_capacity(plan.group_by.len());
3737 for column in &plan.group_by {
3738 let key = column.to_ascii_lowercase();
3739 let index = column_lookup_map.get(&key).ok_or_else(|| {
3740 Error::InvalidArgumentError(format!(
3741 "column '{}' not found in GROUP BY input",
3742 column
3743 ))
3744 })?;
3745 key_indices.push(*index);
3746 }
3747
3748 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
3750 for proj in &plan.projections {
3751 if let SelectProjection::Computed { expr, .. } = proj {
3752 Self::collect_aggregates(expr, &mut aggregate_specs);
3753 }
3754 }
3755
3756 if let Some(having_expr) = &plan.having {
3758 Self::collect_aggregates_from_predicate(having_expr, &mut aggregate_specs);
3759 }
3760
3761 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
3763 let mut group_states: Vec<GroupAggregateState> = Vec::new();
3764
3765 for (batch_idx, batch) in batches.iter().enumerate() {
3767 for row_idx in 0..batch.num_rows() {
3768 let key = build_group_key(batch, row_idx, &key_indices)?;
3769
3770 if let Some(&group_idx) = group_index.get(&key) {
3771 group_states[group_idx]
3773 .row_locations
3774 .push((batch_idx, row_idx));
3775 } else {
3776 let group_idx = group_states.len();
3778 group_index.insert(key, group_idx);
3779 group_states.push(GroupAggregateState {
3780 representative_batch_idx: batch_idx,
3781 representative_row: row_idx,
3782 row_locations: vec![(batch_idx, row_idx)],
3783 });
3784 }
3785 }
3786 }
3787
3788 let mut group_aggregate_values: Vec<FxHashMap<String, PlanValue>> =
3790 Vec::with_capacity(group_states.len());
3791
3792 for group_state in &group_states {
3793 tracing::debug!(
3794 "[GROUP BY] aggregate group rows={:?}",
3795 group_state.row_locations
3796 );
3797 let group_batch = {
3799 let representative_batch = &batches[group_state.representative_batch_idx];
3800 let schema = representative_batch.schema();
3801
3802 let mut per_batch_indices: Vec<(usize, Vec<u64>)> = Vec::new();
3804 for &(batch_idx, row_idx) in &group_state.row_locations {
3805 if let Some((_, indices)) = per_batch_indices
3806 .iter_mut()
3807 .find(|(idx, _)| *idx == batch_idx)
3808 {
3809 indices.push(row_idx as u64);
3810 } else {
3811 per_batch_indices.push((batch_idx, vec![row_idx as u64]));
3812 }
3813 }
3814
3815 let mut row_index_arrays: Vec<(usize, ArrayRef)> =
3816 Vec::with_capacity(per_batch_indices.len());
3817 for (batch_idx, indices) in per_batch_indices {
3818 let index_array: ArrayRef = Arc::new(arrow::array::UInt64Array::from(indices));
3819 row_index_arrays.push((batch_idx, index_array));
3820 }
3821
3822 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
3823
3824 for col_idx in 0..schema.fields().len() {
3825 let column_array = if row_index_arrays.len() == 1 {
3826 let (batch_idx, indices) = &row_index_arrays[0];
3827 let source_array = batches[*batch_idx].column(col_idx);
3828 arrow::compute::take(source_array.as_ref(), indices.as_ref(), None)?
3829 } else {
3830 let mut partial_arrays: Vec<ArrayRef> =
3831 Vec::with_capacity(row_index_arrays.len());
3832 for (batch_idx, indices) in &row_index_arrays {
3833 let source_array = batches[*batch_idx].column(col_idx);
3834 let taken = arrow::compute::take(
3835 source_array.as_ref(),
3836 indices.as_ref(),
3837 None,
3838 )?;
3839 partial_arrays.push(taken);
3840 }
3841 let slices: Vec<&dyn arrow::array::Array> =
3842 partial_arrays.iter().map(|arr| arr.as_ref()).collect();
3843 arrow::compute::concat(&slices)?
3844 };
3845 arrays.push(column_array);
3846 }
3847
3848 let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
3849 tracing::debug!("[GROUP BY] group batch rows={}", batch.num_rows());
3850 batch
3851 };
3852
3853 let mut aggregate_values: FxHashMap<String, PlanValue> = FxHashMap::default();
3855
3856 let mut working_batch = group_batch.clone();
3858 let mut next_temp_col_idx = working_batch.num_columns();
3859
3860 for (key, agg_call) in &aggregate_specs {
3861 let (projection_idx, value_type) = match agg_call {
3863 AggregateCall::CountStar => (None, None),
3864 AggregateCall::Count { expr, .. }
3865 | AggregateCall::Sum { expr, .. }
3866 | AggregateCall::Total { expr, .. }
3867 | AggregateCall::Avg { expr, .. }
3868 | AggregateCall::Min(expr)
3869 | AggregateCall::Max(expr)
3870 | AggregateCall::CountNulls(expr)
3871 | AggregateCall::GroupConcat { expr, .. } => {
3872 if let Some(col_name) = try_extract_simple_column(expr) {
3873 let idx = resolve_column_name_to_index(col_name, &column_lookup_map)
3874 .ok_or_else(|| {
3875 Error::InvalidArgumentError(format!(
3876 "column '{}' not found for aggregate",
3877 col_name
3878 ))
3879 })?;
3880 let field_type = working_batch.schema().field(idx).data_type().clone();
3881 (Some(idx), Some(field_type))
3882 } else {
3883 let mut computed_values = Vec::with_capacity(working_batch.num_rows());
3885 for row_idx in 0..working_batch.num_rows() {
3886 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
3887 expr,
3888 &FxHashMap::default(),
3889 Some(&working_batch),
3890 Some(&column_lookup_map),
3891 row_idx,
3892 )?;
3893 computed_values.push(value);
3894 }
3895
3896 let computed_array = plan_values_to_arrow_array(&computed_values)?;
3897 let computed_type = computed_array.data_type().clone();
3898
3899 let mut new_columns: Vec<ArrayRef> = working_batch.columns().to_vec();
3900 new_columns.push(computed_array);
3901
3902 let temp_field = Arc::new(Field::new(
3903 format!("__temp_agg_expr_{}", next_temp_col_idx),
3904 computed_type.clone(),
3905 true,
3906 ));
3907 let mut new_fields: Vec<Arc<Field>> =
3908 working_batch.schema().fields().iter().cloned().collect();
3909 new_fields.push(temp_field);
3910 let new_schema = Arc::new(Schema::new(new_fields));
3911
3912 working_batch = RecordBatch::try_new(new_schema, new_columns)?;
3913
3914 let col_idx = next_temp_col_idx;
3915 next_temp_col_idx += 1;
3916 (Some(col_idx), Some(computed_type))
3917 }
3918 }
3919 };
3920
3921 let spec = Self::build_aggregate_spec_for_cross_product(
3923 agg_call,
3924 key.clone(),
3925 value_type.clone(),
3926 )?;
3927
3928 let mut state = llkv_aggregate::AggregateState {
3929 alias: key.clone(),
3930 accumulator: llkv_aggregate::AggregateAccumulator::new_with_projection_index(
3931 &spec,
3932 projection_idx,
3933 None,
3934 )?,
3935 override_value: None,
3936 };
3937
3938 state.update(&working_batch)?;
3940
3941 let (_field, array) = state.finalize()?;
3943 let value = llkv_plan::plan_value_from_array(&array, 0)?;
3944 tracing::debug!(
3945 "[GROUP BY] aggregate result key={:?} value={:?}",
3946 key,
3947 value
3948 );
3949 aggregate_values.insert(key.clone(), value);
3950 }
3951
3952 group_aggregate_values.push(aggregate_values);
3953 }
3954
3955 let output_columns = self.build_group_by_output_columns(
3957 &plan,
3958 base_schema.as_ref(),
3959 &column_lookup_map,
3960 batches
3961 .first()
3962 .unwrap_or(&RecordBatch::new_empty(Arc::clone(&base_schema))),
3963 )?;
3964
3965 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(group_states.len());
3966
3967 for (group_idx, group_state) in group_states.iter().enumerate() {
3968 let aggregate_values = &group_aggregate_values[group_idx];
3969 let representative_batch = &batches[group_state.representative_batch_idx];
3970
3971 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
3972 for output in &output_columns {
3973 match output.source {
3974 OutputSource::TableColumn { index } => {
3975 let value = llkv_plan::plan_value_from_array(
3977 representative_batch.column(index),
3978 group_state.representative_row,
3979 )?;
3980 row.push(value);
3981 }
3982 OutputSource::Computed { projection_index } => {
3983 let expr = match &plan.projections[projection_index] {
3984 SelectProjection::Computed { expr, .. } => expr,
3985 _ => unreachable!("projection index mismatch for computed column"),
3986 };
3987 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
3989 expr,
3990 aggregate_values,
3991 Some(representative_batch),
3992 Some(&column_lookup_map),
3993 group_state.representative_row,
3994 )?;
3995 row.push(value);
3996 }
3997 }
3998 }
3999 rows.push(row);
4000 }
4001
4002 let filtered_rows = if let Some(having) = &plan.having {
4004 let mut filtered = Vec::new();
4005 for (row_idx, row) in rows.iter().enumerate() {
4006 let aggregate_values = &group_aggregate_values[row_idx];
4007 let group_state = &group_states[row_idx];
4008 let representative_batch = &batches[group_state.representative_batch_idx];
4009 let passes = Self::evaluate_having_expr(
4011 having,
4012 aggregate_values,
4013 representative_batch,
4014 &column_lookup_map,
4015 group_state.representative_row,
4016 )?;
4017 if matches!(passes, Some(true)) {
4019 filtered.push(row.clone());
4020 }
4021 }
4022 filtered
4023 } else {
4024 rows
4025 };
4026
4027 let fields: Vec<Field> = output_columns
4028 .into_iter()
4029 .map(|output| output.field)
4030 .collect();
4031 let schema = Arc::new(Schema::new(fields));
4032
4033 let mut batch = rows_to_record_batch(Arc::clone(&schema), &filtered_rows)?;
4034
4035 if plan.distinct && batch.num_rows() > 0 {
4036 let mut state = DistinctState::default();
4037 batch = match distinct_filter_batch(batch, &mut state)? {
4038 Some(filtered) => filtered,
4039 None => RecordBatch::new_empty(Arc::clone(&schema)),
4040 };
4041 }
4042
4043 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
4044 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
4045 }
4046
4047 Ok(SelectExecution::new_single_batch(
4048 display_name,
4049 schema,
4050 batch,
4051 ))
4052 }
4053
4054 fn execute_aggregates(
4055 &self,
4056 table: Arc<ExecutorTable<P>>,
4057 display_name: String,
4058 plan: SelectPlan,
4059 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4060 ) -> ExecutorResult<SelectExecution<P>> {
4061 let table_ref = table.as_ref();
4062 let distinct = plan.distinct;
4063 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
4064 for aggregate in plan.aggregates {
4065 match aggregate {
4066 AggregateExpr::CountStar { alias, distinct } => {
4067 specs.push(AggregateSpec {
4068 alias,
4069 kind: AggregateKind::Count {
4070 field_id: None,
4071 distinct,
4072 },
4073 });
4074 }
4075 AggregateExpr::Column {
4076 column,
4077 alias,
4078 function,
4079 distinct,
4080 } => {
4081 let col = table_ref.schema.resolve(&column).ok_or_else(|| {
4082 Error::InvalidArgumentError(format!(
4083 "unknown column '{}' in aggregate",
4084 column
4085 ))
4086 })?;
4087
4088 let kind = match function {
4089 AggregateFunction::Count => AggregateKind::Count {
4090 field_id: Some(col.field_id),
4091 distinct,
4092 },
4093 AggregateFunction::SumInt64 => {
4094 let input_type = Self::validate_aggregate_type(
4095 Some(col.data_type.clone()),
4096 "SUM",
4097 &[DataType::Int64, DataType::Float64],
4098 )?;
4099 AggregateKind::Sum {
4100 field_id: col.field_id,
4101 data_type: input_type,
4102 distinct,
4103 }
4104 }
4105 AggregateFunction::TotalInt64 => {
4106 let input_type = Self::validate_aggregate_type(
4107 Some(col.data_type.clone()),
4108 "TOTAL",
4109 &[DataType::Int64, DataType::Float64],
4110 )?;
4111 AggregateKind::Total {
4112 field_id: col.field_id,
4113 data_type: input_type,
4114 distinct,
4115 }
4116 }
4117 AggregateFunction::MinInt64 => {
4118 let input_type = Self::validate_aggregate_type(
4119 Some(col.data_type.clone()),
4120 "MIN",
4121 &[DataType::Int64, DataType::Float64],
4122 )?;
4123 AggregateKind::Min {
4124 field_id: col.field_id,
4125 data_type: input_type,
4126 }
4127 }
4128 AggregateFunction::MaxInt64 => {
4129 let input_type = Self::validate_aggregate_type(
4130 Some(col.data_type.clone()),
4131 "MAX",
4132 &[DataType::Int64, DataType::Float64],
4133 )?;
4134 AggregateKind::Max {
4135 field_id: col.field_id,
4136 data_type: input_type,
4137 }
4138 }
4139 AggregateFunction::CountNulls => {
4140 if distinct {
4141 return Err(Error::InvalidArgumentError(
4142 "DISTINCT is not supported for COUNT_NULLS".into(),
4143 ));
4144 }
4145 AggregateKind::CountNulls {
4146 field_id: col.field_id,
4147 }
4148 }
4149 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
4150 field_id: col.field_id,
4151 distinct,
4152 separator: ",".to_string(),
4153 },
4154 };
4155 specs.push(AggregateSpec { alias, kind });
4156 }
4157 }
4158 }
4159
4160 if specs.is_empty() {
4161 return Err(Error::InvalidArgumentError(
4162 "aggregate query requires at least one aggregate expression".into(),
4163 ));
4164 }
4165
4166 let had_filter = plan.filter.is_some();
4167 let filter_expr = match &plan.filter {
4168 Some(filter_wrapper) => {
4169 if !filter_wrapper.subqueries.is_empty() {
4170 return Err(Error::InvalidArgumentError(
4171 "EXISTS subqueries not yet implemented in aggregate queries".into(),
4172 ));
4173 }
4174 crate::translation::expression::translate_predicate(
4175 filter_wrapper.predicate.clone(),
4176 table.schema.as_ref(),
4177 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4178 )?
4179 }
4180 None => {
4181 let field_id = table.schema.first_field_id().ok_or_else(|| {
4182 Error::InvalidArgumentError(
4183 "table has no columns; cannot perform aggregate scan".into(),
4184 )
4185 })?;
4186 crate::translation::expression::full_table_scan_filter(field_id)
4187 }
4188 };
4189
4190 let mut projections = Vec::new();
4192 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(specs.len());
4193
4194 for spec in &specs {
4195 if let Some(field_id) = spec.kind.field_id() {
4196 let proj_idx = projections.len();
4197 spec_to_projection.push(Some(proj_idx));
4198 projections.push(ScanProjection::from(StoreProjection::with_alias(
4199 LogicalFieldId::for_user(table.table.table_id(), field_id),
4200 table
4201 .schema
4202 .column_by_field_id(field_id)
4203 .map(|c| c.name.clone())
4204 .unwrap_or_else(|| format!("col{field_id}")),
4205 )));
4206 } else {
4207 spec_to_projection.push(None);
4208 }
4209 }
4210
4211 if projections.is_empty() {
4212 let field_id = table.schema.first_field_id().ok_or_else(|| {
4213 Error::InvalidArgumentError(
4214 "table has no columns; cannot perform aggregate scan".into(),
4215 )
4216 })?;
4217 projections.push(ScanProjection::from(StoreProjection::with_alias(
4218 LogicalFieldId::for_user(table.table.table_id(), field_id),
4219 table
4220 .schema
4221 .column_by_field_id(field_id)
4222 .map(|c| c.name.clone())
4223 .unwrap_or_else(|| format!("col{field_id}")),
4224 )));
4225 }
4226
4227 let options = ScanStreamOptions {
4228 include_nulls: true,
4229 order: None,
4230 row_id_filter: row_filter.clone(),
4231 };
4232
4233 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
4234 let mut count_star_override: Option<i64> = None;
4238 if !had_filter && row_filter.is_none() {
4239 let total_rows = table.total_rows.load(Ordering::SeqCst);
4241 tracing::debug!(
4242 "[AGGREGATE] Using COUNT(*) shortcut: total_rows={}",
4243 total_rows
4244 );
4245 if total_rows > i64::MAX as u64 {
4246 return Err(Error::InvalidArgumentError(
4247 "COUNT(*) result exceeds supported range".into(),
4248 ));
4249 }
4250 count_star_override = Some(total_rows as i64);
4251 } else {
4252 tracing::debug!(
4253 "[AGGREGATE] NOT using COUNT(*) shortcut: had_filter={}, has_row_filter={}",
4254 had_filter,
4255 row_filter.is_some()
4256 );
4257 }
4258
4259 for (idx, spec) in specs.iter().enumerate() {
4260 states.push(AggregateState {
4261 alias: spec.alias.clone(),
4262 accumulator: AggregateAccumulator::new_with_projection_index(
4263 spec,
4264 spec_to_projection[idx],
4265 count_star_override,
4266 )?,
4267 override_value: match &spec.kind {
4268 AggregateKind::Count { field_id: None, .. } => {
4269 tracing::debug!(
4270 "[AGGREGATE] CountStar override_value={:?}",
4271 count_star_override
4272 );
4273 count_star_override
4274 }
4275 _ => None,
4276 },
4277 });
4278 }
4279
4280 let mut error: Option<Error> = None;
4281 match table.table.scan_stream(
4282 projections,
4283 &filter_expr,
4284 ScanStreamOptions {
4285 row_id_filter: row_filter.clone(),
4286 ..options
4287 },
4288 |batch| {
4289 if error.is_some() {
4290 return;
4291 }
4292 for state in &mut states {
4293 if let Err(err) = state.update(&batch) {
4294 error = Some(err);
4295 return;
4296 }
4297 }
4298 },
4299 ) {
4300 Ok(()) => {}
4301 Err(llkv_result::Error::NotFound) => {
4302 }
4305 Err(err) => return Err(err),
4306 }
4307 if let Some(err) = error {
4308 return Err(err);
4309 }
4310
4311 let mut fields = Vec::with_capacity(states.len());
4312 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
4313 for state in states {
4314 let (field, array) = state.finalize()?;
4315 fields.push(field);
4316 arrays.push(array);
4317 }
4318
4319 let schema = Arc::new(Schema::new(fields));
4320 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
4321
4322 if distinct {
4323 let mut state = DistinctState::default();
4324 batch = match distinct_filter_batch(batch, &mut state)? {
4325 Some(filtered) => filtered,
4326 None => RecordBatch::new_empty(Arc::clone(&schema)),
4327 };
4328 }
4329
4330 let schema = batch.schema();
4331
4332 Ok(SelectExecution::new_single_batch(
4333 display_name,
4334 schema,
4335 batch,
4336 ))
4337 }
4338
4339 fn execute_computed_aggregates(
4342 &self,
4343 table: Arc<ExecutorTable<P>>,
4344 display_name: String,
4345 plan: SelectPlan,
4346 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4347 ) -> ExecutorResult<SelectExecution<P>> {
4348 use arrow::array::Int64Array;
4349 use llkv_expr::expr::AggregateCall;
4350
4351 let table_ref = table.as_ref();
4352 let distinct = plan.distinct;
4353
4354 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
4356 for proj in &plan.projections {
4357 if let SelectProjection::Computed { expr, .. } = proj {
4358 Self::collect_aggregates(expr, &mut aggregate_specs);
4359 }
4360 }
4361
4362 let filter_predicate = plan
4364 .filter
4365 .as_ref()
4366 .map(|wrapper| {
4367 if !wrapper.subqueries.is_empty() {
4368 return Err(Error::InvalidArgumentError(
4369 "EXISTS subqueries not yet implemented with aggregates".into(),
4370 ));
4371 }
4372 Ok(wrapper.predicate.clone())
4373 })
4374 .transpose()?;
4375
4376 let computed_aggregates = self.compute_aggregate_values(
4377 table.clone(),
4378 &filter_predicate,
4379 &aggregate_specs,
4380 row_filter.clone(),
4381 )?;
4382
4383 let mut fields = Vec::with_capacity(plan.projections.len());
4385 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
4386
4387 for proj in &plan.projections {
4388 match proj {
4389 SelectProjection::AllColumns | SelectProjection::AllColumnsExcept { .. } => {
4390 return Err(Error::InvalidArgumentError(
4391 "Wildcard projections not supported with computed aggregates".into(),
4392 ));
4393 }
4394 SelectProjection::Column { name, alias } => {
4395 let col = table_ref.schema.resolve(name).ok_or_else(|| {
4396 Error::InvalidArgumentError(format!("unknown column '{}'", name))
4397 })?;
4398 let field_name = alias.as_ref().unwrap_or(name);
4399 fields.push(arrow::datatypes::Field::new(
4400 field_name,
4401 col.data_type.clone(),
4402 col.nullable,
4403 ));
4404 return Err(Error::InvalidArgumentError(
4407 "Regular columns not supported in aggregate queries without GROUP BY"
4408 .into(),
4409 ));
4410 }
4411 SelectProjection::Computed { expr, alias } => {
4412 if let ScalarExpr::Aggregate(agg) = expr {
4414 let key = format!("{:?}", agg);
4415 if let Some(agg_value) = computed_aggregates.get(&key) {
4416 match agg_value {
4417 AggregateValue::Null => {
4418 fields.push(arrow::datatypes::Field::new(
4419 alias,
4420 DataType::Int64,
4421 true,
4422 ));
4423 arrays
4424 .push(Arc::new(Int64Array::from(vec![None::<i64>]))
4425 as ArrayRef);
4426 }
4427 AggregateValue::Int64(v) => {
4428 fields.push(arrow::datatypes::Field::new(
4429 alias,
4430 DataType::Int64,
4431 true,
4432 ));
4433 arrays.push(
4434 Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef
4435 );
4436 }
4437 AggregateValue::Float64(v) => {
4438 fields.push(arrow::datatypes::Field::new(
4439 alias,
4440 DataType::Float64,
4441 true,
4442 ));
4443 arrays
4444 .push(Arc::new(Float64Array::from(vec![Some(*v)]))
4445 as ArrayRef);
4446 }
4447 AggregateValue::String(s) => {
4448 fields.push(arrow::datatypes::Field::new(
4449 alias,
4450 DataType::Utf8,
4451 true,
4452 ));
4453 arrays
4454 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
4455 as ArrayRef);
4456 }
4457 }
4458 continue;
4459 }
4460 }
4461
4462 let value = Self::evaluate_expr_with_aggregates(expr, &computed_aggregates)?;
4464
4465 fields.push(arrow::datatypes::Field::new(alias, DataType::Int64, true));
4466
4467 let array = Arc::new(Int64Array::from(vec![value])) as ArrayRef;
4468 arrays.push(array);
4469 }
4470 }
4471 }
4472
4473 let schema = Arc::new(Schema::new(fields));
4474 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
4475
4476 if distinct {
4477 let mut state = DistinctState::default();
4478 batch = match distinct_filter_batch(batch, &mut state)? {
4479 Some(filtered) => filtered,
4480 None => RecordBatch::new_empty(Arc::clone(&schema)),
4481 };
4482 }
4483
4484 let schema = batch.schema();
4485
4486 Ok(SelectExecution::new_single_batch(
4487 display_name,
4488 schema,
4489 batch,
4490 ))
4491 }
4492
4493 fn build_aggregate_spec_for_cross_product(
4496 agg_call: &llkv_expr::expr::AggregateCall<String>,
4497 alias: String,
4498 data_type: Option<DataType>,
4499 ) -> ExecutorResult<llkv_aggregate::AggregateSpec> {
4500 use llkv_expr::expr::AggregateCall;
4501
4502 let kind = match agg_call {
4503 AggregateCall::CountStar => llkv_aggregate::AggregateKind::Count {
4504 field_id: None,
4505 distinct: false,
4506 },
4507 AggregateCall::Count { distinct, .. } => llkv_aggregate::AggregateKind::Count {
4508 field_id: Some(0),
4509 distinct: *distinct,
4510 },
4511 AggregateCall::Sum { distinct, .. } => llkv_aggregate::AggregateKind::Sum {
4512 field_id: 0,
4513 data_type: Self::validate_aggregate_type(
4514 data_type.clone(),
4515 "SUM",
4516 &[DataType::Int64, DataType::Float64],
4517 )?,
4518 distinct: *distinct,
4519 },
4520 AggregateCall::Total { distinct, .. } => llkv_aggregate::AggregateKind::Total {
4521 field_id: 0,
4522 data_type: Self::validate_aggregate_type(
4523 data_type.clone(),
4524 "TOTAL",
4525 &[DataType::Int64, DataType::Float64],
4526 )?,
4527 distinct: *distinct,
4528 },
4529 AggregateCall::Avg { distinct, .. } => llkv_aggregate::AggregateKind::Avg {
4530 field_id: 0,
4531 data_type: Self::validate_aggregate_type(
4532 data_type.clone(),
4533 "AVG",
4534 &[DataType::Int64, DataType::Float64],
4535 )?,
4536 distinct: *distinct,
4537 },
4538 AggregateCall::Min(_) => llkv_aggregate::AggregateKind::Min {
4539 field_id: 0,
4540 data_type: Self::validate_aggregate_type(
4541 data_type.clone(),
4542 "MIN",
4543 &[DataType::Int64, DataType::Float64],
4544 )?,
4545 },
4546 AggregateCall::Max(_) => llkv_aggregate::AggregateKind::Max {
4547 field_id: 0,
4548 data_type: Self::validate_aggregate_type(
4549 data_type.clone(),
4550 "MAX",
4551 &[DataType::Int64, DataType::Float64],
4552 )?,
4553 },
4554 AggregateCall::CountNulls(_) => {
4555 llkv_aggregate::AggregateKind::CountNulls { field_id: 0 }
4556 }
4557 AggregateCall::GroupConcat {
4558 distinct,
4559 separator,
4560 ..
4561 } => llkv_aggregate::AggregateKind::GroupConcat {
4562 field_id: 0,
4563 distinct: *distinct,
4564 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
4565 },
4566 };
4567
4568 Ok(llkv_aggregate::AggregateSpec { alias, kind })
4569 }
4570
4571 fn validate_aggregate_type(
4586 data_type: Option<DataType>,
4587 func_name: &str,
4588 allowed: &[DataType],
4589 ) -> ExecutorResult<DataType> {
4590 let dt = data_type.ok_or_else(|| {
4591 Error::Internal(format!(
4592 "missing input type metadata for {func_name} aggregate"
4593 ))
4594 })?;
4595
4596 if matches!(func_name, "SUM" | "AVG" | "TOTAL" | "MIN" | "MAX") {
4599 match dt {
4600 DataType::Int64 | DataType::Float64 => Ok(dt),
4602
4603 DataType::Utf8 | DataType::Boolean | DataType::Date32 => Ok(DataType::Float64),
4606
4607 _ => Err(Error::InvalidArgumentError(format!(
4608 "{func_name} aggregate not supported for column type {:?}",
4609 dt
4610 ))),
4611 }
4612 } else {
4613 if allowed.iter().any(|candidate| candidate == &dt) {
4615 Ok(dt)
4616 } else {
4617 Err(Error::InvalidArgumentError(format!(
4618 "{func_name} aggregate not supported for column type {:?}",
4619 dt
4620 )))
4621 }
4622 }
4623 }
4624
4625 fn collect_aggregates(
4627 expr: &ScalarExpr<String>,
4628 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
4629 ) {
4630 match expr {
4631 ScalarExpr::Aggregate(agg) => {
4632 let key = format!("{:?}", agg);
4634 if !aggregates.iter().any(|(k, _)| k == &key) {
4635 aggregates.push((key, agg.clone()));
4636 }
4637 }
4638 ScalarExpr::Binary { left, right, .. } => {
4639 Self::collect_aggregates(left, aggregates);
4640 Self::collect_aggregates(right, aggregates);
4641 }
4642 ScalarExpr::Compare { left, right, .. } => {
4643 Self::collect_aggregates(left, aggregates);
4644 Self::collect_aggregates(right, aggregates);
4645 }
4646 ScalarExpr::GetField { base, .. } => {
4647 Self::collect_aggregates(base, aggregates);
4648 }
4649 ScalarExpr::Cast { expr, .. } => {
4650 Self::collect_aggregates(expr, aggregates);
4651 }
4652 ScalarExpr::Not(expr) => {
4653 Self::collect_aggregates(expr, aggregates);
4654 }
4655 ScalarExpr::IsNull { expr, .. } => {
4656 Self::collect_aggregates(expr, aggregates);
4657 }
4658 ScalarExpr::Case {
4659 operand,
4660 branches,
4661 else_expr,
4662 } => {
4663 if let Some(inner) = operand.as_deref() {
4664 Self::collect_aggregates(inner, aggregates);
4665 }
4666 for (when_expr, then_expr) in branches {
4667 Self::collect_aggregates(when_expr, aggregates);
4668 Self::collect_aggregates(then_expr, aggregates);
4669 }
4670 if let Some(inner) = else_expr.as_deref() {
4671 Self::collect_aggregates(inner, aggregates);
4672 }
4673 }
4674 ScalarExpr::Coalesce(items) => {
4675 for item in items {
4676 Self::collect_aggregates(item, aggregates);
4677 }
4678 }
4679 ScalarExpr::Column(_) | ScalarExpr::Literal(_) => {}
4680 ScalarExpr::ScalarSubquery(_) => {}
4681 }
4682 }
4683
4684 fn collect_aggregates_from_predicate(
4686 expr: &llkv_expr::expr::Expr<String>,
4687 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
4688 ) {
4689 match expr {
4690 llkv_expr::expr::Expr::Compare { left, right, .. } => {
4691 Self::collect_aggregates(left, aggregates);
4692 Self::collect_aggregates(right, aggregates);
4693 }
4694 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
4695 for e in exprs {
4696 Self::collect_aggregates_from_predicate(e, aggregates);
4697 }
4698 }
4699 llkv_expr::expr::Expr::Not(inner) => {
4700 Self::collect_aggregates_from_predicate(inner, aggregates);
4701 }
4702 llkv_expr::expr::Expr::InList {
4703 expr: test_expr,
4704 list,
4705 ..
4706 } => {
4707 Self::collect_aggregates(test_expr, aggregates);
4708 for item in list {
4709 Self::collect_aggregates(item, aggregates);
4710 }
4711 }
4712 llkv_expr::expr::Expr::IsNull { expr, .. } => {
4713 Self::collect_aggregates(expr, aggregates);
4714 }
4715 llkv_expr::expr::Expr::Literal(_) => {}
4716 llkv_expr::expr::Expr::Pred(_) => {}
4717 llkv_expr::expr::Expr::Exists(_) => {}
4718 }
4719 }
4720
4721 fn compute_aggregate_values(
4723 &self,
4724 table: Arc<ExecutorTable<P>>,
4725 filter: &Option<llkv_expr::expr::Expr<'static, String>>,
4726 aggregate_specs: &[(String, llkv_expr::expr::AggregateCall<String>)],
4727 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4728 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
4729 use llkv_expr::expr::AggregateCall;
4730
4731 let table_ref = table.as_ref();
4732 let mut results =
4733 FxHashMap::with_capacity_and_hasher(aggregate_specs.len(), Default::default());
4734
4735 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
4736 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
4737 let mut projections: Vec<ScanProjection> = Vec::new();
4738 let mut column_projection_cache: FxHashMap<FieldId, usize> = FxHashMap::default();
4739 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
4740 FxHashMap::default();
4741 let mut computed_alias_counter: usize = 0;
4742
4743 for (key, agg) in aggregate_specs {
4744 match agg {
4745 AggregateCall::CountStar => {
4746 specs.push(AggregateSpec {
4747 alias: key.clone(),
4748 kind: AggregateKind::Count {
4749 field_id: None,
4750 distinct: false,
4751 },
4752 });
4753 spec_to_projection.push(None);
4754 }
4755 AggregateCall::Count { expr, distinct } => {
4756 if let Some(col_name) = try_extract_simple_column(expr) {
4757 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4758 Error::InvalidArgumentError(format!(
4759 "unknown column '{}' in aggregate",
4760 col_name
4761 ))
4762 })?;
4763 let projection_index = get_or_insert_column_projection(
4764 &mut projections,
4765 &mut column_projection_cache,
4766 table_ref,
4767 col,
4768 );
4769 specs.push(AggregateSpec {
4770 alias: key.clone(),
4771 kind: AggregateKind::Count {
4772 field_id: Some(col.field_id),
4773 distinct: *distinct,
4774 },
4775 });
4776 spec_to_projection.push(Some(projection_index));
4777 } else {
4778 let (projection_index, _dtype) = ensure_computed_projection(
4779 expr,
4780 table_ref,
4781 &mut projections,
4782 &mut computed_projection_cache,
4783 &mut computed_alias_counter,
4784 )?;
4785 let field_id = u32::try_from(projection_index).map_err(|_| {
4786 Error::InvalidArgumentError(
4787 "aggregate projection index exceeds supported range".into(),
4788 )
4789 })?;
4790 specs.push(AggregateSpec {
4791 alias: key.clone(),
4792 kind: AggregateKind::Count {
4793 field_id: Some(field_id),
4794 distinct: *distinct,
4795 },
4796 });
4797 spec_to_projection.push(Some(projection_index));
4798 }
4799 }
4800 AggregateCall::Sum { expr, distinct } => {
4801 let (projection_index, data_type, field_id) =
4802 if let Some(col_name) = try_extract_simple_column(expr) {
4803 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4804 Error::InvalidArgumentError(format!(
4805 "unknown column '{}' in aggregate",
4806 col_name
4807 ))
4808 })?;
4809 let projection_index = get_or_insert_column_projection(
4810 &mut projections,
4811 &mut column_projection_cache,
4812 table_ref,
4813 col,
4814 );
4815 let data_type = col.data_type.clone();
4816 (projection_index, data_type, col.field_id)
4817 } else {
4818 let (projection_index, inferred_type) = ensure_computed_projection(
4819 expr,
4820 table_ref,
4821 &mut projections,
4822 &mut computed_projection_cache,
4823 &mut computed_alias_counter,
4824 )?;
4825 let field_id = u32::try_from(projection_index).map_err(|_| {
4826 Error::InvalidArgumentError(
4827 "aggregate projection index exceeds supported range".into(),
4828 )
4829 })?;
4830 (projection_index, inferred_type, field_id)
4831 };
4832 let normalized_type = Self::validate_aggregate_type(
4833 Some(data_type.clone()),
4834 "SUM",
4835 &[DataType::Int64, DataType::Float64],
4836 )?;
4837 specs.push(AggregateSpec {
4838 alias: key.clone(),
4839 kind: AggregateKind::Sum {
4840 field_id,
4841 data_type: normalized_type,
4842 distinct: *distinct,
4843 },
4844 });
4845 spec_to_projection.push(Some(projection_index));
4846 }
4847 AggregateCall::Total { expr, distinct } => {
4848 let (projection_index, data_type, field_id) =
4849 if let Some(col_name) = try_extract_simple_column(expr) {
4850 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4851 Error::InvalidArgumentError(format!(
4852 "unknown column '{}' in aggregate",
4853 col_name
4854 ))
4855 })?;
4856 let projection_index = get_or_insert_column_projection(
4857 &mut projections,
4858 &mut column_projection_cache,
4859 table_ref,
4860 col,
4861 );
4862 let data_type = col.data_type.clone();
4863 (projection_index, data_type, col.field_id)
4864 } else {
4865 let (projection_index, inferred_type) = ensure_computed_projection(
4866 expr,
4867 table_ref,
4868 &mut projections,
4869 &mut computed_projection_cache,
4870 &mut computed_alias_counter,
4871 )?;
4872 let field_id = u32::try_from(projection_index).map_err(|_| {
4873 Error::InvalidArgumentError(
4874 "aggregate projection index exceeds supported range".into(),
4875 )
4876 })?;
4877 (projection_index, inferred_type, field_id)
4878 };
4879 let normalized_type = Self::validate_aggregate_type(
4880 Some(data_type.clone()),
4881 "TOTAL",
4882 &[DataType::Int64, DataType::Float64],
4883 )?;
4884 specs.push(AggregateSpec {
4885 alias: key.clone(),
4886 kind: AggregateKind::Total {
4887 field_id,
4888 data_type: normalized_type,
4889 distinct: *distinct,
4890 },
4891 });
4892 spec_to_projection.push(Some(projection_index));
4893 }
4894 AggregateCall::Avg { expr, distinct } => {
4895 let (projection_index, data_type, field_id) =
4896 if let Some(col_name) = try_extract_simple_column(expr) {
4897 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4898 Error::InvalidArgumentError(format!(
4899 "unknown column '{}' in aggregate",
4900 col_name
4901 ))
4902 })?;
4903 let projection_index = get_or_insert_column_projection(
4904 &mut projections,
4905 &mut column_projection_cache,
4906 table_ref,
4907 col,
4908 );
4909 let data_type = col.data_type.clone();
4910 (projection_index, data_type, col.field_id)
4911 } else {
4912 let (projection_index, inferred_type) = ensure_computed_projection(
4913 expr,
4914 table_ref,
4915 &mut projections,
4916 &mut computed_projection_cache,
4917 &mut computed_alias_counter,
4918 )?;
4919 tracing::debug!(
4920 "AVG aggregate expr={:?} inferred_type={:?}",
4921 expr,
4922 inferred_type
4923 );
4924 let field_id = u32::try_from(projection_index).map_err(|_| {
4925 Error::InvalidArgumentError(
4926 "aggregate projection index exceeds supported range".into(),
4927 )
4928 })?;
4929 (projection_index, inferred_type, field_id)
4930 };
4931 let normalized_type = Self::validate_aggregate_type(
4932 Some(data_type.clone()),
4933 "AVG",
4934 &[DataType::Int64, DataType::Float64],
4935 )?;
4936 specs.push(AggregateSpec {
4937 alias: key.clone(),
4938 kind: AggregateKind::Avg {
4939 field_id,
4940 data_type: normalized_type,
4941 distinct: *distinct,
4942 },
4943 });
4944 spec_to_projection.push(Some(projection_index));
4945 }
4946 AggregateCall::Min(expr) => {
4947 let (projection_index, data_type, field_id) =
4948 if let Some(col_name) = try_extract_simple_column(expr) {
4949 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4950 Error::InvalidArgumentError(format!(
4951 "unknown column '{}' in aggregate",
4952 col_name
4953 ))
4954 })?;
4955 let projection_index = get_or_insert_column_projection(
4956 &mut projections,
4957 &mut column_projection_cache,
4958 table_ref,
4959 col,
4960 );
4961 let data_type = col.data_type.clone();
4962 (projection_index, data_type, col.field_id)
4963 } else {
4964 let (projection_index, inferred_type) = ensure_computed_projection(
4965 expr,
4966 table_ref,
4967 &mut projections,
4968 &mut computed_projection_cache,
4969 &mut computed_alias_counter,
4970 )?;
4971 let field_id = u32::try_from(projection_index).map_err(|_| {
4972 Error::InvalidArgumentError(
4973 "aggregate projection index exceeds supported range".into(),
4974 )
4975 })?;
4976 (projection_index, inferred_type, field_id)
4977 };
4978 let normalized_type = Self::validate_aggregate_type(
4979 Some(data_type.clone()),
4980 "MIN",
4981 &[DataType::Int64, DataType::Float64],
4982 )?;
4983 specs.push(AggregateSpec {
4984 alias: key.clone(),
4985 kind: AggregateKind::Min {
4986 field_id,
4987 data_type: normalized_type,
4988 },
4989 });
4990 spec_to_projection.push(Some(projection_index));
4991 }
4992 AggregateCall::Max(expr) => {
4993 let (projection_index, data_type, field_id) =
4994 if let Some(col_name) = try_extract_simple_column(expr) {
4995 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
4996 Error::InvalidArgumentError(format!(
4997 "unknown column '{}' in aggregate",
4998 col_name
4999 ))
5000 })?;
5001 let projection_index = get_or_insert_column_projection(
5002 &mut projections,
5003 &mut column_projection_cache,
5004 table_ref,
5005 col,
5006 );
5007 let data_type = col.data_type.clone();
5008 (projection_index, data_type, col.field_id)
5009 } else {
5010 let (projection_index, inferred_type) = ensure_computed_projection(
5011 expr,
5012 table_ref,
5013 &mut projections,
5014 &mut computed_projection_cache,
5015 &mut computed_alias_counter,
5016 )?;
5017 let field_id = u32::try_from(projection_index).map_err(|_| {
5018 Error::InvalidArgumentError(
5019 "aggregate projection index exceeds supported range".into(),
5020 )
5021 })?;
5022 (projection_index, inferred_type, field_id)
5023 };
5024 let normalized_type = Self::validate_aggregate_type(
5025 Some(data_type.clone()),
5026 "MAX",
5027 &[DataType::Int64, DataType::Float64],
5028 )?;
5029 specs.push(AggregateSpec {
5030 alias: key.clone(),
5031 kind: AggregateKind::Max {
5032 field_id,
5033 data_type: normalized_type,
5034 },
5035 });
5036 spec_to_projection.push(Some(projection_index));
5037 }
5038 AggregateCall::CountNulls(expr) => {
5039 if let Some(col_name) = try_extract_simple_column(expr) {
5040 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
5041 Error::InvalidArgumentError(format!(
5042 "unknown column '{}' in aggregate",
5043 col_name
5044 ))
5045 })?;
5046 let projection_index = get_or_insert_column_projection(
5047 &mut projections,
5048 &mut column_projection_cache,
5049 table_ref,
5050 col,
5051 );
5052 specs.push(AggregateSpec {
5053 alias: key.clone(),
5054 kind: AggregateKind::CountNulls {
5055 field_id: col.field_id,
5056 },
5057 });
5058 spec_to_projection.push(Some(projection_index));
5059 } else {
5060 let (projection_index, _dtype) = ensure_computed_projection(
5061 expr,
5062 table_ref,
5063 &mut projections,
5064 &mut computed_projection_cache,
5065 &mut computed_alias_counter,
5066 )?;
5067 let field_id = u32::try_from(projection_index).map_err(|_| {
5068 Error::InvalidArgumentError(
5069 "aggregate projection index exceeds supported range".into(),
5070 )
5071 })?;
5072 specs.push(AggregateSpec {
5073 alias: key.clone(),
5074 kind: AggregateKind::CountNulls { field_id },
5075 });
5076 spec_to_projection.push(Some(projection_index));
5077 }
5078 }
5079 AggregateCall::GroupConcat {
5080 expr,
5081 distinct,
5082 separator,
5083 } => {
5084 if let Some(col_name) = try_extract_simple_column(expr) {
5085 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
5086 Error::InvalidArgumentError(format!(
5087 "unknown column '{}' in aggregate",
5088 col_name
5089 ))
5090 })?;
5091 let projection_index = get_or_insert_column_projection(
5092 &mut projections,
5093 &mut column_projection_cache,
5094 table_ref,
5095 col,
5096 );
5097 specs.push(AggregateSpec {
5098 alias: key.clone(),
5099 kind: AggregateKind::GroupConcat {
5100 field_id: col.field_id,
5101 distinct: *distinct,
5102 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
5103 },
5104 });
5105 spec_to_projection.push(Some(projection_index));
5106 } else {
5107 let (projection_index, _dtype) = ensure_computed_projection(
5108 expr,
5109 table_ref,
5110 &mut projections,
5111 &mut computed_projection_cache,
5112 &mut computed_alias_counter,
5113 )?;
5114 let field_id = u32::try_from(projection_index).map_err(|_| {
5115 Error::InvalidArgumentError(
5116 "aggregate projection index exceeds supported range".into(),
5117 )
5118 })?;
5119 specs.push(AggregateSpec {
5120 alias: key.clone(),
5121 kind: AggregateKind::GroupConcat {
5122 field_id,
5123 distinct: *distinct,
5124 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
5125 },
5126 });
5127 spec_to_projection.push(Some(projection_index));
5128 }
5129 }
5130 }
5131 }
5132
5133 let filter_expr = match filter {
5134 Some(expr) => crate::translation::expression::translate_predicate(
5135 expr.clone(),
5136 table_ref.schema.as_ref(),
5137 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
5138 )?,
5139 None => {
5140 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
5141 Error::InvalidArgumentError(
5142 "table has no columns; cannot perform aggregate scan".into(),
5143 )
5144 })?;
5145 crate::translation::expression::full_table_scan_filter(field_id)
5146 }
5147 };
5148
5149 if projections.is_empty() {
5150 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
5151 Error::InvalidArgumentError(
5152 "table has no columns; cannot perform aggregate scan".into(),
5153 )
5154 })?;
5155 projections.push(ScanProjection::from(StoreProjection::with_alias(
5156 LogicalFieldId::for_user(table.table.table_id(), field_id),
5157 table
5158 .schema
5159 .column_by_field_id(field_id)
5160 .map(|c| c.name.clone())
5161 .unwrap_or_else(|| format!("col{field_id}")),
5162 )));
5163 }
5164
5165 let base_options = ScanStreamOptions {
5166 include_nulls: true,
5167 order: None,
5168 row_id_filter: None,
5169 };
5170
5171 let count_star_override: Option<i64> = None;
5172
5173 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
5174 for (idx, spec) in specs.iter().enumerate() {
5175 states.push(AggregateState {
5176 alias: spec.alias.clone(),
5177 accumulator: AggregateAccumulator::new_with_projection_index(
5178 spec,
5179 spec_to_projection[idx],
5180 count_star_override,
5181 )?,
5182 override_value: match &spec.kind {
5183 AggregateKind::Count { field_id: None, .. } => count_star_override,
5184 _ => None,
5185 },
5186 });
5187 }
5188
5189 let mut error: Option<Error> = None;
5190 match table.table.scan_stream(
5191 projections,
5192 &filter_expr,
5193 ScanStreamOptions {
5194 row_id_filter: row_filter.clone(),
5195 ..base_options
5196 },
5197 |batch| {
5198 if error.is_some() {
5199 return;
5200 }
5201 for state in &mut states {
5202 if let Err(err) = state.update(&batch) {
5203 error = Some(err);
5204 return;
5205 }
5206 }
5207 },
5208 ) {
5209 Ok(()) => {}
5210 Err(llkv_result::Error::NotFound) => {}
5211 Err(err) => return Err(err),
5212 }
5213 if let Some(err) = error {
5214 return Err(err);
5215 }
5216
5217 for state in states {
5218 let alias = state.alias.clone();
5219 let (_field, array) = state.finalize()?;
5220
5221 if let Some(int64_array) = array.as_any().downcast_ref::<arrow::array::Int64Array>() {
5222 if int64_array.len() != 1 {
5223 return Err(Error::Internal(format!(
5224 "Expected single value from aggregate, got {}",
5225 int64_array.len()
5226 )));
5227 }
5228 let value = if int64_array.is_null(0) {
5229 AggregateValue::Null
5230 } else {
5231 AggregateValue::Int64(int64_array.value(0))
5232 };
5233 results.insert(alias, value);
5234 } else if let Some(float64_array) =
5235 array.as_any().downcast_ref::<arrow::array::Float64Array>()
5236 {
5237 if float64_array.len() != 1 {
5238 return Err(Error::Internal(format!(
5239 "Expected single value from aggregate, got {}",
5240 float64_array.len()
5241 )));
5242 }
5243 let value = if float64_array.is_null(0) {
5244 AggregateValue::Null
5245 } else {
5246 AggregateValue::Float64(float64_array.value(0))
5247 };
5248 results.insert(alias, value);
5249 } else if let Some(string_array) =
5250 array.as_any().downcast_ref::<arrow::array::StringArray>()
5251 {
5252 if string_array.len() != 1 {
5253 return Err(Error::Internal(format!(
5254 "Expected single value from aggregate, got {}",
5255 string_array.len()
5256 )));
5257 }
5258 let value = if string_array.is_null(0) {
5259 AggregateValue::Null
5260 } else {
5261 AggregateValue::String(string_array.value(0).to_string())
5262 };
5263 results.insert(alias, value);
5264 } else {
5265 return Err(Error::Internal(format!(
5266 "Unexpected array type from aggregate: {:?}",
5267 array.data_type()
5268 )));
5269 }
5270 }
5271
5272 Ok(results)
5273 }
5274
5275 fn evaluate_having_expr(
5276 expr: &llkv_expr::expr::Expr<String>,
5277 aggregates: &FxHashMap<String, PlanValue>,
5278 row_batch: &RecordBatch,
5279 column_lookup: &FxHashMap<String, usize>,
5280 row_idx: usize,
5281 ) -> ExecutorResult<Option<bool>> {
5282 fn compare_plan_values_for_pred(
5283 left: &PlanValue,
5284 right: &PlanValue,
5285 ) -> Option<std::cmp::Ordering> {
5286 match (left, right) {
5287 (PlanValue::Integer(l), PlanValue::Integer(r)) => Some(l.cmp(r)),
5288 (PlanValue::Float(l), PlanValue::Float(r)) => l.partial_cmp(r),
5289 (PlanValue::Integer(l), PlanValue::Float(r)) => (*l as f64).partial_cmp(r),
5290 (PlanValue::Float(l), PlanValue::Integer(r)) => l.partial_cmp(&(*r as f64)),
5291 (PlanValue::String(l), PlanValue::String(r)) => Some(l.cmp(r)),
5292 _ => None,
5293 }
5294 }
5295
5296 fn evaluate_ordering_predicate<F>(
5297 value: &PlanValue,
5298 literal: &Literal,
5299 predicate: F,
5300 ) -> ExecutorResult<Option<bool>>
5301 where
5302 F: Fn(std::cmp::Ordering) -> bool,
5303 {
5304 if matches!(value, PlanValue::Null) {
5305 return Ok(None);
5306 }
5307 let expected = llkv_plan::plan_value_from_literal(literal)?;
5308 if matches!(expected, PlanValue::Null) {
5309 return Ok(None);
5310 }
5311
5312 match compare_plan_values_for_pred(value, &expected) {
5313 Some(ordering) => Ok(Some(predicate(ordering))),
5314 None => Err(Error::InvalidArgumentError(
5315 "unsupported HAVING comparison between column value and literal".into(),
5316 )),
5317 }
5318 }
5319
5320 match expr {
5321 llkv_expr::expr::Expr::Compare { left, op, right } => {
5322 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5323 left,
5324 aggregates,
5325 Some(row_batch),
5326 Some(column_lookup),
5327 row_idx,
5328 )?;
5329 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5330 right,
5331 aggregates,
5332 Some(row_batch),
5333 Some(column_lookup),
5334 row_idx,
5335 )?;
5336
5337 let (left_val, right_val) = match (&left_val, &right_val) {
5339 (PlanValue::Integer(i), PlanValue::Float(_)) => {
5340 (PlanValue::Float(*i as f64), right_val)
5341 }
5342 (PlanValue::Float(_), PlanValue::Integer(i)) => {
5343 (left_val, PlanValue::Float(*i as f64))
5344 }
5345 _ => (left_val, right_val),
5346 };
5347
5348 match (left_val, right_val) {
5349 (PlanValue::Null, _) | (_, PlanValue::Null) => Ok(None),
5351 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
5352 use llkv_expr::expr::CompareOp;
5353 Ok(Some(match op {
5354 CompareOp::Eq => l == r,
5355 CompareOp::NotEq => l != r,
5356 CompareOp::Lt => l < r,
5357 CompareOp::LtEq => l <= r,
5358 CompareOp::Gt => l > r,
5359 CompareOp::GtEq => l >= r,
5360 }))
5361 }
5362 (PlanValue::Float(l), PlanValue::Float(r)) => {
5363 use llkv_expr::expr::CompareOp;
5364 Ok(Some(match op {
5365 CompareOp::Eq => l == r,
5366 CompareOp::NotEq => l != r,
5367 CompareOp::Lt => l < r,
5368 CompareOp::LtEq => l <= r,
5369 CompareOp::Gt => l > r,
5370 CompareOp::GtEq => l >= r,
5371 }))
5372 }
5373 _ => Ok(Some(false)),
5374 }
5375 }
5376 llkv_expr::expr::Expr::Not(inner) => {
5377 match Self::evaluate_having_expr(
5379 inner,
5380 aggregates,
5381 row_batch,
5382 column_lookup,
5383 row_idx,
5384 )? {
5385 Some(b) => Ok(Some(!b)),
5386 None => Ok(None), }
5388 }
5389 llkv_expr::expr::Expr::InList {
5390 expr: test_expr,
5391 list,
5392 negated,
5393 } => {
5394 let test_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5395 test_expr,
5396 aggregates,
5397 Some(row_batch),
5398 Some(column_lookup),
5399 row_idx,
5400 )?;
5401
5402 if matches!(test_val, PlanValue::Null) {
5405 return Ok(None);
5406 }
5407
5408 let mut found = false;
5409 let mut has_null = false;
5410
5411 for list_item in list {
5412 let list_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5413 list_item,
5414 aggregates,
5415 Some(row_batch),
5416 Some(column_lookup),
5417 row_idx,
5418 )?;
5419
5420 if matches!(list_val, PlanValue::Null) {
5422 has_null = true;
5423 continue;
5424 }
5425
5426 let matches = match (&test_val, &list_val) {
5428 (PlanValue::Integer(a), PlanValue::Integer(b)) => a == b,
5429 (PlanValue::Float(a), PlanValue::Float(b)) => a == b,
5430 (PlanValue::Integer(a), PlanValue::Float(b)) => (*a as f64) == *b,
5431 (PlanValue::Float(a), PlanValue::Integer(b)) => *a == (*b as f64),
5432 (PlanValue::String(a), PlanValue::String(b)) => a == b,
5433 _ => false,
5434 };
5435
5436 if matches {
5437 found = true;
5438 break;
5439 }
5440 }
5441
5442 if *negated {
5446 Ok(if found {
5448 Some(false)
5449 } else if has_null {
5450 None } else {
5452 Some(true)
5453 })
5454 } else {
5455 Ok(if found {
5457 Some(true)
5458 } else if has_null {
5459 None } else {
5461 Some(false)
5462 })
5463 }
5464 }
5465 llkv_expr::expr::Expr::IsNull { expr, negated } => {
5466 let val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5468 expr,
5469 aggregates,
5470 Some(row_batch),
5471 Some(column_lookup),
5472 row_idx,
5473 )?;
5474
5475 let is_null = matches!(val, PlanValue::Null);
5479 Ok(Some(if *negated { !is_null } else { is_null }))
5480 }
5481 llkv_expr::expr::Expr::Literal(val) => Ok(Some(*val)),
5482 llkv_expr::expr::Expr::And(exprs) => {
5483 let mut has_null = false;
5485 for e in exprs {
5486 match Self::evaluate_having_expr(
5487 e,
5488 aggregates,
5489 row_batch,
5490 column_lookup,
5491 row_idx,
5492 )? {
5493 Some(false) => return Ok(Some(false)), None => has_null = true,
5495 Some(true) => {} }
5497 }
5498 Ok(if has_null { None } else { Some(true) })
5499 }
5500 llkv_expr::expr::Expr::Or(exprs) => {
5501 let mut has_null = false;
5503 for e in exprs {
5504 match Self::evaluate_having_expr(
5505 e,
5506 aggregates,
5507 row_batch,
5508 column_lookup,
5509 row_idx,
5510 )? {
5511 Some(true) => return Ok(Some(true)), None => has_null = true,
5513 Some(false) => {} }
5515 }
5516 Ok(if has_null { None } else { Some(false) })
5517 }
5518 llkv_expr::expr::Expr::Pred(filter) => {
5519 use llkv_expr::expr::Operator;
5522
5523 let col_name = &filter.field_id;
5524 let col_idx = column_lookup
5525 .get(&col_name.to_ascii_lowercase())
5526 .ok_or_else(|| {
5527 Error::InvalidArgumentError(format!(
5528 "column '{}' not found in HAVING context",
5529 col_name
5530 ))
5531 })?;
5532
5533 let value = llkv_plan::plan_value_from_array(row_batch.column(*col_idx), row_idx)?;
5534
5535 match &filter.op {
5536 Operator::IsNull => Ok(Some(matches!(value, PlanValue::Null))),
5537 Operator::IsNotNull => Ok(Some(!matches!(value, PlanValue::Null))),
5538 Operator::Equals(expected) => {
5539 if matches!(value, PlanValue::Null) {
5541 return Ok(None);
5542 }
5543 let expected_value = llkv_plan::plan_value_from_literal(expected)?;
5545 if matches!(expected_value, PlanValue::Null) {
5546 return Ok(None);
5547 }
5548 Ok(Some(value == expected_value))
5549 }
5550 Operator::GreaterThan(expected) => {
5551 evaluate_ordering_predicate(&value, expected, |ordering| {
5552 ordering == std::cmp::Ordering::Greater
5553 })
5554 }
5555 Operator::GreaterThanOrEquals(expected) => {
5556 evaluate_ordering_predicate(&value, expected, |ordering| {
5557 ordering == std::cmp::Ordering::Greater
5558 || ordering == std::cmp::Ordering::Equal
5559 })
5560 }
5561 Operator::LessThan(expected) => {
5562 evaluate_ordering_predicate(&value, expected, |ordering| {
5563 ordering == std::cmp::Ordering::Less
5564 })
5565 }
5566 Operator::LessThanOrEquals(expected) => {
5567 evaluate_ordering_predicate(&value, expected, |ordering| {
5568 ordering == std::cmp::Ordering::Less
5569 || ordering == std::cmp::Ordering::Equal
5570 })
5571 }
5572 _ => {
5573 Err(Error::InvalidArgumentError(format!(
5576 "Operator {:?} not supported for column predicates in HAVING clause",
5577 filter.op
5578 )))
5579 }
5580 }
5581 }
5582 llkv_expr::expr::Expr::Exists(_) => Err(Error::InvalidArgumentError(
5583 "EXISTS subqueries not supported in HAVING clause".into(),
5584 )),
5585 }
5586 }
5587
5588 fn evaluate_expr_with_plan_value_aggregates_and_row(
5589 expr: &ScalarExpr<String>,
5590 aggregates: &FxHashMap<String, PlanValue>,
5591 row_batch: Option<&RecordBatch>,
5592 column_lookup: Option<&FxHashMap<String, usize>>,
5593 row_idx: usize,
5594 ) -> ExecutorResult<PlanValue> {
5595 use llkv_expr::expr::BinaryOp;
5596 use llkv_expr::literal::Literal;
5597
5598 match expr {
5599 ScalarExpr::Literal(Literal::Integer(v)) => Ok(PlanValue::Integer(*v as i64)),
5600 ScalarExpr::Literal(Literal::Float(v)) => Ok(PlanValue::Float(*v)),
5601 ScalarExpr::Literal(Literal::Boolean(v)) => {
5602 Ok(PlanValue::Integer(if *v { 1 } else { 0 }))
5603 }
5604 ScalarExpr::Literal(Literal::String(s)) => Ok(PlanValue::String(s.clone())),
5605 ScalarExpr::Literal(Literal::Null) => Ok(PlanValue::Null),
5606 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
5607 "Struct literals not supported in aggregate expressions".into(),
5608 )),
5609 ScalarExpr::Column(col_name) => {
5610 if let (Some(batch), Some(lookup)) = (row_batch, column_lookup) {
5612 let col_idx = lookup.get(&col_name.to_ascii_lowercase()).ok_or_else(|| {
5613 Error::InvalidArgumentError(format!("column '{}' not found", col_name))
5614 })?;
5615 llkv_plan::plan_value_from_array(batch.column(*col_idx), row_idx)
5616 } else {
5617 Err(Error::InvalidArgumentError(
5618 "Column references not supported in aggregate-only expressions".into(),
5619 ))
5620 }
5621 }
5622 ScalarExpr::Compare { left, op, right } => {
5623 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5625 left,
5626 aggregates,
5627 row_batch,
5628 column_lookup,
5629 row_idx,
5630 )?;
5631 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5632 right,
5633 aggregates,
5634 row_batch,
5635 column_lookup,
5636 row_idx,
5637 )?;
5638
5639 if matches!(left_val, PlanValue::Null) || matches!(right_val, PlanValue::Null) {
5641 return Ok(PlanValue::Null);
5642 }
5643
5644 let (left_val, right_val) = match (&left_val, &right_val) {
5646 (PlanValue::Integer(i), PlanValue::Float(_)) => {
5647 (PlanValue::Float(*i as f64), right_val)
5648 }
5649 (PlanValue::Float(_), PlanValue::Integer(i)) => {
5650 (left_val, PlanValue::Float(*i as f64))
5651 }
5652 _ => (left_val, right_val),
5653 };
5654
5655 let result = match (&left_val, &right_val) {
5657 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
5658 use llkv_expr::expr::CompareOp;
5659 match op {
5660 CompareOp::Eq => l == r,
5661 CompareOp::NotEq => l != r,
5662 CompareOp::Lt => l < r,
5663 CompareOp::LtEq => l <= r,
5664 CompareOp::Gt => l > r,
5665 CompareOp::GtEq => l >= r,
5666 }
5667 }
5668 (PlanValue::Float(l), PlanValue::Float(r)) => {
5669 use llkv_expr::expr::CompareOp;
5670 match op {
5671 CompareOp::Eq => l == r,
5672 CompareOp::NotEq => l != r,
5673 CompareOp::Lt => l < r,
5674 CompareOp::LtEq => l <= r,
5675 CompareOp::Gt => l > r,
5676 CompareOp::GtEq => l >= r,
5677 }
5678 }
5679 (PlanValue::String(l), PlanValue::String(r)) => {
5680 use llkv_expr::expr::CompareOp;
5681 match op {
5682 CompareOp::Eq => l == r,
5683 CompareOp::NotEq => l != r,
5684 CompareOp::Lt => l < r,
5685 CompareOp::LtEq => l <= r,
5686 CompareOp::Gt => l > r,
5687 CompareOp::GtEq => l >= r,
5688 }
5689 }
5690 _ => false,
5691 };
5692
5693 Ok(PlanValue::Integer(if result { 1 } else { 0 }))
5695 }
5696 ScalarExpr::Not(inner) => {
5697 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5698 inner,
5699 aggregates,
5700 row_batch,
5701 column_lookup,
5702 row_idx,
5703 )?;
5704 match value {
5705 PlanValue::Integer(v) => Ok(PlanValue::Integer(if v != 0 { 0 } else { 1 })),
5706 PlanValue::Float(v) => Ok(PlanValue::Integer(if v != 0.0 { 0 } else { 1 })),
5707 PlanValue::Null => Ok(PlanValue::Null),
5708 other => Err(Error::InvalidArgumentError(format!(
5709 "logical NOT does not support value {other:?}"
5710 ))),
5711 }
5712 }
5713 ScalarExpr::IsNull { expr, negated } => {
5714 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5715 expr,
5716 aggregates,
5717 row_batch,
5718 column_lookup,
5719 row_idx,
5720 )?;
5721 let is_null = matches!(value, PlanValue::Null);
5722 let condition = if is_null { !negated } else { *negated };
5723 Ok(PlanValue::Integer(if condition { 1 } else { 0 }))
5724 }
5725 ScalarExpr::Aggregate(agg) => {
5726 let key = format!("{:?}", agg);
5727 aggregates
5728 .get(&key)
5729 .cloned()
5730 .ok_or_else(|| Error::Internal(format!("Aggregate value not found: {}", key)))
5731 }
5732 ScalarExpr::Binary { left, op, right } => {
5733 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5734 left,
5735 aggregates,
5736 row_batch,
5737 column_lookup,
5738 row_idx,
5739 )?;
5740 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5741 right,
5742 aggregates,
5743 row_batch,
5744 column_lookup,
5745 row_idx,
5746 )?;
5747
5748 match op {
5749 BinaryOp::Add
5750 | BinaryOp::Subtract
5751 | BinaryOp::Multiply
5752 | BinaryOp::Divide
5753 | BinaryOp::Modulo => {
5754 if matches!(&left_val, PlanValue::Null)
5755 || matches!(&right_val, PlanValue::Null)
5756 {
5757 return Ok(PlanValue::Null);
5758 }
5759
5760 if matches!(op, BinaryOp::Divide)
5761 && let (PlanValue::Integer(lhs), PlanValue::Integer(rhs)) =
5762 (&left_val, &right_val)
5763 {
5764 if *rhs == 0 {
5765 return Ok(PlanValue::Null);
5766 }
5767
5768 if *lhs == i64::MIN && *rhs == -1 {
5769 return Ok(PlanValue::Float((*lhs as f64) / (*rhs as f64)));
5770 }
5771
5772 return Ok(PlanValue::Integer(lhs / rhs));
5773 }
5774
5775 let left_is_float = matches!(&left_val, PlanValue::Float(_));
5776 let right_is_float = matches!(&right_val, PlanValue::Float(_));
5777
5778 let left_num = match left_val {
5779 PlanValue::Integer(i) => i as f64,
5780 PlanValue::Float(f) => f,
5781 other => {
5782 return Err(Error::InvalidArgumentError(format!(
5783 "Non-numeric value {:?} in binary operation",
5784 other
5785 )));
5786 }
5787 };
5788 let right_num = match right_val {
5789 PlanValue::Integer(i) => i as f64,
5790 PlanValue::Float(f) => f,
5791 other => {
5792 return Err(Error::InvalidArgumentError(format!(
5793 "Non-numeric value {:?} in binary operation",
5794 other
5795 )));
5796 }
5797 };
5798
5799 let result = match op {
5800 BinaryOp::Add => left_num + right_num,
5801 BinaryOp::Subtract => left_num - right_num,
5802 BinaryOp::Multiply => left_num * right_num,
5803 BinaryOp::Divide => {
5804 if right_num == 0.0 {
5805 return Ok(PlanValue::Null);
5806 }
5807 left_num / right_num
5808 }
5809 BinaryOp::Modulo => {
5810 if right_num == 0.0 {
5811 return Ok(PlanValue::Null);
5812 }
5813 left_num % right_num
5814 }
5815 BinaryOp::And
5816 | BinaryOp::Or
5817 | BinaryOp::BitwiseShiftLeft
5818 | BinaryOp::BitwiseShiftRight => unreachable!(),
5819 };
5820
5821 if matches!(op, BinaryOp::Divide) {
5822 return Ok(PlanValue::Float(result));
5823 }
5824
5825 if left_is_float || right_is_float {
5826 Ok(PlanValue::Float(result))
5827 } else {
5828 Ok(PlanValue::Integer(result as i64))
5829 }
5830 }
5831 BinaryOp::And => Ok(evaluate_plan_value_logical_and(left_val, right_val)),
5832 BinaryOp::Or => Ok(evaluate_plan_value_logical_or(left_val, right_val)),
5833 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
5834 if matches!(&left_val, PlanValue::Null)
5835 || matches!(&right_val, PlanValue::Null)
5836 {
5837 return Ok(PlanValue::Null);
5838 }
5839
5840 let lhs = match left_val {
5842 PlanValue::Integer(i) => i,
5843 PlanValue::Float(f) => f as i64,
5844 other => {
5845 return Err(Error::InvalidArgumentError(format!(
5846 "Non-numeric value {:?} in bitwise shift operation",
5847 other
5848 )));
5849 }
5850 };
5851 let rhs = match right_val {
5852 PlanValue::Integer(i) => i,
5853 PlanValue::Float(f) => f as i64,
5854 other => {
5855 return Err(Error::InvalidArgumentError(format!(
5856 "Non-numeric value {:?} in bitwise shift operation",
5857 other
5858 )));
5859 }
5860 };
5861
5862 let result = match op {
5864 BinaryOp::BitwiseShiftLeft => lhs.wrapping_shl(rhs as u32),
5865 BinaryOp::BitwiseShiftRight => lhs.wrapping_shr(rhs as u32),
5866 _ => unreachable!(),
5867 };
5868
5869 Ok(PlanValue::Integer(result))
5870 }
5871 }
5872 }
5873 ScalarExpr::Cast { expr, data_type } => {
5874 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5876 expr,
5877 aggregates,
5878 row_batch,
5879 column_lookup,
5880 row_idx,
5881 )?;
5882
5883 if matches!(value, PlanValue::Null) {
5885 return Ok(PlanValue::Null);
5886 }
5887
5888 match data_type {
5890 DataType::Int64 | DataType::Int32 | DataType::Int16 | DataType::Int8 => {
5891 match value {
5892 PlanValue::Integer(i) => Ok(PlanValue::Integer(i)),
5893 PlanValue::Float(f) => Ok(PlanValue::Integer(f as i64)),
5894 PlanValue::String(s) => {
5895 s.parse::<i64>().map(PlanValue::Integer).map_err(|_| {
5896 Error::InvalidArgumentError(format!(
5897 "Cannot cast '{}' to integer",
5898 s
5899 ))
5900 })
5901 }
5902 _ => Err(Error::InvalidArgumentError(format!(
5903 "Cannot cast {:?} to integer",
5904 value
5905 ))),
5906 }
5907 }
5908 DataType::Float64 | DataType::Float32 => match value {
5909 PlanValue::Integer(i) => Ok(PlanValue::Float(i as f64)),
5910 PlanValue::Float(f) => Ok(PlanValue::Float(f)),
5911 PlanValue::String(s) => {
5912 s.parse::<f64>().map(PlanValue::Float).map_err(|_| {
5913 Error::InvalidArgumentError(format!("Cannot cast '{}' to float", s))
5914 })
5915 }
5916 _ => Err(Error::InvalidArgumentError(format!(
5917 "Cannot cast {:?} to float",
5918 value
5919 ))),
5920 },
5921 DataType::Utf8 | DataType::LargeUtf8 => match value {
5922 PlanValue::String(s) => Ok(PlanValue::String(s)),
5923 PlanValue::Integer(i) => Ok(PlanValue::String(i.to_string())),
5924 PlanValue::Float(f) => Ok(PlanValue::String(f.to_string())),
5925 _ => Err(Error::InvalidArgumentError(format!(
5926 "Cannot cast {:?} to string",
5927 value
5928 ))),
5929 },
5930 _ => Err(Error::InvalidArgumentError(format!(
5931 "CAST to {:?} not supported in aggregate expressions",
5932 data_type
5933 ))),
5934 }
5935 }
5936 ScalarExpr::Case {
5937 operand,
5938 branches,
5939 else_expr,
5940 } => {
5941 let operand_value = if let Some(op) = operand {
5943 Some(Self::evaluate_expr_with_plan_value_aggregates_and_row(
5944 op,
5945 aggregates,
5946 row_batch,
5947 column_lookup,
5948 row_idx,
5949 )?)
5950 } else {
5951 None
5952 };
5953
5954 for (when_expr, then_expr) in branches {
5956 let matches = if let Some(ref op_val) = operand_value {
5957 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5959 when_expr,
5960 aggregates,
5961 row_batch,
5962 column_lookup,
5963 row_idx,
5964 )?;
5965 Self::simple_case_branch_matches(op_val, &when_val)
5966 } else {
5967 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5969 when_expr,
5970 aggregates,
5971 row_batch,
5972 column_lookup,
5973 row_idx,
5974 )?;
5975 match when_val {
5977 PlanValue::Integer(i) => i != 0,
5978 PlanValue::Float(f) => f != 0.0,
5979 PlanValue::Null => false,
5980 _ => false,
5981 }
5982 };
5983
5984 if matches {
5985 return Self::evaluate_expr_with_plan_value_aggregates_and_row(
5986 then_expr,
5987 aggregates,
5988 row_batch,
5989 column_lookup,
5990 row_idx,
5991 );
5992 }
5993 }
5994
5995 if let Some(else_e) = else_expr {
5997 Self::evaluate_expr_with_plan_value_aggregates_and_row(
5998 else_e,
5999 aggregates,
6000 row_batch,
6001 column_lookup,
6002 row_idx,
6003 )
6004 } else {
6005 Ok(PlanValue::Null)
6006 }
6007 }
6008 ScalarExpr::Coalesce(exprs) => {
6009 for expr in exprs {
6011 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6012 expr,
6013 aggregates,
6014 row_batch,
6015 column_lookup,
6016 row_idx,
6017 )?;
6018 if !matches!(value, PlanValue::Null) {
6019 return Ok(value);
6020 }
6021 }
6022 Ok(PlanValue::Null)
6023 }
6024 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
6025 "GetField not supported in aggregate expressions".into(),
6026 )),
6027 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
6028 "Scalar subqueries not supported in aggregate expressions".into(),
6029 )),
6030 }
6031 }
6032
6033 fn simple_case_branch_matches(operand: &PlanValue, candidate: &PlanValue) -> bool {
6034 if matches!(operand, PlanValue::Null) || matches!(candidate, PlanValue::Null) {
6035 return false;
6036 }
6037
6038 match (operand, candidate) {
6039 (PlanValue::Integer(left), PlanValue::Integer(right)) => left == right,
6040 (PlanValue::Integer(left), PlanValue::Float(right)) => (*left as f64) == *right,
6041 (PlanValue::Float(left), PlanValue::Integer(right)) => *left == (*right as f64),
6042 (PlanValue::Float(left), PlanValue::Float(right)) => left == right,
6043 (PlanValue::String(left), PlanValue::String(right)) => left == right,
6044 (PlanValue::Struct(left), PlanValue::Struct(right)) => left == right,
6045 _ => operand == candidate,
6046 }
6047 }
6048
6049 fn evaluate_expr_with_aggregates(
6050 expr: &ScalarExpr<String>,
6051 aggregates: &FxHashMap<String, AggregateValue>,
6052 ) -> ExecutorResult<Option<i64>> {
6053 use llkv_expr::expr::BinaryOp;
6054 use llkv_expr::literal::Literal;
6055
6056 match expr {
6057 ScalarExpr::Literal(Literal::Integer(v)) => Ok(Some(*v as i64)),
6058 ScalarExpr::Literal(Literal::Float(v)) => Ok(Some(*v as i64)),
6059 ScalarExpr::Literal(Literal::Boolean(v)) => Ok(Some(if *v { 1 } else { 0 })),
6060 ScalarExpr::Literal(Literal::String(_)) => Err(Error::InvalidArgumentError(
6061 "String literals not supported in aggregate expressions".into(),
6062 )),
6063 ScalarExpr::Literal(Literal::Null) => Ok(None),
6064 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
6065 "Struct literals not supported in aggregate expressions".into(),
6066 )),
6067 ScalarExpr::Column(_) => Err(Error::InvalidArgumentError(
6068 "Column references not supported in aggregate-only expressions".into(),
6069 )),
6070 ScalarExpr::Compare { .. } => Err(Error::InvalidArgumentError(
6071 "Comparisons not supported in aggregate-only expressions".into(),
6072 )),
6073 ScalarExpr::Aggregate(agg) => {
6074 let key = format!("{:?}", agg);
6075 let value = aggregates.get(&key).ok_or_else(|| {
6076 Error::Internal(format!("Aggregate value not found for key: {}", key))
6077 })?;
6078 Ok(value.as_i64())
6079 }
6080 ScalarExpr::Not(inner) => {
6081 let value = Self::evaluate_expr_with_aggregates(inner, aggregates)?;
6082 Ok(value.map(|v| if v != 0 { 0 } else { 1 }))
6083 }
6084 ScalarExpr::IsNull { expr, negated } => {
6085 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
6086 let is_null = value.is_none();
6087 Ok(Some(if is_null != *negated { 1 } else { 0 }))
6088 }
6089 ScalarExpr::Binary { left, op, right } => {
6090 let left_val = Self::evaluate_expr_with_aggregates(left, aggregates)?;
6091 let right_val = Self::evaluate_expr_with_aggregates(right, aggregates)?;
6092
6093 match op {
6094 BinaryOp::Add
6095 | BinaryOp::Subtract
6096 | BinaryOp::Multiply
6097 | BinaryOp::Divide
6098 | BinaryOp::Modulo => match (left_val, right_val) {
6099 (Some(lhs), Some(rhs)) => {
6100 let result = match op {
6101 BinaryOp::Add => lhs.checked_add(rhs),
6102 BinaryOp::Subtract => lhs.checked_sub(rhs),
6103 BinaryOp::Multiply => lhs.checked_mul(rhs),
6104 BinaryOp::Divide => {
6105 if rhs == 0 {
6106 return Ok(None);
6107 }
6108 lhs.checked_div(rhs)
6109 }
6110 BinaryOp::Modulo => {
6111 if rhs == 0 {
6112 return Ok(None);
6113 }
6114 lhs.checked_rem(rhs)
6115 }
6116 BinaryOp::And
6117 | BinaryOp::Or
6118 | BinaryOp::BitwiseShiftLeft
6119 | BinaryOp::BitwiseShiftRight => unreachable!(),
6120 };
6121
6122 result.map(Some).ok_or_else(|| {
6123 Error::InvalidArgumentError(
6124 "Arithmetic overflow in expression".into(),
6125 )
6126 })
6127 }
6128 _ => Ok(None),
6129 },
6130 BinaryOp::And => Ok(evaluate_option_logical_and(left_val, right_val)),
6131 BinaryOp::Or => Ok(evaluate_option_logical_or(left_val, right_val)),
6132 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
6133 match (left_val, right_val) {
6134 (Some(lhs), Some(rhs)) => {
6135 let result = match op {
6136 BinaryOp::BitwiseShiftLeft => {
6137 Some(lhs.wrapping_shl(rhs as u32))
6138 }
6139 BinaryOp::BitwiseShiftRight => {
6140 Some(lhs.wrapping_shr(rhs as u32))
6141 }
6142 _ => unreachable!(),
6143 };
6144 Ok(result)
6145 }
6146 _ => Ok(None),
6147 }
6148 }
6149 }
6150 }
6151 ScalarExpr::Cast { expr, data_type } => {
6152 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
6153 match value {
6154 Some(v) => Self::cast_aggregate_value(v, data_type).map(Some),
6155 None => Ok(None),
6156 }
6157 }
6158 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
6159 "GetField not supported in aggregate-only expressions".into(),
6160 )),
6161 ScalarExpr::Case { .. } => Err(Error::InvalidArgumentError(
6162 "CASE not supported in aggregate-only expressions".into(),
6163 )),
6164 ScalarExpr::Coalesce(_) => Err(Error::InvalidArgumentError(
6165 "COALESCE not supported in aggregate-only expressions".into(),
6166 )),
6167 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
6168 "Scalar subqueries not supported in aggregate-only expressions".into(),
6169 )),
6170 }
6171 }
6172
6173 fn cast_aggregate_value(value: i64, data_type: &DataType) -> ExecutorResult<i64> {
6174 fn ensure_range(value: i64, min: i64, max: i64, ty: &DataType) -> ExecutorResult<i64> {
6175 if value < min || value > max {
6176 return Err(Error::InvalidArgumentError(format!(
6177 "value {} out of range for CAST target {:?}",
6178 value, ty
6179 )));
6180 }
6181 Ok(value)
6182 }
6183
6184 match data_type {
6185 DataType::Int8 => ensure_range(value, i8::MIN as i64, i8::MAX as i64, data_type),
6186 DataType::Int16 => ensure_range(value, i16::MIN as i64, i16::MAX as i64, data_type),
6187 DataType::Int32 => ensure_range(value, i32::MIN as i64, i32::MAX as i64, data_type),
6188 DataType::Int64 => Ok(value),
6189 DataType::UInt8 => ensure_range(value, 0, u8::MAX as i64, data_type),
6190 DataType::UInt16 => ensure_range(value, 0, u16::MAX as i64, data_type),
6191 DataType::UInt32 => ensure_range(value, 0, u32::MAX as i64, data_type),
6192 DataType::UInt64 => {
6193 if value < 0 {
6194 return Err(Error::InvalidArgumentError(format!(
6195 "value {} out of range for CAST target {:?}",
6196 value, data_type
6197 )));
6198 }
6199 Ok(value)
6200 }
6201 DataType::Float32 | DataType::Float64 => Ok(value),
6202 DataType::Boolean => Ok(if value == 0 { 0 } else { 1 }),
6203 DataType::Null => Err(Error::InvalidArgumentError(
6204 "CAST to NULL is not supported in aggregate-only expressions".into(),
6205 )),
6206 _ => Err(Error::InvalidArgumentError(format!(
6207 "CAST to {:?} is not supported in aggregate-only expressions",
6208 data_type
6209 ))),
6210 }
6211 }
6212}
6213
6214struct CrossProductExpressionContext {
6215 schema: Arc<ExecutorSchema>,
6216 field_id_to_index: FxHashMap<FieldId, usize>,
6217 numeric_cache: FxHashMap<FieldId, NumericArray>,
6218 column_cache: FxHashMap<FieldId, ColumnAccessor>,
6219 next_field_id: FieldId,
6220}
6221
6222#[derive(Clone)]
6223enum ColumnAccessor {
6224 Int64(Arc<Int64Array>),
6225 Float64(Arc<Float64Array>),
6226 Boolean(Arc<BooleanArray>),
6227 Utf8(Arc<StringArray>),
6228 Null(usize),
6229}
6230
6231impl ColumnAccessor {
6232 fn from_array(array: &ArrayRef) -> ExecutorResult<Self> {
6233 match array.data_type() {
6234 DataType::Int64 => {
6235 let typed = array
6236 .as_any()
6237 .downcast_ref::<Int64Array>()
6238 .ok_or_else(|| Error::Internal("expected Int64 array".into()))?
6239 .clone();
6240 Ok(Self::Int64(Arc::new(typed)))
6241 }
6242 DataType::Float64 => {
6243 let typed = array
6244 .as_any()
6245 .downcast_ref::<Float64Array>()
6246 .ok_or_else(|| Error::Internal("expected Float64 array".into()))?
6247 .clone();
6248 Ok(Self::Float64(Arc::new(typed)))
6249 }
6250 DataType::Boolean => {
6251 let typed = array
6252 .as_any()
6253 .downcast_ref::<BooleanArray>()
6254 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
6255 .clone();
6256 Ok(Self::Boolean(Arc::new(typed)))
6257 }
6258 DataType::Utf8 => {
6259 let typed = array
6260 .as_any()
6261 .downcast_ref::<StringArray>()
6262 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
6263 .clone();
6264 Ok(Self::Utf8(Arc::new(typed)))
6265 }
6266 DataType::Null => Ok(Self::Null(array.len())),
6267 other => Err(Error::InvalidArgumentError(format!(
6268 "unsupported column type {:?} in cross product filter",
6269 other
6270 ))),
6271 }
6272 }
6273
6274 fn len(&self) -> usize {
6275 match self {
6276 ColumnAccessor::Int64(array) => array.len(),
6277 ColumnAccessor::Float64(array) => array.len(),
6278 ColumnAccessor::Boolean(array) => array.len(),
6279 ColumnAccessor::Utf8(array) => array.len(),
6280 ColumnAccessor::Null(len) => *len,
6281 }
6282 }
6283
6284 fn is_null(&self, idx: usize) -> bool {
6285 match self {
6286 ColumnAccessor::Int64(array) => array.is_null(idx),
6287 ColumnAccessor::Float64(array) => array.is_null(idx),
6288 ColumnAccessor::Boolean(array) => array.is_null(idx),
6289 ColumnAccessor::Utf8(array) => array.is_null(idx),
6290 ColumnAccessor::Null(_) => true,
6291 }
6292 }
6293
6294 fn literal_at(&self, idx: usize) -> ExecutorResult<Literal> {
6295 if self.is_null(idx) {
6296 return Ok(Literal::Null);
6297 }
6298 match self {
6299 ColumnAccessor::Int64(array) => Ok(Literal::Integer(array.value(idx) as i128)),
6300 ColumnAccessor::Float64(array) => Ok(Literal::Float(array.value(idx))),
6301 ColumnAccessor::Boolean(array) => Ok(Literal::Boolean(array.value(idx))),
6302 ColumnAccessor::Utf8(array) => Ok(Literal::String(array.value(idx).to_string())),
6303 ColumnAccessor::Null(_) => Ok(Literal::Null),
6304 }
6305 }
6306
6307 fn as_array_ref(&self) -> ArrayRef {
6308 match self {
6309 ColumnAccessor::Int64(array) => Arc::clone(array) as ArrayRef,
6310 ColumnAccessor::Float64(array) => Arc::clone(array) as ArrayRef,
6311 ColumnAccessor::Boolean(array) => Arc::clone(array) as ArrayRef,
6312 ColumnAccessor::Utf8(array) => Arc::clone(array) as ArrayRef,
6313 ColumnAccessor::Null(len) => new_null_array(&DataType::Null, *len),
6314 }
6315 }
6316}
6317
6318#[derive(Clone)]
6319enum ValueArray {
6320 Numeric(NumericArray),
6321 Boolean(Arc<BooleanArray>),
6322 Utf8(Arc<StringArray>),
6323 Null(usize),
6324}
6325
6326impl ValueArray {
6327 fn from_array(array: ArrayRef) -> ExecutorResult<Self> {
6328 match array.data_type() {
6329 DataType::Boolean => {
6330 let typed = array
6331 .as_any()
6332 .downcast_ref::<BooleanArray>()
6333 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
6334 .clone();
6335 Ok(Self::Boolean(Arc::new(typed)))
6336 }
6337 DataType::Utf8 => {
6338 let typed = array
6339 .as_any()
6340 .downcast_ref::<StringArray>()
6341 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
6342 .clone();
6343 Ok(Self::Utf8(Arc::new(typed)))
6344 }
6345 DataType::Null => Ok(Self::Null(array.len())),
6346 DataType::Int8
6347 | DataType::Int16
6348 | DataType::Int32
6349 | DataType::Int64
6350 | DataType::UInt8
6351 | DataType::UInt16
6352 | DataType::UInt32
6353 | DataType::UInt64
6354 | DataType::Float32
6355 | DataType::Float64 => {
6356 let numeric = NumericArray::try_from_arrow(&array)?;
6357 Ok(Self::Numeric(numeric))
6358 }
6359 other => Err(Error::InvalidArgumentError(format!(
6360 "unsupported data type {:?} in cross product expression",
6361 other
6362 ))),
6363 }
6364 }
6365
6366 fn len(&self) -> usize {
6367 match self {
6368 ValueArray::Numeric(array) => array.len(),
6369 ValueArray::Boolean(array) => array.len(),
6370 ValueArray::Utf8(array) => array.len(),
6371 ValueArray::Null(len) => *len,
6372 }
6373 }
6374}
6375
6376fn truth_and(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
6377 match (lhs, rhs) {
6378 (Some(false), _) | (_, Some(false)) => Some(false),
6379 (Some(true), Some(true)) => Some(true),
6380 (Some(true), None) | (None, Some(true)) | (None, None) => None,
6381 }
6382}
6383
6384fn truth_or(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
6385 match (lhs, rhs) {
6386 (Some(true), _) | (_, Some(true)) => Some(true),
6387 (Some(false), Some(false)) => Some(false),
6388 (Some(false), None) | (None, Some(false)) | (None, None) => None,
6389 }
6390}
6391
6392fn truth_not(value: Option<bool>) -> Option<bool> {
6393 match value {
6394 Some(true) => Some(false),
6395 Some(false) => Some(true),
6396 None => None,
6397 }
6398}
6399
6400fn compare_bool(op: CompareOp, lhs: bool, rhs: bool) -> bool {
6401 let l = lhs as u8;
6402 let r = rhs as u8;
6403 match op {
6404 CompareOp::Eq => lhs == rhs,
6405 CompareOp::NotEq => lhs != rhs,
6406 CompareOp::Lt => l < r,
6407 CompareOp::LtEq => l <= r,
6408 CompareOp::Gt => l > r,
6409 CompareOp::GtEq => l >= r,
6410 }
6411}
6412
6413fn compare_str(op: CompareOp, lhs: &str, rhs: &str) -> bool {
6414 match op {
6415 CompareOp::Eq => lhs == rhs,
6416 CompareOp::NotEq => lhs != rhs,
6417 CompareOp::Lt => lhs < rhs,
6418 CompareOp::LtEq => lhs <= rhs,
6419 CompareOp::Gt => lhs > rhs,
6420 CompareOp::GtEq => lhs >= rhs,
6421 }
6422}
6423
6424fn finalize_in_list_result(has_match: bool, saw_null: bool, negated: bool) -> Option<bool> {
6425 if has_match {
6426 Some(!negated)
6427 } else if saw_null {
6428 None
6429 } else if negated {
6430 Some(true)
6431 } else {
6432 Some(false)
6433 }
6434}
6435
6436fn literal_to_constant_array(literal: &Literal, len: usize) -> ExecutorResult<ArrayRef> {
6437 match literal {
6438 Literal::Integer(v) => {
6439 let value = i64::try_from(*v).unwrap_or(0);
6440 let values = vec![value; len];
6441 Ok(Arc::new(Int64Array::from(values)) as ArrayRef)
6442 }
6443 Literal::Float(v) => {
6444 let values = vec![*v; len];
6445 Ok(Arc::new(Float64Array::from(values)) as ArrayRef)
6446 }
6447 Literal::Boolean(v) => {
6448 let values = vec![Some(*v); len];
6449 Ok(Arc::new(BooleanArray::from(values)) as ArrayRef)
6450 }
6451 Literal::String(v) => {
6452 let values: Vec<Option<String>> = (0..len).map(|_| Some(v.clone())).collect();
6453 Ok(Arc::new(StringArray::from(values)) as ArrayRef)
6454 }
6455 Literal::Null => Ok(new_null_array(&DataType::Null, len)),
6456 Literal::Struct(_) => Err(Error::InvalidArgumentError(
6457 "struct literals are not supported in cross product filters".into(),
6458 )),
6459 }
6460}
6461
6462impl CrossProductExpressionContext {
6463 fn new(schema: &Schema, lookup: FxHashMap<String, usize>) -> ExecutorResult<Self> {
6464 let mut columns = Vec::with_capacity(schema.fields().len());
6465 let mut field_id_to_index = FxHashMap::default();
6466 let mut next_field_id: FieldId = 1;
6467
6468 for (idx, field) in schema.fields().iter().enumerate() {
6469 if next_field_id == u32::MAX {
6470 return Err(Error::Internal(
6471 "cross product projection exhausted FieldId space".into(),
6472 ));
6473 }
6474
6475 let executor_column = ExecutorColumn {
6476 name: field.name().clone(),
6477 data_type: field.data_type().clone(),
6478 nullable: field.is_nullable(),
6479 primary_key: false,
6480 unique: false,
6481 field_id: next_field_id,
6482 check_expr: None,
6483 };
6484 let field_id = next_field_id;
6485 next_field_id = next_field_id.saturating_add(1);
6486
6487 columns.push(executor_column);
6488 field_id_to_index.insert(field_id, idx);
6489 }
6490
6491 Ok(Self {
6492 schema: Arc::new(ExecutorSchema { columns, lookup }),
6493 field_id_to_index,
6494 numeric_cache: FxHashMap::default(),
6495 column_cache: FxHashMap::default(),
6496 next_field_id,
6497 })
6498 }
6499
6500 fn schema(&self) -> &ExecutorSchema {
6501 self.schema.as_ref()
6502 }
6503
6504 fn field_id_for_column(&self, name: &str) -> Option<FieldId> {
6505 self.schema.resolve(name).map(|column| column.field_id)
6506 }
6507
6508 fn reset(&mut self) {
6509 self.numeric_cache.clear();
6510 self.column_cache.clear();
6511 }
6512
6513 fn allocate_synthetic_field_id(&mut self) -> ExecutorResult<FieldId> {
6514 if self.next_field_id == FieldId::MAX {
6515 return Err(Error::Internal(
6516 "cross product projection exhausted FieldId space".into(),
6517 ));
6518 }
6519 let field_id = self.next_field_id;
6520 self.next_field_id = self.next_field_id.saturating_add(1);
6521 Ok(field_id)
6522 }
6523
6524 #[cfg(test)]
6525 fn evaluate(
6526 &mut self,
6527 expr: &ScalarExpr<String>,
6528 batch: &RecordBatch,
6529 ) -> ExecutorResult<ArrayRef> {
6530 let translated = translate_scalar(expr, self.schema.as_ref(), |name| {
6531 Error::InvalidArgumentError(format!(
6532 "column '{}' not found in cross product result",
6533 name
6534 ))
6535 })?;
6536
6537 self.evaluate_numeric(&translated, batch)
6538 }
6539
6540 fn evaluate_predicate_mask(
6541 &mut self,
6542 expr: &LlkvExpr<'static, FieldId>,
6543 batch: &RecordBatch,
6544 mut exists_eval: impl FnMut(
6545 &mut Self,
6546 &llkv_expr::SubqueryExpr,
6547 usize,
6548 &RecordBatch,
6549 ) -> ExecutorResult<Option<bool>>,
6550 ) -> ExecutorResult<BooleanArray> {
6551 let truths = self.evaluate_predicate_truths(expr, batch, &mut exists_eval)?;
6552 let mut builder = BooleanBuilder::with_capacity(truths.len());
6553 for value in truths {
6554 builder.append_value(value.unwrap_or(false));
6555 }
6556 Ok(builder.finish())
6557 }
6558
6559 fn evaluate_predicate_truths(
6560 &mut self,
6561 expr: &LlkvExpr<'static, FieldId>,
6562 batch: &RecordBatch,
6563 exists_eval: &mut impl FnMut(
6564 &mut Self,
6565 &llkv_expr::SubqueryExpr,
6566 usize,
6567 &RecordBatch,
6568 ) -> ExecutorResult<Option<bool>>,
6569 ) -> ExecutorResult<Vec<Option<bool>>> {
6570 match expr {
6571 LlkvExpr::Literal(value) => Ok(vec![Some(*value); batch.num_rows()]),
6572 LlkvExpr::And(children) => {
6573 if children.is_empty() {
6574 return Ok(vec![Some(true); batch.num_rows()]);
6575 }
6576 let mut result =
6577 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
6578 for child in &children[1..] {
6579 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
6580 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
6581 *lhs = truth_and(*lhs, rhs);
6582 }
6583 }
6584 Ok(result)
6585 }
6586 LlkvExpr::Or(children) => {
6587 if children.is_empty() {
6588 return Ok(vec![Some(false); batch.num_rows()]);
6589 }
6590 let mut result =
6591 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
6592 for child in &children[1..] {
6593 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
6594 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
6595 *lhs = truth_or(*lhs, rhs);
6596 }
6597 }
6598 Ok(result)
6599 }
6600 LlkvExpr::Not(inner) => {
6601 let mut values = self.evaluate_predicate_truths(inner, batch, exists_eval)?;
6602 for value in &mut values {
6603 *value = truth_not(*value);
6604 }
6605 Ok(values)
6606 }
6607 LlkvExpr::Pred(filter) => self.evaluate_filter_truths(filter, batch),
6608 LlkvExpr::Compare { left, op, right } => {
6609 self.evaluate_compare_truths(left, *op, right, batch)
6610 }
6611 LlkvExpr::InList {
6612 expr: target,
6613 list,
6614 negated,
6615 } => self.evaluate_in_list_truths(target, list, *negated, batch),
6616 LlkvExpr::IsNull { expr, negated } => {
6617 self.evaluate_is_null_truths(expr, *negated, batch)
6618 }
6619 LlkvExpr::Exists(subquery_expr) => {
6620 let mut values = Vec::with_capacity(batch.num_rows());
6621 for row_idx in 0..batch.num_rows() {
6622 let value = exists_eval(self, subquery_expr, row_idx, batch)?;
6623 values.push(value);
6624 }
6625 Ok(values)
6626 }
6627 }
6628 }
6629
6630 fn evaluate_filter_truths(
6631 &mut self,
6632 filter: &Filter<FieldId>,
6633 batch: &RecordBatch,
6634 ) -> ExecutorResult<Vec<Option<bool>>> {
6635 let accessor = self.column_accessor(filter.field_id, batch)?;
6636 let len = accessor.len();
6637
6638 match &filter.op {
6639 Operator::IsNull => {
6640 let mut out = Vec::with_capacity(len);
6641 for idx in 0..len {
6642 out.push(Some(accessor.is_null(idx)));
6643 }
6644 Ok(out)
6645 }
6646 Operator::IsNotNull => {
6647 let mut out = Vec::with_capacity(len);
6648 for idx in 0..len {
6649 out.push(Some(!accessor.is_null(idx)));
6650 }
6651 Ok(out)
6652 }
6653 _ => match accessor {
6654 ColumnAccessor::Int64(array) => {
6655 let predicate = build_fixed_width_predicate::<Int64Type>(&filter.op)
6656 .map_err(Error::predicate_build)?;
6657 let mut out = Vec::with_capacity(len);
6658 for idx in 0..len {
6659 if array.is_null(idx) {
6660 out.push(None);
6661 } else {
6662 let value = array.value(idx);
6663 out.push(Some(predicate.matches(&value)));
6664 }
6665 }
6666 Ok(out)
6667 }
6668 ColumnAccessor::Float64(array) => {
6669 let predicate = build_fixed_width_predicate::<Float64Type>(&filter.op)
6670 .map_err(Error::predicate_build)?;
6671 let mut out = Vec::with_capacity(len);
6672 for idx in 0..len {
6673 if array.is_null(idx) {
6674 out.push(None);
6675 } else {
6676 let value = array.value(idx);
6677 out.push(Some(predicate.matches(&value)));
6678 }
6679 }
6680 Ok(out)
6681 }
6682 ColumnAccessor::Boolean(array) => {
6683 let predicate =
6684 build_bool_predicate(&filter.op).map_err(Error::predicate_build)?;
6685 let mut out = Vec::with_capacity(len);
6686 for idx in 0..len {
6687 if array.is_null(idx) {
6688 out.push(None);
6689 } else {
6690 let value = array.value(idx);
6691 out.push(Some(predicate.matches(&value)));
6692 }
6693 }
6694 Ok(out)
6695 }
6696 ColumnAccessor::Utf8(array) => {
6697 let predicate =
6698 build_var_width_predicate(&filter.op).map_err(Error::predicate_build)?;
6699 let mut out = Vec::with_capacity(len);
6700 for idx in 0..len {
6701 if array.is_null(idx) {
6702 out.push(None);
6703 } else {
6704 let value = array.value(idx);
6705 out.push(Some(predicate.matches(value)));
6706 }
6707 }
6708 Ok(out)
6709 }
6710 ColumnAccessor::Null(len) => Ok(vec![None; len]),
6711 },
6712 }
6713 }
6714
6715 fn evaluate_compare_truths(
6716 &mut self,
6717 left: &ScalarExpr<FieldId>,
6718 op: CompareOp,
6719 right: &ScalarExpr<FieldId>,
6720 batch: &RecordBatch,
6721 ) -> ExecutorResult<Vec<Option<bool>>> {
6722 let left_values = self.materialize_value_array(left, batch)?;
6723 let right_values = self.materialize_value_array(right, batch)?;
6724
6725 if left_values.len() != right_values.len() {
6726 return Err(Error::Internal(
6727 "mismatched compare operand lengths in cross product filter".into(),
6728 ));
6729 }
6730
6731 let len = left_values.len();
6732 match (&left_values, &right_values) {
6733 (ValueArray::Null(_), _) | (_, ValueArray::Null(_)) => Ok(vec![None; len]),
6734 (ValueArray::Numeric(lhs), ValueArray::Numeric(rhs)) => {
6735 let mut out = Vec::with_capacity(len);
6736 for idx in 0..len {
6737 match (lhs.value(idx), rhs.value(idx)) {
6738 (Some(lv), Some(rv)) => out.push(Some(NumericKernels::compare(op, lv, rv))),
6739 _ => out.push(None),
6740 }
6741 }
6742 Ok(out)
6743 }
6744 (ValueArray::Boolean(lhs), ValueArray::Boolean(rhs)) => {
6745 let lhs = lhs.as_ref();
6746 let rhs = rhs.as_ref();
6747 let mut out = Vec::with_capacity(len);
6748 for idx in 0..len {
6749 if lhs.is_null(idx) || rhs.is_null(idx) {
6750 out.push(None);
6751 } else {
6752 out.push(Some(compare_bool(op, lhs.value(idx), rhs.value(idx))));
6753 }
6754 }
6755 Ok(out)
6756 }
6757 (ValueArray::Utf8(lhs), ValueArray::Utf8(rhs)) => {
6758 let lhs = lhs.as_ref();
6759 let rhs = rhs.as_ref();
6760 let mut out = Vec::with_capacity(len);
6761 for idx in 0..len {
6762 if lhs.is_null(idx) || rhs.is_null(idx) {
6763 out.push(None);
6764 } else {
6765 out.push(Some(compare_str(op, lhs.value(idx), rhs.value(idx))));
6766 }
6767 }
6768 Ok(out)
6769 }
6770 _ => Err(Error::InvalidArgumentError(
6771 "unsupported comparison between mismatched types in cross product filter".into(),
6772 )),
6773 }
6774 }
6775
6776 fn evaluate_is_null_truths(
6777 &mut self,
6778 expr: &ScalarExpr<FieldId>,
6779 negated: bool,
6780 batch: &RecordBatch,
6781 ) -> ExecutorResult<Vec<Option<bool>>> {
6782 let values = self.materialize_value_array(expr, batch)?;
6783 let len = values.len();
6784
6785 match &values {
6786 ValueArray::Null(len) => {
6787 let result = if negated {
6789 Some(false) } else {
6791 Some(true) };
6793 Ok(vec![result; *len])
6794 }
6795 ValueArray::Numeric(arr) => {
6796 let mut out = Vec::with_capacity(len);
6797 for idx in 0..len {
6798 let is_null = arr.value(idx).is_none();
6799 let result = if negated {
6800 !is_null } else {
6802 is_null };
6804 out.push(Some(result));
6805 }
6806 Ok(out)
6807 }
6808 ValueArray::Boolean(arr) => {
6809 let mut out = Vec::with_capacity(len);
6810 for idx in 0..len {
6811 let is_null = arr.is_null(idx);
6812 let result = if negated { !is_null } else { is_null };
6813 out.push(Some(result));
6814 }
6815 Ok(out)
6816 }
6817 ValueArray::Utf8(arr) => {
6818 let mut out = Vec::with_capacity(len);
6819 for idx in 0..len {
6820 let is_null = arr.is_null(idx);
6821 let result = if negated { !is_null } else { is_null };
6822 out.push(Some(result));
6823 }
6824 Ok(out)
6825 }
6826 }
6827 }
6828
6829 fn evaluate_in_list_truths(
6830 &mut self,
6831 target: &ScalarExpr<FieldId>,
6832 list: &[ScalarExpr<FieldId>],
6833 negated: bool,
6834 batch: &RecordBatch,
6835 ) -> ExecutorResult<Vec<Option<bool>>> {
6836 let target_values = self.materialize_value_array(target, batch)?;
6837 let list_values = list
6838 .iter()
6839 .map(|expr| self.materialize_value_array(expr, batch))
6840 .collect::<ExecutorResult<Vec<_>>>()?;
6841
6842 let len = target_values.len();
6843 for values in &list_values {
6844 if values.len() != len {
6845 return Err(Error::Internal(
6846 "mismatched IN list operand lengths in cross product filter".into(),
6847 ));
6848 }
6849 }
6850
6851 match &target_values {
6852 ValueArray::Numeric(target_numeric) => {
6853 let mut out = Vec::with_capacity(len);
6854 for idx in 0..len {
6855 let target_value = match target_numeric.value(idx) {
6856 Some(value) => value,
6857 None => {
6858 out.push(None);
6859 continue;
6860 }
6861 };
6862 let mut has_match = false;
6863 let mut saw_null = false;
6864 for candidate in &list_values {
6865 match candidate {
6866 ValueArray::Numeric(array) => match array.value(idx) {
6867 Some(value) => {
6868 if NumericKernels::compare(CompareOp::Eq, target_value, value) {
6869 has_match = true;
6870 break;
6871 }
6872 }
6873 None => saw_null = true,
6874 },
6875 ValueArray::Null(_) => saw_null = true,
6876 _ => {
6877 return Err(Error::InvalidArgumentError(
6878 "type mismatch in IN list evaluation".into(),
6879 ));
6880 }
6881 }
6882 }
6883 out.push(finalize_in_list_result(has_match, saw_null, negated));
6884 }
6885 Ok(out)
6886 }
6887 ValueArray::Boolean(target_bool) => {
6888 let mut out = Vec::with_capacity(len);
6889 for idx in 0..len {
6890 if target_bool.is_null(idx) {
6891 out.push(None);
6892 continue;
6893 }
6894 let target_value = target_bool.value(idx);
6895 let mut has_match = false;
6896 let mut saw_null = false;
6897 for candidate in &list_values {
6898 match candidate {
6899 ValueArray::Boolean(array) => {
6900 if array.is_null(idx) {
6901 saw_null = true;
6902 } else if array.value(idx) == target_value {
6903 has_match = true;
6904 break;
6905 }
6906 }
6907 ValueArray::Null(_) => saw_null = true,
6908 _ => {
6909 return Err(Error::InvalidArgumentError(
6910 "type mismatch in IN list evaluation".into(),
6911 ));
6912 }
6913 }
6914 }
6915 out.push(finalize_in_list_result(has_match, saw_null, negated));
6916 }
6917 Ok(out)
6918 }
6919 ValueArray::Utf8(target_utf8) => {
6920 let mut out = Vec::with_capacity(len);
6921 for idx in 0..len {
6922 if target_utf8.is_null(idx) {
6923 out.push(None);
6924 continue;
6925 }
6926 let target_value = target_utf8.value(idx);
6927 let mut has_match = false;
6928 let mut saw_null = false;
6929 for candidate in &list_values {
6930 match candidate {
6931 ValueArray::Utf8(array) => {
6932 if array.is_null(idx) {
6933 saw_null = true;
6934 } else if array.value(idx) == target_value {
6935 has_match = true;
6936 break;
6937 }
6938 }
6939 ValueArray::Null(_) => saw_null = true,
6940 _ => {
6941 return Err(Error::InvalidArgumentError(
6942 "type mismatch in IN list evaluation".into(),
6943 ));
6944 }
6945 }
6946 }
6947 out.push(finalize_in_list_result(has_match, saw_null, negated));
6948 }
6949 Ok(out)
6950 }
6951 ValueArray::Null(len) => Ok(vec![None; *len]),
6952 }
6953 }
6954
6955 fn evaluate_numeric(
6956 &mut self,
6957 expr: &ScalarExpr<FieldId>,
6958 batch: &RecordBatch,
6959 ) -> ExecutorResult<ArrayRef> {
6960 let mut required = FxHashSet::default();
6961 collect_field_ids(expr, &mut required);
6962
6963 let mut arrays = NumericArrayMap::default();
6964 for field_id in required {
6965 let numeric = self.numeric_array(field_id, batch)?;
6966 arrays.insert(field_id, numeric);
6967 }
6968
6969 NumericKernels::evaluate_batch(expr, batch.num_rows(), &arrays)
6970 }
6971
6972 fn numeric_array(
6973 &mut self,
6974 field_id: FieldId,
6975 batch: &RecordBatch,
6976 ) -> ExecutorResult<NumericArray> {
6977 if let Some(existing) = self.numeric_cache.get(&field_id) {
6978 return Ok(existing.clone());
6979 }
6980
6981 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
6982 Error::Internal("field mapping missing during cross product evaluation".into())
6983 })?;
6984
6985 let array_ref = batch.column(column_index).clone();
6986 let numeric = NumericArray::try_from_arrow(&array_ref)?;
6987 self.numeric_cache.insert(field_id, numeric.clone());
6988 Ok(numeric)
6989 }
6990
6991 fn column_accessor(
6992 &mut self,
6993 field_id: FieldId,
6994 batch: &RecordBatch,
6995 ) -> ExecutorResult<ColumnAccessor> {
6996 if let Some(existing) = self.column_cache.get(&field_id) {
6997 return Ok(existing.clone());
6998 }
6999
7000 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
7001 Error::Internal("field mapping missing during cross product evaluation".into())
7002 })?;
7003
7004 let accessor = ColumnAccessor::from_array(batch.column(column_index))?;
7005 self.column_cache.insert(field_id, accessor.clone());
7006 Ok(accessor)
7007 }
7008
7009 fn materialize_scalar_array(
7010 &mut self,
7011 expr: &ScalarExpr<FieldId>,
7012 batch: &RecordBatch,
7013 ) -> ExecutorResult<ArrayRef> {
7014 match expr {
7015 ScalarExpr::Column(field_id) => {
7016 let accessor = self.column_accessor(*field_id, batch)?;
7017 Ok(accessor.as_array_ref())
7018 }
7019 ScalarExpr::Literal(literal) => literal_to_constant_array(literal, batch.num_rows()),
7020 ScalarExpr::Binary { .. } => self.evaluate_numeric(expr, batch),
7021 ScalarExpr::Compare { .. } => self.evaluate_numeric(expr, batch),
7022 ScalarExpr::Not(_) => self.evaluate_numeric(expr, batch),
7023 ScalarExpr::IsNull { .. } => self.evaluate_numeric(expr, batch),
7024 ScalarExpr::Aggregate(_) => Err(Error::InvalidArgumentError(
7025 "aggregate expressions are not supported in cross product filters".into(),
7026 )),
7027 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
7028 "struct field access is not supported in cross product filters".into(),
7029 )),
7030 ScalarExpr::Cast { expr, data_type } => {
7031 let source = self.materialize_scalar_array(expr.as_ref(), batch)?;
7032 let casted = cast(source.as_ref(), data_type).map_err(|err| {
7033 Error::InvalidArgumentError(format!("failed to cast expression: {err}"))
7034 })?;
7035 Ok(casted)
7036 }
7037 ScalarExpr::Case { .. } => self.evaluate_numeric(expr, batch),
7038 ScalarExpr::Coalesce(_) => self.evaluate_numeric(expr, batch),
7039 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
7040 "scalar subqueries are not supported in cross product filters".into(),
7041 )),
7042 }
7043 }
7044
7045 fn materialize_value_array(
7046 &mut self,
7047 expr: &ScalarExpr<FieldId>,
7048 batch: &RecordBatch,
7049 ) -> ExecutorResult<ValueArray> {
7050 let array = self.materialize_scalar_array(expr, batch)?;
7051 ValueArray::from_array(array)
7052 }
7053}
7054
7055fn collect_field_ids(expr: &ScalarExpr<FieldId>, out: &mut FxHashSet<FieldId>) {
7057 match expr {
7058 ScalarExpr::Column(fid) => {
7059 out.insert(*fid);
7060 }
7061 ScalarExpr::Binary { left, right, .. } => {
7062 collect_field_ids(left, out);
7063 collect_field_ids(right, out);
7064 }
7065 ScalarExpr::Compare { left, right, .. } => {
7066 collect_field_ids(left, out);
7067 collect_field_ids(right, out);
7068 }
7069 ScalarExpr::Aggregate(call) => match call {
7070 AggregateCall::CountStar => {}
7071 AggregateCall::Count { expr, .. }
7072 | AggregateCall::Sum { expr, .. }
7073 | AggregateCall::Total { expr, .. }
7074 | AggregateCall::Avg { expr, .. }
7075 | AggregateCall::Min(expr)
7076 | AggregateCall::Max(expr)
7077 | AggregateCall::CountNulls(expr)
7078 | AggregateCall::GroupConcat { expr, .. } => {
7079 collect_field_ids(expr, out);
7080 }
7081 },
7082 ScalarExpr::GetField { base, .. } => collect_field_ids(base, out),
7083 ScalarExpr::Cast { expr, .. } => collect_field_ids(expr, out),
7084 ScalarExpr::Not(expr) => collect_field_ids(expr, out),
7085 ScalarExpr::IsNull { expr, .. } => collect_field_ids(expr, out),
7086 ScalarExpr::Case {
7087 operand,
7088 branches,
7089 else_expr,
7090 } => {
7091 if let Some(inner) = operand.as_deref() {
7092 collect_field_ids(inner, out);
7093 }
7094 for (when_expr, then_expr) in branches {
7095 collect_field_ids(when_expr, out);
7096 collect_field_ids(then_expr, out);
7097 }
7098 if let Some(inner) = else_expr.as_deref() {
7099 collect_field_ids(inner, out);
7100 }
7101 }
7102 ScalarExpr::Coalesce(items) => {
7103 for item in items {
7104 collect_field_ids(item, out);
7105 }
7106 }
7107 ScalarExpr::Literal(_) => {}
7108 ScalarExpr::ScalarSubquery(_) => {}
7109 }
7110}
7111
7112fn strip_exists(expr: &LlkvExpr<'static, FieldId>) -> LlkvExpr<'static, FieldId> {
7113 match expr {
7114 LlkvExpr::And(children) => LlkvExpr::And(children.iter().map(strip_exists).collect()),
7115 LlkvExpr::Or(children) => LlkvExpr::Or(children.iter().map(strip_exists).collect()),
7116 LlkvExpr::Not(inner) => LlkvExpr::Not(Box::new(strip_exists(inner))),
7117 LlkvExpr::Pred(filter) => LlkvExpr::Pred(filter.clone()),
7118 LlkvExpr::Compare { left, op, right } => LlkvExpr::Compare {
7119 left: left.clone(),
7120 op: *op,
7121 right: right.clone(),
7122 },
7123 LlkvExpr::InList {
7124 expr,
7125 list,
7126 negated,
7127 } => LlkvExpr::InList {
7128 expr: expr.clone(),
7129 list: list.clone(),
7130 negated: *negated,
7131 },
7132 LlkvExpr::IsNull { expr, negated } => LlkvExpr::IsNull {
7133 expr: expr.clone(),
7134 negated: *negated,
7135 },
7136 LlkvExpr::Literal(value) => LlkvExpr::Literal(*value),
7137 LlkvExpr::Exists(_) => LlkvExpr::Literal(true),
7138 }
7139}
7140
7141fn bind_select_plan(
7142 plan: &SelectPlan,
7143 bindings: &FxHashMap<String, Literal>,
7144) -> ExecutorResult<SelectPlan> {
7145 if bindings.is_empty() {
7146 return Ok(plan.clone());
7147 }
7148
7149 let projections = plan
7150 .projections
7151 .iter()
7152 .map(|projection| bind_projection(projection, bindings))
7153 .collect::<ExecutorResult<Vec<_>>>()?;
7154
7155 let filter = match &plan.filter {
7156 Some(wrapper) => Some(bind_select_filter(wrapper, bindings)?),
7157 None => None,
7158 };
7159
7160 let aggregates = plan
7161 .aggregates
7162 .iter()
7163 .map(|aggregate| bind_aggregate_expr(aggregate, bindings))
7164 .collect::<ExecutorResult<Vec<_>>>()?;
7165
7166 let scalar_subqueries = plan
7167 .scalar_subqueries
7168 .iter()
7169 .map(|subquery| bind_scalar_subquery(subquery, bindings))
7170 .collect::<ExecutorResult<Vec<_>>>()?;
7171
7172 if let Some(compound) = &plan.compound {
7173 let bound_compound = bind_compound_select(compound, bindings)?;
7174 return Ok(SelectPlan {
7175 tables: Vec::new(),
7176 joins: Vec::new(),
7177 projections: Vec::new(),
7178 filter: None,
7179 having: None,
7180 aggregates: Vec::new(),
7181 order_by: plan.order_by.clone(),
7182 distinct: false,
7183 scalar_subqueries: Vec::new(),
7184 compound: Some(bound_compound),
7185 group_by: Vec::new(),
7186 value_table_mode: None,
7187 });
7188 }
7189
7190 Ok(SelectPlan {
7191 tables: plan.tables.clone(),
7192 joins: plan.joins.clone(),
7193 projections,
7194 filter,
7195 having: plan.having.clone(),
7196 aggregates,
7197 order_by: Vec::new(),
7198 distinct: plan.distinct,
7199 scalar_subqueries,
7200 compound: None,
7201 group_by: plan.group_by.clone(),
7202 value_table_mode: plan.value_table_mode.clone(),
7203 })
7204}
7205
7206fn bind_compound_select(
7207 compound: &CompoundSelectPlan,
7208 bindings: &FxHashMap<String, Literal>,
7209) -> ExecutorResult<CompoundSelectPlan> {
7210 let initial = bind_select_plan(&compound.initial, bindings)?;
7211 let mut operations = Vec::with_capacity(compound.operations.len());
7212 for component in &compound.operations {
7213 let bound_plan = bind_select_plan(&component.plan, bindings)?;
7214 operations.push(CompoundSelectComponent {
7215 operator: component.operator.clone(),
7216 quantifier: component.quantifier.clone(),
7217 plan: bound_plan,
7218 });
7219 }
7220 Ok(CompoundSelectPlan {
7221 initial: Box::new(initial),
7222 operations,
7223 })
7224}
7225
7226fn ensure_schema_compatibility(base: &Schema, other: &Schema) -> ExecutorResult<()> {
7227 if base.fields().len() != other.fields().len() {
7228 return Err(Error::InvalidArgumentError(
7229 "compound SELECT requires matching column counts".into(),
7230 ));
7231 }
7232 for (left, right) in base.fields().iter().zip(other.fields().iter()) {
7233 if left.data_type() != right.data_type() {
7234 return Err(Error::InvalidArgumentError(format!(
7235 "compound SELECT column type mismatch: {} vs {}",
7236 left.data_type(),
7237 right.data_type()
7238 )));
7239 }
7240 }
7241 Ok(())
7242}
7243
7244fn ensure_distinct_rows(rows: &mut Vec<Vec<PlanValue>>, cache: &mut Option<FxHashSet<Vec<u8>>>) {
7245 if cache.is_some() {
7246 return;
7247 }
7248 let mut set = FxHashSet::default();
7249 let mut deduped: Vec<Vec<PlanValue>> = Vec::with_capacity(rows.len());
7250 for row in rows.drain(..) {
7251 let key = encode_row(&row);
7252 if set.insert(key) {
7253 deduped.push(row);
7254 }
7255 }
7256 *rows = deduped;
7257 *cache = Some(set);
7258}
7259
7260fn encode_row(row: &[PlanValue]) -> Vec<u8> {
7261 let mut buf = Vec::new();
7262 for value in row {
7263 encode_plan_value(&mut buf, value);
7264 buf.push(0x1F);
7265 }
7266 buf
7267}
7268
7269fn encode_plan_value(buf: &mut Vec<u8>, value: &PlanValue) {
7270 match value {
7271 PlanValue::Null => buf.push(0),
7272 PlanValue::Integer(v) => {
7273 buf.push(1);
7274 buf.extend_from_slice(&v.to_be_bytes());
7275 }
7276 PlanValue::Float(v) => {
7277 buf.push(2);
7278 buf.extend_from_slice(&v.to_bits().to_be_bytes());
7279 }
7280 PlanValue::String(s) => {
7281 buf.push(3);
7282 let bytes = s.as_bytes();
7283 let len = u32::try_from(bytes.len()).unwrap_or(u32::MAX);
7284 buf.extend_from_slice(&len.to_be_bytes());
7285 buf.extend_from_slice(bytes);
7286 }
7287 PlanValue::Struct(map) => {
7288 buf.push(4);
7289 let mut entries: Vec<_> = map.iter().collect();
7290 entries.sort_by(|a, b| a.0.cmp(b.0));
7291 let len = u32::try_from(entries.len()).unwrap_or(u32::MAX);
7292 buf.extend_from_slice(&len.to_be_bytes());
7293 for (key, val) in entries {
7294 let key_bytes = key.as_bytes();
7295 let key_len = u32::try_from(key_bytes.len()).unwrap_or(u32::MAX);
7296 buf.extend_from_slice(&key_len.to_be_bytes());
7297 buf.extend_from_slice(key_bytes);
7298 encode_plan_value(buf, val);
7299 }
7300 }
7301 }
7302}
7303
7304fn rows_to_record_batch(
7305 schema: Arc<Schema>,
7306 rows: &[Vec<PlanValue>],
7307) -> ExecutorResult<RecordBatch> {
7308 let column_count = schema.fields().len();
7309 let mut columns: Vec<Vec<PlanValue>> = vec![Vec::with_capacity(rows.len()); column_count];
7310 for row in rows {
7311 if row.len() != column_count {
7312 return Err(Error::InvalidArgumentError(
7313 "compound SELECT produced mismatched column counts".into(),
7314 ));
7315 }
7316 for (idx, value) in row.iter().enumerate() {
7317 columns[idx].push(value.clone());
7318 }
7319 }
7320
7321 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(column_count);
7322 for (idx, field) in schema.fields().iter().enumerate() {
7323 let array = build_array_for_column(field.data_type(), &columns[idx])?;
7324 arrays.push(array);
7325 }
7326
7327 RecordBatch::try_new(schema, arrays).map_err(|err| {
7328 Error::InvalidArgumentError(format!("failed to materialize compound SELECT: {err}"))
7329 })
7330}
7331
7332fn build_column_lookup_map(schema: &Schema) -> FxHashMap<String, usize> {
7333 let mut lookup = FxHashMap::default();
7334 for (idx, field) in schema.fields().iter().enumerate() {
7335 lookup.insert(field.name().to_ascii_lowercase(), idx);
7336 }
7337 lookup
7338}
7339
7340fn build_group_key(
7341 batch: &RecordBatch,
7342 row_idx: usize,
7343 key_indices: &[usize],
7344) -> ExecutorResult<Vec<GroupKeyValue>> {
7345 let mut values = Vec::with_capacity(key_indices.len());
7346 for &index in key_indices {
7347 values.push(group_key_value(batch.column(index), row_idx)?);
7348 }
7349 Ok(values)
7350}
7351
7352fn group_key_value(array: &ArrayRef, row_idx: usize) -> ExecutorResult<GroupKeyValue> {
7353 if !array.is_valid(row_idx) {
7354 return Ok(GroupKeyValue::Null);
7355 }
7356
7357 match array.data_type() {
7358 DataType::Int8 => {
7359 let values = array
7360 .as_any()
7361 .downcast_ref::<Int8Array>()
7362 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
7363 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7364 }
7365 DataType::Int16 => {
7366 let values = array
7367 .as_any()
7368 .downcast_ref::<Int16Array>()
7369 .ok_or_else(|| Error::Internal("failed to downcast to Int16Array".into()))?;
7370 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7371 }
7372 DataType::Int32 => {
7373 let values = array
7374 .as_any()
7375 .downcast_ref::<Int32Array>()
7376 .ok_or_else(|| Error::Internal("failed to downcast to Int32Array".into()))?;
7377 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7378 }
7379 DataType::Int64 => {
7380 let values = array
7381 .as_any()
7382 .downcast_ref::<Int64Array>()
7383 .ok_or_else(|| Error::Internal("failed to downcast to Int64Array".into()))?;
7384 Ok(GroupKeyValue::Int(values.value(row_idx)))
7385 }
7386 DataType::UInt8 => {
7387 let values = array
7388 .as_any()
7389 .downcast_ref::<UInt8Array>()
7390 .ok_or_else(|| Error::Internal("failed to downcast to UInt8Array".into()))?;
7391 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7392 }
7393 DataType::UInt16 => {
7394 let values = array
7395 .as_any()
7396 .downcast_ref::<UInt16Array>()
7397 .ok_or_else(|| Error::Internal("failed to downcast to UInt16Array".into()))?;
7398 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7399 }
7400 DataType::UInt32 => {
7401 let values = array
7402 .as_any()
7403 .downcast_ref::<UInt32Array>()
7404 .ok_or_else(|| Error::Internal("failed to downcast to UInt32Array".into()))?;
7405 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
7406 }
7407 DataType::UInt64 => {
7408 let values = array
7409 .as_any()
7410 .downcast_ref::<UInt64Array>()
7411 .ok_or_else(|| Error::Internal("failed to downcast to UInt64Array".into()))?;
7412 let value = values.value(row_idx);
7413 if value > i64::MAX as u64 {
7414 return Err(Error::InvalidArgumentError(
7415 "GROUP BY value exceeds supported integer range".into(),
7416 ));
7417 }
7418 Ok(GroupKeyValue::Int(value as i64))
7419 }
7420 DataType::Boolean => {
7421 let values = array
7422 .as_any()
7423 .downcast_ref::<BooleanArray>()
7424 .ok_or_else(|| Error::Internal("failed to downcast to BooleanArray".into()))?;
7425 Ok(GroupKeyValue::Bool(values.value(row_idx)))
7426 }
7427 DataType::Utf8 => {
7428 let values = array
7429 .as_any()
7430 .downcast_ref::<StringArray>()
7431 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
7432 Ok(GroupKeyValue::String(values.value(row_idx).to_string()))
7433 }
7434 other => Err(Error::InvalidArgumentError(format!(
7435 "GROUP BY does not support column type {:?}",
7436 other
7437 ))),
7438 }
7439}
7440
7441fn evaluate_constant_predicate(expr: &LlkvExpr<'static, String>) -> Option<Option<bool>> {
7442 match expr {
7443 LlkvExpr::Literal(value) => Some(Some(*value)),
7444 LlkvExpr::Not(inner) => {
7445 let inner_val = evaluate_constant_predicate(inner)?;
7446 Some(truth_not(inner_val))
7447 }
7448 LlkvExpr::And(children) => {
7449 let mut acc = Some(true);
7450 for child in children {
7451 let child_val = evaluate_constant_predicate(child)?;
7452 acc = truth_and(acc, child_val);
7453 }
7454 Some(acc)
7455 }
7456 LlkvExpr::Or(children) => {
7457 let mut acc = Some(false);
7458 for child in children {
7459 let child_val = evaluate_constant_predicate(child)?;
7460 acc = truth_or(acc, child_val);
7461 }
7462 Some(acc)
7463 }
7464 LlkvExpr::Compare { left, op, right } => {
7465 let left_literal = evaluate_constant_scalar(left)?;
7466 let right_literal = evaluate_constant_scalar(right)?;
7467 Some(compare_literals(*op, &left_literal, &right_literal))
7468 }
7469 LlkvExpr::IsNull { expr, negated } => {
7470 let literal = evaluate_constant_scalar(expr)?;
7471 let is_null = matches!(literal, Literal::Null);
7472 Some(Some(if *negated { !is_null } else { is_null }))
7473 }
7474 LlkvExpr::InList {
7475 expr,
7476 list,
7477 negated,
7478 } => {
7479 let needle = evaluate_constant_scalar(expr)?;
7480 let mut saw_unknown = false;
7481
7482 for candidate in list {
7483 let value = evaluate_constant_scalar(candidate)?;
7484 match compare_literals(CompareOp::Eq, &needle, &value) {
7485 Some(true) => {
7486 return Some(Some(!*negated));
7487 }
7488 Some(false) => {}
7489 None => saw_unknown = true,
7490 }
7491 }
7492
7493 if saw_unknown {
7494 Some(None)
7495 } else {
7496 Some(Some(*negated))
7497 }
7498 }
7499 _ => None,
7500 }
7501}
7502
7503enum ConstantJoinEvaluation {
7504 Known(bool),
7505 Unknown,
7506 NotConstant,
7507}
7508
7509fn evaluate_constant_join_expr(expr: &LlkvExpr<'static, String>) -> ConstantJoinEvaluation {
7510 match expr {
7511 LlkvExpr::Literal(value) => ConstantJoinEvaluation::Known(*value),
7512 LlkvExpr::And(children) => {
7513 let mut saw_unknown = false;
7514 for child in children {
7515 match evaluate_constant_join_expr(child) {
7516 ConstantJoinEvaluation::Known(false) => {
7517 return ConstantJoinEvaluation::Known(false);
7518 }
7519 ConstantJoinEvaluation::Known(true) => {}
7520 ConstantJoinEvaluation::Unknown => saw_unknown = true,
7521 ConstantJoinEvaluation::NotConstant => {
7522 return ConstantJoinEvaluation::NotConstant;
7523 }
7524 }
7525 }
7526 if saw_unknown {
7527 ConstantJoinEvaluation::Unknown
7528 } else {
7529 ConstantJoinEvaluation::Known(true)
7530 }
7531 }
7532 LlkvExpr::Or(children) => {
7533 let mut saw_unknown = false;
7534 for child in children {
7535 match evaluate_constant_join_expr(child) {
7536 ConstantJoinEvaluation::Known(true) => {
7537 return ConstantJoinEvaluation::Known(true);
7538 }
7539 ConstantJoinEvaluation::Known(false) => {}
7540 ConstantJoinEvaluation::Unknown => saw_unknown = true,
7541 ConstantJoinEvaluation::NotConstant => {
7542 return ConstantJoinEvaluation::NotConstant;
7543 }
7544 }
7545 }
7546 if saw_unknown {
7547 ConstantJoinEvaluation::Unknown
7548 } else {
7549 ConstantJoinEvaluation::Known(false)
7550 }
7551 }
7552 LlkvExpr::Not(inner) => match evaluate_constant_join_expr(inner) {
7553 ConstantJoinEvaluation::Known(value) => ConstantJoinEvaluation::Known(!value),
7554 ConstantJoinEvaluation::Unknown => ConstantJoinEvaluation::Unknown,
7555 ConstantJoinEvaluation::NotConstant => ConstantJoinEvaluation::NotConstant,
7556 },
7557 LlkvExpr::Compare { left, op, right } => {
7558 let left_lit = evaluate_constant_scalar(left);
7559 let right_lit = evaluate_constant_scalar(right);
7560
7561 if matches!(left_lit, Some(Literal::Null)) || matches!(right_lit, Some(Literal::Null)) {
7562 return ConstantJoinEvaluation::Unknown;
7564 }
7565
7566 let (Some(left_lit), Some(right_lit)) = (left_lit, right_lit) else {
7567 return ConstantJoinEvaluation::NotConstant;
7568 };
7569
7570 match compare_literals(*op, &left_lit, &right_lit) {
7571 Some(result) => ConstantJoinEvaluation::Known(result),
7572 None => ConstantJoinEvaluation::Unknown,
7573 }
7574 }
7575 LlkvExpr::IsNull { expr, negated } => match evaluate_constant_scalar(expr) {
7576 Some(literal) => {
7577 let is_null = matches!(literal, Literal::Null);
7578 let value = if *negated { !is_null } else { is_null };
7579 ConstantJoinEvaluation::Known(value)
7580 }
7581 None => ConstantJoinEvaluation::NotConstant,
7582 },
7583 LlkvExpr::InList {
7584 expr,
7585 list,
7586 negated,
7587 } => {
7588 let needle = match evaluate_constant_scalar(expr) {
7589 Some(literal) => literal,
7590 None => return ConstantJoinEvaluation::NotConstant,
7591 };
7592
7593 if matches!(needle, Literal::Null) {
7594 return ConstantJoinEvaluation::Unknown;
7595 }
7596
7597 let mut saw_unknown = false;
7598 for candidate in list {
7599 let value = match evaluate_constant_scalar(candidate) {
7600 Some(literal) => literal,
7601 None => return ConstantJoinEvaluation::NotConstant,
7602 };
7603
7604 match compare_literals(CompareOp::Eq, &needle, &value) {
7605 Some(true) => {
7606 let result = !*negated;
7607 return ConstantJoinEvaluation::Known(result);
7608 }
7609 Some(false) => {}
7610 None => saw_unknown = true,
7611 }
7612 }
7613
7614 if saw_unknown {
7615 ConstantJoinEvaluation::Unknown
7616 } else {
7617 let result = *negated;
7618 ConstantJoinEvaluation::Known(result)
7619 }
7620 }
7621 _ => ConstantJoinEvaluation::NotConstant,
7622 }
7623}
7624
7625enum NullComparisonBehavior {
7626 ThreeValuedLogic,
7627}
7628
7629fn evaluate_constant_scalar(expr: &ScalarExpr<String>) -> Option<Literal> {
7630 evaluate_constant_scalar_internal(expr, false)
7631}
7632
7633fn evaluate_constant_scalar_with_aggregates(expr: &ScalarExpr<String>) -> Option<Literal> {
7634 evaluate_constant_scalar_internal(expr, true)
7635}
7636
7637fn evaluate_constant_scalar_internal(
7638 expr: &ScalarExpr<String>,
7639 allow_aggregates: bool,
7640) -> Option<Literal> {
7641 match expr {
7642 ScalarExpr::Literal(lit) => Some(lit.clone()),
7643 ScalarExpr::Binary { left, op, right } => {
7644 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
7645 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
7646 evaluate_binary_literal(*op, &left_value, &right_value)
7647 }
7648 ScalarExpr::Cast { expr, data_type } => {
7649 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7650 cast_literal_to_type(&value, data_type)
7651 }
7652 ScalarExpr::Not(inner) => {
7653 let value = evaluate_constant_scalar_internal(inner, allow_aggregates)?;
7654 match literal_truthiness(&value) {
7655 Some(true) => Some(Literal::Integer(0)),
7656 Some(false) => Some(Literal::Integer(1)),
7657 None => Some(Literal::Null),
7658 }
7659 }
7660 ScalarExpr::IsNull { expr, negated } => {
7661 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7662 let is_null = matches!(value, Literal::Null);
7663 Some(Literal::Boolean(if *negated { !is_null } else { is_null }))
7664 }
7665 ScalarExpr::Coalesce(items) => {
7666 let mut saw_null = false;
7667 for item in items {
7668 match evaluate_constant_scalar_internal(item, allow_aggregates) {
7669 Some(Literal::Null) => saw_null = true,
7670 Some(value) => return Some(value),
7671 None => return None,
7672 }
7673 }
7674 if saw_null { Some(Literal::Null) } else { None }
7675 }
7676 ScalarExpr::Compare { left, op, right } => {
7677 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
7678 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
7679 match compare_literals(*op, &left_value, &right_value) {
7680 Some(flag) => Some(Literal::Boolean(flag)),
7681 None => Some(Literal::Null),
7682 }
7683 }
7684 ScalarExpr::Case {
7685 operand,
7686 branches,
7687 else_expr,
7688 } => {
7689 if let Some(operand_expr) = operand {
7690 let operand_value =
7691 evaluate_constant_scalar_internal(operand_expr, allow_aggregates)?;
7692 for (when_expr, then_expr) in branches {
7693 let when_value =
7694 evaluate_constant_scalar_internal(when_expr, allow_aggregates)?;
7695 if let Some(true) = compare_literals(CompareOp::Eq, &operand_value, &when_value)
7696 {
7697 return evaluate_constant_scalar_internal(then_expr, allow_aggregates);
7698 }
7699 }
7700 } else {
7701 for (condition_expr, result_expr) in branches {
7702 let condition_value =
7703 evaluate_constant_scalar_internal(condition_expr, allow_aggregates)?;
7704 match literal_truthiness(&condition_value) {
7705 Some(true) => {
7706 return evaluate_constant_scalar_internal(
7707 result_expr,
7708 allow_aggregates,
7709 );
7710 }
7711 Some(false) => {}
7712 None => {}
7713 }
7714 }
7715 }
7716
7717 if let Some(else_branch) = else_expr {
7718 evaluate_constant_scalar_internal(else_branch, allow_aggregates)
7719 } else {
7720 Some(Literal::Null)
7721 }
7722 }
7723 ScalarExpr::Column(_) => None,
7724 ScalarExpr::Aggregate(call) => {
7725 if allow_aggregates {
7726 evaluate_constant_aggregate(call, allow_aggregates)
7727 } else {
7728 None
7729 }
7730 }
7731 ScalarExpr::GetField { .. } => None,
7732 ScalarExpr::ScalarSubquery(_) => None,
7733 }
7734}
7735
7736fn evaluate_constant_aggregate(
7737 call: &AggregateCall<String>,
7738 allow_aggregates: bool,
7739) -> Option<Literal> {
7740 match call {
7741 AggregateCall::CountStar => Some(Literal::Integer(1)),
7742 AggregateCall::Count { expr, .. } => {
7743 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7744 if matches!(value, Literal::Null) {
7745 Some(Literal::Integer(0))
7746 } else {
7747 Some(Literal::Integer(1))
7748 }
7749 }
7750 AggregateCall::Sum { expr, .. } => {
7751 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7752 match value {
7753 Literal::Null => Some(Literal::Null),
7754 Literal::Integer(value) => Some(Literal::Integer(value)),
7755 Literal::Float(value) => Some(Literal::Float(value)),
7756 Literal::Boolean(flag) => Some(Literal::Integer(if flag { 1 } else { 0 })),
7757 _ => None,
7758 }
7759 }
7760 AggregateCall::Total { expr, .. } => {
7761 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7762 match value {
7763 Literal::Null => Some(Literal::Integer(0)),
7764 Literal::Integer(value) => Some(Literal::Integer(value)),
7765 Literal::Float(value) => Some(Literal::Float(value)),
7766 Literal::Boolean(flag) => Some(Literal::Integer(if flag { 1 } else { 0 })),
7767 _ => None,
7768 }
7769 }
7770 AggregateCall::Avg { expr, .. } => {
7771 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7772 match value {
7773 Literal::Null => Some(Literal::Null),
7774 other => {
7775 let numeric = literal_to_f64(&other)?;
7776 Some(Literal::Float(numeric))
7777 }
7778 }
7779 }
7780 AggregateCall::Min(expr) => {
7781 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7782 match value {
7783 Literal::Null => Some(Literal::Null),
7784 other => Some(other),
7785 }
7786 }
7787 AggregateCall::Max(expr) => {
7788 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7789 match value {
7790 Literal::Null => Some(Literal::Null),
7791 other => Some(other),
7792 }
7793 }
7794 AggregateCall::CountNulls(expr) => {
7795 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7796 let count = if matches!(value, Literal::Null) { 1 } else { 0 };
7797 Some(Literal::Integer(count))
7798 }
7799 AggregateCall::GroupConcat {
7800 expr, separator: _, ..
7801 } => {
7802 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
7803 match value {
7804 Literal::Null => Some(Literal::Null),
7805 Literal::String(s) => Some(Literal::String(s)),
7806 Literal::Integer(i) => Some(Literal::String(i.to_string())),
7807 Literal::Float(f) => Some(Literal::String(f.to_string())),
7808 Literal::Boolean(b) => Some(Literal::String(if b { "1" } else { "0" }.to_string())),
7809 _ => None,
7810 }
7811 }
7812 }
7813}
7814
7815fn evaluate_binary_literal(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
7816 match op {
7817 BinaryOp::And => evaluate_literal_logical_and(left, right),
7818 BinaryOp::Or => evaluate_literal_logical_or(left, right),
7819 BinaryOp::Add
7820 | BinaryOp::Subtract
7821 | BinaryOp::Multiply
7822 | BinaryOp::Divide
7823 | BinaryOp::Modulo => {
7824 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
7825 return Some(Literal::Null);
7826 }
7827
7828 match op {
7829 BinaryOp::Add => add_literals(left, right),
7830 BinaryOp::Subtract => subtract_literals(left, right),
7831 BinaryOp::Multiply => multiply_literals(left, right),
7832 BinaryOp::Divide => divide_literals(left, right),
7833 BinaryOp::Modulo => modulo_literals(left, right),
7834 BinaryOp::And
7835 | BinaryOp::Or
7836 | BinaryOp::BitwiseShiftLeft
7837 | BinaryOp::BitwiseShiftRight => unreachable!(),
7838 }
7839 }
7840 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
7841 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
7842 return Some(Literal::Null);
7843 }
7844
7845 let lhs = literal_to_i128(left)?;
7847 let rhs = literal_to_i128(right)?;
7848
7849 let result = match op {
7851 BinaryOp::BitwiseShiftLeft => (lhs as i64).wrapping_shl(rhs as u32) as i128,
7852 BinaryOp::BitwiseShiftRight => (lhs as i64).wrapping_shr(rhs as u32) as i128,
7853 _ => unreachable!(),
7854 };
7855
7856 Some(Literal::Integer(result))
7857 }
7858 }
7859}
7860
7861fn evaluate_literal_logical_and(left: &Literal, right: &Literal) -> Option<Literal> {
7862 let left_truth = literal_truthiness(left);
7863 if matches!(left_truth, Some(false)) {
7864 return Some(Literal::Integer(0));
7865 }
7866
7867 let right_truth = literal_truthiness(right);
7868 if matches!(right_truth, Some(false)) {
7869 return Some(Literal::Integer(0));
7870 }
7871
7872 match (left_truth, right_truth) {
7873 (Some(true), Some(true)) => Some(Literal::Integer(1)),
7874 (Some(true), None) | (None, Some(true)) | (None, None) => Some(Literal::Null),
7875 _ => Some(Literal::Null),
7876 }
7877}
7878
7879fn evaluate_literal_logical_or(left: &Literal, right: &Literal) -> Option<Literal> {
7880 let left_truth = literal_truthiness(left);
7881 if matches!(left_truth, Some(true)) {
7882 return Some(Literal::Integer(1));
7883 }
7884
7885 let right_truth = literal_truthiness(right);
7886 if matches!(right_truth, Some(true)) {
7887 return Some(Literal::Integer(1));
7888 }
7889
7890 match (left_truth, right_truth) {
7891 (Some(false), Some(false)) => Some(Literal::Integer(0)),
7892 (Some(false), None) | (None, Some(false)) | (None, None) => Some(Literal::Null),
7893 _ => Some(Literal::Null),
7894 }
7895}
7896
7897fn add_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7898 match (left, right) {
7899 (Literal::Integer(lhs), Literal::Integer(rhs)) => {
7900 Some(Literal::Integer(lhs.saturating_add(*rhs)))
7901 }
7902 _ => {
7903 let lhs = literal_to_f64(left)?;
7904 let rhs = literal_to_f64(right)?;
7905 Some(Literal::Float(lhs + rhs))
7906 }
7907 }
7908}
7909
7910fn subtract_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7911 match (left, right) {
7912 (Literal::Integer(lhs), Literal::Integer(rhs)) => {
7913 Some(Literal::Integer(lhs.saturating_sub(*rhs)))
7914 }
7915 _ => {
7916 let lhs = literal_to_f64(left)?;
7917 let rhs = literal_to_f64(right)?;
7918 Some(Literal::Float(lhs - rhs))
7919 }
7920 }
7921}
7922
7923fn multiply_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7924 match (left, right) {
7925 (Literal::Integer(lhs), Literal::Integer(rhs)) => {
7926 Some(Literal::Integer(lhs.saturating_mul(*rhs)))
7927 }
7928 _ => {
7929 let lhs = literal_to_f64(left)?;
7930 let rhs = literal_to_f64(right)?;
7931 Some(Literal::Float(lhs * rhs))
7932 }
7933 }
7934}
7935
7936fn divide_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7937 fn literal_to_i128_from_integer_like(literal: &Literal) -> Option<i128> {
7938 match literal {
7939 Literal::Integer(value) => Some(*value),
7940 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
7941 _ => None,
7942 }
7943 }
7944
7945 if let (Some(lhs), Some(rhs)) = (
7946 literal_to_i128_from_integer_like(left),
7947 literal_to_i128_from_integer_like(right),
7948 ) {
7949 if rhs == 0 {
7950 return Some(Literal::Null);
7951 }
7952
7953 if lhs == i128::MIN && rhs == -1 {
7954 return Some(Literal::Float((lhs as f64) / (rhs as f64)));
7955 }
7956
7957 return Some(Literal::Integer(lhs / rhs));
7958 }
7959
7960 let lhs = literal_to_f64(left)?;
7961 let rhs = literal_to_f64(right)?;
7962 if rhs == 0.0 {
7963 return Some(Literal::Null);
7964 }
7965 Some(Literal::Float(lhs / rhs))
7966}
7967
7968fn modulo_literals(left: &Literal, right: &Literal) -> Option<Literal> {
7969 let lhs = literal_to_i128(left)?;
7970 let rhs = literal_to_i128(right)?;
7971 if rhs == 0 {
7972 return Some(Literal::Null);
7973 }
7974 Some(Literal::Integer(lhs % rhs))
7975}
7976
7977fn literal_to_f64(literal: &Literal) -> Option<f64> {
7978 match literal {
7979 Literal::Integer(value) => Some(*value as f64),
7980 Literal::Float(value) => Some(*value),
7981 Literal::Boolean(value) => Some(if *value { 1.0 } else { 0.0 }),
7982 _ => None,
7983 }
7984}
7985
7986fn literal_to_i128(literal: &Literal) -> Option<i128> {
7987 match literal {
7988 Literal::Integer(value) => Some(*value),
7989 Literal::Float(value) => Some(*value as i128),
7990 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
7991 _ => None,
7992 }
7993}
7994
7995fn literal_truthiness(literal: &Literal) -> Option<bool> {
7996 match literal {
7997 Literal::Boolean(value) => Some(*value),
7998 Literal::Integer(value) => Some(*value != 0),
7999 Literal::Float(value) => Some(*value != 0.0),
8000 Literal::Null => None,
8001 _ => None,
8002 }
8003}
8004
8005fn plan_value_truthiness(value: &PlanValue) -> Option<bool> {
8006 match value {
8007 PlanValue::Integer(v) => Some(*v != 0),
8008 PlanValue::Float(v) => Some(*v != 0.0),
8009 PlanValue::Null => None,
8010 _ => None,
8011 }
8012}
8013
8014fn option_i64_truthiness(value: Option<i64>) -> Option<bool> {
8015 value.map(|v| v != 0)
8016}
8017
8018fn evaluate_plan_value_logical_and(left: PlanValue, right: PlanValue) -> PlanValue {
8019 let left_truth = plan_value_truthiness(&left);
8020 if matches!(left_truth, Some(false)) {
8021 return PlanValue::Integer(0);
8022 }
8023
8024 let right_truth = plan_value_truthiness(&right);
8025 if matches!(right_truth, Some(false)) {
8026 return PlanValue::Integer(0);
8027 }
8028
8029 match (left_truth, right_truth) {
8030 (Some(true), Some(true)) => PlanValue::Integer(1),
8031 (Some(true), None) | (None, Some(true)) | (None, None) => PlanValue::Null,
8032 _ => PlanValue::Null,
8033 }
8034}
8035
8036fn evaluate_plan_value_logical_or(left: PlanValue, right: PlanValue) -> PlanValue {
8037 let left_truth = plan_value_truthiness(&left);
8038 if matches!(left_truth, Some(true)) {
8039 return PlanValue::Integer(1);
8040 }
8041
8042 let right_truth = plan_value_truthiness(&right);
8043 if matches!(right_truth, Some(true)) {
8044 return PlanValue::Integer(1);
8045 }
8046
8047 match (left_truth, right_truth) {
8048 (Some(false), Some(false)) => PlanValue::Integer(0),
8049 (Some(false), None) | (None, Some(false)) | (None, None) => PlanValue::Null,
8050 _ => PlanValue::Null,
8051 }
8052}
8053
8054fn evaluate_option_logical_and(left: Option<i64>, right: Option<i64>) -> Option<i64> {
8055 let left_truth = option_i64_truthiness(left);
8056 if matches!(left_truth, Some(false)) {
8057 return Some(0);
8058 }
8059
8060 let right_truth = option_i64_truthiness(right);
8061 if matches!(right_truth, Some(false)) {
8062 return Some(0);
8063 }
8064
8065 match (left_truth, right_truth) {
8066 (Some(true), Some(true)) => Some(1),
8067 (Some(true), None) | (None, Some(true)) | (None, None) => None,
8068 _ => None,
8069 }
8070}
8071
8072fn evaluate_option_logical_or(left: Option<i64>, right: Option<i64>) -> Option<i64> {
8073 let left_truth = option_i64_truthiness(left);
8074 if matches!(left_truth, Some(true)) {
8075 return Some(1);
8076 }
8077
8078 let right_truth = option_i64_truthiness(right);
8079 if matches!(right_truth, Some(true)) {
8080 return Some(1);
8081 }
8082
8083 match (left_truth, right_truth) {
8084 (Some(false), Some(false)) => Some(0),
8085 (Some(false), None) | (None, Some(false)) | (None, None) => None,
8086 _ => None,
8087 }
8088}
8089
8090fn cast_literal_to_type(literal: &Literal, data_type: &DataType) -> Option<Literal> {
8091 if matches!(literal, Literal::Null) {
8092 return Some(Literal::Null);
8093 }
8094
8095 match data_type {
8096 DataType::Boolean => literal_truthiness(literal).map(Literal::Boolean),
8097 DataType::Float16 | DataType::Float32 | DataType::Float64 => {
8098 let value = literal_to_f64(literal)?;
8099 Some(Literal::Float(value))
8100 }
8101 DataType::Int8
8102 | DataType::Int16
8103 | DataType::Int32
8104 | DataType::Int64
8105 | DataType::UInt8
8106 | DataType::UInt16
8107 | DataType::UInt32
8108 | DataType::UInt64 => {
8109 let value = literal_to_i128(literal)?;
8110 Some(Literal::Integer(value))
8111 }
8112 DataType::Utf8 | DataType::LargeUtf8 => Some(Literal::String(match literal {
8113 Literal::String(text) => text.clone(),
8114 Literal::Integer(value) => value.to_string(),
8115 Literal::Float(value) => value.to_string(),
8116 Literal::Boolean(value) => {
8117 if *value {
8118 "1".to_string()
8119 } else {
8120 "0".to_string()
8121 }
8122 }
8123 Literal::Struct(_) | Literal::Null => return None,
8124 })),
8125 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
8126 literal_to_i128(literal).map(Literal::Integer)
8127 }
8128 _ => None,
8129 }
8130}
8131
8132fn compare_literals(op: CompareOp, left: &Literal, right: &Literal) -> Option<bool> {
8133 compare_literals_with_mode(op, left, right, NullComparisonBehavior::ThreeValuedLogic)
8134}
8135
8136fn bind_select_filter(
8137 filter: &llkv_plan::SelectFilter,
8138 bindings: &FxHashMap<String, Literal>,
8139) -> ExecutorResult<llkv_plan::SelectFilter> {
8140 let predicate = bind_predicate_expr(&filter.predicate, bindings)?;
8141 let subqueries = filter
8142 .subqueries
8143 .iter()
8144 .map(|subquery| bind_filter_subquery(subquery, bindings))
8145 .collect::<ExecutorResult<Vec<_>>>()?;
8146
8147 Ok(llkv_plan::SelectFilter {
8148 predicate,
8149 subqueries,
8150 })
8151}
8152
8153fn bind_filter_subquery(
8154 subquery: &llkv_plan::FilterSubquery,
8155 bindings: &FxHashMap<String, Literal>,
8156) -> ExecutorResult<llkv_plan::FilterSubquery> {
8157 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
8158 Ok(llkv_plan::FilterSubquery {
8159 id: subquery.id,
8160 plan: Box::new(bound_plan),
8161 correlated_columns: subquery.correlated_columns.clone(),
8162 })
8163}
8164
8165fn bind_scalar_subquery(
8166 subquery: &llkv_plan::ScalarSubquery,
8167 bindings: &FxHashMap<String, Literal>,
8168) -> ExecutorResult<llkv_plan::ScalarSubquery> {
8169 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
8170 Ok(llkv_plan::ScalarSubquery {
8171 id: subquery.id,
8172 plan: Box::new(bound_plan),
8173 correlated_columns: subquery.correlated_columns.clone(),
8174 })
8175}
8176
8177fn bind_projection(
8178 projection: &SelectProjection,
8179 bindings: &FxHashMap<String, Literal>,
8180) -> ExecutorResult<SelectProjection> {
8181 match projection {
8182 SelectProjection::AllColumns => Ok(projection.clone()),
8183 SelectProjection::AllColumnsExcept { exclude } => Ok(SelectProjection::AllColumnsExcept {
8184 exclude: exclude.clone(),
8185 }),
8186 SelectProjection::Column { name, alias } => {
8187 if let Some(literal) = bindings.get(name) {
8188 let expr = ScalarExpr::Literal(literal.clone());
8189 Ok(SelectProjection::Computed {
8190 expr,
8191 alias: alias.clone().unwrap_or_else(|| name.clone()),
8192 })
8193 } else {
8194 Ok(projection.clone())
8195 }
8196 }
8197 SelectProjection::Computed { expr, alias } => Ok(SelectProjection::Computed {
8198 expr: bind_scalar_expr(expr, bindings)?,
8199 alias: alias.clone(),
8200 }),
8201 }
8202}
8203
8204fn bind_aggregate_expr(
8205 aggregate: &AggregateExpr,
8206 bindings: &FxHashMap<String, Literal>,
8207) -> ExecutorResult<AggregateExpr> {
8208 match aggregate {
8209 AggregateExpr::CountStar { .. } => Ok(aggregate.clone()),
8210 AggregateExpr::Column {
8211 column,
8212 alias,
8213 function,
8214 distinct,
8215 } => {
8216 if bindings.contains_key(column) {
8217 return Err(Error::InvalidArgumentError(
8218 "correlated columns are not supported inside aggregate expressions".into(),
8219 ));
8220 }
8221 Ok(AggregateExpr::Column {
8222 column: column.clone(),
8223 alias: alias.clone(),
8224 function: function.clone(),
8225 distinct: *distinct,
8226 })
8227 }
8228 }
8229}
8230
8231fn bind_scalar_expr(
8232 expr: &ScalarExpr<String>,
8233 bindings: &FxHashMap<String, Literal>,
8234) -> ExecutorResult<ScalarExpr<String>> {
8235 match expr {
8236 ScalarExpr::Column(name) => {
8237 if let Some(literal) = bindings.get(name) {
8238 Ok(ScalarExpr::Literal(literal.clone()))
8239 } else {
8240 Ok(ScalarExpr::Column(name.clone()))
8241 }
8242 }
8243 ScalarExpr::Literal(literal) => Ok(ScalarExpr::Literal(literal.clone())),
8244 ScalarExpr::Binary { left, op, right } => Ok(ScalarExpr::Binary {
8245 left: Box::new(bind_scalar_expr(left, bindings)?),
8246 op: *op,
8247 right: Box::new(bind_scalar_expr(right, bindings)?),
8248 }),
8249 ScalarExpr::Compare { left, op, right } => Ok(ScalarExpr::Compare {
8250 left: Box::new(bind_scalar_expr(left, bindings)?),
8251 op: *op,
8252 right: Box::new(bind_scalar_expr(right, bindings)?),
8253 }),
8254 ScalarExpr::Aggregate(call) => Ok(ScalarExpr::Aggregate(call.clone())),
8255 ScalarExpr::GetField { base, field_name } => {
8256 let bound_base = bind_scalar_expr(base, bindings)?;
8257 match bound_base {
8258 ScalarExpr::Literal(literal) => {
8259 let value = extract_struct_field(&literal, field_name).unwrap_or(Literal::Null);
8260 Ok(ScalarExpr::Literal(value))
8261 }
8262 other => Ok(ScalarExpr::GetField {
8263 base: Box::new(other),
8264 field_name: field_name.clone(),
8265 }),
8266 }
8267 }
8268 ScalarExpr::Cast { expr, data_type } => Ok(ScalarExpr::Cast {
8269 expr: Box::new(bind_scalar_expr(expr, bindings)?),
8270 data_type: data_type.clone(),
8271 }),
8272 ScalarExpr::Case {
8273 operand,
8274 branches,
8275 else_expr,
8276 } => {
8277 let bound_operand = match operand {
8278 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
8279 None => None,
8280 };
8281 let mut bound_branches = Vec::with_capacity(branches.len());
8282 for (when_expr, then_expr) in branches {
8283 bound_branches.push((
8284 bind_scalar_expr(when_expr, bindings)?,
8285 bind_scalar_expr(then_expr, bindings)?,
8286 ));
8287 }
8288 let bound_else = match else_expr {
8289 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
8290 None => None,
8291 };
8292 Ok(ScalarExpr::Case {
8293 operand: bound_operand,
8294 branches: bound_branches,
8295 else_expr: bound_else,
8296 })
8297 }
8298 ScalarExpr::Coalesce(items) => {
8299 let mut bound_items = Vec::with_capacity(items.len());
8300 for item in items {
8301 bound_items.push(bind_scalar_expr(item, bindings)?);
8302 }
8303 Ok(ScalarExpr::Coalesce(bound_items))
8304 }
8305 ScalarExpr::Not(inner) => Ok(ScalarExpr::Not(Box::new(bind_scalar_expr(
8306 inner, bindings,
8307 )?))),
8308 ScalarExpr::IsNull { expr, negated } => Ok(ScalarExpr::IsNull {
8309 expr: Box::new(bind_scalar_expr(expr, bindings)?),
8310 negated: *negated,
8311 }),
8312 ScalarExpr::ScalarSubquery(sub) => Ok(ScalarExpr::ScalarSubquery(sub.clone())),
8313 }
8314}
8315
8316fn bind_predicate_expr(
8317 expr: &LlkvExpr<'static, String>,
8318 bindings: &FxHashMap<String, Literal>,
8319) -> ExecutorResult<LlkvExpr<'static, String>> {
8320 match expr {
8321 LlkvExpr::And(children) => {
8322 let mut bound = Vec::with_capacity(children.len());
8323 for child in children {
8324 bound.push(bind_predicate_expr(child, bindings)?);
8325 }
8326 Ok(LlkvExpr::And(bound))
8327 }
8328 LlkvExpr::Or(children) => {
8329 let mut bound = Vec::with_capacity(children.len());
8330 for child in children {
8331 bound.push(bind_predicate_expr(child, bindings)?);
8332 }
8333 Ok(LlkvExpr::Or(bound))
8334 }
8335 LlkvExpr::Not(inner) => Ok(LlkvExpr::Not(Box::new(bind_predicate_expr(
8336 inner, bindings,
8337 )?))),
8338 LlkvExpr::Pred(filter) => bind_filter_predicate(filter, bindings),
8339 LlkvExpr::Compare { left, op, right } => Ok(LlkvExpr::Compare {
8340 left: bind_scalar_expr(left, bindings)?,
8341 op: *op,
8342 right: bind_scalar_expr(right, bindings)?,
8343 }),
8344 LlkvExpr::InList {
8345 expr,
8346 list,
8347 negated,
8348 } => {
8349 let target = bind_scalar_expr(expr, bindings)?;
8350 let mut bound_list = Vec::with_capacity(list.len());
8351 for item in list {
8352 bound_list.push(bind_scalar_expr(item, bindings)?);
8353 }
8354 Ok(LlkvExpr::InList {
8355 expr: target,
8356 list: bound_list,
8357 negated: *negated,
8358 })
8359 }
8360 LlkvExpr::IsNull { expr, negated } => Ok(LlkvExpr::IsNull {
8361 expr: bind_scalar_expr(expr, bindings)?,
8362 negated: *negated,
8363 }),
8364 LlkvExpr::Literal(value) => Ok(LlkvExpr::Literal(*value)),
8365 LlkvExpr::Exists(subquery) => Ok(LlkvExpr::Exists(subquery.clone())),
8366 }
8367}
8368
8369fn bind_filter_predicate(
8370 filter: &Filter<'static, String>,
8371 bindings: &FxHashMap<String, Literal>,
8372) -> ExecutorResult<LlkvExpr<'static, String>> {
8373 if let Some(literal) = bindings.get(&filter.field_id) {
8374 let result = evaluate_filter_against_literal(literal, &filter.op)?;
8375 return Ok(LlkvExpr::Literal(result));
8376 }
8377 Ok(LlkvExpr::Pred(filter.clone()))
8378}
8379
8380fn evaluate_filter_against_literal(value: &Literal, op: &Operator) -> ExecutorResult<bool> {
8381 use std::ops::Bound;
8382
8383 match op {
8384 Operator::IsNull => Ok(matches!(value, Literal::Null)),
8385 Operator::IsNotNull => Ok(!matches!(value, Literal::Null)),
8386 Operator::Equals(rhs) => Ok(literal_equals(value, rhs).unwrap_or(false)),
8387 Operator::GreaterThan(rhs) => Ok(literal_compare(value, rhs)
8388 .map(|cmp| cmp == std::cmp::Ordering::Greater)
8389 .unwrap_or(false)),
8390 Operator::GreaterThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
8391 .map(|cmp| matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal))
8392 .unwrap_or(false)),
8393 Operator::LessThan(rhs) => Ok(literal_compare(value, rhs)
8394 .map(|cmp| cmp == std::cmp::Ordering::Less)
8395 .unwrap_or(false)),
8396 Operator::LessThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
8397 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal))
8398 .unwrap_or(false)),
8399 Operator::In(values) => Ok(values
8400 .iter()
8401 .any(|candidate| literal_equals(value, candidate).unwrap_or(false))),
8402 Operator::Range { lower, upper } => {
8403 let lower_ok = match lower {
8404 Bound::Unbounded => Some(true),
8405 Bound::Included(bound) => literal_compare(value, bound).map(|cmp| {
8406 matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
8407 }),
8408 Bound::Excluded(bound) => {
8409 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Greater)
8410 }
8411 }
8412 .unwrap_or(false);
8413
8414 let upper_ok = match upper {
8415 Bound::Unbounded => Some(true),
8416 Bound::Included(bound) => literal_compare(value, bound)
8417 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal)),
8418 Bound::Excluded(bound) => {
8419 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Less)
8420 }
8421 }
8422 .unwrap_or(false);
8423
8424 Ok(lower_ok && upper_ok)
8425 }
8426 Operator::StartsWith {
8427 pattern,
8428 case_sensitive,
8429 } => {
8430 let target = if *case_sensitive {
8431 pattern.to_string()
8432 } else {
8433 pattern.to_ascii_lowercase()
8434 };
8435 Ok(literal_string(value, *case_sensitive)
8436 .map(|source| source.starts_with(&target))
8437 .unwrap_or(false))
8438 }
8439 Operator::EndsWith {
8440 pattern,
8441 case_sensitive,
8442 } => {
8443 let target = if *case_sensitive {
8444 pattern.to_string()
8445 } else {
8446 pattern.to_ascii_lowercase()
8447 };
8448 Ok(literal_string(value, *case_sensitive)
8449 .map(|source| source.ends_with(&target))
8450 .unwrap_or(false))
8451 }
8452 Operator::Contains {
8453 pattern,
8454 case_sensitive,
8455 } => {
8456 let target = if *case_sensitive {
8457 pattern.to_string()
8458 } else {
8459 pattern.to_ascii_lowercase()
8460 };
8461 Ok(literal_string(value, *case_sensitive)
8462 .map(|source| source.contains(&target))
8463 .unwrap_or(false))
8464 }
8465 }
8466}
8467
8468fn literal_compare(lhs: &Literal, rhs: &Literal) -> Option<std::cmp::Ordering> {
8469 match (lhs, rhs) {
8470 (Literal::Integer(a), Literal::Integer(b)) => Some(a.cmp(b)),
8471 (Literal::Float(a), Literal::Float(b)) => a.partial_cmp(b),
8472 (Literal::Integer(a), Literal::Float(b)) => (*a as f64).partial_cmp(b),
8473 (Literal::Float(a), Literal::Integer(b)) => a.partial_cmp(&(*b as f64)),
8474 (Literal::String(a), Literal::String(b)) => Some(a.cmp(b)),
8475 _ => None,
8476 }
8477}
8478
8479fn literal_equals(lhs: &Literal, rhs: &Literal) -> Option<bool> {
8480 match (lhs, rhs) {
8481 (Literal::Boolean(a), Literal::Boolean(b)) => Some(a == b),
8482 (Literal::String(a), Literal::String(b)) => Some(a == b),
8483 (Literal::Integer(_), Literal::Integer(_))
8484 | (Literal::Integer(_), Literal::Float(_))
8485 | (Literal::Float(_), Literal::Integer(_))
8486 | (Literal::Float(_), Literal::Float(_)) => {
8487 literal_compare(lhs, rhs).map(|cmp| cmp == std::cmp::Ordering::Equal)
8488 }
8489 _ => None,
8490 }
8491}
8492
8493fn literal_string(literal: &Literal, case_sensitive: bool) -> Option<String> {
8494 match literal {
8495 Literal::String(value) => {
8496 if case_sensitive {
8497 Some(value.clone())
8498 } else {
8499 Some(value.to_ascii_lowercase())
8500 }
8501 }
8502 _ => None,
8503 }
8504}
8505
8506fn extract_struct_field(literal: &Literal, field_name: &str) -> Option<Literal> {
8507 if let Literal::Struct(fields) = literal {
8508 for (name, value) in fields {
8509 if name.eq_ignore_ascii_case(field_name) {
8510 return Some((**value).clone());
8511 }
8512 }
8513 }
8514 None
8515}
8516
8517fn array_value_to_literal(array: &ArrayRef, idx: usize) -> ExecutorResult<Literal> {
8518 if array.is_null(idx) {
8519 return Ok(Literal::Null);
8520 }
8521
8522 match array.data_type() {
8523 DataType::Boolean => {
8524 let array = array
8525 .as_any()
8526 .downcast_ref::<BooleanArray>()
8527 .ok_or_else(|| Error::Internal("failed to downcast boolean array".into()))?;
8528 Ok(Literal::Boolean(array.value(idx)))
8529 }
8530 DataType::Int8 => {
8531 let array = array
8532 .as_any()
8533 .downcast_ref::<Int8Array>()
8534 .ok_or_else(|| Error::Internal("failed to downcast int8 array".into()))?;
8535 Ok(Literal::Integer(array.value(idx) as i128))
8536 }
8537 DataType::Int16 => {
8538 let array = array
8539 .as_any()
8540 .downcast_ref::<Int16Array>()
8541 .ok_or_else(|| Error::Internal("failed to downcast int16 array".into()))?;
8542 Ok(Literal::Integer(array.value(idx) as i128))
8543 }
8544 DataType::Int32 => {
8545 let array = array
8546 .as_any()
8547 .downcast_ref::<Int32Array>()
8548 .ok_or_else(|| Error::Internal("failed to downcast int32 array".into()))?;
8549 Ok(Literal::Integer(array.value(idx) as i128))
8550 }
8551 DataType::Int64 => {
8552 let array = array
8553 .as_any()
8554 .downcast_ref::<Int64Array>()
8555 .ok_or_else(|| Error::Internal("failed to downcast int64 array".into()))?;
8556 Ok(Literal::Integer(array.value(idx) as i128))
8557 }
8558 DataType::UInt8 => {
8559 let array = array
8560 .as_any()
8561 .downcast_ref::<UInt8Array>()
8562 .ok_or_else(|| Error::Internal("failed to downcast uint8 array".into()))?;
8563 Ok(Literal::Integer(array.value(idx) as i128))
8564 }
8565 DataType::UInt16 => {
8566 let array = array
8567 .as_any()
8568 .downcast_ref::<UInt16Array>()
8569 .ok_or_else(|| Error::Internal("failed to downcast uint16 array".into()))?;
8570 Ok(Literal::Integer(array.value(idx) as i128))
8571 }
8572 DataType::UInt32 => {
8573 let array = array
8574 .as_any()
8575 .downcast_ref::<UInt32Array>()
8576 .ok_or_else(|| Error::Internal("failed to downcast uint32 array".into()))?;
8577 Ok(Literal::Integer(array.value(idx) as i128))
8578 }
8579 DataType::UInt64 => {
8580 let array = array
8581 .as_any()
8582 .downcast_ref::<UInt64Array>()
8583 .ok_or_else(|| Error::Internal("failed to downcast uint64 array".into()))?;
8584 Ok(Literal::Integer(array.value(idx) as i128))
8585 }
8586 DataType::Float32 => {
8587 let array = array
8588 .as_any()
8589 .downcast_ref::<Float32Array>()
8590 .ok_or_else(|| Error::Internal("failed to downcast float32 array".into()))?;
8591 Ok(Literal::Float(array.value(idx) as f64))
8592 }
8593 DataType::Float64 => {
8594 let array = array
8595 .as_any()
8596 .downcast_ref::<Float64Array>()
8597 .ok_or_else(|| Error::Internal("failed to downcast float64 array".into()))?;
8598 Ok(Literal::Float(array.value(idx)))
8599 }
8600 DataType::Utf8 => {
8601 let array = array
8602 .as_any()
8603 .downcast_ref::<StringArray>()
8604 .ok_or_else(|| Error::Internal("failed to downcast utf8 array".into()))?;
8605 Ok(Literal::String(array.value(idx).to_string()))
8606 }
8607 DataType::LargeUtf8 => {
8608 let array = array
8609 .as_any()
8610 .downcast_ref::<LargeStringArray>()
8611 .ok_or_else(|| Error::Internal("failed to downcast large utf8 array".into()))?;
8612 Ok(Literal::String(array.value(idx).to_string()))
8613 }
8614 DataType::Struct(fields) => {
8615 let struct_array = array
8616 .as_any()
8617 .downcast_ref::<StructArray>()
8618 .ok_or_else(|| Error::Internal("failed to downcast struct array".into()))?;
8619 let mut members = Vec::with_capacity(fields.len());
8620 for (field_idx, field) in fields.iter().enumerate() {
8621 let child = struct_array.column(field_idx);
8622 let literal = array_value_to_literal(child, idx)?;
8623 members.push((field.name().clone(), Box::new(literal)));
8624 }
8625 Ok(Literal::Struct(members))
8626 }
8627 other => Err(Error::InvalidArgumentError(format!(
8628 "unsupported scalar subquery result type: {other:?}"
8629 ))),
8630 }
8631}
8632
8633fn collect_scalar_subquery_ids(expr: &ScalarExpr<FieldId>, ids: &mut FxHashSet<SubqueryId>) {
8634 match expr {
8635 ScalarExpr::ScalarSubquery(subquery) => {
8636 ids.insert(subquery.id);
8637 }
8638 ScalarExpr::Binary { left, right, .. } => {
8639 collect_scalar_subquery_ids(left, ids);
8640 collect_scalar_subquery_ids(right, ids);
8641 }
8642 ScalarExpr::Compare { left, right, .. } => {
8643 collect_scalar_subquery_ids(left, ids);
8644 collect_scalar_subquery_ids(right, ids);
8645 }
8646 ScalarExpr::GetField { base, .. } => {
8647 collect_scalar_subquery_ids(base, ids);
8648 }
8649 ScalarExpr::Cast { expr, .. } => {
8650 collect_scalar_subquery_ids(expr, ids);
8651 }
8652 ScalarExpr::Not(expr) => {
8653 collect_scalar_subquery_ids(expr, ids);
8654 }
8655 ScalarExpr::IsNull { expr, .. } => {
8656 collect_scalar_subquery_ids(expr, ids);
8657 }
8658 ScalarExpr::Case {
8659 operand,
8660 branches,
8661 else_expr,
8662 } => {
8663 if let Some(op) = operand {
8664 collect_scalar_subquery_ids(op, ids);
8665 }
8666 for (when_expr, then_expr) in branches {
8667 collect_scalar_subquery_ids(when_expr, ids);
8668 collect_scalar_subquery_ids(then_expr, ids);
8669 }
8670 if let Some(else_expr) = else_expr {
8671 collect_scalar_subquery_ids(else_expr, ids);
8672 }
8673 }
8674 ScalarExpr::Coalesce(items) => {
8675 for item in items {
8676 collect_scalar_subquery_ids(item, ids);
8677 }
8678 }
8679 ScalarExpr::Aggregate(_) | ScalarExpr::Column(_) | ScalarExpr::Literal(_) => {}
8680 }
8681}
8682
8683fn rewrite_scalar_expr_for_subqueries(
8684 expr: &ScalarExpr<FieldId>,
8685 mapping: &FxHashMap<SubqueryId, FieldId>,
8686) -> ScalarExpr<FieldId> {
8687 match expr {
8688 ScalarExpr::ScalarSubquery(subquery) => mapping
8689 .get(&subquery.id)
8690 .map(|field_id| ScalarExpr::Column(*field_id))
8691 .unwrap_or_else(|| ScalarExpr::ScalarSubquery(subquery.clone())),
8692 ScalarExpr::Binary { left, op, right } => ScalarExpr::Binary {
8693 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
8694 op: *op,
8695 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
8696 },
8697 ScalarExpr::Compare { left, op, right } => ScalarExpr::Compare {
8698 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
8699 op: *op,
8700 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
8701 },
8702 ScalarExpr::GetField { base, field_name } => ScalarExpr::GetField {
8703 base: Box::new(rewrite_scalar_expr_for_subqueries(base, mapping)),
8704 field_name: field_name.clone(),
8705 },
8706 ScalarExpr::Cast { expr, data_type } => ScalarExpr::Cast {
8707 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
8708 data_type: data_type.clone(),
8709 },
8710 ScalarExpr::Not(expr) => {
8711 ScalarExpr::Not(Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)))
8712 }
8713 ScalarExpr::IsNull { expr, negated } => ScalarExpr::IsNull {
8714 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
8715 negated: *negated,
8716 },
8717 ScalarExpr::Case {
8718 operand,
8719 branches,
8720 else_expr,
8721 } => ScalarExpr::Case {
8722 operand: operand
8723 .as_ref()
8724 .map(|op| Box::new(rewrite_scalar_expr_for_subqueries(op, mapping))),
8725 branches: branches
8726 .iter()
8727 .map(|(when_expr, then_expr)| {
8728 (
8729 rewrite_scalar_expr_for_subqueries(when_expr, mapping),
8730 rewrite_scalar_expr_for_subqueries(then_expr, mapping),
8731 )
8732 })
8733 .collect(),
8734 else_expr: else_expr
8735 .as_ref()
8736 .map(|expr| Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping))),
8737 },
8738 ScalarExpr::Coalesce(items) => ScalarExpr::Coalesce(
8739 items
8740 .iter()
8741 .map(|item| rewrite_scalar_expr_for_subqueries(item, mapping))
8742 .collect(),
8743 ),
8744 ScalarExpr::Aggregate(_) | ScalarExpr::Column(_) | ScalarExpr::Literal(_) => expr.clone(),
8745 }
8746}
8747
8748fn collect_correlated_bindings(
8749 context: &mut CrossProductExpressionContext,
8750 batch: &RecordBatch,
8751 row_idx: usize,
8752 columns: &[llkv_plan::CorrelatedColumn],
8753) -> ExecutorResult<FxHashMap<String, Literal>> {
8754 let mut out = FxHashMap::default();
8755
8756 for correlated in columns {
8757 if !correlated.field_path.is_empty() {
8758 return Err(Error::InvalidArgumentError(
8759 "correlated field path resolution is not yet supported".into(),
8760 ));
8761 }
8762
8763 let field_id = context
8764 .field_id_for_column(&correlated.column)
8765 .ok_or_else(|| {
8766 Error::InvalidArgumentError(format!(
8767 "correlated column '{}' not found in outer query output",
8768 correlated.column
8769 ))
8770 })?;
8771
8772 let accessor = context.column_accessor(field_id, batch)?;
8773 let literal = accessor.literal_at(row_idx)?;
8774 out.insert(correlated.placeholder.clone(), literal);
8775 }
8776
8777 Ok(out)
8778}
8779
8780#[derive(Clone)]
8782pub struct SelectExecution<P>
8783where
8784 P: Pager<Blob = EntryHandle> + Send + Sync,
8785{
8786 table_name: String,
8787 schema: Arc<Schema>,
8788 stream: SelectStream<P>,
8789}
8790
8791#[derive(Clone)]
8792enum SelectStream<P>
8793where
8794 P: Pager<Blob = EntryHandle> + Send + Sync,
8795{
8796 Projection {
8797 table: Arc<ExecutorTable<P>>,
8798 projections: Vec<ScanProjection>,
8799 filter_expr: LlkvExpr<'static, FieldId>,
8800 options: ScanStreamOptions<P>,
8801 full_table_scan: bool,
8802 order_by: Vec<OrderByPlan>,
8803 distinct: bool,
8804 },
8805 Aggregation {
8806 batch: RecordBatch,
8807 },
8808}
8809
8810impl<P> SelectExecution<P>
8811where
8812 P: Pager<Blob = EntryHandle> + Send + Sync,
8813{
8814 #[allow(clippy::too_many_arguments)]
8815 fn new_projection(
8816 table_name: String,
8817 schema: Arc<Schema>,
8818 table: Arc<ExecutorTable<P>>,
8819 projections: Vec<ScanProjection>,
8820 filter_expr: LlkvExpr<'static, FieldId>,
8821 options: ScanStreamOptions<P>,
8822 full_table_scan: bool,
8823 order_by: Vec<OrderByPlan>,
8824 distinct: bool,
8825 ) -> Self {
8826 Self {
8827 table_name,
8828 schema,
8829 stream: SelectStream::Projection {
8830 table,
8831 projections,
8832 filter_expr,
8833 options,
8834 full_table_scan,
8835 order_by,
8836 distinct,
8837 },
8838 }
8839 }
8840
8841 pub fn new_single_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
8842 Self {
8843 table_name,
8844 schema,
8845 stream: SelectStream::Aggregation { batch },
8846 }
8847 }
8848
8849 pub fn from_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
8850 Self::new_single_batch(table_name, schema, batch)
8851 }
8852
8853 pub fn table_name(&self) -> &str {
8854 &self.table_name
8855 }
8856
8857 pub fn schema(&self) -> Arc<Schema> {
8858 Arc::clone(&self.schema)
8859 }
8860
8861 pub fn stream(
8862 self,
8863 mut on_batch: impl FnMut(RecordBatch) -> ExecutorResult<()>,
8864 ) -> ExecutorResult<()> {
8865 let schema = Arc::clone(&self.schema);
8866 match self.stream {
8867 SelectStream::Projection {
8868 table,
8869 projections,
8870 filter_expr,
8871 options,
8872 full_table_scan,
8873 order_by,
8874 distinct,
8875 } => {
8876 let total_rows = table.total_rows.load(Ordering::SeqCst);
8878 if total_rows == 0 {
8879 return Ok(());
8881 }
8882
8883 let mut error: Option<Error> = None;
8884 let mut produced = false;
8885 let mut produced_rows: u64 = 0;
8886 let capture_nulls_first = matches!(options.order, Some(spec) if spec.nulls_first);
8887 let needs_post_sort =
8888 !order_by.is_empty() && (order_by.len() > 1 || options.order.is_none());
8889 let collect_batches = needs_post_sort || capture_nulls_first;
8890 let include_nulls = options.include_nulls;
8891 let has_row_id_filter = options.row_id_filter.is_some();
8892 let mut distinct_state = if distinct {
8893 Some(DistinctState::default())
8894 } else {
8895 None
8896 };
8897 let scan_options = options;
8898 let mut buffered_batches: Vec<RecordBatch> = Vec::new();
8899 table
8900 .table
8901 .scan_stream(projections, &filter_expr, scan_options, |batch| {
8902 if error.is_some() {
8903 return;
8904 }
8905 let mut batch = batch;
8906 if let Some(state) = distinct_state.as_mut() {
8907 match distinct_filter_batch(batch, state) {
8908 Ok(Some(filtered)) => {
8909 batch = filtered;
8910 }
8911 Ok(None) => {
8912 return;
8913 }
8914 Err(err) => {
8915 error = Some(err);
8916 return;
8917 }
8918 }
8919 }
8920 produced = true;
8921 produced_rows = produced_rows.saturating_add(batch.num_rows() as u64);
8922 if collect_batches {
8923 buffered_batches.push(batch);
8924 } else if let Err(err) = on_batch(batch) {
8925 error = Some(err);
8926 }
8927 })?;
8928 if let Some(err) = error {
8929 return Err(err);
8930 }
8931 if !produced {
8932 if !distinct && full_table_scan && total_rows > 0 {
8935 for batch in synthesize_null_scan(Arc::clone(&schema), total_rows)? {
8936 on_batch(batch)?;
8937 }
8938 }
8939 return Ok(());
8940 }
8941 let mut null_batches: Vec<RecordBatch> = Vec::new();
8942 if !distinct
8948 && include_nulls
8949 && full_table_scan
8950 && produced_rows < total_rows
8951 && !has_row_id_filter
8952 {
8953 let missing = total_rows - produced_rows;
8954 if missing > 0 {
8955 null_batches = synthesize_null_scan(Arc::clone(&schema), missing)?;
8956 }
8957 }
8958
8959 if collect_batches {
8960 if needs_post_sort {
8961 if !null_batches.is_empty() {
8962 buffered_batches.extend(null_batches);
8963 }
8964 if !buffered_batches.is_empty() {
8965 let combined =
8966 concat_batches(&schema, &buffered_batches).map_err(|err| {
8967 Error::InvalidArgumentError(format!(
8968 "failed to concatenate result batches for ORDER BY: {}",
8969 err
8970 ))
8971 })?;
8972 let sorted_batch =
8973 sort_record_batch_with_order(&schema, &combined, &order_by)?;
8974 on_batch(sorted_batch)?;
8975 }
8976 } else if capture_nulls_first {
8977 for batch in null_batches {
8978 on_batch(batch)?;
8979 }
8980 for batch in buffered_batches {
8981 on_batch(batch)?;
8982 }
8983 }
8984 } else if !null_batches.is_empty() {
8985 for batch in null_batches {
8986 on_batch(batch)?;
8987 }
8988 }
8989 Ok(())
8990 }
8991 SelectStream::Aggregation { batch } => on_batch(batch),
8992 }
8993 }
8994
8995 pub fn collect(self) -> ExecutorResult<Vec<RecordBatch>> {
8996 let mut batches = Vec::new();
8997 self.stream(|batch| {
8998 batches.push(batch);
8999 Ok(())
9000 })?;
9001 Ok(batches)
9002 }
9003
9004 pub fn collect_rows(self) -> ExecutorResult<ExecutorRowBatch> {
9005 let schema = self.schema();
9006 let mut rows: Vec<Vec<PlanValue>> = Vec::new();
9007 self.stream(|batch| {
9008 for row_idx in 0..batch.num_rows() {
9009 let mut row: Vec<PlanValue> = Vec::with_capacity(batch.num_columns());
9010 for col_idx in 0..batch.num_columns() {
9011 let value = llkv_plan::plan_value_from_array(batch.column(col_idx), row_idx)?;
9012 row.push(value);
9013 }
9014 rows.push(row);
9015 }
9016 Ok(())
9017 })?;
9018 let columns = schema
9019 .fields()
9020 .iter()
9021 .map(|field| field.name().to_string())
9022 .collect();
9023 Ok(ExecutorRowBatch { columns, rows })
9024 }
9025
9026 pub fn into_rows(self) -> ExecutorResult<Vec<Vec<PlanValue>>> {
9027 Ok(self.collect_rows()?.rows)
9028 }
9029}
9030
9031impl<P> fmt::Debug for SelectExecution<P>
9032where
9033 P: Pager<Blob = EntryHandle> + Send + Sync,
9034{
9035 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
9036 f.debug_struct("SelectExecution")
9037 .field("table_name", &self.table_name)
9038 .field("schema", &self.schema)
9039 .finish()
9040 }
9041}
9042
9043fn expand_order_targets(
9048 order_items: &[OrderByPlan],
9049 projections: &[ScanProjection],
9050) -> ExecutorResult<Vec<OrderByPlan>> {
9051 let mut expanded = Vec::new();
9052
9053 for item in order_items {
9054 match &item.target {
9055 OrderTarget::All => {
9056 if projections.is_empty() {
9057 return Err(Error::InvalidArgumentError(
9058 "ORDER BY ALL requires at least one projection".into(),
9059 ));
9060 }
9061
9062 for (idx, projection) in projections.iter().enumerate() {
9063 if matches!(projection, ScanProjection::Computed { .. }) {
9064 return Err(Error::InvalidArgumentError(
9065 "ORDER BY ALL cannot reference computed projections".into(),
9066 ));
9067 }
9068
9069 let mut clone = item.clone();
9070 clone.target = OrderTarget::Index(idx);
9071 expanded.push(clone);
9072 }
9073 }
9074 _ => expanded.push(item.clone()),
9075 }
9076 }
9077
9078 Ok(expanded)
9079}
9080
9081fn resolve_scan_order<P>(
9082 table: &ExecutorTable<P>,
9083 projections: &[ScanProjection],
9084 order_plan: &OrderByPlan,
9085) -> ExecutorResult<ScanOrderSpec>
9086where
9087 P: Pager<Blob = EntryHandle> + Send + Sync,
9088{
9089 let (column, field_id) = match &order_plan.target {
9090 OrderTarget::Column(name) => {
9091 let column = table.schema.resolve(name).ok_or_else(|| {
9092 Error::InvalidArgumentError(format!("unknown column '{}' in ORDER BY", name))
9093 })?;
9094 (column, column.field_id)
9095 }
9096 OrderTarget::Index(position) => {
9097 let projection = projections.get(*position).ok_or_else(|| {
9098 Error::InvalidArgumentError(format!(
9099 "ORDER BY position {} is out of range",
9100 position + 1
9101 ))
9102 })?;
9103 match projection {
9104 ScanProjection::Column(store_projection) => {
9105 let field_id = store_projection.logical_field_id.field_id();
9106 let column = table.schema.column_by_field_id(field_id).ok_or_else(|| {
9107 Error::InvalidArgumentError(format!(
9108 "unknown column with field id {field_id} in ORDER BY"
9109 ))
9110 })?;
9111 (column, field_id)
9112 }
9113 ScanProjection::Computed { .. } => {
9114 return Err(Error::InvalidArgumentError(
9115 "ORDER BY position referring to computed projection is not supported"
9116 .into(),
9117 ));
9118 }
9119 }
9120 }
9121 OrderTarget::All => {
9122 return Err(Error::InvalidArgumentError(
9123 "ORDER BY ALL should be expanded before execution".into(),
9124 ));
9125 }
9126 };
9127
9128 let transform = match order_plan.sort_type {
9129 OrderSortType::Native => match column.data_type {
9130 DataType::Int64 => ScanOrderTransform::IdentityInteger,
9131 DataType::Utf8 => ScanOrderTransform::IdentityUtf8,
9132 ref other => {
9133 return Err(Error::InvalidArgumentError(format!(
9134 "ORDER BY on column type {:?} is not supported",
9135 other
9136 )));
9137 }
9138 },
9139 OrderSortType::CastTextToInteger => {
9140 if column.data_type != DataType::Utf8 {
9141 return Err(Error::InvalidArgumentError(
9142 "ORDER BY CAST expects a text column".into(),
9143 ));
9144 }
9145 ScanOrderTransform::CastUtf8ToInteger
9146 }
9147 };
9148
9149 let direction = if order_plan.ascending {
9150 ScanOrderDirection::Ascending
9151 } else {
9152 ScanOrderDirection::Descending
9153 };
9154
9155 Ok(ScanOrderSpec {
9156 field_id,
9157 direction,
9158 nulls_first: order_plan.nulls_first,
9159 transform,
9160 })
9161}
9162
9163fn synthesize_null_scan(schema: Arc<Schema>, total_rows: u64) -> ExecutorResult<Vec<RecordBatch>> {
9164 let row_count = usize::try_from(total_rows).map_err(|_| {
9165 Error::InvalidArgumentError("table row count exceeds supported in-memory batch size".into())
9166 })?;
9167
9168 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
9169 for field in schema.fields() {
9170 match field.data_type() {
9171 DataType::Int64 => {
9172 let mut builder = Int64Builder::with_capacity(row_count);
9173 for _ in 0..row_count {
9174 builder.append_null();
9175 }
9176 arrays.push(Arc::new(builder.finish()));
9177 }
9178 DataType::Float64 => {
9179 let mut builder = arrow::array::Float64Builder::with_capacity(row_count);
9180 for _ in 0..row_count {
9181 builder.append_null();
9182 }
9183 arrays.push(Arc::new(builder.finish()));
9184 }
9185 DataType::Utf8 => {
9186 let mut builder = arrow::array::StringBuilder::with_capacity(row_count, 0);
9187 for _ in 0..row_count {
9188 builder.append_null();
9189 }
9190 arrays.push(Arc::new(builder.finish()));
9191 }
9192 DataType::Date32 => {
9193 let mut builder = arrow::array::Date32Builder::with_capacity(row_count);
9194 for _ in 0..row_count {
9195 builder.append_null();
9196 }
9197 arrays.push(Arc::new(builder.finish()));
9198 }
9199 other => {
9200 return Err(Error::InvalidArgumentError(format!(
9201 "unsupported data type in null synthesis: {other:?}"
9202 )));
9203 }
9204 }
9205 }
9206
9207 let batch = RecordBatch::try_new(schema, arrays)?;
9208 Ok(vec![batch])
9209}
9210
9211struct TableCrossProductData {
9212 schema: Arc<Schema>,
9213 batches: Vec<RecordBatch>,
9214 column_counts: Vec<usize>,
9215 table_indices: Vec<usize>,
9216}
9217
9218fn collect_table_data<P>(
9219 table_index: usize,
9220 table_ref: &llkv_plan::TableRef,
9221 table: &ExecutorTable<P>,
9222 constraints: &[ColumnConstraint],
9223) -> ExecutorResult<TableCrossProductData>
9224where
9225 P: Pager<Blob = EntryHandle> + Send + Sync,
9226{
9227 if table.schema.columns.is_empty() {
9228 return Err(Error::InvalidArgumentError(format!(
9229 "table '{}' has no columns; cross products require at least one column",
9230 table_ref.qualified_name()
9231 )));
9232 }
9233
9234 let mut projections = Vec::with_capacity(table.schema.columns.len());
9235 let mut fields = Vec::with_capacity(table.schema.columns.len());
9236
9237 for column in &table.schema.columns {
9238 let table_component = table_ref
9239 .alias
9240 .as_deref()
9241 .unwrap_or(table_ref.table.as_str());
9242 let qualified_name = format!("{}.{}.{}", table_ref.schema, table_component, column.name);
9243 projections.push(ScanProjection::from(StoreProjection::with_alias(
9244 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
9245 qualified_name.clone(),
9246 )));
9247 fields.push(Field::new(
9248 qualified_name,
9249 column.data_type.clone(),
9250 column.nullable,
9251 ));
9252 }
9253
9254 let schema = Arc::new(Schema::new(fields));
9255
9256 let filter_field_id = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
9257 let filter_expr = crate::translation::expression::full_table_scan_filter(filter_field_id);
9258
9259 let mut raw_batches = Vec::new();
9260 table.table.scan_stream(
9261 projections,
9262 &filter_expr,
9263 ScanStreamOptions {
9264 include_nulls: true,
9265 ..ScanStreamOptions::default()
9266 },
9267 |batch| {
9268 raw_batches.push(batch);
9269 },
9270 )?;
9271
9272 let mut normalized_batches = Vec::with_capacity(raw_batches.len());
9273 for batch in raw_batches {
9274 let normalized = RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec())
9275 .map_err(|err| {
9276 Error::Internal(format!(
9277 "failed to align scan batch for table '{}': {}",
9278 table_ref.qualified_name(),
9279 err
9280 ))
9281 })?;
9282 normalized_batches.push(normalized);
9283 }
9284
9285 if !constraints.is_empty() {
9286 normalized_batches = apply_column_constraints_to_batches(normalized_batches, constraints)?;
9287 }
9288
9289 Ok(TableCrossProductData {
9290 schema,
9291 batches: normalized_batches,
9292 column_counts: vec![table.schema.columns.len()],
9293 table_indices: vec![table_index],
9294 })
9295}
9296
9297fn apply_column_constraints_to_batches(
9298 batches: Vec<RecordBatch>,
9299 constraints: &[ColumnConstraint],
9300) -> ExecutorResult<Vec<RecordBatch>> {
9301 if batches.is_empty() {
9302 return Ok(batches);
9303 }
9304
9305 let mut filtered = batches;
9306 for constraint in constraints {
9307 match constraint {
9308 ColumnConstraint::Equality(lit) => {
9309 filtered = filter_batches_by_literal(filtered, lit.column.column, &lit.value)?;
9310 }
9311 ColumnConstraint::InList(in_list) => {
9312 filtered =
9313 filter_batches_by_in_list(filtered, in_list.column.column, &in_list.values)?;
9314 }
9315 }
9316 if filtered.is_empty() {
9317 break;
9318 }
9319 }
9320
9321 Ok(filtered)
9322}
9323
9324fn filter_batches_by_literal(
9325 batches: Vec<RecordBatch>,
9326 column_idx: usize,
9327 literal: &PlanValue,
9328) -> ExecutorResult<Vec<RecordBatch>> {
9329 let mut result = Vec::with_capacity(batches.len());
9330
9331 for batch in batches {
9332 if column_idx >= batch.num_columns() {
9333 return Err(Error::Internal(
9334 "literal constraint referenced invalid column index".into(),
9335 ));
9336 }
9337
9338 if batch.num_rows() == 0 {
9339 result.push(batch);
9340 continue;
9341 }
9342
9343 let column = batch.column(column_idx);
9344 let mut keep_rows: Vec<u32> = Vec::with_capacity(batch.num_rows());
9345
9346 for row_idx in 0..batch.num_rows() {
9347 if array_value_equals_plan_value(column.as_ref(), row_idx, literal)? {
9348 keep_rows.push(row_idx as u32);
9349 }
9350 }
9351
9352 if keep_rows.len() == batch.num_rows() {
9353 result.push(batch);
9354 continue;
9355 }
9356
9357 if keep_rows.is_empty() {
9358 continue;
9360 }
9361
9362 let indices = UInt32Array::from(keep_rows);
9363 let mut filtered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
9364 for col_idx in 0..batch.num_columns() {
9365 let filtered = take(batch.column(col_idx).as_ref(), &indices, None)
9366 .map_err(|err| Error::Internal(format!("failed to apply literal filter: {err}")))?;
9367 filtered_columns.push(filtered);
9368 }
9369
9370 let filtered_batch =
9371 RecordBatch::try_new(batch.schema(), filtered_columns).map_err(|err| {
9372 Error::Internal(format!(
9373 "failed to rebuild batch after literal filter: {err}"
9374 ))
9375 })?;
9376 result.push(filtered_batch);
9377 }
9378
9379 Ok(result)
9380}
9381
9382fn filter_batches_by_in_list(
9383 batches: Vec<RecordBatch>,
9384 column_idx: usize,
9385 values: &[PlanValue],
9386) -> ExecutorResult<Vec<RecordBatch>> {
9387 use arrow::array::*;
9388 use arrow::compute::or;
9389
9390 if values.is_empty() {
9391 return Ok(Vec::new());
9393 }
9394
9395 let mut result = Vec::with_capacity(batches.len());
9396
9397 for batch in batches {
9398 if column_idx >= batch.num_columns() {
9399 return Err(Error::Internal(
9400 "IN list constraint referenced invalid column index".into(),
9401 ));
9402 }
9403
9404 if batch.num_rows() == 0 {
9405 result.push(batch);
9406 continue;
9407 }
9408
9409 let column = batch.column(column_idx);
9410
9411 let mut mask = BooleanArray::from(vec![false; batch.num_rows()]);
9414
9415 for value in values {
9416 let comparison_mask = build_comparison_mask(column.as_ref(), value)?;
9417 mask = or(&mask, &comparison_mask)
9418 .map_err(|err| Error::Internal(format!("failed to OR comparison masks: {err}")))?;
9419 }
9420
9421 let true_count = mask.true_count();
9423 if true_count == batch.num_rows() {
9424 result.push(batch);
9425 continue;
9426 }
9427
9428 if true_count == 0 {
9429 continue;
9431 }
9432
9433 let filtered_batch = arrow::compute::filter_record_batch(&batch, &mask)
9435 .map_err(|err| Error::Internal(format!("failed to apply IN list filter: {err}")))?;
9436
9437 result.push(filtered_batch);
9438 }
9439
9440 Ok(result)
9441}
9442
9443fn build_comparison_mask(column: &dyn Array, value: &PlanValue) -> ExecutorResult<BooleanArray> {
9445 use arrow::array::*;
9446 use arrow::datatypes::DataType;
9447
9448 match value {
9449 PlanValue::Null => {
9450 let mut builder = BooleanBuilder::with_capacity(column.len());
9452 for i in 0..column.len() {
9453 builder.append_value(column.is_null(i));
9454 }
9455 Ok(builder.finish())
9456 }
9457 PlanValue::Integer(val) => {
9458 let mut builder = BooleanBuilder::with_capacity(column.len());
9459 match column.data_type() {
9460 DataType::Int8 => {
9461 let arr = column
9462 .as_any()
9463 .downcast_ref::<Int8Array>()
9464 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
9465 let target = *val as i8;
9466 for i in 0..arr.len() {
9467 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9468 }
9469 }
9470 DataType::Int16 => {
9471 let arr = column
9472 .as_any()
9473 .downcast_ref::<Int16Array>()
9474 .ok_or_else(|| {
9475 Error::Internal("failed to downcast to Int16Array".into())
9476 })?;
9477 let target = *val as i16;
9478 for i in 0..arr.len() {
9479 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9480 }
9481 }
9482 DataType::Int32 => {
9483 let arr = column
9484 .as_any()
9485 .downcast_ref::<Int32Array>()
9486 .ok_or_else(|| {
9487 Error::Internal("failed to downcast to Int32Array".into())
9488 })?;
9489 let target = *val as i32;
9490 for i in 0..arr.len() {
9491 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9492 }
9493 }
9494 DataType::Int64 => {
9495 let arr = column
9496 .as_any()
9497 .downcast_ref::<Int64Array>()
9498 .ok_or_else(|| {
9499 Error::Internal("failed to downcast to Int64Array".into())
9500 })?;
9501 for i in 0..arr.len() {
9502 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
9503 }
9504 }
9505 DataType::UInt8 => {
9506 let arr = column
9507 .as_any()
9508 .downcast_ref::<UInt8Array>()
9509 .ok_or_else(|| {
9510 Error::Internal("failed to downcast to UInt8Array".into())
9511 })?;
9512 let target = *val as u8;
9513 for i in 0..arr.len() {
9514 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9515 }
9516 }
9517 DataType::UInt16 => {
9518 let arr = column
9519 .as_any()
9520 .downcast_ref::<UInt16Array>()
9521 .ok_or_else(|| {
9522 Error::Internal("failed to downcast to UInt16Array".into())
9523 })?;
9524 let target = *val as u16;
9525 for i in 0..arr.len() {
9526 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9527 }
9528 }
9529 DataType::UInt32 => {
9530 let arr = column
9531 .as_any()
9532 .downcast_ref::<UInt32Array>()
9533 .ok_or_else(|| {
9534 Error::Internal("failed to downcast to UInt32Array".into())
9535 })?;
9536 let target = *val as u32;
9537 for i in 0..arr.len() {
9538 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9539 }
9540 }
9541 DataType::UInt64 => {
9542 let arr = column
9543 .as_any()
9544 .downcast_ref::<UInt64Array>()
9545 .ok_or_else(|| {
9546 Error::Internal("failed to downcast to UInt64Array".into())
9547 })?;
9548 let target = *val as u64;
9549 for i in 0..arr.len() {
9550 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9551 }
9552 }
9553 _ => {
9554 return Err(Error::Internal(format!(
9555 "unsupported integer type for IN list: {:?}",
9556 column.data_type()
9557 )));
9558 }
9559 }
9560 Ok(builder.finish())
9561 }
9562 PlanValue::Float(val) => {
9563 let mut builder = BooleanBuilder::with_capacity(column.len());
9564 match column.data_type() {
9565 DataType::Float32 => {
9566 let arr = column
9567 .as_any()
9568 .downcast_ref::<Float32Array>()
9569 .ok_or_else(|| {
9570 Error::Internal("failed to downcast to Float32Array".into())
9571 })?;
9572 let target = *val as f32;
9573 for i in 0..arr.len() {
9574 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
9575 }
9576 }
9577 DataType::Float64 => {
9578 let arr = column
9579 .as_any()
9580 .downcast_ref::<Float64Array>()
9581 .ok_or_else(|| {
9582 Error::Internal("failed to downcast to Float64Array".into())
9583 })?;
9584 for i in 0..arr.len() {
9585 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
9586 }
9587 }
9588 _ => {
9589 return Err(Error::Internal(format!(
9590 "unsupported float type for IN list: {:?}",
9591 column.data_type()
9592 )));
9593 }
9594 }
9595 Ok(builder.finish())
9596 }
9597 PlanValue::String(val) => {
9598 let mut builder = BooleanBuilder::with_capacity(column.len());
9599 let arr = column
9600 .as_any()
9601 .downcast_ref::<StringArray>()
9602 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
9603 for i in 0..arr.len() {
9604 builder.append_value(!arr.is_null(i) && arr.value(i) == val.as_str());
9605 }
9606 Ok(builder.finish())
9607 }
9608 PlanValue::Struct(_) => Err(Error::Internal(
9609 "struct comparison in IN list not supported".into(),
9610 )),
9611 }
9612}
9613
9614fn array_value_equals_plan_value(
9615 array: &dyn Array,
9616 row_idx: usize,
9617 literal: &PlanValue,
9618) -> ExecutorResult<bool> {
9619 use arrow::array::*;
9620 use arrow::datatypes::DataType;
9621
9622 match literal {
9623 PlanValue::Null => Ok(array.is_null(row_idx)),
9624 PlanValue::Integer(expected) => match array.data_type() {
9625 DataType::Int8 => Ok(!array.is_null(row_idx)
9626 && array
9627 .as_any()
9628 .downcast_ref::<Int8Array>()
9629 .expect("int8 array")
9630 .value(row_idx) as i64
9631 == *expected),
9632 DataType::Int16 => Ok(!array.is_null(row_idx)
9633 && array
9634 .as_any()
9635 .downcast_ref::<Int16Array>()
9636 .expect("int16 array")
9637 .value(row_idx) as i64
9638 == *expected),
9639 DataType::Int32 => Ok(!array.is_null(row_idx)
9640 && array
9641 .as_any()
9642 .downcast_ref::<Int32Array>()
9643 .expect("int32 array")
9644 .value(row_idx) as i64
9645 == *expected),
9646 DataType::Int64 => Ok(!array.is_null(row_idx)
9647 && array
9648 .as_any()
9649 .downcast_ref::<Int64Array>()
9650 .expect("int64 array")
9651 .value(row_idx)
9652 == *expected),
9653 DataType::UInt8 if *expected >= 0 => Ok(!array.is_null(row_idx)
9654 && array
9655 .as_any()
9656 .downcast_ref::<UInt8Array>()
9657 .expect("uint8 array")
9658 .value(row_idx) as i64
9659 == *expected),
9660 DataType::UInt16 if *expected >= 0 => Ok(!array.is_null(row_idx)
9661 && array
9662 .as_any()
9663 .downcast_ref::<UInt16Array>()
9664 .expect("uint16 array")
9665 .value(row_idx) as i64
9666 == *expected),
9667 DataType::UInt32 if *expected >= 0 => Ok(!array.is_null(row_idx)
9668 && array
9669 .as_any()
9670 .downcast_ref::<UInt32Array>()
9671 .expect("uint32 array")
9672 .value(row_idx) as i64
9673 == *expected),
9674 DataType::UInt64 if *expected >= 0 => Ok(!array.is_null(row_idx)
9675 && array
9676 .as_any()
9677 .downcast_ref::<UInt64Array>()
9678 .expect("uint64 array")
9679 .value(row_idx)
9680 == *expected as u64),
9681 DataType::Boolean => {
9682 if array.is_null(row_idx) {
9683 Ok(false)
9684 } else if *expected == 0 || *expected == 1 {
9685 let value = array
9686 .as_any()
9687 .downcast_ref::<BooleanArray>()
9688 .expect("bool array")
9689 .value(row_idx);
9690 Ok(value == (*expected == 1))
9691 } else {
9692 Ok(false)
9693 }
9694 }
9695 _ => Err(Error::InvalidArgumentError(format!(
9696 "literal integer comparison not supported for {:?}",
9697 array.data_type()
9698 ))),
9699 },
9700 PlanValue::Float(expected) => match array.data_type() {
9701 DataType::Float32 => Ok(!array.is_null(row_idx)
9702 && (array
9703 .as_any()
9704 .downcast_ref::<Float32Array>()
9705 .expect("float32 array")
9706 .value(row_idx) as f64
9707 - *expected)
9708 .abs()
9709 .eq(&0.0)),
9710 DataType::Float64 => Ok(!array.is_null(row_idx)
9711 && (array
9712 .as_any()
9713 .downcast_ref::<Float64Array>()
9714 .expect("float64 array")
9715 .value(row_idx)
9716 - *expected)
9717 .abs()
9718 .eq(&0.0)),
9719 _ => Err(Error::InvalidArgumentError(format!(
9720 "literal float comparison not supported for {:?}",
9721 array.data_type()
9722 ))),
9723 },
9724 PlanValue::String(expected) => match array.data_type() {
9725 DataType::Utf8 => Ok(!array.is_null(row_idx)
9726 && array
9727 .as_any()
9728 .downcast_ref::<StringArray>()
9729 .expect("string array")
9730 .value(row_idx)
9731 == expected),
9732 DataType::LargeUtf8 => Ok(!array.is_null(row_idx)
9733 && array
9734 .as_any()
9735 .downcast_ref::<LargeStringArray>()
9736 .expect("large string array")
9737 .value(row_idx)
9738 == expected),
9739 _ => Err(Error::InvalidArgumentError(format!(
9740 "literal string comparison not supported for {:?}",
9741 array.data_type()
9742 ))),
9743 },
9744 PlanValue::Struct(_) => Err(Error::InvalidArgumentError(
9745 "struct literals are not supported in join filters".into(),
9746 )),
9747 }
9748}
9749
9750fn hash_join_table_batches(
9751 left: TableCrossProductData,
9752 right: TableCrossProductData,
9753 join_keys: &[(usize, usize)],
9754 join_type: llkv_join::JoinType,
9755) -> ExecutorResult<TableCrossProductData> {
9756 let TableCrossProductData {
9757 schema: left_schema,
9758 batches: left_batches,
9759 column_counts: left_counts,
9760 table_indices: left_tables,
9761 } = left;
9762
9763 let TableCrossProductData {
9764 schema: right_schema,
9765 batches: right_batches,
9766 column_counts: right_counts,
9767 table_indices: right_tables,
9768 } = right;
9769
9770 let combined_fields: Vec<Field> = left_schema
9771 .fields()
9772 .iter()
9773 .chain(right_schema.fields().iter())
9774 .map(|field| field.as_ref().clone())
9775 .collect();
9776
9777 let combined_schema = Arc::new(Schema::new(combined_fields));
9778
9779 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
9780 column_counts.extend(left_counts.iter());
9781 column_counts.extend(right_counts.iter());
9782
9783 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
9784 table_indices.extend(left_tables.iter().copied());
9785 table_indices.extend(right_tables.iter().copied());
9786
9787 if left_batches.is_empty() {
9789 return Ok(TableCrossProductData {
9790 schema: combined_schema,
9791 batches: Vec::new(),
9792 column_counts,
9793 table_indices,
9794 });
9795 }
9796
9797 if right_batches.is_empty() {
9798 if join_type == llkv_join::JoinType::Left {
9800 let total_left_rows: usize = left_batches.iter().map(|b| b.num_rows()).sum();
9801 let mut left_arrays = Vec::new();
9802 for field in left_schema.fields() {
9803 let column_idx = left_schema.index_of(field.name()).map_err(|e| {
9804 Error::Internal(format!("failed to find field {}: {}", field.name(), e))
9805 })?;
9806 let arrays: Vec<ArrayRef> = left_batches
9807 .iter()
9808 .map(|batch| batch.column(column_idx).clone())
9809 .collect();
9810 let concatenated =
9811 arrow::compute::concat(&arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
9812 .map_err(|e| {
9813 Error::Internal(format!("failed to concat left arrays: {}", e))
9814 })?;
9815 left_arrays.push(concatenated);
9816 }
9817
9818 for field in right_schema.fields() {
9820 let null_array = arrow::array::new_null_array(field.data_type(), total_left_rows);
9821 left_arrays.push(null_array);
9822 }
9823
9824 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), left_arrays)
9825 .map_err(|err| {
9826 Error::Internal(format!(
9827 "failed to create LEFT JOIN batch with NULL right: {err}"
9828 ))
9829 })?;
9830
9831 return Ok(TableCrossProductData {
9832 schema: combined_schema,
9833 batches: vec![joined_batch],
9834 column_counts,
9835 table_indices,
9836 });
9837 } else {
9838 return Ok(TableCrossProductData {
9840 schema: combined_schema,
9841 batches: Vec::new(),
9842 column_counts,
9843 table_indices,
9844 });
9845 }
9846 }
9847
9848 match join_type {
9849 llkv_join::JoinType::Inner => {
9850 let (left_matches, right_matches) =
9851 build_join_match_indices(&left_batches, &right_batches, join_keys)?;
9852
9853 if left_matches.is_empty() {
9854 return Ok(TableCrossProductData {
9855 schema: combined_schema,
9856 batches: Vec::new(),
9857 column_counts,
9858 table_indices,
9859 });
9860 }
9861
9862 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
9863 let right_arrays = gather_indices_from_batches(&right_batches, &right_matches)?;
9864
9865 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
9866 combined_columns.extend(left_arrays);
9867 combined_columns.extend(right_arrays);
9868
9869 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
9870 .map_err(|err| {
9871 Error::Internal(format!("failed to materialize INNER JOIN batch: {err}"))
9872 })?;
9873
9874 Ok(TableCrossProductData {
9875 schema: combined_schema,
9876 batches: vec![joined_batch],
9877 column_counts,
9878 table_indices,
9879 })
9880 }
9881 llkv_join::JoinType::Left => {
9882 let (left_matches, right_optional_matches) =
9883 build_left_join_match_indices(&left_batches, &right_batches, join_keys)?;
9884
9885 if left_matches.is_empty() {
9886 return Ok(TableCrossProductData {
9888 schema: combined_schema,
9889 batches: Vec::new(),
9890 column_counts,
9891 table_indices,
9892 });
9893 }
9894
9895 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
9896 let right_arrays = llkv_column_map::gather::gather_optional_indices_from_batches(
9898 &right_batches,
9899 &right_optional_matches,
9900 )?;
9901
9902 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
9903 combined_columns.extend(left_arrays);
9904 combined_columns.extend(right_arrays);
9905
9906 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
9907 .map_err(|err| {
9908 Error::Internal(format!("failed to materialize LEFT JOIN batch: {err}"))
9909 })?;
9910
9911 Ok(TableCrossProductData {
9912 schema: combined_schema,
9913 batches: vec![joined_batch],
9914 column_counts,
9915 table_indices,
9916 })
9917 }
9918 _ => Err(Error::Internal(format!(
9920 "join type {:?} not supported in hash_join_table_batches; use llkv-join",
9921 join_type
9922 ))),
9923 }
9924}
9925
9926type JoinMatchIndices = Vec<(usize, usize)>;
9928type JoinHashTable = FxHashMap<Vec<u8>, Vec<(usize, usize)>>;
9930type JoinMatchPairs = (JoinMatchIndices, JoinMatchIndices);
9932type OptionalJoinMatches = Vec<Option<(usize, usize)>>;
9934type LeftJoinMatchPairs = (JoinMatchIndices, OptionalJoinMatches);
9936
9937fn build_join_match_indices(
9967 left_batches: &[RecordBatch],
9968 right_batches: &[RecordBatch],
9969 join_keys: &[(usize, usize)],
9970) -> ExecutorResult<JoinMatchPairs> {
9971 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
9972
9973 let hash_table: JoinHashTable = llkv_column_map::parallel::with_thread_pool(|| {
9976 let local_tables: Vec<JoinHashTable> = right_batches
9977 .par_iter()
9978 .enumerate()
9979 .map(|(batch_idx, batch)| {
9980 let mut local_table: JoinHashTable = FxHashMap::default();
9981 let mut key_buffer: Vec<u8> = Vec::new();
9982
9983 for row_idx in 0..batch.num_rows() {
9984 key_buffer.clear();
9985 match build_join_key(batch, &right_key_indices, row_idx, &mut key_buffer) {
9986 Ok(true) => {
9987 local_table
9988 .entry(key_buffer.clone())
9989 .or_default()
9990 .push((batch_idx, row_idx));
9991 }
9992 Ok(false) => continue,
9993 Err(_) => continue, }
9995 }
9996
9997 local_table
9998 })
9999 .collect();
10000
10001 let mut merged_table: JoinHashTable = FxHashMap::default();
10003 for local_table in local_tables {
10004 for (key, mut positions) in local_table {
10005 merged_table.entry(key).or_default().append(&mut positions);
10006 }
10007 }
10008
10009 merged_table
10010 });
10011
10012 if hash_table.is_empty() {
10013 return Ok((Vec::new(), Vec::new()));
10014 }
10015
10016 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
10017
10018 let matches: Vec<JoinMatchPairs> = llkv_column_map::parallel::with_thread_pool(|| {
10021 left_batches
10022 .par_iter()
10023 .enumerate()
10024 .map(|(batch_idx, batch)| {
10025 let mut local_left_matches: JoinMatchIndices = Vec::new();
10026 let mut local_right_matches: JoinMatchIndices = Vec::new();
10027 let mut key_buffer: Vec<u8> = Vec::new();
10028
10029 for row_idx in 0..batch.num_rows() {
10030 key_buffer.clear();
10031 match build_join_key(batch, &left_key_indices, row_idx, &mut key_buffer) {
10032 Ok(true) => {
10033 if let Some(entries) = hash_table.get(&key_buffer) {
10034 for &(r_batch, r_row) in entries {
10035 local_left_matches.push((batch_idx, row_idx));
10036 local_right_matches.push((r_batch, r_row));
10037 }
10038 }
10039 }
10040 Ok(false) => continue,
10041 Err(_) => continue, }
10043 }
10044
10045 (local_left_matches, local_right_matches)
10046 })
10047 .collect()
10048 });
10049
10050 let mut left_matches: JoinMatchIndices = Vec::new();
10052 let mut right_matches: JoinMatchIndices = Vec::new();
10053 for (mut left, mut right) in matches {
10054 left_matches.append(&mut left);
10055 right_matches.append(&mut right);
10056 }
10057
10058 Ok((left_matches, right_matches))
10059}
10060
10061fn build_left_join_match_indices(
10072 left_batches: &[RecordBatch],
10073 right_batches: &[RecordBatch],
10074 join_keys: &[(usize, usize)],
10075) -> ExecutorResult<LeftJoinMatchPairs> {
10076 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
10077
10078 let hash_table: JoinHashTable = llkv_column_map::parallel::with_thread_pool(|| {
10080 let local_tables: Vec<JoinHashTable> = right_batches
10081 .par_iter()
10082 .enumerate()
10083 .map(|(batch_idx, batch)| {
10084 let mut local_table: JoinHashTable = FxHashMap::default();
10085 let mut key_buffer: Vec<u8> = Vec::new();
10086
10087 for row_idx in 0..batch.num_rows() {
10088 key_buffer.clear();
10089 match build_join_key(batch, &right_key_indices, row_idx, &mut key_buffer) {
10090 Ok(true) => {
10091 local_table
10092 .entry(key_buffer.clone())
10093 .or_default()
10094 .push((batch_idx, row_idx));
10095 }
10096 Ok(false) => continue,
10097 Err(_) => continue,
10098 }
10099 }
10100
10101 local_table
10102 })
10103 .collect();
10104
10105 let mut merged_table: JoinHashTable = FxHashMap::default();
10106 for local_table in local_tables {
10107 for (key, mut positions) in local_table {
10108 merged_table.entry(key).or_default().append(&mut positions);
10109 }
10110 }
10111
10112 merged_table
10113 });
10114
10115 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
10116
10117 let matches: Vec<LeftJoinMatchPairs> = llkv_column_map::parallel::with_thread_pool(|| {
10119 left_batches
10120 .par_iter()
10121 .enumerate()
10122 .map(|(batch_idx, batch)| {
10123 let mut local_left_matches: JoinMatchIndices = Vec::new();
10124 let mut local_right_optional: Vec<Option<(usize, usize)>> = Vec::new();
10125 let mut key_buffer: Vec<u8> = Vec::new();
10126
10127 for row_idx in 0..batch.num_rows() {
10128 key_buffer.clear();
10129 match build_join_key(batch, &left_key_indices, row_idx, &mut key_buffer) {
10130 Ok(true) => {
10131 if let Some(entries) = hash_table.get(&key_buffer) {
10132 for &(r_batch, r_row) in entries {
10134 local_left_matches.push((batch_idx, row_idx));
10135 local_right_optional.push(Some((r_batch, r_row)));
10136 }
10137 } else {
10138 local_left_matches.push((batch_idx, row_idx));
10140 local_right_optional.push(None);
10141 }
10142 }
10143 Ok(false) => {
10144 local_left_matches.push((batch_idx, row_idx));
10146 local_right_optional.push(None);
10147 }
10148 Err(_) => {
10149 local_left_matches.push((batch_idx, row_idx));
10151 local_right_optional.push(None);
10152 }
10153 }
10154 }
10155
10156 (local_left_matches, local_right_optional)
10157 })
10158 .collect()
10159 });
10160
10161 let mut left_matches: JoinMatchIndices = Vec::new();
10163 let mut right_optional: Vec<Option<(usize, usize)>> = Vec::new();
10164 for (mut left, mut right) in matches {
10165 left_matches.append(&mut left);
10166 right_optional.append(&mut right);
10167 }
10168
10169 Ok((left_matches, right_optional))
10170}
10171
10172fn build_join_key(
10173 batch: &RecordBatch,
10174 column_indices: &[usize],
10175 row_idx: usize,
10176 buffer: &mut Vec<u8>,
10177) -> ExecutorResult<bool> {
10178 buffer.clear();
10179
10180 for &col_idx in column_indices {
10181 let array = batch.column(col_idx);
10182 if array.is_null(row_idx) {
10183 return Ok(false);
10184 }
10185 append_array_value_to_key(array.as_ref(), row_idx, buffer)?;
10186 }
10187
10188 Ok(true)
10189}
10190
10191fn append_array_value_to_key(
10192 array: &dyn Array,
10193 row_idx: usize,
10194 buffer: &mut Vec<u8>,
10195) -> ExecutorResult<()> {
10196 use arrow::array::*;
10197 use arrow::datatypes::DataType;
10198
10199 match array.data_type() {
10200 DataType::Int8 => buffer.extend_from_slice(
10201 &array
10202 .as_any()
10203 .downcast_ref::<Int8Array>()
10204 .expect("int8 array")
10205 .value(row_idx)
10206 .to_le_bytes(),
10207 ),
10208 DataType::Int16 => buffer.extend_from_slice(
10209 &array
10210 .as_any()
10211 .downcast_ref::<Int16Array>()
10212 .expect("int16 array")
10213 .value(row_idx)
10214 .to_le_bytes(),
10215 ),
10216 DataType::Int32 => buffer.extend_from_slice(
10217 &array
10218 .as_any()
10219 .downcast_ref::<Int32Array>()
10220 .expect("int32 array")
10221 .value(row_idx)
10222 .to_le_bytes(),
10223 ),
10224 DataType::Int64 => buffer.extend_from_slice(
10225 &array
10226 .as_any()
10227 .downcast_ref::<Int64Array>()
10228 .expect("int64 array")
10229 .value(row_idx)
10230 .to_le_bytes(),
10231 ),
10232 DataType::UInt8 => buffer.extend_from_slice(
10233 &array
10234 .as_any()
10235 .downcast_ref::<UInt8Array>()
10236 .expect("uint8 array")
10237 .value(row_idx)
10238 .to_le_bytes(),
10239 ),
10240 DataType::UInt16 => buffer.extend_from_slice(
10241 &array
10242 .as_any()
10243 .downcast_ref::<UInt16Array>()
10244 .expect("uint16 array")
10245 .value(row_idx)
10246 .to_le_bytes(),
10247 ),
10248 DataType::UInt32 => buffer.extend_from_slice(
10249 &array
10250 .as_any()
10251 .downcast_ref::<UInt32Array>()
10252 .expect("uint32 array")
10253 .value(row_idx)
10254 .to_le_bytes(),
10255 ),
10256 DataType::UInt64 => buffer.extend_from_slice(
10257 &array
10258 .as_any()
10259 .downcast_ref::<UInt64Array>()
10260 .expect("uint64 array")
10261 .value(row_idx)
10262 .to_le_bytes(),
10263 ),
10264 DataType::Float32 => buffer.extend_from_slice(
10265 &array
10266 .as_any()
10267 .downcast_ref::<Float32Array>()
10268 .expect("float32 array")
10269 .value(row_idx)
10270 .to_le_bytes(),
10271 ),
10272 DataType::Float64 => buffer.extend_from_slice(
10273 &array
10274 .as_any()
10275 .downcast_ref::<Float64Array>()
10276 .expect("float64 array")
10277 .value(row_idx)
10278 .to_le_bytes(),
10279 ),
10280 DataType::Boolean => buffer.push(
10281 array
10282 .as_any()
10283 .downcast_ref::<BooleanArray>()
10284 .expect("bool array")
10285 .value(row_idx) as u8,
10286 ),
10287 DataType::Utf8 => {
10288 let value = array
10289 .as_any()
10290 .downcast_ref::<StringArray>()
10291 .expect("utf8 array")
10292 .value(row_idx);
10293 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
10294 buffer.extend_from_slice(value.as_bytes());
10295 }
10296 DataType::LargeUtf8 => {
10297 let value = array
10298 .as_any()
10299 .downcast_ref::<LargeStringArray>()
10300 .expect("large utf8 array")
10301 .value(row_idx);
10302 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
10303 buffer.extend_from_slice(value.as_bytes());
10304 }
10305 DataType::Binary => {
10306 let value = array
10307 .as_any()
10308 .downcast_ref::<BinaryArray>()
10309 .expect("binary array")
10310 .value(row_idx);
10311 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
10312 buffer.extend_from_slice(value);
10313 }
10314 other => {
10315 return Err(Error::InvalidArgumentError(format!(
10316 "hash join does not support join key type {:?}",
10317 other
10318 )));
10319 }
10320 }
10321
10322 Ok(())
10323}
10324
10325fn table_has_join_with_used(
10326 candidate: usize,
10327 used_tables: &FxHashSet<usize>,
10328 equalities: &[ColumnEquality],
10329) -> bool {
10330 equalities.iter().any(|equality| {
10331 (equality.left.table == candidate && used_tables.contains(&equality.right.table))
10332 || (equality.right.table == candidate && used_tables.contains(&equality.left.table))
10333 })
10334}
10335
10336fn gather_join_keys(
10337 left: &TableCrossProductData,
10338 right: &TableCrossProductData,
10339 used_tables: &FxHashSet<usize>,
10340 right_table_index: usize,
10341 equalities: &[ColumnEquality],
10342) -> ExecutorResult<Vec<(usize, usize)>> {
10343 let mut keys = Vec::new();
10344
10345 for equality in equalities {
10346 if equality.left.table == right_table_index && used_tables.contains(&equality.right.table) {
10347 let left_idx = resolve_column_index(left, &equality.right).ok_or_else(|| {
10348 Error::Internal("failed to resolve column offset for hash join".into())
10349 })?;
10350 let right_idx = resolve_column_index(right, &equality.left).ok_or_else(|| {
10351 Error::Internal("failed to resolve column offset for hash join".into())
10352 })?;
10353 keys.push((left_idx, right_idx));
10354 } else if equality.right.table == right_table_index
10355 && used_tables.contains(&equality.left.table)
10356 {
10357 let left_idx = resolve_column_index(left, &equality.left).ok_or_else(|| {
10358 Error::Internal("failed to resolve column offset for hash join".into())
10359 })?;
10360 let right_idx = resolve_column_index(right, &equality.right).ok_or_else(|| {
10361 Error::Internal("failed to resolve column offset for hash join".into())
10362 })?;
10363 keys.push((left_idx, right_idx));
10364 }
10365 }
10366
10367 Ok(keys)
10368}
10369
10370fn resolve_column_index(data: &TableCrossProductData, column: &ColumnRef) -> Option<usize> {
10371 let mut offset = 0;
10372 for (table_idx, count) in data.table_indices.iter().zip(data.column_counts.iter()) {
10373 if *table_idx == column.table {
10374 if column.column < *count {
10375 return Some(offset + column.column);
10376 } else {
10377 return None;
10378 }
10379 }
10380 offset += count;
10381 }
10382 None
10383}
10384
10385fn build_cross_product_column_lookup(
10386 schema: &Schema,
10387 tables: &[llkv_plan::TableRef],
10388 column_counts: &[usize],
10389 table_indices: &[usize],
10390) -> FxHashMap<String, usize> {
10391 debug_assert_eq!(tables.len(), column_counts.len());
10392 debug_assert_eq!(column_counts.len(), table_indices.len());
10393
10394 let mut column_occurrences: FxHashMap<String, usize> = FxHashMap::default();
10395 let mut table_column_counts: FxHashMap<String, usize> = FxHashMap::default();
10396 for field in schema.fields() {
10397 let column_name = extract_column_name(field.name());
10398 *column_occurrences.entry(column_name).or_insert(0) += 1;
10399 if let Some(pair) = table_column_suffix(field.name()) {
10400 *table_column_counts.entry(pair).or_insert(0) += 1;
10401 }
10402 }
10403
10404 let mut base_table_totals: FxHashMap<String, usize> = FxHashMap::default();
10405 let mut base_table_unaliased: FxHashMap<String, usize> = FxHashMap::default();
10406 for table_ref in tables {
10407 let key = base_table_key(table_ref);
10408 *base_table_totals.entry(key.clone()).or_insert(0) += 1;
10409 if table_ref.alias.is_none() {
10410 *base_table_unaliased.entry(key).or_insert(0) += 1;
10411 }
10412 }
10413
10414 let mut lookup = FxHashMap::default();
10415
10416 if table_indices.is_empty() || column_counts.is_empty() {
10417 for (idx, field) in schema.fields().iter().enumerate() {
10418 let field_name_lower = field.name().to_ascii_lowercase();
10419 lookup.entry(field_name_lower).or_insert(idx);
10420
10421 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
10422 lookup.entry(trimmed_lower).or_insert(idx);
10423
10424 if let Some(pair) = table_column_suffix(field.name())
10425 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
10426 {
10427 lookup.entry(pair).or_insert(idx);
10428 }
10429
10430 let column_name = extract_column_name(field.name());
10431 if column_occurrences.get(&column_name).copied().unwrap_or(0) == 1 {
10432 lookup.entry(column_name).or_insert(idx);
10433 }
10434 }
10435 return lookup;
10436 }
10437
10438 let mut offset = 0usize;
10439 for (&table_idx, &count) in table_indices.iter().zip(column_counts.iter()) {
10440 if table_idx >= tables.len() {
10441 continue;
10442 }
10443 let table_ref = &tables[table_idx];
10444 let alias_lower = table_ref
10445 .alias
10446 .as_ref()
10447 .map(|alias| alias.to_ascii_lowercase());
10448 let table_lower = table_ref.table.to_ascii_lowercase();
10449 let schema_lower = table_ref.schema.to_ascii_lowercase();
10450 let base_key = base_table_key(table_ref);
10451 let total_refs = base_table_totals.get(&base_key).copied().unwrap_or(0);
10452 let unaliased_refs = base_table_unaliased.get(&base_key).copied().unwrap_or(0);
10453
10454 let allow_base_mapping = if table_ref.alias.is_none() {
10455 unaliased_refs == 1
10456 } else {
10457 unaliased_refs == 0 && total_refs == 1
10458 };
10459
10460 let mut table_keys: Vec<String> = Vec::new();
10461
10462 if let Some(alias) = &alias_lower {
10463 table_keys.push(alias.clone());
10464 if !schema_lower.is_empty() {
10465 table_keys.push(format!("{}.{}", schema_lower, alias));
10466 }
10467 }
10468
10469 if allow_base_mapping {
10470 table_keys.push(table_lower.clone());
10471 if !schema_lower.is_empty() {
10472 table_keys.push(format!("{}.{}", schema_lower, table_lower));
10473 }
10474 }
10475
10476 for local_idx in 0..count {
10477 let field_index = offset + local_idx;
10478 let field = schema.field(field_index);
10479 let field_name_lower = field.name().to_ascii_lowercase();
10480 lookup.entry(field_name_lower).or_insert(field_index);
10481
10482 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
10483 lookup.entry(trimmed_lower).or_insert(field_index);
10484
10485 let column_name = extract_column_name(field.name());
10486 for table_key in &table_keys {
10487 lookup
10488 .entry(format!("{}.{}", table_key, column_name))
10489 .or_insert(field_index);
10490 }
10491
10492 lookup.entry(column_name.clone()).or_insert(field_index);
10496
10497 if table_keys.is_empty()
10498 && let Some(pair) = table_column_suffix(field.name())
10499 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
10500 {
10501 lookup.entry(pair).or_insert(field_index);
10502 }
10503 }
10504
10505 offset = offset.saturating_add(count);
10506 }
10507
10508 lookup
10509}
10510
10511fn base_table_key(table_ref: &llkv_plan::TableRef) -> String {
10512 let schema_lower = table_ref.schema.to_ascii_lowercase();
10513 let table_lower = table_ref.table.to_ascii_lowercase();
10514 if schema_lower.is_empty() {
10515 table_lower
10516 } else {
10517 format!("{}.{}", schema_lower, table_lower)
10518 }
10519}
10520
10521fn extract_column_name(name: &str) -> String {
10522 name.trim_start_matches('.')
10523 .rsplit('.')
10524 .next()
10525 .unwrap_or(name)
10526 .to_ascii_lowercase()
10527}
10528
10529fn table_column_suffix(name: &str) -> Option<String> {
10530 let trimmed = name.trim_start_matches('.');
10531 let mut parts: Vec<&str> = trimmed.split('.').collect();
10532 if parts.len() < 2 {
10533 return None;
10534 }
10535 let column = parts.pop()?.to_ascii_lowercase();
10536 let table = parts.pop()?.to_ascii_lowercase();
10537 Some(format!("{}.{}", table, column))
10538}
10539
10540fn cross_join_table_batches(
10565 left: TableCrossProductData,
10566 right: TableCrossProductData,
10567) -> ExecutorResult<TableCrossProductData> {
10568 let TableCrossProductData {
10569 schema: left_schema,
10570 batches: left_batches,
10571 column_counts: mut left_counts,
10572 table_indices: mut left_tables,
10573 } = left;
10574 let TableCrossProductData {
10575 schema: right_schema,
10576 batches: right_batches,
10577 column_counts: right_counts,
10578 table_indices: right_tables,
10579 } = right;
10580
10581 let combined_fields: Vec<Field> = left_schema
10582 .fields()
10583 .iter()
10584 .chain(right_schema.fields().iter())
10585 .map(|field| field.as_ref().clone())
10586 .collect();
10587
10588 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
10589 column_counts.append(&mut left_counts);
10590 column_counts.extend(right_counts);
10591
10592 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
10593 table_indices.append(&mut left_tables);
10594 table_indices.extend(right_tables);
10595
10596 let combined_schema = Arc::new(Schema::new(combined_fields));
10597
10598 let left_has_rows = left_batches.iter().any(|batch| batch.num_rows() > 0);
10599 let right_has_rows = right_batches.iter().any(|batch| batch.num_rows() > 0);
10600
10601 if !left_has_rows || !right_has_rows {
10602 return Ok(TableCrossProductData {
10603 schema: combined_schema,
10604 batches: Vec::new(),
10605 column_counts,
10606 table_indices,
10607 });
10608 }
10609
10610 let output_batches: Vec<RecordBatch> = llkv_column_map::parallel::with_thread_pool(|| {
10613 left_batches
10614 .par_iter()
10615 .filter(|left_batch| left_batch.num_rows() > 0)
10616 .flat_map(|left_batch| {
10617 right_batches
10618 .par_iter()
10619 .filter(|right_batch| right_batch.num_rows() > 0)
10620 .filter_map(|right_batch| {
10621 cross_join_pair(left_batch, right_batch, &combined_schema).ok()
10622 })
10623 .collect::<Vec<_>>()
10624 })
10625 .collect()
10626 });
10627
10628 Ok(TableCrossProductData {
10629 schema: combined_schema,
10630 batches: output_batches,
10631 column_counts,
10632 table_indices,
10633 })
10634}
10635
10636fn cross_join_all(staged: Vec<TableCrossProductData>) -> ExecutorResult<TableCrossProductData> {
10637 let mut iter = staged.into_iter();
10638 let mut current = iter
10639 .next()
10640 .ok_or_else(|| Error::Internal("cross product preparation yielded no tables".into()))?;
10641 for next in iter {
10642 current = cross_join_table_batches(current, next)?;
10643 }
10644 Ok(current)
10645}
10646
10647struct TableInfo<'a> {
10648 index: usize,
10649 table_ref: &'a llkv_plan::TableRef,
10650 column_map: FxHashMap<String, usize>,
10651}
10652
10653#[derive(Clone, Copy)]
10654struct ColumnRef {
10655 table: usize,
10656 column: usize,
10657}
10658
10659#[derive(Clone, Copy)]
10660struct ColumnEquality {
10661 left: ColumnRef,
10662 right: ColumnRef,
10663}
10664
10665#[derive(Clone)]
10666struct ColumnLiteral {
10667 column: ColumnRef,
10668 value: PlanValue,
10669}
10670
10671#[derive(Clone)]
10672struct ColumnInList {
10673 column: ColumnRef,
10674 values: Vec<PlanValue>,
10675}
10676
10677#[derive(Clone)]
10678enum ColumnConstraint {
10679 Equality(ColumnLiteral),
10680 InList(ColumnInList),
10681}
10682
10683struct JoinConstraintPlan {
10685 equalities: Vec<ColumnEquality>,
10686 literals: Vec<ColumnConstraint>,
10687 unsatisfiable: bool,
10688 total_conjuncts: usize,
10690 handled_conjuncts: usize,
10692}
10693
10694fn extract_literal_pushdown_filters<P>(
10713 expr: &LlkvExpr<'static, String>,
10714 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
10715) -> Vec<Vec<ColumnConstraint>>
10716where
10717 P: Pager<Blob = EntryHandle> + Send + Sync,
10718{
10719 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
10720 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
10721 let mut column_map = FxHashMap::default();
10722 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
10723 let column_name = column.name.to_ascii_lowercase();
10724 column_map.entry(column_name).or_insert(column_idx);
10725 }
10726 table_infos.push(TableInfo {
10727 index,
10728 table_ref,
10729 column_map,
10730 });
10731 }
10732
10733 let mut constraints: Vec<Vec<ColumnConstraint>> = vec![Vec::new(); tables_with_handles.len()];
10734
10735 let mut conjuncts = Vec::new();
10737 collect_conjuncts_lenient(expr, &mut conjuncts);
10738
10739 for conjunct in conjuncts {
10740 if let LlkvExpr::Compare {
10742 left,
10743 op: CompareOp::Eq,
10744 right,
10745 } = conjunct
10746 {
10747 match (
10748 resolve_column_reference(left, &table_infos),
10749 resolve_column_reference(right, &table_infos),
10750 ) {
10751 (Some(column), None) => {
10752 if let Some(literal) = extract_literal(right)
10753 && let Some(value) = literal_to_plan_value_for_join(literal)
10754 && column.table < constraints.len()
10755 {
10756 constraints[column.table]
10757 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
10758 }
10759 }
10760 (None, Some(column)) => {
10761 if let Some(literal) = extract_literal(left)
10762 && let Some(value) = literal_to_plan_value_for_join(literal)
10763 && column.table < constraints.len()
10764 {
10765 constraints[column.table]
10766 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
10767 }
10768 }
10769 _ => {}
10770 }
10771 }
10772 else if let LlkvExpr::Pred(filter) = conjunct {
10775 if let Operator::Equals(ref literal_val) = filter.op {
10776 let field_name = filter.field_id.trim().to_ascii_lowercase();
10778
10779 for info in &table_infos {
10781 if let Some(&col_idx) = info.column_map.get(&field_name) {
10782 if let Some(value) = plan_value_from_operator_literal(literal_val) {
10783 let column_ref = ColumnRef {
10784 table: info.index,
10785 column: col_idx,
10786 };
10787 if info.index < constraints.len() {
10788 constraints[info.index].push(ColumnConstraint::Equality(
10789 ColumnLiteral {
10790 column: column_ref,
10791 value,
10792 },
10793 ));
10794 }
10795 }
10796 break; }
10798 }
10799 }
10800 }
10801 else if let LlkvExpr::InList {
10803 expr: col_expr,
10804 list,
10805 negated: false,
10806 } = conjunct
10807 {
10808 if let Some(column) = resolve_column_reference(col_expr, &table_infos) {
10809 let mut values = Vec::new();
10810 for item in list {
10811 if let Some(literal) = extract_literal(item)
10812 && let Some(value) = literal_to_plan_value_for_join(literal)
10813 {
10814 values.push(value);
10815 }
10816 }
10817 if !values.is_empty() && column.table < constraints.len() {
10818 constraints[column.table]
10819 .push(ColumnConstraint::InList(ColumnInList { column, values }));
10820 }
10821 }
10822 }
10823 else if let LlkvExpr::Or(or_children) = conjunct
10825 && let Some((column, values)) = try_extract_or_as_in_list(or_children, &table_infos)
10826 && !values.is_empty()
10827 && column.table < constraints.len()
10828 {
10829 constraints[column.table]
10830 .push(ColumnConstraint::InList(ColumnInList { column, values }));
10831 }
10832 }
10833
10834 constraints
10835}
10836
10837fn collect_conjuncts_lenient<'a>(
10842 expr: &'a LlkvExpr<'static, String>,
10843 out: &mut Vec<&'a LlkvExpr<'static, String>>,
10844) {
10845 match expr {
10846 LlkvExpr::And(children) => {
10847 for child in children {
10848 collect_conjuncts_lenient(child, out);
10849 }
10850 }
10851 other => {
10852 out.push(other);
10854 }
10855 }
10856}
10857
10858fn try_extract_or_as_in_list(
10862 or_children: &[LlkvExpr<'static, String>],
10863 table_infos: &[TableInfo<'_>],
10864) -> Option<(ColumnRef, Vec<PlanValue>)> {
10865 if or_children.is_empty() {
10866 return None;
10867 }
10868
10869 let mut common_column: Option<ColumnRef> = None;
10870 let mut values = Vec::new();
10871
10872 for child in or_children {
10873 if let LlkvExpr::Compare {
10875 left,
10876 op: CompareOp::Eq,
10877 right,
10878 } = child
10879 {
10880 if let (Some(column), None) = (
10882 resolve_column_reference(left, table_infos),
10883 resolve_column_reference(right, table_infos),
10884 ) && let Some(literal) = extract_literal(right)
10885 && let Some(value) = literal_to_plan_value_for_join(literal)
10886 {
10887 match common_column {
10889 None => common_column = Some(column),
10890 Some(ref prev)
10891 if prev.table == column.table && prev.column == column.column =>
10892 {
10893 }
10895 _ => {
10896 return None;
10898 }
10899 }
10900 values.push(value);
10901 continue;
10902 }
10903
10904 if let (None, Some(column)) = (
10906 resolve_column_reference(left, table_infos),
10907 resolve_column_reference(right, table_infos),
10908 ) && let Some(literal) = extract_literal(left)
10909 && let Some(value) = literal_to_plan_value_for_join(literal)
10910 {
10911 match common_column {
10912 None => common_column = Some(column),
10913 Some(ref prev)
10914 if prev.table == column.table && prev.column == column.column => {}
10915 _ => return None,
10916 }
10917 values.push(value);
10918 continue;
10919 }
10920 }
10921 else if let LlkvExpr::Pred(filter) = child
10923 && let Operator::Equals(ref literal) = filter.op
10924 && let Some(column) =
10925 resolve_column_reference(&ScalarExpr::Column(filter.field_id.clone()), table_infos)
10926 && let Some(value) = literal_to_plan_value_for_join(literal)
10927 {
10928 match common_column {
10929 None => common_column = Some(column),
10930 Some(ref prev) if prev.table == column.table && prev.column == column.column => {}
10931 _ => return None,
10932 }
10933 values.push(value);
10934 continue;
10935 }
10936
10937 return None;
10939 }
10940
10941 common_column.map(|col| (col, values))
10942}
10943
10944fn extract_join_constraints(
10971 expr: &LlkvExpr<'static, String>,
10972 table_infos: &[TableInfo<'_>],
10973) -> Option<JoinConstraintPlan> {
10974 let mut conjuncts = Vec::new();
10975 collect_conjuncts_lenient(expr, &mut conjuncts);
10977
10978 let total_conjuncts = conjuncts.len();
10979 let mut equalities = Vec::new();
10980 let mut literals = Vec::new();
10981 let mut unsatisfiable = false;
10982 let mut handled_conjuncts = 0;
10983
10984 for conjunct in conjuncts {
10985 match conjunct {
10986 LlkvExpr::Literal(true) => {
10987 handled_conjuncts += 1;
10988 }
10989 LlkvExpr::Literal(false) => {
10990 unsatisfiable = true;
10991 handled_conjuncts += 1;
10992 break;
10993 }
10994 LlkvExpr::Compare {
10995 left,
10996 op: CompareOp::Eq,
10997 right,
10998 } => {
10999 match (
11000 resolve_column_reference(left, table_infos),
11001 resolve_column_reference(right, table_infos),
11002 ) {
11003 (Some(left_col), Some(right_col)) => {
11004 equalities.push(ColumnEquality {
11005 left: left_col,
11006 right: right_col,
11007 });
11008 handled_conjuncts += 1;
11009 continue;
11010 }
11011 (Some(column), None) => {
11012 if let Some(literal) = extract_literal(right)
11013 && let Some(value) = literal_to_plan_value_for_join(literal)
11014 {
11015 literals
11016 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
11017 handled_conjuncts += 1;
11018 continue;
11019 }
11020 }
11021 (None, Some(column)) => {
11022 if let Some(literal) = extract_literal(left)
11023 && let Some(value) = literal_to_plan_value_for_join(literal)
11024 {
11025 literals
11026 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
11027 handled_conjuncts += 1;
11028 continue;
11029 }
11030 }
11031 _ => {}
11032 }
11033 }
11035 LlkvExpr::InList {
11037 expr: col_expr,
11038 list,
11039 negated: false,
11040 } => {
11041 if let Some(column) = resolve_column_reference(col_expr, table_infos) {
11042 let mut in_list_values = Vec::new();
11044 for item in list {
11045 if let Some(literal) = extract_literal(item)
11046 && let Some(value) = literal_to_plan_value_for_join(literal)
11047 {
11048 in_list_values.push(value);
11049 }
11050 }
11051 if !in_list_values.is_empty() {
11052 literals.push(ColumnConstraint::InList(ColumnInList {
11053 column,
11054 values: in_list_values,
11055 }));
11056 handled_conjuncts += 1;
11057 continue;
11058 }
11059 }
11060 }
11062 LlkvExpr::Or(or_children) => {
11064 if let Some((column, values)) = try_extract_or_as_in_list(or_children, table_infos)
11065 {
11066 literals.push(ColumnConstraint::InList(ColumnInList { column, values }));
11068 handled_conjuncts += 1;
11069 continue;
11070 }
11071 }
11073 LlkvExpr::Pred(filter) => {
11075 if let Operator::Equals(ref literal) = filter.op
11077 && let Some(column) = resolve_column_reference(
11078 &ScalarExpr::Column(filter.field_id.clone()),
11079 table_infos,
11080 )
11081 && let Some(value) = literal_to_plan_value_for_join(literal)
11082 {
11083 literals.push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
11084 handled_conjuncts += 1;
11085 continue;
11086 }
11087 }
11089 _ => {
11090 }
11092 }
11093 }
11094
11095 Some(JoinConstraintPlan {
11096 equalities,
11097 literals,
11098 unsatisfiable,
11099 total_conjuncts,
11100 handled_conjuncts,
11101 })
11102}
11103
11104fn resolve_column_reference(
11105 expr: &ScalarExpr<String>,
11106 table_infos: &[TableInfo<'_>],
11107) -> Option<ColumnRef> {
11108 let name = match expr {
11109 ScalarExpr::Column(name) => name.trim(),
11110 _ => return None,
11111 };
11112
11113 let mut parts: Vec<&str> = name
11114 .trim_start_matches('.')
11115 .split('.')
11116 .filter(|segment| !segment.is_empty())
11117 .collect();
11118
11119 if parts.is_empty() {
11120 return None;
11121 }
11122
11123 let column_part = parts.pop()?.to_ascii_lowercase();
11124 if parts.is_empty() {
11125 for info in table_infos {
11129 if let Some(&col_idx) = info.column_map.get(&column_part) {
11130 return Some(ColumnRef {
11131 table: info.index,
11132 column: col_idx,
11133 });
11134 }
11135 }
11136 return None;
11137 }
11138
11139 let table_ident = parts.join(".").to_ascii_lowercase();
11140 for info in table_infos {
11141 if matches_table_ident(info.table_ref, &table_ident) {
11142 if let Some(&col_idx) = info.column_map.get(&column_part) {
11143 return Some(ColumnRef {
11144 table: info.index,
11145 column: col_idx,
11146 });
11147 } else {
11148 return None;
11149 }
11150 }
11151 }
11152 None
11153}
11154
11155fn matches_table_ident(table_ref: &llkv_plan::TableRef, ident: &str) -> bool {
11156 if ident.is_empty() {
11157 return false;
11158 }
11159 if let Some(alias) = &table_ref.alias
11160 && alias.to_ascii_lowercase() == ident
11161 {
11162 return true;
11163 }
11164 if table_ref.table.to_ascii_lowercase() == ident {
11165 return true;
11166 }
11167 if !table_ref.schema.is_empty() {
11168 let full = format!(
11169 "{}.{}",
11170 table_ref.schema.to_ascii_lowercase(),
11171 table_ref.table.to_ascii_lowercase()
11172 );
11173 if full == ident {
11174 return true;
11175 }
11176 }
11177 false
11178}
11179
11180fn extract_literal(expr: &ScalarExpr<String>) -> Option<&Literal> {
11181 match expr {
11182 ScalarExpr::Literal(lit) => Some(lit),
11183 _ => None,
11184 }
11185}
11186
11187fn plan_value_from_operator_literal(op_value: &llkv_expr::literal::Literal) -> Option<PlanValue> {
11188 match op_value {
11189 llkv_expr::literal::Literal::Integer(v) => i64::try_from(*v).ok().map(PlanValue::Integer),
11190 llkv_expr::literal::Literal::Float(v) => Some(PlanValue::Float(*v)),
11191 llkv_expr::literal::Literal::Boolean(v) => Some(PlanValue::Integer(if *v { 1 } else { 0 })),
11192 llkv_expr::literal::Literal::String(v) => Some(PlanValue::String(v.clone())),
11193 _ => None,
11194 }
11195}
11196
11197fn literal_to_plan_value_for_join(literal: &Literal) -> Option<PlanValue> {
11198 match literal {
11199 Literal::Integer(v) => i64::try_from(*v).ok().map(PlanValue::Integer),
11200 Literal::Float(v) => Some(PlanValue::Float(*v)),
11201 Literal::Boolean(v) => Some(PlanValue::Integer(if *v { 1 } else { 0 })),
11202 Literal::String(v) => Some(PlanValue::String(v.clone())),
11203 _ => None,
11204 }
11205}
11206
11207#[derive(Default)]
11208struct DistinctState {
11209 seen: FxHashSet<CanonicalRow>,
11210}
11211
11212impl DistinctState {
11213 fn insert(&mut self, row: CanonicalRow) -> bool {
11214 self.seen.insert(row)
11215 }
11216}
11217
11218fn distinct_filter_batch(
11219 batch: RecordBatch,
11220 state: &mut DistinctState,
11221) -> ExecutorResult<Option<RecordBatch>> {
11222 if batch.num_rows() == 0 {
11223 return Ok(None);
11224 }
11225
11226 let mut keep_flags = Vec::with_capacity(batch.num_rows());
11227 let mut keep_count = 0usize;
11228
11229 for row_idx in 0..batch.num_rows() {
11230 let row = CanonicalRow::from_batch(&batch, row_idx)?;
11231 if state.insert(row) {
11232 keep_flags.push(true);
11233 keep_count += 1;
11234 } else {
11235 keep_flags.push(false);
11236 }
11237 }
11238
11239 if keep_count == 0 {
11240 return Ok(None);
11241 }
11242
11243 if keep_count == batch.num_rows() {
11244 return Ok(Some(batch));
11245 }
11246
11247 let mut builder = BooleanBuilder::with_capacity(batch.num_rows());
11248 for flag in keep_flags {
11249 builder.append_value(flag);
11250 }
11251 let mask = Arc::new(builder.finish());
11252
11253 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
11254 Error::InvalidArgumentError(format!("failed to apply DISTINCT filter: {err}"))
11255 })?;
11256
11257 Ok(Some(filtered))
11258}
11259
11260fn sort_record_batch_with_order(
11261 schema: &Arc<Schema>,
11262 batch: &RecordBatch,
11263 order_by: &[OrderByPlan],
11264) -> ExecutorResult<RecordBatch> {
11265 if order_by.is_empty() {
11266 return Ok(batch.clone());
11267 }
11268
11269 let mut sort_columns: Vec<SortColumn> = Vec::with_capacity(order_by.len());
11270
11271 for order in order_by {
11272 let column_index = match &order.target {
11273 OrderTarget::Column(name) => schema.index_of(name).map_err(|_| {
11274 Error::InvalidArgumentError(format!(
11275 "ORDER BY references unknown column '{}'",
11276 name
11277 ))
11278 })?,
11279 OrderTarget::Index(idx) => {
11280 if *idx >= batch.num_columns() {
11281 return Err(Error::InvalidArgumentError(format!(
11282 "ORDER BY position {} is out of bounds for {} columns",
11283 idx + 1,
11284 batch.num_columns()
11285 )));
11286 }
11287 *idx
11288 }
11289 OrderTarget::All => {
11290 return Err(Error::InvalidArgumentError(
11291 "ORDER BY ALL should be expanded before sorting".into(),
11292 ));
11293 }
11294 };
11295
11296 let source_array = batch.column(column_index);
11297
11298 let values: ArrayRef = match order.sort_type {
11299 OrderSortType::Native => Arc::clone(source_array),
11300 OrderSortType::CastTextToInteger => {
11301 let strings = source_array
11302 .as_any()
11303 .downcast_ref::<StringArray>()
11304 .ok_or_else(|| {
11305 Error::InvalidArgumentError(
11306 "ORDER BY CAST expects the underlying column to be TEXT".into(),
11307 )
11308 })?;
11309 let mut builder = Int64Builder::with_capacity(strings.len());
11310 for i in 0..strings.len() {
11311 if strings.is_null(i) {
11312 builder.append_null();
11313 } else {
11314 match strings.value(i).parse::<i64>() {
11315 Ok(value) => builder.append_value(value),
11316 Err(_) => builder.append_null(),
11317 }
11318 }
11319 }
11320 Arc::new(builder.finish()) as ArrayRef
11321 }
11322 };
11323
11324 let sort_options = SortOptions {
11325 descending: !order.ascending,
11326 nulls_first: order.nulls_first,
11327 };
11328
11329 sort_columns.push(SortColumn {
11330 values,
11331 options: Some(sort_options),
11332 });
11333 }
11334
11335 let indices = lexsort_to_indices(&sort_columns, None).map_err(|err| {
11336 Error::InvalidArgumentError(format!("failed to compute ORDER BY indices: {err}"))
11337 })?;
11338
11339 let perm = indices
11340 .as_any()
11341 .downcast_ref::<UInt32Array>()
11342 .ok_or_else(|| Error::Internal("ORDER BY sorting produced unexpected index type".into()))?;
11343
11344 let mut reordered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
11345 for col_idx in 0..batch.num_columns() {
11346 let reordered = take(batch.column(col_idx), perm, None).map_err(|err| {
11347 Error::InvalidArgumentError(format!(
11348 "failed to apply ORDER BY permutation to column {col_idx}: {err}"
11349 ))
11350 })?;
11351 reordered_columns.push(reordered);
11352 }
11353
11354 RecordBatch::try_new(Arc::clone(schema), reordered_columns)
11355 .map_err(|err| Error::Internal(format!("failed to build reordered ORDER BY batch: {err}")))
11356}
11357
11358#[cfg(test)]
11359mod tests {
11360 use super::*;
11361 use arrow::array::{Array, ArrayRef, Int64Array};
11362 use arrow::datatypes::{DataType, Field, Schema};
11363 use llkv_expr::expr::BinaryOp;
11364 use llkv_expr::literal::Literal;
11365 use llkv_storage::pager::MemPager;
11366 use std::sync::Arc;
11367
11368 #[test]
11369 fn cross_product_context_evaluates_expressions() {
11370 let schema = Arc::new(Schema::new(vec![
11371 Field::new("main.tab2.a", DataType::Int64, false),
11372 Field::new("main.tab2.b", DataType::Int64, false),
11373 ]));
11374
11375 let batch = RecordBatch::try_new(
11376 Arc::clone(&schema),
11377 vec![
11378 Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef,
11379 Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef,
11380 ],
11381 )
11382 .expect("valid batch");
11383
11384 let lookup = build_cross_product_column_lookup(schema.as_ref(), &[], &[], &[]);
11385 let mut ctx = CrossProductExpressionContext::new(schema.as_ref(), lookup)
11386 .expect("context builds from schema");
11387
11388 let literal_expr: ScalarExpr<String> = ScalarExpr::literal(67);
11389 let literal = ctx
11390 .evaluate(&literal_expr, &batch)
11391 .expect("literal evaluation succeeds");
11392 let literal_array = literal
11393 .as_any()
11394 .downcast_ref::<Int64Array>()
11395 .expect("int64 literal result");
11396 assert_eq!(literal_array.len(), 3);
11397 assert!(literal_array.iter().all(|value| value == Some(67)));
11398
11399 let add_expr = ScalarExpr::binary(
11400 ScalarExpr::column("tab2.a".to_string()),
11401 BinaryOp::Add,
11402 ScalarExpr::literal(5),
11403 );
11404 let added = ctx
11405 .evaluate(&add_expr, &batch)
11406 .expect("column addition succeeds");
11407 let added_array = added
11408 .as_any()
11409 .downcast_ref::<Int64Array>()
11410 .expect("int64 addition result");
11411 assert_eq!(added_array.values(), &[6, 7, 8]);
11412 }
11413
11414 #[test]
11415 fn aggregate_expr_allows_numeric_casts() {
11416 let expr = ScalarExpr::Cast {
11417 expr: Box::new(ScalarExpr::literal(31)),
11418 data_type: DataType::Int32,
11419 };
11420 let aggregates = FxHashMap::default();
11421
11422 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11423 .expect("cast should succeed for in-range integral values");
11424
11425 assert_eq!(value, Some(31));
11426 }
11427
11428 #[test]
11429 fn aggregate_expr_cast_rejects_out_of_range_values() {
11430 let expr = ScalarExpr::Cast {
11431 expr: Box::new(ScalarExpr::literal(-1)),
11432 data_type: DataType::UInt8,
11433 };
11434 let aggregates = FxHashMap::default();
11435
11436 let result = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates);
11437
11438 assert!(matches!(result, Err(Error::InvalidArgumentError(_))));
11439 }
11440
11441 #[test]
11442 fn aggregate_expr_null_literal_remains_null() {
11443 let expr = ScalarExpr::binary(
11444 ScalarExpr::literal(0),
11445 BinaryOp::Subtract,
11446 ScalarExpr::cast(ScalarExpr::literal(Literal::Null), DataType::Int64),
11447 );
11448 let aggregates = FxHashMap::default();
11449
11450 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11451 .expect("expression should evaluate");
11452
11453 assert_eq!(value, None);
11454 }
11455
11456 #[test]
11457 fn aggregate_expr_divide_by_zero_returns_null() {
11458 let expr = ScalarExpr::binary(
11459 ScalarExpr::literal(10),
11460 BinaryOp::Divide,
11461 ScalarExpr::literal(0),
11462 );
11463 let aggregates = FxHashMap::default();
11464
11465 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11466 .expect("division should evaluate");
11467
11468 assert_eq!(value, None);
11469 }
11470
11471 #[test]
11472 fn aggregate_expr_modulo_by_zero_returns_null() {
11473 let expr = ScalarExpr::binary(
11474 ScalarExpr::literal(10),
11475 BinaryOp::Modulo,
11476 ScalarExpr::literal(0),
11477 );
11478 let aggregates = FxHashMap::default();
11479
11480 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
11481 .expect("modulo should evaluate");
11482
11483 assert_eq!(value, None);
11484 }
11485
11486 #[test]
11487 fn constant_and_with_null_yields_null() {
11488 let expr = ScalarExpr::binary(
11489 ScalarExpr::literal(Literal::Null),
11490 BinaryOp::And,
11491 ScalarExpr::literal(1),
11492 );
11493
11494 let value = evaluate_constant_scalar_with_aggregates(&expr)
11495 .expect("expression should fold as constant");
11496
11497 assert!(matches!(value, Literal::Null));
11498 }
11499
11500 #[test]
11501 fn cross_product_handles_more_than_two_tables() {
11502 let schema_a = Arc::new(Schema::new(vec![Field::new(
11503 "main.t1.a",
11504 DataType::Int64,
11505 false,
11506 )]));
11507 let schema_b = Arc::new(Schema::new(vec![Field::new(
11508 "main.t2.b",
11509 DataType::Int64,
11510 false,
11511 )]));
11512 let schema_c = Arc::new(Schema::new(vec![Field::new(
11513 "main.t3.c",
11514 DataType::Int64,
11515 false,
11516 )]));
11517
11518 let batch_a = RecordBatch::try_new(
11519 Arc::clone(&schema_a),
11520 vec![Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef],
11521 )
11522 .expect("valid batch");
11523 let batch_b = RecordBatch::try_new(
11524 Arc::clone(&schema_b),
11525 vec![Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef],
11526 )
11527 .expect("valid batch");
11528 let batch_c = RecordBatch::try_new(
11529 Arc::clone(&schema_c),
11530 vec![Arc::new(Int64Array::from(vec![100])) as ArrayRef],
11531 )
11532 .expect("valid batch");
11533
11534 let data_a = TableCrossProductData {
11535 schema: schema_a,
11536 batches: vec![batch_a],
11537 column_counts: vec![1],
11538 table_indices: vec![0],
11539 };
11540 let data_b = TableCrossProductData {
11541 schema: schema_b,
11542 batches: vec![batch_b],
11543 column_counts: vec![1],
11544 table_indices: vec![1],
11545 };
11546 let data_c = TableCrossProductData {
11547 schema: schema_c,
11548 batches: vec![batch_c],
11549 column_counts: vec![1],
11550 table_indices: vec![2],
11551 };
11552
11553 let ab = cross_join_table_batches(data_a, data_b).expect("two-table product");
11554 assert_eq!(ab.schema.fields().len(), 2);
11555 assert_eq!(ab.batches.len(), 1);
11556 assert_eq!(ab.batches[0].num_rows(), 6);
11557
11558 let abc = cross_join_table_batches(ab, data_c).expect("three-table product");
11559 assert_eq!(abc.schema.fields().len(), 3);
11560 assert_eq!(abc.batches.len(), 1);
11561
11562 let final_batch = &abc.batches[0];
11563 assert_eq!(final_batch.num_rows(), 6);
11564
11565 let col_a = final_batch
11566 .column(0)
11567 .as_any()
11568 .downcast_ref::<Int64Array>()
11569 .expect("left column values");
11570 assert_eq!(col_a.values(), &[1, 1, 1, 2, 2, 2]);
11571
11572 let col_b = final_batch
11573 .column(1)
11574 .as_any()
11575 .downcast_ref::<Int64Array>()
11576 .expect("middle column values");
11577 assert_eq!(col_b.values(), &[10, 20, 30, 10, 20, 30]);
11578
11579 let col_c = final_batch
11580 .column(2)
11581 .as_any()
11582 .downcast_ref::<Int64Array>()
11583 .expect("right column values");
11584 assert_eq!(col_c.values(), &[100, 100, 100, 100, 100, 100]);
11585 }
11586}