use crate::math::MulAdd;
macro_rules! impl_spring {
($name:ident, $ty:ty) => {
#[derive(Debug, Clone)]
pub struct $name {
smooth_time: $ty,
value: $ty,
velocity: $ty,
initialized: bool,
}
impl $name {
#[inline]
pub fn new(smooth_time: $ty) -> Result<Self, crate::ConfigError> {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
if !(smooth_time > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("smooth_time must be positive"));
}
Ok(Self {
smooth_time,
value: 0.0 as $ty,
velocity: 0.0 as $ty,
initialized: false,
})
}
#[inline]
pub fn update(&mut self, target: $ty, dt: $ty) -> Result<$ty, crate::DataError> {
check_finite!(target);
check_finite!(dt);
if !self.initialized {
self.value = target;
self.initialized = true;
return Ok(target);
}
let omega = 2.0 as $ty / self.smooth_time;
let x = omega * dt;
let exp_neg = 1.0 as $ty / (x.fma(x.fma(0.5 as $ty, 1.0 as $ty), 1.0 as $ty));
let delta = self.value - target;
let temp = (self.velocity + omega * delta) * dt;
self.velocity = (self.velocity - omega * temp) * exp_neg;
self.value = (delta + temp).fma(exp_neg, target);
Ok(self.value)
}
#[inline]
#[must_use]
pub fn value(&self) -> $ty {
self.value
}
#[inline]
#[must_use]
pub fn velocity(&self) -> $ty {
self.velocity
}
#[inline]
pub fn reset(&mut self) {
self.value = 0.0 as $ty;
self.velocity = 0.0 as $ty;
self.initialized = false;
}
#[inline]
pub fn reset_to(&mut self, value: $ty) {
self.value = value;
self.velocity = 0.0 as $ty;
self.initialized = true;
}
}
};
}
impl_spring!(SpringF64, f64);
impl_spring!(SpringF32, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn converges_to_target() {
let mut s = SpringF64::new(0.5).unwrap();
let target = 100.0;
for _ in 0..200 {
s.update(target, 0.016).unwrap(); }
assert!(
(s.value() - target).abs() < 0.01,
"should converge to {target}, got {}",
s.value()
);
}
#[test]
fn no_overshoot() {
let mut s = SpringF64::new(0.5).unwrap();
let target = 100.0;
s.update(0.0, 0.016).unwrap();
let mut max_value = 0.0f64;
for _ in 0..1000 {
let v = s.update(target, 0.016).unwrap();
if v > max_value {
max_value = v;
}
}
assert!(
max_value <= target + 0.1,
"should not overshoot, max was {max_value}"
);
}
#[test]
fn variable_dt_stable() {
let mut s = SpringF64::new(1.0).unwrap();
let target = 50.0;
s.update(target, 0.5).unwrap();
assert!(s.value().is_finite());
s.update(target, 2.0).unwrap();
assert!(s.value().is_finite());
s.update(target, 10.0).unwrap();
assert!(s.value().is_finite());
}
#[test]
#[allow(clippy::float_cmp)]
fn reset_to() {
let mut s = SpringF64::new(0.5).unwrap();
s.update(100.0, 0.016).unwrap();
s.reset_to(50.0);
assert_eq!(s.value(), 50.0);
assert_eq!(s.velocity(), 0.0);
}
#[test]
fn f32_basic() {
let mut s = SpringF32::new(0.5).unwrap();
let v = s.update(100.0, 0.016).unwrap();
assert!((v - 100.0).abs() < 0.01);
}
#[test]
fn rejects_zero_smooth_time() {
assert!(matches!(
SpringF64::new(0.0),
Err(crate::ConfigError::Invalid(_))
));
}
#[test]
fn rejects_nan_and_inf() {
let mut s = SpringF64::new(0.5).unwrap();
assert!(matches!(
s.update(f64::NAN, 0.016),
Err(crate::DataError::NotANumber)
));
assert!(matches!(
s.update(f64::INFINITY, 0.016),
Err(crate::DataError::Infinite)
));
assert!(matches!(
s.update(100.0, f64::NAN),
Err(crate::DataError::NotANumber)
));
assert!(matches!(
s.update(100.0, f64::INFINITY),
Err(crate::DataError::Infinite)
));
}
}