tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Benchmark for the persistent HIP kernel server.
//!
//! Compares per-call latency between:
//!   - the legacy one-shot text path (spawn fresh per call)
//!   - the new persistent kernel server path (one long-lived child per
//!     kernel_type, demuxed by a length-prefixed binary protocol)
//!
//! Both paths must produce numerically identical results (the
//! persistent server is a pure I/O refactor; the kernel compute code
//! is unchanged). The benchmark is the gate that proves the
//! persistent path is the right design.

#![cfg(feature = "rocm-hip")]

use std::time::Instant;

use tokitai_operator::backend::hip_gemm_f16::{
    cpu_gemm_f16, f32_to_f16, hip_gemm_f16_kernel_source_fingerprint, run_rocm_hip_gemm_f16,
};
use tokitai_operator::backend::kernel_server;

/// Locate the cached executable for the **current** kernel source. The
/// production `run_rocm_hip_gemm_f16` uses the fingerprint of the
/// `HIP_GEMM_F16_KERNEL` source string, so we use the same fingerprint
/// here — this guarantees the bench is testing the binary that the
/// production path will actually invoke. (A naive `read_dir` of the
/// cache would pick the first hit and might return a stale binary
/// from a previous kernel revision.)
fn locate_gemm_kernel() -> std::path::PathBuf {
    let cache_dir = std::path::PathBuf::from("target/rocm-hip-cache");
    let fingerprint = hip_gemm_f16_kernel_source_fingerprint();
    let executable = cache_dir.join(format!("{fingerprint}-gemm-fp16-f32"));
    assert!(
        executable.exists(),
        "expected cached gemm kernel at {} (run run_rocm_hip_gemm_f16 first to populate)",
        executable.display()
    );
    executable
}

/// Deterministic pseudo-random u16 values in roughly [-0.05, 0.05).
/// Same shape the existing `rocm_hip_gemm_f16` test uses, so the
/// results are bit-comparable.
fn make_matrix(rows: usize, cols: usize, seed: u32) -> Vec<u16> {
    let mut state: u32 = seed.wrapping_add(1);
    let mut out = Vec::with_capacity(rows * cols);
    for _ in 0..rows * cols {
        state ^= state << 13;
        state ^= state >> 17;
        state ^= state << 5;
        let scaled = (state as f32 / u32::MAX as f32) * 0.1 - 0.05;
        out.push(f32_to_f16(scaled));
    }
    out
}

/// Build the text payload for a 64x64x64 fp16 GEMM. The shape must
/// match the one `run_rocm_hip_gemm_f16` builds, byte-for-byte, so
/// the persistent and one-shot paths produce identical kernel input.
fn build_payload_64x64(a: &[u16], b: &[u16]) -> String {
    const M: usize = 64;
    const N: usize = 64;
    const K: usize = 64;
    assert_eq!(a.len(), M * K);
    assert_eq!(b.len(), K * N);
    let mut payload = String::with_capacity((a.len() + b.len()) * 8);
    payload.push_str(&format!("{M} {N} {K}\n"));
    for (i, v) in a.iter().enumerate() {
        if i != 0 {
            payload.push(' ');
        }
        payload.push_str(&v.to_string());
    }
    payload.push('\n');
    for (i, v) in b.iter().enumerate() {
        if i != 0 {
            payload.push(' ');
        }
        payload.push_str(&v.to_string());
    }
    payload.push('\n');
    payload
}

/// Drive 10 calls through the legacy one-shot text path and report
/// the mean per-call latency. This is the baseline.
#[test]
fn kernel_server_bench_oneshot_64x64() {
    let m = 64usize;
    let n = 64usize;
    let k = 64usize;
    let a = make_matrix(m, k, 0xC0FFEE_42);
    let b = make_matrix(k, n, 0xBADA55_99);
    let payload = build_payload_64x64(&a, &b);

    // Locate the cached executable. We use the same path the
    // production wrapper would use so the benchmark is
    // apples-to-apples with the production path.
    let executable = locate_gemm_kernel();

    // Warmup: compile + spawn once to populate the page cache so
    // the first measured call doesn't pay the cold-start cost.
    let _ = kernel_server::KernelServer::oneshot(&executable, &payload).expect("oneshot warmup");

    // Measured calls: 10 one-shot invocations.
    const ITERATIONS: usize = 10;
    let start = Instant::now();
    let mut last_response = String::new();
    for _ in 0..ITERATIONS {
        last_response =
            kernel_server::KernelServer::oneshot(&executable, &payload).expect("oneshot call");
    }
    let elapsed = start.elapsed();
    let mean_ms = elapsed.as_secs_f64() * 1000.0 / ITERATIONS as f64;

    // Sanity: the response must contain a RESULTS=... line with 4096 floats
    // (64 * 64).
    assert!(
        last_response.contains("RESULTS="),
        "one-shot response missing RESULTS= marker"
    );
    eprintln!(
        "[kernel_server_bench] one-shot 64x64: mean_per_call={:.3}ms total={:.3}ms ({} iters)",
        mean_ms,
        elapsed.as_secs_f64() * 1000.0,
        ITERATIONS
    );
}

