Skip to main content

gam_solve/estimate/
external_options.rs

1use super::*;
2
3pub struct ExternalOptimResult {
4    pub beta: Array1<f64>,
5    pub lambdas: Array1<f64>,
6    pub likelihood_family: LikelihoodSpec,
7    pub likelihood_scale: LikelihoodScaleMetadata,
8    pub log_likelihood_normalization: LogLikelihoodNormalization,
9    pub log_likelihood: f64,
10    /// Residual scale on the response scale.
11    ///
12    /// Contract: Gaussian identity models store the residual standard
13    /// deviation sigma here. Non-Gaussian families keep the response-scale
14    /// summary used by their explicit likelihood-scale metadata.
15    pub standard_deviation: f64,
16    pub iterations: usize,
17    pub finalgrad_norm: f64,
18    /// True iff the outer optimizer reached a stationary point (gradient
19    /// norm below tolerance), as reported by the optimizer itself. False
20    /// when the run exhausted its iteration budget without reaching the
21    /// gradient tolerance. Downstream consumers should NOT assume that a
22    /// fit with `outer_converged == false` is unusable — it may still be
23    /// the best basin reached given the budget — but they must not treat
24    /// it as certified-converged either.
25    pub outer_converged: bool,
26    pub pirls_status: crate::pirls::PirlsStatus,
27    pub deviance: f64,
28    /// Stable quadratic penalty term βᵀSβ, including any solver ridge quadratic.
29    pub stable_penalty_term: f64,
30    pub used_device: bool,
31    pub max_abs_eta: f64,
32    pub constraint_kkt: Option<crate::pirls::ConstraintKktDiagnostics>,
33    pub artifacts: FitArtifacts,
34    pub inference: Option<FitInference>,
35    /// Complete REML/LAML objective value used for smoothing selection.
36    pub reml_score: f64,
37    pub fitted_link: FittedLinkState,
38    /// Number of outer REML cost-only evaluations executed during the fit
39    /// (the count the outer optimizer's trust-region/line-search probes drive,
40    /// each paying an inner P-IRLS solve). Surfaced for regression guards on
41    /// outer work (#1575); not part of the statistical contract.
42    pub outer_cost_evals: usize,
43}
44
45#[derive(Clone)]
46pub struct ExternalOptimOptions {
47    pub family: gam_problem::LikelihoodSpec,
48    pub latent_cloglog: Option<LatentCLogLogState>,
49    pub mixture_link: Option<MixtureLinkSpec>,
50    pub optimize_mixture: bool,
51    pub sas_link: Option<SasLinkSpec>,
52    pub optimize_sas: bool,
53    pub compute_inference: bool,
54    /// Internal lifecycle knob for fits whose result will be immediately
55    /// superseded. Keeps ordinary inference work but skips the live-objective
56    /// rho posterior certificate/escalation until the returned model is known.
57    pub skip_rho_posterior_inference: bool,
58    pub max_iter: usize,
59    pub tol: f64,
60    pub nullspace_dims: Vec<usize>,
61    pub linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
62    /// Optional explicit Firth override for external fitting families that
63    /// support Jeffreys/Firth bias reduction.
64    /// - `Some(true)`: force Firth on
65    /// - `Some(false)`: force Firth off
66    /// - `None`: use family default behavior
67    pub firth_bias_reduction: Option<bool>,
68    /// Relative shrinkage floor for penalized block eigenvalues.
69    /// See [`FitOptions::penalty_shrinkage_floor`] for details.
70    pub penalty_shrinkage_floor: Option<f64>,
71    /// Fixed prior on smoothing parameters for explicit joint HMC sampling
72    /// flows. Standard fitting stays on the REML/Laplace path.
73    pub rho_prior: gam_problem::RhoPrior,
74    /// Kronecker-factored penalty system for tensor-product smooth terms.
75    pub kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
76    /// Full Kronecker factored basis for P-IRLS factored reparameterization.
77    pub kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
78    /// Engage the cross-process ON-DISK persistent warm-start layer for this
79    /// fit. Default `false`: only the in-memory warm start runs, so throwaway /
80    /// replicate / CI-coverage loops pay no disk I/O (#1082). A caller that
81    /// wants cross-process resume threads `true` down from
82    /// `FitConfig::persist_warm_start_disk`; the standard `RemlState`
83    /// constructor then calls `enable_persistent_warm_start_disk()`.
84    pub persist_warm_start_disk: bool,
85}
86
87pub(crate) fn resolve_external_family(
88    family: &gam_problem::LikelihoodSpec,
89    firth_override: Option<bool>,
90) -> Result<(GlmLikelihoodSpec, bool), EstimationError> {
91    let external_glm_supported = match (&family.response, family.link_function()) {
92        (ResponseFamily::Gaussian, LinkFunction::Identity)
93        | (ResponseFamily::Poisson, LinkFunction::Log)
94        | (ResponseFamily::Gamma, LinkFunction::Log)
95        | (ResponseFamily::Tweedie { .. }, LinkFunction::Log)
96        | (ResponseFamily::NegativeBinomial { .. }, LinkFunction::Log)
97        | (ResponseFamily::Binomial, LinkFunction::Logit)
98        | (ResponseFamily::Binomial, LinkFunction::Probit)
99        | (ResponseFamily::Binomial, LinkFunction::CLogLog)
100        | (ResponseFamily::Binomial, LinkFunction::Sas)
101        | (ResponseFamily::Binomial, LinkFunction::BetaLogistic) => true,
102        // Beta regression with a constant precision φ is a genuine-dispersion
103        // mean family on par with Gamma/Tweedie/Negative-Binomial: the inner
104        // P-IRLS carries its full fixed-φ Fisher information and the outer loop
105        // estimates φ by the Pearson moment estimator (`estimate_beta_phi_from_eta`,
106        // mirroring the Tweedie φ / Gamma shape / NegBin θ locks). A
107        // `noise_formula` upgrades it to a dispersion-location-scale model that
108        // smooths log φ; without one, the external GLM route fits the mean with
109        // a single estimated φ exactly as betareg does by default.
110        (ResponseFamily::Beta { .. }, LinkFunction::Logit) => true,
111        _ => false,
112    };
113    if !external_glm_supported {
114        crate::bail_invalid_estim!(
115            "optimize_external_design requires a supported standard GLM family/link; got {}. \
116             The external-design route supports Gaussian(identity), Binomial(logit/probit/cloglog/SAS/Beta-Logistic), \
117             Beta(logit), and Poisson/Gamma/Tweedie/Negative-Binomial(log). For Beta precision modeling \
118             add a noise_formula to upgrade to the dispersion-location-scale route",
119            family.pretty_name(),
120        );
121    }
122
123    let supports_firth = family.supports_firth();
124    if firth_override == Some(true) && !supports_firth {
125        crate::bail_invalid_estim!(
126            "firth_bias_reduction requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
127            family.pretty_name(),
128        );
129    }
130
131    if let ResponseFamily::Tweedie { p } = &family.response {
132        if !gam_problem::is_valid_tweedie_power(*p) {
133            crate::bail_invalid_estim!("optimize_external_design requires a GLM family; Tweedie variance power must be finite and strictly between 1 and 2; use PoissonLog or GammaLog for boundary cases"
134                    .to_string(),);
135        }
136    }
137    Ok((
138        GlmLikelihoodSpec::canonical(family.clone()),
139        firth_override.unwrap_or(false) && supports_firth,
140    ))
141}
142
143#[inline]
144pub(crate) fn effective_sas_link_for_family(
145    family: &gam_problem::LikelihoodSpec,
146    sas_link: Option<SasLinkSpec>,
147) -> Option<SasLinkSpec> {
148    if (family.is_binomial_sas() || family.is_binomial_beta_logistic()) && sas_link.is_none() {
149        Some(SasLinkSpec {
150            initial_epsilon: 0.0,
151            initial_log_delta: 0.0,
152        })
153    } else {
154        sas_link
155    }
156}
157
158#[inline]
159pub(crate) fn resolved_external_inverse_link(
160    link: LinkFunction,
161    latent_cloglog: Option<LatentCLogLogState>,
162    mixture_link: Option<&MixtureLinkSpec>,
163    sas_link: Option<SasLinkSpec>,
164) -> Result<InverseLink, EstimationError> {
165    if let Some(state) = latent_cloglog {
166        return Ok(InverseLink::LatentCLogLog(state));
167    }
168    if let Some(spec) = mixture_link {
169        return Ok(InverseLink::Mixture(state_fromspec(spec).map_err(|e| {
170            EstimationError::InvalidInput(format!("invalid blended inverse link: {e}"))
171        })?));
172    }
173    if let Some(spec) = sas_link {
174        return Ok(match link {
175            LinkFunction::BetaLogistic => {
176                InverseLink::BetaLogistic(state_from_beta_logisticspec(spec).map_err(|e| {
177                    EstimationError::InvalidInput(format!("invalid Beta-Logistic link: {e}"))
178                })?)
179            }
180            _ => InverseLink::Sas(
181                state_from_sasspec(spec)
182                    .map_err(|e| EstimationError::InvalidInput(format!("invalid SAS link: {e}")))?,
183            ),
184        });
185    }
186    Ok(InverseLink::Standard(StandardLink::try_from(link).map_err(|e| {
187        EstimationError::InvalidInput(format!(
188            "inverse link resolution: {e}; supply `sas_link` or `latent_cloglog` configuration for state-bearing links"
189        ))
190    })?))
191}
192
193#[inline]
194pub(crate) fn resolved_external_config(
195    opts: &ExternalOptimOptions,
196) -> Result<(RemlConfig, Option<SasLinkSpec>), EstimationError> {
197    if opts.latent_cloglog.is_some() && (opts.mixture_link.is_some() || opts.sas_link.is_some()) {
198        crate::bail_invalid_estim!(
199            "latent_cloglog cannot be combined with mixture_link or sas_link"
200        );
201    }
202    if opts.mixture_link.is_some() && opts.sas_link.is_some() {
203        crate::bail_invalid_estim!("mixture_link and sas_link are mutually exclusive");
204    }
205    if opts.family.is_latent_cloglog() && opts.latent_cloglog.is_none() {
206        crate::bail_invalid_estim!("BinomialLatentCLogLog requires latent_cloglog state");
207    }
208    if opts.latent_cloglog.is_some() && !opts.family.is_latent_cloglog() {
209        crate::bail_invalid_estim!("latent_cloglog is only supported with BinomialLatentCLogLog");
210    }
211    let effective_sas_link = effective_sas_link_for_family(&opts.family, opts.sas_link);
212    let (likelihood, firth_active) =
213        resolve_external_family(&opts.family, opts.firth_bias_reduction)?;
214    let link = likelihood.link_function();
215    let mut cfg = RemlConfig::external(likelihood, opts.tol, firth_active);
216    cfg.link_kind = resolved_external_inverse_link(
217        link,
218        opts.latent_cloglog,
219        opts.mixture_link.as_ref(),
220        effective_sas_link,
221    )?;
222    Ok((cfg, effective_sas_link))
223}
224
225/// Shape/bounds validation for a single [`PenaltySpec`] against the total
226/// coefficient width `p`. Canonical home for the block/dense shape checks that
227/// were duplicated inline in `terms::construction`'s fused validate-and-
228/// destructure path; both call this so the diagnostics stay identical.
229pub(crate) fn validate_penalty_spec_shape(
230    idx: usize,
231    spec: &PenaltySpec,
232    p: usize,
233    context: &str,
234) -> Result<(), EstimationError> {
235    match spec {
236        PenaltySpec::Block {
237            local, col_range, ..
238        } => {
239            let bd = col_range.len();
240            if local.nrows() != bd || local.ncols() != bd {
241                crate::bail_invalid_estim!(
242                    "{context}: block penalty {idx} local matrix must be {bd}x{bd}, got {}x{}",
243                    local.nrows(),
244                    local.ncols()
245                );
246            }
247            if col_range.end > p {
248                crate::bail_invalid_estim!(
249                    "{context}: block penalty {idx} col_range {}..{} exceeds p={p}",
250                    col_range.start,
251                    col_range.end
252                );
253            }
254        }
255        PenaltySpec::Dense(m) => {
256            if m.nrows() != p || m.ncols() != p {
257                crate::bail_invalid_estim!(
258                    "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
259                    m.nrows(),
260                    m.ncols()
261                );
262            }
263        }
264        PenaltySpec::DenseWithMean { matrix, .. } => {
265            if matrix.nrows() != p || matrix.ncols() != p {
266                crate::bail_invalid_estim!(
267                    "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
268                    matrix.nrows(),
269                    matrix.ncols()
270                );
271            }
272        }
273    }
274    Ok(())
275}