Skip to main content

gam_models/bms/gpu/
device_pcg.rs

1// ────────────────────────────────────────────────────────────────────────
2// Block 9 Phase 5 — device-resident PCG against the BMS-FLEX row-Hessian
3// operator.
4//
5// The inner Newton solve in `BernoulliMarginalSlope` (matrix-free path,
6// large-scale shape n=195k, p=44, r=20) currently reaches the GPU as a
7// per-CG-iteration call to `launch_bms_flex_row_hvp` returning a host
8// `Vec<f64>`. With ~6400 inner CG iterations per outer iteration that round-
9// trip cost dominates: each iter pays one `stream.synchronize()` plus one
10// DtoH download. At p=44 the download itself is 352 bytes — trivial in
11// bandwidth, painful in latency.
12//
13// Phase 5 keeps every PCG vector on the device and runs the outer loop with
14// only a single small scalar download per iteration (the squared residual
15// norm for the convergence check). The Hv kernel becomes `into_device`
16// (Block 9 addition to `bms_flex_row.rs`), and the axpy / dot / diagonal-
17// preconditioner / scale-and-add steps run as tiny NVRTC kernels on the
18// same default stream so the sequence is implicitly ordered without sync.
19// ────────────────────────────────────────────────────────────────────────
20
21/// Inputs to [`run_pcg_against_row_hessian_device`]. The right-hand-side
22/// `b` is supplied as a host slice (it is the only host-resident vector
23/// that needs to enter the loop — the iterate, residual, search direction,
24/// and Hv output all live on the device).
25#[cfg(target_os = "linux")]
26pub struct DeviceResidentPcgInput<'a> {
27    /// Per-fit row-Hessian + design storage. The PCG operator is
28    /// `v ↦ launch_bms_flex_row_hvp_into_device(storage, ...)`.
29    pub storage: &'a crate::bms::gpu::row::DeviceResidentRowHess,
30    /// Right-hand-side `b`, length `storage.block.p_total`. Uploaded once.
31    pub b: &'a [f64],
32    /// Convergence tolerance on relative residual `‖r‖₂ / ‖b‖₂`.
33    pub rel_tol: f64,
34    /// Hard cap on iterations (the inner loop also bails on stagnation).
35    pub max_iters: usize,
36    /// Floor on `|diag(H)[i]|` used by the Jacobi preconditioner. Set to
37    /// `1e-12` for the matrix-free row-Hessian path; the row-primary
38    /// Hessian's diagonal is positive-definite by construction.
39    pub precond_diag_floor: f64,
40}
41
42/// Output of [`run_pcg_against_row_hessian_device`].
43#[cfg(target_os = "linux")]
44pub struct DeviceResidentPcgOutput {
45    /// Solution `x` such that `H · x ≈ b`, length `storage.block.p_total`.
46    pub x: Vec<f64>,
47    /// Number of PCG iterations consumed (final iter does not count if it
48    /// converged immediately after the dot reduction).
49    pub iterations: usize,
50    /// Final achieved relative residual `‖r‖₂ / ‖b‖₂`.
51    pub final_rel_residual: f64,
52}
53
54/// NVRTC source for the Phase-5 device-resident PCG support kernels. Every
55/// kernel here operates on length-`p` device vectors with `p` typically
56/// 44–256, so a single CTA suffices for each.
57#[cfg(target_os = "linux")]
58const PCG_KERNEL_SOURCE: &str = r#"
59// y[i] += a * x[i]
60extern "C" __global__ void pcg_axpy(int n, double a,
61                                    const double * __restrict__ x,
62                                    double * __restrict__ y)
63{
64    int i = blockIdx.x * blockDim.x + threadIdx.x;
65    if (i < n) y[i] += a * x[i];
66}
67
68// y[i] = a * x[i] + b * y[i]
69extern "C" __global__ void pcg_axpby(int n, double a,
70                                     const double * __restrict__ x,
71                                     double b,
72                                     double * __restrict__ y)
73{
74    int i = blockIdx.x * blockDim.x + threadIdx.x;
75    if (i < n) y[i] = a * x[i] + b * y[i];
76}
77
78// z[i] = r[i] / clamp(diag[i], floor) (sign-preserving floor on |diag|).
79extern "C" __global__ void pcg_apply_diag_precond(int n, double floor_val,
80                                                  const double * __restrict__ diag,
81                                                  const double * __restrict__ r,
82                                                  double * __restrict__ z)
83{
84    int i = blockIdx.x * blockDim.x + threadIdx.x;
85    if (i < n) {
86        double d = diag[i];
87        double ad = d < 0 ? -d : d;
88        double clamped = ad > floor_val ? d : (d >= 0.0 ? floor_val : -floor_val);
89        z[i] = r[i] / clamped;
90    }
91}
92
93// Single-block dot product; writes the scalar to out[0]. n must be <= 1024.
94extern "C" __global__ void pcg_dot_single_block(int n,
95                                                const double * __restrict__ a,
96                                                const double * __restrict__ b,
97                                                double * __restrict__ out)
98{
99    __shared__ double s[1024];
100    int tid = threadIdx.x;
101    double acc = 0.0;
102    for (int i = tid; i < n; i += blockDim.x) acc += a[i] * b[i];
103    s[tid] = acc;
104    __syncthreads();
105    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
106        if (tid < stride) s[tid] += s[tid + stride];
107        __syncthreads();
108    }
109    if (tid == 0) out[0] = s[0];
110}
111
112// Set out[i] = 0 for i in [0, n).
113extern "C" __global__ void pcg_init_zero(int n, double * __restrict__ out) {
114    int i = blockIdx.x * blockDim.x + threadIdx.x;
115    if (i < n) out[i] = 0.0;
116}
117
118// Copy y[i] = x[i].
119extern "C" __global__ void pcg_copy(int n,
120                                    const double * __restrict__ x,
121                                    double * __restrict__ y)
122{
123    int i = blockIdx.x * blockDim.x + threadIdx.x;
124    if (i < n) y[i] = x[i];
125}
126"#;
127
128#[cfg(target_os = "linux")]
129mod pcg_device {
130    use super::DeviceResidentPcgInput;
131    use super::DeviceResidentPcgOutput;
132    use super::PCG_KERNEL_SOURCE;
133    use crate::bms::gpu::row::launch_bms_flex_row_diagonal;
134    use crate::bms::gpu::row::launch_bms_flex_row_hvp_into_device;
135    use cudarc::driver::{CudaModule, CudaStream, LaunchConfig, PushKernelArg};
136    use std::sync::{Arc, OnceLock};
137
138    struct PcgBackend {
139        stream: Arc<CudaStream>,
140        module: Arc<CudaModule>,
141    }
142
143    impl PcgBackend {
144        fn probe() -> Result<&'static Self, String> {
145            static BACKEND: OnceLock<Result<PcgBackend, String>> = OnceLock::new();
146            BACKEND
147                .get_or_init(|| {
148                    let runtime = gam_gpu::device_runtime::GpuRuntime::global()
149                        .ok_or_else(|| "pcg backend: no CUDA runtime available".to_string())?;
150                    let ctx = gam_gpu::device_runtime::cuda_context_for(
151                        runtime.selected_device().ordinal,
152                    )
153                    .ok_or_else(|| {
154                        format!(
155                            "pcg backend: failed to create CUDA context for device {}",
156                            runtime.selected_device().ordinal
157                        )
158                    })?;
159                    let stream = ctx.default_stream();
160                    // #1551: arch-aware compile via the project's shared NVRTC
161                    // entry point — pin `--gpu-architecture` to the device
162                    // capability and supply the standard CUDA include paths,
163                    // instead of bare `cudarc::nvrtc::compile_ptx` (NVRTC default
164                    // arch, no includes).
165                    let ptx = gam_gpu::device_cache::compile_ptx_arch(PCG_KERNEL_SOURCE)
166                        .map_err(|err| format!("pcg NVRTC compile failed: {err}"))?;
167                    let module = ctx
168                        .load_module(ptx)
169                        .map_err(|err| format!("pcg module load failed: {err}"))?;
170                    Ok(PcgBackend { stream, module })
171                })
172                .as_ref()
173                .map_err(String::clone)
174        }
175    }
176
177    fn launch_blocks(p: usize, threads: u32) -> u32 {
178        ((p as u32) + threads - 1) / threads
179    }
180
181    /// PCG against the row-Hessian operator with Jacobi preconditioner from
182    /// `diag(H)`. All vectors remain on the device for the duration of the
183    /// loop; only the squared residual norm crosses the host boundary each
184    /// iter (one f64, ≤ 8 bytes).
185    pub(super) fn run(
186        input: DeviceResidentPcgInput<'_>,
187    ) -> Result<DeviceResidentPcgOutput, String> {
188        let p = input.storage.block.p_total;
189        if input.b.len() != p {
190            return Err(format!(
191                "device-resident pcg: b.len()={} != p_total={p}",
192                input.b.len()
193            ));
194        }
195        if !input.rel_tol.is_finite() || input.rel_tol <= 0.0 {
196            return Err(format!(
197                "device-resident pcg: rel_tol must be positive and finite (got {})",
198                input.rel_tol
199            ));
200        }
201        if input.max_iters == 0 {
202            return Err("device-resident pcg: max_iters must be >= 1".to_string());
203        }
204        if !input.precond_diag_floor.is_finite() || input.precond_diag_floor <= 0.0 {
205            return Err(format!(
206                "device-resident pcg: precond_diag_floor must be positive and finite (got {})",
207                input.precond_diag_floor
208            ));
209        }
210
211        let backend = PcgBackend::probe()?;
212        let stream = backend.stream.clone();
213        let module = backend.module.clone();
214
215        // ── Load kernel handles once ─────────────────────────────────────
216        let f_axpy = module
217            .load_function("pcg_axpy")
218            .map_err(|e| format!("pcg load pcg_axpy: {e}"))?;
219        let f_axpby = module
220            .load_function("pcg_axpby")
221            .map_err(|e| format!("pcg load pcg_axpby: {e}"))?;
222        let f_precond = module
223            .load_function("pcg_apply_diag_precond")
224            .map_err(|e| format!("pcg load pcg_apply_diag_precond: {e}"))?;
225        let f_dot = module
226            .load_function("pcg_dot_single_block")
227            .map_err(|e| format!("pcg load pcg_dot_single_block: {e}"))?;
228        let f_copy = module
229            .load_function("pcg_copy")
230            .map_err(|e| format!("pcg load pcg_copy: {e}"))?;
231
232        // ── Allocate device vectors x, r, z, p_vec, q (length p each) ──
233        let mut d_x = stream
234            .alloc_zeros::<f64>(p)
235            .map_err(|e| format!("pcg alloc x: {e}"))?;
236        let mut d_r = stream
237            .clone_htod(input.b)
238            .map_err(|e| format!("pcg upload b -> r: {e}"))?;
239        let mut d_z = stream
240            .alloc_zeros::<f64>(p)
241            .map_err(|e| format!("pcg alloc z: {e}"))?;
242        let mut d_p = stream
243            .alloc_zeros::<f64>(p)
244            .map_err(|e| format!("pcg alloc p: {e}"))?;
245        let mut d_q = stream
246            .alloc_zeros::<f64>(p)
247            .map_err(|e| format!("pcg alloc q: {e}"))?;
248        // One-element scalar buffer reused across iters for `p·q` and
249        // `r·z` dot products.
250        let mut d_scalar = stream
251            .alloc_zeros::<f64>(1)
252            .map_err(|e| format!("pcg alloc scalar: {e}"))?;
253
254        // Preconditioner: M⁻¹ from diag(H). One HostVec download per
255        // *outer* call, but this is constant work per solve — not per
256        // iter — so it does not block the inner loop's no-sync property.
257        let diag_host = launch_bms_flex_row_diagonal(input.storage)
258            .map_err(|e| format!("pcg diag fetch: {e}"))?;
259        if diag_host.len() != p {
260            return Err(format!(
261                "pcg: diag length {} != p_total {p}",
262                diag_host.len()
263            ));
264        }
265        let d_diag = stream
266            .clone_htod(&diag_host)
267            .map_err(|e| format!("pcg upload diag: {e}"))?;
268
269        // ── Convergence baseline: ‖b‖₂ via one in-stream dot ─────────────
270        let n_i32 = i32::try_from(p).map_err(|_| format!("pcg: p_total={p} exceeds i32 range"))?;
271        let vec_threads: u32 = 64;
272        let vec_blocks = launch_blocks(p, vec_threads);
273        let dot_threads: u32 = match p {
274            0..=64 => 64,
275            65..=128 => 128,
276            129..=256 => 256,
277            257..=512 => 512,
278            _ => 1024,
279        };
280        if p > 1024 {
281            return Err(format!(
282                "device-resident pcg: p_total={p} exceeds single-block dot capacity (1024); \
283                 widen pcg_dot_single_block to multi-block reduce before raising the cap"
284            ));
285        }
286
287        // ‖b‖₂² = b · b (b is currently in d_r since r₀ = b - H·0 = b)
288        // SAFETY: `f_dot` is the `pcg_dot_single_block` device function loaded
289        // above; its signature is `(i32, *const f64, *const f64, *mut f64)`.
290        // `n_i32` was bounded against `1024` (kernel's max-n contract) two
291        // lines up; `d_r` is a `CudaSlice<f64>` of length `n` allocated to the
292        // same stream; `d_scalar` is the length-1 output slice. Single-block
293        // grid (1×dot_threads) matches the kernel's reduction strategy.
294        unsafe {
295            stream
296                .launch_builder(&f_dot)
297                .arg(&n_i32)
298                .arg(&d_r)
299                .arg(&d_r)
300                .arg(&mut d_scalar)
301                .launch(LaunchConfig {
302                    grid_dim: (1, 1, 1),
303                    block_dim: (dot_threads, 1, 1),
304                    shared_mem_bytes: 0,
305                })
306        }
307        .map_err(|e| format!("pcg b·b launch: {e}"))?;
308        stream
309            .synchronize()
310            .map_err(|e| format!("pcg b·b sync: {e}"))?;
311        let host_scalar = stream
312            .clone_dtoh(&d_scalar)
313            .map_err(|e| format!("pcg b·b download: {e}"))?;
314        let bb = host_scalar[0];
315        if !bb.is_finite() {
316            return Err(format!("pcg: b·b not finite ({bb})"));
317        }
318        let b_norm = bb.sqrt();
319        if b_norm == 0.0 {
320            // x = 0, r = b = 0, trivially converged.
321            return Ok(DeviceResidentPcgOutput {
322                x: vec![0.0; p],
323                iterations: 0,
324                final_rel_residual: 0.0,
325            });
326        }
327
328        // z₀ = M⁻¹ r₀
329        // SAFETY: `f_precond` is `pcg_jacobi_precond` with signature
330        // `(i32, f64, *const f64, *const f64, *mut f64)`. `d_diag`, `d_r`,
331        // `d_z` are all `CudaSlice<f64>` of length `n` on the same stream;
332        // `vec_blocks × vec_threads ≥ n` covers every output element.
333        unsafe {
334            stream
335                .launch_builder(&f_precond)
336                .arg(&n_i32)
337                .arg(&input.precond_diag_floor)
338                .arg(&d_diag)
339                .arg(&d_r)
340                .arg(&mut d_z)
341                .launch(LaunchConfig {
342                    grid_dim: (vec_blocks, 1, 1),
343                    block_dim: (vec_threads, 1, 1),
344                    shared_mem_bytes: 0,
345                })
346        }
347        .map_err(|e| format!("pcg precond z₀: {e}"))?;
348
349        // p₀ = z₀
350        // SAFETY: `f_copy` is `pcg_copy` with signature
351        // `(i32, *const f64, *mut f64)`. `d_z` and `d_p` are
352        // `CudaSlice<f64>` of length `n` on the same stream;
353        // `vec_blocks × vec_threads ≥ n` covers every element.
354        unsafe {
355            stream
356                .launch_builder(&f_copy)
357                .arg(&n_i32)
358                .arg(&d_z)
359                .arg(&mut d_p)
360                .launch(LaunchConfig {
361                    grid_dim: (vec_blocks, 1, 1),
362                    block_dim: (vec_threads, 1, 1),
363                    shared_mem_bytes: 0,
364                })
365        }
366        .map_err(|e| format!("pcg copy p₀: {e}"))?;
367
368        // ρ₀ = r₀·z₀
369        // SAFETY: same invariants as the ‖b‖₂² launch above — `f_dot`
370        // signature `(i32, *const f64, *const f64, *mut f64)`, `d_r` and
371        // `d_z` are length-`n` `CudaSlice<f64>`, `d_scalar` is length-1,
372        // single-block grid matches kernel's reduction.
373        unsafe {
374            stream
375                .launch_builder(&f_dot)
376                .arg(&n_i32)
377                .arg(&d_r)
378                .arg(&d_z)
379                .arg(&mut d_scalar)
380                .launch(LaunchConfig {
381                    grid_dim: (1, 1, 1),
382                    block_dim: (dot_threads, 1, 1),
383                    shared_mem_bytes: 0,
384                })
385        }
386        .map_err(|e| format!("pcg ρ₀ launch: {e}"))?;
387        stream
388            .synchronize()
389            .map_err(|e| format!("pcg ρ₀ sync: {e}"))?;
390        let s = stream
391            .clone_dtoh(&d_scalar)
392            .map_err(|e| format!("pcg ρ₀ download: {e}"))?;
393        let mut rho = s[0];
394        if !rho.is_finite() {
395            return Err(format!("pcg: ρ₀ not finite ({rho})"));
396        }
397
398        let mut iters_taken: usize = 0;
399        let mut final_rel_residual: f64 = (bb.sqrt() / b_norm).max(0.0);
400        for iter in 0..input.max_iters {
401            iters_taken = iter + 1;
402
403            // q = H · p (on device, no sync, no DtoH).
404            launch_bms_flex_row_hvp_into_device(input.storage, &d_p, &mut d_q)
405                .map_err(|e| format!("pcg Hv iter {iter}: {e}"))?;
406
407            // pq = p·q
408            // SAFETY: identical to ‖b‖₂² launch — `f_dot` signature
409            // `(i32, *const f64, *const f64, *mut f64)`; `d_p` is the
410            // current search direction and `d_q` was just populated by
411            // `launch_bms_flex_row_hvp_into_device` (same stream, same `n`).
412            unsafe {
413                stream
414                    .launch_builder(&f_dot)
415                    .arg(&n_i32)
416                    .arg(&d_p)
417                    .arg(&d_q)
418                    .arg(&mut d_scalar)
419                    .launch(LaunchConfig {
420                        grid_dim: (1, 1, 1),
421                        block_dim: (dot_threads, 1, 1),
422                        shared_mem_bytes: 0,
423                    })
424            }
425            .map_err(|e| format!("pcg p·q launch iter {iter}: {e}"))?;
426            stream
427                .synchronize()
428                .map_err(|e| format!("pcg p·q sync iter {iter}: {e}"))?;
429            let s = stream
430                .clone_dtoh(&d_scalar)
431                .map_err(|e| format!("pcg p·q download iter {iter}: {e}"))?;
432            let pq = s[0];
433            if !pq.is_finite() || pq == 0.0 {
434                return Err(format!(
435                    "pcg iter {iter}: p·q={pq} (non-finite or zero); operator is not positive-definite"
436                ));
437            }
438            let alpha = rho / pq;
439
440            // x += α p
441            // SAFETY: `f_axpy` is `pcg_axpy` with signature
442            // `(i32, f64, *const f64, *mut f64)`. `alpha` is the
443            // finite-checked CG step length (`rho/pq`, both validated
444            // above). `d_p` and `d_x` are length-`n` `CudaSlice<f64>` on
445            // the same stream. Grid covers all `n` elements.
446            unsafe {
447                stream
448                    .launch_builder(&f_axpy)
449                    .arg(&n_i32)
450                    .arg(&alpha)
451                    .arg(&d_p)
452                    .arg(&mut d_x)
453                    .launch(LaunchConfig {
454                        grid_dim: (vec_blocks, 1, 1),
455                        block_dim: (vec_threads, 1, 1),
456                        shared_mem_bytes: 0,
457                    })
458            }
459            .map_err(|e| format!("pcg x+=αp iter {iter}: {e}"))?;
460
461            // r -= α q
462            let neg_alpha = -alpha;
463            // SAFETY: same `f_axpy` invariants as the `x += α p` launch
464            // above; `neg_alpha = -alpha` is finite (alpha was checked),
465            // `d_q` and `d_r` are length-`n` `CudaSlice<f64>` on the same
466            // stream.
467            unsafe {
468                stream
469                    .launch_builder(&f_axpy)
470                    .arg(&n_i32)
471                    .arg(&neg_alpha)
472                    .arg(&d_q)
473                    .arg(&mut d_r)
474                    .launch(LaunchConfig {
475                        grid_dim: (vec_blocks, 1, 1),
476                        block_dim: (vec_threads, 1, 1),
477                        shared_mem_bytes: 0,
478                    })
479            }
480            .map_err(|e| format!("pcg r-=αq iter {iter}: {e}"))?;
481
482            // ‖r‖₂² = r·r (single device dot, single f64 DtoH)
483            // SAFETY: identical to the ‖b‖₂² launch at function entry —
484            // `f_dot` signature, `d_r` length-`n`, `d_scalar` length-1,
485            // single-block reduction grid.
486            unsafe {
487                stream
488                    .launch_builder(&f_dot)
489                    .arg(&n_i32)
490                    .arg(&d_r)
491                    .arg(&d_r)
492                    .arg(&mut d_scalar)
493                    .launch(LaunchConfig {
494                        grid_dim: (1, 1, 1),
495                        block_dim: (dot_threads, 1, 1),
496                        shared_mem_bytes: 0,
497                    })
498            }
499            .map_err(|e| format!("pcg ‖r‖₂² launch iter {iter}: {e}"))?;
500            stream
501                .synchronize()
502                .map_err(|e| format!("pcg ‖r‖₂² sync iter {iter}: {e}"))?;
503            let s = stream
504                .clone_dtoh(&d_scalar)
505                .map_err(|e| format!("pcg ‖r‖₂² download iter {iter}: {e}"))?;
506            let rr = s[0];
507            if !rr.is_finite() {
508                return Err(format!("pcg iter {iter}: ‖r‖₂²={rr} non-finite"));
509            }
510            let rel = rr.sqrt() / b_norm;
511            final_rel_residual = rel;
512            if rel <= input.rel_tol {
513                break;
514            }
515
516            // z = M⁻¹ r
517            // SAFETY: same `f_precond` invariants as the `z₀ = M⁻¹ r₀`
518            // launch above — signature `(i32, f64, *const f64, *const f64,
519            // *mut f64)`, all four slices length-`n` `CudaSlice<f64>`, grid
520            // covers all `n` elements.
521            unsafe {
522                stream
523                    .launch_builder(&f_precond)
524                    .arg(&n_i32)
525                    .arg(&input.precond_diag_floor)
526                    .arg(&d_diag)
527                    .arg(&d_r)
528                    .arg(&mut d_z)
529                    .launch(LaunchConfig {
530                        grid_dim: (vec_blocks, 1, 1),
531                        block_dim: (vec_threads, 1, 1),
532                        shared_mem_bytes: 0,
533                    })
534            }
535            .map_err(|e| format!("pcg z=M⁻¹r iter {iter}: {e}"))?;
536
537            // ρ_new = r·z
538            // SAFETY: identical to the ρ₀ launch above — `f_dot`
539            // signature, `d_r` and `d_z` length-`n`, `d_scalar` length-1.
540            unsafe {
541                stream
542                    .launch_builder(&f_dot)
543                    .arg(&n_i32)
544                    .arg(&d_r)
545                    .arg(&d_z)
546                    .arg(&mut d_scalar)
547                    .launch(LaunchConfig {
548                        grid_dim: (1, 1, 1),
549                        block_dim: (dot_threads, 1, 1),
550                        shared_mem_bytes: 0,
551                    })
552            }
553            .map_err(|e| format!("pcg ρ_new launch iter {iter}: {e}"))?;
554            stream
555                .synchronize()
556                .map_err(|e| format!("pcg ρ_new sync iter {iter}: {e}"))?;
557            let s = stream
558                .clone_dtoh(&d_scalar)
559                .map_err(|e| format!("pcg ρ_new download iter {iter}: {e}"))?;
560            let rho_new = s[0];
561            if !rho_new.is_finite() {
562                return Err(format!("pcg iter {iter}: ρ_new={rho_new} non-finite"));
563            }
564            let beta_pcg = rho_new / rho;
565
566            // p = z + β p  ⇒  via pcg_axpby with a=1, b=β
567            // SAFETY: `f_axpby` is `pcg_axpby` with signature
568            // `(i32, f64, *const f64, f64, *mut f64)`. `beta_pcg = rho_new/rho`
569            // was finite-checked. `d_z` and `d_p` are length-`n`
570            // `CudaSlice<f64>` on the same stream; grid covers all `n`
571            // elements.
572            unsafe {
573                stream
574                    .launch_builder(&f_axpby)
575                    .arg(&n_i32)
576                    .arg(&1.0_f64)
577                    .arg(&d_z)
578                    .arg(&beta_pcg)
579                    .arg(&mut d_p)
580                    .launch(LaunchConfig {
581                        grid_dim: (vec_blocks, 1, 1),
582                        block_dim: (vec_threads, 1, 1),
583                        shared_mem_bytes: 0,
584                    })
585            }
586            .map_err(|e| format!("pcg p=z+βp iter {iter}: {e}"))?;
587
588            rho = rho_new;
589        }
590
591        // Download x once at the end.
592        let x_host = stream
593            .clone_dtoh(&d_x)
594            .map_err(|e| format!("pcg final x DtoH: {e}"))?;
595        // The auxiliary device allocs (d_r/d_z/d_p/d_q/d_scalar/d_diag) drop
596        // here and return their bytes to cudarc's allocator.
597        drop(d_r);
598        drop(d_z);
599        drop(d_p);
600        drop(d_q);
601        drop(d_scalar);
602        drop(d_diag);
603        Ok(DeviceResidentPcgOutput {
604            x: x_host,
605            iterations: iters_taken,
606            final_rel_residual,
607        })
608    }
609}
610
611/// Device-resident PCG against the BMS-FLEX row-Hessian operator.
612///
613/// Block 9 Phase 5: every PCG vector — `x`, `r`, `z`, `p`, `q` — stays on
614/// the device for the entire loop; only the squared residual norm (one f64)
615/// is downloaded per iteration for the convergence check. Bit-equal output
616/// to a host-side reference PCG against the same operator + preconditioner
617/// when the tolerance is tight; differences only show up at the floating-
618/// point reduction-order level.
619///
620/// Linux-only. See [`DeviceResidentPcgInput`] for parameters.
621#[cfg(target_os = "linux")]
622pub fn run_pcg_against_row_hessian_device(
623    input: DeviceResidentPcgInput<'_>,
624) -> Result<DeviceResidentPcgOutput, String> {
625    pcg_device::run(input)
626}
627
628/// Block 9 Phase 5 — V100 parity for `run_pcg_against_row_hessian_device`.
629///
630/// Builds a small `(n=64, r=20, p=44)` BMS-FLEX row-Hessian fixture, computes
631/// the dense joint Hessian via the same CPU oracle the HVP parity test uses,
632/// solves `H · x = b` on the host via dense LU as ground truth, and asserts
633/// the device-resident PCG iterate matches to a tight tolerance.
634#[cfg(all(test, target_os = "linux"))]
635mod pcg_device_parity_tests {
636    use super::*;
637    use crate::bms::gpu::row::{
638        BmsFlexBlockLayout, BmsFlexPrimaryLayout, DeviceResidentRowHess,
639    };
640    use ndarray::Array2;
641
642    /// Dense oracle for `H_full = Σ_i P_iᵀ H_i P_i` consistent with
643    /// `cpu_oracle_bms_flex_row_hvp`'s pullback math.
644    fn cpu_dense_joint_hessian(
645        row_hessians: &[f64],
646        marginal: &[f64],
647        logslope: &[f64],
648        block: &BmsFlexBlockLayout,
649        primary: &BmsFlexPrimaryLayout,
650        n: usize,
651    ) -> Array2<f64> {
652        let p_total = block.p_total;
653        let r = primary.r;
654        let p_m = block.p_m;
655        let p_g = block.p_g;
656        let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
657        let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
658        let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
659        let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
660        let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
661        let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
662        let mut h_dense = Array2::<f64>::zeros((p_total, p_total));
663        // For each row build P_i columns as length-p_total vectors.
664        let mut phi = vec![vec![0.0_f64; p_total]; r];
665        for row in 0..n {
666            for col in phi.iter_mut() {
667                col.iter_mut().for_each(|v| *v = 0.0);
668            }
669            let mrow = &marginal[row * p_m..(row + 1) * p_m];
670            let grow = &logslope[row * p_g..(row + 1) * p_g];
671            for k in 0..p_m {
672                phi[0][k] = mrow[k];
673            }
674            for k in 0..p_g {
675                phi[1][p_m + k] = grow[k];
676            }
677            for k in 0..h_block_len {
678                phi[h_primary_start + k][h_block_start + k] = 1.0;
679            }
680            for k in 0..w_block_len {
681                phi[w_primary_start + k][w_block_start + k] = 1.0;
682            }
683            let h_row = &row_hessians[row * r * r..(row + 1) * r * r];
684            for u in 0..r {
685                for v in 0..r {
686                    let huv = h_row[u * r + v];
687                    if huv == 0.0 {
688                        continue;
689                    }
690                    for m in 0..p_total {
691                        let phim = phi[u][m];
692                        if phim == 0.0 {
693                            continue;
694                        }
695                        let scaled = huv * phim;
696                        for nn in 0..p_total {
697                            h_dense[[m, nn]] += scaled * phi[v][nn];
698                        }
699                    }
700                }
701            }
702        }
703        h_dense
704    }
705
706    /// Reference oracle: host PCG against the dense joint H + diag(H)
707    /// preconditioner, with a tolerance two decades tighter than the GPU
708    /// PCG's. Comparing GPU PCG to host PCG (rather than to a Cholesky
709    /// solve) keeps the comparison numerically apples-to-apples — only
710    /// reduction order differs between the two paths.
711    fn cpu_pcg_oracle(h: &Array2<f64>, b: &[f64], rel_tol: f64) -> Vec<f64> {
712        let p = b.len();
713        let diag: ndarray::Array1<f64> =
714            ndarray::Array1::from_vec((0..p).map(|i| h[[i, i]]).collect());
715        let rhs = ndarray::Array1::from_vec(b.to_vec());
716        let h_owned = h.clone();
717        let apply = move |v: &ndarray::Array1<f64>| h_owned.dot(v);
718        let (x, info) =
719            gam_linalg::utils::solve_spd_pcg_with_info(apply, &rhs, &diag, rel_tol, 4 * p)
720                .expect("host PCG oracle must converge on SPD fixture");
721        assert!(
722            info.converged,
723            "host PCG oracle failed to converge: iters={} rel_res={}",
724            info.iterations, info.relative_residual_norm
725        );
726        x.to_vec()
727    }
728
729    #[test]
730    fn pcg_device_matches_dense_oracle_at_n64_r20_p44() {
731        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
732            eprintln!("[pcg_device parity] no CUDA runtime — skipping");
733            return;
734        };
735        let n = 64_usize;
736        let p_m = 14_usize;
737        let p_g = 12_usize;
738        let p_h_dim = 10_usize;
739        let p_w_dim = 8_usize;
740        let r = 2 + p_h_dim + p_w_dim;
741        let p_total = p_m + p_g + p_h_dim + p_w_dim;
742        let block = BmsFlexBlockLayout {
743            p_m,
744            p_g,
745            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
746            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
747            p_total,
748        };
749        let primary = BmsFlexPrimaryLayout {
750            h: Some(2..2 + p_h_dim),
751            w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
752            r,
753        };
754
755        // Same deterministic symmetric Hessians + designs as the HVP parity
756        // gate, so any drift between Phase 4 and Phase 5 surfaces here too.
757        let mut row_hessians = vec![0.0_f64; n * r * r];
758        for row in 0..n {
759            let base = row * r * r;
760            for u in 0..r {
761                for v in 0..r {
762                    let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
763                    let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
764                    row_hessians[base + u * r + v] = a;
765                }
766            }
767            for u in 0..r {
768                for v in (u + 1)..r {
769                    let upper = row_hessians[base + u * r + v];
770                    let lower = row_hessians[base + v * r + u];
771                    let sym = 0.5 * (upper + lower);
772                    row_hessians[base + u * r + v] = sym;
773                    row_hessians[base + v * r + u] = sym;
774                }
775                // Boost the diagonal heavily so each H_i is positive
776                // definite — guarantees the joint pulled-back Hessian is
777                // SPD, which PCG requires.
778                row_hessians[base + u * r + u] += 4.0 * (r as f64);
779            }
780        }
781        let mut marginal = vec![0.0_f64; n * p_m];
782        for row in 0..n {
783            for j in 0..p_m {
784                let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
785                marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
786            }
787        }
788        let mut logslope = vec![0.0_f64; n * p_g];
789        for row in 0..n {
790            for j in 0..p_g {
791                let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
792                logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
793            }
794        }
795
796        // Pick a non-trivial RHS.
797        let b: Vec<f64> = (0..p_total)
798            .map(|i| {
799                let seed = (i as f64) * 0.157 + 0.6;
800                seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
801            })
802            .collect();
803
804        let h_dense =
805            cpu_dense_joint_hessian(&row_hessians, &marginal, &logslope, &block, &primary, n);
806        let x_oracle = cpu_pcg_oracle(&h_dense, &b, 1e-12);
807
808        // Grab the same CUDA context + default stream that the bms_flex_row
809        // kernels will use when `run_pcg_against_row_hessian_device` probes
810        // its own backend. Going through the public runtime APIs keeps the
811        // test independent of any private kernel-backend symbols.
812        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
813            .expect("runtime must exist when probe succeeded above");
814        // Past the GpuRuntime::global() Some-gate above: a context-creation or
815        // HtoD-upload failure here is a real device fault on a CUDA host, not a
816        // no-CUDA skip — fail loud (device-PCG skip-pass class, eee12f6b2). The old
817        // arms returned, so a context/upload fault on a GPU host passed silently.
818        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
819            .expect("[pcg_device parity] cuda_context_for must succeed on a CUDA host");
820        let stream = ctx.default_stream();
821        let d_h = stream
822            .clone_htod(&row_hessians)
823            .expect("[pcg_device parity] upload h must succeed on a CUDA host");
824        let d_m = stream
825            .clone_htod(&marginal)
826            .expect("[pcg_device parity] upload marginal must succeed on a CUDA host");
827        let d_g = stream
828            .clone_htod(&logslope)
829            .expect("[pcg_device parity] upload logslope must succeed on a CUDA host");
830        let storage = DeviceResidentRowHess {
831            hess: d_h,
832            marginal_design: d_m,
833            logslope_design: d_g,
834            n,
835            r,
836            block,
837            primary,
838
839            bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
840        };
841
842        let out = run_pcg_against_row_hessian_device(DeviceResidentPcgInput {
843            storage: &storage,
844            b: &b,
845            rel_tol: 1e-10,
846            max_iters: 4 * p_total,
847            precond_diag_floor: 1e-12,
848        })
849        .expect("device-resident PCG must succeed on SPD fixture");
850
851        assert_eq!(out.x.len(), p_total);
852        let mut max_abs = 0.0_f64;
853        for i in 0..p_total {
854            let diff = (out.x[i] - x_oracle[i]).abs();
855            if diff > max_abs {
856                max_abs = diff;
857            }
858        }
859        // Each iteration introduces O(1) ULPs of round-off in the dot/
860        // axpy ladder; with ~88 iters max at p=44 we expect ‖Δx‖∞ comfortably
861        // below 1e-7. Anything larger means a code bug, not float noise.
862        assert!(
863            max_abs <= 1e-7,
864            "pcg_device parity ‖Δx‖∞={max_abs:.3e} > 1e-7 after {} iters \
865             (final rel residual={:.3e})",
866            out.iterations,
867            out.final_rel_residual
868        );
869        eprintln!(
870            "[pcg_device parity] n={n} p={p_total} r={r}: iters={} rel_res={:.3e} ‖Δx‖∞={:.3e}",
871            out.iterations, out.final_rel_residual, max_abs
872        );
873    }
874}