Skip to main content

oxihuman_morph/
mocap_retarget_adv.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Advanced skeleton-space BVH retargeting with twist decomposition.
5
6// ── Joint / Pose ──────────────────────────────────────────────────────────────
7
8/// A single joint defined in rest pose.
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub struct Joint {
12    pub name: String,
13    pub parent: Option<usize>,
14    /// Rest rotation as quaternion [x, y, z, w].
15    pub rest_rot: [f32; 4],
16    /// Rest position (local offset from parent).
17    pub rest_pos: [f32; 3],
18}
19
20/// A full skeleton pose (rest skeleton + per-joint local rotations + root translation).
21#[allow(dead_code)]
22#[derive(Debug, Clone)]
23pub struct SkeletonPose {
24    pub joints: Vec<Joint>,
25    /// Local rotation per joint [x, y, z, w].
26    pub local_rots: Vec<[f32; 4]>,
27    pub root_pos: [f32; 3],
28}
29
30/// Maps source joint names to target joint names for retargeting.
31#[allow(dead_code)]
32#[derive(Debug, Clone)]
33pub struct RetargetMap {
34    pub source_joints: Vec<String>,
35    pub target_joints: Vec<String>,
36    pub scale: f32,
37}
38
39// ── Quaternion math ───────────────────────────────────────────────────────────
40
41/// Multiply two quaternions: result = a * b.
42#[allow(dead_code)]
43pub fn quat_multiply(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
44    let [ax, ay, az, aw] = a;
45    let [bx, by, bz, bw] = b;
46    [
47        aw * bx + ax * bw + ay * bz - az * by,
48        aw * by - ax * bz + ay * bw + az * bx,
49        aw * bz + ax * by - ay * bx + az * bw,
50        aw * bw - ax * bx - ay * by - az * bz,
51    ]
52}
53
54/// Return the inverse (conjugate for unit quaternions) of q.
55#[allow(dead_code)]
56pub fn quat_inverse(q: [f32; 4]) -> [f32; 4] {
57    [-q[0], -q[1], -q[2], q[3]]
58}
59
60/// Normalise q to unit length.
61#[allow(dead_code)]
62pub fn quat_normalize(q: [f32; 4]) -> [f32; 4] {
63    let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
64    if len < 1e-10 {
65        return [0.0, 0.0, 0.0, 1.0];
66    }
67    [q[0] / len, q[1] / len, q[2] / len, q[3] / len]
68}
69
70/// Spherical linear interpolation between two quaternions.
71#[allow(dead_code)]
72pub fn quat_slerp(a: [f32; 4], b: [f32; 4], t: f32) -> [f32; 4] {
73    let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3];
74    // Ensure shortest path
75    let (b, dot) = if dot < 0.0 {
76        ([-b[0], -b[1], -b[2], -b[3]], -dot)
77    } else {
78        (b, dot)
79    };
80    let dot = dot.min(1.0);
81    if dot > 0.9995 {
82        // Linear interpolation fallback
83        return quat_normalize([
84            a[0] + t * (b[0] - a[0]),
85            a[1] + t * (b[1] - a[1]),
86            a[2] + t * (b[2] - a[2]),
87            a[3] + t * (b[3] - a[3]),
88        ]);
89    }
90    let theta_0 = dot.acos();
91    let theta = theta_0 * t;
92    let sin_theta = theta.sin();
93    let sin_theta_0 = theta_0.sin();
94    let s0 = ((1.0 - t) * theta_0).sin() / sin_theta_0;
95    let s1 = sin_theta / sin_theta_0;
96    quat_normalize([
97        s0 * a[0] + s1 * b[0],
98        s0 * a[1] + s1 * b[1],
99        s0 * a[2] + s1 * b[2],
100        s0 * a[3] + s1 * b[3],
101    ])
102}
103
104/// Decompose q into swing and twist about `twist_axis`.
105/// Returns (swing, twist).
106#[allow(dead_code)]
107pub fn quat_to_swing_twist(q: [f32; 4], twist_axis: [f32; 3]) -> ([f32; 4], [f32; 4]) {
108    let [x, y, z, w] = q;
109    let [ax, ay, az] = twist_axis;
110    // Project rotation axis onto twist axis
111    let dot = x * ax + y * ay + z * az;
112    let twist = quat_normalize([dot * ax, dot * ay, dot * az, w]);
113    let swing = quat_multiply(q, quat_inverse(twist));
114    (quat_normalize(swing), twist)
115}
116
117// ── Retargeting logic ─────────────────────────────────────────────────────────
118
119/// Retarget a single joint rotation from source skeleton space to target skeleton space.
120/// `src_rot` is the local rotation in source, `src_rest` is the source rest rotation,
121/// `tgt_rest` is the target rest rotation.
122#[allow(dead_code)]
123pub fn retarget_joint_rotation(
124    src_rot: [f32; 4],
125    src_rest: [f32; 4],
126    tgt_rest: [f32; 4],
127) -> [f32; 4] {
128    // Convert source local rotation to rest-relative delta
129    let delta = quat_multiply(src_rot, quat_inverse(src_rest));
130    // Apply delta in target rest space
131    quat_normalize(quat_multiply(delta, tgt_rest))
132}
133
134/// Retarget a full pose from source to target skeleton using the provided joint map.
135#[allow(dead_code)]
136pub fn retarget_pose_adv(
137    src: &SkeletonPose,
138    tgt_rest: &SkeletonPose,
139    map: &RetargetMap,
140) -> SkeletonPose {
141    let mut out = tgt_rest.clone();
142    out.root_pos = scale_root_translation(
143        src.root_pos,
144        compute_skeleton_height(src),
145        compute_skeleton_height(tgt_rest),
146    );
147
148    for (si, src_name) in map.source_joints.iter().enumerate() {
149        if let Some(tgt_name) = map.target_joints.get(si) {
150            // Find source joint index
151            let src_idx = src
152                .joints
153                .iter()
154                .position(|j| &j.name == src_name)
155                .unwrap_or(usize::MAX);
156            // Find target joint index
157            let tgt_idx = tgt_rest
158                .joints
159                .iter()
160                .position(|j| &j.name == tgt_name)
161                .unwrap_or(usize::MAX);
162
163            if src_idx < src.joints.len()
164                && tgt_idx < tgt_rest.joints.len()
165                && src_idx < src.local_rots.len()
166            {
167                let src_rot = src.local_rots[src_idx];
168                let src_rest_rot = src.joints[src_idx].rest_rot;
169                let tgt_rest_rot = tgt_rest.joints[tgt_idx].rest_rot;
170                out.local_rots[tgt_idx] =
171                    retarget_joint_rotation(src_rot, src_rest_rot, tgt_rest_rot);
172            }
173        }
174    }
175    out
176}
177
178/// Scale root translation proportionally between skeleton heights.
179#[allow(dead_code)]
180pub fn scale_root_translation(pos: [f32; 3], src_height: f32, tgt_height: f32) -> [f32; 3] {
181    if src_height < 1e-6 {
182        return pos;
183    }
184    let s = tgt_height / src_height;
185    [pos[0] * s, pos[1] * s, pos[2] * s]
186}
187
188/// Blend two skeleton poses by SLERPing all joint rotations.
189#[allow(dead_code)]
190pub fn blend_poses(a: &SkeletonPose, b: &SkeletonPose, t: f32) -> SkeletonPose {
191    let joints = a.joints.clone();
192    let n = joints.len().min(a.local_rots.len()).min(b.local_rots.len());
193    let local_rots = (0..n)
194        .map(|i| quat_slerp(a.local_rots[i], b.local_rots[i], t))
195        .collect();
196    let root_pos = [
197        a.root_pos[0] + t * (b.root_pos[0] - a.root_pos[0]),
198        a.root_pos[1] + t * (b.root_pos[1] - a.root_pos[1]),
199        a.root_pos[2] + t * (b.root_pos[2] - a.root_pos[2]),
200    ];
201    SkeletonPose {
202        joints,
203        local_rots,
204        root_pos,
205    }
206}
207
208/// Compute approximate skeleton height as max Y extent of rest positions
209/// accumulated from root.
210#[allow(dead_code)]
211pub fn compute_skeleton_height(pose: &SkeletonPose) -> f32 {
212    let mut max_y = 0.0_f32;
213    // Accumulate world-space Y positions
214    let mut world_y = vec![0.0_f32; pose.joints.len()];
215    for (i, joint) in pose.joints.iter().enumerate() {
216        let parent_y = joint.parent.map_or(0.0, |p| world_y[p]);
217        world_y[i] = parent_y + joint.rest_pos[1];
218        max_y = max_y.max(world_y[i]);
219    }
220    max_y.max(0.001)
221}
222
223/// Build a standard 14-joint biped retarget map.
224#[allow(dead_code)]
225pub fn standard_biped_retarget_map() -> RetargetMap {
226    let joints = vec![
227        "Hips",
228        "Spine",
229        "Spine1",
230        "Neck",
231        "Head",
232        "LeftArm",
233        "LeftForeArm",
234        "LeftHand",
235        "RightArm",
236        "RightForeArm",
237        "RightHand",
238        "LeftUpLeg",
239        "LeftLeg",
240        "RightUpLeg",
241    ];
242    RetargetMap {
243        source_joints: joints.iter().map(|s| s.to_string()).collect(),
244        target_joints: joints.iter().map(|s| s.to_string()).collect(),
245        scale: 1.0,
246    }
247}
248
249// ── Helper: build a minimal test skeleton ────────────────────────────────────
250
251#[allow(dead_code)]
252fn identity_quat() -> [f32; 4] {
253    [0.0, 0.0, 0.0, 1.0]
254}
255
256#[allow(dead_code)]
257fn make_test_pose(n: usize) -> SkeletonPose {
258    let joints = (0..n)
259        .map(|i| Joint {
260            name: format!("Joint{i}"),
261            parent: if i == 0 { None } else { Some(i - 1) },
262            rest_rot: identity_quat(),
263            rest_pos: [0.0, 0.1 * i as f32, 0.0],
264        })
265        .collect();
266    let local_rots = vec![identity_quat(); n];
267    SkeletonPose {
268        joints,
269        local_rots,
270        root_pos: [0.0, 0.0, 0.0],
271    }
272}
273
274// ── Tests ─────────────────────────────────────────────────────────────────────
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    fn id() -> [f32; 4] {
281        [0.0, 0.0, 0.0, 1.0]
282    }
283
284    fn nearly_eq(a: [f32; 4], b: [f32; 4]) -> bool {
285        (0..4).all(|i| (a[i] - b[i]).abs() < 1e-4)
286    }
287
288    fn nearly_eq3(a: [f32; 3], b: [f32; 3]) -> bool {
289        (0..3).all(|i| (a[i] - b[i]).abs() < 1e-4)
290    }
291
292    #[test]
293    fn test_quat_multiply_identity_left() {
294        let q = [0.1, 0.2, 0.3, 0.927];
295        let q = quat_normalize(q);
296        let result = quat_multiply(id(), q);
297        assert!(nearly_eq(result, q));
298    }
299
300    #[test]
301    fn test_quat_multiply_identity_right() {
302        let q = quat_normalize([0.1, 0.2, 0.3, 0.927]);
303        let result = quat_multiply(q, id());
304        assert!(nearly_eq(result, q));
305    }
306
307    #[test]
308    fn test_quat_inverse_composed_is_identity() {
309        let q = quat_normalize([0.1, 0.2, 0.3, 0.927]);
310        let qi = quat_inverse(q);
311        let result = quat_normalize(quat_multiply(q, qi));
312        assert!(nearly_eq(result, id()));
313    }
314
315    #[test]
316    fn test_quat_inverse_conjugate() {
317        let q = [0.1, 0.2, 0.3, 0.9];
318        let qi = quat_inverse(q);
319        assert_eq!(qi, [-0.1, -0.2, -0.3, 0.9]);
320    }
321
322    #[test]
323    fn test_quat_normalize_length_one() {
324        let q = quat_normalize([1.0, 2.0, 3.0, 4.0]);
325        let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
326        assert!((len - 1.0).abs() < 1e-6);
327    }
328
329    #[test]
330    fn test_quat_normalize_zero_returns_identity() {
331        let q = quat_normalize([0.0, 0.0, 0.0, 0.0]);
332        assert_eq!(q, [0.0, 0.0, 0.0, 1.0]);
333    }
334
335    #[test]
336    fn test_quat_slerp_t0() {
337        let a = id();
338        let frac = std::f32::consts::FRAC_1_SQRT_2;
339        let b = quat_normalize([0.0, frac, 0.0, frac]);
340        let result = quat_slerp(a, b, 0.0);
341        assert!(nearly_eq(result, a));
342    }
343
344    #[test]
345    fn test_quat_slerp_t1() {
346        let a = id();
347        let frac = std::f32::consts::FRAC_1_SQRT_2;
348        let b = quat_normalize([0.0, frac, 0.0, frac]);
349        let result = quat_slerp(a, b, 1.0);
350        assert!(nearly_eq(result, b));
351    }
352
353    #[test]
354    fn test_quat_slerp_t_half_normalized() {
355        let a = id();
356        let b = id();
357        let result = quat_slerp(a, b, 0.5);
358        assert!(nearly_eq(result, id()));
359    }
360
361    #[test]
362    fn test_swing_twist_roundtrip() {
363        let q = quat_normalize([0.1, 0.2, 0.0, 0.974]);
364        let axis = [0.0, 1.0, 0.0];
365        let (swing, twist) = quat_to_swing_twist(q, axis);
366        let composed = quat_normalize(quat_multiply(swing, twist));
367        assert!(nearly_eq(composed, quat_normalize(q)));
368    }
369
370    #[test]
371    fn test_swing_twist_pure_twist() {
372        // A rotation purely about Y axis — swing should be ~identity
373        let q = quat_normalize([0.0, 0.5, 0.0, 0.866]);
374        let (swing, _twist) = quat_to_swing_twist(q, [0.0, 1.0, 0.0]);
375        assert!((swing[3] - 1.0).abs() < 0.1); // swing w close to 1
376    }
377
378    #[test]
379    fn test_retarget_pose_no_nan() {
380        let src = make_test_pose(5);
381        let tgt = make_test_pose(5);
382        let map = RetargetMap {
383            source_joints: src.joints.iter().map(|j| j.name.clone()).collect(),
384            target_joints: tgt.joints.iter().map(|j| j.name.clone()).collect(),
385            scale: 1.0,
386        };
387        let out = retarget_pose_adv(&src, &tgt, &map);
388        for r in &out.local_rots {
389            for v in r {
390                assert!(!v.is_nan());
391            }
392        }
393    }
394
395    #[test]
396    fn test_blend_poses_t0() {
397        let a = make_test_pose(4);
398        let b = make_test_pose(4);
399        let out = blend_poses(&a, &b, 0.0);
400        for i in 0..4 {
401            assert!(nearly_eq(out.local_rots[i], a.local_rots[i]));
402        }
403    }
404
405    #[test]
406    fn test_blend_poses_t1() {
407        let a = make_test_pose(4);
408        let b = make_test_pose(4);
409        let out = blend_poses(&a, &b, 1.0);
410        for i in 0..4 {
411            assert!(nearly_eq(out.local_rots[i], b.local_rots[i]));
412        }
413    }
414
415    #[test]
416    fn test_blend_poses_root_lerp() {
417        let mut a = make_test_pose(2);
418        let mut b = make_test_pose(2);
419        a.root_pos = [0.0, 0.0, 0.0];
420        b.root_pos = [2.0, 4.0, 6.0];
421        let out = blend_poses(&a, &b, 0.5);
422        assert!(nearly_eq3(out.root_pos, [1.0, 2.0, 3.0]));
423    }
424
425    #[test]
426    fn test_compute_skeleton_height_positive() {
427        let pose = make_test_pose(5);
428        let h = compute_skeleton_height(&pose);
429        assert!(h > 0.0);
430    }
431
432    #[test]
433    fn test_scale_root_translation_proportional() {
434        let pos = [1.0, 2.0, 3.0];
435        let out = scale_root_translation(pos, 1.0, 2.0);
436        assert!(nearly_eq3(out, [2.0, 4.0, 6.0]));
437    }
438
439    #[test]
440    fn test_standard_biped_retarget_map_14_joints() {
441        let map = standard_biped_retarget_map();
442        assert_eq!(map.source_joints.len(), 14);
443        assert_eq!(map.target_joints.len(), 14);
444    }
445
446    #[test]
447    fn test_retarget_joint_rotation_identity_pass_through() {
448        let rot = quat_normalize([0.1, 0.2, 0.3, 0.9]);
449        let rest = id();
450        let result = retarget_joint_rotation(rot, rest, rest);
451        assert!(nearly_eq(quat_normalize(result), quat_normalize(rot)));
452    }
453}