llkv_compute/
eval.rs

1use std::hash::Hash;
2use std::sync::Arc;
3
4use arrow::array::{Array, ArrayRef, Float64Array, UInt32Array, new_null_array};
5use arrow::compute::kernels::cast;
6use arrow::compute::kernels::zip::zip;
7use arrow::compute::{concat, is_not_null, take};
8use arrow::datatypes::{DataType, Field, IntervalMonthDayNanoType};
9use llkv_expr::literal::{Literal, LiteralExt};
10use llkv_expr::{AggregateCall, BinaryOp, CompareOp, ScalarExpr};
11use llkv_result::{Error, Result as LlkvResult};
12use llkv_types::IntervalValue;
13use rustc_hash::{FxHashMap, FxHashSet};
14use sqlparser::ast::BinaryOperator;
15
16use crate::date::{add_interval_to_date32, parse_date32_literal, subtract_interval_from_date32};
17use crate::fast_numeric::NumericFastPath;
18use crate::kernels::{compute_binary, get_common_type};
19
20/// Mapping from field identifiers to the numeric Arrow array used for evaluation.
21pub type NumericArrayMap<F> = FxHashMap<F, ArrayRef>;
22
23/// Intermediate representation for vectorized evaluators.
24enum VectorizedExpr {
25    Array(ArrayRef),
26    Scalar(ArrayRef),
27}
28
29impl VectorizedExpr {
30    fn materialize(self, len: usize, target_type: DataType) -> ArrayRef {
31        match self {
32            VectorizedExpr::Array(array) => {
33                if array.data_type() == &target_type {
34                    array
35                } else {
36                    cast::cast(&array, &target_type).unwrap_or(array)
37                }
38            }
39            VectorizedExpr::Scalar(scalar_array) => {
40                if scalar_array.is_empty() {
41                    return new_null_array(&target_type, len);
42                }
43                if scalar_array.is_null(0) {
44                    return new_null_array(scalar_array.data_type(), len);
45                }
46
47                // Expand scalar to array of length len
48                let indices = UInt32Array::from(vec![0; len]);
49                take(&scalar_array, &indices, None)
50                    .unwrap_or_else(|_| new_null_array(scalar_array.data_type(), len))
51            }
52        }
53    }
54}
55
56/// Extension methods for type inference on `ScalarExpr`.
57pub trait ScalarExprTypeExt<F> {
58    fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
59    where
60        F: Hash + Eq + Copy,
61        R: FnMut(F) -> Option<DataType>;
62
63    fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType
64    where
65        F: Hash + Eq + Copy;
66
67    fn contains_interval(&self) -> bool;
68}
69
70impl<F: Hash + Eq + Copy> ScalarExprTypeExt<F> for ScalarExpr<F> {
71    fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
72    where
73        R: FnMut(F) -> Option<DataType>,
74    {
75        match self {
76            ScalarExpr::Literal(lit) => Some(literal_type(lit)),
77            ScalarExpr::Column(fid) => resolve_type(*fid),
78            ScalarExpr::Binary { left, op, right } => {
79                let left_type = left.infer_result_type(resolve_type)?;
80                let right_type = right.infer_result_type(resolve_type)?;
81                Some(binary_result_type(*op, left_type, right_type))
82            }
83            ScalarExpr::Compare { .. } => Some(DataType::Boolean),
84            ScalarExpr::Not(_) => Some(DataType::Boolean),
85            ScalarExpr::IsNull { .. } => Some(DataType::Boolean),
86            ScalarExpr::Aggregate(call) => aggregate_result_type(call, resolve_type),
87            ScalarExpr::GetField { base, field_name } => {
88                let base_type = base.infer_result_type(resolve_type)?;
89                match base_type {
90                    DataType::Struct(fields) => fields
91                        .iter()
92                        .find(|f| f.name() == field_name)
93                        .map(|f| f.data_type().clone()),
94                    _ => None,
95                }
96            }
97            ScalarExpr::Cast { data_type, .. } => Some(data_type.clone()),
98            ScalarExpr::Case {
99                branches,
100                else_expr,
101                ..
102            } => {
103                let mut types = Vec::new();
104                for (_, then_expr) in branches {
105                    if let Some(t) = then_expr.infer_result_type(resolve_type) {
106                        types.push(t);
107                    }
108                }
109                if let Some(else_expr) = else_expr {
110                    if let Some(t) = else_expr.infer_result_type(resolve_type) {
111                        types.push(t);
112                    }
113                } else {
114                    // Implicit ELSE NULL
115                    types.push(DataType::Null);
116                }
117
118                if types.is_empty() {
119                    return None;
120                }
121
122                let mut common = types[0].clone();
123                for t in &types[1..] {
124                    common = get_common_type(&common, t);
125                }
126
127                Some(common)
128            }
129            ScalarExpr::Coalesce(items) => {
130                let mut types = Vec::new();
131                for item in items {
132                    if let Some(t) = item.infer_result_type(resolve_type) {
133                        types.push(t);
134                    }
135                }
136                if types.is_empty() {
137                    return None;
138                }
139                let mut common = types[0].clone();
140                for t in &types[1..] {
141                    common = get_common_type(&common, t);
142                }
143                Some(common)
144            }
145            ScalarExpr::Random => Some(DataType::Float64),
146            ScalarExpr::ScalarSubquery(sub) => Some(sub.data_type.clone()),
147        }
148    }
149
150    fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType {
151        let mut resolver = |fid| arrays.get(&fid).map(|a| a.data_type().clone());
152        self.infer_result_type(&mut resolver)
153            .unwrap_or(DataType::Float64)
154    }
155
156    fn contains_interval(&self) -> bool {
157        match self {
158            ScalarExpr::Literal(Literal::Interval(_)) => true,
159            ScalarExpr::Binary { left, right, .. } => {
160                left.contains_interval() || right.contains_interval()
161            }
162            _ => false,
163        }
164    }
165}
166
167fn literal_type(lit: &Literal) -> DataType {
168    match lit {
169        Literal::Null => DataType::Null,
170        Literal::Boolean(_) => DataType::Boolean,
171        Literal::Int128(_) => DataType::Int64, // Default to Int64 for literals
172        Literal::Float64(_) => DataType::Float64,
173        Literal::Decimal128(d) => DataType::Decimal128(d.precision(), d.scale()),
174        Literal::String(_) => DataType::Utf8,
175        Literal::Date32(_) => DataType::Date32,
176        Literal::Interval(_) => DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
177        Literal::Struct(fields) => {
178            let arrow_fields = fields
179                .iter()
180                .map(|(name, lit)| Field::new(name, literal_type(lit), true))
181                .collect();
182            DataType::Struct(arrow_fields)
183        }
184    }
185}
186
187fn aggregate_result_type<F, R>(call: &AggregateCall<F>, resolve_type: &mut R) -> Option<DataType>
188where
189    F: Hash + Eq + Copy,
190    R: FnMut(F) -> Option<DataType>,
191{
192    match call {
193        AggregateCall::CountStar | AggregateCall::Count { .. } | AggregateCall::CountNulls(_) => {
194            Some(DataType::Int64)
195        }
196        AggregateCall::Sum { expr, .. } => {
197            let child = expr.infer_result_type(resolve_type)?;
198            Some(match child {
199                DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
200                DataType::Float32 | DataType::Float64 => DataType::Float64,
201                DataType::UInt64 | DataType::Int64 => child,
202                DataType::UInt32
203                | DataType::UInt16
204                | DataType::UInt8
205                | DataType::Int32
206                | DataType::Int16
207                | DataType::Int8 => DataType::Int64,
208                _ => DataType::Float64,
209            })
210        }
211        AggregateCall::Total { expr, .. } | AggregateCall::Avg { expr, .. } => {
212            let child = expr.infer_result_type(resolve_type)?;
213            Some(match child {
214                DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
215                _ => DataType::Float64,
216            })
217        }
218        AggregateCall::Min(expr) | AggregateCall::Max(expr) => expr.infer_result_type(resolve_type),
219        AggregateCall::GroupConcat { .. } => Some(DataType::Utf8),
220    }
221}
222
223fn binary_result_type(op: BinaryOp, lhs: DataType, rhs: DataType) -> DataType {
224    crate::kernels::common_type_for_op(&lhs, &rhs, op)
225}
226
227/// Represents an affine transformation `scale * field + offset`.
228#[derive(Clone, Copy, Debug)]
229pub struct AffineExpr<F> {
230    pub field: F,
231    pub scale: f64,
232    pub offset: f64,
233}
234
235/// Internal accumulator representing a partially merged affine expression.
236#[derive(Clone, Copy, Debug)]
237#[allow(dead_code)]
238struct AffineState<F> {
239    field: Option<F>,
240    scale: f64,
241    offset: f64,
242}
243
244/// Centralizes the numeric kernels applied to scalar expressions so they can be
245/// tuned without touching the surrounding table scan logic.
246pub struct ScalarEvaluator;
247
248impl ScalarEvaluator {
249    /// Combine field identifiers while tracking whether multiple fields were encountered.
250    #[allow(dead_code)]
251    fn merge_field<F: Eq + Copy>(lhs: Option<F>, rhs: Option<F>) -> Option<Option<F>> {
252        match (lhs, rhs) {
253            (Some(a), Some(b)) => {
254                if a == b {
255                    Some(Some(a))
256                } else {
257                    None
258                }
259            }
260            (Some(a), None) => Some(Some(a)),
261            (None, Some(b)) => Some(Some(b)),
262            (None, None) => Some(None),
263        }
264    }
265
266    /// Collect every field referenced by the scalar expression into `acc`.
267    pub fn collect_fields<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>, acc: &mut FxHashSet<F>) {
268        match expr {
269            ScalarExpr::Column(fid) => {
270                acc.insert(*fid);
271            }
272            ScalarExpr::Literal(_) => {}
273            ScalarExpr::Binary { left, right, .. } => {
274                Self::collect_fields(left, acc);
275                Self::collect_fields(right, acc);
276            }
277            ScalarExpr::Compare { left, right, .. } => {
278                Self::collect_fields(left, acc);
279                Self::collect_fields(right, acc);
280            }
281            ScalarExpr::Not(inner) => {
282                Self::collect_fields(inner, acc);
283            }
284            ScalarExpr::IsNull { expr, .. } => {
285                Self::collect_fields(expr, acc);
286            }
287            ScalarExpr::Aggregate(agg) => {
288                // Collect fields referenced by the aggregate expression
289                match agg {
290                    AggregateCall::CountStar => {}
291                    AggregateCall::Count { expr, .. }
292                    | AggregateCall::Sum { expr, .. }
293                    | AggregateCall::Total { expr, .. }
294                    | AggregateCall::Avg { expr, .. }
295                    | AggregateCall::Min(expr)
296                    | AggregateCall::Max(expr)
297                    | AggregateCall::CountNulls(expr)
298                    | AggregateCall::GroupConcat { expr, .. } => {
299                        Self::collect_fields(expr, acc);
300                    }
301                }
302            }
303            ScalarExpr::GetField { base, .. } => {
304                // Collect fields from the base expression
305                Self::collect_fields(base, acc);
306            }
307            ScalarExpr::Cast { expr, .. } => {
308                Self::collect_fields(expr, acc);
309            }
310            ScalarExpr::Case {
311                operand,
312                branches,
313                else_expr,
314            } => {
315                if let Some(inner) = operand.as_deref() {
316                    Self::collect_fields(inner, acc);
317                }
318                for (when_expr, then_expr) in branches {
319                    Self::collect_fields(when_expr, acc);
320                    Self::collect_fields(then_expr, acc);
321                }
322                if let Some(inner) = else_expr.as_deref() {
323                    Self::collect_fields(inner, acc);
324                }
325            }
326            ScalarExpr::Coalesce(items) => {
327                for item in items {
328                    Self::collect_fields(item, acc);
329                }
330            }
331            ScalarExpr::Random => {
332                // Random does not reference any fields
333            }
334            ScalarExpr::ScalarSubquery(_) => {
335                // Scalar subqueries don't directly reference fields from the outer query
336            }
337        }
338    }
339
340    pub fn prepare_numeric_arrays<F: Hash + Eq + Copy>(
341        arrays: &FxHashMap<F, ArrayRef>,
342        _row_count: usize,
343    ) -> NumericArrayMap<F> {
344        arrays.clone()
345    }
346
347    /// Attempts to represent the expression as `scale * column + offset`.
348    /// Returns `None` when the expression depends on multiple columns or is non-linear.
349    pub fn extract_affine<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineExpr<F>> {
350        let simplified = Self::simplify(expr);
351        Self::extract_affine_simplified(&simplified)
352    }
353
354    /// Variant of `extract_affine` that assumes `expr` is already simplified.
355    pub fn extract_affine_simplified<F: Hash + Eq + Copy>(
356        expr: &ScalarExpr<F>,
357    ) -> Option<AffineExpr<F>> {
358        let state = Self::affine_state(expr)?;
359        let field = state.field?;
360        Some(AffineExpr {
361            field,
362            scale: state.scale,
363            offset: state.offset,
364        })
365    }
366
367    fn affine_state<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineState<F>> {
368        match expr {
369            ScalarExpr::Column(fid) => Some(AffineState {
370                field: Some(*fid),
371                scale: 1.0,
372                offset: 0.0,
373            }),
374            ScalarExpr::Literal(lit) => {
375                let arr = Self::literal_to_array(lit);
376                let val = cast::cast(&arr, &DataType::Float64).ok()?;
377                let val = val.as_any().downcast_ref::<Float64Array>()?;
378                if val.is_null(0) {
379                    return None;
380                }
381                Some(AffineState {
382                    field: None,
383                    scale: 0.0,
384                    offset: val.value(0),
385                })
386            }
387            ScalarExpr::Aggregate(_) => None,
388            ScalarExpr::GetField { .. } => None,
389            ScalarExpr::Binary { left, op, right } => {
390                let left_state = Self::affine_state(left)?;
391                let right_state = Self::affine_state(right)?;
392                match op {
393                    BinaryOp::Add => Self::affine_add(left_state, right_state),
394                    BinaryOp::Subtract => Self::affine_sub(left_state, right_state),
395                    BinaryOp::Multiply => Self::affine_mul(left_state, right_state),
396                    BinaryOp::Divide => Self::affine_div(left_state, right_state),
397                    _ => None,
398                }
399            }
400            ScalarExpr::Cast { expr, .. } => Self::affine_state(expr),
401            _ => None,
402        }
403    }
404
405    /// Evaluate a scalar expression for the row at `idx` using the provided numeric arrays.
406    pub fn evaluate_value<F: Hash + Eq + Copy>(
407        expr: &ScalarExpr<F>,
408        idx: usize,
409        arrays: &NumericArrayMap<F>,
410    ) -> LlkvResult<ArrayRef> {
411        match expr {
412            ScalarExpr::Column(fid) => {
413                let array = arrays
414                    .get(fid)
415                    .ok_or_else(|| Error::Internal("missing column for field".into()))?;
416                Ok(array.slice(idx, 1))
417            }
418            ScalarExpr::Literal(lit) => Ok(Self::literal_to_array(lit)),
419            ScalarExpr::Binary { left, op, right } => {
420                let l = Self::evaluate_value(left, idx, arrays)?;
421                let r = Self::evaluate_value(right, idx, arrays)?;
422                Self::evaluate_binary_scalar(&l, *op, &r)
423            }
424            ScalarExpr::Compare { left, op, right } => {
425                let l = Self::evaluate_value(left, idx, arrays)?;
426                let r = Self::evaluate_value(right, idx, arrays)?;
427                crate::kernels::compute_compare(&l, *op, &r)
428            }
429            ScalarExpr::Not(expr) => {
430                let val = Self::evaluate_value(expr, idx, arrays)?;
431                let bool_arr = cast::cast(&val, &DataType::Boolean)
432                    .map_err(|e| Error::Internal(e.to_string()))?;
433                let bool_arr = bool_arr
434                    .as_any()
435                    .downcast_ref::<arrow::array::BooleanArray>()
436                    .unwrap();
437                let result = arrow::compute::kernels::boolean::not(bool_arr)
438                    .map_err(|e| Error::Internal(e.to_string()))?;
439                Ok(Arc::new(result))
440            }
441            ScalarExpr::IsNull { expr, negated } => {
442                let val = Self::evaluate_value(expr, idx, arrays)?;
443                let is_null = val.is_null(0);
444                let result = if *negated { !is_null } else { is_null };
445                Ok(Arc::new(arrow::array::BooleanArray::from(vec![result])))
446            }
447            ScalarExpr::Cast { expr, data_type } => {
448                let val = Self::evaluate_value(expr, idx, arrays)?;
449                cast::cast(&val, data_type).map_err(|e| Error::Internal(e.to_string()))
450            }
451            ScalarExpr::Case {
452                operand,
453                branches,
454                else_expr,
455            } => {
456                let operand_val = if let Some(op) = operand {
457                    Some(Self::evaluate_value(op, idx, arrays)?)
458                } else {
459                    None
460                };
461
462                for (when_expr, then_expr) in branches {
463                    let when_val = Self::evaluate_value(when_expr, idx, arrays)?;
464
465                    let is_match = if let Some(op_val) = &operand_val {
466                        // Simple CASE: operand = when_val
467                        // If either is null, result is null (false for condition)
468                        if op_val.is_null(0) || when_val.is_null(0) {
469                            false
470                        } else {
471                            let eq =
472                                crate::kernels::compute_compare(op_val, CompareOp::Eq, &when_val)?;
473                            let bool_arr = eq
474                                .as_any()
475                                .downcast_ref::<arrow::array::BooleanArray>()
476                                .unwrap();
477                            bool_arr.value(0)
478                        }
479                    } else {
480                        // Searched CASE: when_val is boolean condition
481                        if when_val.is_null(0) {
482                            false
483                        } else {
484                            let bool_arr = cast::cast(&when_val, &DataType::Boolean)
485                                .map_err(|e| Error::Internal(e.to_string()))?;
486                            let bool_arr = bool_arr
487                                .as_any()
488                                .downcast_ref::<arrow::array::BooleanArray>()
489                                .unwrap();
490                            bool_arr.value(0)
491                        }
492                    };
493
494                    if is_match {
495                        return Self::evaluate_value(then_expr, idx, arrays);
496                    }
497                }
498                if let Some(else_expr) = else_expr {
499                    Self::evaluate_value(else_expr, idx, arrays)
500                } else {
501                    Ok(new_null_array(&DataType::Null, 1))
502                }
503            }
504            ScalarExpr::Coalesce(items) => {
505                for item in items {
506                    let val = Self::evaluate_value(item, idx, arrays)?;
507                    if !val.is_null(0) && val.data_type() != &DataType::Null {
508                        return Ok(val);
509                    }
510                }
511                Ok(new_null_array(&DataType::Null, 1))
512            }
513            ScalarExpr::Random => {
514                let val = rand::random::<f64>();
515                Ok(Arc::new(Float64Array::from(vec![val])))
516            }
517            _ => Err(Error::Internal("Unsupported scalar expression".into())),
518        }
519    }
520
521    fn literal_to_array(lit: &Literal) -> ArrayRef {
522        match lit {
523            Literal::Null => new_null_array(&DataType::Null, 1),
524            Literal::Boolean(b) => Arc::new(arrow::array::BooleanArray::from(vec![*b])),
525            Literal::Int128(i) => Arc::new(arrow::array::Int64Array::from(vec![*i as i64])),
526            Literal::Float64(f) => Arc::new(Float64Array::from(vec![*f])),
527            Literal::Decimal128(d) => {
528                let array = arrow::array::Decimal128Array::from(vec![Some(d.raw_value())])
529                    .with_precision_and_scale(d.precision(), d.scale())
530                    .unwrap();
531                Arc::new(array)
532            }
533            Literal::String(s) => Arc::new(arrow::array::StringArray::from(vec![s.clone()])),
534            Literal::Date32(d) => Arc::new(arrow::array::Date32Array::from(vec![*d])),
535            Literal::Interval(i) => {
536                let val = IntervalMonthDayNanoType::make_value(i.months, i.days, i.nanos);
537                Arc::new(arrow::array::IntervalMonthDayNanoArray::from(vec![val]))
538            }
539            Literal::Struct(_) => {
540                new_null_array(&DataType::Struct(arrow::datatypes::Fields::empty()), 1)
541            }
542        }
543    }
544
545    fn evaluate_binary_scalar(
546        lhs: &ArrayRef,
547        op: BinaryOp,
548        rhs: &ArrayRef,
549    ) -> LlkvResult<ArrayRef> {
550        compute_binary(lhs, rhs, op)
551    }
552
553    /// Evaluate a scalar expression for every row in the batch.
554    #[allow(dead_code)]
555    pub fn evaluate_batch<F: Hash + Eq + Copy + std::fmt::Debug>(
556        expr: &ScalarExpr<F>,
557        len: usize,
558        arrays: &NumericArrayMap<F>,
559    ) -> LlkvResult<ArrayRef> {
560        let simplified = Self::simplify(expr);
561        Self::evaluate_batch_simplified(&simplified, len, arrays)
562    }
563
564    /// Evaluate a scalar expression that has already been simplified.
565    pub fn evaluate_batch_simplified<F: Hash + Eq + Copy>(
566        expr: &ScalarExpr<F>,
567        len: usize,
568        arrays: &NumericArrayMap<F>,
569    ) -> LlkvResult<ArrayRef> {
570        let preferred = expr.infer_result_type_from_arrays(arrays);
571
572        if len == 0 {
573            return Ok(new_null_array(&preferred, 0));
574        }
575
576        if let Some(fast_path) = NumericFastPath::compile(expr, arrays, &preferred) {
577            let fast_result = fast_path.execute(len, arrays)?;
578            if fast_result.data_type() != &preferred {
579                let casted = cast::cast(&fast_result, &preferred).map_err(|e| {
580                    Error::Internal(format!("Failed to cast fast path result: {}", e))
581                })?;
582                return Ok(casted);
583            }
584            return Ok(fast_result);
585        }
586
587        if let Some(vectorized) =
588            Self::try_evaluate_vectorized(expr, len, arrays, preferred.clone())?
589        {
590            let result = vectorized.materialize(len, preferred);
591            return Ok(result);
592        }
593
594        let mut values = Vec::with_capacity(len);
595        for idx in 0..len {
596            let val = Self::evaluate_value(expr, idx, arrays)?;
597            if val.data_type() != &preferred {
598                let casted = cast::cast(&val, &preferred).map_err(|e| {
599                    Error::Internal(format!(
600                        "Failed to cast row {}: {} (Val type: {:?}, Preferred: {:?})",
601                        idx,
602                        e,
603                        val.data_type(),
604                        preferred
605                    ))
606                })?;
607                values.push(casted);
608            } else {
609                values.push(val);
610            }
611        }
612        concat(&values.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
613            .map_err(|e| Error::Internal(e.to_string()))
614    }
615
616    fn try_evaluate_vectorized<F: Hash + Eq + Copy>(
617        expr: &ScalarExpr<F>,
618        len: usize,
619        arrays: &NumericArrayMap<F>,
620        _target_type: DataType,
621    ) -> LlkvResult<Option<VectorizedExpr>> {
622        if expr.contains_interval() {
623            return Ok(None);
624        }
625        match expr {
626            ScalarExpr::Column(fid) => {
627                let array = arrays
628                    .get(fid)
629                    .ok_or_else(|| Error::Internal("missing column for field".into()))?;
630                Ok(Some(VectorizedExpr::Array(array.clone())))
631            }
632            ScalarExpr::Literal(lit) => {
633                let array = Self::literal_to_array(lit);
634                Ok(Some(VectorizedExpr::Scalar(array)))
635            }
636            ScalarExpr::Binary { left, op, right } => {
637                let left_type = left.infer_result_type_from_arrays(arrays);
638                let right_type = right.infer_result_type_from_arrays(arrays);
639
640                let left_vec = Self::try_evaluate_vectorized(left, len, arrays, left_type)?;
641                let right_vec = Self::try_evaluate_vectorized(right, len, arrays, right_type)?;
642
643                match (left_vec, right_vec) {
644                    (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
645                        let result = compute_binary(&lhs, &rhs, *op)?;
646                        Ok(Some(VectorizedExpr::Scalar(result)))
647                    }
648                    (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Array(rhs))) => {
649                        let array = compute_binary(&lhs, &rhs, *op)?;
650                        Ok(Some(VectorizedExpr::Array(array)))
651                    }
652                    (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
653                        let rhs_expanded = VectorizedExpr::Scalar(rhs)
654                            .materialize(lhs.len(), lhs.data_type().clone());
655                        let array = compute_binary(&lhs, &rhs_expanded, *op)?;
656                        Ok(Some(VectorizedExpr::Array(array)))
657                    }
658                    (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Array(rhs))) => {
659                        let lhs_expanded = VectorizedExpr::Scalar(lhs)
660                            .materialize(rhs.len(), rhs.data_type().clone());
661                        let array = compute_binary(&lhs_expanded, &rhs, *op)?;
662                        Ok(Some(VectorizedExpr::Array(array)))
663                    }
664                    _ => Ok(None),
665                }
666            }
667            ScalarExpr::Cast { expr, data_type } => {
668                let inner_type = expr.infer_result_type_from_arrays(arrays);
669                let inner_vec = Self::try_evaluate_vectorized(expr, len, arrays, inner_type)?;
670
671                match inner_vec {
672                    Some(VectorizedExpr::Scalar(array)) => {
673                        let casted = cast::cast(&array, data_type)
674                            .map_err(|e| Error::Internal(e.to_string()))?;
675                        Ok(Some(VectorizedExpr::Scalar(casted)))
676                    }
677                    Some(VectorizedExpr::Array(array)) => {
678                        let casted = cast::cast(&array, data_type)
679                            .map_err(|e| Error::Internal(e.to_string()))?;
680                        Ok(Some(VectorizedExpr::Array(casted)))
681                    }
682                    None => Ok(None),
683                }
684            }
685            ScalarExpr::Coalesce(items) => {
686                let mut evaluated_items = Vec::with_capacity(items.len());
687                let mut types = Vec::with_capacity(items.len());
688
689                for item in items {
690                    let item_type = item.infer_result_type_from_arrays(arrays);
691                    // If any item cannot be vectorized, we cannot vectorize the whole Coalesce
692                    let vec_expr = match Self::try_evaluate_vectorized(
693                        item,
694                        len,
695                        arrays,
696                        item_type.clone(),
697                    )? {
698                        Some(v) => v,
699                        None => return Ok(None),
700                    };
701
702                    let array = vec_expr.materialize(len, item_type.clone());
703                    types.push(array.data_type().clone());
704                    evaluated_items.push(array);
705                }
706
707                if evaluated_items.is_empty() {
708                    return Ok(Some(VectorizedExpr::Array(new_null_array(
709                        &DataType::Null,
710                        len,
711                    ))));
712                }
713
714                // Determine common type
715                let mut common_type = types[0].clone();
716                for t in &types[1..] {
717                    common_type = get_common_type(&common_type, t);
718                }
719
720                // Cast all arrays to common type
721                let mut casted_arrays = Vec::with_capacity(evaluated_items.len());
722                for array in evaluated_items {
723                    if array.data_type() != &common_type {
724                        let casted = cast::cast(&array, &common_type)
725                            .map_err(|e| Error::Internal(e.to_string()))?;
726                        casted_arrays.push(casted);
727                    } else {
728                        casted_arrays.push(array);
729                    }
730                }
731
732                let mut result = casted_arrays[0].clone();
733                for next_array in &casted_arrays[1..] {
734                    let mask = is_not_null(&result).map_err(|e| Error::Internal(e.to_string()))?;
735                    // result = zip(mask, result, next_array)
736                    // if mask is true (result is not null), keep result.
737                    // if mask is false (result is null), take next_array.
738                    result = zip(&mask, &result, next_array)
739                        .map_err(|e| Error::Internal(e.to_string()))?;
740                }
741                Ok(Some(VectorizedExpr::Array(result)))
742            }
743            ScalarExpr::Random => {
744                let values: Vec<f64> = (0..len).map(|_| rand::random::<f64>()).collect();
745                let array = Float64Array::from(values);
746                Ok(Some(VectorizedExpr::Array(Arc::new(array))))
747            }
748            _ => Ok(None),
749        }
750    }
751
752    /// Returns the column referenced by an expression when it's a direct or additive identity passthrough.
753    pub fn passthrough_column<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<F> {
754        match Self::simplify(expr) {
755            ScalarExpr::Column(fid) => Some(fid),
756            _ => None,
757        }
758    }
759
760    /// Simplify an expression by constant folding and identity removal.
761    pub fn simplify<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> ScalarExpr<F> {
762        match expr {
763            ScalarExpr::Binary { left, op, right } => {
764                let l = Self::simplify(left);
765                let r = Self::simplify(right);
766                if let (ScalarExpr::Literal(ll), ScalarExpr::Literal(rr)) = (&l, &r)
767                    && let Some(folded) = fold_binary_literals(*op, ll, rr)
768                {
769                    return ScalarExpr::Literal(folded);
770                }
771                ScalarExpr::Binary {
772                    left: Box::new(l),
773                    op: *op,
774                    right: Box::new(r),
775                }
776            }
777            ScalarExpr::Cast { expr, data_type } => {
778                let inner = Self::simplify(expr);
779                if let ScalarExpr::Literal(lit) = &inner
780                    && let Some(folded) = fold_cast_literal(lit, data_type)
781                {
782                    return ScalarExpr::Literal(folded);
783                }
784                ScalarExpr::Cast {
785                    expr: Box::new(inner),
786                    data_type: data_type.clone(),
787                }
788            }
789            _ => expr.clone(),
790        }
791    }
792
793    pub fn evaluate_constant_literal_expr<F: Hash + Eq + Copy>(
794        expr: &ScalarExpr<F>,
795    ) -> LlkvResult<Option<Literal>> {
796        let simplified = Self::simplify(expr);
797
798        if let Some(literal) = Self::evaluate_constant_literal_non_numeric(&simplified)? {
799            return Ok(Some(literal));
800        }
801
802        if let ScalarExpr::Literal(lit) = &simplified {
803            return Ok(Some(lit.clone()));
804        }
805
806        let arrays = NumericArrayMap::default();
807        let array = Self::evaluate_value(&simplified, 0, &arrays)?;
808        if array.is_null(0) {
809            return Ok(None);
810        }
811        Ok(Some(Literal::from_array_ref(&array, 0)?))
812    }
813
814    pub fn evaluate_constant_literal_non_numeric<F: Hash + Eq + Copy>(
815        expr: &ScalarExpr<F>,
816    ) -> LlkvResult<Option<Literal>> {
817        match expr {
818            ScalarExpr::Literal(lit) => Ok(Some(lit.clone())),
819            ScalarExpr::Cast {
820                expr,
821                data_type: DataType::Date32,
822            } => {
823                let inner = Self::evaluate_constant_literal_non_numeric(expr)?;
824                match inner {
825                    Some(Literal::Null) => Ok(Some(Literal::Null)),
826                    Some(Literal::String(text)) => {
827                        let days = parse_date32_literal(&text)?;
828                        Ok(Some(Literal::Date32(days)))
829                    }
830                    Some(Literal::Date32(days)) => Ok(Some(Literal::Date32(days))),
831                    Some(other) => Err(Error::InvalidArgumentError(format!(
832                        "cannot cast literal of type {} to DATE",
833                        other.type_name()
834                    ))),
835                    None => Ok(None),
836                }
837            }
838            ScalarExpr::Cast { .. } => Ok(None),
839            ScalarExpr::Binary { left, op, right } => {
840                let left_lit = match Self::evaluate_constant_literal_non_numeric(left)? {
841                    Some(lit) => lit,
842                    None => return Ok(None),
843                };
844                let right_lit = match Self::evaluate_constant_literal_non_numeric(right)? {
845                    Some(lit) => lit,
846                    None => return Ok(None),
847                };
848
849                if matches!(left_lit, Literal::Null) || matches!(right_lit, Literal::Null) {
850                    return Ok(Some(Literal::Null));
851                }
852
853                match op {
854                    BinaryOp::Add => match (&left_lit, &right_lit) {
855                        (Literal::Date32(days), Literal::Interval(interval))
856                        | (Literal::Interval(interval), Literal::Date32(days)) => {
857                            let adjusted = add_interval_to_date32(*days, *interval)?;
858                            Ok(Some(Literal::Date32(adjusted)))
859                        }
860                        (Literal::Interval(left), Literal::Interval(right)) => {
861                            let sum = left.checked_add(*right).ok_or_else(|| {
862                                Error::InvalidArgumentError(
863                                    "interval addition overflow during constant folding".into(),
864                                )
865                            })?;
866                            Ok(Some(Literal::Interval(sum)))
867                        }
868                        _ => Ok(None),
869                    },
870                    BinaryOp::Subtract => match (&left_lit, &right_lit) {
871                        (Literal::Date32(days), Literal::Interval(interval)) => {
872                            let adjusted = subtract_interval_from_date32(*days, *interval)?;
873                            Ok(Some(Literal::Date32(adjusted)))
874                        }
875                        (Literal::Interval(left), Literal::Interval(right)) => {
876                            let diff = left.checked_sub(*right).ok_or_else(|| {
877                                Error::InvalidArgumentError(
878                                    "interval subtraction overflow during constant folding".into(),
879                                )
880                            })?;
881                            Ok(Some(Literal::Interval(diff)))
882                        }
883                        (Literal::Date32(lhs), Literal::Date32(rhs)) => {
884                            let delta = i64::from(*lhs) - i64::from(*rhs);
885                            if delta < i64::from(i32::MIN) || delta > i64::from(i32::MAX) {
886                                return Err(Error::InvalidArgumentError(
887                                    "DATE subtraction overflowed day precision".into(),
888                                ));
889                            }
890                            Ok(Some(Literal::Interval(IntervalValue::new(
891                                0,
892                                delta as i32,
893                                0,
894                            ))))
895                        }
896                        _ => Ok(None),
897                    },
898                    _ => Ok(None),
899                }
900            }
901            _ => Ok(None),
902        }
903    }
904
905    // TODO: Should Decimal types be included here?
906    pub fn is_supported_numeric(dtype: &DataType) -> bool {
907        matches!(
908            dtype,
909            DataType::UInt64
910                | DataType::UInt32
911                | DataType::UInt16
912                | DataType::UInt8
913                | DataType::Int64
914                | DataType::Int32
915                | DataType::Int16
916                | DataType::Int8
917                | DataType::Float64
918                | DataType::Float32
919        )
920    }
921
922    #[allow(dead_code)]
923    fn affine_add<F: Eq + Copy>(
924        lhs: AffineState<F>,
925        rhs: AffineState<F>,
926    ) -> Option<AffineState<F>> {
927        let merged_field = Self::merge_field(lhs.field, rhs.field)?;
928        if merged_field.is_none() {
929            // Both constant
930            return Some(AffineState {
931                field: None,
932                scale: 0.0,
933                offset: lhs.offset + rhs.offset,
934            });
935        }
936        Some(AffineState {
937            field: merged_field,
938            scale: lhs.scale + rhs.scale,
939            offset: lhs.offset + rhs.offset,
940        })
941    }
942
943    #[allow(dead_code)]
944    fn affine_sub<F: Eq + Copy>(
945        lhs: AffineState<F>,
946        rhs: AffineState<F>,
947    ) -> Option<AffineState<F>> {
948        let merged_field = Self::merge_field(lhs.field, rhs.field)?;
949        if merged_field.is_none() {
950            return Some(AffineState {
951                field: None,
952                scale: 0.0,
953                offset: lhs.offset - rhs.offset,
954            });
955        }
956        Some(AffineState {
957            field: merged_field,
958            scale: lhs.scale - rhs.scale,
959            offset: lhs.offset - rhs.offset,
960        })
961    }
962
963    #[allow(dead_code)]
964    fn affine_mul<F: Eq + Copy>(
965        lhs: AffineState<F>,
966        rhs: AffineState<F>,
967    ) -> Option<AffineState<F>> {
968        if lhs.field.is_some() && rhs.field.is_some() {
969            return None; // Non-linear
970        }
971        if lhs.field.is_none() {
972            let factor = lhs.offset;
973            return Some(AffineState {
974                field: rhs.field,
975                scale: rhs.scale * factor,
976                offset: rhs.offset * factor,
977            });
978        }
979        if rhs.field.is_none() {
980            let factor = rhs.offset;
981            return Some(AffineState {
982                field: lhs.field,
983                scale: lhs.scale * factor,
984                offset: lhs.offset * factor,
985            });
986        }
987        None
988    }
989
990    #[allow(dead_code)]
991    fn affine_div<F: Eq + Copy>(
992        lhs: AffineState<F>,
993        rhs: AffineState<F>,
994    ) -> Option<AffineState<F>> {
995        if rhs.field.is_some() {
996            return None;
997        }
998        let denom = rhs.offset;
999        if denom == 0.0 {
1000            return None;
1001        }
1002        Some(AffineState {
1003            field: lhs.field,
1004            scale: lhs.scale / denom,
1005            offset: lhs.offset / denom,
1006        })
1007    }
1008}
1009
1010fn fold_binary_literals(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
1011    match op {
1012        BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
1013            let pg_op = match op {
1014                BinaryOp::BitwiseShiftLeft => BinaryOperator::PGBitwiseShiftLeft,
1015                BinaryOp::BitwiseShiftRight => BinaryOperator::PGBitwiseShiftRight,
1016                _ => unreachable!(),
1017            };
1018            crate::literal::bitshift_literals(pg_op, left, right).ok()
1019        }
1020        _ => {
1021            let l_arr = ScalarEvaluator::literal_to_array(left);
1022            let r_arr = ScalarEvaluator::literal_to_array(right);
1023            let result = compute_binary(&l_arr, &r_arr, op).ok()?;
1024            if result.is_null(0) {
1025                Some(Literal::Null)
1026            } else {
1027                Literal::from_array_ref(&result, 0).ok()
1028            }
1029        }
1030    }
1031}
1032
1033fn fold_cast_literal(lit: &Literal, data_type: &DataType) -> Option<Literal> {
1034    if matches!(lit, Literal::Null) {
1035        // Preserve explicit casts of NULL so the target type is kept.
1036        return None;
1037    }
1038    let arr = ScalarEvaluator::literal_to_array(lit);
1039    let casted = cast::cast(&arr, data_type).ok()?;
1040    if casted.is_null(0) {
1041        Some(Literal::Null)
1042    } else {
1043        Literal::from_array_ref(&casted, 0).ok()
1044    }
1045}