1use crate::ast::CypherQuery as CypherAST;
7use crate::ast::ReadingClause;
8use crate::config::GraphConfig;
9use crate::error::{GraphError, Result};
10use crate::logical_plan::LogicalPlanner;
11use crate::namespace::DirNamespace;
12use crate::parser::parse_cypher_query;
13use crate::simple_executor::{
14 to_df_boolean_expr_simple, to_df_order_by_expr_simple, to_df_value_expr_simple, PathExecutor,
15};
16use arrow_array::RecordBatch;
17use arrow_schema::{Field, Schema, SchemaRef};
18use lance_namespace::models::DescribeTableRequest;
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21
22fn normalize_schema(schema: SchemaRef) -> Result<SchemaRef> {
27 let fields: Vec<_> = schema
28 .fields()
29 .iter()
30 .map(|f| {
31 Arc::new(Field::new(
32 f.name().to_lowercase(),
33 f.data_type().clone(),
34 f.is_nullable(),
35 ))
36 })
37 .collect();
38 Ok(Arc::new(Schema::new(fields)))
39}
40
41fn normalize_record_batch(batch: &RecordBatch) -> Result<RecordBatch> {
46 let normalized_schema = normalize_schema(batch.schema())?;
47 RecordBatch::try_new(normalized_schema, batch.columns().to_vec()).map_err(|e| {
48 GraphError::PlanError {
49 message: format!("Failed to normalize record batch schema: {}", e),
50 location: snafu::Location::new(file!(), line!(), column!()),
51 }
52 })
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
57pub enum ExecutionStrategy {
58 #[default]
60 DataFusion,
61 Simple,
63 LanceNative,
65}
66
67#[derive(Debug, Clone)]
69pub struct CypherQuery {
70 query_text: String,
72 ast: CypherAST,
74 config: Option<GraphConfig>,
76 parameters: HashMap<String, serde_json::Value>,
78}
79impl CypherQuery {
80 pub fn new(query: &str) -> Result<Self> {
82 let ast = parse_cypher_query(query)?;
83
84 Ok(Self {
85 query_text: query.to_string(),
86 ast,
87 config: None,
88 parameters: HashMap::new(),
89 })
90 }
91
92 pub fn with_config(mut self, config: GraphConfig) -> Self {
94 self.config = Some(config);
95 self
96 }
97
98 pub fn with_parameter<K, V>(mut self, key: K, value: V) -> Self
100 where
101 K: Into<String>,
102 V: Into<serde_json::Value>,
103 {
104 self.parameters.insert(key.into(), value.into());
105 self
106 }
107
108 pub fn with_parameters(mut self, params: HashMap<String, serde_json::Value>) -> Self {
110 self.parameters.extend(params);
111 self
112 }
113
114 pub fn query_text(&self) -> &str {
116 &self.query_text
117 }
118
119 pub fn ast(&self) -> &CypherAST {
121 &self.ast
122 }
123
124 pub fn config(&self) -> Option<&GraphConfig> {
126 self.config.as_ref()
127 }
128
129 pub fn parameters(&self) -> &HashMap<String, serde_json::Value> {
131 &self.parameters
132 }
133
134 fn require_config(&self) -> Result<&GraphConfig> {
136 self.config.as_ref().ok_or_else(|| GraphError::ConfigError {
137 message: "Graph configuration is required for query execution".to_string(),
138 location: snafu::Location::new(file!(), line!(), column!()),
139 })
140 }
141
142 pub async fn execute(
178 &self,
179 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
180 strategy: Option<ExecutionStrategy>,
181 ) -> Result<arrow::record_batch::RecordBatch> {
182 let strategy = strategy.unwrap_or_default();
183 match strategy {
184 ExecutionStrategy::DataFusion => self.execute_datafusion(datasets).await,
185 ExecutionStrategy::Simple => self.execute_simple(datasets).await,
186 ExecutionStrategy::LanceNative => Err(GraphError::UnsupportedFeature {
187 feature: "Lance native execution strategy is not yet implemented".to_string(),
188 location: snafu::Location::new(file!(), line!(), column!()),
189 }),
190 }
191 }
192
193 pub async fn execute_with_namespace(
197 &self,
198 namespace: DirNamespace,
199 strategy: Option<ExecutionStrategy>,
200 ) -> Result<arrow::record_batch::RecordBatch> {
201 self.execute_with_namespace_arc(std::sync::Arc::new(namespace), strategy)
202 .await
203 }
204
205 pub async fn execute_with_namespace_arc(
207 &self,
208 namespace: std::sync::Arc<DirNamespace>,
209 strategy: Option<ExecutionStrategy>,
210 ) -> Result<arrow::record_batch::RecordBatch> {
211 let namespace_trait: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync> =
212 namespace;
213 self.execute_with_namespace_internal(namespace_trait, strategy)
214 .await
215 }
216
217 async fn execute_with_namespace_internal(
218 &self,
219 namespace: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync>,
220 strategy: Option<ExecutionStrategy>,
221 ) -> Result<arrow::record_batch::RecordBatch> {
222 let strategy = strategy.unwrap_or_default();
223 match strategy {
224 ExecutionStrategy::DataFusion => {
225 let (catalog, ctx) = self
226 .build_catalog_and_context_from_namespace(namespace)
227 .await?;
228 self.execute_with_catalog_and_context(std::sync::Arc::new(catalog), ctx)
229 .await
230 }
231 ExecutionStrategy::Simple => Err(GraphError::UnsupportedFeature {
232 feature:
233 "Simple execution strategy is not supported for namespace-backed execution"
234 .to_string(),
235 location: snafu::Location::new(file!(), line!(), column!()),
236 }),
237 ExecutionStrategy::LanceNative => Err(GraphError::UnsupportedFeature {
238 feature: "Lance native execution strategy is not yet implemented".to_string(),
239 location: snafu::Location::new(file!(), line!(), column!()),
240 }),
241 }
242 }
243
244 pub async fn explain(
280 &self,
281 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
282 ) -> Result<String> {
283 use std::sync::Arc;
284
285 let (catalog, ctx) = self
287 .build_catalog_and_context_from_datasets(datasets)
288 .await?;
289
290 self.explain_internal(Arc::new(catalog), ctx).await
292 }
293
294 pub async fn to_sql(
311 &self,
312 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
313 ) -> Result<String> {
314 use datafusion_sql::unparser::plan_to_sql;
315 use std::sync::Arc;
316
317 let _config = self.require_config()?;
318
319 let (catalog, ctx) = self
321 .build_catalog_and_context_from_datasets(datasets)
322 .await?;
323
324 let (_, df_plan) = self.create_logical_plans(Arc::new(catalog))?;
326
327 let optimized_plan = ctx
330 .state()
331 .optimize(&df_plan)
332 .map_err(|e| GraphError::PlanError {
333 message: format!("Failed to optimize plan: {}", e),
334 location: snafu::Location::new(file!(), line!(), column!()),
335 })?;
336
337 let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
339 message: format!("Failed to unparse plan to SQL: {}", e),
340 location: snafu::Location::new(file!(), line!(), column!()),
341 })?;
342
343 Ok(sql_ast.to_string())
344 }
345
346 pub async fn execute_with_context(
395 &self,
396 ctx: datafusion::execution::context::SessionContext,
397 ) -> Result<arrow::record_batch::RecordBatch> {
398 use crate::source_catalog::InMemoryCatalog;
399 use datafusion::datasource::DefaultTableSource;
400 use std::sync::Arc;
401
402 let config = self.require_config()?;
403
404 let mut catalog = InMemoryCatalog::new();
406
407 for label in config.node_mappings.keys() {
409 let table_provider =
410 ctx.table_provider(label)
411 .await
412 .map_err(|e| GraphError::ConfigError {
413 message: format!(
414 "Node label '{}' not found in SessionContext: {}",
415 label, e
416 ),
417 location: snafu::Location::new(file!(), line!(), column!()),
418 })?;
419
420 let table_source = Arc::new(DefaultTableSource::new(table_provider));
421 catalog = catalog.with_node_source(label, table_source);
422 }
423
424 for rel_type in config.relationship_mappings.keys() {
426 let table_provider =
427 ctx.table_provider(rel_type)
428 .await
429 .map_err(|e| GraphError::ConfigError {
430 message: format!(
431 "Relationship type '{}' not found in SessionContext: {}",
432 rel_type, e
433 ),
434 location: snafu::Location::new(file!(), line!(), column!()),
435 })?;
436
437 let table_source = Arc::new(DefaultTableSource::new(table_provider));
438 catalog = catalog.with_relationship_source(rel_type, table_source);
439 }
440
441 self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
443 .await
444 }
445
446 pub async fn execute_with_catalog_and_context(
483 &self,
484 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
485 ctx: datafusion::execution::context::SessionContext,
486 ) -> Result<arrow::record_batch::RecordBatch> {
487 use arrow::compute::concat_batches;
488
489 let (_logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
491
492 let df = ctx
494 .execute_logical_plan(df_logical_plan)
495 .await
496 .map_err(|e| GraphError::ExecutionError {
497 message: format!("Failed to execute DataFusion plan: {}", e),
498 location: snafu::Location::new(file!(), line!(), column!()),
499 })?;
500
501 let result_schema = df.schema().inner().clone();
503
504 let batches = df.collect().await.map_err(|e| GraphError::ExecutionError {
506 message: format!("Failed to collect query results: {}", e),
507 location: snafu::Location::new(file!(), line!(), column!()),
508 })?;
509
510 if batches.is_empty() {
511 return Ok(arrow::record_batch::RecordBatch::new_empty(result_schema));
514 }
515
516 let schema = batches[0].schema();
518 concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
519 message: format!("Failed to concatenate result batches: {}", e),
520 location: snafu::Location::new(file!(), line!(), column!()),
521 })
522 }
523
524 async fn execute_datafusion(
540 &self,
541 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
542 ) -> Result<arrow::record_batch::RecordBatch> {
543 use std::sync::Arc;
544
545 let (catalog, ctx) = self
547 .build_catalog_and_context_from_datasets(datasets)
548 .await?;
549
550 self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
552 .await
553 }
554
555 async fn build_catalog_and_context_from_datasets(
557 &self,
558 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
559 ) -> Result<(
560 crate::source_catalog::InMemoryCatalog,
561 datafusion::execution::context::SessionContext,
562 )> {
563 use crate::source_catalog::InMemoryCatalog;
564 use datafusion::datasource::{DefaultTableSource, MemTable};
565 use datafusion::execution::context::SessionContext;
566 use std::sync::Arc;
567
568 if datasets.is_empty() {
569 return Err(GraphError::ConfigError {
570 message: "No input datasets provided".to_string(),
571 location: snafu::Location::new(file!(), line!(), column!()),
572 });
573 }
574
575 let ctx = SessionContext::new();
577 let mut catalog = InMemoryCatalog::new();
578
579 for (name, batch) in &datasets {
581 let normalized_batch = normalize_record_batch(batch)?;
583
584 let mem_table = Arc::new(
585 MemTable::try_new(
586 normalized_batch.schema(),
587 vec![vec![normalized_batch.clone()]],
588 )
589 .map_err(|e| GraphError::PlanError {
590 message: format!("Failed to create MemTable for {}: {}", name, e),
591 location: snafu::Location::new(file!(), line!(), column!()),
592 })?,
593 );
594
595 let normalized_name = name.to_lowercase();
597
598 ctx.register_table(&normalized_name, mem_table.clone())
600 .map_err(|e| GraphError::PlanError {
601 message: format!("Failed to register table {}: {}", name, e),
602 location: snafu::Location::new(file!(), line!(), column!()),
603 })?;
604
605 let table_source = Arc::new(DefaultTableSource::new(mem_table));
606
607 catalog = catalog
610 .with_node_source(name, table_source.clone())
611 .with_relationship_source(name, table_source);
612 }
613
614 Ok((catalog, ctx))
615 }
616
617 async fn build_catalog_and_context_from_namespace(
619 &self,
620 namespace: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync>,
621 ) -> Result<(
622 crate::source_catalog::InMemoryCatalog,
623 datafusion::execution::context::SessionContext,
624 )> {
625 use crate::source_catalog::InMemoryCatalog;
626 use datafusion::datasource::{DefaultTableSource, TableProvider};
627 use datafusion::execution::context::SessionContext;
628 use lance::datafusion::LanceTableProvider;
629 use std::sync::Arc;
630
631 let config = self.require_config()?;
632
633 let mut required_tables: HashSet<String> = HashSet::new();
634 required_tables.extend(config.node_mappings.values().map(|m| m.label.clone()));
637 required_tables.extend(
638 config
639 .relationship_mappings
640 .values()
641 .map(|m| m.relationship_type.clone()),
642 );
643
644 if required_tables.is_empty() {
645 return Err(GraphError::ConfigError {
646 message:
647 "Graph configuration does not reference any node labels or relationship types"
648 .to_string(),
649 location: snafu::Location::new(file!(), line!(), column!()),
650 });
651 }
652
653 let ctx = SessionContext::new();
654 let mut catalog = InMemoryCatalog::new();
655 let mut providers: HashMap<String, Arc<dyn TableProvider>> = HashMap::new();
656
657 for table_name in required_tables {
658 let mut request = DescribeTableRequest::new();
659 request.id = Some(vec![table_name.clone()]);
660
661 let response =
662 namespace
663 .describe_table(request)
664 .await
665 .map_err(|e| GraphError::ConfigError {
666 message: format!(
667 "Namespace failed to resolve table '{}': {}",
668 table_name, e
669 ),
670 location: snafu::Location::new(file!(), line!(), column!()),
671 })?;
672
673 let location = response.location.ok_or_else(|| GraphError::ConfigError {
674 message: format!(
675 "Namespace did not provide a location for table '{}'",
676 table_name
677 ),
678 location: snafu::Location::new(file!(), line!(), column!()),
679 })?;
680
681 let dataset = lance::dataset::Dataset::open(&location)
682 .await
683 .map_err(|e| GraphError::ConfigError {
684 message: format!("Failed to open dataset for table '{}': {}", table_name, e),
685 location: snafu::Location::new(file!(), line!(), column!()),
686 })?;
687
688 let dataset = Arc::new(dataset);
689 let provider: Arc<dyn TableProvider> =
690 Arc::new(LanceTableProvider::new(dataset.clone(), true, true));
691
692 let normalized_table_name = table_name.to_lowercase();
694 ctx.register_table(&normalized_table_name, provider.clone())
695 .map_err(|e| GraphError::PlanError {
696 message: format!(
697 "Failed to register table '{}' in SessionContext: {}",
698 table_name, e
699 ),
700 location: snafu::Location::new(file!(), line!(), column!()),
701 })?;
702
703 providers.insert(normalized_table_name.clone(), provider);
705 }
706
707 for label in config.node_mappings.keys() {
708 let provider = providers
709 .get(label)
710 .ok_or_else(|| GraphError::ConfigError {
711 message: format!(
712 "Namespace did not resolve dataset for node label '{}'",
713 label
714 ),
715 location: snafu::Location::new(file!(), line!(), column!()),
716 })?;
717
718 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
719 catalog = catalog.with_node_source(label, table_source);
720 }
721
722 for rel_type in config.relationship_mappings.keys() {
723 let provider = providers
724 .get(rel_type)
725 .ok_or_else(|| GraphError::ConfigError {
726 message: format!(
727 "Namespace did not resolve dataset for relationship type '{}'",
728 rel_type
729 ),
730 location: snafu::Location::new(file!(), line!(), column!()),
731 })?;
732
733 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
734 catalog = catalog.with_relationship_source(rel_type, table_source);
735 }
736
737 Ok((catalog, ctx))
738 }
739
740 async fn explain_internal(
742 &self,
743 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
744 ctx: datafusion::execution::context::SessionContext,
745 ) -> Result<String> {
746 let (logical_plan, df_logical_plan, physical_plan) =
748 self.create_plans(catalog, &ctx).await?;
749
750 self.format_explain_output(&logical_plan, &df_logical_plan, physical_plan.as_ref())
752 }
753
754 fn create_logical_plans(
759 &self,
760 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
761 ) -> Result<(
762 crate::logical_plan::LogicalOperator,
763 datafusion::logical_expr::LogicalPlan,
764 )> {
765 use crate::datafusion_planner::{DataFusionPlanner, GraphPhysicalPlanner};
766 use crate::semantic::SemanticAnalyzer;
767
768 let config = self.require_config()?;
769
770 let mut analyzer = SemanticAnalyzer::new(config.clone());
772 let semantic = analyzer.analyze(&self.ast)?;
773 if !semantic.errors.is_empty() {
774 return Err(GraphError::PlanError {
775 message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")),
776 location: snafu::Location::new(file!(), line!(), column!()),
777 });
778 }
779
780 let mut logical_planner = LogicalPlanner::new();
782 let logical_plan = logical_planner.plan(&self.ast)?;
783
784 let df_planner = DataFusionPlanner::with_catalog(config.clone(), catalog);
786 let df_logical_plan = df_planner.plan(&logical_plan)?;
787
788 Ok((logical_plan, df_logical_plan))
789 }
790
791 async fn create_plans(
793 &self,
794 catalog: std::sync::Arc<dyn crate::source_catalog::GraphSourceCatalog>,
795 ctx: &datafusion::execution::context::SessionContext,
796 ) -> Result<(
797 crate::logical_plan::LogicalOperator,
798 datafusion::logical_expr::LogicalPlan,
799 std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
800 )> {
801 let (logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
803
804 let df = ctx
806 .execute_logical_plan(df_logical_plan.clone())
807 .await
808 .map_err(|e| GraphError::ExecutionError {
809 message: format!("Failed to execute DataFusion plan: {}", e),
810 location: snafu::Location::new(file!(), line!(), column!()),
811 })?;
812
813 let physical_plan =
814 df.create_physical_plan()
815 .await
816 .map_err(|e| GraphError::ExecutionError {
817 message: format!("Failed to create physical plan: {}", e),
818 location: snafu::Location::new(file!(), line!(), column!()),
819 })?;
820
821 Ok((logical_plan, df_logical_plan, physical_plan))
822 }
823
824 fn format_explain_output(
826 &self,
827 logical_plan: &crate::logical_plan::LogicalOperator,
828 df_logical_plan: &datafusion::logical_expr::LogicalPlan,
829 physical_plan: &dyn datafusion::physical_plan::ExecutionPlan,
830 ) -> Result<String> {
831 let mut output = String::new();
833
834 output.push_str("Cypher Query:\n");
836 output.push_str(&format!(" {}\n\n", self.query_text));
837
838 let mut rows = vec![];
840
841 let graph_plan_str = format!("{:#?}", logical_plan);
843 rows.push(("graph_logical_plan", graph_plan_str));
844
845 let df_logical_str = format!("{}", df_logical_plan.display_indent());
847 rows.push(("logical_plan", df_logical_str));
848
849 let df_physical_str = format!(
851 "{}",
852 datafusion::physical_plan::displayable(physical_plan).indent(true)
853 );
854 rows.push(("physical_plan", df_physical_str));
855
856 let plan_type_width = rows.iter().map(|(t, _)| t.len()).max().unwrap_or(10);
858 let plan_width = rows
859 .iter()
860 .map(|(_, p)| p.lines().map(|l| l.len()).max().unwrap_or(0))
861 .max()
862 .unwrap_or(50);
863
864 let separator = format!(
866 "+{}+{}+",
867 "-".repeat(plan_type_width + 2),
868 "-".repeat(plan_width + 2)
869 );
870
871 output.push_str(&separator);
872 output.push('\n');
873
874 output.push_str(&format!(
876 "| {:<width$} | {:<plan_width$} |\n",
877 "plan_type",
878 "plan",
879 width = plan_type_width,
880 plan_width = plan_width
881 ));
882 output.push_str(&separator);
883 output.push('\n');
884
885 for (plan_type, plan_content) in rows {
887 let lines: Vec<&str> = plan_content.lines().collect();
888 if lines.is_empty() {
889 output.push_str(&format!(
890 "| {:<width$} | {:<plan_width$} |\n",
891 plan_type,
892 "",
893 width = plan_type_width,
894 plan_width = plan_width
895 ));
896 } else {
897 output.push_str(&format!(
899 "| {:<width$} | {:<plan_width$} |\n",
900 plan_type,
901 lines[0],
902 width = plan_type_width,
903 plan_width = plan_width
904 ));
905
906 for line in &lines[1..] {
908 output.push_str(&format!(
909 "| {:<width$} | {:<plan_width$} |\n",
910 "",
911 line,
912 width = plan_type_width,
913 plan_width = plan_width
914 ));
915 }
916 }
917 }
918
919 output.push_str(&separator);
920 output.push('\n');
921
922 Ok(output)
923 }
924
925 pub async fn execute_simple(
932 &self,
933 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
934 ) -> Result<arrow::record_batch::RecordBatch> {
935 use crate::semantic::SemanticAnalyzer;
936 use arrow::compute::concat_batches;
937 use datafusion::datasource::MemTable;
938 use datafusion::prelude::*;
939 use std::sync::Arc;
940
941 let config = self.require_config()?.clone();
943
944 let mut analyzer = SemanticAnalyzer::new(config);
946 let semantic = analyzer.analyze(&self.ast)?;
947 if !semantic.errors.is_empty() {
948 return Err(GraphError::PlanError {
949 message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")),
950 location: snafu::Location::new(file!(), line!(), column!()),
951 });
952 }
953
954 if datasets.is_empty() {
955 return Err(GraphError::PlanError {
956 message: "No input datasets provided".to_string(),
957 location: snafu::Location::new(file!(), line!(), column!()),
958 });
959 }
960
961 let ctx = SessionContext::new();
964 for (name, batch) in &datasets {
965 let normalized_batch = normalize_record_batch(batch)?;
966 let table = MemTable::try_new(
967 normalized_batch.schema(),
968 vec![vec![normalized_batch.clone()]],
969 )
970 .map_err(|e| GraphError::PlanError {
971 message: format!("Failed to create DataFusion table: {}", e),
972 location: snafu::Location::new(file!(), line!(), column!()),
973 })?;
974
975 let normalized_name = name.to_lowercase();
977 ctx.register_table(&normalized_name, Arc::new(table))
978 .map_err(|e| GraphError::PlanError {
979 message: format!("Failed to register table '{}': {}", name, e),
980 location: snafu::Location::new(file!(), line!(), column!()),
981 })?;
982 }
983
984 if let Some(df) = self.try_execute_path_generic(&ctx).await? {
986 let batches = df.collect().await.map_err(|e| GraphError::PlanError {
987 message: format!("Failed to collect results: {}", e),
988 location: snafu::Location::new(file!(), line!(), column!()),
989 })?;
990 if batches.is_empty() {
991 let schema = datasets.values().next().unwrap().schema();
992 return Ok(arrow_array::RecordBatch::new_empty(schema));
993 }
994 let merged = concat_batches(&batches[0].schema(), &batches).map_err(|e| {
995 GraphError::PlanError {
996 message: format!("Failed to concatenate result batches: {}", e),
997 location: snafu::Location::new(file!(), line!(), column!()),
998 }
999 })?;
1000 return Ok(merged);
1001 }
1002
1003 let (table_name, batch) = datasets.iter().next().unwrap();
1005 let schema = batch.schema();
1006
1007 let mut df = ctx
1009 .table(table_name)
1010 .await
1011 .map_err(|e| GraphError::PlanError {
1012 message: format!("Failed to create DataFrame for '{}': {}", table_name, e),
1013 location: snafu::Location::new(file!(), line!(), column!()),
1014 })?;
1015
1016 if let Some(where_clause) = &self.ast.where_clause {
1018 if let Some(filter_expr) = to_df_boolean_expr_simple(&where_clause.expression) {
1019 df = df.filter(filter_expr).map_err(|e| GraphError::PlanError {
1020 message: format!("Failed to apply filter: {}", e),
1021 location: snafu::Location::new(file!(), line!(), column!()),
1022 })?;
1023 }
1024 }
1025
1026 let proj_exprs: Vec<Expr> = self
1028 .ast
1029 .return_clause
1030 .items
1031 .iter()
1032 .map(|item| {
1033 let expr = to_df_value_expr_simple(&item.expression);
1034 if let Some(alias) = &item.alias {
1035 expr.alias(alias)
1036 } else {
1037 expr
1038 }
1039 })
1040 .collect();
1041 if !proj_exprs.is_empty() {
1042 df = df.select(proj_exprs).map_err(|e| GraphError::PlanError {
1043 message: format!("Failed to project: {}", e),
1044 location: snafu::Location::new(file!(), line!(), column!()),
1045 })?;
1046 }
1047
1048 if self.ast.return_clause.distinct {
1050 df = df.distinct().map_err(|e| GraphError::PlanError {
1051 message: format!("Failed to apply DISTINCT: {}", e),
1052 location: snafu::Location::new(file!(), line!(), column!()),
1053 })?;
1054 }
1055
1056 if let Some(order_by) = &self.ast.order_by {
1058 let sort_expr = to_df_order_by_expr_simple(&order_by.items);
1059 df = df.sort(sort_expr).map_err(|e| GraphError::PlanError {
1060 message: format!("Failed to apply ORDER BY: {}", e),
1061 location: snafu::Location::new(file!(), line!(), column!()),
1062 })?;
1063 }
1064
1065 if self.ast.skip.is_some() || self.ast.limit.is_some() {
1067 let offset = self.ast.skip.unwrap_or(0) as usize;
1068 let fetch = self.ast.limit.map(|l| l as usize);
1069 df = df.limit(offset, fetch).map_err(|e| GraphError::PlanError {
1070 message: format!("Failed to apply SKIP/LIMIT: {}", e),
1071 location: snafu::Location::new(file!(), line!(), column!()),
1072 })?;
1073 }
1074
1075 let batches = df.collect().await.map_err(|e| GraphError::PlanError {
1077 message: format!("Failed to collect results: {}", e),
1078 location: snafu::Location::new(file!(), line!(), column!()),
1079 })?;
1080
1081 if batches.is_empty() {
1082 return Ok(arrow_array::RecordBatch::new_empty(schema));
1084 }
1085
1086 let merged =
1087 concat_batches(&batches[0].schema(), &batches).map_err(|e| GraphError::PlanError {
1088 message: format!("Failed to concatenate result batches: {}", e),
1089 location: snafu::Location::new(file!(), line!(), column!()),
1090 })?;
1091 Ok(merged)
1092 }
1093
1094 pub fn referenced_node_labels(&self) -> Vec<String> {
1096 let mut labels = Vec::new();
1097
1098 for clause in &self.ast.reading_clauses {
1099 if let ReadingClause::Match(match_clause) = clause {
1100 for pattern in &match_clause.patterns {
1101 self.collect_node_labels_from_pattern(pattern, &mut labels);
1102 }
1103 }
1104 }
1105
1106 labels.sort();
1107 labels.dedup();
1108 labels
1109 }
1110
1111 pub fn referenced_relationship_types(&self) -> Vec<String> {
1113 let mut types = Vec::new();
1114
1115 for clause in &self.ast.reading_clauses {
1116 if let ReadingClause::Match(match_clause) = clause {
1117 for pattern in &match_clause.patterns {
1118 self.collect_relationship_types_from_pattern(pattern, &mut types);
1119 }
1120 }
1121 }
1122
1123 types.sort();
1124 types.dedup();
1125 types
1126 }
1127
1128 pub fn variables(&self) -> Vec<String> {
1130 let mut variables = Vec::new();
1131
1132 for clause in &self.ast.reading_clauses {
1133 match clause {
1134 ReadingClause::Match(match_clause) => {
1135 for pattern in &match_clause.patterns {
1136 self.collect_variables_from_pattern(pattern, &mut variables);
1137 }
1138 }
1139 ReadingClause::Unwind(unwind_clause) => {
1140 variables.push(unwind_clause.alias.clone());
1141 }
1142 }
1143 }
1144
1145 variables.sort();
1146 variables.dedup();
1147 variables
1148 }
1149
1150 fn collect_node_labels_from_pattern(
1153 &self,
1154 pattern: &crate::ast::GraphPattern,
1155 labels: &mut Vec<String>,
1156 ) {
1157 match pattern {
1158 crate::ast::GraphPattern::Node(node) => {
1159 labels.extend(node.labels.clone());
1160 }
1161 crate::ast::GraphPattern::Path(path) => {
1162 labels.extend(path.start_node.labels.clone());
1163 for segment in &path.segments {
1164 labels.extend(segment.end_node.labels.clone());
1165 }
1166 }
1167 }
1168 }
1169
1170 fn collect_relationship_types_from_pattern(
1171 &self,
1172 pattern: &crate::ast::GraphPattern,
1173 types: &mut Vec<String>,
1174 ) {
1175 if let crate::ast::GraphPattern::Path(path) = pattern {
1176 for segment in &path.segments {
1177 types.extend(segment.relationship.types.clone());
1178 }
1179 }
1180 }
1181
1182 fn collect_variables_from_pattern(
1183 &self,
1184 pattern: &crate::ast::GraphPattern,
1185 variables: &mut Vec<String>,
1186 ) {
1187 match pattern {
1188 crate::ast::GraphPattern::Node(node) => {
1189 if let Some(var) = &node.variable {
1190 variables.push(var.clone());
1191 }
1192 }
1193 crate::ast::GraphPattern::Path(path) => {
1194 if let Some(var) = &path.start_node.variable {
1195 variables.push(var.clone());
1196 }
1197 for segment in &path.segments {
1198 if let Some(var) = &segment.relationship.variable {
1199 variables.push(var.clone());
1200 }
1201 if let Some(var) = &segment.end_node.variable {
1202 variables.push(var.clone());
1203 }
1204 }
1205 }
1206 }
1207 }
1208}
1209
1210impl CypherQuery {
1211 pub async fn execute_with_vector_rerank(
1240 &self,
1241 datasets: HashMap<String, arrow::record_batch::RecordBatch>,
1242 vector_search: crate::lance_vector_search::VectorSearch,
1243 ) -> Result<arrow::record_batch::RecordBatch> {
1244 let candidates = self.execute(datasets, None).await?;
1246
1247 vector_search.search(&candidates).await
1249 }
1250}
1251
1252impl CypherQuery {
1253 async fn try_execute_path_generic(
1255 &self,
1256 ctx: &datafusion::prelude::SessionContext,
1257 ) -> Result<Option<datafusion::dataframe::DataFrame>> {
1258 use crate::ast::GraphPattern;
1259 let match_clause = match self.ast.reading_clauses.as_slice() {
1262 [ReadingClause::Match(mc)] => mc,
1263 _ => return Ok(None),
1264 };
1265 let path = match match_clause.patterns.as_slice() {
1266 [GraphPattern::Path(p)] if !p.segments.is_empty() => p,
1267 _ => return Ok(None),
1268 };
1269 let cfg = self.require_config()?;
1270
1271 if path.segments.len() == 1 {
1273 if let Some(length_range) = &path.segments[0].relationship.length {
1274 let cap: u32 = crate::MAX_VARIABLE_LENGTH_HOPS;
1275 let min_len = length_range.min.unwrap_or(1).max(1);
1276 let max_len = length_range.max.unwrap_or(cap);
1277
1278 if min_len > max_len {
1279 return Err(GraphError::InvalidPattern {
1280 message: format!(
1281 "Invalid variable-length range: min {:?} greater than max {:?}",
1282 length_range.min, length_range.max
1283 ),
1284 location: snafu::Location::new(file!(), line!(), column!()),
1285 });
1286 }
1287
1288 if max_len > cap {
1289 return Err(GraphError::UnsupportedFeature {
1290 feature: format!(
1291 "Variable-length paths with length > {} are not supported (got {:?}..{:?})",
1292 cap, length_range.min, length_range.max
1293 ),
1294 location: snafu::Location::new(file!(), line!(), column!()),
1295 });
1296 }
1297
1298 use datafusion::dataframe::DataFrame;
1299 let mut union_df: Option<DataFrame> = None;
1300
1301 for hops in min_len..=max_len {
1302 let mut synthetic = crate::ast::PathPattern {
1304 start_node: path.start_node.clone(),
1305 segments: Vec::with_capacity(hops as usize),
1306 };
1307
1308 for i in 0..hops {
1309 let mut seg = path.segments[0].clone();
1310 seg.relationship.variable = None;
1312 if (i + 1) < hops {
1313 seg.end_node.variable = None; }
1315 seg.relationship.length = None;
1317 synthetic.segments.push(seg);
1318 }
1319
1320 let exec = PathExecutor::new(ctx, cfg, &synthetic)?;
1321 let mut df = exec.build_chain().await?;
1322 df = exec.apply_where(df, &self.ast)?;
1323 df = exec.apply_return(df, &self.ast)?;
1324
1325 union_df = Some(match union_df {
1326 Some(acc) => acc.union(df).map_err(|e| GraphError::PlanError {
1327 message: format!("Failed to UNION variable-length paths: {}", e),
1328 location: snafu::Location::new(file!(), line!(), column!()),
1329 })?,
1330 None => df,
1331 });
1332 }
1333
1334 return Ok(union_df);
1335 }
1336 }
1337
1338 let exec = PathExecutor::new(ctx, cfg, path)?;
1339 let df = exec.build_chain().await?;
1340 let df = exec.apply_where(df, &self.ast)?;
1341 let df = exec.apply_return(df, &self.ast)?;
1342 Ok(Some(df))
1343 }
1344}
1345
1346#[derive(Debug, Default)]
1348pub struct CypherQueryBuilder {
1349 match_clauses: Vec<crate::ast::MatchClause>,
1350 where_expression: Option<crate::ast::BooleanExpression>,
1351 return_items: Vec<crate::ast::ReturnItem>,
1352 order_by_items: Vec<crate::ast::OrderByItem>,
1353 limit: Option<u64>,
1354 distinct: bool,
1355 skip: Option<u64>,
1356 config: Option<GraphConfig>,
1357 parameters: HashMap<String, serde_json::Value>,
1358}
1359
1360impl CypherQueryBuilder {
1361 pub fn new() -> Self {
1363 Self::default()
1364 }
1365
1366 pub fn match_node(mut self, variable: &str, label: &str) -> Self {
1368 let node = crate::ast::NodePattern {
1369 variable: Some(variable.to_string()),
1370 labels: vec![label.to_string()],
1371 properties: HashMap::new(),
1372 };
1373
1374 let match_clause = crate::ast::MatchClause {
1375 patterns: vec![crate::ast::GraphPattern::Node(node)],
1376 };
1377
1378 self.match_clauses.push(match_clause);
1379 self
1380 }
1381
1382 pub fn with_config(mut self, config: GraphConfig) -> Self {
1384 self.config = Some(config);
1385 self
1386 }
1387
1388 pub fn return_property(mut self, variable: &str, property: &str) -> Self {
1390 let prop_ref = crate::ast::PropertyRef::new(variable, property);
1391 let return_item = crate::ast::ReturnItem {
1392 expression: crate::ast::ValueExpression::Property(prop_ref),
1393 alias: None,
1394 };
1395
1396 self.return_items.push(return_item);
1397 self
1398 }
1399
1400 pub fn distinct(mut self, distinct: bool) -> Self {
1402 self.distinct = distinct;
1403 self
1404 }
1405
1406 pub fn limit(mut self, limit: u64) -> Self {
1408 self.limit = Some(limit);
1409 self
1410 }
1411
1412 pub fn skip(mut self, skip: u64) -> Self {
1414 self.skip = Some(skip);
1415 self
1416 }
1417
1418 pub fn build(self) -> Result<CypherQuery> {
1420 if self.match_clauses.is_empty() {
1421 return Err(GraphError::PlanError {
1422 message: "Query must have at least one MATCH clause".to_string(),
1423 location: snafu::Location::new(file!(), line!(), column!()),
1424 });
1425 }
1426
1427 if self.return_items.is_empty() {
1428 return Err(GraphError::PlanError {
1429 message: "Query must have at least one RETURN item".to_string(),
1430 location: snafu::Location::new(file!(), line!(), column!()),
1431 });
1432 }
1433
1434 let ast = crate::ast::CypherQuery {
1435 reading_clauses: self
1436 .match_clauses
1437 .into_iter()
1438 .map(crate::ast::ReadingClause::Match)
1439 .collect(),
1440 where_clause: self
1441 .where_expression
1442 .map(|expr| crate::ast::WhereClause { expression: expr }),
1443 with_clause: None, post_with_reading_clauses: vec![],
1445 post_with_where_clause: None,
1446 return_clause: crate::ast::ReturnClause {
1447 distinct: self.distinct,
1448 items: self.return_items,
1449 },
1450 order_by: if self.order_by_items.is_empty() {
1451 None
1452 } else {
1453 Some(crate::ast::OrderByClause {
1454 items: self.order_by_items,
1455 })
1456 },
1457 limit: self.limit,
1458 skip: self.skip,
1459 };
1460
1461 let query_text = "MATCH ... RETURN ...".to_string(); let query = CypherQuery {
1465 query_text,
1466 ast,
1467 config: self.config,
1468 parameters: self.parameters,
1469 };
1470
1471 Ok(query)
1472 }
1473}
1474
1475#[cfg(test)]
1476mod tests {
1477 use super::*;
1478 use crate::config::GraphConfig;
1479
1480 #[test]
1481 fn test_parse_simple_cypher_query() {
1482 let query = CypherQuery::new("MATCH (n:Person) RETURN n.name").unwrap();
1483 assert_eq!(query.query_text(), "MATCH (n:Person) RETURN n.name");
1484 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1485 assert_eq!(query.variables(), vec!["n"]);
1486 }
1487
1488 #[test]
1489 fn test_query_with_parameters() {
1490 let mut params = HashMap::new();
1491 params.insert("minAge".to_string(), serde_json::Value::Number(30.into()));
1492
1493 let query = CypherQuery::new("MATCH (n:Person) WHERE n.age > $minAge RETURN n.name")
1494 .unwrap()
1495 .with_parameters(params);
1496
1497 assert!(query.parameters().contains_key("minAge"));
1498 }
1499
1500 #[test]
1501 fn test_query_builder() {
1502 let config = GraphConfig::builder()
1503 .with_node_label("Person", "person_id")
1504 .build()
1505 .unwrap();
1506
1507 let query = CypherQueryBuilder::new()
1508 .with_config(config)
1509 .match_node("n", "Person")
1510 .return_property("n", "name")
1511 .limit(10)
1512 .build()
1513 .unwrap();
1514
1515 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1516 assert_eq!(query.variables(), vec!["n"]);
1517 }
1518
1519 #[test]
1520 fn test_relationship_query_parsing() {
1521 let query =
1522 CypherQuery::new("MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a.name, b.name")
1523 .unwrap();
1524 assert_eq!(query.referenced_node_labels(), vec!["Person"]);
1525 assert_eq!(query.referenced_relationship_types(), vec!["KNOWS"]);
1526 assert_eq!(query.variables(), vec!["a", "b", "r"]);
1527 }
1528
1529 #[tokio::test]
1530 async fn test_execute_basic_projection_and_filter() {
1531 use arrow_array::{Int64Array, RecordBatch, StringArray};
1532 use arrow_schema::{DataType, Field, Schema};
1533 use std::sync::Arc;
1534
1535 let schema = Arc::new(Schema::new(vec![
1537 Field::new("name", DataType::Utf8, true),
1538 Field::new("age", DataType::Int64, true),
1539 ]));
1540 let batch = RecordBatch::try_new(
1541 schema,
1542 vec![
1543 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1544 Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
1545 ],
1546 )
1547 .unwrap();
1548
1549 let cfg = GraphConfig::builder()
1550 .with_node_label("Person", "id")
1551 .build()
1552 .unwrap();
1553
1554 let q = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
1555 .unwrap()
1556 .with_config(cfg);
1557
1558 let mut data = HashMap::new();
1559 data.insert("people".to_string(), batch);
1560
1561 let out = q.execute_simple(data).await.unwrap();
1562 assert_eq!(out.num_rows(), 2);
1563 let names = out
1564 .column(0)
1565 .as_any()
1566 .downcast_ref::<StringArray>()
1567 .unwrap();
1568 let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
1569 let result: Vec<(String, i64)> = (0..out.num_rows())
1571 .map(|i| (names.value(i).to_string(), ages.value(i)))
1572 .collect();
1573 assert!(result.contains(&("Bob".to_string(), 34)));
1574 assert!(result.contains(&("David".to_string(), 42)));
1575 }
1576
1577 #[tokio::test]
1578 async fn test_execute_single_hop_path_join_projection() {
1579 use arrow_array::{Int64Array, RecordBatch, StringArray};
1580 use arrow_schema::{DataType, Field, Schema};
1581 use std::sync::Arc;
1582
1583 let person_schema = Arc::new(Schema::new(vec![
1585 Field::new("id", DataType::Int64, false),
1586 Field::new("name", DataType::Utf8, true),
1587 Field::new("age", DataType::Int64, true),
1588 ]));
1589 let people = RecordBatch::try_new(
1590 person_schema,
1591 vec![
1592 Arc::new(Int64Array::from(vec![1, 2, 3])),
1593 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1594 Arc::new(Int64Array::from(vec![28, 34, 29])),
1595 ],
1596 )
1597 .unwrap();
1598
1599 let rel_schema = Arc::new(Schema::new(vec![
1601 Field::new("src_person_id", DataType::Int64, false),
1602 Field::new("dst_person_id", DataType::Int64, false),
1603 ]));
1604 let knows = RecordBatch::try_new(
1605 rel_schema,
1606 vec![
1607 Arc::new(Int64Array::from(vec![1, 2])), Arc::new(Int64Array::from(vec![2, 3])),
1609 ],
1610 )
1611 .unwrap();
1612
1613 let cfg = GraphConfig::builder()
1615 .with_node_label("Person", "id")
1616 .with_relationship("KNOWS", "src_person_id", "dst_person_id")
1617 .build()
1618 .unwrap();
1619
1620 let q = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN b.name")
1622 .unwrap()
1623 .with_config(cfg);
1624
1625 let mut data = HashMap::new();
1626 data.insert("Person".to_string(), people);
1628 data.insert("KNOWS".to_string(), knows);
1629
1630 let out = q.execute_simple(data).await.unwrap();
1631 let names = out
1633 .column(0)
1634 .as_any()
1635 .downcast_ref::<StringArray>()
1636 .unwrap();
1637 let got: Vec<String> = (0..out.num_rows())
1638 .map(|i| names.value(i).to_string())
1639 .collect();
1640 assert_eq!(got.len(), 2);
1641 assert!(got.contains(&"Bob".to_string()));
1642 assert!(got.contains(&"Carol".to_string()));
1643 }
1644
1645 #[tokio::test]
1646 async fn test_execute_order_by_asc() {
1647 use arrow_array::{Int64Array, RecordBatch, StringArray};
1648 use arrow_schema::{DataType, Field, Schema};
1649 use std::sync::Arc;
1650
1651 let schema = Arc::new(Schema::new(vec![
1653 Field::new("name", DataType::Utf8, true),
1654 Field::new("age", DataType::Int64, true),
1655 ]));
1656 let batch = RecordBatch::try_new(
1657 schema,
1658 vec![
1659 Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
1660 Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
1661 ],
1662 )
1663 .unwrap();
1664
1665 let cfg = GraphConfig::builder()
1666 .with_node_label("Person", "id")
1667 .build()
1668 .unwrap();
1669
1670 let q = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age ORDER BY p.age ASC")
1672 .unwrap()
1673 .with_config(cfg);
1674
1675 let mut data = HashMap::new();
1676 data.insert("people".to_string(), batch);
1677
1678 let out = q.execute_simple(data).await.unwrap();
1679 let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
1680 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1681 assert_eq!(collected, vec![28, 29, 34, 42]);
1682 }
1683
1684 #[tokio::test]
1685 async fn test_execute_order_by_desc_with_skip_limit() {
1686 use arrow_array::{Int64Array, RecordBatch, StringArray};
1687 use arrow_schema::{DataType, Field, Schema};
1688 use std::sync::Arc;
1689
1690 let schema = Arc::new(Schema::new(vec![
1691 Field::new("name", DataType::Utf8, true),
1692 Field::new("age", DataType::Int64, true),
1693 ]));
1694 let batch = RecordBatch::try_new(
1695 schema,
1696 vec![
1697 Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
1698 Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
1699 ],
1700 )
1701 .unwrap();
1702
1703 let cfg = GraphConfig::builder()
1704 .with_node_label("Person", "id")
1705 .build()
1706 .unwrap();
1707
1708 let q =
1710 CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age DESC SKIP 1 LIMIT 2")
1711 .unwrap()
1712 .with_config(cfg);
1713
1714 let mut data = HashMap::new();
1715 data.insert("people".to_string(), batch);
1716
1717 let out = q.execute_simple(data).await.unwrap();
1718 assert_eq!(out.num_rows(), 2);
1719 let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
1720 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1721 assert_eq!(collected, vec![34, 29]);
1722 }
1723
1724 #[tokio::test]
1725 async fn test_execute_skip_without_limit() {
1726 use arrow_array::{Int64Array, RecordBatch};
1727 use arrow_schema::{DataType, Field, Schema};
1728 use std::sync::Arc;
1729
1730 let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
1731 let batch = RecordBatch::try_new(
1732 schema,
1733 vec![Arc::new(Int64Array::from(vec![10, 20, 30, 40]))],
1734 )
1735 .unwrap();
1736
1737 let cfg = GraphConfig::builder()
1738 .with_node_label("Person", "id")
1739 .build()
1740 .unwrap();
1741
1742 let q = CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age ASC SKIP 2")
1743 .unwrap()
1744 .with_config(cfg);
1745
1746 let mut data = HashMap::new();
1747 data.insert("people".to_string(), batch);
1748
1749 let out = q.execute_simple(data).await.unwrap();
1750 assert_eq!(out.num_rows(), 2);
1751 let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
1752 let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
1753 assert_eq!(collected, vec![30, 40]);
1754 }
1755
1756 #[tokio::test]
1757 async fn test_execute_datafusion_pipeline() {
1758 use arrow_array::{Int64Array, RecordBatch, StringArray};
1759 use arrow_schema::{DataType, Field, Schema};
1760 use std::sync::Arc;
1761
1762 let schema = Arc::new(Schema::new(vec![
1764 Field::new("id", DataType::Int64, false),
1765 Field::new("name", DataType::Utf8, false),
1766 Field::new("age", DataType::Int64, false),
1767 ]));
1768
1769 let batch = RecordBatch::try_new(
1770 schema,
1771 vec![
1772 Arc::new(Int64Array::from(vec![1, 2, 3])),
1773 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
1774 Arc::new(Int64Array::from(vec![25, 35, 30])),
1775 ],
1776 )
1777 .unwrap();
1778
1779 let cfg = GraphConfig::builder()
1780 .with_node_label("Person", "id")
1781 .build()
1782 .unwrap();
1783
1784 let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name")
1786 .unwrap()
1787 .with_config(cfg);
1788
1789 let mut datasets = HashMap::new();
1790 datasets.insert("Person".to_string(), batch);
1791
1792 let result = query.execute_datafusion(datasets.clone()).await;
1794
1795 match &result {
1796 Ok(batch) => {
1797 println!(
1798 "DataFusion result: {} rows, {} columns",
1799 batch.num_rows(),
1800 batch.num_columns()
1801 );
1802 if batch.num_rows() > 0 {
1803 println!("First row data: {:?}", batch.slice(0, 1));
1804 }
1805 }
1806 Err(e) => {
1807 println!("DataFusion execution failed: {:?}", e);
1808 }
1809 }
1810
1811 let legacy_result = query.execute_simple(datasets).await.unwrap();
1813 println!(
1814 "Legacy result: {} rows, {} columns",
1815 legacy_result.num_rows(),
1816 legacy_result.num_columns()
1817 );
1818
1819 let result = result.unwrap();
1820
1821 assert_eq!(
1823 result.num_rows(),
1824 1,
1825 "Expected 1 row after filtering WHERE p.age > 30"
1826 );
1827
1828 assert_eq!(
1830 result.num_columns(),
1831 1,
1832 "Expected 1 column after projection RETURN p.name"
1833 );
1834
1835 let names = result
1837 .column(0)
1838 .as_any()
1839 .downcast_ref::<StringArray>()
1840 .unwrap();
1841 assert_eq!(
1842 names.value(0),
1843 "Bob",
1844 "Expected filtered result to contain Bob"
1845 );
1846 }
1847
1848 #[tokio::test]
1849 async fn test_execute_datafusion_simple_scan() {
1850 use arrow_array::{Int64Array, RecordBatch, StringArray};
1851 use arrow_schema::{DataType, Field, Schema};
1852 use std::sync::Arc;
1853
1854 let schema = Arc::new(Schema::new(vec![
1856 Field::new("id", DataType::Int64, false),
1857 Field::new("name", DataType::Utf8, false),
1858 ]));
1859
1860 let batch = RecordBatch::try_new(
1861 schema,
1862 vec![
1863 Arc::new(Int64Array::from(vec![1, 2])),
1864 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
1865 ],
1866 )
1867 .unwrap();
1868
1869 let cfg = GraphConfig::builder()
1870 .with_node_label("Person", "id")
1871 .build()
1872 .unwrap();
1873
1874 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1876 .unwrap()
1877 .with_config(cfg);
1878
1879 let mut datasets = HashMap::new();
1880 datasets.insert("Person".to_string(), batch);
1881
1882 let result = query.execute_datafusion(datasets).await.unwrap();
1884
1885 assert_eq!(
1887 result.num_rows(),
1888 2,
1889 "Should return all 2 rows without filtering"
1890 );
1891 assert_eq!(result.num_columns(), 1, "Should return 1 column (name)");
1892
1893 let names = result
1895 .column(0)
1896 .as_any()
1897 .downcast_ref::<StringArray>()
1898 .unwrap();
1899 let name_set: std::collections::HashSet<String> = (0..result.num_rows())
1900 .map(|i| names.value(i).to_string())
1901 .collect();
1902 let expected: std::collections::HashSet<String> =
1903 ["Alice", "Bob"].iter().map(|s| s.to_string()).collect();
1904 assert_eq!(name_set, expected, "Should return Alice and Bob");
1905 }
1906
1907 #[tokio::test]
1908 async fn test_execute_with_context_simple_scan() {
1909 use arrow_array::{Int64Array, RecordBatch, StringArray};
1910 use arrow_schema::{DataType, Field, Schema};
1911 use datafusion::datasource::MemTable;
1912 use datafusion::execution::context::SessionContext;
1913 use std::sync::Arc;
1914
1915 let schema = Arc::new(Schema::new(vec![
1917 Field::new("id", DataType::Int64, false),
1918 Field::new("name", DataType::Utf8, false),
1919 Field::new("age", DataType::Int64, false),
1920 ]));
1921 let batch = RecordBatch::try_new(
1922 schema.clone(),
1923 vec![
1924 Arc::new(Int64Array::from(vec![1, 2, 3])),
1925 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
1926 Arc::new(Int64Array::from(vec![28, 34, 29])),
1927 ],
1928 )
1929 .unwrap();
1930
1931 let mem_table =
1933 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1934 let ctx = SessionContext::new();
1935 ctx.register_table("Person", mem_table).unwrap();
1936
1937 let cfg = GraphConfig::builder()
1939 .with_node_label("Person", "id")
1940 .build()
1941 .unwrap();
1942
1943 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
1944 .unwrap()
1945 .with_config(cfg);
1946
1947 let result = query.execute_with_context(ctx).await.unwrap();
1949
1950 assert_eq!(result.num_rows(), 3);
1952 assert_eq!(result.num_columns(), 1);
1953
1954 let names = result
1955 .column(0)
1956 .as_any()
1957 .downcast_ref::<StringArray>()
1958 .unwrap();
1959 assert_eq!(names.value(0), "Alice");
1960 assert_eq!(names.value(1), "Bob");
1961 assert_eq!(names.value(2), "Carol");
1962 }
1963
1964 #[tokio::test]
1965 async fn test_execute_with_context_with_filter() {
1966 use arrow_array::{Int64Array, RecordBatch, StringArray};
1967 use arrow_schema::{DataType, Field, Schema};
1968 use datafusion::datasource::MemTable;
1969 use datafusion::execution::context::SessionContext;
1970 use std::sync::Arc;
1971
1972 let schema = Arc::new(Schema::new(vec![
1974 Field::new("id", DataType::Int64, false),
1975 Field::new("name", DataType::Utf8, false),
1976 Field::new("age", DataType::Int64, false),
1977 ]));
1978 let batch = RecordBatch::try_new(
1979 schema.clone(),
1980 vec![
1981 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
1982 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
1983 Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
1984 ],
1985 )
1986 .unwrap();
1987
1988 let mem_table =
1990 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
1991 let ctx = SessionContext::new();
1992 ctx.register_table("Person", mem_table).unwrap();
1993
1994 let cfg = GraphConfig::builder()
1996 .with_node_label("Person", "id")
1997 .build()
1998 .unwrap();
1999
2000 let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
2001 .unwrap()
2002 .with_config(cfg);
2003
2004 let result = query.execute_with_context(ctx).await.unwrap();
2006
2007 assert_eq!(result.num_rows(), 2);
2009 assert_eq!(result.num_columns(), 2);
2010
2011 let names = result
2012 .column(0)
2013 .as_any()
2014 .downcast_ref::<StringArray>()
2015 .unwrap();
2016 let ages = result
2017 .column(1)
2018 .as_any()
2019 .downcast_ref::<Int64Array>()
2020 .unwrap();
2021
2022 let results: Vec<(String, i64)> = (0..result.num_rows())
2023 .map(|i| (names.value(i).to_string(), ages.value(i)))
2024 .collect();
2025
2026 assert!(results.contains(&("Bob".to_string(), 34)));
2027 assert!(results.contains(&("David".to_string(), 42)));
2028 }
2029
2030 #[tokio::test]
2031 async fn test_execute_with_context_relationship_traversal() {
2032 use arrow_array::{Int64Array, RecordBatch, StringArray};
2033 use arrow_schema::{DataType, Field, Schema};
2034 use datafusion::datasource::MemTable;
2035 use datafusion::execution::context::SessionContext;
2036 use std::sync::Arc;
2037
2038 let person_schema = Arc::new(Schema::new(vec![
2040 Field::new("id", DataType::Int64, false),
2041 Field::new("name", DataType::Utf8, false),
2042 ]));
2043 let person_batch = RecordBatch::try_new(
2044 person_schema.clone(),
2045 vec![
2046 Arc::new(Int64Array::from(vec![1, 2, 3])),
2047 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
2048 ],
2049 )
2050 .unwrap();
2051
2052 let knows_schema = Arc::new(Schema::new(vec![
2054 Field::new("src_id", DataType::Int64, false),
2055 Field::new("dst_id", DataType::Int64, false),
2056 Field::new("since", DataType::Int64, false),
2057 ]));
2058 let knows_batch = RecordBatch::try_new(
2059 knows_schema.clone(),
2060 vec![
2061 Arc::new(Int64Array::from(vec![1, 2])),
2062 Arc::new(Int64Array::from(vec![2, 3])),
2063 Arc::new(Int64Array::from(vec![2020, 2021])),
2064 ],
2065 )
2066 .unwrap();
2067
2068 let person_table = Arc::new(
2070 MemTable::try_new(person_schema.clone(), vec![vec![person_batch.clone()]]).unwrap(),
2071 );
2072 let knows_table = Arc::new(
2073 MemTable::try_new(knows_schema.clone(), vec![vec![knows_batch.clone()]]).unwrap(),
2074 );
2075
2076 let ctx = SessionContext::new();
2077 ctx.register_table("Person", person_table).unwrap();
2078 ctx.register_table("KNOWS", knows_table).unwrap();
2079
2080 let cfg = GraphConfig::builder()
2082 .with_node_label("Person", "id")
2083 .with_relationship("KNOWS", "src_id", "dst_id")
2084 .build()
2085 .unwrap();
2086
2087 let query = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a.name, b.name")
2088 .unwrap()
2089 .with_config(cfg);
2090
2091 let result = query.execute_with_context(ctx).await.unwrap();
2093
2094 assert_eq!(result.num_rows(), 2);
2096 assert_eq!(result.num_columns(), 2);
2097
2098 let src_names = result
2099 .column(0)
2100 .as_any()
2101 .downcast_ref::<StringArray>()
2102 .unwrap();
2103 let dst_names = result
2104 .column(1)
2105 .as_any()
2106 .downcast_ref::<StringArray>()
2107 .unwrap();
2108
2109 let relationships: Vec<(String, String)> = (0..result.num_rows())
2110 .map(|i| {
2111 (
2112 src_names.value(i).to_string(),
2113 dst_names.value(i).to_string(),
2114 )
2115 })
2116 .collect();
2117
2118 assert!(relationships.contains(&("Alice".to_string(), "Bob".to_string())));
2119 assert!(relationships.contains(&("Bob".to_string(), "Carol".to_string())));
2120 }
2121
2122 #[tokio::test]
2123 async fn test_execute_with_context_order_by_limit() {
2124 use arrow_array::{Int64Array, RecordBatch, StringArray};
2125 use arrow_schema::{DataType, Field, Schema};
2126 use datafusion::datasource::MemTable;
2127 use datafusion::execution::context::SessionContext;
2128 use std::sync::Arc;
2129
2130 let schema = Arc::new(Schema::new(vec![
2132 Field::new("id", DataType::Int64, false),
2133 Field::new("name", DataType::Utf8, false),
2134 Field::new("score", DataType::Int64, false),
2135 ]));
2136 let batch = RecordBatch::try_new(
2137 schema.clone(),
2138 vec![
2139 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
2140 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
2141 Arc::new(Int64Array::from(vec![85, 92, 78, 95])),
2142 ],
2143 )
2144 .unwrap();
2145
2146 let mem_table =
2148 Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
2149 let ctx = SessionContext::new();
2150 ctx.register_table("Student", mem_table).unwrap();
2151
2152 let cfg = GraphConfig::builder()
2154 .with_node_label("Student", "id")
2155 .build()
2156 .unwrap();
2157
2158 let query = CypherQuery::new(
2159 "MATCH (s:Student) RETURN s.name, s.score ORDER BY s.score DESC LIMIT 2",
2160 )
2161 .unwrap()
2162 .with_config(cfg);
2163
2164 let result = query.execute_with_context(ctx).await.unwrap();
2166
2167 assert_eq!(result.num_rows(), 2);
2169 assert_eq!(result.num_columns(), 2);
2170
2171 let names = result
2172 .column(0)
2173 .as_any()
2174 .downcast_ref::<StringArray>()
2175 .unwrap();
2176 let scores = result
2177 .column(1)
2178 .as_any()
2179 .downcast_ref::<Int64Array>()
2180 .unwrap();
2181
2182 assert_eq!(names.value(0), "David");
2184 assert_eq!(scores.value(0), 95);
2185
2186 assert_eq!(names.value(1), "Bob");
2188 assert_eq!(scores.value(1), 92);
2189 }
2190
2191 #[tokio::test]
2192 async fn test_to_sql() {
2193 use arrow_array::RecordBatch;
2194 use arrow_schema::{DataType, Field, Schema};
2195 use std::collections::HashMap;
2196 use std::sync::Arc;
2197
2198 let schema = Arc::new(Schema::new(vec![
2199 Field::new("id", DataType::Int64, false),
2200 Field::new("name", DataType::Utf8, false),
2201 ]));
2202 let batch = RecordBatch::new_empty(schema.clone());
2203
2204 let mut datasets = HashMap::new();
2205 datasets.insert("Person".to_string(), batch);
2206
2207 let cfg = GraphConfig::builder()
2208 .with_node_label("Person", "id")
2209 .build()
2210 .unwrap();
2211
2212 let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
2213 .unwrap()
2214 .with_config(cfg);
2215
2216 let sql = query.to_sql(datasets).await.unwrap();
2217 println!("Generated SQL: {}", sql);
2218
2219 assert!(sql.contains("SELECT"));
2220 assert!(sql.to_lowercase().contains("from person"));
2221 assert!(sql.contains("p.name"));
2224 }
2225
2226 async fn write_lance_dataset(path: &std::path::Path, batch: arrow_array::RecordBatch) {
2227 use arrow_array::{RecordBatch, RecordBatchIterator};
2228 use lance::dataset::{Dataset, WriteParams};
2229
2230 let schema = batch.schema();
2231 let batches: Vec<std::result::Result<RecordBatch, arrow::error::ArrowError>> =
2232 vec![std::result::Result::Ok(batch)];
2233 let reader = RecordBatchIterator::new(batches.into_iter(), schema);
2234
2235 Dataset::write(reader, path.to_str().unwrap(), None::<WriteParams>)
2236 .await
2237 .expect("write lance dataset");
2238 }
2239
2240 fn build_people_batch() -> arrow_array::RecordBatch {
2241 use arrow_array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray};
2242 use arrow_schema::{DataType, Field, Schema};
2243 use std::sync::Arc;
2244
2245 let schema = Arc::new(Schema::new(vec![
2246 Field::new("person_id", DataType::Int64, false),
2247 Field::new("name", DataType::Utf8, false),
2248 Field::new("age", DataType::Int32, false),
2249 ]));
2250
2251 let columns: Vec<ArrayRef> = vec![
2252 Arc::new(Int64Array::from(vec![1, 2, 3, 4])) as ArrayRef,
2253 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])) as ArrayRef,
2254 Arc::new(Int32Array::from(vec![28, 34, 29, 42])) as ArrayRef,
2255 ];
2256
2257 RecordBatch::try_new(schema, columns).expect("valid person batch")
2258 }
2259
2260 fn build_friendship_batch() -> arrow_array::RecordBatch {
2261 use arrow_array::{ArrayRef, Int64Array, RecordBatch};
2262 use arrow_schema::{DataType, Field, Schema};
2263 use std::sync::Arc;
2264
2265 let schema = Arc::new(Schema::new(vec![
2266 Field::new("person1_id", DataType::Int64, false),
2267 Field::new("person2_id", DataType::Int64, false),
2268 ]));
2269
2270 let columns: Vec<ArrayRef> = vec![
2271 Arc::new(Int64Array::from(vec![1, 1, 2, 3])) as ArrayRef,
2272 Arc::new(Int64Array::from(vec![2, 3, 4, 4])) as ArrayRef,
2273 ];
2274
2275 RecordBatch::try_new(schema, columns).expect("valid friendship batch")
2276 }
2277
2278 #[tokio::test]
2279 async fn executes_against_directory_namespace() {
2280 use arrow_array::StringArray;
2281 use tempfile::tempdir;
2282
2283 let tmp_dir = tempdir().unwrap();
2284 write_lance_dataset(&tmp_dir.path().join("Person.lance"), build_people_batch()).await;
2285 write_lance_dataset(
2286 &tmp_dir.path().join("FRIEND_OF.lance"),
2287 build_friendship_batch(),
2288 )
2289 .await;
2290
2291 let config = GraphConfig::builder()
2292 .with_node_label("Person", "person_id")
2293 .with_relationship("FRIEND_OF", "person1_id", "person2_id")
2294 .build()
2295 .expect("valid graph config");
2296
2297 let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name")
2298 .expect("query parses")
2299 .with_config(config);
2300
2301 let namespace = DirNamespace::new(tmp_dir.path().to_string_lossy().into_owned());
2302
2303 let result = query
2304 .execute_with_namespace(namespace.clone(), None)
2305 .await
2306 .expect("namespace execution succeeds");
2307
2308 use arrow_array::Array;
2309 let names = result
2310 .column(0)
2311 .as_any()
2312 .downcast_ref::<StringArray>()
2313 .expect("string column");
2314
2315 let mut values: Vec<String> = (0..names.len())
2316 .map(|i| names.value(i).to_string())
2317 .collect();
2318 values.sort();
2319 assert_eq!(values, vec!["Bob".to_string(), "David".to_string()]);
2320
2321 let err = query
2322 .execute_with_namespace(namespace, Some(ExecutionStrategy::Simple))
2323 .await
2324 .expect_err("simple strategy not supported");
2325 assert!(
2326 matches!(err, GraphError::UnsupportedFeature { .. }),
2327 "expected unsupported feature error, got {err:?}"
2328 );
2329 }
2330
2331 #[tokio::test]
2332 async fn test_execute_fails_on_semantic_error() {
2333 use arrow_array::RecordBatch;
2334 use arrow_schema::{DataType, Field, Schema};
2335 use std::collections::HashMap;
2336 use std::sync::Arc;
2337
2338 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
2339 let batch = RecordBatch::new_empty(schema);
2340 let mut datasets = HashMap::new();
2341 datasets.insert("Person".to_string(), batch);
2342
2343 let cfg = GraphConfig::builder()
2344 .with_node_label("Person", "id")
2345 .build()
2346 .unwrap();
2347
2348 let query = CypherQuery::new("MATCH (n:Person) RETURN x.name")
2350 .unwrap()
2351 .with_config(cfg);
2352
2353 let result = query.execute_simple(datasets).await;
2354
2355 assert!(result.is_err());
2356 match result {
2357 Err(GraphError::PlanError { message, .. }) => {
2358 assert!(message.contains("Semantic analysis failed"));
2359 assert!(message.contains("Undefined variable: 'x'"));
2360 }
2361 _ => panic!(
2362 "Expected PlanError with semantic failure message, got {:?}",
2363 result
2364 ),
2365 }
2366 }
2367}