#[cfg(target_os = "linux")]
use std::sync::Arc;
#[cfg(target_os = "linux")]
use std::sync::OnceLock;
#[cfg(target_os = "linux")]
use cudarc::driver::{CudaModule, CudaStream, LaunchConfig, PushKernelArg};
#[cfg(target_os = "linux")]
use super::error::GpuError;
#[cfg(target_os = "linux")]
use crate::gpu::error::GpuResultExt;
#[cfg(target_os = "linux")]
pub(crate) const MAX_R: usize = super::bms_flex_row::MAX_R;
#[cfg(target_os = "linux")]
const ROW_HV_THREADS: u32 = 32;
pub(crate) struct RowHessianMatvecInputs<'a> {
pub n_rows: usize,
pub r: usize,
pub h_rows: &'a [f64],
pub v_rows: &'a [f64],
}
#[cfg(target_os = "linux")]
#[derive(Debug)]
pub(crate) struct RowHessianMatvecOutputs {
pub y_rows: Vec<f64>,
}
pub(crate) struct RowHessianDiagInputs<'a> {
pub n_rows: usize,
pub r: usize,
pub h_rows: &'a [f64],
}
#[cfg(target_os = "linux")]
#[derive(Debug)]
pub(crate) struct RowHessianDiagOutputs {
pub d_rows: Vec<f64>,
}
#[cfg(target_os = "linux")]
impl<'a> RowHessianMatvecInputs<'a> {
pub(crate) fn validate(&self) -> Result<(), GpuError> {
if self.r == 0 {
return Err(GpuError::DriverCallFailed {
reason: "row_hessian_matvec inputs: r must be > 0".to_string(),
});
}
if self.r > MAX_R {
crate::gpu_bail!(
"row_hessian_matvec inputs: r={} exceeds MAX_R={MAX_R}",
self.r
);
}
if self.h_rows.len() != self.n_rows * self.r * self.r {
crate::gpu_bail!(
"row_hessian_matvec inputs: h_rows.len()={} != n_rows({})*r({})*r = {}",
self.h_rows.len(),
self.n_rows,
self.r,
self.n_rows * self.r * self.r
);
}
if self.v_rows.len() != self.n_rows * self.r {
crate::gpu_bail!(
"row_hessian_matvec inputs: v_rows.len()={} != n_rows({})*r({}) = {}",
self.v_rows.len(),
self.n_rows,
self.r,
self.n_rows * self.r
);
}
Ok(())
}
}
#[cfg(target_os = "linux")]
impl<'a> RowHessianDiagInputs<'a> {
pub(crate) fn validate(&self) -> Result<(), GpuError> {
if self.r == 0 {
return Err(GpuError::DriverCallFailed {
reason: "row_hessian_diag inputs: r must be > 0".to_string(),
});
}
if self.r > MAX_R {
crate::gpu_bail!(
"row_hessian_diag inputs: r={} exceeds MAX_R={MAX_R}",
self.r
);
}
if self.h_rows.len() != self.n_rows * self.r * self.r {
crate::gpu_bail!(
"row_hessian_diag inputs: h_rows.len()={} != n_rows({})*r({})*r = {}",
self.h_rows.len(),
self.n_rows,
self.r,
self.n_rows * self.r * self.r
);
}
Ok(())
}
}
#[cfg(target_os = "linux")]
const ROW_KERNEL_SOURCE: &str = r#"
extern "C" {
// Per-row matvec: y_i[u] = sum_v H_i[u, v] * v_i[v].
// One block per row; blockDim.x = 32. Each thread accumulates a partial
// sum over the inner `v` index for its slice of `u` rows.
//
// Parity reference: `scratch.hess.dot(&row_dir)` in CPU
// exact_newton_joint_hessian_matvec_from_cache.
__global__ void row_hessian_matvec_kernel(
const int n_rows,
const int r,
const double* __restrict__ h_rows, // [n_rows, r, r] row-major
const double* __restrict__ v_rows, // [n_rows, r] row-major
double* __restrict__ y_rows // [n_rows, r] row-major
) {
const int row = blockIdx.x;
if (row >= n_rows) return;
const int tid = threadIdx.x;
const int nthr = blockDim.x;
// Stage the direction in shared memory so each `u`-row reuses it.
// MAX_R = 32 (matches host const); we keep the array fixed-size and
// index-guard with `r` for the partial-warp case.
__shared__ double v_shared[32];
for (int u = tid; u < r; u += nthr) {
v_shared[u] = v_rows[row * r + u];
}
__syncthreads();
// Each thread handles a strided subset of output rows `u`.
const double* h_base = h_rows + (size_t)row * (size_t)r * (size_t)r;
double* y_base = y_rows + (size_t)row * (size_t)r;
for (int u = tid; u < r; u += nthr) {
const double* h_row = h_base + (size_t)u * (size_t)r;
double acc = 0.0;
for (int v = 0; v < r; ++v) {
acc += h_row[v] * v_shared[v];
}
y_base[u] = acc;
}
}
// Per-row diagonal: d_i[u] = H_i[u, u].
// One block per row; blockDim.x = 32. Each thread extracts a strided
// subset of diagonal entries; no inner reduction is needed.
//
// Parity reference: `row_hess[[u, u]]` in CPU
// exact_newton_joint_hessian_diagonal_from_cache.
__global__ void row_hessian_diag_kernel(
const int n_rows,
const int r,
const double* __restrict__ h_rows, // [n_rows, r, r] row-major
double* __restrict__ d_rows // [n_rows, r] row-major
) {
const int row = blockIdx.x;
if (row >= n_rows) return;
const int tid = threadIdx.x;
const int nthr = blockDim.x;
const double* h_base = h_rows + (size_t)row * (size_t)r * (size_t)r;
double* d_base = d_rows + (size_t)row * (size_t)r;
for (int u = tid; u < r; u += nthr) {
d_base[u] = h_base[(size_t)u * (size_t)r + (size_t)u];
}
}
} // extern "C"
"#;
#[cfg(target_os = "linux")]
struct RowOpsBackend {
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
}
#[cfg(target_os = "linux")]
impl RowOpsBackend {
fn probe() -> Result<&'static Self, GpuError> {
static BACKEND: OnceLock<Result<RowOpsBackend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
let runtime = super::runtime::GpuRuntime::global().ok_or_else(|| {
GpuError::DriverLibraryUnavailable {
reason: "row_hessian_ops backend: no CUDA runtime available".to_string(),
}
})?;
let ctx = super::runtime::cuda_context_for(runtime.selected_device().ordinal)
.ok_or_else(|| {
gpu_err!(
"row_hessian_ops backend: failed to create CUDA context for device {}",
runtime.selected_device().ordinal
)
})?;
let stream = ctx.default_stream();
let ptx = cudarc::nvrtc::compile_ptx(ROW_KERNEL_SOURCE)
.map_err(|err| gpu_err!("row_hessian_ops NVRTC compile failed: {err}"))?;
let module = ctx
.load_module(ptx)
.gpu_ctx("row_hessian_ops module load failed")?;
Ok(RowOpsBackend { stream, module })
})
.as_ref()
.map_err(GpuError::clone)
}
}
#[cfg(target_os = "linux")]
pub(crate) fn launch_row_hessian_matvec(
inputs: RowHessianMatvecInputs<'_>,
) -> Result<RowHessianMatvecOutputs, GpuError> {
inputs.validate()?;
launch_matvec_linux(inputs)
}
#[cfg(target_os = "linux")]
pub(crate) fn launch_row_hessian_diag(
inputs: RowHessianDiagInputs<'_>,
) -> Result<RowHessianDiagOutputs, GpuError> {
inputs.validate()?;
launch_diag_linux(inputs)
}
#[cfg(target_os = "linux")]
fn launch_matvec_linux(
inputs: RowHessianMatvecInputs<'_>,
) -> Result<RowHessianMatvecOutputs, GpuError> {
let backend = RowOpsBackend::probe()?;
let stream = &backend.stream;
let n = inputs.n_rows;
let r = inputs.r;
let d_h = stream
.clone_htod(inputs.h_rows)
.gpu_ctx("row_hessian_matvec upload h_rows")?;
let d_v = stream
.clone_htod(inputs.v_rows)
.gpu_ctx("row_hessian_matvec upload v_rows")?;
let mut d_y = stream
.alloc_zeros::<f64>(n * r)
.gpu_ctx("row_hessian_matvec alloc y_rows")?;
let func = backend
.module
.load_function("row_hessian_matvec_kernel")
.gpu_ctx("row_hessian_matvec load_function")?;
let cfg = LaunchConfig {
grid_dim: (n as u32, 1, 1),
block_dim: (ROW_HV_THREADS, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = i32::try_from(n)
.map_err(|_| gpu_err!("row_hessian_matvec: n_rows={n} exceeds i32 range"))?;
let r_i32 =
i32::try_from(r).map_err(|_| gpu_err!("row_hessian_matvec: r={r} exceeds i32 range"))?;
let mut builder = stream.launch_builder(&func);
builder
.arg(&n_i32)
.arg(&r_i32)
.arg(&d_h)
.arg(&d_v)
.arg(&mut d_y);
unsafe { builder.launch(cfg) }.gpu_ctx("row_hessian_matvec launch")?;
stream
.synchronize()
.gpu_ctx("row_hessian_matvec synchronize")?;
let y_rows = stream
.clone_dtoh(&d_y)
.gpu_ctx("row_hessian_matvec download y_rows")?;
Ok(RowHessianMatvecOutputs { y_rows })
}
#[cfg(target_os = "linux")]
fn launch_diag_linux(inputs: RowHessianDiagInputs<'_>) -> Result<RowHessianDiagOutputs, GpuError> {
let backend = RowOpsBackend::probe()?;
let stream = &backend.stream;
let n = inputs.n_rows;
let r = inputs.r;
let d_h = stream
.clone_htod(inputs.h_rows)
.gpu_ctx("row_hessian_diag upload h_rows")?;
let mut d_d = stream
.alloc_zeros::<f64>(n * r)
.gpu_ctx("row_hessian_diag alloc d_rows")?;
let func = backend
.module
.load_function("row_hessian_diag_kernel")
.gpu_ctx("row_hessian_diag load_function")?;
let cfg = LaunchConfig {
grid_dim: (n as u32, 1, 1),
block_dim: (ROW_HV_THREADS, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 =
i32::try_from(n).map_err(|_| gpu_err!("row_hessian_diag: n_rows={n} exceeds i32 range"))?;
let r_i32 =
i32::try_from(r).map_err(|_| gpu_err!("row_hessian_diag: r={r} exceeds i32 range"))?;
let mut builder = stream.launch_builder(&func);
builder.arg(&n_i32).arg(&r_i32).arg(&d_h).arg(&mut d_d);
unsafe { builder.launch(cfg) }.gpu_ctx("row_hessian_diag launch")?;
stream
.synchronize()
.gpu_ctx("row_hessian_diag synchronize")?;
let d_rows = stream
.clone_dtoh(&d_d)
.gpu_ctx("row_hessian_diag download d_rows")?;
Ok(RowHessianDiagOutputs { d_rows })
}
pub(crate) fn cpu_row_hessian_matvec(inputs: &RowHessianMatvecInputs<'_>) -> Vec<f64> {
let n = inputs.n_rows;
let r = inputs.r;
let mut y = vec![0.0_f64; n * r];
for row in 0..n {
let h_base = row * r * r;
let v_base = row * r;
for u in 0..r {
let mut acc = 0.0_f64;
for v in 0..r {
acc += inputs.h_rows[h_base + u * r + v] * inputs.v_rows[v_base + v];
}
y[v_base + u] = acc;
}
}
y
}
pub(crate) fn cpu_row_hessian_diag(inputs: &RowHessianDiagInputs<'_>) -> Vec<f64> {
let n = inputs.n_rows;
let r = inputs.r;
let mut d = vec![0.0_f64; n * r];
for row in 0..n {
let h_base = row * r * r;
let v_base = row * r;
for u in 0..r {
d[v_base + u] = inputs.h_rows[h_base + u * r + u];
}
}
d
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(target_os = "linux")]
fn make_fixture(n_rows: usize, r: usize) -> (Vec<f64>, Vec<f64>) {
let mut h = vec![0.0_f64; n_rows * r * r];
let mut v = vec![0.0_f64; n_rows * r];
for row in 0..n_rows {
let base = row * r * r;
for u in 0..r {
for vv in 0..r {
let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
h[base + u * r + vv] = a;
}
}
for u in 0..r {
for vv in (u + 1)..r {
let upper = h[base + u * r + vv];
let lower = h[base + vv * r + u];
let sym = 0.5 * (upper + lower);
h[base + u * r + vv] = sym;
h[base + vv * r + u] = sym;
}
h[base + u * r + u] += r as f64;
}
for u in 0..r {
let seed = (row as f64) * 0.211 + (u as f64) * 0.733 + 1.5;
v[row * r + u] = seed.sin() * 0.6 - (seed * 0.5).cos() * 0.4;
}
}
(h, v)
}
#[test]
fn cpu_oracle_matches_handwritten_2x2() {
let h_rows = vec![2.0, 1.0, 1.0, 3.0, 4.0, 0.0, 0.0, 5.0];
let v_rows = vec![1.0, -1.0, 2.0, 3.0];
let inputs = RowHessianMatvecInputs {
n_rows: 2,
r: 2,
h_rows: &h_rows,
v_rows: &v_rows,
};
inputs.validate().expect("hand fixture must validate");
let y = cpu_row_hessian_matvec(&inputs);
assert_eq!(y, vec![1.0, -2.0, 8.0, 15.0]);
let diag_inputs = RowHessianDiagInputs {
n_rows: 2,
r: 2,
h_rows: &h_rows,
};
diag_inputs.validate().expect("hand fixture must validate");
let d = cpu_row_hessian_diag(&diag_inputs);
assert_eq!(d, vec![2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn validate_rejects_mismatched_shapes() {
let h_rows = vec![1.0; 8];
let v_rows = vec![1.0; 3]; let inputs = RowHessianMatvecInputs {
n_rows: 2,
r: 2,
h_rows: &h_rows,
v_rows: &v_rows,
};
match inputs.validate() {
Err(GpuError::DriverCallFailed { reason }) => {
assert!(reason.contains("v_rows"), "unexpected reason: {reason}");
}
other => panic!("expected DriverCallFailed, got {other:?}"),
}
let big_r = MAX_R + 1;
let h_rows = vec![0.0; big_r * big_r];
let v_rows = vec![0.0; big_r];
let inputs = RowHessianMatvecInputs {
n_rows: 1,
r: big_r,
h_rows: &h_rows,
v_rows: &v_rows,
};
match inputs.validate() {
Err(GpuError::DriverCallFailed { reason }) => {
assert!(reason.contains("MAX_R"), "unexpected reason: {reason}");
}
other => panic!("expected DriverCallFailed for over-MAX_R, got {other:?}"),
}
}
#[cfg(target_os = "linux")]
#[test]
fn row_hessian_kernels_match_cpu_oracle_when_cuda_available() {
let Some(_runtime) = crate::gpu::runtime::GpuRuntime::global() else {
eprintln!("[row_hessian_ops parity] no CUDA runtime — skipping CUDA parity");
return;
};
let n_rows = 4;
let r = 5;
let (h_rows, v_rows) = make_fixture(n_rows, r);
let matvec_inputs = RowHessianMatvecInputs {
n_rows,
r,
h_rows: &h_rows,
v_rows: &v_rows,
};
matvec_inputs
.validate()
.expect("matvec fixture must validate");
let cpu_y = cpu_row_hessian_matvec(&matvec_inputs);
let gpu_y = match launch_row_hessian_matvec(matvec_inputs) {
Ok(out) => out.y_rows,
Err(err) => {
eprintln!(
"[row_hessian_ops parity] matvec launch failed: {err}; \
treating as CI infra outage, not parity regression"
);
return;
}
};
let tol_abs = 2e-8_f64;
let tol_rel = 2e-7_f64;
assert_eq!(cpu_y.len(), gpu_y.len(), "matvec output length mismatch");
for (i, (&c, &g)) in cpu_y.iter().zip(gpu_y.iter()).enumerate() {
let diff = (c - g).abs();
let tol = tol_abs + tol_rel * c.abs();
assert!(
diff <= tol,
"matvec[{i}]: |cpu - gpu| = {diff:.3e} > tol = {tol:.3e}; \
cpu={c:.17e}, gpu={g:.17e}"
);
}
let diag_inputs = RowHessianDiagInputs {
n_rows,
r,
h_rows: &h_rows,
};
diag_inputs.validate().expect("diag fixture must validate");
let cpu_d = cpu_row_hessian_diag(&diag_inputs);
let gpu_d = match launch_row_hessian_diag(diag_inputs) {
Ok(out) => out.d_rows,
Err(err) => {
eprintln!(
"[row_hessian_ops parity] diag launch failed: {err}; \
treating as CI infra outage, not parity regression"
);
return;
}
};
assert_eq!(cpu_d.len(), gpu_d.len(), "diag output length mismatch");
for (i, (&c, &g)) in cpu_d.iter().zip(gpu_d.iter()).enumerate() {
let diff = (c - g).abs();
let tol = tol_abs + tol_rel * c.abs();
assert!(
diff <= tol,
"diag[{i}]: |cpu - gpu| = {diff:.3e} > tol = {tol:.3e}; \
cpu={c:.17e}, gpu={g:.17e}"
);
}
}
}