1#[cfg(target_os = "linux")]
26pub struct DeviceResidentPcgInput<'a> {
27 pub storage: &'a crate::bms::gpu::row::DeviceResidentRowHess,
30 pub b: &'a [f64],
32 pub rel_tol: f64,
34 pub max_iters: usize,
36 pub precond_diag_floor: f64,
40}
41
42#[cfg(target_os = "linux")]
44pub struct DeviceResidentPcgOutput {
45 pub x: Vec<f64>,
47 pub iterations: usize,
50 pub final_rel_residual: f64,
52}
53
54#[cfg(target_os = "linux")]
58const PCG_KERNEL_SOURCE: &str = r#"
59// y[i] += a * x[i]
60extern "C" __global__ void pcg_axpy(int n, double a,
61 const double * __restrict__ x,
62 double * __restrict__ y)
63{
64 int i = blockIdx.x * blockDim.x + threadIdx.x;
65 if (i < n) y[i] += a * x[i];
66}
67
68// y[i] = a * x[i] + b * y[i]
69extern "C" __global__ void pcg_axpby(int n, double a,
70 const double * __restrict__ x,
71 double b,
72 double * __restrict__ y)
73{
74 int i = blockIdx.x * blockDim.x + threadIdx.x;
75 if (i < n) y[i] = a * x[i] + b * y[i];
76}
77
78// z[i] = r[i] / clamp(diag[i], floor) (sign-preserving floor on |diag|).
79extern "C" __global__ void pcg_apply_diag_precond(int n, double floor_val,
80 const double * __restrict__ diag,
81 const double * __restrict__ r,
82 double * __restrict__ z)
83{
84 int i = blockIdx.x * blockDim.x + threadIdx.x;
85 if (i < n) {
86 double d = diag[i];
87 double ad = d < 0 ? -d : d;
88 double clamped = ad > floor_val ? d : (d >= 0.0 ? floor_val : -floor_val);
89 z[i] = r[i] / clamped;
90 }
91}
92
93// Single-block dot product; writes the scalar to out[0]. n must be <= 1024.
94extern "C" __global__ void pcg_dot_single_block(int n,
95 const double * __restrict__ a,
96 const double * __restrict__ b,
97 double * __restrict__ out)
98{
99 __shared__ double s[1024];
100 int tid = threadIdx.x;
101 double acc = 0.0;
102 for (int i = tid; i < n; i += blockDim.x) acc += a[i] * b[i];
103 s[tid] = acc;
104 __syncthreads();
105 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
106 if (tid < stride) s[tid] += s[tid + stride];
107 __syncthreads();
108 }
109 if (tid == 0) out[0] = s[0];
110}
111
112// Set out[i] = 0 for i in [0, n).
113extern "C" __global__ void pcg_init_zero(int n, double * __restrict__ out) {
114 int i = blockIdx.x * blockDim.x + threadIdx.x;
115 if (i < n) out[i] = 0.0;
116}
117
118// Copy y[i] = x[i].
119extern "C" __global__ void pcg_copy(int n,
120 const double * __restrict__ x,
121 double * __restrict__ y)
122{
123 int i = blockIdx.x * blockDim.x + threadIdx.x;
124 if (i < n) y[i] = x[i];
125}
126"#;
127
128#[cfg(target_os = "linux")]
129mod pcg_device {
130 use super::DeviceResidentPcgInput;
131 use super::DeviceResidentPcgOutput;
132 use super::PCG_KERNEL_SOURCE;
133 use crate::bms::gpu::row::launch_bms_flex_row_diagonal;
134 use crate::bms::gpu::row::launch_bms_flex_row_hvp_into_device;
135 use cudarc::driver::{CudaModule, CudaStream, LaunchConfig, PushKernelArg};
136 use std::sync::{Arc, OnceLock};
137
138 struct PcgBackend {
139 stream: Arc<CudaStream>,
140 module: Arc<CudaModule>,
141 }
142
143 impl PcgBackend {
144 fn probe() -> Result<&'static Self, String> {
145 static BACKEND: OnceLock<Result<PcgBackend, String>> = OnceLock::new();
146 BACKEND
147 .get_or_init(|| {
148 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
149 .ok_or_else(|| "pcg backend: no CUDA runtime available".to_string())?;
150 let ctx = gam_gpu::device_runtime::cuda_context_for(
151 runtime.selected_device().ordinal,
152 )
153 .ok_or_else(|| {
154 format!(
155 "pcg backend: failed to create CUDA context for device {}",
156 runtime.selected_device().ordinal
157 )
158 })?;
159 let stream = ctx.default_stream();
160 let ptx = gam_gpu::device_cache::compile_ptx_arch(PCG_KERNEL_SOURCE)
166 .map_err(|err| format!("pcg NVRTC compile failed: {err}"))?;
167 let module = ctx
168 .load_module(ptx)
169 .map_err(|err| format!("pcg module load failed: {err}"))?;
170 Ok(PcgBackend { stream, module })
171 })
172 .as_ref()
173 .map_err(String::clone)
174 }
175 }
176
177 fn launch_blocks(p: usize, threads: u32) -> u32 {
178 ((p as u32) + threads - 1) / threads
179 }
180
181 pub(super) fn run(
186 input: DeviceResidentPcgInput<'_>,
187 ) -> Result<DeviceResidentPcgOutput, String> {
188 let p = input.storage.block.p_total;
189 if input.b.len() != p {
190 return Err(format!(
191 "device-resident pcg: b.len()={} != p_total={p}",
192 input.b.len()
193 ));
194 }
195 if !input.rel_tol.is_finite() || input.rel_tol <= 0.0 {
196 return Err(format!(
197 "device-resident pcg: rel_tol must be positive and finite (got {})",
198 input.rel_tol
199 ));
200 }
201 if input.max_iters == 0 {
202 return Err("device-resident pcg: max_iters must be >= 1".to_string());
203 }
204 if !input.precond_diag_floor.is_finite() || input.precond_diag_floor <= 0.0 {
205 return Err(format!(
206 "device-resident pcg: precond_diag_floor must be positive and finite (got {})",
207 input.precond_diag_floor
208 ));
209 }
210
211 let backend = PcgBackend::probe()?;
212 let stream = backend.stream.clone();
213 let module = backend.module.clone();
214
215 let f_axpy = module
217 .load_function("pcg_axpy")
218 .map_err(|e| format!("pcg load pcg_axpy: {e}"))?;
219 let f_axpby = module
220 .load_function("pcg_axpby")
221 .map_err(|e| format!("pcg load pcg_axpby: {e}"))?;
222 let f_precond = module
223 .load_function("pcg_apply_diag_precond")
224 .map_err(|e| format!("pcg load pcg_apply_diag_precond: {e}"))?;
225 let f_dot = module
226 .load_function("pcg_dot_single_block")
227 .map_err(|e| format!("pcg load pcg_dot_single_block: {e}"))?;
228 let f_copy = module
229 .load_function("pcg_copy")
230 .map_err(|e| format!("pcg load pcg_copy: {e}"))?;
231
232 let mut d_x = stream
234 .alloc_zeros::<f64>(p)
235 .map_err(|e| format!("pcg alloc x: {e}"))?;
236 let mut d_r = stream
237 .clone_htod(input.b)
238 .map_err(|e| format!("pcg upload b -> r: {e}"))?;
239 let mut d_z = stream
240 .alloc_zeros::<f64>(p)
241 .map_err(|e| format!("pcg alloc z: {e}"))?;
242 let mut d_p = stream
243 .alloc_zeros::<f64>(p)
244 .map_err(|e| format!("pcg alloc p: {e}"))?;
245 let mut d_q = stream
246 .alloc_zeros::<f64>(p)
247 .map_err(|e| format!("pcg alloc q: {e}"))?;
248 let mut d_scalar = stream
251 .alloc_zeros::<f64>(1)
252 .map_err(|e| format!("pcg alloc scalar: {e}"))?;
253
254 let diag_host = launch_bms_flex_row_diagonal(input.storage)
258 .map_err(|e| format!("pcg diag fetch: {e}"))?;
259 if diag_host.len() != p {
260 return Err(format!(
261 "pcg: diag length {} != p_total {p}",
262 diag_host.len()
263 ));
264 }
265 let d_diag = stream
266 .clone_htod(&diag_host)
267 .map_err(|e| format!("pcg upload diag: {e}"))?;
268
269 let n_i32 = i32::try_from(p).map_err(|_| format!("pcg: p_total={p} exceeds i32 range"))?;
271 let vec_threads: u32 = 64;
272 let vec_blocks = launch_blocks(p, vec_threads);
273 let dot_threads: u32 = match p {
274 0..=64 => 64,
275 65..=128 => 128,
276 129..=256 => 256,
277 257..=512 => 512,
278 _ => 1024,
279 };
280 if p > 1024 {
281 return Err(format!(
282 "device-resident pcg: p_total={p} exceeds single-block dot capacity (1024); \
283 widen pcg_dot_single_block to multi-block reduce before raising the cap"
284 ));
285 }
286
287 unsafe {
295 stream
296 .launch_builder(&f_dot)
297 .arg(&n_i32)
298 .arg(&d_r)
299 .arg(&d_r)
300 .arg(&mut d_scalar)
301 .launch(LaunchConfig {
302 grid_dim: (1, 1, 1),
303 block_dim: (dot_threads, 1, 1),
304 shared_mem_bytes: 0,
305 })
306 }
307 .map_err(|e| format!("pcg b·b launch: {e}"))?;
308 stream
309 .synchronize()
310 .map_err(|e| format!("pcg b·b sync: {e}"))?;
311 let host_scalar = stream
312 .clone_dtoh(&d_scalar)
313 .map_err(|e| format!("pcg b·b download: {e}"))?;
314 let bb = host_scalar[0];
315 if !bb.is_finite() {
316 return Err(format!("pcg: b·b not finite ({bb})"));
317 }
318 let b_norm = bb.sqrt();
319 if b_norm == 0.0 {
320 return Ok(DeviceResidentPcgOutput {
322 x: vec![0.0; p],
323 iterations: 0,
324 final_rel_residual: 0.0,
325 });
326 }
327
328 unsafe {
334 stream
335 .launch_builder(&f_precond)
336 .arg(&n_i32)
337 .arg(&input.precond_diag_floor)
338 .arg(&d_diag)
339 .arg(&d_r)
340 .arg(&mut d_z)
341 .launch(LaunchConfig {
342 grid_dim: (vec_blocks, 1, 1),
343 block_dim: (vec_threads, 1, 1),
344 shared_mem_bytes: 0,
345 })
346 }
347 .map_err(|e| format!("pcg precond z₀: {e}"))?;
348
349 unsafe {
355 stream
356 .launch_builder(&f_copy)
357 .arg(&n_i32)
358 .arg(&d_z)
359 .arg(&mut d_p)
360 .launch(LaunchConfig {
361 grid_dim: (vec_blocks, 1, 1),
362 block_dim: (vec_threads, 1, 1),
363 shared_mem_bytes: 0,
364 })
365 }
366 .map_err(|e| format!("pcg copy p₀: {e}"))?;
367
368 unsafe {
374 stream
375 .launch_builder(&f_dot)
376 .arg(&n_i32)
377 .arg(&d_r)
378 .arg(&d_z)
379 .arg(&mut d_scalar)
380 .launch(LaunchConfig {
381 grid_dim: (1, 1, 1),
382 block_dim: (dot_threads, 1, 1),
383 shared_mem_bytes: 0,
384 })
385 }
386 .map_err(|e| format!("pcg ρ₀ launch: {e}"))?;
387 stream
388 .synchronize()
389 .map_err(|e| format!("pcg ρ₀ sync: {e}"))?;
390 let s = stream
391 .clone_dtoh(&d_scalar)
392 .map_err(|e| format!("pcg ρ₀ download: {e}"))?;
393 let mut rho = s[0];
394 if !rho.is_finite() {
395 return Err(format!("pcg: ρ₀ not finite ({rho})"));
396 }
397
398 let mut iters_taken: usize = 0;
399 let mut final_rel_residual: f64 = (bb.sqrt() / b_norm).max(0.0);
400 for iter in 0..input.max_iters {
401 iters_taken = iter + 1;
402
403 launch_bms_flex_row_hvp_into_device(input.storage, &d_p, &mut d_q)
405 .map_err(|e| format!("pcg Hv iter {iter}: {e}"))?;
406
407 unsafe {
413 stream
414 .launch_builder(&f_dot)
415 .arg(&n_i32)
416 .arg(&d_p)
417 .arg(&d_q)
418 .arg(&mut d_scalar)
419 .launch(LaunchConfig {
420 grid_dim: (1, 1, 1),
421 block_dim: (dot_threads, 1, 1),
422 shared_mem_bytes: 0,
423 })
424 }
425 .map_err(|e| format!("pcg p·q launch iter {iter}: {e}"))?;
426 stream
427 .synchronize()
428 .map_err(|e| format!("pcg p·q sync iter {iter}: {e}"))?;
429 let s = stream
430 .clone_dtoh(&d_scalar)
431 .map_err(|e| format!("pcg p·q download iter {iter}: {e}"))?;
432 let pq = s[0];
433 if !pq.is_finite() || pq == 0.0 {
434 return Err(format!(
435 "pcg iter {iter}: p·q={pq} (non-finite or zero); operator is not positive-definite"
436 ));
437 }
438 let alpha = rho / pq;
439
440 unsafe {
447 stream
448 .launch_builder(&f_axpy)
449 .arg(&n_i32)
450 .arg(&alpha)
451 .arg(&d_p)
452 .arg(&mut d_x)
453 .launch(LaunchConfig {
454 grid_dim: (vec_blocks, 1, 1),
455 block_dim: (vec_threads, 1, 1),
456 shared_mem_bytes: 0,
457 })
458 }
459 .map_err(|e| format!("pcg x+=αp iter {iter}: {e}"))?;
460
461 let neg_alpha = -alpha;
463 unsafe {
468 stream
469 .launch_builder(&f_axpy)
470 .arg(&n_i32)
471 .arg(&neg_alpha)
472 .arg(&d_q)
473 .arg(&mut d_r)
474 .launch(LaunchConfig {
475 grid_dim: (vec_blocks, 1, 1),
476 block_dim: (vec_threads, 1, 1),
477 shared_mem_bytes: 0,
478 })
479 }
480 .map_err(|e| format!("pcg r-=αq iter {iter}: {e}"))?;
481
482 unsafe {
487 stream
488 .launch_builder(&f_dot)
489 .arg(&n_i32)
490 .arg(&d_r)
491 .arg(&d_r)
492 .arg(&mut d_scalar)
493 .launch(LaunchConfig {
494 grid_dim: (1, 1, 1),
495 block_dim: (dot_threads, 1, 1),
496 shared_mem_bytes: 0,
497 })
498 }
499 .map_err(|e| format!("pcg ‖r‖₂² launch iter {iter}: {e}"))?;
500 stream
501 .synchronize()
502 .map_err(|e| format!("pcg ‖r‖₂² sync iter {iter}: {e}"))?;
503 let s = stream
504 .clone_dtoh(&d_scalar)
505 .map_err(|e| format!("pcg ‖r‖₂² download iter {iter}: {e}"))?;
506 let rr = s[0];
507 if !rr.is_finite() {
508 return Err(format!("pcg iter {iter}: ‖r‖₂²={rr} non-finite"));
509 }
510 let rel = rr.sqrt() / b_norm;
511 final_rel_residual = rel;
512 if rel <= input.rel_tol {
513 break;
514 }
515
516 unsafe {
522 stream
523 .launch_builder(&f_precond)
524 .arg(&n_i32)
525 .arg(&input.precond_diag_floor)
526 .arg(&d_diag)
527 .arg(&d_r)
528 .arg(&mut d_z)
529 .launch(LaunchConfig {
530 grid_dim: (vec_blocks, 1, 1),
531 block_dim: (vec_threads, 1, 1),
532 shared_mem_bytes: 0,
533 })
534 }
535 .map_err(|e| format!("pcg z=M⁻¹r iter {iter}: {e}"))?;
536
537 unsafe {
541 stream
542 .launch_builder(&f_dot)
543 .arg(&n_i32)
544 .arg(&d_r)
545 .arg(&d_z)
546 .arg(&mut d_scalar)
547 .launch(LaunchConfig {
548 grid_dim: (1, 1, 1),
549 block_dim: (dot_threads, 1, 1),
550 shared_mem_bytes: 0,
551 })
552 }
553 .map_err(|e| format!("pcg ρ_new launch iter {iter}: {e}"))?;
554 stream
555 .synchronize()
556 .map_err(|e| format!("pcg ρ_new sync iter {iter}: {e}"))?;
557 let s = stream
558 .clone_dtoh(&d_scalar)
559 .map_err(|e| format!("pcg ρ_new download iter {iter}: {e}"))?;
560 let rho_new = s[0];
561 if !rho_new.is_finite() {
562 return Err(format!("pcg iter {iter}: ρ_new={rho_new} non-finite"));
563 }
564 let beta_pcg = rho_new / rho;
565
566 unsafe {
573 stream
574 .launch_builder(&f_axpby)
575 .arg(&n_i32)
576 .arg(&1.0_f64)
577 .arg(&d_z)
578 .arg(&beta_pcg)
579 .arg(&mut d_p)
580 .launch(LaunchConfig {
581 grid_dim: (vec_blocks, 1, 1),
582 block_dim: (vec_threads, 1, 1),
583 shared_mem_bytes: 0,
584 })
585 }
586 .map_err(|e| format!("pcg p=z+βp iter {iter}: {e}"))?;
587
588 rho = rho_new;
589 }
590
591 let x_host = stream
593 .clone_dtoh(&d_x)
594 .map_err(|e| format!("pcg final x DtoH: {e}"))?;
595 drop(d_r);
598 drop(d_z);
599 drop(d_p);
600 drop(d_q);
601 drop(d_scalar);
602 drop(d_diag);
603 Ok(DeviceResidentPcgOutput {
604 x: x_host,
605 iterations: iters_taken,
606 final_rel_residual,
607 })
608 }
609}
610
611#[cfg(target_os = "linux")]
622pub fn run_pcg_against_row_hessian_device(
623 input: DeviceResidentPcgInput<'_>,
624) -> Result<DeviceResidentPcgOutput, String> {
625 pcg_device::run(input)
626}
627
628#[cfg(all(test, target_os = "linux"))]
635mod pcg_device_parity_tests {
636 use super::*;
637 use crate::bms::gpu::row::{
638 BmsFlexBlockLayout, BmsFlexPrimaryLayout, DeviceResidentRowHess,
639 };
640 use ndarray::Array2;
641
642 fn cpu_dense_joint_hessian(
645 row_hessians: &[f64],
646 marginal: &[f64],
647 logslope: &[f64],
648 block: &BmsFlexBlockLayout,
649 primary: &BmsFlexPrimaryLayout,
650 n: usize,
651 ) -> Array2<f64> {
652 let p_total = block.p_total;
653 let r = primary.r;
654 let p_m = block.p_m;
655 let p_g = block.p_g;
656 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
657 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
658 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
659 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
660 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
661 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
662 let mut h_dense = Array2::<f64>::zeros((p_total, p_total));
663 let mut phi = vec![vec![0.0_f64; p_total]; r];
665 for row in 0..n {
666 for col in phi.iter_mut() {
667 col.iter_mut().for_each(|v| *v = 0.0);
668 }
669 let mrow = &marginal[row * p_m..(row + 1) * p_m];
670 let grow = &logslope[row * p_g..(row + 1) * p_g];
671 for k in 0..p_m {
672 phi[0][k] = mrow[k];
673 }
674 for k in 0..p_g {
675 phi[1][p_m + k] = grow[k];
676 }
677 for k in 0..h_block_len {
678 phi[h_primary_start + k][h_block_start + k] = 1.0;
679 }
680 for k in 0..w_block_len {
681 phi[w_primary_start + k][w_block_start + k] = 1.0;
682 }
683 let h_row = &row_hessians[row * r * r..(row + 1) * r * r];
684 for u in 0..r {
685 for v in 0..r {
686 let huv = h_row[u * r + v];
687 if huv == 0.0 {
688 continue;
689 }
690 for m in 0..p_total {
691 let phim = phi[u][m];
692 if phim == 0.0 {
693 continue;
694 }
695 let scaled = huv * phim;
696 for nn in 0..p_total {
697 h_dense[[m, nn]] += scaled * phi[v][nn];
698 }
699 }
700 }
701 }
702 }
703 h_dense
704 }
705
706 fn cpu_pcg_oracle(h: &Array2<f64>, b: &[f64], rel_tol: f64) -> Vec<f64> {
712 let p = b.len();
713 let diag: ndarray::Array1<f64> =
714 ndarray::Array1::from_vec((0..p).map(|i| h[[i, i]]).collect());
715 let rhs = ndarray::Array1::from_vec(b.to_vec());
716 let h_owned = h.clone();
717 let apply = move |v: &ndarray::Array1<f64>| h_owned.dot(v);
718 let (x, info) =
719 gam_linalg::utils::solve_spd_pcg_with_info(apply, &rhs, &diag, rel_tol, 4 * p)
720 .expect("host PCG oracle must converge on SPD fixture");
721 assert!(
722 info.converged,
723 "host PCG oracle failed to converge: iters={} rel_res={}",
724 info.iterations, info.relative_residual_norm
725 );
726 x.to_vec()
727 }
728
729 #[test]
730 fn pcg_device_matches_dense_oracle_at_n64_r20_p44() {
731 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
732 eprintln!("[pcg_device parity] no CUDA runtime — skipping");
733 return;
734 };
735 let n = 64_usize;
736 let p_m = 14_usize;
737 let p_g = 12_usize;
738 let p_h_dim = 10_usize;
739 let p_w_dim = 8_usize;
740 let r = 2 + p_h_dim + p_w_dim;
741 let p_total = p_m + p_g + p_h_dim + p_w_dim;
742 let block = BmsFlexBlockLayout {
743 p_m,
744 p_g,
745 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
746 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
747 p_total,
748 };
749 let primary = BmsFlexPrimaryLayout {
750 h: Some(2..2 + p_h_dim),
751 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
752 r,
753 };
754
755 let mut row_hessians = vec![0.0_f64; n * r * r];
758 for row in 0..n {
759 let base = row * r * r;
760 for u in 0..r {
761 for v in 0..r {
762 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
763 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
764 row_hessians[base + u * r + v] = a;
765 }
766 }
767 for u in 0..r {
768 for v in (u + 1)..r {
769 let upper = row_hessians[base + u * r + v];
770 let lower = row_hessians[base + v * r + u];
771 let sym = 0.5 * (upper + lower);
772 row_hessians[base + u * r + v] = sym;
773 row_hessians[base + v * r + u] = sym;
774 }
775 row_hessians[base + u * r + u] += 4.0 * (r as f64);
779 }
780 }
781 let mut marginal = vec![0.0_f64; n * p_m];
782 for row in 0..n {
783 for j in 0..p_m {
784 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
785 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
786 }
787 }
788 let mut logslope = vec![0.0_f64; n * p_g];
789 for row in 0..n {
790 for j in 0..p_g {
791 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
792 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
793 }
794 }
795
796 let b: Vec<f64> = (0..p_total)
798 .map(|i| {
799 let seed = (i as f64) * 0.157 + 0.6;
800 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
801 })
802 .collect();
803
804 let h_dense =
805 cpu_dense_joint_hessian(&row_hessians, &marginal, &logslope, &block, &primary, n);
806 let x_oracle = cpu_pcg_oracle(&h_dense, &b, 1e-12);
807
808 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
813 .expect("runtime must exist when probe succeeded above");
814 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
819 .expect("[pcg_device parity] cuda_context_for must succeed on a CUDA host");
820 let stream = ctx.default_stream();
821 let d_h = stream
822 .clone_htod(&row_hessians)
823 .expect("[pcg_device parity] upload h must succeed on a CUDA host");
824 let d_m = stream
825 .clone_htod(&marginal)
826 .expect("[pcg_device parity] upload marginal must succeed on a CUDA host");
827 let d_g = stream
828 .clone_htod(&logslope)
829 .expect("[pcg_device parity] upload logslope must succeed on a CUDA host");
830 let storage = DeviceResidentRowHess {
831 hess: d_h,
832 marginal_design: d_m,
833 logslope_design: d_g,
834 n,
835 r,
836 block,
837 primary,
838
839 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
840 };
841
842 let out = run_pcg_against_row_hessian_device(DeviceResidentPcgInput {
843 storage: &storage,
844 b: &b,
845 rel_tol: 1e-10,
846 max_iters: 4 * p_total,
847 precond_diag_floor: 1e-12,
848 })
849 .expect("device-resident PCG must succeed on SPD fixture");
850
851 assert_eq!(out.x.len(), p_total);
852 let mut max_abs = 0.0_f64;
853 for i in 0..p_total {
854 let diff = (out.x[i] - x_oracle[i]).abs();
855 if diff > max_abs {
856 max_abs = diff;
857 }
858 }
859 assert!(
863 max_abs <= 1e-7,
864 "pcg_device parity ‖Δx‖∞={max_abs:.3e} > 1e-7 after {} iters \
865 (final rel residual={:.3e})",
866 out.iterations,
867 out.final_rel_residual
868 );
869 eprintln!(
870 "[pcg_device parity] n={n} p={p_total} r={r}: iters={} rel_res={:.3e} ‖Δx‖∞={:.3e}",
871 out.iterations, out.final_rel_residual, max_abs
872 );
873 }
874}