use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use std::process::Command;
use std::time::Instant;
use crate::backend::hip_dense::{
hipcc_compile_executable, hipcc_compiler_fingerprint, hipcc_recheck_artifact,
};
use crate::backend::rocm::{RocmHipCapabilityReport, detect_local_rocm_hip};
use crate::object::sheaf::PrecisionClass;
use crate::{Error, Result};
pub const ROCM_HIP_SHEAF_OVERLAP_CHECK_BACKEND: &str = "rocm_hip_sheaf_overlap_check_pilot";
pub const HIP_SHEAF_OVERLAP_CHECK_KERNEL: &str = r#"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdlib>
#include <cmath>
#include <iostream>
#include <string>
#include <vector>
// Path B / MOE_DESIGN.md ยง5: sheaf compatibility check on GPU.
// One block per overlap, blockDim.x threads cooperate over the section_dim axis.
// reduction sums the per-position residuals to a single scalar per overlap.
__global__ void sheaf_overlap_residual_check(
const __half* section_values,
const int* overlap_pairs,
const int* overlap_indices,
float* residuals,
int n_sections,
int n_overlaps,
int section_dim,
int precision_class_id
) {
int overlap_idx = blockIdx.x;
if (overlap_idx >= n_overlaps) return;
int i = overlap_pairs[2 * overlap_idx];
int j = overlap_pairs[2 * overlap_idx + 1];
if (i < 0 || j < 0 || i >= n_sections || j >= n_sections) return;
int precision_digits = 0;
if (precision_class_id == 1) precision_digits = 8;
if (precision_class_id == 2) precision_digits = 16;
__shared__ float shared_sum[64];
float local_sum = 0.0f;
for (int k = threadIdx.x; k < section_dim; k += blockDim.x) {
const __half* a_ptr = section_values + (std::size_t)i * section_dim + k;
const __half* b_ptr = section_values + (std::size_t)j * section_dim + k;
float diff;
if (precision_class_id == 0) {
float fa = __half2float(*a_ptr);
float fb = __half2float(*b_ptr);
diff = fabsf(fa - fb);
} else {
// p-adic 257 subtraction: encode fp16 bits to 257-adic, subtract
// digit-by-digit (mod 257), truncate to precision_digits, decode
// back to fp16, then take the absolute difference of real values.
unsigned short bits_a = __half_as_ushort(*a_ptr);
unsigned short bits_b = __half_as_ushort(*b_ptr);
int sign_a = (bits_a & 0x8000) ? -1 : 1;
int sign_b = (bits_b & 0x8000) ? -1 : 1;
unsigned int abs_a = bits_a & 0x7FFF;
unsigned int abs_b = bits_b & 0x7FFF;
unsigned char digits_a[16];
unsigned char digits_b[16];
for (int d = 0; d < precision_digits; d++) {
digits_a[d] = (unsigned char)(abs_a % 257u);
digits_b[d] = (unsigned char)(abs_b % 257u);
abs_a /= 257u;
abs_b /= 257u;
}
// Subtract with borrow in 257-adic; treat signed by sign-of-difference
int diff_digits[16];
int borrow = 0;
int net_sign = sign_a;
if (abs_a == 0 && abs_b == 0) {
// magnitudes equal: handle via sign
if (sign_a != sign_b) {
int pos_larger = -1;
for (int d = precision_digits - 1; d >= 0; d--) {
if (digits_a[d] != digits_b[d]) {
pos_larger = (digits_a[d] > digits_b[d]) ? 1 : 0;
break;
}
}
if (pos_larger == 0) {
// a < b in magnitude
for (int d = 0; d < precision_digits; d++) {
int dd = (int)digits_b[d] - (int)digits_a[d] - borrow;
if (dd < 0) { dd += 257; borrow = 1; } else { borrow = 0; }
diff_digits[d] = dd;
}
net_sign = sign_b;
} else {
// a > b in magnitude
for (int d = 0; d < precision_digits; d++) {
int dd = (int)digits_a[d] - (int)digits_b[d] - borrow;
if (dd < 0) { dd += 257; borrow = 1; } else { borrow = 0; }
diff_digits[d] = dd;
}
net_sign = sign_a;
}
} else {
for (int d = 0; d < precision_digits; d++) diff_digits[d] = 0;
net_sign = 1;
}
} else if (abs_a > abs_b) {
for (int d = 0; d < precision_digits; d++) {
int dd = (int)digits_a[d] - (int)digits_b[d] - borrow;
if (dd < 0) { dd += 257; borrow = 1; } else { borrow = 0; }
diff_digits[d] = dd;
}
net_sign = sign_a;
} else {
for (int d = 0; d < precision_digits; d++) {
int dd = (int)digits_b[d] - (int)digits_a[d] - borrow;
if (dd < 0) { dd += 257; borrow = 1; } else { borrow = 0; }
diff_digits[d] = dd;
}
net_sign = sign_b;
}
// Convert p-adic digits back to integer
unsigned int result_int = 0;
unsigned int place = 1;
for (int d = 0; d < precision_digits; d++) {
result_int += (unsigned int)diff_digits[d] * place;
place *= 257u;
}
// Take lower 16 bits and reapply sign
unsigned short result_bits = (unsigned short)(result_int & 0xFFFFu);
if (net_sign < 0) result_bits |= 0x8000;
__half result_h = __ushort_as_half(result_bits);
float result_f = __half2float(result_h);
diff = fabsf(result_f);
}
local_sum += diff;
(void)overlap_indices; // currently unused (full alignment)
}
shared_sum[threadIdx.x] = local_sum;
__syncthreads();
// Tree reduction
for (int stride = 32; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
shared_sum[threadIdx.x] += shared_sum[threadIdx.x + stride];
}
__syncthreads();
}
if (threadIdx.x == 0) {
residuals[overlap_idx] = shared_sum[0];
}
}
static void check(hipError_t status, const char* label) {
if (status != hipSuccess) {
std::cerr << "HIP_ERROR " << label << "=" << hipGetErrorString(status) << "\n";
std::exit(10);
}
}
int main(int argc, char** argv) {
if (argc < 7) {
std::cerr << "usage: rocm_sheaf_overlap_check N_SECTIONS N_OVERLAPS SECTION_DIM PRECISION_CLASS ITERATIONS SECTIONS... PAIRS...\n";
return 2;
}
int n_sections = (int)std::strtol(argv[1], nullptr, 10);
int n_overlaps = (int)std::strtol(argv[2], nullptr, 10);
int section_dim = (int)std::strtol(argv[3], nullptr, 10);
int precision_class_id = (int)std::strtol(argv[4], nullptr, 10);
int iterations = (int)std::strtol(argv[5], nullptr, 10);
if (iterations < 1) iterations = 1;
int expected_section_args = 6 + n_sections * section_dim;
int expected_pair_args = expected_section_args + n_overlaps * 2;
if (argc < expected_pair_args) {
std::cerr << "argument count does not match: expected " << expected_pair_args << ", got " << argc << "\n";
return 3;
}
check(hipSetDevice(0), "hipSetDevice");
// Parse sections: each fp16 value is given as its 16-bit bits in decimal.
std::vector<unsigned short> sections_host((std::size_t)n_sections * section_dim);
for (int idx = 0; idx < n_sections * section_dim; idx++) {
int v = (int)std::strtol(argv[6 + idx], nullptr, 10);
sections_host[idx] = (unsigned short)(v & 0xFFFF);
}
// Parse overlap pairs: each as (i, j)
std::vector<int> pairs_host((std::size_t)n_overlaps * 2);
for (int idx = 0; idx < n_overlaps * 2; idx++) {
pairs_host[idx] = (int)std::strtol(argv[expected_section_args + idx], nullptr, 10);
}
// Pack sections as half (fp16) values via bit reinterpretation.
std::vector<unsigned short> section_bits(sections_host.size());
for (std::size_t idx = 0; idx < sections_host.size(); idx++) {
section_bits[idx] = sections_host[idx];
}
__half* d_sections = nullptr;
int* d_pairs = nullptr;
int* d_indices = nullptr;
float* d_residuals = nullptr;
std::size_t section_bytes = section_bits.size() * sizeof(unsigned short);
std::size_t pair_bytes = (std::size_t)n_overlaps * 2 * sizeof(int);
std::size_t residual_bytes = (std::size_t)n_overlaps * sizeof(float);
check(hipMalloc(&d_sections, section_bytes), "hipMalloc(sections)");
check(hipMalloc(&d_pairs, pair_bytes), "hipMalloc(pairs)");
check(hipMalloc(&d_residuals, residual_bytes), "hipMalloc(residuals)");
// overlap_indices is unused in the kernel; pass nullptr.
d_indices = nullptr;
// Copy sections by reinterpreting the unsigned-short buffer as half.
check(hipMemcpy(d_sections, section_bits.data(), section_bytes, hipMemcpyHostToDevice), "hipMemcpy(sections)");
check(hipMemcpy(d_pairs, pairs_host.data(), pair_bytes, hipMemcpyHostToDevice), "hipMemcpy(pairs)");
int block = 64;
int grid = n_overlaps;
hipEvent_t start_event;
hipEvent_t stop_event;
check(hipEventCreate(&start_event), "hipEventCreate(start)");
check(hipEventCreate(&stop_event), "hipEventCreate(stop)");
check(hipEventRecord(start_event, 0), "hipEventRecord(start)");
for (int it = 0; it < iterations; it++) {
hipLaunchKernelGGL(
sheaf_overlap_residual_check,
dim3(grid),
dim3(block),
0,
0,
d_sections,
d_pairs,
d_indices,
d_residuals,
n_sections,
n_overlaps,
section_dim,
precision_class_id
);
check(hipGetLastError(), "hipLaunchKernelGGL");
}
check(hipEventRecord(stop_event, 0), "hipEventRecord(stop)");
check(hipEventSynchronize(stop_event), "hipEventSynchronize(stop)");
float kernel_time_ms = 0.0f;
check(hipEventElapsedTime(&kernel_time_ms, start_event, stop_event), "hipEventElapsedTime");
check(hipEventDestroy(start_event), "hipEventDestroy(start)");
check(hipEventDestroy(stop_event), "hipEventDestroy(stop)");
std::vector<float> residuals_host((std::size_t)n_overlaps);
check(hipMemcpy(residuals_host.data(), d_residuals, residual_bytes, hipMemcpyDeviceToHost), "hipMemcpy(residuals)");
check(hipFree(d_sections), "hipFree(sections)");
check(hipFree(d_pairs), "hipFree(pairs)");
check(hipFree(d_residuals), "hipFree(residuals)");
std::cout << "N_SECTIONS=" << n_sections << "\n";
std::cout << "N_OVERLAPS=" << n_overlaps << "\n";
std::cout << "SECTION_DIM=" << section_dim << "\n";
std::cout << "PRECISION_CLASS_ID=" << precision_class_id << "\n";
std::cout << "ITERATIONS=" << iterations << "\n";
std::cout << "KERNEL_TIME_MS=" << kernel_time_ms << "\n";
std::cout << "GRID=" << grid << "\n";
std::cout << "BLOCK=" << block << "\n";
std::cout << "RESIDUALS=";
for (std::size_t idx = 0; idx < residuals_host.size(); idx++) {
if (idx != 0) std::cout << ",";
std::cout << residuals_host[idx];
}
std::cout << "\n";
return 0;
}
"#;
#[derive(Debug, Clone, PartialEq)]
pub struct RocmHipSheafOverlapCheckReport {
pub backend: String,
pub precision_class: PrecisionClass,
pub n_sections: usize,
pub n_overlaps: usize,
pub section_dim: usize,
pub iterations: usize,
pub residuals: Vec<f32>,
pub launch_grid: u32,
pub launch_block: u32,
pub kernel_time_ms: f32,
pub kernel_source_fingerprint: String,
pub compiler_fingerprint: String,
pub device_evidence: RocmHipCapabilityReport,
pub evidence: Vec<String>,
pub non_claims: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct RocmHipSheafOverlapCheckTimings {
pub report: RocmHipSheafOverlapCheckReport,
pub gpu_wall_time_ms: f64,
pub host_wall_time_ms: f64,
pub gpu_kernel_time_ms: f32,
}
fn half_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exponent = ((bits >> 10) & 0x1F) as u32;
let mantissa = (bits & 0x3FF) as u32;
if exponent == 0 {
if mantissa == 0 {
let s = if sign == 1 { -1.0_f32 } else { 1.0_f32 };
return s * 0.0;
}
let val = (mantissa as f32) / 1024.0 * (1.0_f32 / 16384.0);
return if sign == 1 { -val } else { val };
}
if exponent == 0x1F {
if mantissa == 0 {
return if sign == 1 {
f32::NEG_INFINITY
} else {
f32::INFINITY
};
}
return f32::NAN;
}
let exp = (exponent as i32) - 15;
let val = (1.0 + (mantissa as f32) / 1024.0) * (2.0_f32).powi(exp);
if sign == 1 { -val } else { val }
}
fn compute_host_residuals(
sections: &[Vec<u16>],
overlaps: &[(usize, usize)],
section_dim: usize,
) -> Vec<f32> {
let mut residuals = Vec::with_capacity(overlaps.len());
for (i, j) in overlaps {
let lhs = §ions[*i];
let rhs = §ions[*j];
let mut sum = 0.0_f32;
for k in 0..section_dim {
let a = half_to_f32(lhs[k]);
let b = half_to_f32(rhs[k]);
sum += (a - b).abs();
}
residuals.push(sum);
}
residuals
}
pub fn run_rocm_hip_sheaf_overlap_check(
sections: &[Vec<u16>],
overlaps: &[(usize, usize)],
precision_class: PrecisionClass,
) -> Result<RocmHipSheafOverlapCheckReport> {
run_rocm_hip_sheaf_overlap_check_with_options(
sections,
overlaps,
precision_class,
None,
RocmHipSheafOverlapCheckLaunchOptions::default(),
)
}
#[derive(Debug, Clone, Copy)]
pub struct RocmHipSheafOverlapCheckLaunchOptions {
pub iterations: usize,
}
impl Default for RocmHipSheafOverlapCheckLaunchOptions {
fn default() -> Self {
Self { iterations: 1 }
}
}
pub fn run_rocm_hip_sheaf_overlap_check_with_options(
sections: &[Vec<u16>],
overlaps: &[(usize, usize)],
precision_class: PrecisionClass,
section_dim_override: Option<usize>,
options: RocmHipSheafOverlapCheckLaunchOptions,
) -> Result<RocmHipSheafOverlapCheckReport> {
if sections.is_empty() {
return Err(Error::backend(
"HIP sheaf overlap check requires at least one section",
));
}
if overlaps.is_empty() {
return Err(Error::backend(
"HIP sheaf overlap check requires at least one overlap",
));
}
let detected_dim = sections[0].len();
let section_dim = section_dim_override.unwrap_or(detected_dim);
for section in sections {
if section.len() < section_dim {
return Err(Error::backend(format!(
"HIP sheaf overlap check section length {} is less than required section_dim {}",
section.len(),
section_dim
)));
}
}
for (i, j) in overlaps {
if *i >= sections.len() || *j >= sections.len() {
return Err(Error::backend(format!(
"HIP sheaf overlap pair ({i}, {j}) is out of range for {} sections",
sections.len()
)));
}
}
let device_evidence = detect_local_rocm_hip();
if !device_evidence.available {
return Err(Error::backend(
"ROCm/HIP is unavailable; sheaf overlap check remains inadmissible",
));
}
let compiler_fingerprint = hipcc_compiler_fingerprint("/opt/rocm/bin/hipcc")?;
let source_fingerprint = hip_sheaf_overlap_check_kernel_source_fingerprint();
let cache_dir = PathBuf::from("target/rocm-hip-cache");
fs::create_dir_all(&cache_dir)
.map_err(|err| Error::backend(format!("failed to create HIP cache directory: {err}")))?;
let source_path = cache_dir.join(format!("{source_fingerprint}.cpp"));
let executable_path = cache_dir.join(format!("{source_fingerprint}-sheaf-overlap-check"));
fs::write(&source_path, HIP_SHEAF_OVERLAP_CHECK_KERNEL).map_err(|err| {
Error::backend(format!("failed to write HIP sheaf overlap source: {err}"))
})?;
hipcc_compile_executable("/opt/rocm/bin/hipcc", &source_path, &executable_path, None)?;
let n_sections = sections.len();
let n_overlaps = overlaps.len();
let iterations = options.iterations.max(1);
let mut args = vec![
n_sections.to_string(),
n_overlaps.to_string(),
section_dim.to_string(),
precision_class.id().to_string(),
iterations.to_string(),
];
for section in sections {
for k in 0..section_dim {
args.push(section[k].to_string());
}
}
for (i, j) in overlaps {
args.push(i.to_string());
args.push(j.to_string());
}
hipcc_recheck_artifact("/opt/rocm/bin/hipcc", &source_path, &executable_path, None)?;
let run = Command::new(&executable_path)
.args(args)
.output()
.map_err(|err| Error::backend(format!("failed to run HIP sheaf overlap check: {err}")))?;
if !run.status.success() {
return Err(Error::backend(format!(
"HIP sheaf overlap check failed: {}{}",
String::from_utf8_lossy(&run.stderr),
String::from_utf8_lossy(&run.stdout)
)));
}
let stdout = String::from_utf8_lossy(&run.stdout);
let residuals = parse_residuals(&stdout)?;
let launch_grid = parse_u32_line(&stdout, "GRID=").unwrap_or(n_overlaps as u32);
let launch_block = parse_u32_line(&stdout, "BLOCK=").unwrap_or(64);
let kernel_time_ms = parse_f32_line(&stdout, "KERNEL_TIME_MS=").unwrap_or(0.0);
Ok(RocmHipSheafOverlapCheckReport {
backend: ROCM_HIP_SHEAF_OVERLAP_CHECK_BACKEND.to_string(),
precision_class,
n_sections,
n_overlaps,
section_dim,
iterations,
residuals,
launch_grid,
launch_block,
kernel_time_ms,
kernel_source_fingerprint: source_fingerprint,
compiler_fingerprint,
device_evidence,
evidence: vec![
"compiled HIP sheaf overlap residual check with /opt/rocm/bin/hipcc -O2 (hipcc auto-adds --offload-arch=gfx1101 for the local 7800 XT)"
.to_string(),
"launched one block per overlap with blockDim.x=64 and reduced per-position residual to a scalar"
.to_string(),
"SectionTable precision class routes the per-position subtraction (fp16 / 8-digit p-adic / 16-digit p-adic)"
.to_string(),
],
non_claims: vec![
"not broad GPU execution".to_string(),
"not production performance evidence".to_string(),
"not machine-code verification".to_string(),
],
})
}
pub fn run_rocm_hip_sheaf_overlap_check_with_timing(
sections: &[Vec<u16>],
overlaps: &[(usize, usize)],
precision_class: PrecisionClass,
) -> Result<RocmHipSheafOverlapCheckTimings> {
run_rocm_hip_sheaf_overlap_check_with_timing_and_iterations(
sections,
overlaps,
precision_class,
32,
)
}
pub fn run_rocm_hip_sheaf_overlap_check_with_timing_and_iterations(
sections: &[Vec<u16>],
overlaps: &[(usize, usize)],
precision_class: PrecisionClass,
iterations: usize,
) -> Result<RocmHipSheafOverlapCheckTimings> {
let section_dim = if sections.is_empty() {
0
} else {
sections[0].len()
};
let iterations = iterations.max(1);
let host_start = Instant::now();
for _ in 0..iterations {
let _ = compute_host_residuals(sections, overlaps, section_dim);
}
let host_wall_time_ms = host_start.elapsed().as_secs_f64() * 1000.0;
let gpu_start = Instant::now();
let report = run_rocm_hip_sheaf_overlap_check_with_options(
sections,
overlaps,
precision_class,
None,
RocmHipSheafOverlapCheckLaunchOptions { iterations },
)?;
let gpu_wall_time_ms = gpu_start.elapsed().as_secs_f64() * 1000.0;
let gpu_kernel_time_ms = report.kernel_time_ms;
Ok(RocmHipSheafOverlapCheckTimings {
report,
gpu_wall_time_ms,
host_wall_time_ms,
gpu_kernel_time_ms,
})
}
pub fn hip_sheaf_overlap_check_kernel_source_fingerprint() -> String {
fingerprint(
"hip-sheaf-overlap-check-source",
HIP_SHEAF_OVERLAP_CHECK_KERNEL,
)
}
fn parse_residuals(stdout: &str) -> Result<Vec<f32>> {
let line = stdout
.lines()
.find_map(|line| line.strip_prefix("RESIDUALS="))
.ok_or_else(|| Error::backend("HIP sheaf overlap check did not print RESIDUALS"))?;
if line.trim().is_empty() {
return Ok(Vec::new());
}
line.split(',')
.map(|value| {
value
.trim()
.parse::<f32>()
.map_err(|err| Error::backend(format!("invalid HIP residual value {value}: {err}")))
})
.collect()
}
fn parse_u32_line(stdout: &str, prefix: &str) -> Option<u32> {
stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.and_then(|value| value.trim().parse::<u32>().ok())
}
fn parse_f32_line(stdout: &str, prefix: &str) -> Option<f32> {
stdout
.lines()
.find_map(|line| line.strip_prefix(prefix))
.and_then(|value| value.trim().parse::<f32>().ok())
}
fn fingerprint(label: &str, value: &str) -> String {
let mut hasher = DefaultHasher::new();
label.hash(&mut hasher);
value.hash(&mut hasher);
format!("{label}-{:016x}", hasher.finish())
}