use datafusion::execution::context::SessionContext;
use lance_graph::config::GraphConfig;
use lance_graph::query::CypherQuery;
#[tokio::test]
async fn test_execute_with_context_csv_simple() {
let temp_dir = tempfile::tempdir().unwrap();
let person_csv_path = temp_dir.path().join("persons.csv");
let knows_csv_path = temp_dir.path().join("knows.csv");
std::fs::write(
&person_csv_path,
"id,name,age\n\
1,Alice,28\n\
2,Bob,34\n\
3,Carol,29\n\
4,David,42\n",
)
.unwrap();
std::fs::write(
&knows_csv_path,
"src_id,dst_id,since\n\
1,2,2020\n\
2,3,2021\n\
1,3,2019\n",
)
.unwrap();
let config = GraphConfig::builder()
.with_node_label("Person", "id")
.with_relationship("KNOWS", "src_id", "dst_id")
.build()
.unwrap();
let ctx = SessionContext::new();
ctx.register_csv(
"Person",
person_csv_path.to_str().unwrap(),
Default::default(),
)
.await
.unwrap();
ctx.register_csv(
"KNOWS",
knows_csv_path.to_str().unwrap(),
Default::default(),
)
.await
.unwrap();
let query1 =
CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age ORDER BY p.age")
.unwrap()
.with_config(config.clone());
let result1 = query1.execute_with_context(ctx.clone()).await.unwrap();
assert_eq!(result1.num_rows(), 2);
assert_eq!(result1.num_columns(), 2);
assert_eq!(result1.schema().field(0).name(), "p.name");
assert_eq!(result1.schema().field(1).name(), "p.age");
let names = result1
.column(0)
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
let ages = result1
.column(1)
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap();
assert_eq!(names.value(0), "Bob");
assert_eq!(ages.value(0), 34);
assert_eq!(names.value(1), "David");
assert_eq!(ages.value(1), 42);
let query2 = CypherQuery::new(
"MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a.name, b.name ORDER BY a.name",
)
.unwrap()
.with_config(config);
let result2 = query2.execute_with_context(ctx).await.unwrap();
assert_eq!(result2.num_rows(), 3);
assert_eq!(result2.num_columns(), 2);
assert_eq!(result2.schema().field(0).name(), "a.name");
assert_eq!(result2.schema().field(1).name(), "b.name");
let src_names = result2
.column(0)
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
let dst_names = result2
.column(1)
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
let relationships: Vec<(String, String)> = (0..result2.num_rows())
.map(|i| {
(
src_names.value(i).to_string(),
dst_names.value(i).to_string(),
)
})
.collect();
assert!(relationships.contains(&("Alice".to_string(), "Bob".to_string())));
assert!(relationships.contains(&("Alice".to_string(), "Carol".to_string())));
assert!(relationships.contains(&("Bob".to_string(), "Carol".to_string())));
}
#[tokio::test]
async fn test_execute_with_context_complex_query() {
let temp_dir = tempfile::tempdir().unwrap();
let employee_csv_path = temp_dir.path().join("employees.csv");
let department_csv_path = temp_dir.path().join("departments.csv");
let works_in_csv_path = temp_dir.path().join("works_in.csv");
std::fs::write(
&employee_csv_path,
"emp_id,name,salary\n\
101,Alice,75000\n\
102,Bob,85000\n\
103,Carol,65000\n\
104,David,95000\n\
105,Eve,72000\n",
)
.unwrap();
std::fs::write(
&department_csv_path,
"dept_id,name,budget\n\
1,Engineering,500000\n\
2,Sales,300000\n\
3,HR,200000\n",
)
.unwrap();
std::fs::write(
&works_in_csv_path,
"employee_id,department_id,role\n\
101,1,Engineer\n\
102,1,Senior Engineer\n\
103,2,Sales Rep\n\
104,1,Manager\n\
105,3,HR Specialist\n",
)
.unwrap();
let config = GraphConfig::builder()
.with_node_label("Employee", "emp_id")
.with_node_label("Department", "dept_id")
.with_relationship("WORKS_IN", "employee_id", "department_id")
.build()
.unwrap();
let ctx = SessionContext::new();
ctx.register_csv(
"Employee",
employee_csv_path.to_str().unwrap(),
Default::default(),
)
.await
.unwrap();
ctx.register_csv(
"Department",
department_csv_path.to_str().unwrap(),
Default::default(),
)
.await
.unwrap();
ctx.register_csv(
"WORKS_IN",
works_in_csv_path.to_str().unwrap(),
Default::default(),
)
.await
.unwrap();
let query = CypherQuery::new(
"MATCH (e:Employee)-[:WORKS_IN]->(d:Department) \
WHERE d.name = 'Engineering' AND e.salary > 80000 \
RETURN e.name, e.salary, d.name \
ORDER BY e.salary DESC",
)
.unwrap()
.with_config(config);
let result = query.execute_with_context(ctx).await.unwrap();
assert_eq!(result.num_rows(), 2);
assert_eq!(result.num_columns(), 3);
assert_eq!(result.schema().field(0).name(), "e.name");
assert_eq!(result.schema().field(1).name(), "e.salary");
assert_eq!(result.schema().field(2).name(), "d.name");
let emp_names = result
.column(0)
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
let salaries = result
.column(1)
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap();
let dept_names = result
.column(2)
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
assert_eq!(emp_names.value(0), "David");
assert_eq!(salaries.value(0), 95000);
assert_eq!(dept_names.value(0), "Engineering");
assert_eq!(emp_names.value(1), "Bob");
assert_eq!(salaries.value(1), 85000);
assert_eq!(dept_names.value(1), "Engineering");
}
#[tokio::test]
async fn test_execute_with_context_missing_table() {
let config = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let ctx = SessionContext::new();
let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
.unwrap()
.with_config(config);
let result = query.execute_with_context(ctx).await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string().to_lowercase();
assert!(
err_msg.contains("person") && err_msg.contains("not found"),
"Error should mention missing Person table: {}",
err_msg
);
}
#[tokio::test]
async fn test_execute_with_context_aliases() {
let temp_dir = tempfile::tempdir().unwrap();
let person_csv_path = temp_dir.path().join("persons.csv");
std::fs::write(
&person_csv_path,
"id,name,age\n\
1,Alice,28\n\
2,Bob,34\n",
)
.unwrap();
let config = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();
let ctx = SessionContext::new();
ctx.register_csv(
"Person",
person_csv_path.to_str().unwrap(),
Default::default(),
)
.await
.unwrap();
let query = CypherQuery::new(
"MATCH (p:Person) RETURN p.name AS person_name, p.age AS person_age ORDER BY p.age",
)
.unwrap()
.with_config(config);
let result = query.execute_with_context(ctx).await.unwrap();
assert_eq!(result.num_rows(), 2);
assert_eq!(result.schema().field(0).name(), "person_name");
assert_eq!(result.schema().field(1).name(), "person_age");
let names = result
.column(0)
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");
assert_eq!(names.value(1), "Bob");
}