Skip to main content

kyu_executor/
batch_eval.rs

1//! Batch expression evaluation — operates on columns/DataChunks instead of
2//! per-row TypedValue evaluation. Falls back to `None` for unsupported patterns
3//! so callers can use the scalar `evaluate()` path.
4
5use kyu_common::KyuResult;
6use kyu_expression::BoundExpression;
7use kyu_parser::ast::{BinaryOp, ComparisonOp};
8use kyu_types::{LogicalType, TypedValue};
9use smol_str::SmolStr;
10
11use crate::data_chunk::DataChunk;
12use crate::value_vector::{SelectionVector, ValueVector};
13
14// ---------------------------------------------------------------------------
15// Filter batch evaluation
16// ---------------------------------------------------------------------------
17
18/// Try to evaluate a predicate on a DataChunk in batch mode.
19/// Returns `Some(Ok(SelectionVector))` if handled, `None` to fall back to scalar.
20pub fn evaluate_filter_batch(
21    predicate: &BoundExpression,
22    chunk: &DataChunk,
23) -> Option<KyuResult<SelectionVector>> {
24    match predicate {
25        // Pattern: variable <cmp> literal (Int64)
26        BoundExpression::Comparison { op, left, right } => {
27            match (left.as_ref(), right.as_ref()) {
28                (
29                    BoundExpression::Variable {
30                        index,
31                        result_type: LogicalType::Int64 | LogicalType::Serial,
32                    },
33                    BoundExpression::Literal { value, .. },
34                ) => {
35                    if let TypedValue::Int64(lit) = value {
36                        Some(Ok(filter_cmp_var_lit_i64(
37                            chunk, *index as usize, *op, *lit,
38                        )))
39                    } else {
40                        None
41                    }
42                }
43                // Pattern: literal <cmp> variable — flip the comparison
44                (
45                    BoundExpression::Literal { value, .. },
46                    BoundExpression::Variable {
47                        index,
48                        result_type: LogicalType::Int64 | LogicalType::Serial,
49                    },
50                ) => {
51                    if let TypedValue::Int64(lit) = value {
52                        Some(Ok(filter_cmp_var_lit_i64(
53                            chunk,
54                            *index as usize,
55                            flip_cmp(*op),
56                            *lit,
57                        )))
58                    } else {
59                        None
60                    }
61                }
62                // Pattern: variable <cmp> literal (String)
63                (
64                    BoundExpression::Variable { index, result_type },
65                    BoundExpression::Literal { value, .. },
66                ) if *result_type == LogicalType::String => {
67                    if let TypedValue::String(lit) = value {
68                        Some(Ok(filter_cmp_var_lit_string(
69                            chunk, *index as usize, *op, lit,
70                        )))
71                    } else {
72                        None
73                    }
74                }
75                _ => None,
76            }
77        }
78        // Pattern: expr OR expr
79        BoundExpression::BinaryOp {
80            op: BinaryOp::Or,
81            left,
82            right,
83            ..
84        } => {
85            let left_sel = evaluate_filter_batch(left, chunk)?;
86            let right_sel = evaluate_filter_batch(right, chunk)?;
87            Some(match (left_sel, right_sel) {
88                (Ok(l), Ok(r)) => Ok(union_selections(&l, &r)),
89                (Err(e), _) | (_, Err(e)) => Err(e),
90            })
91        }
92        // Pattern: expr AND expr
93        BoundExpression::BinaryOp {
94            op: BinaryOp::And,
95            left,
96            right,
97            ..
98        } => {
99            let left_sel = evaluate_filter_batch(left, chunk)?;
100            let right_sel = evaluate_filter_batch(right, chunk)?;
101            Some(match (left_sel, right_sel) {
102                (Ok(l), Ok(r)) => Ok(intersect_selections(&l, &r)),
103                (Err(e), _) | (_, Err(e)) => Err(e),
104            })
105        }
106        // Pattern: variable STARTS WITH literal
107        BoundExpression::StringOp {
108            op: kyu_parser::ast::StringOp::StartsWith,
109            left,
110            right,
111        } => {
112            if let (
113                BoundExpression::Variable { index, .. },
114                BoundExpression::Literal {
115                    value: TypedValue::String(prefix),
116                    ..
117                },
118            ) = (left.as_ref(), right.as_ref())
119            {
120                Some(Ok(filter_starts_with(chunk, *index as usize, prefix)))
121            } else {
122                None
123            }
124        }
125        _ => None,
126    }
127}
128
129// ---------------------------------------------------------------------------
130// Column expression evaluation
131// ---------------------------------------------------------------------------
132
133/// Try to evaluate an expression on a DataChunk, producing an output ValueVector.
134/// Returns `None` for unsupported patterns (caller falls back to scalar).
135pub fn evaluate_column(
136    expr: &BoundExpression,
137    chunk: &DataChunk,
138) -> Option<KyuResult<ValueVector>> {
139    match expr {
140        // Variable reference — pass through the column
141        BoundExpression::Variable { index, .. } => {
142            Some(Ok(chunk.compact_column(*index as usize)))
143        }
144        // Literal — broadcast to N copies
145        BoundExpression::Literal { value, .. } => {
146            let n = chunk.num_rows();
147            Some(Ok(ValueVector::Owned(vec![value.clone(); n])))
148        }
149        // BinaryOp on column data
150        BoundExpression::BinaryOp {
151            op,
152            left,
153            right,
154            result_type,
155        } => eval_binop_column(op, left, right, result_type, chunk),
156        _ => None,
157    }
158}
159
160// ---------------------------------------------------------------------------
161// Filter helpers
162// ---------------------------------------------------------------------------
163
164fn filter_cmp_var_lit_i64(
165    chunk: &DataChunk,
166    col_idx: usize,
167    op: ComparisonOp,
168    literal: i64,
169) -> SelectionVector {
170    let sel = chunk.selection();
171    let n = sel.len();
172    let col = chunk.column(col_idx);
173
174    // Fast path: FlatVector — direct i64 slice access
175    if let ValueVector::Flat(flat) = col
176        && matches!(flat.logical_type(), LogicalType::Int64 | LogicalType::Serial)
177    {
178        let data = flat.data_as_i64_slice();
179        let mut selected = Vec::with_capacity(n);
180        for i in 0..n {
181            let phys = sel.get(i);
182            if !flat.is_null(phys) && cmp_i64(data[phys], literal, op) {
183                selected.push(phys as u32);
184            }
185        }
186        return SelectionVector::from_indices(selected);
187    }
188
189    // Fallback: Owned or other vector types
190    let mut selected = Vec::with_capacity(n);
191    for i in 0..n {
192        let phys = sel.get(i);
193        if let TypedValue::Int64(v) = col.get_value(phys)
194            && cmp_i64(v, literal, op)
195        {
196            selected.push(phys as u32);
197        }
198    }
199    SelectionVector::from_indices(selected)
200}
201
202#[inline]
203fn cmp_i64(val: i64, lit: i64, op: ComparisonOp) -> bool {
204    match op {
205        ComparisonOp::Lt => val < lit,
206        ComparisonOp::Le => val <= lit,
207        ComparisonOp::Gt => val > lit,
208        ComparisonOp::Ge => val >= lit,
209        ComparisonOp::Eq => val == lit,
210        ComparisonOp::Neq => val != lit,
211        ComparisonOp::RegexMatch => false,
212    }
213}
214
215fn flip_cmp(op: ComparisonOp) -> ComparisonOp {
216    match op {
217        ComparisonOp::Lt => ComparisonOp::Gt,
218        ComparisonOp::Le => ComparisonOp::Ge,
219        ComparisonOp::Gt => ComparisonOp::Lt,
220        ComparisonOp::Ge => ComparisonOp::Le,
221        ComparisonOp::Eq => ComparisonOp::Eq,
222        ComparisonOp::Neq => ComparisonOp::Neq,
223        ComparisonOp::RegexMatch => ComparisonOp::RegexMatch,
224    }
225}
226
227fn filter_cmp_var_lit_string(
228    chunk: &DataChunk,
229    col_idx: usize,
230    op: ComparisonOp,
231    literal: &SmolStr,
232) -> SelectionVector {
233    let sel = chunk.selection();
234    let n = sel.len();
235    let col = chunk.column(col_idx);
236
237    // Fast path: StringVector — direct data access
238    if let ValueVector::String(sv) = col {
239        let data = sv.data();
240        let mut selected = Vec::with_capacity(n);
241        for i in 0..n {
242            let phys = sel.get(i);
243            if let Some(ref s) = data[phys] {
244                let pass = match op {
245                    ComparisonOp::Eq => s == literal,
246                    ComparisonOp::Neq => s != literal,
247                    ComparisonOp::Lt => s < literal,
248                    ComparisonOp::Le => s <= literal,
249                    ComparisonOp::Gt => s > literal,
250                    ComparisonOp::Ge => s >= literal,
251                    ComparisonOp::RegexMatch => false,
252                };
253                if pass {
254                    selected.push(phys as u32);
255                }
256            }
257        }
258        return SelectionVector::from_indices(selected);
259    }
260
261    // Fallback
262    let mut selected = Vec::with_capacity(n);
263    for i in 0..n {
264        let phys = sel.get(i);
265        if let TypedValue::String(ref s) = col.get_value(phys) {
266            let pass = match op {
267                ComparisonOp::Eq => s == literal,
268                ComparisonOp::Neq => s != literal,
269                _ => false,
270            };
271            if pass {
272                selected.push(phys as u32);
273            }
274        }
275    }
276    SelectionVector::from_indices(selected)
277}
278
279fn filter_starts_with(
280    chunk: &DataChunk,
281    col_idx: usize,
282    prefix: &SmolStr,
283) -> SelectionVector {
284    let sel = chunk.selection();
285    let n = sel.len();
286    let col = chunk.column(col_idx);
287
288    if let ValueVector::String(sv) = col {
289        let data = sv.data();
290        let mut selected = Vec::with_capacity(n);
291        for i in 0..n {
292            let phys = sel.get(i);
293            if let Some(ref s) = data[phys]
294                && s.starts_with(prefix.as_str())
295            {
296                selected.push(phys as u32);
297            }
298        }
299        return SelectionVector::from_indices(selected);
300    }
301
302    // Fallback
303    let mut selected = Vec::with_capacity(n);
304    for i in 0..n {
305        let phys = sel.get(i);
306        if let TypedValue::String(ref s) = col.get_value(phys)
307            && s.starts_with(prefix.as_str())
308        {
309            selected.push(phys as u32);
310        }
311    }
312    SelectionVector::from_indices(selected)
313}
314
315/// Sorted merge union — O(n) since both inputs are already sorted.
316fn union_selections(a: &SelectionVector, b: &SelectionVector) -> SelectionVector {
317    let (na, nb) = (a.len(), b.len());
318    let mut result = Vec::with_capacity(na + nb);
319    let (mut i, mut j) = (0, 0);
320    while i < na && j < nb {
321        let va = a.get(i) as u32;
322        let vb = b.get(j) as u32;
323        match va.cmp(&vb) {
324            std::cmp::Ordering::Less => {
325                result.push(va);
326                i += 1;
327            }
328            std::cmp::Ordering::Greater => {
329                result.push(vb);
330                j += 1;
331            }
332            std::cmp::Ordering::Equal => {
333                result.push(va);
334                i += 1;
335                j += 1;
336            }
337        }
338    }
339    while i < na {
340        result.push(a.get(i) as u32);
341        i += 1;
342    }
343    while j < nb {
344        result.push(b.get(j) as u32);
345        j += 1;
346    }
347    SelectionVector::from_indices(result)
348}
349
350/// Sorted merge intersection — O(n) since both inputs are already sorted.
351fn intersect_selections(a: &SelectionVector, b: &SelectionVector) -> SelectionVector {
352    let (na, nb) = (a.len(), b.len());
353    let mut result = Vec::with_capacity(na.min(nb));
354    let (mut i, mut j) = (0, 0);
355    while i < na && j < nb {
356        let va = a.get(i) as u32;
357        let vb = b.get(j) as u32;
358        match va.cmp(&vb) {
359            std::cmp::Ordering::Less => i += 1,
360            std::cmp::Ordering::Greater => j += 1,
361            std::cmp::Ordering::Equal => {
362                result.push(va);
363                i += 1;
364                j += 1;
365            }
366        }
367    }
368    SelectionVector::from_indices(result)
369}
370
371// ---------------------------------------------------------------------------
372// Column expression helpers
373// ---------------------------------------------------------------------------
374
375fn eval_binop_column(
376    op: &BinaryOp,
377    left: &BoundExpression,
378    right: &BoundExpression,
379    _result_type: &LogicalType,
380    chunk: &DataChunk,
381) -> Option<KyuResult<ValueVector>> {
382    // Pattern: Variable{i64} <op> Literal{i64}
383    if let (
384        BoundExpression::Variable { index, result_type },
385        BoundExpression::Literal {
386            value: TypedValue::Int64(lit),
387            ..
388        },
389    ) = (left, right)
390        && matches!(result_type, LogicalType::Int64 | LogicalType::Serial)
391    {
392        return Some(Ok(binop_col_lit_i64(chunk, *index as usize, *op, *lit)));
393    }
394
395    // Pattern: Literal{i64} <op> Variable{i64}
396    if let (
397        BoundExpression::Literal {
398            value: TypedValue::Int64(lit),
399            ..
400        },
401        BoundExpression::Variable { index, result_type },
402    ) = (left, right)
403        && matches!(result_type, LogicalType::Int64 | LogicalType::Serial)
404    {
405        return Some(Ok(binop_lit_col_i64(*lit, chunk, *index as usize, *op)));
406    }
407
408    // Pattern: nested BinaryOp — evaluate children as columns, then combine
409    let left_col = evaluate_column(left, chunk)?;
410    let right_col = evaluate_column(right, chunk)?;
411    Some(match (left_col, right_col) {
412        (Ok(lv), Ok(rv)) => Ok(binop_vec_vec_i64(&lv, &rv, *op, chunk.num_rows())),
413        (Err(e), _) | (_, Err(e)) => Err(e),
414    })
415}
416
417fn binop_col_lit_i64(
418    chunk: &DataChunk,
419    col_idx: usize,
420    op: BinaryOp,
421    literal: i64,
422) -> ValueVector {
423    let sel = chunk.selection();
424    let n = sel.len();
425    let col = chunk.column(col_idx);
426
427    if let ValueVector::Flat(flat) = col
428        && matches!(flat.logical_type(), LogicalType::Int64 | LogicalType::Serial)
429    {
430        let data = flat.data_as_i64_slice();
431        let mut result = Vec::with_capacity(n);
432        for i in 0..n {
433            let phys = sel.get(i);
434            if flat.is_null(phys) {
435                result.push(TypedValue::Null);
436            } else {
437                result.push(TypedValue::Int64(apply_i64_op(data[phys], literal, op)));
438            }
439        }
440        return ValueVector::Owned(result);
441    }
442
443    // Fallback: Owned
444    let mut result = Vec::with_capacity(n);
445    for i in 0..n {
446        let phys = sel.get(i);
447        match col.get_value(phys) {
448            TypedValue::Int64(v) => {
449                result.push(TypedValue::Int64(apply_i64_op(v, literal, op)));
450            }
451            _ => result.push(TypedValue::Null),
452        }
453    }
454    ValueVector::Owned(result)
455}
456
457fn binop_lit_col_i64(
458    literal: i64,
459    chunk: &DataChunk,
460    col_idx: usize,
461    op: BinaryOp,
462) -> ValueVector {
463    let sel = chunk.selection();
464    let n = sel.len();
465    let col = chunk.column(col_idx);
466    let mut result = Vec::with_capacity(n);
467    for i in 0..n {
468        let phys = sel.get(i);
469        match col.get_value(phys) {
470            TypedValue::Int64(v) => {
471                result.push(TypedValue::Int64(apply_i64_op(literal, v, op)));
472            }
473            _ => result.push(TypedValue::Null),
474        }
475    }
476    ValueVector::Owned(result)
477}
478
479fn binop_vec_vec_i64(
480    left: &ValueVector,
481    right: &ValueVector,
482    op: BinaryOp,
483    n: usize,
484) -> ValueVector {
485    let mut result = Vec::with_capacity(n);
486    for i in 0..n {
487        match (left.get_value(i), right.get_value(i)) {
488            (TypedValue::Int64(a), TypedValue::Int64(b)) => {
489                result.push(TypedValue::Int64(apply_i64_op(a, b, op)));
490            }
491            _ => result.push(TypedValue::Null),
492        }
493    }
494    ValueVector::Owned(result)
495}
496
497#[inline]
498fn apply_i64_op(a: i64, b: i64, op: BinaryOp) -> i64 {
499    match op {
500        BinaryOp::Add => a.wrapping_add(b),
501        BinaryOp::Sub => a.wrapping_sub(b),
502        BinaryOp::Mul => a.wrapping_mul(b),
503        BinaryOp::Div => {
504            if b == 0 {
505                0
506            } else {
507                a / b
508            }
509        }
510        BinaryOp::Mod => {
511            if b == 0 {
512                0
513            } else {
514                a % b
515            }
516        }
517        _ => 0,
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use kyu_storage::ColumnChunkData;
525    use kyu_types::LogicalType;
526
527    fn make_i64_chunk(values: &[i64]) -> DataChunk {
528        let mut col = ColumnChunkData::new(LogicalType::Int64, values.len() as u64);
529        for &v in values {
530            col.append_value::<i64>(v);
531        }
532        let flat = crate::value_vector::FlatVector::from_column_chunk(&col, values.len());
533        DataChunk::from_vectors(
534            vec![ValueVector::Flat(flat)],
535            SelectionVector::identity(values.len()),
536        )
537    }
538
539    #[test]
540    fn batch_filter_lt() {
541        let chunk = make_i64_chunk(&[10, 20, 5, 30, 3]);
542        let pred = BoundExpression::Comparison {
543            op: ComparisonOp::Lt,
544            left: Box::new(BoundExpression::Variable {
545                index: 0,
546                result_type: LogicalType::Int64,
547            }),
548            right: Box::new(BoundExpression::Literal {
549                value: TypedValue::Int64(15),
550                result_type: LogicalType::Int64,
551            }),
552        };
553        let sel = evaluate_filter_batch(&pred, &chunk).unwrap().unwrap();
554        assert_eq!(sel.len(), 3); // 10, 5, 3
555    }
556
557    #[test]
558    fn batch_filter_or() {
559        let chunk = make_i64_chunk(&[1, 50, 3, 100, 2]);
560        let left = BoundExpression::Comparison {
561            op: ComparisonOp::Lt,
562            left: Box::new(BoundExpression::Variable {
563                index: 0,
564                result_type: LogicalType::Int64,
565            }),
566            right: Box::new(BoundExpression::Literal {
567                value: TypedValue::Int64(3),
568                result_type: LogicalType::Int64,
569            }),
570        };
571        let right = BoundExpression::Comparison {
572            op: ComparisonOp::Gt,
573            left: Box::new(BoundExpression::Variable {
574                index: 0,
575                result_type: LogicalType::Int64,
576            }),
577            right: Box::new(BoundExpression::Literal {
578                value: TypedValue::Int64(90),
579                result_type: LogicalType::Int64,
580            }),
581        };
582        let pred = BoundExpression::BinaryOp {
583            op: BinaryOp::Or,
584            left: Box::new(left),
585            right: Box::new(right),
586            result_type: LogicalType::Bool,
587        };
588        let sel = evaluate_filter_batch(&pred, &chunk).unwrap().unwrap();
589        assert_eq!(sel.len(), 3); // 1, 100, 2
590    }
591
592    #[test]
593    fn batch_column_variable() {
594        let chunk = make_i64_chunk(&[10, 20, 30]);
595        let expr = BoundExpression::Variable {
596            index: 0,
597            result_type: LogicalType::Int64,
598        };
599        let col = evaluate_column(&expr, &chunk).unwrap().unwrap();
600        assert_eq!(col.get_value(0), TypedValue::Int64(10));
601        assert_eq!(col.get_value(2), TypedValue::Int64(30));
602    }
603
604    #[test]
605    fn batch_column_mul() {
606        let chunk = make_i64_chunk(&[5, 10, 15]);
607        let expr = BoundExpression::BinaryOp {
608            op: BinaryOp::Mul,
609            left: Box::new(BoundExpression::Variable {
610                index: 0,
611                result_type: LogicalType::Int64,
612            }),
613            right: Box::new(BoundExpression::Literal {
614                value: TypedValue::Int64(2),
615                result_type: LogicalType::Int64,
616            }),
617            result_type: LogicalType::Int64,
618        };
619        let col = evaluate_column(&expr, &chunk).unwrap().unwrap();
620        assert_eq!(col.get_value(0), TypedValue::Int64(10));
621        assert_eq!(col.get_value(1), TypedValue::Int64(20));
622        assert_eq!(col.get_value(2), TypedValue::Int64(30));
623    }
624
625    #[test]
626    fn batch_column_nested_mul() {
627        // c.length * 2 * 2 * 2 * 2  (benchmark query pattern)
628        let chunk = make_i64_chunk(&[1, 3]);
629        let var = BoundExpression::Variable {
630            index: 0,
631            result_type: LogicalType::Int64,
632        };
633        let lit2 = || BoundExpression::Literal {
634            value: TypedValue::Int64(2),
635            result_type: LogicalType::Int64,
636        };
637        let mul = |l, r| BoundExpression::BinaryOp {
638            op: BinaryOp::Mul,
639            left: Box::new(l),
640            right: Box::new(r),
641            result_type: LogicalType::Int64,
642        };
643        let expr = mul(mul(mul(mul(var, lit2()), lit2()), lit2()), lit2());
644        let col = evaluate_column(&expr, &chunk).unwrap().unwrap();
645        assert_eq!(col.get_value(0), TypedValue::Int64(16));
646        assert_eq!(col.get_value(1), TypedValue::Int64(48));
647    }
648}