Skip to main content

radiate_core/
rate.rs

1use crate::Valid;
2
3#[derive(Clone, Debug, PartialEq)]
4pub enum CycleShape {
5    Triangle,
6    Sine,
7}
8
9/// Rate enum representing different types of rate schedules where each variant defines a
10/// method to compute the rate value at a given step.
11/// These are designed to produce values within the range [0.0, 1.0] - ie: a rate.
12#[derive(Clone, Debug, PartialEq)]
13pub enum Rate {
14    /// A fixed rate that does not change over time.
15    ///
16    /// # Parameters
17    /// - `f32`: The fixed rate value.
18    Fixed(f32),
19    /// A linear rate that changes from start to end over a number of steps.
20    ///
21    /// # Parameters
22    /// - `start`: The starting rate value.
23    /// - `end`: The ending rate value.
24    /// - `steps`: The number of steps over which to change the rate.
25    Linear(f32, f32, usize),
26    /// An exponential rate that changes from start to end over a half-life period.
27    ///
28    /// # Parameters
29    /// - `start`: The starting rate value.
30    /// - `end`: The ending rate value.
31    /// - `half_life`: The half-life period over which to change the rate.
32    Exponential(f32, f32, usize),
33    /// A cyclical rate that oscillates between min and max over a period.
34    ///
35    /// # Parameters
36    /// - `min`: The minimum rate value.
37    /// - `max`: The maximum rate value.
38    /// - `period`: The period over which to cycle the rate.
39    /// - `shape`: The shape of the cycle (Triangle or Sine).
40    Cyclical(f32, f32, usize, CycleShape),
41    /// Piecewise-constant schedule: at each listed step, rate jumps to the given value.
42    /// The value remains constant until the next listed step.
43    /// The first step must be 0.
44    /// If the current step is beyond the last listed step, the rate remains at the last value.
45    ///
46    /// # Parameters
47    /// - `Vec<(usize, f32)>`: A vector of (step, rate) pairs.
48    Stepwise(Vec<(usize, f32)>),
49}
50
51impl Rate {
52    pub fn value(&self, step: usize) -> f32 {
53        let f_step = step as f32;
54        match self {
55            Rate::Fixed(v) => *v,
56            Rate::Linear(start, end, steps) => {
57                if *steps == 0 {
58                    return *end;
59                }
60
61                let t = (f_step / *steps as f32).min(1.0);
62                start + (end - start) * t
63            }
64            Rate::Exponential(start, end, half_life) => {
65                if *half_life == 0 {
66                    return *end;
67                }
68
69                let decay = 0.5_f32.powf(f_step / *half_life as f32);
70                end + (start - end) * decay
71            }
72            Rate::Cyclical(min, max, period, shape) => {
73                let phase = (f_step % *period as f32) / *period as f32;
74                let tri = if phase < 0.5 {
75                    phase * 2.0
76                } else {
77                    (1.0 - phase) * 2.0
78                };
79
80                let s = match shape {
81                    CycleShape::Triangle => tri,
82                    CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
83                };
84
85                min + (max - min) * s
86            }
87            Rate::Stepwise(steps) => {
88                if steps.is_empty() {
89                    return 0.0;
90                }
91
92                let mut last_value = steps[0].1;
93                for (s, v) in steps {
94                    if step < *s {
95                        break;
96                    }
97
98                    last_value = *v;
99                }
100
101                last_value
102            }
103        }
104    }
105}
106
107impl Valid for Rate {
108    fn is_valid(&self) -> bool {
109        match self {
110            Rate::Fixed(v) => (0.0..=1.0).contains(v),
111            Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
112            Rate::Exponential(start, end, _) => {
113                (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
114            }
115            Rate::Cyclical(min, max, _, _) => {
116                (0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
117            }
118            Rate::Stepwise(steps) => {
119                if steps.is_empty() {
120                    return false;
121                }
122
123                if steps[0].0 != 0 {
124                    return false;
125                }
126
127                let mut last_step = 0;
128                for (s, v) in steps {
129                    if *s < last_step || !(0.0..=1.0).contains(v) {
130                        return false;
131                    }
132                    last_step = *s;
133                }
134
135                true
136            }
137        }
138    }
139}
140
141impl Default for Rate {
142    fn default() -> Self {
143        Rate::Fixed(1.0)
144    }
145}
146
147impl From<f32> for Rate {
148    fn from(value: f32) -> Self {
149        Rate::Fixed(value)
150    }
151}
152
153impl From<Vec<(usize, f32)>> for Rate {
154    fn from(steps: Vec<(usize, f32)>) -> Self {
155        Rate::Stepwise(steps)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_rate_values() {
165        let fixed = Rate::Fixed(0.5);
166        assert_eq!(fixed.value(0), 0.5);
167        assert_eq!(fixed.value(10), 0.5);
168
169        let linear = Rate::Linear(0.0, 1.0, 10);
170        assert_eq!(linear.value(0), 0.0);
171        assert_eq!(linear.value(5), 0.5);
172        assert_eq!(linear.value(10), 1.0);
173        assert_eq!(linear.value(15), 1.0);
174
175        let exponential = Rate::Exponential(1.0, 0.1, 5);
176        assert!((exponential.value(0) - 1.0).abs() < 1e-6);
177        assert!((exponential.value(5) - 0.55).abs() < 1e-2);
178        assert!((exponential.value(10) - 0.325).abs() < 1e-2);
179
180        let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
181        assert!((cyclical.value(0) - 0.0).abs() < 1e-6);
182        assert!((cyclical.value(2) - 0.4).abs() < 1e-6);
183        assert!((cyclical.value(5) - 1.0).abs() < 1e-6);
184        assert!((cyclical.value(7) - 0.6).abs() < 1e-6);
185        assert!((cyclical.value(10) - 0.0).abs() < 1e-6);
186
187        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
188        assert!((cyclical_sine.value(0) - 0.0).abs() < 1e-6);
189        assert!((cyclical_sine.value(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs() < 1e-6);
190        assert!(cyclical_sine.value(5).abs() < 1e-6);
191        assert!((cyclical_sine.value(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs() < 1e-6);
192        assert!((cyclical_sine.value(10) - 0.0).abs() < 1e-6);
193
194        let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
195        assert_eq!(stepwise.value(0), 0.0);
196        assert_eq!(stepwise.value(3), 0.0);
197        assert_eq!(stepwise.value(5), 0.5);
198        assert_eq!(stepwise.value(7), 0.5);
199        assert_eq!(stepwise.value(10), 1.0);
200        assert_eq!(stepwise.value(15), 1.0);
201    }
202
203    #[test]
204    fn test_rates_between_0_and_1() {
205        let fixed = Rate::Fixed(0.5);
206        let linear = Rate::Linear(0.0, 1.0, 100);
207        let exponential = Rate::Exponential(1.0, 0.0, 50);
208        let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
209        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
210        let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
211
212        for i in 0..100_000 {
213            let fixed_value = fixed.value(i);
214            let linear_value = linear.value(i);
215            let exp_value = exponential.value(i);
216            let cycle_value = cyclical.value(i);
217            let cycle_sine_value = cyclical_sine.value(i);
218            let stepwise_value = stepwise.value(i);
219
220            assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
221            assert!(linear_value >= 0.0 && linear_value <= 1.0);
222            assert!(exp_value >= 0.0 && exp_value <= 1.0);
223            assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
224            assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
225            assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
226        }
227    }
228
229    #[test]
230    fn test_rate_clamping() {
231        let linear = Rate::Linear(0.0, 1.0, 10);
232        assert_eq!(linear.value(15), 1.0);
233    }
234
235    #[test]
236    fn test_default_rate() {
237        let default_rate = Rate::default();
238        assert_eq!(default_rate.value(0), 1.0);
239        assert_eq!(default_rate.value(100), 1.0);
240    }
241
242    #[test]
243    fn test_rate_validity() {
244        let valid_fixed = Rate::Fixed(0.5);
245        let invalid_fixed = Rate::Fixed(1.5);
246        assert!(valid_fixed.is_valid());
247        assert!(!invalid_fixed.is_valid());
248    }
249}