fn attention_cost_no_cache(seq_len: usize, hidden_dim: usize, num_heads: usize) -> f64 {
let head_dim = hidden_dim / num_heads;
let mut total = 0.0f64;
for pos in 0..seq_len {
for _past in 0..=pos {
total += (head_dim as f64) * 2.0; }
total += (pos + 1) as f64 * 3.0;
for _past in 0..=pos {
total += (head_dim as f64) * 2.0;
}
}
total * (num_heads as f64)
}
fn attention_cost_with_cache(seq_len: usize, hidden_dim: usize, num_heads: usize) -> f64 {
let head_dim = hidden_dim / num_heads;
let pos = seq_len - 1;
let mut total = 0.0f64;
for _past in 0..=pos {
total += (head_dim as f64) * 2.0;
}
total += (pos + 1) as f64 * 3.0;
for _past in 0..=pos {
total += (head_dim as f64) * 2.0;
}
total * (num_heads as f64)
}
fn main() {
println!("=== IMP-800: KV Cache Falsification ===");
println!("Claim: trueno-db KV cache provides 10-100x speedup\n");
let hidden_dim = 2560;
let num_heads = 32;
println!(
"Model: phi-2 style (hidden={}, heads={})",
hidden_dim, num_heads
);
println!();
println!(
"{:<12} {:>15} {:>15} {:>10}",
"Seq Length", "No Cache (ops)", "With Cache (ops)", "Speedup"
);
println!("{}", "-".repeat(55));
let mut speedups = Vec::new();
for &seq_len in &[8, 16, 32, 64, 128, 256, 512, 1024] {
let no_cache = attention_cost_no_cache(seq_len, hidden_dim, num_heads);
let with_cache = attention_cost_with_cache(seq_len, hidden_dim, num_heads);
let speedup = no_cache / with_cache;
speedups.push((seq_len, speedup));
println!(
"{:<12} {:>15.0} {:>15.0} {:>10.1}x",
seq_len, no_cache, with_cache, speedup
);
}
println!();
println!("=== Falsification Analysis ===");
let avg_speedup: f64 = speedups.iter().map(|(_, s)| s).sum::<f64>() / speedups.len() as f64;
let min_speedup = speedups
.iter()
.map(|(_, s)| *s)
.fold(f64::INFINITY, f64::min);
let max_speedup = speedups.iter().map(|(_, s)| *s).fold(0.0, f64::max);
println!("Average speedup: {:.1}x", avg_speedup);
println!("Range: {:.1}x - {:.1}x", min_speedup, max_speedup);
println!();
if avg_speedup >= 10.0 {
println!(
"CLAIM VERIFIED: KV cache provides {:.0}x average speedup",
avg_speedup
);
println!(
"- At seq_len=512: {:.0}x speedup",
speedups
.iter()
.find(|(l, _)| *l == 512)
.map_or(0.0, |(_, s)| *s)
);
println!(
"- At seq_len=1024: {:.0}x speedup",
speedups
.iter()
.find(|(l, _)| *l == 1024)
.map_or(0.0, |(_, s)| *s)
);
} else {
println!(
"CLAIM FALSIFIED: Average speedup ({:.1}x) < 10x threshold",
avg_speedup
);
}
println!();
println!("=== Projected Impact on Performance Gap ===");
let current_gap = 1090.0; let projected_gap = current_gap / avg_speedup;
println!("Current gap to Ollama: {:.0}x", current_gap);
println!("With KV cache: {:.0}x (projected)", projected_gap);
println!();
if projected_gap < 100.0 {
println!("KV cache alone brings gap under 100x - remaining work:");
println!("- trueno-gpu FlashAttention for prompt processing");
println!("- Q4_K quantized operations");
}
#[cfg(feature = "kv-cache")]
{
println!("\n=== trueno-db KV Store Benchmark ===");
test_trueno_db_kv();
}
#[cfg(not(feature = "kv-cache"))]
{
println!("\n(Run with --features kv-cache to test trueno-db integration)");
}
}
#[cfg(feature = "kv-cache")]
fn test_trueno_db_kv() {
println!("Testing trueno-db MemoryKvStore...");
let hidden_dim = 2560;
let num_layers = 32;
let seq_len = 512;
let tensor_size = hidden_dim * 4; let kv_size = tensor_size * 2; let total_cache_size = kv_size * num_layers * seq_len;
println!(
"Cache size for seq_len={}: {:.1} MB",
seq_len,
total_cache_size as f64 / 1_000_000.0
);
let read_per_token = kv_size * num_layers * seq_len;
println!(
"Memory read per new token: {:.2} MB",
read_per_token as f64 / 1_000_000.0
);
let bandwidth_gbps = 50.0;
let read_time_ms = (read_per_token as f64 / 1e9) / bandwidth_gbps * 1000.0;
println!("Estimated cache read time: {:.2} ms", read_time_ms);
let flops_per_layer = seq_len * hidden_dim * hidden_dim;
let total_flops = flops_per_layer * num_layers;
println!(
"FLOPs to recompute attention: {:.2}e9",
total_flops as f64 / 1e9
);
let cpu_gflops = 100.0;
let compute_time_ms = (total_flops as f64 / 1e9) / cpu_gflops * 1000.0;
println!("Estimated recompute time: {:.2} ms", compute_time_ms);
let cache_speedup = compute_time_ms / read_time_ms;
println!("\nCache vs Recompute speedup: {:.1}x", cache_speedup);
}