Skip to main content

gam_solve/pirls/
workspace.rs

1//! Reusable inner-loop scratch (`PirlsWorkspace`), the P-IRLS options bundle
2//! (`WorkingModelPirlsOptions`), the arrow-Schur structured-inner-solve
3//! descriptor, and the arrow-latent snapshot/restore/commit helpers.
4
5use super::*;
6
7pub struct PirlsWorkspace {
8    // Common IRLS buffers. Only O(n) state is kept persistently; any
9    // design-weighted n x p scratch must be streamed through bounded chunks.
10    pub wz: Array1<f64>,
11    pub eta_buf: Array1<f64>,
12    // Stage 2/4 assembly (use max needed sizes)
13    pub scaled_matrix: Array2<f64>,    // (<= p + ebrows) x p
14    pub final_aug_matrix: Array2<f64>, // (<= p + erows) x p
15    // Stage 5 RHS buffers
16    pub rhs_full: Array1<f64>, // length <= p + erows
17    // Gradient check helpers
18    pub working_residual: Array1<f64>,
19    pub weighted_residual: Array1<f64>,
20    // Step-halving direction (XΔβ)
21    pub delta_eta: Array1<f64>,
22    // Preallocated buffer for GEMV results (length p)
23    pub vec_buf_p: Array1<f64>,
24    // Cached sparse penalized-system workspace for sparse-native solve eligibility/assembly.
25    pub(crate) sparse_penalized_system_cache: Option<SparsePenalizedSystemCache>,
26    // Factorization scratch (avoid per-iteration allocation)
27    pub factorization_scratch: MemBuffer,
28    // Permutation buffers for LDLT
29    pub perm: Vec<usize>,
30    pub perm_inv: Vec<usize>,
31    // Buffer for in-place factorization (preserves original Hessian in WorkingState)
32    pub factorization_matrix: Array2<f64>,
33    // Buffer for sparse matrix scaling (avoid per-iteration allocation)
34    pub weighted_xvalues: Vec<f64>,
35    // Dense chunk buffer for streaming X'WX assembly on very large n.
36    pub weighted_x_chunk: Array2<f64>,
37    // Reusable p×p buffer for Hessian assembly (avoids per-iteration allocation).
38    pub hessian_buf: Array2<f64>,
39    // Reusable n-length buffer for X*β matvec (avoids per-iteration allocation in update).
40    pub matvec_buf: Array1<f64>,
41    // #1412: device-resident design `X` for the GPU `XᵀWX` Gram. The inner P-IRLS
42    // loop rebuilds the Gram once per Newton/LM iterate with the SAME design `X`
43    // (only the working weights `w` move), so re-uploading the full n×p `X` on
44    // every iterate starves the device on H2D staging (measured ~98% of the
45    // pipeline at <20% utilisation). This caches the device-resident `X` keyed on
46    // its host data pointer + shape, so the first Gram of an inner solve uploads
47    // `X` and every later iterate crosses only the n-vector `w` H2D and the p×p
48    // Gram D2H. `None` whenever CUDA is unavailable / the shape is below the GPU
49    // Gram threshold / the upload failed — the caller keeps its per-call path.
50    pub(crate) resident_design_gram: Option<(
51        usize,
52        usize,
53        usize,
54        gam_gpu::linalg_dispatch::ResidentDesignGram,
55    )>,
56}
57
58impl PirlsWorkspace {
59    pub fn new(n: usize, p: usize, _: usize, _: usize) -> Self {
60        // Default implementation ignores this parameter.
61        // Default implementation ignores this parameter.
62        // Stage buffers are allocated lazily: historically these were pre-sized to
63        // worst-case dimensions, which inflates memory when many PIRLS workspaces
64        // exist concurrently (e.g. parallel REML evals).
65        // The active code paths resize-on-demand where needed.
66
67        PirlsWorkspace {
68            wz: Array1::zeros(n),
69            eta_buf: Array1::zeros(n),
70            scaled_matrix: Array2::zeros((0, 0).f()),
71            final_aug_matrix: Array2::zeros((0, 0).f()),
72            rhs_full: Array1::zeros(0),
73            working_residual: Array1::zeros(n),
74            weighted_residual: Array1::zeros(n),
75            delta_eta: Array1::zeros(n),
76            vec_buf_p: Array1::zeros(p),
77            sparse_penalized_system_cache: None,
78            // Keep scratch minimal at init; grow only if/when a factorization path
79            // needs it.
80            factorization_scratch: {
81                let par = faer::Par::Seq;
82                let req = faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<f64>(
83                    1,
84                    par,
85                    Spec::new(<LltParams as Auto<f64>>::auto()),
86                );
87                MemBuffer::new(req)
88            },
89            perm: vec![0; p],
90            perm_inv: vec![0; p],
91            factorization_matrix: Array2::zeros((0, 0)),
92            weighted_xvalues: Vec::new(),
93            weighted_x_chunk: Array2::zeros((0, 0).f()),
94            hessian_buf: Array2::zeros((0, 0).f()),
95            matvec_buf: Array1::zeros(n),
96            resident_design_gram: None,
97        }
98    }
99
100    pub(super) fn add_dense_xtwx_signed(
101        weights: &Array1<f64>,
102        weighted_x_scratch: &mut Array2<f64>,
103        x: &Array2<f64>,
104        out: &mut Array2<f64>,
105    ) {
106        *out = crate::estimate::reml::assembly::xt_diag_x_dense_into(
107            x,
108            weights,
109            weighted_x_scratch,
110        );
111    }
112
113    /// Ensure the sparse penalty cache is populated and consistent with `x` and `s_lambda`.
114    pub(crate) fn ensure_sparse_penalty_cache(
115        &mut self,
116        x: &SparseColMat<usize, f64>,
117        s_lambda: &Array2<f64>,
118    ) -> Result<(), EstimationError> {
119        let penalty_pattern = SparsePenaltyPattern::from_dense_upper(s_lambda, 1e-12);
120        let rebuild = match self.sparse_penalized_system_cache.as_ref() {
121            Some(cache) => !cache.matches(x, &penalty_pattern),
122            None => true,
123        };
124        if rebuild {
125            self.sparse_penalized_system_cache =
126                Some(SparsePenalizedSystemCache::new(x, penalty_pattern)?);
127        }
128        Ok(())
129    }
130
131    pub(crate) fn sparse_penalized_system_stats(
132        &mut self,
133        x: &SparseColMat<usize, f64>,
134        s_lambda: &Array2<f64>,
135    ) -> Result<SparsePenalizedSystemStats, EstimationError> {
136        self.ensure_sparse_penalty_cache(x, s_lambda)?;
137        Ok(self.sparse_penalized_system_cache.as_ref().unwrap().stats())
138    }
139
140    // Phase 2 hook: numeric sparse penalized-system assembly in original coordinates.
141    pub(super) fn assemble_sparse_penalized_hessian(
142        &mut self,
143        x: &SparseColMat<usize, f64>,
144        weights: &Array1<f64>,
145        s_lambda: &Array2<f64>,
146        ridge: f64,
147        precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
148    ) -> Result<SparseColMat<usize, f64>, EstimationError> {
149        self.ensure_sparse_penalty_cache(x, s_lambda)?;
150        self.sparse_penalized_system_cache
151            .as_mut()
152            .unwrap()
153            .assemble_upper(x, weights, ridge, precomputed_xtwx)
154    }
155}
156
157#[derive(Clone, Debug)]
158pub struct WorkingModelPirlsOptions {
159    pub max_iterations: usize,
160    pub convergence_tolerance: f64,
161    pub adaptive_kkt_tolerance: Option<AdaptiveKktTolerance>,
162    pub max_step_halving: usize,
163    pub min_step_size: f64,
164    pub firth_bias_reduction: bool,
165    /// Optional lower bounds on coefficients (same coordinate system as `beta`).
166    /// Use `-inf` for unconstrained entries.
167    pub coefficient_lower_bounds: Option<Array1<f64>>,
168    /// Optional linear inequality constraints in current coefficient coordinates:
169    ///   A * beta >= b.
170    pub linear_constraints: Option<LinearInequalityConstraints>,
171    /// Optional warm-start hint for the Levenberg-Marquardt damping
172    /// coefficient. When set, the inner solver seeds `λ_LM` to this
173    /// value instead of the default `1e-6`. Clamped on consumption to
174    /// `[1e-6, 1e-3]` so a stale or pathological hint cannot poison the
175    /// solve: the upper bound costs at most three damping halvings
176    /// versus the cold default, which is dwarfed by the savings when
177    /// the hint is informative.
178    ///
179    /// Used by `execute_pirls_if_needed` (in `solver::reml::outer_eval`)
180    /// to persist the converged λ across consecutive PIRLS calls in a
181    /// single REML outer optimization, so the inner Newton does not
182    /// have to rediscover problem-specific damping at every accepted
183    /// outer iterate.
184    pub initial_lm_lambda: Option<f64>,
185    /// Enable the Transtrum-Sethna geodesic-acceleration second-order
186    /// correction on each accepted Levenberg-Marquardt step. When true,
187    /// after the standard LM direction `δp = −(H + λ_lm·diag(H))⁻¹ g`
188    /// is computed and accepted by the LM gain test, the solver computes
189    /// a finite-difference estimate of the directional second derivative
190    /// of the gradient along `δp`, solves a *second* linear system with
191    /// the same (already-factored) Hessian, and adds the correction
192    /// `δp₂` to the step only if `‖δp₂‖ ≤ α‖δp‖` (the Transtrum-Sethna
193    /// 2011 acceptance criterion, α = 0.75 here). The correction costs
194    /// two extra full `WorkingModel::update` calls per accepted step
195    /// (for the FD evaluations); it is most useful for fits whose
196    /// penalized Hessian is near-singular (latent-coordinate fits,
197    /// near-collinear bases). Default `false`; opt-in until validated
198    /// across the broader family of likelihoods and penalties.
199    pub geodesic_acceleration: bool,
200    /// Optional arrow-Schur structured-inner-solve descriptor.
201    ///
202    /// When `Some`, every accepted LM Newton step inside the inner loop
203    /// is computed by the per-observation arrow-Schur path
204    /// ([`crate::arrow_schur::ArrowSchurSystem`]) instead of the
205    /// β-only `solve_newton_direction_dense`. When `None`, the existing
206    /// β-only path is used unchanged (back-compat: every existing call
207    /// site that does not opt in is unaffected).
208    ///
209    /// **Scope note.** This wires the *inner* Gauss–Newton step. The REML
210    /// outer-loop gradient w.r.t. `t` (which carries a shared `Schur⁻¹`
211    /// factor) is a separate plumbing change owned by the REML driver and is
212    /// **not** handled here.
213    pub arrow_schur: Option<ArrowSchurInnerConfig>,
214}
215
216/// Per-iteration arrow-Schur builder hook.
217///
218/// The driver supplies a closure that, given the current `β` iterate,
219/// returns a freshly-populated [`crate::arrow_schur::ArrowSchurSystem`]
220/// — i.e. the per-row `H_tt^(i)`, `H_tβ^(i)`, `g_t^(i)` blocks and the
221/// β-block `H_ββ`, `g_β`. The driver owns the assembly because the
222/// per-row Jacobians depend on the latent-coord term's basis (Duchon,
223/// Sphere, …) and the analytic-penalty contributions depend on the
224/// registry the outer-fit configuration owns. PIRLS only knows how to
225/// *solve* the bordered system once it has been assembled.
226#[derive(Clone)]
227pub struct ArrowSchurInnerConfig {
228    /// Number of latent rows `N`.
229    pub n_rows: usize,
230    /// Latent dimensionality `d`.
231    pub latent_dim: usize,
232    /// β dimensionality `K` (must match the inner Hessian dimension).
233    pub n_beta: usize,
234    /// Closure that builds the bordered system at the current `β` and
235    /// current latent `t` (the latter held externally by the driver, e.g.
236    /// in a `LatentCoordValues` registered alongside the working model).
237    /// Returning `None` signals "fall back to the β-only path for this
238    /// iteration" — useful for the seeding sweep before `t` has been
239    /// initialized.
240    pub build: std::sync::Arc<
241        dyn Fn(&Array1<f64>) -> Option<crate::arrow_schur::ArrowSchurSystem> + Send + Sync,
242    >,
243    /// BA Schur solve mode. `None` selects Direct for `K <= 2000` and
244    /// InexactPCG above, following "Bundle Adjustment in the Large".
245    pub solver_mode: Option<crate::arrow_schur::ArrowSolverMode>,
246    /// When set, assemble the reduced dense Schur block in row chunks.
247    pub streaming_chunk_size: Option<usize>,
248    /// Steihaug trust-region radius for the reduced shared step. This ports
249    /// the Ceres/BA trust-region guard while retaining PIRLS's LM damping.
250    pub trust_region_radius: f64,
251    /// Optional β-block column ranges for the block-Jacobi Schur preconditioner.
252    ///
253    /// When `Some`, the PIRLS driver calls
254    /// [`crate::arrow_schur::ArrowSchurSystem::set_block_offsets`] on
255    /// every system returned by the `build` closure, wiring the block-Jacobi
256    /// path without requiring each family's closure to call it manually.
257    ///
258    /// Derive from `ParameterBlockSpec` slices via
259    /// `gam_custom_family::block_offsets_from_specs`.  When
260    /// `None`, the preconditioner falls back to scalar-diagonal Jacobi (the
261    /// pre-#287 behaviour); when `Some([])` (empty slice), the same fallback
262    /// applies.
263    pub block_offsets: Option<Arc<[std::ops::Range<usize>]>>,
264    /// Callback that the inner solver invokes after each LM-attempted
265    /// joint step to write the latent tangent increment back into the
266    /// driver's `LatentCoordValues` via that latent's update rule
267    /// (`retract_flat_delta` for manifold latents). `delta_t` is the flat
268    /// row-major increment of length `n_rows * latent_dim`.
269    pub apply_delta_t: std::sync::Arc<dyn Fn(&Array1<f64>) + Send + Sync>,
270    /// Snapshot the driver's latent field before an LM trial step mutates it.
271    pub snapshot_t: std::sync::Arc<dyn Fn() -> Array1<f64> + Send + Sync>,
272    /// Restore a snapshot produced by [`Self::snapshot_t`] after any rejected
273    /// LM trial. Accepted trials deliberately do not call this hook: β and t
274    /// commit together.
275    pub restore_t: std::sync::Arc<dyn Fn(&Array1<f64>) + Send + Sync>,
276}
277
278impl std::fmt::Debug for ArrowSchurInnerConfig {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        f.debug_struct("ArrowSchurInnerConfig")
281            .field("n_rows", &self.n_rows)
282            .field("latent_dim", &self.latent_dim)
283            .field("n_beta", &self.n_beta)
284            .field("solver_mode", &self.solver_mode)
285            .field("streaming_chunk_size", &self.streaming_chunk_size)
286            .field("trust_region_radius", &self.trust_region_radius)
287            .field(
288                "block_offsets",
289                &self.block_offsets.as_ref().map(|o| o.len()),
290            )
291            .finish_non_exhaustive()
292    }
293}
294
295pub(crate) fn restore_arrow_latent_if_needed(
296    options: &WorkingModelPirlsOptions,
297    snapshot: Option<Array1<f64>>,
298) {
299    if let (Some(arrow_cfg), Some(snapshot)) = (options.arrow_schur.as_ref(), snapshot) {
300        arrow_cfg.restore_t.as_ref()(&snapshot);
301    }
302}
303
304pub(super) fn restore_pending_arrow_latent_if_needed(
305    options: &WorkingModelPirlsOptions,
306    pending_snapshot: &mut Option<Array1<f64>>,
307) {
308    restore_arrow_latent_if_needed(options, pending_snapshot.take());
309}
310
311pub(super) fn commit_pending_arrow_latent(pending_snapshot: &mut Option<Array1<f64>>) {
312    drop(pending_snapshot.take());
313}