1use crate::logical::*;
2use crate::physical::*;
3use lora_analyzer::{symbols::VarId, ResolvedExpr};
4use std::collections::BTreeSet;
5
6pub struct Optimizer;
7
8impl Default for Optimizer {
9 fn default() -> Self {
10 Self::new()
11 }
12}
13
14impl Optimizer {
15 pub fn new() -> Self {
16 Self
17 }
18
19 pub fn optimize(&mut self, mut plan: LogicalPlan) -> LogicalPlan {
20 self.push_filter_below_projection(&mut plan);
21 self.remove_redundant_limit(&mut plan);
22 plan
23 }
24
25 fn push_filter_below_projection(&self, plan: &mut LogicalPlan) {
26 let len = plan.nodes.len();
27
28 for i in 0..len {
29 let input_id = match &plan.nodes[i] {
32 LogicalOp::Filter(f) => f.input,
33 _ => continue,
34 };
35
36 let should_push = match (&plan.nodes[i], &plan.nodes[input_id]) {
37 (LogicalOp::Filter(filter), LogicalOp::Projection(proj)) => {
38 if proj.distinct || proj.include_existing {
39 false
40 } else {
41 let output_vars: BTreeSet<VarId> =
42 proj.items.iter().map(|item| item.output).collect();
43 let pred_vars = collect_vars(&filter.predicate);
44 !pred_vars.iter().any(|v| output_vars.contains(v))
45 }
46 }
47 _ => false,
48 };
49
50 if !should_push {
51 continue;
52 }
53
54 let placeholder = || LogicalOp::Argument(Argument);
56 let filter = match std::mem::replace(&mut plan.nodes[i], placeholder()) {
57 LogicalOp::Filter(f) => f,
58 _ => unreachable!(),
59 };
60 let proj = match std::mem::replace(&mut plan.nodes[input_id], placeholder()) {
61 LogicalOp::Projection(p) => p,
62 _ => unreachable!(),
63 };
64
65 plan.nodes[input_id] = LogicalOp::Filter(Filter {
66 input: proj.input,
67 predicate: filter.predicate,
68 });
69 plan.nodes[i] = LogicalOp::Projection(Projection {
70 input: input_id,
71 distinct: proj.distinct,
72 items: proj.items,
73 include_existing: proj.include_existing,
74 });
75 }
76 }
77
78 fn remove_redundant_limit(&self, _plan: &mut LogicalPlan) {
79 }
81
82 pub fn lower_to_physical(&mut self, logical: LogicalPlan) -> PhysicalPlan {
86 let LogicalPlan { root, nodes } = logical;
87
88 let nodes = nodes
89 .into_iter()
90 .map(|op| match op {
91 LogicalOp::Argument(_) => PhysicalOp::Argument(ArgumentExec),
92
93 LogicalOp::NodeScan(scan) => {
94 if scan.labels.is_empty() {
95 PhysicalOp::NodeScan(NodeScanExec {
96 input: scan.input,
97 var: scan.var,
98 })
99 } else {
100 PhysicalOp::NodeByLabelScan(NodeByLabelScanExec {
101 input: scan.input,
102 var: scan.var,
103 labels: scan.labels,
104 })
105 }
106 }
107
108 LogicalOp::Expand(expand) => PhysicalOp::Expand(ExpandExec {
109 input: expand.input,
110 src: expand.src,
111 rel: expand.rel,
112 dst: expand.dst,
113 types: expand.types,
114 direction: expand.direction,
115 rel_properties: expand.rel_properties,
116 range: expand.range,
117 }),
118
119 LogicalOp::Filter(filter) => PhysicalOp::Filter(FilterExec {
120 input: filter.input,
121 predicate: filter.predicate,
122 }),
123
124 LogicalOp::Projection(proj) => PhysicalOp::Projection(ProjectionExec {
125 input: proj.input,
126 distinct: proj.distinct,
127 items: proj.items,
128 include_existing: proj.include_existing,
129 }),
130
131 LogicalOp::Unwind(unwind) => PhysicalOp::Unwind(UnwindExec {
132 input: unwind.input,
133 expr: unwind.expr,
134 alias: unwind.alias,
135 }),
136
137 LogicalOp::Aggregation(agg) => PhysicalOp::HashAggregation(HashAggregationExec {
138 input: agg.input,
139 group_by: agg.group_by,
140 aggregates: agg.aggregates,
141 }),
142
143 LogicalOp::Sort(sort) => PhysicalOp::Sort(SortExec {
144 input: sort.input,
145 items: sort.items,
146 }),
147
148 LogicalOp::Limit(limit) => PhysicalOp::Limit(LimitExec {
149 input: limit.input,
150 skip: limit.skip,
151 limit: limit.limit,
152 }),
153
154 LogicalOp::Create(create) => PhysicalOp::Create(CreateExec {
155 input: create.input,
156 pattern: create.pattern,
157 }),
158
159 LogicalOp::Merge(merge) => PhysicalOp::Merge(MergeExec {
160 input: merge.input,
161 pattern_part: merge.pattern_part,
162 actions: merge.actions,
163 }),
164
165 LogicalOp::Delete(delete) => PhysicalOp::Delete(DeleteExec {
166 input: delete.input,
167 detach: delete.detach,
168 expressions: delete.expressions,
169 }),
170
171 LogicalOp::Set(set) => PhysicalOp::Set(SetExec {
172 input: set.input,
173 items: set.items,
174 }),
175
176 LogicalOp::Remove(remove) => PhysicalOp::Remove(RemoveExec {
177 input: remove.input,
178 items: remove.items,
179 }),
180
181 LogicalOp::OptionalMatch(om) => PhysicalOp::OptionalMatch(OptionalMatchExec {
182 input: om.input,
183 inner: om.inner,
184 new_vars: om.new_vars,
185 }),
186
187 LogicalOp::PathBuild(pb) => PhysicalOp::PathBuild(PathBuildExec {
188 input: pb.input,
189 output: pb.output,
190 node_vars: pb.node_vars,
191 rel_vars: pb.rel_vars,
192 shortest_path_all: pb.shortest_path_all,
193 }),
194 })
195 .collect();
196
197 PhysicalPlan { root, nodes }
198 }
199}
200
201fn collect_vars(expr: &ResolvedExpr) -> BTreeSet<VarId> {
202 let mut vars = BTreeSet::new();
203 collect_vars_inner(expr, &mut vars);
204 vars
205}
206
207fn collect_vars_inner(expr: &ResolvedExpr, out: &mut BTreeSet<VarId>) {
208 match expr {
209 ResolvedExpr::Variable(v) => {
210 out.insert(*v);
211 }
212 ResolvedExpr::Property { expr, .. } => collect_vars_inner(expr, out),
213 ResolvedExpr::Binary { lhs, rhs, .. } => {
214 collect_vars_inner(lhs, out);
215 collect_vars_inner(rhs, out);
216 }
217 ResolvedExpr::Unary { expr, .. } => collect_vars_inner(expr, out),
218 ResolvedExpr::Function { args, .. } => {
219 for arg in args {
220 collect_vars_inner(arg, out);
221 }
222 }
223 ResolvedExpr::List(items) => {
224 for item in items {
225 collect_vars_inner(item, out);
226 }
227 }
228 ResolvedExpr::Map(items) => {
229 for (_, v) in items {
230 collect_vars_inner(v, out);
231 }
232 }
233 ResolvedExpr::Case {
234 input,
235 alternatives,
236 else_expr,
237 } => {
238 if let Some(e) = input {
239 collect_vars_inner(e, out);
240 }
241 for (w, t) in alternatives {
242 collect_vars_inner(w, out);
243 collect_vars_inner(t, out);
244 }
245 if let Some(e) = else_expr {
246 collect_vars_inner(e, out);
247 }
248 }
249 ResolvedExpr::ListPredicate {
250 variable,
251 list,
252 predicate,
253 ..
254 } => {
255 out.insert(*variable);
256 collect_vars_inner(list, out);
257 collect_vars_inner(predicate, out);
258 }
259 ResolvedExpr::ListComprehension {
260 variable,
261 list,
262 filter,
263 map_expr,
264 ..
265 } => {
266 out.insert(*variable);
267 collect_vars_inner(list, out);
268 if let Some(f) = filter {
269 collect_vars_inner(f, out);
270 }
271 if let Some(m) = map_expr {
272 collect_vars_inner(m, out);
273 }
274 }
275 ResolvedExpr::Reduce {
276 accumulator,
277 init,
278 variable,
279 list,
280 expr,
281 ..
282 } => {
283 out.insert(*accumulator);
284 out.insert(*variable);
285 collect_vars_inner(init, out);
286 collect_vars_inner(list, out);
287 collect_vars_inner(expr, out);
288 }
289 ResolvedExpr::Index { expr, index } => {
290 collect_vars_inner(expr, out);
291 collect_vars_inner(index, out);
292 }
293 ResolvedExpr::Slice { expr, from, to } => {
294 collect_vars_inner(expr, out);
295 if let Some(f) = from {
296 collect_vars_inner(f, out);
297 }
298 if let Some(t) = to {
299 collect_vars_inner(t, out);
300 }
301 }
302 ResolvedExpr::MapProjection { base, selectors } => {
303 collect_vars_inner(base, out);
304 for sel in selectors {
305 if let lora_analyzer::ResolvedMapSelector::Literal(_, e) = sel {
306 collect_vars_inner(e, out);
307 }
308 }
309 }
310 _ => {}
311 }
312}