Skip to main content

lora_compiler/
planner.rs

1use crate::pattern::PatternPlanner;
2use crate::{
3    Aggregation, Argument, Filter, Limit, LogicalOp, LogicalPlan, OptionalMatch, PlanNodeId,
4    Projection, Sort, Unwind,
5};
6use lora_analyzer::symbols::VarId;
7use lora_analyzer::{
8    ResolvedClause, ResolvedCreate, ResolvedDelete, ResolvedExpr, ResolvedMatch, ResolvedMerge,
9    ResolvedPattern, ResolvedPatternElement, ResolvedProjection, ResolvedQuery, ResolvedRemove,
10    ResolvedReturn, ResolvedSet, ResolvedUnwind, ResolvedWith,
11};
12
13pub struct Planner {
14    nodes: Vec<LogicalOp>,
15}
16
17impl Default for Planner {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl Planner {
24    pub fn new() -> Self {
25        Self { nodes: Vec::new() }
26    }
27
28    pub(crate) fn push(&mut self, op: LogicalOp) -> PlanNodeId {
29        let id = self.nodes.len();
30        self.nodes.push(op);
31        id
32    }
33
34    pub fn plan(&mut self, query: &ResolvedQuery) -> LogicalPlan {
35        let root = self.plan_query(query);
36
37        LogicalPlan {
38            root,
39            nodes: std::mem::take(&mut self.nodes),
40        }
41    }
42
43    fn plan_query(&mut self, query: &ResolvedQuery) -> PlanNodeId {
44        let mut input = None;
45
46        for clause in &query.clauses {
47            input = Some(match clause {
48                ResolvedClause::Match(m) => self.plan_match(input, m),
49
50                ResolvedClause::Unwind(u) => {
51                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
52                    self.plan_unwind(upstream, u)
53                }
54
55                ResolvedClause::Create(c) => {
56                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
57                    self.plan_create(upstream, c)
58                }
59
60                ResolvedClause::Merge(m) => {
61                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
62                    self.plan_merge(upstream, m)
63                }
64
65                ResolvedClause::Delete(d) => {
66                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
67                    self.plan_delete(upstream, d)
68                }
69
70                ResolvedClause::Set(s) => {
71                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
72                    self.plan_set(upstream, s)
73                }
74
75                ResolvedClause::Remove(rm) => {
76                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
77                    self.plan_remove(upstream, rm)
78                }
79
80                ResolvedClause::With(w) => {
81                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
82                    self.plan_with(upstream, w)
83                }
84
85                ResolvedClause::Return(r) => {
86                    let upstream = input.unwrap_or_else(|| self.plan_unit_input());
87                    self.plan_return(upstream, r)
88                }
89            });
90        }
91
92        input.unwrap_or_else(|| self.plan_unit_input())
93    }
94
95    fn plan_match(&mut self, input: Option<PlanNodeId>, m: &ResolvedMatch) -> PlanNodeId {
96        if let (true, Some(upstream)) = (m.optional, input) {
97            // OPTIONAL MATCH: build the inner sub-plan that reads from Argument,
98            // then wrap it in an OptionalMatch node that provides null-extension.
99
100            // Collect variables introduced by this pattern (for null-extension).
101            let new_vars = collect_pattern_vars(&m.pattern);
102
103            // Build inner match plan WITHOUT the upstream input — the executor
104            // will inject each upstream row individually.
105            let mut pattern_planner = PatternPlanner::new(self);
106            let mut inner = pattern_planner.plan_pattern(None, &m.pattern);
107
108            if let Some(pred) = &m.where_ {
109                inner = self.push(LogicalOp::Filter(Filter {
110                    input: inner,
111                    predicate: pred.clone(),
112                }));
113            }
114
115            self.push(LogicalOp::OptionalMatch(OptionalMatch {
116                input: upstream,
117                inner,
118                new_vars,
119            }))
120        } else {
121            let mut pattern_planner = PatternPlanner::new(self);
122            let mut node = pattern_planner.plan_pattern(input, &m.pattern);
123
124            if let Some(pred) = &m.where_ {
125                node = self.push(LogicalOp::Filter(Filter {
126                    input: node,
127                    predicate: pred.clone(),
128                }));
129            }
130
131            node
132        }
133    }
134
135    fn plan_unwind(&mut self, input: PlanNodeId, u: &ResolvedUnwind) -> PlanNodeId {
136        self.push(LogicalOp::Unwind(Unwind {
137            input,
138            expr: u.expr.clone(),
139            alias: u.alias,
140        }))
141    }
142
143    fn plan_create(&mut self, input: PlanNodeId, c: &ResolvedCreate) -> PlanNodeId {
144        self.push(LogicalOp::Create(crate::Create {
145            input,
146            pattern: c.pattern.clone(),
147        }))
148    }
149
150    fn plan_merge(&mut self, input: PlanNodeId, m: &ResolvedMerge) -> PlanNodeId {
151        self.push(LogicalOp::Merge(crate::Merge {
152            input,
153            pattern_part: m.pattern_part.clone(),
154            actions: m.actions.clone(),
155        }))
156    }
157
158    fn plan_delete(&mut self, input: PlanNodeId, d: &ResolvedDelete) -> PlanNodeId {
159        self.push(LogicalOp::Delete(crate::Delete {
160            input,
161            detach: d.detach,
162            expressions: d.expressions.clone(),
163        }))
164    }
165
166    fn plan_set(&mut self, input: PlanNodeId, s: &ResolvedSet) -> PlanNodeId {
167        self.push(LogicalOp::Set(crate::Set {
168            input,
169            items: s.items.clone(),
170        }))
171    }
172
173    fn plan_remove(&mut self, input: PlanNodeId, r: &ResolvedRemove) -> PlanNodeId {
174        self.push(LogicalOp::Remove(crate::Remove {
175            input,
176            items: r.items.clone(),
177        }))
178    }
179
180    fn plan_with(&mut self, input: PlanNodeId, with: &ResolvedWith) -> PlanNodeId {
181        let mut node = input;
182
183        // Sort before projection so sort expressions can access original variables.
184        if !with.order.is_empty() {
185            node = self.push(LogicalOp::Sort(Sort {
186                input: node,
187                items: with.order.clone(),
188            }));
189        }
190
191        if with.skip.is_some() || with.limit.is_some() {
192            node = self.push(LogicalOp::Limit(Limit {
193                input: node,
194                skip: with.skip.clone(),
195                limit: with.limit.clone(),
196            }));
197        }
198
199        node = self.plan_projection_or_aggregation(
200            node,
201            &with.items,
202            with.distinct,
203            with.include_existing,
204        );
205
206        if let Some(pred) = &with.where_ {
207            node = self.push(LogicalOp::Filter(Filter {
208                input: node,
209                predicate: pred.clone(),
210            }));
211        }
212
213        node
214    }
215
216    fn plan_return(&mut self, input: PlanNodeId, ret: &ResolvedReturn) -> PlanNodeId {
217        let mut node = input;
218
219        // Sort must happen BEFORE projection so that the sort expressions
220        // can access the original variables (e.g. n.name) which are not
221        // available after projection replaces the row with output VarIds.
222        if !ret.order.is_empty() {
223            node = self.push(LogicalOp::Sort(Sort {
224                input: node,
225                items: ret.order.clone(),
226            }));
227        }
228
229        if ret.skip.is_some() || ret.limit.is_some() {
230            node = self.push(LogicalOp::Limit(Limit {
231                input: node,
232                skip: ret.skip.clone(),
233                limit: ret.limit.clone(),
234            }));
235        }
236
237        node = self.plan_projection_or_aggregation(
238            node,
239            &ret.items,
240            ret.distinct,
241            ret.include_existing,
242        );
243
244        node
245    }
246
247    /// If any projection item contains an aggregate function, emit an
248    /// Aggregation node followed by a Projection. Otherwise emit a plain
249    /// Projection.
250    fn plan_projection_or_aggregation(
251        &mut self,
252        input: PlanNodeId,
253        items: &[ResolvedProjection],
254        distinct: bool,
255        include_existing: bool,
256    ) -> PlanNodeId {
257        let has_aggregates = items.iter().any(|item| expr_contains_aggregate(&item.expr));
258
259        if !has_aggregates {
260            return self.push(LogicalOp::Projection(Projection {
261                input,
262                distinct,
263                items: items.to_vec(),
264                include_existing,
265            }));
266        }
267
268        // Split items into group-by keys and aggregate expressions.
269        let mut group_by = Vec::new();
270        let mut aggregates = Vec::new();
271
272        for item in items {
273            if expr_contains_aggregate(&item.expr) {
274                aggregates.push(item.clone());
275            } else {
276                group_by.push(item.clone());
277            }
278        }
279
280        let node = self.push(LogicalOp::Aggregation(Aggregation {
281            input,
282            group_by: group_by.clone(),
283            aggregates: aggregates.clone(),
284        }));
285
286        // After aggregation the row already contains the right VarIds and names,
287        // but we still emit a Projection to handle DISTINCT and to ensure the
288        // final column order matches the original item list. The projection uses
289        // include_existing=true so it picks up the aggregation output, and each
290        // item just reads its own output variable.
291        //
292        // However, since the aggregation node already produces correctly-named
293        // rows, we can skip the extra projection when not needed.
294        if distinct {
295            // For DISTINCT we still need the dedup pass in exec_projection.
296            let passthrough_items: Vec<ResolvedProjection> = items
297                .iter()
298                .map(|item| ResolvedProjection {
299                    expr: ResolvedExpr::Variable(item.output),
300                    output: item.output,
301                    name: item.name.clone(),
302                    explicit_alias: item.explicit_alias,
303                    span: item.span,
304                })
305                .collect();
306            self.push(LogicalOp::Projection(Projection {
307                input: node,
308                distinct: true,
309                items: passthrough_items,
310                include_existing: false,
311            }))
312        } else {
313            node
314        }
315    }
316
317    fn plan_unit_input(&mut self) -> PlanNodeId {
318        self.push(LogicalOp::Argument(Argument))
319    }
320}
321
322const AGGREGATE_FUNCTIONS: &[&str] = &[
323    "count",
324    "sum",
325    "avg",
326    "min",
327    "max",
328    "collect",
329    "stdev",
330    "stdevp",
331    "percentilecont",
332    "percentiledisc",
333];
334
335fn is_aggregate_function(name: &str) -> bool {
336    AGGREGATE_FUNCTIONS
337        .iter()
338        .any(|&f| f.eq_ignore_ascii_case(name))
339}
340
341/// Collect all VarIds introduced by a pattern (node vars, relationship vars).
342fn collect_pattern_vars(pattern: &ResolvedPattern) -> Vec<VarId> {
343    let mut vars = Vec::new();
344    for part in &pattern.parts {
345        if let Some(v) = part.binding {
346            vars.push(v);
347        }
348        match &part.element {
349            ResolvedPatternElement::Node { var, .. } => {
350                if let Some(v) = var {
351                    vars.push(*v);
352                }
353            }
354            ResolvedPatternElement::ShortestPath { head, chain, .. }
355            | ResolvedPatternElement::NodeChain { head, chain } => {
356                if let Some(v) = head.var {
357                    vars.push(v);
358                }
359                for step in chain {
360                    if let Some(v) = step.rel.var {
361                        vars.push(v);
362                    }
363                    if let Some(v) = step.node.var {
364                        vars.push(v);
365                    }
366                }
367            }
368        }
369    }
370    vars
371}
372
373fn expr_contains_aggregate(expr: &ResolvedExpr) -> bool {
374    match expr {
375        ResolvedExpr::Function { name, args, .. } => {
376            if is_aggregate_function(name) {
377                return true;
378            }
379            args.iter().any(expr_contains_aggregate)
380        }
381        ResolvedExpr::Property { expr, .. } => expr_contains_aggregate(expr),
382        ResolvedExpr::Binary { lhs, rhs, .. } => {
383            expr_contains_aggregate(lhs) || expr_contains_aggregate(rhs)
384        }
385        ResolvedExpr::Unary { expr, .. } => expr_contains_aggregate(expr),
386        ResolvedExpr::List(items) => items.iter().any(expr_contains_aggregate),
387        ResolvedExpr::Map(items) => items.iter().any(|(_, v)| expr_contains_aggregate(v)),
388        ResolvedExpr::Case {
389            input,
390            alternatives,
391            else_expr,
392        } => {
393            input.as_ref().is_some_and(|e| expr_contains_aggregate(e))
394                || alternatives
395                    .iter()
396                    .any(|(w, t)| expr_contains_aggregate(w) || expr_contains_aggregate(t))
397                || else_expr
398                    .as_ref()
399                    .is_some_and(|e| expr_contains_aggregate(e))
400        }
401        _ => false,
402    }
403}