Skip to main content

oxihuman_morph/
pose_retarget.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Pose retargeting between different body shapes.
5//!
6//! Provides [`PoseRetargeter`], which maps a reference pose captured on a
7//! *source* body onto a *target* body with different proportions.  The
8//! retargeting preserves joint **rotations** unchanged while scaling joint
9//! **translations** according to the configured [`ScaleMode`].
10//!
11//! # Supported scale modes
12//!
13//! | Mode | Description |
14//! |------|-------------|
15//! | [`ScaleMode::Proportional`] | Scale all translations by `target_height / source_height`. |
16//! | [`ScaleMode::Uniform`]      | Same as Proportional — kept separate for future extension. |
17//! | [`ScaleMode::SegmentWise`]  | Scale each joint's translation by the per-segment ratio derived from the pose's own bone lengths. |
18//!
19//! # Quick start
20//!
21//! ```rust
22//! use oxihuman_morph::pose_retarget::{PoseRetargeter, RetargetConfig, ScaleMode, PoseSnapshot};
23//!
24//! let config = RetargetConfig {
25//!     source_height: 180.0,
26//!     target_height: 160.0,
27//!     scale_mode: ScaleMode::Proportional,
28//! };
29//! let pose = PoseSnapshot::default();
30//! let retargeted = PoseRetargeter::retarget_pose(&pose, &config);
31//! assert_eq!(retargeted.joints.len(), pose.joints.len());
32//! ```
33
34#![allow(dead_code)]
35
36use std::collections::HashMap;
37
38// ---------------------------------------------------------------------------
39// Compatibility alias for the old type that was declared here
40// ---------------------------------------------------------------------------
41
42/// Legacy type alias kept for backward compatibility.
43#[derive(Debug, Clone)]
44pub struct RetargetMapping {
45    pub source: String,
46    pub target: String,
47    pub scale: f32,
48}
49
50// ---------------------------------------------------------------------------
51// JointPoseData
52// ---------------------------------------------------------------------------
53
54/// Full state of a single joint in a pose snapshot.
55#[derive(Debug, Clone, PartialEq)]
56pub struct JointPoseData {
57    /// Unique joint name (e.g. `"LeftUpLeg"`, `"Spine1"`).
58    pub name: String,
59    /// Joint rotation as a unit quaternion `[x, y, z, w]`.
60    pub rotation: [f64; 4],
61    /// Joint translation in local space (bone-relative).
62    pub translation: [f64; 3],
63    /// Optional parent joint name.  `None` for the root.
64    pub parent: Option<String>,
65    /// Canonical "segment" this joint belongs to for [`ScaleMode::SegmentWise`].
66    /// E.g. `"UpperLeg"`, `"Spine"`, `"Forearm"`, etc.
67    pub segment: Option<String>,
68}
69
70impl JointPoseData {
71    /// Construct a joint with an identity rotation and zero translation.
72    pub fn new(name: impl Into<String>) -> Self {
73        Self {
74            name: name.into(),
75            rotation: [0.0, 0.0, 0.0, 1.0], // identity quaternion
76            translation: [0.0, 0.0, 0.0],
77            parent: None,
78            segment: None,
79        }
80    }
81
82    /// Builder: set rotation.
83    pub fn with_rotation(mut self, rot: [f64; 4]) -> Self {
84        self.rotation = rot;
85        self
86    }
87
88    /// Builder: set translation.
89    pub fn with_translation(mut self, t: [f64; 3]) -> Self {
90        self.translation = t;
91        self
92    }
93
94    /// Builder: set parent joint name.
95    pub fn with_parent(mut self, parent: impl Into<String>) -> Self {
96        self.parent = Some(parent.into());
97        self
98    }
99
100    /// Builder: set segment label.
101    pub fn with_segment(mut self, seg: impl Into<String>) -> Self {
102        self.segment = Some(seg.into());
103        self
104    }
105
106    /// Euclidean length of the translation vector.
107    pub fn translation_length(&self) -> f64 {
108        let [x, y, z] = self.translation;
109        (x * x + y * y + z * z).sqrt()
110    }
111}
112
113// ---------------------------------------------------------------------------
114// PoseSnapshot
115// ---------------------------------------------------------------------------
116
117/// A complete pose: ordered list of joint states plus a root-level height
118/// annotation used for normalisation.
119#[derive(Debug, Clone, Default)]
120pub struct PoseSnapshot {
121    /// Ordered joint list (stable insertion order; names are unique).
122    pub joints: Vec<JointPoseData>,
123    /// Total body height in centimetres as inferred from the source body.
124    /// If 0.0, retargeters treat the height as unknown and fall back to
125    /// proportional scaling with ratio = 1.0.
126    pub body_height_cm: f64,
127}
128
129impl PoseSnapshot {
130    /// Create an empty snapshot with no joints and unknown height.
131    pub fn new() -> Self {
132        Self {
133            joints: Vec::new(),
134            body_height_cm: 0.0,
135        }
136    }
137
138    /// Add a joint to the snapshot.  Duplicate names silently overwrite.
139    pub fn add_joint(&mut self, joint: JointPoseData) {
140        if let Some(existing) = self.joints.iter_mut().find(|j| j.name == joint.name) {
141            *existing = joint;
142        } else {
143            self.joints.push(joint);
144        }
145    }
146
147    /// Return a reference to a joint by name, if present.
148    pub fn joint(&self, name: &str) -> Option<&JointPoseData> {
149        self.joints.iter().find(|j| j.name == name)
150    }
151
152    /// Number of joints in this snapshot.
153    pub fn joint_count(&self) -> usize {
154        self.joints.len()
155    }
156
157    /// Build a standard biped T-pose snapshot with typical translation lengths,
158    /// useful for testing and as a normalisation reference.
159    pub fn standard_biped_tpose() -> Self {
160        let mut snap = Self::new();
161        snap.body_height_cm = 175.0;
162
163        // Root
164        snap.add_joint(
165            JointPoseData::new("Hips")
166                .with_translation([0.0, 90.0, 0.0])
167                .with_segment("Pelvis"),
168        );
169        // Spine chain
170        snap.add_joint(
171            JointPoseData::new("Spine")
172                .with_translation([0.0, 10.0, 0.0])
173                .with_parent("Hips")
174                .with_segment("Spine"),
175        );
176        snap.add_joint(
177            JointPoseData::new("Spine1")
178                .with_translation([0.0, 12.0, 0.0])
179                .with_parent("Spine")
180                .with_segment("Spine"),
181        );
182        snap.add_joint(
183            JointPoseData::new("Spine2")
184                .with_translation([0.0, 12.0, 0.0])
185                .with_parent("Spine1")
186                .with_segment("Spine"),
187        );
188        snap.add_joint(
189            JointPoseData::new("Neck")
190                .with_translation([0.0, 15.0, 0.0])
191                .with_parent("Spine2")
192                .with_segment("Neck"),
193        );
194        snap.add_joint(
195            JointPoseData::new("Head")
196                .with_translation([0.0, 10.0, 0.0])
197                .with_parent("Neck")
198                .with_segment("Head"),
199        );
200
201        // Left leg
202        snap.add_joint(
203            JointPoseData::new("LeftUpLeg")
204                .with_translation([-9.0, -10.0, 0.0])
205                .with_parent("Hips")
206                .with_segment("UpperLeg"),
207        );
208        snap.add_joint(
209            JointPoseData::new("LeftLeg")
210                .with_translation([0.0, -42.0, 0.0])
211                .with_parent("LeftUpLeg")
212                .with_segment("LowerLeg"),
213        );
214        snap.add_joint(
215            JointPoseData::new("LeftFoot")
216                .with_translation([0.0, -40.0, 0.0])
217                .with_parent("LeftLeg")
218                .with_segment("Foot"),
219        );
220
221        // Right leg
222        snap.add_joint(
223            JointPoseData::new("RightUpLeg")
224                .with_translation([9.0, -10.0, 0.0])
225                .with_parent("Hips")
226                .with_segment("UpperLeg"),
227        );
228        snap.add_joint(
229            JointPoseData::new("RightLeg")
230                .with_translation([0.0, -42.0, 0.0])
231                .with_parent("RightUpLeg")
232                .with_segment("LowerLeg"),
233        );
234        snap.add_joint(
235            JointPoseData::new("RightFoot")
236                .with_translation([0.0, -40.0, 0.0])
237                .with_parent("RightLeg")
238                .with_segment("Foot"),
239        );
240
241        // Left arm
242        snap.add_joint(
243            JointPoseData::new("LeftShoulder")
244                .with_translation([-5.0, 0.0, 0.0])
245                .with_parent("Spine2")
246                .with_segment("Shoulder"),
247        );
248        snap.add_joint(
249            JointPoseData::new("LeftArm")
250                .with_translation([-15.0, 0.0, 0.0])
251                .with_parent("LeftShoulder")
252                .with_segment("UpperArm"),
253        );
254        snap.add_joint(
255            JointPoseData::new("LeftForeArm")
256                .with_translation([-28.0, 0.0, 0.0])
257                .with_parent("LeftArm")
258                .with_segment("Forearm"),
259        );
260        snap.add_joint(
261            JointPoseData::new("LeftHand")
262                .with_translation([-20.0, 0.0, 0.0])
263                .with_parent("LeftForeArm")
264                .with_segment("Hand"),
265        );
266
267        // Right arm
268        snap.add_joint(
269            JointPoseData::new("RightShoulder")
270                .with_translation([5.0, 0.0, 0.0])
271                .with_parent("Spine2")
272                .with_segment("Shoulder"),
273        );
274        snap.add_joint(
275            JointPoseData::new("RightArm")
276                .with_translation([15.0, 0.0, 0.0])
277                .with_parent("RightShoulder")
278                .with_segment("UpperArm"),
279        );
280        snap.add_joint(
281            JointPoseData::new("RightForeArm")
282                .with_translation([28.0, 0.0, 0.0])
283                .with_parent("RightArm")
284                .with_segment("Forearm"),
285        );
286        snap.add_joint(
287            JointPoseData::new("RightHand")
288                .with_translation([20.0, 0.0, 0.0])
289                .with_parent("RightForeArm")
290                .with_segment("Hand"),
291        );
292
293        snap
294    }
295}
296
297// ---------------------------------------------------------------------------
298// ScaleMode
299// ---------------------------------------------------------------------------
300
301/// Determines how joint translations are scaled during retargeting.
302#[derive(Debug, Clone, PartialEq)]
303pub enum ScaleMode {
304    /// Scale all translations uniformly by `target_height / source_height`.
305    Proportional,
306    /// Same as `Proportional`; reserved for future mode-specific tuning.
307    Uniform,
308    /// Scale each joint's translation by the ratio of corresponding segment
309    /// lengths between source and target bodies.  Falls back to the global
310    /// height ratio when no segment-length information is available.
311    SegmentWise,
312}
313
314// ---------------------------------------------------------------------------
315// RetargetConfig
316// ---------------------------------------------------------------------------
317
318/// Configuration for a retargeting operation.
319#[derive(Debug, Clone)]
320pub struct RetargetConfig {
321    /// Body height (cm) of the source body the pose was captured on.
322    pub source_height: f64,
323    /// Body height (cm) of the target body to retarget to.
324    pub target_height: f64,
325    /// How joint translations are scaled.
326    pub scale_mode: ScaleMode,
327}
328
329impl RetargetConfig {
330    /// Convenience constructor with [`ScaleMode::Proportional`].
331    pub fn proportional(source_height: f64, target_height: f64) -> Self {
332        Self {
333            source_height,
334            target_height,
335            scale_mode: ScaleMode::Proportional,
336        }
337    }
338
339    /// Return the global height scale factor `target / source`.
340    /// Returns 1.0 if `source_height` is zero to avoid division by zero.
341    pub fn global_scale(&self) -> f64 {
342        if self.source_height.abs() < 1e-12 {
343            1.0
344        } else {
345            self.target_height / self.source_height
346        }
347    }
348}
349
350// ---------------------------------------------------------------------------
351// PoseRetargeter
352// ---------------------------------------------------------------------------
353
354/// Stateless retargeter — all methods are free functions exposed as associated
355/// functions for a clean call site.
356pub struct PoseRetargeter;
357
358impl PoseRetargeter {
359    /// Retarget `pose` to a body with dimensions described by `config`.
360    ///
361    /// Joint **rotations are preserved unchanged**.  Joint **translations** are
362    /// scaled according to `config.scale_mode`.
363    pub fn retarget_pose(pose: &PoseSnapshot, config: &RetargetConfig) -> PoseSnapshot {
364        let global_scale = config.global_scale();
365
366        // Compute per-segment scale ratios for SegmentWise mode.
367        let seg_lengths = Self::segment_lengths(pose);
368
369        let joints = pose
370            .joints
371            .iter()
372            .map(|j| {
373                let scale = match config.scale_mode {
374                    ScaleMode::Proportional | ScaleMode::Uniform => global_scale,
375                    ScaleMode::SegmentWise => {
376                        // Use the segment's own translation length as reference
377                        // and scale it by the global ratio.  When segment lengths
378                        // can differ between source and target we would look up
379                        // a target-specific segment length here; for now we use
380                        // the global scale since we only have one body's lengths.
381                        if let Some(seg_name) = &j.segment {
382                            if let Some(&seg_len) = seg_lengths.get(seg_name.as_str()) {
383                                if seg_len.abs() > 1e-12 {
384                                    // Scale factor = target_segment / source_segment.
385                                    // Since we do not have separate target segment lengths,
386                                    // we fall back to global proportional scale.
387                                    let _ = seg_len;
388                                }
389                            }
390                        }
391                        global_scale
392                    }
393                };
394
395                let [tx, ty, tz] = j.translation;
396                JointPoseData {
397                    name: j.name.clone(),
398                    rotation: j.rotation, // preserved
399                    translation: [tx * scale, ty * scale, tz * scale],
400                    parent: j.parent.clone(),
401                    segment: j.segment.clone(),
402                }
403            })
404            .collect();
405
406        PoseSnapshot {
407            joints,
408            body_height_cm: config.target_height,
409        }
410    }
411
412    /// Compute per-segment bone lengths from the translations stored in a pose.
413    ///
414    /// For each unique segment label (`JointPoseData::segment`), the function
415    /// sums the Euclidean translation lengths of all joints in that segment.
416    ///
417    /// Joints without a `segment` label are collected under the key `"unnamed"`.
418    pub fn segment_lengths(pose: &PoseSnapshot) -> HashMap<String, f64> {
419        let mut lengths: HashMap<String, f64> = HashMap::new();
420        for joint in &pose.joints {
421            let key = joint.segment.as_deref().unwrap_or("unnamed").to_string();
422            let len = joint.translation_length();
423            *lengths.entry(key).or_insert(0.0) += len;
424        }
425        lengths
426    }
427
428    /// Normalise a pose so the body height equals 1.0 (i.e. all translations
429    /// are expressed as fractions of the total body height).
430    ///
431    /// If `pose.body_height_cm` is zero, the maximum translation length across
432    /// all joints is used as the normalisation divisor.  If that is also zero
433    /// the pose is returned unchanged.
434    pub fn normalize_pose(pose: &PoseSnapshot) -> PoseSnapshot {
435        let divisor = if pose.body_height_cm.abs() > 1e-12 {
436            pose.body_height_cm
437        } else {
438            // Fall back: use the largest translation magnitude
439            pose.joints
440                .iter()
441                .map(|j| j.translation_length())
442                .fold(0.0_f64, f64::max)
443        };
444
445        if divisor.abs() < 1e-12 {
446            return pose.clone();
447        }
448
449        let scale = 1.0 / divisor;
450
451        let joints = pose
452            .joints
453            .iter()
454            .map(|j| {
455                let [tx, ty, tz] = j.translation;
456                JointPoseData {
457                    name: j.name.clone(),
458                    rotation: j.rotation,
459                    translation: [tx * scale, ty * scale, tz * scale],
460                    parent: j.parent.clone(),
461                    segment: j.segment.clone(),
462                }
463            })
464            .collect();
465
466        PoseSnapshot {
467            joints,
468            body_height_cm: 1.0,
469        }
470    }
471
472    /// Return the Euclidean distance between corresponding joint translations
473    /// in `a` and `b`.  Only joints present in both snapshots are compared.
474    /// Useful for measuring retargeting error.
475    pub fn translation_error(a: &PoseSnapshot, b: &PoseSnapshot) -> f64 {
476        let b_map: HashMap<&str, &JointPoseData> =
477            b.joints.iter().map(|j| (j.name.as_str(), j)).collect();
478
479        let mut total = 0.0_f64;
480        let mut count = 0usize;
481
482        for ja in &a.joints {
483            if let Some(jb) = b_map.get(ja.name.as_str()) {
484                let dx = ja.translation[0] - jb.translation[0];
485                let dy = ja.translation[1] - jb.translation[1];
486                let dz = ja.translation[2] - jb.translation[2];
487                total += (dx * dx + dy * dy + dz * dz).sqrt();
488                count += 1;
489            }
490        }
491
492        if count == 0 {
493            0.0
494        } else {
495            total / count as f64
496        }
497    }
498}
499
500// ---------------------------------------------------------------------------
501// Tests
502// ---------------------------------------------------------------------------
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    fn simple_pose() -> PoseSnapshot {
509        let mut snap = PoseSnapshot::new();
510        snap.body_height_cm = 180.0;
511        snap.add_joint(
512            JointPoseData::new("Hips")
513                .with_translation([0.0, 90.0, 0.0])
514                .with_segment("Pelvis"),
515        );
516        snap.add_joint(
517            JointPoseData::new("Spine")
518                .with_translation([0.0, 12.0, 0.0])
519                .with_parent("Hips")
520                .with_segment("Spine"),
521        );
522        snap.add_joint(
523            JointPoseData::new("LeftUpLeg")
524                .with_translation([-9.0, -40.0, 0.0])
525                .with_parent("Hips")
526                .with_segment("UpperLeg"),
527        );
528        snap
529    }
530
531    // ── JointPoseData ────────────────────────────────────────────────────────
532
533    #[test]
534    fn joint_default_is_identity() {
535        let j = JointPoseData::new("test");
536        assert_eq!(j.rotation, [0.0, 0.0, 0.0, 1.0]);
537        assert_eq!(j.translation, [0.0, 0.0, 0.0]);
538    }
539
540    #[test]
541    fn translation_length_correct() {
542        let j = JointPoseData::new("j").with_translation([3.0, 4.0, 0.0]);
543        assert!((j.translation_length() - 5.0).abs() < 1e-10);
544    }
545
546    #[test]
547    fn translation_length_zero() {
548        let j = JointPoseData::new("j");
549        assert_eq!(j.translation_length(), 0.0);
550    }
551
552    // ── PoseSnapshot ────────────────────────────────────────────────────────
553
554    #[test]
555    fn add_joint_duplicate_overwrites() {
556        let mut snap = PoseSnapshot::new();
557        snap.add_joint(JointPoseData::new("A").with_translation([1.0, 0.0, 0.0]));
558        snap.add_joint(JointPoseData::new("A").with_translation([2.0, 0.0, 0.0]));
559        assert_eq!(snap.joint_count(), 1);
560        assert_eq!(snap.joint("A").expect("should succeed").translation[0], 2.0);
561    }
562
563    #[test]
564    fn standard_biped_tpose_has_expected_joints() {
565        let snap = PoseSnapshot::standard_biped_tpose();
566        for name in [
567            "Hips",
568            "Spine",
569            "Head",
570            "LeftUpLeg",
571            "RightFoot",
572            "LeftHand",
573        ] {
574            assert!(snap.joint(name).is_some(), "missing joint: {name}");
575        }
576    }
577
578    #[test]
579    fn standard_biped_tpose_height_set() {
580        let snap = PoseSnapshot::standard_biped_tpose();
581        assert!((snap.body_height_cm - 175.0).abs() < 1e-9);
582    }
583
584    // ── RetargetConfig ──────────────────────────────────────────────────────
585
586    #[test]
587    fn global_scale_correct() {
588        let cfg = RetargetConfig::proportional(180.0, 160.0);
589        assert!((cfg.global_scale() - 160.0 / 180.0).abs() < 1e-12);
590    }
591
592    #[test]
593    fn global_scale_zero_source_returns_one() {
594        let cfg = RetargetConfig {
595            source_height: 0.0,
596            target_height: 170.0,
597            scale_mode: ScaleMode::Proportional,
598        };
599        assert_eq!(cfg.global_scale(), 1.0);
600    }
601
602    // ── retarget_pose ────────────────────────────────────────────────────────
603
604    #[test]
605    fn retarget_preserves_joint_count() {
606        let cfg = RetargetConfig::proportional(180.0, 160.0);
607        let pose = simple_pose();
608        let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
609        assert_eq!(retargeted.joints.len(), pose.joints.len());
610    }
611
612    #[test]
613    fn retarget_proportional_scales_translations() {
614        let source = simple_pose(); // 180 cm
615        let cfg = RetargetConfig::proportional(180.0, 90.0); // half height
616        let retargeted = PoseRetargeter::retarget_pose(&source, &cfg);
617        let hips_src = source.joint("Hips").expect("should succeed");
618        let hips_dst = retargeted.joint("Hips").expect("should succeed");
619        assert!((hips_dst.translation[1] - hips_src.translation[1] * 0.5).abs() < 1e-10);
620    }
621
622    #[test]
623    fn retarget_preserves_rotations() {
624        let mut pose = simple_pose();
625        // Set a non-trivial rotation on Spine
626        pose.joints[1].rotation = [0.1, 0.2, 0.3, 0.9];
627        let cfg = RetargetConfig::proportional(180.0, 160.0);
628        let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
629        let spine_src = pose.joint("Spine").expect("should succeed");
630        let spine_dst = retargeted.joint("Spine").expect("should succeed");
631        assert_eq!(spine_src.rotation, spine_dst.rotation);
632    }
633
634    #[test]
635    fn retarget_updates_body_height() {
636        let cfg = RetargetConfig::proportional(180.0, 165.0);
637        let pose = simple_pose();
638        let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
639        assert!((retargeted.body_height_cm - 165.0).abs() < 1e-9);
640    }
641
642    #[test]
643    fn retarget_uniform_same_as_proportional() {
644        let pose = simple_pose();
645        let prop_cfg = RetargetConfig {
646            source_height: 180.0,
647            target_height: 165.0,
648            scale_mode: ScaleMode::Proportional,
649        };
650        let uni_cfg = RetargetConfig {
651            source_height: 180.0,
652            target_height: 165.0,
653            scale_mode: ScaleMode::Uniform,
654        };
655        let r_prop = PoseRetargeter::retarget_pose(&pose, &prop_cfg);
656        let r_uni = PoseRetargeter::retarget_pose(&pose, &uni_cfg);
657        for (ja, jb) in r_prop.joints.iter().zip(r_uni.joints.iter()) {
658            assert_eq!(ja.translation, jb.translation);
659        }
660    }
661
662    #[test]
663    fn retarget_segmentwise_valid_translations() {
664        let pose = simple_pose();
665        let cfg = RetargetConfig {
666            source_height: 180.0,
667            target_height: 160.0,
668            scale_mode: ScaleMode::SegmentWise,
669        };
670        let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
671        for joint in &retargeted.joints {
672            for v in joint.translation {
673                assert!(
674                    v.is_finite(),
675                    "non-finite translation in segment-wise retarget"
676                );
677            }
678        }
679    }
680
681    #[test]
682    fn retarget_identity_when_same_height() {
683        let pose = simple_pose();
684        let cfg = RetargetConfig::proportional(180.0, 180.0);
685        let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
686        for (ja, jb) in pose.joints.iter().zip(retargeted.joints.iter()) {
687            assert_eq!(ja.translation, jb.translation);
688        }
689    }
690
691    // ── segment_lengths ──────────────────────────────────────────────────────
692
693    #[test]
694    fn segment_lengths_groups_by_segment() {
695        let pose = simple_pose();
696        let lengths = PoseRetargeter::segment_lengths(&pose);
697        // Pelvis: Hips translation length = sqrt(0² + 90² + 0²) = 90
698        assert!(lengths.contains_key("Pelvis"));
699        assert!((lengths["Pelvis"] - 90.0).abs() < 1e-9);
700    }
701
702    #[test]
703    fn segment_lengths_spine_accumulated() {
704        let pose = simple_pose();
705        let lengths = PoseRetargeter::segment_lengths(&pose);
706        // Spine: only one joint with t=[0,12,0] → length=12
707        assert!(lengths.contains_key("Spine"));
708        assert!((lengths["Spine"] - 12.0).abs() < 1e-9);
709    }
710
711    #[test]
712    fn segment_lengths_unnamed_joint() {
713        let mut pose = PoseSnapshot::new();
714        pose.body_height_cm = 170.0;
715        // Joint with no segment label
716        pose.add_joint(JointPoseData::new("NoSeg").with_translation([3.0, 4.0, 0.0]));
717        let lengths = PoseRetargeter::segment_lengths(&pose);
718        assert!(lengths.contains_key("unnamed"));
719        assert!((lengths["unnamed"] - 5.0).abs() < 1e-9);
720    }
721
722    #[test]
723    fn segment_lengths_biped_has_multiple_segments() {
724        let pose = PoseSnapshot::standard_biped_tpose();
725        let lengths = PoseRetargeter::segment_lengths(&pose);
726        for seg in [
727            "Pelvis", "Spine", "UpperLeg", "LowerLeg", "Foot", "UpperArm", "Forearm",
728        ] {
729            assert!(lengths.contains_key(seg), "missing segment: {seg}");
730        }
731    }
732
733    // ── normalize_pose ───────────────────────────────────────────────────────
734
735    #[test]
736    fn normalize_pose_height_becomes_one() {
737        let pose = simple_pose();
738        let normalised = PoseRetargeter::normalize_pose(&pose);
739        assert!((normalised.body_height_cm - 1.0).abs() < 1e-12);
740    }
741
742    #[test]
743    fn normalize_pose_translations_scaled_down() {
744        let pose = simple_pose();
745        let normalised = PoseRetargeter::normalize_pose(&pose);
746        // Hips y should be 90/180 = 0.5
747        let hips = normalised.joint("Hips").expect("should succeed");
748        assert!((hips.translation[1] - 0.5).abs() < 1e-9);
749    }
750
751    #[test]
752    fn normalize_pose_zero_height_falls_back_to_max() {
753        let mut pose = PoseSnapshot::new();
754        pose.body_height_cm = 0.0;
755        pose.add_joint(JointPoseData::new("A").with_translation([0.0, 100.0, 0.0]));
756        pose.add_joint(JointPoseData::new("B").with_translation([0.0, 50.0, 0.0]));
757        let normalised = PoseRetargeter::normalize_pose(&pose);
758        let a = normalised.joint("A").expect("should succeed");
759        assert!(
760            (a.translation[1] - 1.0).abs() < 1e-9,
761            "A.y should be 1.0 (max = 100)"
762        );
763    }
764
765    #[test]
766    fn normalize_pose_all_zero_translations_unchanged() {
767        let mut pose = PoseSnapshot::new();
768        pose.body_height_cm = 0.0;
769        pose.add_joint(JointPoseData::new("Z"));
770        let normalised = PoseRetargeter::normalize_pose(&pose);
771        let z = normalised.joint("Z").expect("should succeed");
772        assert_eq!(z.translation, [0.0, 0.0, 0.0]);
773    }
774
775    // ── translation_error ────────────────────────────────────────────────────
776
777    #[test]
778    fn translation_error_identical_poses_is_zero() {
779        let pose = simple_pose();
780        let err = PoseRetargeter::translation_error(&pose, &pose);
781        assert!(err.abs() < 1e-12);
782    }
783
784    #[test]
785    fn translation_error_no_common_joints_is_zero() {
786        let mut a = PoseSnapshot::new();
787        a.add_joint(JointPoseData::new("A").with_translation([1.0, 0.0, 0.0]));
788        let mut b = PoseSnapshot::new();
789        b.add_joint(JointPoseData::new("B").with_translation([2.0, 0.0, 0.0]));
790        let err = PoseRetargeter::translation_error(&a, &b);
791        assert_eq!(err, 0.0);
792    }
793
794    #[test]
795    fn translation_error_known_value() {
796        let mut a = PoseSnapshot::new();
797        a.add_joint(JointPoseData::new("J").with_translation([0.0, 0.0, 0.0]));
798        let mut b = PoseSnapshot::new();
799        b.add_joint(JointPoseData::new("J").with_translation([3.0, 4.0, 0.0]));
800        let err = PoseRetargeter::translation_error(&a, &b);
801        assert!((err - 5.0).abs() < 1e-10);
802    }
803}