1use crate::{MetricSet, Valid};
2use radiate_expr::{ApplyExpr, Expr};
3use std::fmt::Debug;
4
5pub trait RateCalculator {
6 fn rate(&mut self, generation: usize, metrics: &MetricSet) -> f32;
7}
8
9#[derive(Clone, Debug, PartialEq)]
10pub enum CycleShape {
11 Triangle,
12 Sine,
13}
14
15#[derive(Clone)]
19pub enum Rate {
20 Fixed(f32),
25 Linear(f32, f32, usize),
32 Exponential(f32, f32, usize),
39 Cyclical(f32, f32, usize, CycleShape),
47 Stepwise(Vec<(usize, f32)>),
55
56 Expr(Expr),
61}
62
63impl Rate {
64 pub fn get(&mut self, generation: usize, metrics: &MetricSet) -> f32 {
65 match self {
66 Rate::Expr(expr) => metrics.apply(expr).extract().unwrap_or(0.0),
67 _ => self.get_by_index(generation),
68 }
69 }
70
71 pub fn get_by_index(&self, step: usize) -> f32 {
72 let f_step = step as f32;
73 match self {
74 Rate::Fixed(v) => *v,
75 Rate::Linear(start, end, steps) => {
76 if *steps == 0 {
77 return *end;
78 }
79
80 let t = (f_step / *steps as f32).min(1.0);
81 start + (end - start) * t
82 }
83 Rate::Exponential(start, end, half_life) => {
84 if *half_life == 0 {
85 return *end;
86 }
87
88 let decay = 0.5_f32.powf(f_step / *half_life as f32);
89 end + (start - end) * decay
90 }
91 Rate::Cyclical(min, max, period, shape) => {
92 let phase = (f_step % *period as f32) / *period as f32;
93 let tri = if phase < 0.5 {
94 phase * 2.0
95 } else {
96 (1.0 - phase) * 2.0
97 };
98
99 let s = match shape {
100 CycleShape::Triangle => tri,
101 CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
102 };
103
104 min + (max - min) * s
105 }
106 Rate::Stepwise(steps) => {
107 if steps.is_empty() {
108 return 0.0;
109 }
110
111 let mut last_value = steps[0].1;
112 for (s, v) in steps {
113 if step < *s {
114 break;
115 }
116
117 last_value = *v;
118 }
119
120 last_value
121 }
122 _ => 1.0,
123 }
124 }
125}
126
127impl Valid for Rate {
128 fn is_valid(&self) -> bool {
129 match self {
130 Rate::Fixed(v) => (0.0..=1.0).contains(v),
131 Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
132 Rate::Exponential(start, end, _) => {
133 (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
134 }
135 Rate::Cyclical(min, max, _, _) => {
136 (0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
137 }
138 Rate::Stepwise(steps) => {
139 if steps.is_empty() {
140 return false;
141 }
142
143 if steps[0].0 != 0 {
144 return false;
145 }
146
147 let mut last_step = 0;
148 for (s, v) in steps {
149 if *s < last_step || !(0.0..=1.0).contains(v) {
150 return false;
151 }
152 last_step = *s;
153 }
154
155 true
156 }
157 _ => true,
158 }
159 }
160}
161
162impl Default for Rate {
163 fn default() -> Self {
164 Rate::Fixed(1.0)
165 }
166}
167
168impl From<f32> for Rate {
169 fn from(value: f32) -> Self {
170 Rate::Fixed(value)
171 }
172}
173
174impl From<Vec<(usize, f32)>> for Rate {
175 fn from(steps: Vec<(usize, f32)>) -> Self {
176 Rate::Stepwise(steps)
177 }
178}
179
180impl Debug for Rate {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 match self {
183 Rate::Fixed(v) => write!(f, "Rate::Fixed({})", v),
184 Rate::Linear(start, end, steps) => {
185 write!(
186 f,
187 "Rate::Linear(start: {}, end: {}, steps: {})",
188 start, end, steps
189 )
190 }
191 Rate::Exponential(start, end, half_life) => write!(
192 f,
193 "Rate::Exponential(start: {}, end: {}, half_life: {})",
194 start, end, half_life
195 ),
196 Rate::Cyclical(min, max, period, shape) => write!(
197 f,
198 "Rate::Cyclical(min: {}, max: {}, period: {}, shape: {:?})",
199 min, max, period, shape
200 ),
201 Rate::Stepwise(steps) => write!(f, "Rate::Stepwise(steps: {:?})", steps),
202 Rate::Expr(_) => write!(f, "Rate::Expr(<function>)"),
203 }
204 }
205}
206
207impl PartialEq for Rate {
208 fn eq(&self, other: &Self) -> bool {
209 match (self, other) {
210 (Rate::Fixed(a), Rate::Fixed(b)) => a == b,
211 (Rate::Linear(a_start, a_end, a_steps), Rate::Linear(b_start, b_end, b_steps)) => {
212 a_start == b_start && a_end == b_end && a_steps == b_steps
213 }
214 (
215 Rate::Exponential(a_start, a_end, a_half_life),
216 Rate::Exponential(b_start, b_end, b_half_life),
217 ) => a_start == b_start && a_end == b_end && a_half_life == b_half_life,
218 (
219 Rate::Cyclical(a_min, a_max, a_period, a_shape),
220 Rate::Cyclical(b_min, b_max, b_period, b_shape),
221 ) => a_min == b_min && a_max == b_max && a_period == b_period && a_shape == b_shape,
222 (Rate::Stepwise(a_steps), Rate::Stepwise(b_steps)) => a_steps == b_steps,
223 (Rate::Expr(_), Rate::Expr(_)) => true,
226 _ => false,
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_rate_values() {
237 let fixed = Rate::Fixed(0.5);
238 assert_eq!(fixed.get_by_index(0), 0.5);
239 assert_eq!(fixed.get_by_index(10), 0.5);
240
241 let linear = Rate::Linear(0.0, 1.0, 10);
242 assert_eq!(linear.get_by_index(0), 0.0);
243 assert_eq!(linear.get_by_index(5), 0.5);
244 assert_eq!(linear.get_by_index(10), 1.0);
245 assert_eq!(linear.get_by_index(15), 1.0);
246
247 let exponential = Rate::Exponential(1.0, 0.1, 5);
248 assert!((exponential.get_by_index(0) - 1.0).abs() < 1e-6);
249 assert!((exponential.get_by_index(5) - 0.55).abs() < 1e-2);
250 assert!((exponential.get_by_index(10) - 0.325).abs() < 1e-2);
251
252 let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
253 assert!((cyclical.get_by_index(0) - 0.0).abs() < 1e-6);
254 assert!((cyclical.get_by_index(2) - 0.4).abs() < 1e-6);
255 assert!((cyclical.get_by_index(5) - 1.0).abs() < 1e-6);
256 assert!((cyclical.get_by_index(7) - 0.6).abs() < 1e-6);
257 assert!((cyclical.get_by_index(10) - 0.0).abs() < 1e-6);
258
259 let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
260 assert!((cyclical_sine.get_by_index(0) - 0.0).abs() < 1e-6);
261 assert!(
262 (cyclical_sine.get_by_index(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs()
263 < 1e-6
264 );
265 assert!(cyclical_sine.get_by_index(5).abs() < 1e-6);
266 assert!(
267 (cyclical_sine.get_by_index(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs()
268 < 1e-6
269 );
270 assert!((cyclical_sine.get_by_index(10) - 0.0).abs() < 1e-6);
271
272 let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
273 assert_eq!(stepwise.get_by_index(0), 0.0);
274 assert_eq!(stepwise.get_by_index(3), 0.0);
275 assert_eq!(stepwise.get_by_index(5), 0.5);
276 assert_eq!(stepwise.get_by_index(7), 0.5);
277 assert_eq!(stepwise.get_by_index(10), 1.0);
278 assert_eq!(stepwise.get_by_index(15), 1.0);
279 }
280
281 #[test]
282 fn test_rates_between_0_and_1() {
283 let fixed = Rate::Fixed(0.5);
284 let linear = Rate::Linear(0.0, 1.0, 100);
285 let exponential = Rate::Exponential(1.0, 0.0, 50);
286 let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
287 let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
288 let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
289
290 for i in 0..100_000 {
291 let fixed_value = fixed.get_by_index(i);
292 let linear_value = linear.get_by_index(i);
293 let exp_value = exponential.get_by_index(i);
294 let cycle_value = cyclical.get_by_index(i);
295 let cycle_sine_value = cyclical_sine.get_by_index(i);
296 let stepwise_value = stepwise.get_by_index(i);
297
298 assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
299 assert!(linear_value >= 0.0 && linear_value <= 1.0);
300 assert!(exp_value >= 0.0 && exp_value <= 1.0);
301 assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
302 assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
303 assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
304 }
305 }
306
307 #[test]
308 fn test_rate_clamping() {
309 let linear = Rate::Linear(0.0, 1.0, 10);
310 assert_eq!(linear.get_by_index(15), 1.0);
311 }
312
313 #[test]
314 fn test_default_rate() {
315 let default_rate = Rate::default();
316 assert_eq!(default_rate.get_by_index(0), 1.0);
317 assert_eq!(default_rate.get_by_index(100), 1.0);
318 }
319
320 #[test]
321 fn test_rate_validity() {
322 let valid_fixed = Rate::Fixed(0.5);
323 let invalid_fixed = Rate::Fixed(1.5);
324 assert!(valid_fixed.is_valid());
325 assert!(!invalid_fixed.is_valid());
326 }
327}