1#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub struct Joint {
12 pub name: String,
13 pub parent: Option<usize>,
14 pub rest_rot: [f32; 4],
16 pub rest_pos: [f32; 3],
18}
19
20#[allow(dead_code)]
22#[derive(Debug, Clone)]
23pub struct SkeletonPose {
24 pub joints: Vec<Joint>,
25 pub local_rots: Vec<[f32; 4]>,
27 pub root_pos: [f32; 3],
28}
29
30#[allow(dead_code)]
32#[derive(Debug, Clone)]
33pub struct RetargetMap {
34 pub source_joints: Vec<String>,
35 pub target_joints: Vec<String>,
36 pub scale: f32,
37}
38
39#[allow(dead_code)]
43pub fn quat_multiply(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
44 let [ax, ay, az, aw] = a;
45 let [bx, by, bz, bw] = b;
46 [
47 aw * bx + ax * bw + ay * bz - az * by,
48 aw * by - ax * bz + ay * bw + az * bx,
49 aw * bz + ax * by - ay * bx + az * bw,
50 aw * bw - ax * bx - ay * by - az * bz,
51 ]
52}
53
54#[allow(dead_code)]
56pub fn quat_inverse(q: [f32; 4]) -> [f32; 4] {
57 [-q[0], -q[1], -q[2], q[3]]
58}
59
60#[allow(dead_code)]
62pub fn quat_normalize(q: [f32; 4]) -> [f32; 4] {
63 let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
64 if len < 1e-10 {
65 return [0.0, 0.0, 0.0, 1.0];
66 }
67 [q[0] / len, q[1] / len, q[2] / len, q[3] / len]
68}
69
70#[allow(dead_code)]
72pub fn quat_slerp(a: [f32; 4], b: [f32; 4], t: f32) -> [f32; 4] {
73 let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3];
74 let (b, dot) = if dot < 0.0 {
76 ([-b[0], -b[1], -b[2], -b[3]], -dot)
77 } else {
78 (b, dot)
79 };
80 let dot = dot.min(1.0);
81 if dot > 0.9995 {
82 return quat_normalize([
84 a[0] + t * (b[0] - a[0]),
85 a[1] + t * (b[1] - a[1]),
86 a[2] + t * (b[2] - a[2]),
87 a[3] + t * (b[3] - a[3]),
88 ]);
89 }
90 let theta_0 = dot.acos();
91 let theta = theta_0 * t;
92 let sin_theta = theta.sin();
93 let sin_theta_0 = theta_0.sin();
94 let s0 = ((1.0 - t) * theta_0).sin() / sin_theta_0;
95 let s1 = sin_theta / sin_theta_0;
96 quat_normalize([
97 s0 * a[0] + s1 * b[0],
98 s0 * a[1] + s1 * b[1],
99 s0 * a[2] + s1 * b[2],
100 s0 * a[3] + s1 * b[3],
101 ])
102}
103
104#[allow(dead_code)]
107pub fn quat_to_swing_twist(q: [f32; 4], twist_axis: [f32; 3]) -> ([f32; 4], [f32; 4]) {
108 let [x, y, z, w] = q;
109 let [ax, ay, az] = twist_axis;
110 let dot = x * ax + y * ay + z * az;
112 let twist = quat_normalize([dot * ax, dot * ay, dot * az, w]);
113 let swing = quat_multiply(q, quat_inverse(twist));
114 (quat_normalize(swing), twist)
115}
116
117#[allow(dead_code)]
123pub fn retarget_joint_rotation(
124 src_rot: [f32; 4],
125 src_rest: [f32; 4],
126 tgt_rest: [f32; 4],
127) -> [f32; 4] {
128 let delta = quat_multiply(src_rot, quat_inverse(src_rest));
130 quat_normalize(quat_multiply(delta, tgt_rest))
132}
133
134#[allow(dead_code)]
136pub fn retarget_pose_adv(
137 src: &SkeletonPose,
138 tgt_rest: &SkeletonPose,
139 map: &RetargetMap,
140) -> SkeletonPose {
141 let mut out = tgt_rest.clone();
142 out.root_pos = scale_root_translation(
143 src.root_pos,
144 compute_skeleton_height(src),
145 compute_skeleton_height(tgt_rest),
146 );
147
148 for (si, src_name) in map.source_joints.iter().enumerate() {
149 if let Some(tgt_name) = map.target_joints.get(si) {
150 let src_idx = src
152 .joints
153 .iter()
154 .position(|j| &j.name == src_name)
155 .unwrap_or(usize::MAX);
156 let tgt_idx = tgt_rest
158 .joints
159 .iter()
160 .position(|j| &j.name == tgt_name)
161 .unwrap_or(usize::MAX);
162
163 if src_idx < src.joints.len()
164 && tgt_idx < tgt_rest.joints.len()
165 && src_idx < src.local_rots.len()
166 {
167 let src_rot = src.local_rots[src_idx];
168 let src_rest_rot = src.joints[src_idx].rest_rot;
169 let tgt_rest_rot = tgt_rest.joints[tgt_idx].rest_rot;
170 out.local_rots[tgt_idx] =
171 retarget_joint_rotation(src_rot, src_rest_rot, tgt_rest_rot);
172 }
173 }
174 }
175 out
176}
177
178#[allow(dead_code)]
180pub fn scale_root_translation(pos: [f32; 3], src_height: f32, tgt_height: f32) -> [f32; 3] {
181 if src_height < 1e-6 {
182 return pos;
183 }
184 let s = tgt_height / src_height;
185 [pos[0] * s, pos[1] * s, pos[2] * s]
186}
187
188#[allow(dead_code)]
190pub fn blend_poses(a: &SkeletonPose, b: &SkeletonPose, t: f32) -> SkeletonPose {
191 let joints = a.joints.clone();
192 let n = joints.len().min(a.local_rots.len()).min(b.local_rots.len());
193 let local_rots = (0..n)
194 .map(|i| quat_slerp(a.local_rots[i], b.local_rots[i], t))
195 .collect();
196 let root_pos = [
197 a.root_pos[0] + t * (b.root_pos[0] - a.root_pos[0]),
198 a.root_pos[1] + t * (b.root_pos[1] - a.root_pos[1]),
199 a.root_pos[2] + t * (b.root_pos[2] - a.root_pos[2]),
200 ];
201 SkeletonPose {
202 joints,
203 local_rots,
204 root_pos,
205 }
206}
207
208#[allow(dead_code)]
211pub fn compute_skeleton_height(pose: &SkeletonPose) -> f32 {
212 let mut max_y = 0.0_f32;
213 let mut world_y = vec![0.0_f32; pose.joints.len()];
215 for (i, joint) in pose.joints.iter().enumerate() {
216 let parent_y = joint.parent.map_or(0.0, |p| world_y[p]);
217 world_y[i] = parent_y + joint.rest_pos[1];
218 max_y = max_y.max(world_y[i]);
219 }
220 max_y.max(0.001)
221}
222
223#[allow(dead_code)]
225pub fn standard_biped_retarget_map() -> RetargetMap {
226 let joints = vec![
227 "Hips",
228 "Spine",
229 "Spine1",
230 "Neck",
231 "Head",
232 "LeftArm",
233 "LeftForeArm",
234 "LeftHand",
235 "RightArm",
236 "RightForeArm",
237 "RightHand",
238 "LeftUpLeg",
239 "LeftLeg",
240 "RightUpLeg",
241 ];
242 RetargetMap {
243 source_joints: joints.iter().map(|s| s.to_string()).collect(),
244 target_joints: joints.iter().map(|s| s.to_string()).collect(),
245 scale: 1.0,
246 }
247}
248
249#[allow(dead_code)]
252fn identity_quat() -> [f32; 4] {
253 [0.0, 0.0, 0.0, 1.0]
254}
255
256#[allow(dead_code)]
257fn make_test_pose(n: usize) -> SkeletonPose {
258 let joints = (0..n)
259 .map(|i| Joint {
260 name: format!("Joint{i}"),
261 parent: if i == 0 { None } else { Some(i - 1) },
262 rest_rot: identity_quat(),
263 rest_pos: [0.0, 0.1 * i as f32, 0.0],
264 })
265 .collect();
266 let local_rots = vec![identity_quat(); n];
267 SkeletonPose {
268 joints,
269 local_rots,
270 root_pos: [0.0, 0.0, 0.0],
271 }
272}
273
274#[cfg(test)]
277mod tests {
278 use super::*;
279
280 fn id() -> [f32; 4] {
281 [0.0, 0.0, 0.0, 1.0]
282 }
283
284 fn nearly_eq(a: [f32; 4], b: [f32; 4]) -> bool {
285 (0..4).all(|i| (a[i] - b[i]).abs() < 1e-4)
286 }
287
288 fn nearly_eq3(a: [f32; 3], b: [f32; 3]) -> bool {
289 (0..3).all(|i| (a[i] - b[i]).abs() < 1e-4)
290 }
291
292 #[test]
293 fn test_quat_multiply_identity_left() {
294 let q = [0.1, 0.2, 0.3, 0.927];
295 let q = quat_normalize(q);
296 let result = quat_multiply(id(), q);
297 assert!(nearly_eq(result, q));
298 }
299
300 #[test]
301 fn test_quat_multiply_identity_right() {
302 let q = quat_normalize([0.1, 0.2, 0.3, 0.927]);
303 let result = quat_multiply(q, id());
304 assert!(nearly_eq(result, q));
305 }
306
307 #[test]
308 fn test_quat_inverse_composed_is_identity() {
309 let q = quat_normalize([0.1, 0.2, 0.3, 0.927]);
310 let qi = quat_inverse(q);
311 let result = quat_normalize(quat_multiply(q, qi));
312 assert!(nearly_eq(result, id()));
313 }
314
315 #[test]
316 fn test_quat_inverse_conjugate() {
317 let q = [0.1, 0.2, 0.3, 0.9];
318 let qi = quat_inverse(q);
319 assert_eq!(qi, [-0.1, -0.2, -0.3, 0.9]);
320 }
321
322 #[test]
323 fn test_quat_normalize_length_one() {
324 let q = quat_normalize([1.0, 2.0, 3.0, 4.0]);
325 let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
326 assert!((len - 1.0).abs() < 1e-6);
327 }
328
329 #[test]
330 fn test_quat_normalize_zero_returns_identity() {
331 let q = quat_normalize([0.0, 0.0, 0.0, 0.0]);
332 assert_eq!(q, [0.0, 0.0, 0.0, 1.0]);
333 }
334
335 #[test]
336 fn test_quat_slerp_t0() {
337 let a = id();
338 let frac = std::f32::consts::FRAC_1_SQRT_2;
339 let b = quat_normalize([0.0, frac, 0.0, frac]);
340 let result = quat_slerp(a, b, 0.0);
341 assert!(nearly_eq(result, a));
342 }
343
344 #[test]
345 fn test_quat_slerp_t1() {
346 let a = id();
347 let frac = std::f32::consts::FRAC_1_SQRT_2;
348 let b = quat_normalize([0.0, frac, 0.0, frac]);
349 let result = quat_slerp(a, b, 1.0);
350 assert!(nearly_eq(result, b));
351 }
352
353 #[test]
354 fn test_quat_slerp_t_half_normalized() {
355 let a = id();
356 let b = id();
357 let result = quat_slerp(a, b, 0.5);
358 assert!(nearly_eq(result, id()));
359 }
360
361 #[test]
362 fn test_swing_twist_roundtrip() {
363 let q = quat_normalize([0.1, 0.2, 0.0, 0.974]);
364 let axis = [0.0, 1.0, 0.0];
365 let (swing, twist) = quat_to_swing_twist(q, axis);
366 let composed = quat_normalize(quat_multiply(swing, twist));
367 assert!(nearly_eq(composed, quat_normalize(q)));
368 }
369
370 #[test]
371 fn test_swing_twist_pure_twist() {
372 let q = quat_normalize([0.0, 0.5, 0.0, 0.866]);
374 let (swing, _twist) = quat_to_swing_twist(q, [0.0, 1.0, 0.0]);
375 assert!((swing[3] - 1.0).abs() < 0.1); }
377
378 #[test]
379 fn test_retarget_pose_no_nan() {
380 let src = make_test_pose(5);
381 let tgt = make_test_pose(5);
382 let map = RetargetMap {
383 source_joints: src.joints.iter().map(|j| j.name.clone()).collect(),
384 target_joints: tgt.joints.iter().map(|j| j.name.clone()).collect(),
385 scale: 1.0,
386 };
387 let out = retarget_pose_adv(&src, &tgt, &map);
388 for r in &out.local_rots {
389 for v in r {
390 assert!(!v.is_nan());
391 }
392 }
393 }
394
395 #[test]
396 fn test_blend_poses_t0() {
397 let a = make_test_pose(4);
398 let b = make_test_pose(4);
399 let out = blend_poses(&a, &b, 0.0);
400 for i in 0..4 {
401 assert!(nearly_eq(out.local_rots[i], a.local_rots[i]));
402 }
403 }
404
405 #[test]
406 fn test_blend_poses_t1() {
407 let a = make_test_pose(4);
408 let b = make_test_pose(4);
409 let out = blend_poses(&a, &b, 1.0);
410 for i in 0..4 {
411 assert!(nearly_eq(out.local_rots[i], b.local_rots[i]));
412 }
413 }
414
415 #[test]
416 fn test_blend_poses_root_lerp() {
417 let mut a = make_test_pose(2);
418 let mut b = make_test_pose(2);
419 a.root_pos = [0.0, 0.0, 0.0];
420 b.root_pos = [2.0, 4.0, 6.0];
421 let out = blend_poses(&a, &b, 0.5);
422 assert!(nearly_eq3(out.root_pos, [1.0, 2.0, 3.0]));
423 }
424
425 #[test]
426 fn test_compute_skeleton_height_positive() {
427 let pose = make_test_pose(5);
428 let h = compute_skeleton_height(&pose);
429 assert!(h > 0.0);
430 }
431
432 #[test]
433 fn test_scale_root_translation_proportional() {
434 let pos = [1.0, 2.0, 3.0];
435 let out = scale_root_translation(pos, 1.0, 2.0);
436 assert!(nearly_eq3(out, [2.0, 4.0, 6.0]));
437 }
438
439 #[test]
440 fn test_standard_biped_retarget_map_14_joints() {
441 let map = standard_biped_retarget_map();
442 assert_eq!(map.source_joints.len(), 14);
443 assert_eq!(map.target_joints.len(), 14);
444 }
445
446 #[test]
447 fn test_retarget_joint_rotation_identity_pass_through() {
448 let rot = quat_normalize([0.1, 0.2, 0.3, 0.9]);
449 let rest = id();
450 let result = retarget_joint_rotation(rot, rest, rest);
451 assert!(nearly_eq(quat_normalize(result), quat_normalize(rot)));
452 }
453}