Skip to main content

jetro_core/
plan.rs

1//! Logical plan IR.
2//!
3//! Relational-style representation of a query: `Scan`, `Filter`, `Project`,
4//! `Aggregate`, `Sort`, `Limit`, `Join`.  Built from a compiled `Program`
5//! by walking method calls on arrays; unrecognised opcode sequences fall
6//! through as an opaque `Raw` node.
7//!
8//! The logical plan enables rewrites that are hard to express at the
9//! opcode level: e.g. predicate pushdown, filter-then-project reorder,
10//! join detection across `let` bindings.
11//!
12//! This is a scaffold — the rewrite rules library is intentionally small.
13
14use std::sync::Arc;
15use super::vm::{Program, Opcode, BuiltinMethod};
16
17#[derive(Debug, Clone)]
18pub enum LogicalPlan {
19    /// Root scan — produces the input document.
20    Scan,
21    /// Navigate into a field chain from the scan.
22    Path(Vec<Arc<str>>),
23    /// Filter rows by a boolean predicate program.
24    Filter { input: Box<LogicalPlan>, pred: Arc<Program> },
25    /// Project / transform each row.
26    Project { input: Box<LogicalPlan>, map: Arc<Program> },
27    /// Aggregate to a single scalar.
28    Aggregate { input: Box<LogicalPlan>, op: AggOp, arg: Option<Arc<Program>> },
29    /// Sort rows.
30    Sort { input: Box<LogicalPlan>, key: Option<Arc<Program>>, desc: bool },
31    /// Limit / TopN.
32    Limit { input: Box<LogicalPlan>, n: usize },
33    /// Join two plans on matching keys (stubbed — detected but not yet
34    /// rewritten into a fused execution).
35    Join { left: Box<LogicalPlan>, right: Box<LogicalPlan>, on: Arc<Program> },
36    /// Opaque fallback: opcode sequence not yet lifted.
37    Raw(Arc<Program>),
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum AggOp { Count, Sum, Avg, Min, Max, First, Last }
42
43impl LogicalPlan {
44    /// Lift a compiled `Program` into a `LogicalPlan`.  Best-effort: falls
45    /// back to `Raw` for opcode sequences that don't fit the relational shape.
46    pub fn lift(program: &Program) -> LogicalPlan {
47        let ops = &program.ops;
48        if ops.is_empty() { return LogicalPlan::Raw(Arc::new(program.clone())); }
49
50        let mut plan: Option<LogicalPlan> = None;
51        for op in ops.iter() {
52            plan = Some(match (plan.take(), op) {
53                (None, Opcode::PushRoot) => LogicalPlan::Scan,
54                (None, Opcode::RootChain(chain)) => {
55                    LogicalPlan::Path(chain.iter().cloned().collect())
56                }
57                (Some(p), Opcode::GetField(k)) => match p {
58                    LogicalPlan::Scan => LogicalPlan::Path(vec![k.clone()]),
59                    LogicalPlan::Path(mut v) => { v.push(k.clone()); LogicalPlan::Path(v) }
60                    other => LogicalPlan::Project {
61                        input: Box::new(other),
62                        map: Arc::new({
63                            let ops_vec = vec![Opcode::PushCurrent, Opcode::GetField(k.clone())];
64                            let ics = crate::vm::fresh_ics(ops_vec.len());
65                            Program {
66                                ops: Arc::from(ops_vec),
67                                source: Arc::from(""), id: 0, is_structural: true, ics,
68                            }
69                        }),
70                    },
71                }
72                (Some(p), Opcode::RootChain(chain)) => {
73                    let mut v: Vec<Arc<str>> = match p { LogicalPlan::Scan => vec![], _ => return LogicalPlan::Raw(Arc::new(program.clone())) };
74                    for k in chain.iter() { v.push(k.clone()); }
75                    LogicalPlan::Path(v)
76                }
77                (Some(p), Opcode::CallMethod(c)) => lift_method(p, c),
78                (Some(p), Opcode::FilterMap { pred, map }) => LogicalPlan::Project {
79                    input: Box::new(LogicalPlan::Filter {
80                        input: Box::new(p),
81                        pred:  Arc::clone(pred),
82                    }),
83                    map: Arc::clone(map),
84                },
85                (Some(p), Opcode::MapFilter { map, pred }) => LogicalPlan::Filter {
86                    input: Box::new(LogicalPlan::Project {
87                        input: Box::new(p),
88                        map:   Arc::clone(map),
89                    }),
90                    pred: Arc::clone(pred),
91                },
92                (Some(p), Opcode::FilterCount(pred)) => LogicalPlan::Aggregate {
93                    input: Box::new(LogicalPlan::Filter {
94                        input: Box::new(p),
95                        pred:  Arc::clone(pred),
96                    }),
97                    op:    AggOp::Count,
98                    arg:   None,
99                },
100                (Some(p), Opcode::MapSum(f)) => LogicalPlan::Aggregate {
101                    input: Box::new(p),
102                    op:    AggOp::Sum,
103                    arg:   Some(Arc::clone(f)),
104                },
105                (Some(p), Opcode::MapAvg(f)) => LogicalPlan::Aggregate {
106                    input: Box::new(p),
107                    op:    AggOp::Avg,
108                    arg:   Some(Arc::clone(f)),
109                },
110                (Some(p), Opcode::FilterMapSum { pred, map }) => LogicalPlan::Aggregate {
111                    input: Box::new(LogicalPlan::Filter {
112                        input: Box::new(p),
113                        pred:  Arc::clone(pred),
114                    }),
115                    op:    AggOp::Sum,
116                    arg:   Some(Arc::clone(map)),
117                },
118                (Some(p), Opcode::FilterMapAvg { pred, map }) => LogicalPlan::Aggregate {
119                    input: Box::new(LogicalPlan::Filter {
120                        input: Box::new(p),
121                        pred:  Arc::clone(pred),
122                    }),
123                    op:    AggOp::Avg,
124                    arg:   Some(Arc::clone(map)),
125                },
126                (Some(p), Opcode::TopN { n, asc }) => LogicalPlan::Limit {
127                    input: Box::new(LogicalPlan::Sort {
128                        input: Box::new(p),
129                        key:   None,
130                        desc:  !asc,
131                    }),
132                    n: *n,
133                },
134                _ => return LogicalPlan::Raw(Arc::new(program.clone())),
135            });
136        }
137        plan.unwrap_or(LogicalPlan::Raw(Arc::new(program.clone())))
138    }
139
140    /// True if this plan is a pure aggregate (reduces to scalar).
141    pub fn is_aggregate(&self) -> bool {
142        matches!(self, LogicalPlan::Aggregate { .. })
143    }
144
145    /// Depth of the plan tree (for cost).
146    pub fn depth(&self) -> usize {
147        match self {
148            LogicalPlan::Scan | LogicalPlan::Path(_) | LogicalPlan::Raw(_) => 1,
149            LogicalPlan::Filter { input, .. } | LogicalPlan::Project { input, .. }
150                | LogicalPlan::Sort { input, .. } | LogicalPlan::Limit { input, .. }
151                | LogicalPlan::Aggregate { input, .. } => 1 + input.depth(),
152            LogicalPlan::Join { left, right, .. } =>
153                1 + left.depth().max(right.depth()),
154        }
155    }
156}
157
158fn lift_method(input: LogicalPlan, c: &Arc<super::vm::CompiledCall>) -> LogicalPlan {
159    match c.method {
160        BuiltinMethod::Filter if !c.sub_progs.is_empty() => LogicalPlan::Filter {
161            input: Box::new(input),
162            pred:  Arc::clone(&c.sub_progs[0]),
163        },
164        BuiltinMethod::Map if !c.sub_progs.is_empty() => LogicalPlan::Project {
165            input: Box::new(input),
166            map:   Arc::clone(&c.sub_progs[0]),
167        },
168        BuiltinMethod::Sort => LogicalPlan::Sort {
169            input: Box::new(input),
170            key:   c.sub_progs.first().map(Arc::clone),
171            desc:  false,
172        },
173        BuiltinMethod::Count | BuiltinMethod::Len =>
174            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Count, arg: None },
175        BuiltinMethod::Sum =>
176            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Sum, arg: c.sub_progs.first().map(Arc::clone) },
177        BuiltinMethod::Avg =>
178            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Avg, arg: c.sub_progs.first().map(Arc::clone) },
179        BuiltinMethod::Min =>
180            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Min, arg: None },
181        BuiltinMethod::Max =>
182            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Max, arg: None },
183        BuiltinMethod::First =>
184            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::First, arg: None },
185        BuiltinMethod::Last =>
186            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Last, arg: None },
187        _ => LogicalPlan::Raw(Arc::new({
188            let ops_vec = vec![Opcode::CallMethod(Arc::clone(c))];
189            let ics = crate::vm::fresh_ics(ops_vec.len());
190            Program {
191                ops: Arc::from(ops_vec),
192                source: Arc::from(""), id: 0, is_structural: false, ics,
193            }
194        })),
195    }
196}
197
198// ── Rewrite rules ─────────────────────────────────────────────────────────────
199
200/// Push a filter down through a project when the project's map is a pure
201/// field-access (equi-projection).  Enables evaluating predicate on the
202/// larger pre-project rowset, which is often cheaper.
203///
204/// This is a skeleton — only the trivially-safe case is rewritten.
205pub fn pushdown_filter(plan: LogicalPlan) -> LogicalPlan {
206    match plan {
207        LogicalPlan::Filter { input, pred } => match *input {
208            // filter(sort(x)) → sort(filter(x))   [sort is order-preserving filter-wise]
209            LogicalPlan::Sort { input: inner, key, desc } => LogicalPlan::Sort {
210                input: Box::new(pushdown_filter(LogicalPlan::Filter { input: inner, pred })),
211                key, desc,
212            },
213            other => LogicalPlan::Filter { input: Box::new(pushdown_filter(other)), pred },
214        },
215        LogicalPlan::Project { input, map } => LogicalPlan::Project {
216            input: Box::new(pushdown_filter(*input)),
217            map,
218        },
219        LogicalPlan::Sort { input, key, desc } => LogicalPlan::Sort {
220            input: Box::new(pushdown_filter(*input)),
221            key, desc,
222        },
223        LogicalPlan::Limit { input, n } => LogicalPlan::Limit {
224            input: Box::new(pushdown_filter(*input)),
225            n,
226        },
227        LogicalPlan::Aggregate { input, op, arg } => LogicalPlan::Aggregate {
228            input: Box::new(pushdown_filter(*input)),
229            op, arg,
230        },
231        other => other,
232    }
233}
234
235// ── Lowering: LogicalPlan → Program ──────────────────────────────────────────
236
237/// Compile a `LogicalPlan` back to a flat `Program`.  Inverse of `lift`.
238/// Lifting then lowering should produce a semantically-equivalent program
239/// (not necessarily byte-identical — the lowered form is canonicalised).
240pub fn lower(plan: &LogicalPlan) -> Arc<Program> {
241    let mut ops = Vec::new();
242    emit(plan, &mut ops);
243    let ics = crate::vm::fresh_ics(ops.len());
244    Arc::new(Program {
245        ops:           ops.into(),
246        source:        Arc::from("<lowered>"),
247        id:            0,
248        is_structural: false,
249        ics,
250    })
251}
252
253fn emit(plan: &LogicalPlan, ops: &mut Vec<super::vm::Opcode>) {
254    use super::vm::Opcode;
255    match plan {
256        LogicalPlan::Scan => ops.push(Opcode::PushRoot),
257        LogicalPlan::Path(ks) => {
258            ops.push(Opcode::RootChain(ks.iter().cloned().collect::<Vec<_>>().into()));
259        }
260        LogicalPlan::Filter { input, pred } => {
261            emit(input, ops);
262            ops.push(Opcode::InlineFilter(Arc::clone(pred)));
263        }
264        LogicalPlan::Project { input, map } => {
265            emit(input, ops);
266            ops.push(map_as_call(map));
267        }
268        LogicalPlan::Sort { input, key, desc } => {
269            emit(input, ops);
270            ops.push(sort_as_call(key.as_ref()));
271            if *desc { ops.push(reverse_call()); }
272        }
273        LogicalPlan::Limit { input, n } => {
274            emit(input, ops);
275            ops.push(Opcode::TopN { n: *n, asc: true });
276        }
277        LogicalPlan::Aggregate { input, op, arg } => {
278            emit(input, ops);
279            match op {
280                AggOp::Count => ops.push(noarg_call(super::vm::BuiltinMethod::Count, "count")),
281                AggOp::Sum if arg.is_some() => ops.push(Opcode::MapSum(Arc::clone(arg.as_ref().unwrap()))),
282                AggOp::Avg if arg.is_some() => ops.push(Opcode::MapAvg(Arc::clone(arg.as_ref().unwrap()))),
283                AggOp::Sum => ops.push(noarg_call(super::vm::BuiltinMethod::Sum, "sum")),
284                AggOp::Avg => ops.push(noarg_call(super::vm::BuiltinMethod::Avg, "avg")),
285                AggOp::Min => ops.push(noarg_call(super::vm::BuiltinMethod::Min, "min")),
286                AggOp::Max => ops.push(noarg_call(super::vm::BuiltinMethod::Max, "max")),
287                AggOp::First => ops.push(noarg_call(super::vm::BuiltinMethod::First, "first")),
288                AggOp::Last  => ops.push(noarg_call(super::vm::BuiltinMethod::Last, "last")),
289            }
290        }
291        LogicalPlan::Join { left, right: _, on: _ } => {
292            // Placeholder: emit left only (no fused join runtime yet).
293            emit(left, ops);
294        }
295        LogicalPlan::Raw(p) => {
296            for op in p.ops.iter() { ops.push(op.clone()); }
297        }
298    }
299}
300
301fn noarg_call(method: super::vm::BuiltinMethod, name: &str) -> super::vm::Opcode {
302    use super::vm::{Opcode, CompiledCall};
303    Opcode::CallMethod(Arc::new(CompiledCall {
304        method, name: Arc::from(name),
305        sub_progs: Arc::from(Vec::new()),
306        orig_args: Arc::from(Vec::new()),
307    }))
308}
309
310fn reverse_call() -> super::vm::Opcode {
311    noarg_call(super::vm::BuiltinMethod::Reverse, "reverse")
312}
313
314fn map_as_call(map: &Arc<Program>) -> super::vm::Opcode {
315    use super::vm::{Opcode, CompiledCall, BuiltinMethod};
316    Opcode::CallMethod(Arc::new(CompiledCall {
317        method:    BuiltinMethod::Map,
318        name:      Arc::from("map"),
319        sub_progs: Arc::from(vec![Arc::clone(map)]),
320        orig_args: Arc::from(Vec::new()),
321    }))
322}
323
324fn sort_as_call(key: Option<&Arc<Program>>) -> super::vm::Opcode {
325    use super::vm::{Opcode, CompiledCall, BuiltinMethod};
326    let sub_progs: Vec<Arc<Program>> = key.map(|k| vec![Arc::clone(k)]).unwrap_or_default();
327    Opcode::CallMethod(Arc::new(CompiledCall {
328        method:    BuiltinMethod::Sort,
329        name:      Arc::from("sort"),
330        sub_progs: sub_progs.into(),
331        orig_args: Arc::from(Vec::new()),
332    }))
333}
334
335// ── Join detection ────────────────────────────────────────────────────────────
336
337/// Walk a `LogicalPlan` looking for filter predicates that compare two
338/// distinct identifiers — a candidate equi-join between correlated scans.
339/// Returns a vector of (left-path, right-path) fragment pairs per detected
340/// candidate.  Wiring actual `Join` rewrite is future work.
341pub fn detect_join_candidates(plan: &LogicalPlan) -> Vec<JoinCandidate> {
342    let mut out = Vec::new();
343    walk(plan, &mut out);
344    out
345}
346
347#[derive(Debug, Clone)]
348pub struct JoinCandidate {
349    /// Textual description of left side (from pred source or ident).
350    pub left:  String,
351    pub right: String,
352}
353
354fn walk(plan: &LogicalPlan, out: &mut Vec<JoinCandidate>) {
355    match plan {
356        LogicalPlan::Filter { input, pred } => {
357            if let Some(j) = detect_eq_join(pred) { out.push(j); }
358            walk(input, out);
359        }
360        LogicalPlan::Project { input, .. }
361            | LogicalPlan::Sort { input, .. }
362            | LogicalPlan::Limit { input, .. }
363            | LogicalPlan::Aggregate { input, .. } => walk(input, out),
364        LogicalPlan::Join { left, right, .. } => { walk(left, out); walk(right, out); }
365        _ => {}
366    }
367}
368
369fn detect_eq_join(pred: &Arc<Program>) -> Option<JoinCandidate> {
370    use crate::vm::Opcode;
371    // Look for pattern: LoadIdent + GetField* + LoadIdent + GetField* + Eq
372    let ops = &pred.ops;
373    if ops.len() < 3 { return None; }
374    let last = ops.last()?;
375    if !matches!(last, Opcode::Eq) { return None; }
376    // Split ops by finding the second LoadIdent (crude): two independent
377    // chains, both starting with a LoadIdent.
378    let ident_positions: Vec<usize> = ops.iter().enumerate()
379        .filter(|(_, o)| matches!(o, Opcode::LoadIdent(_)))
380        .map(|(i, _)| i).collect();
381    if ident_positions.len() != 2 { return None; }
382    let a = describe_chain(&ops[ident_positions[0]..ident_positions[1]]);
383    let b = describe_chain(&ops[ident_positions[1]..ops.len()-1]);
384    if a == b { return None; }
385    Some(JoinCandidate { left: a, right: b })
386}
387
388fn describe_chain(ops: &[super::vm::Opcode]) -> String {
389    use super::vm::Opcode;
390    let mut s = String::new();
391    for op in ops {
392        match op {
393            Opcode::LoadIdent(n) => s.push_str(n),
394            Opcode::GetField(k)  => { s.push('.'); s.push_str(k); }
395            _ => {}
396        }
397    }
398    s
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::vm::Compiler;
405
406    #[test]
407    fn lift_path() {
408        let p = Compiler::compile_str("$.store.books").unwrap();
409        match LogicalPlan::lift(&p) {
410            LogicalPlan::Path(v) => {
411                assert_eq!(v.len(), 2);
412                assert_eq!(v[0].as_ref(), "store");
413                assert_eq!(v[1].as_ref(), "books");
414            }
415            other => panic!("expected Path, got {:?}", other),
416        }
417    }
418
419    #[test]
420    fn lift_filter_map() {
421        let p = Compiler::compile_str("$.books.filter(@.price > 10).map(@.title)").unwrap();
422        let plan = LogicalPlan::lift(&p);
423        // filter+map fuses to FilterMap opcode → lifts to Project(Filter(..))
424        assert!(plan.depth() >= 2);
425    }
426
427    #[test]
428    fn lift_aggregate() {
429        let p = Compiler::compile_str("$.books.count()").unwrap();
430        let plan = LogicalPlan::lift(&p);
431        assert!(plan.is_aggregate());
432    }
433
434    #[test]
435    fn roundtrip_lower_preserves_semantics() {
436        use crate::vm::VM;
437        use serde_json::json;
438        let doc = json!({"store": {"books": [{"price": 20}, {"price": 5}]}});
439        let p = Compiler::compile_str("$.store.books.filter(@.price > 10).count()").unwrap();
440        let plan = LogicalPlan::lift(&p);
441        let lowered = lower(&plan);
442        let mut vm = VM::new();
443        let original = vm.execute(&p, &doc).unwrap();
444        let round    = vm.execute(&lowered, &doc).unwrap();
445        assert_eq!(original, round);
446    }
447
448    #[test]
449    fn detect_join_candidates_finds_equi_join() {
450        // Two different identifiers compared by eq → join candidate.
451        let p = Compiler::compile_str("$.x.filter(a.id == b.id)").unwrap();
452        let plan = LogicalPlan::lift(&p);
453        let candidates = detect_join_candidates(&plan);
454        assert!(!candidates.is_empty(), "should detect a.id == b.id as join");
455    }
456}