1use glam::{vec3, EulerRot, Mat3, Mat4, Quat, Vec3};
2use nalgebra::{DMatrix, Dyn, Matrix, VecStorage, U1};
3use std::f32::consts::PI;
4
5const DAMPING: f32 = 10.0;
6const THRESHOLD: f32 = 10.5;
7
8#[derive(Debug, PartialEq, Clone, Copy)]
9pub struct JointLimit {
10 pub axis: Vec3,
11 pub min: f32,
12 pub max: f32,
13}
14
15#[derive(Debug, PartialEq, Clone)]
16pub struct IKJointControl {
17 pub joint_id: usize,
18
19 pub restrict_rotation_axis: Option<Vec3>,
21
22 pub limits: Option<Vec<JointLimit>>,
23 }
25
26impl IKJointControl {
27 pub fn new(joint_id: usize) -> Self {
28 IKJointControl {
29 joint_id,
30 restrict_rotation_axis: None,
31 limits: None,
32 }
33 }
34
35 pub fn with_axis_constraint(self, restrict_rotation_axis: Vec3) -> Self {
36 IKJointControl {
37 restrict_rotation_axis: Some(restrict_rotation_axis),
38 ..self
39 }
40 }
41
42 pub fn with_limits(self, limits: &[JointLimit]) -> Self {
43 IKJointControl {
44 limits: Some(limits.to_vec()),
45 ..self
46 }
47 }
48}
49
50#[derive(Debug, PartialEq, Clone, Copy)]
51pub enum IKGoalKind {
52 Position(Vec3),
53
54 Rotation(Mat3),
56
57 RotY(f32),
60}
61
62#[derive(Debug, PartialEq, Clone, Copy)]
63pub struct IKGoal {
64 pub end_effector_id: usize,
65 pub kind: IKGoalKind,
66}
67
68fn is_parallel(a: Vec3, b: Vec3) -> bool {
69 fuzzy_compare_f32(a.normalize().dot(b.normalize()).abs(), 1.0)
70}
71
72fn get_rotation_axis(to_e: Vec3, target_direction: Vec3) -> Vec3 {
73 let raw_axis = if !is_parallel(to_e, target_direction) {
74 target_direction.cross(to_e)
75 } else if !is_parallel(Vec3::Y, target_direction) {
76 target_direction.cross(Vec3::Y)
77 } else {
78 target_direction.cross(Vec3::Z)
79 };
80
81 if raw_axis == Vec3::ZERO {
84 Vec3::ZERO
85 } else {
86 raw_axis.normalize()
87 }
88}
89
90fn get_num_effector_components(goals: &[IKGoal]) -> usize {
91 goals
92 .iter()
93 .map(|g| match g {
94 IKGoal {
95 kind: IKGoalKind::Position(_),
96 ..
97 } => 3,
98 IKGoal {
99 kind: IKGoalKind::Rotation(_),
100 ..
101 } => 3,
102 IKGoal {
103 kind: IKGoalKind::RotY(_),
104 ..
105 } => 1,
106 })
107 .sum()
108}
109
110struct IterableVec3(Vec3);
111
112impl IntoIterator for IterableVec3 {
113 type Item = f32;
114 type IntoIter = std::array::IntoIter<f32, 3>;
115
116 fn into_iter(self) -> Self::IntoIter {
117 [self.0.x, self.0.y, self.0.z].into_iter()
118 }
119}
120
121enum DoFKind {
123 Quaternion { axis: Vec3 },
124}
125
126struct DegreeOfFreedom {
127 joint_id: usize,
128 kind: DoFKind,
129 influences: Vec<f32>, }
131
132fn ang_diff(a: f32, b: f32) -> f32 {
133 let delta = b - a;
134 (delta + PI) % (2.0 * PI) - PI
135}
136
137const USE_QUATERNIONS: bool = true;
138
139fn build_dof_data(
140 full_skeleton: &[Mat4],
141 affected_joints: &[IKJointControl],
142 goals: &[IKGoal],
143 use_quaternions: bool,
144) -> Vec<DegreeOfFreedom> {
145 let mut dof_data: Vec<DegreeOfFreedom> = Vec::new();
146
147 for joint in affected_joints {
148 let origin_of_rotation = full_skeleton[joint.joint_id].transform_point3(Vec3::ZERO);
150
151 if use_quaternions {
153 for (goal_idx, goal) in goals.iter().enumerate() {
154 let mut influences: Vec<f32> = Vec::new();
155
156 let (temp_influences, axis) = match goal.kind {
157 IKGoalKind::Position(goal_position) => {
158 let end_effector_pos =
159 full_skeleton[goal.end_effector_id].transform_point3(Vec3::ZERO);
160 let target_direction = goal_position - end_effector_pos;
161
162 let to_e = end_effector_pos - origin_of_rotation;
163 let axis_of_rotation = if let Some(axis) = joint.restrict_rotation_axis {
164 let joint_rot = Quat::from_mat4(&full_skeleton[joint.joint_id]);
165 joint_rot * axis
166 } else {
167 get_rotation_axis(to_e, target_direction)
168 };
169
170 let influence = axis_of_rotation.cross(to_e);
171
172 (
173 vec![influence.x, influence.y, influence.z],
174 axis_of_rotation,
175 )
176 }
177 IKGoalKind::Rotation(goal_rotation) => {
178 let end_effector_rot =
179 Quat::from_mat4(&full_skeleton[goal.end_effector_id]);
180
181 let rotation = Quat::from_mat3(&goal_rotation) * end_effector_rot.inverse();
182
183 let axis_of_rotation = if let Some(axis) = joint.restrict_rotation_axis {
184 let joint_rot = Quat::from_mat4(&full_skeleton[joint.joint_id]);
185 joint_rot * axis
186 } else {
187 let (axis_of_rotation, angle) = rotation.to_axis_angle_180();
188
189 if angle < 0.0 {
190 -axis_of_rotation
191 } else {
192 axis_of_rotation
193 }
194 };
195
196 let influence = axis_of_rotation;
197 (
198 vec![influence.x, influence.y, influence.z],
199 axis_of_rotation,
200 )
201 }
202 IKGoalKind::RotY(_) => {
203 let axis_of_rotation = if let Some(axis) = joint.restrict_rotation_axis {
204 let joint_rot = Quat::from_mat4(&full_skeleton[joint.joint_id]);
205 joint_rot * axis
206 } else {
207 Vec3::Y
208 };
209
210 let influence = Quat::from_axis_angle(axis_of_rotation, 1.0)
211 .to_euler(EulerRot::YXZ)
212 .0;
213
214 (vec![influence], axis_of_rotation)
215 }
216 };
217
218 for g2_idx in 0..goals.len() {
219 if g2_idx == goal_idx {
220 influences.extend(&temp_influences);
221 } else {
222 match goals[g2_idx].kind {
223 IKGoalKind::Position(_) => {
224 influences.push(0.0);
233 influences.push(0.0);
234 influences.push(0.0);
235 }
236 IKGoalKind::Rotation(_) => {
237 influences.push(0.0);
241 influences.push(0.0);
242 influences.push(0.0);
243 }
244 IKGoalKind::RotY(_) => {
245 influences.push(0.0); }
247 }
248 }
249 }
250
251 dof_data.push(DegreeOfFreedom {
253 joint_id: joint.joint_id,
254 kind: DoFKind::Quaternion { axis },
255 influences,
256 });
257 }
258 }
259 }
260
261 dof_data
262}
263
264fn pseudo_inverse_damped_least_squares(
265 jacobian: &DMatrix<f32>,
266 num_effectors: usize,
267) -> DMatrix<f32> {
268 let jac_transp = jacobian.transpose();
269
270 let damping_ident_matrix = nalgebra::DMatrix::<f32>::identity(num_effectors, num_effectors);
271
272 let jacobian_square = jacobian * &jac_transp;
273
274 let jacobian_square = jacobian_square + DAMPING * DAMPING * damping_ident_matrix;
275
276 let jac_inv = jac_transp * jacobian_square.try_inverse().unwrap(); jac_inv
279}
280
281pub fn solve(
283 full_skeleton: &mut [Mat4],
284 local_bind_pose: &[Mat4],
285 parents: &[i32],
286 affected_joints: &[IKJointControl],
287 goals: &[IKGoal],
288) {
289 let dof_data = build_dof_data(full_skeleton, affected_joints, goals, USE_QUATERNIONS);
290
291 let num_goal_components = get_num_effector_components(goals);
292 let num_dof_components = dof_data.len();
293
294 let jacobian = DMatrix::<f32>::from_fn(num_goal_components, num_dof_components, |i, j| {
295 dof_data[j].influences[i]
296 });
297
298 let jac_inv = pseudo_inverse_damped_least_squares(&jacobian, num_goal_components);
299
300 let effectors: Vec<_> = goals
301 .iter()
302 .map(|goal| match goal.kind {
303 IKGoalKind::Position(goal_position) => {
304 let end_effector_pos =
305 full_skeleton[goal.end_effector_id].transform_point3(Vec3::ZERO);
306 IterableVec3(goal_position - end_effector_pos)
307 .into_iter()
308 .collect::<Vec<_>>()
309 }
310 IKGoalKind::Rotation(goal_rotation) => {
311 let end_effector_rot = Quat::from_mat4(&full_skeleton[goal.end_effector_id]);
312
313 let r = Quat::from_mat3(&goal_rotation) * end_effector_rot.inverse();
314
315 let (axis_of_rotation, angle) = r.to_axis_angle_180();
316
317 let scaled_axis = axis_of_rotation * angle;
318
319 IterableVec3(scaled_axis).into_iter().collect::<Vec<_>>()
320 }
321 IKGoalKind::RotY(goal_rot_y) => {
322 let end_effector_rot = Quat::from_mat4(&full_skeleton[goal.end_effector_id])
323 .to_euler(EulerRot::YXZ)
324 .0;
325 let delta = ang_diff(end_effector_rot, goal_rot_y);
326 vec![delta]
327 }
328 })
329 .flatten()
330 .collect();
331
332 let effector_vec = Matrix::<f32, Dyn, U1, VecStorage<f32, Dyn, U1>>::from_vec(effectors);
333
334 let theta = &jac_inv * &effector_vec;
336 let threshold = THRESHOLD.to_radians();
337 let max_angle = theta.amax();
338 let beta = threshold / f32::max(max_angle, threshold);
339
340 let previous_skeleton = full_skeleton.to_vec();
342
343 let mut raw_joint_xforms = full_skeleton.to_vec();
345
346 for (theta_idx, dof) in dof_data.iter().enumerate() {
347 let joint_xform = raw_joint_xforms[dof.joint_id];
348 let translation = joint_xform.transform_point3(Vec3::ZERO);
349 let rotation = Quat::from_mat4(&joint_xform);
350
351 #[allow(irrefutable_let_patterns)]
352 if let DoFKind::Quaternion { axis } = dof.kind {
353 let world_rot = Quat::from_axis_angle(axis, beta * theta[theta_idx]);
354
355 let end_rot = world_rot * rotation;
356
357 raw_joint_xforms[dof.joint_id] = Mat4::from_rotation_translation(end_rot, translation);
358 }
359 }
360
361 for dof in &dof_data {
363 let parent_id = parents[dof.joint_id];
364 let parent_xform = *full_skeleton
365 .get(parent_id as usize)
366 .unwrap_or(&Mat4::IDENTITY);
367 let parent_xform_old = *previous_skeleton
368 .get(parent_id as usize)
369 .unwrap_or(&Mat4::IDENTITY);
370
371 let local_rot =
372 Quat::from_mat4(&(parent_xform_old.inverse() * raw_joint_xforms[dof.joint_id]))
373 .normalize();
374
375 let local_translation = local_bind_pose[dof.joint_id].transform_point3(Vec3::ZERO);
376 let mut world_xform =
377 parent_xform * Mat4::from_rotation_translation(local_rot, local_translation);
378
379 let joint_cfg = affected_joints
380 .iter()
381 .find(|joint| joint.joint_id == dof.joint_id)
382 .unwrap();
383
384 if let Some(limits) = &joint_cfg.limits {
385 let local_xform = Quat::from_mat4(&(parent_xform.inverse() * world_xform));
386 let local_bind_xform = Quat::from_mat4(&(local_bind_pose[dof.joint_id]));
387
388 let local_change = local_bind_xform.inverse() * local_xform;
390 for limit in limits.iter() {
391 let custom_axis = local_change * limit.axis;
392 let angle = swing_twist_decompose(local_change, custom_axis);
393
394 if angle < limit.min {
395 world_xform = world_xform
396 * Mat4::from_quat(Quat::from_axis_angle(custom_axis, limit.min - angle));
397 } else if angle > limit.max {
398 world_xform = world_xform
399 * Mat4::from_quat(Quat::from_axis_angle(custom_axis, limit.max - angle));
400 }
401 }
402 }
403
404 full_skeleton[dof.joint_id] = world_xform;
405 }
406}
407
408fn swing_twist_decompose(q: Quat, dir: Vec3) -> f32 {
411 let rotation_axis = vec3(q.x, q.y, q.z);
412 let dot_prod = dir.dot(rotation_axis);
413 let p = dir * dot_prod;
414 let mut twist = Quat::from_xyzw(p.x, p.y, p.z, q.w).normalize();
415
416 if dot_prod < 0.0 {
417 twist = -twist;
418 }
419
420 twist.to_axis_angle_180().1
421}
422
423trait ToAxisAngle180 {
424 fn to_axis_angle_180(self) -> (Vec3, f32);
425}
426
427impl ToAxisAngle180 for Quat {
429 fn to_axis_angle_180(self) -> (Vec3, f32) {
430 let (axis, angle) = self.to_axis_angle();
431 let angle = (angle + PI) % (2.0 * PI) - PI;
432 (axis, angle)
433 }
434}
435
436#[allow(unused)] fn fuzzy_compare_vec3(a: Vec3, b: Vec3) -> bool {
438 let epsilon = 0.01;
439 (a.x - b.x).abs() < epsilon && (a.y - b.y).abs() < epsilon && (a.z - b.z).abs() < epsilon
440}
441
442fn fuzzy_compare_f32(a: f32, b: f32) -> bool {
443 let epsilon = 0.0001;
444 (a - b).abs() < epsilon
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_solve_orient() {
453 let mut skeleton = [
454 Mat4::from_translation(Vec3::new(0.0, 0.0, 0.0)),
455 Mat4::from_translation(Vec3::new(0.0, 0.0, 1.0)),
456 Mat4::from_translation(Vec3::new(0.0, 0.0, 2.0)),
457 ];
458 let bind_pose = skeleton.to_vec();
459 let parents = [-1, 0, 1];
460 let affected_joints = [
461 IKJointControl::new(0),
462 IKJointControl::new(1),
463 IKJointControl::new(2),
464 ];
465
466 let expected_rot_mat = Mat3::from_axis_angle(Vec3::Y, 45f32.to_radians());
467
468 let goals = [IKGoal {
469 end_effector_id: 2,
470 kind: IKGoalKind::Rotation(expected_rot_mat),
471 }];
472
473 for _ in 0..200 {
474 solve(
475 &mut skeleton,
476 &bind_pose,
477 &parents,
478 &affected_joints,
479 &goals,
480 );
481 }
482
483 let expected_rot = Quat::from_mat3(&expected_rot_mat);
484 let actual_rot = Quat::from_mat4(&skeleton[2]);
485
486 assert!(
487 fuzzy_compare_vec3(expected_rot.to_scaled_axis(), actual_rot.to_scaled_axis()),
488 "Expected: {:?}, Actual: {:?}",
489 expected_rot.to_axis_angle(),
490 actual_rot.to_axis_angle()
491 );
492 }
493}