1use std::ops::Bound;
4use std::sync::Arc;
5
6use llkv_expr::{CompareOp, Expr, Filter, Operator, ScalarExpr, literal::Literal};
7use llkv_result::{Error, Result as LlkvResult};
8use llkv_types::FieldId;
9use rustc_hash::{FxHashMap, FxHashSet};
10
11#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
12struct ExprKey(*const ());
13
14impl ExprKey {
15 #[inline]
16 fn new(expr: &Expr<'_, FieldId>) -> Self {
17 Self(expr as *const _ as *const ())
18 }
19}
20
21#[derive(Debug)]
22pub struct ProgramSet<'expr> {
23 eval: EvalProgram,
24 domains: DomainRegistry,
25 _root_expr: Arc<Expr<'expr, FieldId>>,
26}
27
28impl<'expr> ProgramSet<'expr> {
29 pub fn eval_ops(&self) -> &[EvalOp] {
30 &self.eval.ops
31 }
32
33 pub fn domain(&self, id: DomainProgramId) -> Option<&DomainProgram> {
34 self.domains.domain(id)
35 }
36
37 pub fn root_domain(&self) -> Option<DomainProgramId> {
38 self.domains.root
39 }
40}
41
42#[derive(Debug, Default)]
43struct EvalProgram {
44 ops: Vec<EvalOp>,
45}
46
47#[derive(Debug)]
48pub enum EvalOp {
49 PushPredicate(OwnedFilter),
50 PushCompare {
51 left: ScalarExpr<FieldId>,
52 right: ScalarExpr<FieldId>,
53 op: CompareOp,
54 },
55 PushInList {
56 expr: ScalarExpr<FieldId>,
57 list: Vec<ScalarExpr<FieldId>>,
58 negated: bool,
59 },
60 PushIsNull {
61 expr: ScalarExpr<FieldId>,
62 negated: bool,
63 },
64 PushLiteral(bool),
65 FusedAnd {
66 field_id: FieldId,
67 filters: Vec<OwnedFilter>,
68 },
69 And {
70 child_count: usize,
71 },
72 Or {
73 child_count: usize,
74 },
75 Not {
76 domain: DomainProgramId,
77 },
78}
79
80#[derive(Debug, Clone)]
81pub struct OwnedFilter {
82 pub field_id: FieldId,
83 pub op: OwnedOperator,
84}
85
86#[derive(Debug, Clone)]
87pub enum OwnedOperator {
88 Equals(Literal),
89 Range {
90 lower: Bound<Literal>,
91 upper: Bound<Literal>,
92 },
93 GreaterThan(Literal),
94 GreaterThanOrEquals(Literal),
95 LessThan(Literal),
96 LessThanOrEquals(Literal),
97 In(Vec<Literal>),
98 StartsWith {
99 pattern: String,
100 case_sensitive: bool,
101 },
102 EndsWith {
103 pattern: String,
104 case_sensitive: bool,
105 },
106 Contains {
107 pattern: String,
108 case_sensitive: bool,
109 },
110 IsNull,
111 IsNotNull,
112}
113
114impl OwnedOperator {
115 pub fn to_operator(&self) -> Operator<'_> {
116 match self {
117 Self::Equals(lit) => Operator::Equals(lit.clone()),
118 Self::Range { lower, upper } => Operator::Range {
119 lower: lower.clone(),
120 upper: upper.clone(),
121 },
122 Self::GreaterThan(lit) => Operator::GreaterThan(lit.clone()),
123 Self::GreaterThanOrEquals(lit) => Operator::GreaterThanOrEquals(lit.clone()),
124 Self::LessThan(lit) => Operator::LessThan(lit.clone()),
125 Self::LessThanOrEquals(lit) => Operator::LessThanOrEquals(lit.clone()),
126 Self::In(values) => Operator::In(values.as_slice()),
127 Self::StartsWith {
128 pattern,
129 case_sensitive,
130 } => Operator::StartsWith {
131 pattern: pattern.clone(),
132 case_sensitive: *case_sensitive,
133 },
134 Self::EndsWith {
135 pattern,
136 case_sensitive,
137 } => Operator::EndsWith {
138 pattern: pattern.clone(),
139 case_sensitive: *case_sensitive,
140 },
141 Self::Contains {
142 pattern,
143 case_sensitive,
144 } => Operator::Contains {
145 pattern: pattern.clone(),
146 case_sensitive: *case_sensitive,
147 },
148 Self::IsNull => Operator::IsNull,
149 Self::IsNotNull => Operator::IsNotNull,
150 }
151 }
152}
153
154impl<'a> From<&'a Operator<'a>> for OwnedOperator {
155 fn from(op: &'a Operator<'a>) -> Self {
156 match op {
157 Operator::Equals(lit) => Self::Equals(lit.clone()),
158 Operator::Range { lower, upper } => Self::Range {
159 lower: lower.clone(),
160 upper: upper.clone(),
161 },
162 Operator::GreaterThan(lit) => Self::GreaterThan(lit.clone()),
163 Operator::GreaterThanOrEquals(lit) => Self::GreaterThanOrEquals(lit.clone()),
164 Operator::LessThan(lit) => Self::LessThan(lit.clone()),
165 Operator::LessThanOrEquals(lit) => Self::LessThanOrEquals(lit.clone()),
166 Operator::In(values) => Self::In((*values).to_vec()),
167 Operator::StartsWith {
168 pattern,
169 case_sensitive,
170 } => Self::StartsWith {
171 pattern: (*pattern).to_string(),
172 case_sensitive: *case_sensitive,
173 },
174 Operator::EndsWith {
175 pattern,
176 case_sensitive,
177 } => Self::EndsWith {
178 pattern: (*pattern).to_string(),
179 case_sensitive: *case_sensitive,
180 },
181 Operator::Contains {
182 pattern,
183 case_sensitive,
184 } => Self::Contains {
185 pattern: (*pattern).to_string(),
186 case_sensitive: *case_sensitive,
187 },
188 Operator::IsNull => Self::IsNull,
189 Operator::IsNotNull => Self::IsNotNull,
190 }
191 }
192}
193
194impl<'a> From<&'a Filter<'a, FieldId>> for OwnedFilter {
195 fn from(filter: &'a Filter<'a, FieldId>) -> Self {
196 Self {
197 field_id: filter.field_id,
198 op: OwnedOperator::from(&filter.op),
199 }
200 }
201}
202
203pub type DomainProgramId = u32;
204
205#[derive(Debug, Default)]
206pub struct DomainProgram {
207 ops: Vec<DomainOp>,
208}
209
210impl DomainProgram {
211 pub fn ops(&self) -> &[DomainOp] {
212 &self.ops
213 }
214}
215
216#[derive(Debug)]
217pub enum DomainOp {
218 PushFieldAll(FieldId),
219 PushCompareDomain {
220 left: ScalarExpr<FieldId>,
221 right: ScalarExpr<FieldId>,
222 op: CompareOp,
223 fields: Vec<FieldId>,
224 },
225 PushInListDomain {
226 expr: ScalarExpr<FieldId>,
227 list: Vec<ScalarExpr<FieldId>>,
228 fields: Vec<FieldId>,
229 negated: bool,
230 },
231 PushIsNullDomain {
232 expr: ScalarExpr<FieldId>,
233 fields: Vec<FieldId>,
234 negated: bool,
235 },
236 PushAllRows,
237 Union {
238 child_count: usize,
239 },
240 Intersect {
241 child_count: usize,
242 },
243}
244
245#[derive(Debug, Default)]
246struct DomainRegistry {
247 programs: Vec<DomainProgram>,
248 index: FxHashMap<ExprKey, DomainProgramId>,
249 root: Option<DomainProgramId>,
250}
251
252impl DomainRegistry {
253 fn domain(&self, id: DomainProgramId) -> Option<&DomainProgram> {
254 self.programs.get(id as usize)
255 }
256
257 fn ensure(&mut self, expr: &Expr<'_, FieldId>) -> DomainProgramId {
258 let key = ExprKey::new(expr);
259 if let Some(existing) = self.index.get(&key) {
260 return *existing;
261 }
262 let id = self.programs.len() as DomainProgramId;
263 let program = compile_domain(expr);
264 self.programs.push(program);
265 self.index.insert(key, id);
266 id
267 }
268}
269
270#[derive(Debug)]
271pub struct ProgramCompiler<'expr> {
272 root: Arc<Expr<'expr, FieldId>>,
273}
274
275impl<'expr> ProgramCompiler<'expr> {
276 pub fn new(root: Arc<Expr<'expr, FieldId>>) -> Self {
277 Self { root }
278 }
279
280 pub fn compile(self) -> LlkvResult<ProgramSet<'expr>> {
281 let ProgramCompiler { root } = self;
282 let mut domains = DomainRegistry::default();
283 let eval_ops = {
284 let root_ref = root.as_ref();
285 compile_eval(root_ref, &mut domains)?
286 };
287 let root_domain = {
288 let root_ref = root.as_ref();
289 domains.ensure(root_ref)
290 };
291 domains.root = Some(root_domain);
292 Ok(ProgramSet {
293 eval: EvalProgram { ops: eval_ops },
294 domains,
295 _root_expr: root,
296 })
297 }
298}
299
300pub fn normalize_predicate<'expr>(expr: Expr<'expr, FieldId>) -> Expr<'expr, FieldId> {
301 llkv_expr::normalization::normalize_predicate(expr)
302}
303
304#[derive(Clone, Copy)]
305enum EvalVisit<'expr> {
306 Enter(&'expr Expr<'expr, FieldId>),
307 Exit(&'expr Expr<'expr, FieldId>),
308 EmitFused { key: ExprKey, field_id: FieldId },
309}
310
311type PredicateVec = Vec<OwnedFilter>;
312
313fn compile_eval<'expr>(
314 expr: &'expr Expr<'expr, FieldId>,
315 domains: &mut DomainRegistry,
316) -> LlkvResult<Vec<EvalOp>> {
317 let mut ops = Vec::new();
318 let mut stack = vec![EvalVisit::Enter(expr)];
319 let mut fused: FxHashMap<ExprKey, PredicateVec> = FxHashMap::default();
320
321 while let Some(frame) = stack.pop() {
322 match frame {
323 EvalVisit::Enter(node) => match node {
324 Expr::And(children) => {
325 if children.is_empty() {
326 return Err(Error::InvalidArgumentError(
327 "AND expression requires at least one predicate".into(),
328 ));
329 }
330 if let Some((field_id, filters)) = gather_fused(children) {
331 let key = ExprKey::new(node);
332 fused.insert(key, filters);
333 stack.push(EvalVisit::EmitFused { key, field_id });
334 } else {
335 stack.push(EvalVisit::Exit(node));
336 for child in children.iter().rev() {
337 stack.push(EvalVisit::Enter(child));
338 }
339 }
340 }
341 Expr::Or(children) => {
342 if children.is_empty() {
343 return Err(Error::InvalidArgumentError(
344 "OR expression requires at least one predicate".into(),
345 ));
346 }
347 stack.push(EvalVisit::Exit(node));
348 for child in children.iter().rev() {
349 stack.push(EvalVisit::Enter(child));
350 }
351 }
352 Expr::Not(inner) => {
353 stack.push(EvalVisit::Exit(node));
354 stack.push(EvalVisit::Enter(inner));
355 }
356 _ => stack.push(EvalVisit::Exit(node)),
357 },
358 EvalVisit::Exit(node) => match node {
359 Expr::Pred(filter) => {
360 ops.push(EvalOp::PushPredicate(OwnedFilter::from(filter)));
361 }
362 Expr::Compare { left, op, right } => {
363 ops.push(EvalOp::PushCompare {
364 left: left.clone(),
365 right: right.clone(),
366 op: *op,
367 });
368 }
369 Expr::InList {
370 expr,
371 list,
372 negated,
373 } => {
374 ops.push(EvalOp::PushInList {
375 expr: expr.clone(),
376 list: list.clone(),
377 negated: *negated,
378 });
379 }
380 Expr::IsNull { expr, negated } => {
381 ops.push(EvalOp::PushIsNull {
382 expr: expr.clone(),
383 negated: *negated,
384 });
385 }
386 Expr::Literal(value) => ops.push(EvalOp::PushLiteral(*value)),
387 Expr::And(children) => ops.push(EvalOp::And {
388 child_count: children.len(),
389 }),
390 Expr::Or(children) => ops.push(EvalOp::Or {
391 child_count: children.len(),
392 }),
393 Expr::Not(inner) => {
394 let id = domains.ensure(inner);
395 ops.push(EvalOp::Not { domain: id });
396 }
397 Expr::Exists(_) => {
398 return Err(Error::InvalidArgumentError(
399 "EXISTS predicates are not supported in storage evaluation".into(),
400 ));
401 }
402 },
403 EvalVisit::EmitFused { key, field_id } => {
404 let filters = fused
405 .remove(&key)
406 .ok_or_else(|| Error::Internal("missing fused predicate metadata".into()))?;
407 ops.push(EvalOp::FusedAnd { field_id, filters });
408 }
409 }
410 }
411
412 Ok(ops)
413}
414
415fn gather_fused<'expr>(
416 children: &'expr [Expr<'expr, FieldId>],
417) -> Option<(FieldId, Vec<OwnedFilter>)> {
418 if children.is_empty() {
419 return None;
420 }
421 let mut field: Option<FieldId> = None;
422 let mut out: Vec<OwnedFilter> = Vec::with_capacity(children.len());
423 for child in children {
424 match child {
425 Expr::Pred(filter) => {
426 if let Some(expected) = field {
427 if expected != filter.field_id {
428 return None;
429 }
430 } else {
431 field = Some(filter.field_id);
432 }
433 out.push(OwnedFilter::from(filter));
434 }
435 _ => return None,
436 }
437 }
438 field.map(|fid| (fid, out))
439}
440
441#[derive(Clone, Copy)]
442enum DomainVisit<'expr> {
443 Enter(&'expr Expr<'expr, FieldId>),
444 Exit(&'expr Expr<'expr, FieldId>),
445}
446
447fn compile_domain(expr: &Expr<'_, FieldId>) -> DomainProgram {
448 let mut ops = Vec::new();
449 let mut stack = vec![DomainVisit::Enter(expr)];
450
451 while let Some(frame) = stack.pop() {
452 match frame {
453 DomainVisit::Enter(node) => match node {
454 Expr::And(children) | Expr::Or(children) => {
455 stack.push(DomainVisit::Exit(node));
456 for child in children.iter().rev() {
457 stack.push(DomainVisit::Enter(child));
458 }
459 }
460 Expr::Not(inner) => {
461 stack.push(DomainVisit::Exit(node));
462 stack.push(DomainVisit::Enter(inner));
463 }
464 _ => stack.push(DomainVisit::Exit(node)),
465 },
466 DomainVisit::Exit(node) => match node {
467 Expr::Pred(filter) => ops.push(DomainOp::PushFieldAll(filter.field_id)),
468 Expr::Compare { left, op, right } => ops.push(DomainOp::PushCompareDomain {
469 left: left.clone(),
470 right: right.clone(),
471 op: *op,
472 fields: collect_fields([left, right]),
473 }),
474 Expr::InList {
475 expr,
476 list,
477 negated,
478 } => {
479 let mut exprs: Vec<&ScalarExpr<FieldId>> = Vec::with_capacity(list.len() + 1);
480 exprs.push(expr);
481 exprs.extend(list.iter());
482 ops.push(DomainOp::PushInListDomain {
483 expr: expr.clone(),
484 list: list.clone(),
485 fields: collect_fields(exprs),
486 negated: *negated,
487 });
488 }
489 Expr::IsNull { expr, negated } => {
490 ops.push(DomainOp::PushIsNullDomain {
491 expr: expr.clone(),
492 fields: collect_fields([expr]),
493 negated: *negated,
494 });
495 }
496 Expr::Literal(_) => {
497 ops.push(DomainOp::PushAllRows);
499 }
500 Expr::And(children) => {
501 if children.len() > 1 {
502 ops.push(DomainOp::Intersect {
503 child_count: children.len(),
504 });
505 }
506 }
507 Expr::Or(children) => {
508 if children.len() > 1 {
509 ops.push(DomainOp::Union {
510 child_count: children.len(),
511 });
512 }
513 }
514 Expr::Not(_) => {
515 }
517 Expr::Exists(_) => {
518 panic!("EXISTS predicates should not reach storage domain evaluation stage");
519 }
520 },
521 }
522 }
523
524 DomainProgram { ops }
525}
526
527fn collect_fields<'expr>(
528 exprs: impl IntoIterator<Item = &'expr ScalarExpr<FieldId>>,
529) -> Vec<FieldId> {
530 let mut seen: FxHashSet<FieldId> = FxHashSet::default();
531 let mut stack: Vec<&'expr ScalarExpr<FieldId>> = exprs.into_iter().collect();
532 while let Some(expr) = stack.pop() {
533 match expr {
534 ScalarExpr::Column(fid) => {
535 seen.insert(*fid);
536 }
537 ScalarExpr::Literal(_) => {}
538 ScalarExpr::Binary { left, right, .. } => {
539 stack.push(left);
540 stack.push(right);
541 }
542 ScalarExpr::Compare { left, right, .. } => {
543 stack.push(left);
544 stack.push(right);
545 }
546 ScalarExpr::Aggregate(agg) => match agg {
547 llkv_expr::expr::AggregateCall::CountStar => {}
548 llkv_expr::expr::AggregateCall::Count { expr, .. }
549 | llkv_expr::expr::AggregateCall::Sum { expr, .. }
550 | llkv_expr::expr::AggregateCall::Total { expr, .. }
551 | llkv_expr::expr::AggregateCall::Avg { expr, .. }
552 | llkv_expr::expr::AggregateCall::Min(expr)
553 | llkv_expr::expr::AggregateCall::Max(expr)
554 | llkv_expr::expr::AggregateCall::CountNulls(expr)
555 | llkv_expr::expr::AggregateCall::GroupConcat { expr, .. } => {
556 stack.push(expr.as_ref());
557 }
558 },
559 ScalarExpr::GetField { base, .. } => {
560 stack.push(base);
561 }
562 ScalarExpr::Cast { expr, .. } => {
563 stack.push(expr.as_ref());
564 }
565 ScalarExpr::Not(expr) => {
566 stack.push(expr.as_ref());
567 }
568 ScalarExpr::IsNull { expr, .. } => {
569 stack.push(expr.as_ref());
570 }
571 ScalarExpr::Case {
572 operand,
573 branches,
574 else_expr,
575 } => {
576 if let Some(inner) = operand.as_deref() {
577 stack.push(inner);
578 }
579 for (when_expr, then_expr) in branches {
580 stack.push(when_expr);
581 stack.push(then_expr);
582 }
583 if let Some(inner) = else_expr.as_deref() {
584 stack.push(inner);
585 }
586 }
587 ScalarExpr::Coalesce(items) => {
588 for item in items {
589 stack.push(item);
590 }
591 }
592 ScalarExpr::Random => {
593 }
595 ScalarExpr::ScalarSubquery(_) => {
596 }
598 }
599 }
600 let mut fields: Vec<FieldId> = seen.into_iter().collect();
601 fields.sort_unstable();
602 fields
603}