Skip to main content

proof_engine/physics/
soft_body.rs

1//! Mass-spring soft body simulation.
2//!
3//! Models deformable objects as networks of point masses connected by
4//! Hookean springs (structural, shear, and bend). Integration uses
5//! semi-implicit Euler with optional Verlet damping.
6//!
7//! ## Typical Usage
8//! ```rust,no_run
9//! use proof_engine::physics::soft_body::SoftBody;
10//! let mut cloth = SoftBody::grid(8, 8, 1.0);
11//! cloth.pin(0); // pin top-left corner
12//! cloth.step(0.016, [0.0, -9.8].into());
13//! ```
14
15use glam::Vec2;
16
17// ── SoftNode ──────────────────────────────────────────────────────────────────
18
19/// A point mass in the soft body network.
20#[derive(Debug, Clone)]
21pub struct SoftNode {
22    pub position:  Vec2,
23    pub velocity:  Vec2,
24    /// Accumulated force for the current integration step.
25    pub force:     Vec2,
26    pub mass:      f32,
27    pub inv_mass:  f32,  // 0 = pinned/static
28    /// Whether this node is pinned (fixed in space).
29    pub pinned:    bool,
30    /// Optional user tag.
31    pub tag:       u32,
32}
33
34impl SoftNode {
35    pub fn new(position: Vec2, mass: f32) -> Self {
36        Self {
37            position,
38            velocity: Vec2::ZERO,
39            force:    Vec2::ZERO,
40            mass,
41            inv_mass: if mass > 0.0 { 1.0 / mass } else { 0.0 },
42            pinned:   false,
43            tag:      0,
44        }
45    }
46
47    pub fn pinned(mut self) -> Self {
48        self.pinned = true;
49        self.inv_mass = 0.0;
50        self
51    }
52}
53
54// ── Spring ────────────────────────────────────────────────────────────────────
55
56/// Spring type for categorization and visual rendering.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum SpringKind {
59    Structural, // direct neighbors
60    Shear,      // diagonal neighbors
61    Bend,       // skip-one neighbors
62    Custom,
63}
64
65/// A Hookean spring connecting two nodes.
66#[derive(Debug, Clone)]
67pub struct Spring {
68    pub a:            usize,
69    pub b:            usize,
70    pub rest_length:  f32,
71    /// Spring stiffness coefficient (N/m).
72    pub stiffness:    f32,
73    /// Damping coefficient for velocity along the spring axis.
74    pub damping:      f32,
75    pub kind:         SpringKind,
76    /// If true, the spring only resists compression (tension-only).
77    pub tension_only: bool,
78    /// Break threshold — spring removed if stretch exceeds this fraction. 0 = never breaks.
79    pub break_at:     f32,
80    pub broken:       bool,
81}
82
83impl Spring {
84    pub fn new(a: usize, b: usize, rest_length: f32, stiffness: f32) -> Self {
85        Self {
86            a, b, rest_length, stiffness,
87            damping:      0.1,
88            kind:         SpringKind::Custom,
89            tension_only: false,
90            break_at:     0.0,
91            broken:       false,
92        }
93    }
94
95    /// Compute force on node `a` from this spring.
96    fn compute_force(&self, pa: Vec2, pb: Vec2, va: Vec2, vb: Vec2) -> Vec2 {
97        let delta = pb - pa;
98        let dist  = delta.length();
99        if dist < 1e-6 { return Vec2::ZERO; }
100        let dir = delta / dist;
101        let stretch = dist - self.rest_length;
102        if self.tension_only && stretch < 0.0 { return Vec2::ZERO; }
103        let spring_force  = self.stiffness * stretch;
104        let rel_vel       = (vb - va).dot(dir);
105        let damping_force = self.damping * rel_vel;
106        dir * (spring_force + damping_force)
107    }
108}
109
110// ── SoftBody ──────────────────────────────────────────────────────────────────
111
112/// A mass-spring soft body.
113#[derive(Debug, Clone)]
114pub struct SoftBody {
115    pub nodes:          Vec<SoftNode>,
116    pub springs:        Vec<Spring>,
117    /// Global friction / air drag coefficient.
118    pub damping:        f32,
119    /// Contact restitution (for floor/ceiling collisions).
120    pub restitution:    f32,
121    /// Iteration count for constraint projection.
122    pub iterations:     usize,
123    /// User label.
124    pub label:          String,
125}
126
127impl SoftBody {
128    // ── Constructors ───────────────────────────────────────────────────────────
129
130    pub fn new() -> Self {
131        Self {
132            nodes:       Vec::new(),
133            springs:     Vec::new(),
134            damping:     0.98,
135            restitution: 0.3,
136            iterations:  4,
137            label:       String::new(),
138        }
139    }
140
141    /// Create a 1D rope of `n` nodes spanning `length`.
142    pub fn rope(n: usize, length: f32, mass_per_node: f32, stiffness: f32) -> Self {
143        let mut sb = Self::new();
144        let seg = length / (n - 1).max(1) as f32;
145        for i in 0..n {
146            sb.nodes.push(SoftNode::new(Vec2::new(i as f32 * seg, 0.0), mass_per_node));
147        }
148        for i in 0..n - 1 {
149            let mut s = Spring::new(i, i + 1, seg, stiffness);
150            s.kind = SpringKind::Structural;
151            sb.springs.push(s);
152        }
153        sb
154    }
155
156    /// Create a 2D cloth grid of `cols × rows` nodes, spaced `cell_size`.
157    pub fn grid(cols: usize, rows: usize, cell_size: f32) -> Self {
158        Self::grid_with_params(cols, rows, cell_size, 1.0, 800.0)
159    }
160
161    /// Create a cloth grid with custom mass and stiffness.
162    pub fn grid_with_params(
163        cols: usize,
164        rows: usize,
165        cell_size: f32,
166        mass: f32,
167        stiffness: f32,
168    ) -> Self {
169        let mut sb = Self::new();
170        let diag = cell_size * std::f32::consts::SQRT_2;
171        let double = cell_size * 2.0;
172
173        // Nodes
174        for r in 0..rows {
175            for c in 0..cols {
176                sb.nodes.push(SoftNode::new(
177                    Vec2::new(c as f32 * cell_size, -(r as f32) * cell_size),
178                    mass,
179                ));
180            }
181        }
182
183        let idx = |r: usize, c: usize| r * cols + c;
184
185        // Structural springs (horizontal + vertical)
186        for r in 0..rows {
187            for c in 0..cols {
188                if c + 1 < cols {
189                    let mut s = Spring::new(idx(r, c), idx(r, c + 1), cell_size, stiffness);
190                    s.kind = SpringKind::Structural;
191                    sb.springs.push(s);
192                }
193                if r + 1 < rows {
194                    let mut s = Spring::new(idx(r, c), idx(r + 1, c), cell_size, stiffness);
195                    s.kind = SpringKind::Structural;
196                    sb.springs.push(s);
197                }
198            }
199        }
200
201        // Shear springs (diagonal)
202        for r in 0..rows - 1 {
203            for c in 0..cols - 1 {
204                let mut s1 = Spring::new(idx(r, c), idx(r + 1, c + 1), diag, stiffness * 0.7);
205                s1.kind = SpringKind::Shear;
206                sb.springs.push(s1);
207                let mut s2 = Spring::new(idx(r, c + 1), idx(r + 1, c), diag, stiffness * 0.7);
208                s2.kind = SpringKind::Shear;
209                sb.springs.push(s2);
210            }
211        }
212
213        // Bend springs (skip-one)
214        for r in 0..rows {
215            for c in 0..cols {
216                if c + 2 < cols {
217                    let mut s = Spring::new(idx(r, c), idx(r, c + 2), double, stiffness * 0.3);
218                    s.kind = SpringKind::Bend;
219                    sb.springs.push(s);
220                }
221                if r + 2 < rows {
222                    let mut s = Spring::new(idx(r, c), idx(r + 2, c), double, stiffness * 0.3);
223                    s.kind = SpringKind::Bend;
224                    sb.springs.push(s);
225                }
226            }
227        }
228
229        sb
230    }
231
232    /// Create a circular blob of `n` nodes with internal cross-springs.
233    pub fn blob(n: usize, radius: f32, mass: f32, stiffness: f32) -> Self {
234        let mut sb = Self::new();
235        let tau = std::f32::consts::TAU;
236
237        // Outer ring
238        for i in 0..n {
239            let angle = i as f32 / n as f32 * tau;
240            sb.nodes.push(SoftNode::new(
241                Vec2::new(angle.cos() * radius, angle.sin() * radius),
242                mass,
243            ));
244        }
245        // Center node
246        sb.nodes.push(SoftNode::new(Vec2::ZERO, mass * 2.0));
247        let center = n;
248
249        let arc = radius * tau / n as f32;
250
251        // Ring springs
252        for i in 0..n {
253            let j = (i + 1) % n;
254            let mut s = Spring::new(i, j, arc, stiffness);
255            s.kind = SpringKind::Structural;
256            sb.springs.push(s);
257        }
258
259        // Spoke springs (ring → center)
260        for i in 0..n {
261            let mut s = Spring::new(i, center, radius, stiffness * 0.8);
262            s.kind = SpringKind::Structural;
263            sb.springs.push(s);
264        }
265
266        // Cross springs (skip-one ring)
267        for i in 0..n {
268            let j = (i + 2) % n;
269            let p1 = sb.nodes[i].position;
270            let p2 = sb.nodes[j].position;
271            let len = p1.distance(p2);
272            let mut s = Spring::new(i, j, len, stiffness * 0.4);
273            s.kind = SpringKind::Bend;
274            sb.springs.push(s);
275        }
276
277        sb
278    }
279
280    // ── Node manipulation ──────────────────────────────────────────────────────
281
282    /// Add a node and return its index.
283    pub fn add_node(&mut self, position: Vec2, mass: f32) -> usize {
284        let idx = self.nodes.len();
285        self.nodes.push(SoftNode::new(position, mass));
286        idx
287    }
288
289    /// Add a spring between two nodes and return its index.
290    pub fn add_spring(&mut self, a: usize, b: usize, stiffness: f32) -> usize {
291        let rest = self.nodes[a].position.distance(self.nodes[b].position);
292        let idx = self.springs.len();
293        self.springs.push(Spring::new(a, b, rest, stiffness));
294        idx
295    }
296
297    /// Pin a node (fix it in space).
298    pub fn pin(&mut self, node: usize) {
299        if let Some(n) = self.nodes.get_mut(node) {
300            n.pinned = true;
301            n.inv_mass = 0.0;
302        }
303    }
304
305    /// Unpin a node.
306    pub fn unpin(&mut self, node: usize) {
307        if let Some(n) = self.nodes.get_mut(node) {
308            n.pinned = false;
309            n.inv_mass = if n.mass > 0.0 { 1.0 / n.mass } else { 0.0 };
310        }
311    }
312
313    /// Apply an impulse to a node.
314    pub fn apply_impulse(&mut self, node: usize, impulse: Vec2) {
315        if let Some(n) = self.nodes.get_mut(node) {
316            if !n.pinned {
317                n.velocity += impulse * n.inv_mass;
318            }
319        }
320    }
321
322    /// Apply force to all nodes in a radius.
323    pub fn apply_force_radius(&mut self, origin: Vec2, radius: f32, force: Vec2) {
324        for n in &mut self.nodes {
325            if n.pinned { continue; }
326            let d = n.position.distance(origin);
327            if d < radius {
328                let factor = 1.0 - d / radius;
329                n.force += force * factor;
330            }
331        }
332    }
333
334    // ── Simulation ─────────────────────────────────────────────────────────────
335
336    /// Step the simulation by `dt` seconds.
337    pub fn step(&mut self, dt: f32, gravity: Vec2) {
338        self.accumulate_forces(gravity);
339        self.integrate(dt);
340        self.solve_constraints(dt);
341        self.clear_forces();
342        self.remove_broken_springs();
343    }
344
345    fn accumulate_forces(&mut self, gravity: Vec2) {
346        // Apply spring forces
347        let positions: Vec<Vec2> = self.nodes.iter().map(|n| n.position).collect();
348        let velocities: Vec<Vec2> = self.nodes.iter().map(|n| n.velocity).collect();
349
350        for spring in &mut self.springs {
351            if spring.broken { continue; }
352            let a = spring.a;
353            let b = spring.b;
354            let force = spring.compute_force(positions[a], positions[b],
355                                             velocities[a], velocities[b]);
356
357            // Check break condition
358            if spring.break_at > 0.0 {
359                let dist = positions[a].distance(positions[b]);
360                let stretch_ratio = (dist - spring.rest_length).abs() / spring.rest_length.max(1e-6);
361                if stretch_ratio > spring.break_at {
362                    spring.broken = true;
363                    continue;
364                }
365            }
366
367            // Accumulate into node forces (deferred — nodes are not mutably aliased here)
368            let _ = (a, b, force); // forces accumulated below
369        }
370
371        // Re-compute without borrow issue
372        let n = self.nodes.len();
373        let mut forces = vec![Vec2::ZERO; n];
374        let positions: Vec<Vec2> = self.nodes.iter().map(|nd| nd.position).collect();
375        let velocities: Vec<Vec2> = self.nodes.iter().map(|nd| nd.velocity).collect();
376
377        for spring in &self.springs {
378            if spring.broken { continue; }
379            let force = spring.compute_force(
380                positions[spring.a], positions[spring.b],
381                velocities[spring.a], velocities[spring.b],
382            );
383            forces[spring.a] += force;
384            forces[spring.b] -= force;
385        }
386
387        // Add gravity and accumulated forces
388        for (i, node) in self.nodes.iter_mut().enumerate() {
389            if node.pinned { continue; }
390            node.force += forces[i] + gravity * node.mass;
391        }
392    }
393
394    fn integrate(&mut self, dt: f32) {
395        for node in &mut self.nodes {
396            if node.pinned { continue; }
397            // Semi-implicit Euler
398            let acc = node.force * node.inv_mass;
399            node.velocity = (node.velocity + acc * dt) * self.damping;
400            node.position += node.velocity * dt;
401        }
402    }
403
404    fn solve_constraints(&mut self, _dt: f32) {
405        // Iterative position correction (XPBD-lite)
406        for _ in 0..self.iterations {
407            let positions: Vec<Vec2> = self.nodes.iter().map(|n| n.position).collect();
408            let inv_masses: Vec<f32> = self.nodes.iter().map(|n| n.inv_mass).collect();
409
410            let mut deltas = vec![Vec2::ZERO; self.nodes.len()];
411            let mut counts = vec![0u32; self.nodes.len()];
412
413            for spring in &self.springs {
414                if spring.broken { continue; }
415                let pa = positions[spring.a];
416                let pb = positions[spring.b];
417                let delta = pb - pa;
418                let dist = delta.length();
419                if dist < 1e-6 { continue; }
420                let error = dist - spring.rest_length;
421                let dir = delta / dist;
422                let w_a = inv_masses[spring.a];
423                let w_b = inv_masses[spring.b];
424                let w_sum = w_a + w_b;
425                if w_sum < 1e-10 { continue; }
426                let correction = error / w_sum;
427                deltas[spring.a] += dir *  w_a * correction;
428                deltas[spring.b] -= dir *  w_b * correction;
429                counts[spring.a] += 1;
430                counts[spring.b] += 1;
431            }
432
433            // Apply average correction
434            for (i, node) in self.nodes.iter_mut().enumerate() {
435                if node.pinned || counts[i] == 0 { continue; }
436                node.position += deltas[i] / counts[i] as f32 * 0.5;
437            }
438        }
439    }
440
441    fn clear_forces(&mut self) {
442        for node in &mut self.nodes {
443            node.force = Vec2::ZERO;
444        }
445    }
446
447    fn remove_broken_springs(&mut self) {
448        self.springs.retain(|s| !s.broken);
449    }
450
451    // ── Collisions ─────────────────────────────────────────────────────────────
452
453    /// Resolve simple floor collision (y >= floor_y, normal = up).
454    pub fn resolve_floor(&mut self, floor_y: f32) {
455        for node in &mut self.nodes {
456            if node.pinned { continue; }
457            if node.position.y < floor_y {
458                node.position.y = floor_y;
459                if node.velocity.y < 0.0 {
460                    node.velocity.y = -node.velocity.y * self.restitution;
461                    node.velocity.x *= 0.9; // friction
462                }
463            }
464        }
465    }
466
467    /// Resolve ceiling collision.
468    pub fn resolve_ceiling(&mut self, ceiling_y: f32) {
469        for node in &mut self.nodes {
470            if node.pinned { continue; }
471            if node.position.y > ceiling_y {
472                node.position.y = ceiling_y;
473                if node.velocity.y > 0.0 {
474                    node.velocity.y = -node.velocity.y * self.restitution;
475                }
476            }
477        }
478    }
479
480    /// Resolve left wall.
481    pub fn resolve_wall_left(&mut self, x: f32) {
482        for node in &mut self.nodes {
483            if node.pinned { continue; }
484            if node.position.x < x {
485                node.position.x = x;
486                if node.velocity.x < 0.0 {
487                    node.velocity.x = -node.velocity.x * self.restitution;
488                }
489            }
490        }
491    }
492
493    /// Resolve right wall.
494    pub fn resolve_wall_right(&mut self, x: f32) {
495        for node in &mut self.nodes {
496            if node.pinned { continue; }
497            if node.position.x > x {
498                node.position.x = x;
499                if node.velocity.x > 0.0 {
500                    node.velocity.x = -node.velocity.x * self.restitution;
501                }
502            }
503        }
504    }
505
506    /// Push nodes out of a circle.
507    pub fn resolve_circle_obstacle(&mut self, center: Vec2, radius: f32) {
508        for node in &mut self.nodes {
509            if node.pinned { continue; }
510            let d = node.position - center;
511            let dist = d.length();
512            if dist < radius && dist > 1e-6 {
513                let pen = radius - dist;
514                let n = d / dist;
515                node.position += n * pen;
516                let vn = node.velocity.dot(n);
517                if vn < 0.0 {
518                    node.velocity -= n * vn * (1.0 + self.restitution);
519                }
520            }
521        }
522    }
523
524    // ── Queries ────────────────────────────────────────────────────────────────
525
526    /// Axis-aligned bounding box of all nodes.
527    pub fn aabb(&self) -> Option<(Vec2, Vec2)> {
528        if self.nodes.is_empty() { return None; }
529        let mut min = self.nodes[0].position;
530        let mut max = self.nodes[0].position;
531        for n in &self.nodes {
532            min = min.min(n.position);
533            max = max.max(n.position);
534        }
535        Some((min, max))
536    }
537
538    /// Centroid of all node positions.
539    pub fn centroid(&self) -> Vec2 {
540        if self.nodes.is_empty() { return Vec2::ZERO; }
541        let sum: Vec2 = self.nodes.iter().map(|n| n.position).sum();
542        sum / self.nodes.len() as f32
543    }
544
545    /// Total kinetic energy.
546    pub fn kinetic_energy(&self) -> f32 {
547        self.nodes.iter().map(|n| 0.5 * n.mass * n.velocity.length_squared()).sum()
548    }
549
550    /// Total potential energy from spring stretch.
551    pub fn spring_potential_energy(&self) -> f32 {
552        self.springs.iter().map(|s| {
553            if s.broken { return 0.0; }
554            let pa = self.nodes[s.a].position;
555            let pb = self.nodes[s.b].position;
556            let stretch = pa.distance(pb) - s.rest_length;
557            0.5 * s.stiffness * stretch * stretch
558        }).sum()
559    }
560
561    /// Number of active (non-broken) springs.
562    pub fn active_spring_count(&self) -> usize {
563        self.springs.iter().filter(|s| !s.broken).count()
564    }
565
566    /// Translate all nodes.
567    pub fn translate(&mut self, offset: Vec2) {
568        for n in &mut self.nodes {
569            n.position += offset;
570        }
571    }
572
573    /// Scale positions around centroid.
574    pub fn scale(&mut self, factor: f32) {
575        let c = self.centroid();
576        for n in &mut self.nodes {
577            n.position = c + (n.position - c) * factor;
578        }
579    }
580
581    /// Zero all velocities (instant freeze).
582    pub fn freeze(&mut self) {
583        for n in &mut self.nodes {
584            n.velocity = Vec2::ZERO;
585        }
586    }
587
588    /// Collect all edge node positions (ring periphery for convex hulls or rendering).
589    pub fn positions(&self) -> Vec<Vec2> {
590        self.nodes.iter().map(|n| n.position).collect()
591    }
592
593    /// Closest node index to a world position.
594    pub fn nearest_node(&self, point: Vec2) -> Option<usize> {
595        self.nodes.iter().enumerate().min_by(|(_, a), (_, b)| {
596            let da = a.position.distance_squared(point);
597            let db = b.position.distance_squared(point);
598            da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
599        }).map(|(i, _)| i)
600    }
601}
602
603impl Default for SoftBody {
604    fn default() -> Self { Self::new() }
605}
606
607// ── SoftBodyConstraint ────────────────────────────────────────────────────────
608
609/// Additional constraint types that can be layered onto a SoftBody.
610#[derive(Debug, Clone)]
611pub enum SoftConstraint {
612    /// Keep node at a fixed world position.
613    FixedPoint { node: usize, target: Vec2, stiffness: f32 },
614    /// Keep two nodes at a fixed distance.
615    DistanceFixed { a: usize, b: usize, length: f32 },
616    /// Keep node within a circle.
617    CircleBound { node: usize, center: Vec2, radius: f32 },
618}
619
620impl SoftConstraint {
621    pub fn apply(&self, body: &mut SoftBody) {
622        match self {
623            SoftConstraint::FixedPoint { node, target, stiffness } => {
624                if let Some(n) = body.nodes.get_mut(*node) {
625                    let err = *target - n.position;
626                    n.position += err * *stiffness;
627                }
628            }
629            SoftConstraint::DistanceFixed { a, b, length } => {
630                if *a < body.nodes.len() && *b < body.nodes.len() {
631                    let pa = body.nodes[*a].position;
632                    let pb = body.nodes[*b].position;
633                    let delta = pb - pa;
634                    let dist = delta.length();
635                    if dist < 1e-6 { return; }
636                    let error = dist - *length;
637                    let dir = delta / dist;
638                    let correction = dir * error * 0.5;
639                    let inv_a = body.nodes[*a].inv_mass;
640                    let inv_b = body.nodes[*b].inv_mass;
641                    let w = inv_a + inv_b;
642                    if w < 1e-10 { return; }
643                    if !body.nodes[*a].pinned {
644                        body.nodes[*a].position += correction * (inv_a / w);
645                    }
646                    if !body.nodes[*b].pinned {
647                        body.nodes[*b].position -= correction * (inv_b / w);
648                    }
649                }
650            }
651            SoftConstraint::CircleBound { node, center, radius } => {
652                if let Some(n) = body.nodes.get_mut(*node) {
653                    if n.pinned { return; }
654                    let d = n.position - *center;
655                    let dist = d.length();
656                    if dist > *radius {
657                        n.position = *center + d / dist * *radius;
658                        n.velocity = Vec2::ZERO;
659                    }
660                }
661            }
662        }
663    }
664}
665
666// ── Unit tests ─────────────────────────────────────────────────────────────────
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    #[test]
673    fn test_rope_creation() {
674        let r = SoftBody::rope(5, 4.0, 1.0, 100.0);
675        assert_eq!(r.nodes.len(), 5);
676        assert_eq!(r.springs.len(), 4);
677    }
678
679    #[test]
680    fn test_grid_creation() {
681        let g = SoftBody::grid(4, 4, 1.0);
682        assert_eq!(g.nodes.len(), 16);
683        // structural + shear + bend
684        assert!(g.springs.len() > 20);
685    }
686
687    #[test]
688    fn test_blob_creation() {
689        let b = SoftBody::blob(8, 1.0, 1.0, 200.0);
690        assert_eq!(b.nodes.len(), 9); // 8 ring + 1 center
691        assert!(b.springs.len() > 8);
692    }
693
694    #[test]
695    fn test_pin_unpin() {
696        let mut sb = SoftBody::rope(3, 2.0, 1.0, 100.0);
697        sb.pin(0);
698        assert!(sb.nodes[0].pinned);
699        assert_eq!(sb.nodes[0].inv_mass, 0.0);
700        sb.unpin(0);
701        assert!(!sb.nodes[0].pinned);
702        assert!((sb.nodes[0].inv_mass - 1.0).abs() < 1e-5);
703    }
704
705    #[test]
706    fn test_step_gravity() {
707        let mut sb = SoftBody::rope(3, 2.0, 1.0, 500.0);
708        sb.pin(0);
709        let y0 = sb.nodes[2].position.y;
710        for _ in 0..20 {
711            sb.step(0.016, Vec2::new(0.0, -9.8));
712        }
713        // Node 2 should fall (y decreases)
714        assert!(sb.nodes[2].position.y < y0);
715    }
716
717    #[test]
718    fn test_floor_collision() {
719        let mut sb = SoftBody::new();
720        sb.add_node(Vec2::new(0.0, 1.0), 1.0);
721        sb.nodes[0].velocity = Vec2::new(0.0, -5.0);
722        sb.nodes[0].force = Vec2::ZERO;
723        // Manually push node below floor
724        sb.nodes[0].position.y = -0.5;
725        sb.resolve_floor(0.0);
726        assert!(sb.nodes[0].position.y >= 0.0);
727    }
728
729    #[test]
730    fn test_kinetic_energy() {
731        let mut sb = SoftBody::rope(2, 1.0, 1.0, 100.0);
732        sb.nodes[0].velocity = Vec2::new(1.0, 0.0);
733        sb.nodes[1].velocity = Vec2::new(0.0, 1.0);
734        let ke = sb.kinetic_energy();
735        assert!((ke - 1.0).abs() < 1e-5);
736    }
737
738    #[test]
739    fn test_centroid() {
740        let mut sb = SoftBody::new();
741        sb.add_node(Vec2::new(-1.0, 0.0), 1.0);
742        sb.add_node(Vec2::new(1.0,  0.0), 1.0);
743        let c = sb.centroid();
744        assert!((c.x).abs() < 1e-6);
745    }
746
747    #[test]
748    fn test_nearest_node() {
749        let mut sb = SoftBody::new();
750        sb.add_node(Vec2::new(0.0, 0.0), 1.0);
751        sb.add_node(Vec2::new(5.0, 0.0), 1.0);
752        assert_eq!(sb.nearest_node(Vec2::new(0.1, 0.0)), Some(0));
753        assert_eq!(sb.nearest_node(Vec2::new(4.9, 0.0)), Some(1));
754    }
755
756    #[test]
757    fn test_translate() {
758        let mut sb = SoftBody::rope(2, 1.0, 1.0, 100.0);
759        let orig = sb.nodes[0].position;
760        sb.translate(Vec2::new(3.0, 0.0));
761        assert!((sb.nodes[0].position.x - (orig.x + 3.0)).abs() < 1e-6);
762    }
763
764    #[test]
765    fn test_spring_potential_energy() {
766        let mut sb = SoftBody::new();
767        let a = sb.add_node(Vec2::new(0.0, 0.0), 1.0);
768        let b = sb.add_node(Vec2::new(2.0, 0.0), 1.0);
769        sb.add_spring(a, b, 100.0);
770        // rest = 2.0, dist = 2.0 → stretch = 0 → PE = 0
771        assert!(sb.spring_potential_energy() < 1e-6);
772        sb.nodes[b].position.x = 3.0; // stretch by 1.0
773        // PE = 0.5 * 100 * 1^2 = 50
774        assert!((sb.spring_potential_energy() - 50.0).abs() < 1e-4);
775    }
776
777    #[test]
778    fn test_fixed_point_constraint() {
779        let mut sb = SoftBody::new();
780        sb.add_node(Vec2::new(5.0, 0.0), 1.0);
781        let c = SoftConstraint::FixedPoint {
782            node: 0,
783            target: Vec2::new(0.0, 0.0),
784            stiffness: 1.0,
785        };
786        c.apply(&mut sb);
787        assert!((sb.nodes[0].position.x).abs() < 1e-6);
788    }
789}