Skip to main content

graphos_engine/query/
gql_translator.rs

1//! GQL to LogicalPlan translator.
2//!
3//! Translates GQL AST to the common logical plan representation.
4
5use crate::query::plan::{
6    AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DeleteNodeOp, DistinctOp,
7    ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp, LogicalExpression,
8    LogicalOperator, LogicalPlan, NodeScanOp, ProjectOp, Projection, ReturnItem, ReturnOp,
9    SetPropertyOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
10};
11use graphos_adapters::query::gql::{self, ast};
12use graphos_common::types::Value;
13use graphos_common::utils::error::{Error, Result};
14
15/// Translates a GQL query string to a logical plan.
16///
17/// # Errors
18///
19/// Returns an error if the query cannot be parsed or translated.
20pub fn translate(query: &str) -> Result<LogicalPlan> {
21    let statement = gql::parse(query)?;
22    let translator = GqlTranslator::new();
23    translator.translate_statement(&statement)
24}
25
26/// Translator from GQL AST to LogicalPlan.
27struct GqlTranslator;
28
29impl GqlTranslator {
30    fn new() -> Self {
31        Self
32    }
33
34    fn translate_statement(&self, stmt: &ast::Statement) -> Result<LogicalPlan> {
35        match stmt {
36            ast::Statement::Query(query) => self.translate_query(query),
37            ast::Statement::DataModification(dm) => self.translate_data_modification(dm),
38            ast::Statement::Schema(_) => Err(Error::Internal(
39                "Schema statements not yet supported".to_string(),
40            )),
41        }
42    }
43
44    fn translate_query(&self, query: &ast::QueryStatement) -> Result<LogicalPlan> {
45        // Start with the pattern scan (MATCH clauses)
46        let mut plan = LogicalOperator::Empty;
47
48        for match_clause in &query.match_clauses {
49            let match_plan = self.translate_match(match_clause)?;
50            if matches!(plan, LogicalOperator::Empty) {
51                plan = match_plan;
52            } else if match_clause.optional {
53                // OPTIONAL MATCH uses LEFT JOIN semantics
54                plan = LogicalOperator::LeftJoin(LeftJoinOp {
55                    left: Box::new(plan),
56                    right: Box::new(match_plan),
57                    condition: None,
58                });
59            } else {
60                // Regular MATCH - combine with cross join (implicit join on shared variables)
61                plan = LogicalOperator::Join(JoinOp {
62                    left: Box::new(plan),
63                    right: Box::new(match_plan),
64                    join_type: JoinType::Cross,
65                    conditions: vec![],
66                });
67            }
68        }
69
70        // Apply WHERE filter
71        if let Some(where_clause) = &query.where_clause {
72            let predicate = self.translate_expression(&where_clause.expression)?;
73            plan = LogicalOperator::Filter(FilterOp {
74                predicate,
75                input: Box::new(plan),
76            });
77        }
78
79        // Handle WITH clauses (projection for query chaining)
80        for with_clause in &query.with_clauses {
81            let projections: Vec<Projection> = with_clause
82                .items
83                .iter()
84                .map(|item| {
85                    Ok(Projection {
86                        expression: self.translate_expression(&item.expression)?,
87                        alias: item.alias.clone(),
88                    })
89                })
90                .collect::<Result<_>>()?;
91
92            plan = LogicalOperator::Project(ProjectOp {
93                projections,
94                input: Box::new(plan),
95            });
96
97            // Apply WHERE filter if present in WITH clause
98            if let Some(where_clause) = &with_clause.where_clause {
99                let predicate = self.translate_expression(&where_clause.expression)?;
100                plan = LogicalOperator::Filter(FilterOp {
101                    predicate,
102                    input: Box::new(plan),
103                });
104            }
105
106            // Handle DISTINCT
107            if with_clause.distinct {
108                plan = LogicalOperator::Distinct(DistinctOp {
109                    input: Box::new(plan),
110                });
111            }
112        }
113
114        // Apply SKIP
115        if let Some(skip_expr) = &query.return_clause.skip {
116            if let ast::Expression::Literal(ast::Literal::Integer(n)) = skip_expr {
117                plan = LogicalOperator::Skip(SkipOp {
118                    count: *n as usize,
119                    input: Box::new(plan),
120                });
121            }
122        }
123
124        // Apply LIMIT
125        if let Some(limit_expr) = &query.return_clause.limit {
126            if let ast::Expression::Literal(ast::Literal::Integer(n)) = limit_expr {
127                plan = LogicalOperator::Limit(LimitOp {
128                    count: *n as usize,
129                    input: Box::new(plan),
130                });
131            }
132        }
133
134        // Check if RETURN contains aggregate functions
135        let has_aggregates = query
136            .return_clause
137            .items
138            .iter()
139            .any(|item| contains_aggregate(&item.expression));
140
141        if has_aggregates {
142            // Extract aggregate and group-by expressions
143            let (aggregates, group_by) =
144                self.extract_aggregates_and_groups(&query.return_clause.items)?;
145
146            // Insert Aggregate operator - this is the final operator for aggregate queries
147            // The aggregate operator produces the output columns directly
148            plan = LogicalOperator::Aggregate(AggregateOp {
149                group_by,
150                aggregates,
151                input: Box::new(plan),
152            });
153
154            // Note: For aggregate queries, we don't add a Return operator
155            // because Aggregate already produces the final output
156        } else {
157            // Apply ORDER BY
158            if let Some(order_by) = &query.return_clause.order_by {
159                let keys = order_by
160                    .items
161                    .iter()
162                    .map(|item| {
163                        Ok(SortKey {
164                            expression: self.translate_expression(&item.expression)?,
165                            order: match item.order {
166                                ast::SortOrder::Asc => SortOrder::Ascending,
167                                ast::SortOrder::Desc => SortOrder::Descending,
168                            },
169                        })
170                    })
171                    .collect::<Result<Vec<_>>>()?;
172
173                plan = LogicalOperator::Sort(SortOp {
174                    keys,
175                    input: Box::new(plan),
176                });
177            }
178
179            // Apply RETURN
180            let return_items = query
181                .return_clause
182                .items
183                .iter()
184                .map(|item| {
185                    Ok(ReturnItem {
186                        expression: self.translate_expression(&item.expression)?,
187                        alias: item.alias.clone(),
188                    })
189                })
190                .collect::<Result<Vec<_>>>()?;
191
192            plan = LogicalOperator::Return(ReturnOp {
193                items: return_items,
194                distinct: query.return_clause.distinct,
195                input: Box::new(plan),
196            });
197        }
198
199        Ok(LogicalPlan::new(plan))
200    }
201
202    /// Builds return items for an aggregate query.
203    #[allow(dead_code)]
204    fn build_aggregate_return_items(&self, items: &[ast::ReturnItem]) -> Result<Vec<ReturnItem>> {
205        let mut return_items = Vec::new();
206        let mut agg_idx = 0;
207
208        for item in items {
209            if contains_aggregate(&item.expression) {
210                // For aggregate expressions, use a variable reference to the aggregate result
211                let alias = item.alias.clone().unwrap_or_else(|| {
212                    if let ast::Expression::FunctionCall { name, .. } = &item.expression {
213                        format!("{}(...)", name.to_lowercase())
214                    } else {
215                        format!("agg_{}", agg_idx)
216                    }
217                });
218                return_items.push(ReturnItem {
219                    expression: LogicalExpression::Variable(format!("__agg_{}", agg_idx)),
220                    alias: Some(alias),
221                });
222                agg_idx += 1;
223            } else {
224                // Non-aggregate expressions are group-by columns
225                return_items.push(ReturnItem {
226                    expression: self.translate_expression(&item.expression)?,
227                    alias: item.alias.clone(),
228                });
229            }
230        }
231
232        Ok(return_items)
233    }
234
235    fn translate_match(&self, match_clause: &ast::MatchClause) -> Result<LogicalOperator> {
236        let mut plan: Option<LogicalOperator> = None;
237
238        for pattern in &match_clause.patterns {
239            let pattern_plan = self.translate_pattern(pattern, plan.take())?;
240            plan = Some(pattern_plan);
241        }
242
243        plan.ok_or_else(|| Error::Internal("Empty MATCH clause".to_string()))
244    }
245
246    fn translate_pattern(
247        &self,
248        pattern: &ast::Pattern,
249        input: Option<LogicalOperator>,
250    ) -> Result<LogicalOperator> {
251        match pattern {
252            ast::Pattern::Node(node) => self.translate_node_pattern(node, input),
253            ast::Pattern::Path(path) => self.translate_path_pattern(path, input),
254        }
255    }
256
257    fn translate_node_pattern(
258        &self,
259        node: &ast::NodePattern,
260        input: Option<LogicalOperator>,
261    ) -> Result<LogicalOperator> {
262        let variable = node
263            .variable
264            .clone()
265            .unwrap_or_else(|| format!("_anon_{}", rand_id()));
266
267        let label = node.labels.first().cloned();
268
269        Ok(LogicalOperator::NodeScan(NodeScanOp {
270            variable,
271            label,
272            input: input.map(Box::new),
273        }))
274    }
275
276    fn translate_path_pattern(
277        &self,
278        path: &ast::PathPattern,
279        input: Option<LogicalOperator>,
280    ) -> Result<LogicalOperator> {
281        // Start with the source node
282        let source_var = path
283            .source
284            .variable
285            .clone()
286            .unwrap_or_else(|| format!("_anon_{}", rand_id()));
287
288        let source_label = path.source.labels.first().cloned();
289
290        let mut plan = LogicalOperator::NodeScan(NodeScanOp {
291            variable: source_var.clone(),
292            label: source_label,
293            input: input.map(Box::new),
294        });
295
296        // Process each edge in the chain
297        let mut current_source = source_var;
298
299        for edge in &path.edges {
300            let target_var = edge
301                .target
302                .variable
303                .clone()
304                .unwrap_or_else(|| format!("_anon_{}", rand_id()));
305
306            let edge_var = edge.variable.clone();
307            let edge_type = edge.types.first().cloned();
308
309            let direction = match edge.direction {
310                ast::EdgeDirection::Outgoing => ExpandDirection::Outgoing,
311                ast::EdgeDirection::Incoming => ExpandDirection::Incoming,
312                ast::EdgeDirection::Undirected => ExpandDirection::Both,
313            };
314
315            plan = LogicalOperator::Expand(ExpandOp {
316                from_variable: current_source,
317                to_variable: target_var.clone(),
318                edge_variable: edge_var,
319                direction,
320                edge_type,
321                min_hops: 1,
322                max_hops: Some(1),
323                input: Box::new(plan),
324            });
325
326            current_source = target_var;
327        }
328
329        Ok(plan)
330    }
331
332    fn translate_data_modification(
333        &self,
334        dm: &ast::DataModificationStatement,
335    ) -> Result<LogicalPlan> {
336        match dm {
337            ast::DataModificationStatement::Insert(insert) => self.translate_insert(insert),
338            ast::DataModificationStatement::Delete(delete) => self.translate_delete(delete),
339            ast::DataModificationStatement::Set(set) => self.translate_set(set),
340        }
341    }
342
343    fn translate_delete(&self, delete: &ast::DeleteStatement) -> Result<LogicalPlan> {
344        // DELETE requires a preceding MATCH clause to identify what to delete.
345        // For standalone DELETE, we need to scan and delete the specified variables.
346        // This is typically used as: MATCH (n:Label) DELETE n
347
348        if delete.variables.is_empty() {
349            return Err(Error::Internal(
350                "DELETE requires at least one variable".to_string(),
351            ));
352        }
353
354        // For now, we only support deleting nodes (not edges directly)
355        // Build a chain of delete operators
356        let first_var = &delete.variables[0];
357
358        // Create a scan to find the entities to delete
359        let scan = LogicalOperator::NodeScan(NodeScanOp {
360            variable: first_var.clone(),
361            label: None,
362            input: None,
363        });
364
365        // Delete the first variable
366        let mut plan = LogicalOperator::DeleteNode(DeleteNodeOp {
367            variable: first_var.clone(),
368            input: Box::new(scan),
369        });
370
371        // Chain additional deletes
372        for var in delete.variables.iter().skip(1) {
373            plan = LogicalOperator::DeleteNode(DeleteNodeOp {
374                variable: var.clone(),
375                input: Box::new(plan),
376            });
377        }
378
379        Ok(LogicalPlan::new(plan))
380    }
381
382    fn translate_set(&self, set: &ast::SetStatement) -> Result<LogicalPlan> {
383        // SET requires a preceding MATCH clause to identify what to update.
384        // For standalone SET, we error - it should be part of a query.
385
386        if set.assignments.is_empty() {
387            return Err(Error::Internal(
388                "SET requires at least one assignment".to_string(),
389            ));
390        }
391
392        // Group assignments by variable
393        let first_assignment = &set.assignments[0];
394        let var = &first_assignment.variable;
395
396        // Create a scan to find the entity to update
397        let scan = LogicalOperator::NodeScan(NodeScanOp {
398            variable: var.clone(),
399            label: None,
400            input: None,
401        });
402
403        // Build property assignments for this variable
404        let properties: Vec<(String, LogicalExpression)> = set
405            .assignments
406            .iter()
407            .filter(|a| &a.variable == var)
408            .map(|a| Ok((a.property.clone(), self.translate_expression(&a.value)?)))
409            .collect::<Result<_>>()?;
410
411        let plan = LogicalOperator::SetProperty(SetPropertyOp {
412            variable: var.clone(),
413            properties,
414            replace: false,
415            input: Box::new(scan),
416        });
417
418        Ok(LogicalPlan::new(plan))
419    }
420
421    fn translate_insert(&self, insert: &ast::InsertStatement) -> Result<LogicalPlan> {
422        // For now, just translate insert patterns as creates
423        // A full implementation would handle multiple patterns
424
425        if insert.patterns.is_empty() {
426            return Err(Error::Internal("Empty INSERT statement".to_string()));
427        }
428
429        let pattern = &insert.patterns[0];
430
431        match pattern {
432            ast::Pattern::Node(node) => {
433                let variable = node
434                    .variable
435                    .clone()
436                    .unwrap_or_else(|| format!("_anon_{}", rand_id()));
437
438                let properties = node
439                    .properties
440                    .iter()
441                    .map(|(k, v)| Ok((k.clone(), self.translate_expression(v)?)))
442                    .collect::<Result<Vec<_>>>()?;
443
444                let create = LogicalOperator::CreateNode(crate::query::plan::CreateNodeOp {
445                    variable: variable.clone(),
446                    labels: node.labels.clone(),
447                    properties,
448                    input: None,
449                });
450
451                // Return the created node
452                let ret = LogicalOperator::Return(ReturnOp {
453                    items: vec![ReturnItem {
454                        expression: LogicalExpression::Variable(variable),
455                        alias: None,
456                    }],
457                    distinct: false,
458                    input: Box::new(create),
459                });
460
461                Ok(LogicalPlan::new(ret))
462            }
463            ast::Pattern::Path(_) => {
464                Err(Error::Internal("Path INSERT not yet supported".to_string()))
465            }
466        }
467    }
468
469    fn translate_expression(&self, expr: &ast::Expression) -> Result<LogicalExpression> {
470        match expr {
471            ast::Expression::Literal(lit) => Ok(self.translate_literal(lit)),
472            ast::Expression::Variable(name) => Ok(LogicalExpression::Variable(name.clone())),
473            ast::Expression::Parameter(name) => Ok(LogicalExpression::Parameter(name.clone())),
474            ast::Expression::PropertyAccess { variable, property } => {
475                Ok(LogicalExpression::Property {
476                    variable: variable.clone(),
477                    property: property.clone(),
478                })
479            }
480            ast::Expression::Binary { left, op, right } => {
481                let left = self.translate_expression(left)?;
482                let right = self.translate_expression(right)?;
483                let op = self.translate_binary_op(*op);
484                Ok(LogicalExpression::Binary {
485                    left: Box::new(left),
486                    op,
487                    right: Box::new(right),
488                })
489            }
490            ast::Expression::Unary { op, operand } => {
491                let operand = self.translate_expression(operand)?;
492                let op = self.translate_unary_op(*op);
493                Ok(LogicalExpression::Unary {
494                    op,
495                    operand: Box::new(operand),
496                })
497            }
498            ast::Expression::FunctionCall { name, args } => {
499                let args = args
500                    .iter()
501                    .map(|a| self.translate_expression(a))
502                    .collect::<Result<Vec<_>>>()?;
503                Ok(LogicalExpression::FunctionCall {
504                    name: name.clone(),
505                    args,
506                })
507            }
508            ast::Expression::List(items) => {
509                let items = items
510                    .iter()
511                    .map(|i| self.translate_expression(i))
512                    .collect::<Result<Vec<_>>>()?;
513                Ok(LogicalExpression::List(items))
514            }
515            ast::Expression::Case {
516                input,
517                whens,
518                else_clause,
519            } => {
520                let operand = input
521                    .as_ref()
522                    .map(|e| self.translate_expression(e))
523                    .transpose()?
524                    .map(Box::new);
525
526                let when_clauses = whens
527                    .iter()
528                    .map(|(cond, result)| {
529                        Ok((
530                            self.translate_expression(cond)?,
531                            self.translate_expression(result)?,
532                        ))
533                    })
534                    .collect::<Result<Vec<_>>>()?;
535
536                let else_clause = else_clause
537                    .as_ref()
538                    .map(|e| self.translate_expression(e))
539                    .transpose()?
540                    .map(Box::new);
541
542                Ok(LogicalExpression::Case {
543                    operand,
544                    when_clauses,
545                    else_clause,
546                })
547            }
548            ast::Expression::ExistsSubquery { query } => {
549                // Translate inner query to logical operator
550                let inner_plan = self.translate_subquery_to_operator(query)?;
551                Ok(LogicalExpression::ExistsSubquery(Box::new(inner_plan)))
552            }
553        }
554    }
555
556    fn translate_literal(&self, lit: &ast::Literal) -> LogicalExpression {
557        let value = match lit {
558            ast::Literal::Null => Value::Null,
559            ast::Literal::Bool(b) => Value::Bool(*b),
560            ast::Literal::Integer(i) => Value::Int64(*i),
561            ast::Literal::Float(f) => Value::Float64(*f),
562            ast::Literal::String(s) => Value::String(s.clone().into()),
563        };
564        LogicalExpression::Literal(value)
565    }
566
567    fn translate_binary_op(&self, op: ast::BinaryOp) -> BinaryOp {
568        match op {
569            ast::BinaryOp::Eq => BinaryOp::Eq,
570            ast::BinaryOp::Ne => BinaryOp::Ne,
571            ast::BinaryOp::Lt => BinaryOp::Lt,
572            ast::BinaryOp::Le => BinaryOp::Le,
573            ast::BinaryOp::Gt => BinaryOp::Gt,
574            ast::BinaryOp::Ge => BinaryOp::Ge,
575            ast::BinaryOp::And => BinaryOp::And,
576            ast::BinaryOp::Or => BinaryOp::Or,
577            ast::BinaryOp::Add => BinaryOp::Add,
578            ast::BinaryOp::Sub => BinaryOp::Sub,
579            ast::BinaryOp::Mul => BinaryOp::Mul,
580            ast::BinaryOp::Div => BinaryOp::Div,
581            ast::BinaryOp::Mod => BinaryOp::Mod,
582            ast::BinaryOp::Concat => BinaryOp::Concat,
583            ast::BinaryOp::Like => BinaryOp::Like,
584            ast::BinaryOp::In => BinaryOp::In,
585        }
586    }
587
588    fn translate_unary_op(&self, op: ast::UnaryOp) -> UnaryOp {
589        match op {
590            ast::UnaryOp::Not => UnaryOp::Not,
591            ast::UnaryOp::Neg => UnaryOp::Neg,
592            ast::UnaryOp::IsNull => UnaryOp::IsNull,
593            ast::UnaryOp::IsNotNull => UnaryOp::IsNotNull,
594        }
595    }
596
597    /// Translates a subquery to a logical operator (without Return).
598    fn translate_subquery_to_operator(
599        &self,
600        query: &ast::QueryStatement,
601    ) -> Result<LogicalOperator> {
602        let mut plan = LogicalOperator::Empty;
603
604        for match_clause in &query.match_clauses {
605            let match_plan = self.translate_match(match_clause)?;
606            plan = if matches!(plan, LogicalOperator::Empty) {
607                match_plan
608            } else {
609                LogicalOperator::Join(JoinOp {
610                    left: Box::new(plan),
611                    right: Box::new(match_plan),
612                    join_type: JoinType::Cross,
613                    conditions: vec![],
614                })
615            };
616        }
617
618        if let Some(where_clause) = &query.where_clause {
619            let predicate = self.translate_expression(&where_clause.expression)?;
620            plan = LogicalOperator::Filter(FilterOp {
621                predicate,
622                input: Box::new(plan),
623            });
624        }
625
626        Ok(plan)
627    }
628
629    /// Extracts aggregate expressions and group-by expressions from RETURN items.
630    fn extract_aggregates_and_groups(
631        &self,
632        items: &[ast::ReturnItem],
633    ) -> Result<(Vec<AggregateExpr>, Vec<LogicalExpression>)> {
634        let mut aggregates = Vec::new();
635        let mut group_by = Vec::new();
636
637        for item in items {
638            if let Some(agg_expr) = self.try_extract_aggregate(&item.expression, &item.alias)? {
639                aggregates.push(agg_expr);
640            } else {
641                // Non-aggregate expressions become group-by keys
642                let expr = self.translate_expression(&item.expression)?;
643                group_by.push(expr);
644            }
645        }
646
647        Ok((aggregates, group_by))
648    }
649
650    /// Tries to extract an aggregate expression from an AST expression.
651    fn try_extract_aggregate(
652        &self,
653        expr: &ast::Expression,
654        alias: &Option<String>,
655    ) -> Result<Option<AggregateExpr>> {
656        match expr {
657            ast::Expression::FunctionCall { name, args } => {
658                if let Some(func) = to_aggregate_function(name) {
659                    let agg_expr = if args.is_empty() {
660                        // COUNT(*) case
661                        AggregateExpr {
662                            function: func,
663                            expression: None,
664                            distinct: false,
665                            alias: alias.clone(),
666                        }
667                    } else {
668                        // COUNT(x), SUM(x), etc.
669                        AggregateExpr {
670                            function: func,
671                            expression: Some(self.translate_expression(&args[0])?),
672                            distinct: false,
673                            alias: alias.clone(),
674                        }
675                    };
676                    Ok(Some(agg_expr))
677                } else {
678                    Ok(None)
679                }
680            }
681            _ => Ok(None),
682        }
683    }
684}
685
686/// Generate a simple random-ish ID for anonymous variables.
687fn rand_id() -> u32 {
688    use std::sync::atomic::{AtomicU32, Ordering};
689    static COUNTER: AtomicU32 = AtomicU32::new(0);
690    COUNTER.fetch_add(1, Ordering::Relaxed)
691}
692
693/// Returns true if the function name is an aggregate function.
694fn is_aggregate_function(name: &str) -> bool {
695    matches!(
696        name.to_uppercase().as_str(),
697        "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COLLECT"
698    )
699}
700
701/// Converts a function name to an AggregateFunction enum.
702fn to_aggregate_function(name: &str) -> Option<AggregateFunction> {
703    match name.to_uppercase().as_str() {
704        "COUNT" => Some(AggregateFunction::Count),
705        "SUM" => Some(AggregateFunction::Sum),
706        "AVG" => Some(AggregateFunction::Avg),
707        "MIN" => Some(AggregateFunction::Min),
708        "MAX" => Some(AggregateFunction::Max),
709        "COLLECT" => Some(AggregateFunction::Collect),
710        _ => None,
711    }
712}
713
714/// Checks if an AST expression contains an aggregate function call.
715fn contains_aggregate(expr: &ast::Expression) -> bool {
716    match expr {
717        ast::Expression::FunctionCall { name, .. } => is_aggregate_function(name),
718        ast::Expression::Binary { left, right, .. } => {
719            contains_aggregate(left) || contains_aggregate(right)
720        }
721        ast::Expression::Unary { operand, .. } => contains_aggregate(operand),
722        _ => false,
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729
730    // === Basic MATCH Tests ===
731
732    #[test]
733    fn test_translate_simple_match() {
734        let query = "MATCH (n:Person) RETURN n";
735        let result = translate(query);
736        assert!(result.is_ok());
737
738        let plan = result.unwrap();
739        if let LogicalOperator::Return(ret) = &plan.root {
740            assert_eq!(ret.items.len(), 1);
741            assert!(!ret.distinct);
742        } else {
743            panic!("Expected Return operator");
744        }
745    }
746
747    #[test]
748    fn test_translate_match_with_where() {
749        let query = "MATCH (n:Person) WHERE n.age > 30 RETURN n.name";
750        let result = translate(query);
751        assert!(result.is_ok());
752
753        let plan = result.unwrap();
754        if let LogicalOperator::Return(ret) = &plan.root {
755            // Should have Filter as input
756            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
757                if let LogicalExpression::Binary { op, .. } = &filter.predicate {
758                    assert_eq!(*op, BinaryOp::Gt);
759                } else {
760                    panic!("Expected binary expression");
761                }
762            } else {
763                panic!("Expected Filter operator");
764            }
765        } else {
766            panic!("Expected Return operator");
767        }
768    }
769
770    #[test]
771    fn test_translate_match_without_label() {
772        let query = "MATCH (n) RETURN n";
773        let result = translate(query);
774        assert!(result.is_ok());
775
776        let plan = result.unwrap();
777        if let LogicalOperator::Return(ret) = &plan.root {
778            if let LogicalOperator::NodeScan(scan) = ret.input.as_ref() {
779                assert!(scan.label.is_none());
780            } else {
781                panic!("Expected NodeScan operator");
782            }
783        } else {
784            panic!("Expected Return operator");
785        }
786    }
787
788    #[test]
789    fn test_translate_match_distinct() {
790        let query = "MATCH (n:Person) RETURN DISTINCT n.name";
791        let result = translate(query);
792        assert!(result.is_ok());
793
794        let plan = result.unwrap();
795        if let LogicalOperator::Return(ret) = &plan.root {
796            assert!(ret.distinct);
797        } else {
798            panic!("Expected Return operator");
799        }
800    }
801
802    // === Filter and Predicate Tests ===
803
804    #[test]
805    fn test_translate_filter_equality() {
806        let query = "MATCH (n:Person) WHERE n.name = 'Alice' RETURN n";
807        let result = translate(query);
808        assert!(result.is_ok());
809
810        let plan = result.unwrap();
811        // Navigate to find Filter
812        fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
813            match op {
814                LogicalOperator::Filter(f) => Some(f),
815                LogicalOperator::Return(r) => find_filter(&r.input),
816                _ => None,
817            }
818        }
819
820        let filter = find_filter(&plan.root).expect("Expected Filter");
821        if let LogicalExpression::Binary { op, .. } = &filter.predicate {
822            assert_eq!(*op, BinaryOp::Eq);
823        }
824    }
825
826    #[test]
827    fn test_translate_filter_and() {
828        let query = "MATCH (n:Person) WHERE n.age > 20 AND n.age < 40 RETURN n";
829        let result = translate(query);
830        assert!(result.is_ok());
831
832        let plan = result.unwrap();
833        fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
834            match op {
835                LogicalOperator::Filter(f) => Some(f),
836                LogicalOperator::Return(r) => find_filter(&r.input),
837                _ => None,
838            }
839        }
840
841        let filter = find_filter(&plan.root).expect("Expected Filter");
842        if let LogicalExpression::Binary { op, .. } = &filter.predicate {
843            assert_eq!(*op, BinaryOp::And);
844        }
845    }
846
847    #[test]
848    fn test_translate_filter_or() {
849        let query = "MATCH (n:Person) WHERE n.name = 'Alice' OR n.name = 'Bob' RETURN n";
850        let result = translate(query);
851        assert!(result.is_ok());
852
853        let plan = result.unwrap();
854        fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
855            match op {
856                LogicalOperator::Filter(f) => Some(f),
857                LogicalOperator::Return(r) => find_filter(&r.input),
858                _ => None,
859            }
860        }
861
862        let filter = find_filter(&plan.root).expect("Expected Filter");
863        if let LogicalExpression::Binary { op, .. } = &filter.predicate {
864            assert_eq!(*op, BinaryOp::Or);
865        }
866    }
867
868    #[test]
869    fn test_translate_filter_not() {
870        let query = "MATCH (n:Person) WHERE NOT n.active RETURN n";
871        let result = translate(query);
872        assert!(result.is_ok());
873
874        let plan = result.unwrap();
875        fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
876            match op {
877                LogicalOperator::Filter(f) => Some(f),
878                LogicalOperator::Return(r) => find_filter(&r.input),
879                _ => None,
880            }
881        }
882
883        let filter = find_filter(&plan.root).expect("Expected Filter");
884        if let LogicalExpression::Unary { op, .. } = &filter.predicate {
885            assert_eq!(*op, UnaryOp::Not);
886        }
887    }
888
889    // === Path Pattern / Join Tests ===
890
891    #[test]
892    fn test_translate_path_pattern() {
893        let query = "MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a, b";
894        let result = translate(query);
895        assert!(result.is_ok());
896
897        let plan = result.unwrap();
898        // Find Expand operator
899        fn find_expand(op: &LogicalOperator) -> Option<&ExpandOp> {
900            match op {
901                LogicalOperator::Expand(e) => Some(e),
902                LogicalOperator::Return(r) => find_expand(&r.input),
903                LogicalOperator::Filter(f) => find_expand(&f.input),
904                _ => None,
905            }
906        }
907
908        let expand = find_expand(&plan.root).expect("Expected Expand");
909        assert_eq!(expand.direction, ExpandDirection::Outgoing);
910        assert_eq!(expand.edge_type.as_deref(), Some("KNOWS"));
911    }
912
913    #[test]
914    fn test_translate_incoming_path() {
915        let query = "MATCH (a:Person)<-[:KNOWS]-(b:Person) RETURN a, b";
916        let result = translate(query);
917        assert!(result.is_ok());
918
919        let plan = result.unwrap();
920        fn find_expand(op: &LogicalOperator) -> Option<&ExpandOp> {
921            match op {
922                LogicalOperator::Expand(e) => Some(e),
923                LogicalOperator::Return(r) => find_expand(&r.input),
924                _ => None,
925            }
926        }
927
928        let expand = find_expand(&plan.root).expect("Expected Expand");
929        assert_eq!(expand.direction, ExpandDirection::Incoming);
930    }
931
932    #[test]
933    fn test_translate_undirected_path() {
934        let query = "MATCH (a:Person)-[:KNOWS]-(b:Person) RETURN a, b";
935        let result = translate(query);
936        assert!(result.is_ok());
937
938        let plan = result.unwrap();
939        fn find_expand(op: &LogicalOperator) -> Option<&ExpandOp> {
940            match op {
941                LogicalOperator::Expand(e) => Some(e),
942                LogicalOperator::Return(r) => find_expand(&r.input),
943                _ => None,
944            }
945        }
946
947        let expand = find_expand(&plan.root).expect("Expected Expand");
948        assert_eq!(expand.direction, ExpandDirection::Both);
949    }
950
951    // === Aggregation Tests ===
952
953    #[test]
954    fn test_translate_count_aggregate() {
955        let query = "MATCH (n:Person) RETURN COUNT(n)";
956        let result = translate(query);
957        assert!(result.is_ok());
958
959        let plan = result.unwrap();
960        if let LogicalOperator::Aggregate(agg) = &plan.root {
961            assert_eq!(agg.aggregates.len(), 1);
962            assert_eq!(agg.aggregates[0].function, AggregateFunction::Count);
963        } else {
964            panic!("Expected Aggregate operator, got {:?}", plan.root);
965        }
966    }
967
968    #[test]
969    fn test_translate_sum_aggregate() {
970        let query = "MATCH (n:Person) RETURN SUM(n.age)";
971        let result = translate(query);
972        assert!(result.is_ok());
973
974        let plan = result.unwrap();
975        if let LogicalOperator::Aggregate(agg) = &plan.root {
976            assert_eq!(agg.aggregates.len(), 1);
977            assert_eq!(agg.aggregates[0].function, AggregateFunction::Sum);
978        } else {
979            panic!("Expected Aggregate operator");
980        }
981    }
982
983    #[test]
984    fn test_translate_group_by_aggregate() {
985        let query = "MATCH (n:Person) RETURN n.city, COUNT(n)";
986        let result = translate(query);
987        assert!(result.is_ok());
988
989        let plan = result.unwrap();
990        if let LogicalOperator::Aggregate(agg) = &plan.root {
991            assert_eq!(agg.group_by.len(), 1); // n.city
992            assert_eq!(agg.aggregates.len(), 1); // COUNT(n)
993        } else {
994            panic!("Expected Aggregate operator");
995        }
996    }
997
998    // === Ordering and Pagination Tests ===
999
1000    #[test]
1001    fn test_translate_order_by() {
1002        let query = "MATCH (n:Person) RETURN n ORDER BY n.name";
1003        let result = translate(query);
1004        assert!(result.is_ok());
1005
1006        let plan = result.unwrap();
1007        if let LogicalOperator::Return(ret) = &plan.root {
1008            if let LogicalOperator::Sort(sort) = ret.input.as_ref() {
1009                assert_eq!(sort.keys.len(), 1);
1010                assert_eq!(sort.keys[0].order, SortOrder::Ascending);
1011            } else {
1012                panic!("Expected Sort operator");
1013            }
1014        } else {
1015            panic!("Expected Return operator");
1016        }
1017    }
1018
1019    #[test]
1020    fn test_translate_limit() {
1021        let query = "MATCH (n:Person) RETURN n LIMIT 10";
1022        let result = translate(query);
1023        assert!(result.is_ok());
1024
1025        let plan = result.unwrap();
1026        // Find Limit
1027        fn find_limit(op: &LogicalOperator) -> Option<&LimitOp> {
1028            match op {
1029                LogicalOperator::Limit(l) => Some(l),
1030                LogicalOperator::Return(r) => find_limit(&r.input),
1031                LogicalOperator::Sort(s) => find_limit(&s.input),
1032                _ => None,
1033            }
1034        }
1035
1036        let limit = find_limit(&plan.root).expect("Expected Limit");
1037        assert_eq!(limit.count, 10);
1038    }
1039
1040    #[test]
1041    fn test_translate_skip() {
1042        let query = "MATCH (n:Person) RETURN n SKIP 5";
1043        let result = translate(query);
1044        assert!(result.is_ok());
1045
1046        let plan = result.unwrap();
1047        fn find_skip(op: &LogicalOperator) -> Option<&SkipOp> {
1048            match op {
1049                LogicalOperator::Skip(s) => Some(s),
1050                LogicalOperator::Return(r) => find_skip(&r.input),
1051                LogicalOperator::Limit(l) => find_skip(&l.input),
1052                _ => None,
1053            }
1054        }
1055
1056        let skip = find_skip(&plan.root).expect("Expected Skip");
1057        assert_eq!(skip.count, 5);
1058    }
1059
1060    // === Mutation Tests ===
1061
1062    #[test]
1063    fn test_translate_insert_node() {
1064        let query = "INSERT (n:Person {name: 'Alice', age: 30})";
1065        let result = translate(query);
1066        assert!(result.is_ok());
1067
1068        let plan = result.unwrap();
1069        // Find CreateNode
1070        fn find_create(op: &LogicalOperator) -> bool {
1071            match op {
1072                LogicalOperator::CreateNode(_) => true,
1073                LogicalOperator::Return(r) => find_create(&r.input),
1074                _ => false,
1075            }
1076        }
1077
1078        assert!(find_create(&plan.root));
1079    }
1080
1081    #[test]
1082    fn test_translate_delete() {
1083        let query = "DELETE n";
1084        let result = translate(query);
1085        assert!(result.is_ok());
1086
1087        let plan = result.unwrap();
1088        if let LogicalOperator::DeleteNode(del) = &plan.root {
1089            assert_eq!(del.variable, "n");
1090        } else {
1091            panic!("Expected DeleteNode operator");
1092        }
1093    }
1094
1095    #[test]
1096    fn test_translate_set() {
1097        // SET is not a standalone statement in GQL, test the translator method directly
1098        let translator = GqlTranslator::new();
1099        let set_stmt = ast::SetStatement {
1100            assignments: vec![ast::PropertyAssignment {
1101                variable: "n".to_string(),
1102                property: "name".to_string(),
1103                value: ast::Expression::Literal(ast::Literal::String("Bob".to_string())),
1104            }],
1105            span: None,
1106        };
1107
1108        let result = translator.translate_set(&set_stmt);
1109        assert!(result.is_ok());
1110
1111        let plan = result.unwrap();
1112        if let LogicalOperator::SetProperty(set) = &plan.root {
1113            assert_eq!(set.variable, "n");
1114            assert_eq!(set.properties.len(), 1);
1115            assert_eq!(set.properties[0].0, "name");
1116        } else {
1117            panic!("Expected SetProperty operator");
1118        }
1119    }
1120
1121    // === Expression Translation Tests ===
1122
1123    #[test]
1124    fn test_translate_literals() {
1125        let query = "MATCH (n) WHERE n.count = 42 AND n.active = true AND n.rate = 3.14 RETURN n";
1126        let result = translate(query);
1127        assert!(result.is_ok());
1128    }
1129
1130    #[test]
1131    fn test_translate_parameter() {
1132        let query = "MATCH (n:Person) WHERE n.name = $name RETURN n";
1133        let result = translate(query);
1134        assert!(result.is_ok());
1135
1136        let plan = result.unwrap();
1137        fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
1138            match op {
1139                LogicalOperator::Filter(f) => Some(f),
1140                LogicalOperator::Return(r) => find_filter(&r.input),
1141                _ => None,
1142            }
1143        }
1144
1145        let filter = find_filter(&plan.root).expect("Expected Filter");
1146        if let LogicalExpression::Binary { right, .. } = &filter.predicate {
1147            if let LogicalExpression::Parameter(name) = right.as_ref() {
1148                assert_eq!(name, "name");
1149            } else {
1150                panic!("Expected Parameter");
1151            }
1152        }
1153    }
1154
1155    // === Error Handling Tests ===
1156
1157    #[test]
1158    fn test_translate_empty_delete_error() {
1159        // Create translator directly to test empty delete
1160        let translator = GqlTranslator::new();
1161        let delete = ast::DeleteStatement {
1162            variables: vec![],
1163            detach: false,
1164            span: None,
1165        };
1166        let result = translator.translate_delete(&delete);
1167        assert!(result.is_err());
1168    }
1169
1170    #[test]
1171    fn test_translate_empty_set_error() {
1172        let translator = GqlTranslator::new();
1173        let set = ast::SetStatement {
1174            assignments: vec![],
1175            span: None,
1176        };
1177        let result = translator.translate_set(&set);
1178        assert!(result.is_err());
1179    }
1180
1181    #[test]
1182    fn test_translate_empty_insert_error() {
1183        let translator = GqlTranslator::new();
1184        let insert = ast::InsertStatement {
1185            patterns: vec![],
1186            span: None,
1187        };
1188        let result = translator.translate_insert(&insert);
1189        assert!(result.is_err());
1190    }
1191
1192    // === Helper Function Tests ===
1193
1194    #[test]
1195    fn test_is_aggregate_function() {
1196        assert!(is_aggregate_function("COUNT"));
1197        assert!(is_aggregate_function("count"));
1198        assert!(is_aggregate_function("SUM"));
1199        assert!(is_aggregate_function("AVG"));
1200        assert!(is_aggregate_function("MIN"));
1201        assert!(is_aggregate_function("MAX"));
1202        assert!(is_aggregate_function("COLLECT"));
1203        assert!(!is_aggregate_function("UPPER"));
1204        assert!(!is_aggregate_function("RANDOM"));
1205    }
1206
1207    #[test]
1208    fn test_to_aggregate_function() {
1209        assert_eq!(
1210            to_aggregate_function("COUNT"),
1211            Some(AggregateFunction::Count)
1212        );
1213        assert_eq!(to_aggregate_function("sum"), Some(AggregateFunction::Sum));
1214        assert_eq!(to_aggregate_function("Avg"), Some(AggregateFunction::Avg));
1215        assert_eq!(to_aggregate_function("min"), Some(AggregateFunction::Min));
1216        assert_eq!(to_aggregate_function("MAX"), Some(AggregateFunction::Max));
1217        assert_eq!(
1218            to_aggregate_function("collect"),
1219            Some(AggregateFunction::Collect)
1220        );
1221        assert_eq!(to_aggregate_function("UNKNOWN"), None);
1222    }
1223
1224    #[test]
1225    fn test_contains_aggregate() {
1226        let count_expr = ast::Expression::FunctionCall {
1227            name: "COUNT".to_string(),
1228            args: vec![],
1229        };
1230        assert!(contains_aggregate(&count_expr));
1231
1232        let upper_expr = ast::Expression::FunctionCall {
1233            name: "UPPER".to_string(),
1234            args: vec![],
1235        };
1236        assert!(!contains_aggregate(&upper_expr));
1237
1238        let var_expr = ast::Expression::Variable("n".to_string());
1239        assert!(!contains_aggregate(&var_expr));
1240    }
1241
1242    #[test]
1243    fn test_binary_op_translation() {
1244        let translator = GqlTranslator::new();
1245
1246        assert_eq!(
1247            translator.translate_binary_op(ast::BinaryOp::Eq),
1248            BinaryOp::Eq
1249        );
1250        assert_eq!(
1251            translator.translate_binary_op(ast::BinaryOp::Ne),
1252            BinaryOp::Ne
1253        );
1254        assert_eq!(
1255            translator.translate_binary_op(ast::BinaryOp::Lt),
1256            BinaryOp::Lt
1257        );
1258        assert_eq!(
1259            translator.translate_binary_op(ast::BinaryOp::Le),
1260            BinaryOp::Le
1261        );
1262        assert_eq!(
1263            translator.translate_binary_op(ast::BinaryOp::Gt),
1264            BinaryOp::Gt
1265        );
1266        assert_eq!(
1267            translator.translate_binary_op(ast::BinaryOp::Ge),
1268            BinaryOp::Ge
1269        );
1270        assert_eq!(
1271            translator.translate_binary_op(ast::BinaryOp::And),
1272            BinaryOp::And
1273        );
1274        assert_eq!(
1275            translator.translate_binary_op(ast::BinaryOp::Or),
1276            BinaryOp::Or
1277        );
1278        assert_eq!(
1279            translator.translate_binary_op(ast::BinaryOp::Add),
1280            BinaryOp::Add
1281        );
1282        assert_eq!(
1283            translator.translate_binary_op(ast::BinaryOp::Sub),
1284            BinaryOp::Sub
1285        );
1286        assert_eq!(
1287            translator.translate_binary_op(ast::BinaryOp::Mul),
1288            BinaryOp::Mul
1289        );
1290        assert_eq!(
1291            translator.translate_binary_op(ast::BinaryOp::Div),
1292            BinaryOp::Div
1293        );
1294        assert_eq!(
1295            translator.translate_binary_op(ast::BinaryOp::Mod),
1296            BinaryOp::Mod
1297        );
1298        assert_eq!(
1299            translator.translate_binary_op(ast::BinaryOp::Like),
1300            BinaryOp::Like
1301        );
1302        assert_eq!(
1303            translator.translate_binary_op(ast::BinaryOp::In),
1304            BinaryOp::In
1305        );
1306    }
1307
1308    #[test]
1309    fn test_unary_op_translation() {
1310        let translator = GqlTranslator::new();
1311
1312        assert_eq!(
1313            translator.translate_unary_op(ast::UnaryOp::Not),
1314            UnaryOp::Not
1315        );
1316        assert_eq!(
1317            translator.translate_unary_op(ast::UnaryOp::Neg),
1318            UnaryOp::Neg
1319        );
1320        assert_eq!(
1321            translator.translate_unary_op(ast::UnaryOp::IsNull),
1322            UnaryOp::IsNull
1323        );
1324        assert_eq!(
1325            translator.translate_unary_op(ast::UnaryOp::IsNotNull),
1326            UnaryOp::IsNotNull
1327        );
1328    }
1329}