use crate::{MetricSet, Valid};
use radiate_expr::{ApplyExpr, Expr};
use std::fmt::Debug;
pub trait RateCalculator {
fn rate(&mut self, generation: usize, metrics: &MetricSet) -> f32;
}
#[derive(Clone, Debug, PartialEq)]
pub enum CycleShape {
Triangle,
Sine,
}
#[derive(Clone)]
pub enum Rate {
Fixed(f32),
Linear(f32, f32, usize),
Exponential(f32, f32, usize),
Cyclical(f32, f32, usize, CycleShape),
Stepwise(Vec<(usize, f32)>),
Expr(Expr),
}
impl Rate {
pub fn get(&mut self, generation: usize, metrics: &MetricSet) -> f32 {
match self {
Rate::Expr(expr) => metrics.apply(expr).extract().unwrap_or(0.0),
_ => self.get_by_index(generation),
}
}
pub fn get_by_index(&self, step: usize) -> f32 {
let f_step = step as f32;
match self {
Rate::Fixed(v) => *v,
Rate::Linear(start, end, steps) => {
if *steps == 0 {
return *end;
}
let t = (f_step / *steps as f32).min(1.0);
start + (end - start) * t
}
Rate::Exponential(start, end, half_life) => {
if *half_life == 0 {
return *end;
}
let decay = 0.5_f32.powf(f_step / *half_life as f32);
end + (start - end) * decay
}
Rate::Cyclical(min, max, period, shape) => {
let phase = (f_step % *period as f32) / *period as f32;
let tri = if phase < 0.5 {
phase * 2.0
} else {
(1.0 - phase) * 2.0
};
let s = match shape {
CycleShape::Triangle => tri,
CycleShape::Sine => (std::f32::consts::TAU * phase).sin().abs(),
};
min + (max - min) * s
}
Rate::Stepwise(steps) => {
if steps.is_empty() {
return 0.0;
}
let mut last_value = steps[0].1;
for (s, v) in steps {
if step < *s {
break;
}
last_value = *v;
}
last_value
}
_ => 1.0,
}
}
}
impl Valid for Rate {
fn is_valid(&self) -> bool {
match self {
Rate::Fixed(v) => (0.0..=1.0).contains(v),
Rate::Linear(start, end, _) => (0.0..=1.0).contains(start) && (0.0..=1.0).contains(end),
Rate::Exponential(start, end, _) => {
(0.0..=1.0).contains(start) && (0.0..=1.0).contains(end)
}
Rate::Cyclical(min, max, _, _) => {
(0.0..=1.0).contains(min) && (0.0..=1.0).contains(max) && min <= max
}
Rate::Stepwise(steps) => {
if steps.is_empty() {
return false;
}
if steps[0].0 != 0 {
return false;
}
let mut last_step = 0;
for (s, v) in steps {
if *s < last_step || !(0.0..=1.0).contains(v) {
return false;
}
last_step = *s;
}
true
}
_ => true,
}
}
}
impl Default for Rate {
fn default() -> Self {
Rate::Fixed(1.0)
}
}
impl From<f32> for Rate {
fn from(value: f32) -> Self {
Rate::Fixed(value)
}
}
impl From<Vec<(usize, f32)>> for Rate {
fn from(steps: Vec<(usize, f32)>) -> Self {
Rate::Stepwise(steps)
}
}
impl Debug for Rate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Rate::Fixed(v) => write!(f, "Rate::Fixed({})", v),
Rate::Linear(start, end, steps) => {
write!(
f,
"Rate::Linear(start: {}, end: {}, steps: {})",
start, end, steps
)
}
Rate::Exponential(start, end, half_life) => write!(
f,
"Rate::Exponential(start: {}, end: {}, half_life: {})",
start, end, half_life
),
Rate::Cyclical(min, max, period, shape) => write!(
f,
"Rate::Cyclical(min: {}, max: {}, period: {}, shape: {:?})",
min, max, period, shape
),
Rate::Stepwise(steps) => write!(f, "Rate::Stepwise(steps: {:?})", steps),
Rate::Expr(_) => write!(f, "Rate::Expr(<function>)"),
}
}
}
impl PartialEq for Rate {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Rate::Fixed(a), Rate::Fixed(b)) => a == b,
(Rate::Linear(a_start, a_end, a_steps), Rate::Linear(b_start, b_end, b_steps)) => {
a_start == b_start && a_end == b_end && a_steps == b_steps
}
(
Rate::Exponential(a_start, a_end, a_half_life),
Rate::Exponential(b_start, b_end, b_half_life),
) => a_start == b_start && a_end == b_end && a_half_life == b_half_life,
(
Rate::Cyclical(a_min, a_max, a_period, a_shape),
Rate::Cyclical(b_min, b_max, b_period, b_shape),
) => a_min == b_min && a_max == b_max && a_period == b_period && a_shape == b_shape,
(Rate::Stepwise(a_steps), Rate::Stepwise(b_steps)) => a_steps == b_steps,
(Rate::Expr(_), Rate::Expr(_)) => true,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_values() {
let fixed = Rate::Fixed(0.5);
assert_eq!(fixed.get_by_index(0), 0.5);
assert_eq!(fixed.get_by_index(10), 0.5);
let linear = Rate::Linear(0.0, 1.0, 10);
assert_eq!(linear.get_by_index(0), 0.0);
assert_eq!(linear.get_by_index(5), 0.5);
assert_eq!(linear.get_by_index(10), 1.0);
assert_eq!(linear.get_by_index(15), 1.0);
let exponential = Rate::Exponential(1.0, 0.1, 5);
assert!((exponential.get_by_index(0) - 1.0).abs() < 1e-6);
assert!((exponential.get_by_index(5) - 0.55).abs() < 1e-2);
assert!((exponential.get_by_index(10) - 0.325).abs() < 1e-2);
let cyclical = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Triangle);
assert!((cyclical.get_by_index(0) - 0.0).abs() < 1e-6);
assert!((cyclical.get_by_index(2) - 0.4).abs() < 1e-6);
assert!((cyclical.get_by_index(5) - 1.0).abs() < 1e-6);
assert!((cyclical.get_by_index(7) - 0.6).abs() < 1e-6);
assert!((cyclical.get_by_index(10) - 0.0).abs() < 1e-6);
let cyclical_sine = Rate::Cyclical(0.0, 1.0, 10, CycleShape::Sine);
assert!((cyclical_sine.get_by_index(0) - 0.0).abs() < 1e-6);
assert!(
(cyclical_sine.get_by_index(2) - (std::f32::consts::TAU * 0.2).sin().abs()).abs()
< 1e-6
);
assert!(cyclical_sine.get_by_index(5).abs() < 1e-6);
assert!(
(cyclical_sine.get_by_index(7) - (std::f32::consts::TAU * 0.7).sin().abs()).abs()
< 1e-6
);
assert!((cyclical_sine.get_by_index(10) - 0.0).abs() < 1e-6);
let stepwise = Rate::Stepwise(vec![(0, 0.0), (5, 0.5), (10, 1.0)]);
assert_eq!(stepwise.get_by_index(0), 0.0);
assert_eq!(stepwise.get_by_index(3), 0.0);
assert_eq!(stepwise.get_by_index(5), 0.5);
assert_eq!(stepwise.get_by_index(7), 0.5);
assert_eq!(stepwise.get_by_index(10), 1.0);
assert_eq!(stepwise.get_by_index(15), 1.0);
}
#[test]
fn test_rates_between_0_and_1() {
let fixed = Rate::Fixed(0.5);
let linear = Rate::Linear(0.0, 1.0, 100);
let exponential = Rate::Exponential(1.0, 0.0, 50);
let cyclical = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Triangle);
let cyclical_sine = Rate::Cyclical(0.0, 1.0, 20, CycleShape::Sine);
let stepwise = Rate::Stepwise(vec![(0, 0.0), (10, 0.5), (20, 1.0)]);
for i in 0..100_000 {
let fixed_value = fixed.get_by_index(i);
let linear_value = linear.get_by_index(i);
let exp_value = exponential.get_by_index(i);
let cycle_value = cyclical.get_by_index(i);
let cycle_sine_value = cyclical_sine.get_by_index(i);
let stepwise_value = stepwise.get_by_index(i);
assert!(fixed_value >= 0.0 && fixed_value <= 1.0);
assert!(linear_value >= 0.0 && linear_value <= 1.0);
assert!(exp_value >= 0.0 && exp_value <= 1.0);
assert!(cycle_value >= 0.0 && cycle_value <= 1.0);
assert!(cycle_sine_value >= 0.0 && cycle_sine_value <= 1.0);
assert!(stepwise_value >= 0.0 && stepwise_value <= 1.0);
}
}
#[test]
fn test_rate_clamping() {
let linear = Rate::Linear(0.0, 1.0, 10);
assert_eq!(linear.get_by_index(15), 1.0);
}
#[test]
fn test_default_rate() {
let default_rate = Rate::default();
assert_eq!(default_rate.get_by_index(0), 1.0);
assert_eq!(default_rate.get_by_index(100), 1.0);
}
#[test]
fn test_rate_validity() {
let valid_fixed = Rate::Fixed(0.5);
let invalid_fixed = Rate::Fixed(1.5);
assert!(valid_fixed.is_valid());
assert!(!invalid_fixed.is_valid());
}
}