use crate::ast::CypherQuery as CypherAST;
use crate::ast::ReadingClause;
use crate::config::GraphConfig;
use crate::error::{GraphError, Result};
use crate::logical_plan::LogicalPlanner;
use crate::parser::parse_cypher_query;
use arrow_array::RecordBatch;
use arrow_schema::{Field, Schema, SchemaRef};
use lance_graph_catalog::DirNamespace;
use lance_namespace::models::DescribeTableRequest;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
pub(crate) fn normalize_schema(schema: SchemaRef) -> Result<SchemaRef> {
let fields: Vec<_> = schema
.fields()
.iter()
.map(|f| {
Arc::new(Field::new(
f.name().to_lowercase(),
f.data_type().clone(),
f.is_nullable(),
))
})
.collect();
Ok(Arc::new(Schema::new(fields)))
}
pub(crate) fn normalize_record_batch(batch: &RecordBatch) -> Result<RecordBatch> {
let normalized_schema = normalize_schema(batch.schema())?;
RecordBatch::try_new(normalized_schema, batch.columns().to_vec()).map_err(|e| {
GraphError::PlanError {
message: format!("Failed to normalize record batch schema: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
}
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ExecutionStrategy {
#[default]
DataFusion,
LanceNative,
}
#[derive(Debug, Clone)]
pub struct CypherQuery {
query_text: String,
ast: CypherAST,
config: Option<GraphConfig>,
parameters: HashMap<String, serde_json::Value>,
}
impl CypherQuery {
pub fn new(query: &str) -> Result<Self> {
let ast = parse_cypher_query(query)?;
Ok(Self {
query_text: query.to_string(),
ast,
config: None,
parameters: HashMap::new(),
})
}
pub fn with_config(mut self, config: GraphConfig) -> Self {
self.config = Some(config);
self
}
pub fn with_parameter<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<serde_json::Value>,
{
self.parameters
.insert(key.into().to_lowercase(), value.into());
self
}
pub fn with_parameters(mut self, params: HashMap<String, serde_json::Value>) -> Self {
for (k, v) in params {
self.parameters.insert(k.to_lowercase(), v);
}
self
}
pub fn query_text(&self) -> &str {
&self.query_text
}
pub fn ast(&self) -> &CypherAST {
&self.ast
}
pub fn config(&self) -> Option<&GraphConfig> {
self.config.as_ref()
}
pub fn parameters(&self) -> &HashMap<String, serde_json::Value> {
&self.parameters
}
fn require_config(&self) -> Result<&GraphConfig> {
self.config.as_ref().ok_or_else(|| GraphError::ConfigError {
message: "Graph configuration is required for query execution".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
})
}
pub async fn execute(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
strategy: Option<ExecutionStrategy>,
) -> Result<arrow::record_batch::RecordBatch> {
let strategy = strategy.unwrap_or_default();
match strategy {
ExecutionStrategy::DataFusion => self.execute_datafusion(datasets).await,
ExecutionStrategy::LanceNative => Err(GraphError::UnsupportedFeature {
feature: "Lance native execution strategy is not yet implemented".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
}),
}
}
pub async fn execute_with_namespace(
&self,
namespace: DirNamespace,
strategy: Option<ExecutionStrategy>,
) -> Result<arrow::record_batch::RecordBatch> {
self.execute_with_namespace_arc(std::sync::Arc::new(namespace), strategy)
.await
}
pub async fn execute_with_namespace_arc(
&self,
namespace: std::sync::Arc<DirNamespace>,
strategy: Option<ExecutionStrategy>,
) -> Result<arrow::record_batch::RecordBatch> {
let namespace_trait: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync> =
namespace;
self.execute_with_namespace_internal(namespace_trait, strategy)
.await
}
async fn execute_with_namespace_internal(
&self,
namespace: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync>,
strategy: Option<ExecutionStrategy>,
) -> Result<arrow::record_batch::RecordBatch> {
let strategy = strategy.unwrap_or_default();
match strategy {
ExecutionStrategy::DataFusion => {
let (catalog, ctx) = self
.build_catalog_and_context_from_namespace(namespace)
.await?;
self.execute_with_catalog_and_context(std::sync::Arc::new(catalog), ctx)
.await
}
ExecutionStrategy::LanceNative => Err(GraphError::UnsupportedFeature {
feature: "Lance native execution strategy is not yet implemented".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
}),
}
}
pub async fn explain(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
) -> Result<String> {
use std::sync::Arc;
let (catalog, ctx) = self
.build_catalog_and_context_from_datasets(datasets)
.await?;
self.explain_internal(Arc::new(catalog), ctx).await
}
pub async fn to_sql(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
) -> Result<String> {
use datafusion_sql::unparser::plan_to_sql;
use std::sync::Arc;
let _config = self.require_config()?;
let (catalog, ctx) = self
.build_catalog_and_context_from_datasets(datasets)
.await?;
let (_, df_plan) = self.create_logical_plans(Arc::new(catalog))?;
let optimized_plan = ctx
.state()
.optimize(&df_plan)
.map_err(|e| GraphError::PlanError {
message: format!("Failed to optimize plan: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
message: format!("Failed to unparse plan to SQL: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
Ok(sql_ast.to_string())
}
pub async fn execute_with_context(
&self,
ctx: datafusion::execution::context::SessionContext,
) -> Result<arrow::record_batch::RecordBatch> {
use datafusion::datasource::DefaultTableSource;
use lance_graph_catalog::InMemoryCatalog;
use std::sync::Arc;
let config = self.require_config()?;
let mut catalog = InMemoryCatalog::new();
for label in config.node_mappings.keys() {
let table_provider =
ctx.table_provider(label)
.await
.map_err(|e| GraphError::ConfigError {
message: format!(
"Node label '{}' not found in SessionContext: {}",
label, e
),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let table_source = Arc::new(DefaultTableSource::new(table_provider));
catalog = catalog.with_node_source(label, table_source);
}
for rel_type in config.relationship_mappings.keys() {
let table_provider =
ctx.table_provider(rel_type)
.await
.map_err(|e| GraphError::ConfigError {
message: format!(
"Relationship type '{}' not found in SessionContext: {}",
rel_type, e
),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let table_source = Arc::new(DefaultTableSource::new(table_provider));
catalog = catalog.with_relationship_source(rel_type, table_source);
}
self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
.await
}
pub async fn execute_with_catalog_and_context(
&self,
catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
ctx: datafusion::execution::context::SessionContext,
) -> Result<arrow::record_batch::RecordBatch> {
use arrow::compute::concat_batches;
let (_logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
let df = ctx
.execute_logical_plan(df_logical_plan)
.await
.map_err(|e| GraphError::ExecutionError {
message: format!("Failed to execute DataFusion plan: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let result_schema = df.schema().inner().clone();
let batches = df.collect().await.map_err(|e| GraphError::ExecutionError {
message: format!("Failed to collect query results: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
if batches.is_empty() {
return Ok(arrow::record_batch::RecordBatch::new_empty(result_schema));
}
let schema = batches[0].schema();
concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
message: format!("Failed to concatenate result batches: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})
}
async fn execute_datafusion(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
) -> Result<arrow::record_batch::RecordBatch> {
use std::sync::Arc;
let (catalog, ctx) = self
.build_catalog_and_context_from_datasets(datasets)
.await?;
self.execute_with_catalog_and_context(Arc::new(catalog), ctx)
.await
}
async fn build_catalog_and_context_from_datasets(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
) -> Result<(
lance_graph_catalog::InMemoryCatalog,
datafusion::execution::context::SessionContext,
)> {
use datafusion::datasource::{DefaultTableSource, MemTable};
use datafusion::execution::context::SessionContext;
use lance_graph_catalog::InMemoryCatalog;
use std::sync::Arc;
if datasets.is_empty() {
return Err(GraphError::ConfigError {
message: "No input datasets provided".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
let ctx = SessionContext::new();
let mut catalog = InMemoryCatalog::new();
for (name, batch) in &datasets {
let normalized_batch = normalize_record_batch(batch)?;
let mem_table = Arc::new(
MemTable::try_new(
normalized_batch.schema(),
vec![vec![normalized_batch.clone()]],
)
.map_err(|e| GraphError::PlanError {
message: format!("Failed to create MemTable for {}: {}", name, e),
location: snafu::Location::new(file!(), line!(), column!()),
})?,
);
let normalized_name = name.to_lowercase();
ctx.register_table(&normalized_name, mem_table.clone())
.map_err(|e| GraphError::PlanError {
message: format!("Failed to register table {}: {}", name, e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let table_source = Arc::new(DefaultTableSource::new(mem_table));
catalog = catalog
.with_node_source(name, table_source.clone())
.with_relationship_source(name, table_source);
}
Ok((catalog, ctx))
}
async fn build_catalog_and_context_from_namespace(
&self,
namespace: std::sync::Arc<dyn lance_namespace::LanceNamespace + Send + Sync>,
) -> Result<(
lance_graph_catalog::InMemoryCatalog,
datafusion::execution::context::SessionContext,
)> {
use datafusion::datasource::{DefaultTableSource, TableProvider};
use datafusion::execution::context::SessionContext;
use lance::datafusion::LanceTableProvider;
use lance_graph_catalog::InMemoryCatalog;
use std::sync::Arc;
let config = self.require_config()?;
let mut required_tables: HashSet<String> = HashSet::new();
required_tables.extend(config.node_mappings.values().map(|m| m.label.clone()));
required_tables.extend(
config
.relationship_mappings
.values()
.map(|m| m.relationship_type.clone()),
);
if required_tables.is_empty() {
return Err(GraphError::ConfigError {
message:
"Graph configuration does not reference any node labels or relationship types"
.to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
let ctx = SessionContext::new();
let mut catalog = InMemoryCatalog::new();
let mut providers: HashMap<String, Arc<dyn TableProvider>> = HashMap::new();
for table_name in required_tables {
let mut request = DescribeTableRequest::new();
request.id = Some(vec![table_name.clone()]);
let response =
namespace
.describe_table(request)
.await
.map_err(|e| GraphError::ConfigError {
message: format!(
"Namespace failed to resolve table '{}': {}",
table_name, e
),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let location = response.location.ok_or_else(|| GraphError::ConfigError {
message: format!(
"Namespace did not provide a location for table '{}'",
table_name
),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let dataset = lance::dataset::Dataset::open(&location)
.await
.map_err(|e| GraphError::ConfigError {
message: format!("Failed to open dataset for table '{}': {}", table_name, e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let dataset = Arc::new(dataset);
let provider: Arc<dyn TableProvider> =
Arc::new(LanceTableProvider::new(dataset.clone(), true, true));
let normalized_table_name = table_name.to_lowercase();
ctx.register_table(&normalized_table_name, provider.clone())
.map_err(|e| GraphError::PlanError {
message: format!(
"Failed to register table '{}' in SessionContext: {}",
table_name, e
),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
providers.insert(normalized_table_name.clone(), provider);
}
for label in config.node_mappings.keys() {
let provider = providers
.get(label)
.ok_or_else(|| GraphError::ConfigError {
message: format!(
"Namespace did not resolve dataset for node label '{}'",
label
),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
catalog = catalog.with_node_source(label, table_source);
}
for rel_type in config.relationship_mappings.keys() {
let provider = providers
.get(rel_type)
.ok_or_else(|| GraphError::ConfigError {
message: format!(
"Namespace did not resolve dataset for relationship type '{}'",
rel_type
),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
catalog = catalog.with_relationship_source(rel_type, table_source);
}
Ok((catalog, ctx))
}
async fn explain_internal(
&self,
catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
ctx: datafusion::execution::context::SessionContext,
) -> Result<String> {
let (logical_plan, df_logical_plan, physical_plan) =
self.create_plans(catalog, &ctx).await?;
self.format_explain_output(&logical_plan, &df_logical_plan, physical_plan.as_ref())
}
fn create_logical_plans(
&self,
catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
) -> Result<(
crate::logical_plan::LogicalOperator,
datafusion::logical_expr::LogicalPlan,
)> {
use crate::datafusion_planner::{DataFusionPlanner, GraphPhysicalPlanner};
use crate::semantic::SemanticAnalyzer;
let config = self.require_config()?;
let mut analyzer = SemanticAnalyzer::new(config.clone());
let semantic = analyzer.analyze(&self.ast, &self.parameters)?;
if !semantic.errors.is_empty() {
return Err(GraphError::PlanError {
message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
let mut logical_planner = LogicalPlanner::new(config);
let logical_plan = logical_planner.plan(&semantic.ast)?;
let df_planner = DataFusionPlanner::with_catalog(config.clone(), catalog);
let df_logical_plan = df_planner.plan(&logical_plan)?;
Ok((logical_plan, df_logical_plan))
}
async fn create_plans(
&self,
catalog: std::sync::Arc<dyn lance_graph_catalog::GraphSourceCatalog>,
ctx: &datafusion::execution::context::SessionContext,
) -> Result<(
crate::logical_plan::LogicalOperator,
datafusion::logical_expr::LogicalPlan,
std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
)> {
let (logical_plan, df_logical_plan) = self.create_logical_plans(catalog)?;
let df = ctx
.execute_logical_plan(df_logical_plan.clone())
.await
.map_err(|e| GraphError::ExecutionError {
message: format!("Failed to execute DataFusion plan: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let physical_plan =
df.create_physical_plan()
.await
.map_err(|e| GraphError::ExecutionError {
message: format!("Failed to create physical plan: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
Ok((logical_plan, df_logical_plan, physical_plan))
}
fn format_explain_output(
&self,
logical_plan: &crate::logical_plan::LogicalOperator,
df_logical_plan: &datafusion::logical_expr::LogicalPlan,
physical_plan: &dyn datafusion::physical_plan::ExecutionPlan,
) -> Result<String> {
let mut output = String::new();
output.push_str("Cypher Query:\n");
output.push_str(&format!(" {}\n\n", self.query_text));
let mut rows = vec![];
let graph_plan_str = format!("{:#?}", logical_plan);
rows.push(("graph_logical_plan", graph_plan_str));
let df_logical_str = format!("{}", df_logical_plan.display_indent());
rows.push(("logical_plan", df_logical_str));
let df_physical_str = format!(
"{}",
datafusion::physical_plan::displayable(physical_plan).indent(true)
);
rows.push(("physical_plan", df_physical_str));
let plan_type_width = rows.iter().map(|(t, _)| t.len()).max().unwrap_or(10);
let plan_width = rows
.iter()
.map(|(_, p)| p.lines().map(|l| l.len()).max().unwrap_or(0))
.max()
.unwrap_or(50);
let separator = format!(
"+{}+{}+",
"-".repeat(plan_type_width + 2),
"-".repeat(plan_width + 2)
);
output.push_str(&separator);
output.push('\n');
output.push_str(&format!(
"| {:<width$} | {:<plan_width$} |\n",
"plan_type",
"plan",
width = plan_type_width,
plan_width = plan_width
));
output.push_str(&separator);
output.push('\n');
for (plan_type, plan_content) in rows {
let lines: Vec<&str> = plan_content.lines().collect();
if lines.is_empty() {
output.push_str(&format!(
"| {:<width$} | {:<plan_width$} |\n",
plan_type,
"",
width = plan_type_width,
plan_width = plan_width
));
} else {
output.push_str(&format!(
"| {:<width$} | {:<plan_width$} |\n",
plan_type,
lines[0],
width = plan_type_width,
plan_width = plan_width
));
for line in &lines[1..] {
output.push_str(&format!(
"| {:<width$} | {:<plan_width$} |\n",
"",
line,
width = plan_type_width,
plan_width = plan_width
));
}
}
}
output.push_str(&separator);
output.push('\n');
Ok(output)
}
pub fn referenced_node_labels(&self) -> Vec<String> {
let mut labels = Vec::new();
for clause in &self.ast.reading_clauses {
if let ReadingClause::Match(match_clause) = clause {
for pattern in &match_clause.patterns {
self.collect_node_labels_from_pattern(pattern, &mut labels);
}
}
}
labels.sort();
labels.dedup();
labels
}
pub fn referenced_relationship_types(&self) -> Vec<String> {
let mut types = Vec::new();
for clause in &self.ast.reading_clauses {
if let ReadingClause::Match(match_clause) = clause {
for pattern in &match_clause.patterns {
self.collect_relationship_types_from_pattern(pattern, &mut types);
}
}
}
types.sort();
types.dedup();
types
}
pub fn variables(&self) -> Vec<String> {
let mut variables = Vec::new();
for clause in &self.ast.reading_clauses {
match clause {
ReadingClause::Match(match_clause) => {
for pattern in &match_clause.patterns {
self.collect_variables_from_pattern(pattern, &mut variables);
}
}
ReadingClause::Unwind(unwind_clause) => {
variables.push(unwind_clause.alias.clone());
}
}
}
variables.sort();
variables.dedup();
variables
}
fn collect_node_labels_from_pattern(
&self,
pattern: &crate::ast::GraphPattern,
labels: &mut Vec<String>,
) {
match pattern {
crate::ast::GraphPattern::Node(node) => {
labels.extend(node.labels.clone());
}
crate::ast::GraphPattern::Path(path) => {
labels.extend(path.start_node.labels.clone());
for segment in &path.segments {
labels.extend(segment.end_node.labels.clone());
}
}
}
}
fn collect_relationship_types_from_pattern(
&self,
pattern: &crate::ast::GraphPattern,
types: &mut Vec<String>,
) {
if let crate::ast::GraphPattern::Path(path) = pattern {
for segment in &path.segments {
types.extend(segment.relationship.types.clone());
}
}
}
fn collect_variables_from_pattern(
&self,
pattern: &crate::ast::GraphPattern,
variables: &mut Vec<String>,
) {
match pattern {
crate::ast::GraphPattern::Node(node) => {
if let Some(var) = &node.variable {
variables.push(var.clone());
}
}
crate::ast::GraphPattern::Path(path) => {
if let Some(var) = &path.start_node.variable {
variables.push(var.clone());
}
for segment in &path.segments {
if let Some(var) = &segment.relationship.variable {
variables.push(var.clone());
}
if let Some(var) = &segment.end_node.variable {
variables.push(var.clone());
}
}
}
}
}
}
impl CypherQuery {
pub async fn execute_with_vector_rerank(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
vector_search: crate::lance_vector_search::VectorSearch,
) -> Result<arrow::record_batch::RecordBatch> {
let candidates = self.execute(datasets, None).await?;
vector_search.search(&candidates).await
}
}
#[derive(Debug, Default)]
pub struct CypherQueryBuilder {
match_clauses: Vec<crate::ast::MatchClause>,
where_expression: Option<crate::ast::BooleanExpression>,
return_items: Vec<crate::ast::ReturnItem>,
order_by_items: Vec<crate::ast::OrderByItem>,
limit: Option<u64>,
distinct: bool,
skip: Option<u64>,
config: Option<GraphConfig>,
parameters: HashMap<String, serde_json::Value>,
}
impl CypherQueryBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn match_node(mut self, variable: &str, label: &str) -> Self {
let node = crate::ast::NodePattern {
variable: Some(variable.to_string()),
labels: vec![label.to_string()],
properties: HashMap::new(),
};
let match_clause = crate::ast::MatchClause {
patterns: vec![crate::ast::GraphPattern::Node(node)],
};
self.match_clauses.push(match_clause);
self
}
pub fn with_config(mut self, config: GraphConfig) -> Self {
self.config = Some(config);
self
}
pub fn return_property(mut self, variable: &str, property: &str) -> Self {
let prop_ref = crate::ast::PropertyRef::new(variable, property);
let return_item = crate::ast::ReturnItem {
expression: crate::ast::ValueExpression::Property(prop_ref),
alias: None,
};
self.return_items.push(return_item);
self
}
pub fn distinct(mut self, distinct: bool) -> Self {
self.distinct = distinct;
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.limit = Some(limit);
self
}
pub fn skip(mut self, skip: u64) -> Self {
self.skip = Some(skip);
self
}
pub fn build(self) -> Result<CypherQuery> {
if self.match_clauses.is_empty() {
return Err(GraphError::PlanError {
message: "Query must have at least one MATCH clause".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
if self.return_items.is_empty() {
return Err(GraphError::PlanError {
message: "Query must have at least one RETURN item".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
let ast = crate::ast::CypherQuery {
reading_clauses: self
.match_clauses
.into_iter()
.map(crate::ast::ReadingClause::Match)
.collect(),
where_clause: self
.where_expression
.map(|expr| crate::ast::WhereClause { expression: expr }),
with_clause: None, post_with_reading_clauses: vec![],
post_with_where_clause: None,
return_clause: crate::ast::ReturnClause {
distinct: self.distinct,
items: self.return_items,
},
order_by: if self.order_by_items.is_empty() {
None
} else {
Some(crate::ast::OrderByClause {
items: self.order_by_items,
})
},
limit: self.limit,
skip: self.skip,
};
let query_text = "MATCH ... RETURN ...".to_string();
let query = CypherQuery {
query_text,
ast,
config: self.config,
parameters: self.parameters,
};
Ok(query)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GraphConfig;
#[test]
fn test_parse_simple_cypher_query() {
let query = CypherQuery::new("MATCH (n:Person) RETURN n.name").unwrap();
assert_eq!(query.query_text(), "MATCH (n:Person) RETURN n.name");
assert_eq!(query.referenced_node_labels(), vec!["Person"]);
assert_eq!(query.variables(), vec!["n"]);
}
#[test]
fn test_query_with_parameters() {
let mut params = HashMap::new();
params.insert("minAge".to_string(), serde_json::Value::Number(30.into()));
params.insert("maxAge".to_string(), serde_json::Value::Number(50.into()));
let query = CypherQuery::new(
"MATCH (n:Person) WHERE n.age > $minAge AND n.age < $maxAge RETURN n.name",
)
.unwrap()
.with_parameters(params);
assert!(query.parameters().contains_key("minage"));
assert!(query.parameters().contains_key("maxage"));
}
#[test]
fn test_query_builder() {
let config = GraphConfig::builder()
.with_node_label("Person", "person_id")
.build()
.unwrap();
let query = CypherQueryBuilder::new()
.with_config(config)
.match_node("n", "Person")
.return_property("n", "name")
.limit(10)
.build()
.unwrap();
assert_eq!(query.referenced_node_labels(), vec!["Person"]);
assert_eq!(query.variables(), vec!["n"]);
}
#[test]
fn test_relationship_query_parsing() {
let query =
CypherQuery::new("MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a.name, b.name")
.unwrap();
assert_eq!(query.referenced_node_labels(), vec!["Person"]);
assert_eq!(query.referenced_relationship_types(), vec!["KNOWS"]);
assert_eq!(query.variables(), vec!["a", "b", "r"]);
}
#[tokio::test]
async fn test_execute_basic_projection_and_filter() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
],
)
.unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let q = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
.unwrap()
.with_config(cfg);
let mut data = HashMap::new();
data.insert("Person".to_string(), batch);
let out = q.execute(data, None).await.unwrap();
assert_eq!(out.num_rows(), 2);
let names = out
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
let result: Vec<(String, i64)> = (0..out.num_rows())
.map(|i| (names.value(i).to_string(), ages.value(i)))
.collect();
assert!(result.contains(&("Bob".to_string(), 34)));
assert!(result.contains(&("David".to_string(), 42)));
}
#[tokio::test]
async fn test_execute_single_hop_path_join_projection() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let person_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
]));
let people = RecordBatch::try_new(
person_schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
Arc::new(Int64Array::from(vec![28, 34, 29])),
],
)
.unwrap();
let rel_schema = Arc::new(Schema::new(vec![
Field::new("src_person_id", DataType::Int64, false),
Field::new("dst_person_id", DataType::Int64, false),
]));
let knows = RecordBatch::try_new(
rel_schema,
vec![
Arc::new(Int64Array::from(vec![1, 2])), Arc::new(Int64Array::from(vec![2, 3])),
],
)
.unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.with_relationship("KNOWS", "src_person_id", "dst_person_id")
.build()
.unwrap();
let q = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN b.name")
.unwrap()
.with_config(cfg);
let mut data = HashMap::new();
data.insert("Person".to_string(), people);
data.insert("KNOWS".to_string(), knows);
let out = q.execute(data, None).await.unwrap();
let names = out
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let got: Vec<String> = (0..out.num_rows())
.map(|i| names.value(i).to_string())
.collect();
assert_eq!(got.len(), 2);
assert!(got.contains(&"Bob".to_string()));
assert!(got.contains(&"Carol".to_string()));
}
#[tokio::test]
async fn test_execute_order_by_asc() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
],
)
.unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let q = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age ORDER BY p.age ASC")
.unwrap()
.with_config(cfg);
let mut data = HashMap::new();
data.insert("Person".to_string(), batch);
let out = q.execute(data, None).await.unwrap();
let ages = out.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
assert_eq!(collected, vec![28, 29, 34, 42]);
}
#[tokio::test]
async fn test_execute_order_by_desc_with_skip_limit() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["Bob", "Alice", "David", "Carol"])),
Arc::new(Int64Array::from(vec![34, 28, 42, 29])),
],
)
.unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let q =
CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age DESC SKIP 1 LIMIT 2")
.unwrap()
.with_config(cfg);
let mut data = HashMap::new();
data.insert("Person".to_string(), batch);
let out = q.execute(data, None).await.unwrap();
assert_eq!(out.num_rows(), 2);
let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
assert_eq!(collected, vec![34, 29]);
}
#[tokio::test]
async fn test_execute_skip_without_limit() {
use arrow_array::{Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Int64Array::from(vec![10, 20, 30, 40]))],
)
.unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let q = CypherQuery::new("MATCH (p:Person) RETURN p.age ORDER BY p.age ASC SKIP 2")
.unwrap()
.with_config(cfg);
let mut data = HashMap::new();
data.insert("Person".to_string(), batch);
let out = q.execute(data, None).await.unwrap();
assert_eq!(out.num_rows(), 2);
let ages = out.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
let collected: Vec<i64> = (0..out.num_rows()).map(|i| ages.value(i)).collect();
assert_eq!(collected, vec![30, 40]);
}
#[tokio::test]
async fn test_execute_datafusion_simple_scan() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["Alice", "Bob"])),
],
)
.unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
.unwrap()
.with_config(cfg);
let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), batch);
let result = query.execute_datafusion(datasets).await.unwrap();
assert_eq!(
result.num_rows(),
2,
"Should return all 2 rows without filtering"
);
assert_eq!(result.num_columns(), 1, "Should return 1 column (name)");
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let name_set: std::collections::HashSet<String> = (0..result.num_rows())
.map(|i| names.value(i).to_string())
.collect();
let expected: std::collections::HashSet<String> =
["Alice", "Bob"].iter().map(|s| s.to_string()).collect();
assert_eq!(name_set, expected, "Should return Alice and Bob");
}
#[tokio::test]
async fn test_execute_with_context_simple_scan() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionContext;
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("age", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
Arc::new(Int64Array::from(vec![28, 34, 29])),
],
)
.unwrap();
let mem_table =
Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
let ctx = SessionContext::new();
ctx.register_table("Person", mem_table).unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
.unwrap()
.with_config(cfg);
let result = query.execute_with_context(ctx).await.unwrap();
assert_eq!(result.num_rows(), 3);
assert_eq!(result.num_columns(), 1);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");
assert_eq!(names.value(1), "Bob");
assert_eq!(names.value(2), "Carol");
}
#[tokio::test]
async fn test_execute_with_context_with_filter() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionContext;
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("age", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
],
)
.unwrap();
let mem_table =
Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
let ctx = SessionContext::new();
ctx.register_table("Person", mem_table).unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age")
.unwrap()
.with_config(cfg);
let result = query.execute_with_context(ctx).await.unwrap();
assert_eq!(result.num_rows(), 2);
assert_eq!(result.num_columns(), 2);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let ages = result
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
let results: Vec<(String, i64)> = (0..result.num_rows())
.map(|i| (names.value(i).to_string(), ages.value(i)))
.collect();
assert!(results.contains(&("Bob".to_string(), 34)));
assert!(results.contains(&("David".to_string(), 42)));
}
#[tokio::test]
async fn test_execute_with_context_relationship_traversal() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionContext;
use std::sync::Arc;
let person_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let person_batch = RecordBatch::try_new(
person_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
],
)
.unwrap();
let knows_schema = Arc::new(Schema::new(vec![
Field::new("src_id", DataType::Int64, false),
Field::new("dst_id", DataType::Int64, false),
Field::new("since", DataType::Int64, false),
]));
let knows_batch = RecordBatch::try_new(
knows_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2])),
Arc::new(Int64Array::from(vec![2, 3])),
Arc::new(Int64Array::from(vec![2020, 2021])),
],
)
.unwrap();
let person_table = Arc::new(
MemTable::try_new(person_schema.clone(), vec![vec![person_batch.clone()]]).unwrap(),
);
let knows_table = Arc::new(
MemTable::try_new(knows_schema.clone(), vec![vec![knows_batch.clone()]]).unwrap(),
);
let ctx = SessionContext::new();
ctx.register_table("Person", person_table).unwrap();
ctx.register_table("KNOWS", knows_table).unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.with_relationship("KNOWS", "src_id", "dst_id")
.build()
.unwrap();
let query = CypherQuery::new("MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a.name, b.name")
.unwrap()
.with_config(cfg);
let result = query.execute_with_context(ctx).await.unwrap();
assert_eq!(result.num_rows(), 2);
assert_eq!(result.num_columns(), 2);
let src_names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let dst_names = result
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let relationships: Vec<(String, String)> = (0..result.num_rows())
.map(|i| {
(
src_names.value(i).to_string(),
dst_names.value(i).to_string(),
)
})
.collect();
assert!(relationships.contains(&("Alice".to_string(), "Bob".to_string())));
assert!(relationships.contains(&("Bob".to_string(), "Carol".to_string())));
}
#[tokio::test]
async fn test_execute_with_context_order_by_limit() {
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionContext;
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("score", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
Arc::new(Int64Array::from(vec![85, 92, 78, 95])),
],
)
.unwrap();
let mem_table =
Arc::new(MemTable::try_new(schema.clone(), vec![vec![batch.clone()]]).unwrap());
let ctx = SessionContext::new();
ctx.register_table("Student", mem_table).unwrap();
let cfg = GraphConfig::builder()
.with_node_label("Student", "id")
.build()
.unwrap();
let query = CypherQuery::new(
"MATCH (s:Student) RETURN s.name, s.score ORDER BY s.score DESC LIMIT 2",
)
.unwrap()
.with_config(cfg);
let result = query.execute_with_context(ctx).await.unwrap();
assert_eq!(result.num_rows(), 2);
assert_eq!(result.num_columns(), 2);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let scores = result
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(names.value(0), "David");
assert_eq!(scores.value(0), 95);
assert_eq!(names.value(1), "Bob");
assert_eq!(scores.value(1), 92);
}
#[tokio::test]
async fn test_to_sql() {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use std::collections::HashMap;
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::new_empty(schema.clone());
let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), batch);
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
.unwrap()
.with_config(cfg);
let sql = query.to_sql(datasets).await.unwrap();
println!("Generated SQL: {}", sql);
assert!(sql.contains("SELECT"));
assert!(sql.to_lowercase().contains("from person"));
assert!(sql.contains("p.name"));
}
async fn write_lance_dataset(path: &std::path::Path, batch: arrow_array::RecordBatch) {
use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::{Dataset, WriteParams};
let schema = batch.schema();
let batches: Vec<std::result::Result<RecordBatch, arrow::error::ArrowError>> =
vec![std::result::Result::Ok(batch)];
let reader = RecordBatchIterator::new(batches.into_iter(), schema);
Dataset::write(reader, path.to_str().unwrap(), None::<WriteParams>)
.await
.expect("write lance dataset");
}
fn build_people_batch() -> arrow_array::RecordBatch {
use arrow_array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("person_id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("age", DataType::Int32, false),
]));
let columns: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4])) as ArrayRef,
Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])) as ArrayRef,
Arc::new(Int32Array::from(vec![28, 34, 29, 42])) as ArrayRef,
];
RecordBatch::try_new(schema, columns).expect("valid person batch")
}
fn build_friendship_batch() -> arrow_array::RecordBatch {
use arrow_array::{ArrayRef, Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("person1_id", DataType::Int64, false),
Field::new("person2_id", DataType::Int64, false),
]));
let columns: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![1, 1, 2, 3])) as ArrayRef,
Arc::new(Int64Array::from(vec![2, 3, 4, 4])) as ArrayRef,
];
RecordBatch::try_new(schema, columns).expect("valid friendship batch")
}
#[tokio::test]
async fn executes_against_directory_namespace() {
use arrow_array::StringArray;
use tempfile::tempdir;
let tmp_dir = tempdir().unwrap();
write_lance_dataset(&tmp_dir.path().join("Person.lance"), build_people_batch()).await;
write_lance_dataset(
&tmp_dir.path().join("FRIEND_OF.lance"),
build_friendship_batch(),
)
.await;
let config = GraphConfig::builder()
.with_node_label("Person", "person_id")
.with_relationship("FRIEND_OF", "person1_id", "person2_id")
.build()
.expect("valid graph config");
let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name")
.expect("query parses")
.with_config(config);
let namespace = DirNamespace::new(tmp_dir.path().to_string_lossy().into_owned());
let result = query
.execute_with_namespace(namespace.clone(), None)
.await
.expect("namespace execution succeeds");
use arrow_array::Array;
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("string column");
let mut values: Vec<String> = (0..names.len())
.map(|i| names.value(i).to_string())
.collect();
values.sort();
assert_eq!(values, vec!["Bob".to_string(), "David".to_string()]);
}
#[tokio::test]
async fn test_execute_fails_on_semantic_error() {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use std::collections::HashMap;
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch = RecordBatch::new_empty(schema);
let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), batch);
let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let query = CypherQuery::new("MATCH (n:Person) RETURN x.name")
.unwrap()
.with_config(cfg);
let result = query.execute(datasets, None).await;
assert!(result.is_err());
match result {
Err(GraphError::PlanError { message, .. }) => {
assert!(message.contains("Semantic analysis failed"));
assert!(message.contains("Undefined variable: 'x'"));
}
_ => panic!(
"Expected PlanError with semantic failure message, got {:?}",
result
),
}
}
}