#![cfg(feature = "vector-index")]
use std::collections::HashMap;
use grafeo_common::types::Value;
use grafeo_engine::GrafeoDB;
fn vec3(x: f32, y: f32, z: f32) -> Value {
Value::Vector(vec![x, y, z].into())
}
fn setup_db() -> GrafeoDB {
let db = GrafeoDB::new_in_memory();
let n1 = db.create_node(&["Doc"]);
db.set_node_property(n1, "emb", vec3(1.0, 0.0, 0.0));
db.set_node_property(n1, "user_id", Value::Int64(1));
let n2 = db.create_node(&["Doc"]);
db.set_node_property(n2, "emb", vec3(0.95, 0.05, 0.0));
db.set_node_property(n2, "user_id", Value::Int64(1));
let n3 = db.create_node(&["Doc"]);
db.set_node_property(n3, "emb", vec3(0.0, 1.0, 0.0));
db.set_node_property(n3, "user_id", Value::Int64(2));
let n4 = db.create_node(&["Doc"]);
db.set_node_property(n4, "emb", vec3(0.05, 0.95, 0.0));
db.set_node_property(n4, "user_id", Value::Int64(2));
let n5 = db.create_node(&["Doc"]);
db.set_node_property(n5, "emb", vec3(0.0, 0.0, 1.0));
db.set_node_property(n5, "user_id", Value::Int64(3));
let n6 = db.create_node(&["Doc"]);
db.set_node_property(n6, "emb", vec3(0.5, 0.5, 0.0));
db.create_property_index("user_id");
db.create_vector_index("Doc", "emb", Some(3), Some("cosine"), None, None)
.expect("create index");
let _ = (n1, n2, n3, n4, n5, n6);
db
}
#[test]
fn test_filtered_vector_search_by_user_id() {
let db = setup_db();
let filters: HashMap<String, Value> = [("user_id".to_string(), Value::Int64(2))]
.into_iter()
.collect();
let results = db
.vector_search("Doc", "emb", &[1.0, 0.0, 0.0], 5, None, Some(&filters))
.expect("filtered search");
assert!(!results.is_empty());
assert!(results.len() <= 2);
for (id, _) in &results {
let node = db.get_node(*id).expect("node exists");
let uid = node
.properties
.get(&grafeo_common::types::PropertyKey::new("user_id"))
.expect("has user_id");
assert_eq!(uid, &Value::Int64(2), "result should be user_id=2");
}
}
#[test]
fn test_filtered_search_without_filters_returns_all() {
let db = setup_db();
let results = db
.vector_search("Doc", "emb", &[0.5, 0.5, 0.0], 10, None, None)
.expect("unfiltered search");
assert_eq!(results.len(), 6, "should find all 6 Doc nodes");
}
#[test]
fn test_filtered_search_empty_filters_returns_all() {
let db = setup_db();
let filters: HashMap<String, Value> = HashMap::new();
let results = db
.vector_search("Doc", "emb", &[0.5, 0.5, 0.0], 10, None, Some(&filters))
.expect("empty filter search");
assert_eq!(results.len(), 6, "empty filters should return all nodes");
}
#[test]
fn test_filtered_search_no_matches() {
let db = setup_db();
let filters: HashMap<String, Value> = [("user_id".to_string(), Value::Int64(999))]
.into_iter()
.collect();
let results = db
.vector_search("Doc", "emb", &[1.0, 0.0, 0.0], 5, None, Some(&filters))
.expect("filtered search");
assert!(results.is_empty(), "no matching nodes should return empty");
}
#[test]
fn test_batch_vector_search_with_filters() {
let db = setup_db();
let filters: HashMap<String, Value> = [("user_id".to_string(), Value::Int64(1))]
.into_iter()
.collect();
let queries = vec![vec![1.0f32, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let results = db
.batch_vector_search("Doc", "emb", &queries, 5, None, Some(&filters))
.expect("batch filtered search");
assert_eq!(results.len(), 2);
for query_results in &results {
for (id, _) in query_results {
let node = db.get_node(*id).expect("node exists");
let uid = node
.properties
.get(&grafeo_common::types::PropertyKey::new("user_id"))
.expect("has user_id");
assert_eq!(uid, &Value::Int64(1));
}
}
}
#[test]
fn test_mmr_search_with_filters() {
let db = setup_db();
let filters: HashMap<String, Value> = [("user_id".to_string(), Value::Int64(2))]
.into_iter()
.collect();
let results = db
.mmr_search(
"Doc",
"emb",
&[0.0, 1.0, 0.0],
2,
None,
None,
None,
Some(&filters),
)
.expect("mmr filtered search");
assert!(!results.is_empty());
assert!(results.len() <= 2);
for (id, _) in &results {
let node = db.get_node(*id).expect("node exists");
let uid = node
.properties
.get(&grafeo_common::types::PropertyKey::new("user_id"))
.expect("has user_id");
assert_eq!(uid, &Value::Int64(2));
}
}
#[test]
fn test_filtered_search_non_indexed_property() {
let db = setup_db();
let results_all = db
.vector_search("Doc", "emb", &[1.0, 0.0, 0.0], 6, None, None)
.expect("find all");
for (id, _) in results_all.iter().take(2) {
db.set_node_property(*id, "category", Value::String("science".into()));
}
let filters: HashMap<String, Value> =
[("category".to_string(), Value::String("science".into()))]
.into_iter()
.collect();
let results = db
.vector_search("Doc", "emb", &[1.0, 0.0, 0.0], 10, None, Some(&filters))
.expect("filtered search on non-indexed property");
assert!(results.len() <= 2, "at most 2 nodes have category=science");
}
#[test]
fn test_create_vector_index_no_dims_no_data_errors() {
let db = GrafeoDB::new_in_memory();
let result = db.create_vector_index("Doc", "emb", None, None, None, None);
assert!(result.is_err(), "should error without dimensions or data");
}
#[test]
fn test_create_vector_index_with_dims_no_data_succeeds() {
let db = GrafeoDB::new_in_memory();
db.create_vector_index("Doc", "emb", Some(4), Some("cosine"), None, None)
.expect("should create empty index with explicit dimensions");
let id = db.create_node(&["Doc"]);
db.set_node_property(id, "emb", Value::Vector(vec![1.0, 0.0, 0.0, 0.0].into()));
let results = db
.vector_search("Doc", "emb", &[1.0, 0.0, 0.0, 0.0], 5, None, None)
.expect("search should work");
assert_eq!(results.len(), 1, "should find the one auto-inserted node");
}
#[test]
fn test_grafeo_memory_pattern() {
let db = GrafeoDB::new_in_memory();
db.create_vector_index("Memory", "embedding", Some(4), Some("cosine"), None, None)
.expect("create index");
for i in 0..10 {
let user = if i < 5 { "alix" } else { "gus" };
let id = db.create_node_with_props(
&["Memory"],
vec![
(
grafeo_common::types::PropertyKey::new("text"),
Value::String(format!("memory {i}").into()),
),
(
grafeo_common::types::PropertyKey::new("user_id"),
Value::String(user.into()),
),
],
);
let emb = vec![(i as f32) / 10.0, 1.0 - (i as f32) / 10.0, 0.1, 0.1];
db.set_node_property(id, "embedding", Value::Vector(emb.into()));
}
let all_results = db
.vector_search("Memory", "embedding", &[0.5, 0.5, 0.1, 0.1], 10, None, None)
.expect("unfiltered search");
assert_eq!(all_results.len(), 10, "should find all 10 Memory nodes");
let filters: HashMap<String, Value> = [("user_id".to_string(), Value::String("alix".into()))]
.into_iter()
.collect();
let results = db
.vector_search(
"Memory",
"embedding",
&[0.5, 0.5, 0.1, 0.1],
10,
None,
Some(&filters),
)
.expect("filtered search should not error");
assert_eq!(results.len(), 5, "should find 5 alix nodes");
for (id, _) in &results {
let node = db.get_node(*id).expect("node exists");
let uid = node
.properties
.get(&grafeo_common::types::PropertyKey::new("user_id"))
.expect("has user_id");
assert_eq!(uid, &Value::String("alix".into()));
}
}
fn setup_operator_db() -> GrafeoDB {
use grafeo_common::types::PropertyKey;
let db = GrafeoDB::new_in_memory();
db.create_vector_index("Item", "emb", Some(3), Some("cosine"), None, None)
.expect("create index");
for i in 0..10 {
let category = match i % 3 {
0 => "preference",
1 => "fact",
_ => "event",
};
let id = db.create_node_with_props(
&["Item"],
vec![
(PropertyKey::new("score"), Value::Float64((i as f64) * 0.1)),
(PropertyKey::new("rank"), Value::Int64(i)),
(PropertyKey::new("category"), Value::String(category.into())),
(
PropertyKey::new("text"),
Value::String(format!("item number {i} is great").into()),
),
],
);
let emb = vec![(i as f32) / 10.0, 1.0 - (i as f32) / 10.0, 0.5];
db.set_node_property(id, "emb", Value::Vector(emb.into()));
}
db
}
fn op_filter(ops: Vec<(&str, Value)>) -> Value {
let map: std::collections::BTreeMap<grafeo_common::types::PropertyKey, Value> = ops
.into_iter()
.map(|(k, v)| (grafeo_common::types::PropertyKey::new(k), v))
.collect();
Value::Map(std::sync::Arc::new(map))
}
#[test]
fn test_filter_gt_lt() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [(
"rank".to_string(),
op_filter(vec![("$gt", Value::Int64(5))]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("gt filter");
assert_eq!(results.len(), 4, "rank > 5 should match 4 nodes");
let filters: HashMap<String, Value> = [(
"rank".to_string(),
op_filter(vec![("$lt", Value::Int64(3))]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("lt filter");
assert_eq!(results.len(), 3, "rank < 3 should match 3 nodes");
}
#[test]
fn test_filter_gte_lte() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [(
"rank".to_string(),
op_filter(vec![("$gte", Value::Int64(3)), ("$lte", Value::Int64(6))]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("gte/lte filter");
assert_eq!(results.len(), 4, "rank in [3, 6] should match 4 nodes");
}
#[test]
fn test_filter_in() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [(
"category".to_string(),
op_filter(vec![(
"$in",
Value::List(
vec![
Value::String("preference".into()),
Value::String("fact".into()),
]
.into(),
),
)]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("in filter");
assert_eq!(
results.len(),
7,
"category in [preference, fact] should match 7 nodes"
);
}
#[test]
fn test_filter_nin() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [(
"category".to_string(),
op_filter(vec![(
"$nin",
Value::List(vec![Value::String("event".into())].into()),
)]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("nin filter");
assert_eq!(results.len(), 7, "category not in [event] should match 7");
}
#[test]
fn test_filter_contains() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [(
"text".to_string(),
op_filter(vec![("$contains", Value::String("number 5".into()))]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("contains filter");
assert_eq!(results.len(), 1, "text contains 'number 5' should match 1");
}
#[test]
fn test_filter_ne() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [(
"category".to_string(),
op_filter(vec![("$ne", Value::String("event".into()))]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("ne filter");
assert_eq!(results.len(), 7, "category != event should match 7");
}
#[test]
fn test_filter_mixed_equality_and_operators() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [
("category".to_string(), Value::String("preference".into())),
(
"rank".to_string(),
op_filter(vec![("$gt", Value::Int64(3))]),
),
]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("mixed filter");
assert_eq!(
results.len(),
2,
"preference AND rank > 3 should match items 6 and 9"
);
}
#[test]
fn test_filter_cross_type_numeric_comparison() {
let db = setup_operator_db();
let filters: HashMap<String, Value> = [(
"score".to_string(),
op_filter(vec![("$gt", Value::Int64(0))]),
)]
.into_iter()
.collect();
let results = db
.vector_search("Item", "emb", &[0.5, 0.5, 0.5], 10, None, Some(&filters))
.expect("cross-type filter");
assert_eq!(
results.len(),
9,
"score > 0 (cross-type) should match 9 nodes (all except score=0.0)"
);
}
#[test]
fn test_vector_search_results_ordered_by_distance() {
let db = setup_db();
let results = db
.vector_search("Doc", "emb", &[1.0, 0.0, 0.0], 6, None, None)
.expect("unfiltered search");
assert!(results.len() >= 2, "should return multiple results");
for window in results.windows(2) {
let (_, dist_a) = window[0];
let (_, dist_b) = window[1];
assert!(
dist_a <= dist_b,
"results should be ordered by distance: {dist_a} <= {dist_b}"
);
}
}
#[test]
fn test_filtered_vector_search_results_ordered_by_distance() {
let db = setup_db();
let filters: HashMap<String, Value> = [("user_id".to_string(), Value::Int64(1))]
.into_iter()
.collect();
let results = db
.vector_search("Doc", "emb", &[0.5, 0.5, 0.0], 5, None, Some(&filters))
.expect("filtered search");
assert!(!results.is_empty());
for window in results.windows(2) {
let (_, dist_a) = window[0];
let (_, dist_b) = window[1];
assert!(
dist_a <= dist_b,
"filtered results should be ordered by distance: {dist_a} <= {dist_b}"
);
}
}