gam_solve/pirls/workspace.rs
1//! Reusable inner-loop scratch (`PirlsWorkspace`), the P-IRLS options bundle
2//! (`WorkingModelPirlsOptions`), the arrow-Schur structured-inner-solve
3//! descriptor, and the arrow-latent snapshot/restore/commit helpers.
4
5use super::*;
6
7pub struct PirlsWorkspace {
8 // Common IRLS buffers. Only O(n) state is kept persistently; any
9 // design-weighted n x p scratch must be streamed through bounded chunks.
10 pub wz: Array1<f64>,
11 pub eta_buf: Array1<f64>,
12 // Stage 2/4 assembly (use max needed sizes)
13 pub scaled_matrix: Array2<f64>, // (<= p + ebrows) x p
14 pub final_aug_matrix: Array2<f64>, // (<= p + erows) x p
15 // Stage 5 RHS buffers
16 pub rhs_full: Array1<f64>, // length <= p + erows
17 // Gradient check helpers
18 pub working_residual: Array1<f64>,
19 pub weighted_residual: Array1<f64>,
20 // Step-halving direction (XΔβ)
21 pub delta_eta: Array1<f64>,
22 // Preallocated buffer for GEMV results (length p)
23 pub vec_buf_p: Array1<f64>,
24 // Cached sparse penalized-system workspace for sparse-native solve eligibility/assembly.
25 pub(crate) sparse_penalized_system_cache: Option<SparsePenalizedSystemCache>,
26 // Factorization scratch (avoid per-iteration allocation)
27 pub factorization_scratch: MemBuffer,
28 // Permutation buffers for LDLT
29 pub perm: Vec<usize>,
30 pub perm_inv: Vec<usize>,
31 // Buffer for in-place factorization (preserves original Hessian in WorkingState)
32 pub factorization_matrix: Array2<f64>,
33 // Buffer for sparse matrix scaling (avoid per-iteration allocation)
34 pub weighted_xvalues: Vec<f64>,
35 // Dense chunk buffer for streaming X'WX assembly on very large n.
36 pub weighted_x_chunk: Array2<f64>,
37 // Reusable p×p buffer for Hessian assembly (avoids per-iteration allocation).
38 pub hessian_buf: Array2<f64>,
39 // Reusable n-length buffer for X*β matvec (avoids per-iteration allocation in update).
40 pub matvec_buf: Array1<f64>,
41 // #1412: device-resident design `X` for the GPU `XᵀWX` Gram. The inner P-IRLS
42 // loop rebuilds the Gram once per Newton/LM iterate with the SAME design `X`
43 // (only the working weights `w` move), so re-uploading the full n×p `X` on
44 // every iterate starves the device on H2D staging (measured ~98% of the
45 // pipeline at <20% utilisation). This caches the device-resident `X` keyed on
46 // its host data pointer + shape, so the first Gram of an inner solve uploads
47 // `X` and every later iterate crosses only the n-vector `w` H2D and the p×p
48 // Gram D2H. `None` whenever CUDA is unavailable / the shape is below the GPU
49 // Gram threshold / the upload failed — the caller keeps its per-call path.
50 pub(crate) resident_design_gram: Option<(
51 usize,
52 usize,
53 usize,
54 gam_gpu::linalg_dispatch::ResidentDesignGram,
55 )>,
56}
57
58impl PirlsWorkspace {
59 pub fn new(n: usize, p: usize, _: usize, _: usize) -> Self {
60 // Default implementation ignores this parameter.
61 // Default implementation ignores this parameter.
62 // Stage buffers are allocated lazily: historically these were pre-sized to
63 // worst-case dimensions, which inflates memory when many PIRLS workspaces
64 // exist concurrently (e.g. parallel REML evals).
65 // The active code paths resize-on-demand where needed.
66
67 PirlsWorkspace {
68 wz: Array1::zeros(n),
69 eta_buf: Array1::zeros(n),
70 scaled_matrix: Array2::zeros((0, 0).f()),
71 final_aug_matrix: Array2::zeros((0, 0).f()),
72 rhs_full: Array1::zeros(0),
73 working_residual: Array1::zeros(n),
74 weighted_residual: Array1::zeros(n),
75 delta_eta: Array1::zeros(n),
76 vec_buf_p: Array1::zeros(p),
77 sparse_penalized_system_cache: None,
78 // Keep scratch minimal at init; grow only if/when a factorization path
79 // needs it.
80 factorization_scratch: {
81 let par = faer::Par::Seq;
82 let req = faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<f64>(
83 1,
84 par,
85 Spec::new(<LltParams as Auto<f64>>::auto()),
86 );
87 MemBuffer::new(req)
88 },
89 perm: vec![0; p],
90 perm_inv: vec![0; p],
91 factorization_matrix: Array2::zeros((0, 0)),
92 weighted_xvalues: Vec::new(),
93 weighted_x_chunk: Array2::zeros((0, 0).f()),
94 hessian_buf: Array2::zeros((0, 0).f()),
95 matvec_buf: Array1::zeros(n),
96 resident_design_gram: None,
97 }
98 }
99
100 pub(super) fn add_dense_xtwx_signed(
101 weights: &Array1<f64>,
102 weighted_x_scratch: &mut Array2<f64>,
103 x: &Array2<f64>,
104 out: &mut Array2<f64>,
105 ) {
106 *out = crate::estimate::reml::assembly::xt_diag_x_dense_into(
107 x,
108 weights,
109 weighted_x_scratch,
110 );
111 }
112
113 /// Ensure the sparse penalty cache is populated and consistent with `x` and `s_lambda`.
114 pub(crate) fn ensure_sparse_penalty_cache(
115 &mut self,
116 x: &SparseColMat<usize, f64>,
117 s_lambda: &Array2<f64>,
118 ) -> Result<(), EstimationError> {
119 let penalty_pattern = SparsePenaltyPattern::from_dense_upper(s_lambda, 1e-12);
120 let rebuild = match self.sparse_penalized_system_cache.as_ref() {
121 Some(cache) => !cache.matches(x, &penalty_pattern),
122 None => true,
123 };
124 if rebuild {
125 self.sparse_penalized_system_cache =
126 Some(SparsePenalizedSystemCache::new(x, penalty_pattern)?);
127 }
128 Ok(())
129 }
130
131 pub(crate) fn sparse_penalized_system_stats(
132 &mut self,
133 x: &SparseColMat<usize, f64>,
134 s_lambda: &Array2<f64>,
135 ) -> Result<SparsePenalizedSystemStats, EstimationError> {
136 self.ensure_sparse_penalty_cache(x, s_lambda)?;
137 Ok(self.sparse_penalized_system_cache.as_ref().unwrap().stats())
138 }
139
140 // Phase 2 hook: numeric sparse penalized-system assembly in original coordinates.
141 pub(super) fn assemble_sparse_penalized_hessian(
142 &mut self,
143 x: &SparseColMat<usize, f64>,
144 weights: &Array1<f64>,
145 s_lambda: &Array2<f64>,
146 ridge: f64,
147 precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
148 ) -> Result<SparseColMat<usize, f64>, EstimationError> {
149 self.ensure_sparse_penalty_cache(x, s_lambda)?;
150 self.sparse_penalized_system_cache
151 .as_mut()
152 .unwrap()
153 .assemble_upper(x, weights, ridge, precomputed_xtwx)
154 }
155}
156
157#[derive(Clone, Debug)]
158pub struct WorkingModelPirlsOptions {
159 pub max_iterations: usize,
160 pub convergence_tolerance: f64,
161 pub adaptive_kkt_tolerance: Option<AdaptiveKktTolerance>,
162 pub max_step_halving: usize,
163 pub min_step_size: f64,
164 pub firth_bias_reduction: bool,
165 /// Optional lower bounds on coefficients (same coordinate system as `beta`).
166 /// Use `-inf` for unconstrained entries.
167 pub coefficient_lower_bounds: Option<Array1<f64>>,
168 /// Optional linear inequality constraints in current coefficient coordinates:
169 /// A * beta >= b.
170 pub linear_constraints: Option<LinearInequalityConstraints>,
171 /// Optional warm-start hint for the Levenberg-Marquardt damping
172 /// coefficient. When set, the inner solver seeds `λ_LM` to this
173 /// value instead of the default `1e-6`. Clamped on consumption to
174 /// `[1e-6, 1e-3]` so a stale or pathological hint cannot poison the
175 /// solve: the upper bound costs at most three damping halvings
176 /// versus the cold default, which is dwarfed by the savings when
177 /// the hint is informative.
178 ///
179 /// Used by `execute_pirls_if_needed` (in `solver::reml::outer_eval`)
180 /// to persist the converged λ across consecutive PIRLS calls in a
181 /// single REML outer optimization, so the inner Newton does not
182 /// have to rediscover problem-specific damping at every accepted
183 /// outer iterate.
184 pub initial_lm_lambda: Option<f64>,
185 /// Enable the Transtrum-Sethna geodesic-acceleration second-order
186 /// correction on each accepted Levenberg-Marquardt step. When true,
187 /// after the standard LM direction `δp = −(H + λ_lm·diag(H))⁻¹ g`
188 /// is computed and accepted by the LM gain test, the solver computes
189 /// a finite-difference estimate of the directional second derivative
190 /// of the gradient along `δp`, solves a *second* linear system with
191 /// the same (already-factored) Hessian, and adds the correction
192 /// `δp₂` to the step only if `‖δp₂‖ ≤ α‖δp‖` (the Transtrum-Sethna
193 /// 2011 acceptance criterion, α = 0.75 here). The correction costs
194 /// two extra full `WorkingModel::update` calls per accepted step
195 /// (for the FD evaluations); it is most useful for fits whose
196 /// penalized Hessian is near-singular (latent-coordinate fits,
197 /// near-collinear bases). Default `false`; opt-in until validated
198 /// across the broader family of likelihoods and penalties.
199 pub geodesic_acceleration: bool,
200 /// Optional arrow-Schur structured-inner-solve descriptor.
201 ///
202 /// When `Some`, every accepted LM Newton step inside the inner loop
203 /// is computed by the per-observation arrow-Schur path
204 /// ([`crate::arrow_schur::ArrowSchurSystem`]) instead of the
205 /// β-only `solve_newton_direction_dense`. When `None`, the existing
206 /// β-only path is used unchanged (back-compat: every existing call
207 /// site that does not opt in is unaffected).
208 ///
209 /// **Scope note.** This wires the *inner* Gauss–Newton step. The REML
210 /// outer-loop gradient w.r.t. `t` (which carries a shared `Schur⁻¹`
211 /// factor) is a separate plumbing change owned by the REML driver and is
212 /// **not** handled here.
213 pub arrow_schur: Option<ArrowSchurInnerConfig>,
214}
215
216/// Per-iteration arrow-Schur builder hook.
217///
218/// The driver supplies a closure that, given the current `β` iterate,
219/// returns a freshly-populated [`crate::arrow_schur::ArrowSchurSystem`]
220/// — i.e. the per-row `H_tt^(i)`, `H_tβ^(i)`, `g_t^(i)` blocks and the
221/// β-block `H_ββ`, `g_β`. The driver owns the assembly because the
222/// per-row Jacobians depend on the latent-coord term's basis (Duchon,
223/// Sphere, …) and the analytic-penalty contributions depend on the
224/// registry the outer-fit configuration owns. PIRLS only knows how to
225/// *solve* the bordered system once it has been assembled.
226#[derive(Clone)]
227pub struct ArrowSchurInnerConfig {
228 /// Number of latent rows `N`.
229 pub n_rows: usize,
230 /// Latent dimensionality `d`.
231 pub latent_dim: usize,
232 /// β dimensionality `K` (must match the inner Hessian dimension).
233 pub n_beta: usize,
234 /// Closure that builds the bordered system at the current `β` and
235 /// current latent `t` (the latter held externally by the driver, e.g.
236 /// in a `LatentCoordValues` registered alongside the working model).
237 /// Returning `None` signals "fall back to the β-only path for this
238 /// iteration" — useful for the seeding sweep before `t` has been
239 /// initialized.
240 pub build: std::sync::Arc<
241 dyn Fn(&Array1<f64>) -> Option<crate::arrow_schur::ArrowSchurSystem> + Send + Sync,
242 >,
243 /// BA Schur solve mode. `None` selects Direct for `K <= 2000` and
244 /// InexactPCG above, following "Bundle Adjustment in the Large".
245 pub solver_mode: Option<crate::arrow_schur::ArrowSolverMode>,
246 /// When set, assemble the reduced dense Schur block in row chunks.
247 pub streaming_chunk_size: Option<usize>,
248 /// Steihaug trust-region radius for the reduced shared step. This ports
249 /// the Ceres/BA trust-region guard while retaining PIRLS's LM damping.
250 pub trust_region_radius: f64,
251 /// Optional β-block column ranges for the block-Jacobi Schur preconditioner.
252 ///
253 /// When `Some`, the PIRLS driver calls
254 /// [`crate::arrow_schur::ArrowSchurSystem::set_block_offsets`] on
255 /// every system returned by the `build` closure, wiring the block-Jacobi
256 /// path without requiring each family's closure to call it manually.
257 ///
258 /// Derive from `ParameterBlockSpec` slices via
259 /// `gam_custom_family::block_offsets_from_specs`. When
260 /// `None`, the preconditioner falls back to scalar-diagonal Jacobi (the
261 /// pre-#287 behaviour); when `Some([])` (empty slice), the same fallback
262 /// applies.
263 pub block_offsets: Option<Arc<[std::ops::Range<usize>]>>,
264 /// Callback that the inner solver invokes after each LM-attempted
265 /// joint step to write the latent tangent increment back into the
266 /// driver's `LatentCoordValues` via that latent's update rule
267 /// (`retract_flat_delta` for manifold latents). `delta_t` is the flat
268 /// row-major increment of length `n_rows * latent_dim`.
269 pub apply_delta_t: std::sync::Arc<dyn Fn(&Array1<f64>) + Send + Sync>,
270 /// Snapshot the driver's latent field before an LM trial step mutates it.
271 pub snapshot_t: std::sync::Arc<dyn Fn() -> Array1<f64> + Send + Sync>,
272 /// Restore a snapshot produced by [`Self::snapshot_t`] after any rejected
273 /// LM trial. Accepted trials deliberately do not call this hook: β and t
274 /// commit together.
275 pub restore_t: std::sync::Arc<dyn Fn(&Array1<f64>) + Send + Sync>,
276}
277
278impl std::fmt::Debug for ArrowSchurInnerConfig {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 f.debug_struct("ArrowSchurInnerConfig")
281 .field("n_rows", &self.n_rows)
282 .field("latent_dim", &self.latent_dim)
283 .field("n_beta", &self.n_beta)
284 .field("solver_mode", &self.solver_mode)
285 .field("streaming_chunk_size", &self.streaming_chunk_size)
286 .field("trust_region_radius", &self.trust_region_radius)
287 .field(
288 "block_offsets",
289 &self.block_offsets.as_ref().map(|o| o.len()),
290 )
291 .finish_non_exhaustive()
292 }
293}
294
295pub(crate) fn restore_arrow_latent_if_needed(
296 options: &WorkingModelPirlsOptions,
297 snapshot: Option<Array1<f64>>,
298) {
299 if let (Some(arrow_cfg), Some(snapshot)) = (options.arrow_schur.as_ref(), snapshot) {
300 arrow_cfg.restore_t.as_ref()(&snapshot);
301 }
302}
303
304pub(super) fn restore_pending_arrow_latent_if_needed(
305 options: &WorkingModelPirlsOptions,
306 pending_snapshot: &mut Option<Array1<f64>>,
307) {
308 restore_arrow_latent_if_needed(options, pending_snapshot.take());
309}
310
311pub(super) fn commit_pending_arrow_latent(pending_snapshot: &mut Option<Array1<f64>>) {
312 drop(pending_snapshot.take());
313}