Skip to main content

grafeo_engine/query/
binder.rs

1//! Semantic validation - catching errors before execution.
2//!
3//! The binder walks the logical plan and validates that everything makes sense:
4//! - Is that variable actually defined? (You can't use `RETURN x` if `x` wasn't matched)
5//! - Does that property access make sense? (Accessing `.age` on an integer fails)
6//! - Are types compatible? (Can't compare a string to an integer)
7//!
8//! Better to catch these errors early than waste time executing a broken query.
9
10use crate::query::plan::{
11    ExpandOp, FilterOp, LogicalExpression, LogicalOperator, LogicalPlan, NodeScanOp, ReturnItem,
12    ReturnOp, TripleScanOp,
13};
14use grafeo_common::types::LogicalType;
15use grafeo_common::utils::error::{Error, QueryError, QueryErrorKind, Result};
16use std::collections::HashMap;
17
18/// Creates a semantic binding error.
19fn binding_error(message: impl Into<String>) -> Error {
20    Error::Query(QueryError::new(QueryErrorKind::Semantic, message))
21}
22
23/// Information about a bound variable.
24#[derive(Debug, Clone)]
25pub struct VariableInfo {
26    /// The name of the variable.
27    pub name: String,
28    /// The inferred type of the variable.
29    pub data_type: LogicalType,
30    /// Whether this variable is a node.
31    pub is_node: bool,
32    /// Whether this variable is an edge.
33    pub is_edge: bool,
34}
35
36/// Context containing all bound variables and their information.
37#[derive(Debug, Clone, Default)]
38pub struct BindingContext {
39    /// Map from variable name to its info.
40    variables: HashMap<String, VariableInfo>,
41    /// Variables in order of definition.
42    order: Vec<String>,
43}
44
45impl BindingContext {
46    /// Creates a new empty binding context.
47    #[must_use]
48    pub fn new() -> Self {
49        Self {
50            variables: HashMap::new(),
51            order: Vec::new(),
52        }
53    }
54
55    /// Adds a variable to the context.
56    pub fn add_variable(&mut self, name: String, info: VariableInfo) {
57        if !self.variables.contains_key(&name) {
58            self.order.push(name.clone());
59        }
60        self.variables.insert(name, info);
61    }
62
63    /// Looks up a variable by name.
64    #[must_use]
65    pub fn get(&self, name: &str) -> Option<&VariableInfo> {
66        self.variables.get(name)
67    }
68
69    /// Checks if a variable is defined.
70    #[must_use]
71    pub fn contains(&self, name: &str) -> bool {
72        self.variables.contains_key(name)
73    }
74
75    /// Returns all variable names in definition order.
76    #[must_use]
77    pub fn variable_names(&self) -> &[String] {
78        &self.order
79    }
80
81    /// Returns the number of bound variables.
82    #[must_use]
83    pub fn len(&self) -> usize {
84        self.variables.len()
85    }
86
87    /// Returns true if no variables are bound.
88    #[must_use]
89    pub fn is_empty(&self) -> bool {
90        self.variables.is_empty()
91    }
92}
93
94/// Semantic binder for query plans.
95///
96/// The binder walks the logical plan and:
97/// 1. Collects all variable definitions
98/// 2. Validates that all variable references are valid
99/// 3. Infers types where possible
100/// 4. Reports semantic errors
101pub struct Binder {
102    /// The current binding context.
103    context: BindingContext,
104}
105
106impl Binder {
107    /// Creates a new binder.
108    #[must_use]
109    pub fn new() -> Self {
110        Self {
111            context: BindingContext::new(),
112        }
113    }
114
115    /// Binds a logical plan, returning the binding context.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if semantic validation fails.
120    pub fn bind(&mut self, plan: &LogicalPlan) -> Result<BindingContext> {
121        self.bind_operator(&plan.root)?;
122        Ok(self.context.clone())
123    }
124
125    /// Binds a single logical operator.
126    fn bind_operator(&mut self, op: &LogicalOperator) -> Result<()> {
127        match op {
128            LogicalOperator::NodeScan(scan) => self.bind_node_scan(scan),
129            LogicalOperator::Expand(expand) => self.bind_expand(expand),
130            LogicalOperator::Filter(filter) => self.bind_filter(filter),
131            LogicalOperator::Return(ret) => self.bind_return(ret),
132            LogicalOperator::Project(project) => {
133                self.bind_operator(&project.input)?;
134                for projection in &project.projections {
135                    self.validate_expression(&projection.expression)?;
136                    // Add the projection alias to the context (for WITH clause support)
137                    if let Some(ref alias) = projection.alias {
138                        // Determine the type from the expression
139                        let data_type = self.infer_expression_type(&projection.expression);
140                        self.context.add_variable(
141                            alias.clone(),
142                            VariableInfo {
143                                name: alias.clone(),
144                                data_type,
145                                is_node: false,
146                                is_edge: false,
147                            },
148                        );
149                    }
150                }
151                Ok(())
152            }
153            LogicalOperator::Limit(limit) => self.bind_operator(&limit.input),
154            LogicalOperator::Skip(skip) => self.bind_operator(&skip.input),
155            LogicalOperator::Sort(sort) => {
156                self.bind_operator(&sort.input)?;
157                for key in &sort.keys {
158                    self.validate_expression(&key.expression)?;
159                }
160                Ok(())
161            }
162            LogicalOperator::CreateNode(create) => {
163                // CreateNode introduces a new variable
164                if let Some(ref input) = create.input {
165                    self.bind_operator(input)?;
166                }
167                self.context.add_variable(
168                    create.variable.clone(),
169                    VariableInfo {
170                        name: create.variable.clone(),
171                        data_type: LogicalType::Node,
172                        is_node: true,
173                        is_edge: false,
174                    },
175                );
176                // Validate property expressions
177                for (_, expr) in &create.properties {
178                    self.validate_expression(expr)?;
179                }
180                Ok(())
181            }
182            LogicalOperator::EdgeScan(scan) => {
183                if let Some(ref input) = scan.input {
184                    self.bind_operator(input)?;
185                }
186                self.context.add_variable(
187                    scan.variable.clone(),
188                    VariableInfo {
189                        name: scan.variable.clone(),
190                        data_type: LogicalType::Edge,
191                        is_node: false,
192                        is_edge: true,
193                    },
194                );
195                Ok(())
196            }
197            LogicalOperator::Distinct(distinct) => self.bind_operator(&distinct.input),
198            LogicalOperator::Join(join) => self.bind_join(join),
199            LogicalOperator::Aggregate(agg) => self.bind_aggregate(agg),
200            LogicalOperator::CreateEdge(create) => {
201                self.bind_operator(&create.input)?;
202                // Validate that source and target variables are defined
203                if !self.context.contains(&create.from_variable) {
204                    return Err(binding_error(format!(
205                        "Undefined source variable '{}' in CREATE EDGE",
206                        create.from_variable
207                    )));
208                }
209                if !self.context.contains(&create.to_variable) {
210                    return Err(binding_error(format!(
211                        "Undefined target variable '{}' in CREATE EDGE",
212                        create.to_variable
213                    )));
214                }
215                // Add edge variable if present
216                if let Some(ref var) = create.variable {
217                    self.context.add_variable(
218                        var.clone(),
219                        VariableInfo {
220                            name: var.clone(),
221                            data_type: LogicalType::Edge,
222                            is_node: false,
223                            is_edge: true,
224                        },
225                    );
226                }
227                // Validate property expressions
228                for (_, expr) in &create.properties {
229                    self.validate_expression(expr)?;
230                }
231                Ok(())
232            }
233            LogicalOperator::DeleteNode(delete) => {
234                self.bind_operator(&delete.input)?;
235                // Validate that the variable to delete is defined
236                if !self.context.contains(&delete.variable) {
237                    return Err(binding_error(format!(
238                        "Undefined variable '{}' in DELETE",
239                        delete.variable
240                    )));
241                }
242                Ok(())
243            }
244            LogicalOperator::DeleteEdge(delete) => {
245                self.bind_operator(&delete.input)?;
246                // Validate that the variable to delete is defined
247                if !self.context.contains(&delete.variable) {
248                    return Err(binding_error(format!(
249                        "Undefined variable '{}' in DELETE",
250                        delete.variable
251                    )));
252                }
253                Ok(())
254            }
255            LogicalOperator::SetProperty(set) => {
256                self.bind_operator(&set.input)?;
257                // Validate that the variable to update is defined
258                if !self.context.contains(&set.variable) {
259                    return Err(binding_error(format!(
260                        "Undefined variable '{}' in SET",
261                        set.variable
262                    )));
263                }
264                // Validate property value expressions
265                for (_, expr) in &set.properties {
266                    self.validate_expression(expr)?;
267                }
268                Ok(())
269            }
270            LogicalOperator::Empty => Ok(()),
271
272            LogicalOperator::Unwind(unwind) => {
273                // First bind the input
274                self.bind_operator(&unwind.input)?;
275                // Validate the expression being unwound
276                self.validate_expression(&unwind.expression)?;
277                // Add the new variable to the context
278                self.context.add_variable(
279                    unwind.variable.clone(),
280                    VariableInfo {
281                        name: unwind.variable.clone(),
282                        data_type: LogicalType::Any, // Unwound elements can be any type
283                        is_node: false,
284                        is_edge: false,
285                    },
286                );
287                Ok(())
288            }
289
290            // RDF/SPARQL operators
291            LogicalOperator::TripleScan(scan) => self.bind_triple_scan(scan),
292            LogicalOperator::Union(union) => {
293                for input in &union.inputs {
294                    self.bind_operator(input)?;
295                }
296                Ok(())
297            }
298            LogicalOperator::LeftJoin(lj) => {
299                self.bind_operator(&lj.left)?;
300                self.bind_operator(&lj.right)?;
301                if let Some(ref cond) = lj.condition {
302                    self.validate_expression(cond)?;
303                }
304                Ok(())
305            }
306            LogicalOperator::AntiJoin(aj) => {
307                self.bind_operator(&aj.left)?;
308                self.bind_operator(&aj.right)?;
309                Ok(())
310            }
311            LogicalOperator::Bind(bind) => {
312                self.bind_operator(&bind.input)?;
313                self.validate_expression(&bind.expression)?;
314                self.context.add_variable(
315                    bind.variable.clone(),
316                    VariableInfo {
317                        name: bind.variable.clone(),
318                        data_type: LogicalType::Any,
319                        is_node: false,
320                        is_edge: false,
321                    },
322                );
323                Ok(())
324            }
325            LogicalOperator::Merge(merge) => {
326                // First bind the input
327                self.bind_operator(&merge.input)?;
328                // Validate the match property expressions
329                for (_, expr) in &merge.match_properties {
330                    self.validate_expression(expr)?;
331                }
332                // Validate the ON CREATE property expressions
333                for (_, expr) in &merge.on_create {
334                    self.validate_expression(expr)?;
335                }
336                // Validate the ON MATCH property expressions
337                for (_, expr) in &merge.on_match {
338                    self.validate_expression(expr)?;
339                }
340                // MERGE introduces a new variable
341                self.context.add_variable(
342                    merge.variable.clone(),
343                    VariableInfo {
344                        name: merge.variable.clone(),
345                        data_type: LogicalType::Node,
346                        is_node: true,
347                        is_edge: false,
348                    },
349                );
350                Ok(())
351            }
352            LogicalOperator::AddLabel(add_label) => {
353                self.bind_operator(&add_label.input)?;
354                // Validate that the variable exists
355                if !self.context.contains(&add_label.variable) {
356                    return Err(binding_error(format!(
357                        "Undefined variable '{}' in SET labels",
358                        add_label.variable
359                    )));
360                }
361                Ok(())
362            }
363            LogicalOperator::RemoveLabel(remove_label) => {
364                self.bind_operator(&remove_label.input)?;
365                // Validate that the variable exists
366                if !self.context.contains(&remove_label.variable) {
367                    return Err(binding_error(format!(
368                        "Undefined variable '{}' in REMOVE labels",
369                        remove_label.variable
370                    )));
371                }
372                Ok(())
373            }
374            LogicalOperator::ShortestPath(sp) => {
375                // First bind the input
376                self.bind_operator(&sp.input)?;
377                // Validate that source and target variables are defined
378                if !self.context.contains(&sp.source_var) {
379                    return Err(binding_error(format!(
380                        "Undefined source variable '{}' in shortestPath",
381                        sp.source_var
382                    )));
383                }
384                if !self.context.contains(&sp.target_var) {
385                    return Err(binding_error(format!(
386                        "Undefined target variable '{}' in shortestPath",
387                        sp.target_var
388                    )));
389                }
390                // Add the path alias variable to the context
391                self.context.add_variable(
392                    sp.path_alias.clone(),
393                    VariableInfo {
394                        name: sp.path_alias.clone(),
395                        data_type: LogicalType::Any, // Path is a complex type
396                        is_node: false,
397                        is_edge: false,
398                    },
399                );
400                // Also add the path length variable for length(p) calls
401                let path_length_var = format!("_path_length_{}", sp.path_alias);
402                self.context.add_variable(
403                    path_length_var.clone(),
404                    VariableInfo {
405                        name: path_length_var,
406                        data_type: LogicalType::Int64,
407                        is_node: false,
408                        is_edge: false,
409                    },
410                );
411                Ok(())
412            }
413            // SPARQL Update operators - these don't require variable binding
414            LogicalOperator::InsertTriple(insert) => {
415                if let Some(ref input) = insert.input {
416                    self.bind_operator(input)?;
417                }
418                Ok(())
419            }
420            LogicalOperator::DeleteTriple(delete) => {
421                if let Some(ref input) = delete.input {
422                    self.bind_operator(input)?;
423                }
424                Ok(())
425            }
426            LogicalOperator::Modify(modify) => {
427                self.bind_operator(&modify.where_clause)?;
428                Ok(())
429            }
430            LogicalOperator::ClearGraph(_)
431            | LogicalOperator::CreateGraph(_)
432            | LogicalOperator::DropGraph(_)
433            | LogicalOperator::LoadGraph(_)
434            | LogicalOperator::CopyGraph(_)
435            | LogicalOperator::MoveGraph(_)
436            | LogicalOperator::AddGraph(_) => Ok(()),
437        }
438    }
439
440    /// Binds a triple scan operator (for RDF/SPARQL).
441    fn bind_triple_scan(&mut self, scan: &TripleScanOp) -> Result<()> {
442        use crate::query::plan::TripleComponent;
443
444        // First bind the input if present
445        if let Some(ref input) = scan.input {
446            self.bind_operator(input)?;
447        }
448
449        // Add variables for subject, predicate, object
450        if let TripleComponent::Variable(name) = &scan.subject {
451            if !self.context.contains(name) {
452                self.context.add_variable(
453                    name.clone(),
454                    VariableInfo {
455                        name: name.clone(),
456                        data_type: LogicalType::Any, // RDF term
457                        is_node: false,
458                        is_edge: false,
459                    },
460                );
461            }
462        }
463
464        if let TripleComponent::Variable(name) = &scan.predicate {
465            if !self.context.contains(name) {
466                self.context.add_variable(
467                    name.clone(),
468                    VariableInfo {
469                        name: name.clone(),
470                        data_type: LogicalType::Any, // IRI
471                        is_node: false,
472                        is_edge: false,
473                    },
474                );
475            }
476        }
477
478        if let TripleComponent::Variable(name) = &scan.object {
479            if !self.context.contains(name) {
480                self.context.add_variable(
481                    name.clone(),
482                    VariableInfo {
483                        name: name.clone(),
484                        data_type: LogicalType::Any, // RDF term
485                        is_node: false,
486                        is_edge: false,
487                    },
488                );
489            }
490        }
491
492        if let Some(TripleComponent::Variable(name)) = &scan.graph {
493            if !self.context.contains(name) {
494                self.context.add_variable(
495                    name.clone(),
496                    VariableInfo {
497                        name: name.clone(),
498                        data_type: LogicalType::Any, // IRI
499                        is_node: false,
500                        is_edge: false,
501                    },
502                );
503            }
504        }
505
506        Ok(())
507    }
508
509    /// Binds a node scan operator.
510    fn bind_node_scan(&mut self, scan: &NodeScanOp) -> Result<()> {
511        // First bind the input if present
512        if let Some(ref input) = scan.input {
513            self.bind_operator(input)?;
514        }
515
516        // Add the scanned variable to scope
517        self.context.add_variable(
518            scan.variable.clone(),
519            VariableInfo {
520                name: scan.variable.clone(),
521                data_type: LogicalType::Node,
522                is_node: true,
523                is_edge: false,
524            },
525        );
526
527        Ok(())
528    }
529
530    /// Binds an expand operator.
531    fn bind_expand(&mut self, expand: &ExpandOp) -> Result<()> {
532        // First bind the input
533        self.bind_operator(&expand.input)?;
534
535        // Validate that the source variable is defined
536        if !self.context.contains(&expand.from_variable) {
537            return Err(binding_error(format!(
538                "Undefined variable '{}' in EXPAND",
539                expand.from_variable
540            )));
541        }
542
543        // Validate that the source is a node
544        if let Some(info) = self.context.get(&expand.from_variable) {
545            if !info.is_node {
546                return Err(binding_error(format!(
547                    "Variable '{}' is not a node, cannot expand from it",
548                    expand.from_variable
549                )));
550            }
551        }
552
553        // Add edge variable if present
554        if let Some(ref edge_var) = expand.edge_variable {
555            self.context.add_variable(
556                edge_var.clone(),
557                VariableInfo {
558                    name: edge_var.clone(),
559                    data_type: LogicalType::Edge,
560                    is_node: false,
561                    is_edge: true,
562                },
563            );
564        }
565
566        // Add target variable
567        self.context.add_variable(
568            expand.to_variable.clone(),
569            VariableInfo {
570                name: expand.to_variable.clone(),
571                data_type: LogicalType::Node,
572                is_node: true,
573                is_edge: false,
574            },
575        );
576
577        // Add path length variable for variable-length paths (for length(p) calls)
578        if let Some(ref path_alias) = expand.path_alias {
579            let path_length_var = format!("_path_length_{}", path_alias);
580            self.context.add_variable(
581                path_length_var.clone(),
582                VariableInfo {
583                    name: path_length_var,
584                    data_type: LogicalType::Int64,
585                    is_node: false,
586                    is_edge: false,
587                },
588            );
589        }
590
591        Ok(())
592    }
593
594    /// Binds a filter operator.
595    fn bind_filter(&mut self, filter: &FilterOp) -> Result<()> {
596        // First bind the input
597        self.bind_operator(&filter.input)?;
598
599        // Validate the predicate expression
600        self.validate_expression(&filter.predicate)?;
601
602        Ok(())
603    }
604
605    /// Binds a return operator.
606    fn bind_return(&mut self, ret: &ReturnOp) -> Result<()> {
607        // First bind the input
608        self.bind_operator(&ret.input)?;
609
610        // Validate all return expressions
611        for item in &ret.items {
612            self.validate_return_item(item)?;
613        }
614
615        Ok(())
616    }
617
618    /// Validates a return item.
619    fn validate_return_item(&self, item: &ReturnItem) -> Result<()> {
620        self.validate_expression(&item.expression)
621    }
622
623    /// Validates that an expression only references defined variables.
624    fn validate_expression(&self, expr: &LogicalExpression) -> Result<()> {
625        match expr {
626            LogicalExpression::Variable(name) => {
627                if !self.context.contains(name) && !name.starts_with("_anon_") {
628                    return Err(binding_error(format!("Undefined variable '{name}'")));
629                }
630                Ok(())
631            }
632            LogicalExpression::Property { variable, .. } => {
633                if !self.context.contains(variable) && !variable.starts_with("_anon_") {
634                    return Err(binding_error(format!(
635                        "Undefined variable '{variable}' in property access"
636                    )));
637                }
638                Ok(())
639            }
640            LogicalExpression::Literal(_) => Ok(()),
641            LogicalExpression::Binary { left, right, .. } => {
642                self.validate_expression(left)?;
643                self.validate_expression(right)
644            }
645            LogicalExpression::Unary { operand, .. } => self.validate_expression(operand),
646            LogicalExpression::FunctionCall { args, .. } => {
647                for arg in args {
648                    self.validate_expression(arg)?;
649                }
650                Ok(())
651            }
652            LogicalExpression::List(items) => {
653                for item in items {
654                    self.validate_expression(item)?;
655                }
656                Ok(())
657            }
658            LogicalExpression::Map(pairs) => {
659                for (_, value) in pairs {
660                    self.validate_expression(value)?;
661                }
662                Ok(())
663            }
664            LogicalExpression::IndexAccess { base, index } => {
665                self.validate_expression(base)?;
666                self.validate_expression(index)
667            }
668            LogicalExpression::SliceAccess { base, start, end } => {
669                self.validate_expression(base)?;
670                if let Some(s) = start {
671                    self.validate_expression(s)?;
672                }
673                if let Some(e) = end {
674                    self.validate_expression(e)?;
675                }
676                Ok(())
677            }
678            LogicalExpression::Case {
679                operand,
680                when_clauses,
681                else_clause,
682            } => {
683                if let Some(op) = operand {
684                    self.validate_expression(op)?;
685                }
686                for (cond, result) in when_clauses {
687                    self.validate_expression(cond)?;
688                    self.validate_expression(result)?;
689                }
690                if let Some(else_expr) = else_clause {
691                    self.validate_expression(else_expr)?;
692                }
693                Ok(())
694            }
695            // Parameter references are validated externally
696            LogicalExpression::Parameter(_) => Ok(()),
697            // labels(n), type(e), id(n) need the variable to be defined
698            LogicalExpression::Labels(var)
699            | LogicalExpression::Type(var)
700            | LogicalExpression::Id(var) => {
701                if !self.context.contains(var) && !var.starts_with("_anon_") {
702                    return Err(binding_error(format!(
703                        "Undefined variable '{var}' in function"
704                    )));
705                }
706                Ok(())
707            }
708            LogicalExpression::ListComprehension {
709                list_expr,
710                filter_expr,
711                map_expr,
712                ..
713            } => {
714                // Validate the list expression
715                self.validate_expression(list_expr)?;
716                // Note: filter_expr and map_expr use the comprehension variable
717                // which is defined within the comprehension scope, so we don't
718                // need to validate it against the outer context
719                if let Some(filter) = filter_expr {
720                    self.validate_expression(filter)?;
721                }
722                self.validate_expression(map_expr)?;
723                Ok(())
724            }
725            LogicalExpression::ExistsSubquery(subquery)
726            | LogicalExpression::CountSubquery(subquery) => {
727                // Subqueries have their own binding context
728                // For now, just validate the structure exists
729                let _ = subquery; // Would need recursive binding
730                Ok(())
731            }
732        }
733    }
734
735    /// Infers the type of an expression for use in WITH clause aliasing.
736    fn infer_expression_type(&self, expr: &LogicalExpression) -> LogicalType {
737        match expr {
738            LogicalExpression::Variable(name) => {
739                // Look up the variable type from context
740                self.context
741                    .get(name)
742                    .map(|info| info.data_type.clone())
743                    .unwrap_or(LogicalType::Any)
744            }
745            LogicalExpression::Property { .. } => LogicalType::Any, // Properties can be any type
746            LogicalExpression::Literal(value) => {
747                // Infer type from literal value
748                use grafeo_common::types::Value;
749                match value {
750                    Value::Bool(_) => LogicalType::Bool,
751                    Value::Int64(_) => LogicalType::Int64,
752                    Value::Float64(_) => LogicalType::Float64,
753                    Value::String(_) => LogicalType::String,
754                    Value::List(_) => LogicalType::Any, // Complex type
755                    Value::Map(_) => LogicalType::Any,  // Complex type
756                    Value::Null => LogicalType::Any,
757                    _ => LogicalType::Any,
758                }
759            }
760            LogicalExpression::Binary { .. } => LogicalType::Any, // Could be bool or numeric
761            LogicalExpression::Unary { .. } => LogicalType::Any,
762            LogicalExpression::FunctionCall { name, .. } => {
763                // Infer based on function name
764                match name.to_lowercase().as_str() {
765                    "count" | "sum" | "id" => LogicalType::Int64,
766                    "avg" => LogicalType::Float64,
767                    "type" => LogicalType::String,
768                    // List-returning functions use Any since we don't track element type
769                    "labels" | "collect" => LogicalType::Any,
770                    _ => LogicalType::Any,
771                }
772            }
773            LogicalExpression::List(_) => LogicalType::Any, // Complex type
774            LogicalExpression::Map(_) => LogicalType::Any,  // Complex type
775            _ => LogicalType::Any,
776        }
777    }
778
779    /// Binds a join operator.
780    fn bind_join(&mut self, join: &crate::query::plan::JoinOp) -> Result<()> {
781        // Bind both sides of the join
782        self.bind_operator(&join.left)?;
783        self.bind_operator(&join.right)?;
784
785        // Validate join conditions
786        for condition in &join.conditions {
787            self.validate_expression(&condition.left)?;
788            self.validate_expression(&condition.right)?;
789        }
790
791        Ok(())
792    }
793
794    /// Binds an aggregate operator.
795    fn bind_aggregate(&mut self, agg: &crate::query::plan::AggregateOp) -> Result<()> {
796        // Bind the input first
797        self.bind_operator(&agg.input)?;
798
799        // Validate group by expressions
800        for expr in &agg.group_by {
801            self.validate_expression(expr)?;
802        }
803
804        // Validate aggregate expressions
805        for agg_expr in &agg.aggregates {
806            if let Some(ref expr) = agg_expr.expression {
807                self.validate_expression(expr)?;
808            }
809            // Add the alias as a new variable if present
810            if let Some(ref alias) = agg_expr.alias {
811                self.context.add_variable(
812                    alias.clone(),
813                    VariableInfo {
814                        name: alias.clone(),
815                        data_type: LogicalType::Any,
816                        is_node: false,
817                        is_edge: false,
818                    },
819                );
820            }
821        }
822
823        Ok(())
824    }
825}
826
827impl Default for Binder {
828    fn default() -> Self {
829        Self::new()
830    }
831}
832
833#[cfg(test)]
834mod tests {
835    use super::*;
836    use crate::query::plan::{BinaryOp, FilterOp};
837
838    #[test]
839    fn test_bind_simple_scan() {
840        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
841            items: vec![ReturnItem {
842                expression: LogicalExpression::Variable("n".to_string()),
843                alias: None,
844            }],
845            distinct: false,
846            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
847                variable: "n".to_string(),
848                label: Some("Person".to_string()),
849                input: None,
850            })),
851        }));
852
853        let mut binder = Binder::new();
854        let result = binder.bind(&plan);
855
856        assert!(result.is_ok());
857        let ctx = result.unwrap();
858        assert!(ctx.contains("n"));
859        assert!(ctx.get("n").unwrap().is_node);
860    }
861
862    #[test]
863    fn test_bind_undefined_variable() {
864        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
865            items: vec![ReturnItem {
866                expression: LogicalExpression::Variable("undefined".to_string()),
867                alias: None,
868            }],
869            distinct: false,
870            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
871                variable: "n".to_string(),
872                label: None,
873                input: None,
874            })),
875        }));
876
877        let mut binder = Binder::new();
878        let result = binder.bind(&plan);
879
880        assert!(result.is_err());
881        let err = result.unwrap_err();
882        assert!(err.to_string().contains("Undefined variable"));
883    }
884
885    #[test]
886    fn test_bind_property_access() {
887        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
888            items: vec![ReturnItem {
889                expression: LogicalExpression::Property {
890                    variable: "n".to_string(),
891                    property: "name".to_string(),
892                },
893                alias: None,
894            }],
895            distinct: false,
896            input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
897                variable: "n".to_string(),
898                label: Some("Person".to_string()),
899                input: None,
900            })),
901        }));
902
903        let mut binder = Binder::new();
904        let result = binder.bind(&plan);
905
906        assert!(result.is_ok());
907    }
908
909    #[test]
910    fn test_bind_filter_with_undefined_variable() {
911        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
912            items: vec![ReturnItem {
913                expression: LogicalExpression::Variable("n".to_string()),
914                alias: None,
915            }],
916            distinct: false,
917            input: Box::new(LogicalOperator::Filter(FilterOp {
918                predicate: LogicalExpression::Binary {
919                    left: Box::new(LogicalExpression::Property {
920                        variable: "m".to_string(), // undefined!
921                        property: "age".to_string(),
922                    }),
923                    op: BinaryOp::Gt,
924                    right: Box::new(LogicalExpression::Literal(
925                        grafeo_common::types::Value::Int64(30),
926                    )),
927                },
928                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
929                    variable: "n".to_string(),
930                    label: None,
931                    input: None,
932                })),
933            })),
934        }));
935
936        let mut binder = Binder::new();
937        let result = binder.bind(&plan);
938
939        assert!(result.is_err());
940        let err = result.unwrap_err();
941        assert!(err.to_string().contains("Undefined variable 'm'"));
942    }
943
944    #[test]
945    fn test_bind_expand() {
946        use crate::query::plan::{ExpandDirection, ExpandOp};
947
948        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
949            items: vec![
950                ReturnItem {
951                    expression: LogicalExpression::Variable("a".to_string()),
952                    alias: None,
953                },
954                ReturnItem {
955                    expression: LogicalExpression::Variable("b".to_string()),
956                    alias: None,
957                },
958            ],
959            distinct: false,
960            input: Box::new(LogicalOperator::Expand(ExpandOp {
961                from_variable: "a".to_string(),
962                to_variable: "b".to_string(),
963                edge_variable: Some("e".to_string()),
964                direction: ExpandDirection::Outgoing,
965                edge_type: Some("KNOWS".to_string()),
966                min_hops: 1,
967                max_hops: Some(1),
968                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
969                    variable: "a".to_string(),
970                    label: Some("Person".to_string()),
971                    input: None,
972                })),
973                path_alias: None,
974            })),
975        }));
976
977        let mut binder = Binder::new();
978        let result = binder.bind(&plan);
979
980        assert!(result.is_ok());
981        let ctx = result.unwrap();
982        assert!(ctx.contains("a"));
983        assert!(ctx.contains("b"));
984        assert!(ctx.contains("e"));
985        assert!(ctx.get("a").unwrap().is_node);
986        assert!(ctx.get("b").unwrap().is_node);
987        assert!(ctx.get("e").unwrap().is_edge);
988    }
989}