use std::collections::HashMap;
use hisab::{Mat4, Quat, Vec3};
use tracing::instrument;
use crate::pose::Pose;
use crate::skeleton::{BoneId, Skeleton};
#[derive(Debug, Clone)]
pub struct WorldTransforms {
matrices: Vec<Mat4>,
index: HashMap<BoneId, usize>,
}
impl WorldTransforms {
#[must_use]
#[inline]
pub fn position(&self, bone_id: BoneId) -> Option<Vec3> {
self.index
.get(&bone_id)
.map(|&i| self.matrices[i].col(3).truncate())
}
#[must_use]
pub fn rotation(&self, bone_id: BoneId) -> Option<Quat> {
self.index.get(&bone_id).map(|&i| {
let m = self.matrices[i];
let col0 = m.col(0).truncate().normalize();
let col1 = m.col(1).truncate().normalize();
let col2 = m.col(2).truncate().normalize();
let rot_mat = hisab::Mat3::from_cols(col0, col1, col2);
Quat::from_mat3(&rot_mat)
})
}
#[must_use]
#[inline]
pub fn matrix(&self, bone_id: BoneId) -> Option<Mat4> {
self.index.get(&bone_id).map(|&i| self.matrices[i])
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.matrices.len()
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.matrices.is_empty()
}
}
#[must_use]
#[instrument(skip_all, fields(skeleton = %skeleton.name, bone_count = skeleton.bone_count()))]
pub fn forward_kinematics(
skeleton: &Skeleton,
pose: &Pose,
root_position: Vec3,
root_rotation: Quat,
) -> WorldTransforms {
let bones = skeleton.bones();
let mut matrices = Vec::with_capacity(bones.len());
let mut index = HashMap::with_capacity(bones.len());
let root_mat = Mat4::from_rotation_translation(root_rotation, root_position);
for (i, bone) in bones.iter().enumerate() {
index.insert(bone.id, i);
let parent_mat = bone
.parent
.and_then(|pid| index.get(&pid))
.map(|&pi| matrices[pi])
.unwrap_or(root_mat);
let pose_rot = pose.get_joint(bone.id);
let combined_rot = bone.local_rotation * pose_rot;
let local_mat = Mat4::from_rotation_translation(combined_rot, bone.local_position);
let world_mat = parent_mat * local_mat;
matrices.push(world_mat);
}
WorldTransforms { matrices, index }
}
#[must_use]
pub fn world_center_of_mass(skeleton: &Skeleton, transforms: &WorldTransforms) -> Vec3 {
let mut total_mass = 0.0_f32;
let mut weighted_pos = Vec3::ZERO;
for bone in skeleton.bones() {
if let Some(pos) = transforms.position(bone.id) {
weighted_pos += bone.mass * pos;
total_mass += bone.mass;
}
}
if total_mass > 0.0 {
weighted_pos / total_mass
} else {
Vec3::ZERO
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::skeleton::Bone;
fn make_chain_skeleton() -> Skeleton {
Skeleton {
name: "chain".into(),
bones: vec![
Bone {
id: BoneId(0),
name: "root".into(),
parent: None,
length: 1.0,
mass: 1.0,
local_position: Vec3::ZERO,
local_rotation: Quat::IDENTITY,
},
Bone {
id: BoneId(1),
name: "mid".into(),
parent: Some(BoneId(0)),
length: 1.0,
mass: 1.0,
local_position: Vec3::new(0.0, 1.0, 0.0),
local_rotation: Quat::IDENTITY,
},
Bone {
id: BoneId(2),
name: "tip".into(),
parent: Some(BoneId(1)),
length: 1.0,
mass: 1.0,
local_position: Vec3::new(0.0, 1.0, 0.0),
local_rotation: Quat::IDENTITY,
},
],
}
}
#[test]
fn rest_pose_preserves_local_transforms() {
let skeleton = make_chain_skeleton();
let pose = Pose::rest(3);
let transforms = forward_kinematics(&skeleton, &pose, Vec3::ZERO, Quat::IDENTITY);
assert_eq!(transforms.len(), 3);
let p0 = transforms.position(BoneId(0)).unwrap();
let p1 = transforms.position(BoneId(1)).unwrap();
let p2 = transforms.position(BoneId(2)).unwrap();
assert!((p0 - Vec3::ZERO).length() < 1e-5, "root at origin");
assert!(
(p1 - Vec3::new(0.0, 1.0, 0.0)).length() < 1e-5,
"mid at y=1"
);
assert!(
(p2 - Vec3::new(0.0, 2.0, 0.0)).length() < 1e-5,
"tip at y=2"
);
}
#[test]
fn root_offset_shifts_all_bones() {
let skeleton = make_chain_skeleton();
let pose = Pose::rest(3);
let offset = Vec3::new(10.0, 0.0, 0.0);
let transforms = forward_kinematics(&skeleton, &pose, offset, Quat::IDENTITY);
let p0 = transforms.position(BoneId(0)).unwrap();
let p2 = transforms.position(BoneId(2)).unwrap();
assert!(
(p0 - Vec3::new(10.0, 0.0, 0.0)).length() < 1e-5,
"root shifted by offset"
);
assert!(
(p2 - Vec3::new(10.0, 2.0, 0.0)).length() < 1e-5,
"tip shifted by offset"
);
}
#[test]
fn pose_rotation_affects_children() {
let skeleton = make_chain_skeleton();
let mut pose = Pose::rest(3);
pose.set_joint(
BoneId(0),
Quat::from_rotation_z(std::f32::consts::FRAC_PI_2),
);
let transforms = forward_kinematics(&skeleton, &pose, Vec3::ZERO, Quat::IDENTITY);
let p1 = transforms.position(BoneId(1)).unwrap();
assert!(
(p1 - Vec3::new(-1.0, 0.0, 0.0)).length() < 1e-4,
"mid bone should be at (-1,0,0) after 90° Z rotation, got ({:.3},{:.3},{:.3})",
p1.x,
p1.y,
p1.z
);
let p2 = transforms.position(BoneId(2)).unwrap();
assert!(
(p2 - Vec3::new(-2.0, 0.0, 0.0)).length() < 1e-4,
"tip should be at (-2,0,0), got ({:.3},{:.3},{:.3})",
p2.x,
p2.y,
p2.z
);
}
#[test]
fn root_rotation_rotates_entire_skeleton() {
let skeleton = make_chain_skeleton();
let pose = Pose::rest(3);
let root_rot = Quat::from_rotation_z(std::f32::consts::FRAC_PI_2);
let transforms = forward_kinematics(&skeleton, &pose, Vec3::ZERO, root_rot);
let p2 = transforms.position(BoneId(2)).unwrap();
assert!(
(p2 - Vec3::new(-2.0, 0.0, 0.0)).length() < 1e-4,
"tip should be at (-2,0,0) with root rotated 90°"
);
}
#[test]
fn world_com_at_rest() {
let skeleton = make_chain_skeleton();
let pose = Pose::rest(3);
let transforms = forward_kinematics(&skeleton, &pose, Vec3::ZERO, Quat::IDENTITY);
let com = world_center_of_mass(&skeleton, &transforms);
assert!(
(com - Vec3::new(0.0, 1.0, 0.0)).length() < 1e-5,
"CoM should be at (0,1,0), got ({:.3},{:.3},{:.3})",
com.x,
com.y,
com.z
);
}
#[test]
fn rotation_extraction() {
let skeleton = make_chain_skeleton();
let mut pose = Pose::rest(3);
let rot45 = Quat::from_rotation_z(std::f32::consts::FRAC_PI_4);
pose.set_joint(BoneId(0), rot45);
let transforms = forward_kinematics(&skeleton, &pose, Vec3::ZERO, Quat::IDENTITY);
let got = transforms.rotation(BoneId(0)).unwrap();
assert!(
got.dot(rot45).abs() > 0.999,
"extracted rotation should match pose rotation"
);
}
#[test]
fn empty_skeleton_produces_empty_transforms() {
let skeleton = Skeleton::new("empty");
let pose = Pose::rest(0);
let transforms = forward_kinematics(&skeleton, &pose, Vec3::ZERO, Quat::IDENTITY);
assert!(transforms.is_empty());
}
}