llkv_compute/
program.rs

1//! Bytecode-style programs for predicate evaluation and domain analysis.
2
3use std::ops::Bound;
4use std::sync::Arc;
5
6use llkv_expr::{CompareOp, Expr, Filter, Operator, ScalarExpr, literal::Literal};
7use llkv_result::{Error, Result as LlkvResult};
8use llkv_types::FieldId;
9use rustc_hash::{FxHashMap, FxHashSet};
10
11#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
12struct ExprKey(*const ());
13
14impl ExprKey {
15    #[inline]
16    fn new(expr: &Expr<'_, FieldId>) -> Self {
17        Self(expr as *const _ as *const ())
18    }
19}
20
21#[derive(Debug)]
22pub struct ProgramSet<'expr> {
23    eval: EvalProgram,
24    domains: DomainRegistry,
25    _root_expr: Arc<Expr<'expr, FieldId>>,
26}
27
28impl<'expr> ProgramSet<'expr> {
29    pub fn eval_ops(&self) -> &[EvalOp] {
30        &self.eval.ops
31    }
32
33    pub fn domain(&self, id: DomainProgramId) -> Option<&DomainProgram> {
34        self.domains.domain(id)
35    }
36
37    pub fn root_domain(&self) -> Option<DomainProgramId> {
38        self.domains.root
39    }
40}
41
42#[derive(Debug, Default)]
43struct EvalProgram {
44    ops: Vec<EvalOp>,
45}
46
47#[derive(Debug)]
48pub enum EvalOp {
49    PushPredicate(OwnedFilter),
50    PushCompare {
51        left: ScalarExpr<FieldId>,
52        right: ScalarExpr<FieldId>,
53        op: CompareOp,
54    },
55    PushInList {
56        expr: ScalarExpr<FieldId>,
57        list: Vec<ScalarExpr<FieldId>>,
58        negated: bool,
59    },
60    PushIsNull {
61        expr: ScalarExpr<FieldId>,
62        negated: bool,
63    },
64    PushLiteral(bool),
65    FusedAnd {
66        field_id: FieldId,
67        filters: Vec<OwnedFilter>,
68    },
69    And {
70        child_count: usize,
71    },
72    Or {
73        child_count: usize,
74    },
75    Not {
76        domain: DomainProgramId,
77    },
78}
79
80#[derive(Debug, Clone)]
81pub struct OwnedFilter {
82    pub field_id: FieldId,
83    pub op: OwnedOperator,
84}
85
86#[derive(Debug, Clone)]
87pub enum OwnedOperator {
88    Equals(Literal),
89    Range {
90        lower: Bound<Literal>,
91        upper: Bound<Literal>,
92    },
93    GreaterThan(Literal),
94    GreaterThanOrEquals(Literal),
95    LessThan(Literal),
96    LessThanOrEquals(Literal),
97    In(Vec<Literal>),
98    StartsWith {
99        pattern: String,
100        case_sensitive: bool,
101    },
102    EndsWith {
103        pattern: String,
104        case_sensitive: bool,
105    },
106    Contains {
107        pattern: String,
108        case_sensitive: bool,
109    },
110    IsNull,
111    IsNotNull,
112}
113
114impl OwnedOperator {
115    pub fn to_operator(&self) -> Operator<'_> {
116        match self {
117            Self::Equals(lit) => Operator::Equals(lit.clone()),
118            Self::Range { lower, upper } => Operator::Range {
119                lower: lower.clone(),
120                upper: upper.clone(),
121            },
122            Self::GreaterThan(lit) => Operator::GreaterThan(lit.clone()),
123            Self::GreaterThanOrEquals(lit) => Operator::GreaterThanOrEquals(lit.clone()),
124            Self::LessThan(lit) => Operator::LessThan(lit.clone()),
125            Self::LessThanOrEquals(lit) => Operator::LessThanOrEquals(lit.clone()),
126            Self::In(values) => Operator::In(values.as_slice()),
127            Self::StartsWith {
128                pattern,
129                case_sensitive,
130            } => Operator::StartsWith {
131                pattern: pattern.clone(),
132                case_sensitive: *case_sensitive,
133            },
134            Self::EndsWith {
135                pattern,
136                case_sensitive,
137            } => Operator::EndsWith {
138                pattern: pattern.clone(),
139                case_sensitive: *case_sensitive,
140            },
141            Self::Contains {
142                pattern,
143                case_sensitive,
144            } => Operator::Contains {
145                pattern: pattern.clone(),
146                case_sensitive: *case_sensitive,
147            },
148            Self::IsNull => Operator::IsNull,
149            Self::IsNotNull => Operator::IsNotNull,
150        }
151    }
152}
153
154impl<'a> From<&'a Operator<'a>> for OwnedOperator {
155    fn from(op: &'a Operator<'a>) -> Self {
156        match op {
157            Operator::Equals(lit) => Self::Equals(lit.clone()),
158            Operator::Range { lower, upper } => Self::Range {
159                lower: lower.clone(),
160                upper: upper.clone(),
161            },
162            Operator::GreaterThan(lit) => Self::GreaterThan(lit.clone()),
163            Operator::GreaterThanOrEquals(lit) => Self::GreaterThanOrEquals(lit.clone()),
164            Operator::LessThan(lit) => Self::LessThan(lit.clone()),
165            Operator::LessThanOrEquals(lit) => Self::LessThanOrEquals(lit.clone()),
166            Operator::In(values) => Self::In((*values).to_vec()),
167            Operator::StartsWith {
168                pattern,
169                case_sensitive,
170            } => Self::StartsWith {
171                pattern: (*pattern).to_string(),
172                case_sensitive: *case_sensitive,
173            },
174            Operator::EndsWith {
175                pattern,
176                case_sensitive,
177            } => Self::EndsWith {
178                pattern: (*pattern).to_string(),
179                case_sensitive: *case_sensitive,
180            },
181            Operator::Contains {
182                pattern,
183                case_sensitive,
184            } => Self::Contains {
185                pattern: (*pattern).to_string(),
186                case_sensitive: *case_sensitive,
187            },
188            Operator::IsNull => Self::IsNull,
189            Operator::IsNotNull => Self::IsNotNull,
190        }
191    }
192}
193
194impl<'a> From<&'a Filter<'a, FieldId>> for OwnedFilter {
195    fn from(filter: &'a Filter<'a, FieldId>) -> Self {
196        Self {
197            field_id: filter.field_id,
198            op: OwnedOperator::from(&filter.op),
199        }
200    }
201}
202
203pub type DomainProgramId = u32;
204
205#[derive(Debug, Default)]
206pub struct DomainProgram {
207    ops: Vec<DomainOp>,
208}
209
210impl DomainProgram {
211    pub fn ops(&self) -> &[DomainOp] {
212        &self.ops
213    }
214}
215
216#[derive(Debug)]
217pub enum DomainOp {
218    PushFieldAll(FieldId),
219    PushCompareDomain {
220        left: ScalarExpr<FieldId>,
221        right: ScalarExpr<FieldId>,
222        op: CompareOp,
223        fields: Vec<FieldId>,
224    },
225    PushInListDomain {
226        expr: ScalarExpr<FieldId>,
227        list: Vec<ScalarExpr<FieldId>>,
228        fields: Vec<FieldId>,
229        negated: bool,
230    },
231    PushIsNullDomain {
232        expr: ScalarExpr<FieldId>,
233        fields: Vec<FieldId>,
234        negated: bool,
235    },
236    PushAllRows,
237    Union {
238        child_count: usize,
239    },
240    Intersect {
241        child_count: usize,
242    },
243}
244
245#[derive(Debug, Default)]
246struct DomainRegistry {
247    programs: Vec<DomainProgram>,
248    index: FxHashMap<ExprKey, DomainProgramId>,
249    root: Option<DomainProgramId>,
250}
251
252impl DomainRegistry {
253    fn domain(&self, id: DomainProgramId) -> Option<&DomainProgram> {
254        self.programs.get(id as usize)
255    }
256
257    fn ensure(&mut self, expr: &Expr<'_, FieldId>) -> DomainProgramId {
258        let key = ExprKey::new(expr);
259        if let Some(existing) = self.index.get(&key) {
260            return *existing;
261        }
262        let id = self.programs.len() as DomainProgramId;
263        let program = compile_domain(expr);
264        self.programs.push(program);
265        self.index.insert(key, id);
266        id
267    }
268}
269
270#[derive(Debug)]
271pub struct ProgramCompiler<'expr> {
272    root: Arc<Expr<'expr, FieldId>>,
273}
274
275impl<'expr> ProgramCompiler<'expr> {
276    pub fn new(root: Arc<Expr<'expr, FieldId>>) -> Self {
277        Self { root }
278    }
279
280    pub fn compile(self) -> LlkvResult<ProgramSet<'expr>> {
281        let ProgramCompiler { root } = self;
282        let mut domains = DomainRegistry::default();
283        let eval_ops = {
284            let root_ref = root.as_ref();
285            compile_eval(root_ref, &mut domains)?
286        };
287        let root_domain = {
288            let root_ref = root.as_ref();
289            domains.ensure(root_ref)
290        };
291        domains.root = Some(root_domain);
292        Ok(ProgramSet {
293            eval: EvalProgram { ops: eval_ops },
294            domains,
295            _root_expr: root,
296        })
297    }
298}
299
300pub fn normalize_predicate<'expr>(expr: Expr<'expr, FieldId>) -> Expr<'expr, FieldId> {
301    llkv_expr::normalization::normalize_predicate(expr)
302}
303
304#[derive(Clone, Copy)]
305enum EvalVisit<'expr> {
306    Enter(&'expr Expr<'expr, FieldId>),
307    Exit(&'expr Expr<'expr, FieldId>),
308    EmitFused { key: ExprKey, field_id: FieldId },
309}
310
311type PredicateVec = Vec<OwnedFilter>;
312
313fn compile_eval<'expr>(
314    expr: &'expr Expr<'expr, FieldId>,
315    domains: &mut DomainRegistry,
316) -> LlkvResult<Vec<EvalOp>> {
317    let mut ops = Vec::new();
318    let mut stack = vec![EvalVisit::Enter(expr)];
319    let mut fused: FxHashMap<ExprKey, PredicateVec> = FxHashMap::default();
320
321    while let Some(frame) = stack.pop() {
322        match frame {
323            EvalVisit::Enter(node) => match node {
324                Expr::And(children) => {
325                    if children.is_empty() {
326                        return Err(Error::InvalidArgumentError(
327                            "AND expression requires at least one predicate".into(),
328                        ));
329                    }
330                    if let Some((field_id, filters)) = gather_fused(children) {
331                        let key = ExprKey::new(node);
332                        fused.insert(key, filters);
333                        stack.push(EvalVisit::EmitFused { key, field_id });
334                    } else {
335                        stack.push(EvalVisit::Exit(node));
336                        for child in children.iter().rev() {
337                            stack.push(EvalVisit::Enter(child));
338                        }
339                    }
340                }
341                Expr::Or(children) => {
342                    if children.is_empty() {
343                        return Err(Error::InvalidArgumentError(
344                            "OR expression requires at least one predicate".into(),
345                        ));
346                    }
347                    stack.push(EvalVisit::Exit(node));
348                    for child in children.iter().rev() {
349                        stack.push(EvalVisit::Enter(child));
350                    }
351                }
352                Expr::Not(inner) => {
353                    stack.push(EvalVisit::Exit(node));
354                    stack.push(EvalVisit::Enter(inner));
355                }
356                _ => stack.push(EvalVisit::Exit(node)),
357            },
358            EvalVisit::Exit(node) => match node {
359                Expr::Pred(filter) => {
360                    ops.push(EvalOp::PushPredicate(OwnedFilter::from(filter)));
361                }
362                Expr::Compare { left, op, right } => {
363                    ops.push(EvalOp::PushCompare {
364                        left: left.clone(),
365                        right: right.clone(),
366                        op: *op,
367                    });
368                }
369                Expr::InList {
370                    expr,
371                    list,
372                    negated,
373                } => {
374                    ops.push(EvalOp::PushInList {
375                        expr: expr.clone(),
376                        list: list.clone(),
377                        negated: *negated,
378                    });
379                }
380                Expr::IsNull { expr, negated } => {
381                    ops.push(EvalOp::PushIsNull {
382                        expr: expr.clone(),
383                        negated: *negated,
384                    });
385                }
386                Expr::Literal(value) => ops.push(EvalOp::PushLiteral(*value)),
387                Expr::And(children) => ops.push(EvalOp::And {
388                    child_count: children.len(),
389                }),
390                Expr::Or(children) => ops.push(EvalOp::Or {
391                    child_count: children.len(),
392                }),
393                Expr::Not(inner) => {
394                    let id = domains.ensure(inner);
395                    ops.push(EvalOp::Not { domain: id });
396                }
397                Expr::Exists(_) => {
398                    return Err(Error::InvalidArgumentError(
399                        "EXISTS predicates are not supported in storage evaluation".into(),
400                    ));
401                }
402            },
403            EvalVisit::EmitFused { key, field_id } => {
404                let filters = fused
405                    .remove(&key)
406                    .ok_or_else(|| Error::Internal("missing fused predicate metadata".into()))?;
407                ops.push(EvalOp::FusedAnd { field_id, filters });
408            }
409        }
410    }
411
412    Ok(ops)
413}
414
415fn gather_fused<'expr>(
416    children: &'expr [Expr<'expr, FieldId>],
417) -> Option<(FieldId, Vec<OwnedFilter>)> {
418    if children.is_empty() {
419        return None;
420    }
421    let mut field: Option<FieldId> = None;
422    let mut out: Vec<OwnedFilter> = Vec::with_capacity(children.len());
423    for child in children {
424        match child {
425            Expr::Pred(filter) => {
426                if let Some(expected) = field {
427                    if expected != filter.field_id {
428                        return None;
429                    }
430                } else {
431                    field = Some(filter.field_id);
432                }
433                out.push(OwnedFilter::from(filter));
434            }
435            _ => return None,
436        }
437    }
438    field.map(|fid| (fid, out))
439}
440
441#[derive(Clone, Copy)]
442enum DomainVisit<'expr> {
443    Enter(&'expr Expr<'expr, FieldId>),
444    Exit(&'expr Expr<'expr, FieldId>),
445}
446
447fn compile_domain(expr: &Expr<'_, FieldId>) -> DomainProgram {
448    let mut ops = Vec::new();
449    let mut stack = vec![DomainVisit::Enter(expr)];
450
451    while let Some(frame) = stack.pop() {
452        match frame {
453            DomainVisit::Enter(node) => match node {
454                Expr::And(children) | Expr::Or(children) => {
455                    stack.push(DomainVisit::Exit(node));
456                    for child in children.iter().rev() {
457                        stack.push(DomainVisit::Enter(child));
458                    }
459                }
460                Expr::Not(inner) => {
461                    stack.push(DomainVisit::Exit(node));
462                    stack.push(DomainVisit::Enter(inner));
463                }
464                _ => stack.push(DomainVisit::Exit(node)),
465            },
466            DomainVisit::Exit(node) => match node {
467                Expr::Pred(filter) => ops.push(DomainOp::PushFieldAll(filter.field_id)),
468                Expr::Compare { left, op, right } => ops.push(DomainOp::PushCompareDomain {
469                    left: left.clone(),
470                    right: right.clone(),
471                    op: *op,
472                    fields: collect_fields([left, right]),
473                }),
474                Expr::InList {
475                    expr,
476                    list,
477                    negated,
478                } => {
479                    let mut exprs: Vec<&ScalarExpr<FieldId>> = Vec::with_capacity(list.len() + 1);
480                    exprs.push(expr);
481                    exprs.extend(list.iter());
482                    ops.push(DomainOp::PushInListDomain {
483                        expr: expr.clone(),
484                        list: list.clone(),
485                        fields: collect_fields(exprs),
486                        negated: *negated,
487                    });
488                }
489                Expr::IsNull { expr, negated } => {
490                    ops.push(DomainOp::PushIsNullDomain {
491                        expr: expr.clone(),
492                        fields: collect_fields([expr]),
493                        negated: *negated,
494                    });
495                }
496                Expr::Literal(_) => {
497                    // Boolean literals are defined for every row regardless of value.
498                    ops.push(DomainOp::PushAllRows);
499                }
500                Expr::And(children) => {
501                    if children.len() > 1 {
502                        ops.push(DomainOp::Intersect {
503                            child_count: children.len(),
504                        });
505                    }
506                }
507                Expr::Or(children) => {
508                    if children.len() > 1 {
509                        ops.push(DomainOp::Union {
510                            child_count: children.len(),
511                        });
512                    }
513                }
514                Expr::Not(_) => {
515                    // Domain equals child domain; no-op.
516                }
517                Expr::Exists(_) => {
518                    panic!("EXISTS predicates should not reach storage domain evaluation stage");
519                }
520            },
521        }
522    }
523
524    DomainProgram { ops }
525}
526
527fn collect_fields<'expr>(
528    exprs: impl IntoIterator<Item = &'expr ScalarExpr<FieldId>>,
529) -> Vec<FieldId> {
530    let mut seen: FxHashSet<FieldId> = FxHashSet::default();
531    let mut stack: Vec<&'expr ScalarExpr<FieldId>> = exprs.into_iter().collect();
532    while let Some(expr) = stack.pop() {
533        match expr {
534            ScalarExpr::Column(fid) => {
535                seen.insert(*fid);
536            }
537            ScalarExpr::Literal(_) => {}
538            ScalarExpr::Binary { left, right, .. } => {
539                stack.push(left);
540                stack.push(right);
541            }
542            ScalarExpr::Compare { left, right, .. } => {
543                stack.push(left);
544                stack.push(right);
545            }
546            ScalarExpr::Aggregate(agg) => match agg {
547                llkv_expr::expr::AggregateCall::CountStar => {}
548                llkv_expr::expr::AggregateCall::Count { expr, .. }
549                | llkv_expr::expr::AggregateCall::Sum { expr, .. }
550                | llkv_expr::expr::AggregateCall::Total { expr, .. }
551                | llkv_expr::expr::AggregateCall::Avg { expr, .. }
552                | llkv_expr::expr::AggregateCall::Min(expr)
553                | llkv_expr::expr::AggregateCall::Max(expr)
554                | llkv_expr::expr::AggregateCall::CountNulls(expr)
555                | llkv_expr::expr::AggregateCall::GroupConcat { expr, .. } => {
556                    stack.push(expr.as_ref());
557                }
558            },
559            ScalarExpr::GetField { base, .. } => {
560                stack.push(base);
561            }
562            ScalarExpr::Cast { expr, .. } => {
563                stack.push(expr.as_ref());
564            }
565            ScalarExpr::Not(expr) => {
566                stack.push(expr.as_ref());
567            }
568            ScalarExpr::IsNull { expr, .. } => {
569                stack.push(expr.as_ref());
570            }
571            ScalarExpr::Case {
572                operand,
573                branches,
574                else_expr,
575            } => {
576                if let Some(inner) = operand.as_deref() {
577                    stack.push(inner);
578                }
579                for (when_expr, then_expr) in branches {
580                    stack.push(when_expr);
581                    stack.push(then_expr);
582                }
583                if let Some(inner) = else_expr.as_deref() {
584                    stack.push(inner);
585                }
586            }
587            ScalarExpr::Coalesce(items) => {
588                for item in items {
589                    stack.push(item);
590                }
591            }
592            ScalarExpr::Random => {
593                // Random does not reference any fields
594            }
595            ScalarExpr::ScalarSubquery(_) => {
596                // Scalar subqueries are resolved separately at planning time
597            }
598        }
599    }
600    let mut fields: Vec<FieldId> = seen.into_iter().collect();
601    fields.sort_unstable();
602    fields
603}