use std::f64::consts::E;
#[derive(Debug, Clone, Copy)]
pub struct Spring {
pub stiffness: f64,
pub damping: f64,
pub mass: f64,
pub initial_velocity: f64,
pub rest_threshold: f64,
}
impl Default for Spring {
fn default() -> Self {
Self::new()
}
}
impl Spring {
const EPSILON: f64 = 1e-9;
#[inline]
pub fn new() -> Self {
Self {
stiffness: 100.0,
damping: 10.0,
mass: 1.0,
initial_velocity: 0.0,
rest_threshold: 0.001,
}
}
#[inline]
pub fn stiffness(mut self, s: f64) -> Self {
self.stiffness = s.max(Self::EPSILON);
self
}
#[inline]
pub fn damping(mut self, d: f64) -> Self {
self.damping = d.max(0.0);
self
}
#[inline]
pub fn mass(mut self, m: f64) -> Self {
self.mass = m.max(Self::EPSILON);
self
}
#[inline]
pub fn initial_velocity(mut self, v: f64) -> Self {
self.initial_velocity = v;
self
}
#[inline]
pub fn rest_threshold(mut self, t: f64) -> Self {
self.rest_threshold = t.max(0.0);
self
}
#[inline]
pub fn damping_ratio(&self) -> f64 {
self.damping / (2.0 * (self.stiffness * self.mass).sqrt())
}
#[inline]
pub fn angular_frequency(&self) -> f64 {
(self.stiffness / self.mass).sqrt()
}
pub fn evaluate(&self, t: f64) -> (f64, f64) {
if t <= 0.0 {
return (1.0, self.initial_velocity);
}
let zeta = self.damping_ratio();
let w0 = self.angular_frequency();
let x0 = 1.0;
let v0 = self.initial_velocity;
if (zeta - 1.0).abs() < 1e-6 {
self.evaluate_critical(t, w0, x0, v0)
} else if zeta < 1.0 {
self.evaluate_underdamped(t, w0, zeta, x0, v0)
} else {
self.evaluate_overdamped(t, w0, zeta, x0, v0)
}
}
#[inline]
fn evaluate_underdamped(&self, t: f64, w0: f64, zeta: f64, x0: f64, v0: f64) -> (f64, f64) {
let wd = w0 * (1.0 - zeta * zeta).sqrt();
let a = x0;
let b = (v0 + zeta * w0 * x0) / wd;
let envelope = E.powf(-zeta * w0 * t);
let cos_term = (wd * t).cos();
let sin_term = (wd * t).sin();
let position = envelope * (a * cos_term + b * sin_term);
let velocity = envelope * (
(-zeta * w0) * (a * cos_term + b * sin_term)
+ wd * (-a * sin_term + b * cos_term)
);
(position, velocity)
}
#[inline]
fn evaluate_critical(&self, t: f64, w0: f64, x0: f64, v0: f64) -> (f64, f64) {
let a = x0;
let b = v0 + w0 * x0;
let envelope = E.powf(-w0 * t);
let position = (a + b * t) * envelope;
let velocity = envelope * (b - w0 * (a + b * t));
(position, velocity)
}
#[inline]
fn evaluate_overdamped(&self, t: f64, w0: f64, zeta: f64, x0: f64, v0: f64) -> (f64, f64) {
let sqrt_term = (zeta * zeta - 1.0).sqrt();
let r1 = -w0 * (zeta - sqrt_term);
let r2 = -w0 * (zeta + sqrt_term);
let a = (v0 - r2 * x0) / (r1 - r2);
let b = x0 - a;
let exp1 = E.powf(r1 * t);
let exp2 = E.powf(r2 * t);
let position = a * exp1 + b * exp2;
let velocity = a * r1 * exp1 + b * r2 * exp2;
(position, velocity)
}
#[inline]
pub fn is_at_rest(&self, t: f64) -> bool {
let (position, velocity) = self.evaluate(t);
position.abs() + velocity.abs() < self.rest_threshold
}
pub fn estimated_duration(&self) -> f64 {
let zeta = self.damping_ratio();
let w0 = self.angular_frequency();
if w0 < Self::EPSILON {
return 100.0; }
let decay_rate = zeta * w0;
if decay_rate < Self::EPSILON {
return 100.0; }
let duration = -self.rest_threshold.ln() / decay_rate;
duration.clamp(0.0, 100.0)
}
#[inline]
pub fn gentle() -> Self {
Self::new().stiffness(120.0).damping(14.0)
}
#[inline]
pub fn bouncy() -> Self {
Self::new().stiffness(180.0).damping(12.0)
}
#[inline]
pub fn stiff() -> Self {
Self::new().stiffness(300.0).damping(20.0)
}
#[inline]
pub fn slow() -> Self {
Self::new().stiffness(60.0).damping(14.0)
}
pub fn as_easing(&self, samples: usize) -> Vec<f64> {
if samples == 0 {
return vec![];
}
let duration = self.estimated_duration();
let dt = duration / samples.max(1) as f64;
(0..samples)
.map(|i| {
let t = i as f64 * dt;
let (position, _) = self.evaluate(t);
1.0 - position.clamp(0.0, 1.0)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_damping_ratio() {
let spring = Spring::new();
let ratio = spring.damping_ratio();
assert!((ratio - 0.5).abs() < 1e-6);
}
#[test]
fn test_angular_frequency() {
let spring = Spring::new();
let w0 = spring.angular_frequency();
assert!((w0 - 10.0).abs() < 1e-6);
}
#[test]
fn test_initial_conditions() {
let spring = Spring::new();
let (pos, vel) = spring.evaluate(0.0);
assert!((pos - 1.0).abs() < 1e-6);
assert!(vel.abs() < 1e-6);
}
#[test]
fn test_underdamped_oscillation() {
let spring = Spring::bouncy();
let mut positions = Vec::new();
for i in 0..100 {
let t = i as f64 * 0.01;
let (pos, _) = spring.evaluate(t);
positions.push(pos);
}
let mut sign_changes = 0;
for i in 1..positions.len() {
if positions[i-1] * positions[i] < 0.0 {
sign_changes += 1;
}
}
assert!(sign_changes > 0, "Under-damped spring should oscillate");
}
#[test]
fn test_overdamped_no_oscillation() {
let spring = Spring::new().stiffness(50.0).damping(50.0);
for i in 0..100 {
let t = i as f64 * 0.01;
let (pos, _) = spring.evaluate(t);
assert!(pos >= -1e-6, "Over-damped spring should not overshoot");
}
}
#[test]
fn test_critically_damped_fast_settle() {
let spring = Spring::new().stiffness(100.0).damping(20.0); let ratio = spring.damping_ratio();
assert!((ratio - 1.0).abs() < 1e-3);
let duration = spring.estimated_duration();
assert!(duration > 0.0 && duration < 10.0);
}
#[test]
fn test_rest_detection() {
let spring = Spring::stiff();
assert!(!spring.is_at_rest(0.0));
assert!(spring.is_at_rest(5.0));
}
#[test]
fn test_presets() {
let gentle = Spring::gentle();
let bouncy = Spring::bouncy();
let stiff = Spring::stiff();
let slow = Spring::slow();
assert!(gentle.stiffness > 0.0);
assert!(bouncy.stiffness > 0.0);
assert!(stiff.stiffness > 0.0);
assert!(slow.stiffness > 0.0);
}
#[test]
fn test_as_easing() {
let spring = Spring::gentle();
let easing = spring.as_easing(100);
assert_eq!(easing.len(), 100);
assert!(easing[0] < 0.1);
assert!(easing[99] > 0.9);
assert!(easing[99] > easing[0]);
}
#[test]
fn test_initial_velocity() {
let spring = Spring::new().initial_velocity(10.0);
let (_, vel) = spring.evaluate(0.0);
assert!((vel - 10.0).abs() < 1e-6);
}
#[test]
fn test_edge_cases() {
let spring = Spring::new().stiffness(0.0);
assert!(spring.stiffness > 0.0);
let spring = Spring::new().mass(0.0);
assert!(spring.mass > 0.0);
let spring = Spring::new().damping(-5.0);
assert!(spring.damping >= 0.0);
}
#[test]
fn test_energy_conservation() {
let spring = Spring::bouncy();
let duration = spring.estimated_duration();
let (pos_early, _) = spring.evaluate(duration * 0.1);
let (pos_late, _) = spring.evaluate(duration * 0.9);
assert!(pos_early.abs() <= 2.0); assert!(pos_late.abs() < spring.rest_threshold * 10.0);
}
}