Skip to main content

lora_compiler/
optimizer.rs

1use crate::logical::*;
2use crate::physical::*;
3use lora_analyzer::{symbols::VarId, ResolvedExpr};
4use std::collections::BTreeSet;
5
6pub struct Optimizer;
7
8impl Default for Optimizer {
9    fn default() -> Self {
10        Self::new()
11    }
12}
13
14impl Optimizer {
15    pub fn new() -> Self {
16        Self
17    }
18
19    pub fn optimize(&mut self, mut plan: LogicalPlan) -> LogicalPlan {
20        self.push_filter_below_projection(&mut plan);
21        self.remove_redundant_limit(&mut plan);
22        plan
23    }
24
25    fn push_filter_below_projection(&self, plan: &mut LogicalPlan) {
26        let len = plan.nodes.len();
27
28        for i in 0..len {
29            // Inspect by reference first so we can decide without cloning the
30            // potentially-large op payloads.
31            let input_id = match &plan.nodes[i] {
32                LogicalOp::Filter(f) => f.input,
33                _ => continue,
34            };
35
36            let should_push = match (&plan.nodes[i], &plan.nodes[input_id]) {
37                (LogicalOp::Filter(filter), LogicalOp::Projection(proj)) => {
38                    if proj.distinct || proj.include_existing {
39                        false
40                    } else {
41                        let output_vars: BTreeSet<VarId> =
42                            proj.items.iter().map(|item| item.output).collect();
43                        let pred_vars = collect_vars(&filter.predicate);
44                        !pred_vars.iter().any(|v| output_vars.contains(v))
45                    }
46                }
47                _ => false,
48            };
49
50            if !should_push {
51                continue;
52            }
53
54            // Move both nodes out by swap, then rebuild without cloning.
55            let placeholder = || LogicalOp::Argument(Argument);
56            let filter = match std::mem::replace(&mut plan.nodes[i], placeholder()) {
57                LogicalOp::Filter(f) => f,
58                _ => unreachable!(),
59            };
60            let proj = match std::mem::replace(&mut plan.nodes[input_id], placeholder()) {
61                LogicalOp::Projection(p) => p,
62                _ => unreachable!(),
63            };
64
65            plan.nodes[input_id] = LogicalOp::Filter(Filter {
66                input: proj.input,
67                predicate: filter.predicate,
68            });
69            plan.nodes[i] = LogicalOp::Projection(Projection {
70                input: input_id,
71                distinct: proj.distinct,
72                items: proj.items,
73                include_existing: proj.include_existing,
74            });
75        }
76    }
77
78    fn remove_redundant_limit(&self, _plan: &mut LogicalPlan) {
79        // placeholder for future rules
80    }
81
82    /// Lower a logical plan by consuming it — each op's owned payload
83    /// (expressions, patterns, items) is moved into the physical op rather
84    /// than cloned. Callers should not need the logical plan after this.
85    pub fn lower_to_physical(&mut self, logical: LogicalPlan) -> PhysicalPlan {
86        let LogicalPlan { root, nodes } = logical;
87
88        let nodes = nodes
89            .into_iter()
90            .map(|op| match op {
91                LogicalOp::Argument(_) => PhysicalOp::Argument(ArgumentExec),
92
93                LogicalOp::NodeScan(scan) => {
94                    if scan.labels.is_empty() {
95                        PhysicalOp::NodeScan(NodeScanExec {
96                            input: scan.input,
97                            var: scan.var,
98                        })
99                    } else {
100                        PhysicalOp::NodeByLabelScan(NodeByLabelScanExec {
101                            input: scan.input,
102                            var: scan.var,
103                            labels: scan.labels,
104                        })
105                    }
106                }
107
108                LogicalOp::Expand(expand) => PhysicalOp::Expand(ExpandExec {
109                    input: expand.input,
110                    src: expand.src,
111                    rel: expand.rel,
112                    dst: expand.dst,
113                    types: expand.types,
114                    direction: expand.direction,
115                    rel_properties: expand.rel_properties,
116                    range: expand.range,
117                }),
118
119                LogicalOp::Filter(filter) => PhysicalOp::Filter(FilterExec {
120                    input: filter.input,
121                    predicate: filter.predicate,
122                }),
123
124                LogicalOp::Projection(proj) => PhysicalOp::Projection(ProjectionExec {
125                    input: proj.input,
126                    distinct: proj.distinct,
127                    items: proj.items,
128                    include_existing: proj.include_existing,
129                }),
130
131                LogicalOp::Unwind(unwind) => PhysicalOp::Unwind(UnwindExec {
132                    input: unwind.input,
133                    expr: unwind.expr,
134                    alias: unwind.alias,
135                }),
136
137                LogicalOp::Aggregation(agg) => PhysicalOp::HashAggregation(HashAggregationExec {
138                    input: agg.input,
139                    group_by: agg.group_by,
140                    aggregates: agg.aggregates,
141                }),
142
143                LogicalOp::Sort(sort) => PhysicalOp::Sort(SortExec {
144                    input: sort.input,
145                    items: sort.items,
146                }),
147
148                LogicalOp::Limit(limit) => PhysicalOp::Limit(LimitExec {
149                    input: limit.input,
150                    skip: limit.skip,
151                    limit: limit.limit,
152                }),
153
154                LogicalOp::Create(create) => PhysicalOp::Create(CreateExec {
155                    input: create.input,
156                    pattern: create.pattern,
157                }),
158
159                LogicalOp::Merge(merge) => PhysicalOp::Merge(MergeExec {
160                    input: merge.input,
161                    pattern_part: merge.pattern_part,
162                    actions: merge.actions,
163                }),
164
165                LogicalOp::Delete(delete) => PhysicalOp::Delete(DeleteExec {
166                    input: delete.input,
167                    detach: delete.detach,
168                    expressions: delete.expressions,
169                }),
170
171                LogicalOp::Set(set) => PhysicalOp::Set(SetExec {
172                    input: set.input,
173                    items: set.items,
174                }),
175
176                LogicalOp::Remove(remove) => PhysicalOp::Remove(RemoveExec {
177                    input: remove.input,
178                    items: remove.items,
179                }),
180
181                LogicalOp::OptionalMatch(om) => PhysicalOp::OptionalMatch(OptionalMatchExec {
182                    input: om.input,
183                    inner: om.inner,
184                    new_vars: om.new_vars,
185                }),
186
187                LogicalOp::PathBuild(pb) => PhysicalOp::PathBuild(PathBuildExec {
188                    input: pb.input,
189                    output: pb.output,
190                    node_vars: pb.node_vars,
191                    rel_vars: pb.rel_vars,
192                    shortest_path_all: pb.shortest_path_all,
193                }),
194            })
195            .collect();
196
197        PhysicalPlan { root, nodes }
198    }
199}
200
201fn collect_vars(expr: &ResolvedExpr) -> BTreeSet<VarId> {
202    let mut vars = BTreeSet::new();
203    collect_vars_inner(expr, &mut vars);
204    vars
205}
206
207fn collect_vars_inner(expr: &ResolvedExpr, out: &mut BTreeSet<VarId>) {
208    match expr {
209        ResolvedExpr::Variable(v) => {
210            out.insert(*v);
211        }
212        ResolvedExpr::Property { expr, .. } => collect_vars_inner(expr, out),
213        ResolvedExpr::Binary { lhs, rhs, .. } => {
214            collect_vars_inner(lhs, out);
215            collect_vars_inner(rhs, out);
216        }
217        ResolvedExpr::Unary { expr, .. } => collect_vars_inner(expr, out),
218        ResolvedExpr::Function { args, .. } => {
219            for arg in args {
220                collect_vars_inner(arg, out);
221            }
222        }
223        ResolvedExpr::List(items) => {
224            for item in items {
225                collect_vars_inner(item, out);
226            }
227        }
228        ResolvedExpr::Map(items) => {
229            for (_, v) in items {
230                collect_vars_inner(v, out);
231            }
232        }
233        ResolvedExpr::Case {
234            input,
235            alternatives,
236            else_expr,
237        } => {
238            if let Some(e) = input {
239                collect_vars_inner(e, out);
240            }
241            for (w, t) in alternatives {
242                collect_vars_inner(w, out);
243                collect_vars_inner(t, out);
244            }
245            if let Some(e) = else_expr {
246                collect_vars_inner(e, out);
247            }
248        }
249        ResolvedExpr::ListPredicate {
250            variable,
251            list,
252            predicate,
253            ..
254        } => {
255            out.insert(*variable);
256            collect_vars_inner(list, out);
257            collect_vars_inner(predicate, out);
258        }
259        ResolvedExpr::ListComprehension {
260            variable,
261            list,
262            filter,
263            map_expr,
264            ..
265        } => {
266            out.insert(*variable);
267            collect_vars_inner(list, out);
268            if let Some(f) = filter {
269                collect_vars_inner(f, out);
270            }
271            if let Some(m) = map_expr {
272                collect_vars_inner(m, out);
273            }
274        }
275        ResolvedExpr::Reduce {
276            accumulator,
277            init,
278            variable,
279            list,
280            expr,
281            ..
282        } => {
283            out.insert(*accumulator);
284            out.insert(*variable);
285            collect_vars_inner(init, out);
286            collect_vars_inner(list, out);
287            collect_vars_inner(expr, out);
288        }
289        ResolvedExpr::Index { expr, index } => {
290            collect_vars_inner(expr, out);
291            collect_vars_inner(index, out);
292        }
293        ResolvedExpr::Slice { expr, from, to } => {
294            collect_vars_inner(expr, out);
295            if let Some(f) = from {
296                collect_vars_inner(f, out);
297            }
298            if let Some(t) = to {
299                collect_vars_inner(t, out);
300            }
301        }
302        ResolvedExpr::MapProjection { base, selectors } => {
303            collect_vars_inner(base, out);
304            for sel in selectors {
305                if let lora_analyzer::ResolvedMapSelector::Literal(_, e) = sel {
306                    collect_vars_inner(e, out);
307                }
308            }
309        }
310        _ => {}
311    }
312}