Skip to main content

oxihuman_mesh/
spring_deform.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6use crate::mesh::MeshBuffers;
7use crate::normals::compute_normals;
8use std::collections::HashMap;
9
10// ---------------------------------------------------------------------------
11// SpringParams
12// ---------------------------------------------------------------------------
13
14/// Parameters controlling the spring-mass simulation.
15pub struct SpringParams {
16    /// Spring stiffness k.
17    pub stiffness: f32,
18    /// Velocity damping factor applied each substep (0..1).
19    pub damping: f32,
20    /// Vertex mass.
21    pub mass: f32,
22    /// Gravity vector in world space.
23    pub gravity: [f32; 3],
24    /// Number of integration substeps per `step()` call.
25    pub substeps: usize,
26    /// If `true`, boundary vertices are pinned and do not move.
27    pub fixed_boundary: bool,
28}
29
30impl Default for SpringParams {
31    fn default() -> Self {
32        Self {
33            stiffness: 50.0,
34            damping: 0.9,
35            mass: 1.0,
36            gravity: [0.0, -9.8, 0.0],
37            substeps: 4,
38            fixed_boundary: true,
39        }
40    }
41}
42
43// ---------------------------------------------------------------------------
44// Helper math (inline, no external deps)
45// ---------------------------------------------------------------------------
46
47#[inline]
48fn vec3_sub(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
49    [a[0] - b[0], a[1] - b[1], a[2] - b[2]]
50}
51
52#[inline]
53fn vec3_add(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
54    [a[0] + b[0], a[1] + b[1], a[2] + b[2]]
55}
56
57#[inline]
58fn vec3_scale(a: [f32; 3], s: f32) -> [f32; 3] {
59    [a[0] * s, a[1] * s, a[2] * s]
60}
61
62#[inline]
63fn vec3_dot(a: [f32; 3], b: [f32; 3]) -> f32 {
64    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
65}
66
67#[inline]
68fn vec3_len_sq(a: [f32; 3]) -> f32 {
69    vec3_dot(a, a)
70}
71
72#[inline]
73fn vec3_len(a: [f32; 3]) -> f32 {
74    vec3_len_sq(a).sqrt()
75}
76
77// ---------------------------------------------------------------------------
78// Build topology helpers
79// ---------------------------------------------------------------------------
80
81/// Build springs from mesh edge topology (deduplicated).
82/// Each spring is `(vertex_a, vertex_b, rest_length)`.
83pub fn build_edge_springs(mesh: &MeshBuffers) -> Vec<(usize, usize, f32)> {
84    // Use a HashMap keyed by (min_idx, max_idx) to deduplicate edges.
85    let mut edge_map: HashMap<(u32, u32), ()> = HashMap::new();
86    let mut springs = Vec::new();
87
88    for tri in mesh.indices.chunks_exact(3) {
89        let verts = [tri[0], tri[1], tri[2]];
90        for i in 0..3 {
91            let a = verts[i];
92            let b = verts[(i + 1) % 3];
93            let key = if a < b { (a, b) } else { (b, a) };
94            if edge_map.insert(key, ()).is_none() {
95                // New edge — compute rest length from initial positions.
96                let pa = mesh.positions[key.0 as usize];
97                let pb = mesh.positions[key.1 as usize];
98                let rest_len = vec3_len(vec3_sub(pa, pb));
99                springs.push((key.0 as usize, key.1 as usize, rest_len));
100            }
101        }
102    }
103
104    springs
105}
106
107/// Detect boundary vertices: a vertex is on the boundary if at least one of
108/// its edges belongs to only one triangle face.
109pub fn find_boundary_vertices(mesh: &MeshBuffers) -> Vec<bool> {
110    let n = mesh.positions.len();
111    // Count how many faces each directed edge appears in.
112    let mut edge_face_count: HashMap<(u32, u32), u32> = HashMap::new();
113
114    for tri in mesh.indices.chunks_exact(3) {
115        let verts = [tri[0], tri[1], tri[2]];
116        for i in 0..3 {
117            let a = verts[i];
118            let b = verts[(i + 1) % 3];
119            // Use undirected key (min, max) for manifold-edge counting.
120            let key = if a < b { (a, b) } else { (b, a) };
121            *edge_face_count.entry(key).or_insert(0) += 1;
122        }
123    }
124
125    let mut is_boundary = vec![false; n];
126    for ((a, b), count) in &edge_face_count {
127        if *count == 1 {
128            // This edge is on the boundary — mark both endpoints.
129            is_boundary[*a as usize] = true;
130            is_boundary[*b as usize] = true;
131        }
132    }
133    is_boundary
134}
135
136// ---------------------------------------------------------------------------
137// SpringSystem
138// ---------------------------------------------------------------------------
139
140/// A spring-mass system attached to mesh vertices for soft-body simulation.
141pub struct SpringSystem {
142    /// Rest positions (used to reset).
143    pub rest_positions: Vec<[f32; 3]>,
144    /// Current positions.
145    pub positions: Vec<[f32; 3]>,
146    /// Current velocities.
147    pub velocities: Vec<[f32; 3]>,
148    /// Springs: (vertex_a, vertex_b, rest_length).
149    pub springs: Vec<(usize, usize, f32)>,
150    /// Fixed (pinned) vertices that do not move.
151    pub fixed: Vec<bool>,
152    /// Simulation parameters.
153    pub params: SpringParams,
154}
155
156impl SpringSystem {
157    /// Construct from a mesh and simulation parameters.
158    pub fn from_mesh(mesh: &MeshBuffers, params: SpringParams) -> Self {
159        let n = mesh.positions.len();
160        let rest_positions = mesh.positions.clone();
161        let positions = mesh.positions.clone();
162        let velocities = vec![[0.0f32; 3]; n];
163        let springs = build_edge_springs(mesh);
164
165        let fixed = if params.fixed_boundary {
166            find_boundary_vertices(mesh)
167        } else {
168            vec![false; n]
169        };
170
171        Self {
172            rest_positions,
173            positions,
174            velocities,
175            springs,
176            fixed,
177            params,
178        }
179    }
180
181    /// Number of vertices in the system.
182    pub fn vertex_count(&self) -> usize {
183        self.positions.len()
184    }
185
186    /// Number of springs in the system.
187    pub fn spring_count(&self) -> usize {
188        self.springs.len()
189    }
190
191    /// Total kinetic energy: sum of 0.5 * mass * |v|^2 over all vertices.
192    pub fn kinetic_energy(&self) -> f32 {
193        let half_m = 0.5 * self.params.mass;
194        self.velocities
195            .iter()
196            .map(|v| half_m * vec3_len_sq(*v))
197            .sum()
198    }
199
200    /// Returns `true` when the kinetic energy is below `threshold`.
201    pub fn is_settled(&self, threshold: f32) -> bool {
202        self.kinetic_energy() < threshold
203    }
204
205    /// Pin or unpin a single vertex.
206    pub fn set_fixed(&mut self, vertex: usize, fixed: bool) {
207        if vertex < self.fixed.len() {
208            self.fixed[vertex] = fixed;
209        }
210    }
211
212    /// Apply an instantaneous velocity impulse to a vertex.
213    pub fn apply_impulse(&mut self, vertex: usize, force: [f32; 3]) {
214        if vertex < self.velocities.len() && !self.fixed[vertex] {
215            self.velocities[vertex] = vec3_add(self.velocities[vertex], force);
216        }
217    }
218
219    /// Accumulate gravity into all non-fixed vertex velocities.
220    pub fn apply_gravity_impulse(&mut self, dt: f32) {
221        let g = self.params.gravity;
222        for (i, vel) in self.velocities.iter_mut().enumerate() {
223            if !self.fixed[i] {
224                *vel = vec3_add(*vel, vec3_scale(g, dt));
225            }
226        }
227    }
228
229    /// Advance the simulation by `dt` seconds (uses `params.substeps` substeps).
230    pub fn step(&mut self, dt: f32) {
231        let sub_dt = dt / self.params.substeps as f32;
232        for _ in 0..self.params.substeps {
233            self.substep(sub_dt);
234        }
235    }
236
237    /// Advance the simulation by `n` steps of `dt` seconds each.
238    pub fn step_n(&mut self, dt: f32, n: usize) {
239        for _ in 0..n {
240            self.step(dt);
241        }
242    }
243
244    /// Return all vertices to their rest positions and zero all velocities.
245    pub fn reset(&mut self) {
246        self.positions = self.rest_positions.clone();
247        let n = self.positions.len();
248        self.velocities = vec![[0.0f32; 3]; n];
249    }
250
251    /// Build a new `MeshBuffers` from the template, with positions replaced by
252    /// the current simulated positions, and normals recomputed.
253    pub fn to_mesh(&self, template: &MeshBuffers) -> MeshBuffers {
254        let mut out = template.clone();
255        out.positions = self.positions.clone();
256        compute_normals(&mut out);
257        out
258    }
259
260    // -----------------------------------------------------------------------
261    // Internal integration step
262    // -----------------------------------------------------------------------
263
264    fn substep(&mut self, dt: f32) {
265        let n = self.positions.len();
266        let mut forces = vec![[0.0f32; 3]; n];
267
268        // Spring forces.
269        let k = self.params.stiffness;
270        for &(a, b, rest_len) in &self.springs {
271            let pa = self.positions[a];
272            let pb = self.positions[b];
273            let diff = vec3_sub(pb, pa);
274            let cur_len = vec3_len(diff);
275            if cur_len < 1e-10 {
276                continue;
277            }
278            let unit = vec3_scale(diff, 1.0 / cur_len);
279            let stretch = cur_len - rest_len;
280            let f = vec3_scale(unit, k * stretch);
281            forces[a] = vec3_add(forces[a], f);
282            forces[b] = vec3_sub(forces[b], f);
283        }
284
285        let g = self.params.gravity;
286        let m = self.params.mass;
287        let damping = self.params.damping;
288
289        // Integrate each non-fixed vertex: semi-implicit Euler.
290        for (i, (pos, vel)) in self
291            .positions
292            .iter_mut()
293            .zip(self.velocities.iter_mut())
294            .enumerate()
295        {
296            if self.fixed[i] {
297                continue;
298            }
299            let acc = [
300                forces[i][0] / m + g[0],
301                forces[i][1] / m + g[1],
302                forces[i][2] / m + g[2],
303            ];
304            let new_vel = vec3_add(*vel, vec3_scale(acc, dt));
305            // Apply damping.
306            let new_vel = vec3_scale(new_vel, damping);
307            *vel = new_vel;
308            *pos = vec3_add(*pos, vec3_scale(new_vel, dt));
309        }
310    }
311}
312
313// ---------------------------------------------------------------------------
314// High-level jiggle deform
315// ---------------------------------------------------------------------------
316
317/// Apply an impulse to a single vertex and simulate until the mesh settles
318/// (kinetic energy < 0.001) or 1000 steps have elapsed, then return the
319/// deformed mesh.
320pub fn jiggle_deform(
321    mesh: &MeshBuffers,
322    impulse_vertex: usize,
323    impulse: [f32; 3],
324    params: &SpringParams,
325) -> MeshBuffers {
326    // Build a clone of params — SpringParams does not implement Clone so we
327    // reconstruct manually.
328    let p = SpringParams {
329        stiffness: params.stiffness,
330        damping: params.damping,
331        mass: params.mass,
332        gravity: params.gravity,
333        substeps: params.substeps,
334        fixed_boundary: params.fixed_boundary,
335    };
336
337    let mut system = SpringSystem::from_mesh(mesh, p);
338    system.apply_impulse(impulse_vertex, impulse);
339
340    const MAX_STEPS: usize = 1000;
341    const DT: f32 = 1.0 / 60.0;
342
343    for _ in 0..MAX_STEPS {
344        if system.is_settled(0.001) {
345            break;
346        }
347        system.step(DT);
348    }
349
350    system.to_mesh(mesh)
351}
352
353// ---------------------------------------------------------------------------
354// Tests
355// ---------------------------------------------------------------------------
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use oxihuman_morph::engine::MeshBuffers as MB;
361
362    /// Build a simple 2-triangle mesh (4 vertices, 2 faces).
363    ///
364    /// ```
365    ///  3---2
366    ///  |  /|
367    ///  | / |
368    ///  |/  |
369    ///  0---1
370    /// ```
371    fn two_tri_mesh() -> MeshBuffers {
372        MeshBuffers::from_morph(MB {
373            positions: vec![
374                [0.0, 0.0, 0.0], // 0
375                [1.0, 0.0, 0.0], // 1
376                [1.0, 1.0, 0.0], // 2
377                [0.0, 1.0, 0.0], // 3
378            ],
379            normals: vec![[0.0, 0.0, 1.0]; 4],
380            uvs: vec![[0.0, 0.0]; 4],
381            // Two triangles sharing edge 1-3.
382            indices: vec![0, 1, 3, 1, 2, 3],
383            has_suit: false,
384        })
385    }
386
387    /// Default params but with gravity disabled so tests are deterministic.
388    fn no_gravity_params() -> SpringParams {
389        SpringParams {
390            gravity: [0.0, 0.0, 0.0],
391            fixed_boundary: false,
392            ..Default::default()
393        }
394    }
395
396    // -----------------------------------------------------------------------
397
398    #[test]
399    fn test_spring_system_from_mesh() {
400        let mesh = two_tri_mesh();
401        let sys = SpringSystem::from_mesh(&mesh, SpringParams::default());
402        assert_eq!(sys.rest_positions.len(), 4);
403        assert_eq!(sys.positions.len(), 4);
404        assert_eq!(sys.velocities.len(), 4);
405    }
406
407    #[test]
408    fn test_vertex_count() {
409        let mesh = two_tri_mesh();
410        let sys = SpringSystem::from_mesh(&mesh, SpringParams::default());
411        assert_eq!(sys.vertex_count(), 4);
412    }
413
414    #[test]
415    fn test_spring_count() {
416        let mesh = two_tri_mesh();
417        let sys = SpringSystem::from_mesh(&mesh, SpringParams::default());
418        // 2 triangles → up to 5 unique edges (0-1, 1-3, 0-3, 1-2, 2-3).
419        assert!(sys.spring_count() >= 4);
420        assert!(sys.spring_count() <= 5);
421    }
422
423    #[test]
424    fn test_reset() {
425        let mesh = two_tri_mesh();
426        let mut sys = SpringSystem::from_mesh(&mesh, no_gravity_params());
427        sys.apply_impulse(0, [1.0, 0.0, 0.0]);
428        sys.step(0.1);
429        sys.reset();
430        for i in 0..sys.vertex_count() {
431            for j in 0..3 {
432                assert!(
433                    (sys.positions[i][j] - sys.rest_positions[i][j]).abs() < 1e-6,
434                    "position not reset at vertex {i}"
435                );
436                assert!(
437                    sys.velocities[i][j].abs() < 1e-6,
438                    "velocity not zeroed at vertex {i}"
439                );
440            }
441        }
442    }
443
444    #[test]
445    fn test_step_moves_unfixed_vertices() {
446        let mesh = two_tri_mesh();
447        let params = SpringParams {
448            fixed_boundary: false,
449            gravity: [0.0, -9.8, 0.0],
450            damping: 1.0, // no damping so movement is clear
451            substeps: 1,
452            ..Default::default()
453        };
454        let mut sys = SpringSystem::from_mesh(&mesh, params);
455        let orig = sys.positions.clone();
456        sys.step(0.05);
457        // With gravity, at least some vertices should have moved.
458        let moved = sys
459            .positions
460            .iter()
461            .zip(orig.iter())
462            .any(|(a, b)| vec3_len(vec3_sub(*a, *b)) > 1e-6);
463        assert!(moved, "no vertices moved after step with gravity");
464    }
465
466    #[test]
467    fn test_fixed_vertex_stays_fixed() {
468        let mesh = two_tri_mesh();
469        let params = SpringParams {
470            fixed_boundary: false,
471            gravity: [0.0, -9.8, 0.0],
472            ..Default::default()
473        };
474        let mut sys = SpringSystem::from_mesh(&mesh, params);
475        // Manually pin vertex 0.
476        sys.set_fixed(0, true);
477        let orig0 = sys.positions[0];
478        sys.step_n(0.016, 20);
479        // Vertex 0 must not have moved.
480        for (j, &orig) in orig0.iter().enumerate() {
481            assert!(
482                (sys.positions[0][j] - orig).abs() < 1e-6,
483                "fixed vertex moved at component {j}"
484            );
485        }
486    }
487
488    #[test]
489    fn test_apply_impulse() {
490        let mesh = two_tri_mesh();
491        let params = no_gravity_params();
492        let mut sys = SpringSystem::from_mesh(&mesh, params);
493        // Vertex 0 is not fixed (fixed_boundary=false).
494        sys.apply_impulse(0, [5.0, 0.0, 0.0]);
495        assert!((sys.velocities[0][0] - 5.0).abs() < 1e-6);
496    }
497
498    #[test]
499    fn test_kinetic_energy() {
500        let mesh = two_tri_mesh();
501        let params = no_gravity_params();
502        let mut sys = SpringSystem::from_mesh(&mesh, params);
503        // At rest, KE should be zero.
504        assert!(sys.kinetic_energy() < 1e-10);
505        sys.apply_impulse(0, [1.0, 0.0, 0.0]);
506        assert!(sys.kinetic_energy() > 0.0);
507    }
508
509    #[test]
510    fn test_is_settled() {
511        let mesh = two_tri_mesh();
512        let params = no_gravity_params();
513        let mut sys = SpringSystem::from_mesh(&mesh, params);
514        // At rest, settled with any positive threshold.
515        assert!(sys.is_settled(1e-3));
516        sys.apply_impulse(0, [100.0, 0.0, 0.0]);
517        // Large impulse → not settled.
518        assert!(!sys.is_settled(1e-3));
519    }
520
521    #[test]
522    fn test_build_edge_springs() {
523        let mesh = two_tri_mesh();
524        let springs = build_edge_springs(&mesh);
525        // Verify no duplicate edges.
526        let mut seen: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
527        for &(a, b, _) in &springs {
528            let key = (a.min(b), a.max(b));
529            assert!(seen.insert(key), "duplicate edge ({a},{b})");
530        }
531        // All rest lengths should be positive.
532        for &(_, _, len) in &springs {
533            assert!(len > 0.0, "non-positive rest length");
534        }
535    }
536
537    #[test]
538    fn test_find_boundary_vertices() {
539        let mesh = two_tri_mesh();
540        let boundary = find_boundary_vertices(&mesh);
541        assert_eq!(boundary.len(), 4);
542        // All 4 vertices are on the boundary of the 2-triangle mesh.
543        for (i, &b) in boundary.iter().enumerate() {
544            assert!(b, "vertex {i} should be boundary");
545        }
546    }
547
548    #[test]
549    fn test_to_mesh() {
550        let mesh = two_tri_mesh();
551        let mut sys = SpringSystem::from_mesh(
552            &mesh,
553            SpringParams {
554                fixed_boundary: false,
555                gravity: [0.0, -9.8, 0.0],
556                ..Default::default()
557            },
558        );
559        sys.step(0.1);
560        let out = sys.to_mesh(&mesh);
561        // Output mesh should have same topology.
562        assert_eq!(out.indices, mesh.indices);
563        assert_eq!(out.positions.len(), mesh.positions.len());
564        // Normals must be recomputed (not all-zero).
565        let all_zero = out
566            .normals
567            .iter()
568            .all(|n| n[0].abs() < 1e-10 && n[1].abs() < 1e-10 && n[2].abs() < 1e-10);
569        assert!(!all_zero, "normals should be non-zero after recompute");
570    }
571
572    #[test]
573    fn test_jiggle_deform() {
574        let mesh = two_tri_mesh();
575        let params = SpringParams {
576            fixed_boundary: false,
577            gravity: [0.0, 0.0, 0.0],
578            damping: 0.5,
579            stiffness: 50.0,
580            ..Default::default()
581        };
582        let result = jiggle_deform(&mesh, 0, [0.5, 0.0, 0.0], &params);
583        // Result must have same topology and vertex count.
584        assert_eq!(result.indices, mesh.indices);
585        assert_eq!(result.positions.len(), mesh.positions.len());
586    }
587}