Skip to main content

gam_problem/
seeding.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum SeedRiskProfile {
3    Gaussian,
4    /// Gaussian location-scale keeps Gaussian's lowest-REML keep-best policy,
5    /// but its non-profiled log-scale predictor has the same capped-screening
6    /// over-smoothing risk as multi-parameter likelihoods.
7    GaussianLocationScale,
8    GeneralizedLinear,
9    Survival,
10}
11
12impl SeedRiskProfile {
13    #[inline]
14    pub const fn anchor_rho_shift(self) -> f64 {
15        match self {
16            Self::Gaussian | Self::GaussianLocationScale => 0.0,
17            Self::GeneralizedLinear => 1.0,
18            Self::Survival => 2.0,
19        }
20    }
21
22    #[inline]
23    pub const fn baseline_centers(self) -> &'static [f64] {
24        match self {
25            Self::Gaussian | Self::GaussianLocationScale => &[0.0, -3.0, 3.0, -6.0, 6.0],
26            Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -2.0],
27            Self::Survival => &[0.0, 2.0, 4.0, 6.0],
28        }
29    }
30
31    #[inline]
32    pub const fn global_shifts(self) -> &'static [f64] {
33        match self {
34            Self::Gaussian | Self::GaussianLocationScale => &[-2.0, 2.0, -4.0, 4.0],
35            Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -1.0, -2.0, -4.0],
36            Self::Survival => &[0.0, 2.0, 4.0, 6.0, -2.0, -4.0],
37        }
38    }
39
40    #[inline]
41    pub const fn exploratory_amplitude(self) -> f64 {
42        match self {
43            Self::Gaussian | Self::GaussianLocationScale => 2.0,
44            Self::GeneralizedLinear => 2.5,
45            Self::Survival => 3.0,
46        }
47    }
48
49    #[inline]
50    pub const fn promotes_interior_seed_extremes(self) -> bool {
51        matches!(
52            self,
53            Self::GaussianLocationScale | Self::GeneralizedLinear | Self::Survival
54        )
55    }
56
57    #[inline]
58    pub const fn uses_parsimonious_keep_best(self) -> bool {
59        matches!(self, Self::GeneralizedLinear | Self::Survival)
60    }
61
62    #[inline]
63    pub const fn uses_lowest_cost_keep_best(self) -> bool {
64        matches!(self, Self::Gaussian | Self::GaussianLocationScale)
65    }
66}
67
68#[derive(Clone, Copy, Debug)]
69pub struct SeedConfig {
70    pub bounds: (f64, f64),
71    pub max_seeds: usize,
72    /// Maximum number of seed starts to run in heuristic order.
73    pub seed_budget: usize,
74    /// Initial inner-iteration cap used while ranking candidate seeds.
75    pub screen_max_inner_iterations: usize,
76    pub risk_profile: SeedRiskProfile,
77    /// Number of trailing dimensions that are auxiliary parameters rather than
78    /// log-smoothing parameters.
79    pub num_auxiliary_trailing: usize,
80    /// Optional absolute over-smoothing probe on every smoothing dimension.
81    pub over_smoothing_probe_rho: Option<f64>,
82}
83
84impl Default for SeedConfig {
85    fn default() -> Self {
86        Self {
87            bounds: (-12.0, 12.0),
88            max_seeds: 12,
89            seed_budget: 2,
90            screen_max_inner_iterations: 3,
91            risk_profile: SeedRiskProfile::GeneralizedLinear,
92            num_auxiliary_trailing: 0,
93            over_smoothing_probe_rho: None,
94        }
95    }
96}
97
98#[inline]
99pub fn normalize_seed_bounds(bounds: (f64, f64)) -> (f64, f64) {
100    if bounds.0 <= bounds.1 {
101        bounds
102    } else {
103        (bounds.1, bounds.0)
104    }
105}
106
107#[inline]
108pub fn clamp_seed_rho_to_bounds(value: f64, bounds: (f64, f64)) -> f64 {
109    let (lo, hi) = normalize_seed_bounds(bounds);
110    value.clamp(lo, hi)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    // ── anchor_rho_shift ───────────────────────────────────────────────────────
118
119    #[test]
120    fn anchor_rho_shift_gaussian_is_zero() {
121        assert_eq!(SeedRiskProfile::Gaussian.anchor_rho_shift(), 0.0);
122        assert_eq!(SeedRiskProfile::GaussianLocationScale.anchor_rho_shift(), 0.0);
123    }
124
125    #[test]
126    fn anchor_rho_shift_generalized_linear_is_one() {
127        assert_eq!(SeedRiskProfile::GeneralizedLinear.anchor_rho_shift(), 1.0);
128    }
129
130    #[test]
131    fn anchor_rho_shift_survival_is_two() {
132        assert_eq!(SeedRiskProfile::Survival.anchor_rho_shift(), 2.0);
133    }
134
135    // ── promotes_interior_seed_extremes ───────────────────────────────────────
136
137    #[test]
138    fn promotes_interior_extremes_false_for_gaussian_only() {
139        assert!(!SeedRiskProfile::Gaussian.promotes_interior_seed_extremes());
140        assert!(SeedRiskProfile::GaussianLocationScale.promotes_interior_seed_extremes());
141        assert!(SeedRiskProfile::GeneralizedLinear.promotes_interior_seed_extremes());
142        assert!(SeedRiskProfile::Survival.promotes_interior_seed_extremes());
143    }
144
145    // ── keep-best policy flags ────────────────────────────────────────────────
146
147    #[test]
148    fn parsimonious_keep_best_only_for_glm_and_survival() {
149        assert!(!SeedRiskProfile::Gaussian.uses_parsimonious_keep_best());
150        assert!(!SeedRiskProfile::GaussianLocationScale.uses_parsimonious_keep_best());
151        assert!(SeedRiskProfile::GeneralizedLinear.uses_parsimonious_keep_best());
152        assert!(SeedRiskProfile::Survival.uses_parsimonious_keep_best());
153    }
154
155    #[test]
156    fn lowest_cost_keep_best_only_for_gaussian_variants() {
157        assert!(SeedRiskProfile::Gaussian.uses_lowest_cost_keep_best());
158        assert!(SeedRiskProfile::GaussianLocationScale.uses_lowest_cost_keep_best());
159        assert!(!SeedRiskProfile::GeneralizedLinear.uses_lowest_cost_keep_best());
160        assert!(!SeedRiskProfile::Survival.uses_lowest_cost_keep_best());
161    }
162
163    // ── normalize_seed_bounds ─────────────────────────────────────────────────
164
165    #[test]
166    fn normalize_already_ordered_bounds_unchanged() {
167        assert_eq!(normalize_seed_bounds((-3.0, 5.0)), (-3.0, 5.0));
168    }
169
170    #[test]
171    fn normalize_reversed_bounds_swaps() {
172        assert_eq!(normalize_seed_bounds((5.0, -3.0)), (-3.0, 5.0));
173    }
174
175    #[test]
176    fn normalize_equal_bounds_unchanged() {
177        assert_eq!(normalize_seed_bounds((2.0, 2.0)), (2.0, 2.0));
178    }
179
180    // ── clamp_seed_rho_to_bounds ──────────────────────────────────────────────
181
182    #[test]
183    fn clamp_within_bounds_returns_value() {
184        assert_eq!(clamp_seed_rho_to_bounds(1.0, (-3.0, 5.0)), 1.0);
185    }
186
187    #[test]
188    fn clamp_below_lo_returns_lo() {
189        assert_eq!(clamp_seed_rho_to_bounds(-10.0, (-3.0, 5.0)), -3.0);
190    }
191
192    #[test]
193    fn clamp_above_hi_returns_hi() {
194        assert_eq!(clamp_seed_rho_to_bounds(100.0, (-3.0, 5.0)), 5.0);
195    }
196
197    #[test]
198    fn clamp_normalizes_reversed_bounds_before_clamping() {
199        // bounds (5, -3) normalizes to (-3, 5); 100 clamps to 5
200        assert_eq!(clamp_seed_rho_to_bounds(100.0, (5.0, -3.0)), 5.0);
201    }
202}