use std::collections::HashSet;
use std::path::PathBuf;
use std::time::Instant;
use clap::{Parser, ValueEnum};
use serde::Serialize;
use turboquant::backend::{sum_squared_error, weighted_sum_in_place, ExecutionBackend};
use turboquant::batch::{
batch_attention_scores_mse_with_backend, batch_dequantize_mse_with_backend,
batch_estimate_inner_products_with_backend, batch_quantize_mse_with_backend,
batch_quantize_prod_with_backend, BatchQuantizedMSE, BatchQuantizedProd,
};
#[cfg(feature = "gpu")]
use turboquant::gpu::WgpuMseBatchRunner;
use turboquant::polar::PolarQuant;
#[cfg(feature = "gpu")]
use turboquant::qjl::QJLQuantized;
use turboquant::real_model::{
RealModelGenerationConfig, RealModelQuantizationConfig, RealModelRunner, RealModelTrace,
SupportedRealModel,
};
use turboquant::trace::KvTrace;
use turboquant::turboquant_mse::TurboQuantMSE;
use turboquant::turboquant_prod::{ProdQuantized, TurboQuantProd};
use turboquant::utils::{inner_product, normalize};
use turboquant::TurboQuantError;
#[derive(Parser, Debug)]
#[command(
about = "Benchmark TurboQuant on synthetic workloads, exported traces, and real ONNX Runtime decoder runs"
)]
struct Args {
#[arg(long, value_enum, default_value_t = WorkloadArg::Synthetic)]
workload: WorkloadArg,
#[arg(long)]
quick: bool,
#[arg(long)]
json: bool,
#[arg(long, value_enum, default_value_t = BackendArg::Simd)]
backend: BackendArg,
#[arg(long, default_value_t = 4)]
bits: u8,
#[arg(long)]
value_bits: Option<u8>,
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long)]
trace: Option<PathBuf>,
#[arg(long, default_value_t = 0)]
max_queries: usize,
#[arg(long)]
real_model_dir: Option<PathBuf>,
#[arg(long)]
prompt: Option<String>,
#[arg(long)]
prompts: Option<PathBuf>,
#[arg(long, default_value_t = 4)]
max_prompts: usize,
#[arg(long, default_value_t = 16)]
max_new_tokens: usize,
#[arg(long, default_value_t = 5)]
top_k: usize,
#[arg(long, value_enum, default_value_t = RealEvalModeArg::Compare)]
real_eval_mode: RealEvalModeArg,
#[arg(long, value_enum, default_value_t = KeyStrategyArg::Prod)]
real_key_strategy: KeyStrategyArg,
#[arg(long)]
supported_real_models: bool,
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum WorkloadArg {
Synthetic,
Trace,
RealModel,
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum BackendArg {
Scalar,
Simd,
#[cfg(feature = "gpu")]
Wgpu,
}
impl BackendArg {
fn into_execution_backend(self) -> ExecutionBackend {
match self {
Self::Scalar => ExecutionBackend::Scalar,
Self::Simd => ExecutionBackend::Simd,
#[cfg(feature = "gpu")]
Self::Wgpu => ExecutionBackend::Wgpu,
}
}
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum RealEvalModeArg {
Exact,
Quantized,
Compare,
}
impl RealEvalModeArg {
fn as_str(self) -> &'static str {
match self {
Self::Exact => "exact",
Self::Quantized => "quantized",
Self::Compare => "compare",
}
}
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum KeyStrategyArg {
Mse,
Prod,
}
impl KeyStrategyArg {
fn into_quant_strategy(self) -> turboquant::QuantStrategy {
match self {
Self::Mse => turboquant::QuantStrategy::MSE,
Self::Prod => turboquant::QuantStrategy::Prod,
}
}
fn algorithm_name(self) -> &'static str {
match self {
Self::Mse => "TurboQuantMSE",
Self::Prod => "TurboQuantProd",
}
}
}
#[derive(Debug, Clone)]
struct EvalSample {
keys: Vec<Vec<f64>>,
values: Vec<Vec<f64>>,
queries: Vec<Vec<f64>>,
query_positions: Vec<usize>,
}
#[derive(Debug, Clone)]
struct EvalWorkload {
name: String,
model: Option<String>,
source: String,
dim: usize,
samples: Vec<EvalSample>,
}
#[derive(Debug, Clone)]
struct PromptRecord {
prompt: String,
}
#[derive(Debug, Serialize)]
struct BenchmarkRow {
workload: String,
source: String,
model: Option<String>,
eval_mode: String,
algorithm: Option<String>,
backend: String,
bits: Option<u8>,
value_bits: Option<u8>,
dim: Option<usize>,
samples: usize,
tokens: usize,
queries: usize,
generated_tokens: Option<usize>,
key_mse: Option<f64>,
logit_rmse: Option<f64>,
output_rmse: Option<f64>,
recall_at_1: Option<f64>,
recall_at_5: Option<f64>,
recall_at_10: Option<f64>,
top_k_agreement: Option<f64>,
token_match_rate: Option<f64>,
divergence_rate: Option<f64>,
mean_first_divergence_step: Option<f64>,
cross_entropy_exact: Option<f64>,
cross_entropy_quantized: Option<f64>,
perplexity_exact: Option<f64>,
perplexity_quantized: Option<f64>,
exact_latency_seconds: Option<f64>,
quantized_latency_seconds: Option<f64>,
exact_tokens_per_second: Option<f64>,
quantized_tokens_per_second: Option<f64>,
kv_memory_exact_bytes: Option<usize>,
kv_memory_quantized_bytes: Option<usize>,
compressed_bytes: Option<usize>,
uncompressed_bytes: Option<usize>,
compression_ratio: Option<f64>,
quantize_tokens_per_second: Option<f64>,
query_tokens_per_second: Option<f64>,
aggregate_score: Option<f64>,
}
#[derive(Debug, Serialize)]
struct BenchmarkReport {
mode: String,
rows: Vec<BenchmarkRow>,
}
const REAL_MODEL_BACKEND: &str = "onnxruntime-cpu";
#[derive(Debug, Clone, Copy)]
enum Algorithm {
Mse,
Prod,
Polar,
}
impl Algorithm {
fn name(self) -> &'static str {
match self {
Self::Mse => "TurboQuantMSE",
Self::Prod => "TurboQuantProd",
Self::Polar => "PolarQuant",
}
}
}
#[derive(Default)]
struct MetricsAccumulator {
query_count: usize,
logit_sq_error_sum: f64,
logit_count: usize,
output_sq_error_sum: f64,
output_count: usize,
recall_at_1_sum: f64,
recall_at_5_sum: f64,
recall_at_10_sum: f64,
}
impl MetricsAccumulator {
fn observe(
&mut self,
exact_scores: &[f64],
approx_scores: &[f64],
exact_output: &[f64],
approx_output: &[f64],
) {
self.query_count += 1;
self.logit_sq_error_sum += exact_scores
.iter()
.zip(approx_scores.iter())
.map(|(left, right)| {
let delta = left - right;
delta * delta
})
.sum::<f64>();
self.logit_count += exact_scores.len();
self.output_sq_error_sum += exact_output
.iter()
.zip(approx_output.iter())
.map(|(left, right)| {
let delta = left - right;
delta * delta
})
.sum::<f64>();
self.output_count += exact_output.len();
self.recall_at_1_sum += recall_at_k(exact_scores, approx_scores, 1);
self.recall_at_5_sum += recall_at_k(exact_scores, approx_scores, 5);
self.recall_at_10_sum += recall_at_k(exact_scores, approx_scores, 10);
}
fn logit_rmse(&self) -> f64 {
if self.logit_count == 0 {
0.0
} else {
(self.logit_sq_error_sum / self.logit_count as f64).sqrt()
}
}
fn output_rmse(&self) -> f64 {
if self.output_count == 0 {
0.0
} else {
(self.output_sq_error_sum / self.output_count as f64).sqrt()
}
}
fn recall_at_1(&self) -> f64 {
average(self.recall_at_1_sum, self.query_count)
}
fn recall_at_5(&self) -> f64 {
average(self.recall_at_5_sum, self.query_count)
}
fn recall_at_10(&self) -> f64 {
average(self.recall_at_10_sum, self.query_count)
}
}
#[derive(Default)]
struct SingleRunAccumulator {
samples: usize,
prompt_tokens: usize,
generated_tokens: usize,
latency_seconds: f64,
kv_exact_bytes: usize,
kv_quantized_bytes: usize,
self_cross_entropy_sum: f64,
self_cross_entropy_count: usize,
}
impl SingleRunAccumulator {
fn observe(&mut self, trace: &RealModelTrace) {
self.samples += 1;
self.prompt_tokens += trace.prompt_tokens.len();
self.generated_tokens += trace.generated_tokens.len();
self.latency_seconds += trace.prefill_seconds + trace.decode_seconds;
self.kv_exact_bytes += trace.kv_cache.exact_bytes;
self.kv_quantized_bytes += trace.kv_cache.quantized_bytes.unwrap_or(0);
for (logits, token_id) in trace.step_logits.iter().zip(trace.generated_tokens.iter()) {
self.self_cross_entropy_sum += negative_log_probability(logits, *token_id as usize);
self.self_cross_entropy_count += 1;
}
}
fn self_cross_entropy(&self) -> Option<f64> {
if self.self_cross_entropy_count == 0 {
None
} else {
Some(self.self_cross_entropy_sum / self.self_cross_entropy_count as f64)
}
}
}
#[derive(Default)]
struct ComparisonAccumulator {
compared_steps: usize,
logit_sq_error_sum: f64,
logit_count: usize,
top_k_agreement_sum: f64,
token_matches: usize,
token_positions: usize,
divergence_count: usize,
divergence_step_sum: f64,
reference_cross_entropy_exact_sum: f64,
reference_cross_entropy_quantized_sum: f64,
reference_cross_entropy_count: usize,
}
impl ComparisonAccumulator {
fn observe(&mut self, exact: &RealModelTrace, quantized: &RealModelTrace, top_k: usize) {
let max_tokens = exact
.generated_tokens
.len()
.max(quantized.generated_tokens.len());
self.token_positions += max_tokens;
let mut divergence = None;
for step in 0..max_tokens {
let exact_token = exact.generated_tokens.get(step).copied();
let quant_token = quantized.generated_tokens.get(step).copied();
if exact_token.is_some() && exact_token == quant_token {
self.token_matches += 1;
} else if divergence.is_none() {
divergence = Some(step + 1);
}
}
if let Some(step) = divergence {
self.divergence_count += 1;
self.divergence_step_sum += step as f64;
}
let shared_steps = exact
.step_logits
.len()
.min(quantized.step_logits.len())
.min(exact.generated_tokens.len());
for step in 0..shared_steps {
let exact_logits = &exact.step_logits[step];
let quantized_logits = &quantized.step_logits[step];
self.logit_sq_error_sum += exact_logits
.iter()
.zip(quantized_logits.iter())
.map(|(left, right)| {
let delta = *left as f64 - *right as f64;
delta * delta
})
.sum::<f64>();
self.logit_count += exact_logits.len();
self.top_k_agreement_sum += top_k_agreement(exact_logits, quantized_logits, top_k);
self.reference_cross_entropy_exact_sum +=
negative_log_probability(exact_logits, exact.generated_tokens[step] as usize);
self.reference_cross_entropy_quantized_sum +=
negative_log_probability(quantized_logits, exact.generated_tokens[step] as usize);
self.reference_cross_entropy_count += 1;
self.compared_steps += 1;
}
}
fn logit_rmse(&self) -> Option<f64> {
if self.logit_count == 0 {
None
} else {
Some((self.logit_sq_error_sum / self.logit_count as f64).sqrt())
}
}
fn top_k_agreement(&self) -> Option<f64> {
if self.compared_steps == 0 {
None
} else {
Some(self.top_k_agreement_sum / self.compared_steps as f64)
}
}
fn token_match_rate(&self) -> Option<f64> {
if self.token_positions == 0 {
None
} else {
Some(self.token_matches as f64 / self.token_positions as f64)
}
}
fn divergence_rate(&self) -> Option<f64> {
self.token_match_rate().map(|rate| 1.0 - rate)
}
fn mean_first_divergence_step(&self) -> Option<f64> {
if self.divergence_count == 0 {
None
} else {
Some(self.divergence_step_sum / self.divergence_count as f64)
}
}
fn reference_cross_entropy_exact(&self) -> Option<f64> {
if self.reference_cross_entropy_count == 0 {
None
} else {
Some(self.reference_cross_entropy_exact_sum / self.reference_cross_entropy_count as f64)
}
}
fn reference_cross_entropy_quantized(&self) -> Option<f64> {
if self.reference_cross_entropy_count == 0 {
None
} else {
Some(
self.reference_cross_entropy_quantized_sum
/ self.reference_cross_entropy_count as f64,
)
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
if args.supported_real_models {
print_supported_real_models();
return Ok(());
}
let report = match args.workload {
WorkloadArg::Synthetic => synthetic_report(&args)?,
WorkloadArg::Trace => trace_report(&args)?,
WorkloadArg::RealModel => real_model_report(&args)?,
};
if args.json {
println!("{}", serde_json::to_string_pretty(&report)?);
} else {
print_report(&report);
}
Ok(())
}
fn synthetic_report(args: &Args) -> turboquant::Result<BenchmarkReport> {
let backend = args.backend.into_execution_backend();
let workloads = synthetic_workloads(args.quick, args.seed);
let mut rows = Vec::new();
for workload in &workloads {
for algorithm in [Algorithm::Mse, Algorithm::Prod, Algorithm::Polar] {
rows.push(evaluate_workload(
workload,
algorithm,
backend,
args.bits,
args.seed,
args.max_queries,
)?);
}
}
Ok(BenchmarkReport {
mode: if args.quick { "quick" } else { "full" }.to_string(),
rows,
})
}
fn trace_report(args: &Args) -> turboquant::Result<BenchmarkReport> {
let path = args
.trace
.as_ref()
.ok_or_else(|| TurboQuantError::TraceFormat("trace mode requires --trace <path>".into()))?;
let backend = args.backend.into_execution_backend();
let workload = load_trace_workload(path)?;
let mut rows = Vec::new();
for algorithm in [Algorithm::Mse, Algorithm::Prod, Algorithm::Polar] {
rows.push(evaluate_workload(
&workload,
algorithm,
backend,
args.bits,
args.seed,
args.max_queries,
)?);
}
Ok(BenchmarkReport {
mode: if args.quick { "quick" } else { "full" }.to_string(),
rows,
})
}
fn real_model_report(args: &Args) -> turboquant::Result<BenchmarkReport> {
let model_dir = args.real_model_dir.as_ref().ok_or_else(|| {
TurboQuantError::ModelConfig("real-model mode requires --real-model-dir <dir>".into())
})?;
if args.prompt.is_some() && args.prompts.is_some() {
return Err(TurboQuantError::ModelConfig(
"pass either --prompt or --prompts, not both".into(),
));
}
let prompts = load_prompt_records(args)?;
let runner = RealModelRunner::load(model_dir)?;
let generation = RealModelGenerationConfig {
max_new_tokens: args.max_new_tokens,
stop_on_eos: true,
};
let value_bits = args.value_bits.unwrap_or(args.bits);
let quantization = RealModelQuantizationConfig {
key_bits: args.bits,
value_bits,
key_strategy: args.real_key_strategy.into_quant_strategy(),
seed: args.seed,
};
let model_id = runner.model_id().to_string();
let mut rows = Vec::new();
match args.real_eval_mode {
RealEvalModeArg::Exact => {
let mut exact = SingleRunAccumulator::default();
for prompt in &prompts {
let trace = runner.generate_exact(&prompt.prompt, &generation)?;
exact.observe(&trace);
}
rows.push(build_real_model_exact_row(
&model_id,
args.real_eval_mode,
&exact,
));
}
RealEvalModeArg::Quantized => {
let mut quantized = SingleRunAccumulator::default();
for prompt in &prompts {
let trace =
runner.generate_quantized(&prompt.prompt, &generation, &quantization)?;
quantized.observe(&trace);
}
rows.push(build_real_model_quantized_row(
&model_id,
args.real_eval_mode,
args.real_key_strategy,
args.bits,
value_bits,
&quantized,
));
}
RealEvalModeArg::Compare => {
let mut exact = SingleRunAccumulator::default();
let mut quantized = SingleRunAccumulator::default();
let mut comparison = ComparisonAccumulator::default();
for prompt in &prompts {
let exact_trace = runner.generate_exact(&prompt.prompt, &generation)?;
let quantized_trace =
runner.generate_quantized(&prompt.prompt, &generation, &quantization)?;
comparison.observe(&exact_trace, &quantized_trace, args.top_k.max(1));
exact.observe(&exact_trace);
quantized.observe(&quantized_trace);
}
rows.push(build_real_model_compare_row(
&model_id,
args.real_key_strategy,
args.bits,
value_bits,
&exact,
&quantized,
&comparison,
));
}
}
Ok(BenchmarkReport {
mode: args.real_eval_mode.as_str().to_string(),
rows,
})
}
fn build_real_model_exact_row(
model_id: &str,
mode: RealEvalModeArg,
exact: &SingleRunAccumulator,
) -> BenchmarkRow {
let cross_entropy = exact.self_cross_entropy();
BenchmarkRow {
workload: format!("real-model:{model_id}"),
source: "real-model".to_string(),
model: Some(model_id.to_string()),
eval_mode: mode.as_str().to_string(),
algorithm: Some("ExactCache".to_string()),
backend: REAL_MODEL_BACKEND.to_string(),
bits: None,
value_bits: None,
dim: None,
samples: exact.samples,
tokens: exact.prompt_tokens,
queries: 0,
generated_tokens: Some(exact.generated_tokens),
key_mse: None,
logit_rmse: None,
output_rmse: None,
recall_at_1: None,
recall_at_5: None,
recall_at_10: None,
top_k_agreement: None,
token_match_rate: None,
divergence_rate: None,
mean_first_divergence_step: None,
cross_entropy_exact: cross_entropy,
cross_entropy_quantized: None,
perplexity_exact: cross_entropy.map(f64::exp),
perplexity_quantized: None,
exact_latency_seconds: Some(exact.latency_seconds),
quantized_latency_seconds: None,
exact_tokens_per_second: Some(rate(exact.generated_tokens, exact.latency_seconds)),
quantized_tokens_per_second: None,
kv_memory_exact_bytes: Some(exact.kv_exact_bytes),
kv_memory_quantized_bytes: None,
compressed_bytes: None,
uncompressed_bytes: Some(exact.kv_exact_bytes),
compression_ratio: None,
quantize_tokens_per_second: None,
query_tokens_per_second: None,
aggregate_score: None,
}
}
fn build_real_model_quantized_row(
model_id: &str,
mode: RealEvalModeArg,
strategy: KeyStrategyArg,
bits: u8,
value_bits: u8,
quantized: &SingleRunAccumulator,
) -> BenchmarkRow {
let cross_entropy = quantized.self_cross_entropy();
let compression_ratio = if quantized.kv_quantized_bytes == 0 {
None
} else {
Some(quantized.kv_exact_bytes as f64 / quantized.kv_quantized_bytes as f64)
};
BenchmarkRow {
workload: format!("real-model:{model_id}"),
source: "real-model".to_string(),
model: Some(model_id.to_string()),
eval_mode: mode.as_str().to_string(),
algorithm: Some(strategy.algorithm_name().to_string()),
backend: REAL_MODEL_BACKEND.to_string(),
bits: Some(bits),
value_bits: Some(value_bits),
dim: None,
samples: quantized.samples,
tokens: quantized.prompt_tokens,
queries: 0,
generated_tokens: Some(quantized.generated_tokens),
key_mse: None,
logit_rmse: None,
output_rmse: None,
recall_at_1: None,
recall_at_5: None,
recall_at_10: None,
top_k_agreement: None,
token_match_rate: None,
divergence_rate: None,
mean_first_divergence_step: None,
cross_entropy_exact: None,
cross_entropy_quantized: cross_entropy,
perplexity_exact: None,
perplexity_quantized: cross_entropy.map(f64::exp),
exact_latency_seconds: None,
quantized_latency_seconds: Some(quantized.latency_seconds),
exact_tokens_per_second: None,
quantized_tokens_per_second: Some(rate(
quantized.generated_tokens,
quantized.latency_seconds,
)),
kv_memory_exact_bytes: Some(quantized.kv_exact_bytes),
kv_memory_quantized_bytes: Some(quantized.kv_quantized_bytes),
compressed_bytes: Some(quantized.kv_quantized_bytes),
uncompressed_bytes: Some(quantized.kv_exact_bytes),
compression_ratio,
quantize_tokens_per_second: None,
query_tokens_per_second: None,
aggregate_score: None,
}
}
fn build_real_model_compare_row(
model_id: &str,
strategy: KeyStrategyArg,
bits: u8,
value_bits: u8,
exact: &SingleRunAccumulator,
quantized: &SingleRunAccumulator,
comparison: &ComparisonAccumulator,
) -> BenchmarkRow {
let compression_ratio = if quantized.kv_quantized_bytes == 0 {
None
} else {
Some(quantized.kv_exact_bytes as f64 / quantized.kv_quantized_bytes as f64)
};
BenchmarkRow {
workload: format!("real-model:{model_id}"),
source: "real-model".to_string(),
model: Some(model_id.to_string()),
eval_mode: "compare".to_string(),
algorithm: Some(strategy.algorithm_name().to_string()),
backend: REAL_MODEL_BACKEND.to_string(),
bits: Some(bits),
value_bits: Some(value_bits),
dim: None,
samples: exact.samples,
tokens: exact.prompt_tokens,
queries: 0,
generated_tokens: Some(exact.generated_tokens.max(quantized.generated_tokens)),
key_mse: None,
logit_rmse: comparison.logit_rmse(),
output_rmse: None,
recall_at_1: None,
recall_at_5: None,
recall_at_10: None,
top_k_agreement: comparison.top_k_agreement(),
token_match_rate: comparison.token_match_rate(),
divergence_rate: comparison.divergence_rate(),
mean_first_divergence_step: comparison.mean_first_divergence_step(),
cross_entropy_exact: comparison.reference_cross_entropy_exact(),
cross_entropy_quantized: comparison.reference_cross_entropy_quantized(),
perplexity_exact: comparison.reference_cross_entropy_exact().map(f64::exp),
perplexity_quantized: comparison.reference_cross_entropy_quantized().map(f64::exp),
exact_latency_seconds: Some(exact.latency_seconds),
quantized_latency_seconds: Some(quantized.latency_seconds),
exact_tokens_per_second: Some(rate(exact.generated_tokens, exact.latency_seconds)),
quantized_tokens_per_second: Some(rate(
quantized.generated_tokens,
quantized.latency_seconds,
)),
kv_memory_exact_bytes: Some(exact.kv_exact_bytes),
kv_memory_quantized_bytes: Some(quantized.kv_quantized_bytes),
compressed_bytes: Some(quantized.kv_quantized_bytes),
uncompressed_bytes: Some(exact.kv_exact_bytes),
compression_ratio,
quantize_tokens_per_second: None,
query_tokens_per_second: None,
aggregate_score: None,
}
}
fn print_supported_real_models() {
println!("Supported real-model presets:");
for model in SupportedRealModel::all() {
println!(
" {:<28} {:<44} {}",
model.preset_name(),
model.model_id(),
model.description()
);
}
}
fn load_prompt_records(args: &Args) -> turboquant::Result<Vec<PromptRecord>> {
if let Some(prompt) = &args.prompt {
let trimmed = prompt.trim();
if trimmed.is_empty() {
return Err(TurboQuantError::ModelConfig(
"--prompt must not be empty".into(),
));
}
return Ok(vec![PromptRecord {
prompt: trimmed.to_string(),
}]);
}
let path = args.prompts.as_ref().ok_or_else(|| {
TurboQuantError::ModelConfig(
"real-model mode requires either --prompt or --prompts <file>".into(),
)
})?;
let text = std::fs::read_to_string(path)
.map_err(|error| TurboQuantError::Io(format!("{}: {error}", path.display())))?;
let mut prompts = Vec::new();
if path.extension().is_some_and(|ext| ext == "jsonl") {
for line in text.lines() {
if line.trim().is_empty() {
continue;
}
let value: serde_json::Value = serde_json::from_str(line).map_err(|error| {
TurboQuantError::ModelConfig(format!(
"{} contains invalid JSONL: {error}",
path.display()
))
})?;
let prompt = value
.get("prompt")
.or_else(|| value.get("text"))
.or_else(|| value.get("input"))
.or_else(|| value.get("question"))
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|prompt| !prompt.is_empty())
.ok_or_else(|| {
TurboQuantError::ModelConfig(format!(
"{} JSONL records must contain a non-empty prompt/text/input/question field",
path.display()
))
})?;
prompts.push(PromptRecord {
prompt: prompt.to_string(),
});
if prompts.len() >= args.max_prompts {
break;
}
}
} else {
for line in text.lines() {
let prompt = line.trim();
if prompt.is_empty() {
continue;
}
prompts.push(PromptRecord {
prompt: prompt.to_string(),
});
if prompts.len() >= args.max_prompts {
break;
}
}
}
if prompts.is_empty() {
return Err(TurboQuantError::ModelConfig(format!(
"no prompts found in {}",
path.display()
)));
}
Ok(prompts)
}
fn synthetic_workloads(quick: bool, seed: u64) -> Vec<EvalWorkload> {
let presets = if quick {
vec![
(
"synthetic-gemma-2b-head-shape",
256usize,
128usize,
12usize,
1usize,
),
(
"synthetic-mistral-7b-head-shape",
128usize,
192usize,
16usize,
1usize,
),
]
} else {
vec![
(
"synthetic-gemma-2b-head-shape",
256usize,
1024usize,
96usize,
3usize,
),
(
"synthetic-mistral-7b-head-shape",
128usize,
2048usize,
128usize,
4usize,
),
(
"synthetic-llama-3.1-8b-head-shape",
128usize,
2048usize,
128usize,
4usize,
),
]
};
presets
.into_iter()
.enumerate()
.map(
|(preset_index, (name, dim, tokens, queries, samples))| EvalWorkload {
name: name.to_string(),
model: None,
source: "synthetic".to_string(),
dim,
samples: (0..samples)
.map(|sample_index| {
synthetic_sample(
dim,
tokens,
queries,
seed + preset_index as u64 * 10_000 + sample_index as u64,
)
})
.collect(),
},
)
.collect()
}
fn synthetic_sample(dim: usize, token_count: usize, query_count: usize, seed: u64) -> EvalSample {
let keys = random_unit_vectors(dim, token_count, seed);
let values = random_unit_vectors(dim, token_count, seed + 1);
let queries = random_unit_vectors(dim, query_count, seed + 2);
EvalSample {
keys,
values,
queries,
query_positions: evenly_spaced_positions(token_count, query_count),
}
}
fn random_unit_vectors(dim: usize, count: usize, seed: u64) -> Vec<Vec<f64>> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
(0..count)
.map(|_| {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&raw).unwrap()
})
.collect()
}
fn evenly_spaced_positions(token_count: usize, query_count: usize) -> Vec<usize> {
if token_count == 0 || query_count == 0 {
return Vec::new();
}
if query_count == 1 {
return vec![token_count - 1];
}
let step = (token_count - 1) as f64 / (query_count - 1) as f64;
(0..query_count)
.map(|index| ((index as f64) * step).round() as usize)
.collect()
}
fn load_trace_workload(path: &PathBuf) -> turboquant::Result<EvalWorkload> {
let trace = KvTrace::load(path)?;
let base_name = trace
.metadata
.benchmark
.clone()
.or_else(|| {
path.file_stem()
.map(|stem| stem.to_string_lossy().to_string())
})
.unwrap_or_else(|| "trace".to_string());
let samples = trace
.samples
.into_iter()
.map(|sample| -> turboquant::Result<EvalSample> {
Ok(EvalSample {
keys: normalize_rows(sample.keys)?,
values: normalize_rows(sample.values)?,
queries: normalize_rows(sample.queries)?,
query_positions: sample.query_positions,
})
})
.collect::<turboquant::Result<Vec<_>>>()?;
Ok(EvalWorkload {
name: format!("trace:{base_name}"),
model: trace.metadata.model,
source: "trace".to_string(),
dim: trace.dim,
samples,
})
}
fn normalize_rows(rows: Vec<Vec<f64>>) -> turboquant::Result<Vec<Vec<f64>>> {
rows.into_iter().map(|row| normalize(&row)).collect()
}
fn evaluate_workload(
workload: &EvalWorkload,
algorithm: Algorithm,
backend: ExecutionBackend,
bits: u8,
seed: u64,
max_queries: usize,
) -> turboquant::Result<BenchmarkRow> {
let dim = workload.dim;
let value_quantizer = TurboQuantMSE::new(dim, bits, seed.wrapping_add(1000))?;
let mse_quantizer = TurboQuantMSE::new(dim, bits, seed)?;
let prod_quantizer = TurboQuantProd::new(dim, bits.max(2), seed)?;
let polar_quantizer = PolarQuant::new(dim, seed, bits)?;
#[cfg(feature = "gpu")]
let mse_runner = if matches!(backend, ExecutionBackend::Wgpu) {
Some(WgpuMseBatchRunner::new(&mse_quantizer)?)
} else {
None
};
#[cfg(not(feature = "gpu"))]
let _mse_runner: Option<()> = None;
#[cfg(feature = "gpu")]
let prod_runner = if matches!(backend, ExecutionBackend::Wgpu) {
Some(WgpuMseBatchRunner::new(prod_quantizer.mse_stage())?)
} else {
None
};
#[cfg(not(feature = "gpu"))]
let _prod_runner: Option<()> = None;
let mut quantize_time = 0.0f64;
let mut query_time = 0.0f64;
let mut compressed_bytes = 0usize;
let mut token_count = 0usize;
let mut query_count = 0usize;
let mut key_mse_sum = 0.0f64;
let mut key_mse_batches = 0usize;
let mut metrics = MetricsAccumulator::default();
for sample in &workload.samples {
token_count += sample.keys.len();
let value_start = Instant::now();
let value_batch =
batch_quantize_mse_with_backend(backend, &value_quantizer, &sample.values)?;
let value_quantize_ms = value_start.elapsed().as_secs_f64();
let reconstructed_values =
batch_dequantize_mse_with_backend(backend, &value_quantizer, &value_batch)?;
let key_start = Instant::now();
let prepared = match algorithm {
Algorithm::Mse => {
let batch = batch_quantize_mse_with_backend(backend, &mse_quantizer, &sample.keys)?;
let key_quantize_ms = key_start.elapsed().as_secs_f64();
let reconstructed =
batch_dequantize_mse_with_backend(backend, &mse_quantizer, &batch)?;
quantize_time += value_quantize_ms + key_quantize_ms;
PreparedKeys::Mse {
batch,
reconstructed,
}
}
Algorithm::Prod => PreparedKeys::Prod {
batch: {
let batch =
batch_quantize_prod_with_backend(backend, &prod_quantizer, &sample.keys)?;
quantize_time += value_quantize_ms + key_start.elapsed().as_secs_f64();
batch
},
},
Algorithm::Polar => {
let quantized = sample
.keys
.iter()
.map(|key| polar_quantizer.quantize(key))
.collect::<turboquant::Result<Vec<_>>>()?;
let key_quantize_ms = key_start.elapsed().as_secs_f64();
let reconstructed = quantized
.iter()
.map(|item| polar_quantizer.dequantize(item))
.collect::<turboquant::Result<Vec<_>>>()?;
quantize_time += value_quantize_ms + key_quantize_ms;
PreparedKeys::Polar {
quantized,
reconstructed,
}
}
};
compressed_bytes += value_batch.total_bytes() + prepared.key_bytes();
if let Some(mse) = prepared.key_mse(&sample.keys, dim, &prod_quantizer)? {
key_mse_sum += mse;
key_mse_batches += 1;
}
let query_limit = max_queries_per_sample(sample.queries.len(), max_queries);
for query_index in 0..query_limit {
query_count += 1;
let position = *sample.query_positions.get(query_index).ok_or_else(|| {
TurboQuantError::LengthMismatch {
context: "benchmark query position count".into(),
expected: query_limit,
got: sample.query_positions.len(),
}
})?;
if position >= sample.keys.len() {
return Err(TurboQuantError::TraceFormat(format!(
"query position {position} is out of range for sample with {} keys",
sample.keys.len()
)));
}
let prefix_len = position + 1;
let exact_scores =
attention_scores_exact(&sample.keys[..prefix_len], &sample.queries[query_index]);
let exact_output =
attention_output(&exact_scores, &sample.values[..prefix_len], dim, backend);
let query_start = Instant::now();
let approx_scores_all = match &prepared {
PreparedKeys::Mse { batch, .. } => score_mse(
backend,
&mse_quantizer,
batch,
&sample.queries[query_index],
#[cfg(feature = "gpu")]
mse_runner.as_ref(),
)?,
PreparedKeys::Prod { batch } => score_prod(
backend,
&prod_quantizer,
batch,
&sample.queries[query_index],
#[cfg(feature = "gpu")]
prod_runner.as_ref(),
)?,
PreparedKeys::Polar { reconstructed, .. } => {
attention_scores_exact(&reconstructed[..], &sample.queries[query_index])
}
};
let approx_scores = approx_scores_all[..prefix_len].to_vec();
let approx_output = attention_output(
&approx_scores,
&reconstructed_values[..prefix_len],
dim,
backend,
);
query_time += query_start.elapsed().as_secs_f64();
metrics.observe(&exact_scores, &approx_scores, &exact_output, &approx_output);
}
}
let uncompressed_bytes = token_count * dim * std::mem::size_of::<f32>() * 2;
let compression_ratio = if compressed_bytes == 0 {
0.0
} else {
uncompressed_bytes as f64 / compressed_bytes as f64
};
let recall_mean =
(metrics.recall_at_1() + metrics.recall_at_5() + metrics.recall_at_10()) / 3.0;
let aggregate_score = aggregate_score(
metrics.logit_rmse(),
metrics.output_rmse(),
recall_mean,
compression_ratio,
);
Ok(BenchmarkRow {
workload: workload.name.clone(),
source: workload.source.clone(),
model: workload.model.clone(),
eval_mode: "approximation".to_string(),
algorithm: Some(algorithm.name().to_string()),
backend: backend.name().to_string(),
bits: Some(bits),
value_bits: None,
dim: Some(dim),
samples: workload.samples.len(),
tokens: token_count,
queries: query_count,
generated_tokens: None,
key_mse: if key_mse_batches == 0 {
None
} else {
Some(key_mse_sum / key_mse_batches as f64)
},
logit_rmse: Some(metrics.logit_rmse()),
output_rmse: Some(metrics.output_rmse()),
recall_at_1: Some(metrics.recall_at_1()),
recall_at_5: Some(metrics.recall_at_5()),
recall_at_10: Some(metrics.recall_at_10()),
top_k_agreement: None,
token_match_rate: None,
divergence_rate: None,
mean_first_divergence_step: None,
cross_entropy_exact: None,
cross_entropy_quantized: None,
perplexity_exact: None,
perplexity_quantized: None,
exact_latency_seconds: None,
quantized_latency_seconds: None,
exact_tokens_per_second: None,
quantized_tokens_per_second: None,
kv_memory_exact_bytes: None,
kv_memory_quantized_bytes: None,
compressed_bytes: Some(compressed_bytes),
uncompressed_bytes: Some(uncompressed_bytes),
compression_ratio: Some(compression_ratio),
quantize_tokens_per_second: Some(rate(token_count, quantize_time)),
query_tokens_per_second: Some(rate(query_count, query_time)),
aggregate_score: Some(aggregate_score),
})
}
enum PreparedKeys {
Mse {
batch: BatchQuantizedMSE,
reconstructed: Vec<Vec<f64>>,
},
Prod {
batch: BatchQuantizedProd,
},
Polar {
quantized: Vec<turboquant::polar::PolarQuantized>,
reconstructed: Vec<Vec<f64>>,
},
}
impl PreparedKeys {
fn key_bytes(&self) -> usize {
match self {
Self::Mse { batch, .. } => batch.total_bytes(),
Self::Prod { batch } => batch.total_bytes(),
Self::Polar { quantized, .. } => quantized
.iter()
.map(|item| item.bytes().ceil() as usize)
.sum(),
}
}
fn key_mse(
&self,
exact_keys: &[Vec<f64>],
dim: usize,
prod_quantizer: &TurboQuantProd,
) -> turboquant::Result<Option<f64>> {
match self {
Self::Mse { reconstructed, .. } | Self::Polar { reconstructed, .. } => {
Ok(Some(average_squared_error(exact_keys, reconstructed, dim)))
}
Self::Prod { batch } => {
if batch.len() != exact_keys.len() {
return Err(TurboQuantError::LengthMismatch {
context: "benchmark Prod key count".into(),
expected: batch.len(),
got: exact_keys.len(),
});
}
let mut total = 0.0;
for (index, exact_key) in exact_keys.iter().take(batch.len()).enumerate() {
let quantized = ProdQuantized {
mse_indices: batch.unpack_mse_indices(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing packed MSE row {index} in benchmark batch"
))
})?,
qjl_signs: batch.unpack_qjl_signs(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing packed QJL row {index} in benchmark batch"
))
})?,
residual_norm: batch.residual_norm(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing residual norm row {index} in benchmark batch"
))
})?,
bit_width: batch.bit_width,
dim: batch.dim,
};
let reconstructed = prod_quantizer.dequantize(&quantized)?;
total +=
sum_squared_error(ExecutionBackend::default(), exact_key, &reconstructed)
/ dim as f64;
}
Ok(Some(average(total, batch.len())))
}
}
}
}
fn score_mse(
backend: ExecutionBackend,
quantizer: &TurboQuantMSE,
batch: &BatchQuantizedMSE,
query: &[f64],
#[cfg(feature = "gpu")] runner: Option<&WgpuMseBatchRunner>,
) -> turboquant::Result<Vec<f64>> {
#[cfg(feature = "gpu")]
if matches!(backend, ExecutionBackend::Wgpu) {
let runner = runner
.ok_or_else(|| TurboQuantError::UnsupportedBackend("WGPU runner unavailable".into()))?;
return runner.attention_scores(quantizer, batch, query);
}
batch_attention_scores_mse_with_backend(backend, quantizer, batch, query)
}
fn score_prod(
backend: ExecutionBackend,
quantizer: &TurboQuantProd,
batch: &BatchQuantizedProd,
query: &[f64],
#[cfg(feature = "gpu")] runner: Option<&WgpuMseBatchRunner>,
) -> turboquant::Result<Vec<f64>> {
#[cfg(feature = "gpu")]
if matches!(backend, ExecutionBackend::Wgpu) {
let mse_indices = (0..batch.len())
.map(|index| {
batch.unpack_mse_indices(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing packed MSE row {index} in benchmark batch"
))
})
})
.collect::<turboquant::Result<Vec<_>>>()?;
let runner = runner
.ok_or_else(|| TurboQuantError::UnsupportedBackend("WGPU runner unavailable".into()))?;
let mut scores = runner.mse_stage_scores_from_indices(&mse_indices, query)?;
for (index, score) in scores.iter_mut().enumerate() {
let correction = quantizer.qjl_stage().estimate_inner_product(
&QJLQuantized {
signs: batch.unpack_qjl_signs(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing packed QJL row {index} in benchmark batch"
))
})?,
residual_norm: batch.residual_norm(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing residual norm row {index} in benchmark batch"
))
})?,
dim: batch.dim,
},
query,
)?;
*score += correction;
}
return Ok(scores);
}
batch_estimate_inner_products_with_backend(backend, quantizer, batch, query)
}
fn attention_scores_exact(keys: &[Vec<f64>], query: &[f64]) -> Vec<f64> {
keys.iter().map(|key| inner_product(key, query)).collect()
}
fn attention_output(
scores: &[f64],
values: &[Vec<f64>],
dim: usize,
backend: ExecutionBackend,
) -> Vec<f64> {
if scores.is_empty() {
return vec![0.0; dim];
}
let temperature = 1.0 / (dim as f64).sqrt();
let scaled: Vec<f64> = scores.iter().map(|score| score * temperature).collect();
let max_score = scaled.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exp_scores: Vec<f64> = scaled
.iter()
.map(|score| (score - max_score).exp())
.collect();
let sum_exp: f64 = exp_scores.iter().sum();
let weights: Vec<f64> = exp_scores.iter().map(|score| score / sum_exp).collect();
let mut output = vec![0.0; dim];
for (weight, value) in weights.iter().zip(values.iter()) {
weighted_sum_in_place(backend, &mut output, *weight, value);
}
output
}
fn average_squared_error(left: &[Vec<f64>], right: &[Vec<f64>], dim: usize) -> f64 {
let mut total = 0.0;
for (lhs, rhs) in left.iter().zip(right.iter()) {
total += sum_squared_error(ExecutionBackend::default(), lhs, rhs) / dim as f64;
}
average(total, left.len())
}
fn recall_at_k(exact: &[f64], approx: &[f64], k: usize) -> f64 {
let k = k.min(exact.len()).min(approx.len());
if k == 0 {
return 0.0;
}
let exact_top = top_k_indices(exact, k);
let approx_top = top_k_indices(approx, k);
let hits = approx_top
.iter()
.filter(|index| exact_top.contains(index))
.count();
hits as f64 / k as f64
}
fn top_k_indices(values: &[f64], k: usize) -> Vec<usize> {
let mut indices: Vec<usize> = (0..values.len()).collect();
indices.sort_by(|left, right| values[*right].total_cmp(&values[*left]));
indices.truncate(k);
indices
}
fn top_k_agreement(left: &[f32], right: &[f32], k: usize) -> f64 {
let left_top = top_k_indices_f32(left, k.max(1));
let right_top = top_k_indices_f32(right, k.max(1));
let left_set: HashSet<usize> = left_top.into_iter().collect();
let right_set: HashSet<usize> = right_top.into_iter().collect();
let overlap = left_set.intersection(&right_set).count();
overlap as f64 / k.max(1) as f64
}
fn top_k_indices_f32(values: &[f32], k: usize) -> Vec<usize> {
let mut indices: Vec<usize> = (0..values.len()).collect();
indices.sort_by(|left, right| values[*right].total_cmp(&values[*left]));
indices.truncate(k.min(indices.len()));
indices
}
fn negative_log_probability(logits: &[f32], token_index: usize) -> f64 {
if logits.is_empty() || token_index >= logits.len() {
return 0.0;
}
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let log_sum_exp = logits
.iter()
.map(|logit| (*logit - max_logit).exp() as f64)
.sum::<f64>()
.ln()
+ max_logit as f64;
log_sum_exp - logits[token_index] as f64
}
fn aggregate_score(
logit_rmse: f64,
output_rmse: f64,
recall_mean: f64,
compression_ratio: f64,
) -> f64 {
let logit_score = 1.0 / (1.0 + logit_rmse);
let output_score = 1.0 / (1.0 + output_rmse);
let memory_score = if compression_ratio <= 0.0 {
0.0
} else {
1.0 - (1.0 / compression_ratio).min(1.0)
};
100.0 * (0.35 * recall_mean + 0.25 * logit_score + 0.20 * output_score + 0.20 * memory_score)
}
fn max_queries_per_sample(total: usize, max_queries: usize) -> usize {
if max_queries == 0 {
total
} else {
total.min(max_queries)
}
}
fn rate(count: usize, seconds: f64) -> f64 {
if seconds <= f64::EPSILON {
0.0
} else {
count as f64 / seconds
}
}
fn average(total: f64, count: usize) -> f64 {
if count == 0 {
0.0
} else {
total / count as f64
}
}
fn print_report(report: &BenchmarkReport) {
println!("==============================================");
println!(" TurboQuant Benchmark ({})", report.mode);
println!("==============================================\n");
for row in &report.rows {
match row.source.as_str() {
"real-model" => {
println!(
"[real-model] {} / {} / {}",
row.model.as_deref().unwrap_or(&row.workload),
row.eval_mode,
row.algorithm.as_deref().unwrap_or("n/a")
);
println!(
" tokens={} generated={} logit_rmse={} topk={} token_match={} ratio={} exact_tps={} quant_tps={}",
row.tokens,
row.generated_tokens.unwrap_or(0),
format_opt(row.logit_rmse, 4),
format_opt(row.top_k_agreement, 3),
format_opt(row.token_match_rate, 3),
format_ratio(row.compression_ratio),
format_opt(row.exact_tokens_per_second, 1),
format_opt(row.quantized_tokens_per_second, 1),
);
println!(
" ce_exact={} ce_quantized={} exact_latency={}s quant_latency={}s kv_exact={} kv_quantized={}",
format_opt(row.cross_entropy_exact, 4),
format_opt(row.cross_entropy_quantized, 4),
format_opt(row.exact_latency_seconds, 3),
format_opt(row.quantized_latency_seconds, 3),
format_usize(row.kv_memory_exact_bytes),
format_usize(row.kv_memory_quantized_bytes),
);
}
_ => {
println!(
"[{}] {} / {} / {}",
row.source,
row.workload,
row.algorithm.as_deref().unwrap_or("n/a"),
row.backend
);
println!(
" dim={} logit_rmse={} output_rmse={} r1={} r5={} r10={} ratio={} score={}",
row.dim
.map(|dim| dim.to_string())
.unwrap_or_else(|| "-".to_string()),
format_opt(row.logit_rmse, 4),
format_opt(row.output_rmse, 4),
format_opt(row.recall_at_1, 3),
format_opt(row.recall_at_5, 3),
format_opt(row.recall_at_10, 3),
format_ratio(row.compression_ratio),
format_opt(row.aggregate_score, 1),
);
println!(
" quantize_tps={} query_tps={} bytes={} -> {}",
format_opt(row.quantize_tokens_per_second, 1),
format_opt(row.query_tokens_per_second, 1),
format_usize(row.uncompressed_bytes),
format_usize(row.compressed_bytes),
);
}
}
println!();
}
}
fn format_opt(value: Option<f64>, precision: usize) -> String {
value
.map(|value| format!("{value:.precision$}"))
.unwrap_or_else(|| "-".to_string())
}
fn format_ratio(value: Option<f64>) -> String {
value
.map(|value| format!("{value:.2}x"))
.unwrap_or_else(|| "-".to_string())
}
fn format_usize(value: Option<usize>) -> String {
value
.map(|value| value.to_string())
.unwrap_or_else(|| "-".to_string())
}