1use std::{convert::TryFrom, sync::Arc};
8
9use arrow::array::{Array, ArrayRef, Float64Array, Int64Array};
10use arrow::compute::cast;
11use arrow::datatypes::DataType;
12use llkv_column_map::types::LogicalFieldId;
13use llkv_expr::{BinaryOp, CompareOp, ScalarExpr};
14use llkv_result::{Error, Result as LlkvResult};
15use rustc_hash::{FxHashMap, FxHashSet};
16
17use crate::types::FieldId;
18
19pub type NumericArrayMap = FxHashMap<FieldId, NumericArray>;
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum NumericKind {
25 Integer,
26 Float,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq)]
31pub enum NumericValue {
32 Integer(i64),
33 Float(f64),
34}
35
36impl NumericValue {
37 #[inline]
38 pub fn as_f64(self) -> f64 {
39 match self {
40 NumericValue::Integer(v) => v as f64,
41 NumericValue::Float(v) => v,
42 }
43 }
44
45 #[inline]
46 pub fn as_i64(self) -> Option<i64> {
47 match self {
48 NumericValue::Integer(v) => Some(v),
49 NumericValue::Float(_) => None,
50 }
51 }
52
53 #[inline]
54 pub fn kind(self) -> NumericKind {
55 match self {
56 NumericValue::Integer(_) => NumericKind::Integer,
57 NumericValue::Float(_) => NumericKind::Float,
58 }
59 }
60}
61
62impl From<i64> for NumericValue {
63 fn from(value: i64) -> Self {
64 NumericValue::Integer(value)
65 }
66}
67
68impl From<f64> for NumericValue {
69 fn from(value: f64) -> Self {
70 NumericValue::Float(value)
71 }
72}
73
74#[derive(Clone)]
76pub struct NumericArray {
77 kind: NumericKind,
78 len: usize,
79 int_data: Option<Arc<Int64Array>>,
80 float_data: Option<Arc<Float64Array>>,
81}
82
83impl NumericArray {
84 pub(crate) fn from_int(array: Arc<Int64Array>) -> Self {
85 let len = array.len();
86 Self {
87 kind: NumericKind::Integer,
88 len,
89 int_data: Some(array),
90 float_data: None,
91 }
92 }
93
94 pub(crate) fn from_float(array: Arc<Float64Array>) -> Self {
95 let len = array.len();
96 Self {
97 kind: NumericKind::Float,
98 len,
99 int_data: None,
100 float_data: Some(array),
101 }
102 }
103
104 pub fn try_from_arrow(array: &ArrayRef) -> LlkvResult<Self> {
106 match array.data_type() {
107 DataType::Int64 => {
108 let int_array = array
109 .as_any()
110 .downcast_ref::<Int64Array>()
111 .ok_or_else(|| Error::Internal("expected Int64 array".into()))?
112 .clone();
113 Ok(NumericArray::from_int(Arc::new(int_array)))
114 }
115 DataType::Float64 => {
116 let float_array = array
117 .as_any()
118 .downcast_ref::<Float64Array>()
119 .ok_or_else(|| Error::Internal("expected Float64 array".into()))?
120 .clone();
121 Ok(NumericArray::from_float(Arc::new(float_array)))
122 }
123 DataType::Int8 | DataType::Int16 | DataType::Int32 => {
124 let casted = cast(array.as_ref(), &DataType::Int64)
125 .map_err(|e| Error::Internal(format!("cast to Int64 failed: {e}")))?;
126 let int_array = casted
127 .as_any()
128 .downcast_ref::<Int64Array>()
129 .ok_or_else(|| Error::Internal("cast produced non-Int64 array".into()))?
130 .clone();
131 Ok(NumericArray::from_int(Arc::new(int_array)))
132 }
133 DataType::UInt8
134 | DataType::UInt16
135 | DataType::UInt32
136 | DataType::UInt64
137 | DataType::Float32 => {
138 let casted = cast(array.as_ref(), &DataType::Float64)
139 .map_err(|e| Error::Internal(format!("cast to Float64 failed: {e}")))?;
140 let float_array = casted
141 .as_any()
142 .downcast_ref::<Float64Array>()
143 .ok_or_else(|| Error::Internal("cast produced non-Float64 array".into()))?
144 .clone();
145 Ok(NumericArray::from_float(Arc::new(float_array)))
146 }
147 DataType::Boolean => {
148 let casted = cast(array.as_ref(), &DataType::Int64)
149 .map_err(|e| Error::Internal(format!("cast to Int64 failed: {e}")))?;
150 let int_array = casted
151 .as_any()
152 .downcast_ref::<Int64Array>()
153 .ok_or_else(|| Error::Internal("cast produced non-Int64 array".into()))?
154 .clone();
155 Ok(NumericArray::from_int(Arc::new(int_array)))
156 }
157 DataType::Null => {
158 let float_array = Float64Array::from(vec![None; array.len()]);
159 Ok(NumericArray::from_float(Arc::new(float_array)))
160 }
161 other => Err(Error::InvalidArgumentError(format!(
162 "unsupported data type in numeric kernel: {other:?}"
163 ))),
164 }
165 }
166
167 #[inline]
168 pub fn kind(&self) -> NumericKind {
169 self.kind
170 }
171
172 #[inline]
173 pub fn len(&self) -> usize {
174 self.len
175 }
176
177 #[inline]
178 pub fn is_empty(&self) -> bool {
179 self.len == 0
180 }
181
182 pub fn value(&self, idx: usize) -> Option<NumericValue> {
183 match self.kind {
184 NumericKind::Integer => {
185 let array = self
186 .int_data
187 .as_ref()
188 .expect("integer array missing backing data");
189 if array.is_null(idx) {
190 None
191 } else {
192 Some(NumericValue::Integer(array.value(idx)))
193 }
194 }
195 NumericKind::Float => {
196 let array = self
197 .float_data
198 .as_ref()
199 .expect("float array missing backing data");
200 if array.is_null(idx) {
201 None
202 } else {
203 Some(NumericValue::Float(array.value(idx)))
204 }
205 }
206 }
207 }
208
209 fn to_array_ref(&self) -> ArrayRef {
210 match self.kind {
211 NumericKind::Integer => Arc::clone(
212 self.int_data
213 .as_ref()
214 .expect("integer array missing backing data"),
215 ) as ArrayRef,
216 NumericKind::Float => Arc::clone(
217 self.float_data
218 .as_ref()
219 .expect("float array missing backing data"),
220 ) as ArrayRef,
221 }
222 }
223
224 fn promote_to_float(&self) -> NumericArray {
225 match self.kind {
226 NumericKind::Float => self.clone(),
227 NumericKind::Integer => {
228 let array = self
229 .int_data
230 .as_ref()
231 .expect("integer array missing backing data");
232 let iter = (0..self.len).map(|idx| {
233 if array.is_null(idx) {
234 None
235 } else {
236 Some(array.value(idx) as f64)
237 }
238 });
239 let float_array = Float64Array::from_iter(iter);
240 NumericArray::from_float(Arc::new(float_array))
241 }
242 }
243 }
244
245 fn to_aligned_array_ref(&self, preferred: NumericKind) -> ArrayRef {
246 match (preferred, self.kind) {
247 (NumericKind::Float, NumericKind::Integer) => self.promote_to_float().to_array_ref(),
248 _ => self.to_array_ref(),
249 }
250 }
251
252 fn from_numeric_values(values: Vec<Option<NumericValue>>, preferred: NumericKind) -> Self {
253 let contains_float = values
254 .iter()
255 .any(|opt| matches!(opt, Some(NumericValue::Float(_))));
256 match (contains_float, preferred) {
257 (true, _) => {
258 let iter = values.into_iter().map(|opt| opt.map(|v| v.as_f64()));
259 let array = Float64Array::from_iter(iter);
260 NumericArray::from_float(Arc::new(array))
261 }
262 (false, NumericKind::Float) => {
263 let iter = values.into_iter().map(|opt| opt.map(|v| v.as_f64()));
264 let array = Float64Array::from_iter(iter);
265 NumericArray::from_float(Arc::new(array))
266 }
267 (false, NumericKind::Integer) => {
268 let iter = values
269 .into_iter()
270 .map(|opt| opt.map(|v| v.as_i64().expect("expected integer")));
271 let array = Int64Array::from_iter(iter);
272 NumericArray::from_int(Arc::new(array))
273 }
274 }
275 }
276}
277
278enum VectorizedExpr {
280 Array(NumericArray),
281 Scalar(Option<NumericValue>),
282}
283
284impl VectorizedExpr {
285 fn materialize(self, len: usize, kind: NumericKind) -> ArrayRef {
286 match self {
287 VectorizedExpr::Array(array) => array.to_aligned_array_ref(kind),
288 VectorizedExpr::Scalar(Some(value)) => {
289 let target_kind = match (value.kind(), kind) {
290 (NumericKind::Float, _) => NumericKind::Float,
291 (NumericKind::Integer, NumericKind::Float) => NumericKind::Float,
292 (NumericKind::Integer, NumericKind::Integer) => NumericKind::Integer,
293 };
294 let values = vec![Some(value); len];
295 let array = NumericArray::from_numeric_values(values, target_kind);
296 array.to_aligned_array_ref(kind)
297 }
298 VectorizedExpr::Scalar(None) => {
299 let values = vec![None; len];
300 let array = NumericArray::from_numeric_values(values, kind);
301 array.to_aligned_array_ref(kind)
302 }
303 }
304 }
305}
306
307#[derive(Clone, Copy, Debug)]
309pub struct AffineExpr {
310 pub field: FieldId,
311 pub scale: f64,
312 pub offset: f64,
313}
314
315#[derive(Clone, Copy, Debug)]
317struct AffineState {
318 field: Option<FieldId>,
319 scale: f64,
320 offset: f64,
321}
322
323fn merge_field(lhs: Option<FieldId>, rhs: Option<FieldId>) -> Option<Option<FieldId>> {
326 match (lhs, rhs) {
327 (Some(a), Some(b)) => {
328 if a == b {
329 Some(Some(a))
330 } else {
331 None
332 }
333 }
334 (Some(a), None) => Some(Some(a)),
335 (None, Some(b)) => Some(Some(b)),
336 (None, None) => Some(None),
337 }
338}
339
340pub struct NumericKernels;
343
344impl NumericKernels {
345 pub fn collect_fields(expr: &ScalarExpr<FieldId>, acc: &mut FxHashSet<FieldId>) {
347 match expr {
348 ScalarExpr::Column(fid) => {
349 acc.insert(*fid);
350 }
351 ScalarExpr::Literal(_) => {}
352 ScalarExpr::Binary { left, right, .. } => {
353 Self::collect_fields(left, acc);
354 Self::collect_fields(right, acc);
355 }
356 ScalarExpr::Aggregate(agg) => {
357 match agg {
359 llkv_expr::expr::AggregateCall::CountStar => {}
360 llkv_expr::expr::AggregateCall::Count(fid)
361 | llkv_expr::expr::AggregateCall::Sum(fid)
362 | llkv_expr::expr::AggregateCall::Min(fid)
363 | llkv_expr::expr::AggregateCall::Max(fid)
364 | llkv_expr::expr::AggregateCall::CountNulls(fid) => {
365 acc.insert(*fid);
366 }
367 }
368 }
369 ScalarExpr::GetField { base, .. } => {
370 Self::collect_fields(base, acc);
372 }
373 }
374 }
375
376 pub fn prepare_numeric_arrays(
378 lfids: &[LogicalFieldId],
379 arrays: &[ArrayRef],
380 needed_fields: &FxHashSet<FieldId>,
381 ) -> LlkvResult<NumericArrayMap> {
382 let mut out: NumericArrayMap = FxHashMap::default();
383 if needed_fields.is_empty() {
384 return Ok(out);
385 }
386 for (lfid, array) in lfids.iter().zip(arrays.iter()) {
387 let fid = lfid.field_id();
388 if !needed_fields.contains(&fid) {
389 continue;
390 }
391 let numeric = Self::coerce_array(array)?;
392 out.insert(fid, numeric);
393 }
394 Ok(out)
395 }
396
397 pub fn evaluate_value(
399 expr: &ScalarExpr<FieldId>,
400 idx: usize,
401 arrays: &NumericArrayMap,
402 ) -> LlkvResult<Option<NumericValue>> {
403 match expr {
404 ScalarExpr::Column(fid) => {
405 let array = arrays
406 .get(fid)
407 .ok_or_else(|| Error::Internal(format!("missing column for field {fid}")))?;
408 Ok(array.value(idx))
409 }
410 ScalarExpr::Literal(_) => Ok(Self::literal_numeric_value(expr)),
411 ScalarExpr::Binary { left, op, right } => {
412 let l = Self::evaluate_value(left, idx, arrays)?;
413 let r = Self::evaluate_value(right, idx, arrays)?;
414 Ok(Self::apply_binary(*op, l, r))
415 }
416 ScalarExpr::Aggregate(_) => Err(Error::Internal(
417 "Aggregate expressions should not appear in row-level evaluation".into(),
418 )),
419 ScalarExpr::GetField { .. } => Err(Error::Internal(
420 "GetField expressions should not be evaluated in numeric kernels".into(),
421 )),
422 }
423 }
424
425 #[allow(dead_code)]
427 pub fn evaluate_batch(
428 expr: &ScalarExpr<FieldId>,
429 len: usize,
430 arrays: &NumericArrayMap,
431 ) -> LlkvResult<ArrayRef> {
432 let simplified = Self::simplify(expr);
433 Self::evaluate_batch_simplified(&simplified, len, arrays)
434 }
435
436 pub fn evaluate_batch_simplified(
438 expr: &ScalarExpr<FieldId>,
439 len: usize,
440 arrays: &NumericArrayMap,
441 ) -> LlkvResult<ArrayRef> {
442 let preferred = Self::infer_result_kind(expr, arrays);
443 if let Some(vectorized) = Self::try_evaluate_vectorized(expr, len, arrays, preferred)? {
444 return Ok(vectorized.materialize(len, preferred));
445 }
446
447 let mut values: Vec<Option<NumericValue>> = Vec::with_capacity(len);
448 for idx in 0..len {
449 values.push(Self::evaluate_value(expr, idx, arrays)?);
450 }
451 let array = NumericArray::from_numeric_values(values, preferred);
452 Ok(array.to_aligned_array_ref(preferred))
453 }
454
455 fn try_evaluate_vectorized(
456 expr: &ScalarExpr<FieldId>,
457 len: usize,
458 arrays: &NumericArrayMap,
459 preferred: NumericKind,
460 ) -> LlkvResult<Option<VectorizedExpr>> {
461 match expr {
462 ScalarExpr::Column(fid) => {
463 let array = arrays
464 .get(fid)
465 .ok_or_else(|| Error::Internal(format!("missing column for field {fid}")))?;
466 Ok(Some(VectorizedExpr::Array(array.clone())))
467 }
468 ScalarExpr::Literal(_) => Ok(Some(VectorizedExpr::Scalar(
469 Self::literal_numeric_value(expr),
470 ))),
471 ScalarExpr::Binary { left, op, right } => {
472 let left_kind = Self::infer_result_kind(left, arrays);
473 let right_kind = Self::infer_result_kind(right, arrays);
474
475 let left_vec = Self::try_evaluate_vectorized(left, len, arrays, left_kind)?;
476 let right_vec = Self::try_evaluate_vectorized(right, len, arrays, right_kind)?;
477
478 match (left_vec, right_vec) {
479 (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Scalar(rhs))) => Ok(
480 Some(VectorizedExpr::Scalar(Self::apply_binary(*op, lhs, rhs))),
481 ),
482 (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Array(rhs))) => {
483 let array =
484 Self::compute_binary_array_array(&lhs, &rhs, len, *op, preferred)?;
485 Ok(Some(VectorizedExpr::Array(array)))
486 }
487 (Some(VectorizedExpr::Array(lhs)), Some(VectorizedExpr::Scalar(rhs))) => {
488 let array = Self::compute_binary_array_scalar(
489 &lhs, rhs, len, *op, true, preferred,
490 )?;
491 Ok(Some(VectorizedExpr::Array(array)))
492 }
493 (Some(VectorizedExpr::Scalar(lhs)), Some(VectorizedExpr::Array(rhs))) => {
494 let array = Self::compute_binary_array_scalar(
495 &rhs, lhs, len, *op, false, preferred,
496 )?;
497 Ok(Some(VectorizedExpr::Array(array)))
498 }
499 _ => Ok(None),
500 }
501 }
502 ScalarExpr::Aggregate(_) => Err(Error::Internal(
503 "Aggregate expressions should not appear in row-level evaluation".into(),
504 )),
505 ScalarExpr::GetField { .. } => Err(Error::Internal(
506 "GetField expressions should not be evaluated in numeric kernels".into(),
507 )),
508 }
509 }
510
511 fn compute_binary_array_array(
512 left: &NumericArray,
513 right: &NumericArray,
514 len: usize,
515 op: BinaryOp,
516 preferred: NumericKind,
517 ) -> LlkvResult<NumericArray> {
518 if left.len() != len || right.len() != len {
519 return Err(Error::Internal("scalar expression length mismatch".into()));
520 }
521
522 let iter = (0..len).map(|idx| {
523 let lhs = left.value(idx);
524 let rhs = right.value(idx);
525 Self::apply_binary(op, lhs, rhs)
526 });
527
528 let values = iter.collect::<Vec<_>>();
529 Ok(NumericArray::from_numeric_values(values, preferred))
530 }
531
532 fn compute_binary_array_scalar(
533 array: &NumericArray,
534 scalar: Option<NumericValue>,
535 len: usize,
536 op: BinaryOp,
537 array_is_left: bool,
538 preferred: NumericKind,
539 ) -> LlkvResult<NumericArray> {
540 if array.len() != len {
541 return Err(Error::Internal("scalar expression length mismatch".into()));
542 }
543
544 if scalar.is_none() {
545 return Ok(NumericArray::from_numeric_values(
546 vec![None; len],
547 preferred,
548 ));
549 }
550 let scalar_value = scalar.expect("checked above");
551
552 if array_is_left && matches!(op, BinaryOp::Divide | BinaryOp::Modulo) {
553 let is_zero = matches!(
554 scalar_value,
555 NumericValue::Integer(0) | NumericValue::Float(0.0)
556 );
557 if is_zero {
558 return Ok(NumericArray::from_numeric_values(
559 vec![None; len],
560 preferred,
561 ));
562 }
563 }
564
565 let iter = (0..len).map(|idx| {
566 let array_val = array.value(idx);
567 let (lhs, rhs) = if array_is_left {
568 (array_val, Some(scalar_value))
569 } else {
570 (Some(scalar_value), array_val)
571 };
572 Self::apply_binary(op, lhs, rhs)
573 });
574
575 let values = iter.collect::<Vec<_>>();
576 Ok(NumericArray::from_numeric_values(values, preferred))
577 }
578
579 pub fn passthrough_column(expr: &ScalarExpr<FieldId>) -> Option<FieldId> {
581 match Self::simplify(expr) {
582 ScalarExpr::Column(fid) => Some(fid),
583 _ => None,
584 }
585 }
586
587 fn literal_numeric_value(expr: &ScalarExpr<FieldId>) -> Option<NumericValue> {
588 if let ScalarExpr::Literal(lit) = expr {
589 match lit {
590 llkv_expr::literal::Literal::Float(f) => Some(NumericValue::Float(*f)),
591 llkv_expr::literal::Literal::Integer(i) => {
592 if let Ok(value) = i64::try_from(*i) {
593 Some(NumericValue::Integer(value))
594 } else {
595 Some(NumericValue::Float(*i as f64))
596 }
597 }
598 llkv_expr::literal::Literal::Boolean(b) => {
599 Some(NumericValue::Integer(if *b { 1 } else { 0 }))
600 }
601 llkv_expr::literal::Literal::String(_) => None,
602 llkv_expr::literal::Literal::Struct(_) => None,
603 llkv_expr::literal::Literal::Null => None,
604 }
605 } else {
606 None
607 }
608 }
609
610 fn literal_is_zero(expr: &ScalarExpr<FieldId>) -> bool {
611 matches!(
612 Self::literal_numeric_value(expr),
613 Some(NumericValue::Integer(0)) | Some(NumericValue::Float(0.0))
614 )
615 }
616
617 fn literal_is_one(expr: &ScalarExpr<FieldId>) -> bool {
618 matches!(
619 Self::literal_numeric_value(expr),
620 Some(NumericValue::Integer(1)) | Some(NumericValue::Float(1.0))
621 )
622 }
623
624 pub fn simplify(expr: &ScalarExpr<FieldId>) -> ScalarExpr<FieldId> {
626 match expr {
627 ScalarExpr::Column(_)
628 | ScalarExpr::Literal(_)
629 | ScalarExpr::Aggregate(_)
630 | ScalarExpr::GetField { .. } => expr.clone(),
631 ScalarExpr::Binary { left, op, right } => {
632 let left_s = Self::simplify(left);
633 let right_s = Self::simplify(right);
634
635 if let (Some(lv), Some(rv)) = (
636 Self::literal_numeric_value(&left_s),
637 Self::literal_numeric_value(&right_s),
638 ) && let Some(lit) = Self::apply_binary_literal(*op, lv, rv)
639 {
640 return lit;
641 }
642
643 match op {
644 BinaryOp::Add => {
645 if Self::literal_is_zero(&left_s) {
646 return right_s;
647 }
648 if Self::literal_is_zero(&right_s) {
649 return left_s;
650 }
651 }
652 BinaryOp::Subtract => {
653 if Self::literal_is_zero(&right_s) {
654 return left_s;
655 }
656 }
657 BinaryOp::Multiply => {
658 if Self::literal_is_one(&left_s) {
659 return right_s;
660 }
661 if Self::literal_is_one(&right_s) {
662 return left_s;
663 }
664 if Self::literal_is_zero(&left_s) || Self::literal_is_zero(&right_s) {
665 return ScalarExpr::literal(0);
666 }
667 }
668 BinaryOp::Divide => {
669 if Self::literal_is_one(&right_s) {
670 return left_s;
671 }
672 }
673 BinaryOp::Modulo => {}
674 }
675
676 ScalarExpr::binary(left_s, *op, right_s)
677 }
678 }
679 }
680
681 #[allow(dead_code)]
684 pub fn extract_affine(expr: &ScalarExpr<FieldId>) -> Option<AffineExpr> {
685 let simplified = Self::simplify(expr);
686 Self::extract_affine_simplified(&simplified)
687 }
688
689 pub fn extract_affine_simplified(expr: &ScalarExpr<FieldId>) -> Option<AffineExpr> {
691 let state = Self::affine_state(expr)?;
692 let field = state.field?;
693 Some(AffineExpr {
694 field,
695 scale: state.scale,
696 offset: state.offset,
697 })
698 }
699
700 fn affine_state(expr: &ScalarExpr<FieldId>) -> Option<AffineState> {
701 match expr {
702 ScalarExpr::Column(fid) => Some(AffineState {
703 field: Some(*fid),
704 scale: 1.0,
705 offset: 0.0,
706 }),
707 ScalarExpr::Literal(_) => {
708 let value = Self::literal_numeric_value(expr)?.as_f64();
709 Some(AffineState {
710 field: None,
711 scale: 0.0,
712 offset: value,
713 })
714 }
715 ScalarExpr::Aggregate(_) => None, ScalarExpr::GetField { .. } => None, ScalarExpr::Binary { left, op, right } => {
718 let left_state = Self::affine_state(left)?;
719 let right_state = Self::affine_state(right)?;
720 match op {
721 BinaryOp::Add => Self::affine_add(left_state, right_state),
722 BinaryOp::Subtract => Self::affine_sub(left_state, right_state),
723 BinaryOp::Multiply => Self::affine_mul(left_state, right_state),
724 BinaryOp::Divide => Self::affine_div(left_state, right_state),
725 BinaryOp::Modulo => None,
726 }
727 }
728 }
729 }
730
731 fn affine_add(lhs: AffineState, rhs: AffineState) -> Option<AffineState> {
732 let field = merge_field(lhs.field, rhs.field)?;
733 Some(AffineState {
734 field,
735 scale: lhs.scale + rhs.scale,
736 offset: lhs.offset + rhs.offset,
737 })
738 }
739
740 fn affine_sub(lhs: AffineState, rhs: AffineState) -> Option<AffineState> {
741 let neg_rhs = AffineState {
742 field: rhs.field,
743 scale: -rhs.scale,
744 offset: -rhs.offset,
745 };
746 Self::affine_add(lhs, neg_rhs)
747 }
748
749 fn affine_mul(lhs: AffineState, rhs: AffineState) -> Option<AffineState> {
750 if rhs.field.is_none() {
751 let factor = rhs.offset;
752 return Some(AffineState {
753 field: lhs.field,
754 scale: lhs.scale * factor,
755 offset: lhs.offset * factor,
756 });
757 }
758 if lhs.field.is_none() {
759 let factor = lhs.offset;
760 return Some(AffineState {
761 field: rhs.field,
762 scale: rhs.scale * factor,
763 offset: rhs.offset * factor,
764 });
765 }
766 None
767 }
768
769 fn affine_div(lhs: AffineState, rhs: AffineState) -> Option<AffineState> {
770 if rhs.field.is_some() {
771 return None;
772 }
773 let denom = rhs.offset;
774 if denom == 0.0 {
775 return None;
776 }
777 Some(AffineState {
778 field: lhs.field,
779 scale: lhs.scale / denom,
780 offset: lhs.offset / denom,
781 })
782 }
783
784 fn apply_binary_literal(
785 op: BinaryOp,
786 lhs: NumericValue,
787 rhs: NumericValue,
788 ) -> Option<ScalarExpr<FieldId>> {
789 match op {
790 BinaryOp::Add => Some(Self::literal_from_numeric(Self::add_values(lhs, rhs))),
791 BinaryOp::Subtract => Some(Self::literal_from_numeric(Self::sub_values(lhs, rhs))),
792 BinaryOp::Multiply => Some(Self::literal_from_numeric(Self::mul_values(lhs, rhs))),
793 BinaryOp::Divide => Self::div_values(lhs, rhs).map(Self::literal_from_numeric),
794 BinaryOp::Modulo => Self::mod_values(lhs, rhs).map(Self::literal_from_numeric),
795 }
796 }
797
798 fn literal_from_numeric(value: NumericValue) -> ScalarExpr<FieldId> {
799 match value {
800 NumericValue::Integer(i) => ScalarExpr::literal(i),
801 NumericValue::Float(f) => ScalarExpr::literal(f),
802 }
803 }
804
805 pub fn apply_binary(
807 op: BinaryOp,
808 lhs: Option<NumericValue>,
809 rhs: Option<NumericValue>,
810 ) -> Option<NumericValue> {
811 match (lhs, rhs) {
812 (Some(lv), Some(rv)) => Self::apply_binary_values(op, lv, rv),
813 _ => None,
814 }
815 }
816
817 fn apply_binary_values(
818 op: BinaryOp,
819 lhs: NumericValue,
820 rhs: NumericValue,
821 ) -> Option<NumericValue> {
822 match op {
823 BinaryOp::Add => Some(Self::add_values(lhs, rhs)),
824 BinaryOp::Subtract => Some(Self::sub_values(lhs, rhs)),
825 BinaryOp::Multiply => Some(Self::mul_values(lhs, rhs)),
826 BinaryOp::Divide => Self::div_values(lhs, rhs),
827 BinaryOp::Modulo => Self::mod_values(lhs, rhs),
828 }
829 }
830
831 fn add_values(lhs: NumericValue, rhs: NumericValue) -> NumericValue {
832 match (lhs, rhs) {
833 (NumericValue::Integer(li), NumericValue::Integer(ri)) => match li.checked_add(ri) {
834 Some(sum) => NumericValue::Integer(sum),
835 None => NumericValue::Float(li as f64 + ri as f64),
836 },
837 _ => NumericValue::Float(lhs.as_f64() + rhs.as_f64()),
838 }
839 }
840
841 fn sub_values(lhs: NumericValue, rhs: NumericValue) -> NumericValue {
842 match (lhs, rhs) {
843 (NumericValue::Integer(li), NumericValue::Integer(ri)) => match li.checked_sub(ri) {
844 Some(diff) => NumericValue::Integer(diff),
845 None => NumericValue::Float(li as f64 - ri as f64),
846 },
847 _ => NumericValue::Float(lhs.as_f64() - rhs.as_f64()),
848 }
849 }
850
851 fn mul_values(lhs: NumericValue, rhs: NumericValue) -> NumericValue {
852 match (lhs, rhs) {
853 (NumericValue::Integer(li), NumericValue::Integer(ri)) => match li.checked_mul(ri) {
854 Some(prod) => NumericValue::Integer(prod),
855 None => NumericValue::Float(li as f64 * ri as f64),
856 },
857 _ => NumericValue::Float(lhs.as_f64() * rhs.as_f64()),
858 }
859 }
860
861 fn div_values(lhs: NumericValue, rhs: NumericValue) -> Option<NumericValue> {
862 match rhs {
863 NumericValue::Integer(0) | NumericValue::Float(0.0) => return None,
864 _ => {}
865 }
866
867 match (lhs, rhs) {
868 (NumericValue::Integer(li), NumericValue::Integer(ri)) => {
869 Some(NumericValue::Integer(li / ri))
870 }
871 _ => Some(NumericValue::Float(lhs.as_f64() / rhs.as_f64())),
872 }
873 }
874
875 fn mod_values(lhs: NumericValue, rhs: NumericValue) -> Option<NumericValue> {
876 match rhs {
877 NumericValue::Integer(0) | NumericValue::Float(0.0) => return None,
878 _ => {}
879 }
880
881 match (lhs, rhs) {
882 (NumericValue::Integer(li), NumericValue::Integer(ri)) => {
883 Some(NumericValue::Integer(li % ri))
884 }
885 _ => Some(NumericValue::Float(lhs.as_f64() % rhs.as_f64())),
886 }
887 }
888
889 fn infer_result_kind(expr: &ScalarExpr<FieldId>, arrays: &NumericArrayMap) -> NumericKind {
890 match expr {
891 ScalarExpr::Literal(lit) => match lit {
892 llkv_expr::literal::Literal::Float(_) => NumericKind::Float,
893 llkv_expr::literal::Literal::Integer(_) => NumericKind::Integer,
894 llkv_expr::literal::Literal::Boolean(_) => NumericKind::Integer,
895 llkv_expr::literal::Literal::Null => NumericKind::Integer,
896 llkv_expr::literal::Literal::String(_) => NumericKind::Float,
897 llkv_expr::literal::Literal::Struct(_) => NumericKind::Float,
898 },
899 ScalarExpr::Column(fid) => arrays
900 .get(fid)
901 .map(|arr| arr.kind())
902 .unwrap_or(NumericKind::Float),
903 ScalarExpr::Binary { left, op, right } => {
904 let left_kind = Self::infer_result_kind(left, arrays);
905 let right_kind = Self::infer_result_kind(right, arrays);
906 Self::binary_result_kind(*op, left_kind, right_kind)
907 }
908 ScalarExpr::Aggregate(_) => NumericKind::Float,
909 ScalarExpr::GetField { .. } => NumericKind::Float,
910 }
911 }
912
913 pub fn infer_result_kind_from_types<F>(
915 expr: &ScalarExpr<FieldId>,
916 resolve_kind: &mut F,
917 ) -> Option<NumericKind>
918 where
919 F: FnMut(FieldId) -> Option<NumericKind>,
920 {
921 match expr {
922 ScalarExpr::Literal(_) => Self::literal_numeric_value(expr).map(|v| v.kind()),
923 ScalarExpr::Column(fid) => resolve_kind(*fid),
924 ScalarExpr::Binary { left, op, right } => {
925 let left_kind = Self::infer_result_kind_from_types(left, resolve_kind)?;
926 let right_kind = Self::infer_result_kind_from_types(right, resolve_kind)?;
927 Some(Self::binary_result_kind(*op, left_kind, right_kind))
928 }
929 ScalarExpr::Aggregate(_) => Some(NumericKind::Float),
930 ScalarExpr::GetField { .. } => None,
931 }
932 }
933
934 pub fn kind_for_data_type(dtype: &DataType) -> Option<NumericKind> {
936 match dtype {
937 DataType::Int8
938 | DataType::Int16
939 | DataType::Int32
940 | DataType::Int64
941 | DataType::Boolean => Some(NumericKind::Integer),
942 DataType::UInt8
943 | DataType::UInt16
944 | DataType::UInt32
945 | DataType::UInt64
946 | DataType::Float32
947 | DataType::Float64
948 | DataType::Null => Some(NumericKind::Float),
949 _ => None,
950 }
951 }
952
953 fn binary_result_kind(
954 op: BinaryOp,
955 lhs_kind: NumericKind,
956 rhs_kind: NumericKind,
957 ) -> NumericKind {
958 let lhs_value = match lhs_kind {
959 NumericKind::Integer => NumericValue::Integer(1),
960 NumericKind::Float => NumericValue::Float(1.0),
961 };
962 let rhs_value = match rhs_kind {
963 NumericKind::Integer => NumericValue::Integer(1),
964 NumericKind::Float => NumericValue::Float(1.0),
965 };
966
967 Self::apply_binary_values(op, lhs_value, rhs_value)
968 .unwrap_or(NumericValue::Float(0.0))
969 .kind()
970 }
971
972 pub fn compare(op: CompareOp, lhs: NumericValue, rhs: NumericValue) -> bool {
974 match (lhs, rhs) {
975 (NumericValue::Integer(li), NumericValue::Integer(ri)) => match op {
976 CompareOp::Eq => li == ri,
977 CompareOp::NotEq => li != ri,
978 CompareOp::Lt => li < ri,
979 CompareOp::LtEq => li <= ri,
980 CompareOp::Gt => li > ri,
981 CompareOp::GtEq => li >= ri,
982 },
983 (lv, rv) => {
984 let lf = lv.as_f64();
985 let rf = rv.as_f64();
986 match op {
987 CompareOp::Eq => lf == rf,
988 CompareOp::NotEq => lf != rf,
989 CompareOp::Lt => lf < rf,
990 CompareOp::LtEq => lf <= rf,
991 CompareOp::Gt => lf > rf,
992 CompareOp::GtEq => lf >= rf,
993 }
994 }
995 }
996 }
997
998 fn coerce_array(array: &ArrayRef) -> LlkvResult<NumericArray> {
999 NumericArray::try_from_arrow(array)
1000 }
1001}
1002
1003#[cfg(test)]
1004mod tests {
1005 use super::*;
1006 use arrow::array::{Float64Array, Int64Array};
1007 use llkv_expr::Literal;
1008
1009 fn float_array(values: &[Option<f64>]) -> NumericArray {
1010 let array = Float64Array::from(values.to_vec());
1011 NumericArray::from_float(Arc::new(array))
1012 }
1013
1014 fn int_array(values: &[Option<i64>]) -> NumericArray {
1015 let array = Int64Array::from(values.to_vec());
1016 NumericArray::from_int(Arc::new(array))
1017 }
1018
1019 #[test]
1020 fn integer_addition_preserves_int_type() {
1021 const F1: FieldId = 30;
1022 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1023 arrays.insert(F1, int_array(&[Some(1), Some(-5), None, Some(42)]));
1024
1025 let expr = ScalarExpr::binary(
1026 ScalarExpr::column(F1),
1027 BinaryOp::Add,
1028 ScalarExpr::literal(3),
1029 );
1030
1031 let result = NumericKernels::evaluate_batch(&expr, 4, &arrays).unwrap();
1032 let array = result
1033 .as_ref()
1034 .as_any()
1035 .downcast_ref::<Int64Array>()
1036 .expect("expected Int64Array");
1037
1038 assert_eq!(array.len(), 4);
1039 assert_eq!(array.value(0), 4);
1040 assert_eq!(array.value(1), -2);
1041 assert!(array.is_null(2));
1042 assert_eq!(array.value(3), 45);
1043 }
1044
1045 #[test]
1046 fn integer_division_matches_sqlite_semantics() {
1047 const F1: FieldId = 31;
1048 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1049 arrays.insert(F1, int_array(&[Some(5), Some(-7), Some(0), None]));
1050
1051 let expr = ScalarExpr::binary(
1052 ScalarExpr::column(F1),
1053 BinaryOp::Divide,
1054 ScalarExpr::literal(2),
1055 );
1056
1057 let result = NumericKernels::evaluate_batch(&expr, 4, &arrays).unwrap();
1058 let array = result
1059 .as_ref()
1060 .as_any()
1061 .downcast_ref::<Int64Array>()
1062 .expect("expected Int64Array");
1063
1064 assert_eq!(array.len(), 4);
1065 assert_eq!(array.value(0), 2);
1066 assert_eq!(array.value(1), -3);
1067 assert_eq!(array.value(2), 0);
1068 assert!(array.is_null(3));
1069 }
1070
1071 #[test]
1072 fn integer_overflow_promotes_to_float_array() {
1073 const F1: FieldId = 32;
1074 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1075 arrays.insert(F1, int_array(&[Some(i64::MAX), Some(10)]));
1076
1077 let expr = ScalarExpr::binary(
1078 ScalarExpr::column(F1),
1079 BinaryOp::Add,
1080 ScalarExpr::literal(1),
1081 );
1082
1083 let result = NumericKernels::evaluate_batch(&expr, 2, &arrays).unwrap();
1084 assert!(
1085 result
1086 .as_ref()
1087 .as_any()
1088 .downcast_ref::<Int64Array>()
1089 .is_none()
1090 );
1091
1092 let array = result
1093 .as_ref()
1094 .as_any()
1095 .downcast_ref::<Float64Array>()
1096 .expect("expected Float64Array after overflow");
1097
1098 assert_eq!(array.len(), 2);
1099 assert!(array.value(0).is_finite());
1100 assert_eq!(array.value(1), 11.0);
1101 }
1102
1103 #[test]
1104 fn vectorized_add_columns() {
1105 const F1: FieldId = 1;
1106 const F2: FieldId = 2;
1107 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1108 arrays.insert(F1, float_array(&[Some(1.0), Some(2.0), None, Some(-1.0)]));
1109 arrays.insert(
1110 F2,
1111 float_array(&[Some(5.0), Some(-1.0), Some(3.0), Some(4.0)]),
1112 );
1113
1114 let expr = ScalarExpr::binary(
1115 ScalarExpr::column(F1),
1116 BinaryOp::Add,
1117 ScalarExpr::column(F2),
1118 );
1119
1120 let result = NumericKernels::evaluate_batch(&expr, 4, &arrays).unwrap();
1121 let result = result
1122 .as_ref()
1123 .as_any()
1124 .downcast_ref::<Float64Array>()
1125 .unwrap();
1126
1127 assert_eq!(result.len(), 4);
1128 assert_eq!(result.value(0), 6.0);
1129 assert_eq!(result.value(1), 1.0);
1130 assert!(result.is_null(2));
1131 assert_eq!(result.value(3), 3.0);
1132 }
1133
1134 #[test]
1135 fn vectorized_multiply_literal() {
1136 const F1: FieldId = 10;
1137 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1138 arrays.insert(F1, float_array(&[Some(1.0), Some(-2.5), Some(0.0), None]));
1139
1140 let expr = ScalarExpr::binary(
1141 ScalarExpr::column(F1),
1142 BinaryOp::Multiply,
1143 ScalarExpr::literal(3),
1144 );
1145
1146 let result = NumericKernels::evaluate_batch(&expr, 4, &arrays).unwrap();
1147 let result = result
1148 .as_ref()
1149 .as_any()
1150 .downcast_ref::<Float64Array>()
1151 .unwrap();
1152
1153 assert_eq!(result.len(), 4);
1154 assert_eq!(result.value(0), 3.0);
1155 assert_eq!(result.value(1), -7.5);
1156 assert_eq!(result.value(2), 0.0);
1157 assert!(result.is_null(3));
1158 }
1159
1160 #[test]
1161 fn vectorized_add_column_scalar_literal() {
1162 const F1: FieldId = 11;
1163 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1164 arrays.insert(F1, float_array(&[Some(2.0), None, Some(-5.5)]));
1165
1166 let expr = ScalarExpr::binary(
1167 ScalarExpr::column(F1),
1168 BinaryOp::Add,
1169 ScalarExpr::literal(4),
1170 );
1171
1172 let result = NumericKernels::evaluate_batch(&expr, 3, &arrays).unwrap();
1173 let result = result
1174 .as_ref()
1175 .as_any()
1176 .downcast_ref::<Float64Array>()
1177 .unwrap();
1178
1179 assert_eq!(result.len(), 3);
1180 assert_eq!(result.value(0), 6.0);
1181 assert!(result.is_null(1));
1182 assert!((result.value(2) - (-1.5)).abs() < f64::EPSILON);
1183 }
1184
1185 #[test]
1186 fn vectorized_literal_minus_column() {
1187 const F1: FieldId = 12;
1188 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1189 arrays.insert(F1, float_array(&[Some(3.0), Some(-2.0), None]));
1190
1191 let expr = ScalarExpr::binary(
1192 ScalarExpr::literal(10),
1193 BinaryOp::Subtract,
1194 ScalarExpr::column(F1),
1195 );
1196
1197 let result = NumericKernels::evaluate_batch(&expr, 3, &arrays).unwrap();
1198 let result = result
1199 .as_ref()
1200 .as_any()
1201 .downcast_ref::<Float64Array>()
1202 .unwrap();
1203
1204 assert_eq!(result.len(), 3);
1205 assert_eq!(result.value(0), 7.0);
1206 assert_eq!(result.value(1), 12.0);
1207 assert!(result.is_null(2));
1208 }
1209
1210 #[test]
1211 fn vectorized_divide_by_zero_yields_null() {
1212 const NUM: FieldId = 20;
1213 const DEN: FieldId = 21;
1214 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1215 arrays.insert(
1216 NUM,
1217 float_array(&[Some(4.0), Some(9.0), Some(5.0), Some(-6.0)]),
1218 );
1219 arrays.insert(DEN, float_array(&[Some(2.0), Some(0.0), None, Some(-3.0)]));
1220
1221 let expr = ScalarExpr::binary(
1222 ScalarExpr::column(NUM),
1223 BinaryOp::Divide,
1224 ScalarExpr::column(DEN),
1225 );
1226
1227 let result = NumericKernels::evaluate_batch(&expr, 4, &arrays).unwrap();
1228 let result = result
1229 .as_ref()
1230 .as_any()
1231 .downcast_ref::<Float64Array>()
1232 .unwrap();
1233
1234 assert_eq!(result.len(), 4);
1235 assert_eq!(result.value(0), 2.0);
1236 assert!(result.is_null(1));
1237 assert!(result.is_null(2));
1238 assert_eq!(result.value(3), 2.0);
1239 }
1240
1241 #[test]
1242 fn vectorized_divide_by_zero_literal_rhs_yields_nulls() {
1243 const F1: FieldId = 22;
1244 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1245 arrays.insert(F1, float_array(&[Some(1.0), Some(-4.0), None]));
1246
1247 let expr = ScalarExpr::binary(
1248 ScalarExpr::column(F1),
1249 BinaryOp::Divide,
1250 ScalarExpr::literal(0),
1251 );
1252
1253 let result = NumericKernels::evaluate_batch(&expr, 3, &arrays).unwrap();
1254 let result = result
1255 .as_ref()
1256 .as_any()
1257 .downcast_ref::<Float64Array>()
1258 .unwrap();
1259
1260 assert_eq!(result.len(), 3);
1261 assert!(result.is_null(0));
1262 assert!(result.is_null(1));
1263 assert!(result.is_null(2));
1264 }
1265
1266 #[test]
1267 fn vectorized_modulo_literals() {
1268 let expr = ScalarExpr::binary(
1269 ScalarExpr::literal(13),
1270 BinaryOp::Modulo,
1271 ScalarExpr::literal(5),
1272 );
1273
1274 let simplified = NumericKernels::simplify(&expr);
1275 let ScalarExpr::Literal(Literal::Integer(value)) = simplified else {
1276 panic!("expected literal result");
1277 };
1278 assert_eq!(value, 3);
1279 }
1280
1281 #[test]
1282 fn vectorized_modulo_column_rhs_zero_yields_null() {
1283 const NUM: FieldId = 23;
1284 const DEN: FieldId = 24;
1285 let mut arrays: NumericArrayMap = NumericArrayMap::default();
1286 arrays.insert(NUM, float_array(&[Some(4.0), Some(7.0), None, Some(-6.0)]));
1287 arrays.insert(
1288 DEN,
1289 float_array(&[Some(2.0), Some(0.0), Some(3.0), Some(-4.0)]),
1290 );
1291
1292 let expr = ScalarExpr::binary(
1293 ScalarExpr::column(NUM),
1294 BinaryOp::Modulo,
1295 ScalarExpr::column(DEN),
1296 );
1297
1298 let result = NumericKernels::evaluate_batch(&expr, 4, &arrays).unwrap();
1299 let result = result
1300 .as_ref()
1301 .as_any()
1302 .downcast_ref::<Float64Array>()
1303 .unwrap();
1304
1305 assert_eq!(result.len(), 4);
1306 assert_eq!(result.value(0), 0.0);
1307 assert!(result.is_null(1));
1308 assert!(result.is_null(2));
1309 assert_eq!(result.value(3), -6.0 % -4.0);
1310 }
1311
1312 #[test]
1313 fn passthrough_detects_identity_ops() {
1314 const F1: FieldId = 99;
1315
1316 let expr_add = ScalarExpr::binary(
1317 ScalarExpr::column(F1),
1318 BinaryOp::Add,
1319 ScalarExpr::literal(0),
1320 );
1321 assert_eq!(NumericKernels::passthrough_column(&expr_add), Some(F1));
1322
1323 let expr_sub = ScalarExpr::binary(
1324 ScalarExpr::column(F1),
1325 BinaryOp::Subtract,
1326 ScalarExpr::literal(0),
1327 );
1328 assert_eq!(NumericKernels::passthrough_column(&expr_sub), Some(F1));
1329
1330 let expr_mul = ScalarExpr::binary(
1331 ScalarExpr::column(F1),
1332 BinaryOp::Multiply,
1333 ScalarExpr::literal(1),
1334 );
1335 assert_eq!(NumericKernels::passthrough_column(&expr_mul), Some(F1));
1336
1337 let expr_div = ScalarExpr::binary(
1338 ScalarExpr::column(F1),
1339 BinaryOp::Divide,
1340 ScalarExpr::literal(1),
1341 );
1342 assert_eq!(NumericKernels::passthrough_column(&expr_div), Some(F1));
1343
1344 let expr_add_two = ScalarExpr::binary(
1346 ScalarExpr::column(F1),
1347 BinaryOp::Add,
1348 ScalarExpr::literal(2),
1349 );
1350 assert_eq!(NumericKernels::passthrough_column(&expr_add_two), None);
1351 }
1352}