Skip to main content

gam_solve/reml/reml_outer_engine/
inner_solution.rs

1use super::*;
2
3/// Specifies whether the model uses profiled scale (Gaussian REML) or
4/// fixed dispersion (non-Gaussian LAML).
5#[derive(Clone, Debug)]
6pub enum DispersionHandling {
7    /// Gaussian REML: φ̂ = D_p / (n − M_p), profiled out of the objective.
8    /// The cost includes (n−M_p)/2 · log(2πφ̂) and the gradient includes
9    /// the profiled scale derivative. Always includes both logdet terms.
10    ProfiledGaussian,
11    /// Non-Gaussian LAML or maximum penalized likelihood.
12    ///
13    /// `include_logdet_h` controls whether ½ log|H| is included (true for full
14    /// LAML, false for MPL/PQL).
15    /// `include_logdet_s` controls whether −½ log|S|₊ is included.
16    ///
17    /// Standard LAML: `Fixed { phi: 1.0, include_logdet_h: true, include_logdet_s: true }`
18    /// MaxPenalizedLikelihood: `Fixed { phi: 1.0, include_logdet_h: false, include_logdet_s: false }`
19    Fixed {
20        phi: f64,
21        include_logdet_h: bool,
22        include_logdet_s: bool,
23    },
24}
25
26/// The unified inner solution produced by any inner solver.
27///
28/// Contains everything the outer REML/LAML evaluator needs. Produced by:
29/// - Single-block PIRLS (via `PirlsResult::into_inner_solution()`)
30/// - Blockwise coupled Newton (via `BlockwiseInnerResult::into_inner_solution()`)
31/// - Sparse Cholesky (via `SparsePenalizedSystem::into_inner_solution()`)
32pub struct InnerSolution<'dp> {
33    // === Objective ingredients ===
34    /// ℓ(β̂) — log-likelihood at the converged mode.
35    /// For Gaussian: −0.5 × deviance (RSS). For GLMs: actual log-likelihood.
36    pub log_likelihood: f64,
37
38    /// β̂ᵀS(ρ)β̂ — penalty quadratic form at the mode.
39    pub penalty_quadratic: f64,
40
41    // === The factorization (single source of truth for all linear algebra) ===
42    /// The Hessian operator providing logdet, trace, and solve.
43    /// Both cost and gradient use this same object.
44    ///
45    /// IMPORTANT: This MUST encode the **observed** Hessian H_obs = X'W_obs X + S
46    /// at the converged mode, where W_obs includes the residual-dependent correction
47    /// for non-canonical links. Using expected Fisher H_Fisher = X'W_Fisher X + S
48    /// would make this a PQL surrogate rather than the exact Laplace approximation.
49    /// See response.md Section 3 for the mathematical justification.
50    pub hessian_op: Arc<dyn HessianOperator>,
51
52    // === Coefficients and penalty structure ===
53    /// β̂ — coefficients at the converged mode (in the operator's native basis).
54    pub beta: Array1<f64>,
55
56    /// Penalty coordinates for the rho block.
57    ///
58    /// Each coordinate represents one smoothing-parameter direction
59    ///   A_k = λ_k S_k
60    /// through either a full-root or a block-local root.
61    pub penalty_coords: Vec<PenaltyCoordinate>,
62
63    /// Derivatives of log|S(ρ)|₊ — precomputed from penalty structure.
64    pub penalty_logdet: PenaltyLogdetDerivs,
65
66    // === Family-specific derivative info ===
67    /// Provider of third-derivative corrections for non-Gaussian families.
68    ///
69    /// The c and d arrays (dW/deta, d^2W/deta^2) carried by this provider MUST
70    /// be the **observed** derivatives, not the Fisher derivatives. For non-canonical
71    /// links the observed c/d include residual-dependent corrections:
72    ///   c_obs = c_Fisher + h'*B - (y-mu)*B_eta
73    ///   d_obs = d_Fisher + h''*B + 2*h'*B_eta - (y-mu)*B_etaeta
74    /// These corrections matter for the outer gradient (C[v] correction) and
75    /// outer Hessian (Q[v_k, v_l] correction). See response.md Section 3.
76    pub deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
77
78    // === Corrections ===
79    /// Optional exact Jeffreys/Firth term in the active coefficient basis.
80    pub firth: Option<ExactJeffreysTerm>,
81
82    /// Additive correction for the Hessian logdet when `hessian_op` encodes a
83    /// uniformly rescaled exact curvature matrix.
84    pub hessian_logdet_correction: f64,
85
86    /// When the cost uses `log|U_Sᵀ H U_S|_+` (rank-deficient LAML fix),
87    /// this carries the matching projected kernel so the gradient trace
88    /// `tr(K · Ḣ)` agrees with the cost's derivative.  See
89    /// [`PenaltySubspaceTrace`] for the full derivation.
90    pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
91
92    /// Uniform scale `s` applied to rho-coordinate penalty derivatives in the
93    /// H-dependent trace / solve parts of the outer calculus.
94    ///
95    /// ## Contract (CRITICAL — gradient/cost consistency)
96    ///
97    /// `rho_curvature_scale` is NOT a free knob.  It encodes the convention
98    /// that the supplied `hessian_op` represents the **rescaled** curvature
99    /// `H_op = s · (∇²(-ℓ) + Σ_k e^{ρ_k} S_k)`, i.e. every contribution to
100    /// the curvature (likelihood Hessian AND penalty `λ_k S_k`) has been
101    /// uniformly multiplied by `s` before reaching the evaluator.  Under this
102    /// convention:
103    ///
104    /// * `∂H_op/∂ρ_k = s · λ_k S_k` (matches the `curvature_lambdas = s · λ`
105    ///   drift used inside the gradient's trace term),
106    /// * `K = H_op⁻¹ = (1/s) · (∇²(-ℓ) + λS)⁻¹`,
107    /// * `tr(K · ∂H_op/∂ρ_k) = tr((∇²(-ℓ) + λS)⁻¹ · λ_k S_k)` (the analytic
108    ///   gradient of the **unscaled** `log|H|`),
109    /// * `log|H_op| = log|∇²(-ℓ) + λS| + p · log(s)`, which the caller MUST
110    ///   un-scale by supplying `hessian_logdet_correction += −p · log(s)` so
111    ///   that `hop.logdet() + hessian_logdet_correction` evaluates the same
112    ///   unscaled `log|H|` whose derivative the gradient trace computes.
113    ///
114    /// Callers that set `rho_curvature_scale ≠ 1` without ALSO pre-scaling
115    /// `hessian_op` AND adding the matching `−p·log(s)` term to
116    /// `hessian_logdet_correction` will get a gradient that is off by the
117    /// factor `s` from `dV/dρ_k`.  The unified evaluator does **not** scale
118    /// `hop` for the caller — that would defeat the purpose of the
119    /// curvature-conditioning trick survival families use to keep the
120    /// outer eigendecomposition numerically stable.
121    ///
122    /// See `survival::location_scale::exact_newton_outer_curvature` for the
123    /// canonical example: `rho_curvature_scale = exp(-log_scale)` paired with
124    /// `hessian_logdet_correction = p · log_scale = −p · log(scale)`.
125    ///
126    /// The evaluator enforces `rho_curvature_scale > 0` and finite; pass
127    /// `1.0` (the documented default) when no curvature conditioning is in
128    /// play.
129    pub rho_curvature_scale: f64,
130
131    /// Configured prior over rho coordinates. The evaluator receives the
132    /// realized cost/gradient tuple separately; this copy lets EFS use the
133    /// conjugate Gamma rate in its multiplicative denominator.
134    pub rho_prior: gam_problem::RhoPrior,
135
136    // === Model dimensions ===
137    /// Number of observations.
138    pub n_observations: usize,
139
140    /// M_p: dimension of the penalty null space (unpenalized coefficients).
141    pub nullspace_dim: f64,
142
143    /// ½·Σᵢ log(wᵢ) — half the sum of log prior weights.
144    ///
145    /// This is the per-observation Gaussian normalization constant that the
146    /// `log_likelihood` (computed by
147    /// [`calculate_loglikelihood_omitting_constants`]) deliberately drops. The
148    /// full weighted-Gaussian negative log-likelihood normalization is
149    ///   ½·Σᵢ log(2π·φ/wᵢ) = (n/2)·log(2πφ) − ½·Σᵢ log(wᵢ),
150    /// because `Var(yᵢ) = φ/wᵢ` under inverse-variance prior weights.
151    ///
152    /// Dropping `−½·Σ log(wᵢ)` does not move the ρ-argmin in exact arithmetic
153    /// (it is constant in ρ), but it makes the ProfiledGaussian objective VALUE
154    /// scale-dependent: under a global rescale `w → c·w` the invariance-
155    /// preserving smoothing `λ → c·λ` leaves the cost SHAPE fixed but inflates
156    /// its absolute value by `(n/2)·log c`. That inflation breaks the exact
157    /// weight-scale invariance of the selected λ̂ / EDF / fit (issue #877).
158    /// Restoring this term makes the ProfiledGaussian cost value exactly
159    /// invariant to `w → c·w` (with σ̂² absorbing the c factor), matching mgcv.
160    ///
161    /// Only consumed by the `ProfiledGaussian` arm; the `Fixed`-dispersion arm
162    /// already omits the Gaussian normalization constant by design and is not
163    /// affected.
164    pub gaussian_weight_log_sum_half: f64,
165
166    /// Deviance scale `D₀` used as the *relative* reference for the smooth
167    /// penalized-deviance floor (see [`crate::estimate::smooth_floor_dp`]).
168    ///
169    /// Set to the weighted null deviance of the Gaussian response,
170    /// `D₀ = Σ wᵢ(yᵢ − ȳ_w)²`, which is the natural upper reference for
171    /// `D_p` and — crucially — transforms as `D₀ → a²·D₀` under a response
172    /// rescale `y → a·y`, exactly as `D_p` does. Flooring `D_p` at a fixed
173    /// fraction of `D₀` therefore keeps the profiled Gaussian REML criterion
174    /// exactly scale-equivariant (issue #1127); an absolute floor does not.
175    ///
176    /// Only consumed by the `ProfiledGaussian` arm. Defaults to `1.0`, which
177    /// reproduces the historical absolute floor byte-for-byte for every caller
178    /// that does not supply a response scale.
179    pub dp_floor_scale: f64,
180
181    /// How the dispersion parameter is handled.
182    pub dispersion: DispersionHandling,
183
184    // === Extended hyperparameter coordinates (ψ / τ) ===
185    /// External (non-ρ) hyperparameter coordinates with their fixed-β objects.
186    /// These are appended after the ρ coordinates in the gradient/Hessian output.
187    pub ext_coords: Vec<HyperCoord>,
188
189    /// Callback to compute second-order fixed-β objects for a pair (i, j)
190    /// of external coordinates (or external × ρ cross pairs).
191    /// Arguments: (ext_index_i, ext_index_j) → HyperCoordPair.
192    /// When None, the outer Hessian is not computed for extended coordinates.
193    pub ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
194
195    /// Callback for ρ × ext cross pairs: (rho_index, ext_index) → HyperCoordPair.
196    pub rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
197
198    /// M_i[u] = D_β B_i[u] callback for extended coordinates.
199    /// Arguments: (ext_index, direction) → correction matrix.
200    pub fixed_drift_deriv: Option<FixedDriftDerivFn>,
201
202    /// Direction-contracted second-order ψ hook for the profiled θ-HVP (#740).
203    /// When present, the outer-Hessian operator builder skips the `K²` per-pair
204    /// `base_h2` ψψ assembly and instead applies this once per matvec to obtain
205    /// every output row's `tr(K · D²_ψ H_L[ψ_i, ψ(α)])` in a single family row
206    /// pass. `None` keeps the exact per-pair assembly. See
207    /// [`ContractedPsiSecondOrderFn`].
208    pub contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
209
210    /// Optional log-barrier configuration for monotonicity-constrained coefficients.
211    /// When present, the barrier cost and Hessian corrections are added to the
212    /// outer REML/LAML objective.
213    pub barrier_config: Option<BarrierConfig>,
214
215    /// Optional inner KKT residual `r = ∇_β L_pen(β̂)` at the converged β̂,
216    /// already projected onto the free subspace (see [`ProjectedKktResidual`]
217    /// for the invariant and why the type wraps this). `Some` activates the
218    /// implicit-function-theorem corrections in `reml_laml_evaluate` (cost
219    /// gets `−½ rᵀ H⁻¹ r`, ρ-gradient and ρρ Hessian get the matching first
220    /// and second derivatives of that same scalar correction). `None` keeps
221    /// the envelope-only behaviour for callers that genuinely guarantee
222    /// exact KKT.
223    pub kkt_residual: Option<ProjectedKktResidual>,
224
225    /// Optional active linear-inequality constraints at the converged inner
226    /// iterate. `Some(rows)` means the joint constraint matrix's row indices
227    /// in `rows.active_indices` are pinned (treated as equality constraints
228    /// at the cert point). The unified evaluator combines this with the
229    /// `penalty_subspace_trace` to form the **constraint-aware** kernel
230    /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S` for per-coordinate IFT mode
231    /// responses `v_k = ∂β/∂ρ_k`. See [`ConstrainedSubspaceKernel`] for
232    /// the full derivation and consistency with `log|U_Tᵀ H U_T|`.
233    ///
234    /// `None` is the legacy/unconstrained path (no active inequality
235    /// constraints to project against).
236    pub active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
237
238    /// Fit-level stochastic trace state. Shared by stochastic trace batches so
239    /// CRN probe prefixes stay fixed and matrix-free trace CG can warm-start
240    /// from the previous solve of the same probe id.
241    pub stochastic_trace_state: Arc<Mutex<StochasticTraceState>>,
242}
243
244/// Builder for `InnerSolution` that provides sensible defaults and
245/// auto-computes derived quantities (nullspace_dim).
246pub struct InnerSolutionBuilder<'dp> {
247    // Required fields
248    pub(crate) log_likelihood: f64,
249    pub(crate) penalty_quadratic: f64,
250    pub(crate) hessian_op: Arc<dyn HessianOperator>,
251    pub(crate) beta: Array1<f64>,
252    pub(crate) penalty_coords: Vec<PenaltyCoordinate>,
253    pub(crate) penalty_logdet: PenaltyLogdetDerivs,
254    pub(crate) n_observations: usize,
255    pub(crate) dispersion: DispersionHandling,
256    // Optional fields with defaults
257    pub(crate) deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
258    pub(crate) firth: Option<ExactJeffreysTerm>,
259    pub(crate) hessian_logdet_correction: f64,
260    pub(crate) penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
261    pub(crate) rho_curvature_scale: f64,
262    pub(crate) rho_prior: gam_problem::RhoPrior,
263    pub(crate) nullspace_dim_override: Option<f64>,
264    // Extended hyperparameter coordinates
265    pub(crate) ext_coords: Vec<HyperCoord>,
266    pub(crate) ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
267    pub(crate) rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
268    pub(crate) fixed_drift_deriv: Option<FixedDriftDerivFn>,
269    pub(crate) contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
270    pub(crate) barrier_config: Option<BarrierConfig>,
271    pub(crate) kkt_residual: Option<ProjectedKktResidual>,
272    pub(crate) active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
273    pub(crate) gaussian_weight_log_sum_half: f64,
274    pub(crate) dp_floor_scale: f64,
275}
276
277impl<'dp> InnerSolutionBuilder<'dp> {
278    /// Create a builder with the required core fields.
279    pub fn new(
280        log_likelihood: f64,
281        penalty_quadratic: f64,
282        beta: Array1<f64>,
283        n_observations: usize,
284        hessian_op: Arc<dyn HessianOperator>,
285        penalty_coords: Vec<PenaltyCoordinate>,
286        penalty_logdet: PenaltyLogdetDerivs,
287        dispersion: DispersionHandling,
288    ) -> Self {
289        Self {
290            log_likelihood,
291            penalty_quadratic,
292            hessian_op,
293            beta,
294            penalty_coords,
295            penalty_logdet,
296            n_observations,
297            dispersion,
298            deriv_provider: Box::new(GaussianDerivatives),
299            firth: None,
300            hessian_logdet_correction: 0.0,
301            penalty_subspace_trace: None,
302            rho_curvature_scale: 1.0,
303            rho_prior: gam_problem::RhoPrior::Flat,
304            nullspace_dim_override: None,
305            ext_coords: Vec::new(),
306            ext_coord_pair_fn: None,
307            rho_ext_pair_fn: None,
308            fixed_drift_deriv: None,
309            contracted_psi_second_order: None,
310            barrier_config: None,
311            kkt_residual: None,
312            active_constraints: None,
313            gaussian_weight_log_sum_half: 0.0,
314            dp_floor_scale: 1.0,
315        }
316    }
317
318    pub fn deriv_provider(mut self, p: Box<dyn HessianDerivativeProvider + 'dp>) -> Self {
319        self.deriv_provider = p;
320        self
321    }
322
323    /// Install a pre-built Jeffreys/Firth term (Tier-A operator-backed via
324    /// `ExactJeffreysTerm::new`, or the Tier-B value-only carrier via
325    /// `ExactJeffreysTerm::value_only`).
326    pub fn firth_term(mut self, term: Option<ExactJeffreysTerm>) -> Self {
327        self.firth = term;
328        self
329    }
330
331    pub fn hessian_logdet_correction(mut self, correction: f64) -> Self {
332        self.hessian_logdet_correction = correction;
333        self
334    }
335
336    /// Install the projected-logdet trace kernel that pairs with the
337    /// `hessian_logdet_correction` on a rank-deficient penalty surface.
338    /// See [`PenaltySubspaceTrace`] for the derivation and when it is
339    /// required for gradient consistency.
340    pub fn penalty_subspace_trace(mut self, kernel: Option<Arc<PenaltySubspaceTrace>>) -> Self {
341        self.penalty_subspace_trace = kernel;
342        self
343    }
344
345    pub fn rho_curvature_scale(mut self, scale: f64) -> Self {
346        self.rho_curvature_scale = scale;
347        self
348    }
349
350    pub fn rho_prior(mut self, prior: gam_problem::RhoPrior) -> Self {
351        self.rho_prior = prior;
352        self
353    }
354
355    /// Override the auto-computed nullspace dimension.
356    ///
357    /// By default, `build()` computes nullspace_dim as
358    /// `beta.len() - sum(penalty_coord.rank())`. Use this when the caller
359    /// has a different authoritative value (e.g. from stored per-penalty dims).
360    pub fn nullspace_dim_override(mut self, dim: f64) -> Self {
361        self.nullspace_dim_override = Some(dim);
362        self
363    }
364
365    pub fn ext_coords(mut self, coords: Vec<HyperCoord>) -> Self {
366        self.ext_coords = coords;
367        self
368    }
369
370    pub fn ext_coord_pair_fn(
371        mut self,
372        f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
373    ) -> Self {
374        self.ext_coord_pair_fn = Some(f);
375        self
376    }
377
378    pub fn rho_ext_pair_fn(
379        mut self,
380        f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
381    ) -> Self {
382        self.rho_ext_pair_fn = Some(f);
383        self
384    }
385
386    pub fn fixed_drift_deriv(mut self, f: FixedDriftDerivFn) -> Self {
387        self.fixed_drift_deriv = Some(f);
388        self
389    }
390
391    /// Install the direction-contracted second-order ψ hook (#740). When set,
392    /// the outer-Hessian operator builder uses it instead of the `K²` per-pair
393    /// `base_h2` ψψ assembly. See [`ContractedPsiSecondOrderFn`].
394    pub fn contracted_psi_second_order(mut self, f: Option<ContractedPsiSecondOrderFn>) -> Self {
395        self.contracted_psi_second_order = f;
396        self
397    }
398
399    pub fn barrier_config(mut self, config: Option<BarrierConfig>) -> Self {
400        self.barrier_config = config;
401        self
402    }
403
404    pub fn kkt_residual(mut self, residual: Option<ProjectedKktResidual>) -> Self {
405        self.kkt_residual = residual;
406        self
407    }
408
409    /// Stash the active linear-inequality constraint block carried alongside the
410    /// inner solution. Used by `PenaltySubspaceTrace::with_active_constraints`
411    /// at REML/LAML evaluation time to form the constraint-aware kernel
412    /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S`.
413    pub fn active_constraints(mut self, block: Option<Arc<ActiveLinearConstraintBlock>>) -> Self {
414        self.active_constraints = block;
415        self
416    }
417
418    /// Build the `InnerSolution`, auto-computing nullspace_dim from penalty coordinates.
419    pub fn build(self) -> InnerSolution<'dp> {
420        let beta_dim = self.beta.len();
421        let penalty_dim = self.penalty_coords.len();
422        assert_eq!(
423            self.hessian_op.dim(),
424            beta_dim,
425            "InnerSolutionBuilder: Hessian dimension {} does not match beta length {}",
426            self.hessian_op.dim(),
427            beta_dim
428        );
429        for (idx, coord) in self.penalty_coords.iter().enumerate() {
430            assert_eq!(
431                coord.dim(),
432                beta_dim,
433                "InnerSolutionBuilder: penalty coordinate {idx} has dimension {} but beta length is {}",
434                coord.dim(),
435                beta_dim
436            );
437        }
438        assert_eq!(
439            self.penalty_logdet.first.len(),
440            penalty_dim,
441            "InnerSolutionBuilder: penalty logdet first-derivative length {} does not match penalty coordinate count {}",
442            self.penalty_logdet.first.len(),
443            penalty_dim
444        );
445        if let Some(second) = self.penalty_logdet.second.as_ref() {
446            assert!(
447                second.nrows() == penalty_dim && second.ncols() == penalty_dim,
448                "InnerSolutionBuilder: penalty logdet Hessian shape {}x{} does not match penalty coordinate count {}",
449                second.nrows(),
450                second.ncols(),
451                penalty_dim
452            );
453        }
454        if let Some(barrier_config) = self.barrier_config.as_ref() {
455            assert_eq!(
456                barrier_config.constrained_indices.len(),
457                barrier_config.lower_bounds.len(),
458                "InnerSolutionBuilder: barrier constrained index count {} does not match lower-bound count {}",
459                barrier_config.constrained_indices.len(),
460                barrier_config.lower_bounds.len()
461            );
462            assert_eq!(
463                barrier_config.constrained_indices.len(),
464                barrier_config.bound_signs.len(),
465                "InnerSolutionBuilder: barrier constrained index count {} does not match bound-direction count {}",
466                barrier_config.constrained_indices.len(),
467                barrier_config.bound_signs.len()
468            );
469            assert!(
470                barrier_config.tau.is_finite() && barrier_config.tau >= 0.0,
471                "InnerSolutionBuilder: barrier tau must be finite and non-negative, got {}",
472                barrier_config.tau
473            );
474            for ((&idx, &lower_bound), &sign) in barrier_config
475                .constrained_indices
476                .iter()
477                .zip(barrier_config.lower_bounds.iter())
478                .zip(barrier_config.bound_signs.iter())
479            {
480                assert!(
481                    idx < beta_dim,
482                    "InnerSolutionBuilder: barrier constrained index {idx} out of bounds for beta length {beta_dim}"
483                );
484                assert!(
485                    lower_bound.is_finite(),
486                    "InnerSolutionBuilder: barrier lower bound for beta[{idx}] must be finite, got {lower_bound}"
487                );
488                assert!(
489                    sign == 1.0 || sign == -1.0,
490                    "InnerSolutionBuilder: barrier bound direction for beta[{idx}] must be ±1, got {sign}"
491                );
492            }
493        }
494        if let Some(active_constraints) = self.active_constraints.as_ref() {
495            assert_eq!(
496                active_constraints.a.ncols(),
497                beta_dim,
498                "InnerSolutionBuilder: active constraint width {} does not match beta length {}",
499                active_constraints.a.ncols(),
500                beta_dim
501            );
502        }
503        let nullspace_dim = self.nullspace_dim_override.unwrap_or_else(|| {
504            let penalty_rank: usize = self
505                .penalty_coords
506                .iter()
507                .map(PenaltyCoordinate::rank)
508                .sum();
509            beta_dim.saturating_sub(penalty_rank) as f64
510        });
511
512        InnerSolution {
513            log_likelihood: self.log_likelihood,
514            penalty_quadratic: self.penalty_quadratic,
515            hessian_op: self.hessian_op,
516            beta: self.beta,
517            penalty_coords: self.penalty_coords,
518            penalty_logdet: self.penalty_logdet,
519            deriv_provider: self.deriv_provider,
520            firth: self.firth,
521            hessian_logdet_correction: self.hessian_logdet_correction,
522            penalty_subspace_trace: self.penalty_subspace_trace,
523            rho_curvature_scale: self.rho_curvature_scale,
524            rho_prior: self.rho_prior,
525            n_observations: self.n_observations,
526            nullspace_dim,
527            gaussian_weight_log_sum_half: self.gaussian_weight_log_sum_half,
528            dp_floor_scale: self.dp_floor_scale,
529            dispersion: self.dispersion,
530            ext_coords: self.ext_coords,
531            ext_coord_pair_fn: self.ext_coord_pair_fn,
532            rho_ext_pair_fn: self.rho_ext_pair_fn,
533            fixed_drift_deriv: self.fixed_drift_deriv,
534            contracted_psi_second_order: self.contracted_psi_second_order,
535            barrier_config: self.barrier_config,
536            kkt_residual: self.kkt_residual,
537            active_constraints: self.active_constraints,
538            stochastic_trace_state: Arc::new(Mutex::new(StochasticTraceState::default())),
539        }
540    }
541}