Skip to main content

radiate_core/
rate.rs

1use crate::stats::expression::{Evaluate, Expr};
2use crate::{MetricSet, Valid};
3use std::fmt::Debug;
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CycleShape {
7    Triangle,
8    Sine,
9}
10
11/// Rate enum representing different types of rate schedules where each variant defines a
12/// method to compute the rate value at a given step.
13/// These are designed to produce values within the range [0.0, 1.0] - ie: a rate.
14#[derive(Clone)]
15pub enum Rate {
16    /// A fixed rate that does not change over time.
17    ///
18    /// # Parameters
19    /// - `f32`: The fixed rate value.
20    Fixed(f32),
21    /// A linear rate that changes from start to end over a number of steps.
22    ///
23    /// # Parameters
24    /// - `start`: The starting rate value.
25    /// - `end`: The ending rate value.
26    /// - `steps`: The number of steps over which to change the rate.
27    Linear(f32, f32, usize),
28    /// An exponential rate that changes from start to end over a half-life period.
29    ///
30    /// # Parameters
31    /// - `start`: The starting rate value.
32    /// - `end`: The ending rate value.
33    /// - `half_life`: The half-life period over which to change the rate.
34    Exponential(f32, f32, usize),
35    /// A cyclical rate that oscillates between min and max over a period.
36    ///
37    /// # Parameters
38    /// - `min`: The minimum rate value.
39    /// - `max`: The maximum rate value.
40    /// - `period`: The period over which to cycle the rate.
41    /// - `shape`: The shape of the cycle (Triangle or Sine).
42    Cyclical(f32, f32, usize, CycleShape),
43    /// Piecewise-constant schedule: at each listed step, rate jumps to the given value.
44    /// The value remains constant until the next listed step.
45    /// The first step must be 0.
46    /// If the current step is beyond the last listed step, the rate remains at the last value.
47    ///
48    /// # Parameters
49    /// - `Vec<(usize, f32)>`: A vector of (step, rate) pairs.
50    Stepwise(Vec<(usize, f32)>),
51
52    /// A rate defined by an expression that can query metrics.
53    /// The expression should evaluate to a float value representing the rate.
54    /// The expression can use the provided metrics to compute a dynamic rate based on the current state of the ecosystem.
55    /// The expression is expected to return a value in the range [0.0, 1.0], but this is not enforced at compile time.
56    Expr(Expr),
57}
58
59impl Rate {
60    pub fn get(&mut self, generation: usize, metrics: &MetricSet) -> f32 {
61        match self {
62            Rate::Expr(expr) => expr
63                .eval(metrics)
64                .ok()
65                .and_then(|v| v.extract())
66                .unwrap_or(0.0),
67            _ => self.get_by_index(generation),
68        }
69    }
70
71    pub fn get_by_index(&self, step: usize) -> f32 {
72        let f_step = step as f32;
73        match self {
74            Rate::Fixed(v) => *v,
75            Rate::Linear(start, end, steps) => {
76                if *steps == 0 {
77                    return *end;
78                }
79
80                let t = (f_step / *steps as f32).min(1.0);
81                start + (end - start) * t
82            }
83            Rate::Exponential(start, end, half_life) => {
84                if *half_life == 0 {
85                    return *end;
86                }
87
88                let decay = 0.5_f32.powf(f_step / *half_life as f32);
89                end + (start - end) * decay
90            }
91            Rate::Cyclical(min, max, period, shape) => {
92                let phase = (f_step % *period as f32) / *period as f32;
93                let tri = if phase < 0.5 {
94                    phase * 2.0
95                } else {
96                    (1.0 - phase) * 2.0
97                };
98
99                let s = match shape {
100                    CycleShape::Triangle => tri,
101                    CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
102                };
103
104                min + (max - min) * s
105            }
106            Rate::Stepwise(steps) => {
107                if steps.is_empty() {
108                    return 0.0;
109                }
110
111                let mut last_value = steps[0].1;
112                for (s, v) in steps {
113                    if step < *s {
114                        break;
115                    }
116
117                    last_value = *v;
118                }
119
120                last_value
121            }
122            _ => 1.0,
123        }
124    }
125}
126
127impl Valid for Rate {
128    fn is_valid(&self) -> bool {
129        match self {
130            Rate::Fixed(v) => (0.0..=1.0).contains(v),
131            Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
132            Rate::Exponential(start, end, _) => {
133                (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
134            }
135            Rate::Cyclical(min, max, _, _) => {
136                (0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
137            }
138            Rate::Stepwise(steps) => {
139                if steps.is_empty() {
140                    return false;
141                }
142
143                if steps[0].0 != 0 {
144                    return false;
145                }
146
147                let mut last_step = 0;
148                for (s, v) in steps {
149                    if *s < last_step || !(0.0..=1.0).contains(v) {
150                        return false;
151                    }
152                    last_step = *s;
153                }
154
155                true
156            }
157            _ => true,
158        }
159    }
160}
161
162impl Default for Rate {
163    fn default() -> Self {
164        Rate::Fixed(1.0)
165    }
166}
167
168impl From<f32> for Rate {
169    fn from(value: f32) -> Self {
170        Rate::Fixed(value)
171    }
172}
173
174impl From<Vec<(usize, f32)>> for Rate {
175    fn from(steps: Vec<(usize, f32)>) -> Self {
176        Rate::Stepwise(steps)
177    }
178}
179
180impl From<Expr> for Rate {
181    fn from(expr: Expr) -> Self {
182        Rate::Expr(expr.compile())
183    }
184}
185
186impl Debug for Rate {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        match self {
189            Rate::Fixed(v) => write!(f, "Rate::Fixed({})", v),
190            Rate::Linear(start, end, steps) => {
191                write!(
192                    f,
193                    "Rate::Linear(start: {}, end: {}, steps: {})",
194                    start, end, steps
195                )
196            }
197            Rate::Exponential(start, end, half_life) => write!(
198                f,
199                "Rate::Exponential(start: {}, end: {}, half_life: {})",
200                start, end, half_life
201            ),
202            Rate::Cyclical(min, max, period, shape) => write!(
203                f,
204                "Rate::Cyclical(min: {}, max: {}, period: {}, shape: {:?})",
205                min, max, period, shape
206            ),
207            Rate::Stepwise(steps) => write!(f, "Rate::Stepwise(steps: {:?})", steps),
208            Rate::Expr(_) => write!(f, "Rate::Expr(<function>)"),
209        }
210    }
211}
212
213impl PartialEq for Rate {
214    fn eq(&self, other: &Self) -> bool {
215        match (self, other) {
216            (Rate::Fixed(a), Rate::Fixed(b)) => a == b,
217            (Rate::Linear(a_start, a_end, a_steps), Rate::Linear(b_start, b_end, b_steps)) => {
218                a_start == b_start && a_end == b_end && a_steps == b_steps
219            }
220            (
221                Rate::Exponential(a_start, a_end, a_half_life),
222                Rate::Exponential(b_start, b_end, b_half_life),
223            ) => a_start == b_start && a_end == b_end && a_half_life == b_half_life,
224            (
225                Rate::Cyclical(a_min, a_max, a_period, a_shape),
226                Rate::Cyclical(b_min, b_max, b_period, b_shape),
227            ) => a_min == b_min && a_max == b_max && a_period == b_period && a_shape == b_shape,
228            (Rate::Stepwise(a_steps), Rate::Stepwise(b_steps)) => a_steps == b_steps,
229            // For Expr variants, we consider them equal if they are the same variant,
230            // since we cannot compare the inner function for equality.
231            (Rate::Expr(_), Rate::Expr(_)) => true,
232            _ => false,
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_rate_values() {
243        let fixed = Rate::Fixed(0.5);
244        assert_eq!(fixed.get_by_index(0), 0.5);
245        assert_eq!(fixed.get_by_index(10), 0.5);
246
247        let linear = Rate::Linear(0.0, 1.0, 10);
248        assert_eq!(linear.get_by_index(0), 0.0);
249        assert_eq!(linear.get_by_index(5), 0.5);
250        assert_eq!(linear.get_by_index(10), 1.0);
251        assert_eq!(linear.get_by_index(15), 1.0);
252
253        let exponential = Rate::Exponential(1.0, 0.1, 5);
254        assert!((exponential.get_by_index(0) - 1.0).abs() < 1e-6);
255        assert!((exponential.get_by_index(5) - 0.55).abs() < 1e-2);
256        assert!((exponential.get_by_index(10) - 0.325).abs() < 1e-2);
257
258        let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
259        assert!((cyclical.get_by_index(0) - 0.0).abs() < 1e-6);
260        assert!((cyclical.get_by_index(2) - 0.4).abs() < 1e-6);
261        assert!((cyclical.get_by_index(5) - 1.0).abs() < 1e-6);
262        assert!((cyclical.get_by_index(7) - 0.6).abs() < 1e-6);
263        assert!((cyclical.get_by_index(10) - 0.0).abs() < 1e-6);
264
265        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
266        assert!((cyclical_sine.get_by_index(0) - 0.0).abs() < 1e-6);
267        assert!(
268            (cyclical_sine.get_by_index(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs()
269                < 1e-6
270        );
271        assert!(cyclical_sine.get_by_index(5).abs() < 1e-6);
272        assert!(
273            (cyclical_sine.get_by_index(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs()
274                < 1e-6
275        );
276        assert!((cyclical_sine.get_by_index(10) - 0.0).abs() < 1e-6);
277
278        let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
279        assert_eq!(stepwise.get_by_index(0), 0.0);
280        assert_eq!(stepwise.get_by_index(3), 0.0);
281        assert_eq!(stepwise.get_by_index(5), 0.5);
282        assert_eq!(stepwise.get_by_index(7), 0.5);
283        assert_eq!(stepwise.get_by_index(10), 1.0);
284        assert_eq!(stepwise.get_by_index(15), 1.0);
285    }
286
287    #[test]
288    fn test_rates_between_0_and_1() {
289        let fixed = Rate::Fixed(0.5);
290        let linear = Rate::Linear(0.0, 1.0, 100);
291        let exponential = Rate::Exponential(1.0, 0.0, 50);
292        let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
293        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
294        let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
295
296        for i in 0..100_000 {
297            let fixed_value = fixed.get_by_index(i);
298            let linear_value = linear.get_by_index(i);
299            let exp_value = exponential.get_by_index(i);
300            let cycle_value = cyclical.get_by_index(i);
301            let cycle_sine_value = cyclical_sine.get_by_index(i);
302            let stepwise_value = stepwise.get_by_index(i);
303
304            assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
305            assert!(linear_value >= 0.0 && linear_value <= 1.0);
306            assert!(exp_value >= 0.0 && exp_value <= 1.0);
307            assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
308            assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
309            assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
310        }
311    }
312
313    #[test]
314    fn test_rate_clamping() {
315        let linear = Rate::Linear(0.0, 1.0, 10);
316        assert_eq!(linear.get_by_index(15), 1.0);
317    }
318
319    #[test]
320    fn test_default_rate() {
321        let default_rate = Rate::default();
322        assert_eq!(default_rate.get_by_index(0), 1.0);
323        assert_eq!(default_rate.get_by_index(100), 1.0);
324    }
325
326    #[test]
327    fn test_rate_validity() {
328        let valid_fixed = Rate::Fixed(0.5);
329        let invalid_fixed = Rate::Fixed(1.5);
330        assert!(valid_fixed.is_valid());
331        assert!(!invalid_fixed.is_valid());
332    }
333}