1#[allow(dead_code)]
8#[derive(Debug, Clone)]
9pub struct JointPose {
10 pub name: String,
11 pub rotation: [f32; 4], pub translation: [f32; 3],
13 pub scale: f32,
14}
15
16#[allow(dead_code)]
17#[derive(Debug, Clone)]
18pub struct SymmetryPair {
19 pub left_name: String,
20 pub right_name: String,
21 pub mirror_axis: u8, }
23
24#[allow(dead_code)]
25#[derive(Debug, Clone)]
26pub struct PoseSkeleton {
27 pub joints: Vec<JointPose>,
28}
29
30#[allow(dead_code)]
33pub fn mirror_joint_rotation(q: [f32; 4], axis: u8) -> [f32; 4] {
34 let [x, y, z, w] = q;
37 match axis {
38 0 => [-x, y, z, -w], 1 => [x, -y, z, -w], 2 => [x, y, -z, -w], _ => [x, y, z, w],
42 }
43}
44
45#[allow(dead_code)]
47pub fn mirror_pose(skeleton: &PoseSkeleton, pairs: &[SymmetryPair]) -> PoseSkeleton {
48 let mut joints = skeleton.joints.clone();
49
50 for pair in pairs {
51 let left_idx = joints.iter().position(|j| j.name == pair.left_name);
52 let right_idx = joints.iter().position(|j| j.name == pair.right_name);
53
54 if let (Some(li), Some(ri)) = (left_idx, right_idx) {
55 let left_rot = joints[li].rotation;
56 let right_rot = joints[ri].rotation;
57 let left_trans = joints[li].translation;
58 let right_trans = joints[ri].translation;
59
60 joints[li].rotation = mirror_joint_rotation(right_rot, pair.mirror_axis);
61 joints[ri].rotation = mirror_joint_rotation(left_rot, pair.mirror_axis);
62
63 let mut new_left_trans = right_trans;
65 let mut new_right_trans = left_trans;
66 let ax = pair.mirror_axis as usize;
67 new_left_trans[ax] = -right_trans[ax];
68 new_right_trans[ax] = -left_trans[ax];
69
70 joints[li].translation = new_left_trans;
71 joints[ri].translation = new_right_trans;
72 }
73 }
74
75 PoseSkeleton { joints }
76}
77
78#[allow(dead_code)]
80pub fn enforce_symmetry_pose(skeleton: &mut PoseSkeleton, pairs: &[SymmetryPair], blend: f32) {
81 let blend = blend.clamp(0.0, 1.0);
82 let mirrored = mirror_pose(skeleton, pairs);
83
84 for (joint, mirrored_joint) in skeleton.joints.iter_mut().zip(mirrored.joints.iter()) {
85 joint.rotation = quat_slerp_pose(joint.rotation, mirrored_joint.rotation, blend * 0.5);
86
87 for i in 0..3 {
88 joint.translation[i] +=
89 (mirrored_joint.translation[i] - joint.translation[i]) * blend * 0.5;
90 }
91 }
92}
93
94#[allow(dead_code)]
96pub fn pose_symmetry_error(skeleton: &PoseSkeleton, pairs: &[SymmetryPair]) -> f32 {
97 let mut sum_sq = 0.0_f32;
98 let mut count = 0;
99
100 for pair in pairs {
101 let left = find_joint_by_name(skeleton, &pair.left_name);
102 let right = find_joint_by_name(skeleton, &pair.right_name);
103
104 if let (Some(l), Some(r)) = (left, right) {
105 let mirrored_r = mirror_joint_rotation(r.rotation, pair.mirror_axis);
106 let dist = quat_angle_distance(l.rotation, mirrored_r);
108 sum_sq += dist * dist;
109 count += 1;
110 }
111 }
112
113 if count == 0 {
114 0.0
115 } else {
116 (sum_sq / count as f32).sqrt()
117 }
118}
119
120#[allow(dead_code)]
122pub fn standard_biped_symmetry_pairs() -> Vec<SymmetryPair> {
123 let pairs_data = [
124 ("LeftArm", "RightArm"),
125 ("LeftForeArm", "RightForeArm"),
126 ("LeftHand", "RightHand"),
127 ("LeftUpLeg", "RightUpLeg"),
128 ("LeftLeg", "RightLeg"),
129 ("LeftFoot", "RightFoot"),
130 ("LeftToeBase", "RightToeBase"),
131 ("LeftShoulder", "RightShoulder"),
132 ("LeftHandThumb1", "RightHandThumb1"),
133 ("LeftHandIndex1", "RightHandIndex1"),
134 ("LeftHandMiddle1", "RightHandMiddle1"),
135 ("LeftHandRing1", "RightHandRing1"),
136 ("LeftHandPinky1", "RightHandPinky1"),
137 ];
138
139 pairs_data
140 .iter()
141 .map(|(l, r)| SymmetryPair {
142 left_name: l.to_string(),
143 right_name: r.to_string(),
144 mirror_axis: 0, })
146 .collect()
147}
148
149#[allow(dead_code)]
151pub fn find_joint_by_name<'a>(skeleton: &'a PoseSkeleton, name: &str) -> Option<&'a JointPose> {
152 skeleton.joints.iter().find(|j| j.name == name)
153}
154
155#[allow(dead_code)]
157pub fn quat_slerp_pose(a: [f32; 4], b: [f32; 4], t: f32) -> [f32; 4] {
158 let t = t.clamp(0.0, 1.0);
159 let [ax, ay, az, aw] = a;
160 let [mut bx, mut by, mut bz, mut bw] = b;
161
162 let mut dot = ax * bx + ay * by + az * bz + aw * bw;
163 if dot < 0.0 {
164 bx = -bx;
165 by = -by;
166 bz = -bz;
167 bw = -bw;
168 dot = -dot;
169 }
170
171 if dot > 0.9995 {
172 let rx = ax + t * (bx - ax);
174 let ry = ay + t * (by - ay);
175 let rz = az + t * (bz - az);
176 let rw = aw + t * (bw - aw);
177 let mag = (rx * rx + ry * ry + rz * rz + rw * rw).sqrt().max(1e-8);
178 return [rx / mag, ry / mag, rz / mag, rw / mag];
179 }
180
181 let theta_0 = dot.acos();
182 let theta = theta_0 * t;
183 let sin_theta = theta.sin();
184 let sin_theta_0 = theta_0.sin();
185
186 let s0 = (theta_0 - theta).sin() / sin_theta_0;
187 let s1 = sin_theta / sin_theta_0;
188
189 [
190 s0 * ax + s1 * bx,
191 s0 * ay + s1 * by,
192 s0 * az + s1 * bz,
193 s0 * aw + s1 * bw,
194 ]
195}
196
197#[allow(dead_code)]
199pub fn interpolate_poses(a: &PoseSkeleton, b: &PoseSkeleton, t: f32) -> PoseSkeleton {
200 let t = t.clamp(0.0, 1.0);
201 let joints = a
202 .joints
203 .iter()
204 .zip(b.joints.iter())
205 .map(|(ja, jb)| {
206 let lerp = |x: f32, y: f32| x + (y - x) * t;
207 JointPose {
208 name: ja.name.clone(),
209 rotation: quat_slerp_pose(ja.rotation, jb.rotation, t),
210 translation: [
211 lerp(ja.translation[0], jb.translation[0]),
212 lerp(ja.translation[1], jb.translation[1]),
213 lerp(ja.translation[2], jb.translation[2]),
214 ],
215 scale: lerp(ja.scale, jb.scale),
216 }
217 })
218 .collect();
219 PoseSkeleton { joints }
220}
221
222#[allow(dead_code)]
224pub fn detect_symmetry_pairs(joint_names: &[String]) -> Vec<SymmetryPair> {
225 let mut pairs = Vec::new();
226 for name in joint_names {
227 if let Some(suffix) = name.strip_prefix("Left") {
228 let right_name = format!("Right{suffix}");
229 if joint_names.iter().any(|n| n == &right_name) {
230 pairs.push(SymmetryPair {
231 left_name: name.clone(),
232 right_name,
233 mirror_axis: 0,
234 });
235 }
236 }
237 }
238 pairs
239}
240
241#[allow(dead_code)]
243pub fn pose_distance_sym(a: &PoseSkeleton, b: &PoseSkeleton) -> f32 {
244 let pairs: Vec<_> = a.joints.iter().zip(b.joints.iter()).collect();
245 if pairs.is_empty() {
246 return 0.0;
247 }
248 let sum: f32 = pairs
249 .iter()
250 .map(|(ja, jb)| quat_angle_distance(ja.rotation, jb.rotation))
251 .sum();
252 sum / pairs.len() as f32
253}
254
255#[allow(dead_code)]
257pub fn apply_pose_offset(skeleton: &mut PoseSkeleton, joint_name: &str, rotation_delta: [f32; 4]) {
258 if let Some(joint) = skeleton.joints.iter_mut().find(|j| j.name == joint_name) {
259 joint.rotation = quat_multiply_pose(joint.rotation, rotation_delta);
260 let [x, y, z, w] = joint.rotation;
262 let mag = (x * x + y * y + z * z + w * w).sqrt().max(1e-8);
263 joint.rotation = [x / mag, y / mag, z / mag, w / mag];
264 }
265}
266
267fn quat_multiply_pose(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
270 let [ax, ay, az, aw] = a;
271 let [bx, by, bz, bw] = b;
272 [
273 aw * bx + ax * bw + ay * bz - az * by,
274 aw * by - ax * bz + ay * bw + az * bx,
275 aw * bz + ax * by - ay * bx + az * bw,
276 aw * bw - ax * bx - ay * by - az * bz,
277 ]
278}
279
280fn quat_angle_distance(a: [f32; 4], b: [f32; 4]) -> f32 {
281 let dot = (a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3])
282 .abs()
283 .clamp(0.0, 1.0);
284 2.0 * dot.acos()
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn identity_quat() -> [f32; 4] {
292 [0.0, 0.0, 0.0, 1.0]
293 }
294
295 fn make_joint(name: &str) -> JointPose {
296 JointPose {
297 name: name.to_string(),
298 rotation: identity_quat(),
299 translation: [0.0, 0.0, 0.0],
300 scale: 1.0,
301 }
302 }
303
304 fn make_simple_skeleton() -> PoseSkeleton {
305 PoseSkeleton {
306 joints: vec![
307 make_joint("LeftArm"),
308 make_joint("RightArm"),
309 make_joint("Spine"),
310 ],
311 }
312 }
313
314 #[test]
315 fn test_mirror_identity_quat() {
316 let q = identity_quat();
317 let mirrored = mirror_joint_rotation(q, 0);
318 assert_eq!(mirrored, [0.0, 0.0, 0.0, -1.0]);
320 }
321
322 #[test]
323 fn test_mirror_pose_swaps_joints() {
324 let mut skel = make_simple_skeleton();
325 skel.joints[0].translation = [1.0, 0.0, 0.0]; skel.joints[1].translation = [-1.0, 0.0, 0.0]; let pairs = vec![SymmetryPair {
329 left_name: "LeftArm".to_string(),
330 right_name: "RightArm".to_string(),
331 mirror_axis: 0,
332 }];
333
334 let mirrored = mirror_pose(&skel, &pairs);
335 assert!((mirrored.joints[0].translation[0] - 1.0).abs() < 1e-4);
337 }
338
339 #[test]
340 fn test_enforce_symmetry_reduces_error() {
341 let mut skel = PoseSkeleton {
342 joints: vec![
343 JointPose {
344 name: "LeftArm".to_string(),
345 rotation: [0.1, 0.0, 0.0, (1.0_f32 - 0.01_f32).sqrt()],
346 translation: [1.0, 0.0, 0.0],
347 scale: 1.0,
348 },
349 JointPose {
350 name: "RightArm".to_string(),
351 rotation: identity_quat(),
352 translation: [-1.0, 0.0, 0.0],
353 scale: 1.0,
354 },
355 ],
356 };
357 let pairs = vec![SymmetryPair {
358 left_name: "LeftArm".to_string(),
359 right_name: "RightArm".to_string(),
360 mirror_axis: 0,
361 }];
362 let err_before = pose_symmetry_error(&skel, &pairs);
363 enforce_symmetry_pose(&mut skel, &pairs, 1.0);
364 let err_after = pose_symmetry_error(&skel, &pairs);
365 assert!(
366 err_after <= err_before + 1e-4,
367 "symmetry error should not increase"
368 );
369 }
370
371 #[test]
372 fn test_pose_symmetry_error_symmetric_skeleton() {
373 let skel = make_simple_skeleton();
374 let pairs = standard_biped_symmetry_pairs();
375 let err = pose_symmetry_error(&skel, &pairs);
377 assert_eq!(err, 0.0);
378 }
379
380 #[test]
381 fn test_standard_biped_symmetry_pairs_not_empty() {
382 let pairs = standard_biped_symmetry_pairs();
383 assert!(!pairs.is_empty());
384 assert!(pairs.iter().any(|p| p.left_name.contains("Arm")));
385 }
386
387 #[test]
388 fn test_find_joint_by_name() {
389 let skel = make_simple_skeleton();
390 let joint = find_joint_by_name(&skel, "Spine");
391 assert!(joint.is_some());
392 assert_eq!(joint.expect("should succeed").name, "Spine");
393 }
394
395 #[test]
396 fn test_find_joint_missing() {
397 let skel = make_simple_skeleton();
398 assert!(find_joint_by_name(&skel, "NonExistent").is_none());
399 }
400
401 #[test]
402 fn test_quat_slerp_t0() {
403 let a = identity_quat();
404 let b = [0.0, 0.0, 1.0_f32.sin(), 1.0_f32.cos()];
405 let result = quat_slerp_pose(a, b, 0.0);
406 assert!((result[3] - a[3]).abs() < 1e-4);
407 }
408
409 #[test]
410 fn test_quat_slerp_t1() {
411 let a = identity_quat();
412 let b = [0.0, 0.0, (0.5_f32).sin(), (0.5_f32).cos()];
413 let result = quat_slerp_pose(a, b, 1.0);
414 assert!((result[2] - b[2]).abs() < 1e-4);
415 assert!((result[3] - b[3]).abs() < 1e-4);
416 }
417
418 #[test]
419 fn test_interpolate_poses_midpoint() {
420 let a = PoseSkeleton {
421 joints: vec![JointPose {
422 name: "Root".to_string(),
423 rotation: identity_quat(),
424 translation: [0.0, 0.0, 0.0],
425 scale: 1.0,
426 }],
427 };
428 let b = PoseSkeleton {
429 joints: vec![JointPose {
430 name: "Root".to_string(),
431 rotation: identity_quat(),
432 translation: [2.0, 0.0, 0.0],
433 scale: 2.0,
434 }],
435 };
436 let mid = interpolate_poses(&a, &b, 0.5);
437 assert!((mid.joints[0].translation[0] - 1.0).abs() < 1e-4);
438 assert!((mid.joints[0].scale - 1.5).abs() < 1e-4);
439 }
440
441 #[test]
442 fn test_detect_symmetry_pairs() {
443 let names: Vec<String> = vec![
444 "LeftArm".to_string(),
445 "RightArm".to_string(),
446 "LeftLeg".to_string(),
447 "RightLeg".to_string(),
448 "Spine".to_string(),
449 ];
450 let pairs = detect_symmetry_pairs(&names);
451 assert_eq!(pairs.len(), 2);
452 assert!(pairs.iter().any(|p| p.left_name == "LeftArm"));
453 }
454
455 #[test]
456 fn test_detect_symmetry_pairs_no_match() {
457 let names: Vec<String> = vec!["Spine".to_string(), "Hips".to_string()];
458 let pairs = detect_symmetry_pairs(&names);
459 assert!(pairs.is_empty());
460 }
461
462 #[test]
463 fn test_pose_distance_sym_identity() {
464 let a = make_simple_skeleton();
465 let b = a.clone();
466 let dist = pose_distance_sym(&a, &b);
467 assert!(dist < 1e-4);
468 }
469
470 #[test]
471 fn test_apply_pose_offset() {
472 let mut skel = make_simple_skeleton();
473 let delta = [0.0, 0.0, (0.1_f32).sin(), (0.1_f32).cos()];
474 apply_pose_offset(&mut skel, "LeftArm", delta);
475 let joint = find_joint_by_name(&skel, "LeftArm").expect("should succeed");
477 let still_identity = joint.rotation[3].abs() > 0.9999;
478 assert!(!still_identity || delta[3] > 0.9999);
480 }
481
482 #[test]
483 fn test_apply_pose_offset_missing_joint() {
484 let mut skel = make_simple_skeleton();
485 apply_pose_offset(&mut skel, "NonExistent", identity_quat());
487 }
488}