/// Drive 10 calls through the new persistent kernel server path and
/// report the mean per-call latency. First call pays the spawn
/// cost; subsequent calls are pure I/O round-trips.
#[test]
fn kernel_server_bench_persistent_64x64() {
    let m = 64usize;
    let n = 64usize;
    let k = 64usize;
    let a = make_matrix(m, k, 0xC0FFEE_42);
    let b = make_matrix(k, n, 0xBADA55_99);
    let payload = build_payload_64x64(&a, &b);

    let executable = locate_gemm_kernel();

    const ITERATIONS: usize = 10;
    let start = Instant::now();
    let mut last_response = String::new();
    for _ in 0..ITERATIONS {
        last_response =
            kernel_server::run_persistent("hip-gemm-f16-fwd-bench", &executable, &payload)
                .expect("persistent call");
    }
    let elapsed = start.elapsed();
    let mean_ms = elapsed.as_secs_f64() * 1000.0 / ITERATIONS as f64;

    assert!(
        last_response.contains("RESULTS="),
        "persistent response missing RESULTS= marker"
    );
    eprintln!(
        "[kernel_server_bench] persistent 64x64: mean_per_call={:.3}ms total={:.3}ms ({} iters)",
        mean_ms,
        elapsed.as_secs_f64() * 1000.0,
        ITERATIONS
    );
}

/// Correctness gate: the persistent server must produce the same
/// `RocmHipGemmF16Report` as the production `run_rocm_hip_gemm_f16`
/// (which now uses the persistent server under the hood). This
/// test pins the result format so a future protocol refactor that
/// drops a `RESULTS=` marker fails the test rather than silently
/// producing a malformed report.
#[test]
fn kernel_server_persistent_matches_oneshot_for_64x64() {
    let m = 64usize;
    let n = 64usize;
    let k = 64usize;
    let a = make_matrix(m, k, 0xDEAD_BEEF);
    let b = make_matrix(k, n, 0xCAFE_F00D);
    let payload = build_payload_64x64(&a, &b);

    let executable = locate_gemm_kernel();

    let oneshot_response =
        kernel_server::KernelServer::oneshot(&executable, &payload).expect("oneshot call");
    let persistent_response =
        kernel_server::run_persistent("hip-gemm-f16-fwd-correctness", &executable, &payload)
            .expect("persistent call");

    // Both responses must contain the same RESULTS= line.
    let oneshot_results = extract_results_line(&oneshot_response);
    let persistent_results = extract_results_line(&persistent_response);
    assert_eq!(
        oneshot_results, persistent_results,
        "persistent and one-shot responses differ"
    );
    eprintln!(
        "[kernel_server_bench] correctness: persistent matches one-shot, results line has {} chars",
        persistent_results.len()
    );

    // And the production `run_rocm_hip_gemm_f16` must succeed (it
    // now routes through the persistent server). We check the
    // `within_tolerance` flag (max_abs_error < 1e-2) rather than
    // exact equality because the kernel does the K-accumulation in
    // fp32 and the CPU oracle round-trips the result through fp16;
    // these will differ by ~1ulp at typical scales, but always
    // well under the 1e-2 tolerance the report asserts.
    let report = run_rocm_hip_gemm_f16(&a, &b, m, n, k).expect("production gemm call");
    assert!(report.within_tolerance, "production gemm failed tolerance");
    let cpu = cpu_gemm_f16(&a, &b, m, n, k);
    let mut prod_max_err = 0.0f32;
    for (g, c) in report.outputs.iter().zip(cpu.iter()) {
        prod_max_err = prod_max_err.max((g - c).abs());
    }
    assert!(
        prod_max_err < 1e-2,
        "production gemm output drifts from CPU oracle: max_abs_err={prod_max_err}"
    );
}

fn extract_results_line(response: &str) -> &str {
    response
        .lines()
        .find_map(|line| line.strip_prefix("RESULTS="))
        .unwrap_or("")
}