1use std::hash::Hash;
2use std::sync::Arc;
3
4use arrow::array::{Array, ArrayRef, Float64Array, UInt32Array, new_null_array};
5use arrow::compute::kernels::cast;
6use arrow::compute::kernels::zip::zip;
7use arrow::compute::{concat, is_not_null, take};
8use arrow::datatypes::{DataType, Field, IntervalMonthDayNanoType};
9use llkv_expr::literal::{Literal, LiteralExt};
10use llkv_expr::{AggregateCall, BinaryOp, CompareOp, ScalarExpr};
11use llkv_result::{Error, Result as LlkvResult};
12use llkv_types::IntervalValue;
13use rustc_hash::{FxHashMap, FxHashSet};
14use sqlparser::ast::BinaryOperator;
15
16use crate::date::{add_interval_to_date32, parse_date32_literal, subtract_interval_from_date32};
17use crate::kernels::{compute_binary, get_common_type};
18
19pub type NumericArrayMap<F> = FxHashMap<F, ArrayRef>;
21
22enum VectorizedExpr {
24 Array(ArrayRef),
25 Scalar(ArrayRef),
26}
27
28impl VectorizedExpr {
29 fn materialize(self, len: usize, target_type: DataType) -> ArrayRef {
30 match self {
31 VectorizedExpr::Array(array) => {
32 if array.data_type() == &target_type {
33 array
34 } else {
35 cast::cast(&array, &target_type).unwrap_or(array)
36 }
37 }
38 VectorizedExpr::Scalar(scalar_array) => {
39 if scalar_array.is_empty() {
40 return new_null_array(&target_type, len);
41 }
42 if scalar_array.is_null(0) {
43 return new_null_array(scalar_array.data_type(), len);
44 }
45
46 let indices = UInt32Array::from(vec![0; len]);
48 take(&scalar_array, &indices, None)
49 .unwrap_or_else(|_| new_null_array(scalar_array.data_type(), len))
50 }
51 }
52 }
53}
54
55pub trait ScalarExprTypeExt<F> {
57 fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
58 where
59 F: Hash + Eq + Copy,
60 R: FnMut(F) -> Option<DataType>;
61
62 fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType
63 where
64 F: Hash + Eq + Copy;
65
66 fn contains_interval(&self) -> bool;
67}
68
69impl<F: Hash + Eq + Copy> ScalarExprTypeExt<F> for ScalarExpr<F> {
70 fn infer_result_type<R>(&self, resolve_type: &mut R) -> Option<DataType>
71 where
72 R: FnMut(F) -> Option<DataType>,
73 {
74 match self {
75 ScalarExpr::Literal(lit) => Some(literal_type(lit)),
76 ScalarExpr::Column(fid) => resolve_type(*fid),
77 ScalarExpr::Binary { left, op, right } => {
78 let left_type = left.infer_result_type(resolve_type)?;
79 let right_type = right.infer_result_type(resolve_type)?;
80 Some(binary_result_type(*op, left_type, right_type))
81 }
82 ScalarExpr::Compare { .. } => Some(DataType::Boolean),
83 ScalarExpr::Not(_) => Some(DataType::Boolean),
84 ScalarExpr::IsNull { .. } => Some(DataType::Boolean),
85 ScalarExpr::Aggregate(call) => aggregate_result_type(call, resolve_type),
86 ScalarExpr::GetField { base, field_name } => {
87 let base_type = base.infer_result_type(resolve_type)?;
88 match base_type {
89 DataType::Struct(fields) => fields
90 .iter()
91 .find(|f| f.name() == field_name)
92 .map(|f| f.data_type().clone()),
93 _ => None,
94 }
95 }
96 ScalarExpr::Cast { data_type, .. } => Some(data_type.clone()),
97 ScalarExpr::Case {
98 branches,
99 else_expr,
100 ..
101 } => {
102 let mut types = Vec::new();
103 for (_, then_expr) in branches {
104 if let Some(t) = then_expr.infer_result_type(resolve_type) {
105 types.push(t);
106 }
107 }
108 if let Some(else_expr) = else_expr {
109 if let Some(t) = else_expr.infer_result_type(resolve_type) {
110 types.push(t);
111 }
112 } else {
113 types.push(DataType::Null);
115 }
116
117 if types.is_empty() {
118 return None;
119 }
120
121 let mut common = types[0].clone();
122 for t in &types[1..] {
123 common = get_common_type(&common, t);
124 }
125
126 Some(common)
127 }
128 ScalarExpr::Coalesce(items) => {
129 let mut types = Vec::new();
130 for item in items {
131 if let Some(t) = item.infer_result_type(resolve_type) {
132 types.push(t);
133 }
134 }
135 if types.is_empty() {
136 return None;
137 }
138 let mut common = types[0].clone();
139 for t in &types[1..] {
140 common = get_common_type(&common, t);
141 }
142 Some(common)
143 }
144 ScalarExpr::Random => Some(DataType::Float64),
145 ScalarExpr::ScalarSubquery(sub) => Some(sub.data_type.clone()),
146 }
147 }
148
149 fn infer_result_type_from_arrays(&self, arrays: &NumericArrayMap<F>) -> DataType {
150 let mut resolver = |fid| arrays.get(&fid).map(|a| a.data_type().clone());
151 self.infer_result_type(&mut resolver)
152 .unwrap_or(DataType::Float64)
153 }
154
155 fn contains_interval(&self) -> bool {
156 match self {
157 ScalarExpr::Literal(Literal::Interval(_)) => true,
158 ScalarExpr::Binary { left, right, .. } => {
159 left.contains_interval() || right.contains_interval()
160 }
161 _ => false,
162 }
163 }
164}
165
166fn literal_type(lit: &Literal) -> DataType {
167 match lit {
168 Literal::Null => DataType::Null,
169 Literal::Boolean(_) => DataType::Boolean,
170 Literal::Int128(_) => DataType::Int64, Literal::Float64(_) => DataType::Float64,
172 Literal::Decimal128(d) => DataType::Decimal128(d.precision(), d.scale()),
173 Literal::String(_) => DataType::Utf8,
174 Literal::Date32(_) => DataType::Date32,
175 Literal::Interval(_) => DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
176 Literal::Struct(fields) => {
177 let arrow_fields = fields
178 .iter()
179 .map(|(name, lit)| Field::new(name, literal_type(lit), true))
180 .collect();
181 DataType::Struct(arrow_fields)
182 }
183 }
184}
185
186fn aggregate_result_type<F, R>(call: &AggregateCall<F>, resolve_type: &mut R) -> Option<DataType>
187where
188 F: Hash + Eq + Copy,
189 R: FnMut(F) -> Option<DataType>,
190{
191 match call {
192 AggregateCall::CountStar | AggregateCall::Count { .. } | AggregateCall::CountNulls(_) => {
193 Some(DataType::Int64)
194 }
195 AggregateCall::Sum { expr, .. } => {
196 let child = expr.infer_result_type(resolve_type)?;
197 Some(match child {
198 DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
199 DataType::Float32 | DataType::Float64 => DataType::Float64,
200 DataType::UInt64 | DataType::Int64 => child,
201 DataType::UInt32
202 | DataType::UInt16
203 | DataType::UInt8
204 | DataType::Int32
205 | DataType::Int16
206 | DataType::Int8 => DataType::Int64,
207 _ => DataType::Float64,
208 })
209 }
210 AggregateCall::Total { expr, .. } | AggregateCall::Avg { expr, .. } => {
211 let child = expr.infer_result_type(resolve_type)?;
212 Some(match child {
213 DataType::Decimal128(p, s) => DataType::Decimal128(p, s),
214 _ => DataType::Float64,
215 })
216 }
217 AggregateCall::Min(expr) | AggregateCall::Max(expr) => expr.infer_result_type(resolve_type),
218 AggregateCall::GroupConcat { .. } => Some(DataType::Utf8),
219 }
220}
221
222fn binary_result_type(_op: BinaryOp, lhs: DataType, rhs: DataType) -> DataType {
223 get_common_type(&lhs, &rhs)
224}
225
226#[derive(Clone, Copy, Debug)]
228pub struct AffineExpr<F> {
229 pub field: F,
230 pub scale: f64,
231 pub offset: f64,
232}
233
234#[derive(Clone, Copy, Debug)]
236#[allow(dead_code)]
237struct AffineState<F> {
238 field: Option<F>,
239 scale: f64,
240 offset: f64,
241}
242
243pub struct ScalarEvaluator;
246
247impl ScalarEvaluator {
248 #[allow(dead_code)]
250 fn merge_field<F: Eq + Copy>(lhs: Option<F>, rhs: Option<F>) -> Option<Option<F>> {
251 match (lhs, rhs) {
252 (Some(a), Some(b)) => {
253 if a == b {
254 Some(Some(a))
255 } else {
256 None
257 }
258 }
259 (Some(a), None) => Some(Some(a)),
260 (None, Some(b)) => Some(Some(b)),
261 (None, None) => Some(None),
262 }
263 }
264
265 pub fn collect_fields<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>, acc: &mut FxHashSet<F>) {
267 match expr {
268 ScalarExpr::Column(fid) => {
269 acc.insert(*fid);
270 }
271 ScalarExpr::Literal(_) => {}
272 ScalarExpr::Binary { left, right, .. } => {
273 Self::collect_fields(left, acc);
274 Self::collect_fields(right, acc);
275 }
276 ScalarExpr::Compare { left, right, .. } => {
277 Self::collect_fields(left, acc);
278 Self::collect_fields(right, acc);
279 }
280 ScalarExpr::Not(inner) => {
281 Self::collect_fields(inner, acc);
282 }
283 ScalarExpr::IsNull { expr, .. } => {
284 Self::collect_fields(expr, acc);
285 }
286 ScalarExpr::Aggregate(agg) => {
287 match agg {
289 AggregateCall::CountStar => {}
290 AggregateCall::Count { expr, .. }
291 | AggregateCall::Sum { expr, .. }
292 | AggregateCall::Total { expr, .. }
293 | AggregateCall::Avg { expr, .. }
294 | AggregateCall::Min(expr)
295 | AggregateCall::Max(expr)
296 | AggregateCall::CountNulls(expr)
297 | AggregateCall::GroupConcat { expr, .. } => {
298 Self::collect_fields(expr, acc);
299 }
300 }
301 }
302 ScalarExpr::GetField { base, .. } => {
303 Self::collect_fields(base, acc);
305 }
306 ScalarExpr::Cast { expr, .. } => {
307 Self::collect_fields(expr, acc);
308 }
309 ScalarExpr::Case {
310 operand,
311 branches,
312 else_expr,
313 } => {
314 if let Some(inner) = operand.as_deref() {
315 Self::collect_fields(inner, acc);
316 }
317 for (when_expr, then_expr) in branches {
318 Self::collect_fields(when_expr, acc);
319 Self::collect_fields(then_expr, acc);
320 }
321 if let Some(inner) = else_expr.as_deref() {
322 Self::collect_fields(inner, acc);
323 }
324 }
325 ScalarExpr::Coalesce(items) => {
326 for item in items {
327 Self::collect_fields(item, acc);
328 }
329 }
330 ScalarExpr::Random => {
331 }
333 ScalarExpr::ScalarSubquery(_) => {
334 }
336 }
337 }
338
339 pub fn prepare_numeric_arrays<F: Hash + Eq + Copy>(
340 arrays: &FxHashMap<F, ArrayRef>,
341 _row_count: usize,
342 ) -> NumericArrayMap<F> {
343 arrays.clone()
344 }
345
346 pub fn extract_affine<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineExpr<F>> {
349 let simplified = Self::simplify(expr);
350 Self::extract_affine_simplified(&simplified)
351 }
352
353 pub fn extract_affine_simplified<F: Hash + Eq + Copy>(
355 expr: &ScalarExpr<F>,
356 ) -> Option<AffineExpr<F>> {
357 let state = Self::affine_state(expr)?;
358 let field = state.field?;
359 Some(AffineExpr {
360 field,
361 scale: state.scale,
362 offset: state.offset,
363 })
364 }
365
366 fn affine_state<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<AffineState<F>> {
367 match expr {
368 ScalarExpr::Column(fid) => Some(AffineState {
369 field: Some(*fid),
370 scale: 1.0,
371 offset: 0.0,
372 }),
373 ScalarExpr::Literal(lit) => {
374 let arr = Self::literal_to_array(lit);
375 let val = cast::cast(&arr, &DataType::Float64).ok()?;
376 let val = val.as_any().downcast_ref::<Float64Array>()?;
377 if val.is_null(0) {
378 return None;
379 }
380 Some(AffineState {
381 field: None,
382 scale: 0.0,
383 offset: val.value(0),
384 })
385 }
386 ScalarExpr::Aggregate(_) => None,
387 ScalarExpr::GetField { .. } => None,
388 ScalarExpr::Binary { left, op, right } => {
389 let left_state = Self::affine_state(left)?;
390 let right_state = Self::affine_state(right)?;
391 match op {
392 BinaryOp::Add => Self::affine_add(left_state, right_state),
393 BinaryOp::Subtract => Self::affine_sub(left_state, right_state),
394 BinaryOp::Multiply => Self::affine_mul(left_state, right_state),
395 BinaryOp::Divide => Self::affine_div(left_state, right_state),
396 _ => None,
397 }
398 }
399 ScalarExpr::Cast { expr, .. } => Self::affine_state(expr),
400 _ => None,
401 }
402 }
403
404 pub fn evaluate_value<F: Hash + Eq + Copy>(
406 expr: &ScalarExpr<F>,
407 idx: usize,
408 arrays: &NumericArrayMap<F>,
409 ) -> LlkvResult<ArrayRef> {
410 match expr {
411 ScalarExpr::Column(fid) => {
412 let array = arrays
413 .get(fid)
414 .ok_or_else(|| Error::Internal("missing column for field".into()))?;
415 Ok(array.slice(idx, 1))
416 }
417 ScalarExpr::Literal(lit) => Ok(Self::literal_to_array(lit)),
418 ScalarExpr::Binary { left, op, right } => {
419 let l = Self::evaluate_value(left, idx, arrays)?;
420 let r = Self::evaluate_value(right, idx, arrays)?;
421 Self::evaluate_binary_scalar(&l, *op, &r)
422 }
423 ScalarExpr::Compare { left, op, right } => {
424 let l = Self::evaluate_value(left, idx, arrays)?;
425 let r = Self::evaluate_value(right, idx, arrays)?;
426 crate::kernels::compute_compare(&l, *op, &r)
427 }
428 ScalarExpr::Not(expr) => {
429 let val = Self::evaluate_value(expr, idx, arrays)?;
430 let bool_arr = cast::cast(&val, &DataType::Boolean)
431 .map_err(|e| Error::Internal(e.to_string()))?;
432 let bool_arr = bool_arr
433 .as_any()
434 .downcast_ref::<arrow::array::BooleanArray>()
435 .unwrap();
436 let result = arrow::compute::kernels::boolean::not(bool_arr)
437 .map_err(|e| Error::Internal(e.to_string()))?;
438 Ok(Arc::new(result))
439 }
440 ScalarExpr::IsNull { expr, negated } => {
441 let val = Self::evaluate_value(expr, idx, arrays)?;
442 let is_null = val.is_null(0);
443 let result = if *negated { !is_null } else { is_null };
444 Ok(Arc::new(arrow::array::BooleanArray::from(vec![result])))
445 }
446 ScalarExpr::Cast { expr, data_type } => {
447 let val = Self::evaluate_value(expr, idx, arrays)?;
448 cast::cast(&val, data_type).map_err(|e| Error::Internal(e.to_string()))
449 }
450 ScalarExpr::Case {
451 operand,
452 branches,
453 else_expr,
454 } => {
455 let operand_val = if let Some(op) = operand {
456 Some(Self::evaluate_value(op, idx, arrays)?)
457 } else {
458 None
459 };
460
461 for (when_expr, then_expr) in branches {
462 let when_val = Self::evaluate_value(when_expr, idx, arrays)?;
463
464 let is_match = if let Some(op_val) = &operand_val {
465 if op_val.is_null(0) || when_val.is_null(0) {
468 false
469 } else {
470 let eq =
471 crate::kernels::compute_compare(op_val, CompareOp::Eq, &when_val)?;
472 let bool_arr = eq
473 .as_any()
474 .downcast_ref::<arrow::array::BooleanArray>()
475 .unwrap();
476 bool_arr.value(0)
477 }
478 } else {
479 if when_val.is_null(0) {
481 false
482 } else {
483 let bool_arr = cast::cast(&when_val, &DataType::Boolean)
484 .map_err(|e| Error::Internal(e.to_string()))?;
485 let bool_arr = bool_arr
486 .as_any()
487 .downcast_ref::<arrow::array::BooleanArray>()
488 .unwrap();
489 bool_arr.value(0)
490 }
491 };
492
493 if is_match {
494 return Self::evaluate_value(then_expr, idx, arrays);
495 }
496 }
497 if let Some(else_expr) = else_expr {
498 Self::evaluate_value(else_expr, idx, arrays)
499 } else {
500 Ok(new_null_array(&DataType::Null, 1))
501 }
502 }
503 ScalarExpr::Coalesce(items) => {
504 for item in items {
505 let val = Self::evaluate_value(item, idx, arrays)?;
506 if !val.is_null(0) && val.data_type() != &DataType::Null {
507 return Ok(val);
508 }
509 }
510 Ok(new_null_array(&DataType::Null, 1))
511 }
512 ScalarExpr::Random => {
513 let val = rand::random::<f64>();
514 Ok(Arc::new(Float64Array::from(vec![val])))
515 }
516 _ => Err(Error::Internal("Unsupported scalar expression".into())),
517 }
518 }
519
520 fn literal_to_array(lit: &Literal) -> ArrayRef {
521 match lit {
522 Literal::Null => new_null_array(&DataType::Null, 1),
523 Literal::Boolean(b) => Arc::new(arrow::array::BooleanArray::from(vec![*b])),
524 Literal::Int128(i) => Arc::new(arrow::array::Int64Array::from(vec![*i as i64])),
525 Literal::Float64(f) => Arc::new(Float64Array::from(vec![*f])),
526 Literal::Decimal128(d) => {
527 let array = arrow::array::Decimal128Array::from(vec![Some(d.raw_value())])
528 .with_precision_and_scale(d.precision(), d.scale())
529 .unwrap();
530 Arc::new(array)
531 }
532 Literal::String(s) => Arc::new(arrow::array::StringArray::from(vec![s.clone()])),
533 Literal::Date32(d) => Arc::new(arrow::array::Date32Array::from(vec![*d])),
534 Literal::Interval(i) => {
535 let val = IntervalMonthDayNanoType::make_value(i.months, i.days, i.nanos);
536 Arc::new(arrow::array::IntervalMonthDayNanoArray::from(vec![val]))
537 }
538 Literal::Struct(_) => {
539 new_null_array(&DataType::Struct(arrow::datatypes::Fields::empty()), 1)
540 }
541 }
542 }
543
544 fn evaluate_binary_scalar(
545 lhs: &ArrayRef,
546 op: BinaryOp,
547 rhs: &ArrayRef,
548 ) -> LlkvResult<ArrayRef> {
549 compute_binary(lhs, rhs, op)
550 }
551
552 #[allow(dead_code)]
554 pub fn evaluate_batch<F: Hash + Eq + Copy + std::fmt::Debug>(
555 expr: &ScalarExpr<F>,
556 len: usize,
557 arrays: &NumericArrayMap<F>,
558 ) -> LlkvResult<ArrayRef> {
559 let simplified = Self::simplify(expr);
560 Self::evaluate_batch_simplified(&simplified, len, arrays)
561 }
562
563 pub fn evaluate_batch_simplified<F: Hash + Eq + Copy>(
565 expr: &ScalarExpr<F>,
566 len: usize,
567 arrays: &NumericArrayMap<F>,
568 ) -> LlkvResult<ArrayRef> {
569 let preferred = expr.infer_result_type_from_arrays(arrays);
570
571 if len == 0 {
572 return Ok(new_null_array(&preferred, 0));
573 }
574
575 if let Some(vectorized) =
576 Self::try_evaluate_vectorized(expr, len, arrays, preferred.clone())?
577 {
578 let result = vectorized.materialize(len, preferred);
579 return Ok(result);
580 }
581
582 let mut values = Vec::with_capacity(len);
583 for idx in 0..len {
584 let val = Self::evaluate_value(expr, idx, arrays)?;
585 if val.data_type() != &preferred {
586 let casted = cast::cast(&val, &preferred).map_err(|e| {
587 Error::Internal(format!(
588 "Failed to cast row {}: {} (Val type: {:?}, Preferred: {:?})",
589 idx,
590 e,
591 val.data_type(),
592 preferred
593 ))
594 })?;
595 values.push(casted);
596 } else {
597 values.push(val);
598 }
599 }
600 concat(&values.iter().map(|a| a.as_ref()).collect::<Vec<_>>())
601 .map_err(|e| Error::Internal(e.to_string()))
602 }
603
604 fn try_evaluate_vectorized<F: Hash + Eq + Copy>(
605 expr: &ScalarExpr<F>,
606 len: usize,
607 arrays: &NumericArrayMap<F>,
608 _target_type: DataType,
609 ) -> LlkvResult<Option<VectorizedExpr>> {
610 if expr.contains_interval() {
611 return Ok(None);
612 }
613 match expr {
614 ScalarExpr::Column(fid) => {
615 let array = arrays
616 .get(fid)
617 .ok_or_else(|| Error::Internal("missing column for field".into()))?;
618 Ok(Some(VectorizedExpr::Array(array.clone())))
619 }
620 ScalarExpr::Literal(lit) => {
621 let array = Self::literal_to_array(lit);
622 Ok(Some(VectorizedExpr::Scalar(array)))
623 }
624 ScalarExpr::Binary { left, op, right } => {
625 let left_type = left.infer_result_type_from_arrays(arrays);
626 let right_type = right.infer_result_type_from_arrays(arrays);
627
628 let left_vec = Self::try_evaluate_vectorized(left, len, arrays, left_type)?;
629 let right_vec = Self::try_evaluate_vectorized(right, len, arrays, right_type)?;
630
631 match (left_vec, right_vec) {
632 (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
633 let result = compute_binary(&lhs, &rhs, *op)?;
634 Ok(Some(VectorizedExpr::Scalar(result)))
635 }
636 (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Array(rhs))) => {
637 let array = compute_binary(&lhs, &rhs, *op)?;
638 Ok(Some(VectorizedExpr::Array(array)))
639 }
640 (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
641 let rhs_expanded = VectorizedExpr::Scalar(rhs)
642 .materialize(lhs.len(), lhs.data_type().clone());
643 let array = compute_binary(&lhs, &rhs_expanded, *op)?;
644 Ok(Some(VectorizedExpr::Array(array)))
645 }
646 (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Array(rhs))) => {
647 let lhs_expanded = VectorizedExpr::Scalar(lhs)
648 .materialize(rhs.len(), rhs.data_type().clone());
649 let array = compute_binary(&lhs_expanded, &rhs, *op)?;
650 Ok(Some(VectorizedExpr::Array(array)))
651 }
652 _ => Ok(None),
653 }
654 }
655 ScalarExpr::Cast { expr, data_type } => {
656 let inner_type = expr.infer_result_type_from_arrays(arrays);
657 let inner_vec = Self::try_evaluate_vectorized(expr, len, arrays, inner_type)?;
658
659 match inner_vec {
660 Some(VectorizedExpr::Scalar(array)) => {
661 let casted = cast::cast(&array, data_type)
662 .map_err(|e| Error::Internal(e.to_string()))?;
663 Ok(Some(VectorizedExpr::Scalar(casted)))
664 }
665 Some(VectorizedExpr::Array(array)) => {
666 let casted = cast::cast(&array, data_type)
667 .map_err(|e| Error::Internal(e.to_string()))?;
668 Ok(Some(VectorizedExpr::Array(casted)))
669 }
670 None => Ok(None),
671 }
672 }
673 ScalarExpr::Coalesce(items) => {
674 let mut evaluated_items = Vec::with_capacity(items.len());
675 let mut types = Vec::with_capacity(items.len());
676
677 for item in items {
678 let item_type = item.infer_result_type_from_arrays(arrays);
679 let vec_expr = match Self::try_evaluate_vectorized(
681 item,
682 len,
683 arrays,
684 item_type.clone(),
685 )? {
686 Some(v) => v,
687 None => return Ok(None),
688 };
689
690 let array = vec_expr.materialize(len, item_type.clone());
691 types.push(array.data_type().clone());
692 evaluated_items.push(array);
693 }
694
695 if evaluated_items.is_empty() {
696 return Ok(Some(VectorizedExpr::Array(new_null_array(
697 &DataType::Null,
698 len,
699 ))));
700 }
701
702 let mut common_type = types[0].clone();
704 for t in &types[1..] {
705 common_type = get_common_type(&common_type, t);
706 }
707
708 let mut casted_arrays = Vec::with_capacity(evaluated_items.len());
710 for array in evaluated_items {
711 if array.data_type() != &common_type {
712 let casted = cast::cast(&array, &common_type)
713 .map_err(|e| Error::Internal(e.to_string()))?;
714 casted_arrays.push(casted);
715 } else {
716 casted_arrays.push(array);
717 }
718 }
719
720 let mut result = casted_arrays[0].clone();
721 for next_array in &casted_arrays[1..] {
722 let mask = is_not_null(&result).map_err(|e| Error::Internal(e.to_string()))?;
723 result = zip(&mask, &result, next_array)
727 .map_err(|e| Error::Internal(e.to_string()))?;
728 }
729 Ok(Some(VectorizedExpr::Array(result)))
730 }
731 ScalarExpr::Random => {
732 let values: Vec<f64> = (0..len).map(|_| rand::random::<f64>()).collect();
733 let array = Float64Array::from(values);
734 Ok(Some(VectorizedExpr::Array(Arc::new(array))))
735 }
736 _ => Ok(None),
737 }
738 }
739
740 pub fn passthrough_column<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> Option<F> {
742 match Self::simplify(expr) {
743 ScalarExpr::Column(fid) => Some(fid),
744 _ => None,
745 }
746 }
747
748 pub fn simplify<F: Hash + Eq + Copy>(expr: &ScalarExpr<F>) -> ScalarExpr<F> {
750 match expr {
751 ScalarExpr::Binary { left, op, right } => {
752 let l = Self::simplify(left);
753 let r = Self::simplify(right);
754 if let (ScalarExpr::Literal(ll), ScalarExpr::Literal(rr)) = (&l, &r)
755 && let Some(folded) = fold_binary_literals(*op, ll, rr)
756 {
757 return ScalarExpr::Literal(folded);
758 }
759 ScalarExpr::Binary {
760 left: Box::new(l),
761 op: *op,
762 right: Box::new(r),
763 }
764 }
765 ScalarExpr::Cast { expr, data_type } => {
766 let inner = Self::simplify(expr);
767 if let ScalarExpr::Literal(lit) = &inner
768 && let Some(folded) = fold_cast_literal(lit, data_type)
769 {
770 return ScalarExpr::Literal(folded);
771 }
772 ScalarExpr::Cast {
773 expr: Box::new(inner),
774 data_type: data_type.clone(),
775 }
776 }
777 _ => expr.clone(),
778 }
779 }
780
781 pub fn evaluate_constant_literal_expr<F: Hash + Eq + Copy>(
782 expr: &ScalarExpr<F>,
783 ) -> LlkvResult<Option<Literal>> {
784 let simplified = Self::simplify(expr);
785
786 if let Some(literal) = Self::evaluate_constant_literal_non_numeric(&simplified)? {
787 return Ok(Some(literal));
788 }
789
790 if let ScalarExpr::Literal(lit) = &simplified {
791 return Ok(Some(lit.clone()));
792 }
793
794 let arrays = NumericArrayMap::default();
795 let array = Self::evaluate_value(&simplified, 0, &arrays)?;
796 if array.is_null(0) {
797 return Ok(None);
798 }
799 Ok(Some(Literal::from_array_ref(&array, 0)?))
800 }
801
802 pub fn evaluate_constant_literal_non_numeric<F: Hash + Eq + Copy>(
803 expr: &ScalarExpr<F>,
804 ) -> LlkvResult<Option<Literal>> {
805 match expr {
806 ScalarExpr::Literal(lit) => Ok(Some(lit.clone())),
807 ScalarExpr::Cast {
808 expr,
809 data_type: DataType::Date32,
810 } => {
811 let inner = Self::evaluate_constant_literal_non_numeric(expr)?;
812 match inner {
813 Some(Literal::Null) => Ok(Some(Literal::Null)),
814 Some(Literal::String(text)) => {
815 let days = parse_date32_literal(&text)?;
816 Ok(Some(Literal::Date32(days)))
817 }
818 Some(Literal::Date32(days)) => Ok(Some(Literal::Date32(days))),
819 Some(other) => Err(Error::InvalidArgumentError(format!(
820 "cannot cast literal of type {} to DATE",
821 other.type_name()
822 ))),
823 None => Ok(None),
824 }
825 }
826 ScalarExpr::Cast { .. } => Ok(None),
827 ScalarExpr::Binary { left, op, right } => {
828 let left_lit = match Self::evaluate_constant_literal_non_numeric(left)? {
829 Some(lit) => lit,
830 None => return Ok(None),
831 };
832 let right_lit = match Self::evaluate_constant_literal_non_numeric(right)? {
833 Some(lit) => lit,
834 None => return Ok(None),
835 };
836
837 if matches!(left_lit, Literal::Null) || matches!(right_lit, Literal::Null) {
838 return Ok(Some(Literal::Null));
839 }
840
841 match op {
842 BinaryOp::Add => match (&left_lit, &right_lit) {
843 (Literal::Date32(days), Literal::Interval(interval))
844 | (Literal::Interval(interval), Literal::Date32(days)) => {
845 let adjusted = add_interval_to_date32(*days, *interval)?;
846 Ok(Some(Literal::Date32(adjusted)))
847 }
848 (Literal::Interval(left), Literal::Interval(right)) => {
849 let sum = left.checked_add(*right).ok_or_else(|| {
850 Error::InvalidArgumentError(
851 "interval addition overflow during constant folding".into(),
852 )
853 })?;
854 Ok(Some(Literal::Interval(sum)))
855 }
856 _ => Ok(None),
857 },
858 BinaryOp::Subtract => match (&left_lit, &right_lit) {
859 (Literal::Date32(days), Literal::Interval(interval)) => {
860 let adjusted = subtract_interval_from_date32(*days, *interval)?;
861 Ok(Some(Literal::Date32(adjusted)))
862 }
863 (Literal::Interval(left), Literal::Interval(right)) => {
864 let diff = left.checked_sub(*right).ok_or_else(|| {
865 Error::InvalidArgumentError(
866 "interval subtraction overflow during constant folding".into(),
867 )
868 })?;
869 Ok(Some(Literal::Interval(diff)))
870 }
871 (Literal::Date32(lhs), Literal::Date32(rhs)) => {
872 let delta = i64::from(*lhs) - i64::from(*rhs);
873 if delta < i64::from(i32::MIN) || delta > i64::from(i32::MAX) {
874 return Err(Error::InvalidArgumentError(
875 "DATE subtraction overflowed day precision".into(),
876 ));
877 }
878 Ok(Some(Literal::Interval(IntervalValue::new(
879 0,
880 delta as i32,
881 0,
882 ))))
883 }
884 _ => Ok(None),
885 },
886 _ => Ok(None),
887 }
888 }
889 _ => Ok(None),
890 }
891 }
892
893 pub fn is_supported_numeric(dtype: &DataType) -> bool {
895 matches!(
896 dtype,
897 DataType::UInt64
898 | DataType::UInt32
899 | DataType::UInt16
900 | DataType::UInt8
901 | DataType::Int64
902 | DataType::Int32
903 | DataType::Int16
904 | DataType::Int8
905 | DataType::Float64
906 | DataType::Float32
907 )
908 }
909
910 #[allow(dead_code)]
911 fn affine_add<F: Eq + Copy>(
912 lhs: AffineState<F>,
913 rhs: AffineState<F>,
914 ) -> Option<AffineState<F>> {
915 let merged_field = Self::merge_field(lhs.field, rhs.field)?;
916 if merged_field.is_none() {
917 return Some(AffineState {
919 field: None,
920 scale: 0.0,
921 offset: lhs.offset + rhs.offset,
922 });
923 }
924 Some(AffineState {
925 field: merged_field,
926 scale: lhs.scale + rhs.scale,
927 offset: lhs.offset + rhs.offset,
928 })
929 }
930
931 #[allow(dead_code)]
932 fn affine_sub<F: Eq + Copy>(
933 lhs: AffineState<F>,
934 rhs: AffineState<F>,
935 ) -> Option<AffineState<F>> {
936 let merged_field = Self::merge_field(lhs.field, rhs.field)?;
937 if merged_field.is_none() {
938 return Some(AffineState {
939 field: None,
940 scale: 0.0,
941 offset: lhs.offset - rhs.offset,
942 });
943 }
944 Some(AffineState {
945 field: merged_field,
946 scale: lhs.scale - rhs.scale,
947 offset: lhs.offset - rhs.offset,
948 })
949 }
950
951 #[allow(dead_code)]
952 fn affine_mul<F: Eq + Copy>(
953 lhs: AffineState<F>,
954 rhs: AffineState<F>,
955 ) -> Option<AffineState<F>> {
956 if lhs.field.is_some() && rhs.field.is_some() {
957 return None; }
959 if lhs.field.is_none() {
960 let factor = lhs.offset;
961 return Some(AffineState {
962 field: rhs.field,
963 scale: rhs.scale * factor,
964 offset: rhs.offset * factor,
965 });
966 }
967 if rhs.field.is_none() {
968 let factor = rhs.offset;
969 return Some(AffineState {
970 field: lhs.field,
971 scale: lhs.scale * factor,
972 offset: lhs.offset * factor,
973 });
974 }
975 None
976 }
977
978 #[allow(dead_code)]
979 fn affine_div<F: Eq + Copy>(
980 lhs: AffineState<F>,
981 rhs: AffineState<F>,
982 ) -> Option<AffineState<F>> {
983 if rhs.field.is_some() {
984 return None;
985 }
986 let denom = rhs.offset;
987 if denom == 0.0 {
988 return None;
989 }
990 Some(AffineState {
991 field: lhs.field,
992 scale: lhs.scale / denom,
993 offset: lhs.offset / denom,
994 })
995 }
996}
997
998fn fold_binary_literals(op: BinaryOp, left: &Literal, right: &Literal) -> Option<Literal> {
999 match op {
1000 BinaryOp::BitwiseShiftLeft | BinaryOp::BitwiseShiftRight => {
1001 let pg_op = match op {
1002 BinaryOp::BitwiseShiftLeft => BinaryOperator::PGBitwiseShiftLeft,
1003 BinaryOp::BitwiseShiftRight => BinaryOperator::PGBitwiseShiftRight,
1004 _ => unreachable!(),
1005 };
1006 crate::literal::bitshift_literals(pg_op, left, right).ok()
1007 }
1008 _ => {
1009 let l_arr = ScalarEvaluator::literal_to_array(left);
1010 let r_arr = ScalarEvaluator::literal_to_array(right);
1011 let result = compute_binary(&l_arr, &r_arr, op).ok()?;
1012 if result.is_null(0) {
1013 Some(Literal::Null)
1014 } else {
1015 Literal::from_array_ref(&result, 0).ok()
1016 }
1017 }
1018 }
1019}
1020
1021fn fold_cast_literal(lit: &Literal, data_type: &DataType) -> Option<Literal> {
1022 if matches!(lit, Literal::Null) {
1023 return None;
1025 }
1026 let arr = ScalarEvaluator::literal_to_array(lit);
1027 let casted = cast::cast(&arr, data_type).ok()?;
1028 if casted.is_null(0) {
1029 Some(Literal::Null)
1030 } else {
1031 Literal::from_array_ref(&casted, 0).ok()
1032 }
1033}