Skip to main content

oxihuman_morph/
pose_blend.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6/// A single corrective shape driven by one or more joint angles
7pub struct PoseCorrectiveShape {
8    pub name: String,
9    pub joint_name: String,
10    pub axis: [f32; 3],
11    pub angle_min: f32,
12    pub angle_max: f32,
13    pub deltas: Vec<(u32, [f32; 3])>,
14}
15
16impl PoseCorrectiveShape {
17    pub fn new(name: impl Into<String>, joint_name: impl Into<String>) -> Self {
18        Self {
19            name: name.into(),
20            joint_name: joint_name.into(),
21            axis: [0.0, 0.0, 1.0],
22            angle_min: 0.0,
23            angle_max: std::f32::consts::PI,
24            deltas: Vec::new(),
25        }
26    }
27
28    /// Compute weight [0..1] given current joint angle along axis
29    pub fn weight(&self, current_angle: f32) -> f32 {
30        angle_to_weight(
31            current_angle,
32            self.angle_min,
33            self.angle_max,
34            BlendInterpolation::Linear,
35        )
36    }
37}
38
39/// Joint rotation state: axis-angle representation
40pub struct JointRotation {
41    pub joint_name: String,
42    pub axis: [f32; 3],
43    pub angle: f32,
44}
45
46impl JointRotation {
47    pub fn new(joint_name: impl Into<String>, axis: [f32; 3], angle: f32) -> Self {
48        Self {
49            joint_name: joint_name.into(),
50            axis,
51            angle,
52        }
53    }
54
55    /// Angle projected onto a specific axis (dot product of axes * angle)
56    pub fn projected_angle(&self, target_axis: [f32; 3]) -> f32 {
57        let dot = self.axis[0] * target_axis[0]
58            + self.axis[1] * target_axis[1]
59            + self.axis[2] * target_axis[2];
60        dot * self.angle
61    }
62}
63
64/// Library of corrective shapes
65pub struct PoseBlendLibrary {
66    shapes: Vec<PoseCorrectiveShape>,
67}
68
69impl PoseBlendLibrary {
70    pub fn new() -> Self {
71        Self { shapes: Vec::new() }
72    }
73
74    pub fn add_shape(&mut self, shape: PoseCorrectiveShape) {
75        self.shapes.push(shape);
76    }
77
78    pub fn shape_count(&self) -> usize {
79        self.shapes.len()
80    }
81
82    pub fn shapes_for_joint(&self, joint_name: &str) -> Vec<&PoseCorrectiveShape> {
83        self.shapes
84            .iter()
85            .filter(|s| s.joint_name == joint_name)
86            .collect()
87    }
88
89    pub fn get_shape(&self, name: &str) -> Option<&PoseCorrectiveShape> {
90        self.shapes.iter().find(|s| s.name == name)
91    }
92
93    pub fn remove_shape(&mut self, name: &str) -> bool {
94        let before = self.shapes.len();
95        self.shapes.retain(|s| s.name != name);
96        self.shapes.len() < before
97    }
98
99    /// Compute active weights for all shapes given current joint rotations
100    pub fn compute_weights<'a>(
101        &'a self,
102        rotations: &[JointRotation],
103    ) -> Vec<(&'a PoseCorrectiveShape, f32)> {
104        self.shapes
105            .iter()
106            .map(|shape| {
107                let weight = rotations
108                    .iter()
109                    .find(|r| r.joint_name == shape.joint_name)
110                    .map(|r| {
111                        let projected = r.projected_angle(shape.axis);
112                        shape.weight(projected)
113                    })
114                    .unwrap_or(0.0);
115                (shape, weight)
116            })
117            .collect()
118    }
119
120    /// Apply all active corrective shapes to vertex positions.
121    /// Returns new positions with corrections applied.
122    pub fn apply_corrections(
123        &self,
124        positions: &[[f32; 3]],
125        rotations: &[JointRotation],
126    ) -> Vec<[f32; 3]> {
127        let mut result = positions.to_vec();
128        let weights = self.compute_weights(rotations);
129
130        for (shape, w) in weights {
131            if w <= 0.0 {
132                continue;
133            }
134            for &(vi, [dx, dy, dz]) in &shape.deltas {
135                let idx = vi as usize;
136                if idx < result.len() {
137                    result[idx][0] += dx * w;
138                    result[idx][1] += dy * w;
139                    result[idx][2] += dz * w;
140                }
141            }
142        }
143
144        result
145    }
146}
147
148impl Default for PoseBlendLibrary {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154/// Interpolation modes for weight mapping
155pub enum BlendInterpolation {
156    Linear,
157    SmoothStep,
158    Cubic,
159}
160
161/// Map angle to weight using specified interpolation
162pub fn angle_to_weight(angle: f32, min: f32, max: f32, mode: BlendInterpolation) -> f32 {
163    let range = max - min;
164    let t = if range == 0.0 {
165        0.0
166    } else {
167        ((angle - min) / range).clamp(0.0, 1.0)
168    };
169
170    match mode {
171        BlendInterpolation::Linear => t,
172        BlendInterpolation::SmoothStep => t * t * (3.0 - 2.0 * t),
173        BlendInterpolation::Cubic => t * t * t * (10.0 - 15.0 * t + 6.0 * t * t),
174    }
175}
176
177/// Create a common elbow corrective shape (example factory)
178pub fn make_elbow_corrective(joint_name: impl Into<String>) -> PoseCorrectiveShape {
179    PoseCorrectiveShape {
180        name: "elbow_corrective".to_string(),
181        joint_name: joint_name.into(),
182        axis: [0.0, 0.0, 1.0],
183        angle_min: 0.0,
184        angle_max: std::f32::consts::PI * 0.8,
185        deltas: Vec::new(),
186    }
187}
188
189/// Create a common shoulder corrective shape
190pub fn make_shoulder_corrective(joint_name: impl Into<String>) -> PoseCorrectiveShape {
191    PoseCorrectiveShape {
192        name: "shoulder_corrective".to_string(),
193        joint_name: joint_name.into(),
194        axis: [1.0, 0.0, 0.0],
195        angle_min: 0.0,
196        angle_max: std::f32::consts::FRAC_PI_2,
197        deltas: Vec::new(),
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use std::f32::consts::{FRAC_PI_2, PI};
205
206    #[test]
207    fn test_pose_corrective_shape_new() {
208        let shape = PoseCorrectiveShape::new("test_shape", "elbow_joint");
209        assert_eq!(shape.name, "test_shape");
210        assert_eq!(shape.joint_name, "elbow_joint");
211        assert_eq!(shape.axis, [0.0, 0.0, 1.0]);
212        assert_eq!(shape.angle_min, 0.0);
213        assert_eq!(shape.angle_max, PI);
214        assert!(shape.deltas.is_empty());
215    }
216
217    #[test]
218    fn test_weight_at_min_angle() {
219        let shape = PoseCorrectiveShape {
220            name: "s".to_string(),
221            joint_name: "j".to_string(),
222            axis: [0.0, 0.0, 1.0],
223            angle_min: 0.0,
224            angle_max: PI,
225            deltas: Vec::new(),
226        };
227        let w = shape.weight(0.0);
228        assert!((w - 0.0).abs() < 1e-6, "weight at min should be 0, got {w}");
229    }
230
231    #[test]
232    fn test_weight_at_max_angle() {
233        let shape = PoseCorrectiveShape {
234            name: "s".to_string(),
235            joint_name: "j".to_string(),
236            axis: [0.0, 0.0, 1.0],
237            angle_min: 0.0,
238            angle_max: PI,
239            deltas: Vec::new(),
240        };
241        let w = shape.weight(PI);
242        assert!((w - 1.0).abs() < 1e-6, "weight at max should be 1, got {w}");
243    }
244
245    #[test]
246    fn test_weight_midpoint() {
247        let shape = PoseCorrectiveShape {
248            name: "s".to_string(),
249            joint_name: "j".to_string(),
250            axis: [0.0, 0.0, 1.0],
251            angle_min: 0.0,
252            angle_max: PI,
253            deltas: Vec::new(),
254        };
255        let w = shape.weight(PI / 2.0);
256        assert!(
257            (w - 0.5).abs() < 1e-5,
258            "weight at midpoint should be 0.5, got {w}"
259        );
260    }
261
262    #[test]
263    fn test_weight_clamped_below() {
264        let shape = PoseCorrectiveShape {
265            name: "s".to_string(),
266            joint_name: "j".to_string(),
267            axis: [0.0, 0.0, 1.0],
268            angle_min: 1.0,
269            angle_max: 2.0,
270            deltas: Vec::new(),
271        };
272        let w = shape.weight(-1.0);
273        assert!(
274            (w - 0.0).abs() < 1e-6,
275            "weight below min should clamp to 0, got {w}"
276        );
277    }
278
279    #[test]
280    fn test_weight_clamped_above() {
281        let shape = PoseCorrectiveShape {
282            name: "s".to_string(),
283            joint_name: "j".to_string(),
284            axis: [0.0, 0.0, 1.0],
285            angle_min: 0.0,
286            angle_max: 1.0,
287            deltas: Vec::new(),
288        };
289        let w = shape.weight(100.0);
290        assert!(
291            (w - 1.0).abs() < 1e-6,
292            "weight above max should clamp to 1, got {w}"
293        );
294    }
295
296    #[test]
297    fn test_joint_rotation_projected_angle() {
298        let rot = JointRotation::new("shoulder", [1.0, 0.0, 0.0], FRAC_PI_2);
299        // Projection onto the same axis should give the full angle
300        let proj = rot.projected_angle([1.0, 0.0, 0.0]);
301        assert!(
302            (proj - FRAC_PI_2).abs() < 1e-5,
303            "projected angle should equal angle, got {proj}"
304        );
305
306        // Projection onto orthogonal axis should be 0
307        let proj_orth = rot.projected_angle([0.0, 1.0, 0.0]);
308        assert!(
309            proj_orth.abs() < 1e-6,
310            "orthogonal projection should be 0, got {proj_orth}"
311        );
312    }
313
314    #[test]
315    fn test_library_add_and_count() {
316        let mut lib = PoseBlendLibrary::new();
317        assert_eq!(lib.shape_count(), 0);
318        lib.add_shape(PoseCorrectiveShape::new("s1", "j1"));
319        lib.add_shape(PoseCorrectiveShape::new("s2", "j2"));
320        assert_eq!(lib.shape_count(), 2);
321    }
322
323    #[test]
324    fn test_library_shapes_for_joint() {
325        let mut lib = PoseBlendLibrary::new();
326        lib.add_shape(PoseCorrectiveShape::new("s1", "elbow"));
327        lib.add_shape(PoseCorrectiveShape::new("s2", "elbow"));
328        lib.add_shape(PoseCorrectiveShape::new("s3", "shoulder"));
329
330        let elbow_shapes = lib.shapes_for_joint("elbow");
331        assert_eq!(elbow_shapes.len(), 2);
332
333        let shoulder_shapes = lib.shapes_for_joint("shoulder");
334        assert_eq!(shoulder_shapes.len(), 1);
335
336        let missing = lib.shapes_for_joint("knee");
337        assert!(missing.is_empty());
338    }
339
340    #[test]
341    fn test_library_compute_weights() {
342        let mut lib = PoseBlendLibrary::new();
343        let mut shape = PoseCorrectiveShape::new("elbow_corr", "elbow");
344        shape.angle_min = 0.0;
345        shape.angle_max = PI;
346        shape.axis = [0.0, 0.0, 1.0];
347        lib.add_shape(shape);
348
349        let rotations = vec![JointRotation::new("elbow", [0.0, 0.0, 1.0], PI / 2.0)];
350        let weights = lib.compute_weights(&rotations);
351        assert_eq!(weights.len(), 1);
352        let (s, w) = &weights[0];
353        assert_eq!(s.name, "elbow_corr");
354        assert!((w - 0.5).abs() < 1e-5, "expected weight ~0.5, got {w}");
355    }
356
357    #[test]
358    fn test_library_apply_corrections() {
359        let mut lib = PoseBlendLibrary::new();
360        let mut shape = PoseCorrectiveShape::new("corr", "elbow");
361        shape.angle_min = 0.0;
362        shape.angle_max = PI;
363        shape.axis = [0.0, 0.0, 1.0];
364        // vertex 0 gets +1 on x at full weight
365        shape.deltas = vec![(0, [1.0, 0.0, 0.0])];
366        lib.add_shape(shape);
367
368        let positions = vec![[0.0_f32, 0.0, 0.0], [1.0, 1.0, 1.0]];
369        // Full angle => weight = 1.0
370        let rotations = vec![JointRotation::new("elbow", [0.0, 0.0, 1.0], PI)];
371        let result = lib.apply_corrections(&positions, &rotations);
372        assert_eq!(result.len(), 2);
373        assert!(
374            (result[0][0] - 1.0).abs() < 1e-5,
375            "vertex 0 x should be 1.0, got {}",
376            result[0][0]
377        );
378        assert!((result[0][1]).abs() < 1e-5);
379        assert!(
380            (result[1][0] - 1.0).abs() < 1e-5,
381            "vertex 1 x unchanged, got {}",
382            result[1][0]
383        );
384    }
385
386    #[test]
387    fn test_angle_to_weight_linear() {
388        let w0 = angle_to_weight(0.0, 0.0, 1.0, BlendInterpolation::Linear);
389        let w1 = angle_to_weight(1.0, 0.0, 1.0, BlendInterpolation::Linear);
390        let wh = angle_to_weight(0.5, 0.0, 1.0, BlendInterpolation::Linear);
391        assert!((w0 - 0.0).abs() < 1e-6);
392        assert!((w1 - 1.0).abs() < 1e-6);
393        assert!((wh - 0.5).abs() < 1e-6);
394    }
395
396    #[test]
397    fn test_angle_to_weight_smoothstep() {
398        let w0 = angle_to_weight(0.0, 0.0, 1.0, BlendInterpolation::SmoothStep);
399        let w1 = angle_to_weight(1.0, 0.0, 1.0, BlendInterpolation::SmoothStep);
400        let wh = angle_to_weight(0.5, 0.0, 1.0, BlendInterpolation::SmoothStep);
401        assert!((w0 - 0.0).abs() < 1e-6, "smoothstep at 0 should be 0");
402        assert!((w1 - 1.0).abs() < 1e-6, "smoothstep at 1 should be 1");
403        // smoothstep(0.5) = 0.5*0.5*(3 - 2*0.5) = 0.25 * 2 = 0.5
404        assert!(
405            (wh - 0.5).abs() < 1e-6,
406            "smoothstep at 0.5 should be 0.5, got {wh}"
407        );
408    }
409
410    #[test]
411    fn test_make_elbow_corrective() {
412        let shape = make_elbow_corrective("elbow_L");
413        assert_eq!(shape.joint_name, "elbow_L");
414        assert_eq!(shape.axis, [0.0, 0.0, 1.0]);
415        assert!((shape.angle_min - 0.0).abs() < 1e-6);
416        assert!((shape.angle_max - PI * 0.8).abs() < 1e-5);
417        assert!(shape.deltas.is_empty());
418    }
419
420    #[test]
421    fn test_make_shoulder_corrective() {
422        let shape = make_shoulder_corrective("shoulder_R");
423        assert_eq!(shape.joint_name, "shoulder_R");
424        assert_eq!(shape.axis, [1.0, 0.0, 0.0]);
425        assert!((shape.angle_min - 0.0).abs() < 1e-6);
426        assert!((shape.angle_max - FRAC_PI_2).abs() < 1e-5);
427        assert!(shape.deltas.is_empty());
428    }
429}