mod common;
use apithing::ApiOperation;
use common::create_temp_dir_for_test;
use shardex::{
api::{
operations::{
BatchAddPostings, CreateIndex, Flush, GetPerformanceStats, GetStats, IncrementalAdd, RemoveDocuments,
Search,
},
parameters::{
BatchAddPostingsParams, CreateIndexParams, FlushParams, GetPerformanceStatsParams, GetStatsParams,
IncrementalAddParams, RemoveDocumentsParams, SearchParams,
},
ShardexContext,
},
DocumentId, Posting,
};
use std::error::Error;
use std::time::{Duration, Instant};
#[test]
fn test_batch_operations_example_workflow() -> Result<(), Box<dyn Error>> {
let temp_dir = create_temp_dir_for_test();
let mut context = ShardexContext::new();
let create_params = CreateIndexParams::builder()
.directory_path(temp_dir.path().to_path_buf())
.vector_size(256) .shard_size(10000) .batch_write_interval_ms(100) .default_slop_factor(3) .bloom_filter_size(1024) .wal_segment_size(2 * 1024 * 1024) .build()?;
CreateIndex::execute(&mut context, &create_params)?;
const BATCH_SIZE: usize = 25;
const NUM_BATCHES: usize = 3;
const TOTAL_DOCS: usize = BATCH_SIZE * NUM_BATCHES;
let mut total_indexing_time = Duration::new(0, 0);
for batch_num in 0..NUM_BATCHES {
let batch_start = batch_num * BATCH_SIZE;
let postings = generate_document_batch(batch_start, BATCH_SIZE, 256);
let batch_params = BatchAddPostingsParams::with_flush_and_tracking(postings)?;
let batch_stats = BatchAddPostings::execute(&mut context, &batch_params)?;
total_indexing_time += batch_stats.processing_time;
assert!(batch_stats.processing_time.as_nanos() > 0);
assert_eq!(batch_stats.operations_flushed, BATCH_SIZE as u64);
assert!(batch_stats.throughput_docs_per_sec >= 0.0);
let stats_params = GetStatsParams::new();
let _stats = GetStats::execute(&mut context, &stats_params)?;
}
assert!(total_indexing_time.as_nanos() > 0);
let overall_throughput = TOTAL_DOCS as f64 / total_indexing_time.as_secs_f64();
assert!(overall_throughput >= 0.0);
let initial_stats_params = GetStatsParams::new();
let _initial_stats = GetStats::execute(&mut context, &initial_stats_params)?;
for i in 0..3 {
let increment = generate_document_batch(TOTAL_DOCS + i * 10, 10, 256);
let incremental_params = IncrementalAddParams::with_batch_id(increment, format!("increment_{}", i + 1))?;
let incremental_stats = IncrementalAdd::execute(&mut context, &incremental_params)?;
assert_eq!(incremental_stats.postings_added, 10);
assert!(incremental_stats.processing_time.as_nanos() > 0);
if i % 2 == 0 {
let flush_params = FlushParams::new();
Flush::execute(&mut context, &flush_params)?;
}
let current_stats_params = GetStatsParams::new();
let _current_stats = GetStats::execute(&mut context, ¤t_stats_params)?;
std::thread::sleep(Duration::from_millis(50)); }
let final_flush_params = FlushParams::new();
Flush::execute(&mut context, &final_flush_params)?;
let final_stats_params = GetStatsParams::new();
let final_stats = GetStats::execute(&mut context, &final_stats_params)?;
let mut docs_to_remove = Vec::new();
for i in (5..=final_stats.total_postings).step_by(5) {
docs_to_remove.push(i as u128);
}
if !docs_to_remove.is_empty() {
let removal_params = RemoveDocumentsParams::new(docs_to_remove)?;
let removal_stats = RemoveDocuments::execute(&mut context, &removal_params)?;
let flush_params = FlushParams::new();
Flush::execute(&mut context, &flush_params)?;
let after_removal_stats_params = GetStatsParams::new();
let _after_removal_stats = GetStats::execute(&mut context, &after_removal_stats_params)?;
assert!(removal_stats.processing_time.as_nanos() > 0);
}
let query_vector = generate_test_vector(256);
for k in [1, 3, 5, 10] {
let search_params = SearchParams::builder()
.query_vector(query_vector.clone())
.k(k)
.slop_factor(None)
.build()?;
let search_start = Instant::now();
let results = Search::execute(&mut context, &search_params)?;
let search_time = search_start.elapsed();
assert!(search_time.as_nanos() > 0);
assert!(results.len() <= k); }
let detailed_stats_params = GetStatsParams::new();
let detailed_stats = GetStats::execute(&mut context, &detailed_stats_params)?;
let perf_stats_params = GetPerformanceStatsParams::detailed();
let perf_stats = GetPerformanceStats::execute(&mut context, &perf_stats_params)?;
assert!(detailed_stats.average_shard_utilization >= 0.0);
assert!(perf_stats.average_latency.as_nanos() > 0);
assert!(perf_stats.throughput >= 0.0);
Ok(())
}
#[test]
fn test_batch_operations_performance() -> Result<(), Box<dyn Error>> {
let temp_dir = create_temp_dir_for_test();
let mut context = ShardexContext::new();
let create_params = CreateIndexParams::builder()
.directory_path(temp_dir.path().to_path_buf())
.vector_size(128) .shard_size(50) .build()?;
CreateIndex::execute(&mut context, &create_params)?;
let batch = generate_document_batch(0, 20, 128);
let start_time = Instant::now();
let batch_params = BatchAddPostingsParams::with_flush_and_tracking(batch)?;
let batch_stats = BatchAddPostings::execute(&mut context, &batch_params)?;
let total_time = start_time.elapsed();
assert!(
total_time < Duration::from_secs(30),
"Batch operation too slow: {:?}",
total_time
);
assert!(
batch_stats.throughput_docs_per_sec > 0.0,
"Throughput should be positive"
);
let query_vector = generate_test_vector(128);
let search_params = SearchParams::builder()
.query_vector(query_vector)
.k(5)
.build()?;
let search_start = Instant::now();
let _results = Search::execute(&mut context, &search_params)?;
let search_time = search_start.elapsed();
assert!(
search_time < Duration::from_secs(10),
"Search too slow: {:?}",
search_time
);
Ok(())
}
#[test]
fn test_batch_operations_cleanup() -> Result<(), Box<dyn Error>> {
let temp_dir = create_temp_dir_for_test();
let temp_path = temp_dir.path().to_path_buf();
{
let mut context = ShardexContext::new();
let create_params = CreateIndexParams::builder()
.directory_path(temp_path.clone())
.vector_size(64)
.shard_size(50)
.build()?;
CreateIndex::execute(&mut context, &create_params)?;
let batch = generate_document_batch(0, 10, 64);
let batch_params = BatchAddPostingsParams::with_flush_and_tracking(batch)?;
BatchAddPostings::execute(&mut context, &batch_params)?;
let flush_params = FlushParams::new();
Flush::execute(&mut context, &flush_params)?;
}
assert!(temp_path.exists(), "Index directory should exist");
Ok(())
}
fn generate_document_batch(start_id: usize, count: usize, vector_size: usize) -> Vec<Posting> {
(0..count)
.map(|i| {
let doc_id = start_id + i + 1;
let document_id = DocumentId::from_raw(doc_id as u128);
let vector = generate_deterministic_vector(doc_id, vector_size);
Posting {
document_id,
start: 0,
length: 50 + (doc_id % 50) as u32,
vector,
}
})
.collect()
}
fn generate_deterministic_vector(seed: usize, size: usize) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut vector = Vec::with_capacity(size);
let mut hasher = DefaultHasher::new();
seed.hash(&mut hasher);
for i in 0..size {
(seed + i).hash(&mut hasher);
let value = ((hasher.finish() % 10000) as f32 - 5000.0) / 5000.0;
vector.push(value);
}
let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for value in &mut vector {
*value /= magnitude;
}
} else {
for (i, value) in vector.iter_mut().enumerate() {
*value = if i == 0 { 1.0 } else { 0.0 };
}
}
vector
}
fn generate_test_vector(size: usize) -> Vec<f32> {
generate_deterministic_vector(99999, size)
}