1use crate::stats::expression::{Evaluate, Expr};
2use crate::{MetricSet, Valid};
3use std::fmt::Debug;
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CycleShape {
7 Triangle,
8 Sine,
9}
10
11#[derive(Clone)]
15pub enum Rate {
16 Fixed(f32),
21 Linear(f32, f32, usize),
28 Exponential(f32, f32, usize),
35 Cyclical(f32, f32, usize, CycleShape),
43 Stepwise(Vec<(usize, f32)>),
51
52 Expr(Expr),
57}
58
59impl Rate {
60 pub fn get(&mut self, generation: usize, metrics: &MetricSet) -> f32 {
61 match self {
62 Rate::Expr(expr) => expr
63 .eval(metrics)
64 .ok()
65 .and_then(|v| v.extract())
66 .unwrap_or(0.0),
67 _ => self.get_by_index(generation),
68 }
69 }
70
71 pub fn get_by_index(&self, step: usize) -> f32 {
72 let f_step = step as f32;
73 match self {
74 Rate::Fixed(v) => *v,
75 Rate::Linear(start, end, steps) => {
76 if *steps == 0 {
77 return *end;
78 }
79
80 let t = (f_step / *steps as f32).min(1.0);
81 start + (end - start) * t
82 }
83 Rate::Exponential(start, end, half_life) => {
84 if *half_life == 0 {
85 return *end;
86 }
87
88 let decay = 0.5_f32.powf(f_step / *half_life as f32);
89 end + (start - end) * decay
90 }
91 Rate::Cyclical(min, max, period, shape) => {
92 let phase = (f_step % *period as f32) / *period as f32;
93 let tri = if phase < 0.5 {
94 phase * 2.0
95 } else {
96 (1.0 - phase) * 2.0
97 };
98
99 let s = match shape {
100 CycleShape::Triangle => tri,
101 CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
102 };
103
104 min + (max - min) * s
105 }
106 Rate::Stepwise(steps) => {
107 if steps.is_empty() {
108 return 0.0;
109 }
110
111 let mut last_value = steps[0].1;
112 for (s, v) in steps {
113 if step < *s {
114 break;
115 }
116
117 last_value = *v;
118 }
119
120 last_value
121 }
122 _ => 1.0,
123 }
124 }
125}
126
127impl Valid for Rate {
128 fn is_valid(&self) -> bool {
129 match self {
130 Rate::Fixed(v) => (0.0..=1.0).contains(v),
131 Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
132 Rate::Exponential(start, end, _) => {
133 (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
134 }
135 Rate::Cyclical(min, max, _, _) => {
136 (0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
137 }
138 Rate::Stepwise(steps) => {
139 if steps.is_empty() {
140 return false;
141 }
142
143 if steps[0].0 != 0 {
144 return false;
145 }
146
147 let mut last_step = 0;
148 for (s, v) in steps {
149 if *s < last_step || !(0.0..=1.0).contains(v) {
150 return false;
151 }
152 last_step = *s;
153 }
154
155 true
156 }
157 _ => true,
158 }
159 }
160}
161
162impl Default for Rate {
163 fn default() -> Self {
164 Rate::Fixed(1.0)
165 }
166}
167
168impl From<f32> for Rate {
169 fn from(value: f32) -> Self {
170 Rate::Fixed(value)
171 }
172}
173
174impl From<Vec<(usize, f32)>> for Rate {
175 fn from(steps: Vec<(usize, f32)>) -> Self {
176 Rate::Stepwise(steps)
177 }
178}
179
180impl From<Expr> for Rate {
181 fn from(expr: Expr) -> Self {
182 Rate::Expr(expr.compile())
183 }
184}
185
186impl Debug for Rate {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 match self {
189 Rate::Fixed(v) => write!(f, "Rate::Fixed({})", v),
190 Rate::Linear(start, end, steps) => {
191 write!(
192 f,
193 "Rate::Linear(start: {}, end: {}, steps: {})",
194 start, end, steps
195 )
196 }
197 Rate::Exponential(start, end, half_life) => write!(
198 f,
199 "Rate::Exponential(start: {}, end: {}, half_life: {})",
200 start, end, half_life
201 ),
202 Rate::Cyclical(min, max, period, shape) => write!(
203 f,
204 "Rate::Cyclical(min: {}, max: {}, period: {}, shape: {:?})",
205 min, max, period, shape
206 ),
207 Rate::Stepwise(steps) => write!(f, "Rate::Stepwise(steps: {:?})", steps),
208 Rate::Expr(_) => write!(f, "Rate::Expr(<function>)"),
209 }
210 }
211}
212
213impl PartialEq for Rate {
214 fn eq(&self, other: &Self) -> bool {
215 match (self, other) {
216 (Rate::Fixed(a), Rate::Fixed(b)) => a == b,
217 (Rate::Linear(a_start, a_end, a_steps), Rate::Linear(b_start, b_end, b_steps)) => {
218 a_start == b_start && a_end == b_end && a_steps == b_steps
219 }
220 (
221 Rate::Exponential(a_start, a_end, a_half_life),
222 Rate::Exponential(b_start, b_end, b_half_life),
223 ) => a_start == b_start && a_end == b_end && a_half_life == b_half_life,
224 (
225 Rate::Cyclical(a_min, a_max, a_period, a_shape),
226 Rate::Cyclical(b_min, b_max, b_period, b_shape),
227 ) => a_min == b_min && a_max == b_max && a_period == b_period && a_shape == b_shape,
228 (Rate::Stepwise(a_steps), Rate::Stepwise(b_steps)) => a_steps == b_steps,
229 (Rate::Expr(_), Rate::Expr(_)) => true,
232 _ => false,
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_rate_values() {
243 let fixed = Rate::Fixed(0.5);
244 assert_eq!(fixed.get_by_index(0), 0.5);
245 assert_eq!(fixed.get_by_index(10), 0.5);
246
247 let linear = Rate::Linear(0.0, 1.0, 10);
248 assert_eq!(linear.get_by_index(0), 0.0);
249 assert_eq!(linear.get_by_index(5), 0.5);
250 assert_eq!(linear.get_by_index(10), 1.0);
251 assert_eq!(linear.get_by_index(15), 1.0);
252
253 let exponential = Rate::Exponential(1.0, 0.1, 5);
254 assert!((exponential.get_by_index(0) - 1.0).abs() < 1e-6);
255 assert!((exponential.get_by_index(5) - 0.55).abs() < 1e-2);
256 assert!((exponential.get_by_index(10) - 0.325).abs() < 1e-2);
257
258 let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
259 assert!((cyclical.get_by_index(0) - 0.0).abs() < 1e-6);
260 assert!((cyclical.get_by_index(2) - 0.4).abs() < 1e-6);
261 assert!((cyclical.get_by_index(5) - 1.0).abs() < 1e-6);
262 assert!((cyclical.get_by_index(7) - 0.6).abs() < 1e-6);
263 assert!((cyclical.get_by_index(10) - 0.0).abs() < 1e-6);
264
265 let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
266 assert!((cyclical_sine.get_by_index(0) - 0.0).abs() < 1e-6);
267 assert!(
268 (cyclical_sine.get_by_index(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs()
269 < 1e-6
270 );
271 assert!(cyclical_sine.get_by_index(5).abs() < 1e-6);
272 assert!(
273 (cyclical_sine.get_by_index(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs()
274 < 1e-6
275 );
276 assert!((cyclical_sine.get_by_index(10) - 0.0).abs() < 1e-6);
277
278 let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
279 assert_eq!(stepwise.get_by_index(0), 0.0);
280 assert_eq!(stepwise.get_by_index(3), 0.0);
281 assert_eq!(stepwise.get_by_index(5), 0.5);
282 assert_eq!(stepwise.get_by_index(7), 0.5);
283 assert_eq!(stepwise.get_by_index(10), 1.0);
284 assert_eq!(stepwise.get_by_index(15), 1.0);
285 }
286
287 #[test]
288 fn test_rates_between_0_and_1() {
289 let fixed = Rate::Fixed(0.5);
290 let linear = Rate::Linear(0.0, 1.0, 100);
291 let exponential = Rate::Exponential(1.0, 0.0, 50);
292 let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
293 let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
294 let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
295
296 for i in 0..100_000 {
297 let fixed_value = fixed.get_by_index(i);
298 let linear_value = linear.get_by_index(i);
299 let exp_value = exponential.get_by_index(i);
300 let cycle_value = cyclical.get_by_index(i);
301 let cycle_sine_value = cyclical_sine.get_by_index(i);
302 let stepwise_value = stepwise.get_by_index(i);
303
304 assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
305 assert!(linear_value >= 0.0 && linear_value <= 1.0);
306 assert!(exp_value >= 0.0 && exp_value <= 1.0);
307 assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
308 assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
309 assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
310 }
311 }
312
313 #[test]
314 fn test_rate_clamping() {
315 let linear = Rate::Linear(0.0, 1.0, 10);
316 assert_eq!(linear.get_by_index(15), 1.0);
317 }
318
319 #[test]
320 fn test_default_rate() {
321 let default_rate = Rate::default();
322 assert_eq!(default_rate.get_by_index(0), 1.0);
323 assert_eq!(default_rate.get_by_index(100), 1.0);
324 }
325
326 #[test]
327 fn test_rate_validity() {
328 let valid_fixed = Rate::Fixed(0.5);
329 let invalid_fixed = Rate::Fixed(1.5);
330 assert!(valid_fixed.is_valid());
331 assert!(!invalid_fixed.is_valid());
332 }
333}