use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use crate::backend::hip_dense::{
hipcc_compile_executable, hipcc_compiler_fingerprint, hipcc_recheck_artifact,
};
use crate::backend::hip_gemm_f16::{f16_to_f32, f32_to_f16};
use crate::backend::kernel_server;
use crate::backend::rocm::{RocmHipCapabilityReport, detect_local_rocm_hip};
use crate::{Error, Result};
pub const ROCM_HIP_SOFTMAX_FWD_BACKEND: &str = "rocm_hip_softmax_fwd_pilot";
pub const ROCM_HIP_GRAD_LOSS_WRT_LOGITS_BACKEND: &str = "rocm_hip_grad_loss_wrt_logits_pilot";
pub const ROCM_HIP_SOFTMAX_FWD_LOWERING_ID: &str = "hip.softmax.fp16_f32_fwd";
pub const ROCM_HIP_GRAD_LOSS_WRT_LOGITS_LOWERING_ID: &str = "hip.softmax.grad.fp16_f32";
const SOFTMAX_FWD_GRAD_KERNEL_TYPE: &str = "hip-softmax-fwd-grad";
pub const HIP_SOFTMAX_KERNEL: &str = r#"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdlib>
#include <cmath>
#include <iostream>
#include <limits>
#include <sstream>
#include <string>
#include <vector>
#define HIP_SOFTMAX_MODE_FWD 0
#define HIP_SOFTMAX_MODE_GRAD 1
#define HIP_SOFTMAX_BLOCK 256
// Numerically stable row-wise softmax. One block per row.
// Phase 1: parallel max reduction in shared memory.
// Phase 2: write exp(x - max) to output and accumulate sum in shared
// memory.
// Phase 3: multiply every output element by 1 / sum.
__global__ void softmax_fwd_fp16_f32_kernel(
const __half* input,
__half* output,
int n_rows,
int n_cols) {
extern __shared__ float sdata[];
int row = blockIdx.x;
if (row >= n_rows) {
return;
}
const __half* row_in = input + static_cast<std::size_t>(row) * static_cast<std::size_t>(n_cols);
__half* row_out = output + static_cast<std::size_t>(row) * static_cast<std::size_t>(n_cols);
int tid = threadIdx.x;
int block_size = blockDim.x;
// Phase 1: local max -> block-wide max.
float local_max = -std::numeric_limits<float>::infinity();
for (int j = tid; j < n_cols; j += block_size) {
float v = __half2float(row_in[j]);
if (v > local_max) {
local_max = v;
}
}
sdata[tid] = local_max;
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
float a = sdata[tid];
float b = sdata[tid + s];
sdata[tid] = (a > b) ? a : b;
}
__syncthreads();
}
float row_max = sdata[0];
__syncthreads();
// Phase 2: exp(x - max), accumulate sum, write to output.
float local_sum = 0.0f;
for (int j = tid; j < n_cols; j += block_size) {
float v = __half2float(row_in[j]) - row_max;
float e = expf(v);
row_out[j] = __float2half_rn(e);
local_sum += e;
}
sdata[tid] = local_sum;
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = sdata[tid] + sdata[tid + s];
}
__syncthreads();
}
float row_sum = sdata[0];
__syncthreads();
float inv_sum = 1.0f / row_sum;
// Phase 3: normalize.
for (int j = tid; j < n_cols; j += block_size) {
float v = __half2float(row_out[j]) * inv_sum;
row_out[j] = __float2half_rn(v);
}
}
// Standard "softmax + cross-entropy" backward:
// grad_input = softmax_output * (grad_output - dot)
// where dot = sum(grad_output * softmax_output) per row.
// One block per row. Phase 1 reduces dot in shared memory; Phase 2
// writes the per-element result.
__global__ void grad_loss_wrt_logits_fp16_f32_kernel(
const __half* grad_output,
const __half* softmax_output,
__half* grad_input,
int n_rows,
int n_cols) {
extern __shared__ float sdata[];
int row = blockIdx.x;
if (row >= n_rows) {
return;
}
const __half* row_grad_out = grad_output + static_cast<std::size_t>(row) * static_cast<std::size_t>(n_cols);
const __half* row_softmax = softmax_output + static_cast<std::size_t>(row) * static_cast<std::size_t>(n_cols);
__half* row_grad_in = grad_input + static_cast<std::size_t>(row) * static_cast<std::size_t>(n_cols);
int tid = threadIdx.x;
int block_size = blockDim.x;
// Phase 1: dot = sum(grad_output * softmax_output).
float local_dot = 0.0f;
for (int j = tid; j < n_cols; j += block_size) {
float g = __half2float(row_grad_out[j]);
float s = __half2float(row_softmax[j]);
local_dot += g * s;
}
sdata[tid] = local_dot;
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = sdata[tid] + sdata[tid + s];
}
__syncthreads();
}
float dot = sdata[0];
__syncthreads();
// Phase 2: grad_input = softmax * (grad_output - dot).
for (int j = tid; j < n_cols; j += block_size) {
float g = __half2float(row_grad_out[j]);
float s = __half2float(row_softmax[j]);
float v = s * (g - dot);
row_grad_in[j] = __float2half_rn(v);
}
}
static void hip_softmax_check(hipError_t status, const char* label) {
if (status != hipSuccess) {
std::cerr << "HIP_ERROR " << label << "=" << hipGetErrorString(status) << "\n";
std::exit(10);
}
}
// Forward declaration of the existing main() body, extracted into
// a static helper so the server-mode loop can call it on each
// request. The default `main()` also routes through this helper so
// the one-shot and server code paths share the same compute logic.
static int run_one_shot_from_main_body();
// Persistent server-mode protocol (see hip_gemm_f16.rs for the full
// design rationale). The host writes a little-endian u32 payload_len
// followed by `payload_len` bytes of the existing text payload, then
// reads back a little-endian u32 response_len followed by
// `response_len` bytes of the existing text response.
static int run_server_mode() {
while (true) {
uint32_t payload_len = 0;
std::cin.read(reinterpret_cast<char*>(&payload_len), 4);
if (!std::cin || std::cin.gcount() == 0) {
return 0; // clean EOF
}
if (std::cin.gcount() != 4) {
std::cerr << "server_mode: short read on payload_len (got "
<< std::cin.gcount() << " bytes)\n";
return 20;
}
std::vector<char> payload(payload_len);
if (payload_len > 0) {
std::cin.read(payload.data(), payload_len);
if (static_cast<uint32_t>(std::cin.gcount()) != payload_len) {
std::cerr << "server_mode: short read on payload (got "
<< std::cin.gcount() << " of " << payload_len << ")\n";
return 21;
}
}
std::string payload_str(payload.begin(), payload.end());
std::istringstream fake_stdin(payload_str);
std::streambuf* old_buf = std::cin.rdbuf(fake_stdin.rdbuf());
std::ostringstream captured;
std::streambuf* old_cout = std::cout.rdbuf(captured.rdbuf());
std::ostringstream captured_err;
std::streambuf* old_cerr = std::cerr.rdbuf(captured_err.rdbuf());
int rc = run_one_shot_from_main_body();
std::cin.rdbuf(old_buf);
std::cout.rdbuf(old_cout);
std::cerr.rdbuf(old_cerr);
std::string response = captured.str();
if (rc != 0) {
std::string err_str = captured_err.str();
response += err_str;
}
uint32_t response_len = static_cast<uint32_t>(response.size());
std::cout.write(reinterpret_cast<const char*>(&response_len), 4);
if (response_len > 0) {
std::cout.write(response.data(), response_len);
}
std::cout.flush();
if (rc != 0) {
return rc;
}
}
}
int main(int argc, char** argv) {
if (argc > 1 && std::string(argv[1]) == "--server") {
return run_server_mode();
}
return run_one_shot_from_main_body();
}
static int run_one_shot_from_main_body() {
int mode = -1;
int n_rows = 0;
int n_cols = 0;
if (!(std::cin >> mode >> n_rows >> n_cols)) {
std::cerr << "usage: stdin payload is \"MODE N_ROWS N_COLS\\n...\"\n";
return 2;
}
if (mode != HIP_SOFTMAX_MODE_FWD && mode != HIP_SOFTMAX_MODE_GRAD) {
std::cerr << "MODE must be 0 (softmax_fwd) or 1 (grad_loss_wrt_logits), got " << mode << "\n";
return 3;
}
if (n_rows <= 0 || n_cols <= 0) {
std::cerr << "N_ROWS and N_COLS must both be positive, got " << n_rows << " " << n_cols << "\n";
return 4;
}
int device = 0;
hip_softmax_check(hipSetDevice(device), "hipSetDevice");
hipDeviceProp_t props;
hip_softmax_check(hipGetDeviceProperties(&props, device), "hipGetDeviceProperties");
std::size_t count = static_cast<std::size_t>(n_rows) * static_cast<std::size_t>(n_cols);
std::size_t bytes = count * sizeof(__half);
std::size_t shmem = HIP_SOFTMAX_BLOCK * sizeof(float);
int block = HIP_SOFTMAX_BLOCK;
int grid = n_rows;
if (mode == HIP_SOFTMAX_MODE_FWD) {
std::vector<uint16_t> input_bits(count, 0);
for (std::size_t i = 0; i < count; ++i) {
if (!(std::cin >> input_bits[i])) {
std::cerr << "failed to read input element " << i << "\n";
return 5;
}
}
__half* d_input = nullptr;
__half* d_output = nullptr;
hip_softmax_check(hipMalloc(&d_input, bytes), "hipMalloc(d_input)");
hip_softmax_check(hipMalloc(&d_output, bytes), "hipMalloc(d_output)");
hip_softmax_check(hipMemcpy(d_input, input_bits.data(), bytes, hipMemcpyHostToDevice), "hipMemcpy(d_input)");
hipEvent_t start;
hipEvent_t stop;
hip_softmax_check(hipEventCreate(&start), "hipEventCreate(start)");
hip_softmax_check(hipEventCreate(&stop), "hipEventCreate(stop)");
hip_softmax_check(hipEventRecord(start), "hipEventRecord(start)");
hipLaunchKernelGGL(softmax_fwd_fp16_f32_kernel,
dim3(grid), dim3(block), shmem, 0,
d_input, d_output, n_rows, n_cols);
hip_softmax_check(hipGetLastError(), "hipLaunchKernelGGL");
hip_softmax_check(hipEventRecord(stop), "hipEventRecord(stop)");
hip_softmax_check(hipEventSynchronize(stop), "hipEventSynchronize");
float kernel_time_ms = 0.0f;
hip_softmax_check(hipEventElapsedTime(&kernel_time_ms, start, stop), "hipEventElapsedTime");
hip_softmax_check(hipEventDestroy(start), "hipEventDestroy(start)");
hip_softmax_check(hipEventDestroy(stop), "hipEventDestroy(stop)");
std::vector<uint16_t> output_bits(count, 0);
hip_softmax_check(hipMemcpy(output_bits.data(), d_output, bytes, hipMemcpyDeviceToHost), "hipMemcpy(d_output)");
hip_softmax_check(hipFree(d_input), "hipFree(d_input)");
hip_softmax_check(hipFree(d_output), "hipFree(d_output)");
std::cout << "BACKEND=softmax_fwd\n";
std::cout << "DEVICE_NAME=" << props.name << "\n";
std::cout << "GFX=" << props.gcnArchName << "\n";
std::cout << "N_ROWS=" << n_rows << "\n";
std::cout << "N_COLS=" << n_cols << "\n";
std::cout << "GRID=" << grid << "\n";
std::cout << "BLOCK=" << block << "\n";
std::cout << "KERNEL_TIME_MS=" << kernel_time_ms << "\n";
std::cout << "OUTPUT=";
for (std::size_t i = 0; i < output_bits.size(); ++i) {
if (i != 0) {
std::cout << ",";
}
std::cout << static_cast<unsigned int>(output_bits[i]);
}
std::cout << "\n";
} else {
// HIP_SOFTMAX_MODE_GRAD
std::vector<uint16_t> grad_output_bits(count, 0);
std::vector<uint16_t> softmax_output_bits(count, 0);
for (std::size_t i = 0; i < count; ++i) {
if (!(std::cin >> grad_output_bits[i])) {
std::cerr << "failed to read grad_output element " << i << "\n";
return 6;
}
}
for (std::size_t i = 0; i < count; ++i) {
if (!(std::cin >> softmax_output_bits[i])) {
std::cerr << "failed to read softmax_output element " << i << "\n";
return 7;
}
}
__half* d_grad_output = nullptr;
__half* d_softmax_output = nullptr;
__half* d_grad_input = nullptr;
hip_softmax_check(hipMalloc(&d_grad_output, bytes), "hipMalloc(d_grad_output)");
hip_softmax_check(hipMalloc(&d_softmax_output, bytes), "hipMalloc(d_softmax_output)");
hip_softmax_check(hipMalloc(&d_grad_input, bytes), "hipMalloc(d_grad_input)");
hip_softmax_check(hipMemcpy(d_grad_output, grad_output_bits.data(), bytes, hipMemcpyHostToDevice), "hipMemcpy(d_grad_output)");
hip_softmax_check(hipMemcpy(d_softmax_output, softmax_output_bits.data(), bytes, hipMemcpyHostToDevice), "hipMemcpy(d_softmax_output)");
hipEvent_t start;
hipEvent_t stop;
hip_softmax_check(hipEventCreate(&start), "hipEventCreate(start)");
hip_softmax_check(hipEventCreate(&stop), "hipEventCreate(stop)");
hip_softmax_check(hipEventRecord(start), "hipEventRecord(start)");
hipLaunchKernelGGL(grad_loss_wrt_logits_fp16_f32_kernel,
dim3(grid), dim3(block), shmem, 0,
d_grad_output, d_softmax_output, d_grad_input, n_rows, n_cols);
hip_softmax_check(hipGetLastError(), "hipLaunchKernelGGL");
hip_softmax_check(hipEventRecord(stop), "hipEventRecord(stop)");
hip_softmax_check(hipEventSynchronize(stop), "hipEventSynchronize");
float kernel_time_ms = 0.0f;
hip_softmax_check(hipEventElapsedTime(&kernel_time_ms, start, stop), "hipEventElapsedTime");
hip_softmax_check(hipEventDestroy(start), "hipEventDestroy(start)");
hip_softmax_check(hipEventDestroy(stop), "hipEventDestroy(stop)");
std::vector<uint16_t> grad_input_bits(count, 0);
hip_softmax_check(hipMemcpy(grad_input_bits.data(), d_grad_input, bytes, hipMemcpyDeviceToHost), "hipMemcpy(d_grad_input)");
hip_softmax_check(hipFree(d_grad_output), "hipFree(d_grad_output)");
hip_softmax_check(hipFree(d_softmax_output), "hipFree(d_softmax_output)");
hip_softmax_check(hipFree(d_grad_input), "hipFree(d_grad_input)");
std::cout << "BACKEND=grad_loss_wrt_logits\n";
std::cout << "DEVICE_NAME=" << props.name << "\n";
std::cout << "GFX=" << props.gcnArchName << "\n";
std::cout << "N_ROWS=" << n_rows << "\n";
std::cout << "N_COLS=" << n_cols << "\n";
std::cout << "GRID=" << grid << "\n";
std::cout << "BLOCK=" << block << "\n";
std::cout << "KERNEL_TIME_MS=" << kernel_time_ms << "\n";
std::cout << "OUTPUT=";
for (std::size_t i = 0; i < grad_input_bits.size(); ++i) {
if (i != 0) {
std::cout << ",";
}
std::cout << static_cast<unsigned int>(grad_input_bits[i]);
}
std::cout << "\n";
}
return 0;
}
"#;
#[derive(Debug, Clone, PartialEq)]
pub struct RocmHipSoftmaxFwdReport {
pub backend: String,
pub n_rows: usize,
pub n_cols: usize,
pub outputs: Vec<u16>,
pub cpu_oracle_outputs: Vec<u16>,
pub max_abs_error: f32,
pub max_row_sum_abs_error: f32,
pub within_tolerance: bool,
pub kernel_time_ms: f32,
pub kernel_source_fingerprint: String,
pub compiler_fingerprint: String,
pub build_command: String,
pub executable_path: String,
pub device_evidence: RocmHipCapabilityReport,
pub evidence: Vec<String>,
pub non_claims: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RocmHipGradLossWrtLogitsReport {
pub backend: String,
pub n_rows: usize,
pub n_cols: usize,
pub outputs: Vec<u16>,
pub cpu_oracle_outputs: Vec<u16>,
pub max_abs_error: f32,
pub within_tolerance: bool,
pub kernel_time_ms: f32,
pub kernel_source_fingerprint: String,
pub compiler_fingerprint: String,
pub build_command: String,
pub executable_path: String,
pub device_evidence: RocmHipCapabilityReport,
pub evidence: Vec<String>,
pub non_claims: Vec<String>,
}
impl RocmHipSoftmaxFwdReport {
pub fn to_markdown(&self) -> String {
let mut lines = vec![
"# ROCm/HIP fp16 Softmax Forward Pilot".to_string(),
String::new(),
format!("backend: {}", self.backend),
format!("n_rows: {}", self.n_rows),
format!("n_cols: {}", self.n_cols),
format!("max_abs_error: {}", self.max_abs_error),
format!("max_row_sum_abs_error: {}", self.max_row_sum_abs_error),
format!("within_tolerance: {}", self.within_tolerance),
format!("kernel_time_ms: {}", self.kernel_time_ms),
format!(
"kernel_source_fingerprint: {}",
self.kernel_source_fingerprint
),
format!("compiler_fingerprint: {}", self.compiler_fingerprint),
String::new(),
"## Evidence".to_string(),
];
for item in &self.evidence {
lines.push(format!("- {item}"));
}
lines.push(String::new());
lines.push("## Non-Claims".to_string());
for item in &self.non_claims {
lines.push(format!("- {item}"));
}
lines.join("\n")
}
}
impl RocmHipGradLossWrtLogitsReport {
pub fn to_markdown(&self) -> String {
let mut lines = vec![
"# ROCm/HIP fp16 grad-loss-wrt-logits Pilot".to_string(),
String::new(),
format!("backend: {}", self.backend),
format!("n_rows: {}", self.n_rows),
format!("n_cols: {}", self.n_cols),
format!("max_abs_error: {}", self.max_abs_error),
format!("within_tolerance: {}", self.within_tolerance),
format!("kernel_time_ms: {}", self.kernel_time_ms),
format!(
"kernel_source_fingerprint: {}",
self.kernel_source_fingerprint
),
format!("compiler_fingerprint: {}", self.compiler_fingerprint),
String::new(),
"## Evidence".to_string(),
];
for item in &self.evidence {
lines.push(format!("- {item}"));
}
lines.push(String::new());
lines.push("## Non-Claims".to_string());
for item in &self.non_claims {
lines.push(format!("- {item}"));
}
lines.join("\n")
}
}
const SOFTMAX_FWD_MODE: i32 = 0;
const GRAD_LOSS_WRT_LOGITS_MODE: i32 = 1;
fn compile_softmax_kernel() -> Result<(String, String, String, PathBuf, PathBuf)> {
let source_fingerprint = hip_softmax_kernel_source_fingerprint();
let cache_dir = PathBuf::from("target/rocm-hip-cache");
fs::create_dir_all(&cache_dir)
.map_err(|err| Error::backend(format!("failed to create HIP cache directory: {err}")))?;
let source_path = cache_dir.join(format!("{source_fingerprint}.cpp"));
let executable_path = cache_dir.join(format!("{source_fingerprint}-softmax"));
fs::write(&source_path, HIP_SOFTMAX_KERNEL)
.map_err(|err| Error::backend(format!("failed to write HIP kernel source: {err}")))?;
let hipcc = "/opt/rocm/bin/hipcc";
let compiler_fingerprint = hipcc_compiler_fingerprint(hipcc)?;
let build_command =
hipcc_compile_executable(hipcc, &source_path, &executable_path, Some("gfx1101"))?;
Ok((
source_fingerprint,
compiler_fingerprint,
build_command,
source_path,
executable_path,
))
}
fn build_payload_header(mode: i32, n_rows: usize, n_cols: usize) -> String {
format!("{mode} {n_rows} {n_cols}\n")
}
fn append_u16_csv(payload: &mut String, values: &[u16]) {
for (i, v) in values.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&v.to_string());
}
payload.push('\n');
}
fn parse_u16_csv(stdout: &str, prefix: &str) -> Result<Vec<u16>> {
let line = stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.ok_or_else(|| Error::backend(format!("HIP softmax did not print {prefix}")))?;
if line.trim().is_empty() {
return Ok(Vec::new());
}
line.split(',')
.map(|value| {
value
.trim()
.parse::<u32>()
.map(|v| v as u16)
.map_err(|err| Error::backend(format!("invalid HIP softmax u16 {value:?}: {err}")))
})
.collect()
}
fn parse_f32_line(stdout: &str, prefix: &str) -> Option<f32> {
stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.and_then(|value| value.trim().parse::<f32>().ok())
}
fn run_softmax_binary(
executable_path: &Path,
source_path: &Path,
mode: i32,
n_rows: usize,
n_cols: usize,
payload_arrays: &[&[u16]],
) -> Result<String> {
let mut payload = build_payload_header(mode, n_rows, n_cols);
for arr in payload_arrays {
append_u16_csv(&mut payload, arr);
}
hipcc_recheck_artifact(
"/opt/rocm/bin/hipcc",
source_path,
executable_path,
Some("gfx1101"),
)?;
kernel_server::run_persistent(SOFTMAX_FWD_GRAD_KERNEL_TYPE, executable_path, &payload)
}
pub fn run_rocm_hip_softmax_fwd(
input: &[u16],
n_rows: usize,
n_cols: usize,
) -> Result<RocmHipSoftmaxFwdReport> {
if n_rows == 0 || n_cols == 0 {
return Err(Error::backend(
"softmax forward requires positive n_rows and n_cols",
));
}
if input.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"softmax forward input length {} does not match n_rows*n_cols={}",
input.len(),
n_rows * n_cols
)));
}
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; softmax forward pilot remains inadmissible",
));
}
let (source_fingerprint, compiler_fingerprint, build_command, source_path, executable_path) =
compile_softmax_kernel()?;
let stdout = run_softmax_binary(
&executable_path,
&source_path,
SOFTMAX_FWD_MODE,
n_rows,
n_cols,
&[input],
)?;
let outputs = parse_u16_csv(&stdout, "OUTPUT=")?;
let kernel_time_ms = parse_f32_line(&stdout, "KERNEL_TIME_MS=")
.ok_or_else(|| Error::backend("HIP softmax did not print KERNEL_TIME_MS marker"))?;
if outputs.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"HIP softmax forward returned {} outputs, expected {}",
outputs.len(),
n_rows * n_cols
)));
}
let cpu_oracle_outputs = cpu_softmax_fwd(input, n_rows, n_cols);
let (max_abs_error, max_row_sum_abs_error) =
compare_softmax_outputs(&outputs, &cpu_oracle_outputs, n_rows, n_cols);
let within_tolerance = max_abs_error < 1e-2 && max_row_sum_abs_error < 1e-2;
Ok(RocmHipSoftmaxFwdReport {
backend: ROCM_HIP_SOFTMAX_FWD_BACKEND.to_string(),
n_rows,
n_cols,
outputs,
cpu_oracle_outputs,
max_abs_error,
max_row_sum_abs_error,
within_tolerance,
kernel_time_ms,
kernel_source_fingerprint: source_fingerprint,
compiler_fingerprint,
build_command,
executable_path: executable_path.display().to_string(),
device_evidence,
evidence: vec![
"compiled HIP kernel with /opt/rocm/bin/hipcc -O2 --offload-arch=gfx1101".to_string(),
"shipped input bits to the kernel via stdin (Stdio::piped)".to_string(),
"launched softmax_fwd_fp16_f32_kernel with grid=n_rows block=256".to_string(),
"captured kernel time with hipEventRecord/hipEventSynchronize".to_string(),
"compared every output element against the CPU fp32 oracle within 1e-2".to_string(),
"verified every output row sums to 1.0 within 1e-2".to_string(),
],
non_claims: vec![
"not a fused softmax/attention kernel".to_string(),
"not online softmax (single-pass, large-col stable)".to_string(),
"not production speedup evidence".to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn run_rocm_hip_grad_loss_wrt_logits(
grad_output: &[u16],
softmax_output: &[u16],
n_rows: usize,
n_cols: usize,
) -> Result<RocmHipGradLossWrtLogitsReport> {
if n_rows == 0 || n_cols == 0 {
return Err(Error::backend(
"grad_loss_wrt_logits requires positive n_rows and n_cols",
));
}
if grad_output.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"grad_loss_wrt_logits grad_output length {} does not match n_rows*n_cols={}",
grad_output.len(),
n_rows * n_cols
)));
}
if softmax_output.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"grad_loss_wrt_logits softmax_output length {} does not match n_rows*n_cols={}",
softmax_output.len(),
n_rows * n_cols
)));
}
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; grad_loss_wrt_logits pilot remains inadmissible",
));
}
let (source_fingerprint, compiler_fingerprint, build_command, source_path, executable_path) =
compile_softmax_kernel()?;
let stdout = run_softmax_binary(
&executable_path,
&source_path,
GRAD_LOSS_WRT_LOGITS_MODE,
n_rows,
n_cols,
&[grad_output, softmax_output],
)?;
let outputs = parse_u16_csv(&stdout, "OUTPUT=")?;
let kernel_time_ms = parse_f32_line(&stdout, "KERNEL_TIME_MS=")
.ok_or_else(|| Error::backend("HIP softmax did not print KERNEL_TIME_MS marker"))?;
if outputs.len() != n_rows * n_cols {
return Err(Error::backend(format!(
"HIP grad_loss_wrt_logits returned {} outputs, expected {}",
outputs.len(),
n_rows * n_cols
)));
}
let cpu_oracle_outputs = cpu_grad_loss_wrt_logits(grad_output, softmax_output, n_rows, n_cols);
let mut max_abs_error = 0.0f32;
for (g, c) in outputs.iter().zip(cpu_oracle_outputs.iter()) {
let err = (f16_to_f32(*g) - f16_to_f32(*c)).abs();
if err > max_abs_error {
max_abs_error = err;
}
}
let within_tolerance = max_abs_error < 1e-2;
Ok(RocmHipGradLossWrtLogitsReport {
backend: ROCM_HIP_GRAD_LOSS_WRT_LOGITS_BACKEND.to_string(),
n_rows,
n_cols,
outputs,
cpu_oracle_outputs,
max_abs_error,
within_tolerance,
kernel_time_ms,
kernel_source_fingerprint: source_fingerprint,
compiler_fingerprint,
build_command,
executable_path: executable_path.display().to_string(),
device_evidence,
evidence: vec![
"compiled HIP kernel with /opt/rocm/bin/hipcc -O2 --offload-arch=gfx1101".to_string(),
"shipped grad_output and softmax_output bits to the kernel via stdin (Stdio::piped)"
.to_string(),
"launched grad_loss_wrt_logits_fp16_f32_kernel with grid=n_rows block=256".to_string(),
"captured kernel time with hipEventRecord/hipEventSynchronize".to_string(),
"compared every output element against the CPU fp32 oracle within 1e-2".to_string(),
],
non_claims: vec![
"not a fused softmax/cross-entropy backward".to_string(),
"not label-smoothed or label-aware variant".to_string(),
"not production speedup evidence".to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn cpu_softmax_fwd(input: &[u16], n_rows: usize, n_cols: usize) -> Vec<u16> {
let mut output = vec![0u16; n_rows * n_cols];
for r in 0..n_rows {
let row_in = &input[r * n_cols..(r + 1) * n_cols];
let mut row_max = f32::NEG_INFINITY;
for &b in row_in {
let v = f16_to_f32(b);
if v > row_max {
row_max = v;
}
}
let mut exp_vals = vec![0.0f32; n_cols];
let mut sum = 0.0f32;
for c in 0..n_cols {
let e = (f16_to_f32(row_in[c]) - row_max).exp();
exp_vals[c] = e;
sum += e;
}
let inv = 1.0f32 / sum;
for c in 0..n_cols {
output[r * n_cols + c] = f32_to_f16(exp_vals[c] * inv);
}
}
output
}
pub fn cpu_grad_loss_wrt_logits(
grad_output: &[u16],
softmax_output: &[u16],
n_rows: usize,
n_cols: usize,
) -> Vec<u16> {
let mut output = vec![0u16; n_rows * n_cols];
for r in 0..n_rows {
let g_row = &grad_output[r * n_cols..(r + 1) * n_cols];
let s_row = &softmax_output[r * n_cols..(r + 1) * n_cols];
let mut dot = 0.0f32;
for c in 0..n_cols {
dot += f16_to_f32(g_row[c]) * f16_to_f32(s_row[c]);
}
for c in 0..n_cols {
let v = f16_to_f32(s_row[c]) * (f16_to_f32(g_row[c]) - dot);
output[r * n_cols + c] = f32_to_f16(v);
}
}
output
}
fn compare_softmax_outputs(
outputs: &[u16],
cpu_oracle_outputs: &[u16],
n_rows: usize,
n_cols: usize,
) -> (f32, f32) {
let mut max_abs_error = 0.0f32;
for (g, c) in outputs.iter().zip(cpu_oracle_outputs.iter()) {
let err = (f16_to_f32(*g) - f16_to_f32(*c)).abs();
if err > max_abs_error {
max_abs_error = err;
}
}
let mut max_row_sum_abs_error = 0.0f32;
for r in 0..n_rows {
let mut sum = 0.0f32;
for c in 0..n_cols {
sum += f16_to_f32(outputs[r * n_cols + c]);
}
let err = (sum - 1.0f32).abs();
if err > max_row_sum_abs_error {
max_row_sum_abs_error = err;
}
}
(max_abs_error, max_row_sum_abs_error)
}
pub fn hip_softmax_kernel_source_fingerprint() -> String {
fingerprint("hip-softmax-source", HIP_SOFTMAX_KERNEL)
}
fn fingerprint(label: &str, value: &str) -> String {
let mut hasher = DefaultHasher::new();
label.hash(&mut hasher);
value.hash(&mut hasher);
format!("{label}-{:016x}", hasher.finish())
}