use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use edgevec::metric::scalar;
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
use edgevec::metric::simd::x86;
use edgevec::metric::{DotProduct, L2Squared, Metric};
use edgevec::quantization::binary::QuantizedVector;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::hint::black_box;
fn generate_vectors(count: usize, dims: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
(0..count)
.map(|_| (0..dims).map(|_| rng.gen_range(-1.0..1.0)).collect())
.collect()
}
fn generate_u8_vectors(count: usize, dims: usize, seed: u64) -> Vec<Vec<u8>> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
(0..count)
.map(|_| (0..dims).map(|_| rng.gen_range(0..255)).collect())
.collect()
}
fn generate_quantized_vector(seed: u64) -> QuantizedVector {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut data = [0u8; 96];
for byte in &mut data {
*byte = rng.gen();
}
QuantizedVector::from_bytes(data)
}
fn bench_simd_hamming_cycles(c: &mut Criterion) {
let q1 = QuantizedVector::from_bytes([0xAAu8; 96]);
let q2 = QuantizedVector::from_bytes([0x55u8; 96]);
c.bench_function("simd_hamming_96bytes", |b| {
b.iter(|| black_box(&q1).hamming_distance(black_box(&q2)))
});
}
fn bench_simd_vs_portable(c: &mut Criterion) {
let mut group = c.benchmark_group("simd_hamming_comparison");
group.throughput(Throughput::Bytes(96 * 2));
let q1 = QuantizedVector::from_bytes([0xAAu8; 96]);
let q2 = QuantizedVector::from_bytes([0x55u8; 96]);
group.bench_function("simd_dispatch", |b| {
b.iter(|| black_box(&q1).hamming_distance(black_box(&q2)))
});
group.bench_function("portable_baseline", |b| {
b.iter(|| {
let a = black_box(q1.data());
let b = black_box(q2.data());
let mut distance = 0u32;
for i in 0..96 {
distance += (a[i] ^ b[i]).count_ones();
}
black_box(distance)
})
});
group.finish();
}
fn bench_simd_throughput(c: &mut Criterion) {
let mut group = c.benchmark_group("simd_hamming_throughput");
group.throughput(Throughput::Elements(1));
let q1 = QuantizedVector::from_bytes([0xAAu8; 96]);
let q2 = QuantizedVector::from_bytes([0x55u8; 96]);
group.bench_function("hamming_ops_per_sec", |b| {
b.iter(|| black_box(&q1).hamming_distance(black_box(&q2)))
});
group.finish();
}
fn bench_simd_diverse_patterns(c: &mut Criterion) {
let mut group = c.benchmark_group("simd_hamming_patterns");
let zeros = QuantizedVector::from_bytes([0x00u8; 96]);
let ones = QuantizedVector::from_bytes([0xFFu8; 96]);
let alt_aa = QuantizedVector::from_bytes([0xAAu8; 96]);
let alt_55 = QuantizedVector::from_bytes([0x55u8; 96]);
let random1 = generate_quantized_vector(42);
let random2 = generate_quantized_vector(43);
group.bench_function("pattern_zeros_identical", |b| {
b.iter(|| black_box(&zeros).hamming_distance(black_box(&zeros)))
});
group.bench_function("pattern_ones_vs_zeros", |b| {
b.iter(|| black_box(&ones).hamming_distance(black_box(&zeros)))
});
group.bench_function("pattern_alternating", |b| {
b.iter(|| black_box(&alt_aa).hamming_distance(black_box(&alt_55)))
});
group.bench_function("pattern_random", |b| {
b.iter(|| black_box(&random1).hamming_distance(black_box(&random2)))
});
group.finish();
}
fn bench_simd_batch(c: &mut Criterion) {
let mut group = c.benchmark_group("simd_hamming_batch");
let batch_size = 1000;
let vectors: Vec<QuantizedVector> = (0..batch_size)
.map(|i| generate_quantized_vector(i as u64))
.collect();
let query = generate_quantized_vector(9999);
group.throughput(Throughput::Elements(batch_size as u64));
group.bench_function("batch_1000_vectors", |b| {
b.iter(|| {
for v in &vectors {
black_box(black_box(&query).hamming_distance(black_box(v)));
}
})
});
group.finish();
}
fn bench_l2_squared(c: &mut Criterion) {
let mut group = c.benchmark_group("l2_squared");
let count = 10_000;
for dims in [128, 768, 1536] {
let vectors = generate_vectors(count + 1, dims, 42);
let query = &vectors[0];
let targets = &vectors[1..];
group.throughput(Throughput::Elements(count as u64));
group.bench_with_input(BenchmarkId::from_parameter(dims), &dims, |b, _| {
b.iter(|| {
for target in targets {
black_box(L2Squared::distance(black_box(query), black_box(target)));
}
});
});
}
group.finish();
}
fn bench_l2_squared_u8(c: &mut Criterion) {
let mut group = c.benchmark_group("l2_squared_u8");
let count = 10_000;
for dims in [128, 768, 1536] {
let vectors = generate_u8_vectors(count + 1, dims, 42);
let query = &vectors[0];
let targets = &vectors[1..];
group.throughput(Throughput::Elements(count as u64));
group.bench_with_input(BenchmarkId::new("scalar", dims), &dims, |b, _| {
b.iter(|| {
for target in targets {
black_box(scalar::l2_squared_u8(black_box(query), black_box(target)));
}
});
});
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
group.bench_with_input(BenchmarkId::new("avx2", dims), &dims, |b, _| {
b.iter(|| {
for target in targets {
unsafe {
black_box(x86::l2_squared_u8(black_box(query), black_box(target)));
}
}
});
});
}
}
group.finish();
}
fn bench_dot_product(c: &mut Criterion) {
let mut group = c.benchmark_group("dot_product");
let count = 10_000;
for dims in [128, 768] {
let vectors = generate_vectors(count + 1, dims, 42);
let query = &vectors[0];
let targets = &vectors[1..];
group.throughput(Throughput::Elements(count as u64));
group.bench_with_input(BenchmarkId::from_parameter(dims), &dims, |b, _| {
b.iter(|| {
for target in targets {
black_box(DotProduct::distance(black_box(query), black_box(target)));
}
});
});
}
group.finish();
}
#[cfg(target_arch = "x86_64")]
pub fn measure_cycles<F>(f: F, iterations: u64) -> u64
where
F: Fn() -> u32,
{
use std::arch::x86_64::_rdtsc;
for _ in 0..1000 {
std::hint::black_box(f());
}
let start = unsafe { _rdtsc() };
for _ in 0..iterations {
std::hint::black_box(f());
}
let end = unsafe { _rdtsc() };
(end - start) / iterations
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_simd_cycle_target() {
let q1 = QuantizedVector::from_bytes([0xAAu8; 96]);
let q2 = QuantizedVector::from_bytes([0x55u8; 96]);
let cycles = measure_cycles(|| q1.hamming_distance(&q2), 10_000);
println!("=================================================");
println!("SIMD Hamming Distance Cycle Measurement");
println!("=================================================");
println!("Measured cycles: {}", cycles);
println!("Target: <50 cycles");
println!("Hard limit: <75 cycles");
println!("=================================================");
assert!(
cycles < 75,
"FAIL: {} cycles exceeds hard limit of 75 cycles",
cycles
);
if cycles >= 50 {
println!("WARNING: {} cycles exceeds target of 50 cycles", cycles);
println!("Implementation meets hard limit but not target.");
} else {
println!("SUCCESS: {} cycles meets <50 cycle target", cycles);
}
}
criterion_group!(
benches,
bench_simd_hamming_cycles,
bench_simd_vs_portable,
bench_simd_throughput,
bench_simd_diverse_patterns,
bench_simd_batch,
bench_l2_squared,
bench_l2_squared_u8,
bench_dot_product
);
criterion_main!(benches);