const EPSILON: f64 = 1e-6;
#[inline]
pub fn fps(n: u32) -> f64 {
if n == 0 {
return 0.0;
}
1.0 / f64::from(n)
}
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct Spring {
pos_pos_coef: f64,
pos_vel_coef: f64,
vel_pos_coef: f64,
vel_vel_coef: f64,
}
impl Spring {
pub fn new(delta_time: f64, angular_frequency: f64, damping_ratio: f64) -> Self {
let angular_frequency = angular_frequency.max(0.0);
let damping_ratio = damping_ratio.max(0.0);
if angular_frequency < EPSILON {
return Self {
pos_pos_coef: 1.0,
pos_vel_coef: 0.0,
vel_pos_coef: 0.0,
vel_vel_coef: 1.0,
};
}
if damping_ratio > 1.0 + EPSILON {
Self::over_damped(delta_time, angular_frequency, damping_ratio)
} else if damping_ratio < 1.0 - EPSILON {
Self::under_damped(delta_time, angular_frequency, damping_ratio)
} else {
Self::critically_damped(delta_time, angular_frequency)
}
}
fn over_damped(delta_time: f64, angular_frequency: f64, damping_ratio: f64) -> Self {
let za = -angular_frequency * damping_ratio;
let zb = angular_frequency * (damping_ratio * damping_ratio - 1.0).sqrt();
let z1 = za - zb;
let z2 = za + zb;
let e1 = exp(z1 * delta_time);
let e2 = exp(z2 * delta_time);
let inv_two_zb = 1.0 / (2.0 * zb);
let e1_over_two_zb = e1 * inv_two_zb;
let e2_over_two_zb = e2 * inv_two_zb;
let z1e1_over_two_zb = z1 * e1_over_two_zb;
let z2e2_over_two_zb = z2 * e2_over_two_zb;
Self {
pos_pos_coef: e1_over_two_zb * z2 - z2e2_over_two_zb + e2,
pos_vel_coef: -e1_over_two_zb + e2_over_two_zb,
vel_pos_coef: (z1e1_over_two_zb - z2e2_over_two_zb + e2) * z2,
vel_vel_coef: -z1e1_over_two_zb + z2e2_over_two_zb,
}
}
fn under_damped(delta_time: f64, angular_frequency: f64, damping_ratio: f64) -> Self {
let omega_zeta = angular_frequency * damping_ratio;
let alpha = angular_frequency * (1.0 - damping_ratio * damping_ratio).sqrt();
let exp_term = exp(-omega_zeta * delta_time);
let cos_term = cos(alpha * delta_time);
let sin_term = sin(alpha * delta_time);
let inv_alpha = 1.0 / alpha;
let exp_sin = exp_term * sin_term;
let exp_cos = exp_term * cos_term;
let exp_omega_zeta_sin_over_alpha = exp_term * omega_zeta * sin_term * inv_alpha;
Self {
pos_pos_coef: exp_cos + exp_omega_zeta_sin_over_alpha,
pos_vel_coef: exp_sin * inv_alpha,
vel_pos_coef: -exp_sin * alpha - omega_zeta * exp_omega_zeta_sin_over_alpha,
vel_vel_coef: exp_cos - exp_omega_zeta_sin_over_alpha,
}
}
fn critically_damped(delta_time: f64, angular_frequency: f64) -> Self {
let exp_term = exp(-angular_frequency * delta_time);
let time_exp = delta_time * exp_term;
let time_exp_freq = time_exp * angular_frequency;
Self {
pos_pos_coef: time_exp_freq + exp_term,
pos_vel_coef: time_exp,
vel_pos_coef: -angular_frequency * time_exp_freq,
vel_vel_coef: -time_exp_freq + exp_term,
}
}
#[inline]
pub fn update(&self, pos: f64, vel: f64, equilibrium_pos: f64) -> (f64, f64) {
let old_pos = pos - equilibrium_pos;
let old_vel = vel;
let new_pos = old_pos * self.pos_pos_coef + old_vel * self.pos_vel_coef + equilibrium_pos;
let new_vel = old_pos * self.vel_pos_coef + old_vel * self.vel_vel_coef;
(new_pos, new_vel)
}
}
#[cfg(feature = "std")]
#[inline]
fn exp(x: f64) -> f64 {
x.exp()
}
#[cfg(not(feature = "std"))]
#[inline]
fn exp(x: f64) -> f64 {
libm::exp(x)
}
#[cfg(feature = "std")]
#[inline]
fn sin(x: f64) -> f64 {
x.sin()
}
#[cfg(not(feature = "std"))]
#[inline]
fn sin(x: f64) -> f64 {
libm::sin(x)
}
#[cfg(feature = "std")]
#[inline]
fn cos(x: f64) -> f64 {
x.cos()
}
#[cfg(not(feature = "std"))]
#[inline]
fn cos(x: f64) -> f64 {
libm::cos(x)
}
#[cfg(test)]
mod tests {
use super::*;
const TOLERANCE: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < TOLERANCE
}
#[test]
fn test_fps() {
assert!(approx_eq(fps(60), 1.0 / 60.0));
assert!(approx_eq(fps(30), 1.0 / 30.0));
assert!(approx_eq(fps(120), 1.0 / 120.0));
assert!(approx_eq(fps(0), 0.0));
}
#[test]
fn test_identity_spring() {
let spring = Spring::new(fps(60), 0.0, 0.5);
let (new_pos, new_vel) = spring.update(10.0, 5.0, 100.0);
assert!(approx_eq(new_pos, 10.0));
assert!(approx_eq(new_vel, 5.0));
}
#[test]
fn test_critically_damped_approaches_target() {
let spring = Spring::new(fps(60), 5.0, 1.0);
let mut pos = 0.0;
let mut vel = 0.0;
let target = 100.0;
for _ in 0..300 {
(pos, vel) = spring.update(pos, vel, target);
}
assert!(
(pos - target).abs() < 0.01,
"Expected pos ≈ {target}, got {pos}"
);
assert!(vel.abs() < 0.01, "Expected vel ≈ 0, got {vel}");
}
#[test]
fn test_under_damped_oscillates() {
let spring = Spring::new(fps(60), 10.0, 0.1);
let mut pos = 0.0;
let mut vel = 0.0;
let target = 100.0;
let mut crossed_target = false;
let mut overshot = false;
for _ in 0..120 {
let old_pos = pos;
(pos, vel) = spring.update(pos, vel, target);
if old_pos < target && pos >= target {
crossed_target = true;
}
if pos > target {
overshot = true;
}
}
assert!(crossed_target, "Under-damped spring should cross target");
assert!(overshot, "Under-damped spring should overshoot target");
}
#[test]
fn test_over_damped_no_oscillation() {
let spring = Spring::new(fps(60), 5.0, 2.0);
let mut pos = 0.0;
let mut vel = 0.0;
let target = 100.0;
let mut max_pos: f64 = 0.0;
for _ in 0..600 {
(pos, vel) = spring.update(pos, vel, target);
max_pos = max_pos.max(pos);
}
assert!(
max_pos <= target + TOLERANCE,
"Over-damped spring should not overshoot: max_pos={max_pos}, target={target}"
);
assert!(
(pos - target).abs() < 1.0,
"Over-damped spring should approach target"
);
}
#[test]
fn test_spring_is_copy() {
let spring = Spring::new(fps(60), 5.0, 0.5);
let spring2 = spring; let _ = spring.update(0.0, 0.0, 100.0);
let _ = spring2.update(0.0, 0.0, 100.0);
}
#[test]
fn test_negative_values_clamped() {
let spring = Spring::new(fps(60), -5.0, 0.5);
let (new_pos, new_vel) = spring.update(10.0, 5.0, 100.0);
assert!(approx_eq(new_pos, 10.0));
assert!(approx_eq(new_vel, 5.0));
}
#[test]
fn test_zero_damping_oscillates_indefinitely() {
let spring = Spring::new(fps(60), 5.0, 0.0);
let mut pos = 0.0;
let mut vel = 0.0;
let target = 100.0;
let mut oscillations = 0;
let mut last_sign = f64::signum(pos - target);
for _ in 0..600 {
(pos, vel) = spring.update(pos, vel, target);
let current_sign = f64::signum(pos - target);
#[allow(clippy::float_cmp)] if current_sign != last_sign && current_sign != 0.0 {
oscillations += 1;
last_sign = current_sign;
}
}
assert!(
oscillations >= 5,
"Zero damping should oscillate indefinitely, got {oscillations} oscillations"
);
}
#[test]
fn test_very_high_stiffness_snaps() {
let spring = Spring::new(fps(60), 100.0, 1.0);
let mut pos = 0.0;
let mut vel = 0.0;
let target = 100.0;
for _ in 0..30 {
(pos, vel) = spring.update(pos, vel, target);
}
assert!(
(pos - target).abs() < 1.0,
"High stiffness should snap quickly, got pos={pos}"
);
}
#[test]
fn test_negative_target() {
let spring = Spring::new(fps(60), 5.0, 1.0);
let mut pos = 100.0;
let mut vel = 0.0;
let target = -50.0;
for _ in 0..300 {
(pos, vel) = spring.update(pos, vel, target);
}
assert!(
(pos - target).abs() < 0.1,
"Should approach negative target, got pos={pos}"
);
}
#[test]
fn test_very_small_movements() {
let spring = Spring::new(fps(60), 5.0, 1.0);
let mut pos = 0.0;
let mut vel = 0.0;
let target = 0.001;
for _ in 0..300 {
(pos, vel) = spring.update(pos, vel, target);
}
assert!(
(pos - target).abs() < 0.0001,
"Should handle small movements, got pos={pos}, target={target}"
);
}
#[test]
fn test_large_time_delta() {
let spring = Spring::new(1.0, 5.0, 1.0); let mut pos = 0.0;
let mut vel = 0.0;
let target = 100.0;
for _ in 0..10 {
(pos, vel) = spring.update(pos, vel, target);
}
assert!(
(pos - target).abs() < 5.0,
"Large delta should still converge, got pos={pos}"
);
}
#[test]
fn test_accumulated_error_bounded() {
let spring = Spring::new(fps(60), 5.0, 0.5);
let mut pos = 0.0;
let mut vel = 0.0;
let target = 100.0;
for _ in 0..3600 {
(pos, vel) = spring.update(pos, vel, target);
}
assert!(
(pos - target).abs() < 0.001,
"Accumulated error should be bounded, got pos={pos}"
);
assert!(
vel.abs() < 0.001,
"Velocity should decay completely, got vel={vel}"
);
}
#[test]
fn test_spring_default() {
let spring = Spring::default();
let (new_pos, new_vel) = spring.update(10.0, 5.0, 100.0);
assert!(approx_eq(new_pos, 100.0));
assert!(approx_eq(new_vel, 0.0));
}
#[test]
fn test_spring_clone() {
let spring1 = Spring::new(fps(60), 5.0, 0.5);
let spring2 = spring1;
let result1 = spring1.update(0.0, 0.0, 100.0);
let result2 = spring2.update(0.0, 0.0, 100.0);
assert!(approx_eq(result1.0, result2.0));
assert!(approx_eq(result1.1, result2.1));
}
#[test]
fn test_spring_equilibrium_at_target() {
let spring = Spring::new(fps(60), 5.0, 0.5);
let target = 50.0;
let (new_pos, new_vel) = spring.update(target, 0.0, target);
assert!(approx_eq(new_pos, target));
assert!(approx_eq(new_vel, 0.0));
}
#[test]
fn test_fps_various_rates() {
assert!(approx_eq(fps(30), 1.0 / 30.0));
assert!(approx_eq(fps(60), 1.0 / 60.0));
assert!(approx_eq(fps(120), 1.0 / 120.0));
assert!(approx_eq(fps(144), 1.0 / 144.0));
assert!(approx_eq(fps(240), 1.0 / 240.0));
assert!(approx_eq(fps(1), 1.0));
}
#[test]
fn test_damping_ratio_boundary() {
let under = Spring::new(fps(60), 5.0, 0.999);
let critical = Spring::new(fps(60), 5.0, 1.0);
let over = Spring::new(fps(60), 5.0, 1.001);
let _ = under.update(0.0, 0.0, 100.0);
let _ = critical.update(0.0, 0.0, 100.0);
let _ = over.update(0.0, 0.0, 100.0);
}
#[test]
fn test_initial_velocity() {
let spring = Spring::new(fps(60), 5.0, 1.0);
let mut pos = 0.0;
let mut vel = 1000.0; let target = 50.0;
for _ in 0..600 {
(pos, vel) = spring.update(pos, vel, target);
}
assert!(
(pos - target).abs() < 0.1,
"Should converge despite initial velocity"
);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
fn convergence_frames(angular_freq: f64, damping_ratio: f64) -> usize {
let (tau, multiplier) = if damping_ratio > 1.0 {
let discriminant = (damping_ratio * damping_ratio - 1.0).sqrt();
(1.0 / (angular_freq * (damping_ratio - discriminant)), 8.0)
} else {
(1.0 / (angular_freq * damping_ratio.max(0.01)), 10.0)
};
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{ (tau * multiplier * 60.0) as usize }.clamp(600, 18000)
}
proptest! {
#[test]
fn spring_update_never_produces_nan_or_inf(
delta_time in 0.0001f64..1.0,
angular_freq in 0.0f64..500.0,
damping_ratio in 0.0f64..50.0,
pos in -1e6f64..1e6,
vel in -1e6f64..1e6,
target in -1e6f64..1e6,
) {
let spring = Spring::new(delta_time, angular_freq, damping_ratio);
let (new_pos, new_vel) = spring.update(pos, vel, target);
prop_assert!(!new_pos.is_nan(), "position was NaN for dt={delta_time}, af={angular_freq}, dr={damping_ratio}, pos={pos}, vel={vel}, target={target}");
prop_assert!(!new_pos.is_infinite(), "position was Inf for dt={delta_time}, af={angular_freq}, dr={damping_ratio}, pos={pos}, vel={vel}, target={target}");
prop_assert!(!new_vel.is_nan(), "velocity was NaN for dt={delta_time}, af={angular_freq}, dr={damping_ratio}, pos={pos}, vel={vel}, target={target}");
prop_assert!(!new_vel.is_infinite(), "velocity was Inf for dt={delta_time}, af={angular_freq}, dr={damping_ratio}, pos={pos}, vel={vel}, target={target}");
}
#[test]
fn spring_multi_frame_never_produces_nan_or_inf(
angular_freq in 0.1f64..100.0,
damping_ratio in 0.01f64..10.0,
target in -1000.0f64..1000.0,
) {
let spring = Spring::new(fps(60), angular_freq, damping_ratio);
let mut pos = 0.0;
let mut vel = 0.0;
for _ in 0..600 {
(pos, vel) = spring.update(pos, vel, target);
prop_assert!(!pos.is_nan(), "position became NaN during simulation");
prop_assert!(!pos.is_infinite(), "position became Inf during simulation");
prop_assert!(!vel.is_nan(), "velocity became NaN during simulation");
prop_assert!(!vel.is_infinite(), "velocity became Inf during simulation");
}
}
}
proptest! {
#[test]
fn damped_spring_converges_to_target(
angular_freq in 1.0f64..50.0,
damping_ratio in 0.2f64..10.0,
target in -500.0f64..500.0,
) {
let spring = Spring::new(fps(60), angular_freq, damping_ratio);
let mut pos = 0.0;
let mut vel = 0.0;
let frames = convergence_frames(angular_freq, damping_ratio);
for _ in 0..frames {
(pos, vel) = spring.update(pos, vel, target);
}
let error = (pos - target).abs();
let tolerance = 1.0f64.max(target.abs() * 0.005);
prop_assert!(
error < tolerance,
"Spring did not converge: pos={pos}, target={target}, error={error}, tol={tolerance}, af={angular_freq}, dr={damping_ratio}, frames={frames}"
);
}
#[test]
fn spring_final_velocity_near_zero(
angular_freq in 1.0f64..50.0,
damping_ratio in 0.2f64..10.0,
target in -500.0f64..500.0,
) {
let spring = Spring::new(fps(60), angular_freq, damping_ratio);
let mut pos = 0.0;
let mut vel = 0.0;
let frames = convergence_frames(angular_freq, damping_ratio);
for _ in 0..frames {
(pos, vel) = spring.update(pos, vel, target);
}
let tolerance = 1.0f64.max(target.abs() * 0.005);
prop_assert!(
vel.abs() < tolerance,
"Velocity did not decay: vel={vel}, tol={tolerance}, af={angular_freq}, dr={damping_ratio}, frames={frames}"
);
}
}
proptest! {
#[test]
fn higher_stiffness_means_faster_initial_response(
low_freq in 1.0f64..10.0,
high_freq_add in 10.0f64..90.0,
) {
let high_freq = low_freq + high_freq_add;
let damping = 1.0; let target = 100.0;
let spring_low = Spring::new(fps(60), low_freq, damping);
let spring_high = Spring::new(fps(60), high_freq, damping);
let (pos_low, _) = spring_low.update(0.0, 0.0, target);
let (pos_high, _) = spring_high.update(0.0, 0.0, target);
prop_assert!(
pos_high >= pos_low,
"Higher stiffness should respond faster: pos_high={pos_high}, pos_low={pos_low}"
);
}
#[test]
fn over_damped_does_not_overshoot(
angular_freq in 1.0f64..50.0,
damping_excess in 0.5f64..10.0,
target in 1.0f64..1000.0,
) {
let damping_ratio = 1.0 + damping_excess; let spring = Spring::new(fps(60), angular_freq, damping_ratio);
let mut pos = 0.0;
let mut vel = 0.0;
for _ in 0..600 {
(pos, vel) = spring.update(pos, vel, target);
prop_assert!(
pos <= target + 0.01,
"Over-damped spring overshot: pos={pos}, target={target}, af={angular_freq}, dr={damping_ratio}"
);
}
}
#[test]
fn under_damped_oscillates(
angular_freq in 5.0f64..50.0,
damping_ratio in 0.01f64..0.3,
) {
let spring = Spring::new(fps(60), angular_freq, damping_ratio);
let target = 100.0;
let mut pos = 0.0;
let mut vel = 0.0;
let mut overshot = false;
for _ in 0..300 {
(pos, vel) = spring.update(pos, vel, target);
if pos > target {
overshot = true;
break;
}
}
prop_assert!(
overshot,
"Under-damped spring should overshoot: af={angular_freq}, dr={damping_ratio}"
);
}
}
proptest! {
#[test]
fn at_equilibrium_stays_at_equilibrium(
angular_freq in 0.1f64..100.0,
damping_ratio in 0.0f64..10.0,
target in -1000.0f64..1000.0,
) {
let spring = Spring::new(fps(60), angular_freq, damping_ratio);
let (new_pos, new_vel) = spring.update(target, 0.0, target);
let pos_error = (new_pos - target).abs();
prop_assert!(
pos_error < 1e-10,
"Position drifted from equilibrium: error={pos_error}"
);
prop_assert!(
new_vel.abs() < 1e-10,
"Velocity non-zero at equilibrium: vel={new_vel}"
);
}
}
proptest! {
#[test]
fn frame_independence_approximate(
angular_freq in 1.0f64..20.0,
damping_ratio in 0.1f64..5.0,
target in 10.0f64..500.0,
) {
let spring_60 = Spring::new(fps(60), angular_freq, damping_ratio);
let spring_120 = Spring::new(fps(120), angular_freq, damping_ratio);
let mut pos_60 = 0.0;
let mut vel_60 = 0.0;
for _ in 0..60 {
(pos_60, vel_60) = spring_60.update(pos_60, vel_60, target);
}
let mut pos_120 = 0.0;
let mut vel_120 = 0.0;
for _ in 0..120 {
(pos_120, vel_120) = spring_120.update(pos_120, vel_120, target);
}
let pos_diff = (pos_60 - pos_120).abs();
let tolerance = target.abs() * 0.05; prop_assert!(
pos_diff < tolerance,
"Frame rate independence violated: pos@60fps={pos_60}, pos@120fps={pos_120}, diff={pos_diff}, tol={tolerance}"
);
}
}
proptest! {
#[test]
fn negative_angular_freq_acts_as_identity(
neg_freq in -100.0f64..0.0,
damping in 0.0f64..10.0,
pos in -1000.0f64..1000.0,
vel in -1000.0f64..1000.0,
target in -1000.0f64..1000.0,
) {
let spring = Spring::new(fps(60), neg_freq, damping);
let (new_pos, new_vel) = spring.update(pos, vel, target);
let pos_error = (new_pos - pos).abs();
let vel_error = (new_vel - vel).abs();
prop_assert!(pos_error < 1e-10, "Identity spring changed position: {new_pos} != {pos}");
prop_assert!(vel_error < 1e-10, "Identity spring changed velocity: {new_vel} != {vel}");
}
#[test]
fn zero_delta_time_identity(
angular_freq in 0.1f64..100.0,
damping_ratio in 0.0f64..10.0,
pos in -1000.0f64..1000.0,
vel in -1000.0f64..1000.0,
target in -1000.0f64..1000.0,
) {
let spring = Spring::new(0.0, angular_freq, damping_ratio);
let (new_pos, new_vel) = spring.update(pos, vel, target);
prop_assert!(!new_pos.is_nan(), "NaN with zero delta time");
prop_assert!(!new_vel.is_nan(), "NaN velocity with zero delta time");
}
}
}