1use crate::pattern::PatternPlanner;
2use crate::{
3 Aggregation, Argument, Filter, Limit, LogicalOp, LogicalPlan, OptionalMatch, PlanNodeId,
4 Projection, Sort, Unwind,
5};
6use lora_analyzer::symbols::VarId;
7use lora_analyzer::{
8 ResolvedClause, ResolvedCreate, ResolvedDelete, ResolvedExpr, ResolvedMatch, ResolvedMerge,
9 ResolvedPattern, ResolvedPatternElement, ResolvedProjection, ResolvedQuery, ResolvedRemove,
10 ResolvedReturn, ResolvedSet, ResolvedUnwind, ResolvedWith,
11};
12
13pub struct Planner {
14 nodes: Vec<LogicalOp>,
15}
16
17impl Default for Planner {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl Planner {
24 pub fn new() -> Self {
25 Self { nodes: Vec::new() }
26 }
27
28 pub(crate) fn push(&mut self, op: LogicalOp) -> PlanNodeId {
29 let id = self.nodes.len();
30 self.nodes.push(op);
31 id
32 }
33
34 pub fn plan(&mut self, query: &ResolvedQuery) -> LogicalPlan {
35 let root = self.plan_query(query);
36
37 LogicalPlan {
38 root,
39 nodes: std::mem::take(&mut self.nodes),
40 }
41 }
42
43 fn plan_query(&mut self, query: &ResolvedQuery) -> PlanNodeId {
44 let mut input = None;
45
46 for clause in &query.clauses {
47 input = Some(match clause {
48 ResolvedClause::Match(m) => self.plan_match(input, m),
49
50 ResolvedClause::Unwind(u) => {
51 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
52 self.plan_unwind(upstream, u)
53 }
54
55 ResolvedClause::Create(c) => {
56 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
57 self.plan_create(upstream, c)
58 }
59
60 ResolvedClause::Merge(m) => {
61 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
62 self.plan_merge(upstream, m)
63 }
64
65 ResolvedClause::Delete(d) => {
66 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
67 self.plan_delete(upstream, d)
68 }
69
70 ResolvedClause::Set(s) => {
71 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
72 self.plan_set(upstream, s)
73 }
74
75 ResolvedClause::Remove(rm) => {
76 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
77 self.plan_remove(upstream, rm)
78 }
79
80 ResolvedClause::With(w) => {
81 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
82 self.plan_with(upstream, w)
83 }
84
85 ResolvedClause::Return(r) => {
86 let upstream = input.unwrap_or_else(|| self.plan_unit_input());
87 self.plan_return(upstream, r)
88 }
89 });
90 }
91
92 input.unwrap_or_else(|| self.plan_unit_input())
93 }
94
95 fn plan_match(&mut self, input: Option<PlanNodeId>, m: &ResolvedMatch) -> PlanNodeId {
96 if let (true, Some(upstream)) = (m.optional, input) {
97 let new_vars = collect_pattern_vars(&m.pattern);
102
103 let mut pattern_planner = PatternPlanner::new(self);
106 let mut inner = pattern_planner.plan_pattern(None, &m.pattern);
107
108 if let Some(pred) = &m.where_ {
109 inner = self.push(LogicalOp::Filter(Filter {
110 input: inner,
111 predicate: pred.clone(),
112 }));
113 }
114
115 self.push(LogicalOp::OptionalMatch(OptionalMatch {
116 input: upstream,
117 inner,
118 new_vars,
119 }))
120 } else {
121 let mut pattern_planner = PatternPlanner::new(self);
122 let mut node = pattern_planner.plan_pattern(input, &m.pattern);
123
124 if let Some(pred) = &m.where_ {
125 node = self.push(LogicalOp::Filter(Filter {
126 input: node,
127 predicate: pred.clone(),
128 }));
129 }
130
131 node
132 }
133 }
134
135 fn plan_unwind(&mut self, input: PlanNodeId, u: &ResolvedUnwind) -> PlanNodeId {
136 self.push(LogicalOp::Unwind(Unwind {
137 input,
138 expr: u.expr.clone(),
139 alias: u.alias,
140 }))
141 }
142
143 fn plan_create(&mut self, input: PlanNodeId, c: &ResolvedCreate) -> PlanNodeId {
144 self.push(LogicalOp::Create(crate::Create {
145 input,
146 pattern: c.pattern.clone(),
147 }))
148 }
149
150 fn plan_merge(&mut self, input: PlanNodeId, m: &ResolvedMerge) -> PlanNodeId {
151 self.push(LogicalOp::Merge(crate::Merge {
152 input,
153 pattern_part: m.pattern_part.clone(),
154 actions: m.actions.clone(),
155 }))
156 }
157
158 fn plan_delete(&mut self, input: PlanNodeId, d: &ResolvedDelete) -> PlanNodeId {
159 self.push(LogicalOp::Delete(crate::Delete {
160 input,
161 detach: d.detach,
162 expressions: d.expressions.clone(),
163 }))
164 }
165
166 fn plan_set(&mut self, input: PlanNodeId, s: &ResolvedSet) -> PlanNodeId {
167 self.push(LogicalOp::Set(crate::Set {
168 input,
169 items: s.items.clone(),
170 }))
171 }
172
173 fn plan_remove(&mut self, input: PlanNodeId, r: &ResolvedRemove) -> PlanNodeId {
174 self.push(LogicalOp::Remove(crate::Remove {
175 input,
176 items: r.items.clone(),
177 }))
178 }
179
180 fn plan_with(&mut self, input: PlanNodeId, with: &ResolvedWith) -> PlanNodeId {
181 let mut node = input;
182
183 if !with.order.is_empty() {
185 node = self.push(LogicalOp::Sort(Sort {
186 input: node,
187 items: with.order.clone(),
188 }));
189 }
190
191 if with.skip.is_some() || with.limit.is_some() {
192 node = self.push(LogicalOp::Limit(Limit {
193 input: node,
194 skip: with.skip.clone(),
195 limit: with.limit.clone(),
196 }));
197 }
198
199 node = self.plan_projection_or_aggregation(
200 node,
201 &with.items,
202 with.distinct,
203 with.include_existing,
204 );
205
206 if let Some(pred) = &with.where_ {
207 node = self.push(LogicalOp::Filter(Filter {
208 input: node,
209 predicate: pred.clone(),
210 }));
211 }
212
213 node
214 }
215
216 fn plan_return(&mut self, input: PlanNodeId, ret: &ResolvedReturn) -> PlanNodeId {
217 let mut node = input;
218
219 if !ret.order.is_empty() {
223 node = self.push(LogicalOp::Sort(Sort {
224 input: node,
225 items: ret.order.clone(),
226 }));
227 }
228
229 if ret.skip.is_some() || ret.limit.is_some() {
230 node = self.push(LogicalOp::Limit(Limit {
231 input: node,
232 skip: ret.skip.clone(),
233 limit: ret.limit.clone(),
234 }));
235 }
236
237 node = self.plan_projection_or_aggregation(
238 node,
239 &ret.items,
240 ret.distinct,
241 ret.include_existing,
242 );
243
244 node
245 }
246
247 fn plan_projection_or_aggregation(
251 &mut self,
252 input: PlanNodeId,
253 items: &[ResolvedProjection],
254 distinct: bool,
255 include_existing: bool,
256 ) -> PlanNodeId {
257 let has_aggregates = items.iter().any(|item| expr_contains_aggregate(&item.expr));
258
259 if !has_aggregates {
260 return self.push(LogicalOp::Projection(Projection {
261 input,
262 distinct,
263 items: items.to_vec(),
264 include_existing,
265 }));
266 }
267
268 let mut group_by = Vec::new();
270 let mut aggregates = Vec::new();
271
272 for item in items {
273 if expr_contains_aggregate(&item.expr) {
274 aggregates.push(item.clone());
275 } else {
276 group_by.push(item.clone());
277 }
278 }
279
280 let node = self.push(LogicalOp::Aggregation(Aggregation {
281 input,
282 group_by: group_by.clone(),
283 aggregates: aggregates.clone(),
284 }));
285
286 if distinct {
295 let passthrough_items: Vec<ResolvedProjection> = items
297 .iter()
298 .map(|item| ResolvedProjection {
299 expr: ResolvedExpr::Variable(item.output),
300 output: item.output,
301 name: item.name.clone(),
302 explicit_alias: item.explicit_alias,
303 span: item.span,
304 })
305 .collect();
306 self.push(LogicalOp::Projection(Projection {
307 input: node,
308 distinct: true,
309 items: passthrough_items,
310 include_existing: false,
311 }))
312 } else {
313 node
314 }
315 }
316
317 fn plan_unit_input(&mut self) -> PlanNodeId {
318 self.push(LogicalOp::Argument(Argument))
319 }
320}
321
322const AGGREGATE_FUNCTIONS: &[&str] = &[
323 "count",
324 "sum",
325 "avg",
326 "min",
327 "max",
328 "collect",
329 "stdev",
330 "stdevp",
331 "percentilecont",
332 "percentiledisc",
333];
334
335fn is_aggregate_function(name: &str) -> bool {
336 AGGREGATE_FUNCTIONS
337 .iter()
338 .any(|&f| f.eq_ignore_ascii_case(name))
339}
340
341fn collect_pattern_vars(pattern: &ResolvedPattern) -> Vec<VarId> {
343 let mut vars = Vec::new();
344 for part in &pattern.parts {
345 if let Some(v) = part.binding {
346 vars.push(v);
347 }
348 match &part.element {
349 ResolvedPatternElement::Node { var, .. } => {
350 if let Some(v) = var {
351 vars.push(*v);
352 }
353 }
354 ResolvedPatternElement::ShortestPath { head, chain, .. }
355 | ResolvedPatternElement::NodeChain { head, chain } => {
356 if let Some(v) = head.var {
357 vars.push(v);
358 }
359 for step in chain {
360 if let Some(v) = step.rel.var {
361 vars.push(v);
362 }
363 if let Some(v) = step.node.var {
364 vars.push(v);
365 }
366 }
367 }
368 }
369 }
370 vars
371}
372
373fn expr_contains_aggregate(expr: &ResolvedExpr) -> bool {
374 match expr {
375 ResolvedExpr::Function { name, args, .. } => {
376 if is_aggregate_function(name) {
377 return true;
378 }
379 args.iter().any(expr_contains_aggregate)
380 }
381 ResolvedExpr::Property { expr, .. } => expr_contains_aggregate(expr),
382 ResolvedExpr::Binary { lhs, rhs, .. } => {
383 expr_contains_aggregate(lhs) || expr_contains_aggregate(rhs)
384 }
385 ResolvedExpr::Unary { expr, .. } => expr_contains_aggregate(expr),
386 ResolvedExpr::List(items) => items.iter().any(expr_contains_aggregate),
387 ResolvedExpr::Map(items) => items.iter().any(|(_, v)| expr_contains_aggregate(v)),
388 ResolvedExpr::Case {
389 input,
390 alternatives,
391 else_expr,
392 } => {
393 input.as_ref().is_some_and(|e| expr_contains_aggregate(e))
394 || alternatives
395 .iter()
396 .any(|(w, t)| expr_contains_aggregate(w) || expr_contains_aggregate(t))
397 || else_expr
398 .as_ref()
399 .is_some_and(|e| expr_contains_aggregate(e))
400 }
401 _ => false,
402 }
403}