use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use crate::backend::hip_dense::{
hipcc_compile_executable, hipcc_compiler_fingerprint, hipcc_recheck_artifact,
};
use crate::backend::kernel_server;
use crate::backend::rocm::{RocmHipCapabilityReport, detect_local_rocm_hip};
use crate::{Error, Result};
pub const ROCM_HIP_GEMM_BW_GRAD_A_BACKEND: &str = "rocm_hip_gemm_bw_grad_a_pilot";
pub const ROCM_HIP_GEMM_BW_GRAD_B_BACKEND: &str = "rocm_hip_gemm_bw_grad_b_pilot";
pub const ROCM_HIP_GEMM_BW_GRAD_A_LOWERING_ID: &str = "hip.gemm_bw.grad_a.fp16_f32";
pub const ROCM_HIP_GEMM_BW_GRAD_B_LOWERING_ID: &str = "hip.gemm_bw.grad_b.fp16_f32";
const GEMM_BW_GRAD_A_B_KERNEL_TYPE: &str = "hip-gemm-bw-grad-a-b";
pub const HIP_GEMM_BW_KERNEL: &str = r#"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
// grad_A[m, k] = sum_n grad_C[m, n] * B[k, n]
// Same 16x16 tile structure as the forward fp16 GEMM, with the K-loop index
// being N (the reduction/interior dimension) instead of K.
__global__ void gemm_bw_grad_a_fp16_f32_kernel(
const __half* grad_C,
const __half* B,
__half* grad_A,
int M,
int N,
int K) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y * blockDim.y + threadIdx.y;
if (row >= M || col >= K) {
return;
}
float acc = 0.0f;
for (int n = 0; n < N; ++n) {
float gc = __half2float(grad_C[static_cast<int>(row) * static_cast<int>(N) + n]);
float b = __half2float(B[static_cast<int>(col) * static_cast<int>(N) + n]);
// Use explicit fmaf() to fuse the multiply-add into a single
// rounding step, matching the CPU oracle's `mul_add`-based
// FMA. Without this, a 1024-element reduction can produce
// ~2.0 errors due to severe cancellation: the unfused
// `acc = acc + (gc * b)` rounds the product first, so when
// partial sums have large magnitude with sign changes the
// rounding step can land the partial sum on a different
// side of zero than the FMA-fused version does.
acc = fmaf(gc, b, acc);
}
grad_A[static_cast<int>(row) * static_cast<int>(K) + col] = __float2half_rn(acc);
}
// grad_B[k, n] = sum_m A[m, k] * grad_C[m, n]
// Same 16x16 tile structure as the forward fp16 GEMM, with the K-loop index
// being M (the reduction/interior dimension) instead of K.
__global__ void gemm_bw_grad_b_fp16_f32_kernel(
const __half* grad_C,
const __half* A,
__half* grad_B,
int M,
int N,
int K) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y * blockDim.y + threadIdx.y;
if (row >= K || col >= N) {
return;
}
float acc = 0.0f;
for (int m = 0; m < M; ++m) {
float a = __half2float(A[static_cast<int>(m) * static_cast<int>(K) + row]);
float gc = __half2float(grad_C[static_cast<int>(m) * static_cast<int>(N) + col]);
// See gemm_bw_grad_a for why fmaf() is required.
acc = fmaf(a, gc, acc);
}
grad_B[static_cast<int>(row) * static_cast<int>(N) + col] = __float2half_rn(acc);
}
static void check(hipError_t status, const char* label) {
if (status != hipSuccess) {
std::cerr << "HIP_ERROR " << label << "=" << hipGetErrorString(status) << "\n";
std::exit(10);
}
}
// Forward declaration of the existing main() body, extracted into
// a static helper so the server-mode loop can call it on each
// request. The default `main()` also routes through this helper so
// the one-shot and server code paths share the same compute logic.
static int run_one_shot_from_main_body();
// Persistent server-mode protocol (see hip_gemm_f16.rs for the full
// design rationale). The host writes a little-endian u32 payload_len
// followed by `payload_len` bytes of the existing text payload, then
// reads back a little-endian u32 response_len followed by
// `response_len` bytes of the existing text response.
static int run_server_mode() {
while (true) {
uint32_t payload_len = 0;
std::cin.read(reinterpret_cast<char*>(&payload_len), 4);
if (!std::cin || std::cin.gcount() == 0) {
return 0; // clean EOF
}
if (std::cin.gcount() != 4) {
std::cerr << "server_mode: short read on payload_len (got "
<< std::cin.gcount() << " bytes)\n";
return 20;
}
std::vector<char> payload(payload_len);
if (payload_len > 0) {
std::cin.read(payload.data(), payload_len);
if (static_cast<uint32_t>(std::cin.gcount()) != payload_len) {
std::cerr << "server_mode: short read on payload (got "
<< std::cin.gcount() << " of " << payload_len << ")\n";
return 21;
}
}
std::string payload_str(payload.begin(), payload.end());
std::istringstream fake_stdin(payload_str);
std::streambuf* old_buf = std::cin.rdbuf(fake_stdin.rdbuf());
std::ostringstream captured;
std::streambuf* old_cout = std::cout.rdbuf(captured.rdbuf());
std::ostringstream captured_err;
std::streambuf* old_cerr = std::cerr.rdbuf(captured_err.rdbuf());
int rc = run_one_shot_from_main_body();
std::cin.rdbuf(old_buf);
std::cout.rdbuf(old_cout);
std::cerr.rdbuf(old_cerr);
std::string response = captured.str();
if (rc != 0) {
std::string err_str = captured_err.str();
response += err_str;
}
uint32_t response_len = static_cast<uint32_t>(response.size());
std::cout.write(reinterpret_cast<const char*>(&response_len), 4);
if (response_len > 0) {
std::cout.write(response.data(), response_len);
}
std::cout.flush();
if (rc != 0) {
return rc;
}
}
}
int main(int argc, char** argv) {
if (argc > 1 && std::string(argv[1]) == "--server") {
return run_server_mode();
}
return run_one_shot_from_main_body();
}
static int run_one_shot_from_main_body() {
int kernel_id = 0;
int M = 0;
int N = 0;
int K = 0;
if (!(std::cin >> kernel_id >> M >> N >> K)) {
std::cerr << "usage: stdin payload is \"kernel_id M N K\\n<grad_C_bits>\\n<operand_bits>\\n\"\n";
return 2;
}
if (kernel_id != 1 && kernel_id != 2) {
std::cerr << "kernel_id must be 1 (grad_a) or 2 (grad_b); got " << kernel_id << "\n";
return 3;
}
if (M <= 0 || N <= 0 || K <= 0) {
std::cerr << "M N K must all be positive\n";
return 4;
}
if (kernel_id == 1 && (M % 16 != 0 || K % 16 != 0)) {
std::cerr << "grad_a kernel requires M=" << M << " and K=" << K
<< " to be multiples of 16 for the 16x16 tile design\n";
return 5;
}
if (kernel_id == 2 && (K % 16 != 0 || N % 16 != 0)) {
std::cerr << "grad_b kernel requires K=" << K << " and N=" << N
<< " to be multiples of 16 for the 16x16 tile design\n";
return 6;
}
std::size_t grad_c_count = static_cast<std::size_t>(M) * static_cast<std::size_t>(N);
std::size_t a_count = static_cast<std::size_t>(M) * static_cast<std::size_t>(K);
std::size_t b_count = static_cast<std::size_t>(K) * static_cast<std::size_t>(N);
std::size_t grad_a_count = static_cast<std::size_t>(M) * static_cast<std::size_t>(K);
std::size_t grad_b_count = static_cast<std::size_t>(K) * static_cast<std::size_t>(N);
std::vector<uint16_t> grad_c_bits(grad_c_count);
for (std::size_t i = 0; i < grad_c_count; ++i) {
if (!(std::cin >> grad_c_bits[i])) {
std::cerr << "failed to read grad_C element " << i << "\n";
return 7;
}
}
std::vector<uint16_t> a_bits(a_count);
std::vector<uint16_t> b_bits(b_count);
if (kernel_id == 1) {
for (std::size_t i = 0; i < b_count; ++i) {
if (!(std::cin >> b_bits[i])) {
std::cerr << "failed to read B element " << i << "\n";
return 8;
}
}
} else {
for (std::size_t i = 0; i < a_count; ++i) {
if (!(std::cin >> a_bits[i])) {
std::cerr << "failed to read A element " << i << "\n";
return 9;
}
}
}
int device = 0;
check(hipSetDevice(device), "hipSetDevice");
hipDeviceProp_t props;
check(hipGetDeviceProperties(&props, device), "hipGetDeviceProperties");
__half* d_grad_C = nullptr;
__half* d_A = nullptr;
__half* d_B = nullptr;
__half* d_grad_A = nullptr;
__half* d_grad_B = nullptr;
std::size_t grad_c_bytes = grad_c_count * sizeof(__half);
std::size_t a_bytes = a_count * sizeof(__half);
std::size_t b_bytes = b_count * sizeof(__half);
std::size_t grad_a_bytes = grad_a_count * sizeof(__half);
std::size_t grad_b_bytes = grad_b_count * sizeof(__half);
check(hipMalloc(&d_grad_C, grad_c_bytes), "hipMalloc(grad_C)");
check(hipMemcpy(d_grad_C, grad_c_bits.data(), grad_c_bytes, hipMemcpyHostToDevice),
"hipMemcpy(grad_C)");
dim3 block(16, 16);
dim3 grid(0, 0);
hipEvent_t start;
hipEvent_t stop;
check(hipEventCreate(&start), "hipEventCreate(start)");
check(hipEventCreate(&stop), "hipEventCreate(stop)");
check(hipEventRecord(start), "hipEventRecord(start)");
if (kernel_id == 1) {
check(hipMalloc(&d_B, b_bytes), "hipMalloc(B)");
check(hipMalloc(&d_grad_A, grad_a_bytes), "hipMalloc(grad_A)");
check(hipMemcpy(d_B, b_bits.data(), b_bytes, hipMemcpyHostToDevice),
"hipMemcpy(B)");
grid = dim3(M / 16, K / 16);
hipLaunchKernelGGL(gemm_bw_grad_a_fp16_f32_kernel, grid, block, 0, 0,
d_grad_C, d_B, d_grad_A, M, N, K);
} else {
check(hipMalloc(&d_A, a_bytes), "hipMalloc(A)");
check(hipMalloc(&d_grad_B, grad_b_bytes), "hipMalloc(grad_B)");
check(hipMemcpy(d_A, a_bits.data(), a_bytes, hipMemcpyHostToDevice),
"hipMemcpy(A)");
grid = dim3(K / 16, N / 16);
hipLaunchKernelGGL(gemm_bw_grad_b_fp16_f32_kernel, grid, block, 0, 0,
d_grad_C, d_A, d_grad_B, M, N, K);
}
check(hipGetLastError(), "hipLaunchKernelGGL");
check(hipEventRecord(stop), "hipEventRecord(stop)");
check(hipEventSynchronize(stop), "hipEventSynchronize");
float kernel_time_ms = 0.0f;
check(hipEventElapsedTime(&kernel_time_ms, start, stop), "hipEventElapsedTime");
check(hipEventDestroy(start), "hipEventDestroy(start)");
check(hipEventDestroy(stop), "hipEventDestroy(stop)");
std::vector<uint16_t> out_bits;
if (kernel_id == 1) {
out_bits.resize(grad_a_count);
check(hipMemcpy(out_bits.data(), d_grad_A, grad_a_bytes, hipMemcpyDeviceToHost),
"hipMemcpy(grad_A)");
} else {
out_bits.resize(grad_b_count);
check(hipMemcpy(out_bits.data(), d_grad_B, grad_b_bytes, hipMemcpyDeviceToHost),
"hipMemcpy(grad_B)");
}
check(hipFree(d_grad_C), "hipFree(grad_C)");
if (d_A) {
check(hipFree(d_A), "hipFree(A)");
}
if (d_B) {
check(hipFree(d_B), "hipFree(B)");
}
if (d_grad_A) {
check(hipFree(d_grad_A), "hipFree(grad_A)");
}
if (d_grad_B) {
check(hipFree(d_grad_B), "hipFree(grad_B)");
}
std::cout << "DEVICE_NAME=" << props.name << "\n";
std::cout << "GFX=" << props.gcnArchName << "\n";
std::cout << "KERNEL_ID=" << kernel_id << "\n";
std::cout << "M=" << M << "\n";
std::cout << "N=" << N << "\n";
std::cout << "K=" << K << "\n";
std::cout << "GRID_X=" << grid.x << "\n";
std::cout << "GRID_Y=" << grid.y << "\n";
std::cout << "BLOCK_X=" << block.x << "\n";
std::cout << "BLOCK_Y=" << block.y << "\n";
std::cout << "KERNEL_TIME_MS=" << kernel_time_ms << "\n";
std::cout << "RESULTS=";
for (std::size_t i = 0; i < out_bits.size(); ++i) {
if (i != 0) {
std::cout << " ";
}
std::cout << out_bits[i];
}
std::cout << "\n";
return 0;
}
"#;
#[derive(Debug, Clone, PartialEq)]
pub struct RocmHipGemmBwReport {
pub backend: &'static str,
pub kernel_id: u32,
pub m: usize,
pub n: usize,
pub k: usize,
pub outputs: Vec<u16>,
pub cpu_oracle_outputs: Vec<u16>,
pub max_abs_error: f32,
pub within_tolerance: bool,
pub kernel_time_ms: f32,
pub kernel_source_fingerprint: String,
pub compiler_fingerprint: String,
pub build_command: String,
pub executable_path: String,
pub device_evidence: RocmHipCapabilityReport,
pub evidence: Vec<String>,
pub non_claims: Vec<String>,
}
impl RocmHipGemmBwReport {
pub fn to_markdown(&self) -> String {
let mut lines = vec![
"# ROCm/HIP fp16 GEMM Backward Pilot".to_string(),
String::new(),
format!("backend: {}", self.backend),
format!("kernel_id: {}", self.kernel_id),
format!("m: {}", self.m),
format!("n: {}", self.n),
format!("k: {}", self.k),
format!("max_abs_error: {}", self.max_abs_error),
format!("within_tolerance: {}", self.within_tolerance),
format!("kernel_time_ms: {}", self.kernel_time_ms),
format!(
"kernel_source_fingerprint: {}",
self.kernel_source_fingerprint
),
format!("compiler_fingerprint: {}", self.compiler_fingerprint),
String::new(),
"## Evidence".to_string(),
];
for item in &self.evidence {
lines.push(format!("- {item}"));
}
lines.push(String::new());
lines.push("## Non-Claims".to_string());
for item in &self.non_claims {
lines.push(format!("- {item}"));
}
lines.join("\n")
}
}
pub fn run_rocm_hip_gemm_bw_grad_a(
grad_c: &[u16],
b: &[u16],
m: usize,
n: usize,
k: usize,
) -> Result<RocmHipGemmBwReport> {
run_bw_kernel(
ROCM_HIP_GEMM_BW_GRAD_A_BACKEND,
1,
grad_c,
Some(b),
None,
m,
n,
k,
)
}
pub fn run_rocm_hip_gemm_bw_grad_b(
grad_c: &[u16],
a: &[u16],
m: usize,
n: usize,
k: usize,
) -> Result<RocmHipGemmBwReport> {
run_bw_kernel(
ROCM_HIP_GEMM_BW_GRAD_B_BACKEND,
2,
grad_c,
None,
Some(a),
m,
n,
k,
)
}
fn run_bw_kernel(
backend: &'static str,
kernel_id: u32,
grad_c: &[u16],
b: Option<&[u16]>,
a: Option<&[u16]>,
m: usize,
n: usize,
k: usize,
) -> Result<RocmHipGemmBwReport> {
if grad_c.len() != m * n {
return Err(Error::backend(format!(
"fp16 GEMM BW grad_C length {} does not match m*n={}",
grad_c.len(),
m * n
)));
}
if kernel_id == 1 {
let b = b.ok_or_else(|| Error::backend("grad_a kernel requires B operand"))?;
if b.len() != k * n {
return Err(Error::backend(format!(
"fp16 GEMM BW B length {} does not match k*n={}",
b.len(),
k * n
)));
}
if m % 16 != 0 || k % 16 != 0 {
return Err(Error::backend(format!(
"fp16 GEMM BW grad_a requires m={} and k={} to be multiples of 16",
m, k
)));
}
} else {
let a = a.ok_or_else(|| Error::backend("grad_b kernel requires A operand"))?;
if a.len() != m * k {
return Err(Error::backend(format!(
"fp16 GEMM BW A length {} does not match m*k={}",
a.len(),
m * k
)));
}
if k % 16 != 0 || n % 16 != 0 {
return Err(Error::backend(format!(
"fp16 GEMM BW grad_b requires k={} and n={} to be multiples of 16",
k, n
)));
}
}
if m == 0 || n == 0 || k == 0 {
return Err(Error::backend(
"fp16 GEMM BW dimensions must all be positive",
));
}
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; fp16 GEMM BW pilot remains inadmissible",
));
}
let source_fingerprint = hip_gemm_bw_kernel_source_fingerprint();
let cache_dir = PathBuf::from("target/rocm-hip-cache");
fs::create_dir_all(&cache_dir)
.map_err(|err| Error::backend(format!("failed to create HIP cache directory: {err}")))?;
let source_path = cache_dir.join(format!("{source_fingerprint}.cpp"));
let executable_path = cache_dir.join(format!("{source_fingerprint}-gemm-bw-fp16-f32"));
fs::write(&source_path, HIP_GEMM_BW_KERNEL)
.map_err(|err| Error::backend(format!("failed to write HIP BW kernel source: {err}")))?;
let hipcc = "/opt/rocm/bin/hipcc";
let compiler_fingerprint = hipcc_compiler_fingerprint(hipcc)?;
let build_command =
hipcc_compile_executable(hipcc, &source_path, &executable_path, Some("gfx1101"))?;
let grad_c_count = m * n;
let operand_count = if kernel_id == 1 { k * n } else { m * k };
let mut payload = String::with_capacity((grad_c_count + operand_count) * 8);
payload.push_str(&format!("{kernel_id} {m} {n} {k}\n"));
for (i, v) in grad_c.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&v.to_string());
}
payload.push('\n');
if kernel_id == 1 {
let b = b.expect("B operand validated above");
for (i, v) in b.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&v.to_string());
}
} else {
let a = a.expect("A operand validated above");
for (i, v) in a.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&v.to_string());
}
}
payload.push('\n');
let stdout = run_gemm_bw_executable(&executable_path, &source_path, &payload)?;
let outputs = parse_bw_results(&stdout)?;
let kernel_time_ms = parse_bw_f32_line(&stdout, "KERNEL_TIME_MS=")
.ok_or_else(|| Error::backend("HIP fp16 GEMM BW did not print KERNEL_TIME_MS marker"))?;
let cpu_oracle_outputs = if kernel_id == 1 {
cpu_gemm_bw_grad_a(grad_c, b.expect("B validated above"), m, n, k)
} else {
cpu_gemm_bw_grad_b(grad_c, a.expect("A validated above"), m, n, k)
};
let mut max_abs_error = 0.0f32;
for (g, c) in outputs.iter().zip(cpu_oracle_outputs.iter()) {
let err = (crate::backend::hip_gemm_f16::f16_to_f32(*g)
- crate::backend::hip_gemm_f16::f16_to_f32(*c))
.abs();
if err > max_abs_error {
max_abs_error = err;
}
}
let within_tolerance = max_abs_error < 1e-2;
let kernel_label = if kernel_id == 1 { "grad_a" } else { "grad_b" };
Ok(RocmHipGemmBwReport {
backend,
kernel_id,
m,
n,
k,
outputs,
cpu_oracle_outputs,
max_abs_error,
within_tolerance,
kernel_time_ms,
kernel_source_fingerprint: source_fingerprint,
compiler_fingerprint,
build_command,
executable_path: executable_path.display().to_string(),
device_evidence,
evidence: vec![
"compiled HIP kernel with /opt/rocm/bin/hipcc -O2 --offload-arch=gfx1101".to_string(),
format!(
"shipped grad_C and the {} operand to the kernel via stdin (Stdio::piped)",
if kernel_id == 1 { "B" } else { "A" }
),
format!(
"launched gemm_bw_{}_fp16_f32_kernel with grid/block as documented in the kernel source",
kernel_label
),
"captured kernel time with hipEventRecord/hipEventSynchronize".to_string(),
"compared every output element against the CPU fp16 oracle within 1e-2".to_string(),
],
non_claims: vec![
"not production speedup evidence".to_string(),
"not optimized GEMM backward (no shared-memory tiling, no vectorized loads)"
.to_string(),
"not a generic autograd engine (only the two transposed fp16 GEMM variants)"
.to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn hip_gemm_bw_kernel_source_fingerprint() -> String {
fingerprint("hip-gemm-bw-source", HIP_GEMM_BW_KERNEL)
}
fn run_gemm_bw_executable(
executable_path: &std::path::Path,
source_path: &std::path::Path,
payload: &str,
) -> Result<String> {
hipcc_recheck_artifact(
"/opt/rocm/bin/hipcc",
source_path,
executable_path,
Some("gfx1101"),
)?;
kernel_server::run_persistent(GEMM_BW_GRAD_A_B_KERNEL_TYPE, executable_path, payload)
}
pub fn cpu_gemm_bw_grad_a(grad_c: &[u16], b: &[u16], m: usize, n: usize, k: usize) -> Vec<u16> {
use crate::backend::hip_gemm_f16::{f16_to_f32, f32_to_f16};
let gc_f32: Vec<f32> = grad_c.iter().copied().map(f16_to_f32).collect();
let b_f32: Vec<f32> = b.iter().copied().map(f16_to_f32).collect();
let mut grad_a = vec![0u16; m * k];
for i in 0..m {
for kk in 0..k {
let mut acc = 0.0f32;
for nn in 0..n {
acc = gc_f32[i * n + nn].mul_add(b_f32[kk * n + nn], acc);
}
grad_a[i * k + kk] = f32_to_f16(acc);
}
}
grad_a
}
pub fn cpu_gemm_bw_grad_b(grad_c: &[u16], a: &[u16], m: usize, n: usize, k: usize) -> Vec<u16> {
use crate::backend::hip_gemm_f16::{f16_to_f32, f32_to_f16};
let a_f32: Vec<f32> = a.iter().copied().map(f16_to_f32).collect();
let gc_f32: Vec<f32> = grad_c.iter().copied().map(f16_to_f32).collect();
let mut grad_b = vec![0u16; k * n];
for kk in 0..k {
for nn in 0..n {
let mut acc = 0.0f32;
for mm in 0..m {
acc = a_f32[mm * k + kk].mul_add(gc_f32[mm * n + nn], acc);
}
grad_b[kk * n + nn] = f32_to_f16(acc);
}
}
grad_b
}
fn parse_bw_results(stdout: &str) -> Result<Vec<u16>> {
let line = stdout
.lines()
.find_map(|line| line.strip_prefix("RESULTS="))
.ok_or_else(|| Error::backend("HIP fp16 GEMM BW did not print RESULTS marker"))?;
if line.trim().is_empty() {
return Ok(Vec::new());
}
line.split_whitespace()
.map(|value| {
value.trim().parse::<u16>().map_err(|err| {
Error::backend(format!(
"invalid HIP fp16 GEMM BW output value {value:?}: {err}"
))
})
})
.collect()
}
fn parse_bw_f32_line(stdout: &str, prefix: &str) -> Option<f32> {
stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.and_then(|value| value.trim().parse::<f32>().ok())
}
fn fingerprint(label: &str, value: &str) -> String {
let mut hasher = DefaultHasher::new();
label.hash(&mut hasher);
value.hash(&mut hasher);
format!("{label}-{:016x}", hasher.finish())
}