Skip to main content

gam_solve/estimate/
mod.rs

1//! # Model Estimation via Penalized Likelihood and REML
2//!
3//! This module orchestrates the core model fitting procedure for Generalized Additive
4//! Models (GAMs). It determines optimal smoothing parameters directly from the data,
5//! moving beyond simple hyperparameter-driven models. This is achieved through a
6//! nested optimization scheme, a standard approach for this class of models:
7//!
8//! 1.  Outer Loop (planner-selected optimizer): Optimizes the log-smoothing
9//!     parameters (`rho`) by maximizing a marginal likelihood criterion. For
10//!     non-Gaussian models (e.g., Logit), this is the Laplace Approximate
11//!     Marginal Likelihood (LAML). The concrete solver is chosen centrally by
12//!     `rho_optimizer` from the derivative capability of the model path:
13//!     ARC with analytic Hessian when available, BFGS for gradient-only
14//!     problems, and EFS / hybrid EFS when the hyperparameter geometry
15//!     admits those fixed-point updates.
16//!
17//! 2.  Inner Loop (P-IRLS): For each set of trial smoothing parameters from the
18//!     outer loop, this routine finds the corresponding model coefficients (`beta`) by
19//!     running a Penalized Iteratively Reweighted Least Squares (P-IRLS) algorithm
20//!     to convergence.
21//!
22//! This two-tiered structure allows the model to learn the appropriate complexity for
23//! each smooth term directly from the data.
24
25use crate::estimate::reml::{DirectionalHyperParam, RemlState};
26use std::fmt;
27use std::time::Instant;
28
29// Crate-level imports
30use gam_terms::construction::{CanonicalPenalty, ReparamInvariant};
31use gam_linalg::utils::{
32    KahanSum, add_relative_diag_ridge, matrix_inversewith_regularization, row_mismatch_message,
33};
34use gam_linalg::matrix::{DesignMatrix, FactorizedSystem, LinearOperator};
35use crate::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
36pub use crate::model_types::{CoefficientPriorMean, Dispersion, EstimationError, PenaltySpec};
37use crate::pirls::{self, PirlsResult};
38use gam_problem::{SeedConfig, SeedRiskProfile};
39use gam_terms::smooth::BlockwisePenalty;
40use gam_problem::{
41    Coefficients, GlmLikelihoodSpec, InverseLink, LatentCLogLogState, LikelihoodScaleMetadata,
42    LikelihoodSpec, LinkFunction, LogLikelihoodNormalization, LogSmoothingParamsView,
43    ResponseFamily, RidgePassport, StandardLink,
44};
45use gam_problem::{MixtureLinkSpec, SasLinkSpec};
46
47// Ndarray and faer linear algebra helpers
48use ndarray::{Array1, Array2, ArrayView1, Axis, s};
49// faer: high-performance dense solvers
50use gam_linalg::faer_ndarray::{FaerArrayView, FaerCholesky, FaerEigh, fast_ab, fast_atb};
51use faer::{MatRef, Side};
52use rayon::prelude::*;
53
54// Note: deflateweights_by_se was removed. We now use integrated (GHQ)
55// family-dispatched likelihood updates in PIRLS instead of weight deflation.
56// The SE is passed through to PIRLS which integrates over uncertainty
57// in the likelihood, rather than using ad-hoc weight adjustment.
58
59use std::sync::Arc;
60
61#[path = "../reml/mod.rs"]
62pub mod reml;
63
64pub use reml::reml_outer_engine::PenaltyCoordinate;
65
66mod evaluation;
67mod external_options;
68mod fit;
69mod joint_hyper;
70mod optimizer;
71mod penalty;
72mod prefit;
73pub(crate) mod smoothing_correction;
74mod summary;
75
76pub use crate::model_types::result_types::dispersion_from_likelihood;
77pub use crate::model_types::{
78    AdaptiveRegularizationOptions, BlockRole, FitArtifacts, FitGeometry, FitInference, FitOptions,
79    FittedBlock, FittedLinkState, UnifiedFitResult, UnifiedFitResultParts,
80    saved_latent_cloglog_state_from_fit, saved_mixture_state_from_fit, saved_sas_state_from_fit,
81    validate_dense_hessian_export, validate_explicit_dense_hessian_for_whitening,
82};
83pub use gam_problem::{ensure_finite_scalar, validate_all_finite};
84pub use evaluation::{
85    evaluate_external_ift_residual_at_perturbed_rho, evaluate_externalcost_andridge,
86    evaluate_externalgradient,
87};
88pub(crate) use evaluation::{
89    materialize_link_outer_hessian, sas_effective_epsilon, sas_effective_epsilon_second,
90    sas_log_delta_edge_barriercostgrad, sas_log_delta_edge_barriercostgradhess,
91    sas_log_deltaridgeweight,
92};
93pub use external_options::{ExternalOptimOptions, ExternalOptimResult};
94pub(crate) use external_options::{
95    effective_sas_link_for_family, resolved_external_config, validate_penalty_spec_shape,
96};
97pub use fit::{fit_gam, fit_gam_with_penalty_specs, fit_gamwith_heuristic_lambdas};
98pub use joint_hyper::ExternalJointHyperEvaluator;
99pub(crate) use optimizer::optimize_external_designwith_heuristic_lambdas_andwarm_start;
100pub use optimizer::{optimize_external_design, optimize_external_designwith_heuristic_lambdas};
101pub(crate) use penalty::{
102    ParametricColumnConditioning, REML_CONTINUATION_PREWARM_RHO_CAP, REML_SECOND_ORDER_RHO_CAP,
103    REML_SEED_SCREENING_RHO_CAP, faer_frob_inner, kahan_sum, map_hessian_to_original_basis,
104    scaled_covariance,
105};
106pub(crate) use prefit::{
107    reject_prefit_binomial_separation, reject_prefit_unpenalized_rank_deficiency,
108    validate_penalty_specs,
109};
110pub(crate) use smoothing_correction::{
111    AUTO_CUBATURE_BOUNDARY_MARGIN, AUTO_CUBATURE_MAX_BETA_DIM, AUTO_CUBATURE_MAX_EIGENVECTORS,
112    AUTO_CUBATURE_MAX_RHO_DIM, AUTO_CUBATURE_TARGET_VAR_FRAC, MAX_FACTORIZATION_ATTEMPTS,
113    RHO_SOFT_PRIOR_SHARPNESS, RHO_SOFT_PRIOR_WEIGHT, RemlConfig, compute_smoothing_correction,
114    smooth_floor_dp,
115};
116// #1521 carve: the spatial-optimization driver reads the unified rho bound as
117// `gam_solve::estimate::RHO_BOUND`.
118pub use smoothing_correction::RHO_BOUND;
119pub use summary::{
120    ContinuousSmoothnessOrder, ContinuousSmoothnessOrderStatus, ModelSummary,
121    ParametricTermSummary, SmoothTermSummary, compute_continuous_smoothness_order,
122};
123
124#[cfg(test)]
125mod continuous_order_tests;
126#[cfg(test)]
127mod estimate_policy_tests;
128#[cfg(test)]
129mod invert_regularized_rho_hessian_tests;
130#[cfg(test)]
131mod tests_diagnostics;