gam_solve/reml/reml_outer_engine/inner_solution.rs
1use super::*;
2
3/// Specifies whether the model uses profiled scale (Gaussian REML) or
4/// fixed dispersion (non-Gaussian LAML).
5#[derive(Clone, Debug)]
6pub enum DispersionHandling {
7 /// Gaussian REML: φ̂ = D_p / (n − M_p), profiled out of the objective.
8 /// The cost includes (n−M_p)/2 · log(2πφ̂) and the gradient includes
9 /// the profiled scale derivative. Always includes both logdet terms.
10 ProfiledGaussian,
11 /// Non-Gaussian LAML or maximum penalized likelihood.
12 ///
13 /// `include_logdet_h` controls whether ½ log|H| is included (true for full
14 /// LAML, false for MPL/PQL).
15 /// `include_logdet_s` controls whether −½ log|S|₊ is included.
16 ///
17 /// Standard LAML: `Fixed { phi: 1.0, include_logdet_h: true, include_logdet_s: true }`
18 /// MaxPenalizedLikelihood: `Fixed { phi: 1.0, include_logdet_h: false, include_logdet_s: false }`
19 Fixed {
20 phi: f64,
21 include_logdet_h: bool,
22 include_logdet_s: bool,
23 },
24}
25
26/// The unified inner solution produced by any inner solver.
27///
28/// Contains everything the outer REML/LAML evaluator needs. Produced by:
29/// - Single-block PIRLS (via `PirlsResult::into_inner_solution()`)
30/// - Blockwise coupled Newton (via `BlockwiseInnerResult::into_inner_solution()`)
31/// - Sparse Cholesky (via `SparsePenalizedSystem::into_inner_solution()`)
32pub struct InnerSolution<'dp> {
33 // === Objective ingredients ===
34 /// ℓ(β̂) — log-likelihood at the converged mode.
35 /// For Gaussian: −0.5 × deviance (RSS). For GLMs: actual log-likelihood.
36 pub log_likelihood: f64,
37
38 /// β̂ᵀS(ρ)β̂ — penalty quadratic form at the mode.
39 pub penalty_quadratic: f64,
40
41 // === The factorization (single source of truth for all linear algebra) ===
42 /// The Hessian operator providing logdet, trace, and solve.
43 /// Both cost and gradient use this same object.
44 ///
45 /// IMPORTANT: This MUST encode the **observed** Hessian H_obs = X'W_obs X + S
46 /// at the converged mode, where W_obs includes the residual-dependent correction
47 /// for non-canonical links. Using expected Fisher H_Fisher = X'W_Fisher X + S
48 /// would make this a PQL surrogate rather than the exact Laplace approximation.
49 /// See response.md Section 3 for the mathematical justification.
50 pub hessian_op: Arc<dyn HessianOperator>,
51
52 // === Coefficients and penalty structure ===
53 /// β̂ — coefficients at the converged mode (in the operator's native basis).
54 pub beta: Array1<f64>,
55
56 /// Penalty coordinates for the rho block.
57 ///
58 /// Each coordinate represents one smoothing-parameter direction
59 /// A_k = λ_k S_k
60 /// through either a full-root or a block-local root.
61 pub penalty_coords: Vec<PenaltyCoordinate>,
62
63 /// Derivatives of log|S(ρ)|₊ — precomputed from penalty structure.
64 pub penalty_logdet: PenaltyLogdetDerivs,
65
66 // === Family-specific derivative info ===
67 /// Provider of third-derivative corrections for non-Gaussian families.
68 ///
69 /// The c and d arrays (dW/deta, d^2W/deta^2) carried by this provider MUST
70 /// be the **observed** derivatives, not the Fisher derivatives. For non-canonical
71 /// links the observed c/d include residual-dependent corrections:
72 /// c_obs = c_Fisher + h'*B - (y-mu)*B_eta
73 /// d_obs = d_Fisher + h''*B + 2*h'*B_eta - (y-mu)*B_etaeta
74 /// These corrections matter for the outer gradient (C[v] correction) and
75 /// outer Hessian (Q[v_k, v_l] correction). See response.md Section 3.
76 pub deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
77
78 // === Corrections ===
79 /// Optional exact Jeffreys/Firth term in the active coefficient basis.
80 pub firth: Option<ExactJeffreysTerm>,
81
82 /// Additive correction for the Hessian logdet when `hessian_op` encodes a
83 /// uniformly rescaled exact curvature matrix.
84 pub hessian_logdet_correction: f64,
85
86 /// When the cost uses `log|U_Sᵀ H U_S|_+` (rank-deficient LAML fix),
87 /// this carries the matching projected kernel so the gradient trace
88 /// `tr(K · Ḣ)` agrees with the cost's derivative. See
89 /// [`PenaltySubspaceTrace`] for the full derivation.
90 pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
91
92 /// Uniform scale `s` applied to rho-coordinate penalty derivatives in the
93 /// H-dependent trace / solve parts of the outer calculus.
94 ///
95 /// ## Contract (CRITICAL — gradient/cost consistency)
96 ///
97 /// `rho_curvature_scale` is NOT a free knob. It encodes the convention
98 /// that the supplied `hessian_op` represents the **rescaled** curvature
99 /// `H_op = s · (∇²(-ℓ) + Σ_k e^{ρ_k} S_k)`, i.e. every contribution to
100 /// the curvature (likelihood Hessian AND penalty `λ_k S_k`) has been
101 /// uniformly multiplied by `s` before reaching the evaluator. Under this
102 /// convention:
103 ///
104 /// * `∂H_op/∂ρ_k = s · λ_k S_k` (matches the `curvature_lambdas = s · λ`
105 /// drift used inside the gradient's trace term),
106 /// * `K = H_op⁻¹ = (1/s) · (∇²(-ℓ) + λS)⁻¹`,
107 /// * `tr(K · ∂H_op/∂ρ_k) = tr((∇²(-ℓ) + λS)⁻¹ · λ_k S_k)` (the analytic
108 /// gradient of the **unscaled** `log|H|`),
109 /// * `log|H_op| = log|∇²(-ℓ) + λS| + p · log(s)`, which the caller MUST
110 /// un-scale by supplying `hessian_logdet_correction += −p · log(s)` so
111 /// that `hop.logdet() + hessian_logdet_correction` evaluates the same
112 /// unscaled `log|H|` whose derivative the gradient trace computes.
113 ///
114 /// Callers that set `rho_curvature_scale ≠ 1` without ALSO pre-scaling
115 /// `hessian_op` AND adding the matching `−p·log(s)` term to
116 /// `hessian_logdet_correction` will get a gradient that is off by the
117 /// factor `s` from `dV/dρ_k`. The unified evaluator does **not** scale
118 /// `hop` for the caller — that would defeat the purpose of the
119 /// curvature-conditioning trick survival families use to keep the
120 /// outer eigendecomposition numerically stable.
121 ///
122 /// See `survival::location_scale::exact_newton_outer_curvature` for the
123 /// canonical example: `rho_curvature_scale = exp(-log_scale)` paired with
124 /// `hessian_logdet_correction = p · log_scale = −p · log(scale)`.
125 ///
126 /// The evaluator enforces `rho_curvature_scale > 0` and finite; pass
127 /// `1.0` (the documented default) when no curvature conditioning is in
128 /// play.
129 pub rho_curvature_scale: f64,
130
131 /// Configured prior over rho coordinates. The evaluator receives the
132 /// realized cost/gradient tuple separately; this copy lets EFS use the
133 /// conjugate Gamma rate in its multiplicative denominator.
134 pub rho_prior: gam_problem::RhoPrior,
135
136 // === Model dimensions ===
137 /// Number of observations.
138 pub n_observations: usize,
139
140 /// M_p: dimension of the penalty null space (unpenalized coefficients).
141 pub nullspace_dim: f64,
142
143 /// ½·Σᵢ log(wᵢ) — half the sum of log prior weights.
144 ///
145 /// This is the per-observation Gaussian normalization constant that the
146 /// `log_likelihood` (computed by
147 /// [`calculate_loglikelihood_omitting_constants`]) deliberately drops. The
148 /// full weighted-Gaussian negative log-likelihood normalization is
149 /// ½·Σᵢ log(2π·φ/wᵢ) = (n/2)·log(2πφ) − ½·Σᵢ log(wᵢ),
150 /// because `Var(yᵢ) = φ/wᵢ` under inverse-variance prior weights.
151 ///
152 /// Dropping `−½·Σ log(wᵢ)` does not move the ρ-argmin in exact arithmetic
153 /// (it is constant in ρ), but it makes the ProfiledGaussian objective VALUE
154 /// scale-dependent: under a global rescale `w → c·w` the invariance-
155 /// preserving smoothing `λ → c·λ` leaves the cost SHAPE fixed but inflates
156 /// its absolute value by `(n/2)·log c`. That inflation breaks the exact
157 /// weight-scale invariance of the selected λ̂ / EDF / fit (issue #877).
158 /// Restoring this term makes the ProfiledGaussian cost value exactly
159 /// invariant to `w → c·w` (with σ̂² absorbing the c factor), matching mgcv.
160 ///
161 /// Only consumed by the `ProfiledGaussian` arm; the `Fixed`-dispersion arm
162 /// already omits the Gaussian normalization constant by design and is not
163 /// affected.
164 pub gaussian_weight_log_sum_half: f64,
165
166 /// Deviance scale `D₀` used as the *relative* reference for the smooth
167 /// penalized-deviance floor (see [`crate::estimate::smooth_floor_dp`]).
168 ///
169 /// Set to the weighted null deviance of the Gaussian response,
170 /// `D₀ = Σ wᵢ(yᵢ − ȳ_w)²`, which is the natural upper reference for
171 /// `D_p` and — crucially — transforms as `D₀ → a²·D₀` under a response
172 /// rescale `y → a·y`, exactly as `D_p` does. Flooring `D_p` at a fixed
173 /// fraction of `D₀` therefore keeps the profiled Gaussian REML criterion
174 /// exactly scale-equivariant (issue #1127); an absolute floor does not.
175 ///
176 /// Only consumed by the `ProfiledGaussian` arm. Defaults to `1.0`, which
177 /// reproduces the historical absolute floor byte-for-byte for every caller
178 /// that does not supply a response scale.
179 pub dp_floor_scale: f64,
180
181 /// How the dispersion parameter is handled.
182 pub dispersion: DispersionHandling,
183
184 // === Extended hyperparameter coordinates (ψ / τ) ===
185 /// External (non-ρ) hyperparameter coordinates with their fixed-β objects.
186 /// These are appended after the ρ coordinates in the gradient/Hessian output.
187 pub ext_coords: Vec<HyperCoord>,
188
189 /// Callback to compute second-order fixed-β objects for a pair (i, j)
190 /// of external coordinates (or external × ρ cross pairs).
191 /// Arguments: (ext_index_i, ext_index_j) → HyperCoordPair.
192 /// When None, the outer Hessian is not computed for extended coordinates.
193 pub ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
194
195 /// Callback for ρ × ext cross pairs: (rho_index, ext_index) → HyperCoordPair.
196 pub rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
197
198 /// M_i[u] = D_β B_i[u] callback for extended coordinates.
199 /// Arguments: (ext_index, direction) → correction matrix.
200 pub fixed_drift_deriv: Option<FixedDriftDerivFn>,
201
202 /// Direction-contracted second-order ψ hook for the profiled θ-HVP (#740).
203 /// When present, the outer-Hessian operator builder skips the `K²` per-pair
204 /// `base_h2` ψψ assembly and instead applies this once per matvec to obtain
205 /// every output row's `tr(K · D²_ψ H_L[ψ_i, ψ(α)])` in a single family row
206 /// pass. `None` keeps the exact per-pair assembly. See
207 /// [`ContractedPsiSecondOrderFn`].
208 pub contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
209
210 /// Optional log-barrier configuration for monotonicity-constrained coefficients.
211 /// When present, the barrier cost and Hessian corrections are added to the
212 /// outer REML/LAML objective.
213 pub barrier_config: Option<BarrierConfig>,
214
215 /// Optional inner KKT residual `r = ∇_β L_pen(β̂)` at the converged β̂,
216 /// already projected onto the free subspace (see [`ProjectedKktResidual`]
217 /// for the invariant and why the type wraps this). `Some` activates the
218 /// implicit-function-theorem corrections in `reml_laml_evaluate` (cost
219 /// gets `−½ rᵀ H⁻¹ r`, ρ-gradient and ρρ Hessian get the matching first
220 /// and second derivatives of that same scalar correction). `None` keeps
221 /// the envelope-only behaviour for callers that genuinely guarantee
222 /// exact KKT.
223 pub kkt_residual: Option<ProjectedKktResidual>,
224
225 /// Optional active linear-inequality constraints at the converged inner
226 /// iterate. `Some(rows)` means the joint constraint matrix's row indices
227 /// in `rows.active_indices` are pinned (treated as equality constraints
228 /// at the cert point). The unified evaluator combines this with the
229 /// `penalty_subspace_trace` to form the **constraint-aware** kernel
230 /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S` for per-coordinate IFT mode
231 /// responses `v_k = ∂β/∂ρ_k`. See [`ConstrainedSubspaceKernel`] for
232 /// the full derivation and consistency with `log|U_Tᵀ H U_T|`.
233 ///
234 /// `None` is the legacy/unconstrained path (no active inequality
235 /// constraints to project against).
236 pub active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
237
238 /// Fit-level stochastic trace state. Shared by stochastic trace batches so
239 /// CRN probe prefixes stay fixed and matrix-free trace CG can warm-start
240 /// from the previous solve of the same probe id.
241 pub stochastic_trace_state: Arc<Mutex<StochasticTraceState>>,
242}
243
244/// Builder for `InnerSolution` that provides sensible defaults and
245/// auto-computes derived quantities (nullspace_dim).
246pub struct InnerSolutionBuilder<'dp> {
247 // Required fields
248 pub(crate) log_likelihood: f64,
249 pub(crate) penalty_quadratic: f64,
250 pub(crate) hessian_op: Arc<dyn HessianOperator>,
251 pub(crate) beta: Array1<f64>,
252 pub(crate) penalty_coords: Vec<PenaltyCoordinate>,
253 pub(crate) penalty_logdet: PenaltyLogdetDerivs,
254 pub(crate) n_observations: usize,
255 pub(crate) dispersion: DispersionHandling,
256 // Optional fields with defaults
257 pub(crate) deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
258 pub(crate) firth: Option<ExactJeffreysTerm>,
259 pub(crate) hessian_logdet_correction: f64,
260 pub(crate) penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
261 pub(crate) rho_curvature_scale: f64,
262 pub(crate) rho_prior: gam_problem::RhoPrior,
263 pub(crate) nullspace_dim_override: Option<f64>,
264 // Extended hyperparameter coordinates
265 pub(crate) ext_coords: Vec<HyperCoord>,
266 pub(crate) ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
267 pub(crate) rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
268 pub(crate) fixed_drift_deriv: Option<FixedDriftDerivFn>,
269 pub(crate) contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
270 pub(crate) barrier_config: Option<BarrierConfig>,
271 pub(crate) kkt_residual: Option<ProjectedKktResidual>,
272 pub(crate) active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
273 pub(crate) gaussian_weight_log_sum_half: f64,
274 pub(crate) dp_floor_scale: f64,
275}
276
277impl<'dp> InnerSolutionBuilder<'dp> {
278 /// Create a builder with the required core fields.
279 pub fn new(
280 log_likelihood: f64,
281 penalty_quadratic: f64,
282 beta: Array1<f64>,
283 n_observations: usize,
284 hessian_op: Arc<dyn HessianOperator>,
285 penalty_coords: Vec<PenaltyCoordinate>,
286 penalty_logdet: PenaltyLogdetDerivs,
287 dispersion: DispersionHandling,
288 ) -> Self {
289 Self {
290 log_likelihood,
291 penalty_quadratic,
292 hessian_op,
293 beta,
294 penalty_coords,
295 penalty_logdet,
296 n_observations,
297 dispersion,
298 deriv_provider: Box::new(GaussianDerivatives),
299 firth: None,
300 hessian_logdet_correction: 0.0,
301 penalty_subspace_trace: None,
302 rho_curvature_scale: 1.0,
303 rho_prior: gam_problem::RhoPrior::Flat,
304 nullspace_dim_override: None,
305 ext_coords: Vec::new(),
306 ext_coord_pair_fn: None,
307 rho_ext_pair_fn: None,
308 fixed_drift_deriv: None,
309 contracted_psi_second_order: None,
310 barrier_config: None,
311 kkt_residual: None,
312 active_constraints: None,
313 gaussian_weight_log_sum_half: 0.0,
314 dp_floor_scale: 1.0,
315 }
316 }
317
318 pub fn deriv_provider(mut self, p: Box<dyn HessianDerivativeProvider + 'dp>) -> Self {
319 self.deriv_provider = p;
320 self
321 }
322
323 /// Install a pre-built Jeffreys/Firth term (Tier-A operator-backed via
324 /// `ExactJeffreysTerm::new`, or the Tier-B value-only carrier via
325 /// `ExactJeffreysTerm::value_only`).
326 pub fn firth_term(mut self, term: Option<ExactJeffreysTerm>) -> Self {
327 self.firth = term;
328 self
329 }
330
331 pub fn hessian_logdet_correction(mut self, correction: f64) -> Self {
332 self.hessian_logdet_correction = correction;
333 self
334 }
335
336 /// Install the projected-logdet trace kernel that pairs with the
337 /// `hessian_logdet_correction` on a rank-deficient penalty surface.
338 /// See [`PenaltySubspaceTrace`] for the derivation and when it is
339 /// required for gradient consistency.
340 pub fn penalty_subspace_trace(mut self, kernel: Option<Arc<PenaltySubspaceTrace>>) -> Self {
341 self.penalty_subspace_trace = kernel;
342 self
343 }
344
345 pub fn rho_curvature_scale(mut self, scale: f64) -> Self {
346 self.rho_curvature_scale = scale;
347 self
348 }
349
350 pub fn rho_prior(mut self, prior: gam_problem::RhoPrior) -> Self {
351 self.rho_prior = prior;
352 self
353 }
354
355 /// Override the auto-computed nullspace dimension.
356 ///
357 /// By default, `build()` computes nullspace_dim as
358 /// `beta.len() - sum(penalty_coord.rank())`. Use this when the caller
359 /// has a different authoritative value (e.g. from stored per-penalty dims).
360 pub fn nullspace_dim_override(mut self, dim: f64) -> Self {
361 self.nullspace_dim_override = Some(dim);
362 self
363 }
364
365 pub fn ext_coords(mut self, coords: Vec<HyperCoord>) -> Self {
366 self.ext_coords = coords;
367 self
368 }
369
370 pub fn ext_coord_pair_fn(
371 mut self,
372 f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
373 ) -> Self {
374 self.ext_coord_pair_fn = Some(f);
375 self
376 }
377
378 pub fn rho_ext_pair_fn(
379 mut self,
380 f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
381 ) -> Self {
382 self.rho_ext_pair_fn = Some(f);
383 self
384 }
385
386 pub fn fixed_drift_deriv(mut self, f: FixedDriftDerivFn) -> Self {
387 self.fixed_drift_deriv = Some(f);
388 self
389 }
390
391 /// Install the direction-contracted second-order ψ hook (#740). When set,
392 /// the outer-Hessian operator builder uses it instead of the `K²` per-pair
393 /// `base_h2` ψψ assembly. See [`ContractedPsiSecondOrderFn`].
394 pub fn contracted_psi_second_order(mut self, f: Option<ContractedPsiSecondOrderFn>) -> Self {
395 self.contracted_psi_second_order = f;
396 self
397 }
398
399 pub fn barrier_config(mut self, config: Option<BarrierConfig>) -> Self {
400 self.barrier_config = config;
401 self
402 }
403
404 pub fn kkt_residual(mut self, residual: Option<ProjectedKktResidual>) -> Self {
405 self.kkt_residual = residual;
406 self
407 }
408
409 /// Stash the active linear-inequality constraint block carried alongside the
410 /// inner solution. Used by `PenaltySubspaceTrace::with_active_constraints`
411 /// at REML/LAML evaluation time to form the constraint-aware kernel
412 /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S`.
413 pub fn active_constraints(mut self, block: Option<Arc<ActiveLinearConstraintBlock>>) -> Self {
414 self.active_constraints = block;
415 self
416 }
417
418 /// Build the `InnerSolution`, auto-computing nullspace_dim from penalty coordinates.
419 pub fn build(self) -> InnerSolution<'dp> {
420 let beta_dim = self.beta.len();
421 let penalty_dim = self.penalty_coords.len();
422 assert_eq!(
423 self.hessian_op.dim(),
424 beta_dim,
425 "InnerSolutionBuilder: Hessian dimension {} does not match beta length {}",
426 self.hessian_op.dim(),
427 beta_dim
428 );
429 for (idx, coord) in self.penalty_coords.iter().enumerate() {
430 assert_eq!(
431 coord.dim(),
432 beta_dim,
433 "InnerSolutionBuilder: penalty coordinate {idx} has dimension {} but beta length is {}",
434 coord.dim(),
435 beta_dim
436 );
437 }
438 assert_eq!(
439 self.penalty_logdet.first.len(),
440 penalty_dim,
441 "InnerSolutionBuilder: penalty logdet first-derivative length {} does not match penalty coordinate count {}",
442 self.penalty_logdet.first.len(),
443 penalty_dim
444 );
445 if let Some(second) = self.penalty_logdet.second.as_ref() {
446 assert!(
447 second.nrows() == penalty_dim && second.ncols() == penalty_dim,
448 "InnerSolutionBuilder: penalty logdet Hessian shape {}x{} does not match penalty coordinate count {}",
449 second.nrows(),
450 second.ncols(),
451 penalty_dim
452 );
453 }
454 if let Some(barrier_config) = self.barrier_config.as_ref() {
455 assert_eq!(
456 barrier_config.constrained_indices.len(),
457 barrier_config.lower_bounds.len(),
458 "InnerSolutionBuilder: barrier constrained index count {} does not match lower-bound count {}",
459 barrier_config.constrained_indices.len(),
460 barrier_config.lower_bounds.len()
461 );
462 assert_eq!(
463 barrier_config.constrained_indices.len(),
464 barrier_config.bound_signs.len(),
465 "InnerSolutionBuilder: barrier constrained index count {} does not match bound-direction count {}",
466 barrier_config.constrained_indices.len(),
467 barrier_config.bound_signs.len()
468 );
469 assert!(
470 barrier_config.tau.is_finite() && barrier_config.tau >= 0.0,
471 "InnerSolutionBuilder: barrier tau must be finite and non-negative, got {}",
472 barrier_config.tau
473 );
474 for ((&idx, &lower_bound), &sign) in barrier_config
475 .constrained_indices
476 .iter()
477 .zip(barrier_config.lower_bounds.iter())
478 .zip(barrier_config.bound_signs.iter())
479 {
480 assert!(
481 idx < beta_dim,
482 "InnerSolutionBuilder: barrier constrained index {idx} out of bounds for beta length {beta_dim}"
483 );
484 assert!(
485 lower_bound.is_finite(),
486 "InnerSolutionBuilder: barrier lower bound for beta[{idx}] must be finite, got {lower_bound}"
487 );
488 assert!(
489 sign == 1.0 || sign == -1.0,
490 "InnerSolutionBuilder: barrier bound direction for beta[{idx}] must be ±1, got {sign}"
491 );
492 }
493 }
494 if let Some(active_constraints) = self.active_constraints.as_ref() {
495 assert_eq!(
496 active_constraints.a.ncols(),
497 beta_dim,
498 "InnerSolutionBuilder: active constraint width {} does not match beta length {}",
499 active_constraints.a.ncols(),
500 beta_dim
501 );
502 }
503 let nullspace_dim = self.nullspace_dim_override.unwrap_or_else(|| {
504 let penalty_rank: usize = self
505 .penalty_coords
506 .iter()
507 .map(PenaltyCoordinate::rank)
508 .sum();
509 beta_dim.saturating_sub(penalty_rank) as f64
510 });
511
512 InnerSolution {
513 log_likelihood: self.log_likelihood,
514 penalty_quadratic: self.penalty_quadratic,
515 hessian_op: self.hessian_op,
516 beta: self.beta,
517 penalty_coords: self.penalty_coords,
518 penalty_logdet: self.penalty_logdet,
519 deriv_provider: self.deriv_provider,
520 firth: self.firth,
521 hessian_logdet_correction: self.hessian_logdet_correction,
522 penalty_subspace_trace: self.penalty_subspace_trace,
523 rho_curvature_scale: self.rho_curvature_scale,
524 rho_prior: self.rho_prior,
525 n_observations: self.n_observations,
526 nullspace_dim,
527 gaussian_weight_log_sum_half: self.gaussian_weight_log_sum_half,
528 dp_floor_scale: self.dp_floor_scale,
529 dispersion: self.dispersion,
530 ext_coords: self.ext_coords,
531 ext_coord_pair_fn: self.ext_coord_pair_fn,
532 rho_ext_pair_fn: self.rho_ext_pair_fn,
533 fixed_drift_deriv: self.fixed_drift_deriv,
534 contracted_psi_second_order: self.contracted_psi_second_order,
535 barrier_config: self.barrier_config,
536 kkt_residual: self.kkt_residual,
537 active_constraints: self.active_constraints,
538 stochastic_trace_state: Arc::new(Mutex::new(StochasticTraceState::default())),
539 }
540 }
541}