#![cfg(feature = "rocm-hip")]
use std::fs;
use std::io::Write;
use std::path::PathBuf;
use std::process::{Command, Stdio};
use std::time::Instant;
use tokitai_operator::backend::hip_embedding::{
cpu_embedding_bwd, cpu_embedding_fwd, f16_to_f32 as emb_f16_to_f32,
f32_to_f16 as emb_f32_to_f16, run_rocm_hip_embedding_bwd_with_report,
run_rocm_hip_embedding_fwd_with_report,
};
use tokitai_operator::backend::hip_gelu::{cpu_gelu_fwd_fp16, run_rocm_hip_gelu_fwd_reported};
use tokitai_operator::backend::hip_gelu_bw::run_rocm_hip_gelu_bwd_reported;
use tokitai_operator::backend::hip_gemm_bw::{
cpu_gemm_bw_grad_a, cpu_gemm_bw_grad_b, run_rocm_hip_gemm_bw_grad_a,
run_rocm_hip_gemm_bw_grad_b,
};
use tokitai_operator::backend::hip_gemm_f16::{cpu_gemm_f16, f32_to_f16, run_rocm_hip_gemm_f16};
use tokitai_operator::backend::hip_layernorm::{
run_rocm_hip_layernorm_bwd, run_rocm_hip_layernorm_fwd,
};
use tokitai_operator::backend::hip_sheaf_overlap_check::run_rocm_hip_sheaf_overlap_check;
use tokitai_operator::backend::hip_softmax::{
run_rocm_hip_grad_loss_wrt_logits, run_rocm_hip_softmax_fwd,
};
use tokitai_operator::backend::rocm::{RocmHipCapabilityReport, detect_local_rocm_hip};
use tokitai_operator::object::sheaf::PrecisionClass;
mod _adamw_mod {
include!("../backend/hip_adamw.rs");
}
mod _padic_mod {
include!("../backend/hip_padic_codec.rs");
}
const GFX_TARGET: &str = "gfx1101";
const COPY_BYTES: usize = 64 * 1024 * 1024; const COPY_ITERS: usize = 10;
#[derive(Debug, Clone)]
struct KernelResult {
name: &'static str,
suite: &'static str,
problem: String,
gpu_ms: f32,
cpu_ms: f64,
speedup: f64,
max_abs_error: f32,
within_tolerance: bool,
error: Option<String>,
}
impl KernelResult {
fn failed(
name: &'static str,
suite: &'static str,
problem: impl Into<String>,
err: String,
) -> Self {
Self {
name,
suite,
problem: problem.into(),
gpu_ms: 0.0,
cpu_ms: 0.0,
speedup: 0.0,
max_abs_error: f32::NAN,
within_tolerance: false,
error: Some(err),
}
}
fn to_row(&self) -> String {
if let Some(err) = &self.error {
return format!(
"| {} | `{}` | ERROR | — | — | — | {} |",
self.suite, self.problem, err
);
}
format!(
"| {} | `{}` | {:.4} | {:.3} | {:.1}x | {:.5} | {} |",
self.suite,
self.problem,
self.gpu_ms,
self.cpu_ms,
self.speedup,
self.max_abs_error,
if self.within_tolerance { "OK" } else { "FAIL" }
)
}
}
fn make_fp16_vector(n: usize, seed: u32) -> Vec<u16> {
let mut state: u32 = seed.wrapping_add(1);
let mut out = Vec::with_capacity(n);
for _ in 0..n {
state ^= state << 13;
state ^= state >> 17;
state ^= state << 5;
let scaled = (state as f32 / u32::MAX as f32) * 6.0 - 3.0;
out.push(f32_to_f16(scaled));
}
out
}
fn make_gemm_matrix(rows: usize, cols: usize, seed: u32) -> Vec<u16> {
let mut state: u32 = seed.wrapping_add(1);
let mut out = Vec::with_capacity(rows * cols);
for _ in 0..rows * cols {
state ^= state << 13;
state ^= state >> 17;
state ^= state << 5;
let scaled = (state as f32 / u32::MAX as f32) * 0.1 - 0.05;
out.push(f32_to_f16(scaled));
}
out
}
fn time_cpu<F: FnMut()>(mut f: F) -> f64 {
let start = Instant::now();
f();
start.elapsed().as_secs_f64() * 1000.0
}
fn write_memcpy_kernel() -> PathBuf {
const SRC: &str = r#"
#include <hip/hip_runtime.h>
#include <chrono>
#include <cstdio>
#include <cstdint>
#include <cstdlib>
#include <vector>
#define CHECK(call) do { hipError_t e = (call); if (e != hipSuccess) { \
fprintf(stderr, "HIP error %s at %s:%d\n", hipGetErrorString(e), __FILE__, __LINE__); std::exit(1); } } while (0)
int main(int argc, char** argv) {
if (argc < 3) { std::fprintf(stderr, "usage: %s <bytes> <iters>\n", argv[0]); return 1; }
const size_t bytes = std::strtoull(argv[1], nullptr, 10);
const int iters = std::atoi(argv[2]);
if (bytes == 0 || iters <= 0) { std::fprintf(stderr, "bad args\n"); return 1; }
std::vector<uint8_t> host_src(bytes);
for (size_t i = 0; i < bytes; i++) host_src[i] = (uint8_t)(i & 0xFF);
uint8_t* d_a = nullptr;
uint8_t* d_b = nullptr;
CHECK(hipMalloc(&d_a, bytes));
CHECK(hipMalloc(&d_b, bytes));
// Warm-up copies (excluded from timing)
CHECK(hipMemcpy(d_a, host_src.data(), bytes, hipMemcpyHostToDevice));
CHECK(hipMemcpy(d_b, d_a, bytes, hipMemcpyDeviceToDevice));
CHECK(hipDeviceSynchronize());
// 1) Host -> Device
auto h2d_start = std::chrono::steady_clock::now();
for (int i = 0; i < iters; i++) {
CHECK(hipMemcpy(d_a, host_src.data(), bytes, hipMemcpyHostToDevice));
}
CHECK(hipDeviceSynchronize());
auto h2d_end = std::chrono::steady_clock::now();
double h2d_ms = std::chrono::duration<double, std::milli>(h2d_end - h2d_start).count();
// 2) Device -> Host
std::vector<uint8_t> host_dst(bytes);
auto d2h_start = std::chrono::steady_clock::now();
for (int i = 0; i < iters; i++) {
CHECK(hipMemcpy(host_dst.data(), d_a, bytes, hipMemcpyDeviceToHost));
}
auto d2h_end = std::chrono::steady_clock::now();
double d2h_ms = std::chrono::duration<double, std::milli>(d2h_end - d2h_start).count();
// 3) Device -> Device (true VRAM bandwidth)
auto d2d_start = std::chrono::steady_clock::now();
for (int i = 0; i < iters; i++) {
CHECK(hipMemcpy(d_b, d_a, bytes, hipMemcpyDeviceToDevice));
}
CHECK(hipDeviceSynchronize());
auto d2d_end = std::chrono::steady_clock::now();
double d2d_ms = std::chrono::duration<double, std::milli>(d2d_end - d2d_start).count();
// D2D is a single read + single write per copy, so effective bandwidth is
// 2 * bytes * iters / time (read+write traffic).
double h2d_gbps = (double)bytes * (double)iters / (h2d_ms * 1.0e6);
double d2h_gbps = (double)bytes * (double)iters / (d2h_ms * 1.0e6);
double d2d_gbps = 2.0 * (double)bytes * (double)iters / (d2d_ms * 1.0e6);
CHECK(hipFree(d_a));
CHECK(hipFree(d_b));
std::printf("H2D_MS=%.6f\nH2D_GBPS=%.3f\nD2H_MS=%.6f\nD2H_GBPS=%.3f\nD2D_MS=%.6f\nD2D_GBPS=%.3f\n",
h2d_ms, h2d_gbps, d2h_ms, d2h_gbps, d2d_ms, d2d_gbps);
return 0;
}
"#;
let cache_dir = PathBuf::from("target/rocm-hip-cache");
fs::create_dir_all(&cache_dir).expect("create cache dir");
let src_path = cache_dir.join("gpu_smoke_memcpy.cpp");
let mut hasher = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(&SRC, &mut hasher);
let fp = format!("{:016x}", std::hash::Hasher::finish(&hasher));
let exe_path = cache_dir.join(format!("gpu_smoke_memcpy_{fp}"));
fs::write(&src_path, SRC).expect("write memcpy src");
if !exe_path.exists() {
let status = Command::new("/opt/rocm/bin/hipcc")
.arg("-O2")
.arg("-std=c++17")
.arg(&src_path)
.arg("-o")
.arg(&exe_path)
.status()
.expect("invoke hipcc for memcpy");
assert!(status.success(), "hipcc memcpy compile failed");
}
exe_path
}
struct MemcpyReport {
h2d_ms_avg: f64,
h2d_gbps: f64,
d2h_ms_avg: f64,
d2h_gbps: f64,
d2d_ms_avg: f64,
d2d_gbps: f64,
}
fn run_memcpy_benchmark() -> MemcpyReport {
let exe = write_memcpy_kernel();
let out = Command::new(&exe)
.arg(COPY_BYTES.to_string())
.arg(COPY_ITERS.to_string())
.stdin(Stdio::null())
.output()
.expect("run memcpy kernel");
assert!(
out.status.success(),
"memcpy kernel failed: stderr={}",
String::from_utf8_lossy(&out.stderr)
);
let stdout = String::from_utf8_lossy(&out.stdout);
let h2d_ms = parse_marker(&stdout, "H2D_MS=").expect("H2D_MS");
let h2d_gbps = parse_marker(&stdout, "H2D_GBPS=").expect("H2D_GBPS");
let d2h_ms = parse_marker(&stdout, "D2H_MS=").expect("D2H_MS");
let d2h_gbps = parse_marker(&stdout, "D2H_GBPS=").expect("D2H_GBPS");
let d2d_ms = parse_marker(&stdout, "D2D_MS=").expect("D2D_MS");
let d2d_gbps = parse_marker(&stdout, "D2D_GBPS=").expect("D2D_GBPS");
MemcpyReport {
h2d_ms_avg: h2d_ms / COPY_ITERS as f64,
h2d_gbps,
d2h_ms_avg: d2h_ms / COPY_ITERS as f64,
d2h_gbps,
d2d_ms_avg: d2d_ms / COPY_ITERS as f64,
d2d_gbps,
}
}
fn parse_marker(stdout: &str, marker: &str) -> Option<f64> {
stdout
.lines()
.find_map(|line| line.strip_prefix(marker))
.and_then(|v| v.trim().parse::<f64>().ok())
}
fn run_gelu_fwd() -> KernelResult {
let n = 1_000_000usize;
let input = make_fp16_vector(n, 0xC0FFEE_42);
let cpu_ms = time_cpu(|| {
let _ = cpu_gelu_fwd_fp16(&input, n);
});
match run_rocm_hip_gelu_fwd_reported(&input, n) {
Ok(r) => KernelResult {
name: "gelu_fwd",
suite: "gelu",
problem: "1M fp16".to_string(),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed("gelu_fwd", "gelu", "1M fp16", e.to_string()),
}
}
fn run_gelu_bwd() -> KernelResult {
let n = 1_000_000usize;
let grad_output = make_fp16_vector(n, 0xDEAD_BEEF);
let input = make_fp16_vector(n, 0xC0FFEE_42);
let cpu_ms = time_cpu(|| {
let mut s = 0.0f32;
for &g in &grad_output {
s += emb_f16_to_f32(g) * emb_f16_to_f32(input[0]);
}
std::hint::black_box(s);
});
match run_rocm_hip_gelu_bwd_reported(&grad_output, &input, n) {
Ok(r) => KernelResult {
name: "gelu_bwd",
suite: "gelu",
problem: "1M fp16".to_string(),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed("gelu_bwd", "gelu", "1M fp16", e.to_string()),
}
}
fn run_gemm_fwd() -> KernelResult {
let m = 1024usize;
let n = 1024;
let k = 1024;
let a = make_gemm_matrix(m, k, 0x1111_1111);
let b = make_gemm_matrix(k, n, 0x2222_2222);
let cpu_ms = time_cpu(|| {
let _ = cpu_gemm_f16(&a, &b, m, n, k);
});
match run_rocm_hip_gemm_f16(&a, &b, m, n, k) {
Ok(r) => KernelResult {
name: "gemm_fwd",
suite: "gemm_fwd",
problem: format!("{m}x{n}x{k} fp16"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed(
"gemm_fwd",
"gemm_fwd",
format!("{m}x{n}x{k} fp16"),
e.to_string(),
),
}
}
fn run_gemm_bw_grad_a() -> KernelResult {
let m = 1024usize;
let n = 1024;
let k = 1024;
let grad_c = make_gemm_matrix(m, n, 0xAAAA_1111);
let b = make_gemm_matrix(k, n, 0xBBBB_2222);
let cpu_ms = time_cpu(|| {
let _ = cpu_gemm_bw_grad_a(&grad_c, &b, m, n, k);
});
match run_rocm_hip_gemm_bw_grad_a(&grad_c, &b, m, n, k) {
Ok(r) => KernelResult {
name: "gemm_bw_grad_a",
suite: "gemm_bw",
problem: format!("{m}x{n}x{k} fp16"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed(
"gemm_bw_grad_a",
"gemm_bw",
format!("{m}x{n}x{k} fp16"),
e.to_string(),
),
}
}
fn run_gemm_bw_grad_b() -> KernelResult {
let m = 1024usize;
let n = 1024;
let k = 1024;
let grad_c = make_gemm_matrix(m, n, 0xCCCC_3333);
let a = make_gemm_matrix(m, k, 0xDDDD_4444);
let cpu_ms = time_cpu(|| {
let _ = cpu_gemm_bw_grad_b(&grad_c, &a, m, n, k);
});
match run_rocm_hip_gemm_bw_grad_b(&grad_c, &a, m, n, k) {
Ok(r) => KernelResult {
name: "gemm_bw_grad_b",
suite: "gemm_bw",
problem: format!("{m}x{n}x{k} fp16"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed(
"gemm_bw_grad_b",
"gemm_bw",
format!("{m}x{n}x{k} fp16"),
e.to_string(),
),
}
}
fn run_softmax_fwd() -> KernelResult {
let n_rows = 1024usize;
let n_cols = 1024;
let input = make_fp16_vector(n_rows * n_cols, 0x5050_5050);
let cpu_ms = time_cpu(|| {
for r in 0..n_rows {
let row = &input[r * n_cols..(r + 1) * n_cols];
let mut m = f32::MIN;
for &v in row {
m = m.max(emb_f16_to_f32(v));
}
let mut s = 0.0f32;
for &v in row {
s += (emb_f16_to_f32(v) - m).exp();
}
std::hint::black_box(s);
}
});
match run_rocm_hip_softmax_fwd(&input, n_rows, n_cols) {
Ok(r) => KernelResult {
name: "softmax_fwd",
suite: "softmax",
problem: format!("{n_rows}x{n_cols} fp16"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed(
"softmax_fwd",
"softmax",
format!("{n_rows}x{n_cols} fp16"),
e.to_string(),
),
}
}
fn softmax_fwd_host(input: &[u16], n_rows: usize, n_cols: usize) -> Vec<u16> {
let mut output = vec![0u16; n_rows * n_cols];
for r in 0..n_rows {
let row = &input[r * n_cols..(r + 1) * n_cols];
let mut m = f32::MIN;
for &v in row {
m = m.max(emb_f16_to_f32(v));
}
let mut s = 0.0f32;
for &v in row {
s += (emb_f16_to_f32(v) - m).exp();
}
let inv = 1.0f32 / s;
for c in 0..n_cols {
let v = (emb_f16_to_f32(row[c]) - m).exp() * inv;
output[r * n_cols + c] = emb_f32_to_f16(v);
}
}
output
}
fn run_grad_loss() -> KernelResult {
let n_rows = 1024usize;
let n_cols = 1024;
let logits = make_fp16_vector(n_rows * n_cols, 0x6060_6060);
let softmax_output = softmax_fwd_host(&logits, n_rows, n_cols);
let target = n_cols / 2;
let mut grad_output = vec![0u16; n_rows * n_cols];
for r in 0..n_rows {
grad_output[r * n_cols + target] = emb_f32_to_f16(1.0);
}
let cpu_ms = time_cpu(|| {
let mut s = 0.0f32;
for &v in &logits {
s += emb_f16_to_f32(v);
}
std::hint::black_box(s);
});
match run_rocm_hip_grad_loss_wrt_logits(&grad_output, &softmax_output, n_rows, n_cols) {
Ok(r) => KernelResult {
name: "grad_loss_wrt_logits",
suite: "softmax",
problem: format!("{n_rows}x{n_cols} fp16"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed(
"grad_loss_wrt_logits",
"softmax",
format!("{n_rows}x{n_cols} fp16"),
e.to_string(),
),
}
}
fn run_layernorm_fwd() -> KernelResult {
let n_rows = 128usize;
let n_cols = 768;
let input = make_fp16_vector(n_rows * n_cols, 0x7070_7070);
let gamma = make_fp16_vector(n_cols, 0x1234_5678);
let beta = make_fp16_vector(n_cols, 0x8765_4321);
let cpu_ms = time_cpu(|| {
for r in 0..n_rows {
let row = &input[r * n_cols..(r + 1) * n_cols];
let mut s = 0.0f32;
for &v in row {
s += emb_f16_to_f32(v);
}
let m = s / n_cols as f32;
let mut v = 0.0f32;
for &x in row {
let d = emb_f16_to_f32(x) - m;
v += d * d;
}
v = (v / n_cols as f32 + 1e-5).sqrt();
std::hint::black_box(v);
}
});
match run_rocm_hip_layernorm_fwd(&input, &gamma, &beta, n_rows, n_cols, 1e-5) {
Ok(r) => KernelResult {
name: "layernorm_fwd",
suite: "layernorm",
problem: format!("{n_rows}x{n_cols} fp16"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error_output,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed(
"layernorm_fwd",
"layernorm",
format!("{n_rows}x{n_cols} fp16"),
e.to_string(),
),
}
}
fn run_layernorm_bwd() -> KernelResult {
let n_rows = 128usize;
let n_cols = 768;
let grad_output = make_fp16_vector(n_rows * n_cols, 0x8080_8080);
let input = make_fp16_vector(n_rows * n_cols, 0x8181_8181);
let gamma = make_fp16_vector(n_cols, 0x1234_5678);
let mut mean = Vec::with_capacity(n_rows);
let mut rstd = Vec::with_capacity(n_rows);
for r in 0..n_rows {
let row = &input[r * n_cols..(r + 1) * n_cols];
let mut s = 0.0f32;
for &v in row {
s += emb_f16_to_f32(v);
}
let m = s / n_cols as f32;
let mut v = 0.0f32;
for &x in row {
let d = emb_f16_to_f32(x) - m;
v += d * d;
}
v = (v / n_cols as f32 + 1e-5).sqrt();
mean.push(m);
rstd.push(1.0 / v);
}
let cpu_ms = time_cpu(|| {
for r in 0..n_rows {
let row = &input[r * n_cols..(r + 1) * n_cols];
for &x in row {
std::hint::black_box(emb_f16_to_f32(x));
}
}
});
match run_rocm_hip_layernorm_bwd(&grad_output, &input, &gamma, &mean, &rstd, n_rows, n_cols) {
Ok(r) => {
let gi = r.max_abs_error_grad_input;
let gg = r.max_abs_error_grad_gamma;
let gb = r.max_abs_error_grad_beta;
let (bottleneck, worst) = [("grad_input", gi), ("grad_gamma", gg), ("grad_beta", gb)]
.into_iter()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(n, v)| (n, v))
.unwrap_or(("grad_input", gi));
eprintln!(
" [layernorm_bwd] grad_input={:.6} grad_gamma={:.6} grad_beta={:.6} [bottleneck={}@{:.6}]",
gi, gg, gb, bottleneck, worst,
);
KernelResult {
name: "layernorm_bwd",
suite: "layernorm",
problem: format!("{n_rows}x{n_cols} fp16"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: worst,
within_tolerance: r.within_tolerance,
error: None,
}
}
Err(e) => KernelResult::failed(
"layernorm_bwd",
"layernorm",
format!("{n_rows}x{n_cols} fp16"),
e.to_string(),
),
}
}
fn run_embedding_fwd() -> KernelResult {
let n_queries = 64usize;
let embedding_dim = 64;
let vocab_size = 256;
let weight = make_gemm_matrix(vocab_size, embedding_dim, 0x9191_9191);
let indices: Vec<i32> = (0..n_queries).map(|i| (i % vocab_size) as i32).collect();
let cpu_ms = time_cpu(|| {
let _ = cpu_embedding_fwd(&indices, &weight, n_queries, embedding_dim, vocab_size);
});
match run_rocm_hip_embedding_fwd_with_report(
&indices,
&weight,
n_queries,
embedding_dim,
vocab_size,
) {
Ok(r) => KernelResult {
name: "embedding_fwd",
suite: "embedding",
problem: format!("q={n_queries} dim={embedding_dim} vocab={vocab_size}"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: if r.forward_exact { 0.0 } else { 1.0 },
within_tolerance: r.forward_exact,
error: None,
},
Err(e) => KernelResult::failed(
"embedding_fwd",
"embedding",
format!("q={n_queries} dim={embedding_dim}"),
e.to_string(),
),
}
}
fn run_embedding_bwd() -> KernelResult {
let n_queries = 64usize;
let embedding_dim = 64;
let vocab_size = 256;
let indices: Vec<i32> = (0..n_queries).map(|i| (i % vocab_size) as i32).collect();
let grad_output = make_gemm_matrix(n_queries, embedding_dim, 0x9292_9292);
let cpu_ms = time_cpu(|| {
let _ = cpu_embedding_bwd(&grad_output, &indices, n_queries, embedding_dim, vocab_size);
});
match run_rocm_hip_embedding_bwd_with_report(&grad_output, &indices, embedding_dim, vocab_size)
{
Ok(r) => KernelResult {
name: "embedding_bwd",
suite: "embedding",
problem: format!("q={n_queries} dim={embedding_dim} vocab={vocab_size}"),
gpu_ms: r.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / r.kernel_time_ms.max(1e-6) as f64,
max_abs_error: r.max_abs_error,
within_tolerance: r.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed(
"embedding_bwd",
"embedding",
format!("q={n_queries} dim={embedding_dim}"),
e.to_string(),
),
}
}
fn run_adamw_step() -> KernelResult {
let n = 1024usize;
let mut theta: Vec<u16> = make_gemm_matrix(n, 1, 0xA0A0_A0A0);
let mut m: Vec<f32> = vec![0.0; n];
let mut v: Vec<f32> = vec![0.0; n];
let grad: Vec<u16> = make_gemm_matrix(n, 1, 0xA1A1_A1A1);
let theta_snapshot = theta.clone();
let cpu_ms = time_cpu(|| {
let lr = 1e-3f32;
let beta1 = 0.9f32;
let beta2 = 0.999f32;
let eps = 1e-8f32;
let wd = 0.01f32;
let t = 1i32;
for i in 0..n {
let t_f = emb_f16_to_f32(theta_snapshot[i]);
let g = emb_f16_to_f32(grad[i]);
m[i] = beta1 * m[i] + (1.0 - beta1) * g;
v[i] = beta2 * v[i] + (1.0 - beta2) * g * g;
let m_hat = m[i] / (1.0 - beta1.powi(t));
let v_hat = v[i] / (1.0 - beta2.powi(t));
let update = m_hat / (v_hat.sqrt() + eps) + wd * t_f;
let new_theta = t_f - lr * update;
let new_bits = emb_f32_to_f16(new_theta);
std::hint::black_box(new_bits);
}
});
match _adamw_mod::run_rocm_hip_adamw_step_oracle(
&mut theta, &mut m, &mut v, &grad, 1e-3, 0.9, 0.999, 1e-8, 0.01, 1, 1e-2,
) {
Ok(report) => KernelResult {
name: "adamw_step",
suite: "adamw",
problem: format!("n={n} t=1"),
gpu_ms: report.kernel_time_ms,
cpu_ms,
speedup: cpu_ms / report.kernel_time_ms.max(1e-6) as f64,
max_abs_error: report.max_abs_error_theta,
within_tolerance: report.within_tolerance,
error: None,
},
Err(e) => KernelResult::failed("adamw_step", "adamw", format!("n={n} t=1"), e.to_string()),
}
}
fn run_padic_encode() -> KernelResult {
let n = 1024usize;
let values = make_fp16_vector(n, 0xB0B0_B0B0);
let cpu_ms = time_cpu(|| {
let _ = _padic_mod::cpu_padic_encode_f16(&values, 8);
});
match _padic_mod::run_rocm_hip_padic_encode_f16(&values, 8) {
Ok(r) => KernelResult {
name: "padic_encode",
suite: "padic",
problem: format!("n={n} prec=8"),
gpu_ms: 0.0,
cpu_ms,
speedup: 0.0,
max_abs_error: 0.0,
within_tolerance: r.cpu_oracle_matches,
error: if !r.cpu_oracle_matches {
Some("oracle mismatch".to_string())
} else {
None
},
},
Err(e) => KernelResult::failed(
"padic_encode",
"padic",
format!("n={n} prec=8"),
e.to_string(),
),
}
}
fn run_padic_decode() -> KernelResult {
let n = 1024usize;
let values = make_fp16_vector(n, 0xB1B1_B1B1);
let encoded = match _padic_mod::run_rocm_hip_padic_encode_f16(&values, 8) {
Ok(r) => r.outputs,
Err(e) => {
return KernelResult::failed(
"padic_decode",
"padic",
format!("n={n} prec=8"),
format!("encode prerequisite failed: {e}"),
);
}
};
let cpu_ms = time_cpu(|| {
let _ = _padic_mod::cpu_padic_decode_f16(&encoded, 8);
});
match _padic_mod::run_rocm_hip_padic_decode_f16(&encoded, 8) {
Ok(r) => KernelResult {
name: "padic_decode",
suite: "padic",
problem: format!("n={n} prec=8"),
gpu_ms: 0.0,
cpu_ms,
speedup: 0.0,
max_abs_error: 0.0,
within_tolerance: r.cpu_oracle_matches,
error: if !r.cpu_oracle_matches {
Some("oracle mismatch".to_string())
} else {
None
},
},
Err(e) => KernelResult::failed(
"padic_decode",
"padic",
format!("n={n} prec=8"),
e.to_string(),
),
}
}
fn run_sheaf_overlap() -> KernelResult {
let n_sections = 64usize;
let section_dim = 16;
let n_overlaps = 8usize;
let mut sections: Vec<Vec<u16>> = Vec::with_capacity(n_sections);
for i in 0..n_sections {
sections.push(make_fp16_vector(section_dim, 0xC000_0000 + i as u32));
}
let overlaps: Vec<(usize, usize)> =
(0..n_overlaps).map(|i| (i, (i + 1) % n_sections)).collect();
match run_rocm_hip_sheaf_overlap_check(§ions, &overlaps, PrecisionClass::Fp16) {
Ok(r) => KernelResult {
name: "sheaf_overlap_check",
suite: "sheaf",
problem: format!("sections={n_sections} dim={section_dim} overlaps={n_overlaps}"),
gpu_ms: r.kernel_time_ms,
cpu_ms: 0.0,
speedup: 0.0,
max_abs_error: 0.0,
within_tolerance: r.n_overlaps == n_overlaps,
error: None,
},
Err(e) => KernelResult::failed(
"sheaf_overlap_check",
"sheaf",
format!("sections={n_sections}"),
e.to_string(),
),
}
}
fn warmup_gpu() {
let n = 1024usize;
let input = vec![f32_to_f16(0.0); n];
for _ in 0..3 {
let _ = run_rocm_hip_gelu_fwd_reported(&input, n);
}
}
fn main() {
let report = detect_local_rocm_hip();
if !report.available {
eprintln!("GPU_DETECTED=false");
eprintln!("REASON: ROCm/HIP capability report says available=false");
for ev in &report.evidence {
eprintln!(" - {ev}");
}
std::process::exit(2);
}
let selected = report
.selected_device
.as_ref()
.expect("available implies selected_device");
eprintln!("GPU_DETECTED=true");
eprintln!("GPU_NAME={}", selected.marketing_name);
eprintln!("GPU_GFX={}", selected.gfx);
eprintln!(
"GPU_VRAM={}",
selected
.vram_bytes
.map(|b| format!("{} bytes ({:.1} GiB)", b, b as f64 / (1u64 << 30) as f64))
.unwrap_or_else(|| "unknown".to_string())
);
eprintln!(
"GPU_COMPUTE_UNITS={}",
selected
.compute_units
.map(|c| c.to_string())
.unwrap_or_else(|| "unknown".to_string())
);
eprintln!(
"HIP_VERSION={}",
report.toolchain.hip_version.as_deref().unwrap_or("unknown")
);
eprintln!(
"DRIVER_VERSION={}",
report
.toolchain
.driver_version
.as_deref()
.unwrap_or("unknown")
);
if selected.gfx != GFX_TARGET {
eprintln!(
"WARNING: expected gfx {} but got gfx {}; speedup numbers may be off-target",
GFX_TARGET, selected.gfx
);
}
eprintln!("");
eprintln!("Warming up GPU with 3 small gelu_fwd launches...");
warmup_gpu();
eprintln!("Running 10 kernel suites (this takes 30-90 seconds)...");
let started = Instant::now();
let mut results: Vec<KernelResult> = Vec::new();
let runners: Vec<(&'static str, fn() -> KernelResult)> = vec![
("gelu_fwd", run_gelu_fwd),
("gelu_bwd", run_gelu_bwd),
("gemm_fwd", run_gemm_fwd),
("gemm_bw_grad_a", run_gemm_bw_grad_a),
("gemm_bw_grad_b", run_gemm_bw_grad_b),
("softmax_fwd", run_softmax_fwd),
("grad_loss_wrt_logits", run_grad_loss),
("layernorm_fwd", run_layernorm_fwd),
("layernorm_bwd", run_layernorm_bwd),
("embedding_fwd", run_embedding_fwd),
("embedding_bwd", run_embedding_bwd),
("adamw_step", run_adamw_step),
("padic_encode", run_padic_encode),
("padic_decode", run_padic_decode),
("sheaf_overlap_check", run_sheaf_overlap),
];
for (name, runner) in &runners {
eprint!(" {} ... ", name);
let r = runner();
if r.error.is_some() {
eprintln!("ERROR");
} else {
eprintln!(
"{:.3} ms (cpu {:.2} ms, {:.1}x){}",
r.gpu_ms,
r.cpu_ms,
r.speedup,
if r.within_tolerance {
""
} else {
" [TOLERANCE FAIL]"
}
);
}
results.push(r);
}
let suites_ran: std::collections::BTreeSet<&'static str> =
results.iter().map(|r| r.suite).collect();
eprintln!("");
eprintln!(
"Kernels run: {} (across {} suites: {})",
results.len(),
suites_ran.len(),
suites_ran
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(", ")
);
let total_kernel_time_ms: f64 = results.iter().map(|r| r.gpu_ms as f64).sum();
eprintln!("");
eprintln!(
"Running memcpy bandwidth micro-benchmark ({} MiB x {} iters)...",
COPY_BYTES / (1024 * 1024),
COPY_ITERS
);
let bw = run_memcpy_benchmark();
eprintln!(
" H2D: {:.2} GB/s ({:.3} ms avg, PCIe-limited)",
bw.h2d_gbps, bw.h2d_ms_avg
);
eprintln!(
" D2H: {:.2} GB/s ({:.3} ms avg, PCIe-limited)",
bw.d2h_gbps, bw.d2h_ms_avg
);
eprintln!(
" D2D: {:.2} GB/s ({:.3} ms avg, VRAM)",
bw.d2d_gbps, bw.d2d_ms_avg
);
let step_estimate = estimate_step_time(&results);
eprintln!("");
eprintln!(
"0.7B MoE step-time estimate: {:.1} ms (plan claims ~13 ms; see report)",
step_estimate.total_ms
);
let total_elapsed = started.elapsed().as_secs_f64();
eprintln!("Total wall time: {:.1} s", total_elapsed);
let report_path = PathBuf::from("target/gpu_smoke_report.md");
let md = build_report(
&report,
&results,
&bw,
&step_estimate,
total_elapsed,
total_kernel_time_ms,
);
fs::create_dir_all("target").expect("create target dir");
let mut f = fs::File::create(&report_path).expect("create report");
f.write_all(md.as_bytes()).expect("write report");
eprintln!("");
eprintln!("Wrote report to {}", report_path.display());
println!(
"SUMMARY: gpu=true gfx={} suites={} kernels={} h2d_gbps={:.1} d2h_gbps={:.1} d2d_gbps={:.1} step_ms={:.1} report={}",
selected.gfx,
suites_ran.len(),
results.len(),
bw.h2d_gbps,
bw.d2h_gbps,
bw.d2d_gbps,
step_estimate.total_ms,
report_path.display()
);
}
struct StepEstimate {
embed_ms: f64,
router_ms: f64,
softmax_ms: f64,
expert_matmul_ms: f64,
expert_gelu_ms: f64,
expert_layernorm_ms: f64,
sheaf_ms: f64,
head_ms: f64,
loss_ms: f64,
backward_ms: f64,
adamw_ms: f64,
total_ms: f64,
}
fn estimate_step_time(results: &[KernelResult]) -> StepEstimate {
fn gpu_ms(name: &str, results: &[KernelResult]) -> f64 {
results
.iter()
.find(|r| r.name == name && r.error.is_none())
.map(|r| r.gpu_ms as f64)
.unwrap_or(0.0)
}
let gemm_fwd_1024_ms = gpu_ms("gemm_fwd", results);
let gemm_bwa_1024_ms = gpu_ms("gemm_bw_grad_a", results);
let gemm_bwb_1024_ms = gpu_ms("gemm_bw_grad_b", results);
let gelu_1m_ms = gpu_ms("gelu_fwd", results);
let softmax_1024_ms = gpu_ms("softmax_fwd", results);
let grad_loss_1024_ms = gpu_ms("grad_loss_wrt_logits", results);
let layernorm_ms = gpu_ms("layernorm_fwd", results);
let embed_measured = gpu_ms("embedding_fwd", results);
let sheaf_ms = gpu_ms("sheaf_overlap_check", results);
let flops_1024_cube = 2.0 * 1024.0_f64.powi(3);
let ms_per_gflop_fwd = gemm_fwd_1024_ms / (flops_1024_cube / 1.0e9);
let ms_per_gflop_bwa = gemm_bwa_1024_ms / (flops_1024_cube / 1.0e9);
let ms_per_gflop_bwb = gemm_bwb_1024_ms / (flops_1024_cube / 1.0e9);
let ms_per_gflop_avg = (ms_per_gflop_fwd + ms_per_gflop_bwa + ms_per_gflop_bwb) / 3.0;
let router_gflops = 16.0 * 4.0 * 1024.0 * 2.0 / 1.0e9;
let router_ms = f64::max(router_gflops * ms_per_gflop_avg, 0.05);
let expert_gflops = 16.0 * 1024.0 * 4096.0 * 2.0 * 4.0 * 6.0 * 2.0 / 1.0e9;
let expert_matmul_ms_fwd = expert_gflops * 0.5 * ms_per_gflop_fwd;
let expert_matmul_ms_bwd = expert_gflops * 0.5 * (ms_per_gflop_bwa + ms_per_gflop_bwb) * 0.5;
let expert_matmul_total = expert_matmul_ms_fwd + expert_matmul_ms_bwd;
let gelu_elems_per_step = 16.0 * 4096.0 * 6.0 * 2.0 * 2.0;
let expert_gelu_ms = (gelu_elems_per_step / 1.0e6) * gelu_1m_ms;
let layernorm_elems_per_step = 16.0 * 1024.0 * 6.0 * 2.0 * 2.0;
let expert_layernorm_ms = (layernorm_elems_per_step / (128.0 * 768.0)) * layernorm_ms;
let softmax_ms = f64::max(softmax_1024_ms * (16.0 * 4.0) / (1024.0 * 1024.0), 0.01);
let sheaf_ms = f64::max(sheaf_ms, 0.05);
let head_gflops = 4.0 * 16.0 * 1024.0 * 16.0 * 2.0 / 1.0e9;
let head_ms = f64::max(head_gflops * ms_per_gflop_avg, 0.04);
let loss_ms =
4.0 * f64::max(grad_loss_1024_ms * (16.0 * 4.0) / (1024.0 * 1024.0), 0.005) + sheaf_ms;
let embed_ms = f64::max(embed_measured * 16.0 / 1024.0, 0.05);
let adamw_ms = f64::max(0.7e9 * 4.0 / 800.0e9 * 1000.0, 0.5);
let backward_ms = expert_matmul_ms_bwd + head_ms + 0.5 * (layernorm_ms + grad_loss_1024_ms);
let total_ms = embed_ms
+ router_ms
+ softmax_ms
+ expert_matmul_total
+ expert_gelu_ms
+ expert_layernorm_ms
+ sheaf_ms
+ head_ms
+ loss_ms
+ adamw_ms;
StepEstimate {
embed_ms,
router_ms,
softmax_ms,
expert_matmul_ms: expert_matmul_total,
expert_gelu_ms,
expert_layernorm_ms,
sheaf_ms,
head_ms,
loss_ms,
backward_ms,
adamw_ms,
total_ms,
}
}
fn build_report(
cap: &RocmHipCapabilityReport,
results: &[KernelResult],
bw: &MemcpyReport,
step: &StepEstimate,
wall_s: f64,
total_kernel_ms: f64,
) -> String {
let selected = cap.selected_device.as_ref().expect("selected device");
let vram_str = selected
.vram_bytes
.map(|b| format!("{:.2} GiB", b as f64 / (1u64 << 30) as f64))
.unwrap_or_else(|| "unknown".to_string());
let cu_str = selected
.compute_units
.map(|c| c.to_string())
.unwrap_or_else(|| "unknown".to_string());
let mut md = String::new();
md.push_str("# GPU Smoke Report - AMD RX 7800 XT (gfx1101)\n\n");
md.push_str("Generated by `cargo run --features rocm-hip --bin gpu_smoke --release`.\n\n");
md.push_str("## GPU\n\n");
md.push_str("| Field | Value |\n| --- | --- |\n");
md.push_str(&format!("| Detected | {} |\n", cap.available));
md.push_str(&format!(
"| Marketing name | {} |\n",
selected.marketing_name
));
md.push_str(&format!(
"| GFX target | {} (expected {}) |\n",
selected.gfx, GFX_TARGET
));
md.push_str(&format!("| Compute units | {} |\n", cu_str));
md.push_str(&format!("| VRAM | {} |\n", vram_str));
md.push_str(&format!(
"| HIP version | {} |\n",
cap.toolchain.hip_version.as_deref().unwrap_or("unknown")
));
md.push_str(&format!(
"| Driver version | {} |\n",
cap.toolchain.driver_version.as_deref().unwrap_or("unknown")
));
md.push_str(&format!(
"| Capability fingerprint | {} |\n",
cap.capability_fingerprint
));
let suite_names: std::collections::BTreeSet<&'static str> =
results.iter().map(|r| r.suite).collect();
md.push_str(&format!(
"\n## Kernel Timings ({} HIP kernel functions across {} suites, all kernels warm)\n\n",
results.len(),
suite_names.len()
));
md.push_str("| Suite | Problem | GPU ms | CPU ms | Speedup | Max abs err | Status |\n");
md.push_str("| --- | --- | ---: | ---: | ---: | ---: | --- |\n");
for r in results {
md.push_str(&r.to_row());
md.push('\n');
}
md.push_str(&format!(
"\nWall time for kernel run: **{:.1} ms** total kernel time across all kernel calls.\n",
total_kernel_ms
));
md.push_str("\n## Memory Bandwidth (64 MiB buffer, 10 iters)\n\n");
md.push_str("| Direction | Avg ms | GB/s | Target (7800 XT) |\n| --- | ---: | ---: | --- |\n");
md.push_str(&format!(
"| Host -> Device | {:.3} | {:.1} | ~25-32 (PCIe Gen4 x16) |\n",
bw.h2d_ms_avg, bw.h2d_gbps
));
md.push_str(&format!(
"| Device -> Host | {:.3} | {:.1} | ~25-32 (PCIe Gen4 x16) |\n",
bw.d2h_ms_avg, bw.d2h_gbps
));
md.push_str(&format!(
"| Device -> Device | {:.3} | {:.1} | > 400 (VRAM, 576 GB/s peak) |\n",
bw.d2d_ms_avg, bw.d2d_gbps
));
let target = 400.0;
let d2d_pass = if bw.d2d_gbps >= target {
"PASS"
} else {
"BELOW TARGET"
};
md.push_str(&format!(
"\n**D2D bandwidth check: {}** (target > {:.0} GB/s; H2D/D2H are PCIe-limited)\n",
d2d_pass, target
));
md.push_str("\n## 0.7B MoE Single-Step Time Estimate\n\n");
md.push_str("Per `tokitai-search/docs/MOE_TRAINING_PLAN.md` §2/§5, with the per-kernel times measured above.\n\n");
md.push_str("Architecture: H=1024, D=1024, B=16, S=80, E=4, K=2 active, L=6 layers.\n\n");
md.push_str("| Component | Estimated ms | Notes |\n| --- | ---: | --- |\n");
md.push_str(&format!(
"| 1. Embedding lookup | {:.2} | scaled from {}ms @ 1024 queries |\n",
step.embed_ms,
embed_ms_field(results)
));
md.push_str(&format!(
"| 2. Router GEMM + softmax | {:.2} + {:.3} | small matmul + 16x4 row softmax |\n",
step.router_ms, step.softmax_ms
));
md.push_str(&format!(
"| 3. Expert matmuls (fwd+bwd) | {:.2} | 4 GEMMs/layer x 6 layers x 2 active experts |\n",
step.expert_matmul_ms
));
md.push_str(&format!(
"| 3a. GELU activations | {:.2} | scaled from 1M-element timing |\n",
step.expert_gelu_ms
));
md.push_str(&format!(
"| 3b. LayerNorm | {:.2} | 12 layernorms per step |\n",
step.expert_layernorm_ms
));
md.push_str(&format!(
"| 4. Sheaf overlap check | {:.3} | K=2, ~7 overlaps |\n",
step.sheaf_ms
));
md.push_str(&format!(
"| 5. 4 head GEMMs | {:.3} | (B,D) @ (D,16) per head |\n",
step.head_ms
));
md.push_str(&format!(
"| 6. Loss + grad | {:.3} | fused CE/MSE per head |\n",
step.loss_ms
));
md.push_str(&format!(
"| 7-10. Backward-only matmul overhead | {:.2} | mirror of forward via grad_a/grad_b kernels |\n",
step.backward_ms
));
md.push_str(&format!(
"| 11. AdamW (0.7B params) | {:.2} | memory-bandwidth floor (4 B/param / 800 GB/s) |\n",
step.adamw_ms
));
md.push_str(&format!(
"| **Total per step** | **{:.1} ms** | vs plan claim of ~13 ms |\n",
step.total_ms
));
md.push_str(&format!(
"\n**Plan assessment**: the plan's ~13 ms/step target is for the 80M Tiny model, not the 0.7B. The 0.7B model with 12 GEMMs per step and AdamW state movement realistically lands at **{:.0} ms/step**, which means ~**{:.1} hours** for 100k steps.\n",
step.total_ms,
step.total_ms * 100_000.0 / 3_600_000.0
));
md.push_str("\n## Run Wall Time\n\n");
md.push_str(&format!(
"All suites + memcpy benchmark + report: **{:.1} s** total.\n",
wall_s
));
md.push_str("\n## Sanity Verdict\n\n");
let all_suites_passed: std::collections::BTreeSet<&'static str> = results
.iter()
.filter(|r| r.error.is_none() && r.within_tolerance)
.map(|r| r.suite)
.collect();
let kernels_ok = results
.iter()
.filter(|r| r.error.is_none() && r.within_tolerance)
.count();
md.push_str(&format!(
"- GPU detected: **{}** (gfx={})\n",
if cap.available { "YES" } else { "NO" },
selected.gfx
));
md.push_str(&format!(
"- HIP kernel functions run: **{}** (target = 10+ per `MOE_TRAINING_PLAN.md` §1.2)\n",
kernels_ok
));
md.push_str(&format!(
"- HIP kernel suites run: **{}** (gelu, gemm_fwd, gemm_bw, softmax, layernorm, embedding, adamw, padic, sheaf)\n",
all_suites_passed.len()
));
md.push_str(&format!(
"- D2D memcpy bandwidth (VRAM): **{:.1} GB/s** (target > {:.0} GB/s) {}\n",
bw.d2d_gbps, target, d2d_pass
));
md.push_str(&format!(
"- D2H memcpy bandwidth (PCIe): **{:.1} GB/s**\n",
bw.d2h_gbps
));
md.push_str(&format!(
"- 0.7B MoE step time estimate: **{:.1} ms**\n",
step.total_ms
));
md.push_str("\n## Evidence\n\n");
for ev in &cap.evidence {
md.push_str(&format!("- {ev}\n"));
}
md
}
fn embed_ms_field(results: &[KernelResult]) -> String {
results
.iter()
.find(|r| r.name == "embedding_fwd")
.map(|r| format!("{:.3}", r.gpu_ms))
.unwrap_or_else(|| "?".to_string())
}