Skip to main content

gam_solve/estimate/
external_options.rs

1use super::*;
2
3pub struct ExternalOptimResult {
4    pub beta: Array1<f64>,
5    pub lambdas: Array1<f64>,
6    pub likelihood_family: LikelihoodSpec,
7    pub likelihood_scale: LikelihoodScaleMetadata,
8    pub log_likelihood_normalization: LogLikelihoodNormalization,
9    pub log_likelihood: f64,
10    /// Residual scale on the response scale.
11    ///
12    /// Contract: Gaussian identity models store the residual standard
13    /// deviation sigma here. Non-Gaussian families keep the response-scale
14    /// summary used by their explicit likelihood-scale metadata.
15    pub standard_deviation: f64,
16    pub iterations: usize,
17    pub finalgrad_norm: f64,
18    /// True iff the outer optimizer reached a stationary point (gradient
19    /// norm below tolerance), as reported by the optimizer itself. False
20    /// when the run exhausted its iteration budget without reaching the
21    /// gradient tolerance. Downstream consumers should NOT assume that a
22    /// fit with `outer_converged == false` is unusable — it may still be
23    /// the best basin reached given the budget — but they must not treat
24    /// it as certified-converged either.
25    pub outer_converged: bool,
26    pub pirls_status: crate::pirls::PirlsStatus,
27    pub deviance: f64,
28    /// Stable quadratic penalty term βᵀSβ, including any solver ridge quadratic.
29    pub stable_penalty_term: f64,
30    pub used_device: bool,
31    pub max_abs_eta: f64,
32    pub constraint_kkt: Option<crate::pirls::ConstraintKktDiagnostics>,
33    pub artifacts: FitArtifacts,
34    pub inference: Option<FitInference>,
35    /// Complete REML/LAML objective value used for smoothing selection.
36    pub reml_score: f64,
37    pub fitted_link: FittedLinkState,
38    /// Number of outer REML cost-only evaluations executed during the fit
39    /// (the count the outer optimizer's trust-region/line-search probes drive,
40    /// each paying an inner P-IRLS solve). Surfaced for regression guards on
41    /// outer work (#1575); not part of the statistical contract.
42    pub outer_cost_evals: usize,
43    /// Number of *actual* full-n inner P-IRLS solves performed (cache-missing
44    /// `prepare_eval_bundlewithkey` calls). This is the true cost driver the
45    /// #1575 slowdown is measured in ("~150 outer cost evals each running a
46    /// full n-sized P-IRLS"): unlike `outer_cost_evals`, it excludes single-slot
47    /// cache hits and prior short-circuits and includes every solve done by the
48    /// seed-grid prepass, screening, multistart, and finalize phases. Surfaced
49    /// for a regression guard that pins the warm-start / parsimony-waiver /
50    /// PSIS-optin economy (#1575); not part of the statistical contract.
51    pub inner_pirls_solves: usize,
52}
53
54#[derive(Clone)]
55pub struct ExternalOptimOptions {
56    pub family: gam_problem::LikelihoodSpec,
57    pub latent_cloglog: Option<LatentCLogLogState>,
58    pub mixture_link: Option<MixtureLinkSpec>,
59    pub optimize_mixture: bool,
60    pub sas_link: Option<SasLinkSpec>,
61    pub optimize_sas: bool,
62    pub compute_inference: bool,
63    /// Internal lifecycle knob for fits whose result will be immediately
64    /// superseded. Keeps ordinary inference work but skips the live-objective
65    /// rho posterior certificate/escalation until the returned model is known.
66    pub skip_rho_posterior_inference: bool,
67    pub max_iter: usize,
68    pub tol: f64,
69    pub nullspace_dims: Vec<usize>,
70    pub linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
71    /// Optional explicit Firth override for external fitting families that
72    /// support Jeffreys/Firth bias reduction.
73    /// - `Some(true)`: force Firth on
74    /// - `Some(false)`: force Firth off
75    /// - `None`: use family default behavior
76    pub firth_bias_reduction: Option<bool>,
77    /// Relative shrinkage floor for penalized block eigenvalues.
78    /// See [`FitOptions::penalty_shrinkage_floor`] for details.
79    pub penalty_shrinkage_floor: Option<f64>,
80    /// Fixed prior on smoothing parameters for explicit joint HMC sampling
81    /// flows. Standard fitting stays on the REML/Laplace path.
82    pub rho_prior: gam_problem::RhoPrior,
83    /// Kronecker-factored penalty system for tensor-product smooth terms.
84    pub kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
85    /// Full Kronecker factored basis for P-IRLS factored reparameterization.
86    pub kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
87    /// Engage the cross-process ON-DISK persistent warm-start layer for this
88    /// fit. Default `false`: only the in-memory warm start runs, so throwaway /
89    /// replicate / CI-coverage loops pay no disk I/O (#1082). A caller that
90    /// wants cross-process resume threads `true` down from
91    /// `FitConfig::persist_warm_start_disk`; the standard `RemlState`
92    /// constructor then calls `enable_persistent_warm_start_disk()`.
93    pub persist_warm_start_disk: bool,
94}
95
96pub(crate) fn resolve_external_family(
97    family: &gam_problem::LikelihoodSpec,
98    firth_override: Option<bool>,
99) -> Result<(GlmLikelihoodSpec, bool), EstimationError> {
100    let external_glm_supported = match (&family.response, family.link_function()) {
101        (ResponseFamily::Gaussian, LinkFunction::Identity)
102        | (ResponseFamily::Poisson, LinkFunction::Log)
103        | (ResponseFamily::Gamma, LinkFunction::Log)
104        | (ResponseFamily::Tweedie { .. }, LinkFunction::Log)
105        | (ResponseFamily::NegativeBinomial { .. }, LinkFunction::Log)
106        | (ResponseFamily::Binomial, LinkFunction::Logit)
107        | (ResponseFamily::Binomial, LinkFunction::Probit)
108        | (ResponseFamily::Binomial, LinkFunction::CLogLog)
109        | (ResponseFamily::Binomial, LinkFunction::Sas)
110        | (ResponseFamily::Binomial, LinkFunction::BetaLogistic) => true,
111        // Beta regression with a constant precision φ is a genuine-dispersion
112        // mean family on par with Gamma/Tweedie/Negative-Binomial: the inner
113        // P-IRLS carries its full fixed-φ Fisher information and the outer loop
114        // estimates φ by the Pearson moment estimator (`estimate_beta_phi_from_eta`,
115        // mirroring the Tweedie φ / Gamma shape / NegBin θ locks). A
116        // `noise_formula` upgrades it to a dispersion-location-scale model that
117        // smooths log φ; without one, the external GLM route fits the mean with
118        // a single estimated φ exactly as betareg does by default.
119        (ResponseFamily::Beta { .. }, LinkFunction::Logit) => true,
120        _ => false,
121    };
122    if !external_glm_supported {
123        crate::bail_invalid_estim!(
124            "optimize_external_design requires a supported standard GLM family/link; got {}. \
125             The external-design route supports Gaussian(identity), Binomial(logit/probit/cloglog/SAS/Beta-Logistic), \
126             Beta(logit), and Poisson/Gamma/Tweedie/Negative-Binomial(log). For Beta precision modeling \
127             add a noise_formula to upgrade to the dispersion-location-scale route",
128            family.pretty_name(),
129        );
130    }
131
132    let supports_firth = family.supports_firth();
133    if firth_override == Some(true) && !supports_firth {
134        crate::bail_invalid_estim!(
135            "firth_bias_reduction requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
136            family.pretty_name(),
137        );
138    }
139
140    if let ResponseFamily::Tweedie { p } = &family.response {
141        if !gam_problem::is_valid_tweedie_power(*p) {
142            crate::bail_invalid_estim!("optimize_external_design requires a GLM family; Tweedie variance power must be finite and strictly between 1 and 2; use PoissonLog or GammaLog for boundary cases"
143                    .to_string(),);
144        }
145    }
146    Ok((
147        GlmLikelihoodSpec::canonical(family.clone()),
148        firth_override.unwrap_or(false) && supports_firth,
149    ))
150}
151
152#[inline]
153pub(crate) fn effective_sas_link_for_family(
154    family: &gam_problem::LikelihoodSpec,
155    sas_link: Option<SasLinkSpec>,
156) -> Option<SasLinkSpec> {
157    if (family.is_binomial_sas() || family.is_binomial_beta_logistic()) && sas_link.is_none() {
158        Some(SasLinkSpec {
159            initial_epsilon: 0.0,
160            initial_log_delta: 0.0,
161        })
162    } else {
163        sas_link
164    }
165}
166
167#[inline]
168pub(crate) fn resolved_external_inverse_link(
169    link: LinkFunction,
170    latent_cloglog: Option<LatentCLogLogState>,
171    mixture_link: Option<&MixtureLinkSpec>,
172    sas_link: Option<SasLinkSpec>,
173) -> Result<InverseLink, EstimationError> {
174    if let Some(state) = latent_cloglog {
175        return Ok(InverseLink::LatentCLogLog(state));
176    }
177    if let Some(spec) = mixture_link {
178        return Ok(InverseLink::Mixture(state_fromspec(spec).map_err(|e| {
179            EstimationError::InvalidInput(format!("invalid blended inverse link: {e}"))
180        })?));
181    }
182    if let Some(spec) = sas_link {
183        return Ok(match link {
184            LinkFunction::BetaLogistic => {
185                InverseLink::BetaLogistic(state_from_beta_logisticspec(spec).map_err(|e| {
186                    EstimationError::InvalidInput(format!("invalid Beta-Logistic link: {e}"))
187                })?)
188            }
189            _ => InverseLink::Sas(
190                state_from_sasspec(spec)
191                    .map_err(|e| EstimationError::InvalidInput(format!("invalid SAS link: {e}")))?,
192            ),
193        });
194    }
195    Ok(InverseLink::Standard(StandardLink::try_from(link).map_err(|e| {
196        EstimationError::InvalidInput(format!(
197            "inverse link resolution: {e}; supply `sas_link` or `latent_cloglog` configuration for state-bearing links"
198        ))
199    })?))
200}
201
202#[inline]
203pub(crate) fn resolved_external_config(
204    opts: &ExternalOptimOptions,
205) -> Result<(RemlConfig, Option<SasLinkSpec>), EstimationError> {
206    if opts.latent_cloglog.is_some() && (opts.mixture_link.is_some() || opts.sas_link.is_some()) {
207        crate::bail_invalid_estim!(
208            "latent_cloglog cannot be combined with mixture_link or sas_link"
209        );
210    }
211    if opts.mixture_link.is_some() && opts.sas_link.is_some() {
212        crate::bail_invalid_estim!("mixture_link and sas_link are mutually exclusive");
213    }
214    if opts.family.is_latent_cloglog() && opts.latent_cloglog.is_none() {
215        crate::bail_invalid_estim!("BinomialLatentCLogLog requires latent_cloglog state");
216    }
217    if opts.latent_cloglog.is_some() && !opts.family.is_latent_cloglog() {
218        crate::bail_invalid_estim!("latent_cloglog is only supported with BinomialLatentCLogLog");
219    }
220    let effective_sas_link = effective_sas_link_for_family(&opts.family, opts.sas_link);
221    let (likelihood, firth_active) =
222        resolve_external_family(&opts.family, opts.firth_bias_reduction)?;
223    let link = likelihood.link_function();
224    let mut cfg = RemlConfig::external(likelihood, opts.tol, firth_active);
225    cfg.link_kind = resolved_external_inverse_link(
226        link,
227        opts.latent_cloglog,
228        opts.mixture_link.as_ref(),
229        effective_sas_link,
230    )?;
231    Ok((cfg, effective_sas_link))
232}
233
234/// Shape/bounds validation for a single [`PenaltySpec`] against the total
235/// coefficient width `p`. Canonical home for the block/dense shape checks that
236/// were duplicated inline in `terms::construction`'s fused validate-and-
237/// destructure path; both call this so the diagnostics stay identical.
238pub(crate) fn validate_penalty_spec_shape(
239    idx: usize,
240    spec: &PenaltySpec,
241    p: usize,
242    context: &str,
243) -> Result<(), EstimationError> {
244    match spec {
245        PenaltySpec::Block {
246            local, col_range, ..
247        } => {
248            let bd = col_range.len();
249            if local.nrows() != bd || local.ncols() != bd {
250                crate::bail_invalid_estim!(
251                    "{context}: block penalty {idx} local matrix must be {bd}x{bd}, got {}x{}",
252                    local.nrows(),
253                    local.ncols()
254                );
255            }
256            if col_range.end > p {
257                crate::bail_invalid_estim!(
258                    "{context}: block penalty {idx} col_range {}..{} exceeds p={p}",
259                    col_range.start,
260                    col_range.end
261                );
262            }
263        }
264        PenaltySpec::Dense(m) => {
265            if m.nrows() != p || m.ncols() != p {
266                crate::bail_invalid_estim!(
267                    "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
268                    m.nrows(),
269                    m.ncols()
270                );
271            }
272        }
273        PenaltySpec::DenseWithMean { matrix, .. } => {
274            if matrix.nrows() != p || matrix.ncols() != p {
275                crate::bail_invalid_estim!(
276                    "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
277                    matrix.nrows(),
278                    matrix.ncols()
279                );
280            }
281        }
282    }
283    Ok(())
284}