1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2
3#[derive(Clone, Debug)]
4pub struct PirlsGpuInput<'a> {
5 pub x: ArrayView2<'a, f64>,
6 pub weights: ArrayView1<'a, f64>,
7 pub penalty_hessian: ArrayView2<'a, f64>,
8 pub gradient: ArrayView1<'a, f64>,
12 pub step_lm_lambda: f64,
16 pub objective_ridge: f64,
19}
20
21#[derive(Clone, Debug)]
22pub struct PirlsGpuStep {
23 pub penalized_hessian: Array2<f64>,
24 pub direction: Array1<f64>,
25 pub logdet: f64,
26}
27
28#[derive(Clone, Debug)]
36pub struct PirlsStepStreamInput<'a> {
37 pub weights: ArrayView1<'a, f64>,
38 pub penalty_hessian: ArrayView2<'a, f64>,
39 pub gradient: ArrayView1<'a, f64>,
40 pub step_lm_lambda: f64,
43 pub objective_ridge: f64,
46}
47
48#[cfg(target_os = "linux")]
58pub struct PirlsStepStreamDeviceInput<'a, 'b> {
59 pub w_solver_dev: &'a cudarc::driver::CudaSlice<f64>,
62 pub grad_eta_dev: &'b cudarc::driver::CudaSlice<f64>,
65 pub penalty_hessian: ArrayView2<'b, f64>,
67 pub step_lm_lambda: f64,
70 pub objective_ridge: f64,
73 pub beta_dev: &'b cudarc::driver::CudaSlice<f64>,
77 pub linear_shift: ArrayView1<'b, f64>,
80}
81
82#[cfg(target_os = "linux")]
90pub struct PirlsGpuSharedData {
91 pub(crate) ctx: std::sync::Arc<cudarc::driver::CudaContext>,
92 pub(crate) n: usize,
93 pub(crate) p: usize,
94 pub(crate) x_original_dev: cudarc::driver::CudaSlice<f64>,
97 pub(crate) y_dev: cudarc::driver::CudaSlice<f64>,
99 pub(crate) prior_w_dev: cudarc::driver::CudaSlice<f64>,
101 pub(crate) offset_dev: cudarc::driver::CudaSlice<f64>,
103}
104
105#[cfg(target_os = "linux")]
119pub struct SigmaPirlsGpuWorkspace {
120 pub(crate) stream: std::sync::Arc<cudarc::driver::CudaStream>,
121 pub(crate) blas: cudarc::cublas::CudaBlas,
122 pub(crate) solver: cudarc::cusolver::DnHandle,
123 pub(crate) wx_dev: Option<cudarc::driver::CudaSlice<f64>>,
126 pub(crate) w_dev: cudarc::driver::CudaSlice<f64>,
127 pub(crate) xtwx_dev: cudarc::driver::CudaSlice<f64>,
129 pub(crate) h_dev: cudarc::driver::CudaSlice<f64>,
130 pub(crate) rhs_dev: cudarc::driver::CudaSlice<f64>,
131 pub(crate) penalty_dev: cudarc::driver::CudaSlice<f64>,
132 pub(crate) qs_dev: cudarc::driver::CudaSlice<f64>,
137 pub(crate) qs_tmp_dev: cudarc::driver::CudaSlice<f64>,
140 pub(crate) beta_orig_dev: cudarc::driver::CudaSlice<f64>,
142 pub(crate) dir_orig_dev: cudarc::driver::CudaSlice<f64>,
144 pub(crate) potrf_work_dev: cudarc::driver::CudaSlice<f64>,
147 pub(crate) potrf_lwork: i32,
150 pub(crate) potrf_info_dev: cudarc::driver::CudaSlice<i32>,
154 pub(crate) potrs_info_dev: cudarc::driver::CudaSlice<i32>,
156 pub(crate) n: usize,
157 pub(crate) p: usize,
158}
159
160#[cfg(target_os = "linux")]
161pub(crate) mod cuda {
162 use super::{
163 PirlsGpuInput, PirlsGpuSharedData, PirlsGpuStep, PirlsStepStreamDeviceInput,
164 PirlsStepStreamInput, SigmaPirlsGpuWorkspace,
165 };
166 use gam_gpu::device_cache::PtxModuleCache;
167 use gam_gpu::driver::{from_col_major, to_col_major};
168 use gam_gpu::solver::{
169 check_deferred_potrf_info, check_deferred_potrs_info, context_and_stream, pinned_htod,
170 potrf_in_place_reuse, potrf_query_lwork, potrs_in_place_reuse,
171 };
172 use cudarc::cublas::sys::{
173 cublasDdgmm, cublasDgeam, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
174 };
175 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
176 use cudarc::cusolver::DnHandle;
177 use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut, LaunchConfig, PushKernelArg};
178 use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
179
180 const CHOL_LOGDET_PTX_SOURCE: &str = r#"
187extern "C" __global__ void chol_logdet_col_major(
188 const double* __restrict__ factor,
189 int p,
190 double* __restrict__ out
191) {
192 if (threadIdx.x != 0 || blockIdx.x != 0) return;
193 double acc = 0.0;
194 long long pp = (long long)p;
195 for (long long i = 0; i < pp; ++i) {
196 acc += log(factor[i * pp + i]);
197 }
198 out[0] = 2.0 * acc;
199}
200"#;
201
202 static CHOL_LOGDET_CACHE: PtxModuleCache = PtxModuleCache::new();
203
204 const FUSED_XTWX_P_THRESHOLD: usize = 256;
209
210 const FUSED_XTWX_PTX_SOURCE: &str = concat!(
220 "extern \"C\" __global__ void xtwx_lower(",
226 "const double* __restrict__ X,",
227 "const double* __restrict__ w,",
228 "double* __restrict__ A,",
229 "int n, int p) {",
230 "int t=blockIdx.x*blockDim.x+threadIdx.x;",
231 "int np=p*(p+1)/2; if(t>=np)return;",
232 "int jv=(int)((__dsqrt_rn((double)(8*t+1))-1.0)*0.5);",
234 "while((long long)(jv+1)*(jv+2)/2<=t)jv++;",
235 "while(jv>0&&(long long)jv*(jv+1)/2>t)jv--;",
236 "int kv=t-(int)((long long)jv*(jv+1)/2);",
237 "double acc=0.0;",
238 "const double*Xj=X+(long long)jv*n;",
239 "const double*Xk=X+(long long)kv*n;",
240 "for(int i=0;i<n;++i)acc+=w[i]*Xj[i]*Xk[i];",
241 "A[jv+(long long)kv*p]=acc;}",
243 "extern \"C\" __global__ void xtscore(",
245 "const double* __restrict__ X,",
246 "const double* __restrict__ score,",
247 "double* __restrict__ s,",
248 "int n, int p) {",
249 "int j=blockIdx.x*blockDim.x+threadIdx.x;",
250 "if(j>=p)return;",
251 "double acc=0.0;",
252 "const double*Xj=X+(long long)j*n;",
253 "for(int i=0;i<n;++i)acc+=score[i]*Xj[i];",
254 "s[j]=acc;}",
255 "extern \"C\" __global__ void symmetrize_lower(",
260 "double* __restrict__ A, int p) {",
261 "int ns=p*(p-1)/2;",
262 "int t=blockIdx.x*blockDim.x+threadIdx.x;",
263 "if(t>=ns)return;",
264 "int jv=(int)((__dsqrt_rn((double)(8*t+1))+1.0)*0.5);",
266 "while((long long)jv*(jv-1)/2>t)jv--;",
267 "while((long long)(jv+1)*jv/2<=t)jv++;",
268 "int kv=t-(int)((long long)jv*(jv-1)/2);",
269 "A[kv+(long long)jv*p]=A[jv+(long long)kv*p];}",
271 );
272
273 static FUSED_XTWX_CACHE: PtxModuleCache = PtxModuleCache::new();
274
275 impl PirlsGpuSharedData {
276 pub(crate) fn upload_impl(
280 x: ArrayView2<'_, f64>,
281 y: ArrayView1<'_, f64>,
282 prior_w: ArrayView1<'_, f64>,
283 offset: ArrayView1<'_, f64>,
284 ) -> Result<Self, String> {
285 let (n, p) = x.dim();
286 if n == 0 || p == 0 {
287 return Err("empty design cannot be uploaded".to_string());
288 }
289 if y.len() != n || prior_w.len() != n || offset.len() != n {
290 return Err(format!(
291 "y/prior_w/offset length mismatch (y={}, w={}, offset={}, n={n})",
292 y.len(),
293 prior_w.len(),
294 offset.len()
295 ));
296 }
297 let (ctx, stream) = context_and_stream()?;
298 let x_col = to_col_major(&x);
299 let x_original_dev = pinned_htod(&stream, &x_col)?;
300 let y_dev = pinned_htod(&stream, y.as_slice().ok_or("y not contiguous")?)?;
301 let prior_w_dev =
302 pinned_htod(&stream, prior_w.as_slice().ok_or("prior_w not contiguous")?)?;
303 let offset_dev =
304 pinned_htod(&stream, offset.as_slice().ok_or("offset not contiguous")?)?;
305 stream
309 .synchronize()
310 .map_err(|e| format!("cuda sync after model upload: {e}"))?;
311 Ok(Self {
312 ctx,
313 n,
314 p,
315 x_original_dev,
316 y_dev,
317 prior_w_dev,
318 offset_dev,
319 })
320 }
321 }
322
323 impl SigmaPirlsGpuWorkspace {
324 pub(crate) fn allocate_impl(shared: &PirlsGpuSharedData) -> Result<Self, String> {
330 let n = shared.n;
331 let p = shared.p;
332 let stream = shared
333 .ctx
334 .new_stream()
335 .map_err(|e| format!("cuda stream alloc: {e}"))?;
336 let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
337 let solver =
338 DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
339 let np = n.checked_mul(p).ok_or("X size overflow")?;
340 let pp = p.checked_mul(p).ok_or("H size overflow")?;
341 let wx_dev = if p >= FUSED_XTWX_P_THRESHOLD {
343 Some(
344 stream
345 .alloc_zeros::<f64>(np)
346 .map_err(|e| format!("cuda alloc WX: {e}"))?,
347 )
348 } else {
349 None
350 };
351 let w_dev = stream
352 .alloc_zeros::<f64>(n)
353 .map_err(|e| format!("cuda alloc W: {e}"))?;
354 let xtwx_dev = stream
355 .alloc_zeros::<f64>(pp)
356 .map_err(|e| format!("cuda alloc XtWX: {e}"))?;
357 let h_dev = stream
358 .alloc_zeros::<f64>(pp)
359 .map_err(|e| format!("cuda alloc H: {e}"))?;
360 let rhs_dev = stream
361 .alloc_zeros::<f64>(p)
362 .map_err(|e| format!("cuda alloc RHS: {e}"))?;
363 let penalty_dev = stream
364 .alloc_zeros::<f64>(pp)
365 .map_err(|e| format!("cuda alloc penalty: {e}"))?;
366 let mut qs_dev = stream
368 .alloc_zeros::<f64>(pp)
369 .map_err(|e| format!("cuda alloc Qs: {e}"))?;
370 {
372 let mut qs_host = vec![0.0_f64; pp];
373 for i in 0..p {
374 qs_host[i * p + i] = 1.0;
375 }
376 stream
377 .memcpy_htod(&qs_host, &mut qs_dev)
378 .map_err(|e| format!("init Qs identity: {e}"))?;
379 }
380 let qs_tmp_dev = stream
381 .alloc_zeros::<f64>(pp)
382 .map_err(|e| format!("cuda alloc Qs tmp: {e}"))?;
383 let beta_orig_dev = stream
384 .alloc_zeros::<f64>(p)
385 .map_err(|e| format!("cuda alloc beta_orig: {e}"))?;
386 let dir_orig_dev = stream
387 .alloc_zeros::<f64>(p)
388 .map_err(|e| format!("cuda alloc dir_orig: {e}"))?;
389 let potrf_lwork_usize = potrf_query_lwork(&solver, &stream, p)?;
393 let potrf_lwork = i32::try_from(potrf_lwork_usize)
394 .map_err(|_| format!("potrf lwork {potrf_lwork_usize} exceeds i32"))?;
395 let alloc_len = potrf_lwork_usize.max(1);
398 let potrf_work_dev = stream
399 .alloc_zeros::<f64>(alloc_len)
400 .map_err(|e| format!("cuda alloc potrf workspace: {e}"))?;
401 let potrf_info_dev = stream
402 .alloc_zeros::<i32>(1)
403 .map_err(|e| format!("cuda alloc potrf info: {e}"))?;
404 let potrs_info_dev = stream
405 .alloc_zeros::<i32>(1)
406 .map_err(|e| format!("cuda alloc potrs info: {e}"))?;
407 Ok(Self {
408 stream,
409 blas,
410 solver,
411 wx_dev,
412 w_dev,
413 xtwx_dev,
414 h_dev,
415 rhs_dev,
416 penalty_dev,
417 qs_dev,
418 qs_tmp_dev,
419 beta_orig_dev,
420 dir_orig_dev,
421 potrf_work_dev,
422 potrf_lwork,
423 potrf_info_dev,
424 potrs_info_dev,
425 n,
426 p,
427 })
428 }
429 }
430
431 pub(super) fn upload_qs(
435 ws: &mut SigmaPirlsGpuWorkspace,
436 qs: ArrayView2<'_, f64>,
437 ) -> Result<(), String> {
438 let p = ws.p;
439 if qs.dim() != (p, p) {
440 return Err(format!("upload_qs: Qs shape {:?} != ({p},{p})", qs.dim()));
441 }
442 let qs_col = to_col_major(&qs);
443 ws.stream
444 .memcpy_htod(qs_col.as_ref(), &mut ws.qs_dev)
445 .map_err(|e| format!("upload Qs: {e}"))
446 }
447
448 pub(super) fn upload_qs_identity(ws: &mut SigmaPirlsGpuWorkspace) -> Result<(), String> {
450 let p = ws.p;
451 let pp = p * p;
452 let mut qs_host = vec![0.0_f64; pp];
453 for i in 0..p {
454 qs_host[i * p + i] = 1.0;
455 }
456 ws.stream
457 .memcpy_htod(&qs_host, &mut ws.qs_dev)
458 .map_err(|e| format!("upload Qs identity: {e}"))
459 }
460
461 fn newton_step_refine_once(
470 solver: &cudarc::cusolver::DnHandle,
471 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
472 p: usize,
473 chol_factor_dev: &CudaSlice<f64>,
474 rhs_dev: &mut CudaSlice<f64>,
475 potrs_info_dev: &mut CudaSlice<i32>,
476 mut direction_raw: Vec<f64>,
477 g: &[f64],
478 penalized_hessian: &ndarray::Array2<f64>,
479 step_lm_delta: f64,
480 ) -> Result<Vec<f64>, String> {
481 use gam_gpu::policy::GpuDispatchPolicy;
482 if p < GpuDispatchPolicy::REFINEMENT_MIN_P {
483 return Ok(direction_raw);
484 }
485 let norm_g = g.iter().map(|v| v * v).sum::<f64>().sqrt();
486 if norm_g == 0.0 {
487 return Ok(direction_raw);
488 }
489 let hx: Vec<f64> = (0..p)
490 .map(|i| {
491 penalized_hessian
492 .row(i)
493 .iter()
494 .zip(direction_raw.iter())
495 .map(|(hij, xj)| hij * xj)
496 .sum::<f64>()
497 + step_lm_delta * direction_raw[i]
498 })
499 .collect();
500 let residual: Vec<f64> = g.iter().zip(hx.iter()).map(|(gi, hxi)| gi - hxi).collect();
501 let rel_res = residual.iter().map(|v| v * v).sum::<f64>().sqrt() / norm_g;
502 if rel_res <= GpuDispatchPolicy::REFINEMENT_TOL {
503 return Ok(direction_raw);
504 }
505 stream
506 .memcpy_htod(&residual, rhs_dev)
507 .map_err(|e| format!("upload residual: {e}"))?;
508 potrs_in_place_reuse(
509 solver,
510 stream,
511 p,
512 1,
513 chol_factor_dev,
514 rhs_dev,
515 potrs_info_dev,
516 )?;
517 let correction = stream
518 .clone_dtoh(rhs_dev)
519 .map_err(|e| format!("download correction: {e}"))?;
520 check_deferred_potrs_info(stream, potrs_info_dev)?;
521 for (xi, ei) in direction_raw.iter_mut().zip(correction.iter()) {
522 *xi += ei;
523 }
524 Ok(direction_raw)
525 }
526
527 pub(super) fn solve_step_on_stream(
538 shared: &PirlsGpuSharedData,
539 ws: &mut SigmaPirlsGpuWorkspace,
540 input: PirlsStepStreamInput<'_>,
541 ) -> Result<PirlsGpuStep, String> {
542 let n = shared.n;
543 let p = shared.p;
544 if ws.n != n || ws.p != p {
545 return Err(format!(
546 "workspace shape ({}, {}) does not match shared design ({n}, {p})",
547 ws.n, ws.p
548 ));
549 }
550 if input.weights.len() != n {
551 return Err(format!(
552 "weights length {} does not match rows {n}",
553 input.weights.len()
554 ));
555 }
556 if input.penalty_hessian.dim() != (p, p) {
557 return Err(format!(
558 "penalty Hessian shape {:?} does not match p={p}",
559 input.penalty_hessian.dim()
560 ));
561 }
562 if input.gradient.len() != p {
563 return Err(format!(
564 "gradient length {} does not match p={p}",
565 input.gradient.len()
566 ));
567 }
568
569 let w_slice = input
571 .weights
572 .as_slice()
573 .ok_or("weights must be contiguous")?;
574 ws.stream
575 .memcpy_htod(w_slice, &mut ws.w_dev)
576 .map_err(|e| format!("upload W: {e}"))?;
577
578 let n_i = to_i32(n)?;
582 let p_i = to_i32(p)?;
583 if let Some(ref mut wx_dev) = ws.wx_dev {
584 left_scale_rows(
585 &ws.blas,
586 &ws.stream,
587 n,
588 p,
589 &shared.x_original_dev,
590 &mut ws.w_dev,
591 wx_dev,
592 )?;
593 let cfg = GemmConfig::<f64> {
594 transa: cublasOperation_t::CUBLAS_OP_T,
595 transb: cublasOperation_t::CUBLAS_OP_N,
596 m: p_i,
597 n: p_i,
598 k: n_i,
599 alpha: 1.0,
600 lda: n_i,
601 ldb: n_i,
602 beta: 0.0,
603 ldc: p_i,
604 };
605 unsafe {
608 ws.blas
609 .gemm(cfg, &shared.x_original_dev, wx_dev, &mut ws.xtwx_dev)
610 }
611 .map_err(|e| format!("cublas dgemm XtWX: {e}"))?;
612 } else {
613 launch_xtwx_lower(
614 &ws.stream,
615 &shared.ctx,
616 n,
617 p,
618 &shared.x_original_dev,
619 &ws.w_dev,
620 &mut ws.xtwx_dev,
621 )?;
622 launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
623 }
624
625 let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
627 let penalty_step_view = penalty_step.view();
628 let penalty_step_col = to_col_major(&penalty_step_view);
629 ws.stream
630 .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
631 .map_err(|e| format!("upload penalty: {e}"))?;
632
633 {
637 let cfg_aq = GemmConfig::<f64> {
638 transa: cublasOperation_t::CUBLAS_OP_N,
639 transb: cublasOperation_t::CUBLAS_OP_N,
640 m: p_i,
641 n: p_i,
642 k: p_i,
643 alpha: 1.0,
644 lda: p_i,
645 ldb: p_i,
646 beta: 0.0,
647 ldc: p_i,
648 };
649 unsafe {
651 ws.blas
652 .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
653 }
654 .map_err(|e| format!("dgemm A·Qs (host-input step): {e}"))?;
655 }
656 {
657 let cfg_qt = GemmConfig::<f64> {
658 transa: cublasOperation_t::CUBLAS_OP_T,
659 transb: cublasOperation_t::CUBLAS_OP_N,
660 m: p_i,
661 n: p_i,
662 k: p_i,
663 alpha: 1.0,
664 lda: p_i,
665 ldb: p_i,
666 beta: 0.0,
667 ldc: p_i,
668 };
669 unsafe {
671 ws.blas
672 .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
673 }
674 .map_err(|e| format!("dgemm Qsᵀ·A·Qs (host-input step): {e}"))?;
675 }
676 geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
678
679 let g_slice = input
683 .gradient
684 .as_slice()
685 .ok_or("gradient must be contiguous")?;
686 ws.stream
687 .memcpy_htod(g_slice, &mut ws.rhs_dev)
688 .map_err(|e| format!("upload gradient: {e}"))?;
689
690 let xtwx_col = ws
694 .stream
695 .clone_dtoh(&ws.xtwx_dev)
696 .map_err(|e| format!("download XᵀWX (host-input step): {e}"))?;
697 let xtwx_host = from_col_major(&xtwx_col, p, p).ok_or("XᵀWX layout conversion failed")?;
698 let qs_col = ws
699 .stream
700 .clone_dtoh(&ws.qs_dev)
701 .map_err(|e| format!("download Qs (host-input step): {e}"))?;
702 let qs_host =
703 from_col_major(&qs_col, p, p).ok_or("Qs layout conversion failed (host-input step)")?;
704 let tmp_aq = xtwx_host.dot(&qs_host);
705 let h_rotated = qs_host.t().dot(&tmp_aq);
706 let penalty_export = penalty_with_ridge(input.penalty_hessian, input.objective_ridge);
707 let penalized_hessian = h_rotated + &penalty_export;
708
709 potrf_in_place_reuse(
712 &ws.solver,
713 &ws.stream,
714 p,
715 ws.potrf_lwork,
716 &mut ws.h_dev,
717 &mut ws.potrf_work_dev,
718 &mut ws.potrf_info_dev,
719 )?;
720 potrs_in_place_reuse(
721 &ws.solver,
722 &ws.stream,
723 p,
724 1,
725 &ws.h_dev,
726 &mut ws.rhs_dev,
727 &mut ws.potrs_info_dev,
728 )?;
729
730 let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
734
735 let direction_raw = ws
737 .stream
738 .clone_dtoh(&ws.rhs_dev)
739 .map_err(|e| format!("download direction: {e}"))?;
740 check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
744 check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
745
746 let lm_ridge_delta = input.step_lm_lambda - input.objective_ridge;
750 let direction_raw = newton_step_refine_once(
751 &ws.solver,
752 &ws.stream,
753 p,
754 &ws.h_dev,
755 &mut ws.rhs_dev,
756 &mut ws.potrs_info_dev,
757 direction_raw,
758 g_slice,
759 &penalized_hessian,
760 lm_ridge_delta,
761 )?;
762
763 let direction = Array1::from_vec(direction_raw);
766
767 Ok(PirlsGpuStep {
768 penalized_hessian,
769 direction,
770 logdet,
771 })
772 }
773
774 pub(super) fn solve_step_on_stream_device(
790 shared: &PirlsGpuSharedData,
791 ws: &mut SigmaPirlsGpuWorkspace,
792 input: PirlsStepStreamDeviceInput<'_, '_>,
793 ) -> Result<PirlsGpuStep, String> {
794 let n = shared.n;
795 let p = shared.p;
796 if ws.n != n || ws.p != p {
797 return Err(format!(
798 "workspace shape ({}, {}) does not match shared design ({n}, {p})",
799 ws.n, ws.p
800 ));
801 }
802 if input.w_solver_dev.len() != n {
803 return Err(format!(
804 "w_solver_dev length {} does not match n={n}",
805 input.w_solver_dev.len()
806 ));
807 }
808 if input.grad_eta_dev.len() != n {
809 return Err(format!(
810 "grad_eta_dev length {} does not match n={n}",
811 input.grad_eta_dev.len()
812 ));
813 }
814 if input.penalty_hessian.dim() != (p, p) {
815 return Err(format!(
816 "penalty Hessian shape {:?} does not match p={p}",
817 input.penalty_hessian.dim()
818 ));
819 }
820
821 let n_i = to_i32(n)?;
824 let p_i = to_i32(p)?;
825 if let Some(ref mut wx_dev_fb) = ws.wx_dev {
826 left_scale_rows_borrowed(
828 &ws.blas,
829 &ws.stream,
830 n,
831 p,
832 &shared.x_original_dev,
833 input.w_solver_dev,
834 wx_dev_fb,
835 )?;
836 let gemm_cfg = GemmConfig::<f64> {
837 transa: cublasOperation_t::CUBLAS_OP_T,
838 transb: cublasOperation_t::CUBLAS_OP_N,
839 m: p_i,
840 n: p_i,
841 k: n_i,
842 alpha: 1.0,
843 lda: n_i,
844 ldb: n_i,
845 beta: 0.0,
846 ldc: p_i,
847 };
848 unsafe {
851 ws.blas.gemm(
852 gemm_cfg,
853 &shared.x_original_dev,
854 wx_dev_fb,
855 &mut ws.xtwx_dev,
856 )
857 }
858 .map_err(|e| format!("cublas dgemm XtWX (device-input): {e}"))?;
859 let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
860 let penalty_step_col = to_col_major(&penalty_step);
861 ws.stream
862 .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
863 .map_err(|e| format!("upload penalty (device-input): {e}"))?;
864 {
866 let cfg_aq = GemmConfig::<f64> {
867 transa: cublasOperation_t::CUBLAS_OP_N,
868 transb: cublasOperation_t::CUBLAS_OP_N,
869 m: p_i,
870 n: p_i,
871 k: p_i,
872 alpha: 1.0,
873 lda: p_i,
874 ldb: p_i,
875 beta: 0.0,
876 ldc: p_i,
877 };
878 unsafe {
880 ws.blas
881 .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
882 }
883 .map_err(|e| format!("dgemm A·Qs (device-input large-p): {e}"))?;
884 }
885 {
886 let cfg_qt = GemmConfig::<f64> {
887 transa: cublasOperation_t::CUBLAS_OP_T,
888 transb: cublasOperation_t::CUBLAS_OP_N,
889 m: p_i,
890 n: p_i,
891 k: p_i,
892 alpha: 1.0,
893 lda: p_i,
894 ldb: p_i,
895 beta: 0.0,
896 ldc: p_i,
897 };
898 unsafe {
900 ws.blas
901 .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
902 }
903 .map_err(|e| format!("dgemm Qsᵀ·A·Qs (device-input large-p): {e}"))?;
904 }
905 geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
906 let gemv_cfg = GemvConfig::<f64> {
907 trans: cublasOperation_t::CUBLAS_OP_T,
908 m: n_i,
909 n: p_i,
910 alpha: 1.0,
911 lda: n_i,
912 incx: 1,
913 beta: 0.0,
914 incy: 1,
915 };
916 unsafe {
918 ws.blas.gemv(
919 gemv_cfg,
920 &shared.x_original_dev,
921 input.grad_eta_dev,
922 &mut ws.rhs_dev,
923 )
924 }
925 .map_err(|e| format!("cublas dgemv Xtg (device-input): {e}"))?;
926 } else {
927 launch_xtwx_lower(
929 &ws.stream,
930 &shared.ctx,
931 n,
932 p,
933 &shared.x_original_dev,
934 input.w_solver_dev,
935 &mut ws.xtwx_dev,
936 )?;
937 launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
938 launch_xtscore(
939 &ws.stream,
940 &shared.ctx,
941 n,
942 p,
943 &shared.x_original_dev,
944 input.grad_eta_dev,
945 &mut ws.rhs_dev,
946 )?;
947 {
949 let cfg_aq = GemmConfig::<f64> {
950 transa: cublasOperation_t::CUBLAS_OP_N,
951 transb: cublasOperation_t::CUBLAS_OP_N,
952 m: p_i,
953 n: p_i,
954 k: p_i,
955 alpha: 1.0,
956 lda: p_i,
957 ldb: p_i,
958 beta: 0.0,
959 ldc: p_i,
960 };
961 unsafe {
963 ws.blas
964 .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
965 }
966 .map_err(|e| format!("dgemm A·Qs (device-input fused): {e}"))?;
967 }
968 {
969 let cfg_qt = GemmConfig::<f64> {
970 transa: cublasOperation_t::CUBLAS_OP_T,
971 transb: cublasOperation_t::CUBLAS_OP_N,
972 m: p_i,
973 n: p_i,
974 k: p_i,
975 alpha: 1.0,
976 lda: p_i,
977 ldb: p_i,
978 beta: 0.0,
979 ldc: p_i,
980 };
981 unsafe {
983 ws.blas
984 .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
985 }
986 .map_err(|e| format!("dgemm Qsᵀ·A·Qs (device-input fused): {e}"))?;
987 }
988 let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
989 let penalty_step_col = to_col_major(&penalty_step);
990 ws.stream
991 .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
992 .map_err(|e| format!("upload penalty (fused device-input): {e}"))?;
993 geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
994 }
995
996 {
1001 let cfg_qts = GemvConfig::<f64> {
1003 trans: cublasOperation_t::CUBLAS_OP_T,
1004 m: p_i,
1005 n: p_i,
1006 alpha: 1.0,
1007 lda: p_i,
1008 incx: 1,
1009 beta: 0.0,
1010 incy: 1,
1011 };
1012 unsafe {
1014 ws.blas
1015 .gemv(cfg_qts, &ws.qs_dev, &ws.rhs_dev, &mut ws.beta_orig_dev)
1016 }
1017 .map_err(|e| format!("dgemv Qsᵀ·score (device-input): {e}"))?;
1018 ws.stream
1020 .memcpy_dtod(&ws.beta_orig_dev, &mut ws.rhs_dev)
1021 .map_err(|e| format!("d2d Qsᵀ·score→rhs (device-input): {e}"))?;
1022 let rhs_raw = ws
1024 .stream
1025 .clone_dtoh(&ws.rhs_dev)
1026 .map_err(|e| format!("download Qsᵀscore (device-input): {e}"))?;
1027 let beta_raw = ws
1028 .stream
1029 .clone_dtoh(input.beta_dev)
1030 .map_err(|e| format!("download beta (device-input): {e}"))?;
1031 let mut rhs_host = Array1::from_vec(rhs_raw);
1032 let beta_host = Array1::from_vec(beta_raw);
1033 let s_beta = input.penalty_hessian.dot(&beta_host);
1034 rhs_host -= &s_beta;
1035 rhs_host += &input.linear_shift;
1036 ws.stream
1037 .memcpy_htod(
1038 rhs_host
1039 .as_slice()
1040 .ok_or("rhs_host not contiguous (device-input correction)")?,
1041 &mut ws.rhs_dev,
1042 )
1043 .map_err(|e| format!("re-upload corrected rhs (device-input): {e}"))?;
1044 }
1045
1046 let xtwx_col = ws
1050 .stream
1051 .clone_dtoh(&ws.xtwx_dev)
1052 .map_err(|e| format!("download XᵀWX (device-input): {e}"))?;
1053 let xtwx_host = from_col_major(&xtwx_col, p, p)
1054 .ok_or("XᵀWX layout conversion failed (device-input)")?;
1055 let qs_col = ws
1056 .stream
1057 .clone_dtoh(&ws.qs_dev)
1058 .map_err(|e| format!("download Qs (device-input): {e}"))?;
1059 let qs_host =
1060 from_col_major(&qs_col, p, p).ok_or("Qs layout conversion failed (device-input)")?;
1061 let tmp_aq = xtwx_host.dot(&qs_host);
1062 let h_rotated = qs_host.t().dot(&tmp_aq);
1063 let penalty_export = penalty_with_ridge(input.penalty_hessian, input.objective_ridge);
1064 let penalized_hessian = h_rotated + &penalty_export;
1065
1066 potrf_in_place_reuse(
1069 &ws.solver,
1070 &ws.stream,
1071 p,
1072 ws.potrf_lwork,
1073 &mut ws.h_dev,
1074 &mut ws.potrf_work_dev,
1075 &mut ws.potrf_info_dev,
1076 )?;
1077 potrs_in_place_reuse(
1078 &ws.solver,
1079 &ws.stream,
1080 p,
1081 1,
1082 &ws.h_dev,
1083 &mut ws.rhs_dev,
1084 &mut ws.potrs_info_dev,
1085 )?;
1086
1087 let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
1088
1089 let direction_raw = ws
1090 .stream
1091 .clone_dtoh(&ws.rhs_dev)
1092 .map_err(|e| format!("download direction (device-input): {e}"))?;
1093 check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
1097 check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
1098 let direction = Array1::from_vec(direction_raw);
1101
1102 Ok(PirlsGpuStep {
1103 penalized_hessian,
1104 direction,
1105 logdet,
1106 })
1107 }
1108
1109 pub(super) fn solve_step_on_stream_device_inplace(
1120 shared: &PirlsGpuSharedData,
1121 ws: &mut SigmaPirlsGpuWorkspace,
1122 input: PirlsStepStreamDeviceInput<'_, '_>,
1123 ) -> Result<f64, String> {
1124 let n = shared.n;
1125 let p = shared.p;
1126 if ws.n != n || ws.p != p {
1127 return Err(format!(
1128 "workspace shape ({}, {}) does not match shared design ({n}, {p})",
1129 ws.n, ws.p
1130 ));
1131 }
1132 if input.w_solver_dev.len() != n {
1133 return Err(format!(
1134 "w_solver_dev length {} does not match n={n}",
1135 input.w_solver_dev.len()
1136 ));
1137 }
1138 if input.grad_eta_dev.len() != n {
1139 return Err(format!(
1140 "grad_eta_dev length {} does not match n={n}",
1141 input.grad_eta_dev.len()
1142 ));
1143 }
1144 if input.penalty_hessian.dim() != (p, p) {
1145 return Err(format!(
1146 "penalty Hessian shape {:?} does not match p={p}",
1147 input.penalty_hessian.dim()
1148 ));
1149 }
1150
1151 if input.linear_shift.len() != p {
1152 return Err(format!(
1153 "linear_shift length {} does not match p={p}",
1154 input.linear_shift.len()
1155 ));
1156 }
1157 let n_i = to_i32(n)?;
1158 let p_i = to_i32(p)?;
1159
1160 if let Some(ref mut wx_dev_ib) = ws.wx_dev {
1163 left_scale_rows_borrowed(
1165 &ws.blas,
1166 &ws.stream,
1167 n,
1168 p,
1169 &shared.x_original_dev,
1170 input.w_solver_dev,
1171 wx_dev_ib,
1172 )?;
1173 let cfg_xtx = GemmConfig::<f64> {
1174 transa: cublasOperation_t::CUBLAS_OP_T,
1175 transb: cublasOperation_t::CUBLAS_OP_N,
1176 m: p_i,
1177 n: p_i,
1178 k: n_i,
1179 alpha: 1.0,
1180 lda: n_i,
1181 ldb: n_i,
1182 beta: 0.0,
1183 ldc: p_i,
1184 };
1185 unsafe {
1187 ws.blas
1188 .gemm(cfg_xtx, &shared.x_original_dev, wx_dev_ib, &mut ws.xtwx_dev)
1189 }
1190 .map_err(|e| format!("dgemm XtWX inplace (large-p): {e}"))?;
1191 let cfg_xts = GemvConfig::<f64> {
1192 trans: cublasOperation_t::CUBLAS_OP_T,
1193 m: n_i,
1194 n: p_i,
1195 alpha: 1.0,
1196 lda: n_i,
1197 incx: 1,
1198 beta: 0.0,
1199 incy: 1,
1200 };
1201 unsafe {
1203 ws.blas.gemv(
1204 cfg_xts,
1205 &shared.x_original_dev,
1206 input.grad_eta_dev,
1207 &mut ws.rhs_dev,
1208 )
1209 }
1210 .map_err(|e| format!("dgemv Xᵀ·score inplace (large-p): {e}"))?;
1211 } else {
1212 launch_xtwx_lower(
1214 &ws.stream,
1215 &shared.ctx,
1216 n,
1217 p,
1218 &shared.x_original_dev,
1219 input.w_solver_dev,
1220 &mut ws.xtwx_dev,
1221 )?;
1222 launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
1223 launch_xtscore(
1224 &ws.stream,
1225 &shared.ctx,
1226 n,
1227 p,
1228 &shared.x_original_dev,
1229 input.grad_eta_dev,
1230 &mut ws.rhs_dev,
1231 )?;
1232 }
1233
1234 {
1237 let cfg_aq = GemmConfig::<f64> {
1238 transa: cublasOperation_t::CUBLAS_OP_N,
1239 transb: cublasOperation_t::CUBLAS_OP_N,
1240 m: p_i,
1241 n: p_i,
1242 k: p_i,
1243 alpha: 1.0,
1244 lda: p_i,
1245 ldb: p_i,
1246 beta: 0.0,
1247 ldc: p_i,
1248 };
1249 unsafe {
1251 ws.blas
1252 .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
1253 }
1254 .map_err(|e| format!("dgemm A·Qs inplace: {e}"))?;
1255 }
1256 {
1258 let cfg_qt = GemmConfig::<f64> {
1259 transa: cublasOperation_t::CUBLAS_OP_T,
1260 transb: cublasOperation_t::CUBLAS_OP_N,
1261 m: p_i,
1262 n: p_i,
1263 k: p_i,
1264 alpha: 1.0,
1265 lda: p_i,
1266 ldb: p_i,
1267 beta: 0.0,
1268 ldc: p_i,
1269 };
1270 unsafe {
1272 ws.blas
1273 .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
1274 }
1275 .map_err(|e| format!("dgemm Qsᵀ·A·Qs inplace: {e}"))?;
1276 }
1277 let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
1279 let penalty_step_col = to_col_major(&penalty_step);
1280 ws.stream
1281 .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
1282 .map_err(|e| format!("upload penalty inplace: {e}"))?;
1283 geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
1284
1285 {
1289 let cfg_qts = GemvConfig::<f64> {
1290 trans: cublasOperation_t::CUBLAS_OP_T,
1291 m: p_i,
1292 n: p_i,
1293 alpha: 1.0,
1294 lda: p_i,
1295 incx: 1,
1296 beta: 0.0,
1297 incy: 1,
1298 };
1299 unsafe {
1301 ws.blas
1302 .gemv(cfg_qts, &ws.qs_dev, &ws.rhs_dev, &mut ws.beta_orig_dev)
1303 }
1304 .map_err(|e| format!("dgemv Qsᵀ·score inplace: {e}"))?;
1305 ws.stream
1306 .memcpy_dtod(&ws.beta_orig_dev, &mut ws.rhs_dev)
1307 .map_err(|e| format!("d2d Qsᵀ·score→rhs inplace: {e}"))?;
1308 }
1309 let rhs_raw = ws
1312 .stream
1313 .clone_dtoh(&ws.rhs_dev)
1314 .map_err(|e| format!("download Qsᵀ·score inplace: {e}"))?;
1315 let beta_raw = ws
1316 .stream
1317 .clone_dtoh(input.beta_dev)
1318 .map_err(|e| format!("download beta inplace: {e}"))?;
1319 let mut rhs_host = Array1::from_vec(rhs_raw);
1320 let beta_host = Array1::from_vec(beta_raw);
1321 let s_beta = input.penalty_hessian.dot(&beta_host);
1323 rhs_host -= &s_beta;
1324 rhs_host += &input.linear_shift;
1325 ws.stream
1326 .memcpy_htod(
1327 rhs_host.as_slice().ok_or("rhs_host not contiguous")?,
1328 &mut ws.rhs_dev,
1329 )
1330 .map_err(|e| format!("re-upload corrected rhs inplace: {e}"))?;
1331
1332 potrf_in_place_reuse(
1334 &ws.solver,
1335 &ws.stream,
1336 p,
1337 ws.potrf_lwork,
1338 &mut ws.h_dev,
1339 &mut ws.potrf_work_dev,
1340 &mut ws.potrf_info_dev,
1341 )?;
1342 potrs_in_place_reuse(
1343 &ws.solver,
1344 &ws.stream,
1345 p,
1346 1,
1347 &ws.h_dev,
1348 &mut ws.rhs_dev,
1349 &mut ws.potrs_info_dev,
1350 )?;
1351 let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
1352 check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
1353 check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
1354
1355 Ok(logdet)
1358 }
1359
1360 pub(super) fn rebuild_h_final(
1368 shared: &PirlsGpuSharedData,
1369 ws: &mut SigmaPirlsGpuWorkspace,
1370 w_hessian_dev: &CudaSlice<f64>,
1371 penalty_hessian: ArrayView2<'_, f64>,
1372 objective_ridge: f64,
1373 ) -> Result<Array2<f64>, String> {
1374 let n = shared.n;
1375 let p = shared.p;
1376
1377 if let Some(ref mut wx_dev_rh) = ws.wx_dev {
1379 left_scale_rows_borrowed(
1381 &ws.blas,
1382 &ws.stream,
1383 n,
1384 p,
1385 &shared.x_original_dev,
1386 w_hessian_dev,
1387 wx_dev_rh,
1388 )?;
1389 let n_i = to_i32(n)?;
1390 let p_i = to_i32(p)?;
1391 let gemm_cfg = GemmConfig::<f64> {
1392 transa: cublasOperation_t::CUBLAS_OP_T,
1393 transb: cublasOperation_t::CUBLAS_OP_N,
1394 m: p_i,
1395 n: p_i,
1396 k: n_i,
1397 alpha: 1.0,
1398 lda: n_i,
1399 ldb: n_i,
1400 beta: 0.0,
1401 ldc: p_i,
1402 };
1403 unsafe {
1406 ws.blas.gemm(
1407 gemm_cfg,
1408 &shared.x_original_dev,
1409 wx_dev_rh,
1410 &mut ws.xtwx_dev,
1411 )
1412 }
1413 .map_err(|e| format!("cublas dgemm XtWX (final H rebuild): {e}"))?;
1414 } else {
1415 launch_xtwx_lower(
1417 &ws.stream,
1418 &shared.ctx,
1419 n,
1420 p,
1421 &shared.x_original_dev,
1422 w_hessian_dev,
1423 &mut ws.xtwx_dev,
1424 )?;
1425 launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
1426 }
1427
1428 let p_i = to_i32(p)?;
1430 {
1432 let cfg_aq = GemmConfig::<f64> {
1433 transa: cublasOperation_t::CUBLAS_OP_N,
1434 transb: cublasOperation_t::CUBLAS_OP_N,
1435 m: p_i,
1436 n: p_i,
1437 k: p_i,
1438 alpha: 1.0,
1439 lda: p_i,
1440 ldb: p_i,
1441 beta: 0.0,
1442 ldc: p_i,
1443 };
1444 unsafe {
1446 ws.blas
1447 .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
1448 }
1449 .map_err(|e| format!("dgemm A·Qs (final H rebuild): {e}"))?;
1450 }
1451 {
1453 let cfg_qt = GemmConfig::<f64> {
1454 transa: cublasOperation_t::CUBLAS_OP_T,
1455 transb: cublasOperation_t::CUBLAS_OP_N,
1456 m: p_i,
1457 n: p_i,
1458 k: p_i,
1459 alpha: 1.0,
1460 lda: p_i,
1461 ldb: p_i,
1462 beta: 0.0,
1463 ldc: p_i,
1464 };
1465 unsafe {
1467 ws.blas
1468 .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
1469 }
1470 .map_err(|e| format!("dgemm Qsᵀ·A·Qs (final H rebuild): {e}"))?;
1471 }
1472 let penalty = penalty_with_ridge(penalty_hessian, objective_ridge);
1473 let penalty_col = to_col_major(&penalty);
1474 ws.stream
1475 .memcpy_htod(penalty_col.as_ref(), &mut ws.penalty_dev)
1476 .map_err(|e| format!("upload penalty (final H rebuild): {e}"))?;
1477 geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
1478
1479 let h_col = ws
1481 .stream
1482 .clone_dtoh(&ws.h_dev)
1483 .map_err(|e| format!("download H_final: {e}"))?;
1484 from_col_major(&h_col, p, p).ok_or_else(|| "H_final layout conversion failed".to_string())
1485 }
1486
1487 pub(super) fn weighted_crossprod(
1488 x: ArrayView2<'_, f64>,
1489 weights: ArrayView1<'_, f64>,
1490 ) -> Result<Array2<f64>, String> {
1491 let (_, stream) = context_and_stream()?;
1492 let (n, p) = validate_design(x, weights)?;
1493 let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
1494 let x_col = to_col_major(&x);
1495 let x_dev = pinned_htod(&stream, &x_col)?;
1496 let mut w_dev = pinned_htod(
1497 &stream,
1498 weights.as_slice().ok_or("weights must be contiguous")?,
1499 )?;
1500 let mut wx_dev = stream
1501 .alloc_zeros::<f64>(n.checked_mul(p).ok_or("X size overflow")?)
1502 .map_err(|e| format!("cuda alloc WX: {e}"))?;
1503 left_scale_rows(&blas, &stream, n, p, &x_dev, &mut w_dev, &mut wx_dev)?;
1504 let mut h_dev = stream
1505 .alloc_zeros::<f64>(p.checked_mul(p).ok_or("H size overflow")?)
1506 .map_err(|e| format!("cuda alloc H: {e}"))?;
1507 let n_i = to_i32(n)?;
1508 let p_i = to_i32(p)?;
1509 let cfg = GemmConfig::<f64> {
1510 transa: cublasOperation_t::CUBLAS_OP_T,
1511 transb: cublasOperation_t::CUBLAS_OP_N,
1512 m: p_i,
1513 n: p_i,
1514 k: n_i,
1515 alpha: 1.0,
1516 lda: n_i,
1517 ldb: n_i,
1518 beta: 0.0,
1519 ldc: p_i,
1520 };
1521 unsafe { blas.gemm(cfg, &x_dev, &wx_dev, &mut h_dev) }
1524 .map_err(|e| format!("cublas dgemm XtWX: {e}"))?;
1525 let h_col = stream
1526 .clone_dtoh(&h_dev)
1527 .map_err(|e| format!("download H: {e}"))?;
1528 from_col_major(&h_col, p, p).ok_or_else(|| "H layout conversion failed".to_string())
1529 }
1530
1531 pub(super) fn solve_step(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
1532 let (_, p) = validate_design(input.x, input.weights)?;
1538 if input.penalty_hessian.dim() != (p, p) {
1539 return Err(format!(
1540 "penalty Hessian shape {:?} does not match p={p}",
1541 input.penalty_hessian.dim()
1542 ));
1543 }
1544 if input.gradient.len() != p {
1545 return Err(format!(
1546 "gradient length {} does not match p={p}",
1547 input.gradient.len()
1548 ));
1549 }
1550 let n_rows = input.x.nrows();
1556 let zero_n = ndarray::Array1::<f64>::zeros(n_rows);
1557 let shared =
1558 PirlsGpuSharedData::upload_impl(input.x, zero_n.view(), zero_n.view(), zero_n.view())?;
1559 let mut ws = SigmaPirlsGpuWorkspace::allocate_impl(&shared)?;
1560 solve_step_on_stream(
1561 &shared,
1562 &mut ws,
1563 PirlsStepStreamInput {
1564 weights: input.weights,
1565 penalty_hessian: input.penalty_hessian,
1566 gradient: input.gradient,
1567 step_lm_lambda: input.step_lm_lambda,
1568 objective_ridge: input.objective_ridge,
1569 },
1570 )
1571 }
1572
1573 fn validate_design(
1574 x: ArrayView2<'_, f64>,
1575 weights: ArrayView1<'_, f64>,
1576 ) -> Result<(usize, usize), String> {
1577 let (n, p) = x.dim();
1578 if weights.len() != n {
1579 return Err(format!(
1580 "weights length {} does not match rows {n}",
1581 weights.len()
1582 ));
1583 }
1584 if n == 0 || p == 0 {
1585 return Err("empty design cannot be solved on CUDA".to_string());
1586 }
1587 Ok((n, p))
1588 }
1589
1590 fn left_scale_rows(
1591 blas: &CudaBlas,
1592 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1593 n: usize,
1594 p: usize,
1595 x_dev: &CudaSlice<f64>,
1596 w_dev: &mut CudaSlice<f64>,
1597 wx_dev: &mut CudaSlice<f64>,
1598 ) -> Result<(), String> {
1599 let n_i = to_i32(n)?;
1600 let p_i = to_i32(p)?;
1601 let handle = *blas.handle();
1602 let (x_ptr, _x_record) = x_dev.device_ptr(stream);
1603 let (w_ptr, _w_record) = w_dev.device_ptr(stream);
1604 let (wx_ptr, _wx_record) = wx_dev.device_ptr_mut(stream);
1605 let status = unsafe {
1608 cublasDdgmm(
1609 handle,
1610 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1611 n_i,
1612 p_i,
1613 x_ptr as *const f64,
1614 n_i,
1615 w_ptr as *const f64,
1616 1,
1617 wx_ptr as *mut f64,
1618 n_i,
1619 )
1620 };
1621 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1622 Ok(())
1623 } else {
1624 Err(format!("cublasDdgmm failed with {status:?}"))
1625 }
1626 }
1627
1628 fn left_scale_rows_borrowed(
1633 blas: &CudaBlas,
1634 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1635 n: usize,
1636 p: usize,
1637 x_dev: &CudaSlice<f64>,
1638 w_dev: &CudaSlice<f64>,
1639 wx_dev: &mut CudaSlice<f64>,
1640 ) -> Result<(), String> {
1641 let n_i = to_i32(n)?;
1642 let p_i = to_i32(p)?;
1643 let handle = *blas.handle();
1644 let (x_ptr, _x_record) = x_dev.device_ptr(stream);
1645 let (w_ptr, _w_record) = w_dev.device_ptr(stream);
1646 let (wx_ptr, _wx_record) = wx_dev.device_ptr_mut(stream);
1647 let status = unsafe {
1652 cublasDdgmm(
1653 handle,
1654 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1655 n_i,
1656 p_i,
1657 x_ptr as *const f64,
1658 n_i,
1659 w_ptr as *const f64,
1660 1,
1661 wx_ptr as *mut f64,
1662 n_i,
1663 )
1664 };
1665 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1666 Ok(())
1667 } else {
1668 Err(format!("cublasDdgmm (borrowed) failed with {status:?}"))
1669 }
1670 }
1671
1672 fn geam_add_inplace(
1680 blas: &CudaBlas,
1681 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1682 p: usize,
1683 a: &mut CudaSlice<f64>,
1684 b: &CudaSlice<f64>,
1685 ) -> Result<(), String> {
1686 let p_i = to_i32(p)?;
1687 let alpha = 1.0_f64;
1688 let beta = 1.0_f64;
1689 let handle = *blas.handle();
1690 let (b_ptr, _b_record) = b.device_ptr(stream);
1691 let (a_ptr, _a_record) = a.device_ptr_mut(stream);
1692 let out_ptr = a_ptr;
1694 let status = unsafe {
1697 cublasDgeam(
1698 handle,
1699 cublasOperation_t::CUBLAS_OP_N,
1700 cublasOperation_t::CUBLAS_OP_N,
1701 p_i,
1702 p_i,
1703 &alpha,
1704 a_ptr as *const f64,
1705 p_i,
1706 &beta,
1707 b_ptr as *const f64,
1708 p_i,
1709 out_ptr as *mut f64,
1710 p_i,
1711 )
1712 };
1713 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1714 Ok(())
1715 } else {
1716 Err(format!("cublasDgeam failed with {status:?}"))
1717 }
1718 }
1719
1720 fn launch_xtwx_lower(
1724 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1725 ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1726 n: usize,
1727 p: usize,
1728 x_dev: &CudaSlice<f64>,
1729 w_dev: &CudaSlice<f64>,
1730 a_dev: &mut CudaSlice<f64>,
1731 ) -> Result<(), String> {
1732 let module = FUSED_XTWX_CACHE
1733 .get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
1734 .map_err(|e| format!("fused_xtwx module: {e}"))?;
1735 let func = module
1736 .load_function("xtwx_lower")
1737 .map_err(|e| format!("load xtwx_lower: {e}"))?;
1738 let n_i = to_i32(n)?;
1739 let p_i = to_i32(p)?;
1740 let num_pairs = p * (p + 1) / 2;
1741 let num_pairs_u32 = u32::try_from(num_pairs)
1742 .map_err(|_| format!("xtwx_lower: num_pairs {num_pairs} > u32"))?;
1743 const BLOCK: u32 = 256;
1744 let grid = num_pairs_u32.div_ceil(BLOCK).max(1);
1745 let cfg = cudarc::driver::LaunchConfig {
1746 grid_dim: (grid, 1, 1),
1747 block_dim: (BLOCK, 1, 1),
1748 shared_mem_bytes: 0,
1749 };
1750 let mut builder = stream.launch_builder(&func);
1751 builder.arg(x_dev);
1752 builder.arg(w_dev);
1753 builder.arg(a_dev);
1754 builder.arg(&n_i);
1755 builder.arg(&p_i);
1756 unsafe { builder.launch(cfg) }
1759 .map_err(|e| format!("xtwx_lower launch: {e}"))
1760 .map(|_| ())
1761 }
1762
1763 fn launch_xtscore(
1766 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1767 ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1768 n: usize,
1769 p: usize,
1770 x_dev: &CudaSlice<f64>,
1771 score_dev: &CudaSlice<f64>,
1772 s_dev: &mut CudaSlice<f64>,
1773 ) -> Result<(), String> {
1774 let module = FUSED_XTWX_CACHE
1775 .get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
1776 .map_err(|e| format!("fused_xtwx module (xtscore): {e}"))?;
1777 let func = module
1778 .load_function("xtscore")
1779 .map_err(|e| format!("load xtscore: {e}"))?;
1780 let n_i = to_i32(n)?;
1781 let p_i = to_i32(p)?;
1782 let p_u32 = u32::try_from(p).map_err(|_| format!("xtscore: p {p} > u32"))?;
1783 const BLOCK: u32 = 256;
1784 let grid = p_u32.div_ceil(BLOCK).max(1);
1785 let cfg = cudarc::driver::LaunchConfig {
1786 grid_dim: (grid, 1, 1),
1787 block_dim: (BLOCK, 1, 1),
1788 shared_mem_bytes: 0,
1789 };
1790 let mut builder = stream.launch_builder(&func);
1791 builder.arg(x_dev);
1792 builder.arg(score_dev);
1793 builder.arg(s_dev);
1794 builder.arg(&n_i);
1795 builder.arg(&p_i);
1796 unsafe { builder.launch(cfg) }
1799 .map_err(|e| format!("xtscore launch: {e}"))
1800 .map(|_| ())
1801 }
1802
1803 fn launch_symmetrize_lower(
1807 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1808 ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1809 p: usize,
1810 a_dev: &mut CudaSlice<f64>,
1811 ) -> Result<(), String> {
1812 if p <= 1 {
1813 return Ok(());
1814 }
1815 let module = FUSED_XTWX_CACHE
1816 .get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
1817 .map_err(|e| format!("fused_xtwx module (sym): {e}"))?;
1818 let func = module
1819 .load_function("symmetrize_lower")
1820 .map_err(|e| format!("load symmetrize_lower: {e}"))?;
1821 let p_i = to_i32(p)?;
1822 let num_strict = p * (p - 1) / 2;
1823 let num_strict_u32 = u32::try_from(num_strict)
1824 .map_err(|_| format!("symmetrize_lower: num_strict {num_strict} > u32"))?;
1825 const BLOCK: u32 = 256;
1826 let grid = num_strict_u32.div_ceil(BLOCK).max(1);
1827 let cfg = cudarc::driver::LaunchConfig {
1828 grid_dim: (grid, 1, 1),
1829 block_dim: (BLOCK, 1, 1),
1830 shared_mem_bytes: 0,
1831 };
1832 let mut builder = stream.launch_builder(&func);
1833 builder.arg(a_dev);
1834 builder.arg(&p_i);
1835 unsafe { builder.launch(cfg) }
1838 .map_err(|e| format!("symmetrize_lower launch: {e}"))
1839 .map(|_| ())
1840 }
1841
1842 fn cholesky_logdet_device(
1847 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1848 ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1849 p: usize,
1850 factor_dev: &CudaSlice<f64>,
1851 ) -> Result<f64, String> {
1852 let module = CHOL_LOGDET_CACHE
1853 .get_or_compile(ctx, "pirls_gpu_chol_logdet", CHOL_LOGDET_PTX_SOURCE)
1854 .map_err(|err| format!("chol_logdet module: {err}"))?;
1855 let func = module
1856 .load_function("chol_logdet_col_major")
1857 .map_err(|err| format!("chol_logdet load_function: {err}"))?;
1858 let mut out_dev = stream
1859 .alloc_zeros::<f64>(1)
1860 .map_err(|err| format!("alloc chol_logdet out: {err}"))?;
1861 let p_i = to_i32(p)?;
1862 let cfg = LaunchConfig {
1863 grid_dim: (1, 1, 1),
1864 block_dim: (1, 1, 1),
1865 shared_mem_bytes: 0,
1866 };
1867 let mut builder = stream.launch_builder(&func);
1868 builder.arg(factor_dev);
1869 builder.arg(&p_i);
1870 builder.arg(&mut out_dev);
1871 unsafe { builder.launch(cfg) }.map_err(|err| format!("chol_logdet launch: {err}"))?;
1876 let out_host = stream
1877 .clone_dtoh(&out_dev)
1878 .map_err(|err| format!("download chol_logdet: {err}"))?;
1879 Ok(out_host[0])
1880 }
1881
1882 fn penalty_with_ridge(penalty: ArrayView2<'_, f64>, ridge: f64) -> Array2<f64> {
1883 let mut out = penalty.to_owned();
1884 if ridge != 0.0 {
1885 for i in 0..out.nrows().min(out.ncols()) {
1886 out[[i, i]] += ridge;
1887 }
1888 }
1889 out
1890 }
1891
1892 fn to_i32(value: usize) -> Result<i32, String> {
1893 i32::try_from(value).map_err(|_| format!("CUDA dimension {value} exceeds i32"))
1894 }
1895
1896 const PIRLS_LOOP_PTX_SOURCE: &str = r#"
1903extern "C" {
1904 double fabs(double);
1905}
1906
1907extern "C" __global__ void axpy_n(
1908 double alpha,
1909 const double* __restrict__ x,
1910 double* __restrict__ y,
1911 int n
1912) {
1913 int i = blockIdx.x * blockDim.x + threadIdx.x;
1914 if (i >= n) return;
1915 y[i] += alpha * x[i];
1916}
1917
1918extern "C" __global__ void deviance_sum(
1919 const double* __restrict__ d,
1920 int n,
1921 double* __restrict__ out
1922) {
1923 __shared__ double sm[1024];
1924 int tid = threadIdx.x;
1925 int bdim = blockDim.x;
1926 double acc = 0.0;
1927 for (int i = tid; i < n; i += bdim) {
1928 acc += d[i];
1929 }
1930 sm[tid] = acc;
1931 __syncthreads();
1932 for (int stride = bdim / 2; stride > 0; stride >>= 1) {
1933 if (tid < stride) sm[tid] += sm[tid + stride];
1934 __syncthreads();
1935 }
1936 if (tid == 0) out[0] = sm[0];
1937}
1938
1939extern "C" __global__ void linf_norm(
1940 const double* __restrict__ v,
1941 int p,
1942 double* __restrict__ out
1943) {
1944 __shared__ double sm[1024];
1945 int tid = threadIdx.x;
1946 int bdim = blockDim.x;
1947 double acc = 0.0;
1948 for (int i = tid; i < p; i += bdim) {
1949 double a = fabs(v[i]);
1950 if (a > acc) acc = a;
1951 }
1952 sm[tid] = acc;
1953 __syncthreads();
1954 for (int stride = bdim / 2; stride > 0; stride >>= 1) {
1955 if (tid < stride) {
1956 double r = sm[tid + stride];
1957 if (r > sm[tid]) sm[tid] = r;
1958 }
1959 __syncthreads();
1960 }
1961 if (tid == 0) out[0] = sm[0];
1962}
1963
1964extern "C" __global__ void negate_n(
1965 double* __restrict__ v,
1966 int n
1967) {
1968 int i = blockIdx.x * blockDim.x + threadIdx.x;
1969 if (i >= n) return;
1970 v[i] = -v[i];
1971}
1972
1973// OR-reduction over a u32 status array (length n). Single-block;
1974// same launch config as deviance_sum (1 block of 1024 threads).
1975// out[0] receives the bitwise-OR of all status[i] for i in [0, n).
1976extern "C" __global__ void status_or(
1977 const unsigned int* __restrict__ status,
1978 int n,
1979 unsigned int* __restrict__ out
1980) {
1981 __shared__ unsigned int sm[1024];
1982 int tid = threadIdx.x;
1983 int bdim = blockDim.x;
1984 unsigned int acc = 0u;
1985 for (int i = tid; i < n; i += bdim) {
1986 acc |= status[i];
1987 }
1988 sm[tid] = acc;
1989 __syncthreads();
1990 for (int stride = bdim / 2; stride > 0; stride >>= 1) {
1991 if (tid < stride) sm[tid] |= sm[tid + stride];
1992 __syncthreads();
1993 }
1994 if (tid == 0) out[0] = sm[0];
1995}
1996"#;
1997
1998 static PIRLS_LOOP_CACHE: PtxModuleCache = PtxModuleCache::new();
1999
2000 pub struct PirlsLoopWorkspace {
2007 pub beta_dev: CudaSlice<f64>,
2008 pub eta_dev: CudaSlice<f64>,
2009 pub row_solve: crate::gpu_kernels::pirls_row::SolveRowBuffers,
2011 pub alpha_ladder: crate::gpu_kernels::pirls_row::AlphaLadderDevBuffers,
2013 pub row_final: crate::gpu_kernels::pirls_row::RowOutputDevBuffers,
2015 pub direction_dev: CudaSlice<f64>,
2016 pub xd_dev: CudaSlice<f64>,
2017 pub scalar_dev: CudaSlice<f64>,
2018 pub status_u32_dev: CudaSlice<u32>,
2020 pub n: usize,
2021 pub p: usize,
2022 }
2023
2024 impl PirlsLoopWorkspace {
2025 pub fn allocate(
2026 shared: &PirlsGpuSharedData,
2027 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
2028 ) -> Result<Self, String> {
2029 let n = shared.n;
2030 let p = shared.p;
2031 let alloc_f64 = |label: &'static str, len: usize| {
2032 stream
2033 .alloc_zeros::<f64>(len)
2034 .map_err(|e| format!("pirls loop alloc {label}: {e}"))
2035 };
2036 Ok(Self {
2037 beta_dev: alloc_f64("beta", p)?,
2038 eta_dev: alloc_f64("eta", n)?,
2039 row_solve: crate::gpu_kernels::pirls_row::SolveRowBuffers::allocate(stream, n)
2040 .map_err(|e| format!("pirls loop alloc row_solve: {e}"))?,
2041 alpha_ladder: crate::gpu_kernels::pirls_row::AlphaLadderDevBuffers::allocate(
2042 stream,
2043 )
2044 .map_err(|e| format!("pirls loop alloc alpha_ladder: {e}"))?,
2045 row_final: crate::gpu_kernels::pirls_row::RowOutputDevBuffers::allocate(stream, n)
2046 .map_err(|e| format!("pirls loop alloc row_final: {e}"))?,
2047 direction_dev: alloc_f64("direction", p)?,
2048 xd_dev: alloc_f64("xd", n)?,
2049 scalar_dev: alloc_f64("scalar", 1)?,
2050 status_u32_dev: stream
2051 .alloc_zeros::<u32>(1)
2052 .map_err(|e| format!("pirls loop alloc status_u32: {e}"))?,
2053 n,
2054 p,
2055 })
2056 }
2057 }
2058
2059 pub struct PirlsLoopExtra<'a> {
2083 pub likelihood: &'a gam_problem::GlmLikelihoodSpec,
2088 pub inverse_link: &'a gam_problem::InverseLink,
2091 pub y: ndarray::ArrayView1<'a, f64>,
2094 pub priorweights: ndarray::ArrayView1<'a, f64>,
2097 pub offset: ndarray::ArrayView1<'a, f64>,
2101 pub linear_constraints: Option<&'a gam_problem::LinearInequalityConstraints>,
2108 pub exported_curvature: crate::pirls::HessianCurvatureKind,
2118 pub ridge_passport: Option<gam_problem::RidgePassport>,
2125 pub firth: Option<crate::pirls::FirthDiagnostics>,
2131 pub qs: Option<ndarray::ArrayView2<'a, f64>>,
2141 pub edf: Option<f64>,
2150 }
2151
2152 #[derive(Clone, Debug)]
2153 pub struct PirlsLoopOutcome {
2154 pub beta: Array1<f64>,
2155 pub penalized_hessian: Array2<f64>,
2156 pub logdet: f64,
2157 pub deviance: f64,
2158 pub iterations: usize,
2159 pub converged: bool,
2160 pub final_eta: Array1<f64>,
2163 pub final_mu: Array1<f64>,
2166 pub final_grad_eta: Array1<f64>,
2170 pub final_w_hessian: Array1<f64>,
2174 pub final_w_solver: Array1<f64>,
2177 pub final_offset: Array1<f64>,
2181 pub beta_transformed: Array1<f64>,
2186 pub finalweights: Array1<f64>,
2190 pub solveweights: Array1<f64>,
2193 pub solve_dmu_deta: Array1<f64>,
2197 pub solve_d2mu_deta2: Array1<f64>,
2199 pub solve_d3mu_deta3: Array1<f64>,
2201 pub solve_c_array: Array1<f64>,
2205 pub solve_d_array: Array1<f64>,
2208 pub derivatives_unsupported: bool,
2212 pub status: crate::pirls::PirlsStatus,
2218 pub ridge_passport: gam_problem::RidgePassport,
2223 pub firth: crate::pirls::FirthDiagnostics,
2226 pub constraint_kkt: Option<crate::active_set::ConstraintKktDiagnostics>,
2230 pub edf: f64,
2233 pub last_deviance_change: f64,
2237 pub last_step_halving: usize,
2243 pub last_step_size: f64,
2246 pub final_lm_lambda: f64,
2252 pub min_deviance: f64,
2259 pub max_abs_eta: f64,
2264 pub per_row_status_or: u32,
2273 }
2274
2275 pub(super) fn pirls_loop(
2279 shared: &PirlsGpuSharedData,
2280 ws: &mut SigmaPirlsGpuWorkspace,
2281 loop_ws: &mut PirlsLoopWorkspace,
2282 family: crate::gpu_kernels::pirls_row::PirlsRowFamily,
2283 curvature: crate::gpu_kernels::pirls_row::CurvatureMode,
2284 gamma_shape: f64,
2287 beta0_host: ArrayView1<'_, f64>,
2288 penalty_hessian: ArrayView2<'_, f64>,
2289 linear_shift: ArrayView1<'_, f64>,
2294 constant_shift: f64,
2297 lm_ridge: f64,
2300 objective_ridge: f64,
2303 max_iter: usize,
2304 tol: f64,
2305 extra: Option<&PirlsLoopExtra<'_>>,
2306 ) -> Result<PirlsLoopOutcome, String> {
2307 let n = shared.n;
2308 let p = shared.p;
2309 if loop_ws.n != n || loop_ws.p != p {
2310 return Err(format!(
2311 "loop workspace ({}, {}) ≠ shared ({n}, {p})",
2312 loop_ws.n, loop_ws.p
2313 ));
2314 }
2315 if beta0_host.len() != p {
2316 return Err(format!("beta0 length {} ≠ p={p}", beta0_host.len()));
2317 }
2318
2319 if linear_shift.len() != p {
2320 return Err(format!(
2321 "linear_shift length {} ≠ p={p}",
2322 linear_shift.len()
2323 ));
2324 }
2325 if penalty_hessian.dim() != (p, p) {
2326 return Err(format!(
2327 "penalty_hessian shape {:?} ≠ (p={p}, p={p})",
2328 penalty_hessian.dim()
2329 ));
2330 }
2331
2332 ws.stream
2333 .memcpy_htod(
2334 beta0_host.as_slice().ok_or("beta0 not contiguous")?,
2335 &mut loop_ws.beta_dev,
2336 )
2337 .map_err(|e| format!("upload beta0: {e}"))?;
2338
2339 let backend = crate::gpu_kernels::pirls_row::PirlsRowBackend::probe()
2340 .map_err(|e| format!("pirls_row backend: {e}"))?;
2341 let loop_module = PIRLS_LOOP_CACHE
2342 .get_or_compile(&shared.ctx, "pirls_loop", PIRLS_LOOP_PTX_SOURCE)
2343 .map_err(|e| format!("pirls loop module: {e}"))?;
2344 let axpy_func = loop_module
2345 .load_function("axpy_n")
2346 .map_err(|e| format!("load axpy_n: {e}"))?;
2347 let sum_func = loop_module
2348 .load_function("deviance_sum")
2349 .map_err(|e| format!("load deviance_sum: {e}"))?;
2350 let linf_func = loop_module
2351 .load_function("linf_norm")
2352 .map_err(|e| format!("load linf_norm: {e}"))?;
2353 let status_or_func = loop_module
2354 .load_function("status_or")
2355 .map_err(|e| format!("load status_or: {e}"))?;
2356
2357 gemv_no_trans(
2360 &ws.blas,
2361 p,
2362 p,
2363 &ws.qs_dev,
2364 &loop_ws.beta_dev,
2365 &mut ws.beta_orig_dev,
2366 )?;
2367 gemv_no_trans(
2369 &ws.blas,
2370 n,
2371 p,
2372 &shared.x_original_dev,
2373 &ws.beta_orig_dev,
2374 &mut loop_ws.eta_dev,
2375 )?;
2376 axpy(
2377 &ws.stream,
2378 &axpy_func,
2379 1.0,
2380 &shared.offset_dev,
2381 &mut loop_ws.eta_dev,
2382 n,
2383 )?;
2384 crate::gpu_kernels::pirls_row::launch_solve_row_on_stream(
2386 backend,
2387 family,
2388 curvature,
2389 gamma_shape,
2390 &ws.stream,
2391 n,
2392 &loop_ws.eta_dev,
2393 &shared.y_dev,
2394 &shared.prior_w_dev,
2395 &mut loop_ws.row_solve,
2396 )
2397 .map_err(|e| format!("solve-row init: {e}"))?;
2398
2399 let mut prev_deviance = reduce_scalar(
2400 &ws.stream,
2401 &sum_func,
2402 &loop_ws.row_solve.deviance,
2403 n,
2404 &mut loop_ws.scalar_dev,
2405 "deviance_init",
2406 )?;
2407 let mut last_logdet = 0.0_f64;
2408 let mut converged = false;
2409
2410 let mut beta_host: Array1<f64> = beta0_host.to_owned();
2416
2417 let s_beta0 = penalty_hessian.dot(&beta_host);
2422 let penalty_init =
2423 beta_host.dot(&s_beta0) - 2.0 * beta_host.dot(&linear_shift) + constant_shift;
2424 let mut prev_objective = prev_deviance + penalty_init;
2425
2426 let mut last_dev_delta = 0.0_f64;
2434 let mut last_halving: usize = 0;
2435 let mut last_step_size = 0.0_f64;
2436 let mut min_dev = prev_deviance;
2437 let mut step_search_exhausted = false;
2438
2439 for it in 0..max_iter {
2440 last_logdet = solve_step_on_stream_device_inplace(
2441 shared,
2442 ws,
2443 PirlsStepStreamDeviceInput {
2444 w_solver_dev: &loop_ws.row_solve.w_solver,
2445 grad_eta_dev: &loop_ws.row_solve.grad_eta,
2446 penalty_hessian,
2447 step_lm_lambda: lm_ridge,
2448 objective_ridge,
2449 beta_dev: &loop_ws.beta_dev,
2450 linear_shift,
2451 },
2452 )
2453 .map_err(|e| format!("inner step it={it}: {e}"))?;
2454 ws.stream
2457 .memcpy_dtod(&ws.rhs_dev, &mut loop_ws.direction_dev)
2458 .map_err(|e| format!("direction d2d copy it={it}: {e}"))?;
2459
2460 let dir_linf = reduce_scalar(
2461 &ws.stream,
2462 &linf_func,
2463 &loop_ws.direction_dev,
2464 p,
2465 &mut loop_ws.scalar_dev,
2466 "dir_linf",
2467 )?;
2468
2469 gemv_no_trans(
2471 &ws.blas,
2472 p,
2473 p,
2474 &ws.qs_dev,
2475 &loop_ws.direction_dev,
2476 &mut ws.dir_orig_dev,
2477 )?;
2478 gemv_no_trans(
2479 &ws.blas,
2480 n,
2481 p,
2482 &shared.x_original_dev,
2483 &ws.dir_orig_dev,
2484 &mut loop_ws.xd_dev,
2485 )?;
2486
2487 loop_ws
2495 .alpha_ladder
2496 .zero(&ws.stream)
2497 .map_err(|e| format!("ladder zero it={it}: {e}"))?;
2498 crate::gpu_kernels::pirls_row::launch_alpha_ladder_on_stream(
2499 backend,
2500 family,
2501 curvature,
2502 gamma_shape,
2503 &ws.stream,
2504 n,
2505 &loop_ws.eta_dev,
2506 &loop_ws.xd_dev,
2507 &shared.y_dev,
2508 &shared.prior_w_dev,
2509 &mut loop_ws.alpha_ladder,
2510 )
2511 .map_err(|e| format!("alpha-ladder it={it}: {e}"))?;
2512 let obj_host: Vec<f64> = ws
2513 .stream
2514 .clone_dtoh(&loop_ws.alpha_ladder.objective_dev)
2515 .map_err(|e| format!("ladder dtoh obj it={it}: {e}"))?;
2516 let stat_host: Vec<u32> = ws
2517 .stream
2518 .clone_dtoh(&loop_ws.alpha_ladder.status_dev)
2519 .map_err(|e| format!("ladder dtoh stat it={it}: {e}"))?;
2520 let direction_host: Vec<f64> = ws
2523 .stream
2524 .clone_dtoh(&loop_ws.direction_dev)
2525 .map_err(|e| format!("dtoh direction it={it}: {e}"))?;
2526
2527 let dir_view = ndarray::aview1(&direction_host);
2536 let sd = penalty_hessian.dot(&dir_view);
2537 let s_beta = penalty_hessian.dot(&beta_host);
2538 let dtsd = dir_view.dot(&sd);
2539 let linear_coeff = 2.0 * dir_view.dot(&(&s_beta - &linear_shift));
2540 let penalty_beta =
2541 beta_host.dot(&s_beta) - 2.0 * beta_host.dot(&linear_shift) + constant_shift;
2542
2543 const FORBIDDEN_LINESEARCH: u32 =
2544 crate::gpu_kernels::pirls_row::status_flags::INVALID_RESPONSE
2545 | crate::gpu_kernels::pirls_row::status_flags::ZERO_PRIOR_WEIGHT;
2546 let mut alpha = 0.0_f64;
2547 let mut accepted_dev = prev_deviance;
2548 let mut accepted_objective = prev_objective;
2549 let mut halving_count: usize = 0;
2550 for (k, (&dev_k, &st)) in obj_host.iter().zip(stat_host.iter()).enumerate() {
2551 let a = crate::gpu_kernels::pirls_row::ALPHA_LADDER[k];
2552 let pen_k = penalty_beta + a * linear_coeff + a * a * dtsd;
2553 let obj_k = dev_k + pen_k;
2554 if obj_k.is_finite() && obj_k <= prev_objective && (st & FORBIDDEN_LINESEARCH) == 0
2562 {
2563 alpha = a;
2564 accepted_dev = dev_k;
2565 accepted_objective = obj_k;
2566 halving_count = k;
2567 break;
2568 }
2569 }
2570 if alpha == 0.0 {
2571 step_search_exhausted = true;
2590 last_halving = 0;
2591 last_step_size = 0.0;
2592 last_dev_delta = 0.0;
2593 break;
2594 }
2595 step_search_exhausted = false;
2596 axpy(
2598 &ws.stream,
2599 &axpy_func,
2600 alpha,
2601 &loop_ws.direction_dev,
2602 &mut loop_ws.beta_dev,
2603 p,
2604 )?;
2605 axpy(
2606 &ws.stream,
2607 &axpy_func,
2608 alpha,
2609 &loop_ws.xd_dev,
2610 &mut loop_ws.eta_dev,
2611 n,
2612 )?;
2613 for (b, &d) in beta_host.iter_mut().zip(direction_host.iter()) {
2615 *b += alpha * d;
2616 }
2617 crate::gpu_kernels::pirls_row::launch_solve_row_on_stream(
2619 backend,
2620 family,
2621 curvature,
2622 gamma_shape,
2623 &ws.stream,
2624 n,
2625 &loop_ws.eta_dev,
2626 &shared.y_dev,
2627 &shared.prior_w_dev,
2628 &mut loop_ws.row_solve,
2629 )
2630 .map_err(|e| format!("solve-row accepted it={it}: {e}"))?;
2631
2632 let step_norm = alpha.abs() * dir_linf;
2633 let dev_delta = (prev_objective - accepted_objective).abs();
2634 last_dev_delta = dev_delta;
2635 last_halving = halving_count;
2636 last_step_size = alpha;
2637 if accepted_dev < min_dev {
2638 min_dev = accepted_dev;
2639 }
2640
2641 prev_deviance = accepted_dev;
2642 prev_objective = accepted_objective;
2643
2644 if dir_linf <= tol
2645 && step_norm <= tol
2646 && dev_delta <= tol * (1.0 + prev_objective.abs())
2647 {
2648 converged = true;
2649 crate::gpu_kernels::pirls_row::launch_row_reweight_on_stream(
2651 backend,
2652 family,
2653 curvature,
2654 gamma_shape,
2655 &ws.stream,
2656 n,
2657 &loop_ws.eta_dev,
2658 &shared.y_dev,
2659 &shared.prior_w_dev,
2660 &mut loop_ws.row_final,
2661 )
2662 .map_err(|e| format!("final-row converged: {e}"))?;
2663 let h_final = rebuild_h_final(
2664 shared,
2665 ws,
2666 &loop_ws.row_final.w_hessian,
2667 penalty_hessian,
2668 objective_ridge,
2669 )
2670 .map_err(|e| format!("rebuild H_final (converged): {e}"))?;
2671 return build_loop_outcome(
2672 ws,
2673 loop_ws,
2674 h_final,
2675 last_logdet,
2676 prev_deviance,
2677 it + 1,
2678 converged,
2679 lm_ridge,
2680 objective_ridge,
2681 extra,
2682 LoopDiagnostics {
2683 last_deviance_change: last_dev_delta,
2684 last_step_halving: last_halving,
2685 last_step_size,
2686 min_deviance: min_dev,
2687 step_search_exhausted,
2688 },
2689 &status_or_func,
2690 );
2691 }
2692 }
2693
2694 crate::gpu_kernels::pirls_row::launch_row_reweight_on_stream(
2696 backend,
2697 family,
2698 curvature,
2699 gamma_shape,
2700 &ws.stream,
2701 n,
2702 &loop_ws.eta_dev,
2703 &shared.y_dev,
2704 &shared.prior_w_dev,
2705 &mut loop_ws.row_final,
2706 )
2707 .map_err(|e| format!("final-row max_iter: {e}"))?;
2708 let h_final = rebuild_h_final(
2709 shared,
2710 ws,
2711 &loop_ws.row_final.w_hessian,
2712 penalty_hessian,
2713 objective_ridge,
2714 )
2715 .map_err(|e| format!("rebuild H_final (max_iter): {e}"))?;
2716 build_loop_outcome(
2717 ws,
2718 loop_ws,
2719 h_final,
2720 last_logdet,
2721 prev_deviance,
2722 max_iter,
2723 converged,
2724 lm_ridge,
2725 objective_ridge,
2726 extra,
2727 LoopDiagnostics {
2728 last_deviance_change: last_dev_delta,
2729 last_step_halving: last_halving,
2730 last_step_size,
2731 min_deviance: min_dev,
2732 step_search_exhausted,
2733 },
2734 &status_or_func,
2735 )
2736 }
2737
2738 struct LoopDiagnostics {
2751 last_deviance_change: f64,
2752 last_step_halving: usize,
2753 last_step_size: f64,
2754 min_deviance: f64,
2755 step_search_exhausted: bool,
2756 }
2757
2758 fn build_loop_outcome(
2773 ws: &mut SigmaPirlsGpuWorkspace,
2774 loop_ws: &mut PirlsLoopWorkspace,
2775 penalized_hessian: Array2<f64>,
2776 logdet: f64,
2777 deviance: f64,
2778 iterations: usize,
2779 converged: bool,
2780 step_lm_lambda: f64,
2781 objective_ridge: f64,
2782 extra: Option<&PirlsLoopExtra<'_>>,
2783 diagnostics: LoopDiagnostics,
2784 status_or_func: &cudarc::driver::CudaFunction,
2785 ) -> Result<PirlsLoopOutcome, String> {
2786 let beta = download_vec(&ws.stream, &loop_ws.beta_dev)?;
2787 let final_eta = download_vec(&ws.stream, &loop_ws.eta_dev)?;
2788 let final_mu = download_vec(&ws.stream, &loop_ws.row_final.mu)?;
2789 let final_grad_eta = download_vec(&ws.stream, &loop_ws.row_final.grad_eta)?;
2790 let final_w_hessian = download_vec(&ws.stream, &loop_ws.row_final.w_hessian)?;
2791 let final_w_solver = download_vec(&ws.stream, &loop_ws.row_final.w_solver)?;
2792
2793 let n_rows = loop_ws.n;
2798 let final_row_status = reduce_status_or(
2799 &ws.stream,
2800 status_or_func,
2801 &loop_ws.row_final.status,
2802 n_rows,
2803 &mut loop_ws.status_u32_dev,
2804 "final_row_status",
2805 )?;
2806 const FORBIDDEN_FINAL: u32 = crate::gpu_kernels::pirls_row::status_flags::INVALID_RESPONSE
2807 | crate::gpu_kernels::pirls_row::status_flags::ZERO_PRIOR_WEIGHT;
2808
2809 let eta_finite = final_eta.iter().all(|v| v.is_finite());
2815 let mu_finite = final_mu.iter().all(|v| v.is_finite());
2816 let beta_finite = beta.iter().all(|v| v.is_finite());
2817 let stability_ok =
2818 eta_finite && mu_finite && beta_finite && (final_row_status & FORBIDDEN_FINAL) == 0;
2819 let status = if !stability_ok {
2820 crate::pirls::PirlsStatus::Unstable
2821 } else if converged {
2822 crate::pirls::PirlsStatus::Converged
2823 } else if diagnostics.step_search_exhausted {
2824 crate::pirls::PirlsStatus::LmStepSearchExhausted
2832 } else {
2833 crate::pirls::PirlsStatus::MaxIterationsReached
2834 };
2835
2836 let default_ridge = gam_problem::RidgePassport::scaled_identity(
2839 objective_ridge,
2840 gam_linalg::RidgePolicy::explicit_stabilization_full(),
2841 );
2842
2843 let max_abs_eta = final_eta.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2844
2845 match extra {
2846 Some(ext) => {
2847 let (score_c, score_d, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
2850 crate::pirls::computeworkingweight_derivatives_from_eta(
2851 ext.likelihood,
2852 ext.inverse_link,
2853 &final_eta,
2854 ext.priorweights,
2855 )
2856 .map_err(|e| format!("pirls postpass dmu/deta: {e:?}"))?;
2857
2858 let (finalweights, solve_c_array, solve_d_array) = match ext.exported_curvature {
2859 crate::pirls::HessianCurvatureKind::Observed => {
2860 crate::pirls::compute_observed_hessian_curvature_arrays(
2861 ext.likelihood,
2862 ext.inverse_link,
2863 &final_eta,
2864 ext.y,
2865 &final_w_solver,
2866 ext.priorweights,
2867 )
2868 .map_err(|e| format!("pirls postpass observed curvature: {e:?}"))?
2869 }
2870 crate::pirls::HessianCurvatureKind::Fisher => {
2871 (final_w_solver.clone(), score_c.clone(), score_d.clone())
2872 }
2873 };
2874
2875 let beta_transformed = beta.clone();
2881
2882 let constraint_kkt = ext.linear_constraints.and_then(|lin| {
2883 if lin.a.nrows() == 0 {
2884 return None;
2885 }
2886 let grad = penalized_hessian.dot(&beta);
2892 Some(
2893 crate::active_set::compute_constraint_kkt_diagnostics(
2894 &beta, &grad, lin,
2895 ),
2896 )
2897 });
2898
2899 let ridge_passport = ext.ridge_passport.unwrap_or(default_ridge);
2900 let firth = ext
2901 .firth
2902 .clone()
2903 .unwrap_or(crate::pirls::FirthDiagnostics::Inactive);
2904 let edf = ext.edf.unwrap_or(f64::NAN);
2905 let derivatives_unsupported = false;
2911
2912 Ok(PirlsLoopOutcome {
2913 beta,
2914 penalized_hessian,
2915 logdet,
2916 deviance,
2917 iterations,
2918 converged,
2919 final_eta,
2920 final_mu,
2921 final_grad_eta,
2922 final_w_hessian,
2923 final_w_solver: final_w_solver.clone(),
2924 final_offset: ext.offset.to_owned(),
2925 beta_transformed,
2926 finalweights,
2927 solveweights: final_w_solver,
2928 solve_dmu_deta,
2929 solve_d2mu_deta2,
2930 solve_d3mu_deta3,
2931 solve_c_array,
2932 solve_d_array,
2933 derivatives_unsupported,
2934 status,
2935 ridge_passport,
2936 firth,
2937 constraint_kkt,
2938 edf,
2939 last_deviance_change: diagnostics.last_deviance_change,
2940 last_step_halving: diagnostics.last_step_halving,
2941 last_step_size: diagnostics.last_step_size,
2942 final_lm_lambda: step_lm_lambda,
2943 min_deviance: diagnostics.min_deviance,
2944 max_abs_eta,
2945 per_row_status_or: final_row_status,
2946 })
2947 }
2948 None => {
2949 Ok(PirlsLoopOutcome {
2957 beta: beta.clone(),
2958 penalized_hessian,
2959 logdet,
2960 deviance,
2961 iterations,
2962 converged,
2963 final_eta,
2964 final_mu,
2965 final_grad_eta,
2966 final_w_hessian,
2967 final_w_solver: final_w_solver.clone(),
2968 final_offset: Array1::<f64>::zeros(0),
2969 beta_transformed: beta,
2970 finalweights: Array1::<f64>::zeros(0),
2971 solveweights: final_w_solver,
2972 solve_dmu_deta: Array1::<f64>::zeros(0),
2973 solve_d2mu_deta2: Array1::<f64>::zeros(0),
2974 solve_d3mu_deta3: Array1::<f64>::zeros(0),
2975 solve_c_array: Array1::<f64>::zeros(0),
2976 solve_d_array: Array1::<f64>::zeros(0),
2977 derivatives_unsupported: true,
2978 status,
2979 ridge_passport: default_ridge,
2980 firth: crate::pirls::FirthDiagnostics::Inactive,
2981 constraint_kkt: None,
2982 edf: f64::NAN,
2983 last_deviance_change: diagnostics.last_deviance_change,
2984 last_step_halving: diagnostics.last_step_halving,
2985 last_step_size: diagnostics.last_step_size,
2986 final_lm_lambda: step_lm_lambda,
2987 min_deviance: diagnostics.min_deviance,
2988 max_abs_eta,
2989 per_row_status_or: final_row_status,
2990 })
2991 }
2992 }
2993 }
2994
2995 fn gemv_no_trans(
2996 blas: &CudaBlas,
2997 n: usize,
2998 p: usize,
2999 a_dev: &CudaSlice<f64>,
3000 x_dev: &CudaSlice<f64>,
3001 y_dev: &mut CudaSlice<f64>,
3002 ) -> Result<(), String> {
3003 let n_i = to_i32(n)?;
3004 let p_i = to_i32(p)?;
3005 let cfg = GemvConfig::<f64> {
3006 trans: cublasOperation_t::CUBLAS_OP_N,
3007 m: n_i,
3008 n: p_i,
3009 alpha: 1.0,
3010 lda: n_i,
3011 incx: 1,
3012 beta: 0.0,
3013 incy: 1,
3014 };
3015 unsafe { blas.gemv(cfg, a_dev, x_dev, y_dev) }.map_err(|e| format!("dgemv no-trans: {e}"))
3017 }
3018
3019 fn axpy(
3020 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3021 func: &cudarc::driver::CudaFunction,
3022 alpha: f64,
3023 x_dev: &CudaSlice<f64>,
3024 y_dev: &mut CudaSlice<f64>,
3025 n: usize,
3026 ) -> Result<(), String> {
3027 const THREADS: u32 = 256;
3028 let n_i = to_i32(n)?;
3029 let n_u = u32::try_from(n).map_err(|_| format!("axpy n={n} > u32"))?;
3030 let grid = n_u.div_ceil(THREADS).max(1);
3031 let cfg = LaunchConfig {
3032 grid_dim: (grid, 1, 1),
3033 block_dim: (THREADS, 1, 1),
3034 shared_mem_bytes: 0,
3035 };
3036 let mut builder = stream.launch_builder(func);
3037 builder.arg(&alpha);
3038 builder.arg(x_dev);
3039 builder.arg(y_dev);
3040 builder.arg(&n_i);
3041 unsafe { builder.launch(cfg) }
3044 .map(|_event_pair| ())
3045 .map_err(|e| format!("axpy launch: {e}"))
3046 }
3047
3048 fn reduce_scalar(
3049 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3050 func: &cudarc::driver::CudaFunction,
3051 src: &CudaSlice<f64>,
3052 len: usize,
3053 scalar_dev: &mut CudaSlice<f64>,
3054 label: &'static str,
3055 ) -> Result<f64, String> {
3056 const THREADS: u32 = 1024;
3057 let len_i = to_i32(len)?;
3058 let cfg = LaunchConfig {
3059 grid_dim: (1, 1, 1),
3060 block_dim: (THREADS, 1, 1),
3061 shared_mem_bytes: 0,
3062 };
3063 let mut builder = stream.launch_builder(func);
3064 builder.arg(src);
3065 builder.arg(&len_i);
3066 builder.arg(&mut *scalar_dev);
3067 unsafe { builder.launch(cfg) }.map_err(|e| format!("{label} reduce launch: {e}"))?;
3071 let host = stream
3072 .clone_dtoh(scalar_dev)
3073 .map_err(|e| format!("download {label}: {e}"))?;
3074 Ok(host[0])
3075 }
3076
3077 fn reduce_status_or(
3081 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3082 func: &cudarc::driver::CudaFunction,
3083 src: &CudaSlice<u32>,
3084 len: usize,
3085 status_dev: &mut CudaSlice<u32>,
3086 label: &'static str,
3087 ) -> Result<u32, String> {
3088 const THREADS: u32 = 1024;
3089 let len_i = to_i32(len)?;
3090 let cfg = LaunchConfig {
3091 grid_dim: (1, 1, 1),
3092 block_dim: (THREADS, 1, 1),
3093 shared_mem_bytes: 0,
3094 };
3095 let mut builder = stream.launch_builder(func);
3096 builder.arg(src);
3097 builder.arg(&len_i);
3098 builder.arg(&mut *status_dev);
3099 unsafe { builder.launch(cfg) }.map_err(|e| format!("{label} or reduce launch: {e}"))?;
3102 let host = stream
3103 .clone_dtoh(status_dev)
3104 .map_err(|e| format!("download {label}: {e}"))?;
3105 Ok(host[0])
3106 }
3107
3108 fn download_vec(
3109 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3110 dev: &CudaSlice<f64>,
3111 ) -> Result<Array1<f64>, String> {
3112 let host = stream
3113 .clone_dtoh(dev)
3114 .map_err(|e| format!("download vec: {e}"))?;
3115 Ok(Array1::from_vec(host))
3116 }
3117
3118 pub struct GaussianPlsResult {
3120 pub beta: Array1<f64>,
3121 pub penalized_hessian: Array2<f64>,
3122 pub logdet: f64,
3123 }
3124
3125 pub fn solve_gaussian_pls_on_stream(
3128 a_orig: ArrayView2<'_, f64>,
3129 b_orig: ArrayView1<'_, f64>,
3130 s_transformed: ArrayView2<'_, f64>,
3131 linear_shift: ArrayView1<'_, f64>,
3132 prior_mean_target: ArrayView1<'_, f64>,
3133 ridge: f64,
3134 qs: Option<ArrayView2<'_, f64>>,
3135 ) -> Result<GaussianPlsResult, String> {
3136 let p = b_orig.len();
3137 if a_orig.dim() != (p, p) {
3138 return Err(format!("A shape {:?} != ({p},{p})", a_orig.dim()));
3139 }
3140 if s_transformed.dim() != (p, p) {
3141 return Err(format!("S shape {:?} != ({p},{p})", s_transformed.dim()));
3142 }
3143 if linear_shift.len() != p {
3144 return Err(format!("linear_shift len {} != p={p}", linear_shift.len()));
3145 }
3146 if prior_mean_target.len() != p {
3147 return Err(format!(
3148 "prior_mean_target len {} != p={p}",
3149 prior_mean_target.len()
3150 ));
3151 }
3152 if let Some(qs_v) = qs {
3153 if qs_v.dim() != (p, p) {
3154 return Err(format!("qs shape {:?} != ({p},{p})", qs_v.dim()));
3155 }
3156 }
3157 let (h_rotated, rhs_base) = if let Some(qs_v) = qs {
3158 let qs_owned = qs_v.to_owned();
3159 let tmp = a_orig.dot(&qs_owned);
3160 let h = qs_owned.t().dot(&tmp);
3161 let rb = qs_owned.t().dot(&b_orig);
3162 (h, rb)
3163 } else {
3164 (a_orig.to_owned(), b_orig.to_owned())
3165 };
3166 let penalized_hessian: Array2<f64> = &h_rotated + &s_transformed;
3167 let mut regularized = penalized_hessian.clone();
3168 if ridge > 0.0 {
3169 for i in 0..p {
3170 regularized[[i, i]] += ridge;
3171 }
3172 }
3173 let mut rhs_host = rhs_base;
3174 rhs_host += &linear_shift;
3175 if ridge > 0.0 {
3176 rhs_host.scaled_add(ridge, &prior_mean_target);
3177 }
3178 let (ctx, stream) = context_and_stream()?;
3179 let solver = DnHandle::new(stream.clone())
3180 .map_err(|e| format!("cusolver init (gaussian pls): {e}"))?;
3181 let pp = p.checked_mul(p).ok_or("p*p overflow (gaussian pls)")?;
3182 let mut h_dev = stream
3183 .alloc_zeros::<f64>(pp)
3184 .map_err(|e| format!("alloc H (gaussian pls): {e}"))?;
3185 let mut rhs_dev = stream
3186 .alloc_zeros::<f64>(p)
3187 .map_err(|e| format!("alloc rhs (gaussian pls): {e}"))?;
3188 let potrf_lwork_usize = potrf_query_lwork(&solver, &stream, p)?;
3189 let potrf_lwork = i32::try_from(potrf_lwork_usize)
3190 .map_err(|_| "potrf lwork overflow (gaussian pls)".to_string())?;
3191 let mut potrf_work_dev = stream
3192 .alloc_zeros::<f64>(potrf_lwork_usize.max(1))
3193 .map_err(|e| format!("alloc potrf workspace (gaussian pls): {e}"))?;
3194 let mut potrf_info_dev = stream
3195 .alloc_zeros::<i32>(1)
3196 .map_err(|e| format!("alloc potrf info (gaussian pls): {e}"))?;
3197 let mut potrs_info_dev = stream
3198 .alloc_zeros::<i32>(1)
3199 .map_err(|e| format!("alloc potrs info (gaussian pls): {e}"))?;
3200 let reg_col = to_col_major(®ularized);
3201 stream
3202 .memcpy_htod(reg_col.as_ref(), &mut h_dev)
3203 .map_err(|e| format!("upload H (gaussian pls): {e}"))?;
3204 let rhs_slice = rhs_host
3205 .as_slice()
3206 .ok_or("rhs_host not contiguous (gaussian pls)")?;
3207 stream
3208 .memcpy_htod(rhs_slice, &mut rhs_dev)
3209 .map_err(|e| format!("upload rhs (gaussian pls): {e}"))?;
3210 potrf_in_place_reuse(
3211 &solver,
3212 &stream,
3213 p,
3214 potrf_lwork,
3215 &mut h_dev,
3216 &mut potrf_work_dev,
3217 &mut potrf_info_dev,
3218 )?;
3219 potrs_in_place_reuse(
3220 &solver,
3221 &stream,
3222 p,
3223 1,
3224 &h_dev,
3225 &mut rhs_dev,
3226 &mut potrs_info_dev,
3227 )?;
3228 let logdet = cholesky_logdet_device(&stream, &ctx, p, &h_dev)?;
3229 let beta_raw = stream
3230 .clone_dtoh(&rhs_dev)
3231 .map_err(|e| format!("download beta (gaussian pls): {e}"))?;
3232 check_deferred_potrf_info(&stream, &potrf_info_dev)?;
3233 check_deferred_potrs_info(&stream, &potrs_info_dev)?;
3234 Ok(GaussianPlsResult {
3235 beta: Array1::from_vec(beta_raw),
3236 penalized_hessian,
3237 logdet,
3238 })
3239 }
3240}
3241
3242pub fn weighted_crossprod_gpu(
3243 x: ArrayView2<'_, f64>,
3244 weights: ArrayView1<'_, f64>,
3245) -> Result<Array2<f64>, String> {
3246 #[cfg(not(target_os = "linux"))]
3247 {
3248 return cpu_fallback::weighted_crossprod_cpu(x, weights);
3249 }
3250
3251 #[cfg(target_os = "linux")]
3252 {
3253 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3254 return cpu_fallback::weighted_crossprod_cpu(x, weights);
3255 }
3256 cuda::weighted_crossprod(x, weights)
3257 }
3258}
3259
3260pub fn solve_pirls_step_gpu(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
3261 #[cfg(not(target_os = "linux"))]
3262 {
3263 return cpu_fallback::solve_step_cpu(input);
3264 }
3265
3266 #[cfg(target_os = "linux")]
3267 {
3268 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3269 return cpu_fallback::solve_step_cpu(input);
3270 }
3271 cuda::solve_step(input)
3272 }
3273}
3274
3275#[cfg(target_os = "linux")]
3281pub fn upload_shared_pirls_gpu(
3282 x: ndarray::ArrayView2<'_, f64>,
3283 y: ndarray::ArrayView1<'_, f64>,
3284 prior_w: ndarray::ArrayView1<'_, f64>,
3285 offset: ndarray::ArrayView1<'_, f64>,
3286) -> Result<PirlsGpuSharedData, String> {
3287 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3288 return Err("cuda runtime unavailable; cannot upload shared GPU PIRLS data".to_string());
3289 }
3290 PirlsGpuSharedData::upload_impl(x, y, prior_w, offset)
3291}
3292
3293#[cfg(target_os = "linux")]
3297pub fn allocate_sigma_pirls_workspace(
3298 shared: &PirlsGpuSharedData,
3299) -> Result<SigmaPirlsGpuWorkspace, String> {
3300 SigmaPirlsGpuWorkspace::allocate_impl(shared)
3301}
3302
3303#[cfg(target_os = "linux")]
3308pub fn upload_qs_pirls(
3309 ws: &mut SigmaPirlsGpuWorkspace,
3310 qs: ndarray::ArrayView2<'_, f64>,
3311) -> Result<(), String> {
3312 cuda::upload_qs(ws, qs)
3313}
3314
3315#[cfg(target_os = "linux")]
3318pub fn upload_qs_identity_pirls(ws: &mut SigmaPirlsGpuWorkspace) -> Result<(), String> {
3319 cuda::upload_qs_identity(ws)
3320}
3321
3322#[cfg(target_os = "linux")]
3328pub fn solve_pirls_step_on_stream(
3329 shared: &PirlsGpuSharedData,
3330 ws: &mut SigmaPirlsGpuWorkspace,
3331 input: PirlsStepStreamInput<'_>,
3332) -> Result<PirlsGpuStep, String> {
3333 cuda::solve_step_on_stream(shared, ws, input)
3334}
3335
3336#[cfg(target_os = "linux")]
3344pub fn solve_pirls_step_on_stream_device(
3345 shared: &PirlsGpuSharedData,
3346 ws: &mut SigmaPirlsGpuWorkspace,
3347 input: PirlsStepStreamDeviceInput<'_, '_>,
3348) -> Result<PirlsGpuStep, String> {
3349 cuda::solve_step_on_stream_device(shared, ws, input)
3350}
3351
3352#[cfg(target_os = "linux")]
3362pub fn pirls_loop_on_stream(
3363 shared: &PirlsGpuSharedData,
3364 ws: &mut SigmaPirlsGpuWorkspace,
3365 loop_ws: &mut cuda::PirlsLoopWorkspace,
3366 family: crate::gpu_kernels::pirls_row::PirlsRowFamily,
3367 curvature: crate::gpu_kernels::pirls_row::CurvatureMode,
3368 gamma_shape: f64,
3370 beta0: ndarray::ArrayView1<'_, f64>,
3371 penalty_hessian: ndarray::ArrayView2<'_, f64>,
3372 linear_shift: ndarray::ArrayView1<'_, f64>,
3375 constant_shift: f64,
3377 step_lm_lambda: f64,
3378 objective_ridge: f64,
3379 max_iter: usize,
3380 tol: f64,
3381 extra: Option<&cuda::PirlsLoopExtra<'_>>,
3382) -> Result<cuda::PirlsLoopOutcome, String> {
3383 cuda::pirls_loop(
3384 shared,
3385 ws,
3386 loop_ws,
3387 family,
3388 curvature,
3389 gamma_shape,
3390 beta0,
3391 penalty_hessian,
3392 linear_shift,
3393 constant_shift,
3394 step_lm_lambda,
3395 objective_ridge,
3396 max_iter,
3397 tol,
3398 extra,
3399 )
3400}
3401
3402#[cfg(target_os = "linux")]
3405pub fn allocate_pirls_loop_workspace(
3406 shared: &PirlsGpuSharedData,
3407 ws: &SigmaPirlsGpuWorkspace,
3408) -> Result<cuda::PirlsLoopWorkspace, String> {
3409 cuda::PirlsLoopWorkspace::allocate(shared, &ws.stream)
3410}
3411
3412#[cfg(target_os = "linux")]
3418pub fn solve_gaussian_pls_gpu(
3419 a_orig: ndarray::ArrayView2<'_, f64>,
3420 b_orig: ndarray::ArrayView1<'_, f64>,
3421 s_transformed: ndarray::ArrayView2<'_, f64>,
3422 linear_shift: ndarray::ArrayView1<'_, f64>,
3423 prior_mean_target: ndarray::ArrayView1<'_, f64>,
3424 ridge: f64,
3425 qs: Option<ndarray::ArrayView2<'_, f64>>,
3426) -> Result<cuda::GaussianPlsResult, String> {
3427 cuda::solve_gaussian_pls_on_stream(
3428 a_orig,
3429 b_orig,
3430 s_transformed,
3431 linear_shift,
3432 prior_mean_target,
3433 ridge,
3434 qs,
3435 )
3436}
3437
3438mod cpu_fallback {
3446 use super::{PirlsGpuInput, PirlsGpuStep};
3447 use gam_linalg::faer_ndarray::FaerCholesky;
3448 use crate::estimate::reml::assembly::xt_diag_x_dense_into;
3449 use faer::Side;
3450 use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
3451
3452 pub(super) fn weighted_crossprod_cpu(
3453 x: ArrayView2<'_, f64>,
3454 weights: ArrayView1<'_, f64>,
3455 ) -> Result<Array2<f64>, String> {
3456 validate(x, weights)?;
3457 let x_owned = x.to_owned();
3458 let w_owned = weights.to_owned();
3459 let mut scratch = Array2::<f64>::zeros(x_owned.dim());
3460 Ok(xt_diag_x_dense_into(&x_owned, &w_owned, &mut scratch))
3461 }
3462
3463 pub(super) fn solve_step_cpu(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
3464 validate(input.x, input.weights)?;
3465 let (_n, p) = input.x.dim();
3466 if input.penalty_hessian.dim() != (p, p) {
3467 return Err(format!(
3468 "penalty Hessian shape {:?} does not match p={p}",
3469 input.penalty_hessian.dim()
3470 ));
3471 }
3472 if input.gradient.len() != p {
3473 return Err(format!(
3474 "gradient length {} does not match p={p}",
3475 input.gradient.len()
3476 ));
3477 }
3478 let xtwx = weighted_crossprod_cpu(input.x, input.weights)?;
3479 let mut penalized_hessian = xtwx.clone();
3481 penalized_hessian += &input.penalty_hessian;
3482 if input.objective_ridge != 0.0 {
3483 for i in 0..p {
3484 penalized_hessian[[i, i]] += input.objective_ridge;
3485 }
3486 }
3487 let mut h_step = xtwx;
3489 h_step += &input.penalty_hessian;
3490 if input.step_lm_lambda != 0.0 {
3491 for i in 0..p {
3492 h_step[[i, i]] += input.step_lm_lambda;
3493 }
3494 }
3495 let factor = h_step
3496 .cholesky(Side::Lower)
3497 .map_err(|e| format!("CPU Cholesky failed in PIRLS fallback: {e:?}"))?;
3498 let g = Array1::from_iter(input.gradient.iter().copied());
3499 let direction = factor.solvevec(&g);
3502 let logdet = 2.0 * factor.diag().iter().map(|v| v.ln()).sum::<f64>();
3504 Ok(PirlsGpuStep {
3505 penalized_hessian,
3506 direction,
3507 logdet,
3508 })
3509 }
3510
3511 fn validate(x: ArrayView2<'_, f64>, weights: ArrayView1<'_, f64>) -> Result<(), String> {
3512 let (n, p) = x.dim();
3513 if weights.len() != n {
3514 return Err(format!(
3515 "weights length {} does not match rows {n}",
3516 weights.len()
3517 ));
3518 }
3519 if n == 0 || p == 0 {
3520 return Err("empty design cannot be solved".to_string());
3521 }
3522 Ok(())
3523 }
3524}
3525
3526pub fn cholesky_solve_gpu(
3527 hessian: ArrayView2<'_, f64>,
3528 rhs: ArrayView2<'_, f64>,
3529) -> Result<(Array2<f64>, f64), String> {
3530 gam_gpu::solver::cholesky_solve_gpu(hessian, rhs)
3531}
3532
3533pub fn cholesky_solve_only_gpu(
3537 hessian: ArrayView2<'_, f64>,
3538 rhs: ArrayView2<'_, f64>,
3539) -> Result<Array2<f64>, String> {
3540 gam_gpu::solver::cholesky_solve_only_gpu(hessian, rhs)
3541}
3542
3543pub fn cholesky_lower_gpu(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
3544 gam_gpu::solver::cholesky_lower_gpu(hessian)
3545}
3546
3547#[cfg(all(test, target_os = "linux"))]
3553mod stream_device_parity_tests {
3554 use super::*;
3555 use ndarray::arr2;
3556
3557 #[test]
3558 fn device_input_step_matches_host_input_step_on_v100() {
3559 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3560 eprintln!("[stream_device_parity] no CUDA runtime — skipping");
3561 return;
3562 }
3563 let x = arr2(&[
3564 [1.0, 0.5, 0.1],
3565 [0.2, -0.3, 1.4],
3566 [0.7, 1.1, -0.2],
3567 [-0.4, 0.9, 0.6],
3568 [0.3, -0.8, 0.5],
3569 ]);
3570 let weights = ndarray::arr1(&[1.0, 0.8, 1.2, 0.9, 1.05]);
3571 let g_eta = ndarray::arr1(&[0.10_f64, -0.20, 0.05, 0.30, -0.15]);
3575 let gradient: ndarray::Array1<f64> = x.t().dot(&g_eta);
3576 let penalty = arr2(&[[0.4, 0.0, 0.0], [0.0, 0.9, 0.0], [0.0, 0.0, 1.2]]);
3577 let lm_ridge = 0.1;
3578
3579 let n = x.nrows();
3580 let y_dummy = ndarray::Array1::<f64>::zeros(n);
3581 let prior_w_dummy = ndarray::Array1::<f64>::ones(n);
3582 let offset_dummy = ndarray::Array1::<f64>::zeros(n);
3583 let shared = upload_shared_pirls_gpu(
3584 x.view(),
3585 y_dummy.view(),
3586 prior_w_dummy.view(),
3587 offset_dummy.view(),
3588 )
3589 .expect("upload shared design");
3590 let mut ws_host = allocate_sigma_pirls_workspace(&shared).expect("alloc host-input ws");
3591 let mut ws_dev = allocate_sigma_pirls_workspace(&shared).expect("alloc device-input ws");
3592
3593 let host_step = solve_pirls_step_on_stream(
3594 &shared,
3595 &mut ws_host,
3596 PirlsStepStreamInput {
3597 weights: weights.view(),
3598 penalty_hessian: penalty.view(),
3599 gradient: gradient.view(),
3600 step_lm_lambda: lm_ridge,
3601 objective_ridge: 0.0,
3602 },
3603 )
3604 .expect("host-input step");
3605
3606 let mut w_dev = ws_dev.stream.alloc_zeros::<f64>(n).expect("alloc w_dev");
3607 let mut g_dev = ws_dev.stream.alloc_zeros::<f64>(n).expect("alloc g_dev");
3608 ws_dev
3609 .stream
3610 .memcpy_htod(weights.as_slice().unwrap(), &mut w_dev)
3611 .expect("upload w_dev");
3612 ws_dev
3613 .stream
3614 .memcpy_htod(g_eta.as_slice().unwrap(), &mut g_dev)
3615 .expect("upload g_dev");
3616
3617 let beta_dev_test = ws_dev
3618 .stream
3619 .alloc_zeros::<f64>(x.ncols())
3620 .expect("alloc beta_dev_test");
3621 let linear_shift_test = ndarray::Array1::<f64>::zeros(x.ncols());
3622 let dev_step = solve_pirls_step_on_stream_device(
3623 &shared,
3624 &mut ws_dev,
3625 PirlsStepStreamDeviceInput {
3626 w_solver_dev: &w_dev,
3627 grad_eta_dev: &g_dev,
3628 penalty_hessian: penalty.view(),
3629 step_lm_lambda: lm_ridge,
3630 objective_ridge: 0.0,
3631 beta_dev: &beta_dev_test,
3632 linear_shift: linear_shift_test.view(),
3633 },
3634 )
3635 .expect("device-input step");
3636
3637 for i in 0..3 {
3640 for j in 0..3 {
3641 let diff = (host_step.penalized_hessian[[i, j]]
3642 - dev_step.penalized_hessian[[i, j]])
3643 .abs();
3644 assert!(diff <= 1e-10, "H[{i},{j}] mismatch: {diff}");
3645 }
3646 }
3647 assert!(
3648 (host_step.logdet - dev_step.logdet).abs() <= 1e-9,
3649 "logdet mismatch: host={} dev={}",
3650 host_step.logdet,
3651 dev_step.logdet
3652 );
3653 for i in 0..3 {
3656 let diff = (host_step.direction[i] - dev_step.direction[i]).abs();
3657 assert!(diff <= 1e-9, "direction[{i}] mismatch: {diff}");
3658 }
3659 }
3660
3661 #[test]
3670 fn hill_climb_loop_beats_cpu_10x_on_large_scale_logit() {
3671 use crate::gpu_kernels::pirls_row::{
3672 CurvatureMode, PirlsRowFamily, RowInput, row_reweight_cpu,
3673 };
3674 use std::time::Instant;
3675 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3676 eprintln!("[hill_climb] no CUDA runtime — skipping");
3677 return;
3678 }
3679 let n = 80_000_usize;
3680 let p = 44_usize;
3681 let beta_true: ndarray::Array1<f64> = ndarray::Array1::from_iter(
3683 (0..p).map(|j| 0.05 * ((j as f64) - 0.5 * p as f64) / p as f64),
3684 );
3685 let mut x = ndarray::Array2::<f64>::zeros((n, p));
3686 for i in 0..n {
3687 for j in 0..p {
3688 x[[i, j]] = ((i as f64 + j as f64 * 17.0) * 0.001).sin();
3689 }
3690 }
3691 let eta: ndarray::Array1<f64> = x.dot(&beta_true);
3692 let y: ndarray::Array1<f64> = eta
3693 .iter()
3694 .enumerate()
3695 .map(|(i, &e)| {
3696 let mu = 0.5 * (1.0 + (0.5 * e).tanh());
3697 if (i as f64 * 1.31).fract() < mu {
3698 1.0
3699 } else {
3700 0.0
3701 }
3702 })
3703 .collect();
3704 let prior_w = ndarray::Array1::<f64>::ones(n);
3705 let penalty = ndarray::Array2::<f64>::eye(p) * 1e-3;
3706 let beta0 = ndarray::Array1::<f64>::zeros(p);
3707
3708 let offset_bench = ndarray::Array1::<f64>::zeros(n);
3710 let shared =
3711 upload_shared_pirls_gpu(x.view(), y.view(), prior_w.view(), offset_bench.view())
3712 .expect("upload shared design");
3713 let mut ws = allocate_sigma_pirls_workspace(&shared).expect("alloc ws");
3714 let mut loop_ws = allocate_pirls_loop_workspace(&shared, &ws).expect("alloc loop_ws");
3715 let t0 = Instant::now();
3716 let linear_shift_zero = ndarray::Array1::<f64>::zeros(p);
3720 drop(
3721 pirls_loop_on_stream(
3722 &shared,
3723 &mut ws,
3724 &mut loop_ws,
3725 PirlsRowFamily::BernoulliLogit,
3726 CurvatureMode::Fisher,
3727 1.0,
3728 beta0.view(),
3729 penalty.view(),
3730 linear_shift_zero.view(),
3731 0.0,
3732 0.0,
3733 0.0,
3734 30,
3735 1e-6,
3736 None,
3737 )
3738 .expect("pirls loop"),
3739 );
3740 let gpu_secs = t0.elapsed().as_secs_f64();
3741
3742 let t1 = Instant::now();
3745 let mut beta = ndarray::Array1::<f64>::zeros(p);
3746 for _ in 0..30 {
3747 let eta: ndarray::Array1<f64> = x.dot(&beta);
3748 let mut w = ndarray::Array1::<f64>::zeros(n);
3749 let mut g = ndarray::Array1::<f64>::zeros(n);
3750 for i in 0..n {
3751 let out = row_reweight_cpu(
3752 PirlsRowFamily::BernoulliLogit,
3753 CurvatureMode::Fisher,
3754 RowInput {
3755 eta: eta[i],
3756 y: y[i],
3757 prior_weight: prior_w[i],
3758 },
3759 1.0,
3760 );
3761 w[i] = out.w_solver;
3762 g[i] = out.grad_eta;
3763 }
3764 let mut wx_full = x.clone();
3765 for j in 0..p {
3766 for i in 0..n {
3767 wx_full[[i, j]] *= w[i];
3768 }
3769 }
3770 let h = x.t().dot(&wx_full) + &penalty;
3771 let rhs = x.t().dot(&g);
3772 use gam_linalg::faer_ndarray::FaerCholesky;
3773 let chol = h
3774 .cholesky(faer::Side::Lower)
3775 .expect("CPU PIRLS reference Cholesky");
3776 let d = chol.solvevec(&rhs);
3777 for i in 0..p {
3778 beta[i] -= d[i];
3779 }
3780 }
3781 let cpu_secs = t1.elapsed().as_secs_f64();
3782
3783 let speedup = cpu_secs / gpu_secs;
3784 eprintln!(
3785 "[hill_climb] n={n} p={p} BernoulliLogit/Fisher: gpu={:.3}s cpu={:.3}s speedup={:.2}×",
3786 gpu_secs, cpu_secs, speedup
3787 );
3788 assert!(
3789 speedup >= 10.0,
3790 "GPU PIRLS loop must be ≥10× CPU at large-scale shape; got speedup={speedup:.2}× (gpu={gpu_secs:.3}s cpu={cpu_secs:.3}s)"
3791 );
3792 }
3793
3794 #[test]
3799 fn pirls_loop_converges_to_ols_solution_on_gaussian_identity() {
3800 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3801 eprintln!("[stage_3_3] no CUDA runtime — skipping");
3802 return;
3803 }
3804 let x = arr2(&[
3805 [1.0, 0.5, 0.1],
3806 [0.2, -0.3, 1.4],
3807 [0.7, 1.1, -0.2],
3808 [-0.4, 0.9, 0.6],
3809 [0.3, -0.8, 0.5],
3810 [1.1, 0.2, -0.4],
3811 [-0.6, 0.4, 0.3],
3812 [0.8, -1.0, 0.7],
3813 ]);
3814 let n = x.nrows();
3815 let p = x.ncols();
3816 let beta_true = ndarray::arr1(&[0.5_f64, -1.2, 0.3]);
3818 let y: ndarray::Array1<f64> = x.dot(&beta_true);
3819 let prior_w = ndarray::Array1::<f64>::ones(n);
3820 let penalty = ndarray::Array2::<f64>::eye(p) * 1e-4; let beta0 = ndarray::Array1::<f64>::zeros(p);
3822
3823 let offset_ols = ndarray::Array1::<f64>::zeros(n);
3824 let shared = upload_shared_pirls_gpu(x.view(), y.view(), prior_w.view(), offset_ols.view())
3825 .expect("upload shared design");
3826 let mut ws = allocate_sigma_pirls_workspace(&shared).expect("alloc ws");
3827 let mut loop_ws = allocate_pirls_loop_workspace(&shared, &ws).expect("alloc loop_ws");
3828
3829 let linear_shift_zero = ndarray::Array1::<f64>::zeros(p);
3833 let outcome = pirls_loop_on_stream(
3834 &shared,
3835 &mut ws,
3836 &mut loop_ws,
3837 crate::gpu_kernels::pirls_row::PirlsRowFamily::GaussianIdentity,
3838 crate::gpu_kernels::pirls_row::CurvatureMode::Fisher,
3839 1.0,
3840 beta0.view(),
3841 penalty.view(),
3842 linear_shift_zero.view(),
3843 0.0,
3844 0.0,
3845 0.0,
3846 20,
3847 1e-9,
3848 None,
3849 )
3850 .expect("pirls loop");
3851
3852 let xtx = x.t().dot(&x);
3854 let xty = x.t().dot(&y);
3855 let h_ref = xtx + &penalty;
3856 use gam_linalg::faer_ndarray::FaerCholesky;
3858 let chol = h_ref
3859 .cholesky(faer::Side::Lower)
3860 .expect("OLS reference Cholesky");
3861 let beta_ref: ndarray::Array1<f64> = chol.solvevec(&xty);
3862
3863 assert!(
3868 outcome.converged || outcome.iterations <= 5,
3869 "PIRLS loop did not converge in 20 iters on Gaussian-identity (iters={})",
3870 outcome.iterations
3871 );
3872 for i in 0..p {
3873 let diff = (outcome.beta[i] - beta_ref[i]).abs();
3874 assert!(
3875 diff <= 1e-6,
3876 "β[{i}] mismatch: gpu={} ref={} diff={}",
3877 outcome.beta[i],
3878 beta_ref[i],
3879 diff
3880 );
3881 }
3882 for i in 0..p {
3885 for j in 0..p {
3886 let diff = (outcome.penalized_hessian[[i, j]] - h_ref[[i, j]]).abs();
3887 assert!(diff <= 1e-8, "H[{i},{j}] mismatch: {diff}");
3888 }
3889 }
3890 }
3891}
3892
3893#[cfg(test)]
3902mod weighted_crossprod_cpu_fallback_tests {
3903 use super::weighted_crossprod_gpu;
3904 use ndarray::{Array1, Array2};
3905
3906 #[test]
3907 fn weighted_crossprod_gpu_cpu_fallback_matches_dense_xtwx() {
3908 let x = Array2::<f64>::from_shape_fn((4, 3), |(i, j)| (i + j) as f64 + 1.0);
3911 let w = Array1::<f64>::from_vec(vec![0.5, 1.0, 1.5, 2.0]);
3912
3913 let got = weighted_crossprod_gpu(x.view(), w.view())
3914 .expect("weighted_crossprod_gpu must return Ok via CPU fallback on a CPU-only host");
3915
3916 let (n, p) = x.dim();
3918 let mut expected = Array2::<f64>::zeros((p, p));
3919 for k in 0..n {
3920 for i in 0..p {
3921 for j in 0..p {
3922 expected[[i, j]] += w[k] * x[[k, i]] * x[[k, j]];
3923 }
3924 }
3925 }
3926
3927 assert_eq!(got.dim(), (p, p));
3928 for i in 0..p {
3929 for j in 0..p {
3930 let diff = (got[[i, j]] - expected[[i, j]]).abs();
3931 assert!(diff <= 1e-10, "XtWX[{i},{j}] mismatch: got vs expected diff={diff}");
3932 }
3933 }
3934 }
3935}