use crate::spring::SpringAnimatable;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
use crate::traits::Update;
#[derive(Clone, Debug)]
pub struct InertiaConfig {
pub friction: f32,
pub epsilon: f32,
}
impl InertiaConfig {
pub fn default_flick() -> Self {
Self {
friction: 0.05,
epsilon: 0.1,
}
}
pub fn heavy() -> Self {
Self {
friction: 0.02,
epsilon: 0.1,
}
}
pub fn snappy() -> Self {
Self {
friction: 0.1,
epsilon: 0.1,
}
}
}
impl Default for InertiaConfig {
fn default() -> Self {
Self::default_flick()
}
}
#[derive(Clone, Debug)]
pub struct Inertia {
pub config: InertiaConfig,
velocity: f32,
position: f32,
settled: bool,
}
impl Inertia {
pub fn new(config: InertiaConfig) -> Self {
Self {
config,
velocity: 0.0,
position: 0.0,
settled: true,
}
}
pub fn with_velocity(mut self, velocity: f32) -> Self {
self.velocity = velocity;
self.settled = velocity.abs() < self.config.epsilon;
self
}
pub fn with_position(mut self, position: f32) -> Self {
self.position = position;
self
}
pub fn kick(&mut self, velocity: f32) {
self.velocity = velocity;
self.settled = false;
}
pub fn position(&self) -> f32 {
self.position
}
pub fn velocity(&self) -> f32 {
self.velocity
}
pub fn is_settled(&self) -> bool {
self.settled
}
pub fn reset(&mut self) {
self.position = 0.0;
self.velocity = 0.0;
self.settled = true;
}
}
impl Update for Inertia {
fn update(&mut self, dt: f32) -> bool {
if self.settled {
return false;
}
let decay = (1.0 - self.config.friction).powf(dt * 60.0);
self.velocity *= decay;
self.position += self.velocity * dt;
if self.velocity.abs() < self.config.epsilon {
self.velocity = 0.0;
self.settled = true;
}
!self.settled
}
}
#[derive(Clone, Debug)]
pub struct InertiaN<T: SpringAnimatable> {
pub config: InertiaConfig,
velocities: Vec<f32>,
positions: Vec<f32>,
current: T,
settled: bool,
}
impl<T: SpringAnimatable> InertiaN<T> {
pub fn new(config: InertiaConfig, initial: T) -> Self {
let positions = initial.to_components();
let n = positions.len();
Self {
config,
velocities: vec![0.0; n],
positions,
current: initial,
settled: true,
}
}
pub fn with_velocity(mut self, velocity: T) -> Self {
self.velocities = velocity.to_components();
self.settled = self
.velocities
.iter()
.all(|&v: &f32| v.abs() < self.config.epsilon);
self
}
pub fn kick(&mut self, velocity: T) {
self.velocities = velocity.to_components();
self.settled = false;
}
pub fn position(&self) -> T {
self.current.clone()
}
pub fn velocity_components(&self) -> &[f32] {
&self.velocities
}
pub fn is_settled(&self) -> bool {
self.settled
}
pub fn reset(&mut self, initial: T) {
self.positions = initial.to_components();
self.velocities = vec![0.0; self.positions.len()];
self.current = initial;
self.settled = true;
}
}
impl<T: SpringAnimatable> Update for InertiaN<T> {
fn update(&mut self, dt: f32) -> bool {
if self.settled {
return false;
}
let decay = (1.0 - self.config.friction).powf(dt * 60.0);
let mut all_settled = true;
for i in 0..self.velocities.len() {
self.velocities[i] *= decay;
self.positions[i] += self.velocities[i] * dt;
if self.velocities[i].abs() >= self.config.epsilon {
all_settled = false;
} else {
self.velocities[i] = 0.0;
}
}
self.current = T::from_components(&self.positions);
if all_settled {
self.settled = true;
}
!self.settled
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn inertia_decelerates_to_zero() {
let mut inertia = Inertia::new(InertiaConfig::default_flick()).with_velocity(500.0);
for _ in 0..600 {
if !inertia.update(1.0 / 60.0) {
break;
}
}
assert!(inertia.is_settled());
assert!(inertia.velocity().abs() < 0.2);
}
#[test]
fn inertia_position_increases_for_positive_velocity() {
let mut inertia = Inertia::new(InertiaConfig::default_flick()).with_velocity(100.0);
let prev_pos = inertia.position();
inertia.update(1.0 / 60.0);
assert!(inertia.position() > prev_pos);
}
#[test]
fn inertia_zero_velocity_is_settled() {
let inertia = Inertia::new(InertiaConfig::default_flick()).with_velocity(0.0);
assert!(inertia.is_settled());
}
#[test]
fn inertia_kick_restarts() {
let mut inertia = Inertia::new(InertiaConfig::default_flick()).with_velocity(100.0);
for _ in 0..600 {
if !inertia.update(1.0 / 60.0) {
break;
}
}
assert!(inertia.is_settled());
inertia.kick(200.0);
assert!(!inertia.is_settled());
assert!(inertia.update(1.0 / 60.0));
}
#[test]
fn inertia_snappy_stops_faster_than_heavy() {
let mut snappy = Inertia::new(InertiaConfig::snappy())
.with_velocity(500.0)
.with_position(0.0);
let mut heavy = Inertia::new(InertiaConfig::heavy())
.with_velocity(500.0)
.with_position(0.0);
let mut snappy_frames = 0u32;
for _ in 0..10000 {
snappy_frames += 1;
if !snappy.update(1.0 / 60.0) {
break;
}
}
let mut heavy_frames = 0u32;
for _ in 0..10000 {
heavy_frames += 1;
if !heavy.update(1.0 / 60.0) {
break;
}
}
assert!(
snappy_frames < heavy_frames,
"snappy ({snappy_frames}) should stop before heavy ({heavy_frames})"
);
}
#[test]
fn inertia_n_2d_decelerates() {
let mut inertia = InertiaN::new(InertiaConfig::default_flick(), [0.0_f32, 0.0])
.with_velocity([300.0, -200.0]);
for _ in 0..600 {
if !inertia.update(1.0 / 60.0) {
break;
}
}
assert!(inertia.is_settled());
let vel = inertia.velocity_components();
assert!(vel[0].abs() < 0.2);
assert!(vel[1].abs() < 0.2);
}
#[test]
fn inertia_n_position_changes() {
let mut inertia = InertiaN::new(InertiaConfig::default_flick(), [0.0_f32, 0.0])
.with_velocity([100.0, 0.0]);
inertia.update(1.0 / 60.0);
let pos = inertia.position();
assert!(pos[0] > 0.0, "x should have moved: {:?}", pos);
assert!((pos[1]).abs() < 1e-6, "y should be ~0: {:?}", pos);
}
#[test]
fn inertia_reset_works() {
let mut inertia = Inertia::new(InertiaConfig::default_flick())
.with_velocity(100.0)
.with_position(50.0);
inertia.update(0.1);
inertia.reset();
assert!(inertia.is_settled());
assert!((inertia.position()).abs() < 1e-6);
assert!((inertia.velocity()).abs() < 1e-6);
}
#[test]
fn inertia_frame_rate_independence() {
let mut a = Inertia::new(InertiaConfig::default_flick()).with_velocity(500.0);
let mut b = Inertia::new(InertiaConfig::default_flick()).with_velocity(500.0);
for _ in 0..60 {
a.update(1.0 / 60.0);
}
for _ in 0..120 {
b.update(1.0 / 120.0);
}
let diff = (a.position() - b.position()).abs();
assert!(
diff < 5.0,
"Frame rate independence: 60fps pos={}, 120fps pos={}, diff={}",
a.position(),
b.position(),
diff
);
}
}