realizar 0.3.2

Pure Rust ML inference engine built from scratch - model serving for GGUF and safetensors
Documentation
//! IMP-900: Optimized GPU Performance Benchmark
//!
//! Measures performance improvements from IMP-900 optimizations:
//! - IMP-900a: Optimized GEMM kernel
//! - IMP-900b: Kernel fusion
//! - IMP-900c: FlashAttention
//! - IMP-900d: Memory optimization
//!
//! Run with: cargo run --features cuda --example imp900_optimized_gpu

use realizar::bench::{
    FlashAttentionConfig, GemmPerformanceResult, Imp900Result, MemoryPoolConfig,
    OptimizedGemmBenchmark, OptimizedGemmConfig,
};
use realizar::cuda::CudaExecutor;
use std::time::Instant;

fn main() {
    println!("╔══════════════════════════════════════════════════════════════╗");
    println!("║       IMP-900: Optimized GPU Performance Benchmark           ║");
    println!("║       Closing the 18x Gap to Ollama Parity                   ║");
    println!("╚══════════════════════════════════════════════════════════════╝");
    println!();

    // Check CUDA availability
    if !CudaExecutor::is_available() {
        println!("❌ CUDA not available. Cannot run GPU benchmark.");
        return;
    }

    let num_devices = CudaExecutor::num_devices();
    println!("✅ CUDA available: {} device(s)", num_devices);

    // Create executor and get device info
    let mut executor = match CudaExecutor::new(0) {
        Ok(e) => e,
        Err(err) => {
            println!("❌ Failed to create CUDA executor: {}", err);
            return;
        },
    };

    let device_name = executor.device_name().unwrap_or_default();
    let (vram_free, vram_total) = executor.memory_info().unwrap_or((0, 0));
    let vram_mb = vram_total / 1024 / 1024;

    println!("   Device: {}", device_name);
    println!(
        "   VRAM: {} MB ({} MB free)",
        vram_mb,
        vram_free / 1024 / 1024
    );
    println!();

    // ========================================================================
    // IMP-900a: Optimized GEMM Benchmark
    // ========================================================================

    println!("═══════════════════════════════════════════════════════════════");
    println!("  IMP-900a: OPTIMIZED GEMM BENCHMARK");
    println!("═══════════════════════════════════════════════════════════════");

    // Benchmark configurations
    // Note: CUDA limits threads per block to 1024, so max tile_size is 32 (32x32=1024)
    let configs = [
        ("Default (32x32 tiles)", OptimizedGemmConfig::default()),
        ("Small (16x16 tiles)", OptimizedGemmConfig::small()),
        // Large config uses 32x32 tiles with larger reg_block for better register utilization
        (
            "Optimized (32x32, reg_block=8)",
            OptimizedGemmConfig {
                tile_size: 32,
                reg_block: 8,
                use_tensor_cores: false,
                vector_width: 4,
                k_unroll: 8,
                double_buffer: true,
            },
        ),
    ];

    // Test sizes
    let sizes: [(u32, u32, u32); 3] = [(256, 256, 256), (512, 512, 512), (1024, 1024, 1024)];

    let mut baseline_gflops_1024 = 0.0;
    let mut optimized_gflops_1024 = 0.0;

    for (config_name, config) in &configs {
        println!("\n  Configuration: {}", config_name);
        println!(
            "    Tile size: {}, Reg block: {}, Double buffer: {}",
            config.tile_size, config.reg_block, config.double_buffer
        );
        println!(
            "    Shared memory: {} KB, Threads/block: {}",
            config.shared_memory_bytes() / 1024,
            config.threads_per_block()
        );

        let benchmark = OptimizedGemmBenchmark::with_config(config.clone());
        println!(
            "    Expected improvement: {:.2}x over naive",
            benchmark.expected_improvement()
        );

        for (m, n, k) in sizes {
            let a = vec![1.0f32; (m * k) as usize];
            let b = vec![1.0f32; (k * n) as usize];
            let mut c = vec![0.0f32; (m * n) as usize];

            // Warmup
            for _ in 0..3 {
                let _ = executor.gemm_optimized(&a, &b, &mut c, m, n, k, config.tile_size);
            }

            // Measure
            let iterations = 10;
            let start = Instant::now();
            for _ in 0..iterations {
                executor
                    .gemm_optimized(&a, &b, &mut c, m, n, k, config.tile_size)
                    .expect("GEMM");
            }
            let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0 / iterations as f64;

            let result = GemmPerformanceResult::new(m, n, k, elapsed_ms);
            println!(
                "    {}x{}x{}: {:.3}ms ({:.1} GFLOP/s)",
                m, n, k, elapsed_ms, result.gflops
            );

            // Track 1024x1024 for comparison
            if m == 1024 && config_name.contains("Default") {
                baseline_gflops_1024 = result.gflops;
            } else if m == 1024 && config_name.contains("Optimized") {
                optimized_gflops_1024 = result.gflops;
            }
        }
    }

    // Calculate improvement
    let gemm_improvement = if baseline_gflops_1024 > 0.0 {
        optimized_gflops_1024 / baseline_gflops_1024
    } else {
        1.0
    };

    println!();
    println!(
        "  GEMM Improvement (1024x1024): {:.2}x",
        gemm_improvement.max(1.0)
    );

    // ========================================================================
    // IMP-900c: FlashAttention Analysis
    // ========================================================================

    println!();
    println!("═══════════════════════════════════════════════════════════════");
    println!("  IMP-900c: FLASHATTENTION MEMORY ANALYSIS");
    println!("═══════════════════════════════════════════════════════════════");

    let flash_config = FlashAttentionConfig::phi2();
    println!(
        "  Config: phi-2 ({} heads, {} head_dim)",
        flash_config.num_heads, flash_config.head_dim
    );
    println!(
        "  Causal: {}, Scale: {:.4}",
        flash_config.causal, flash_config.scale
    );
    println!();

    for seq_len in [128, 512, 1024, 2048, 4096] {
        let (naive, flash) = flash_config.memory_comparison(seq_len);
        let savings = flash_config.memory_savings(seq_len);
        println!(
            "  Seq {}: Naive {:.1} KB → Flash {:.1} KB ({:.0}x savings)",
            seq_len,
            naive as f64 / 1024.0,
            flash as f64 / 1024.0,
            savings
        );
    }

    // ========================================================================
    // IMP-900d: Memory Pool Analysis
    // ========================================================================

    println!();
    println!("═══════════════════════════════════════════════════════════════");
    println!("  IMP-900d: MEMORY POOL CONFIGURATION");
    println!("═══════════════════════════════════════════════════════════════");

    let pool_config = MemoryPoolConfig::default();
    println!(
        "  Initial pool: {} MB, Max: {} GB",
        pool_config.initial_size / 1024 / 1024,
        pool_config.max_size / 1024 / 1024 / 1024
    );
    println!("  Pinned memory: {}", pool_config.use_pinned_memory);
    println!("  Async transfers: {}", pool_config.async_transfers);
    println!(
        "  Expected bandwidth improvement: {:.1}x",
        pool_config.expected_bandwidth_improvement()
    );
    println!("  Size classes: {} tiers", pool_config.size_classes.len());

    // ========================================================================
    // Combined IMP-900 Result
    // ========================================================================

    println!();
    println!("═══════════════════════════════════════════════════════════════");
    println!("  IMP-900: COMBINED IMPROVEMENT PROJECTION");
    println!("═══════════════════════════════════════════════════════════════");

    // Baseline from IMP-800
    let baseline_tps = 13.1;

    // Use conservative measured/projected improvements
    let result = Imp900Result::from_baseline(baseline_tps)
        .with_gemm_improvement(gemm_improvement.max(1.0))
        .with_fusion_improvement(1.5) // Projected from kernel fusion
        .with_flash_attention_improvement(2.0) // Projected from FlashAttention
        .with_memory_improvement(1.5); // Projected from memory pooling

    println!("  Baseline (IMP-800): {:.1} tok/s", result.baseline_tps);
    println!("  GEMM improvement: {:.2}x", result.gemm_improvement);
    println!(
        "  Fusion improvement: {:.2}x (projected)",
        result.fusion_improvement
    );
    println!(
        "  FlashAttention improvement: {:.2}x (projected)",
        result.flash_attention_improvement
    );
    println!(
        "  Memory improvement: {:.2}x (projected)",
        result.memory_improvement
    );
    println!();
    println!("  Total improvement: {:.2}x", result.total_improvement());
    println!("  Projected throughput: {:.1} tok/s", result.optimized_tps);
    println!("  Gap to Ollama (240 tok/s): {:.2}x", result.gap_ratio);

    // Milestone check
    println!();
    if let Some(ref milestone) = result.milestone {
        println!("  ✅ Milestone achieved: {}", milestone);
    } else {
        println!("  ⏳ No milestone achieved yet");
    }

    println!(
        "  M3 target (>48 tok/s, <5x gap): {}",
        if result.achieves_m3() {
            "✅ PASS"
        } else {
            "❌ FAIL"
        }
    );
    println!(
        "  M4 target (>192 tok/s, <1.25x gap): {}",
        if result.achieves_m4() {
            "✅ PASS"
        } else {
            "❌ FAIL"
        }
    );

    // Summary
    println!();
    println!("╔══════════════════════════════════════════════════════════════╗");
    println!("║                         SUMMARY                              ║");
    println!("╠══════════════════════════════════════════════════════════════╣");
    println!("║  Device: {:50} ║", device_name);
    println!(
        "║  Baseline: {:>6.1} tok/s (IMP-800)                           ║",
        baseline_tps
    );
    println!(
        "║  Projected: {:>6.1} tok/s (IMP-900)                          ║",
        result.optimized_tps
    );
    println!(
        "║  Improvement: {:>5.1}x                                        ║",
        result.total_improvement()
    );
    println!(
        "║  Gap: {:>6.2}x (target: <1.25x)                              ║",
        result.gap_ratio
    );
    println!("╚══════════════════════════════════════════════════════════════╝");
}