1use crate::Valid;
2
3#[derive(Clone, Debug, PartialEq)]
4pub enum CycleShape {
5 Triangle,
6 Sine,
7}
8
9#[derive(Clone, Debug, PartialEq)]
10pub enum Rate {
11 Fixed(f32),
12 Linear(f32, f32, usize),
19 Exponential(f32, f32, usize),
26 Cyclical(f32, f32, usize, CycleShape),
34 Stepwise(Vec<(usize, f32)>),
42 WarmupExp {
52 warmup_steps: usize,
53 start: f32,
54 peak: f32,
55 end: f32,
56 half_life: usize,
57 },
58}
59
60impl Rate {
61 pub fn value(&self, step: usize) -> f32 {
62 let f_step = step as f32;
63 match self {
64 Rate::Fixed(v) => *v,
65 Rate::Linear(start, end, steps) => {
66 if *steps == 0 {
67 return *end;
68 }
69
70 let t = (f_step / *steps as f32).min(1.0);
71 start + (end - start) * t
72 }
73 Rate::Exponential(start, end, half_life) => {
74 if *half_life == 0 {
75 return *end;
76 }
77
78 let decay = 0.5_f32.powf(f_step / *half_life as f32);
79 end + (start - end) * decay
80 }
81 Rate::Cyclical(min, max, period, shape) => {
82 let phase = (f_step % *period as f32) / *period as f32;
83 let tri = if phase < 0.5 {
84 phase * 2.0
85 } else {
86 (1.0 - phase) * 2.0
87 };
88
89 let s = match shape {
90 CycleShape::Triangle => tri,
91 CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
92 };
93
94 min + (max - min) * s
95 }
96 Rate::Stepwise(steps) => {
97 if steps.is_empty() {
98 return 0.0;
99 }
100
101 let mut last_value = steps[0].1;
102 for (s, v) in steps {
103 if step < *s {
104 break;
105 }
106
107 last_value = *v;
108 }
109
110 last_value
111 }
112 Rate::WarmupExp {
113 warmup_steps,
114 start,
115 peak,
116 end,
117 half_life,
118 } => {
119 if step < *warmup_steps {
120 if *warmup_steps == 0 {
121 return *peak;
122 }
123 let t = f_step / *warmup_steps as f32;
124 start + (peak - start) * t
125 } else {
126 let decay_step = step - *warmup_steps;
127 let decay = 0.5_f32.powf(decay_step as f32 / *half_life as f32);
128 end + (peak - end) * decay
129 }
130 }
131 }
132 }
133}
134
135impl Valid for Rate {
136 fn is_valid(&self) -> bool {
137 match self {
138 Rate::Fixed(v) => (0.0..=1.0).contains(v),
139 Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
140 Rate::Exponential(start, end, _) => {
141 (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
142 }
143 Rate::Cyclical(min, max, _, _) => {
144 (0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
145 }
146 Rate::Stepwise(steps) => {
147 if steps.is_empty() {
148 return false;
149 }
150
151 if steps[0].0 != 0 {
152 return false;
153 }
154
155 let mut last_step = 0;
156 for (s, v) in steps {
157 if *s < last_step || !(0.0..=1.0).contains(v) {
158 return false;
159 }
160 last_step = *s;
161 }
162
163 true
164 }
165 Rate::WarmupExp {
166 start, peak, end, ..
167 } => {
168 (0.0..=1.0).contains(start)
169 && (0.0..=1.0).contains(peak)
170 && (0.0..=1.0).contains(end)
171 }
172 }
173 }
174}
175
176impl Default for Rate {
177 fn default() -> Self {
178 Rate::Fixed(1.0)
179 }
180}
181
182impl From<f32> for Rate {
183 fn from(value: f32) -> Self {
184 Rate::Fixed(value)
185 }
186}
187
188impl From<Vec<(usize, f32)>> for Rate {
189 fn from(steps: Vec<(usize, f32)>) -> Self {
190 Rate::Stepwise(steps)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_rate_values() {
200 let fixed = Rate::Fixed(0.5);
201 assert_eq!(fixed.value(0), 0.5);
202 assert_eq!(fixed.value(10), 0.5);
203
204 let linear = Rate::Linear(0.0, 1.0, 10);
205 assert_eq!(linear.value(0), 0.0);
206 assert_eq!(linear.value(5), 0.5);
207 assert_eq!(linear.value(10), 1.0);
208 assert_eq!(linear.value(15), 1.0);
209
210 let exponential = Rate::Exponential(1.0, 0.1, 5);
211 assert!((exponential.value(0) - 1.0).abs() < 1e-6);
212 assert!((exponential.value(5) - 0.55).abs() < 1e-2);
213 assert!((exponential.value(10) - 0.325).abs() < 1e-2);
214
215 let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
216 assert!((cyclical.value(0) - 0.0).abs() < 1e-6);
217 assert!((cyclical.value(2) - 0.4).abs() < 1e-6);
218 assert!((cyclical.value(5) - 1.0).abs() < 1e-6);
219 assert!((cyclical.value(7) - 0.6).abs() < 1e-6);
220 assert!((cyclical.value(10) - 0.0).abs() < 1e-6);
221
222 let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
223 assert!((cyclical_sine.value(0) - 0.0).abs() < 1e-6);
224 assert!((cyclical_sine.value(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs() < 1e-6);
225 assert!((cyclical_sine.value(5)).abs() < 1e-6);
226 assert!((cyclical_sine.value(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs() < 1e-6);
227 assert!((cyclical_sine.value(10) - 0.0).abs() < 1e-6);
228
229 let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
230 assert_eq!(stepwise.value(0), 0.0);
231 assert_eq!(stepwise.value(3), 0.0);
232 assert_eq!(stepwise.value(5), 0.5);
233 assert_eq!(stepwise.value(7), 0.5);
234 assert_eq!(stepwise.value(10), 1.0);
235 assert_eq!(stepwise.value(15), 1.0);
236
237 let warmup_exp = Rate::WarmupExp {
238 warmup_steps: 5,
239 start: 0.0,
240 peak: 1.0,
241 end: 0.1,
242 half_life: 5,
243 };
244 assert_eq!(warmup_exp.value(0), 0.0);
245 assert_eq!(warmup_exp.value(2), 0.4);
246 assert_eq!(warmup_exp.value(5), 1.0);
247 assert!((warmup_exp.value(10) - 0.55).abs() < 1e-2);
248 assert!((warmup_exp.value(15) - 0.325).abs() < 1e-2);
249 }
250
251 #[test]
252 fn test_rates_between_0_and_1() {
253 let fixed = Rate::Fixed(0.5);
254 let linear = Rate::Linear(0.0, 1.0, 100);
255 let exponential = Rate::Exponential(1.0, 0.0, 50);
256 let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
257 let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
258 let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
259 let warmup_exp = Rate::WarmupExp {
260 warmup_steps: 50,
261 start: 0.0,
262 peak: 1.0,
263 end: 0.0,
264 half_life: 50,
265 };
266
267 for i in 0..100_000 {
268 let fixed_value = fixed.value(i);
269 let linear_value = linear.value(i);
270 let exp_value = exponential.value(i);
271 let cycle_value = cyclical.value(i);
272 let cycle_sine_value = cyclical_sine.value(i);
273 let stepwise_value = stepwise.value(i);
274 let warmup_exp_value = warmup_exp.value(i);
275
276 assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
277 assert!(linear_value >= 0.0 && linear_value <= 1.0);
278 assert!(exp_value >= 0.0 && exp_value <= 1.0);
279 assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
280 assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
281 assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
282 assert!(warmup_exp_value >= 0.0 && warmup_exp_value <= 1.0);
283 }
284 }
285
286 #[test]
287 fn test_rate_clamping() {
288 let linear = Rate::Linear(0.0, 1.0, 10);
289 assert_eq!(linear.value(15), 1.0);
290 }
291
292 #[test]
293 fn test_default_rate() {
294 let default_rate = Rate::default();
295 assert_eq!(default_rate.value(0), 1.0);
296 assert_eq!(default_rate.value(100), 1.0);
297 }
298
299 #[test]
300 fn test_rate_validity() {
301 let valid_fixed = Rate::Fixed(0.5);
302 let invalid_fixed = Rate::Fixed(1.5);
303 assert!(valid_fixed.is_valid());
304 assert!(!invalid_fixed.is_valid());
305 }
306}