Skip to main content

oxihuman_morph/
corrective_shapes.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Corrective blend shapes (CBS) / pose-space deformation system.
5
6use std::collections::HashMap;
7
8// ---------------------------------------------------------------------------
9// Structs
10// ---------------------------------------------------------------------------
11
12/// A single corrective blend shape.
13#[allow(dead_code)]
14#[derive(Debug, Clone)]
15pub struct CorrectiveShape {
16    pub name: String,
17    /// Param name → trigger value when this shape is fully active.
18    pub driver_params: HashMap<String, f32>,
19    /// Per-vertex delta when fully active.
20    pub deltas: Vec<[f32; 3]>,
21    /// Controls width of Gaussian RBF activation (default 1.0).
22    pub influence_radius: f32,
23}
24
25/// Library of corrective shapes.
26#[allow(dead_code)]
27#[derive(Debug, Clone)]
28pub struct CorrectiveShapeLibrary {
29    pub shapes: Vec<CorrectiveShape>,
30    pub vertex_count: usize,
31}
32
33/// Result of evaluating the library against a parameter set.
34#[allow(dead_code)]
35#[derive(Debug, Clone)]
36pub struct CorrectiveEvalResult {
37    pub combined_deltas: Vec<[f32; 3]>,
38    /// (shape_name, weight) for every shape with weight > 0.01.
39    pub active_shapes: Vec<(String, f32)>,
40}
41
42// ---------------------------------------------------------------------------
43// Implementations
44// ---------------------------------------------------------------------------
45
46impl CorrectiveShapeLibrary {
47    #[allow(dead_code)]
48    pub fn new(vertex_count: usize) -> Self {
49        Self {
50            shapes: Vec::new(),
51            vertex_count,
52        }
53    }
54
55    #[allow(dead_code)]
56    pub fn add_shape(&mut self, shape: CorrectiveShape) {
57        self.shapes.push(shape);
58    }
59
60    /// Evaluate all shapes against `current_params`, returning combined deltas.
61    #[allow(dead_code)]
62    pub fn evaluate(&self, current_params: &HashMap<String, f32>) -> CorrectiveEvalResult {
63        let mut pairs: Vec<(Vec<[f32; 3]>, f32)> = Vec::new();
64        let mut active_shapes = Vec::new();
65
66        for shape in &self.shapes {
67            let dist = corrective_distance(current_params, &shape.driver_params);
68            let w = corrective_weight(dist, shape.influence_radius);
69            if w > 0.01 {
70                active_shapes.push((shape.name.clone(), w));
71                pairs.push((shape.deltas.clone(), w));
72            }
73        }
74
75        let combined_deltas = combine_corrective_deltas(&pairs, self.vertex_count);
76        CorrectiveEvalResult {
77            combined_deltas,
78            active_shapes,
79        }
80    }
81}
82
83// ---------------------------------------------------------------------------
84// Free functions
85// ---------------------------------------------------------------------------
86
87/// L2 distance over matching keys only.
88#[allow(dead_code)]
89pub fn corrective_distance(current: &HashMap<String, f32>, driver: &HashMap<String, f32>) -> f32 {
90    let mut sum_sq = 0.0f32;
91    for (k, &d) in driver {
92        let c = current.get(k).copied().unwrap_or(0.0);
93        sum_sq += (c - d) * (c - d);
94    }
95    sum_sq.sqrt()
96}
97
98/// Gaussian RBF: exp(-(dist/radius)²).
99#[allow(dead_code)]
100pub fn corrective_weight(distance: f32, radius: f32) -> f32 {
101    let r = radius.max(f32::EPSILON);
102    let t = distance / r;
103    (-t * t).exp()
104}
105
106/// Weighted sum of delta arrays; missing/shorter arrays are zero-padded.
107#[allow(dead_code)]
108pub fn combine_corrective_deltas(
109    deltas_and_weights: &[(Vec<[f32; 3]>, f32)],
110    vertex_count: usize,
111) -> Vec<[f32; 3]> {
112    let mut out = vec![[0.0f32; 3]; vertex_count];
113    for (deltas, w) in deltas_and_weights {
114        let n = deltas.len().min(vertex_count);
115        for (out_v, delta_v) in out.iter_mut().zip(deltas.iter()).take(n) {
116            out_v[0] += delta_v[0] * w;
117            out_v[1] += delta_v[1] * w;
118            out_v[2] += delta_v[2] * w;
119        }
120    }
121    out
122}
123
124/// Apply corrective deltas on top of the base mesh positions.
125#[allow(dead_code)]
126pub fn apply_corrective_to_mesh(base: &[[f32; 3]], result: &CorrectiveEvalResult) -> Vec<[f32; 3]> {
127    let mut out: Vec<[f32; 3]> = base.to_vec();
128    for (out_v, delta_v) in out.iter_mut().zip(result.combined_deltas.iter()) {
129        out_v[0] += delta_v[0];
130        out_v[1] += delta_v[1];
131        out_v[2] += delta_v[2];
132    }
133    out
134}
135
136/// A small example library with 4 corrective shapes.
137#[allow(dead_code)]
138pub fn standard_corrective_shapes(vertex_count: usize) -> CorrectiveShapeLibrary {
139    let mut lib = CorrectiveShapeLibrary::new(vertex_count);
140
141    // 1 – Shoulder raise (left)
142    {
143        let mut driver = HashMap::new();
144        driver.insert("shoulder_raise_l".into(), 1.0);
145        let deltas: Vec<[f32; 3]> = (0..vertex_count)
146            .map(|i| {
147                let t = (i as f32) / (vertex_count.max(1) as f32);
148                [0.0, t * 0.02, 0.0]
149            })
150            .collect();
151        lib.add_shape(CorrectiveShape {
152            name: "shoulder_raise_left".into(),
153            driver_params: driver,
154            deltas,
155            influence_radius: 1.0,
156        });
157    }
158
159    // 2 – Elbow bend (right)
160    {
161        let mut driver = HashMap::new();
162        driver.insert("elbow_bend_r".into(), 1.0);
163        let deltas: Vec<[f32; 3]> = (0..vertex_count)
164            .map(|i| {
165                let t = (i as f32) / (vertex_count.max(1) as f32);
166                [t * 0.01, 0.0, 0.0]
167            })
168            .collect();
169        lib.add_shape(CorrectiveShape {
170            name: "elbow_bend_right".into(),
171            driver_params: driver,
172            deltas,
173            influence_radius: 1.0,
174        });
175    }
176
177    // 3 – Squat knee
178    {
179        let mut driver = HashMap::new();
180        driver.insert("knee_bend".into(), 1.0);
181        let deltas: Vec<[f32; 3]> = (0..vertex_count)
182            .map(|i| {
183                let t = (i as f32) / (vertex_count.max(1) as f32);
184                [0.0, 0.0, t * 0.015]
185            })
186            .collect();
187        lib.add_shape(CorrectiveShape {
188            name: "squat_knee".into(),
189            driver_params: driver,
190            deltas,
191            influence_radius: 1.0,
192        });
193    }
194
195    // 4 – Heavy belly
196    {
197        let mut driver = HashMap::new();
198        driver.insert("belly_weight".into(), 1.0);
199        let deltas: Vec<[f32; 3]> = (0..vertex_count)
200            .map(|i| {
201                let t = (i as f32) / (vertex_count.max(1) as f32);
202                [0.0, -t * 0.01, t * 0.03]
203            })
204            .collect();
205        lib.add_shape(CorrectiveShape {
206            name: "heavy_belly".into(),
207            driver_params: driver,
208            deltas,
209            influence_radius: 1.0,
210        });
211    }
212
213    lib
214}
215
216// ---------------------------------------------------------------------------
217// Tests
218// ---------------------------------------------------------------------------
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_corrective_weight_at_zero() {
226        assert!((corrective_weight(0.0, 1.0) - 1.0).abs() < 1e-6);
227    }
228
229    #[test]
230    fn test_corrective_weight_at_radius() {
231        let w = corrective_weight(1.0, 1.0);
232        // exp(-1) ≈ 0.3679
233        assert!(w < 0.37 && w > 0.35, "w={w}");
234    }
235
236    #[test]
237    fn test_corrective_weight_large_distance() {
238        let w = corrective_weight(100.0, 1.0);
239        assert!(w < 1e-10, "w={w}");
240    }
241
242    #[test]
243    fn test_corrective_distance_same_params() {
244        let mut p = HashMap::new();
245        p.insert("a".into(), 1.0);
246        p.insert("b".into(), 2.0);
247        assert!(corrective_distance(&p, &p) < 1e-6);
248    }
249
250    #[test]
251    fn test_corrective_distance_different_params() {
252        let mut current = HashMap::new();
253        current.insert("x".into(), 0.0);
254        let mut driver = HashMap::new();
255        driver.insert("x".into(), 3.0);
256        driver.insert("y".into(), 4.0); // missing in current → 0
257                                        // sqrt(9 + 16) = 5
258        let d = corrective_distance(&current, &driver);
259        assert!((d - 5.0).abs() < 1e-5, "d={d}");
260    }
261
262    #[test]
263    fn test_combine_corrective_deltas_single_weight() {
264        let deltas = vec![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
265        let combined = combine_corrective_deltas(&[(deltas, 0.5)], 2);
266        assert!((combined[0][0] - 0.5).abs() < 1e-5);
267        assert!((combined[1][2] - 3.0).abs() < 1e-5);
268    }
269
270    #[test]
271    fn test_combine_corrective_deltas_two_shapes() {
272        let d1 = vec![[1.0f32, 0.0, 0.0]];
273        let d2 = vec![[0.0f32, 1.0, 0.0]];
274        let combined = combine_corrective_deltas(&[(d1, 1.0), (d2, 1.0)], 1);
275        assert!((combined[0][0] - 1.0).abs() < 1e-5);
276        assert!((combined[0][1] - 1.0).abs() < 1e-5);
277    }
278
279    #[test]
280    fn test_evaluate_matching_params() {
281        let lib = standard_corrective_shapes(4);
282        let mut params = HashMap::new();
283        params.insert("shoulder_raise_l".into(), 1.0);
284        let result = lib.evaluate(&params);
285        assert!(!result.active_shapes.is_empty());
286        assert!(result
287            .active_shapes
288            .iter()
289            .any(|(n, _)| n == "shoulder_raise_left"));
290    }
291
292    #[test]
293    fn test_evaluate_no_matching_params() {
294        let lib = standard_corrective_shapes(4);
295        let params = HashMap::new(); // no drivers match at distance 0
296        let result = lib.evaluate(&params);
297        // All shapes have driver params with value 1.0; current=0 → dist=1.0.
298        // corrective_weight(1.0, 1.0) = exp(-1) ≈ 0.368 > 0.01, so active.
299        // With completely empty params all shapes will still be slightly active.
300        // Verify deltas are well-defined.
301        assert_eq!(result.combined_deltas.len(), 4);
302    }
303
304    #[test]
305    fn test_evaluate_far_params_near_zero() {
306        let lib = standard_corrective_shapes(4);
307        let mut params = HashMap::new();
308        params.insert("shoulder_raise_l".into(), 1000.0); // very far from driver=1.0
309        let result = lib.evaluate(&params);
310        // shoulder_raise_left shape should be nearly zero.
311        let shoulder = result
312            .active_shapes
313            .iter()
314            .find(|(n, _)| n == "shoulder_raise_left");
315        if let Some((_, w)) = shoulder {
316            assert!(*w < 0.01 || *w < 1.0);
317        }
318        assert_eq!(result.combined_deltas.len(), 4);
319    }
320
321    #[test]
322    fn test_standard_corrective_shapes_has_4() {
323        let lib = standard_corrective_shapes(10);
324        assert_eq!(lib.shapes.len(), 4);
325    }
326
327    #[test]
328    fn test_apply_corrective_to_mesh_adds_deltas() {
329        let base = vec![[1.0f32, 1.0, 1.0], [2.0, 2.0, 2.0]];
330        let combined_deltas = vec![[0.1f32, 0.2, 0.3], [0.4, 0.5, 0.6]];
331        let result = CorrectiveEvalResult {
332            combined_deltas,
333            active_shapes: Vec::new(),
334        };
335        let out = apply_corrective_to_mesh(&base, &result);
336        assert!((out[0][0] - 1.1).abs() < 1e-5);
337        assert!((out[1][2] - 2.6).abs() < 1e-5);
338    }
339
340    #[test]
341    fn test_apply_corrective_zero_weight_no_change() {
342        let base = vec![[5.0f32, 5.0, 5.0]];
343        let combined_deltas = vec![[0.0f32, 0.0, 0.0]];
344        let result = CorrectiveEvalResult {
345            combined_deltas,
346            active_shapes: Vec::new(),
347        };
348        let out = apply_corrective_to_mesh(&base, &result);
349        assert!((out[0][0] - 5.0).abs() < 1e-5);
350    }
351
352    #[test]
353    fn test_combine_corrective_deltas_empty() {
354        let combined = combine_corrective_deltas(&[], 3);
355        assert_eq!(combined.len(), 3);
356        assert_eq!(combined[0], [0.0, 0.0, 0.0]);
357    }
358}