Skip to main content

oxihuman_morph/
pose_symmetry.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Pose/skeleton symmetry enforcement and mirror analysis.
5//! Note: body_symmetry.rs handles mesh vertex symmetry; this module covers joint pose symmetry.
6
7#[allow(dead_code)]
8#[derive(Debug, Clone)]
9pub struct JointPose {
10    pub name: String,
11    pub rotation: [f32; 4], // quaternion xyzw
12    pub translation: [f32; 3],
13    pub scale: f32,
14}
15
16#[allow(dead_code)]
17#[derive(Debug, Clone)]
18pub struct SymmetryPair {
19    pub left_name: String,
20    pub right_name: String,
21    pub mirror_axis: u8, // 0=X, 1=Y, 2=Z
22}
23
24#[allow(dead_code)]
25#[derive(Debug, Clone)]
26pub struct PoseSkeleton {
27    pub joints: Vec<JointPose>,
28}
29
30/// Mirror a quaternion across the given axis (0=X, 1=Y, 2=Z).
31/// Mirroring across axis negates the corresponding imaginary components.
32#[allow(dead_code)]
33pub fn mirror_joint_rotation(q: [f32; 4], axis: u8) -> [f32; 4] {
34    // xyzw layout. Mirroring across axis A negates components B and C (the other two).
35    // Then also negate w to keep the handedness consistent.
36    let [x, y, z, w] = q;
37    match axis {
38        0 => [-x, y, z, -w], // mirror X: flip yz plane → negate x and w
39        1 => [x, -y, z, -w], // mirror Y
40        2 => [x, y, -z, -w], // mirror Z
41        _ => [x, y, z, w],
42    }
43}
44
45/// Produce a mirrored copy of the skeleton, swapping left/right joints.
46#[allow(dead_code)]
47pub fn mirror_pose(skeleton: &PoseSkeleton, pairs: &[SymmetryPair]) -> PoseSkeleton {
48    let mut joints = skeleton.joints.clone();
49
50    for pair in pairs {
51        let left_idx = joints.iter().position(|j| j.name == pair.left_name);
52        let right_idx = joints.iter().position(|j| j.name == pair.right_name);
53
54        if let (Some(li), Some(ri)) = (left_idx, right_idx) {
55            let left_rot = joints[li].rotation;
56            let right_rot = joints[ri].rotation;
57            let left_trans = joints[li].translation;
58            let right_trans = joints[ri].translation;
59
60            joints[li].rotation = mirror_joint_rotation(right_rot, pair.mirror_axis);
61            joints[ri].rotation = mirror_joint_rotation(left_rot, pair.mirror_axis);
62
63            // Mirror translation across the axis
64            let mut new_left_trans = right_trans;
65            let mut new_right_trans = left_trans;
66            let ax = pair.mirror_axis as usize;
67            new_left_trans[ax] = -right_trans[ax];
68            new_right_trans[ax] = -left_trans[ax];
69
70            joints[li].translation = new_left_trans;
71            joints[ri].translation = new_right_trans;
72        }
73    }
74
75    PoseSkeleton { joints }
76}
77
78/// Blend the skeleton toward its symmetric version by `blend` (0 = original, 1 = fully symmetric).
79#[allow(dead_code)]
80pub fn enforce_symmetry_pose(skeleton: &mut PoseSkeleton, pairs: &[SymmetryPair], blend: f32) {
81    let blend = blend.clamp(0.0, 1.0);
82    let mirrored = mirror_pose(skeleton, pairs);
83
84    for (joint, mirrored_joint) in skeleton.joints.iter_mut().zip(mirrored.joints.iter()) {
85        joint.rotation = quat_slerp_pose(joint.rotation, mirrored_joint.rotation, blend * 0.5);
86
87        for i in 0..3 {
88            joint.translation[i] +=
89                (mirrored_joint.translation[i] - joint.translation[i]) * blend * 0.5;
90        }
91    }
92}
93
94/// Compute RMS asymmetry error across all symmetry pairs.
95#[allow(dead_code)]
96pub fn pose_symmetry_error(skeleton: &PoseSkeleton, pairs: &[SymmetryPair]) -> f32 {
97    let mut sum_sq = 0.0_f32;
98    let mut count = 0;
99
100    for pair in pairs {
101        let left = find_joint_by_name(skeleton, &pair.left_name);
102        let right = find_joint_by_name(skeleton, &pair.right_name);
103
104        if let (Some(l), Some(r)) = (left, right) {
105            let mirrored_r = mirror_joint_rotation(r.rotation, pair.mirror_axis);
106            // Quaternion angle distance
107            let dist = quat_angle_distance(l.rotation, mirrored_r);
108            sum_sq += dist * dist;
109            count += 1;
110        }
111    }
112
113    if count == 0 {
114        0.0
115    } else {
116        (sum_sq / count as f32).sqrt()
117    }
118}
119
120/// Return canonical left/right joint pairs for a standard biped skeleton.
121#[allow(dead_code)]
122pub fn standard_biped_symmetry_pairs() -> Vec<SymmetryPair> {
123    let pairs_data = [
124        ("LeftArm", "RightArm"),
125        ("LeftForeArm", "RightForeArm"),
126        ("LeftHand", "RightHand"),
127        ("LeftUpLeg", "RightUpLeg"),
128        ("LeftLeg", "RightLeg"),
129        ("LeftFoot", "RightFoot"),
130        ("LeftToeBase", "RightToeBase"),
131        ("LeftShoulder", "RightShoulder"),
132        ("LeftHandThumb1", "RightHandThumb1"),
133        ("LeftHandIndex1", "RightHandIndex1"),
134        ("LeftHandMiddle1", "RightHandMiddle1"),
135        ("LeftHandRing1", "RightHandRing1"),
136        ("LeftHandPinky1", "RightHandPinky1"),
137    ];
138
139    pairs_data
140        .iter()
141        .map(|(l, r)| SymmetryPair {
142            left_name: l.to_string(),
143            right_name: r.to_string(),
144            mirror_axis: 0, // X axis for biped
145        })
146        .collect()
147}
148
149/// Find a joint by name in a skeleton.
150#[allow(dead_code)]
151pub fn find_joint_by_name<'a>(skeleton: &'a PoseSkeleton, name: &str) -> Option<&'a JointPose> {
152    skeleton.joints.iter().find(|j| j.name == name)
153}
154
155/// Quaternion slerp used internally.
156#[allow(dead_code)]
157pub fn quat_slerp_pose(a: [f32; 4], b: [f32; 4], t: f32) -> [f32; 4] {
158    let t = t.clamp(0.0, 1.0);
159    let [ax, ay, az, aw] = a;
160    let [mut bx, mut by, mut bz, mut bw] = b;
161
162    let mut dot = ax * bx + ay * by + az * bz + aw * bw;
163    if dot < 0.0 {
164        bx = -bx;
165        by = -by;
166        bz = -bz;
167        bw = -bw;
168        dot = -dot;
169    }
170
171    if dot > 0.9995 {
172        // Linear interpolation for nearly identical quaternions
173        let rx = ax + t * (bx - ax);
174        let ry = ay + t * (by - ay);
175        let rz = az + t * (bz - az);
176        let rw = aw + t * (bw - aw);
177        let mag = (rx * rx + ry * ry + rz * rz + rw * rw).sqrt().max(1e-8);
178        return [rx / mag, ry / mag, rz / mag, rw / mag];
179    }
180
181    let theta_0 = dot.acos();
182    let theta = theta_0 * t;
183    let sin_theta = theta.sin();
184    let sin_theta_0 = theta_0.sin();
185
186    let s0 = (theta_0 - theta).sin() / sin_theta_0;
187    let s1 = sin_theta / sin_theta_0;
188
189    [
190        s0 * ax + s1 * bx,
191        s0 * ay + s1 * by,
192        s0 * az + s1 * bz,
193        s0 * aw + s1 * bw,
194    ]
195}
196
197/// Interpolate two skeletons joint-by-joint using slerp for rotations.
198#[allow(dead_code)]
199pub fn interpolate_poses(a: &PoseSkeleton, b: &PoseSkeleton, t: f32) -> PoseSkeleton {
200    let t = t.clamp(0.0, 1.0);
201    let joints = a
202        .joints
203        .iter()
204        .zip(b.joints.iter())
205        .map(|(ja, jb)| {
206            let lerp = |x: f32, y: f32| x + (y - x) * t;
207            JointPose {
208                name: ja.name.clone(),
209                rotation: quat_slerp_pose(ja.rotation, jb.rotation, t),
210                translation: [
211                    lerp(ja.translation[0], jb.translation[0]),
212                    lerp(ja.translation[1], jb.translation[1]),
213                    lerp(ja.translation[2], jb.translation[2]),
214                ],
215                scale: lerp(ja.scale, jb.scale),
216            }
217        })
218        .collect();
219    PoseSkeleton { joints }
220}
221
222/// Auto-detect symmetry pairs from joint names containing "Left" and "Right".
223#[allow(dead_code)]
224pub fn detect_symmetry_pairs(joint_names: &[String]) -> Vec<SymmetryPair> {
225    let mut pairs = Vec::new();
226    for name in joint_names {
227        if let Some(suffix) = name.strip_prefix("Left") {
228            let right_name = format!("Right{suffix}");
229            if joint_names.iter().any(|n| n == &right_name) {
230                pairs.push(SymmetryPair {
231                    left_name: name.clone(),
232                    right_name,
233                    mirror_axis: 0,
234                });
235            }
236        }
237    }
238    pairs
239}
240
241/// Mean quaternion rotation distance across matching joints.
242#[allow(dead_code)]
243pub fn pose_distance_sym(a: &PoseSkeleton, b: &PoseSkeleton) -> f32 {
244    let pairs: Vec<_> = a.joints.iter().zip(b.joints.iter()).collect();
245    if pairs.is_empty() {
246        return 0.0;
247    }
248    let sum: f32 = pairs
249        .iter()
250        .map(|(ja, jb)| quat_angle_distance(ja.rotation, jb.rotation))
251        .sum();
252    sum / pairs.len() as f32
253}
254
255/// Apply a rotation delta to a specific joint via quaternion composition.
256#[allow(dead_code)]
257pub fn apply_pose_offset(skeleton: &mut PoseSkeleton, joint_name: &str, rotation_delta: [f32; 4]) {
258    if let Some(joint) = skeleton.joints.iter_mut().find(|j| j.name == joint_name) {
259        joint.rotation = quat_multiply_pose(joint.rotation, rotation_delta);
260        // Normalize
261        let [x, y, z, w] = joint.rotation;
262        let mag = (x * x + y * y + z * z + w * w).sqrt().max(1e-8);
263        joint.rotation = [x / mag, y / mag, z / mag, w / mag];
264    }
265}
266
267// Internal helpers
268
269fn quat_multiply_pose(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
270    let [ax, ay, az, aw] = a;
271    let [bx, by, bz, bw] = b;
272    [
273        aw * bx + ax * bw + ay * bz - az * by,
274        aw * by - ax * bz + ay * bw + az * bx,
275        aw * bz + ax * by - ay * bx + az * bw,
276        aw * bw - ax * bx - ay * by - az * bz,
277    ]
278}
279
280fn quat_angle_distance(a: [f32; 4], b: [f32; 4]) -> f32 {
281    let dot = (a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3])
282        .abs()
283        .clamp(0.0, 1.0);
284    2.0 * dot.acos()
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    fn identity_quat() -> [f32; 4] {
292        [0.0, 0.0, 0.0, 1.0]
293    }
294
295    fn make_joint(name: &str) -> JointPose {
296        JointPose {
297            name: name.to_string(),
298            rotation: identity_quat(),
299            translation: [0.0, 0.0, 0.0],
300            scale: 1.0,
301        }
302    }
303
304    fn make_simple_skeleton() -> PoseSkeleton {
305        PoseSkeleton {
306            joints: vec![
307                make_joint("LeftArm"),
308                make_joint("RightArm"),
309                make_joint("Spine"),
310            ],
311        }
312    }
313
314    #[test]
315    fn test_mirror_identity_quat() {
316        let q = identity_quat();
317        let mirrored = mirror_joint_rotation(q, 0);
318        // Identity mirrored should have negated x and w components
319        assert_eq!(mirrored, [0.0, 0.0, 0.0, -1.0]);
320    }
321
322    #[test]
323    fn test_mirror_pose_swaps_joints() {
324        let mut skel = make_simple_skeleton();
325        skel.joints[0].translation = [1.0, 0.0, 0.0]; // LeftArm
326        skel.joints[1].translation = [-1.0, 0.0, 0.0]; // RightArm
327
328        let pairs = vec![SymmetryPair {
329            left_name: "LeftArm".to_string(),
330            right_name: "RightArm".to_string(),
331            mirror_axis: 0,
332        }];
333
334        let mirrored = mirror_pose(&skel, &pairs);
335        // After mirroring, left should have the mirrored-right translation
336        assert!((mirrored.joints[0].translation[0] - 1.0).abs() < 1e-4);
337    }
338
339    #[test]
340    fn test_enforce_symmetry_reduces_error() {
341        let mut skel = PoseSkeleton {
342            joints: vec![
343                JointPose {
344                    name: "LeftArm".to_string(),
345                    rotation: [0.1, 0.0, 0.0, (1.0_f32 - 0.01_f32).sqrt()],
346                    translation: [1.0, 0.0, 0.0],
347                    scale: 1.0,
348                },
349                JointPose {
350                    name: "RightArm".to_string(),
351                    rotation: identity_quat(),
352                    translation: [-1.0, 0.0, 0.0],
353                    scale: 1.0,
354                },
355            ],
356        };
357        let pairs = vec![SymmetryPair {
358            left_name: "LeftArm".to_string(),
359            right_name: "RightArm".to_string(),
360            mirror_axis: 0,
361        }];
362        let err_before = pose_symmetry_error(&skel, &pairs);
363        enforce_symmetry_pose(&mut skel, &pairs, 1.0);
364        let err_after = pose_symmetry_error(&skel, &pairs);
365        assert!(
366            err_after <= err_before + 1e-4,
367            "symmetry error should not increase"
368        );
369    }
370
371    #[test]
372    fn test_pose_symmetry_error_symmetric_skeleton() {
373        let skel = make_simple_skeleton();
374        let pairs = standard_biped_symmetry_pairs();
375        // No matching pairs for this small skeleton
376        let err = pose_symmetry_error(&skel, &pairs);
377        assert_eq!(err, 0.0);
378    }
379
380    #[test]
381    fn test_standard_biped_symmetry_pairs_not_empty() {
382        let pairs = standard_biped_symmetry_pairs();
383        assert!(!pairs.is_empty());
384        assert!(pairs.iter().any(|p| p.left_name.contains("Arm")));
385    }
386
387    #[test]
388    fn test_find_joint_by_name() {
389        let skel = make_simple_skeleton();
390        let joint = find_joint_by_name(&skel, "Spine");
391        assert!(joint.is_some());
392        assert_eq!(joint.expect("should succeed").name, "Spine");
393    }
394
395    #[test]
396    fn test_find_joint_missing() {
397        let skel = make_simple_skeleton();
398        assert!(find_joint_by_name(&skel, "NonExistent").is_none());
399    }
400
401    #[test]
402    fn test_quat_slerp_t0() {
403        let a = identity_quat();
404        let b = [0.0, 0.0, 1.0_f32.sin(), 1.0_f32.cos()];
405        let result = quat_slerp_pose(a, b, 0.0);
406        assert!((result[3] - a[3]).abs() < 1e-4);
407    }
408
409    #[test]
410    fn test_quat_slerp_t1() {
411        let a = identity_quat();
412        let b = [0.0, 0.0, (0.5_f32).sin(), (0.5_f32).cos()];
413        let result = quat_slerp_pose(a, b, 1.0);
414        assert!((result[2] - b[2]).abs() < 1e-4);
415        assert!((result[3] - b[3]).abs() < 1e-4);
416    }
417
418    #[test]
419    fn test_interpolate_poses_midpoint() {
420        let a = PoseSkeleton {
421            joints: vec![JointPose {
422                name: "Root".to_string(),
423                rotation: identity_quat(),
424                translation: [0.0, 0.0, 0.0],
425                scale: 1.0,
426            }],
427        };
428        let b = PoseSkeleton {
429            joints: vec![JointPose {
430                name: "Root".to_string(),
431                rotation: identity_quat(),
432                translation: [2.0, 0.0, 0.0],
433                scale: 2.0,
434            }],
435        };
436        let mid = interpolate_poses(&a, &b, 0.5);
437        assert!((mid.joints[0].translation[0] - 1.0).abs() < 1e-4);
438        assert!((mid.joints[0].scale - 1.5).abs() < 1e-4);
439    }
440
441    #[test]
442    fn test_detect_symmetry_pairs() {
443        let names: Vec<String> = vec![
444            "LeftArm".to_string(),
445            "RightArm".to_string(),
446            "LeftLeg".to_string(),
447            "RightLeg".to_string(),
448            "Spine".to_string(),
449        ];
450        let pairs = detect_symmetry_pairs(&names);
451        assert_eq!(pairs.len(), 2);
452        assert!(pairs.iter().any(|p| p.left_name == "LeftArm"));
453    }
454
455    #[test]
456    fn test_detect_symmetry_pairs_no_match() {
457        let names: Vec<String> = vec!["Spine".to_string(), "Hips".to_string()];
458        let pairs = detect_symmetry_pairs(&names);
459        assert!(pairs.is_empty());
460    }
461
462    #[test]
463    fn test_pose_distance_sym_identity() {
464        let a = make_simple_skeleton();
465        let b = a.clone();
466        let dist = pose_distance_sym(&a, &b);
467        assert!(dist < 1e-4);
468    }
469
470    #[test]
471    fn test_apply_pose_offset() {
472        let mut skel = make_simple_skeleton();
473        let delta = [0.0, 0.0, (0.1_f32).sin(), (0.1_f32).cos()];
474        apply_pose_offset(&mut skel, "LeftArm", delta);
475        // Should not remain identity
476        let joint = find_joint_by_name(&skel, "LeftArm").expect("should succeed");
477        let still_identity = joint.rotation[3].abs() > 0.9999;
478        // With a non-trivial delta, rotation should change
479        assert!(!still_identity || delta[3] > 0.9999);
480    }
481
482    #[test]
483    fn test_apply_pose_offset_missing_joint() {
484        let mut skel = make_simple_skeleton();
485        // Should not panic on missing joint
486        apply_pose_offset(&mut skel, "NonExistent", identity_quat());
487    }
488}