radiate_core/
rate.rs

1use crate::Valid;
2
3#[derive(Clone, Debug, PartialEq)]
4pub enum CycleShape {
5    Triangle,
6    Sine,
7}
8
9#[derive(Clone, Debug, PartialEq)]
10pub enum Rate {
11    Fixed(f32),
12    /// A linear rate that changes from start to end over a number of steps.
13    ///
14    /// # Parameters
15    /// - `start`: The starting rate value.
16    /// - `end`: The ending rate value.
17    /// - `steps`: The number of steps over which to change the rate.
18    Linear(f32, f32, usize),
19    /// An exponential rate that changes from start to end over a half-life period.
20    ///
21    /// # Parameters
22    /// - `start`: The starting rate value.
23    /// - `end`: The ending rate value.
24    /// - `half_life`: The half-life period over which to change the rate.
25    Exponential(f32, f32, usize),
26    /// A cyclical rate that oscillates between min and max over a period.
27    ///
28    /// # Parameters
29    /// - `min`: The minimum rate value.
30    /// - `max`: The maximum rate value.
31    /// - `period`: The period over which to cycle the rate.
32    /// - `shape`: The shape of the cycle (Triangle or Sine).
33    Cyclical(f32, f32, usize, CycleShape),
34    /// Piecewise-constant schedule: at each listed step, rate jumps to the given value.
35    /// The value remains constant until the next listed step.
36    /// The first step must be 0.
37    /// If the current step is beyond the last listed step, the rate remains at the last value.
38    ///
39    /// # Parameters
40    /// - `Vec<(usize, f32)>`: A vector of (step, rate) pairs.
41    Stepwise(Vec<(usize, f32)>),
42    /// A warmup exponential schedule that starts at `start`, rises to `peak` over `warmup_steps`,
43    /// then decays to `end` with a half-life of `half_life`.
44    ///
45    /// # Parameters
46    /// - `warmup_steps`: Number of steps to reach peak from start.
47    /// - `start`: The starting rate value.
48    /// - `peak`: The peak rate value after warmup.
49    /// - `end`: The ending rate value after decay.
50    /// - `half_life`: The half-life period for decay after warmup.
51    WarmupExp {
52        warmup_steps: usize,
53        start: f32,
54        peak: f32,
55        end: f32,
56        half_life: usize,
57    },
58}
59
60impl Rate {
61    pub fn value(&self, step: usize) -> f32 {
62        let f_step = step as f32;
63        match self {
64            Rate::Fixed(v) => *v,
65            Rate::Linear(start, end, steps) => {
66                if *steps == 0 {
67                    return *end;
68                }
69
70                let t = (f_step / *steps as f32).min(1.0);
71                start + (end - start) * t
72            }
73            Rate::Exponential(start, end, half_life) => {
74                if *half_life == 0 {
75                    return *end;
76                }
77
78                let decay = 0.5_f32.powf(f_step / *half_life as f32);
79                end + (start - end) * decay
80            }
81            Rate::Cyclical(min, max, period, shape) => {
82                let phase = (f_step % *period as f32) / *period as f32;
83                let tri = if phase < 0.5 {
84                    phase * 2.0
85                } else {
86                    (1.0 - phase) * 2.0
87                };
88
89                let s = match shape {
90                    CycleShape::Triangle => tri,
91                    CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
92                };
93
94                min + (max - min) * s
95            }
96            Rate::Stepwise(steps) => {
97                if steps.is_empty() {
98                    return 0.0;
99                }
100
101                let mut last_value = steps[0].1;
102                for (s, v) in steps {
103                    if step < *s {
104                        break;
105                    }
106
107                    last_value = *v;
108                }
109
110                last_value
111            }
112            Rate::WarmupExp {
113                warmup_steps,
114                start,
115                peak,
116                end,
117                half_life,
118            } => {
119                if step < *warmup_steps {
120                    if *warmup_steps == 0 {
121                        return *peak;
122                    }
123                    let t = f_step / *warmup_steps as f32;
124                    start + (peak - start) * t
125                } else {
126                    let decay_step = step - *warmup_steps;
127                    let decay = 0.5_f32.powf(decay_step as f32 / *half_life as f32);
128                    end + (peak - end) * decay
129                }
130            }
131        }
132    }
133}
134
135impl Valid for Rate {
136    fn is_valid(&self) -> bool {
137        match self {
138            Rate::Fixed(v) => (0.0..=1.0).contains(v),
139            Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
140            Rate::Exponential(start, end, _) => {
141                (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
142            }
143            Rate::Cyclical(min, max, _, _) => {
144                (0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
145            }
146            Rate::Stepwise(steps) => {
147                if steps.is_empty() {
148                    return false;
149                }
150
151                if steps[0].0 != 0 {
152                    return false;
153                }
154
155                let mut last_step = 0;
156                for (s, v) in steps {
157                    if *s < last_step || !(0.0..=1.0).contains(v) {
158                        return false;
159                    }
160                    last_step = *s;
161                }
162
163                true
164            }
165            Rate::WarmupExp {
166                start, peak, end, ..
167            } => {
168                (0.0..=1.0).contains(start)
169                    && (0.0..=1.0).contains(peak)
170                    && (0.0..=1.0).contains(end)
171            }
172        }
173    }
174}
175
176impl Default for Rate {
177    fn default() -> Self {
178        Rate::Fixed(1.0)
179    }
180}
181
182impl From<f32> for Rate {
183    fn from(value: f32) -> Self {
184        Rate::Fixed(value)
185    }
186}
187
188impl From<Vec<(usize, f32)>> for Rate {
189    fn from(steps: Vec<(usize, f32)>) -> Self {
190        Rate::Stepwise(steps)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_rate_values() {
200        let fixed = Rate::Fixed(0.5);
201        assert_eq!(fixed.value(0), 0.5);
202        assert_eq!(fixed.value(10), 0.5);
203
204        let linear = Rate::Linear(0.0, 1.0, 10);
205        assert_eq!(linear.value(0), 0.0);
206        assert_eq!(linear.value(5), 0.5);
207        assert_eq!(linear.value(10), 1.0);
208        assert_eq!(linear.value(15), 1.0);
209
210        let exponential = Rate::Exponential(1.0, 0.1, 5);
211        assert!((exponential.value(0) - 1.0).abs() < 1e-6);
212        assert!((exponential.value(5) - 0.55).abs() < 1e-2);
213        assert!((exponential.value(10) - 0.325).abs() < 1e-2);
214
215        let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
216        assert!((cyclical.value(0) - 0.0).abs() < 1e-6);
217        assert!((cyclical.value(2) - 0.4).abs() < 1e-6);
218        assert!((cyclical.value(5) - 1.0).abs() < 1e-6);
219        assert!((cyclical.value(7) - 0.6).abs() < 1e-6);
220        assert!((cyclical.value(10) - 0.0).abs() < 1e-6);
221
222        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
223        assert!((cyclical_sine.value(0) - 0.0).abs() < 1e-6);
224        assert!((cyclical_sine.value(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs() < 1e-6);
225        assert!((cyclical_sine.value(5)).abs() < 1e-6);
226        assert!((cyclical_sine.value(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs() < 1e-6);
227        assert!((cyclical_sine.value(10) - 0.0).abs() < 1e-6);
228
229        let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
230        assert_eq!(stepwise.value(0), 0.0);
231        assert_eq!(stepwise.value(3), 0.0);
232        assert_eq!(stepwise.value(5), 0.5);
233        assert_eq!(stepwise.value(7), 0.5);
234        assert_eq!(stepwise.value(10), 1.0);
235        assert_eq!(stepwise.value(15), 1.0);
236
237        let warmup_exp = Rate::WarmupExp {
238            warmup_steps: 5,
239            start: 0.0,
240            peak: 1.0,
241            end: 0.1,
242            half_life: 5,
243        };
244        assert_eq!(warmup_exp.value(0), 0.0);
245        assert_eq!(warmup_exp.value(2), 0.4);
246        assert_eq!(warmup_exp.value(5), 1.0);
247        assert!((warmup_exp.value(10) - 0.55).abs() < 1e-2);
248        assert!((warmup_exp.value(15) - 0.325).abs() < 1e-2);
249    }
250
251    #[test]
252    fn test_rates_between_0_and_1() {
253        let fixed = Rate::Fixed(0.5);
254        let linear = Rate::Linear(0.0, 1.0, 100);
255        let exponential = Rate::Exponential(1.0, 0.0, 50);
256        let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
257        let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
258        let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
259        let warmup_exp = Rate::WarmupExp {
260            warmup_steps: 50,
261            start: 0.0,
262            peak: 1.0,
263            end: 0.0,
264            half_life: 50,
265        };
266
267        for i in 0..100_000 {
268            let fixed_value = fixed.value(i);
269            let linear_value = linear.value(i);
270            let exp_value = exponential.value(i);
271            let cycle_value = cyclical.value(i);
272            let cycle_sine_value = cyclical_sine.value(i);
273            let stepwise_value = stepwise.value(i);
274            let warmup_exp_value = warmup_exp.value(i);
275
276            assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
277            assert!(linear_value >= 0.0 && linear_value <= 1.0);
278            assert!(exp_value >= 0.0 && exp_value <= 1.0);
279            assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
280            assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
281            assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
282            assert!(warmup_exp_value >= 0.0 && warmup_exp_value <= 1.0);
283        }
284    }
285
286    #[test]
287    fn test_rate_clamping() {
288        let linear = Rate::Linear(0.0, 1.0, 10);
289        assert_eq!(linear.value(15), 1.0);
290    }
291
292    #[test]
293    fn test_default_rate() {
294        let default_rate = Rate::default();
295        assert_eq!(default_rate.value(0), 1.0);
296        assert_eq!(default_rate.value(100), 1.0);
297    }
298
299    #[test]
300    fn test_rate_validity() {
301        let valid_fixed = Rate::Fixed(0.5);
302        let invalid_fixed = Rate::Fixed(1.5);
303        assert!(valid_fixed.is_valid());
304        assert!(!invalid_fixed.is_valid());
305    }
306}