Skip to main content

gam_problem/
seeding.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum SeedRiskProfile {
3    Gaussian,
4    /// Gaussian location-scale keeps Gaussian's lowest-REML keep-best policy,
5    /// but its non-profiled log-scale predictor has the same capped-screening
6    /// over-smoothing risk as multi-parameter likelihoods.
7    GaussianLocationScale,
8    GeneralizedLinear,
9    Survival,
10}
11
12impl SeedRiskProfile {
13    #[inline]
14    pub const fn anchor_rho_shift(self) -> f64 {
15        match self {
16            Self::Gaussian | Self::GaussianLocationScale => 0.0,
17            Self::GeneralizedLinear => 1.0,
18            Self::Survival => 2.0,
19        }
20    }
21
22    #[inline]
23    pub const fn baseline_centers(self) -> &'static [f64] {
24        match self {
25            Self::Gaussian | Self::GaussianLocationScale => &[0.0, -3.0, 3.0, -6.0, 6.0],
26            Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -2.0],
27            Self::Survival => &[0.0, 2.0, 4.0, 6.0],
28        }
29    }
30
31    #[inline]
32    pub const fn global_shifts(self) -> &'static [f64] {
33        match self {
34            Self::Gaussian | Self::GaussianLocationScale => &[-2.0, 2.0, -4.0, 4.0],
35            Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -1.0, -2.0, -4.0],
36            Self::Survival => &[0.0, 2.0, 4.0, 6.0, -2.0, -4.0],
37        }
38    }
39
40    #[inline]
41    pub const fn exploratory_amplitude(self) -> f64 {
42        match self {
43            Self::Gaussian | Self::GaussianLocationScale => 2.0,
44            Self::GeneralizedLinear => 2.5,
45            Self::Survival => 3.0,
46        }
47    }
48
49    #[inline]
50    pub const fn promotes_interior_seed_extremes(self) -> bool {
51        // Plain Gaussian REML's profiled-scale basin does NOT exhibit the
52        // capped-screening over-smoothing bias the other profiles do, so it does
53        // not need the flexible slot-0 promotion for its own sake. But a
54        // weak-signal Gaussian fit on an over-rich spatial basis has the
55        // OPPOSITE failure: REML descends from the heuristic anchor into the
56        // flexible (low-λ) basin and over-fits (#1074 quakes: edf≈104 vs mgcv≈15,
57        // held-out R²≈0.02), because the heavily-penalized basin is a separate
58        // attractor never seeded/solved. Promoting the heaviest INTERIOR seed to
59        // the second full-budget slot (paired with the Gaussian over-smoothing
60        // probe and seed_budget≥2 in `external_reml_seed_config`) lets the
61        // multi-start SEE that basin; Gaussian's lowest-cost keep-best
62        // (`uses_lowest_cost_keep_best`) then adopts it only when it scores a
63        // strictly lower REML, so this can never worsen a flexible fit.
64        matches!(
65            self,
66            Self::Gaussian
67                | Self::GaussianLocationScale
68                | Self::GeneralizedLinear
69                | Self::Survival
70        )
71    }
72
73    #[inline]
74    pub const fn uses_parsimonious_keep_best(self) -> bool {
75        matches!(self, Self::GeneralizedLinear | Self::Survival)
76    }
77
78    #[inline]
79    pub const fn uses_lowest_cost_keep_best(self) -> bool {
80        matches!(self, Self::Gaussian | Self::GaussianLocationScale)
81    }
82}
83
84#[derive(Clone, Copy, Debug)]
85pub struct SeedConfig {
86    pub bounds: (f64, f64),
87    pub max_seeds: usize,
88    /// Maximum number of seed starts to run in heuristic order.
89    pub seed_budget: usize,
90    /// Initial inner-iteration cap used while ranking candidate seeds.
91    pub screen_max_inner_iterations: usize,
92    pub risk_profile: SeedRiskProfile,
93    /// Number of trailing dimensions that are auxiliary parameters rather than
94    /// log-smoothing parameters.
95    pub num_auxiliary_trailing: usize,
96    /// Optional absolute over-smoothing probe on every smoothing dimension.
97    pub over_smoothing_probe_rho: Option<f64>,
98}
99
100impl Default for SeedConfig {
101    fn default() -> Self {
102        Self {
103            bounds: (-12.0, 12.0),
104            max_seeds: 12,
105            seed_budget: 2,
106            screen_max_inner_iterations: 3,
107            risk_profile: SeedRiskProfile::GeneralizedLinear,
108            num_auxiliary_trailing: 0,
109            over_smoothing_probe_rho: None,
110        }
111    }
112}
113
114#[inline]
115pub fn normalize_seed_bounds(bounds: (f64, f64)) -> (f64, f64) {
116    if bounds.0 <= bounds.1 {
117        bounds
118    } else {
119        (bounds.1, bounds.0)
120    }
121}
122
123#[inline]
124pub fn clamp_seed_rho_to_bounds(value: f64, bounds: (f64, f64)) -> f64 {
125    let (lo, hi) = normalize_seed_bounds(bounds);
126    value.clamp(lo, hi)
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    // ── anchor_rho_shift ───────────────────────────────────────────────────────
134
135    #[test]
136    fn anchor_rho_shift_gaussian_is_zero() {
137        assert_eq!(SeedRiskProfile::Gaussian.anchor_rho_shift(), 0.0);
138        assert_eq!(SeedRiskProfile::GaussianLocationScale.anchor_rho_shift(), 0.0);
139    }
140
141    #[test]
142    fn anchor_rho_shift_generalized_linear_is_one() {
143        assert_eq!(SeedRiskProfile::GeneralizedLinear.anchor_rho_shift(), 1.0);
144    }
145
146    #[test]
147    fn anchor_rho_shift_survival_is_two() {
148        assert_eq!(SeedRiskProfile::Survival.anchor_rho_shift(), 2.0);
149    }
150
151    // ── promotes_interior_seed_extremes ───────────────────────────────────────
152
153    #[test]
154    fn promotes_interior_extremes_for_all_profiles() {
155        // #1074: plain Gaussian was originally excluded (its profiled-scale REML
156        // basin has no capped-screening over-smoothing bias), but a weak-signal
157        // Gaussian fit on an over-rich basis has the OPPOSITE failure — it
158        // descends into the flexible (low-λ) basin and over-fits. Promoting the
159        // heaviest interior seed to the second full-budget slot (paired with the
160        // over-smoothing probe + `seed_budget ≥ 2`) lets the multi-start SEE the
161        // heavily-penalized basin; Gaussian's lowest-cost keep-best then adopts
162        // it only when it scores a strictly lower REML, so this can never worsen
163        // a flexible fit. Every risk profile now promotes the interior extremes.
164        assert!(SeedRiskProfile::Gaussian.promotes_interior_seed_extremes());
165        assert!(SeedRiskProfile::GaussianLocationScale.promotes_interior_seed_extremes());
166        assert!(SeedRiskProfile::GeneralizedLinear.promotes_interior_seed_extremes());
167        assert!(SeedRiskProfile::Survival.promotes_interior_seed_extremes());
168    }
169
170    // ── keep-best policy flags ────────────────────────────────────────────────
171
172    #[test]
173    fn parsimonious_keep_best_only_for_glm_and_survival() {
174        assert!(!SeedRiskProfile::Gaussian.uses_parsimonious_keep_best());
175        assert!(!SeedRiskProfile::GaussianLocationScale.uses_parsimonious_keep_best());
176        assert!(SeedRiskProfile::GeneralizedLinear.uses_parsimonious_keep_best());
177        assert!(SeedRiskProfile::Survival.uses_parsimonious_keep_best());
178    }
179
180    #[test]
181    fn lowest_cost_keep_best_only_for_gaussian_variants() {
182        assert!(SeedRiskProfile::Gaussian.uses_lowest_cost_keep_best());
183        assert!(SeedRiskProfile::GaussianLocationScale.uses_lowest_cost_keep_best());
184        assert!(!SeedRiskProfile::GeneralizedLinear.uses_lowest_cost_keep_best());
185        assert!(!SeedRiskProfile::Survival.uses_lowest_cost_keep_best());
186    }
187
188    // ── normalize_seed_bounds ─────────────────────────────────────────────────
189
190    #[test]
191    fn normalize_already_ordered_bounds_unchanged() {
192        assert_eq!(normalize_seed_bounds((-3.0, 5.0)), (-3.0, 5.0));
193    }
194
195    #[test]
196    fn normalize_reversed_bounds_swaps() {
197        assert_eq!(normalize_seed_bounds((5.0, -3.0)), (-3.0, 5.0));
198    }
199
200    #[test]
201    fn normalize_equal_bounds_unchanged() {
202        assert_eq!(normalize_seed_bounds((2.0, 2.0)), (2.0, 2.0));
203    }
204
205    // ── clamp_seed_rho_to_bounds ──────────────────────────────────────────────
206
207    #[test]
208    fn clamp_within_bounds_returns_value() {
209        assert_eq!(clamp_seed_rho_to_bounds(1.0, (-3.0, 5.0)), 1.0);
210    }
211
212    #[test]
213    fn clamp_below_lo_returns_lo() {
214        assert_eq!(clamp_seed_rho_to_bounds(-10.0, (-3.0, 5.0)), -3.0);
215    }
216
217    #[test]
218    fn clamp_above_hi_returns_hi() {
219        assert_eq!(clamp_seed_rho_to_bounds(100.0, (-3.0, 5.0)), 5.0);
220    }
221
222    #[test]
223    fn clamp_normalizes_reversed_bounds_before_clamping() {
224        // bounds (5, -3) normalizes to (-3, 5); 100 clamps to 5
225        assert_eq!(clamp_seed_rho_to_bounds(100.0, (5.0, -3.0)), 5.0);
226    }
227}