#![cfg(feature = "rocm-hip")]
use std::time::Instant;
use tokitai_operator::backend::hip_gemm_f16::{
cpu_gemm_f16, f32_to_f16, hip_gemm_f16_kernel_source_fingerprint, run_rocm_hip_gemm_f16,
};
use tokitai_operator::backend::kernel_server;
fn locate_gemm_kernel() -> std::path::PathBuf {
let cache_dir = std::path::PathBuf::from("target/rocm-hip-cache");
let fingerprint = hip_gemm_f16_kernel_source_fingerprint();
let executable = cache_dir.join(format!("{fingerprint}-gemm-fp16-f32"));
assert!(
executable.exists(),
"expected cached gemm kernel at {} (run run_rocm_hip_gemm_f16 first to populate)",
executable.display()
);
executable
}
fn make_matrix(rows: usize, cols: usize, seed: u32) -> Vec<u16> {
let mut state: u32 = seed.wrapping_add(1);
let mut out = Vec::with_capacity(rows * cols);
for _ in 0..rows * cols {
state ^= state << 13;
state ^= state >> 17;
state ^= state << 5;
let scaled = (state as f32 / u32::MAX as f32) * 0.1 - 0.05;
out.push(f32_to_f16(scaled));
}
out
}
fn build_payload_64x64(a: &[u16], b: &[u16]) -> String {
const M: usize = 64;
const N: usize = 64;
const K: usize = 64;
assert_eq!(a.len(), M * K);
assert_eq!(b.len(), K * N);
let mut payload = String::with_capacity((a.len() + b.len()) * 8);
payload.push_str(&format!("{M} {N} {K}\n"));
for (i, v) in a.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&v.to_string());
}
payload.push('\n');
for (i, v) in b.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&v.to_string());
}
payload.push('\n');
payload
}
#[test]
fn kernel_server_bench_oneshot_64x64() {
let m = 64usize;
let n = 64usize;
let k = 64usize;
let a = make_matrix(m, k, 0xC0FFEE_42);
let b = make_matrix(k, n, 0xBADA55_99);
let payload = build_payload_64x64(&a, &b);
let executable = locate_gemm_kernel();
let _ = kernel_server::KernelServer::oneshot(&executable, &payload).expect("oneshot warmup");
const ITERATIONS: usize = 10;
let start = Instant::now();
let mut last_response = String::new();
for _ in 0..ITERATIONS {
last_response =
kernel_server::KernelServer::oneshot(&executable, &payload).expect("oneshot call");
}
let elapsed = start.elapsed();
let mean_ms = elapsed.as_secs_f64() * 1000.0 / ITERATIONS as f64;
assert!(
last_response.contains("RESULTS="),
"one-shot response missing RESULTS= marker"
);
eprintln!(
"[kernel_server_bench] one-shot 64x64: mean_per_call={:.3}ms total={:.3}ms ({} iters)",
mean_ms,
elapsed.as_secs_f64() * 1000.0,
ITERATIONS
);
}
#[test]
fn kernel_server_bench_persistent_64x64() {
let m = 64usize;
let n = 64usize;
let k = 64usize;
let a = make_matrix(m, k, 0xC0FFEE_42);
let b = make_matrix(k, n, 0xBADA55_99);
let payload = build_payload_64x64(&a, &b);
let executable = locate_gemm_kernel();
const ITERATIONS: usize = 10;
let start = Instant::now();
let mut last_response = String::new();
for _ in 0..ITERATIONS {
last_response =
kernel_server::run_persistent("hip-gemm-f16-fwd-bench", &executable, &payload)
.expect("persistent call");
}
let elapsed = start.elapsed();
let mean_ms = elapsed.as_secs_f64() * 1000.0 / ITERATIONS as f64;
assert!(
last_response.contains("RESULTS="),
"persistent response missing RESULTS= marker"
);
eprintln!(
"[kernel_server_bench] persistent 64x64: mean_per_call={:.3}ms total={:.3}ms ({} iters)",
mean_ms,
elapsed.as_secs_f64() * 1000.0,
ITERATIONS
);
}
#[test]
fn kernel_server_persistent_matches_oneshot_for_64x64() {
let m = 64usize;
let n = 64usize;
let k = 64usize;
let a = make_matrix(m, k, 0xDEAD_BEEF);
let b = make_matrix(k, n, 0xCAFE_F00D);
let payload = build_payload_64x64(&a, &b);
let executable = locate_gemm_kernel();
let oneshot_response =
kernel_server::KernelServer::oneshot(&executable, &payload).expect("oneshot call");
let persistent_response =
kernel_server::run_persistent("hip-gemm-f16-fwd-correctness", &executable, &payload)
.expect("persistent call");
let oneshot_results = extract_results_line(&oneshot_response);
let persistent_results = extract_results_line(&persistent_response);
assert_eq!(
oneshot_results, persistent_results,
"persistent and one-shot responses differ"
);
eprintln!(
"[kernel_server_bench] correctness: persistent matches one-shot, results line has {} chars",
persistent_results.len()
);
let report = run_rocm_hip_gemm_f16(&a, &b, m, n, k).expect("production gemm call");
assert!(report.within_tolerance, "production gemm failed tolerance");
let cpu = cpu_gemm_f16(&a, &b, m, n, k);
let mut prod_max_err = 0.0f32;
for (g, c) in report.outputs.iter().zip(cpu.iter()) {
prod_max_err = prod_max_err.max((g - c).abs());
}
assert!(
prod_max_err < 1e-2,
"production gemm output drifts from CPU oracle: max_abs_err={prod_max_err}"
);
}
fn extract_results_line(response: &str) -> &str {
response
.lines()
.find_map(|line| line.strip_prefix("RESULTS="))
.unwrap_or("")
}