Skip to main content

oxiphysics_gpu/
gpu_cloth.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! GPU-accelerated cloth simulation (CPU mock implementation).
6//!
7//! Provides a complete cloth simulation pipeline including spring–damper edges,
8//! bending constraints, XPBD position solving, and collision response against
9//! analytic primitives (sphere, plane).
10
11use crate::{cross3, dot3, length3, normalize3};
12
13// ---------------------------------------------------------------------------
14// Helper math
15// ---------------------------------------------------------------------------
16
17#[inline]
18fn sub3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
19    [a[0] - b[0], a[1] - b[1], a[2] - b[2]]
20}
21
22#[inline]
23fn add3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
24    [a[0] + b[0], a[1] + b[1], a[2] + b[2]]
25}
26
27#[inline]
28fn scale3(v: [f64; 3], s: f64) -> [f64; 3] {
29    [v[0] * s, v[1] * s, v[2] * s]
30}
31
32// ---------------------------------------------------------------------------
33// triangle_area
34// ---------------------------------------------------------------------------
35
36/// Compute the area of a triangle defined by three 3-D vertices.
37///
38/// Uses the cross-product formula: `area = ||(v1-v0) × (v2-v0)|| / 2`.
39pub fn triangle_area(v0: [f64; 3], v1: [f64; 3], v2: [f64; 3]) -> f64 {
40    let e1 = sub3(v1, v0);
41    let e2 = sub3(v2, v0);
42    length3(cross3(e1, e2)) * 0.5
43}
44
45// ---------------------------------------------------------------------------
46// dihedral_angle
47// ---------------------------------------------------------------------------
48
49/// Compute the dihedral angle (in radians) between two triangles sharing edge
50/// `(v1, v2)`.
51///
52/// `v0` and `v3` are the "wing" vertices of each triangle respectively.
53/// Returns a value in `[0, π]`.
54pub fn dihedral_angle(v0: [f64; 3], v1: [f64; 3], v2: [f64; 3], v3: [f64; 3]) -> f64 {
55    let e = sub3(v2, v1);
56    let n1 = cross3(sub3(v0, v1), e);
57    let n2 = cross3(e, sub3(v3, v1));
58    let len1 = length3(n1);
59    let len2 = length3(n2);
60    if len1 < 1e-15 || len2 < 1e-15 {
61        return 0.0;
62    }
63    let cos_a = dot3(n1, n2) / (len1 * len2);
64    cos_a.clamp(-1.0, 1.0).acos()
65}
66
67// ---------------------------------------------------------------------------
68// ClothVertex
69// ---------------------------------------------------------------------------
70
71/// A single vertex (particle) in the cloth mesh.
72#[derive(Debug, Clone)]
73pub struct ClothVertex {
74    /// World-space position `[x, y, z]`.
75    pub position: [f64; 3],
76    /// World-space velocity `[vx, vy, vz]`.
77    pub velocity: [f64; 3],
78    /// Particle mass (kg).
79    pub mass: f64,
80    /// If `true` this vertex is fixed in space and not integrated.
81    pub pinned: bool,
82}
83
84impl ClothVertex {
85    /// Create a new vertex at `position` with zero velocity.
86    pub fn new(position: [f64; 3], mass: f64, pinned: bool) -> Self {
87        Self {
88            position,
89            velocity: [0.0; 3],
90            mass,
91            pinned,
92        }
93    }
94}
95
96// ---------------------------------------------------------------------------
97// ClothEdge
98// ---------------------------------------------------------------------------
99
100/// A structural spring edge connecting two cloth vertices.
101#[derive(Debug, Clone)]
102pub struct ClothEdge {
103    /// Index of the first vertex.
104    pub v0: usize,
105    /// Index of the second vertex.
106    pub v1: usize,
107    /// Rest length of the spring (m).
108    pub rest_length: f64,
109    /// Spring stiffness (N/m).
110    pub stiffness: f64,
111    /// Damping coefficient (N·s/m).
112    pub damping: f64,
113}
114
115impl ClothEdge {
116    /// Create a new edge.
117    pub fn new(v0: usize, v1: usize, rest_length: f64, stiffness: f64, damping: f64) -> Self {
118        Self {
119            v0,
120            v1,
121            rest_length,
122            stiffness,
123            damping,
124        }
125    }
126
127    /// Compute the spring force acting **on** vertex `v0` (force on `v1` is negated).
128    ///
129    /// Includes both Hooke's law and linear damping.
130    pub fn spring_force(&self, verts: &[ClothVertex]) -> [f64; 3] {
131        let p0 = verts[self.v0].position;
132        let p1 = verts[self.v1].position;
133        let vel0 = verts[self.v0].velocity;
134        let vel1 = verts[self.v1].velocity;
135
136        let delta = sub3(p1, p0);
137        let dist = length3(delta);
138        if dist < 1e-15 {
139            return [0.0; 3];
140        }
141        let dir = scale3(delta, 1.0 / dist);
142
143        // Relative velocity projected onto the edge direction
144        let rel_vel = dot3(sub3(vel1, vel0), dir);
145
146        let spring_f = self.stiffness * (dist - self.rest_length);
147        let damp_f = self.damping * rel_vel;
148        scale3(dir, spring_f + damp_f)
149    }
150}
151
152// ---------------------------------------------------------------------------
153// BendingConstraint
154// ---------------------------------------------------------------------------
155
156/// A bending constraint connecting four vertices that share a common edge.
157///
158/// The constraint resists deviation from the `rest_angle` dihedral angle.
159#[derive(Debug, Clone)]
160pub struct BendingConstraint {
161    /// First wing vertex index.
162    pub v0: usize,
163    /// Edge vertex index (first).
164    pub v1: usize,
165    /// Edge vertex index (second).
166    pub v2: usize,
167    /// Second wing vertex index.
168    pub v3: usize,
169    /// Rest dihedral angle (radians).
170    pub rest_angle: f64,
171    /// Bending stiffness (N·m/rad).
172    pub stiffness: f64,
173}
174
175impl BendingConstraint {
176    /// Create a new bending constraint.
177    pub fn new(
178        v0: usize,
179        v1: usize,
180        v2: usize,
181        v3: usize,
182        rest_angle: f64,
183        stiffness: f64,
184    ) -> Self {
185        Self {
186            v0,
187            v1,
188            v2,
189            v3,
190            rest_angle,
191            stiffness,
192        }
193    }
194
195    /// Compute restoring forces on the four vertices.
196    ///
197    /// Returns `[f_v0, f_v1, f_v2, f_v3]` — one `[f64;3]` force per vertex.
198    pub fn compute_force(&self, verts: &[ClothVertex]) -> Vec<[f64; 3]> {
199        let p0 = verts[self.v0].position;
200        let p1 = verts[self.v1].position;
201        let p2 = verts[self.v2].position;
202        let p3 = verts[self.v3].position;
203
204        let current_angle = dihedral_angle(p0, p1, p2, p3);
205        let angle_diff = current_angle - self.rest_angle;
206
207        if angle_diff.abs() < 1e-12 {
208            return vec![[0.0; 3]; 4];
209        }
210
211        // Gradient magnitude
212        let mag = self.stiffness * angle_diff;
213
214        // Edge and normals
215        let e = sub3(p2, p1);
216        let e_len = length3(e);
217        if e_len < 1e-15 {
218            return vec![[0.0; 3]; 4];
219        }
220
221        let n1 = cross3(sub3(p0, p1), e);
222        let n2 = cross3(e, sub3(p3, p1));
223        let len1 = length3(n1);
224        let len2 = length3(n2);
225        if len1 < 1e-15 || len2 < 1e-15 {
226            return vec![[0.0; 3]; 4];
227        }
228
229        // Gradient of the dihedral angle w.r.t. wing vertices
230        let g0 = scale3(normalize3(n1), -mag / len1);
231        let g3 = scale3(normalize3(n2), -mag / len2);
232
233        // Distribute to edge vertices proportionally
234        let g1 = scale3(add3(g0, g3), -0.5);
235        let g2 = scale3(add3(g0, g3), -0.5);
236
237        vec![g0, g1, g2, g3]
238    }
239}
240
241// ---------------------------------------------------------------------------
242// ClothMesh
243// ---------------------------------------------------------------------------
244
245/// The cloth mesh consisting of vertices and spring edges.
246#[derive(Debug, Clone)]
247pub struct ClothMesh {
248    /// Mesh vertices (particles).
249    pub vertices: Vec<ClothVertex>,
250    /// Spring edges.
251    pub edges: Vec<ClothEdge>,
252}
253
254impl ClothMesh {
255    /// Create an empty cloth mesh.
256    pub fn new() -> Self {
257        Self {
258            vertices: Vec::new(),
259            edges: Vec::new(),
260        }
261    }
262
263    /// Build a regular rectangular grid of `rows × cols` vertices with given
264    /// `spacing` (m) between adjacent vertices.
265    ///
266    /// The grid is laid out in the XZ plane (Y = 0).  The top row (`row == 0`)
267    /// is pinned.
268    pub fn build_grid(&mut self, rows: usize, cols: usize, spacing: f64) {
269        self.vertices.clear();
270        self.edges.clear();
271
272        // Create vertices
273        for r in 0..rows {
274            for c in 0..cols {
275                let pos = [c as f64 * spacing, 0.0, r as f64 * spacing];
276                let pinned = r == 0;
277                self.vertices.push(ClothVertex::new(pos, 1.0, pinned));
278            }
279        }
280
281        let idx = |r: usize, c: usize| r * cols + c;
282
283        // Structural edges (horizontal + vertical)
284        for r in 0..rows {
285            for c in 0..cols {
286                if c + 1 < cols {
287                    self.edges.push(ClothEdge::new(
288                        idx(r, c),
289                        idx(r, c + 1),
290                        spacing,
291                        1000.0,
292                        0.5,
293                    ));
294                }
295                if r + 1 < rows {
296                    self.edges.push(ClothEdge::new(
297                        idx(r, c),
298                        idx(r + 1, c),
299                        spacing,
300                        1000.0,
301                        0.5,
302                    ));
303                }
304            }
305        }
306
307        // Shear edges (diagonals)
308        let diag = spacing * std::f64::consts::SQRT_2;
309        for r in 0..rows {
310            for c in 0..cols {
311                if r + 1 < rows && c + 1 < cols {
312                    self.edges.push(ClothEdge::new(
313                        idx(r, c),
314                        idx(r + 1, c + 1),
315                        diag,
316                        500.0,
317                        0.2,
318                    ));
319                    self.edges.push(ClothEdge::new(
320                        idx(r + 1, c),
321                        idx(r, c + 1),
322                        diag,
323                        500.0,
324                        0.2,
325                    ));
326                }
327            }
328        }
329    }
330
331    /// Advance the simulation by one time step `dt` (seconds) with gravity `[gx, gy, gz]`.
332    ///
333    /// Uses symplectic Euler integration with explicit spring forces.
334    pub fn step(&mut self, dt: f64, gravity: [f64; 3]) {
335        let n = self.vertices.len();
336        let mut forces = vec![[0.0f64; 3]; n];
337
338        // Gravity
339        for (i, v) in self.vertices.iter().enumerate() {
340            if !v.pinned {
341                forces[i] = add3(forces[i], scale3(gravity, v.mass));
342            }
343        }
344
345        // Spring forces
346        for edge in &self.edges {
347            let f = edge.spring_force(&self.vertices);
348            if !self.vertices[edge.v0].pinned {
349                forces[edge.v0] = add3(forces[edge.v0], f);
350            }
351            if !self.vertices[edge.v1].pinned {
352                forces[edge.v1] = sub3(forces[edge.v1], f);
353            }
354        }
355
356        // Integrate
357        for (i, v) in self.vertices.iter_mut().enumerate() {
358            if v.pinned {
359                continue;
360            }
361            let inv_m = 1.0 / v.mass;
362            let accel = scale3(forces[i], inv_m);
363            v.velocity = add3(v.velocity, scale3(accel, dt));
364            v.position = add3(v.position, scale3(v.velocity, dt));
365        }
366    }
367}
368
369impl Default for ClothMesh {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375// ---------------------------------------------------------------------------
376// ClothCollider
377// ---------------------------------------------------------------------------
378
379/// An analytic collision primitive for cloth–object interaction.
380#[derive(Debug, Clone)]
381pub enum ClothCollider {
382    /// Sphere collider.
383    Sphere {
384        /// Center of the sphere.
385        center: [f64; 3],
386        /// Radius of the sphere.
387        radius: f64,
388    },
389    /// Half-space plane collider.  Points on the side `dot(p, normal) >= d`
390    /// are outside.
391    Plane {
392        /// Outward-facing unit normal.
393        normal: [f64; 3],
394        /// Signed distance from the origin to the plane along the normal.
395        d: f64,
396    },
397}
398
399impl ClothCollider {
400    /// Test whether a point `p` is inside (penetrating) this collider.
401    ///
402    /// Returns the penetration depth (positive means inside) and the outward
403    /// push direction.  Returns `None` if there is no penetration.
404    pub fn penetration(&self, p: [f64; 3]) -> Option<(f64, [f64; 3])> {
405        match self {
406            ClothCollider::Sphere { center, radius } => {
407                let delta = sub3(p, *center);
408                let dist = length3(delta);
409                if dist < *radius {
410                    let depth = radius - dist;
411                    let dir = if dist < 1e-15 {
412                        [0.0, 1.0, 0.0]
413                    } else {
414                        scale3(delta, 1.0 / dist)
415                    };
416                    Some((depth, dir))
417                } else {
418                    None
419                }
420            }
421            ClothCollider::Plane { normal, d } => {
422                let signed = dot3(*normal, p) - d;
423                if signed < 0.0 {
424                    Some((-signed, *normal))
425                } else {
426                    None
427                }
428            }
429        }
430    }
431}
432
433// ---------------------------------------------------------------------------
434// GpuClothSolver
435// ---------------------------------------------------------------------------
436
437/// XPBD-style cloth solver (CPU mock implementation).
438///
439/// Wraps a `ClothMesh` together with a set of collision primitives and runs
440/// position-based dynamic iterations per time step.
441#[derive(Debug, Clone)]
442pub struct GpuClothSolver {
443    /// The cloth mesh being simulated.
444    pub mesh: ClothMesh,
445    /// Collision primitives.
446    pub colliders: Vec<ClothCollider>,
447    /// Number of XPBD constraint-projection iterations per time step.
448    pub xpbd_iterations: usize,
449}
450
451impl GpuClothSolver {
452    /// Create a new solver with the given mesh.
453    pub fn new(mesh: ClothMesh) -> Self {
454        Self {
455            mesh,
456            colliders: Vec::new(),
457            xpbd_iterations: 8,
458        }
459    }
460
461    /// Add a collision primitive to the solver.
462    pub fn add_collider(&mut self, collider: ClothCollider) {
463        self.colliders.push(collider);
464    }
465
466    /// Advance the simulation by `dt` seconds.
467    ///
468    /// 1. Integrates velocities and positions with gravity.
469    /// 2. Projects spring constraints (`xpbd_iterations` times).
470    /// 3. Resolves collisions.
471    pub fn solve(&mut self, dt: f64) {
472        let gravity = [0.0, -9.81, 0.0];
473
474        // --- 1. Semi-implicit Euler predict ---
475        let n = self.mesh.vertices.len();
476        let mut pred_pos: Vec<[f64; 3]> = self
477            .mesh
478            .vertices
479            .iter()
480            .map(|v| {
481                if v.pinned {
482                    v.position
483                } else {
484                    add3(
485                        v.position,
486                        scale3(add3(v.velocity, scale3(gravity, dt)), dt),
487                    )
488                }
489            })
490            .collect();
491
492        // --- 2. XPBD constraint projection ---
493        for _ in 0..self.xpbd_iterations {
494            for edge in &self.mesh.edges {
495                let i = edge.v0;
496                let j = edge.v1;
497                let pi = pred_pos[i];
498                let pj = pred_pos[j];
499                let delta = sub3(pj, pi);
500                let dist = length3(delta);
501                if dist < 1e-15 {
502                    continue;
503                }
504                let constraint = dist - edge.rest_length;
505                let dir = scale3(delta, 1.0 / dist);
506
507                let wi = if self.mesh.vertices[i].pinned {
508                    0.0
509                } else {
510                    1.0 / self.mesh.vertices[i].mass
511                };
512                let wj = if self.mesh.vertices[j].pinned {
513                    0.0
514                } else {
515                    1.0 / self.mesh.vertices[j].mass
516                };
517                let w_total = wi + wj;
518                if w_total < 1e-15 {
519                    continue;
520                }
521
522                let alpha = 1.0 / (edge.stiffness * dt * dt);
523                let lambda = -constraint / (w_total + alpha);
524
525                if !self.mesh.vertices[i].pinned {
526                    pred_pos[i] = sub3(pred_pos[i], scale3(dir, wi * lambda));
527                }
528                if !self.mesh.vertices[j].pinned {
529                    pred_pos[j] = add3(pred_pos[j], scale3(dir, wj * lambda));
530                }
531            }
532        }
533
534        // --- 3. Collision resolution ---
535        for collider in &self.colliders {
536            for i in 0..n {
537                if self.mesh.vertices[i].pinned {
538                    continue;
539                }
540                if let Some((depth, dir)) = collider.penetration(pred_pos[i]) {
541                    pred_pos[i] = add3(pred_pos[i], scale3(dir, depth));
542                }
543            }
544        }
545
546        // --- 4. Update velocities and positions ---
547        for i in 0..n {
548            if !self.mesh.vertices[i].pinned {
549                let old_pos = self.mesh.vertices[i].position;
550                self.mesh.vertices[i].velocity = scale3(sub3(pred_pos[i], old_pos), 1.0 / dt);
551                self.mesh.vertices[i].position = pred_pos[i];
552            }
553        }
554    }
555}
556
557// ---------------------------------------------------------------------------
558// Tests
559// ---------------------------------------------------------------------------
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use std::f64::consts::PI;
565
566    // --- triangle_area ---
567
568    #[test]
569    fn test_triangle_area_unit() {
570        let v0 = [0.0, 0.0, 0.0];
571        let v1 = [1.0, 0.0, 0.0];
572        let v2 = [0.0, 1.0, 0.0];
573        let area = triangle_area(v0, v1, v2);
574        assert!((area - 0.5).abs() < 1e-12, "area={area}");
575    }
576
577    #[test]
578    fn test_triangle_area_degenerate() {
579        // Collinear points → area == 0
580        let v0 = [0.0, 0.0, 0.0];
581        let v1 = [1.0, 0.0, 0.0];
582        let v2 = [2.0, 0.0, 0.0];
583        assert!(triangle_area(v0, v1, v2) < 1e-12);
584    }
585
586    #[test]
587    fn test_triangle_area_equilateral() {
588        // Side = 2 → area = sqrt(3)
589        let v0 = [0.0, 0.0, 0.0];
590        let v1 = [2.0, 0.0, 0.0];
591        let v2 = [1.0, f64::sqrt(3.0), 0.0];
592        let expected = f64::sqrt(3.0);
593        assert!((triangle_area(v0, v1, v2) - expected).abs() < 1e-10);
594    }
595
596    #[test]
597    fn test_triangle_area_3d() {
598        // Right triangle in 3-D: legs along X and Z of length 1
599        let v0 = [0.0, 0.0, 0.0];
600        let v1 = [1.0, 0.0, 0.0];
601        let v2 = [0.0, 0.0, 1.0];
602        let area = triangle_area(v0, v1, v2);
603        assert!((area - 0.5).abs() < 1e-12);
604    }
605
606    #[test]
607    fn test_triangle_area_large() {
608        let v0 = [0.0, 0.0, 0.0];
609        let v1 = [10.0, 0.0, 0.0];
610        let v2 = [0.0, 10.0, 0.0];
611        let area = triangle_area(v0, v1, v2);
612        assert!((area - 50.0).abs() < 1e-10);
613    }
614
615    // --- dihedral_angle ---
616
617    #[test]
618    fn test_dihedral_angle_flat() {
619        // Two triangles coplanar in XZ → angle = 0
620        let v0 = [0.0, 0.0, -1.0];
621        let v1 = [-1.0, 0.0, 0.0];
622        let v2 = [1.0, 0.0, 0.0];
623        let v3 = [0.0, 0.0, 1.0];
624        let angle = dihedral_angle(v0, v1, v2, v3);
625        assert!(angle < 1e-10, "angle={angle}");
626    }
627
628    #[test]
629    fn test_dihedral_angle_ninety_degrees() {
630        // Second triangle folded 90° out of plane
631        let v0 = [0.0, 1.0, 0.0];
632        let v1 = [0.0, 0.0, 0.0];
633        let v2 = [1.0, 0.0, 0.0];
634        let v3 = [0.5, 0.0, 1.0];
635        let angle = dihedral_angle(v0, v1, v2, v3);
636        // Should be close to π/2
637        assert!((angle - PI / 2.0).abs() < 0.3, "angle={angle}");
638    }
639
640    #[test]
641    fn test_dihedral_angle_degenerate_edge() {
642        // v1 == v2 → degenerate, should not panic
643        let v0 = [0.0, 1.0, 0.0];
644        let v1 = [0.0, 0.0, 0.0];
645        let v2 = [0.0, 0.0, 0.0];
646        let v3 = [0.0, -1.0, 0.0];
647        let angle = dihedral_angle(v0, v1, v2, v3);
648        assert!(angle.is_finite());
649    }
650
651    #[test]
652    fn test_dihedral_angle_range() {
653        let v0 = [1.0, 1.0, 0.0];
654        let v1 = [0.0, 0.0, 0.0];
655        let v2 = [1.0, 0.0, 0.0];
656        let v3 = [1.0, -1.0, 0.0];
657        let angle = dihedral_angle(v0, v1, v2, v3);
658        assert!((0.0..=PI + 1e-10).contains(&angle), "angle={angle}");
659    }
660
661    // --- ClothVertex ---
662
663    #[test]
664    fn test_cloth_vertex_new() {
665        let v = ClothVertex::new([1.0, 2.0, 3.0], 2.5, false);
666        assert_eq!(v.position, [1.0, 2.0, 3.0]);
667        assert_eq!(v.velocity, [0.0; 3]);
668        assert!((v.mass - 2.5).abs() < 1e-12);
669        assert!(!v.pinned);
670    }
671
672    #[test]
673    fn test_cloth_vertex_pinned() {
674        let v = ClothVertex::new([0.0; 3], 1.0, true);
675        assert!(v.pinned);
676    }
677
678    // --- ClothEdge spring force ---
679
680    #[test]
681    fn test_spring_force_at_rest() {
682        let verts = vec![
683            ClothVertex::new([0.0, 0.0, 0.0], 1.0, false),
684            ClothVertex::new([1.0, 0.0, 0.0], 1.0, false),
685        ];
686        let edge = ClothEdge::new(0, 1, 1.0, 1000.0, 0.5);
687        let f = edge.spring_force(&verts);
688        // At rest length, spring force is zero (and zero relative velocity)
689        assert!(length3(f) < 1e-10, "f={f:?}");
690    }
691
692    #[test]
693    fn test_spring_force_stretched() {
694        let verts = vec![
695            ClothVertex::new([0.0, 0.0, 0.0], 1.0, false),
696            ClothVertex::new([2.0, 0.0, 0.0], 1.0, false),
697        ];
698        // rest = 1, stretched to 2 → force pulls v0 toward v1
699        let edge = ClothEdge::new(0, 1, 1.0, 1000.0, 0.0);
700        let f = edge.spring_force(&verts);
701        assert!(f[0] > 0.0, "force should be positive (toward v1)");
702        assert!(f[1].abs() < 1e-12);
703        assert!(f[2].abs() < 1e-12);
704    }
705
706    #[test]
707    fn test_spring_force_compressed() {
708        let verts = vec![
709            ClothVertex::new([0.0, 0.0, 0.0], 1.0, false),
710            ClothVertex::new([0.5, 0.0, 0.0], 1.0, false),
711        ];
712        // rest = 1, compressed to 0.5 → force pushes v0 away
713        let edge = ClothEdge::new(0, 1, 1.0, 1000.0, 0.0);
714        let f = edge.spring_force(&verts);
715        assert!(f[0] < 0.0, "force should be negative (push back)");
716    }
717
718    #[test]
719    fn test_spring_force_with_damping() {
720        let mut v0 = ClothVertex::new([0.0, 0.0, 0.0], 1.0, false);
721        let mut v1 = ClothVertex::new([2.0, 0.0, 0.0], 1.0, false);
722        v0.velocity = [-1.0, 0.0, 0.0];
723        v1.velocity = [1.0, 0.0, 0.0];
724        let verts = vec![v0, v1];
725        let edge = ClothEdge::new(0, 1, 1.0, 0.0, 10.0); // only damping
726        let f = edge.spring_force(&verts);
727        // Relative velocity along edge: (1 - (-1)) = 2, damping force = 10 * 2 = 20
728        assert!((f[0] - 20.0).abs() < 1e-10, "f={f:?}");
729    }
730
731    // --- ClothMesh ---
732
733    #[test]
734    fn test_build_grid_vertex_count() {
735        let mut mesh = ClothMesh::new();
736        mesh.build_grid(4, 5, 0.1);
737        assert_eq!(mesh.vertices.len(), 20);
738    }
739
740    #[test]
741    fn test_build_grid_top_row_pinned() {
742        let mut mesh = ClothMesh::new();
743        mesh.build_grid(4, 5, 0.1);
744        for c in 0..5 {
745            assert!(
746                mesh.vertices[c].pinned,
747                "vertex {c} in row 0 should be pinned"
748            );
749        }
750        for r in 1..4 {
751            for c in 0..5 {
752                assert!(!mesh.vertices[r * 5 + c].pinned);
753            }
754        }
755    }
756
757    #[test]
758    fn test_build_grid_spacing() {
759        let mut mesh = ClothMesh::new();
760        mesh.build_grid(2, 2, 0.5);
761        // Horizontal neighbor distance
762        let d = sub3(mesh.vertices[1].position, mesh.vertices[0].position);
763        assert!((length3(d) - 0.5).abs() < 1e-12);
764    }
765
766    #[test]
767    fn test_mesh_step_gravity() {
768        let mut mesh = ClothMesh::new();
769        mesh.build_grid(2, 1, 1.0);
770        // vertex 1 is not pinned
771        let y_before = mesh.vertices[1].position[1];
772        mesh.step(0.01, [0.0, -9.81, 0.0]);
773        let y_after = mesh.vertices[1].position[1];
774        assert!(y_after < y_before, "unpinned vertex should fall");
775    }
776
777    #[test]
778    fn test_mesh_step_pinned_unchanged() {
779        let mut mesh = ClothMesh::new();
780        mesh.build_grid(2, 1, 1.0);
781        let pos_before = mesh.vertices[0].position;
782        mesh.step(0.01, [0.0, -9.81, 0.0]);
783        assert_eq!(mesh.vertices[0].position, pos_before);
784    }
785
786    #[test]
787    fn test_mesh_default() {
788        let mesh = ClothMesh::default();
789        assert!(mesh.vertices.is_empty());
790        assert!(mesh.edges.is_empty());
791    }
792
793    // --- BendingConstraint ---
794
795    #[test]
796    fn test_bending_at_rest_angle() {
797        // Compute rest angle, then check zero force
798        let p0 = [0.0, 1.0, 0.0];
799        let p1 = [0.0, 0.0, 0.0];
800        let p2 = [1.0, 0.0, 0.0];
801        let p3 = [1.0, 0.0, -1.0];
802
803        let rest = dihedral_angle(p0, p1, p2, p3);
804
805        let verts = vec![
806            ClothVertex::new(p0, 1.0, false),
807            ClothVertex::new(p1, 1.0, false),
808            ClothVertex::new(p2, 1.0, false),
809            ClothVertex::new(p3, 1.0, false),
810        ];
811
812        let bc = BendingConstraint::new(0, 1, 2, 3, rest, 100.0);
813        let forces = bc.compute_force(&verts);
814        for f in &forces {
815            assert!(length3(*f) < 1e-8, "force should be ~zero at rest");
816        }
817    }
818
819    #[test]
820    fn test_bending_constraint_forces_len() {
821        let verts: Vec<ClothVertex> = (0..4)
822            .map(|i| ClothVertex::new([i as f64, 0.0, 0.0], 1.0, false))
823            .collect();
824        let bc = BendingConstraint::new(0, 1, 2, 3, 0.0, 10.0);
825        let forces = bc.compute_force(&verts);
826        assert_eq!(forces.len(), 4);
827    }
828
829    // --- ClothCollider ---
830
831    #[test]
832    fn test_sphere_collider_inside() {
833        let col = ClothCollider::Sphere {
834            center: [0.0; 3],
835            radius: 1.0,
836        };
837        let (depth, _dir) = col.penetration([0.5, 0.0, 0.0]).unwrap();
838        assert!((depth - 0.5).abs() < 1e-10);
839    }
840
841    #[test]
842    fn test_sphere_collider_outside() {
843        let col = ClothCollider::Sphere {
844            center: [0.0; 3],
845            radius: 1.0,
846        };
847        assert!(col.penetration([2.0, 0.0, 0.0]).is_none());
848    }
849
850    #[test]
851    fn test_sphere_collider_direction() {
852        let col = ClothCollider::Sphere {
853            center: [0.0; 3],
854            radius: 2.0,
855        };
856        let (_depth, dir) = col.penetration([1.0, 0.0, 0.0]).unwrap();
857        assert!((dir[0] - 1.0).abs() < 1e-10);
858    }
859
860    #[test]
861    fn test_sphere_collider_center() {
862        // Point at center → degenerate, should not panic
863        let col = ClothCollider::Sphere {
864            center: [0.0; 3],
865            radius: 1.0,
866        };
867        let result = col.penetration([0.0; 3]);
868        assert!(result.is_some());
869    }
870
871    #[test]
872    fn test_plane_collider_below() {
873        // Plane y=0, normal=(0,1,0), d=0 → point at y=-0.5 is below
874        let col = ClothCollider::Plane {
875            normal: [0.0, 1.0, 0.0],
876            d: 0.0,
877        };
878        let (depth, dir) = col.penetration([0.0, -0.5, 0.0]).unwrap();
879        assert!((depth - 0.5).abs() < 1e-10);
880        assert!((dir[1] - 1.0).abs() < 1e-10);
881    }
882
883    #[test]
884    fn test_plane_collider_above() {
885        let col = ClothCollider::Plane {
886            normal: [0.0, 1.0, 0.0],
887            d: 0.0,
888        };
889        assert!(col.penetration([0.0, 1.0, 0.0]).is_none());
890    }
891
892    // --- GpuClothSolver ---
893
894    #[test]
895    fn test_solver_new() {
896        let mesh = ClothMesh::new();
897        let solver = GpuClothSolver::new(mesh);
898        assert_eq!(solver.xpbd_iterations, 8);
899        assert!(solver.colliders.is_empty());
900    }
901
902    #[test]
903    fn test_solver_add_collider() {
904        let mesh = ClothMesh::new();
905        let mut solver = GpuClothSolver::new(mesh);
906        solver.add_collider(ClothCollider::Plane {
907            normal: [0.0, 1.0, 0.0],
908            d: -1.0,
909        });
910        assert_eq!(solver.colliders.len(), 1);
911    }
912
913    #[test]
914    fn test_solver_solve_no_penetration() {
915        let mut mesh = ClothMesh::new();
916        mesh.build_grid(2, 2, 0.5);
917        let mut solver = GpuClothSolver::new(mesh);
918        // Ground plane well below
919        solver.add_collider(ClothCollider::Plane {
920            normal: [0.0, 1.0, 0.0],
921            d: -10.0,
922        });
923        solver.solve(0.01);
924        // Unpinned vertices should still be above ground
925        for v in &solver.mesh.vertices {
926            if !v.pinned {
927                assert!(v.position[1] > -10.0);
928            }
929        }
930    }
931
932    #[test]
933    fn test_solver_sphere_prevents_penetration() {
934        let mut mesh = ClothMesh::new();
935        mesh.build_grid(2, 2, 0.1);
936        // Move free vertices inside a sphere
937        for v in mesh.vertices.iter_mut() {
938            if !v.pinned {
939                v.position = [0.0, 0.0, 0.0];
940            }
941        }
942        let mut solver = GpuClothSolver::new(mesh);
943        solver.add_collider(ClothCollider::Sphere {
944            center: [0.0; 3],
945            radius: 5.0,
946        });
947        solver.solve(0.01);
948        // All unpinned verts should now be at distance >= radius from center
949        for v in &solver.mesh.vertices {
950            if !v.pinned {
951                let dist = length3(v.position);
952                assert!(dist >= 5.0 - 1e-6, "dist={dist}");
953            }
954        }
955    }
956
957    #[test]
958    fn test_solver_pinned_stays() {
959        let mut mesh = ClothMesh::new();
960        mesh.build_grid(3, 3, 0.5);
961        let pin_pos: Vec<_> = mesh
962            .vertices
963            .iter()
964            .filter(|v| v.pinned)
965            .map(|v| v.position)
966            .collect();
967
968        let mut solver = GpuClothSolver::new(mesh);
969        for _ in 0..10 {
970            solver.solve(0.01);
971        }
972
973        let pin_pos_after: Vec<_> = solver
974            .mesh
975            .vertices
976            .iter()
977            .filter(|v| v.pinned)
978            .map(|v| v.position)
979            .collect();
980
981        assert_eq!(pin_pos, pin_pos_after);
982    }
983
984    #[test]
985    fn test_cloth_grid_has_edges() {
986        let mut mesh = ClothMesh::new();
987        mesh.build_grid(3, 3, 0.5);
988        assert!(!mesh.edges.is_empty());
989    }
990
991    #[test]
992    fn test_dihedral_symmetric() {
993        // Swapping the two wing vertices should give the same angle
994        let p0 = [0.0, 1.0, 0.0];
995        let p1 = [-1.0, 0.0, 0.0];
996        let p2 = [1.0, 0.0, 0.0];
997        let p3 = [0.0, -1.0, 0.5];
998        let a1 = dihedral_angle(p0, p1, p2, p3);
999        let a2 = dihedral_angle(p3, p1, p2, p0);
1000        assert!((a1 - a2).abs() < 1e-10, "a1={a1} a2={a2}");
1001    }
1002
1003    #[test]
1004    fn test_spring_force_zero_length() {
1005        // Both vertices at same position → no force (no panic)
1006        let verts = vec![
1007            ClothVertex::new([0.0, 0.0, 0.0], 1.0, false),
1008            ClothVertex::new([0.0, 0.0, 0.0], 1.0, false),
1009        ];
1010        let edge = ClothEdge::new(0, 1, 1.0, 1000.0, 0.5);
1011        let f = edge.spring_force(&verts);
1012        assert!(length3(f) < 1e-12);
1013    }
1014
1015    #[test]
1016    fn test_multiple_steps_energy_decreases() {
1017        // With damping, a stretched spring should lose energy over time
1018        let mut mesh = ClothMesh::new();
1019        mesh.build_grid(2, 1, 2.0); // rest_length=1, stretched to 2
1020        let mut solver = GpuClothSolver::new(mesh);
1021        solver.solve(0.001);
1022        // Just check no NaN
1023        for v in &solver.mesh.vertices {
1024            for x in v.position {
1025                assert!(x.is_finite());
1026            }
1027        }
1028    }
1029
1030    #[test]
1031    fn test_cloth_vertex_clone() {
1032        let v = ClothVertex::new([1.0, 2.0, 3.0], 1.0, false);
1033        let v2 = v.clone();
1034        assert_eq!(v.position, v2.position);
1035    }
1036
1037    #[test]
1038    fn test_cloth_edge_clone() {
1039        let e = ClothEdge::new(0, 1, 1.0, 100.0, 0.1);
1040        let e2 = e.clone();
1041        assert_eq!(e.v0, e2.v0);
1042        assert_eq!(e.rest_length, e2.rest_length);
1043    }
1044
1045    #[test]
1046    fn test_collider_clone() {
1047        let c = ClothCollider::Sphere {
1048            center: [1.0, 2.0, 3.0],
1049            radius: 0.5,
1050        };
1051        let c2 = c.clone();
1052        if let ClothCollider::Sphere { radius, .. } = c2 {
1053            assert!((radius - 0.5).abs() < 1e-12);
1054        } else {
1055            panic!("wrong variant");
1056        }
1057    }
1058}