1use std::sync::Arc;
15use super::vm::{Program, Opcode, BuiltinMethod};
16
17#[derive(Debug, Clone)]
18pub enum LogicalPlan {
19 Scan,
21 Path(Vec<Arc<str>>),
23 Filter { input: Box<LogicalPlan>, pred: Arc<Program> },
25 Project { input: Box<LogicalPlan>, map: Arc<Program> },
27 Aggregate { input: Box<LogicalPlan>, op: AggOp, arg: Option<Arc<Program>> },
29 Sort { input: Box<LogicalPlan>, key: Option<Arc<Program>>, desc: bool },
31 Limit { input: Box<LogicalPlan>, n: usize },
33 Join { left: Box<LogicalPlan>, right: Box<LogicalPlan>, on: Arc<Program> },
36 Raw(Arc<Program>),
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum AggOp { Count, Sum, Avg, Min, Max, First, Last }
42
43impl LogicalPlan {
44 pub fn lift(program: &Program) -> LogicalPlan {
47 let ops = &program.ops;
48 if ops.is_empty() { return LogicalPlan::Raw(Arc::new(program.clone())); }
49
50 let mut plan: Option<LogicalPlan> = None;
51 for op in ops.iter() {
52 plan = Some(match (plan.take(), op) {
53 (None, Opcode::PushRoot) => LogicalPlan::Scan,
54 (None, Opcode::RootChain(chain)) => {
55 LogicalPlan::Path(chain.iter().cloned().collect())
56 }
57 (Some(p), Opcode::GetField(k)) => match p {
58 LogicalPlan::Scan => LogicalPlan::Path(vec![k.clone()]),
59 LogicalPlan::Path(mut v) => { v.push(k.clone()); LogicalPlan::Path(v) }
60 other => LogicalPlan::Project {
61 input: Box::new(other),
62 map: Arc::new({
63 let ops_vec = vec![Opcode::PushCurrent, Opcode::GetField(k.clone())];
64 let ics = crate::vm::fresh_ics(ops_vec.len());
65 Program {
66 ops: Arc::from(ops_vec),
67 source: Arc::from(""), id: 0, is_structural: true, ics,
68 }
69 }),
70 },
71 }
72 (Some(p), Opcode::RootChain(chain)) => {
73 let mut v: Vec<Arc<str>> = match p { LogicalPlan::Scan => vec![], _ => return LogicalPlan::Raw(Arc::new(program.clone())) };
74 for k in chain.iter() { v.push(k.clone()); }
75 LogicalPlan::Path(v)
76 }
77 (Some(p), Opcode::CallMethod(c)) => lift_method(p, c),
78 (Some(p), Opcode::FilterMap { pred, map }) => LogicalPlan::Project {
79 input: Box::new(LogicalPlan::Filter {
80 input: Box::new(p),
81 pred: Arc::clone(pred),
82 }),
83 map: Arc::clone(map),
84 },
85 (Some(p), Opcode::MapFilter { map, pred }) => LogicalPlan::Filter {
86 input: Box::new(LogicalPlan::Project {
87 input: Box::new(p),
88 map: Arc::clone(map),
89 }),
90 pred: Arc::clone(pred),
91 },
92 (Some(p), Opcode::FilterCount(pred)) => LogicalPlan::Aggregate {
93 input: Box::new(LogicalPlan::Filter {
94 input: Box::new(p),
95 pred: Arc::clone(pred),
96 }),
97 op: AggOp::Count,
98 arg: None,
99 },
100 (Some(p), Opcode::MapSum(f)) => LogicalPlan::Aggregate {
101 input: Box::new(p),
102 op: AggOp::Sum,
103 arg: Some(Arc::clone(f)),
104 },
105 (Some(p), Opcode::MapAvg(f)) => LogicalPlan::Aggregate {
106 input: Box::new(p),
107 op: AggOp::Avg,
108 arg: Some(Arc::clone(f)),
109 },
110 (Some(p), Opcode::FilterMapSum { pred, map }) => LogicalPlan::Aggregate {
111 input: Box::new(LogicalPlan::Filter {
112 input: Box::new(p),
113 pred: Arc::clone(pred),
114 }),
115 op: AggOp::Sum,
116 arg: Some(Arc::clone(map)),
117 },
118 (Some(p), Opcode::FilterMapAvg { pred, map }) => LogicalPlan::Aggregate {
119 input: Box::new(LogicalPlan::Filter {
120 input: Box::new(p),
121 pred: Arc::clone(pred),
122 }),
123 op: AggOp::Avg,
124 arg: Some(Arc::clone(map)),
125 },
126 (Some(p), Opcode::TopN { n, asc }) => LogicalPlan::Limit {
127 input: Box::new(LogicalPlan::Sort {
128 input: Box::new(p),
129 key: None,
130 desc: !asc,
131 }),
132 n: *n,
133 },
134 _ => return LogicalPlan::Raw(Arc::new(program.clone())),
135 });
136 }
137 plan.unwrap_or(LogicalPlan::Raw(Arc::new(program.clone())))
138 }
139
140 pub fn is_aggregate(&self) -> bool {
142 matches!(self, LogicalPlan::Aggregate { .. })
143 }
144
145 pub fn depth(&self) -> usize {
147 match self {
148 LogicalPlan::Scan | LogicalPlan::Path(_) | LogicalPlan::Raw(_) => 1,
149 LogicalPlan::Filter { input, .. } | LogicalPlan::Project { input, .. }
150 | LogicalPlan::Sort { input, .. } | LogicalPlan::Limit { input, .. }
151 | LogicalPlan::Aggregate { input, .. } => 1 + input.depth(),
152 LogicalPlan::Join { left, right, .. } =>
153 1 + left.depth().max(right.depth()),
154 }
155 }
156}
157
158fn lift_method(input: LogicalPlan, c: &Arc<super::vm::CompiledCall>) -> LogicalPlan {
159 match c.method {
160 BuiltinMethod::Filter if !c.sub_progs.is_empty() => LogicalPlan::Filter {
161 input: Box::new(input),
162 pred: Arc::clone(&c.sub_progs[0]),
163 },
164 BuiltinMethod::Map if !c.sub_progs.is_empty() => LogicalPlan::Project {
165 input: Box::new(input),
166 map: Arc::clone(&c.sub_progs[0]),
167 },
168 BuiltinMethod::Sort => LogicalPlan::Sort {
169 input: Box::new(input),
170 key: c.sub_progs.first().map(Arc::clone),
171 desc: false,
172 },
173 BuiltinMethod::Count | BuiltinMethod::Len =>
174 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Count, arg: None },
175 BuiltinMethod::Sum =>
176 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Sum, arg: c.sub_progs.first().map(Arc::clone) },
177 BuiltinMethod::Avg =>
178 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Avg, arg: c.sub_progs.first().map(Arc::clone) },
179 BuiltinMethod::Min =>
180 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Min, arg: None },
181 BuiltinMethod::Max =>
182 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Max, arg: None },
183 BuiltinMethod::First =>
184 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::First, arg: None },
185 BuiltinMethod::Last =>
186 LogicalPlan::Aggregate { input: Box::new(input), op: AggOp::Last, arg: None },
187 _ => LogicalPlan::Raw(Arc::new({
188 let ops_vec = vec![Opcode::CallMethod(Arc::clone(c))];
189 let ics = crate::vm::fresh_ics(ops_vec.len());
190 Program {
191 ops: Arc::from(ops_vec),
192 source: Arc::from(""), id: 0, is_structural: false, ics,
193 }
194 })),
195 }
196}
197
198pub fn pushdown_filter(plan: LogicalPlan) -> LogicalPlan {
206 match plan {
207 LogicalPlan::Filter { input, pred } => match *input {
208 LogicalPlan::Sort { input: inner, key, desc } => LogicalPlan::Sort {
210 input: Box::new(pushdown_filter(LogicalPlan::Filter { input: inner, pred })),
211 key, desc,
212 },
213 other => LogicalPlan::Filter { input: Box::new(pushdown_filter(other)), pred },
214 },
215 LogicalPlan::Project { input, map } => LogicalPlan::Project {
216 input: Box::new(pushdown_filter(*input)),
217 map,
218 },
219 LogicalPlan::Sort { input, key, desc } => LogicalPlan::Sort {
220 input: Box::new(pushdown_filter(*input)),
221 key, desc,
222 },
223 LogicalPlan::Limit { input, n } => LogicalPlan::Limit {
224 input: Box::new(pushdown_filter(*input)),
225 n,
226 },
227 LogicalPlan::Aggregate { input, op, arg } => LogicalPlan::Aggregate {
228 input: Box::new(pushdown_filter(*input)),
229 op, arg,
230 },
231 other => other,
232 }
233}
234
235pub fn lower(plan: &LogicalPlan) -> Arc<Program> {
241 let mut ops = Vec::new();
242 emit(plan, &mut ops);
243 let ics = crate::vm::fresh_ics(ops.len());
244 Arc::new(Program {
245 ops: ops.into(),
246 source: Arc::from("<lowered>"),
247 id: 0,
248 is_structural: false,
249 ics,
250 })
251}
252
253fn emit(plan: &LogicalPlan, ops: &mut Vec<super::vm::Opcode>) {
254 use super::vm::Opcode;
255 match plan {
256 LogicalPlan::Scan => ops.push(Opcode::PushRoot),
257 LogicalPlan::Path(ks) => {
258 ops.push(Opcode::RootChain(ks.iter().cloned().collect::<Vec<_>>().into()));
259 }
260 LogicalPlan::Filter { input, pred } => {
261 emit(input, ops);
262 ops.push(Opcode::InlineFilter(Arc::clone(pred)));
263 }
264 LogicalPlan::Project { input, map } => {
265 emit(input, ops);
266 ops.push(map_as_call(map));
267 }
268 LogicalPlan::Sort { input, key, desc } => {
269 emit(input, ops);
270 ops.push(sort_as_call(key.as_ref()));
271 if *desc { ops.push(reverse_call()); }
272 }
273 LogicalPlan::Limit { input, n } => {
274 emit(input, ops);
275 ops.push(Opcode::TopN { n: *n, asc: true });
276 }
277 LogicalPlan::Aggregate { input, op, arg } => {
278 emit(input, ops);
279 match op {
280 AggOp::Count => ops.push(noarg_call(super::vm::BuiltinMethod::Count, "count")),
281 AggOp::Sum if arg.is_some() => ops.push(Opcode::MapSum(Arc::clone(arg.as_ref().unwrap()))),
282 AggOp::Avg if arg.is_some() => ops.push(Opcode::MapAvg(Arc::clone(arg.as_ref().unwrap()))),
283 AggOp::Sum => ops.push(noarg_call(super::vm::BuiltinMethod::Sum, "sum")),
284 AggOp::Avg => ops.push(noarg_call(super::vm::BuiltinMethod::Avg, "avg")),
285 AggOp::Min => ops.push(noarg_call(super::vm::BuiltinMethod::Min, "min")),
286 AggOp::Max => ops.push(noarg_call(super::vm::BuiltinMethod::Max, "max")),
287 AggOp::First => ops.push(noarg_call(super::vm::BuiltinMethod::First, "first")),
288 AggOp::Last => ops.push(noarg_call(super::vm::BuiltinMethod::Last, "last")),
289 }
290 }
291 LogicalPlan::Join { left, right: _, on: _ } => {
292 emit(left, ops);
294 }
295 LogicalPlan::Raw(p) => {
296 for op in p.ops.iter() { ops.push(op.clone()); }
297 }
298 }
299}
300
301fn noarg_call(method: super::vm::BuiltinMethod, name: &str) -> super::vm::Opcode {
302 use super::vm::{Opcode, CompiledCall};
303 Opcode::CallMethod(Arc::new(CompiledCall {
304 method, name: Arc::from(name),
305 sub_progs: Arc::from(Vec::new()),
306 orig_args: Arc::from(Vec::new()),
307 }))
308}
309
310fn reverse_call() -> super::vm::Opcode {
311 noarg_call(super::vm::BuiltinMethod::Reverse, "reverse")
312}
313
314fn map_as_call(map: &Arc<Program>) -> super::vm::Opcode {
315 use super::vm::{Opcode, CompiledCall, BuiltinMethod};
316 Opcode::CallMethod(Arc::new(CompiledCall {
317 method: BuiltinMethod::Map,
318 name: Arc::from("map"),
319 sub_progs: Arc::from(vec![Arc::clone(map)]),
320 orig_args: Arc::from(Vec::new()),
321 }))
322}
323
324fn sort_as_call(key: Option<&Arc<Program>>) -> super::vm::Opcode {
325 use super::vm::{Opcode, CompiledCall, BuiltinMethod};
326 let sub_progs: Vec<Arc<Program>> = key.map(|k| vec![Arc::clone(k)]).unwrap_or_default();
327 Opcode::CallMethod(Arc::new(CompiledCall {
328 method: BuiltinMethod::Sort,
329 name: Arc::from("sort"),
330 sub_progs: sub_progs.into(),
331 orig_args: Arc::from(Vec::new()),
332 }))
333}
334
335pub fn detect_join_candidates(plan: &LogicalPlan) -> Vec<JoinCandidate> {
342 let mut out = Vec::new();
343 walk(plan, &mut out);
344 out
345}
346
347#[derive(Debug, Clone)]
348pub struct JoinCandidate {
349 pub left: String,
351 pub right: String,
352}
353
354fn walk(plan: &LogicalPlan, out: &mut Vec<JoinCandidate>) {
355 match plan {
356 LogicalPlan::Filter { input, pred } => {
357 if let Some(j) = detect_eq_join(pred) { out.push(j); }
358 walk(input, out);
359 }
360 LogicalPlan::Project { input, .. }
361 | LogicalPlan::Sort { input, .. }
362 | LogicalPlan::Limit { input, .. }
363 | LogicalPlan::Aggregate { input, .. } => walk(input, out),
364 LogicalPlan::Join { left, right, .. } => { walk(left, out); walk(right, out); }
365 _ => {}
366 }
367}
368
369fn detect_eq_join(pred: &Arc<Program>) -> Option<JoinCandidate> {
370 use crate::vm::Opcode;
371 let ops = &pred.ops;
373 if ops.len() < 3 { return None; }
374 let last = ops.last()?;
375 if !matches!(last, Opcode::Eq) { return None; }
376 let ident_positions: Vec<usize> = ops.iter().enumerate()
379 .filter(|(_, o)| matches!(o, Opcode::LoadIdent(_)))
380 .map(|(i, _)| i).collect();
381 if ident_positions.len() != 2 { return None; }
382 let a = describe_chain(&ops[ident_positions[0]..ident_positions[1]]);
383 let b = describe_chain(&ops[ident_positions[1]..ops.len()-1]);
384 if a == b { return None; }
385 Some(JoinCandidate { left: a, right: b })
386}
387
388fn describe_chain(ops: &[super::vm::Opcode]) -> String {
389 use super::vm::Opcode;
390 let mut s = String::new();
391 for op in ops {
392 match op {
393 Opcode::LoadIdent(n) => s.push_str(n),
394 Opcode::GetField(k) => { s.push('.'); s.push_str(k); }
395 _ => {}
396 }
397 }
398 s
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use crate::vm::Compiler;
405
406 #[test]
407 fn lift_path() {
408 let p = Compiler::compile_str("$.store.books").unwrap();
409 match LogicalPlan::lift(&p) {
410 LogicalPlan::Path(v) => {
411 assert_eq!(v.len(), 2);
412 assert_eq!(v[0].as_ref(), "store");
413 assert_eq!(v[1].as_ref(), "books");
414 }
415 other => panic!("expected Path, got {:?}", other),
416 }
417 }
418
419 #[test]
420 fn lift_filter_map() {
421 let p = Compiler::compile_str("$.books.filter(@.price > 10).map(@.title)").unwrap();
422 let plan = LogicalPlan::lift(&p);
423 assert!(plan.depth() >= 2);
425 }
426
427 #[test]
428 fn lift_aggregate() {
429 let p = Compiler::compile_str("$.books.count()").unwrap();
430 let plan = LogicalPlan::lift(&p);
431 assert!(plan.is_aggregate());
432 }
433
434 #[test]
435 fn roundtrip_lower_preserves_semantics() {
436 use crate::vm::VM;
437 use serde_json::json;
438 let doc = json!({"store": {"books": [{"price": 20}, {"price": 5}]}});
439 let p = Compiler::compile_str("$.store.books.filter(@.price > 10).count()").unwrap();
440 let plan = LogicalPlan::lift(&p);
441 let lowered = lower(&plan);
442 let mut vm = VM::new();
443 let original = vm.execute(&p, &doc).unwrap();
444 let round = vm.execute(&lowered, &doc).unwrap();
445 assert_eq!(original, round);
446 }
447
448 #[test]
449 fn detect_join_candidates_finds_equi_join() {
450 let p = Compiler::compile_str("$.x.filter(a.id == b.id)").unwrap();
452 let plan = LogicalPlan::lift(&p);
453 let candidates = detect_join_candidates(&plan);
454 assert!(!candidates.is_empty(), "should detect a.id == b.id as join");
455 }
456}