use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use crate::backend::hip_dense::{
hipcc_compile_executable, hipcc_compiler_fingerprint, hipcc_recheck_artifact,
};
use crate::backend::kernel_server;
use crate::backend::rocm::{RocmHipCapabilityReport, detect_local_rocm_hip};
use crate::{Error, Result};
pub const ROCM_HIP_LAYERNORM_FWD_BACKEND: &str = "rocm_hip_layernorm_fwd_pilot";
pub const ROCM_HIP_LAYERNORM_FWD_LOWERING_ID: &str = "hip.layernorm.fp16_f32.fwd";
pub const ROCM_HIP_LAYERNORM_BWD_BACKEND: &str = "rocm_hip_layernorm_bwd_pilot";
pub const ROCM_HIP_LAYERNORM_BWD_LOWERING_ID: &str = "hip.layernorm.fp16_f32.bwd";
const LAYERNORM_FWD_BWD_KERNEL_TYPE: &str = "hip-layernorm-fwd-bwd";
pub const HIP_LAYERNORM_FWD_BWD_KERNEL: &str = r#"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
// Warp-level fp32 sum reduction. Assumes warp size = 32.
__device__ __forceinline__ float warp_reduce_sum(float v) {
// ROCm 7.x requires the mask to be a 64-bit integer; passing a 32-bit
// unsigned fails the static_assert inside __shfl_down_sync.
unsigned long long mask = 0xffffffffULL;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
v += __shfl_down_sync(mask, v, offset);
}
return v;
}
// Block-level fp32 sum reduction. Uses a single shared scratch slot per
// warp, then the first warp reduces across warps. Assumes block size is
// a multiple of 32 and at most 1024 threads (so at most 32 warps).
__device__ __forceinline__ float block_reduce_sum(float v, float* shared) {
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int n_warps = (blockDim.x + 31) >> 5;
v = warp_reduce_sum(v);
if (lane == 0) {
shared[warp] = v;
}
__syncthreads();
float total = (threadIdx.x < n_warps) ? shared[lane] : 0.0f;
__syncthreads();
if (warp == 0) {
total = warp_reduce_sum(total);
if (lane == 0) {
shared[0] = total;
}
}
__syncthreads();
return shared[0];
}
// LayerNorm forward: y = (x - mean) * rstd * gamma + beta.
// grid = (n_rows,), block = (BLOCK_SIZE,). One block per row.
// Reductions stay in fp32; only the final write is cast back to fp16.
__global__ void layernorm_fwd_fp16_f32_kernel(
const __half* input, // [n_rows, n_cols]
const __half* gamma, // [n_cols]
const __half* beta, // [n_cols]
__half* output, // [n_rows, n_cols]
float* mean, // [n_rows]
float* rstd, // [n_rows]
int n_rows,
int n_cols,
float eps,
float inv_n_cols) {
int row = blockIdx.x;
int tid = threadIdx.x;
int block_size = blockDim.x;
long long row_off = static_cast<long long>(row) * n_cols;
const __half* row_in = input + row_off;
__half* row_out = output + row_off;
__shared__ float shared[32];
// Pass 1a: sum x.
float local_sum = 0.0f;
for (int i = tid; i < n_cols; i += block_size) {
local_sum += __half2float(row_in[i]);
}
float total_sum = block_reduce_sum(local_sum, shared);
float row_mean = total_sum * inv_n_cols;
// Pass 1b: sum (x - mean)^2.
float local_sum_sq = 0.0f;
for (int i = tid; i < n_cols; i += block_size) {
float v = __half2float(row_in[i]);
float d = v - row_mean;
local_sum_sq += d * d;
}
float total_sum_sq = block_reduce_sum(local_sum_sq, shared);
float row_var = total_sum_sq * inv_n_cols;
float row_rstd = rsqrtf(row_var + eps);
if (tid == 0) {
mean[row] = row_mean;
rstd[row] = row_rstd;
}
// Pass 2: write normalized output with affine transform.
for (int i = tid; i < n_cols; i += block_size) {
float v = __half2float(row_in[i]);
float g = __half2float(gamma[i]);
float b = __half2float(beta[i]);
float normalized = (v - row_mean) * row_rstd;
row_out[i] = __float2half_rn(normalized * g + b);
}
}
// LayerNorm backward: standard two-pass design for grad_input.
// dxhat = grad_output * gamma
// dx = (1/n_cols) * rstd * (n_cols * dxhat - sum_dxhat - xhat * sum_dxhat_xhat)
// grad_gamma and grad_beta are computed in a SEPARATE kernel
// (`layernorm_bw_grad_kernel` below) that uses one block per column with a
// deterministic block-level fp32 reduction. Keeping these in two kernels
// means grad_input is still produced row-by-row, and grad_gamma / grad_beta
// are produced column-by-column with no non-deterministic atomicAdd ordering
// (which previously caused ~1 ULP fp16 round-trip differences).
// grid = (n_rows,), block = (BLOCK_SIZE,). One block per row.
__global__ void layernorm_bw_fp16_f32_kernel(
const __half* grad_output, // [n_rows, n_cols]
const __half* input, // [n_rows, n_cols]
const __half* gamma, // [n_cols]
const __half* mean, // [n_rows]
const __half* rstd, // [n_rows]
__half* grad_input, // [n_rows, n_cols]
int n_rows,
int n_cols,
float inv_n_cols) {
int row = blockIdx.x;
int tid = threadIdx.x;
int block_size = blockDim.x;
long long row_off = static_cast<long long>(row) * n_cols;
const __half* row_dy = grad_output + row_off;
const __half* row_x = input + row_off;
__half* row_dx = grad_input + row_off;
float row_mean = __half2float(mean[row]);
float row_rstd = __half2float(rstd[row]);
float n_cols_f = static_cast<float>(n_cols);
__shared__ float shared_dxhat[32];
__shared__ float shared_dxhat_xhat[32];
// Pass 1: per-element dxhat = grad_output * gamma, xhat = (x - mean) * rstd.
// * accumulate local sum_dxhat and sum_dxhat_xhat for the block reduction
// * grad_gamma / grad_beta contributions are computed in a separate
// column-major kernel below
float local_sum_dxhat = 0.0f;
float local_sum_dxhat_xhat = 0.0f;
for (int i = tid; i < n_cols; i += block_size) {
float dy = __half2float(row_dy[i]);
float x = __half2float(row_x[i]);
float g = __half2float(gamma[i]);
float xhat = (x - row_mean) * row_rstd;
float dxhat = dy * g;
local_sum_dxhat += dxhat;
local_sum_dxhat_xhat += dxhat * xhat;
}
float sum_dxhat = block_reduce_sum(local_sum_dxhat, shared_dxhat);
float sum_dxhat_xhat = block_reduce_sum(local_sum_dxhat_xhat, shared_dxhat_xhat);
// Pass 2: per-element grad_input.
for (int i = tid; i < n_cols; i += block_size) {
float dy = __half2float(row_dy[i]);
float x = __half2float(row_x[i]);
float g = __half2float(gamma[i]);
float xhat = (x - row_mean) * row_rstd;
float dxhat = dy * g;
float dx = inv_n_cols * row_rstd *
(n_cols_f * dxhat - sum_dxhat - xhat * sum_dxhat_xhat);
row_dx[i] = __float2half_rn(dx);
}
}
// LayerNorm backward: grad_gamma and grad_beta via column-parallel reduction.
// dgamma[col] = sum_r xhat[r,col] * grad_output[r,col]
// dbeta[col] = sum_r grad_output[r,col]
//
// grid = (n_cols,). One block per column. Each thread accumulates the
// contribution from a subset of rows for that column, the block performs a
// deterministic fp32 warp + shared-memory reduction, and thread 0 writes
// the per-column sum directly to global fp32 grad_gamma / grad_beta. There
// is no atomicAdd at all, so the result is fully deterministic (up to the
// block reduction order) and matches the CPU fp32 oracle within ~1 fp16 ULP.
//
// The previous design had grid = (n_rows,) with each thread doing one
// atomicAdd per (row, column) pair. With 128 blocks contributing to each
// column, the non-deterministic atomicAdd order could produce fp32 results
// that, when cast to fp16, differ from the CPU oracle by 1 fp16 ULP. For
// large-magnitude inputs that 1 ULP easily exceeds 1e-2 absolute tolerance.
//
// mean and rstd are read as fp32 (not fp16) so the per-row xhat matches the
// CPU oracle's xhat exactly. The fp16 round-trip used in the previous
// design introduced ~1e-3 relative error in xhat, which then propagated
// into the per-column sum and routinely crossed the 1e-2 absolute
// tolerance for inputs in [-3, 3].
//
// block = (BLOCK_SIZE_GRAD,) where BLOCK_SIZE_GRAD is chosen so each thread
// accumulates ~1 row (capped at 1024).
__global__ void layernorm_bw_grad_kernel(
const __half* grad_output, // [n_rows, n_cols]
const __half* input, // [n_rows, n_cols]
const __half* gamma, // [n_cols]
const float* mean, // [n_rows] (fp32, NOT fp16)
const float* rstd, // [n_rows] (fp32, NOT fp16)
float* grad_gamma, // [n_cols] (fp32, no atomicAdd needed)
float* grad_beta, // [n_cols] (fp32, no atomicAdd needed)
int n_rows,
int n_cols) {
int col = blockIdx.x;
int tid = threadIdx.x;
int block_size = blockDim.x;
// Each thread walks every block_size-th row for this column, accumulating
// a thread-local fp32 sum.
float local_sum_gamma = 0.0f;
float local_sum_beta = 0.0f;
for (int r = tid; r < n_rows; r += block_size) {
long long row_off = static_cast<long long>(r) * n_cols + col;
float dy = __half2float(grad_output[row_off]);
float x = __half2float(input[row_off]);
float m = mean[r];
float rs = rstd[r];
float xhat = (x - m) * rs;
local_sum_gamma += xhat * dy;
local_sum_beta += dy;
}
// Block-level fp32 reduction. Reuse a single 32-slot shared scratch for
// both reductions (the second block_reduce_sum is preceded by a
// __syncthreads to keep the scratch safe to overwrite).
__shared__ float s_sum[32];
float block_gamma = block_reduce_sum(local_sum_gamma, s_sum);
__syncthreads();
float block_beta = block_reduce_sum(local_sum_beta, s_sum);
// One thread per block writes the per-column fp32 sum. Each block owns a
// unique column, so no atomicAdd / race is possible.
if (tid == 0) {
grad_gamma[col] = block_gamma;
grad_beta[col] = block_beta;
}
}
static void check(hipError_t status, const char* label) {
if (status != hipSuccess) {
std::cerr << "HIP_ERROR " << label << "=" << hipGetErrorString(status) << "\n";
std::exit(10);
}
}
static void print_u16(const char* label, const std::vector<uint16_t>& values) {
// `label` already ends in `=`, so the helper just appends the values.
std::cout << label;
for (std::size_t i = 0; i < values.size(); ++i) {
if (i != 0) {
std::cout << ",";
}
// Cast to unsigned int: streaming a uint16_t would otherwise print
// it as a `unsigned char` (i.e. the ASCII character), not a number.
std::cout << static_cast<unsigned int>(values[i]);
}
std::cout << "\n";
}
static void print_f32(const char* label, const std::vector<float>& values) {
std::cout << label;
for (std::size_t i = 0; i < values.size(); ++i) {
if (i != 0) {
std::cout << ",";
}
std::cout << values[i];
}
std::cout << "\n";
}
// Forward declaration of the existing main() body, extracted into
// a static helper so the server-mode loop can call it on each
// request. The default `main()` also routes through this helper so
// the one-shot and server code paths share the same compute logic.
static int run_one_shot_from_main_body();
// Stdin layout:
// OP N_ROWS N_COLS BLOCK_SIZE EPS GRAD_BLOCK_SIZE
// gamma_bits[N_COLS] beta_bits[N_COLS]
// <fwd-only: input_bits[N_ROWS*N_COLS]>
// <bwd-only: grad_output_bits[N_ROWS*N_COLS] input_bits[N_ROWS*N_COLS]
// mean_bits[N_ROWS] rstd_bits[N_ROWS]>
// OP=1 selects forward, OP=2 selects backward.
// GRAD_BLOCK_SIZE is only consumed for OP=2; it controls the launch shape of
// the column-parallel `layernorm_bw_grad_kernel`.
// Persistent server-mode protocol (see hip_gemm_f16.rs for the full
// design rationale). The host writes a little-endian u32 payload_len
// followed by `payload_len` bytes of the existing text payload, then
// reads back a little-endian u32 response_len followed by
// `response_len` bytes of the existing text response.
static int run_server_mode() {
while (true) {
uint32_t payload_len = 0;
std::cin.read(reinterpret_cast<char*>(&payload_len), 4);
if (!std::cin || std::cin.gcount() == 0) {
return 0; // clean EOF
}
if (std::cin.gcount() != 4) {
std::cerr << "server_mode: short read on payload_len (got "
<< std::cin.gcount() << " bytes)\n";
return 20;
}
std::vector<char> payload(payload_len);
if (payload_len > 0) {
std::cin.read(payload.data(), payload_len);
if (static_cast<uint32_t>(std::cin.gcount()) != payload_len) {
std::cerr << "server_mode: short read on payload (got "
<< std::cin.gcount() << " of " << payload_len << ")\n";
return 21;
}
}
std::string payload_str(payload.begin(), payload.end());
std::istringstream fake_stdin(payload_str);
std::streambuf* old_buf = std::cin.rdbuf(fake_stdin.rdbuf());
std::ostringstream captured;
std::streambuf* old_cout = std::cout.rdbuf(captured.rdbuf());
std::ostringstream captured_err;
std::streambuf* old_cerr = std::cerr.rdbuf(captured_err.rdbuf());
int rc = run_one_shot_from_main_body();
std::cin.rdbuf(old_buf);
std::cout.rdbuf(old_cout);
std::cerr.rdbuf(old_cerr);
std::string response = captured.str();
if (rc != 0) {
std::string err_str = captured_err.str();
response += err_str;
}
uint32_t response_len = static_cast<uint32_t>(response.size());
std::cout.write(reinterpret_cast<const char*>(&response_len), 4);
if (response_len > 0) {
std::cout.write(response.data(), response_len);
}
std::cout.flush();
if (rc != 0) {
return rc;
}
}
}
int main(int argc, char** argv) {
if (argc > 1 && std::string(argv[1]) == "--server") {
return run_server_mode();
}
return run_one_shot_from_main_body();
}
static int run_one_shot_from_main_body() {
int op = 0;
int n_rows = 0;
int n_cols = 0;
int block_size = 0;
float eps = 0.0f;
int grad_block_size = 0;
if (!(std::cin >> op >> n_rows >> n_cols >> block_size >> eps)) {
std::cerr << "usage: stdin payload is \"OP N_ROWS N_COLS BLOCK_SIZE EPS\\n"
"GAMMA[N_COLS] BETA[N_COLS] <fwd: INPUT[N_ROWS*N_COLS]>\\n"
"<bwd: GRAD_OUTPUT[N_ROWS*N_COLS] INPUT[N_ROWS*N_COLS] MEAN[N_ROWS] RSTD[N_ROWS] GRAD_BLOCK_SIZE>\\n\"\n";
return 2;
}
if (op != 1 && op != 2) {
std::cerr << "OP must be 1 (forward) or 2 (backward)\n";
return 3;
}
if (n_rows <= 0 || n_cols <= 0 || block_size <= 0) {
std::cerr << "N_ROWS N_COLS BLOCK_SIZE must all be positive\n";
return 4;
}
if (block_size % 32 != 0 || block_size > 1024) {
std::cerr << "BLOCK_SIZE=" << block_size
<< " must be a positive multiple of 32 up to 1024\n";
return 5;
}
if (op == 2) {
if (!(std::cin >> grad_block_size)) {
std::cerr << "expected GRAD_BLOCK_SIZE for OP=2\n";
return 6;
}
if (grad_block_size <= 0 || grad_block_size % 32 != 0 || grad_block_size > 1024) {
std::cerr << "GRAD_BLOCK_SIZE=" << grad_block_size
<< " must be a positive multiple of 32 up to 1024\n";
return 7;
}
} else {
grad_block_size = 32; // unused in forward
}
float inv_n_cols = 1.0f / static_cast<float>(n_cols);
std::size_t row_count = static_cast<std::size_t>(n_rows) *
static_cast<std::size_t>(n_cols);
std::size_t col_count = static_cast<std::size_t>(n_cols);
std::size_t col_bytes = col_count * sizeof(__half);
// Common buffers: gamma and beta.
std::vector<uint16_t> gamma_bits(col_count);
std::vector<uint16_t> beta_bits(col_count);
for (std::size_t i = 0; i < col_count; ++i) {
if (!(std::cin >> gamma_bits[i])) {
std::cerr << "failed to read gamma element " << i << "\n";
return 6;
}
}
for (std::size_t i = 0; i < col_count; ++i) {
if (!(std::cin >> beta_bits[i])) {
std::cerr << "failed to read beta element " << i << "\n";
return 7;
}
}
int device = 0;
check(hipSetDevice(device), "hipSetDevice");
hipDeviceProp_t props;
check(hipGetDeviceProperties(&props, device), "hipGetDeviceProperties");
__half* d_gamma = nullptr;
__half* d_beta = nullptr;
check(hipMalloc(&d_gamma, col_bytes), "hipMalloc(gamma)");
check(hipMalloc(&d_beta, col_bytes), "hipMalloc(beta)");
check(hipMemcpy(d_gamma, gamma_bits.data(), col_bytes, hipMemcpyHostToDevice),
"hipMemcpy(gamma)");
check(hipMemcpy(d_beta, beta_bits.data(), col_bytes, hipMemcpyHostToDevice),
"hipMemcpy(beta)");
if (op == 1) {
// Forward payload: input_bits[N_ROWS*N_COLS].
std::vector<uint16_t> input_bits(row_count);
for (std::size_t i = 0; i < row_count; ++i) {
if (!(std::cin >> input_bits[i])) {
std::cerr << "failed to read input element " << i << "\n";
return 8;
}
}
std::vector<uint16_t> output_bits(row_count);
std::vector<float> mean_out(n_rows);
std::vector<float> rstd_out(n_rows);
__half* d_input = nullptr;
__half* d_output = nullptr;
float* d_mean = nullptr;
float* d_rstd = nullptr;
check(hipMalloc(&d_input, row_count * sizeof(__half)), "hipMalloc(input)");
check(hipMalloc(&d_output, row_count * sizeof(__half)), "hipMalloc(output)");
check(hipMalloc(&d_mean, static_cast<std::size_t>(n_rows) * sizeof(float)),
"hipMalloc(mean)");
check(hipMalloc(&d_rstd, static_cast<std::size_t>(n_rows) * sizeof(float)),
"hipMalloc(rstd)");
check(hipMemcpy(d_input, input_bits.data(), row_count * sizeof(__half),
hipMemcpyHostToDevice), "hipMemcpy(input)");
dim3 grid(n_rows);
dim3 block(block_size);
hipEvent_t start;
hipEvent_t stop;
check(hipEventCreate(&start), "hipEventCreate(start)");
check(hipEventCreate(&stop), "hipEventCreate(stop)");
check(hipEventRecord(start), "hipEventRecord(start)");
hipLaunchKernelGGL(layernorm_fwd_fp16_f32_kernel, grid, block, 0, 0,
d_input, d_gamma, d_beta, d_output, d_mean, d_rstd,
n_rows, n_cols, eps, inv_n_cols);
check(hipGetLastError(), "hipLaunchKernelGGL(fwd)");
check(hipEventRecord(stop), "hipEventRecord(stop)");
check(hipEventSynchronize(stop), "hipEventSynchronize");
float kernel_time_ms = 0.0f;
check(hipEventElapsedTime(&kernel_time_ms, start, stop), "hipEventElapsedTime");
check(hipEventDestroy(start), "hipEventDestroy(start)");
check(hipEventDestroy(stop), "hipEventDestroy(stop)");
check(hipMemcpy(output_bits.data(), d_output, row_count * sizeof(__half),
hipMemcpyDeviceToHost), "hipMemcpy(output)");
check(hipMemcpy(mean_out.data(), d_mean,
static_cast<std::size_t>(n_rows) * sizeof(float),
hipMemcpyDeviceToHost), "hipMemcpy(mean)");
check(hipMemcpy(rstd_out.data(), d_rstd,
static_cast<std::size_t>(n_rows) * sizeof(float),
hipMemcpyDeviceToHost), "hipMemcpy(rstd)");
check(hipFree(d_input), "hipFree(input)");
check(hipFree(d_output), "hipFree(output)");
check(hipFree(d_mean), "hipFree(mean)");
check(hipFree(d_rstd), "hipFree(rstd)");
check(hipFree(d_gamma), "hipFree(gamma)");
check(hipFree(d_beta), "hipFree(beta)");
std::cout << "OP=1\n";
std::cout << "DEVICE_NAME=" << props.name << "\n";
std::cout << "GFX=" << props.gcnArchName << "\n";
std::cout << "N_ROWS=" << n_rows << "\n";
std::cout << "N_COLS=" << n_cols << "\n";
std::cout << "BLOCK_SIZE=" << block_size << "\n";
std::cout << "GRID_X=" << grid.x << "\n";
std::cout << "KERNEL_TIME_MS=" << kernel_time_ms << "\n";
print_u16("OUTPUT=", output_bits);
print_f32("MEAN=", mean_out);
print_f32("RSTD=", rstd_out);
return 0;
}
// Backward payload: grad_output_bits[N_ROWS*N_COLS] input_bits[N_ROWS*N_COLS]
// mean_bits[N_ROWS] rstd_bits[N_ROWS].
std::vector<uint16_t> grad_output_bits(row_count);
std::vector<uint16_t> input_bits(row_count);
// The Rust payload ships grad_output_bits in one batch and input_bits in
// a second batch, so read them in two separate loops rather than an
// interleaved one (which would consume the wrong tokens for the second
// half of grad_output and all of input).
for (std::size_t i = 0; i < row_count; ++i) {
if (!(std::cin >> grad_output_bits[i])) {
std::cerr << "failed to read grad_output element " << i << "\n";
return 9;
}
}
for (std::size_t i = 0; i < row_count; ++i) {
if (!(std::cin >> input_bits[i])) {
std::cerr << "failed to read input element " << i << "\n";
return 10;
}
}
std::vector<uint16_t> mean_bits(n_rows);
std::vector<uint16_t> rstd_bits(n_rows);
// The Rust payload ships mean_bits in one batch and rstd_bits in a
// second batch, so read them in two separate loops rather than an
// interleaved one (which would consume the wrong tokens for the second
// half of mean and all of rstd).
for (int i = 0; i < n_rows; ++i) {
if (!(std::cin >> mean_bits[i])) {
std::cerr << "failed to read mean element " << i << "\n";
return 11;
}
}
for (int i = 0; i < n_rows; ++i) {
if (!(std::cin >> rstd_bits[i])) {
std::cerr << "failed to read rstd element " << i << "\n";
return 12;
}
}
// The grad kernel reads mean and rstd as fp32 (not fp16) so the per-row
// xhat computation matches the CPU oracle exactly. The bwd input kernel
// still reads them as fp16 (its precision budget is ~1e-3 which fits
// grad_input's 1e-2 tolerance); only the grad_gamma / grad_beta path
// needs the fp32 round-trip to stay under 1 ULP.
//
// The mean_f32 and rstd_f32 vectors are read from stdin (after the
// mean_bits / rstd_bits lines), so the host can ship the original fp32
// values without going through a fp16 round-trip.
std::vector<float> mean_f32(n_rows);
std::vector<float> rstd_f32(n_rows);
for (int i = 0; i < n_rows; ++i) {
if (!(std::cin >> mean_f32[i])) {
std::cerr << "failed to read mean_f32 element " << i << "\n";
return 13;
}
}
for (int i = 0; i < n_rows; ++i) {
if (!(std::cin >> rstd_f32[i])) {
std::cerr << "failed to read rstd_f32 element " << i << "\n";
return 14;
}
}
std::vector<uint16_t> grad_input_bits(row_count);
std::vector<uint16_t> grad_gamma_bits(col_count, 0);
std::vector<uint16_t> grad_beta_bits(col_count, 0);
// The kernel accumulates into fp32 buffers; we convert on the host side.
std::vector<float> grad_gamma_f32(col_count, 0.0f);
std::vector<float> grad_beta_f32(col_count, 0.0f);
__half* d_grad_output = nullptr;
__half* d_input = nullptr;
__half* d_mean = nullptr;
__half* d_rstd = nullptr;
__half* d_grad_input = nullptr;
__half* d_grad_gamma = nullptr;
__half* d_grad_beta = nullptr;
float* d_mean_f32 = nullptr;
float* d_rstd_f32 = nullptr;
float* d_grad_gamma_f32 = nullptr;
float* d_grad_beta_f32 = nullptr;
check(hipMalloc(&d_grad_output, row_count * sizeof(__half)), "hipMalloc(grad_output)");
check(hipMalloc(&d_input, row_count * sizeof(__half)), "hipMalloc(input)");
check(hipMalloc(&d_mean, static_cast<std::size_t>(n_rows) * sizeof(__half)),
"hipMalloc(mean)");
check(hipMalloc(&d_rstd, static_cast<std::size_t>(n_rows) * sizeof(__half)),
"hipMalloc(rstd)");
check(hipMalloc(&d_grad_input, row_count * sizeof(__half)), "hipMalloc(grad_input)");
check(hipMalloc(&d_grad_gamma, col_bytes), "hipMalloc(grad_gamma)");
check(hipMalloc(&d_grad_beta, col_bytes), "hipMalloc(grad_beta)");
check(hipMalloc(&d_mean_f32, static_cast<std::size_t>(n_rows) * sizeof(float)),
"hipMalloc(mean_f32)");
check(hipMalloc(&d_rstd_f32, static_cast<std::size_t>(n_rows) * sizeof(float)),
"hipMalloc(rstd_f32)");
check(hipMalloc(&d_grad_gamma_f32, col_count * sizeof(float)),
"hipMalloc(grad_gamma_f32)");
check(hipMalloc(&d_grad_beta_f32, col_count * sizeof(float)),
"hipMalloc(grad_beta_f32)");
check(hipMemcpy(d_grad_output, grad_output_bits.data(), row_count * sizeof(__half),
hipMemcpyHostToDevice), "hipMemcpy(grad_output)");
check(hipMemcpy(d_input, input_bits.data(), row_count * sizeof(__half),
hipMemcpyHostToDevice), "hipMemcpy(input)");
check(hipMemcpy(d_mean, mean_bits.data(),
static_cast<std::size_t>(n_rows) * sizeof(__half),
hipMemcpyHostToDevice), "hipMemcpy(mean)");
check(hipMemcpy(d_rstd, rstd_bits.data(),
static_cast<std::size_t>(n_rows) * sizeof(__half),
hipMemcpyHostToDevice), "hipMemcpy(rstd)");
check(hipMemcpy(d_mean_f32, mean_f32.data(),
static_cast<std::size_t>(n_rows) * sizeof(float),
hipMemcpyHostToDevice), "hipMemcpy(mean_f32)");
check(hipMemcpy(d_rstd_f32, rstd_f32.data(),
static_cast<std::size_t>(n_rows) * sizeof(float),
hipMemcpyHostToDevice), "hipMemcpy(rstd_f32)");
check(hipMemset(d_grad_gamma_f32, 0, col_count * sizeof(float)),
"hipMemset(grad_gamma_f32)");
check(hipMemset(d_grad_beta_f32, 0, col_count * sizeof(float)),
"hipMemset(grad_beta_f32)");
dim3 grid(n_rows);
dim3 block(block_size);
hipEvent_t start;
hipEvent_t stop;
check(hipEventCreate(&start), "hipEventCreate(start)");
check(hipEventCreate(&stop), "hipEventCreate(stop)");
check(hipEventRecord(start), "hipEventRecord(start)");
// Pass 1: grad_input (one block per row). The grad_gamma / grad_beta
// parameters were removed in favour of the column-parallel kernel below.
hipLaunchKernelGGL(layernorm_bw_fp16_f32_kernel, grid, block, 0, 0,
d_grad_output, d_input, d_gamma, d_mean, d_rstd,
d_grad_input,
n_rows, n_cols, inv_n_cols);
check(hipGetLastError(), "hipLaunchKernelGGL(bwd)");
// Pass 2: grad_gamma, grad_beta (one block per column, deterministic
// fp32 block reduction, no atomicAdd). Each block owns a unique
// column, so the kernel overwrites grad_gamma_f32 / grad_beta_f32
// directly without a prior zeroing pass being logically required (we
// still memset above to keep the buffer deterministic in case of a
// short launch that aborts before any block runs).
dim3 grad_grid(n_cols);
dim3 grad_block(grad_block_size);
hipLaunchKernelGGL(layernorm_bw_grad_kernel, grad_grid, grad_block, 0, 0,
d_grad_output, d_input, d_gamma, d_mean_f32, d_rstd_f32,
d_grad_gamma_f32, d_grad_beta_f32,
n_rows, n_cols);
check(hipGetLastError(), "hipLaunchKernelGGL(bw_grad)");
check(hipEventRecord(stop), "hipEventRecord(stop)");
check(hipEventSynchronize(stop), "hipEventSynchronize");
float kernel_time_ms = 0.0f;
check(hipEventElapsedTime(&kernel_time_ms, start, stop), "hipEventElapsedTime");
check(hipEventDestroy(start), "hipEventDestroy(start)");
check(hipEventDestroy(stop), "hipEventDestroy(stop)");
check(hipMemcpy(grad_input_bits.data(), d_grad_input, row_count * sizeof(__half),
hipMemcpyDeviceToHost), "hipMemcpy(grad_input)");
check(hipMemcpy(grad_gamma_f32.data(), d_grad_gamma_f32,
col_count * sizeof(float), hipMemcpyDeviceToHost),
"hipMemcpy(grad_gamma_f32)");
check(hipMemcpy(grad_beta_f32.data(), d_grad_beta_f32,
col_count * sizeof(float), hipMemcpyDeviceToHost),
"hipMemcpy(grad_beta_f32)");
// Convert fp32 accumulators to fp16 via round-to-nearest-even. The
// values here are sums of fp16-sized terms across n_rows iterations,
// so they can grow beyond fp16's exponent range for large n_rows;
// for the test cases (n_rows up to a few hundred, fp16 input range)
// this is well within fp16's representable range.
//
// NOTE: assigning a `__half` directly to a `uint16_t` would trigger an
// implicit conversion through `float` (truncating the value to 0/1).
// Use an explicit bit-pattern copy so the round-to-nearest-even fp16
// bits survive the storage.
for (std::size_t i = 0; i < col_count; ++i) {
__half g = __float2half_rn(grad_gamma_f32[i]);
__half b = __float2half_rn(grad_beta_f32[i]);
grad_gamma_bits[i] = *reinterpret_cast<uint16_t*>(&g);
grad_beta_bits[i] = *reinterpret_cast<uint16_t*>(&b);
}
check(hipFree(d_grad_output), "hipFree(grad_output)");
check(hipFree(d_input), "hipFree(input)");
check(hipFree(d_mean), "hipFree(mean)");
check(hipFree(d_rstd), "hipFree(rstd)");
check(hipFree(d_grad_input), "hipFree(grad_input)");
check(hipFree(d_grad_gamma), "hipFree(grad_gamma)");
check(hipFree(d_grad_beta), "hipFree(grad_beta)");
check(hipFree(d_mean_f32), "hipFree(mean_f32)");
check(hipFree(d_rstd_f32), "hipFree(rstd_f32)");
check(hipFree(d_grad_gamma_f32), "hipFree(grad_gamma_f32)");
check(hipFree(d_grad_beta_f32), "hipFree(grad_beta_f32)");
check(hipFree(d_gamma), "hipFree(gamma)");
check(hipFree(d_beta), "hipFree(beta)");
std::cout << "OP=2\n";
std::cout << "DEVICE_NAME=" << props.name << "\n";
std::cout << "GFX=" << props.gcnArchName << "\n";
std::cout << "N_ROWS=" << n_rows << "\n";
std::cout << "N_COLS=" << n_cols << "\n";
std::cout << "BLOCK_SIZE=" << block_size << "\n";
std::cout << "GRID_X=" << grid.x << "\n";
std::cout << "GRAD_BLOCK_SIZE=" << grad_block_size << "\n";
std::cout << "GRAD_GRID_X=" << grad_grid.x << "\n";
std::cout << "KERNEL_TIME_MS=" << kernel_time_ms << "\n";
print_u16("GRAD_INPUT=", grad_input_bits);
print_u16("GRAD_GAMMA=", grad_gamma_bits);
print_u16("GRAD_BETA=", grad_beta_bits);
return 0;
}
"#;
#[derive(Debug, Clone, PartialEq)]
pub struct RocmHipLayernormFwdReport {
pub n_rows: usize,
pub n_cols: usize,
pub block_size: usize,
pub output: Vec<u16>,
pub mean: Vec<f32>,
pub rstd: Vec<f32>,
pub cpu_oracle_output: Vec<u16>,
pub cpu_oracle_mean: Vec<f32>,
pub cpu_oracle_rstd: Vec<f32>,
pub max_abs_error_output: f32,
pub max_abs_error_mean: f32,
pub max_abs_error_rstd: f32,
pub within_tolerance: bool,
pub kernel_time_ms: f32,
pub kernel_source_fingerprint: String,
pub compiler_fingerprint: String,
pub build_command: String,
pub executable_path: String,
pub device_evidence: RocmHipCapabilityReport,
pub evidence: Vec<String>,
pub non_claims: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RocmHipLayernormBwdReport {
pub n_rows: usize,
pub n_cols: usize,
pub block_size: usize,
pub grad_input: Vec<u16>,
pub grad_gamma: Vec<u16>,
pub grad_beta: Vec<u16>,
pub cpu_oracle_grad_input: Vec<u16>,
pub cpu_oracle_grad_gamma: Vec<u16>,
pub cpu_oracle_grad_beta: Vec<u16>,
pub max_abs_error_grad_input: f32,
pub max_abs_error_grad_gamma: f32,
pub max_abs_error_grad_beta: f32,
pub within_tolerance: bool,
pub kernel_time_ms: f32,
pub kernel_source_fingerprint: String,
pub compiler_fingerprint: String,
pub build_command: String,
pub executable_path: String,
pub device_evidence: RocmHipCapabilityReport,
pub evidence: Vec<String>,
pub non_claims: Vec<String>,
}
pub fn run_rocm_hip_layernorm_fwd(
input: &[u16],
gamma: &[u16],
beta: &[u16],
n_rows: usize,
n_cols: usize,
eps: f32,
) -> Result<RocmHipLayernormFwdReport> {
let block_size = layernorm_block_size(n_cols);
let fwd = run_layernorm_fwd_kernel(input, gamma, beta, n_rows, n_cols, block_size, eps)?;
let (cpu_output, cpu_mean, cpu_rstd) =
cpu_layernorm_fwd(input, gamma, beta, n_rows, n_cols, eps);
let max_abs_error_output = max_abs_err_u16(&fwd.output, &cpu_output);
let mut max_abs_error_mean = 0.0f32;
for (g, c) in fwd.mean.iter().zip(cpu_mean.iter()) {
let err = (g - c).abs();
if err > max_abs_error_mean {
max_abs_error_mean = err;
}
}
let mut max_abs_error_rstd = 0.0f32;
for (g, c) in fwd.rstd.iter().zip(cpu_rstd.iter()) {
let err = (g - c).abs();
if err > max_abs_error_rstd {
max_abs_error_rstd = err;
}
}
let within_tolerance =
max_abs_error_output < 1e-2 && max_abs_error_mean < 1e-2 && max_abs_error_rstd < 1e-2;
Ok(RocmHipLayernormFwdReport {
n_rows,
n_cols,
block_size,
output: fwd.output,
mean: fwd.mean,
rstd: fwd.rstd,
cpu_oracle_output: cpu_output,
cpu_oracle_mean: cpu_mean,
cpu_oracle_rstd: cpu_rstd,
max_abs_error_output,
max_abs_error_mean,
max_abs_error_rstd,
within_tolerance,
kernel_time_ms: fwd.kernel_time_ms,
kernel_source_fingerprint: fwd.kernel_source_fingerprint,
compiler_fingerprint: fwd.compiler_fingerprint,
build_command: fwd.build_command,
executable_path: fwd.executable_path,
device_evidence: fwd.device_evidence,
evidence: vec![
"compiled HIP kernel with /opt/rocm/bin/hipcc -O2 --offload-arch=gfx1101".to_string(),
format!(
"shipped input/gamma/beta to the kernel via stdin ({} + {} + {} bits)",
input.len(),
gamma.len(),
beta.len()
),
format!(
"launched layernorm_fwd_fp16_f32_kernel with grid=({n_rows}) block=({block_size})"
),
"captured kernel time with hipEventRecord/hipEventSynchronize".to_string(),
"compared every output element, mean, and rstd against the CPU fp32 oracle within 1e-2"
.to_string(),
],
non_claims: vec![
"not production speedup evidence".to_string(),
"not a fused-attention / rotary / residual LayerNorm".to_string(),
"not vectorized fp16 loads / shared-memory tiling of the row reduction".to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn run_rocm_hip_layernorm_bwd(
grad_output: &[u16],
input: &[u16],
gamma: &[u16],
mean: &[f32],
rstd: &[f32],
n_rows: usize,
n_cols: usize,
) -> Result<RocmHipLayernormBwdReport> {
let block_size = layernorm_block_size(n_cols);
let grad_block_size = layernorm_grad_block_size(n_rows);
let bwd = run_layernorm_bwd_kernel(
grad_output,
input,
gamma,
mean,
rstd,
n_rows,
n_cols,
block_size,
grad_block_size,
)?;
let (cpu_grad_input, cpu_grad_gamma, cpu_grad_beta) =
cpu_layernorm_bwd(grad_output, input, gamma, mean, rstd, n_rows, n_cols);
let max_abs_error_grad_input = max_abs_err_u16(&bwd.grad_input, &cpu_grad_input);
let max_abs_error_grad_gamma = max_abs_err_u16(&bwd.grad_gamma, &cpu_grad_gamma);
let max_abs_error_grad_beta = max_abs_err_u16(&bwd.grad_beta, &cpu_grad_beta);
let within_tolerance = max_abs_error_grad_input < 1e-2
&& max_abs_error_grad_gamma < 1e-2
&& max_abs_error_grad_beta < 1e-2;
Ok(RocmHipLayernormBwdReport {
n_rows,
n_cols,
block_size,
grad_input: bwd.grad_input,
grad_gamma: bwd.grad_gamma,
grad_beta: bwd.grad_beta,
cpu_oracle_grad_input: cpu_grad_input,
cpu_oracle_grad_gamma: cpu_grad_gamma,
cpu_oracle_grad_beta: cpu_grad_beta,
max_abs_error_grad_input,
max_abs_error_grad_gamma,
max_abs_error_grad_beta,
within_tolerance,
kernel_time_ms: bwd.kernel_time_ms,
kernel_source_fingerprint: bwd.kernel_source_fingerprint,
compiler_fingerprint: bwd.compiler_fingerprint,
build_command: bwd.build_command,
executable_path: bwd.executable_path,
device_evidence: bwd.device_evidence,
evidence: vec![
"compiled HIP kernel with /opt/rocm/bin/hipcc -O2 --offload-arch=gfx1101".to_string(),
format!("shipped grad_output/input/gamma/mean/rstd to the kernel via stdin ({} + {} + {} + {} + {} bits)",
grad_output.len(), input.len(), gamma.len(), mean.len(), rstd.len()),
format!("launched layernorm_bw_fp16_f32_kernel with grid=({n_rows}) block=({block_size}) for grad_input"),
format!("launched layernorm_bw_grad_kernel with grid=({n_cols}) block=({grad_block_size}) for grad_gamma/grad_beta"),
"captured kernel time with hipEventRecord/hipEventSynchronize".to_string(),
"accumulated grad_gamma and grad_beta in fp32 via a one-block-per-column deterministic warp + shared-memory block reduction (no atomicAdd across blocks) and converted to fp16 on host"
.to_string(),
"compared grad_input, grad_gamma, grad_beta against the CPU fp32 oracle within 1e-2"
.to_string(),
],
non_claims: vec![
"not production speedup evidence".to_string(),
"not a fused / online LayerNorm backward (no recomputation, no recompute-on-demand)"
.to_string(),
"not vectorized fp16 loads / shared-memory tiling of the row reduction".to_string(),
"not a single-kernel fused grad_input + grad_gamma + grad_beta launch (two separate kernels: row-major for grad_input, column-major for grad_gamma/grad_beta)"
.to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn layernorm_block_size(n_cols: usize) -> usize {
let target = ((n_cols + 7) / 8).next_power_of_two().max(32);
target.min(1024)
}
pub fn layernorm_grad_block_size(n_rows: usize) -> usize {
let mut bs = 32usize;
while bs < n_rows && bs < 1024 {
bs *= 2;
}
bs
}
struct FwdKernelResult {
output: Vec<u16>,
mean: Vec<f32>,
rstd: Vec<f32>,
kernel_time_ms: f32,
kernel_source_fingerprint: String,
compiler_fingerprint: String,
build_command: String,
executable_path: String,
device_evidence: RocmHipCapabilityReport,
}
struct BwdKernelResult {
grad_input: Vec<u16>,
grad_gamma: Vec<u16>,
grad_beta: Vec<u16>,
kernel_time_ms: f32,
kernel_source_fingerprint: String,
compiler_fingerprint: String,
build_command: String,
executable_path: String,
device_evidence: RocmHipCapabilityReport,
}
fn run_layernorm_fwd_kernel(
input: &[u16],
gamma: &[u16],
beta: &[u16],
n_rows: usize,
n_cols: usize,
block_size: usize,
eps: f32,
) -> Result<FwdKernelResult> {
if n_rows == 0 || n_cols == 0 {
return Err(Error::backend("LayerNorm forward shape must be positive"));
}
if input.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"LayerNorm forward input length {} does not match n_rows*n_cols={}",
input.len(),
n_rows * n_cols
)));
}
if gamma.len() != n_cols {
return Err(Error::backend(format!(
"LayerNorm forward gamma length {} does not match n_cols={}",
gamma.len(),
n_cols
)));
}
if beta.len() != n_cols {
return Err(Error::backend(format!(
"LayerNorm forward beta length {} does not match n_cols={}",
beta.len(),
n_cols
)));
}
let beta_required = beta.to_vec();
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; LayerNorm forward pilot remains inadmissible",
));
}
let (
executable_path,
build_command,
kernel_source_fingerprint,
compiler_fingerprint,
source_path,
) = compile_layernorm_kernel()?;
let mut payload = String::new();
payload.push_str(&format!("1 {n_rows} {n_cols} {block_size} {eps}\n"));
append_u16_line(&mut payload, gamma);
append_u16_line(&mut payload, &beta_required);
append_u16_line(&mut payload, input);
payload.push('\n');
let stdout = run_layernorm_executable(&executable_path, &source_path, &payload)?;
let kernel_time_ms = parse_f32_line(&stdout, "KERNEL_TIME_MS=").ok_or_else(|| {
Error::backend("HIP LayerNorm forward did not print KERNEL_TIME_MS marker")
})?;
let output = parse_u16_line(&stdout, "OUTPUT=")
.ok_or_else(|| Error::backend("HIP LayerNorm forward did not print OUTPUT marker"))?;
let mean = parse_f32_vec_line(&stdout, "MEAN=")
.ok_or_else(|| Error::backend("HIP LayerNorm forward did not print MEAN marker"))?;
let rstd = parse_f32_vec_line(&stdout, "RSTD=")
.ok_or_else(|| Error::backend("HIP LayerNorm forward did not print RSTD marker"))?;
if output.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"HIP LayerNorm forward OUTPUT length {} does not match n_rows*n_cols={}",
output.len(),
n_rows * n_cols
)));
}
if mean.len() != n_rows {
return Err(Error::backend(format!(
"HIP LayerNorm forward MEAN length {} does not match n_rows={}",
mean.len(),
n_rows
)));
}
if rstd.len() != n_rows {
return Err(Error::backend(format!(
"HIP LayerNorm forward RSTD length {} does not match n_rows={}",
rstd.len(),
n_rows
)));
}
Ok(FwdKernelResult {
output,
mean,
rstd,
kernel_time_ms,
kernel_source_fingerprint,
compiler_fingerprint,
build_command,
executable_path: executable_path.display().to_string(),
device_evidence,
})
}
fn run_layernorm_bwd_kernel(
grad_output: &[u16],
input: &[u16],
gamma: &[u16],
mean: &[f32],
rstd: &[f32],
n_rows: usize,
n_cols: usize,
block_size: usize,
grad_block_size: usize,
) -> Result<BwdKernelResult> {
if n_rows == 0 || n_cols == 0 {
return Err(Error::backend("LayerNorm backward shape must be positive"));
}
if grad_output.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"LayerNorm backward grad_output length {} does not match n_rows*n_cols={}",
grad_output.len(),
n_rows * n_cols
)));
}
if input.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"LayerNorm backward input length {} does not match n_rows*n_cols={}",
input.len(),
n_rows * n_cols
)));
}
if gamma.len() != n_cols {
return Err(Error::backend(format!(
"LayerNorm backward gamma length {} does not match n_cols={}",
gamma.len(),
n_cols
)));
}
if mean.len() != n_rows {
return Err(Error::backend(format!(
"LayerNorm backward mean length {} does not match n_rows={}",
mean.len(),
n_rows
)));
}
if rstd.len() != n_rows {
return Err(Error::backend(format!(
"LayerNorm backward rstd length {} does not match n_rows={}",
rstd.len(),
n_rows
)));
}
let beta_placeholder: Vec<u16> = vec![0u16; n_cols];
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; LayerNorm backward pilot remains inadmissible",
));
}
let (
executable_path,
build_command,
kernel_source_fingerprint,
compiler_fingerprint,
source_path,
) = compile_layernorm_kernel()?;
let mut payload = String::new();
payload.push_str(&format!(
"2 {n_rows} {n_cols} {block_size} 0.0 {grad_block_size}\n"
));
append_u16_line(&mut payload, gamma);
append_u16_line(&mut payload, &beta_placeholder);
append_u16_line(&mut payload, grad_output);
append_u16_line(&mut payload, input);
let mean_bits: Vec<u16> = mean.iter().copied().map(f32_to_f16).collect();
let rstd_bits: Vec<u16> = rstd.iter().copied().map(f32_to_f16).collect();
append_u16_line(&mut payload, &mean_bits);
append_u16_line(&mut payload, &rstd_bits);
append_f32_line(&mut payload, mean);
append_f32_line(&mut payload, rstd);
payload.push('\n');
let stdout = run_layernorm_executable(&executable_path, &source_path, &payload)?;
let kernel_time_ms = parse_f32_line(&stdout, "KERNEL_TIME_MS=").ok_or_else(|| {
Error::backend("HIP LayerNorm backward did not print KERNEL_TIME_MS marker")
})?;
let grad_input = parse_u16_line(&stdout, "GRAD_INPUT=")
.ok_or_else(|| Error::backend("HIP LayerNorm backward did not print GRAD_INPUT marker"))?;
let grad_gamma = parse_u16_line(&stdout, "GRAD_GAMMA=")
.ok_or_else(|| Error::backend("HIP LayerNorm backward did not print GRAD_GAMMA marker"))?;
let grad_beta = parse_u16_line(&stdout, "GRAD_BETA=")
.ok_or_else(|| Error::backend("HIP LayerNorm backward did not print GRAD_BETA marker"))?;
if grad_input.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"HIP LayerNorm backward GRAD_INPUT length {} does not match n_rows*n_cols={}",
grad_input.len(),
n_rows * n_cols
)));
}
if grad_gamma.len() != n_cols {
return Err(Error::backend(format!(
"HIP LayerNorm backward GRAD_GAMMA length {} does not match n_cols={}",
grad_gamma.len(),
n_cols
)));
}
if grad_beta.len() != n_cols {
return Err(Error::backend(format!(
"HIP LayerNorm backward GRAD_BETA length {} does not match n_cols={}",
grad_beta.len(),
n_cols
)));
}
Ok(BwdKernelResult {
grad_input,
grad_gamma,
grad_beta,
kernel_time_ms,
kernel_source_fingerprint,
compiler_fingerprint,
build_command,
executable_path: executable_path.display().to_string(),
device_evidence,
})
}
fn compile_layernorm_kernel() -> Result<(PathBuf, String, String, String, PathBuf)> {
let source_fingerprint = layernorm_kernel_source_fingerprint();
let cache_dir = PathBuf::from("target/rocm-hip-cache");
fs::create_dir_all(&cache_dir)
.map_err(|err| Error::backend(format!("failed to create HIP cache directory: {err}")))?;
let source_path = cache_dir.join(format!("{source_fingerprint}.cpp"));
let executable_path = cache_dir.join(format!("{source_fingerprint}-layernorm-fp16-f32"));
fs::write(&source_path, HIP_LAYERNORM_FWD_BWD_KERNEL)
.map_err(|err| Error::backend(format!("failed to write HIP kernel source: {err}")))?;
let hipcc = "/opt/rocm/bin/hipcc";
let compiler_fingerprint = hipcc_compiler_fingerprint(hipcc)?;
let build_command =
hipcc_compile_executable(hipcc, &source_path, &executable_path, Some("gfx1101"))?;
Ok((
executable_path,
build_command,
source_fingerprint,
compiler_fingerprint,
source_path,
))
}
fn run_layernorm_executable(
executable_path: &Path,
source_path: &Path,
payload: &str,
) -> Result<String> {
hipcc_recheck_artifact(
"/opt/rocm/bin/hipcc",
source_path,
executable_path,
Some("gfx1101"),
)?;
kernel_server::run_persistent(LAYERNORM_FWD_BWD_KERNEL_TYPE, executable_path, payload)
}
pub fn layernorm_kernel_source_fingerprint() -> String {
fingerprint(
"hip-layernorm-fp16-f32-source",
HIP_LAYERNORM_FWD_BWD_KERNEL,
)
}
pub fn cpu_layernorm_fwd(
input: &[u16],
gamma: &[u16],
beta: &[u16],
n_rows: usize,
n_cols: usize,
eps: f32,
) -> (Vec<u16>, Vec<f32>, Vec<f32>) {
let input_f32: Vec<f32> = input.iter().copied().map(f16_to_f32).collect();
let gamma_f32: Vec<f32> = gamma.iter().copied().map(f16_to_f32).collect();
let beta_f32: Vec<f32> = beta.iter().copied().map(f16_to_f32).collect();
let mut output = vec![0u16; n_rows * n_cols];
let mut mean = vec![0.0f32; n_rows];
let mut rstd = vec![0.0f32; n_rows];
for r in 0..n_rows {
let row = &input_f32[r * n_cols..(r + 1) * n_cols];
let m: f32 = row.iter().sum::<f32>() / n_cols as f32;
let var: f32 = row.iter().map(|v| (v - m) * (v - m)).sum::<f32>() / n_cols as f32;
let rs = 1.0f32 / (var + eps).sqrt();
mean[r] = m;
rstd[r] = rs;
for c in 0..n_cols {
let normalized = (row[c] - m) * rs;
let out = normalized * gamma_f32[c] + beta_f32[c];
output[r * n_cols + c] = f32_to_f16(out);
}
}
(output, mean, rstd)
}
pub fn cpu_layernorm_bwd(
grad_output: &[u16],
input: &[u16],
gamma: &[u16],
mean: &[f32],
rstd: &[f32],
n_rows: usize,
n_cols: usize,
) -> (Vec<u16>, Vec<u16>, Vec<u16>) {
let dy_f32: Vec<f32> = grad_output.iter().copied().map(f16_to_f32).collect();
let x_f32: Vec<f32> = input.iter().copied().map(f16_to_f32).collect();
let g_f32: Vec<f32> = gamma.iter().copied().map(f16_to_f32).collect();
let mut grad_input = vec![0u16; n_rows * n_cols];
let mut grad_gamma = vec![0.0f32; n_cols];
let mut grad_beta = vec![0.0f32; n_cols];
let inv_n = 1.0f32 / n_cols as f32;
for r in 0..n_rows {
let row_dy = &dy_f32[r * n_cols..(r + 1) * n_cols];
let row_x = &x_f32[r * n_cols..(r + 1) * n_cols];
let m = mean[r];
let rs = rstd[r];
let mut sum_dxhat = 0.0f32;
let mut sum_dxhat_xhat = 0.0f32;
let mut xhat_row = vec![0.0f32; n_cols];
for c in 0..n_cols {
let xhat = (row_x[c] - m) * rs;
let dxhat = row_dy[c] * g_f32[c];
xhat_row[c] = xhat;
sum_dxhat += dxhat;
sum_dxhat_xhat += dxhat * xhat;
}
for c in 0..n_cols {
let xhat = xhat_row[c];
let dxhat = row_dy[c] * g_f32[c];
let dx = inv_n * rs * (n_cols as f32 * dxhat - sum_dxhat - xhat * sum_dxhat_xhat);
grad_input[r * n_cols + c] = f32_to_f16(dx);
grad_gamma[c] += xhat * row_dy[c];
grad_beta[c] += row_dy[c];
}
}
let grad_gamma_bits: Vec<u16> = grad_gamma.iter().copied().map(f32_to_f16).collect();
let grad_beta_bits: Vec<u16> = grad_beta.iter().copied().map(f32_to_f16).collect();
(grad_input, grad_gamma_bits, grad_beta_bits)
}
pub use crate::backend::f16_convert::{f16_to_f32, f32_to_f16};
fn max_abs_err_u16(a: &[u16], b: &[u16]) -> f32 {
let mut max_err = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
let err = (f16_to_f32(*x) - f16_to_f32(*y)).abs();
if err > max_err {
max_err = err;
}
}
max_err
}
fn append_u16_line(out: &mut String, values: &[u16]) {
for (i, v) in values.iter().enumerate() {
if i != 0 {
out.push(' ');
}
out.push_str(&v.to_string());
}
out.push('\n');
}
fn append_f32_line(out: &mut String, values: &[f32]) {
for (i, v) in values.iter().enumerate() {
if i != 0 {
out.push(' ');
}
out.push_str(&format!("{v}"));
}
out.push('\n');
}
fn parse_f32_line(stdout: &str, prefix: &str) -> Option<f32> {
stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.and_then(|value| value.trim().parse::<f32>().ok())
}
fn parse_u16_line(stdout: &str, prefix: &str) -> Option<Vec<u16>> {
let line = stdout.lines().find_map(|line| line.strip_prefix(prefix))?;
if line.trim().is_empty() {
return Some(Vec::new());
}
Some(
line.split(',')
.map(|value| {
value
.trim()
.parse::<u16>()
.expect("HIP output should be a u16 integer token")
})
.collect(),
)
}
fn parse_f32_vec_line(stdout: &str, prefix: &str) -> Option<Vec<f32>> {
let line = stdout.lines().find_map(|line| line.strip_prefix(prefix))?;
if line.trim().is_empty() {
return Some(Vec::new());
}
Some(
line.split(',')
.map(|value| {
value
.trim()
.parse::<f32>()
.expect("HIP output should be an f32 token")
})
.collect(),
)
}
fn fingerprint(label: &str, value: &str) -> String {
let mut hasher = DefaultHasher::new();
label.hash(&mut hasher);
value.hash(&mut hasher);
format!("{label}-{:016x}", hasher.finish())
}