1use std::hash::Hash;
2use std::sync::Arc;
3
4use arrow::array::{Array, ArrayRef, Float64Array, UInt32Array, new_null_array};
5use arrow::compute::kernels::cast;
6use arrow::compute::kernels::zip::zip;
7use arrow::compute::{concat, is_not_null, take};
8use arrow::datatypes::{DataType, Field, IntervalMonthDayNanoType};
9use llkv_expr::literal::{Literal, LiteralExt};
10use llkv_expr::{AggregateCall, BinaryOp, CompareOp, ScalarExpr};
11use llkv_result::{Error, Result as LlkvResult};
12use llkv_types::IntervalValue;
13use rustc_hash::{FxHashMap, FxHashSet};
14use sqlparser::ast::BinaryOperator;
15
16use crate::date::{add_interval_to_date32, parse_date32_literal, subtract_interval_from_date32};
17use crate::fast_numeric::NumericFastPath;
18use crate::kernels::{compute_binary, get_common_type};
19
20pub type NumericArrayMap<F> = FxHashMap<F, ArrayRef>;
22
23enum VectorizedExpr {
25 Array(ArrayRef),
26 Scalar(ArrayRef),
27}
28
29impl VectorizedExpr {
30 fn materialize(self, len: usize, target_type: DataType) -> ArrayRef {
31 match self {
32 VectorizedExpr::Array(array) => {
33 if array.data_type() == &target_type {
34 array
35 } else {
36 cast::cast(&array, &target_type).unwrap_or(array)
37 }
38 }
39 VectorizedExpr::Scalar(scalar_array) => {
40 if scalar_array.is_empty() {
41 return new_null_array(&target_type, len);
42 }
43 if scalar_array.is_null(0) {
44 return new_null_array(scalar_array.data_type(), len);
45 }
46
47 let indices = UInt32Array::from(vec![0; len]);
49 take(&scalar_array, &indices, None)
50 .unwrap_or_else(|_| new_null_array(scalar_array.data_type(), len))
51 }
52 }
53 }
54}
55
56pub trait ScalarExprTypeExt<F> {
58 fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
59 where
60 F: Hash + Eq + Copy,
61 R: FnMut(F) -> Option<DataType>;
62
63 fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType
64 where
65 F: Hash + Eq + Copy;
66
67 fn contains_interval(&self) -> bool;
68}
69
70impl<F: Hash + Eq + Copy> ScalarExprTypeExt<F> for ScalarExpr<F> {
71 fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
72 where
73 R: FnMut(F) -> Option<DataType>,
74 {
75 match self {
76 ScalarExpr::Literal(lit) => Some(literal_type(lit)),
77 ScalarExpr::Column(fid) => resolve_type(*fid),
78 ScalarExpr::Binary { left, op, right } => {
79 let left_type = left.infer_result_type(resolve_type)?;
80 let right_type = right.infer_result_type(resolve_type)?;
81 Some(binary_result_type(*op, left_type, right_type))
82 }
83 ScalarExpr::Compare { .. } => Some(DataType::Boolean),
84 ScalarExpr::Not(_) => Some(DataType::Boolean),
85 ScalarExpr::IsNull { .. } => Some(DataType::Boolean),
86 ScalarExpr::Aggregate(call) => aggregate_result_type(call, resolve_type),
87 ScalarExpr::GetField { base, field_name } => {
88 let base_type = base.infer_result_type(resolve_type)?;
89 match base_type {
90 DataType::Struct(fields) => fields
91 .iter()
92 .find(|f| f.name() == field_name)
93 .map(|f| f.data_type().clone()),
94 _ => None,
95 }
96 }
97 ScalarExpr::Cast { data_type, .. } => Some(data_type.clone()),
98 ScalarExpr::Case {
99 branches,
100 else_expr,
101 ..
102 } => {
103 let mut types = Vec::new();
104 for (_, then_expr) in branches {
105 if let Some(t) = then_expr.infer_result_type(resolve_type) {
106 types.push(t);
107 }
108 }
109 if let Some(else_expr) = else_expr {
110 if let Some(t) = else_expr.infer_result_type(resolve_type) {
111 types.push(t);
112 }
113 } else {
114 types.push(DataType::Null);
116 }
117
118 if types.is_empty() {
119 return None;
120 }
121
122 let mut common = types[0].clone();
123 for t in &types[1..] {
124 common = get_common_type(&common, t);
125 }
126
127 Some(common)
128 }
129 ScalarExpr::Coalesce(items) => {
130 let mut types = Vec::new();
131 for item in items {
132 if let Some(t) = item.infer_result_type(resolve_type) {
133 types.push(t);
134 }
135 }
136 if types.is_empty() {
137 return None;
138 }
139 let mut common = types[0].clone();
140 for t in &types[1..] {
141 common = get_common_type(&common, t);
142 }
143 Some(common)
144 }
145 ScalarExpr::Random => Some(DataType::Float64),
146 ScalarExpr::ScalarSubquery(sub) => Some(sub.data_type.clone()),
147 }
148 }
149
150 fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType {
151 let mut resolver = |fid| arrays.get(&fid).map(|a| a.data_type().clone());
152 self.infer_result_type(&mut resolver)
153 .unwrap_or(DataType::Float64)
154 }
155
156 fn contains_interval(&self) -> bool {
157 match self {
158 ScalarExpr::Literal(Literal::Interval(_)) => true,
159 ScalarExpr::Binary { left, right, .. } => {
160 left.contains_interval() || right.contains_interval()
161 }
162 _ => false,
163 }
164 }
165}
166
167fn literal_type(lit: &Literal) -> DataType {
168 match lit {
169 Literal::Null => DataType::Null,
170 Literal::Boolean(_) => DataType::Boolean,
171 Literal::Int128(_) => DataType::Int64, Literal::Float64(_) => DataType::Float64,
173 Literal::Decimal128(d) => DataType::Decimal128(d.precision(), d.scale()),
174 Literal::String(_) => DataType::Utf8,
175 Literal::Date32(_) => DataType::Date32,
176 Literal::Interval(_) => DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
177 Literal::Struct(fields) => {
178 let arrow_fields = fields
179 .iter()
180 .map(|(name, lit)| Field::new(name, literal_type(lit), true))
181 .collect();
182 DataType::Struct(arrow_fields)
183 }
184 }
185}
186
187fn aggregate_result_type<F, R>(call: &AggregateCall<F>, resolve_type: &mut R) -> Option<DataType>
188where
189 F: Hash + Eq + Copy,
190 R: FnMut(F) -> Option<DataType>,
191{
192 match call {
193 AggregateCall::CountStar | AggregateCall::Count { .. } | AggregateCall::CountNulls(_) => {
194 Some(DataType::Int64)
195 }
196 AggregateCall::Sum { expr, .. } => {
197 let child = expr.infer_result_type(resolve_type)?;
198 Some(match child {
199 DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
200 DataType::Float32 | DataType::Float64 => DataType::Float64,
201 DataType::UInt64 | DataType::Int64 => child,
202 DataType::UInt32
203 | DataType::UInt16
204 | DataType::UInt8
205 | DataType::Int32
206 | DataType::Int16
207 | DataType::Int8 => DataType::Int64,
208 _ => DataType::Float64,
209 })
210 }
211 AggregateCall::Total { expr, .. } | AggregateCall::Avg { expr, .. } => {
212 let child = expr.infer_result_type(resolve_type)?;
213 Some(match child {
214 DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
215 _ => DataType::Float64,
216 })
217 }
218 AggregateCall::Min(expr) | AggregateCall::Max(expr) => expr.infer_result_type(resolve_type),
219 AggregateCall::GroupConcat { .. } => Some(DataType::Utf8),
220 }
221}
222
223fn binary_result_type(op: BinaryOp, lhs: DataType, rhs: DataType) -> DataType {
224 crate::kernels::common_type_for_op(&lhs, &rhs, op)
225}
226
227#[derive(Clone, Copy, Debug)]
229pub struct AffineExpr<F> {
230 pub field: F,
231 pub scale: f64,
232 pub offset: f64,
233}
234
235#[derive(Clone, Copy, Debug)]
237#[allow(dead_code)]
238struct AffineState<F> {
239 field: Option<F>,
240 scale: f64,
241 offset: f64,
242}
243
244pub struct ScalarEvaluator;
247
248impl ScalarEvaluator {
249 #[allow(dead_code)]
251 fn merge_field<F: Eq + Copy>(lhs: Option<F>, rhs: Option<F>) -> Option<Option<F>> {
252 match (lhs, rhs) {
253 (Some(a), Some(b)) => {
254 if a == b {
255 Some(Some(a))
256 } else {
257 None
258 }
259 }
260 (Some(a), None) => Some(Some(a)),
261 (None, Some(b)) => Some(Some(b)),
262 (None, None) => Some(None),
263 }
264 }
265
266 pub fn collect_fields<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>, acc: &mut FxHashSet<F>) {
268 match expr {
269 ScalarExpr::Column(fid) => {
270 acc.insert(*fid);
271 }
272 ScalarExpr::Literal(_) => {}
273 ScalarExpr::Binary { left, right, .. } => {
274 Self::collect_fields(left, acc);
275 Self::collect_fields(right, acc);
276 }
277 ScalarExpr::Compare { left, right, .. } => {
278 Self::collect_fields(left, acc);
279 Self::collect_fields(right, acc);
280 }
281 ScalarExpr::Not(inner) => {
282 Self::collect_fields(inner, acc);
283 }
284 ScalarExpr::IsNull { expr, .. } => {
285 Self::collect_fields(expr, acc);
286 }
287 ScalarExpr::Aggregate(agg) => {
288 match agg {
290 AggregateCall::CountStar => {}
291 AggregateCall::Count { expr, .. }
292 | AggregateCall::Sum { expr, .. }
293 | AggregateCall::Total { expr, .. }
294 | AggregateCall::Avg { expr, .. }
295 | AggregateCall::Min(expr)
296 | AggregateCall::Max(expr)
297 | AggregateCall::CountNulls(expr)
298 | AggregateCall::GroupConcat { expr, .. } => {
299 Self::collect_fields(expr, acc);
300 }
301 }
302 }
303 ScalarExpr::GetField { base, .. } => {
304 Self::collect_fields(base, acc);
306 }
307 ScalarExpr::Cast { expr, .. } => {
308 Self::collect_fields(expr, acc);
309 }
310 ScalarExpr::Case {
311 operand,
312 branches,
313 else_expr,
314 } => {
315 if let Some(inner) = operand.as_deref() {
316 Self::collect_fields(inner, acc);
317 }
318 for (when_expr, then_expr) in branches {
319 Self::collect_fields(when_expr, acc);
320 Self::collect_fields(then_expr, acc);
321 }
322 if let Some(inner) = else_expr.as_deref() {
323 Self::collect_fields(inner, acc);
324 }
325 }
326 ScalarExpr::Coalesce(items) => {
327 for item in items {
328 Self::collect_fields(item, acc);
329 }
330 }
331 ScalarExpr::Random => {
332 }
334 ScalarExpr::ScalarSubquery(_) => {
335 }
337 }
338 }
339
340 pub fn prepare_numeric_arrays<F: Hash + Eq + Copy>(
341 arrays: &FxHashMap<F, ArrayRef>,
342 _row_count: usize,
343 ) -> NumericArrayMap<F> {
344 arrays.clone()
345 }
346
347 pub fn extract_affine<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineExpr<F>> {
350 let simplified = Self::simplify(expr);
351 Self::extract_affine_simplified(&simplified)
352 }
353
354 pub fn extract_affine_simplified<F: Hash + Eq + Copy>(
356 expr: &ScalarExpr<F>,
357 ) -> Option<AffineExpr<F>> {
358 let state = Self::affine_state(expr)?;
359 let field = state.field?;
360 Some(AffineExpr {
361 field,
362 scale: state.scale,
363 offset: state.offset,
364 })
365 }
366
367 fn affine_state<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineState<F>> {
368 match expr {
369 ScalarExpr::Column(fid) => Some(AffineState {
370 field: Some(*fid),
371 scale: 1.0,
372 offset: 0.0,
373 }),
374 ScalarExpr::Literal(lit) => {
375 let arr = Self::literal_to_array(lit);
376 let val = cast::cast(&arr, &DataType::Float64).ok()?;
377 let val = val.as_any().downcast_ref::<Float64Array>()?;
378 if val.is_null(0) {
379 return None;
380 }
381 Some(AffineState {
382 field: None,
383 scale: 0.0,
384 offset: val.value(0),
385 })
386 }
387 ScalarExpr::Aggregate(_) => None,
388 ScalarExpr::GetField { .. } => None,
389 ScalarExpr::Binary { left, op, right } => {
390 let left_state = Self::affine_state(left)?;
391 let right_state = Self::affine_state(right)?;
392 match op {
393 BinaryOp::Add => Self::affine_add(left_state, right_state),
394 BinaryOp::Subtract => Self::affine_sub(left_state, right_state),
395 BinaryOp::Multiply => Self::affine_mul(left_state, right_state),
396 BinaryOp::Divide => Self::affine_div(left_state, right_state),
397 _ => None,
398 }
399 }
400 ScalarExpr::Cast { expr, .. } => Self::affine_state(expr),
401 _ => None,
402 }
403 }
404
405 pub fn evaluate_value<F: Hash + Eq + Copy>(
407 expr: &ScalarExpr<F>,
408 idx: usize,
409 arrays: &NumericArrayMap<F>,
410 ) -> LlkvResult<ArrayRef> {
411 match expr {
412 ScalarExpr::Column(fid) => {
413 let array = arrays
414 .get(fid)
415 .ok_or_else(|| Error::Internal("missing column for field".into()))?;
416 Ok(array.slice(idx, 1))
417 }
418 ScalarExpr::Literal(lit) => Ok(Self::literal_to_array(lit)),
419 ScalarExpr::Binary { left, op, right } => {
420 let l = Self::evaluate_value(left, idx, arrays)?;
421 let r = Self::evaluate_value(right, idx, arrays)?;
422 Self::evaluate_binary_scalar(&l, *op, &r)
423 }
424 ScalarExpr::Compare { left, op, right } => {
425 let l = Self::evaluate_value(left, idx, arrays)?;
426 let r = Self::evaluate_value(right, idx, arrays)?;
427 crate::kernels::compute_compare(&l, *op, &r)
428 }
429 ScalarExpr::Not(expr) => {
430 let val = Self::evaluate_value(expr, idx, arrays)?;
431 let bool_arr = cast::cast(&val, &DataType::Boolean)
432 .map_err(|e| Error::Internal(e.to_string()))?;
433 let bool_arr = bool_arr
434 .as_any()
435 .downcast_ref::<arrow::array::BooleanArray>()
436 .unwrap();
437 let result = arrow::compute::kernels::boolean::not(bool_arr)
438 .map_err(|e| Error::Internal(e.to_string()))?;
439 Ok(Arc::new(result))
440 }
441 ScalarExpr::IsNull { expr, negated } => {
442 let val = Self::evaluate_value(expr, idx, arrays)?;
443 let is_null = val.is_null(0);
444 let result = if *negated { !is_null } else { is_null };
445 Ok(Arc::new(arrow::array::BooleanArray::from(vec![result])))
446 }
447 ScalarExpr::Cast { expr, data_type } => {
448 let val = Self::evaluate_value(expr, idx, arrays)?;
449 cast::cast(&val, data_type).map_err(|e| Error::Internal(e.to_string()))
450 }
451 ScalarExpr::Case {
452 operand,
453 branches,
454 else_expr,
455 } => {
456 let operand_val = if let Some(op) = operand {
457 Some(Self::evaluate_value(op, idx, arrays)?)
458 } else {
459 None
460 };
461
462 for (when_expr, then_expr) in branches {
463 let when_val = Self::evaluate_value(when_expr, idx, arrays)?;
464
465 let is_match = if let Some(op_val) = &operand_val {
466 if op_val.is_null(0) || when_val.is_null(0) {
469 false
470 } else {
471 let eq =
472 crate::kernels::compute_compare(op_val, CompareOp::Eq, &when_val)?;
473 let bool_arr = eq
474 .as_any()
475 .downcast_ref::<arrow::array::BooleanArray>()
476 .unwrap();
477 bool_arr.value(0)
478 }
479 } else {
480 if when_val.is_null(0) {
482 false
483 } else {
484 let bool_arr = cast::cast(&when_val, &DataType::Boolean)
485 .map_err(|e| Error::Internal(e.to_string()))?;
486 let bool_arr = bool_arr
487 .as_any()
488 .downcast_ref::<arrow::array::BooleanArray>()
489 .unwrap();
490 bool_arr.value(0)
491 }
492 };
493
494 if is_match {
495 return Self::evaluate_value(then_expr, idx, arrays);
496 }
497 }
498 if let Some(else_expr) = else_expr {
499 Self::evaluate_value(else_expr, idx, arrays)
500 } else {
501 Ok(new_null_array(&DataType::Null, 1))
502 }
503 }
504 ScalarExpr::Coalesce(items) => {
505 for item in items {
506 let val = Self::evaluate_value(item, idx, arrays)?;
507 if !val.is_null(0) && val.data_type() != &DataType::Null {
508 return Ok(val);
509 }
510 }
511 Ok(new_null_array(&DataType::Null, 1))
512 }
513 ScalarExpr::Random => {
514 let val = rand::random::<f64>();
515 Ok(Arc::new(Float64Array::from(vec![val])))
516 }
517 _ => Err(Error::Internal("Unsupported scalar expression".into())),
518 }
519 }
520
521 fn literal_to_array(lit: &Literal) -> ArrayRef {
522 match lit {
523 Literal::Null => new_null_array(&DataType::Null, 1),
524 Literal::Boolean(b) => Arc::new(arrow::array::BooleanArray::from(vec![*b])),
525 Literal::Int128(i) => Arc::new(arrow::array::Int64Array::from(vec![*i as i64])),
526 Literal::Float64(f) => Arc::new(Float64Array::from(vec![*f])),
527 Literal::Decimal128(d) => {
528 let array = arrow::array::Decimal128Array::from(vec![Some(d.raw_value())])
529 .with_precision_and_scale(d.precision(), d.scale())
530 .unwrap();
531 Arc::new(array)
532 }
533 Literal::String(s) => Arc::new(arrow::array::StringArray::from(vec![s.clone()])),
534 Literal::Date32(d) => Arc::new(arrow::array::Date32Array::from(vec![*d])),
535 Literal::Interval(i) => {
536 let val = IntervalMonthDayNanoType::make_value(i.months, i.days, i.nanos);
537 Arc::new(arrow::array::IntervalMonthDayNanoArray::from(vec![val]))
538 }
539 Literal::Struct(_) => {
540 new_null_array(&DataType::Struct(arrow::datatypes::Fields::empty()), 1)
541 }
542 }
543 }
544
545 fn evaluate_binary_scalar(
546 lhs: &ArrayRef,
547 op: BinaryOp,
548 rhs: &ArrayRef,
549 ) -> LlkvResult<ArrayRef> {
550 compute_binary(lhs, rhs, op)
551 }
552
553 #[allow(dead_code)]
555 pub fn evaluate_batch<F: Hash + Eq + Copy + std::fmt::Debug>(
556 expr: &ScalarExpr<F>,
557 len: usize,
558 arrays: &NumericArrayMap<F>,
559 ) -> LlkvResult<ArrayRef> {
560 let simplified = Self::simplify(expr);
561 Self::evaluate_batch_simplified(&simplified, len, arrays)
562 }
563
564 pub fn evaluate_batch_simplified<F: Hash + Eq + Copy>(
566 expr: &ScalarExpr<F>,
567 len: usize,
568 arrays: &NumericArrayMap<F>,
569 ) -> LlkvResult<ArrayRef> {
570 let preferred = expr.infer_result_type_from_arrays(arrays);
571
572 if len == 0 {
573 return Ok(new_null_array(&preferred, 0));
574 }
575
576 if let Some(fast_path) = NumericFastPath::compile(expr, arrays, &preferred) {
577 let fast_result = fast_path.execute(len, arrays)?;
578 if fast_result.data_type() != &preferred {
579 let casted = cast::cast(&fast_result, &preferred).map_err(|e| {
580 Error::Internal(format!("Failed to cast fast path result: {}", e))
581 })?;
582 return Ok(casted);
583 }
584 return Ok(fast_result);
585 }
586
587 if let Some(vectorized) =
588 Self::try_evaluate_vectorized(expr, len, arrays, preferred.clone())?
589 {
590 let result = vectorized.materialize(len, preferred);
591 return Ok(result);
592 }
593
594 let mut values = Vec::with_capacity(len);
595 for idx in 0..len {
596 let val = Self::evaluate_value(expr, idx, arrays)?;
597 if val.data_type() != &preferred {
598 let casted = cast::cast(&val, &preferred).map_err(|e| {
599 Error::Internal(format!(
600 "Failed to cast row {}: {} (Val type: {:?}, Preferred: {:?})",
601 idx,
602 e,
603 val.data_type(),
604 preferred
605 ))
606 })?;
607 values.push(casted);
608 } else {
609 values.push(val);
610 }
611 }
612 concat(&values.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
613 .map_err(|e| Error::Internal(e.to_string()))
614 }
615
616 fn try_evaluate_vectorized<F: Hash + Eq + Copy>(
617 expr: &ScalarExpr<F>,
618 len: usize,
619 arrays: &NumericArrayMap<F>,
620 _target_type: DataType,
621 ) -> LlkvResult<Option<VectorizedExpr>> {
622 if expr.contains_interval() {
623 return Ok(None);
624 }
625 match expr {
626 ScalarExpr::Column(fid) => {
627 let array = arrays
628 .get(fid)
629 .ok_or_else(|| Error::Internal("missing column for field".into()))?;
630 Ok(Some(VectorizedExpr::Array(array.clone())))
631 }
632 ScalarExpr::Literal(lit) => {
633 let array = Self::literal_to_array(lit);
634 Ok(Some(VectorizedExpr::Scalar(array)))
635 }
636 ScalarExpr::Binary { left, op, right } => {
637 let left_type = left.infer_result_type_from_arrays(arrays);
638 let right_type = right.infer_result_type_from_arrays(arrays);
639
640 let left_vec = Self::try_evaluate_vectorized(left, len, arrays, left_type)?;
641 let right_vec = Self::try_evaluate_vectorized(right, len, arrays, right_type)?;
642
643 match (left_vec, right_vec) {
644 (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
645 let result = compute_binary(&lhs, &rhs, *op)?;
646 Ok(Some(VectorizedExpr::Scalar(result)))
647 }
648 (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Array(rhs))) => {
649 let array = compute_binary(&lhs, &rhs, *op)?;
650 Ok(Some(VectorizedExpr::Array(array)))
651 }
652 (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
653 let rhs_expanded = VectorizedExpr::Scalar(rhs)
654 .materialize(lhs.len(), lhs.data_type().clone());
655 let array = compute_binary(&lhs, &rhs_expanded, *op)?;
656 Ok(Some(VectorizedExpr::Array(array)))
657 }
658 (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Array(rhs))) => {
659 let lhs_expanded = VectorizedExpr::Scalar(lhs)
660 .materialize(rhs.len(), rhs.data_type().clone());
661 let array = compute_binary(&lhs_expanded, &rhs, *op)?;
662 Ok(Some(VectorizedExpr::Array(array)))
663 }
664 _ => Ok(None),
665 }
666 }
667 ScalarExpr::Cast { expr, data_type } => {
668 let inner_type = expr.infer_result_type_from_arrays(arrays);
669 let inner_vec = Self::try_evaluate_vectorized(expr, len, arrays, inner_type)?;
670
671 match inner_vec {
672 Some(VectorizedExpr::Scalar(array)) => {
673 let casted = cast::cast(&array, data_type)
674 .map_err(|e| Error::Internal(e.to_string()))?;
675 Ok(Some(VectorizedExpr::Scalar(casted)))
676 }
677 Some(VectorizedExpr::Array(array)) => {
678 let casted = cast::cast(&array, data_type)
679 .map_err(|e| Error::Internal(e.to_string()))?;
680 Ok(Some(VectorizedExpr::Array(casted)))
681 }
682 None => Ok(None),
683 }
684 }
685 ScalarExpr::Coalesce(items) => {
686 let mut evaluated_items = Vec::with_capacity(items.len());
687 let mut types = Vec::with_capacity(items.len());
688
689 for item in items {
690 let item_type = item.infer_result_type_from_arrays(arrays);
691 let vec_expr = match Self::try_evaluate_vectorized(
693 item,
694 len,
695 arrays,
696 item_type.clone(),
697 )? {
698 Some(v) => v,
699 None => return Ok(None),
700 };
701
702 let array = vec_expr.materialize(len, item_type.clone());
703 types.push(array.data_type().clone());
704 evaluated_items.push(array);
705 }
706
707 if evaluated_items.is_empty() {
708 return Ok(Some(VectorizedExpr::Array(new_null_array(
709 &DataType::Null,
710 len,
711 ))));
712 }
713
714 let mut common_type = types[0].clone();
716 for t in &types[1..] {
717 common_type = get_common_type(&common_type, t);
718 }
719
720 let mut casted_arrays = Vec::with_capacity(evaluated_items.len());
722 for array in evaluated_items {
723 if array.data_type() != &common_type {
724 let casted = cast::cast(&array, &common_type)
725 .map_err(|e| Error::Internal(e.to_string()))?;
726 casted_arrays.push(casted);
727 } else {
728 casted_arrays.push(array);
729 }
730 }
731
732 let mut result = casted_arrays[0].clone();
733 for next_array in &casted_arrays[1..] {
734 let mask = is_not_null(&result).map_err(|e| Error::Internal(e.to_string()))?;
735 result = zip(&mask, &result, next_array)
739 .map_err(|e| Error::Internal(e.to_string()))?;
740 }
741 Ok(Some(VectorizedExpr::Array(result)))
742 }
743 ScalarExpr::Random => {
744 let values: Vec<f64> = (0..len).map(|_| rand::random::<f64>()).collect();
745 let array = Float64Array::from(values);
746 Ok(Some(VectorizedExpr::Array(Arc::new(array))))
747 }
748 _ => Ok(None),
749 }
750 }
751
752 pub fn passthrough_column<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<F> {
754 match Self::simplify(expr) {
755 ScalarExpr::Column(fid) => Some(fid),
756 _ => None,
757 }
758 }
759
760 pub fn simplify<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> ScalarExpr<F> {
762 match expr {
763 ScalarExpr::Binary { left, op, right } => {
764 let l = Self::simplify(left);
765 let r = Self::simplify(right);
766 if let (ScalarExpr::Literal(ll), ScalarExpr::Literal(rr)) = (&l, &r)
767 && let Some(folded) = fold_binary_literals(*op, ll, rr)
768 {
769 return ScalarExpr::Literal(folded);
770 }
771 ScalarExpr::Binary {
772 left: Box::new(l),
773 op: *op,
774 right: Box::new(r),
775 }
776 }
777 ScalarExpr::Cast { expr, data_type } => {
778 let inner = Self::simplify(expr);
779 if let ScalarExpr::Literal(lit) = &inner
780 && let Some(folded) = fold_cast_literal(lit, data_type)
781 {
782 return ScalarExpr::Literal(folded);
783 }
784 ScalarExpr::Cast {
785 expr: Box::new(inner),
786 data_type: data_type.clone(),
787 }
788 }
789 _ => expr.clone(),
790 }
791 }
792
793 pub fn evaluate_constant_literal_expr<F: Hash + Eq + Copy>(
794 expr: &ScalarExpr<F>,
795 ) -> LlkvResult<Option<Literal>> {
796 let simplified = Self::simplify(expr);
797
798 if let Some(literal) = Self::evaluate_constant_literal_non_numeric(&simplified)? {
799 return Ok(Some(literal));
800 }
801
802 if let ScalarExpr::Literal(lit) = &simplified {
803 return Ok(Some(lit.clone()));
804 }
805
806 let arrays = NumericArrayMap::default();
807 let array = Self::evaluate_value(&simplified, 0, &arrays)?;
808 if array.is_null(0) {
809 return Ok(None);
810 }
811 Ok(Some(Literal::from_array_ref(&array, 0)?))
812 }
813
814 pub fn evaluate_constant_literal_non_numeric<F: Hash + Eq + Copy>(
815 expr: &ScalarExpr<F>,
816 ) -> LlkvResult<Option<Literal>> {
817 match expr {
818 ScalarExpr::Literal(lit) => Ok(Some(lit.clone())),
819 ScalarExpr::Cast {
820 expr,
821 data_type: DataType::Date32,
822 } => {
823 let inner = Self::evaluate_constant_literal_non_numeric(expr)?;
824 match inner {
825 Some(Literal::Null) => Ok(Some(Literal::Null)),
826 Some(Literal::String(text)) => {
827 let days = parse_date32_literal(&text)?;
828 Ok(Some(Literal::Date32(days)))
829 }
830 Some(Literal::Date32(days)) => Ok(Some(Literal::Date32(days))),
831 Some(other) => Err(Error::InvalidArgumentError(format!(
832 "cannot cast literal of type {} to DATE",
833 other.type_name()
834 ))),
835 None => Ok(None),
836 }
837 }
838 ScalarExpr::Cast { .. } => Ok(None),
839 ScalarExpr::Binary { left, op, right } => {
840 let left_lit = match Self::evaluate_constant_literal_non_numeric(left)? {
841 Some(lit) => lit,
842 None => return Ok(None),
843 };
844 let right_lit = match Self::evaluate_constant_literal_non_numeric(right)? {
845 Some(lit) => lit,
846 None => return Ok(None),
847 };
848
849 if matches!(left_lit, Literal::Null) || matches!(right_lit, Literal::Null) {
850 return Ok(Some(Literal::Null));
851 }
852
853 match op {
854 BinaryOp::Add => match (&left_lit, &right_lit) {
855 (Literal::Date32(days), Literal::Interval(interval))
856 | (Literal::Interval(interval), Literal::Date32(days)) => {
857 let adjusted = add_interval_to_date32(*days, *interval)?;
858 Ok(Some(Literal::Date32(adjusted)))
859 }
860 (Literal::Interval(left), Literal::Interval(right)) => {
861 let sum = left.checked_add(*right).ok_or_else(|| {
862 Error::InvalidArgumentError(
863 "interval addition overflow during constant folding".into(),
864 )
865 })?;
866 Ok(Some(Literal::Interval(sum)))
867 }
868 _ => Ok(None),
869 },
870 BinaryOp::Subtract => match (&left_lit, &right_lit) {
871 (Literal::Date32(days), Literal::Interval(interval)) => {
872 let adjusted = subtract_interval_from_date32(*days, *interval)?;
873 Ok(Some(Literal::Date32(adjusted)))
874 }
875 (Literal::Interval(left), Literal::Interval(right)) => {
876 let diff = left.checked_sub(*right).ok_or_else(|| {
877 Error::InvalidArgumentError(
878 "interval subtraction overflow during constant folding".into(),
879 )
880 })?;
881 Ok(Some(Literal::Interval(diff)))
882 }
883 (Literal::Date32(lhs), Literal::Date32(rhs)) => {
884 let delta = i64::from(*lhs) - i64::from(*rhs);
885 if delta < i64::from(i32::MIN) || delta > i64::from(i32::MAX) {
886 return Err(Error::InvalidArgumentError(
887 "DATE subtraction overflowed day precision".into(),
888 ));
889 }
890 Ok(Some(Literal::Interval(IntervalValue::new(
891 0,
892 delta as i32,
893 0,
894 ))))
895 }
896 _ => Ok(None),
897 },
898 _ => Ok(None),
899 }
900 }
901 _ => Ok(None),
902 }
903 }
904
905 pub fn is_supported_numeric(dtype: &DataType) -> bool {
907 matches!(
908 dtype,
909 DataType::UInt64
910 | DataType::UInt32
911 | DataType::UInt16
912 | DataType::UInt8
913 | DataType::Int64
914 | DataType::Int32
915 | DataType::Int16
916 | DataType::Int8
917 | DataType::Float64
918 | DataType::Float32
919 )
920 }
921
922 #[allow(dead_code)]
923 fn affine_add<F: Eq + Copy>(
924 lhs: AffineState<F>,
925 rhs: AffineState<F>,
926 ) -> Option<AffineState<F>> {
927 let merged_field = Self::merge_field(lhs.field, rhs.field)?;
928 if merged_field.is_none() {
929 return Some(AffineState {
931 field: None,
932 scale: 0.0,
933 offset: lhs.offset + rhs.offset,
934 });
935 }
936 Some(AffineState {
937 field: merged_field,
938 scale: lhs.scale + rhs.scale,
939 offset: lhs.offset + rhs.offset,
940 })
941 }
942
943 #[allow(dead_code)]
944 fn affine_sub<F: Eq + Copy>(
945 lhs: AffineState<F>,
946 rhs: AffineState<F>,
947 ) -> Option<AffineState<F>> {
948 let merged_field = Self::merge_field(lhs.field, rhs.field)?;
949 if merged_field.is_none() {
950 return Some(AffineState {
951 field: None,
952 scale: 0.0,
953 offset: lhs.offset - rhs.offset,
954 });
955 }
956 Some(AffineState {
957 field: merged_field,
958 scale: lhs.scale - rhs.scale,
959 offset: lhs.offset - rhs.offset,
960 })
961 }
962
963 #[allow(dead_code)]
964 fn affine_mul<F: Eq + Copy>(
965 lhs: AffineState<F>,
966 rhs: AffineState<F>,
967 ) -> Option<AffineState<F>> {
968 if lhs.field.is_some() && rhs.field.is_some() {
969 return None; }
971 if lhs.field.is_none() {
972 let factor = lhs.offset;
973 return Some(AffineState {
974 field: rhs.field,
975 scale: rhs.scale * factor,
976 offset: rhs.offset * factor,
977 });
978 }
979 if rhs.field.is_none() {
980 let factor = rhs.offset;
981 return Some(AffineState {
982 field: lhs.field,
983 scale: lhs.scale * factor,
984 offset: lhs.offset * factor,
985 });
986 }
987 None
988 }
989
990 #[allow(dead_code)]
991 fn affine_div<F: Eq + Copy>(
992 lhs: AffineState<F>,
993 rhs: AffineState<F>,
994 ) -> Option<AffineState<F>> {
995 if rhs.field.is_some() {
996 return None;
997 }
998 let denom = rhs.offset;
999 if denom == 0.0 {
1000 return None;
1001 }
1002 Some(AffineState {
1003 field: lhs.field,
1004 scale: lhs.scale / denom,
1005 offset: lhs.offset / denom,
1006 })
1007 }
1008}
1009
1010fn fold_binary_literals(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
1011 match op {
1012 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
1013 let pg_op = match op {
1014 BinaryOp::BitwiseShiftLeft => BinaryOperator::PGBitwiseShiftLeft,
1015 BinaryOp::BitwiseShiftRight => BinaryOperator::PGBitwiseShiftRight,
1016 _ => unreachable!(),
1017 };
1018 crate::literal::bitshift_literals(pg_op, left, right).ok()
1019 }
1020 _ => {
1021 let l_arr = ScalarEvaluator::literal_to_array(left);
1022 let r_arr = ScalarEvaluator::literal_to_array(right);
1023 let result = compute_binary(&l_arr, &r_arr, op).ok()?;
1024 if result.is_null(0) {
1025 Some(Literal::Null)
1026 } else {
1027 Literal::from_array_ref(&result, 0).ok()
1028 }
1029 }
1030 }
1031}
1032
1033fn fold_cast_literal(lit: &Literal, data_type: &DataType) -> Option<Literal> {
1034 if matches!(lit, Literal::Null) {
1035 return None;
1037 }
1038 let arr = ScalarEvaluator::literal_to_array(lit);
1039 let casted = cast::cast(&arr, data_type).ok()?;
1040 if casted.is_null(0) {
1041 Some(Literal::Null)
1042 } else {
1043 Literal::from_array_ref(&casted, 0).ok()
1044 }
1045}