1use crate::ast::CypherQuery as CypherAST;
7use crate::ast::ReadingClause;
8use crate::config::GraphConfig;
9use crate::error::{GraphError, Result};
10use crate::logical_plan::LogicalPlanner;
11use crate::parser::parse_cypher_query;
12use arrow_array::RecordBatch;
13use arrow_schema::{Field, Schema, SchemaRef};
14use lance_graph_catalog::DirNamespace;
15use lance_namespace::models::DescribeTableRequest;
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19pub(crate) fn normalize_schema(schema: SchemaRef) -> Result<SchemaRef> {
24 let fields: Vec<_> = schema
25 .fields()
26 .iter()
27 .map(|f| {
28 Arc::new(Field::new(
29 f.name().to_lowercase(),
30 f.data_type().clone(),
31 f.is_nullable(),
32 ))
33 })
34 .collect();
35 Ok(Arc::new(Schema::new(fields)))
36}
37
38pub(crate) fn normalize_record_batch(batch: &RecordBatch) -> Result<RecordBatch> {
43 let normalized_schema = normalize_schema(batch.schema())?;
44 RecordBatch::try_new(normalized_schema, batch.columns().to_vec()).map_err(|e| {
45 GraphError::PlanError {
46 message: format!("Failed to normalize record batch schema: {}", e),
47 location: snafu::Location::new(file!(), line!(), column!()),
48 }
49 })
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54pub enum ExecutionStrategy {
55 #[default]
57 DataFusion,
58 LanceNative,
60}
61
62#[derive(Debug, Clone)]
64pub struct CypherQuery {
65 query_text: String,
67 ast: CypherAST,
69 config: Option<GraphConfig>,
71 parameters: HashMap<String, serde_json::Value>,
73}
74impl CypherQuery {
75 pub fn new(query: &str) -> Result<Self> {
77 let ast = parse_cypher_query(query)?;
78
79 Ok(Self {
80 query_text: query.to_string(),
81 ast,
82 config: None,
83 parameters: HashMap::new(),
84 })
85 }
86
87 pub fn with_config(mut self, config: GraphConfig) -> Self {
89 self.config = Some(config);
90 self
91 }
92
93 pub fn with_parameter<K, V>(mut self, key: K, value: V) -> Self
95 where
96 K: Into<String>,
97 V: Into<serde_json::Value>,
98 {
99 self.parameters
100 .insert(key.into().to_lowercase(), value.into());
101 self
102 }
103
104 pub fn with_parameters(mut self, params: HashMap<String, serde_json::Value>) -> Self {
106 for (k, v) in params {
107 self.parameters.insert(k.to_lowercase(), v);
108 }
109 self
110 }
111
112 pub fn query_text(&self) -> &str {
114 &self.query_text
115 }
116
117 pub fn ast(&self) -> &CypherAST {
119 &self.ast
120 }
121
122 pub fn config(&self) -> Option<&GraphConfig> {
124 self.config.as_ref()
125 }
126
127 pub fn parameters(&self) -> &HashMap<String, serde_json::Value> {
129 &self.parameters
130 }
131
132 fn require_config(&self) -> Result<&GraphConfig> {
134 self.config.as_ref().ok_or_else(|| GraphError::ConfigError {
135 message: "Graph configuration is required for query execution".to_string(),
136 location: snafu::Location::new(file!(), line!(), column!()),
137 })
138 }
139
140 pub async fn execute(
174 &self,
175 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
176 strategy: Option<ExecutionStrategy>,
177 ) -> Result<arrow::record_batch::RecordBatch> {
178 let strategy = strategy.unwrap_or_default();
179 match strategy {
180 ExecutionStrategy::DataFusion => self.execute_datafusion(datasets).await,
181 ExecutionStrategy::LanceNative => Err(GraphError::UnsupportedFeature {
182 feature: "Lance native execution strategy is not yet implemented".to_string(),
183 location: snafu::Location::new(file!(), line!(), column!()),
184 }),
185 }
186 }
187
188 pub async fn execute_with_namespace(
192 &self,
193 namespace: DirNamespace,
194 strategy: Option<ExecutionStrategy>,
195 ) -> Result<arrow::record_batch::RecordBatch> {
196 self.execute_with_namespace_arc(std::sync::Arc::new(namespace), strategy)
197 .await
198 }
199
200 pub async fn execute_with_namespace_arc(
202 &self,
203 namespace: std::sync::Arc<DirNamespace>,
204 strategy: Option<ExecutionStrategy>,
205 ) -> Result<arrow::record_batch::RecordBatch> {
206 let namespace_trait: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync> =
207 namespace;
208 self.execute_with_namespace_internal(namespace_trait, strategy)
209 .await
210 }
211
212 async fn execute_with_namespace_internal(
213 &self,
214 namespace: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync>,
215 strategy: Option<ExecutionStrategy>,
216 ) -> Result<arrow::record_batch::RecordBatch> {
217 let strategy = strategy.unwrap_or_default();
218 match strategy {
219 ExecutionStrategy::DataFusion => {
220 let (catalog, ctx) = self
221 .build_catalog_and_context_from_namespace(namespace)
222 .await?;
223 self.execute_with_catalog_and_context(std::sync::Arc::new(catalog), ctx)
224 .await
225 }
226 ExecutionStrategy::LanceNative => Err(GraphError::UnsupportedFeature {
227 feature: "Lance native execution strategy is not yet implemented".to_string(),
228 location: snafu::Location::new(file!(), line!(), column!()),
229 }),
230 }
231 }
232
233 pub async fn explain(
269 &self,
270 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
271 ) -> Result<String> {
272 use std::sync::Arc;
273
274 let (catalog, ctx) = self
276 .build_catalog_and_context_from_datasets(datasets)
277 .await?;
278
279 self.explain_internal(Arc::new(catalog), ctx).await
281 }
282
283 pub async fn to_sql(
300 &self,
301 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
302 ) -> Result<String> {
303 use datafusion_sql::unparser::plan_to_sql;
304 use std::sync::Arc;
305
306 let _config = self.require_config()?;
307
308 let (catalog, ctx) = self
310 .build_catalog_and_context_from_datasets(datasets)
311 .await?;
312
313 let (_, df_plan) = self.create_logical_plans(Arc::new(catalog))?;
315
316 let optimized_plan = ctx
319 .state()
320 .optimize(&df_plan)
321 .map_err(|e| GraphError::PlanError {
322 message: format!("Failed to optimize plan: {}", e),
323 location: snafu::Location::new(file!(), line!(), column!()),
324 })?;
325
326 let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
328 message: format!("Failed to unparse plan to SQL: {}", e),
329 location: snafu::Location::new(file!(), line!(), column!()),
330 })?;
331
332 Ok(sql_ast.to_string())
333 }
334
335 pub async fn execute_with_context(
384 &self,
385 ctx: datafusion::execution::context::SessionContext,
386 ) -> Result<arrow::record_batch::RecordBatch> {
387 use datafusion::datasource::DefaultTableSource;
388 use lance_graph_catalog::InMemoryCatalog;
389 use std::sync::Arc;
390
391 let config = self.require_config()?;
392
393 let mut catalog = InMemoryCatalog::new();
395
396 for label in config.node_mappings.keys() {
398 let table_provider =
399 ctx.table_provider(label)
400 .await
401 .map_err(|e| GraphError::ConfigError {
402 message: format!(
403 "Node label '{}' not found in SessionContext: {}",
404 label, e
405 ),
406 location: snafu::Location::new(file!(), line!(), column!()),
407 })?;
408
409 let table_source = Arc::new(DefaultTableSource::new(table_provider));
410 catalog = catalog.with_node_source(label, table_source);
411 }
412
413 for rel_type in config.relationship_mappings.keys() {
415 let table_provider =
416 ctx.table_provider(rel_type)
417 .await
418 .map_err(|e| GraphError::ConfigError {
419 message: format!(
420 "Relationship type '{}' not found in SessionContext: {}",
421 rel_type, e
422 ),
423 location: snafu::Location::new(file!(), line!(), column!()),
424 })?;
425
426 let table_source = Arc::new(DefaultTableSource::new(table_provider));
427 catalog = catalog.with_relationship_source(rel_type, table_source);
428 }
429
430 self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
432 .await
433 }
434
435 pub async fn execute_with_catalog_and_context(
472 &self,
473 catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
474 ctx: datafusion::execution::context::SessionContext,
475 ) -> Result<arrow::record_batch::RecordBatch> {
476 use arrow::compute::concat_batches;
477
478 let (_logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
480
481 let df = ctx
483 .execute_logical_plan(df_logical_plan)
484 .await
485 .map_err(|e| GraphError::ExecutionError {
486 message: format!("Failed to execute DataFusion plan: {}", e),
487 location: snafu::Location::new(file!(), line!(), column!()),
488 })?;
489
490 let result_schema = df.schema().inner().clone();
492
493 let batches = df.collect().await.map_err(|e| GraphError::ExecutionError {
495 message: format!("Failed to collect query results: {}", e),
496 location: snafu::Location::new(file!(), line!(), column!()),
497 })?;
498
499 if batches.is_empty() {
500 return Ok(arrow::record_batch::RecordBatch::new_empty(result_schema));
503 }
504
505 let schema = batches[0].schema();
507 concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
508 message: format!("Failed to concatenate result batches: {}", e),
509 location: snafu::Location::new(file!(), line!(), column!()),
510 })
511 }
512
513 async fn execute_datafusion(
529 &self,
530 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
531 ) -> Result<arrow::record_batch::RecordBatch> {
532 use std::sync::Arc;
533
534 let (catalog, ctx) = self
536 .build_catalog_and_context_from_datasets(datasets)
537 .await?;
538
539 self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
541 .await
542 }
543
544 async fn build_catalog_and_context_from_datasets(
546 &self,
547 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
548 ) -> Result<(
549 lance_graph_catalog::InMemoryCatalog,
550 datafusion::execution::context::SessionContext,
551 )> {
552 use datafusion::datasource::{DefaultTableSource, MemTable};
553 use datafusion::execution::context::SessionContext;
554 use lance_graph_catalog::InMemoryCatalog;
555 use std::sync::Arc;
556
557 if datasets.is_empty() {
558 return Err(GraphError::ConfigError {
559 message: "No input datasets provided".to_string(),
560 location: snafu::Location::new(file!(), line!(), column!()),
561 });
562 }
563
564 let ctx = SessionContext::new();
566 let mut catalog = InMemoryCatalog::new();
567
568 for (name, batch) in &datasets {
570 let normalized_batch = normalize_record_batch(batch)?;
572
573 let mem_table = Arc::new(
574 MemTable::try_new(
575 normalized_batch.schema(),
576 vec![vec![normalized_batch.clone()]],
577 )
578 .map_err(|e| GraphError::PlanError {
579 message: format!("Failed to create MemTable for {}: {}", name, e),
580 location: snafu::Location::new(file!(), line!(), column!()),
581 })?,
582 );
583
584 let normalized_name = name.to_lowercase();
586
587 ctx.register_table(&normalized_name, mem_table.clone())
589 .map_err(|e| GraphError::PlanError {
590 message: format!("Failed to register table {}: {}", name, e),
591 location: snafu::Location::new(file!(), line!(), column!()),
592 })?;
593
594 let table_source = Arc::new(DefaultTableSource::new(mem_table));
595
596 catalog = catalog
599 .with_node_source(name, table_source.clone())
600 .with_relationship_source(name, table_source);
601 }
602
603 Ok((catalog, ctx))
604 }
605
606 async fn build_catalog_and_context_from_namespace(
608 &self,
609 namespace: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync>,
610 ) -> Result<(
611 lance_graph_catalog::InMemoryCatalog,
612 datafusion::execution::context::SessionContext,
613 )> {
614 use datafusion::datasource::{DefaultTableSource, TableProvider};
615 use datafusion::execution::context::SessionContext;
616 use lance::datafusion::LanceTableProvider;
617 use lance_graph_catalog::InMemoryCatalog;
618 use std::sync::Arc;
619
620 let config = self.require_config()?;
621
622 let mut required_tables: HashSet<String> = HashSet::new();
623 required_tables.extend(config.node_mappings.values().map(|m| m.label.clone()));
626 required_tables.extend(
627 config
628 .relationship_mappings
629 .values()
630 .map(|m| m.relationship_type.clone()),
631 );
632
633 if required_tables.is_empty() {
634 return Err(GraphError::ConfigError {
635 message:
636 "Graph configuration does not reference any node labels or relationship types"
637 .to_string(),
638 location: snafu::Location::new(file!(), line!(), column!()),
639 });
640 }
641
642 let ctx = SessionContext::new();
643 let mut catalog = InMemoryCatalog::new();
644 let mut providers: HashMap<String, Arc<dyn TableProvider>> = HashMap::new();
645
646 for table_name in required_tables {
647 let mut request = DescribeTableRequest::new();
648 request.id = Some(vec![table_name.clone()]);
649
650 let response =
651 namespace
652 .describe_table(request)
653 .await
654 .map_err(|e| GraphError::ConfigError {
655 message: format!(
656 "Namespace failed to resolve table '{}': {}",
657 table_name, e
658 ),
659 location: snafu::Location::new(file!(), line!(), column!()),
660 })?;
661
662 let location = response.location.ok_or_else(|| GraphError::ConfigError {
663 message: format!(
664 "Namespace did not provide a location for table '{}'",
665 table_name
666 ),
667 location: snafu::Location::new(file!(), line!(), column!()),
668 })?;
669
670 let dataset = lance::dataset::Dataset::open(&location)
671 .await
672 .map_err(|e| GraphError::ConfigError {
673 message: format!("Failed to open dataset for table '{}': {}", table_name, e),
674 location: snafu::Location::new(file!(), line!(), column!()),
675 })?;
676
677 let dataset = Arc::new(dataset);
678 let provider: Arc<dyn TableProvider> =
679 Arc::new(LanceTableProvider::new(dataset.clone(), true, true));
680
681 let normalized_table_name = table_name.to_lowercase();
683 ctx.register_table(&normalized_table_name, provider.clone())
684 .map_err(|e| GraphError::PlanError {
685 message: format!(
686 "Failed to register table '{}' in SessionContext: {}",
687 table_name, e
688 ),
689 location: snafu::Location::new(file!(), line!(), column!()),
690 })?;
691
692 providers.insert(normalized_table_name.clone(), provider);
694 }
695
696 for label in config.node_mappings.keys() {
697 let provider = providers
698 .get(label)
699 .ok_or_else(|| GraphError::ConfigError {
700 message: format!(
701 "Namespace did not resolve dataset for node label '{}'",
702 label
703 ),
704 location: snafu::Location::new(file!(), line!(), column!()),
705 })?;
706
707 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
708 catalog = catalog.with_node_source(label, table_source);
709 }
710
711 for rel_type in config.relationship_mappings.keys() {
712 let provider = providers
713 .get(rel_type)
714 .ok_or_else(|| GraphError::ConfigError {
715 message: format!(
716 "Namespace did not resolve dataset for relationship type '{}'",
717 rel_type
718 ),
719 location: snafu::Location::new(file!(), line!(), column!()),
720 })?;
721
722 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
723 catalog = catalog.with_relationship_source(rel_type, table_source);
724 }
725
726 Ok((catalog, ctx))
727 }
728
729 async fn explain_internal(
731 &self,
732 catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
733 ctx: datafusion::execution::context::SessionContext,
734 ) -> Result<String> {
735 let (logical_plan, df_logical_plan, physical_plan) =
737 self.create_plans(catalog, &ctx).await?;
738
739 self.format_explain_output(&logical_plan, &df_logical_plan, physical_plan.as_ref())
741 }
742
743 fn create_logical_plans(
748 &self,
749 catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
750 ) -> Result<(
751 crate::logical_plan::LogicalOperator,
752 datafusion::logical_expr::LogicalPlan,
753 )> {
754 use crate::datafusion_planner::{DataFusionPlanner, GraphPhysicalPlanner};
755 use crate::semantic::SemanticAnalyzer;
756
757 let config = self.require_config()?;
758
759 let mut analyzer = SemanticAnalyzer::new(config.clone());
761 let semantic = analyzer.analyze(&self.ast, &self.parameters)?;
762 if !semantic.errors.is_empty() {
763 return Err(GraphError::PlanError {
764 message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")),
765 location: snafu::Location::new(file!(), line!(), column!()),
766 });
767 }
768
769 let mut logical_planner = LogicalPlanner::new(config);
771 let logical_plan = logical_planner.plan(&semantic.ast)?;
773
774 let df_planner = DataFusionPlanner::with_catalog(config.clone(), catalog);
776 let df_logical_plan = df_planner.plan(&logical_plan)?;
777
778 Ok((logical_plan, df_logical_plan))
779 }
780
781 async fn create_plans(
783 &self,
784 catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
785 ctx: &datafusion::execution::context::SessionContext,
786 ) -> Result<(
787 crate::logical_plan::LogicalOperator,
788 datafusion::logical_expr::LogicalPlan,
789 std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
790 )> {
791 let (logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
793
794 let df = ctx
796 .execute_logical_plan(df_logical_plan.clone())
797 .await
798 .map_err(|e| GraphError::ExecutionError {
799 message: format!("Failed to execute DataFusion plan: {}", e),
800 location: snafu::Location::new(file!(), line!(), column!()),
801 })?;
802
803 let physical_plan =
804 df.create_physical_plan()
805 .await
806 .map_err(|e| GraphError::ExecutionError {
807 message: format!("Failed to create physical plan: {}", e),
808 location: snafu::Location::new(file!(), line!(), column!()),
809 })?;
810
811 Ok((logical_plan, df_logical_plan, physical_plan))
812 }
813
814 fn format_explain_output(
816 &self,
817 logical_plan: &crate::logical_plan::LogicalOperator,
818 df_logical_plan: &datafusion::logical_expr::LogicalPlan,
819 physical_plan: &dyn datafusion::physical_plan::ExecutionPlan,
820 ) -> Result<String> {
821 let mut output = String::new();
823
824 output.push_str("Cypher Query:\n");
826 output.push_str(&format!(" {}\n\n", self.query_text));
827
828 let mut rows = vec![];
830
831 let graph_plan_str = format!("{:#?}", logical_plan);
833 rows.push(("graph_logical_plan", graph_plan_str));
834
835 let df_logical_str = format!("{}", df_logical_plan.display_indent());
837 rows.push(("logical_plan", df_logical_str));
838
839 let df_physical_str = format!(
841 "{}",
842 datafusion::physical_plan::displayable(physical_plan).indent(true)
843 );
844 rows.push(("physical_plan", df_physical_str));
845
846 let plan_type_width = rows.iter().map(|(t, _)| t.len()).max().unwrap_or(10);
848 let plan_width = rows
849 .iter()
850 .map(|(_, p)| p.lines().map(|l| l.len()).max().unwrap_or(0))
851 .max()
852 .unwrap_or(50);
853
854 let separator = format!(
856 "+{}+{}+",
857 "-".repeat(plan_type_width + 2),
858 "-".repeat(plan_width + 2)
859 );
860
861 output.push_str(&separator);
862 output.push('\n');
863
864 output.push_str(&format!(
866 "| {:<width$} | {:<plan_width$} |\n",
867 "plan_type",
868 "plan",
869 width = plan_type_width,
870 plan_width = plan_width
871 ));
872 output.push_str(&separator);
873 output.push('\n');
874
875 for (plan_type, plan_content) in rows {
877 let lines: Vec<&str> = plan_content.lines().collect();
878 if lines.is_empty() {
879 output.push_str(&format!(
880 "| {:<width$} | {:<plan_width$} |\n",
881 plan_type,
882 "",
883 width = plan_type_width,
884 plan_width = plan_width
885 ));
886 } else {
887 output.push_str(&format!(
889 "| {:<width$} | {:<plan_width$} |\n",
890 plan_type,
891 lines[0],
892 width = plan_type_width,
893 plan_width = plan_width
894 ));
895
896 for line in &lines[1..] {
898 output.push_str(&format!(
899 "| {:<width$} | {:<plan_width$} |\n",
900 "",
901 line,
902 width = plan_type_width,
903 plan_width = plan_width
904 ));
905 }
906 }
907 }
908
909 output.push_str(&separator);
910 output.push('\n');
911
912 Ok(output)
913 }
914
915 pub fn referenced_node_labels(&self) -> Vec<String> {
917 let mut labels = Vec::new();
918
919 for clause in &self.ast.reading_clauses {
920 if let ReadingClause::Match(match_clause) = clause {
921 for pattern in &match_clause.patterns {
922 self.collect_node_labels_from_pattern(pattern, &mut labels);
923 }
924 }
925 }
926
927 labels.sort();
928 labels.dedup();
929 labels
930 }
931
932 pub fn referenced_relationship_types(&self) -> Vec<String> {
934 let mut types = Vec::new();
935
936 for clause in &self.ast.reading_clauses {
937 if let ReadingClause::Match(match_clause) = clause {
938 for pattern in &match_clause.patterns {
939 self.collect_relationship_types_from_pattern(pattern, &mut types);
940 }
941 }
942 }
943
944 types.sort();
945 types.dedup();
946 types
947 }
948
949 pub fn variables(&self) -> Vec<String> {
951 let mut variables = Vec::new();
952
953 for clause in &self.ast.reading_clauses {
954 match clause {
955 ReadingClause::Match(match_clause) => {
956 for pattern in &match_clause.patterns {
957 self.collect_variables_from_pattern(pattern, &mut variables);
958 }
959 }
960 ReadingClause::Unwind(unwind_clause) => {
961 variables.push(unwind_clause.alias.clone());
962 }
963 }
964 }
965
966 variables.sort();
967 variables.dedup();
968 variables
969 }
970
971 fn collect_node_labels_from_pattern(
974 &self,
975 pattern: &crate::ast::GraphPattern,
976 labels: &mut Vec<String>,
977 ) {
978 match pattern {
979 crate::ast::GraphPattern::Node(node) => {
980 labels.extend(node.labels.clone());
981 }
982 crate::ast::GraphPattern::Path(path) => {
983 labels.extend(path.start_node.labels.clone());
984 for segment in &path.segments {
985 labels.extend(segment.end_node.labels.clone());
986 }
987 }
988 }
989 }
990
991 fn collect_relationship_types_from_pattern(
992 &self,
993 pattern: &crate::ast::GraphPattern,
994 types: &mut Vec<String>,
995 ) {
996 if let crate::ast::GraphPattern::Path(path) = pattern {
997 for segment in &path.segments {
998 types.extend(segment.relationship.types.clone());
999 }
1000 }
1001 }
1002
1003 fn collect_variables_from_pattern(
1004 &self,
1005 pattern: &crate::ast::GraphPattern,
1006 variables: &mut Vec<String>,
1007 ) {
1008 match pattern {
1009 crate::ast::GraphPattern::Node(node) => {
1010 if let Some(var) = &node.variable {
1011 variables.push(var.clone());
1012 }
1013 }
1014 crate::ast::GraphPattern::Path(path) => {
1015 if let Some(var) = &path.start_node.variable {
1016 variables.push(var.clone());
1017 }
1018 for segment in &path.segments {
1019 if let Some(var) = &segment.relationship.variable {
1020 variables.push(var.clone());
1021 }
1022 if let Some(var) = &segment.end_node.variable {
1023 variables.push(var.clone());
1024 }
1025 }
1026 }
1027 }
1028 }
1029}
1030
1031impl CypherQuery {
1032 pub async fn execute_with_vector_rerank(
1061 &self,
1062 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
1063 vector_search: crate::lance_vector_search::VectorSearch,
1064 ) -> Result<arrow::record_batch::RecordBatch> {
1065 let candidates = self.execute(datasets, None).await?;
1067
1068 vector_search.search(&candidates).await
1070 }
1071}
1072
1073#[derive(Debug, Default)]
1075pub struct CypherQueryBuilder {
1076 match_clauses: Vec<crate::ast::MatchClause>,
1077 where_expression: Option<crate::ast::BooleanExpression>,
1078 return_items: Vec<crate::ast::ReturnItem>,
1079 order_by_items: Vec<crate::ast::OrderByItem>,
1080 limit: Option<u64>,
1081 distinct: bool,
1082 skip: Option<u64>,
1083 config: Option<GraphConfig>,
1084 parameters: HashMap<String, serde_json::Value>,
1085}
1086
1087impl CypherQueryBuilder {
1088 pub fn new() -> Self {
1090 Self::default()
1091 }
1092
1093 pub fn match_node(mut self, variable: &str, label: &str) -> Self {
1095 let node = crate::ast::NodePattern {
1096 variable: Some(variable.to_string()),
1097 labels: vec![label.to_string()],
1098 properties: HashMap::new(),
1099 };
1100
1101 let match_clause = crate::ast::MatchClause {
1102 patterns: vec![crate::ast::GraphPattern::Node(node)],
1103 };
1104
1105 self.match_clauses.push(match_clause);
1106 self
1107 }
1108
1109 pub fn with_config(mut self, config: GraphConfig) -> Self {
1111 self.config = Some(config);
1112 self
1113 }
1114
1115 pub fn return_property(mut self, variable: &str, property: &str) -> Self {
1117 let prop_ref = crate::ast::PropertyRef::new(variable, property);
1118 let return_item = crate::ast::ReturnItem {
1119 expression: crate::ast::ValueExpression::Property(prop_ref),
1120 alias: None,
1121 };
1122
1123 self.return_items.push(return_item);
1124 self
1125 }
1126
1127 pub fn distinct(mut self, distinct: bool) -> Self {
1129 self.distinct = distinct;
1130 self
1131 }
1132
1133 pub fn limit(mut self, limit: u64) -> Self {
1135 self.limit = Some(limit);
1136 self
1137 }
1138
1139 pub fn skip(mut self, skip: u64) -> Self {
1141 self.skip = Some(skip);
1142 self
1143 }
1144
1145 pub fn build(self) -> Result<CypherQuery> {
1147 if self.match_clauses.is_empty() {
1148 return Err(GraphError::PlanError {
1149 message: "Query must have at least one MATCH clause".to_string(),
1150 location: snafu::Location::new(file!(), line!(), column!()),
1151 });
1152 }
1153
1154 if self.return_items.is_empty() {
1155 return Err(GraphError::PlanError {
1156 message: "Query must have at least one RETURN item".to_string(),
1157 location: snafu::Location::new(file!(), line!(), column!()),
1158 });
1159 }
1160
1161 let ast = crate::ast::CypherQuery {
1162 reading_clauses: self
1163 .match_clauses
1164 .into_iter()
1165 .map(crate::ast::ReadingClause::Match)
1166 .collect(),
1167 where_clause: self
1168 .where_expression
1169 .map(|expr| crate::ast::WhereClause { expression: expr }),
1170 with_clause: None, post_with_reading_clauses: vec![],
1172 post_with_where_clause: None,
1173 return_clause: crate::ast::ReturnClause {
1174 distinct: self.distinct,
1175 items: self.return_items,
1176 },
1177 order_by: if self.order_by_items.is_empty() {
1178 None
1179 } else {
1180 Some(crate::ast::OrderByClause {
1181 items: self.order_by_items,
1182 })
1183 },
1184 limit: self.limit,
1185 skip: self.skip,
1186 };
1187
1188 let query_text = "MATCH ... RETURN ...".to_string(); let query = CypherQuery {
1192 query_text,
1193 ast,
1194 config: self.config,
1195 parameters: self.parameters,
1196 };
1197
1198 Ok(query)
1199 }
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204 use super::*;
1205 use crate::config::GraphConfig;
1206
1207 #[test]
1208 fn test_parse_simple_cypher_query() {
1209 let query = CypherQuery::new("MATCH (n:Person) RETURN n.name").unwrap();
1210 assert_eq!(query.query_text(), "MATCH (n:Person) RETURN n.name");
1211 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1212 assert_eq!(query.variables(), vec!["n"]);
1213 }
1214
1215 #[test]
1216 fn test_query_with_parameters() {
1217 let mut params = HashMap::new();
1218 params.insert("minAge".to_string(), serde_json::Value::Number(30.into()));
1219 params.insert("maxAge".to_string(), serde_json::Value::Number(50.into()));
1220
1221 let query = CypherQuery::new(
1222 "MATCH (n:Person) WHERE n.age > $minAge AND n.age < $maxAge RETURN n.name",
1223 )
1224 .unwrap()
1225 .with_parameters(params);
1226
1227 assert!(query.parameters().contains_key("minage"));
1228 assert!(query.parameters().contains_key("maxage"));
1229 }
1230
1231 #[test]
1232 fn test_query_builder() {
1233 let config = GraphConfig::builder()
1234 .with_node_label("Person", "person_id")
1235 .build()
1236 .unwrap();
1237
1238 let query = CypherQueryBuilder::new()
1239 .with_config(config)
1240 .match_node("n", "Person")
1241 .return_property("n", "name")
1242 .limit(10)
1243 .build()
1244 .unwrap();
1245
1246 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1247 assert_eq!(query.variables(), vec!["n"]);
1248 }
1249
1250 #[test]
1251 fn test_relationship_query_parsing() {
1252 let query =
1253 CypherQuery::new("MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a.name, b.name")
1254 .unwrap();
1255 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1256 assert_eq!(query.referenced_relationship_types(), vec!["KNOWS"]);
1257 assert_eq!(query.variables(), vec!["a", "b", "r"]);
1258 }
1259
1260 #[tokio::test]
1261 async fn test_execute_basic_projection_and_filter() {
1262 use arrow_array::{Int64Array, RecordBatch, StringArray};
1263 use arrow_schema::{DataType, Field, Schema};
1264 use std::sync::Arc;
1265
1266 let schema = Arc::new(Schema::new(vec![
1268 Field::new("name", DataType::Utf8, true),
1269 Field::new("age", DataType::Int64, true),
1270 ]));
1271 let batch = RecordBatch::try_new(
1272 schema,
1273 vec![
1274 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1275 Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
1276 ],
1277 )
1278 .unwrap();
1279
1280 let cfg = GraphConfig::builder()
1281 .with_node_label("Person", "id")
1282 .build()
1283 .unwrap();
1284
1285 let q = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
1286 .unwrap()
1287 .with_config(cfg);
1288
1289 let mut data = HashMap::new();
1290 data.insert("Person".to_string(), batch);
1291
1292 let out = q.execute(data, None).await.unwrap();
1293 assert_eq!(out.num_rows(), 2);
1294 let names = out
1295 .column(0)
1296 .as_any()
1297 .downcast_ref::<StringArray>()
1298 .unwrap();
1299 let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
1300 let result: Vec<(String, i64)> = (0..out.num_rows())
1302 .map(|i| (names.value(i).to_string(), ages.value(i)))
1303 .collect();
1304 assert!(result.contains(&("Bob".to_string(), 34)));
1305 assert!(result.contains(&("David".to_string(), 42)));
1306 }
1307
1308 #[tokio::test]
1309 async fn test_execute_single_hop_path_join_projection() {
1310 use arrow_array::{Int64Array, RecordBatch, StringArray};
1311 use arrow_schema::{DataType, Field, Schema};
1312 use std::sync::Arc;
1313
1314 let person_schema = Arc::new(Schema::new(vec![
1316 Field::new("id", DataType::Int64, false),
1317 Field::new("name", DataType::Utf8, true),
1318 Field::new("age", DataType::Int64, true),
1319 ]));
1320 let people = RecordBatch::try_new(
1321 person_schema,
1322 vec![
1323 Arc::new(Int64Array::from(vec![1, 2, 3])),
1324 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1325 Arc::new(Int64Array::from(vec![28, 34, 29])),
1326 ],
1327 )
1328 .unwrap();
1329
1330 let rel_schema = Arc::new(Schema::new(vec![
1332 Field::new("src_person_id", DataType::Int64, false),
1333 Field::new("dst_person_id", DataType::Int64, false),
1334 ]));
1335 let knows = RecordBatch::try_new(
1336 rel_schema,
1337 vec![
1338 Arc::new(Int64Array::from(vec![1, 2])), Arc::new(Int64Array::from(vec![2, 3])),
1340 ],
1341 )
1342 .unwrap();
1343
1344 let cfg = GraphConfig::builder()
1346 .with_node_label("Person", "id")
1347 .with_relationship("KNOWS", "src_person_id", "dst_person_id")
1348 .build()
1349 .unwrap();
1350
1351 let q = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN b.name")
1353 .unwrap()
1354 .with_config(cfg);
1355
1356 let mut data = HashMap::new();
1357 data.insert("Person".to_string(), people);
1359 data.insert("KNOWS".to_string(), knows);
1360
1361 let out = q.execute(data, None).await.unwrap();
1362 let names = out
1364 .column(0)
1365 .as_any()
1366 .downcast_ref::<StringArray>()
1367 .unwrap();
1368 let got: Vec<String> = (0..out.num_rows())
1369 .map(|i| names.value(i).to_string())
1370 .collect();
1371 assert_eq!(got.len(), 2);
1372 assert!(got.contains(&"Bob".to_string()));
1373 assert!(got.contains(&"Carol".to_string()));
1374 }
1375
1376 #[tokio::test]
1377 async fn test_execute_order_by_asc() {
1378 use arrow_array::{Int64Array, RecordBatch, StringArray};
1379 use arrow_schema::{DataType, Field, Schema};
1380 use std::sync::Arc;
1381
1382 let schema = Arc::new(Schema::new(vec![
1384 Field::new("name", DataType::Utf8, true),
1385 Field::new("age", DataType::Int64, true),
1386 ]));
1387 let batch = RecordBatch::try_new(
1388 schema,
1389 vec![
1390 Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
1391 Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
1392 ],
1393 )
1394 .unwrap();
1395
1396 let cfg = GraphConfig::builder()
1397 .with_node_label("Person", "id")
1398 .build()
1399 .unwrap();
1400
1401 let q = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age ORDER BY p.age ASC")
1403 .unwrap()
1404 .with_config(cfg);
1405
1406 let mut data = HashMap::new();
1407 data.insert("Person".to_string(), batch);
1408
1409 let out = q.execute(data, None).await.unwrap();
1410 let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
1411 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1412 assert_eq!(collected, vec![28, 29, 34, 42]);
1413 }
1414
1415 #[tokio::test]
1416 async fn test_execute_order_by_desc_with_skip_limit() {
1417 use arrow_array::{Int64Array, RecordBatch, StringArray};
1418 use arrow_schema::{DataType, Field, Schema};
1419 use std::sync::Arc;
1420
1421 let schema = Arc::new(Schema::new(vec![
1422 Field::new("name", DataType::Utf8, true),
1423 Field::new("age", DataType::Int64, true),
1424 ]));
1425 let batch = RecordBatch::try_new(
1426 schema,
1427 vec![
1428 Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
1429 Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
1430 ],
1431 )
1432 .unwrap();
1433
1434 let cfg = GraphConfig::builder()
1435 .with_node_label("Person", "id")
1436 .build()
1437 .unwrap();
1438
1439 let q =
1441 CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age DESC SKIP 1 LIMIT 2")
1442 .unwrap()
1443 .with_config(cfg);
1444
1445 let mut data = HashMap::new();
1446 data.insert("Person".to_string(), batch);
1447
1448 let out = q.execute(data, None).await.unwrap();
1449 assert_eq!(out.num_rows(), 2);
1450 let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
1451 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1452 assert_eq!(collected, vec![34, 29]);
1453 }
1454
1455 #[tokio::test]
1456 async fn test_execute_skip_without_limit() {
1457 use arrow_array::{Int64Array, RecordBatch};
1458 use arrow_schema::{DataType, Field, Schema};
1459 use std::sync::Arc;
1460
1461 let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
1462 let batch = RecordBatch::try_new(
1463 schema,
1464 vec![Arc::new(Int64Array::from(vec![10, 20, 30, 40]))],
1465 )
1466 .unwrap();
1467
1468 let cfg = GraphConfig::builder()
1469 .with_node_label("Person", "id")
1470 .build()
1471 .unwrap();
1472
1473 let q = CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age ASC SKIP 2")
1474 .unwrap()
1475 .with_config(cfg);
1476
1477 let mut data = HashMap::new();
1478 data.insert("Person".to_string(), batch);
1479
1480 let out = q.execute(data, None).await.unwrap();
1481 assert_eq!(out.num_rows(), 2);
1482 let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
1483 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1484 assert_eq!(collected, vec![30, 40]);
1485 }
1486
1487 #[tokio::test]
1488 async fn test_execute_datafusion_simple_scan() {
1489 use arrow_array::{Int64Array, RecordBatch, StringArray};
1490 use arrow_schema::{DataType, Field, Schema};
1491 use std::sync::Arc;
1492
1493 let schema = Arc::new(Schema::new(vec![
1495 Field::new("id", DataType::Int64, false),
1496 Field::new("name", DataType::Utf8, false),
1497 ]));
1498
1499 let batch = RecordBatch::try_new(
1500 schema,
1501 vec![
1502 Arc::new(Int64Array::from(vec![1, 2])),
1503 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
1504 ],
1505 )
1506 .unwrap();
1507
1508 let cfg = GraphConfig::builder()
1509 .with_node_label("Person", "id")
1510 .build()
1511 .unwrap();
1512
1513 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1515 .unwrap()
1516 .with_config(cfg);
1517
1518 let mut datasets = HashMap::new();
1519 datasets.insert("Person".to_string(), batch);
1520
1521 let result = query.execute_datafusion(datasets).await.unwrap();
1523
1524 assert_eq!(
1526 result.num_rows(),
1527 2,
1528 "Should return all 2 rows without filtering"
1529 );
1530 assert_eq!(result.num_columns(), 1, "Should return 1 column (name)");
1531
1532 let names = result
1534 .column(0)
1535 .as_any()
1536 .downcast_ref::<StringArray>()
1537 .unwrap();
1538 let name_set: std::collections::HashSet<String> = (0..result.num_rows())
1539 .map(|i| names.value(i).to_string())
1540 .collect();
1541 let expected: std::collections::HashSet<String> =
1542 ["Alice", "Bob"].iter().map(|s| s.to_string()).collect();
1543 assert_eq!(name_set, expected, "Should return Alice and Bob");
1544 }
1545
1546 #[tokio::test]
1547 async fn test_execute_with_context_simple_scan() {
1548 use arrow_array::{Int64Array, RecordBatch, StringArray};
1549 use arrow_schema::{DataType, Field, Schema};
1550 use datafusion::datasource::MemTable;
1551 use datafusion::execution::context::SessionContext;
1552 use std::sync::Arc;
1553
1554 let schema = Arc::new(Schema::new(vec![
1556 Field::new("id", DataType::Int64, false),
1557 Field::new("name", DataType::Utf8, false),
1558 Field::new("age", DataType::Int64, false),
1559 ]));
1560 let batch = RecordBatch::try_new(
1561 schema.clone(),
1562 vec![
1563 Arc::new(Int64Array::from(vec![1, 2, 3])),
1564 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1565 Arc::new(Int64Array::from(vec![28, 34, 29])),
1566 ],
1567 )
1568 .unwrap();
1569
1570 let mem_table =
1572 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1573 let ctx = SessionContext::new();
1574 ctx.register_table("Person", mem_table).unwrap();
1575
1576 let cfg = GraphConfig::builder()
1578 .with_node_label("Person", "id")
1579 .build()
1580 .unwrap();
1581
1582 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1583 .unwrap()
1584 .with_config(cfg);
1585
1586 let result = query.execute_with_context(ctx).await.unwrap();
1588
1589 assert_eq!(result.num_rows(), 3);
1591 assert_eq!(result.num_columns(), 1);
1592
1593 let names = result
1594 .column(0)
1595 .as_any()
1596 .downcast_ref::<StringArray>()
1597 .unwrap();
1598 assert_eq!(names.value(0), "Alice");
1599 assert_eq!(names.value(1), "Bob");
1600 assert_eq!(names.value(2), "Carol");
1601 }
1602
1603 #[tokio::test]
1604 async fn test_execute_with_context_with_filter() {
1605 use arrow_array::{Int64Array, RecordBatch, StringArray};
1606 use arrow_schema::{DataType, Field, Schema};
1607 use datafusion::datasource::MemTable;
1608 use datafusion::execution::context::SessionContext;
1609 use std::sync::Arc;
1610
1611 let schema = Arc::new(Schema::new(vec![
1613 Field::new("id", DataType::Int64, false),
1614 Field::new("name", DataType::Utf8, false),
1615 Field::new("age", DataType::Int64, false),
1616 ]));
1617 let batch = RecordBatch::try_new(
1618 schema.clone(),
1619 vec![
1620 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
1621 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1622 Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
1623 ],
1624 )
1625 .unwrap();
1626
1627 let mem_table =
1629 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1630 let ctx = SessionContext::new();
1631 ctx.register_table("Person", mem_table).unwrap();
1632
1633 let cfg = GraphConfig::builder()
1635 .with_node_label("Person", "id")
1636 .build()
1637 .unwrap();
1638
1639 let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
1640 .unwrap()
1641 .with_config(cfg);
1642
1643 let result = query.execute_with_context(ctx).await.unwrap();
1645
1646 assert_eq!(result.num_rows(), 2);
1648 assert_eq!(result.num_columns(), 2);
1649
1650 let names = result
1651 .column(0)
1652 .as_any()
1653 .downcast_ref::<StringArray>()
1654 .unwrap();
1655 let ages = result
1656 .column(1)
1657 .as_any()
1658 .downcast_ref::<Int64Array>()
1659 .unwrap();
1660
1661 let results: Vec<(String, i64)> = (0..result.num_rows())
1662 .map(|i| (names.value(i).to_string(), ages.value(i)))
1663 .collect();
1664
1665 assert!(results.contains(&("Bob".to_string(), 34)));
1666 assert!(results.contains(&("David".to_string(), 42)));
1667 }
1668
1669 #[tokio::test]
1670 async fn test_execute_with_context_relationship_traversal() {
1671 use arrow_array::{Int64Array, RecordBatch, StringArray};
1672 use arrow_schema::{DataType, Field, Schema};
1673 use datafusion::datasource::MemTable;
1674 use datafusion::execution::context::SessionContext;
1675 use std::sync::Arc;
1676
1677 let person_schema = Arc::new(Schema::new(vec![
1679 Field::new("id", DataType::Int64, false),
1680 Field::new("name", DataType::Utf8, false),
1681 ]));
1682 let person_batch = RecordBatch::try_new(
1683 person_schema.clone(),
1684 vec![
1685 Arc::new(Int64Array::from(vec![1, 2, 3])),
1686 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1687 ],
1688 )
1689 .unwrap();
1690
1691 let knows_schema = Arc::new(Schema::new(vec![
1693 Field::new("src_id", DataType::Int64, false),
1694 Field::new("dst_id", DataType::Int64, false),
1695 Field::new("since", DataType::Int64, false),
1696 ]));
1697 let knows_batch = RecordBatch::try_new(
1698 knows_schema.clone(),
1699 vec![
1700 Arc::new(Int64Array::from(vec![1, 2])),
1701 Arc::new(Int64Array::from(vec![2, 3])),
1702 Arc::new(Int64Array::from(vec![2020, 2021])),
1703 ],
1704 )
1705 .unwrap();
1706
1707 let person_table = Arc::new(
1709 MemTable::try_new(person_schema.clone(), vec![vec![person_batch.clone()]]).unwrap(),
1710 );
1711 let knows_table = Arc::new(
1712 MemTable::try_new(knows_schema.clone(), vec![vec![knows_batch.clone()]]).unwrap(),
1713 );
1714
1715 let ctx = SessionContext::new();
1716 ctx.register_table("Person", person_table).unwrap();
1717 ctx.register_table("KNOWS", knows_table).unwrap();
1718
1719 let cfg = GraphConfig::builder()
1721 .with_node_label("Person", "id")
1722 .with_relationship("KNOWS", "src_id", "dst_id")
1723 .build()
1724 .unwrap();
1725
1726 let query = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a.name, b.name")
1727 .unwrap()
1728 .with_config(cfg);
1729
1730 let result = query.execute_with_context(ctx).await.unwrap();
1732
1733 assert_eq!(result.num_rows(), 2);
1735 assert_eq!(result.num_columns(), 2);
1736
1737 let src_names = result
1738 .column(0)
1739 .as_any()
1740 .downcast_ref::<StringArray>()
1741 .unwrap();
1742 let dst_names = result
1743 .column(1)
1744 .as_any()
1745 .downcast_ref::<StringArray>()
1746 .unwrap();
1747
1748 let relationships: Vec<(String, String)> = (0..result.num_rows())
1749 .map(|i| {
1750 (
1751 src_names.value(i).to_string(),
1752 dst_names.value(i).to_string(),
1753 )
1754 })
1755 .collect();
1756
1757 assert!(relationships.contains(&("Alice".to_string(), "Bob".to_string())));
1758 assert!(relationships.contains(&("Bob".to_string(), "Carol".to_string())));
1759 }
1760
1761 #[tokio::test]
1762 async fn test_execute_with_context_order_by_limit() {
1763 use arrow_array::{Int64Array, RecordBatch, StringArray};
1764 use arrow_schema::{DataType, Field, Schema};
1765 use datafusion::datasource::MemTable;
1766 use datafusion::execution::context::SessionContext;
1767 use std::sync::Arc;
1768
1769 let schema = Arc::new(Schema::new(vec![
1771 Field::new("id", DataType::Int64, false),
1772 Field::new("name", DataType::Utf8, false),
1773 Field::new("score", DataType::Int64, false),
1774 ]));
1775 let batch = RecordBatch::try_new(
1776 schema.clone(),
1777 vec![
1778 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
1779 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1780 Arc::new(Int64Array::from(vec![85, 92, 78, 95])),
1781 ],
1782 )
1783 .unwrap();
1784
1785 let mem_table =
1787 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1788 let ctx = SessionContext::new();
1789 ctx.register_table("Student", mem_table).unwrap();
1790
1791 let cfg = GraphConfig::builder()
1793 .with_node_label("Student", "id")
1794 .build()
1795 .unwrap();
1796
1797 let query = CypherQuery::new(
1798 "MATCH (s:Student) RETURN s.name, s.score ORDER BY s.score DESC LIMIT 2",
1799 )
1800 .unwrap()
1801 .with_config(cfg);
1802
1803 let result = query.execute_with_context(ctx).await.unwrap();
1805
1806 assert_eq!(result.num_rows(), 2);
1808 assert_eq!(result.num_columns(), 2);
1809
1810 let names = result
1811 .column(0)
1812 .as_any()
1813 .downcast_ref::<StringArray>()
1814 .unwrap();
1815 let scores = result
1816 .column(1)
1817 .as_any()
1818 .downcast_ref::<Int64Array>()
1819 .unwrap();
1820
1821 assert_eq!(names.value(0), "David");
1823 assert_eq!(scores.value(0), 95);
1824
1825 assert_eq!(names.value(1), "Bob");
1827 assert_eq!(scores.value(1), 92);
1828 }
1829
1830 #[tokio::test]
1831 async fn test_to_sql() {
1832 use arrow_array::RecordBatch;
1833 use arrow_schema::{DataType, Field, Schema};
1834 use std::collections::HashMap;
1835 use std::sync::Arc;
1836
1837 let schema = Arc::new(Schema::new(vec![
1838 Field::new("id", DataType::Int64, false),
1839 Field::new("name", DataType::Utf8, false),
1840 ]));
1841 let batch = RecordBatch::new_empty(schema.clone());
1842
1843 let mut datasets = HashMap::new();
1844 datasets.insert("Person".to_string(), batch);
1845
1846 let cfg = GraphConfig::builder()
1847 .with_node_label("Person", "id")
1848 .build()
1849 .unwrap();
1850
1851 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1852 .unwrap()
1853 .with_config(cfg);
1854
1855 let sql = query.to_sql(datasets).await.unwrap();
1856 println!("Generated SQL: {}", sql);
1857
1858 assert!(sql.contains("SELECT"));
1859 assert!(sql.to_lowercase().contains("from person"));
1860 assert!(sql.contains("p.name"));
1863 }
1864
1865 async fn write_lance_dataset(path: &std::path::Path, batch: arrow_array::RecordBatch) {
1866 use arrow_array::{RecordBatch, RecordBatchIterator};
1867 use lance::dataset::{Dataset, WriteParams};
1868
1869 let schema = batch.schema();
1870 let batches: Vec<std::result::Result<RecordBatch, arrow::error::ArrowError>> =
1871 vec![std::result::Result::Ok(batch)];
1872 let reader = RecordBatchIterator::new(batches.into_iter(), schema);
1873
1874 Dataset::write(reader, path.to_str().unwrap(), None::<WriteParams>)
1875 .await
1876 .expect("write lance dataset");
1877 }
1878
1879 fn build_people_batch() -> arrow_array::RecordBatch {
1880 use arrow_array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray};
1881 use arrow_schema::{DataType, Field, Schema};
1882 use std::sync::Arc;
1883
1884 let schema = Arc::new(Schema::new(vec![
1885 Field::new("person_id", DataType::Int64, false),
1886 Field::new("name", DataType::Utf8, false),
1887 Field::new("age", DataType::Int32, false),
1888 ]));
1889
1890 let columns: Vec<ArrayRef> = vec![
1891 Arc::new(Int64Array::from(vec![1, 2, 3, 4])) as ArrayRef,
1892 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])) as ArrayRef,
1893 Arc::new(Int32Array::from(vec![28, 34, 29, 42])) as ArrayRef,
1894 ];
1895
1896 RecordBatch::try_new(schema, columns).expect("valid person batch")
1897 }
1898
1899 fn build_friendship_batch() -> arrow_array::RecordBatch {
1900 use arrow_array::{ArrayRef, Int64Array, RecordBatch};
1901 use arrow_schema::{DataType, Field, Schema};
1902 use std::sync::Arc;
1903
1904 let schema = Arc::new(Schema::new(vec![
1905 Field::new("person1_id", DataType::Int64, false),
1906 Field::new("person2_id", DataType::Int64, false),
1907 ]));
1908
1909 let columns: Vec<ArrayRef> = vec![
1910 Arc::new(Int64Array::from(vec![1, 1, 2, 3])) as ArrayRef,
1911 Arc::new(Int64Array::from(vec![2, 3, 4, 4])) as ArrayRef,
1912 ];
1913
1914 RecordBatch::try_new(schema, columns).expect("valid friendship batch")
1915 }
1916
1917 #[tokio::test]
1918 async fn executes_against_directory_namespace() {
1919 use arrow_array::StringArray;
1920 use tempfile::tempdir;
1921
1922 let tmp_dir = tempdir().unwrap();
1923 write_lance_dataset(&tmp_dir.path().join("Person.lance"), build_people_batch()).await;
1924 write_lance_dataset(
1925 &tmp_dir.path().join("FRIEND_OF.lance"),
1926 build_friendship_batch(),
1927 )
1928 .await;
1929
1930 let config = GraphConfig::builder()
1931 .with_node_label("Person", "person_id")
1932 .with_relationship("FRIEND_OF", "person1_id", "person2_id")
1933 .build()
1934 .expect("valid graph config");
1935
1936 let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name")
1937 .expect("query parses")
1938 .with_config(config);
1939
1940 let namespace = DirNamespace::new(tmp_dir.path().to_string_lossy().into_owned());
1941
1942 let result = query
1943 .execute_with_namespace(namespace.clone(), None)
1944 .await
1945 .expect("namespace execution succeeds");
1946
1947 use arrow_array::Array;
1948 let names = result
1949 .column(0)
1950 .as_any()
1951 .downcast_ref::<StringArray>()
1952 .expect("string column");
1953
1954 let mut values: Vec<String> = (0..names.len())
1955 .map(|i| names.value(i).to_string())
1956 .collect();
1957 values.sort();
1958 assert_eq!(values, vec!["Bob".to_string(), "David".to_string()]);
1959 }
1960
1961 #[tokio::test]
1962 async fn test_execute_fails_on_semantic_error() {
1963 use arrow_array::RecordBatch;
1964 use arrow_schema::{DataType, Field, Schema};
1965 use std::collections::HashMap;
1966 use std::sync::Arc;
1967
1968 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
1969 let batch = RecordBatch::new_empty(schema);
1970 let mut datasets = HashMap::new();
1971 datasets.insert("Person".to_string(), batch);
1972
1973 let cfg = GraphConfig::builder()
1974 .with_node_label("Person", "id")
1975 .build()
1976 .unwrap();
1977
1978 let query = CypherQuery::new("MATCH (n:Person) RETURN x.name")
1980 .unwrap()
1981 .with_config(cfg);
1982
1983 let result = query.execute(datasets, None).await;
1984
1985 assert!(result.is_err());
1986 match result {
1987 Err(GraphError::PlanError { message, .. }) => {
1988 assert!(message.contains("Semantic analysis failed"));
1989 assert!(message.contains("Undefined variable: 'x'"));
1990 }
1991 _ => panic!(
1992 "Expected PlanError with semantic failure message, got {:?}",
1993 result
1994 ),
1995 }
1996 }
1997}