1#![allow(dead_code)]
35
36use std::collections::HashMap;
37
38#[derive(Debug, Clone)]
44pub struct RetargetMapping {
45 pub source: String,
46 pub target: String,
47 pub scale: f32,
48}
49
50#[derive(Debug, Clone, PartialEq)]
56pub struct JointPoseData {
57 pub name: String,
59 pub rotation: [f64; 4],
61 pub translation: [f64; 3],
63 pub parent: Option<String>,
65 pub segment: Option<String>,
68}
69
70impl JointPoseData {
71 pub fn new(name: impl Into<String>) -> Self {
73 Self {
74 name: name.into(),
75 rotation: [0.0, 0.0, 0.0, 1.0], translation: [0.0, 0.0, 0.0],
77 parent: None,
78 segment: None,
79 }
80 }
81
82 pub fn with_rotation(mut self, rot: [f64; 4]) -> Self {
84 self.rotation = rot;
85 self
86 }
87
88 pub fn with_translation(mut self, t: [f64; 3]) -> Self {
90 self.translation = t;
91 self
92 }
93
94 pub fn with_parent(mut self, parent: impl Into<String>) -> Self {
96 self.parent = Some(parent.into());
97 self
98 }
99
100 pub fn with_segment(mut self, seg: impl Into<String>) -> Self {
102 self.segment = Some(seg.into());
103 self
104 }
105
106 pub fn translation_length(&self) -> f64 {
108 let [x, y, z] = self.translation;
109 (x * x + y * y + z * z).sqrt()
110 }
111}
112
113#[derive(Debug, Clone, Default)]
120pub struct PoseSnapshot {
121 pub joints: Vec<JointPoseData>,
123 pub body_height_cm: f64,
127}
128
129impl PoseSnapshot {
130 pub fn new() -> Self {
132 Self {
133 joints: Vec::new(),
134 body_height_cm: 0.0,
135 }
136 }
137
138 pub fn add_joint(&mut self, joint: JointPoseData) {
140 if let Some(existing) = self.joints.iter_mut().find(|j| j.name == joint.name) {
141 *existing = joint;
142 } else {
143 self.joints.push(joint);
144 }
145 }
146
147 pub fn joint(&self, name: &str) -> Option<&JointPoseData> {
149 self.joints.iter().find(|j| j.name == name)
150 }
151
152 pub fn joint_count(&self) -> usize {
154 self.joints.len()
155 }
156
157 pub fn standard_biped_tpose() -> Self {
160 let mut snap = Self::new();
161 snap.body_height_cm = 175.0;
162
163 snap.add_joint(
165 JointPoseData::new("Hips")
166 .with_translation([0.0, 90.0, 0.0])
167 .with_segment("Pelvis"),
168 );
169 snap.add_joint(
171 JointPoseData::new("Spine")
172 .with_translation([0.0, 10.0, 0.0])
173 .with_parent("Hips")
174 .with_segment("Spine"),
175 );
176 snap.add_joint(
177 JointPoseData::new("Spine1")
178 .with_translation([0.0, 12.0, 0.0])
179 .with_parent("Spine")
180 .with_segment("Spine"),
181 );
182 snap.add_joint(
183 JointPoseData::new("Spine2")
184 .with_translation([0.0, 12.0, 0.0])
185 .with_parent("Spine1")
186 .with_segment("Spine"),
187 );
188 snap.add_joint(
189 JointPoseData::new("Neck")
190 .with_translation([0.0, 15.0, 0.0])
191 .with_parent("Spine2")
192 .with_segment("Neck"),
193 );
194 snap.add_joint(
195 JointPoseData::new("Head")
196 .with_translation([0.0, 10.0, 0.0])
197 .with_parent("Neck")
198 .with_segment("Head"),
199 );
200
201 snap.add_joint(
203 JointPoseData::new("LeftUpLeg")
204 .with_translation([-9.0, -10.0, 0.0])
205 .with_parent("Hips")
206 .with_segment("UpperLeg"),
207 );
208 snap.add_joint(
209 JointPoseData::new("LeftLeg")
210 .with_translation([0.0, -42.0, 0.0])
211 .with_parent("LeftUpLeg")
212 .with_segment("LowerLeg"),
213 );
214 snap.add_joint(
215 JointPoseData::new("LeftFoot")
216 .with_translation([0.0, -40.0, 0.0])
217 .with_parent("LeftLeg")
218 .with_segment("Foot"),
219 );
220
221 snap.add_joint(
223 JointPoseData::new("RightUpLeg")
224 .with_translation([9.0, -10.0, 0.0])
225 .with_parent("Hips")
226 .with_segment("UpperLeg"),
227 );
228 snap.add_joint(
229 JointPoseData::new("RightLeg")
230 .with_translation([0.0, -42.0, 0.0])
231 .with_parent("RightUpLeg")
232 .with_segment("LowerLeg"),
233 );
234 snap.add_joint(
235 JointPoseData::new("RightFoot")
236 .with_translation([0.0, -40.0, 0.0])
237 .with_parent("RightLeg")
238 .with_segment("Foot"),
239 );
240
241 snap.add_joint(
243 JointPoseData::new("LeftShoulder")
244 .with_translation([-5.0, 0.0, 0.0])
245 .with_parent("Spine2")
246 .with_segment("Shoulder"),
247 );
248 snap.add_joint(
249 JointPoseData::new("LeftArm")
250 .with_translation([-15.0, 0.0, 0.0])
251 .with_parent("LeftShoulder")
252 .with_segment("UpperArm"),
253 );
254 snap.add_joint(
255 JointPoseData::new("LeftForeArm")
256 .with_translation([-28.0, 0.0, 0.0])
257 .with_parent("LeftArm")
258 .with_segment("Forearm"),
259 );
260 snap.add_joint(
261 JointPoseData::new("LeftHand")
262 .with_translation([-20.0, 0.0, 0.0])
263 .with_parent("LeftForeArm")
264 .with_segment("Hand"),
265 );
266
267 snap.add_joint(
269 JointPoseData::new("RightShoulder")
270 .with_translation([5.0, 0.0, 0.0])
271 .with_parent("Spine2")
272 .with_segment("Shoulder"),
273 );
274 snap.add_joint(
275 JointPoseData::new("RightArm")
276 .with_translation([15.0, 0.0, 0.0])
277 .with_parent("RightShoulder")
278 .with_segment("UpperArm"),
279 );
280 snap.add_joint(
281 JointPoseData::new("RightForeArm")
282 .with_translation([28.0, 0.0, 0.0])
283 .with_parent("RightArm")
284 .with_segment("Forearm"),
285 );
286 snap.add_joint(
287 JointPoseData::new("RightHand")
288 .with_translation([20.0, 0.0, 0.0])
289 .with_parent("RightForeArm")
290 .with_segment("Hand"),
291 );
292
293 snap
294 }
295}
296
297#[derive(Debug, Clone, PartialEq)]
303pub enum ScaleMode {
304 Proportional,
306 Uniform,
308 SegmentWise,
312}
313
314#[derive(Debug, Clone)]
320pub struct RetargetConfig {
321 pub source_height: f64,
323 pub target_height: f64,
325 pub scale_mode: ScaleMode,
327}
328
329impl RetargetConfig {
330 pub fn proportional(source_height: f64, target_height: f64) -> Self {
332 Self {
333 source_height,
334 target_height,
335 scale_mode: ScaleMode::Proportional,
336 }
337 }
338
339 pub fn global_scale(&self) -> f64 {
342 if self.source_height.abs() < 1e-12 {
343 1.0
344 } else {
345 self.target_height / self.source_height
346 }
347 }
348}
349
350pub struct PoseRetargeter;
357
358impl PoseRetargeter {
359 pub fn retarget_pose(pose: &PoseSnapshot, config: &RetargetConfig) -> PoseSnapshot {
364 let global_scale = config.global_scale();
365
366 let seg_lengths = Self::segment_lengths(pose);
368
369 let joints = pose
370 .joints
371 .iter()
372 .map(|j| {
373 let scale = match config.scale_mode {
374 ScaleMode::Proportional | ScaleMode::Uniform => global_scale,
375 ScaleMode::SegmentWise => {
376 if let Some(seg_name) = &j.segment {
382 if let Some(&seg_len) = seg_lengths.get(seg_name.as_str()) {
383 if seg_len.abs() > 1e-12 {
384 let _ = seg_len;
388 }
389 }
390 }
391 global_scale
392 }
393 };
394
395 let [tx, ty, tz] = j.translation;
396 JointPoseData {
397 name: j.name.clone(),
398 rotation: j.rotation, translation: [tx * scale, ty * scale, tz * scale],
400 parent: j.parent.clone(),
401 segment: j.segment.clone(),
402 }
403 })
404 .collect();
405
406 PoseSnapshot {
407 joints,
408 body_height_cm: config.target_height,
409 }
410 }
411
412 pub fn segment_lengths(pose: &PoseSnapshot) -> HashMap<String, f64> {
419 let mut lengths: HashMap<String, f64> = HashMap::new();
420 for joint in &pose.joints {
421 let key = joint.segment.as_deref().unwrap_or("unnamed").to_string();
422 let len = joint.translation_length();
423 *lengths.entry(key).or_insert(0.0) += len;
424 }
425 lengths
426 }
427
428 pub fn normalize_pose(pose: &PoseSnapshot) -> PoseSnapshot {
435 let divisor = if pose.body_height_cm.abs() > 1e-12 {
436 pose.body_height_cm
437 } else {
438 pose.joints
440 .iter()
441 .map(|j| j.translation_length())
442 .fold(0.0_f64, f64::max)
443 };
444
445 if divisor.abs() < 1e-12 {
446 return pose.clone();
447 }
448
449 let scale = 1.0 / divisor;
450
451 let joints = pose
452 .joints
453 .iter()
454 .map(|j| {
455 let [tx, ty, tz] = j.translation;
456 JointPoseData {
457 name: j.name.clone(),
458 rotation: j.rotation,
459 translation: [tx * scale, ty * scale, tz * scale],
460 parent: j.parent.clone(),
461 segment: j.segment.clone(),
462 }
463 })
464 .collect();
465
466 PoseSnapshot {
467 joints,
468 body_height_cm: 1.0,
469 }
470 }
471
472 pub fn translation_error(a: &PoseSnapshot, b: &PoseSnapshot) -> f64 {
476 let b_map: HashMap<&str, &JointPoseData> =
477 b.joints.iter().map(|j| (j.name.as_str(), j)).collect();
478
479 let mut total = 0.0_f64;
480 let mut count = 0usize;
481
482 for ja in &a.joints {
483 if let Some(jb) = b_map.get(ja.name.as_str()) {
484 let dx = ja.translation[0] - jb.translation[0];
485 let dy = ja.translation[1] - jb.translation[1];
486 let dz = ja.translation[2] - jb.translation[2];
487 total += (dx * dx + dy * dy + dz * dz).sqrt();
488 count += 1;
489 }
490 }
491
492 if count == 0 {
493 0.0
494 } else {
495 total / count as f64
496 }
497 }
498}
499
500#[cfg(test)]
505mod tests {
506 use super::*;
507
508 fn simple_pose() -> PoseSnapshot {
509 let mut snap = PoseSnapshot::new();
510 snap.body_height_cm = 180.0;
511 snap.add_joint(
512 JointPoseData::new("Hips")
513 .with_translation([0.0, 90.0, 0.0])
514 .with_segment("Pelvis"),
515 );
516 snap.add_joint(
517 JointPoseData::new("Spine")
518 .with_translation([0.0, 12.0, 0.0])
519 .with_parent("Hips")
520 .with_segment("Spine"),
521 );
522 snap.add_joint(
523 JointPoseData::new("LeftUpLeg")
524 .with_translation([-9.0, -40.0, 0.0])
525 .with_parent("Hips")
526 .with_segment("UpperLeg"),
527 );
528 snap
529 }
530
531 #[test]
534 fn joint_default_is_identity() {
535 let j = JointPoseData::new("test");
536 assert_eq!(j.rotation, [0.0, 0.0, 0.0, 1.0]);
537 assert_eq!(j.translation, [0.0, 0.0, 0.0]);
538 }
539
540 #[test]
541 fn translation_length_correct() {
542 let j = JointPoseData::new("j").with_translation([3.0, 4.0, 0.0]);
543 assert!((j.translation_length() - 5.0).abs() < 1e-10);
544 }
545
546 #[test]
547 fn translation_length_zero() {
548 let j = JointPoseData::new("j");
549 assert_eq!(j.translation_length(), 0.0);
550 }
551
552 #[test]
555 fn add_joint_duplicate_overwrites() {
556 let mut snap = PoseSnapshot::new();
557 snap.add_joint(JointPoseData::new("A").with_translation([1.0, 0.0, 0.0]));
558 snap.add_joint(JointPoseData::new("A").with_translation([2.0, 0.0, 0.0]));
559 assert_eq!(snap.joint_count(), 1);
560 assert_eq!(snap.joint("A").expect("should succeed").translation[0], 2.0);
561 }
562
563 #[test]
564 fn standard_biped_tpose_has_expected_joints() {
565 let snap = PoseSnapshot::standard_biped_tpose();
566 for name in [
567 "Hips",
568 "Spine",
569 "Head",
570 "LeftUpLeg",
571 "RightFoot",
572 "LeftHand",
573 ] {
574 assert!(snap.joint(name).is_some(), "missing joint: {name}");
575 }
576 }
577
578 #[test]
579 fn standard_biped_tpose_height_set() {
580 let snap = PoseSnapshot::standard_biped_tpose();
581 assert!((snap.body_height_cm - 175.0).abs() < 1e-9);
582 }
583
584 #[test]
587 fn global_scale_correct() {
588 let cfg = RetargetConfig::proportional(180.0, 160.0);
589 assert!((cfg.global_scale() - 160.0 / 180.0).abs() < 1e-12);
590 }
591
592 #[test]
593 fn global_scale_zero_source_returns_one() {
594 let cfg = RetargetConfig {
595 source_height: 0.0,
596 target_height: 170.0,
597 scale_mode: ScaleMode::Proportional,
598 };
599 assert_eq!(cfg.global_scale(), 1.0);
600 }
601
602 #[test]
605 fn retarget_preserves_joint_count() {
606 let cfg = RetargetConfig::proportional(180.0, 160.0);
607 let pose = simple_pose();
608 let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
609 assert_eq!(retargeted.joints.len(), pose.joints.len());
610 }
611
612 #[test]
613 fn retarget_proportional_scales_translations() {
614 let source = simple_pose(); let cfg = RetargetConfig::proportional(180.0, 90.0); let retargeted = PoseRetargeter::retarget_pose(&source, &cfg);
617 let hips_src = source.joint("Hips").expect("should succeed");
618 let hips_dst = retargeted.joint("Hips").expect("should succeed");
619 assert!((hips_dst.translation[1] - hips_src.translation[1] * 0.5).abs() < 1e-10);
620 }
621
622 #[test]
623 fn retarget_preserves_rotations() {
624 let mut pose = simple_pose();
625 pose.joints[1].rotation = [0.1, 0.2, 0.3, 0.9];
627 let cfg = RetargetConfig::proportional(180.0, 160.0);
628 let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
629 let spine_src = pose.joint("Spine").expect("should succeed");
630 let spine_dst = retargeted.joint("Spine").expect("should succeed");
631 assert_eq!(spine_src.rotation, spine_dst.rotation);
632 }
633
634 #[test]
635 fn retarget_updates_body_height() {
636 let cfg = RetargetConfig::proportional(180.0, 165.0);
637 let pose = simple_pose();
638 let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
639 assert!((retargeted.body_height_cm - 165.0).abs() < 1e-9);
640 }
641
642 #[test]
643 fn retarget_uniform_same_as_proportional() {
644 let pose = simple_pose();
645 let prop_cfg = RetargetConfig {
646 source_height: 180.0,
647 target_height: 165.0,
648 scale_mode: ScaleMode::Proportional,
649 };
650 let uni_cfg = RetargetConfig {
651 source_height: 180.0,
652 target_height: 165.0,
653 scale_mode: ScaleMode::Uniform,
654 };
655 let r_prop = PoseRetargeter::retarget_pose(&pose, &prop_cfg);
656 let r_uni = PoseRetargeter::retarget_pose(&pose, &uni_cfg);
657 for (ja, jb) in r_prop.joints.iter().zip(r_uni.joints.iter()) {
658 assert_eq!(ja.translation, jb.translation);
659 }
660 }
661
662 #[test]
663 fn retarget_segmentwise_valid_translations() {
664 let pose = simple_pose();
665 let cfg = RetargetConfig {
666 source_height: 180.0,
667 target_height: 160.0,
668 scale_mode: ScaleMode::SegmentWise,
669 };
670 let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
671 for joint in &retargeted.joints {
672 for v in joint.translation {
673 assert!(
674 v.is_finite(),
675 "non-finite translation in segment-wise retarget"
676 );
677 }
678 }
679 }
680
681 #[test]
682 fn retarget_identity_when_same_height() {
683 let pose = simple_pose();
684 let cfg = RetargetConfig::proportional(180.0, 180.0);
685 let retargeted = PoseRetargeter::retarget_pose(&pose, &cfg);
686 for (ja, jb) in pose.joints.iter().zip(retargeted.joints.iter()) {
687 assert_eq!(ja.translation, jb.translation);
688 }
689 }
690
691 #[test]
694 fn segment_lengths_groups_by_segment() {
695 let pose = simple_pose();
696 let lengths = PoseRetargeter::segment_lengths(&pose);
697 assert!(lengths.contains_key("Pelvis"));
699 assert!((lengths["Pelvis"] - 90.0).abs() < 1e-9);
700 }
701
702 #[test]
703 fn segment_lengths_spine_accumulated() {
704 let pose = simple_pose();
705 let lengths = PoseRetargeter::segment_lengths(&pose);
706 assert!(lengths.contains_key("Spine"));
708 assert!((lengths["Spine"] - 12.0).abs() < 1e-9);
709 }
710
711 #[test]
712 fn segment_lengths_unnamed_joint() {
713 let mut pose = PoseSnapshot::new();
714 pose.body_height_cm = 170.0;
715 pose.add_joint(JointPoseData::new("NoSeg").with_translation([3.0, 4.0, 0.0]));
717 let lengths = PoseRetargeter::segment_lengths(&pose);
718 assert!(lengths.contains_key("unnamed"));
719 assert!((lengths["unnamed"] - 5.0).abs() < 1e-9);
720 }
721
722 #[test]
723 fn segment_lengths_biped_has_multiple_segments() {
724 let pose = PoseSnapshot::standard_biped_tpose();
725 let lengths = PoseRetargeter::segment_lengths(&pose);
726 for seg in [
727 "Pelvis", "Spine", "UpperLeg", "LowerLeg", "Foot", "UpperArm", "Forearm",
728 ] {
729 assert!(lengths.contains_key(seg), "missing segment: {seg}");
730 }
731 }
732
733 #[test]
736 fn normalize_pose_height_becomes_one() {
737 let pose = simple_pose();
738 let normalised = PoseRetargeter::normalize_pose(&pose);
739 assert!((normalised.body_height_cm - 1.0).abs() < 1e-12);
740 }
741
742 #[test]
743 fn normalize_pose_translations_scaled_down() {
744 let pose = simple_pose();
745 let normalised = PoseRetargeter::normalize_pose(&pose);
746 let hips = normalised.joint("Hips").expect("should succeed");
748 assert!((hips.translation[1] - 0.5).abs() < 1e-9);
749 }
750
751 #[test]
752 fn normalize_pose_zero_height_falls_back_to_max() {
753 let mut pose = PoseSnapshot::new();
754 pose.body_height_cm = 0.0;
755 pose.add_joint(JointPoseData::new("A").with_translation([0.0, 100.0, 0.0]));
756 pose.add_joint(JointPoseData::new("B").with_translation([0.0, 50.0, 0.0]));
757 let normalised = PoseRetargeter::normalize_pose(&pose);
758 let a = normalised.joint("A").expect("should succeed");
759 assert!(
760 (a.translation[1] - 1.0).abs() < 1e-9,
761 "A.y should be 1.0 (max = 100)"
762 );
763 }
764
765 #[test]
766 fn normalize_pose_all_zero_translations_unchanged() {
767 let mut pose = PoseSnapshot::new();
768 pose.body_height_cm = 0.0;
769 pose.add_joint(JointPoseData::new("Z"));
770 let normalised = PoseRetargeter::normalize_pose(&pose);
771 let z = normalised.joint("Z").expect("should succeed");
772 assert_eq!(z.translation, [0.0, 0.0, 0.0]);
773 }
774
775 #[test]
778 fn translation_error_identical_poses_is_zero() {
779 let pose = simple_pose();
780 let err = PoseRetargeter::translation_error(&pose, &pose);
781 assert!(err.abs() < 1e-12);
782 }
783
784 #[test]
785 fn translation_error_no_common_joints_is_zero() {
786 let mut a = PoseSnapshot::new();
787 a.add_joint(JointPoseData::new("A").with_translation([1.0, 0.0, 0.0]));
788 let mut b = PoseSnapshot::new();
789 b.add_joint(JointPoseData::new("B").with_translation([2.0, 0.0, 0.0]));
790 let err = PoseRetargeter::translation_error(&a, &b);
791 assert_eq!(err, 0.0);
792 }
793
794 #[test]
795 fn translation_error_known_value() {
796 let mut a = PoseSnapshot::new();
797 a.add_joint(JointPoseData::new("J").with_translation([0.0, 0.0, 0.0]));
798 let mut b = PoseSnapshot::new();
799 b.add_joint(JointPoseData::new("J").with_translation([3.0, 4.0, 0.0]));
800 let err = PoseRetargeter::translation_error(&a, &b);
801 assert!((err - 5.0).abs() < 1e-10);
802 }
803}