gam_solve/reml/reml_outer_engine/inner_solution.rs
1use super::*;
2
3/// Specifies whether the model uses profiled scale (Gaussian REML) or
4/// fixed dispersion (non-Gaussian LAML).
5#[derive(Clone, Debug)]
6pub enum DispersionHandling {
7 /// Gaussian REML: φ̂ = D_p / (n − M_p), profiled out of the objective.
8 /// The cost includes (n−M_p)/2 · log(2πφ̂) and the gradient includes
9 /// the profiled scale derivative. Always includes both logdet terms.
10 ProfiledGaussian,
11 /// Non-Gaussian LAML or maximum penalized likelihood.
12 ///
13 /// `include_logdet_h` controls whether ½ log|H| is included (true for full
14 /// LAML, false for MPL/PQL).
15 /// `include_logdet_s` controls whether −½ log|S|₊ is included.
16 ///
17 /// Standard LAML: `Fixed { phi: 1.0, include_logdet_h: true, include_logdet_s: true }`
18 /// MaxPenalizedLikelihood: `Fixed { phi: 1.0, include_logdet_h: false, include_logdet_s: false }`
19 Fixed {
20 phi: f64,
21 include_logdet_h: bool,
22 include_logdet_s: bool,
23 },
24}
25
26/// The unified inner solution produced by any inner solver.
27///
28/// Contains everything the outer REML/LAML evaluator needs. Produced by:
29/// - Single-block PIRLS (via `PirlsResult::into_inner_solution()`)
30/// - Blockwise coupled Newton (via `BlockwiseInnerResult::into_inner_solution()`)
31/// - Sparse Cholesky (via `SparsePenalizedSystem::into_inner_solution()`)
32pub struct InnerSolution<'dp> {
33 // === Objective ingredients ===
34 /// ℓ(β̂) — log-likelihood at the converged mode.
35 /// For Gaussian: −0.5 × deviance (RSS). For GLMs: actual log-likelihood.
36 pub log_likelihood: f64,
37
38 /// β̂ᵀS(ρ)β̂ — penalty quadratic form at the mode.
39 pub penalty_quadratic: f64,
40
41 // === The factorization (single source of truth for all linear algebra) ===
42 /// The Hessian operator providing logdet, trace, and solve.
43 /// Both cost and gradient use this same object.
44 ///
45 /// IMPORTANT: This MUST encode the **observed** Hessian H_obs = X'W_obs X + S
46 /// at the converged mode, where W_obs includes the residual-dependent correction
47 /// for non-canonical links. Using expected Fisher H_Fisher = X'W_Fisher X + S
48 /// would make this a PQL surrogate rather than the exact Laplace approximation.
49 /// See response.md Section 3 for the mathematical justification.
50 pub hessian_op: Arc<dyn HessianOperator>,
51
52 // === Coefficients and penalty structure ===
53 /// β̂ — coefficients at the converged mode (in the operator's native basis).
54 pub beta: Array1<f64>,
55
56 /// Penalty coordinates for the rho block.
57 ///
58 /// Each coordinate represents one smoothing-parameter direction
59 /// A_k = λ_k S_k
60 /// through either a full-root or a block-local root.
61 pub penalty_coords: Vec<PenaltyCoordinate>,
62
63 /// Derivatives of log|S(ρ)|₊ — precomputed from penalty structure.
64 pub penalty_logdet: PenaltyLogdetDerivs,
65
66 // === Family-specific derivative info ===
67 /// Provider of third-derivative corrections for non-Gaussian families.
68 ///
69 /// The c and d arrays (dW/deta, d^2W/deta^2) carried by this provider MUST
70 /// be the **observed** derivatives, not the Fisher derivatives. For non-canonical
71 /// links the observed c/d include residual-dependent corrections:
72 /// c_obs = c_Fisher + h'*B - (y-mu)*B_eta
73 /// d_obs = d_Fisher + h''*B + 2*h'*B_eta - (y-mu)*B_etaeta
74 /// These corrections matter for the outer gradient (C[v] correction) and
75 /// outer Hessian (Q[v_k, v_l] correction). See response.md Section 3.
76 pub deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
77
78 // === Corrections ===
79 /// Optional exact Jeffreys/Firth term in the active coefficient basis.
80 pub firth: Option<ExactJeffreysTerm>,
81
82 /// Additive correction for the Hessian logdet when `hessian_op` encodes a
83 /// uniformly rescaled exact curvature matrix.
84 pub hessian_logdet_correction: f64,
85
86 /// When the cost uses `log|U_Sᵀ H U_S|_+` (rank-deficient LAML fix),
87 /// this carries the matching projected kernel so the gradient trace
88 /// `tr(K · Ḣ)` agrees with the cost's derivative. See
89 /// [`PenaltySubspaceTrace`] for the full derivation.
90 pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
91
92 /// Uniform scale `s` applied to rho-coordinate penalty derivatives in the
93 /// H-dependent trace / solve parts of the outer calculus.
94 ///
95 /// ## Contract (CRITICAL — gradient/cost consistency)
96 ///
97 /// `rho_curvature_scale` is NOT a free knob. It encodes the convention
98 /// that the supplied `hessian_op` represents the **rescaled** curvature
99 /// `H_op = s · (∇²(-ℓ) + Σ_k e^{ρ_k} S_k)`, i.e. every contribution to
100 /// the curvature (likelihood Hessian AND penalty `λ_k S_k`) has been
101 /// uniformly multiplied by `s` before reaching the evaluator. Under this
102 /// convention:
103 ///
104 /// * `∂H_op/∂ρ_k = s · λ_k S_k` (matches the `curvature_lambdas = s · λ`
105 /// drift used inside the gradient's trace term),
106 /// * `K = H_op⁻¹ = (1/s) · (∇²(-ℓ) + λS)⁻¹`,
107 /// * `tr(K · ∂H_op/∂ρ_k) = tr((∇²(-ℓ) + λS)⁻¹ · λ_k S_k)` (the analytic
108 /// gradient of the **unscaled** `log|H|`),
109 /// * `log|H_op| = log|∇²(-ℓ) + λS| + p · log(s)`, which the caller MUST
110 /// un-scale by supplying `hessian_logdet_correction += −p · log(s)` so
111 /// that `hop.logdet() + hessian_logdet_correction` evaluates the same
112 /// unscaled `log|H|` whose derivative the gradient trace computes.
113 ///
114 /// Callers that set `rho_curvature_scale ≠ 1` without ALSO pre-scaling
115 /// `hessian_op` AND adding the matching `−p·log(s)` term to
116 /// `hessian_logdet_correction` will get a gradient that is off by the
117 /// factor `s` from `dV/dρ_k`. The unified evaluator does **not** scale
118 /// `hop` for the caller — that would defeat the purpose of the
119 /// curvature-conditioning trick survival families use to keep the
120 /// outer eigendecomposition numerically stable.
121 ///
122 /// See `survival::location_scale::exact_newton_outer_curvature` for the
123 /// canonical example: `rho_curvature_scale = exp(-log_scale)` paired with
124 /// `hessian_logdet_correction = p · log_scale = −p · log(scale)`.
125 ///
126 /// The evaluator enforces `rho_curvature_scale > 0` and finite; pass
127 /// `1.0` (the documented default) when no curvature conditioning is in
128 /// play.
129 pub rho_curvature_scale: f64,
130
131 /// Configured prior over rho coordinates. The evaluator receives the
132 /// realized cost/gradient tuple separately; this copy lets EFS use the
133 /// conjugate Gamma rate in its multiplicative denominator.
134 pub rho_prior: gam_problem::RhoPrior,
135
136 // === Model dimensions ===
137 /// Number of observations.
138 pub n_observations: usize,
139
140 /// M_p: dimension of the penalty null space (unpenalized coefficients).
141 pub nullspace_dim: f64,
142
143 /// ½·Σᵢ log(wᵢ) — half the sum of log prior weights.
144 ///
145 /// This is the per-observation Gaussian normalization constant that the
146 /// `log_likelihood` (computed by
147 /// [`calculate_loglikelihood_omitting_constants`]) deliberately drops. The
148 /// full weighted-Gaussian negative log-likelihood normalization is
149 /// ½·Σᵢ log(2π·φ/wᵢ) = (n/2)·log(2πφ) − ½·Σᵢ log(wᵢ),
150 /// because `Var(yᵢ) = φ/wᵢ` under inverse-variance prior weights.
151 ///
152 /// Dropping `−½·Σ log(wᵢ)` does not move the ρ-argmin in exact arithmetic
153 /// (it is constant in ρ), but it makes the ProfiledGaussian objective VALUE
154 /// scale-dependent: under a global rescale `w → c·w` the invariance-
155 /// preserving smoothing `λ → c·λ` leaves the cost SHAPE fixed but inflates
156 /// its absolute value by `(n/2)·log c`. That inflation breaks the exact
157 /// weight-scale invariance of the selected λ̂ / EDF / fit (issue #877).
158 /// Restoring this term makes the ProfiledGaussian cost value exactly
159 /// invariant to `w → c·w` (with σ̂² absorbing the c factor), matching mgcv.
160 ///
161 /// Only consumed by the `ProfiledGaussian` arm; the `Fixed`-dispersion arm
162 /// already omits the Gaussian normalization constant by design and is not
163 /// affected.
164 pub gaussian_weight_log_sum_half: f64,
165
166 /// Deviance scale `D₀` used as the *relative* reference for the smooth
167 /// penalized-deviance floor (see [`crate::estimate::smooth_floor_dp`]).
168 ///
169 /// Set to the weighted null deviance of the Gaussian response,
170 /// `D₀ = Σ wᵢ(yᵢ − ȳ_w)²`, which is the natural upper reference for
171 /// `D_p` and — crucially — transforms as `D₀ → a²·D₀` under a response
172 /// rescale `y → a·y`, exactly as `D_p` does. Flooring `D_p` at a fixed
173 /// fraction of `D₀` therefore keeps the profiled Gaussian REML criterion
174 /// exactly scale-equivariant (issue #1127); an absolute floor does not.
175 ///
176 /// Only consumed by the `ProfiledGaussian` arm. Defaults to `1.0`, which
177 /// reproduces the historical absolute floor byte-for-byte for every caller
178 /// that does not supply a response scale.
179 pub dp_floor_scale: f64,
180
181 /// How the dispersion parameter is handled.
182 pub dispersion: DispersionHandling,
183
184 // === Extended hyperparameter coordinates (ψ / τ) ===
185 /// External (non-ρ) hyperparameter coordinates with their fixed-β objects.
186 /// These are appended after the ρ coordinates in the gradient/Hessian output.
187 pub ext_coords: Vec<HyperCoord>,
188
189 /// Callback to compute second-order fixed-β objects for a pair (i, j)
190 /// of external coordinates (or external × ρ cross pairs).
191 /// Arguments: (ext_index_i, ext_index_j) → HyperCoordPair.
192 /// When None, the outer Hessian is not computed for extended coordinates.
193 ///
194 /// `Arc`-backed ([`HyperCoordPairFn`]) so a derived solution — notably the
195 /// tangent-projected solution built under active inequality constraints —
196 /// can clone the same callback through to `ValueGradientHessian` assembly.
197 pub ext_coord_pair_fn: Option<HyperCoordPairFn>,
198
199 /// Callback for ρ × ext cross pairs: (rho_index, ext_index) → HyperCoordPair.
200 pub rho_ext_pair_fn: Option<HyperCoordPairFn>,
201
202 /// M_i[u] = D_β B_i[u] callback for extended coordinates.
203 /// Arguments: (ext_index, direction) → correction matrix.
204 ///
205 /// `Arc`-backed ([`SharedFixedDriftDerivFn`]) so the tangent-projected
206 /// solution can clone it through — the drift `M` is a p-space matrix the
207 /// wrapper projects via `ZᵀMZ` in `trace_logdet_*`.
208 pub fixed_drift_deriv: Option<SharedFixedDriftDerivFn>,
209
210 /// Direction-contracted second-order ψ hook for the profiled θ-HVP (#740).
211 /// When present, the outer-Hessian operator builder skips the `K²` per-pair
212 /// `base_h2` ψψ assembly and instead applies this once per matvec to obtain
213 /// every output row's `tr(K · D²_ψ H_L[ψ_i, ψ(α)])` in a single family row
214 /// pass. `None` keeps the exact per-pair assembly. See
215 /// [`ContractedPsiSecondOrderFn`].
216 pub contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
217
218 /// Optional log-barrier configuration for monotonicity-constrained coefficients.
219 /// When present, the barrier cost and Hessian corrections are added to the
220 /// outer REML/LAML objective.
221 pub barrier_config: Option<BarrierConfig>,
222
223 /// Optional inner KKT residual `r = ∇_β L_pen(β̂)` at the converged β̂,
224 /// already projected onto the free subspace (see [`ProjectedKktResidual`]
225 /// for the invariant and why the type wraps this). `Some` activates the
226 /// implicit-function-theorem corrections in `reml_laml_evaluate` (cost
227 /// gets `−½ rᵀ H⁻¹ r`, ρ-gradient and ρρ Hessian get the matching first
228 /// and second derivatives of that same scalar correction). `None` keeps
229 /// the envelope-only behaviour for callers that genuinely guarantee
230 /// exact KKT.
231 pub kkt_residual: Option<ProjectedKktResidual>,
232
233 /// Optional active linear-inequality constraints at the converged inner
234 /// iterate. `Some(rows)` means the joint constraint matrix's row indices
235 /// in `rows.active_indices` are pinned (treated as equality constraints
236 /// at the cert point). The unified evaluator combines this with the
237 /// `penalty_subspace_trace` to form the **constraint-aware** kernel
238 /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S` for per-coordinate IFT mode
239 /// responses `v_k = ∂β/∂ρ_k`. See [`ConstrainedSubspaceKernel`] for
240 /// the full derivation and consistency with `log|U_Tᵀ H U_T|`.
241 ///
242 /// `None` is the legacy/unconstrained path (no active inequality
243 /// constraints to project against).
244 pub active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
245
246 /// Fit-level stochastic trace state. Shared by stochastic trace batches so
247 /// CRN probe prefixes stay fixed and matrix-free trace CG can warm-start
248 /// from the previous solve of the same probe id.
249 pub stochastic_trace_state: Arc<Mutex<StochasticTraceState>>,
250}
251
252/// Builder for `InnerSolution` that provides sensible defaults and
253/// auto-computes derived quantities (nullspace_dim).
254pub struct InnerSolutionBuilder<'dp> {
255 // Required fields
256 pub(crate) log_likelihood: f64,
257 pub(crate) penalty_quadratic: f64,
258 pub(crate) hessian_op: Arc<dyn HessianOperator>,
259 pub(crate) beta: Array1<f64>,
260 pub(crate) penalty_coords: Vec<PenaltyCoordinate>,
261 pub(crate) penalty_logdet: PenaltyLogdetDerivs,
262 pub(crate) n_observations: usize,
263 pub(crate) dispersion: DispersionHandling,
264 // Optional fields with defaults
265 pub(crate) deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
266 pub(crate) firth: Option<ExactJeffreysTerm>,
267 pub(crate) hessian_logdet_correction: f64,
268 pub(crate) penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
269 pub(crate) rho_curvature_scale: f64,
270 pub(crate) rho_prior: gam_problem::RhoPrior,
271 pub(crate) nullspace_dim_override: Option<f64>,
272 // Extended hyperparameter coordinates
273 pub(crate) ext_coords: Vec<HyperCoord>,
274 pub(crate) ext_coord_pair_fn: Option<HyperCoordPairFn>,
275 pub(crate) rho_ext_pair_fn: Option<HyperCoordPairFn>,
276 pub(crate) fixed_drift_deriv: Option<SharedFixedDriftDerivFn>,
277 pub(crate) contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
278 pub(crate) barrier_config: Option<BarrierConfig>,
279 pub(crate) kkt_residual: Option<ProjectedKktResidual>,
280 pub(crate) active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
281 pub(crate) gaussian_weight_log_sum_half: f64,
282 pub(crate) dp_floor_scale: f64,
283}
284
285impl<'dp> InnerSolutionBuilder<'dp> {
286 /// Create a builder with the required core fields.
287 pub fn new(
288 log_likelihood: f64,
289 penalty_quadratic: f64,
290 beta: Array1<f64>,
291 n_observations: usize,
292 hessian_op: Arc<dyn HessianOperator>,
293 penalty_coords: Vec<PenaltyCoordinate>,
294 penalty_logdet: PenaltyLogdetDerivs,
295 dispersion: DispersionHandling,
296 ) -> Self {
297 Self {
298 log_likelihood,
299 penalty_quadratic,
300 hessian_op,
301 beta,
302 penalty_coords,
303 penalty_logdet,
304 n_observations,
305 dispersion,
306 deriv_provider: Box::new(GaussianDerivatives),
307 firth: None,
308 hessian_logdet_correction: 0.0,
309 penalty_subspace_trace: None,
310 rho_curvature_scale: 1.0,
311 rho_prior: gam_problem::RhoPrior::Flat,
312 nullspace_dim_override: None,
313 ext_coords: Vec::new(),
314 ext_coord_pair_fn: None,
315 rho_ext_pair_fn: None,
316 fixed_drift_deriv: None,
317 contracted_psi_second_order: None,
318 barrier_config: None,
319 kkt_residual: None,
320 active_constraints: None,
321 gaussian_weight_log_sum_half: 0.0,
322 dp_floor_scale: 1.0,
323 }
324 }
325
326 pub fn deriv_provider(mut self, p: Box<dyn HessianDerivativeProvider + 'dp>) -> Self {
327 self.deriv_provider = p;
328 self
329 }
330
331 /// Install a pre-built Jeffreys/Firth term (Tier-A operator-backed via
332 /// `ExactJeffreysTerm::new`, or the Tier-B value-only carrier via
333 /// `ExactJeffreysTerm::value_only`).
334 pub fn firth_term(mut self, term: Option<ExactJeffreysTerm>) -> Self {
335 self.firth = term;
336 self
337 }
338
339 pub fn hessian_logdet_correction(mut self, correction: f64) -> Self {
340 self.hessian_logdet_correction = correction;
341 self
342 }
343
344 /// Install the projected-logdet trace kernel that pairs with the
345 /// `hessian_logdet_correction` on a rank-deficient penalty surface.
346 /// See [`PenaltySubspaceTrace`] for the derivation and when it is
347 /// required for gradient consistency.
348 pub fn penalty_subspace_trace(mut self, kernel: Option<Arc<PenaltySubspaceTrace>>) -> Self {
349 self.penalty_subspace_trace = kernel;
350 self
351 }
352
353 pub fn rho_curvature_scale(mut self, scale: f64) -> Self {
354 self.rho_curvature_scale = scale;
355 self
356 }
357
358 pub fn rho_prior(mut self, prior: gam_problem::RhoPrior) -> Self {
359 self.rho_prior = prior;
360 self
361 }
362
363 /// Override the auto-computed nullspace dimension.
364 ///
365 /// By default, `build()` computes nullspace_dim as
366 /// `beta.len() - sum(penalty_coord.rank())`. Use this when the caller
367 /// has a different authoritative value (e.g. from stored per-penalty dims).
368 pub fn nullspace_dim_override(mut self, dim: f64) -> Self {
369 self.nullspace_dim_override = Some(dim);
370 self
371 }
372
373 pub fn ext_coords(mut self, coords: Vec<HyperCoord>) -> Self {
374 self.ext_coords = coords;
375 self
376 }
377
378 pub fn ext_coord_pair_fn(
379 mut self,
380 f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
381 ) -> Self {
382 // `Arc::from(Box<dyn …>)` is a zero-cost re-tag of the existing
383 // allocation; storing it as `Arc` lets the projected solution clone
384 // the callback through.
385 self.ext_coord_pair_fn = Some(HyperCoordPairFn::from(f));
386 self
387 }
388
389 pub fn rho_ext_pair_fn(
390 mut self,
391 f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
392 ) -> Self {
393 self.rho_ext_pair_fn = Some(HyperCoordPairFn::from(f));
394 self
395 }
396
397 pub fn fixed_drift_deriv(mut self, f: FixedDriftDerivFn) -> Self {
398 // `Arc::from(Box<dyn …>)` re-tags the existing allocation for shared
399 // ownership so the projected solution can clone the callback through.
400 self.fixed_drift_deriv = Some(SharedFixedDriftDerivFn::from(f));
401 self
402 }
403
404 /// Install the direction-contracted second-order ψ hook (#740). When set,
405 /// the outer-Hessian operator builder uses it instead of the `K²` per-pair
406 /// `base_h2` ψψ assembly. See [`ContractedPsiSecondOrderFn`].
407 pub fn contracted_psi_second_order(mut self, f: Option<ContractedPsiSecondOrderFn>) -> Self {
408 self.contracted_psi_second_order = f;
409 self
410 }
411
412 pub fn barrier_config(mut self, config: Option<BarrierConfig>) -> Self {
413 self.barrier_config = config;
414 self
415 }
416
417 pub fn kkt_residual(mut self, residual: Option<ProjectedKktResidual>) -> Self {
418 self.kkt_residual = residual;
419 self
420 }
421
422 /// Stash the active linear-inequality constraint block carried alongside the
423 /// inner solution. Used by `PenaltySubspaceTrace::with_active_constraints`
424 /// at REML/LAML evaluation time to form the constraint-aware kernel
425 /// `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S`.
426 pub fn active_constraints(mut self, block: Option<Arc<ActiveLinearConstraintBlock>>) -> Self {
427 self.active_constraints = block;
428 self
429 }
430
431 /// Build the `InnerSolution`, auto-computing nullspace_dim from penalty coordinates.
432 pub fn build(self) -> InnerSolution<'dp> {
433 let beta_dim = self.beta.len();
434 let penalty_dim = self.penalty_coords.len();
435 assert_eq!(
436 self.hessian_op.dim(),
437 beta_dim,
438 "InnerSolutionBuilder: Hessian dimension {} does not match beta length {}",
439 self.hessian_op.dim(),
440 beta_dim
441 );
442 for (idx, coord) in self.penalty_coords.iter().enumerate() {
443 assert_eq!(
444 coord.dim(),
445 beta_dim,
446 "InnerSolutionBuilder: penalty coordinate {idx} has dimension {} but beta length is {}",
447 coord.dim(),
448 beta_dim
449 );
450 }
451 assert_eq!(
452 self.penalty_logdet.first.len(),
453 penalty_dim,
454 "InnerSolutionBuilder: penalty logdet first-derivative length {} does not match penalty coordinate count {}",
455 self.penalty_logdet.first.len(),
456 penalty_dim
457 );
458 if let Some(second) = self.penalty_logdet.second.as_ref() {
459 assert!(
460 second.nrows() == penalty_dim && second.ncols() == penalty_dim,
461 "InnerSolutionBuilder: penalty logdet Hessian shape {}x{} does not match penalty coordinate count {}",
462 second.nrows(),
463 second.ncols(),
464 penalty_dim
465 );
466 }
467 if let Some(barrier_config) = self.barrier_config.as_ref() {
468 assert_eq!(
469 barrier_config.constrained_indices.len(),
470 barrier_config.lower_bounds.len(),
471 "InnerSolutionBuilder: barrier constrained index count {} does not match lower-bound count {}",
472 barrier_config.constrained_indices.len(),
473 barrier_config.lower_bounds.len()
474 );
475 assert_eq!(
476 barrier_config.constrained_indices.len(),
477 barrier_config.bound_signs.len(),
478 "InnerSolutionBuilder: barrier constrained index count {} does not match bound-direction count {}",
479 barrier_config.constrained_indices.len(),
480 barrier_config.bound_signs.len()
481 );
482 assert!(
483 barrier_config.tau.is_finite() && barrier_config.tau >= 0.0,
484 "InnerSolutionBuilder: barrier tau must be finite and non-negative, got {}",
485 barrier_config.tau
486 );
487 for ((&idx, &lower_bound), &sign) in barrier_config
488 .constrained_indices
489 .iter()
490 .zip(barrier_config.lower_bounds.iter())
491 .zip(barrier_config.bound_signs.iter())
492 {
493 assert!(
494 idx < beta_dim,
495 "InnerSolutionBuilder: barrier constrained index {idx} out of bounds for beta length {beta_dim}"
496 );
497 assert!(
498 lower_bound.is_finite(),
499 "InnerSolutionBuilder: barrier lower bound for beta[{idx}] must be finite, got {lower_bound}"
500 );
501 assert!(
502 sign == 1.0 || sign == -1.0,
503 "InnerSolutionBuilder: barrier bound direction for beta[{idx}] must be ±1, got {sign}"
504 );
505 }
506 }
507 if let Some(active_constraints) = self.active_constraints.as_ref() {
508 assert_eq!(
509 active_constraints.a.ncols(),
510 beta_dim,
511 "InnerSolutionBuilder: active constraint width {} does not match beta length {}",
512 active_constraints.a.ncols(),
513 beta_dim
514 );
515 }
516 let nullspace_dim = self.nullspace_dim_override.unwrap_or_else(|| {
517 let penalty_rank: usize = self
518 .penalty_coords
519 .iter()
520 .map(PenaltyCoordinate::rank)
521 .sum();
522 beta_dim.saturating_sub(penalty_rank) as f64
523 });
524
525 InnerSolution {
526 log_likelihood: self.log_likelihood,
527 penalty_quadratic: self.penalty_quadratic,
528 hessian_op: self.hessian_op,
529 beta: self.beta,
530 penalty_coords: self.penalty_coords,
531 penalty_logdet: self.penalty_logdet,
532 deriv_provider: self.deriv_provider,
533 firth: self.firth,
534 hessian_logdet_correction: self.hessian_logdet_correction,
535 penalty_subspace_trace: self.penalty_subspace_trace,
536 rho_curvature_scale: self.rho_curvature_scale,
537 rho_prior: self.rho_prior,
538 n_observations: self.n_observations,
539 nullspace_dim,
540 gaussian_weight_log_sum_half: self.gaussian_weight_log_sum_half,
541 dp_floor_scale: self.dp_floor_scale,
542 dispersion: self.dispersion,
543 ext_coords: self.ext_coords,
544 ext_coord_pair_fn: self.ext_coord_pair_fn,
545 rho_ext_pair_fn: self.rho_ext_pair_fn,
546 fixed_drift_deriv: self.fixed_drift_deriv,
547 contracted_psi_second_order: self.contracted_psi_second_order,
548 barrier_config: self.barrier_config,
549 kkt_residual: self.kkt_residual,
550 active_constraints: self.active_constraints,
551 stochastic_trace_state: Arc::new(Mutex::new(StochasticTraceState::default())),
552 }
553 }
554}