use std::sync::OnceLock;
use super::error::GpuError;
#[cfg(target_os = "linux")]
use super::error::GpuResultExt;
#[cfg(target_os = "linux")]
use std::sync::{Arc, Mutex};
#[cfg(target_os = "linux")]
use cudarc::driver::{CudaContext, CudaModule};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum PirlsRowFamily {
BernoulliLogit,
BernoulliProbit,
BernoulliCLogLog,
PoissonLog,
GaussianIdentity,
GammaLog,
}
impl PirlsRowFamily {
pub const ALL: [Self; 6] = [
Self::BernoulliLogit,
Self::BernoulliProbit,
Self::BernoulliCLogLog,
Self::PoissonLog,
Self::GaussianIdentity,
Self::GammaLog,
];
pub const fn as_str(self) -> &'static str {
match self {
Self::BernoulliLogit => "bernoulli-logit",
Self::BernoulliProbit => "bernoulli-probit",
Self::BernoulliCLogLog => "bernoulli-cloglog",
Self::PoissonLog => "poisson-log",
Self::GaussianIdentity => "gaussian-identity",
Self::GammaLog => "gamma-log",
}
}
pub const fn kernel_name(self) -> &'static str {
match self {
Self::BernoulliLogit => "pirls_row_bernoulli_logit",
Self::BernoulliProbit => "pirls_row_bernoulli_probit",
Self::BernoulliCLogLog => "pirls_row_bernoulli_cloglog",
Self::PoissonLog => "pirls_row_poisson_log",
Self::GaussianIdentity => "pirls_row_gaussian_identity",
Self::GammaLog => "pirls_row_gamma_log",
}
}
pub const fn solve_kernel_name(self) -> &'static str {
match self {
Self::BernoulliLogit => "pirls_solve_bernoulli_logit",
Self::BernoulliProbit => "pirls_solve_bernoulli_probit",
Self::BernoulliCLogLog => "pirls_solve_bernoulli_cloglog",
Self::PoissonLog => "pirls_solve_poisson_log",
Self::GaussianIdentity => "pirls_solve_gaussian_identity",
Self::GammaLog => "pirls_solve_gamma_log",
}
}
pub const fn ladder_kernel_name(self) -> &'static str {
match self {
Self::BernoulliLogit => "pirls_ladder_bernoulli_logit",
Self::BernoulliProbit => "pirls_ladder_bernoulli_probit",
Self::BernoulliCLogLog => "pirls_ladder_bernoulli_cloglog",
Self::PoissonLog => "pirls_ladder_poisson_log",
Self::GaussianIdentity => "pirls_ladder_gaussian_identity",
Self::GammaLog => "pirls_ladder_gamma_log",
}
}
pub const fn is_canonical(self) -> bool {
match self {
Self::BernoulliLogit | Self::PoissonLog | Self::GaussianIdentity => true,
Self::GammaLog | Self::BernoulliProbit | Self::BernoulliCLogLog => false,
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum CurvatureMode {
Fisher,
Observed,
}
impl CurvatureMode {
pub const fn as_str(self) -> &'static str {
match self {
Self::Fisher => "fisher",
Self::Observed => "observed",
}
}
}
pub mod status_flags {
pub const ETA_CLAMPED: u32 = 1 << 0;
pub const MU_FLOORED: u32 = 1 << 1;
pub const NONSMOOTH_BERNOULLI: u32 = 1 << 2;
pub const INVALID_RESPONSE: u32 = 1 << 3;
pub const ZERO_PRIOR_WEIGHT: u32 = 1 << 4;
}
#[derive(Clone, Copy, Debug)]
pub struct RowInput {
pub eta: f64,
pub y: f64,
pub prior_weight: f64,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct RowOutput {
pub mu: f64,
pub grad_eta: f64,
pub w_fisher: f64,
pub w_hessian: f64,
pub w_solver: f64,
pub z_fisher: f64,
pub z_hessian: f64,
pub deviance: f64,
pub status: u32,
}
const ETA_CLAMP: f64 = 700.0;
const MU_FLOOR_POISSON: f64 = 1.0e-10;
const MU_FLOOR_GAMMA: f64 = 1.0e-10;
const MU_FLOOR_BERNOULLI: f64 = 1.0e-12;
const W_SOLVER_FLOOR: f64 = 1.0e-12;
const DMU_DETA_MIN: f64 = 0.0;
#[inline]
fn clamp_eta(eta: f64) -> (f64, bool) {
if eta > ETA_CLAMP {
(ETA_CLAMP, true)
} else if eta < -ETA_CLAMP {
(-ETA_CLAMP, true)
} else {
(eta, false)
}
}
pub fn row_reweight_cpu(
family: PirlsRowFamily,
mode: CurvatureMode,
input: RowInput,
gamma_shape: f64,
) -> RowOutput {
match family {
PirlsRowFamily::GaussianIdentity => row_gaussian_identity(input, mode),
PirlsRowFamily::PoissonLog => row_poisson_log(input, mode),
PirlsRowFamily::GammaLog => row_gamma_log(input, mode, gamma_shape),
PirlsRowFamily::BernoulliLogit => row_bernoulli_logit(input, mode),
PirlsRowFamily::BernoulliProbit => row_bernoulli_probit(input, mode),
PirlsRowFamily::BernoulliCLogLog => row_bernoulli_cloglog(input, mode),
}
}
#[inline]
fn select_w_hessian(mode: CurvatureMode, w_fisher: f64, observed_correction: f64) -> f64 {
match mode {
CurvatureMode::Fisher => w_fisher,
CurvatureMode::Observed => w_fisher + observed_correction,
}
}
#[inline]
fn row_gaussian_identity(input: RowInput, mode: CurvatureMode) -> RowOutput {
let w = input.prior_weight.max(0.0);
let mu = input.eta;
let resid = input.y - mu;
let dev = w * resid * resid;
let status = if input.prior_weight <= 0.0 {
status_flags::ZERO_PRIOR_WEIGHT
} else {
0
};
let w_hessian = select_w_hessian(mode, w, 0.0);
RowOutput {
mu,
grad_eta: w * resid,
w_fisher: w,
w_hessian,
w_solver: if w_hessian > 0.0 {
w_hessian.max(W_SOLVER_FLOOR)
} else {
0.0
},
z_fisher: input.y,
z_hessian: input.y,
deviance: dev,
status,
}
}
#[inline]
fn row_poisson_log(input: RowInput, mode: CurvatureMode) -> RowOutput {
let (eta_c, clamped) = clamp_eta(input.eta);
let mu_raw = eta_c.exp();
let mu_floored = mu_raw < MU_FLOOR_POISSON;
let mu = mu_raw.max(MU_FLOOR_POISSON);
let w_prior = input.prior_weight.max(0.0);
let raw_w = w_prior * mu;
let w_fisher = if raw_w > 0.0 {
raw_w.max(W_SOLVER_FLOOR)
} else {
0.0
};
let resid = input.y - mu;
let dev_term = if input.y > 0.0 {
input.y * (input.y / mu).ln() - resid
} else {
-resid
};
let dev = 2.0 * w_prior * dev_term;
let z = eta_c + resid / mu;
let mut status = 0u32;
if clamped {
status |= status_flags::ETA_CLAMPED;
}
if mu_floored {
status |= status_flags::MU_FLOORED;
}
if input.prior_weight <= 0.0 {
status |= status_flags::ZERO_PRIOR_WEIGHT;
}
if !(input.y.is_finite() && input.y >= 0.0) {
status |= status_flags::INVALID_RESPONSE;
}
let w_hessian = select_w_hessian(mode, w_fisher, 0.0);
RowOutput {
mu,
grad_eta: w_prior * resid,
w_fisher,
w_hessian,
w_solver: w_hessian,
z_fisher: z,
z_hessian: z,
deviance: dev,
status,
}
}
#[inline]
fn row_gamma_log(input: RowInput, mode: CurvatureMode, shape: f64) -> RowOutput {
let (eta_c, clamped) = clamp_eta(input.eta);
let mu_raw = eta_c.exp();
let mu_floored = mu_raw < MU_FLOOR_GAMMA;
let mu = mu_raw.max(MU_FLOOR_GAMMA);
let w_prior = input.prior_weight.max(0.0);
let w_fisher = w_prior * shape;
let obs_correction = if w_fisher > 0.0 && mu > 0.0 && input.y.is_finite() {
w_fisher * (input.y / mu - 1.0)
} else {
0.0
};
let w_hessian = select_w_hessian(mode, w_fisher, obs_correction);
let resid = input.y - mu;
let dev = if input.y > 0.0 {
2.0 * w_prior * (-((input.y / mu).ln()) + resid / mu)
} else {
f64::INFINITY
};
let z = eta_c + resid / mu;
let mut status = 0u32;
if clamped {
status |= status_flags::ETA_CLAMPED;
}
if mu_floored {
status |= status_flags::MU_FLOORED;
}
if input.prior_weight <= 0.0 {
status |= status_flags::ZERO_PRIOR_WEIGHT;
}
if !(input.y.is_finite() && input.y > 0.0) {
status |= status_flags::INVALID_RESPONSE;
}
RowOutput {
mu,
grad_eta: w_prior * resid / mu,
w_fisher,
w_hessian,
w_solver: if w_hessian > 0.0 {
w_hessian.max(W_SOLVER_FLOOR)
} else {
0.0
},
z_fisher: z,
z_hessian: z,
deviance: dev,
status,
}
}
#[inline]
fn row_bernoulli_logit(input: RowInput, mode: CurvatureMode) -> RowOutput {
let (eta_c, clamped) = clamp_eta(input.eta);
let half = 0.5 * eta_c;
let mu_raw = 0.5 * (1.0 + half.tanh());
let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
let w_prior = input.prior_weight.max(0.0);
let dmu_deta = mu * (1.0 - mu); let w_fisher = w_prior * dmu_deta; let resid = input.y - mu;
let grad_eta = w_prior * resid; let dev = bernoulli_deviance(input.y, mu, w_prior);
let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
let mut status = 0u32;
if clamped {
status |= status_flags::ETA_CLAMPED;
}
if mu_low || mu_high {
status |= status_flags::MU_FLOORED;
}
if input.prior_weight <= 0.0 {
status |= status_flags::ZERO_PRIOR_WEIGHT;
}
if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
status |= status_flags::INVALID_RESPONSE;
}
let w_hessian = select_w_hessian(mode, w_fisher, 0.0);
RowOutput {
mu,
grad_eta,
w_fisher,
w_hessian,
w_solver: if w_hessian > 0.0 {
w_hessian.max(W_SOLVER_FLOOR)
} else {
0.0
},
z_fisher: z,
z_hessian: z,
deviance: dev,
status,
}
}
#[inline]
fn row_bernoulli_probit(input: RowInput, mode: CurvatureMode) -> RowOutput {
let (eta_c, clamped) = clamp_eta(input.eta);
let mu_raw = standard_normal_cdf(eta_c);
let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
let w_prior = input.prior_weight.max(0.0);
let dmu_deta = standard_normal_pdf(eta_c); let v = mu * (1.0 - mu);
let fisher_per_prior = if v > 0.0 {
dmu_deta * dmu_deta / v
} else {
0.0
};
let w_fisher = w_prior * fisher_per_prior;
let resid = input.y - mu;
let grad_eta = if v > 0.0 {
w_prior * resid * dmu_deta / v
} else {
0.0
};
let dev = bernoulli_deviance(input.y, mu, w_prior);
let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
let mut status = 0u32;
if clamped {
status |= status_flags::ETA_CLAMPED;
}
if mu_low || mu_high {
status |= status_flags::MU_FLOORED;
}
if input.prior_weight <= 0.0 {
status |= status_flags::ZERO_PRIOR_WEIGHT;
}
if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
status |= status_flags::INVALID_RESPONSE;
}
let obs_correction = if v > 0.0 && w_prior > 0.0 {
let h_prime = -eta_c * dmu_deta;
let v_prime = 1.0 - 2.0 * mu;
let bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
w_prior * resid * bracket
} else {
0.0
};
let w_hessian_observed = select_w_hessian(mode, w_fisher, obs_correction);
RowOutput {
mu,
grad_eta,
w_fisher,
w_hessian: w_hessian_observed,
w_solver: {
let wh = w_hessian_observed;
if wh > 0.0 {
wh.max(W_SOLVER_FLOOR)
} else {
0.0
}
},
z_fisher: z,
z_hessian: z,
deviance: dev,
status,
}
}
#[inline]
fn row_bernoulli_cloglog(input: RowInput, mode: CurvatureMode) -> RowOutput {
let (eta_c, clamped) = clamp_eta(input.eta);
let inner = eta_c.exp();
let mu_raw = -(-inner).exp_m1();
let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
let dmu_deta = inner * (1.0 - mu_raw);
let w_prior = input.prior_weight.max(0.0);
let v = mu * (1.0 - mu);
let fisher_per_prior = if v > 0.0 {
dmu_deta * dmu_deta / v
} else {
0.0
};
let w_fisher = w_prior * fisher_per_prior;
let resid = input.y - mu;
let grad_eta = if v > 0.0 {
w_prior * resid * dmu_deta / v
} else {
0.0
};
let dev = bernoulli_deviance(input.y, mu, w_prior);
let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
let mut status = 0u32;
if clamped {
status |= status_flags::ETA_CLAMPED;
}
if mu_low || mu_high {
status |= status_flags::MU_FLOORED;
}
if input.prior_weight <= 0.0 {
status |= status_flags::ZERO_PRIOR_WEIGHT;
}
if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
status |= status_flags::INVALID_RESPONSE;
}
let obs_correction = if v > 0.0 && w_prior > 0.0 {
let h_prime = dmu_deta * (1.0 - inner);
let v_prime = 1.0 - 2.0 * mu;
let bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
w_prior * resid * bracket
} else {
0.0
};
let w_hessian = select_w_hessian(mode, w_fisher, obs_correction);
RowOutput {
mu,
grad_eta,
w_fisher,
w_hessian,
w_solver: if w_hessian > 0.0 {
w_hessian.max(W_SOLVER_FLOOR)
} else {
0.0
},
z_fisher: z,
z_hessian: z,
deviance: dev,
status,
}
}
#[inline]
fn bernoulli_deviance(y: f64, mu: f64, w_prior: f64) -> f64 {
if w_prior == 0.0 {
return 0.0;
}
let t1 = if y > 0.0 { y * (y / mu).ln() } else { 0.0 };
let t2 = if y < 1.0 {
(1.0 - y) * ((1.0 - y) / (1.0 - mu)).ln()
} else {
0.0
};
2.0 * w_prior * (t1 + t2)
}
#[inline]
fn bernoulli_z(eta_used: f64, y: f64, mu: f64, dmu_deta: f64) -> f64 {
if dmu_deta.is_finite() && dmu_deta > DMU_DETA_MIN {
let delta = (y - mu) / dmu_deta;
if delta.is_finite() {
return eta_used + delta;
}
}
eta_used
}
#[inline]
fn standard_normal_cdf(x: f64) -> f64 {
0.5 * erfc(-x * std::f64::consts::FRAC_1_SQRT_2)
}
#[inline]
fn standard_normal_pdf(x: f64) -> f64 {
const COEFF: f64 = 0.398_942_280_401_432_7; COEFF * (-0.5 * x * x).exp()
}
fn erfc(x: f64) -> f64 {
libm_erfc(x)
}
#[inline]
fn libm_erfc(x: f64) -> f64 {
if !x.is_finite() {
return if x.is_nan() {
f64::NAN
} else if x > 0.0 {
0.0
} else {
2.0
};
}
let ax = x.abs();
let t = 1.0 / (1.0 + 0.5 * ax);
let r = t
* (-ax * ax - 1.265_512_23
+ t * (1.000_023_68
+ t * (0.374_091_96
+ t * (0.096_784_18
+ t * (-0.186_288_06
+ t * (0.278_868_07
+ t * (-1.135_203_98
+ t * (1.488_515_87
+ t * (-0.822_152_23 + t * 0.170_872_77)))))))))
.exp();
if x >= 0.0 { r } else { 2.0 - r }
}
#[derive(Clone, Copy, Debug)]
pub struct PirlsRowDims {
pub n: usize,
}
#[must_use]
pub struct PirlsRowBackend {
#[cfg(target_os = "linux")]
inner: PirlsRowBackendLinux,
}
#[cfg(target_os = "linux")]
struct PirlsRowBackendLinux {
ctx: Arc<CudaContext>,
modules: Mutex<std::collections::HashMap<ModuleKey, Arc<CudaModule>>>,
jit_modules: Mutex<std::collections::HashMap<JitKey, Arc<CudaModule>>>,
}
#[cfg(target_os = "linux")]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
enum KernelMode {
FinalRow,
SolveRow,
AlphaLadder,
}
#[cfg(target_os = "linux")]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
struct ModuleKey {
family: PirlsRowFamily,
curvature: CurvatureMode,
mode: KernelMode,
}
impl PirlsRowBackend {
pub const fn compiled() -> bool {
cfg!(target_os = "linux")
}
pub fn probe() -> Result<&'static Self, GpuError> {
static BACKEND: OnceLock<Result<PirlsRowBackend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
#[cfg(target_os = "linux")]
{
Self::probe_linux()
}
#[cfg(not(target_os = "linux"))]
{
Err(GpuError::DriverLibraryUnavailable {
reason: "pirls_row GPU backend is Linux-only".to_string(),
})
}
})
.as_ref()
.map_err(GpuError::clone)
}
#[cfg(target_os = "linux")]
fn probe_linux() -> Result<Self, GpuError> {
let runtime = super::runtime::GpuRuntime::global().ok_or_else(|| {
GpuError::DriverLibraryUnavailable {
reason: "pirls_row backend: no CUDA runtime available".to_string(),
}
})?;
let ctx = super::runtime::cuda_context_for(runtime.selected_device().ordinal).ok_or_else(
|| {
gpu_err!(
"pirls_row backend: failed to create CUDA context for device {}",
runtime.selected_device().ordinal
)
},
)?;
Ok(Self {
inner: PirlsRowBackendLinux {
ctx,
modules: Mutex::new(std::collections::HashMap::new()),
jit_modules: Mutex::new(std::collections::HashMap::new()),
},
})
}
#[cfg(target_os = "linux")]
pub fn module_for(
&self,
family: PirlsRowFamily,
curvature: CurvatureMode,
) -> Result<Arc<CudaModule>, GpuError> {
let key = ModuleKey {
family,
curvature,
mode: KernelMode::FinalRow,
};
if let Some(existing) = self
.inner
.modules
.lock()
.gpu_ctx("pirls_row module cache mutex poisoned")?
.get(&key)
{
return Ok(existing.clone());
}
let source = cuda_source_for(family, curvature);
let ptx = cudarc::nvrtc::compile_ptx(source).gpu_ctx_with(|err| {
format!(
"pirls_row NVRTC compile failed for {family}/{curv}: {err}",
family = family.as_str(),
curv = curvature.as_str(),
)
})?;
let module = self
.inner
.ctx
.load_module(ptx)
.gpu_ctx("pirls_row module load failed")?;
self.inner
.modules
.lock()
.gpu_ctx("pirls_row module cache mutex poisoned")?
.insert(key, module.clone());
Ok(module)
}
#[cfg(target_os = "linux")]
pub fn module_for_solve(
&self,
family: PirlsRowFamily,
curvature: CurvatureMode,
) -> Result<Arc<CudaModule>, GpuError> {
let key = ModuleKey {
family,
curvature,
mode: KernelMode::SolveRow,
};
if let Some(existing) = self
.inner
.modules
.lock()
.gpu_ctx("pirls_row solve module cache mutex poisoned")?
.get(&key)
{
return Ok(existing.clone());
}
let source = solve_row_source_for(family, curvature);
let ptx = cudarc::nvrtc::compile_ptx(source).gpu_ctx_with(|err| {
format!(
"pirls_row solve NVRTC compile failed for {family}/{curv}: {err}",
family = family.as_str(),
curv = curvature.as_str(),
)
})?;
let module = self
.inner
.ctx
.load_module(ptx)
.gpu_ctx("pirls_row solve module load failed")?;
self.inner
.modules
.lock()
.gpu_ctx("pirls_row solve module cache mutex poisoned")?
.insert(key, module.clone());
Ok(module)
}
#[cfg(target_os = "linux")]
pub fn module_for_ladder(
&self,
family: PirlsRowFamily,
curvature: CurvatureMode,
) -> Result<Arc<CudaModule>, GpuError> {
let key = ModuleKey {
family,
curvature,
mode: KernelMode::AlphaLadder,
};
if let Some(existing) = self
.inner
.modules
.lock()
.gpu_ctx("pirls_row ladder module cache mutex poisoned")?
.get(&key)
{
return Ok(existing.clone());
}
let source = ladder_source_for(family, curvature);
let ptx = cudarc::nvrtc::compile_ptx(source).gpu_ctx_with(|err| {
format!(
"pirls_row ladder NVRTC compile failed for {family}/{curv}: {err}",
family = family.as_str(),
curv = curvature.as_str(),
)
})?;
let module = self
.inner
.ctx
.load_module(ptx)
.gpu_ctx("pirls_row ladder module load failed")?;
self.inner
.modules
.lock()
.gpu_ctx("pirls_row ladder module cache mutex poisoned")?
.insert(key, module.clone());
Ok(module)
}
#[cfg(target_os = "linux")]
pub fn module_for_jit(
&self,
spec: &JitFamilySpec,
curvature: CurvatureMode,
) -> Result<Arc<CudaModule>, GpuError> {
let key = JitKey {
spec_id: spec.spec_id,
curvature,
};
if let Some(existing) = self
.inner
.jit_modules
.lock()
.gpu_ctx("pirls_row jit cache poisoned")?
.get(&key)
{
return Ok(existing.clone());
}
let source = spec.cuda_source(curvature);
let ptx = cudarc::nvrtc::compile_ptx(source).gpu_ctx_with(|err| {
format!(
"pirls_row JIT NVRTC compile failed for spec_id={} curvature={}: {err}",
spec.spec_id,
curvature.as_str(),
)
})?;
let module = self
.inner
.ctx
.load_module(ptx)
.gpu_ctx("pirls_row JIT module load failed")?;
self.inner
.jit_modules
.lock()
.gpu_ctx("pirls_row jit cache poisoned (insert)")?
.insert(key, module.clone());
Ok(module)
}
}
#[cfg(target_os = "linux")]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
struct JitKey {
spec_id: u64,
curvature: CurvatureMode,
}
#[derive(Clone, Debug)]
pub struct JitFamilySpec {
pub spec_id: u64,
pub body: String,
}
impl JitFamilySpec {
#[cfg(target_os = "linux")]
pub fn glm(spec_id: u64, family: PirlsRowFamily, curvature: CurvatureMode) -> Self {
let body = match family {
PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
PirlsRowFamily::GammaLog => gamma_log_body(curvature),
PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
};
Self { spec_id, body }
}
pub fn raw(spec_id: u64, body: impl Into<String>) -> Self {
Self {
spec_id,
body: body.into(),
}
}
pub fn kernel_name(&self) -> String {
format!("pirls_row_jit_{}", self.spec_id)
}
#[cfg(target_os = "linux")]
pub fn cuda_source(&self, curvature: CurvatureMode) -> String {
let curvature_define = match curvature {
CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
};
let kernel_name = self.kernel_name();
let body = &self.body;
format!(
r#"
{curvature_define}
{prolog}
extern "C" __global__ void {kernel_name}(
int n,
const double* __restrict__ eta,
const double* __restrict__ y,
const double* __restrict__ prior_w,
double* __restrict__ mu_out,
double* __restrict__ grad_eta_out,
double* __restrict__ w_fisher_out,
double* __restrict__ w_hessian_out,
double* __restrict__ w_solver_out,
double* __restrict__ z_fisher_out,
double* __restrict__ z_hessian_out,
double* __restrict__ deviance_out,
unsigned int* __restrict__ status_out
) {{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n) return;
unsigned int flags = 0u;
double eta_i = eta[i];
double y_i = y[i];
double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
if (prior_w[i] <= 0.0) flags |= 0x10u;
{body}
mu_out[i] = mu;
grad_eta_out[i] = grad_eta;
w_fisher_out[i] = w_fisher;
w_hessian_out[i] = w_hessian;
w_solver_out[i] = w_solver;
z_fisher_out[i] = z_f;
z_hessian_out[i] = z_h;
deviance_out[i] = dev;
status_out[i] = flags;
}}
"#,
prolog = COMMON_DEVICE_PROLOG,
)
}
}
#[cfg(target_os = "linux")]
pub struct RowOutputDevBuffers {
pub mu: cudarc::driver::CudaSlice<f64>,
pub grad_eta: cudarc::driver::CudaSlice<f64>,
pub w_fisher: cudarc::driver::CudaSlice<f64>,
pub w_hessian: cudarc::driver::CudaSlice<f64>,
pub w_solver: cudarc::driver::CudaSlice<f64>,
pub z_fisher: cudarc::driver::CudaSlice<f64>,
pub z_hessian: cudarc::driver::CudaSlice<f64>,
pub deviance: cudarc::driver::CudaSlice<f64>,
pub status: cudarc::driver::CudaSlice<u32>,
pub n: usize,
}
#[cfg(target_os = "linux")]
impl RowOutputDevBuffers {
pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>, n: usize) -> Result<Self, GpuError> {
let alloc_f64 = |label: &'static str| {
stream
.alloc_zeros::<f64>(n)
.gpu_ctx_with(|err| format!("pirls_row alloc {label}: {err}"))
};
let alloc_u32 = |label: &'static str| {
stream
.alloc_zeros::<u32>(n)
.gpu_ctx_with(|err| format!("pirls_row alloc {label}: {err}"))
};
Ok(Self {
mu: alloc_f64("mu")?,
grad_eta: alloc_f64("grad_eta")?,
w_fisher: alloc_f64("w_fisher")?,
w_hessian: alloc_f64("w_hessian")?,
w_solver: alloc_f64("w_solver")?,
z_fisher: alloc_f64("z_fisher")?,
z_hessian: alloc_f64("z_hessian")?,
deviance: alloc_f64("deviance")?,
status: alloc_u32("status")?,
n,
})
}
}
#[cfg(target_os = "linux")]
pub struct SolveRowBuffers {
pub grad_eta: cudarc::driver::CudaSlice<f64>,
pub w_solver: cudarc::driver::CudaSlice<f64>,
pub deviance: cudarc::driver::CudaSlice<f64>,
pub status: cudarc::driver::CudaSlice<u32>,
pub n: usize,
}
#[cfg(target_os = "linux")]
impl SolveRowBuffers {
pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>, n: usize) -> Result<Self, GpuError> {
let alloc_f64 = |label: &'static str| {
stream
.alloc_zeros::<f64>(n)
.gpu_ctx_with(|err| format!("pirls_row solve alloc {label}: {err}"))
};
let alloc_u32 = |label: &'static str| {
stream
.alloc_zeros::<u32>(n)
.gpu_ctx_with(|err| format!("pirls_row solve alloc {label}: {err}"))
};
Ok(Self {
grad_eta: alloc_f64("grad_eta")?,
w_solver: alloc_f64("w_solver")?,
deviance: alloc_f64("deviance")?,
status: alloc_u32("status")?,
n,
})
}
}
pub const ALPHA_LADDER_LEN: usize = 7;
pub const ALPHA_LADDER: [f64; ALPHA_LADDER_LEN] =
[1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625];
#[cfg(target_os = "linux")]
pub struct AlphaLadderDevBuffers {
pub objective_dev: cudarc::driver::CudaSlice<f64>,
pub status_dev: cudarc::driver::CudaSlice<u32>,
}
#[cfg(target_os = "linux")]
impl AlphaLadderDevBuffers {
pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>) -> Result<Self, GpuError> {
Ok(Self {
objective_dev: stream
.alloc_zeros::<f64>(ALPHA_LADDER_LEN)
.gpu_ctx_with(|err| format!("pirls_row ladder alloc objective: {err}"))?,
status_dev: stream
.alloc_zeros::<u32>(ALPHA_LADDER_LEN)
.gpu_ctx_with(|err| format!("pirls_row ladder alloc status: {err}"))?,
})
}
pub fn zero(&mut self, stream: &Arc<cudarc::driver::CudaStream>) -> Result<(), GpuError> {
stream
.memset_zeros(&mut self.objective_dev)
.gpu_ctx_with(|err| format!("pirls_row ladder zero objective: {err}"))?;
stream
.memset_zeros(&mut self.status_dev)
.gpu_ctx_with(|err| format!("pirls_row ladder zero status: {err}"))
}
}
#[cfg(target_os = "linux")]
pub fn launch_row_reweight_on_stream(
backend: &PirlsRowBackend,
family: PirlsRowFamily,
curvature: CurvatureMode,
gamma_shape: f64,
stream: &Arc<cudarc::driver::CudaStream>,
n: usize,
eta_dev: &cudarc::driver::CudaSlice<f64>,
y_dev: &cudarc::driver::CudaSlice<f64>,
prior_w_dev: &cudarc::driver::CudaSlice<f64>,
out: &mut RowOutputDevBuffers,
) -> Result<(), GpuError> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
if out.n != n {
crate::gpu_bail!("row reweight buffers shape {} mismatches n={n}", out.n);
}
let module = backend.module_for(family, curvature)?;
let func = module
.load_function(family.kernel_name())
.gpu_ctx_with(|err| {
format!(
"row reweight load_function({}): {err}",
family.kernel_name()
)
})?;
const THREADS_PER_BLOCK: u32 = 256;
let n_u32 =
u32::try_from(n).map_err(|_| gpu_err!("n={n} exceeds u32 for row reweight grid sizing"))?;
let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
let n_i32 = i32::try_from(n)
.map_err(|_| gpu_err!("n={n} exceeds i32 for row reweight kernel argument"))?;
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(&n_i32);
builder.arg(eta_dev);
builder.arg(y_dev);
builder.arg(prior_w_dev);
if matches!(family, PirlsRowFamily::GammaLog) {
builder.arg(&gamma_shape);
}
builder.arg(&mut out.mu);
builder.arg(&mut out.grad_eta);
builder.arg(&mut out.w_fisher);
builder.arg(&mut out.w_hessian);
builder.arg(&mut out.w_solver);
builder.arg(&mut out.z_fisher);
builder.arg(&mut out.z_hessian);
builder.arg(&mut out.deviance);
builder.arg(&mut out.status);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.gpu_ctx_with(|err| format!("row reweight launch({}): {err}", family.kernel_name()))
}
#[cfg(target_os = "linux")]
pub fn launch_row_reweight_jit_on_stream(
backend: &PirlsRowBackend,
spec: &JitFamilySpec,
curvature: CurvatureMode,
stream: &Arc<cudarc::driver::CudaStream>,
n: usize,
eta_dev: &cudarc::driver::CudaSlice<f64>,
y_dev: &cudarc::driver::CudaSlice<f64>,
prior_w_dev: &cudarc::driver::CudaSlice<f64>,
out: &mut RowOutputDevBuffers,
) -> Result<(), GpuError> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
if out.n != n {
crate::gpu_bail!("JIT row reweight buffers shape {} mismatches n={n}", out.n);
}
let module = backend.module_for_jit(spec, curvature)?;
let kernel_name = spec.kernel_name();
let func = module
.load_function(&kernel_name)
.gpu_ctx_with(|err| format!("JIT row reweight load_function({kernel_name}): {err}"))?;
const THREADS_PER_BLOCK: u32 = 256;
let n_u32 = u32::try_from(n)
.map_err(|_| gpu_err!("n={n} exceeds u32 for JIT row reweight grid sizing"))?;
let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
let n_i32 = i32::try_from(n)
.map_err(|_| gpu_err!("n={n} exceeds i32 for JIT row reweight kernel argument"))?;
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(&n_i32);
builder.arg(eta_dev);
builder.arg(y_dev);
builder.arg(prior_w_dev);
builder.arg(&mut out.mu);
builder.arg(&mut out.grad_eta);
builder.arg(&mut out.w_fisher);
builder.arg(&mut out.w_hessian);
builder.arg(&mut out.w_solver);
builder.arg(&mut out.z_fisher);
builder.arg(&mut out.z_hessian);
builder.arg(&mut out.deviance);
builder.arg(&mut out.status);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.gpu_ctx_with(|err| format!("JIT row reweight launch({kernel_name}): {err}"))
}
#[cfg(target_os = "linux")]
pub fn launch_solve_row_on_stream(
backend: &PirlsRowBackend,
family: PirlsRowFamily,
curvature: CurvatureMode,
gamma_shape: f64,
stream: &Arc<cudarc::driver::CudaStream>,
n: usize,
eta_dev: &cudarc::driver::CudaSlice<f64>,
y_dev: &cudarc::driver::CudaSlice<f64>,
prior_w_dev: &cudarc::driver::CudaSlice<f64>,
out: &mut SolveRowBuffers,
) -> Result<(), GpuError> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
if out.n != n {
crate::gpu_bail!("solve-row buffers shape {} mismatches n={n}", out.n);
}
let module = backend.module_for_solve(family, curvature)?;
let kernel_name = family.solve_kernel_name();
let func = module
.load_function(kernel_name)
.gpu_ctx_with(|err| format!("solve-row load_function({kernel_name}): {err}"))?;
const THREADS_PER_BLOCK: u32 = 256;
let n_u32 =
u32::try_from(n).map_err(|_| gpu_err!("n={n} exceeds u32 for solve-row grid sizing"))?;
let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
let n_i32 = i32::try_from(n)
.map_err(|_| gpu_err!("n={n} exceeds i32 for solve-row kernel argument"))?;
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(&n_i32);
builder.arg(eta_dev);
builder.arg(y_dev);
builder.arg(prior_w_dev);
if matches!(family, PirlsRowFamily::GammaLog) {
builder.arg(&gamma_shape);
}
builder.arg(&mut out.grad_eta);
builder.arg(&mut out.w_solver);
builder.arg(&mut out.deviance);
builder.arg(&mut out.status);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.gpu_ctx_with(|err| format!("solve-row launch({kernel_name}): {err}"))
}
#[cfg(target_os = "linux")]
pub fn launch_alpha_ladder_on_stream(
backend: &PirlsRowBackend,
family: PirlsRowFamily,
curvature: CurvatureMode,
gamma_shape: f64,
stream: &Arc<cudarc::driver::CudaStream>,
n: usize,
eta_dev: &cudarc::driver::CudaSlice<f64>,
xd_dev: &cudarc::driver::CudaSlice<f64>,
y_dev: &cudarc::driver::CudaSlice<f64>,
prior_w_dev: &cudarc::driver::CudaSlice<f64>,
out: &mut AlphaLadderDevBuffers,
) -> Result<(), GpuError> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
let module = backend.module_for_ladder(family, curvature)?;
let kernel_name = family.ladder_kernel_name();
let func = module
.load_function(kernel_name)
.gpu_ctx_with(|err| format!("alpha-ladder load_function({kernel_name}): {err}"))?;
const THREADS_PER_BLOCK: u32 = 256;
let n_u32 =
u32::try_from(n).map_err(|_| gpu_err!("n={n} exceeds u32 for alpha-ladder grid sizing"))?;
let row_blocks = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
let n_i32 = i32::try_from(n)
.map_err(|_| gpu_err!("n={n} exceeds i32 for alpha-ladder kernel argument"))?;
let cfg = LaunchConfig {
grid_dim: (row_blocks, ALPHA_LADDER_LEN as u32, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(&n_i32);
builder.arg(eta_dev);
builder.arg(xd_dev);
builder.arg(y_dev);
builder.arg(prior_w_dev);
if matches!(family, PirlsRowFamily::GammaLog) {
builder.arg(&gamma_shape);
}
builder.arg(&mut out.objective_dev);
builder.arg(&mut out.status_dev);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.gpu_ctx_with(|err| format!("alpha-ladder launch({kernel_name}): {err}"))
}
#[cfg(target_os = "linux")]
const COMMON_DEVICE_PROLOG: &str = r#"
extern "C" {
double exp(double);
double log(double);
double log1p(double);
double tanh(double);
double sqrt(double);
double fabs(double);
double erfc(double);
}
__device__ __forceinline__ double clamp_eta(double eta, unsigned int* flags) {
const double E = 700.0;
if (eta > E) { *flags |= 0x1u; return E; }
if (eta < -E) { *flags |= 0x1u; return -E; }
return eta;
}
__device__ __forceinline__ double bernoulli_deviance(double y, double mu, double w) {
if (w == 0.0) return 0.0;
double t1 = (y > 0.0) ? y * log(y / mu) : 0.0;
double t2 = (y < 1.0) ? (1.0 - y) * log((1.0 - y) / (1.0 - mu)) : 0.0;
return 2.0 * w * (t1 + t2);
}
__device__ __forceinline__ double bernoulli_z(double eta, double y, double mu, double dmu_deta) {
if (dmu_deta > 0.0 && isfinite(dmu_deta)) {
double delta = (y - mu) / dmu_deta;
if (isfinite(delta)) return eta + delta;
}
return eta;
}
__device__ __forceinline__ double std_norm_cdf(double x) {
return 0.5 * erfc(-x * 0.7071067811865475);
}
__device__ __forceinline__ double std_norm_pdf(double x) {
return 0.3989422804014327 * exp(-0.5 * x * x);
}
"#;
#[cfg(target_os = "linux")]
fn cuda_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
let body = match family {
PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
PirlsRowFamily::GammaLog => gamma_log_body(curvature),
PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
};
let kernel_name = family.kernel_name();
let curvature_define = match curvature {
CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
};
let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
" double shape,\n"
} else {
""
};
format!(
r#"
{curvature_define}
{prolog}
extern "C" __global__ void {kernel_name}(
int n,
const double* __restrict__ eta,
const double* __restrict__ y,
const double* __restrict__ prior_w,
{shape_param} double* __restrict__ mu_out,
double* __restrict__ grad_eta_out,
double* __restrict__ w_fisher_out,
double* __restrict__ w_hessian_out,
double* __restrict__ w_solver_out,
double* __restrict__ z_fisher_out,
double* __restrict__ z_hessian_out,
double* __restrict__ deviance_out,
unsigned int* __restrict__ status_out
) {{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n) return;
unsigned int flags = 0u;
double eta_i = eta[i];
double y_i = y[i];
double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
if (prior_w[i] <= 0.0) flags |= 0x10u;
{body}
mu_out[i] = mu;
grad_eta_out[i] = grad_eta;
w_fisher_out[i] = w_fisher;
w_hessian_out[i] = w_hessian;
w_solver_out[i] = w_solver;
z_fisher_out[i] = z_f;
z_hessian_out[i] = z_h;
deviance_out[i] = dev;
status_out[i] = flags;
}}
"#,
prolog = COMMON_DEVICE_PROLOG,
)
}
#[cfg(target_os = "linux")]
#[inline]
fn curvature_tag(curvature: CurvatureMode) -> &'static str {
match curvature {
CurvatureMode::Fisher => " // curvature: fisher\n",
CurvatureMode::Observed => " // curvature: observed\n",
}
}
#[cfg(target_os = "linux")]
fn gaussian_identity_body(curvature: CurvatureMode) -> String {
let tag = curvature_tag(curvature);
format!(
r#"{tag} double mu = eta_i;
double resid = y_i - mu;
double grad_eta = wp * resid;
double w_fisher = wp;
double w_hessian = wp;
double w_solver = (wp > 0.0) ? fmax(wp, 1e-12) : 0.0;
double z_f = y_i;
double z_h = y_i;
double dev = wp * resid * resid;
"#
)
}
#[cfg(target_os = "linux")]
fn poisson_log_body(curvature: CurvatureMode) -> String {
let tag = curvature_tag(curvature);
format!(
r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
double mu_raw = exp(eta_c);
if (mu_raw < 1e-10) flags |= 0x2u;
double mu = (mu_raw > 1e-10) ? mu_raw : 1e-10;
double raw_w = wp * mu;
double w_fisher = (raw_w > 0.0) ? fmax(raw_w, 1e-12) : 0.0;
double resid = y_i - mu;
double grad_eta = wp * resid;
double w_hessian = w_fisher;
double w_solver = w_fisher;
double z_f = eta_c + resid / mu;
double z_h = z_f;
double dev_term = (y_i > 0.0) ? (y_i * log(y_i / mu) - resid) : (-resid);
double dev = 2.0 * wp * dev_term;
if (!(isfinite(y_i) && y_i >= 0.0)) flags |= 0x8u;
"#
)
}
#[cfg(target_os = "linux")]
fn gamma_log_body(curvature: CurvatureMode) -> String {
let tag = curvature_tag(curvature);
format!(
r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
double mu_raw = exp(eta_c);
if (mu_raw < 1e-10) flags |= 0x2u;
double mu = (mu_raw > 1e-10) ? mu_raw : 1e-10;
double w_fisher = wp * shape;
#ifdef PIRLS_CURVATURE_OBSERVED
// Stage 5: observed information for Gamma-log.
// w_obs = w_F + w_F · (y/μ − 1) = w_F · y/μ.
double w_hessian = (w_fisher > 0.0 && mu > 0.0 && isfinite(y_i))
? w_fisher * (y_i / mu)
: w_fisher;
#else
double w_hessian = w_fisher;
#endif
double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
double resid = y_i - mu;
double grad_eta = wp * resid / mu;
double z_f = eta_c + resid / mu;
double z_h = z_f;
double dev = (y_i > 0.0)
? (2.0 * wp * (-log(y_i / mu) + resid / mu))
: (1.0 / 0.0);
if (!(isfinite(y_i) && y_i > 0.0)) flags |= 0x8u;
"#
)
}
#[cfg(target_os = "linux")]
fn bernoulli_logit_body(curvature: CurvatureMode) -> String {
let tag = curvature_tag(curvature);
format!(
r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
double half = 0.5 * eta_c;
double mu_raw = 0.5 * (1.0 + tanh(half));
if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
double dmu_deta = mu * (1.0 - mu);
double w_fisher = wp * dmu_deta;
double w_hessian = w_fisher;
double w_solver = (w_fisher > 0.0) ? fmax(w_fisher, 1e-12) : 0.0;
double resid = y_i - mu;
double grad_eta = wp * resid;
double dev = bernoulli_deviance(y_i, mu, wp);
double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
double z_h = z_f;
if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
"#
)
}
#[cfg(target_os = "linux")]
fn bernoulli_probit_body(curvature: CurvatureMode) -> String {
let tag = curvature_tag(curvature);
format!(
r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
double mu_raw = std_norm_cdf(eta_c);
if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
double dmu_deta = std_norm_pdf(eta_c);
double v = mu * (1.0 - mu);
double fpp = (v > 0.0) ? dmu_deta * dmu_deta / v : 0.0;
double w_fisher = wp * fpp;
#ifdef PIRLS_CURVATURE_OBSERVED
// Stage 5: observed information for Bernoulli probit.
// w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
// h(η)=φ(η), h'(η)=−η·φ(η); V'=1−2μ.
double w_hessian = w_fisher;
if (v > 0.0 && wp > 0.0) {{
double h_prime = -eta_c * dmu_deta;
double v_prime = 1.0 - 2.0 * mu;
double bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
w_hessian = w_fisher + wp * (y_i - mu) * bracket;
}}
#else
double w_hessian = w_fisher;
#endif
double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
double resid = y_i - mu;
double grad_eta = (v > 0.0) ? wp * resid * dmu_deta / v : 0.0;
double dev = bernoulli_deviance(y_i, mu, wp);
double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
double z_h = z_f;
if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
"#
)
}
#[cfg(target_os = "linux")]
fn bernoulli_cloglog_body(curvature: CurvatureMode) -> String {
let tag = curvature_tag(curvature);
format!(
r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
double inner = exp(eta_c);
// μ = 1 − exp(−exp(η)); use -expm1(-inner) to avoid catastrophic
// cancellation in the deep negative tail (η ≲ -36).
double mu_raw = -expm1(-inner);
if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
double dmu_deta = inner * (1.0 - mu_raw);
double v = mu * (1.0 - mu);
double fpp = (v > 0.0) ? dmu_deta * dmu_deta / v : 0.0;
double w_fisher = wp * fpp;
#ifdef PIRLS_CURVATURE_OBSERVED
// Stage 5: observed information for Bernoulli cloglog.
// w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
// h'(η) = h(η) · (1 − inner); V'=1−2μ.
double w_hessian = w_fisher;
if (v > 0.0 && wp > 0.0) {{
double h_prime = dmu_deta * (1.0 - inner);
double v_prime = 1.0 - 2.0 * mu;
double bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
w_hessian = w_fisher + wp * (y_i - mu) * bracket;
}}
#else
double w_hessian = w_fisher;
#endif
double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
double resid = y_i - mu;
double grad_eta = (v > 0.0) ? wp * resid * dmu_deta / v : 0.0;
double dev = bernoulli_deviance(y_i, mu, wp);
double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
double z_h = z_f;
if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
"#
)
}
#[cfg(target_os = "linux")]
fn solve_row_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
let body = match family {
PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
PirlsRowFamily::GammaLog => gamma_log_body(curvature),
PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
};
let kernel_name = family.solve_kernel_name();
let curvature_define = match curvature {
CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
};
let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
" double shape,\n"
} else {
""
};
format!(
r#"
{curvature_define}
{prolog}
extern "C" __global__ void {kernel_name}(
int n,
const double* __restrict__ eta,
const double* __restrict__ y,
const double* __restrict__ prior_w,
{shape_param} double* __restrict__ grad_eta_out,
double* __restrict__ w_solver_out,
double* __restrict__ deviance_out,
unsigned int* __restrict__ status_out
) {{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n) return;
unsigned int flags = 0u;
double eta_i = eta[i];
double y_i = y[i];
double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
if (prior_w[i] <= 0.0) flags |= 0x10u;
{body}
grad_eta_out[i] = grad_eta;
w_solver_out[i] = w_solver;
deviance_out[i] = dev;
status_out[i] = flags;
}}
"#,
prolog = COMMON_DEVICE_PROLOG,
)
}
#[cfg(target_os = "linux")]
const ALPHA_LADDER_CUDA_ARRAY: &str =
"__constant__ double PIRLS_ALPHAS[7] = {1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625};";
#[cfg(target_os = "linux")]
fn ladder_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
let body = match family {
PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
PirlsRowFamily::GammaLog => gamma_log_body(curvature),
PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
};
let kernel_name = family.ladder_kernel_name();
let curvature_define = match curvature {
CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
};
let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
" double shape,\n"
} else {
""
};
format!(
r#"
{curvature_define}
{prolog}
{alphas}
extern "C" __global__ void {kernel_name}(
int n,
const double* __restrict__ eta,
const double* __restrict__ xd,
const double* __restrict__ y,
const double* __restrict__ prior_w,
{shape_param} double* __restrict__ objective_out,
unsigned int* __restrict__ status_out
) {{
int i = blockIdx.x * blockDim.x + threadIdx.x;
int k = (int)blockIdx.y;
if (i >= n) return;
unsigned int flags = 0u;
double alpha = PIRLS_ALPHAS[k];
double eta_i = eta[i] + alpha * xd[i];
double y_i = y[i];
double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
if (prior_w[i] <= 0.0) flags |= 0x10u;
{body}
atomicAdd(&objective_out[k], dev);
atomicOr(&status_out[k], flags);
}}
"#,
prolog = COMMON_DEVICE_PROLOG,
alphas = ALPHA_LADDER_CUDA_ARRAY,
)
}
#[cfg(test)]
mod pirls_row_gpu_tests {
use super::*;
fn assert_close(label: &str, got: f64, expected: f64, tol: f64) {
if !(got.is_finite() && expected.is_finite()) {
assert_eq!(
got.is_finite(),
expected.is_finite(),
"{label}: finiteness disagrees (got={got}, expected={expected})"
);
return;
}
let diff = (got - expected).abs();
let denom = expected.abs().max(1.0);
assert!(
diff <= tol * denom,
"{label}: |{got} - {expected}| = {diff} exceeds tol {tol} (rel denom {denom})"
);
}
fn check_family_matches_cpu_reference(family: PirlsRowFamily) {
let etas = [-700.0, -3.0, -0.5, 0.0, 0.5, 3.0, 700.0];
let ys = match family {
PirlsRowFamily::GammaLog => vec![0.5, 1.0, 2.5],
PirlsRowFamily::PoissonLog => vec![0.0, 1.0, 5.0],
PirlsRowFamily::GaussianIdentity => vec![-1.5, 0.0, 2.0],
_ => vec![0.0, 1.0],
};
let ws = [0.0, 1.0, 2.5];
for &eta in &etas {
for &y in &ys {
for &wp in &ws {
let input = RowInput {
eta,
y,
prior_weight: wp,
};
let out = row_reweight_cpu(family, CurvatureMode::Fisher, input, 1.0);
assert!(
out.w_fisher >= 0.0,
"{family:?}: w_fisher must be non-negative (got {})",
out.w_fisher
);
assert!(
out.w_solver >= 0.0,
"{family:?}: w_solver must be non-negative (got {})",
out.w_solver
);
if wp > 0.0 && out.w_hessian > 0.0 {
assert!(
out.w_solver >= W_SOLVER_FLOOR,
"{family:?}: w_solver must be floored away from zero when positive (got {})",
out.w_solver
);
}
if (out.status & status_flags::ETA_CLAMPED) != 0 {
continue;
}
if out.w_fisher > 0.0 && out.z_fisher.is_finite() {
let reconstructed = out.w_fisher * (out.z_fisher - eta);
if reconstructed.is_finite() {
let denom = reconstructed.abs().max(out.grad_eta.abs()).max(1.0);
let diff = (reconstructed - out.grad_eta).abs() / denom;
assert!(
diff < 1.0e-6,
"{family:?} eta={eta} y={y} wp={wp}: grad_eta {} vs w·(z−η) {} differ by rel {}",
out.grad_eta,
reconstructed,
diff
);
}
}
if out.status & status_flags::INVALID_RESPONSE == 0 && wp >= 0.0 {
assert!(
out.deviance >= 0.0 || !out.deviance.is_finite(),
"{family:?} eta={eta} y={y} wp={wp}: deviance must be non-negative for valid inputs (got {})",
out.deviance
);
}
if out.status
& (status_flags::INVALID_RESPONSE | status_flags::ZERO_PRIOR_WEIGHT)
== 0
{
assert!(
out.mu.is_finite(),
"{family:?} eta={eta} y={y} wp={wp}: mu must be finite for valid inputs"
);
assert!(
out.grad_eta.is_finite(),
"{family:?} eta={eta} y={y} wp={wp}: grad_eta must be finite for valid inputs"
);
}
}
}
}
assert_close("self", 0.0, 0.0, 0.0);
}
fn count_active_rows(family: PirlsRowFamily) -> usize {
let mut active = 0usize;
for &eta in [-700.0, -3.0, 0.0, 3.0, 700.0].iter() {
for &y in [0.0, 0.5, 1.0].iter() {
for &wp in [1.0, 2.5].iter() {
let out = row_reweight_cpu(
family,
CurvatureMode::Fisher,
RowInput {
eta,
y,
prior_weight: wp,
},
1.0,
);
if out.w_fisher > 0.0 {
active += 1;
}
}
}
}
active
}
#[test]
fn gaussian_identity_row_invariants() {
check_family_matches_cpu_reference(PirlsRowFamily::GaussianIdentity);
assert!(count_active_rows(PirlsRowFamily::GaussianIdentity) > 0);
}
#[test]
fn poisson_log_row_invariants() {
check_family_matches_cpu_reference(PirlsRowFamily::PoissonLog);
assert!(count_active_rows(PirlsRowFamily::PoissonLog) > 0);
}
#[test]
fn gamma_log_row_invariants() {
check_family_matches_cpu_reference(PirlsRowFamily::GammaLog);
assert!(count_active_rows(PirlsRowFamily::GammaLog) > 0);
}
#[test]
fn bernoulli_logit_row_invariants() {
check_family_matches_cpu_reference(PirlsRowFamily::BernoulliLogit);
assert!(count_active_rows(PirlsRowFamily::BernoulliLogit) > 0);
}
#[test]
fn bernoulli_probit_row_invariants() {
check_family_matches_cpu_reference(PirlsRowFamily::BernoulliProbit);
assert!(count_active_rows(PirlsRowFamily::BernoulliProbit) > 0);
}
#[test]
fn bernoulli_cloglog_row_invariants() {
check_family_matches_cpu_reference(PirlsRowFamily::BernoulliCLogLog);
assert!(count_active_rows(PirlsRowFamily::BernoulliCLogLog) > 0);
}
#[test]
fn gaussian_identity_matches_explicit_formulas() {
let out = row_reweight_cpu(
PirlsRowFamily::GaussianIdentity,
CurvatureMode::Fisher,
RowInput {
eta: 0.25,
y: 1.0,
prior_weight: 2.0,
},
1.0,
);
assert!(out.mu.is_finite() && out.deviance.is_finite());
assert_close("mu", out.mu, 0.25, 0.0);
assert_close("grad_eta", out.grad_eta, 2.0 * (1.0 - 0.25), 1e-15);
assert_close("w_fisher", out.w_fisher, 2.0, 0.0);
assert_close(
"deviance",
out.deviance,
2.0 * (1.0 - 0.25_f64).powi(2),
1e-15,
);
}
#[test]
fn poisson_log_matches_explicit_formulas() {
let out = row_reweight_cpu(
PirlsRowFamily::PoissonLog,
CurvatureMode::Fisher,
RowInput {
eta: 1.5,
y: 4.0,
prior_weight: 1.0,
},
1.0,
);
let expected_mu = (1.5_f64).exp();
assert!(expected_mu.is_finite() && out.mu.is_finite());
assert_close("mu", out.mu, expected_mu, 1e-15);
assert_close("grad_eta", out.grad_eta, 4.0 - expected_mu, 1e-15);
assert_close("w_fisher", out.w_fisher, expected_mu, 1e-15);
}
#[test]
fn bernoulli_logit_matches_explicit_formulas() {
let eta: f64 = 0.7;
let mu = 1.0 / (1.0 + (-eta).exp());
let out = row_reweight_cpu(
PirlsRowFamily::BernoulliLogit,
CurvatureMode::Fisher,
RowInput {
eta,
y: 1.0,
prior_weight: 3.0,
},
1.0,
);
assert!(mu > 0.0 && mu < 1.0);
assert_close("mu", out.mu, mu, 1e-12);
assert_close("w_fisher", out.w_fisher, 3.0 * mu * (1.0 - mu), 1e-12);
assert_close("grad_eta", out.grad_eta, 3.0 * (1.0 - mu), 1e-12);
}
#[test]
fn eta_clamp_status_flag_trips() {
let out = row_reweight_cpu(
PirlsRowFamily::PoissonLog,
CurvatureMode::Fisher,
RowInput {
eta: 1000.0,
y: 0.0,
prior_weight: 1.0,
},
1.0,
);
assert!(out.status & status_flags::ETA_CLAMPED != 0);
}
#[test]
fn backend_compiles_one_module_per_family_when_device_present() {
assert_eq!(PirlsRowBackend::compiled(), cfg!(target_os = "linux"));
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[pirls_row_gpu test] no CUDA runtime — skipping device compile test");
return;
}
#[cfg(target_os = "linux")]
{
let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
for &family in PirlsRowFamily::ALL.iter() {
let m1 = backend
.module_for(family, CurvatureMode::Fisher)
.unwrap_or_else(|err| panic!("compile {family:?}: {err}"));
let m2 = backend
.module_for(family, CurvatureMode::Fisher)
.unwrap_or_else(|err| panic!("re-fetch {family:?}: {err}"));
assert!(
Arc::ptr_eq(&m1, &m2),
"{family:?}: module cache must return same handle on second call"
);
}
}
}
#[test]
fn jit_glm_kernel_matches_builtin_byte_identical() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[stage_6_jit] no CUDA runtime — skipping");
return;
}
#[cfg(target_os = "linux")]
{
let etas = [-2.0_f64, -0.5, 0.3, 1.5];
let ys = [0.0_f64, 1.0, 0.0, 1.0];
let priors = [1.0_f64, 1.2, 0.8, 1.5];
let n = etas.len();
let family = PirlsRowFamily::BernoulliLogit;
let curvature = CurvatureMode::Fisher;
let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
let runtime =
super::super::runtime::GpuRuntime::global().expect("GPU runtime available");
let ctx = super::super::runtime::cuda_context_for(runtime.selected_device().ordinal)
.expect("ctx");
let stream = ctx.default_stream();
let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("eta");
let mut y_dev = stream.alloc_zeros::<f64>(n).expect("y");
let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("prior");
stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
stream
.memcpy_htod(&priors, &mut prior_dev)
.expect("up prior");
let mut out_builtin = RowOutputDevBuffers::allocate(&stream, n).expect("alloc builtin");
launch_row_reweight_on_stream(
backend,
family,
curvature,
1.0,
&stream,
n,
&eta_dev,
&y_dev,
&prior_dev,
&mut out_builtin,
)
.expect("builtin launch");
let spec = JitFamilySpec::glm(0x424c_4c47u64, family, curvature);
let mut out_jit = RowOutputDevBuffers::allocate(&stream, n).expect("alloc jit");
launch_row_reweight_jit_on_stream(
backend,
&spec,
curvature,
&stream,
n,
&eta_dev,
&y_dev,
&prior_dev,
&mut out_jit,
)
.expect("jit launch");
stream.synchronize().expect("sync");
for (label, b_dev, j_dev) in [
("mu", &out_builtin.mu, &out_jit.mu),
("grad_eta", &out_builtin.grad_eta, &out_jit.grad_eta),
("w_fisher", &out_builtin.w_fisher, &out_jit.w_fisher),
("w_hessian", &out_builtin.w_hessian, &out_jit.w_hessian),
("w_solver", &out_builtin.w_solver, &out_jit.w_solver),
("z_fisher", &out_builtin.z_fisher, &out_jit.z_fisher),
("z_hessian", &out_builtin.z_hessian, &out_jit.z_hessian),
("deviance", &out_builtin.deviance, &out_jit.deviance),
] {
let b = stream.clone_dtoh(b_dev).expect("dl builtin");
let j = stream.clone_dtoh(j_dev).expect("dl jit");
for i in 0..n {
assert_eq!(
b[i].to_bits(),
j[i].to_bits(),
"{label}[{i}]: builtin {} ≠ jit {}",
b[i],
j[i],
);
}
}
}
}
#[test]
fn jit_raw_body_kernel_matches_builtin_gaussian_byte_identical() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[stage_6_jit_raw] no CUDA runtime — skipping");
return;
}
#[cfg(target_os = "linux")]
{
let n: usize = 256;
let mut etas = vec![0.0_f64; n];
let mut ys = vec![0.0_f64; n];
let mut priors = vec![0.0_f64; n];
for i in 0..n {
let t = (i as f64) / (n as f64 - 1.0); etas[i] = -3.0 + 6.0 * t;
ys[i] = 5.0 * (t - 0.5);
priors[i] = if i == 7 {
0.0 } else {
0.25 + 1.75 * t
};
}
let family = PirlsRowFamily::GaussianIdentity;
let curvature = CurvatureMode::Fisher;
let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
let runtime =
super::super::runtime::GpuRuntime::global().expect("GPU runtime available");
let ctx = super::super::runtime::cuda_context_for(runtime.selected_device().ordinal)
.expect("ctx");
let stream = ctx.default_stream();
let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("eta");
let mut y_dev = stream.alloc_zeros::<f64>(n).expect("y");
let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("prior");
stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
stream
.memcpy_htod(&priors, &mut prior_dev)
.expect("up prior");
let mut out_builtin = RowOutputDevBuffers::allocate(&stream, n).expect("alloc builtin");
launch_row_reweight_on_stream(
backend,
family,
curvature,
1.0,
&stream,
n,
&eta_dev,
&y_dev,
&prior_dev,
&mut out_builtin,
)
.expect("builtin launch");
let raw_body = r#" // level-b raw body: gaussian identity (hand-written)
// identity link: mu = eta
double mu = eta_i;
// ordinary residual on the response scale
double resid = y_i - mu;
// canonical score contribution
double grad_eta = wp * resid;
// fisher info per row: weight itself (V(mu)=1, dmu/deta=1)
double w_fisher = wp;
// observed == fisher for canonical identity link
double w_hessian = wp;
// solver weight clamps tiny positives to avoid singularity
double w_solver = (wp > 0.0) ? fmax(wp, 1e-12) : 0.0;
// working response equals raw response on identity link
double z_f = y_i;
double z_h = y_i;
// squared-error contribution to deviance
double dev = wp * resid * resid;
"#;
let spec = JitFamilySpec::raw(0x5241_575f_4741_5553u64, raw_body);
let mut out_jit = RowOutputDevBuffers::allocate(&stream, n).expect("alloc jit");
launch_row_reweight_jit_on_stream(
backend,
&spec,
curvature,
&stream,
n,
&eta_dev,
&y_dev,
&prior_dev,
&mut out_jit,
)
.expect("jit raw launch");
stream.synchronize().expect("sync");
for (label, b_dev, j_dev) in [
("mu", &out_builtin.mu, &out_jit.mu),
("grad_eta", &out_builtin.grad_eta, &out_jit.grad_eta),
("w_fisher", &out_builtin.w_fisher, &out_jit.w_fisher),
("w_hessian", &out_builtin.w_hessian, &out_jit.w_hessian),
("w_solver", &out_builtin.w_solver, &out_jit.w_solver),
("z_fisher", &out_builtin.z_fisher, &out_jit.z_fisher),
("z_hessian", &out_builtin.z_hessian, &out_jit.z_hessian),
("deviance", &out_builtin.deviance, &out_jit.deviance),
] {
let b = stream.clone_dtoh(b_dev).expect("dl builtin");
let j = stream.clone_dtoh(j_dev).expect("dl jit raw");
for i in 0..n {
assert_eq!(
b[i].to_bits(),
j[i].to_bits(),
"{label}[{i}]: builtin {} ≠ jit-raw {}",
b[i],
j[i],
);
}
}
let mu_j = stream.clone_dtoh(&out_jit.mu).expect("dl jit mu");
let g_j = stream.clone_dtoh(&out_jit.grad_eta).expect("dl jit g");
let wf_j = stream.clone_dtoh(&out_jit.w_fisher).expect("dl jit wf");
let wh_j = stream.clone_dtoh(&out_jit.w_hessian).expect("dl jit wh");
let ws_j = stream.clone_dtoh(&out_jit.w_solver).expect("dl jit ws");
let zf_j = stream.clone_dtoh(&out_jit.z_fisher).expect("dl jit zf");
let zh_j = stream.clone_dtoh(&out_jit.z_hessian).expect("dl jit zh");
let d_j = stream.clone_dtoh(&out_jit.deviance).expect("dl jit d");
for i in 0..n {
let cpu = row_reweight_cpu(
PirlsRowFamily::GaussianIdentity,
curvature,
RowInput {
eta: etas[i],
y: ys[i],
prior_weight: priors[i],
},
1.0,
);
for (label, cpu_v, jit_v) in [
("mu", cpu.mu, mu_j[i]),
("grad_eta", cpu.grad_eta, g_j[i]),
("w_fisher", cpu.w_fisher, wf_j[i]),
("w_hessian", cpu.w_hessian, wh_j[i]),
("w_solver", cpu.w_solver, ws_j[i]),
("z_fisher", cpu.z_fisher, zf_j[i]),
("z_hessian", cpu.z_hessian, zh_j[i]),
("deviance", cpu.deviance, d_j[i]),
] {
assert_eq!(
cpu_v.to_bits(),
jit_v.to_bits(),
"{label}[{i}]: cpu {} ≠ jit-raw {}",
cpu_v,
jit_v,
);
}
}
}
}
#[test]
fn observed_curvature_matches_expected_per_family() {
let probe_eta = 0.4_f64;
let probe_y = 1.0_f64;
let wp = 1.5_f64;
let input = RowInput {
eta: probe_eta,
y: probe_y,
prior_weight: wp,
};
for canonical in [
PirlsRowFamily::GaussianIdentity,
PirlsRowFamily::PoissonLog,
PirlsRowFamily::BernoulliLogit,
] {
let f = row_reweight_cpu(canonical, CurvatureMode::Fisher, input, 1.0);
let o = row_reweight_cpu(canonical, CurvatureMode::Observed, input, 1.0);
assert_eq!(
f.w_hessian, o.w_hessian,
"{canonical:?}: observed must equal Fisher for canonical link"
);
}
for &shape in &[1.0_f64, 2.5] {
let gf = row_reweight_cpu(
PirlsRowFamily::GammaLog,
CurvatureMode::Fisher,
input,
shape,
);
let go = row_reweight_cpu(
PirlsRowFamily::GammaLog,
CurvatureMode::Observed,
input,
shape,
);
assert!(
(go.w_hessian - gf.w_fisher * (probe_y / gf.mu)).abs() <= 1e-12,
"Gamma-log observed mismatch (shape={shape}): got={} expected={} (mu={})",
go.w_hessian,
gf.w_fisher * (probe_y / gf.mu),
gf.mu
);
assert_ne!(
gf.w_hessian, go.w_hessian,
"Gamma-log: observed must differ from Fisher when y ≠ μ (shape={shape})"
);
}
for noncanon in [
PirlsRowFamily::BernoulliProbit,
PirlsRowFamily::BernoulliCLogLog,
] {
let f = row_reweight_cpu(noncanon, CurvatureMode::Fisher, input, 1.0);
let o = row_reweight_cpu(noncanon, CurvatureMode::Observed, input, 1.0);
assert!(
(f.w_hessian - o.w_hessian).abs() > 0.0 || (probe_y - f.mu).abs() < 1e-15,
"{noncanon:?}: observed should differ from Fisher when y ≠ μ"
);
}
}
#[test]
fn gamma_log_shape_scaling() {
let input = RowInput {
eta: 0.5,
y: 2.0,
prior_weight: 1.0,
};
let base = row_reweight_cpu(PirlsRowFamily::GammaLog, CurvatureMode::Fisher, input, 1.0);
for &shape in &[0.5_f64, 1.5, 3.0, 10.0] {
let r = row_reweight_cpu(
PirlsRowFamily::GammaLog,
CurvatureMode::Fisher,
input,
shape,
);
assert!(
(r.w_fisher - shape * base.w_fisher).abs() <= 1e-14,
"w_fisher should scale with shape: got {} expected {} (shape={shape})",
r.w_fisher,
shape * base.w_fisher,
);
assert_eq!(
r.mu.to_bits(),
base.mu.to_bits(),
"mu must not depend on shape"
);
let ro = row_reweight_cpu(
PirlsRowFamily::GammaLog,
CurvatureMode::Observed,
input,
shape,
);
let expected_obs = r.w_fisher * (input.y / r.mu);
assert!(
(ro.w_hessian - expected_obs).abs() <= 1e-13,
"observed w_hessian mismatch (shape={shape}): got={} expected={}",
ro.w_hessian,
expected_obs,
);
}
}
#[test]
fn launch_row_reweight_matches_cpu_reference_on_device() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[pirls_row_gpu test] no CUDA runtime — skipping launcher parity test");
return;
}
#[cfg(target_os = "linux")]
{
let etas = [-3.0_f64, -0.5, 0.0, 0.5, 3.0, 10.0, -10.0, 1.5];
let n = etas.len();
let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
let runtime = super::super::runtime::GpuRuntime::global()
.expect("GPU runtime available when probe succeeded");
let ctx = super::super::runtime::cuda_context_for(runtime.selected_device().ordinal)
.expect("ctx for selected device");
let stream = ctx.default_stream();
for &family in PirlsRowFamily::ALL.iter() {
let ys: Vec<f64> = match family {
PirlsRowFamily::GammaLog | PirlsRowFamily::PoissonLog => {
(0..n).map(|i| 1.0 + 0.5 * (i as f64)).collect()
}
PirlsRowFamily::GaussianIdentity => {
(0..n).map(|i| -1.0 + 0.5 * (i as f64)).collect()
}
_ => (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
};
let priors: Vec<f64> = (0..n).map(|i| 1.0 + 0.25 * (i as f64)).collect();
let mut cpu_out = Vec::with_capacity(n);
for i in 0..n {
cpu_out.push(row_reweight_cpu(
family,
CurvatureMode::Fisher,
RowInput {
eta: etas[i],
y: ys[i],
prior_weight: priors[i],
},
1.0,
));
}
let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("alloc eta_dev");
let mut y_dev = stream.alloc_zeros::<f64>(n).expect("alloc y_dev");
let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("alloc prior_dev");
stream
.memcpy_htod(etas.as_slice(), &mut eta_dev)
.expect("upload eta");
stream
.memcpy_htod(ys.as_slice(), &mut y_dev)
.expect("upload y");
stream
.memcpy_htod(priors.as_slice(), &mut prior_dev)
.expect("upload prior");
let mut out = RowOutputDevBuffers::allocate(&stream, n).expect("alloc row buffers");
launch_row_reweight_on_stream(
backend,
family,
CurvatureMode::Fisher,
1.0,
&stream,
n,
&eta_dev,
&y_dev,
&prior_dev,
&mut out,
)
.unwrap_or_else(|err| panic!("launch {family:?}: {err}"));
stream.synchronize().expect("stream sync");
let mu = stream.clone_dtoh(&out.mu).expect("dl mu");
let g = stream.clone_dtoh(&out.grad_eta).expect("dl grad_eta");
let wf = stream.clone_dtoh(&out.w_fisher).expect("dl w_fisher");
let wh = stream.clone_dtoh(&out.w_hessian).expect("dl w_hessian");
let ws_v = stream.clone_dtoh(&out.w_solver).expect("dl w_solver");
let zf = stream.clone_dtoh(&out.z_fisher).expect("dl z_fisher");
let zh = stream.clone_dtoh(&out.z_hessian).expect("dl z_hessian");
let dev = stream.clone_dtoh(&out.deviance).expect("dl deviance");
let tol = 1e-12;
for i in 0..n {
let r = cpu_out[i];
assert_close(&format!("{family:?}/row{i}/mu"), mu[i], r.mu, tol);
assert_close(
&format!("{family:?}/row{i}/grad_eta"),
g[i],
r.grad_eta,
tol,
);
assert_close(
&format!("{family:?}/row{i}/w_fisher"),
wf[i],
r.w_fisher,
tol,
);
assert_close(
&format!("{family:?}/row{i}/w_hessian"),
wh[i],
r.w_hessian,
tol,
);
assert_close(
&format!("{family:?}/row{i}/w_solver"),
ws_v[i],
r.w_solver,
tol,
);
assert_close(
&format!("{family:?}/row{i}/z_fisher"),
zf[i],
r.z_fisher,
tol,
);
assert_close(
&format!("{family:?}/row{i}/z_hessian"),
zh[i],
r.z_hessian,
tol,
);
assert_close(
&format!("{family:?}/row{i}/deviance"),
dev[i],
r.deviance,
tol,
);
}
}
}
}
#[test]
fn gpu_observed_parity() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[gpu_observed_parity] no CUDA runtime — skipping");
return;
}
#[cfg(target_os = "linux")]
{
const N: usize = 256;
let etas: Vec<f64> = (0..N)
.map(|i| -6.0 + 12.0 * (i as f64) / ((N - 1) as f64))
.collect();
let priors: Vec<f64> = (0..N)
.map(|i| 0.5 + 1.5 * ((i as f64) / (N as f64)))
.collect();
let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
let runtime = super::super::runtime::GpuRuntime::global()
.expect("GPU runtime available when probe succeeded");
let ctx = super::super::runtime::cuda_context_for(runtime.selected_device().ordinal)
.expect("ctx for selected device");
let stream = ctx.default_stream();
for &family in PirlsRowFamily::ALL.iter() {
let ys: Vec<f64> = match family {
PirlsRowFamily::GammaLog => (0..N).map(|i| 0.25 + 0.05 * (i as f64)).collect(),
PirlsRowFamily::PoissonLog => (0..N).map(|i| (i % 6) as f64).collect(),
PirlsRowFamily::GaussianIdentity => (0..N)
.map(|i| -2.0 + 4.0 * (i as f64) / ((N - 1) as f64))
.collect(),
_ => (0..N).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
};
let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("alloc eta_dev");
let mut y_dev = stream.alloc_zeros::<f64>(N).expect("alloc y_dev");
let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("alloc prior_dev");
stream
.memcpy_htod(etas.as_slice(), &mut eta_dev)
.expect("upload eta");
stream
.memcpy_htod(ys.as_slice(), &mut y_dev)
.expect("upload y");
stream
.memcpy_htod(priors.as_slice(), &mut prior_dev)
.expect("upload prior");
let mut out_obs = RowOutputDevBuffers::allocate(&stream, N).expect("alloc out_obs");
launch_row_reweight_on_stream(
backend,
family,
CurvatureMode::Observed,
1.0,
&stream,
N,
&eta_dev,
&y_dev,
&prior_dev,
&mut out_obs,
)
.unwrap_or_else(|err| panic!("observed launch {family:?}: {err}"));
stream.synchronize().expect("stream sync (observed)");
let wh_obs = stream
.clone_dtoh(&out_obs.w_hessian)
.expect("dl w_hessian (observed)");
let wf_obs = stream
.clone_dtoh(&out_obs.w_fisher)
.expect("dl w_fisher (observed)");
if family.is_canonical() {
for i in 0..N {
assert_eq!(
wh_obs[i].to_bits(),
wf_obs[i].to_bits(),
"{family:?} row {i}: observed w_hessian {} must bit-equal w_fisher {} on canonical link",
wh_obs[i],
wf_obs[i],
);
}
} else {
for i in 0..N {
let cpu = row_reweight_cpu(
family,
CurvatureMode::Observed,
RowInput {
eta: etas[i],
y: ys[i],
prior_weight: priors[i],
},
1.0,
);
let got = wh_obs[i];
let exp = cpu.w_hessian;
let abs_err = (got - exp).abs();
let rel_err = if exp.abs() > 0.0 {
abs_err / exp.abs()
} else {
abs_err
};
assert!(
abs_err <= 1.0e-12 || rel_err <= 1.0e-11,
"{family:?} row {i} (eta={}, y={}, wp={}): \
device w_hessian={} vs CPU observed={} (abs={}, rel={})",
etas[i],
ys[i],
priors[i],
got,
exp,
abs_err,
rel_err,
);
}
}
}
}
}
#[test]
fn gpu_observed_parity_end_to_end_n1000() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[gpu_observed_parity_end_to_end_n1000] no CUDA runtime — skipping");
return;
}
#[cfg(target_os = "linux")]
{
const N: usize = 1000;
let etas: Vec<f64> = (0..N)
.map(|i| -8.0 + 16.0 * (i as f64) / ((N - 1) as f64))
.collect();
let priors: Vec<f64> = (0..N)
.map(|i| 0.25 + 1.75 * ((i as f64) / (N as f64)))
.collect();
let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
let runtime = super::super::runtime::GpuRuntime::global()
.expect("GPU runtime available when probe succeeded");
let ctx = super::super::runtime::cuda_context_for(runtime.selected_device().ordinal)
.expect("ctx for selected device");
let stream = ctx.default_stream();
const TOL: f64 = 1.0e-9;
for &family in PirlsRowFamily::ALL.iter() {
let ys: Vec<f64> = match family {
PirlsRowFamily::GammaLog => {
(0..N).map(|i| 0.10 + 0.05 * ((i % 97) as f64)).collect()
}
PirlsRowFamily::PoissonLog => (0..N).map(|i| (i % 11) as f64).collect(),
PirlsRowFamily::GaussianIdentity => (0..N)
.map(|i| -3.0 + 6.0 * (i as f64) / ((N - 1) as f64))
.collect(),
PirlsRowFamily::BernoulliLogit
| PirlsRowFamily::BernoulliProbit
| PirlsRowFamily::BernoulliCLogLog => {
(0..N).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect()
}
};
let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("alloc eta_dev");
let mut y_dev = stream.alloc_zeros::<f64>(N).expect("alloc y_dev");
let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("alloc prior_dev");
stream
.memcpy_htod(etas.as_slice(), &mut eta_dev)
.expect("upload eta");
stream
.memcpy_htod(ys.as_slice(), &mut y_dev)
.expect("upload y");
stream
.memcpy_htod(priors.as_slice(), &mut prior_dev)
.expect("upload prior");
let mut out_obs = RowOutputDevBuffers::allocate(&stream, N).expect("alloc out_obs");
launch_row_reweight_on_stream(
backend,
family,
CurvatureMode::Observed,
1.0,
&stream,
N,
&eta_dev,
&y_dev,
&prior_dev,
&mut out_obs,
)
.unwrap_or_else(|err| panic!("observed launch {family:?}: {err}"));
stream.synchronize().expect("stream sync (observed)");
let wh_obs = stream
.clone_dtoh(&out_obs.w_hessian)
.expect("dl w_hessian (observed)");
let ge_obs = stream
.clone_dtoh(&out_obs.grad_eta)
.expect("dl grad_eta (observed)");
for i in 0..N {
let cpu = row_reweight_cpu(
family,
CurvatureMode::Observed,
RowInput {
eta: etas[i],
y: ys[i],
prior_weight: priors[i],
},
1.0,
);
let h_got = wh_obs[i];
let h_exp = cpu.w_hessian;
let h_abs = (h_got - h_exp).abs();
let h_rel = if h_exp.abs() > 0.0 {
h_abs / h_exp.abs()
} else {
h_abs
};
assert!(
h_abs <= TOL || h_rel <= TOL,
"{family:?} row {i} (eta={}, y={}, wp={}): \
observed w_hessian GPU={} vs CPU={} (abs={}, rel={})",
etas[i],
ys[i],
priors[i],
h_got,
h_exp,
h_abs,
h_rel,
);
let g_got = ge_obs[i];
let g_exp = cpu.grad_eta;
let g_abs = (g_got - g_exp).abs();
let g_rel = if g_exp.abs() > 0.0 {
g_abs / g_exp.abs()
} else {
g_abs
};
assert!(
g_abs <= TOL || g_rel <= TOL,
"{family:?} row {i} (eta={}, y={}, wp={}): \
observed grad_eta GPU={} vs CPU={} (abs={}, rel={})",
etas[i],
ys[i],
priors[i],
g_got,
g_exp,
g_abs,
g_rel,
);
}
}
}
}
#[test]
fn gpu_jit_level_b_raw_body_end_to_end_all_families_n1000() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!(
"[gpu_jit_level_b_raw_body_end_to_end_all_families_n1000] no CUDA runtime — skipping"
);
return;
}
#[cfg(target_os = "linux")]
{
const N: usize = 1000;
const TOL: f64 = 1.0e-10;
let curvature = CurvatureMode::Fisher;
let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
let runtime =
super::super::runtime::GpuRuntime::global().expect("GPU runtime available");
let ctx = super::super::runtime::cuda_context_for(runtime.selected_device().ordinal)
.expect("ctx");
let stream = ctx.default_stream();
let etas: Vec<f64> = (0..N)
.map(|i| -6.0 + 12.0 * (i as f64) / ((N - 1) as f64))
.collect();
let priors: Vec<f64> = (0..N)
.map(|i| 0.25 + 1.75 * ((i as f64) / (N as f64)))
.collect();
let raw_gaussian = r#" // raw-body gaussian identity (independent re-derivation)
double resp = y_i;
double pred = eta_i;
double mu = pred;
double w_p = wp;
double e_resid = resp - pred;
double grad_eta = w_p * e_resid;
double w_fisher = w_p;
double w_hessian = w_p;
double w_solver = (w_p > 0.0) ? fmax(w_p, 1e-12) : 0.0;
double z_f = resp;
double z_h = resp;
double dev = w_p * e_resid * e_resid;
"#;
let raw_poisson = r#" // raw-body poisson log (independent re-derivation)
double eta_c = clamp_eta(eta_i, &flags);
double mu_pre = exp(eta_c);
if (mu_pre < 1e-10) flags |= 0x2u;
double mu = (mu_pre > 1e-10) ? mu_pre : 1e-10;
double wrate = wp * mu;
double w_fisher = (wrate > 0.0) ? fmax(wrate, 1e-12) : 0.0;
double w_hessian = w_fisher;
double w_solver = w_fisher;
double pres = y_i - mu;
double grad_eta = wp * pres;
double z_lin = eta_c + pres / mu;
double z_f = z_lin;
double z_h = z_lin;
double dterm;
if (y_i > 0.0) {
dterm = y_i * log(y_i / mu) - pres;
} else {
dterm = -pres;
}
double dev = 2.0 * wp * dterm;
if (!(isfinite(y_i) && y_i >= 0.0)) flags |= 0x8u;
"#;
let raw_gamma = r#" // raw-body gamma log (independent re-derivation; unit shape)
double k_shape = 1.0;
double eta_c = clamp_eta(eta_i, &flags);
double mu_pre = exp(eta_c);
if (mu_pre < 1e-10) flags |= 0x2u;
double mu = (mu_pre > 1e-10) ? mu_pre : 1e-10;
double w_fisher = wp * k_shape;
double w_hessian = w_fisher;
double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
double pres = y_i - mu;
double grad_eta = wp * pres / mu;
double z_lin = eta_c + pres / mu;
double z_f = z_lin;
double z_h = z_lin;
double dev;
if (y_i > 0.0) {
dev = 2.0 * wp * (-log(y_i / mu) + pres / mu);
} else {
dev = 1.0 / 0.0;
}
if (!(isfinite(y_i) && y_i > 0.0)) flags |= 0x8u;
"#;
let raw_logit = r#" // raw-body bernoulli logit (independent re-derivation)
double eta_c = clamp_eta(eta_i, &flags);
double te = tanh(0.5 * eta_c);
double mu_pre = 0.5 * (1.0 + te);
if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
double dmu_deta = mu * (1.0 - mu);
double w_fisher = wp * dmu_deta;
double w_hessian = w_fisher;
double w_solver = (w_fisher > 0.0) ? fmax(w_fisher, 1e-12) : 0.0;
double bres = y_i - mu;
double grad_eta = wp * bres;
double dev = bernoulli_deviance(y_i, mu, wp);
double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
double z_f = z_lin;
double z_h = z_lin;
if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
"#;
let raw_probit = r#" // raw-body bernoulli probit (independent re-derivation; Fisher mode)
double eta_c = clamp_eta(eta_i, &flags);
double mu_pre = std_norm_cdf(eta_c);
if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
double phi = std_norm_pdf(eta_c);
double dmu_deta = phi;
double vmu = mu * (1.0 - mu);
double w_pp = (vmu > 0.0) ? (phi * phi) / vmu : 0.0;
double w_fisher = wp * w_pp;
double w_hessian = w_fisher;
double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
double bres = y_i - mu;
double grad_eta = (vmu > 0.0) ? wp * bres * phi / vmu : 0.0;
double dev = bernoulli_deviance(y_i, mu, wp);
double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
double z_f = z_lin;
double z_h = z_lin;
if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
"#;
let raw_cloglog = r#" // raw-body bernoulli cloglog (independent re-derivation; Fisher mode)
double eta_c = clamp_eta(eta_i, &flags);
double a = exp(eta_c);
double mu_pre = 1.0 - exp(-a);
if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
double dmu_deta = a * (1.0 - mu_pre);
double vmu = mu * (1.0 - mu);
double w_pp = (vmu > 0.0) ? (dmu_deta * dmu_deta) / vmu : 0.0;
double w_fisher = wp * w_pp;
double w_hessian = w_fisher;
double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
double bres = y_i - mu;
double grad_eta = (vmu > 0.0) ? wp * bres * dmu_deta / vmu : 0.0;
double dev = bernoulli_deviance(y_i, mu, wp);
double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
double z_f = z_lin;
double z_h = z_lin;
if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
"#;
let cases: [(PirlsRowFamily, &str, u64, fn(usize) -> Vec<f64>); 6] = [
(
PirlsRowFamily::GaussianIdentity,
raw_gaussian,
0x5242_3031_4741_5553u64,
|n| {
(0..n)
.map(|i| -3.0 + 6.0 * (i as f64) / ((n - 1) as f64))
.collect()
},
),
(
PirlsRowFamily::PoissonLog,
raw_poisson,
0x5242_3032_504f_4953u64,
|n| (0..n).map(|i| (i % 11) as f64).collect(),
),
(
PirlsRowFamily::GammaLog,
raw_gamma,
0x5242_3033_474d_414cu64,
|n| (0..n).map(|i| 0.10 + 0.05 * ((i % 97) as f64)).collect(),
),
(
PirlsRowFamily::BernoulliLogit,
raw_logit,
0x5242_3034_4c47_4954u64,
|n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
),
(
PirlsRowFamily::BernoulliProbit,
raw_probit,
0x5242_3035_5052_4254u64,
|n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
),
(
PirlsRowFamily::BernoulliCLogLog,
raw_cloglog,
0x5242_3036_434c_4f47u64,
|n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
),
];
for (family, raw_body, spec_id, build_y) in cases {
let ys: Vec<f64> = build_y(N);
let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("eta");
let mut y_dev = stream.alloc_zeros::<f64>(N).expect("y");
let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("prior");
stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
stream
.memcpy_htod(&priors, &mut prior_dev)
.expect("up prior");
let spec = JitFamilySpec::raw(spec_id, raw_body);
let mut out_jit = RowOutputDevBuffers::allocate(&stream, N).expect("alloc jit out");
launch_row_reweight_jit_on_stream(
backend,
&spec,
curvature,
&stream,
N,
&eta_dev,
&y_dev,
&prior_dev,
&mut out_jit,
)
.unwrap_or_else(|err| panic!("jit raw-body launch {family:?}: {err}"));
stream.synchronize().expect("sync");
let mu_j = stream.clone_dtoh(&out_jit.mu).expect("dl mu");
let ge_j = stream.clone_dtoh(&out_jit.grad_eta).expect("dl g");
let wf_j = stream.clone_dtoh(&out_jit.w_fisher).expect("dl wf");
let wh_j = stream.clone_dtoh(&out_jit.w_hessian).expect("dl wh");
let ws_j = stream.clone_dtoh(&out_jit.w_solver).expect("dl ws");
let zf_j = stream.clone_dtoh(&out_jit.z_fisher).expect("dl zf");
let zh_j = stream.clone_dtoh(&out_jit.z_hessian).expect("dl zh");
let dv_j = stream.clone_dtoh(&out_jit.deviance).expect("dl dv");
for i in 0..N {
let cpu = row_reweight_cpu(
family,
curvature,
RowInput {
eta: etas[i],
y: ys[i],
prior_weight: priors[i],
},
1.0,
);
for (label, got, exp) in [
("mu", mu_j[i], cpu.mu),
("grad_eta", ge_j[i], cpu.grad_eta),
("w_fisher", wf_j[i], cpu.w_fisher),
("w_hessian", wh_j[i], cpu.w_hessian),
("w_solver", ws_j[i], cpu.w_solver),
("z_fisher", zf_j[i], cpu.z_fisher),
("z_hessian", zh_j[i], cpu.z_hessian),
("deviance", dv_j[i], cpu.deviance),
] {
if !got.is_finite() && !exp.is_finite() {
continue;
}
let abs_err = (got - exp).abs();
let rel_err = if exp.abs() > 0.0 {
abs_err / exp.abs()
} else {
abs_err
};
assert!(
abs_err <= TOL || rel_err <= TOL,
"{family:?} {label}[{i}] (eta={}, y={}, wp={}): \
JIT raw-body={} vs CPU={} (abs={}, rel={})",
etas[i],
ys[i],
priors[i],
got,
exp,
abs_err,
rel_err,
);
}
}
}
}
}
}