1use crate::ast::CypherQuery as CypherAST;
7use crate::config::GraphConfig;
8use crate::error::{GraphError, Result};
9use crate::logical_plan::LogicalPlanner;
10use crate::parser::parse_cypher_query;
11use crate::simple_executor::{
12 to_df_boolean_expr_simple, to_df_order_by_expr_simple, to_df_value_expr_simple, PathExecutor,
13};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub enum ExecutionStrategy {
19 #[default]
21 DataFusion,
22 Simple,
24 LanceNative,
26}
27
28#[derive(Debug, Clone)]
30pub struct CypherQuery {
31 query_text: String,
33 ast: CypherAST,
35 config: Option<GraphConfig>,
37 parameters: HashMap<String, serde_json::Value>,
39}
40impl CypherQuery {
41 pub fn new(query: &str) -> Result<Self> {
43 let ast = parse_cypher_query(query)?;
44
45 Ok(Self {
46 query_text: query.to_string(),
47 ast,
48 config: None,
49 parameters: HashMap::new(),
50 })
51 }
52
53 pub fn with_config(mut self, config: GraphConfig) -> Self {
55 self.config = Some(config);
56 self
57 }
58
59 pub fn with_parameter<K, V>(mut self, key: K, value: V) -> Self
61 where
62 K: Into<String>,
63 V: Into<serde_json::Value>,
64 {
65 self.parameters.insert(key.into(), value.into());
66 self
67 }
68
69 pub fn with_parameters(mut self, params: HashMap<String, serde_json::Value>) -> Self {
71 self.parameters.extend(params);
72 self
73 }
74
75 pub fn query_text(&self) -> &str {
77 &self.query_text
78 }
79
80 pub fn ast(&self) -> &CypherAST {
82 &self.ast
83 }
84
85 pub fn config(&self) -> Option<&GraphConfig> {
87 self.config.as_ref()
88 }
89
90 pub fn parameters(&self) -> &HashMap<String, serde_json::Value> {
92 &self.parameters
93 }
94
95 fn require_config(&self) -> Result<&GraphConfig> {
97 self.config.as_ref().ok_or_else(|| GraphError::ConfigError {
98 message: "Graph configuration is required for query execution".to_string(),
99 location: snafu::Location::new(file!(), line!(), column!()),
100 })
101 }
102
103 pub async fn execute(
139 &self,
140 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
141 strategy: Option<ExecutionStrategy>,
142 ) -> Result<arrow::record_batch::RecordBatch> {
143 let strategy = strategy.unwrap_or_default();
144 match strategy {
145 ExecutionStrategy::DataFusion => self.execute_datafusion(datasets).await,
146 ExecutionStrategy::Simple => self.execute_simple(datasets).await,
147 ExecutionStrategy::LanceNative => Err(GraphError::UnsupportedFeature {
148 feature: "Lance native execution strategy is not yet implemented".to_string(),
149 location: snafu::Location::new(file!(), line!(), column!()),
150 }),
151 }
152 }
153
154 pub async fn explain(
190 &self,
191 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
192 ) -> Result<String> {
193 use std::sync::Arc;
194
195 let (catalog, ctx) = self
197 .build_catalog_and_context_from_datasets(datasets)
198 .await?;
199
200 self.explain_internal(Arc::new(catalog), ctx).await
202 }
203
204 pub async fn to_sql(
227 &self,
228 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
229 ) -> Result<String> {
230 use datafusion_sql::unparser::plan_to_sql;
231 use std::sync::Arc;
232
233 let _config = self.require_config()?;
234
235 let (catalog, ctx) = self
237 .build_catalog_and_context_from_datasets(datasets)
238 .await?;
239
240 let (_, df_plan) = self.create_logical_plans(Arc::new(catalog))?;
242
243 let optimized_plan = ctx
246 .state()
247 .optimize(&df_plan)
248 .map_err(|e| GraphError::PlanError {
249 message: format!("Failed to optimize plan: {}", e),
250 location: snafu::Location::new(file!(), line!(), column!()),
251 })?;
252
253 let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
255 message: format!("Failed to unparse plan to SQL: {}", e),
256 location: snafu::Location::new(file!(), line!(), column!()),
257 })?;
258
259 Ok(sql_ast.to_string())
260 }
261
262 pub async fn execute_with_context(
311 &self,
312 ctx: datafusion::execution::context::SessionContext,
313 ) -> Result<arrow::record_batch::RecordBatch> {
314 use crate::source_catalog::InMemoryCatalog;
315 use datafusion::datasource::DefaultTableSource;
316 use std::sync::Arc;
317
318 let config = self.require_config()?;
319
320 let mut catalog = InMemoryCatalog::new();
322
323 for label in config.node_mappings.keys() {
325 let table_provider =
326 ctx.table_provider(label)
327 .await
328 .map_err(|e| GraphError::ConfigError {
329 message: format!(
330 "Node label '{}' not found in SessionContext: {}",
331 label, e
332 ),
333 location: snafu::Location::new(file!(), line!(), column!()),
334 })?;
335
336 let table_source = Arc::new(DefaultTableSource::new(table_provider));
337 catalog = catalog.with_node_source(label, table_source);
338 }
339
340 for rel_type in config.relationship_mappings.keys() {
342 let table_provider =
343 ctx.table_provider(rel_type)
344 .await
345 .map_err(|e| GraphError::ConfigError {
346 message: format!(
347 "Relationship type '{}' not found in SessionContext: {}",
348 rel_type, e
349 ),
350 location: snafu::Location::new(file!(), line!(), column!()),
351 })?;
352
353 let table_source = Arc::new(DefaultTableSource::new(table_provider));
354 catalog = catalog.with_relationship_source(rel_type, table_source);
355 }
356
357 self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
359 .await
360 }
361
362 pub async fn execute_with_catalog_and_context(
399 &self,
400 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
401 ctx: datafusion::execution::context::SessionContext,
402 ) -> Result<arrow::record_batch::RecordBatch> {
403 use arrow::compute::concat_batches;
404
405 let (_logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
407
408 let df = ctx
410 .execute_logical_plan(df_logical_plan)
411 .await
412 .map_err(|e| GraphError::ExecutionError {
413 message: format!("Failed to execute DataFusion plan: {}", e),
414 location: snafu::Location::new(file!(), line!(), column!()),
415 })?;
416
417 let result_schema = df.schema().inner().clone();
419
420 let batches = df.collect().await.map_err(|e| GraphError::ExecutionError {
422 message: format!("Failed to collect query results: {}", e),
423 location: snafu::Location::new(file!(), line!(), column!()),
424 })?;
425
426 if batches.is_empty() {
427 return Ok(arrow::record_batch::RecordBatch::new_empty(result_schema));
430 }
431
432 let schema = batches[0].schema();
434 concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
435 message: format!("Failed to concatenate result batches: {}", e),
436 location: snafu::Location::new(file!(), line!(), column!()),
437 })
438 }
439
440 async fn execute_datafusion(
456 &self,
457 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
458 ) -> Result<arrow::record_batch::RecordBatch> {
459 use std::sync::Arc;
460
461 let (catalog, ctx) = self
463 .build_catalog_and_context_from_datasets(datasets)
464 .await?;
465
466 self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
468 .await
469 }
470
471 async fn build_catalog_and_context_from_datasets(
473 &self,
474 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
475 ) -> Result<(
476 crate::source_catalog::InMemoryCatalog,
477 datafusion::execution::context::SessionContext,
478 )> {
479 use crate::source_catalog::InMemoryCatalog;
480 use datafusion::datasource::{DefaultTableSource, MemTable};
481 use datafusion::execution::context::SessionContext;
482 use std::sync::Arc;
483
484 if datasets.is_empty() {
485 return Err(GraphError::ConfigError {
486 message: "No input datasets provided".to_string(),
487 location: snafu::Location::new(file!(), line!(), column!()),
488 });
489 }
490
491 let ctx = SessionContext::new();
493 let mut catalog = InMemoryCatalog::new();
494
495 for (name, batch) in &datasets {
497 let mem_table = Arc::new(
498 MemTable::try_new(batch.schema(), vec![vec![batch.clone()]]).map_err(|e| {
499 GraphError::PlanError {
500 message: format!("Failed to create MemTable for {}: {}", name, e),
501 location: snafu::Location::new(file!(), line!(), column!()),
502 }
503 })?,
504 );
505
506 ctx.register_table(name, mem_table.clone())
508 .map_err(|e| GraphError::PlanError {
509 message: format!("Failed to register table {}: {}", name, e),
510 location: snafu::Location::new(file!(), line!(), column!()),
511 })?;
512
513 let table_source = Arc::new(DefaultTableSource::new(mem_table));
514
515 catalog = catalog
518 .with_node_source(name, table_source.clone())
519 .with_relationship_source(name, table_source);
520 }
521
522 Ok((catalog, ctx))
523 }
524
525 async fn explain_internal(
527 &self,
528 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
529 ctx: datafusion::execution::context::SessionContext,
530 ) -> Result<String> {
531 let (logical_plan, df_logical_plan, physical_plan) =
533 self.create_plans(catalog, &ctx).await?;
534
535 self.format_explain_output(&logical_plan, &df_logical_plan, physical_plan.as_ref())
537 }
538
539 fn create_logical_plans(
544 &self,
545 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
546 ) -> Result<(
547 crate::logical_plan::LogicalOperator,
548 datafusion::logical_expr::LogicalPlan,
549 )> {
550 use crate::datafusion_planner::{DataFusionPlanner, GraphPhysicalPlanner};
551 use crate::semantic::SemanticAnalyzer;
552
553 let config = self.require_config()?;
554
555 let mut analyzer = SemanticAnalyzer::new(config.clone());
557 analyzer.analyze(&self.ast)?;
558
559 let mut logical_planner = LogicalPlanner::new();
561 let logical_plan = logical_planner.plan(&self.ast)?;
562
563 let df_planner = DataFusionPlanner::with_catalog(config.clone(), catalog);
565 let df_logical_plan = df_planner.plan(&logical_plan)?;
566
567 Ok((logical_plan, df_logical_plan))
568 }
569
570 async fn create_plans(
572 &self,
573 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
574 ctx: &datafusion::execution::context::SessionContext,
575 ) -> Result<(
576 crate::logical_plan::LogicalOperator,
577 datafusion::logical_expr::LogicalPlan,
578 std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
579 )> {
580 let (logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
582
583 let df = ctx
585 .execute_logical_plan(df_logical_plan.clone())
586 .await
587 .map_err(|e| GraphError::ExecutionError {
588 message: format!("Failed to execute DataFusion plan: {}", e),
589 location: snafu::Location::new(file!(), line!(), column!()),
590 })?;
591
592 let physical_plan =
593 df.create_physical_plan()
594 .await
595 .map_err(|e| GraphError::ExecutionError {
596 message: format!("Failed to create physical plan: {}", e),
597 location: snafu::Location::new(file!(), line!(), column!()),
598 })?;
599
600 Ok((logical_plan, df_logical_plan, physical_plan))
601 }
602
603 fn format_explain_output(
605 &self,
606 logical_plan: &crate::logical_plan::LogicalOperator,
607 df_logical_plan: &datafusion::logical_expr::LogicalPlan,
608 physical_plan: &dyn datafusion::physical_plan::ExecutionPlan,
609 ) -> Result<String> {
610 let mut output = String::new();
612
613 output.push_str("Cypher Query:\n");
615 output.push_str(&format!(" {}\n\n", self.query_text));
616
617 let mut rows = vec![];
619
620 let graph_plan_str = format!("{:#?}", logical_plan);
622 rows.push(("graph_logical_plan", graph_plan_str));
623
624 let df_logical_str = format!("{}", df_logical_plan.display_indent());
626 rows.push(("logical_plan", df_logical_str));
627
628 let df_physical_str = format!(
630 "{}",
631 datafusion::physical_plan::displayable(physical_plan).indent(true)
632 );
633 rows.push(("physical_plan", df_physical_str));
634
635 let plan_type_width = rows.iter().map(|(t, _)| t.len()).max().unwrap_or(10);
637 let plan_width = rows
638 .iter()
639 .map(|(_, p)| p.lines().map(|l| l.len()).max().unwrap_or(0))
640 .max()
641 .unwrap_or(50);
642
643 let separator = format!(
645 "+{}+{}+",
646 "-".repeat(plan_type_width + 2),
647 "-".repeat(plan_width + 2)
648 );
649
650 output.push_str(&separator);
651 output.push('\n');
652
653 output.push_str(&format!(
655 "| {:<width$} | {:<plan_width$} |\n",
656 "plan_type",
657 "plan",
658 width = plan_type_width,
659 plan_width = plan_width
660 ));
661 output.push_str(&separator);
662 output.push('\n');
663
664 for (plan_type, plan_content) in rows {
666 let lines: Vec<&str> = plan_content.lines().collect();
667 if lines.is_empty() {
668 output.push_str(&format!(
669 "| {:<width$} | {:<plan_width$} |\n",
670 plan_type,
671 "",
672 width = plan_type_width,
673 plan_width = plan_width
674 ));
675 } else {
676 output.push_str(&format!(
678 "| {:<width$} | {:<plan_width$} |\n",
679 plan_type,
680 lines[0],
681 width = plan_type_width,
682 plan_width = plan_width
683 ));
684
685 for line in &lines[1..] {
687 output.push_str(&format!(
688 "| {:<width$} | {:<plan_width$} |\n",
689 "",
690 line,
691 width = plan_type_width,
692 plan_width = plan_width
693 ));
694 }
695 }
696 }
697
698 output.push_str(&separator);
699 output.push('\n');
700
701 Ok(output)
702 }
703
704 pub async fn execute_simple(
711 &self,
712 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
713 ) -> Result<arrow::record_batch::RecordBatch> {
714 use arrow::compute::concat_batches;
715 use datafusion::datasource::MemTable;
716 use datafusion::prelude::*;
717 use std::sync::Arc;
718
719 let _config = self.require_config()?;
721
722 if datasets.is_empty() {
723 return Err(GraphError::PlanError {
724 message: "No input datasets provided".to_string(),
725 location: snafu::Location::new(file!(), line!(), column!()),
726 });
727 }
728
729 let ctx = SessionContext::new();
731 for (name, batch) in &datasets {
732 let table =
733 MemTable::try_new(batch.schema(), vec![vec![batch.clone()]]).map_err(|e| {
734 GraphError::PlanError {
735 message: format!("Failed to create DataFusion table: {}", e),
736 location: snafu::Location::new(file!(), line!(), column!()),
737 }
738 })?;
739 ctx.register_table(name, Arc::new(table))
740 .map_err(|e| GraphError::PlanError {
741 message: format!("Failed to register table '{}': {}", name, e),
742 location: snafu::Location::new(file!(), line!(), column!()),
743 })?;
744 }
745
746 if let Some(df) = self.try_execute_path_generic(&ctx).await? {
748 let batches = df.collect().await.map_err(|e| GraphError::PlanError {
749 message: format!("Failed to collect results: {}", e),
750 location: snafu::Location::new(file!(), line!(), column!()),
751 })?;
752 if batches.is_empty() {
753 let schema = datasets.values().next().unwrap().schema();
754 return Ok(arrow_array::RecordBatch::new_empty(schema));
755 }
756 let merged = concat_batches(&batches[0].schema(), &batches).map_err(|e| {
757 GraphError::PlanError {
758 message: format!("Failed to concatenate result batches: {}", e),
759 location: snafu::Location::new(file!(), line!(), column!()),
760 }
761 })?;
762 return Ok(merged);
763 }
764
765 let (table_name, batch) = datasets.iter().next().unwrap();
767 let schema = batch.schema();
768
769 let mut df = ctx
771 .table(table_name)
772 .await
773 .map_err(|e| GraphError::PlanError {
774 message: format!("Failed to create DataFrame for '{}': {}", table_name, e),
775 location: snafu::Location::new(file!(), line!(), column!()),
776 })?;
777
778 if let Some(where_clause) = &self.ast.where_clause {
780 if let Some(filter_expr) = to_df_boolean_expr_simple(&where_clause.expression) {
781 df = df.filter(filter_expr).map_err(|e| GraphError::PlanError {
782 message: format!("Failed to apply filter: {}", e),
783 location: snafu::Location::new(file!(), line!(), column!()),
784 })?;
785 }
786 }
787
788 let proj_exprs: Vec<Expr> = self
790 .ast
791 .return_clause
792 .items
793 .iter()
794 .map(|item| to_df_value_expr_simple(&item.expression))
795 .collect();
796 if !proj_exprs.is_empty() {
797 df = df.select(proj_exprs).map_err(|e| GraphError::PlanError {
798 message: format!("Failed to project: {}", e),
799 location: snafu::Location::new(file!(), line!(), column!()),
800 })?;
801 }
802
803 if self.ast.return_clause.distinct {
805 df = df.distinct().map_err(|e| GraphError::PlanError {
806 message: format!("Failed to apply DISTINCT: {}", e),
807 location: snafu::Location::new(file!(), line!(), column!()),
808 })?;
809 }
810
811 if let Some(order_by) = &self.ast.order_by {
813 let sort_expr = to_df_order_by_expr_simple(&order_by.items);
814 df = df.sort(sort_expr).map_err(|e| GraphError::PlanError {
815 message: format!("Failed to apply ORDER BY: {}", e),
816 location: snafu::Location::new(file!(), line!(), column!()),
817 })?;
818 }
819
820 if self.ast.skip.is_some() || self.ast.limit.is_some() {
822 let offset = self.ast.skip.unwrap_or(0) as usize;
823 let fetch = self.ast.limit.map(|l| l as usize);
824 df = df.limit(offset, fetch).map_err(|e| GraphError::PlanError {
825 message: format!("Failed to apply SKIP/LIMIT: {}", e),
826 location: snafu::Location::new(file!(), line!(), column!()),
827 })?;
828 }
829
830 let batches = df.collect().await.map_err(|e| GraphError::PlanError {
832 message: format!("Failed to collect results: {}", e),
833 location: snafu::Location::new(file!(), line!(), column!()),
834 })?;
835
836 if batches.is_empty() {
837 return Ok(arrow_array::RecordBatch::new_empty(schema));
839 }
840
841 let merged =
842 concat_batches(&batches[0].schema(), &batches).map_err(|e| GraphError::PlanError {
843 message: format!("Failed to concatenate result batches: {}", e),
844 location: snafu::Location::new(file!(), line!(), column!()),
845 })?;
846 Ok(merged)
847 }
848
849 pub fn referenced_node_labels(&self) -> Vec<String> {
851 let mut labels = Vec::new();
852
853 for match_clause in &self.ast.match_clauses {
854 for pattern in &match_clause.patterns {
855 self.collect_node_labels_from_pattern(pattern, &mut labels);
856 }
857 }
858
859 labels.sort();
860 labels.dedup();
861 labels
862 }
863
864 pub fn referenced_relationship_types(&self) -> Vec<String> {
866 let mut types = Vec::new();
867
868 for match_clause in &self.ast.match_clauses {
869 for pattern in &match_clause.patterns {
870 self.collect_relationship_types_from_pattern(pattern, &mut types);
871 }
872 }
873
874 types.sort();
875 types.dedup();
876 types
877 }
878
879 pub fn variables(&self) -> Vec<String> {
881 let mut variables = Vec::new();
882
883 for match_clause in &self.ast.match_clauses {
884 for pattern in &match_clause.patterns {
885 self.collect_variables_from_pattern(pattern, &mut variables);
886 }
887 }
888
889 variables.sort();
890 variables.dedup();
891 variables
892 }
893
894 fn collect_node_labels_from_pattern(
897 &self,
898 pattern: &crate::ast::GraphPattern,
899 labels: &mut Vec<String>,
900 ) {
901 match pattern {
902 crate::ast::GraphPattern::Node(node) => {
903 labels.extend(node.labels.clone());
904 }
905 crate::ast::GraphPattern::Path(path) => {
906 labels.extend(path.start_node.labels.clone());
907 for segment in &path.segments {
908 labels.extend(segment.end_node.labels.clone());
909 }
910 }
911 }
912 }
913
914 fn collect_relationship_types_from_pattern(
915 &self,
916 pattern: &crate::ast::GraphPattern,
917 types: &mut Vec<String>,
918 ) {
919 if let crate::ast::GraphPattern::Path(path) = pattern {
920 for segment in &path.segments {
921 types.extend(segment.relationship.types.clone());
922 }
923 }
924 }
925
926 fn collect_variables_from_pattern(
927 &self,
928 pattern: &crate::ast::GraphPattern,
929 variables: &mut Vec<String>,
930 ) {
931 match pattern {
932 crate::ast::GraphPattern::Node(node) => {
933 if let Some(var) = &node.variable {
934 variables.push(var.clone());
935 }
936 }
937 crate::ast::GraphPattern::Path(path) => {
938 if let Some(var) = &path.start_node.variable {
939 variables.push(var.clone());
940 }
941 for segment in &path.segments {
942 if let Some(var) = &segment.relationship.variable {
943 variables.push(var.clone());
944 }
945 if let Some(var) = &segment.end_node.variable {
946 variables.push(var.clone());
947 }
948 }
949 }
950 }
951 }
952}
953
954impl CypherQuery {
955 pub async fn execute_with_vector_rerank(
984 &self,
985 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
986 vector_search: crate::lance_vector_search::VectorSearch,
987 ) -> Result<arrow::record_batch::RecordBatch> {
988 let candidates = self.execute(datasets, None).await?;
990
991 vector_search.search(&candidates).await
993 }
994}
995
996impl CypherQuery {
997 async fn try_execute_path_generic(
999 &self,
1000 ctx: &datafusion::prelude::SessionContext,
1001 ) -> Result<Option<datafusion::dataframe::DataFrame>> {
1002 use crate::ast::GraphPattern;
1003 let [mc] = self.ast.match_clauses.as_slice() else {
1004 return Ok(None);
1005 };
1006 let match_clause = mc;
1007 let path = match match_clause.patterns.as_slice() {
1008 [GraphPattern::Path(p)] if !p.segments.is_empty() => p,
1009 _ => return Ok(None),
1010 };
1011 let cfg = self.require_config()?;
1012
1013 if path.segments.len() == 1 {
1015 if let Some(length_range) = &path.segments[0].relationship.length {
1016 let cap: u32 = crate::MAX_VARIABLE_LENGTH_HOPS;
1017 let min_len = length_range.min.unwrap_or(1).max(1);
1018 let max_len = length_range.max.unwrap_or(cap);
1019
1020 if min_len > max_len {
1021 return Err(GraphError::InvalidPattern {
1022 message: format!(
1023 "Invalid variable-length range: min {:?} greater than max {:?}",
1024 length_range.min, length_range.max
1025 ),
1026 location: snafu::Location::new(file!(), line!(), column!()),
1027 });
1028 }
1029
1030 if max_len > cap {
1031 return Err(GraphError::UnsupportedFeature {
1032 feature: format!(
1033 "Variable-length paths with length > {} are not supported (got {:?}..{:?})",
1034 cap, length_range.min, length_range.max
1035 ),
1036 location: snafu::Location::new(file!(), line!(), column!()),
1037 });
1038 }
1039
1040 use datafusion::dataframe::DataFrame;
1041 let mut union_df: Option<DataFrame> = None;
1042
1043 for hops in min_len..=max_len {
1044 let mut synthetic = crate::ast::PathPattern {
1046 start_node: path.start_node.clone(),
1047 segments: Vec::with_capacity(hops as usize),
1048 };
1049
1050 for i in 0..hops {
1051 let mut seg = path.segments[0].clone();
1052 seg.relationship.variable = None;
1054 if (i + 1) < hops {
1055 seg.end_node.variable = None; }
1057 seg.relationship.length = None;
1059 synthetic.segments.push(seg);
1060 }
1061
1062 let exec = PathExecutor::new(ctx, cfg, &synthetic)?;
1063 let mut df = exec.build_chain().await?;
1064 df = exec.apply_where(df, &self.ast)?;
1065 df = exec.apply_return(df, &self.ast)?;
1066
1067 union_df = Some(match union_df {
1068 Some(acc) => acc.union(df).map_err(|e| GraphError::PlanError {
1069 message: format!("Failed to UNION variable-length paths: {}", e),
1070 location: snafu::Location::new(file!(), line!(), column!()),
1071 })?,
1072 None => df,
1073 });
1074 }
1075
1076 return Ok(union_df);
1077 }
1078 }
1079
1080 let exec = PathExecutor::new(ctx, cfg, path)?;
1081 let df = exec.build_chain().await?;
1082 let df = exec.apply_where(df, &self.ast)?;
1083 let df = exec.apply_return(df, &self.ast)?;
1084 Ok(Some(df))
1085 }
1086}
1087
1088#[derive(Debug, Default)]
1090pub struct CypherQueryBuilder {
1091 match_clauses: Vec<crate::ast::MatchClause>,
1092 where_expression: Option<crate::ast::BooleanExpression>,
1093 return_items: Vec<crate::ast::ReturnItem>,
1094 order_by_items: Vec<crate::ast::OrderByItem>,
1095 limit: Option<u64>,
1096 distinct: bool,
1097 skip: Option<u64>,
1098 config: Option<GraphConfig>,
1099 parameters: HashMap<String, serde_json::Value>,
1100}
1101
1102impl CypherQueryBuilder {
1103 pub fn new() -> Self {
1105 Self::default()
1106 }
1107
1108 pub fn match_node(mut self, variable: &str, label: &str) -> Self {
1110 let node = crate::ast::NodePattern {
1111 variable: Some(variable.to_string()),
1112 labels: vec![label.to_string()],
1113 properties: HashMap::new(),
1114 };
1115
1116 let match_clause = crate::ast::MatchClause {
1117 patterns: vec![crate::ast::GraphPattern::Node(node)],
1118 };
1119
1120 self.match_clauses.push(match_clause);
1121 self
1122 }
1123
1124 pub fn with_config(mut self, config: GraphConfig) -> Self {
1126 self.config = Some(config);
1127 self
1128 }
1129
1130 pub fn return_property(mut self, variable: &str, property: &str) -> Self {
1132 let prop_ref = crate::ast::PropertyRef::new(variable, property);
1133 let return_item = crate::ast::ReturnItem {
1134 expression: crate::ast::ValueExpression::Property(prop_ref),
1135 alias: None,
1136 };
1137
1138 self.return_items.push(return_item);
1139 self
1140 }
1141
1142 pub fn distinct(mut self, distinct: bool) -> Self {
1144 self.distinct = distinct;
1145 self
1146 }
1147
1148 pub fn limit(mut self, limit: u64) -> Self {
1150 self.limit = Some(limit);
1151 self
1152 }
1153
1154 pub fn skip(mut self, skip: u64) -> Self {
1156 self.skip = Some(skip);
1157 self
1158 }
1159
1160 pub fn build(self) -> Result<CypherQuery> {
1162 if self.match_clauses.is_empty() {
1163 return Err(GraphError::PlanError {
1164 message: "Query must have at least one MATCH clause".to_string(),
1165 location: snafu::Location::new(file!(), line!(), column!()),
1166 });
1167 }
1168
1169 if self.return_items.is_empty() {
1170 return Err(GraphError::PlanError {
1171 message: "Query must have at least one RETURN item".to_string(),
1172 location: snafu::Location::new(file!(), line!(), column!()),
1173 });
1174 }
1175
1176 let ast = crate::ast::CypherQuery {
1177 match_clauses: self.match_clauses,
1178 where_clause: self
1179 .where_expression
1180 .map(|expr| crate::ast::WhereClause { expression: expr }),
1181 return_clause: crate::ast::ReturnClause {
1182 distinct: self.distinct,
1183 items: self.return_items,
1184 },
1185 order_by: if self.order_by_items.is_empty() {
1186 None
1187 } else {
1188 Some(crate::ast::OrderByClause {
1189 items: self.order_by_items,
1190 })
1191 },
1192 limit: self.limit,
1193 skip: self.skip,
1194 };
1195
1196 let query_text = "MATCH ... RETURN ...".to_string(); let query = CypherQuery {
1200 query_text,
1201 ast,
1202 config: self.config,
1203 parameters: self.parameters,
1204 };
1205
1206 Ok(query)
1207 }
1208}
1209
1210#[cfg(test)]
1211mod tests {
1212 use super::*;
1213 use crate::config::GraphConfig;
1214
1215 #[test]
1216 fn test_parse_simple_cypher_query() {
1217 let query = CypherQuery::new("MATCH (n:Person) RETURN n.name").unwrap();
1218 assert_eq!(query.query_text(), "MATCH (n:Person) RETURN n.name");
1219 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1220 assert_eq!(query.variables(), vec!["n"]);
1221 }
1222
1223 #[test]
1224 fn test_query_with_parameters() {
1225 let mut params = HashMap::new();
1226 params.insert("minAge".to_string(), serde_json::Value::Number(30.into()));
1227
1228 let query = CypherQuery::new("MATCH (n:Person) WHERE n.age > $minAge RETURN n.name")
1229 .unwrap()
1230 .with_parameters(params);
1231
1232 assert!(query.parameters().contains_key("minAge"));
1233 }
1234
1235 #[test]
1236 fn test_query_builder() {
1237 let config = GraphConfig::builder()
1238 .with_node_label("Person", "person_id")
1239 .build()
1240 .unwrap();
1241
1242 let query = CypherQueryBuilder::new()
1243 .with_config(config)
1244 .match_node("n", "Person")
1245 .return_property("n", "name")
1246 .limit(10)
1247 .build()
1248 .unwrap();
1249
1250 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1251 assert_eq!(query.variables(), vec!["n"]);
1252 }
1253
1254 #[test]
1255 fn test_relationship_query_parsing() {
1256 let query =
1257 CypherQuery::new("MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a.name, b.name")
1258 .unwrap();
1259 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1260 assert_eq!(query.referenced_relationship_types(), vec!["KNOWS"]);
1261 assert_eq!(query.variables(), vec!["a", "b", "r"]);
1262 }
1263
1264 #[tokio::test]
1265 async fn test_execute_basic_projection_and_filter() {
1266 use arrow_array::{Int64Array, RecordBatch, StringArray};
1267 use arrow_schema::{DataType, Field, Schema};
1268 use std::sync::Arc;
1269
1270 let schema = Arc::new(Schema::new(vec![
1272 Field::new("name", DataType::Utf8, true),
1273 Field::new("age", DataType::Int64, true),
1274 ]));
1275 let batch = RecordBatch::try_new(
1276 schema,
1277 vec![
1278 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1279 Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
1280 ],
1281 )
1282 .unwrap();
1283
1284 let cfg = GraphConfig::builder()
1285 .with_node_label("Person", "id")
1286 .build()
1287 .unwrap();
1288
1289 let q = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
1290 .unwrap()
1291 .with_config(cfg);
1292
1293 let mut data = HashMap::new();
1294 data.insert("people".to_string(), batch);
1295
1296 let out = q.execute_simple(data).await.unwrap();
1297 assert_eq!(out.num_rows(), 2);
1298 let names = out
1299 .column(0)
1300 .as_any()
1301 .downcast_ref::<StringArray>()
1302 .unwrap();
1303 let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
1304 let result: Vec<(String, i64)> = (0..out.num_rows())
1306 .map(|i| (names.value(i).to_string(), ages.value(i)))
1307 .collect();
1308 assert!(result.contains(&("Bob".to_string(), 34)));
1309 assert!(result.contains(&("David".to_string(), 42)));
1310 }
1311
1312 #[tokio::test]
1313 async fn test_execute_single_hop_path_join_projection() {
1314 use arrow_array::{Int64Array, RecordBatch, StringArray};
1315 use arrow_schema::{DataType, Field, Schema};
1316 use std::sync::Arc;
1317
1318 let person_schema = Arc::new(Schema::new(vec![
1320 Field::new("id", DataType::Int64, false),
1321 Field::new("name", DataType::Utf8, true),
1322 Field::new("age", DataType::Int64, true),
1323 ]));
1324 let people = RecordBatch::try_new(
1325 person_schema,
1326 vec![
1327 Arc::new(Int64Array::from(vec![1, 2, 3])),
1328 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1329 Arc::new(Int64Array::from(vec![28, 34, 29])),
1330 ],
1331 )
1332 .unwrap();
1333
1334 let rel_schema = Arc::new(Schema::new(vec![
1336 Field::new("src_person_id", DataType::Int64, false),
1337 Field::new("dst_person_id", DataType::Int64, false),
1338 ]));
1339 let knows = RecordBatch::try_new(
1340 rel_schema,
1341 vec![
1342 Arc::new(Int64Array::from(vec![1, 2])), Arc::new(Int64Array::from(vec![2, 3])),
1344 ],
1345 )
1346 .unwrap();
1347
1348 let cfg = GraphConfig::builder()
1350 .with_node_label("Person", "id")
1351 .with_relationship("KNOWS", "src_person_id", "dst_person_id")
1352 .build()
1353 .unwrap();
1354
1355 let q = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN b.name")
1357 .unwrap()
1358 .with_config(cfg);
1359
1360 let mut data = HashMap::new();
1361 data.insert("Person".to_string(), people);
1363 data.insert("KNOWS".to_string(), knows);
1364
1365 let out = q.execute_simple(data).await.unwrap();
1366 let names = out
1368 .column(0)
1369 .as_any()
1370 .downcast_ref::<StringArray>()
1371 .unwrap();
1372 let got: Vec<String> = (0..out.num_rows())
1373 .map(|i| names.value(i).to_string())
1374 .collect();
1375 assert_eq!(got.len(), 2);
1376 assert!(got.contains(&"Bob".to_string()));
1377 assert!(got.contains(&"Carol".to_string()));
1378 }
1379
1380 #[tokio::test]
1381 async fn test_execute_order_by_asc() {
1382 use arrow_array::{Int64Array, RecordBatch, StringArray};
1383 use arrow_schema::{DataType, Field, Schema};
1384 use std::sync::Arc;
1385
1386 let schema = Arc::new(Schema::new(vec![
1388 Field::new("name", DataType::Utf8, true),
1389 Field::new("age", DataType::Int64, true),
1390 ]));
1391 let batch = RecordBatch::try_new(
1392 schema,
1393 vec![
1394 Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
1395 Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
1396 ],
1397 )
1398 .unwrap();
1399
1400 let cfg = GraphConfig::builder()
1401 .with_node_label("Person", "id")
1402 .build()
1403 .unwrap();
1404
1405 let q = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age ORDER BY p.age ASC")
1407 .unwrap()
1408 .with_config(cfg);
1409
1410 let mut data = HashMap::new();
1411 data.insert("people".to_string(), batch);
1412
1413 let out = q.execute_simple(data).await.unwrap();
1414 let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
1415 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1416 assert_eq!(collected, vec![28, 29, 34, 42]);
1417 }
1418
1419 #[tokio::test]
1420 async fn test_execute_order_by_desc_with_skip_limit() {
1421 use arrow_array::{Int64Array, RecordBatch, StringArray};
1422 use arrow_schema::{DataType, Field, Schema};
1423 use std::sync::Arc;
1424
1425 let schema = Arc::new(Schema::new(vec![
1426 Field::new("name", DataType::Utf8, true),
1427 Field::new("age", DataType::Int64, true),
1428 ]));
1429 let batch = RecordBatch::try_new(
1430 schema,
1431 vec![
1432 Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
1433 Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
1434 ],
1435 )
1436 .unwrap();
1437
1438 let cfg = GraphConfig::builder()
1439 .with_node_label("Person", "id")
1440 .build()
1441 .unwrap();
1442
1443 let q =
1445 CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age DESC SKIP 1 LIMIT 2")
1446 .unwrap()
1447 .with_config(cfg);
1448
1449 let mut data = HashMap::new();
1450 data.insert("people".to_string(), batch);
1451
1452 let out = q.execute_simple(data).await.unwrap();
1453 assert_eq!(out.num_rows(), 2);
1454 let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
1455 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1456 assert_eq!(collected, vec![34, 29]);
1457 }
1458
1459 #[tokio::test]
1460 async fn test_execute_skip_without_limit() {
1461 use arrow_array::{Int64Array, RecordBatch};
1462 use arrow_schema::{DataType, Field, Schema};
1463 use std::sync::Arc;
1464
1465 let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
1466 let batch = RecordBatch::try_new(
1467 schema,
1468 vec![Arc::new(Int64Array::from(vec![10, 20, 30, 40]))],
1469 )
1470 .unwrap();
1471
1472 let cfg = GraphConfig::builder()
1473 .with_node_label("Person", "id")
1474 .build()
1475 .unwrap();
1476
1477 let q = CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age ASC SKIP 2")
1478 .unwrap()
1479 .with_config(cfg);
1480
1481 let mut data = HashMap::new();
1482 data.insert("people".to_string(), batch);
1483
1484 let out = q.execute_simple(data).await.unwrap();
1485 assert_eq!(out.num_rows(), 2);
1486 let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
1487 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1488 assert_eq!(collected, vec![30, 40]);
1489 }
1490
1491 #[tokio::test]
1492 async fn test_execute_datafusion_pipeline() {
1493 use arrow_array::{Int64Array, RecordBatch, StringArray};
1494 use arrow_schema::{DataType, Field, Schema};
1495 use std::sync::Arc;
1496
1497 let schema = Arc::new(Schema::new(vec![
1499 Field::new("id", DataType::Int64, false),
1500 Field::new("name", DataType::Utf8, false),
1501 Field::new("age", DataType::Int64, false),
1502 ]));
1503
1504 let batch = RecordBatch::try_new(
1505 schema,
1506 vec![
1507 Arc::new(Int64Array::from(vec![1, 2, 3])),
1508 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
1509 Arc::new(Int64Array::from(vec![25, 35, 30])),
1510 ],
1511 )
1512 .unwrap();
1513
1514 let cfg = GraphConfig::builder()
1515 .with_node_label("Person", "id")
1516 .build()
1517 .unwrap();
1518
1519 let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name")
1521 .unwrap()
1522 .with_config(cfg);
1523
1524 let mut datasets = HashMap::new();
1525 datasets.insert("Person".to_string(), batch);
1526
1527 let result = query.execute_datafusion(datasets.clone()).await;
1529
1530 match &result {
1531 Ok(batch) => {
1532 println!(
1533 "DataFusion result: {} rows, {} columns",
1534 batch.num_rows(),
1535 batch.num_columns()
1536 );
1537 if batch.num_rows() > 0 {
1538 println!("First row data: {:?}", batch.slice(0, 1));
1539 }
1540 }
1541 Err(e) => {
1542 println!("DataFusion execution failed: {:?}", e);
1543 }
1544 }
1545
1546 let legacy_result = query.execute_simple(datasets).await.unwrap();
1548 println!(
1549 "Legacy result: {} rows, {} columns",
1550 legacy_result.num_rows(),
1551 legacy_result.num_columns()
1552 );
1553
1554 let result = result.unwrap();
1555
1556 assert_eq!(
1558 result.num_rows(),
1559 1,
1560 "Expected 1 row after filtering WHERE p.age > 30"
1561 );
1562
1563 assert_eq!(
1565 result.num_columns(),
1566 1,
1567 "Expected 1 column after projection RETURN p.name"
1568 );
1569
1570 let names = result
1572 .column(0)
1573 .as_any()
1574 .downcast_ref::<StringArray>()
1575 .unwrap();
1576 assert_eq!(
1577 names.value(0),
1578 "Bob",
1579 "Expected filtered result to contain Bob"
1580 );
1581 }
1582
1583 #[tokio::test]
1584 async fn test_execute_datafusion_simple_scan() {
1585 use arrow_array::{Int64Array, RecordBatch, StringArray};
1586 use arrow_schema::{DataType, Field, Schema};
1587 use std::sync::Arc;
1588
1589 let schema = Arc::new(Schema::new(vec![
1591 Field::new("id", DataType::Int64, false),
1592 Field::new("name", DataType::Utf8, false),
1593 ]));
1594
1595 let batch = RecordBatch::try_new(
1596 schema,
1597 vec![
1598 Arc::new(Int64Array::from(vec![1, 2])),
1599 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
1600 ],
1601 )
1602 .unwrap();
1603
1604 let cfg = GraphConfig::builder()
1605 .with_node_label("Person", "id")
1606 .build()
1607 .unwrap();
1608
1609 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1611 .unwrap()
1612 .with_config(cfg);
1613
1614 let mut datasets = HashMap::new();
1615 datasets.insert("Person".to_string(), batch);
1616
1617 let result = query.execute_datafusion(datasets).await.unwrap();
1619
1620 assert_eq!(
1622 result.num_rows(),
1623 2,
1624 "Should return all 2 rows without filtering"
1625 );
1626 assert_eq!(result.num_columns(), 1, "Should return 1 column (name)");
1627
1628 let names = result
1630 .column(0)
1631 .as_any()
1632 .downcast_ref::<StringArray>()
1633 .unwrap();
1634 let name_set: std::collections::HashSet<String> = (0..result.num_rows())
1635 .map(|i| names.value(i).to_string())
1636 .collect();
1637 let expected: std::collections::HashSet<String> =
1638 ["Alice", "Bob"].iter().map(|s| s.to_string()).collect();
1639 assert_eq!(name_set, expected, "Should return Alice and Bob");
1640 }
1641
1642 #[tokio::test]
1643 async fn test_execute_with_context_simple_scan() {
1644 use arrow_array::{Int64Array, RecordBatch, StringArray};
1645 use arrow_schema::{DataType, Field, Schema};
1646 use datafusion::datasource::MemTable;
1647 use datafusion::execution::context::SessionContext;
1648 use std::sync::Arc;
1649
1650 let schema = Arc::new(Schema::new(vec![
1652 Field::new("id", DataType::Int64, false),
1653 Field::new("name", DataType::Utf8, false),
1654 Field::new("age", DataType::Int64, false),
1655 ]));
1656 let batch = RecordBatch::try_new(
1657 schema.clone(),
1658 vec![
1659 Arc::new(Int64Array::from(vec![1, 2, 3])),
1660 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1661 Arc::new(Int64Array::from(vec![28, 34, 29])),
1662 ],
1663 )
1664 .unwrap();
1665
1666 let mem_table =
1668 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1669 let ctx = SessionContext::new();
1670 ctx.register_table("Person", mem_table).unwrap();
1671
1672 let cfg = GraphConfig::builder()
1674 .with_node_label("Person", "id")
1675 .build()
1676 .unwrap();
1677
1678 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1679 .unwrap()
1680 .with_config(cfg);
1681
1682 let result = query.execute_with_context(ctx).await.unwrap();
1684
1685 assert_eq!(result.num_rows(), 3);
1687 assert_eq!(result.num_columns(), 1);
1688
1689 let names = result
1690 .column(0)
1691 .as_any()
1692 .downcast_ref::<StringArray>()
1693 .unwrap();
1694 assert_eq!(names.value(0), "Alice");
1695 assert_eq!(names.value(1), "Bob");
1696 assert_eq!(names.value(2), "Carol");
1697 }
1698
1699 #[tokio::test]
1700 async fn test_execute_with_context_with_filter() {
1701 use arrow_array::{Int64Array, RecordBatch, StringArray};
1702 use arrow_schema::{DataType, Field, Schema};
1703 use datafusion::datasource::MemTable;
1704 use datafusion::execution::context::SessionContext;
1705 use std::sync::Arc;
1706
1707 let schema = Arc::new(Schema::new(vec![
1709 Field::new("id", DataType::Int64, false),
1710 Field::new("name", DataType::Utf8, false),
1711 Field::new("age", DataType::Int64, false),
1712 ]));
1713 let batch = RecordBatch::try_new(
1714 schema.clone(),
1715 vec![
1716 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
1717 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1718 Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
1719 ],
1720 )
1721 .unwrap();
1722
1723 let mem_table =
1725 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1726 let ctx = SessionContext::new();
1727 ctx.register_table("Person", mem_table).unwrap();
1728
1729 let cfg = GraphConfig::builder()
1731 .with_node_label("Person", "id")
1732 .build()
1733 .unwrap();
1734
1735 let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
1736 .unwrap()
1737 .with_config(cfg);
1738
1739 let result = query.execute_with_context(ctx).await.unwrap();
1741
1742 assert_eq!(result.num_rows(), 2);
1744 assert_eq!(result.num_columns(), 2);
1745
1746 let names = result
1747 .column(0)
1748 .as_any()
1749 .downcast_ref::<StringArray>()
1750 .unwrap();
1751 let ages = result
1752 .column(1)
1753 .as_any()
1754 .downcast_ref::<Int64Array>()
1755 .unwrap();
1756
1757 let results: Vec<(String, i64)> = (0..result.num_rows())
1758 .map(|i| (names.value(i).to_string(), ages.value(i)))
1759 .collect();
1760
1761 assert!(results.contains(&("Bob".to_string(), 34)));
1762 assert!(results.contains(&("David".to_string(), 42)));
1763 }
1764
1765 #[tokio::test]
1766 async fn test_execute_with_context_relationship_traversal() {
1767 use arrow_array::{Int64Array, RecordBatch, StringArray};
1768 use arrow_schema::{DataType, Field, Schema};
1769 use datafusion::datasource::MemTable;
1770 use datafusion::execution::context::SessionContext;
1771 use std::sync::Arc;
1772
1773 let person_schema = Arc::new(Schema::new(vec![
1775 Field::new("id", DataType::Int64, false),
1776 Field::new("name", DataType::Utf8, false),
1777 ]));
1778 let person_batch = RecordBatch::try_new(
1779 person_schema.clone(),
1780 vec![
1781 Arc::new(Int64Array::from(vec![1, 2, 3])),
1782 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1783 ],
1784 )
1785 .unwrap();
1786
1787 let knows_schema = Arc::new(Schema::new(vec![
1789 Field::new("src_id", DataType::Int64, false),
1790 Field::new("dst_id", DataType::Int64, false),
1791 Field::new("since", DataType::Int64, false),
1792 ]));
1793 let knows_batch = RecordBatch::try_new(
1794 knows_schema.clone(),
1795 vec![
1796 Arc::new(Int64Array::from(vec![1, 2])),
1797 Arc::new(Int64Array::from(vec![2, 3])),
1798 Arc::new(Int64Array::from(vec![2020, 2021])),
1799 ],
1800 )
1801 .unwrap();
1802
1803 let person_table = Arc::new(
1805 MemTable::try_new(person_schema.clone(), vec![vec![person_batch.clone()]]).unwrap(),
1806 );
1807 let knows_table = Arc::new(
1808 MemTable::try_new(knows_schema.clone(), vec![vec![knows_batch.clone()]]).unwrap(),
1809 );
1810
1811 let ctx = SessionContext::new();
1812 ctx.register_table("Person", person_table).unwrap();
1813 ctx.register_table("KNOWS", knows_table).unwrap();
1814
1815 let cfg = GraphConfig::builder()
1817 .with_node_label("Person", "id")
1818 .with_relationship("KNOWS", "src_id", "dst_id")
1819 .build()
1820 .unwrap();
1821
1822 let query = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a.name, b.name")
1823 .unwrap()
1824 .with_config(cfg);
1825
1826 let result = query.execute_with_context(ctx).await.unwrap();
1828
1829 assert_eq!(result.num_rows(), 2);
1831 assert_eq!(result.num_columns(), 2);
1832
1833 let src_names = result
1834 .column(0)
1835 .as_any()
1836 .downcast_ref::<StringArray>()
1837 .unwrap();
1838 let dst_names = result
1839 .column(1)
1840 .as_any()
1841 .downcast_ref::<StringArray>()
1842 .unwrap();
1843
1844 let relationships: Vec<(String, String)> = (0..result.num_rows())
1845 .map(|i| {
1846 (
1847 src_names.value(i).to_string(),
1848 dst_names.value(i).to_string(),
1849 )
1850 })
1851 .collect();
1852
1853 assert!(relationships.contains(&("Alice".to_string(), "Bob".to_string())));
1854 assert!(relationships.contains(&("Bob".to_string(), "Carol".to_string())));
1855 }
1856
1857 #[tokio::test]
1858 async fn test_execute_with_context_order_by_limit() {
1859 use arrow_array::{Int64Array, RecordBatch, StringArray};
1860 use arrow_schema::{DataType, Field, Schema};
1861 use datafusion::datasource::MemTable;
1862 use datafusion::execution::context::SessionContext;
1863 use std::sync::Arc;
1864
1865 let schema = Arc::new(Schema::new(vec![
1867 Field::new("id", DataType::Int64, false),
1868 Field::new("name", DataType::Utf8, false),
1869 Field::new("score", DataType::Int64, false),
1870 ]));
1871 let batch = RecordBatch::try_new(
1872 schema.clone(),
1873 vec![
1874 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
1875 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1876 Arc::new(Int64Array::from(vec![85, 92, 78, 95])),
1877 ],
1878 )
1879 .unwrap();
1880
1881 let mem_table =
1883 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1884 let ctx = SessionContext::new();
1885 ctx.register_table("Student", mem_table).unwrap();
1886
1887 let cfg = GraphConfig::builder()
1889 .with_node_label("Student", "id")
1890 .build()
1891 .unwrap();
1892
1893 let query = CypherQuery::new(
1894 "MATCH (s:Student) RETURN s.name, s.score ORDER BY s.score DESC LIMIT 2",
1895 )
1896 .unwrap()
1897 .with_config(cfg);
1898
1899 let result = query.execute_with_context(ctx).await.unwrap();
1901
1902 assert_eq!(result.num_rows(), 2);
1904 assert_eq!(result.num_columns(), 2);
1905
1906 let names = result
1907 .column(0)
1908 .as_any()
1909 .downcast_ref::<StringArray>()
1910 .unwrap();
1911 let scores = result
1912 .column(1)
1913 .as_any()
1914 .downcast_ref::<Int64Array>()
1915 .unwrap();
1916
1917 assert_eq!(names.value(0), "David");
1919 assert_eq!(scores.value(0), 95);
1920
1921 assert_eq!(names.value(1), "Bob");
1923 assert_eq!(scores.value(1), 92);
1924 }
1925
1926 #[tokio::test]
1927 async fn test_to_sql() {
1928 use arrow_array::RecordBatch;
1929 use arrow_schema::{DataType, Field, Schema};
1930 use std::collections::HashMap;
1931 use std::sync::Arc;
1932
1933 let schema = Arc::new(Schema::new(vec![
1934 Field::new("id", DataType::Int64, false),
1935 Field::new("name", DataType::Utf8, false),
1936 ]));
1937 let batch = RecordBatch::new_empty(schema.clone());
1938
1939 let mut datasets = HashMap::new();
1940 datasets.insert("Person".to_string(), batch);
1941
1942 let cfg = GraphConfig::builder()
1943 .with_node_label("Person", "id")
1944 .build()
1945 .unwrap();
1946
1947 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1948 .unwrap()
1949 .with_config(cfg);
1950
1951 let sql = query.to_sql(datasets).await.unwrap();
1952 println!("Generated SQL: {}", sql);
1953
1954 assert!(sql.contains("SELECT"));
1955 assert!(sql.to_lowercase().contains("from person"));
1956 assert!(sql.contains("p.name"));
1959 }
1960}