Skip to main content

gam_solve/estimate/
evaluation.rs

1use super::*;
2
3pub(crate) fn sas_log_deltaridgeweight() -> f64 {
4    // Weak fixed stabilization for the SAS tail parameter to avoid
5    // boundary/flat-region pathologies in outer optimization.
6    1e-4
7}
8
9#[inline]
10pub(crate) fn sas_log_delta_edge_barrierweight() -> f64 {
11    // Keep SAS raw log-delta away from tanh-saturation edges where
12    // link sensitivities collapse and outer gradients become uninformative.
13    1e-2
14}
15
16#[inline]
17pub(crate) fn sas_log_delta_bound() -> f64 {
18    crate::mixture_link::SAS_LOG_DELTA_BOUND
19}
20
21#[inline]
22pub(crate) fn sas_log_delta_edge_barriercostgrad(raw_log_delta: f64) -> (f64, f64) {
23    let w = sas_log_delta_edge_barrierweight();
24    if w <= 0.0 || !raw_log_delta.is_finite() {
25        return (0.0, 0.0);
26    }
27    let b = sas_log_delta_bound().max(f64::EPSILON);
28    let t = (raw_log_delta / b).tanh();
29    let one_minus_t2 = (1.0 - t * t).max(1e-12);
30    let cost = -w * one_minus_t2.ln();
31    // d/draw[-w log(1-t^2)] = (2w/B) * t.
32    let grad = (2.0 * w / b) * t;
33    (cost, grad)
34}
35
36#[inline]
37pub(crate) fn sas_epsilon_bound() -> f64 {
38    // Fixed smooth bound on raw SAS epsilon during outer optimization.
39    8.0
40}
41
42#[inline]
43pub(crate) fn sas_effective_epsilon(raw_epsilon: f64) -> (f64, f64) {
44    let bound = sas_epsilon_bound().max(f64::EPSILON);
45    let t = (raw_epsilon / bound).tanh();
46    let epsilon = bound * t;
47    let d_epsilon_d_raw = 1.0 - t * t;
48    (epsilon, d_epsilon_d_raw)
49}
50
51#[inline]
52pub(crate) fn sas_effective_epsilon_second(raw_epsilon: f64) -> (f64, f64, f64) {
53    let bound = sas_epsilon_bound().max(f64::EPSILON);
54    let t = (raw_epsilon / bound).tanh();
55    let first = 1.0 - t * t;
56    let second = -2.0 * t * first / bound;
57    (bound * t, first, second)
58}
59
60#[inline]
61pub(crate) fn sas_log_delta_edge_barriercostgradhess(raw_log_delta: f64) -> (f64, f64, f64) {
62    let w = sas_log_delta_edge_barrierweight();
63    if w <= 0.0 || !raw_log_delta.is_finite() {
64        return (0.0, 0.0, 0.0);
65    }
66    let b = sas_log_delta_bound().max(f64::EPSILON);
67    let t = (raw_log_delta / b).tanh();
68    let one_minus_t2 = (1.0 - t * t).max(1e-12);
69    let cost = -w * one_minus_t2.ln();
70    let grad = (2.0 * w / b) * t;
71    let hess = (2.0 * w / (b * b)) * one_minus_t2;
72    (cost, grad, hess)
73}
74
75pub(crate) fn materialize_link_outer_hessian(
76    hessian: gam_problem::HessianResult,
77    theta_dim: usize,
78) -> Result<Array2<f64>, EstimationError> {
79    match hessian.materialize_dense() {
80        Ok(Some(h)) => {
81            if h.nrows() != theta_dim || h.ncols() != theta_dim {
82                crate::bail_invalid_estim!(
83                    "unified evaluator Hessian shape {}x{} != theta_dim {}",
84                    h.nrows(),
85                    h.ncols(),
86                    theta_dim
87                );
88            }
89            Ok(h)
90        }
91        Ok(None) => Err(EstimationError::InvalidInput(
92            "unified evaluator returned no analytic Hessian in ValueGradientHessian mode"
93                .to_string(),
94        )),
95        Err(err) => Err(EstimationError::InvalidInput(format!(
96            "failed to materialize analytic link Hessian: {err}"
97        ))),
98    }
99}
100
101/// Evaluate the analytic gradient of the external REML objective.
102pub fn evaluate_externalgradient<X>(
103    y: ArrayView1<'_, f64>,
104    w: ArrayView1<'_, f64>,
105    x: X,
106    offset: ArrayView1<'_, f64>,
107    s_list: &[BlockwisePenalty],
108    opts: &ExternalOptimOptions,
109    rho: &Array1<f64>,
110) -> Result<Array1<f64>, EstimationError>
111where
112    X: Into<DesignMatrix>,
113{
114    let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
115    let x = x.into();
116    if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
117        crate::bail_invalid_estim!("{}", message);
118    }
119
120    let p = x.ncols();
121    validate_penalty_specs(&specs, p, "evaluate_externalgradient")?;
122    let (canonical, active_nullspace_dims) = gam_terms::construction::canonicalize_penalty_specs(
123        &specs,
124        &opts.nullspace_dims,
125        p,
126        "evaluate_externalgradient",
127    )?;
128    if rho.len() != active_nullspace_dims.len() {
129        crate::bail_invalid_estim!(
130            "rho dimension mismatch: rho_dim={}, active_penalties={}",
131            rho.len(),
132            active_nullspace_dims.len()
133        );
134    }
135
136    let (cfg, _) = resolved_external_config(opts)?;
137
138    let y_o = y.to_owned();
139    let w_o = w.to_owned();
140    let offset_o = offset.to_owned();
141    let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
142    let x_fit = conditioning.apply_to_design(&x);
143    let fit_linear_constraints =
144        conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
145
146    let mut reml_state = RemlState::newwith_offset(
147        y_o.view(),
148        x_fit,
149        w_o.view(),
150        offset_o.view(),
151        canonical,
152        p,
153        &cfg,
154        Some(active_nullspace_dims),
155        None,
156        fit_linear_constraints,
157    )?;
158    reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
159    reml_state.set_rho_prior(opts.rho_prior.clone());
160    reml_state.set_link_states(
161        cfg.link_kind.mixture_state().cloned(),
162        cfg.link_kind.sas_state().copied(),
163    );
164
165    reml_state.compute_gradient(rho)
166}
167
168fn gaussian_identity_inner_residual_norm(
169    y: ArrayView1<'_, f64>,
170    w: ArrayView1<'_, f64>,
171    x: &DesignMatrix,
172    offset: ArrayView1<'_, f64>,
173    canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
174    rho: &Array1<f64>,
175    beta: &Array1<f64>,
176) -> Result<f64, EstimationError> {
177    if beta.len() != x.ncols() {
178        crate::bail_invalid_estim!(
179            "beta dimension mismatch: beta_dim={}, x_cols={}",
180            beta.len(),
181            x.ncols()
182        );
183    }
184    if rho.len() != canonical_penalties.len() {
185        crate::bail_invalid_estim!(
186            "rho dimension mismatch: rho_dim={}, active_penalties={}",
187            rho.len(),
188            canonical_penalties.len()
189        );
190    }
191
192    let mut residual = x.apply(beta);
193    residual += &offset;
194    residual -= &y;
195    residual *= &w;
196    let mut gradient = x.apply_transpose(&residual);
197
198    for (k, cp) in canonical_penalties.iter().enumerate() {
199        let lambda = rho[k].exp();
200        if lambda == 0.0 || cp.rank() == 0 {
201            continue;
202        }
203        let r = cp.col_range.clone();
204        let centered = &beta.slice(s![r.start..r.end]) - &cp.prior_mean;
205        let penalty_grad = cp.local.dot(&centered) * lambda;
206        gradient
207            .slice_mut(s![r.start..r.end])
208            .scaled_add(1.0, &penalty_grad);
209    }
210
211    Ok(gradient.iter().map(|v| v * v).sum::<f64>().sqrt())
212}
213
214/// Evaluate IFT and flat warm-start inner residuals at `rho + delta_rho`.
215///
216/// Computes the inner-KKT residual norm at the IFT-predicted coefficient
217/// `β_pred(ρ+Δρ)` obtained by linearizing the inner solution around the
218/// converged `β̂(ρ)`, alongside the residual norm for the "flat" warm start
219/// `β̂(ρ)` (the same coefficient without any IFT correction). The pair lets
220/// callers verify that the IFT predictor reduces the inner residual to the
221/// expected second-order remainder in `‖Δρ‖`.
222///
223/// # Math
224///
225/// Let `β̂(ρ)` minimize the penalized inner objective and `v_j = ∂β̂/∂ρ_j`
226/// be the IFT sensitivity vectors at `ρ`. The first-order predictor is
227///
228/// ```text
229///   β_pred(ρ + Δρ) = β̂(ρ) − Σ_j Δρ_j · v_j.
230/// ```
231///
232/// Writing `r(β, ρ) = ∇_β L(β, ρ)` for the inner-KKT residual, the test
233/// invariant exercised by callers is
234///
235/// ```text
236///   ‖ r( β_pred(ρ+Δρ),  ρ + Δρ ) ‖ = O( ‖Δρ‖² ).
237/// ```
238///
239/// The flat baseline `‖ r( β̂(ρ), ρ + Δρ ) ‖` is `O(‖Δρ‖)` for comparison.
240///
241/// # Arguments
242///
243/// * `y`, `w`, `x`, `offset` — full-data response, weights, design, offset.
244/// * `s_list` — blockwise penalty specifications matching `rho`.
245/// * `opts` — external optimization options; must be `GaussianIdentity`
246///   with no linear constraints.
247/// * `rho` — base log-smoothing parameter vector at which the IFT
248///   sensitivities are taken.
249/// * `delta_rho` — perturbation applied to `rho` for the residual probe.
250///
251/// # Returns
252///
253/// `(ift_residual_norm, flat_residual_norm)` — the L2 norm of the inner
254/// KKT residual at `β_pred(ρ+Δρ)` and at the flat warm start `β̂(ρ)`,
255/// both evaluated at `ρ + Δρ`.
256///
257/// # Used by
258///
259/// Tests that exercise the IFT predictor's residual-order property; not
260/// part of the production solver hot path.
261pub fn evaluate_external_ift_residual_at_perturbed_rho<X>(
262    y: ArrayView1<'_, f64>,
263    w: ArrayView1<'_, f64>,
264    x: X,
265    offset: ArrayView1<'_, f64>,
266    s_list: &[BlockwisePenalty],
267    opts: &ExternalOptimOptions,
268    rho: &Array1<f64>,
269    delta_rho: ArrayView1<'_, f64>,
270) -> Result<(f64, f64), EstimationError>
271where
272    X: Into<DesignMatrix>,
273{
274    if !opts.family.is_gaussian_identity() {
275        crate::bail_invalid_estim!(
276            "evaluate_external_ift_residual_at_perturbed_rho currently supports GaussianIdentity"
277                .to_string(),
278        );
279    }
280    if opts.linear_constraints.is_some() {
281        crate::bail_invalid_estim!(
282            "evaluate_external_ift_residual_at_perturbed_rho does not support constrained fits"
283                .to_string(),
284        );
285    }
286
287    let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
288    let x = x.into();
289    if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
290        crate::bail_invalid_estim!("{}", message);
291    }
292
293    let p = x.ncols();
294    validate_penalty_specs(&specs, p, "evaluate_external_ift_residual_at_perturbed_rho")?;
295    let (canonical, active_nullspace_dims) = gam_terms::construction::canonicalize_penalty_specs(
296        &specs,
297        &opts.nullspace_dims,
298        p,
299        "evaluate_external_ift_residual_at_perturbed_rho",
300    )?;
301    if rho.len() != active_nullspace_dims.len() {
302        crate::bail_invalid_estim!(
303            "rho dimension mismatch: rho_dim={}, active_penalties={}",
304            rho.len(),
305            active_nullspace_dims.len()
306        );
307    }
308    if delta_rho.len() != rho.len() {
309        crate::bail_invalid_estim!(
310            "delta_rho dimension mismatch: delta_dim={}, rho_dim={}",
311            delta_rho.len(),
312            rho.len()
313        );
314    }
315
316    let mut tight_opts = opts.clone();
317    tight_opts.tol = 1e-12;
318    let (cfg, _) = resolved_external_config(&tight_opts)?;
319
320    let y_o = y.to_owned();
321    let w_o = w.to_owned();
322    let offset_o = offset.to_owned();
323    let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
324    let x_fit = conditioning.apply_to_design(&x);
325    let fit_linear_constraints =
326        conditioning.transform_linear_constraints_to_internal(tight_opts.linear_constraints);
327
328    let mut reml_state = RemlState::newwith_offset(
329        y_o.view(),
330        x_fit.clone(),
331        w_o.view(),
332        offset_o.view(),
333        canonical.clone(),
334        p,
335        &cfg,
336        Some(active_nullspace_dims),
337        None,
338        fit_linear_constraints,
339    )?;
340    reml_state.set_penalty_shrinkage_floor(tight_opts.penalty_shrinkage_floor);
341    reml_state.set_rho_prior(tight_opts.rho_prior.clone());
342    reml_state.set_link_states(
343        cfg.link_kind.mixture_state().cloned(),
344        cfg.link_kind.sas_state().copied(),
345    );
346
347    reml_state.compute_gradient(rho)?;
348    let beta_hat = reml_state
349        .warm_start_beta
350        .read()
351        .unwrap()
352        .as_ref()
353        .map(|beta| beta.0.clone())
354        .ok_or_else(|| {
355            EstimationError::InvalidInput(
356                "PIRLS solve did not populate the warm-start beta cache".to_string(),
357            )
358        })?;
359
360    let rho_perturbed = rho + &delta_rho.to_owned();
361    let beta_pred = reml_state
362        .predict_warm_start_beta_ift_with_outcome(&rho_perturbed)
363        .map(|(beta, _)| beta.as_ref().clone())
364        .ok_or_else(|| {
365            EstimationError::InvalidInput(
366                "IFT warm-start predictor rejected the perturbed rho".to_string(),
367            )
368        })?;
369
370    let ift_residual = gaussian_identity_inner_residual_norm(
371        y_o.view(),
372        w_o.view(),
373        &x_fit,
374        offset_o.view(),
375        &canonical,
376        &rho_perturbed,
377        &beta_pred,
378    )?;
379    let flat_residual = gaussian_identity_inner_residual_norm(
380        y_o.view(),
381        w_o.view(),
382        &x_fit,
383        offset_o.view(),
384        &canonical,
385        &rho_perturbed,
386        &beta_hat,
387    )?;
388
389    Ok((ift_residual, flat_residual))
390}
391
392/// Evaluate the external cost and report the stabilization ridge used.
393/// This is a diagnostic helper for tests that need to detect ridge jitter.
394pub fn evaluate_externalcost_andridge<X>(
395    y: ArrayView1<'_, f64>,
396    w: ArrayView1<'_, f64>,
397    x: X,
398    offset: ArrayView1<'_, f64>,
399    s_list: &[BlockwisePenalty],
400    opts: &ExternalOptimOptions,
401    rho: &Array1<f64>,
402) -> Result<(f64, f64), EstimationError>
403where
404    X: Into<DesignMatrix>,
405{
406    let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
407    let x = x.into();
408    if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
409        crate::bail_invalid_estim!("{}", message);
410    }
411
412    let p = x.ncols();
413    validate_penalty_specs(&specs, p, "evaluate_externalcost_andridge")?;
414    let (canonical, active_nullspace_dims) = gam_terms::construction::canonicalize_penalty_specs(
415        &specs,
416        &opts.nullspace_dims,
417        p,
418        "evaluate_externalcost_andridge",
419    )?;
420    if rho.len() != active_nullspace_dims.len() {
421        crate::bail_invalid_estim!(
422            "rho dimension mismatch: rho_dim={}, active_penalties={}",
423            rho.len(),
424            active_nullspace_dims.len()
425        );
426    }
427
428    let (cfg, _) = resolved_external_config(opts)?;
429
430    let y_o = y.to_owned();
431    let w_o = w.to_owned();
432    let offset_o = offset.to_owned();
433    let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
434    let x_fit = conditioning.apply_to_design(&x);
435    let fit_linear_constraints =
436        conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
437
438    let mut reml_state = RemlState::newwith_offset(
439        y_o.view(),
440        x_fit,
441        w_o.view(),
442        offset_o.view(),
443        canonical,
444        p,
445        &cfg,
446        Some(active_nullspace_dims),
447        None,
448        fit_linear_constraints,
449    )?;
450    reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
451    reml_state.set_rho_prior(opts.rho_prior.clone());
452    reml_state.set_link_states(
453        cfg.link_kind.mixture_state().cloned(),
454        cfg.link_kind.sas_state().copied(),
455    );
456
457    let cost = reml_state.compute_cost(rho)?;
458    let ridge = reml_state.last_ridge_used().unwrap_or(0.0);
459    Ok((cost, ridge))
460}