1use arrow::array::{
17 Array, ArrayRef, BooleanArray, BooleanBuilder, Date32Array, Decimal128Array, Decimal128Builder,
18 Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, Int64Builder,
19 IntervalMonthDayNanoArray, RecordBatch, StringArray, UInt8Array, UInt16Array, UInt32Array,
20 UInt64Array, new_null_array,
21};
22use arrow::compute::{
23 SortColumn, SortOptions, cast, concat_batches, filter_record_batch, lexsort_to_indices, not,
24 or_kleene, take,
25};
26use arrow::datatypes::{DataType, Field, Float64Type, Int32Type, Int64Type, IntervalUnit, Schema};
27use arrow::row::{RowConverter, SortField};
28use arrow_buffer::IntervalMonthDayNano;
29use llkv_aggregate::{AggregateAccumulator, AggregateKind, AggregateSpec, AggregateState};
30use llkv_column_map::gather::gather_indices_from_batches;
31use llkv_column_map::store::Projection as StoreProjection;
32use llkv_expr::SubqueryId;
33use llkv_expr::expr::{
34 AggregateCall, BinaryOp, CompareOp, Expr as LlkvExpr, Filter, Operator, ScalarExpr,
35};
36use llkv_expr::literal::{Literal, LiteralExt};
37use llkv_expr::typed_predicate::{
38 build_bool_predicate, build_fixed_width_predicate, build_var_width_predicate,
39};
40use llkv_join::cross_join_pair;
41use llkv_plan::{
42 AggregateExpr, AggregateFunction, CanonicalRow, CompoundOperator, CompoundQuantifier,
43 CompoundSelectComponent, CompoundSelectPlan, OrderByPlan, OrderSortType, OrderTarget,
44 PlanValue, SelectPlan, SelectProjection, plan_value_from_literal,
45};
46use llkv_result::Error;
47use llkv_storage::pager::Pager;
48use llkv_table::table::{
49 RowIdFilter, ScanOrderDirection, ScanOrderSpec, ScanOrderTransform, ScanProjection,
50 ScanStreamOptions,
51};
52use llkv_table::types::FieldId;
53use llkv_table::{NumericArrayMap, NumericKernels, ROW_ID_FIELD_ID};
54use llkv_types::LogicalFieldId;
55use llkv_types::decimal::DecimalValue;
56use rayon::prelude::*;
57use rustc_hash::{FxHashMap, FxHashSet};
58use simd_r_drive_entry_handle::EntryHandle;
59use std::convert::TryFrom;
60use std::fmt;
61use std::sync::Arc;
62use std::sync::atomic::Ordering;
63
64#[cfg(test)]
65use std::cell::RefCell;
66
67pub mod insert;
72pub mod translation;
73pub mod types;
74
75pub type ExecutorResult<T> = Result<T, Error>;
81
82use crate::translation::schema::infer_computed_data_type;
83pub use insert::{
84 build_array_for_column, normalize_insert_value_for_column, resolve_insert_columns,
85};
86use llkv_compute::date::{format_date32_literal, parse_date32_literal};
87use llkv_compute::scalar::decimal::{
88 align_decimal_to_scale, decimal_from_f64, decimal_from_i64, decimal_truthy,
89};
90use llkv_compute::scalar::interval::{
91 compare_interval_values, interval_value_from_arrow, interval_value_to_arrow,
92};
93pub use llkv_compute::time::current_time_micros;
94pub use translation::{
95 build_projected_columns, build_wildcard_projections, full_table_scan_filter,
96 resolve_field_id_from_schema, schema_for_projections, translate_predicate,
97 translate_predicate_with, translate_scalar, translate_scalar_with,
98};
99pub use types::{
100 ExecutorColumn, ExecutorMultiColumnUnique, ExecutorRowBatch, ExecutorSchema, ExecutorTable,
101 ExecutorTableProvider,
102};
103
104#[derive(Clone, Debug, PartialEq, Eq, Hash)]
105enum GroupKeyValue {
106 Null,
107 Int(i64),
109 Bool(bool),
110 String(String),
111}
112
113#[derive(Clone, Debug, PartialEq)]
117enum AggregateValue {
118 Null,
119 Int64(i64),
120 Float64(f64),
121 Decimal128 { value: i128, scale: i8 },
122 String(String),
123}
124
125impl AggregateValue {
126 fn as_i64(&self) -> Option<i64> {
128 match self {
129 AggregateValue::Null => None,
130 AggregateValue::Int64(v) => Some(*v),
131 AggregateValue::Float64(v) => Some(*v as i64),
132 AggregateValue::Decimal128 { value, scale } => {
133 let divisor = 10_i128.pow(*scale as u32);
135 Some((value / divisor) as i64)
136 }
137 AggregateValue::String(s) => s.parse().ok(),
138 }
139 }
140
141 #[allow(dead_code)]
143 fn as_f64(&self) -> Option<f64> {
144 match self {
145 AggregateValue::Null => None,
146 AggregateValue::Int64(v) => Some(*v as f64),
147 AggregateValue::Float64(v) => Some(*v),
148 AggregateValue::Decimal128 { value, scale } => {
149 let divisor = 10_f64.powi(*scale as i32);
151 Some(*value as f64 / divisor)
152 }
153 AggregateValue::String(s) => s.parse().ok(),
154 }
155 }
156}
157
158fn decimal_exact_i64(decimal: DecimalValue) -> Option<i64> {
159 llkv_compute::scalar::decimal::rescale(decimal, 0)
160 .ok()
161 .and_then(|integral| i64::try_from(integral.raw_value()).ok())
162}
163
164struct GroupState {
165 batch: RecordBatch,
166 row_idx: usize,
167}
168
169struct GroupAggregateState {
171 representative_batch_idx: usize,
172 representative_row: usize,
173 row_locations: Vec<(usize, usize)>,
174}
175
176struct OutputColumn {
177 field: Field,
178 source: OutputSource,
179}
180
181enum OutputSource {
182 TableColumn { index: usize },
183 Computed { projection_index: usize },
184}
185
186#[cfg(test)]
191thread_local! {
192 static QUERY_LABEL_STACK: RefCell<Vec<String>> = const { RefCell::new(Vec::new()) };
193}
194
195pub struct QueryLogGuard {
197 _private: (),
198}
199
200#[cfg(test)]
203pub fn push_query_label(label: impl Into<String>) -> QueryLogGuard {
204 QUERY_LABEL_STACK.with(|stack| stack.borrow_mut().push(label.into()));
205 QueryLogGuard { _private: () }
206}
207
208#[cfg(not(test))]
213#[inline]
214pub fn push_query_label(_label: impl Into<String>) -> QueryLogGuard {
215 QueryLogGuard { _private: () }
216}
217
218#[cfg(test)]
219impl Drop for QueryLogGuard {
220 fn drop(&mut self) {
221 QUERY_LABEL_STACK.with(|stack| {
222 let _ = stack.borrow_mut().pop();
223 });
224 }
225}
226
227#[cfg(not(test))]
228impl Drop for QueryLogGuard {
229 #[inline]
230 fn drop(&mut self) {
231 }
233}
234
235#[cfg(test)]
237pub fn current_query_label() -> Option<String> {
238 QUERY_LABEL_STACK.with(|stack| stack.borrow().last().cloned())
239}
240
241#[cfg(not(test))]
245#[inline]
246pub fn current_query_label() -> Option<String> {
247 None
248}
249
250fn try_extract_simple_column<F: AsRef<str>>(expr: &ScalarExpr<F>) -> Option<&str> {
265 match expr {
266 ScalarExpr::Column(name) => Some(name.as_ref()),
267 ScalarExpr::Binary { left, op, right } => {
269 match op {
271 BinaryOp::Add => {
272 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Int128(0))) {
274 return try_extract_simple_column(right);
275 }
276 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Int128(0))) {
277 return try_extract_simple_column(left);
278 }
279 }
280 BinaryOp::Multiply => {
283 if matches!(left.as_ref(), ScalarExpr::Literal(Literal::Int128(1))) {
285 return try_extract_simple_column(right);
286 }
287 if matches!(right.as_ref(), ScalarExpr::Literal(Literal::Int128(1))) {
288 return try_extract_simple_column(left);
289 }
290 }
291 _ => {}
292 }
293 None
294 }
295 _ => None,
296 }
297}
298
299fn plan_values_to_arrow_array(values: &[PlanValue]) -> ExecutorResult<ArrayRef> {
304 use arrow::array::{
305 Date32Array, Decimal128Array, Float64Array, Int64Array, IntervalMonthDayNanoArray,
306 StringArray,
307 };
308
309 let mut value_type = None;
311 for v in values {
312 if !matches!(v, PlanValue::Null) {
313 value_type = Some(v);
314 break;
315 }
316 }
317
318 match value_type {
319 Some(PlanValue::Decimal(d)) => {
320 let precision = d.precision();
321 let scale = d.scale();
322 let mut builder = Decimal128Array::builder(values.len())
323 .with_precision_and_scale(precision, scale)
324 .map_err(|e| {
325 Error::InvalidArgumentError(format!(
326 "invalid Decimal128 precision/scale: {}",
327 e
328 ))
329 })?;
330 for v in values {
331 match v {
332 PlanValue::Decimal(d) => builder.append_value(d.raw_value()),
333 PlanValue::Null => builder.append_null(),
334 other => {
335 return Err(Error::InvalidArgumentError(format!(
336 "expected DECIMAL plan value, found {other:?}"
337 )));
338 }
339 }
340 }
341 Ok(Arc::new(builder.finish()) as ArrayRef)
342 }
343 Some(PlanValue::Integer(_)) => {
344 let int_values: Vec<Option<i64>> = values
345 .iter()
346 .map(|v| match v {
347 PlanValue::Integer(i) => Ok(Some(*i)),
348 PlanValue::Null => Ok(None),
349 other => Err(Error::InvalidArgumentError(format!(
350 "expected INTEGER plan value, found {other:?}"
351 ))),
352 })
353 .collect::<Result<_, _>>()?;
354 Ok(Arc::new(Int64Array::from(int_values)) as ArrayRef)
355 }
356 Some(PlanValue::Float(_)) => {
357 let float_values: Vec<Option<f64>> = values
358 .iter()
359 .map(|v| match v {
360 PlanValue::Float(f) => Ok(Some(*f)),
361 PlanValue::Null => Ok(None),
362 PlanValue::Integer(i) => Ok(Some(*i as f64)),
363 other => Err(Error::InvalidArgumentError(format!(
364 "expected FLOAT plan value, found {other:?}"
365 ))),
366 })
367 .collect::<Result<_, _>>()?;
368 Ok(Arc::new(Float64Array::from(float_values)) as ArrayRef)
369 }
370 Some(PlanValue::String(_)) => {
371 let string_values: Vec<Option<&str>> = values
372 .iter()
373 .map(|v| match v {
374 PlanValue::String(s) => Ok(Some(s.as_str())),
375 PlanValue::Null => Ok(None),
376 other => Err(Error::InvalidArgumentError(format!(
377 "expected STRING plan value, found {other:?}"
378 ))),
379 })
380 .collect::<Result<_, _>>()?;
381 Ok(Arc::new(StringArray::from(string_values)) as ArrayRef)
382 }
383 Some(PlanValue::Date32(_)) => {
384 let date_values: Vec<Option<i32>> = values
385 .iter()
386 .map(|v| match v {
387 PlanValue::Date32(d) => Ok(Some(*d)),
388 PlanValue::Null => Ok(None),
389 other => Err(Error::InvalidArgumentError(format!(
390 "expected DATE plan value, found {other:?}"
391 ))),
392 })
393 .collect::<Result<_, _>>()?;
394 Ok(Arc::new(Date32Array::from(date_values)) as ArrayRef)
395 }
396 Some(PlanValue::Interval(_)) => {
397 let interval_values: Vec<Option<IntervalMonthDayNano>> = values
398 .iter()
399 .map(|v| match v {
400 PlanValue::Interval(interval) => Ok(Some(interval_value_to_arrow(*interval))),
401 PlanValue::Null => Ok(None),
402 other => Err(Error::InvalidArgumentError(format!(
403 "expected INTERVAL plan value, found {other:?}"
404 ))),
405 })
406 .collect::<Result<_, _>>()?;
407 Ok(Arc::new(IntervalMonthDayNanoArray::from(interval_values)) as ArrayRef)
408 }
409 _ => Ok(new_null_array(&DataType::Int64, values.len())),
410 }
411}
412
413fn resolve_column_name_to_index(
423 col_name: &str,
424 column_lookup_map: &FxHashMap<String, usize>,
425) -> Option<usize> {
426 let col_lower = col_name.to_ascii_lowercase();
427
428 if let Some(&idx) = column_lookup_map.get(&col_lower) {
430 return Some(idx);
431 }
432
433 let unqualified = col_name
436 .rsplit('.')
437 .next()
438 .unwrap_or(col_name)
439 .to_ascii_lowercase();
440 column_lookup_map
441 .iter()
442 .find(|(k, _)| k.ends_with(&format!(".{}", unqualified)) || k == &&unqualified)
443 .map(|(_, &idx)| idx)
444}
445
446fn get_or_insert_column_projection<P>(
448 projections: &mut Vec<ScanProjection>,
449 cache: &mut FxHashMap<FieldId, usize>,
450 table: &ExecutorTable<P>,
451 column: &ExecutorColumn,
452) -> usize
453where
454 P: Pager<Blob = EntryHandle> + Send + Sync,
455{
456 if let Some(existing) = cache.get(&column.field_id) {
457 return *existing;
458 }
459
460 let projection_index = projections.len();
461 let alias = if column.name.is_empty() {
462 format!("col{}", column.field_id)
463 } else {
464 column.name.clone()
465 };
466 projections.push(ScanProjection::from(StoreProjection::with_alias(
467 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
468 alias,
469 )));
470 cache.insert(column.field_id, projection_index);
471 projection_index
472}
473
474fn ensure_computed_projection<P>(
476 expr: &ScalarExpr<String>,
477 table: &ExecutorTable<P>,
478 projections: &mut Vec<ScanProjection>,
479 cache: &mut FxHashMap<String, (usize, DataType)>,
480 alias_counter: &mut usize,
481) -> ExecutorResult<(usize, DataType)>
482where
483 P: Pager<Blob = EntryHandle> + Send + Sync,
484{
485 let key = format!("{:?}", expr);
486 if let Some((idx, dtype)) = cache.get(&key) {
487 return Ok((*idx, dtype.clone()));
488 }
489
490 let translated = translate_scalar(expr, table.schema.as_ref(), |name| {
491 Error::InvalidArgumentError(format!("unknown column '{}' in aggregate expression", name))
492 })?;
493 let data_type = infer_computed_data_type(table.schema.as_ref(), &translated)?;
494 if data_type == DataType::Null {
495 tracing::debug!(
496 "ensure_computed_projection inferred Null type for expr: {:?}",
497 expr
498 );
499 }
500 let alias = format!("__agg_expr_{}", *alias_counter);
501 *alias_counter += 1;
502 let projection_index = projections.len();
503 projections.push(ScanProjection::computed(translated, alias));
504 cache.insert(key, (projection_index, data_type.clone()));
505 Ok((projection_index, data_type))
506}
507
508pub struct QueryExecutor<P>
510where
511 P: Pager<Blob = EntryHandle> + Send + Sync,
512{
513 provider: Arc<dyn ExecutorTableProvider<P>>,
514}
515
516impl<P> QueryExecutor<P>
517where
518 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
519{
520 pub fn new(provider: Arc<dyn ExecutorTableProvider<P>>) -> Self {
521 Self { provider }
522 }
523
524 pub fn execute_select(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
525 self.execute_select_with_filter(plan, None)
526 }
527
528 pub fn execute_select_with_filter(
529 &self,
530 plan: SelectPlan,
531 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
532 ) -> ExecutorResult<SelectExecution<P>> {
533 let limit = plan.limit;
534 let offset = plan.offset;
535
536 let execution = if plan.compound.is_some() {
537 self.execute_compound_select(plan, row_filter)?
538 } else if plan.tables.is_empty() {
539 self.execute_select_without_table(plan)?
540 } else if !plan.group_by.is_empty() {
541 if plan.tables.len() > 1 {
542 self.execute_cross_product(plan)?
543 } else {
544 let table_ref = &plan.tables[0];
545 let table = self.provider.get_table(&table_ref.qualified_name())?;
546 let display_name = table_ref.qualified_name();
547 self.execute_group_by_single_table(table, display_name, plan, row_filter)?
548 }
549 } else if plan.tables.len() > 1 {
550 self.execute_cross_product(plan)?
551 } else {
552 let table_ref = &plan.tables[0];
554 let table = self.provider.get_table(&table_ref.qualified_name())?;
555 let display_name = table_ref.qualified_name();
556
557 if !plan.aggregates.is_empty() {
558 self.execute_aggregates(table, display_name, plan, row_filter)?
559 } else if self.has_computed_aggregates(&plan) {
560 self.execute_computed_aggregates(table, display_name, plan, row_filter)?
562 } else {
563 self.execute_projection(table, display_name, plan, row_filter)?
564 }
565 };
566
567 Ok(execution.with_limit(limit).with_offset(offset))
568 }
569
570 fn execute_compound_select(
590 &self,
591 plan: SelectPlan,
592 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
593 ) -> ExecutorResult<SelectExecution<P>> {
594 let order_by = plan.order_by.clone();
595 let compound = plan.compound.expect("compound plan should be present");
596
597 let CompoundSelectPlan {
598 initial,
599 operations,
600 } = compound;
601
602 let initial_exec = self.execute_select_with_filter(*initial, row_filter.clone())?;
603 let schema = initial_exec.schema();
604 let mut rows = initial_exec.into_rows()?;
605 let mut distinct_cache: Option<FxHashSet<Vec<u8>>> = None;
606
607 for component in operations {
608 let exec = self.execute_select_with_filter(component.plan, row_filter.clone())?;
609 let other_schema = exec.schema();
610 ensure_schema_compatibility(schema.as_ref(), other_schema.as_ref())?;
611 let other_rows = exec.into_rows()?;
612
613 match (component.operator, component.quantifier) {
614 (CompoundOperator::Union, CompoundQuantifier::All) => {
615 rows.extend(other_rows);
616 distinct_cache = None;
617 }
618 (CompoundOperator::Union, CompoundQuantifier::Distinct) => {
619 ensure_distinct_rows(&mut rows, &mut distinct_cache);
620 let cache = distinct_cache
621 .as_mut()
622 .expect("distinct cache should be initialized");
623 for row in other_rows {
624 let key = encode_row(&row);
625 if cache.insert(key) {
626 rows.push(row);
627 }
628 }
629 }
630 (CompoundOperator::Except, CompoundQuantifier::Distinct) => {
631 ensure_distinct_rows(&mut rows, &mut distinct_cache);
632 let cache = distinct_cache
633 .as_mut()
634 .expect("distinct cache should be initialized");
635 if rows.is_empty() {
636 continue;
637 }
638 let mut remove_keys = FxHashSet::default();
639 for row in other_rows {
640 remove_keys.insert(encode_row(&row));
641 }
642 if remove_keys.is_empty() {
643 continue;
644 }
645 rows.retain(|row| {
646 let key = encode_row(row);
647 if remove_keys.contains(&key) {
648 cache.remove(&key);
649 false
650 } else {
651 true
652 }
653 });
654 }
655 (CompoundOperator::Except, CompoundQuantifier::All) => {
656 return Err(Error::InvalidArgumentError(
657 "EXCEPT ALL is not supported yet".into(),
658 ));
659 }
660 (CompoundOperator::Intersect, CompoundQuantifier::Distinct) => {
661 ensure_distinct_rows(&mut rows, &mut distinct_cache);
662 let mut right_keys = FxHashSet::default();
663 for row in other_rows {
664 right_keys.insert(encode_row(&row));
665 }
666 if right_keys.is_empty() {
667 rows.clear();
668 distinct_cache = Some(FxHashSet::default());
669 continue;
670 }
671 let mut new_rows = Vec::new();
672 let mut new_cache = FxHashSet::default();
673 for row in rows.drain(..) {
674 let key = encode_row(&row);
675 if right_keys.contains(&key) && new_cache.insert(key) {
676 new_rows.push(row);
677 }
678 }
679 rows = new_rows;
680 distinct_cache = Some(new_cache);
681 }
682 (CompoundOperator::Intersect, CompoundQuantifier::All) => {
683 return Err(Error::InvalidArgumentError(
684 "INTERSECT ALL is not supported yet".into(),
685 ));
686 }
687 }
688 }
689
690 let mut batch = rows_to_record_batch(schema.clone(), &rows)?;
691 if !order_by.is_empty() && batch.num_rows() > 0 {
692 batch = sort_record_batch_with_order(&schema, &batch, &order_by)?;
693 }
694
695 Ok(SelectExecution::new_single_batch(
696 String::new(),
697 schema,
698 batch,
699 ))
700 }
701
702 fn has_computed_aggregates(&self, plan: &SelectPlan) -> bool {
704 plan.projections.iter().any(|proj| {
705 if let SelectProjection::Computed { expr, .. } = proj {
706 Self::expr_contains_aggregate(expr)
707 } else {
708 false
709 }
710 })
711 }
712
713 fn predicate_contains_aggregate(expr: &llkv_expr::expr::Expr<String>) -> bool {
715 match expr {
716 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
717 exprs.iter().any(Self::predicate_contains_aggregate)
718 }
719 llkv_expr::expr::Expr::Not(inner) => Self::predicate_contains_aggregate(inner),
720 llkv_expr::expr::Expr::Compare { left, right, .. } => {
721 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
722 }
723 llkv_expr::expr::Expr::InList { expr, list, .. } => {
724 Self::expr_contains_aggregate(expr)
725 || list.iter().any(|e| Self::expr_contains_aggregate(e))
726 }
727 llkv_expr::expr::Expr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
728 llkv_expr::expr::Expr::Literal(_) => false,
729 llkv_expr::expr::Expr::Pred(_) => false,
730 llkv_expr::expr::Expr::Exists(_) => false,
731 }
732 }
733
734 fn expr_contains_aggregate(expr: &ScalarExpr<String>) -> bool {
736 match expr {
737 ScalarExpr::Aggregate(_) => true,
738 ScalarExpr::Binary { left, right, .. } => {
739 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
740 }
741 ScalarExpr::Compare { left, right, .. } => {
742 Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
743 }
744 ScalarExpr::GetField { base, .. } => Self::expr_contains_aggregate(base),
745 ScalarExpr::Cast { expr, .. } => Self::expr_contains_aggregate(expr),
746 ScalarExpr::Not(expr) => Self::expr_contains_aggregate(expr),
747 ScalarExpr::IsNull { expr, .. } => Self::expr_contains_aggregate(expr),
748 ScalarExpr::Case {
749 operand,
750 branches,
751 else_expr,
752 } => {
753 operand
754 .as_deref()
755 .map(Self::expr_contains_aggregate)
756 .unwrap_or(false)
757 || branches.iter().any(|(when_expr, then_expr)| {
758 Self::expr_contains_aggregate(when_expr)
759 || Self::expr_contains_aggregate(then_expr)
760 })
761 || else_expr
762 .as_deref()
763 .map(Self::expr_contains_aggregate)
764 .unwrap_or(false)
765 }
766 ScalarExpr::Coalesce(items) => items.iter().any(Self::expr_contains_aggregate),
767 ScalarExpr::Column(_) | ScalarExpr::Literal(_) | ScalarExpr::Random => false,
768 ScalarExpr::ScalarSubquery(_) => false,
769 }
770 }
771
772 fn evaluate_exists_subquery(
773 &self,
774 context: &mut CrossProductExpressionContext,
775 subquery: &llkv_plan::FilterSubquery,
776 batch: &RecordBatch,
777 row_idx: usize,
778 ) -> ExecutorResult<bool> {
779 let bindings =
780 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
781 let bound_plan = bind_select_plan(&subquery.plan, &bindings)?;
782 let execution = self.execute_select(bound_plan)?;
783 let mut found = false;
784 execution.stream(|inner_batch| {
785 if inner_batch.num_rows() > 0 {
786 found = true;
787 }
788 Ok(())
789 })?;
790 Ok(found)
791 }
792
793 fn evaluate_scalar_subquery_literal(
794 &self,
795 context: &mut CrossProductExpressionContext,
796 subquery: &llkv_plan::ScalarSubquery,
797 batch: &RecordBatch,
798 row_idx: usize,
799 ) -> ExecutorResult<Literal> {
800 let bindings =
801 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
802 self.evaluate_scalar_subquery_with_bindings(subquery, &bindings)
803 }
804
805 fn evaluate_scalar_subquery_with_bindings(
806 &self,
807 subquery: &llkv_plan::ScalarSubquery,
808 bindings: &FxHashMap<String, Literal>,
809 ) -> ExecutorResult<Literal> {
810 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
811 let execution = self.execute_select(bound_plan)?;
812 let mut rows_seen: usize = 0;
813 let mut result: Option<Literal> = None;
814 execution.stream(|inner_batch| {
815 if inner_batch.num_columns() != 1 {
816 return Err(Error::InvalidArgumentError(
817 "scalar subquery must return exactly one column".into(),
818 ));
819 }
820 let column = inner_batch.column(0).clone();
821 for idx in 0..inner_batch.num_rows() {
822 if rows_seen >= 1 {
823 return Err(Error::InvalidArgumentError(
824 "scalar subquery produced more than one row".into(),
825 ));
826 }
827 rows_seen = rows_seen.saturating_add(1);
828 result = Some(Literal::from_array_ref(&column, idx)?);
829 }
830 Ok(())
831 })?;
832
833 if rows_seen == 0 {
834 Ok(Literal::Null)
835 } else {
836 result
837 .ok_or_else(|| Error::Internal("scalar subquery evaluation missing result".into()))
838 }
839 }
840
841 fn evaluate_scalar_subquery_numeric(
842 &self,
843 context: &mut CrossProductExpressionContext,
844 subquery: &llkv_plan::ScalarSubquery,
845 batch: &RecordBatch,
846 ) -> ExecutorResult<ArrayRef> {
847 let row_count = batch.num_rows();
848 let mut row_job_indices: Vec<usize> = Vec::with_capacity(row_count);
849 let mut unique_bindings: Vec<FxHashMap<String, Literal>> = Vec::new();
850 let mut key_lookup: FxHashMap<Vec<u8>, usize> = FxHashMap::default();
851
852 for row_idx in 0..row_count {
853 let bindings =
854 collect_correlated_bindings(context, batch, row_idx, &subquery.correlated_columns)?;
855
856 let mut plan_values: Vec<PlanValue> =
858 Vec::with_capacity(subquery.correlated_columns.len());
859 for column in &subquery.correlated_columns {
860 let literal = bindings
861 .get(&column.placeholder)
862 .cloned()
863 .unwrap_or(Literal::Null);
864 let plan_value = plan_value_from_literal(&literal)?;
865 plan_values.push(plan_value);
866 }
867 let key = encode_row(&plan_values);
868
869 let job_idx = if let Some(&existing) = key_lookup.get(&key) {
870 existing
871 } else {
872 let idx = unique_bindings.len();
873 key_lookup.insert(key, idx);
874 unique_bindings.push(bindings);
875 idx
876 };
877 row_job_indices.push(job_idx);
878 }
879
880 let job_results: Vec<ExecutorResult<Literal>> =
882 llkv_column_map::parallel::with_thread_pool(|| {
883 let results: Vec<ExecutorResult<Literal>> = unique_bindings
884 .par_iter()
885 .map(|bindings| self.evaluate_scalar_subquery_with_bindings(subquery, bindings))
886 .collect();
887 results
888 });
889
890 let mut job_literals: Vec<Literal> = Vec::with_capacity(job_results.len());
891 for result in job_results {
892 job_literals.push(result?);
893 }
894
895 let mut values: Vec<Option<f64>> = Vec::with_capacity(row_count);
896 let mut all_integer = true;
897
898 for row_idx in 0..row_count {
899 let literal = &job_literals[row_job_indices[row_idx]];
900 match literal {
901 Literal::Null => values.push(None),
902 Literal::Int128(value) => {
903 let cast = i64::try_from(*value).map_err(|_| {
904 Error::InvalidArgumentError(
905 "scalar subquery integer result exceeds supported range".into(),
906 )
907 })?;
908 values.push(Some(cast as f64));
909 }
910 Literal::Float64(value) => {
911 all_integer = false;
912 values.push(Some(*value));
913 }
914 Literal::Boolean(flag) => {
915 let numeric = if *flag { 1.0 } else { 0.0 };
916 values.push(Some(numeric));
917 }
918 Literal::Decimal128(decimal) => {
919 if let Some(value) = decimal_exact_i64(*decimal) {
920 values.push(Some(value as f64));
921 } else {
922 all_integer = false;
923 values.push(Some(decimal.to_f64()));
924 }
925 }
926 Literal::String(_)
927 | Literal::Struct(_)
928 | Literal::Date32(_)
929 | Literal::Interval(_) => {
930 return Err(Error::InvalidArgumentError(
931 "scalar subquery produced non-numeric result in numeric context".into(),
932 ));
933 }
934 }
935 }
936
937 if all_integer {
938 let iter = values.into_iter().map(|opt| opt.map(|v| v as i64));
939 let array = Int64Array::from_iter(iter);
940 Ok(Arc::new(array) as ArrayRef)
941 } else {
942 let array = Float64Array::from_iter(values);
943 Ok(Arc::new(array) as ArrayRef)
944 }
945 }
946
947 fn evaluate_scalar_subquery_array(
948 &self,
949 context: &mut CrossProductExpressionContext,
950 subquery: &llkv_plan::ScalarSubquery,
951 batch: &RecordBatch,
952 ) -> ExecutorResult<ArrayRef> {
953 let mut values = Vec::with_capacity(batch.num_rows());
954 for row_idx in 0..batch.num_rows() {
955 let literal =
956 self.evaluate_scalar_subquery_literal(context, subquery, batch, row_idx)?;
957 values.push(literal);
958 }
959 literals_to_array(&values)
960 }
961
962 fn evaluate_projection_expression(
963 &self,
964 context: &mut CrossProductExpressionContext,
965 expr: &ScalarExpr<String>,
966 batch: &RecordBatch,
967 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
968 ) -> ExecutorResult<ArrayRef> {
969 let translated = translate_scalar(expr, context.schema(), |name| {
970 Error::InvalidArgumentError(format!(
971 "column '{}' not found in cross product result",
972 name
973 ))
974 })?;
975
976 let mut subquery_ids: FxHashSet<SubqueryId> = FxHashSet::default();
977 collect_scalar_subquery_ids(&translated, &mut subquery_ids);
978
979 let mut mapping: FxHashMap<SubqueryId, FieldId> = FxHashMap::default();
980 for subquery_id in &subquery_ids {
981 let info = scalar_lookup
982 .get(subquery_id)
983 .ok_or_else(|| Error::Internal("missing scalar subquery metadata".into()))?;
984 let field_id = context.allocate_synthetic_field_id()?;
985 let numeric = self.evaluate_scalar_subquery_numeric(context, info, batch)?;
986 context.numeric_cache.insert(field_id, numeric);
987 mapping.insert(*subquery_id, field_id);
988 }
989
990 let rewritten = rewrite_scalar_expr_for_subqueries(&translated, &mapping);
991 context.evaluate_numeric(&rewritten, batch)
992 }
993
994 fn execute_select_without_table(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
996 use arrow::array::ArrayRef;
997 use arrow::datatypes::Field;
998
999 let mut fields = Vec::new();
1001 let mut arrays: Vec<ArrayRef> = Vec::new();
1002
1003 for proj in &plan.projections {
1004 match proj {
1005 SelectProjection::Computed { expr, alias } => {
1006 let literal =
1007 evaluate_constant_scalar_with_aggregates(expr).ok_or_else(|| {
1008 Error::InvalidArgumentError(
1009 "SELECT without FROM only supports constant expressions".into(),
1010 )
1011 })?;
1012 let (dtype, array) = Self::literal_to_array(&literal)?;
1013
1014 fields.push(Field::new(alias.clone(), dtype, true));
1015 arrays.push(array);
1016 }
1017 _ => {
1018 return Err(Error::InvalidArgumentError(
1019 "SELECT without FROM only supports computed projections".into(),
1020 ));
1021 }
1022 }
1023 }
1024
1025 let schema = Arc::new(Schema::new(fields));
1026 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)
1027 .map_err(|e| Error::Internal(format!("failed to create record batch: {}", e)))?;
1028
1029 if plan.distinct {
1030 let mut state = DistinctState::default();
1031 batch = match distinct_filter_batch(batch, &mut state)? {
1032 Some(filtered) => filtered,
1033 None => RecordBatch::new_empty(Arc::clone(&schema)),
1034 };
1035 }
1036
1037 let schema = batch.schema();
1038
1039 Ok(SelectExecution::new_single_batch(
1040 String::new(), schema,
1042 batch,
1043 ))
1044 }
1045
1046 fn literal_to_array(lit: &llkv_expr::literal::Literal) -> ExecutorResult<(DataType, ArrayRef)> {
1048 use arrow::array::{
1049 ArrayRef, BooleanArray, Date32Array, Decimal128Array, Float64Array, Int64Array,
1050 IntervalMonthDayNanoArray, StringArray, StructArray, new_null_array,
1051 };
1052 use arrow::datatypes::{DataType, Field, IntervalUnit};
1053 use llkv_compute::scalar::interval::interval_value_to_arrow;
1054 use llkv_expr::literal::Literal;
1055
1056 match lit {
1057 Literal::Int128(v) => {
1058 let val = i64::try_from(*v).unwrap_or(0);
1059 Ok((
1060 DataType::Int64,
1061 Arc::new(Int64Array::from(vec![val])) as ArrayRef,
1062 ))
1063 }
1064 Literal::Float64(v) => Ok((
1065 DataType::Float64,
1066 Arc::new(Float64Array::from(vec![*v])) as ArrayRef,
1067 )),
1068 Literal::Boolean(v) => Ok((
1069 DataType::Boolean,
1070 Arc::new(BooleanArray::from(vec![*v])) as ArrayRef,
1071 )),
1072 Literal::Decimal128(value) => {
1073 let iter = std::iter::once(value.raw_value());
1074 let precision = std::cmp::max(value.precision(), value.scale() as u8);
1075 let array = Decimal128Array::from_iter_values(iter)
1076 .with_precision_and_scale(precision, value.scale())
1077 .map_err(|err| {
1078 Error::InvalidArgumentError(format!(
1079 "failed to build Decimal128 literal array: {err}"
1080 ))
1081 })?;
1082 Ok((
1083 DataType::Decimal128(precision, value.scale()),
1084 Arc::new(array) as ArrayRef,
1085 ))
1086 }
1087 Literal::String(v) => Ok((
1088 DataType::Utf8,
1089 Arc::new(StringArray::from(vec![v.clone()])) as ArrayRef,
1090 )),
1091 Literal::Date32(v) => Ok((
1092 DataType::Date32,
1093 Arc::new(Date32Array::from(vec![*v])) as ArrayRef,
1094 )),
1095 Literal::Null => Ok((DataType::Null, new_null_array(&DataType::Null, 1))),
1096 Literal::Interval(interval) => Ok((
1097 DataType::Interval(IntervalUnit::MonthDayNano),
1098 Arc::new(IntervalMonthDayNanoArray::from(vec![
1099 interval_value_to_arrow(*interval),
1100 ])) as ArrayRef,
1101 )),
1102 Literal::Struct(struct_fields) => {
1103 let mut inner_fields = Vec::new();
1105 let mut inner_arrays = Vec::new();
1106
1107 for (field_name, field_lit) in struct_fields {
1108 let (field_dtype, field_array) = Self::literal_to_array(field_lit)?;
1109 inner_fields.push(Field::new(field_name.clone(), field_dtype, true));
1110 inner_arrays.push(field_array);
1111 }
1112
1113 let struct_array =
1114 StructArray::try_new(inner_fields.clone().into(), inner_arrays, None).map_err(
1115 |e| Error::Internal(format!("failed to create struct array: {}", e)),
1116 )?;
1117
1118 Ok((
1119 DataType::Struct(inner_fields.into()),
1120 Arc::new(struct_array) as ArrayRef,
1121 ))
1122 }
1123 }
1124 }
1125
1126 fn execute_cross_product(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
1128 use arrow::compute::concat_batches;
1129
1130 if plan.tables.len() < 2 {
1131 return Err(Error::InvalidArgumentError(
1132 "cross product requires at least 2 tables".into(),
1133 ));
1134 }
1135
1136 let mut tables_with_handles = Vec::with_capacity(plan.tables.len());
1137 for table_ref in &plan.tables {
1138 let qualified_name = table_ref.qualified_name();
1139 let table = self.provider.get_table(&qualified_name)?;
1140 tables_with_handles.push((table_ref.clone(), table));
1141 }
1142
1143 let display_name = tables_with_handles
1144 .iter()
1145 .map(|(table_ref, _)| table_ref.qualified_name())
1146 .collect::<Vec<_>>()
1147 .join(",");
1148
1149 let mut remaining_filter = plan.filter.clone();
1150
1151 let join_data = if remaining_filter.as_ref().is_some() {
1154 self.try_execute_hash_join(&plan, &tables_with_handles)?
1155 } else {
1156 None
1157 };
1158
1159 let current = if let Some((joined, handled_all_predicates)) = join_data {
1160 if handled_all_predicates {
1162 remaining_filter = None;
1163 }
1164 joined
1165 } else {
1166 let has_joins = !plan.joins.is_empty();
1168
1169 if has_joins && tables_with_handles.len() == 2 {
1170 use llkv_join::{JoinOptions, TableJoinExt};
1172
1173 let (left_ref, left_table) = &tables_with_handles[0];
1174 let (right_ref, right_table) = &tables_with_handles[1];
1175
1176 let join_metadata = plan.joins.first().ok_or_else(|| {
1177 Error::InvalidArgumentError("expected join metadata for two-table join".into())
1178 })?;
1179
1180 let join_type = match join_metadata.join_type {
1181 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1182 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1183 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1184 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1185 };
1186
1187 tracing::debug!(
1188 "Using llkv-join for {join_type:?} join between {} and {}",
1189 left_ref.qualified_name(),
1190 right_ref.qualified_name()
1191 );
1192
1193 let left_col_count = left_table.schema.columns.len();
1194 let right_col_count = right_table.schema.columns.len();
1195
1196 let mut combined_fields = Vec::with_capacity(left_col_count + right_col_count);
1197 for col in &left_table.schema.columns {
1198 combined_fields.push(Field::new(
1199 col.name.clone(),
1200 col.data_type.clone(),
1201 col.nullable,
1202 ));
1203 }
1204 for col in &right_table.schema.columns {
1205 combined_fields.push(Field::new(
1206 col.name.clone(),
1207 col.data_type.clone(),
1208 col.nullable,
1209 ));
1210 }
1211 let combined_schema = Arc::new(Schema::new(combined_fields));
1212 let column_counts = vec![left_col_count, right_col_count];
1213 let table_indices = vec![0, 1];
1214
1215 let mut join_keys = Vec::new();
1216 let mut condition_is_trivial = false;
1217 let mut condition_is_impossible = false;
1218
1219 if let Some(condition) = join_metadata.on_condition.as_ref() {
1220 let plan = build_join_keys_from_condition(
1221 condition,
1222 left_ref,
1223 left_table.as_ref(),
1224 right_ref,
1225 right_table.as_ref(),
1226 )?;
1227 join_keys = plan.keys;
1228 condition_is_trivial = plan.always_true;
1229 condition_is_impossible = plan.always_false;
1230 }
1231
1232 if condition_is_impossible {
1233 let batches = build_no_match_join_batches(
1234 join_type,
1235 left_ref,
1236 left_table.as_ref(),
1237 right_ref,
1238 right_table.as_ref(),
1239 Arc::clone(&combined_schema),
1240 )?;
1241
1242 TableCrossProductData {
1243 schema: combined_schema,
1244 batches,
1245 column_counts,
1246 table_indices,
1247 }
1248 } else {
1249 if !condition_is_trivial
1250 && join_metadata.on_condition.is_some()
1251 && join_keys.is_empty()
1252 {
1253 return Err(Error::InvalidArgumentError(
1254 "JOIN ON clause must include at least one equality predicate".into(),
1255 ));
1256 }
1257
1258 let mut result_batches = Vec::new();
1259 left_table.table.join_stream(
1260 &right_table.table,
1261 &join_keys,
1262 &JoinOptions {
1263 join_type,
1264 ..Default::default()
1265 },
1266 |batch| {
1267 result_batches.push(batch);
1268 },
1269 )?;
1270
1271 TableCrossProductData {
1272 schema: combined_schema,
1273 batches: result_batches,
1274 column_counts,
1275 table_indices,
1276 }
1277 }
1278 } else if has_joins && tables_with_handles.len() > 2 {
1279 let join_lookup: FxHashMap<usize, &llkv_plan::JoinMetadata> = plan
1282 .joins
1283 .iter()
1284 .map(|join| (join.left_table_index, join))
1285 .collect();
1286
1287 let constraint_map = if let Some(filter_wrapper) = remaining_filter.as_ref() {
1289 extract_literal_pushdown_filters(
1290 &filter_wrapper.predicate,
1291 &tables_with_handles,
1292 )
1293 } else {
1294 vec![Vec::new(); tables_with_handles.len()]
1295 };
1296
1297 let (first_ref, first_table) = &tables_with_handles[0];
1299 let first_constraints = constraint_map.first().map(|v| v.as_slice()).unwrap_or(&[]);
1300 let mut accumulated =
1301 collect_table_data(0, first_ref, first_table.as_ref(), first_constraints)?;
1302
1303 for (idx, (right_ref, right_table)) in
1305 tables_with_handles.iter().enumerate().skip(1)
1306 {
1307 let right_constraints =
1308 constraint_map.get(idx).map(|v| v.as_slice()).unwrap_or(&[]);
1309
1310 let join_metadata = join_lookup.get(&(idx - 1)).ok_or_else(|| {
1311 Error::InvalidArgumentError(format!(
1312 "No join condition found between table {} and {}. Multi-table queries require explicit JOIN syntax.",
1313 idx - 1, idx
1314 ))
1315 })?;
1316
1317 let join_type = match join_metadata.join_type {
1318 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1319 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1320 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1321 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1322 };
1323
1324 let right_data = collect_table_data(
1326 idx,
1327 right_ref,
1328 right_table.as_ref(),
1329 right_constraints,
1330 )?;
1331
1332 let condition_expr = join_metadata
1334 .on_condition
1335 .clone()
1336 .unwrap_or(LlkvExpr::Literal(true));
1337
1338 let join_batches = execute_hash_join_batches(
1341 &accumulated.schema,
1342 &accumulated.batches,
1343 &right_data.schema,
1344 &right_data.batches,
1345 &condition_expr,
1346 join_type,
1347 )?;
1348
1349 let combined_fields: Vec<Field> = accumulated
1351 .schema
1352 .fields()
1353 .iter()
1354 .chain(right_data.schema.fields().iter())
1355 .map(|f| {
1356 Field::new(f.name().clone(), f.data_type().clone(), f.is_nullable())
1357 })
1358 .collect();
1359 let combined_schema = Arc::new(Schema::new(combined_fields));
1360
1361 accumulated = TableCrossProductData {
1362 schema: combined_schema,
1363 batches: join_batches,
1364 column_counts: {
1365 let mut counts = accumulated.column_counts;
1366 counts.push(right_data.schema.fields().len());
1367 counts
1368 },
1369 table_indices: {
1370 let mut indices = accumulated.table_indices;
1371 indices.push(idx);
1372 indices
1373 },
1374 };
1375 }
1376
1377 accumulated
1378 } else {
1379 let constraint_map = if let Some(filter_wrapper) = remaining_filter.as_ref() {
1381 extract_literal_pushdown_filters(
1382 &filter_wrapper.predicate,
1383 &tables_with_handles,
1384 )
1385 } else {
1386 vec![Vec::new(); tables_with_handles.len()]
1387 };
1388
1389 let mut staged: Vec<TableCrossProductData> =
1390 Vec::with_capacity(tables_with_handles.len());
1391 let join_lookup: FxHashMap<usize, &llkv_plan::JoinMetadata> = plan
1392 .joins
1393 .iter()
1394 .map(|join| (join.left_table_index, join))
1395 .collect();
1396
1397 let mut idx = 0usize;
1398 while idx < tables_with_handles.len() {
1399 if let Some(join_metadata) = join_lookup.get(&idx) {
1400 if idx + 1 >= tables_with_handles.len() {
1401 return Err(Error::Internal(
1402 "join metadata references table beyond FROM list".into(),
1403 ));
1404 }
1405
1406 let overlaps_next_join = join_lookup.contains_key(&(idx + 1));
1411 if let Some(condition) = join_metadata.on_condition.as_ref() {
1412 let (left_ref, left_table) = &tables_with_handles[idx];
1413 let (right_ref, right_table) = &tables_with_handles[idx + 1];
1414 let join_plan = build_join_keys_from_condition(
1415 condition,
1416 left_ref,
1417 left_table.as_ref(),
1418 right_ref,
1419 right_table.as_ref(),
1420 )?;
1421 if join_plan.always_false && !overlaps_next_join {
1422 let join_type = match join_metadata.join_type {
1423 llkv_plan::JoinPlan::Inner => llkv_join::JoinType::Inner,
1424 llkv_plan::JoinPlan::Left => llkv_join::JoinType::Left,
1425 llkv_plan::JoinPlan::Right => llkv_join::JoinType::Right,
1426 llkv_plan::JoinPlan::Full => llkv_join::JoinType::Full,
1427 };
1428
1429 let left_col_count = left_table.schema.columns.len();
1430 let right_col_count = right_table.schema.columns.len();
1431
1432 let mut combined_fields =
1433 Vec::with_capacity(left_col_count + right_col_count);
1434 for col in &left_table.schema.columns {
1435 combined_fields.push(Field::new(
1436 col.name.clone(),
1437 col.data_type.clone(),
1438 col.nullable,
1439 ));
1440 }
1441 for col in &right_table.schema.columns {
1442 combined_fields.push(Field::new(
1443 col.name.clone(),
1444 col.data_type.clone(),
1445 col.nullable,
1446 ));
1447 }
1448
1449 let combined_schema = Arc::new(Schema::new(combined_fields));
1450 let batches = build_no_match_join_batches(
1451 join_type,
1452 left_ref,
1453 left_table.as_ref(),
1454 right_ref,
1455 right_table.as_ref(),
1456 Arc::clone(&combined_schema),
1457 )?;
1458
1459 staged.push(TableCrossProductData {
1460 schema: combined_schema,
1461 batches,
1462 column_counts: vec![left_col_count, right_col_count],
1463 table_indices: vec![idx, idx + 1],
1464 });
1465 idx += 2;
1466 continue;
1467 }
1468 }
1469 }
1470
1471 let (table_ref, table) = &tables_with_handles[idx];
1472 let constraints = constraint_map.get(idx).map(|v| v.as_slice()).unwrap_or(&[]);
1473 staged.push(collect_table_data(
1474 idx,
1475 table_ref,
1476 table.as_ref(),
1477 constraints,
1478 )?);
1479 idx += 1;
1480 }
1481
1482 cross_join_all(staged)?
1483 }
1484 };
1485
1486 let TableCrossProductData {
1487 schema: combined_schema,
1488 batches: mut combined_batches,
1489 column_counts,
1490 table_indices,
1491 } = current;
1492
1493 let column_lookup_map = build_cross_product_column_lookup(
1494 combined_schema.as_ref(),
1495 &plan.tables,
1496 &column_counts,
1497 &table_indices,
1498 );
1499
1500 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
1501 .scalar_subqueries
1502 .iter()
1503 .map(|subquery| (subquery.id, subquery))
1504 .collect();
1505
1506 if let Some(filter_wrapper) = remaining_filter.as_ref() {
1507 let mut filter_context = CrossProductExpressionContext::new(
1508 combined_schema.as_ref(),
1509 column_lookup_map.clone(),
1510 )?;
1511 let translated_filter = translate_predicate(
1512 filter_wrapper.predicate.clone(),
1513 filter_context.schema(),
1514 |name| {
1515 Error::InvalidArgumentError(format!(
1516 "column '{}' not found in cross product result",
1517 name
1518 ))
1519 },
1520 )?;
1521
1522 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
1523 filter_wrapper
1524 .subqueries
1525 .iter()
1526 .map(|subquery| (subquery.id, subquery))
1527 .collect();
1528 let mut predicate_scalar_ids = FxHashSet::default();
1529 collect_predicate_scalar_subquery_ids(&translated_filter, &mut predicate_scalar_ids);
1530
1531 let mut filtered_batches = Vec::with_capacity(combined_batches.len());
1532 for batch in combined_batches.into_iter() {
1533 filter_context.reset();
1534 for subquery_id in &predicate_scalar_ids {
1535 let info = scalar_lookup.get(subquery_id).ok_or_else(|| {
1536 Error::Internal("missing scalar subquery metadata".into())
1537 })?;
1538 let array =
1539 self.evaluate_scalar_subquery_array(&mut filter_context, info, &batch)?;
1540 let accessor = ColumnAccessor::from_array(&array)?;
1541 filter_context.register_scalar_subquery_column(*subquery_id, accessor);
1542 }
1543 let mask = filter_context.evaluate_predicate_mask(
1544 &translated_filter,
1545 &batch,
1546 |ctx, subquery_expr, row_idx, current_batch| {
1547 let subquery = subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
1548 Error::Internal("missing correlated subquery metadata".into())
1549 })?;
1550 let exists =
1551 self.evaluate_exists_subquery(ctx, subquery, current_batch, row_idx)?;
1552 let value = if subquery_expr.negated {
1553 !exists
1554 } else {
1555 exists
1556 };
1557 Ok(Some(value))
1558 },
1559 )?;
1560 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
1561 Error::InvalidArgumentError(format!(
1562 "failed to apply cross product filter: {err}"
1563 ))
1564 })?;
1565 if filtered.num_rows() > 0 {
1566 filtered_batches.push(filtered);
1567 }
1568 }
1569 combined_batches = filtered_batches;
1570 }
1571
1572 if !plan.group_by.is_empty() {
1574 return self.execute_group_by_from_batches(
1575 display_name,
1576 plan,
1577 combined_schema,
1578 combined_batches,
1579 column_lookup_map,
1580 );
1581 }
1582
1583 if !plan.aggregates.is_empty() {
1584 return self.execute_cross_product_aggregates(
1585 Arc::clone(&combined_schema),
1586 combined_batches,
1587 &column_lookup_map,
1588 &plan,
1589 &display_name,
1590 );
1591 }
1592
1593 if self.has_computed_aggregates(&plan) {
1594 return self.execute_cross_product_computed_aggregates(
1595 Arc::clone(&combined_schema),
1596 combined_batches,
1597 &column_lookup_map,
1598 &plan,
1599 &display_name,
1600 );
1601 }
1602
1603 let mut combined_batch = if combined_batches.is_empty() {
1604 RecordBatch::new_empty(Arc::clone(&combined_schema))
1605 } else if combined_batches.len() == 1 {
1606 combined_batches.pop().unwrap()
1607 } else {
1608 concat_batches(&combined_schema, &combined_batches).map_err(|e| {
1609 Error::Internal(format!(
1610 "failed to concatenate cross product batches: {}",
1611 e
1612 ))
1613 })?
1614 };
1615
1616 if !plan.order_by.is_empty() {
1618 let mut resolved_order_by = Vec::with_capacity(plan.order_by.len());
1619 for order in &plan.order_by {
1620 let resolved_target = match &order.target {
1621 OrderTarget::Column(name) => {
1622 let col_name = name.to_ascii_lowercase();
1623 if let Some(&idx) = column_lookup_map.get(&col_name) {
1624 OrderTarget::Index(idx)
1625 } else {
1626 if let Ok(idx) = combined_schema.index_of(name) {
1628 OrderTarget::Index(idx)
1629 } else {
1630 return Err(Error::InvalidArgumentError(format!(
1631 "ORDER BY references unknown column '{}'",
1632 name
1633 )));
1634 }
1635 }
1636 }
1637 other => other.clone(),
1638 };
1639 resolved_order_by.push(llkv_plan::OrderByPlan {
1640 target: resolved_target,
1641 sort_type: order.sort_type.clone(),
1642 ascending: order.ascending,
1643 nulls_first: order.nulls_first,
1644 });
1645 }
1646
1647 combined_batch = sort_record_batch_with_order(
1648 &combined_schema,
1649 &combined_batch,
1650 &resolved_order_by,
1651 )?;
1652 }
1653
1654 if !plan.projections.is_empty() {
1656 let mut selected_fields = Vec::new();
1657 let mut selected_columns = Vec::new();
1658 let mut expr_context: Option<CrossProductExpressionContext> = None;
1659
1660 for proj in &plan.projections {
1661 match proj {
1662 SelectProjection::AllColumns => {
1663 selected_fields = combined_schema.fields().iter().cloned().collect();
1665 selected_columns = combined_batch.columns().to_vec();
1666 break;
1667 }
1668 SelectProjection::AllColumnsExcept { exclude } => {
1669 let exclude_lower: Vec<String> =
1671 exclude.iter().map(|e| e.to_ascii_lowercase()).collect();
1672
1673 let mut excluded_indices = FxHashSet::default();
1674 for excluded_name in &exclude_lower {
1675 if let Some(&idx) = column_lookup_map.get(excluded_name) {
1676 excluded_indices.insert(idx);
1677 }
1678 }
1679
1680 for (idx, field) in combined_schema.fields().iter().enumerate() {
1681 let field_name_lower = field.name().to_ascii_lowercase();
1682 if !exclude_lower.contains(&field_name_lower)
1683 && !excluded_indices.contains(&idx)
1684 {
1685 selected_fields.push(field.clone());
1686 selected_columns.push(combined_batch.column(idx).clone());
1687 }
1688 }
1689 break;
1690 }
1691 SelectProjection::Column { name, alias } => {
1692 let col_name = name.to_ascii_lowercase();
1694 if let Some(&idx) = column_lookup_map.get(&col_name) {
1695 let field = combined_schema.field(idx);
1696 let output_name = alias.as_ref().unwrap_or(name).clone();
1697 selected_fields.push(Arc::new(arrow::datatypes::Field::new(
1698 output_name,
1699 field.data_type().clone(),
1700 field.is_nullable(),
1701 )));
1702 selected_columns.push(combined_batch.column(idx).clone());
1703 } else {
1704 return Err(Error::InvalidArgumentError(format!(
1705 "column '{}' not found in cross product result",
1706 name
1707 )));
1708 }
1709 }
1710 SelectProjection::Computed { expr, alias } => {
1711 if expr_context.is_none() {
1712 expr_context = Some(CrossProductExpressionContext::new(
1713 combined_schema.as_ref(),
1714 column_lookup_map.clone(),
1715 )?);
1716 }
1717 let context = expr_context
1718 .as_mut()
1719 .expect("projection context must be initialized");
1720 context.reset();
1721 let evaluated = self.evaluate_projection_expression(
1722 context,
1723 expr,
1724 &combined_batch,
1725 &scalar_lookup,
1726 )?;
1727 let field = Arc::new(arrow::datatypes::Field::new(
1728 alias.clone(),
1729 evaluated.data_type().clone(),
1730 true,
1731 ));
1732 selected_fields.push(field);
1733 selected_columns.push(evaluated);
1734 }
1735 }
1736 }
1737
1738 let projected_schema = Arc::new(Schema::new(selected_fields));
1739 combined_batch = RecordBatch::try_new(projected_schema, selected_columns)
1740 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))?;
1741 }
1742
1743 if plan.distinct {
1744 let mut state = DistinctState::default();
1745 let source_schema = combined_batch.schema();
1746 combined_batch = match distinct_filter_batch(combined_batch, &mut state)? {
1747 Some(filtered) => filtered,
1748 None => RecordBatch::new_empty(source_schema),
1749 };
1750 }
1751
1752 let schema = combined_batch.schema();
1753
1754 Ok(SelectExecution::new_single_batch(
1755 display_name,
1756 schema,
1757 combined_batch,
1758 ))
1759 }
1760}
1761
1762struct JoinKeyBuild {
1763 keys: Vec<llkv_join::JoinKey>,
1764 always_true: bool,
1765 always_false: bool,
1766}
1767
1768#[allow(dead_code)]
1770type JoinKeyBuildEqualities = JoinKeyBuild;
1771
1772impl JoinKeyBuild {
1773 #[allow(dead_code)]
1774 fn equalities(&self) -> &[llkv_join::JoinKey] {
1775 &self.keys
1776 }
1777}
1778
1779#[derive(Debug)]
1780enum JoinConditionAnalysis {
1781 AlwaysTrue,
1782 AlwaysFalse,
1783 EquiPairs(Vec<(String, String)>),
1784}
1785
1786fn build_join_keys_from_condition<P>(
1787 condition: &LlkvExpr<'static, String>,
1788 left_ref: &llkv_plan::TableRef,
1789 left_table: &ExecutorTable<P>,
1790 right_ref: &llkv_plan::TableRef,
1791 right_table: &ExecutorTable<P>,
1792) -> ExecutorResult<JoinKeyBuild>
1793where
1794 P: Pager<Blob = EntryHandle> + Send + Sync,
1795{
1796 match analyze_join_condition(condition)? {
1797 JoinConditionAnalysis::AlwaysTrue => Ok(JoinKeyBuild {
1798 keys: Vec::new(),
1799 always_true: true,
1800 always_false: false,
1801 }),
1802 JoinConditionAnalysis::AlwaysFalse => Ok(JoinKeyBuild {
1803 keys: Vec::new(),
1804 always_true: false,
1805 always_false: true,
1806 }),
1807 JoinConditionAnalysis::EquiPairs(pairs) => {
1808 let left_lookup = build_join_column_lookup(left_ref, left_table);
1809 let right_lookup = build_join_column_lookup(right_ref, right_table);
1810
1811 let mut keys = Vec::with_capacity(pairs.len());
1812 for (lhs, rhs) in pairs {
1813 let (lhs_side, lhs_field) = resolve_join_column(&lhs, &left_lookup, &right_lookup)?;
1814 let (rhs_side, rhs_field) = resolve_join_column(&rhs, &left_lookup, &right_lookup)?;
1815
1816 match (lhs_side, rhs_side) {
1817 (JoinColumnSide::Left, JoinColumnSide::Right) => {
1818 keys.push(llkv_join::JoinKey::new(lhs_field, rhs_field));
1819 }
1820 (JoinColumnSide::Right, JoinColumnSide::Left) => {
1821 keys.push(llkv_join::JoinKey::new(rhs_field, lhs_field));
1822 }
1823 (JoinColumnSide::Left, JoinColumnSide::Left) => {
1824 return Err(Error::InvalidArgumentError(format!(
1825 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1826 left_ref.display_name(),
1827 lhs,
1828 rhs
1829 )));
1830 }
1831 (JoinColumnSide::Right, JoinColumnSide::Right) => {
1832 return Err(Error::InvalidArgumentError(format!(
1833 "JOIN condition compares two columns from '{}': '{}' and '{}'",
1834 right_ref.display_name(),
1835 lhs,
1836 rhs
1837 )));
1838 }
1839 }
1840 }
1841
1842 Ok(JoinKeyBuild {
1843 keys,
1844 always_true: false,
1845 always_false: false,
1846 })
1847 }
1848 }
1849}
1850
1851fn analyze_join_condition(
1852 expr: &LlkvExpr<'static, String>,
1853) -> ExecutorResult<JoinConditionAnalysis> {
1854 match evaluate_constant_join_expr(expr) {
1855 ConstantJoinEvaluation::Known(true) => {
1856 return Ok(JoinConditionAnalysis::AlwaysTrue);
1857 }
1858 ConstantJoinEvaluation::Known(false) | ConstantJoinEvaluation::Unknown => {
1859 return Ok(JoinConditionAnalysis::AlwaysFalse);
1860 }
1861 ConstantJoinEvaluation::NotConstant => {}
1862 }
1863 match expr {
1864 LlkvExpr::Literal(value) => {
1865 if *value {
1866 Ok(JoinConditionAnalysis::AlwaysTrue)
1867 } else {
1868 Ok(JoinConditionAnalysis::AlwaysFalse)
1869 }
1870 }
1871 LlkvExpr::And(children) => {
1872 let mut collected: Vec<(String, String)> = Vec::new();
1873 for child in children {
1874 match analyze_join_condition(child)? {
1875 JoinConditionAnalysis::AlwaysTrue => {}
1876 JoinConditionAnalysis::AlwaysFalse => {
1877 return Ok(JoinConditionAnalysis::AlwaysFalse);
1878 }
1879 JoinConditionAnalysis::EquiPairs(mut pairs) => {
1880 collected.append(&mut pairs);
1881 }
1882 }
1883 }
1884
1885 if collected.is_empty() {
1886 Ok(JoinConditionAnalysis::AlwaysTrue)
1887 } else {
1888 Ok(JoinConditionAnalysis::EquiPairs(collected))
1889 }
1890 }
1891 LlkvExpr::Compare { left, op, right } => {
1892 if *op != CompareOp::Eq {
1893 return Err(Error::InvalidArgumentError(
1894 "JOIN ON clause only supports '=' comparisons in optimized path".into(),
1895 ));
1896 }
1897 let left_name = try_extract_simple_column(left).ok_or_else(|| {
1898 Error::InvalidArgumentError(
1899 "JOIN ON clause requires plain column references".into(),
1900 )
1901 })?;
1902 let right_name = try_extract_simple_column(right).ok_or_else(|| {
1903 Error::InvalidArgumentError(
1904 "JOIN ON clause requires plain column references".into(),
1905 )
1906 })?;
1907 Ok(JoinConditionAnalysis::EquiPairs(vec![(
1908 left_name.to_string(),
1909 right_name.to_string(),
1910 )]))
1911 }
1912 _ => Err(Error::InvalidArgumentError(
1913 "JOIN ON expressions must be conjunctions of column equality predicates".into(),
1914 )),
1915 }
1916}
1917
1918fn compare_literals_with_mode(
1919 op: CompareOp,
1920 left: &Literal,
1921 right: &Literal,
1922 null_behavior: NullComparisonBehavior,
1923) -> Option<bool> {
1924 use std::cmp::Ordering;
1925
1926 fn ordering_result(ord: Ordering, op: CompareOp) -> bool {
1927 match op {
1928 CompareOp::Eq => ord == Ordering::Equal,
1929 CompareOp::NotEq => ord != Ordering::Equal,
1930 CompareOp::Lt => ord == Ordering::Less,
1931 CompareOp::LtEq => ord != Ordering::Greater,
1932 CompareOp::Gt => ord == Ordering::Greater,
1933 CompareOp::GtEq => ord != Ordering::Less,
1934 }
1935 }
1936
1937 fn compare_f64(lhs: f64, rhs: f64, op: CompareOp) -> bool {
1938 match op {
1939 CompareOp::Eq => lhs == rhs,
1940 CompareOp::NotEq => lhs != rhs,
1941 CompareOp::Lt => lhs < rhs,
1942 CompareOp::LtEq => lhs <= rhs,
1943 CompareOp::Gt => lhs > rhs,
1944 CompareOp::GtEq => lhs >= rhs,
1945 }
1946 }
1947
1948 match (left, right) {
1949 (Literal::Null, _) | (_, Literal::Null) => match null_behavior {
1950 NullComparisonBehavior::ThreeValuedLogic => None,
1951 },
1952 (Literal::Int128(lhs), Literal::Int128(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1953 (Literal::Float64(lhs), Literal::Float64(rhs)) => Some(compare_f64(*lhs, *rhs, op)),
1954 (Literal::Int128(lhs), Literal::Float64(rhs)) => Some(compare_f64(*lhs as f64, *rhs, op)),
1955 (Literal::Float64(lhs), Literal::Int128(rhs)) => Some(compare_f64(*lhs, *rhs as f64, op)),
1956 (Literal::Boolean(lhs), Literal::Boolean(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1957 (Literal::String(lhs), Literal::String(rhs)) => Some(ordering_result(lhs.cmp(rhs), op)),
1958 (Literal::Decimal128(lhs), Literal::Decimal128(rhs)) => {
1959 llkv_compute::scalar::decimal::compare(*lhs, *rhs)
1960 .ok()
1961 .map(|ord| ordering_result(ord, op))
1962 }
1963 (Literal::Decimal128(lhs), Literal::Int128(rhs)) => {
1964 DecimalValue::new(*rhs, 0).ok().and_then(|rhs_dec| {
1965 llkv_compute::scalar::decimal::compare(*lhs, rhs_dec)
1966 .ok()
1967 .map(|ord| ordering_result(ord, op))
1968 })
1969 }
1970 (Literal::Int128(lhs), Literal::Decimal128(rhs)) => {
1971 DecimalValue::new(*lhs, 0).ok().and_then(|lhs_dec| {
1972 llkv_compute::scalar::decimal::compare(lhs_dec, *rhs)
1973 .ok()
1974 .map(|ord| ordering_result(ord, op))
1975 })
1976 }
1977 (Literal::Decimal128(lhs), Literal::Float64(rhs)) => {
1978 Some(compare_f64(lhs.to_f64(), *rhs, op))
1979 }
1980 (Literal::Float64(lhs), Literal::Decimal128(rhs)) => {
1981 Some(compare_f64(*lhs, rhs.to_f64(), op))
1982 }
1983 (Literal::Struct(_), _) | (_, Literal::Struct(_)) => None,
1984 _ => None,
1985 }
1986}
1987
1988fn build_no_match_join_batches<P>(
1989 join_type: llkv_join::JoinType,
1990 left_ref: &llkv_plan::TableRef,
1991 left_table: &ExecutorTable<P>,
1992 right_ref: &llkv_plan::TableRef,
1993 right_table: &ExecutorTable<P>,
1994 combined_schema: Arc<Schema>,
1995) -> ExecutorResult<Vec<RecordBatch>>
1996where
1997 P: Pager<Blob = EntryHandle> + Send + Sync,
1998{
1999 match join_type {
2000 llkv_join::JoinType::Inner => Ok(Vec::new()),
2001 llkv_join::JoinType::Left => {
2002 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
2003 let mut results = Vec::new();
2004
2005 for left_batch in left_batches {
2006 let row_count = left_batch.num_rows();
2007 if row_count == 0 {
2008 continue;
2009 }
2010
2011 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2012 columns.extend(left_batch.columns().iter().cloned());
2013 for column in &right_table.schema.columns {
2014 columns.push(new_null_array(&column.data_type, row_count));
2015 }
2016
2017 let batch =
2018 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2019 Error::Internal(format!("failed to build LEFT JOIN fallback batch: {err}"))
2020 })?;
2021 results.push(batch);
2022 }
2023
2024 Ok(results)
2025 }
2026 llkv_join::JoinType::Right => {
2027 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
2028 let mut results = Vec::new();
2029
2030 for right_batch in right_batches {
2031 let row_count = right_batch.num_rows();
2032 if row_count == 0 {
2033 continue;
2034 }
2035
2036 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2037 for column in &left_table.schema.columns {
2038 columns.push(new_null_array(&column.data_type, row_count));
2039 }
2040 columns.extend(right_batch.columns().iter().cloned());
2041
2042 let batch =
2043 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2044 Error::Internal(format!("failed to build RIGHT JOIN fallback batch: {err}"))
2045 })?;
2046 results.push(batch);
2047 }
2048
2049 Ok(results)
2050 }
2051 llkv_join::JoinType::Full => {
2052 let mut results = Vec::new();
2053
2054 let left_batches = scan_all_columns_for_join(left_ref, left_table)?;
2055 for left_batch in left_batches {
2056 let row_count = left_batch.num_rows();
2057 if row_count == 0 {
2058 continue;
2059 }
2060
2061 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2062 columns.extend(left_batch.columns().iter().cloned());
2063 for column in &right_table.schema.columns {
2064 columns.push(new_null_array(&column.data_type, row_count));
2065 }
2066
2067 let batch =
2068 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2069 Error::Internal(format!(
2070 "failed to build FULL JOIN left fallback batch: {err}"
2071 ))
2072 })?;
2073 results.push(batch);
2074 }
2075
2076 let right_batches = scan_all_columns_for_join(right_ref, right_table)?;
2077 for right_batch in right_batches {
2078 let row_count = right_batch.num_rows();
2079 if row_count == 0 {
2080 continue;
2081 }
2082
2083 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2084 for column in &left_table.schema.columns {
2085 columns.push(new_null_array(&column.data_type, row_count));
2086 }
2087 columns.extend(right_batch.columns().iter().cloned());
2088
2089 let batch =
2090 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(|err| {
2091 Error::Internal(format!(
2092 "failed to build FULL JOIN right fallback batch: {err}"
2093 ))
2094 })?;
2095 results.push(batch);
2096 }
2097
2098 Ok(results)
2099 }
2100 other => Err(Error::InvalidArgumentError(format!(
2101 "{other:?} join type is not supported when join predicate is unsatisfiable",
2102 ))),
2103 }
2104}
2105
2106fn scan_all_columns_for_join<P>(
2107 table_ref: &llkv_plan::TableRef,
2108 table: &ExecutorTable<P>,
2109) -> ExecutorResult<Vec<RecordBatch>>
2110where
2111 P: Pager<Blob = EntryHandle> + Send + Sync,
2112{
2113 if table.schema.columns.is_empty() {
2114 return Err(Error::InvalidArgumentError(format!(
2115 "table '{}' has no columns; joins require at least one column",
2116 table_ref.qualified_name()
2117 )));
2118 }
2119
2120 let mut projections = Vec::with_capacity(table.schema.columns.len());
2121 for column in &table.schema.columns {
2122 projections.push(ScanProjection::from(StoreProjection::with_alias(
2123 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
2124 column.name.clone(),
2125 )));
2126 }
2127
2128 let filter_field = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
2129 let filter_expr = full_table_scan_filter(filter_field);
2130
2131 let mut batches = Vec::new();
2132 table.table.scan_stream(
2133 projections,
2134 &filter_expr,
2135 ScanStreamOptions {
2136 include_nulls: true,
2137 ..ScanStreamOptions::default()
2138 },
2139 |batch| {
2140 batches.push(batch);
2141 },
2142 )?;
2143
2144 Ok(batches)
2145}
2146
2147fn build_join_column_lookup<P>(
2148 table_ref: &llkv_plan::TableRef,
2149 table: &ExecutorTable<P>,
2150) -> FxHashMap<String, FieldId>
2151where
2152 P: Pager<Blob = EntryHandle> + Send + Sync,
2153{
2154 let mut lookup = FxHashMap::default();
2155 let table_lower = table_ref.table.to_ascii_lowercase();
2156 let qualified_lower = table_ref.qualified_name().to_ascii_lowercase();
2157 let display_lower = table_ref.display_name().to_ascii_lowercase();
2158 let alias_lower = table_ref.alias.as_ref().map(|s| s.to_ascii_lowercase());
2159 let schema_lower = if table_ref.schema.is_empty() {
2160 None
2161 } else {
2162 Some(table_ref.schema.to_ascii_lowercase())
2163 };
2164
2165 for column in &table.schema.columns {
2166 let base = column.name.to_ascii_lowercase();
2167 let short = base.rsplit('.').next().unwrap_or(base.as_str()).to_string();
2168
2169 lookup.entry(short.clone()).or_insert(column.field_id);
2170 lookup.entry(base.clone()).or_insert(column.field_id);
2171
2172 lookup
2173 .entry(format!("{table_lower}.{short}"))
2174 .or_insert(column.field_id);
2175
2176 if display_lower != table_lower {
2177 lookup
2178 .entry(format!("{display_lower}.{short}"))
2179 .or_insert(column.field_id);
2180 }
2181
2182 if qualified_lower != table_lower {
2183 lookup
2184 .entry(format!("{qualified_lower}.{short}"))
2185 .or_insert(column.field_id);
2186 }
2187
2188 if let Some(schema) = &schema_lower {
2189 lookup
2190 .entry(format!("{schema}.{table_lower}.{short}"))
2191 .or_insert(column.field_id);
2192 if display_lower != table_lower {
2193 lookup
2194 .entry(format!("{schema}.{display_lower}.{short}"))
2195 .or_insert(column.field_id);
2196 }
2197 }
2198
2199 if let Some(alias) = &alias_lower {
2200 lookup
2201 .entry(format!("{alias}.{short}"))
2202 .or_insert(column.field_id);
2203 }
2204 }
2205
2206 lookup
2207}
2208
2209#[derive(Clone, Copy)]
2210enum JoinColumnSide {
2211 Left,
2212 Right,
2213}
2214
2215fn resolve_join_column(
2216 column: &str,
2217 left_lookup: &FxHashMap<String, FieldId>,
2218 right_lookup: &FxHashMap<String, FieldId>,
2219) -> ExecutorResult<(JoinColumnSide, FieldId)> {
2220 let key = column.to_ascii_lowercase();
2221 match (left_lookup.get(&key), right_lookup.get(&key)) {
2222 (Some(&field_id), None) => Ok((JoinColumnSide::Left, field_id)),
2223 (None, Some(&field_id)) => Ok((JoinColumnSide::Right, field_id)),
2224 (Some(_), Some(_)) => Err(Error::InvalidArgumentError(format!(
2225 "join column '{column}' is ambiguous; qualify it with a table name or alias",
2226 ))),
2227 (None, None) => Err(Error::InvalidArgumentError(format!(
2228 "join column '{column}' was not found in either table",
2229 ))),
2230 }
2231}
2232
2233fn execute_hash_join_batches(
2244 left_schema: &Arc<Schema>,
2245 left_batches: &[RecordBatch],
2246 right_schema: &Arc<Schema>,
2247 right_batches: &[RecordBatch],
2248 condition: &LlkvExpr<'static, String>,
2249 join_type: llkv_join::JoinType,
2250) -> ExecutorResult<Vec<RecordBatch>> {
2251 let equalities = match analyze_join_condition(condition)? {
2253 JoinConditionAnalysis::AlwaysTrue => {
2254 let mut results = Vec::new();
2256 for left in left_batches {
2257 for right in right_batches {
2258 results.push(execute_cross_join_batches(left, right)?);
2259 }
2260 }
2261 return Ok(results);
2262 }
2263 JoinConditionAnalysis::AlwaysFalse => {
2264 let combined_fields: Vec<Field> = left_schema
2266 .fields()
2267 .iter()
2268 .chain(right_schema.fields().iter())
2269 .map(|f| Field::new(f.name().clone(), f.data_type().clone(), f.is_nullable()))
2270 .collect();
2271 let combined_schema = Arc::new(Schema::new(combined_fields));
2272
2273 let mut results = Vec::new();
2274 match join_type {
2275 llkv_join::JoinType::Inner
2276 | llkv_join::JoinType::Semi
2277 | llkv_join::JoinType::Anti => {
2278 results.push(RecordBatch::new_empty(combined_schema));
2279 }
2280 llkv_join::JoinType::Left => {
2281 for left_batch in left_batches {
2282 let row_count = left_batch.num_rows();
2283 if row_count == 0 {
2284 continue;
2285 }
2286 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2287 columns.extend(left_batch.columns().iter().cloned());
2288 for field in right_schema.fields() {
2289 columns.push(new_null_array(field.data_type(), row_count));
2290 }
2291 results.push(
2292 RecordBatch::try_new(Arc::clone(&combined_schema), columns)
2293 .map_err(|err| {
2294 Error::Internal(format!(
2295 "failed to materialize LEFT JOIN null-extension batch: {err}"
2296 ))
2297 })?,
2298 );
2299 }
2300 if results.is_empty() {
2301 results.push(RecordBatch::new_empty(combined_schema));
2302 }
2303 }
2304 llkv_join::JoinType::Right => {
2305 for right_batch in right_batches {
2306 let row_count = right_batch.num_rows();
2307 if row_count == 0 {
2308 continue;
2309 }
2310 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2311 for field in left_schema.fields() {
2312 columns.push(new_null_array(field.data_type(), row_count));
2313 }
2314 columns.extend(right_batch.columns().iter().cloned());
2315 results.push(
2316 RecordBatch::try_new(Arc::clone(&combined_schema), columns)
2317 .map_err(|err| {
2318 Error::Internal(format!(
2319 "failed to materialize RIGHT JOIN null-extension batch: {err}"
2320 ))
2321 })?,
2322 );
2323 }
2324 if results.is_empty() {
2325 results.push(RecordBatch::new_empty(combined_schema));
2326 }
2327 }
2328 llkv_join::JoinType::Full => {
2329 for left_batch in left_batches {
2330 let row_count = left_batch.num_rows();
2331 if row_count == 0 {
2332 continue;
2333 }
2334 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2335 columns.extend(left_batch.columns().iter().cloned());
2336 for field in right_schema.fields() {
2337 columns.push(new_null_array(field.data_type(), row_count));
2338 }
2339 results.push(
2340 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(
2341 |err| {
2342 Error::Internal(format!(
2343 "failed to materialize FULL JOIN left batch: {err}"
2344 ))
2345 },
2346 )?,
2347 );
2348 }
2349
2350 for right_batch in right_batches {
2351 let row_count = right_batch.num_rows();
2352 if row_count == 0 {
2353 continue;
2354 }
2355 let mut columns = Vec::with_capacity(combined_schema.fields().len());
2356 for field in left_schema.fields() {
2357 columns.push(new_null_array(field.data_type(), row_count));
2358 }
2359 columns.extend(right_batch.columns().iter().cloned());
2360 results.push(
2361 RecordBatch::try_new(Arc::clone(&combined_schema), columns).map_err(
2362 |err| {
2363 Error::Internal(format!(
2364 "failed to materialize FULL JOIN right batch: {err}"
2365 ))
2366 },
2367 )?,
2368 );
2369 }
2370
2371 if results.is_empty() {
2372 results.push(RecordBatch::new_empty(combined_schema));
2373 }
2374 }
2375 }
2376
2377 return Ok(results);
2378 }
2379 JoinConditionAnalysis::EquiPairs(pairs) => pairs,
2380 };
2381
2382 let mut left_lookup: FxHashMap<String, usize> = FxHashMap::default();
2384 for (idx, field) in left_schema.fields().iter().enumerate() {
2385 left_lookup.insert(field.name().to_ascii_lowercase(), idx);
2386 }
2387
2388 let mut right_lookup: FxHashMap<String, usize> = FxHashMap::default();
2389 for (idx, field) in right_schema.fields().iter().enumerate() {
2390 right_lookup.insert(field.name().to_ascii_lowercase(), idx);
2391 }
2392
2393 let mut left_key_indices = Vec::new();
2395 let mut right_key_indices = Vec::new();
2396
2397 for (lhs_col, rhs_col) in equalities {
2398 let lhs_lower = lhs_col.to_ascii_lowercase();
2399 let rhs_lower = rhs_col.to_ascii_lowercase();
2400
2401 let (left_idx, right_idx) =
2402 match (left_lookup.get(&lhs_lower), right_lookup.get(&rhs_lower)) {
2403 (Some(&l), Some(&r)) => (l, r),
2404 (Some(_), None) => {
2405 if left_lookup.contains_key(&rhs_lower) {
2406 return Err(Error::InvalidArgumentError(format!(
2407 "Both join columns '{}' and '{}' are from left table",
2408 lhs_col, rhs_col
2409 )));
2410 }
2411 return Err(Error::InvalidArgumentError(format!(
2412 "Join column '{}' not found in right table",
2413 rhs_col
2414 )));
2415 }
2416 (None, Some(_)) => {
2417 if right_lookup.contains_key(&lhs_lower) {
2418 return Err(Error::InvalidArgumentError(format!(
2419 "Both join columns '{}' and '{}' are from right table",
2420 lhs_col, rhs_col
2421 )));
2422 }
2423 return Err(Error::InvalidArgumentError(format!(
2424 "Join column '{}' not found in left table",
2425 lhs_col
2426 )));
2427 }
2428 (None, None) => {
2429 match (left_lookup.get(&rhs_lower), right_lookup.get(&lhs_lower)) {
2431 (Some(&l), Some(&r)) => (l, r),
2432 _ => {
2433 return Err(Error::InvalidArgumentError(format!(
2434 "Join columns '{}' and '{}' not found in either table",
2435 lhs_col, rhs_col
2436 )));
2437 }
2438 }
2439 }
2440 };
2441
2442 left_key_indices.push(left_idx);
2443 right_key_indices.push(right_idx);
2444 }
2445
2446 let mut hash_table: FxHashMap<Vec<i64>, Vec<(usize, usize)>> = FxHashMap::default();
2449
2450 for (batch_idx, right_batch) in right_batches.iter().enumerate() {
2451 let num_rows = right_batch.num_rows();
2452 if num_rows == 0 {
2453 continue;
2454 }
2455
2456 let key_columns: Vec<&ArrayRef> = right_key_indices
2458 .iter()
2459 .map(|&idx| right_batch.column(idx))
2460 .collect();
2461
2462 for row_idx in 0..num_rows {
2464 let mut key_values = Vec::with_capacity(key_columns.len());
2466 let mut has_null = false;
2467
2468 for col in &key_columns {
2469 if col.is_null(row_idx) {
2470 has_null = true;
2471 break;
2472 }
2473 let value = extract_key_value_as_i64(col, row_idx)?;
2475 key_values.push(value);
2476 }
2477
2478 if has_null {
2480 continue;
2481 }
2482
2483 hash_table
2484 .entry(key_values)
2485 .or_default()
2486 .push((batch_idx, row_idx));
2487 }
2488 }
2489
2490 let mut result_batches = Vec::new();
2492 let combined_fields: Vec<Field> = left_schema
2493 .fields()
2494 .iter()
2495 .chain(right_schema.fields().iter())
2496 .map(|f| Field::new(f.name().clone(), f.data_type().clone(), true)) .collect();
2498 let combined_schema = Arc::new(Schema::new(combined_fields));
2499
2500 for left_batch in left_batches {
2501 let num_rows = left_batch.num_rows();
2502 if num_rows == 0 {
2503 continue;
2504 }
2505
2506 let left_key_columns: Vec<&ArrayRef> = left_key_indices
2508 .iter()
2509 .map(|&idx| left_batch.column(idx))
2510 .collect();
2511
2512 let mut left_matched = vec![false; num_rows];
2514
2515 let mut left_indices = Vec::new();
2517 let mut right_refs = Vec::new();
2518
2519 for (left_row_idx, matched) in left_matched.iter_mut().enumerate() {
2520 let mut key_values = Vec::with_capacity(left_key_columns.len());
2522 let mut has_null = false;
2523
2524 for col in &left_key_columns {
2525 if col.is_null(left_row_idx) {
2526 has_null = true;
2527 break;
2528 }
2529 let value = extract_key_value_as_i64(col, left_row_idx)?;
2530 key_values.push(value);
2531 }
2532
2533 if has_null {
2534 continue;
2536 }
2537
2538 if let Some(right_rows) = hash_table.get(&key_values) {
2540 *matched = true;
2541 for &(right_batch_idx, right_row_idx) in right_rows {
2542 left_indices.push(left_row_idx as u32);
2543 right_refs.push((right_batch_idx, right_row_idx));
2544 }
2545 }
2546 }
2547
2548 if !left_indices.is_empty() || join_type == llkv_join::JoinType::Left {
2550 let output_batch = build_join_output_batch(
2551 left_batch,
2552 right_batches,
2553 &left_indices,
2554 &right_refs,
2555 &left_matched,
2556 &combined_schema,
2557 join_type,
2558 )?;
2559
2560 if output_batch.num_rows() > 0 {
2561 result_batches.push(output_batch);
2562 }
2563 }
2564 }
2565
2566 if result_batches.is_empty() {
2567 result_batches.push(RecordBatch::new_empty(combined_schema));
2568 }
2569
2570 Ok(result_batches)
2571}
2572
2573fn extract_key_value_as_i64(col: &ArrayRef, row_idx: usize) -> ExecutorResult<i64> {
2575 use arrow::array::*;
2576 use arrow::datatypes::DataType;
2577
2578 match col.data_type() {
2579 DataType::Int8 => Ok(col
2580 .as_any()
2581 .downcast_ref::<Int8Array>()
2582 .unwrap()
2583 .value(row_idx) as i64),
2584 DataType::Int16 => Ok(col
2585 .as_any()
2586 .downcast_ref::<Int16Array>()
2587 .unwrap()
2588 .value(row_idx) as i64),
2589 DataType::Int32 => Ok(col
2590 .as_any()
2591 .downcast_ref::<Int32Array>()
2592 .unwrap()
2593 .value(row_idx) as i64),
2594 DataType::Int64 => Ok(col
2595 .as_any()
2596 .downcast_ref::<Int64Array>()
2597 .unwrap()
2598 .value(row_idx)),
2599 DataType::UInt8 => Ok(col
2600 .as_any()
2601 .downcast_ref::<UInt8Array>()
2602 .unwrap()
2603 .value(row_idx) as i64),
2604 DataType::UInt16 => Ok(col
2605 .as_any()
2606 .downcast_ref::<UInt16Array>()
2607 .unwrap()
2608 .value(row_idx) as i64),
2609 DataType::UInt32 => Ok(col
2610 .as_any()
2611 .downcast_ref::<UInt32Array>()
2612 .unwrap()
2613 .value(row_idx) as i64),
2614 DataType::UInt64 => {
2615 let val = col
2616 .as_any()
2617 .downcast_ref::<UInt64Array>()
2618 .unwrap()
2619 .value(row_idx);
2620 Ok(val as i64) }
2622 DataType::Utf8 => {
2623 let s = col
2625 .as_any()
2626 .downcast_ref::<StringArray>()
2627 .unwrap()
2628 .value(row_idx);
2629 use std::collections::hash_map::DefaultHasher;
2630 use std::hash::{Hash, Hasher};
2631 let mut hasher = DefaultHasher::new();
2632 s.hash(&mut hasher);
2633 Ok(hasher.finish() as i64)
2634 }
2635 _ => Err(Error::InvalidArgumentError(format!(
2636 "Unsupported join key type: {:?}",
2637 col.data_type()
2638 ))),
2639 }
2640}
2641
2642fn build_join_output_batch(
2644 left_batch: &RecordBatch,
2645 right_batches: &[RecordBatch],
2646 left_indices: &[u32],
2647 right_refs: &[(usize, usize)],
2648 left_matched: &[bool],
2649 combined_schema: &Arc<Schema>,
2650 join_type: llkv_join::JoinType,
2651) -> ExecutorResult<RecordBatch> {
2652 use arrow::array::UInt32Array;
2653 use arrow::compute::take;
2654
2655 match join_type {
2656 llkv_join::JoinType::Inner => {
2657 let left_indices_array = UInt32Array::from(left_indices.to_vec());
2659
2660 let mut output_columns = Vec::new();
2661
2662 for col in left_batch.columns() {
2664 let taken = take(col.as_ref(), &left_indices_array, None)
2665 .map_err(|e| Error::Internal(format!("Failed to take left column: {}", e)))?;
2666 output_columns.push(taken);
2667 }
2668
2669 for right_col_idx in 0..right_batches[0].num_columns() {
2671 let mut values = Vec::with_capacity(right_refs.len());
2672 for &(batch_idx, row_idx) in right_refs {
2673 let col = right_batches[batch_idx].column(right_col_idx);
2674 values.push((col.clone(), row_idx));
2675 }
2676
2677 let right_col = gather_from_multiple_batches(
2679 &values,
2680 right_batches[0].column(right_col_idx).data_type(),
2681 )?;
2682 output_columns.push(right_col);
2683 }
2684
2685 RecordBatch::try_new(Arc::clone(combined_schema), output_columns)
2686 .map_err(|e| Error::Internal(format!("Failed to create output batch: {}", e)))
2687 }
2688 llkv_join::JoinType::Left => {
2689 let mut output_columns = Vec::new();
2691
2692 for col in left_batch.columns() {
2694 output_columns.push(col.clone());
2695 }
2696
2697 for right_col_idx in 0..right_batches[0].num_columns() {
2699 let right_col = build_left_join_column(
2700 left_matched,
2701 right_batches,
2702 right_col_idx,
2703 left_indices,
2704 right_refs,
2705 )?;
2706 output_columns.push(right_col);
2707 }
2708
2709 RecordBatch::try_new(Arc::clone(combined_schema), output_columns)
2710 .map_err(|e| Error::Internal(format!("Failed to create left join batch: {}", e)))
2711 }
2712 _ => Err(Error::InvalidArgumentError(format!(
2713 "{:?} join not yet implemented in batch join",
2714 join_type
2715 ))),
2716 }
2717}
2718
2719fn gather_from_multiple_batches(
2724 values: &[(ArrayRef, usize)],
2725 _data_type: &DataType,
2726) -> ExecutorResult<ArrayRef> {
2727 use arrow::array::*;
2728 use arrow::compute::take;
2729
2730 if values.is_empty() {
2731 return Ok(new_null_array(&DataType::Null, 0));
2732 }
2733
2734 if values.len() > 1 {
2736 let first_array_ptr = Arc::as_ptr(&values[0].0);
2737 let all_same_array = values
2738 .iter()
2739 .all(|(arr, _)| std::ptr::addr_eq(Arc::as_ptr(arr), first_array_ptr));
2740
2741 if all_same_array {
2742 let indices: Vec<u32> = values.iter().map(|(_, idx)| *idx as u32).collect();
2745 let indices_array = UInt32Array::from(indices);
2746 return take(values[0].0.as_ref(), &indices_array, None)
2747 .map_err(|e| Error::Internal(format!("Arrow take failed: {}", e)));
2748 }
2749 }
2750
2751 use arrow::compute::concat;
2754
2755 let mut unique_arrays: Vec<(Arc<dyn Array>, Vec<usize>)> = Vec::new();
2757 let mut array_map: FxHashMap<*const dyn Array, usize> = FxHashMap::default();
2758
2759 for (arr, row_idx) in values {
2760 let ptr = Arc::as_ptr(arr);
2761 if let Some(&idx) = array_map.get(&ptr) {
2762 unique_arrays[idx].1.push(*row_idx);
2763 } else {
2764 let idx = unique_arrays.len();
2765 array_map.insert(ptr, idx);
2766 unique_arrays.push((Arc::clone(arr), vec![*row_idx]));
2767 }
2768 }
2769
2770 if unique_arrays.len() == 1 {
2772 let (arr, indices) = &unique_arrays[0];
2773 let indices_u32: Vec<u32> = indices.iter().map(|&i| i as u32).collect();
2774 let indices_array = UInt32Array::from(indices_u32);
2775 return take(arr.as_ref(), &indices_array, None)
2776 .map_err(|e| Error::Internal(format!("Arrow take failed: {}", e)));
2777 }
2778
2779 let arrays_to_concat: Vec<&dyn Array> =
2781 unique_arrays.iter().map(|(arr, _)| arr.as_ref()).collect();
2782
2783 let concatenated = concat(&arrays_to_concat)
2784 .map_err(|e| Error::Internal(format!("Arrow concat failed: {}", e)))?;
2785
2786 let mut offset = 0;
2788 let mut adjusted_indices = Vec::with_capacity(values.len());
2789 for (arr, _) in &unique_arrays {
2790 let arr_len = arr.len();
2791 for (check_arr, row_idx) in values {
2792 if Arc::ptr_eq(arr, check_arr) {
2793 adjusted_indices.push((offset + row_idx) as u32);
2794 }
2795 }
2796 offset += arr_len;
2797 }
2798
2799 let indices_array = UInt32Array::from(adjusted_indices);
2800 take(&concatenated, &indices_array, None)
2801 .map_err(|e| Error::Internal(format!("Arrow take on concatenated failed: {}", e)))
2802}
2803
2804fn build_left_join_column(
2806 left_matched: &[bool],
2807 right_batches: &[RecordBatch],
2808 right_col_idx: usize,
2809 _left_indices: &[u32],
2810 _right_refs: &[(usize, usize)],
2811) -> ExecutorResult<ArrayRef> {
2812 let data_type = right_batches[0].column(right_col_idx).data_type();
2815 Ok(new_null_array(data_type, left_matched.len()))
2816}
2817
2818fn execute_cross_join_batches(
2820 left: &RecordBatch,
2821 right: &RecordBatch,
2822) -> ExecutorResult<RecordBatch> {
2823 let combined_fields: Vec<Field> = left
2824 .schema()
2825 .fields()
2826 .iter()
2827 .chain(right.schema().fields().iter())
2828 .map(|f| Field::new(f.name().clone(), f.data_type().clone(), f.is_nullable()))
2829 .collect();
2830 let combined_schema = Arc::new(Schema::new(combined_fields));
2831
2832 cross_join_pair(left, right, &combined_schema)
2833}
2834
2835#[allow(dead_code)]
2837fn build_temp_table_from_batches<P>(
2838 _schema: &Arc<Schema>,
2839 _batches: &[RecordBatch],
2840) -> ExecutorResult<llkv_table::Table<P>>
2841where
2842 P: Pager<Blob = EntryHandle> + Send + Sync,
2843{
2844 Err(Error::Internal(
2846 "build_temp_table_from_batches should not be called".into(),
2847 ))
2848}
2849
2850#[allow(dead_code)]
2852fn build_join_keys_from_condition_indexed(
2853 _condition: &LlkvExpr<'static, String>,
2854 _left_data: &TableCrossProductData,
2855 _right_data: &TableCrossProductData,
2856 _left_idx: usize,
2857 _right_idx: usize,
2858) -> ExecutorResult<JoinKeyBuild> {
2859 Err(Error::Internal(
2861 "build_join_keys_from_condition_indexed should not be called".into(),
2862 ))
2863}
2864
2865#[cfg(test)]
2866mod join_condition_tests {
2867 use super::*;
2868 use llkv_expr::expr::{CompareOp, ScalarExpr};
2869 use llkv_expr::literal::Literal;
2870
2871 #[test]
2872 fn analyze_detects_simple_equality() {
2873 let expr = LlkvExpr::Compare {
2874 left: ScalarExpr::Column("t1.col".into()),
2875 op: CompareOp::Eq,
2876 right: ScalarExpr::Column("t2.col".into()),
2877 };
2878
2879 match analyze_join_condition(&expr).expect("analysis succeeds") {
2880 JoinConditionAnalysis::EquiPairs(pairs) => {
2881 assert_eq!(pairs, vec![("t1.col".to_string(), "t2.col".to_string())]);
2882 }
2883 other => panic!("unexpected analysis result: {other:?}"),
2884 }
2885 }
2886
2887 #[test]
2888 fn analyze_handles_literal_true() {
2889 let expr = LlkvExpr::Literal(true);
2890 assert!(matches!(
2891 analyze_join_condition(&expr).expect("analysis succeeds"),
2892 JoinConditionAnalysis::AlwaysTrue
2893 ));
2894 }
2895
2896 #[test]
2897 fn analyze_rejects_non_equality() {
2898 let expr = LlkvExpr::Compare {
2899 left: ScalarExpr::Column("t1.col".into()),
2900 op: CompareOp::Gt,
2901 right: ScalarExpr::Column("t2.col".into()),
2902 };
2903 assert!(analyze_join_condition(&expr).is_err());
2904 }
2905
2906 #[test]
2907 fn analyze_handles_constant_is_not_null() {
2908 let expr = LlkvExpr::IsNull {
2909 expr: ScalarExpr::Literal(Literal::Null),
2910 negated: true,
2911 };
2912
2913 assert!(matches!(
2914 analyze_join_condition(&expr).expect("analysis succeeds"),
2915 JoinConditionAnalysis::AlwaysFalse
2916 ));
2917 }
2918
2919 #[test]
2920 fn analyze_handles_not_applied_to_is_not_null() {
2921 let expr = LlkvExpr::Not(Box::new(LlkvExpr::IsNull {
2922 expr: ScalarExpr::Literal(Literal::Int128(86)),
2923 negated: true,
2924 }));
2925
2926 assert!(matches!(
2927 analyze_join_condition(&expr).expect("analysis succeeds"),
2928 JoinConditionAnalysis::AlwaysFalse
2929 ));
2930 }
2931
2932 #[test]
2933 fn analyze_literal_is_null_is_always_false() {
2934 let expr = LlkvExpr::IsNull {
2935 expr: ScalarExpr::Literal(Literal::Int128(1)),
2936 negated: false,
2937 };
2938
2939 assert!(matches!(
2940 analyze_join_condition(&expr).expect("analysis succeeds"),
2941 JoinConditionAnalysis::AlwaysFalse
2942 ));
2943 }
2944
2945 #[test]
2946 fn analyze_not_null_comparison_is_always_false() {
2947 let expr = LlkvExpr::Not(Box::new(LlkvExpr::Compare {
2948 left: ScalarExpr::Literal(Literal::Null),
2949 op: CompareOp::Lt,
2950 right: ScalarExpr::Column("t2.col".into()),
2951 }));
2952
2953 assert!(matches!(
2954 analyze_join_condition(&expr).expect("analysis succeeds"),
2955 JoinConditionAnalysis::AlwaysFalse
2956 ));
2957 }
2958}
2959
2960#[cfg(test)]
2961mod cross_join_batch_tests {
2962 use super::*;
2963 use arrow::array::Int32Array;
2964
2965 #[test]
2966 fn execute_cross_join_batches_emits_full_cartesian_product() {
2967 let left_schema = Arc::new(Schema::new(vec![Field::new("l", DataType::Int32, false)]));
2968 let right_schema = Arc::new(Schema::new(vec![Field::new("r", DataType::Int32, false)]));
2969
2970 let left_batch = RecordBatch::try_new(
2971 Arc::clone(&left_schema),
2972 vec![Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef],
2973 )
2974 .expect("left batch");
2975 let right_batch = RecordBatch::try_new(
2976 Arc::clone(&right_schema),
2977 vec![Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef],
2978 )
2979 .expect("right batch");
2980
2981 let result = execute_cross_join_batches(&left_batch, &right_batch).expect("cross join");
2982
2983 assert_eq!(result.num_rows(), 6);
2984 assert_eq!(result.num_columns(), 2);
2985
2986 let left_values: Vec<i32> = {
2987 let array = result
2988 .column(0)
2989 .as_any()
2990 .downcast_ref::<Int32Array>()
2991 .unwrap();
2992 (0..array.len()).map(|idx| array.value(idx)).collect()
2993 };
2994 let right_values: Vec<i32> = {
2995 let array = result
2996 .column(1)
2997 .as_any()
2998 .downcast_ref::<Int32Array>()
2999 .unwrap();
3000 (0..array.len()).map(|idx| array.value(idx)).collect()
3001 };
3002
3003 assert_eq!(left_values, vec![1, 1, 1, 2, 2, 2]);
3004 assert_eq!(right_values, vec![10, 20, 30, 10, 20, 30]);
3005 }
3006}
3007
3008impl<P> QueryExecutor<P>
3009where
3010 P: Pager<Blob = EntryHandle> + Send + Sync,
3011{
3012 fn execute_cross_product_aggregates(
3013 &self,
3014 combined_schema: Arc<Schema>,
3015 batches: Vec<RecordBatch>,
3016 column_lookup_map: &FxHashMap<String, usize>,
3017 plan: &SelectPlan,
3018 display_name: &str,
3019 ) -> ExecutorResult<SelectExecution<P>> {
3020 if !plan.scalar_subqueries.is_empty() {
3021 return Err(Error::InvalidArgumentError(
3022 "scalar subqueries not supported in aggregate joins".into(),
3023 ));
3024 }
3025
3026 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
3027 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(plan.aggregates.len());
3028
3029 for aggregate in &plan.aggregates {
3030 match aggregate {
3031 AggregateExpr::CountStar { alias, distinct } => {
3032 specs.push(AggregateSpec {
3033 alias: alias.clone(),
3034 kind: AggregateKind::Count {
3035 field_id: None,
3036 distinct: *distinct,
3037 },
3038 });
3039 spec_to_projection.push(None);
3040 }
3041 AggregateExpr::Column {
3042 column,
3043 alias,
3044 function,
3045 distinct,
3046 } => {
3047 let key = column.to_ascii_lowercase();
3048 let column_index = *column_lookup_map.get(&key).ok_or_else(|| {
3049 Error::InvalidArgumentError(format!(
3050 "unknown column '{column}' in aggregate"
3051 ))
3052 })?;
3053 let field = combined_schema.field(column_index);
3054 let kind = match function {
3055 AggregateFunction::Count => AggregateKind::Count {
3056 field_id: Some(column_index as u32),
3057 distinct: *distinct,
3058 },
3059 AggregateFunction::SumInt64 => {
3060 let input_type = Self::validate_aggregate_type(
3061 Some(field.data_type().clone()),
3062 "SUM",
3063 &[DataType::Int64, DataType::Float64],
3064 )?;
3065 AggregateKind::Sum {
3066 field_id: column_index as u32,
3067 data_type: input_type,
3068 distinct: *distinct,
3069 }
3070 }
3071 AggregateFunction::TotalInt64 => {
3072 let input_type = Self::validate_aggregate_type(
3073 Some(field.data_type().clone()),
3074 "TOTAL",
3075 &[DataType::Int64, DataType::Float64],
3076 )?;
3077 AggregateKind::Total {
3078 field_id: column_index as u32,
3079 data_type: input_type,
3080 distinct: *distinct,
3081 }
3082 }
3083 AggregateFunction::MinInt64 => {
3084 let input_type = Self::validate_aggregate_type(
3085 Some(field.data_type().clone()),
3086 "MIN",
3087 &[DataType::Int64, DataType::Float64],
3088 )?;
3089 AggregateKind::Min {
3090 field_id: column_index as u32,
3091 data_type: input_type,
3092 }
3093 }
3094 AggregateFunction::MaxInt64 => {
3095 let input_type = Self::validate_aggregate_type(
3096 Some(field.data_type().clone()),
3097 "MAX",
3098 &[DataType::Int64, DataType::Float64],
3099 )?;
3100 AggregateKind::Max {
3101 field_id: column_index as u32,
3102 data_type: input_type,
3103 }
3104 }
3105 AggregateFunction::CountNulls => AggregateKind::CountNulls {
3106 field_id: column_index as u32,
3107 },
3108 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
3109 field_id: column_index as u32,
3110 distinct: *distinct,
3111 separator: ",".to_string(),
3112 },
3113 };
3114
3115 specs.push(AggregateSpec {
3116 alias: alias.clone(),
3117 kind,
3118 });
3119 spec_to_projection.push(Some(column_index));
3120 }
3121 }
3122 }
3123
3124 if specs.is_empty() {
3125 return Err(Error::InvalidArgumentError(
3126 "aggregate query requires at least one aggregate expression".into(),
3127 ));
3128 }
3129
3130 let mut states = Vec::with_capacity(specs.len());
3131 for (idx, spec) in specs.iter().enumerate() {
3132 states.push(AggregateState {
3133 alias: spec.alias.clone(),
3134 accumulator: AggregateAccumulator::new_with_projection_index(
3135 spec,
3136 spec_to_projection[idx],
3137 None,
3138 )?,
3139 override_value: None,
3140 });
3141 }
3142
3143 for batch in &batches {
3144 for state in &mut states {
3145 state.update(batch)?;
3146 }
3147 }
3148
3149 let mut fields = Vec::with_capacity(states.len());
3150 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
3151 for state in states {
3152 let (field, array) = state.finalize()?;
3153 fields.push(Arc::new(field));
3154 arrays.push(array);
3155 }
3156
3157 let schema = Arc::new(Schema::new(fields));
3158 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
3159
3160 if plan.distinct {
3161 let mut distinct_state = DistinctState::default();
3162 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
3163 Some(filtered) => filtered,
3164 None => RecordBatch::new_empty(Arc::clone(&schema)),
3165 };
3166 }
3167
3168 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
3169 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
3170 }
3171
3172 Ok(SelectExecution::new_single_batch(
3173 display_name.to_string(),
3174 schema,
3175 batch,
3176 ))
3177 }
3178
3179 fn execute_cross_product_computed_aggregates(
3180 &self,
3181 combined_schema: Arc<Schema>,
3182 batches: Vec<RecordBatch>,
3183 column_lookup_map: &FxHashMap<String, usize>,
3184 plan: &SelectPlan,
3185 display_name: &str,
3186 ) -> ExecutorResult<SelectExecution<P>> {
3187 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
3188 for projection in &plan.projections {
3189 match projection {
3190 SelectProjection::Computed { expr, .. } => {
3191 Self::collect_aggregates(expr, &mut aggregate_specs);
3192 }
3193 SelectProjection::AllColumns
3194 | SelectProjection::AllColumnsExcept { .. }
3195 | SelectProjection::Column { .. } => {
3196 return Err(Error::InvalidArgumentError(
3197 "non-computed projections not supported with aggregate expressions".into(),
3198 ));
3199 }
3200 }
3201 }
3202
3203 if aggregate_specs.is_empty() {
3204 return Err(Error::InvalidArgumentError(
3205 "computed aggregate query requires at least one aggregate expression".into(),
3206 ));
3207 }
3208
3209 let aggregate_values = self.compute_cross_product_aggregate_values(
3210 &combined_schema,
3211 &batches,
3212 column_lookup_map,
3213 &aggregate_specs,
3214 )?;
3215
3216 let mut fields = Vec::with_capacity(plan.projections.len());
3217 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
3218
3219 for projection in &plan.projections {
3220 if let SelectProjection::Computed { expr, alias } = projection {
3221 if let ScalarExpr::Aggregate(agg) = expr {
3223 let key = format!("{:?}", agg);
3224 if let Some(agg_value) = aggregate_values.get(&key) {
3225 match agg_value {
3226 AggregateValue::Null => {
3227 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
3228 arrays.push(Arc::new(Int64Array::from(vec![None::<i64>])) as ArrayRef);
3229 }
3230 AggregateValue::Int64(v) => {
3231 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
3232 arrays.push(Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef);
3233 }
3234 AggregateValue::Float64(v) => {
3235 fields.push(Arc::new(Field::new(alias, DataType::Float64, true)));
3236 arrays
3237 .push(Arc::new(Float64Array::from(vec![Some(*v)])) as ArrayRef);
3238 }
3239 AggregateValue::Decimal128 { value, scale } => {
3240 let precision = if *value == 0 {
3242 1
3243 } else {
3244 (*value).abs().to_string().len() as u8
3245 };
3246 fields.push(Arc::new(Field::new(
3247 alias,
3248 DataType::Decimal128(precision, *scale),
3249 true,
3250 )));
3251 let array = Decimal128Array::from(vec![Some(*value)])
3252 .with_precision_and_scale(precision, *scale)
3253 .map_err(|e| {
3254 Error::Internal(format!("invalid Decimal128: {}", e))
3255 })?;
3256 arrays.push(Arc::new(array) as ArrayRef);
3257 }
3258 AggregateValue::String(s) => {
3259 fields.push(Arc::new(Field::new(alias, DataType::Utf8, true)));
3260 arrays
3261 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
3262 as ArrayRef);
3263 }
3264 }
3265 continue;
3266 }
3267 }
3268
3269 let value = Self::evaluate_expr_with_aggregates(expr, &aggregate_values)?;
3271 fields.push(Arc::new(Field::new(alias, DataType::Int64, true)));
3272 arrays.push(Arc::new(Int64Array::from(vec![value])) as ArrayRef);
3273 }
3274 }
3275
3276 let schema = Arc::new(Schema::new(fields));
3277 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
3278
3279 if plan.distinct {
3280 let mut distinct_state = DistinctState::default();
3281 batch = match distinct_filter_batch(batch, &mut distinct_state)? {
3282 Some(filtered) => filtered,
3283 None => RecordBatch::new_empty(Arc::clone(&schema)),
3284 };
3285 }
3286
3287 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
3288 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
3289 }
3290
3291 Ok(SelectExecution::new_single_batch(
3292 display_name.to_string(),
3293 schema,
3294 batch,
3295 ))
3296 }
3297
3298 fn compute_cross_product_aggregate_values(
3299 &self,
3300 combined_schema: &Arc<Schema>,
3301 batches: &[RecordBatch],
3302 column_lookup_map: &FxHashMap<String, usize>,
3303 aggregate_specs: &[(String, AggregateCall<String>)],
3304 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
3305 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
3306 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
3307
3308 let mut columns_per_batch: Option<Vec<Vec<ArrayRef>>> = None;
3309 let mut augmented_fields: Option<Vec<Field>> = None;
3310 let mut owned_batches: Option<Vec<RecordBatch>> = None;
3311 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
3312 FxHashMap::default();
3313 let mut computed_alias_counter: usize = 0;
3314 let mut expr_context = CrossProductExpressionContext::new(
3315 combined_schema.as_ref(),
3316 column_lookup_map.clone(),
3317 )?;
3318
3319 let mut ensure_computed_column =
3320 |expr: &ScalarExpr<String>| -> ExecutorResult<(usize, DataType)> {
3321 let key = format!("{:?}", expr);
3322 if let Some((idx, dtype)) = computed_projection_cache.get(&key) {
3323 return Ok((*idx, dtype.clone()));
3324 }
3325
3326 if columns_per_batch.is_none() {
3327 let initial_columns: Vec<Vec<ArrayRef>> = batches
3328 .iter()
3329 .map(|batch| batch.columns().to_vec())
3330 .collect();
3331 columns_per_batch = Some(initial_columns);
3332 }
3333 if augmented_fields.is_none() {
3334 augmented_fields = Some(
3335 combined_schema
3336 .fields()
3337 .iter()
3338 .map(|field| field.as_ref().clone())
3339 .collect(),
3340 );
3341 }
3342
3343 let translated = translate_scalar(expr, expr_context.schema(), |name| {
3344 Error::InvalidArgumentError(format!(
3345 "unknown column '{}' in aggregate expression",
3346 name
3347 ))
3348 })?;
3349 let data_type = infer_computed_data_type(expr_context.schema(), &translated)?;
3350
3351 if let Some(columns) = columns_per_batch.as_mut() {
3352 for (batch_idx, batch) in batches.iter().enumerate() {
3353 expr_context.reset();
3354 let array = expr_context.materialize_scalar_array(&translated, batch)?;
3355 if let Some(batch_columns) = columns.get_mut(batch_idx) {
3356 batch_columns.push(array);
3357 }
3358 }
3359 }
3360
3361 let column_index = augmented_fields
3362 .as_ref()
3363 .map(|fields| fields.len())
3364 .unwrap_or_else(|| combined_schema.fields().len());
3365
3366 let alias = format!("__agg_expr_cp_{}", computed_alias_counter);
3367 computed_alias_counter += 1;
3368 augmented_fields
3369 .as_mut()
3370 .expect("augmented fields initialized")
3371 .push(Field::new(&alias, data_type.clone(), true));
3372
3373 computed_projection_cache.insert(key, (column_index, data_type.clone()));
3374 Ok((column_index, data_type))
3375 };
3376
3377 for (key, agg) in aggregate_specs {
3378 match agg {
3379 AggregateCall::CountStar => {
3380 specs.push(AggregateSpec {
3381 alias: key.clone(),
3382 kind: AggregateKind::Count {
3383 field_id: None,
3384 distinct: false,
3385 },
3386 });
3387 spec_to_projection.push(None);
3388 }
3389 AggregateCall::Count { expr, .. }
3390 | AggregateCall::Sum { expr, .. }
3391 | AggregateCall::Total { expr, .. }
3392 | AggregateCall::Avg { expr, .. }
3393 | AggregateCall::Min(expr)
3394 | AggregateCall::Max(expr)
3395 | AggregateCall::CountNulls(expr)
3396 | AggregateCall::GroupConcat { expr, .. } => {
3397 let (column_index, data_type_opt) = if let Some(column) =
3398 try_extract_simple_column(expr)
3399 {
3400 let key_lower = column.to_ascii_lowercase();
3401 let column_index = *column_lookup_map.get(&key_lower).ok_or_else(|| {
3402 Error::InvalidArgumentError(format!(
3403 "unknown column '{column}' in aggregate"
3404 ))
3405 })?;
3406 let field = combined_schema.field(column_index);
3407 (column_index, Some(field.data_type().clone()))
3408 } else {
3409 let (index, dtype) = ensure_computed_column(expr)?;
3410 (index, Some(dtype))
3411 };
3412
3413 let kind = match agg {
3414 AggregateCall::Count { distinct, .. } => {
3415 let field_id = u32::try_from(column_index).map_err(|_| {
3416 Error::InvalidArgumentError(
3417 "aggregate projection index exceeds supported range".into(),
3418 )
3419 })?;
3420 AggregateKind::Count {
3421 field_id: Some(field_id),
3422 distinct: *distinct,
3423 }
3424 }
3425 AggregateCall::Sum { distinct, .. } => {
3426 let input_type = Self::validate_aggregate_type(
3427 data_type_opt.clone(),
3428 "SUM",
3429 &[DataType::Int64, DataType::Float64],
3430 )?;
3431 let field_id = u32::try_from(column_index).map_err(|_| {
3432 Error::InvalidArgumentError(
3433 "aggregate projection index exceeds supported range".into(),
3434 )
3435 })?;
3436 AggregateKind::Sum {
3437 field_id,
3438 data_type: input_type,
3439 distinct: *distinct,
3440 }
3441 }
3442 AggregateCall::Total { distinct, .. } => {
3443 let input_type = Self::validate_aggregate_type(
3444 data_type_opt.clone(),
3445 "TOTAL",
3446 &[DataType::Int64, DataType::Float64],
3447 )?;
3448 let field_id = u32::try_from(column_index).map_err(|_| {
3449 Error::InvalidArgumentError(
3450 "aggregate projection index exceeds supported range".into(),
3451 )
3452 })?;
3453 AggregateKind::Total {
3454 field_id,
3455 data_type: input_type,
3456 distinct: *distinct,
3457 }
3458 }
3459 AggregateCall::Avg { distinct, .. } => {
3460 let input_type = Self::validate_aggregate_type(
3461 data_type_opt.clone(),
3462 "AVG",
3463 &[DataType::Int64, DataType::Float64],
3464 )?;
3465 let field_id = u32::try_from(column_index).map_err(|_| {
3466 Error::InvalidArgumentError(
3467 "aggregate projection index exceeds supported range".into(),
3468 )
3469 })?;
3470 AggregateKind::Avg {
3471 field_id,
3472 data_type: input_type,
3473 distinct: *distinct,
3474 }
3475 }
3476 AggregateCall::Min(_) => {
3477 let input_type = Self::validate_aggregate_type(
3478 data_type_opt.clone(),
3479 "MIN",
3480 &[DataType::Int64, DataType::Float64],
3481 )?;
3482 let field_id = u32::try_from(column_index).map_err(|_| {
3483 Error::InvalidArgumentError(
3484 "aggregate projection index exceeds supported range".into(),
3485 )
3486 })?;
3487 AggregateKind::Min {
3488 field_id,
3489 data_type: input_type,
3490 }
3491 }
3492 AggregateCall::Max(_) => {
3493 let input_type = Self::validate_aggregate_type(
3494 data_type_opt.clone(),
3495 "MAX",
3496 &[DataType::Int64, DataType::Float64],
3497 )?;
3498 let field_id = u32::try_from(column_index).map_err(|_| {
3499 Error::InvalidArgumentError(
3500 "aggregate projection index exceeds supported range".into(),
3501 )
3502 })?;
3503 AggregateKind::Max {
3504 field_id,
3505 data_type: input_type,
3506 }
3507 }
3508 AggregateCall::CountNulls(_) => {
3509 let field_id = u32::try_from(column_index).map_err(|_| {
3510 Error::InvalidArgumentError(
3511 "aggregate projection index exceeds supported range".into(),
3512 )
3513 })?;
3514 AggregateKind::CountNulls { field_id }
3515 }
3516 AggregateCall::GroupConcat {
3517 distinct,
3518 separator,
3519 ..
3520 } => {
3521 let field_id = u32::try_from(column_index).map_err(|_| {
3522 Error::InvalidArgumentError(
3523 "aggregate projection index exceeds supported range".into(),
3524 )
3525 })?;
3526 AggregateKind::GroupConcat {
3527 field_id,
3528 distinct: *distinct,
3529 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
3530 }
3531 }
3532 _ => unreachable!(),
3533 };
3534
3535 specs.push(AggregateSpec {
3536 alias: key.clone(),
3537 kind,
3538 });
3539 spec_to_projection.push(Some(column_index));
3540 }
3541 }
3542 }
3543
3544 if let Some(columns) = columns_per_batch {
3545 let fields = augmented_fields.unwrap_or_else(|| {
3546 combined_schema
3547 .fields()
3548 .iter()
3549 .map(|field| field.as_ref().clone())
3550 .collect()
3551 });
3552 let augmented_schema = Arc::new(Schema::new(fields));
3553 let mut new_batches = Vec::with_capacity(columns.len());
3554 for batch_columns in columns {
3555 let batch = RecordBatch::try_new(Arc::clone(&augmented_schema), batch_columns)
3556 .map_err(|err| {
3557 Error::InvalidArgumentError(format!(
3558 "failed to materialize aggregate projections: {err}"
3559 ))
3560 })?;
3561 new_batches.push(batch);
3562 }
3563 owned_batches = Some(new_batches);
3564 }
3565
3566 let mut states = Vec::with_capacity(specs.len());
3567 for (idx, spec) in specs.iter().enumerate() {
3568 states.push(AggregateState {
3569 alias: spec.alias.clone(),
3570 accumulator: AggregateAccumulator::new_with_projection_index(
3571 spec,
3572 spec_to_projection[idx],
3573 None,
3574 )?,
3575 override_value: None,
3576 });
3577 }
3578
3579 let batch_iter: &[RecordBatch] = if let Some(ref extended) = owned_batches {
3580 extended.as_slice()
3581 } else {
3582 batches
3583 };
3584
3585 for batch in batch_iter {
3586 for state in &mut states {
3587 state.update(batch)?;
3588 }
3589 }
3590
3591 let mut results = FxHashMap::default();
3592 for state in states {
3593 let (field, array) = state.finalize()?;
3594
3595 if let Some(int_array) = array.as_any().downcast_ref::<Int64Array>() {
3597 if int_array.len() != 1 {
3598 return Err(Error::Internal(format!(
3599 "Expected single value from aggregate, got {}",
3600 int_array.len()
3601 )));
3602 }
3603 let value = if int_array.is_null(0) {
3604 AggregateValue::Null
3605 } else {
3606 AggregateValue::Int64(int_array.value(0))
3607 };
3608 results.insert(field.name().to_string(), value);
3609 }
3610 else if let Some(float_array) = array.as_any().downcast_ref::<Float64Array>() {
3612 if float_array.len() != 1 {
3613 return Err(Error::Internal(format!(
3614 "Expected single value from aggregate, got {}",
3615 float_array.len()
3616 )));
3617 }
3618 let value = if float_array.is_null(0) {
3619 AggregateValue::Null
3620 } else {
3621 AggregateValue::Float64(float_array.value(0))
3622 };
3623 results.insert(field.name().to_string(), value);
3624 }
3625 else if let Some(string_array) = array.as_any().downcast_ref::<StringArray>() {
3627 if string_array.len() != 1 {
3628 return Err(Error::Internal(format!(
3629 "Expected single value from aggregate, got {}",
3630 string_array.len()
3631 )));
3632 }
3633 let value = if string_array.is_null(0) {
3634 AggregateValue::Null
3635 } else {
3636 AggregateValue::String(string_array.value(0).to_string())
3637 };
3638 results.insert(field.name().to_string(), value);
3639 }
3640 else if let Some(decimal_array) = array.as_any().downcast_ref::<Decimal128Array>() {
3642 if decimal_array.len() != 1 {
3643 return Err(Error::Internal(format!(
3644 "Expected single value from aggregate, got {}",
3645 decimal_array.len()
3646 )));
3647 }
3648 let value = if decimal_array.is_null(0) {
3649 AggregateValue::Null
3650 } else {
3651 AggregateValue::Decimal128 {
3652 value: decimal_array.value(0),
3653 scale: decimal_array.scale(),
3654 }
3655 };
3656 results.insert(field.name().to_string(), value);
3657 } else {
3658 return Err(Error::Internal(format!(
3659 "Unexpected array type from aggregate: {:?}",
3660 array.data_type()
3661 )));
3662 }
3663 }
3664
3665 Ok(results)
3666 }
3667
3668 fn try_execute_hash_join(
3685 &self,
3686 plan: &SelectPlan,
3687 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
3688 ) -> ExecutorResult<Option<(TableCrossProductData, bool)>> {
3689 let query_label_opt = current_query_label();
3690 let query_label = query_label_opt.as_deref().unwrap_or("<unknown query>");
3691
3692 let filter_wrapper = match &plan.filter {
3694 Some(filter) => filter,
3695 None => {
3696 tracing::debug!(
3697 "join_opt[{query_label}]: skipping optimization – no filter present"
3698 );
3699 return Ok(None);
3700 }
3701 };
3702
3703 let all_inner_joins = plan
3711 .joins
3712 .iter()
3713 .all(|j| j.join_type == llkv_plan::JoinPlan::Inner);
3714
3715 if !plan.joins.is_empty() && !all_inner_joins {
3716 tracing::debug!(
3717 "join_opt[{query_label}]: skipping optimization – explicit non-INNER JOINs present"
3718 );
3719 return Ok(None);
3720 }
3721
3722 if tables_with_handles.len() < 2 {
3723 tracing::debug!(
3724 "join_opt[{query_label}]: skipping optimization – requires at least 2 tables"
3725 );
3726 return Ok(None);
3727 }
3728
3729 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
3731 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
3732 let mut column_map = FxHashMap::default();
3733 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
3734 let column_name = column.name.to_ascii_lowercase();
3735 column_map.entry(column_name).or_insert(column_idx);
3736 }
3737 table_infos.push(TableInfo {
3738 index,
3739 table_ref,
3740 column_map,
3741 });
3742 }
3743
3744 let constraint_plan = match extract_join_constraints(
3746 &filter_wrapper.predicate,
3747 &table_infos,
3748 ) {
3749 Some(plan) => plan,
3750 None => {
3751 tracing::debug!(
3752 "join_opt[{query_label}]: skipping optimization – predicate parsing failed (contains OR or other unsupported top-level structure)"
3753 );
3754 return Ok(None);
3755 }
3756 };
3757
3758 tracing::debug!(
3759 "join_opt[{query_label}]: constraint extraction succeeded - equalities={}, literals={}, handled={}/{} predicates",
3760 constraint_plan.equalities.len(),
3761 constraint_plan.literals.len(),
3762 constraint_plan.handled_conjuncts,
3763 constraint_plan.total_conjuncts
3764 );
3765 tracing::debug!(
3766 "join_opt[{query_label}]: attempting hash join with tables={:?} filter={:?}",
3767 plan.tables
3768 .iter()
3769 .map(|t| t.qualified_name())
3770 .collect::<Vec<_>>(),
3771 filter_wrapper.predicate,
3772 );
3773
3774 if constraint_plan.unsatisfiable {
3776 tracing::debug!(
3777 "join_opt[{query_label}]: predicate unsatisfiable – returning empty result"
3778 );
3779 let mut combined_fields = Vec::new();
3780 let mut column_counts = Vec::new();
3781 for (_table_ref, executor_table) in tables_with_handles {
3782 for column in &executor_table.schema.columns {
3783 combined_fields.push(Field::new(
3784 column.name.clone(),
3785 column.data_type.clone(),
3786 column.nullable,
3787 ));
3788 }
3789 column_counts.push(executor_table.schema.columns.len());
3790 }
3791 let combined_schema = Arc::new(Schema::new(combined_fields));
3792 let empty_batch = RecordBatch::new_empty(Arc::clone(&combined_schema));
3793 return Ok(Some((
3794 TableCrossProductData {
3795 schema: combined_schema,
3796 batches: vec![empty_batch],
3797 column_counts,
3798 table_indices: (0..tables_with_handles.len()).collect(),
3799 },
3800 true, )));
3802 }
3803
3804 if constraint_plan.equalities.is_empty() {
3806 tracing::debug!(
3807 "join_opt[{query_label}]: skipping optimization – no join equalities found"
3808 );
3809 return Ok(None);
3810 }
3811
3812 if !constraint_plan.literals.is_empty() {
3817 tracing::debug!(
3818 "join_opt[{query_label}]: found {} literal constraints - proceeding with hash join but may need fallback",
3819 constraint_plan.literals.len()
3820 );
3821 }
3822
3823 tracing::debug!(
3824 "join_opt[{query_label}]: hash join optimization applicable with {} equality constraints",
3825 constraint_plan.equalities.len()
3826 );
3827
3828 let mut literal_map: Vec<Vec<ColumnConstraint>> =
3829 vec![Vec::new(); tables_with_handles.len()];
3830 for constraint in &constraint_plan.literals {
3831 let table_idx = match constraint {
3832 ColumnConstraint::Equality(lit) => lit.column.table,
3833 ColumnConstraint::InList(in_list) => in_list.column.table,
3834 };
3835 if table_idx >= literal_map.len() {
3836 tracing::debug!(
3837 "join_opt[{query_label}]: constraint references unknown table index {}; falling back",
3838 table_idx
3839 );
3840 return Ok(None);
3841 }
3842 tracing::debug!(
3843 "join_opt[{query_label}]: mapping constraint to table_idx={} (table={})",
3844 table_idx,
3845 tables_with_handles[table_idx].0.qualified_name()
3846 );
3847 literal_map[table_idx].push(constraint.clone());
3848 }
3849
3850 let mut per_table: Vec<Option<TableCrossProductData>> =
3851 Vec::with_capacity(tables_with_handles.len());
3852 for (idx, (table_ref, table)) in tables_with_handles.iter().enumerate() {
3853 let data =
3854 collect_table_data(idx, table_ref, table.as_ref(), literal_map[idx].as_slice())?;
3855 per_table.push(Some(data));
3856 }
3857
3858 let has_left_join = plan
3860 .joins
3861 .iter()
3862 .any(|j| j.join_type == llkv_plan::JoinPlan::Left);
3863
3864 let mut current: Option<TableCrossProductData> = None;
3865
3866 if has_left_join {
3867 tracing::debug!(
3869 "join_opt[{query_label}]: delegating to llkv-join for LEFT JOIN support"
3870 );
3871 return Ok(None);
3873 } else {
3874 let mut remaining: Vec<usize> = (0..tables_with_handles.len()).collect();
3876 let mut used_tables: FxHashSet<usize> = FxHashSet::default();
3877
3878 while !remaining.is_empty() {
3879 let next_index = if used_tables.is_empty() {
3880 remaining[0]
3881 } else {
3882 match remaining.iter().copied().find(|idx| {
3883 table_has_join_with_used(*idx, &used_tables, &constraint_plan.equalities)
3884 }) {
3885 Some(idx) => idx,
3886 None => {
3887 tracing::debug!(
3888 "join_opt[{query_label}]: no remaining equality links – using cartesian expansion for table index {idx}",
3889 idx = remaining[0]
3890 );
3891 remaining[0]
3892 }
3893 }
3894 };
3895
3896 let position = remaining
3897 .iter()
3898 .position(|&idx| idx == next_index)
3899 .expect("next index present");
3900
3901 let next_data = per_table[next_index]
3902 .take()
3903 .ok_or_else(|| Error::Internal("hash join consumed table data twice".into()))?;
3904
3905 if let Some(current_data) = current.take() {
3906 let join_keys = gather_join_keys(
3907 ¤t_data,
3908 &next_data,
3909 &used_tables,
3910 next_index,
3911 &constraint_plan.equalities,
3912 )?;
3913
3914 let joined = if join_keys.is_empty() {
3915 tracing::debug!(
3916 "join_opt[{query_label}]: joining '{}' via cartesian expansion (no equality keys)",
3917 tables_with_handles[next_index].0.qualified_name()
3918 );
3919 cross_join_table_batches(current_data, next_data)?
3920 } else {
3921 hash_join_table_batches(
3922 current_data,
3923 next_data,
3924 &join_keys,
3925 llkv_join::JoinType::Inner,
3926 )?
3927 };
3928 current = Some(joined);
3929 } else {
3930 current = Some(next_data);
3931 }
3932
3933 used_tables.insert(next_index);
3934 remaining.remove(position);
3935 }
3936 }
3937
3938 if let Some(result) = current {
3939 let handled_all = constraint_plan.handled_conjuncts == constraint_plan.total_conjuncts;
3940 tracing::debug!(
3941 "join_opt[{query_label}]: hash join succeeded across {} tables (handled {}/{} predicates)",
3942 tables_with_handles.len(),
3943 constraint_plan.handled_conjuncts,
3944 constraint_plan.total_conjuncts
3945 );
3946 return Ok(Some((result, handled_all)));
3947 }
3948
3949 Ok(None)
3950 }
3951
3952 fn execute_projection(
3953 &self,
3954 table: Arc<ExecutorTable<P>>,
3955 display_name: String,
3956 plan: SelectPlan,
3957 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
3958 ) -> ExecutorResult<SelectExecution<P>> {
3959 if plan.having.is_some() {
3960 return Err(Error::InvalidArgumentError(
3961 "HAVING requires GROUP BY".into(),
3962 ));
3963 }
3964
3965 let has_filter_subqueries = plan
3966 .filter
3967 .as_ref()
3968 .is_some_and(|filter| !filter.subqueries.is_empty());
3969 let has_scalar_subqueries = !plan.scalar_subqueries.is_empty();
3970
3971 if has_filter_subqueries || has_scalar_subqueries {
3972 return self.execute_projection_with_subqueries(table, display_name, plan, row_filter);
3973 }
3974
3975 let table_ref = table.as_ref();
3976 let constant_filter = plan
3977 .filter
3978 .as_ref()
3979 .and_then(|filter| evaluate_constant_predicate(&filter.predicate));
3980 let projections = if plan.projections.is_empty() {
3981 build_wildcard_projections(table_ref)
3982 } else {
3983 build_projected_columns(table_ref, &plan.projections)?
3984 };
3985 let schema = schema_for_projections(table_ref, &projections)?;
3986
3987 if let Some(result) = constant_filter {
3988 match result {
3989 Some(true) => {
3990 }
3992 Some(false) | None => {
3993 let batch = RecordBatch::new_empty(Arc::clone(&schema));
3994 return Ok(SelectExecution::new_single_batch(
3995 display_name,
3996 schema,
3997 batch,
3998 ));
3999 }
4000 }
4001 }
4002
4003 let (mut filter_expr, mut full_table_scan) = match &plan.filter {
4004 Some(filter_wrapper) => (
4005 crate::translation::expression::translate_predicate(
4006 filter_wrapper.predicate.clone(),
4007 table_ref.schema.as_ref(),
4008 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4009 )?,
4010 false,
4011 ),
4012 None => {
4013 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4014 Error::InvalidArgumentError(
4015 "table has no columns; cannot perform wildcard scan".into(),
4016 )
4017 })?;
4018 (
4019 crate::translation::expression::full_table_scan_filter(field_id),
4020 true,
4021 )
4022 }
4023 };
4024
4025 if matches!(constant_filter, Some(Some(true))) {
4026 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4027 Error::InvalidArgumentError(
4028 "table has no columns; cannot perform wildcard scan".into(),
4029 )
4030 })?;
4031 filter_expr = crate::translation::expression::full_table_scan_filter(field_id);
4032 full_table_scan = true;
4033 }
4034
4035 let expanded_order = expand_order_targets(&plan.order_by, &projections)?;
4036
4037 let mut physical_order: Option<ScanOrderSpec> = None;
4038
4039 if let Some(first) = expanded_order.first() {
4040 match &first.target {
4041 OrderTarget::Column(name) => {
4042 if table_ref.schema.resolve(name).is_some() {
4043 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
4044 }
4045 }
4046 OrderTarget::Index(position) => match projections.get(*position) {
4047 Some(ScanProjection::Column(_)) => {
4048 physical_order = Some(resolve_scan_order(table_ref, &projections, first)?);
4049 }
4050 Some(ScanProjection::Computed { .. }) => {}
4051 None => {
4052 return Err(Error::InvalidArgumentError(format!(
4053 "ORDER BY position {} is out of range",
4054 position + 1
4055 )));
4056 }
4057 },
4058 OrderTarget::All => {}
4059 }
4060 }
4061
4062 let options = if let Some(order_spec) = physical_order {
4063 if row_filter.is_some() {
4064 tracing::debug!("Applying MVCC row filter with ORDER BY");
4065 }
4066 ScanStreamOptions {
4067 include_nulls: true,
4068 order: Some(order_spec),
4069 row_id_filter: row_filter.clone(),
4070 }
4071 } else {
4072 if row_filter.is_some() {
4073 tracing::debug!("Applying MVCC row filter");
4074 }
4075 ScanStreamOptions {
4076 include_nulls: true,
4077 order: None,
4078 row_id_filter: row_filter.clone(),
4079 }
4080 };
4081
4082 Ok(SelectExecution::new_projection(
4083 display_name,
4084 schema,
4085 table,
4086 projections,
4087 filter_expr,
4088 options,
4089 full_table_scan,
4090 expanded_order,
4091 plan.distinct,
4092 ))
4093 }
4094
4095 fn execute_projection_with_subqueries(
4096 &self,
4097 table: Arc<ExecutorTable<P>>,
4098 display_name: String,
4099 plan: SelectPlan,
4100 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4101 ) -> ExecutorResult<SelectExecution<P>> {
4102 if plan.having.is_some() {
4103 return Err(Error::InvalidArgumentError(
4104 "HAVING requires GROUP BY".into(),
4105 ));
4106 }
4107 let table_ref = table.as_ref();
4108
4109 let (output_scan_projections, effective_projections): (
4110 Vec<ScanProjection>,
4111 Vec<SelectProjection>,
4112 ) = if plan.projections.is_empty() {
4113 (
4114 build_wildcard_projections(table_ref),
4115 vec![SelectProjection::AllColumns],
4116 )
4117 } else {
4118 (
4119 build_projected_columns(table_ref, &plan.projections)?,
4120 plan.projections.clone(),
4121 )
4122 };
4123
4124 let scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> = plan
4125 .scalar_subqueries
4126 .iter()
4127 .map(|subquery| (subquery.id, subquery))
4128 .collect();
4129
4130 let base_projections = build_wildcard_projections(table_ref);
4131
4132 let filter_wrapper_opt = plan.filter.as_ref();
4133
4134 let mut filter_has_scalar_subqueries = false;
4136 if let Some(filter_wrapper) = filter_wrapper_opt {
4137 let translated = crate::translation::expression::translate_predicate(
4138 filter_wrapper.predicate.clone(),
4139 table_ref.schema.as_ref(),
4140 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4141 )?;
4142 let mut scalar_filter_ids = FxHashSet::default();
4143 collect_predicate_scalar_subquery_ids(&translated, &mut scalar_filter_ids);
4144 filter_has_scalar_subqueries = !scalar_filter_ids.is_empty();
4145 }
4146
4147 let mut translated_filter: Option<llkv_expr::expr::Expr<'static, FieldId>> = None;
4148 let pushdown_filter = if let Some(filter_wrapper) = filter_wrapper_opt {
4149 let translated = crate::translation::expression::translate_predicate(
4150 filter_wrapper.predicate.clone(),
4151 table_ref.schema.as_ref(),
4152 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
4153 )?;
4154 if !filter_wrapper.subqueries.is_empty() || filter_has_scalar_subqueries {
4155 translated_filter = Some(translated.clone());
4156 if filter_has_scalar_subqueries {
4157 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4160 Error::InvalidArgumentError(
4161 "table has no columns; cannot perform scalar subquery projection"
4162 .into(),
4163 )
4164 })?;
4165 crate::translation::expression::full_table_scan_filter(field_id)
4166 } else {
4167 strip_exists(&translated)
4169 }
4170 } else {
4171 translated
4172 }
4173 } else {
4174 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
4175 Error::InvalidArgumentError(
4176 "table has no columns; cannot perform scalar subquery projection".into(),
4177 )
4178 })?;
4179 crate::translation::expression::full_table_scan_filter(field_id)
4180 };
4181
4182 let mut base_fields: Vec<Field> = Vec::with_capacity(table_ref.schema.columns.len());
4183 for column in &table_ref.schema.columns {
4184 base_fields.push(Field::new(
4185 column.name.clone(),
4186 column.data_type.clone(),
4187 column.nullable,
4188 ));
4189 }
4190 let base_schema = Arc::new(Schema::new(base_fields));
4191 let base_column_counts = vec![base_schema.fields().len()];
4192 let base_table_indices = vec![0usize];
4193 let base_lookup = build_cross_product_column_lookup(
4194 base_schema.as_ref(),
4195 &plan.tables,
4196 &base_column_counts,
4197 &base_table_indices,
4198 );
4199
4200 let mut filter_context = if translated_filter.is_some() {
4201 Some(CrossProductExpressionContext::new(
4202 base_schema.as_ref(),
4203 base_lookup.clone(),
4204 )?)
4205 } else {
4206 None
4207 };
4208
4209 let mut filter_scalar_subquery_ids = FxHashSet::default();
4211 if let Some(translated) = translated_filter.as_ref() {
4212 collect_predicate_scalar_subquery_ids(translated, &mut filter_scalar_subquery_ids);
4213 }
4214
4215 let filter_scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> =
4217 if !filter_scalar_subquery_ids.is_empty() {
4218 plan.scalar_subqueries
4219 .iter()
4220 .filter(|subquery| filter_scalar_subquery_ids.contains(&subquery.id))
4221 .map(|subquery| (subquery.id, subquery))
4222 .collect()
4223 } else {
4224 FxHashMap::default()
4225 };
4226
4227 let options = ScanStreamOptions {
4228 include_nulls: true,
4229 order: None,
4230 row_id_filter: row_filter.clone(),
4231 };
4232
4233 let subquery_lookup: FxHashMap<llkv_expr::SubqueryId, &llkv_plan::FilterSubquery> =
4234 filter_wrapper_opt
4235 .map(|wrapper| {
4236 wrapper
4237 .subqueries
4238 .iter()
4239 .map(|subquery| (subquery.id, subquery))
4240 .collect()
4241 })
4242 .unwrap_or_default();
4243
4244 let mut projected_batches: Vec<RecordBatch> = Vec::new();
4245 let mut scan_error: Option<Error> = None;
4246
4247 table.table.scan_stream(
4248 base_projections.clone(),
4249 &pushdown_filter,
4250 options,
4251 |batch| {
4252 if scan_error.is_some() {
4253 return;
4254 }
4255 let effective_batch = if let Some(context) = filter_context.as_mut() {
4256 context.reset();
4257
4258 for (subquery_id, subquery) in filter_scalar_lookup.iter() {
4260 let result_array = match self
4261 .evaluate_scalar_subquery_numeric(context, subquery, &batch)
4262 {
4263 Ok(array) => array,
4264 Err(err) => {
4265 scan_error = Some(err);
4266 return;
4267 }
4268 };
4269 let accessor = match ColumnAccessor::from_numeric_array(&result_array) {
4270 Ok(acc) => acc,
4271 Err(err) => {
4272 scan_error = Some(err);
4273 return;
4274 }
4275 };
4276 context
4277 .scalar_subquery_columns
4278 .insert(*subquery_id, accessor);
4279 }
4280 let translated = translated_filter
4281 .as_ref()
4282 .expect("filter context requires translated filter");
4283 let mask = match context.evaluate_predicate_mask(
4284 translated,
4285 &batch,
4286 |ctx, subquery_expr, row_idx, current_batch| {
4287 let subquery =
4288 subquery_lookup.get(&subquery_expr.id).ok_or_else(|| {
4289 Error::Internal("missing correlated subquery metadata".into())
4290 })?;
4291 let exists = self.evaluate_exists_subquery(
4292 ctx,
4293 subquery,
4294 current_batch,
4295 row_idx,
4296 )?;
4297 let value = if subquery_expr.negated {
4298 !exists
4299 } else {
4300 exists
4301 };
4302 Ok(Some(value))
4303 },
4304 ) {
4305 Ok(mask) => mask,
4306 Err(err) => {
4307 scan_error = Some(err);
4308 return;
4309 }
4310 };
4311 match filter_record_batch(&batch, &mask) {
4312 Ok(filtered) => {
4313 if filtered.num_rows() == 0 {
4314 return;
4315 }
4316 filtered
4317 }
4318 Err(err) => {
4319 scan_error = Some(Error::InvalidArgumentError(format!(
4320 "failed to apply EXISTS filter: {err}"
4321 )));
4322 return;
4323 }
4324 }
4325 } else {
4326 batch.clone()
4327 };
4328
4329 if effective_batch.num_rows() == 0 {
4330 return;
4331 }
4332
4333 let projected = match self.project_record_batch(
4334 &effective_batch,
4335 &effective_projections,
4336 &base_lookup,
4337 &scalar_lookup,
4338 ) {
4339 Ok(batch) => batch,
4340 Err(err) => {
4341 scan_error = Some(Error::InvalidArgumentError(format!(
4342 "failed to evaluate projections: {err}"
4343 )));
4344 return;
4345 }
4346 };
4347 projected_batches.push(projected);
4348 },
4349 )?;
4350
4351 if let Some(err) = scan_error {
4352 return Err(err);
4353 }
4354
4355 let mut result_batch = if projected_batches.is_empty() {
4356 let empty_batch = RecordBatch::new_empty(Arc::clone(&base_schema));
4357 self.project_record_batch(
4358 &empty_batch,
4359 &effective_projections,
4360 &base_lookup,
4361 &scalar_lookup,
4362 )?
4363 } else if projected_batches.len() == 1 {
4364 projected_batches.pop().unwrap()
4365 } else {
4366 let schema = projected_batches[0].schema();
4367 concat_batches(&schema, &projected_batches).map_err(|err| {
4368 Error::Internal(format!("failed to combine filtered batches: {err}"))
4369 })?
4370 };
4371
4372 if plan.distinct && result_batch.num_rows() > 0 {
4373 let mut state = DistinctState::default();
4374 let schema = result_batch.schema();
4375 result_batch = match distinct_filter_batch(result_batch, &mut state)? {
4376 Some(filtered) => filtered,
4377 None => RecordBatch::new_empty(schema),
4378 };
4379 }
4380
4381 if !plan.order_by.is_empty() && result_batch.num_rows() > 0 {
4382 let expanded_order = expand_order_targets(&plan.order_by, &output_scan_projections)?;
4383 if !expanded_order.is_empty() {
4384 result_batch = sort_record_batch_with_order(
4385 &result_batch.schema(),
4386 &result_batch,
4387 &expanded_order,
4388 )?;
4389 }
4390 }
4391
4392 let schema = result_batch.schema();
4393
4394 Ok(SelectExecution::new_single_batch(
4395 display_name,
4396 schema,
4397 result_batch,
4398 ))
4399 }
4400
4401 fn execute_group_by_single_table(
4402 &self,
4403 table: Arc<ExecutorTable<P>>,
4404 display_name: String,
4405 plan: SelectPlan,
4406 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
4407 ) -> ExecutorResult<SelectExecution<P>> {
4408 if plan
4409 .filter
4410 .as_ref()
4411 .is_some_and(|filter| !filter.subqueries.is_empty())
4412 || !plan.scalar_subqueries.is_empty()
4413 {
4414 return Err(Error::InvalidArgumentError(
4415 "GROUP BY with subqueries is not supported yet".into(),
4416 ));
4417 }
4418
4419 tracing::debug!(
4421 "[GROUP BY] Original plan: projections={}, aggregates={}, has_filter={}, has_having={}",
4422 plan.projections.len(),
4423 plan.aggregates.len(),
4424 plan.filter.is_some(),
4425 plan.having.is_some()
4426 );
4427
4428 let mut base_plan = plan.clone();
4432 base_plan.projections.clear();
4433 base_plan.aggregates.clear();
4434 base_plan.scalar_subqueries.clear();
4435 base_plan.order_by.clear();
4436 base_plan.distinct = false;
4437 base_plan.group_by.clear();
4438 base_plan.value_table_mode = None;
4439 base_plan.having = None;
4440
4441 tracing::debug!(
4442 "[GROUP BY] Base plan: projections={}, aggregates={}, has_filter={}, has_having={}",
4443 base_plan.projections.len(),
4444 base_plan.aggregates.len(),
4445 base_plan.filter.is_some(),
4446 base_plan.having.is_some()
4447 );
4448
4449 let table_ref = table.as_ref();
4452 let projections = build_wildcard_projections(table_ref);
4453 let base_schema = schema_for_projections(table_ref, &projections)?;
4454
4455 tracing::debug!(
4457 "[GROUP BY] Building base filter: has_filter={}",
4458 base_plan.filter.is_some()
4459 );
4460 let (filter_expr, full_table_scan) = match &base_plan.filter {
4461 Some(filter_wrapper) => {
4462 tracing::debug!(
4463 "[GROUP BY] Translating filter predicate: {:?}",
4464 filter_wrapper.predicate
4465 );
4466 let expr = crate::translation::expression::translate_predicate(
4467 filter_wrapper.predicate.clone(),
4468 table_ref.schema.as_ref(),
4469 |name| {
4470 Error::InvalidArgumentError(format!(
4471 "Binder Error: does not have a column named '{}'",
4472 name
4473 ))
4474 },
4475 )?;
4476 tracing::debug!("[GROUP BY] Translated filter expr: {:?}", expr);
4477 (expr, false)
4478 }
4479 None => {
4480 let first_col =
4482 table_ref.schema.columns.first().ok_or_else(|| {
4483 Error::InvalidArgumentError("Table has no columns".into())
4484 })?;
4485 (full_table_scan_filter(first_col.field_id), true)
4486 }
4487 };
4488
4489 let options = ScanStreamOptions {
4490 include_nulls: true,
4491 order: None,
4492 row_id_filter: row_filter.clone(),
4493 };
4494
4495 let execution = SelectExecution::new_projection(
4496 display_name.clone(),
4497 Arc::clone(&base_schema),
4498 Arc::clone(&table),
4499 projections,
4500 filter_expr,
4501 options,
4502 full_table_scan,
4503 vec![],
4504 false,
4505 );
4506
4507 let batches = execution.collect()?;
4508
4509 tracing::debug!(
4510 "[GROUP BY] Collected {} batches from base scan, total_rows={}",
4511 batches.len(),
4512 batches.iter().map(|b| b.num_rows()).sum::<usize>()
4513 );
4514
4515 let column_lookup_map = build_column_lookup_map(base_schema.as_ref());
4516
4517 self.execute_group_by_from_batches(
4518 display_name,
4519 plan,
4520 base_schema,
4521 batches,
4522 column_lookup_map,
4523 )
4524 }
4525
4526 fn execute_group_by_from_batches(
4527 &self,
4528 display_name: String,
4529 plan: SelectPlan,
4530 base_schema: Arc<Schema>,
4531 batches: Vec<RecordBatch>,
4532 column_lookup_map: FxHashMap<String, usize>,
4533 ) -> ExecutorResult<SelectExecution<P>> {
4534 if plan
4535 .filter
4536 .as_ref()
4537 .is_some_and(|filter| !filter.subqueries.is_empty())
4538 || !plan.scalar_subqueries.is_empty()
4539 {
4540 return Err(Error::InvalidArgumentError(
4541 "GROUP BY with subqueries is not supported yet".into(),
4542 ));
4543 }
4544
4545 let having_has_aggregates = plan
4548 .having
4549 .as_ref()
4550 .map(|h| Self::predicate_contains_aggregate(h))
4551 .unwrap_or(false);
4552
4553 tracing::debug!(
4554 "[GROUP BY PATH] aggregates={}, has_computed={}, having_has_agg={}",
4555 plan.aggregates.len(),
4556 self.has_computed_aggregates(&plan),
4557 having_has_aggregates
4558 );
4559
4560 if !plan.aggregates.is_empty()
4561 || self.has_computed_aggregates(&plan)
4562 || having_has_aggregates
4563 {
4564 tracing::debug!("[GROUP BY PATH] Taking aggregates path");
4565 return self.execute_group_by_with_aggregates(
4566 display_name,
4567 plan,
4568 base_schema,
4569 batches,
4570 column_lookup_map,
4571 );
4572 }
4573
4574 let mut key_indices = Vec::with_capacity(plan.group_by.len());
4575 for column in &plan.group_by {
4576 let key = column.to_ascii_lowercase();
4577 let index = column_lookup_map.get(&key).ok_or_else(|| {
4578 Error::InvalidArgumentError(format!(
4579 "column '{}' not found in GROUP BY input",
4580 column
4581 ))
4582 })?;
4583 key_indices.push(*index);
4584 }
4585
4586 let sample_batch = batches
4587 .first()
4588 .cloned()
4589 .unwrap_or_else(|| RecordBatch::new_empty(Arc::clone(&base_schema)));
4590
4591 let output_columns = self.build_group_by_output_columns(
4592 &plan,
4593 base_schema.as_ref(),
4594 &column_lookup_map,
4595 &sample_batch,
4596 )?;
4597
4598 let constant_having = plan.having.as_ref().and_then(evaluate_constant_predicate);
4599
4600 if let Some(result) = constant_having
4601 && !result.unwrap_or(false)
4602 {
4603 let fields: Vec<Field> = output_columns
4604 .iter()
4605 .map(|output| output.field.clone())
4606 .collect();
4607 let schema = Arc::new(Schema::new(fields));
4608 let batch = RecordBatch::new_empty(Arc::clone(&schema));
4609 return Ok(SelectExecution::new_single_batch(
4610 display_name,
4611 schema,
4612 batch,
4613 ));
4614 }
4615
4616 let translated_having = if plan.having.is_some() && constant_having.is_none() {
4617 let having = plan.having.clone().expect("checked above");
4618 if Self::predicate_contains_aggregate(&having) {
4621 None
4622 } else {
4623 let temp_context = CrossProductExpressionContext::new(
4624 base_schema.as_ref(),
4625 column_lookup_map.clone(),
4626 )?;
4627 Some(translate_predicate(
4628 having,
4629 temp_context.schema(),
4630 |name| {
4631 Error::InvalidArgumentError(format!(
4632 "column '{}' not found in GROUP BY result",
4633 name
4634 ))
4635 },
4636 )?)
4637 }
4638 } else {
4639 None
4640 };
4641
4642 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
4643 let mut groups: Vec<GroupState> = Vec::new();
4644
4645 for batch in &batches {
4646 for row_idx in 0..batch.num_rows() {
4647 let key = build_group_key(batch, row_idx, &key_indices)?;
4648 if group_index.contains_key(&key) {
4649 continue;
4650 }
4651 group_index.insert(key, groups.len());
4652 groups.push(GroupState {
4653 batch: batch.clone(),
4654 row_idx,
4655 });
4656 }
4657 }
4658
4659 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(groups.len());
4660
4661 for group in &groups {
4662 if let Some(predicate) = translated_having.as_ref() {
4663 let mut context = CrossProductExpressionContext::new(
4664 group.batch.schema().as_ref(),
4665 column_lookup_map.clone(),
4666 )?;
4667 context.reset();
4668 let mut eval = |_ctx: &mut CrossProductExpressionContext,
4669 _subquery_expr: &llkv_expr::SubqueryExpr,
4670 _row_idx: usize,
4671 _current_batch: &RecordBatch|
4672 -> ExecutorResult<Option<bool>> {
4673 Err(Error::InvalidArgumentError(
4674 "HAVING subqueries are not supported yet".into(),
4675 ))
4676 };
4677 let truths =
4678 context.evaluate_predicate_truths(predicate, &group.batch, &mut eval)?;
4679 let passes = truths
4680 .get(group.row_idx)
4681 .copied()
4682 .flatten()
4683 .unwrap_or(false);
4684 if !passes {
4685 continue;
4686 }
4687 }
4688
4689 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
4690 for output in &output_columns {
4691 match output.source {
4692 OutputSource::TableColumn { index } => {
4693 let value = llkv_plan::plan_value_from_array(
4694 group.batch.column(index),
4695 group.row_idx,
4696 )?;
4697 row.push(value);
4698 }
4699 OutputSource::Computed { projection_index } => {
4700 let expr = match &plan.projections[projection_index] {
4701 SelectProjection::Computed { expr, .. } => expr,
4702 _ => unreachable!("projection index mismatch for computed column"),
4703 };
4704 let mut context = CrossProductExpressionContext::new(
4705 group.batch.schema().as_ref(),
4706 column_lookup_map.clone(),
4707 )?;
4708 context.reset();
4709 let evaluated = self.evaluate_projection_expression(
4710 &mut context,
4711 expr,
4712 &group.batch,
4713 &FxHashMap::default(),
4714 )?;
4715 let value = llkv_plan::plan_value_from_array(&evaluated, group.row_idx)?;
4716 row.push(value);
4717 }
4718 }
4719 }
4720 rows.push(row);
4721 }
4722
4723 let fields: Vec<Field> = output_columns
4724 .into_iter()
4725 .map(|output| output.field)
4726 .collect();
4727 let schema = Arc::new(Schema::new(fields));
4728
4729 let mut batch = rows_to_record_batch(Arc::clone(&schema), &rows)?;
4730
4731 if plan.distinct && batch.num_rows() > 0 {
4732 let mut state = DistinctState::default();
4733 batch = match distinct_filter_batch(batch, &mut state)? {
4734 Some(filtered) => filtered,
4735 None => RecordBatch::new_empty(Arc::clone(&schema)),
4736 };
4737 }
4738
4739 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
4740 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
4741 }
4742
4743 Ok(SelectExecution::new_single_batch(
4744 display_name,
4745 schema,
4746 batch,
4747 ))
4748 }
4749
4750 fn infer_computed_expression_type(
4752 expr: &ScalarExpr<String>,
4753 base_schema: &Schema,
4754 column_lookup_map: &FxHashMap<String, usize>,
4755 sample_batch: &RecordBatch,
4756 ) -> Option<DataType> {
4757 use llkv_expr::expr::AggregateCall;
4758
4759 if let ScalarExpr::Aggregate(agg_call) = expr {
4761 return match agg_call {
4762 AggregateCall::CountStar
4763 | AggregateCall::Count { .. }
4764 | AggregateCall::CountNulls(_) => Some(DataType::Int64),
4765 AggregateCall::Sum { expr: agg_expr, .. }
4766 | AggregateCall::Total { expr: agg_expr, .. }
4767 | AggregateCall::Avg { expr: agg_expr, .. }
4768 | AggregateCall::Min(agg_expr)
4769 | AggregateCall::Max(agg_expr) => {
4770 if let Some(dtype) =
4772 infer_type_recursive(agg_expr, base_schema, column_lookup_map)
4773 {
4774 return Some(dtype);
4775 }
4776
4777 if let Some(col_name) = try_extract_simple_column(agg_expr) {
4779 let idx = resolve_column_name_to_index(col_name, column_lookup_map)?;
4780 Some(base_schema.field(idx).data_type().clone())
4781 } else {
4782 if sample_batch.num_rows() > 0 {
4785 let mut computed_values = Vec::new();
4786 if let Ok(value) =
4787 Self::evaluate_expr_with_plan_value_aggregates_and_row(
4788 agg_expr,
4789 &FxHashMap::default(),
4790 Some(sample_batch),
4791 Some(column_lookup_map),
4792 0,
4793 )
4794 {
4795 computed_values.push(value);
4796 if let Ok(array) = plan_values_to_arrow_array(&computed_values) {
4797 match array.data_type() {
4798 DataType::Decimal128(_, scale) => {
4800 return Some(DataType::Decimal128(38, *scale));
4801 }
4802 DataType::Null => {
4804 return Some(DataType::Float64);
4805 }
4806 other => {
4807 return Some(other.clone());
4808 }
4809 }
4810 }
4811 }
4812 }
4813 Some(DataType::Float64)
4815 }
4816 }
4817 AggregateCall::GroupConcat { .. } => Some(DataType::Utf8),
4818 };
4819 }
4820
4821 None
4824 }
4825
4826 fn build_group_by_output_columns(
4827 &self,
4828 plan: &SelectPlan,
4829 base_schema: &Schema,
4830 column_lookup_map: &FxHashMap<String, usize>,
4831 _sample_batch: &RecordBatch,
4832 ) -> ExecutorResult<Vec<OutputColumn>> {
4833 let projections = if plan.projections.is_empty() {
4834 vec![SelectProjection::AllColumns]
4835 } else {
4836 plan.projections.clone()
4837 };
4838
4839 let mut columns: Vec<OutputColumn> = Vec::new();
4840
4841 for (proj_idx, projection) in projections.iter().enumerate() {
4842 match projection {
4843 SelectProjection::AllColumns => {
4844 for (index, field) in base_schema.fields().iter().enumerate() {
4845 columns.push(OutputColumn {
4846 field: (**field).clone(),
4847 source: OutputSource::TableColumn { index },
4848 });
4849 }
4850 }
4851 SelectProjection::AllColumnsExcept { exclude } => {
4852 let exclude_lower: FxHashSet<String> = exclude
4853 .iter()
4854 .map(|name| name.to_ascii_lowercase())
4855 .collect();
4856
4857 let mut excluded_indices = FxHashSet::default();
4858 for excluded_name in &exclude_lower {
4859 if let Some(&idx) = column_lookup_map.get(excluded_name) {
4860 excluded_indices.insert(idx);
4861 }
4862 }
4863
4864 for (index, field) in base_schema.fields().iter().enumerate() {
4865 if !exclude_lower.contains(&field.name().to_ascii_lowercase())
4866 && !excluded_indices.contains(&index)
4867 {
4868 columns.push(OutputColumn {
4869 field: (**field).clone(),
4870 source: OutputSource::TableColumn { index },
4871 });
4872 }
4873 }
4874 }
4875 SelectProjection::Column { name, alias } => {
4876 let lookup_key = name.to_ascii_lowercase();
4877 let index = column_lookup_map.get(&lookup_key).ok_or_else(|| {
4878 Error::InvalidArgumentError(format!(
4879 "column '{}' not found in GROUP BY result",
4880 name
4881 ))
4882 })?;
4883 let field = base_schema.field(*index);
4884 let field = Field::new(
4885 alias.as_ref().unwrap_or(name).clone(),
4886 field.data_type().clone(),
4887 field.is_nullable(),
4888 );
4889 columns.push(OutputColumn {
4890 field,
4891 source: OutputSource::TableColumn { index: *index },
4892 });
4893 }
4894 SelectProjection::Computed { expr, alias } => {
4895 let inferred_type = Self::infer_computed_expression_type(
4899 expr,
4900 base_schema,
4901 column_lookup_map,
4902 _sample_batch,
4903 )
4904 .unwrap_or(DataType::Float64);
4905 let field = Field::new(alias.clone(), inferred_type, true);
4906 columns.push(OutputColumn {
4907 field,
4908 source: OutputSource::Computed {
4909 projection_index: proj_idx,
4910 },
4911 });
4912 }
4913 }
4914 }
4915
4916 if columns.is_empty() {
4917 for (index, field) in base_schema.fields().iter().enumerate() {
4918 columns.push(OutputColumn {
4919 field: (**field).clone(),
4920 source: OutputSource::TableColumn { index },
4921 });
4922 }
4923 }
4924
4925 Ok(columns)
4926 }
4927
4928 fn project_record_batch(
4929 &self,
4930 batch: &RecordBatch,
4931 projections: &[SelectProjection],
4932 lookup: &FxHashMap<String, usize>,
4933 scalar_lookup: &FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery>,
4934 ) -> ExecutorResult<RecordBatch> {
4935 if projections.is_empty() {
4936 return Ok(batch.clone());
4937 }
4938
4939 let schema = batch.schema();
4940 let mut selected_fields: Vec<Arc<Field>> = Vec::new();
4941 let mut selected_columns: Vec<ArrayRef> = Vec::new();
4942 let mut expr_context: Option<CrossProductExpressionContext> = None;
4943
4944 for proj in projections {
4945 match proj {
4946 SelectProjection::AllColumns => {
4947 selected_fields = schema.fields().iter().cloned().collect();
4948 selected_columns = batch.columns().to_vec();
4949 break;
4950 }
4951 SelectProjection::AllColumnsExcept { exclude } => {
4952 let exclude_lower: FxHashSet<String> = exclude
4953 .iter()
4954 .map(|name| name.to_ascii_lowercase())
4955 .collect();
4956
4957 let mut excluded_indices = FxHashSet::default();
4958 for excluded_name in &exclude_lower {
4959 if let Some(&idx) = lookup.get(excluded_name) {
4960 excluded_indices.insert(idx);
4961 }
4962 }
4963
4964 for (idx, field) in schema.fields().iter().enumerate() {
4965 let column_name = field.name().to_ascii_lowercase();
4966 if !exclude_lower.contains(&column_name) && !excluded_indices.contains(&idx)
4967 {
4968 selected_fields.push(Arc::clone(field));
4969 selected_columns.push(batch.column(idx).clone());
4970 }
4971 }
4972 break;
4973 }
4974 SelectProjection::Column { name, alias } => {
4975 let normalized = name.to_ascii_lowercase();
4976 let column_index = lookup.get(&normalized).ok_or_else(|| {
4977 Error::InvalidArgumentError(format!(
4978 "column '{}' not found in projection",
4979 name
4980 ))
4981 })?;
4982 let field = schema.field(*column_index);
4983 let output_field = Arc::new(Field::new(
4984 alias.as_ref().unwrap_or_else(|| field.name()),
4985 field.data_type().clone(),
4986 field.is_nullable(),
4987 ));
4988 selected_fields.push(output_field);
4989 selected_columns.push(batch.column(*column_index).clone());
4990 }
4991 SelectProjection::Computed { expr, alias } => {
4992 if expr_context.is_none() {
4993 expr_context = Some(CrossProductExpressionContext::new(
4994 schema.as_ref(),
4995 lookup.clone(),
4996 )?);
4997 }
4998 let context = expr_context
4999 .as_mut()
5000 .expect("projection context must be initialized");
5001 context.reset();
5002 let evaluated =
5003 self.evaluate_projection_expression(context, expr, batch, scalar_lookup)?;
5004 let field = Arc::new(Field::new(
5005 alias.clone(),
5006 evaluated.data_type().clone(),
5007 true,
5008 ));
5009 selected_fields.push(field);
5010 selected_columns.push(evaluated);
5011 }
5012 }
5013 }
5014
5015 let projected_schema = Arc::new(Schema::new(selected_fields));
5016 RecordBatch::try_new(projected_schema, selected_columns)
5017 .map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))
5018 }
5019
5020 fn execute_group_by_with_aggregates(
5022 &self,
5023 display_name: String,
5024 plan: SelectPlan,
5025 base_schema: Arc<Schema>,
5026 batches: Vec<RecordBatch>,
5027 column_lookup_map: FxHashMap<String, usize>,
5028 ) -> ExecutorResult<SelectExecution<P>> {
5029 use llkv_expr::expr::AggregateCall;
5030
5031 let mut key_indices = Vec::with_capacity(plan.group_by.len());
5033 for column in &plan.group_by {
5034 let key = column.to_ascii_lowercase();
5035 let index = column_lookup_map.get(&key).ok_or_else(|| {
5036 Error::InvalidArgumentError(format!(
5037 "column '{}' not found in GROUP BY input",
5038 column
5039 ))
5040 })?;
5041 key_indices.push(*index);
5042 }
5043
5044 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
5046 for proj in &plan.projections {
5047 if let SelectProjection::Computed { expr, .. } = proj {
5048 Self::collect_aggregates(expr, &mut aggregate_specs);
5049 }
5050 }
5051
5052 if let Some(having_expr) = &plan.having {
5054 Self::collect_aggregates_from_predicate(having_expr, &mut aggregate_specs);
5055 }
5056
5057 let mut group_index: FxHashMap<Vec<GroupKeyValue>, usize> = FxHashMap::default();
5059 let mut group_states: Vec<GroupAggregateState> = Vec::new();
5060
5061 for (batch_idx, batch) in batches.iter().enumerate() {
5063 for row_idx in 0..batch.num_rows() {
5064 let key = build_group_key(batch, row_idx, &key_indices)?;
5065
5066 if let Some(&group_idx) = group_index.get(&key) {
5067 group_states[group_idx]
5069 .row_locations
5070 .push((batch_idx, row_idx));
5071 } else {
5072 let group_idx = group_states.len();
5074 group_index.insert(key, group_idx);
5075 group_states.push(GroupAggregateState {
5076 representative_batch_idx: batch_idx,
5077 representative_row: row_idx,
5078 row_locations: vec![(batch_idx, row_idx)],
5079 });
5080 }
5081 }
5082 }
5083
5084 let mut group_aggregate_values: Vec<FxHashMap<String, PlanValue>> =
5086 Vec::with_capacity(group_states.len());
5087
5088 for group_state in &group_states {
5089 tracing::debug!(
5090 "[GROUP BY] aggregate group rows={:?}",
5091 group_state.row_locations
5092 );
5093 let group_batch = {
5095 let representative_batch = &batches[group_state.representative_batch_idx];
5096 let schema = representative_batch.schema();
5097
5098 let mut per_batch_indices: Vec<(usize, Vec<u64>)> = Vec::new();
5100 for &(batch_idx, row_idx) in &group_state.row_locations {
5101 if let Some((_, indices)) = per_batch_indices
5102 .iter_mut()
5103 .find(|(idx, _)| *idx == batch_idx)
5104 {
5105 indices.push(row_idx as u64);
5106 } else {
5107 per_batch_indices.push((batch_idx, vec![row_idx as u64]));
5108 }
5109 }
5110
5111 let mut row_index_arrays: Vec<(usize, ArrayRef)> =
5112 Vec::with_capacity(per_batch_indices.len());
5113 for (batch_idx, indices) in per_batch_indices {
5114 let index_array: ArrayRef = Arc::new(arrow::array::UInt64Array::from(indices));
5115 row_index_arrays.push((batch_idx, index_array));
5116 }
5117
5118 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
5119
5120 for col_idx in 0..schema.fields().len() {
5121 let column_array = if row_index_arrays.len() == 1 {
5122 let (batch_idx, indices) = &row_index_arrays[0];
5123 let source_array = batches[*batch_idx].column(col_idx);
5124 arrow::compute::take(source_array.as_ref(), indices.as_ref(), None)?
5125 } else {
5126 let mut partial_arrays: Vec<ArrayRef> =
5127 Vec::with_capacity(row_index_arrays.len());
5128 for (batch_idx, indices) in &row_index_arrays {
5129 let source_array = batches[*batch_idx].column(col_idx);
5130 let taken = arrow::compute::take(
5131 source_array.as_ref(),
5132 indices.as_ref(),
5133 None,
5134 )?;
5135 partial_arrays.push(taken);
5136 }
5137 let slices: Vec<&dyn arrow::array::Array> =
5138 partial_arrays.iter().map(|arr| arr.as_ref()).collect();
5139 arrow::compute::concat(&slices)?
5140 };
5141 arrays.push(column_array);
5142 }
5143
5144 let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
5145 tracing::debug!("[GROUP BY] group batch rows={}", batch.num_rows());
5146 batch
5147 };
5148
5149 let mut aggregate_values: FxHashMap<String, PlanValue> = FxHashMap::default();
5151
5152 let mut working_batch = group_batch.clone();
5154 let mut next_temp_col_idx = working_batch.num_columns();
5155
5156 for (key, agg_call) in &aggregate_specs {
5157 let (projection_idx, value_type) = match agg_call {
5159 AggregateCall::CountStar => (None, None),
5160 AggregateCall::Count { expr, .. }
5161 | AggregateCall::Sum { expr, .. }
5162 | AggregateCall::Total { expr, .. }
5163 | AggregateCall::Avg { expr, .. }
5164 | AggregateCall::Min(expr)
5165 | AggregateCall::Max(expr)
5166 | AggregateCall::CountNulls(expr)
5167 | AggregateCall::GroupConcat { expr, .. } => {
5168 if let Some(col_name) = try_extract_simple_column(expr) {
5169 let idx = resolve_column_name_to_index(col_name, &column_lookup_map)
5170 .ok_or_else(|| {
5171 Error::InvalidArgumentError(format!(
5172 "column '{}' not found for aggregate",
5173 col_name
5174 ))
5175 })?;
5176 let field_type = working_batch.schema().field(idx).data_type().clone();
5177 (Some(idx), Some(field_type))
5178 } else {
5179 let mut computed_values = Vec::with_capacity(working_batch.num_rows());
5181 for row_idx in 0..working_batch.num_rows() {
5182 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5183 expr,
5184 &FxHashMap::default(),
5185 Some(&working_batch),
5186 Some(&column_lookup_map),
5187 row_idx,
5188 )?;
5189 computed_values.push(value);
5190 }
5191
5192 let computed_array = plan_values_to_arrow_array(&computed_values)?;
5193 let computed_type = computed_array.data_type().clone();
5194
5195 let mut new_columns: Vec<ArrayRef> = working_batch.columns().to_vec();
5196 new_columns.push(computed_array);
5197
5198 let temp_field = Arc::new(Field::new(
5199 format!("__temp_agg_expr_{}", next_temp_col_idx),
5200 computed_type.clone(),
5201 true,
5202 ));
5203 let mut new_fields: Vec<Arc<Field>> =
5204 working_batch.schema().fields().iter().cloned().collect();
5205 new_fields.push(temp_field);
5206 let new_schema = Arc::new(Schema::new(new_fields));
5207
5208 working_batch = RecordBatch::try_new(new_schema, new_columns)?;
5209
5210 let col_idx = next_temp_col_idx;
5211 next_temp_col_idx += 1;
5212 (Some(col_idx), Some(computed_type))
5213 }
5214 }
5215 };
5216
5217 let spec = Self::build_aggregate_spec_for_cross_product(
5219 agg_call,
5220 key.clone(),
5221 value_type.clone(),
5222 )?;
5223
5224 let mut state = llkv_aggregate::AggregateState {
5225 alias: key.clone(),
5226 accumulator: llkv_aggregate::AggregateAccumulator::new_with_projection_index(
5227 &spec,
5228 projection_idx,
5229 None,
5230 )?,
5231 override_value: None,
5232 };
5233
5234 state.update(&working_batch)?;
5236
5237 let (_field, array) = state.finalize()?;
5239 let value = llkv_plan::plan_value_from_array(&array, 0)?;
5240 tracing::debug!(
5241 "[GROUP BY] aggregate result key={:?} value={:?}",
5242 key,
5243 value
5244 );
5245 aggregate_values.insert(key.clone(), value);
5246 }
5247
5248 group_aggregate_values.push(aggregate_values);
5249 }
5250
5251 let output_columns = self.build_group_by_output_columns(
5253 &plan,
5254 base_schema.as_ref(),
5255 &column_lookup_map,
5256 batches
5257 .first()
5258 .unwrap_or(&RecordBatch::new_empty(Arc::clone(&base_schema))),
5259 )?;
5260
5261 let mut rows: Vec<Vec<PlanValue>> = Vec::with_capacity(group_states.len());
5262
5263 for (group_idx, group_state) in group_states.iter().enumerate() {
5264 let aggregate_values = &group_aggregate_values[group_idx];
5265 let representative_batch = &batches[group_state.representative_batch_idx];
5266
5267 let mut row: Vec<PlanValue> = Vec::with_capacity(output_columns.len());
5268 for output in &output_columns {
5269 match output.source {
5270 OutputSource::TableColumn { index } => {
5271 let value = llkv_plan::plan_value_from_array(
5273 representative_batch.column(index),
5274 group_state.representative_row,
5275 )?;
5276 row.push(value);
5277 }
5278 OutputSource::Computed { projection_index } => {
5279 let expr = match &plan.projections[projection_index] {
5280 SelectProjection::Computed { expr, .. } => expr,
5281 _ => unreachable!("projection index mismatch for computed column"),
5282 };
5283 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
5285 expr,
5286 aggregate_values,
5287 Some(representative_batch),
5288 Some(&column_lookup_map),
5289 group_state.representative_row,
5290 )?;
5291 row.push(value);
5292 }
5293 }
5294 }
5295 rows.push(row);
5296 }
5297
5298 let filtered_rows = if let Some(having) = &plan.having {
5300 let mut filtered = Vec::new();
5301 for (row_idx, row) in rows.iter().enumerate() {
5302 let aggregate_values = &group_aggregate_values[row_idx];
5303 let group_state = &group_states[row_idx];
5304 let representative_batch = &batches[group_state.representative_batch_idx];
5305 let passes = Self::evaluate_having_expr(
5307 having,
5308 aggregate_values,
5309 representative_batch,
5310 &column_lookup_map,
5311 group_state.representative_row,
5312 )?;
5313 if matches!(passes, Some(true)) {
5315 filtered.push(row.clone());
5316 }
5317 }
5318 filtered
5319 } else {
5320 rows
5321 };
5322
5323 let fields: Vec<Field> = output_columns
5324 .into_iter()
5325 .map(|output| output.field)
5326 .collect();
5327 let schema = Arc::new(Schema::new(fields));
5328
5329 let mut batch = rows_to_record_batch(Arc::clone(&schema), &filtered_rows)?;
5330
5331 if plan.distinct && batch.num_rows() > 0 {
5332 let mut state = DistinctState::default();
5333 batch = match distinct_filter_batch(batch, &mut state)? {
5334 Some(filtered) => filtered,
5335 None => RecordBatch::new_empty(Arc::clone(&schema)),
5336 };
5337 }
5338
5339 if !plan.order_by.is_empty() && batch.num_rows() > 0 {
5340 batch = sort_record_batch_with_order(&schema, &batch, &plan.order_by)?;
5341 }
5342
5343 Ok(SelectExecution::new_single_batch(
5344 display_name,
5345 schema,
5346 batch,
5347 ))
5348 }
5349
5350 fn execute_aggregates(
5351 &self,
5352 table: Arc<ExecutorTable<P>>,
5353 display_name: String,
5354 plan: SelectPlan,
5355 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
5356 ) -> ExecutorResult<SelectExecution<P>> {
5357 let table_ref = table.as_ref();
5358 let distinct = plan.distinct;
5359 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
5360 for aggregate in plan.aggregates {
5361 match aggregate {
5362 AggregateExpr::CountStar { alias, distinct } => {
5363 specs.push(AggregateSpec {
5364 alias,
5365 kind: AggregateKind::Count {
5366 field_id: None,
5367 distinct,
5368 },
5369 });
5370 }
5371 AggregateExpr::Column {
5372 column,
5373 alias,
5374 function,
5375 distinct,
5376 } => {
5377 let col = table_ref.schema.resolve(&column).ok_or_else(|| {
5378 Error::InvalidArgumentError(format!(
5379 "unknown column '{}' in aggregate",
5380 column
5381 ))
5382 })?;
5383
5384 let kind = match function {
5385 AggregateFunction::Count => AggregateKind::Count {
5386 field_id: Some(col.field_id),
5387 distinct,
5388 },
5389 AggregateFunction::SumInt64 => {
5390 let input_type = Self::validate_aggregate_type(
5391 Some(col.data_type.clone()),
5392 "SUM",
5393 &[DataType::Int64, DataType::Float64],
5394 )?;
5395 AggregateKind::Sum {
5396 field_id: col.field_id,
5397 data_type: input_type,
5398 distinct,
5399 }
5400 }
5401 AggregateFunction::TotalInt64 => {
5402 let input_type = Self::validate_aggregate_type(
5403 Some(col.data_type.clone()),
5404 "TOTAL",
5405 &[DataType::Int64, DataType::Float64],
5406 )?;
5407 AggregateKind::Total {
5408 field_id: col.field_id,
5409 data_type: input_type,
5410 distinct,
5411 }
5412 }
5413 AggregateFunction::MinInt64 => {
5414 let input_type = Self::validate_aggregate_type(
5415 Some(col.data_type.clone()),
5416 "MIN",
5417 &[DataType::Int64, DataType::Float64],
5418 )?;
5419 AggregateKind::Min {
5420 field_id: col.field_id,
5421 data_type: input_type,
5422 }
5423 }
5424 AggregateFunction::MaxInt64 => {
5425 let input_type = Self::validate_aggregate_type(
5426 Some(col.data_type.clone()),
5427 "MAX",
5428 &[DataType::Int64, DataType::Float64],
5429 )?;
5430 AggregateKind::Max {
5431 field_id: col.field_id,
5432 data_type: input_type,
5433 }
5434 }
5435 AggregateFunction::CountNulls => {
5436 if distinct {
5437 return Err(Error::InvalidArgumentError(
5438 "DISTINCT is not supported for COUNT_NULLS".into(),
5439 ));
5440 }
5441 AggregateKind::CountNulls {
5442 field_id: col.field_id,
5443 }
5444 }
5445 AggregateFunction::GroupConcat => AggregateKind::GroupConcat {
5446 field_id: col.field_id,
5447 distinct,
5448 separator: ",".to_string(),
5449 },
5450 };
5451 specs.push(AggregateSpec { alias, kind });
5452 }
5453 }
5454 }
5455
5456 if specs.is_empty() {
5457 return Err(Error::InvalidArgumentError(
5458 "aggregate query requires at least one aggregate expression".into(),
5459 ));
5460 }
5461
5462 let had_filter = plan.filter.is_some();
5463 let filter_expr = match &plan.filter {
5464 Some(filter_wrapper) => {
5465 if !filter_wrapper.subqueries.is_empty() {
5466 return Err(Error::InvalidArgumentError(
5467 "EXISTS subqueries not yet implemented in aggregate queries".into(),
5468 ));
5469 }
5470 let mut translated = crate::translation::expression::translate_predicate(
5471 filter_wrapper.predicate.clone(),
5472 table.schema.as_ref(),
5473 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
5474 )?;
5475
5476 let mut filter_scalar_ids = FxHashSet::default();
5478 collect_predicate_scalar_subquery_ids(&translated, &mut filter_scalar_ids);
5479
5480 if !filter_scalar_ids.is_empty() {
5481 let filter_scalar_lookup: FxHashMap<SubqueryId, &llkv_plan::ScalarSubquery> =
5483 plan.scalar_subqueries
5484 .iter()
5485 .filter(|subquery| filter_scalar_ids.contains(&subquery.id))
5486 .map(|subquery| (subquery.id, subquery))
5487 .collect();
5488
5489 let base_schema = Arc::new(Schema::new(Vec::<Field>::new()));
5491 let base_lookup = FxHashMap::default();
5492 let mut context =
5493 CrossProductExpressionContext::new(base_schema.as_ref(), base_lookup)?;
5494 let empty_batch =
5495 RecordBatch::new_empty(Arc::new(Schema::new(Vec::<Field>::new())));
5496
5497 let mut scalar_literals: FxHashMap<SubqueryId, Literal> = FxHashMap::default();
5499 for (subquery_id, subquery) in filter_scalar_lookup.iter() {
5500 let literal = self.evaluate_scalar_subquery_literal(
5501 &mut context,
5502 subquery,
5503 &empty_batch,
5504 0,
5505 )?;
5506 scalar_literals.insert(*subquery_id, literal);
5507 }
5508
5509 translated = rewrite_predicate_scalar_subqueries(translated, &scalar_literals)?;
5511 }
5512
5513 translated
5514 }
5515 None => {
5516 let field_id = table.schema.first_field_id().ok_or_else(|| {
5517 Error::InvalidArgumentError(
5518 "table has no columns; cannot perform aggregate scan".into(),
5519 )
5520 })?;
5521 crate::translation::expression::full_table_scan_filter(field_id)
5522 }
5523 };
5524
5525 let mut projections = Vec::new();
5527 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(specs.len());
5528
5529 for spec in &specs {
5530 if let Some(field_id) = spec.kind.field_id() {
5531 let proj_idx = projections.len();
5532 spec_to_projection.push(Some(proj_idx));
5533 projections.push(ScanProjection::from(StoreProjection::with_alias(
5534 LogicalFieldId::for_user(table.table.table_id(), field_id),
5535 table
5536 .schema
5537 .column_by_field_id(field_id)
5538 .map(|c| c.name.clone())
5539 .unwrap_or_else(|| format!("col{field_id}")),
5540 )));
5541 } else {
5542 spec_to_projection.push(None);
5543 }
5544 }
5545
5546 if projections.is_empty() {
5547 let field_id = table.schema.first_field_id().ok_or_else(|| {
5548 Error::InvalidArgumentError(
5549 "table has no columns; cannot perform aggregate scan".into(),
5550 )
5551 })?;
5552 projections.push(ScanProjection::from(StoreProjection::with_alias(
5553 LogicalFieldId::for_user(table.table.table_id(), field_id),
5554 table
5555 .schema
5556 .column_by_field_id(field_id)
5557 .map(|c| c.name.clone())
5558 .unwrap_or_else(|| format!("col{field_id}")),
5559 )));
5560 }
5561
5562 let options = ScanStreamOptions {
5563 include_nulls: true,
5564 order: None,
5565 row_id_filter: row_filter.clone(),
5566 };
5567
5568 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
5569 let mut count_star_override: Option<i64> = None;
5573 if !had_filter && row_filter.is_none() {
5574 let total_rows = table.total_rows.load(Ordering::SeqCst);
5576 tracing::debug!(
5577 "[AGGREGATE] Using COUNT(*) shortcut: total_rows={}",
5578 total_rows
5579 );
5580 if total_rows > i64::MAX as u64 {
5581 return Err(Error::InvalidArgumentError(
5582 "COUNT(*) result exceeds supported range".into(),
5583 ));
5584 }
5585 count_star_override = Some(total_rows as i64);
5586 } else {
5587 tracing::debug!(
5588 "[AGGREGATE] NOT using COUNT(*) shortcut: had_filter={}, has_row_filter={}",
5589 had_filter,
5590 row_filter.is_some()
5591 );
5592 }
5593
5594 for (idx, spec) in specs.iter().enumerate() {
5595 states.push(AggregateState {
5596 alias: spec.alias.clone(),
5597 accumulator: AggregateAccumulator::new_with_projection_index(
5598 spec,
5599 spec_to_projection[idx],
5600 count_star_override,
5601 )?,
5602 override_value: match &spec.kind {
5603 AggregateKind::Count { field_id: None, .. } => {
5604 tracing::debug!(
5605 "[AGGREGATE] CountStar override_value={:?}",
5606 count_star_override
5607 );
5608 count_star_override
5609 }
5610 _ => None,
5611 },
5612 });
5613 }
5614
5615 let mut error: Option<Error> = None;
5616 match table.table.scan_stream(
5617 projections,
5618 &filter_expr,
5619 ScanStreamOptions {
5620 row_id_filter: row_filter.clone(),
5621 ..options
5622 },
5623 |batch| {
5624 if error.is_some() {
5625 return;
5626 }
5627 for state in &mut states {
5628 if let Err(err) = state.update(&batch) {
5629 error = Some(err);
5630 return;
5631 }
5632 }
5633 },
5634 ) {
5635 Ok(()) => {}
5636 Err(llkv_result::Error::NotFound) => {
5637 }
5640 Err(err) => return Err(err),
5641 }
5642 if let Some(err) = error {
5643 return Err(err);
5644 }
5645
5646 let mut fields = Vec::with_capacity(states.len());
5647 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
5648 for state in states {
5649 let (field, array) = state.finalize()?;
5650 fields.push(field);
5651 arrays.push(array);
5652 }
5653
5654 let schema = Arc::new(Schema::new(fields));
5655 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
5656
5657 if distinct {
5658 let mut state = DistinctState::default();
5659 batch = match distinct_filter_batch(batch, &mut state)? {
5660 Some(filtered) => filtered,
5661 None => RecordBatch::new_empty(Arc::clone(&schema)),
5662 };
5663 }
5664
5665 let schema = batch.schema();
5666
5667 Ok(SelectExecution::new_single_batch(
5668 display_name,
5669 schema,
5670 batch,
5671 ))
5672 }
5673
5674 fn execute_computed_aggregates(
5677 &self,
5678 table: Arc<ExecutorTable<P>>,
5679 display_name: String,
5680 plan: SelectPlan,
5681 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
5682 ) -> ExecutorResult<SelectExecution<P>> {
5683 use arrow::array::Int64Array;
5684 use llkv_expr::expr::AggregateCall;
5685
5686 let table_ref = table.as_ref();
5687 let distinct = plan.distinct;
5688
5689 let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
5691 for proj in &plan.projections {
5692 if let SelectProjection::Computed { expr, .. } = proj {
5693 Self::collect_aggregates(expr, &mut aggregate_specs);
5694 }
5695 }
5696
5697 let filter_predicate = plan
5699 .filter
5700 .as_ref()
5701 .map(|wrapper| {
5702 if !wrapper.subqueries.is_empty() {
5703 return Err(Error::InvalidArgumentError(
5704 "EXISTS subqueries not yet implemented with aggregates".into(),
5705 ));
5706 }
5707 Ok(wrapper.predicate.clone())
5708 })
5709 .transpose()?;
5710
5711 let computed_aggregates = self.compute_aggregate_values(
5712 table.clone(),
5713 &filter_predicate,
5714 &aggregate_specs,
5715 row_filter.clone(),
5716 )?;
5717
5718 let mut fields = Vec::with_capacity(plan.projections.len());
5720 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
5721
5722 for proj in &plan.projections {
5723 match proj {
5724 SelectProjection::AllColumns | SelectProjection::AllColumnsExcept { .. } => {
5725 return Err(Error::InvalidArgumentError(
5726 "Wildcard projections not supported with computed aggregates".into(),
5727 ));
5728 }
5729 SelectProjection::Column { name, alias } => {
5730 let col = table_ref.schema.resolve(name).ok_or_else(|| {
5731 Error::InvalidArgumentError(format!("unknown column '{}'", name))
5732 })?;
5733 let field_name = alias.as_ref().unwrap_or(name);
5734 fields.push(arrow::datatypes::Field::new(
5735 field_name,
5736 col.data_type.clone(),
5737 col.nullable,
5738 ));
5739 return Err(Error::InvalidArgumentError(
5742 "Regular columns not supported in aggregate queries without GROUP BY"
5743 .into(),
5744 ));
5745 }
5746 SelectProjection::Computed { expr, alias } => {
5747 if let ScalarExpr::Aggregate(agg) = expr {
5749 let key = format!("{:?}", agg);
5750 if let Some(agg_value) = computed_aggregates.get(&key) {
5751 match agg_value {
5752 AggregateValue::Null => {
5753 fields.push(arrow::datatypes::Field::new(
5754 alias,
5755 DataType::Int64,
5756 true,
5757 ));
5758 arrays
5759 .push(Arc::new(Int64Array::from(vec![None::<i64>]))
5760 as ArrayRef);
5761 }
5762 AggregateValue::Int64(v) => {
5763 fields.push(arrow::datatypes::Field::new(
5764 alias,
5765 DataType::Int64,
5766 true,
5767 ));
5768 arrays.push(
5769 Arc::new(Int64Array::from(vec![Some(*v)])) as ArrayRef
5770 );
5771 }
5772 AggregateValue::Float64(v) => {
5773 fields.push(arrow::datatypes::Field::new(
5774 alias,
5775 DataType::Float64,
5776 true,
5777 ));
5778 arrays
5779 .push(Arc::new(Float64Array::from(vec![Some(*v)]))
5780 as ArrayRef);
5781 }
5782 AggregateValue::Decimal128 { value, scale } => {
5783 let precision = if *value == 0 {
5785 1
5786 } else {
5787 (*value).abs().to_string().len() as u8
5788 };
5789 fields.push(arrow::datatypes::Field::new(
5790 alias,
5791 DataType::Decimal128(precision, *scale),
5792 true,
5793 ));
5794 let array = Decimal128Array::from(vec![Some(*value)])
5795 .with_precision_and_scale(precision, *scale)
5796 .map_err(|e| {
5797 Error::Internal(format!("invalid Decimal128: {}", e))
5798 })?;
5799 arrays.push(Arc::new(array) as ArrayRef);
5800 }
5801 AggregateValue::String(s) => {
5802 fields.push(arrow::datatypes::Field::new(
5803 alias,
5804 DataType::Utf8,
5805 true,
5806 ));
5807 arrays
5808 .push(Arc::new(StringArray::from(vec![Some(s.as_str())]))
5809 as ArrayRef);
5810 }
5811 }
5812 continue;
5813 }
5814 }
5815
5816 let value = Self::evaluate_expr_with_aggregates(expr, &computed_aggregates)?;
5818
5819 fields.push(arrow::datatypes::Field::new(alias, DataType::Int64, true));
5820
5821 let array = Arc::new(Int64Array::from(vec![value])) as ArrayRef;
5822 arrays.push(array);
5823 }
5824 }
5825 }
5826
5827 let schema = Arc::new(Schema::new(fields));
5828 let mut batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
5829
5830 if distinct {
5831 let mut state = DistinctState::default();
5832 batch = match distinct_filter_batch(batch, &mut state)? {
5833 Some(filtered) => filtered,
5834 None => RecordBatch::new_empty(Arc::clone(&schema)),
5835 };
5836 }
5837
5838 let schema = batch.schema();
5839
5840 Ok(SelectExecution::new_single_batch(
5841 display_name,
5842 schema,
5843 batch,
5844 ))
5845 }
5846
5847 fn build_aggregate_spec_for_cross_product(
5850 agg_call: &llkv_expr::expr::AggregateCall<String>,
5851 alias: String,
5852 data_type: Option<DataType>,
5853 ) -> ExecutorResult<llkv_aggregate::AggregateSpec> {
5854 use llkv_expr::expr::AggregateCall;
5855
5856 let kind = match agg_call {
5857 AggregateCall::CountStar => llkv_aggregate::AggregateKind::Count {
5858 field_id: None,
5859 distinct: false,
5860 },
5861 AggregateCall::Count { distinct, .. } => llkv_aggregate::AggregateKind::Count {
5862 field_id: Some(0),
5863 distinct: *distinct,
5864 },
5865 AggregateCall::Sum { distinct, .. } => llkv_aggregate::AggregateKind::Sum {
5866 field_id: 0,
5867 data_type: Self::validate_aggregate_type(
5868 data_type.clone(),
5869 "SUM",
5870 &[DataType::Int64, DataType::Float64],
5871 )?,
5872 distinct: *distinct,
5873 },
5874 AggregateCall::Total { distinct, .. } => llkv_aggregate::AggregateKind::Total {
5875 field_id: 0,
5876 data_type: Self::validate_aggregate_type(
5877 data_type.clone(),
5878 "TOTAL",
5879 &[DataType::Int64, DataType::Float64],
5880 )?,
5881 distinct: *distinct,
5882 },
5883 AggregateCall::Avg { distinct, .. } => llkv_aggregate::AggregateKind::Avg {
5884 field_id: 0,
5885 data_type: Self::validate_aggregate_type(
5886 data_type.clone(),
5887 "AVG",
5888 &[DataType::Int64, DataType::Float64],
5889 )?,
5890 distinct: *distinct,
5891 },
5892 AggregateCall::Min(_) => llkv_aggregate::AggregateKind::Min {
5893 field_id: 0,
5894 data_type: Self::validate_aggregate_type(
5895 data_type.clone(),
5896 "MIN",
5897 &[DataType::Int64, DataType::Float64],
5898 )?,
5899 },
5900 AggregateCall::Max(_) => llkv_aggregate::AggregateKind::Max {
5901 field_id: 0,
5902 data_type: Self::validate_aggregate_type(
5903 data_type.clone(),
5904 "MAX",
5905 &[DataType::Int64, DataType::Float64],
5906 )?,
5907 },
5908 AggregateCall::CountNulls(_) => {
5909 llkv_aggregate::AggregateKind::CountNulls { field_id: 0 }
5910 }
5911 AggregateCall::GroupConcat {
5912 distinct,
5913 separator,
5914 ..
5915 } => llkv_aggregate::AggregateKind::GroupConcat {
5916 field_id: 0,
5917 distinct: *distinct,
5918 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
5919 },
5920 };
5921
5922 Ok(llkv_aggregate::AggregateSpec { alias, kind })
5923 }
5924
5925 fn validate_aggregate_type(
5937 data_type: Option<DataType>,
5938 func_name: &str,
5939 allowed: &[DataType],
5940 ) -> ExecutorResult<DataType> {
5941 let dt = data_type.ok_or_else(|| {
5942 Error::Internal(format!(
5943 "missing input type metadata for {func_name} aggregate"
5944 ))
5945 })?;
5946
5947 if matches!(func_name, "SUM" | "AVG" | "TOTAL" | "MIN" | "MAX") {
5950 match dt {
5951 DataType::Int64 | DataType::Float64 | DataType::Decimal128(_, _) => Ok(dt),
5953
5954 DataType::Utf8 | DataType::Boolean | DataType::Date32 => Ok(DataType::Float64),
5957
5958 DataType::Null => Ok(DataType::Float64),
5961
5962 _ => Err(Error::InvalidArgumentError(format!(
5963 "{func_name} aggregate not supported for column type {:?}",
5964 dt
5965 ))),
5966 }
5967 } else {
5968 if allowed.iter().any(|candidate| candidate == &dt) {
5970 Ok(dt)
5971 } else {
5972 Err(Error::InvalidArgumentError(format!(
5973 "{func_name} aggregate not supported for column type {:?}",
5974 dt
5975 )))
5976 }
5977 }
5978 }
5979
5980 fn collect_aggregates(
5982 expr: &ScalarExpr<String>,
5983 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
5984 ) {
5985 match expr {
5986 ScalarExpr::Aggregate(agg) => {
5987 let key = format!("{:?}", agg);
5989 if !aggregates.iter().any(|(k, _)| k == &key) {
5990 aggregates.push((key, agg.clone()));
5991 }
5992 }
5993 ScalarExpr::Binary { left, right, .. } => {
5994 Self::collect_aggregates(left, aggregates);
5995 Self::collect_aggregates(right, aggregates);
5996 }
5997 ScalarExpr::Compare { left, right, .. } => {
5998 Self::collect_aggregates(left, aggregates);
5999 Self::collect_aggregates(right, aggregates);
6000 }
6001 ScalarExpr::GetField { base, .. } => {
6002 Self::collect_aggregates(base, aggregates);
6003 }
6004 ScalarExpr::Cast { expr, .. } => {
6005 Self::collect_aggregates(expr, aggregates);
6006 }
6007 ScalarExpr::Not(expr) => {
6008 Self::collect_aggregates(expr, aggregates);
6009 }
6010 ScalarExpr::IsNull { expr, .. } => {
6011 Self::collect_aggregates(expr, aggregates);
6012 }
6013 ScalarExpr::Case {
6014 operand,
6015 branches,
6016 else_expr,
6017 } => {
6018 if let Some(inner) = operand.as_deref() {
6019 Self::collect_aggregates(inner, aggregates);
6020 }
6021 for (when_expr, then_expr) in branches {
6022 Self::collect_aggregates(when_expr, aggregates);
6023 Self::collect_aggregates(then_expr, aggregates);
6024 }
6025 if let Some(inner) = else_expr.as_deref() {
6026 Self::collect_aggregates(inner, aggregates);
6027 }
6028 }
6029 ScalarExpr::Coalesce(items) => {
6030 for item in items {
6031 Self::collect_aggregates(item, aggregates);
6032 }
6033 }
6034 ScalarExpr::Column(_) | ScalarExpr::Literal(_) | ScalarExpr::Random => {}
6035 ScalarExpr::ScalarSubquery(_) => {}
6036 }
6037 }
6038
6039 fn collect_aggregates_from_predicate(
6041 expr: &llkv_expr::expr::Expr<String>,
6042 aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
6043 ) {
6044 match expr {
6045 llkv_expr::expr::Expr::Compare { left, right, .. } => {
6046 Self::collect_aggregates(left, aggregates);
6047 Self::collect_aggregates(right, aggregates);
6048 }
6049 llkv_expr::expr::Expr::And(exprs) | llkv_expr::expr::Expr::Or(exprs) => {
6050 for e in exprs {
6051 Self::collect_aggregates_from_predicate(e, aggregates);
6052 }
6053 }
6054 llkv_expr::expr::Expr::Not(inner) => {
6055 Self::collect_aggregates_from_predicate(inner, aggregates);
6056 }
6057 llkv_expr::expr::Expr::InList {
6058 expr: test_expr,
6059 list,
6060 ..
6061 } => {
6062 Self::collect_aggregates(test_expr, aggregates);
6063 for item in list {
6064 Self::collect_aggregates(item, aggregates);
6065 }
6066 }
6067 llkv_expr::expr::Expr::IsNull { expr, .. } => {
6068 Self::collect_aggregates(expr, aggregates);
6069 }
6070 llkv_expr::expr::Expr::Literal(_) => {}
6071 llkv_expr::expr::Expr::Pred(_) => {}
6072 llkv_expr::expr::Expr::Exists(_) => {}
6073 }
6074 }
6075
6076 fn compute_aggregate_values(
6078 &self,
6079 table: Arc<ExecutorTable<P>>,
6080 filter: &Option<llkv_expr::expr::Expr<'static, String>>,
6081 aggregate_specs: &[(String, llkv_expr::expr::AggregateCall<String>)],
6082 row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
6083 ) -> ExecutorResult<FxHashMap<String, AggregateValue>> {
6084 use llkv_expr::expr::AggregateCall;
6085
6086 let table_ref = table.as_ref();
6087 let mut results =
6088 FxHashMap::with_capacity_and_hasher(aggregate_specs.len(), Default::default());
6089
6090 let mut specs: Vec<AggregateSpec> = Vec::with_capacity(aggregate_specs.len());
6091 let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(aggregate_specs.len());
6092 let mut projections: Vec<ScanProjection> = Vec::new();
6093 let mut column_projection_cache: FxHashMap<FieldId, usize> = FxHashMap::default();
6094 let mut computed_projection_cache: FxHashMap<String, (usize, DataType)> =
6095 FxHashMap::default();
6096 let mut computed_alias_counter: usize = 0;
6097
6098 for (key, agg) in aggregate_specs {
6099 match agg {
6100 AggregateCall::CountStar => {
6101 specs.push(AggregateSpec {
6102 alias: key.clone(),
6103 kind: AggregateKind::Count {
6104 field_id: None,
6105 distinct: false,
6106 },
6107 });
6108 spec_to_projection.push(None);
6109 }
6110 AggregateCall::Count { expr, distinct } => {
6111 if let Some(col_name) = try_extract_simple_column(expr) {
6112 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6113 Error::InvalidArgumentError(format!(
6114 "unknown column '{}' in aggregate",
6115 col_name
6116 ))
6117 })?;
6118 let projection_index = get_or_insert_column_projection(
6119 &mut projections,
6120 &mut column_projection_cache,
6121 table_ref,
6122 col,
6123 );
6124 specs.push(AggregateSpec {
6125 alias: key.clone(),
6126 kind: AggregateKind::Count {
6127 field_id: Some(col.field_id),
6128 distinct: *distinct,
6129 },
6130 });
6131 spec_to_projection.push(Some(projection_index));
6132 } else {
6133 let (projection_index, _dtype) = ensure_computed_projection(
6134 expr,
6135 table_ref,
6136 &mut projections,
6137 &mut computed_projection_cache,
6138 &mut computed_alias_counter,
6139 )?;
6140 let field_id = u32::try_from(projection_index).map_err(|_| {
6141 Error::InvalidArgumentError(
6142 "aggregate projection index exceeds supported range".into(),
6143 )
6144 })?;
6145 specs.push(AggregateSpec {
6146 alias: key.clone(),
6147 kind: AggregateKind::Count {
6148 field_id: Some(field_id),
6149 distinct: *distinct,
6150 },
6151 });
6152 spec_to_projection.push(Some(projection_index));
6153 }
6154 }
6155 AggregateCall::Sum { expr, distinct } => {
6156 let (projection_index, data_type, field_id) =
6157 if let Some(col_name) = try_extract_simple_column(expr) {
6158 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6159 Error::InvalidArgumentError(format!(
6160 "unknown column '{}' in aggregate",
6161 col_name
6162 ))
6163 })?;
6164 let projection_index = get_or_insert_column_projection(
6165 &mut projections,
6166 &mut column_projection_cache,
6167 table_ref,
6168 col,
6169 );
6170 let data_type = col.data_type.clone();
6171 (projection_index, data_type, col.field_id)
6172 } else {
6173 let (projection_index, inferred_type) = ensure_computed_projection(
6174 expr,
6175 table_ref,
6176 &mut projections,
6177 &mut computed_projection_cache,
6178 &mut computed_alias_counter,
6179 )?;
6180 let field_id = u32::try_from(projection_index).map_err(|_| {
6181 Error::InvalidArgumentError(
6182 "aggregate projection index exceeds supported range".into(),
6183 )
6184 })?;
6185 (projection_index, inferred_type, field_id)
6186 };
6187 let normalized_type = Self::validate_aggregate_type(
6188 Some(data_type.clone()),
6189 "SUM",
6190 &[DataType::Int64, DataType::Float64],
6191 )?;
6192 specs.push(AggregateSpec {
6193 alias: key.clone(),
6194 kind: AggregateKind::Sum {
6195 field_id,
6196 data_type: normalized_type,
6197 distinct: *distinct,
6198 },
6199 });
6200 spec_to_projection.push(Some(projection_index));
6201 }
6202 AggregateCall::Total { expr, distinct } => {
6203 let (projection_index, data_type, field_id) =
6204 if let Some(col_name) = try_extract_simple_column(expr) {
6205 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6206 Error::InvalidArgumentError(format!(
6207 "unknown column '{}' in aggregate",
6208 col_name
6209 ))
6210 })?;
6211 let projection_index = get_or_insert_column_projection(
6212 &mut projections,
6213 &mut column_projection_cache,
6214 table_ref,
6215 col,
6216 );
6217 let data_type = col.data_type.clone();
6218 (projection_index, data_type, col.field_id)
6219 } else {
6220 let (projection_index, inferred_type) = ensure_computed_projection(
6221 expr,
6222 table_ref,
6223 &mut projections,
6224 &mut computed_projection_cache,
6225 &mut computed_alias_counter,
6226 )?;
6227 let field_id = u32::try_from(projection_index).map_err(|_| {
6228 Error::InvalidArgumentError(
6229 "aggregate projection index exceeds supported range".into(),
6230 )
6231 })?;
6232 (projection_index, inferred_type, field_id)
6233 };
6234 let normalized_type = Self::validate_aggregate_type(
6235 Some(data_type.clone()),
6236 "TOTAL",
6237 &[DataType::Int64, DataType::Float64],
6238 )?;
6239 specs.push(AggregateSpec {
6240 alias: key.clone(),
6241 kind: AggregateKind::Total {
6242 field_id,
6243 data_type: normalized_type,
6244 distinct: *distinct,
6245 },
6246 });
6247 spec_to_projection.push(Some(projection_index));
6248 }
6249 AggregateCall::Avg { expr, distinct } => {
6250 let (projection_index, data_type, field_id) =
6251 if let Some(col_name) = try_extract_simple_column(expr) {
6252 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6253 Error::InvalidArgumentError(format!(
6254 "unknown column '{}' in aggregate",
6255 col_name
6256 ))
6257 })?;
6258 let projection_index = get_or_insert_column_projection(
6259 &mut projections,
6260 &mut column_projection_cache,
6261 table_ref,
6262 col,
6263 );
6264 let data_type = col.data_type.clone();
6265 (projection_index, data_type, col.field_id)
6266 } else {
6267 let (projection_index, inferred_type) = ensure_computed_projection(
6268 expr,
6269 table_ref,
6270 &mut projections,
6271 &mut computed_projection_cache,
6272 &mut computed_alias_counter,
6273 )?;
6274 tracing::debug!(
6275 "AVG aggregate expr={:?} inferred_type={:?}",
6276 expr,
6277 inferred_type
6278 );
6279 let field_id = u32::try_from(projection_index).map_err(|_| {
6280 Error::InvalidArgumentError(
6281 "aggregate projection index exceeds supported range".into(),
6282 )
6283 })?;
6284 (projection_index, inferred_type, field_id)
6285 };
6286 let normalized_type = Self::validate_aggregate_type(
6287 Some(data_type.clone()),
6288 "AVG",
6289 &[DataType::Int64, DataType::Float64],
6290 )?;
6291 specs.push(AggregateSpec {
6292 alias: key.clone(),
6293 kind: AggregateKind::Avg {
6294 field_id,
6295 data_type: normalized_type,
6296 distinct: *distinct,
6297 },
6298 });
6299 spec_to_projection.push(Some(projection_index));
6300 }
6301 AggregateCall::Min(expr) => {
6302 let (projection_index, data_type, field_id) =
6303 if let Some(col_name) = try_extract_simple_column(expr) {
6304 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6305 Error::InvalidArgumentError(format!(
6306 "unknown column '{}' in aggregate",
6307 col_name
6308 ))
6309 })?;
6310 let projection_index = get_or_insert_column_projection(
6311 &mut projections,
6312 &mut column_projection_cache,
6313 table_ref,
6314 col,
6315 );
6316 let data_type = col.data_type.clone();
6317 (projection_index, data_type, col.field_id)
6318 } else {
6319 let (projection_index, inferred_type) = ensure_computed_projection(
6320 expr,
6321 table_ref,
6322 &mut projections,
6323 &mut computed_projection_cache,
6324 &mut computed_alias_counter,
6325 )?;
6326 let field_id = u32::try_from(projection_index).map_err(|_| {
6327 Error::InvalidArgumentError(
6328 "aggregate projection index exceeds supported range".into(),
6329 )
6330 })?;
6331 (projection_index, inferred_type, field_id)
6332 };
6333 let normalized_type = Self::validate_aggregate_type(
6334 Some(data_type.clone()),
6335 "MIN",
6336 &[DataType::Int64, DataType::Float64],
6337 )?;
6338 specs.push(AggregateSpec {
6339 alias: key.clone(),
6340 kind: AggregateKind::Min {
6341 field_id,
6342 data_type: normalized_type,
6343 },
6344 });
6345 spec_to_projection.push(Some(projection_index));
6346 }
6347 AggregateCall::Max(expr) => {
6348 let (projection_index, data_type, field_id) =
6349 if let Some(col_name) = try_extract_simple_column(expr) {
6350 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6351 Error::InvalidArgumentError(format!(
6352 "unknown column '{}' in aggregate",
6353 col_name
6354 ))
6355 })?;
6356 let projection_index = get_or_insert_column_projection(
6357 &mut projections,
6358 &mut column_projection_cache,
6359 table_ref,
6360 col,
6361 );
6362 let data_type = col.data_type.clone();
6363 (projection_index, data_type, col.field_id)
6364 } else {
6365 let (projection_index, inferred_type) = ensure_computed_projection(
6366 expr,
6367 table_ref,
6368 &mut projections,
6369 &mut computed_projection_cache,
6370 &mut computed_alias_counter,
6371 )?;
6372 let field_id = u32::try_from(projection_index).map_err(|_| {
6373 Error::InvalidArgumentError(
6374 "aggregate projection index exceeds supported range".into(),
6375 )
6376 })?;
6377 (projection_index, inferred_type, field_id)
6378 };
6379 let normalized_type = Self::validate_aggregate_type(
6380 Some(data_type.clone()),
6381 "MAX",
6382 &[DataType::Int64, DataType::Float64],
6383 )?;
6384 specs.push(AggregateSpec {
6385 alias: key.clone(),
6386 kind: AggregateKind::Max {
6387 field_id,
6388 data_type: normalized_type,
6389 },
6390 });
6391 spec_to_projection.push(Some(projection_index));
6392 }
6393 AggregateCall::CountNulls(expr) => {
6394 if let Some(col_name) = try_extract_simple_column(expr) {
6395 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6396 Error::InvalidArgumentError(format!(
6397 "unknown column '{}' in aggregate",
6398 col_name
6399 ))
6400 })?;
6401 let projection_index = get_or_insert_column_projection(
6402 &mut projections,
6403 &mut column_projection_cache,
6404 table_ref,
6405 col,
6406 );
6407 specs.push(AggregateSpec {
6408 alias: key.clone(),
6409 kind: AggregateKind::CountNulls {
6410 field_id: col.field_id,
6411 },
6412 });
6413 spec_to_projection.push(Some(projection_index));
6414 } else {
6415 let (projection_index, _dtype) = ensure_computed_projection(
6416 expr,
6417 table_ref,
6418 &mut projections,
6419 &mut computed_projection_cache,
6420 &mut computed_alias_counter,
6421 )?;
6422 let field_id = u32::try_from(projection_index).map_err(|_| {
6423 Error::InvalidArgumentError(
6424 "aggregate projection index exceeds supported range".into(),
6425 )
6426 })?;
6427 specs.push(AggregateSpec {
6428 alias: key.clone(),
6429 kind: AggregateKind::CountNulls { field_id },
6430 });
6431 spec_to_projection.push(Some(projection_index));
6432 }
6433 }
6434 AggregateCall::GroupConcat {
6435 expr,
6436 distinct,
6437 separator,
6438 } => {
6439 if let Some(col_name) = try_extract_simple_column(expr) {
6440 let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
6441 Error::InvalidArgumentError(format!(
6442 "unknown column '{}' in aggregate",
6443 col_name
6444 ))
6445 })?;
6446 let projection_index = get_or_insert_column_projection(
6447 &mut projections,
6448 &mut column_projection_cache,
6449 table_ref,
6450 col,
6451 );
6452 specs.push(AggregateSpec {
6453 alias: key.clone(),
6454 kind: AggregateKind::GroupConcat {
6455 field_id: col.field_id,
6456 distinct: *distinct,
6457 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
6458 },
6459 });
6460 spec_to_projection.push(Some(projection_index));
6461 } else {
6462 let (projection_index, _dtype) = ensure_computed_projection(
6463 expr,
6464 table_ref,
6465 &mut projections,
6466 &mut computed_projection_cache,
6467 &mut computed_alias_counter,
6468 )?;
6469 let field_id = u32::try_from(projection_index).map_err(|_| {
6470 Error::InvalidArgumentError(
6471 "aggregate projection index exceeds supported range".into(),
6472 )
6473 })?;
6474 specs.push(AggregateSpec {
6475 alias: key.clone(),
6476 kind: AggregateKind::GroupConcat {
6477 field_id,
6478 distinct: *distinct,
6479 separator: separator.clone().unwrap_or_else(|| ",".to_string()),
6480 },
6481 });
6482 spec_to_projection.push(Some(projection_index));
6483 }
6484 }
6485 }
6486 }
6487
6488 let filter_expr = match filter {
6489 Some(expr) => crate::translation::expression::translate_predicate(
6490 expr.clone(),
6491 table_ref.schema.as_ref(),
6492 |name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
6493 )?,
6494 None => {
6495 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
6496 Error::InvalidArgumentError(
6497 "table has no columns; cannot perform aggregate scan".into(),
6498 )
6499 })?;
6500 crate::translation::expression::full_table_scan_filter(field_id)
6501 }
6502 };
6503
6504 if projections.is_empty() {
6505 let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
6506 Error::InvalidArgumentError(
6507 "table has no columns; cannot perform aggregate scan".into(),
6508 )
6509 })?;
6510 projections.push(ScanProjection::from(StoreProjection::with_alias(
6511 LogicalFieldId::for_user(table.table.table_id(), field_id),
6512 table
6513 .schema
6514 .column_by_field_id(field_id)
6515 .map(|c| c.name.clone())
6516 .unwrap_or_else(|| format!("col{field_id}")),
6517 )));
6518 }
6519
6520 let base_options = ScanStreamOptions {
6521 include_nulls: true,
6522 order: None,
6523 row_id_filter: None,
6524 };
6525
6526 let count_star_override: Option<i64> = None;
6527
6528 let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
6529 for (idx, spec) in specs.iter().enumerate() {
6530 states.push(AggregateState {
6531 alias: spec.alias.clone(),
6532 accumulator: AggregateAccumulator::new_with_projection_index(
6533 spec,
6534 spec_to_projection[idx],
6535 count_star_override,
6536 )?,
6537 override_value: match &spec.kind {
6538 AggregateKind::Count { field_id: None, .. } => count_star_override,
6539 _ => None,
6540 },
6541 });
6542 }
6543
6544 let mut error: Option<Error> = None;
6545 match table.table.scan_stream(
6546 projections,
6547 &filter_expr,
6548 ScanStreamOptions {
6549 row_id_filter: row_filter.clone(),
6550 ..base_options
6551 },
6552 |batch| {
6553 if error.is_some() {
6554 return;
6555 }
6556 for state in &mut states {
6557 if let Err(err) = state.update(&batch) {
6558 error = Some(err);
6559 return;
6560 }
6561 }
6562 },
6563 ) {
6564 Ok(()) => {}
6565 Err(llkv_result::Error::NotFound) => {}
6566 Err(err) => return Err(err),
6567 }
6568 if let Some(err) = error {
6569 return Err(err);
6570 }
6571
6572 for state in states {
6573 let alias = state.alias.clone();
6574 let (_field, array) = state.finalize()?;
6575
6576 if let Some(int64_array) = array.as_any().downcast_ref::<arrow::array::Int64Array>() {
6577 if int64_array.len() != 1 {
6578 return Err(Error::Internal(format!(
6579 "Expected single value from aggregate, got {}",
6580 int64_array.len()
6581 )));
6582 }
6583 let value = if int64_array.is_null(0) {
6584 AggregateValue::Null
6585 } else {
6586 AggregateValue::Int64(int64_array.value(0))
6587 };
6588 results.insert(alias, value);
6589 } else if let Some(float64_array) =
6590 array.as_any().downcast_ref::<arrow::array::Float64Array>()
6591 {
6592 if float64_array.len() != 1 {
6593 return Err(Error::Internal(format!(
6594 "Expected single value from aggregate, got {}",
6595 float64_array.len()
6596 )));
6597 }
6598 let value = if float64_array.is_null(0) {
6599 AggregateValue::Null
6600 } else {
6601 AggregateValue::Float64(float64_array.value(0))
6602 };
6603 results.insert(alias, value);
6604 } else if let Some(string_array) =
6605 array.as_any().downcast_ref::<arrow::array::StringArray>()
6606 {
6607 if string_array.len() != 1 {
6608 return Err(Error::Internal(format!(
6609 "Expected single value from aggregate, got {}",
6610 string_array.len()
6611 )));
6612 }
6613 let value = if string_array.is_null(0) {
6614 AggregateValue::Null
6615 } else {
6616 AggregateValue::String(string_array.value(0).to_string())
6617 };
6618 results.insert(alias, value);
6619 } else if let Some(decimal_array) = array
6620 .as_any()
6621 .downcast_ref::<arrow::array::Decimal128Array>()
6622 {
6623 if decimal_array.len() != 1 {
6624 return Err(Error::Internal(format!(
6625 "Expected single value from aggregate, got {}",
6626 decimal_array.len()
6627 )));
6628 }
6629 let value = if decimal_array.is_null(0) {
6630 AggregateValue::Null
6631 } else {
6632 AggregateValue::Decimal128 {
6633 value: decimal_array.value(0),
6634 scale: decimal_array.scale(),
6635 }
6636 };
6637 results.insert(alias, value);
6638 } else {
6639 return Err(Error::Internal(format!(
6640 "Unexpected array type from aggregate: {:?}",
6641 array.data_type()
6642 )));
6643 }
6644 }
6645
6646 Ok(results)
6647 }
6648
6649 fn evaluate_having_expr(
6650 expr: &llkv_expr::expr::Expr<String>,
6651 aggregates: &FxHashMap<String, PlanValue>,
6652 row_batch: &RecordBatch,
6653 column_lookup: &FxHashMap<String, usize>,
6654 row_idx: usize,
6655 ) -> ExecutorResult<Option<bool>> {
6656 fn compare_plan_values_for_pred(
6657 left: &PlanValue,
6658 right: &PlanValue,
6659 ) -> Option<std::cmp::Ordering> {
6660 match (left, right) {
6661 (PlanValue::Integer(l), PlanValue::Integer(r)) => Some(l.cmp(r)),
6662 (PlanValue::Float(l), PlanValue::Float(r)) => l.partial_cmp(r),
6663 (PlanValue::Integer(l), PlanValue::Float(r)) => (*l as f64).partial_cmp(r),
6664 (PlanValue::Float(l), PlanValue::Integer(r)) => l.partial_cmp(&(*r as f64)),
6665 (PlanValue::String(l), PlanValue::String(r)) => Some(l.cmp(r)),
6666 (PlanValue::Interval(l), PlanValue::Interval(r)) => {
6667 Some(compare_interval_values(*l, *r))
6668 }
6669 _ => None,
6670 }
6671 }
6672
6673 fn evaluate_ordering_predicate<F>(
6674 value: &PlanValue,
6675 literal: &Literal,
6676 predicate: F,
6677 ) -> ExecutorResult<Option<bool>>
6678 where
6679 F: Fn(std::cmp::Ordering) -> bool,
6680 {
6681 if matches!(value, PlanValue::Null) {
6682 return Ok(None);
6683 }
6684 let expected = llkv_plan::plan_value_from_literal(literal)?;
6685 if matches!(expected, PlanValue::Null) {
6686 return Ok(None);
6687 }
6688
6689 match compare_plan_values_for_pred(value, &expected) {
6690 Some(ordering) => Ok(Some(predicate(ordering))),
6691 None => Err(Error::InvalidArgumentError(
6692 "unsupported HAVING comparison between column value and literal".into(),
6693 )),
6694 }
6695 }
6696
6697 match expr {
6698 llkv_expr::expr::Expr::Compare { left, op, right } => {
6699 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6700 left,
6701 aggregates,
6702 Some(row_batch),
6703 Some(column_lookup),
6704 row_idx,
6705 )?;
6706 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6707 right,
6708 aggregates,
6709 Some(row_batch),
6710 Some(column_lookup),
6711 row_idx,
6712 )?;
6713
6714 let (left_val, right_val) = match (&left_val, &right_val) {
6716 (PlanValue::Integer(i), PlanValue::Float(_)) => {
6717 (PlanValue::Float(*i as f64), right_val)
6718 }
6719 (PlanValue::Float(_), PlanValue::Integer(i)) => {
6720 (left_val, PlanValue::Float(*i as f64))
6721 }
6722 _ => (left_val, right_val),
6723 };
6724
6725 match (left_val, right_val) {
6726 (PlanValue::Null, _) | (_, PlanValue::Null) => Ok(None),
6728 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
6729 use llkv_expr::expr::CompareOp;
6730 Ok(Some(match op {
6731 CompareOp::Eq => l == r,
6732 CompareOp::NotEq => l != r,
6733 CompareOp::Lt => l < r,
6734 CompareOp::LtEq => l <= r,
6735 CompareOp::Gt => l > r,
6736 CompareOp::GtEq => l >= r,
6737 }))
6738 }
6739 (PlanValue::Float(l), PlanValue::Float(r)) => {
6740 use llkv_expr::expr::CompareOp;
6741 Ok(Some(match op {
6742 CompareOp::Eq => l == r,
6743 CompareOp::NotEq => l != r,
6744 CompareOp::Lt => l < r,
6745 CompareOp::LtEq => l <= r,
6746 CompareOp::Gt => l > r,
6747 CompareOp::GtEq => l >= r,
6748 }))
6749 }
6750 (PlanValue::Interval(l), PlanValue::Interval(r)) => {
6751 use llkv_expr::expr::CompareOp;
6752 let ordering = compare_interval_values(l, r);
6753 Ok(Some(match op {
6754 CompareOp::Eq => ordering == std::cmp::Ordering::Equal,
6755 CompareOp::NotEq => ordering != std::cmp::Ordering::Equal,
6756 CompareOp::Lt => ordering == std::cmp::Ordering::Less,
6757 CompareOp::LtEq => {
6758 matches!(
6759 ordering,
6760 std::cmp::Ordering::Less | std::cmp::Ordering::Equal
6761 )
6762 }
6763 CompareOp::Gt => ordering == std::cmp::Ordering::Greater,
6764 CompareOp::GtEq => {
6765 matches!(
6766 ordering,
6767 std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
6768 )
6769 }
6770 }))
6771 }
6772 _ => Ok(Some(false)),
6773 }
6774 }
6775 llkv_expr::expr::Expr::Not(inner) => {
6776 match Self::evaluate_having_expr(
6778 inner,
6779 aggregates,
6780 row_batch,
6781 column_lookup,
6782 row_idx,
6783 )? {
6784 Some(b) => Ok(Some(!b)),
6785 None => Ok(None), }
6787 }
6788 llkv_expr::expr::Expr::InList {
6789 expr: test_expr,
6790 list,
6791 negated,
6792 } => {
6793 let test_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6794 test_expr,
6795 aggregates,
6796 Some(row_batch),
6797 Some(column_lookup),
6798 row_idx,
6799 )?;
6800
6801 if matches!(test_val, PlanValue::Null) {
6804 return Ok(None);
6805 }
6806
6807 let mut found = false;
6808 let mut has_null = false;
6809
6810 for list_item in list {
6811 let list_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6812 list_item,
6813 aggregates,
6814 Some(row_batch),
6815 Some(column_lookup),
6816 row_idx,
6817 )?;
6818
6819 if matches!(list_val, PlanValue::Null) {
6821 has_null = true;
6822 continue;
6823 }
6824
6825 let matches = match (&test_val, &list_val) {
6827 (PlanValue::Integer(a), PlanValue::Integer(b)) => a == b,
6828 (PlanValue::Float(a), PlanValue::Float(b)) => a == b,
6829 (PlanValue::Integer(a), PlanValue::Float(b)) => (*a as f64) == *b,
6830 (PlanValue::Float(a), PlanValue::Integer(b)) => *a == (*b as f64),
6831 (PlanValue::String(a), PlanValue::String(b)) => a == b,
6832 (PlanValue::Interval(a), PlanValue::Interval(b)) => {
6833 compare_interval_values(*a, *b) == std::cmp::Ordering::Equal
6834 }
6835 _ => false,
6836 };
6837
6838 if matches {
6839 found = true;
6840 break;
6841 }
6842 }
6843
6844 if *negated {
6848 Ok(if found {
6850 Some(false)
6851 } else if has_null {
6852 None } else {
6854 Some(true)
6855 })
6856 } else {
6857 Ok(if found {
6859 Some(true)
6860 } else if has_null {
6861 None } else {
6863 Some(false)
6864 })
6865 }
6866 }
6867 llkv_expr::expr::Expr::IsNull { expr, negated } => {
6868 let val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
6870 expr,
6871 aggregates,
6872 Some(row_batch),
6873 Some(column_lookup),
6874 row_idx,
6875 )?;
6876
6877 let is_null = matches!(val, PlanValue::Null);
6881 Ok(Some(if *negated { !is_null } else { is_null }))
6882 }
6883 llkv_expr::expr::Expr::Literal(val) => Ok(Some(*val)),
6884 llkv_expr::expr::Expr::And(exprs) => {
6885 let mut has_null = false;
6887 for e in exprs {
6888 match Self::evaluate_having_expr(
6889 e,
6890 aggregates,
6891 row_batch,
6892 column_lookup,
6893 row_idx,
6894 )? {
6895 Some(false) => return Ok(Some(false)), None => has_null = true,
6897 Some(true) => {} }
6899 }
6900 Ok(if has_null { None } else { Some(true) })
6901 }
6902 llkv_expr::expr::Expr::Or(exprs) => {
6903 let mut has_null = false;
6905 for e in exprs {
6906 match Self::evaluate_having_expr(
6907 e,
6908 aggregates,
6909 row_batch,
6910 column_lookup,
6911 row_idx,
6912 )? {
6913 Some(true) => return Ok(Some(true)), None => has_null = true,
6915 Some(false) => {} }
6917 }
6918 Ok(if has_null { None } else { Some(false) })
6919 }
6920 llkv_expr::expr::Expr::Pred(filter) => {
6921 use llkv_expr::expr::Operator;
6924
6925 let col_name = &filter.field_id;
6926 let col_idx = column_lookup
6927 .get(&col_name.to_ascii_lowercase())
6928 .ok_or_else(|| {
6929 Error::InvalidArgumentError(format!(
6930 "column '{}' not found in HAVING context",
6931 col_name
6932 ))
6933 })?;
6934
6935 let value = llkv_plan::plan_value_from_array(row_batch.column(*col_idx), row_idx)?;
6936
6937 match &filter.op {
6938 Operator::IsNull => Ok(Some(matches!(value, PlanValue::Null))),
6939 Operator::IsNotNull => Ok(Some(!matches!(value, PlanValue::Null))),
6940 Operator::Equals(expected) => {
6941 if matches!(value, PlanValue::Null) {
6943 return Ok(None);
6944 }
6945 let expected_value = llkv_plan::plan_value_from_literal(expected)?;
6947 if matches!(expected_value, PlanValue::Null) {
6948 return Ok(None);
6949 }
6950 Ok(Some(value == expected_value))
6951 }
6952 Operator::GreaterThan(expected) => {
6953 evaluate_ordering_predicate(&value, expected, |ordering| {
6954 ordering == std::cmp::Ordering::Greater
6955 })
6956 }
6957 Operator::GreaterThanOrEquals(expected) => {
6958 evaluate_ordering_predicate(&value, expected, |ordering| {
6959 ordering == std::cmp::Ordering::Greater
6960 || ordering == std::cmp::Ordering::Equal
6961 })
6962 }
6963 Operator::LessThan(expected) => {
6964 evaluate_ordering_predicate(&value, expected, |ordering| {
6965 ordering == std::cmp::Ordering::Less
6966 })
6967 }
6968 Operator::LessThanOrEquals(expected) => {
6969 evaluate_ordering_predicate(&value, expected, |ordering| {
6970 ordering == std::cmp::Ordering::Less
6971 || ordering == std::cmp::Ordering::Equal
6972 })
6973 }
6974 _ => {
6975 Err(Error::InvalidArgumentError(format!(
6978 "Operator {:?} not supported for column predicates in HAVING clause",
6979 filter.op
6980 )))
6981 }
6982 }
6983 }
6984 llkv_expr::expr::Expr::Exists(_) => Err(Error::InvalidArgumentError(
6985 "EXISTS subqueries not supported in HAVING clause".into(),
6986 )),
6987 }
6988 }
6989
6990 fn evaluate_expr_with_plan_value_aggregates_and_row(
6991 expr: &ScalarExpr<String>,
6992 aggregates: &FxHashMap<String, PlanValue>,
6993 row_batch: Option<&RecordBatch>,
6994 column_lookup: Option<&FxHashMap<String, usize>>,
6995 row_idx: usize,
6996 ) -> ExecutorResult<PlanValue> {
6997 use llkv_expr::expr::BinaryOp;
6998 use llkv_expr::literal::Literal;
6999
7000 match expr {
7001 ScalarExpr::Literal(Literal::Int128(v)) => Ok(PlanValue::Integer(*v as i64)),
7002 ScalarExpr::Literal(Literal::Float64(v)) => Ok(PlanValue::Float(*v)),
7003 ScalarExpr::Literal(Literal::Decimal128(value)) => Ok(PlanValue::Decimal(*value)),
7004 ScalarExpr::Literal(Literal::Boolean(v)) => {
7005 Ok(PlanValue::Integer(if *v { 1 } else { 0 }))
7006 }
7007 ScalarExpr::Literal(Literal::String(s)) => Ok(PlanValue::String(s.clone())),
7008 ScalarExpr::Literal(Literal::Date32(days)) => Ok(PlanValue::Date32(*days)),
7009 ScalarExpr::Literal(Literal::Interval(interval)) => Ok(PlanValue::Interval(*interval)),
7010 ScalarExpr::Literal(Literal::Null) => Ok(PlanValue::Null),
7011 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
7012 "Struct literals not supported in aggregate expressions".into(),
7013 )),
7014 ScalarExpr::Column(col_name) => {
7015 if let (Some(batch), Some(lookup)) = (row_batch, column_lookup) {
7017 let col_idx = lookup.get(&col_name.to_ascii_lowercase()).ok_or_else(|| {
7018 Error::InvalidArgumentError(format!("column '{}' not found", col_name))
7019 })?;
7020 llkv_plan::plan_value_from_array(batch.column(*col_idx), row_idx)
7021 } else {
7022 Err(Error::InvalidArgumentError(
7023 "Column references not supported in aggregate-only expressions".into(),
7024 ))
7025 }
7026 }
7027 ScalarExpr::Compare { left, op, right } => {
7028 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7030 left,
7031 aggregates,
7032 row_batch,
7033 column_lookup,
7034 row_idx,
7035 )?;
7036 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7037 right,
7038 aggregates,
7039 row_batch,
7040 column_lookup,
7041 row_idx,
7042 )?;
7043
7044 if matches!(left_val, PlanValue::Null) || matches!(right_val, PlanValue::Null) {
7046 return Ok(PlanValue::Null);
7047 }
7048
7049 let (left_val, right_val) = match (&left_val, &right_val) {
7051 (PlanValue::Integer(i), PlanValue::Float(_)) => {
7052 (PlanValue::Float(*i as f64), right_val)
7053 }
7054 (PlanValue::Float(_), PlanValue::Integer(i)) => {
7055 (left_val, PlanValue::Float(*i as f64))
7056 }
7057 _ => (left_val, right_val),
7058 };
7059
7060 let result = match (&left_val, &right_val) {
7062 (PlanValue::Integer(l), PlanValue::Integer(r)) => {
7063 use llkv_expr::expr::CompareOp;
7064 match op {
7065 CompareOp::Eq => l == r,
7066 CompareOp::NotEq => l != r,
7067 CompareOp::Lt => l < r,
7068 CompareOp::LtEq => l <= r,
7069 CompareOp::Gt => l > r,
7070 CompareOp::GtEq => l >= r,
7071 }
7072 }
7073 (PlanValue::Float(l), PlanValue::Float(r)) => {
7074 use llkv_expr::expr::CompareOp;
7075 match op {
7076 CompareOp::Eq => l == r,
7077 CompareOp::NotEq => l != r,
7078 CompareOp::Lt => l < r,
7079 CompareOp::LtEq => l <= r,
7080 CompareOp::Gt => l > r,
7081 CompareOp::GtEq => l >= r,
7082 }
7083 }
7084 (PlanValue::String(l), PlanValue::String(r)) => {
7085 use llkv_expr::expr::CompareOp;
7086 match op {
7087 CompareOp::Eq => l == r,
7088 CompareOp::NotEq => l != r,
7089 CompareOp::Lt => l < r,
7090 CompareOp::LtEq => l <= r,
7091 CompareOp::Gt => l > r,
7092 CompareOp::GtEq => l >= r,
7093 }
7094 }
7095 (PlanValue::Interval(l), PlanValue::Interval(r)) => {
7096 use llkv_expr::expr::CompareOp;
7097 let ordering = compare_interval_values(*l, *r);
7098 match op {
7099 CompareOp::Eq => ordering == std::cmp::Ordering::Equal,
7100 CompareOp::NotEq => ordering != std::cmp::Ordering::Equal,
7101 CompareOp::Lt => ordering == std::cmp::Ordering::Less,
7102 CompareOp::LtEq => {
7103 matches!(
7104 ordering,
7105 std::cmp::Ordering::Less | std::cmp::Ordering::Equal
7106 )
7107 }
7108 CompareOp::Gt => ordering == std::cmp::Ordering::Greater,
7109 CompareOp::GtEq => {
7110 matches!(
7111 ordering,
7112 std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
7113 )
7114 }
7115 }
7116 }
7117 _ => false,
7118 };
7119
7120 Ok(PlanValue::Integer(if result { 1 } else { 0 }))
7122 }
7123 ScalarExpr::Not(inner) => {
7124 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7125 inner,
7126 aggregates,
7127 row_batch,
7128 column_lookup,
7129 row_idx,
7130 )?;
7131 match value {
7132 PlanValue::Integer(v) => Ok(PlanValue::Integer(if v != 0 { 0 } else { 1 })),
7133 PlanValue::Float(v) => Ok(PlanValue::Integer(if v != 0.0 { 0 } else { 1 })),
7134 PlanValue::Null => Ok(PlanValue::Null),
7135 other => Err(Error::InvalidArgumentError(format!(
7136 "logical NOT does not support value {other:?}"
7137 ))),
7138 }
7139 }
7140 ScalarExpr::IsNull { expr, negated } => {
7141 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7142 expr,
7143 aggregates,
7144 row_batch,
7145 column_lookup,
7146 row_idx,
7147 )?;
7148 let is_null = matches!(value, PlanValue::Null);
7149 let condition = if is_null { !negated } else { *negated };
7150 Ok(PlanValue::Integer(if condition { 1 } else { 0 }))
7151 }
7152 ScalarExpr::Aggregate(agg) => {
7153 let key = format!("{:?}", agg);
7154 aggregates
7155 .get(&key)
7156 .cloned()
7157 .ok_or_else(|| Error::Internal(format!("Aggregate value not found: {}", key)))
7158 }
7159 ScalarExpr::Binary { left, op, right } => {
7160 let left_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7161 left,
7162 aggregates,
7163 row_batch,
7164 column_lookup,
7165 row_idx,
7166 )?;
7167 let right_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7168 right,
7169 aggregates,
7170 row_batch,
7171 column_lookup,
7172 row_idx,
7173 )?;
7174
7175 match op {
7176 BinaryOp::Add
7177 | BinaryOp::Subtract
7178 | BinaryOp::Multiply
7179 | BinaryOp::Divide
7180 | BinaryOp::Modulo => {
7181 if matches!(&left_val, PlanValue::Null)
7182 || matches!(&right_val, PlanValue::Null)
7183 {
7184 return Ok(PlanValue::Null);
7185 }
7186
7187 if matches!(left_val, PlanValue::Interval(_))
7188 || matches!(right_val, PlanValue::Interval(_))
7189 {
7190 return Err(Error::InvalidArgumentError(
7191 "interval arithmetic not supported in aggregate expressions".into(),
7192 ));
7193 }
7194
7195 if matches!(op, BinaryOp::Divide)
7197 && let (PlanValue::Integer(lhs), PlanValue::Integer(rhs)) =
7198 (&left_val, &right_val)
7199 {
7200 if *rhs == 0 {
7201 return Ok(PlanValue::Null);
7202 }
7203
7204 if *lhs == i64::MIN && *rhs == -1 {
7205 return Ok(PlanValue::Float((*lhs as f64) / (*rhs as f64)));
7206 }
7207
7208 return Ok(PlanValue::Integer(lhs / rhs));
7209 }
7210
7211 let has_decimal = matches!(&left_val, PlanValue::Decimal(_))
7213 || matches!(&right_val, PlanValue::Decimal(_));
7214
7215 if has_decimal {
7216 use llkv_types::decimal::DecimalValue;
7217
7218 let left_dec = match &left_val {
7220 PlanValue::Integer(i) => DecimalValue::from_i64(*i),
7221 PlanValue::Float(_f) => {
7222 return Err(Error::InvalidArgumentError(
7224 "Cannot perform exact decimal arithmetic with Float operands"
7225 .into(),
7226 ));
7227 }
7228 PlanValue::Decimal(d) => *d,
7229 other => {
7230 return Err(Error::InvalidArgumentError(format!(
7231 "Non-numeric value {:?} in binary operation",
7232 other
7233 )));
7234 }
7235 };
7236
7237 let right_dec = match &right_val {
7238 PlanValue::Integer(i) => DecimalValue::from_i64(*i),
7239 PlanValue::Float(_f) => {
7240 return Err(Error::InvalidArgumentError(
7241 "Cannot perform exact decimal arithmetic with Float operands"
7242 .into(),
7243 ));
7244 }
7245 PlanValue::Decimal(d) => *d,
7246 other => {
7247 return Err(Error::InvalidArgumentError(format!(
7248 "Non-numeric value {:?} in binary operation",
7249 other
7250 )));
7251 }
7252 };
7253
7254 let result_dec = match op {
7256 BinaryOp::Add => {
7257 llkv_compute::scalar::decimal::add(left_dec, right_dec)
7258 .map_err(|e| {
7259 Error::InvalidArgumentError(format!(
7260 "Decimal addition overflow: {}",
7261 e
7262 ))
7263 })?
7264 }
7265 BinaryOp::Subtract => {
7266 llkv_compute::scalar::decimal::sub(left_dec, right_dec)
7267 .map_err(|e| {
7268 Error::InvalidArgumentError(format!(
7269 "Decimal subtraction overflow: {}",
7270 e
7271 ))
7272 })?
7273 }
7274 BinaryOp::Multiply => {
7275 llkv_compute::scalar::decimal::mul(left_dec, right_dec)
7276 .map_err(|e| {
7277 Error::InvalidArgumentError(format!(
7278 "Decimal multiplication overflow: {}",
7279 e
7280 ))
7281 })?
7282 }
7283 BinaryOp::Divide => {
7284 if right_dec.raw_value() == 0 {
7286 return Ok(PlanValue::Null);
7287 }
7288 let target_scale = left_dec.scale();
7290 llkv_compute::scalar::decimal::div(
7291 left_dec,
7292 right_dec,
7293 target_scale,
7294 )
7295 .map_err(|e| {
7296 Error::InvalidArgumentError(format!(
7297 "Decimal division error: {}",
7298 e
7299 ))
7300 })?
7301 }
7302 BinaryOp::Modulo => {
7303 return Err(Error::InvalidArgumentError(
7304 "Modulo not supported for Decimal types".into(),
7305 ));
7306 }
7307 BinaryOp::And
7308 | BinaryOp::Or
7309 | BinaryOp::BitwiseShiftLeft
7310 | BinaryOp::BitwiseShiftRight => unreachable!(),
7311 };
7312
7313 return Ok(PlanValue::Decimal(result_dec));
7314 }
7315
7316 let left_is_float = matches!(&left_val, PlanValue::Float(_));
7318 let right_is_float = matches!(&right_val, PlanValue::Float(_));
7319
7320 let left_num = match left_val {
7321 PlanValue::Integer(i) => i as f64,
7322 PlanValue::Float(f) => f,
7323 other => {
7324 return Err(Error::InvalidArgumentError(format!(
7325 "Non-numeric value {:?} in binary operation",
7326 other
7327 )));
7328 }
7329 };
7330 let right_num = match right_val {
7331 PlanValue::Integer(i) => i as f64,
7332 PlanValue::Float(f) => f,
7333 other => {
7334 return Err(Error::InvalidArgumentError(format!(
7335 "Non-numeric value {:?} in binary operation",
7336 other
7337 )));
7338 }
7339 };
7340
7341 let result = match op {
7342 BinaryOp::Add => left_num + right_num,
7343 BinaryOp::Subtract => left_num - right_num,
7344 BinaryOp::Multiply => left_num * right_num,
7345 BinaryOp::Divide => {
7346 if right_num == 0.0 {
7347 return Ok(PlanValue::Null);
7348 }
7349 left_num / right_num
7350 }
7351 BinaryOp::Modulo => {
7352 if right_num == 0.0 {
7353 return Ok(PlanValue::Null);
7354 }
7355 left_num % right_num
7356 }
7357 BinaryOp::And
7358 | BinaryOp::Or
7359 | BinaryOp::BitwiseShiftLeft
7360 | BinaryOp::BitwiseShiftRight => unreachable!(),
7361 };
7362
7363 if matches!(op, BinaryOp::Divide) {
7364 return Ok(PlanValue::Float(result));
7365 }
7366
7367 if left_is_float || right_is_float {
7368 Ok(PlanValue::Float(result))
7369 } else {
7370 Ok(PlanValue::Integer(result as i64))
7371 }
7372 }
7373 BinaryOp::And => Ok(evaluate_plan_value_logical_and(left_val, right_val)),
7374 BinaryOp::Or => Ok(evaluate_plan_value_logical_or(left_val, right_val)),
7375 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
7376 if matches!(&left_val, PlanValue::Null)
7377 || matches!(&right_val, PlanValue::Null)
7378 {
7379 return Ok(PlanValue::Null);
7380 }
7381
7382 let lhs = match left_val {
7384 PlanValue::Integer(i) => i,
7385 PlanValue::Float(f) => f as i64,
7386 other => {
7387 return Err(Error::InvalidArgumentError(format!(
7388 "Non-numeric value {:?} in bitwise shift operation",
7389 other
7390 )));
7391 }
7392 };
7393 let rhs = match right_val {
7394 PlanValue::Integer(i) => i,
7395 PlanValue::Float(f) => f as i64,
7396 other => {
7397 return Err(Error::InvalidArgumentError(format!(
7398 "Non-numeric value {:?} in bitwise shift operation",
7399 other
7400 )));
7401 }
7402 };
7403
7404 let result = match op {
7406 BinaryOp::BitwiseShiftLeft => lhs.wrapping_shl(rhs as u32),
7407 BinaryOp::BitwiseShiftRight => lhs.wrapping_shr(rhs as u32),
7408 _ => unreachable!(),
7409 };
7410
7411 Ok(PlanValue::Integer(result))
7412 }
7413 }
7414 }
7415 ScalarExpr::Cast { expr, data_type } => {
7416 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7418 expr,
7419 aggregates,
7420 row_batch,
7421 column_lookup,
7422 row_idx,
7423 )?;
7424
7425 if matches!(value, PlanValue::Null) {
7427 return Ok(PlanValue::Null);
7428 }
7429
7430 match data_type {
7432 DataType::Int64 | DataType::Int32 | DataType::Int16 | DataType::Int8 => {
7433 match value {
7434 PlanValue::Integer(i) => Ok(PlanValue::Integer(i)),
7435 PlanValue::Float(f) => Ok(PlanValue::Integer(f as i64)),
7436 PlanValue::String(s) => {
7437 s.parse::<i64>().map(PlanValue::Integer).map_err(|_| {
7438 Error::InvalidArgumentError(format!(
7439 "Cannot cast '{}' to integer",
7440 s
7441 ))
7442 })
7443 }
7444 _ => Err(Error::InvalidArgumentError(format!(
7445 "Cannot cast {:?} to integer",
7446 value
7447 ))),
7448 }
7449 }
7450 DataType::Float64 | DataType::Float32 => match value {
7451 PlanValue::Integer(i) => Ok(PlanValue::Float(i as f64)),
7452 PlanValue::Float(f) => Ok(PlanValue::Float(f)),
7453 PlanValue::String(s) => {
7454 s.parse::<f64>().map(PlanValue::Float).map_err(|_| {
7455 Error::InvalidArgumentError(format!("Cannot cast '{}' to float", s))
7456 })
7457 }
7458 _ => Err(Error::InvalidArgumentError(format!(
7459 "Cannot cast {:?} to float",
7460 value
7461 ))),
7462 },
7463 DataType::Utf8 | DataType::LargeUtf8 => match value {
7464 PlanValue::String(s) => Ok(PlanValue::String(s)),
7465 PlanValue::Integer(i) => Ok(PlanValue::String(i.to_string())),
7466 PlanValue::Float(f) => Ok(PlanValue::String(f.to_string())),
7467 PlanValue::Interval(_) => Err(Error::InvalidArgumentError(
7468 "Cannot cast interval to string in aggregate expressions".into(),
7469 )),
7470 _ => Err(Error::InvalidArgumentError(format!(
7471 "Cannot cast {:?} to string",
7472 value
7473 ))),
7474 },
7475 DataType::Interval(IntervalUnit::MonthDayNano) => match value {
7476 PlanValue::Interval(interval) => Ok(PlanValue::Interval(interval)),
7477 _ => Err(Error::InvalidArgumentError(format!(
7478 "Cannot cast {:?} to interval",
7479 value
7480 ))),
7481 },
7482 DataType::Date32 => match value {
7483 PlanValue::Date32(days) => Ok(PlanValue::Date32(days)),
7484 PlanValue::String(text) => {
7485 let days = parse_date32_literal(&text)?;
7486 Ok(PlanValue::Date32(days))
7487 }
7488 _ => Err(Error::InvalidArgumentError(format!(
7489 "Cannot cast {:?} to date",
7490 value
7491 ))),
7492 },
7493 _ => Err(Error::InvalidArgumentError(format!(
7494 "CAST to {:?} not supported in aggregate expressions",
7495 data_type
7496 ))),
7497 }
7498 }
7499 ScalarExpr::Case {
7500 operand,
7501 branches,
7502 else_expr,
7503 } => {
7504 let operand_value = if let Some(op) = operand {
7506 Some(Self::evaluate_expr_with_plan_value_aggregates_and_row(
7507 op,
7508 aggregates,
7509 row_batch,
7510 column_lookup,
7511 row_idx,
7512 )?)
7513 } else {
7514 None
7515 };
7516
7517 for (when_expr, then_expr) in branches {
7519 let matches = if let Some(ref op_val) = operand_value {
7520 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7522 when_expr,
7523 aggregates,
7524 row_batch,
7525 column_lookup,
7526 row_idx,
7527 )?;
7528 Self::simple_case_branch_matches(op_val, &when_val)
7529 } else {
7530 let when_val = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7532 when_expr,
7533 aggregates,
7534 row_batch,
7535 column_lookup,
7536 row_idx,
7537 )?;
7538 match when_val {
7540 PlanValue::Integer(i) => i != 0,
7541 PlanValue::Float(f) => f != 0.0,
7542 PlanValue::Null => false,
7543 _ => false,
7544 }
7545 };
7546
7547 if matches {
7548 return Self::evaluate_expr_with_plan_value_aggregates_and_row(
7549 then_expr,
7550 aggregates,
7551 row_batch,
7552 column_lookup,
7553 row_idx,
7554 );
7555 }
7556 }
7557
7558 if let Some(else_e) = else_expr {
7560 Self::evaluate_expr_with_plan_value_aggregates_and_row(
7561 else_e,
7562 aggregates,
7563 row_batch,
7564 column_lookup,
7565 row_idx,
7566 )
7567 } else {
7568 Ok(PlanValue::Null)
7569 }
7570 }
7571 ScalarExpr::Coalesce(exprs) => {
7572 for expr in exprs {
7574 let value = Self::evaluate_expr_with_plan_value_aggregates_and_row(
7575 expr,
7576 aggregates,
7577 row_batch,
7578 column_lookup,
7579 row_idx,
7580 )?;
7581 if !matches!(value, PlanValue::Null) {
7582 return Ok(value);
7583 }
7584 }
7585 Ok(PlanValue::Null)
7586 }
7587 ScalarExpr::Random => Ok(PlanValue::Float(rand::random::<f64>())),
7588 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
7589 "GetField not supported in aggregate expressions".into(),
7590 )),
7591 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
7592 "Scalar subqueries not supported in aggregate expressions".into(),
7593 )),
7594 }
7595 }
7596
7597 fn simple_case_branch_matches(operand: &PlanValue, candidate: &PlanValue) -> bool {
7598 if matches!(operand, PlanValue::Null) || matches!(candidate, PlanValue::Null) {
7599 return false;
7600 }
7601
7602 match (operand, candidate) {
7603 (PlanValue::Integer(left), PlanValue::Integer(right)) => left == right,
7604 (PlanValue::Integer(left), PlanValue::Float(right)) => (*left as f64) == *right,
7605 (PlanValue::Float(left), PlanValue::Integer(right)) => *left == (*right as f64),
7606 (PlanValue::Float(left), PlanValue::Float(right)) => left == right,
7607 (PlanValue::String(left), PlanValue::String(right)) => left == right,
7608 (PlanValue::Struct(left), PlanValue::Struct(right)) => left == right,
7609 (PlanValue::Interval(left), PlanValue::Interval(right)) => {
7610 compare_interval_values(*left, *right) == std::cmp::Ordering::Equal
7611 }
7612 _ => operand == candidate,
7613 }
7614 }
7615
7616 fn evaluate_expr_with_aggregates(
7617 expr: &ScalarExpr<String>,
7618 aggregates: &FxHashMap<String, AggregateValue>,
7619 ) -> ExecutorResult<Option<i64>> {
7620 use llkv_expr::expr::BinaryOp;
7621 use llkv_expr::literal::Literal;
7622
7623 match expr {
7624 ScalarExpr::Literal(Literal::Int128(v)) => Ok(Some(*v as i64)),
7625 ScalarExpr::Literal(Literal::Float64(v)) => Ok(Some(*v as i64)),
7626 ScalarExpr::Literal(Literal::Decimal128(value)) => {
7627 if let Some(int) = decimal_exact_i64(*value) {
7628 Ok(Some(int))
7629 } else {
7630 Ok(Some(value.to_f64() as i64))
7631 }
7632 }
7633 ScalarExpr::Literal(Literal::Boolean(v)) => Ok(Some(if *v { 1 } else { 0 })),
7634 ScalarExpr::Literal(Literal::String(_)) => Err(Error::InvalidArgumentError(
7635 "String literals not supported in aggregate expressions".into(),
7636 )),
7637 ScalarExpr::Literal(Literal::Date32(days)) => Ok(Some(*days as i64)),
7638 ScalarExpr::Literal(Literal::Null) => Ok(None),
7639 ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
7640 "Struct literals not supported in aggregate expressions".into(),
7641 )),
7642 ScalarExpr::Literal(Literal::Interval(_)) => Err(Error::InvalidArgumentError(
7643 "Interval literals not supported in aggregate-only expressions".into(),
7644 )),
7645 ScalarExpr::Column(_) => Err(Error::InvalidArgumentError(
7646 "Column references not supported in aggregate-only expressions".into(),
7647 )),
7648 ScalarExpr::Compare { .. } => Err(Error::InvalidArgumentError(
7649 "Comparisons not supported in aggregate-only expressions".into(),
7650 )),
7651 ScalarExpr::Aggregate(agg) => {
7652 let key = format!("{:?}", agg);
7653 let value = aggregates.get(&key).ok_or_else(|| {
7654 Error::Internal(format!("Aggregate value not found for key: {}", key))
7655 })?;
7656 Ok(value.as_i64())
7657 }
7658 ScalarExpr::Not(inner) => {
7659 let value = Self::evaluate_expr_with_aggregates(inner, aggregates)?;
7660 Ok(value.map(|v| if v != 0 { 0 } else { 1 }))
7661 }
7662 ScalarExpr::IsNull { expr, negated } => {
7663 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
7664 let is_null = value.is_none();
7665 Ok(Some(if is_null != *negated { 1 } else { 0 }))
7666 }
7667 ScalarExpr::Binary { left, op, right } => {
7668 let left_val = Self::evaluate_expr_with_aggregates(left, aggregates)?;
7669 let right_val = Self::evaluate_expr_with_aggregates(right, aggregates)?;
7670
7671 match op {
7672 BinaryOp::Add
7673 | BinaryOp::Subtract
7674 | BinaryOp::Multiply
7675 | BinaryOp::Divide
7676 | BinaryOp::Modulo => match (left_val, right_val) {
7677 (Some(lhs), Some(rhs)) => {
7678 let result = match op {
7679 BinaryOp::Add => lhs.checked_add(rhs),
7680 BinaryOp::Subtract => lhs.checked_sub(rhs),
7681 BinaryOp::Multiply => lhs.checked_mul(rhs),
7682 BinaryOp::Divide => {
7683 if rhs == 0 {
7684 return Ok(None);
7685 }
7686 lhs.checked_div(rhs)
7687 }
7688 BinaryOp::Modulo => {
7689 if rhs == 0 {
7690 return Ok(None);
7691 }
7692 lhs.checked_rem(rhs)
7693 }
7694 BinaryOp::And
7695 | BinaryOp::Or
7696 | BinaryOp::BitwiseShiftLeft
7697 | BinaryOp::BitwiseShiftRight => unreachable!(),
7698 };
7699
7700 result.map(Some).ok_or_else(|| {
7701 Error::InvalidArgumentError(
7702 "Arithmetic overflow in expression".into(),
7703 )
7704 })
7705 }
7706 _ => Ok(None),
7707 },
7708 BinaryOp::And => Ok(evaluate_option_logical_and(left_val, right_val)),
7709 BinaryOp::Or => Ok(evaluate_option_logical_or(left_val, right_val)),
7710 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
7711 match (left_val, right_val) {
7712 (Some(lhs), Some(rhs)) => {
7713 let result = match op {
7714 BinaryOp::BitwiseShiftLeft => {
7715 Some(lhs.wrapping_shl(rhs as u32))
7716 }
7717 BinaryOp::BitwiseShiftRight => {
7718 Some(lhs.wrapping_shr(rhs as u32))
7719 }
7720 _ => unreachable!(),
7721 };
7722 Ok(result)
7723 }
7724 _ => Ok(None),
7725 }
7726 }
7727 }
7728 }
7729 ScalarExpr::Cast { expr, data_type } => {
7730 let value = Self::evaluate_expr_with_aggregates(expr, aggregates)?;
7731 match value {
7732 Some(v) => Self::cast_aggregate_value(v, data_type).map(Some),
7733 None => Ok(None),
7734 }
7735 }
7736 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
7737 "GetField not supported in aggregate-only expressions".into(),
7738 )),
7739 ScalarExpr::Case { .. } => Err(Error::InvalidArgumentError(
7740 "CASE not supported in aggregate-only expressions".into(),
7741 )),
7742 ScalarExpr::Coalesce(_) => Err(Error::InvalidArgumentError(
7743 "COALESCE not supported in aggregate-only expressions".into(),
7744 )),
7745 ScalarExpr::Random => Ok(Some((rand::random::<f64>() * (i64::MAX as f64)) as i64)),
7746 ScalarExpr::ScalarSubquery(_) => Err(Error::InvalidArgumentError(
7747 "Scalar subqueries not supported in aggregate-only expressions".into(),
7748 )),
7749 }
7750 }
7751
7752 fn cast_aggregate_value(value: i64, data_type: &DataType) -> ExecutorResult<i64> {
7753 fn ensure_range(value: i64, min: i64, max: i64, ty: &DataType) -> ExecutorResult<i64> {
7754 if value < min || value > max {
7755 return Err(Error::InvalidArgumentError(format!(
7756 "value {} out of range for CAST target {:?}",
7757 value, ty
7758 )));
7759 }
7760 Ok(value)
7761 }
7762
7763 match data_type {
7764 DataType::Int8 => ensure_range(value, i8::MIN as i64, i8::MAX as i64, data_type),
7765 DataType::Int16 => ensure_range(value, i16::MIN as i64, i16::MAX as i64, data_type),
7766 DataType::Int32 => ensure_range(value, i32::MIN as i64, i32::MAX as i64, data_type),
7767 DataType::Int64 => Ok(value),
7768 DataType::UInt8 => ensure_range(value, 0, u8::MAX as i64, data_type),
7769 DataType::UInt16 => ensure_range(value, 0, u16::MAX as i64, data_type),
7770 DataType::UInt32 => ensure_range(value, 0, u32::MAX as i64, data_type),
7771 DataType::UInt64 => {
7772 if value < 0 {
7773 return Err(Error::InvalidArgumentError(format!(
7774 "value {} out of range for CAST target {:?}",
7775 value, data_type
7776 )));
7777 }
7778 Ok(value)
7779 }
7780 DataType::Float32 | DataType::Float64 => Ok(value),
7781 DataType::Boolean => Ok(if value == 0 { 0 } else { 1 }),
7782 DataType::Null => Err(Error::InvalidArgumentError(
7783 "CAST to NULL is not supported in aggregate-only expressions".into(),
7784 )),
7785 _ => Err(Error::InvalidArgumentError(format!(
7786 "CAST to {:?} is not supported in aggregate-only expressions",
7787 data_type
7788 ))),
7789 }
7790 }
7791}
7792
7793struct CrossProductExpressionContext {
7794 schema: Arc<ExecutorSchema>,
7795 field_id_to_index: FxHashMap<FieldId, usize>,
7796 numeric_cache: FxHashMap<FieldId, ArrayRef>,
7797 column_cache: FxHashMap<FieldId, ColumnAccessor>,
7798 scalar_subquery_columns: FxHashMap<SubqueryId, ColumnAccessor>,
7799 next_field_id: FieldId,
7800}
7801
7802#[derive(Clone)]
7803enum ColumnAccessor {
7804 Int64(Arc<Int64Array>),
7805 Float64(Arc<Float64Array>),
7806 Boolean(Arc<BooleanArray>),
7807 Utf8(Arc<StringArray>),
7808 Date32(Arc<Date32Array>),
7809 Interval(Arc<IntervalMonthDayNanoArray>),
7810 Decimal128 {
7811 array: Arc<Decimal128Array>,
7812 scale: i8,
7813 },
7814 Null(usize),
7815}
7816
7817impl ColumnAccessor {
7818 fn from_array(array: &ArrayRef) -> ExecutorResult<Self> {
7819 match array.data_type() {
7820 DataType::Int64 => {
7821 let typed = array
7822 .as_any()
7823 .downcast_ref::<Int64Array>()
7824 .ok_or_else(|| Error::Internal("expected Int64 array".into()))?
7825 .clone();
7826 Ok(Self::Int64(Arc::new(typed)))
7827 }
7828 DataType::Float64 => {
7829 let typed = array
7830 .as_any()
7831 .downcast_ref::<Float64Array>()
7832 .ok_or_else(|| Error::Internal("expected Float64 array".into()))?
7833 .clone();
7834 Ok(Self::Float64(Arc::new(typed)))
7835 }
7836 DataType::Boolean => {
7837 let typed = array
7838 .as_any()
7839 .downcast_ref::<BooleanArray>()
7840 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
7841 .clone();
7842 Ok(Self::Boolean(Arc::new(typed)))
7843 }
7844 DataType::Utf8 => {
7845 let typed = array
7846 .as_any()
7847 .downcast_ref::<StringArray>()
7848 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
7849 .clone();
7850 Ok(Self::Utf8(Arc::new(typed)))
7851 }
7852 DataType::Date32 => {
7853 let typed = array
7854 .as_any()
7855 .downcast_ref::<Date32Array>()
7856 .ok_or_else(|| Error::Internal("expected Date32 array".into()))?
7857 .clone();
7858 Ok(Self::Date32(Arc::new(typed)))
7859 }
7860 DataType::Interval(IntervalUnit::MonthDayNano) => {
7861 let typed = array
7862 .as_any()
7863 .downcast_ref::<IntervalMonthDayNanoArray>()
7864 .ok_or_else(|| Error::Internal("expected IntervalMonthDayNano array".into()))?
7865 .clone();
7866 Ok(Self::Interval(Arc::new(typed)))
7867 }
7868 DataType::Decimal128(_, scale) => {
7869 let typed = array
7870 .as_any()
7871 .downcast_ref::<Decimal128Array>()
7872 .ok_or_else(|| Error::Internal("expected Decimal128 array".into()))?
7873 .clone();
7874 Ok(Self::Decimal128 {
7875 array: Arc::new(typed),
7876 scale: *scale,
7877 })
7878 }
7879 DataType::Null => Ok(Self::Null(array.len())),
7880 other => Err(Error::InvalidArgumentError(format!(
7881 "unsupported column type {:?} in cross product filter",
7882 other
7883 ))),
7884 }
7885 }
7886
7887 fn from_numeric_array(numeric: &ArrayRef) -> ExecutorResult<Self> {
7888 let casted = cast(numeric, &DataType::Float64)?;
7889 let float_array = casted
7890 .as_any()
7891 .downcast_ref::<Float64Array>()
7892 .expect("cast to Float64 failed")
7893 .clone();
7894 Ok(Self::Float64(Arc::new(float_array)))
7895 }
7896
7897 fn len(&self) -> usize {
7898 match self {
7899 ColumnAccessor::Int64(array) => array.len(),
7900 ColumnAccessor::Float64(array) => array.len(),
7901 ColumnAccessor::Boolean(array) => array.len(),
7902 ColumnAccessor::Utf8(array) => array.len(),
7903 ColumnAccessor::Date32(array) => array.len(),
7904 ColumnAccessor::Interval(array) => array.len(),
7905 ColumnAccessor::Decimal128 { array, .. } => array.len(),
7906 ColumnAccessor::Null(len) => *len,
7907 }
7908 }
7909
7910 fn is_null(&self, idx: usize) -> bool {
7911 match self {
7912 ColumnAccessor::Int64(array) => array.is_null(idx),
7913 ColumnAccessor::Float64(array) => array.is_null(idx),
7914 ColumnAccessor::Boolean(array) => array.is_null(idx),
7915 ColumnAccessor::Utf8(array) => array.is_null(idx),
7916 ColumnAccessor::Date32(array) => array.is_null(idx),
7917 ColumnAccessor::Interval(array) => array.is_null(idx),
7918 ColumnAccessor::Decimal128 { array, .. } => array.is_null(idx),
7919 ColumnAccessor::Null(_) => true,
7920 }
7921 }
7922
7923 fn literal_at(&self, idx: usize) -> ExecutorResult<Literal> {
7924 if self.is_null(idx) {
7925 return Ok(Literal::Null);
7926 }
7927 match self {
7928 ColumnAccessor::Int64(array) => Ok(Literal::Int128(array.value(idx) as i128)),
7929 ColumnAccessor::Float64(array) => Ok(Literal::Float64(array.value(idx))),
7930 ColumnAccessor::Boolean(array) => Ok(Literal::Boolean(array.value(idx))),
7931 ColumnAccessor::Utf8(array) => Ok(Literal::String(array.value(idx).to_string())),
7932 ColumnAccessor::Date32(array) => Ok(Literal::Date32(array.value(idx))),
7933 ColumnAccessor::Interval(array) => Ok(Literal::Interval(interval_value_from_arrow(
7934 array.value(idx),
7935 ))),
7936 ColumnAccessor::Decimal128 { array, .. } => Ok(Literal::Int128(array.value(idx))),
7937 ColumnAccessor::Null(_) => Ok(Literal::Null),
7938 }
7939 }
7940
7941 fn as_array_ref(&self) -> ArrayRef {
7942 match self {
7943 ColumnAccessor::Int64(array) => Arc::clone(array) as ArrayRef,
7944 ColumnAccessor::Float64(array) => Arc::clone(array) as ArrayRef,
7945 ColumnAccessor::Boolean(array) => Arc::clone(array) as ArrayRef,
7946 ColumnAccessor::Utf8(array) => Arc::clone(array) as ArrayRef,
7947 ColumnAccessor::Date32(array) => Arc::clone(array) as ArrayRef,
7948 ColumnAccessor::Interval(array) => Arc::clone(array) as ArrayRef,
7949 ColumnAccessor::Decimal128 { array, .. } => Arc::clone(array) as ArrayRef,
7950 ColumnAccessor::Null(len) => new_null_array(&DataType::Null, *len),
7951 }
7952 }
7953}
7954
7955#[derive(Clone)]
7956enum ValueArray {
7957 Numeric(ArrayRef),
7958 Boolean(Arc<BooleanArray>),
7959 Utf8(Arc<StringArray>),
7960 Interval(Arc<IntervalMonthDayNanoArray>),
7961 Null(usize),
7962}
7963
7964impl ValueArray {
7965 fn from_array(array: ArrayRef) -> ExecutorResult<Self> {
7966 match array.data_type() {
7967 DataType::Boolean => {
7968 let typed = array
7969 .as_any()
7970 .downcast_ref::<BooleanArray>()
7971 .ok_or_else(|| Error::Internal("expected Boolean array".into()))?
7972 .clone();
7973 Ok(Self::Boolean(Arc::new(typed)))
7974 }
7975 DataType::Utf8 => {
7976 let typed = array
7977 .as_any()
7978 .downcast_ref::<StringArray>()
7979 .ok_or_else(|| Error::Internal("expected Utf8 array".into()))?
7980 .clone();
7981 Ok(Self::Utf8(Arc::new(typed)))
7982 }
7983 DataType::Interval(IntervalUnit::MonthDayNano) => {
7984 let typed = array
7985 .as_any()
7986 .downcast_ref::<IntervalMonthDayNanoArray>()
7987 .ok_or_else(|| Error::Internal("expected IntervalMonthDayNano array".into()))?
7988 .clone();
7989 Ok(Self::Interval(Arc::new(typed)))
7990 }
7991 DataType::Null => Ok(Self::Null(array.len())),
7992 DataType::Int8
7993 | DataType::Int16
7994 | DataType::Int32
7995 | DataType::Int64
7996 | DataType::UInt8
7997 | DataType::UInt16
7998 | DataType::UInt32
7999 | DataType::UInt64
8000 | DataType::Date32
8001 | DataType::Float32
8002 | DataType::Float64
8003 | DataType::Decimal128(_, _) => Ok(Self::Numeric(array)),
8004 other => Err(Error::InvalidArgumentError(format!(
8005 "unsupported data type {:?} in cross product expression",
8006 other
8007 ))),
8008 }
8009 }
8010
8011 fn len(&self) -> usize {
8012 match self {
8013 ValueArray::Numeric(array) => array.len(),
8014 ValueArray::Boolean(array) => array.len(),
8015 ValueArray::Utf8(array) => array.len(),
8016 ValueArray::Interval(array) => array.len(),
8017 ValueArray::Null(len) => *len,
8018 }
8019 }
8020
8021 fn as_array_ref(&self) -> ArrayRef {
8022 match self {
8023 ValueArray::Numeric(arr) => arr.clone(),
8024 ValueArray::Boolean(arr) => arr.clone() as ArrayRef,
8025 ValueArray::Utf8(arr) => arr.clone() as ArrayRef,
8026 ValueArray::Interval(arr) => arr.clone() as ArrayRef,
8027 ValueArray::Null(len) => new_null_array(&DataType::Null, *len),
8028 }
8029 }
8030}
8031
8032fn truth_and(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
8033 match (lhs, rhs) {
8034 (Some(false), _) | (_, Some(false)) => Some(false),
8035 (Some(true), Some(true)) => Some(true),
8036 (Some(true), None) | (None, Some(true)) | (None, None) => None,
8037 }
8038}
8039
8040fn truth_or(lhs: Option<bool>, rhs: Option<bool>) -> Option<bool> {
8041 match (lhs, rhs) {
8042 (Some(true), _) | (_, Some(true)) => Some(true),
8043 (Some(false), Some(false)) => Some(false),
8044 (Some(false), None) | (None, Some(false)) | (None, None) => None,
8045 }
8046}
8047
8048fn truth_not(value: Option<bool>) -> Option<bool> {
8049 match value {
8050 Some(true) => Some(false),
8051 Some(false) => Some(true),
8052 None => None,
8053 }
8054}
8055
8056fn literal_to_constant_array(literal: &Literal, len: usize) -> ExecutorResult<ArrayRef> {
8057 match literal {
8058 Literal::Int128(v) => {
8059 let value = i64::try_from(*v).unwrap_or(0);
8060 let values = vec![value; len];
8061 Ok(Arc::new(Int64Array::from(values)) as ArrayRef)
8062 }
8063 Literal::Float64(v) => {
8064 let values = vec![*v; len];
8065 Ok(Arc::new(Float64Array::from(values)) as ArrayRef)
8066 }
8067 Literal::Boolean(v) => {
8068 let values = vec![Some(*v); len];
8069 Ok(Arc::new(BooleanArray::from(values)) as ArrayRef)
8070 }
8071 Literal::String(v) => {
8072 let values: Vec<Option<String>> = (0..len).map(|_| Some(v.clone())).collect();
8073 Ok(Arc::new(StringArray::from(values)) as ArrayRef)
8074 }
8075 Literal::Date32(days) => {
8076 let values = vec![*days; len];
8077 Ok(Arc::new(Date32Array::from(values)) as ArrayRef)
8078 }
8079 Literal::Decimal128(value) => {
8080 let iter = std::iter::repeat_n(value.raw_value(), len);
8081 let array = Decimal128Array::from_iter_values(iter)
8082 .with_precision_and_scale(value.precision(), value.scale())
8083 .map_err(|err| {
8084 Error::InvalidArgumentError(format!(
8085 "failed to synthesize decimal literal array: {err}"
8086 ))
8087 })?;
8088 Ok(Arc::new(array) as ArrayRef)
8089 }
8090 Literal::Interval(interval) => {
8091 let value = interval_value_to_arrow(*interval);
8092 let values = vec![value; len];
8093 Ok(Arc::new(IntervalMonthDayNanoArray::from(values)) as ArrayRef)
8094 }
8095 Literal::Null => Ok(new_null_array(&DataType::Null, len)),
8096 Literal::Struct(_) => Err(Error::InvalidArgumentError(
8097 "struct literals are not supported in cross product filters".into(),
8098 )),
8099 }
8100}
8101
8102fn literals_to_array(values: &[Literal]) -> ExecutorResult<ArrayRef> {
8103 #[derive(Copy, Clone, Eq, PartialEq)]
8104 enum LiteralArrayKind {
8105 Null,
8106 Integer,
8107 Float,
8108 Boolean,
8109 String,
8110 Date32,
8111 Interval,
8112 Decimal,
8113 }
8114
8115 if values.is_empty() {
8116 return Ok(new_null_array(&DataType::Null, 0));
8117 }
8118
8119 let mut has_integer = false;
8120 let mut has_float = false;
8121 let mut has_decimal = false;
8122 let mut has_boolean = false;
8123 let mut has_string = false;
8124 let mut has_date = false;
8125 let mut has_interval = false;
8126
8127 for literal in values {
8128 match literal {
8129 Literal::Null => {}
8130 Literal::Int128(_) => {
8131 has_integer = true;
8132 }
8133 Literal::Float64(_) => {
8134 has_float = true;
8135 }
8136 Literal::Decimal128(_) => {
8137 has_decimal = true;
8138 }
8139 Literal::Boolean(_) => {
8140 has_boolean = true;
8141 }
8142 Literal::String(_) => {
8143 has_string = true;
8144 }
8145 Literal::Date32(_) => {
8146 has_date = true;
8147 }
8148 Literal::Interval(_) => {
8149 has_interval = true;
8150 }
8151 Literal::Struct(_) => {
8152 return Err(Error::InvalidArgumentError(
8153 "struct scalar subquery results are not supported".into(),
8154 ));
8155 }
8156 }
8157 }
8158
8159 let mixed_numeric = has_integer as u8 + has_float as u8 + has_decimal as u8;
8160 if has_string && (has_boolean || has_date || has_interval || mixed_numeric > 0)
8161 || has_boolean && (has_date || has_interval || mixed_numeric > 0)
8162 || has_date && (has_interval || mixed_numeric > 0)
8163 || has_interval && (mixed_numeric > 0)
8164 {
8165 return Err(Error::InvalidArgumentError(
8166 "mixed scalar subquery result types are not supported".into(),
8167 ));
8168 }
8169
8170 let target_kind = if has_string {
8171 LiteralArrayKind::String
8172 } else if has_interval {
8173 LiteralArrayKind::Interval
8174 } else if has_date {
8175 LiteralArrayKind::Date32
8176 } else if has_boolean {
8177 LiteralArrayKind::Boolean
8178 } else if has_float {
8179 LiteralArrayKind::Float
8180 } else if has_decimal {
8181 LiteralArrayKind::Decimal
8182 } else if has_integer {
8183 LiteralArrayKind::Integer
8184 } else {
8185 LiteralArrayKind::Null
8186 };
8187
8188 match target_kind {
8189 LiteralArrayKind::Null => Ok(new_null_array(&DataType::Null, values.len())),
8190 LiteralArrayKind::Integer => {
8191 let mut coerced: Vec<Option<i64>> = Vec::with_capacity(values.len());
8192 for literal in values {
8193 match literal {
8194 Literal::Null => coerced.push(None),
8195 Literal::Int128(value) => {
8196 let v = i64::try_from(*value).map_err(|_| {
8197 Error::InvalidArgumentError(
8198 "scalar subquery integer result exceeds supported range".into(),
8199 )
8200 })?;
8201 coerced.push(Some(v));
8202 }
8203 _ => unreachable!("non-integer value encountered in integer array"),
8204 }
8205 }
8206 let array = Int64Array::from_iter(coerced);
8207 Ok(Arc::new(array) as ArrayRef)
8208 }
8209 LiteralArrayKind::Float => {
8210 let mut coerced: Vec<Option<f64>> = Vec::with_capacity(values.len());
8211 for literal in values {
8212 match literal {
8213 Literal::Null => coerced.push(None),
8214 Literal::Int128(_) | Literal::Float64(_) | Literal::Decimal128(_) => {
8215 let value = literal_to_f64(literal).ok_or_else(|| {
8216 Error::InvalidArgumentError(
8217 "failed to coerce scalar subquery value to FLOAT".into(),
8218 )
8219 })?;
8220 coerced.push(Some(value));
8221 }
8222 _ => unreachable!("non-numeric value encountered in float array"),
8223 }
8224 }
8225 let array = Float64Array::from_iter(coerced);
8226 Ok(Arc::new(array) as ArrayRef)
8227 }
8228 LiteralArrayKind::Boolean => {
8229 let iter = values.iter().map(|literal| match literal {
8230 Literal::Null => None,
8231 Literal::Boolean(flag) => Some(*flag),
8232 _ => unreachable!("non-boolean value encountered in boolean array"),
8233 });
8234 let array = BooleanArray::from_iter(iter);
8235 Ok(Arc::new(array) as ArrayRef)
8236 }
8237 LiteralArrayKind::String => {
8238 let iter = values.iter().map(|literal| match literal {
8239 Literal::Null => None,
8240 Literal::String(value) => Some(value.clone()),
8241 _ => unreachable!("non-string value encountered in string array"),
8242 });
8243 let array = StringArray::from_iter(iter);
8244 Ok(Arc::new(array) as ArrayRef)
8245 }
8246 LiteralArrayKind::Date32 => {
8247 let iter = values.iter().map(|literal| match literal {
8248 Literal::Null => None,
8249 Literal::Date32(days) => Some(*days),
8250 _ => unreachable!("non-date value encountered in date array"),
8251 });
8252 let array = Date32Array::from_iter(iter);
8253 Ok(Arc::new(array) as ArrayRef)
8254 }
8255 LiteralArrayKind::Interval => {
8256 let iter = values.iter().map(|literal| match literal {
8257 Literal::Null => None,
8258 Literal::Interval(interval) => Some(interval_value_to_arrow(*interval)),
8259 _ => unreachable!("non-interval value encountered in interval array"),
8260 });
8261 let array = IntervalMonthDayNanoArray::from_iter(iter);
8262 Ok(Arc::new(array) as ArrayRef)
8263 }
8264 LiteralArrayKind::Decimal => {
8265 let mut target_scale: Option<i8> = None;
8266 for literal in values {
8267 if let Literal::Decimal128(value) = literal {
8268 target_scale = Some(match target_scale {
8269 Some(scale) => scale.max(value.scale()),
8270 None => value.scale(),
8271 });
8272 }
8273 }
8274 let target_scale = target_scale.expect("decimal literal expected");
8275
8276 let mut max_precision: u8 = 1;
8277 let mut aligned: Vec<Option<DecimalValue>> = Vec::with_capacity(values.len());
8278 for literal in values {
8279 match literal {
8280 Literal::Null => aligned.push(None),
8281 Literal::Decimal128(value) => {
8282 let adjusted = if value.scale() != target_scale {
8283 llkv_compute::scalar::decimal::rescale(*value, target_scale).map_err(
8284 |err| {
8285 Error::InvalidArgumentError(format!(
8286 "failed to align decimal scale: {err}"
8287 ))
8288 },
8289 )?
8290 } else {
8291 *value
8292 };
8293 max_precision = max_precision.max(adjusted.precision());
8294 aligned.push(Some(adjusted));
8295 }
8296 Literal::Int128(value) => {
8297 let decimal = DecimalValue::new(*value, 0).map_err(|err| {
8298 Error::InvalidArgumentError(format!(
8299 "failed to build decimal from integer: {err}"
8300 ))
8301 })?;
8302 let decimal = llkv_compute::scalar::decimal::rescale(decimal, target_scale)
8303 .map_err(|err| {
8304 Error::InvalidArgumentError(format!(
8305 "failed to align integer decimal scale: {err}"
8306 ))
8307 })?;
8308 max_precision = max_precision.max(decimal.precision());
8309 aligned.push(Some(decimal));
8310 }
8311 _ => unreachable!("unexpected literal in decimal array"),
8312 }
8313 }
8314
8315 let mut builder = Decimal128Builder::new()
8316 .with_precision_and_scale(max_precision, target_scale)
8317 .map_err(|err| {
8318 Error::InvalidArgumentError(format!(
8319 "invalid Decimal128 precision/scale: {err}"
8320 ))
8321 })?;
8322 for value in aligned {
8323 match value {
8324 Some(decimal) => builder.append_value(decimal.raw_value()),
8325 None => builder.append_null(),
8326 }
8327 }
8328 let array = builder.finish();
8329 Ok(Arc::new(array) as ArrayRef)
8330 }
8331 }
8332}
8333
8334impl CrossProductExpressionContext {
8335 fn new(schema: &Schema, lookup: FxHashMap<String, usize>) -> ExecutorResult<Self> {
8336 let mut columns = Vec::with_capacity(schema.fields().len());
8337 let mut field_id_to_index = FxHashMap::default();
8338 let mut next_field_id: FieldId = 1;
8339
8340 for (idx, field) in schema.fields().iter().enumerate() {
8341 if next_field_id == u32::MAX {
8342 return Err(Error::Internal(
8343 "cross product projection exhausted FieldId space".into(),
8344 ));
8345 }
8346
8347 let executor_column = ExecutorColumn {
8348 name: field.name().clone(),
8349 data_type: field.data_type().clone(),
8350 nullable: field.is_nullable(),
8351 primary_key: false,
8352 unique: false,
8353 field_id: next_field_id,
8354 check_expr: None,
8355 };
8356 let field_id = next_field_id;
8357 next_field_id = next_field_id.saturating_add(1);
8358
8359 columns.push(executor_column);
8360 field_id_to_index.insert(field_id, idx);
8361 }
8362
8363 Ok(Self {
8364 schema: Arc::new(ExecutorSchema { columns, lookup }),
8365 field_id_to_index,
8366 numeric_cache: FxHashMap::default(),
8367 column_cache: FxHashMap::default(),
8368 scalar_subquery_columns: FxHashMap::default(),
8369 next_field_id,
8370 })
8371 }
8372
8373 fn schema(&self) -> &ExecutorSchema {
8374 self.schema.as_ref()
8375 }
8376
8377 fn field_id_for_column(&self, name: &str) -> Option<FieldId> {
8378 self.schema.resolve(name).map(|column| column.field_id)
8379 }
8380
8381 fn reset(&mut self) {
8382 self.numeric_cache.clear();
8383 self.column_cache.clear();
8384 self.scalar_subquery_columns.clear();
8385 }
8386
8387 fn allocate_synthetic_field_id(&mut self) -> ExecutorResult<FieldId> {
8388 if self.next_field_id == FieldId::MAX {
8389 return Err(Error::Internal(
8390 "cross product projection exhausted FieldId space".into(),
8391 ));
8392 }
8393 let field_id = self.next_field_id;
8394 self.next_field_id = self.next_field_id.saturating_add(1);
8395 Ok(field_id)
8396 }
8397
8398 fn register_scalar_subquery_column(
8399 &mut self,
8400 subquery_id: SubqueryId,
8401 accessor: ColumnAccessor,
8402 ) {
8403 self.scalar_subquery_columns.insert(subquery_id, accessor);
8404 }
8405
8406 #[cfg(test)]
8407 fn evaluate(
8408 &mut self,
8409 expr: &ScalarExpr<String>,
8410 batch: &RecordBatch,
8411 ) -> ExecutorResult<ArrayRef> {
8412 let translated = translate_scalar(expr, self.schema.as_ref(), |name| {
8413 Error::InvalidArgumentError(format!(
8414 "column '{}' not found in cross product result",
8415 name
8416 ))
8417 })?;
8418
8419 self.evaluate_numeric(&translated, batch)
8420 }
8421
8422 fn evaluate_predicate_mask(
8423 &mut self,
8424 expr: &LlkvExpr<'static, FieldId>,
8425 batch: &RecordBatch,
8426 mut exists_eval: impl FnMut(
8427 &mut Self,
8428 &llkv_expr::SubqueryExpr,
8429 usize,
8430 &RecordBatch,
8431 ) -> ExecutorResult<Option<bool>>,
8432 ) -> ExecutorResult<BooleanArray> {
8433 let truths = self.evaluate_predicate_truths(expr, batch, &mut exists_eval)?;
8434 let mut builder = BooleanBuilder::with_capacity(truths.len());
8435 for value in truths {
8436 builder.append_value(value.unwrap_or(false));
8437 }
8438 Ok(builder.finish())
8439 }
8440
8441 fn evaluate_predicate_truths(
8442 &mut self,
8443 expr: &LlkvExpr<'static, FieldId>,
8444 batch: &RecordBatch,
8445 exists_eval: &mut impl FnMut(
8446 &mut Self,
8447 &llkv_expr::SubqueryExpr,
8448 usize,
8449 &RecordBatch,
8450 ) -> ExecutorResult<Option<bool>>,
8451 ) -> ExecutorResult<Vec<Option<bool>>> {
8452 match expr {
8453 LlkvExpr::Literal(value) => Ok(vec![Some(*value); batch.num_rows()]),
8454 LlkvExpr::And(children) => {
8455 if children.is_empty() {
8456 return Ok(vec![Some(true); batch.num_rows()]);
8457 }
8458 let mut result =
8459 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
8460 for child in &children[1..] {
8461 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
8462 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
8463 *lhs = truth_and(*lhs, rhs);
8464 }
8465 }
8466 Ok(result)
8467 }
8468 LlkvExpr::Or(children) => {
8469 if children.is_empty() {
8470 return Ok(vec![Some(false); batch.num_rows()]);
8471 }
8472 let mut result =
8473 self.evaluate_predicate_truths(&children[0], batch, exists_eval)?;
8474 for child in &children[1..] {
8475 let next = self.evaluate_predicate_truths(child, batch, exists_eval)?;
8476 for (lhs, rhs) in result.iter_mut().zip(next.into_iter()) {
8477 *lhs = truth_or(*lhs, rhs);
8478 }
8479 }
8480 Ok(result)
8481 }
8482 LlkvExpr::Not(inner) => {
8483 let mut values = self.evaluate_predicate_truths(inner, batch, exists_eval)?;
8484 for value in &mut values {
8485 *value = truth_not(*value);
8486 }
8487 Ok(values)
8488 }
8489 LlkvExpr::Pred(filter) => self.evaluate_filter_truths(filter, batch),
8490 LlkvExpr::Compare { left, op, right } => {
8491 self.evaluate_compare_truths(left, *op, right, batch)
8492 }
8493 LlkvExpr::InList {
8494 expr: target,
8495 list,
8496 negated,
8497 } => self.evaluate_in_list_truths(target, list, *negated, batch),
8498 LlkvExpr::IsNull { expr, negated } => {
8499 self.evaluate_is_null_truths(expr, *negated, batch)
8500 }
8501 LlkvExpr::Exists(subquery_expr) => {
8502 let mut values = Vec::with_capacity(batch.num_rows());
8503 for row_idx in 0..batch.num_rows() {
8504 let value = exists_eval(self, subquery_expr, row_idx, batch)?;
8505 values.push(value);
8506 }
8507 Ok(values)
8508 }
8509 }
8510 }
8511
8512 fn evaluate_filter_truths(
8513 &mut self,
8514 filter: &Filter<FieldId>,
8515 batch: &RecordBatch,
8516 ) -> ExecutorResult<Vec<Option<bool>>> {
8517 let accessor = self.column_accessor(filter.field_id, batch)?;
8518 let len = accessor.len();
8519
8520 match &filter.op {
8521 Operator::IsNull => {
8522 let mut out = Vec::with_capacity(len);
8523 for idx in 0..len {
8524 out.push(Some(accessor.is_null(idx)));
8525 }
8526 Ok(out)
8527 }
8528 Operator::IsNotNull => {
8529 let mut out = Vec::with_capacity(len);
8530 for idx in 0..len {
8531 out.push(Some(!accessor.is_null(idx)));
8532 }
8533 Ok(out)
8534 }
8535 _ => match accessor {
8536 ColumnAccessor::Int64(array) => {
8537 let predicate = build_fixed_width_predicate::<Int64Type>(&filter.op)
8538 .map_err(Error::predicate_build)?;
8539 let mut out = Vec::with_capacity(len);
8540 for idx in 0..len {
8541 if array.is_null(idx) {
8542 out.push(None);
8543 } else {
8544 let value = array.value(idx);
8545 out.push(Some(predicate.matches(&value)));
8546 }
8547 }
8548 Ok(out)
8549 }
8550 ColumnAccessor::Float64(array) => {
8551 let predicate = build_fixed_width_predicate::<Float64Type>(&filter.op)
8552 .map_err(Error::predicate_build)?;
8553 let mut out = Vec::with_capacity(len);
8554 for idx in 0..len {
8555 if array.is_null(idx) {
8556 out.push(None);
8557 } else {
8558 let value = array.value(idx);
8559 out.push(Some(predicate.matches(&value)));
8560 }
8561 }
8562 Ok(out)
8563 }
8564 ColumnAccessor::Boolean(array) => {
8565 let predicate =
8566 build_bool_predicate(&filter.op).map_err(Error::predicate_build)?;
8567 let mut out = Vec::with_capacity(len);
8568 for idx in 0..len {
8569 if array.is_null(idx) {
8570 out.push(None);
8571 } else {
8572 let value = array.value(idx);
8573 out.push(Some(predicate.matches(&value)));
8574 }
8575 }
8576 Ok(out)
8577 }
8578 ColumnAccessor::Utf8(array) => {
8579 let predicate =
8580 build_var_width_predicate(&filter.op).map_err(Error::predicate_build)?;
8581 let mut out = Vec::with_capacity(len);
8582 for idx in 0..len {
8583 if array.is_null(idx) {
8584 out.push(None);
8585 } else {
8586 let value = array.value(idx);
8587 out.push(Some(predicate.matches(value)));
8588 }
8589 }
8590 Ok(out)
8591 }
8592 ColumnAccessor::Date32(array) => {
8593 let predicate = build_fixed_width_predicate::<Int32Type>(&filter.op)
8594 .map_err(Error::predicate_build)?;
8595 let mut out = Vec::with_capacity(len);
8596 for idx in 0..len {
8597 if array.is_null(idx) {
8598 out.push(None);
8599 } else {
8600 let value = array.value(idx);
8601 out.push(Some(predicate.matches(&value)));
8602 }
8603 }
8604 Ok(out)
8605 }
8606 ColumnAccessor::Interval(array) => {
8607 let array = array.as_ref();
8608 let mut out = Vec::with_capacity(len);
8609 for idx in 0..len {
8610 if array.is_null(idx) {
8611 out.push(None);
8612 continue;
8613 }
8614 let literal =
8615 Literal::Interval(interval_value_from_arrow(array.value(idx)));
8616 let matches = evaluate_filter_against_literal(&literal, &filter.op)?;
8617 out.push(Some(matches));
8618 }
8619 Ok(out)
8620 }
8621 ColumnAccessor::Decimal128 { array, scale } => {
8622 let scale_factor = 10_f64.powi(scale as i32);
8625 let mut out = Vec::with_capacity(len);
8626 for idx in 0..len {
8627 if array.is_null(idx) {
8628 out.push(None);
8629 continue;
8630 }
8631 let raw_value = array.value(idx);
8632 let decimal_value = raw_value as f64 / scale_factor;
8633 let literal = Literal::Float64(decimal_value);
8634 let matches = evaluate_filter_against_literal(&literal, &filter.op)?;
8635 out.push(Some(matches));
8636 }
8637 Ok(out)
8638 }
8639 ColumnAccessor::Null(len) => Ok(vec![None; len]),
8640 },
8641 }
8642 }
8643
8644 fn evaluate_compare_truths(
8645 &mut self,
8646 left: &ScalarExpr<FieldId>,
8647 op: CompareOp,
8648 right: &ScalarExpr<FieldId>,
8649 batch: &RecordBatch,
8650 ) -> ExecutorResult<Vec<Option<bool>>> {
8651 let left_values = self.materialize_value_array(left, batch)?;
8652 let right_values = self.materialize_value_array(right, batch)?;
8653
8654 if left_values.len() != right_values.len() {
8655 return Err(Error::Internal(
8656 "mismatched compare operand lengths in cross product filter".into(),
8657 ));
8658 }
8659
8660 let len = left_values.len();
8661
8662 if matches!(left_values, ValueArray::Null(_)) || matches!(right_values, ValueArray::Null(_))
8663 {
8664 return Ok(vec![None; len]);
8665 }
8666
8667 let lhs_arr = left_values.as_array_ref();
8668 let rhs_arr = right_values.as_array_ref();
8669
8670 let result_array = llkv_compute::kernels::compute_compare(&lhs_arr, op, &rhs_arr)?;
8671 let bool_array = result_array
8672 .as_any()
8673 .downcast_ref::<BooleanArray>()
8674 .expect("compute_compare must return BooleanArray");
8675
8676 let out: Vec<Option<bool>> = bool_array.iter().collect();
8677 Ok(out)
8678 }
8679
8680 fn evaluate_is_null_truths(
8681 &mut self,
8682 expr: &ScalarExpr<FieldId>,
8683 negated: bool,
8684 batch: &RecordBatch,
8685 ) -> ExecutorResult<Vec<Option<bool>>> {
8686 let values = self.materialize_value_array(expr, batch)?;
8687 let len = values.len();
8688
8689 if let ValueArray::Null(len) = values {
8690 let result = if negated { Some(false) } else { Some(true) };
8691 return Ok(vec![result; len]);
8692 }
8693
8694 let arr = values.as_array_ref();
8695 let mut out = Vec::with_capacity(len);
8696 for idx in 0..len {
8697 let is_null = arr.is_null(idx);
8698 let result = if negated { !is_null } else { is_null };
8699 out.push(Some(result));
8700 }
8701 Ok(out)
8702 }
8703
8704 fn evaluate_in_list_truths(
8705 &mut self,
8706 target: &ScalarExpr<FieldId>,
8707 list: &[ScalarExpr<FieldId>],
8708 negated: bool,
8709 batch: &RecordBatch,
8710 ) -> ExecutorResult<Vec<Option<bool>>> {
8711 let target_values = self.materialize_value_array(target, batch)?;
8712 let list_values = list
8713 .iter()
8714 .map(|expr| self.materialize_value_array(expr, batch))
8715 .collect::<ExecutorResult<Vec<_>>>()?;
8716
8717 let len = target_values.len();
8718 for values in &list_values {
8719 if values.len() != len {
8720 return Err(Error::Internal(
8721 "mismatched IN list operand lengths in cross product filter".into(),
8722 ));
8723 }
8724 }
8725
8726 if matches!(target_values, ValueArray::Null(_)) {
8727 return Ok(vec![None; len]);
8728 }
8729
8730 let target_arr = target_values.as_array_ref();
8731 let mut combined_result: Option<BooleanArray> = None;
8732
8733 for candidate in &list_values {
8734 if matches!(candidate, ValueArray::Null(_)) {
8735 let nulls = new_null_array(&DataType::Boolean, len);
8736 let bool_nulls = nulls
8737 .as_any()
8738 .downcast_ref::<BooleanArray>()
8739 .unwrap()
8740 .clone();
8741
8742 match combined_result {
8743 None => combined_result = Some(bool_nulls),
8744 Some(prev) => {
8745 combined_result = Some(or_kleene(&prev, &bool_nulls)?);
8746 }
8747 }
8748 continue;
8749 }
8750
8751 let candidate_arr = candidate.as_array_ref();
8752
8753 let cmp =
8754 llkv_compute::kernels::compute_compare(&target_arr, CompareOp::Eq, &candidate_arr)?;
8755 let bool_cmp = cmp
8756 .as_any()
8757 .downcast_ref::<BooleanArray>()
8758 .expect("compute_compare returns BooleanArray")
8759 .clone();
8760
8761 match combined_result {
8762 None => combined_result = Some(bool_cmp),
8763 Some(prev) => {
8764 combined_result = Some(or_kleene(&prev, &bool_cmp)?);
8765 }
8766 }
8767 }
8768
8769 let final_bool = combined_result.unwrap_or_else(|| {
8770 let mut builder = BooleanBuilder::new();
8771 for _ in 0..len {
8772 builder.append_value(false);
8773 }
8774 builder.finish()
8775 });
8776
8777 let final_bool = if negated {
8778 not(&final_bool)?
8779 } else {
8780 final_bool
8781 };
8782
8783 let out: Vec<Option<bool>> = final_bool.iter().collect();
8784 Ok(out)
8785 }
8786
8787 fn evaluate_numeric(
8788 &mut self,
8789 expr: &ScalarExpr<FieldId>,
8790 batch: &RecordBatch,
8791 ) -> ExecutorResult<ArrayRef> {
8792 let mut required = FxHashSet::default();
8793 collect_field_ids(expr, &mut required);
8794
8795 let mut arrays = NumericArrayMap::default();
8796 for field_id in required {
8797 let numeric = self.numeric_array(field_id, batch)?;
8798 arrays.insert(field_id, numeric);
8799 }
8800
8801 NumericKernels::evaluate_batch(expr, batch.num_rows(), &arrays)
8802 }
8803
8804 fn numeric_array(
8805 &mut self,
8806 field_id: FieldId,
8807 batch: &RecordBatch,
8808 ) -> ExecutorResult<ArrayRef> {
8809 if let Some(existing) = self.numeric_cache.get(&field_id) {
8810 return Ok(existing.clone());
8811 }
8812
8813 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
8814 Error::Internal("field mapping missing during cross product evaluation".into())
8815 })?;
8816
8817 let array_ref = batch.column(column_index).clone();
8818 self.numeric_cache.insert(field_id, array_ref.clone());
8819 Ok(array_ref)
8820 }
8821
8822 fn column_accessor(
8823 &mut self,
8824 field_id: FieldId,
8825 batch: &RecordBatch,
8826 ) -> ExecutorResult<ColumnAccessor> {
8827 if let Some(existing) = self.column_cache.get(&field_id) {
8828 return Ok(existing.clone());
8829 }
8830
8831 let column_index = *self.field_id_to_index.get(&field_id).ok_or_else(|| {
8832 Error::Internal("field mapping missing during cross product evaluation".into())
8833 })?;
8834
8835 let accessor = ColumnAccessor::from_array(batch.column(column_index))?;
8836 self.column_cache.insert(field_id, accessor.clone());
8837 Ok(accessor)
8838 }
8839
8840 fn materialize_scalar_array(
8841 &mut self,
8842 expr: &ScalarExpr<FieldId>,
8843 batch: &RecordBatch,
8844 ) -> ExecutorResult<ArrayRef> {
8845 match expr {
8846 ScalarExpr::Column(field_id) => {
8847 let accessor = self.column_accessor(*field_id, batch)?;
8848 Ok(accessor.as_array_ref())
8849 }
8850 ScalarExpr::Literal(literal) => literal_to_constant_array(literal, batch.num_rows()),
8851 ScalarExpr::Binary { .. } => self.evaluate_numeric(expr, batch),
8852 ScalarExpr::Compare { .. } => self.evaluate_numeric(expr, batch),
8853 ScalarExpr::Not(_) => self.evaluate_numeric(expr, batch),
8854 ScalarExpr::IsNull { .. } => self.evaluate_numeric(expr, batch),
8855 ScalarExpr::Aggregate(_) => Err(Error::InvalidArgumentError(
8856 "aggregate expressions are not supported in cross product filters".into(),
8857 )),
8858 ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
8859 "struct field access is not supported in cross product filters".into(),
8860 )),
8861 ScalarExpr::Cast { expr, data_type } => {
8862 let source = self.materialize_scalar_array(expr.as_ref(), batch)?;
8863 let casted = cast(source.as_ref(), data_type).map_err(|err| {
8864 Error::InvalidArgumentError(format!("failed to cast expression: {err}"))
8865 })?;
8866 Ok(casted)
8867 }
8868 ScalarExpr::Case { .. } => self.evaluate_numeric(expr, batch),
8869 ScalarExpr::Coalesce(_) => self.evaluate_numeric(expr, batch),
8870 ScalarExpr::Random => self.evaluate_numeric(expr, batch),
8871 ScalarExpr::ScalarSubquery(subquery) => {
8872 let accessor = self
8873 .scalar_subquery_columns
8874 .get(&subquery.id)
8875 .ok_or_else(|| {
8876 Error::InvalidArgumentError(
8877 "scalar subqueries are not supported in cross product filters".into(),
8878 )
8879 })?
8880 .clone();
8881 Ok(accessor.as_array_ref())
8882 }
8883 }
8884 }
8885
8886 fn materialize_value_array(
8887 &mut self,
8888 expr: &ScalarExpr<FieldId>,
8889 batch: &RecordBatch,
8890 ) -> ExecutorResult<ValueArray> {
8891 let array = self.materialize_scalar_array(expr, batch)?;
8892 ValueArray::from_array(array)
8893 }
8894}
8895
8896fn collect_field_ids(expr: &ScalarExpr<FieldId>, out: &mut FxHashSet<FieldId>) {
8898 match expr {
8899 ScalarExpr::Column(fid) => {
8900 out.insert(*fid);
8901 }
8902 ScalarExpr::Binary { left, right, .. } => {
8903 collect_field_ids(left, out);
8904 collect_field_ids(right, out);
8905 }
8906 ScalarExpr::Compare { left, right, .. } => {
8907 collect_field_ids(left, out);
8908 collect_field_ids(right, out);
8909 }
8910 ScalarExpr::Aggregate(call) => match call {
8911 AggregateCall::CountStar => {}
8912 AggregateCall::Count { expr, .. }
8913 | AggregateCall::Sum { expr, .. }
8914 | AggregateCall::Total { expr, .. }
8915 | AggregateCall::Avg { expr, .. }
8916 | AggregateCall::Min(expr)
8917 | AggregateCall::Max(expr)
8918 | AggregateCall::CountNulls(expr)
8919 | AggregateCall::GroupConcat { expr, .. } => {
8920 collect_field_ids(expr, out);
8921 }
8922 },
8923 ScalarExpr::GetField { base, .. } => collect_field_ids(base, out),
8924 ScalarExpr::Cast { expr, .. } => collect_field_ids(expr, out),
8925 ScalarExpr::Not(expr) => collect_field_ids(expr, out),
8926 ScalarExpr::IsNull { expr, .. } => collect_field_ids(expr, out),
8927 ScalarExpr::Case {
8928 operand,
8929 branches,
8930 else_expr,
8931 } => {
8932 if let Some(inner) = operand.as_deref() {
8933 collect_field_ids(inner, out);
8934 }
8935 for (when_expr, then_expr) in branches {
8936 collect_field_ids(when_expr, out);
8937 collect_field_ids(then_expr, out);
8938 }
8939 if let Some(inner) = else_expr.as_deref() {
8940 collect_field_ids(inner, out);
8941 }
8942 }
8943 ScalarExpr::Coalesce(items) => {
8944 for item in items {
8945 collect_field_ids(item, out);
8946 }
8947 }
8948 ScalarExpr::Literal(_) | ScalarExpr::Random => {}
8949 ScalarExpr::ScalarSubquery(_) => {}
8950 }
8951}
8952
8953fn strip_exists(expr: &LlkvExpr<'static, FieldId>) -> LlkvExpr<'static, FieldId> {
8954 match expr {
8955 LlkvExpr::And(children) => LlkvExpr::And(children.iter().map(strip_exists).collect()),
8956 LlkvExpr::Or(children) => LlkvExpr::Or(children.iter().map(strip_exists).collect()),
8957 LlkvExpr::Not(inner) => LlkvExpr::Not(Box::new(strip_exists(inner))),
8958 LlkvExpr::Pred(filter) => LlkvExpr::Pred(filter.clone()),
8959 LlkvExpr::Compare { left, op, right } => LlkvExpr::Compare {
8960 left: left.clone(),
8961 op: *op,
8962 right: right.clone(),
8963 },
8964 LlkvExpr::InList {
8965 expr,
8966 list,
8967 negated,
8968 } => LlkvExpr::InList {
8969 expr: expr.clone(),
8970 list: list.clone(),
8971 negated: *negated,
8972 },
8973 LlkvExpr::IsNull { expr, negated } => LlkvExpr::IsNull {
8974 expr: expr.clone(),
8975 negated: *negated,
8976 },
8977 LlkvExpr::Literal(value) => LlkvExpr::Literal(*value),
8978 LlkvExpr::Exists(_) => LlkvExpr::Literal(true),
8979 }
8980}
8981
8982fn rewrite_predicate_scalar_subqueries(
8983 expr: LlkvExpr<'static, FieldId>,
8984 literals: &FxHashMap<SubqueryId, Literal>,
8985) -> ExecutorResult<LlkvExpr<'static, FieldId>> {
8986 match expr {
8987 LlkvExpr::And(children) => {
8988 let rewritten: ExecutorResult<Vec<_>> = children
8989 .into_iter()
8990 .map(|child| rewrite_predicate_scalar_subqueries(child, literals))
8991 .collect();
8992 Ok(LlkvExpr::And(rewritten?))
8993 }
8994 LlkvExpr::Or(children) => {
8995 let rewritten: ExecutorResult<Vec<_>> = children
8996 .into_iter()
8997 .map(|child| rewrite_predicate_scalar_subqueries(child, literals))
8998 .collect();
8999 Ok(LlkvExpr::Or(rewritten?))
9000 }
9001 LlkvExpr::Not(inner) => Ok(LlkvExpr::Not(Box::new(
9002 rewrite_predicate_scalar_subqueries(*inner, literals)?,
9003 ))),
9004 LlkvExpr::Pred(filter) => Ok(LlkvExpr::Pred(filter)),
9005 LlkvExpr::Compare { left, op, right } => Ok(LlkvExpr::Compare {
9006 left: rewrite_scalar_expr_subqueries(left, literals)?,
9007 op,
9008 right: rewrite_scalar_expr_subqueries(right, literals)?,
9009 }),
9010 LlkvExpr::InList {
9011 expr,
9012 list,
9013 negated,
9014 } => Ok(LlkvExpr::InList {
9015 expr: rewrite_scalar_expr_subqueries(expr, literals)?,
9016 list: list
9017 .into_iter()
9018 .map(|item| rewrite_scalar_expr_subqueries(item, literals))
9019 .collect::<ExecutorResult<_>>()?,
9020 negated,
9021 }),
9022 LlkvExpr::IsNull { expr, negated } => Ok(LlkvExpr::IsNull {
9023 expr: rewrite_scalar_expr_subqueries(expr, literals)?,
9024 negated,
9025 }),
9026 LlkvExpr::Literal(value) => Ok(LlkvExpr::Literal(value)),
9027 LlkvExpr::Exists(subquery) => Ok(LlkvExpr::Exists(subquery)),
9028 }
9029}
9030
9031fn rewrite_scalar_expr_subqueries(
9032 expr: ScalarExpr<FieldId>,
9033 literals: &FxHashMap<SubqueryId, Literal>,
9034) -> ExecutorResult<ScalarExpr<FieldId>> {
9035 match expr {
9036 ScalarExpr::ScalarSubquery(subquery) => {
9037 let literal = literals.get(&subquery.id).ok_or_else(|| {
9038 Error::Internal(format!(
9039 "missing literal for scalar subquery {:?}",
9040 subquery.id
9041 ))
9042 })?;
9043 Ok(ScalarExpr::Literal(literal.clone()))
9044 }
9045 ScalarExpr::Column(fid) => Ok(ScalarExpr::Column(fid)),
9046 ScalarExpr::Literal(lit) => Ok(ScalarExpr::Literal(lit)),
9047 ScalarExpr::Binary { left, op, right } => Ok(ScalarExpr::Binary {
9048 left: Box::new(rewrite_scalar_expr_subqueries(*left, literals)?),
9049 op,
9050 right: Box::new(rewrite_scalar_expr_subqueries(*right, literals)?),
9051 }),
9052 ScalarExpr::Compare { left, op, right } => Ok(ScalarExpr::Compare {
9053 left: Box::new(rewrite_scalar_expr_subqueries(*left, literals)?),
9054 op,
9055 right: Box::new(rewrite_scalar_expr_subqueries(*right, literals)?),
9056 }),
9057 ScalarExpr::Not(inner) => Ok(ScalarExpr::Not(Box::new(rewrite_scalar_expr_subqueries(
9058 *inner, literals,
9059 )?))),
9060 ScalarExpr::IsNull { expr, negated } => Ok(ScalarExpr::IsNull {
9061 expr: Box::new(rewrite_scalar_expr_subqueries(*expr, literals)?),
9062 negated,
9063 }),
9064 ScalarExpr::Aggregate(agg) => Ok(ScalarExpr::Aggregate(agg)),
9065 ScalarExpr::GetField { base, field_name } => Ok(ScalarExpr::GetField {
9066 base: Box::new(rewrite_scalar_expr_subqueries(*base, literals)?),
9067 field_name,
9068 }),
9069 ScalarExpr::Cast { expr, data_type } => Ok(ScalarExpr::Cast {
9070 expr: Box::new(rewrite_scalar_expr_subqueries(*expr, literals)?),
9071 data_type,
9072 }),
9073 ScalarExpr::Case {
9074 operand,
9075 branches,
9076 else_expr,
9077 } => Ok(ScalarExpr::Case {
9078 operand: operand
9079 .map(|e| rewrite_scalar_expr_subqueries(*e, literals))
9080 .transpose()?
9081 .map(Box::new),
9082 branches: branches
9083 .into_iter()
9084 .map(|(when, then)| {
9085 Ok((
9086 rewrite_scalar_expr_subqueries(when, literals)?,
9087 rewrite_scalar_expr_subqueries(then, literals)?,
9088 ))
9089 })
9090 .collect::<ExecutorResult<_>>()?,
9091 else_expr: else_expr
9092 .map(|e| rewrite_scalar_expr_subqueries(*e, literals))
9093 .transpose()?
9094 .map(Box::new),
9095 }),
9096 ScalarExpr::Coalesce(items) => Ok(ScalarExpr::Coalesce(
9097 items
9098 .into_iter()
9099 .map(|item| rewrite_scalar_expr_subqueries(item, literals))
9100 .collect::<ExecutorResult<_>>()?,
9101 )),
9102 ScalarExpr::Random => Ok(ScalarExpr::Random),
9103 }
9104}
9105
9106fn bind_select_plan(
9107 plan: &SelectPlan,
9108 bindings: &FxHashMap<String, Literal>,
9109) -> ExecutorResult<SelectPlan> {
9110 if bindings.is_empty() {
9111 return Ok(plan.clone());
9112 }
9113
9114 let projections = plan
9115 .projections
9116 .iter()
9117 .map(|projection| bind_projection(projection, bindings))
9118 .collect::<ExecutorResult<Vec<_>>>()?;
9119
9120 let filter = match &plan.filter {
9121 Some(wrapper) => Some(bind_select_filter(wrapper, bindings)?),
9122 None => None,
9123 };
9124
9125 let aggregates = plan
9126 .aggregates
9127 .iter()
9128 .map(|aggregate| bind_aggregate_expr(aggregate, bindings))
9129 .collect::<ExecutorResult<Vec<_>>>()?;
9130
9131 let scalar_subqueries = plan
9132 .scalar_subqueries
9133 .iter()
9134 .map(|subquery| bind_scalar_subquery(subquery, bindings))
9135 .collect::<ExecutorResult<Vec<_>>>()?;
9136
9137 if let Some(compound) = &plan.compound {
9138 let bound_compound = bind_compound_select(compound, bindings)?;
9139 return Ok(SelectPlan {
9140 tables: Vec::new(),
9141 joins: Vec::new(),
9142 projections: Vec::new(),
9143 filter: None,
9144 having: None,
9145 aggregates: Vec::new(),
9146 order_by: plan.order_by.clone(),
9147 distinct: false,
9148 scalar_subqueries: Vec::new(),
9149 compound: Some(bound_compound),
9150 group_by: Vec::new(),
9151 value_table_mode: None,
9152 limit: plan.limit,
9153 offset: plan.offset,
9154 });
9155 }
9156
9157 Ok(SelectPlan {
9158 tables: plan.tables.clone(),
9159 joins: plan.joins.clone(),
9160 projections,
9161 filter,
9162 having: plan.having.clone(),
9163 aggregates,
9164 order_by: Vec::new(),
9165 distinct: plan.distinct,
9166 scalar_subqueries,
9167 compound: None,
9168 group_by: plan.group_by.clone(),
9169 value_table_mode: plan.value_table_mode.clone(),
9170 limit: plan.limit,
9171 offset: plan.offset,
9172 })
9173}
9174
9175fn bind_compound_select(
9176 compound: &CompoundSelectPlan,
9177 bindings: &FxHashMap<String, Literal>,
9178) -> ExecutorResult<CompoundSelectPlan> {
9179 let initial = bind_select_plan(&compound.initial, bindings)?;
9180 let mut operations = Vec::with_capacity(compound.operations.len());
9181 for component in &compound.operations {
9182 let bound_plan = bind_select_plan(&component.plan, bindings)?;
9183 operations.push(CompoundSelectComponent {
9184 operator: component.operator.clone(),
9185 quantifier: component.quantifier.clone(),
9186 plan: bound_plan,
9187 });
9188 }
9189 Ok(CompoundSelectPlan {
9190 initial: Box::new(initial),
9191 operations,
9192 })
9193}
9194
9195fn ensure_schema_compatibility(base: &Schema, other: &Schema) -> ExecutorResult<()> {
9196 if base.fields().len() != other.fields().len() {
9197 return Err(Error::InvalidArgumentError(
9198 "compound SELECT requires matching column counts".into(),
9199 ));
9200 }
9201 for (left, right) in base.fields().iter().zip(other.fields().iter()) {
9202 if left.data_type() != right.data_type() {
9203 return Err(Error::InvalidArgumentError(format!(
9204 "compound SELECT column type mismatch: {} vs {}",
9205 left.data_type(),
9206 right.data_type()
9207 )));
9208 }
9209 }
9210 Ok(())
9211}
9212
9213fn ensure_distinct_rows(rows: &mut Vec<Vec<PlanValue>>, cache: &mut Option<FxHashSet<Vec<u8>>>) {
9214 if cache.is_some() {
9215 return;
9216 }
9217 let mut set = FxHashSet::default();
9218 let mut deduped: Vec<Vec<PlanValue>> = Vec::with_capacity(rows.len());
9219 for row in rows.drain(..) {
9220 let key = encode_row(&row);
9221 if set.insert(key) {
9222 deduped.push(row);
9223 }
9224 }
9225 *rows = deduped;
9226 *cache = Some(set);
9227}
9228
9229fn encode_row(row: &[PlanValue]) -> Vec<u8> {
9230 let mut buf = Vec::new();
9231 for value in row {
9232 encode_plan_value(&mut buf, value);
9233 buf.push(0x1F);
9234 }
9235 buf
9236}
9237
9238fn encode_plan_value(buf: &mut Vec<u8>, value: &PlanValue) {
9239 match value {
9240 PlanValue::Null => buf.push(0),
9241 PlanValue::Integer(v) => {
9242 buf.push(1);
9243 buf.extend_from_slice(&v.to_be_bytes());
9244 }
9245 PlanValue::Float(v) => {
9246 buf.push(2);
9247 buf.extend_from_slice(&v.to_bits().to_be_bytes());
9248 }
9249 PlanValue::Decimal(decimal) => {
9250 buf.push(7);
9251 buf.extend_from_slice(&decimal.raw_value().to_be_bytes());
9252 buf.push(decimal.scale().to_be_bytes()[0]);
9253 }
9254 PlanValue::String(s) => {
9255 buf.push(3);
9256 let bytes = s.as_bytes();
9257 let len = u32::try_from(bytes.len()).unwrap_or(u32::MAX);
9258 buf.extend_from_slice(&len.to_be_bytes());
9259 buf.extend_from_slice(bytes);
9260 }
9261 PlanValue::Date32(days) => {
9262 buf.push(5);
9263 buf.extend_from_slice(&days.to_be_bytes());
9264 }
9265 PlanValue::Struct(map) => {
9266 buf.push(4);
9267 let mut entries: Vec<_> = map.iter().collect();
9268 entries.sort_by(|a, b| a.0.cmp(b.0));
9269 let len = u32::try_from(entries.len()).unwrap_or(u32::MAX);
9270 buf.extend_from_slice(&len.to_be_bytes());
9271 for (key, val) in entries {
9272 let key_bytes = key.as_bytes();
9273 let key_len = u32::try_from(key_bytes.len()).unwrap_or(u32::MAX);
9274 buf.extend_from_slice(&key_len.to_be_bytes());
9275 buf.extend_from_slice(key_bytes);
9276 encode_plan_value(buf, val);
9277 }
9278 }
9279 PlanValue::Interval(interval) => {
9280 buf.push(6);
9281 buf.extend_from_slice(&interval.months.to_be_bytes());
9282 buf.extend_from_slice(&interval.days.to_be_bytes());
9283 buf.extend_from_slice(&interval.nanos.to_be_bytes());
9284 }
9285 }
9286}
9287
9288fn rows_to_record_batch(
9289 schema: Arc<Schema>,
9290 rows: &[Vec<PlanValue>],
9291) -> ExecutorResult<RecordBatch> {
9292 let column_count = schema.fields().len();
9293 let mut columns: Vec<Vec<PlanValue>> = vec![Vec::with_capacity(rows.len()); column_count];
9294 for row in rows {
9295 if row.len() != column_count {
9296 return Err(Error::InvalidArgumentError(
9297 "compound SELECT produced mismatched column counts".into(),
9298 ));
9299 }
9300 for (idx, value) in row.iter().enumerate() {
9301 columns[idx].push(value.clone());
9302 }
9303 }
9304
9305 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(column_count);
9306 for (idx, field) in schema.fields().iter().enumerate() {
9307 let array = build_array_for_column(field.data_type(), &columns[idx])?;
9308 arrays.push(array);
9309 }
9310
9311 RecordBatch::try_new(schema, arrays).map_err(|err| {
9312 Error::InvalidArgumentError(format!("failed to materialize compound SELECT: {err}"))
9313 })
9314}
9315
9316fn build_column_lookup_map(schema: &Schema) -> FxHashMap<String, usize> {
9317 let mut lookup = FxHashMap::default();
9318 for (idx, field) in schema.fields().iter().enumerate() {
9319 lookup.insert(field.name().to_ascii_lowercase(), idx);
9320 }
9321 lookup
9322}
9323
9324fn build_group_key(
9325 batch: &RecordBatch,
9326 row_idx: usize,
9327 key_indices: &[usize],
9328) -> ExecutorResult<Vec<GroupKeyValue>> {
9329 let mut values = Vec::with_capacity(key_indices.len());
9330 for &index in key_indices {
9331 values.push(group_key_value(batch.column(index), row_idx)?);
9332 }
9333 Ok(values)
9334}
9335
9336fn group_key_value(array: &ArrayRef, row_idx: usize) -> ExecutorResult<GroupKeyValue> {
9337 if !array.is_valid(row_idx) {
9338 return Ok(GroupKeyValue::Null);
9339 }
9340
9341 match array.data_type() {
9342 DataType::Int8 => {
9343 let values = array
9344 .as_any()
9345 .downcast_ref::<Int8Array>()
9346 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
9347 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9348 }
9349 DataType::Int16 => {
9350 let values = array
9351 .as_any()
9352 .downcast_ref::<Int16Array>()
9353 .ok_or_else(|| Error::Internal("failed to downcast to Int16Array".into()))?;
9354 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9355 }
9356 DataType::Int32 => {
9357 let values = array
9358 .as_any()
9359 .downcast_ref::<Int32Array>()
9360 .ok_or_else(|| Error::Internal("failed to downcast to Int32Array".into()))?;
9361 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9362 }
9363 DataType::Int64 => {
9364 let values = array
9365 .as_any()
9366 .downcast_ref::<Int64Array>()
9367 .ok_or_else(|| Error::Internal("failed to downcast to Int64Array".into()))?;
9368 Ok(GroupKeyValue::Int(values.value(row_idx)))
9369 }
9370 DataType::UInt8 => {
9371 let values = array
9372 .as_any()
9373 .downcast_ref::<UInt8Array>()
9374 .ok_or_else(|| Error::Internal("failed to downcast to UInt8Array".into()))?;
9375 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9376 }
9377 DataType::UInt16 => {
9378 let values = array
9379 .as_any()
9380 .downcast_ref::<UInt16Array>()
9381 .ok_or_else(|| Error::Internal("failed to downcast to UInt16Array".into()))?;
9382 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9383 }
9384 DataType::UInt32 => {
9385 let values = array
9386 .as_any()
9387 .downcast_ref::<UInt32Array>()
9388 .ok_or_else(|| Error::Internal("failed to downcast to UInt32Array".into()))?;
9389 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9390 }
9391 DataType::UInt64 => {
9392 let values = array
9393 .as_any()
9394 .downcast_ref::<UInt64Array>()
9395 .ok_or_else(|| Error::Internal("failed to downcast to UInt64Array".into()))?;
9396 let value = values.value(row_idx);
9397 if value > i64::MAX as u64 {
9398 return Err(Error::InvalidArgumentError(
9399 "GROUP BY value exceeds supported integer range".into(),
9400 ));
9401 }
9402 Ok(GroupKeyValue::Int(value as i64))
9403 }
9404 DataType::Date32 => {
9405 let values = array
9406 .as_any()
9407 .downcast_ref::<Date32Array>()
9408 .ok_or_else(|| Error::Internal("failed to downcast to Date32Array".into()))?;
9409 Ok(GroupKeyValue::Int(values.value(row_idx) as i64))
9410 }
9411 DataType::Boolean => {
9412 let values = array
9413 .as_any()
9414 .downcast_ref::<BooleanArray>()
9415 .ok_or_else(|| Error::Internal("failed to downcast to BooleanArray".into()))?;
9416 Ok(GroupKeyValue::Bool(values.value(row_idx)))
9417 }
9418 DataType::Utf8 => {
9419 let values = array
9420 .as_any()
9421 .downcast_ref::<StringArray>()
9422 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
9423 Ok(GroupKeyValue::String(values.value(row_idx).to_string()))
9424 }
9425 other => Err(Error::InvalidArgumentError(format!(
9426 "GROUP BY does not support column type {:?}",
9427 other
9428 ))),
9429 }
9430}
9431
9432fn evaluate_constant_predicate(expr: &LlkvExpr<'static, String>) -> Option<Option<bool>> {
9433 match expr {
9434 LlkvExpr::Literal(value) => Some(Some(*value)),
9435 LlkvExpr::Not(inner) => {
9436 let inner_val = evaluate_constant_predicate(inner)?;
9437 Some(truth_not(inner_val))
9438 }
9439 LlkvExpr::And(children) => {
9440 let mut acc = Some(true);
9441 for child in children {
9442 let child_val = evaluate_constant_predicate(child)?;
9443 acc = truth_and(acc, child_val);
9444 }
9445 Some(acc)
9446 }
9447 LlkvExpr::Or(children) => {
9448 let mut acc = Some(false);
9449 for child in children {
9450 let child_val = evaluate_constant_predicate(child)?;
9451 acc = truth_or(acc, child_val);
9452 }
9453 Some(acc)
9454 }
9455 LlkvExpr::Compare { left, op, right } => {
9456 let left_literal = evaluate_constant_scalar(left)?;
9457 let right_literal = evaluate_constant_scalar(right)?;
9458 Some(compare_literals(*op, &left_literal, &right_literal))
9459 }
9460 LlkvExpr::IsNull { expr, negated } => {
9461 let literal = evaluate_constant_scalar(expr)?;
9462 let is_null = matches!(literal, Literal::Null);
9463 Some(Some(if *negated { !is_null } else { is_null }))
9464 }
9465 LlkvExpr::InList {
9466 expr,
9467 list,
9468 negated,
9469 } => {
9470 let needle = evaluate_constant_scalar(expr)?;
9471 let mut saw_unknown = false;
9472
9473 for candidate in list {
9474 let value = evaluate_constant_scalar(candidate)?;
9475 match compare_literals(CompareOp::Eq, &needle, &value) {
9476 Some(true) => {
9477 return Some(Some(!*negated));
9478 }
9479 Some(false) => {}
9480 None => saw_unknown = true,
9481 }
9482 }
9483
9484 if saw_unknown {
9485 Some(None)
9486 } else {
9487 Some(Some(*negated))
9488 }
9489 }
9490 _ => None,
9491 }
9492}
9493
9494enum ConstantJoinEvaluation {
9495 Known(bool),
9496 Unknown,
9497 NotConstant,
9498}
9499
9500fn evaluate_constant_join_expr(expr: &LlkvExpr<'static, String>) -> ConstantJoinEvaluation {
9501 match expr {
9502 LlkvExpr::Literal(value) => ConstantJoinEvaluation::Known(*value),
9503 LlkvExpr::And(children) => {
9504 let mut saw_unknown = false;
9505 for child in children {
9506 match evaluate_constant_join_expr(child) {
9507 ConstantJoinEvaluation::Known(false) => {
9508 return ConstantJoinEvaluation::Known(false);
9509 }
9510 ConstantJoinEvaluation::Known(true) => {}
9511 ConstantJoinEvaluation::Unknown => saw_unknown = true,
9512 ConstantJoinEvaluation::NotConstant => {
9513 return ConstantJoinEvaluation::NotConstant;
9514 }
9515 }
9516 }
9517 if saw_unknown {
9518 ConstantJoinEvaluation::Unknown
9519 } else {
9520 ConstantJoinEvaluation::Known(true)
9521 }
9522 }
9523 LlkvExpr::Or(children) => {
9524 let mut saw_unknown = false;
9525 for child in children {
9526 match evaluate_constant_join_expr(child) {
9527 ConstantJoinEvaluation::Known(true) => {
9528 return ConstantJoinEvaluation::Known(true);
9529 }
9530 ConstantJoinEvaluation::Known(false) => {}
9531 ConstantJoinEvaluation::Unknown => saw_unknown = true,
9532 ConstantJoinEvaluation::NotConstant => {
9533 return ConstantJoinEvaluation::NotConstant;
9534 }
9535 }
9536 }
9537 if saw_unknown {
9538 ConstantJoinEvaluation::Unknown
9539 } else {
9540 ConstantJoinEvaluation::Known(false)
9541 }
9542 }
9543 LlkvExpr::Not(inner) => match evaluate_constant_join_expr(inner) {
9544 ConstantJoinEvaluation::Known(value) => ConstantJoinEvaluation::Known(!value),
9545 ConstantJoinEvaluation::Unknown => ConstantJoinEvaluation::Unknown,
9546 ConstantJoinEvaluation::NotConstant => ConstantJoinEvaluation::NotConstant,
9547 },
9548 LlkvExpr::Compare { left, op, right } => {
9549 let left_lit = evaluate_constant_scalar(left);
9550 let right_lit = evaluate_constant_scalar(right);
9551
9552 if matches!(left_lit, Some(Literal::Null)) || matches!(right_lit, Some(Literal::Null)) {
9553 return ConstantJoinEvaluation::Unknown;
9555 }
9556
9557 let (Some(left_lit), Some(right_lit)) = (left_lit, right_lit) else {
9558 return ConstantJoinEvaluation::NotConstant;
9559 };
9560
9561 match compare_literals(*op, &left_lit, &right_lit) {
9562 Some(result) => ConstantJoinEvaluation::Known(result),
9563 None => ConstantJoinEvaluation::Unknown,
9564 }
9565 }
9566 LlkvExpr::IsNull { expr, negated } => match evaluate_constant_scalar(expr) {
9567 Some(literal) => {
9568 let is_null = matches!(literal, Literal::Null);
9569 let value = if *negated { !is_null } else { is_null };
9570 ConstantJoinEvaluation::Known(value)
9571 }
9572 None => ConstantJoinEvaluation::NotConstant,
9573 },
9574 LlkvExpr::InList {
9575 expr,
9576 list,
9577 negated,
9578 } => {
9579 let needle = match evaluate_constant_scalar(expr) {
9580 Some(literal) => literal,
9581 None => return ConstantJoinEvaluation::NotConstant,
9582 };
9583
9584 if matches!(needle, Literal::Null) {
9585 return ConstantJoinEvaluation::Unknown;
9586 }
9587
9588 let mut saw_unknown = false;
9589 for candidate in list {
9590 let value = match evaluate_constant_scalar(candidate) {
9591 Some(literal) => literal,
9592 None => return ConstantJoinEvaluation::NotConstant,
9593 };
9594
9595 match compare_literals(CompareOp::Eq, &needle, &value) {
9596 Some(true) => {
9597 let result = !*negated;
9598 return ConstantJoinEvaluation::Known(result);
9599 }
9600 Some(false) => {}
9601 None => saw_unknown = true,
9602 }
9603 }
9604
9605 if saw_unknown {
9606 ConstantJoinEvaluation::Unknown
9607 } else {
9608 let result = *negated;
9609 ConstantJoinEvaluation::Known(result)
9610 }
9611 }
9612 _ => ConstantJoinEvaluation::NotConstant,
9613 }
9614}
9615
9616enum NullComparisonBehavior {
9617 ThreeValuedLogic,
9618}
9619
9620fn evaluate_constant_scalar(expr: &ScalarExpr<String>) -> Option<Literal> {
9621 evaluate_constant_scalar_internal(expr, false)
9622}
9623
9624fn evaluate_constant_scalar_with_aggregates(expr: &ScalarExpr<String>) -> Option<Literal> {
9625 evaluate_constant_scalar_internal(expr, true)
9626}
9627
9628fn evaluate_constant_scalar_internal(
9629 expr: &ScalarExpr<String>,
9630 allow_aggregates: bool,
9631) -> Option<Literal> {
9632 match expr {
9633 ScalarExpr::Literal(lit) => Some(lit.clone()),
9634 ScalarExpr::Binary { left, op, right } => {
9635 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
9636 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
9637 evaluate_binary_literal(*op, &left_value, &right_value)
9638 }
9639 ScalarExpr::Cast { expr, data_type } => {
9640 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9641 cast_literal_to_type(&value, data_type)
9642 }
9643 ScalarExpr::Not(inner) => {
9644 let value = evaluate_constant_scalar_internal(inner, allow_aggregates)?;
9645 match literal_truthiness(&value) {
9646 Some(true) => Some(Literal::Int128(0)),
9647 Some(false) => Some(Literal::Int128(1)),
9648 None => Some(Literal::Null),
9649 }
9650 }
9651 ScalarExpr::IsNull { expr, negated } => {
9652 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9653 let is_null = matches!(value, Literal::Null);
9654 Some(Literal::Boolean(if *negated { !is_null } else { is_null }))
9655 }
9656 ScalarExpr::Coalesce(items) => {
9657 let mut saw_null = false;
9658 for item in items {
9659 match evaluate_constant_scalar_internal(item, allow_aggregates) {
9660 Some(Literal::Null) => saw_null = true,
9661 Some(value) => return Some(value),
9662 None => return None,
9663 }
9664 }
9665 if saw_null { Some(Literal::Null) } else { None }
9666 }
9667 ScalarExpr::Compare { left, op, right } => {
9668 let left_value = evaluate_constant_scalar_internal(left, allow_aggregates)?;
9669 let right_value = evaluate_constant_scalar_internal(right, allow_aggregates)?;
9670 match compare_literals(*op, &left_value, &right_value) {
9671 Some(flag) => Some(Literal::Boolean(flag)),
9672 None => Some(Literal::Null),
9673 }
9674 }
9675 ScalarExpr::Case {
9676 operand,
9677 branches,
9678 else_expr,
9679 } => {
9680 if let Some(operand_expr) = operand {
9681 let operand_value =
9682 evaluate_constant_scalar_internal(operand_expr, allow_aggregates)?;
9683 for (when_expr, then_expr) in branches {
9684 let when_value =
9685 evaluate_constant_scalar_internal(when_expr, allow_aggregates)?;
9686 if let Some(true) = compare_literals(CompareOp::Eq, &operand_value, &when_value)
9687 {
9688 return evaluate_constant_scalar_internal(then_expr, allow_aggregates);
9689 }
9690 }
9691 } else {
9692 for (condition_expr, result_expr) in branches {
9693 let condition_value =
9694 evaluate_constant_scalar_internal(condition_expr, allow_aggregates)?;
9695 match literal_truthiness(&condition_value) {
9696 Some(true) => {
9697 return evaluate_constant_scalar_internal(
9698 result_expr,
9699 allow_aggregates,
9700 );
9701 }
9702 Some(false) => {}
9703 None => {}
9704 }
9705 }
9706 }
9707
9708 if let Some(else_branch) = else_expr {
9709 evaluate_constant_scalar_internal(else_branch, allow_aggregates)
9710 } else {
9711 Some(Literal::Null)
9712 }
9713 }
9714 ScalarExpr::Column(_) => None,
9715 ScalarExpr::Aggregate(call) => {
9716 if allow_aggregates {
9717 evaluate_constant_aggregate(call, allow_aggregates)
9718 } else {
9719 None
9720 }
9721 }
9722 ScalarExpr::GetField { .. } => None,
9723 ScalarExpr::Random => None, ScalarExpr::ScalarSubquery(_) => None,
9725 }
9726}
9727
9728fn evaluate_constant_aggregate(
9729 call: &AggregateCall<String>,
9730 allow_aggregates: bool,
9731) -> Option<Literal> {
9732 match call {
9733 AggregateCall::CountStar => Some(Literal::Int128(1)),
9734 AggregateCall::Count { expr, .. } => {
9735 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9736 if matches!(value, Literal::Null) {
9737 Some(Literal::Int128(0))
9738 } else {
9739 Some(Literal::Int128(1))
9740 }
9741 }
9742 AggregateCall::Sum { expr, .. } => {
9743 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9744 match value {
9745 Literal::Null => Some(Literal::Null),
9746 Literal::Int128(value) => Some(Literal::Int128(value)),
9747 Literal::Float64(value) => Some(Literal::Float64(value)),
9748 Literal::Boolean(flag) => Some(Literal::Int128(if flag { 1 } else { 0 })),
9749 _ => None,
9750 }
9751 }
9752 AggregateCall::Total { expr, .. } => {
9753 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9754 match value {
9755 Literal::Null => Some(Literal::Int128(0)),
9756 Literal::Int128(value) => Some(Literal::Int128(value)),
9757 Literal::Float64(value) => Some(Literal::Float64(value)),
9758 Literal::Boolean(flag) => Some(Literal::Int128(if flag { 1 } else { 0 })),
9759 _ => None,
9760 }
9761 }
9762 AggregateCall::Avg { expr, .. } => {
9763 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9764 match value {
9765 Literal::Null => Some(Literal::Null),
9766 other => {
9767 let numeric = literal_to_f64(&other)?;
9768 Some(Literal::Float64(numeric))
9769 }
9770 }
9771 }
9772 AggregateCall::Min(expr) => {
9773 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9774 match value {
9775 Literal::Null => Some(Literal::Null),
9776 other => Some(other),
9777 }
9778 }
9779 AggregateCall::Max(expr) => {
9780 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9781 match value {
9782 Literal::Null => Some(Literal::Null),
9783 other => Some(other),
9784 }
9785 }
9786 AggregateCall::CountNulls(expr) => {
9787 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9788 let count = if matches!(value, Literal::Null) { 1 } else { 0 };
9789 Some(Literal::Int128(count))
9790 }
9791 AggregateCall::GroupConcat {
9792 expr, separator: _, ..
9793 } => {
9794 let value = evaluate_constant_scalar_internal(expr, allow_aggregates)?;
9795 match value {
9796 Literal::Null => Some(Literal::Null),
9797 Literal::String(s) => Some(Literal::String(s)),
9798 Literal::Int128(i) => Some(Literal::String(i.to_string())),
9799 Literal::Float64(f) => Some(Literal::String(f.to_string())),
9800 Literal::Boolean(b) => Some(Literal::String(if b { "1" } else { "0" }.to_string())),
9801 _ => None,
9802 }
9803 }
9804 }
9805}
9806
9807fn evaluate_binary_literal(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
9808 match op {
9809 BinaryOp::And => evaluate_literal_logical_and(left, right),
9810 BinaryOp::Or => evaluate_literal_logical_or(left, right),
9811 BinaryOp::Add
9812 | BinaryOp::Subtract
9813 | BinaryOp::Multiply
9814 | BinaryOp::Divide
9815 | BinaryOp::Modulo => {
9816 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
9817 return Some(Literal::Null);
9818 }
9819
9820 match op {
9821 BinaryOp::Add => add_literals(left, right),
9822 BinaryOp::Subtract => subtract_literals(left, right),
9823 BinaryOp::Multiply => multiply_literals(left, right),
9824 BinaryOp::Divide => divide_literals(left, right),
9825 BinaryOp::Modulo => modulo_literals(left, right),
9826 BinaryOp::And
9827 | BinaryOp::Or
9828 | BinaryOp::BitwiseShiftLeft
9829 | BinaryOp::BitwiseShiftRight => unreachable!(),
9830 }
9831 }
9832 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
9833 if matches!(left, Literal::Null) || matches!(right, Literal::Null) {
9834 return Some(Literal::Null);
9835 }
9836
9837 let lhs = literal_to_i128(left)?;
9839 let rhs = literal_to_i128(right)?;
9840
9841 let result = match op {
9843 BinaryOp::BitwiseShiftLeft => (lhs as i64).wrapping_shl(rhs as u32) as i128,
9844 BinaryOp::BitwiseShiftRight => (lhs as i64).wrapping_shr(rhs as u32) as i128,
9845 _ => unreachable!(),
9846 };
9847
9848 Some(Literal::Int128(result))
9849 }
9850 }
9851}
9852
9853fn evaluate_literal_logical_and(left: &Literal, right: &Literal) -> Option<Literal> {
9854 let left_truth = literal_truthiness(left);
9855 if matches!(left_truth, Some(false)) {
9856 return Some(Literal::Int128(0));
9857 }
9858
9859 let right_truth = literal_truthiness(right);
9860 if matches!(right_truth, Some(false)) {
9861 return Some(Literal::Int128(0));
9862 }
9863
9864 match (left_truth, right_truth) {
9865 (Some(true), Some(true)) => Some(Literal::Int128(1)),
9866 (Some(true), None) | (None, Some(true)) | (None, None) => Some(Literal::Null),
9867 _ => Some(Literal::Null),
9868 }
9869}
9870
9871fn evaluate_literal_logical_or(left: &Literal, right: &Literal) -> Option<Literal> {
9872 let left_truth = literal_truthiness(left);
9873 if matches!(left_truth, Some(true)) {
9874 return Some(Literal::Int128(1));
9875 }
9876
9877 let right_truth = literal_truthiness(right);
9878 if matches!(right_truth, Some(true)) {
9879 return Some(Literal::Int128(1));
9880 }
9881
9882 match (left_truth, right_truth) {
9883 (Some(false), Some(false)) => Some(Literal::Int128(0)),
9884 (Some(false), None) | (None, Some(false)) | (None, None) => Some(Literal::Null),
9885 _ => Some(Literal::Null),
9886 }
9887}
9888
9889fn add_literals(left: &Literal, right: &Literal) -> Option<Literal> {
9890 match (left, right) {
9891 (Literal::Int128(lhs), Literal::Int128(rhs)) => {
9892 Some(Literal::Int128(lhs.saturating_add(*rhs)))
9893 }
9894 _ => {
9895 let lhs = literal_to_f64(left)?;
9896 let rhs = literal_to_f64(right)?;
9897 Some(Literal::Float64(lhs + rhs))
9898 }
9899 }
9900}
9901
9902fn subtract_literals(left: &Literal, right: &Literal) -> Option<Literal> {
9903 match (left, right) {
9904 (Literal::Int128(lhs), Literal::Int128(rhs)) => {
9905 Some(Literal::Int128(lhs.saturating_sub(*rhs)))
9906 }
9907 _ => {
9908 let lhs = literal_to_f64(left)?;
9909 let rhs = literal_to_f64(right)?;
9910 Some(Literal::Float64(lhs - rhs))
9911 }
9912 }
9913}
9914
9915fn multiply_literals(left: &Literal, right: &Literal) -> Option<Literal> {
9916 match (left, right) {
9917 (Literal::Int128(lhs), Literal::Int128(rhs)) => {
9918 Some(Literal::Int128(lhs.saturating_mul(*rhs)))
9919 }
9920 _ => {
9921 let lhs = literal_to_f64(left)?;
9922 let rhs = literal_to_f64(right)?;
9923 Some(Literal::Float64(lhs * rhs))
9924 }
9925 }
9926}
9927
9928fn divide_literals(left: &Literal, right: &Literal) -> Option<Literal> {
9929 fn literal_to_i128_from_integer_like(literal: &Literal) -> Option<i128> {
9930 match literal {
9931 Literal::Int128(value) => Some(*value),
9932 Literal::Decimal128(value) => llkv_compute::scalar::decimal::rescale(*value, 0)
9933 .ok()
9934 .map(|integral| integral.raw_value()),
9935 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
9936 Literal::Date32(value) => Some(*value as i128),
9937 _ => None,
9938 }
9939 }
9940
9941 if let (Some(lhs), Some(rhs)) = (
9942 literal_to_i128_from_integer_like(left),
9943 literal_to_i128_from_integer_like(right),
9944 ) {
9945 if rhs == 0 {
9946 return Some(Literal::Null);
9947 }
9948
9949 if lhs == i128::MIN && rhs == -1 {
9950 return Some(Literal::Float64((lhs as f64) / (rhs as f64)));
9951 }
9952
9953 return Some(Literal::Int128(lhs / rhs));
9954 }
9955
9956 let lhs = literal_to_f64(left)?;
9957 let rhs = literal_to_f64(right)?;
9958 if rhs == 0.0 {
9959 return Some(Literal::Null);
9960 }
9961 Some(Literal::Float64(lhs / rhs))
9962}
9963
9964fn modulo_literals(left: &Literal, right: &Literal) -> Option<Literal> {
9965 let lhs = literal_to_i128(left)?;
9966 let rhs = literal_to_i128(right)?;
9967 if rhs == 0 {
9968 return Some(Literal::Null);
9969 }
9970 Some(Literal::Int128(lhs % rhs))
9971}
9972
9973fn literal_to_f64(literal: &Literal) -> Option<f64> {
9974 match literal {
9975 Literal::Int128(value) => Some(*value as f64),
9976 Literal::Float64(value) => Some(*value),
9977 Literal::Decimal128(value) => Some(value.to_f64()),
9978 Literal::Boolean(value) => Some(if *value { 1.0 } else { 0.0 }),
9979 Literal::Date32(value) => Some(*value as f64),
9980 _ => None,
9981 }
9982}
9983
9984fn literal_to_i128(literal: &Literal) -> Option<i128> {
9985 match literal {
9986 Literal::Int128(value) => Some(*value),
9987 Literal::Float64(value) => Some(*value as i128),
9988 Literal::Decimal128(value) => llkv_compute::scalar::decimal::rescale(*value, 0)
9989 .ok()
9990 .map(|integral| integral.raw_value()),
9991 Literal::Boolean(value) => Some(if *value { 1 } else { 0 }),
9992 Literal::Date32(value) => Some(*value as i128),
9993 _ => None,
9994 }
9995}
9996
9997fn literal_truthiness(literal: &Literal) -> Option<bool> {
9998 match literal {
9999 Literal::Boolean(value) => Some(*value),
10000 Literal::Int128(value) => Some(*value != 0),
10001 Literal::Float64(value) => Some(*value != 0.0),
10002 Literal::Decimal128(value) => Some(decimal_truthy(*value)),
10003 Literal::Date32(value) => Some(*value != 0),
10004 Literal::Null => None,
10005 _ => None,
10006 }
10007}
10008
10009fn plan_value_truthiness(value: &PlanValue) -> Option<bool> {
10010 match value {
10011 PlanValue::Integer(v) => Some(*v != 0),
10012 PlanValue::Float(v) => Some(*v != 0.0),
10013 PlanValue::Decimal(v) => Some(decimal_truthy(*v)),
10014 PlanValue::Date32(v) => Some(*v != 0),
10015 PlanValue::Null => None,
10016 _ => None,
10017 }
10018}
10019
10020fn option_i64_truthiness(value: Option<i64>) -> Option<bool> {
10021 value.map(|v| v != 0)
10022}
10023
10024fn evaluate_plan_value_logical_and(left: PlanValue, right: PlanValue) -> PlanValue {
10025 let left_truth = plan_value_truthiness(&left);
10026 if matches!(left_truth, Some(false)) {
10027 return PlanValue::Integer(0);
10028 }
10029
10030 let right_truth = plan_value_truthiness(&right);
10031 if matches!(right_truth, Some(false)) {
10032 return PlanValue::Integer(0);
10033 }
10034
10035 match (left_truth, right_truth) {
10036 (Some(true), Some(true)) => PlanValue::Integer(1),
10037 (Some(true), None) | (None, Some(true)) | (None, None) => PlanValue::Null,
10038 _ => PlanValue::Null,
10039 }
10040}
10041
10042fn evaluate_plan_value_logical_or(left: PlanValue, right: PlanValue) -> PlanValue {
10043 let left_truth = plan_value_truthiness(&left);
10044 if matches!(left_truth, Some(true)) {
10045 return PlanValue::Integer(1);
10046 }
10047
10048 let right_truth = plan_value_truthiness(&right);
10049 if matches!(right_truth, Some(true)) {
10050 return PlanValue::Integer(1);
10051 }
10052
10053 match (left_truth, right_truth) {
10054 (Some(false), Some(false)) => PlanValue::Integer(0),
10055 (Some(false), None) | (None, Some(false)) | (None, None) => PlanValue::Null,
10056 _ => PlanValue::Null,
10057 }
10058}
10059
10060fn evaluate_option_logical_and(left: Option<i64>, right: Option<i64>) -> Option<i64> {
10061 let left_truth = option_i64_truthiness(left);
10062 if matches!(left_truth, Some(false)) {
10063 return Some(0);
10064 }
10065
10066 let right_truth = option_i64_truthiness(right);
10067 if matches!(right_truth, Some(false)) {
10068 return Some(0);
10069 }
10070
10071 match (left_truth, right_truth) {
10072 (Some(true), Some(true)) => Some(1),
10073 (Some(true), None) | (None, Some(true)) | (None, None) => None,
10074 _ => None,
10075 }
10076}
10077
10078fn evaluate_option_logical_or(left: Option<i64>, right: Option<i64>) -> Option<i64> {
10079 let left_truth = option_i64_truthiness(left);
10080 if matches!(left_truth, Some(true)) {
10081 return Some(1);
10082 }
10083
10084 let right_truth = option_i64_truthiness(right);
10085 if matches!(right_truth, Some(true)) {
10086 return Some(1);
10087 }
10088
10089 match (left_truth, right_truth) {
10090 (Some(false), Some(false)) => Some(0),
10091 (Some(false), None) | (None, Some(false)) | (None, None) => None,
10092 _ => None,
10093 }
10094}
10095
10096fn cast_literal_to_type(literal: &Literal, data_type: &DataType) -> Option<Literal> {
10097 if matches!(literal, Literal::Null) {
10098 return Some(Literal::Null);
10099 }
10100
10101 match data_type {
10102 DataType::Boolean => literal_truthiness(literal).map(Literal::Boolean),
10103 DataType::Float16 | DataType::Float32 | DataType::Float64 => {
10104 let value = literal_to_f64(literal)?;
10105 Some(Literal::Float64(value))
10106 }
10107 DataType::Int8
10108 | DataType::Int16
10109 | DataType::Int32
10110 | DataType::Int64
10111 | DataType::UInt8
10112 | DataType::UInt16
10113 | DataType::UInt32
10114 | DataType::UInt64 => {
10115 let value = literal_to_i128(literal)?;
10116 Some(Literal::Int128(value))
10117 }
10118 DataType::Utf8 | DataType::LargeUtf8 => Some(Literal::String(match literal {
10119 Literal::String(text) => text.clone(),
10120 Literal::Int128(value) => value.to_string(),
10121 Literal::Float64(value) => value.to_string(),
10122 Literal::Decimal128(value) => value.to_string(),
10123 Literal::Boolean(value) => {
10124 if *value {
10125 "1".to_string()
10126 } else {
10127 "0".to_string()
10128 }
10129 }
10130 Literal::Date32(days) => format_date32_literal(*days).ok()?,
10131 Literal::Struct(_) | Literal::Null | Literal::Interval(_) => return None,
10132 })),
10133 DataType::Decimal128(precision, scale) => {
10134 literal_to_decimal_literal(literal, *precision, *scale)
10135 }
10136 DataType::Decimal256(precision, scale) => {
10137 literal_to_decimal_literal(literal, *precision, *scale)
10138 }
10139 DataType::Interval(IntervalUnit::MonthDayNano) => match literal {
10140 Literal::Interval(interval) => Some(Literal::Interval(*interval)),
10141 Literal::Null => Some(Literal::Null),
10142 _ => None,
10143 },
10144 DataType::Date32 => match literal {
10145 Literal::Null => Some(Literal::Null),
10146 Literal::Date32(days) => Some(Literal::Date32(*days)),
10147 Literal::String(text) => parse_date32_literal(text).ok().map(Literal::Date32),
10148 _ => None,
10149 },
10150 _ => None,
10151 }
10152}
10153
10154fn literal_to_decimal_literal(literal: &Literal, precision: u8, scale: i8) -> Option<Literal> {
10155 match literal {
10156 Literal::Decimal128(value) => align_decimal_to_scale(*value, precision, scale)
10157 .ok()
10158 .map(Literal::Decimal128),
10159 Literal::Int128(value) => {
10160 let int = i64::try_from(*value).ok()?;
10161 decimal_from_i64(int, precision, scale)
10162 .ok()
10163 .map(Literal::Decimal128)
10164 }
10165 Literal::Float64(value) => decimal_from_f64(*value, precision, scale)
10166 .ok()
10167 .map(Literal::Decimal128),
10168 Literal::Boolean(value) => {
10169 let int = if *value { 1 } else { 0 };
10170 decimal_from_i64(int, precision, scale)
10171 .ok()
10172 .map(Literal::Decimal128)
10173 }
10174 Literal::Null => Some(Literal::Null),
10175 _ => None,
10176 }
10177}
10178
10179fn compare_literals(op: CompareOp, left: &Literal, right: &Literal) -> Option<bool> {
10180 compare_literals_with_mode(op, left, right, NullComparisonBehavior::ThreeValuedLogic)
10181}
10182
10183fn bind_select_filter(
10184 filter: &llkv_plan::SelectFilter,
10185 bindings: &FxHashMap<String, Literal>,
10186) -> ExecutorResult<llkv_plan::SelectFilter> {
10187 let predicate = bind_predicate_expr(&filter.predicate, bindings)?;
10188 let subqueries = filter
10189 .subqueries
10190 .iter()
10191 .map(|subquery| bind_filter_subquery(subquery, bindings))
10192 .collect::<ExecutorResult<Vec<_>>>()?;
10193
10194 Ok(llkv_plan::SelectFilter {
10195 predicate,
10196 subqueries,
10197 })
10198}
10199
10200fn bind_filter_subquery(
10201 subquery: &llkv_plan::FilterSubquery,
10202 bindings: &FxHashMap<String, Literal>,
10203) -> ExecutorResult<llkv_plan::FilterSubquery> {
10204 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
10205 Ok(llkv_plan::FilterSubquery {
10206 id: subquery.id,
10207 plan: Box::new(bound_plan),
10208 correlated_columns: subquery.correlated_columns.clone(),
10209 })
10210}
10211
10212fn bind_scalar_subquery(
10213 subquery: &llkv_plan::ScalarSubquery,
10214 bindings: &FxHashMap<String, Literal>,
10215) -> ExecutorResult<llkv_plan::ScalarSubquery> {
10216 let bound_plan = bind_select_plan(&subquery.plan, bindings)?;
10217 Ok(llkv_plan::ScalarSubquery {
10218 id: subquery.id,
10219 plan: Box::new(bound_plan),
10220 correlated_columns: subquery.correlated_columns.clone(),
10221 })
10222}
10223
10224fn bind_projection(
10225 projection: &SelectProjection,
10226 bindings: &FxHashMap<String, Literal>,
10227) -> ExecutorResult<SelectProjection> {
10228 match projection {
10229 SelectProjection::AllColumns => Ok(projection.clone()),
10230 SelectProjection::AllColumnsExcept { exclude } => Ok(SelectProjection::AllColumnsExcept {
10231 exclude: exclude.clone(),
10232 }),
10233 SelectProjection::Column { name, alias } => {
10234 if let Some(literal) = bindings.get(name) {
10235 let expr = ScalarExpr::Literal(literal.clone());
10236 Ok(SelectProjection::Computed {
10237 expr,
10238 alias: alias.clone().unwrap_or_else(|| name.clone()),
10239 })
10240 } else {
10241 Ok(projection.clone())
10242 }
10243 }
10244 SelectProjection::Computed { expr, alias } => Ok(SelectProjection::Computed {
10245 expr: bind_scalar_expr(expr, bindings)?,
10246 alias: alias.clone(),
10247 }),
10248 }
10249}
10250
10251fn bind_aggregate_expr(
10252 aggregate: &AggregateExpr,
10253 bindings: &FxHashMap<String, Literal>,
10254) -> ExecutorResult<AggregateExpr> {
10255 match aggregate {
10256 AggregateExpr::CountStar { .. } => Ok(aggregate.clone()),
10257 AggregateExpr::Column {
10258 column,
10259 alias,
10260 function,
10261 distinct,
10262 } => {
10263 if bindings.contains_key(column) {
10264 return Err(Error::InvalidArgumentError(
10265 "correlated columns are not supported inside aggregate expressions".into(),
10266 ));
10267 }
10268 Ok(AggregateExpr::Column {
10269 column: column.clone(),
10270 alias: alias.clone(),
10271 function: function.clone(),
10272 distinct: *distinct,
10273 })
10274 }
10275 }
10276}
10277
10278fn bind_scalar_expr(
10279 expr: &ScalarExpr<String>,
10280 bindings: &FxHashMap<String, Literal>,
10281) -> ExecutorResult<ScalarExpr<String>> {
10282 match expr {
10283 ScalarExpr::Column(name) => {
10284 if let Some(literal) = bindings.get(name) {
10285 Ok(ScalarExpr::Literal(literal.clone()))
10286 } else {
10287 Ok(ScalarExpr::Column(name.clone()))
10288 }
10289 }
10290 ScalarExpr::Literal(literal) => Ok(ScalarExpr::Literal(literal.clone())),
10291 ScalarExpr::Binary { left, op, right } => Ok(ScalarExpr::Binary {
10292 left: Box::new(bind_scalar_expr(left, bindings)?),
10293 op: *op,
10294 right: Box::new(bind_scalar_expr(right, bindings)?),
10295 }),
10296 ScalarExpr::Compare { left, op, right } => Ok(ScalarExpr::Compare {
10297 left: Box::new(bind_scalar_expr(left, bindings)?),
10298 op: *op,
10299 right: Box::new(bind_scalar_expr(right, bindings)?),
10300 }),
10301 ScalarExpr::Aggregate(call) => Ok(ScalarExpr::Aggregate(call.clone())),
10302 ScalarExpr::GetField { base, field_name } => {
10303 let bound_base = bind_scalar_expr(base, bindings)?;
10304 match bound_base {
10305 ScalarExpr::Literal(literal) => {
10306 let value = extract_struct_field(&literal, field_name).unwrap_or(Literal::Null);
10307 Ok(ScalarExpr::Literal(value))
10308 }
10309 other => Ok(ScalarExpr::GetField {
10310 base: Box::new(other),
10311 field_name: field_name.clone(),
10312 }),
10313 }
10314 }
10315 ScalarExpr::Cast { expr, data_type } => Ok(ScalarExpr::Cast {
10316 expr: Box::new(bind_scalar_expr(expr, bindings)?),
10317 data_type: data_type.clone(),
10318 }),
10319 ScalarExpr::Case {
10320 operand,
10321 branches,
10322 else_expr,
10323 } => {
10324 let bound_operand = match operand {
10325 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
10326 None => None,
10327 };
10328 let mut bound_branches = Vec::with_capacity(branches.len());
10329 for (when_expr, then_expr) in branches {
10330 bound_branches.push((
10331 bind_scalar_expr(when_expr, bindings)?,
10332 bind_scalar_expr(then_expr, bindings)?,
10333 ));
10334 }
10335 let bound_else = match else_expr {
10336 Some(inner) => Some(Box::new(bind_scalar_expr(inner, bindings)?)),
10337 None => None,
10338 };
10339 Ok(ScalarExpr::Case {
10340 operand: bound_operand,
10341 branches: bound_branches,
10342 else_expr: bound_else,
10343 })
10344 }
10345 ScalarExpr::Coalesce(items) => {
10346 let mut bound_items = Vec::with_capacity(items.len());
10347 for item in items {
10348 bound_items.push(bind_scalar_expr(item, bindings)?);
10349 }
10350 Ok(ScalarExpr::Coalesce(bound_items))
10351 }
10352 ScalarExpr::Not(inner) => Ok(ScalarExpr::Not(Box::new(bind_scalar_expr(
10353 inner, bindings,
10354 )?))),
10355 ScalarExpr::IsNull { expr, negated } => Ok(ScalarExpr::IsNull {
10356 expr: Box::new(bind_scalar_expr(expr, bindings)?),
10357 negated: *negated,
10358 }),
10359 ScalarExpr::Random => Ok(ScalarExpr::Random),
10360 ScalarExpr::ScalarSubquery(sub) => Ok(ScalarExpr::ScalarSubquery(sub.clone())),
10361 }
10362}
10363
10364fn bind_predicate_expr(
10365 expr: &LlkvExpr<'static, String>,
10366 bindings: &FxHashMap<String, Literal>,
10367) -> ExecutorResult<LlkvExpr<'static, String>> {
10368 match expr {
10369 LlkvExpr::And(children) => {
10370 let mut bound = Vec::with_capacity(children.len());
10371 for child in children {
10372 bound.push(bind_predicate_expr(child, bindings)?);
10373 }
10374 Ok(LlkvExpr::And(bound))
10375 }
10376 LlkvExpr::Or(children) => {
10377 let mut bound = Vec::with_capacity(children.len());
10378 for child in children {
10379 bound.push(bind_predicate_expr(child, bindings)?);
10380 }
10381 Ok(LlkvExpr::Or(bound))
10382 }
10383 LlkvExpr::Not(inner) => Ok(LlkvExpr::Not(Box::new(bind_predicate_expr(
10384 inner, bindings,
10385 )?))),
10386 LlkvExpr::Pred(filter) => bind_filter_predicate(filter, bindings),
10387 LlkvExpr::Compare { left, op, right } => Ok(LlkvExpr::Compare {
10388 left: bind_scalar_expr(left, bindings)?,
10389 op: *op,
10390 right: bind_scalar_expr(right, bindings)?,
10391 }),
10392 LlkvExpr::InList {
10393 expr,
10394 list,
10395 negated,
10396 } => {
10397 let target = bind_scalar_expr(expr, bindings)?;
10398 let mut bound_list = Vec::with_capacity(list.len());
10399 for item in list {
10400 bound_list.push(bind_scalar_expr(item, bindings)?);
10401 }
10402 Ok(LlkvExpr::InList {
10403 expr: target,
10404 list: bound_list,
10405 negated: *negated,
10406 })
10407 }
10408 LlkvExpr::IsNull { expr, negated } => Ok(LlkvExpr::IsNull {
10409 expr: bind_scalar_expr(expr, bindings)?,
10410 negated: *negated,
10411 }),
10412 LlkvExpr::Literal(value) => Ok(LlkvExpr::Literal(*value)),
10413 LlkvExpr::Exists(subquery) => Ok(LlkvExpr::Exists(subquery.clone())),
10414 }
10415}
10416
10417fn bind_filter_predicate(
10418 filter: &Filter<'static, String>,
10419 bindings: &FxHashMap<String, Literal>,
10420) -> ExecutorResult<LlkvExpr<'static, String>> {
10421 if let Some(literal) = bindings.get(&filter.field_id) {
10422 let result = evaluate_filter_against_literal(literal, &filter.op)?;
10423 return Ok(LlkvExpr::Literal(result));
10424 }
10425 Ok(LlkvExpr::Pred(filter.clone()))
10426}
10427
10428fn evaluate_filter_against_literal(value: &Literal, op: &Operator) -> ExecutorResult<bool> {
10429 use std::ops::Bound;
10430
10431 match op {
10432 Operator::IsNull => Ok(matches!(value, Literal::Null)),
10433 Operator::IsNotNull => Ok(!matches!(value, Literal::Null)),
10434 Operator::Equals(rhs) => Ok(literal_equals(value, rhs).unwrap_or(false)),
10435 Operator::GreaterThan(rhs) => Ok(literal_compare(value, rhs)
10436 .map(|cmp| cmp == std::cmp::Ordering::Greater)
10437 .unwrap_or(false)),
10438 Operator::GreaterThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
10439 .map(|cmp| matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal))
10440 .unwrap_or(false)),
10441 Operator::LessThan(rhs) => Ok(literal_compare(value, rhs)
10442 .map(|cmp| cmp == std::cmp::Ordering::Less)
10443 .unwrap_or(false)),
10444 Operator::LessThanOrEquals(rhs) => Ok(literal_compare(value, rhs)
10445 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal))
10446 .unwrap_or(false)),
10447 Operator::In(values) => Ok(values
10448 .iter()
10449 .any(|candidate| literal_equals(value, candidate).unwrap_or(false))),
10450 Operator::Range { lower, upper } => {
10451 let lower_ok = match lower {
10452 Bound::Unbounded => Some(true),
10453 Bound::Included(bound) => literal_compare(value, bound).map(|cmp| {
10454 matches!(cmp, std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
10455 }),
10456 Bound::Excluded(bound) => {
10457 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Greater)
10458 }
10459 }
10460 .unwrap_or(false);
10461
10462 let upper_ok = match upper {
10463 Bound::Unbounded => Some(true),
10464 Bound::Included(bound) => literal_compare(value, bound)
10465 .map(|cmp| matches!(cmp, std::cmp::Ordering::Less | std::cmp::Ordering::Equal)),
10466 Bound::Excluded(bound) => {
10467 literal_compare(value, bound).map(|cmp| cmp == std::cmp::Ordering::Less)
10468 }
10469 }
10470 .unwrap_or(false);
10471
10472 Ok(lower_ok && upper_ok)
10473 }
10474 Operator::StartsWith {
10475 pattern,
10476 case_sensitive,
10477 } => {
10478 let target = if *case_sensitive {
10479 pattern.to_string()
10480 } else {
10481 pattern.to_ascii_lowercase()
10482 };
10483 Ok(literal_string(value, *case_sensitive)
10484 .map(|source| source.starts_with(&target))
10485 .unwrap_or(false))
10486 }
10487 Operator::EndsWith {
10488 pattern,
10489 case_sensitive,
10490 } => {
10491 let target = if *case_sensitive {
10492 pattern.to_string()
10493 } else {
10494 pattern.to_ascii_lowercase()
10495 };
10496 Ok(literal_string(value, *case_sensitive)
10497 .map(|source| source.ends_with(&target))
10498 .unwrap_or(false))
10499 }
10500 Operator::Contains {
10501 pattern,
10502 case_sensitive,
10503 } => {
10504 let target = if *case_sensitive {
10505 pattern.to_string()
10506 } else {
10507 pattern.to_ascii_lowercase()
10508 };
10509 Ok(literal_string(value, *case_sensitive)
10510 .map(|source| source.contains(&target))
10511 .unwrap_or(false))
10512 }
10513 }
10514}
10515
10516fn literal_compare(lhs: &Literal, rhs: &Literal) -> Option<std::cmp::Ordering> {
10517 match (lhs, rhs) {
10518 (Literal::Int128(a), Literal::Int128(b)) => Some(a.cmp(b)),
10519 (Literal::Float64(a), Literal::Float64(b)) => a.partial_cmp(b),
10520 (Literal::Int128(a), Literal::Float64(b)) => (*a as f64).partial_cmp(b),
10521 (Literal::Float64(a), Literal::Int128(b)) => a.partial_cmp(&(*b as f64)),
10522 (Literal::Date32(a), Literal::Date32(b)) => Some(a.cmp(b)),
10523 (Literal::Date32(a), Literal::Int128(b)) => Some((*a as i128).cmp(b)),
10524 (Literal::Int128(a), Literal::Date32(b)) => Some(a.cmp(&(*b as i128))),
10525 (Literal::Date32(a), Literal::Float64(b)) => (*a as f64).partial_cmp(b),
10526 (Literal::Float64(a), Literal::Date32(b)) => a.partial_cmp(&(*b as f64)),
10527 (Literal::String(a), Literal::String(b)) => Some(a.cmp(b)),
10528 (Literal::Interval(a), Literal::Interval(b)) => Some(compare_interval_values(*a, *b)),
10529 _ => None,
10530 }
10531}
10532
10533fn literal_equals(lhs: &Literal, rhs: &Literal) -> Option<bool> {
10534 match (lhs, rhs) {
10535 (Literal::Boolean(a), Literal::Boolean(b)) => Some(a == b),
10536 (Literal::String(a), Literal::String(b)) => Some(a == b),
10537 (Literal::Int128(_), Literal::Int128(_))
10538 | (Literal::Int128(_), Literal::Float64(_))
10539 | (Literal::Float64(_), Literal::Int128(_))
10540 | (Literal::Float64(_), Literal::Float64(_))
10541 | (Literal::Date32(_), Literal::Date32(_))
10542 | (Literal::Date32(_), Literal::Int128(_))
10543 | (Literal::Int128(_), Literal::Date32(_))
10544 | (Literal::Date32(_), Literal::Float64(_))
10545 | (Literal::Float64(_), Literal::Date32(_))
10546 | (Literal::Interval(_), Literal::Interval(_)) => {
10547 literal_compare(lhs, rhs).map(|cmp| cmp == std::cmp::Ordering::Equal)
10548 }
10549 _ => None,
10550 }
10551}
10552
10553fn literal_string(literal: &Literal, case_sensitive: bool) -> Option<String> {
10554 match literal {
10555 Literal::String(value) => {
10556 if case_sensitive {
10557 Some(value.clone())
10558 } else {
10559 Some(value.to_ascii_lowercase())
10560 }
10561 }
10562 Literal::Date32(value) => {
10563 let formatted = format_date32_literal(*value).ok()?;
10564 if case_sensitive {
10565 Some(formatted)
10566 } else {
10567 Some(formatted.to_ascii_lowercase())
10568 }
10569 }
10570 _ => None,
10571 }
10572}
10573
10574fn extract_struct_field(literal: &Literal, field_name: &str) -> Option<Literal> {
10575 if let Literal::Struct(fields) = literal {
10576 for (name, value) in fields {
10577 if name.eq_ignore_ascii_case(field_name) {
10578 return Some((**value).clone());
10579 }
10580 }
10581 }
10582 None
10583}
10584
10585fn collect_scalar_subquery_ids(expr: &ScalarExpr<FieldId>, ids: &mut FxHashSet<SubqueryId>) {
10586 match expr {
10587 ScalarExpr::ScalarSubquery(subquery) => {
10588 ids.insert(subquery.id);
10589 }
10590 ScalarExpr::Binary { left, right, .. } => {
10591 collect_scalar_subquery_ids(left, ids);
10592 collect_scalar_subquery_ids(right, ids);
10593 }
10594 ScalarExpr::Compare { left, right, .. } => {
10595 collect_scalar_subquery_ids(left, ids);
10596 collect_scalar_subquery_ids(right, ids);
10597 }
10598 ScalarExpr::GetField { base, .. } => {
10599 collect_scalar_subquery_ids(base, ids);
10600 }
10601 ScalarExpr::Cast { expr, .. } => {
10602 collect_scalar_subquery_ids(expr, ids);
10603 }
10604 ScalarExpr::Not(expr) => {
10605 collect_scalar_subquery_ids(expr, ids);
10606 }
10607 ScalarExpr::IsNull { expr, .. } => {
10608 collect_scalar_subquery_ids(expr, ids);
10609 }
10610 ScalarExpr::Case {
10611 operand,
10612 branches,
10613 else_expr,
10614 } => {
10615 if let Some(op) = operand {
10616 collect_scalar_subquery_ids(op, ids);
10617 }
10618 for (when_expr, then_expr) in branches {
10619 collect_scalar_subquery_ids(when_expr, ids);
10620 collect_scalar_subquery_ids(then_expr, ids);
10621 }
10622 if let Some(else_expr) = else_expr {
10623 collect_scalar_subquery_ids(else_expr, ids);
10624 }
10625 }
10626 ScalarExpr::Coalesce(items) => {
10627 for item in items {
10628 collect_scalar_subquery_ids(item, ids);
10629 }
10630 }
10631 ScalarExpr::Aggregate(_)
10632 | ScalarExpr::Column(_)
10633 | ScalarExpr::Literal(_)
10634 | ScalarExpr::Random => {}
10635 }
10636}
10637
10638fn collect_predicate_scalar_subquery_ids(
10639 expr: &LlkvExpr<'static, FieldId>,
10640 ids: &mut FxHashSet<SubqueryId>,
10641) {
10642 match expr {
10643 LlkvExpr::And(children) | LlkvExpr::Or(children) => {
10644 for child in children {
10645 collect_predicate_scalar_subquery_ids(child, ids);
10646 }
10647 }
10648 LlkvExpr::Not(inner) => collect_predicate_scalar_subquery_ids(inner, ids),
10649 LlkvExpr::Compare { left, right, .. } => {
10650 collect_scalar_subquery_ids(left, ids);
10651 collect_scalar_subquery_ids(right, ids);
10652 }
10653 LlkvExpr::InList { expr, list, .. } => {
10654 collect_scalar_subquery_ids(expr, ids);
10655 for item in list {
10656 collect_scalar_subquery_ids(item, ids);
10657 }
10658 }
10659 LlkvExpr::IsNull { expr, .. } => {
10660 collect_scalar_subquery_ids(expr, ids);
10661 }
10662 LlkvExpr::Exists(_) | LlkvExpr::Pred(_) | LlkvExpr::Literal(_) => {
10663 }
10665 }
10666}
10667
10668fn rewrite_scalar_expr_for_subqueries(
10669 expr: &ScalarExpr<FieldId>,
10670 mapping: &FxHashMap<SubqueryId, FieldId>,
10671) -> ScalarExpr<FieldId> {
10672 match expr {
10673 ScalarExpr::ScalarSubquery(subquery) => mapping
10674 .get(&subquery.id)
10675 .map(|field_id| ScalarExpr::Column(*field_id))
10676 .unwrap_or_else(|| ScalarExpr::ScalarSubquery(subquery.clone())),
10677 ScalarExpr::Binary { left, op, right } => ScalarExpr::Binary {
10678 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
10679 op: *op,
10680 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
10681 },
10682 ScalarExpr::Compare { left, op, right } => ScalarExpr::Compare {
10683 left: Box::new(rewrite_scalar_expr_for_subqueries(left, mapping)),
10684 op: *op,
10685 right: Box::new(rewrite_scalar_expr_for_subqueries(right, mapping)),
10686 },
10687 ScalarExpr::GetField { base, field_name } => ScalarExpr::GetField {
10688 base: Box::new(rewrite_scalar_expr_for_subqueries(base, mapping)),
10689 field_name: field_name.clone(),
10690 },
10691 ScalarExpr::Cast { expr, data_type } => ScalarExpr::Cast {
10692 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
10693 data_type: data_type.clone(),
10694 },
10695 ScalarExpr::Not(expr) => {
10696 ScalarExpr::Not(Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)))
10697 }
10698 ScalarExpr::IsNull { expr, negated } => ScalarExpr::IsNull {
10699 expr: Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping)),
10700 negated: *negated,
10701 },
10702 ScalarExpr::Case {
10703 operand,
10704 branches,
10705 else_expr,
10706 } => ScalarExpr::Case {
10707 operand: operand
10708 .as_ref()
10709 .map(|op| Box::new(rewrite_scalar_expr_for_subqueries(op, mapping))),
10710 branches: branches
10711 .iter()
10712 .map(|(when_expr, then_expr)| {
10713 (
10714 rewrite_scalar_expr_for_subqueries(when_expr, mapping),
10715 rewrite_scalar_expr_for_subqueries(then_expr, mapping),
10716 )
10717 })
10718 .collect(),
10719 else_expr: else_expr
10720 .as_ref()
10721 .map(|expr| Box::new(rewrite_scalar_expr_for_subqueries(expr, mapping))),
10722 },
10723 ScalarExpr::Coalesce(items) => ScalarExpr::Coalesce(
10724 items
10725 .iter()
10726 .map(|item| rewrite_scalar_expr_for_subqueries(item, mapping))
10727 .collect(),
10728 ),
10729 ScalarExpr::Aggregate(_)
10730 | ScalarExpr::Column(_)
10731 | ScalarExpr::Literal(_)
10732 | ScalarExpr::Random => expr.clone(),
10733 }
10734}
10735
10736fn collect_correlated_bindings(
10737 context: &mut CrossProductExpressionContext,
10738 batch: &RecordBatch,
10739 row_idx: usize,
10740 columns: &[llkv_plan::CorrelatedColumn],
10741) -> ExecutorResult<FxHashMap<String, Literal>> {
10742 let mut out = FxHashMap::default();
10743
10744 for correlated in columns {
10745 if !correlated.field_path.is_empty() {
10746 return Err(Error::InvalidArgumentError(
10747 "correlated field path resolution is not yet supported".into(),
10748 ));
10749 }
10750
10751 let field_id = context
10752 .field_id_for_column(&correlated.column)
10753 .ok_or_else(|| {
10754 Error::InvalidArgumentError(format!(
10755 "correlated column '{}' not found in outer query output",
10756 correlated.column
10757 ))
10758 })?;
10759
10760 let accessor = context.column_accessor(field_id, batch)?;
10761 let literal = accessor.literal_at(row_idx)?;
10762 out.insert(correlated.placeholder.clone(), literal);
10763 }
10764
10765 Ok(out)
10766}
10767
10768#[derive(Clone)]
10770pub struct SelectExecution<P>
10771where
10772 P: Pager<Blob = EntryHandle> + Send + Sync,
10773{
10774 table_name: String,
10775 schema: Arc<Schema>,
10776 stream: SelectStream<P>,
10777 limit: Option<usize>,
10778 offset: Option<usize>,
10779}
10780
10781#[derive(Clone)]
10782enum SelectStream<P>
10783where
10784 P: Pager<Blob = EntryHandle> + Send + Sync,
10785{
10786 Projection {
10787 table: Arc<ExecutorTable<P>>,
10788 projections: Vec<ScanProjection>,
10789 filter_expr: LlkvExpr<'static, FieldId>,
10790 options: ScanStreamOptions<P>,
10791 full_table_scan: bool,
10792 order_by: Vec<OrderByPlan>,
10793 distinct: bool,
10794 },
10795 Aggregation {
10796 batch: RecordBatch,
10797 },
10798}
10799
10800impl<P> SelectExecution<P>
10801where
10802 P: Pager<Blob = EntryHandle> + Send + Sync,
10803{
10804 #[allow(clippy::too_many_arguments)]
10805 fn new_projection(
10806 table_name: String,
10807 schema: Arc<Schema>,
10808 table: Arc<ExecutorTable<P>>,
10809 projections: Vec<ScanProjection>,
10810 filter_expr: LlkvExpr<'static, FieldId>,
10811 options: ScanStreamOptions<P>,
10812 full_table_scan: bool,
10813 order_by: Vec<OrderByPlan>,
10814 distinct: bool,
10815 ) -> Self {
10816 Self {
10817 table_name,
10818 schema,
10819 stream: SelectStream::Projection {
10820 table,
10821 projections,
10822 filter_expr,
10823 options,
10824 full_table_scan,
10825 order_by,
10826 distinct,
10827 },
10828 limit: None,
10829 offset: None,
10830 }
10831 }
10832
10833 pub fn new_single_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
10834 Self {
10835 table_name,
10836 schema,
10837 stream: SelectStream::Aggregation { batch },
10838 limit: None,
10839 offset: None,
10840 }
10841 }
10842
10843 pub fn from_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
10844 Self::new_single_batch(table_name, schema, batch)
10845 }
10846
10847 pub fn table_name(&self) -> &str {
10848 &self.table_name
10849 }
10850
10851 pub fn schema(&self) -> Arc<Schema> {
10852 Arc::clone(&self.schema)
10853 }
10854
10855 pub fn with_limit(mut self, limit: Option<usize>) -> Self {
10856 self.limit = limit;
10857 self
10858 }
10859
10860 pub fn with_offset(mut self, offset: Option<usize>) -> Self {
10861 self.offset = offset;
10862 self
10863 }
10864
10865 pub fn stream(
10866 self,
10867 mut on_batch: impl FnMut(RecordBatch) -> ExecutorResult<()>,
10868 ) -> ExecutorResult<()> {
10869 let limit = self.limit;
10870 let mut offset = self.offset.unwrap_or(0);
10871 let mut rows_emitted = 0;
10872
10873 let mut on_batch = |batch: RecordBatch| -> ExecutorResult<()> {
10874 let rows = batch.num_rows();
10875 let mut batch_to_emit = batch;
10876
10877 if offset > 0 {
10879 if rows == 0 {
10880 } else if rows <= offset {
10882 offset -= rows;
10883 return Ok(());
10884 } else {
10885 batch_to_emit = batch_to_emit.slice(offset, rows - offset);
10886 offset = 0;
10887 }
10888 }
10889
10890 if let Some(limit_val) = limit {
10892 if rows_emitted >= limit_val {
10893 return Ok(());
10894 }
10895 let remaining = limit_val - rows_emitted;
10896 if batch_to_emit.num_rows() > remaining {
10897 batch_to_emit = batch_to_emit.slice(0, remaining);
10898 }
10899 rows_emitted += batch_to_emit.num_rows();
10900 }
10901
10902 on_batch(batch_to_emit)
10903 };
10904
10905 let schema = Arc::clone(&self.schema);
10906 match self.stream {
10907 SelectStream::Projection {
10908 table,
10909 projections,
10910 filter_expr,
10911 options,
10912 full_table_scan,
10913 order_by,
10914 distinct,
10915 } => {
10916 let total_rows = table.total_rows.load(Ordering::SeqCst);
10918 if total_rows == 0 {
10919 return Ok(());
10921 }
10922
10923 let mut error: Option<Error> = None;
10924 let mut produced = false;
10925 let mut produced_rows: u64 = 0;
10926 let capture_nulls_first = matches!(options.order, Some(spec) if spec.nulls_first);
10927 let needs_post_sort =
10928 !order_by.is_empty() && (order_by.len() > 1 || options.order.is_none());
10929 let collect_batches = needs_post_sort || capture_nulls_first;
10930 let include_nulls = options.include_nulls;
10931 let has_row_id_filter = options.row_id_filter.is_some();
10932 let mut distinct_state = if distinct {
10933 Some(DistinctState::default())
10934 } else {
10935 None
10936 };
10937 let scan_options = options;
10938 let mut buffered_batches: Vec<RecordBatch> = Vec::new();
10939 table
10940 .table
10941 .scan_stream(projections, &filter_expr, scan_options, |batch| {
10942 if error.is_some() {
10943 return;
10944 }
10945 let mut batch = batch;
10946 if let Some(state) = distinct_state.as_mut() {
10947 match distinct_filter_batch(batch, state) {
10948 Ok(Some(filtered)) => {
10949 batch = filtered;
10950 }
10951 Ok(None) => {
10952 return;
10953 }
10954 Err(err) => {
10955 error = Some(err);
10956 return;
10957 }
10958 }
10959 }
10960 produced = true;
10961 produced_rows = produced_rows.saturating_add(batch.num_rows() as u64);
10962 if collect_batches {
10963 buffered_batches.push(batch);
10964 } else if let Err(err) = on_batch(batch) {
10965 error = Some(err);
10966 }
10967 })?;
10968 if let Some(err) = error {
10969 return Err(err);
10970 }
10971 if !produced {
10972 if !distinct && full_table_scan && total_rows > 0 {
10975 for batch in synthesize_null_scan(Arc::clone(&schema), total_rows)? {
10976 on_batch(batch)?;
10977 }
10978 }
10979 return Ok(());
10980 }
10981 let mut null_batches: Vec<RecordBatch> = Vec::new();
10982 if !distinct
10988 && include_nulls
10989 && full_table_scan
10990 && produced_rows < total_rows
10991 && !has_row_id_filter
10992 {
10993 let missing = total_rows - produced_rows;
10994 if missing > 0 {
10995 null_batches = synthesize_null_scan(Arc::clone(&schema), missing)?;
10996 }
10997 }
10998
10999 if collect_batches {
11000 if needs_post_sort {
11001 if !null_batches.is_empty() {
11002 buffered_batches.extend(null_batches);
11003 }
11004 if !buffered_batches.is_empty() {
11005 let combined =
11006 concat_batches(&schema, &buffered_batches).map_err(|err| {
11007 Error::InvalidArgumentError(format!(
11008 "failed to concatenate result batches for ORDER BY: {}",
11009 err
11010 ))
11011 })?;
11012 let sorted_batch =
11013 sort_record_batch_with_order(&schema, &combined, &order_by)?;
11014 on_batch(sorted_batch)?;
11015 }
11016 } else if capture_nulls_first {
11017 for batch in null_batches {
11018 on_batch(batch)?;
11019 }
11020 for batch in buffered_batches {
11021 on_batch(batch)?;
11022 }
11023 }
11024 } else if !null_batches.is_empty() {
11025 for batch in null_batches {
11026 on_batch(batch)?;
11027 }
11028 }
11029 Ok(())
11030 }
11031 SelectStream::Aggregation { batch } => on_batch(batch),
11032 }
11033 }
11034
11035 pub fn collect(self) -> ExecutorResult<Vec<RecordBatch>> {
11036 let mut batches = Vec::new();
11037 self.stream(|batch| {
11038 batches.push(batch);
11039 Ok(())
11040 })?;
11041 Ok(batches)
11042 }
11043
11044 pub fn collect_rows(self) -> ExecutorResult<ExecutorRowBatch> {
11045 let schema = self.schema();
11046 let mut rows: Vec<Vec<PlanValue>> = Vec::new();
11047 self.stream(|batch| {
11048 for row_idx in 0..batch.num_rows() {
11049 let mut row: Vec<PlanValue> = Vec::with_capacity(batch.num_columns());
11050 for col_idx in 0..batch.num_columns() {
11051 let value = llkv_plan::plan_value_from_array(batch.column(col_idx), row_idx)?;
11052 row.push(value);
11053 }
11054 rows.push(row);
11055 }
11056 Ok(())
11057 })?;
11058 let columns = schema
11059 .fields()
11060 .iter()
11061 .map(|field| field.name().to_string())
11062 .collect();
11063 Ok(ExecutorRowBatch { columns, rows })
11064 }
11065
11066 pub fn into_rows(self) -> ExecutorResult<Vec<Vec<PlanValue>>> {
11067 Ok(self.collect_rows()?.rows)
11068 }
11069}
11070
11071impl<P> fmt::Debug for SelectExecution<P>
11072where
11073 P: Pager<Blob = EntryHandle> + Send + Sync,
11074{
11075 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11076 f.debug_struct("SelectExecution")
11077 .field("table_name", &self.table_name)
11078 .field("schema", &self.schema)
11079 .finish()
11080 }
11081}
11082
11083fn infer_type_recursive(
11089 expr: &ScalarExpr<String>,
11090 base_schema: &Schema,
11091 column_lookup_map: &FxHashMap<String, usize>,
11092) -> Option<DataType> {
11093 use arrow::datatypes::IntervalUnit;
11094 use llkv_expr::literal::Literal;
11095
11096 match expr {
11097 ScalarExpr::Column(name) => resolve_column_name_to_index(name, column_lookup_map)
11098 .map(|idx| base_schema.field(idx).data_type().clone()),
11099 ScalarExpr::Literal(lit) => match lit {
11100 Literal::Decimal128(v) => Some(DataType::Decimal128(v.precision(), v.scale())),
11101 Literal::Float64(_) => Some(DataType::Float64),
11102 Literal::Int128(_) => Some(DataType::Int64),
11103 Literal::Boolean(_) => Some(DataType::Boolean),
11104 Literal::String(_) => Some(DataType::Utf8),
11105 Literal::Date32(_) => Some(DataType::Date32),
11106 Literal::Null => Some(DataType::Null),
11107 Literal::Interval(_) => Some(DataType::Interval(IntervalUnit::MonthDayNano)),
11108 _ => None,
11109 },
11110 ScalarExpr::Binary { left, op: _, right } => {
11111 let l = infer_type_recursive(left, base_schema, column_lookup_map)?;
11112 let r = infer_type_recursive(right, base_schema, column_lookup_map)?;
11113
11114 if matches!(l, DataType::Float64) || matches!(r, DataType::Float64) {
11115 return Some(DataType::Float64);
11116 }
11117
11118 match (l, r) {
11119 (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) => {
11120 Some(DataType::Decimal128(38, s1.max(s2)))
11122 }
11123 (DataType::Decimal128(p, s), _) => Some(DataType::Decimal128(p, s)),
11124 (_, DataType::Decimal128(p, s)) => Some(DataType::Decimal128(p, s)),
11125 (l, _) => Some(l),
11126 }
11127 }
11128 ScalarExpr::Cast { data_type, .. } => Some(data_type.clone()),
11129 _ => None,
11131 }
11132}
11133
11134fn expand_order_targets(
11135 order_items: &[OrderByPlan],
11136 projections: &[ScanProjection],
11137) -> ExecutorResult<Vec<OrderByPlan>> {
11138 let mut expanded = Vec::new();
11139
11140 for item in order_items {
11141 match &item.target {
11142 OrderTarget::All => {
11143 if projections.is_empty() {
11144 return Err(Error::InvalidArgumentError(
11145 "ORDER BY ALL requires at least one projection".into(),
11146 ));
11147 }
11148
11149 for (idx, projection) in projections.iter().enumerate() {
11150 if matches!(projection, ScanProjection::Computed { .. }) {
11151 return Err(Error::InvalidArgumentError(
11152 "ORDER BY ALL cannot reference computed projections".into(),
11153 ));
11154 }
11155
11156 let mut clone = item.clone();
11157 clone.target = OrderTarget::Index(idx);
11158 expanded.push(clone);
11159 }
11160 }
11161 _ => expanded.push(item.clone()),
11162 }
11163 }
11164
11165 Ok(expanded)
11166}
11167
11168fn resolve_scan_order<P>(
11169 table: &ExecutorTable<P>,
11170 projections: &[ScanProjection],
11171 order_plan: &OrderByPlan,
11172) -> ExecutorResult<ScanOrderSpec>
11173where
11174 P: Pager<Blob = EntryHandle> + Send + Sync,
11175{
11176 let (column, field_id) = match &order_plan.target {
11177 OrderTarget::Column(name) => {
11178 let column = table.schema.resolve(name).ok_or_else(|| {
11179 Error::InvalidArgumentError(format!("unknown column '{}' in ORDER BY", name))
11180 })?;
11181 (column, column.field_id)
11182 }
11183 OrderTarget::Index(position) => {
11184 let projection = projections.get(*position).ok_or_else(|| {
11185 Error::InvalidArgumentError(format!(
11186 "ORDER BY position {} is out of range",
11187 position + 1
11188 ))
11189 })?;
11190 match projection {
11191 ScanProjection::Column(store_projection) => {
11192 let field_id = store_projection.logical_field_id.field_id();
11193 let column = table.schema.column_by_field_id(field_id).ok_or_else(|| {
11194 Error::InvalidArgumentError(format!(
11195 "unknown column with field id {field_id} in ORDER BY"
11196 ))
11197 })?;
11198 (column, field_id)
11199 }
11200 ScanProjection::Computed { .. } => {
11201 return Err(Error::InvalidArgumentError(
11202 "ORDER BY position referring to computed projection is not supported"
11203 .into(),
11204 ));
11205 }
11206 }
11207 }
11208 OrderTarget::All => {
11209 return Err(Error::InvalidArgumentError(
11210 "ORDER BY ALL should be expanded before execution".into(),
11211 ));
11212 }
11213 };
11214
11215 let transform = match order_plan.sort_type {
11216 OrderSortType::Native => match column.data_type {
11217 DataType::Int64 => ScanOrderTransform::IdentityInt64,
11218 DataType::Int32 => ScanOrderTransform::IdentityInt32,
11219 DataType::Utf8 => ScanOrderTransform::IdentityUtf8,
11220 ref other => {
11221 return Err(Error::InvalidArgumentError(format!(
11222 "ORDER BY on column type {:?} is not supported",
11223 other
11224 )));
11225 }
11226 },
11227 OrderSortType::CastTextToInteger => {
11228 if column.data_type != DataType::Utf8 {
11229 return Err(Error::InvalidArgumentError(
11230 "ORDER BY CAST expects a text column".into(),
11231 ));
11232 }
11233 ScanOrderTransform::CastUtf8ToInteger
11234 }
11235 };
11236
11237 let direction = if order_plan.ascending {
11238 ScanOrderDirection::Ascending
11239 } else {
11240 ScanOrderDirection::Descending
11241 };
11242
11243 Ok(ScanOrderSpec {
11244 field_id,
11245 direction,
11246 nulls_first: order_plan.nulls_first,
11247 transform,
11248 })
11249}
11250
11251fn synthesize_null_scan(schema: Arc<Schema>, total_rows: u64) -> ExecutorResult<Vec<RecordBatch>> {
11252 let row_count = usize::try_from(total_rows).map_err(|_| {
11253 Error::InvalidArgumentError("table row count exceeds supported in-memory batch size".into())
11254 })?;
11255
11256 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
11257 for field in schema.fields() {
11258 match field.data_type() {
11259 DataType::Int64 => {
11260 let mut builder = Int64Builder::with_capacity(row_count);
11261 for _ in 0..row_count {
11262 builder.append_null();
11263 }
11264 arrays.push(Arc::new(builder.finish()));
11265 }
11266 DataType::Float64 => {
11267 let mut builder = arrow::array::Float64Builder::with_capacity(row_count);
11268 for _ in 0..row_count {
11269 builder.append_null();
11270 }
11271 arrays.push(Arc::new(builder.finish()));
11272 }
11273 DataType::Utf8 => {
11274 let mut builder = arrow::array::StringBuilder::with_capacity(row_count, 0);
11275 for _ in 0..row_count {
11276 builder.append_null();
11277 }
11278 arrays.push(Arc::new(builder.finish()));
11279 }
11280 DataType::Date32 => {
11281 let mut builder = arrow::array::Date32Builder::with_capacity(row_count);
11282 for _ in 0..row_count {
11283 builder.append_null();
11284 }
11285 arrays.push(Arc::new(builder.finish()));
11286 }
11287 other => {
11288 return Err(Error::InvalidArgumentError(format!(
11289 "unsupported data type in null synthesis: {other:?}"
11290 )));
11291 }
11292 }
11293 }
11294
11295 let batch = RecordBatch::try_new(schema, arrays)?;
11296 Ok(vec![batch])
11297}
11298
11299struct TableCrossProductData {
11300 schema: Arc<Schema>,
11301 batches: Vec<RecordBatch>,
11302 column_counts: Vec<usize>,
11303 table_indices: Vec<usize>,
11304}
11305
11306fn plan_value_to_literal(value: &PlanValue) -> ExecutorResult<Literal> {
11307 match value {
11308 PlanValue::String(s) => Ok(Literal::String(s.clone())),
11309 PlanValue::Integer(i) => Ok(Literal::Int128(*i as i128)),
11310 PlanValue::Float(f) => Ok(Literal::Float64(*f)),
11311 PlanValue::Null => Ok(Literal::Null),
11312 PlanValue::Date32(d) => Ok(Literal::Date32(*d)),
11313 PlanValue::Decimal(d) => Ok(Literal::Decimal128(*d)),
11314 _ => Err(Error::Internal(format!(
11315 "unsupported plan value for literal conversion: {:?}",
11316 value
11317 ))),
11318 }
11319}
11320
11321fn collect_table_data<P>(
11322 table_index: usize,
11323 table_ref: &llkv_plan::TableRef,
11324 table: &ExecutorTable<P>,
11325 constraints: &[ColumnConstraint],
11326) -> ExecutorResult<TableCrossProductData>
11327where
11328 P: Pager<Blob = EntryHandle> + Send + Sync,
11329{
11330 if table.schema.columns.is_empty() {
11331 return Err(Error::InvalidArgumentError(format!(
11332 "table '{}' has no columns; cross products require at least one column",
11333 table_ref.qualified_name()
11334 )));
11335 }
11336
11337 let mut projections = Vec::with_capacity(table.schema.columns.len());
11338 let mut fields = Vec::with_capacity(table.schema.columns.len());
11339
11340 for column in &table.schema.columns {
11341 let table_component = table_ref
11342 .alias
11343 .as_deref()
11344 .unwrap_or(table_ref.table.as_str());
11345 let qualified_name = format!("{}.{}.{}", table_ref.schema, table_component, column.name);
11346 projections.push(ScanProjection::from(StoreProjection::with_alias(
11347 LogicalFieldId::for_user(table.table.table_id(), column.field_id),
11348 qualified_name.clone(),
11349 )));
11350 fields.push(Field::new(
11351 qualified_name,
11352 column.data_type.clone(),
11353 column.nullable,
11354 ));
11355 }
11356
11357 let schema = Arc::new(Schema::new(fields));
11358
11359 let filter_field_id = table.schema.first_field_id().unwrap_or(ROW_ID_FIELD_ID);
11360
11361 let mut filter_exprs = Vec::new();
11363 for constraint in constraints {
11364 match constraint {
11365 ColumnConstraint::Equality(lit) => {
11366 let col_idx = lit.column.column;
11367 if col_idx < table.schema.columns.len() {
11368 let field_id = table.schema.columns[col_idx].field_id;
11369 if let Ok(literal) = plan_value_to_literal(&lit.value) {
11370 filter_exprs.push(LlkvExpr::Compare {
11371 left: ScalarExpr::Column(field_id),
11372 op: CompareOp::Eq,
11373 right: ScalarExpr::Literal(literal),
11374 });
11375 }
11376 }
11377 }
11378 ColumnConstraint::InList(in_list) => {
11379 let col_idx = in_list.column.column;
11380 if col_idx < table.schema.columns.len() {
11381 let field_id = table.schema.columns[col_idx].field_id;
11382 let literals: Vec<Literal> = in_list
11383 .values
11384 .iter()
11385 .filter_map(|v| plan_value_to_literal(v).ok())
11386 .collect();
11387
11388 if !literals.is_empty() {
11389 filter_exprs.push(LlkvExpr::InList {
11390 expr: ScalarExpr::Column(field_id),
11391 list: literals.into_iter().map(ScalarExpr::Literal).collect(),
11392 negated: false,
11393 });
11394 }
11395 }
11396 }
11397 }
11398 }
11399
11400 let filter_expr = if filter_exprs.is_empty() {
11401 crate::translation::expression::full_table_scan_filter(filter_field_id)
11402 } else if filter_exprs.len() == 1 {
11403 filter_exprs.pop().unwrap()
11404 } else {
11405 LlkvExpr::And(filter_exprs)
11406 };
11407
11408 let mut raw_batches = Vec::new();
11409 table.table.scan_stream(
11410 projections,
11411 &filter_expr,
11412 ScanStreamOptions {
11413 include_nulls: true,
11414 ..ScanStreamOptions::default()
11415 },
11416 |batch| {
11417 raw_batches.push(batch);
11418 },
11419 )?;
11420
11421 let mut normalized_batches = Vec::with_capacity(raw_batches.len());
11422 for batch in raw_batches {
11423 let normalized = RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec())
11424 .map_err(|err| {
11425 Error::Internal(format!(
11426 "failed to align scan batch for table '{}': {}",
11427 table_ref.qualified_name(),
11428 err
11429 ))
11430 })?;
11431 normalized_batches.push(normalized);
11432 }
11433
11434 if !constraints.is_empty() {
11435 normalized_batches = apply_column_constraints_to_batches(normalized_batches, constraints)?;
11436 }
11437
11438 Ok(TableCrossProductData {
11439 schema,
11440 batches: normalized_batches,
11441 column_counts: vec![table.schema.columns.len()],
11442 table_indices: vec![table_index],
11443 })
11444}
11445
11446fn apply_column_constraints_to_batches(
11447 batches: Vec<RecordBatch>,
11448 constraints: &[ColumnConstraint],
11449) -> ExecutorResult<Vec<RecordBatch>> {
11450 if batches.is_empty() {
11451 return Ok(batches);
11452 }
11453
11454 let mut filtered = batches;
11455 for constraint in constraints {
11456 match constraint {
11457 ColumnConstraint::Equality(lit) => {
11458 filtered = filter_batches_by_literal(filtered, lit.column.column, &lit.value)?;
11459 }
11460 ColumnConstraint::InList(in_list) => {
11461 filtered =
11462 filter_batches_by_in_list(filtered, in_list.column.column, &in_list.values)?;
11463 }
11464 }
11465 if filtered.is_empty() {
11466 break;
11467 }
11468 }
11469
11470 Ok(filtered)
11471}
11472
11473fn filter_batches_by_literal(
11474 batches: Vec<RecordBatch>,
11475 column_idx: usize,
11476 literal: &PlanValue,
11477) -> ExecutorResult<Vec<RecordBatch>> {
11478 let mut result = Vec::with_capacity(batches.len());
11479
11480 for batch in batches {
11481 if column_idx >= batch.num_columns() {
11482 return Err(Error::Internal(
11483 "literal constraint referenced invalid column index".into(),
11484 ));
11485 }
11486
11487 if batch.num_rows() == 0 {
11488 result.push(batch);
11489 continue;
11490 }
11491
11492 let column = batch.column(column_idx);
11493 let mut keep_rows: Vec<u32> = Vec::with_capacity(batch.num_rows());
11494
11495 for row_idx in 0..batch.num_rows() {
11496 if array_value_equals_plan_value(column.as_ref(), row_idx, literal)? {
11497 keep_rows.push(row_idx as u32);
11498 }
11499 }
11500
11501 if keep_rows.len() == batch.num_rows() {
11502 result.push(batch);
11503 continue;
11504 }
11505
11506 if keep_rows.is_empty() {
11507 continue;
11509 }
11510
11511 let indices = UInt32Array::from(keep_rows);
11512 let mut filtered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
11513 for col_idx in 0..batch.num_columns() {
11514 let filtered = take(batch.column(col_idx).as_ref(), &indices, None)
11515 .map_err(|err| Error::Internal(format!("failed to apply literal filter: {err}")))?;
11516 filtered_columns.push(filtered);
11517 }
11518
11519 let filtered_batch =
11520 RecordBatch::try_new(batch.schema(), filtered_columns).map_err(|err| {
11521 Error::Internal(format!(
11522 "failed to rebuild batch after literal filter: {err}"
11523 ))
11524 })?;
11525 result.push(filtered_batch);
11526 }
11527
11528 Ok(result)
11529}
11530
11531fn filter_batches_by_in_list(
11532 batches: Vec<RecordBatch>,
11533 column_idx: usize,
11534 values: &[PlanValue],
11535) -> ExecutorResult<Vec<RecordBatch>> {
11536 use arrow::array::*;
11537 use arrow::compute::or;
11538
11539 if values.is_empty() {
11540 return Ok(Vec::new());
11542 }
11543
11544 let mut result = Vec::with_capacity(batches.len());
11545
11546 for batch in batches {
11547 if column_idx >= batch.num_columns() {
11548 return Err(Error::Internal(
11549 "IN list constraint referenced invalid column index".into(),
11550 ));
11551 }
11552
11553 if batch.num_rows() == 0 {
11554 result.push(batch);
11555 continue;
11556 }
11557
11558 let column = batch.column(column_idx);
11559
11560 let mut mask = BooleanArray::from(vec![false; batch.num_rows()]);
11563
11564 for value in values {
11565 let comparison_mask = build_comparison_mask(column.as_ref(), value)?;
11566 mask = or(&mask, &comparison_mask)
11567 .map_err(|err| Error::Internal(format!("failed to OR comparison masks: {err}")))?;
11568 }
11569
11570 let true_count = mask.true_count();
11572 if true_count == batch.num_rows() {
11573 result.push(batch);
11574 continue;
11575 }
11576
11577 if true_count == 0 {
11578 continue;
11580 }
11581
11582 let filtered_batch = arrow::compute::filter_record_batch(&batch, &mask)
11584 .map_err(|err| Error::Internal(format!("failed to apply IN list filter: {err}")))?;
11585
11586 result.push(filtered_batch);
11587 }
11588
11589 Ok(result)
11590}
11591
11592fn build_comparison_mask(column: &dyn Array, value: &PlanValue) -> ExecutorResult<BooleanArray> {
11594 use arrow::array::*;
11595 use arrow::datatypes::DataType;
11596
11597 match value {
11598 PlanValue::Null => {
11599 let mut builder = BooleanBuilder::with_capacity(column.len());
11601 for i in 0..column.len() {
11602 builder.append_value(column.is_null(i));
11603 }
11604 Ok(builder.finish())
11605 }
11606 PlanValue::Integer(val) => {
11607 let mut builder = BooleanBuilder::with_capacity(column.len());
11608 match column.data_type() {
11609 DataType::Int8 => {
11610 let arr = column
11611 .as_any()
11612 .downcast_ref::<Int8Array>()
11613 .ok_or_else(|| Error::Internal("failed to downcast to Int8Array".into()))?;
11614 let target = *val as i8;
11615 for i in 0..arr.len() {
11616 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11617 }
11618 }
11619 DataType::Int16 => {
11620 let arr = column
11621 .as_any()
11622 .downcast_ref::<Int16Array>()
11623 .ok_or_else(|| {
11624 Error::Internal("failed to downcast to Int16Array".into())
11625 })?;
11626 let target = *val as i16;
11627 for i in 0..arr.len() {
11628 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11629 }
11630 }
11631 DataType::Int32 => {
11632 let arr = column
11633 .as_any()
11634 .downcast_ref::<Int32Array>()
11635 .ok_or_else(|| {
11636 Error::Internal("failed to downcast to Int32Array".into())
11637 })?;
11638 let target = *val as i32;
11639 for i in 0..arr.len() {
11640 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11641 }
11642 }
11643 DataType::Int64 => {
11644 let arr = column
11645 .as_any()
11646 .downcast_ref::<Int64Array>()
11647 .ok_or_else(|| {
11648 Error::Internal("failed to downcast to Int64Array".into())
11649 })?;
11650 for i in 0..arr.len() {
11651 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
11652 }
11653 }
11654 DataType::UInt8 => {
11655 let arr = column
11656 .as_any()
11657 .downcast_ref::<UInt8Array>()
11658 .ok_or_else(|| {
11659 Error::Internal("failed to downcast to UInt8Array".into())
11660 })?;
11661 let target = *val as u8;
11662 for i in 0..arr.len() {
11663 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11664 }
11665 }
11666 DataType::UInt16 => {
11667 let arr = column
11668 .as_any()
11669 .downcast_ref::<UInt16Array>()
11670 .ok_or_else(|| {
11671 Error::Internal("failed to downcast to UInt16Array".into())
11672 })?;
11673 let target = *val as u16;
11674 for i in 0..arr.len() {
11675 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11676 }
11677 }
11678 DataType::UInt32 => {
11679 let arr = column
11680 .as_any()
11681 .downcast_ref::<UInt32Array>()
11682 .ok_or_else(|| {
11683 Error::Internal("failed to downcast to UInt32Array".into())
11684 })?;
11685 let target = *val as u32;
11686 for i in 0..arr.len() {
11687 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11688 }
11689 }
11690 DataType::UInt64 => {
11691 let arr = column
11692 .as_any()
11693 .downcast_ref::<UInt64Array>()
11694 .ok_or_else(|| {
11695 Error::Internal("failed to downcast to UInt64Array".into())
11696 })?;
11697 let target = *val as u64;
11698 for i in 0..arr.len() {
11699 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11700 }
11701 }
11702 _ => {
11703 return Err(Error::Internal(format!(
11704 "unsupported integer type for IN list: {:?}",
11705 column.data_type()
11706 )));
11707 }
11708 }
11709 Ok(builder.finish())
11710 }
11711 PlanValue::Float(val) => {
11712 let mut builder = BooleanBuilder::with_capacity(column.len());
11713 match column.data_type() {
11714 DataType::Float32 => {
11715 let arr = column
11716 .as_any()
11717 .downcast_ref::<Float32Array>()
11718 .ok_or_else(|| {
11719 Error::Internal("failed to downcast to Float32Array".into())
11720 })?;
11721 let target = *val as f32;
11722 for i in 0..arr.len() {
11723 builder.append_value(!arr.is_null(i) && arr.value(i) == target);
11724 }
11725 }
11726 DataType::Float64 => {
11727 let arr = column
11728 .as_any()
11729 .downcast_ref::<Float64Array>()
11730 .ok_or_else(|| {
11731 Error::Internal("failed to downcast to Float64Array".into())
11732 })?;
11733 for i in 0..arr.len() {
11734 builder.append_value(!arr.is_null(i) && arr.value(i) == *val);
11735 }
11736 }
11737 _ => {
11738 return Err(Error::Internal(format!(
11739 "unsupported float type for IN list: {:?}",
11740 column.data_type()
11741 )));
11742 }
11743 }
11744 Ok(builder.finish())
11745 }
11746 PlanValue::Decimal(expected) => match column.data_type() {
11747 DataType::Decimal128(precision, scale) => {
11748 let arr = column
11749 .as_any()
11750 .downcast_ref::<Decimal128Array>()
11751 .ok_or_else(|| {
11752 Error::Internal("failed to downcast to Decimal128Array".into())
11753 })?;
11754 let expected_aligned = align_decimal_to_scale(*expected, *precision, *scale)
11755 .map_err(|err| {
11756 Error::InvalidArgumentError(format!(
11757 "decimal literal {expected} incompatible with DECIMAL({}, {}): {err}",
11758 precision, scale
11759 ))
11760 })?;
11761 let mut builder = BooleanBuilder::with_capacity(arr.len());
11762 for i in 0..arr.len() {
11763 if arr.is_null(i) {
11764 builder.append_value(false);
11765 } else {
11766 let actual = DecimalValue::new(arr.value(i), *scale).map_err(|err| {
11767 Error::InvalidArgumentError(format!(
11768 "invalid decimal value stored in column: {err}"
11769 ))
11770 })?;
11771 builder.append_value(actual.raw_value() == expected_aligned.raw_value());
11772 }
11773 }
11774 Ok(builder.finish())
11775 }
11776 DataType::Int8
11777 | DataType::Int16
11778 | DataType::Int32
11779 | DataType::Int64
11780 | DataType::UInt8
11781 | DataType::UInt16
11782 | DataType::UInt32
11783 | DataType::UInt64
11784 | DataType::Boolean => {
11785 if let Some(int_value) = decimal_exact_i64(*expected) {
11786 return build_comparison_mask(column, &PlanValue::Integer(int_value));
11787 }
11788 Ok(BooleanArray::from(vec![false; column.len()]))
11789 }
11790 DataType::Float32 | DataType::Float64 => {
11791 build_comparison_mask(column, &PlanValue::Float(expected.to_f64()))
11792 }
11793 _ => Err(Error::Internal(format!(
11794 "unsupported decimal type for IN list: {:?}",
11795 column.data_type()
11796 ))),
11797 },
11798 PlanValue::String(val) => {
11799 let mut builder = BooleanBuilder::with_capacity(column.len());
11800 let arr = column
11801 .as_any()
11802 .downcast_ref::<StringArray>()
11803 .ok_or_else(|| Error::Internal("failed to downcast to StringArray".into()))?;
11804 for i in 0..arr.len() {
11805 builder.append_value(!arr.is_null(i) && arr.value(i) == val.as_str());
11806 }
11807 Ok(builder.finish())
11808 }
11809 PlanValue::Date32(days) => {
11810 let mut builder = BooleanBuilder::with_capacity(column.len());
11811 match column.data_type() {
11812 DataType::Date32 => {
11813 let arr = column
11814 .as_any()
11815 .downcast_ref::<Date32Array>()
11816 .ok_or_else(|| {
11817 Error::Internal("failed to downcast to Date32Array".into())
11818 })?;
11819 for i in 0..arr.len() {
11820 builder.append_value(!arr.is_null(i) && arr.value(i) == *days);
11821 }
11822 }
11823 _ => {
11824 return Err(Error::Internal(format!(
11825 "unsupported DATE type for IN list: {:?}",
11826 column.data_type()
11827 )));
11828 }
11829 }
11830 Ok(builder.finish())
11831 }
11832 PlanValue::Interval(interval) => {
11833 let mut builder = BooleanBuilder::with_capacity(column.len());
11834 match column.data_type() {
11835 DataType::Interval(IntervalUnit::MonthDayNano) => {
11836 let arr = column
11837 .as_any()
11838 .downcast_ref::<IntervalMonthDayNanoArray>()
11839 .ok_or_else(|| {
11840 Error::Internal(
11841 "failed to downcast to IntervalMonthDayNanoArray".into(),
11842 )
11843 })?;
11844 let expected = *interval;
11845 for i in 0..arr.len() {
11846 if arr.is_null(i) {
11847 builder.append_value(false);
11848 } else {
11849 let candidate = interval_value_from_arrow(arr.value(i));
11850 let matches = compare_interval_values(expected, candidate)
11851 == std::cmp::Ordering::Equal;
11852 builder.append_value(matches);
11853 }
11854 }
11855 }
11856 _ => {
11857 return Err(Error::Internal(format!(
11858 "unsupported INTERVAL type for IN list: {:?}",
11859 column.data_type()
11860 )));
11861 }
11862 }
11863 Ok(builder.finish())
11864 }
11865 PlanValue::Struct(_) => Err(Error::Internal(
11866 "struct comparison in IN list not supported".into(),
11867 )),
11868 }
11869}
11870
11871fn array_value_equals_plan_value(
11872 array: &dyn Array,
11873 row_idx: usize,
11874 literal: &PlanValue,
11875) -> ExecutorResult<bool> {
11876 use arrow::array::*;
11877 use arrow::datatypes::DataType;
11878
11879 match literal {
11880 PlanValue::Null => Ok(array.is_null(row_idx)),
11881 PlanValue::Decimal(expected) => match array.data_type() {
11882 DataType::Decimal128(precision, scale) => {
11883 if array.is_null(row_idx) {
11884 return Ok(false);
11885 }
11886 let arr = array
11887 .as_any()
11888 .downcast_ref::<Decimal128Array>()
11889 .ok_or_else(|| {
11890 Error::Internal("failed to downcast to Decimal128Array".into())
11891 })?;
11892 let actual = DecimalValue::new(arr.value(row_idx), *scale).map_err(|err| {
11893 Error::InvalidArgumentError(format!(
11894 "invalid decimal value retrieved from column: {err}"
11895 ))
11896 })?;
11897 let expected_aligned = align_decimal_to_scale(*expected, *precision, *scale)
11898 .map_err(|err| {
11899 Error::InvalidArgumentError(format!(
11900 "failed to align decimal literal for comparison: {err}"
11901 ))
11902 })?;
11903 Ok(actual.raw_value() == expected_aligned.raw_value())
11904 }
11905 DataType::Int8
11906 | DataType::Int16
11907 | DataType::Int32
11908 | DataType::Int64
11909 | DataType::UInt8
11910 | DataType::UInt16
11911 | DataType::UInt32
11912 | DataType::UInt64 => {
11913 if array.is_null(row_idx) {
11914 return Ok(false);
11915 }
11916 if let Some(int_value) = decimal_exact_i64(*expected) {
11917 array_value_equals_plan_value(array, row_idx, &PlanValue::Integer(int_value))
11918 } else {
11919 Ok(false)
11920 }
11921 }
11922 DataType::Float32 | DataType::Float64 => {
11923 if array.is_null(row_idx) {
11924 return Ok(false);
11925 }
11926 array_value_equals_plan_value(array, row_idx, &PlanValue::Float(expected.to_f64()))
11927 }
11928 DataType::Boolean => {
11929 if array.is_null(row_idx) {
11930 return Ok(false);
11931 }
11932 if let Some(int_value) = decimal_exact_i64(*expected) {
11933 array_value_equals_plan_value(array, row_idx, &PlanValue::Integer(int_value))
11934 } else {
11935 Ok(false)
11936 }
11937 }
11938 _ => Err(Error::InvalidArgumentError(format!(
11939 "decimal literal comparison not supported for {:?}",
11940 array.data_type()
11941 ))),
11942 },
11943 PlanValue::Integer(expected) => match array.data_type() {
11944 DataType::Int8 => Ok(!array.is_null(row_idx)
11945 && array
11946 .as_any()
11947 .downcast_ref::<Int8Array>()
11948 .expect("int8 array")
11949 .value(row_idx) as i64
11950 == *expected),
11951 DataType::Int16 => Ok(!array.is_null(row_idx)
11952 && array
11953 .as_any()
11954 .downcast_ref::<Int16Array>()
11955 .expect("int16 array")
11956 .value(row_idx) as i64
11957 == *expected),
11958 DataType::Int32 => Ok(!array.is_null(row_idx)
11959 && array
11960 .as_any()
11961 .downcast_ref::<Int32Array>()
11962 .expect("int32 array")
11963 .value(row_idx) as i64
11964 == *expected),
11965 DataType::Int64 => Ok(!array.is_null(row_idx)
11966 && array
11967 .as_any()
11968 .downcast_ref::<Int64Array>()
11969 .expect("int64 array")
11970 .value(row_idx)
11971 == *expected),
11972 DataType::UInt8 if *expected >= 0 => Ok(!array.is_null(row_idx)
11973 && array
11974 .as_any()
11975 .downcast_ref::<UInt8Array>()
11976 .expect("uint8 array")
11977 .value(row_idx) as i64
11978 == *expected),
11979 DataType::UInt16 if *expected >= 0 => Ok(!array.is_null(row_idx)
11980 && array
11981 .as_any()
11982 .downcast_ref::<UInt16Array>()
11983 .expect("uint16 array")
11984 .value(row_idx) as i64
11985 == *expected),
11986 DataType::UInt32 if *expected >= 0 => Ok(!array.is_null(row_idx)
11987 && array
11988 .as_any()
11989 .downcast_ref::<UInt32Array>()
11990 .expect("uint32 array")
11991 .value(row_idx) as i64
11992 == *expected),
11993 DataType::UInt64 if *expected >= 0 => Ok(!array.is_null(row_idx)
11994 && array
11995 .as_any()
11996 .downcast_ref::<UInt64Array>()
11997 .expect("uint64 array")
11998 .value(row_idx)
11999 == *expected as u64),
12000 DataType::Boolean => {
12001 if array.is_null(row_idx) {
12002 Ok(false)
12003 } else if *expected == 0 || *expected == 1 {
12004 let value = array
12005 .as_any()
12006 .downcast_ref::<BooleanArray>()
12007 .expect("bool array")
12008 .value(row_idx);
12009 Ok(value == (*expected == 1))
12010 } else {
12011 Ok(false)
12012 }
12013 }
12014 _ => Err(Error::InvalidArgumentError(format!(
12015 "literal integer comparison not supported for {:?}",
12016 array.data_type()
12017 ))),
12018 },
12019 PlanValue::Float(expected) => match array.data_type() {
12020 DataType::Float32 => Ok(!array.is_null(row_idx)
12021 && (array
12022 .as_any()
12023 .downcast_ref::<Float32Array>()
12024 .expect("float32 array")
12025 .value(row_idx) as f64
12026 - *expected)
12027 .abs()
12028 .eq(&0.0)),
12029 DataType::Float64 => Ok(!array.is_null(row_idx)
12030 && (array
12031 .as_any()
12032 .downcast_ref::<Float64Array>()
12033 .expect("float64 array")
12034 .value(row_idx)
12035 - *expected)
12036 .abs()
12037 .eq(&0.0)),
12038 _ => Err(Error::InvalidArgumentError(format!(
12039 "literal float comparison not supported for {:?}",
12040 array.data_type()
12041 ))),
12042 },
12043 PlanValue::String(expected) => match array.data_type() {
12044 DataType::Utf8 => Ok(!array.is_null(row_idx)
12045 && array
12046 .as_any()
12047 .downcast_ref::<StringArray>()
12048 .expect("string array")
12049 .value(row_idx)
12050 == expected),
12051 DataType::LargeUtf8 => Ok(!array.is_null(row_idx)
12052 && array
12053 .as_any()
12054 .downcast_ref::<LargeStringArray>()
12055 .expect("large string array")
12056 .value(row_idx)
12057 == expected),
12058 _ => Err(Error::InvalidArgumentError(format!(
12059 "literal string comparison not supported for {:?}",
12060 array.data_type()
12061 ))),
12062 },
12063 PlanValue::Date32(expected) => match array.data_type() {
12064 DataType::Date32 => Ok(!array.is_null(row_idx)
12065 && array
12066 .as_any()
12067 .downcast_ref::<Date32Array>()
12068 .expect("date32 array")
12069 .value(row_idx)
12070 == *expected),
12071 _ => Err(Error::InvalidArgumentError(format!(
12072 "literal date comparison not supported for {:?}",
12073 array.data_type()
12074 ))),
12075 },
12076 PlanValue::Interval(expected) => {
12077 match array.data_type() {
12078 DataType::Interval(IntervalUnit::MonthDayNano) => {
12079 if array.is_null(row_idx) {
12080 Ok(false)
12081 } else {
12082 let value = array
12083 .as_any()
12084 .downcast_ref::<IntervalMonthDayNanoArray>()
12085 .expect("interval array")
12086 .value(row_idx);
12087 let arrow_value = interval_value_from_arrow(value);
12088 Ok(compare_interval_values(*expected, arrow_value)
12089 == std::cmp::Ordering::Equal)
12090 }
12091 }
12092 _ => Err(Error::InvalidArgumentError(format!(
12093 "literal interval comparison not supported for {:?}",
12094 array.data_type()
12095 ))),
12096 }
12097 }
12098 PlanValue::Struct(_) => Err(Error::InvalidArgumentError(
12099 "struct literals are not supported in join filters".into(),
12100 )),
12101 }
12102}
12103
12104fn hash_join_table_batches(
12105 left: TableCrossProductData,
12106 right: TableCrossProductData,
12107 join_keys: &[(usize, usize)],
12108 join_type: llkv_join::JoinType,
12109) -> ExecutorResult<TableCrossProductData> {
12110 let TableCrossProductData {
12111 schema: left_schema,
12112 batches: left_batches,
12113 column_counts: left_counts,
12114 table_indices: left_tables,
12115 } = left;
12116
12117 let TableCrossProductData {
12118 schema: right_schema,
12119 batches: right_batches,
12120 column_counts: right_counts,
12121 table_indices: right_tables,
12122 } = right;
12123
12124 let combined_fields: Vec<Field> = left_schema
12125 .fields()
12126 .iter()
12127 .chain(right_schema.fields().iter())
12128 .map(|field| field.as_ref().clone())
12129 .collect();
12130
12131 let combined_schema = Arc::new(Schema::new(combined_fields));
12132
12133 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
12134 column_counts.extend(left_counts.iter());
12135 column_counts.extend(right_counts.iter());
12136
12137 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
12138 table_indices.extend(left_tables.iter().copied());
12139 table_indices.extend(right_tables.iter().copied());
12140
12141 if left_batches.is_empty() {
12143 return Ok(TableCrossProductData {
12144 schema: combined_schema,
12145 batches: Vec::new(),
12146 column_counts,
12147 table_indices,
12148 });
12149 }
12150
12151 if right_batches.is_empty() {
12152 if join_type == llkv_join::JoinType::Left {
12154 let total_left_rows: usize = left_batches.iter().map(|b| b.num_rows()).sum();
12155 let mut left_arrays = Vec::new();
12156 for field in left_schema.fields() {
12157 let column_idx = left_schema.index_of(field.name()).map_err(|e| {
12158 Error::Internal(format!("failed to find field {}: {}", field.name(), e))
12159 })?;
12160 let arrays: Vec<ArrayRef> = left_batches
12161 .iter()
12162 .map(|batch| batch.column(column_idx).clone())
12163 .collect();
12164 let concatenated =
12165 arrow::compute::concat(&arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
12166 .map_err(|e| {
12167 Error::Internal(format!("failed to concat left arrays: {}", e))
12168 })?;
12169 left_arrays.push(concatenated);
12170 }
12171
12172 for field in right_schema.fields() {
12174 let null_array = arrow::array::new_null_array(field.data_type(), total_left_rows);
12175 left_arrays.push(null_array);
12176 }
12177
12178 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), left_arrays)
12179 .map_err(|err| {
12180 Error::Internal(format!(
12181 "failed to create LEFT JOIN batch with NULL right: {err}"
12182 ))
12183 })?;
12184
12185 return Ok(TableCrossProductData {
12186 schema: combined_schema,
12187 batches: vec![joined_batch],
12188 column_counts,
12189 table_indices,
12190 });
12191 } else {
12192 return Ok(TableCrossProductData {
12194 schema: combined_schema,
12195 batches: Vec::new(),
12196 column_counts,
12197 table_indices,
12198 });
12199 }
12200 }
12201
12202 match join_type {
12203 llkv_join::JoinType::Inner => {
12204 let (left_matches, right_matches) =
12205 build_join_match_indices(&left_batches, &right_batches, join_keys)?;
12206
12207 if left_matches.is_empty() {
12208 return Ok(TableCrossProductData {
12209 schema: combined_schema,
12210 batches: Vec::new(),
12211 column_counts,
12212 table_indices,
12213 });
12214 }
12215
12216 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
12217 let right_arrays = gather_indices_from_batches(&right_batches, &right_matches)?;
12218
12219 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
12220 combined_columns.extend(left_arrays);
12221 combined_columns.extend(right_arrays);
12222
12223 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
12224 .map_err(|err| {
12225 Error::Internal(format!("failed to materialize INNER JOIN batch: {err}"))
12226 })?;
12227
12228 Ok(TableCrossProductData {
12229 schema: combined_schema,
12230 batches: vec![joined_batch],
12231 column_counts,
12232 table_indices,
12233 })
12234 }
12235 llkv_join::JoinType::Left => {
12236 let (left_matches, right_optional_matches) =
12237 build_left_join_match_indices(&left_batches, &right_batches, join_keys)?;
12238
12239 if left_matches.is_empty() {
12240 return Ok(TableCrossProductData {
12242 schema: combined_schema,
12243 batches: Vec::new(),
12244 column_counts,
12245 table_indices,
12246 });
12247 }
12248
12249 let left_arrays = gather_indices_from_batches(&left_batches, &left_matches)?;
12250 let right_arrays = llkv_column_map::gather::gather_optional_indices_from_batches(
12252 &right_batches,
12253 &right_optional_matches,
12254 )?;
12255
12256 let mut combined_columns = Vec::with_capacity(left_arrays.len() + right_arrays.len());
12257 combined_columns.extend(left_arrays);
12258 combined_columns.extend(right_arrays);
12259
12260 let joined_batch = RecordBatch::try_new(Arc::clone(&combined_schema), combined_columns)
12261 .map_err(|err| {
12262 Error::Internal(format!("failed to materialize LEFT JOIN batch: {err}"))
12263 })?;
12264
12265 Ok(TableCrossProductData {
12266 schema: combined_schema,
12267 batches: vec![joined_batch],
12268 column_counts,
12269 table_indices,
12270 })
12271 }
12272 _ => Err(Error::Internal(format!(
12274 "join type {:?} not supported in hash_join_table_batches; use llkv-join",
12275 join_type
12276 ))),
12277 }
12278}
12279
12280type JoinMatchIndices = Vec<(usize, usize)>;
12282type JoinHashTable = FxHashMap<Vec<u8>, Vec<(usize, usize)>>;
12284type JoinMatchPairs = (JoinMatchIndices, JoinMatchIndices);
12286type OptionalJoinMatches = Vec<Option<(usize, usize)>>;
12288type LeftJoinMatchPairs = (JoinMatchIndices, OptionalJoinMatches);
12290
12291fn normalize_join_column(array: &ArrayRef) -> ExecutorResult<ArrayRef> {
12292 match array.data_type() {
12293 DataType::Boolean
12294 | DataType::Int8
12295 | DataType::Int16
12296 | DataType::Int32
12297 | DataType::UInt8
12298 | DataType::UInt16
12299 | DataType::UInt32
12300 | DataType::UInt64 => cast(array, &DataType::Int64)
12301 .map_err(|e| Error::Internal(format!("failed to cast integer/boolean to Int64: {e}"))),
12302 DataType::Float32 => cast(array, &DataType::Float64)
12303 .map_err(|e| Error::Internal(format!("failed to cast Float32 to Float64: {e}"))),
12304 DataType::Utf8 | DataType::LargeUtf8 => cast(array, &DataType::LargeUtf8)
12305 .map_err(|e| Error::Internal(format!("failed to cast Utf8 to LargeUtf8: {e}"))),
12306 DataType::Dictionary(_, value_type) => {
12307 let unpacked = cast(array, value_type)
12308 .map_err(|e| Error::Internal(format!("failed to unpack dictionary: {e}")))?;
12309 normalize_join_column(&unpacked)
12310 }
12311 _ => Ok(array.clone()),
12312 }
12313}
12314
12315fn build_join_match_indices(
12345 left_batches: &[RecordBatch],
12346 right_batches: &[RecordBatch],
12347 join_keys: &[(usize, usize)],
12348) -> ExecutorResult<JoinMatchPairs> {
12349 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
12350
12351 let hash_table: JoinHashTable = llkv_column_map::parallel::with_thread_pool(|| {
12354 let local_tables: Vec<ExecutorResult<JoinHashTable>> = right_batches
12355 .par_iter()
12356 .enumerate()
12357 .map(|(batch_idx, batch)| {
12358 let mut local_table: JoinHashTable = FxHashMap::default();
12359
12360 let columns: Vec<ArrayRef> = right_key_indices
12361 .iter()
12362 .map(|&idx| normalize_join_column(batch.column(idx)))
12363 .collect::<ExecutorResult<Vec<_>>>()?;
12364
12365 let sort_fields: Vec<SortField> = columns
12366 .iter()
12367 .map(|c| SortField::new(c.data_type().clone()))
12368 .collect();
12369
12370 let converter = RowConverter::new(sort_fields)
12371 .map_err(|e| Error::Internal(format!("failed to create RowConverter: {e}")))?;
12372 let rows = converter.convert_columns(&columns).map_err(|e| {
12373 Error::Internal(format!("failed to convert columns to rows: {e}"))
12374 })?;
12375
12376 for (row_idx, row) in rows.iter().enumerate() {
12377 if columns.iter().any(|c| c.is_null(row_idx)) {
12379 continue;
12380 }
12381
12382 local_table
12383 .entry(row.as_ref().to_vec())
12384 .or_default()
12385 .push((batch_idx, row_idx));
12386 }
12387
12388 Ok(local_table)
12389 })
12390 .collect();
12391
12392 let mut merged_table: JoinHashTable = FxHashMap::default();
12394 for local_table_res in local_tables {
12395 if let Ok(local_table) = local_table_res {
12396 for (key, mut positions) in local_table {
12397 merged_table.entry(key).or_default().append(&mut positions);
12398 }
12399 } else {
12400 tracing::error!("failed to build hash table for batch");
12401 }
12402 }
12403
12404 merged_table
12405 });
12406
12407 if hash_table.is_empty() {
12408 return Ok((Vec::new(), Vec::new()));
12409 }
12410
12411 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
12412
12413 let matches: Vec<ExecutorResult<JoinMatchPairs>> =
12416 llkv_column_map::parallel::with_thread_pool(|| {
12417 left_batches
12418 .par_iter()
12419 .enumerate()
12420 .map(|(batch_idx, batch)| {
12421 let mut local_left_matches: JoinMatchIndices = Vec::new();
12422 let mut local_right_matches: JoinMatchIndices = Vec::new();
12423
12424 let columns: Vec<ArrayRef> = left_key_indices
12425 .iter()
12426 .map(|&idx| normalize_join_column(batch.column(idx)))
12427 .collect::<ExecutorResult<Vec<_>>>()?;
12428
12429 let sort_fields: Vec<SortField> = columns
12430 .iter()
12431 .map(|c| SortField::new(c.data_type().clone()))
12432 .collect();
12433
12434 let converter = RowConverter::new(sort_fields).map_err(|e| {
12435 Error::Internal(format!("failed to create RowConverter: {e}"))
12436 })?;
12437 let rows = converter.convert_columns(&columns).map_err(|e| {
12438 Error::Internal(format!("failed to convert columns to rows: {e}"))
12439 })?;
12440
12441 for (row_idx, row) in rows.iter().enumerate() {
12442 if columns.iter().any(|c| c.is_null(row_idx)) {
12443 continue;
12444 }
12445
12446 if let Some(positions) = hash_table.get(row.as_ref()) {
12447 for &(r_batch_idx, r_row_idx) in positions {
12448 local_left_matches.push((batch_idx, row_idx));
12449 local_right_matches.push((r_batch_idx, r_row_idx));
12450 }
12451 }
12452 }
12453
12454 Ok((local_left_matches, local_right_matches))
12455 })
12456 .collect()
12457 });
12458
12459 let mut left_matches: JoinMatchIndices = Vec::new();
12461 let mut right_matches: JoinMatchIndices = Vec::new();
12462 for match_res in matches {
12463 let (mut left, mut right) = match_res?;
12464 left_matches.append(&mut left);
12465 right_matches.append(&mut right);
12466 }
12467
12468 Ok((left_matches, right_matches))
12469}
12470
12471fn build_left_join_match_indices(
12482 left_batches: &[RecordBatch],
12483 right_batches: &[RecordBatch],
12484 join_keys: &[(usize, usize)],
12485) -> ExecutorResult<LeftJoinMatchPairs> {
12486 let right_key_indices: Vec<usize> = join_keys.iter().map(|(_, right)| *right).collect();
12487
12488 let hash_table: JoinHashTable = llkv_column_map::parallel::with_thread_pool(|| {
12490 let local_tables: Vec<JoinHashTable> = right_batches
12491 .par_iter()
12492 .enumerate()
12493 .map(|(batch_idx, batch)| {
12494 let mut local_table: JoinHashTable = FxHashMap::default();
12495 let mut key_buffer: Vec<u8> = Vec::new();
12496
12497 for row_idx in 0..batch.num_rows() {
12498 key_buffer.clear();
12499 match build_join_key(batch, &right_key_indices, row_idx, &mut key_buffer) {
12500 Ok(true) => {
12501 local_table
12502 .entry(key_buffer.clone())
12503 .or_default()
12504 .push((batch_idx, row_idx));
12505 }
12506 Ok(false) => continue,
12507 Err(_) => continue,
12508 }
12509 }
12510
12511 local_table
12512 })
12513 .collect();
12514
12515 let mut merged_table: JoinHashTable = FxHashMap::default();
12516 for local_table in local_tables {
12517 for (key, mut positions) in local_table {
12518 merged_table.entry(key).or_default().append(&mut positions);
12519 }
12520 }
12521
12522 merged_table
12523 });
12524
12525 let left_key_indices: Vec<usize> = join_keys.iter().map(|(left, _)| *left).collect();
12526
12527 let matches: Vec<LeftJoinMatchPairs> = llkv_column_map::parallel::with_thread_pool(|| {
12529 left_batches
12530 .par_iter()
12531 .enumerate()
12532 .map(|(batch_idx, batch)| {
12533 let mut local_left_matches: JoinMatchIndices = Vec::new();
12534 let mut local_right_optional: Vec<Option<(usize, usize)>> = Vec::new();
12535 let mut key_buffer: Vec<u8> = Vec::new();
12536
12537 for row_idx in 0..batch.num_rows() {
12538 key_buffer.clear();
12539 match build_join_key(batch, &left_key_indices, row_idx, &mut key_buffer) {
12540 Ok(true) => {
12541 if let Some(entries) = hash_table.get(&key_buffer) {
12542 for &(r_batch, r_row) in entries {
12544 local_left_matches.push((batch_idx, row_idx));
12545 local_right_optional.push(Some((r_batch, r_row)));
12546 }
12547 } else {
12548 local_left_matches.push((batch_idx, row_idx));
12550 local_right_optional.push(None);
12551 }
12552 }
12553 Ok(false) => {
12554 local_left_matches.push((batch_idx, row_idx));
12556 local_right_optional.push(None);
12557 }
12558 Err(_) => {
12559 local_left_matches.push((batch_idx, row_idx));
12561 local_right_optional.push(None);
12562 }
12563 }
12564 }
12565
12566 (local_left_matches, local_right_optional)
12567 })
12568 .collect()
12569 });
12570
12571 let mut left_matches: JoinMatchIndices = Vec::new();
12573 let mut right_optional: Vec<Option<(usize, usize)>> = Vec::new();
12574 for (mut left, mut right) in matches {
12575 left_matches.append(&mut left);
12576 right_optional.append(&mut right);
12577 }
12578
12579 Ok((left_matches, right_optional))
12580}
12581
12582fn build_join_key(
12583 batch: &RecordBatch,
12584 column_indices: &[usize],
12585 row_idx: usize,
12586 buffer: &mut Vec<u8>,
12587) -> ExecutorResult<bool> {
12588 buffer.clear();
12589
12590 for &col_idx in column_indices {
12591 let array = batch.column(col_idx);
12592 if array.is_null(row_idx) {
12593 return Ok(false);
12594 }
12595 append_array_value_to_key(array.as_ref(), row_idx, buffer)?;
12596 }
12597
12598 Ok(true)
12599}
12600
12601fn append_array_value_to_key(
12602 array: &dyn Array,
12603 row_idx: usize,
12604 buffer: &mut Vec<u8>,
12605) -> ExecutorResult<()> {
12606 use arrow::array::*;
12607 use arrow::datatypes::DataType;
12608
12609 match array.data_type() {
12610 DataType::Int8 => buffer.extend_from_slice(
12611 &array
12612 .as_any()
12613 .downcast_ref::<Int8Array>()
12614 .expect("int8 array")
12615 .value(row_idx)
12616 .to_le_bytes(),
12617 ),
12618 DataType::Int16 => buffer.extend_from_slice(
12619 &array
12620 .as_any()
12621 .downcast_ref::<Int16Array>()
12622 .expect("int16 array")
12623 .value(row_idx)
12624 .to_le_bytes(),
12625 ),
12626 DataType::Int32 => buffer.extend_from_slice(
12627 &array
12628 .as_any()
12629 .downcast_ref::<Int32Array>()
12630 .expect("int32 array")
12631 .value(row_idx)
12632 .to_le_bytes(),
12633 ),
12634 DataType::Int64 => buffer.extend_from_slice(
12635 &array
12636 .as_any()
12637 .downcast_ref::<Int64Array>()
12638 .expect("int64 array")
12639 .value(row_idx)
12640 .to_le_bytes(),
12641 ),
12642 DataType::UInt8 => buffer.extend_from_slice(
12643 &array
12644 .as_any()
12645 .downcast_ref::<UInt8Array>()
12646 .expect("uint8 array")
12647 .value(row_idx)
12648 .to_le_bytes(),
12649 ),
12650 DataType::UInt16 => buffer.extend_from_slice(
12651 &array
12652 .as_any()
12653 .downcast_ref::<UInt16Array>()
12654 .expect("uint16 array")
12655 .value(row_idx)
12656 .to_le_bytes(),
12657 ),
12658 DataType::UInt32 => buffer.extend_from_slice(
12659 &array
12660 .as_any()
12661 .downcast_ref::<UInt32Array>()
12662 .expect("uint32 array")
12663 .value(row_idx)
12664 .to_le_bytes(),
12665 ),
12666 DataType::UInt64 => buffer.extend_from_slice(
12667 &array
12668 .as_any()
12669 .downcast_ref::<UInt64Array>()
12670 .expect("uint64 array")
12671 .value(row_idx)
12672 .to_le_bytes(),
12673 ),
12674 DataType::Float32 => buffer.extend_from_slice(
12675 &array
12676 .as_any()
12677 .downcast_ref::<Float32Array>()
12678 .expect("float32 array")
12679 .value(row_idx)
12680 .to_le_bytes(),
12681 ),
12682 DataType::Float64 => buffer.extend_from_slice(
12683 &array
12684 .as_any()
12685 .downcast_ref::<Float64Array>()
12686 .expect("float64 array")
12687 .value(row_idx)
12688 .to_le_bytes(),
12689 ),
12690 DataType::Boolean => buffer.push(
12691 array
12692 .as_any()
12693 .downcast_ref::<BooleanArray>()
12694 .expect("bool array")
12695 .value(row_idx) as u8,
12696 ),
12697 DataType::Utf8 => {
12698 let value = array
12699 .as_any()
12700 .downcast_ref::<StringArray>()
12701 .expect("utf8 array")
12702 .value(row_idx);
12703 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
12704 buffer.extend_from_slice(value.as_bytes());
12705 }
12706 DataType::LargeUtf8 => {
12707 let value = array
12708 .as_any()
12709 .downcast_ref::<LargeStringArray>()
12710 .expect("large utf8 array")
12711 .value(row_idx);
12712 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
12713 buffer.extend_from_slice(value.as_bytes());
12714 }
12715 DataType::Binary => {
12716 let value = array
12717 .as_any()
12718 .downcast_ref::<BinaryArray>()
12719 .expect("binary array")
12720 .value(row_idx);
12721 buffer.extend_from_slice(&(value.len() as u32).to_le_bytes());
12722 buffer.extend_from_slice(value);
12723 }
12724 other => {
12725 return Err(Error::InvalidArgumentError(format!(
12726 "hash join does not support join key type {:?}",
12727 other
12728 )));
12729 }
12730 }
12731
12732 Ok(())
12733}
12734
12735fn table_has_join_with_used(
12736 candidate: usize,
12737 used_tables: &FxHashSet<usize>,
12738 equalities: &[ColumnEquality],
12739) -> bool {
12740 equalities.iter().any(|equality| {
12741 (equality.left.table == candidate && used_tables.contains(&equality.right.table))
12742 || (equality.right.table == candidate && used_tables.contains(&equality.left.table))
12743 })
12744}
12745
12746fn gather_join_keys(
12747 left: &TableCrossProductData,
12748 right: &TableCrossProductData,
12749 used_tables: &FxHashSet<usize>,
12750 right_table_index: usize,
12751 equalities: &[ColumnEquality],
12752) -> ExecutorResult<Vec<(usize, usize)>> {
12753 let mut keys = Vec::new();
12754
12755 for equality in equalities {
12756 if equality.left.table == right_table_index && used_tables.contains(&equality.right.table) {
12757 let left_idx = resolve_column_index(left, &equality.right).ok_or_else(|| {
12758 Error::Internal("failed to resolve column offset for hash join".into())
12759 })?;
12760 let right_idx = resolve_column_index(right, &equality.left).ok_or_else(|| {
12761 Error::Internal("failed to resolve column offset for hash join".into())
12762 })?;
12763 keys.push((left_idx, right_idx));
12764 } else if equality.right.table == right_table_index
12765 && used_tables.contains(&equality.left.table)
12766 {
12767 let left_idx = resolve_column_index(left, &equality.left).ok_or_else(|| {
12768 Error::Internal("failed to resolve column offset for hash join".into())
12769 })?;
12770 let right_idx = resolve_column_index(right, &equality.right).ok_or_else(|| {
12771 Error::Internal("failed to resolve column offset for hash join".into())
12772 })?;
12773 keys.push((left_idx, right_idx));
12774 }
12775 }
12776
12777 Ok(keys)
12778}
12779
12780fn resolve_column_index(data: &TableCrossProductData, column: &ColumnRef) -> Option<usize> {
12781 let mut offset = 0;
12782 for (table_idx, count) in data.table_indices.iter().zip(data.column_counts.iter()) {
12783 if *table_idx == column.table {
12784 if column.column < *count {
12785 return Some(offset + column.column);
12786 } else {
12787 return None;
12788 }
12789 }
12790 offset += count;
12791 }
12792 None
12793}
12794
12795fn build_cross_product_column_lookup(
12796 schema: &Schema,
12797 tables: &[llkv_plan::TableRef],
12798 column_counts: &[usize],
12799 table_indices: &[usize],
12800) -> FxHashMap<String, usize> {
12801 debug_assert_eq!(tables.len(), column_counts.len());
12802 debug_assert_eq!(column_counts.len(), table_indices.len());
12803
12804 let mut column_occurrences: FxHashMap<String, usize> = FxHashMap::default();
12805 let mut table_column_counts: FxHashMap<String, usize> = FxHashMap::default();
12806 for field in schema.fields() {
12807 let column_name = extract_column_name(field.name());
12808 *column_occurrences.entry(column_name).or_insert(0) += 1;
12809 if let Some(pair) = table_column_suffix(field.name()) {
12810 *table_column_counts.entry(pair).or_insert(0) += 1;
12811 }
12812 }
12813
12814 let mut base_table_totals: FxHashMap<String, usize> = FxHashMap::default();
12815 let mut base_table_unaliased: FxHashMap<String, usize> = FxHashMap::default();
12816 for table_ref in tables {
12817 let key = base_table_key(table_ref);
12818 *base_table_totals.entry(key.clone()).or_insert(0) += 1;
12819 if table_ref.alias.is_none() {
12820 *base_table_unaliased.entry(key).or_insert(0) += 1;
12821 }
12822 }
12823
12824 let mut lookup = FxHashMap::default();
12825
12826 if table_indices.is_empty() || column_counts.is_empty() {
12827 for (idx, field) in schema.fields().iter().enumerate() {
12828 let field_name_lower = field.name().to_ascii_lowercase();
12829 lookup.entry(field_name_lower).or_insert(idx);
12830
12831 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
12832 lookup.entry(trimmed_lower).or_insert(idx);
12833
12834 if let Some(pair) = table_column_suffix(field.name())
12835 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
12836 {
12837 lookup.entry(pair).or_insert(idx);
12838 }
12839
12840 let column_name = extract_column_name(field.name());
12841 if column_occurrences.get(&column_name).copied().unwrap_or(0) == 1 {
12842 lookup.entry(column_name).or_insert(idx);
12843 }
12844 }
12845 return lookup;
12846 }
12847
12848 let mut offset = 0usize;
12849 for (&table_idx, &count) in table_indices.iter().zip(column_counts.iter()) {
12850 if table_idx >= tables.len() {
12851 continue;
12852 }
12853 let table_ref = &tables[table_idx];
12854 let alias_lower = table_ref
12855 .alias
12856 .as_ref()
12857 .map(|alias| alias.to_ascii_lowercase());
12858 let table_lower = table_ref.table.to_ascii_lowercase();
12859 let schema_lower = table_ref.schema.to_ascii_lowercase();
12860 let base_key = base_table_key(table_ref);
12861 let total_refs = base_table_totals.get(&base_key).copied().unwrap_or(0);
12862 let unaliased_refs = base_table_unaliased.get(&base_key).copied().unwrap_or(0);
12863
12864 let allow_base_mapping = if table_ref.alias.is_none() {
12865 unaliased_refs == 1
12866 } else {
12867 unaliased_refs == 0 && total_refs == 1
12868 };
12869
12870 let mut table_keys: Vec<String> = Vec::new();
12871
12872 if let Some(alias) = &alias_lower {
12873 table_keys.push(alias.clone());
12874 if !schema_lower.is_empty() {
12875 table_keys.push(format!("{}.{}", schema_lower, alias));
12876 }
12877 }
12878
12879 if allow_base_mapping {
12880 table_keys.push(table_lower.clone());
12881 if !schema_lower.is_empty() {
12882 table_keys.push(format!("{}.{}", schema_lower, table_lower));
12883 }
12884 }
12885
12886 for local_idx in 0..count {
12887 let field_index = offset + local_idx;
12888 let field = schema.field(field_index);
12889 let field_name_lower = field.name().to_ascii_lowercase();
12890 lookup.entry(field_name_lower).or_insert(field_index);
12891
12892 let trimmed_lower = field.name().trim_start_matches('.').to_ascii_lowercase();
12893 lookup.entry(trimmed_lower).or_insert(field_index);
12894
12895 let column_name = extract_column_name(field.name());
12896 for table_key in &table_keys {
12897 lookup
12898 .entry(format!("{}.{}", table_key, column_name))
12899 .or_insert(field_index);
12900 }
12901
12902 lookup.entry(column_name.clone()).or_insert(field_index);
12906
12907 if table_keys.is_empty()
12908 && let Some(pair) = table_column_suffix(field.name())
12909 && table_column_counts.get(&pair).copied().unwrap_or(0) == 1
12910 {
12911 lookup.entry(pair).or_insert(field_index);
12912 }
12913 }
12914
12915 offset = offset.saturating_add(count);
12916 }
12917
12918 lookup
12919}
12920
12921fn base_table_key(table_ref: &llkv_plan::TableRef) -> String {
12922 let schema_lower = table_ref.schema.to_ascii_lowercase();
12923 let table_lower = table_ref.table.to_ascii_lowercase();
12924 if schema_lower.is_empty() {
12925 table_lower
12926 } else {
12927 format!("{}.{}", schema_lower, table_lower)
12928 }
12929}
12930
12931fn extract_column_name(name: &str) -> String {
12932 name.trim_start_matches('.')
12933 .rsplit('.')
12934 .next()
12935 .unwrap_or(name)
12936 .to_ascii_lowercase()
12937}
12938
12939fn table_column_suffix(name: &str) -> Option<String> {
12940 let trimmed = name.trim_start_matches('.');
12941 let mut parts: Vec<&str> = trimmed.split('.').collect();
12942 if parts.len() < 2 {
12943 return None;
12944 }
12945 let column = parts.pop()?.to_ascii_lowercase();
12946 let table = parts.pop()?.to_ascii_lowercase();
12947 Some(format!("{}.{}", table, column))
12948}
12949
12950fn cross_join_table_batches(
12975 left: TableCrossProductData,
12976 right: TableCrossProductData,
12977) -> ExecutorResult<TableCrossProductData> {
12978 let TableCrossProductData {
12979 schema: left_schema,
12980 batches: left_batches,
12981 column_counts: mut left_counts,
12982 table_indices: mut left_tables,
12983 } = left;
12984 let TableCrossProductData {
12985 schema: right_schema,
12986 batches: right_batches,
12987 column_counts: right_counts,
12988 table_indices: right_tables,
12989 } = right;
12990
12991 let combined_fields: Vec<Field> = left_schema
12992 .fields()
12993 .iter()
12994 .chain(right_schema.fields().iter())
12995 .map(|field| field.as_ref().clone())
12996 .collect();
12997
12998 let mut column_counts = Vec::with_capacity(left_counts.len() + right_counts.len());
12999 column_counts.append(&mut left_counts);
13000 column_counts.extend(right_counts);
13001
13002 let mut table_indices = Vec::with_capacity(left_tables.len() + right_tables.len());
13003 table_indices.append(&mut left_tables);
13004 table_indices.extend(right_tables);
13005
13006 let combined_schema = Arc::new(Schema::new(combined_fields));
13007
13008 let left_has_rows = left_batches.iter().any(|batch| batch.num_rows() > 0);
13009 let right_has_rows = right_batches.iter().any(|batch| batch.num_rows() > 0);
13010
13011 if !left_has_rows || !right_has_rows {
13012 return Ok(TableCrossProductData {
13013 schema: combined_schema,
13014 batches: Vec::new(),
13015 column_counts,
13016 table_indices,
13017 });
13018 }
13019
13020 let output_batches: Vec<RecordBatch> = llkv_column_map::parallel::with_thread_pool(|| {
13023 left_batches
13024 .par_iter()
13025 .filter(|left_batch| left_batch.num_rows() > 0)
13026 .flat_map(|left_batch| {
13027 right_batches
13028 .par_iter()
13029 .filter(|right_batch| right_batch.num_rows() > 0)
13030 .filter_map(|right_batch| {
13031 cross_join_pair(left_batch, right_batch, &combined_schema).ok()
13032 })
13033 .collect::<Vec<_>>()
13034 })
13035 .collect()
13036 });
13037
13038 Ok(TableCrossProductData {
13039 schema: combined_schema,
13040 batches: output_batches,
13041 column_counts,
13042 table_indices,
13043 })
13044}
13045
13046fn cross_join_all(staged: Vec<TableCrossProductData>) -> ExecutorResult<TableCrossProductData> {
13047 let mut iter = staged.into_iter();
13048 let mut current = iter
13049 .next()
13050 .ok_or_else(|| Error::Internal("cross product preparation yielded no tables".into()))?;
13051 for next in iter {
13052 current = cross_join_table_batches(current, next)?;
13053 }
13054 Ok(current)
13055}
13056
13057struct TableInfo<'a> {
13058 index: usize,
13059 table_ref: &'a llkv_plan::TableRef,
13060 column_map: FxHashMap<String, usize>,
13061}
13062
13063#[derive(Clone, Copy)]
13064struct ColumnRef {
13065 table: usize,
13066 column: usize,
13067}
13068
13069#[derive(Clone, Copy)]
13070struct ColumnEquality {
13071 left: ColumnRef,
13072 right: ColumnRef,
13073}
13074
13075#[derive(Clone)]
13076struct ColumnLiteral {
13077 column: ColumnRef,
13078 value: PlanValue,
13079}
13080
13081#[derive(Clone)]
13082struct ColumnInList {
13083 column: ColumnRef,
13084 values: Vec<PlanValue>,
13085}
13086
13087#[derive(Clone)]
13088enum ColumnConstraint {
13089 Equality(ColumnLiteral),
13090 InList(ColumnInList),
13091}
13092
13093struct JoinConstraintPlan {
13095 equalities: Vec<ColumnEquality>,
13096 literals: Vec<ColumnConstraint>,
13097 unsatisfiable: bool,
13098 total_conjuncts: usize,
13100 handled_conjuncts: usize,
13102}
13103
13104fn extract_literal_pushdown_filters<P>(
13123 expr: &LlkvExpr<'static, String>,
13124 tables_with_handles: &[(llkv_plan::TableRef, Arc<ExecutorTable<P>>)],
13125) -> Vec<Vec<ColumnConstraint>>
13126where
13127 P: Pager<Blob = EntryHandle> + Send + Sync,
13128{
13129 let mut table_infos = Vec::with_capacity(tables_with_handles.len());
13130 for (index, (table_ref, executor_table)) in tables_with_handles.iter().enumerate() {
13131 let mut column_map = FxHashMap::default();
13132 for (column_idx, column) in executor_table.schema.columns.iter().enumerate() {
13133 let column_name = column.name.to_ascii_lowercase();
13134 column_map.entry(column_name).or_insert(column_idx);
13135 }
13136 table_infos.push(TableInfo {
13137 index,
13138 table_ref,
13139 column_map,
13140 });
13141 }
13142
13143 let mut constraints: Vec<Vec<ColumnConstraint>> = vec![Vec::new(); tables_with_handles.len()];
13144
13145 let mut conjuncts = Vec::new();
13147 collect_conjuncts_lenient(expr, &mut conjuncts);
13148
13149 for conjunct in conjuncts {
13150 if let LlkvExpr::Compare {
13152 left,
13153 op: CompareOp::Eq,
13154 right,
13155 } = conjunct
13156 {
13157 match (
13158 resolve_column_reference(left, &table_infos),
13159 resolve_column_reference(right, &table_infos),
13160 ) {
13161 (Some(column), None) => {
13162 if let Some(literal) = extract_literal(right)
13163 && let Some(value) = PlanValue::from_literal_for_join(literal)
13164 && column.table < constraints.len()
13165 {
13166 constraints[column.table]
13167 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13168 }
13169 }
13170 (None, Some(column)) => {
13171 if let Some(literal) = extract_literal(left)
13172 && let Some(value) = PlanValue::from_literal_for_join(literal)
13173 && column.table < constraints.len()
13174 {
13175 constraints[column.table]
13176 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13177 }
13178 }
13179 _ => {}
13180 }
13181 }
13182 else if let LlkvExpr::Pred(filter) = conjunct {
13185 if let Operator::Equals(ref literal_val) = filter.op {
13186 let field_name = filter.field_id.trim().to_ascii_lowercase();
13188
13189 for info in &table_infos {
13191 if let Some(&col_idx) = info.column_map.get(&field_name) {
13192 if let Some(value) = PlanValue::from_operator_literal(literal_val) {
13193 let column_ref = ColumnRef {
13194 table: info.index,
13195 column: col_idx,
13196 };
13197 if info.index < constraints.len() {
13198 constraints[info.index].push(ColumnConstraint::Equality(
13199 ColumnLiteral {
13200 column: column_ref,
13201 value,
13202 },
13203 ));
13204 }
13205 }
13206 break; }
13208 }
13209 }
13210 }
13211 else if let LlkvExpr::InList {
13213 expr: col_expr,
13214 list,
13215 negated: false,
13216 } = conjunct
13217 {
13218 if let Some(column) = resolve_column_reference(col_expr, &table_infos) {
13219 let mut values = Vec::new();
13220 for item in list {
13221 if let Some(literal) = extract_literal(item)
13222 && let Some(value) = PlanValue::from_literal_for_join(literal)
13223 {
13224 values.push(value);
13225 }
13226 }
13227 if !values.is_empty() && column.table < constraints.len() {
13228 constraints[column.table]
13229 .push(ColumnConstraint::InList(ColumnInList { column, values }));
13230 }
13231 }
13232 }
13233 else if let LlkvExpr::Or(or_children) = conjunct
13235 && let Some((column, values)) = try_extract_or_as_in_list(or_children, &table_infos)
13236 && !values.is_empty()
13237 && column.table < constraints.len()
13238 {
13239 constraints[column.table]
13240 .push(ColumnConstraint::InList(ColumnInList { column, values }));
13241 }
13242 }
13243
13244 constraints
13245}
13246
13247fn collect_conjuncts_lenient<'a>(
13252 expr: &'a LlkvExpr<'static, String>,
13253 out: &mut Vec<&'a LlkvExpr<'static, String>>,
13254) {
13255 match expr {
13256 LlkvExpr::And(children) => {
13257 for child in children {
13258 collect_conjuncts_lenient(child, out);
13259 }
13260 }
13261 other => {
13262 out.push(other);
13264 }
13265 }
13266}
13267
13268fn try_extract_or_as_in_list(
13272 or_children: &[LlkvExpr<'static, String>],
13273 table_infos: &[TableInfo<'_>],
13274) -> Option<(ColumnRef, Vec<PlanValue>)> {
13275 if or_children.is_empty() {
13276 return None;
13277 }
13278
13279 let mut common_column: Option<ColumnRef> = None;
13280 let mut values = Vec::new();
13281
13282 for child in or_children {
13283 if let LlkvExpr::Compare {
13285 left,
13286 op: CompareOp::Eq,
13287 right,
13288 } = child
13289 {
13290 if let (Some(column), None) = (
13292 resolve_column_reference(left, table_infos),
13293 resolve_column_reference(right, table_infos),
13294 ) && let Some(literal) = extract_literal(right)
13295 && let Some(value) = PlanValue::from_literal_for_join(literal)
13296 {
13297 match common_column {
13299 None => common_column = Some(column),
13300 Some(ref prev)
13301 if prev.table == column.table && prev.column == column.column =>
13302 {
13303 }
13305 _ => {
13306 return None;
13308 }
13309 }
13310 values.push(value);
13311 continue;
13312 }
13313
13314 if let (None, Some(column)) = (
13316 resolve_column_reference(left, table_infos),
13317 resolve_column_reference(right, table_infos),
13318 ) && let Some(literal) = extract_literal(left)
13319 && let Some(value) = PlanValue::from_literal_for_join(literal)
13320 {
13321 match common_column {
13322 None => common_column = Some(column),
13323 Some(ref prev)
13324 if prev.table == column.table && prev.column == column.column => {}
13325 _ => return None,
13326 }
13327 values.push(value);
13328 continue;
13329 }
13330 }
13331 else if let LlkvExpr::Pred(filter) = child
13333 && let Operator::Equals(ref literal) = filter.op
13334 && let Some(column) =
13335 resolve_column_reference(&ScalarExpr::Column(filter.field_id.clone()), table_infos)
13336 && let Some(value) = PlanValue::from_literal_for_join(literal)
13337 {
13338 match common_column {
13339 None => common_column = Some(column),
13340 Some(ref prev) if prev.table == column.table && prev.column == column.column => {}
13341 _ => return None,
13342 }
13343 values.push(value);
13344 continue;
13345 }
13346
13347 return None;
13349 }
13350
13351 common_column.map(|col| (col, values))
13352}
13353
13354fn extract_join_constraints(
13381 expr: &LlkvExpr<'static, String>,
13382 table_infos: &[TableInfo<'_>],
13383) -> Option<JoinConstraintPlan> {
13384 let mut conjuncts = Vec::new();
13385 collect_conjuncts_lenient(expr, &mut conjuncts);
13387
13388 let total_conjuncts = conjuncts.len();
13389 let mut equalities = Vec::new();
13390 let mut literals = Vec::new();
13391 let mut unsatisfiable = false;
13392 let mut handled_conjuncts = 0;
13393
13394 for conjunct in conjuncts {
13395 match conjunct {
13396 LlkvExpr::Literal(true) => {
13397 handled_conjuncts += 1;
13398 }
13399 LlkvExpr::Literal(false) => {
13400 unsatisfiable = true;
13401 handled_conjuncts += 1;
13402 break;
13403 }
13404 LlkvExpr::Compare {
13405 left,
13406 op: CompareOp::Eq,
13407 right,
13408 } => {
13409 match (
13410 resolve_column_reference(left, table_infos),
13411 resolve_column_reference(right, table_infos),
13412 ) {
13413 (Some(left_col), Some(right_col)) => {
13414 equalities.push(ColumnEquality {
13415 left: left_col,
13416 right: right_col,
13417 });
13418 handled_conjuncts += 1;
13419 continue;
13420 }
13421 (Some(column), None) => {
13422 if let Some(literal) = extract_literal(right)
13423 && let Some(value) = PlanValue::from_literal_for_join(literal)
13424 {
13425 literals
13426 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13427 handled_conjuncts += 1;
13428 continue;
13429 }
13430 }
13431 (None, Some(column)) => {
13432 if let Some(literal) = extract_literal(left)
13433 && let Some(value) = PlanValue::from_literal_for_join(literal)
13434 {
13435 literals
13436 .push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13437 handled_conjuncts += 1;
13438 continue;
13439 }
13440 }
13441 _ => {}
13442 }
13443 }
13445 LlkvExpr::InList {
13447 expr: col_expr,
13448 list,
13449 negated: false,
13450 } => {
13451 if let Some(column) = resolve_column_reference(col_expr, table_infos) {
13452 let mut in_list_values = Vec::new();
13454 for item in list {
13455 if let Some(literal) = extract_literal(item)
13456 && let Some(value) = PlanValue::from_literal_for_join(literal)
13457 {
13458 in_list_values.push(value);
13459 }
13460 }
13461 if !in_list_values.is_empty() {
13462 literals.push(ColumnConstraint::InList(ColumnInList {
13463 column,
13464 values: in_list_values,
13465 }));
13466 handled_conjuncts += 1;
13467 continue;
13468 }
13469 }
13470 }
13472 LlkvExpr::Or(or_children) => {
13474 if let Some((column, values)) = try_extract_or_as_in_list(or_children, table_infos)
13475 {
13476 literals.push(ColumnConstraint::InList(ColumnInList { column, values }));
13478 handled_conjuncts += 1;
13479 continue;
13480 }
13481 }
13483 LlkvExpr::Pred(filter) => {
13485 if let Operator::Equals(ref literal) = filter.op
13487 && let Some(column) = resolve_column_reference(
13488 &ScalarExpr::Column(filter.field_id.clone()),
13489 table_infos,
13490 )
13491 && let Some(value) = PlanValue::from_literal_for_join(literal)
13492 {
13493 literals.push(ColumnConstraint::Equality(ColumnLiteral { column, value }));
13494 handled_conjuncts += 1;
13495 continue;
13496 }
13497 }
13499 _ => {
13500 }
13502 }
13503 }
13504
13505 Some(JoinConstraintPlan {
13506 equalities,
13507 literals,
13508 unsatisfiable,
13509 total_conjuncts,
13510 handled_conjuncts,
13511 })
13512}
13513
13514fn resolve_column_reference(
13515 expr: &ScalarExpr<String>,
13516 table_infos: &[TableInfo<'_>],
13517) -> Option<ColumnRef> {
13518 let name = match expr {
13519 ScalarExpr::Column(name) => name.trim(),
13520 _ => return None,
13521 };
13522
13523 let mut parts: Vec<&str> = name
13524 .trim_start_matches('.')
13525 .split('.')
13526 .filter(|segment| !segment.is_empty())
13527 .collect();
13528
13529 if parts.is_empty() {
13530 return None;
13531 }
13532
13533 let column_part = parts.pop()?.to_ascii_lowercase();
13534 if parts.is_empty() {
13535 for info in table_infos {
13539 if let Some(&col_idx) = info.column_map.get(&column_part) {
13540 return Some(ColumnRef {
13541 table: info.index,
13542 column: col_idx,
13543 });
13544 }
13545 }
13546 return None;
13547 }
13548
13549 let table_ident = parts.join(".").to_ascii_lowercase();
13550 for info in table_infos {
13551 if matches_table_ident(info.table_ref, &table_ident) {
13552 if let Some(&col_idx) = info.column_map.get(&column_part) {
13553 return Some(ColumnRef {
13554 table: info.index,
13555 column: col_idx,
13556 });
13557 } else {
13558 return None;
13559 }
13560 }
13561 }
13562 None
13563}
13564
13565fn matches_table_ident(table_ref: &llkv_plan::TableRef, ident: &str) -> bool {
13566 if ident.is_empty() {
13567 return false;
13568 }
13569 if let Some(alias) = &table_ref.alias
13570 && alias.to_ascii_lowercase() == ident
13571 {
13572 return true;
13573 }
13574 if table_ref.table.to_ascii_lowercase() == ident {
13575 return true;
13576 }
13577 if !table_ref.schema.is_empty() {
13578 let full = format!(
13579 "{}.{}",
13580 table_ref.schema.to_ascii_lowercase(),
13581 table_ref.table.to_ascii_lowercase()
13582 );
13583 if full == ident {
13584 return true;
13585 }
13586 }
13587 false
13588}
13589
13590fn extract_literal(expr: &ScalarExpr<String>) -> Option<&Literal> {
13591 match expr {
13592 ScalarExpr::Literal(lit) => Some(lit),
13593 _ => None,
13594 }
13595}
13596
13597#[derive(Default)]
13598struct DistinctState {
13599 seen: FxHashSet<CanonicalRow>,
13600}
13601
13602impl DistinctState {
13603 fn insert(&mut self, row: CanonicalRow) -> bool {
13604 self.seen.insert(row)
13605 }
13606}
13607
13608fn distinct_filter_batch(
13609 batch: RecordBatch,
13610 state: &mut DistinctState,
13611) -> ExecutorResult<Option<RecordBatch>> {
13612 if batch.num_rows() == 0 {
13613 return Ok(None);
13614 }
13615
13616 let mut keep_flags = Vec::with_capacity(batch.num_rows());
13617 let mut keep_count = 0usize;
13618
13619 for row_idx in 0..batch.num_rows() {
13620 let row = CanonicalRow::from_batch(&batch, row_idx)?;
13621 if state.insert(row) {
13622 keep_flags.push(true);
13623 keep_count += 1;
13624 } else {
13625 keep_flags.push(false);
13626 }
13627 }
13628
13629 if keep_count == 0 {
13630 return Ok(None);
13631 }
13632
13633 if keep_count == batch.num_rows() {
13634 return Ok(Some(batch));
13635 }
13636
13637 let mut builder = BooleanBuilder::with_capacity(batch.num_rows());
13638 for flag in keep_flags {
13639 builder.append_value(flag);
13640 }
13641 let mask = Arc::new(builder.finish());
13642
13643 let filtered = filter_record_batch(&batch, &mask).map_err(|err| {
13644 Error::InvalidArgumentError(format!("failed to apply DISTINCT filter: {err}"))
13645 })?;
13646
13647 Ok(Some(filtered))
13648}
13649
13650fn sort_record_batch_with_order(
13651 schema: &Arc<Schema>,
13652 batch: &RecordBatch,
13653 order_by: &[OrderByPlan],
13654) -> ExecutorResult<RecordBatch> {
13655 if order_by.is_empty() {
13656 return Ok(batch.clone());
13657 }
13658
13659 let mut sort_columns: Vec<SortColumn> = Vec::with_capacity(order_by.len());
13660
13661 for order in order_by {
13662 let column_index = match &order.target {
13663 OrderTarget::Column(name) => schema.index_of(name).map_err(|_| {
13664 Error::InvalidArgumentError(format!(
13665 "ORDER BY references unknown column '{}'",
13666 name
13667 ))
13668 })?,
13669 OrderTarget::Index(idx) => {
13670 if *idx >= batch.num_columns() {
13671 return Err(Error::InvalidArgumentError(format!(
13672 "ORDER BY position {} is out of bounds for {} columns",
13673 idx + 1,
13674 batch.num_columns()
13675 )));
13676 }
13677 *idx
13678 }
13679 OrderTarget::All => {
13680 return Err(Error::InvalidArgumentError(
13681 "ORDER BY ALL should be expanded before sorting".into(),
13682 ));
13683 }
13684 };
13685
13686 let source_array = batch.column(column_index);
13687
13688 let values: ArrayRef = match order.sort_type {
13689 OrderSortType::Native => Arc::clone(source_array),
13690 OrderSortType::CastTextToInteger => {
13691 let strings = source_array
13692 .as_any()
13693 .downcast_ref::<StringArray>()
13694 .ok_or_else(|| {
13695 Error::InvalidArgumentError(
13696 "ORDER BY CAST expects the underlying column to be TEXT".into(),
13697 )
13698 })?;
13699 let mut builder = Int64Builder::with_capacity(strings.len());
13700 for i in 0..strings.len() {
13701 if strings.is_null(i) {
13702 builder.append_null();
13703 } else {
13704 match strings.value(i).parse::<i64>() {
13705 Ok(value) => builder.append_value(value),
13706 Err(_) => builder.append_null(),
13707 }
13708 }
13709 }
13710 Arc::new(builder.finish()) as ArrayRef
13711 }
13712 };
13713
13714 let sort_options = SortOptions {
13715 descending: !order.ascending,
13716 nulls_first: order.nulls_first,
13717 };
13718
13719 sort_columns.push(SortColumn {
13720 values,
13721 options: Some(sort_options),
13722 });
13723 }
13724
13725 let indices = lexsort_to_indices(&sort_columns, None).map_err(|err| {
13726 Error::InvalidArgumentError(format!("failed to compute ORDER BY indices: {err}"))
13727 })?;
13728
13729 let perm = indices
13730 .as_any()
13731 .downcast_ref::<UInt32Array>()
13732 .ok_or_else(|| Error::Internal("ORDER BY sorting produced unexpected index type".into()))?;
13733
13734 let mut reordered_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
13735 for col_idx in 0..batch.num_columns() {
13736 let reordered = take(batch.column(col_idx), perm, None).map_err(|err| {
13737 Error::InvalidArgumentError(format!(
13738 "failed to apply ORDER BY permutation to column {col_idx}: {err}"
13739 ))
13740 })?;
13741 reordered_columns.push(reordered);
13742 }
13743
13744 RecordBatch::try_new(Arc::clone(schema), reordered_columns)
13745 .map_err(|err| Error::Internal(format!("failed to build reordered ORDER BY batch: {err}")))
13746}
13747
13748#[cfg(test)]
13749mod tests {
13750 use super::*;
13751 use arrow::array::{Array, ArrayRef, Date32Array, Int64Array};
13752 use arrow::datatypes::{DataType, Field, Schema};
13753 use llkv_expr::expr::{BinaryOp, CompareOp};
13754 use llkv_expr::literal::Literal;
13755 use llkv_storage::pager::MemPager;
13756 use std::sync::Arc;
13757
13758 #[test]
13759 fn cross_product_context_evaluates_expressions() {
13760 let schema = Arc::new(Schema::new(vec![
13761 Field::new("main.tab2.a", DataType::Int64, false),
13762 Field::new("main.tab2.b", DataType::Int64, false),
13763 ]));
13764
13765 let batch = RecordBatch::try_new(
13766 Arc::clone(&schema),
13767 vec![
13768 Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef,
13769 Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef,
13770 ],
13771 )
13772 .expect("valid batch");
13773
13774 let lookup = build_cross_product_column_lookup(schema.as_ref(), &[], &[], &[]);
13775 let mut ctx = CrossProductExpressionContext::new(schema.as_ref(), lookup)
13776 .expect("context builds from schema");
13777
13778 let literal_expr: ScalarExpr<String> = ScalarExpr::literal(67);
13779 let literal = ctx
13780 .evaluate(&literal_expr, &batch)
13781 .expect("literal evaluation succeeds");
13782 let literal_array = literal
13783 .as_any()
13784 .downcast_ref::<Int64Array>()
13785 .expect("int64 literal result");
13786 assert_eq!(literal_array.len(), 3);
13787 assert!(literal_array.iter().all(|value| value == Some(67)));
13788
13789 let add_expr = ScalarExpr::binary(
13790 ScalarExpr::column("tab2.a".to_string()),
13791 BinaryOp::Add,
13792 ScalarExpr::literal(5),
13793 );
13794 let added = ctx
13795 .evaluate(&add_expr, &batch)
13796 .expect("column addition succeeds");
13797 let added_array = added
13798 .as_any()
13799 .downcast_ref::<Int64Array>()
13800 .expect("int64 addition result");
13801 assert_eq!(added_array.values(), &[6, 7, 8]);
13802 }
13803
13804 #[test]
13805 fn cross_product_filter_handles_date32_columns() {
13806 let schema = Arc::new(Schema::new(vec![Field::new(
13807 "orders.o_orderdate",
13808 DataType::Date32,
13809 false,
13810 )]));
13811
13812 let batch = RecordBatch::try_new(
13813 Arc::clone(&schema),
13814 vec![Arc::new(Date32Array::from(vec![0, 1, 3])) as ArrayRef],
13815 )
13816 .expect("valid batch");
13817
13818 let lookup = build_cross_product_column_lookup(schema.as_ref(), &[], &[], &[]);
13819 let mut ctx = CrossProductExpressionContext::new(schema.as_ref(), lookup)
13820 .expect("context builds from schema");
13821
13822 let field_id = ctx
13823 .schema()
13824 .columns
13825 .first()
13826 .expect("schema exposes date column")
13827 .field_id;
13828
13829 let predicate = LlkvExpr::Compare {
13830 left: ScalarExpr::Column(field_id),
13831 op: CompareOp::GtEq,
13832 right: ScalarExpr::Literal(Literal::Date32(1)),
13833 };
13834
13835 let truths = ctx
13836 .evaluate_predicate_truths(&predicate, &batch, &mut |_, _, _, _| Ok(None))
13837 .expect("date comparison evaluates");
13838
13839 assert_eq!(truths, vec![Some(false), Some(true), Some(true)]);
13840 }
13841
13842 #[test]
13843 fn group_by_handles_date32_columns() {
13844 let array: ArrayRef = Arc::new(Date32Array::from(vec![Some(3), None, Some(-7)]));
13845
13846 let first = group_key_value(&array, 0).expect("extract first group key");
13847 assert_eq!(first, GroupKeyValue::Int(3));
13848
13849 let second = group_key_value(&array, 1).expect("extract second group key");
13850 assert_eq!(second, GroupKeyValue::Null);
13851
13852 let third = group_key_value(&array, 2).expect("extract third group key");
13853 assert_eq!(third, GroupKeyValue::Int(-7));
13854 }
13855
13856 #[test]
13857 fn aggregate_expr_allows_numeric_casts() {
13858 let expr = ScalarExpr::Cast {
13859 expr: Box::new(ScalarExpr::literal(31)),
13860 data_type: DataType::Int32,
13861 };
13862 let aggregates = FxHashMap::default();
13863
13864 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
13865 .expect("cast should succeed for in-range integral values");
13866
13867 assert_eq!(value, Some(31));
13868 }
13869
13870 #[test]
13871 fn aggregate_expr_cast_rejects_out_of_range_values() {
13872 let expr = ScalarExpr::Cast {
13873 expr: Box::new(ScalarExpr::literal(-1)),
13874 data_type: DataType::UInt8,
13875 };
13876 let aggregates = FxHashMap::default();
13877
13878 let result = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates);
13879
13880 assert!(matches!(result, Err(Error::InvalidArgumentError(_))));
13881 }
13882
13883 #[test]
13884 fn aggregate_expr_null_literal_remains_null() {
13885 let expr = ScalarExpr::binary(
13886 ScalarExpr::literal(0),
13887 BinaryOp::Subtract,
13888 ScalarExpr::cast(ScalarExpr::literal(Literal::Null), DataType::Int64),
13889 );
13890 let aggregates = FxHashMap::default();
13891
13892 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
13893 .expect("expression should evaluate");
13894
13895 assert_eq!(value, None);
13896 }
13897
13898 #[test]
13899 fn aggregate_expr_divide_by_zero_returns_null() {
13900 let expr = ScalarExpr::binary(
13901 ScalarExpr::literal(10),
13902 BinaryOp::Divide,
13903 ScalarExpr::literal(0),
13904 );
13905 let aggregates = FxHashMap::default();
13906
13907 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
13908 .expect("division should evaluate");
13909
13910 assert_eq!(value, None);
13911 }
13912
13913 #[test]
13914 fn aggregate_expr_modulo_by_zero_returns_null() {
13915 let expr = ScalarExpr::binary(
13916 ScalarExpr::literal(10),
13917 BinaryOp::Modulo,
13918 ScalarExpr::literal(0),
13919 );
13920 let aggregates = FxHashMap::default();
13921
13922 let value = QueryExecutor::<MemPager>::evaluate_expr_with_aggregates(&expr, &aggregates)
13923 .expect("modulo should evaluate");
13924
13925 assert_eq!(value, None);
13926 }
13927
13928 #[test]
13929 fn constant_and_with_null_yields_null() {
13930 let expr = ScalarExpr::binary(
13931 ScalarExpr::literal(Literal::Null),
13932 BinaryOp::And,
13933 ScalarExpr::literal(1),
13934 );
13935
13936 let value = evaluate_constant_scalar_with_aggregates(&expr)
13937 .expect("expression should fold as constant");
13938
13939 assert!(matches!(value, Literal::Null));
13940 }
13941
13942 #[test]
13943 fn cross_product_handles_more_than_two_tables() {
13944 let schema_a = Arc::new(Schema::new(vec![Field::new(
13945 "main.t1.a",
13946 DataType::Int64,
13947 false,
13948 )]));
13949 let schema_b = Arc::new(Schema::new(vec![Field::new(
13950 "main.t2.b",
13951 DataType::Int64,
13952 false,
13953 )]));
13954 let schema_c = Arc::new(Schema::new(vec![Field::new(
13955 "main.t3.c",
13956 DataType::Int64,
13957 false,
13958 )]));
13959
13960 let batch_a = RecordBatch::try_new(
13961 Arc::clone(&schema_a),
13962 vec![Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef],
13963 )
13964 .expect("valid batch");
13965 let batch_b = RecordBatch::try_new(
13966 Arc::clone(&schema_b),
13967 vec![Arc::new(Int64Array::from(vec![10, 20, 30])) as ArrayRef],
13968 )
13969 .expect("valid batch");
13970 let batch_c = RecordBatch::try_new(
13971 Arc::clone(&schema_c),
13972 vec![Arc::new(Int64Array::from(vec![100])) as ArrayRef],
13973 )
13974 .expect("valid batch");
13975
13976 let data_a = TableCrossProductData {
13977 schema: schema_a,
13978 batches: vec![batch_a],
13979 column_counts: vec![1],
13980 table_indices: vec![0],
13981 };
13982 let data_b = TableCrossProductData {
13983 schema: schema_b,
13984 batches: vec![batch_b],
13985 column_counts: vec![1],
13986 table_indices: vec![1],
13987 };
13988 let data_c = TableCrossProductData {
13989 schema: schema_c,
13990 batches: vec![batch_c],
13991 column_counts: vec![1],
13992 table_indices: vec![2],
13993 };
13994
13995 let ab = cross_join_table_batches(data_a, data_b).expect("two-table product");
13996 assert_eq!(ab.schema.fields().len(), 2);
13997 assert_eq!(ab.batches.len(), 1);
13998 assert_eq!(ab.batches[0].num_rows(), 6);
13999
14000 let abc = cross_join_table_batches(ab, data_c).expect("three-table product");
14001 assert_eq!(abc.schema.fields().len(), 3);
14002 assert_eq!(abc.batches.len(), 1);
14003
14004 let final_batch = &abc.batches[0];
14005 assert_eq!(final_batch.num_rows(), 6);
14006
14007 let col_a = final_batch
14008 .column(0)
14009 .as_any()
14010 .downcast_ref::<Int64Array>()
14011 .expect("left column values");
14012 assert_eq!(col_a.values(), &[1, 1, 1, 2, 2, 2]);
14013
14014 let col_b = final_batch
14015 .column(1)
14016 .as_any()
14017 .downcast_ref::<Int64Array>()
14018 .expect("middle column values");
14019 assert_eq!(col_b.values(), &[10, 20, 30, 10, 20, 30]);
14020
14021 let col_c = final_batch
14022 .column(2)
14023 .as_any()
14024 .downcast_ref::<Int64Array>()
14025 .expect("right column values");
14026 assert_eq!(col_c.values(), &[100, 100, 100, 100, 100, 100]);
14027 }
14028}