1use crate::query::plan::{
10 ExpandOp, FilterOp, LogicalExpression, LogicalOperator, LogicalPlan, NodeScanOp, ReturnItem,
11 ReturnOp, TripleScanOp,
12};
13use graphos_common::types::LogicalType;
14use graphos_common::utils::error::{Error, QueryError, QueryErrorKind, Result};
15use std::collections::HashMap;
16
17fn binding_error(message: impl Into<String>) -> Error {
19 Error::Query(QueryError::new(QueryErrorKind::Semantic, message))
20}
21
22#[derive(Debug, Clone)]
24pub struct VariableInfo {
25 pub name: String,
27 pub data_type: LogicalType,
29 pub is_node: bool,
31 pub is_edge: bool,
33}
34
35#[derive(Debug, Clone, Default)]
37pub struct BindingContext {
38 variables: HashMap<String, VariableInfo>,
40 order: Vec<String>,
42}
43
44impl BindingContext {
45 #[must_use]
47 pub fn new() -> Self {
48 Self {
49 variables: HashMap::new(),
50 order: Vec::new(),
51 }
52 }
53
54 pub fn add_variable(&mut self, name: String, info: VariableInfo) {
56 if !self.variables.contains_key(&name) {
57 self.order.push(name.clone());
58 }
59 self.variables.insert(name, info);
60 }
61
62 #[must_use]
64 pub fn get(&self, name: &str) -> Option<&VariableInfo> {
65 self.variables.get(name)
66 }
67
68 #[must_use]
70 pub fn contains(&self, name: &str) -> bool {
71 self.variables.contains_key(name)
72 }
73
74 #[must_use]
76 pub fn variable_names(&self) -> &[String] {
77 &self.order
78 }
79
80 #[must_use]
82 pub fn len(&self) -> usize {
83 self.variables.len()
84 }
85
86 #[must_use]
88 pub fn is_empty(&self) -> bool {
89 self.variables.is_empty()
90 }
91}
92
93pub struct Binder {
101 context: BindingContext,
103}
104
105impl Binder {
106 #[must_use]
108 pub fn new() -> Self {
109 Self {
110 context: BindingContext::new(),
111 }
112 }
113
114 pub fn bind(&mut self, plan: &LogicalPlan) -> Result<BindingContext> {
120 self.bind_operator(&plan.root)?;
121 Ok(self.context.clone())
122 }
123
124 fn bind_operator(&mut self, op: &LogicalOperator) -> Result<()> {
126 match op {
127 LogicalOperator::NodeScan(scan) => self.bind_node_scan(scan),
128 LogicalOperator::Expand(expand) => self.bind_expand(expand),
129 LogicalOperator::Filter(filter) => self.bind_filter(filter),
130 LogicalOperator::Return(ret) => self.bind_return(ret),
131 LogicalOperator::Project(project) => {
132 self.bind_operator(&project.input)?;
133 for projection in &project.projections {
134 self.validate_expression(&projection.expression)?;
135 }
136 Ok(())
137 }
138 LogicalOperator::Limit(limit) => self.bind_operator(&limit.input),
139 LogicalOperator::Skip(skip) => self.bind_operator(&skip.input),
140 LogicalOperator::Sort(sort) => {
141 self.bind_operator(&sort.input)?;
142 for key in &sort.keys {
143 self.validate_expression(&key.expression)?;
144 }
145 Ok(())
146 }
147 LogicalOperator::CreateNode(create) => {
148 if let Some(ref input) = create.input {
150 self.bind_operator(input)?;
151 }
152 self.context.add_variable(
153 create.variable.clone(),
154 VariableInfo {
155 name: create.variable.clone(),
156 data_type: LogicalType::Node,
157 is_node: true,
158 is_edge: false,
159 },
160 );
161 for (_, expr) in &create.properties {
163 self.validate_expression(expr)?;
164 }
165 Ok(())
166 }
167 LogicalOperator::EdgeScan(scan) => {
168 if let Some(ref input) = scan.input {
169 self.bind_operator(input)?;
170 }
171 self.context.add_variable(
172 scan.variable.clone(),
173 VariableInfo {
174 name: scan.variable.clone(),
175 data_type: LogicalType::Edge,
176 is_node: false,
177 is_edge: true,
178 },
179 );
180 Ok(())
181 }
182 LogicalOperator::Distinct(distinct) => self.bind_operator(&distinct.input),
183 LogicalOperator::Join(join) => self.bind_join(join),
184 LogicalOperator::Aggregate(agg) => self.bind_aggregate(agg),
185 LogicalOperator::CreateEdge(create) => {
186 self.bind_operator(&create.input)?;
187 if !self.context.contains(&create.from_variable) {
189 return Err(binding_error(format!(
190 "Undefined source variable '{}' in CREATE EDGE",
191 create.from_variable
192 )));
193 }
194 if !self.context.contains(&create.to_variable) {
195 return Err(binding_error(format!(
196 "Undefined target variable '{}' in CREATE EDGE",
197 create.to_variable
198 )));
199 }
200 if let Some(ref var) = create.variable {
202 self.context.add_variable(
203 var.clone(),
204 VariableInfo {
205 name: var.clone(),
206 data_type: LogicalType::Edge,
207 is_node: false,
208 is_edge: true,
209 },
210 );
211 }
212 for (_, expr) in &create.properties {
214 self.validate_expression(expr)?;
215 }
216 Ok(())
217 }
218 LogicalOperator::DeleteNode(delete) => {
219 self.bind_operator(&delete.input)?;
220 if !self.context.contains(&delete.variable) {
222 return Err(binding_error(format!(
223 "Undefined variable '{}' in DELETE",
224 delete.variable
225 )));
226 }
227 Ok(())
228 }
229 LogicalOperator::DeleteEdge(delete) => {
230 self.bind_operator(&delete.input)?;
231 if !self.context.contains(&delete.variable) {
233 return Err(binding_error(format!(
234 "Undefined variable '{}' in DELETE",
235 delete.variable
236 )));
237 }
238 Ok(())
239 }
240 LogicalOperator::SetProperty(set) => {
241 self.bind_operator(&set.input)?;
242 if !self.context.contains(&set.variable) {
244 return Err(binding_error(format!(
245 "Undefined variable '{}' in SET",
246 set.variable
247 )));
248 }
249 for (_, expr) in &set.properties {
251 self.validate_expression(expr)?;
252 }
253 Ok(())
254 }
255 LogicalOperator::Empty => Ok(()),
256
257 LogicalOperator::Unwind(unwind) => {
258 self.bind_operator(&unwind.input)?;
260 self.validate_expression(&unwind.expression)?;
262 self.context.add_variable(
264 unwind.variable.clone(),
265 VariableInfo {
266 name: unwind.variable.clone(),
267 data_type: LogicalType::Any, is_node: false,
269 is_edge: false,
270 },
271 );
272 Ok(())
273 }
274
275 LogicalOperator::TripleScan(scan) => self.bind_triple_scan(scan),
277 LogicalOperator::Union(union) => {
278 for input in &union.inputs {
279 self.bind_operator(input)?;
280 }
281 Ok(())
282 }
283 LogicalOperator::LeftJoin(lj) => {
284 self.bind_operator(&lj.left)?;
285 self.bind_operator(&lj.right)?;
286 if let Some(ref cond) = lj.condition {
287 self.validate_expression(cond)?;
288 }
289 Ok(())
290 }
291 LogicalOperator::AntiJoin(aj) => {
292 self.bind_operator(&aj.left)?;
293 self.bind_operator(&aj.right)?;
294 Ok(())
295 }
296 LogicalOperator::Bind(bind) => {
297 self.bind_operator(&bind.input)?;
298 self.validate_expression(&bind.expression)?;
299 self.context.add_variable(
300 bind.variable.clone(),
301 VariableInfo {
302 name: bind.variable.clone(),
303 data_type: LogicalType::Any,
304 is_node: false,
305 is_edge: false,
306 },
307 );
308 Ok(())
309 }
310 LogicalOperator::Merge(merge) => {
311 self.bind_operator(&merge.input)?;
313 for (_, expr) in &merge.match_properties {
315 self.validate_expression(expr)?;
316 }
317 for (_, expr) in &merge.on_create {
319 self.validate_expression(expr)?;
320 }
321 for (_, expr) in &merge.on_match {
323 self.validate_expression(expr)?;
324 }
325 self.context.add_variable(
327 merge.variable.clone(),
328 VariableInfo {
329 name: merge.variable.clone(),
330 data_type: LogicalType::Node,
331 is_node: true,
332 is_edge: false,
333 },
334 );
335 Ok(())
336 }
337 LogicalOperator::AddLabel(add_label) => {
338 self.bind_operator(&add_label.input)?;
339 if !self.context.contains(&add_label.variable) {
341 return Err(binding_error(format!(
342 "Undefined variable '{}' in SET labels",
343 add_label.variable
344 )));
345 }
346 Ok(())
347 }
348 LogicalOperator::RemoveLabel(remove_label) => {
349 self.bind_operator(&remove_label.input)?;
350 if !self.context.contains(&remove_label.variable) {
352 return Err(binding_error(format!(
353 "Undefined variable '{}' in REMOVE labels",
354 remove_label.variable
355 )));
356 }
357 Ok(())
358 }
359 }
360 }
361
362 fn bind_triple_scan(&mut self, scan: &TripleScanOp) -> Result<()> {
364 use crate::query::plan::TripleComponent;
365
366 if let Some(ref input) = scan.input {
368 self.bind_operator(input)?;
369 }
370
371 if let TripleComponent::Variable(name) = &scan.subject {
373 if !self.context.contains(name) {
374 self.context.add_variable(
375 name.clone(),
376 VariableInfo {
377 name: name.clone(),
378 data_type: LogicalType::Any, is_node: false,
380 is_edge: false,
381 },
382 );
383 }
384 }
385
386 if let TripleComponent::Variable(name) = &scan.predicate {
387 if !self.context.contains(name) {
388 self.context.add_variable(
389 name.clone(),
390 VariableInfo {
391 name: name.clone(),
392 data_type: LogicalType::Any, is_node: false,
394 is_edge: false,
395 },
396 );
397 }
398 }
399
400 if let TripleComponent::Variable(name) = &scan.object {
401 if !self.context.contains(name) {
402 self.context.add_variable(
403 name.clone(),
404 VariableInfo {
405 name: name.clone(),
406 data_type: LogicalType::Any, is_node: false,
408 is_edge: false,
409 },
410 );
411 }
412 }
413
414 if let Some(TripleComponent::Variable(name)) = &scan.graph {
415 if !self.context.contains(name) {
416 self.context.add_variable(
417 name.clone(),
418 VariableInfo {
419 name: name.clone(),
420 data_type: LogicalType::Any, is_node: false,
422 is_edge: false,
423 },
424 );
425 }
426 }
427
428 Ok(())
429 }
430
431 fn bind_node_scan(&mut self, scan: &NodeScanOp) -> Result<()> {
433 if let Some(ref input) = scan.input {
435 self.bind_operator(input)?;
436 }
437
438 self.context.add_variable(
440 scan.variable.clone(),
441 VariableInfo {
442 name: scan.variable.clone(),
443 data_type: LogicalType::Node,
444 is_node: true,
445 is_edge: false,
446 },
447 );
448
449 Ok(())
450 }
451
452 fn bind_expand(&mut self, expand: &ExpandOp) -> Result<()> {
454 self.bind_operator(&expand.input)?;
456
457 if !self.context.contains(&expand.from_variable) {
459 return Err(binding_error(format!(
460 "Undefined variable '{}' in EXPAND",
461 expand.from_variable
462 )));
463 }
464
465 if let Some(info) = self.context.get(&expand.from_variable) {
467 if !info.is_node {
468 return Err(binding_error(format!(
469 "Variable '{}' is not a node, cannot expand from it",
470 expand.from_variable
471 )));
472 }
473 }
474
475 if let Some(ref edge_var) = expand.edge_variable {
477 self.context.add_variable(
478 edge_var.clone(),
479 VariableInfo {
480 name: edge_var.clone(),
481 data_type: LogicalType::Edge,
482 is_node: false,
483 is_edge: true,
484 },
485 );
486 }
487
488 self.context.add_variable(
490 expand.to_variable.clone(),
491 VariableInfo {
492 name: expand.to_variable.clone(),
493 data_type: LogicalType::Node,
494 is_node: true,
495 is_edge: false,
496 },
497 );
498
499 Ok(())
500 }
501
502 fn bind_filter(&mut self, filter: &FilterOp) -> Result<()> {
504 self.bind_operator(&filter.input)?;
506
507 self.validate_expression(&filter.predicate)?;
509
510 Ok(())
511 }
512
513 fn bind_return(&mut self, ret: &ReturnOp) -> Result<()> {
515 self.bind_operator(&ret.input)?;
517
518 for item in &ret.items {
520 self.validate_return_item(item)?;
521 }
522
523 Ok(())
524 }
525
526 fn validate_return_item(&self, item: &ReturnItem) -> Result<()> {
528 self.validate_expression(&item.expression)
529 }
530
531 fn validate_expression(&self, expr: &LogicalExpression) -> Result<()> {
533 match expr {
534 LogicalExpression::Variable(name) => {
535 if !self.context.contains(name) && !name.starts_with("_anon_") {
536 return Err(binding_error(format!("Undefined variable '{name}'")));
537 }
538 Ok(())
539 }
540 LogicalExpression::Property { variable, .. } => {
541 if !self.context.contains(variable) && !variable.starts_with("_anon_") {
542 return Err(binding_error(format!(
543 "Undefined variable '{variable}' in property access"
544 )));
545 }
546 Ok(())
547 }
548 LogicalExpression::Literal(_) => Ok(()),
549 LogicalExpression::Binary { left, right, .. } => {
550 self.validate_expression(left)?;
551 self.validate_expression(right)
552 }
553 LogicalExpression::Unary { operand, .. } => self.validate_expression(operand),
554 LogicalExpression::FunctionCall { args, .. } => {
555 for arg in args {
556 self.validate_expression(arg)?;
557 }
558 Ok(())
559 }
560 LogicalExpression::List(items) => {
561 for item in items {
562 self.validate_expression(item)?;
563 }
564 Ok(())
565 }
566 LogicalExpression::Map(pairs) => {
567 for (_, value) in pairs {
568 self.validate_expression(value)?;
569 }
570 Ok(())
571 }
572 LogicalExpression::IndexAccess { base, index } => {
573 self.validate_expression(base)?;
574 self.validate_expression(index)
575 }
576 LogicalExpression::SliceAccess { base, start, end } => {
577 self.validate_expression(base)?;
578 if let Some(s) = start {
579 self.validate_expression(s)?;
580 }
581 if let Some(e) = end {
582 self.validate_expression(e)?;
583 }
584 Ok(())
585 }
586 LogicalExpression::Case {
587 operand,
588 when_clauses,
589 else_clause,
590 } => {
591 if let Some(op) = operand {
592 self.validate_expression(op)?;
593 }
594 for (cond, result) in when_clauses {
595 self.validate_expression(cond)?;
596 self.validate_expression(result)?;
597 }
598 if let Some(else_expr) = else_clause {
599 self.validate_expression(else_expr)?;
600 }
601 Ok(())
602 }
603 LogicalExpression::Parameter(_) => Ok(()),
605 LogicalExpression::Labels(var)
607 | LogicalExpression::Type(var)
608 | LogicalExpression::Id(var) => {
609 if !self.context.contains(var) && !var.starts_with("_anon_") {
610 return Err(binding_error(format!(
611 "Undefined variable '{var}' in function"
612 )));
613 }
614 Ok(())
615 }
616 LogicalExpression::ListComprehension {
617 list_expr,
618 filter_expr,
619 map_expr,
620 ..
621 } => {
622 self.validate_expression(list_expr)?;
624 if let Some(filter) = filter_expr {
628 self.validate_expression(filter)?;
629 }
630 self.validate_expression(map_expr)?;
631 Ok(())
632 }
633 LogicalExpression::ExistsSubquery(subquery)
634 | LogicalExpression::CountSubquery(subquery) => {
635 let _ = subquery; Ok(())
639 }
640 }
641 }
642
643 fn bind_join(&mut self, join: &crate::query::plan::JoinOp) -> Result<()> {
645 self.bind_operator(&join.left)?;
647 self.bind_operator(&join.right)?;
648
649 for condition in &join.conditions {
651 self.validate_expression(&condition.left)?;
652 self.validate_expression(&condition.right)?;
653 }
654
655 Ok(())
656 }
657
658 fn bind_aggregate(&mut self, agg: &crate::query::plan::AggregateOp) -> Result<()> {
660 self.bind_operator(&agg.input)?;
662
663 for expr in &agg.group_by {
665 self.validate_expression(expr)?;
666 }
667
668 for agg_expr in &agg.aggregates {
670 if let Some(ref expr) = agg_expr.expression {
671 self.validate_expression(expr)?;
672 }
673 if let Some(ref alias) = agg_expr.alias {
675 self.context.add_variable(
676 alias.clone(),
677 VariableInfo {
678 name: alias.clone(),
679 data_type: LogicalType::Any,
680 is_node: false,
681 is_edge: false,
682 },
683 );
684 }
685 }
686
687 Ok(())
688 }
689}
690
691impl Default for Binder {
692 fn default() -> Self {
693 Self::new()
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700 use crate::query::plan::{BinaryOp, FilterOp};
701
702 #[test]
703 fn test_bind_simple_scan() {
704 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
705 items: vec![ReturnItem {
706 expression: LogicalExpression::Variable("n".to_string()),
707 alias: None,
708 }],
709 distinct: false,
710 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
711 variable: "n".to_string(),
712 label: Some("Person".to_string()),
713 input: None,
714 })),
715 }));
716
717 let mut binder = Binder::new();
718 let result = binder.bind(&plan);
719
720 assert!(result.is_ok());
721 let ctx = result.unwrap();
722 assert!(ctx.contains("n"));
723 assert!(ctx.get("n").unwrap().is_node);
724 }
725
726 #[test]
727 fn test_bind_undefined_variable() {
728 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
729 items: vec![ReturnItem {
730 expression: LogicalExpression::Variable("undefined".to_string()),
731 alias: None,
732 }],
733 distinct: false,
734 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
735 variable: "n".to_string(),
736 label: None,
737 input: None,
738 })),
739 }));
740
741 let mut binder = Binder::new();
742 let result = binder.bind(&plan);
743
744 assert!(result.is_err());
745 let err = result.unwrap_err();
746 assert!(err.to_string().contains("Undefined variable"));
747 }
748
749 #[test]
750 fn test_bind_property_access() {
751 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
752 items: vec![ReturnItem {
753 expression: LogicalExpression::Property {
754 variable: "n".to_string(),
755 property: "name".to_string(),
756 },
757 alias: None,
758 }],
759 distinct: false,
760 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
761 variable: "n".to_string(),
762 label: Some("Person".to_string()),
763 input: None,
764 })),
765 }));
766
767 let mut binder = Binder::new();
768 let result = binder.bind(&plan);
769
770 assert!(result.is_ok());
771 }
772
773 #[test]
774 fn test_bind_filter_with_undefined_variable() {
775 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
776 items: vec![ReturnItem {
777 expression: LogicalExpression::Variable("n".to_string()),
778 alias: None,
779 }],
780 distinct: false,
781 input: Box::new(LogicalOperator::Filter(FilterOp {
782 predicate: LogicalExpression::Binary {
783 left: Box::new(LogicalExpression::Property {
784 variable: "m".to_string(), property: "age".to_string(),
786 }),
787 op: BinaryOp::Gt,
788 right: Box::new(LogicalExpression::Literal(
789 graphos_common::types::Value::Int64(30),
790 )),
791 },
792 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
793 variable: "n".to_string(),
794 label: None,
795 input: None,
796 })),
797 })),
798 }));
799
800 let mut binder = Binder::new();
801 let result = binder.bind(&plan);
802
803 assert!(result.is_err());
804 let err = result.unwrap_err();
805 assert!(err.to_string().contains("Undefined variable 'm'"));
806 }
807
808 #[test]
809 fn test_bind_expand() {
810 use crate::query::plan::{ExpandDirection, ExpandOp};
811
812 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
813 items: vec![
814 ReturnItem {
815 expression: LogicalExpression::Variable("a".to_string()),
816 alias: None,
817 },
818 ReturnItem {
819 expression: LogicalExpression::Variable("b".to_string()),
820 alias: None,
821 },
822 ],
823 distinct: false,
824 input: Box::new(LogicalOperator::Expand(ExpandOp {
825 from_variable: "a".to_string(),
826 to_variable: "b".to_string(),
827 edge_variable: Some("e".to_string()),
828 direction: ExpandDirection::Outgoing,
829 edge_type: Some("KNOWS".to_string()),
830 min_hops: 1,
831 max_hops: Some(1),
832 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
833 variable: "a".to_string(),
834 label: Some("Person".to_string()),
835 input: None,
836 })),
837 })),
838 }));
839
840 let mut binder = Binder::new();
841 let result = binder.bind(&plan);
842
843 assert!(result.is_ok());
844 let ctx = result.unwrap();
845 assert!(ctx.contains("a"));
846 assert!(ctx.contains("b"));
847 assert!(ctx.contains("e"));
848 assert!(ctx.get("a").unwrap().is_node);
849 assert!(ctx.get("b").unwrap().is_node);
850 assert!(ctx.get("e").unwrap().is_edge);
851 }
852}