Skip to main content

gam_solve/reml/reml_outer_engine/
inner_solution.rs

1use super::*;
2
3/// Specifies whether the model uses profiled scale (Gaussian REML) or
4/// fixed dispersion (non-Gaussian LAML).
5#[derive(Clone, Debug)]
6pub enum DispersionHandling {
7    /// Gaussian REML: φ̂ = D_p / (n − M_p), profiled out of the objective.
8    /// The cost includes (n−M_p)/2 · log(2πφ̂) and the gradient includes
9    /// the profiled scale derivative. Always includes both logdet terms.
10    ProfiledGaussian,
11    /// Non-Gaussian LAML or maximum penalized likelihood.
12    ///
13    /// `include_logdet_h` controls whether ½ log|H| is included (true for full
14    /// LAML, false for MPL/PQL).
15    /// `include_logdet_s` controls whether −½ log|S|₊ is included.
16    ///
17    /// Standard LAML: `Fixed { phi: 1.0, include_logdet_h: true, include_logdet_s: true }`
18    /// MaxPenalizedLikelihood: `Fixed { phi: 1.0, include_logdet_h: false, include_logdet_s: false }`
19    Fixed {
20        phi: f64,
21        include_logdet_h: bool,
22        include_logdet_s: bool,
23    },
24}
25
26/// The unified inner solution produced by any inner solver.
27///
28/// Contains everything the outer REML/LAML evaluator needs. Produced by:
29/// - Single-block PIRLS (via `PirlsResult::into_inner_solution()`)
30/// - Blockwise coupled Newton (via `BlockwiseInnerResult::into_inner_solution()`)
31/// - Sparse Cholesky (via `SparsePenalizedSystem::into_inner_solution()`)
32pub struct InnerSolution<'dp> {
33    // === Objective ingredients ===
34    /// ℓ(β̂) — log-likelihood at the converged mode.
35    /// For Gaussian: −0.5 × deviance (RSS). For GLMs: actual log-likelihood.
36    pub log_likelihood: f64,
37
38    /// β̂ᵀS(ρ)β̂ — penalty quadratic form at the mode.
39    pub penalty_quadratic: f64,
40
41    // === The factorization (single source of truth for all linear algebra) ===
42    /// The Hessian operator providing logdet, trace, and solve.
43    /// Both cost and gradient use this same object.
44    ///
45    /// IMPORTANT: This MUST encode the **observed** Hessian H_obs = X'W_obs X + S
46    /// at the converged mode, where W_obs includes the residual-dependent correction
47    /// for non-canonical links. Using expected Fisher H_Fisher = X'W_Fisher X + S
48    /// would make this a PQL surrogate rather than the exact Laplace approximation.
49    /// See response.md Section 3 for the mathematical justification.
50    pub hessian_op: Arc<dyn HessianOperator>,
51
52    // === Coefficients and penalty structure ===
53    /// β̂ — coefficients at the converged mode (in the operator's native basis).
54    pub beta: Array1<f64>,
55
56    /// Penalty coordinates for the rho block.
57    ///
58    /// Each coordinate represents one smoothing-parameter direction
59    ///   A_k = λ_k S_k
60    /// through either a full-root or a block-local root.
61    pub penalty_coords: Vec<PenaltyCoordinate>,
62
63    /// Derivatives of log|S(ρ)|₊ — precomputed from penalty structure.
64    pub penalty_logdet: PenaltyLogdetDerivs,
65
66    // === Family-specific derivative info ===
67    /// Provider of third-derivative corrections for non-Gaussian families.
68    ///
69    /// The c and d arrays (dW/deta, d^2W/deta^2) carried by this provider MUST
70    /// be the **observed** derivatives, not the Fisher derivatives. For non-canonical
71    /// links the observed c/d include residual-dependent corrections:
72    ///   c_obs = c_Fisher + h'*B - (y-mu)*B_eta
73    ///   d_obs = d_Fisher + h''*B + 2*h'*B_eta - (y-mu)*B_etaeta
74    /// These corrections matter for the outer gradient (C[v] correction) and
75    /// outer Hessian (Q[v_k, v_l] correction). See response.md Section 3.
76    pub deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
77
78    // === Corrections ===
79    /// Optional exact Jeffreys/Firth term in the active coefficient basis.
80    pub firth: Option<ExactJeffreysTerm>,
81
82    /// Additive correction for the Hessian logdet when `hessian_op` encodes a
83    /// uniformly rescaled exact curvature matrix.
84    pub hessian_logdet_correction: f64,
85
86    /// When the cost uses `log|U_Sᵀ H U_S|_+` (rank-deficient LAML fix),
87    /// this carries the matching projected kernel so the gradient trace
88    /// `tr(K · Ḣ)` agrees with the cost's derivative.  See
89    /// [`PenaltySubspaceTrace`] for the full derivation.
90    pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
91
92    /// Uniform scale `s` applied to rho-coordinate penalty derivatives in the
93    /// H-dependent trace / solve parts of the outer calculus.
94    ///
95    /// ## Contract (CRITICAL — gradient/cost consistency)
96    ///
97    /// `rho_curvature_scale` is NOT a free knob.  It encodes the convention
98    /// that the supplied `hessian_op` represents the **rescaled** curvature
99    /// `H_op = s · (∇²(-ℓ) + Σ_k e^{ρ_k} S_k)`, i.e. every contribution to
100    /// the curvature (likelihood Hessian AND penalty `λ_k S_k`) has been
101    /// uniformly multiplied by `s` before reaching the evaluator.  Under this
102    /// convention:
103    ///
104    /// * `∂H_op/∂ρ_k = s · λ_k S_k` (matches the `curvature_lambdas = s · λ`
105    ///   drift used inside the gradient's trace term),
106    /// * `K = H_op⁻¹ = (1/s) · (∇²(-ℓ) + λS)⁻¹`,
107    /// * `tr(K · ∂H_op/∂ρ_k) = tr((∇²(-ℓ) + λS)⁻¹ · λ_k S_k)` (the analytic
108    ///   gradient of the **unscaled** `log|H|`),
109    /// * `log|H_op| = log|∇²(-ℓ) + λS| + p · log(s)`, which the caller MUST
110    ///   un-scale by supplying `hessian_logdet_correction += −p · log(s)` so
111    ///   that `hop.logdet() + hessian_logdet_correction` evaluates the same
112    ///   unscaled `log|H|` whose derivative the gradient trace computes.
113    ///
114    /// Callers that set `rho_curvature_scale ≠ 1` without ALSO pre-scaling
115    /// `hessian_op` AND adding the matching `−p·log(s)` term to
116    /// `hessian_logdet_correction` will get a gradient that is off by the
117    /// factor `s` from `dV/dρ_k`.  The unified evaluator does **not** scale
118    /// `hop` for the caller — that would defeat the purpose of the
119    /// curvature-conditioning trick survival families use to keep the
120    /// outer eigendecomposition numerically stable.
121    ///
122    /// See `survival::location_scale::exact_newton_outer_curvature` for the
123    /// canonical example: `rho_curvature_scale = exp(-log_scale)` paired with
124    /// `hessian_logdet_correction = p · log_scale = −p · log(scale)`.
125    ///
126    /// The evaluator enforces `rho_curvature_scale > 0` and finite; pass
127    /// `1.0` (the documented default) when no curvature conditioning is in
128    /// play.
129    pub rho_curvature_scale: f64,
130
131    /// Configured prior over rho coordinates. The evaluator receives the
132    /// realized cost/gradient tuple separately; this copy lets EFS use the
133    /// conjugate Gamma rate in its multiplicative denominator.
134    pub rho_prior: gam_problem::RhoPrior,
135
136    // === Model dimensions ===
137    /// Number of observations.
138    pub n_observations: usize,
139
140    /// M_p: dimension of the penalty null space (unpenalized coefficients).
141    pub nullspace_dim: f64,
142
143    /// ½·Σᵢ log(wᵢ) — half the sum of log prior weights.
144    ///
145    /// This is the per-observation Gaussian normalization constant that the
146    /// `log_likelihood` (computed by
147    /// [`calculate_loglikelihood_omitting_constants`]) deliberately drops. The
148    /// full weighted-Gaussian negative log-likelihood normalization is
149    ///   ½·Σᵢ log(2π·φ/wᵢ) = (n/2)·log(2πφ) − ½·Σᵢ log(wᵢ),
150    /// because `Var(yᵢ) = φ/wᵢ` under inverse-variance prior weights.
151    ///
152    /// Dropping `−½·Σ log(wᵢ)` does not move the ρ-argmin in exact arithmetic
153    /// (it is constant in ρ), but it makes the ProfiledGaussian objective VALUE
154    /// scale-dependent: under a global rescale `w → c·w` the invariance-
155    /// preserving smoothing `λ → c·λ` leaves the cost SHAPE fixed but inflates
156    /// its absolute value by `(n/2)·log c`. That inflation breaks the exact
157    /// weight-scale invariance of the selected λ̂ / EDF / fit (issue #877).
158    /// Restoring this term makes the ProfiledGaussian cost value exactly
159    /// invariant to `w → c·w` (with σ̂² absorbing the c factor), matching mgcv.
160    ///
161    /// Only consumed by the `ProfiledGaussian` arm; the `Fixed`-dispersion arm
162    /// already omits the Gaussian normalization constant by design and is not
163    /// affected.
164    pub gaussian_weight_log_sum_half: f64,
165
166    /// Deviance scale `D₀` used as the *relative* reference for the smooth
167    /// penalized-deviance floor (see [`crate::estimate::smooth_floor_dp`]).
168    ///
169    /// Set to the weighted null deviance of the Gaussian response,
170    /// `D₀ = Σ wᵢ(yᵢ − ȳ_w)²`, which is the natural upper reference for
171    /// `D_p` and — crucially — transforms as `D₀ → a²·D₀` under a response
172    /// rescale `y → a·y`, exactly as `D_p` does. Flooring `D_p` at a fixed
173    /// fraction of `D₀` therefore keeps the profiled Gaussian REML criterion
174    /// exactly scale-equivariant (issue #1127); an absolute floor does not.
175    ///
176    /// Only consumed by the `ProfiledGaussian` arm. Defaults to `1.0`, which
177    /// reproduces the historical absolute floor byte-for-byte for every caller
178    /// that does not supply a response scale.
179    pub dp_floor_scale: f64,
180
181    /// How the dispersion parameter is handled.
182    pub dispersion: DispersionHandling,
183
184    // === Extended hyperparameter coordinates (ψ / τ) ===
185    /// External (non-ρ) hyperparameter coordinates with their fixed-β objects.
186    /// These are appended after the ρ coordinates in the gradient/Hessian output.
187    pub ext_coords: Vec<HyperCoord>,
188
189    /// Callback to compute second-order fixed-β objects for a pair (i, j)
190    /// of external coordinates (or external × ρ cross pairs).
191    /// Arguments: (ext_index_i, ext_index_j) → HyperCoordPair.
192    /// When None, the outer Hessian is not computed for extended coordinates.
193    ///
194    /// `Arc`-backed ([`HyperCoordPairFn`]) so a derived solution — notably the
195    /// tangent-projected solution built under active inequality constraints —
196    /// can clone the same callback through to `ValueGradientHessian` assembly.
197    pub ext_coord_pair_fn: Option<HyperCoordPairFn>,
198
199    /// Callback for ρ × ext cross pairs: (rho_index, ext_index) → HyperCoordPair.
200    pub rho_ext_pair_fn: Option<HyperCoordPairFn>,
201
202    /// M_i[u] = D_β B_i[u] callback for extended coordinates.
203    /// Arguments: (ext_index, direction) → correction matrix.
204    ///
205    /// `Arc`-backed ([`SharedFixedDriftDerivFn`]) so the tangent-projected
206    /// solution can clone it through — the drift `M` is a p-space matrix the
207    /// wrapper projects via `ZᵀMZ` in `trace_logdet_*`.
208    pub fixed_drift_deriv: Option<SharedFixedDriftDerivFn>,
209
210    /// Direction-contracted second-order ψ hook for the profiled θ-HVP (#740).
211    /// When present, the outer-Hessian operator builder skips the `K²` per-pair
212    /// `base_h2` ψψ assembly and instead applies this once per matvec to obtain
213    /// every output row's `tr(K · D²_ψ H_L[ψ_i, ψ(α)])` in a single family row
214    /// pass. `None` keeps the exact per-pair assembly. See
215    /// [`ContractedPsiSecondOrderFn`].
216    pub contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
217
218    /// Optional log-barrier configuration for monotonicity-constrained coefficients.
219    /// When present, the barrier cost and Hessian corrections are added to the
220    /// outer REML/LAML objective.
221    pub barrier_config: Option<BarrierConfig>,
222
223    /// Optional inner KKT residual `r = ∇_β L_pen(β̂)` at the converged β̂,
224    /// already projected onto the free subspace (see [`ProjectedKktResidual`]
225    /// for the invariant and why the type wraps this). `Some` activates the
226    /// implicit-function-theorem corrections in `reml_laml_evaluate` (cost
227    /// gets `−½ rᵀ H⁻¹ r`, ρ-gradient and ρρ Hessian get the matching first
228    /// and second derivatives of that same scalar correction). `None` keeps
229    /// the envelope-only behaviour for callers that genuinely guarantee
230    /// exact KKT.
231    pub kkt_residual: Option<ProjectedKktResidual>,
232
233    /// Optional active linear-inequality constraints at the converged inner
234    /// iterate. `Some(rows)` means the joint constraint matrix's row indices
235    /// in `rows.active_indices` are pinned (treated as equality constraints
236    /// at the cert point). The unified evaluator combines this with the
237    /// `penalty_subspace_trace` to form the **constraint-aware** kernel
238    /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S` for per-coordinate IFT mode
239    /// responses `v_k = ∂β/∂ρ_k`. See [`ConstrainedSubspaceKernel`] for
240    /// the full derivation and consistency with `log|U_Tᵀ H U_T|`.
241    ///
242    /// `None` is the legacy/unconstrained path (no active inequality
243    /// constraints to project against).
244    pub active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
245
246    /// Fit-level stochastic trace state. Shared by stochastic trace batches so
247    /// CRN probe prefixes stay fixed and matrix-free trace CG can warm-start
248    /// from the previous solve of the same probe id.
249    pub stochastic_trace_state: Arc<Mutex<StochasticTraceState>>,
250}
251
252/// Builder for `InnerSolution` that provides sensible defaults and
253/// auto-computes derived quantities (nullspace_dim).
254pub struct InnerSolutionBuilder<'dp> {
255    // Required fields
256    pub(crate) log_likelihood: f64,
257    pub(crate) penalty_quadratic: f64,
258    pub(crate) hessian_op: Arc<dyn HessianOperator>,
259    pub(crate) beta: Array1<f64>,
260    pub(crate) penalty_coords: Vec<PenaltyCoordinate>,
261    pub(crate) penalty_logdet: PenaltyLogdetDerivs,
262    pub(crate) n_observations: usize,
263    pub(crate) dispersion: DispersionHandling,
264    // Optional fields with defaults
265    pub(crate) deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
266    pub(crate) firth: Option<ExactJeffreysTerm>,
267    pub(crate) hessian_logdet_correction: f64,
268    pub(crate) penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
269    pub(crate) rho_curvature_scale: f64,
270    pub(crate) rho_prior: gam_problem::RhoPrior,
271    pub(crate) nullspace_dim_override: Option<f64>,
272    // Extended hyperparameter coordinates
273    pub(crate) ext_coords: Vec<HyperCoord>,
274    pub(crate) ext_coord_pair_fn: Option<HyperCoordPairFn>,
275    pub(crate) rho_ext_pair_fn: Option<HyperCoordPairFn>,
276    pub(crate) fixed_drift_deriv: Option<SharedFixedDriftDerivFn>,
277    pub(crate) contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
278    pub(crate) barrier_config: Option<BarrierConfig>,
279    pub(crate) kkt_residual: Option<ProjectedKktResidual>,
280    pub(crate) active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
281    pub(crate) gaussian_weight_log_sum_half: f64,
282    pub(crate) dp_floor_scale: f64,
283}
284
285impl<'dp> InnerSolutionBuilder<'dp> {
286    /// Create a builder with the required core fields.
287    pub fn new(
288        log_likelihood: f64,
289        penalty_quadratic: f64,
290        beta: Array1<f64>,
291        n_observations: usize,
292        hessian_op: Arc<dyn HessianOperator>,
293        penalty_coords: Vec<PenaltyCoordinate>,
294        penalty_logdet: PenaltyLogdetDerivs,
295        dispersion: DispersionHandling,
296    ) -> Self {
297        Self {
298            log_likelihood,
299            penalty_quadratic,
300            hessian_op,
301            beta,
302            penalty_coords,
303            penalty_logdet,
304            n_observations,
305            dispersion,
306            deriv_provider: Box::new(GaussianDerivatives),
307            firth: None,
308            hessian_logdet_correction: 0.0,
309            penalty_subspace_trace: None,
310            rho_curvature_scale: 1.0,
311            rho_prior: gam_problem::RhoPrior::Flat,
312            nullspace_dim_override: None,
313            ext_coords: Vec::new(),
314            ext_coord_pair_fn: None,
315            rho_ext_pair_fn: None,
316            fixed_drift_deriv: None,
317            contracted_psi_second_order: None,
318            barrier_config: None,
319            kkt_residual: None,
320            active_constraints: None,
321            gaussian_weight_log_sum_half: 0.0,
322            dp_floor_scale: 1.0,
323        }
324    }
325
326    pub fn deriv_provider(mut self, p: Box<dyn HessianDerivativeProvider + 'dp>) -> Self {
327        self.deriv_provider = p;
328        self
329    }
330
331    /// Install a pre-built Jeffreys/Firth term (Tier-A operator-backed via
332    /// `ExactJeffreysTerm::new`, or the Tier-B value-only carrier via
333    /// `ExactJeffreysTerm::value_only`).
334    pub fn firth_term(mut self, term: Option<ExactJeffreysTerm>) -> Self {
335        self.firth = term;
336        self
337    }
338
339    pub fn hessian_logdet_correction(mut self, correction: f64) -> Self {
340        self.hessian_logdet_correction = correction;
341        self
342    }
343
344    /// Install the projected-logdet trace kernel that pairs with the
345    /// `hessian_logdet_correction` on a rank-deficient penalty surface.
346    /// See [`PenaltySubspaceTrace`] for the derivation and when it is
347    /// required for gradient consistency.
348    pub fn penalty_subspace_trace(mut self, kernel: Option<Arc<PenaltySubspaceTrace>>) -> Self {
349        self.penalty_subspace_trace = kernel;
350        self
351    }
352
353    pub fn rho_curvature_scale(mut self, scale: f64) -> Self {
354        self.rho_curvature_scale = scale;
355        self
356    }
357
358    pub fn rho_prior(mut self, prior: gam_problem::RhoPrior) -> Self {
359        self.rho_prior = prior;
360        self
361    }
362
363    /// Override the auto-computed nullspace dimension.
364    ///
365    /// By default, `build()` computes nullspace_dim as
366    /// `beta.len() - sum(penalty_coord.rank())`. Use this when the caller
367    /// has a different authoritative value (e.g. from stored per-penalty dims).
368    pub fn nullspace_dim_override(mut self, dim: f64) -> Self {
369        self.nullspace_dim_override = Some(dim);
370        self
371    }
372
373    pub fn ext_coords(mut self, coords: Vec<HyperCoord>) -> Self {
374        self.ext_coords = coords;
375        self
376    }
377
378    pub fn ext_coord_pair_fn(
379        mut self,
380        f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
381    ) -> Self {
382        // `Arc::from(Box<dyn …>)` is a zero-cost re-tag of the existing
383        // allocation; storing it as `Arc` lets the projected solution clone
384        // the callback through.
385        self.ext_coord_pair_fn = Some(HyperCoordPairFn::from(f));
386        self
387    }
388
389    pub fn rho_ext_pair_fn(
390        mut self,
391        f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
392    ) -> Self {
393        self.rho_ext_pair_fn = Some(HyperCoordPairFn::from(f));
394        self
395    }
396
397    pub fn fixed_drift_deriv(mut self, f: FixedDriftDerivFn) -> Self {
398        // `Arc::from(Box<dyn …>)` re-tags the existing allocation for shared
399        // ownership so the projected solution can clone the callback through.
400        self.fixed_drift_deriv = Some(SharedFixedDriftDerivFn::from(f));
401        self
402    }
403
404    /// Install the direction-contracted second-order ψ hook (#740). When set,
405    /// the outer-Hessian operator builder uses it instead of the `K²` per-pair
406    /// `base_h2` ψψ assembly. See [`ContractedPsiSecondOrderFn`].
407    pub fn contracted_psi_second_order(mut self, f: Option<ContractedPsiSecondOrderFn>) -> Self {
408        self.contracted_psi_second_order = f;
409        self
410    }
411
412    pub fn barrier_config(mut self, config: Option<BarrierConfig>) -> Self {
413        self.barrier_config = config;
414        self
415    }
416
417    pub fn kkt_residual(mut self, residual: Option<ProjectedKktResidual>) -> Self {
418        self.kkt_residual = residual;
419        self
420    }
421
422    /// Stash the active linear-inequality constraint block carried alongside the
423    /// inner solution. Used by `PenaltySubspaceTrace::with_active_constraints`
424    /// at REML/LAML evaluation time to form the constraint-aware kernel
425    /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S`.
426    pub fn active_constraints(mut self, block: Option<Arc<ActiveLinearConstraintBlock>>) -> Self {
427        self.active_constraints = block;
428        self
429    }
430
431    /// Build the `InnerSolution`, auto-computing nullspace_dim from penalty coordinates.
432    pub fn build(self) -> InnerSolution<'dp> {
433        let beta_dim = self.beta.len();
434        let penalty_dim = self.penalty_coords.len();
435        assert_eq!(
436            self.hessian_op.dim(),
437            beta_dim,
438            "InnerSolutionBuilder: Hessian dimension {} does not match beta length {}",
439            self.hessian_op.dim(),
440            beta_dim
441        );
442        for (idx, coord) in self.penalty_coords.iter().enumerate() {
443            assert_eq!(
444                coord.dim(),
445                beta_dim,
446                "InnerSolutionBuilder: penalty coordinate {idx} has dimension {} but beta length is {}",
447                coord.dim(),
448                beta_dim
449            );
450        }
451        assert_eq!(
452            self.penalty_logdet.first.len(),
453            penalty_dim,
454            "InnerSolutionBuilder: penalty logdet first-derivative length {} does not match penalty coordinate count {}",
455            self.penalty_logdet.first.len(),
456            penalty_dim
457        );
458        if let Some(second) = self.penalty_logdet.second.as_ref() {
459            assert!(
460                second.nrows() == penalty_dim && second.ncols() == penalty_dim,
461                "InnerSolutionBuilder: penalty logdet Hessian shape {}x{} does not match penalty coordinate count {}",
462                second.nrows(),
463                second.ncols(),
464                penalty_dim
465            );
466        }
467        if let Some(barrier_config) = self.barrier_config.as_ref() {
468            assert_eq!(
469                barrier_config.constrained_indices.len(),
470                barrier_config.lower_bounds.len(),
471                "InnerSolutionBuilder: barrier constrained index count {} does not match lower-bound count {}",
472                barrier_config.constrained_indices.len(),
473                barrier_config.lower_bounds.len()
474            );
475            assert_eq!(
476                barrier_config.constrained_indices.len(),
477                barrier_config.bound_signs.len(),
478                "InnerSolutionBuilder: barrier constrained index count {} does not match bound-direction count {}",
479                barrier_config.constrained_indices.len(),
480                barrier_config.bound_signs.len()
481            );
482            assert!(
483                barrier_config.tau.is_finite() && barrier_config.tau >= 0.0,
484                "InnerSolutionBuilder: barrier tau must be finite and non-negative, got {}",
485                barrier_config.tau
486            );
487            for ((&idx, &lower_bound), &sign) in barrier_config
488                .constrained_indices
489                .iter()
490                .zip(barrier_config.lower_bounds.iter())
491                .zip(barrier_config.bound_signs.iter())
492            {
493                assert!(
494                    idx < beta_dim,
495                    "InnerSolutionBuilder: barrier constrained index {idx} out of bounds for beta length {beta_dim}"
496                );
497                assert!(
498                    lower_bound.is_finite(),
499                    "InnerSolutionBuilder: barrier lower bound for beta[{idx}] must be finite, got {lower_bound}"
500                );
501                assert!(
502                    sign == 1.0 || sign == -1.0,
503                    "InnerSolutionBuilder: barrier bound direction for beta[{idx}] must be ±1, got {sign}"
504                );
505            }
506        }
507        if let Some(active_constraints) = self.active_constraints.as_ref() {
508            assert_eq!(
509                active_constraints.a.ncols(),
510                beta_dim,
511                "InnerSolutionBuilder: active constraint width {} does not match beta length {}",
512                active_constraints.a.ncols(),
513                beta_dim
514            );
515        }
516        let nullspace_dim = self.nullspace_dim_override.unwrap_or_else(|| {
517            let penalty_rank: usize = self
518                .penalty_coords
519                .iter()
520                .map(PenaltyCoordinate::rank)
521                .sum();
522            beta_dim.saturating_sub(penalty_rank) as f64
523        });
524
525        InnerSolution {
526            log_likelihood: self.log_likelihood,
527            penalty_quadratic: self.penalty_quadratic,
528            hessian_op: self.hessian_op,
529            beta: self.beta,
530            penalty_coords: self.penalty_coords,
531            penalty_logdet: self.penalty_logdet,
532            deriv_provider: self.deriv_provider,
533            firth: self.firth,
534            hessian_logdet_correction: self.hessian_logdet_correction,
535            penalty_subspace_trace: self.penalty_subspace_trace,
536            rho_curvature_scale: self.rho_curvature_scale,
537            rho_prior: self.rho_prior,
538            n_observations: self.n_observations,
539            nullspace_dim,
540            gaussian_weight_log_sum_half: self.gaussian_weight_log_sum_half,
541            dp_floor_scale: self.dp_floor_scale,
542            dispersion: self.dispersion,
543            ext_coords: self.ext_coords,
544            ext_coord_pair_fn: self.ext_coord_pair_fn,
545            rho_ext_pair_fn: self.rho_ext_pair_fn,
546            fixed_drift_deriv: self.fixed_drift_deriv,
547            contracted_psi_second_order: self.contracted_psi_second_order,
548            barrier_config: self.barrier_config,
549            kkt_residual: self.kkt_residual,
550            active_constraints: self.active_constraints,
551            stochastic_trace_state: Arc::new(Mutex::new(StochasticTraceState::default())),
552        }
553    }
554}