use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use std::sync::Arc;
use common::{DistanceMetric, QueryRequest, Vector};
use engine::SearchEngine;
use storage::InMemoryStorage;
use storage::VectorStorage;
fn generate_vectors(count: usize, dimension: usize, prefix: &str) -> Vec<Vector> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
(0..count)
.map(|i| {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
let seed = hasher.finish();
let values: Vec<f32> = (0..dimension)
.map(|j| {
let mut h = DefaultHasher::new();
(seed + j as u64).hash(&mut h);
(h.finish() as f32 / u64::MAX as f32) * 2.0 - 1.0
})
.collect();
Vector {
id: format!("{}{}", prefix, i),
values,
metadata: None,
ttl_seconds: None,
expires_at: None,
}
})
.collect()
}
fn generate_query_vector(dimension: usize, seed: u64) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
(0..dimension)
.map(|j| {
let mut h = DefaultHasher::new();
(seed + j as u64).hash(&mut h);
(h.finish() as f32 / u64::MAX as f32) * 2.0 - 1.0
})
.collect()
}
fn bench_query_top_k(c: &mut Criterion) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let dimension = 128;
let dataset_size = 10_000;
let storage = Arc::new(InMemoryStorage::new());
let engine = Arc::new(SearchEngine::new(storage.clone()));
let namespace = "bench".to_string();
rt.block_on(async {
storage.ensure_namespace(&namespace).await.unwrap();
let vectors = generate_vectors(dataset_size, dimension, "v");
storage.upsert(&namespace, vectors).await.unwrap();
});
let mut group = c.benchmark_group("query_top_k");
group.throughput(Throughput::Elements(1));
for top_k in [1, 5, 10, 50, 100, 500].iter() {
let query_vector = generate_query_vector(dimension, *top_k as u64);
group.bench_with_input(BenchmarkId::from_parameter(top_k), top_k, |b, &k| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k: k,
distance_metric: DistanceMetric::Cosine,
include_metadata: false,
include_vectors: false,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
});
}
group.finish();
}
fn bench_query_dataset_sizes(c: &mut Criterion) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let dimension = 128;
let top_k = 10;
let mut group = c.benchmark_group("query_dataset_sizes");
group.throughput(Throughput::Elements(1));
for dataset_size in [100, 1_000, 5_000, 10_000, 50_000].iter() {
let storage = Arc::new(InMemoryStorage::new());
let engine = Arc::new(SearchEngine::new(storage.clone()));
let namespace = "bench".to_string();
rt.block_on(async {
storage.ensure_namespace(&namespace).await.unwrap();
let vectors = generate_vectors(*dataset_size, dimension, "v");
storage.upsert(&namespace, vectors).await.unwrap();
});
let query_vector = generate_query_vector(dimension, 42);
group.bench_with_input(
BenchmarkId::from_parameter(dataset_size),
dataset_size,
|b, _| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k,
distance_metric: DistanceMetric::Cosine,
include_metadata: false,
include_vectors: false,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
},
);
}
group.finish();
}
fn bench_query_distance_metrics(c: &mut Criterion) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let dimension = 128;
let dataset_size = 10_000;
let top_k = 10;
let storage = Arc::new(InMemoryStorage::new());
let engine = Arc::new(SearchEngine::new(storage.clone()));
let namespace = "bench".to_string();
rt.block_on(async {
storage.ensure_namespace(&namespace).await.unwrap();
let vectors = generate_vectors(dataset_size, dimension, "v");
storage.upsert(&namespace, vectors).await.unwrap();
});
let query_vector = generate_query_vector(dimension, 123);
let mut group = c.benchmark_group("query_distance_metrics");
group.throughput(Throughput::Elements(1));
for (name, metric) in [
("cosine", DistanceMetric::Cosine),
("euclidean", DistanceMetric::Euclidean),
("dot_product", DistanceMetric::DotProduct),
] {
group.bench_function(name, |b| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k,
distance_metric: metric,
include_metadata: false,
include_vectors: false,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
});
}
group.finish();
}
fn bench_query_with_metadata(c: &mut Criterion) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let dimension = 128;
let dataset_size = 10_000;
let top_k = 10;
let storage = Arc::new(InMemoryStorage::new());
let engine = Arc::new(SearchEngine::new(storage.clone()));
let namespace = "bench".to_string();
rt.block_on(async {
storage.ensure_namespace(&namespace).await.unwrap();
let mut vectors = generate_vectors(dataset_size, dimension, "v");
for (i, v) in vectors.iter_mut().enumerate() {
v.metadata = Some(serde_json::json!({
"category": format!("cat_{}", i % 10),
"value": i as f64,
"tags": ["tag1", "tag2"]
}));
}
storage.upsert(&namespace, vectors).await.unwrap();
});
let query_vector = generate_query_vector(dimension, 456);
let mut group = c.benchmark_group("query_with_metadata");
group.throughput(Throughput::Elements(1));
group.bench_function("exclude_metadata", |b| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k,
distance_metric: DistanceMetric::Cosine,
include_metadata: false,
include_vectors: false,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
});
group.bench_function("include_metadata", |b| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
});
group.bench_function("include_vectors", |b| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k,
distance_metric: DistanceMetric::Cosine,
include_metadata: false,
include_vectors: true,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
});
group.bench_function("include_both", |b| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: true,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
});
group.finish();
}
fn bench_query_dimensions(c: &mut Criterion) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let dataset_size = 5_000;
let top_k = 10;
let mut group = c.benchmark_group("query_dimensions");
group.throughput(Throughput::Elements(1));
for dimension in [32, 128, 384, 768, 1536].iter() {
let storage = Arc::new(InMemoryStorage::new());
let engine = Arc::new(SearchEngine::new(storage.clone()));
let namespace = "bench".to_string();
rt.block_on(async {
storage.ensure_namespace(&namespace).await.unwrap();
let vectors = generate_vectors(dataset_size, *dimension, "v");
storage.upsert(&namespace, vectors).await.unwrap();
});
let query_vector = generate_query_vector(*dimension, 789);
group.bench_with_input(BenchmarkId::from_parameter(dimension), dimension, |b, _| {
b.iter(|| {
rt.block_on(async {
let request = QueryRequest {
vector: query_vector.clone(),
top_k,
distance_metric: DistanceMetric::Cosine,
include_metadata: false,
include_vectors: false,
filter: None,
cursor: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
black_box(response)
})
});
});
}
group.finish();
}
criterion_group!(
benches,
bench_query_top_k,
bench_query_dataset_sizes,
bench_query_distance_metrics,
bench_query_with_metadata,
bench_query_dimensions,
);
criterion_main!(benches);