Skip to main content

radiate_core/
rate.rs

1use crate::{MetricSet, Valid};
2use radiate_expr::{ApplyExpr, Expr};
3use std::fmt::Debug;
4
5pub trait RateCalculator {
6    fn rate(&mut self, generation: usize, metrics: &MetricSet) -> f32;
7}
8
9#[derive(Clone, Debug, PartialEq)]
10pub enum CycleShape {
11    Triangle,
12    Sine,
13}
14
15/// Rate enum representing different types of rate schedules where each variant defines a
16/// method to compute the rate value at a given step.
17/// These are designed to produce values within the range [0.0, 1.0] - ie: a rate.
18#[derive(Clone)]
19pub enum Rate {
20    /// A fixed rate that does not change over time.
21    ///
22    /// # Parameters
23    /// - `f32`: The fixed rate value.
24    Fixed(f32),
25    /// A linear rate that changes from start to end over a number of steps.
26    ///
27    /// # Parameters
28    /// - `start`: The starting rate value.
29    /// - `end`: The ending rate value.
30    /// - `steps`: The number of steps over which to change the rate.
31    Linear(f32, f32, usize),
32    /// An exponential rate that changes from start to end over a half-life period.
33    ///
34    /// # Parameters
35    /// - `start`: The starting rate value.
36    /// - `end`: The ending rate value.
37    /// - `half_life`: The half-life period over which to change the rate.
38    Exponential(f32, f32, usize),
39    /// A cyclical rate that oscillates between min and max over a period.
40    ///
41    /// # Parameters
42    /// - `min`: The minimum rate value.
43    /// - `max`: The maximum rate value.
44    /// - `period`: The period over which to cycle the rate.
45    /// - `shape`: The shape of the cycle (Triangle or Sine).
46    Cyclical(f32, f32, usize, CycleShape),
47    /// Piecewise-constant schedule: at each listed step, rate jumps to the given value.
48    /// The value remains constant until the next listed step.
49    /// The first step must be 0.
50    /// If the current step is beyond the last listed step, the rate remains at the last value.
51    ///
52    /// # Parameters
53    /// - `Vec<(usize, f32)>`: A vector of (step, rate) pairs.
54    Stepwise(Vec<(usize, f32)>),
55
56    /// A rate defined by an expression that can query metrics.
57    /// The expression should evaluate to a float value representing the rate.
58    /// The expression can use the provided metrics to compute a dynamic rate based on the current state of the ecosystem.
59    /// The expression is expected to return a value in the range [0.0, 1.0], but this is not enforced at compile time.
60    Expr(Expr),
61}
62
63impl Rate {
64    pub fn get(&mut self, generation: usize, metrics: &MetricSet) -> f32 {
65        match self {
66            Rate::Expr(expr) => metrics.apply(expr).extract().unwrap_or(0.0),
67            _ => self.get_by_index(generation),
68        }
69    }
70
71    pub fn get_by_index(&self, step: usize) -> f32 {
72        let f_step = step as f32;
73        match self {
74            Rate::Fixed(v) => *v,
75            Rate::Linear(start, end, steps) => {
76                if *steps == 0 {
77                    return *end;
78                }
79
80                let t = (f_step / *steps as f32).min(1.0);
81                start + (end - start) * t
82            }
83            Rate::Exponential(start, end, half_life) => {
84                if *half_life == 0 {
85                    return *end;
86                }
87
88                let decay = 0.5_f32.powf(f_step / *half_life as f32);
89                end + (start - end) * decay
90            }
91            Rate::Cyclical(min, max, period, shape) => {
92                let phase = (f_step % *period as f32) / *period as f32;
93                let tri = if phase < 0.5 {
94                    phase * 2.0
95                } else {
96                    (1.0 - phase) * 2.0
97                };
98
99                let s = match shape {
100                    CycleShape::Triangle => tri,
101                    CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
102                };
103
104                min + (max - min) * s
105            }
106            Rate::Stepwise(steps) => {
107                if steps.is_empty() {
108                    return 0.0;
109                }
110
111                let mut last_value = steps[0].1;
112                for (s, v) in steps {
113                    if step < *s {
114                        break;
115                    }
116
117                    last_value = *v;
118                }
119
120                last_value
121            }
122            _ => 1.0,
123        }
124    }
125}
126
127impl Valid for Rate {
128    fn is_valid(&self) -> bool {
129        match self {
130            Rate::Fixed(v) => (0.0..=1.0).contains(v),
131            Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
132            Rate::Exponential(start, end, _) => {
133                (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
134            }
135            Rate::Cyclical(min, max, _, _) => {
136                (0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
137            }
138            Rate::Stepwise(steps) => {
139                if steps.is_empty() {
140                    return false;
141                }
142
143                if steps[0].0 != 0 {
144                    return false;
145                }
146
147                let mut last_step = 0;
148                for (s, v) in steps {
149                    if *s < last_step || !(0.0..=1.0).contains(v) {
150                        return false;
151                    }
152                    last_step = *s;
153                }
154
155                true
156            }
157            _ => true,
158        }
159    }
160}
161
162impl Default for Rate {
163    fn default() -> Self {
164        Rate::Fixed(1.0)
165    }
166}
167
168impl From<f32> for Rate {
169    fn from(value: f32) -> Self {
170        Rate::Fixed(value)
171    }
172}
173
174impl From<Vec<(usize, f32)>> for Rate {
175    fn from(steps: Vec<(usize, f32)>) -> Self {
176        Rate::Stepwise(steps)
177    }
178}
179
180impl Debug for Rate {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            Rate::Fixed(v) => write!(f, "Rate::Fixed({})", v),
184            Rate::Linear(start, end, steps) => {
185                write!(
186                    f,
187                    "Rate::Linear(start: {}, end: {}, steps: {})",
188                    start, end, steps
189                )
190            }
191            Rate::Exponential(start, end, half_life) => write!(
192                f,
193                "Rate::Exponential(start: {}, end: {}, half_life: {})",
194                start, end, half_life
195            ),
196            Rate::Cyclical(min, max, period, shape) => write!(
197                f,
198                "Rate::Cyclical(min: {}, max: {}, period: {}, shape: {:?})",
199                min, max, period, shape
200            ),
201            Rate::Stepwise(steps) => write!(f, "Rate::Stepwise(steps: {:?})", steps),
202            Rate::Expr(_) => write!(f, "Rate::Expr(<function>)"),
203        }
204    }
205}
206
207impl PartialEq for Rate {
208    fn eq(&self, other: &Self) -> bool {
209        match (self, other) {
210            (Rate::Fixed(a), Rate::Fixed(b)) => a == b,
211            (Rate::Linear(a_start, a_end, a_steps), Rate::Linear(b_start, b_end, b_steps)) => {
212                a_start == b_start && a_end == b_end && a_steps == b_steps
213            }
214            (
215                Rate::Exponential(a_start, a_end, a_half_life),
216                Rate::Exponential(b_start, b_end, b_half_life),
217            ) => a_start == b_start && a_end == b_end && a_half_life == b_half_life,
218            (
219                Rate::Cyclical(a_min, a_max, a_period, a_shape),
220                Rate::Cyclical(b_min, b_max, b_period, b_shape),
221            ) => a_min == b_min && a_max == b_max && a_period == b_period && a_shape == b_shape,
222            (Rate::Stepwise(a_steps), Rate::Stepwise(b_steps)) => a_steps == b_steps,
223            // For Expr variants, we consider them equal if they are the same variant,
224            // since we cannot compare the inner function for equality.
225            (Rate::Expr(_), Rate::Expr(_)) => true,
226            _ => false,
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_rate_values() {
237        let fixed = Rate::Fixed(0.5);
238        assert_eq!(fixed.get_by_index(0), 0.5);
239        assert_eq!(fixed.get_by_index(10), 0.5);
240
241        let linear = Rate::Linear(0.0, 1.0, 10);
242        assert_eq!(linear.get_by_index(0), 0.0);
243        assert_eq!(linear.get_by_index(5), 0.5);
244        assert_eq!(linear.get_by_index(10), 1.0);
245        assert_eq!(linear.get_by_index(15), 1.0);
246
247        let exponential = Rate::Exponential(1.0, 0.1, 5);
248        assert!((exponential.get_by_index(0) - 1.0).abs() < 1e-6);
249        assert!((exponential.get_by_index(5) - 0.55).abs() < 1e-2);
250        assert!((exponential.get_by_index(10) - 0.325).abs() < 1e-2);
251
252        let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
253        assert!((cyclical.get_by_index(0) - 0.0).abs() < 1e-6);
254        assert!((cyclical.get_by_index(2) - 0.4).abs() < 1e-6);
255        assert!((cyclical.get_by_index(5) - 1.0).abs() < 1e-6);
256        assert!((cyclical.get_by_index(7) - 0.6).abs() < 1e-6);
257        assert!((cyclical.get_by_index(10) - 0.0).abs() < 1e-6);
258
259        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
260        assert!((cyclical_sine.get_by_index(0) - 0.0).abs() < 1e-6);
261        assert!(
262            (cyclical_sine.get_by_index(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs()
263                < 1e-6
264        );
265        assert!(cyclical_sine.get_by_index(5).abs() < 1e-6);
266        assert!(
267            (cyclical_sine.get_by_index(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs()
268                < 1e-6
269        );
270        assert!((cyclical_sine.get_by_index(10) - 0.0).abs() < 1e-6);
271
272        let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
273        assert_eq!(stepwise.get_by_index(0), 0.0);
274        assert_eq!(stepwise.get_by_index(3), 0.0);
275        assert_eq!(stepwise.get_by_index(5), 0.5);
276        assert_eq!(stepwise.get_by_index(7), 0.5);
277        assert_eq!(stepwise.get_by_index(10), 1.0);
278        assert_eq!(stepwise.get_by_index(15), 1.0);
279    }
280
281    #[test]
282    fn test_rates_between_0_and_1() {
283        let fixed = Rate::Fixed(0.5);
284        let linear = Rate::Linear(0.0, 1.0, 100);
285        let exponential = Rate::Exponential(1.0, 0.0, 50);
286        let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
287        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
288        let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
289
290        for i in 0..100_000 {
291            let fixed_value = fixed.get_by_index(i);
292            let linear_value = linear.get_by_index(i);
293            let exp_value = exponential.get_by_index(i);
294            let cycle_value = cyclical.get_by_index(i);
295            let cycle_sine_value = cyclical_sine.get_by_index(i);
296            let stepwise_value = stepwise.get_by_index(i);
297
298            assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
299            assert!(linear_value >= 0.0 && linear_value <= 1.0);
300            assert!(exp_value >= 0.0 && exp_value <= 1.0);
301            assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
302            assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
303            assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
304        }
305    }
306
307    #[test]
308    fn test_rate_clamping() {
309        let linear = Rate::Linear(0.0, 1.0, 10);
310        assert_eq!(linear.get_by_index(15), 1.0);
311    }
312
313    #[test]
314    fn test_default_rate() {
315        let default_rate = Rate::default();
316        assert_eq!(default_rate.get_by_index(0), 1.0);
317        assert_eq!(default_rate.get_by_index(100), 1.0);
318    }
319
320    #[test]
321    fn test_rate_validity() {
322        let valid_fixed = Rate::Fixed(0.5);
323        let invalid_fixed = Rate::Fixed(1.5);
324        assert!(valid_fixed.is_valid());
325        assert!(!invalid_fixed.is_valid());
326    }
327}