Skip to main content

grafeo_engine/query/
binder.rs

1//! Semantic validation - catching errors before execution.
2//!
3//! The binder walks the logical plan and validates that everything makes sense:
4//! - Is that variable actually defined? (You can't use `RETURN x` if `x` wasn't matched)
5//! - Does that property access make sense? (Accessing `.age` on an integer fails)
6//! - Are types compatible? (Can't compare a string to an integer)
7//!
8//! Better to catch these errors early than waste time executing a broken query.
9
10use crate::query::plan::{
11    ExpandOp, FilterOp, LogicalExpression, LogicalOperator, LogicalPlan, NodeScanOp, ReturnItem,
12    ReturnOp, TripleScanOp,
13};
14use grafeo_common::types::LogicalType;
15use grafeo_common::utils::error::{Error, QueryError, QueryErrorKind, Result};
16use grafeo_common::utils::strings::{find_similar, format_suggestion};
17use std::collections::HashMap;
18
19/// Creates a semantic binding error.
20fn binding_error(message: impl Into<String>) -> Error {
21    Error::Query(QueryError::new(QueryErrorKind::Semantic, message))
22}
23
24/// Creates a semantic binding error with a hint.
25fn binding_error_with_hint(message: impl Into<String>, hint: impl Into<String>) -> Error {
26    Error::Query(QueryError::new(QueryErrorKind::Semantic, message).with_hint(hint))
27}
28
29/// Creates an "undefined variable" error with a suggestion if a similar variable exists.
30fn undefined_variable_error(variable: &str, context: &BindingContext, suffix: &str) -> Error {
31    let candidates: Vec<String> = context.variable_names().to_vec();
32    let candidates_ref: Vec<&str> = candidates.iter().map(|s| s.as_str()).collect();
33
34    if let Some(suggestion) = find_similar(variable, &candidates_ref) {
35        binding_error_with_hint(
36            format!("Undefined variable '{variable}'{suffix}"),
37            format_suggestion(suggestion),
38        )
39    } else {
40        binding_error(format!("Undefined variable '{variable}'{suffix}"))
41    }
42}
43
44/// Information about a bound variable.
45#[derive(Debug, Clone)]
46pub struct VariableInfo {
47    /// The name of the variable.
48    pub name: String,
49    /// The inferred type of the variable.
50    pub data_type: LogicalType,
51    /// Whether this variable is a node.
52    pub is_node: bool,
53    /// Whether this variable is an edge.
54    pub is_edge: bool,
55}
56
57/// Context containing all bound variables and their information.
58#[derive(Debug, Clone, Default)]
59pub struct BindingContext {
60    /// Map from variable name to its info.
61    variables: HashMap<String, VariableInfo>,
62    /// Variables in order of definition.
63    order: Vec<String>,
64}
65
66impl BindingContext {
67    /// Creates a new empty binding context.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            variables: HashMap::new(),
72            order: Vec::new(),
73        }
74    }
75
76    /// Adds a variable to the context.
77    pub fn add_variable(&mut self, name: String, info: VariableInfo) {
78        if !self.variables.contains_key(&name) {
79            self.order.push(name.clone());
80        }
81        self.variables.insert(name, info);
82    }
83
84    /// Looks up a variable by name.
85    #[must_use]
86    pub fn get(&self, name: &str) -> Option<&VariableInfo> {
87        self.variables.get(name)
88    }
89
90    /// Checks if a variable is defined.
91    #[must_use]
92    pub fn contains(&self, name: &str) -> bool {
93        self.variables.contains_key(name)
94    }
95
96    /// Returns all variable names in definition order.
97    #[must_use]
98    pub fn variable_names(&self) -> &[String] {
99        &self.order
100    }
101
102    /// Returns the number of bound variables.
103    #[must_use]
104    pub fn len(&self) -> usize {
105        self.variables.len()
106    }
107
108    /// Returns true if no variables are bound.
109    #[must_use]
110    pub fn is_empty(&self) -> bool {
111        self.variables.is_empty()
112    }
113}
114
115/// Semantic binder for query plans.
116///
117/// The binder walks the logical plan and:
118/// 1. Collects all variable definitions
119/// 2. Validates that all variable references are valid
120/// 3. Infers types where possible
121/// 4. Reports semantic errors
122pub struct Binder {
123    /// The current binding context.
124    context: BindingContext,
125}
126
127impl Binder {
128    /// Creates a new binder.
129    #[must_use]
130    pub fn new() -> Self {
131        Self {
132            context: BindingContext::new(),
133        }
134    }
135
136    /// Binds a logical plan, returning the binding context.
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if semantic validation fails.
141    pub fn bind(&mut self, plan: &LogicalPlan) -> Result<BindingContext> {
142        self.bind_operator(&plan.root)?;
143        Ok(self.context.clone())
144    }
145
146    /// Binds a single logical operator.
147    fn bind_operator(&mut self, op: &LogicalOperator) -> Result<()> {
148        match op {
149            LogicalOperator::NodeScan(scan) => self.bind_node_scan(scan),
150            LogicalOperator::Expand(expand) => self.bind_expand(expand),
151            LogicalOperator::Filter(filter) => self.bind_filter(filter),
152            LogicalOperator::Return(ret) => self.bind_return(ret),
153            LogicalOperator::Project(project) => {
154                self.bind_operator(&project.input)?;
155                for projection in &project.projections {
156                    self.validate_expression(&projection.expression)?;
157                    // Add the projection alias to the context (for WITH clause support)
158                    if let Some(ref alias) = projection.alias {
159                        // Determine the type from the expression
160                        let data_type = self.infer_expression_type(&projection.expression);
161                        self.context.add_variable(
162                            alias.clone(),
163                            VariableInfo {
164                                name: alias.clone(),
165                                data_type,
166                                is_node: false,
167                                is_edge: false,
168                            },
169                        );
170                    }
171                }
172                Ok(())
173            }
174            LogicalOperator::Limit(limit) => self.bind_operator(&limit.input),
175            LogicalOperator::Skip(skip) => self.bind_operator(&skip.input),
176            LogicalOperator::Sort(sort) => {
177                self.bind_operator(&sort.input)?;
178                for key in &sort.keys {
179                    self.validate_expression(&key.expression)?;
180                }
181                Ok(())
182            }
183            LogicalOperator::CreateNode(create) => {
184                // CreateNode introduces a new variable
185                if let Some(ref input) = create.input {
186                    self.bind_operator(input)?;
187                }
188                self.context.add_variable(
189                    create.variable.clone(),
190                    VariableInfo {
191                        name: create.variable.clone(),
192                        data_type: LogicalType::Node,
193                        is_node: true,
194                        is_edge: false,
195                    },
196                );
197                // Validate property expressions
198                for (_, expr) in &create.properties {
199                    self.validate_expression(expr)?;
200                }
201                Ok(())
202            }
203            LogicalOperator::EdgeScan(scan) => {
204                if let Some(ref input) = scan.input {
205                    self.bind_operator(input)?;
206                }
207                self.context.add_variable(
208                    scan.variable.clone(),
209                    VariableInfo {
210                        name: scan.variable.clone(),
211                        data_type: LogicalType::Edge,
212                        is_node: false,
213                        is_edge: true,
214                    },
215                );
216                Ok(())
217            }
218            LogicalOperator::Distinct(distinct) => self.bind_operator(&distinct.input),
219            LogicalOperator::Join(join) => self.bind_join(join),
220            LogicalOperator::Aggregate(agg) => self.bind_aggregate(agg),
221            LogicalOperator::CreateEdge(create) => {
222                self.bind_operator(&create.input)?;
223                // Validate that source and target variables are defined
224                if !self.context.contains(&create.from_variable) {
225                    return Err(undefined_variable_error(
226                        &create.from_variable,
227                        &self.context,
228                        " (source in CREATE EDGE)",
229                    ));
230                }
231                if !self.context.contains(&create.to_variable) {
232                    return Err(undefined_variable_error(
233                        &create.to_variable,
234                        &self.context,
235                        " (target in CREATE EDGE)",
236                    ));
237                }
238                // Add edge variable if present
239                if let Some(ref var) = create.variable {
240                    self.context.add_variable(
241                        var.clone(),
242                        VariableInfo {
243                            name: var.clone(),
244                            data_type: LogicalType::Edge,
245                            is_node: false,
246                            is_edge: true,
247                        },
248                    );
249                }
250                // Validate property expressions
251                for (_, expr) in &create.properties {
252                    self.validate_expression(expr)?;
253                }
254                Ok(())
255            }
256            LogicalOperator::DeleteNode(delete) => {
257                self.bind_operator(&delete.input)?;
258                // Validate that the variable to delete is defined
259                if !self.context.contains(&delete.variable) {
260                    return Err(undefined_variable_error(
261                        &delete.variable,
262                        &self.context,
263                        " in DELETE",
264                    ));
265                }
266                Ok(())
267            }
268            LogicalOperator::DeleteEdge(delete) => {
269                self.bind_operator(&delete.input)?;
270                // Validate that the variable to delete is defined
271                if !self.context.contains(&delete.variable) {
272                    return Err(undefined_variable_error(
273                        &delete.variable,
274                        &self.context,
275                        " in DELETE",
276                    ));
277                }
278                Ok(())
279            }
280            LogicalOperator::SetProperty(set) => {
281                self.bind_operator(&set.input)?;
282                // Validate that the variable to update is defined
283                if !self.context.contains(&set.variable) {
284                    return Err(undefined_variable_error(
285                        &set.variable,
286                        &self.context,
287                        " in SET",
288                    ));
289                }
290                // Validate property value expressions
291                for (_, expr) in &set.properties {
292                    self.validate_expression(expr)?;
293                }
294                Ok(())
295            }
296            LogicalOperator::Empty => Ok(()),
297
298            LogicalOperator::Unwind(unwind) => {
299                // First bind the input
300                self.bind_operator(&unwind.input)?;
301                // Validate the expression being unwound
302                self.validate_expression(&unwind.expression)?;
303                // Add the new variable to the context
304                self.context.add_variable(
305                    unwind.variable.clone(),
306                    VariableInfo {
307                        name: unwind.variable.clone(),
308                        data_type: LogicalType::Any, // Unwound elements can be any type
309                        is_node: false,
310                        is_edge: false,
311                    },
312                );
313                Ok(())
314            }
315
316            // RDF/SPARQL operators
317            LogicalOperator::TripleScan(scan) => self.bind_triple_scan(scan),
318            LogicalOperator::Union(union) => {
319                for input in &union.inputs {
320                    self.bind_operator(input)?;
321                }
322                Ok(())
323            }
324            LogicalOperator::LeftJoin(lj) => {
325                self.bind_operator(&lj.left)?;
326                self.bind_operator(&lj.right)?;
327                if let Some(ref cond) = lj.condition {
328                    self.validate_expression(cond)?;
329                }
330                Ok(())
331            }
332            LogicalOperator::AntiJoin(aj) => {
333                self.bind_operator(&aj.left)?;
334                self.bind_operator(&aj.right)?;
335                Ok(())
336            }
337            LogicalOperator::Bind(bind) => {
338                self.bind_operator(&bind.input)?;
339                self.validate_expression(&bind.expression)?;
340                self.context.add_variable(
341                    bind.variable.clone(),
342                    VariableInfo {
343                        name: bind.variable.clone(),
344                        data_type: LogicalType::Any,
345                        is_node: false,
346                        is_edge: false,
347                    },
348                );
349                Ok(())
350            }
351            LogicalOperator::Merge(merge) => {
352                // First bind the input
353                self.bind_operator(&merge.input)?;
354                // Validate the match property expressions
355                for (_, expr) in &merge.match_properties {
356                    self.validate_expression(expr)?;
357                }
358                // Validate the ON CREATE property expressions
359                for (_, expr) in &merge.on_create {
360                    self.validate_expression(expr)?;
361                }
362                // Validate the ON MATCH property expressions
363                for (_, expr) in &merge.on_match {
364                    self.validate_expression(expr)?;
365                }
366                // MERGE introduces a new variable
367                self.context.add_variable(
368                    merge.variable.clone(),
369                    VariableInfo {
370                        name: merge.variable.clone(),
371                        data_type: LogicalType::Node,
372                        is_node: true,
373                        is_edge: false,
374                    },
375                );
376                Ok(())
377            }
378            LogicalOperator::AddLabel(add_label) => {
379                self.bind_operator(&add_label.input)?;
380                // Validate that the variable exists
381                if !self.context.contains(&add_label.variable) {
382                    return Err(undefined_variable_error(
383                        &add_label.variable,
384                        &self.context,
385                        " in SET labels",
386                    ));
387                }
388                Ok(())
389            }
390            LogicalOperator::RemoveLabel(remove_label) => {
391                self.bind_operator(&remove_label.input)?;
392                // Validate that the variable exists
393                if !self.context.contains(&remove_label.variable) {
394                    return Err(undefined_variable_error(
395                        &remove_label.variable,
396                        &self.context,
397                        " in REMOVE labels",
398                    ));
399                }
400                Ok(())
401            }
402            LogicalOperator::ShortestPath(sp) => {
403                // First bind the input
404                self.bind_operator(&sp.input)?;
405                // Validate that source and target variables are defined
406                if !self.context.contains(&sp.source_var) {
407                    return Err(undefined_variable_error(
408                        &sp.source_var,
409                        &self.context,
410                        " (source in shortestPath)",
411                    ));
412                }
413                if !self.context.contains(&sp.target_var) {
414                    return Err(undefined_variable_error(
415                        &sp.target_var,
416                        &self.context,
417                        " (target in shortestPath)",
418                    ));
419                }
420                // Add the path alias variable to the context
421                self.context.add_variable(
422                    sp.path_alias.clone(),
423                    VariableInfo {
424                        name: sp.path_alias.clone(),
425                        data_type: LogicalType::Any, // Path is a complex type
426                        is_node: false,
427                        is_edge: false,
428                    },
429                );
430                // Also add the path length variable for length(p) calls
431                let path_length_var = format!("_path_length_{}", sp.path_alias);
432                self.context.add_variable(
433                    path_length_var.clone(),
434                    VariableInfo {
435                        name: path_length_var,
436                        data_type: LogicalType::Int64,
437                        is_node: false,
438                        is_edge: false,
439                    },
440                );
441                Ok(())
442            }
443            // SPARQL Update operators - these don't require variable binding
444            LogicalOperator::InsertTriple(insert) => {
445                if let Some(ref input) = insert.input {
446                    self.bind_operator(input)?;
447                }
448                Ok(())
449            }
450            LogicalOperator::DeleteTriple(delete) => {
451                if let Some(ref input) = delete.input {
452                    self.bind_operator(input)?;
453                }
454                Ok(())
455            }
456            LogicalOperator::Modify(modify) => {
457                self.bind_operator(&modify.where_clause)?;
458                Ok(())
459            }
460            LogicalOperator::ClearGraph(_)
461            | LogicalOperator::CreateGraph(_)
462            | LogicalOperator::DropGraph(_)
463            | LogicalOperator::LoadGraph(_)
464            | LogicalOperator::CopyGraph(_)
465            | LogicalOperator::MoveGraph(_)
466            | LogicalOperator::AddGraph(_) => Ok(()),
467            LogicalOperator::VectorScan(scan) => {
468                // VectorScan introduces a variable for matched nodes
469                if let Some(ref input) = scan.input {
470                    self.bind_operator(input)?;
471                }
472                self.context.add_variable(
473                    scan.variable.clone(),
474                    VariableInfo {
475                        name: scan.variable.clone(),
476                        data_type: LogicalType::Node,
477                        is_node: true,
478                        is_edge: false,
479                    },
480                );
481                // Validate the query vector expression
482                self.validate_expression(&scan.query_vector)?;
483                Ok(())
484            }
485            LogicalOperator::VectorJoin(join) => {
486                // VectorJoin takes input from left side and produces right-side matches
487                self.bind_operator(&join.input)?;
488                // Add right variable for matched nodes
489                self.context.add_variable(
490                    join.right_variable.clone(),
491                    VariableInfo {
492                        name: join.right_variable.clone(),
493                        data_type: LogicalType::Node,
494                        is_node: true,
495                        is_edge: false,
496                    },
497                );
498                // Optionally add score variable
499                if let Some(ref score_var) = join.score_variable {
500                    self.context.add_variable(
501                        score_var.clone(),
502                        VariableInfo {
503                            name: score_var.clone(),
504                            data_type: LogicalType::Float64,
505                            is_node: false,
506                            is_edge: false,
507                        },
508                    );
509                }
510                // Validate the query vector expression
511                self.validate_expression(&join.query_vector)?;
512                Ok(())
513            }
514        }
515    }
516
517    /// Binds a triple scan operator (for RDF/SPARQL).
518    fn bind_triple_scan(&mut self, scan: &TripleScanOp) -> Result<()> {
519        use crate::query::plan::TripleComponent;
520
521        // First bind the input if present
522        if let Some(ref input) = scan.input {
523            self.bind_operator(input)?;
524        }
525
526        // Add variables for subject, predicate, object
527        if let TripleComponent::Variable(name) = &scan.subject
528            && !self.context.contains(name)
529        {
530            self.context.add_variable(
531                name.clone(),
532                VariableInfo {
533                    name: name.clone(),
534                    data_type: LogicalType::Any, // RDF term
535                    is_node: false,
536                    is_edge: false,
537                },
538            );
539        }
540
541        if let TripleComponent::Variable(name) = &scan.predicate
542            && !self.context.contains(name)
543        {
544            self.context.add_variable(
545                name.clone(),
546                VariableInfo {
547                    name: name.clone(),
548                    data_type: LogicalType::Any, // IRI
549                    is_node: false,
550                    is_edge: false,
551                },
552            );
553        }
554
555        if let TripleComponent::Variable(name) = &scan.object
556            && !self.context.contains(name)
557        {
558            self.context.add_variable(
559                name.clone(),
560                VariableInfo {
561                    name: name.clone(),
562                    data_type: LogicalType::Any, // RDF term
563                    is_node: false,
564                    is_edge: false,
565                },
566            );
567        }
568
569        if let Some(TripleComponent::Variable(name)) = &scan.graph
570            && !self.context.contains(name)
571        {
572            self.context.add_variable(
573                name.clone(),
574                VariableInfo {
575                    name: name.clone(),
576                    data_type: LogicalType::Any, // IRI
577                    is_node: false,
578                    is_edge: false,
579                },
580            );
581        }
582
583        Ok(())
584    }
585
586    /// Binds a node scan operator.
587    fn bind_node_scan(&mut self, scan: &NodeScanOp) -> Result<()> {
588        // First bind the input if present
589        if let Some(ref input) = scan.input {
590            self.bind_operator(input)?;
591        }
592
593        // Add the scanned variable to scope
594        self.context.add_variable(
595            scan.variable.clone(),
596            VariableInfo {
597                name: scan.variable.clone(),
598                data_type: LogicalType::Node,
599                is_node: true,
600                is_edge: false,
601            },
602        );
603
604        Ok(())
605    }
606
607    /// Binds an expand operator.
608    fn bind_expand(&mut self, expand: &ExpandOp) -> Result<()> {
609        // First bind the input
610        self.bind_operator(&expand.input)?;
611
612        // Validate that the source variable is defined
613        if !self.context.contains(&expand.from_variable) {
614            return Err(undefined_variable_error(
615                &expand.from_variable,
616                &self.context,
617                " in EXPAND",
618            ));
619        }
620
621        // Validate that the source is a node
622        if let Some(info) = self.context.get(&expand.from_variable)
623            && !info.is_node
624        {
625            return Err(binding_error(format!(
626                "Variable '{}' is not a node, cannot expand from it",
627                expand.from_variable
628            )));
629        }
630
631        // Add edge variable if present
632        if let Some(ref edge_var) = expand.edge_variable {
633            self.context.add_variable(
634                edge_var.clone(),
635                VariableInfo {
636                    name: edge_var.clone(),
637                    data_type: LogicalType::Edge,
638                    is_node: false,
639                    is_edge: true,
640                },
641            );
642        }
643
644        // Add target variable
645        self.context.add_variable(
646            expand.to_variable.clone(),
647            VariableInfo {
648                name: expand.to_variable.clone(),
649                data_type: LogicalType::Node,
650                is_node: true,
651                is_edge: false,
652            },
653        );
654
655        // Add path length variable for variable-length paths (for length(p) calls)
656        if let Some(ref path_alias) = expand.path_alias {
657            let path_length_var = format!("_path_length_{}", path_alias);
658            self.context.add_variable(
659                path_length_var.clone(),
660                VariableInfo {
661                    name: path_length_var,
662                    data_type: LogicalType::Int64,
663                    is_node: false,
664                    is_edge: false,
665                },
666            );
667        }
668
669        Ok(())
670    }
671
672    /// Binds a filter operator.
673    fn bind_filter(&mut self, filter: &FilterOp) -> Result<()> {
674        // First bind the input
675        self.bind_operator(&filter.input)?;
676
677        // Validate the predicate expression
678        self.validate_expression(&filter.predicate)?;
679
680        Ok(())
681    }
682
683    /// Binds a return operator.
684    fn bind_return(&mut self, ret: &ReturnOp) -> Result<()> {
685        // First bind the input
686        self.bind_operator(&ret.input)?;
687
688        // Validate all return expressions
689        for item in &ret.items {
690            self.validate_return_item(item)?;
691        }
692
693        Ok(())
694    }
695
696    /// Validates a return item.
697    fn validate_return_item(&self, item: &ReturnItem) -> Result<()> {
698        self.validate_expression(&item.expression)
699    }
700
701    /// Validates that an expression only references defined variables.
702    fn validate_expression(&self, expr: &LogicalExpression) -> Result<()> {
703        match expr {
704            LogicalExpression::Variable(name) => {
705                if !self.context.contains(name) && !name.starts_with("_anon_") {
706                    return Err(undefined_variable_error(name, &self.context, ""));
707                }
708                Ok(())
709            }
710            LogicalExpression::Property { variable, .. } => {
711                if !self.context.contains(variable) && !variable.starts_with("_anon_") {
712                    return Err(undefined_variable_error(
713                        variable,
714                        &self.context,
715                        " in property access",
716                    ));
717                }
718                Ok(())
719            }
720            LogicalExpression::Literal(_) => Ok(()),
721            LogicalExpression::Binary { left, right, .. } => {
722                self.validate_expression(left)?;
723                self.validate_expression(right)
724            }
725            LogicalExpression::Unary { operand, .. } => self.validate_expression(operand),
726            LogicalExpression::FunctionCall { args, .. } => {
727                for arg in args {
728                    self.validate_expression(arg)?;
729                }
730                Ok(())
731            }
732            LogicalExpression::List(items) => {
733                for item in items {
734                    self.validate_expression(item)?;
735                }
736                Ok(())
737            }
738            LogicalExpression::Map(pairs) => {
739                for (_, value) in pairs {
740                    self.validate_expression(value)?;
741                }
742                Ok(())
743            }
744            LogicalExpression::IndexAccess { base, index } => {
745                self.validate_expression(base)?;
746                self.validate_expression(index)
747            }
748            LogicalExpression::SliceAccess { base, start, end } => {
749                self.validate_expression(base)?;
750                if let Some(s) = start {
751                    self.validate_expression(s)?;
752                }
753                if let Some(e) = end {
754                    self.validate_expression(e)?;
755                }
756                Ok(())
757            }
758            LogicalExpression::Case {
759                operand,
760                when_clauses,
761                else_clause,
762            } => {
763                if let Some(op) = operand {
764                    self.validate_expression(op)?;
765                }
766                for (cond, result) in when_clauses {
767                    self.validate_expression(cond)?;
768                    self.validate_expression(result)?;
769                }
770                if let Some(else_expr) = else_clause {
771                    self.validate_expression(else_expr)?;
772                }
773                Ok(())
774            }
775            // Parameter references are validated externally
776            LogicalExpression::Parameter(_) => Ok(()),
777            // labels(n), type(e), id(n) need the variable to be defined
778            LogicalExpression::Labels(var)
779            | LogicalExpression::Type(var)
780            | LogicalExpression::Id(var) => {
781                if !self.context.contains(var) && !var.starts_with("_anon_") {
782                    return Err(undefined_variable_error(var, &self.context, " in function"));
783                }
784                Ok(())
785            }
786            LogicalExpression::ListComprehension {
787                list_expr,
788                filter_expr,
789                map_expr,
790                ..
791            } => {
792                // Validate the list expression
793                self.validate_expression(list_expr)?;
794                // Note: filter_expr and map_expr use the comprehension variable
795                // which is defined within the comprehension scope, so we don't
796                // need to validate it against the outer context
797                if let Some(filter) = filter_expr {
798                    self.validate_expression(filter)?;
799                }
800                self.validate_expression(map_expr)?;
801                Ok(())
802            }
803            LogicalExpression::ExistsSubquery(subquery)
804            | LogicalExpression::CountSubquery(subquery) => {
805                // Subqueries have their own binding context
806                // For now, just validate the structure exists
807                let _ = subquery; // Would need recursive binding
808                Ok(())
809            }
810        }
811    }
812
813    /// Infers the type of an expression for use in WITH clause aliasing.
814    fn infer_expression_type(&self, expr: &LogicalExpression) -> LogicalType {
815        match expr {
816            LogicalExpression::Variable(name) => {
817                // Look up the variable type from context
818                self.context
819                    .get(name)
820                    .map(|info| info.data_type.clone())
821                    .unwrap_or(LogicalType::Any)
822            }
823            LogicalExpression::Property { .. } => LogicalType::Any, // Properties can be any type
824            LogicalExpression::Literal(value) => {
825                // Infer type from literal value
826                use grafeo_common::types::Value;
827                match value {
828                    Value::Bool(_) => LogicalType::Bool,
829                    Value::Int64(_) => LogicalType::Int64,
830                    Value::Float64(_) => LogicalType::Float64,
831                    Value::String(_) => LogicalType::String,
832                    Value::List(_) => LogicalType::Any, // Complex type
833                    Value::Map(_) => LogicalType::Any,  // Complex type
834                    Value::Null => LogicalType::Any,
835                    _ => LogicalType::Any,
836                }
837            }
838            LogicalExpression::Binary { .. } => LogicalType::Any, // Could be bool or numeric
839            LogicalExpression::Unary { .. } => LogicalType::Any,
840            LogicalExpression::FunctionCall { name, .. } => {
841                // Infer based on function name
842                match name.to_lowercase().as_str() {
843                    "count" | "sum" | "id" => LogicalType::Int64,
844                    "avg" => LogicalType::Float64,
845                    "type" => LogicalType::String,
846                    // List-returning functions use Any since we don't track element type
847                    "labels" | "collect" => LogicalType::Any,
848                    _ => LogicalType::Any,
849                }
850            }
851            LogicalExpression::List(_) => LogicalType::Any, // Complex type
852            LogicalExpression::Map(_) => LogicalType::Any,  // Complex type
853            _ => LogicalType::Any,
854        }
855    }
856
857    /// Binds a join operator.
858    fn bind_join(&mut self, join: &crate::query::plan::JoinOp) -> Result<()> {
859        // Bind both sides of the join
860        self.bind_operator(&join.left)?;
861        self.bind_operator(&join.right)?;
862
863        // Validate join conditions
864        for condition in &join.conditions {
865            self.validate_expression(&condition.left)?;
866            self.validate_expression(&condition.right)?;
867        }
868
869        Ok(())
870    }
871
872    /// Binds an aggregate operator.
873    fn bind_aggregate(&mut self, agg: &crate::query::plan::AggregateOp) -> Result<()> {
874        // Bind the input first
875        self.bind_operator(&agg.input)?;
876
877        // Validate group by expressions
878        for expr in &agg.group_by {
879            self.validate_expression(expr)?;
880        }
881
882        // Validate aggregate expressions
883        for agg_expr in &agg.aggregates {
884            if let Some(ref expr) = agg_expr.expression {
885                self.validate_expression(expr)?;
886            }
887            // Add the alias as a new variable if present
888            if let Some(ref alias) = agg_expr.alias {
889                self.context.add_variable(
890                    alias.clone(),
891                    VariableInfo {
892                        name: alias.clone(),
893                        data_type: LogicalType::Any,
894                        is_node: false,
895                        is_edge: false,
896                    },
897                );
898            }
899        }
900
901        Ok(())
902    }
903}
904
905impl Default for Binder {
906    fn default() -> Self {
907        Self::new()
908    }
909}
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914    use crate::query::plan::{BinaryOp, FilterOp};
915
916    #[test]
917    fn test_bind_simple_scan() {
918        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
919            items: vec![ReturnItem {
920                expression: LogicalExpression::Variable("n".to_string()),
921                alias: None,
922            }],
923            distinct: false,
924            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
925                variable: "n".to_string(),
926                label: Some("Person".to_string()),
927                input: None,
928            })),
929        }));
930
931        let mut binder = Binder::new();
932        let result = binder.bind(&plan);
933
934        assert!(result.is_ok());
935        let ctx = result.unwrap();
936        assert!(ctx.contains("n"));
937        assert!(ctx.get("n").unwrap().is_node);
938    }
939
940    #[test]
941    fn test_bind_undefined_variable() {
942        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
943            items: vec![ReturnItem {
944                expression: LogicalExpression::Variable("undefined".to_string()),
945                alias: None,
946            }],
947            distinct: false,
948            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
949                variable: "n".to_string(),
950                label: None,
951                input: None,
952            })),
953        }));
954
955        let mut binder = Binder::new();
956        let result = binder.bind(&plan);
957
958        assert!(result.is_err());
959        let err = result.unwrap_err();
960        assert!(err.to_string().contains("Undefined variable"));
961    }
962
963    #[test]
964    fn test_bind_property_access() {
965        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
966            items: vec![ReturnItem {
967                expression: LogicalExpression::Property {
968                    variable: "n".to_string(),
969                    property: "name".to_string(),
970                },
971                alias: None,
972            }],
973            distinct: false,
974            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
975                variable: "n".to_string(),
976                label: Some("Person".to_string()),
977                input: None,
978            })),
979        }));
980
981        let mut binder = Binder::new();
982        let result = binder.bind(&plan);
983
984        assert!(result.is_ok());
985    }
986
987    #[test]
988    fn test_bind_filter_with_undefined_variable() {
989        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
990            items: vec![ReturnItem {
991                expression: LogicalExpression::Variable("n".to_string()),
992                alias: None,
993            }],
994            distinct: false,
995            input: Box::new(LogicalOperator::Filter(FilterOp {
996                predicate: LogicalExpression::Binary {
997                    left: Box::new(LogicalExpression::Property {
998                        variable: "m".to_string(), // undefined!
999                        property: "age".to_string(),
1000                    }),
1001                    op: BinaryOp::Gt,
1002                    right: Box::new(LogicalExpression::Literal(
1003                        grafeo_common::types::Value::Int64(30),
1004                    )),
1005                },
1006                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1007                    variable: "n".to_string(),
1008                    label: None,
1009                    input: None,
1010                })),
1011            })),
1012        }));
1013
1014        let mut binder = Binder::new();
1015        let result = binder.bind(&plan);
1016
1017        assert!(result.is_err());
1018        let err = result.unwrap_err();
1019        assert!(err.to_string().contains("Undefined variable 'm'"));
1020    }
1021
1022    #[test]
1023    fn test_bind_expand() {
1024        use crate::query::plan::{ExpandDirection, ExpandOp};
1025
1026        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1027            items: vec![
1028                ReturnItem {
1029                    expression: LogicalExpression::Variable("a".to_string()),
1030                    alias: None,
1031                },
1032                ReturnItem {
1033                    expression: LogicalExpression::Variable("b".to_string()),
1034                    alias: None,
1035                },
1036            ],
1037            distinct: false,
1038            input: Box::new(LogicalOperator::Expand(ExpandOp {
1039                from_variable: "a".to_string(),
1040                to_variable: "b".to_string(),
1041                edge_variable: Some("e".to_string()),
1042                direction: ExpandDirection::Outgoing,
1043                edge_type: Some("KNOWS".to_string()),
1044                min_hops: 1,
1045                max_hops: Some(1),
1046                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1047                    variable: "a".to_string(),
1048                    label: Some("Person".to_string()),
1049                    input: None,
1050                })),
1051                path_alias: None,
1052            })),
1053        }));
1054
1055        let mut binder = Binder::new();
1056        let result = binder.bind(&plan);
1057
1058        assert!(result.is_ok());
1059        let ctx = result.unwrap();
1060        assert!(ctx.contains("a"));
1061        assert!(ctx.contains("b"));
1062        assert!(ctx.contains("e"));
1063        assert!(ctx.get("a").unwrap().is_node);
1064        assert!(ctx.get("b").unwrap().is_node);
1065        assert!(ctx.get("e").unwrap().is_edge);
1066    }
1067
1068    #[test]
1069    fn test_bind_expand_from_undefined_variable() {
1070        // Tests that expanding from an undefined variable produces a clear error
1071        use crate::query::plan::{ExpandDirection, ExpandOp};
1072
1073        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1074            items: vec![ReturnItem {
1075                expression: LogicalExpression::Variable("b".to_string()),
1076                alias: None,
1077            }],
1078            distinct: false,
1079            input: Box::new(LogicalOperator::Expand(ExpandOp {
1080                from_variable: "undefined".to_string(), // not defined!
1081                to_variable: "b".to_string(),
1082                edge_variable: None,
1083                direction: ExpandDirection::Outgoing,
1084                edge_type: None,
1085                min_hops: 1,
1086                max_hops: Some(1),
1087                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1088                    variable: "a".to_string(),
1089                    label: None,
1090                    input: None,
1091                })),
1092                path_alias: None,
1093            })),
1094        }));
1095
1096        let mut binder = Binder::new();
1097        let result = binder.bind(&plan);
1098
1099        assert!(result.is_err());
1100        let err = result.unwrap_err();
1101        assert!(
1102            err.to_string().contains("Undefined variable 'undefined'"),
1103            "Expected error about undefined variable, got: {}",
1104            err
1105        );
1106    }
1107
1108    #[test]
1109    fn test_bind_return_with_aggregate_and_non_aggregate() {
1110        // Tests binding of aggregate functions alongside regular expressions
1111        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1112            items: vec![
1113                ReturnItem {
1114                    expression: LogicalExpression::FunctionCall {
1115                        name: "count".to_string(),
1116                        args: vec![LogicalExpression::Variable("n".to_string())],
1117                        distinct: false,
1118                    },
1119                    alias: Some("cnt".to_string()),
1120                },
1121                ReturnItem {
1122                    expression: LogicalExpression::Literal(grafeo_common::types::Value::Int64(1)),
1123                    alias: Some("one".to_string()),
1124                },
1125            ],
1126            distinct: false,
1127            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1128                variable: "n".to_string(),
1129                label: Some("Person".to_string()),
1130                input: None,
1131            })),
1132        }));
1133
1134        let mut binder = Binder::new();
1135        let result = binder.bind(&plan);
1136
1137        // This should succeed - count(n) with literal is valid
1138        assert!(result.is_ok());
1139    }
1140
1141    #[test]
1142    fn test_bind_nested_property_access() {
1143        // Tests that nested property access on the same variable works
1144        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1145            items: vec![
1146                ReturnItem {
1147                    expression: LogicalExpression::Property {
1148                        variable: "n".to_string(),
1149                        property: "name".to_string(),
1150                    },
1151                    alias: None,
1152                },
1153                ReturnItem {
1154                    expression: LogicalExpression::Property {
1155                        variable: "n".to_string(),
1156                        property: "age".to_string(),
1157                    },
1158                    alias: None,
1159                },
1160            ],
1161            distinct: false,
1162            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1163                variable: "n".to_string(),
1164                label: Some("Person".to_string()),
1165                input: None,
1166            })),
1167        }));
1168
1169        let mut binder = Binder::new();
1170        let result = binder.bind(&plan);
1171
1172        assert!(result.is_ok());
1173    }
1174
1175    #[test]
1176    fn test_bind_binary_expression_with_undefined() {
1177        // Tests that binary expressions with undefined variables produce errors
1178        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1179            items: vec![ReturnItem {
1180                expression: LogicalExpression::Binary {
1181                    left: Box::new(LogicalExpression::Property {
1182                        variable: "n".to_string(),
1183                        property: "age".to_string(),
1184                    }),
1185                    op: BinaryOp::Add,
1186                    right: Box::new(LogicalExpression::Property {
1187                        variable: "m".to_string(), // undefined!
1188                        property: "age".to_string(),
1189                    }),
1190                },
1191                alias: Some("total".to_string()),
1192            }],
1193            distinct: false,
1194            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1195                variable: "n".to_string(),
1196                label: None,
1197                input: None,
1198            })),
1199        }));
1200
1201        let mut binder = Binder::new();
1202        let result = binder.bind(&plan);
1203
1204        assert!(result.is_err());
1205        assert!(
1206            result
1207                .unwrap_err()
1208                .to_string()
1209                .contains("Undefined variable 'm'")
1210        );
1211    }
1212
1213    #[test]
1214    fn test_bind_duplicate_variable_definition() {
1215        // Tests behavior when the same variable is defined twice (via two NodeScans)
1216        // This is typically not allowed or the second shadows the first
1217        use crate::query::plan::{JoinOp, JoinType};
1218
1219        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1220            items: vec![ReturnItem {
1221                expression: LogicalExpression::Variable("n".to_string()),
1222                alias: None,
1223            }],
1224            distinct: false,
1225            input: Box::new(LogicalOperator::Join(JoinOp {
1226                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1227                    variable: "n".to_string(),
1228                    label: Some("A".to_string()),
1229                    input: None,
1230                })),
1231                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1232                    variable: "m".to_string(), // different variable is fine
1233                    label: Some("B".to_string()),
1234                    input: None,
1235                })),
1236                join_type: JoinType::Inner,
1237                conditions: vec![],
1238            })),
1239        }));
1240
1241        let mut binder = Binder::new();
1242        let result = binder.bind(&plan);
1243
1244        // Join with different variables should work
1245        assert!(result.is_ok());
1246        let ctx = result.unwrap();
1247        assert!(ctx.contains("n"));
1248        assert!(ctx.contains("m"));
1249    }
1250
1251    #[test]
1252    fn test_bind_function_with_wrong_arity() {
1253        // Tests that functions with wrong number of arguments are handled
1254        // (behavior depends on whether binder validates arity)
1255        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1256            items: vec![ReturnItem {
1257                expression: LogicalExpression::FunctionCall {
1258                    name: "count".to_string(),
1259                    args: vec![], // count() needs an argument
1260                    distinct: false,
1261                },
1262                alias: None,
1263            }],
1264            distinct: false,
1265            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1266                variable: "n".to_string(),
1267                label: None,
1268                input: None,
1269            })),
1270        }));
1271
1272        let mut binder = Binder::new();
1273        let result = binder.bind(&plan);
1274
1275        // The binder may or may not catch this - if it passes, execution will fail
1276        // This test documents current behavior
1277        // If binding fails, that's fine; if it passes, execution will handle it
1278        let _ = result; // We're just testing it doesn't panic
1279    }
1280}