1use crate::interpolate::{Keyframe, MorphTrack};
5use crate::mocap_bvh::{parse_bvh, BvhChannel};
6use crate::params::ParamState;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11#[allow(dead_code)]
12pub struct ParamRetarget {
13 pub scale: f32,
15 pub offset: f32,
17 pub clamp: bool,
19}
20
21impl ParamRetarget {
22 #[allow(dead_code)]
24 pub fn identity() -> Self {
25 Self {
26 scale: 1.0,
27 offset: 0.0,
28 clamp: true,
29 }
30 }
31
32 #[allow(dead_code)]
34 pub fn scaled(factor: f32) -> Self {
35 Self {
36 scale: factor,
37 offset: 0.0,
38 clamp: true,
39 }
40 }
41
42 #[allow(dead_code)]
44 pub fn apply(&self, value: f32) -> f32 {
45 let result = value * self.scale + self.offset;
46 if self.clamp {
47 result.clamp(0.0, 1.0)
48 } else {
49 result
50 }
51 }
52}
53
54#[allow(dead_code)]
56pub struct AnimRetargetConfig {
57 pub height: ParamRetarget,
58 pub weight: ParamRetarget,
59 pub muscle: ParamRetarget,
60 pub age: ParamRetarget,
61}
62
63impl AnimRetargetConfig {
64 #[allow(dead_code)]
66 pub fn identity() -> Self {
67 Self {
68 height: ParamRetarget::identity(),
69 weight: ParamRetarget::identity(),
70 muscle: ParamRetarget::identity(),
71 age: ParamRetarget::identity(),
72 }
73 }
74
75 #[allow(dead_code)]
77 pub fn apply(&self, state: &ParamState) -> ParamState {
78 ParamState {
79 height: self.height.apply(state.height),
80 weight: self.weight.apply(state.weight),
81 muscle: self.muscle.apply(state.muscle),
82 age: self.age.apply(state.age),
83 extra: state.extra.clone(),
84 }
85 }
86}
87
88#[allow(dead_code)]
90pub fn retarget_keyframe(kf: &Keyframe, config: &AnimRetargetConfig) -> Keyframe {
91 Keyframe {
92 time: kf.time,
93 params: config.apply(&kf.params),
94 label: kf.label.clone(),
95 }
96}
97
98#[allow(dead_code)]
100pub fn retarget_track(track: &MorphTrack, config: &AnimRetargetConfig) -> MorphTrack {
101 let mut new_track = MorphTrack::new(track.name.clone());
102 for kf in track.keyframes_iter() {
103 new_track.add_keyframe(retarget_keyframe(kf, config));
104 }
105 new_track
106}
107
108#[allow(dead_code)]
111pub fn scale_track_time(track: &MorphTrack, time_scale: f32) -> MorphTrack {
112 let mut new_track = MorphTrack::new(track.name.clone());
113 for kf in track.keyframes_iter() {
114 new_track.add_keyframe(Keyframe {
115 time: kf.time * time_scale,
116 params: kf.params.clone(),
117 label: kf.label.clone(),
118 });
119 }
120 new_track
121}
122
123#[allow(dead_code)]
126pub fn trim_track(track: &MorphTrack, start_time: f32, end_time: f32) -> MorphTrack {
127 let mut new_track = MorphTrack::new(track.name.clone());
128 for kf in track.keyframes_iter() {
129 if kf.time >= start_time && kf.time <= end_time {
130 new_track.add_keyframe(kf.clone());
131 }
132 }
133 new_track
134}
135
136#[allow(dead_code)]
138pub fn reverse_track(track: &MorphTrack) -> MorphTrack {
139 let mut new_track = MorphTrack::new(track.name.clone());
140 if track.is_empty() {
141 return new_track;
142 }
143 let kfs: Vec<&Keyframe> = track.keyframes_iter().collect();
144 let last_time = kfs[kfs.len() - 1].time;
145 for kf in kfs {
146 new_track.add_keyframe(Keyframe {
147 time: last_time - kf.time,
148 params: kf.params.clone(),
149 label: kf.label.clone(),
150 });
151 }
152 new_track
153}
154
155#[allow(dead_code)]
157pub fn concat_tracks(first: &MorphTrack, second: &MorphTrack) -> MorphTrack {
158 let mut new_track = MorphTrack::new(first.name.clone());
159
160 for kf in first.keyframes_iter() {
161 new_track.add_keyframe(kf.clone());
162 }
163
164 let offset = first
165 .keyframes_iter()
166 .last()
167 .map(|kf| kf.time)
168 .unwrap_or(0.0);
169
170 for kf in second.keyframes_iter() {
171 new_track.add_keyframe(Keyframe {
172 time: kf.time + offset,
173 params: kf.params.clone(),
174 label: kf.label.clone(),
175 });
176 }
177
178 new_track
179}
180
181#[allow(dead_code)]
185#[derive(Debug, Clone)]
186pub struct BvhJointFrame {
187 pub joint_name: String,
189 pub local_rotation: [f32; 4],
191 pub local_position: [f32; 3],
193}
194
195#[allow(dead_code)]
197#[derive(Debug, Clone)]
198pub struct BvhData {
199 pub fps: f32,
201 pub frames: Vec<Vec<BvhJointFrame>>,
203}
204
205#[allow(dead_code)]
207#[derive(Debug, Clone)]
208pub struct SkeletonMapping {
209 pub map: HashMap<String, String>,
211}
212
213impl SkeletonMapping {
214 #[allow(dead_code)]
216 pub fn default_cmu() -> Self {
217 let entries: &[(&str, &str)] = &[
218 ("Hips", "pelvis"),
219 ("Spine", "torso"),
220 ("Spine1", "spine_02"),
221 ("Spine2", "spine_03"),
222 ("Neck", "neck_01"),
223 ("Head", "head"),
224 ("LeftArm", "left_shoulder"),
225 ("LeftForeArm", "left_elbow"),
226 ("LeftHand", "left_wrist"),
227 ("RightArm", "right_shoulder"),
228 ("RightForeArm", "right_elbow"),
229 ("RightHand", "right_wrist"),
230 ("LeftUpLeg", "left_hip"),
231 ("LeftLeg", "left_knee"),
232 ("LeftFoot", "left_ankle"),
233 ("RightUpLeg", "right_hip"),
234 ("RightLeg", "right_knee"),
235 ("RightFoot", "right_ankle"),
236 ];
237 let map = entries
238 .iter()
239 .map(|(k, v)| (k.to_string(), v.to_string()))
240 .collect();
241 Self { map }
242 }
243
244 #[allow(dead_code)]
246 pub fn from_map(map: HashMap<String, String>) -> Self {
247 Self { map }
248 }
249}
250
251fn euler_zxy_to_quat(rx_deg: f32, ry_deg: f32, rz_deg: f32) -> [f32; 4] {
257 let half = std::f32::consts::PI / 360.0; let (sx, cx) = (rx_deg * half).sin_cos();
259 let (sy, cy) = (ry_deg * half).sin_cos();
260 let (sz, cz) = (rz_deg * half).sin_cos();
261
262 let zx_x = cz * sx + sz * cx; let zx_y = -sz * sx; let _ = (zx_x, zx_y); let zx_xi = cz * sx + sz * 0.0 + 0.0 * 0.0 - 0.0 * cx; let zx_yi = cz * 0.0 - 0.0 * 0.0 + 0.0 * cx + sz * sx; let zx_zi = cz * 0.0 + 0.0 * sx - 0.0 * 0.0 + sz * cx; let zx_w = cz * cx - 0.0 * sx - 0.0 * 0.0 - sz * 0.0; let rx = zx_w * 0.0 + zx_xi * cy + zx_yi * 0.0 - zx_zi * sy;
287 let rx2 = zx_xi * cy - zx_zi * sy;
289 let _ = rx;
290 let ry2 = zx_w * sy + (-zx_xi) * 0.0 + zx_yi * cy + zx_zi * 0.0;
291 let ry3 = zx_w * sy + zx_yi * cy;
293 let _ = ry2;
294 let rz2 = zx_w * 0.0 + zx_xi * sy - zx_yi * 0.0 + zx_zi * cy;
295 let rz3 = zx_xi * sy + zx_zi * cy;
297 let _ = rz2;
298 let rw = zx_w * cy - zx_xi * 0.0 - zx_yi * sy - zx_zi * 0.0;
299 let rw2 = zx_w * cy - zx_yi * sy;
301 let _ = rw;
302
303 let len = (rx2 * rx2 + ry3 * ry3 + rz3 * rz3 + rw2 * rw2).sqrt();
305 if len < 1e-9 {
306 [0.0, 0.0, 0.0, 1.0]
307 } else {
308 [rx2 / len, ry3 / len, rz3 / len, rw2 / len]
309 }
310}
311
312#[allow(dead_code)]
317pub fn parse_bvh_text(bvh: &str) -> anyhow::Result<BvhData> {
318 let bvh_file = parse_bvh(bvh).map_err(|e| anyhow::anyhow!("BVH parse error: {}", e))?;
319
320 let fps = if bvh_file.frame_time > 0.0 {
321 1.0 / bvh_file.frame_time
322 } else {
323 30.0
324 };
325
326 let skeleton = &bvh_file.skeleton;
328 let active_joints: Vec<usize> = skeleton
329 .joints
330 .iter()
331 .enumerate()
332 .filter(|(_, j)| !j.channels.is_empty())
333 .map(|(i, _)| i)
334 .collect();
335
336 let root_idx = skeleton.root_index;
337
338 let mut all_frames: Vec<Vec<BvhJointFrame>> = Vec::with_capacity(bvh_file.frames.len());
339
340 for bvh_frame in &bvh_file.frames {
341 let mut joint_frames: Vec<BvhJointFrame> = Vec::with_capacity(active_joints.len());
342
343 for &joint_idx in &active_joints {
344 let joint = &skeleton.joints[joint_idx];
345 let ch_offset = skeleton.channel_offset_for(joint_idx);
346 let is_root = joint_idx == root_idx;
347
348 let mut rx = 0.0_f32;
350 let mut ry = 0.0_f32;
351 let mut rz = 0.0_f32;
352 let mut tx = 0.0_f32;
353 let mut ty = 0.0_f32;
354 let mut tz = 0.0_f32;
355
356 for (i, ch) in joint.channels.iter().enumerate() {
357 let val = bvh_frame.values.get(ch_offset + i).copied().unwrap_or(0.0);
358 match ch {
359 BvhChannel::Xrotation => rx = val,
360 BvhChannel::Yrotation => ry = val,
361 BvhChannel::Zrotation => rz = val,
362 BvhChannel::Xposition => tx = val,
363 BvhChannel::Yposition => ty = val,
364 BvhChannel::Zposition => tz = val,
365 }
366 }
367
368 let local_rotation = euler_zxy_to_quat(rx, ry, rz);
369 let local_position = if is_root {
370 [tx, ty, tz]
371 } else {
372 [0.0, 0.0, 0.0]
373 };
374
375 joint_frames.push(BvhJointFrame {
376 joint_name: joint.name.clone(),
377 local_rotation,
378 local_position,
379 });
380 }
381
382 all_frames.push(joint_frames);
383 }
384
385 Ok(BvhData {
386 fps,
387 frames: all_frames,
388 })
389}
390
391#[allow(dead_code)]
399pub fn retarget_bvh_to_param_tracks(
400 bvh: &BvhData,
401 mapping: &SkeletonMapping,
402) -> HashMap<String, Vec<f32>> {
403 let mut result: HashMap<String, Vec<f32>> = HashMap::new();
404
405 for (bvh_name, target_name) in &mapping.map {
406 let mut raw_values: Vec<f32> = Vec::with_capacity(bvh.frames.len());
409
410 for frame_joints in &bvh.frames {
411 let jf = frame_joints.iter().find(|jf| &jf.joint_name == bvh_name);
413 let dominant = match jf {
414 None => 0.0_f32,
415 Some(jf) => {
416 let [qx, qy, qz, _qw] = jf.local_rotation;
418 let ax = qx.abs();
419 let ay = qy.abs();
420 let az = qz.abs();
421 if ax >= ay && ax >= az {
422 qx
423 } else if ay >= az {
424 qy
425 } else {
426 qz
427 }
428 }
429 };
430 raw_values.push(dominant);
431 }
432
433 if raw_values.is_empty() {
434 continue;
435 }
436
437 let normalised: Vec<f32> = raw_values.iter().map(|&v| (v + 1.0) * 0.5).collect();
440
441 result.insert(target_name.clone(), normalised);
442 }
443
444 result
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use crate::interpolate::{Keyframe, MorphTrack};
451 use crate::params::ParamState;
452
453 fn make_params(h: f32, w: f32, m: f32, a: f32) -> ParamState {
454 ParamState::new(h, w, m, a)
455 }
456
457 fn make_track_two_kf() -> MorphTrack {
458 let mut track = MorphTrack::new("test");
459 track.add_keyframe(Keyframe::new(0.0, make_params(0.2, 0.3, 0.4, 0.5)));
460 track.add_keyframe(Keyframe::new(1.0, make_params(0.6, 0.7, 0.8, 0.9)));
461 track
462 }
463
464 #[test]
465 fn param_retarget_identity_unchanged() {
466 let r = ParamRetarget::identity();
467 assert!((r.apply(0.7) - 0.7).abs() < 1e-6);
468 }
469
470 #[test]
471 fn param_retarget_scaled_doubles() {
472 let r = ParamRetarget::scaled(2.0);
473 assert!((r.apply(0.3) - 0.6).abs() < 1e-6);
475 }
476
477 #[test]
478 fn param_retarget_clamps_above_one() {
479 let r = ParamRetarget::scaled(3.0);
480 assert!((r.apply(0.8) - 1.0).abs() < 1e-6);
482 }
483
484 #[test]
485 fn param_retarget_clamps_below_zero() {
486 let r = ParamRetarget {
487 scale: 1.0,
488 offset: -0.5,
489 clamp: true,
490 };
491 assert!((r.apply(0.2) - 0.0).abs() < 1e-6);
493 }
494
495 #[test]
496 fn anim_retarget_config_identity_preserves_params() {
497 let config = AnimRetargetConfig::identity();
498 let state = make_params(0.1, 0.2, 0.3, 0.4);
499 let result = config.apply(&state);
500 assert!((result.height - 0.1).abs() < 1e-6);
501 assert!((result.weight - 0.2).abs() < 1e-6);
502 assert!((result.muscle - 0.3).abs() < 1e-6);
503 assert!((result.age - 0.4).abs() < 1e-6);
504 }
505
506 #[test]
507 fn anim_retarget_config_scale_weight() {
508 let config = AnimRetargetConfig {
509 height: ParamRetarget::identity(),
510 weight: ParamRetarget::scaled(0.5),
511 muscle: ParamRetarget::identity(),
512 age: ParamRetarget::identity(),
513 };
514 let state = make_params(0.4, 0.8, 0.6, 0.5);
515 let result = config.apply(&state);
516 assert!((result.height - 0.4).abs() < 1e-6);
517 assert!((result.weight - 0.4).abs() < 1e-6);
519 }
520
521 #[test]
522 fn retarget_keyframe_applies_config() {
523 let config = AnimRetargetConfig {
524 height: ParamRetarget::scaled(0.5),
525 weight: ParamRetarget::identity(),
526 muscle: ParamRetarget::identity(),
527 age: ParamRetarget::identity(),
528 };
529 let kf = Keyframe::new(2.5, make_params(1.0, 0.5, 0.5, 0.5));
530 let result = retarget_keyframe(&kf, &config);
531 assert!((result.time - 2.5).abs() < 1e-6);
532 assert!((result.params.height - 0.5).abs() < 1e-6);
534 }
535
536 #[test]
537 fn retarget_track_preserves_length() {
538 let track = make_track_two_kf();
539 let config = AnimRetargetConfig::identity();
540 let result = retarget_track(&track, &config);
541 assert_eq!(result.len(), track.len());
542 }
543
544 #[test]
545 fn scale_track_time_doubles_durations() {
546 let track = make_track_two_kf();
547 let original_duration = track.duration();
548 let scaled = scale_track_time(&track, 2.0);
549 assert!((scaled.duration() - original_duration * 2.0).abs() < 1e-5);
550 }
551
552 #[test]
553 fn trim_track_removes_outside_keyframes() {
554 let mut track = MorphTrack::new("trim_test");
555 track.add_keyframe(Keyframe::new(0.0, make_params(0.1, 0.1, 0.1, 0.1)));
556 track.add_keyframe(Keyframe::new(1.0, make_params(0.2, 0.2, 0.2, 0.2)));
557 track.add_keyframe(Keyframe::new(2.0, make_params(0.3, 0.3, 0.3, 0.3)));
558 track.add_keyframe(Keyframe::new(3.0, make_params(0.4, 0.4, 0.4, 0.4)));
559 let trimmed = trim_track(&track, 1.0, 2.0);
560 assert_eq!(trimmed.len(), 2);
561 }
562
563 #[test]
564 fn trim_track_keeps_inside_keyframes() {
565 let mut track = MorphTrack::new("trim_keep");
566 track.add_keyframe(Keyframe::new(0.5, make_params(0.1, 0.1, 0.1, 0.1)));
567 track.add_keyframe(Keyframe::new(1.5, make_params(0.5, 0.5, 0.5, 0.5)));
568 track.add_keyframe(Keyframe::new(2.5, make_params(0.9, 0.9, 0.9, 0.9)));
569 let trimmed = trim_track(&track, 0.0, 3.0);
570 assert_eq!(trimmed.len(), 3);
571 }
572
573 #[test]
574 fn reverse_track_flips_order() {
575 let track = make_track_two_kf();
576 let reversed = reverse_track(&track);
577 assert_eq!(reversed.len(), 2);
578 let original_last = track
582 .keyframes_iter()
583 .last()
584 .expect("should succeed")
585 .params
586 .clone();
587 let reversed_first = reversed
588 .keyframes_iter()
589 .next()
590 .expect("should succeed")
591 .params
592 .clone();
593 assert!((original_last.height - reversed_first.height).abs() < 1e-6);
594 }
595
596 #[test]
597 fn concat_tracks_total_length() {
598 let first = make_track_two_kf();
599 let second = make_track_two_kf();
600 let combined = concat_tracks(&first, &second);
601 assert_eq!(combined.len(), 4);
602 }
603
604 #[test]
605 fn concat_tracks_second_offset_correctly() {
606 let first = make_track_two_kf(); let mut second = MorphTrack::new("second");
608 second.add_keyframe(Keyframe::new(0.0, make_params(0.1, 0.1, 0.1, 0.1)));
609 second.add_keyframe(Keyframe::new(0.5, make_params(0.9, 0.9, 0.9, 0.9)));
610 let combined = concat_tracks(&first, &second);
611 let last_kf = combined.keyframes_iter().last().expect("should succeed");
616 assert!((last_kf.time - 1.5).abs() < 1e-6);
617 }
618
619 fn minimal_bvh_bridge() -> &'static str {
623 "HIERARCHY
624ROOT Hips
625{
626 OFFSET 0.00 0.00 0.00
627 CHANNELS 6 Xposition Yposition Zposition Zrotation Xrotation Yrotation
628 JOINT Spine
629 {
630 OFFSET 0.00 5.21 0.00
631 CHANNELS 3 Zrotation Xrotation Yrotation
632 End Site
633 {
634 OFFSET 0.00 5.00 0.00
635 }
636 }
637}
638MOTION
639Frames: 2
640Frame Time: 0.033333
6410.00 94.26 0.00 0.00 0.00 0.00 0.00 0.00 0.00
6420.00 94.26 0.00 10.00 5.00 0.00 5.00 0.00 0.00
643"
644 }
645
646 #[test]
648 fn bvh_bridge_parse_no_error() {
649 let result = parse_bvh_text(minimal_bvh_bridge());
650 assert!(result.is_ok(), "parse_bvh_text returned Err: {:?}", result);
651 }
652
653 #[test]
655 fn bvh_bridge_joint_count_per_frame() {
656 let bvh = parse_bvh_text(minimal_bvh_bridge()).expect("parse failed");
657 for frame in &bvh.frames {
659 assert_eq!(
660 frame.len(),
661 2,
662 "expected 2 joints per frame, got {}",
663 frame.len()
664 );
665 }
666 }
667
668 #[test]
670 fn bvh_bridge_default_cmu_has_hips() {
671 let mapping = SkeletonMapping::default_cmu();
672 assert!(
673 mapping.map.contains_key("Hips"),
674 "SkeletonMapping::default_cmu() must contain 'Hips'"
675 );
676 }
677
678 #[test]
680 fn bvh_bridge_retarget_nonempty_tracks() {
681 let bvh = parse_bvh_text(minimal_bvh_bridge()).expect("parse failed");
682 let mapping = SkeletonMapping::default_cmu();
683 let tracks = retarget_bvh_to_param_tracks(&bvh, &mapping);
684 assert!(
686 !tracks.is_empty(),
687 "retarget_bvh_to_param_tracks must produce at least one track"
688 );
689 }
690
691 #[test]
693 fn bvh_bridge_track_length_equals_frame_count() {
694 let bvh = parse_bvh_text(minimal_bvh_bridge()).expect("parse failed");
695 let frame_count = bvh.frames.len();
696 let mapping = SkeletonMapping::default_cmu();
697 let tracks = retarget_bvh_to_param_tracks(&bvh, &mapping);
698 for (name, track) in &tracks {
699 assert_eq!(
700 track.len(),
701 frame_count,
702 "track '{}' has {} entries, expected {}",
703 name,
704 track.len(),
705 frame_count
706 );
707 }
708 }
709}