1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
use serde::{Deserialize, Serialize}; use typetag; #[typetag::serde] pub trait Evaluator: std::fmt::Debug + Sync + Send { fn evaluate(&self, value: f32) -> f32; } #[derive(Debug, Serialize, Deserialize)] pub struct LinearEvaluator { xa: f32, ya: f32, dy_over_dx: f32, } impl LinearEvaluator { pub fn new(xa: f32, ya: f32, xb: f32, yb: f32) -> Self { Self { xa, ya, dy_over_dx: (yb - ya) / (xb - xa), } } } #[typetag::serde] impl Evaluator for LinearEvaluator { fn evaluate(&self, value: f32) -> f32 { clamp(self.ya + self.dy_over_dx * (value - self.xa), 0.0, 1.0) } } #[derive(Debug, Serialize, Deserialize)] pub struct PowerEvaluator { xa: f32, ya: f32, xb: f32, power: f32, dy: f32, } impl PowerEvaluator { pub fn new(power: f32, xa: f32, ya: f32, xb: f32, yb: f32) -> Self { Self { power: clamp(power, 0.0, 10000.0), dy: yb - ya, xa, ya, xb, } } } #[typetag::serde] impl Evaluator for PowerEvaluator { fn evaluate(&self, value: f32) -> f32 { let cx = clamp(value, self.xa, self.xb); self.dy * ((cx - self.xa) / (self.xb - self.xa)).powf(self.power) + self.ya } } #[derive(Debug, Serialize, Deserialize)] pub struct SigmoidEvaluator { xa: f32, xb: f32, k: f32, two_over_dx: f32, x_mean: f32, y_mean: f32, dy_over_two: f32, one_minus_k: f32, } impl SigmoidEvaluator { pub fn new(k: f32, xa: f32, ya: f32, xb: f32, yb: f32) -> Self { let k = clamp(k, -0.99999, 0.99999); Self { xa, xb, two_over_dx: (2.0 / (xb - ya)).abs(), x_mean: (xa + xb) / 2.0, y_mean: (ya + yb) / 2.0, dy_over_two: (yb - ya) / 2.0, one_minus_k: 1.0 - k, k, } } } #[typetag::serde] impl Evaluator for SigmoidEvaluator { fn evaluate(&self, x: f32) -> f32 { let cx_minus_x_mean = clamp(x, self.xa, self.xb) - self.x_mean; let numerator = self.two_over_dx * cx_minus_x_mean * self.one_minus_k; let denominator = self.k * (1.0 - 2.0 * (self.two_over_dx * cx_minus_x_mean)).abs() + 1.0; self.dy_over_two * (numerator / denominator) + self.y_mean } } fn clamp<T: PartialOrd>(val: T, min: T, max: T) -> T { let val = if val > max { max } else { val }; if val < min { min } else { val } }