Skip to main content

jetro_core/
plan.rs

1//! Logical plan IR.
2//!
3//! Relational-style representation of a query: `Scan`, `Filter`, `Project`,
4//! `Aggregate`, `Sort`, `Limit`, `Join`.  Built from a compiled `Program`
5//! by walking method calls on arrays; unrecognised opcode sequences fall
6//! through as an opaque `Raw` node.
7//!
8//! The logical plan enables rewrites that are hard to express at the
9//! opcode level: e.g. predicate pushdown, filter-then-project reorder,
10//! join detection across `let` bindings.
11//!
12//! This is a scaffold — the rewrite rules library is intentionally small.
13
14use std::sync::Arc;
15use super::vm::{Program, Opcode, BuiltinMethod};
16
17#[derive(Debug, Clone)]
18pub enum LogicalPlan {
19    /// Root scan — produces the input document.
20    Scan,
21    /// Navigate into a field chain from the scan.
22    Path(Vec<Arc<str>>),
23    /// Filter rows by a boolean predicate program.
24    Filter { input: Box<LogicalPlan>, pred: Arc<Program> },
25    /// Project / transform each row.
26    Project { input: Box<LogicalPlan>, map: Arc<Program> },
27    /// Aggregate to a single scalar.
28    Aggregate { input: Box<LogicalPlan>, op: AggOp, arg: Option<Arc<Program>> },
29    /// Sort rows.
30    Sort { input: Box<LogicalPlan>, key: Option<Arc<Program>>, desc: bool },
31    /// Limit / TopN.
32    Limit { input: Box<LogicalPlan>, n: usize },
33    /// Join two plans on matching keys (stubbed — detected but not yet
34    /// rewritten into a fused execution).
35    Join { left: Box<LogicalPlan>, right: Box<LogicalPlan>, on: Arc<Program> },
36    /// Opaque fallback: opcode sequence not yet lifted.
37    Raw(Arc<Program>),
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum AggOp { Count, Sum, Avg, Min, Max, First, Last }
42
43impl LogicalPlan {
44    /// Lift a compiled `Program` into a `LogicalPlan`.  Best-effort: falls
45    /// back to `Raw` for opcode sequences that don't fit the relational shape.
46    pub fn lift(program: &Program) -> LogicalPlan {
47        let ops = &program.ops;
48        if ops.is_empty() { return LogicalPlan::Raw(Arc::new(program.clone())); }
49
50        let mut plan: Option<LogicalPlan> = None;
51        for op in ops.iter() {
52            plan = Some(match (plan.take(), op) {
53                (None, Opcode::PushRoot) => LogicalPlan::Scan,
54                (None, Opcode::RootChain(chain)) => {
55                    LogicalPlan::Path(chain.iter().cloned().collect())
56                }
57                (Some(p), Opcode::GetField(k)) => match p {
58                    LogicalPlan::Scan => LogicalPlan::Path(vec![k.clone()]),
59                    LogicalPlan::Path(mut v) => { v.push(k.clone()); LogicalPlan::Path(v) }
60                    other => LogicalPlan::Project {
61                        input: Box::new(other),
62                        map: Arc::new(Program {
63                            ops: Arc::from(vec![Opcode::PushCurrent, Opcode::GetField(k.clone())]),
64                            source: Arc::from(""), id: 0, is_structural: true,
65                        }),
66                    },
67                }
68                (Some(p), Opcode::RootChain(chain)) => {
69                    let mut v: Vec<Arc<str>> = match p { LogicalPlan::Scan => vec![], _ => return LogicalPlan::Raw(Arc::new(program.clone())) };
70                    for k in chain.iter() { v.push(k.clone()); }
71                    LogicalPlan::Path(v)
72                }
73                (Some(p), Opcode::CallMethod(c)) => lift_method(p, c),
74                (Some(p), Opcode::FilterMap { pred, map }) => LogicalPlan::Project {
75                    input: Box::new(LogicalPlan::Filter {
76                        input: Box::new(p),
77                        pred:  Arc::clone(pred),
78                    }),
79                    map: Arc::clone(map),
80                },
81                (Some(p), Opcode::FilterCount(pred)) => LogicalPlan::Aggregate {
82                    input: Box::new(LogicalPlan::Filter {
83                        input: Box::new(p),
84                        pred:  Arc::clone(pred),
85                    }),
86                    op:    AggOp::Count,
87                    arg:   None,
88                },
89                (Some(p), Opcode::MapSum(f)) => LogicalPlan::Aggregate {
90                    input: Box::new(p),
91                    op:    AggOp::Sum,
92                    arg:   Some(Arc::clone(f)),
93                },
94                (Some(p), Opcode::MapAvg(f)) => LogicalPlan::Aggregate {
95                    input: Box::new(p),
96                    op:    AggOp::Avg,
97                    arg:   Some(Arc::clone(f)),
98                },
99                (Some(p), Opcode::TopN { n, asc }) => LogicalPlan::Limit {
100                    input: Box::new(LogicalPlan::Sort {
101                        input: Box::new(p),
102                        key:   None,
103                        desc:  !asc,
104                    }),
105                    n: *n,
106                },
107                _ => return LogicalPlan::Raw(Arc::new(program.clone())),
108            });
109        }
110        plan.unwrap_or(LogicalPlan::Raw(Arc::new(program.clone())))
111    }
112
113    /// True if this plan is a pure aggregate (reduces to scalar).
114    pub fn is_aggregate(&self) -> bool {
115        matches!(self, LogicalPlan::Aggregate { .. })
116    }
117
118    /// Depth of the plan tree (for cost).
119    pub fn depth(&self) -> usize {
120        match self {
121            LogicalPlan::Scan | LogicalPlan::Path(_) | LogicalPlan::Raw(_) => 1,
122            LogicalPlan::Filter { input, .. } | LogicalPlan::Project { input, .. }
123                | LogicalPlan::Sort { input, .. } | LogicalPlan::Limit { input, .. }
124                | LogicalPlan::Aggregate { input, .. } => 1 + input.depth(),
125            LogicalPlan::Join { left, right, .. } =>
126                1 + left.depth().max(right.depth()),
127        }
128    }
129}
130
131fn lift_method(input: LogicalPlan, c: &Arc<super::vm::CompiledCall>) -> LogicalPlan {
132    match c.method {
133        BuiltinMethod::Filter if !c.sub_progs.is_empty() => LogicalPlan::Filter {
134            input: Box::new(input),
135            pred:  Arc::clone(&c.sub_progs[0]),
136        },
137        BuiltinMethod::Map if !c.sub_progs.is_empty() => LogicalPlan::Project {
138            input: Box::new(input),
139            map:   Arc::clone(&c.sub_progs[0]),
140        },
141        BuiltinMethod::Sort => LogicalPlan::Sort {
142            input: Box::new(input),
143            key:   c.sub_progs.first().map(Arc::clone),
144            desc:  false,
145        },
146        BuiltinMethod::Count | BuiltinMethod::Len =>
147            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Count, arg: None },
148        BuiltinMethod::Sum =>
149            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Sum, arg: c.sub_progs.first().map(Arc::clone) },
150        BuiltinMethod::Avg =>
151            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Avg, arg: c.sub_progs.first().map(Arc::clone) },
152        BuiltinMethod::Min =>
153            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Min, arg: None },
154        BuiltinMethod::Max =>
155            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Max, arg: None },
156        BuiltinMethod::First =>
157            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::First, arg: None },
158        BuiltinMethod::Last =>
159            LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Last, arg: None },
160        _ => LogicalPlan::Raw(Arc::new(Program {
161            ops: Arc::from(vec![Opcode::CallMethod(Arc::clone(c))]),
162            source: Arc::from(""), id: 0, is_structural: false,
163        })),
164    }
165}
166
167// ── Rewrite rules ─────────────────────────────────────────────────────────────
168
169/// Push a filter down through a project when the project's map is a pure
170/// field-access (equi-projection).  Enables evaluating predicate on the
171/// larger pre-project rowset, which is often cheaper.
172///
173/// This is a skeleton — only the trivially-safe case is rewritten.
174pub fn pushdown_filter(plan: LogicalPlan) -> LogicalPlan {
175    match plan {
176        LogicalPlan::Filter { input, pred } => match *input {
177            // filter(sort(x)) → sort(filter(x))   [sort is order-preserving filter-wise]
178            LogicalPlan::Sort { input: inner, key, desc } => LogicalPlan::Sort {
179                input: Box::new(pushdown_filter(LogicalPlan::Filter { input: inner, pred })),
180                key, desc,
181            },
182            other => LogicalPlan::Filter { input: Box::new(pushdown_filter(other)), pred },
183        },
184        LogicalPlan::Project { input, map } => LogicalPlan::Project {
185            input: Box::new(pushdown_filter(*input)),
186            map,
187        },
188        LogicalPlan::Sort { input, key, desc } => LogicalPlan::Sort {
189            input: Box::new(pushdown_filter(*input)),
190            key, desc,
191        },
192        LogicalPlan::Limit { input, n } => LogicalPlan::Limit {
193            input: Box::new(pushdown_filter(*input)),
194            n,
195        },
196        LogicalPlan::Aggregate { input, op, arg } => LogicalPlan::Aggregate {
197            input: Box::new(pushdown_filter(*input)),
198            op, arg,
199        },
200        other => other,
201    }
202}
203
204// ── Lowering: LogicalPlan → Program ──────────────────────────────────────────
205
206/// Compile a `LogicalPlan` back to a flat `Program`.  Inverse of `lift`.
207/// Lifting then lowering should produce a semantically-equivalent program
208/// (not necessarily byte-identical — the lowered form is canonicalised).
209pub fn lower(plan: &LogicalPlan) -> Arc<Program> {
210    let mut ops = Vec::new();
211    emit(plan, &mut ops);
212    Arc::new(Program {
213        ops:           ops.into(),
214        source:        Arc::from("<lowered>"),
215        id:            0,
216        is_structural: false,
217    })
218}
219
220fn emit(plan: &LogicalPlan, ops: &mut Vec<super::vm::Opcode>) {
221    use super::vm::Opcode;
222    match plan {
223        LogicalPlan::Scan => ops.push(Opcode::PushRoot),
224        LogicalPlan::Path(ks) => {
225            ops.push(Opcode::RootChain(ks.iter().cloned().collect::<Vec<_>>().into()));
226        }
227        LogicalPlan::Filter { input, pred } => {
228            emit(input, ops);
229            ops.push(Opcode::InlineFilter(Arc::clone(pred)));
230        }
231        LogicalPlan::Project { input, map } => {
232            emit(input, ops);
233            ops.push(map_as_call(map));
234        }
235        LogicalPlan::Sort { input, key, desc } => {
236            emit(input, ops);
237            ops.push(sort_as_call(key.as_ref()));
238            if *desc { ops.push(reverse_call()); }
239        }
240        LogicalPlan::Limit { input, n } => {
241            emit(input, ops);
242            ops.push(Opcode::TopN { n: *n, asc: true });
243        }
244        LogicalPlan::Aggregate { input, op, arg } => {
245            emit(input, ops);
246            match op {
247                AggOp::Count => ops.push(noarg_call(super::vm::BuiltinMethod::Count, "count")),
248                AggOp::Sum if arg.is_some() => ops.push(Opcode::MapSum(Arc::clone(arg.as_ref().unwrap()))),
249                AggOp::Avg if arg.is_some() => ops.push(Opcode::MapAvg(Arc::clone(arg.as_ref().unwrap()))),
250                AggOp::Sum => ops.push(noarg_call(super::vm::BuiltinMethod::Sum, "sum")),
251                AggOp::Avg => ops.push(noarg_call(super::vm::BuiltinMethod::Avg, "avg")),
252                AggOp::Min => ops.push(noarg_call(super::vm::BuiltinMethod::Min, "min")),
253                AggOp::Max => ops.push(noarg_call(super::vm::BuiltinMethod::Max, "max")),
254                AggOp::First => ops.push(noarg_call(super::vm::BuiltinMethod::First, "first")),
255                AggOp::Last  => ops.push(noarg_call(super::vm::BuiltinMethod::Last, "last")),
256            }
257        }
258        LogicalPlan::Join { left, right: _, on: _ } => {
259            // Placeholder: emit left only (no fused join runtime yet).
260            emit(left, ops);
261        }
262        LogicalPlan::Raw(p) => {
263            for op in p.ops.iter() { ops.push(op.clone()); }
264        }
265    }
266}
267
268fn noarg_call(method: super::vm::BuiltinMethod, name: &str) -> super::vm::Opcode {
269    use super::vm::{Opcode, CompiledCall};
270    Opcode::CallMethod(Arc::new(CompiledCall {
271        method, name: Arc::from(name),
272        sub_progs: Arc::from(Vec::new()),
273        orig_args: Arc::from(Vec::new()),
274    }))
275}
276
277fn reverse_call() -> super::vm::Opcode {
278    noarg_call(super::vm::BuiltinMethod::Reverse, "reverse")
279}
280
281fn map_as_call(map: &Arc<Program>) -> super::vm::Opcode {
282    use super::vm::{Opcode, CompiledCall, BuiltinMethod};
283    Opcode::CallMethod(Arc::new(CompiledCall {
284        method:    BuiltinMethod::Map,
285        name:      Arc::from("map"),
286        sub_progs: Arc::from(vec![Arc::clone(map)]),
287        orig_args: Arc::from(Vec::new()),
288    }))
289}
290
291fn sort_as_call(key: Option<&Arc<Program>>) -> super::vm::Opcode {
292    use super::vm::{Opcode, CompiledCall, BuiltinMethod};
293    let sub_progs: Vec<Arc<Program>> = key.map(|k| vec![Arc::clone(k)]).unwrap_or_default();
294    Opcode::CallMethod(Arc::new(CompiledCall {
295        method:    BuiltinMethod::Sort,
296        name:      Arc::from("sort"),
297        sub_progs: sub_progs.into(),
298        orig_args: Arc::from(Vec::new()),
299    }))
300}
301
302// ── Join detection ────────────────────────────────────────────────────────────
303
304/// Walk a `LogicalPlan` looking for filter predicates that compare two
305/// distinct identifiers — a candidate equi-join between correlated scans.
306/// Returns a vector of (left-path, right-path) fragment pairs per detected
307/// candidate.  Wiring actual `Join` rewrite is future work.
308pub fn detect_join_candidates(plan: &LogicalPlan) -> Vec<JoinCandidate> {
309    let mut out = Vec::new();
310    walk(plan, &mut out);
311    out
312}
313
314#[derive(Debug, Clone)]
315pub struct JoinCandidate {
316    /// Textual description of left side (from pred source or ident).
317    pub left:  String,
318    pub right: String,
319}
320
321fn walk(plan: &LogicalPlan, out: &mut Vec<JoinCandidate>) {
322    match plan {
323        LogicalPlan::Filter { input, pred } => {
324            if let Some(j) = detect_eq_join(pred) { out.push(j); }
325            walk(input, out);
326        }
327        LogicalPlan::Project { input, .. }
328            | LogicalPlan::Sort { input, .. }
329            | LogicalPlan::Limit { input, .. }
330            | LogicalPlan::Aggregate { input, .. } => walk(input, out),
331        LogicalPlan::Join { left, right, .. } => { walk(left, out); walk(right, out); }
332        _ => {}
333    }
334}
335
336fn detect_eq_join(pred: &Arc<Program>) -> Option<JoinCandidate> {
337    use crate::vm::Opcode;
338    // Look for pattern: LoadIdent + GetField* + LoadIdent + GetField* + Eq
339    let ops = &pred.ops;
340    if ops.len() < 3 { return None; }
341    let last = ops.last()?;
342    if !matches!(last, Opcode::Eq) { return None; }
343    // Split ops by finding the second LoadIdent (crude): two independent
344    // chains, both starting with a LoadIdent.
345    let ident_positions: Vec<usize> = ops.iter().enumerate()
346        .filter(|(_, o)| matches!(o, Opcode::LoadIdent(_)))
347        .map(|(i, _)| i).collect();
348    if ident_positions.len() != 2 { return None; }
349    let a = describe_chain(&ops[ident_positions[0]..ident_positions[1]]);
350    let b = describe_chain(&ops[ident_positions[1]..ops.len()-1]);
351    if a == b { return None; }
352    Some(JoinCandidate { left: a, right: b })
353}
354
355fn describe_chain(ops: &[super::vm::Opcode]) -> String {
356    use super::vm::Opcode;
357    let mut s = String::new();
358    for op in ops {
359        match op {
360            Opcode::LoadIdent(n) => s.push_str(n),
361            Opcode::GetField(k)  => { s.push('.'); s.push_str(k); }
362            _ => {}
363        }
364    }
365    s
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::vm::Compiler;
372
373    #[test]
374    fn lift_path() {
375        let p = Compiler::compile_str("$.store.books").unwrap();
376        match LogicalPlan::lift(&p) {
377            LogicalPlan::Path(v) => {
378                assert_eq!(v.len(), 2);
379                assert_eq!(v[0].as_ref(), "store");
380                assert_eq!(v[1].as_ref(), "books");
381            }
382            other => panic!("expected Path, got {:?}", other),
383        }
384    }
385
386    #[test]
387    fn lift_filter_map() {
388        let p = Compiler::compile_str("$.books.filter(@.price > 10).map(@.title)").unwrap();
389        let plan = LogicalPlan::lift(&p);
390        // filter+map fuses to FilterMap opcode → lifts to Project(Filter(..))
391        assert!(plan.depth() >= 2);
392    }
393
394    #[test]
395    fn lift_aggregate() {
396        let p = Compiler::compile_str("$.books.count()").unwrap();
397        let plan = LogicalPlan::lift(&p);
398        assert!(plan.is_aggregate());
399    }
400
401    #[test]
402    fn roundtrip_lower_preserves_semantics() {
403        use crate::vm::VM;
404        use serde_json::json;
405        let doc = json!({"store": {"books": [{"price": 20}, {"price": 5}]}});
406        let p = Compiler::compile_str("$.store.books.filter(@.price > 10).count()").unwrap();
407        let plan = LogicalPlan::lift(&p);
408        let lowered = lower(&plan);
409        let mut vm = VM::new();
410        let original = vm.execute(&p, &doc).unwrap();
411        let round    = vm.execute(&lowered, &doc).unwrap();
412        assert_eq!(original, round);
413    }
414
415    #[test]
416    fn detect_join_candidates_finds_equi_join() {
417        // Two different identifiers compared by eq → join candidate.
418        let p = Compiler::compile_str("$.x.filter(a.id == b.id)").unwrap();
419        let plan = LogicalPlan::lift(&p);
420        let candidates = detect_join_candidates(&plan);
421        assert!(!candidates.is_empty(), "should detect a.id == b.id as join");
422    }
423}