Skip to main content

gam_problem/
seeding.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum SeedRiskProfile {
3    Gaussian,
4    /// Gaussian location-scale keeps Gaussian's lowest-REML keep-best policy,
5    /// but its non-profiled log-scale predictor has the same capped-screening
6    /// over-smoothing risk as multi-parameter likelihoods.
7    GaussianLocationScale,
8    GeneralizedLinear,
9    Survival,
10}
11
12impl SeedRiskProfile {
13    #[inline]
14    pub const fn anchor_rho_shift(self) -> f64 {
15        match self {
16            Self::Gaussian | Self::GaussianLocationScale => 0.0,
17            Self::GeneralizedLinear => 1.0,
18            Self::Survival => 2.0,
19        }
20    }
21
22    #[inline]
23    pub const fn baseline_centers(self) -> &'static [f64] {
24        match self {
25            Self::Gaussian | Self::GaussianLocationScale => &[0.0, -3.0, 3.0, -6.0, 6.0],
26            Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -2.0],
27            Self::Survival => &[0.0, 2.0, 4.0, 6.0],
28        }
29    }
30
31    #[inline]
32    pub const fn global_shifts(self) -> &'static [f64] {
33        match self {
34            Self::Gaussian | Self::GaussianLocationScale => &[-2.0, 2.0, -4.0, 4.0],
35            Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -1.0, -2.0, -4.0],
36            Self::Survival => &[0.0, 2.0, 4.0, 6.0, -2.0, -4.0],
37        }
38    }
39
40    #[inline]
41    pub const fn exploratory_amplitude(self) -> f64 {
42        match self {
43            Self::Gaussian | Self::GaussianLocationScale => 2.0,
44            Self::GeneralizedLinear => 2.5,
45            Self::Survival => 3.0,
46        }
47    }
48
49    #[inline]
50    pub const fn promotes_interior_seed_extremes(self) -> bool {
51        matches!(
52            self,
53            Self::GaussianLocationScale | Self::GeneralizedLinear | Self::Survival
54        )
55    }
56
57    #[inline]
58    pub const fn uses_parsimonious_keep_best(self) -> bool {
59        matches!(self, Self::GeneralizedLinear | Self::Survival)
60    }
61
62    #[inline]
63    pub const fn uses_lowest_cost_keep_best(self) -> bool {
64        matches!(self, Self::Gaussian | Self::GaussianLocationScale)
65    }
66}
67
68#[derive(Clone, Copy, Debug)]
69pub struct SeedConfig {
70    pub bounds: (f64, f64),
71    pub max_seeds: usize,
72    /// Maximum number of seed starts to run in heuristic order.
73    pub seed_budget: usize,
74    /// Initial inner-iteration cap used while ranking candidate seeds.
75    pub screen_max_inner_iterations: usize,
76    pub risk_profile: SeedRiskProfile,
77    /// Number of trailing dimensions that are auxiliary parameters rather than
78    /// log-smoothing parameters.
79    pub num_auxiliary_trailing: usize,
80    /// Optional absolute over-smoothing probe on every smoothing dimension.
81    pub over_smoothing_probe_rho: Option<f64>,
82}
83
84impl Default for SeedConfig {
85    fn default() -> Self {
86        Self {
87            bounds: (-12.0, 12.0),
88            max_seeds: 12,
89            seed_budget: 2,
90            screen_max_inner_iterations: 3,
91            risk_profile: SeedRiskProfile::GeneralizedLinear,
92            num_auxiliary_trailing: 0,
93            over_smoothing_probe_rho: None,
94        }
95    }
96}
97
98#[inline]
99pub fn normalize_seed_bounds(bounds: (f64, f64)) -> (f64, f64) {
100    if bounds.0 <= bounds.1 {
101        bounds
102    } else {
103        (bounds.1, bounds.0)
104    }
105}
106
107#[inline]
108pub fn clamp_seed_rho_to_bounds(value: f64, bounds: (f64, f64)) -> f64 {
109    let (lo, hi) = normalize_seed_bounds(bounds);
110    value.clamp(lo, hi)
111}