1use crate::logical::*;
2use crate::physical::*;
3use lora_analyzer::{symbols::VarId, ResolvedExpr};
4use lora_ast::BinaryOp;
5use std::collections::BTreeSet;
6
7pub struct Optimizer;
8
9impl Default for Optimizer {
10 fn default() -> Self {
11 Self::new()
12 }
13}
14
15impl Optimizer {
16 pub fn new() -> Self {
17 Self
18 }
19
20 pub fn optimize(&mut self, mut plan: LogicalPlan) -> LogicalPlan {
21 self.push_filter_below_projection(&mut plan);
22 self.use_property_indexed_node_scans(&mut plan);
23 self.remove_redundant_limit(&mut plan);
24 plan
25 }
26
27 fn push_filter_below_projection(&self, plan: &mut LogicalPlan) {
28 let len = plan.nodes.len();
29
30 for i in 0..len {
31 let input_id = match &plan.nodes[i] {
34 LogicalOp::Filter(f) => f.input,
35 _ => continue,
36 };
37
38 let should_push = match (&plan.nodes[i], &plan.nodes[input_id]) {
39 (LogicalOp::Filter(filter), LogicalOp::Projection(proj)) => {
40 if proj.distinct || proj.include_existing {
41 false
42 } else {
43 let output_vars: BTreeSet<VarId> =
44 proj.items.iter().map(|item| item.output).collect();
45 let pred_vars = collect_vars(&filter.predicate);
46 !pred_vars.iter().any(|v| output_vars.contains(v))
47 }
48 }
49 _ => false,
50 };
51
52 if !should_push {
53 continue;
54 }
55
56 let placeholder = || LogicalOp::Argument(Argument);
58 let filter = match std::mem::replace(&mut plan.nodes[i], placeholder()) {
59 LogicalOp::Filter(f) => f,
60 _ => unreachable!(),
61 };
62 let proj = match std::mem::replace(&mut plan.nodes[input_id], placeholder()) {
63 LogicalOp::Projection(p) => p,
64 _ => unreachable!(),
65 };
66
67 plan.nodes[input_id] = LogicalOp::Filter(Filter {
68 input: proj.input,
69 predicate: filter.predicate,
70 });
71 plan.nodes[i] = LogicalOp::Projection(Projection {
72 input: input_id,
73 distinct: proj.distinct,
74 items: proj.items,
75 include_existing: proj.include_existing,
76 });
77 }
78 }
79
80 fn remove_redundant_limit(&self, _plan: &mut LogicalPlan) {
81 }
83
84 fn use_property_indexed_node_scans(&self, plan: &mut LogicalPlan) {
85 let len = plan.nodes.len();
86
87 for i in 0..len {
88 let (input_id, predicate) = match &plan.nodes[i] {
89 LogicalOp::Filter(f) => (f.input, &f.predicate),
90 _ => continue,
91 };
92
93 let (var, key, value) =
94 match property_equality_candidate(predicate, &plan.nodes[input_id]) {
95 Some(candidate) => candidate,
96 None => continue,
97 };
98
99 let replacement = match &plan.nodes[input_id] {
100 LogicalOp::NodeScan(scan) => {
101 Some(LogicalOp::NodeByPropertyScan(NodeByPropertyScan {
102 input: scan.input,
103 var,
104 labels: scan.labels.clone(),
105 key,
106 value,
107 }))
108 }
109 LogicalOp::NodeByPropertyScan(_) => None,
110 _ => None,
111 };
112
113 if let Some(replacement) = replacement {
114 plan.nodes[input_id] = replacement;
115 }
116 }
117 }
118
119 pub fn lower_to_physical(&mut self, logical: LogicalPlan) -> PhysicalPlan {
123 let LogicalPlan { root, nodes } = logical;
124
125 let nodes = nodes
126 .into_iter()
127 .map(|op| match op {
128 LogicalOp::Argument(_) => PhysicalOp::Argument(ArgumentExec),
129
130 LogicalOp::NodeScan(scan) => {
131 if scan.labels.is_empty() {
132 PhysicalOp::NodeScan(NodeScanExec {
133 input: scan.input,
134 var: scan.var,
135 })
136 } else {
137 PhysicalOp::NodeByLabelScan(NodeByLabelScanExec {
138 input: scan.input,
139 var: scan.var,
140 labels: scan.labels,
141 })
142 }
143 }
144
145 LogicalOp::NodeByPropertyScan(scan) => {
146 PhysicalOp::NodeByPropertyScan(NodeByPropertyScanExec {
147 input: scan.input,
148 var: scan.var,
149 labels: scan.labels,
150 key: scan.key,
151 value: scan.value,
152 })
153 }
154
155 LogicalOp::Expand(expand) => PhysicalOp::Expand(ExpandExec {
156 input: expand.input,
157 src: expand.src,
158 rel: expand.rel,
159 dst: expand.dst,
160 types: expand.types,
161 direction: expand.direction,
162 rel_properties: expand.rel_properties,
163 range: expand.range,
164 }),
165
166 LogicalOp::Filter(filter) => PhysicalOp::Filter(FilterExec {
167 input: filter.input,
168 predicate: filter.predicate,
169 }),
170
171 LogicalOp::Projection(proj) => PhysicalOp::Projection(ProjectionExec {
172 input: proj.input,
173 distinct: proj.distinct,
174 items: proj.items,
175 include_existing: proj.include_existing,
176 }),
177
178 LogicalOp::Unwind(unwind) => PhysicalOp::Unwind(UnwindExec {
179 input: unwind.input,
180 expr: unwind.expr,
181 alias: unwind.alias,
182 }),
183
184 LogicalOp::Aggregation(agg) => PhysicalOp::HashAggregation(HashAggregationExec {
185 input: agg.input,
186 group_by: agg.group_by,
187 aggregates: agg.aggregates,
188 }),
189
190 LogicalOp::Sort(sort) => PhysicalOp::Sort(SortExec {
191 input: sort.input,
192 items: sort.items,
193 }),
194
195 LogicalOp::Limit(limit) => PhysicalOp::Limit(LimitExec {
196 input: limit.input,
197 skip: limit.skip,
198 limit: limit.limit,
199 }),
200
201 LogicalOp::Create(create) => PhysicalOp::Create(CreateExec {
202 input: create.input,
203 pattern: create.pattern,
204 }),
205
206 LogicalOp::Merge(merge) => PhysicalOp::Merge(MergeExec {
207 input: merge.input,
208 pattern_part: merge.pattern_part,
209 actions: merge.actions,
210 }),
211
212 LogicalOp::Delete(delete) => PhysicalOp::Delete(DeleteExec {
213 input: delete.input,
214 detach: delete.detach,
215 expressions: delete.expressions,
216 }),
217
218 LogicalOp::Set(set) => PhysicalOp::Set(SetExec {
219 input: set.input,
220 items: set.items,
221 }),
222
223 LogicalOp::Remove(remove) => PhysicalOp::Remove(RemoveExec {
224 input: remove.input,
225 items: remove.items,
226 }),
227
228 LogicalOp::OptionalMatch(om) => PhysicalOp::OptionalMatch(OptionalMatchExec {
229 input: om.input,
230 inner: om.inner,
231 new_vars: om.new_vars,
232 }),
233
234 LogicalOp::PathBuild(pb) => PhysicalOp::PathBuild(PathBuildExec {
235 input: pb.input,
236 output: pb.output,
237 node_vars: pb.node_vars,
238 rel_vars: pb.rel_vars,
239 shortest_path_all: pb.shortest_path_all,
240 }),
241 })
242 .collect();
243
244 PhysicalPlan { root, nodes }
245 }
246}
247
248fn collect_vars(expr: &ResolvedExpr) -> BTreeSet<VarId> {
249 let mut vars = BTreeSet::new();
250 collect_vars_inner(expr, &mut vars);
251 vars
252}
253
254fn property_equality_candidate(
255 predicate: &ResolvedExpr,
256 input: &LogicalOp,
257) -> Option<(VarId, String, ResolvedExpr)> {
258 let LogicalOp::NodeScan(scan) = input else {
259 return None;
260 };
261
262 property_equality_for_var(predicate, scan.var)
263}
264
265fn property_equality_for_var(
266 predicate: &ResolvedExpr,
267 var: VarId,
268) -> Option<(VarId, String, ResolvedExpr)> {
269 let ResolvedExpr::Binary { lhs, op, rhs } = predicate else {
270 return None;
271 };
272
273 if matches!(op, BinaryOp::And) {
274 return property_equality_for_var(lhs, var).or_else(|| property_equality_for_var(rhs, var));
275 }
276
277 if !matches!(op, BinaryOp::Eq) {
278 return None;
279 }
280
281 property_access_for_var(lhs, var)
282 .filter(|_| !collect_vars(rhs).contains(&var))
283 .map(|key| (var, key, (**rhs).clone()))
284 .or_else(|| {
285 property_access_for_var(rhs, var)
286 .filter(|_| !collect_vars(lhs).contains(&var))
287 .map(|key| (var, key, (**lhs).clone()))
288 })
289}
290
291fn property_access_for_var(expr: &ResolvedExpr, var: VarId) -> Option<String> {
292 match expr {
293 ResolvedExpr::Property { expr, property } => match &**expr {
294 ResolvedExpr::Variable(v) if *v == var => Some(property.clone()),
295 _ => None,
296 },
297 _ => None,
298 }
299}
300
301fn collect_vars_inner(expr: &ResolvedExpr, out: &mut BTreeSet<VarId>) {
302 match expr {
303 ResolvedExpr::Variable(v) => {
304 out.insert(*v);
305 }
306 ResolvedExpr::Property { expr, .. } => collect_vars_inner(expr, out),
307 ResolvedExpr::Binary { lhs, rhs, .. } => {
308 collect_vars_inner(lhs, out);
309 collect_vars_inner(rhs, out);
310 }
311 ResolvedExpr::Unary { expr, .. } => collect_vars_inner(expr, out),
312 ResolvedExpr::Function { args, .. } => {
313 for arg in args {
314 collect_vars_inner(arg, out);
315 }
316 }
317 ResolvedExpr::List(items) => {
318 for item in items {
319 collect_vars_inner(item, out);
320 }
321 }
322 ResolvedExpr::Map(items) => {
323 for (_, v) in items {
324 collect_vars_inner(v, out);
325 }
326 }
327 ResolvedExpr::Case {
328 input,
329 alternatives,
330 else_expr,
331 } => {
332 if let Some(e) = input {
333 collect_vars_inner(e, out);
334 }
335 for (w, t) in alternatives {
336 collect_vars_inner(w, out);
337 collect_vars_inner(t, out);
338 }
339 if let Some(e) = else_expr {
340 collect_vars_inner(e, out);
341 }
342 }
343 ResolvedExpr::ListPredicate {
344 variable,
345 list,
346 predicate,
347 ..
348 } => {
349 out.insert(*variable);
350 collect_vars_inner(list, out);
351 collect_vars_inner(predicate, out);
352 }
353 ResolvedExpr::ListComprehension {
354 variable,
355 list,
356 filter,
357 map_expr,
358 ..
359 } => {
360 out.insert(*variable);
361 collect_vars_inner(list, out);
362 if let Some(f) = filter {
363 collect_vars_inner(f, out);
364 }
365 if let Some(m) = map_expr {
366 collect_vars_inner(m, out);
367 }
368 }
369 ResolvedExpr::Reduce {
370 accumulator,
371 init,
372 variable,
373 list,
374 expr,
375 ..
376 } => {
377 out.insert(*accumulator);
378 out.insert(*variable);
379 collect_vars_inner(init, out);
380 collect_vars_inner(list, out);
381 collect_vars_inner(expr, out);
382 }
383 ResolvedExpr::Index { expr, index } => {
384 collect_vars_inner(expr, out);
385 collect_vars_inner(index, out);
386 }
387 ResolvedExpr::Slice { expr, from, to } => {
388 collect_vars_inner(expr, out);
389 if let Some(f) = from {
390 collect_vars_inner(f, out);
391 }
392 if let Some(t) = to {
393 collect_vars_inner(t, out);
394 }
395 }
396 ResolvedExpr::MapProjection { base, selectors } => {
397 collect_vars_inner(base, out);
398 for sel in selectors {
399 if let lora_analyzer::ResolvedMapSelector::Literal(_, e) = sel {
400 collect_vars_inner(e, out);
401 }
402 }
403 }
404 _ => {}
405 }
406}