use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use edgevec::hnsw::{HnswConfig, HnswIndex, SearchContext};
use edgevec::metric::{DotProduct, L2Squared, Metric};
use edgevec::quantization::binary::QuantizedVector;
use edgevec::storage::VectorStorage;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::hint::black_box;
fn generate_random_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
(0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect()
}
fn generate_vectors(count: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
(0..count)
.map(|i| generate_random_vector(dim, seed + i as u64))
.collect()
}
fn generate_quantized_vector(seed: u64) -> QuantizedVector {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut data = [0u8; 96];
for byte in &mut data {
*byte = rng.gen();
}
QuantizedVector::from_bytes(data)
}
fn bench_dot_product(c: &mut Criterion) {
let mut group = c.benchmark_group("dot_product");
for dim in [128, 256, 384, 512, 768, 1024, 1536] {
let a = generate_random_vector(dim, 42);
let b = generate_random_vector(dim, 43);
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| DotProduct::distance(black_box(&a), black_box(&b)))
});
}
group.finish();
}
fn bench_l2_squared(c: &mut Criterion) {
let mut group = c.benchmark_group("l2_squared");
for dim in [128, 256, 384, 512, 768, 1024, 1536] {
let a = generate_random_vector(dim, 42);
let b = generate_random_vector(dim, 43);
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| L2Squared::distance(black_box(&a), black_box(&b)))
});
}
group.finish();
}
fn bench_cosine(c: &mut Criterion) {
let mut group = c.benchmark_group("cosine_similarity");
for dim in [128, 256, 384, 512, 768, 1024, 1536] {
let a = generate_random_vector(dim, 42);
let b = generate_random_vector(dim, 43);
let a_norm: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let b_norm: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| {
let dot = DotProduct::distance(black_box(&a), black_box(&b));
let cosine = 1.0 + dot / (a_norm * b_norm);
black_box(cosine)
})
});
}
group.finish();
}
fn bench_hamming(c: &mut Criterion) {
let mut group = c.benchmark_group("hamming_distance");
let q1 = generate_quantized_vector(42);
let q2 = generate_quantized_vector(43);
group.throughput(Throughput::Bytes(96 * 2));
group.bench_function("768bit", |b| {
b.iter(|| black_box(&q1).hamming_distance(black_box(&q2)))
});
group.finish();
}
fn bench_search(c: &mut Criterion) {
let mut group = c.benchmark_group("search");
group.sample_size(50);
for count in [1_000, 10_000] {
let dims = 768;
let seed = 42u64;
let vectors = generate_vectors(count, dims, seed);
let config = HnswConfig::new(dims as u32);
let mut storage = VectorStorage::new(&config, None);
let mut index = HnswIndex::new(config, &storage).unwrap();
for v in &vectors {
index.insert(v, &mut storage).unwrap();
}
let query = generate_random_vector(dims, 999);
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::new("k=10", count), &count, |bench, _| {
let mut ctx = SearchContext::new();
bench.iter(|| {
index
.search_with_context(black_box(&query), 10, &storage, &mut ctx)
.unwrap()
})
});
}
group.finish();
}
fn bench_distance_batch(c: &mut Criterion) {
let mut group = c.benchmark_group("distance_batch_10k");
for dim in [128, 384, 768] {
let query = generate_random_vector(dim, 42);
let vectors = generate_vectors(10_000, dim, 100);
group.throughput(Throughput::Elements(10_000));
group.bench_with_input(BenchmarkId::new("l2", dim), &dim, |bench, _| {
bench.iter(|| {
for v in &vectors {
black_box(L2Squared::distance(black_box(&query), black_box(v)));
}
})
});
group.bench_with_input(BenchmarkId::new("dot", dim), &dim, |bench, _| {
bench.iter(|| {
for v in &vectors {
black_box(DotProduct::distance(black_box(&query), black_box(v)));
}
})
});
}
group.finish();
}
fn bench_hamming_batch(c: &mut Criterion) {
let mut group = c.benchmark_group("hamming_batch_10k");
let query = generate_quantized_vector(42);
let vectors: Vec<QuantizedVector> = (0..10_000)
.map(|i| generate_quantized_vector(100 + i))
.collect();
group.throughput(Throughput::Elements(10_000));
group.bench_function("768bit", |b| {
b.iter(|| {
for v in &vectors {
black_box(black_box(&query).hamming_distance(black_box(v)));
}
})
});
group.finish();
}
criterion_group!(
benches,
bench_dot_product,
bench_l2_squared,
bench_cosine,
bench_hamming,
bench_search,
bench_distance_batch,
bench_hamming_batch
);
criterion_main!(benches);