Skip to main content

kyu_executor/
batch_eval.rs

1//! Batch expression evaluation — operates on columns/DataChunks instead of
2//! per-row TypedValue evaluation. Falls back to `None` for unsupported patterns
3//! so callers can use the scalar `evaluate()` path.
4
5use kyu_common::KyuResult;
6use kyu_expression::BoundExpression;
7use kyu_parser::ast::{BinaryOp, ComparisonOp};
8use kyu_types::{LogicalType, TypedValue};
9use smol_str::SmolStr;
10
11use crate::data_chunk::DataChunk;
12use crate::value_vector::{SelectionVector, ValueVector};
13
14// ---------------------------------------------------------------------------
15// Filter batch evaluation
16// ---------------------------------------------------------------------------
17
18/// Try to evaluate a predicate on a DataChunk in batch mode.
19/// Returns `Some(Ok(SelectionVector))` if handled, `None` to fall back to scalar.
20pub fn evaluate_filter_batch(
21    predicate: &BoundExpression,
22    chunk: &DataChunk,
23) -> Option<KyuResult<SelectionVector>> {
24    match predicate {
25        // Pattern: variable <cmp> literal (Int64)
26        BoundExpression::Comparison { op, left, right } => {
27            match (left.as_ref(), right.as_ref()) {
28                (
29                    BoundExpression::Variable {
30                        index,
31                        result_type: LogicalType::Int64 | LogicalType::Serial,
32                    },
33                    BoundExpression::Literal { value, .. },
34                ) => {
35                    if let TypedValue::Int64(lit) = value {
36                        Some(Ok(filter_cmp_var_lit_i64(
37                            chunk,
38                            *index as usize,
39                            *op,
40                            *lit,
41                        )))
42                    } else {
43                        None
44                    }
45                }
46                // Pattern: literal <cmp> variable — flip the comparison
47                (
48                    BoundExpression::Literal { value, .. },
49                    BoundExpression::Variable {
50                        index,
51                        result_type: LogicalType::Int64 | LogicalType::Serial,
52                    },
53                ) => {
54                    if let TypedValue::Int64(lit) = value {
55                        Some(Ok(filter_cmp_var_lit_i64(
56                            chunk,
57                            *index as usize,
58                            flip_cmp(*op),
59                            *lit,
60                        )))
61                    } else {
62                        None
63                    }
64                }
65                // Pattern: variable <cmp> literal (String)
66                (
67                    BoundExpression::Variable { index, result_type },
68                    BoundExpression::Literal { value, .. },
69                ) if *result_type == LogicalType::String => {
70                    if let TypedValue::String(lit) = value {
71                        Some(Ok(filter_cmp_var_lit_string(
72                            chunk,
73                            *index as usize,
74                            *op,
75                            lit,
76                        )))
77                    } else {
78                        None
79                    }
80                }
81                _ => None,
82            }
83        }
84        // Pattern: expr OR expr
85        BoundExpression::BinaryOp {
86            op: BinaryOp::Or,
87            left,
88            right,
89            ..
90        } => {
91            let left_sel = evaluate_filter_batch(left, chunk)?;
92            let right_sel = evaluate_filter_batch(right, chunk)?;
93            Some(match (left_sel, right_sel) {
94                (Ok(l), Ok(r)) => Ok(union_selections(&l, &r)),
95                (Err(e), _) | (_, Err(e)) => Err(e),
96            })
97        }
98        // Pattern: expr AND expr
99        BoundExpression::BinaryOp {
100            op: BinaryOp::And,
101            left,
102            right,
103            ..
104        } => {
105            let left_sel = evaluate_filter_batch(left, chunk)?;
106            let right_sel = evaluate_filter_batch(right, chunk)?;
107            Some(match (left_sel, right_sel) {
108                (Ok(l), Ok(r)) => Ok(intersect_selections(&l, &r)),
109                (Err(e), _) | (_, Err(e)) => Err(e),
110            })
111        }
112        // Pattern: variable STARTS WITH literal
113        BoundExpression::StringOp {
114            op: kyu_parser::ast::StringOp::StartsWith,
115            left,
116            right,
117        } => {
118            if let (
119                BoundExpression::Variable { index, .. },
120                BoundExpression::Literal {
121                    value: TypedValue::String(prefix),
122                    ..
123                },
124            ) = (left.as_ref(), right.as_ref())
125            {
126                Some(Ok(filter_starts_with(chunk, *index as usize, prefix)))
127            } else {
128                None
129            }
130        }
131        _ => None,
132    }
133}
134
135// ---------------------------------------------------------------------------
136// Column expression evaluation
137// ---------------------------------------------------------------------------
138
139/// Try to evaluate an expression on a DataChunk, producing an output ValueVector.
140/// Returns `None` for unsupported patterns (caller falls back to scalar).
141pub fn evaluate_column(
142    expr: &BoundExpression,
143    chunk: &DataChunk,
144) -> Option<KyuResult<ValueVector>> {
145    match expr {
146        // Variable reference — pass through the column
147        BoundExpression::Variable { index, .. } => Some(Ok(chunk.compact_column(*index as usize))),
148        // Literal — broadcast to N copies
149        BoundExpression::Literal { value, .. } => {
150            let n = chunk.num_rows();
151            Some(Ok(ValueVector::Owned(vec![value.clone(); n])))
152        }
153        // BinaryOp on column data
154        BoundExpression::BinaryOp {
155            op,
156            left,
157            right,
158            result_type,
159        } => eval_binop_column(op, left, right, result_type, chunk),
160        _ => None,
161    }
162}
163
164// ---------------------------------------------------------------------------
165// Filter helpers
166// ---------------------------------------------------------------------------
167
168fn filter_cmp_var_lit_i64(
169    chunk: &DataChunk,
170    col_idx: usize,
171    op: ComparisonOp,
172    literal: i64,
173) -> SelectionVector {
174    let sel = chunk.selection();
175    let n = sel.len();
176    let col = chunk.column(col_idx);
177
178    // Fast path: FlatVector — direct i64 slice access
179    if let ValueVector::Flat(flat) = col
180        && matches!(
181            flat.logical_type(),
182            LogicalType::Int64 | LogicalType::Serial
183        )
184    {
185        let data = flat.data_as_i64_slice();
186        let mut selected = Vec::with_capacity(n);
187        for i in 0..n {
188            let phys = sel.get(i);
189            if !flat.is_null(phys) && cmp_i64(data[phys], literal, op) {
190                selected.push(phys as u32);
191            }
192        }
193        return SelectionVector::from_indices(selected);
194    }
195
196    // Fallback: Owned or other vector types
197    let mut selected = Vec::with_capacity(n);
198    for i in 0..n {
199        let phys = sel.get(i);
200        if let TypedValue::Int64(v) = col.get_value(phys)
201            && cmp_i64(v, literal, op)
202        {
203            selected.push(phys as u32);
204        }
205    }
206    SelectionVector::from_indices(selected)
207}
208
209#[inline]
210fn cmp_i64(val: i64, lit: i64, op: ComparisonOp) -> bool {
211    match op {
212        ComparisonOp::Lt => val < lit,
213        ComparisonOp::Le => val <= lit,
214        ComparisonOp::Gt => val > lit,
215        ComparisonOp::Ge => val >= lit,
216        ComparisonOp::Eq => val == lit,
217        ComparisonOp::Neq => val != lit,
218        ComparisonOp::RegexMatch => false,
219    }
220}
221
222fn flip_cmp(op: ComparisonOp) -> ComparisonOp {
223    match op {
224        ComparisonOp::Lt => ComparisonOp::Gt,
225        ComparisonOp::Le => ComparisonOp::Ge,
226        ComparisonOp::Gt => ComparisonOp::Lt,
227        ComparisonOp::Ge => ComparisonOp::Le,
228        ComparisonOp::Eq => ComparisonOp::Eq,
229        ComparisonOp::Neq => ComparisonOp::Neq,
230        ComparisonOp::RegexMatch => ComparisonOp::RegexMatch,
231    }
232}
233
234fn filter_cmp_var_lit_string(
235    chunk: &DataChunk,
236    col_idx: usize,
237    op: ComparisonOp,
238    literal: &SmolStr,
239) -> SelectionVector {
240    let sel = chunk.selection();
241    let n = sel.len();
242    let col = chunk.column(col_idx);
243
244    // Fast path: StringVector — direct data access
245    if let ValueVector::String(sv) = col {
246        let data = sv.data();
247        let mut selected = Vec::with_capacity(n);
248        for i in 0..n {
249            let phys = sel.get(i);
250            if let Some(ref s) = data[phys] {
251                let pass = match op {
252                    ComparisonOp::Eq => s == literal,
253                    ComparisonOp::Neq => s != literal,
254                    ComparisonOp::Lt => s < literal,
255                    ComparisonOp::Le => s <= literal,
256                    ComparisonOp::Gt => s > literal,
257                    ComparisonOp::Ge => s >= literal,
258                    ComparisonOp::RegexMatch => false,
259                };
260                if pass {
261                    selected.push(phys as u32);
262                }
263            }
264        }
265        return SelectionVector::from_indices(selected);
266    }
267
268    // Fallback
269    let mut selected = Vec::with_capacity(n);
270    for i in 0..n {
271        let phys = sel.get(i);
272        if let TypedValue::String(ref s) = col.get_value(phys) {
273            let pass = match op {
274                ComparisonOp::Eq => s == literal,
275                ComparisonOp::Neq => s != literal,
276                _ => false,
277            };
278            if pass {
279                selected.push(phys as u32);
280            }
281        }
282    }
283    SelectionVector::from_indices(selected)
284}
285
286fn filter_starts_with(chunk: &DataChunk, col_idx: usize, prefix: &SmolStr) -> SelectionVector {
287    let sel = chunk.selection();
288    let n = sel.len();
289    let col = chunk.column(col_idx);
290
291    if let ValueVector::String(sv) = col {
292        let data = sv.data();
293        let mut selected = Vec::with_capacity(n);
294        for i in 0..n {
295            let phys = sel.get(i);
296            if let Some(ref s) = data[phys]
297                && s.starts_with(prefix.as_str())
298            {
299                selected.push(phys as u32);
300            }
301        }
302        return SelectionVector::from_indices(selected);
303    }
304
305    // Fallback
306    let mut selected = Vec::with_capacity(n);
307    for i in 0..n {
308        let phys = sel.get(i);
309        if let TypedValue::String(ref s) = col.get_value(phys)
310            && s.starts_with(prefix.as_str())
311        {
312            selected.push(phys as u32);
313        }
314    }
315    SelectionVector::from_indices(selected)
316}
317
318/// Sorted merge union — O(n) since both inputs are already sorted.
319fn union_selections(a: &SelectionVector, b: &SelectionVector) -> SelectionVector {
320    let (na, nb) = (a.len(), b.len());
321    let mut result = Vec::with_capacity(na + nb);
322    let (mut i, mut j) = (0, 0);
323    while i < na && j < nb {
324        let va = a.get(i) as u32;
325        let vb = b.get(j) as u32;
326        match va.cmp(&vb) {
327            std::cmp::Ordering::Less => {
328                result.push(va);
329                i += 1;
330            }
331            std::cmp::Ordering::Greater => {
332                result.push(vb);
333                j += 1;
334            }
335            std::cmp::Ordering::Equal => {
336                result.push(va);
337                i += 1;
338                j += 1;
339            }
340        }
341    }
342    while i < na {
343        result.push(a.get(i) as u32);
344        i += 1;
345    }
346    while j < nb {
347        result.push(b.get(j) as u32);
348        j += 1;
349    }
350    SelectionVector::from_indices(result)
351}
352
353/// Sorted merge intersection — O(n) since both inputs are already sorted.
354fn intersect_selections(a: &SelectionVector, b: &SelectionVector) -> SelectionVector {
355    let (na, nb) = (a.len(), b.len());
356    let mut result = Vec::with_capacity(na.min(nb));
357    let (mut i, mut j) = (0, 0);
358    while i < na && j < nb {
359        let va = a.get(i) as u32;
360        let vb = b.get(j) as u32;
361        match va.cmp(&vb) {
362            std::cmp::Ordering::Less => i += 1,
363            std::cmp::Ordering::Greater => j += 1,
364            std::cmp::Ordering::Equal => {
365                result.push(va);
366                i += 1;
367                j += 1;
368            }
369        }
370    }
371    SelectionVector::from_indices(result)
372}
373
374// ---------------------------------------------------------------------------
375// Column expression helpers
376// ---------------------------------------------------------------------------
377
378fn eval_binop_column(
379    op: &BinaryOp,
380    left: &BoundExpression,
381    right: &BoundExpression,
382    _result_type: &LogicalType,
383    chunk: &DataChunk,
384) -> Option<KyuResult<ValueVector>> {
385    // Pattern: Variable{i64} <op> Literal{i64}
386    if let (
387        BoundExpression::Variable { index, result_type },
388        BoundExpression::Literal {
389            value: TypedValue::Int64(lit),
390            ..
391        },
392    ) = (left, right)
393        && matches!(result_type, LogicalType::Int64 | LogicalType::Serial)
394    {
395        return Some(Ok(binop_col_lit_i64(chunk, *index as usize, *op, *lit)));
396    }
397
398    // Pattern: Literal{i64} <op> Variable{i64}
399    if let (
400        BoundExpression::Literal {
401            value: TypedValue::Int64(lit),
402            ..
403        },
404        BoundExpression::Variable { index, result_type },
405    ) = (left, right)
406        && matches!(result_type, LogicalType::Int64 | LogicalType::Serial)
407    {
408        return Some(Ok(binop_lit_col_i64(*lit, chunk, *index as usize, *op)));
409    }
410
411    // Pattern: nested BinaryOp — evaluate children as columns, then combine
412    let left_col = evaluate_column(left, chunk)?;
413    let right_col = evaluate_column(right, chunk)?;
414    Some(match (left_col, right_col) {
415        (Ok(lv), Ok(rv)) => Ok(binop_vec_vec_i64(&lv, &rv, *op, chunk.num_rows())),
416        (Err(e), _) | (_, Err(e)) => Err(e),
417    })
418}
419
420fn binop_col_lit_i64(chunk: &DataChunk, col_idx: usize, op: BinaryOp, literal: i64) -> ValueVector {
421    let sel = chunk.selection();
422    let n = sel.len();
423    let col = chunk.column(col_idx);
424
425    if let ValueVector::Flat(flat) = col
426        && matches!(
427            flat.logical_type(),
428            LogicalType::Int64 | LogicalType::Serial
429        )
430    {
431        let data = flat.data_as_i64_slice();
432        let mut result = Vec::with_capacity(n);
433        for i in 0..n {
434            let phys = sel.get(i);
435            if flat.is_null(phys) {
436                result.push(TypedValue::Null);
437            } else {
438                result.push(TypedValue::Int64(apply_i64_op(data[phys], literal, op)));
439            }
440        }
441        return ValueVector::Owned(result);
442    }
443
444    // Fallback: Owned
445    let mut result = Vec::with_capacity(n);
446    for i in 0..n {
447        let phys = sel.get(i);
448        match col.get_value(phys) {
449            TypedValue::Int64(v) => {
450                result.push(TypedValue::Int64(apply_i64_op(v, literal, op)));
451            }
452            _ => result.push(TypedValue::Null),
453        }
454    }
455    ValueVector::Owned(result)
456}
457
458fn binop_lit_col_i64(literal: i64, chunk: &DataChunk, col_idx: usize, op: BinaryOp) -> ValueVector {
459    let sel = chunk.selection();
460    let n = sel.len();
461    let col = chunk.column(col_idx);
462    let mut result = Vec::with_capacity(n);
463    for i in 0..n {
464        let phys = sel.get(i);
465        match col.get_value(phys) {
466            TypedValue::Int64(v) => {
467                result.push(TypedValue::Int64(apply_i64_op(literal, v, op)));
468            }
469            _ => result.push(TypedValue::Null),
470        }
471    }
472    ValueVector::Owned(result)
473}
474
475fn binop_vec_vec_i64(
476    left: &ValueVector,
477    right: &ValueVector,
478    op: BinaryOp,
479    n: usize,
480) -> ValueVector {
481    let mut result = Vec::with_capacity(n);
482    for i in 0..n {
483        match (left.get_value(i), right.get_value(i)) {
484            (TypedValue::Int64(a), TypedValue::Int64(b)) => {
485                result.push(TypedValue::Int64(apply_i64_op(a, b, op)));
486            }
487            _ => result.push(TypedValue::Null),
488        }
489    }
490    ValueVector::Owned(result)
491}
492
493#[inline]
494fn apply_i64_op(a: i64, b: i64, op: BinaryOp) -> i64 {
495    match op {
496        BinaryOp::Add => a.wrapping_add(b),
497        BinaryOp::Sub => a.wrapping_sub(b),
498        BinaryOp::Mul => a.wrapping_mul(b),
499        BinaryOp::Div => {
500            if b == 0 {
501                0
502            } else {
503                a / b
504            }
505        }
506        BinaryOp::Mod => {
507            if b == 0 {
508                0
509            } else {
510                a % b
511            }
512        }
513        _ => 0,
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use kyu_storage::ColumnChunkData;
521    use kyu_types::LogicalType;
522
523    fn make_i64_chunk(values: &[i64]) -> DataChunk {
524        let mut col = ColumnChunkData::new(LogicalType::Int64, values.len() as u64);
525        for &v in values {
526            col.append_value::<i64>(v);
527        }
528        let flat = crate::value_vector::FlatVector::from_column_chunk(&col, values.len());
529        DataChunk::from_vectors(
530            vec![ValueVector::Flat(flat)],
531            SelectionVector::identity(values.len()),
532        )
533    }
534
535    #[test]
536    fn batch_filter_lt() {
537        let chunk = make_i64_chunk(&[10, 20, 5, 30, 3]);
538        let pred = BoundExpression::Comparison {
539            op: ComparisonOp::Lt,
540            left: Box::new(BoundExpression::Variable {
541                index: 0,
542                result_type: LogicalType::Int64,
543            }),
544            right: Box::new(BoundExpression::Literal {
545                value: TypedValue::Int64(15),
546                result_type: LogicalType::Int64,
547            }),
548        };
549        let sel = evaluate_filter_batch(&pred, &chunk).unwrap().unwrap();
550        assert_eq!(sel.len(), 3); // 10, 5, 3
551    }
552
553    #[test]
554    fn batch_filter_or() {
555        let chunk = make_i64_chunk(&[1, 50, 3, 100, 2]);
556        let left = BoundExpression::Comparison {
557            op: ComparisonOp::Lt,
558            left: Box::new(BoundExpression::Variable {
559                index: 0,
560                result_type: LogicalType::Int64,
561            }),
562            right: Box::new(BoundExpression::Literal {
563                value: TypedValue::Int64(3),
564                result_type: LogicalType::Int64,
565            }),
566        };
567        let right = BoundExpression::Comparison {
568            op: ComparisonOp::Gt,
569            left: Box::new(BoundExpression::Variable {
570                index: 0,
571                result_type: LogicalType::Int64,
572            }),
573            right: Box::new(BoundExpression::Literal {
574                value: TypedValue::Int64(90),
575                result_type: LogicalType::Int64,
576            }),
577        };
578        let pred = BoundExpression::BinaryOp {
579            op: BinaryOp::Or,
580            left: Box::new(left),
581            right: Box::new(right),
582            result_type: LogicalType::Bool,
583        };
584        let sel = evaluate_filter_batch(&pred, &chunk).unwrap().unwrap();
585        assert_eq!(sel.len(), 3); // 1, 100, 2
586    }
587
588    #[test]
589    fn batch_column_variable() {
590        let chunk = make_i64_chunk(&[10, 20, 30]);
591        let expr = BoundExpression::Variable {
592            index: 0,
593            result_type: LogicalType::Int64,
594        };
595        let col = evaluate_column(&expr, &chunk).unwrap().unwrap();
596        assert_eq!(col.get_value(0), TypedValue::Int64(10));
597        assert_eq!(col.get_value(2), TypedValue::Int64(30));
598    }
599
600    #[test]
601    fn batch_column_mul() {
602        let chunk = make_i64_chunk(&[5, 10, 15]);
603        let expr = BoundExpression::BinaryOp {
604            op: BinaryOp::Mul,
605            left: Box::new(BoundExpression::Variable {
606                index: 0,
607                result_type: LogicalType::Int64,
608            }),
609            right: Box::new(BoundExpression::Literal {
610                value: TypedValue::Int64(2),
611                result_type: LogicalType::Int64,
612            }),
613            result_type: LogicalType::Int64,
614        };
615        let col = evaluate_column(&expr, &chunk).unwrap().unwrap();
616        assert_eq!(col.get_value(0), TypedValue::Int64(10));
617        assert_eq!(col.get_value(1), TypedValue::Int64(20));
618        assert_eq!(col.get_value(2), TypedValue::Int64(30));
619    }
620
621    #[test]
622    fn batch_column_nested_mul() {
623        // c.length * 2 * 2 * 2 * 2  (benchmark query pattern)
624        let chunk = make_i64_chunk(&[1, 3]);
625        let var = BoundExpression::Variable {
626            index: 0,
627            result_type: LogicalType::Int64,
628        };
629        let lit2 = || BoundExpression::Literal {
630            value: TypedValue::Int64(2),
631            result_type: LogicalType::Int64,
632        };
633        let mul = |l, r| BoundExpression::BinaryOp {
634            op: BinaryOp::Mul,
635            left: Box::new(l),
636            right: Box::new(r),
637            result_type: LogicalType::Int64,
638        };
639        let expr = mul(mul(mul(mul(var, lit2()), lit2()), lit2()), lit2());
640        let col = evaluate_column(&expr, &chunk).unwrap().unwrap();
641        assert_eq!(col.get_value(0), TypedValue::Int64(16));
642        assert_eq!(col.get_value(1), TypedValue::Int64(48));
643    }
644}