llkv_compute/
eval.rs

1use std::hash::Hash;
2use std::sync::Arc;
3
4use arrow::array::{Array, ArrayRef, Float64Array, UInt32Array, new_null_array};
5use arrow::compute::kernels::cast;
6use arrow::compute::kernels::zip::zip;
7use arrow::compute::{concat, is_not_null, take};
8use arrow::datatypes::{DataType, Field, IntervalMonthDayNanoType};
9use llkv_expr::literal::{Literal, LiteralExt};
10use llkv_expr::{AggregateCall, BinaryOp, CompareOp, ScalarExpr};
11use llkv_result::{Error, Result as LlkvResult};
12use llkv_types::IntervalValue;
13use rustc_hash::{FxHashMap, FxHashSet};
14use sqlparser::ast::BinaryOperator;
15
16use crate::date::{add_interval_to_date32, parse_date32_literal, subtract_interval_from_date32};
17use crate::kernels::{compute_binary, get_common_type};
18
19/// Mapping from field identifiers to the numeric Arrow array used for evaluation.
20pub type NumericArrayMap<F> = FxHashMap<F, ArrayRef>;
21
22/// Intermediate representation for vectorized evaluators.
23enum VectorizedExpr {
24    Array(ArrayRef),
25    Scalar(ArrayRef),
26}
27
28impl VectorizedExpr {
29    fn materialize(self, len: usize, target_type: DataType) -> ArrayRef {
30        match self {
31            VectorizedExpr::Array(array) => {
32                if array.data_type() == &target_type {
33                    array
34                } else {
35                    cast::cast(&array, &target_type).unwrap_or(array)
36                }
37            }
38            VectorizedExpr::Scalar(scalar_array) => {
39                if scalar_array.is_empty() {
40                    return new_null_array(&target_type, len);
41                }
42                if scalar_array.is_null(0) {
43                    return new_null_array(scalar_array.data_type(), len);
44                }
45
46                // Expand scalar to array of length len
47                let indices = UInt32Array::from(vec![0; len]);
48                take(&scalar_array, &indices, None)
49                    .unwrap_or_else(|_| new_null_array(scalar_array.data_type(), len))
50            }
51        }
52    }
53}
54
55/// Extension methods for type inference on `ScalarExpr`.
56pub trait ScalarExprTypeExt<F> {
57    fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
58    where
59        F: Hash + Eq + Copy,
60        R: FnMut(F) -> Option<DataType>;
61
62    fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType
63    where
64        F: Hash + Eq + Copy;
65
66    fn contains_interval(&self) -> bool;
67}
68
69impl<F: Hash + Eq + Copy> ScalarExprTypeExt<F> for ScalarExpr<F> {
70    fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
71    where
72        R: FnMut(F) -> Option<DataType>,
73    {
74        match self {
75            ScalarExpr::Literal(lit) => Some(literal_type(lit)),
76            ScalarExpr::Column(fid) => resolve_type(*fid),
77            ScalarExpr::Binary { left, op, right } => {
78                let left_type = left.infer_result_type(resolve_type)?;
79                let right_type = right.infer_result_type(resolve_type)?;
80                Some(binary_result_type(*op, left_type, right_type))
81            }
82            ScalarExpr::Compare { .. } => Some(DataType::Boolean),
83            ScalarExpr::Not(_) => Some(DataType::Boolean),
84            ScalarExpr::IsNull { .. } => Some(DataType::Boolean),
85            ScalarExpr::Aggregate(call) => aggregate_result_type(call, resolve_type),
86            ScalarExpr::GetField { base, field_name } => {
87                let base_type = base.infer_result_type(resolve_type)?;
88                match base_type {
89                    DataType::Struct(fields) => fields
90                        .iter()
91                        .find(|f| f.name() == field_name)
92                        .map(|f| f.data_type().clone()),
93                    _ => None,
94                }
95            }
96            ScalarExpr::Cast { data_type, .. } => Some(data_type.clone()),
97            ScalarExpr::Case {
98                branches,
99                else_expr,
100                ..
101            } => {
102                let mut types = Vec::new();
103                for (_, then_expr) in branches {
104                    if let Some(t) = then_expr.infer_result_type(resolve_type) {
105                        types.push(t);
106                    }
107                }
108                if let Some(else_expr) = else_expr {
109                    if let Some(t) = else_expr.infer_result_type(resolve_type) {
110                        types.push(t);
111                    }
112                } else {
113                    // Implicit ELSE NULL
114                    types.push(DataType::Null);
115                }
116
117                if types.is_empty() {
118                    return None;
119                }
120
121                let mut common = types[0].clone();
122                for t in &types[1..] {
123                    common = get_common_type(&common, t);
124                }
125
126                Some(common)
127            }
128            ScalarExpr::Coalesce(items) => {
129                let mut types = Vec::new();
130                for item in items {
131                    if let Some(t) = item.infer_result_type(resolve_type) {
132                        types.push(t);
133                    }
134                }
135                if types.is_empty() {
136                    return None;
137                }
138                let mut common = types[0].clone();
139                for t in &types[1..] {
140                    common = get_common_type(&common, t);
141                }
142                Some(common)
143            }
144            ScalarExpr::Random => Some(DataType::Float64),
145            ScalarExpr::ScalarSubquery(sub) => Some(sub.data_type.clone()),
146        }
147    }
148
149    fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType {
150        let mut resolver = |fid| arrays.get(&fid).map(|a| a.data_type().clone());
151        self.infer_result_type(&mut resolver)
152            .unwrap_or(DataType::Float64)
153    }
154
155    fn contains_interval(&self) -> bool {
156        match self {
157            ScalarExpr::Literal(Literal::Interval(_)) => true,
158            ScalarExpr::Binary { left, right, .. } => {
159                left.contains_interval() || right.contains_interval()
160            }
161            _ => false,
162        }
163    }
164}
165
166fn literal_type(lit: &Literal) -> DataType {
167    match lit {
168        Literal::Null => DataType::Null,
169        Literal::Boolean(_) => DataType::Boolean,
170        Literal::Int128(_) => DataType::Int64, // Default to Int64 for literals
171        Literal::Float64(_) => DataType::Float64,
172        Literal::Decimal128(d) => DataType::Decimal128(d.precision(), d.scale()),
173        Literal::String(_) => DataType::Utf8,
174        Literal::Date32(_) => DataType::Date32,
175        Literal::Interval(_) => DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
176        Literal::Struct(fields) => {
177            let arrow_fields = fields
178                .iter()
179                .map(|(name, lit)| Field::new(name, literal_type(lit), true))
180                .collect();
181            DataType::Struct(arrow_fields)
182        }
183    }
184}
185
186fn aggregate_result_type<F, R>(call: &AggregateCall<F>, resolve_type: &mut R) -> Option<DataType>
187where
188    F: Hash + Eq + Copy,
189    R: FnMut(F) -> Option<DataType>,
190{
191    match call {
192        AggregateCall::CountStar | AggregateCall::Count { .. } | AggregateCall::CountNulls(_) => {
193            Some(DataType::Int64)
194        }
195        AggregateCall::Sum { expr, .. } => {
196            let child = expr.infer_result_type(resolve_type)?;
197            Some(match child {
198                DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
199                DataType::Float32 | DataType::Float64 => DataType::Float64,
200                DataType::UInt64 | DataType::Int64 => child,
201                DataType::UInt32
202                | DataType::UInt16
203                | DataType::UInt8
204                | DataType::Int32
205                | DataType::Int16
206                | DataType::Int8 => DataType::Int64,
207                _ => DataType::Float64,
208            })
209        }
210        AggregateCall::Total { expr, .. } | AggregateCall::Avg { expr, .. } => {
211            let child = expr.infer_result_type(resolve_type)?;
212            Some(match child {
213                DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
214                _ => DataType::Float64,
215            })
216        }
217        AggregateCall::Min(expr) | AggregateCall::Max(expr) => expr.infer_result_type(resolve_type),
218        AggregateCall::GroupConcat { .. } => Some(DataType::Utf8),
219    }
220}
221
222fn binary_result_type(_op: BinaryOp, lhs: DataType, rhs: DataType) -> DataType {
223    get_common_type(&lhs, &rhs)
224}
225
226/// Represents an affine transformation `scale * field + offset`.
227#[derive(Clone, Copy, Debug)]
228pub struct AffineExpr<F> {
229    pub field: F,
230    pub scale: f64,
231    pub offset: f64,
232}
233
234/// Internal accumulator representing a partially merged affine expression.
235#[derive(Clone, Copy, Debug)]
236#[allow(dead_code)]
237struct AffineState<F> {
238    field: Option<F>,
239    scale: f64,
240    offset: f64,
241}
242
243/// Centralizes the numeric kernels applied to scalar expressions so they can be
244/// tuned without touching the surrounding table scan logic.
245pub struct ScalarEvaluator;
246
247impl ScalarEvaluator {
248    /// Combine field identifiers while tracking whether multiple fields were encountered.
249    #[allow(dead_code)]
250    fn merge_field<F: Eq + Copy>(lhs: Option<F>, rhs: Option<F>) -> Option<Option<F>> {
251        match (lhs, rhs) {
252            (Some(a), Some(b)) => {
253                if a == b {
254                    Some(Some(a))
255                } else {
256                    None
257                }
258            }
259            (Some(a), None) => Some(Some(a)),
260            (None, Some(b)) => Some(Some(b)),
261            (None, None) => Some(None),
262        }
263    }
264
265    /// Collect every field referenced by the scalar expression into `acc`.
266    pub fn collect_fields<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>, acc: &mut FxHashSet<F>) {
267        match expr {
268            ScalarExpr::Column(fid) => {
269                acc.insert(*fid);
270            }
271            ScalarExpr::Literal(_) => {}
272            ScalarExpr::Binary { left, right, .. } => {
273                Self::collect_fields(left, acc);
274                Self::collect_fields(right, acc);
275            }
276            ScalarExpr::Compare { left, right, .. } => {
277                Self::collect_fields(left, acc);
278                Self::collect_fields(right, acc);
279            }
280            ScalarExpr::Not(inner) => {
281                Self::collect_fields(inner, acc);
282            }
283            ScalarExpr::IsNull { expr, .. } => {
284                Self::collect_fields(expr, acc);
285            }
286            ScalarExpr::Aggregate(agg) => {
287                // Collect fields referenced by the aggregate expression
288                match agg {
289                    AggregateCall::CountStar => {}
290                    AggregateCall::Count { expr, .. }
291                    | AggregateCall::Sum { expr, .. }
292                    | AggregateCall::Total { expr, .. }
293                    | AggregateCall::Avg { expr, .. }
294                    | AggregateCall::Min(expr)
295                    | AggregateCall::Max(expr)
296                    | AggregateCall::CountNulls(expr)
297                    | AggregateCall::GroupConcat { expr, .. } => {
298                        Self::collect_fields(expr, acc);
299                    }
300                }
301            }
302            ScalarExpr::GetField { base, .. } => {
303                // Collect fields from the base expression
304                Self::collect_fields(base, acc);
305            }
306            ScalarExpr::Cast { expr, .. } => {
307                Self::collect_fields(expr, acc);
308            }
309            ScalarExpr::Case {
310                operand,
311                branches,
312                else_expr,
313            } => {
314                if let Some(inner) = operand.as_deref() {
315                    Self::collect_fields(inner, acc);
316                }
317                for (when_expr, then_expr) in branches {
318                    Self::collect_fields(when_expr, acc);
319                    Self::collect_fields(then_expr, acc);
320                }
321                if let Some(inner) = else_expr.as_deref() {
322                    Self::collect_fields(inner, acc);
323                }
324            }
325            ScalarExpr::Coalesce(items) => {
326                for item in items {
327                    Self::collect_fields(item, acc);
328                }
329            }
330            ScalarExpr::Random => {
331                // Random does not reference any fields
332            }
333            ScalarExpr::ScalarSubquery(_) => {
334                // Scalar subqueries don't directly reference fields from the outer query
335            }
336        }
337    }
338
339    pub fn prepare_numeric_arrays<F: Hash + Eq + Copy>(
340        arrays: &FxHashMap<F, ArrayRef>,
341        _row_count: usize,
342    ) -> NumericArrayMap<F> {
343        arrays.clone()
344    }
345
346    /// Attempts to represent the expression as `scale * column + offset`.
347    /// Returns `None` when the expression depends on multiple columns or is non-linear.
348    pub fn extract_affine<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineExpr<F>> {
349        let simplified = Self::simplify(expr);
350        Self::extract_affine_simplified(&simplified)
351    }
352
353    /// Variant of `extract_affine` that assumes `expr` is already simplified.
354    pub fn extract_affine_simplified<F: Hash + Eq + Copy>(
355        expr: &ScalarExpr<F>,
356    ) -> Option<AffineExpr<F>> {
357        let state = Self::affine_state(expr)?;
358        let field = state.field?;
359        Some(AffineExpr {
360            field,
361            scale: state.scale,
362            offset: state.offset,
363        })
364    }
365
366    fn affine_state<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineState<F>> {
367        match expr {
368            ScalarExpr::Column(fid) => Some(AffineState {
369                field: Some(*fid),
370                scale: 1.0,
371                offset: 0.0,
372            }),
373            ScalarExpr::Literal(lit) => {
374                let arr = Self::literal_to_array(lit);
375                let val = cast::cast(&arr, &DataType::Float64).ok()?;
376                let val = val.as_any().downcast_ref::<Float64Array>()?;
377                if val.is_null(0) {
378                    return None;
379                }
380                Some(AffineState {
381                    field: None,
382                    scale: 0.0,
383                    offset: val.value(0),
384                })
385            }
386            ScalarExpr::Aggregate(_) => None,
387            ScalarExpr::GetField { .. } => None,
388            ScalarExpr::Binary { left, op, right } => {
389                let left_state = Self::affine_state(left)?;
390                let right_state = Self::affine_state(right)?;
391                match op {
392                    BinaryOp::Add => Self::affine_add(left_state, right_state),
393                    BinaryOp::Subtract => Self::affine_sub(left_state, right_state),
394                    BinaryOp::Multiply => Self::affine_mul(left_state, right_state),
395                    BinaryOp::Divide => Self::affine_div(left_state, right_state),
396                    _ => None,
397                }
398            }
399            ScalarExpr::Cast { expr, .. } => Self::affine_state(expr),
400            _ => None,
401        }
402    }
403
404    /// Evaluate a scalar expression for the row at `idx` using the provided numeric arrays.
405    pub fn evaluate_value<F: Hash + Eq + Copy>(
406        expr: &ScalarExpr<F>,
407        idx: usize,
408        arrays: &NumericArrayMap<F>,
409    ) -> LlkvResult<ArrayRef> {
410        match expr {
411            ScalarExpr::Column(fid) => {
412                let array = arrays
413                    .get(fid)
414                    .ok_or_else(|| Error::Internal("missing column for field".into()))?;
415                Ok(array.slice(idx, 1))
416            }
417            ScalarExpr::Literal(lit) => Ok(Self::literal_to_array(lit)),
418            ScalarExpr::Binary { left, op, right } => {
419                let l = Self::evaluate_value(left, idx, arrays)?;
420                let r = Self::evaluate_value(right, idx, arrays)?;
421                Self::evaluate_binary_scalar(&l, *op, &r)
422            }
423            ScalarExpr::Compare { left, op, right } => {
424                let l = Self::evaluate_value(left, idx, arrays)?;
425                let r = Self::evaluate_value(right, idx, arrays)?;
426                crate::kernels::compute_compare(&l, *op, &r)
427            }
428            ScalarExpr::Not(expr) => {
429                let val = Self::evaluate_value(expr, idx, arrays)?;
430                let bool_arr = cast::cast(&val, &DataType::Boolean)
431                    .map_err(|e| Error::Internal(e.to_string()))?;
432                let bool_arr = bool_arr
433                    .as_any()
434                    .downcast_ref::<arrow::array::BooleanArray>()
435                    .unwrap();
436                let result = arrow::compute::kernels::boolean::not(bool_arr)
437                    .map_err(|e| Error::Internal(e.to_string()))?;
438                Ok(Arc::new(result))
439            }
440            ScalarExpr::IsNull { expr, negated } => {
441                let val = Self::evaluate_value(expr, idx, arrays)?;
442                let is_null = val.is_null(0);
443                let result = if *negated { !is_null } else { is_null };
444                Ok(Arc::new(arrow::array::BooleanArray::from(vec![result])))
445            }
446            ScalarExpr::Cast { expr, data_type } => {
447                let val = Self::evaluate_value(expr, idx, arrays)?;
448                cast::cast(&val, data_type).map_err(|e| Error::Internal(e.to_string()))
449            }
450            ScalarExpr::Case {
451                operand,
452                branches,
453                else_expr,
454            } => {
455                let operand_val = if let Some(op) = operand {
456                    Some(Self::evaluate_value(op, idx, arrays)?)
457                } else {
458                    None
459                };
460
461                for (when_expr, then_expr) in branches {
462                    let when_val = Self::evaluate_value(when_expr, idx, arrays)?;
463
464                    let is_match = if let Some(op_val) = &operand_val {
465                        // Simple CASE: operand = when_val
466                        // If either is null, result is null (false for condition)
467                        if op_val.is_null(0) || when_val.is_null(0) {
468                            false
469                        } else {
470                            let eq =
471                                crate::kernels::compute_compare(op_val, CompareOp::Eq, &when_val)?;
472                            let bool_arr = eq
473                                .as_any()
474                                .downcast_ref::<arrow::array::BooleanArray>()
475                                .unwrap();
476                            bool_arr.value(0)
477                        }
478                    } else {
479                        // Searched CASE: when_val is boolean condition
480                        if when_val.is_null(0) {
481                            false
482                        } else {
483                            let bool_arr = cast::cast(&when_val, &DataType::Boolean)
484                                .map_err(|e| Error::Internal(e.to_string()))?;
485                            let bool_arr = bool_arr
486                                .as_any()
487                                .downcast_ref::<arrow::array::BooleanArray>()
488                                .unwrap();
489                            bool_arr.value(0)
490                        }
491                    };
492
493                    if is_match {
494                        return Self::evaluate_value(then_expr, idx, arrays);
495                    }
496                }
497                if let Some(else_expr) = else_expr {
498                    Self::evaluate_value(else_expr, idx, arrays)
499                } else {
500                    Ok(new_null_array(&DataType::Null, 1))
501                }
502            }
503            ScalarExpr::Coalesce(items) => {
504                for item in items {
505                    let val = Self::evaluate_value(item, idx, arrays)?;
506                    if !val.is_null(0) && val.data_type() != &DataType::Null {
507                        return Ok(val);
508                    }
509                }
510                Ok(new_null_array(&DataType::Null, 1))
511            }
512            ScalarExpr::Random => {
513                let val = rand::random::<f64>();
514                Ok(Arc::new(Float64Array::from(vec![val])))
515            }
516            _ => Err(Error::Internal("Unsupported scalar expression".into())),
517        }
518    }
519
520    fn literal_to_array(lit: &Literal) -> ArrayRef {
521        match lit {
522            Literal::Null => new_null_array(&DataType::Null, 1),
523            Literal::Boolean(b) => Arc::new(arrow::array::BooleanArray::from(vec![*b])),
524            Literal::Int128(i) => Arc::new(arrow::array::Int64Array::from(vec![*i as i64])),
525            Literal::Float64(f) => Arc::new(Float64Array::from(vec![*f])),
526            Literal::Decimal128(d) => {
527                let array = arrow::array::Decimal128Array::from(vec![Some(d.raw_value())])
528                    .with_precision_and_scale(d.precision(), d.scale())
529                    .unwrap();
530                Arc::new(array)
531            }
532            Literal::String(s) => Arc::new(arrow::array::StringArray::from(vec![s.clone()])),
533            Literal::Date32(d) => Arc::new(arrow::array::Date32Array::from(vec![*d])),
534            Literal::Interval(i) => {
535                let val = IntervalMonthDayNanoType::make_value(i.months, i.days, i.nanos);
536                Arc::new(arrow::array::IntervalMonthDayNanoArray::from(vec![val]))
537            }
538            Literal::Struct(_) => {
539                new_null_array(&DataType::Struct(arrow::datatypes::Fields::empty()), 1)
540            }
541        }
542    }
543
544    fn evaluate_binary_scalar(
545        lhs: &ArrayRef,
546        op: BinaryOp,
547        rhs: &ArrayRef,
548    ) -> LlkvResult<ArrayRef> {
549        compute_binary(lhs, rhs, op)
550    }
551
552    /// Evaluate a scalar expression for every row in the batch.
553    #[allow(dead_code)]
554    pub fn evaluate_batch<F: Hash + Eq + Copy + std::fmt::Debug>(
555        expr: &ScalarExpr<F>,
556        len: usize,
557        arrays: &NumericArrayMap<F>,
558    ) -> LlkvResult<ArrayRef> {
559        let simplified = Self::simplify(expr);
560        Self::evaluate_batch_simplified(&simplified, len, arrays)
561    }
562
563    /// Evaluate a scalar expression that has already been simplified.
564    pub fn evaluate_batch_simplified<F: Hash + Eq + Copy>(
565        expr: &ScalarExpr<F>,
566        len: usize,
567        arrays: &NumericArrayMap<F>,
568    ) -> LlkvResult<ArrayRef> {
569        let preferred = expr.infer_result_type_from_arrays(arrays);
570
571        if len == 0 {
572            return Ok(new_null_array(&preferred, 0));
573        }
574
575        if let Some(vectorized) =
576            Self::try_evaluate_vectorized(expr, len, arrays, preferred.clone())?
577        {
578            let result = vectorized.materialize(len, preferred);
579            return Ok(result);
580        }
581
582        let mut values = Vec::with_capacity(len);
583        for idx in 0..len {
584            let val = Self::evaluate_value(expr, idx, arrays)?;
585            if val.data_type() != &preferred {
586                let casted = cast::cast(&val, &preferred).map_err(|e| {
587                    Error::Internal(format!(
588                        "Failed to cast row {}: {} (Val type: {:?}, Preferred: {:?})",
589                        idx,
590                        e,
591                        val.data_type(),
592                        preferred
593                    ))
594                })?;
595                values.push(casted);
596            } else {
597                values.push(val);
598            }
599        }
600        concat(&values.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
601            .map_err(|e| Error::Internal(e.to_string()))
602    }
603
604    fn try_evaluate_vectorized<F: Hash + Eq + Copy>(
605        expr: &ScalarExpr<F>,
606        len: usize,
607        arrays: &NumericArrayMap<F>,
608        _target_type: DataType,
609    ) -> LlkvResult<Option<VectorizedExpr>> {
610        if expr.contains_interval() {
611            return Ok(None);
612        }
613        match expr {
614            ScalarExpr::Column(fid) => {
615                let array = arrays
616                    .get(fid)
617                    .ok_or_else(|| Error::Internal("missing column for field".into()))?;
618                Ok(Some(VectorizedExpr::Array(array.clone())))
619            }
620            ScalarExpr::Literal(lit) => {
621                let array = Self::literal_to_array(lit);
622                Ok(Some(VectorizedExpr::Scalar(array)))
623            }
624            ScalarExpr::Binary { left, op, right } => {
625                let left_type = left.infer_result_type_from_arrays(arrays);
626                let right_type = right.infer_result_type_from_arrays(arrays);
627
628                let left_vec = Self::try_evaluate_vectorized(left, len, arrays, left_type)?;
629                let right_vec = Self::try_evaluate_vectorized(right, len, arrays, right_type)?;
630
631                match (left_vec, right_vec) {
632                    (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
633                        let result = compute_binary(&lhs, &rhs, *op)?;
634                        Ok(Some(VectorizedExpr::Scalar(result)))
635                    }
636                    (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Array(rhs))) => {
637                        let array = compute_binary(&lhs, &rhs, *op)?;
638                        Ok(Some(VectorizedExpr::Array(array)))
639                    }
640                    (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
641                        let rhs_expanded = VectorizedExpr::Scalar(rhs)
642                            .materialize(lhs.len(), lhs.data_type().clone());
643                        let array = compute_binary(&lhs, &rhs_expanded, *op)?;
644                        Ok(Some(VectorizedExpr::Array(array)))
645                    }
646                    (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Array(rhs))) => {
647                        let lhs_expanded = VectorizedExpr::Scalar(lhs)
648                            .materialize(rhs.len(), rhs.data_type().clone());
649                        let array = compute_binary(&lhs_expanded, &rhs, *op)?;
650                        Ok(Some(VectorizedExpr::Array(array)))
651                    }
652                    _ => Ok(None),
653                }
654            }
655            ScalarExpr::Cast { expr, data_type } => {
656                let inner_type = expr.infer_result_type_from_arrays(arrays);
657                let inner_vec = Self::try_evaluate_vectorized(expr, len, arrays, inner_type)?;
658
659                match inner_vec {
660                    Some(VectorizedExpr::Scalar(array)) => {
661                        let casted = cast::cast(&array, data_type)
662                            .map_err(|e| Error::Internal(e.to_string()))?;
663                        Ok(Some(VectorizedExpr::Scalar(casted)))
664                    }
665                    Some(VectorizedExpr::Array(array)) => {
666                        let casted = cast::cast(&array, data_type)
667                            .map_err(|e| Error::Internal(e.to_string()))?;
668                        Ok(Some(VectorizedExpr::Array(casted)))
669                    }
670                    None => Ok(None),
671                }
672            }
673            ScalarExpr::Coalesce(items) => {
674                let mut evaluated_items = Vec::with_capacity(items.len());
675                let mut types = Vec::with_capacity(items.len());
676
677                for item in items {
678                    let item_type = item.infer_result_type_from_arrays(arrays);
679                    // If any item cannot be vectorized, we cannot vectorize the whole Coalesce
680                    let vec_expr = match Self::try_evaluate_vectorized(
681                        item,
682                        len,
683                        arrays,
684                        item_type.clone(),
685                    )? {
686                        Some(v) => v,
687                        None => return Ok(None),
688                    };
689
690                    let array = vec_expr.materialize(len, item_type.clone());
691                    types.push(array.data_type().clone());
692                    evaluated_items.push(array);
693                }
694
695                if evaluated_items.is_empty() {
696                    return Ok(Some(VectorizedExpr::Array(new_null_array(
697                        &DataType::Null,
698                        len,
699                    ))));
700                }
701
702                // Determine common type
703                let mut common_type = types[0].clone();
704                for t in &types[1..] {
705                    common_type = get_common_type(&common_type, t);
706                }
707
708                // Cast all arrays to common type
709                let mut casted_arrays = Vec::with_capacity(evaluated_items.len());
710                for array in evaluated_items {
711                    if array.data_type() != &common_type {
712                        let casted = cast::cast(&array, &common_type)
713                            .map_err(|e| Error::Internal(e.to_string()))?;
714                        casted_arrays.push(casted);
715                    } else {
716                        casted_arrays.push(array);
717                    }
718                }
719
720                let mut result = casted_arrays[0].clone();
721                for next_array in &casted_arrays[1..] {
722                    let mask = is_not_null(&result).map_err(|e| Error::Internal(e.to_string()))?;
723                    // result = zip(mask, result, next_array)
724                    // if mask is true (result is not null), keep result.
725                    // if mask is false (result is null), take next_array.
726                    result = zip(&mask, &result, next_array)
727                        .map_err(|e| Error::Internal(e.to_string()))?;
728                }
729                Ok(Some(VectorizedExpr::Array(result)))
730            }
731            ScalarExpr::Random => {
732                let values: Vec<f64> = (0..len).map(|_| rand::random::<f64>()).collect();
733                let array = Float64Array::from(values);
734                Ok(Some(VectorizedExpr::Array(Arc::new(array))))
735            }
736            _ => Ok(None),
737        }
738    }
739
740    /// Returns the column referenced by an expression when it's a direct or additive identity passthrough.
741    pub fn passthrough_column<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<F> {
742        match Self::simplify(expr) {
743            ScalarExpr::Column(fid) => Some(fid),
744            _ => None,
745        }
746    }
747
748    /// Simplify an expression by constant folding and identity removal.
749    pub fn simplify<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> ScalarExpr<F> {
750        match expr {
751            ScalarExpr::Binary { left, op, right } => {
752                let l = Self::simplify(left);
753                let r = Self::simplify(right);
754                if let (ScalarExpr::Literal(ll), ScalarExpr::Literal(rr)) = (&l, &r)
755                    && let Some(folded) = fold_binary_literals(*op, ll, rr)
756                {
757                    return ScalarExpr::Literal(folded);
758                }
759                ScalarExpr::Binary {
760                    left: Box::new(l),
761                    op: *op,
762                    right: Box::new(r),
763                }
764            }
765            ScalarExpr::Cast { expr, data_type } => {
766                let inner = Self::simplify(expr);
767                if let ScalarExpr::Literal(lit) = &inner
768                    && let Some(folded) = fold_cast_literal(lit, data_type)
769                {
770                    return ScalarExpr::Literal(folded);
771                }
772                ScalarExpr::Cast {
773                    expr: Box::new(inner),
774                    data_type: data_type.clone(),
775                }
776            }
777            _ => expr.clone(),
778        }
779    }
780
781    pub fn evaluate_constant_literal_expr<F: Hash + Eq + Copy>(
782        expr: &ScalarExpr<F>,
783    ) -> LlkvResult<Option<Literal>> {
784        let simplified = Self::simplify(expr);
785
786        if let Some(literal) = Self::evaluate_constant_literal_non_numeric(&simplified)? {
787            return Ok(Some(literal));
788        }
789
790        if let ScalarExpr::Literal(lit) = &simplified {
791            return Ok(Some(lit.clone()));
792        }
793
794        let arrays = NumericArrayMap::default();
795        let array = Self::evaluate_value(&simplified, 0, &arrays)?;
796        if array.is_null(0) {
797            return Ok(None);
798        }
799        Ok(Some(Literal::from_array_ref(&array, 0)?))
800    }
801
802    pub fn evaluate_constant_literal_non_numeric<F: Hash + Eq + Copy>(
803        expr: &ScalarExpr<F>,
804    ) -> LlkvResult<Option<Literal>> {
805        match expr {
806            ScalarExpr::Literal(lit) => Ok(Some(lit.clone())),
807            ScalarExpr::Cast {
808                expr,
809                data_type: DataType::Date32,
810            } => {
811                let inner = Self::evaluate_constant_literal_non_numeric(expr)?;
812                match inner {
813                    Some(Literal::Null) => Ok(Some(Literal::Null)),
814                    Some(Literal::String(text)) => {
815                        let days = parse_date32_literal(&text)?;
816                        Ok(Some(Literal::Date32(days)))
817                    }
818                    Some(Literal::Date32(days)) => Ok(Some(Literal::Date32(days))),
819                    Some(other) => Err(Error::InvalidArgumentError(format!(
820                        "cannot cast literal of type {} to DATE",
821                        other.type_name()
822                    ))),
823                    None => Ok(None),
824                }
825            }
826            ScalarExpr::Cast { .. } => Ok(None),
827            ScalarExpr::Binary { left, op, right } => {
828                let left_lit = match Self::evaluate_constant_literal_non_numeric(left)? {
829                    Some(lit) => lit,
830                    None => return Ok(None),
831                };
832                let right_lit = match Self::evaluate_constant_literal_non_numeric(right)? {
833                    Some(lit) => lit,
834                    None => return Ok(None),
835                };
836
837                if matches!(left_lit, Literal::Null) || matches!(right_lit, Literal::Null) {
838                    return Ok(Some(Literal::Null));
839                }
840
841                match op {
842                    BinaryOp::Add => match (&left_lit, &right_lit) {
843                        (Literal::Date32(days), Literal::Interval(interval))
844                        | (Literal::Interval(interval), Literal::Date32(days)) => {
845                            let adjusted = add_interval_to_date32(*days, *interval)?;
846                            Ok(Some(Literal::Date32(adjusted)))
847                        }
848                        (Literal::Interval(left), Literal::Interval(right)) => {
849                            let sum = left.checked_add(*right).ok_or_else(|| {
850                                Error::InvalidArgumentError(
851                                    "interval addition overflow during constant folding".into(),
852                                )
853                            })?;
854                            Ok(Some(Literal::Interval(sum)))
855                        }
856                        _ => Ok(None),
857                    },
858                    BinaryOp::Subtract => match (&left_lit, &right_lit) {
859                        (Literal::Date32(days), Literal::Interval(interval)) => {
860                            let adjusted = subtract_interval_from_date32(*days, *interval)?;
861                            Ok(Some(Literal::Date32(adjusted)))
862                        }
863                        (Literal::Interval(left), Literal::Interval(right)) => {
864                            let diff = left.checked_sub(*right).ok_or_else(|| {
865                                Error::InvalidArgumentError(
866                                    "interval subtraction overflow during constant folding".into(),
867                                )
868                            })?;
869                            Ok(Some(Literal::Interval(diff)))
870                        }
871                        (Literal::Date32(lhs), Literal::Date32(rhs)) => {
872                            let delta = i64::from(*lhs) - i64::from(*rhs);
873                            if delta < i64::from(i32::MIN) || delta > i64::from(i32::MAX) {
874                                return Err(Error::InvalidArgumentError(
875                                    "DATE subtraction overflowed day precision".into(),
876                                ));
877                            }
878                            Ok(Some(Literal::Interval(IntervalValue::new(
879                                0,
880                                delta as i32,
881                                0,
882                            ))))
883                        }
884                        _ => Ok(None),
885                    },
886                    _ => Ok(None),
887                }
888            }
889            _ => Ok(None),
890        }
891    }
892
893    // TODO: Should Decimal types be included here?
894    pub fn is_supported_numeric(dtype: &DataType) -> bool {
895        matches!(
896            dtype,
897            DataType::UInt64
898                | DataType::UInt32
899                | DataType::UInt16
900                | DataType::UInt8
901                | DataType::Int64
902                | DataType::Int32
903                | DataType::Int16
904                | DataType::Int8
905                | DataType::Float64
906                | DataType::Float32
907        )
908    }
909
910    #[allow(dead_code)]
911    fn affine_add<F: Eq + Copy>(
912        lhs: AffineState<F>,
913        rhs: AffineState<F>,
914    ) -> Option<AffineState<F>> {
915        let merged_field = Self::merge_field(lhs.field, rhs.field)?;
916        if merged_field.is_none() {
917            // Both constant
918            return Some(AffineState {
919                field: None,
920                scale: 0.0,
921                offset: lhs.offset + rhs.offset,
922            });
923        }
924        Some(AffineState {
925            field: merged_field,
926            scale: lhs.scale + rhs.scale,
927            offset: lhs.offset + rhs.offset,
928        })
929    }
930
931    #[allow(dead_code)]
932    fn affine_sub<F: Eq + Copy>(
933        lhs: AffineState<F>,
934        rhs: AffineState<F>,
935    ) -> Option<AffineState<F>> {
936        let merged_field = Self::merge_field(lhs.field, rhs.field)?;
937        if merged_field.is_none() {
938            return Some(AffineState {
939                field: None,
940                scale: 0.0,
941                offset: lhs.offset - rhs.offset,
942            });
943        }
944        Some(AffineState {
945            field: merged_field,
946            scale: lhs.scale - rhs.scale,
947            offset: lhs.offset - rhs.offset,
948        })
949    }
950
951    #[allow(dead_code)]
952    fn affine_mul<F: Eq + Copy>(
953        lhs: AffineState<F>,
954        rhs: AffineState<F>,
955    ) -> Option<AffineState<F>> {
956        if lhs.field.is_some() && rhs.field.is_some() {
957            return None; // Non-linear
958        }
959        if lhs.field.is_none() {
960            let factor = lhs.offset;
961            return Some(AffineState {
962                field: rhs.field,
963                scale: rhs.scale * factor,
964                offset: rhs.offset * factor,
965            });
966        }
967        if rhs.field.is_none() {
968            let factor = rhs.offset;
969            return Some(AffineState {
970                field: lhs.field,
971                scale: lhs.scale * factor,
972                offset: lhs.offset * factor,
973            });
974        }
975        None
976    }
977
978    #[allow(dead_code)]
979    fn affine_div<F: Eq + Copy>(
980        lhs: AffineState<F>,
981        rhs: AffineState<F>,
982    ) -> Option<AffineState<F>> {
983        if rhs.field.is_some() {
984            return None;
985        }
986        let denom = rhs.offset;
987        if denom == 0.0 {
988            return None;
989        }
990        Some(AffineState {
991            field: lhs.field,
992            scale: lhs.scale / denom,
993            offset: lhs.offset / denom,
994        })
995    }
996}
997
998fn fold_binary_literals(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
999    match op {
1000        BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
1001            let pg_op = match op {
1002                BinaryOp::BitwiseShiftLeft => BinaryOperator::PGBitwiseShiftLeft,
1003                BinaryOp::BitwiseShiftRight => BinaryOperator::PGBitwiseShiftRight,
1004                _ => unreachable!(),
1005            };
1006            crate::literal::bitshift_literals(pg_op, left, right).ok()
1007        }
1008        _ => {
1009            let l_arr = ScalarEvaluator::literal_to_array(left);
1010            let r_arr = ScalarEvaluator::literal_to_array(right);
1011            let result = compute_binary(&l_arr, &r_arr, op).ok()?;
1012            if result.is_null(0) {
1013                Some(Literal::Null)
1014            } else {
1015                Literal::from_array_ref(&result, 0).ok()
1016            }
1017        }
1018    }
1019}
1020
1021fn fold_cast_literal(lit: &Literal, data_type: &DataType) -> Option<Literal> {
1022    if matches!(lit, Literal::Null) {
1023        // Preserve explicit casts of NULL so the target type is kept.
1024        return None;
1025    }
1026    let arr = ScalarEvaluator::literal_to_array(lit);
1027    let casted = cast::cast(&arr, data_type).ok()?;
1028    if casted.is_null(0) {
1029        Some(Literal::Null)
1030    } else {
1031        Literal::from_array_ref(&casted, 0).ok()
1032    }
1033}