1use std::sync::OnceLock;
49
50use gam_gpu::gpu_error::GpuError;
51#[cfg(target_os = "linux")]
52use gam_gpu::gpu_error::GpuResultExt;
53
54#[cfg(target_os = "linux")]
55use std::sync::{Arc, Mutex};
56
57#[cfg(target_os = "linux")]
58use cudarc::driver::{CudaContext, CudaModule};
59
60#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
70pub enum PirlsRowFamily {
71 BernoulliLogit,
72 BernoulliProbit,
73 BernoulliCLogLog,
74 PoissonLog,
75 GaussianIdentity,
76 GammaLog,
77}
78
79impl PirlsRowFamily {
80 pub const ALL: [Self; 6] = [
81 Self::BernoulliLogit,
82 Self::BernoulliProbit,
83 Self::BernoulliCLogLog,
84 Self::PoissonLog,
85 Self::GaussianIdentity,
86 Self::GammaLog,
87 ];
88
89 pub const fn as_str(self) -> &'static str {
90 match self {
91 Self::BernoulliLogit => "bernoulli-logit",
92 Self::BernoulliProbit => "bernoulli-probit",
93 Self::BernoulliCLogLog => "bernoulli-cloglog",
94 Self::PoissonLog => "poisson-log",
95 Self::GaussianIdentity => "gaussian-identity",
96 Self::GammaLog => "gamma-log",
97 }
98 }
99
100 pub const fn kernel_name(self) -> &'static str {
102 match self {
103 Self::BernoulliLogit => "pirls_row_bernoulli_logit",
104 Self::BernoulliProbit => "pirls_row_bernoulli_probit",
105 Self::BernoulliCLogLog => "pirls_row_bernoulli_cloglog",
106 Self::PoissonLog => "pirls_row_poisson_log",
107 Self::GaussianIdentity => "pirls_row_gaussian_identity",
108 Self::GammaLog => "pirls_row_gamma_log",
109 }
110 }
111
112 pub const fn solve_kernel_name(self) -> &'static str {
115 match self {
116 Self::BernoulliLogit => "pirls_solve_bernoulli_logit",
117 Self::BernoulliProbit => "pirls_solve_bernoulli_probit",
118 Self::BernoulliCLogLog => "pirls_solve_bernoulli_cloglog",
119 Self::PoissonLog => "pirls_solve_poisson_log",
120 Self::GaussianIdentity => "pirls_solve_gaussian_identity",
121 Self::GammaLog => "pirls_solve_gamma_log",
122 }
123 }
124
125 pub const fn ladder_kernel_name(self) -> &'static str {
129 match self {
130 Self::BernoulliLogit => "pirls_ladder_bernoulli_logit",
131 Self::BernoulliProbit => "pirls_ladder_bernoulli_probit",
132 Self::BernoulliCLogLog => "pirls_ladder_bernoulli_cloglog",
133 Self::PoissonLog => "pirls_ladder_poisson_log",
134 Self::GaussianIdentity => "pirls_ladder_gaussian_identity",
135 Self::GammaLog => "pirls_ladder_gamma_log",
136 }
137 }
138
139 pub const fn is_canonical(self) -> bool {
151 match self {
152 Self::BernoulliLogit | Self::PoissonLog | Self::GaussianIdentity => true,
153 Self::GammaLog | Self::BernoulliProbit | Self::BernoulliCLogLog => false,
154 }
155 }
156}
157
158#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
165pub enum CurvatureMode {
166 Fisher,
167 Observed,
168}
169
170impl CurvatureMode {
171 pub const fn as_str(self) -> &'static str {
172 match self {
173 Self::Fisher => "fisher",
174 Self::Observed => "observed",
175 }
176 }
177}
178
179pub mod status_flags {
181 pub const ETA_CLAMPED: u32 = 1 << 0;
182 pub const MU_FLOORED: u32 = 1 << 1;
183 pub const NONSMOOTH_BERNOULLI: u32 = 1 << 2;
184 pub const INVALID_RESPONSE: u32 = 1 << 3;
185 pub const ZERO_PRIOR_WEIGHT: u32 = 1 << 4;
186}
187
188#[derive(Clone, Copy, Debug)]
201pub struct RowInput {
202 pub eta: f64,
203 pub y: f64,
204 pub prior_weight: f64,
205}
206
207#[derive(Clone, Copy, Debug, Default)]
209pub struct RowOutput {
210 pub mu: f64,
211 pub grad_eta: f64,
212 pub w_fisher: f64,
213 pub w_hessian: f64,
214 pub w_solver: f64,
215 pub z_fisher: f64,
216 pub z_hessian: f64,
217 pub deviance: f64,
218 pub status: u32,
219}
220
221const ETA_CLAMP: f64 = 700.0;
222const MU_FLOOR_POISSON: f64 = 1.0e-10;
223const MU_FLOOR_GAMMA: f64 = 1.0e-10;
224const MU_FLOOR_BERNOULLI: f64 = 1.0e-12;
225const W_SOLVER_FLOOR: f64 = 1.0e-12;
226const DMU_DETA_MIN: f64 = 0.0;
230
231#[inline]
232fn clamp_eta(eta: f64) -> (f64, bool) {
233 if eta > ETA_CLAMP {
234 (ETA_CLAMP, true)
235 } else if eta < -ETA_CLAMP {
236 (-ETA_CLAMP, true)
237 } else {
238 (eta, false)
239 }
240}
241
242pub fn row_reweight_cpu(
248 family: PirlsRowFamily,
249 mode: CurvatureMode,
250 input: RowInput,
251 gamma_shape: f64,
252) -> RowOutput {
253 match family {
254 PirlsRowFamily::GaussianIdentity => row_gaussian_identity(input, mode),
255 PirlsRowFamily::PoissonLog => row_poisson_log(input, mode),
256 PirlsRowFamily::GammaLog => row_gamma_log(input, mode, gamma_shape),
257 PirlsRowFamily::BernoulliLogit => row_bernoulli_logit(input, mode),
258 PirlsRowFamily::BernoulliProbit => row_bernoulli_probit(input, mode),
259 PirlsRowFamily::BernoulliCLogLog => row_bernoulli_cloglog(input, mode),
260 }
261}
262
263#[inline]
269fn select_w_hessian(mode: CurvatureMode, w_fisher: f64, observed_correction: f64) -> f64 {
270 match mode {
271 CurvatureMode::Fisher => w_fisher,
272 CurvatureMode::Observed => w_fisher + observed_correction,
273 }
274}
275
276#[inline]
277fn row_gaussian_identity(input: RowInput, mode: CurvatureMode) -> RowOutput {
278 let w = input.prior_weight.max(0.0);
279 let mu = input.eta;
280 let resid = input.y - mu;
281 let dev = w * resid * resid;
282 let status = if input.prior_weight <= 0.0 {
283 status_flags::ZERO_PRIOR_WEIGHT
284 } else {
285 0
286 };
287 let w_hessian = select_w_hessian(mode, w, 0.0);
289 RowOutput {
290 mu,
291 grad_eta: w * resid,
292 w_fisher: w,
293 w_hessian,
294 w_solver: if w_hessian > 0.0 {
295 w_hessian.max(W_SOLVER_FLOOR)
296 } else {
297 0.0
298 },
299 z_fisher: input.y,
300 z_hessian: input.y,
301 deviance: dev,
302 status,
303 }
304}
305
306#[inline]
307fn row_poisson_log(input: RowInput, mode: CurvatureMode) -> RowOutput {
308 let (eta_c, clamped) = clamp_eta(input.eta);
309 let mu_raw = eta_c.exp();
310 let mu_floored = mu_raw < MU_FLOOR_POISSON;
311 let mu = mu_raw.max(MU_FLOOR_POISSON);
312 let w_prior = input.prior_weight.max(0.0);
313 let raw_w = w_prior * mu;
314 let w_fisher = if raw_w > 0.0 {
315 raw_w.max(W_SOLVER_FLOOR)
316 } else {
317 0.0
318 };
319 let resid = input.y - mu;
320 let dev_term = if input.y > 0.0 {
323 input.y * (input.y / mu).ln() - resid
324 } else {
325 -resid
326 };
327 let dev = 2.0 * w_prior * dev_term;
328 let z = eta_c + resid / mu;
329 let mut status = 0u32;
330 if clamped {
331 status |= status_flags::ETA_CLAMPED;
332 }
333 if mu_floored {
334 status |= status_flags::MU_FLOORED;
335 }
336 if input.prior_weight <= 0.0 {
337 status |= status_flags::ZERO_PRIOR_WEIGHT;
338 }
339 if !(input.y.is_finite() && input.y >= 0.0) {
340 status |= status_flags::INVALID_RESPONSE;
341 }
342 let w_hessian = select_w_hessian(mode, w_fisher, 0.0);
344 RowOutput {
345 mu,
346 grad_eta: w_prior * resid,
347 w_fisher,
348 w_hessian,
349 w_solver: w_hessian,
350 z_fisher: z,
351 z_hessian: z,
352 deviance: dev,
353 status,
354 }
355}
356
357#[inline]
358fn row_gamma_log(input: RowInput, mode: CurvatureMode, shape: f64) -> RowOutput {
359 let (eta_c, clamped) = clamp_eta(input.eta);
360 let mu_raw = eta_c.exp();
361 let mu_floored = mu_raw < MU_FLOOR_GAMMA;
362 let mu = mu_raw.max(MU_FLOOR_GAMMA);
363 let w_prior = input.prior_weight.max(0.0);
364 let w_fisher = w_prior * shape;
365 let obs_correction = if w_fisher > 0.0 && mu > 0.0 && input.y.is_finite() {
370 w_fisher * (input.y / mu - 1.0)
371 } else {
372 0.0
373 };
374 let w_hessian = select_w_hessian(mode, w_fisher, obs_correction);
375 let resid = input.y - mu;
376 let dev = if input.y > 0.0 {
378 2.0 * w_prior * (-((input.y / mu).ln()) + resid / mu)
379 } else {
380 f64::INFINITY
383 };
384 let z = eta_c + resid / mu;
385 let mut status = 0u32;
386 if clamped {
387 status |= status_flags::ETA_CLAMPED;
388 }
389 if mu_floored {
390 status |= status_flags::MU_FLOORED;
391 }
392 if input.prior_weight <= 0.0 {
393 status |= status_flags::ZERO_PRIOR_WEIGHT;
394 }
395 if !(input.y.is_finite() && input.y > 0.0) {
396 status |= status_flags::INVALID_RESPONSE;
397 }
398 RowOutput {
399 mu,
400 grad_eta: w_prior * resid / mu,
401 w_fisher,
402 w_hessian,
403 w_solver: if w_hessian > 0.0 {
404 w_hessian.max(W_SOLVER_FLOOR)
405 } else {
406 0.0
407 },
408 z_fisher: z,
409 z_hessian: z,
410 deviance: dev,
411 status,
412 }
413}
414
415#[inline]
416fn row_bernoulli_logit(input: RowInput, mode: CurvatureMode) -> RowOutput {
417 let (eta_c, clamped) = clamp_eta(input.eta);
418 let half = 0.5 * eta_c;
421 let mu_raw = 0.5 * (1.0 + half.tanh());
422 let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
423 let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
424 let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
425 let w_prior = input.prior_weight.max(0.0);
426 let dmu_deta = mu * (1.0 - mu); let w_fisher = w_prior * dmu_deta; let resid = input.y - mu;
429 let grad_eta = w_prior * resid; let dev = bernoulli_deviance(input.y, mu, w_prior);
432 let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
433 let mut status = 0u32;
434 if clamped {
435 status |= status_flags::ETA_CLAMPED;
436 }
437 if mu_low || mu_high {
438 status |= status_flags::MU_FLOORED;
439 }
440 if input.prior_weight <= 0.0 {
441 status |= status_flags::ZERO_PRIOR_WEIGHT;
442 }
443 if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
444 status |= status_flags::INVALID_RESPONSE;
445 }
446 let w_hessian = select_w_hessian(mode, w_fisher, 0.0);
447 RowOutput {
448 mu,
449 grad_eta,
450 w_fisher,
451 w_hessian,
452 w_solver: if w_hessian > 0.0 {
453 w_hessian.max(W_SOLVER_FLOOR)
454 } else {
455 0.0
456 },
457 z_fisher: z,
458 z_hessian: z,
459 deviance: dev,
460 status,
461 }
462}
463
464#[inline]
465fn row_bernoulli_probit(input: RowInput, mode: CurvatureMode) -> RowOutput {
466 let (eta_c, clamped) = clamp_eta(input.eta);
467 let mu_raw = standard_normal_cdf(eta_c);
468 let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
469 let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
470 let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
471 let w_prior = input.prior_weight.max(0.0);
472 let dmu_deta = standard_normal_pdf(eta_c); let v = mu * (1.0 - mu);
474 let fisher_per_prior = if v > 0.0 {
475 dmu_deta * dmu_deta / v
476 } else {
477 0.0
478 };
479 let w_fisher = w_prior * fisher_per_prior;
480 let resid = input.y - mu;
481 let grad_eta = if v > 0.0 {
482 w_prior * resid * dmu_deta / v
483 } else {
484 0.0
485 };
486 let dev = bernoulli_deviance(input.y, mu, w_prior);
487 let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
488 let mut status = 0u32;
489 if clamped {
490 status |= status_flags::ETA_CLAMPED;
491 }
492 if mu_low || mu_high {
493 status |= status_flags::MU_FLOORED;
494 }
495 if input.prior_weight <= 0.0 {
496 status |= status_flags::ZERO_PRIOR_WEIGHT;
497 }
498 if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
499 status |= status_flags::INVALID_RESPONSE;
500 }
501 let obs_correction = if v > 0.0 && w_prior > 0.0 {
505 let h_prime = -eta_c * dmu_deta;
506 let v_prime = 1.0 - 2.0 * mu;
507 let bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
508 w_prior * resid * bracket
509 } else {
510 0.0
511 };
512 let w_hessian_observed = select_w_hessian(mode, w_fisher, obs_correction);
513 RowOutput {
514 mu,
515 grad_eta,
516 w_fisher,
517 w_hessian: w_hessian_observed,
518 w_solver: {
519 let wh = w_hessian_observed;
520 if wh > 0.0 {
521 wh.max(W_SOLVER_FLOOR)
522 } else {
523 0.0
524 }
525 },
526 z_fisher: z,
527 z_hessian: z,
528 deviance: dev,
529 status,
530 }
531}
532
533#[inline]
534fn row_bernoulli_cloglog(input: RowInput, mode: CurvatureMode) -> RowOutput {
535 let (eta_c, clamped) = clamp_eta(input.eta);
536 let inner = eta_c.exp();
540 let mu_raw = -(-inner).exp_m1();
541 let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
542 let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
543 let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
544 let dmu_deta = inner * (1.0 - mu_raw);
547 let w_prior = input.prior_weight.max(0.0);
548 let v = mu * (1.0 - mu);
549 let fisher_per_prior = if v > 0.0 {
550 dmu_deta * dmu_deta / v
551 } else {
552 0.0
553 };
554 let w_fisher = w_prior * fisher_per_prior;
555 let resid = input.y - mu;
556 let grad_eta = if v > 0.0 {
557 w_prior * resid * dmu_deta / v
558 } else {
559 0.0
560 };
561 let dev = bernoulli_deviance(input.y, mu, w_prior);
562 let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
563 let mut status = 0u32;
564 if clamped {
565 status |= status_flags::ETA_CLAMPED;
566 }
567 if mu_low || mu_high {
568 status |= status_flags::MU_FLOORED;
569 }
570 if input.prior_weight <= 0.0 {
571 status |= status_flags::ZERO_PRIOR_WEIGHT;
572 }
573 if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
574 status |= status_flags::INVALID_RESPONSE;
575 }
576 let obs_correction = if v > 0.0 && w_prior > 0.0 {
581 let h_prime = dmu_deta * (1.0 - inner);
582 let v_prime = 1.0 - 2.0 * mu;
583 let bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
584 w_prior * resid * bracket
585 } else {
586 0.0
587 };
588 let w_hessian = select_w_hessian(mode, w_fisher, obs_correction);
589 RowOutput {
590 mu,
591 grad_eta,
592 w_fisher,
593 w_hessian,
594 w_solver: if w_hessian > 0.0 {
595 w_hessian.max(W_SOLVER_FLOOR)
596 } else {
597 0.0
598 },
599 z_fisher: z,
600 z_hessian: z,
601 deviance: dev,
602 status,
603 }
604}
605
606#[inline]
607fn bernoulli_deviance(y: f64, mu: f64, w_prior: f64) -> f64 {
608 if w_prior == 0.0 {
609 return 0.0;
610 }
611 let t1 = if y > 0.0 { y * (y / mu).ln() } else { 0.0 };
612 let t2 = if y < 1.0 {
613 (1.0 - y) * ((1.0 - y) / (1.0 - mu)).ln()
614 } else {
615 0.0
616 };
617 2.0 * w_prior * (t1 + t2)
618}
619
620#[inline]
621fn bernoulli_z(eta_used: f64, y: f64, mu: f64, dmu_deta: f64) -> f64 {
622 if dmu_deta.is_finite() && dmu_deta > DMU_DETA_MIN {
623 let delta = (y - mu) / dmu_deta;
624 if delta.is_finite() {
625 return eta_used + delta;
626 }
627 }
628 eta_used
629}
630
631#[inline]
634fn standard_normal_cdf(x: f64) -> f64 {
635 0.5 * gam_gpu::numerics_host::erfc(-x * std::f64::consts::FRAC_1_SQRT_2)
636}
637
638#[inline]
639fn standard_normal_pdf(x: f64) -> f64 {
640 const COEFF: f64 = 0.398_942_280_401_432_7; COEFF * (-0.5 * x * x).exp()
642}
643
644#[must_use]
650pub struct PirlsRowBackend {
651 #[cfg(target_os = "linux")]
652 inner: PirlsRowBackendLinux,
653}
654
655#[cfg(target_os = "linux")]
656struct PirlsRowBackendLinux {
657 ctx: Arc<CudaContext>,
658 modules: Mutex<std::collections::HashMap<ModuleKey, Arc<CudaModule>>>,
659 jit_modules: Mutex<std::collections::HashMap<JitKey, Arc<CudaModule>>>,
663}
664
665#[cfg(target_os = "linux")]
667#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
668enum KernelMode {
669 FinalRow,
672 SolveRow,
674 AlphaLadder,
676}
677
678#[cfg(target_os = "linux")]
679#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
680struct ModuleKey {
681 family: PirlsRowFamily,
682 curvature: CurvatureMode,
683 mode: KernelMode,
684}
685
686impl PirlsRowBackend {
687 pub const fn compiled() -> bool {
688 cfg!(target_os = "linux")
689 }
690
691 pub fn probe() -> Result<&'static Self, GpuError> {
692 static BACKEND: OnceLock<Result<PirlsRowBackend, GpuError>> = OnceLock::new();
693 BACKEND
694 .get_or_init(|| {
695 #[cfg(target_os = "linux")]
696 {
697 Self::probe_linux()
698 }
699 #[cfg(not(target_os = "linux"))]
700 {
701 Err(GpuError::DriverLibraryUnavailable {
702 reason: "pirls_row GPU backend is Linux-only".to_string(),
703 })
704 }
705 })
706 .as_ref()
707 .map_err(GpuError::clone)
708 }
709
710 #[cfg(target_os = "linux")]
711 fn probe_linux() -> Result<Self, GpuError> {
712 let parts = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")?;
713 Ok(Self {
714 inner: PirlsRowBackendLinux {
715 ctx: parts.ctx,
716 modules: Mutex::new(std::collections::HashMap::new()),
717 jit_modules: Mutex::new(std::collections::HashMap::new()),
718 },
719 })
720 }
721
722 #[cfg(target_os = "linux")]
728 fn module_for_kind(
729 &self,
730 family: PirlsRowFamily,
731 curvature: CurvatureMode,
732 mode: KernelMode,
733 label: &str,
734 ) -> Result<Arc<CudaModule>, GpuError> {
735 let key = ModuleKey {
736 family,
737 curvature,
738 mode,
739 };
740 if let Some(existing) = self
741 .inner
742 .modules
743 .lock()
744 .gpu_ctx_with(|err| format!("pirls_row {label}module cache mutex poisoned: {err}"))?
745 .get(&key)
746 {
747 return Ok(existing.clone());
748 }
749 let source = match mode {
750 KernelMode::FinalRow => cuda_source_for(family, curvature),
751 KernelMode::SolveRow => solve_row_source_for(family, curvature),
752 KernelMode::AlphaLadder => ladder_source_for(family, curvature),
753 };
754 let ptx = gam_gpu::device_cache::compile_ptx_arch(&source).gpu_ctx_with(|err| {
758 format!(
759 "pirls_row {label}NVRTC compile failed for {family}/{curv}: {err}",
760 family = family.as_str(),
761 curv = curvature.as_str(),
762 )
763 })?;
764 let module = self
765 .inner
766 .ctx
767 .load_module(ptx)
768 .gpu_ctx_with(|err| format!("pirls_row {label}module load failed: {err}"))?;
769 self.inner
770 .modules
771 .lock()
772 .gpu_ctx_with(|err| format!("pirls_row {label}module cache mutex poisoned: {err}"))?
773 .insert(key, module.clone());
774 Ok(module)
775 }
776
777 #[cfg(target_os = "linux")]
780 pub fn module_for(
781 &self,
782 family: PirlsRowFamily,
783 curvature: CurvatureMode,
784 ) -> Result<Arc<CudaModule>, GpuError> {
785 self.module_for_kind(family, curvature, KernelMode::FinalRow, "")
786 }
787
788 #[cfg(target_os = "linux")]
792 pub fn module_for_solve(
793 &self,
794 family: PirlsRowFamily,
795 curvature: CurvatureMode,
796 ) -> Result<Arc<CudaModule>, GpuError> {
797 self.module_for_kind(family, curvature, KernelMode::SolveRow, "solve ")
798 }
799
800 #[cfg(target_os = "linux")]
804 pub fn module_for_ladder(
805 &self,
806 family: PirlsRowFamily,
807 curvature: CurvatureMode,
808 ) -> Result<Arc<CudaModule>, GpuError> {
809 self.module_for_kind(family, curvature, KernelMode::AlphaLadder, "ladder ")
810 }
811
812 #[cfg(target_os = "linux")]
826 pub fn module_for_jit(
827 &self,
828 spec: &JitFamilySpec,
829 curvature: CurvatureMode,
830 ) -> Result<Arc<CudaModule>, GpuError> {
831 let key = JitKey {
836 spec_id: spec.spec_id,
837 curvature,
838 };
839 if let Some(existing) = self
840 .inner
841 .jit_modules
842 .lock()
843 .gpu_ctx("pirls_row jit cache poisoned")?
844 .get(&key)
845 {
846 return Ok(existing.clone());
847 }
848 let source = spec.cuda_source(curvature);
849 let ptx = gam_gpu::device_cache::compile_ptx_arch(&source).gpu_ctx_with(|err| {
851 format!(
852 "pirls_row JIT NVRTC compile failed for spec_id={} curvature={}: {err}",
853 spec.spec_id,
854 curvature.as_str(),
855 )
856 })?;
857 let module = self
858 .inner
859 .ctx
860 .load_module(ptx)
861 .gpu_ctx("pirls_row JIT module load failed")?;
862 self.inner
863 .jit_modules
864 .lock()
865 .gpu_ctx("pirls_row jit cache poisoned (insert)")?
866 .insert(key, module.clone());
867 Ok(module)
868 }
869}
870
871#[cfg(target_os = "linux")]
873#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
874struct JitKey {
875 spec_id: u64,
876 curvature: CurvatureMode,
877}
878
879#[derive(Clone, Debug)]
895pub struct JitFamilySpec {
896 pub spec_id: u64,
900 pub body: String,
905}
906
907impl JitFamilySpec {
908 #[cfg(target_os = "linux")]
913 pub fn glm(spec_id: u64, family: PirlsRowFamily, curvature: CurvatureMode) -> Self {
914 let body = match family {
915 PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
916 PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
917 PirlsRowFamily::GammaLog => gamma_log_body(curvature),
918 PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
919 PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
920 PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
921 };
922 Self { spec_id, body }
923 }
924
925 pub fn raw(spec_id: u64, body: impl Into<String>) -> Self {
929 Self {
930 spec_id,
931 body: body.into(),
932 }
933 }
934
935 pub fn kernel_name(&self) -> String {
937 format!("pirls_row_jit_{}", self.spec_id)
938 }
939
940 #[cfg(target_os = "linux")]
945 pub fn cuda_source(&self, curvature: CurvatureMode) -> String {
946 let curvature_define = match curvature {
947 CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
948 CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
949 };
950 let kernel_name = self.kernel_name();
951 let body = &self.body;
952 format!(
953 r#"
954{curvature_define}
955{prolog}
956
957extern "C" __global__ void {kernel_name}(
958 int n,
959 const double* __restrict__ eta,
960 const double* __restrict__ y,
961 const double* __restrict__ prior_w,
962 double* __restrict__ mu_out,
963 double* __restrict__ grad_eta_out,
964 double* __restrict__ w_fisher_out,
965 double* __restrict__ w_hessian_out,
966 double* __restrict__ w_solver_out,
967 double* __restrict__ z_fisher_out,
968 double* __restrict__ z_hessian_out,
969 double* __restrict__ deviance_out,
970 unsigned int* __restrict__ status_out
971) {{
972 int i = blockIdx.x * blockDim.x + threadIdx.x;
973 if (i >= n) return;
974 unsigned int flags = 0u;
975 double eta_i = eta[i];
976 double y_i = y[i];
977 double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
978 if (prior_w[i] <= 0.0) flags |= 0x10u;
979{body}
980 mu_out[i] = mu;
981 grad_eta_out[i] = grad_eta;
982 w_fisher_out[i] = w_fisher;
983 w_hessian_out[i] = w_hessian;
984 w_solver_out[i] = w_solver;
985 z_fisher_out[i] = z_f;
986 z_hessian_out[i] = z_h;
987 deviance_out[i] = dev;
988 status_out[i] = flags;
989}}
990"#,
991 prolog = COMMON_DEVICE_PROLOG,
992 )
993 }
994}
995
996#[cfg(target_os = "linux")]
1003pub struct RowOutputDevBuffers {
1004 pub mu: cudarc::driver::CudaSlice<f64>,
1005 pub grad_eta: cudarc::driver::CudaSlice<f64>,
1006 pub w_fisher: cudarc::driver::CudaSlice<f64>,
1007 pub w_hessian: cudarc::driver::CudaSlice<f64>,
1008 pub w_solver: cudarc::driver::CudaSlice<f64>,
1009 pub z_fisher: cudarc::driver::CudaSlice<f64>,
1010 pub z_hessian: cudarc::driver::CudaSlice<f64>,
1011 pub deviance: cudarc::driver::CudaSlice<f64>,
1012 pub status: cudarc::driver::CudaSlice<u32>,
1013 pub n: usize,
1014}
1015
1016#[cfg(target_os = "linux")]
1017impl RowOutputDevBuffers {
1018 pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>, n: usize) -> Result<Self, GpuError> {
1020 let alloc_f64 = |label: &'static str| {
1021 stream
1022 .alloc_zeros::<f64>(n)
1023 .gpu_ctx_with(|err| format!("pirls_row alloc {label}: {err}"))
1024 };
1025 let alloc_u32 = |label: &'static str| {
1026 stream
1027 .alloc_zeros::<u32>(n)
1028 .gpu_ctx_with(|err| format!("pirls_row alloc {label}: {err}"))
1029 };
1030 Ok(Self {
1031 mu: alloc_f64("mu")?,
1032 grad_eta: alloc_f64("grad_eta")?,
1033 w_fisher: alloc_f64("w_fisher")?,
1034 w_hessian: alloc_f64("w_hessian")?,
1035 w_solver: alloc_f64("w_solver")?,
1036 z_fisher: alloc_f64("z_fisher")?,
1037 z_hessian: alloc_f64("z_hessian")?,
1038 deviance: alloc_f64("deviance")?,
1039 status: alloc_u32("status")?,
1040 n,
1041 })
1042 }
1043}
1044
1045#[cfg(target_os = "linux")]
1054pub struct SolveRowBuffers {
1055 pub grad_eta: cudarc::driver::CudaSlice<f64>,
1057 pub w_solver: cudarc::driver::CudaSlice<f64>,
1059 pub deviance: cudarc::driver::CudaSlice<f64>,
1061 pub status: cudarc::driver::CudaSlice<u32>,
1063 pub n: usize,
1064}
1065
1066#[cfg(target_os = "linux")]
1067impl SolveRowBuffers {
1068 pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>, n: usize) -> Result<Self, GpuError> {
1070 let alloc_f64 = |label: &'static str| {
1071 stream
1072 .alloc_zeros::<f64>(n)
1073 .gpu_ctx_with(|err| format!("pirls_row solve alloc {label}: {err}"))
1074 };
1075 let alloc_u32 = |label: &'static str| {
1076 stream
1077 .alloc_zeros::<u32>(n)
1078 .gpu_ctx_with(|err| format!("pirls_row solve alloc {label}: {err}"))
1079 };
1080 Ok(Self {
1081 grad_eta: alloc_f64("grad_eta")?,
1082 w_solver: alloc_f64("w_solver")?,
1083 deviance: alloc_f64("deviance")?,
1084 status: alloc_u32("status")?,
1085 n,
1086 })
1087 }
1088}
1089
1090pub const ALPHA_LADDER_LEN: usize = 7;
1092
1093pub const ALPHA_LADDER: [f64; ALPHA_LADDER_LEN] =
1095 [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625];
1096
1097#[cfg(target_os = "linux")]
1106pub struct AlphaLadderDevBuffers {
1107 pub objective_dev: cudarc::driver::CudaSlice<f64>,
1109 pub status_dev: cudarc::driver::CudaSlice<u32>,
1111}
1112
1113#[cfg(target_os = "linux")]
1114impl AlphaLadderDevBuffers {
1115 pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>) -> Result<Self, GpuError> {
1117 Ok(Self {
1118 objective_dev: stream
1119 .alloc_zeros::<f64>(ALPHA_LADDER_LEN)
1120 .gpu_ctx_with(|err| format!("pirls_row ladder alloc objective: {err}"))?,
1121 status_dev: stream
1122 .alloc_zeros::<u32>(ALPHA_LADDER_LEN)
1123 .gpu_ctx_with(|err| format!("pirls_row ladder alloc status: {err}"))?,
1124 })
1125 }
1126
1127 pub fn zero(&mut self, stream: &Arc<cudarc::driver::CudaStream>) -> Result<(), GpuError> {
1129 stream
1130 .memset_zeros(&mut self.objective_dev)
1131 .gpu_ctx_with(|err| format!("pirls_row ladder zero objective: {err}"))?;
1132 stream
1133 .memset_zeros(&mut self.status_dev)
1134 .gpu_ctx_with(|err| format!("pirls_row ladder zero status: {err}"))
1135 }
1136}
1137
1138#[cfg(target_os = "linux")]
1152pub fn launch_row_reweight_on_stream(
1153 backend: &PirlsRowBackend,
1154 family: PirlsRowFamily,
1155 curvature: CurvatureMode,
1156 gamma_shape: f64,
1157 stream: &Arc<cudarc::driver::CudaStream>,
1158 n: usize,
1159 eta_dev: &cudarc::driver::CudaSlice<f64>,
1160 y_dev: &cudarc::driver::CudaSlice<f64>,
1161 prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1162 out: &mut RowOutputDevBuffers,
1163) -> Result<(), GpuError> {
1164 use cudarc::driver::{LaunchConfig, PushKernelArg};
1165 if out.n != n {
1166 gam_gpu::gpu_bail!("row reweight buffers shape {} mismatches n={n}", out.n);
1167 }
1168 let module = backend.module_for(family, curvature)?;
1169 let func = module
1170 .load_function(family.kernel_name())
1171 .gpu_ctx_with(|err| {
1172 format!(
1173 "row reweight load_function({}): {err}",
1174 family.kernel_name()
1175 )
1176 })?;
1177 const THREADS_PER_BLOCK: u32 = 256;
1178 let n_u32 =
1179 u32::try_from(n).map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for row reweight grid sizing"))?;
1180 let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1181 let n_i32 = i32::try_from(n)
1182 .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for row reweight kernel argument"))?;
1183 let cfg = LaunchConfig {
1184 grid_dim: (grid_x, 1, 1),
1185 block_dim: (THREADS_PER_BLOCK, 1, 1),
1186 shared_mem_bytes: 0,
1187 };
1188 let mut builder = stream.launch_builder(&func);
1189 builder.arg(&n_i32);
1190 builder.arg(eta_dev);
1191 builder.arg(y_dev);
1192 builder.arg(prior_w_dev);
1193 if matches!(family, PirlsRowFamily::GammaLog) {
1195 builder.arg(&gamma_shape);
1196 }
1197 builder.arg(&mut out.mu);
1198 builder.arg(&mut out.grad_eta);
1199 builder.arg(&mut out.w_fisher);
1200 builder.arg(&mut out.w_hessian);
1201 builder.arg(&mut out.w_solver);
1202 builder.arg(&mut out.z_fisher);
1203 builder.arg(&mut out.z_hessian);
1204 builder.arg(&mut out.deviance);
1205 builder.arg(&mut out.status);
1206 unsafe { builder.launch(cfg) }
1213 .map(|_event_pair| ())
1214 .gpu_ctx_with(|err| format!("row reweight launch({}): {err}", family.kernel_name()))
1215}
1216
1217#[cfg(target_os = "linux")]
1223pub fn launch_row_reweight_jit_on_stream(
1224 backend: &PirlsRowBackend,
1225 spec: &JitFamilySpec,
1226 curvature: CurvatureMode,
1227 stream: &Arc<cudarc::driver::CudaStream>,
1228 n: usize,
1229 eta_dev: &cudarc::driver::CudaSlice<f64>,
1230 y_dev: &cudarc::driver::CudaSlice<f64>,
1231 prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1232 out: &mut RowOutputDevBuffers,
1233) -> Result<(), GpuError> {
1234 use cudarc::driver::{LaunchConfig, PushKernelArg};
1235 if out.n != n {
1236 gam_gpu::gpu_bail!("JIT row reweight buffers shape {} mismatches n={n}", out.n);
1237 }
1238 let module = backend.module_for_jit(spec, curvature)?;
1239 let kernel_name = spec.kernel_name();
1240 let func = module
1241 .load_function(&kernel_name)
1242 .gpu_ctx_with(|err| format!("JIT row reweight load_function({kernel_name}): {err}"))?;
1243 const THREADS_PER_BLOCK: u32 = 256;
1244 let n_u32 = u32::try_from(n)
1245 .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for JIT row reweight grid sizing"))?;
1246 let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1247 let n_i32 = i32::try_from(n)
1248 .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for JIT row reweight kernel argument"))?;
1249 let cfg = LaunchConfig {
1250 grid_dim: (grid_x, 1, 1),
1251 block_dim: (THREADS_PER_BLOCK, 1, 1),
1252 shared_mem_bytes: 0,
1253 };
1254 let mut builder = stream.launch_builder(&func);
1255 builder.arg(&n_i32);
1256 builder.arg(eta_dev);
1257 builder.arg(y_dev);
1258 builder.arg(prior_w_dev);
1259 builder.arg(&mut out.mu);
1260 builder.arg(&mut out.grad_eta);
1261 builder.arg(&mut out.w_fisher);
1262 builder.arg(&mut out.w_hessian);
1263 builder.arg(&mut out.w_solver);
1264 builder.arg(&mut out.z_fisher);
1265 builder.arg(&mut out.z_hessian);
1266 builder.arg(&mut out.deviance);
1267 builder.arg(&mut out.status);
1268 unsafe { builder.launch(cfg) }
1271 .map(|_event_pair| ())
1272 .gpu_ctx_with(|err| format!("JIT row reweight launch({kernel_name}): {err}"))
1273}
1274
1275#[cfg(target_os = "linux")]
1291pub fn launch_solve_row_on_stream(
1292 backend: &PirlsRowBackend,
1293 family: PirlsRowFamily,
1294 curvature: CurvatureMode,
1295 gamma_shape: f64,
1296 stream: &Arc<cudarc::driver::CudaStream>,
1297 n: usize,
1298 eta_dev: &cudarc::driver::CudaSlice<f64>,
1299 y_dev: &cudarc::driver::CudaSlice<f64>,
1300 prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1301 out: &mut SolveRowBuffers,
1302) -> Result<(), GpuError> {
1303 use cudarc::driver::{LaunchConfig, PushKernelArg};
1304 if out.n != n {
1305 gam_gpu::gpu_bail!("solve-row buffers shape {} mismatches n={n}", out.n);
1306 }
1307 let module = backend.module_for_solve(family, curvature)?;
1308 let kernel_name = family.solve_kernel_name();
1309 let func = module
1310 .load_function(kernel_name)
1311 .gpu_ctx_with(|err| format!("solve-row load_function({kernel_name}): {err}"))?;
1312 const THREADS_PER_BLOCK: u32 = 256;
1313 let n_u32 =
1314 u32::try_from(n).map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for solve-row grid sizing"))?;
1315 let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1316 let n_i32 = i32::try_from(n)
1317 .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for solve-row kernel argument"))?;
1318 let cfg = LaunchConfig {
1319 grid_dim: (grid_x, 1, 1),
1320 block_dim: (THREADS_PER_BLOCK, 1, 1),
1321 shared_mem_bytes: 0,
1322 };
1323 let mut builder = stream.launch_builder(&func);
1324 builder.arg(&n_i32);
1325 builder.arg(eta_dev);
1326 builder.arg(y_dev);
1327 builder.arg(prior_w_dev);
1328 if matches!(family, PirlsRowFamily::GammaLog) {
1330 builder.arg(&gamma_shape);
1331 }
1332 builder.arg(&mut out.grad_eta);
1333 builder.arg(&mut out.w_solver);
1334 builder.arg(&mut out.deviance);
1335 builder.arg(&mut out.status);
1336 unsafe { builder.launch(cfg) }
1343 .map(|_event_pair| ())
1344 .gpu_ctx_with(|err| format!("solve-row launch({kernel_name}): {err}"))
1345}
1346
1347#[cfg(target_os = "linux")]
1361pub fn launch_alpha_ladder_on_stream(
1362 backend: &PirlsRowBackend,
1363 family: PirlsRowFamily,
1364 curvature: CurvatureMode,
1365 gamma_shape: f64,
1366 stream: &Arc<cudarc::driver::CudaStream>,
1367 n: usize,
1368 eta_dev: &cudarc::driver::CudaSlice<f64>,
1369 xd_dev: &cudarc::driver::CudaSlice<f64>,
1370 y_dev: &cudarc::driver::CudaSlice<f64>,
1371 prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1372 out: &mut AlphaLadderDevBuffers,
1373) -> Result<(), GpuError> {
1374 use cudarc::driver::{LaunchConfig, PushKernelArg};
1375 let module = backend.module_for_ladder(family, curvature)?;
1376 let kernel_name = family.ladder_kernel_name();
1377 let func = module
1378 .load_function(kernel_name)
1379 .gpu_ctx_with(|err| format!("alpha-ladder load_function({kernel_name}): {err}"))?;
1380 const THREADS_PER_BLOCK: u32 = 256;
1381 let n_u32 =
1382 u32::try_from(n).map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for alpha-ladder grid sizing"))?;
1383 let row_blocks = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1384 let n_i32 = i32::try_from(n)
1385 .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for alpha-ladder kernel argument"))?;
1386 let cfg = LaunchConfig {
1388 grid_dim: (row_blocks, ALPHA_LADDER_LEN as u32, 1),
1389 block_dim: (THREADS_PER_BLOCK, 1, 1),
1390 shared_mem_bytes: 0,
1391 };
1392 let mut builder = stream.launch_builder(&func);
1393 builder.arg(&n_i32);
1394 builder.arg(eta_dev);
1395 builder.arg(xd_dev);
1396 builder.arg(y_dev);
1397 builder.arg(prior_w_dev);
1398 if matches!(family, PirlsRowFamily::GammaLog) {
1400 builder.arg(&gamma_shape);
1401 }
1402 builder.arg(&mut out.objective_dev);
1403 builder.arg(&mut out.status_dev);
1404 unsafe { builder.launch(cfg) }
1412 .map(|_event_pair| ())
1413 .gpu_ctx_with(|err| format!("alpha-ladder launch({kernel_name}): {err}"))
1414}
1415
1416#[cfg(target_os = "linux")]
1422const COMMON_DEVICE_PROLOG: &str = r#"
1423extern "C" {
1424 double exp(double);
1425 double log(double);
1426 double log1p(double);
1427 double tanh(double);
1428 double sqrt(double);
1429 double fabs(double);
1430 double erfc(double);
1431}
1432
1433__device__ __forceinline__ double clamp_eta(double eta, unsigned int* flags) {
1434 const double E = 700.0;
1435 if (eta > E) { *flags |= 0x1u; return E; }
1436 if (eta < -E) { *flags |= 0x1u; return -E; }
1437 return eta;
1438}
1439
1440__device__ __forceinline__ double bernoulli_deviance(double y, double mu, double w) {
1441 if (w == 0.0) return 0.0;
1442 double t1 = (y > 0.0) ? y * log(y / mu) : 0.0;
1443 double t2 = (y < 1.0) ? (1.0 - y) * log((1.0 - y) / (1.0 - mu)) : 0.0;
1444 return 2.0 * w * (t1 + t2);
1445}
1446
1447__device__ __forceinline__ double bernoulli_z(double eta, double y, double mu, double dmu_deta) {
1448 if (dmu_deta > 0.0 && isfinite(dmu_deta)) {
1449 double delta = (y - mu) / dmu_deta;
1450 if (isfinite(delta)) return eta + delta;
1451 }
1452 return eta;
1453}
1454
1455__device__ __forceinline__ double std_norm_cdf(double x) {
1456 return 0.5 * erfc(-x * 0.7071067811865475);
1457}
1458
1459__device__ __forceinline__ double std_norm_pdf(double x) {
1460 return 0.3989422804014327 * exp(-0.5 * x * x);
1461}
1462"#;
1463
1464#[cfg(target_os = "linux")]
1472fn cuda_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
1473 let body = match family {
1474 PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
1475 PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
1476 PirlsRowFamily::GammaLog => gamma_log_body(curvature),
1477 PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
1478 PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
1479 PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
1480 };
1481 let kernel_name = family.kernel_name();
1482 let curvature_define = match curvature {
1487 CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
1488 CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
1489 };
1490 let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
1493 " double shape,\n"
1494 } else {
1495 ""
1496 };
1497 format!(
1498 r#"
1499{curvature_define}
1500{prolog}
1501
1502extern "C" __global__ void {kernel_name}(
1503 int n,
1504 const double* __restrict__ eta,
1505 const double* __restrict__ y,
1506 const double* __restrict__ prior_w,
1507{shape_param} double* __restrict__ mu_out,
1508 double* __restrict__ grad_eta_out,
1509 double* __restrict__ w_fisher_out,
1510 double* __restrict__ w_hessian_out,
1511 double* __restrict__ w_solver_out,
1512 double* __restrict__ z_fisher_out,
1513 double* __restrict__ z_hessian_out,
1514 double* __restrict__ deviance_out,
1515 unsigned int* __restrict__ status_out
1516) {{
1517 int i = blockIdx.x * blockDim.x + threadIdx.x;
1518 if (i >= n) return;
1519 unsigned int flags = 0u;
1520 double eta_i = eta[i];
1521 double y_i = y[i];
1522 double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
1523 if (prior_w[i] <= 0.0) flags |= 0x10u;
1524{body}
1525 mu_out[i] = mu;
1526 grad_eta_out[i] = grad_eta;
1527 w_fisher_out[i] = w_fisher;
1528 w_hessian_out[i] = w_hessian;
1529 w_solver_out[i] = w_solver;
1530 z_fisher_out[i] = z_f;
1531 z_hessian_out[i] = z_h;
1532 deviance_out[i] = dev;
1533 status_out[i] = flags;
1534}}
1535"#,
1536 prolog = COMMON_DEVICE_PROLOG,
1537 )
1538}
1539
1540#[cfg(target_os = "linux")]
1545#[inline]
1546fn curvature_tag(curvature: CurvatureMode) -> &'static str {
1547 match curvature {
1548 CurvatureMode::Fisher => " // curvature: fisher\n",
1549 CurvatureMode::Observed => " // curvature: observed\n",
1550 }
1551}
1552
1553#[cfg(target_os = "linux")]
1554fn gaussian_identity_body(curvature: CurvatureMode) -> String {
1555 let tag = curvature_tag(curvature);
1556 format!(
1557 r#"{tag} double mu = eta_i;
1558 double resid = y_i - mu;
1559 double grad_eta = wp * resid;
1560 double w_fisher = wp;
1561 double w_hessian = wp;
1562 double w_solver = (wp > 0.0) ? fmax(wp, 1e-12) : 0.0;
1563 double z_f = y_i;
1564 double z_h = y_i;
1565 double dev = wp * resid * resid;
1566"#
1567 )
1568}
1569
1570#[cfg(target_os = "linux")]
1571fn poisson_log_body(curvature: CurvatureMode) -> String {
1572 let tag = curvature_tag(curvature);
1573 format!(
1574 r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
1575 double mu_raw = exp(eta_c);
1576 if (mu_raw < 1e-10) flags |= 0x2u;
1577 double mu = (mu_raw > 1e-10) ? mu_raw : 1e-10;
1578 double raw_w = wp * mu;
1579 double w_fisher = (raw_w > 0.0) ? fmax(raw_w, 1e-12) : 0.0;
1580 double resid = y_i - mu;
1581 double grad_eta = wp * resid;
1582 double w_hessian = w_fisher;
1583 double w_solver = w_fisher;
1584 double z_f = eta_c + resid / mu;
1585 double z_h = z_f;
1586 double dev_term = (y_i > 0.0) ? (y_i * log(y_i / mu) - resid) : (-resid);
1587 double dev = 2.0 * wp * dev_term;
1588 if (!(isfinite(y_i) && y_i >= 0.0)) flags |= 0x8u;
1589"#
1590 )
1591}
1592
1593#[cfg(target_os = "linux")]
1594fn gamma_log_body(curvature: CurvatureMode) -> String {
1595 let tag = curvature_tag(curvature);
1598 format!(
1599 r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
1600 double mu_raw = exp(eta_c);
1601 if (mu_raw < 1e-10) flags |= 0x2u;
1602 double mu = (mu_raw > 1e-10) ? mu_raw : 1e-10;
1603 double w_fisher = wp * shape;
1604#ifdef PIRLS_CURVATURE_OBSERVED
1605 // Stage 5: observed information for Gamma-log.
1606 // w_obs = w_F + w_F · (y/μ − 1) = w_F · y/μ.
1607 double w_hessian = (w_fisher > 0.0 && mu > 0.0 && isfinite(y_i))
1608 ? w_fisher * (y_i / mu)
1609 : w_fisher;
1610#else
1611 double w_hessian = w_fisher;
1612#endif
1613 double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
1614 double resid = y_i - mu;
1615 double grad_eta = wp * resid / mu;
1616 double z_f = eta_c + resid / mu;
1617 double z_h = z_f;
1618 double dev = (y_i > 0.0)
1619 ? (2.0 * wp * (-log(y_i / mu) + resid / mu))
1620 : (1.0 / 0.0);
1621 if (!(isfinite(y_i) && y_i > 0.0)) flags |= 0x8u;
1622"#
1623 )
1624}
1625
1626#[cfg(target_os = "linux")]
1627fn bernoulli_logit_body(curvature: CurvatureMode) -> String {
1628 let tag = curvature_tag(curvature);
1629 format!(
1630 r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
1631 double half = 0.5 * eta_c;
1632 double mu_raw = 0.5 * (1.0 + tanh(half));
1633 if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
1634 double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
1635 double dmu_deta = mu * (1.0 - mu);
1636 double w_fisher = wp * dmu_deta;
1637 double w_hessian = w_fisher;
1638 double w_solver = (w_fisher > 0.0) ? fmax(w_fisher, 1e-12) : 0.0;
1639 double resid = y_i - mu;
1640 double grad_eta = wp * resid;
1641 double dev = bernoulli_deviance(y_i, mu, wp);
1642 double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
1643 double z_h = z_f;
1644 if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
1645"#
1646 )
1647}
1648
1649#[cfg(target_os = "linux")]
1650fn bernoulli_probit_body(curvature: CurvatureMode) -> String {
1651 let tag = curvature_tag(curvature);
1652 format!(
1653 r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
1654 double mu_raw = std_norm_cdf(eta_c);
1655 if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
1656 double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
1657 double dmu_deta = std_norm_pdf(eta_c);
1658 double v = mu * (1.0 - mu);
1659 double fpp = (v > 0.0) ? dmu_deta * dmu_deta / v : 0.0;
1660 double w_fisher = wp * fpp;
1661#ifdef PIRLS_CURVATURE_OBSERVED
1662 // Stage 5: observed information for Bernoulli probit.
1663 // w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
1664 // h(η)=φ(η), h'(η)=−η·φ(η); V'=1−2μ.
1665 double w_hessian = w_fisher;
1666 if (v > 0.0 && wp > 0.0) {{
1667 double h_prime = -eta_c * dmu_deta;
1668 double v_prime = 1.0 - 2.0 * mu;
1669 double bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
1670 w_hessian = w_fisher + wp * (y_i - mu) * bracket;
1671 }}
1672#else
1673 double w_hessian = w_fisher;
1674#endif
1675 double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
1676 double resid = y_i - mu;
1677 double grad_eta = (v > 0.0) ? wp * resid * dmu_deta / v : 0.0;
1678 double dev = bernoulli_deviance(y_i, mu, wp);
1679 double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
1680 double z_h = z_f;
1681 if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
1682"#
1683 )
1684}
1685
1686#[cfg(target_os = "linux")]
1687fn bernoulli_cloglog_body(curvature: CurvatureMode) -> String {
1688 let tag = curvature_tag(curvature);
1689 format!(
1690 r#"{tag} double eta_c = clamp_eta(eta_i, &flags);
1691 double inner = exp(eta_c);
1692 // μ = 1 − exp(−exp(η)); use -expm1(-inner) to avoid catastrophic
1693 // cancellation in the deep negative tail (η ≲ -36).
1694 double mu_raw = -expm1(-inner);
1695 if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
1696 double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
1697 double dmu_deta = inner * (1.0 - mu_raw);
1698 double v = mu * (1.0 - mu);
1699 double fpp = (v > 0.0) ? dmu_deta * dmu_deta / v : 0.0;
1700 double w_fisher = wp * fpp;
1701#ifdef PIRLS_CURVATURE_OBSERVED
1702 // Stage 5: observed information for Bernoulli cloglog.
1703 // w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
1704 // h'(η) = h(η) · (1 − inner); V'=1−2μ.
1705 double w_hessian = w_fisher;
1706 if (v > 0.0 && wp > 0.0) {{
1707 double h_prime = dmu_deta * (1.0 - inner);
1708 double v_prime = 1.0 - 2.0 * mu;
1709 double bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
1710 w_hessian = w_fisher + wp * (y_i - mu) * bracket;
1711 }}
1712#else
1713 double w_hessian = w_fisher;
1714#endif
1715 double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
1716 double resid = y_i - mu;
1717 double grad_eta = (v > 0.0) ? wp * resid * dmu_deta / v : 0.0;
1718 double dev = bernoulli_deviance(y_i, mu, wp);
1719 double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
1720 double z_h = z_f;
1721 if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
1722"#
1723 )
1724}
1725
1726#[cfg(target_os = "linux")]
1740fn solve_row_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
1741 let body = match family {
1742 PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
1743 PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
1744 PirlsRowFamily::GammaLog => gamma_log_body(curvature),
1745 PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
1746 PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
1747 PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
1748 };
1749 let kernel_name = family.solve_kernel_name();
1750 let curvature_define = match curvature {
1751 CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
1752 CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
1753 };
1754 let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
1756 " double shape,\n"
1757 } else {
1758 ""
1759 };
1760 format!(
1761 r#"
1762{curvature_define}
1763{prolog}
1764
1765extern "C" __global__ void {kernel_name}(
1766 int n,
1767 const double* __restrict__ eta,
1768 const double* __restrict__ y,
1769 const double* __restrict__ prior_w,
1770{shape_param} double* __restrict__ grad_eta_out,
1771 double* __restrict__ w_solver_out,
1772 double* __restrict__ deviance_out,
1773 unsigned int* __restrict__ status_out
1774) {{
1775 int i = blockIdx.x * blockDim.x + threadIdx.x;
1776 if (i >= n) return;
1777 unsigned int flags = 0u;
1778 double eta_i = eta[i];
1779 double y_i = y[i];
1780 double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
1781 if (prior_w[i] <= 0.0) flags |= 0x10u;
1782{body}
1783 grad_eta_out[i] = grad_eta;
1784 w_solver_out[i] = w_solver;
1785 deviance_out[i] = dev;
1786 status_out[i] = flags;
1787}}
1788"#,
1789 prolog = COMMON_DEVICE_PROLOG,
1790 )
1791}
1792
1793#[cfg(target_os = "linux")]
1800const ALPHA_LADDER_CUDA_ARRAY: &str =
1801 "__constant__ double PIRLS_ALPHAS[7] = {1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625};";
1802
1803#[cfg(target_os = "linux")]
1818fn ladder_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
1819 let body = match family {
1820 PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
1821 PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
1822 PirlsRowFamily::GammaLog => gamma_log_body(curvature),
1823 PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
1824 PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
1825 PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
1826 };
1827 let kernel_name = family.ladder_kernel_name();
1828 let curvature_define = match curvature {
1829 CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
1830 CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
1831 };
1832 let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
1841 " double shape,\n"
1842 } else {
1843 ""
1844 };
1845 format!(
1846 r#"
1847{curvature_define}
1848{prolog}
1849{alphas}
1850
1851extern "C" __global__ void {kernel_name}(
1852 int n,
1853 const double* __restrict__ eta,
1854 const double* __restrict__ xd,
1855 const double* __restrict__ y,
1856 const double* __restrict__ prior_w,
1857{shape_param} double* __restrict__ objective_out,
1858 unsigned int* __restrict__ status_out
1859) {{
1860 int i = blockIdx.x * blockDim.x + threadIdx.x;
1861 int k = (int)blockIdx.y;
1862 if (i >= n) return;
1863 unsigned int flags = 0u;
1864 double alpha = PIRLS_ALPHAS[k];
1865 double eta_i = eta[i] + alpha * xd[i];
1866 double y_i = y[i];
1867 double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
1868 if (prior_w[i] <= 0.0) flags |= 0x10u;
1869{body}
1870 atomicAdd(&objective_out[k], dev);
1871 atomicOr(&status_out[k], flags);
1872}}
1873"#,
1874 prolog = COMMON_DEVICE_PROLOG,
1875 alphas = ALPHA_LADDER_CUDA_ARRAY,
1876 )
1877}
1878
1879#[cfg(test)]
1884mod pirls_row_gpu_tests {
1885 use super::*;
1886
1887 fn assert_close(label: &str, got: f64, expected: f64, tol: f64) {
1888 if !(got.is_finite() && expected.is_finite()) {
1889 assert_eq!(
1890 got.is_finite(),
1891 expected.is_finite(),
1892 "{label}: finiteness disagrees (got={got}, expected={expected})"
1893 );
1894 return;
1895 }
1896 let diff = (got - expected).abs();
1897 let denom = expected.abs().max(1.0);
1898 assert!(
1899 diff <= tol * denom,
1900 "{label}: |{got} - {expected}| = {diff} exceeds tol {tol} (rel denom {denom})"
1901 );
1902 }
1903
1904 fn check_family_matches_cpu_reference(family: PirlsRowFamily) {
1905 let etas = [-700.0, -3.0, -0.5, 0.0, 0.5, 3.0, 700.0];
1906 let ys = match family {
1907 PirlsRowFamily::GammaLog => vec![0.5, 1.0, 2.5],
1908 PirlsRowFamily::PoissonLog => vec![0.0, 1.0, 5.0],
1909 PirlsRowFamily::GaussianIdentity => vec![-1.5, 0.0, 2.0],
1910 _ => vec![0.0, 1.0],
1911 };
1912 let ws = [0.0, 1.0, 2.5];
1913 for &eta in &etas {
1914 for &y in &ys {
1915 for &wp in &ws {
1916 let input = RowInput {
1917 eta,
1918 y,
1919 prior_weight: wp,
1920 };
1921 let out = row_reweight_cpu(family, CurvatureMode::Fisher, input, 1.0);
1922 assert!(
1924 out.w_fisher >= 0.0,
1925 "{family:?}: w_fisher must be non-negative (got {})",
1926 out.w_fisher
1927 );
1928 assert!(
1929 out.w_solver >= 0.0,
1930 "{family:?}: w_solver must be non-negative (got {})",
1931 out.w_solver
1932 );
1933 if wp > 0.0 && out.w_hessian > 0.0 {
1934 assert!(
1935 out.w_solver >= W_SOLVER_FLOOR,
1936 "{family:?}: w_solver must be floored away from zero when positive (got {})",
1937 out.w_solver
1938 );
1939 }
1940 if (out.status & status_flags::ETA_CLAMPED) != 0 {
1944 continue;
1945 }
1946 if out.w_fisher > 0.0 && out.z_fisher.is_finite() {
1947 let reconstructed = out.w_fisher * (out.z_fisher - eta);
1948 if reconstructed.is_finite() {
1954 let denom = reconstructed.abs().max(out.grad_eta.abs()).max(1.0);
1955 let diff = (reconstructed - out.grad_eta).abs() / denom;
1956 assert!(
1957 diff < 1.0e-6,
1958 "{family:?} eta={eta} y={y} wp={wp}: grad_eta {} vs w·(z−η) {} differ by rel {}",
1959 out.grad_eta,
1960 reconstructed,
1961 diff
1962 );
1963 }
1964 }
1965 if out.status & status_flags::INVALID_RESPONSE == 0 && wp >= 0.0 {
1967 assert!(
1968 out.deviance >= 0.0 || !out.deviance.is_finite(),
1969 "{family:?} eta={eta} y={y} wp={wp}: deviance must be non-negative for valid inputs (got {})",
1970 out.deviance
1971 );
1972 }
1973 if out.status
1976 & (status_flags::INVALID_RESPONSE | status_flags::ZERO_PRIOR_WEIGHT)
1977 == 0
1978 {
1979 assert!(
1980 out.mu.is_finite(),
1981 "{family:?} eta={eta} y={y} wp={wp}: mu must be finite for valid inputs"
1982 );
1983 assert!(
1984 out.grad_eta.is_finite(),
1985 "{family:?} eta={eta} y={y} wp={wp}: grad_eta must be finite for valid inputs"
1986 );
1987 }
1988 }
1989 }
1990 }
1991 assert_close("self", 0.0, 0.0, 0.0);
1995 }
1996
1997 fn count_active_rows(family: PirlsRowFamily) -> usize {
2001 let mut active = 0usize;
2002 for &eta in [-700.0, -3.0, 0.0, 3.0, 700.0].iter() {
2003 for &y in [0.0, 0.5, 1.0].iter() {
2004 for &wp in [1.0, 2.5].iter() {
2005 let out = row_reweight_cpu(
2006 family,
2007 CurvatureMode::Fisher,
2008 RowInput {
2009 eta,
2010 y,
2011 prior_weight: wp,
2012 },
2013 1.0,
2014 );
2015 if out.w_fisher > 0.0 {
2016 active += 1;
2017 }
2018 }
2019 }
2020 }
2021 active
2022 }
2023
2024 #[test]
2025 fn gaussian_identity_row_invariants() {
2026 check_family_matches_cpu_reference(PirlsRowFamily::GaussianIdentity);
2027 assert!(count_active_rows(PirlsRowFamily::GaussianIdentity) > 0);
2028 }
2029
2030 #[test]
2031 fn poisson_log_row_invariants() {
2032 check_family_matches_cpu_reference(PirlsRowFamily::PoissonLog);
2033 assert!(count_active_rows(PirlsRowFamily::PoissonLog) > 0);
2034 }
2035
2036 #[test]
2037 fn gamma_log_row_invariants() {
2038 check_family_matches_cpu_reference(PirlsRowFamily::GammaLog);
2039 assert!(count_active_rows(PirlsRowFamily::GammaLog) > 0);
2040 }
2041
2042 #[test]
2043 fn bernoulli_logit_row_invariants() {
2044 check_family_matches_cpu_reference(PirlsRowFamily::BernoulliLogit);
2045 assert!(count_active_rows(PirlsRowFamily::BernoulliLogit) > 0);
2046 }
2047
2048 #[test]
2049 fn bernoulli_probit_row_invariants() {
2050 check_family_matches_cpu_reference(PirlsRowFamily::BernoulliProbit);
2051 assert!(count_active_rows(PirlsRowFamily::BernoulliProbit) > 0);
2052 }
2053
2054 #[test]
2055 fn bernoulli_cloglog_row_invariants() {
2056 check_family_matches_cpu_reference(PirlsRowFamily::BernoulliCLogLog);
2057 assert!(count_active_rows(PirlsRowFamily::BernoulliCLogLog) > 0);
2058 }
2059
2060 #[test]
2062 fn gaussian_identity_matches_explicit_formulas() {
2063 let out = row_reweight_cpu(
2064 PirlsRowFamily::GaussianIdentity,
2065 CurvatureMode::Fisher,
2066 RowInput {
2067 eta: 0.25,
2068 y: 1.0,
2069 prior_weight: 2.0,
2070 },
2071 1.0,
2072 );
2073 assert!(out.mu.is_finite() && out.deviance.is_finite());
2074 assert_close("mu", out.mu, 0.25, 0.0);
2075 assert_close("grad_eta", out.grad_eta, 2.0 * (1.0 - 0.25), 1e-15);
2076 assert_close("w_fisher", out.w_fisher, 2.0, 0.0);
2077 assert_close(
2078 "deviance",
2079 out.deviance,
2080 2.0 * (1.0 - 0.25_f64).powi(2),
2081 1e-15,
2082 );
2083 }
2084
2085 #[test]
2087 fn poisson_log_matches_explicit_formulas() {
2088 let out = row_reweight_cpu(
2089 PirlsRowFamily::PoissonLog,
2090 CurvatureMode::Fisher,
2091 RowInput {
2092 eta: 1.5,
2093 y: 4.0,
2094 prior_weight: 1.0,
2095 },
2096 1.0,
2097 );
2098 let expected_mu = (1.5_f64).exp();
2099 assert!(expected_mu.is_finite() && out.mu.is_finite());
2100 assert_close("mu", out.mu, expected_mu, 1e-15);
2101 assert_close("grad_eta", out.grad_eta, 4.0 - expected_mu, 1e-15);
2102 assert_close("w_fisher", out.w_fisher, expected_mu, 1e-15);
2103 }
2104
2105 #[test]
2107 fn bernoulli_logit_matches_explicit_formulas() {
2108 let eta: f64 = 0.7;
2109 let mu = 1.0 / (1.0 + (-eta).exp());
2110 let out = row_reweight_cpu(
2111 PirlsRowFamily::BernoulliLogit,
2112 CurvatureMode::Fisher,
2113 RowInput {
2114 eta,
2115 y: 1.0,
2116 prior_weight: 3.0,
2117 },
2118 1.0,
2119 );
2120 assert!(mu > 0.0 && mu < 1.0);
2121 assert_close("mu", out.mu, mu, 1e-12);
2122 assert_close("w_fisher", out.w_fisher, 3.0 * mu * (1.0 - mu), 1e-12);
2123 assert_close("grad_eta", out.grad_eta, 3.0 * (1.0 - mu), 1e-12);
2124 }
2125
2126 #[test]
2128 fn eta_clamp_status_flag_trips() {
2129 let out = row_reweight_cpu(
2130 PirlsRowFamily::PoissonLog,
2131 CurvatureMode::Fisher,
2132 RowInput {
2133 eta: 1000.0,
2134 y: 0.0,
2135 prior_weight: 1.0,
2136 },
2137 1.0,
2138 );
2139 assert!(out.status & status_flags::ETA_CLAMPED != 0);
2140 }
2141
2142 #[test]
2145 fn backend_compiles_one_module_per_family_when_device_present() {
2146 assert_eq!(PirlsRowBackend::compiled(), cfg!(target_os = "linux"));
2150 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2151 eprintln!("[pirls_row_gpu test] no CUDA runtime — skipping device compile test");
2152 return;
2153 }
2154 #[cfg(target_os = "linux")]
2155 {
2156 let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2157 for &family in PirlsRowFamily::ALL.iter() {
2158 let m1 = backend
2159 .module_for(family, CurvatureMode::Fisher)
2160 .unwrap_or_else(|err| panic!("compile {family:?}: {err}"));
2161 let m2 = backend
2162 .module_for(family, CurvatureMode::Fisher)
2163 .unwrap_or_else(|err| panic!("re-fetch {family:?}: {err}"));
2164 assert!(
2165 Arc::ptr_eq(&m1, &m2),
2166 "{family:?}: module cache must return same handle on second call"
2167 );
2168 }
2169 }
2170 }
2171
2172 #[test]
2179 fn jit_glm_kernel_matches_builtin_byte_identical() {
2180 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2181 eprintln!("[stage_6_jit] no CUDA runtime — skipping");
2182 return;
2183 }
2184 #[cfg(target_os = "linux")]
2185 {
2186 let etas = [-2.0_f64, -0.5, 0.3, 1.5];
2187 let ys = [0.0_f64, 1.0, 0.0, 1.0];
2188 let priors = [1.0_f64, 1.2, 0.8, 1.5];
2189 let n = etas.len();
2190 let family = PirlsRowFamily::BernoulliLogit;
2191 let curvature = CurvatureMode::Fisher;
2192 let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2193 let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2194 .expect("shared backend probe")
2195 .stream;
2196
2197 let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("eta");
2198 let mut y_dev = stream.alloc_zeros::<f64>(n).expect("y");
2199 let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("prior");
2200 stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
2201 stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
2202 stream
2203 .memcpy_htod(&priors, &mut prior_dev)
2204 .expect("up prior");
2205
2206 let mut out_builtin = RowOutputDevBuffers::allocate(&stream, n).expect("alloc builtin");
2208 launch_row_reweight_on_stream(
2209 backend,
2210 family,
2211 curvature,
2212 1.0,
2213 &stream,
2214 n,
2215 &eta_dev,
2216 &y_dev,
2217 &prior_dev,
2218 &mut out_builtin,
2219 )
2220 .expect("builtin launch");
2221
2222 let spec = JitFamilySpec::glm(0x424c_4c47u64, family, curvature);
2224 let mut out_jit = RowOutputDevBuffers::allocate(&stream, n).expect("alloc jit");
2225 launch_row_reweight_jit_on_stream(
2226 backend,
2227 &spec,
2228 curvature,
2229 &stream,
2230 n,
2231 &eta_dev,
2232 &y_dev,
2233 &prior_dev,
2234 &mut out_jit,
2235 )
2236 .expect("jit launch");
2237 stream.synchronize().expect("sync");
2238
2239 for (label, b_dev, j_dev) in [
2241 ("mu", &out_builtin.mu, &out_jit.mu),
2242 ("grad_eta", &out_builtin.grad_eta, &out_jit.grad_eta),
2243 ("w_fisher", &out_builtin.w_fisher, &out_jit.w_fisher),
2244 ("w_hessian", &out_builtin.w_hessian, &out_jit.w_hessian),
2245 ("w_solver", &out_builtin.w_solver, &out_jit.w_solver),
2246 ("z_fisher", &out_builtin.z_fisher, &out_jit.z_fisher),
2247 ("z_hessian", &out_builtin.z_hessian, &out_jit.z_hessian),
2248 ("deviance", &out_builtin.deviance, &out_jit.deviance),
2249 ] {
2250 let b = stream.clone_dtoh(b_dev).expect("dl builtin");
2251 let j = stream.clone_dtoh(j_dev).expect("dl jit");
2252 for i in 0..n {
2253 assert_eq!(
2254 b[i].to_bits(),
2255 j[i].to_bits(),
2256 "{label}[{i}]: builtin {} ≠ jit {}",
2257 b[i],
2258 j[i],
2259 );
2260 }
2261 }
2262 }
2263 }
2264
2265 #[test]
2274 fn jit_raw_body_kernel_matches_builtin_gaussian_byte_identical() {
2275 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2276 eprintln!("[stage_6_jit_raw] no CUDA runtime — skipping");
2277 return;
2278 }
2279 #[cfg(target_os = "linux")]
2280 {
2281 let n: usize = 256;
2286 let mut etas = vec![0.0_f64; n];
2287 let mut ys = vec![0.0_f64; n];
2288 let mut priors = vec![0.0_f64; n];
2289 for i in 0..n {
2290 let t = (i as f64) / (n as f64 - 1.0); etas[i] = -3.0 + 6.0 * t;
2292 ys[i] = 5.0 * (t - 0.5);
2293 priors[i] = if i == 7 {
2294 0.0 } else {
2296 0.25 + 1.75 * t
2297 };
2298 }
2299
2300 let family = PirlsRowFamily::GaussianIdentity;
2301 let curvature = CurvatureMode::Fisher;
2302 let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2303 let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2304 .expect("shared backend probe")
2305 .stream;
2306
2307 let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("eta");
2308 let mut y_dev = stream.alloc_zeros::<f64>(n).expect("y");
2309 let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("prior");
2310 stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
2311 stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
2312 stream
2313 .memcpy_htod(&priors, &mut prior_dev)
2314 .expect("up prior");
2315
2316 let mut out_builtin = RowOutputDevBuffers::allocate(&stream, n).expect("alloc builtin");
2318 launch_row_reweight_on_stream(
2319 backend,
2320 family,
2321 curvature,
2322 1.0,
2323 &stream,
2324 n,
2325 &eta_dev,
2326 &y_dev,
2327 &prior_dev,
2328 &mut out_builtin,
2329 )
2330 .expect("builtin launch");
2331
2332 let raw_body = r#" // level-b raw body: gaussian identity (hand-written)
2339 // identity link: mu = eta
2340 double mu = eta_i;
2341 // ordinary residual on the response scale
2342 double resid = y_i - mu;
2343 // canonical score contribution
2344 double grad_eta = wp * resid;
2345 // fisher info per row: weight itself (V(mu)=1, dmu/deta=1)
2346 double w_fisher = wp;
2347 // observed == fisher for canonical identity link
2348 double w_hessian = wp;
2349 // solver weight clamps tiny positives to avoid singularity
2350 double w_solver = (wp > 0.0) ? fmax(wp, 1e-12) : 0.0;
2351 // working response equals raw response on identity link
2352 double z_f = y_i;
2353 double z_h = y_i;
2354 // squared-error contribution to deviance
2355 double dev = wp * resid * resid;
2356"#;
2357 let spec = JitFamilySpec::raw(0x5241_575f_4741_5553u64, raw_body);
2358 let mut out_jit = RowOutputDevBuffers::allocate(&stream, n).expect("alloc jit");
2359 launch_row_reweight_jit_on_stream(
2360 backend,
2361 &spec,
2362 curvature,
2363 &stream,
2364 n,
2365 &eta_dev,
2366 &y_dev,
2367 &prior_dev,
2368 &mut out_jit,
2369 )
2370 .expect("jit raw launch");
2371 stream.synchronize().expect("sync");
2372
2373 for (label, b_dev, j_dev) in [
2374 ("mu", &out_builtin.mu, &out_jit.mu),
2375 ("grad_eta", &out_builtin.grad_eta, &out_jit.grad_eta),
2376 ("w_fisher", &out_builtin.w_fisher, &out_jit.w_fisher),
2377 ("w_hessian", &out_builtin.w_hessian, &out_jit.w_hessian),
2378 ("w_solver", &out_builtin.w_solver, &out_jit.w_solver),
2379 ("z_fisher", &out_builtin.z_fisher, &out_jit.z_fisher),
2380 ("z_hessian", &out_builtin.z_hessian, &out_jit.z_hessian),
2381 ("deviance", &out_builtin.deviance, &out_jit.deviance),
2382 ] {
2383 let b = stream.clone_dtoh(b_dev).expect("dl builtin");
2384 let j = stream.clone_dtoh(j_dev).expect("dl jit raw");
2385 for i in 0..n {
2386 assert_eq!(
2387 b[i].to_bits(),
2388 j[i].to_bits(),
2389 "{label}[{i}]: builtin {} ≠ jit-raw {}",
2390 b[i],
2391 j[i],
2392 );
2393 }
2394 }
2395
2396 let mu_j = stream.clone_dtoh(&out_jit.mu).expect("dl jit mu");
2403 let g_j = stream.clone_dtoh(&out_jit.grad_eta).expect("dl jit g");
2404 let wf_j = stream.clone_dtoh(&out_jit.w_fisher).expect("dl jit wf");
2405 let wh_j = stream.clone_dtoh(&out_jit.w_hessian).expect("dl jit wh");
2406 let ws_j = stream.clone_dtoh(&out_jit.w_solver).expect("dl jit ws");
2407 let zf_j = stream.clone_dtoh(&out_jit.z_fisher).expect("dl jit zf");
2408 let zh_j = stream.clone_dtoh(&out_jit.z_hessian).expect("dl jit zh");
2409 let d_j = stream.clone_dtoh(&out_jit.deviance).expect("dl jit d");
2410 for i in 0..n {
2411 let cpu = row_reweight_cpu(
2412 PirlsRowFamily::GaussianIdentity,
2413 curvature,
2414 RowInput {
2415 eta: etas[i],
2416 y: ys[i],
2417 prior_weight: priors[i],
2418 },
2419 1.0,
2420 );
2421 for (label, cpu_v, jit_v) in [
2422 ("mu", cpu.mu, mu_j[i]),
2423 ("grad_eta", cpu.grad_eta, g_j[i]),
2424 ("w_fisher", cpu.w_fisher, wf_j[i]),
2425 ("w_hessian", cpu.w_hessian, wh_j[i]),
2426 ("w_solver", cpu.w_solver, ws_j[i]),
2427 ("z_fisher", cpu.z_fisher, zf_j[i]),
2428 ("z_hessian", cpu.z_hessian, zh_j[i]),
2429 ("deviance", cpu.deviance, d_j[i]),
2430 ] {
2431 assert_eq!(
2432 cpu_v.to_bits(),
2433 jit_v.to_bits(),
2434 "{label}[{i}]: cpu {} ≠ jit-raw {}",
2435 cpu_v,
2436 jit_v,
2437 );
2438 }
2439 }
2440 }
2441 }
2442
2443 #[test]
2449 fn observed_curvature_matches_expected_per_family() {
2450 let probe_eta = 0.4_f64;
2453 let probe_y = 1.0_f64;
2454 let wp = 1.5_f64;
2455 let input = RowInput {
2456 eta: probe_eta,
2457 y: probe_y,
2458 prior_weight: wp,
2459 };
2460
2461 for canonical in [
2463 PirlsRowFamily::GaussianIdentity,
2464 PirlsRowFamily::PoissonLog,
2465 PirlsRowFamily::BernoulliLogit,
2466 ] {
2467 let f = row_reweight_cpu(canonical, CurvatureMode::Fisher, input, 1.0);
2468 let o = row_reweight_cpu(canonical, CurvatureMode::Observed, input, 1.0);
2469 assert_eq!(
2470 f.w_hessian, o.w_hessian,
2471 "{canonical:?}: observed must equal Fisher for canonical link"
2472 );
2473 }
2474
2475 for &shape in &[1.0_f64, 2.5] {
2478 let gf = row_reweight_cpu(
2479 PirlsRowFamily::GammaLog,
2480 CurvatureMode::Fisher,
2481 input,
2482 shape,
2483 );
2484 let go = row_reweight_cpu(
2485 PirlsRowFamily::GammaLog,
2486 CurvatureMode::Observed,
2487 input,
2488 shape,
2489 );
2490 assert!(
2491 (go.w_hessian - gf.w_fisher * (probe_y / gf.mu)).abs() <= 1e-12,
2492 "Gamma-log observed mismatch (shape={shape}): got={} expected={} (mu={})",
2493 go.w_hessian,
2494 gf.w_fisher * (probe_y / gf.mu),
2495 gf.mu
2496 );
2497 assert_ne!(
2498 gf.w_hessian, go.w_hessian,
2499 "Gamma-log: observed must differ from Fisher when y ≠ μ (shape={shape})"
2500 );
2501 }
2502
2503 for noncanon in [
2507 PirlsRowFamily::BernoulliProbit,
2508 PirlsRowFamily::BernoulliCLogLog,
2509 ] {
2510 let f = row_reweight_cpu(noncanon, CurvatureMode::Fisher, input, 1.0);
2511 let o = row_reweight_cpu(noncanon, CurvatureMode::Observed, input, 1.0);
2512 assert!(
2513 (f.w_hessian - o.w_hessian).abs() > 0.0 || (probe_y - f.mu).abs() < 1e-15,
2514 "{noncanon:?}: observed should differ from Fisher when y ≠ μ"
2515 );
2516 }
2517 }
2518
2519 #[test]
2522 fn gamma_log_shape_scaling() {
2523 let input = RowInput {
2524 eta: 0.5,
2525 y: 2.0,
2526 prior_weight: 1.0,
2527 };
2528 let base = row_reweight_cpu(PirlsRowFamily::GammaLog, CurvatureMode::Fisher, input, 1.0);
2529 for &shape in &[0.5_f64, 1.5, 3.0, 10.0] {
2530 let r = row_reweight_cpu(
2531 PirlsRowFamily::GammaLog,
2532 CurvatureMode::Fisher,
2533 input,
2534 shape,
2535 );
2536 assert!(
2537 (r.w_fisher - shape * base.w_fisher).abs() <= 1e-14,
2538 "w_fisher should scale with shape: got {} expected {} (shape={shape})",
2539 r.w_fisher,
2540 shape * base.w_fisher,
2541 );
2542 assert_eq!(
2543 r.mu.to_bits(),
2544 base.mu.to_bits(),
2545 "mu must not depend on shape"
2546 );
2547 let ro = row_reweight_cpu(
2548 PirlsRowFamily::GammaLog,
2549 CurvatureMode::Observed,
2550 input,
2551 shape,
2552 );
2553 let expected_obs = r.w_fisher * (input.y / r.mu);
2554 assert!(
2555 (ro.w_hessian - expected_obs).abs() <= 1e-13,
2556 "observed w_hessian mismatch (shape={shape}): got={} expected={}",
2557 ro.w_hessian,
2558 expected_obs,
2559 );
2560 }
2561 }
2562
2563 #[test]
2571 fn launch_row_reweight_matches_cpu_reference_on_device() {
2572 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2573 eprintln!("[pirls_row_gpu test] no CUDA runtime — skipping launcher parity test");
2574 return;
2575 }
2576 #[cfg(target_os = "linux")]
2577 {
2578 let etas = [-3.0_f64, -0.5, 0.0, 0.5, 3.0, 10.0, -10.0, 1.5];
2583 let n = etas.len();
2584 let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2585 let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2586 .expect("shared backend probe")
2587 .stream;
2588
2589 for &family in PirlsRowFamily::ALL.iter() {
2590 let ys: Vec<f64> = match family {
2591 PirlsRowFamily::GammaLog | PirlsRowFamily::PoissonLog => {
2592 (0..n).map(|i| 1.0 + 0.5 * (i as f64)).collect()
2593 }
2594 PirlsRowFamily::GaussianIdentity => {
2595 (0..n).map(|i| -1.0 + 0.5 * (i as f64)).collect()
2596 }
2597 _ => (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
2598 };
2599 let priors: Vec<f64> = (0..n).map(|i| 1.0 + 0.25 * (i as f64)).collect();
2600
2601 let mut cpu_out = Vec::with_capacity(n);
2603 for i in 0..n {
2604 cpu_out.push(row_reweight_cpu(
2605 family,
2606 CurvatureMode::Fisher,
2607 RowInput {
2608 eta: etas[i],
2609 y: ys[i],
2610 prior_weight: priors[i],
2611 },
2612 1.0,
2613 ));
2614 }
2615
2616 let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("alloc eta_dev");
2618 let mut y_dev = stream.alloc_zeros::<f64>(n).expect("alloc y_dev");
2619 let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("alloc prior_dev");
2620 stream
2621 .memcpy_htod(etas.as_slice(), &mut eta_dev)
2622 .expect("upload eta");
2623 stream
2624 .memcpy_htod(ys.as_slice(), &mut y_dev)
2625 .expect("upload y");
2626 stream
2627 .memcpy_htod(priors.as_slice(), &mut prior_dev)
2628 .expect("upload prior");
2629 let mut out = RowOutputDevBuffers::allocate(&stream, n).expect("alloc row buffers");
2630 launch_row_reweight_on_stream(
2631 backend,
2632 family,
2633 CurvatureMode::Fisher,
2634 1.0,
2635 &stream,
2636 n,
2637 &eta_dev,
2638 &y_dev,
2639 &prior_dev,
2640 &mut out,
2641 )
2642 .unwrap_or_else(|err| panic!("launch {family:?}: {err}"));
2643 stream.synchronize().expect("stream sync");
2644 let mu = stream.clone_dtoh(&out.mu).expect("dl mu");
2645 let g = stream.clone_dtoh(&out.grad_eta).expect("dl grad_eta");
2646 let wf = stream.clone_dtoh(&out.w_fisher).expect("dl w_fisher");
2647 let wh = stream.clone_dtoh(&out.w_hessian).expect("dl w_hessian");
2648 let ws_v = stream.clone_dtoh(&out.w_solver).expect("dl w_solver");
2649 let zf = stream.clone_dtoh(&out.z_fisher).expect("dl z_fisher");
2650 let zh = stream.clone_dtoh(&out.z_hessian).expect("dl z_hessian");
2651 let dev = stream.clone_dtoh(&out.deviance).expect("dl deviance");
2652
2653 let tol = 1e-12;
2654 for i in 0..n {
2655 let r = cpu_out[i];
2656 assert_close(&format!("{family:?}/row{i}/mu"), mu[i], r.mu, tol);
2657 assert_close(
2658 &format!("{family:?}/row{i}/grad_eta"),
2659 g[i],
2660 r.grad_eta,
2661 tol,
2662 );
2663 assert_close(
2664 &format!("{family:?}/row{i}/w_fisher"),
2665 wf[i],
2666 r.w_fisher,
2667 tol,
2668 );
2669 assert_close(
2670 &format!("{family:?}/row{i}/w_hessian"),
2671 wh[i],
2672 r.w_hessian,
2673 tol,
2674 );
2675 assert_close(
2676 &format!("{family:?}/row{i}/w_solver"),
2677 ws_v[i],
2678 r.w_solver,
2679 tol,
2680 );
2681 assert_close(
2682 &format!("{family:?}/row{i}/z_fisher"),
2683 zf[i],
2684 r.z_fisher,
2685 tol,
2686 );
2687 assert_close(
2688 &format!("{family:?}/row{i}/z_hessian"),
2689 zh[i],
2690 r.z_hessian,
2691 tol,
2692 );
2693 assert_close(
2694 &format!("{family:?}/row{i}/deviance"),
2695 dev[i],
2696 r.deviance,
2697 tol,
2698 );
2699 }
2700 }
2701 }
2702 }
2703
2704 #[test]
2718 fn gpu_observed_parity() {
2719 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2720 eprintln!("[gpu_observed_parity] no CUDA runtime — skipping");
2721 return;
2722 }
2723 #[cfg(target_os = "linux")]
2724 {
2725 const N: usize = 256;
2726 let etas: Vec<f64> = (0..N)
2727 .map(|i| -6.0 + 12.0 * (i as f64) / ((N - 1) as f64))
2728 .collect();
2729 let priors: Vec<f64> = (0..N)
2730 .map(|i| 0.5 + 1.5 * ((i as f64) / (N as f64)))
2731 .collect();
2732
2733 let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2734 let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2735 .expect("shared backend probe")
2736 .stream;
2737
2738 for &family in PirlsRowFamily::ALL.iter() {
2739 let ys: Vec<f64> = match family {
2740 PirlsRowFamily::GammaLog => (0..N).map(|i| 0.25 + 0.05 * (i as f64)).collect(),
2741 PirlsRowFamily::PoissonLog => (0..N).map(|i| (i % 6) as f64).collect(),
2742 PirlsRowFamily::GaussianIdentity => (0..N)
2743 .map(|i| -2.0 + 4.0 * (i as f64) / ((N - 1) as f64))
2744 .collect(),
2745 _ => (0..N).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
2746 };
2747
2748 let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("alloc eta_dev");
2749 let mut y_dev = stream.alloc_zeros::<f64>(N).expect("alloc y_dev");
2750 let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("alloc prior_dev");
2751 stream
2752 .memcpy_htod(etas.as_slice(), &mut eta_dev)
2753 .expect("upload eta");
2754 stream
2755 .memcpy_htod(ys.as_slice(), &mut y_dev)
2756 .expect("upload y");
2757 stream
2758 .memcpy_htod(priors.as_slice(), &mut prior_dev)
2759 .expect("upload prior");
2760
2761 let mut out_obs = RowOutputDevBuffers::allocate(&stream, N).expect("alloc out_obs");
2762 launch_row_reweight_on_stream(
2763 backend,
2764 family,
2765 CurvatureMode::Observed,
2766 1.0,
2767 &stream,
2768 N,
2769 &eta_dev,
2770 &y_dev,
2771 &prior_dev,
2772 &mut out_obs,
2773 )
2774 .unwrap_or_else(|err| panic!("observed launch {family:?}: {err}"));
2775 stream.synchronize().expect("stream sync (observed)");
2776
2777 let wh_obs = stream
2778 .clone_dtoh(&out_obs.w_hessian)
2779 .expect("dl w_hessian (observed)");
2780 let wf_obs = stream
2781 .clone_dtoh(&out_obs.w_fisher)
2782 .expect("dl w_fisher (observed)");
2783
2784 if family.is_canonical() {
2785 for i in 0..N {
2786 assert_eq!(
2787 wh_obs[i].to_bits(),
2788 wf_obs[i].to_bits(),
2789 "{family:?} row {i}: observed w_hessian {} must bit-equal w_fisher {} on canonical link",
2790 wh_obs[i],
2791 wf_obs[i],
2792 );
2793 }
2794 } else {
2795 for i in 0..N {
2796 let cpu = row_reweight_cpu(
2797 family,
2798 CurvatureMode::Observed,
2799 RowInput {
2800 eta: etas[i],
2801 y: ys[i],
2802 prior_weight: priors[i],
2803 },
2804 1.0,
2805 );
2806 let got = wh_obs[i];
2807 let exp = cpu.w_hessian;
2808 let abs_err = (got - exp).abs();
2809 let rel_err = if exp.abs() > 0.0 {
2810 abs_err / exp.abs()
2811 } else {
2812 abs_err
2813 };
2814 assert!(
2815 abs_err <= 1.0e-12 || rel_err <= 1.0e-11,
2816 "{family:?} row {i} (eta={}, y={}, wp={}): \
2817 device w_hessian={} vs CPU observed={} (abs={}, rel={})",
2818 etas[i],
2819 ys[i],
2820 priors[i],
2821 got,
2822 exp,
2823 abs_err,
2824 rel_err,
2825 );
2826 }
2827 }
2828 }
2829 }
2830 }
2831
2832 #[test]
2841 fn gpu_observed_parity_end_to_end_n1000() {
2842 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2843 eprintln!("[gpu_observed_parity_end_to_end_n1000] no CUDA runtime — skipping");
2844 return;
2845 }
2846 #[cfg(target_os = "linux")]
2847 {
2848 const N: usize = 1000;
2849 let etas: Vec<f64> = (0..N)
2853 .map(|i| -8.0 + 16.0 * (i as f64) / ((N - 1) as f64))
2854 .collect();
2855 let priors: Vec<f64> = (0..N)
2856 .map(|i| 0.25 + 1.75 * ((i as f64) / (N as f64)))
2857 .collect();
2858
2859 let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2860 let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2861 .expect("shared backend probe")
2862 .stream;
2863
2864 const TOL: f64 = 1.0e-9;
2865
2866 for &family in PirlsRowFamily::ALL.iter() {
2867 let ys: Vec<f64> = match family {
2870 PirlsRowFamily::GammaLog => {
2871 (0..N).map(|i| 0.10 + 0.05 * ((i % 97) as f64)).collect()
2872 }
2873 PirlsRowFamily::PoissonLog => (0..N).map(|i| (i % 11) as f64).collect(),
2874 PirlsRowFamily::GaussianIdentity => (0..N)
2875 .map(|i| -3.0 + 6.0 * (i as f64) / ((N - 1) as f64))
2876 .collect(),
2877 PirlsRowFamily::BernoulliLogit
2878 | PirlsRowFamily::BernoulliProbit
2879 | PirlsRowFamily::BernoulliCLogLog => {
2880 (0..N).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect()
2881 }
2882 };
2883
2884 let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("alloc eta_dev");
2885 let mut y_dev = stream.alloc_zeros::<f64>(N).expect("alloc y_dev");
2886 let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("alloc prior_dev");
2887 stream
2888 .memcpy_htod(etas.as_slice(), &mut eta_dev)
2889 .expect("upload eta");
2890 stream
2891 .memcpy_htod(ys.as_slice(), &mut y_dev)
2892 .expect("upload y");
2893 stream
2894 .memcpy_htod(priors.as_slice(), &mut prior_dev)
2895 .expect("upload prior");
2896
2897 let mut out_obs = RowOutputDevBuffers::allocate(&stream, N).expect("alloc out_obs");
2898 launch_row_reweight_on_stream(
2899 backend,
2900 family,
2901 CurvatureMode::Observed,
2902 1.0,
2903 &stream,
2904 N,
2905 &eta_dev,
2906 &y_dev,
2907 &prior_dev,
2908 &mut out_obs,
2909 )
2910 .unwrap_or_else(|err| panic!("observed launch {family:?}: {err}"));
2911 stream.synchronize().expect("stream sync (observed)");
2912
2913 let wh_obs = stream
2914 .clone_dtoh(&out_obs.w_hessian)
2915 .expect("dl w_hessian (observed)");
2916 let ge_obs = stream
2917 .clone_dtoh(&out_obs.grad_eta)
2918 .expect("dl grad_eta (observed)");
2919
2920 for i in 0..N {
2921 let cpu = row_reweight_cpu(
2922 family,
2923 CurvatureMode::Observed,
2924 RowInput {
2925 eta: etas[i],
2926 y: ys[i],
2927 prior_weight: priors[i],
2928 },
2929 1.0,
2930 );
2931
2932 let h_got = wh_obs[i];
2934 let h_exp = cpu.w_hessian;
2935 let h_abs = (h_got - h_exp).abs();
2936 let h_rel = if h_exp.abs() > 0.0 {
2937 h_abs / h_exp.abs()
2938 } else {
2939 h_abs
2940 };
2941 assert!(
2942 h_abs <= TOL || h_rel <= TOL,
2943 "{family:?} row {i} (eta={}, y={}, wp={}): \
2944 observed w_hessian GPU={} vs CPU={} (abs={}, rel={})",
2945 etas[i],
2946 ys[i],
2947 priors[i],
2948 h_got,
2949 h_exp,
2950 h_abs,
2951 h_rel,
2952 );
2953
2954 let g_got = ge_obs[i];
2959 let g_exp = cpu.grad_eta;
2960 let g_abs = (g_got - g_exp).abs();
2961 let g_rel = if g_exp.abs() > 0.0 {
2962 g_abs / g_exp.abs()
2963 } else {
2964 g_abs
2965 };
2966 assert!(
2967 g_abs <= TOL || g_rel <= TOL,
2968 "{family:?} row {i} (eta={}, y={}, wp={}): \
2969 observed grad_eta GPU={} vs CPU={} (abs={}, rel={})",
2970 etas[i],
2971 ys[i],
2972 priors[i],
2973 g_got,
2974 g_exp,
2975 g_abs,
2976 g_rel,
2977 );
2978 }
2979 }
2980 }
2981 }
2982
2983 #[test]
2993 fn gpu_jit_level_b_raw_body_end_to_end_all_families_n1000() {
2994 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2995 eprintln!(
2996 "[gpu_jit_level_b_raw_body_end_to_end_all_families_n1000] no CUDA runtime — skipping"
2997 );
2998 return;
2999 }
3000 #[cfg(target_os = "linux")]
3001 {
3002 const N: usize = 1000;
3003 const TOL: f64 = 1.0e-10;
3004 let curvature = CurvatureMode::Fisher;
3005
3006 let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
3007 let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
3008 .expect("shared backend probe")
3009 .stream;
3010
3011 let etas: Vec<f64> = (0..N)
3016 .map(|i| -6.0 + 12.0 * (i as f64) / ((N - 1) as f64))
3017 .collect();
3018 let priors: Vec<f64> = (0..N)
3019 .map(|i| 0.25 + 1.75 * ((i as f64) / (N as f64)))
3020 .collect();
3021
3022 let raw_gaussian = r#" // raw-body gaussian identity (independent re-derivation)
3027 double resp = y_i;
3028 double pred = eta_i;
3029 double mu = pred;
3030 double w_p = wp;
3031 double e_resid = resp - pred;
3032 double grad_eta = w_p * e_resid;
3033 double w_fisher = w_p;
3034 double w_hessian = w_p;
3035 double w_solver = (w_p > 0.0) ? fmax(w_p, 1e-12) : 0.0;
3036 double z_f = resp;
3037 double z_h = resp;
3038 double dev = w_p * e_resid * e_resid;
3039"#;
3040
3041 let raw_poisson = r#" // raw-body poisson log (independent re-derivation)
3042 double eta_c = clamp_eta(eta_i, &flags);
3043 double mu_pre = exp(eta_c);
3044 if (mu_pre < 1e-10) flags |= 0x2u;
3045 double mu = (mu_pre > 1e-10) ? mu_pre : 1e-10;
3046 double wrate = wp * mu;
3047 double w_fisher = (wrate > 0.0) ? fmax(wrate, 1e-12) : 0.0;
3048 double w_hessian = w_fisher;
3049 double w_solver = w_fisher;
3050 double pres = y_i - mu;
3051 double grad_eta = wp * pres;
3052 double z_lin = eta_c + pres / mu;
3053 double z_f = z_lin;
3054 double z_h = z_lin;
3055 double dterm;
3056 if (y_i > 0.0) {
3057 dterm = y_i * log(y_i / mu) - pres;
3058 } else {
3059 dterm = -pres;
3060 }
3061 double dev = 2.0 * wp * dterm;
3062 if (!(isfinite(y_i) && y_i >= 0.0)) flags |= 0x8u;
3063"#;
3064
3065 let raw_gamma = r#" // raw-body gamma log (independent re-derivation; unit shape)
3066 double k_shape = 1.0;
3067 double eta_c = clamp_eta(eta_i, &flags);
3068 double mu_pre = exp(eta_c);
3069 if (mu_pre < 1e-10) flags |= 0x2u;
3070 double mu = (mu_pre > 1e-10) ? mu_pre : 1e-10;
3071 double w_fisher = wp * k_shape;
3072 double w_hessian = w_fisher;
3073 double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
3074 double pres = y_i - mu;
3075 double grad_eta = wp * pres / mu;
3076 double z_lin = eta_c + pres / mu;
3077 double z_f = z_lin;
3078 double z_h = z_lin;
3079 double dev;
3080 if (y_i > 0.0) {
3081 dev = 2.0 * wp * (-log(y_i / mu) + pres / mu);
3082 } else {
3083 dev = 1.0 / 0.0;
3084 }
3085 if (!(isfinite(y_i) && y_i > 0.0)) flags |= 0x8u;
3086"#;
3087
3088 let raw_logit = r#" // raw-body bernoulli logit (independent re-derivation)
3089 double eta_c = clamp_eta(eta_i, &flags);
3090 double te = tanh(0.5 * eta_c);
3091 double mu_pre = 0.5 * (1.0 + te);
3092 if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
3093 double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
3094 double dmu_deta = mu * (1.0 - mu);
3095 double w_fisher = wp * dmu_deta;
3096 double w_hessian = w_fisher;
3097 double w_solver = (w_fisher > 0.0) ? fmax(w_fisher, 1e-12) : 0.0;
3098 double bres = y_i - mu;
3099 double grad_eta = wp * bres;
3100 double dev = bernoulli_deviance(y_i, mu, wp);
3101 double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
3102 double z_f = z_lin;
3103 double z_h = z_lin;
3104 if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
3105"#;
3106
3107 let raw_probit = r#" // raw-body bernoulli probit (independent re-derivation; Fisher mode)
3108 double eta_c = clamp_eta(eta_i, &flags);
3109 double mu_pre = std_norm_cdf(eta_c);
3110 if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
3111 double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
3112 double phi = std_norm_pdf(eta_c);
3113 double dmu_deta = phi;
3114 double vmu = mu * (1.0 - mu);
3115 double w_pp = (vmu > 0.0) ? (phi * phi) / vmu : 0.0;
3116 double w_fisher = wp * w_pp;
3117 double w_hessian = w_fisher;
3118 double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
3119 double bres = y_i - mu;
3120 double grad_eta = (vmu > 0.0) ? wp * bres * phi / vmu : 0.0;
3121 double dev = bernoulli_deviance(y_i, mu, wp);
3122 double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
3123 double z_f = z_lin;
3124 double z_h = z_lin;
3125 if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
3126"#;
3127
3128 let raw_cloglog = r#" // raw-body bernoulli cloglog (independent re-derivation; Fisher mode)
3129 double eta_c = clamp_eta(eta_i, &flags);
3130 double a = exp(eta_c);
3131 double mu_pre = 1.0 - exp(-a);
3132 if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
3133 double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
3134 double dmu_deta = a * (1.0 - mu_pre);
3135 double vmu = mu * (1.0 - mu);
3136 double w_pp = (vmu > 0.0) ? (dmu_deta * dmu_deta) / vmu : 0.0;
3137 double w_fisher = wp * w_pp;
3138 double w_hessian = w_fisher;
3139 double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
3140 double bres = y_i - mu;
3141 double grad_eta = (vmu > 0.0) ? wp * bres * dmu_deta / vmu : 0.0;
3142 double dev = bernoulli_deviance(y_i, mu, wp);
3143 double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
3144 double z_f = z_lin;
3145 double z_h = z_lin;
3146 if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
3147"#;
3148
3149 let cases: [(PirlsRowFamily, &str, u64, fn(usize) -> Vec<f64>); 6] = [
3153 (
3154 PirlsRowFamily::GaussianIdentity,
3155 raw_gaussian,
3156 0x5242_3031_4741_5553u64,
3157 |n| {
3158 (0..n)
3159 .map(|i| -3.0 + 6.0 * (i as f64) / ((n - 1) as f64))
3160 .collect()
3161 },
3162 ),
3163 (
3164 PirlsRowFamily::PoissonLog,
3165 raw_poisson,
3166 0x5242_3032_504f_4953u64,
3167 |n| (0..n).map(|i| (i % 11) as f64).collect(),
3168 ),
3169 (
3170 PirlsRowFamily::GammaLog,
3171 raw_gamma,
3172 0x5242_3033_474d_414cu64,
3173 |n| (0..n).map(|i| 0.10 + 0.05 * ((i % 97) as f64)).collect(),
3174 ),
3175 (
3176 PirlsRowFamily::BernoulliLogit,
3177 raw_logit,
3178 0x5242_3034_4c47_4954u64,
3179 |n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
3180 ),
3181 (
3182 PirlsRowFamily::BernoulliProbit,
3183 raw_probit,
3184 0x5242_3035_5052_4254u64,
3185 |n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
3186 ),
3187 (
3188 PirlsRowFamily::BernoulliCLogLog,
3189 raw_cloglog,
3190 0x5242_3036_434c_4f47u64,
3191 |n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
3192 ),
3193 ];
3194
3195 for (family, raw_body, spec_id, build_y) in cases {
3196 let ys: Vec<f64> = build_y(N);
3197
3198 let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("eta");
3199 let mut y_dev = stream.alloc_zeros::<f64>(N).expect("y");
3200 let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("prior");
3201 stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
3202 stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
3203 stream
3204 .memcpy_htod(&priors, &mut prior_dev)
3205 .expect("up prior");
3206
3207 let spec = JitFamilySpec::raw(spec_id, raw_body);
3208 let mut out_jit = RowOutputDevBuffers::allocate(&stream, N).expect("alloc jit out");
3209 launch_row_reweight_jit_on_stream(
3210 backend,
3211 &spec,
3212 curvature,
3213 &stream,
3214 N,
3215 &eta_dev,
3216 &y_dev,
3217 &prior_dev,
3218 &mut out_jit,
3219 )
3220 .unwrap_or_else(|err| panic!("jit raw-body launch {family:?}: {err}"));
3221 stream.synchronize().expect("sync");
3222
3223 let mu_j = stream.clone_dtoh(&out_jit.mu).expect("dl mu");
3224 let ge_j = stream.clone_dtoh(&out_jit.grad_eta).expect("dl g");
3225 let wf_j = stream.clone_dtoh(&out_jit.w_fisher).expect("dl wf");
3226 let wh_j = stream.clone_dtoh(&out_jit.w_hessian).expect("dl wh");
3227 let ws_j = stream.clone_dtoh(&out_jit.w_solver).expect("dl ws");
3228 let zf_j = stream.clone_dtoh(&out_jit.z_fisher).expect("dl zf");
3229 let zh_j = stream.clone_dtoh(&out_jit.z_hessian).expect("dl zh");
3230 let dv_j = stream.clone_dtoh(&out_jit.deviance).expect("dl dv");
3231
3232 for i in 0..N {
3233 let cpu = row_reweight_cpu(
3234 family,
3235 curvature,
3236 RowInput {
3237 eta: etas[i],
3238 y: ys[i],
3239 prior_weight: priors[i],
3240 },
3241 1.0,
3242 );
3243 for (label, got, exp) in [
3244 ("mu", mu_j[i], cpu.mu),
3245 ("grad_eta", ge_j[i], cpu.grad_eta),
3246 ("w_fisher", wf_j[i], cpu.w_fisher),
3247 ("w_hessian", wh_j[i], cpu.w_hessian),
3248 ("w_solver", ws_j[i], cpu.w_solver),
3249 ("z_fisher", zf_j[i], cpu.z_fisher),
3250 ("z_hessian", zh_j[i], cpu.z_hessian),
3251 ("deviance", dv_j[i], cpu.deviance),
3252 ] {
3253 if !got.is_finite() && !exp.is_finite() {
3254 continue;
3257 }
3258 let abs_err = (got - exp).abs();
3259 let rel_err = if exp.abs() > 0.0 {
3260 abs_err / exp.abs()
3261 } else {
3262 abs_err
3263 };
3264 assert!(
3265 abs_err <= TOL || rel_err <= TOL,
3266 "{family:?} {label}[{i}] (eta={}, y={}, wp={}): \
3267 JIT raw-body={} vs CPU={} (abs={}, rel={})",
3268 etas[i],
3269 ys[i],
3270 priors[i],
3271 got,
3272 exp,
3273 abs_err,
3274 rel_err,
3275 );
3276 }
3277 }
3278 }
3279 }
3280 }
3281}