use luci::index::Index;
use luci::mapping::{FieldType, Mapping};
use luci::search::expression::parse_search;
use serde_json::json;
fn test_dir(name: &str) -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("luci_knn_query_{}_{name}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
dir
}
fn cleanup(path: &std::path::Path) {
let _ = std::fs::remove_dir_all(path);
}
fn build_index(name: &str) -> (std::path::PathBuf, Index) {
let path = test_dir(name);
let schema = Mapping::builder()
.field("title", FieldType::Text)
.field("tag", FieldType::Keyword)
.field("embedding", FieldType::dense_vector(4))
.build();
let index = Index::create_with_mapping(&path, schema).unwrap();
index.bulk(vec![
json!({"title": "search engine design", "tag": "tech", "embedding": [0.9, 0.1, 0.0, 0.0]}),
json!({"title": "search algorithms", "tag": "tech", "embedding": [0.1, 0.9, 0.0, 0.0]}),
json!({"title": "cute cats", "tag": "animal", "embedding": [0.0, 0.0, 0.9, 0.1]}),
json!({"title": "search optimization", "tag": "tech", "embedding": [0.0, 0.0, 0.1, 0.9]}),
json!({"title": "happy dog", "tag": "animal", "embedding": [0.0, 0.0, 0.0, 0.1]}),
]).unwrap();
(path, index)
}
fn search(
index: &Index,
query: serde_json::Value,
size: usize,
) -> luci::search::results::SearchResults {
let expr = parse_search(query, size).unwrap();
index.search(&expr).unwrap()
}
#[test]
fn knn_query_standalone() {
let (path, index) = build_index("standalone");
let results = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 3
}}
}),
10,
);
assert_eq!(results.len(), 3);
assert_eq!(results.hit(0).unwrap().doc_id().as_u32(), 0);
cleanup(&path);
}
#[test]
fn knn_query_with_threshold() {
let (path, index) = build_index("threshold");
let all = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 5
}}
}),
10,
);
assert_eq!(all.len(), 5);
let filtered = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 5,
"threshold": 0.6
}}
}),
10,
);
assert!(
filtered.len() < all.len(),
"threshold should reduce results"
);
assert!(filtered.len() >= 1, "closest doc should pass threshold");
cleanup(&path);
}
#[test]
fn knn_query_threshold_excludes_all() {
let (path, index) = build_index("threshold_all");
let results = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 5,
"threshold": 0.999
}}
}),
10,
);
assert_eq!(results.len(), 0);
cleanup(&path);
}
#[test]
fn knn_query_in_bool_should() {
let (path, index) = build_index("bool_should");
let results = search(
&index,
json!({
"query": {"bool": {"should": [
{"match": {"title": "search engine"}},
{"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 3
}}
]}}
}),
10,
);
assert!(!results.is_empty());
assert_eq!(results.hit(0).unwrap().doc_id().as_u32(), 0);
cleanup(&path);
}
#[test]
fn knn_query_in_bool_must() {
let (path, index) = build_index("bool_must");
let results = search(
&index,
json!({
"query": {"bool": {"must": [
{"match": {"title": "search"}},
{"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 5
}}
]}}
}),
10,
);
for hit in results.iter() {
let source = hit.source().unwrap();
let title = source["title"].as_str().unwrap();
assert!(
title.contains("search"),
"must conjunction: doc should match 'search', got '{title}'"
);
}
cleanup(&path);
}
#[test]
fn knn_query_in_bool_filter() {
let (path, index) = build_index("bool_filter");
let results = search(
&index,
json!({
"query": {"bool": {
"must": [{"match": {"title": "search"}}],
"filter": [{"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 2
}}]
}}
}),
10,
);
assert!(results.len() <= 2, "filter should restrict to kNN top-2");
for hit in results.iter() {
let id = hit.doc_id().as_u32();
assert!(id == 0 || id == 1, "expected doc 0 or 1, got {id}");
}
cleanup(&path);
}
#[test]
fn knn_query_num_candidates_default() {
let _expr = parse_search(
json!({
"query": {"knn": {
"field": "f",
"query_vector": [1.0],
"k": 10
}}
}),
10,
)
.unwrap();
}
#[test]
fn knn_query_scores_descending() {
let (path, index) = build_index("scores_desc");
let results = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 5
}}
}),
10,
);
for i in 0..results.len().saturating_sub(1) {
let a = results.hit(i).unwrap().score();
let b = results.hit(i + 1).unwrap().score();
assert!(
a >= b,
"scores should be descending: hit[{i}]={a} < hit[{}]={b}",
i + 1
);
}
cleanup(&path);
}
#[test]
fn knn_query_score_range() {
let (path, index) = build_index("score_range");
let results = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 5
}}
}),
10,
);
for hit in results.iter() {
let s = hit.score();
assert!(s >= 0.0, "score should be non-negative, got {s}");
}
cleanup(&path);
}
#[test]
fn knn_query_explain() {
let (path, index) = build_index("explain");
let results = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 3
}}
}),
10,
);
let hit = results.hit(0).unwrap();
let explanation = hit.explain().expect("explain should not error");
assert!(
explanation.is_some(),
"kNN query should produce an explanation"
);
let expl = explanation.unwrap();
assert!(expl.value > 0.0, "explanation score should be > 0");
cleanup(&path);
}
#[test]
fn knn_query_invalid_vector() {
let result = parse_search(
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1, "bad", 3]
}}
}),
10,
);
assert!(
result.is_err(),
"non-numeric vector elements should be rejected"
);
}
#[test]
fn knn_query_zero_k() {
let result = parse_search(
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 0
}}
}),
10,
);
assert!(result.is_err(), "k=0 should be rejected");
}
#[test]
fn knn_query_bool_must_correctness() {
let path = test_dir("bool_must_correct");
let schema = Mapping::builder()
.field("tag", FieldType::Keyword)
.field("embedding", FieldType::dense_vector(4))
.build();
let index = Index::create_with_mapping(&path, schema).unwrap();
let mut docs = Vec::new();
for i in 0..20 {
let tag = if i < 5 { "target" } else { "other" };
let angle = (i as f32) * 0.3;
let v = [
angle.cos(),
angle.sin(),
(angle * 0.5).cos(),
(angle * 0.5).sin(),
];
docs.push(json!({"tag": tag, "embedding": v}));
}
index.bulk(docs).unwrap();
let knn_only = search(
&index,
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 1.0, 0.0],
"k": 10
}}
}),
10,
);
let term_only = search(
&index,
json!({
"query": {"term": {"tag": "target"}}
}),
10,
);
let conjunction = search(
&index,
json!({
"query": {"bool": {"must": [
{"term": {"tag": "target"}},
{"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 1.0, 0.0],
"k": 10
}}
]}}
}),
10,
);
let knn_ids: std::collections::HashSet<u32> =
knn_only.iter().map(|h| h.doc_id().as_u32()).collect();
let term_ids: std::collections::HashSet<u32> =
term_only.iter().map(|h| h.doc_id().as_u32()).collect();
for hit in conjunction.iter() {
let id = hit.doc_id().as_u32();
assert!(
knn_ids.contains(&id),
"conjunction doc {id} not in kNN results"
);
assert!(
term_ids.contains(&id),
"conjunction doc {id} not in term results"
);
}
assert!(
!conjunction.is_empty(),
"conjunction should find at least one doc matching both conditions"
);
cleanup(&path);
}
#[test]
fn knn_query_dimension_mismatch() {
let (path, index) = build_index("dim_mismatch");
let expr = parse_search(
json!({
"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 2.0],
"k": 5
}}
}),
10,
)
.unwrap();
let err = match index.search(&expr) {
Ok(_) => panic!("dimension mismatch must error, not return empty"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("2 dimensions") && msg.contains("embedding"),
"error message should name the dim mismatch: {msg}"
);
cleanup(&path);
}
#[test]
fn knn_query_non_vector_field_errors() {
let (path, index) = build_index("non_vector_field");
let expr = parse_search(
json!({
"query": {"knn": {
"field": "title",
"query_vector": [1.0, 2.0, 3.0, 4.0],
"k": 5
}}
}),
10,
)
.unwrap();
let err = match index.search(&expr) {
Ok(_) => panic!("knn on a non-dense_vector field must error, not return empty"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("title") && msg.contains("dense_vector"),
"error should name the field and that it is not a dense_vector: {msg}"
);
cleanup(&path);
}
#[test]
fn knn_query_unknown_field_errors() {
let (path, index) = build_index("unknown_field");
let expr = parse_search(
json!({
"query": {"knn": {
"field": "nope",
"query_vector": [1.0, 2.0, 3.0, 4.0],
"k": 5
}}
}),
10,
)
.unwrap();
let err = match index.search(&expr) {
Ok(_) => panic!("knn on an unknown field must error, not return empty"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("nope") && msg.contains("unknown"),
"error should name the unknown field: {msg}"
);
cleanup(&path);
}
#[test]
fn knn_dims_zero_builder_rejected() {
let path = test_dir("dims_zero");
let schema = Mapping::builder()
.field("v", FieldType::dense_vector(0))
.build();
let err = match Index::create_with_mapping(&path, schema) {
Ok(_) => panic!("dense_vector(0) must be rejected at mapping validation"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("v") && (msg.contains("dims") || msg.contains("dimension")),
"error should name the field and dims: {msg}"
);
cleanup(&path);
}
#[test]
fn knn_bad_field_in_agg_filter_errors() {
let (path, index) = build_index("agg_filter_bad_field");
let expr = parse_search(
json!({
"query": {"match_all": {}},
"aggs": {"f": {"filter": {"knn": {
"field": "nope",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 2
}}}}
}),
10,
)
.unwrap();
let err = match index.search(&expr) {
Ok(_) => panic!("bad-field knn in a filter agg must error, not return empty buckets"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("nope") && msg.contains("unknown"),
"error should name the unknown field: {msg}"
);
cleanup(&path);
}
#[test]
fn knn_valid_field_in_agg_filter_works() {
let path = test_dir("agg_filter_valid");
let schema = Mapping::builder()
.field("embedding", FieldType::dense_vector(4))
.build();
let index = Index::create_with_mapping(&path, schema).unwrap();
index
.bulk(vec![
json!({"embedding": [1.0, 0.0, 0.0, 0.0]}),
json!({"embedding": [0.9, 0.1, 0.0, 0.0]}),
json!({"embedding": [0.0, 1.0, 0.0, 0.0]}),
json!({"embedding": [0.0, 0.0, 1.0, 0.0]}),
])
.unwrap();
let expr = parse_search(
json!({
"query": {"match_all": {}},
"aggs": {"f": {"filter": {"knn": {
"field": "embedding",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 2,
"num_candidates": 10
}}}}
}),
10,
)
.unwrap();
let results = index.search(&expr).unwrap();
let agg = results.aggregations()["f"].to_json();
assert_eq!(
agg["buckets"][0]["doc_count"].as_u64().unwrap(),
2,
"filter-agg knn top-2 should count exactly docs {{0,1}}: {agg}"
);
cleanup(&path);
}
#[test]
fn knn_bad_field_in_filters_agg_errors() {
let (path, index) = build_index("filters_agg_bad_field");
let expr = parse_search(
json!({
"query": {"match_all": {}},
"aggs": {"f": {"filters": {"filters": {
"a": {"knn": {"field": "nope", "query_vector": [1.0, 0.0, 0.0, 0.0], "k": 2}},
"b": {"match_all": {}}
}}}}
}),
10,
)
.unwrap();
let err = match index.search(&expr) {
Ok(_) => panic!("bad-field knn in a filters agg must error, not return empty buckets"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("nope") && msg.contains("unknown"),
"error should name the unknown field: {msg}"
);
cleanup(&path);
}
#[test]
fn filters_agg_sub_aggs_refused() {
let (path, index) = build_index("filters_agg_sub_aggs");
let expr = parse_search(
json!({
"query": {"match_all": {}},
"aggs": {"f": {
"filters": {"filters": {"a": {"match_all": {}}}},
"aggs": {"by_tag": {"terms": {"field": "tag"}}}
}}
}),
10,
)
.unwrap();
let err = match index.search(&expr) {
Ok(_) => panic!("filters agg with sub_aggs must be refused, not silently dropped"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("filters") && msg.contains("not yet supported"),
"error should explain filters sub-aggs are unsupported: {msg}"
);
cleanup(&path);
}
#[test]
fn knn_bad_field_in_nested_inner_hits_errors() {
let (path, index) = build_index("nested_inner_hits_bad_field");
let expr = parse_search(
json!({
"query": {"nested": {
"path": "items",
"query": {"knn": {
"field": "nope",
"query_vector": [1.0, 0.0, 0.0, 0.0],
"k": 2
}},
"inner_hits": {"name": "matched"}
}}
}),
10,
)
.unwrap();
let err = match index.search(&expr) {
Ok(_) => panic!("bad-field knn under nested inner_hits must error, not silently empty"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("nope") && msg.contains("unknown"),
"error should name the unknown field: {msg}"
);
cleanup(&path);
}
#[test]
fn knn_recall_survives_segment_merge() {
let path = test_dir("merge_recall");
let schema = Mapping::builder()
.field("embedding", FieldType::dense_vector(4))
.build();
let index = Index::create_with_mapping(&path, schema).unwrap();
index.set_memory_budget(1);
let n: usize = 16;
let docs: Vec<_> = (0..n)
.map(|i| json!({"embedding": [i as f32 + 1.0, 1.0, 0.0, 0.0]}))
.collect();
index.bulk(docs).unwrap();
index.force_merge(1).unwrap();
let results = search(
&index,
json!({"query": {"knn": {
"field": "embedding",
"query_vector": [1.0, 1.0, 0.0, 0.0],
"k": n,
"num_candidates": 100
}}}),
n,
);
assert_eq!(
results.len(),
n,
"merge left dangling resolver entries: kNN returned {} of {n} hits",
results.len()
);
cleanup(&path);
}