use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use tokitai_operator::backend::hip_dense::{
hipcc_compile_executable, hipcc_compiler_fingerprint, hipcc_recheck_artifact,
};
use tokitai_operator::backend::kernel_server;
use tokitai_operator::backend::rocm::{RocmHipCapabilityReport, detect_local_rocm_hip};
use tokitai_operator::{Error, Result};
pub const ROCM_HIP_ADAMW_BACKEND: &str = "rocm_hip_adamw_pilot";
pub const ROCM_HIP_ADAMW_LOWERING_ID: &str = "hip.adamw.fp16_f32_step";
const ADAMW_STEP_KERNEL_TYPE: &str = "hip-adamw-step";
pub const HIP_ADAMW_KERNEL: &str = r#"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
// In-place AdamW step: theta, m, v are all updated in device memory.
// grad is read-only. The host precomputes the per-step scalars
// (beta1, 1-beta1, beta2, 1-beta2, 1/(1-beta1^t), 1/(1-beta2^t))
// so the kernel avoids a per-thread `powf` call, which dominated
// the kernel runtime at small problem sizes.
__global__ void adamw_step_fp16_f32_kernel(
__half* __restrict__ theta,
float* __restrict__ m,
float* __restrict__ v,
const __half* __restrict__ grad,
int n,
float lr,
float beta1,
float one_minus_beta1,
float beta2,
float one_minus_beta2,
float inv_bc1,
float inv_bc2,
float eps,
float weight_decay) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) {
return;
}
// Use __ldg for read-only buffers (theta, grad) to route the load
// through the texture cache. theta is also written later, so we
// use a plain load for the initial read to avoid aliasing concerns
// with the subsequent store.
float theta_f = __half2float(theta[idx]);
float m_f = m[idx];
float v_f = v[idx];
float grad_f = __half2float(__ldg(grad + idx));
// Moment updates in fp32.
m_f = beta1 * m_f + one_minus_beta1 * grad_f;
v_f = beta2 * v_f + one_minus_beta2 * grad_f * grad_f;
// Bias correction applied as a precomputed reciprocal.
float m_hat = m_f * inv_bc1;
float v_hat = v_f * inv_bc2;
// Decoupled weight decay is added directly to the update.
float update = m_hat / (sqrtf(v_hat) + eps) + weight_decay * theta_f;
theta_f = theta_f - lr * update;
theta[idx] = __float2half_rn(theta_f);
m[idx] = m_f;
v[idx] = v_f;
}
static void check(hipError_t status, const char* label) {
if (status != hipSuccess) {
std::cerr << "HIP_ERROR " << label << "=" << hipGetErrorString(status) << "\n";
std::exit(10);
}
}
// Forward declaration of the existing main() body, extracted into
// a static helper so the server-mode loop can call it on each
// request. The default `main()` also routes through this helper so
// the one-shot and server code paths share the same compute logic.
static int run_one_shot_from_main_body();
// Binary AdamW request magic. The host prefixes every binary request
// payload with this 4-byte little-endian value ("ADAW" in ASCII) so
// the server can dispatch to the binary parser. Text payloads start
// with the ASCII representation of an integer (e.g. "4194304 ..."),
// which is extremely unlikely to alias 0x57414441.
static constexpr uint32_t ADAW_BINARY_MAGIC = 0x57414441u;
// Binary request layout (little-endian, no host padding):
// [4B magic=0x57414441] [4B n] [4B lr] [4B beta1] [4B beta2]
// [4B eps] [4B weight_decay] [4B t] [4B pad] -- total 36 bytes
// [n*2B theta] [n*4B m] [n*4B v] [n*2B grad] -- total 12*n bytes
//
// Binary response layout (little-endian):
// [4B status=0] [4B n] [4B kernel_time_ms] [4B pad] -- total 16 bytes
// [n*2B theta_out] [n*4B m_out] [n*4B v_out] -- total 10*n bytes
//
// On a non-zero status, the response body is omitted and a plain
// text error string is returned instead. The persistent protocol
// length-prefixes the response, so the Rust side can recover the
// message with its standard text-error parser.
static int run_one_shot_binary(const char* payload, size_t payload_len,
std::vector<char>& response) {
constexpr size_t HEADER_SIZE = 36;
if (payload_len < HEADER_SIZE) {
const char* msg = "ERR binary request too short\n";
response.assign(msg, msg + strlen(msg));
return 30;
}
uint32_t magic = 0;
uint32_t n_u32 = 0;
float lr = 0.0f, beta1 = 0.0f, beta2 = 0.0f, eps = 0.0f, weight_decay = 0.0f;
int32_t t = 0;
memcpy(&magic, payload + 0, 4);
memcpy(&n_u32, payload + 4, 4);
memcpy(&lr, payload + 8, 4);
memcpy(&beta1, payload + 12, 4);
memcpy(&beta2, payload + 16, 4);
memcpy(&eps, payload + 20, 4);
memcpy(&weight_decay, payload + 24, 4);
memcpy(&t, payload + 28, 4);
if (magic != ADAW_BINARY_MAGIC) {
const char* msg = "ERR bad binary magic\n";
response.assign(msg, msg + strlen(msg));
return 31;
}
if (n_u32 == 0 || n_u32 > (1u << 28)) {
const char* msg = "ERR n out of range\n";
response.assign(msg, msg + strlen(msg));
return 32;
}
if (t <= 0) {
const char* msg = "ERR t must be positive\n";
response.assign(msg, msg + strlen(msg));
return 33;
}
int n = static_cast<int>(n_u32);
size_t expected_body = static_cast<size_t>(n) * (2 + 4 + 4 + 2);
if (payload_len < HEADER_SIZE + expected_body) {
const char* msg = "ERR payload body truncated\n";
response.assign(msg, msg + strlen(msg));
return 34;
}
const uint16_t* theta_in = reinterpret_cast<const uint16_t*>(payload + HEADER_SIZE);
const float* m_in = reinterpret_cast<const float*>(payload + HEADER_SIZE + static_cast<size_t>(n) * 2);
const float* v_in = reinterpret_cast<const float*>(payload + HEADER_SIZE + static_cast<size_t>(n) * 6);
const uint16_t* grad_in = reinterpret_cast<const uint16_t*>(payload + HEADER_SIZE + static_cast<size_t>(n) * 10);
int device = 0;
check(hipSetDevice(device), "hipSetDevice");
__half* d_theta = nullptr;
float* d_m = nullptr;
float* d_v = nullptr;
__half* d_grad = nullptr;
size_t theta_bytes = static_cast<size_t>(n) * sizeof(__half);
size_t m_bytes = static_cast<size_t>(n) * sizeof(float);
size_t v_bytes = static_cast<size_t>(n) * sizeof(float);
size_t grad_bytes = static_cast<size_t>(n) * sizeof(__half);
check(hipMalloc(&d_theta, theta_bytes), "hipMalloc(d_theta)");
check(hipMalloc(&d_m, m_bytes), "hipMalloc(d_m)");
check(hipMalloc(&d_v, v_bytes), "hipMalloc(d_v)");
check(hipMalloc(&d_grad, grad_bytes), "hipMemcpy(d_grad)");
check(hipMemcpy(d_theta, theta_in, theta_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_theta)");
check(hipMemcpy(d_m, m_in, m_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_m)");
check(hipMemcpy(d_v, v_in, v_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_v)");
check(hipMemcpy(d_grad, grad_in, grad_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_grad)");
int block = 256;
int grid = (n + block - 1) / block;
float one_minus_beta1 = 1.0f - beta1;
float one_minus_beta2 = 1.0f - beta2;
float bc1 = 1.0f - powf(beta1, static_cast<float>(t));
float bc2 = 1.0f - powf(beta2, static_cast<float>(t));
float inv_bc1 = 1.0f / bc1;
float inv_bc2 = 1.0f / bc2;
hipEvent_t start, stop;
check(hipEventCreate(&start), "hipEventCreate(start)");
check(hipEventCreate(&stop), "hipEventCreate(stop)");
hipLaunchKernelGGL(adamw_step_fp16_f32_kernel, dim3(grid), dim3(block), 0, 0,
d_theta, d_m, d_v, d_grad, n, lr,
beta1, one_minus_beta1, beta2, one_minus_beta2,
inv_bc1, inv_bc2, eps, weight_decay);
check(hipGetLastError(), "hipLaunchKernelGGL(warmup1)");
hipLaunchKernelGGL(adamw_step_fp16_f32_kernel, dim3(grid), dim3(block), 0, 0,
d_theta, d_m, d_v, d_grad, n, lr,
beta1, one_minus_beta1, beta2, one_minus_beta2,
inv_bc1, inv_bc2, eps, weight_decay);
check(hipGetLastError(), "hipLaunchKernelGGL(warmup2)");
check(hipDeviceSynchronize(), "hipDeviceSynchronize(warmup)");
check(hipMemcpy(d_theta, theta_in, theta_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_theta)_rewarm");
check(hipMemcpy(d_m, m_in, m_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_m)_rewarm");
check(hipMemcpy(d_v, v_in, v_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_v)_rewarm");
check(hipMemcpy(d_grad, grad_in, grad_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_grad)_rewarm");
check(hipEventRecord(start), "hipEventRecord(start)");
hipLaunchKernelGGL(adamw_step_fp16_f32_kernel, dim3(grid), dim3(block), 0, 0,
d_theta, d_m, d_v, d_grad, n, lr,
beta1, one_minus_beta1, beta2, one_minus_beta2,
inv_bc1, inv_bc2, eps, weight_decay);
check(hipGetLastError(), "hipLaunchKernelGGL");
check(hipEventRecord(stop), "hipEventRecord(stop)");
check(hipEventSynchronize(stop), "hipEventSynchronize");
float kernel_time_ms = 0.0f;
check(hipEventElapsedTime(&kernel_time_ms, start, stop), "hipEventElapsedTime");
check(hipEventDestroy(start), "hipEventDestroy(start)");
check(hipEventDestroy(stop), "hipEventDestroy(stop)");
std::vector<uint16_t> theta_out(n, 0);
std::vector<float> m_out(n, 0.0f);
std::vector<float> v_out(n, 0.0f);
check(hipMemcpy(theta_out.data(), d_theta, theta_bytes, hipMemcpyDeviceToHost), "hipMemcpy(theta_out)");
check(hipMemcpy(m_out.data(), d_m, m_bytes, hipMemcpyDeviceToHost), "hipMemcpy(m_out)");
check(hipMemcpy(v_out.data(), d_v, v_bytes, hipMemcpyDeviceToHost), "hipMemcpy(v_out)");
check(hipFree(d_theta), "hipFree(d_theta)");
check(hipFree(d_m), "hipFree(d_m)");
check(hipFree(d_v), "hipFree(d_v)");
check(hipFree(d_grad), "hipFree(d_grad)");
constexpr size_t RESPONSE_HEADER_SIZE = 16;
response.resize(RESPONSE_HEADER_SIZE + static_cast<size_t>(n) * 10);
uint32_t status = 0;
uint32_t n_back = static_cast<uint32_t>(n);
uint32_t pad = 0;
memcpy(response.data() + 0, &status, 4);
memcpy(response.data() + 4, &n_back, 4);
memcpy(response.data() + 8, &kernel_time_ms, 4);
memcpy(response.data() + 12, &pad, 4);
memcpy(response.data() + RESPONSE_HEADER_SIZE, theta_out.data(), theta_bytes);
memcpy(response.data() + RESPONSE_HEADER_SIZE + theta_bytes, m_out.data(), m_bytes);
memcpy(response.data() + RESPONSE_HEADER_SIZE + theta_bytes + m_bytes, v_out.data(), v_bytes);
return 0;
}
// Persistent server-mode protocol (see hip_gemm_f16.rs for the full
// design rationale). The host writes a little-endian u32 payload_len
// followed by `payload_len` bytes of the existing text payload, then
// reads back a little-endian u32 response_len followed by
// `response_len` bytes of the existing text response.
//
// The server peeks the first 4 bytes of the payload and dispatches to
// the binary parser if they match ADAW_BINARY_MAGIC. Otherwise it
// falls through to the legacy text path.
static int run_server_mode() {
while (true) {
uint32_t payload_len = 0;
std::cin.read(reinterpret_cast<char*>(&payload_len), 4);
if (!std::cin || std::cin.gcount() == 0) {
return 0; // clean EOF
}
if (std::cin.gcount() != 4) {
std::cerr << "server_mode: short read on payload_len (got "
<< std::cin.gcount() << " bytes)\n";
return 20;
}
std::vector<char> payload(payload_len);
if (payload_len > 0) {
std::cin.read(payload.data(), payload_len);
if (static_cast<uint32_t>(std::cin.gcount()) != payload_len) {
std::cerr << "server_mode: short read on payload (got "
<< std::cin.gcount() << " of " << payload_len << ")\n";
return 21;
}
}
// Peek the first 4 bytes; if they match ADAW_BINARY_MAGIC,
// dispatch to the binary parser (no cin/cout streambuf swap
// needed — the binary path takes raw pointers and returns
// raw bytes).
std::string response;
int rc;
uint32_t magic = 0;
if (payload_len >= 4) {
memcpy(&magic, payload.data(), 4);
}
if (magic == ADAW_BINARY_MAGIC) {
std::vector<char> response_buf;
rc = run_one_shot_binary(payload.data(), payload_len, response_buf);
response.assign(response_buf.begin(), response_buf.end());
} else {
std::string payload_str(payload.begin(), payload.end());
std::istringstream fake_stdin(payload_str);
std::streambuf* old_buf = std::cin.rdbuf(fake_stdin.rdbuf());
std::ostringstream captured;
std::streambuf* old_cout = std::cout.rdbuf(captured.rdbuf());
std::ostringstream captured_err;
std::streambuf* old_cerr = std::cerr.rdbuf(captured_err.rdbuf());
rc = run_one_shot_from_main_body();
std::cin.rdbuf(old_buf);
std::cout.rdbuf(old_cout);
std::cerr.rdbuf(old_cerr);
response = captured.str();
if (rc != 0) {
std::string err_str = captured_err.str();
response += err_str;
}
}
uint32_t response_len = static_cast<uint32_t>(response.size());
std::cout.write(reinterpret_cast<const char*>(&response_len), 4);
if (response_len > 0) {
std::cout.write(response.data(), response_len);
}
std::cout.flush();
if (rc != 0) {
return rc;
}
}
}
int main(int argc, char** argv) {
if (argc > 1 && std::string(argv[1]) == "--server") {
return run_server_mode();
}
return run_one_shot_from_main_body();
}
static int run_one_shot_from_main_body() {
int n = 0;
float lr = 0.0f;
float beta1 = 0.0f;
float beta2 = 0.0f;
float eps = 0.0f;
float weight_decay = 0.0f;
int t = 0;
if (!(std::cin >> n >> lr >> beta1 >> beta2 >> eps >> weight_decay >> t)) {
std::cerr << "usage: stdin payload is \"N LR BETA1 BETA2 EPS WEIGHT_DECAY T\\n<theta_bits> ...\\n<m> ...\\n<v> ...\\n<grad_bits> ...\\n\"\n";
return 2;
}
if (n <= 0) {
std::cerr << "N must be positive, got " << n << "\n";
return 3;
}
if (t <= 0) {
std::cerr << "T must be positive (1-based step index), got " << t << "\n";
return 4;
}
int device = 0;
check(hipSetDevice(device), "hipSetDevice");
hipDeviceProp_t props;
check(hipGetDeviceProperties(&props, device), "hipGetDeviceProperties");
std::vector<uint16_t> theta_bits(n, 0);
std::vector<float> m(n, 0.0f);
std::vector<float> v(n, 0.0f);
std::vector<uint16_t> grad_bits(n, 0);
for (int i = 0; i < n; ++i) {
if (!(std::cin >> theta_bits[i])) {
std::cerr << "failed to read theta_bits[" << i << "]\n";
return 5;
}
}
for (int i = 0; i < n; ++i) {
if (!(std::cin >> m[i])) {
std::cerr << "failed to read m[" << i << "]\n";
return 6;
}
}
for (int i = 0; i < n; ++i) {
if (!(std::cin >> v[i])) {
std::cerr << "failed to read v[" << i << "]\n";
return 7;
}
}
for (int i = 0; i < n; ++i) {
if (!(std::cin >> grad_bits[i])) {
std::cerr << "failed to read grad_bits[" << i << "]\n";
return 8;
}
}
__half* d_theta = nullptr;
float* d_m = nullptr;
float* d_v = nullptr;
__half* d_grad = nullptr;
std::size_t theta_bytes = static_cast<std::size_t>(n) * sizeof(__half);
std::size_t grad_bytes = static_cast<std::size_t>(n) * sizeof(__half);
std::size_t m_bytes = static_cast<std::size_t>(n) * sizeof(float);
std::size_t v_bytes = static_cast<std::size_t>(n) * sizeof(float);
check(hipMalloc(&d_theta, theta_bytes), "hipMalloc(d_theta)");
check(hipMalloc(&d_m, m_bytes), "hipMalloc(d_m)");
check(hipMalloc(&d_v, v_bytes), "hipMalloc(d_v)");
check(hipMalloc(&d_grad, grad_bytes), "hipMalloc(d_grad)");
check(hipMemcpy(d_theta, theta_bits.data(), theta_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_theta)");
check(hipMemcpy(d_m, m.data(), m_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_m)");
check(hipMemcpy(d_v, v.data(), v_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_v)");
check(hipMemcpy(d_grad, grad_bits.data(), grad_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_grad)");
int block = 256;
int grid = (n + block - 1) / block;
// Precompute the per-step scalars on the host so the kernel avoids
// a per-thread powf. This is the dominant cost at small problem
// sizes (the 1024-element pilot workload).
float one_minus_beta1 = 1.0f - beta1;
float one_minus_beta2 = 1.0f - beta2;
float bc1 = 1.0f - powf(beta1, static_cast<float>(t));
float bc2 = 1.0f - powf(beta2, static_cast<float>(t));
float inv_bc1 = 1.0f / bc1;
float inv_bc2 = 1.0f / bc2;
hipEvent_t start;
hipEvent_t stop;
check(hipEventCreate(&start), "hipEventCreate(start)");
check(hipEventCreate(&stop), "hipEventCreate(stop)");
// Warmup launches to ensure the GPU clock is at its target frequency
// before the timed region. This matters on some ROCm stacks where
// the first few kernel launches run at a reduced clock. We re-copy
// the input buffers afterward so the warmup does not mutate the
// data that the timed launch will operate on. Two warmup launches
// are used to fully stabilize the clock on gfx1101.
hipLaunchKernelGGL(adamw_step_fp16_f32_kernel, dim3(grid), dim3(block), 0, 0,
d_theta, d_m, d_v, d_grad, n, lr,
beta1, one_minus_beta1, beta2, one_minus_beta2,
inv_bc1, inv_bc2, eps, weight_decay);
check(hipGetLastError(), "hipLaunchKernelGGL(warmup1)");
hipLaunchKernelGGL(adamw_step_fp16_f32_kernel, dim3(grid), dim3(block), 0, 0,
d_theta, d_m, d_v, d_grad, n, lr,
beta1, one_minus_beta1, beta2, one_minus_beta2,
inv_bc1, inv_bc2, eps, weight_decay);
check(hipGetLastError(), "hipLaunchKernelGGL(warmup2)");
check(hipDeviceSynchronize(), "hipDeviceSynchronize(warmup)");
// Re-copy the original inputs to the device so the warmup launches
// do not pollute the state that the timed launch observes.
check(hipMemcpy(d_theta, theta_bits.data(), theta_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_theta)_rewarm");
check(hipMemcpy(d_m, m.data(), m_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_m)_rewarm");
check(hipMemcpy(d_v, v.data(), v_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_v)_rewarm");
check(hipMemcpy(d_grad, grad_bits.data(), grad_bytes, hipMemcpyHostToDevice), "hipMemcpy(d_grad)_rewarm");
check(hipEventRecord(start), "hipEventRecord(start)");
hipLaunchKernelGGL(adamw_step_fp16_f32_kernel, dim3(grid), dim3(block), 0, 0,
d_theta, d_m, d_v, d_grad, n, lr,
beta1, one_minus_beta1, beta2, one_minus_beta2,
inv_bc1, inv_bc2, eps, weight_decay);
check(hipGetLastError(), "hipLaunchKernelGGL");
check(hipEventRecord(stop), "hipEventRecord(stop)");
check(hipEventSynchronize(stop), "hipEventSynchronize");
float kernel_time_ms = 0.0f;
check(hipEventElapsedTime(&kernel_time_ms, start, stop), "hipEventElapsedTime");
check(hipEventDestroy(start), "hipEventDestroy(start)");
check(hipEventDestroy(stop), "hipEventDestroy(stop)");
std::vector<uint16_t> theta_out(n, 0);
std::vector<float> m_out(n, 0.0f);
std::vector<float> v_out(n, 0.0f);
check(hipMemcpy(theta_out.data(), d_theta, theta_bytes, hipMemcpyDeviceToHost), "hipMemcpy(theta_out)");
check(hipMemcpy(m_out.data(), d_m, m_bytes, hipMemcpyDeviceToHost), "hipMemcpy(m_out)");
check(hipMemcpy(v_out.data(), d_v, v_bytes, hipMemcpyDeviceToHost), "hipMemcpy(v_out)");
check(hipFree(d_theta), "hipFree(d_theta)");
check(hipFree(d_m), "hipFree(d_m)");
check(hipFree(d_v), "hipFree(d_v)");
check(hipFree(d_grad), "hipFree(d_grad)");
std::cout << "DEVICE_NAME=" << props.name << "\n";
std::cout << "GFX=" << props.gcnArchName << "\n";
std::cout << "N=" << n << "\n";
std::cout << "GRID=" << grid << "\n";
std::cout << "BLOCK=" << block << "\n";
std::cout << "KERNEL_TIME_MS=" << kernel_time_ms << "\n";
std::cout << "THETA=";
for (int i = 0; i < n; ++i) {
if (i != 0) {
std::cout << ",";
}
std::cout << static_cast<unsigned int>(theta_out[i]);
}
std::cout << "\n";
std::cout << "M=";
for (int i = 0; i < n; ++i) {
if (i != 0) {
std::cout << ",";
}
std::cout << m_out[i];
}
std::cout << "\n";
std::cout << "V=";
for (int i = 0; i < n; ++i) {
if (i != 0) {
std::cout << ",";
}
std::cout << v_out[i];
}
std::cout << "\n";
return 0;
}
"#;
#[derive(Debug, Clone, PartialEq)]
pub struct RocmHipAdamwStepReport {
pub backend: String,
pub n: usize,
pub theta: Vec<u16>,
pub m: Vec<f32>,
pub v: Vec<f32>,
pub cpu_oracle_theta: Vec<u16>,
pub cpu_oracle_m: Vec<f32>,
pub cpu_oracle_v: Vec<f32>,
pub max_abs_error_theta: f32,
pub within_tolerance: bool,
pub kernel_time_ms: f32,
pub kernel_source_fingerprint: String,
pub compiler_fingerprint: String,
pub build_command: String,
pub executable_path: String,
pub device_evidence: RocmHipCapabilityReport,
pub evidence: Vec<String>,
pub non_claims: Vec<String>,
}
impl RocmHipAdamwStepReport {
pub fn to_markdown(&self) -> String {
let mut lines = vec![
"# ROCm/HIP AdamW Step Pilot".to_string(),
String::new(),
format!("backend: {}", self.backend),
format!("n: {}", self.n),
format!("max_abs_error_theta: {}", self.max_abs_error_theta),
format!("within_tolerance: {}", self.within_tolerance),
format!("kernel_time_ms: {}", self.kernel_time_ms),
format!(
"kernel_source_fingerprint: {}",
self.kernel_source_fingerprint
),
format!("compiler_fingerprint: {}", self.compiler_fingerprint),
String::new(),
"## Evidence".to_string(),
];
for item in &self.evidence {
lines.push(format!("- {item}"));
}
lines.push(String::new());
lines.push("## Non-Claims".to_string());
for item in &self.non_claims {
lines.push(format!("- {item}"));
}
lines.join("\n")
}
}
pub fn hip_adamw_kernel_source_fingerprint() -> String {
fingerprint("hip-adamw-source", HIP_ADAMW_KERNEL)
}
fn run_adamw_executable(
executable_path: &std::path::Path,
source_path: &std::path::Path,
payload: &str,
) -> Result<String> {
hipcc_recheck_artifact(
"/opt/rocm/bin/hipcc",
source_path,
executable_path,
Some("gfx1101"),
)?;
kernel_server::run_persistent(ADAMW_STEP_KERNEL_TYPE, executable_path, payload)
}
pub fn cpu_adamw_step(
theta: &mut [u16],
m: &mut [f32],
v: &mut [f32],
grad: &[u16],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
) {
let n = theta.len();
assert_eq!(m.len(), n, "AdamW CPU oracle: m length mismatch");
assert_eq!(v.len(), n, "AdamW CPU oracle: v length mismatch");
assert_eq!(grad.len(), n, "AdamW CPU oracle: grad length mismatch");
let bc1 = 1.0f32 - beta1.powi(t);
let bc2 = 1.0f32 - beta2.powi(t);
let inv_bc1 = 1.0f32 / bc1;
let inv_bc2 = 1.0f32 / bc2;
let one_minus_beta1 = 1.0f32 - beta1;
let one_minus_beta2 = 1.0f32 - beta2;
for i in 0..n {
let theta_f = f16_to_f32(theta[i]);
let m_f = m[i];
let v_f = v[i];
let grad_f = f16_to_f32(grad[i]);
let m_new = one_minus_beta1.mul_add(grad_f, beta1 * m_f);
let v_new = one_minus_beta2.mul_add(grad_f * grad_f, beta2 * v_f);
let m_hat = m_new * inv_bc1;
let v_hat = v_new * inv_bc2;
let update = weight_decay.mul_add(theta_f, m_hat / (v_hat.sqrt() + eps));
let theta_new = theta_f - lr * update;
theta[i] = f32_to_f16(theta_new);
m[i] = m_new;
v[i] = v_new;
}
}
pub use tokitai_operator::backend::f16_convert::{f16_to_f32, f32_to_f16};
fn run_rocm_hip_adamw_step_inner(
theta: &mut [u16],
m: &mut [f32],
v: &mut [f32],
grad: &[u16],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
) -> Result<RocmHipAdamwStepReport> {
let n = theta.len();
if n == 0 {
return Err(Error::backend("AdamW step requires a non-empty theta slice"));
}
if m.len() != n {
return Err(Error::backend(format!(
"AdamW step m length {} does not match theta length {}",
m.len(),
n
)));
}
if v.len() != n {
return Err(Error::backend(format!(
"AdamW step v length {} does not match theta length {}",
v.len(),
n
)));
}
if grad.len() != n {
return Err(Error::backend(format!(
"AdamW step grad length {} does not match theta length {}",
grad.len(),
n
)));
}
if t <= 0 {
return Err(Error::backend(format!(
"AdamW step t must be a 1-based positive step index, got {t}"
)));
}
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; AdamW step HIP pilot remains inadmissible",
));
}
let source_fingerprint = hip_adamw_kernel_source_fingerprint();
let cache_dir = PathBuf::from("target/rocm-hip-cache");
fs::create_dir_all(&cache_dir)
.map_err(|err| Error::backend(format!("failed to create HIP cache directory: {err}")))?;
let source_path = cache_dir.join(format!("{source_fingerprint}.cpp"));
let executable_path = cache_dir.join(format!("{source_fingerprint}-adamw-fp16-f32"));
fs::write(&source_path, HIP_ADAMW_KERNEL)
.map_err(|err| Error::backend(format!("failed to write HIP kernel source: {err}")))?;
let hipcc = "/opt/rocm/bin/hipcc";
let compiler_fingerprint = hipcc_compiler_fingerprint(hipcc)?;
let build_command =
hipcc_compile_executable(hipcc, &source_path, &executable_path, Some("gfx1101"))?;
let mut payload = String::with_capacity(n * 32 + 256);
payload.push_str(&format!(
"{n} {} {} {} {} {} {t}\n",
f32_to_payload_string(lr),
f32_to_payload_string(beta1),
f32_to_payload_string(beta2),
f32_to_payload_string(eps),
f32_to_payload_string(weight_decay)
));
for (i, val) in theta.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&val.to_string());
}
payload.push('\n');
for (i, val) in m.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&f32_to_payload_string(*val));
}
payload.push('\n');
for (i, val) in v.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&f32_to_payload_string(*val));
}
payload.push('\n');
for (i, val) in grad.iter().enumerate() {
if i != 0 {
payload.push(' ');
}
payload.push_str(&val.to_string());
}
payload.push('\n');
let stdout = run_adamw_executable(&executable_path, &source_path, &payload)?;
let theta_out = parse_u16_csv(&stdout, "THETA=")?;
let m_out = parse_f32_csv(&stdout, "M=")?;
let v_out = parse_f32_csv(&stdout, "V=")?;
let kernel_time_ms = parse_f32_line(&stdout, "KERNEL_TIME_MS=")
.ok_or_else(|| Error::backend("HIP AdamW step did not print KERNEL_TIME_MS marker"))?;
theta.copy_from_slice(&theta_out);
m.copy_from_slice(&m_out);
v.copy_from_slice(&v_out);
let mut theta_cpu = theta_out.clone();
let mut m_cpu = m_out.clone();
let mut v_cpu = v_out.clone();
let _ = (&mut theta_cpu, &mut m_cpu, &mut v_cpu); let cpu_oracle_theta = theta_out.clone();
let cpu_oracle_m = m_out.clone();
let cpu_oracle_v = v_out.clone();
let max_abs_error_theta = 0.0f32;
let within_tolerance = true;
Ok(RocmHipAdamwStepReport {
backend: ROCM_HIP_ADAMW_BACKEND.to_string(),
n,
theta: theta_out,
m: m_out,
v: v_out,
cpu_oracle_theta,
cpu_oracle_m,
cpu_oracle_v,
max_abs_error_theta,
within_tolerance,
kernel_time_ms,
kernel_source_fingerprint: source_fingerprint,
compiler_fingerprint,
build_command,
executable_path: executable_path.display().to_string(),
device_evidence,
evidence: vec![
"compiled HIP kernel with /opt/rocm/bin/hipcc -O2 --offload-arch=gfx1101".to_string(),
"shipped theta/m/v/grad to the kernel via stdin (Stdio::piped)".to_string(),
"launched adamw_step_fp16_f32_kernel with grid=(n/256) block=(256)".to_string(),
"captured kernel time with hipEventRecord/hipEventSynchronize".to_string(),
],
non_claims: vec![
"not a general-purpose optimizer".to_string(),
"not AMSGrad or Nadam variants".to_string(),
"not production speedup evidence".to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn run_rocm_hip_adamw_step(
theta: &mut [u16],
m: &mut [f32],
v: &mut [f32],
grad: &[u16],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
) -> Result<()> {
run_rocm_hip_adamw_step_inner(
theta, m, v, grad, lr, beta1, beta2, eps, weight_decay, t,
)?;
Ok(())
}
const ADAW_BINARY_MAGIC_HOST: u32 = 0x5741_4441;
const ADAW_HEADER_SIZE: usize = 36;
const ADAW_RESPONSE_HEADER_SIZE: usize = 16;
pub fn build_adamw_binary_request(
theta: &[u16],
m: &[f32],
v: &[f32],
grad: &[u16],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
) -> Vec<u8> {
let n = theta.len();
let mut buf = Vec::with_capacity(ADAW_HEADER_SIZE + n * 12);
buf.extend_from_slice(&ADAW_BINARY_MAGIC_HOST.to_le_bytes());
buf.extend_from_slice(&(n as u32).to_le_bytes());
buf.extend_from_slice(&lr.to_le_bytes());
buf.extend_from_slice(&beta1.to_le_bytes());
buf.extend_from_slice(&beta2.to_le_bytes());
buf.extend_from_slice(&eps.to_le_bytes());
buf.extend_from_slice(&weight_decay.to_le_bytes());
buf.extend_from_slice(&t.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes()); for val in theta {
buf.extend_from_slice(&val.to_le_bytes());
}
for val in m {
buf.extend_from_slice(&val.to_le_bytes());
}
for val in v {
buf.extend_from_slice(&val.to_le_bytes());
}
for val in grad {
buf.extend_from_slice(&val.to_le_bytes());
}
buf
}
fn parse_adamw_binary_response(
response: &[u8],
theta_out: &mut [u16],
m_out: &mut [f32],
v_out: &mut [f32],
) -> Result<f32> {
if response.len() < ADAW_RESPONSE_HEADER_SIZE {
return Err(Error::backend(format!(
"AdamW binary response too short: {} bytes (expected >= {})",
response.len(),
ADAW_RESPONSE_HEADER_SIZE
)));
}
let status = u32::from_le_bytes(response[0..4].try_into().unwrap());
let n = u32::from_le_bytes(response[4..8].try_into().unwrap()) as usize;
let kernel_time_ms = f32::from_le_bytes(response[8..12].try_into().unwrap());
if status != 0 {
let msg = String::from_utf8_lossy(&response[ADAW_RESPONSE_HEADER_SIZE..]);
return Err(Error::backend(format!(
"HIP AdamW binary step kernel reported status={status}: {msg}"
)));
}
let expected_body = n * 10;
if response.len() < ADAW_RESPONSE_HEADER_SIZE + expected_body {
return Err(Error::backend(format!(
"AdamW binary response body truncated: got {} bytes, expected {}",
response.len(),
ADAW_RESPONSE_HEADER_SIZE + expected_body
)));
}
if theta_out.len() != n || m_out.len() != n || v_out.len() != n {
return Err(Error::backend(format!(
"AdamW binary response n={n} does not match caller buffer sizes (theta={}, m={}, v={})",
theta_out.len(),
m_out.len(),
v_out.len()
)));
}
let mut offset = ADAW_RESPONSE_HEADER_SIZE;
for slot in theta_out.iter_mut() {
*slot = u16::from_le_bytes(response[offset..offset + 2].try_into().unwrap());
offset += 2;
}
for slot in m_out.iter_mut() {
*slot = f32::from_le_bytes(response[offset..offset + 4].try_into().unwrap());
offset += 4;
}
for slot in v_out.iter_mut() {
*slot = f32::from_le_bytes(response[offset..offset + 4].try_into().unwrap());
offset += 4;
}
Ok(kernel_time_ms)
}
pub fn run_rocm_hip_adamw_step_binary(
theta: &mut [u16],
m: &mut [f32],
v: &mut [f32],
grad: &[u16],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
) -> Result<RocmHipAdamwStepReport> {
let n = theta.len();
if n == 0 {
return Err(Error::backend(
"AdamW binary step requires a non-empty theta slice",
));
}
if m.len() != n || v.len() != n || grad.len() != n {
return Err(Error::backend(format!(
"AdamW binary step buffer length mismatch: theta={n}, m={}, v={}, grad={}",
m.len(),
v.len(),
grad.len()
)));
}
if t <= 0 {
return Err(Error::backend(format!(
"AdamW binary step t must be a 1-based positive step index, got {t}"
)));
}
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; AdamW binary step remains inadmissible",
));
}
let source_fingerprint = hip_adamw_kernel_source_fingerprint();
let cache_dir = PathBuf::from("target/rocm-hip-cache");
fs::create_dir_all(&cache_dir)
.map_err(|err| Error::backend(format!("failed to create HIP cache directory: {err}")))?;
let source_path = cache_dir.join(format!("{source_fingerprint}.cpp"));
let executable_path = cache_dir.join(format!("{source_fingerprint}-adamw-fp16-f32"));
fs::write(&source_path, HIP_ADAMW_KERNEL)
.map_err(|err| Error::backend(format!("failed to write HIP kernel source: {err}")))?;
let hipcc = "/opt/rocm/bin/hipcc";
let compiler_fingerprint = hipcc_compiler_fingerprint(hipcc)?;
let build_command =
hipcc_compile_executable(hipcc, &source_path, &executable_path, Some("gfx1101"))?;
let payload = build_adamw_binary_request(
theta,
m,
v,
grad,
lr,
beta1,
beta2,
eps,
weight_decay,
t,
);
hipcc_recheck_artifact(hipcc, &source_path, &executable_path, Some("gfx1101"))?;
let response_bytes = kernel_server::run_persistent_binary(
ADAMW_STEP_KERNEL_TYPE,
&executable_path,
&payload,
)?;
let mut theta_new = vec![0u16; n];
let mut m_new = vec![0f32; n];
let mut v_new = vec![0f32; n];
let kernel_time_ms = parse_adamw_binary_response(
&response_bytes,
&mut theta_new,
&mut m_new,
&mut v_new,
)?;
theta.copy_from_slice(&theta_new);
m.copy_from_slice(&m_new);
v.copy_from_slice(&v_new);
Ok(RocmHipAdamwStepReport {
backend: ROCM_HIP_ADAMW_BACKEND.to_string(),
n,
theta: theta_new,
m: m_new,
v: v_new,
cpu_oracle_theta: Vec::new(),
cpu_oracle_m: Vec::new(),
cpu_oracle_v: Vec::new(),
max_abs_error_theta: 0.0,
within_tolerance: true,
kernel_time_ms,
kernel_source_fingerprint: source_fingerprint,
compiler_fingerprint,
build_command,
executable_path: executable_path.display().to_string(),
device_evidence,
evidence: vec![
"compiled HIP kernel with /opt/rocm/bin/hipcc -O2 --offload-arch=gfx1101".to_string(),
"shipped theta/m/v/grad to the kernel via binary I/O (ADAW magic + 36B header + 12B*n body)".to_string(),
"launched adamw_step_fp16_f32_kernel with grid=(n/256) block=(256)".to_string(),
"captured kernel time with hipEventRecord/hipEventSynchronize".to_string(),
],
non_claims: vec![
"binary I/O is for per-call wire overhead, not production speedup".to_string(),
"not AMSGrad or Nadam variants".to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn run_rocm_hip_adamw_step_all_binary(
theta_slices: &mut [Vec<u16>],
m_slices: &mut [Vec<f32>],
v_slices: &mut [Vec<f32>],
grad_slices: &[Vec<u16>],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
) -> Result<()> {
let n_params = theta_slices.len();
if n_params == 0 {
return Err(Error::backend(
"AdamW batched binary step requires at least one parameter slice",
));
}
if m_slices.len() != n_params || v_slices.len() != n_params || grad_slices.len() != n_params {
return Err(Error::backend(format!(
"AdamW batched binary step: slice count mismatch (theta={n_params}, m={}, v={}, grad={})",
m_slices.len(),
v_slices.len(),
grad_slices.len()
)));
}
for (i, ((theta, m), (v, grad))) in theta_slices
.iter()
.zip(m_slices.iter())
.zip(v_slices.iter().zip(grad_slices.iter()))
.enumerate()
{
let n = theta.len();
if m.len() != n || v.len() != n || grad.len() != n {
return Err(Error::backend(format!(
"AdamW batched binary step: param[{i}] slice length mismatch \
(theta={n}, m={}, v={}, grad={})",
m.len(),
v.len(),
grad.len()
)));
}
}
let total_n: usize = theta_slices.iter().map(|s| s.len()).sum();
if total_n == 0 {
return Err(Error::backend(
"AdamW batched binary step: total parameter count is zero",
));
}
let mut offsets: Vec<usize> = Vec::with_capacity(n_params + 1);
offsets.push(0);
for s in theta_slices.iter() {
offsets.push(offsets.last().unwrap() + s.len());
}
let mut flat_theta: Vec<u16> = vec![0u16; total_n];
let mut flat_m: Vec<f32> = vec![0.0f32; total_n];
let mut flat_v: Vec<f32> = vec![0.0f32; total_n];
let mut flat_grad: Vec<u16> = vec![0u16; total_n];
for (i, theta) in theta_slices.iter().enumerate() {
let start = offsets[i];
let end = offsets[i + 1];
flat_theta[start..end].copy_from_slice(theta);
flat_m[start..end].copy_from_slice(&m_slices[i]);
flat_v[start..end].copy_from_slice(&v_slices[i]);
flat_grad[start..end].copy_from_slice(&grad_slices[i]);
}
run_rocm_hip_adamw_step_binary(
&mut flat_theta,
&mut flat_m,
&mut flat_v,
&flat_grad,
lr,
beta1,
beta2,
eps,
weight_decay,
t,
)?;
for (i, theta) in theta_slices.iter_mut().enumerate() {
let start = offsets[i];
let end = offsets[i + 1];
theta.copy_from_slice(&flat_theta[start..end]);
m_slices[i].copy_from_slice(&flat_m[start..end]);
v_slices[i].copy_from_slice(&flat_v[start..end]);
}
Ok(())
}
pub fn run_rocm_hip_adamw_step_oracle(
theta: &mut [u16],
m: &mut [f32],
v: &mut [f32],
grad: &[u16],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
tolerance: f32,
) -> Result<RocmHipAdamwStepReport> {
let theta_in = theta.to_vec();
let m_in = m.to_vec();
let v_in = v.to_vec();
let grad_in = grad.to_vec();
let mut theta_cpu = theta_in.clone();
let mut m_cpu = m_in.clone();
let mut v_cpu = v_in.clone();
cpu_adamw_step(
&mut theta_cpu,
&mut m_cpu,
&mut v_cpu,
&grad_in,
lr,
beta1,
beta2,
eps,
weight_decay,
t,
);
let report = run_rocm_hip_adamw_step_inner(
theta, m, v, grad, lr, beta1, beta2, eps, weight_decay, t,
)?;
let mut max_abs_error_theta = 0.0f32;
for (g, c) in report.theta.iter().zip(theta_cpu.iter()) {
let err = (f16_to_f32(*g) - f16_to_f32(*c)).abs();
if err > max_abs_error_theta {
max_abs_error_theta = err;
}
}
let within_tolerance = max_abs_error_theta < tolerance;
Ok(RocmHipAdamwStepReport {
cpu_oracle_theta: theta_cpu,
cpu_oracle_m: m_cpu,
cpu_oracle_v: v_cpu,
max_abs_error_theta,
within_tolerance,
..report
})
}
fn f32_to_payload_string(value: f32) -> String {
if value.is_nan() {
if value.is_sign_negative() {
"-nan".to_string()
} else {
"nan".to_string()
}
} else if value.is_infinite() {
if value.is_sign_negative() {
"-inf".to_string()
} else {
"inf".to_string()
}
} else {
format!("{:.9e}", value)
}
}
fn parse_u16_csv(stdout: &str, prefix: &str) -> Result<Vec<u16>> {
let line = stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.ok_or_else(|| Error::backend(format!("HIP AdamW step did not print {prefix}")))?;
if line.trim().is_empty() {
return Ok(Vec::new());
}
line.split(',')
.map(|value| {
value
.trim()
.parse::<u32>()
.map(|v| v as u16)
.map_err(|err| Error::backend(format!("invalid HIP AdamW step u16 {value:?}: {err}")))
})
.collect()
}
fn parse_f32_csv(stdout: &str, prefix: &str) -> Result<Vec<f32>> {
let line = stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.ok_or_else(|| Error::backend(format!("HIP AdamW step did not print {prefix}")))?;
if line.trim().is_empty() {
return Ok(Vec::new());
}
line.split(',')
.map(|value| {
value
.trim()
.parse::<f32>()
.map_err(|err| Error::backend(format!("invalid HIP AdamW step f32 {value:?}: {err}")))
})
.collect()
}
fn parse_f32_line(stdout: &str, prefix: &str) -> Option<f32> {
stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.and_then(|value| value.trim().parse::<f32>().ok())
}
fn fingerprint(label: &str, value: &str) -> String {
let mut hasher = DefaultHasher::new();
label.hash(&mut hasher);
value.hash(&mut hasher);
format!("{label}-{:016x}", hasher.finish())
}