1use crate::query::plan::{
6 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DeleteNodeOp, DistinctOp,
7 ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp, LimitOp, LogicalExpression,
8 LogicalOperator, LogicalPlan, NodeScanOp, ProjectOp, Projection, ReturnItem, ReturnOp,
9 SetPropertyOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
10};
11use graphos_adapters::query::gql::{self, ast};
12use graphos_common::types::Value;
13use graphos_common::utils::error::{Error, Result};
14
15pub fn translate(query: &str) -> Result<LogicalPlan> {
21 let statement = gql::parse(query)?;
22 let translator = GqlTranslator::new();
23 translator.translate_statement(&statement)
24}
25
26struct GqlTranslator;
28
29impl GqlTranslator {
30 fn new() -> Self {
31 Self
32 }
33
34 fn translate_statement(&self, stmt: &ast::Statement) -> Result<LogicalPlan> {
35 match stmt {
36 ast::Statement::Query(query) => self.translate_query(query),
37 ast::Statement::DataModification(dm) => self.translate_data_modification(dm),
38 ast::Statement::Schema(_) => Err(Error::Internal(
39 "Schema statements not yet supported".to_string(),
40 )),
41 }
42 }
43
44 fn translate_query(&self, query: &ast::QueryStatement) -> Result<LogicalPlan> {
45 let mut plan = LogicalOperator::Empty;
47
48 for match_clause in &query.match_clauses {
49 let match_plan = self.translate_match(match_clause)?;
50 if matches!(plan, LogicalOperator::Empty) {
51 plan = match_plan;
52 } else if match_clause.optional {
53 plan = LogicalOperator::LeftJoin(LeftJoinOp {
55 left: Box::new(plan),
56 right: Box::new(match_plan),
57 condition: None,
58 });
59 } else {
60 plan = LogicalOperator::Join(JoinOp {
62 left: Box::new(plan),
63 right: Box::new(match_plan),
64 join_type: JoinType::Cross,
65 conditions: vec![],
66 });
67 }
68 }
69
70 if let Some(where_clause) = &query.where_clause {
72 let predicate = self.translate_expression(&where_clause.expression)?;
73 plan = LogicalOperator::Filter(FilterOp {
74 predicate,
75 input: Box::new(plan),
76 });
77 }
78
79 for with_clause in &query.with_clauses {
81 let projections: Vec<Projection> = with_clause
82 .items
83 .iter()
84 .map(|item| {
85 Ok(Projection {
86 expression: self.translate_expression(&item.expression)?,
87 alias: item.alias.clone(),
88 })
89 })
90 .collect::<Result<_>>()?;
91
92 plan = LogicalOperator::Project(ProjectOp {
93 projections,
94 input: Box::new(plan),
95 });
96
97 if let Some(where_clause) = &with_clause.where_clause {
99 let predicate = self.translate_expression(&where_clause.expression)?;
100 plan = LogicalOperator::Filter(FilterOp {
101 predicate,
102 input: Box::new(plan),
103 });
104 }
105
106 if with_clause.distinct {
108 plan = LogicalOperator::Distinct(DistinctOp {
109 input: Box::new(plan),
110 });
111 }
112 }
113
114 if let Some(skip_expr) = &query.return_clause.skip {
116 if let ast::Expression::Literal(ast::Literal::Integer(n)) = skip_expr {
117 plan = LogicalOperator::Skip(SkipOp {
118 count: *n as usize,
119 input: Box::new(plan),
120 });
121 }
122 }
123
124 if let Some(limit_expr) = &query.return_clause.limit {
126 if let ast::Expression::Literal(ast::Literal::Integer(n)) = limit_expr {
127 plan = LogicalOperator::Limit(LimitOp {
128 count: *n as usize,
129 input: Box::new(plan),
130 });
131 }
132 }
133
134 let has_aggregates = query
136 .return_clause
137 .items
138 .iter()
139 .any(|item| contains_aggregate(&item.expression));
140
141 if has_aggregates {
142 let (aggregates, group_by) =
144 self.extract_aggregates_and_groups(&query.return_clause.items)?;
145
146 plan = LogicalOperator::Aggregate(AggregateOp {
149 group_by,
150 aggregates,
151 input: Box::new(plan),
152 });
153
154 } else {
157 if let Some(order_by) = &query.return_clause.order_by {
159 let keys = order_by
160 .items
161 .iter()
162 .map(|item| {
163 Ok(SortKey {
164 expression: self.translate_expression(&item.expression)?,
165 order: match item.order {
166 ast::SortOrder::Asc => SortOrder::Ascending,
167 ast::SortOrder::Desc => SortOrder::Descending,
168 },
169 })
170 })
171 .collect::<Result<Vec<_>>>()?;
172
173 plan = LogicalOperator::Sort(SortOp {
174 keys,
175 input: Box::new(plan),
176 });
177 }
178
179 let return_items = query
181 .return_clause
182 .items
183 .iter()
184 .map(|item| {
185 Ok(ReturnItem {
186 expression: self.translate_expression(&item.expression)?,
187 alias: item.alias.clone(),
188 })
189 })
190 .collect::<Result<Vec<_>>>()?;
191
192 plan = LogicalOperator::Return(ReturnOp {
193 items: return_items,
194 distinct: query.return_clause.distinct,
195 input: Box::new(plan),
196 });
197 }
198
199 Ok(LogicalPlan::new(plan))
200 }
201
202 #[allow(dead_code)]
204 fn build_aggregate_return_items(&self, items: &[ast::ReturnItem]) -> Result<Vec<ReturnItem>> {
205 let mut return_items = Vec::new();
206 let mut agg_idx = 0;
207
208 for item in items {
209 if contains_aggregate(&item.expression) {
210 let alias = item.alias.clone().unwrap_or_else(|| {
212 if let ast::Expression::FunctionCall { name, .. } = &item.expression {
213 format!("{}(...)", name.to_lowercase())
214 } else {
215 format!("agg_{}", agg_idx)
216 }
217 });
218 return_items.push(ReturnItem {
219 expression: LogicalExpression::Variable(format!("__agg_{}", agg_idx)),
220 alias: Some(alias),
221 });
222 agg_idx += 1;
223 } else {
224 return_items.push(ReturnItem {
226 expression: self.translate_expression(&item.expression)?,
227 alias: item.alias.clone(),
228 });
229 }
230 }
231
232 Ok(return_items)
233 }
234
235 fn translate_match(&self, match_clause: &ast::MatchClause) -> Result<LogicalOperator> {
236 let mut plan: Option<LogicalOperator> = None;
237
238 for pattern in &match_clause.patterns {
239 let pattern_plan = self.translate_pattern(pattern, plan.take())?;
240 plan = Some(pattern_plan);
241 }
242
243 plan.ok_or_else(|| Error::Internal("Empty MATCH clause".to_string()))
244 }
245
246 fn translate_pattern(
247 &self,
248 pattern: &ast::Pattern,
249 input: Option<LogicalOperator>,
250 ) -> Result<LogicalOperator> {
251 match pattern {
252 ast::Pattern::Node(node) => self.translate_node_pattern(node, input),
253 ast::Pattern::Path(path) => self.translate_path_pattern(path, input),
254 }
255 }
256
257 fn translate_node_pattern(
258 &self,
259 node: &ast::NodePattern,
260 input: Option<LogicalOperator>,
261 ) -> Result<LogicalOperator> {
262 let variable = node
263 .variable
264 .clone()
265 .unwrap_or_else(|| format!("_anon_{}", rand_id()));
266
267 let label = node.labels.first().cloned();
268
269 Ok(LogicalOperator::NodeScan(NodeScanOp {
270 variable,
271 label,
272 input: input.map(Box::new),
273 }))
274 }
275
276 fn translate_path_pattern(
277 &self,
278 path: &ast::PathPattern,
279 input: Option<LogicalOperator>,
280 ) -> Result<LogicalOperator> {
281 let source_var = path
283 .source
284 .variable
285 .clone()
286 .unwrap_or_else(|| format!("_anon_{}", rand_id()));
287
288 let source_label = path.source.labels.first().cloned();
289
290 let mut plan = LogicalOperator::NodeScan(NodeScanOp {
291 variable: source_var.clone(),
292 label: source_label,
293 input: input.map(Box::new),
294 });
295
296 let mut current_source = source_var;
298
299 for edge in &path.edges {
300 let target_var = edge
301 .target
302 .variable
303 .clone()
304 .unwrap_or_else(|| format!("_anon_{}", rand_id()));
305
306 let edge_var = edge.variable.clone();
307 let edge_type = edge.types.first().cloned();
308
309 let direction = match edge.direction {
310 ast::EdgeDirection::Outgoing => ExpandDirection::Outgoing,
311 ast::EdgeDirection::Incoming => ExpandDirection::Incoming,
312 ast::EdgeDirection::Undirected => ExpandDirection::Both,
313 };
314
315 plan = LogicalOperator::Expand(ExpandOp {
316 from_variable: current_source,
317 to_variable: target_var.clone(),
318 edge_variable: edge_var,
319 direction,
320 edge_type,
321 min_hops: 1,
322 max_hops: Some(1),
323 input: Box::new(plan),
324 });
325
326 current_source = target_var;
327 }
328
329 Ok(plan)
330 }
331
332 fn translate_data_modification(
333 &self,
334 dm: &ast::DataModificationStatement,
335 ) -> Result<LogicalPlan> {
336 match dm {
337 ast::DataModificationStatement::Insert(insert) => self.translate_insert(insert),
338 ast::DataModificationStatement::Delete(delete) => self.translate_delete(delete),
339 ast::DataModificationStatement::Set(set) => self.translate_set(set),
340 }
341 }
342
343 fn translate_delete(&self, delete: &ast::DeleteStatement) -> Result<LogicalPlan> {
344 if delete.variables.is_empty() {
349 return Err(Error::Internal(
350 "DELETE requires at least one variable".to_string(),
351 ));
352 }
353
354 let first_var = &delete.variables[0];
357
358 let scan = LogicalOperator::NodeScan(NodeScanOp {
360 variable: first_var.clone(),
361 label: None,
362 input: None,
363 });
364
365 let mut plan = LogicalOperator::DeleteNode(DeleteNodeOp {
367 variable: first_var.clone(),
368 input: Box::new(scan),
369 });
370
371 for var in delete.variables.iter().skip(1) {
373 plan = LogicalOperator::DeleteNode(DeleteNodeOp {
374 variable: var.clone(),
375 input: Box::new(plan),
376 });
377 }
378
379 Ok(LogicalPlan::new(plan))
380 }
381
382 fn translate_set(&self, set: &ast::SetStatement) -> Result<LogicalPlan> {
383 if set.assignments.is_empty() {
387 return Err(Error::Internal(
388 "SET requires at least one assignment".to_string(),
389 ));
390 }
391
392 let first_assignment = &set.assignments[0];
394 let var = &first_assignment.variable;
395
396 let scan = LogicalOperator::NodeScan(NodeScanOp {
398 variable: var.clone(),
399 label: None,
400 input: None,
401 });
402
403 let properties: Vec<(String, LogicalExpression)> = set
405 .assignments
406 .iter()
407 .filter(|a| &a.variable == var)
408 .map(|a| Ok((a.property.clone(), self.translate_expression(&a.value)?)))
409 .collect::<Result<_>>()?;
410
411 let plan = LogicalOperator::SetProperty(SetPropertyOp {
412 variable: var.clone(),
413 properties,
414 replace: false,
415 input: Box::new(scan),
416 });
417
418 Ok(LogicalPlan::new(plan))
419 }
420
421 fn translate_insert(&self, insert: &ast::InsertStatement) -> Result<LogicalPlan> {
422 if insert.patterns.is_empty() {
426 return Err(Error::Internal("Empty INSERT statement".to_string()));
427 }
428
429 let pattern = &insert.patterns[0];
430
431 match pattern {
432 ast::Pattern::Node(node) => {
433 let variable = node
434 .variable
435 .clone()
436 .unwrap_or_else(|| format!("_anon_{}", rand_id()));
437
438 let properties = node
439 .properties
440 .iter()
441 .map(|(k, v)| Ok((k.clone(), self.translate_expression(v)?)))
442 .collect::<Result<Vec<_>>>()?;
443
444 let create = LogicalOperator::CreateNode(crate::query::plan::CreateNodeOp {
445 variable: variable.clone(),
446 labels: node.labels.clone(),
447 properties,
448 input: None,
449 });
450
451 let ret = LogicalOperator::Return(ReturnOp {
453 items: vec![ReturnItem {
454 expression: LogicalExpression::Variable(variable),
455 alias: None,
456 }],
457 distinct: false,
458 input: Box::new(create),
459 });
460
461 Ok(LogicalPlan::new(ret))
462 }
463 ast::Pattern::Path(_) => {
464 Err(Error::Internal("Path INSERT not yet supported".to_string()))
465 }
466 }
467 }
468
469 fn translate_expression(&self, expr: &ast::Expression) -> Result<LogicalExpression> {
470 match expr {
471 ast::Expression::Literal(lit) => Ok(self.translate_literal(lit)),
472 ast::Expression::Variable(name) => Ok(LogicalExpression::Variable(name.clone())),
473 ast::Expression::Parameter(name) => Ok(LogicalExpression::Parameter(name.clone())),
474 ast::Expression::PropertyAccess { variable, property } => {
475 Ok(LogicalExpression::Property {
476 variable: variable.clone(),
477 property: property.clone(),
478 })
479 }
480 ast::Expression::Binary { left, op, right } => {
481 let left = self.translate_expression(left)?;
482 let right = self.translate_expression(right)?;
483 let op = self.translate_binary_op(*op);
484 Ok(LogicalExpression::Binary {
485 left: Box::new(left),
486 op,
487 right: Box::new(right),
488 })
489 }
490 ast::Expression::Unary { op, operand } => {
491 let operand = self.translate_expression(operand)?;
492 let op = self.translate_unary_op(*op);
493 Ok(LogicalExpression::Unary {
494 op,
495 operand: Box::new(operand),
496 })
497 }
498 ast::Expression::FunctionCall { name, args } => {
499 let args = args
500 .iter()
501 .map(|a| self.translate_expression(a))
502 .collect::<Result<Vec<_>>>()?;
503 Ok(LogicalExpression::FunctionCall {
504 name: name.clone(),
505 args,
506 })
507 }
508 ast::Expression::List(items) => {
509 let items = items
510 .iter()
511 .map(|i| self.translate_expression(i))
512 .collect::<Result<Vec<_>>>()?;
513 Ok(LogicalExpression::List(items))
514 }
515 ast::Expression::Case {
516 input,
517 whens,
518 else_clause,
519 } => {
520 let operand = input
521 .as_ref()
522 .map(|e| self.translate_expression(e))
523 .transpose()?
524 .map(Box::new);
525
526 let when_clauses = whens
527 .iter()
528 .map(|(cond, result)| {
529 Ok((
530 self.translate_expression(cond)?,
531 self.translate_expression(result)?,
532 ))
533 })
534 .collect::<Result<Vec<_>>>()?;
535
536 let else_clause = else_clause
537 .as_ref()
538 .map(|e| self.translate_expression(e))
539 .transpose()?
540 .map(Box::new);
541
542 Ok(LogicalExpression::Case {
543 operand,
544 when_clauses,
545 else_clause,
546 })
547 }
548 ast::Expression::ExistsSubquery { query } => {
549 let inner_plan = self.translate_subquery_to_operator(query)?;
551 Ok(LogicalExpression::ExistsSubquery(Box::new(inner_plan)))
552 }
553 }
554 }
555
556 fn translate_literal(&self, lit: &ast::Literal) -> LogicalExpression {
557 let value = match lit {
558 ast::Literal::Null => Value::Null,
559 ast::Literal::Bool(b) => Value::Bool(*b),
560 ast::Literal::Integer(i) => Value::Int64(*i),
561 ast::Literal::Float(f) => Value::Float64(*f),
562 ast::Literal::String(s) => Value::String(s.clone().into()),
563 };
564 LogicalExpression::Literal(value)
565 }
566
567 fn translate_binary_op(&self, op: ast::BinaryOp) -> BinaryOp {
568 match op {
569 ast::BinaryOp::Eq => BinaryOp::Eq,
570 ast::BinaryOp::Ne => BinaryOp::Ne,
571 ast::BinaryOp::Lt => BinaryOp::Lt,
572 ast::BinaryOp::Le => BinaryOp::Le,
573 ast::BinaryOp::Gt => BinaryOp::Gt,
574 ast::BinaryOp::Ge => BinaryOp::Ge,
575 ast::BinaryOp::And => BinaryOp::And,
576 ast::BinaryOp::Or => BinaryOp::Or,
577 ast::BinaryOp::Add => BinaryOp::Add,
578 ast::BinaryOp::Sub => BinaryOp::Sub,
579 ast::BinaryOp::Mul => BinaryOp::Mul,
580 ast::BinaryOp::Div => BinaryOp::Div,
581 ast::BinaryOp::Mod => BinaryOp::Mod,
582 ast::BinaryOp::Concat => BinaryOp::Concat,
583 ast::BinaryOp::Like => BinaryOp::Like,
584 ast::BinaryOp::In => BinaryOp::In,
585 }
586 }
587
588 fn translate_unary_op(&self, op: ast::UnaryOp) -> UnaryOp {
589 match op {
590 ast::UnaryOp::Not => UnaryOp::Not,
591 ast::UnaryOp::Neg => UnaryOp::Neg,
592 ast::UnaryOp::IsNull => UnaryOp::IsNull,
593 ast::UnaryOp::IsNotNull => UnaryOp::IsNotNull,
594 }
595 }
596
597 fn translate_subquery_to_operator(
599 &self,
600 query: &ast::QueryStatement,
601 ) -> Result<LogicalOperator> {
602 let mut plan = LogicalOperator::Empty;
603
604 for match_clause in &query.match_clauses {
605 let match_plan = self.translate_match(match_clause)?;
606 plan = if matches!(plan, LogicalOperator::Empty) {
607 match_plan
608 } else {
609 LogicalOperator::Join(JoinOp {
610 left: Box::new(plan),
611 right: Box::new(match_plan),
612 join_type: JoinType::Cross,
613 conditions: vec![],
614 })
615 };
616 }
617
618 if let Some(where_clause) = &query.where_clause {
619 let predicate = self.translate_expression(&where_clause.expression)?;
620 plan = LogicalOperator::Filter(FilterOp {
621 predicate,
622 input: Box::new(plan),
623 });
624 }
625
626 Ok(plan)
627 }
628
629 fn extract_aggregates_and_groups(
631 &self,
632 items: &[ast::ReturnItem],
633 ) -> Result<(Vec<AggregateExpr>, Vec<LogicalExpression>)> {
634 let mut aggregates = Vec::new();
635 let mut group_by = Vec::new();
636
637 for item in items {
638 if let Some(agg_expr) = self.try_extract_aggregate(&item.expression, &item.alias)? {
639 aggregates.push(agg_expr);
640 } else {
641 let expr = self.translate_expression(&item.expression)?;
643 group_by.push(expr);
644 }
645 }
646
647 Ok((aggregates, group_by))
648 }
649
650 fn try_extract_aggregate(
652 &self,
653 expr: &ast::Expression,
654 alias: &Option<String>,
655 ) -> Result<Option<AggregateExpr>> {
656 match expr {
657 ast::Expression::FunctionCall { name, args } => {
658 if let Some(func) = to_aggregate_function(name) {
659 let agg_expr = if args.is_empty() {
660 AggregateExpr {
662 function: func,
663 expression: None,
664 distinct: false,
665 alias: alias.clone(),
666 }
667 } else {
668 AggregateExpr {
670 function: func,
671 expression: Some(self.translate_expression(&args[0])?),
672 distinct: false,
673 alias: alias.clone(),
674 }
675 };
676 Ok(Some(agg_expr))
677 } else {
678 Ok(None)
679 }
680 }
681 _ => Ok(None),
682 }
683 }
684}
685
686fn rand_id() -> u32 {
688 use std::sync::atomic::{AtomicU32, Ordering};
689 static COUNTER: AtomicU32 = AtomicU32::new(0);
690 COUNTER.fetch_add(1, Ordering::Relaxed)
691}
692
693fn is_aggregate_function(name: &str) -> bool {
695 matches!(
696 name.to_uppercase().as_str(),
697 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COLLECT"
698 )
699}
700
701fn to_aggregate_function(name: &str) -> Option<AggregateFunction> {
703 match name.to_uppercase().as_str() {
704 "COUNT" => Some(AggregateFunction::Count),
705 "SUM" => Some(AggregateFunction::Sum),
706 "AVG" => Some(AggregateFunction::Avg),
707 "MIN" => Some(AggregateFunction::Min),
708 "MAX" => Some(AggregateFunction::Max),
709 "COLLECT" => Some(AggregateFunction::Collect),
710 _ => None,
711 }
712}
713
714fn contains_aggregate(expr: &ast::Expression) -> bool {
716 match expr {
717 ast::Expression::FunctionCall { name, .. } => is_aggregate_function(name),
718 ast::Expression::Binary { left, right, .. } => {
719 contains_aggregate(left) || contains_aggregate(right)
720 }
721 ast::Expression::Unary { operand, .. } => contains_aggregate(operand),
722 _ => false,
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[test]
733 fn test_translate_simple_match() {
734 let query = "MATCH (n:Person) RETURN n";
735 let result = translate(query);
736 assert!(result.is_ok());
737
738 let plan = result.unwrap();
739 if let LogicalOperator::Return(ret) = &plan.root {
740 assert_eq!(ret.items.len(), 1);
741 assert!(!ret.distinct);
742 } else {
743 panic!("Expected Return operator");
744 }
745 }
746
747 #[test]
748 fn test_translate_match_with_where() {
749 let query = "MATCH (n:Person) WHERE n.age > 30 RETURN n.name";
750 let result = translate(query);
751 assert!(result.is_ok());
752
753 let plan = result.unwrap();
754 if let LogicalOperator::Return(ret) = &plan.root {
755 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
757 if let LogicalExpression::Binary { op, .. } = &filter.predicate {
758 assert_eq!(*op, BinaryOp::Gt);
759 } else {
760 panic!("Expected binary expression");
761 }
762 } else {
763 panic!("Expected Filter operator");
764 }
765 } else {
766 panic!("Expected Return operator");
767 }
768 }
769
770 #[test]
771 fn test_translate_match_without_label() {
772 let query = "MATCH (n) RETURN n";
773 let result = translate(query);
774 assert!(result.is_ok());
775
776 let plan = result.unwrap();
777 if let LogicalOperator::Return(ret) = &plan.root {
778 if let LogicalOperator::NodeScan(scan) = ret.input.as_ref() {
779 assert!(scan.label.is_none());
780 } else {
781 panic!("Expected NodeScan operator");
782 }
783 } else {
784 panic!("Expected Return operator");
785 }
786 }
787
788 #[test]
789 fn test_translate_match_distinct() {
790 let query = "MATCH (n:Person) RETURN DISTINCT n.name";
791 let result = translate(query);
792 assert!(result.is_ok());
793
794 let plan = result.unwrap();
795 if let LogicalOperator::Return(ret) = &plan.root {
796 assert!(ret.distinct);
797 } else {
798 panic!("Expected Return operator");
799 }
800 }
801
802 #[test]
805 fn test_translate_filter_equality() {
806 let query = "MATCH (n:Person) WHERE n.name = 'Alice' RETURN n";
807 let result = translate(query);
808 assert!(result.is_ok());
809
810 let plan = result.unwrap();
811 fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
813 match op {
814 LogicalOperator::Filter(f) => Some(f),
815 LogicalOperator::Return(r) => find_filter(&r.input),
816 _ => None,
817 }
818 }
819
820 let filter = find_filter(&plan.root).expect("Expected Filter");
821 if let LogicalExpression::Binary { op, .. } = &filter.predicate {
822 assert_eq!(*op, BinaryOp::Eq);
823 }
824 }
825
826 #[test]
827 fn test_translate_filter_and() {
828 let query = "MATCH (n:Person) WHERE n.age > 20 AND n.age < 40 RETURN n";
829 let result = translate(query);
830 assert!(result.is_ok());
831
832 let plan = result.unwrap();
833 fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
834 match op {
835 LogicalOperator::Filter(f) => Some(f),
836 LogicalOperator::Return(r) => find_filter(&r.input),
837 _ => None,
838 }
839 }
840
841 let filter = find_filter(&plan.root).expect("Expected Filter");
842 if let LogicalExpression::Binary { op, .. } = &filter.predicate {
843 assert_eq!(*op, BinaryOp::And);
844 }
845 }
846
847 #[test]
848 fn test_translate_filter_or() {
849 let query = "MATCH (n:Person) WHERE n.name = 'Alice' OR n.name = 'Bob' RETURN n";
850 let result = translate(query);
851 assert!(result.is_ok());
852
853 let plan = result.unwrap();
854 fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
855 match op {
856 LogicalOperator::Filter(f) => Some(f),
857 LogicalOperator::Return(r) => find_filter(&r.input),
858 _ => None,
859 }
860 }
861
862 let filter = find_filter(&plan.root).expect("Expected Filter");
863 if let LogicalExpression::Binary { op, .. } = &filter.predicate {
864 assert_eq!(*op, BinaryOp::Or);
865 }
866 }
867
868 #[test]
869 fn test_translate_filter_not() {
870 let query = "MATCH (n:Person) WHERE NOT n.active RETURN n";
871 let result = translate(query);
872 assert!(result.is_ok());
873
874 let plan = result.unwrap();
875 fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
876 match op {
877 LogicalOperator::Filter(f) => Some(f),
878 LogicalOperator::Return(r) => find_filter(&r.input),
879 _ => None,
880 }
881 }
882
883 let filter = find_filter(&plan.root).expect("Expected Filter");
884 if let LogicalExpression::Unary { op, .. } = &filter.predicate {
885 assert_eq!(*op, UnaryOp::Not);
886 }
887 }
888
889 #[test]
892 fn test_translate_path_pattern() {
893 let query = "MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a, b";
894 let result = translate(query);
895 assert!(result.is_ok());
896
897 let plan = result.unwrap();
898 fn find_expand(op: &LogicalOperator) -> Option<&ExpandOp> {
900 match op {
901 LogicalOperator::Expand(e) => Some(e),
902 LogicalOperator::Return(r) => find_expand(&r.input),
903 LogicalOperator::Filter(f) => find_expand(&f.input),
904 _ => None,
905 }
906 }
907
908 let expand = find_expand(&plan.root).expect("Expected Expand");
909 assert_eq!(expand.direction, ExpandDirection::Outgoing);
910 assert_eq!(expand.edge_type.as_deref(), Some("KNOWS"));
911 }
912
913 #[test]
914 fn test_translate_incoming_path() {
915 let query = "MATCH (a:Person)<-[:KNOWS]-(b:Person) RETURN a, b";
916 let result = translate(query);
917 assert!(result.is_ok());
918
919 let plan = result.unwrap();
920 fn find_expand(op: &LogicalOperator) -> Option<&ExpandOp> {
921 match op {
922 LogicalOperator::Expand(e) => Some(e),
923 LogicalOperator::Return(r) => find_expand(&r.input),
924 _ => None,
925 }
926 }
927
928 let expand = find_expand(&plan.root).expect("Expected Expand");
929 assert_eq!(expand.direction, ExpandDirection::Incoming);
930 }
931
932 #[test]
933 fn test_translate_undirected_path() {
934 let query = "MATCH (a:Person)-[:KNOWS]-(b:Person) RETURN a, b";
935 let result = translate(query);
936 assert!(result.is_ok());
937
938 let plan = result.unwrap();
939 fn find_expand(op: &LogicalOperator) -> Option<&ExpandOp> {
940 match op {
941 LogicalOperator::Expand(e) => Some(e),
942 LogicalOperator::Return(r) => find_expand(&r.input),
943 _ => None,
944 }
945 }
946
947 let expand = find_expand(&plan.root).expect("Expected Expand");
948 assert_eq!(expand.direction, ExpandDirection::Both);
949 }
950
951 #[test]
954 fn test_translate_count_aggregate() {
955 let query = "MATCH (n:Person) RETURN COUNT(n)";
956 let result = translate(query);
957 assert!(result.is_ok());
958
959 let plan = result.unwrap();
960 if let LogicalOperator::Aggregate(agg) = &plan.root {
961 assert_eq!(agg.aggregates.len(), 1);
962 assert_eq!(agg.aggregates[0].function, AggregateFunction::Count);
963 } else {
964 panic!("Expected Aggregate operator, got {:?}", plan.root);
965 }
966 }
967
968 #[test]
969 fn test_translate_sum_aggregate() {
970 let query = "MATCH (n:Person) RETURN SUM(n.age)";
971 let result = translate(query);
972 assert!(result.is_ok());
973
974 let plan = result.unwrap();
975 if let LogicalOperator::Aggregate(agg) = &plan.root {
976 assert_eq!(agg.aggregates.len(), 1);
977 assert_eq!(agg.aggregates[0].function, AggregateFunction::Sum);
978 } else {
979 panic!("Expected Aggregate operator");
980 }
981 }
982
983 #[test]
984 fn test_translate_group_by_aggregate() {
985 let query = "MATCH (n:Person) RETURN n.city, COUNT(n)";
986 let result = translate(query);
987 assert!(result.is_ok());
988
989 let plan = result.unwrap();
990 if let LogicalOperator::Aggregate(agg) = &plan.root {
991 assert_eq!(agg.group_by.len(), 1); assert_eq!(agg.aggregates.len(), 1); } else {
994 panic!("Expected Aggregate operator");
995 }
996 }
997
998 #[test]
1001 fn test_translate_order_by() {
1002 let query = "MATCH (n:Person) RETURN n ORDER BY n.name";
1003 let result = translate(query);
1004 assert!(result.is_ok());
1005
1006 let plan = result.unwrap();
1007 if let LogicalOperator::Return(ret) = &plan.root {
1008 if let LogicalOperator::Sort(sort) = ret.input.as_ref() {
1009 assert_eq!(sort.keys.len(), 1);
1010 assert_eq!(sort.keys[0].order, SortOrder::Ascending);
1011 } else {
1012 panic!("Expected Sort operator");
1013 }
1014 } else {
1015 panic!("Expected Return operator");
1016 }
1017 }
1018
1019 #[test]
1020 fn test_translate_limit() {
1021 let query = "MATCH (n:Person) RETURN n LIMIT 10";
1022 let result = translate(query);
1023 assert!(result.is_ok());
1024
1025 let plan = result.unwrap();
1026 fn find_limit(op: &LogicalOperator) -> Option<&LimitOp> {
1028 match op {
1029 LogicalOperator::Limit(l) => Some(l),
1030 LogicalOperator::Return(r) => find_limit(&r.input),
1031 LogicalOperator::Sort(s) => find_limit(&s.input),
1032 _ => None,
1033 }
1034 }
1035
1036 let limit = find_limit(&plan.root).expect("Expected Limit");
1037 assert_eq!(limit.count, 10);
1038 }
1039
1040 #[test]
1041 fn test_translate_skip() {
1042 let query = "MATCH (n:Person) RETURN n SKIP 5";
1043 let result = translate(query);
1044 assert!(result.is_ok());
1045
1046 let plan = result.unwrap();
1047 fn find_skip(op: &LogicalOperator) -> Option<&SkipOp> {
1048 match op {
1049 LogicalOperator::Skip(s) => Some(s),
1050 LogicalOperator::Return(r) => find_skip(&r.input),
1051 LogicalOperator::Limit(l) => find_skip(&l.input),
1052 _ => None,
1053 }
1054 }
1055
1056 let skip = find_skip(&plan.root).expect("Expected Skip");
1057 assert_eq!(skip.count, 5);
1058 }
1059
1060 #[test]
1063 fn test_translate_insert_node() {
1064 let query = "INSERT (n:Person {name: 'Alice', age: 30})";
1065 let result = translate(query);
1066 assert!(result.is_ok());
1067
1068 let plan = result.unwrap();
1069 fn find_create(op: &LogicalOperator) -> bool {
1071 match op {
1072 LogicalOperator::CreateNode(_) => true,
1073 LogicalOperator::Return(r) => find_create(&r.input),
1074 _ => false,
1075 }
1076 }
1077
1078 assert!(find_create(&plan.root));
1079 }
1080
1081 #[test]
1082 fn test_translate_delete() {
1083 let query = "DELETE n";
1084 let result = translate(query);
1085 assert!(result.is_ok());
1086
1087 let plan = result.unwrap();
1088 if let LogicalOperator::DeleteNode(del) = &plan.root {
1089 assert_eq!(del.variable, "n");
1090 } else {
1091 panic!("Expected DeleteNode operator");
1092 }
1093 }
1094
1095 #[test]
1096 fn test_translate_set() {
1097 let translator = GqlTranslator::new();
1099 let set_stmt = ast::SetStatement {
1100 assignments: vec![ast::PropertyAssignment {
1101 variable: "n".to_string(),
1102 property: "name".to_string(),
1103 value: ast::Expression::Literal(ast::Literal::String("Bob".to_string())),
1104 }],
1105 span: None,
1106 };
1107
1108 let result = translator.translate_set(&set_stmt);
1109 assert!(result.is_ok());
1110
1111 let plan = result.unwrap();
1112 if let LogicalOperator::SetProperty(set) = &plan.root {
1113 assert_eq!(set.variable, "n");
1114 assert_eq!(set.properties.len(), 1);
1115 assert_eq!(set.properties[0].0, "name");
1116 } else {
1117 panic!("Expected SetProperty operator");
1118 }
1119 }
1120
1121 #[test]
1124 fn test_translate_literals() {
1125 let query = "MATCH (n) WHERE n.count = 42 AND n.active = true AND n.rate = 3.14 RETURN n";
1126 let result = translate(query);
1127 assert!(result.is_ok());
1128 }
1129
1130 #[test]
1131 fn test_translate_parameter() {
1132 let query = "MATCH (n:Person) WHERE n.name = $name RETURN n";
1133 let result = translate(query);
1134 assert!(result.is_ok());
1135
1136 let plan = result.unwrap();
1137 fn find_filter(op: &LogicalOperator) -> Option<&FilterOp> {
1138 match op {
1139 LogicalOperator::Filter(f) => Some(f),
1140 LogicalOperator::Return(r) => find_filter(&r.input),
1141 _ => None,
1142 }
1143 }
1144
1145 let filter = find_filter(&plan.root).expect("Expected Filter");
1146 if let LogicalExpression::Binary { right, .. } = &filter.predicate {
1147 if let LogicalExpression::Parameter(name) = right.as_ref() {
1148 assert_eq!(name, "name");
1149 } else {
1150 panic!("Expected Parameter");
1151 }
1152 }
1153 }
1154
1155 #[test]
1158 fn test_translate_empty_delete_error() {
1159 let translator = GqlTranslator::new();
1161 let delete = ast::DeleteStatement {
1162 variables: vec![],
1163 detach: false,
1164 span: None,
1165 };
1166 let result = translator.translate_delete(&delete);
1167 assert!(result.is_err());
1168 }
1169
1170 #[test]
1171 fn test_translate_empty_set_error() {
1172 let translator = GqlTranslator::new();
1173 let set = ast::SetStatement {
1174 assignments: vec![],
1175 span: None,
1176 };
1177 let result = translator.translate_set(&set);
1178 assert!(result.is_err());
1179 }
1180
1181 #[test]
1182 fn test_translate_empty_insert_error() {
1183 let translator = GqlTranslator::new();
1184 let insert = ast::InsertStatement {
1185 patterns: vec![],
1186 span: None,
1187 };
1188 let result = translator.translate_insert(&insert);
1189 assert!(result.is_err());
1190 }
1191
1192 #[test]
1195 fn test_is_aggregate_function() {
1196 assert!(is_aggregate_function("COUNT"));
1197 assert!(is_aggregate_function("count"));
1198 assert!(is_aggregate_function("SUM"));
1199 assert!(is_aggregate_function("AVG"));
1200 assert!(is_aggregate_function("MIN"));
1201 assert!(is_aggregate_function("MAX"));
1202 assert!(is_aggregate_function("COLLECT"));
1203 assert!(!is_aggregate_function("UPPER"));
1204 assert!(!is_aggregate_function("RANDOM"));
1205 }
1206
1207 #[test]
1208 fn test_to_aggregate_function() {
1209 assert_eq!(
1210 to_aggregate_function("COUNT"),
1211 Some(AggregateFunction::Count)
1212 );
1213 assert_eq!(to_aggregate_function("sum"), Some(AggregateFunction::Sum));
1214 assert_eq!(to_aggregate_function("Avg"), Some(AggregateFunction::Avg));
1215 assert_eq!(to_aggregate_function("min"), Some(AggregateFunction::Min));
1216 assert_eq!(to_aggregate_function("MAX"), Some(AggregateFunction::Max));
1217 assert_eq!(
1218 to_aggregate_function("collect"),
1219 Some(AggregateFunction::Collect)
1220 );
1221 assert_eq!(to_aggregate_function("UNKNOWN"), None);
1222 }
1223
1224 #[test]
1225 fn test_contains_aggregate() {
1226 let count_expr = ast::Expression::FunctionCall {
1227 name: "COUNT".to_string(),
1228 args: vec![],
1229 };
1230 assert!(contains_aggregate(&count_expr));
1231
1232 let upper_expr = ast::Expression::FunctionCall {
1233 name: "UPPER".to_string(),
1234 args: vec![],
1235 };
1236 assert!(!contains_aggregate(&upper_expr));
1237
1238 let var_expr = ast::Expression::Variable("n".to_string());
1239 assert!(!contains_aggregate(&var_expr));
1240 }
1241
1242 #[test]
1243 fn test_binary_op_translation() {
1244 let translator = GqlTranslator::new();
1245
1246 assert_eq!(
1247 translator.translate_binary_op(ast::BinaryOp::Eq),
1248 BinaryOp::Eq
1249 );
1250 assert_eq!(
1251 translator.translate_binary_op(ast::BinaryOp::Ne),
1252 BinaryOp::Ne
1253 );
1254 assert_eq!(
1255 translator.translate_binary_op(ast::BinaryOp::Lt),
1256 BinaryOp::Lt
1257 );
1258 assert_eq!(
1259 translator.translate_binary_op(ast::BinaryOp::Le),
1260 BinaryOp::Le
1261 );
1262 assert_eq!(
1263 translator.translate_binary_op(ast::BinaryOp::Gt),
1264 BinaryOp::Gt
1265 );
1266 assert_eq!(
1267 translator.translate_binary_op(ast::BinaryOp::Ge),
1268 BinaryOp::Ge
1269 );
1270 assert_eq!(
1271 translator.translate_binary_op(ast::BinaryOp::And),
1272 BinaryOp::And
1273 );
1274 assert_eq!(
1275 translator.translate_binary_op(ast::BinaryOp::Or),
1276 BinaryOp::Or
1277 );
1278 assert_eq!(
1279 translator.translate_binary_op(ast::BinaryOp::Add),
1280 BinaryOp::Add
1281 );
1282 assert_eq!(
1283 translator.translate_binary_op(ast::BinaryOp::Sub),
1284 BinaryOp::Sub
1285 );
1286 assert_eq!(
1287 translator.translate_binary_op(ast::BinaryOp::Mul),
1288 BinaryOp::Mul
1289 );
1290 assert_eq!(
1291 translator.translate_binary_op(ast::BinaryOp::Div),
1292 BinaryOp::Div
1293 );
1294 assert_eq!(
1295 translator.translate_binary_op(ast::BinaryOp::Mod),
1296 BinaryOp::Mod
1297 );
1298 assert_eq!(
1299 translator.translate_binary_op(ast::BinaryOp::Like),
1300 BinaryOp::Like
1301 );
1302 assert_eq!(
1303 translator.translate_binary_op(ast::BinaryOp::In),
1304 BinaryOp::In
1305 );
1306 }
1307
1308 #[test]
1309 fn test_unary_op_translation() {
1310 let translator = GqlTranslator::new();
1311
1312 assert_eq!(
1313 translator.translate_unary_op(ast::UnaryOp::Not),
1314 UnaryOp::Not
1315 );
1316 assert_eq!(
1317 translator.translate_unary_op(ast::UnaryOp::Neg),
1318 UnaryOp::Neg
1319 );
1320 assert_eq!(
1321 translator.translate_unary_op(ast::UnaryOp::IsNull),
1322 UnaryOp::IsNull
1323 );
1324 assert_eq!(
1325 translator.translate_unary_op(ast::UnaryOp::IsNotNull),
1326 UnaryOp::IsNotNull
1327 );
1328 }
1329}