1use std::sync::Arc;
15use super::vm::{Program, Opcode, BuiltinMethod};
16
17#[derive(Debug, Clone)]
18pub enum LogicalPlan {
19 Scan,
21 Path(Vec<Arc<str>>),
23 Filter { input: Box<LogicalPlan>, pred: Arc<Program> },
25 Project { input: Box<LogicalPlan>, map: Arc<Program> },
27 Aggregate { input: Box<LogicalPlan>, op: AggOp, arg: Option<Arc<Program>> },
29 Sort { input: Box<LogicalPlan>, key: Option<Arc<Program>>, desc: bool },
31 Limit { input: Box<LogicalPlan>, n: usize },
33 Join { left: Box<LogicalPlan>, right: Box<LogicalPlan>, on: Arc<Program> },
36 Raw(Arc<Program>),
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum AggOp { Count, Sum, Avg, Min, Max, First, Last }
42
43impl LogicalPlan {
44 pub fn lift(program: &Program) -> LogicalPlan {
47 let ops = &program.ops;
48 if ops.is_empty() { return LogicalPlan::Raw(Arc::new(program.clone())); }
49
50 let mut plan: Option<LogicalPlan> = None;
51 for op in ops.iter() {
52 plan = Some(match (plan.take(), op) {
53 (None, Opcode::PushRoot) => LogicalPlan::Scan,
54 (None, Opcode::RootChain(chain)) => {
55 LogicalPlan::Path(chain.iter().cloned().collect())
56 }
57 (Some(p), Opcode::GetField(k)) => match p {
58 LogicalPlan::Scan => LogicalPlan::Path(vec![k.clone()]),
59 LogicalPlan::Path(mut v) => { v.push(k.clone()); LogicalPlan::Path(v) }
60 other => LogicalPlan::Project {
61 input: Box::new(other),
62 map: Arc::new(Program {
63 ops: Arc::from(vec![Opcode::PushCurrent, Opcode::GetField(k.clone())]),
64 source: Arc::from(""), id: 0, is_structural: true,
65 }),
66 },
67 }
68 (Some(p), Opcode::RootChain(chain)) => {
69 let mut v: Vec<Arc<str>> = match p { LogicalPlan::Scan => vec![], _ => return LogicalPlan::Raw(Arc::new(program.clone())) };
70 for k in chain.iter() { v.push(k.clone()); }
71 LogicalPlan::Path(v)
72 }
73 (Some(p), Opcode::CallMethod(c)) => lift_method(p, c),
74 (Some(p), Opcode::FilterMap { pred, map }) => LogicalPlan::Project {
75 input: Box::new(LogicalPlan::Filter {
76 input: Box::new(p),
77 pred: Arc::clone(pred),
78 }),
79 map: Arc::clone(map),
80 },
81 (Some(p), Opcode::FilterCount(pred)) => LogicalPlan::Aggregate {
82 input: Box::new(LogicalPlan::Filter {
83 input: Box::new(p),
84 pred: Arc::clone(pred),
85 }),
86 op: AggOp::Count,
87 arg: None,
88 },
89 (Some(p), Opcode::MapSum(f)) => LogicalPlan::Aggregate {
90 input: Box::new(p),
91 op: AggOp::Sum,
92 arg: Some(Arc::clone(f)),
93 },
94 (Some(p), Opcode::MapAvg(f)) => LogicalPlan::Aggregate {
95 input: Box::new(p),
96 op: AggOp::Avg,
97 arg: Some(Arc::clone(f)),
98 },
99 (Some(p), Opcode::TopN { n, asc }) => LogicalPlan::Limit {
100 input: Box::new(LogicalPlan::Sort {
101 input: Box::new(p),
102 key: None,
103 desc: !asc,
104 }),
105 n: *n,
106 },
107 _ => return LogicalPlan::Raw(Arc::new(program.clone())),
108 });
109 }
110 plan.unwrap_or(LogicalPlan::Raw(Arc::new(program.clone())))
111 }
112
113 pub fn is_aggregate(&self) -> bool {
115 matches!(self, LogicalPlan::Aggregate { .. })
116 }
117
118 pub fn depth(&self) -> usize {
120 match self {
121 LogicalPlan::Scan | LogicalPlan::Path(_) | LogicalPlan::Raw(_) => 1,
122 LogicalPlan::Filter { input, .. } | LogicalPlan::Project { input, .. }
123 | LogicalPlan::Sort { input, .. } | LogicalPlan::Limit { input, .. }
124 | LogicalPlan::Aggregate { input, .. } => 1 + input.depth(),
125 LogicalPlan::Join { left, right, .. } =>
126 1 + left.depth().max(right.depth()),
127 }
128 }
129}
130
131fn lift_method(input: LogicalPlan, c: &Arc<super::vm::CompiledCall>) -> LogicalPlan {
132 match c.method {
133 BuiltinMethod::Filter if !c.sub_progs.is_empty() => LogicalPlan::Filter {
134 input: Box::new(input),
135 pred: Arc::clone(&c.sub_progs[0]),
136 },
137 BuiltinMethod::Map if !c.sub_progs.is_empty() => LogicalPlan::Project {
138 input: Box::new(input),
139 map: Arc::clone(&c.sub_progs[0]),
140 },
141 BuiltinMethod::Sort => LogicalPlan::Sort {
142 input: Box::new(input),
143 key: c.sub_progs.first().map(Arc::clone),
144 desc: false,
145 },
146 BuiltinMethod::Count | BuiltinMethod::Len =>
147 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Count, arg: None },
148 BuiltinMethod::Sum =>
149 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Sum, arg: c.sub_progs.first().map(Arc::clone) },
150 BuiltinMethod::Avg =>
151 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Avg, arg: c.sub_progs.first().map(Arc::clone) },
152 BuiltinMethod::Min =>
153 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Min, arg: None },
154 BuiltinMethod::Max =>
155 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Max, arg: None },
156 BuiltinMethod::First =>
157 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::First, arg: None },
158 BuiltinMethod::Last =>
159 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Last, arg: None },
160 _ => LogicalPlan::Raw(Arc::new(Program {
161 ops: Arc::from(vec![Opcode::CallMethod(Arc::clone(c))]),
162 source: Arc::from(""), id: 0, is_structural: false,
163 })),
164 }
165}
166
167pub fn pushdown_filter(plan: LogicalPlan) -> LogicalPlan {
175 match plan {
176 LogicalPlan::Filter { input, pred } => match *input {
177 LogicalPlan::Sort { input: inner, key, desc } => LogicalPlan::Sort {
179 input: Box::new(pushdown_filter(LogicalPlan::Filter { input: inner, pred })),
180 key, desc,
181 },
182 other => LogicalPlan::Filter { input: Box::new(pushdown_filter(other)), pred },
183 },
184 LogicalPlan::Project { input, map } => LogicalPlan::Project {
185 input: Box::new(pushdown_filter(*input)),
186 map,
187 },
188 LogicalPlan::Sort { input, key, desc } => LogicalPlan::Sort {
189 input: Box::new(pushdown_filter(*input)),
190 key, desc,
191 },
192 LogicalPlan::Limit { input, n } => LogicalPlan::Limit {
193 input: Box::new(pushdown_filter(*input)),
194 n,
195 },
196 LogicalPlan::Aggregate { input, op, arg } => LogicalPlan::Aggregate {
197 input: Box::new(pushdown_filter(*input)),
198 op, arg,
199 },
200 other => other,
201 }
202}
203
204pub fn lower(plan: &LogicalPlan) -> Arc<Program> {
210 let mut ops = Vec::new();
211 emit(plan, &mut ops);
212 Arc::new(Program {
213 ops: ops.into(),
214 source: Arc::from("<lowered>"),
215 id: 0,
216 is_structural: false,
217 })
218}
219
220fn emit(plan: &LogicalPlan, ops: &mut Vec<super::vm::Opcode>) {
221 use super::vm::Opcode;
222 match plan {
223 LogicalPlan::Scan => ops.push(Opcode::PushRoot),
224 LogicalPlan::Path(ks) => {
225 ops.push(Opcode::RootChain(ks.iter().cloned().collect::<Vec<_>>().into()));
226 }
227 LogicalPlan::Filter { input, pred } => {
228 emit(input, ops);
229 ops.push(Opcode::InlineFilter(Arc::clone(pred)));
230 }
231 LogicalPlan::Project { input, map } => {
232 emit(input, ops);
233 ops.push(map_as_call(map));
234 }
235 LogicalPlan::Sort { input, key, desc } => {
236 emit(input, ops);
237 ops.push(sort_as_call(key.as_ref()));
238 if *desc { ops.push(reverse_call()); }
239 }
240 LogicalPlan::Limit { input, n } => {
241 emit(input, ops);
242 ops.push(Opcode::TopN { n: *n, asc: true });
243 }
244 LogicalPlan::Aggregate { input, op, arg } => {
245 emit(input, ops);
246 match op {
247 AggOp::Count => ops.push(noarg_call(super::vm::BuiltinMethod::Count, "count")),
248 AggOp::Sum if arg.is_some() => ops.push(Opcode::MapSum(Arc::clone(arg.as_ref().unwrap()))),
249 AggOp::Avg if arg.is_some() => ops.push(Opcode::MapAvg(Arc::clone(arg.as_ref().unwrap()))),
250 AggOp::Sum => ops.push(noarg_call(super::vm::BuiltinMethod::Sum, "sum")),
251 AggOp::Avg => ops.push(noarg_call(super::vm::BuiltinMethod::Avg, "avg")),
252 AggOp::Min => ops.push(noarg_call(super::vm::BuiltinMethod::Min, "min")),
253 AggOp::Max => ops.push(noarg_call(super::vm::BuiltinMethod::Max, "max")),
254 AggOp::First => ops.push(noarg_call(super::vm::BuiltinMethod::First, "first")),
255 AggOp::Last => ops.push(noarg_call(super::vm::BuiltinMethod::Last, "last")),
256 }
257 }
258 LogicalPlan::Join { left, right: _, on: _ } => {
259 emit(left, ops);
261 }
262 LogicalPlan::Raw(p) => {
263 for op in p.ops.iter() { ops.push(op.clone()); }
264 }
265 }
266}
267
268fn noarg_call(method: super::vm::BuiltinMethod, name: &str) -> super::vm::Opcode {
269 use super::vm::{Opcode, CompiledCall};
270 Opcode::CallMethod(Arc::new(CompiledCall {
271 method, name: Arc::from(name),
272 sub_progs: Arc::from(Vec::new()),
273 orig_args: Arc::from(Vec::new()),
274 }))
275}
276
277fn reverse_call() -> super::vm::Opcode {
278 noarg_call(super::vm::BuiltinMethod::Reverse, "reverse")
279}
280
281fn map_as_call(map: &Arc<Program>) -> super::vm::Opcode {
282 use super::vm::{Opcode, CompiledCall, BuiltinMethod};
283 Opcode::CallMethod(Arc::new(CompiledCall {
284 method: BuiltinMethod::Map,
285 name: Arc::from("map"),
286 sub_progs: Arc::from(vec![Arc::clone(map)]),
287 orig_args: Arc::from(Vec::new()),
288 }))
289}
290
291fn sort_as_call(key: Option<&Arc<Program>>) -> super::vm::Opcode {
292 use super::vm::{Opcode, CompiledCall, BuiltinMethod};
293 let sub_progs: Vec<Arc<Program>> = key.map(|k| vec![Arc::clone(k)]).unwrap_or_default();
294 Opcode::CallMethod(Arc::new(CompiledCall {
295 method: BuiltinMethod::Sort,
296 name: Arc::from("sort"),
297 sub_progs: sub_progs.into(),
298 orig_args: Arc::from(Vec::new()),
299 }))
300}
301
302pub fn detect_join_candidates(plan: &LogicalPlan) -> Vec<JoinCandidate> {
309 let mut out = Vec::new();
310 walk(plan, &mut out);
311 out
312}
313
314#[derive(Debug, Clone)]
315pub struct JoinCandidate {
316 pub left: String,
318 pub right: String,
319}
320
321fn walk(plan: &LogicalPlan, out: &mut Vec<JoinCandidate>) {
322 match plan {
323 LogicalPlan::Filter { input, pred } => {
324 if let Some(j) = detect_eq_join(pred) { out.push(j); }
325 walk(input, out);
326 }
327 LogicalPlan::Project { input, .. }
328 | LogicalPlan::Sort { input, .. }
329 | LogicalPlan::Limit { input, .. }
330 | LogicalPlan::Aggregate { input, .. } => walk(input, out),
331 LogicalPlan::Join { left, right, .. } => { walk(left, out); walk(right, out); }
332 _ => {}
333 }
334}
335
336fn detect_eq_join(pred: &Arc<Program>) -> Option<JoinCandidate> {
337 use crate::vm::Opcode;
338 let ops = &pred.ops;
340 if ops.len() < 3 { return None; }
341 let last = ops.last()?;
342 if !matches!(last, Opcode::Eq) { return None; }
343 let ident_positions: Vec<usize> = ops.iter().enumerate()
346 .filter(|(_, o)| matches!(o, Opcode::LoadIdent(_)))
347 .map(|(i, _)| i).collect();
348 if ident_positions.len() != 2 { return None; }
349 let a = describe_chain(&ops[ident_positions[0]..ident_positions[1]]);
350 let b = describe_chain(&ops[ident_positions[1]..ops.len()-1]);
351 if a == b { return None; }
352 Some(JoinCandidate { left: a, right: b })
353}
354
355fn describe_chain(ops: &[super::vm::Opcode]) -> String {
356 use super::vm::Opcode;
357 let mut s = String::new();
358 for op in ops {
359 match op {
360 Opcode::LoadIdent(n) => s.push_str(n),
361 Opcode::GetField(k) => { s.push('.'); s.push_str(k); }
362 _ => {}
363 }
364 }
365 s
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::vm::Compiler;
372
373 #[test]
374 fn lift_path() {
375 let p = Compiler::compile_str("$.store.books").unwrap();
376 match LogicalPlan::lift(&p) {
377 LogicalPlan::Path(v) => {
378 assert_eq!(v.len(), 2);
379 assert_eq!(v[0].as_ref(), "store");
380 assert_eq!(v[1].as_ref(), "books");
381 }
382 other => panic!("expected Path, got {:?}", other),
383 }
384 }
385
386 #[test]
387 fn lift_filter_map() {
388 let p = Compiler::compile_str("$.books.filter(@.price > 10).map(@.title)").unwrap();
389 let plan = LogicalPlan::lift(&p);
390 assert!(plan.depth() >= 2);
392 }
393
394 #[test]
395 fn lift_aggregate() {
396 let p = Compiler::compile_str("$.books.count()").unwrap();
397 let plan = LogicalPlan::lift(&p);
398 assert!(plan.is_aggregate());
399 }
400
401 #[test]
402 fn roundtrip_lower_preserves_semantics() {
403 use crate::vm::VM;
404 use serde_json::json;
405 let doc = json!({"store": {"books": [{"price": 20}, {"price": 5}]}});
406 let p = Compiler::compile_str("$.store.books.filter(@.price > 10).count()").unwrap();
407 let plan = LogicalPlan::lift(&p);
408 let lowered = lower(&plan);
409 let mut vm = VM::new();
410 let original = vm.execute(&p, &doc).unwrap();
411 let round = vm.execute(&lowered, &doc).unwrap();
412 assert_eq!(original, round);
413 }
414
415 #[test]
416 fn detect_join_candidates_finds_equi_join() {
417 let p = Compiler::compile_str("$.x.filter(a.id == b.id)").unwrap();
419 let plan = LogicalPlan::lift(&p);
420 let candidates = detect_join_candidates(&plan);
421 assert!(!candidates.is_empty(), "should detect a.id == b.id as join");
422 }
423}