Skip to main content

lora_compiler/
optimizer.rs

1use crate::logical::*;
2use crate::physical::*;
3use lora_analyzer::{symbols::VarId, ResolvedExpr};
4use lora_ast::BinaryOp;
5use std::collections::BTreeSet;
6
7pub struct Optimizer;
8
9impl Default for Optimizer {
10    fn default() -> Self {
11        Self::new()
12    }
13}
14
15impl Optimizer {
16    pub fn new() -> Self {
17        Self
18    }
19
20    pub fn optimize(&mut self, mut plan: LogicalPlan) -> LogicalPlan {
21        self.push_filter_below_projection(&mut plan);
22        self.use_property_indexed_node_scans(&mut plan);
23        self.remove_redundant_limit(&mut plan);
24        plan
25    }
26
27    fn push_filter_below_projection(&self, plan: &mut LogicalPlan) {
28        let len = plan.nodes.len();
29
30        for i in 0..len {
31            // Inspect by reference first so we can decide without cloning the
32            // potentially-large op payloads.
33            let input_id = match &plan.nodes[i] {
34                LogicalOp::Filter(f) => f.input,
35                _ => continue,
36            };
37
38            let should_push = match (&plan.nodes[i], &plan.nodes[input_id]) {
39                (LogicalOp::Filter(filter), LogicalOp::Projection(proj)) => {
40                    if proj.distinct || proj.include_existing {
41                        false
42                    } else {
43                        let output_vars: BTreeSet<VarId> =
44                            proj.items.iter().map(|item| item.output).collect();
45                        let pred_vars = collect_vars(&filter.predicate);
46                        !pred_vars.iter().any(|v| output_vars.contains(v))
47                    }
48                }
49                _ => false,
50            };
51
52            if !should_push {
53                continue;
54            }
55
56            // Move both nodes out by swap, then rebuild without cloning.
57            let placeholder = || LogicalOp::Argument(Argument);
58            let filter = match std::mem::replace(&mut plan.nodes[i], placeholder()) {
59                LogicalOp::Filter(f) => f,
60                _ => unreachable!(),
61            };
62            let proj = match std::mem::replace(&mut plan.nodes[input_id], placeholder()) {
63                LogicalOp::Projection(p) => p,
64                _ => unreachable!(),
65            };
66
67            plan.nodes[input_id] = LogicalOp::Filter(Filter {
68                input: proj.input,
69                predicate: filter.predicate,
70            });
71            plan.nodes[i] = LogicalOp::Projection(Projection {
72                input: input_id,
73                distinct: proj.distinct,
74                items: proj.items,
75                include_existing: proj.include_existing,
76            });
77        }
78    }
79
80    fn remove_redundant_limit(&self, _plan: &mut LogicalPlan) {
81        // placeholder for future rules
82    }
83
84    fn use_property_indexed_node_scans(&self, plan: &mut LogicalPlan) {
85        let len = plan.nodes.len();
86
87        for i in 0..len {
88            let (input_id, predicate) = match &plan.nodes[i] {
89                LogicalOp::Filter(f) => (f.input, &f.predicate),
90                _ => continue,
91            };
92
93            let (var, key, value) =
94                match property_equality_candidate(predicate, &plan.nodes[input_id]) {
95                    Some(candidate) => candidate,
96                    None => continue,
97                };
98
99            let replacement = match &plan.nodes[input_id] {
100                LogicalOp::NodeScan(scan) => {
101                    Some(LogicalOp::NodeByPropertyScan(NodeByPropertyScan {
102                        input: scan.input,
103                        var,
104                        labels: scan.labels.clone(),
105                        key,
106                        value,
107                    }))
108                }
109                LogicalOp::NodeByPropertyScan(_) => None,
110                _ => None,
111            };
112
113            if let Some(replacement) = replacement {
114                plan.nodes[input_id] = replacement;
115            }
116        }
117    }
118
119    /// Lower a logical plan by consuming it — each op's owned payload
120    /// (expressions, patterns, items) is moved into the physical op rather
121    /// than cloned. Callers should not need the logical plan after this.
122    pub fn lower_to_physical(&mut self, logical: LogicalPlan) -> PhysicalPlan {
123        let LogicalPlan { root, nodes } = logical;
124
125        let nodes = nodes
126            .into_iter()
127            .map(|op| match op {
128                LogicalOp::Argument(_) => PhysicalOp::Argument(ArgumentExec),
129
130                LogicalOp::NodeScan(scan) => {
131                    if scan.labels.is_empty() {
132                        PhysicalOp::NodeScan(NodeScanExec {
133                            input: scan.input,
134                            var: scan.var,
135                        })
136                    } else {
137                        PhysicalOp::NodeByLabelScan(NodeByLabelScanExec {
138                            input: scan.input,
139                            var: scan.var,
140                            labels: scan.labels,
141                        })
142                    }
143                }
144
145                LogicalOp::NodeByPropertyScan(scan) => {
146                    PhysicalOp::NodeByPropertyScan(NodeByPropertyScanExec {
147                        input: scan.input,
148                        var: scan.var,
149                        labels: scan.labels,
150                        key: scan.key,
151                        value: scan.value,
152                    })
153                }
154
155                LogicalOp::Expand(expand) => PhysicalOp::Expand(ExpandExec {
156                    input: expand.input,
157                    src: expand.src,
158                    rel: expand.rel,
159                    dst: expand.dst,
160                    types: expand.types,
161                    direction: expand.direction,
162                    rel_properties: expand.rel_properties,
163                    range: expand.range,
164                }),
165
166                LogicalOp::Filter(filter) => PhysicalOp::Filter(FilterExec {
167                    input: filter.input,
168                    predicate: filter.predicate,
169                }),
170
171                LogicalOp::Projection(proj) => PhysicalOp::Projection(ProjectionExec {
172                    input: proj.input,
173                    distinct: proj.distinct,
174                    items: proj.items,
175                    include_existing: proj.include_existing,
176                }),
177
178                LogicalOp::Unwind(unwind) => PhysicalOp::Unwind(UnwindExec {
179                    input: unwind.input,
180                    expr: unwind.expr,
181                    alias: unwind.alias,
182                }),
183
184                LogicalOp::Aggregation(agg) => PhysicalOp::HashAggregation(HashAggregationExec {
185                    input: agg.input,
186                    group_by: agg.group_by,
187                    aggregates: agg.aggregates,
188                }),
189
190                LogicalOp::Sort(sort) => PhysicalOp::Sort(SortExec {
191                    input: sort.input,
192                    items: sort.items,
193                }),
194
195                LogicalOp::Limit(limit) => PhysicalOp::Limit(LimitExec {
196                    input: limit.input,
197                    skip: limit.skip,
198                    limit: limit.limit,
199                }),
200
201                LogicalOp::Create(create) => PhysicalOp::Create(CreateExec {
202                    input: create.input,
203                    pattern: create.pattern,
204                }),
205
206                LogicalOp::Merge(merge) => PhysicalOp::Merge(MergeExec {
207                    input: merge.input,
208                    pattern_part: merge.pattern_part,
209                    actions: merge.actions,
210                }),
211
212                LogicalOp::Delete(delete) => PhysicalOp::Delete(DeleteExec {
213                    input: delete.input,
214                    detach: delete.detach,
215                    expressions: delete.expressions,
216                }),
217
218                LogicalOp::Set(set) => PhysicalOp::Set(SetExec {
219                    input: set.input,
220                    items: set.items,
221                }),
222
223                LogicalOp::Remove(remove) => PhysicalOp::Remove(RemoveExec {
224                    input: remove.input,
225                    items: remove.items,
226                }),
227
228                LogicalOp::OptionalMatch(om) => PhysicalOp::OptionalMatch(OptionalMatchExec {
229                    input: om.input,
230                    inner: om.inner,
231                    new_vars: om.new_vars,
232                }),
233
234                LogicalOp::PathBuild(pb) => PhysicalOp::PathBuild(PathBuildExec {
235                    input: pb.input,
236                    output: pb.output,
237                    node_vars: pb.node_vars,
238                    rel_vars: pb.rel_vars,
239                    shortest_path_all: pb.shortest_path_all,
240                }),
241            })
242            .collect();
243
244        PhysicalPlan { root, nodes }
245    }
246}
247
248fn collect_vars(expr: &ResolvedExpr) -> BTreeSet<VarId> {
249    let mut vars = BTreeSet::new();
250    collect_vars_inner(expr, &mut vars);
251    vars
252}
253
254fn property_equality_candidate(
255    predicate: &ResolvedExpr,
256    input: &LogicalOp,
257) -> Option<(VarId, String, ResolvedExpr)> {
258    let LogicalOp::NodeScan(scan) = input else {
259        return None;
260    };
261
262    property_equality_for_var(predicate, scan.var)
263}
264
265fn property_equality_for_var(
266    predicate: &ResolvedExpr,
267    var: VarId,
268) -> Option<(VarId, String, ResolvedExpr)> {
269    let ResolvedExpr::Binary { lhs, op, rhs } = predicate else {
270        return None;
271    };
272
273    if matches!(op, BinaryOp::And) {
274        return property_equality_for_var(lhs, var).or_else(|| property_equality_for_var(rhs, var));
275    }
276
277    if !matches!(op, BinaryOp::Eq) {
278        return None;
279    }
280
281    property_access_for_var(lhs, var)
282        .filter(|_| !collect_vars(rhs).contains(&var))
283        .map(|key| (var, key, (**rhs).clone()))
284        .or_else(|| {
285            property_access_for_var(rhs, var)
286                .filter(|_| !collect_vars(lhs).contains(&var))
287                .map(|key| (var, key, (**lhs).clone()))
288        })
289}
290
291fn property_access_for_var(expr: &ResolvedExpr, var: VarId) -> Option<String> {
292    match expr {
293        ResolvedExpr::Property { expr, property } => match &**expr {
294            ResolvedExpr::Variable(v) if *v == var => Some(property.clone()),
295            _ => None,
296        },
297        _ => None,
298    }
299}
300
301fn collect_vars_inner(expr: &ResolvedExpr, out: &mut BTreeSet<VarId>) {
302    match expr {
303        ResolvedExpr::Variable(v) => {
304            out.insert(*v);
305        }
306        ResolvedExpr::Property { expr, .. } => collect_vars_inner(expr, out),
307        ResolvedExpr::Binary { lhs, rhs, .. } => {
308            collect_vars_inner(lhs, out);
309            collect_vars_inner(rhs, out);
310        }
311        ResolvedExpr::Unary { expr, .. } => collect_vars_inner(expr, out),
312        ResolvedExpr::Function { args, .. } => {
313            for arg in args {
314                collect_vars_inner(arg, out);
315            }
316        }
317        ResolvedExpr::List(items) => {
318            for item in items {
319                collect_vars_inner(item, out);
320            }
321        }
322        ResolvedExpr::Map(items) => {
323            for (_, v) in items {
324                collect_vars_inner(v, out);
325            }
326        }
327        ResolvedExpr::Case {
328            input,
329            alternatives,
330            else_expr,
331        } => {
332            if let Some(e) = input {
333                collect_vars_inner(e, out);
334            }
335            for (w, t) in alternatives {
336                collect_vars_inner(w, out);
337                collect_vars_inner(t, out);
338            }
339            if let Some(e) = else_expr {
340                collect_vars_inner(e, out);
341            }
342        }
343        ResolvedExpr::ListPredicate {
344            variable,
345            list,
346            predicate,
347            ..
348        } => {
349            out.insert(*variable);
350            collect_vars_inner(list, out);
351            collect_vars_inner(predicate, out);
352        }
353        ResolvedExpr::ListComprehension {
354            variable,
355            list,
356            filter,
357            map_expr,
358            ..
359        } => {
360            out.insert(*variable);
361            collect_vars_inner(list, out);
362            if let Some(f) = filter {
363                collect_vars_inner(f, out);
364            }
365            if let Some(m) = map_expr {
366                collect_vars_inner(m, out);
367            }
368        }
369        ResolvedExpr::Reduce {
370            accumulator,
371            init,
372            variable,
373            list,
374            expr,
375            ..
376        } => {
377            out.insert(*accumulator);
378            out.insert(*variable);
379            collect_vars_inner(init, out);
380            collect_vars_inner(list, out);
381            collect_vars_inner(expr, out);
382        }
383        ResolvedExpr::Index { expr, index } => {
384            collect_vars_inner(expr, out);
385            collect_vars_inner(index, out);
386        }
387        ResolvedExpr::Slice { expr, from, to } => {
388            collect_vars_inner(expr, out);
389            if let Some(f) = from {
390                collect_vars_inner(f, out);
391            }
392            if let Some(t) = to {
393                collect_vars_inner(t, out);
394            }
395        }
396        ResolvedExpr::MapProjection { base, selectors } => {
397            collect_vars_inner(base, out);
398            for sel in selectors {
399                if let lora_analyzer::ResolvedMapSelector::Literal(_, e) = sel {
400                    collect_vars_inner(e, out);
401                }
402            }
403        }
404        _ => {}
405    }
406}