Skip to main content

omnigraph_compiler/ir/
lower.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use crate::catalog::Catalog;
4use crate::error::Result;
5use crate::query::ast::*;
6use crate::query::typecheck::TypeContext;
7use crate::types::Direction;
8
9use super::*;
10
11pub fn lower_query(
12    catalog: &Catalog,
13    query: &QueryDecl,
14    type_ctx: &TypeContext,
15) -> Result<QueryIR> {
16    if !query.mutations.is_empty() {
17        return Err(crate::error::NanoError::Plan(
18            "cannot lower mutation query with read-query lowerer".to_string(),
19        ));
20    }
21    let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
22
23    let mut pipeline = Vec::new();
24    let mut bound_vars = HashSet::new();
25
26    lower_clauses(
27        catalog,
28        &query.match_clause,
29        type_ctx,
30        &mut pipeline,
31        &mut bound_vars,
32        &param_names,
33    )?;
34
35    let return_exprs: Vec<IRProjection> = query
36        .return_clause
37        .iter()
38        .map(|p| IRProjection {
39            expr: lower_expr(&p.expr, &param_names),
40            alias: p.alias.clone(),
41        })
42        .collect();
43
44    let order_by: Vec<IROrdering> = query
45        .order_clause
46        .iter()
47        .map(|o| IROrdering {
48            expr: lower_expr(&o.expr, &param_names),
49            descending: o.descending,
50        })
51        .collect();
52
53    Ok(QueryIR {
54        name: query.name.clone(),
55        params: query.params.clone(),
56        pipeline,
57        return_exprs,
58        order_by,
59        limit: query.limit,
60    })
61}
62
63pub fn lower_mutation_query(query: &QueryDecl) -> Result<MutationIR> {
64    if query.mutations.is_empty() {
65        return Err(crate::error::NanoError::Plan(
66            "query does not contain a mutation body".to_string(),
67        ));
68    }
69    let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
70
71    let ops = query
72        .mutations
73        .iter()
74        .map(|m| lower_single_mutation(m, &param_names))
75        .collect::<Result<Vec<_>>>()?;
76
77    Ok(MutationIR {
78        name: query.name.clone(),
79        params: query.params.clone(),
80        ops,
81    })
82}
83
84fn lower_single_mutation(
85    mutation: &Mutation,
86    param_names: &HashSet<String>,
87) -> Result<MutationOpIR> {
88    match mutation {
89        Mutation::Insert(insert) => Ok(MutationOpIR::Insert {
90            type_name: insert.type_name.clone(),
91            assignments: insert
92                .assignments
93                .iter()
94                .map(|a| IRAssignment {
95                    property: a.property.clone(),
96                    value: lower_match_value(&a.value, param_names),
97                })
98                .collect(),
99        }),
100        Mutation::Update(update) => Ok(MutationOpIR::Update {
101            type_name: update.type_name.clone(),
102            assignments: update
103                .assignments
104                .iter()
105                .map(|a| IRAssignment {
106                    property: a.property.clone(),
107                    value: lower_match_value(&a.value, param_names),
108                })
109                .collect(),
110            predicate: IRMutationPredicate {
111                property: update.predicate.property.clone(),
112                op: update.predicate.op,
113                value: lower_match_value(&update.predicate.value, param_names),
114            },
115        }),
116        Mutation::Delete(delete) => Ok(MutationOpIR::Delete {
117            type_name: delete.type_name.clone(),
118            predicate: IRMutationPredicate {
119                property: delete.predicate.property.clone(),
120                op: delete.predicate.op,
121                value: lower_match_value(&delete.predicate.value, param_names),
122            },
123        }),
124    }
125}
126
127fn lower_clauses(
128    catalog: &Catalog,
129    clauses: &[Clause],
130    type_ctx: &TypeContext,
131    pipeline: &mut Vec<IROp>,
132    bound_vars: &mut HashSet<String>,
133    param_names: &HashSet<String>,
134) -> Result<()> {
135    // Separate clause types for ordering: bindings first, then traversals, then filters
136    let mut bindings = Vec::new();
137    let mut traversals = Vec::new();
138    let mut filters = Vec::new();
139    let mut negations = Vec::new();
140
141    for clause in clauses {
142        match clause {
143            Clause::Binding(b) => bindings.push(b),
144            Clause::Traversal(t) => traversals.push(t),
145            Clause::Filter(f) => filters.push(f),
146            Clause::Negation(inner) => negations.push(inner),
147        }
148    }
149
150    // ── Determine which bindings are "deferred" ─────────────────────────
151    //
152    // When multiple bindings in the same match clause are connected by
153    // traversals, only the first-declared binding needs a NodeScan; the
154    // rest will be introduced by Expand operations.  Making them all
155    // NodeScans triggers expensive cross-joins followed by cycle-closing
156    // filters.
157    //
158    // Algorithm: build an undirected graph of variables connected by
159    // traversals, then walk connected components in binding declaration
160    // order.  The first binding in each component becomes the root (gets
161    // a NodeScan); all other bindings in the same component are deferred
162    // — their inline filters become post-Expand Filter ops.
163
164    let binding_set: HashSet<&str> = bindings.iter().map(|b| b.variable.as_str()).collect();
165
166    // Build undirected traversal adjacency (variable → neighbours).
167    // Exclude the anonymous wildcard "_" so it cannot falsely bridge
168    // otherwise-independent components.
169    let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
170    for t in &traversals {
171        let src = t.src.as_str();
172        let dst = t.dst.as_str();
173        if src != "_" && dst != "_" {
174            adj.entry(src).or_default().push(dst);
175            adj.entry(dst).or_default().push(src);
176        }
177    }
178
179    // Walk components to find deferred binding variables
180    let mut deferred_set: HashSet<String> = HashSet::new();
181    let mut component_visited: HashSet<&str> = HashSet::new();
182
183    for binding in &bindings {
184        if component_visited.contains(binding.variable.as_str()) {
185            continue;
186        }
187        // BFS from this binding through the traversal graph
188        let mut queue = VecDeque::new();
189        queue.push_back(binding.variable.as_str());
190        let mut component_bindings: Vec<&str> = Vec::new();
191
192        while let Some(var) = queue.pop_front() {
193            if !component_visited.insert(var) {
194                continue;
195            }
196            if binding_set.contains(var) {
197                component_bindings.push(var);
198            }
199            if let Some(neighbours) = adj.get(var) {
200                for &n in neighbours {
201                    if !component_visited.contains(n) {
202                        queue.push_back(n);
203                    }
204                }
205            }
206        }
207
208        // First binding in the component is the root; defer the rest.
209        for var in component_bindings.into_iter().skip(1) {
210            deferred_set.insert(var.to_string());
211        }
212    }
213
214    // Build deferred filters map for variables introduced by traversals
215    let mut deferred_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
216
217    // Lower bindings into NodeScan ops (skip deferred ones)
218    for binding in &bindings {
219        let node_type = catalog
220            .node_types
221            .get(&binding.type_name)
222            .expect("binding type was validated during typecheck");
223
224        let binding_filters = build_binding_filters(binding, node_type, param_names);
225
226        if deferred_set.contains(&binding.variable) {
227            // Save filters for emission after the Expand that introduces
228            // this variable.
229            if !binding_filters.is_empty() {
230                deferred_filters.insert(binding.variable.clone(), binding_filters);
231            }
232            continue;
233        }
234
235        pipeline.push(IROp::NodeScan {
236            variable: binding.variable.clone(),
237            type_name: binding.type_name.clone(),
238            filters: binding_filters,
239        });
240        bound_vars.insert(binding.variable.clone());
241    }
242
243    // Lower traversals into Expand ops.
244    //
245    // Traversals are processed iteratively rather than in a single pass
246    // because deferred bindings mean a traversal's source might not be
247    // bound until a prior traversal introduces it.  Each pass processes
248    // every traversal that has at least one bound endpoint; this repeats
249    // until all traversals are consumed.
250    let mut remaining: Vec<&Traversal> = traversals.to_vec();
251    while !remaining.is_empty() {
252        let mut next_remaining = Vec::new();
253        for traversal in &remaining {
254            let src_bound = bound_vars.contains(&traversal.src);
255            let dst_bound = bound_vars.contains(&traversal.dst);
256            if !src_bound && !dst_bound {
257                next_remaining.push(*traversal);
258                continue;
259            }
260
261            let edge = catalog
262                .lookup_edge_by_name(&traversal.edge_name)
263                .ok_or_else(|| {
264                    crate::error::NanoError::Plan(format!(
265                        "lowering traversal referenced missing edge '{}' after typecheck",
266                        traversal.edge_name
267                    ))
268                })?;
269
270            let direction = type_ctx
271                .traversals
272                .iter()
273                .find(|rt| {
274                    rt.src == traversal.src
275                        && rt.dst == traversal.dst
276                        && rt.edge_type == edge.name
277                })
278                .map(|rt| rt.direction)
279                .unwrap_or(Direction::Out);
280
281            let dst_type = match direction {
282                Direction::Out => edge.to_type.clone(),
283                Direction::In => edge.from_type.clone(),
284            };
285
286            if src_bound && dst_bound {
287                // Cycle closing: expand to a temp var, then filter temp.id = dst.id
288                let temp_var = format!("__temp_{}", traversal.dst);
289                pipeline.push(IROp::Expand {
290                    src_var: traversal.src.clone(),
291                    dst_var: temp_var.clone(),
292                    edge_type: edge.name.clone(),
293                    direction,
294                    dst_type,
295                    min_hops: traversal.min_hops,
296                    max_hops: traversal.max_hops,
297                    dst_filters: vec![],
298                });
299                pipeline.push(IROp::Filter(IRFilter {
300                    left: IRExpr::PropAccess {
301                        variable: temp_var,
302                        property: "id".to_string(),
303                    },
304                    op: CompOp::Eq,
305                    right: IRExpr::PropAccess {
306                        variable: traversal.dst.clone(),
307                        property: "id".to_string(),
308                    },
309                }));
310            } else if !src_bound && dst_bound {
311                // Reverse expand: dst is bound, src is not.
312                let reverse_dir = match direction {
313                    Direction::Out => Direction::In,
314                    Direction::In => Direction::Out,
315                };
316                let src_type = match direction {
317                    Direction::Out => edge.from_type.clone(),
318                    Direction::In => edge.to_type.clone(),
319                };
320                let introduced_filters =
321                    deferred_filters.remove(&traversal.src).unwrap_or_default();
322                pipeline.push(IROp::Expand {
323                    src_var: traversal.dst.clone(),
324                    dst_var: traversal.src.clone(),
325                    edge_type: edge.name.clone(),
326                    direction: reverse_dir,
327                    dst_type: src_type,
328                    min_hops: traversal.min_hops,
329                    max_hops: traversal.max_hops,
330                    dst_filters: introduced_filters,
331                });
332                if traversal.src != "_" {
333                    bound_vars.insert(traversal.src.clone());
334                }
335            } else {
336                // Normal expand: src is bound, dst is not.
337                let introduced_filters =
338                    deferred_filters.remove(&traversal.dst).unwrap_or_default();
339                pipeline.push(IROp::Expand {
340                    src_var: traversal.src.clone(),
341                    dst_var: traversal.dst.clone(),
342                    edge_type: edge.name.clone(),
343                    direction,
344                    dst_type,
345                    min_hops: traversal.min_hops,
346                    max_hops: traversal.max_hops,
347                    dst_filters: introduced_filters,
348                });
349                if traversal.dst != "_" {
350                    bound_vars.insert(traversal.dst.clone());
351                }
352            }
353        }
354        if next_remaining.len() == remaining.len() {
355            break;
356        }
357        remaining = next_remaining;
358    }
359
360    // Lower explicit filters
361    for filter in &filters {
362        pipeline.push(IROp::Filter(IRFilter {
363            left: lower_expr(&filter.left, param_names),
364            op: filter.op,
365            right: lower_expr(&filter.right, param_names),
366        }));
367    }
368
369    // Lower negations into AntiJoin ops
370    for neg_clauses in &negations {
371        // Find outer-bound variable referenced in the negation
372        let outer_var = find_outer_var(neg_clauses, bound_vars);
373
374        let mut inner_pipeline = Vec::new();
375        let mut inner_bound = bound_vars.clone();
376        lower_clauses(
377            catalog,
378            neg_clauses,
379            type_ctx,
380            &mut inner_pipeline,
381            &mut inner_bound,
382            param_names,
383        )?;
384
385        pipeline.push(IROp::AntiJoin {
386            outer_var: outer_var.unwrap_or_default(),
387            inner: inner_pipeline,
388        });
389    }
390
391    Ok(())
392}
393
394/// Build IR filters from a binding's inline property matches.
395fn build_binding_filters(
396    binding: &Binding,
397    node_type: &crate::catalog::NodeType,
398    param_names: &HashSet<String>,
399) -> Vec<IRFilter> {
400    let mut filters = Vec::new();
401    for pm in &binding.prop_matches {
402        let prop = node_type
403            .properties
404            .get(&pm.prop_name)
405            .expect("binding property was validated during typecheck");
406        let op = if prop.list {
407            CompOp::Contains
408        } else {
409            CompOp::Eq
410        };
411        let right = match &pm.value {
412            MatchValue::Literal(lit) => IRExpr::Literal(lit.clone()),
413            MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
414            MatchValue::Variable(v) => {
415                if param_names.contains(v) {
416                    IRExpr::Param(v.clone())
417                } else {
418                    IRExpr::Variable(v.clone())
419                }
420            }
421        };
422        filters.push(IRFilter {
423            left: IRExpr::PropAccess {
424                variable: binding.variable.clone(),
425                property: pm.prop_name.clone(),
426            },
427            op,
428            right,
429        });
430    }
431    filters
432}
433
434fn find_outer_var(clauses: &[Clause], outer_bound: &HashSet<String>) -> Option<String> {
435    for clause in clauses {
436        match clause {
437            Clause::Traversal(t) => {
438                if outer_bound.contains(&t.src) {
439                    return Some(t.src.clone());
440                }
441                if outer_bound.contains(&t.dst) {
442                    return Some(t.dst.clone());
443                }
444            }
445            Clause::Filter(f) => {
446                if let Some(v) = expr_var(&f.left)
447                    && outer_bound.contains(&v)
448                {
449                    return Some(v);
450                }
451                if let Some(v) = expr_var(&f.right)
452                    && outer_bound.contains(&v)
453                {
454                    return Some(v);
455                }
456            }
457            Clause::Binding(b) => {
458                if outer_bound.contains(&b.variable) {
459                    return Some(b.variable.clone());
460                }
461            }
462            _ => {}
463        }
464    }
465    None
466}
467
468fn expr_var(expr: &Expr) -> Option<String> {
469    match expr {
470        Expr::Now => None,
471        Expr::PropAccess { variable, .. } => Some(variable.clone()),
472        Expr::Variable(v) => Some(v.clone()),
473        Expr::Nearest { variable, .. } => Some(variable.clone()),
474        Expr::Search { field, query } => expr_var(field).or_else(|| expr_var(query)),
475        Expr::Fuzzy {
476            field,
477            query,
478            max_edits,
479        } => expr_var(field)
480            .or_else(|| expr_var(query))
481            .or_else(|| max_edits.as_deref().and_then(expr_var)),
482        Expr::MatchText { field, query } => expr_var(field).or_else(|| expr_var(query)),
483        Expr::Bm25 { field, query } => expr_var(field).or_else(|| expr_var(query)),
484        Expr::Rrf {
485            primary,
486            secondary,
487            k,
488        } => expr_var(primary)
489            .or_else(|| expr_var(secondary))
490            .or_else(|| k.as_deref().and_then(expr_var)),
491        Expr::Aggregate { arg, .. } => expr_var(arg),
492        _ => None,
493    }
494}
495
496fn lower_expr(expr: &Expr, param_names: &HashSet<String>) -> IRExpr {
497    match expr {
498        Expr::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
499        Expr::PropAccess { variable, property } => IRExpr::PropAccess {
500            variable: variable.clone(),
501            property: property.clone(),
502        },
503        Expr::Nearest {
504            variable,
505            property,
506            query,
507        } => IRExpr::Nearest {
508            variable: variable.clone(),
509            property: property.clone(),
510            query: Box::new(lower_expr(query, param_names)),
511        },
512        Expr::Search { field, query } => IRExpr::Search {
513            field: Box::new(lower_expr(field, param_names)),
514            query: Box::new(lower_expr(query, param_names)),
515        },
516        Expr::Fuzzy {
517            field,
518            query,
519            max_edits,
520        } => IRExpr::Fuzzy {
521            field: Box::new(lower_expr(field, param_names)),
522            query: Box::new(lower_expr(query, param_names)),
523            max_edits: max_edits
524                .as_ref()
525                .map(|expr| Box::new(lower_expr(expr, param_names))),
526        },
527        Expr::MatchText { field, query } => IRExpr::MatchText {
528            field: Box::new(lower_expr(field, param_names)),
529            query: Box::new(lower_expr(query, param_names)),
530        },
531        Expr::Bm25 { field, query } => IRExpr::Bm25 {
532            field: Box::new(lower_expr(field, param_names)),
533            query: Box::new(lower_expr(query, param_names)),
534        },
535        Expr::Rrf {
536            primary,
537            secondary,
538            k,
539        } => IRExpr::Rrf {
540            primary: Box::new(lower_expr(primary, param_names)),
541            secondary: Box::new(lower_expr(secondary, param_names)),
542            k: k.as_ref()
543                .map(|expr| Box::new(lower_expr(expr, param_names))),
544        },
545        Expr::Variable(v) => {
546            if param_names.contains(v) {
547                IRExpr::Param(v.clone())
548            } else {
549                IRExpr::Variable(v.clone())
550            }
551        }
552        Expr::Literal(l) => IRExpr::Literal(l.clone()),
553        Expr::Aggregate { func, arg } => IRExpr::Aggregate {
554            func: *func,
555            arg: Box::new(lower_expr(arg, param_names)),
556        },
557        Expr::AliasRef(name) => IRExpr::AliasRef(name.clone()),
558    }
559}
560
561fn lower_match_value(value: &MatchValue, param_names: &HashSet<String>) -> IRExpr {
562    match value {
563        MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
564        MatchValue::Literal(l) => IRExpr::Literal(l.clone()),
565        MatchValue::Variable(v) => {
566            if param_names.contains(v) {
567                IRExpr::Param(v.clone())
568            } else {
569                IRExpr::Variable(v.clone())
570            }
571        }
572    }
573}
574
575#[cfg(test)]
576#[path = "lower_tests.rs"]
577mod tests;