1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum SeedRiskProfile {
3 Gaussian,
4 GaussianLocationScale,
8 GeneralizedLinear,
9 Survival,
10}
11
12impl SeedRiskProfile {
13 #[inline]
14 pub const fn anchor_rho_shift(self) -> f64 {
15 match self {
16 Self::Gaussian | Self::GaussianLocationScale => 0.0,
17 Self::GeneralizedLinear => 1.0,
18 Self::Survival => 2.0,
19 }
20 }
21
22 #[inline]
23 pub const fn baseline_centers(self) -> &'static [f64] {
24 match self {
25 Self::Gaussian | Self::GaussianLocationScale => &[0.0, -3.0, 3.0, -6.0, 6.0],
26 Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -2.0],
27 Self::Survival => &[0.0, 2.0, 4.0, 6.0],
28 }
29 }
30
31 #[inline]
32 pub const fn global_shifts(self) -> &'static [f64] {
33 match self {
34 Self::Gaussian | Self::GaussianLocationScale => &[-2.0, 2.0, -4.0, 4.0],
35 Self::GeneralizedLinear => &[0.0, 2.0, 4.0, -1.0, -2.0, -4.0],
36 Self::Survival => &[0.0, 2.0, 4.0, 6.0, -2.0, -4.0],
37 }
38 }
39
40 #[inline]
41 pub const fn exploratory_amplitude(self) -> f64 {
42 match self {
43 Self::Gaussian | Self::GaussianLocationScale => 2.0,
44 Self::GeneralizedLinear => 2.5,
45 Self::Survival => 3.0,
46 }
47 }
48
49 #[inline]
50 pub const fn promotes_interior_seed_extremes(self) -> bool {
51 matches!(
65 self,
66 Self::Gaussian
67 | Self::GaussianLocationScale
68 | Self::GeneralizedLinear
69 | Self::Survival
70 )
71 }
72
73 #[inline]
74 pub const fn uses_parsimonious_keep_best(self) -> bool {
75 matches!(self, Self::GeneralizedLinear | Self::Survival)
76 }
77
78 #[inline]
79 pub const fn uses_lowest_cost_keep_best(self) -> bool {
80 matches!(self, Self::Gaussian | Self::GaussianLocationScale)
81 }
82}
83
84#[derive(Clone, Copy, Debug)]
85pub struct SeedConfig {
86 pub bounds: (f64, f64),
87 pub max_seeds: usize,
88 pub seed_budget: usize,
90 pub screen_max_inner_iterations: usize,
92 pub risk_profile: SeedRiskProfile,
93 pub num_auxiliary_trailing: usize,
96 pub over_smoothing_probe_rho: Option<f64>,
98}
99
100impl Default for SeedConfig {
101 fn default() -> Self {
102 Self {
103 bounds: (-12.0, 12.0),
104 max_seeds: 12,
105 seed_budget: 2,
106 screen_max_inner_iterations: 3,
107 risk_profile: SeedRiskProfile::GeneralizedLinear,
108 num_auxiliary_trailing: 0,
109 over_smoothing_probe_rho: None,
110 }
111 }
112}
113
114#[inline]
115pub fn normalize_seed_bounds(bounds: (f64, f64)) -> (f64, f64) {
116 if bounds.0 <= bounds.1 {
117 bounds
118 } else {
119 (bounds.1, bounds.0)
120 }
121}
122
123#[inline]
124pub fn clamp_seed_rho_to_bounds(value: f64, bounds: (f64, f64)) -> f64 {
125 let (lo, hi) = normalize_seed_bounds(bounds);
126 value.clamp(lo, hi)
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
136 fn anchor_rho_shift_gaussian_is_zero() {
137 assert_eq!(SeedRiskProfile::Gaussian.anchor_rho_shift(), 0.0);
138 assert_eq!(SeedRiskProfile::GaussianLocationScale.anchor_rho_shift(), 0.0);
139 }
140
141 #[test]
142 fn anchor_rho_shift_generalized_linear_is_one() {
143 assert_eq!(SeedRiskProfile::GeneralizedLinear.anchor_rho_shift(), 1.0);
144 }
145
146 #[test]
147 fn anchor_rho_shift_survival_is_two() {
148 assert_eq!(SeedRiskProfile::Survival.anchor_rho_shift(), 2.0);
149 }
150
151 #[test]
154 fn promotes_interior_extremes_for_all_profiles() {
155 assert!(SeedRiskProfile::Gaussian.promotes_interior_seed_extremes());
165 assert!(SeedRiskProfile::GaussianLocationScale.promotes_interior_seed_extremes());
166 assert!(SeedRiskProfile::GeneralizedLinear.promotes_interior_seed_extremes());
167 assert!(SeedRiskProfile::Survival.promotes_interior_seed_extremes());
168 }
169
170 #[test]
173 fn parsimonious_keep_best_only_for_glm_and_survival() {
174 assert!(!SeedRiskProfile::Gaussian.uses_parsimonious_keep_best());
175 assert!(!SeedRiskProfile::GaussianLocationScale.uses_parsimonious_keep_best());
176 assert!(SeedRiskProfile::GeneralizedLinear.uses_parsimonious_keep_best());
177 assert!(SeedRiskProfile::Survival.uses_parsimonious_keep_best());
178 }
179
180 #[test]
181 fn lowest_cost_keep_best_only_for_gaussian_variants() {
182 assert!(SeedRiskProfile::Gaussian.uses_lowest_cost_keep_best());
183 assert!(SeedRiskProfile::GaussianLocationScale.uses_lowest_cost_keep_best());
184 assert!(!SeedRiskProfile::GeneralizedLinear.uses_lowest_cost_keep_best());
185 assert!(!SeedRiskProfile::Survival.uses_lowest_cost_keep_best());
186 }
187
188 #[test]
191 fn normalize_already_ordered_bounds_unchanged() {
192 assert_eq!(normalize_seed_bounds((-3.0, 5.0)), (-3.0, 5.0));
193 }
194
195 #[test]
196 fn normalize_reversed_bounds_swaps() {
197 assert_eq!(normalize_seed_bounds((5.0, -3.0)), (-3.0, 5.0));
198 }
199
200 #[test]
201 fn normalize_equal_bounds_unchanged() {
202 assert_eq!(normalize_seed_bounds((2.0, 2.0)), (2.0, 2.0));
203 }
204
205 #[test]
208 fn clamp_within_bounds_returns_value() {
209 assert_eq!(clamp_seed_rho_to_bounds(1.0, (-3.0, 5.0)), 1.0);
210 }
211
212 #[test]
213 fn clamp_below_lo_returns_lo() {
214 assert_eq!(clamp_seed_rho_to_bounds(-10.0, (-3.0, 5.0)), -3.0);
215 }
216
217 #[test]
218 fn clamp_above_hi_returns_hi() {
219 assert_eq!(clamp_seed_rho_to_bounds(100.0, (-3.0, 5.0)), 5.0);
220 }
221
222 #[test]
223 fn clamp_normalizes_reversed_bounds_before_clamping() {
224 assert_eq!(clamp_seed_rho_to_bounds(100.0, (5.0, -3.0)), 5.0);
226 }
227}