use realizar::bench::{
FlashAttentionConfig, GemmPerformanceResult, Imp900Result, MemoryPoolConfig,
OptimizedGemmBenchmark, OptimizedGemmConfig,
};
use realizar::cuda::CudaExecutor;
use std::time::Instant;
fn main() {
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ IMP-900: Optimized GPU Performance Benchmark ║");
println!("║ Closing the 18x Gap to Ollama Parity ║");
println!("╚══════════════════════════════════════════════════════════════╝");
println!();
if !CudaExecutor::is_available() {
println!("❌ CUDA not available. Cannot run GPU benchmark.");
return;
}
let num_devices = CudaExecutor::num_devices();
println!("✅ CUDA available: {} device(s)", num_devices);
let mut executor = match CudaExecutor::new(0) {
Ok(e) => e,
Err(err) => {
println!("❌ Failed to create CUDA executor: {}", err);
return;
},
};
let device_name = executor.device_name().unwrap_or_default();
let (vram_free, vram_total) = executor.memory_info().unwrap_or((0, 0));
let vram_mb = vram_total / 1024 / 1024;
println!(" Device: {}", device_name);
println!(
" VRAM: {} MB ({} MB free)",
vram_mb,
vram_free / 1024 / 1024
);
println!();
println!("═══════════════════════════════════════════════════════════════");
println!(" IMP-900a: OPTIMIZED GEMM BENCHMARK");
println!("═══════════════════════════════════════════════════════════════");
let configs = [
("Default (32x32 tiles)", OptimizedGemmConfig::default()),
("Small (16x16 tiles)", OptimizedGemmConfig::small()),
(
"Optimized (32x32, reg_block=8)",
OptimizedGemmConfig {
tile_size: 32,
reg_block: 8,
use_tensor_cores: false,
vector_width: 4,
k_unroll: 8,
double_buffer: true,
},
),
];
let sizes: [(u32, u32, u32); 3] = [(256, 256, 256), (512, 512, 512), (1024, 1024, 1024)];
let mut baseline_gflops_1024 = 0.0;
let mut optimized_gflops_1024 = 0.0;
for (config_name, config) in &configs {
println!("\n Configuration: {}", config_name);
println!(
" Tile size: {}, Reg block: {}, Double buffer: {}",
config.tile_size, config.reg_block, config.double_buffer
);
println!(
" Shared memory: {} KB, Threads/block: {}",
config.shared_memory_bytes() / 1024,
config.threads_per_block()
);
let benchmark = OptimizedGemmBenchmark::with_config(config.clone());
println!(
" Expected improvement: {:.2}x over naive",
benchmark.expected_improvement()
);
for (m, n, k) in sizes {
let a = vec![1.0f32; (m * k) as usize];
let b = vec![1.0f32; (k * n) as usize];
let mut c = vec![0.0f32; (m * n) as usize];
for _ in 0..3 {
let _ = executor.gemm_optimized(&a, &b, &mut c, m, n, k, config.tile_size);
}
let iterations = 10;
let start = Instant::now();
for _ in 0..iterations {
executor
.gemm_optimized(&a, &b, &mut c, m, n, k, config.tile_size)
.expect("GEMM");
}
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0 / iterations as f64;
let result = GemmPerformanceResult::new(m, n, k, elapsed_ms);
println!(
" {}x{}x{}: {:.3}ms ({:.1} GFLOP/s)",
m, n, k, elapsed_ms, result.gflops
);
if m == 1024 && config_name.contains("Default") {
baseline_gflops_1024 = result.gflops;
} else if m == 1024 && config_name.contains("Optimized") {
optimized_gflops_1024 = result.gflops;
}
}
}
let gemm_improvement = if baseline_gflops_1024 > 0.0 {
optimized_gflops_1024 / baseline_gflops_1024
} else {
1.0
};
println!();
println!(
" GEMM Improvement (1024x1024): {:.2}x",
gemm_improvement.max(1.0)
);
println!();
println!("═══════════════════════════════════════════════════════════════");
println!(" IMP-900c: FLASHATTENTION MEMORY ANALYSIS");
println!("═══════════════════════════════════════════════════════════════");
let flash_config = FlashAttentionConfig::phi2();
println!(
" Config: phi-2 ({} heads, {} head_dim)",
flash_config.num_heads, flash_config.head_dim
);
println!(
" Causal: {}, Scale: {:.4}",
flash_config.causal, flash_config.scale
);
println!();
for seq_len in [128, 512, 1024, 2048, 4096] {
let (naive, flash) = flash_config.memory_comparison(seq_len);
let savings = flash_config.memory_savings(seq_len);
println!(
" Seq {}: Naive {:.1} KB → Flash {:.1} KB ({:.0}x savings)",
seq_len,
naive as f64 / 1024.0,
flash as f64 / 1024.0,
savings
);
}
println!();
println!("═══════════════════════════════════════════════════════════════");
println!(" IMP-900d: MEMORY POOL CONFIGURATION");
println!("═══════════════════════════════════════════════════════════════");
let pool_config = MemoryPoolConfig::default();
println!(
" Initial pool: {} MB, Max: {} GB",
pool_config.initial_size / 1024 / 1024,
pool_config.max_size / 1024 / 1024 / 1024
);
println!(" Pinned memory: {}", pool_config.use_pinned_memory);
println!(" Async transfers: {}", pool_config.async_transfers);
println!(
" Expected bandwidth improvement: {:.1}x",
pool_config.expected_bandwidth_improvement()
);
println!(" Size classes: {} tiers", pool_config.size_classes.len());
println!();
println!("═══════════════════════════════════════════════════════════════");
println!(" IMP-900: COMBINED IMPROVEMENT PROJECTION");
println!("═══════════════════════════════════════════════════════════════");
let baseline_tps = 13.1;
let result = Imp900Result::from_baseline(baseline_tps)
.with_gemm_improvement(gemm_improvement.max(1.0))
.with_fusion_improvement(1.5) .with_flash_attention_improvement(2.0) .with_memory_improvement(1.5);
println!(" Baseline (IMP-800): {:.1} tok/s", result.baseline_tps);
println!(" GEMM improvement: {:.2}x", result.gemm_improvement);
println!(
" Fusion improvement: {:.2}x (projected)",
result.fusion_improvement
);
println!(
" FlashAttention improvement: {:.2}x (projected)",
result.flash_attention_improvement
);
println!(
" Memory improvement: {:.2}x (projected)",
result.memory_improvement
);
println!();
println!(" Total improvement: {:.2}x", result.total_improvement());
println!(" Projected throughput: {:.1} tok/s", result.optimized_tps);
println!(" Gap to Ollama (240 tok/s): {:.2}x", result.gap_ratio);
println!();
if let Some(ref milestone) = result.milestone {
println!(" ✅ Milestone achieved: {}", milestone);
} else {
println!(" ⏳ No milestone achieved yet");
}
println!(
" M3 target (>48 tok/s, <5x gap): {}",
if result.achieves_m3() {
"✅ PASS"
} else {
"❌ FAIL"
}
);
println!(
" M4 target (>192 tok/s, <1.25x gap): {}",
if result.achieves_m4() {
"✅ PASS"
} else {
"❌ FAIL"
}
);
println!();
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ SUMMARY ║");
println!("╠══════════════════════════════════════════════════════════════╣");
println!("║ Device: {:50} ║", device_name);
println!(
"║ Baseline: {:>6.1} tok/s (IMP-800) ║",
baseline_tps
);
println!(
"║ Projected: {:>6.1} tok/s (IMP-900) ║",
result.optimized_tps
);
println!(
"║ Improvement: {:>5.1}x ║",
result.total_improvement()
);
println!(
"║ Gap: {:>6.2}x (target: <1.25x) ║",
result.gap_ratio
);
println!("╚══════════════════════════════════════════════════════════════╝");
}