Skip to main content

gam_solve/gpu_kernels/
sae_resident.rs

1//! Device-resident SAE inner-iteration workspace for issue #1017.
2//!
3//! This first vertical slice keeps production fitting untouched. It accepts
4//! host-evaluated SAE basis/gate values plus already-assembled data-fit
5//! Arrow-Schur slabs, uploads those buffers once, and runs one Newton step
6//! through the existing GPU Arrow-Schur sequence when the runtime probe admits
7//! the workload. Later slices can replace the host slab feed with on-device
8//! basis/gate evaluation without changing the public step API.
9
10use ndarray::Array1;
11
12use crate::gpu_kernels::arrow_schur::{
13    ArrowSchurGpuFailure, solve_arrow_newton_step, solve_arrow_newton_step_dense_reference,
14};
15use gam_problem::ExecutionPath;
16
17/// Per-iterate solve backend for the resident inner Newton loop.
18///
19/// All three modes run the IDENTICAL host control flow (`run_inner_loop`):
20/// residual-gradient assembly, LM trust-region accept/reject, ridge schedule.
21/// They differ ONLY in how the per-iterate arrow step is computed, which is
22/// exactly the residency lever #1017 measures:
23///
24/// * [`InnerSolveMode::DeviceResident`] — the Phase-3 fix: factor the constant
25///   Hessian blocks ONCE into a [`crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle`]
26///   and, every iterate, upload only the `O(n·d + k)` gradient and read back
27///   only `δ`. No per-solve D/B re-upload, no per-solve POTRF.
28/// * [`InnerSolveMode::DeviceReupload`] — the BEFORE path: call
29///   `solve_arrow_newton_step` per iterate, which re-packs and re-uploads
30///   `D`/`B`/`g` and re-runs the per-row POTRF + border Schur factor every call.
31///   This is the residency baseline the bench divides against.
32/// * [`InnerSolveMode::CpuReference`] — the dense f64 oracle (re-factors per
33///   iterate on the host), used for the correctness parity check.
34#[derive(Clone, Copy, Debug, Eq, PartialEq)]
35pub enum InnerSolveMode {
36    DeviceResident,
37    DeviceReupload,
38    CpuReference,
39}
40
41impl InnerSolveMode {
42    /// Truthful [`ExecutionPath`] this solve mode realizes (issue #1017): the
43    /// resident loop keeps factors on-device (`GpuResidentFull`), the baseline
44    /// re-uploads/re-factors every iterate (`GpuReupload`), and the reference
45    /// path runs on the host (`Cpu`).
46    #[inline]
47    const fn execution_path(self) -> ExecutionPath {
48        match self {
49            Self::DeviceResident => ExecutionPath::GpuResidentFull,
50            Self::DeviceReupload => ExecutionPath::GpuReupload,
51            Self::CpuReference => ExecutionPath::Cpu,
52        }
53    }
54}
55use crate::arrow_schur::{ArrowSchurError, ArrowSchurSystem};
56
57/// SAE shape used by the resident inner-iteration workspace.
58///
59/// `p` is the target width and current shared-border width for this slice. The
60/// true SAE decoder has richer `(basis × output)` structure; slice 1 deliberately
61/// keeps that structure host-assembled into `row_cross_slabs` while preserving
62/// the qwen-scale target width in the Schur border.
63#[derive(Clone, Copy, Debug, Eq, PartialEq)]
64pub struct DeviceResidentArrowShape {
65    pub n: usize,
66    pub p: usize,
67    pub basis_cols: usize,
68    pub d: usize,
69}
70
71impl DeviceResidentArrowShape {
72    #[inline]
73    pub const fn qwen_non_gating() -> Self {
74        Self {
75            n: 2_000,
76            p: 2_048,
77            basis_cols: 8,
78            d: 2,
79        }
80    }
81
82    /// Color-arm shape from the #1017 measured gap (n=180, p=5120, M≈9, K=1):
83    /// few rows, very wide border. The dense-Schur device path (cuSOLVER border
84    /// POTRF) handles the `p=5120` border that exceeds the fused-kernel `P_MAX`.
85    #[inline]
86    pub const fn color_arm() -> Self {
87        Self {
88            n: 180,
89            p: 5_120,
90            basis_cols: 9,
91            d: 2,
92        }
93    }
94
95    #[inline]
96    pub const fn target_len(self) -> usize {
97        self.n * self.p
98    }
99
100    #[inline]
101    pub const fn basis_len(self) -> usize {
102        self.n * self.basis_cols
103    }
104
105    #[inline]
106    pub const fn row_hessian_len(self) -> usize {
107        self.n * self.d * self.d
108    }
109
110    #[inline]
111    pub const fn row_cross_len(self) -> usize {
112        self.n * self.d * self.p
113    }
114
115    #[inline]
116    pub const fn row_gradient_len(self) -> usize {
117        self.n * self.d
118    }
119
120    #[inline]
121    pub const fn border_hessian_len(self) -> usize {
122        self.p * self.p
123    }
124}
125
126/// Host-fed row-block slabs for the first resident slice.
127///
128/// All matrices are row-major in host memory:
129/// * `row_hessian_slabs`: `n` slabs of shape `d × d`.
130/// * `row_cross_slabs`: `n` slabs of shape `d × p`.
131/// * `border_hessian`: one `p × p` shared block.
132#[derive(Clone, Debug)]
133pub struct DeviceResidentArrowSlabs {
134    pub row_hessian_slabs: Vec<f64>,
135    pub row_cross_slabs: Vec<f64>,
136    pub row_gradient_slabs: Vec<f64>,
137    pub border_hessian: Vec<f64>,
138    pub border_gradient: Vec<f64>,
139}
140
141/// Result of one resident SAE inner Newton iteration.
142#[derive(Clone, Debug)]
143pub struct DeviceResidentArrowStep {
144    pub delta_t: Array1<f64>,
145    pub delta_beta: Array1<f64>,
146    pub objective: f64,
147    pub gradient_norm: f64,
148    pub log_det_hessian: f64,
149    pub execution_path: ExecutionPath,
150}
151
152#[derive(Debug, Clone)]
153pub enum DeviceResidentArrowError {
154    Shape { reason: String },
155    Unavailable { reason: String },
156    Solve { reason: String },
157}
158
159impl std::fmt::Display for DeviceResidentArrowError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            Self::Shape { reason } | Self::Unavailable { reason } | Self::Solve { reason } => {
163                f.write_str(reason)
164            }
165        }
166    }
167}
168
169impl std::error::Error for DeviceResidentArrowError {}
170
171#[cfg(target_os = "linux")]
172pub struct DeviceResidentArrowBuffers {
173    pub stream: std::sync::Arc<cudarc::driver::CudaStream>,
174    pub target_x_dev: cudarc::driver::CudaSlice<f64>,
175    pub basis_values_dev: cudarc::driver::CudaSlice<f64>,
176    pub gate_activations_dev: cudarc::driver::CudaSlice<f64>,
177    pub row_hessian_dev: cudarc::driver::CudaSlice<f64>,
178    pub row_cross_dev: cudarc::driver::CudaSlice<f64>,
179    pub row_gradient_dev: cudarc::driver::CudaSlice<f64>,
180    pub border_hessian_dev: cudarc::driver::CudaSlice<f64>,
181    pub border_gradient_dev: cudarc::driver::CudaSlice<f64>,
182    pub bytes: usize,
183}
184
185/// Upload-once workspace for the SAE data-fit Arrow-Schur inner iteration.
186pub struct DeviceResidentArrowWorkspace {
187    shape: DeviceResidentArrowShape,
188    target_x: Vec<f64>,
189    basis_values: Vec<f64>,
190    gate_activations: Vec<f64>,
191    slabs: DeviceResidentArrowSlabs,
192    #[cfg(target_os = "linux")]
193    device: Option<DeviceResidentArrowBuffers>,
194}
195
196impl DeviceResidentArrowWorkspace {
197    pub fn new(
198        shape: DeviceResidentArrowShape,
199        target_x: Vec<f64>,
200        basis_values: Vec<f64>,
201        gate_activations: Vec<f64>,
202        slabs: DeviceResidentArrowSlabs,
203    ) -> Result<Self, DeviceResidentArrowError> {
204        validate_shape(shape, &target_x, &basis_values, &gate_activations, &slabs)?;
205        #[cfg(target_os = "linux")]
206        let device =
207            upload_resident_buffers(shape, &target_x, &basis_values, &gate_activations, &slabs);
208        Ok(Self {
209            shape,
210            target_x,
211            basis_values,
212            gate_activations,
213            slabs,
214            #[cfg(target_os = "linux")]
215            device,
216        })
217    }
218
219    #[inline]
220    pub const fn shape(&self) -> DeviceResidentArrowShape {
221        self.shape
222    }
223
224    #[must_use]
225    pub fn device_resident(&self) -> bool {
226        #[cfg(target_os = "linux")]
227        {
228            self.device.is_some()
229        }
230        #[cfg(not(target_os = "linux"))]
231        {
232            false
233        }
234    }
235
236    #[must_use]
237    pub fn resident_device_bytes(&self) -> usize {
238        #[cfg(target_os = "linux")]
239        {
240            self.device.as_ref().map_or(0, |device| device.bytes)
241        }
242        #[cfg(not(target_os = "linux"))]
243        {
244            0
245        }
246    }
247
248    /// Opaque device-context identifier for telemetry: `1` when the resident
249    /// device buffers are live on this workspace, `0` when no device was bound.
250    /// Distinguishes "a device executed this fit" from "silent CPU fallback"
251    /// without leaking the cudarc handle.
252    #[must_use]
253    fn context_id(&self) -> usize {
254        usize::from(self.device_resident())
255    }
256
257    /// Bytes the re-uploading / frame-build path moves host→device for a full
258    /// `D`/`B`/`g`/border refresh, used to attribute H2D traffic in telemetry.
259    #[must_use]
260    fn frame_upload_bytes(&self) -> usize {
261        [
262            self.slabs.row_hessian_slabs.len(),
263            self.slabs.row_cross_slabs.len(),
264            self.slabs.row_gradient_slabs.len(),
265            self.slabs.border_hessian.len(),
266            self.slabs.border_gradient.len(),
267        ]
268        .into_iter()
269        .sum::<usize>()
270            * std::mem::size_of::<f64>()
271    }
272
273    #[must_use]
274    pub fn host_shadow_bytes(&self) -> usize {
275        [
276            self.target_x.len(),
277            self.basis_values.len(),
278            self.gate_activations.len(),
279            self.slabs.row_hessian_slabs.len(),
280            self.slabs.row_cross_slabs.len(),
281            self.slabs.row_gradient_slabs.len(),
282            self.slabs.border_hessian.len(),
283            self.slabs.border_gradient.len(),
284        ]
285        .into_iter()
286        .sum::<usize>()
287            * std::mem::size_of::<f64>()
288    }
289
290    /// Run one device-side Newton sequence. No CPU fallback is attempted here:
291    /// callers that want a reference path must call [`Self::cpu_reference_step`].
292    pub fn one_inner_iteration(
293        &self,
294        ridge_t: f64,
295        ridge_beta: f64,
296    ) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
297        if !self.device_resident() {
298            return Err(DeviceResidentArrowError::Unavailable {
299                reason: "SAE resident inner iteration unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
300            });
301        }
302        let sys = self.to_arrow_system();
303        solve_arrow_newton_step(&sys, ridge_t, ridge_beta)
304            .map(|solution| self.finish_step(solution, ExecutionPath::GpuResidentLinearization))
305            .map_err(map_gpu_error)
306    }
307
308    /// CPU reference for parity harnesses. This path is explicit and is never
309    /// called from [`Self::one_inner_iteration`].
310    pub fn cpu_reference_step(
311        &self,
312        ridge_t: f64,
313        ridge_beta: f64,
314    ) -> Result<DeviceResidentArrowStep, DeviceResidentArrowError> {
315        let sys = self.to_arrow_system();
316        solve_arrow_newton_step_dense_reference(&sys, ridge_t, ridge_beta)
317            .map(|solution| self.finish_step(solution, ExecutionPath::Cpu))
318            .map_err(|reason| DeviceResidentArrowError::Solve { reason })
319    }
320
321    pub fn to_arrow_system(&self) -> ArrowSchurSystem {
322        let shape = self.shape;
323        let mut sys = ArrowSchurSystem::new(shape.n, shape.d, shape.p);
324        for i in 0..shape.n {
325            let h_base = i * shape.d * shape.d;
326            let b_base = i * shape.d * shape.p;
327            let g_base = i * shape.d;
328            for r in 0..shape.d {
329                for c in 0..shape.d {
330                    sys.rows[i].htt[[r, c]] =
331                        self.slabs.row_hessian_slabs[h_base + r * shape.d + c];
332                }
333                sys.rows[i].gt[r] = self.slabs.row_gradient_slabs[g_base + r];
334                for c in 0..shape.p {
335                    sys.rows[i].htbeta[[r, c]] =
336                        self.slabs.row_cross_slabs[b_base + r * shape.p + c];
337                }
338            }
339        }
340        for r in 0..shape.p {
341            sys.gb[r] = self.slabs.border_gradient[r];
342            for c in 0..shape.p {
343                sys.hbb[[r, c]] = self.slabs.border_hessian[r * shape.p + c];
344            }
345        }
346        sys.refresh_row_hessian_fingerprint();
347        sys
348    }
349
350    fn finish_step(
351        &self,
352        solution: crate::gpu_kernels::arrow_schur::ArrowSchurGpuSolution,
353        execution_path: ExecutionPath,
354    ) -> DeviceResidentArrowStep {
355        DeviceResidentArrowStep {
356            delta_t: solution.delta_t,
357            delta_beta: solution.delta_beta,
358            objective: 0.5 * squared_norm(&self.target_x),
359            gradient_norm: self.gradient_norm(),
360            log_det_hessian: solution.log_det_hessian,
361            execution_path,
362        }
363    }
364
365    fn gradient_norm(&self) -> f64 {
366        let row = squared_norm(&self.slabs.row_gradient_slabs);
367        let border = squared_norm(&self.slabs.border_gradient);
368        (row + border).sqrt()
369    }
370
371    // ---------------------------------------------------------------------
372    // Phase 3: full device-resident inner Newton loop (#1017).
373    //
374    // The resident slabs define a fixed bordered-quadratic data-fit objective
375    //     φ(z) = ½‖X‖² + ½ zᵀ H z − g₀ᵀ z,   z = (t, β),
376    // where `H` is the arrow-structured Hessian (per-row `H_tt`/`H_tβ` blocks
377    // plus the shared `H_ββ` border) and `g₀` is the base gradient assembled
378    // once at upload. This is the quadratic the SAE joint inner Newton actually
379    // minimises at a frozen gate/basis evaluation; the production driver
380    // (`LatentInnerSolver::solve`) re-linearises per outer evaluation, so a
381    // single resident frame is one such inner solve.
382    //
383    // The loop mirrors the production LM trust-region accept/reject exactly:
384    // at iterate `z` it forms the residual gradient `r(z) = H z − g₀`, takes
385    // the LM-damped arrow step (device or dense-reference), evaluates the trial
386    // objective, and accepts on the actual-vs-predicted reduction ratio. The
387    // iterate `(t, β)` and the per-step scalars (objective, gradient norm, ρ)
388    // are the ONLY host-side state; the heavy `O(n d³ + p³)` factor/solve stays
389    // on the resident buffers via `solve_arrow_newton_step`. For an exact
390    // quadratic the loop converges in one accepted step, but it exercises the
391    // full assemble→solve→objective→accept machinery and the scalar-only
392    // readback contract the production loop relies on.
393    // ---------------------------------------------------------------------
394
395    /// Run the full device-resident inner Newton loop. Routes the per-iteration
396    /// arrow solve through the GPU path; returns `Unavailable` when CUDA did not
397    /// admit the resident workload (callers wanting a CPU path use
398    /// [`Self::cpu_reference_fit`]).
399    pub fn device_fit(
400        &self,
401        opts: &DeviceResidentInnerOptions,
402    ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
403        if !self.device_resident() {
404            return Err(DeviceResidentArrowError::Unavailable {
405                reason: "SAE resident inner loop unavailable: CUDA runtime did not admit the qwen-scale row-block workload".to_string(),
406            });
407        }
408        self.run_inner_loop(opts, InnerSolveMode::DeviceResident)
409    }
410
411    /// The #1017 residency baseline: run the SAME inner Newton loop but compute
412    /// each per-iterate arrow step through `solve_arrow_newton_step`, which
413    /// re-packs/re-uploads `D`/`B`/`g` and re-runs the per-row POTRF + border
414    /// Schur factor on EVERY iterate. This is the "current re-uploading path";
415    /// the bench divides [`Self::device_fit`] (resident) against it to isolate
416    /// the across-iteration residency speedup on one device, holding the host
417    /// control flow and the GPU factor kernels fixed.
418    pub fn device_reupload_fit(
419        &self,
420        opts: &DeviceResidentInnerOptions,
421    ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
422        if !self.device_resident() {
423            return Err(DeviceResidentArrowError::Unavailable {
424                reason: "SAE re-uploading inner loop unavailable: CUDA runtime did not admit the row-block workload".to_string(),
425            });
426        }
427        self.run_inner_loop(opts, InnerSolveMode::DeviceReupload)
428    }
429
430    /// CPU dense-reference inner loop. Bit-for-bit the same host arithmetic as
431    /// [`Self::device_fit`] except the per-iteration arrow solve uses the dense
432    /// reference factorisation; the parity harness asserts the two agree.
433    pub fn cpu_reference_fit(
434        &self,
435        opts: &DeviceResidentInnerOptions,
436    ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
437        self.run_inner_loop(opts, InnerSolveMode::CpuReference)
438    }
439
440    fn run_inner_loop(
441        &self,
442        opts: &DeviceResidentInnerOptions,
443        mode: InnerSolveMode,
444    ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
445        let execution_path = mode.execution_path();
446        let n = self.shape.n;
447        let d = self.shape.d;
448        let p = self.shape.p;
449        let t_len = n * d;
450
451        // Resident iterate, host-side scalars only. The device buffers (X,
452        // slabs, border) never leave the device across iterations; only this
453        // O(t_len + p) iterate and the per-step reduction scalars cross back.
454        let mut t = vec![0.0_f64; t_len];
455        let mut beta = vec![0.0_f64; p];
456
457        let base = self.to_arrow_system();
458        let half_target_energy = 0.5 * squared_norm(&self.target_x);
459
460        let mut ridge_t = opts.initial_ridge_t.max(0.0);
461        let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
462        // #1017 Phase 3: when running on device, keep the resident Arrow frame
463        // (constant Hessian blocks + their factors) on the device across
464        // iterations. The frame bakes a fixed `(ridge_t, ridge_beta)` into the
465        // per-row and border Cholesky factors, so it is rebuilt only when the LM
466        // ridge changes (reject/shrink); every iteration that shares the cached
467        // ridge reuses the resident factors and uploads only the `O(n·d + p)`
468        // gradient. The CPU reference path keeps re-factoring per iterate so the
469        // parity harness compares residency against a fully independent solve.
470        let mut resident_frame: Option<(
471            f64,
472            f64,
473            crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle,
474        )> = None;
475        let mut current_objective = self.objective_at(&base, half_target_energy, &t, &beta);
476        let mut accepted_iters = 0_usize;
477        let mut total_iters = 0_usize;
478        let mut converged = false;
479        let mut last_step = DeviceResidentArrowStep {
480            delta_t: Array1::zeros(t_len),
481            delta_beta: Array1::zeros(p),
482            objective: current_objective,
483            gradient_norm: 0.0,
484            log_det_hessian: 0.0,
485            execution_path,
486        };
487
488        while total_iters < opts.max_iterations {
489            // Residual gradient r(z) = H z − g₀ becomes the system gradient.
490            let residual = self.residual_system(&base, &t, &beta);
491            let g_norm = arrow_system_gradient_norm(&residual);
492            let scale = 1.0 + iterate_norm(&t, &beta);
493            if g_norm / scale < opts.convergence_tolerance {
494                converged = true;
495                break;
496            }
497
498            let solution = match mode {
499                InnerSolveMode::DeviceResident => {
500                    // Rebuild the resident frame only when the LM ridge changed; an
501                    // unchanged ridge reuses the resident factors. A build failure
502                    // becomes a Solve error so the LM-escalation arm below grows the
503                    // ridge and retries, identical to a per-iterate solve failure.
504                    let frame_matches = resident_frame
505                        .as_ref()
506                        .is_some_and(|(rt, rb, _)| *rt == ridge_t && *rb == ridge_beta);
507                    let mut frame_build_error: Option<DeviceResidentArrowError> = None;
508                    if !frame_matches {
509                        resident_frame = None;
510                        match crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
511                            &residual, ridge_t, ridge_beta,
512                        ) {
513                            Ok(frame) => {
514                                // Building a resident frame creates the device
515                                // stream/handles and runs the per-row POTRF +
516                                // border Schur factor once; record both so a
517                                // silent decline (no rebuild ⇒ no factor count)
518                                // is visible in the telemetry.
519                                gam_gpu::profile::telemetry_record_handle_creation(
520                                    self.context_id(),
521                                );
522                                gam_gpu::profile::telemetry_record_factorization();
523                                gam_gpu::profile::telemetry_record_h2d(
524                                    self.frame_upload_bytes(),
525                                );
526                                resident_frame = Some((ridge_t, ridge_beta, frame));
527                            }
528                            Err(err) => frame_build_error = Some(map_gpu_error(err)),
529                        }
530                    }
531                    match resident_frame.as_ref() {
532                        Some((_, _, frame)) => {
533                            // Per-iterate gradient r(z) = (g_t rows, g_β), extracted
534                            // from the residual system the frame was built to match.
535                            let mut g_t = Vec::with_capacity(n * d);
536                            for row in &residual.rows {
537                                for &v in row.gt.iter() {
538                                    g_t.push(v);
539                                }
540                            }
541                            let g_beta: Vec<f64> = residual.gb.iter().copied().collect();
542                            // The resident solve uploads only the O(n·d + p)
543                            // gradient, launches the per-iterate solve kernel, and
544                            // reads back only δ.
545                            let grad_bytes =
546                                (g_t.len() + g_beta.len()) * std::mem::size_of::<f64>();
547                            gam_gpu::profile::telemetry_record_h2d(grad_bytes);
548                            gam_gpu::profile::telemetry_record_kernel_launch();
549                            gam_gpu::profile::telemetry_record_d2h(
550                                (n * d + p) * std::mem::size_of::<f64>(),
551                            );
552                            frame.solve_gradient(&g_t, &g_beta).map_err(map_gpu_error)
553                        }
554                        None => Err(frame_build_error.unwrap_or_else(|| {
555                            DeviceResidentArrowError::Solve {
556                                reason: "SAE resident frame build declined".to_string(),
557                            }
558                        })),
559                    }
560                }
561                InnerSolveMode::DeviceReupload => {
562                    // #1017 residency baseline: re-upload D/B/g and re-factor on
563                    // every iterate. Same GPU factor kernels as the resident path,
564                    // minus the across-iteration buffer/factor reuse — so EVERY
565                    // iterate creates handles, factorizes, launches, and re-uploads
566                    // the full slabs.
567                    gam_gpu::profile::telemetry_record_handle_creation(self.context_id());
568                    gam_gpu::profile::telemetry_record_factorization();
569                    gam_gpu::profile::telemetry_record_h2d(self.frame_upload_bytes());
570                    gam_gpu::profile::telemetry_record_kernel_launch();
571                    gam_gpu::profile::telemetry_record_d2h(
572                        (n * d + p) * std::mem::size_of::<f64>(),
573                    );
574                    solve_arrow_newton_step(&residual, ridge_t, ridge_beta).map_err(map_gpu_error)
575                }
576                InnerSolveMode::CpuReference => {
577                    solve_arrow_newton_step_dense_reference(&residual, ridge_t, ridge_beta)
578                        .map_err(|reason| DeviceResidentArrowError::Solve { reason })
579                }
580            };
581
582            let solution = match solution {
583                Ok(sol) => sol,
584                Err(DeviceResidentArrowError::Solve { .. })
585                | Err(DeviceResidentArrowError::Unavailable { .. }) => {
586                    // LM escalation: grow ridge, retry without consuming an
587                    // iteration. Mirrors the production per-row/Schur PD-failure
588                    // arm in `LatentInnerSolver::solve`.
589                    ridge_t = grow_ridge(ridge_t, opts.lm_grow);
590                    ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
591                    if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
592                        return Err(DeviceResidentArrowError::Solve {
593                            reason: format!(
594                                "SAE resident inner loop: LM ridge exceeded max ({:e}) at iter {total_iters}",
595                                opts.max_ridge
596                            ),
597                        });
598                    }
599                    total_iters += 1;
600                    continue;
601                }
602                Err(other) => return Err(other),
603            };
604
605            // Predicted reduction from the bare quadratic model on the residual
606            // system, identical formula to the production trust-region ratio.
607            let predicted_reduction =
608                crate::arrow_schur::arrow_bare_quadratic_model_reduction(
609                    &residual,
610                    solution.delta_t.view(),
611                    solution.delta_beta.view(),
612                    ridge_t,
613                    ridge_beta,
614                )
615                .map_err(|err| DeviceResidentArrowError::Solve {
616                    reason: format!("SAE resident inner loop predicted-reduction failed: {err}"),
617                })?;
618
619            // Trial iterate.
620            let mut trial_t = t.clone();
621            let mut trial_beta = beta.clone();
622            for (slot, dv) in trial_t.iter_mut().zip(solution.delta_t.iter()) {
623                *slot += *dv;
624            }
625            for (slot, dv) in trial_beta.iter_mut().zip(solution.delta_beta.iter()) {
626                *slot += *dv;
627            }
628            let trial_objective =
629                self.objective_at(&base, half_target_energy, &trial_t, &trial_beta);
630
631            // Trust-region gain-ratio noise floor keyed to the objective's own
632            // magnitude, mirroring the production `LatentInnerSolver` (#1127): the
633            // floor must be equivariant under a response rescaling `y → a·y` (the
634            // penalized objective and both reductions scale as `O(a²)`). The
635            // previous `.max(1.0)` absolute floor broke this — near a converged
636            // iterate it pinned the floor at `1e-14` while a genuine refining
637            // step's `predicted_reduction` was `O(a²)`, misclassifying the real
638            // step as numerical noise and stalling the inner solve at a
639            // non-stationary point. A perfectly converged objective
640            // (`current_objective == 0`) yields a `0` floor, so the
641            // `predicted_reduction > 0` branch still governs and no step is lost.
642            let objective_scale = current_objective.abs();
643            let noise_floor = objective_scale * 1e-14;
644            let actual_reduction = current_objective - trial_objective;
645            let rho = if predicted_reduction > noise_floor {
646                actual_reduction / predicted_reduction
647            } else if actual_reduction >= -noise_floor {
648                1.0
649            } else {
650                -1.0
651            };
652
653            if rho > 0.0 && trial_objective.is_finite() {
654                t = trial_t;
655                beta = trial_beta;
656                current_objective = trial_objective;
657                ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
658                ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
659                last_step = DeviceResidentArrowStep {
660                    delta_t: solution.delta_t,
661                    delta_beta: solution.delta_beta,
662                    objective: current_objective,
663                    gradient_norm: g_norm,
664                    log_det_hessian: solution.log_det_hessian,
665                    execution_path,
666                };
667                accepted_iters += 1;
668                total_iters += 1;
669            } else {
670                ridge_t = grow_ridge(ridge_t, opts.lm_grow);
671                ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
672                if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
673                    return Err(DeviceResidentArrowError::Solve {
674                        reason: format!(
675                            "SAE resident inner loop: LM rejected step until ridge exceeded max ({:e}) at iter {total_iters} (rho={rho:.3e})",
676                            opts.max_ridge
677                        ),
678                    });
679                }
680                total_iters += 1;
681            }
682        }
683
684        Ok(DeviceResidentInnerOutcome {
685            t: Array1::from_vec(t),
686            beta: Array1::from_vec(beta),
687            objective: current_objective,
688            gradient_norm: last_step.gradient_norm,
689            log_det_hessian: last_step.log_det_hessian,
690            iterations: total_iters,
691            accepted_iterations: accepted_iters,
692            converged,
693            execution_path,
694        })
695    }
696
697    // ---------------------------------------------------------------------
698    // Phase 3b: reuse the resident frame ACROSS OUTER iterations (#1017
699    // deliverable 3).
700    //
701    // The inner Newton loop above already keeps the resident Arrow frame
702    // (factored `D`/`B`/Schur) on the device across INNER iterations at a fixed
703    // ridge. The next residency tier is the OUTER loop: across consecutive outer
704    // evaluations the SAE Hessian operator is unchanged whenever the frozen
705    // gate/basis frame (hence `D = H_tt`, `B = H_tβ`, border `H_ββ`) does not
706    // move — only the base gradient `g₀` (the linearization point / target
707    // residual) changes. In that regime the `O(n·d³ + p³)` factor work and the
708    // dominant `O(n·d·p)` `D`/`B` upload need to happen ONCE for the whole outer
709    // sweep, not once per outer. `device_fit_outer_sequence` realizes that: it
710    // builds at most ONE resident frame for an unchanged operator and drives
711    // every outer's inner solve through it, re-uploading only the per-outer
712    // `O(n·d + p)` gradient. The per-outer parity oracle is an independent
713    // `device_fit` (fresh frame per outer); the two must agree because sharing
714    // the factor across outers skips only re-deriving operator-independent work.
715    // ---------------------------------------------------------------------
716
717    /// Run a sequence of outer evaluations that SHARE one resident frame when the
718    /// Hessian operator is unchanged across outers (#1017 deliverable 3).
719    ///
720    /// Each entry of `base_gradient_overrides` is one outer evaluation's base
721    /// gradient `(g_t rows: n·d, g_β: p)` — the only part of the bordered
722    /// quadratic that moves across outers at a frozen gate/basis frame. The
723    /// constant Hessian blocks ride the resident frame, which is built ONCE and
724    /// reused for every outer (frame builds are counted and returned so a caller
725    /// can assert the across-outer amortization actually fired: exactly one frame
726    /// build for an unchanged operator, regardless of how many outers run).
727    ///
728    /// Returns one [`DeviceResidentInnerOutcome`] per outer plus the number of
729    /// resident-frame builds performed across the whole sweep. On a CPU-only host
730    /// returns `Unavailable` (callers wanting a host path use
731    /// [`Self::cpu_reference_outer_sequence`]).
732    pub fn device_fit_outer_sequence(
733        &self,
734        base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
735        opts: &DeviceResidentInnerOptions,
736    ) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
737        if !self.device_resident() {
738            return Err(DeviceResidentArrowError::Unavailable {
739                reason: "SAE outer-sequence residency unavailable: CUDA runtime did not admit the row-block workload".to_string(),
740            });
741        }
742        self.run_outer_sequence(
743            base_gradient_overrides,
744            opts,
745            InnerSolveMode::DeviceResident,
746        )
747    }
748
749    /// CPU-reference outer sequence: same host control flow as
750    /// [`Self::device_fit_outer_sequence`] but the per-iterate arrow solve uses
751    /// the dense reference factorisation. The parity harness asserts the device
752    /// across-outer sweep agrees with this per-outer-independent reference.
753    pub fn cpu_reference_outer_sequence(
754        &self,
755        base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
756        opts: &DeviceResidentInnerOptions,
757    ) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
758        self.run_outer_sequence(base_gradient_overrides, opts, InnerSolveMode::CpuReference)
759    }
760
761    fn run_outer_sequence(
762        &self,
763        base_gradient_overrides: &[(Vec<f64>, Vec<f64>)],
764        opts: &DeviceResidentInnerOptions,
765        mode: InnerSolveMode,
766    ) -> Result<OuterSequenceOutcome, DeviceResidentArrowError> {
767        let n = self.shape.n;
768        let d = self.shape.d;
769        let p = self.shape.p;
770        let t_len = n * d;
771        let half_target_energy = 0.5 * squared_norm(&self.target_x);
772
773        // ONE resident frame for the whole sweep (device mode only). The operator
774        // is unchanged across outers — the frame bakes the constant `D`/`B`/Schur
775        // factors at `(initial_ridge_t, initial_ridge_beta)` once and every outer
776        // reuses it. A per-outer ridge escalation (PD failure) still rebuilds, but
777        // for a well-posed unchanged operator the build count stays at 1, which is
778        // the across-outer amortization this method delivers.
779        let mut shared = SharedFrameState::default();
780        let mut outcomes = Vec::with_capacity(base_gradient_overrides.len());
781
782        for (g_t_override, g_beta_override) in base_gradient_overrides {
783            if g_t_override.len() != t_len || g_beta_override.len() != p {
784                return Err(DeviceResidentArrowError::Shape {
785                    reason: format!(
786                        "outer-sequence gradient shape mismatch: g_t={} (want {t_len}), g_beta={} (want {p})",
787                        g_t_override.len(),
788                        g_beta_override.len()
789                    ),
790                });
791            }
792            // This outer's bordered quadratic: same Hessian blocks, base gradient
793            // swapped to this outer's `g₀`.
794            let mut base = self.to_arrow_system();
795            for (i, row) in base.rows.iter_mut().enumerate() {
796                for r in 0..d {
797                    row.gt[r] = g_t_override[i * d + r];
798                }
799            }
800            for (j, gb) in base.gb.iter_mut().enumerate() {
801                *gb = g_beta_override[j];
802            }
803            base.refresh_row_hessian_fingerprint();
804
805            let outcome = self.run_one_outer(&base, half_target_energy, opts, mode, &mut shared)?;
806            outcomes.push(outcome);
807        }
808
809        Ok(OuterSequenceOutcome {
810            outers: outcomes,
811            frame_builds: shared.frame_builds,
812        })
813    }
814
815    /// One outer evaluation's inner Newton loop, optionally reusing the frame
816    /// carried in `shared` across calls. Mirrors `run_inner_loop` but takes the
817    /// base system + the shared across-outer state so the caller can keep one
818    /// frame live for the whole sweep. `shared.frame_builds` is incremented every
819    /// time a frame is actually (re)built, so the caller can assert the
820    /// across-outer amortization fired.
821    fn run_one_outer(
822        &self,
823        base: &ArrowSchurSystem,
824        half_target_energy: f64,
825        opts: &DeviceResidentInnerOptions,
826        mode: InnerSolveMode,
827        shared: &mut SharedFrameState,
828    ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError> {
829        let execution_path = mode.execution_path();
830        let n = self.shape.n;
831        let d = self.shape.d;
832        let p = self.shape.p;
833        let t_len = n * d;
834
835        let mut t = vec![0.0_f64; t_len];
836        let mut beta = vec![0.0_f64; p];
837        let mut ridge_t = opts.initial_ridge_t.max(0.0);
838        let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
839        let mut current_objective = self.objective_at(base, half_target_energy, &t, &beta);
840        let mut accepted_iters = 0_usize;
841        let mut total_iters = 0_usize;
842        let mut converged = false;
843        let mut last_gradient_norm = 0.0_f64;
844        let mut last_log_det = 0.0_f64;
845
846        while total_iters < opts.max_iterations {
847            let residual = self.residual_system(base, &t, &beta);
848            let g_norm = arrow_system_gradient_norm(&residual);
849            let scale = 1.0 + iterate_norm(&t, &beta);
850            if g_norm / scale < opts.convergence_tolerance {
851                converged = true;
852                break;
853            }
854
855            let solution = match mode {
856                InnerSolveMode::DeviceResident => {
857                    let frame_matches = shared
858                        .frame
859                        .as_ref()
860                        .is_some_and(|(rt, rb, _)| *rt == ridge_t && *rb == ridge_beta);
861                    let mut frame_build_error: Option<DeviceResidentArrowError> = None;
862                    if !frame_matches {
863                        shared.frame = None;
864                        match crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
865                            &residual, ridge_t, ridge_beta,
866                        ) {
867                            Ok(frame) => {
868                                shared.frame_builds += 1;
869                                gam_gpu::profile::telemetry_record_handle_creation(
870                                    self.context_id(),
871                                );
872                                gam_gpu::profile::telemetry_record_factorization();
873                                gam_gpu::profile::telemetry_record_h2d(
874                                    self.frame_upload_bytes(),
875                                );
876                                shared.frame = Some((ridge_t, ridge_beta, frame));
877                            }
878                            Err(err) => frame_build_error = Some(map_gpu_error(err)),
879                        }
880                    }
881                    match shared.frame.as_ref() {
882                        Some((_, _, frame)) => {
883                            let mut g_t = Vec::with_capacity(n * d);
884                            for row in &residual.rows {
885                                for &v in row.gt.iter() {
886                                    g_t.push(v);
887                                }
888                            }
889                            let g_beta: Vec<f64> = residual.gb.iter().copied().collect();
890                            let grad_bytes =
891                                (g_t.len() + g_beta.len()) * std::mem::size_of::<f64>();
892                            gam_gpu::profile::telemetry_record_h2d(grad_bytes);
893                            gam_gpu::profile::telemetry_record_kernel_launch();
894                            gam_gpu::profile::telemetry_record_d2h(
895                                (n * d + p) * std::mem::size_of::<f64>(),
896                            );
897                            frame.solve_gradient(&g_t, &g_beta).map_err(map_gpu_error)
898                        }
899                        None => Err(frame_build_error.unwrap_or_else(|| {
900                            DeviceResidentArrowError::Solve {
901                                reason: "SAE resident frame build declined".to_string(),
902                            }
903                        })),
904                    }
905                }
906                InnerSolveMode::DeviceReupload => {
907                    solve_arrow_newton_step(&residual, ridge_t, ridge_beta).map_err(map_gpu_error)
908                }
909                InnerSolveMode::CpuReference => {
910                    solve_arrow_newton_step_dense_reference(&residual, ridge_t, ridge_beta)
911                        .map_err(|reason| DeviceResidentArrowError::Solve { reason })
912                }
913            };
914
915            let solution = match solution {
916                Ok(sol) => sol,
917                Err(DeviceResidentArrowError::Solve { .. })
918                | Err(DeviceResidentArrowError::Unavailable { .. }) => {
919                    ridge_t = grow_ridge(ridge_t, opts.lm_grow);
920                    ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
921                    if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
922                        return Err(DeviceResidentArrowError::Solve {
923                            reason: format!(
924                                "SAE outer-sequence inner loop: LM ridge exceeded max ({:e}) at iter {total_iters}",
925                                opts.max_ridge
926                            ),
927                        });
928                    }
929                    total_iters += 1;
930                    continue;
931                }
932                Err(other) => return Err(other),
933            };
934
935            let predicted_reduction =
936                crate::arrow_schur::arrow_bare_quadratic_model_reduction(
937                    &residual,
938                    solution.delta_t.view(),
939                    solution.delta_beta.view(),
940                    ridge_t,
941                    ridge_beta,
942                )
943                .map_err(|err| DeviceResidentArrowError::Solve {
944                    reason: format!("SAE outer-sequence predicted-reduction failed: {err}"),
945                })?;
946
947            let mut trial_t = t.clone();
948            let mut trial_beta = beta.clone();
949            for (slot, dv) in trial_t.iter_mut().zip(solution.delta_t.iter()) {
950                *slot += *dv;
951            }
952            for (slot, dv) in trial_beta.iter_mut().zip(solution.delta_beta.iter()) {
953                *slot += *dv;
954            }
955            let trial_objective =
956                self.objective_at(base, half_target_energy, &trial_t, &trial_beta);
957
958            let objective_scale = current_objective.abs();
959            let noise_floor = objective_scale * 1e-14;
960            let actual_reduction = current_objective - trial_objective;
961            let rho = if predicted_reduction > noise_floor {
962                actual_reduction / predicted_reduction
963            } else if actual_reduction >= -noise_floor {
964                1.0
965            } else {
966                -1.0
967            };
968
969            if rho > 0.0 && trial_objective.is_finite() {
970                t = trial_t;
971                beta = trial_beta;
972                current_objective = trial_objective;
973                ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
974                ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
975                last_gradient_norm = g_norm;
976                last_log_det = solution.log_det_hessian;
977                accepted_iters += 1;
978                total_iters += 1;
979            } else {
980                ridge_t = grow_ridge(ridge_t, opts.lm_grow);
981                ridge_beta = grow_ridge(ridge_beta, opts.lm_grow);
982                if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
983                    return Err(DeviceResidentArrowError::Solve {
984                        reason: format!(
985                            "SAE outer-sequence inner loop: LM rejected step until ridge exceeded max ({:e}) at iter {total_iters} (rho={rho:.3e})",
986                            opts.max_ridge
987                        ),
988                    });
989                }
990                total_iters += 1;
991            }
992        }
993
994        Ok(DeviceResidentInnerOutcome {
995            t: Array1::from_vec(t),
996            beta: Array1::from_vec(beta),
997            objective: current_objective,
998            gradient_norm: last_gradient_norm,
999            log_det_hessian: last_log_det,
1000            iterations: total_iters,
1001            accepted_iterations: accepted_iters,
1002            converged,
1003            execution_path,
1004        })
1005    }
1006
1007    /// Bordered-quadratic objective `½‖X‖² + ½ zᵀ H z − g₀ᵀ z` at iterate
1008    /// `z = (t, β)`. Uses the resident arrow structure: per-row `H_tt`/`H_tβ`
1009    /// contractions plus the shared `H_ββ` border, then the linear `g₀ᵀ z`
1010    /// term. This is the reduction the device line search evaluates; on a CUDA
1011    /// host the `H z` contraction rides the same resident slabs (batched
1012    /// per-row GEMV + border GEMV), with only the final dot reduced to a scalar.
1013    fn objective_at(
1014        &self,
1015        base: &ArrowSchurSystem,
1016        half_target_energy: f64,
1017        t: &[f64],
1018        beta: &[f64],
1019    ) -> f64 {
1020        let n = self.shape.n;
1021        let d = self.shape.d;
1022        let p = self.shape.p;
1023        // quad = zᵀ H z, lin = g₀ᵀ z.
1024        let mut quad = 0.0_f64;
1025        let mut lin = 0.0_f64;
1026        // Per-row blocks: tᵢᵀ H_tt tᵢ + 2 tᵢᵀ H_tβ β contributes to quad; the
1027        // β border H_ββ is added once below.
1028        for i in 0..n {
1029            let t_base = i * d;
1030            for r in 0..d {
1031                // H_tt tᵢ row.
1032                let mut htt_t = 0.0_f64;
1033                for c in 0..d {
1034                    htt_t += base.rows[i].htt[[r, c]] * t[t_base + c];
1035                }
1036                // H_tβ β row.
1037                let mut htb_b = 0.0_f64;
1038                for c in 0..p {
1039                    htb_b += base.rows[i].htbeta[[r, c]] * beta[c];
1040                }
1041                quad += t[t_base + r] * (htt_t + 2.0 * htb_b);
1042                lin += base.rows[i].gt[r] * t[t_base + r];
1043            }
1044        }
1045        // β border: βᵀ H_ββ β and g_β ᵀ β.
1046        for r in 0..p {
1047            let mut hbb_b = 0.0_f64;
1048            for c in 0..p {
1049                hbb_b += base.hbb[[r, c]] * beta[c];
1050            }
1051            quad += beta[r] * hbb_b;
1052            lin += base.gb[r] * beta[r];
1053        }
1054        half_target_energy + 0.5 * quad - lin
1055    }
1056
1057    /// Build the residual arrow system at iterate `z`: same Hessian blocks as
1058    /// `base`, but the gradient set to `r(z) = H z − g₀`. The arrow solver
1059    /// solves `H δ = −gradient = −r(z) = g₀ − H z`, the Newton direction toward
1060    /// the quadratic's minimiser.
1061    fn residual_system(
1062        &self,
1063        base: &ArrowSchurSystem,
1064        t: &[f64],
1065        beta: &[f64],
1066    ) -> ArrowSchurSystem {
1067        let n = self.shape.n;
1068        let d = self.shape.d;
1069        let p = self.shape.p;
1070        // `ArrowSchurSystem` is not `Clone` (it carries matrix-free operator
1071        // closures whose sharing across a then-mutated system would be a
1072        // footgun), so own a fresh system built from the resident slabs rather
1073        // than cloning `base`. `to_arrow_system` reproduces the identical
1074        // Hessian blocks; we overwrite only the gradients below with the
1075        // residual `r(z) = H z − g₀`. The Hessian reads stay on `base` (bit-
1076        // identical to the fresh system's blocks).
1077        let mut sys = self.to_arrow_system();
1078        for i in 0..n {
1079            let t_base = i * d;
1080            for r in 0..d {
1081                let mut hz = 0.0_f64;
1082                for c in 0..d {
1083                    hz += base.rows[i].htt[[r, c]] * t[t_base + c];
1084                }
1085                for c in 0..p {
1086                    hz += base.rows[i].htbeta[[r, c]] * beta[c];
1087                }
1088                sys.rows[i].gt[r] = hz - base.rows[i].gt[r];
1089            }
1090        }
1091        for r in 0..p {
1092            let mut hz = 0.0_f64;
1093            // H_ββ β.
1094            for c in 0..p {
1095                hz += base.hbb[[r, c]] * beta[c];
1096            }
1097            // Σ_i (H_tβ^(i))ᵀ tᵢ contribution to the β-gradient.
1098            for i in 0..n {
1099                let t_base = i * d;
1100                for rr in 0..d {
1101                    hz += base.rows[i].htbeta[[rr, r]] * t[t_base + rr];
1102                }
1103            }
1104            sys.gb[r] = hz - base.gb[r];
1105        }
1106        sys.refresh_row_hessian_fingerprint();
1107        sys
1108    }
1109}
1110
1111/// Options for the device-resident inner Newton loop. Defaults mirror the
1112/// production [`crate::latent_inner::LatentInnerOptions`] trust-region
1113/// schedule so device and CPU paths run identical host-side control flow.
1114#[derive(Clone, Copy, Debug)]
1115pub struct DeviceResidentInnerOptions {
1116    pub max_iterations: usize,
1117    pub convergence_tolerance: f64,
1118    pub initial_ridge_t: f64,
1119    pub initial_ridge_beta: f64,
1120    pub lm_grow: f64,
1121    pub lm_shrink: f64,
1122    pub max_ridge: f64,
1123}
1124
1125impl Default for DeviceResidentInnerOptions {
1126    fn default() -> Self {
1127        Self {
1128            max_iterations: 16,
1129            convergence_tolerance: 1e-9,
1130            initial_ridge_t: 0.0,
1131            initial_ridge_beta: 0.0,
1132            lm_grow: 4.0,
1133            lm_shrink: 0.5,
1134            max_ridge: 1e9,
1135        }
1136    }
1137}
1138
1139/// Result of the full device-resident inner Newton loop.
1140#[derive(Clone, Debug)]
1141pub struct DeviceResidentInnerOutcome {
1142    pub t: Array1<f64>,
1143    pub beta: Array1<f64>,
1144    pub objective: f64,
1145    pub gradient_norm: f64,
1146    pub log_det_hessian: f64,
1147    pub iterations: usize,
1148    pub accepted_iterations: usize,
1149    pub converged: bool,
1150    pub execution_path: ExecutionPath,
1151}
1152
1153/// Result of an across-outer resident sweep ([`DeviceResidentArrowWorkspace::device_fit_outer_sequence`]).
1154///
1155/// `outers` holds one inner-loop outcome per outer evaluation, in input order.
1156/// `frame_builds` is the total number of resident-frame (re)builds performed
1157/// across the whole sweep: for an unchanged operator with a well-posed ridge it
1158/// is exactly `1` (the across-outer amortization #1017 deliverable 3 buys —
1159/// factor once, reuse the device factors for every outer), regardless of how
1160/// many outers ran. A value `> 1` means a per-outer ridge escalation forced a
1161/// refactor, which the parity oracle still matches but which costs the
1162/// amortization for those outers.
1163#[derive(Clone, Debug)]
1164pub struct OuterSequenceOutcome {
1165    pub outers: Vec<DeviceResidentInnerOutcome>,
1166    pub frame_builds: usize,
1167}
1168
1169/// Across-outer resident-frame state carried through a `device_fit_outer_sequence`
1170/// sweep. Holds the single resident frame (keyed by its `(ridge_t, ridge_beta)`)
1171/// reused across outers at an unchanged operator, plus the running count of frame
1172/// (re)builds so the caller can assert the across-outer amortization fired.
1173#[derive(Default)]
1174struct SharedFrameState {
1175    frame: Option<(
1176        f64,
1177        f64,
1178        crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle,
1179    )>,
1180    frame_builds: usize,
1181}
1182
1183fn grow_ridge(current: f64, grow: f64) -> f64 {
1184    if current == 0.0 { 1e-6 } else { current * grow }
1185}
1186
1187fn arrow_system_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
1188    let mut acc = 0.0_f64;
1189    for row in &sys.rows {
1190        for &v in row.gt.iter() {
1191            acc += v * v;
1192        }
1193    }
1194    for &v in sys.gb.iter() {
1195        acc += v * v;
1196    }
1197    acc.sqrt()
1198}
1199
1200fn iterate_norm(t: &[f64], beta: &[f64]) -> f64 {
1201    (squared_norm(t) + squared_norm(beta)).sqrt()
1202}
1203
1204fn validate_shape(
1205    shape: DeviceResidentArrowShape,
1206    target_x: &[f64],
1207    basis_values: &[f64],
1208    gate_activations: &[f64],
1209    slabs: &DeviceResidentArrowSlabs,
1210) -> Result<(), DeviceResidentArrowError> {
1211    let checks = [
1212        ("target_x", target_x.len(), shape.target_len()),
1213        ("basis_values", basis_values.len(), shape.basis_len()),
1214        (
1215            "gate_activations",
1216            gate_activations.len(),
1217            shape.basis_len(),
1218        ),
1219        (
1220            "row_hessian_slabs",
1221            slabs.row_hessian_slabs.len(),
1222            shape.row_hessian_len(),
1223        ),
1224        (
1225            "row_cross_slabs",
1226            slabs.row_cross_slabs.len(),
1227            shape.row_cross_len(),
1228        ),
1229        (
1230            "row_gradient_slabs",
1231            slabs.row_gradient_slabs.len(),
1232            shape.row_gradient_len(),
1233        ),
1234        (
1235            "border_hessian",
1236            slabs.border_hessian.len(),
1237            shape.border_hessian_len(),
1238        ),
1239        ("border_gradient", slabs.border_gradient.len(), shape.p),
1240    ];
1241    for (label, got, want) in checks {
1242        if got != want {
1243            return Err(DeviceResidentArrowError::Shape {
1244                reason: format!(
1245                    "SAE resident workspace shape mismatch for {label}: got {got}, expected {want}"
1246                ),
1247            });
1248        }
1249    }
1250    if shape.n == 0 || shape.p == 0 || shape.d == 0 || shape.basis_cols == 0 {
1251        return Err(DeviceResidentArrowError::Shape {
1252            reason: "SAE resident workspace requires nonzero n, p, basis_cols, and d".to_string(),
1253        });
1254    }
1255    Ok(())
1256}
1257
1258#[cfg(target_os = "linux")]
1259fn upload_resident_buffers(
1260    shape: DeviceResidentArrowShape,
1261    target_x: &[f64],
1262    basis_values: &[f64],
1263    gate_activations: &[f64],
1264    slabs: &DeviceResidentArrowSlabs,
1265) -> Option<DeviceResidentArrowBuffers> {
1266    use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1267
1268    let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf {
1269        p: shape.d,
1270        batch: shape.n,
1271    })
1272    .or_else(|| {
1273        route_through_gpu(DispatchOp::Gemm {
1274            m: shape.p,
1275            n: shape.p,
1276            k: shape.n * shape.basis_cols,
1277        })
1278    })?;
1279    let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)?;
1280    let stream = ctx.new_stream().ok()?;
1281    let target_x_dev = stream.clone_htod(target_x).ok()?;
1282    let basis_values_dev = stream.clone_htod(basis_values).ok()?;
1283    let gate_activations_dev = stream.clone_htod(gate_activations).ok()?;
1284    let row_hessian_dev = stream.clone_htod(&slabs.row_hessian_slabs).ok()?;
1285    let row_cross_dev = stream.clone_htod(&slabs.row_cross_slabs).ok()?;
1286    let row_gradient_dev = stream.clone_htod(&slabs.row_gradient_slabs).ok()?;
1287    let border_hessian_dev = stream.clone_htod(&slabs.border_hessian).ok()?;
1288    let border_gradient_dev = stream.clone_htod(&slabs.border_gradient).ok()?;
1289    let bytes = [
1290        target_x.len(),
1291        basis_values.len(),
1292        gate_activations.len(),
1293        slabs.row_hessian_slabs.len(),
1294        slabs.row_cross_slabs.len(),
1295        slabs.row_gradient_slabs.len(),
1296        slabs.border_hessian.len(),
1297        slabs.border_gradient.len(),
1298    ]
1299    .into_iter()
1300    .sum::<usize>()
1301        * std::mem::size_of::<f64>();
1302    Some(DeviceResidentArrowBuffers {
1303        stream,
1304        target_x_dev,
1305        basis_values_dev,
1306        gate_activations_dev,
1307        row_hessian_dev,
1308        row_cross_dev,
1309        row_gradient_dev,
1310        border_hessian_dev,
1311        border_gradient_dev,
1312        bytes,
1313    })
1314}
1315
1316fn map_gpu_error(err: ArrowSchurGpuFailure) -> DeviceResidentArrowError {
1317    match err {
1318        ArrowSchurGpuFailure::Unavailable => DeviceResidentArrowError::Unavailable {
1319            reason: "SAE resident inner iteration unavailable after GPU admission".to_string(),
1320        },
1321        ArrowSchurGpuFailure::RidgeBumpRequired { row, bump } => DeviceResidentArrowError::Solve {
1322            reason: format!("SAE resident inner iteration row {row} requires ridge bump {bump:e}"),
1323        },
1324        ArrowSchurGpuFailure::SchurFactorFailed { reason } => {
1325            DeviceResidentArrowError::Solve { reason }
1326        }
1327        ArrowSchurGpuFailure::GpuRequiresDenseSystem {
1328            had_hbb_matvec,
1329            had_htbeta_matvec,
1330        } => DeviceResidentArrowError::Solve {
1331            reason: format!(
1332                "SAE resident inner iteration requires dense slabs; hbb_matvec={had_hbb_matvec} htbeta_matvec={had_htbeta_matvec}"
1333            ),
1334        },
1335    }
1336}
1337
1338fn squared_norm(values: &[f64]) -> f64 {
1339    values.iter().map(|v| v * v).sum()
1340}
1341
1342impl From<ArrowSchurError> for DeviceResidentArrowError {
1343    fn from(err: ArrowSchurError) -> Self {
1344        Self::Solve {
1345            reason: err.to_string(),
1346        }
1347    }
1348}
1349
1350/// Deterministic qwen-scale non-gating fixture for the resident harness.
1351pub fn qwen_non_gating_fixture() -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1352    qwen_non_gating_fixture_seeded(0x1017_0003_D3A1_5EED)
1353}
1354
1355/// Seeded variant of [`qwen_non_gating_fixture`]. Distinct seeds produce
1356/// distinct-but-well-conditioned resident frames, used to build independent
1357/// replicate fits for the stream-multiplexing parity harness.
1358pub fn qwen_non_gating_fixture_seeded(
1359    seed: u64,
1360) -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1361    fixture_for_shape_seeded(DeviceResidentArrowShape::qwen_non_gating(), seed)
1362}
1363
1364/// Deterministic color-arm-scale resident fixture (n=180, p=5120) for the
1365/// #1017 GPU wall-clock bench: few rows, very wide border — the shape where the
1366/// per-iterate re-upload + re-factor that across-iteration residency eliminates
1367/// dominates.
1368pub fn color_arm_fixture() -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1369    fixture_for_shape_seeded(DeviceResidentArrowShape::color_arm(), 0x1017_C010_2A12_5EED)
1370}
1371
1372/// Build a well-conditioned resident frame for any `d == 2` shape. Both the
1373/// qwen and color-arm fixtures share this body; the conditioning (strong row
1374/// `H_tt` diagonals, tiny cross blocks, diagonally-dominant border) keeps the
1375/// dense reference factorisation PD so the parity harness is meaningful.
1376fn fixture_for_shape_seeded(
1377    shape: DeviceResidentArrowShape,
1378    seed: u64,
1379) -> Result<DeviceResidentArrowWorkspace, DeviceResidentArrowError> {
1380    if shape.d == 0 {
1381        return Err(DeviceResidentArrowError::Shape {
1382            reason: "fixture_for_shape_seeded requires d >= 1".to_string(),
1383        });
1384    }
1385    let d = shape.d;
1386    let mut rng = SplitMix64::new(seed);
1387    let mut target_x = vec![0.0_f64; shape.target_len()];
1388    for i in 0..shape.n {
1389        for j in 0..shape.p {
1390            let phase = ((i % 97) as f64) * 0.013 + ((j % 131) as f64) * 0.007;
1391            target_x[i * shape.p + j] = 0.02 * phase.sin() + 0.001 * rng.sample_signed();
1392        }
1393    }
1394    let mut basis_values = vec![0.0_f64; shape.basis_len()];
1395    let mut gate_activations = vec![1.0_f64; shape.basis_len()];
1396    for i in 0..shape.n {
1397        for a in 0..shape.basis_cols {
1398            let phase = ((i + 1) as f64) * ((a + 1) as f64) * 0.003;
1399            basis_values[i * shape.basis_cols + a] = phase.cos();
1400            gate_activations[i * shape.basis_cols + a] = 1.0;
1401        }
1402    }
1403    let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
1404    let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
1405    let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
1406    for i in 0..shape.n {
1407        let mut basis_sum = 0.0_f64;
1408        for a in 0..shape.basis_cols {
1409            basis_sum +=
1410                basis_values[i * shape.basis_cols + a] * gate_activations[i * shape.basis_cols + a];
1411        }
1412        // Strongly diagonally-dominant d×d H_tt (row-major): diagonal ≈ 3, tiny
1413        // symmetric off-diagonals — PD for any d so the dense reference factors.
1414        let h_base = i * d * d;
1415        for r in 0..d {
1416            for c in 0..d {
1417                let v = if r == c {
1418                    3.0 + 0.01 * basis_sum.abs() + 0.1 * (r as f64)
1419                } else {
1420                    0.02 * (basis_sum + (r + c) as f64).sin() / (d as f64)
1421                };
1422                row_hessian_slabs[h_base + r * d + c] = v;
1423            }
1424        }
1425        // Symmetrize the off-diagonals exactly.
1426        for r in 0..d {
1427            for c in 0..r {
1428                let avg = 0.5
1429                    * (row_hessian_slabs[h_base + r * d + c]
1430                        + row_hessian_slabs[h_base + c * d + r]);
1431                row_hessian_slabs[h_base + r * d + c] = avg;
1432                row_hessian_slabs[h_base + c * d + r] = avg;
1433            }
1434        }
1435        // d×p cross block (row-major) and length-d gradient.
1436        let b_base = i * d * shape.p;
1437        let g_base = i * d;
1438        for r in 0..d {
1439            for j in 0..shape.p {
1440                let feature = ((j % 257) as f64) * 0.011;
1441                row_cross_slabs[b_base + r * shape.p + j] =
1442                    1.0e-4 * (basis_sum + r as f64).sin() * feature.cos();
1443            }
1444            row_gradient_slabs[g_base + r] = 0.01 * (basis_sum + r as f64).sin();
1445        }
1446    }
1447    let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
1448    for r in 0..shape.p {
1449        border_hessian[r * shape.p + r] = 4.0;
1450        if r + 1 < shape.p {
1451            border_hessian[r * shape.p + r + 1] = 0.01;
1452            border_hessian[(r + 1) * shape.p + r] = 0.01;
1453        }
1454    }
1455    let mut border_gradient = vec![0.0_f64; shape.p];
1456    for j in 0..shape.p {
1457        border_gradient[j] = 0.001 * ((j % 193) as f64 * 0.017).sin();
1458    }
1459    DeviceResidentArrowWorkspace::new(
1460        shape,
1461        target_x,
1462        basis_values,
1463        gate_activations,
1464        DeviceResidentArrowSlabs {
1465            row_hessian_slabs,
1466            row_cross_slabs,
1467            row_gradient_slabs,
1468            border_hessian,
1469            border_gradient,
1470        },
1471    )
1472}
1473
1474/// One multiplexed resident fit: the workspace plus the inner-loop outcome.
1475pub struct MultiplexedFit {
1476    pub outcome: DeviceResidentInnerOutcome,
1477}
1478
1479/// Phase 4: run `workspaces.len()` independent device-resident inner fits that
1480/// share one device.
1481///
1482/// # Stream-multiplexing safety argument
1483///
1484/// Each fit calls [`DeviceResidentArrowWorkspace::device_fit`], whose per-row
1485/// arrow solve (`solve_arrow_newton_step`) acquires the **process-shared**
1486/// `Arc<CudaContext>` via `device_runtime::cuda_context_for` (a `Mutex`-guarded
1487/// `OnceLock` cache) and then creates its **own** `CudaStream` with its own
1488/// cuSOLVER/cuBLAS handles and its own device allocations. Distinct streams off
1489/// one shared context execute concurrently on the device; the only shared
1490/// mutable state — the context cache and cudarc's allocator — is internally
1491/// synchronised, and no two fits touch the same stream, handle, or buffer. So
1492/// independent fits are data-race-free and the device serialises only where the
1493/// hardware must (shared SMs / copy engines), which is exactly the throughput
1494/// multiplexing the issue's Phase 4 calls for.
1495///
1496/// Concurrency is driven through [`run_topology_race_parallel`] (bac4af426),
1497/// which already bounds nested Rayon so each fit's internal `par_iter`/faer
1498/// parallelism stays inside its per-fit thread budget rather than oversubscribing
1499/// the global pool. Results are returned in input order. A single A100 thus hosts
1500/// many color-/qwen-arm fits at once — the cross-fit batch where the 1e5–1e6×
1501/// race speedup materialises.
1502///
1503/// The GPU runtime singleton (`GpuRuntime::global`) and per-ordinal context
1504/// cache are warmed by constructing the resident workspaces (each `new` calls
1505/// the same probe), so the per-fit calls inside the Rayon scope only *read* the
1506/// already-initialised `OnceLock`s — they never trigger a `get_or_init` whose
1507/// closure does nested parallel work, avoiding the OnceLock×Rayon deadlock.
1508pub fn run_resident_fits_multiplexed(
1509    workspaces: Vec<DeviceResidentArrowWorkspace>,
1510    opts: DeviceResidentInnerOptions,
1511) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String> {
1512    run_resident_fits_multiplexed_with(workspaces, opts, |workspace, opts| {
1513        workspace.device_fit(opts)
1514    })
1515}
1516
1517/// Multiplexing core parameterised over the per-fit runner, so the CPU-reference
1518/// path can exercise the exact same `run_topology_race_parallel` plumbing as the
1519/// device path in tests that run without CUDA.
1520fn run_resident_fits_multiplexed_with<Run>(
1521    workspaces: Vec<DeviceResidentArrowWorkspace>,
1522    opts: DeviceResidentInnerOptions,
1523    run_one: Run,
1524) -> Result<Vec<Result<MultiplexedFit, DeviceResidentArrowError>>, String>
1525where
1526    Run: Fn(
1527            &DeviceResidentArrowWorkspace,
1528            &DeviceResidentInnerOptions,
1529        ) -> Result<DeviceResidentInnerOutcome, DeviceResidentArrowError>
1530        + Sync,
1531{
1532    let rows = crate::topology_selector::run_topology_race_parallel(
1533        workspaces,
1534        move |workspace: DeviceResidentArrowWorkspace| {
1535            run_one(&workspace, &opts).map(|outcome| MultiplexedFit { outcome })
1536        },
1537    )?;
1538    Ok(rows.into_iter().map(|row| row.result).collect())
1539}
1540
1541/// Sequential reference for the multiplexing parity harness: the same fits run
1542/// one after another on the same shared device. Multiplexed results must be
1543/// bit-identical to this because each fit's arithmetic is independent of the
1544/// others — sharing the device changes only scheduling, never the numbers.
1545pub fn run_resident_fits_sequential(
1546    workspaces: &[DeviceResidentArrowWorkspace],
1547    opts: &DeviceResidentInnerOptions,
1548) -> Vec<Result<MultiplexedFit, DeviceResidentArrowError>> {
1549    workspaces
1550        .iter()
1551        .map(|workspace| {
1552            workspace
1553                .device_fit(opts)
1554                .map(|outcome| MultiplexedFit { outcome })
1555        })
1556        .collect()
1557}
1558
1559// ---------------------------------------------------------------------------
1560// Phase 4 variant sweep (#1017): the OLMo research battery's independent-fit
1561// matrix (K × topology × basis × layer/checkpoint) dispatched concurrently on
1562// one device.
1563//
1564// Each variant is a SEPARATE fit with its OWN resident frame: the per-fit
1565// arithmetic is independent of the others, so multiplexing them onto one a100
1566// changes only scheduling, never the numbers. This is the cross-fit batch where
1567// the issue's 1e5–1e6× race throughput materialises — and unlike per-fit
1568// across-iteration residency it needs NO fixed-quadratic inner loop, because the
1569// parallelism is BETWEEN fits, not within one.
1570// ---------------------------------------------------------------------------
1571
1572/// One independent fit in the battery's variant sweep. The battery maps each
1573/// (K, topology, basis, layer, checkpoint, seed) cell of its matrix to a
1574/// `SweepVariant`; `dim` carries the resident-frame shape that cell produces
1575/// after the host assembles its row/border slabs. Distinct `seed`s keep the
1576/// fits genuinely independent (no shared device buffer, handle, or stream).
1577#[derive(Clone, Copy, Debug)]
1578pub struct SweepVariant {
1579    /// Resident-frame shape for this variant's frozen gate/basis frame.
1580    pub dim: DeviceResidentArrowShape,
1581    /// Deterministic seed for this variant's fixture/frame.
1582    pub seed: u64,
1583}
1584
1585/// Throughput summary for a multiplexed variant sweep on one device.
1586#[derive(Clone, Copy, Debug)]
1587pub struct SweepThroughput {
1588    pub fits: usize,
1589    pub succeeded: usize,
1590    pub wall_seconds: f64,
1591    /// Fits completed per wall-clock second on the single shared device.
1592    pub fits_per_second: f64,
1593}
1594
1595/// Build the independent resident workspaces for a variant sweep. Each variant
1596/// gets its own well-conditioned `d == 2` frame (the host feeds real slabs in
1597/// production; here the deterministic fixture stands in for the parity/throughput
1598/// harness). Returns the workspaces in variant order.
1599pub fn build_sweep_workspaces(
1600    variants: &[SweepVariant],
1601) -> Result<Vec<DeviceResidentArrowWorkspace>, DeviceResidentArrowError> {
1602    variants
1603        .iter()
1604        .map(|v| fixture_for_shape_seeded(v.dim, v.seed))
1605        .collect()
1606}
1607
1608/// Dispatch a variant sweep concurrently on one device and measure cross-fit
1609/// throughput. Returns the per-variant outcomes (in variant order) and the
1610/// throughput summary (fits/sec on the single shared a100). Per-fit certified
1611/// parity is asserted by [`assert_sweep_parity_vs_sequential`].
1612pub fn run_variant_sweep_multiplexed(
1613    variants: &[SweepVariant],
1614    opts: DeviceResidentInnerOptions,
1615) -> Result<
1616    (
1617        Vec<Result<MultiplexedFit, DeviceResidentArrowError>>,
1618        SweepThroughput,
1619    ),
1620    String,
1621> {
1622    let workspaces = build_sweep_workspaces(variants).map_err(|e| e.to_string())?;
1623    run_battery_sweep_multiplexed(workspaces, opts)
1624}
1625
1626/// Production battery entry (#1017 Phase 4): dispatch CALLER-ASSEMBLED resident
1627/// workspaces concurrently on one device and measure cross-fit throughput.
1628///
1629/// This is the real-slab seam the OLMo battery uses: the host (pyffi) builds one
1630/// [`DeviceResidentArrowWorkspace`] per matrix cell from the cell's ACTUAL SAE
1631/// row_hessian/row_cross/border slabs via [`DeviceResidentArrowWorkspace::new`],
1632/// then hands the workspaces here. Unlike [`run_variant_sweep_multiplexed`]
1633/// (which builds frames from the deterministic harness fixture), this consumes
1634/// real frames, so the printed throughput is the battery's true fits/sec on one
1635/// device. Returns per-cell outcomes (in input order) + the throughput summary.
1636pub fn run_battery_sweep_multiplexed(
1637    workspaces: Vec<DeviceResidentArrowWorkspace>,
1638    opts: DeviceResidentInnerOptions,
1639) -> Result<
1640    (
1641        Vec<Result<MultiplexedFit, DeviceResidentArrowError>>,
1642        SweepThroughput,
1643    ),
1644    String,
1645> {
1646    let fits = workspaces.len();
1647    let start = std::time::Instant::now();
1648    let results = run_resident_fits_multiplexed(workspaces, opts)?;
1649    let wall_seconds = start.elapsed().as_secs_f64();
1650    let succeeded = results.iter().filter(|r| r.is_ok()).count();
1651    let throughput = SweepThroughput {
1652        fits,
1653        succeeded,
1654        wall_seconds,
1655        fits_per_second: (fits as f64) / wall_seconds.max(1e-9),
1656    };
1657    Ok((results, throughput))
1658}
1659
1660/// The OLMo battery's full color-arm variant matrix as [`SweepVariant`]s:
1661/// `K{1..=4} × topology{4} × basis{periodic, linear}` at the color-arm shape
1662/// (n=180, p=5120). `d` and `basis_cols` follow the intrinsic-rank convention
1663/// (periodic ⇒ d=2, basis_cols=8; linear ⇒ d=1, basis_cols=2). Exposed so the
1664/// pyffi battery seam can quote cross-fit throughput on the real shape matrix
1665/// (fixture frames) before the per-cell real-slab fits are wired through.
1666#[must_use]
1667pub fn color_arm_variant_matrix() -> Vec<SweepVariant> {
1668    let topologies = ["euclidean", "circle", "torus", "sphere"];
1669    let mut variants = Vec::with_capacity(4 * topologies.len() * 2);
1670    for k in 1..=4u64 {
1671        for (t_idx, _topology) in topologies.iter().enumerate() {
1672            // periodic (2 harmonics) and linear basis arms.
1673            for &(d, basis_cols, basis_tag) in &[(2usize, 8usize, 0u64), (1usize, 2usize, 1u64)] {
1674                let mut dim = DeviceResidentArrowShape::color_arm();
1675                dim.d = d;
1676                dim.basis_cols = basis_cols;
1677                let seed = 0x1017_C010_0000_0000 ^ (k << 16) ^ ((t_idx as u64) << 8) ^ basis_tag;
1678                variants.push(SweepVariant { dim, seed });
1679            }
1680        }
1681    }
1682    variants
1683}
1684
1685/// Certified per-fit parity for a variant sweep: the multiplexed (concurrent)
1686/// results must be bit-for-bit identical to the same fits run sequentially on
1687/// the same device, because independent fits' arithmetic does not depend on
1688/// scheduling. Returns the sequential throughput so the caller can report the
1689/// multiplex speedup (multiplexed fits/sec ÷ sequential fits/sec). Returns an
1690/// `Err` describing the first divergence so the harness fails loudly.
1691pub fn assert_sweep_parity_vs_sequential(
1692    variants: &[SweepVariant],
1693    opts: &DeviceResidentInnerOptions,
1694    multiplexed: &[Result<MultiplexedFit, DeviceResidentArrowError>],
1695) -> Result<SweepThroughput, String> {
1696    let workspaces = build_sweep_workspaces(variants).map_err(|e| e.to_string())?;
1697    let start = std::time::Instant::now();
1698    let sequential = run_resident_fits_sequential(&workspaces, opts);
1699    let wall_seconds = start.elapsed().as_secs_f64();
1700    if sequential.len() != multiplexed.len() {
1701        return Err(format!(
1702            "sweep parity: length mismatch seq={} mux={}",
1703            sequential.len(),
1704            multiplexed.len()
1705        ));
1706    }
1707    for (idx, (seq, mux)) in sequential.iter().zip(multiplexed.iter()).enumerate() {
1708        match (seq, mux) {
1709            (Ok(s), Ok(m)) => {
1710                if s.outcome.t.as_slice() != m.outcome.t.as_slice()
1711                    || s.outcome.beta.as_slice() != m.outcome.beta.as_slice()
1712                    || s.outcome.objective.to_bits() != m.outcome.objective.to_bits()
1713                {
1714                    return Err(format!(
1715                        "sweep parity: fit {idx} multiplexed result differs from sequential"
1716                    ));
1717                }
1718            }
1719            (Err(_), Err(_)) => {}
1720            _ => {
1721                return Err(format!(
1722                    "sweep parity: fit {idx} success/failure disagrees seq-vs-mux"
1723                ));
1724            }
1725        }
1726    }
1727    let fits = variants.len();
1728    let succeeded = sequential.iter().filter(|r| r.is_ok()).count();
1729    Ok(SweepThroughput {
1730        fits,
1731        succeeded,
1732        wall_seconds,
1733        fits_per_second: (fits as f64) / wall_seconds.max(1e-9),
1734    })
1735}
1736
1737struct SplitMix64 {
1738    state: u64,
1739}
1740
1741impl SplitMix64 {
1742    const fn new(seed: u64) -> Self {
1743        Self { state: seed }
1744    }
1745
1746    fn next_u64(&mut self) -> u64 {
1747        gam_linalg::utils::splitmix64(&mut self.state)
1748    }
1749
1750    fn sample_signed(&mut self) -> f64 {
1751        let unit = (self.next_u64() >> 11) as f64 / ((1_u64 << 53) as f64);
1752        2.0 * unit - 1.0
1753    }
1754}
1755
1756#[cfg(test)]
1757mod tests {
1758    use super::*;
1759    use ndarray::Array2;
1760
1761    /// Build a small, strongly diagonally-dominant resident frame whose dense
1762    /// reference factorisation is well-conditioned. The objective minimiser is
1763    /// `z* = H^{-1} g₀`, which the inner loop must reach.
1764    fn small_fixture(seed: u64) -> DeviceResidentArrowWorkspace {
1765        // batch (n) = 8 clears the device dispatch floor
1766        // (`small_dense_batched_potrf_min_batch = 8`) so that on a CUDA host
1767        // `upload_resident_buffers` actually binds a device and
1768        // `device_resident()` is TRUE — otherwise the device-resident parity
1769        // branch of `device_resident_fit_matches_cpu_reference` is dead on real
1770        // GPU hardware (the route declines for batch < 8, so the test only ever
1771        // exercised the CPU-decline branch and never validated the device loop).
1772        let shape = DeviceResidentArrowShape {
1773            n: 8,
1774            p: 4,
1775            basis_cols: 2,
1776            d: 2,
1777        };
1778        let mut rng = SplitMix64::new(seed);
1779        let target_x = vec![0.0_f64; shape.target_len()];
1780        let basis_values = vec![0.5_f64; shape.basis_len()];
1781        let gate_activations = vec![1.0_f64; shape.basis_len()];
1782
1783        let mut row_hessian_slabs = vec![0.0_f64; shape.row_hessian_len()];
1784        let mut row_cross_slabs = vec![0.0_f64; shape.row_cross_len()];
1785        let mut row_gradient_slabs = vec![0.0_f64; shape.row_gradient_len()];
1786        for i in 0..shape.n {
1787            let h = i * shape.d * shape.d;
1788            row_hessian_slabs[h] = 5.0 + 0.1 * rng.sample_signed();
1789            row_hessian_slabs[h + 1] = 0.05 * rng.sample_signed();
1790            row_hessian_slabs[h + 2] = row_hessian_slabs[h + 1];
1791            row_hessian_slabs[h + 3] = 4.0 + 0.1 * rng.sample_signed();
1792            let b = i * shape.d * shape.p;
1793            for j in 0..shape.p {
1794                row_cross_slabs[b + j] = 0.01 * rng.sample_signed();
1795                row_cross_slabs[b + shape.p + j] = 0.01 * rng.sample_signed();
1796            }
1797            let g = i * shape.d;
1798            row_gradient_slabs[g] = rng.sample_signed();
1799            row_gradient_slabs[g + 1] = rng.sample_signed();
1800        }
1801        let mut border_hessian = vec![0.0_f64; shape.border_hessian_len()];
1802        for r in 0..shape.p {
1803            border_hessian[r * shape.p + r] = 6.0 + 0.1 * rng.sample_signed();
1804        }
1805        let border_gradient: Vec<f64> = (0..shape.p).map(|_| rng.sample_signed()).collect();
1806
1807        DeviceResidentArrowWorkspace::new(
1808            shape,
1809            target_x,
1810            basis_values,
1811            gate_activations,
1812            DeviceResidentArrowSlabs {
1813                row_hessian_slabs,
1814                row_cross_slabs,
1815                row_gradient_slabs,
1816                border_hessian,
1817                border_gradient,
1818            },
1819        )
1820        .expect("small resident fixture must validate")
1821    }
1822
1823    /// Dense `H z` for the resident frame (independent of the arrow path),
1824    /// used to confirm the inner-loop fixed point is the true stationary point.
1825    fn dense_hz(
1826        ws: &DeviceResidentArrowWorkspace,
1827        sys: &ArrowSchurSystem,
1828    ) -> (Array2<f64>, Array1<f64>) {
1829        let shape = ws.shape;
1830        let total = shape.n * shape.d + shape.p;
1831        let mut h = Array2::<f64>::zeros((total, total));
1832        let mut g0 = Array1::<f64>::zeros(total);
1833        for i in 0..shape.n {
1834            let base = i * shape.d;
1835            for r in 0..shape.d {
1836                for c in 0..shape.d {
1837                    h[[base + r, base + c]] = sys.rows[i].htt[[r, c]];
1838                }
1839                for c in 0..shape.p {
1840                    let v = sys.rows[i].htbeta[[r, c]];
1841                    h[[base + r, shape.n * shape.d + c]] = v;
1842                    h[[shape.n * shape.d + c, base + r]] = v;
1843                }
1844                g0[base + r] = sys.rows[i].gt[r];
1845            }
1846        }
1847        for r in 0..shape.p {
1848            for c in 0..shape.p {
1849                h[[shape.n * shape.d + r, shape.n * shape.d + c]] = sys.hbb[[r, c]];
1850            }
1851            g0[shape.n * shape.d + r] = sys.gb[r];
1852        }
1853        (h, g0)
1854    }
1855
1856    #[test]
1857    fn cpu_inner_loop_reaches_quadratic_minimiser() {
1858        let ws = small_fixture(0xABCD_0001);
1859        let opts = DeviceResidentInnerOptions::default();
1860        let outcome = ws.cpu_reference_fit(&opts).expect("cpu fit");
1861        assert!(
1862            outcome.converged,
1863            "inner loop must converge on a PD quadratic"
1864        );
1865
1866        // The stationary point satisfies H z* = g₀; verify the residual is zero.
1867        let base = ws.to_arrow_system();
1868        let (h, g0) = dense_hz(&ws, &base);
1869        let total = ws.shape.n * ws.shape.d + ws.shape.p;
1870        let mut z = Array1::<f64>::zeros(total);
1871        for r in 0..ws.shape.n * ws.shape.d {
1872            z[r] = outcome.t[r];
1873        }
1874        for c in 0..ws.shape.p {
1875            z[ws.shape.n * ws.shape.d + c] = outcome.beta[c];
1876        }
1877        let hz = h.dot(&z);
1878        let mut max_resid = 0.0_f64;
1879        for r in 0..total {
1880            max_resid = max_resid.max((hz[r] - g0[r]).abs());
1881        }
1882        assert!(
1883            max_resid < 1e-9,
1884            "inner loop fixed point must solve H z = g0; residual {max_resid:e}"
1885        );
1886    }
1887
1888    #[test]
1889    fn cpu_multiplex_matches_sequential_bit_identical() {
1890        let seeds = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66];
1891        let opts = DeviceResidentInnerOptions::default();
1892
1893        let seq_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
1894        let sequential: Vec<_> = seq_workspaces
1895            .iter()
1896            .map(|ws| ws.cpu_reference_fit(&opts).expect("seq cpu fit"))
1897            .collect();
1898
1899        let mux_workspaces: Vec<_> = seeds.iter().map(|&s| small_fixture(s)).collect();
1900        let multiplexed = run_resident_fits_multiplexed_with(mux_workspaces, opts, |ws, opts| {
1901            ws.cpu_reference_fit(opts)
1902        })
1903        .expect("multiplexed cpu fits");
1904
1905        assert_eq!(sequential.len(), multiplexed.len());
1906        for (seq, mux) in sequential.iter().zip(multiplexed.iter()) {
1907            let mux = mux.as_ref().expect("mux fit ok");
1908            // Independent fits: scheduling cannot change the numbers, so the
1909            // parallel result must be bit-for-bit identical to sequential.
1910            assert_eq!(seq.t.as_slice(), mux.outcome.t.as_slice());
1911            assert_eq!(seq.beta.as_slice(), mux.outcome.beta.as_slice());
1912            assert_eq!(seq.objective.to_bits(), mux.outcome.objective.to_bits());
1913        }
1914    }
1915
1916    /// #1017 Phase 3 residency parity. On a CUDA host the device-resident inner
1917    /// loop (`device_fit`, which keeps the Hessian factors on-device across
1918    /// iterations via `ResidentArrowFrameHandle`) must reach the same minimiser
1919    /// as the fully independent CPU dense-reference loop (`cpu_reference_fit`,
1920    /// which re-factors per iterate). On a CPU-only host the resident path must
1921    /// decline cleanly (`Unavailable`) rather than silently disagree, and the
1922    /// resident-frame handle construction must likewise decline — so the gate is
1923    /// meaningful on the build box and the wall-clock arm runs on the GPU node.
1924    #[test]
1925    fn device_resident_fit_matches_cpu_reference() {
1926        let ws = small_fixture(0x5AE_1017);
1927        let opts = DeviceResidentInnerOptions::default();
1928
1929        // CPU reference (re-factors per iterate) — always available.
1930        let cpu = ws.cpu_reference_fit(&opts).expect("cpu reference fit");
1931        assert!(cpu.converged, "cpu reference must converge on PD quadratic");
1932
1933        let base = ws.to_arrow_system();
1934
1935        if ws.device_resident() {
1936            // Resident device loop: factors stay on-device across iterations.
1937            let dev = ws.device_fit(&opts).expect("device resident fit");
1938            assert_eq!(
1939                dev.execution_path,
1940                ExecutionPath::GpuResidentFull,
1941                "device_fit must report the full device-resident execution path"
1942            );
1943            assert!(dev.converged, "device resident loop must converge");
1944
1945            // Certified-refinement parity (#1014): the resident path and the
1946            // independent CPU path solve the same quadratic, so their minimisers
1947            // agree to a tight relative tolerance. The resident path differs from
1948            // the reference only by SKIPPING re-derivation of g-independent
1949            // factor work, not by changing the arithmetic.
1950            let t_scale = cpu.t.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
1951            let b_scale = cpu.beta.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
1952            let mut max_rel = 0.0_f64;
1953            for (a, b) in dev.t.iter().zip(cpu.t.iter()) {
1954                max_rel = max_rel.max((a - b).abs() / t_scale);
1955            }
1956            for (a, b) in dev.beta.iter().zip(cpu.beta.iter()) {
1957                max_rel = max_rel.max((a - b).abs() / b_scale);
1958            }
1959            assert!(
1960                max_rel < 1e-9,
1961                "resident device fit must match CPU reference (rel {max_rel:e})"
1962            );
1963
1964            // The resident frame's single-gradient solve must also match a full
1965            // independent solve at the same gradient (the per-iterate contract).
1966            let frame = crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
1967                &base,
1968                opts.initial_ridge_t,
1969                opts.initial_ridge_beta,
1970            )
1971            .expect("resident frame must build on CUDA host");
1972            let g_t: Vec<f64> = base
1973                .rows
1974                .iter()
1975                .flat_map(|r| r.gt.iter().copied())
1976                .collect();
1977            let g_beta: Vec<f64> = base.gb.iter().copied().collect();
1978            let resident_sol = frame
1979                .solve_gradient(&g_t, &g_beta)
1980                .expect("resident single-gradient solve");
1981            let full = crate::gpu_kernels::arrow_schur::solve_arrow_newton_step_dense_reference(
1982                &base,
1983                opts.initial_ridge_t,
1984                opts.initial_ridge_beta,
1985            )
1986            .expect("dense reference single solve");
1987            let mut max_step_rel = 0.0_f64;
1988            let step_scale = full
1989                .delta_t
1990                .iter()
1991                .chain(full.delta_beta.iter())
1992                .fold(1.0_f64, |m, &v| m.max(v.abs()));
1993            for (a, b) in resident_sol.delta_t.iter().zip(full.delta_t.iter()) {
1994                max_step_rel = max_step_rel.max((a - b).abs() / step_scale);
1995            }
1996            for (a, b) in resident_sol.delta_beta.iter().zip(full.delta_beta.iter()) {
1997                max_step_rel = max_step_rel.max((a - b).abs() / step_scale);
1998            }
1999            assert!(
2000                max_step_rel < 1e-9,
2001                "resident solve_gradient must match full dense reference step (rel {max_step_rel:e})"
2002            );
2003
2004            // The re-uploading GPU loop (residency baseline) must reach the same
2005            // minimiser as both the resident loop and the CPU reference.
2006            let reup = ws
2007                .device_reupload_fit(&opts)
2008                .expect("device re-uploading fit");
2009            assert_eq!(
2010                reup.execution_path,
2011                ExecutionPath::GpuReupload,
2012                "device_reupload_fit must report the re-uploading device path"
2013            );
2014            assert!(reup.converged, "re-uploading loop must converge");
2015            let mut max_reup_rel = 0.0_f64;
2016            for (a, b) in reup.t.iter().zip(cpu.t.iter()) {
2017                max_reup_rel = max_reup_rel.max((a - b).abs() / t_scale);
2018            }
2019            for (a, b) in reup.beta.iter().zip(cpu.beta.iter()) {
2020                max_reup_rel = max_reup_rel.max((a - b).abs() / b_scale);
2021            }
2022            assert!(
2023                max_reup_rel < 1e-9,
2024                "re-uploading GPU fit must match CPU reference (rel {max_reup_rel:e})"
2025            );
2026        } else {
2027            // The fixture is sized (batch = 8) to clear the device dispatch floor,
2028            // so on a host WITH a CUDA runtime `device_resident()` must be true and
2029            // we take the device branch above. Reaching this branch with a runtime
2030            // present means the device binding silently failed — which would mask a
2031            // real upload/dispatch fault behind the CPU-decline path (the
2032            // device-PCG skip-pass class, eee12f6b2). Fail loud unless this is a
2033            // genuinely CPU-only host.
2034            assert!(
2035                gam_gpu::device_runtime::GpuRuntime::global().is_none(),
2036                "device_resident() is false on a host WITH a CUDA runtime present, \
2037                 despite a floor-clearing fixture (batch=8): the resident device \
2038                 buffers failed to bind — a real device fault, not a CPU-only skip."
2039            );
2040            // CPU-only host: the resident path must decline, not disagree.
2041            let dev = ws.device_fit(&opts);
2042            assert!(
2043                matches!(dev, Err(DeviceResidentArrowError::Unavailable { .. })),
2044                "device_fit must report Unavailable on a CPU-only host, got {dev:?}"
2045            );
2046            let reup = ws.device_reupload_fit(&opts);
2047            assert!(
2048                matches!(reup, Err(DeviceResidentArrowError::Unavailable { .. })),
2049                "device_reupload_fit must report Unavailable on a CPU-only host, got {reup:?}"
2050            );
2051            let frame = crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(
2052                &base,
2053                opts.initial_ridge_t,
2054                opts.initial_ridge_beta,
2055            );
2056            assert!(
2057                frame.is_err(),
2058                "resident frame construction must decline on a CPU-only host"
2059            );
2060        }
2061    }
2062
2063    /// #1017 fit-path parity (CPU-runnable). The resident inner solve and the
2064    /// PRODUCTION arrow-Schur inner solve (`solve_arrow_newton_step_core`, the
2065    /// entry the SAE joint fit reaches through `solve_with_lm_escalation_inner`)
2066    /// must solve the SAME bordered-quadratic Newton system.
2067    ///
2068    /// This is the cross-implementation parity behind wiring the device seam into
2069    /// the SAE inner loop: `solve_arrow_newton_step_core` carries the #1017
2070    /// device-Schur seam (and falls through bit-identically to its CPU path off
2071    /// CUDA), and the resident workspace's `cpu_reference_fit` converges the same
2072    /// quadratic `φ(z) = ½‖X‖² + ½ zᵀH z − g₀ᵀ z`. The resident converged iterate
2073    /// `z*` is the stationary point `H z* = g₀`; the production arrow path solves
2074    /// the Newton system `H Δ = −g₀` from `z = 0`, so its step is
2075    /// `Δ = −H⁻¹ g₀ = −z*`. With `H` PD the exact relationship is therefore
2076    /// `Δ = −z*`; asserting it pins that routing the production inner solve through
2077    /// the device-aware `_core` (which a GPU host then offloads) solves the
2078    /// identical system the resident loop does. Runs on the CPU build box — no
2079    /// CUDA required.
2080    #[test]
2081    fn resident_inner_solve_matches_production_arrow_core() {
2082        use crate::arrow_schur::{ArrowSolveOptions, solve_arrow_newton_step_core};
2083
2084        let ws = small_fixture(0x1017_F17);
2085        let opts = DeviceResidentInnerOptions::default();
2086
2087        // Resident workspace converged fit (re-factoring CPU reference loop).
2088        let resident = ws.cpu_reference_fit(&opts).expect("resident cpu fit");
2089        assert!(
2090            resident.converged,
2091            "resident reference must converge on the PD quadratic"
2092        );
2093
2094        // Production arrow path: one Newton step on the same system from z = 0.
2095        // `_core` is the device-aware entry; on this CPU box it runs the dense
2096        // CPU solve, the exact path the GPU host would fall back to on decline.
2097        let sys = ws.to_arrow_system();
2098        let (delta_t, delta_beta, _diag) = solve_arrow_newton_step_core(
2099            &sys,
2100            opts.initial_ridge_t,
2101            opts.initial_ridge_beta,
2102            &ArrowSolveOptions::direct(),
2103        )
2104        .expect("production arrow-core solve");
2105
2106        // The Newton step from z = 0 is Δ = −H⁻¹g₀ = −z*, where z* is the resident
2107        // converged iterate (H z* = g₀, the invariant
2108        // `cpu_inner_loop_reaches_quadratic_minimiser` pins directly). With H PD
2109        // the relationship is exact, so Δ + z* = 0 to factorisation tolerance.
2110        let t_scale = resident.t.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
2111        let b_scale = resident.beta.iter().fold(1.0_f64, |m, &v| m.max(v.abs()));
2112        // #1399: report the t-block and beta-block mismatch SEPARATELY (not one
2113        // fused scalar). The two halves localise a divergence: a t-block-only gap
2114        // points at the per-row factor / row gradient assembly, a beta-block gap
2115        // at the border Schur path — turning the opaque overall rel into an
2116        // actionable signal for the resident-vs-production parity divergence.
2117        let mut max_rel_t = 0.0_f64;
2118        let mut worst_t: Option<(usize, f64, f64)> = None;
2119        for (i, (prod, res)) in delta_t.iter().zip(resident.t.iter()).enumerate() {
2120            let rel = (prod + res).abs() / t_scale;
2121            if rel > max_rel_t {
2122                max_rel_t = rel;
2123                worst_t = Some((i, *prod, *res));
2124            }
2125        }
2126        let mut max_rel_b = 0.0_f64;
2127        let mut worst_b: Option<(usize, f64, f64)> = None;
2128        for (i, (prod, res)) in delta_beta.iter().zip(resident.beta.iter()).enumerate() {
2129            let rel = (prod + res).abs() / b_scale;
2130            if rel > max_rel_b {
2131                max_rel_b = rel;
2132                worst_b = Some((i, *prod, *res));
2133            }
2134        }
2135        let max_rel = max_rel_t.max(max_rel_b);
2136        assert!(
2137            max_rel < 1e-9,
2138            "production arrow-core Newton step must be −(resident converged fit) on \
2139             the same quadratic; wiring the device seam into the SAE inner loop must \
2140             not change the system being solved. rel_t={max_rel_t:e} (worst {worst_t:?}: \
2141             Δt+t* must be 0), rel_beta={max_rel_b:e} (worst {worst_b:?}: Δβ+β* must \
2142             be 0). A t-only gap implicates the per-row factor / row-gradient \
2143             assembly; a β-only gap the border Schur path."
2144        );
2145    }
2146
2147    /// #1017 deliverable 3: across-OUTER residency. A sequence of outer
2148    /// evaluations whose Hessian operator is unchanged (only the base gradient
2149    /// moves) must share ONE resident frame — exactly one frame build for the
2150    /// whole sweep — and produce results bit-identical to per-outer-independent
2151    /// fits (each with a fresh frame). On a CPU-only host this asserts the
2152    /// reference path's outer-sequence wiring is consistent; on the A100 it
2153    /// proves the across-outer factor amortization fires AND stays exact.
2154    #[test]
2155    fn outer_sequence_reuses_frame_and_matches_independent() {
2156        let ws = super::color_arm_fixture().expect("color_arm fixture");
2157        let opts = DeviceResidentInnerOptions::default();
2158        let n = ws.shape.n;
2159        let d = ws.shape.d;
2160        let p = ws.shape.p;
2161
2162        // Three "outer" evaluations: same operator, distinct base gradients (the
2163        // moving linearization point). These stand in for consecutive outer REML
2164        // evaluations at a frozen gate/basis frame.
2165        let outers: Vec<(Vec<f64>, Vec<f64>)> = (0..3)
2166            .map(|s| {
2167                let g_t: Vec<f64> = (0..n * d)
2168                    .map(|i| 0.01 * (((i + 3 * s) as f64) * 0.002).sin())
2169                    .collect();
2170                let g_beta: Vec<f64> = (0..p)
2171                    .map(|j| 0.001 * (((j + 11 * s) as f64) * 0.0009).cos())
2172                    .collect();
2173                (g_t, g_beta)
2174            })
2175            .collect();
2176
2177        // Per-outer-independent reference (fresh frame each outer) via the CPU
2178        // path, which runs on any host.
2179        let independent = ws
2180            .cpu_reference_outer_sequence(&outers, &opts)
2181            .expect("cpu reference outer sequence");
2182        assert_eq!(independent.outers.len(), outers.len());
2183
2184        if ws.device_resident() {
2185            // Device across-outer sweep: ONE frame for all three outers.
2186            let shared = ws
2187                .device_fit_outer_sequence(&outers, &opts)
2188                .expect("device outer sequence");
2189            assert_eq!(
2190                shared.frame_builds,
2191                1,
2192                "across-outer residency must build the resident frame exactly once \
2193                 for an unchanged operator (got {} builds over {} outers) — a count \
2194                 > 1 means the frame was needlessly re-factored per outer",
2195                shared.frame_builds,
2196                outers.len()
2197            );
2198            // Bit-parity: sharing the factor across outers must not change the
2199            // numbers vs per-outer-independent device fits.
2200            for (idx, (sh, ind)) in shared
2201                .outers
2202                .iter()
2203                .zip(independent.outers.iter())
2204                .enumerate()
2205            {
2206                let scale = ind
2207                    .t
2208                    .iter()
2209                    .chain(ind.beta.iter())
2210                    .fold(1.0_f64, |m, &v| m.max(v.abs()));
2211                let mut max_rel = 0.0_f64;
2212                for (a, b) in sh.t.iter().zip(ind.t.iter()) {
2213                    max_rel = max_rel.max((a - b).abs() / scale);
2214                }
2215                for (a, b) in sh.beta.iter().zip(ind.beta.iter()) {
2216                    max_rel = max_rel.max((a - b).abs() / scale);
2217                }
2218                assert!(
2219                    max_rel < 1e-9,
2220                    "outer {idx}: across-outer-shared frame must match independent fit \
2221                     (rel {max_rel:e})"
2222                );
2223            }
2224            println!(
2225                "[#1017 outer-seq color_arm] outers={} frame_builds={} (across-outer factor \
2226                 amortized) parity<1e-9 OK",
2227                outers.len(),
2228                shared.frame_builds
2229            );
2230        } else {
2231            println!(
2232                "[#1017 outer-seq color_arm] no CUDA device — across-outer residency skipped; \
2233                 run on the GPU node to assert frame_builds==1 + device parity"
2234            );
2235        }
2236    }
2237
2238    /// #1017 residency-isolating per-solve bench. A full-fit wall-clock bench
2239    /// runs an exact quadratic that converges
2240    /// in ONE Newton step, so the resident frame is built once and solved once —
2241    /// the across-iteration amortization (factor `D`/`B`/Schur once, reuse for
2242    /// every gradient) has nothing to amortize over and the measured speedup is
2243    /// only the single-solve `D`/`B` upload saving.
2244    ///
2245    /// This bench isolates the residency lever the way the production inner loop
2246    /// actually exercises it: at a frozen gate/basis frame the Hessian blocks are
2247    /// CONSTANT and the SAE inner Newton takes MANY gradient solves against them.
2248    /// It therefore times
2249    ///   * RESIDENT: build the [`crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle`]
2250    ///     ONCE, then `N` `solve_gradient` calls (upload only the `O(n·d + k)`
2251    ///     gradient per solve; no POTRF, no `D`/`B` re-upload);
2252    ///   * REUPLOAD: `N` `solve_arrow_newton_step` calls (re-pack/upload `D`/`B`/`g`
2253    ///     and re-run the per-row POTRF + border Schur factor every call).
2254    /// Both produce bit-identical steps; the ratio is the pure across-iteration
2255    /// residency speedup, which is what #1017 Phase 3 buys per inner iteration.
2256    /// `N` mirrors a realistic SAE inner-Newton iteration count. CPU-only hosts
2257    /// print a skip line. Run with `--nocapture`.
2258    #[test]
2259    fn gpu_residency_per_solve_bench() {
2260        use std::time::Instant;
2261        const N_SOLVES: usize = 24;
2262        for (label, ws) in [
2263            ("color_arm", super::color_arm_fixture()),
2264            ("qwen_non_gating", super::qwen_non_gating_fixture()),
2265        ] {
2266            let ws = ws.expect("bench fixture must validate");
2267            let base = ws.to_arrow_system();
2268            // A family of distinct gradients standing in for the per-iterate
2269            // residual r(z) = H z − g₀ the inner loop feeds. Distinct gradients
2270            // make the reupload path redo the (g-independent) factor work each
2271            // time — exactly the waste residency removes.
2272            let n = ws.shape.n;
2273            let d = ws.shape.d;
2274            let p = ws.shape.p;
2275            let gradients: Vec<(Vec<f64>, Vec<f64>)> = (0..N_SOLVES)
2276                .map(|s| {
2277                    let g_t: Vec<f64> =
2278                        (0..n * d).map(|i| ((i + s) as f64 * 0.001).sin()).collect();
2279                    let g_beta: Vec<f64> = (0..p)
2280                        .map(|j| ((j + 7 * s) as f64 * 0.0007).cos())
2281                        .collect();
2282                    (g_t, g_beta)
2283                })
2284                .collect();
2285
2286            if !ws.device_resident() {
2287                println!(
2288                    "[#1017 per-solve {label}] no CUDA device — {N_SOLVES} solves skipped; \
2289                     run on the GPU node for the across-iteration residency speedup"
2290                );
2291                continue;
2292            }
2293
2294            // Build the resident frame ONCE (its factor cost is the across-
2295            // iteration amortization the bench is measuring, so it is timed
2296            // separately from the per-solve loop).
2297            let t_build = Instant::now();
2298            let frame =
2299                crate::gpu_kernels::arrow_schur::ResidentArrowFrameHandle::new(&base, 0.0, 0.0)
2300                    .expect("resident frame must build on CUDA host");
2301            let frame_build_ms = t_build.elapsed().as_secs_f64() * 1e3;
2302
2303            // Warm-up: one solve on each path before timing so the residency
2304            // ratio reflects steady-state per-iterate cost, not the one-time
2305            // NVRTC/cuSOLVER handle init, module JIT, or first-touch device
2306            // allocation (those are paid once per process, not per inner
2307            // iteration). The production inner loop pays them once and then runs
2308            // MANY solves, which is exactly the regime this assertion guards.
2309            let _ = frame
2310                .solve_gradient(&gradients[0].0, &gradients[0].1)
2311                .expect("resident warm-up solve");
2312            {
2313                let mut sys = ws.to_arrow_system();
2314                for (i, row) in sys.rows.iter_mut().enumerate() {
2315                    for r in 0..d {
2316                        row.gt[r] = gradients[0].0[i * d + r];
2317                    }
2318                }
2319                for (j, gb) in sys.gb.iter_mut().enumerate() {
2320                    *gb = gradients[0].1[j];
2321                }
2322                sys.refresh_row_hessian_fingerprint();
2323                let _ = crate::gpu_kernels::arrow_schur::solve_arrow_newton_step(&sys, 0.0, 0.0)
2324                    .expect("reupload warm-up solve");
2325            }
2326
2327            // RESIDENT: reuse the (already-built, already-warmed) frame for N
2328            // gradient-only solves. Times ONLY the per-iterate gradient solves —
2329            // upload `O(n·d + k)` gradient, run the cheap residual path, read
2330            // back `δ`. No POTRF, no `D`/`B` re-upload.
2331            let t_res = Instant::now();
2332            let mut resident_steps = Vec::with_capacity(N_SOLVES);
2333            for (g_t, g_beta) in &gradients {
2334                resident_steps.push(
2335                    frame
2336                        .solve_gradient(g_t, g_beta)
2337                        .expect("resident solve_gradient"),
2338                );
2339            }
2340            let resident_ms = t_res.elapsed().as_secs_f64() * 1e3;
2341
2342            // REUPLOAD: N full solves, each re-uploading D/B/g and re-factoring.
2343            let t_reup = Instant::now();
2344            let mut reupload_steps = Vec::with_capacity(N_SOLVES);
2345            for (g_t, g_beta) in &gradients {
2346                let mut sys = ws.to_arrow_system();
2347                for (i, row) in sys.rows.iter_mut().enumerate() {
2348                    for r in 0..d {
2349                        row.gt[r] = g_t[i * d + r];
2350                    }
2351                }
2352                for (j, gb) in sys.gb.iter_mut().enumerate() {
2353                    *gb = g_beta[j];
2354                }
2355                sys.refresh_row_hessian_fingerprint();
2356                reupload_steps.push(
2357                    crate::gpu_kernels::arrow_schur::solve_arrow_newton_step(&sys, 0.0, 0.0)
2358                        .expect("reupload solve_arrow_newton_step"),
2359                );
2360            }
2361            let reupload_ms = t_reup.elapsed().as_secs_f64() * 1e3;
2362
2363            // Parity: resident and reupload steps must be bit-identical (same
2364            // factor kernels; residency only skips re-deriving g-independent work).
2365            let mut max_rel = 0.0_f64;
2366            for (rs, us) in resident_steps.iter().zip(reupload_steps.iter()) {
2367                let scale = us
2368                    .delta_t
2369                    .iter()
2370                    .chain(us.delta_beta.iter())
2371                    .fold(1.0_f64, |m, &v| m.max(v.abs()));
2372                for (a, b) in rs.delta_t.iter().zip(us.delta_t.iter()) {
2373                    max_rel = max_rel.max((a - b).abs() / scale);
2374                }
2375                for (a, b) in rs.delta_beta.iter().zip(us.delta_beta.iter()) {
2376                    max_rel = max_rel.max((a - b).abs() / scale);
2377                }
2378            }
2379
2380            let resident_per_solve = resident_ms / N_SOLVES as f64;
2381            let reupload_per_solve = reupload_ms / N_SOLVES as f64;
2382            let residency_speedup = reupload_ms / resident_ms.max(1e-9);
2383            println!(
2384                "[#1017 per-solve {label}] N={N_SOLVES} frame_build={frame_build_ms:.2}ms \
2385                 resident={resident_ms:.2}ms ({resident_per_solve:.3}ms/solve, \
2386                 grad-upload + warm factors) reupload={reupload_ms:.2}ms \
2387                 ({reupload_per_solve:.3}ms/solve, N factors + N D/B uploads) \
2388                 residency_speedup={residency_speedup:.2}x parity_rel={max_rel:e}"
2389            );
2390            assert!(
2391                max_rel < 1e-9,
2392                "{label}: resident per-solve steps must match reupload (rel {max_rel:e})"
2393            );
2394
2395            // #1017 deliverable 2: the residency amortization must actually fire
2396            // on hardware — reusing the resident factors across iterations has to
2397            // be STRICTLY cheaper per solve than re-uploading D/B/g and
2398            // re-factoring every iterate. This is the core perf claim, asserted
2399            // (not merely printed) so a regression that silently re-uploads, or a
2400            // dispatch change that drops the resident path, fails the gate on the
2401            // A100 instead of slipping through as a slower-but-green run.
2402            //
2403            // The `color_arm` shape (n=180, p=5120) is the decisive case: the
2404            // per-solve reupload pays a 5120-wide border Schur factor + the
2405            // `O(n·d·p)` cross-block upload every iterate, while the resident path
2406            // pays only the `O(n·d + p)` gradient transfer and two border TRSMs.
2407            // We require a clear >1.5x margin there. The `qwen_non_gating` shape
2408            // (p=2048) has a smaller border so its margin is thinner; we still
2409            // require a genuine speedup (>1x) but do not over-tighten it.
2410            let min_speedup = if label == "color_arm" { 1.5 } else { 1.0 };
2411            assert!(
2412                residency_speedup > min_speedup,
2413                "{label}: across-iteration residency must beat per-solve re-upload \
2414                 (residency_speedup={residency_speedup:.3}x, required >{min_speedup}x; \
2415                 resident {resident_per_solve:.3}ms/solve vs reupload \
2416                 {reupload_per_solve:.3}ms/solve over N={N_SOLVES} solves) — the resident \
2417                 frame either silently re-uploaded D/B or the dispatch dropped the \
2418                 amortized factor path"
2419            );
2420        }
2421    }
2422
2423    /// #1017 Phase 4 variant sweep: an OLMo-battery-shaped matrix of independent
2424    /// fits (here K{1..4} × 3 basis widths = 12 color-arm variants) dispatched
2425    /// concurrently on one device. This is the cross-fit throughput lever — the
2426    /// fits are independent, so multiplexing changes only scheduling.
2427    fn battery_variant_matrix() -> Vec<super::SweepVariant> {
2428        let mut variants = Vec::new();
2429        // K is the topology rank; the battery races K{1..4}. Each K × basis cell
2430        // is an independent fit. Color-arm border, varied basis_cols per cell.
2431        for k in 1..=4u64 {
2432            for basis_cols in [4usize, 8, 12] {
2433                let mut dim = DeviceResidentArrowShape::color_arm();
2434                dim.basis_cols = basis_cols;
2435                variants.push(super::SweepVariant {
2436                    dim,
2437                    seed: 0x1017_0040_0000_0000 ^ (k << 8) ^ (basis_cols as u64),
2438                });
2439            }
2440        }
2441        variants
2442    }
2443
2444    /// Phase-4 parity: the multiplexed sweep must be bit-identical to running the
2445    /// same fits sequentially (CPU reference path here so the gate runs on the
2446    /// build box; the device path is exercised by the throughput bench on the a100).
2447    #[test]
2448    fn variant_sweep_multiplex_matches_sequential() {
2449        let variants = battery_variant_matrix();
2450        let opts = DeviceResidentInnerOptions::default();
2451
2452        // Multiplexed via the CPU-reference runner so the gate is meaningful
2453        // without CUDA, exercising the exact run_topology_race_parallel plumbing.
2454        let workspaces =
2455            super::build_sweep_workspaces(&variants).expect("sweep workspaces must build");
2456        let multiplexed =
2457            super::run_resident_fits_multiplexed_with(workspaces, opts, |ws, opts| {
2458                ws.cpu_reference_fit(opts)
2459            })
2460            .expect("multiplexed cpu sweep");
2461
2462        let seq_workspaces =
2463            super::build_sweep_workspaces(&variants).expect("sweep workspaces must build");
2464        let sequential: Vec<_> = seq_workspaces
2465            .iter()
2466            .map(|ws| ws.cpu_reference_fit(&opts))
2467            .collect();
2468
2469        assert_eq!(multiplexed.len(), sequential.len());
2470        for (idx, (mux, seq)) in multiplexed.iter().zip(sequential.iter()).enumerate() {
2471            let mux = &mux.as_ref().unwrap().outcome;
2472            let seq = seq.as_ref().unwrap();
2473            assert_eq!(
2474                mux.t.as_slice(),
2475                seq.t.as_slice(),
2476                "variant {idx}: multiplexed t differs from sequential"
2477            );
2478            assert_eq!(
2479                mux.beta.as_slice(),
2480                seq.beta.as_slice(),
2481                "variant {idx}: multiplexed beta differs from sequential"
2482            );
2483            assert_eq!(
2484                mux.objective.to_bits(),
2485                seq.objective.to_bits(),
2486                "variant {idx}: multiplexed objective differs from sequential"
2487            );
2488        }
2489    }
2490
2491    /// #1017 Phase 4 throughput bench. On a CUDA host this dispatches the battery
2492    /// variant matrix concurrently on one device, asserts per-fit certified
2493    /// parity vs sequential, and prints the cross-fit throughput (multiplexed
2494    /// fits/sec vs sequential fits/sec — the single-a100 race speedup). On a
2495    /// CPU-only host it prints a skip line. Run with `--nocapture`.
2496    #[test]
2497    fn gpu_multiplex_throughput_bench() {
2498        let variants = battery_variant_matrix();
2499        let opts = DeviceResidentInnerOptions::default();
2500
2501        let probe = super::build_sweep_workspaces(&variants).expect("sweep workspaces");
2502        let any_device = probe.iter().any(|w| w.device_resident());
2503        if !any_device {
2504            println!(
2505                "[#1017 mux-bench] no CUDA device — {} variants (K1..4 x 3 basis) \
2506                 skipped; run on the GPU node for cross-fit throughput",
2507                variants.len()
2508            );
2509            return;
2510        }
2511
2512        let (results, mux_tp) =
2513            super::run_variant_sweep_multiplexed(&variants, opts).expect("multiplexed sweep");
2514        let seq_tp = super::assert_sweep_parity_vs_sequential(&variants, &opts, &results)
2515            .expect("sweep parity vs sequential must hold");
2516        println!(
2517            "[#1017 mux-bench] fits={} succeeded={} multiplexed={:.3}s ({:.1} fits/s) \
2518             sequential={:.3}s ({:.1} fits/s) cross-fit-speedup={:.2}x",
2519            mux_tp.fits,
2520            mux_tp.succeeded,
2521            mux_tp.wall_seconds,
2522            mux_tp.fits_per_second,
2523            seq_tp.wall_seconds,
2524            seq_tp.fits_per_second,
2525            mux_tp.fits_per_second / seq_tp.fits_per_second.max(1e-9),
2526        );
2527        assert_eq!(
2528            mux_tp.succeeded, mux_tp.fits,
2529            "all battery variants must fit successfully on device"
2530        );
2531    }
2532}