use std::collections::HashSet;
use crate::dsl::{Document, SchemaBuilder};
use crate::index::{Index, IndexConfig, IndexWriter};
use crate::query::{
BooleanQuery, PrefixQuery, RangeQuery, SparseTermQuery, SparseVectorQuery, TermQuery,
};
async fn create_boolean_test_index() -> (
Index<crate::directories::MmapDirectory>,
crate::dsl::Field, // content (text)
crate::dsl::Field, // timestamp (u64 fast)
crate::dsl::Field, // embedding (sparse_vector)
) {
use crate::directories::MmapDirectory;
let tmp_dir = tempfile::tempdir().unwrap();
let dir = MmapDirectory::new(tmp_dir.path());
let mut sb = SchemaBuilder::default();
let content = sb.add_text_field("content", true, true);
let timestamp = sb.add_u64_field("timestamp", false, true);
sb.set_fast(timestamp, true);
let embedding = sb.add_sparse_vector_field("embedding", true, true);
let schema = sb.build();
let config = IndexConfig {
max_indexing_memory_bytes: 8192,
..Default::default()
};
let mut writer = IndexWriter::create(dir.clone(), schema, config.clone())
.await
.unwrap();
for i in 0u64..100 {
let mut doc = Document::new();
doc.add_text(content, format!("doc{}", i));
doc.add_u64(timestamp, 1000 + i * 100);
let entries: Vec<(u32, f32)> = vec![(0, 0.5), (1, 0.3), (100 + i as u32, 0.8)];
doc.add_sparse_vector(embedding, entries);
writer.add_document(doc).unwrap();
}
writer.commit().await.unwrap();
writer.force_merge().await.unwrap();
let index = Index::open(dir, config).await.unwrap();
assert_eq!(index.num_docs().await.unwrap(), 100);
(index, content, timestamp, embedding)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_predicate_aware_maxscore_fills_topk() {
let (index, _content, timestamp, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new()
.must(RangeQuery::u64(timestamp, Some(3000), Some(5000)))
.should(SparseTermQuery::new(embedding, 0, 1.0));
let results = index.search(&q, 10).await.unwrap();
assert_eq!(
results.hits.len(),
10,
"Predicate-aware MaxScore should return exactly limit results when enough docs match, got {}",
results.hits.len()
);
for hit in &results.hits {
let doc_id = hit.address.doc_id;
assert!(
(20..=40).contains(&doc_id),
"Doc {} should be in range [20, 40]",
doc_id
);
assert!(
hit.score > 0.0,
"Score should be positive, got {}",
hit.score
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_predicate_aware_maxscore_sparse_vector() {
let (index, _content, timestamp, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new()
.must(RangeQuery::u64(timestamp, Some(2000), Some(8000)))
.should(SparseVectorQuery::new(embedding, vec![(0, 1.0), (1, 0.5)]));
let results = index.search(&q, 20).await.unwrap();
assert_eq!(
results.hits.len(),
20,
"Should return exactly 20 results from 61 matching docs, got {}",
results.hits.len()
);
for hit in &results.hits {
let doc_id = hit.address.doc_id;
assert!(
(10..=70).contains(&doc_id),
"Doc {} should be in range [10, 70]",
doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_sparse_term_query_single_dim() {
let (index, _content, _ts, embedding) = create_boolean_test_index().await;
let q = SparseTermQuery::new(embedding, 142, 1.0);
let results = index.search(&q, 10).await.unwrap();
assert_eq!(results.hits.len(), 1, "Only doc 42 has dim 42");
assert_eq!(results.hits[0].address.doc_id, 42);
assert!(results.hits[0].score > 0.0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_sparse_term_query_shared_dim() {
let (index, _content, _ts, embedding) = create_boolean_test_index().await;
let q = SparseTermQuery::new(embedding, 0, 1.0);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(results.hits.len(), 100, "Dim 0 is shared across all docs");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_sparse_term_query_missing_dim() {
let (index, _content, _ts, embedding) = create_boolean_test_index().await;
let q = SparseTermQuery::new(embedding, 99999, 1.0);
let results = index.search(&q, 10).await.unwrap();
assert_eq!(
results.hits.len(),
0,
"Non-existent dim should match nothing"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_sparse_vector_query_maxscore_path() {
let (index, _content, _ts, embedding) = create_boolean_test_index().await;
let q = SparseVectorQuery::new(embedding, vec![(110, 1.0), (120, 1.0)]);
let results = index.search(&q, 10).await.unwrap();
let top2: Vec<u32> = results
.hits
.iter()
.take(2)
.map(|h| h.address.doc_id)
.collect();
assert!(
top2.contains(&10) && top2.contains(&20),
"Top 2 should be docs 10 and 20, got {:?}",
top2
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_must_range_should_sparse_lazy() {
let (index, _content, timestamp, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new()
.must(RangeQuery::u64(timestamp, Some(5000), Some(7000)))
.should(SparseVectorQuery::new(
embedding,
vec![(150, 1.0), (151, 1.0)],
));
let results = index.search(&q, 100).await.unwrap();
assert_eq!(results.hits.len(), 2);
let doc_ids: Vec<u32> = results.hits.iter().map(|h| h.address.doc_id).collect();
assert!(
doc_ids.contains(&50) && doc_ids.contains(&51),
"Docs 50 and 51 should be returned, got {:?}",
doc_ids
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_must_term_should_sparse_term() {
let (index, content, _ts, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new()
.must(TermQuery::text(content, "doc25"))
.should(SparseTermQuery::new(embedding, 125, 1.0));
let results = index.search(&q, 10).await.unwrap();
assert_eq!(results.hits.len(), 1, "Only doc_25 matches the MUST term");
assert_eq!(results.hits[0].address.doc_id, 25);
assert!(results.hits[0].score > 0.0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_should_sparse_must_not_range() {
let (index, _content, timestamp, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new()
.should(SparseTermQuery::new(embedding, 0, 1.0))
.must_not(RangeQuery::u64(timestamp, Some(5000), None));
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
40,
"Should exclude docs with ts >= 5000"
);
for hit in &results.hits {
assert!(
hit.address.doc_id < 40,
"Doc {} should have been excluded",
hit.address.doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_must_should_must_not_combined() {
let (index, content, timestamp, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new()
.must(RangeQuery::u64(timestamp, Some(2000), Some(6000)))
.should(SparseTermQuery::new(embedding, 0, 0.5))
.should(SparseTermQuery::new(embedding, 130, 2.0))
.must_not(TermQuery::text(content, "doc30"));
let results = index.search(&q, 100).await.unwrap();
assert_eq!(
results.hits.len(),
40,
"Should be 40 docs (41 range - 1 excluded)"
);
let has_30 = results.hits.iter().any(|h| h.address.doc_id == 30);
assert!(!has_30, "Doc 30 should be excluded by MUST_NOT");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_two_must_two_should_sparse() {
let (index, _content, timestamp, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new()
.must(RangeQuery::u64(timestamp, Some(3000), None))
.must(RangeQuery::u64(timestamp, None, Some(5000)))
.should(SparseTermQuery::new(embedding, 125, 1.0))
.should(SparseTermQuery::new(embedding, 130, 1.0));
let results = index.search(&q, 100).await.unwrap();
assert_eq!(results.hits.len(), 2);
let doc_ids: Vec<u32> = results.hits.iter().map(|h| h.address.doc_id).collect();
assert!(
doc_ids.contains(&25) && doc_ids.contains(&30),
"Results should be docs 25 and 30, got {:?}",
doc_ids
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_should_only_no_match() {
let (index, _content, _ts, embedding) = create_boolean_test_index().await;
let q = BooleanQuery::new().should(SparseTermQuery::new(embedding, 99999, 1.0));
let results = index.search(&q, 10).await.unwrap();
assert_eq!(
results.hits.len(),
0,
"Non-existent dim should yield no results"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_sparse_vector_single_dim_path() {
let (index, _content, _ts, embedding) = create_boolean_test_index().await;
let q = SparseVectorQuery::new(embedding, vec![(142, 1.0)]);
let results = index.search(&q, 10).await.unwrap();
assert_eq!(
results.hits.len(),
1,
"Single dim 142 should match only doc 42"
);
assert_eq!(results.hits[0].address.doc_id, 42);
}
async fn create_pruning_test_index() -> (
Index<crate::directories::MmapDirectory>,
crate::dsl::Field, // embedding
) {
use crate::directories::MmapDirectory;
let tmp_dir = tempfile::tempdir().unwrap();
let dir = MmapDirectory::new(tmp_dir.path());
let mut sb = SchemaBuilder::default();
let embedding = sb.add_sparse_vector_field("embedding", true, true);
let schema = sb.build();
let config = IndexConfig {
max_indexing_memory_bytes: 50 * 1024 * 1024,
..Default::default()
};
let mut writer = IndexWriter::create(dir.clone(), schema, config.clone())
.await
.unwrap();
for i in 0u64..200 {
let mut doc = Document::new();
let mut entries: Vec<(u32, f32)> = vec![(0, 0.1)]; match i {
0..50 => entries.push((1000, 5.0)), 50..100 => entries.push((2000, 2.0)), 100..150 => entries.push((3000, 0.05)), 150..200 => entries.push((4000 + i as u32, 1.0)), _ => {}
}
doc.add_sparse_vector(embedding, entries);
writer.add_document(doc).unwrap();
}
writer.commit().await.unwrap();
let index = Index::open(dir, config).await.unwrap();
assert_eq!(index.num_docs().await.unwrap(), 200);
(index, embedding)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_baseline_no_pruning() {
let (index, embedding) = create_pruning_test_index().await;
let q = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
200,
"Baseline: all docs match via shared dim 0"
);
let top_doc = results.hits[0].address.doc_id;
assert!(
top_doc < 50,
"Top hit should be from strong group, got doc {}",
top_doc
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_weight_threshold_drops_shared_dim() {
let (index, embedding) = create_pruning_test_index().await;
let q = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
)
.with_weight_threshold(0.005)
.with_min_query_dims(1);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
150,
"After threshold 0.005: unique group (50 docs) should be gone, got {}",
results.hits.len()
);
for hit in &results.hits {
assert!(
hit.address.doc_id < 150,
"Doc {} from unique group should be pruned",
hit.address.doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_weight_threshold_drops_weak_and_shared() {
let (index, embedding) = create_pruning_test_index().await;
let q = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
)
.with_weight_threshold(0.05)
.with_min_query_dims(1);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
100,
"After threshold 0.05: only strong+medium (100 docs), got {}",
results.hits.len()
);
for hit in &results.hits {
assert!(
hit.address.doc_id < 100,
"Doc {} should have been pruned (weak/unique group)",
hit.address.doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_max_query_dims() {
let (index, embedding) = create_pruning_test_index().await;
let q = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
)
.with_max_query_dims(2);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
100,
"max_query_dims=2: only strong+medium (100 docs), got {}",
results.hits.len()
);
for hit in &results.hits {
assert!(
hit.address.doc_id < 100,
"Doc {} should be pruned (only 2 dims kept)",
hit.address.doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_max_query_dims_single() {
let (index, embedding) = create_pruning_test_index().await;
let q = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
)
.with_max_query_dims(1);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
50,
"max_query_dims=1: only strong group (50 docs), got {}",
results.hits.len()
);
for hit in &results.hits {
assert!(
hit.address.doc_id < 50,
"Doc {} should be pruned (only top-1 dim kept)",
hit.address.doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_fraction_half() {
let (index, embedding) = create_pruning_test_index().await;
let q = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
)
.with_pruning(0.5)
.with_min_query_dims(1);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
100,
"pruning=0.5: keep top 2 dims → 100 docs, got {}",
results.hits.len()
);
for hit in &results.hits {
assert!(
hit.address.doc_id < 100,
"Doc {} should be pruned (bottom 50% dims dropped)",
hit.address.doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_fraction_quarter() {
let (index, embedding) = create_pruning_test_index().await;
let q = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
)
.with_pruning(0.25)
.with_min_query_dims(1);
let results = index.search(&q, 200).await.unwrap();
assert_eq!(
results.hits.len(),
50,
"pruning=0.25: keep 1 dim → 50 docs, got {}",
results.hits.len()
);
for hit in &results.hits {
assert!(
hit.address.doc_id < 50,
"Doc {} should be pruned (top 25% = 1 dim)",
hit.address.doc_id
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_pruning_score_impact() {
let (index, embedding) = create_pruning_test_index().await;
let q_full = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
);
let q_pruned = SparseVectorQuery::new(
embedding,
vec![(1000, 10.0), (2000, 5.0), (3000, 0.01), (0, 0.001)],
)
.with_weight_threshold(0.05)
.with_min_query_dims(1);
let full = index.search(&q_full, 200).await.unwrap();
let pruned = index.search(&q_pruned, 200).await.unwrap();
let full_score = full
.hits
.iter()
.find(|h| h.address.doc_id == 0)
.unwrap()
.score;
let pruned_score = pruned
.hits
.iter()
.find(|h| h.address.doc_id == 0)
.unwrap()
.score;
assert!(
pruned_score <= full_score,
"Pruned score ({}) should be <= full score ({})",
pruned_score,
full_score
);
assert!(
full.hits.iter().any(|h| h.address.doc_id == 120),
"Doc 120 should be in full results"
);
assert!(
!pruned.hits.iter().any(|h| h.address.doc_id == 120),
"Doc 120 should NOT be in pruned results (dim 3000 dropped)"
);
}
#[tokio::test]
async fn test_multi_field_text_should() {
use crate::directories::MmapDirectory;
let tmp_dir = tempfile::tempdir().unwrap();
let dir = MmapDirectory::new(tmp_dir.path());
let mut sb = SchemaBuilder::default();
let title = sb.add_text_field("title", true, true);
let body = sb.add_text_field("body", true, true);
let schema = sb.build();
let config = IndexConfig {
max_indexing_memory_bytes: 8192,
..Default::default()
};
let mut writer = IndexWriter::create(dir.clone(), schema, config.clone())
.await
.unwrap();
let mut doc = Document::new();
doc.add_text(title, "rust programming");
doc.add_text(body, "python scripting");
writer.add_document(doc).unwrap();
let mut doc = Document::new();
doc.add_text(title, "rust language");
doc.add_text(body, "rust compiler");
writer.add_document(doc).unwrap();
let mut doc = Document::new();
doc.add_text(title, "java enterprise");
doc.add_text(body, "python machine learning");
writer.add_document(doc).unwrap();
let mut doc = Document::new();
doc.add_text(title, "java enterprise");
doc.add_text(body, "java spring");
writer.add_document(doc).unwrap();
for i in 4..14 {
let mut doc = Document::new();
doc.add_text(title, format!("filler title {}", i));
doc.add_text(body, format!("filler body {}", i));
writer.add_document(doc).unwrap();
}
writer.commit().await.unwrap();
let index = Index::open(dir, config).await.unwrap();
let mut query = BooleanQuery::new();
query = query.should(TermQuery::text(title, "rust"));
query = query.should(TermQuery::text(body, "rust"));
query = query.should(TermQuery::text(title, "python"));
query = query.should(TermQuery::text(body, "python"));
let results = index.search(&query, 10).await.unwrap();
let doc_ids: Vec<u32> = results.hits.iter().map(|h| h.address.doc_id).collect();
assert!(
doc_ids.contains(&0),
"Doc 0 should match (rust in title, python in body)"
);
assert!(
doc_ids.contains(&1),
"Doc 1 should match (rust in both fields)"
);
assert!(doc_ids.contains(&2), "Doc 2 should match (python in body)");
assert!(!doc_ids.contains(&3), "Doc 3 should not match (java only)");
let doc1_score = results
.hits
.iter()
.find(|h| h.address.doc_id == 1)
.unwrap()
.score;
let doc2_score = results
.hits
.iter()
.find(|h| h.address.doc_id == 2)
.unwrap()
.score;
assert!(
doc1_score > doc2_score,
"Doc 1 (rust in both fields) should score higher than doc 2 (python in body only): {} vs {}",
doc1_score,
doc2_score
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_must_prefix_should_sparse_no_result_loss() {
let (index, content, _ts, embedding) = create_boolean_test_index().await;
let prefix_q = PrefixQuery::text(content, "doc1");
let prefix_results = index.search(&prefix_q, 100).await.unwrap();
assert!(
prefix_results.hits.len() == 11,
"prefix 'doc1' should match 11 docs (doc1, doc10-doc19), got {}",
prefix_results.hits.len()
);
let bool_q = BooleanQuery::new()
.should(SparseVectorQuery::new(embedding, vec![(0, 1.0)]))
.must(PrefixQuery::text(content, "doc1"));
let bool_results = index.search(&bool_q, 100).await.unwrap();
assert_eq!(
bool_results.hits.len(),
11,
"boolean(should=sparse, must=prefix) should return all 11 matching docs, got {}",
bool_results.hits.len()
);
let expected: HashSet<u32> = [1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
.into_iter()
.collect();
let actual: HashSet<u32> = bool_results.hits.iter().map(|h| h.address.doc_id).collect();
assert_eq!(actual, expected, "wrong doc IDs returned");
let bool_q5 = BooleanQuery::new()
.should(SparseVectorQuery::new(embedding, vec![(0, 1.0)]))
.must(PrefixQuery::text(content, "doc1"));
let results5 = index.search(&bool_q5, 5).await.unwrap();
assert_eq!(
results5.hits.len(),
5,
"should return exactly 5 docs with limit=5, got {}",
results5.hits.len()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_must_term_non_fast_should_sparse_bitset_fallback() {
let (index, content, _ts, embedding) = create_boolean_test_index().await;
let bool_q = BooleanQuery::new()
.should(SparseVectorQuery::new(embedding, vec![(0, 1.0)]))
.must(TermQuery::text(content, "doc25"));
let results = index.search(&bool_q, 5).await.unwrap();
assert_eq!(
results.hits.len(),
1,
"should find doc25 (the only match), got {}",
results.hits.len()
);
assert_eq!(results.hits[0].address.doc_id, 25);
}