1use std::collections::{HashMap, HashSet, VecDeque};
2
3use crate::catalog::Catalog;
4use crate::error::Result;
5use crate::query::ast::*;
6use crate::query::typecheck::TypeContext;
7use crate::types::Direction;
8
9use super::*;
10
11pub fn lower_query(
12 catalog: &Catalog,
13 query: &QueryDecl,
14 type_ctx: &TypeContext,
15) -> Result<QueryIR> {
16 if !query.mutations.is_empty() {
17 return Err(crate::error::NanoError::Plan(
18 "cannot lower mutation query with read-query lowerer".to_string(),
19 ));
20 }
21 let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
22
23 let mut pipeline = Vec::new();
24 let mut bound_vars = HashSet::new();
25
26 lower_clauses(
27 catalog,
28 &query.match_clause,
29 type_ctx,
30 &mut pipeline,
31 &mut bound_vars,
32 ¶m_names,
33 )?;
34
35 let return_exprs: Vec<IRProjection> = query
36 .return_clause
37 .iter()
38 .map(|p| IRProjection {
39 expr: lower_expr(&p.expr, ¶m_names),
40 alias: p.alias.clone(),
41 })
42 .collect();
43
44 let order_by: Vec<IROrdering> = query
45 .order_clause
46 .iter()
47 .map(|o| IROrdering {
48 expr: lower_expr(&o.expr, ¶m_names),
49 descending: o.descending,
50 })
51 .collect();
52
53 Ok(QueryIR {
54 name: query.name.clone(),
55 params: query.params.clone(),
56 pipeline,
57 return_exprs,
58 order_by,
59 limit: query.limit,
60 })
61}
62
63pub fn lower_mutation_query(query: &QueryDecl) -> Result<MutationIR> {
64 if query.mutations.is_empty() {
65 return Err(crate::error::NanoError::Plan(
66 "query does not contain a mutation body".to_string(),
67 ));
68 }
69 let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
70
71 let ops = query
72 .mutations
73 .iter()
74 .map(|m| lower_single_mutation(m, ¶m_names))
75 .collect::<Result<Vec<_>>>()?;
76
77 Ok(MutationIR {
78 name: query.name.clone(),
79 params: query.params.clone(),
80 ops,
81 })
82}
83
84fn lower_single_mutation(
85 mutation: &Mutation,
86 param_names: &HashSet<String>,
87) -> Result<MutationOpIR> {
88 match mutation {
89 Mutation::Insert(insert) => Ok(MutationOpIR::Insert {
90 type_name: insert.type_name.clone(),
91 assignments: insert
92 .assignments
93 .iter()
94 .map(|a| IRAssignment {
95 property: a.property.clone(),
96 value: lower_match_value(&a.value, param_names),
97 })
98 .collect(),
99 }),
100 Mutation::Update(update) => Ok(MutationOpIR::Update {
101 type_name: update.type_name.clone(),
102 assignments: update
103 .assignments
104 .iter()
105 .map(|a| IRAssignment {
106 property: a.property.clone(),
107 value: lower_match_value(&a.value, param_names),
108 })
109 .collect(),
110 predicate: IRMutationPredicate {
111 property: update.predicate.property.clone(),
112 op: update.predicate.op,
113 value: lower_match_value(&update.predicate.value, param_names),
114 },
115 }),
116 Mutation::Delete(delete) => Ok(MutationOpIR::Delete {
117 type_name: delete.type_name.clone(),
118 predicate: IRMutationPredicate {
119 property: delete.predicate.property.clone(),
120 op: delete.predicate.op,
121 value: lower_match_value(&delete.predicate.value, param_names),
122 },
123 }),
124 }
125}
126
127fn lower_clauses(
128 catalog: &Catalog,
129 clauses: &[Clause],
130 type_ctx: &TypeContext,
131 pipeline: &mut Vec<IROp>,
132 bound_vars: &mut HashSet<String>,
133 param_names: &HashSet<String>,
134) -> Result<()> {
135 let mut bindings = Vec::new();
137 let mut traversals = Vec::new();
138 let mut filters = Vec::new();
139 let mut negations = Vec::new();
140
141 for clause in clauses {
142 match clause {
143 Clause::Binding(b) => bindings.push(b),
144 Clause::Traversal(t) => traversals.push(t),
145 Clause::Filter(f) => filters.push(f),
146 Clause::Negation(inner) => negations.push(inner),
147 }
148 }
149
150 let binding_set: HashSet<&str> = bindings.iter().map(|b| b.variable.as_str()).collect();
165
166 let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
170 for t in &traversals {
171 let src = t.src.as_str();
172 let dst = t.dst.as_str();
173 if src != "_" && dst != "_" {
174 adj.entry(src).or_default().push(dst);
175 adj.entry(dst).or_default().push(src);
176 }
177 }
178
179 let mut deferred_set: HashSet<String> = HashSet::new();
181 let mut component_visited: HashSet<&str> = HashSet::new();
182
183 for binding in &bindings {
184 if component_visited.contains(binding.variable.as_str()) {
185 continue;
186 }
187 let mut queue = VecDeque::new();
189 queue.push_back(binding.variable.as_str());
190 let mut component_bindings: Vec<&str> = Vec::new();
191
192 while let Some(var) = queue.pop_front() {
193 if !component_visited.insert(var) {
194 continue;
195 }
196 if binding_set.contains(var) {
197 component_bindings.push(var);
198 }
199 if let Some(neighbours) = adj.get(var) {
200 for &n in neighbours {
201 if !component_visited.contains(n) {
202 queue.push_back(n);
203 }
204 }
205 }
206 }
207
208 for var in component_bindings.into_iter().skip(1) {
210 deferred_set.insert(var.to_string());
211 }
212 }
213
214 let mut deferred_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
216
217 for binding in &bindings {
219 let node_type = catalog
220 .node_types
221 .get(&binding.type_name)
222 .expect("binding type was validated during typecheck");
223
224 let binding_filters = build_binding_filters(binding, node_type, param_names);
225
226 if deferred_set.contains(&binding.variable) {
227 if !binding_filters.is_empty() {
230 deferred_filters.insert(binding.variable.clone(), binding_filters);
231 }
232 continue;
233 }
234
235 pipeline.push(IROp::NodeScan {
236 variable: binding.variable.clone(),
237 type_name: binding.type_name.clone(),
238 filters: binding_filters,
239 });
240 bound_vars.insert(binding.variable.clone());
241 }
242
243 let mut remaining: Vec<&Traversal> = traversals.to_vec();
251 while !remaining.is_empty() {
252 let mut next_remaining = Vec::new();
253 for traversal in &remaining {
254 let src_bound = bound_vars.contains(&traversal.src);
255 let dst_bound = bound_vars.contains(&traversal.dst);
256 if !src_bound && !dst_bound {
257 next_remaining.push(*traversal);
258 continue;
259 }
260
261 let edge = catalog
262 .lookup_edge_by_name(&traversal.edge_name)
263 .ok_or_else(|| {
264 crate::error::NanoError::Plan(format!(
265 "lowering traversal referenced missing edge '{}' after typecheck",
266 traversal.edge_name
267 ))
268 })?;
269
270 let direction = type_ctx
271 .traversals
272 .iter()
273 .find(|rt| {
274 rt.src == traversal.src
275 && rt.dst == traversal.dst
276 && rt.edge_type == edge.name
277 })
278 .map(|rt| rt.direction)
279 .unwrap_or(Direction::Out);
280
281 let dst_type = match direction {
282 Direction::Out => edge.to_type.clone(),
283 Direction::In => edge.from_type.clone(),
284 };
285
286 if src_bound && dst_bound {
287 let temp_var = format!("__temp_{}", traversal.dst);
289 pipeline.push(IROp::Expand {
290 src_var: traversal.src.clone(),
291 dst_var: temp_var.clone(),
292 edge_type: edge.name.clone(),
293 direction,
294 dst_type,
295 min_hops: traversal.min_hops,
296 max_hops: traversal.max_hops,
297 dst_filters: vec![],
298 });
299 pipeline.push(IROp::Filter(IRFilter {
300 left: IRExpr::PropAccess {
301 variable: temp_var,
302 property: "id".to_string(),
303 },
304 op: CompOp::Eq,
305 right: IRExpr::PropAccess {
306 variable: traversal.dst.clone(),
307 property: "id".to_string(),
308 },
309 }));
310 } else if !src_bound && dst_bound {
311 let reverse_dir = match direction {
313 Direction::Out => Direction::In,
314 Direction::In => Direction::Out,
315 };
316 let src_type = match direction {
317 Direction::Out => edge.from_type.clone(),
318 Direction::In => edge.to_type.clone(),
319 };
320 let introduced_filters =
321 deferred_filters.remove(&traversal.src).unwrap_or_default();
322 pipeline.push(IROp::Expand {
323 src_var: traversal.dst.clone(),
324 dst_var: traversal.src.clone(),
325 edge_type: edge.name.clone(),
326 direction: reverse_dir,
327 dst_type: src_type,
328 min_hops: traversal.min_hops,
329 max_hops: traversal.max_hops,
330 dst_filters: introduced_filters,
331 });
332 if traversal.src != "_" {
333 bound_vars.insert(traversal.src.clone());
334 }
335 } else {
336 let introduced_filters =
338 deferred_filters.remove(&traversal.dst).unwrap_or_default();
339 pipeline.push(IROp::Expand {
340 src_var: traversal.src.clone(),
341 dst_var: traversal.dst.clone(),
342 edge_type: edge.name.clone(),
343 direction,
344 dst_type,
345 min_hops: traversal.min_hops,
346 max_hops: traversal.max_hops,
347 dst_filters: introduced_filters,
348 });
349 if traversal.dst != "_" {
350 bound_vars.insert(traversal.dst.clone());
351 }
352 }
353 }
354 if next_remaining.len() == remaining.len() {
355 break;
356 }
357 remaining = next_remaining;
358 }
359
360 for filter in &filters {
362 pipeline.push(IROp::Filter(IRFilter {
363 left: lower_expr(&filter.left, param_names),
364 op: filter.op,
365 right: lower_expr(&filter.right, param_names),
366 }));
367 }
368
369 for neg_clauses in &negations {
371 let outer_var = find_outer_var(neg_clauses, bound_vars);
373
374 let mut inner_pipeline = Vec::new();
375 let mut inner_bound = bound_vars.clone();
376 lower_clauses(
377 catalog,
378 neg_clauses,
379 type_ctx,
380 &mut inner_pipeline,
381 &mut inner_bound,
382 param_names,
383 )?;
384
385 pipeline.push(IROp::AntiJoin {
386 outer_var: outer_var.unwrap_or_default(),
387 inner: inner_pipeline,
388 });
389 }
390
391 Ok(())
392}
393
394fn build_binding_filters(
396 binding: &Binding,
397 node_type: &crate::catalog::NodeType,
398 param_names: &HashSet<String>,
399) -> Vec<IRFilter> {
400 let mut filters = Vec::new();
401 for pm in &binding.prop_matches {
402 let prop = node_type
403 .properties
404 .get(&pm.prop_name)
405 .expect("binding property was validated during typecheck");
406 let op = if prop.list {
407 CompOp::Contains
408 } else {
409 CompOp::Eq
410 };
411 let right = match &pm.value {
412 MatchValue::Literal(lit) => IRExpr::Literal(lit.clone()),
413 MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
414 MatchValue::Variable(v) => {
415 if param_names.contains(v) {
416 IRExpr::Param(v.clone())
417 } else {
418 IRExpr::Variable(v.clone())
419 }
420 }
421 };
422 filters.push(IRFilter {
423 left: IRExpr::PropAccess {
424 variable: binding.variable.clone(),
425 property: pm.prop_name.clone(),
426 },
427 op,
428 right,
429 });
430 }
431 filters
432}
433
434fn find_outer_var(clauses: &[Clause], outer_bound: &HashSet<String>) -> Option<String> {
435 for clause in clauses {
436 match clause {
437 Clause::Traversal(t) => {
438 if outer_bound.contains(&t.src) {
439 return Some(t.src.clone());
440 }
441 if outer_bound.contains(&t.dst) {
442 return Some(t.dst.clone());
443 }
444 }
445 Clause::Filter(f) => {
446 if let Some(v) = expr_var(&f.left)
447 && outer_bound.contains(&v)
448 {
449 return Some(v);
450 }
451 if let Some(v) = expr_var(&f.right)
452 && outer_bound.contains(&v)
453 {
454 return Some(v);
455 }
456 }
457 Clause::Binding(b) => {
458 if outer_bound.contains(&b.variable) {
459 return Some(b.variable.clone());
460 }
461 }
462 _ => {}
463 }
464 }
465 None
466}
467
468fn expr_var(expr: &Expr) -> Option<String> {
469 match expr {
470 Expr::Now => None,
471 Expr::PropAccess { variable, .. } => Some(variable.clone()),
472 Expr::Variable(v) => Some(v.clone()),
473 Expr::Nearest { variable, .. } => Some(variable.clone()),
474 Expr::Search { field, query } => expr_var(field).or_else(|| expr_var(query)),
475 Expr::Fuzzy {
476 field,
477 query,
478 max_edits,
479 } => expr_var(field)
480 .or_else(|| expr_var(query))
481 .or_else(|| max_edits.as_deref().and_then(expr_var)),
482 Expr::MatchText { field, query } => expr_var(field).or_else(|| expr_var(query)),
483 Expr::Bm25 { field, query } => expr_var(field).or_else(|| expr_var(query)),
484 Expr::Rrf {
485 primary,
486 secondary,
487 k,
488 } => expr_var(primary)
489 .or_else(|| expr_var(secondary))
490 .or_else(|| k.as_deref().and_then(expr_var)),
491 Expr::Aggregate { arg, .. } => expr_var(arg),
492 _ => None,
493 }
494}
495
496fn lower_expr(expr: &Expr, param_names: &HashSet<String>) -> IRExpr {
497 match expr {
498 Expr::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
499 Expr::PropAccess { variable, property } => IRExpr::PropAccess {
500 variable: variable.clone(),
501 property: property.clone(),
502 },
503 Expr::Nearest {
504 variable,
505 property,
506 query,
507 } => IRExpr::Nearest {
508 variable: variable.clone(),
509 property: property.clone(),
510 query: Box::new(lower_expr(query, param_names)),
511 },
512 Expr::Search { field, query } => IRExpr::Search {
513 field: Box::new(lower_expr(field, param_names)),
514 query: Box::new(lower_expr(query, param_names)),
515 },
516 Expr::Fuzzy {
517 field,
518 query,
519 max_edits,
520 } => IRExpr::Fuzzy {
521 field: Box::new(lower_expr(field, param_names)),
522 query: Box::new(lower_expr(query, param_names)),
523 max_edits: max_edits
524 .as_ref()
525 .map(|expr| Box::new(lower_expr(expr, param_names))),
526 },
527 Expr::MatchText { field, query } => IRExpr::MatchText {
528 field: Box::new(lower_expr(field, param_names)),
529 query: Box::new(lower_expr(query, param_names)),
530 },
531 Expr::Bm25 { field, query } => IRExpr::Bm25 {
532 field: Box::new(lower_expr(field, param_names)),
533 query: Box::new(lower_expr(query, param_names)),
534 },
535 Expr::Rrf {
536 primary,
537 secondary,
538 k,
539 } => IRExpr::Rrf {
540 primary: Box::new(lower_expr(primary, param_names)),
541 secondary: Box::new(lower_expr(secondary, param_names)),
542 k: k.as_ref()
543 .map(|expr| Box::new(lower_expr(expr, param_names))),
544 },
545 Expr::Variable(v) => {
546 if param_names.contains(v) {
547 IRExpr::Param(v.clone())
548 } else {
549 IRExpr::Variable(v.clone())
550 }
551 }
552 Expr::Literal(l) => IRExpr::Literal(l.clone()),
553 Expr::Aggregate { func, arg } => IRExpr::Aggregate {
554 func: *func,
555 arg: Box::new(lower_expr(arg, param_names)),
556 },
557 Expr::AliasRef(name) => IRExpr::AliasRef(name.clone()),
558 }
559}
560
561fn lower_match_value(value: &MatchValue, param_names: &HashSet<String>) -> IRExpr {
562 match value {
563 MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
564 MatchValue::Literal(l) => IRExpr::Literal(l.clone()),
565 MatchValue::Variable(v) => {
566 if param_names.contains(v) {
567 IRExpr::Param(v.clone())
568 } else {
569 IRExpr::Variable(v.clone())
570 }
571 }
572 }
573}
574
575#[cfg(test)]
576#[path = "lower_tests.rs"]
577mod tests;