use swarc::{HNSWIndex, Document};
use rand::Rng;
use std::time::{Duration, Instant};
use std::fs::File;
use std::io::Write;
#[derive(Debug, Clone)]
struct BenchmarkResult {
num_embeddings: usize,
dimension: usize,
insertion_time: Duration,
parallel_insertion_time: Duration,
search_time: Duration,
memory_usage: usize,
}
fn generate_random_embedding(dimension: usize) -> Vec<f32> {
let mut rng = rand::thread_rng();
(0..dimension)
.map(|_| rng.gen_range(-1.0..1.0))
.collect()
}
fn benchmark_insertion(
index: &mut HNSWIndex<String>,
embeddings: &[Vec<f32>],
batch_size: usize,
) -> Duration {
let start = Instant::now();
for (i, embedding) in embeddings.iter().enumerate() {
let doc = Document {
id: format!("doc_{}", i),
data: format!("Content for document {}", i),
};
index
.insert(format!("node_{}", i), embedding.clone(), Some(doc))
.expect("Failed to insert embedding");
if (i + 1) % batch_size == 0 {
println!("Inserted {} embeddings", i + 1);
}
}
start.elapsed()
}
fn benchmark_parallel_insertion(
index: &mut HNSWIndex<String>,
embeddings: &[Vec<f32>],
_batch_size: usize,
) -> Duration {
let start = Instant::now();
let items: Vec<(String, Vec<f32>, Option<Document<String>>)> = embeddings
.iter()
.enumerate()
.map(|(i, embedding)| {
let doc = Document {
id: format!("doc_{}", i),
data: format!("Content for document {}", i),
};
(format!("node_{}", i), embedding.clone(), Some(doc))
})
.collect();
let result = index.insert_parallel(items).expect("Failed to insert embeddings in parallel");
for (i, result) in result.iter().enumerate() {
if let Err(e) = result {
panic!("Failed to insert embedding {}: {}", i, e);
}
}
println!("Inserted {} embeddings in parallel", embeddings.len());
start.elapsed()
}
fn benchmark_search(
index: &HNSWIndex<String>,
query_embeddings: &[Vec<f32>],
k: usize,
) -> Duration {
let start = Instant::now();
for query in query_embeddings {
let _results = index.search(query, k);
}
start.elapsed()
}
fn estimate_memory_usage(index: &HNSWIndex<String>) -> usize {
let num_nodes = index.len();
let dimension = 3072; let avg_connections = 16;
num_nodes * (dimension * 4 + avg_connections * 4 + 100) }
fn run_benchmark(
num_embeddings: usize,
dimension: usize,
batch_size: usize,
) -> BenchmarkResult {
println!("Starting benchmark with {} embeddings of dimension {}", num_embeddings, dimension);
println!("Generating {} random embeddings...", num_embeddings);
let embeddings: Vec<Vec<f32>> = (0..num_embeddings)
.map(|_| generate_random_embedding(dimension))
.collect();
let num_queries = (num_embeddings / 10).max(100);
let query_embeddings: Vec<Vec<f32>> = (0..num_queries)
.map(|_| generate_random_embedding(dimension))
.collect();
println!("Starting sequential insertion benchmark...");
let mut index_seq = HNSWIndex::new(dimension, 16, 200);
let insertion_time = benchmark_insertion(&mut index_seq, &embeddings, batch_size);
println!("Starting parallel insertion benchmark...");
let mut index_parallel = HNSWIndex::new(dimension, 16, 200);
let parallel_insertion_time = benchmark_parallel_insertion(&mut index_parallel, &embeddings, batch_size);
println!("Starting search benchmark...");
let search_time = benchmark_search(&index_parallel, &query_embeddings, 10);
let memory_usage = estimate_memory_usage(&index_parallel);
BenchmarkResult {
num_embeddings,
dimension,
insertion_time,
parallel_insertion_time,
search_time,
memory_usage,
}
}
fn save_results_to_csv(results: &[BenchmarkResult], filename: &str) {
let mut file = File::create(filename).expect("Failed to create CSV file");
writeln!(file, "num_embeddings,dimension,insertion_time_ms,parallel_insertion_time_ms,search_time_ms,memory_usage_mb").unwrap();
for result in results {
writeln!(
file,
"{},{},{},{},{},{}",
result.num_embeddings,
result.dimension,
result.insertion_time.as_millis(),
result.parallel_insertion_time.as_millis(),
result.search_time.as_millis(),
result.memory_usage / (1024 * 1024) ).unwrap();
}
println!("Results saved to {}", filename);
}
fn main() {
let dimension = 3072;
let batch_size = 1000;
let test_sizes = vec![
1000, 10000, 50000, ];
let mut results = Vec::new();
for &size in &test_sizes {
println!("\n=== Benchmarking {} embeddings ===", size);
let result = run_benchmark(size, dimension, batch_size);
println!("\nBenchmark Summary:");
println!(" Embeddings: {}", size);
println!(" Sequential insertion time: {:.2}s", result.insertion_time.as_secs_f64());
println!(" Parallel insertion time: {:.2}s", result.parallel_insertion_time.as_secs_f64());
println!(" Speedup: {:.2}x", result.insertion_time.as_secs_f64() / result.parallel_insertion_time.as_secs_f64());
println!(" Search time: {:.2}ms", result.search_time.as_millis());
println!(" Memory usage: {:.2}MB", result.memory_usage as f64 / (1024.0 * 1024.0));
println!(" Sequential insertions/sec: {:.0}", size as f64 / result.insertion_time.as_secs_f64());
println!(" Parallel insertions/sec: {:.0}", size as f64 / result.parallel_insertion_time.as_secs_f64());
results.push(result);
}
save_results_to_csv(&results, "performance_tests/benchmark_results.csv");
println!("\n=== All benchmarks completed ===");
println!("Results saved to performance_tests/benchmark_results.csv");
}