Skip to main content

gam_models/bms/gpu/
device_pcg.rs

1// ────────────────────────────────────────────────────────────────────────
2// Block 9 Phase 5 — device-resident PCG against the BMS-FLEX row-Hessian
3// operator.
4//
5// The inner Newton solve in `BernoulliMarginalSlope` (matrix-free path,
6// large-scale shape n=195k, p=44, r=20) currently reaches the GPU as a
7// per-CG-iteration call to `launch_bms_flex_row_hvp` returning a host
8// `Vec<f64>`. With ~6400 inner CG iterations per outer iteration that round-
9// trip cost dominates: each iter pays one `stream.synchronize()` plus one
10// DtoH download. At p=44 the download itself is 352 bytes — trivial in
11// bandwidth, painful in latency.
12//
13// Phase 5 keeps every PCG vector on the device and runs the outer loop with
14// only a single small scalar download per iteration (the squared residual
15// norm for the convergence check). The Hv kernel becomes `into_device`
16// (Block 9 addition to `bms_flex_row.rs`), and the axpy / dot / diagonal-
17// preconditioner / scale-and-add steps run as tiny NVRTC kernels on the
18// same default stream so the sequence is implicitly ordered without sync.
19// ────────────────────────────────────────────────────────────────────────
20
21/// Inputs to [`run_pcg_against_row_hessian_device`]. The right-hand-side
22/// `b` is supplied as a host slice (it is the only host-resident vector
23/// that needs to enter the loop — the iterate, residual, search direction,
24/// and Hv output all live on the device).
25#[cfg(target_os = "linux")]
26pub struct DeviceResidentPcgInput<'a> {
27    /// Per-fit row-Hessian + design storage. The PCG operator is
28    /// `v ↦ launch_bms_flex_row_hvp_into_device(storage, ...)`.
29    pub storage: &'a crate::bms::gpu::row::DeviceResidentRowHess,
30    /// Right-hand-side `b`, length `storage.block.p_total`. Uploaded once.
31    pub b: &'a [f64],
32    /// Convergence tolerance on relative residual `‖r‖₂ / ‖b‖₂`.
33    pub rel_tol: f64,
34    /// Hard cap on iterations (the inner loop also bails on stagnation).
35    pub max_iters: usize,
36    /// Floor on `|diag(H)[i]|` used by the Jacobi preconditioner. Set to
37    /// `1e-12` for the matrix-free row-Hessian path; the row-primary
38    /// Hessian's diagonal is positive-definite by construction.
39    pub precond_diag_floor: f64,
40}
41
42/// Output of [`run_pcg_against_row_hessian_device`].
43#[cfg(target_os = "linux")]
44pub struct DeviceResidentPcgOutput {
45    /// Solution `x` such that `H · x ≈ b`, length `storage.block.p_total`.
46    pub x: Vec<f64>,
47    /// Number of PCG iterations consumed (final iter does not count if it
48    /// converged immediately after the dot reduction).
49    pub iterations: usize,
50    /// Final achieved relative residual `‖r‖₂ / ‖b‖₂`.
51    pub final_rel_residual: f64,
52}
53
54/// NVRTC source for the Phase-5 device-resident PCG support kernels. Every
55/// kernel here operates on length-`p` device vectors with `p` typically
56/// 44–256, so a single CTA suffices for each.
57#[cfg(target_os = "linux")]
58const PCG_KERNEL_SOURCE: &str = r#"
59// y[i] += a * x[i]
60extern "C" __global__ void pcg_axpy(int n, double a,
61                                    const double * __restrict__ x,
62                                    double * __restrict__ y)
63{
64    int i = blockIdx.x * blockDim.x + threadIdx.x;
65    if (i < n) y[i] += a * x[i];
66}
67
68// y[i] = a * x[i] + b * y[i]
69extern "C" __global__ void pcg_axpby(int n, double a,
70                                     const double * __restrict__ x,
71                                     double b,
72                                     double * __restrict__ y)
73{
74    int i = blockIdx.x * blockDim.x + threadIdx.x;
75    if (i < n) y[i] = a * x[i] + b * y[i];
76}
77
78// z[i] = r[i] / clamp(diag[i], floor) (sign-preserving floor on |diag|).
79extern "C" __global__ void pcg_apply_diag_precond(int n, double floor_val,
80                                                  const double * __restrict__ diag,
81                                                  const double * __restrict__ r,
82                                                  double * __restrict__ z)
83{
84    int i = blockIdx.x * blockDim.x + threadIdx.x;
85    if (i < n) {
86        double d = diag[i];
87        double ad = d < 0 ? -d : d;
88        double clamped = ad > floor_val ? d : (d >= 0.0 ? floor_val : -floor_val);
89        z[i] = r[i] / clamped;
90    }
91}
92
93// Single-block dot product; writes the scalar to out[0]. n must be <= 1024.
94extern "C" __global__ void pcg_dot_single_block(int n,
95                                                const double * __restrict__ a,
96                                                const double * __restrict__ b,
97                                                double * __restrict__ out)
98{
99    __shared__ double s[1024];
100    int tid = threadIdx.x;
101    double acc = 0.0;
102    for (int i = tid; i < n; i += blockDim.x) acc += a[i] * b[i];
103    s[tid] = acc;
104    __syncthreads();
105    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
106        if (tid < stride) s[tid] += s[tid + stride];
107        __syncthreads();
108    }
109    if (tid == 0) out[0] = s[0];
110}
111
112// Set out[i] = 0 for i in [0, n).
113extern "C" __global__ void pcg_init_zero(int n, double * __restrict__ out) {
114    int i = blockIdx.x * blockDim.x + threadIdx.x;
115    if (i < n) out[i] = 0.0;
116}
117
118// Copy y[i] = x[i].
119extern "C" __global__ void pcg_copy(int n,
120                                    const double * __restrict__ x,
121                                    double * __restrict__ y)
122{
123    int i = blockIdx.x * blockDim.x + threadIdx.x;
124    if (i < n) y[i] = x[i];
125}
126"#;
127
128#[cfg(target_os = "linux")]
129mod pcg_device {
130    use super::DeviceResidentPcgInput;
131    use super::DeviceResidentPcgOutput;
132    use super::PCG_KERNEL_SOURCE;
133    use crate::bms::gpu::row::launch_bms_flex_row_diagonal;
134    use crate::bms::gpu::row::launch_bms_flex_row_hvp_into_device;
135    use cudarc::driver::{CudaModule, CudaStream, LaunchConfig, PushKernelArg};
136    use std::sync::{Arc, OnceLock};
137
138    struct PcgBackend {
139        stream: Arc<CudaStream>,
140        module: Arc<CudaModule>,
141    }
142
143    impl PcgBackend {
144        fn probe() -> Result<&'static Self, String> {
145            static BACKEND: OnceLock<Result<PcgBackend, String>> = OnceLock::new();
146            BACKEND
147                .get_or_init(|| {
148                    let runtime = gam_gpu::device_runtime::GpuRuntime::global()
149                        .ok_or_else(|| "pcg backend: no CUDA runtime available".to_string())?;
150                    let ctx = gam_gpu::device_runtime::cuda_context_for(
151                        runtime.selected_device().ordinal,
152                    )
153                    .ok_or_else(|| {
154                        format!(
155                            "pcg backend: failed to create CUDA context for device {}",
156                            runtime.selected_device().ordinal
157                        )
158                    })?;
159                    let stream = ctx.default_stream();
160                    let ptx = cudarc::nvrtc::compile_ptx(PCG_KERNEL_SOURCE)
161                        .map_err(|err| format!("pcg NVRTC compile failed: {err}"))?;
162                    let module = ctx
163                        .load_module(ptx)
164                        .map_err(|err| format!("pcg module load failed: {err}"))?;
165                    Ok(PcgBackend { stream, module })
166                })
167                .as_ref()
168                .map_err(String::clone)
169        }
170    }
171
172    fn launch_blocks(p: usize, threads: u32) -> u32 {
173        ((p as u32) + threads - 1) / threads
174    }
175
176    /// PCG against the row-Hessian operator with Jacobi preconditioner from
177    /// `diag(H)`. All vectors remain on the device for the duration of the
178    /// loop; only the squared residual norm crosses the host boundary each
179    /// iter (one f64, ≤ 8 bytes).
180    pub(super) fn run(
181        input: DeviceResidentPcgInput<'_>,
182    ) -> Result<DeviceResidentPcgOutput, String> {
183        let p = input.storage.block.p_total;
184        if input.b.len() != p {
185            return Err(format!(
186                "device-resident pcg: b.len()={} != p_total={p}",
187                input.b.len()
188            ));
189        }
190        if !input.rel_tol.is_finite() || input.rel_tol <= 0.0 {
191            return Err(format!(
192                "device-resident pcg: rel_tol must be positive and finite (got {})",
193                input.rel_tol
194            ));
195        }
196        if input.max_iters == 0 {
197            return Err("device-resident pcg: max_iters must be >= 1".to_string());
198        }
199        if !input.precond_diag_floor.is_finite() || input.precond_diag_floor <= 0.0 {
200            return Err(format!(
201                "device-resident pcg: precond_diag_floor must be positive and finite (got {})",
202                input.precond_diag_floor
203            ));
204        }
205
206        let backend = PcgBackend::probe()?;
207        let stream = backend.stream.clone();
208        let module = backend.module.clone();
209
210        // ── Load kernel handles once ─────────────────────────────────────
211        let f_axpy = module
212            .load_function("pcg_axpy")
213            .map_err(|e| format!("pcg load pcg_axpy: {e}"))?;
214        let f_axpby = module
215            .load_function("pcg_axpby")
216            .map_err(|e| format!("pcg load pcg_axpby: {e}"))?;
217        let f_precond = module
218            .load_function("pcg_apply_diag_precond")
219            .map_err(|e| format!("pcg load pcg_apply_diag_precond: {e}"))?;
220        let f_dot = module
221            .load_function("pcg_dot_single_block")
222            .map_err(|e| format!("pcg load pcg_dot_single_block: {e}"))?;
223        let f_copy = module
224            .load_function("pcg_copy")
225            .map_err(|e| format!("pcg load pcg_copy: {e}"))?;
226
227        // ── Allocate device vectors x, r, z, p_vec, q (length p each) ──
228        let mut d_x = stream
229            .alloc_zeros::<f64>(p)
230            .map_err(|e| format!("pcg alloc x: {e}"))?;
231        let mut d_r = stream
232            .clone_htod(input.b)
233            .map_err(|e| format!("pcg upload b -> r: {e}"))?;
234        let mut d_z = stream
235            .alloc_zeros::<f64>(p)
236            .map_err(|e| format!("pcg alloc z: {e}"))?;
237        let mut d_p = stream
238            .alloc_zeros::<f64>(p)
239            .map_err(|e| format!("pcg alloc p: {e}"))?;
240        let mut d_q = stream
241            .alloc_zeros::<f64>(p)
242            .map_err(|e| format!("pcg alloc q: {e}"))?;
243        // One-element scalar buffer reused across iters for `p·q` and
244        // `r·z` dot products.
245        let mut d_scalar = stream
246            .alloc_zeros::<f64>(1)
247            .map_err(|e| format!("pcg alloc scalar: {e}"))?;
248
249        // Preconditioner: M⁻¹ from diag(H). One HostVec download per
250        // *outer* call, but this is constant work per solve — not per
251        // iter — so it does not block the inner loop's no-sync property.
252        let diag_host = launch_bms_flex_row_diagonal(input.storage)
253            .map_err(|e| format!("pcg diag fetch: {e}"))?;
254        if diag_host.len() != p {
255            return Err(format!(
256                "pcg: diag length {} != p_total {p}",
257                diag_host.len()
258            ));
259        }
260        let d_diag = stream
261            .clone_htod(&diag_host)
262            .map_err(|e| format!("pcg upload diag: {e}"))?;
263
264        // ── Convergence baseline: ‖b‖₂ via one in-stream dot ─────────────
265        let n_i32 = i32::try_from(p).map_err(|_| format!("pcg: p_total={p} exceeds i32 range"))?;
266        let vec_threads: u32 = 64;
267        let vec_blocks = launch_blocks(p, vec_threads);
268        let dot_threads: u32 = match p {
269            0..=64 => 64,
270            65..=128 => 128,
271            129..=256 => 256,
272            257..=512 => 512,
273            _ => 1024,
274        };
275        if p > 1024 {
276            return Err(format!(
277                "device-resident pcg: p_total={p} exceeds single-block dot capacity (1024); \
278                 widen pcg_dot_single_block to multi-block reduce before raising the cap"
279            ));
280        }
281
282        // ‖b‖₂² = b · b (b is currently in d_r since r₀ = b - H·0 = b)
283        // SAFETY: `f_dot` is the `pcg_dot_single_block` device function loaded
284        // above; its signature is `(i32, *const f64, *const f64, *mut f64)`.
285        // `n_i32` was bounded against `1024` (kernel's max-n contract) two
286        // lines up; `d_r` is a `CudaSlice<f64>` of length `n` allocated to the
287        // same stream; `d_scalar` is the length-1 output slice. Single-block
288        // grid (1×dot_threads) matches the kernel's reduction strategy.
289        unsafe {
290            stream
291                .launch_builder(&f_dot)
292                .arg(&n_i32)
293                .arg(&d_r)
294                .arg(&d_r)
295                .arg(&mut d_scalar)
296                .launch(LaunchConfig {
297                    grid_dim: (1, 1, 1),
298                    block_dim: (dot_threads, 1, 1),
299                    shared_mem_bytes: 0,
300                })
301        }
302        .map_err(|e| format!("pcg b·b launch: {e}"))?;
303        stream
304            .synchronize()
305            .map_err(|e| format!("pcg b·b sync: {e}"))?;
306        let host_scalar = stream
307            .clone_dtoh(&d_scalar)
308            .map_err(|e| format!("pcg b·b download: {e}"))?;
309        let bb = host_scalar[0];
310        if !bb.is_finite() {
311            return Err(format!("pcg: b·b not finite ({bb})"));
312        }
313        let b_norm = bb.sqrt();
314        if b_norm == 0.0 {
315            // x = 0, r = b = 0, trivially converged.
316            return Ok(DeviceResidentPcgOutput {
317                x: vec![0.0; p],
318                iterations: 0,
319                final_rel_residual: 0.0,
320            });
321        }
322
323        // z₀ = M⁻¹ r₀
324        // SAFETY: `f_precond` is `pcg_jacobi_precond` with signature
325        // `(i32, f64, *const f64, *const f64, *mut f64)`. `d_diag`, `d_r`,
326        // `d_z` are all `CudaSlice<f64>` of length `n` on the same stream;
327        // `vec_blocks × vec_threads ≥ n` covers every output element.
328        unsafe {
329            stream
330                .launch_builder(&f_precond)
331                .arg(&n_i32)
332                .arg(&input.precond_diag_floor)
333                .arg(&d_diag)
334                .arg(&d_r)
335                .arg(&mut d_z)
336                .launch(LaunchConfig {
337                    grid_dim: (vec_blocks, 1, 1),
338                    block_dim: (vec_threads, 1, 1),
339                    shared_mem_bytes: 0,
340                })
341        }
342        .map_err(|e| format!("pcg precond z₀: {e}"))?;
343
344        // p₀ = z₀
345        // SAFETY: `f_copy` is `pcg_copy` with signature
346        // `(i32, *const f64, *mut f64)`. `d_z` and `d_p` are
347        // `CudaSlice<f64>` of length `n` on the same stream;
348        // `vec_blocks × vec_threads ≥ n` covers every element.
349        unsafe {
350            stream
351                .launch_builder(&f_copy)
352                .arg(&n_i32)
353                .arg(&d_z)
354                .arg(&mut d_p)
355                .launch(LaunchConfig {
356                    grid_dim: (vec_blocks, 1, 1),
357                    block_dim: (vec_threads, 1, 1),
358                    shared_mem_bytes: 0,
359                })
360        }
361        .map_err(|e| format!("pcg copy p₀: {e}"))?;
362
363        // ρ₀ = r₀·z₀
364        // SAFETY: same invariants as the ‖b‖₂² launch above — `f_dot`
365        // signature `(i32, *const f64, *const f64, *mut f64)`, `d_r` and
366        // `d_z` are length-`n` `CudaSlice<f64>`, `d_scalar` is length-1,
367        // single-block grid matches kernel's reduction.
368        unsafe {
369            stream
370                .launch_builder(&f_dot)
371                .arg(&n_i32)
372                .arg(&d_r)
373                .arg(&d_z)
374                .arg(&mut d_scalar)
375                .launch(LaunchConfig {
376                    grid_dim: (1, 1, 1),
377                    block_dim: (dot_threads, 1, 1),
378                    shared_mem_bytes: 0,
379                })
380        }
381        .map_err(|e| format!("pcg ρ₀ launch: {e}"))?;
382        stream
383            .synchronize()
384            .map_err(|e| format!("pcg ρ₀ sync: {e}"))?;
385        let s = stream
386            .clone_dtoh(&d_scalar)
387            .map_err(|e| format!("pcg ρ₀ download: {e}"))?;
388        let mut rho = s[0];
389        if !rho.is_finite() {
390            return Err(format!("pcg: ρ₀ not finite ({rho})"));
391        }
392
393        let mut iters_taken: usize = 0;
394        let mut final_rel_residual: f64 = (bb.sqrt() / b_norm).max(0.0);
395        for iter in 0..input.max_iters {
396            iters_taken = iter + 1;
397
398            // q = H · p (on device, no sync, no DtoH).
399            launch_bms_flex_row_hvp_into_device(input.storage, &d_p, &mut d_q)
400                .map_err(|e| format!("pcg Hv iter {iter}: {e}"))?;
401
402            // pq = p·q
403            // SAFETY: identical to ‖b‖₂² launch — `f_dot` signature
404            // `(i32, *const f64, *const f64, *mut f64)`; `d_p` is the
405            // current search direction and `d_q` was just populated by
406            // `launch_bms_flex_row_hvp_into_device` (same stream, same `n`).
407            unsafe {
408                stream
409                    .launch_builder(&f_dot)
410                    .arg(&n_i32)
411                    .arg(&d_p)
412                    .arg(&d_q)
413                    .arg(&mut d_scalar)
414                    .launch(LaunchConfig {
415                        grid_dim: (1, 1, 1),
416                        block_dim: (dot_threads, 1, 1),
417                        shared_mem_bytes: 0,
418                    })
419            }
420            .map_err(|e| format!("pcg p·q launch iter {iter}: {e}"))?;
421            stream
422                .synchronize()
423                .map_err(|e| format!("pcg p·q sync iter {iter}: {e}"))?;
424            let s = stream
425                .clone_dtoh(&d_scalar)
426                .map_err(|e| format!("pcg p·q download iter {iter}: {e}"))?;
427            let pq = s[0];
428            if !pq.is_finite() || pq == 0.0 {
429                return Err(format!(
430                    "pcg iter {iter}: p·q={pq} (non-finite or zero); operator is not positive-definite"
431                ));
432            }
433            let alpha = rho / pq;
434
435            // x += α p
436            // SAFETY: `f_axpy` is `pcg_axpy` with signature
437            // `(i32, f64, *const f64, *mut f64)`. `alpha` is the
438            // finite-checked CG step length (`rho/pq`, both validated
439            // above). `d_p` and `d_x` are length-`n` `CudaSlice<f64>` on
440            // the same stream. Grid covers all `n` elements.
441            unsafe {
442                stream
443                    .launch_builder(&f_axpy)
444                    .arg(&n_i32)
445                    .arg(&alpha)
446                    .arg(&d_p)
447                    .arg(&mut d_x)
448                    .launch(LaunchConfig {
449                        grid_dim: (vec_blocks, 1, 1),
450                        block_dim: (vec_threads, 1, 1),
451                        shared_mem_bytes: 0,
452                    })
453            }
454            .map_err(|e| format!("pcg x+=αp iter {iter}: {e}"))?;
455
456            // r -= α q
457            let neg_alpha = -alpha;
458            // SAFETY: same `f_axpy` invariants as the `x += α p` launch
459            // above; `neg_alpha = -alpha` is finite (alpha was checked),
460            // `d_q` and `d_r` are length-`n` `CudaSlice<f64>` on the same
461            // stream.
462            unsafe {
463                stream
464                    .launch_builder(&f_axpy)
465                    .arg(&n_i32)
466                    .arg(&neg_alpha)
467                    .arg(&d_q)
468                    .arg(&mut d_r)
469                    .launch(LaunchConfig {
470                        grid_dim: (vec_blocks, 1, 1),
471                        block_dim: (vec_threads, 1, 1),
472                        shared_mem_bytes: 0,
473                    })
474            }
475            .map_err(|e| format!("pcg r-=αq iter {iter}: {e}"))?;
476
477            // ‖r‖₂² = r·r (single device dot, single f64 DtoH)
478            // SAFETY: identical to the ‖b‖₂² launch at function entry —
479            // `f_dot` signature, `d_r` length-`n`, `d_scalar` length-1,
480            // single-block reduction grid.
481            unsafe {
482                stream
483                    .launch_builder(&f_dot)
484                    .arg(&n_i32)
485                    .arg(&d_r)
486                    .arg(&d_r)
487                    .arg(&mut d_scalar)
488                    .launch(LaunchConfig {
489                        grid_dim: (1, 1, 1),
490                        block_dim: (dot_threads, 1, 1),
491                        shared_mem_bytes: 0,
492                    })
493            }
494            .map_err(|e| format!("pcg ‖r‖₂² launch iter {iter}: {e}"))?;
495            stream
496                .synchronize()
497                .map_err(|e| format!("pcg ‖r‖₂² sync iter {iter}: {e}"))?;
498            let s = stream
499                .clone_dtoh(&d_scalar)
500                .map_err(|e| format!("pcg ‖r‖₂² download iter {iter}: {e}"))?;
501            let rr = s[0];
502            if !rr.is_finite() {
503                return Err(format!("pcg iter {iter}: ‖r‖₂²={rr} non-finite"));
504            }
505            let rel = rr.sqrt() / b_norm;
506            final_rel_residual = rel;
507            if rel <= input.rel_tol {
508                break;
509            }
510
511            // z = M⁻¹ r
512            // SAFETY: same `f_precond` invariants as the `z₀ = M⁻¹ r₀`
513            // launch above — signature `(i32, f64, *const f64, *const f64,
514            // *mut f64)`, all four slices length-`n` `CudaSlice<f64>`, grid
515            // covers all `n` elements.
516            unsafe {
517                stream
518                    .launch_builder(&f_precond)
519                    .arg(&n_i32)
520                    .arg(&input.precond_diag_floor)
521                    .arg(&d_diag)
522                    .arg(&d_r)
523                    .arg(&mut d_z)
524                    .launch(LaunchConfig {
525                        grid_dim: (vec_blocks, 1, 1),
526                        block_dim: (vec_threads, 1, 1),
527                        shared_mem_bytes: 0,
528                    })
529            }
530            .map_err(|e| format!("pcg z=M⁻¹r iter {iter}: {e}"))?;
531
532            // ρ_new = r·z
533            // SAFETY: identical to the ρ₀ launch above — `f_dot`
534            // signature, `d_r` and `d_z` length-`n`, `d_scalar` length-1.
535            unsafe {
536                stream
537                    .launch_builder(&f_dot)
538                    .arg(&n_i32)
539                    .arg(&d_r)
540                    .arg(&d_z)
541                    .arg(&mut d_scalar)
542                    .launch(LaunchConfig {
543                        grid_dim: (1, 1, 1),
544                        block_dim: (dot_threads, 1, 1),
545                        shared_mem_bytes: 0,
546                    })
547            }
548            .map_err(|e| format!("pcg ρ_new launch iter {iter}: {e}"))?;
549            stream
550                .synchronize()
551                .map_err(|e| format!("pcg ρ_new sync iter {iter}: {e}"))?;
552            let s = stream
553                .clone_dtoh(&d_scalar)
554                .map_err(|e| format!("pcg ρ_new download iter {iter}: {e}"))?;
555            let rho_new = s[0];
556            if !rho_new.is_finite() {
557                return Err(format!("pcg iter {iter}: ρ_new={rho_new} non-finite"));
558            }
559            let beta_pcg = rho_new / rho;
560
561            // p = z + β p  ⇒  via pcg_axpby with a=1, b=β
562            // SAFETY: `f_axpby` is `pcg_axpby` with signature
563            // `(i32, f64, *const f64, f64, *mut f64)`. `beta_pcg = rho_new/rho`
564            // was finite-checked. `d_z` and `d_p` are length-`n`
565            // `CudaSlice<f64>` on the same stream; grid covers all `n`
566            // elements.
567            unsafe {
568                stream
569                    .launch_builder(&f_axpby)
570                    .arg(&n_i32)
571                    .arg(&1.0_f64)
572                    .arg(&d_z)
573                    .arg(&beta_pcg)
574                    .arg(&mut d_p)
575                    .launch(LaunchConfig {
576                        grid_dim: (vec_blocks, 1, 1),
577                        block_dim: (vec_threads, 1, 1),
578                        shared_mem_bytes: 0,
579                    })
580            }
581            .map_err(|e| format!("pcg p=z+βp iter {iter}: {e}"))?;
582
583            rho = rho_new;
584        }
585
586        // Download x once at the end.
587        let x_host = stream
588            .clone_dtoh(&d_x)
589            .map_err(|e| format!("pcg final x DtoH: {e}"))?;
590        // The auxiliary device allocs (d_r/d_z/d_p/d_q/d_scalar/d_diag) drop
591        // here and return their bytes to cudarc's allocator.
592        drop(d_r);
593        drop(d_z);
594        drop(d_p);
595        drop(d_q);
596        drop(d_scalar);
597        drop(d_diag);
598        Ok(DeviceResidentPcgOutput {
599            x: x_host,
600            iterations: iters_taken,
601            final_rel_residual,
602        })
603    }
604}
605
606/// Device-resident PCG against the BMS-FLEX row-Hessian operator.
607///
608/// Block 9 Phase 5: every PCG vector — `x`, `r`, `z`, `p`, `q` — stays on
609/// the device for the entire loop; only the squared residual norm (one f64)
610/// is downloaded per iteration for the convergence check. Bit-equal output
611/// to a host-side reference PCG against the same operator + preconditioner
612/// when the tolerance is tight; differences only show up at the floating-
613/// point reduction-order level.
614///
615/// Linux-only. See [`DeviceResidentPcgInput`] for parameters.
616#[cfg(target_os = "linux")]
617pub fn run_pcg_against_row_hessian_device(
618    input: DeviceResidentPcgInput<'_>,
619) -> Result<DeviceResidentPcgOutput, String> {
620    pcg_device::run(input)
621}
622
623/// Block 9 Phase 5 — V100 parity for `run_pcg_against_row_hessian_device`.
624///
625/// Builds a small `(n=64, r=20, p=44)` BMS-FLEX row-Hessian fixture, computes
626/// the dense joint Hessian via the same CPU oracle the HVP parity test uses,
627/// solves `H · x = b` on the host via dense LU as ground truth, and asserts
628/// the device-resident PCG iterate matches to a tight tolerance.
629#[cfg(all(test, target_os = "linux"))]
630mod pcg_device_parity_tests {
631    use super::*;
632    use crate::bms::gpu::row::{
633        BmsFlexBlockLayout, BmsFlexPrimaryLayout, DeviceResidentRowHess,
634    };
635    use ndarray::Array2;
636
637    /// Dense oracle for `H_full = Σ_i P_iᵀ H_i P_i` consistent with
638    /// `cpu_oracle_bms_flex_row_hvp`'s pullback math.
639    fn cpu_dense_joint_hessian(
640        row_hessians: &[f64],
641        marginal: &[f64],
642        logslope: &[f64],
643        block: &BmsFlexBlockLayout,
644        primary: &BmsFlexPrimaryLayout,
645        n: usize,
646    ) -> Array2<f64> {
647        let p_total = block.p_total;
648        let r = primary.r;
649        let p_m = block.p_m;
650        let p_g = block.p_g;
651        let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
652        let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
653        let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
654        let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
655        let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
656        let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
657        let mut h_dense = Array2::<f64>::zeros((p_total, p_total));
658        // For each row build P_i columns as length-p_total vectors.
659        let mut phi = vec![vec![0.0_f64; p_total]; r];
660        for row in 0..n {
661            for col in phi.iter_mut() {
662                col.iter_mut().for_each(|v| *v = 0.0);
663            }
664            let mrow = &marginal[row * p_m..(row + 1) * p_m];
665            let grow = &logslope[row * p_g..(row + 1) * p_g];
666            for k in 0..p_m {
667                phi[0][k] = mrow[k];
668            }
669            for k in 0..p_g {
670                phi[1][p_m + k] = grow[k];
671            }
672            for k in 0..h_block_len {
673                phi[h_primary_start + k][h_block_start + k] = 1.0;
674            }
675            for k in 0..w_block_len {
676                phi[w_primary_start + k][w_block_start + k] = 1.0;
677            }
678            let h_row = &row_hessians[row * r * r..(row + 1) * r * r];
679            for u in 0..r {
680                for v in 0..r {
681                    let huv = h_row[u * r + v];
682                    if huv == 0.0 {
683                        continue;
684                    }
685                    for m in 0..p_total {
686                        let phim = phi[u][m];
687                        if phim == 0.0 {
688                            continue;
689                        }
690                        let scaled = huv * phim;
691                        for nn in 0..p_total {
692                            h_dense[[m, nn]] += scaled * phi[v][nn];
693                        }
694                    }
695                }
696            }
697        }
698        h_dense
699    }
700
701    /// Reference oracle: host PCG against the dense joint H + diag(H)
702    /// preconditioner, with a tolerance two decades tighter than the GPU
703    /// PCG's. Comparing GPU PCG to host PCG (rather than to a Cholesky
704    /// solve) keeps the comparison numerically apples-to-apples — only
705    /// reduction order differs between the two paths.
706    fn cpu_pcg_oracle(h: &Array2<f64>, b: &[f64], rel_tol: f64) -> Vec<f64> {
707        let p = b.len();
708        let diag: ndarray::Array1<f64> =
709            ndarray::Array1::from_vec((0..p).map(|i| h[[i, i]]).collect());
710        let rhs = ndarray::Array1::from_vec(b.to_vec());
711        let h_owned = h.clone();
712        let apply = move |v: &ndarray::Array1<f64>| h_owned.dot(v);
713        let (x, info) =
714            gam_linalg::utils::solve_spd_pcg_with_info(apply, &rhs, &diag, rel_tol, 4 * p)
715                .expect("host PCG oracle must converge on SPD fixture");
716        assert!(
717            info.converged,
718            "host PCG oracle failed to converge: iters={} rel_res={}",
719            info.iterations, info.relative_residual_norm
720        );
721        x.to_vec()
722    }
723
724    #[test]
725    fn pcg_device_matches_dense_oracle_at_n64_r20_p44() {
726        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
727            eprintln!("[pcg_device parity] no CUDA runtime — skipping");
728            return;
729        };
730        let n = 64_usize;
731        let p_m = 14_usize;
732        let p_g = 12_usize;
733        let p_h_dim = 10_usize;
734        let p_w_dim = 8_usize;
735        let r = 2 + p_h_dim + p_w_dim;
736        let p_total = p_m + p_g + p_h_dim + p_w_dim;
737        let block = BmsFlexBlockLayout {
738            p_m,
739            p_g,
740            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
741            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
742            p_total,
743        };
744        let primary = BmsFlexPrimaryLayout {
745            h: Some(2..2 + p_h_dim),
746            w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
747            r,
748        };
749
750        // Same deterministic symmetric Hessians + designs as the HVP parity
751        // gate, so any drift between Phase 4 and Phase 5 surfaces here too.
752        let mut row_hessians = vec![0.0_f64; n * r * r];
753        for row in 0..n {
754            let base = row * r * r;
755            for u in 0..r {
756                for v in 0..r {
757                    let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
758                    let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
759                    row_hessians[base + u * r + v] = a;
760                }
761            }
762            for u in 0..r {
763                for v in (u + 1)..r {
764                    let upper = row_hessians[base + u * r + v];
765                    let lower = row_hessians[base + v * r + u];
766                    let sym = 0.5 * (upper + lower);
767                    row_hessians[base + u * r + v] = sym;
768                    row_hessians[base + v * r + u] = sym;
769                }
770                // Boost the diagonal heavily so each H_i is positive
771                // definite — guarantees the joint pulled-back Hessian is
772                // SPD, which PCG requires.
773                row_hessians[base + u * r + u] += 4.0 * (r as f64);
774            }
775        }
776        let mut marginal = vec![0.0_f64; n * p_m];
777        for row in 0..n {
778            for j in 0..p_m {
779                let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
780                marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
781            }
782        }
783        let mut logslope = vec![0.0_f64; n * p_g];
784        for row in 0..n {
785            for j in 0..p_g {
786                let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
787                logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
788            }
789        }
790
791        // Pick a non-trivial RHS.
792        let b: Vec<f64> = (0..p_total)
793            .map(|i| {
794                let seed = (i as f64) * 0.157 + 0.6;
795                seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
796            })
797            .collect();
798
799        let h_dense =
800            cpu_dense_joint_hessian(&row_hessians, &marginal, &logslope, &block, &primary, n);
801        let x_oracle = cpu_pcg_oracle(&h_dense, &b, 1e-12);
802
803        // Grab the same CUDA context + default stream that the bms_flex_row
804        // kernels will use when `run_pcg_against_row_hessian_device` probes
805        // its own backend. Going through the public runtime APIs keeps the
806        // test independent of any private kernel-backend symbols.
807        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
808            .expect("runtime must exist when probe succeeded above");
809        // Past the GpuRuntime::global() Some-gate above: a context-creation or
810        // HtoD-upload failure here is a real device fault on a CUDA host, not a
811        // no-CUDA skip — fail loud (device-PCG skip-pass class, eee12f6b2). The old
812        // arms returned, so a context/upload fault on a GPU host passed silently.
813        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
814            .expect("[pcg_device parity] cuda_context_for must succeed on a CUDA host");
815        let stream = ctx.default_stream();
816        let d_h = stream
817            .clone_htod(&row_hessians)
818            .expect("[pcg_device parity] upload h must succeed on a CUDA host");
819        let d_m = stream
820            .clone_htod(&marginal)
821            .expect("[pcg_device parity] upload marginal must succeed on a CUDA host");
822        let d_g = stream
823            .clone_htod(&logslope)
824            .expect("[pcg_device parity] upload logslope must succeed on a CUDA host");
825        let storage = DeviceResidentRowHess {
826            hess: d_h,
827            marginal_design: d_m,
828            logslope_design: d_g,
829            n,
830            r,
831            block,
832            primary,
833
834            bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
835        };
836
837        let out = run_pcg_against_row_hessian_device(DeviceResidentPcgInput {
838            storage: &storage,
839            b: &b,
840            rel_tol: 1e-10,
841            max_iters: 4 * p_total,
842            precond_diag_floor: 1e-12,
843        })
844        .expect("device-resident PCG must succeed on SPD fixture");
845
846        assert_eq!(out.x.len(), p_total);
847        let mut max_abs = 0.0_f64;
848        for i in 0..p_total {
849            let diff = (out.x[i] - x_oracle[i]).abs();
850            if diff > max_abs {
851                max_abs = diff;
852            }
853        }
854        // Each iteration introduces O(1) ULPs of round-off in the dot/
855        // axpy ladder; with ~88 iters max at p=44 we expect ‖Δx‖∞ comfortably
856        // below 1e-7. Anything larger means a code bug, not float noise.
857        assert!(
858            max_abs <= 1e-7,
859            "pcg_device parity ‖Δx‖∞={max_abs:.3e} > 1e-7 after {} iters \
860             (final rel residual={:.3e})",
861            out.iterations,
862            out.final_rel_residual
863        );
864        eprintln!(
865            "[pcg_device parity] n={n} p={p_total} r={r}: iters={} rel_res={:.3e} ‖Δx‖∞={:.3e}",
866            out.iterations, out.final_rel_residual, max_abs
867        );
868    }
869}