use std::env;
use std::time::Instant;
use serde::Serialize;
use turboquant::kv_cache::{
KVCacheConfig, MultiHeadConfig, MultiHeadKVCache, QuantStrategy, QuantizedKVCache,
};
use turboquant::polar::PolarQuant;
use turboquant::turboquant_mse::TurboQuantMSE;
use turboquant::turboquant_prod::{ProdQuantized, TurboQuantProd};
use turboquant::utils::{inner_product, normalize};
fn random_unit_vectors(dim: usize, count: usize, seed: u64) -> Vec<Vec<f64>> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
(0..count)
.map(|_| {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&raw).unwrap()
})
.collect()
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum OutputFormat {
Text,
Json,
}
#[derive(Clone, Copy, Debug)]
struct BenchmarkConfig {
quick: bool,
output_format: OutputFormat,
}
impl BenchmarkConfig {
fn mode(self) -> &'static str {
if self.quick {
"quick"
} else {
"full"
}
}
fn mse_dims(self) -> &'static [usize] {
if self.quick {
&[64, 128]
} else {
&[64, 128, 256, 512]
}
}
fn mse_vector_count(self) -> usize {
if self.quick {
20
} else {
100
}
}
fn prod_dims(self) -> &'static [usize] {
if self.quick {
&[64, 128]
} else {
&[64, 128, 256]
}
}
fn prod_vector_count(self) -> usize {
if self.quick {
20
} else {
50
}
}
fn polar_dims(self) -> &'static [usize] {
if self.quick {
&[8, 16, 32, 64]
} else {
&[8, 16, 32, 64, 128]
}
}
fn polar_vector_count(self) -> usize {
if self.quick {
20
} else {
50
}
}
fn throughput_count(self) -> usize {
if self.quick {
50
} else {
200
}
}
fn recall_db_size(self) -> usize {
if self.quick {
100
} else {
500
}
}
fn recall_query_count(self) -> usize {
if self.quick {
10
} else {
50
}
}
fn kv_batch_size(self) -> usize {
if self.quick {
8
} else {
16
}
}
fn kv_num_batches(self) -> usize {
if self.quick {
5
} else {
20
}
}
fn kv_query_count(self) -> usize {
if self.quick {
20
} else {
100
}
}
fn multi_head_num_heads(self) -> usize {
if self.quick {
4
} else {
8
}
}
fn multi_head_seq_len(self) -> usize {
if self.quick {
16
} else {
32
}
}
}
#[derive(Debug, Serialize)]
struct BenchmarkReport {
mode: String,
mse_accuracy: Vec<MseAccuracyRow>,
prod_accuracy: Vec<ProdAccuracyRow>,
polar_accuracy: Vec<PolarAccuracyRow>,
throughput: Vec<ThroughputRow>,
recall_at_k: Vec<RecallRow>,
kv_cache_incremental: KvCacheRow,
multi_head_attention: MultiHeadRow,
}
#[derive(Debug, Serialize)]
struct MseAccuracyRow {
dim: usize,
bits: u8,
avg_mse: f64,
bound: f64,
compression_ratio: f64,
}
#[derive(Debug, Serialize)]
struct ProdAccuracyRow {
dim: usize,
bits: u8,
avg_abs_error: f64,
max_abs_error: f64,
bound: f64,
}
#[derive(Debug, Serialize)]
struct PolarAccuracyRow {
dim: usize,
bits: u8,
avg_mse: f64,
compression_ratio: f64,
}
#[derive(Debug, Serialize)]
struct ThroughputRow {
method: String,
dim: usize,
bits: u8,
count: usize,
vectors_per_second: f64,
}
#[derive(Debug, Serialize)]
struct RecallRow {
method: String,
bits: u8,
recall_at_1: f64,
recall_at_5: f64,
recall_at_10: f64,
}
#[derive(Debug, Serialize)]
struct KvCacheRow {
dim: usize,
batch_size: usize,
num_batches: usize,
total_tokens: usize,
append_time_ms: f64,
append_tokens_per_second: f64,
attention_query_count: usize,
average_query_time_ms: f64,
compressed_bytes: usize,
compression_ratio: f64,
}
#[derive(Debug, Serialize)]
struct MultiHeadRow {
num_heads: usize,
head_dim: usize,
seq_len: usize,
append_time_ms: f64,
attention_output_time_ms: f64,
output_dim: usize,
compressed_bytes: usize,
uncompressed_bytes: usize,
compression_ratio: f64,
}
fn parse_args() -> BenchmarkConfig {
let mut quick = false;
let mut output_format = OutputFormat::Text;
for arg in env::args().skip(1) {
match arg.as_str() {
"--quick" => quick = true,
"--json" => output_format = OutputFormat::Json,
"--help" | "-h" => {
println!("Usage: cargo run --release --example benchmark -- [--quick] [--json]");
std::process::exit(0);
}
other => {
eprintln!("unknown argument: {other}");
std::process::exit(2);
}
}
}
BenchmarkConfig {
quick,
output_format,
}
}
fn main() {
let config = parse_args();
let report = BenchmarkReport {
mode: config.mode().to_string(),
mse_accuracy: bench_mse_accuracy(config),
prod_accuracy: bench_prod_accuracy(config),
polar_accuracy: bench_polar_accuracy(config),
throughput: bench_quantize_throughput(config),
recall_at_k: bench_recall_at_k(config),
kv_cache_incremental: bench_kv_cache_incremental(config),
multi_head_attention: bench_multi_head_attention(config),
};
match config.output_format {
OutputFormat::Text => print_report(&report),
OutputFormat::Json => {
println!("{}", serde_json::to_string_pretty(&report).unwrap());
}
}
}
fn print_report(report: &BenchmarkReport) {
println!("========================================");
println!(" TurboQuant Benchmark Suite ({})", report.mode);
println!("========================================\n");
print_mse_accuracy(&report.mse_accuracy);
print_prod_accuracy(&report.prod_accuracy);
print_polar_accuracy(&report.polar_accuracy);
print_throughput(&report.throughput);
print_recall(&report.recall_at_k);
print_kv_cache(&report.kv_cache_incremental);
print_multi_head(&report.multi_head_attention);
println!("========================================");
println!(" All benchmarks complete.");
println!("========================================");
}
fn print_mse_accuracy(rows: &[MseAccuracyRow]) {
println!("--- TurboQuantMSE: Accuracy vs Bit Width ---\n");
println!(
" {:>6} {:>4} {:>12} {:>12} {:>10}",
"dim", "bits", "avg_mse", "bound", "ratio"
);
println!(
" {:-<6} {:-<4} {:-<12} {:-<12} {:-<10}",
"", "", "", "", ""
);
for row in rows {
println!(
" {:>6} {:>4} {:>12.8} {:>12.8} {:>9.1}x",
row.dim, row.bits, row.avg_mse, row.bound, row.compression_ratio
);
}
println!();
}
fn print_prod_accuracy(rows: &[ProdAccuracyRow]) {
println!("--- TurboQuantProd: Inner Product Error ---\n");
println!(
" {:>6} {:>4} {:>14} {:>14} {:>14}",
"dim", "bits", "avg_|err|", "max_|err|", "bound"
);
println!(
" {:-<6} {:-<4} {:-<14} {:-<14} {:-<14}",
"", "", "", "", ""
);
for row in rows {
println!(
" {:>6} {:>4} {:>14.8} {:>14.8} {:>14.8}",
row.dim, row.bits, row.avg_abs_error, row.max_abs_error, row.bound
);
}
println!();
}
fn print_polar_accuracy(rows: &[PolarAccuracyRow]) {
println!("--- PolarQuant: Reconstruction Accuracy ---\n");
println!(
" {:>6} {:>4} {:>12} {:>10}",
"dim", "bits", "avg_mse", "ratio"
);
println!(" {:-<6} {:-<4} {:-<12} {:-<10}", "", "", "", "");
for row in rows {
println!(
" {:>6} {:>4} {:>12.8} {:>9.1}x",
row.dim, row.bits, row.avg_mse, row.compression_ratio
);
}
println!();
}
fn print_throughput(rows: &[ThroughputRow]) {
println!("--- Quantization Throughput ---\n");
println!(
" {:>20} {:>6} {:>4} {:>10} {:>12}",
"method", "dim", "bits", "count", "vectors/s"
);
println!(
" {:-<20} {:-<6} {:-<4} {:-<10} {:-<12}",
"", "", "", "", ""
);
for row in rows {
println!(
" {:>20} {:>6} {:>4} {:>10} {:>12.0}",
row.method, row.dim, row.bits, row.count, row.vectors_per_second
);
}
println!();
}
fn print_recall(rows: &[RecallRow]) {
println!("--- Recall@k: Approximate Nearest Neighbor Search ---\n");
println!(
" {:>20} {:>4} {:>10} {:>10} {:>10}",
"method", "bits", "R@1", "R@5", "R@10"
);
println!(
" {:-<20} {:-<4} {:-<10} {:-<10} {:-<10}",
"", "", "", "", ""
);
for row in rows {
println!(
" {:>20} {:>4} {:>10.4} {:>10.4} {:>10.4}",
row.method, row.bits, row.recall_at_1, row.recall_at_5, row.recall_at_10
);
}
println!();
}
fn print_kv_cache(row: &KvCacheRow) {
println!("--- KV Cache: Incremental Append Performance ---\n");
println!(
" Appending {} batches of {} vectors (dim={})...",
row.num_batches, row.batch_size, row.dim
);
println!(
" Total tokens: {}, Append time: {:.1}ms ({:.0} tokens/s)",
row.total_tokens, row.append_time_ms, row.append_tokens_per_second
);
println!(
" Storage: {} bytes ({:.1}x compression)",
row.compressed_bytes, row.compression_ratio
);
println!(
" Attention query: {:.2}ms avg ({} queries)",
row.average_query_time_ms, row.attention_query_count
);
println!();
}
fn print_multi_head(row: &MultiHeadRow) {
println!("--- Multi-Head KV Cache ---\n");
println!(
" Config: {} heads x {} dim, {} tokens",
row.num_heads, row.head_dim, row.seq_len
);
println!(" Append time: {:.1}ms", row.append_time_ms);
println!(
" Attention output time: {:.2}ms",
row.attention_output_time_ms
);
println!(
" Output dim: {} ({}x{})",
row.output_dim, row.num_heads, row.head_dim
);
println!(
" Memory: {} bytes compressed, {} bytes uncompressed ({:.1}x)",
row.compressed_bytes, row.uncompressed_bytes, row.compression_ratio
);
println!();
}
fn bench_mse_accuracy(config: BenchmarkConfig) -> Vec<MseAccuracyRow> {
let mut rows = Vec::new();
for &dim in config.mse_dims() {
let vectors = random_unit_vectors(dim, config.mse_vector_count(), 42);
for &bits in &[2u8, 3, 4] {
let tq = TurboQuantMSE::new(dim, bits, 42).unwrap();
let mut total_mse = 0.0;
for v in &vectors {
total_mse += tq.actual_mse(v).unwrap();
}
let avg_mse = total_mse / vectors.len() as f64;
let bound = tq.distortion_bound();
let q = tq.quantize(&vectors[0]).unwrap();
rows.push(MseAccuracyRow {
dim,
bits,
avg_mse,
bound,
compression_ratio: q.compression_ratio(),
});
}
}
rows
}
fn bench_prod_accuracy(config: BenchmarkConfig) -> Vec<ProdAccuracyRow> {
let mut rows = Vec::new();
for &dim in config.prod_dims() {
let vectors = random_unit_vectors(dim, config.prod_vector_count(), 7);
let queries = random_unit_vectors(dim, config.prod_vector_count(), 99);
for &bits in &[3u8, 4] {
let tq = TurboQuantProd::new(dim, bits, 7).unwrap();
let mut total_err = 0.0;
let mut max_err: f64 = 0.0;
for (v, q) in vectors.iter().zip(queries.iter()) {
let true_ip = inner_product(v, q);
let quant = tq.quantize(v).unwrap();
let est_ip = tq.estimate_inner_product(&quant, q).unwrap();
let err = (true_ip - est_ip).abs();
total_err += err;
max_err = max_err.max(err);
}
rows.push(ProdAccuracyRow {
dim,
bits,
avg_abs_error: total_err / vectors.len() as f64,
max_abs_error: max_err,
bound: tq.distortion_bound(1.0),
});
}
}
rows
}
fn bench_polar_accuracy(config: BenchmarkConfig) -> Vec<PolarAccuracyRow> {
let mut rows = Vec::new();
for &dim in config.polar_dims() {
let vectors = random_unit_vectors(dim, config.polar_vector_count(), 42);
for &bits in &[4u8, 8] {
let pq = PolarQuant::new(dim, 42, bits).unwrap();
let mut total_mse = 0.0;
for v in &vectors {
let q = pq.quantize(v).unwrap();
let recon = pq.dequantize(&q).unwrap();
let mse: f64 = v
.iter()
.zip(recon.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
/ dim as f64;
total_mse += mse;
}
let q = pq.quantize(&vectors[0]).unwrap();
rows.push(PolarAccuracyRow {
dim,
bits,
avg_mse: total_mse / vectors.len() as f64,
compression_ratio: q.compression_ratio(),
});
}
}
rows
}
fn bench_quantize_throughput(config: BenchmarkConfig) -> Vec<ThroughputRow> {
let dim = 128;
let count = config.throughput_count();
let vectors = random_unit_vectors(dim, count, 42);
let mut rows = Vec::new();
for &bits in &[2u8, 4] {
let tq = TurboQuantMSE::new(dim, bits, 42).unwrap();
let start = Instant::now();
for v in &vectors {
let _ = tq.quantize(v).unwrap();
}
let elapsed = start.elapsed();
rows.push(ThroughputRow {
method: "TurboQuantMSE".into(),
dim,
bits,
count,
vectors_per_second: count as f64 / elapsed.as_secs_f64(),
});
}
for &bits in &[3u8, 4] {
let tq = TurboQuantProd::new(dim, bits, 42).unwrap();
let start = Instant::now();
for v in &vectors {
let _ = tq.quantize(v).unwrap();
}
let elapsed = start.elapsed();
rows.push(ThroughputRow {
method: "TurboQuantProd".into(),
dim,
bits,
count,
vectors_per_second: count as f64 / elapsed.as_secs_f64(),
});
}
for &bits in &[4u8, 8] {
let pq = PolarQuant::new(dim, 42, bits).unwrap();
let start = Instant::now();
for v in &vectors {
let _ = pq.quantize(v).unwrap();
}
let elapsed = start.elapsed();
rows.push(ThroughputRow {
method: "PolarQuant".into(),
dim,
bits,
count,
vectors_per_second: count as f64 / elapsed.as_secs_f64(),
});
}
rows
}
fn bench_recall_at_k(config: BenchmarkConfig) -> Vec<RecallRow> {
let dim = 128;
let db_size = config.recall_db_size();
let num_queries = config.recall_query_count();
let ks = [1, 5, 10];
let database = random_unit_vectors(dim, db_size, 42);
let queries = random_unit_vectors(dim, num_queries, 99);
let ground_truth: Vec<Vec<usize>> = queries
.iter()
.map(|q| {
let mut scored: Vec<(usize, f64)> = database
.iter()
.enumerate()
.map(|(i, v)| (i, inner_product(q, v)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored.iter().map(|(i, _)| *i).collect()
})
.collect();
let mut rows = Vec::new();
for &bits in &[2u8, 3, 4] {
let tq = TurboQuantMSE::new(dim, bits, 42).unwrap();
let quantized_db: Vec<_> = database.iter().map(|v| tq.quantize(v).unwrap()).collect();
let reconstructed_db: Vec<_> = quantized_db
.iter()
.map(|q| tq.dequantize(q).unwrap())
.collect();
let recalls = compute_recalls(&queries, &reconstructed_db, &ground_truth, &ks, |q, v| {
inner_product(q, v)
});
rows.push(RecallRow {
method: "TurboQuantMSE".into(),
bits,
recall_at_1: recalls[0],
recall_at_5: recalls[1],
recall_at_10: recalls[2],
});
}
for &bits in &[3u8, 4] {
let tq = TurboQuantProd::new(dim, bits, 42).unwrap();
let quantized_db: Vec<_> = database.iter().map(|v| tq.quantize(v).unwrap()).collect();
let recalls = compute_recalls_prod(&queries, &quantized_db, &ground_truth, &ks, &tq);
rows.push(RecallRow {
method: "TurboQuantProd".into(),
bits,
recall_at_1: recalls[0],
recall_at_5: recalls[1],
recall_at_10: recalls[2],
});
}
for &bits in &[4u8, 8] {
let pq = PolarQuant::new(dim, 42, bits).unwrap();
let reconstructed_db: Vec<_> = database
.iter()
.map(|v| {
let q = pq.quantize(v).unwrap();
pq.dequantize(&q).unwrap()
})
.collect();
let recalls = compute_recalls(&queries, &reconstructed_db, &ground_truth, &ks, |q, v| {
inner_product(q, v)
});
rows.push(RecallRow {
method: "PolarQuant".into(),
bits,
recall_at_1: recalls[0],
recall_at_5: recalls[1],
recall_at_10: recalls[2],
});
}
rows
}
fn compute_recalls<F>(
queries: &[Vec<f64>],
reconstructed_db: &[Vec<f64>],
ground_truth: &[Vec<usize>],
ks: &[usize],
score_fn: F,
) -> Vec<f64>
where
F: Fn(&[f64], &[f64]) -> f64,
{
let num_queries = queries.len();
let max_k = *ks.iter().max().unwrap();
let mut recalls = vec![0.0f64; ks.len()];
for (qi, query) in queries.iter().enumerate() {
let mut scored: Vec<(usize, f64)> = reconstructed_db
.iter()
.enumerate()
.map(|(i, v)| (i, score_fn(query, v)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let approx_top: Vec<usize> = scored.iter().take(max_k).map(|(i, _)| *i).collect();
for (ki, &k) in ks.iter().enumerate() {
let true_top_k: std::collections::HashSet<usize> =
ground_truth[qi].iter().take(k).copied().collect();
let found = approx_top
.iter()
.take(k)
.filter(|i| true_top_k.contains(i))
.count();
recalls[ki] += found as f64 / k as f64;
}
}
recalls.iter().map(|r| r / num_queries as f64).collect()
}
fn compute_recalls_prod(
queries: &[Vec<f64>],
quantized_db: &[ProdQuantized],
ground_truth: &[Vec<usize>],
ks: &[usize],
tq: &TurboQuantProd,
) -> Vec<f64> {
let num_queries = queries.len();
let max_k = *ks.iter().max().unwrap();
let mut recalls = vec![0.0f64; ks.len()];
for (qi, query) in queries.iter().enumerate() {
let mut scored: Vec<(usize, f64)> = quantized_db
.iter()
.enumerate()
.map(|(i, qv)| (i, tq.estimate_inner_product(qv, query).unwrap()))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let approx_top: Vec<usize> = scored.iter().take(max_k).map(|(i, _)| *i).collect();
for (ki, &k) in ks.iter().enumerate() {
let true_top_k: std::collections::HashSet<usize> =
ground_truth[qi].iter().take(k).copied().collect();
let found = approx_top
.iter()
.take(k)
.filter(|i| true_top_k.contains(i))
.count();
recalls[ki] += found as f64 / k as f64;
}
}
recalls.iter().map(|r| r / num_queries as f64).collect()
}
fn bench_kv_cache_incremental(config: BenchmarkConfig) -> KvCacheRow {
let dim = 128;
let batch_size = config.kv_batch_size();
let num_batches = config.kv_num_batches();
let cache_config = KVCacheConfig::new(dim)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE)
.with_seed(42);
let mut cache = QuantizedKVCache::new(cache_config).unwrap();
let start = Instant::now();
for b in 0..num_batches {
let keys = random_unit_vectors(dim, batch_size, b as u64 * 100);
let vals = random_unit_vectors(dim, batch_size, b as u64 * 100 + 50);
cache.append(&keys, &vals).unwrap();
}
let append_elapsed = start.elapsed();
let stats = cache.stats();
let total_tokens = num_batches * batch_size;
let query = &random_unit_vectors(dim, 1, 999)[0];
let query_count = config.kv_query_count();
let start = Instant::now();
for _ in 0..query_count {
let _ = cache.attention_scores(query).unwrap();
}
let query_elapsed = start.elapsed();
KvCacheRow {
dim,
batch_size,
num_batches,
total_tokens,
append_time_ms: append_elapsed.as_secs_f64() * 1000.0,
append_tokens_per_second: total_tokens as f64 / append_elapsed.as_secs_f64(),
attention_query_count: query_count,
average_query_time_ms: query_elapsed.as_secs_f64() * 1000.0 / query_count as f64,
compressed_bytes: stats.total_bytes,
compression_ratio: stats.compression_ratio,
}
}
fn bench_multi_head_attention(config: BenchmarkConfig) -> MultiHeadRow {
let num_heads = config.multi_head_num_heads();
let head_dim = 64;
let seq_len = config.multi_head_seq_len();
let cache_config = MultiHeadConfig::new(
num_heads,
KVCacheConfig::new(head_dim)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::Prod)
.with_seed(42),
);
let mut cache = MultiHeadKVCache::new(cache_config).unwrap();
let keys: Vec<Vec<Vec<f64>>> = (0..num_heads)
.map(|h| random_unit_vectors(head_dim, seq_len, h as u64 * 1000))
.collect();
let values: Vec<Vec<Vec<f64>>> = (0..num_heads)
.map(|h| random_unit_vectors(head_dim, seq_len, h as u64 * 1000 + 500))
.collect();
let start = Instant::now();
cache.append_all(&keys, &values).unwrap();
let append_elapsed = start.elapsed();
let queries: Vec<Vec<f64>> = (0..num_heads)
.map(|h| {
random_unit_vectors(head_dim, 1, 9000 + h as u64)
.into_iter()
.next()
.unwrap()
})
.collect();
let start = Instant::now();
let outputs = cache.attention_output_all(&queries, 0.0).unwrap();
let query_elapsed = start.elapsed();
let concat = cache.concat_outputs(&outputs).unwrap();
let stats = cache.stats();
MultiHeadRow {
num_heads,
head_dim,
seq_len,
append_time_ms: append_elapsed.as_secs_f64() * 1000.0,
attention_output_time_ms: query_elapsed.as_secs_f64() * 1000.0,
output_dim: concat.len(),
compressed_bytes: stats.total_bytes,
uncompressed_bytes: stats.uncompressed_bytes,
compression_ratio: stats.compression_ratio,
}
}