lux_ik/
ik.rs

1use glam::{vec3, EulerRot, Mat3, Mat4, Quat, Vec3};
2use nalgebra::{DMatrix, Dyn, Matrix, VecStorage, U1};
3use std::f32::consts::PI;
4
5const DAMPING: f32 = 10.0;
6const THRESHOLD: f32 = 10.5;
7
8#[derive(Debug, PartialEq, Clone, Copy)]
9pub struct JointLimit {
10    pub axis: Vec3,
11    pub min: f32,
12    pub max: f32,
13}
14
15#[derive(Debug, PartialEq, Clone)]
16pub struct IKJointControl {
17    pub joint_id: usize,
18
19    // if set, limits the rotation axis to this (in joint space)
20    pub restrict_rotation_axis: Option<Vec3>,
21
22    pub limits: Option<Vec<JointLimit>>,
23    // TODO: Weights
24}
25
26impl IKJointControl {
27    pub fn new(joint_id: usize) -> Self {
28        IKJointControl {
29            joint_id,
30            restrict_rotation_axis: None,
31            limits: None,
32        }
33    }
34
35    pub fn with_axis_constraint(self, restrict_rotation_axis: Vec3) -> Self {
36        IKJointControl {
37            restrict_rotation_axis: Some(restrict_rotation_axis),
38            ..self
39        }
40    }
41
42    pub fn with_limits(self, limits: &[JointLimit]) -> Self {
43        IKJointControl {
44            limits: Some(limits.to_vec()),
45            ..self
46        }
47    }
48}
49
50#[derive(Debug, PartialEq, Clone, Copy)]
51pub enum IKGoalKind {
52    Position(Vec3),
53
54    // TODO: Should probably be a Quat
55    Rotation(Mat3),
56
57    // Orientation goal that only cares about orientation around the world Y axis but leaves the
58    // other axes free.
59    RotY(f32),
60}
61
62#[derive(Debug, PartialEq, Clone, Copy)]
63pub struct IKGoal {
64    pub end_effector_id: usize,
65    pub kind: IKGoalKind,
66}
67
68fn is_parallel(a: Vec3, b: Vec3) -> bool {
69    fuzzy_compare_f32(a.normalize().dot(b.normalize()).abs(), 1.0)
70}
71
72fn get_rotation_axis(to_e: Vec3, target_direction: Vec3) -> Vec3 {
73    let raw_axis = if !is_parallel(to_e, target_direction) {
74        target_direction.cross(to_e)
75    } else if !is_parallel(Vec3::Y, target_direction) {
76        target_direction.cross(Vec3::Y)
77    } else {
78        target_direction.cross(Vec3::Z)
79    };
80
81    // if we are very close to the end effector, then there's no more useful rotation axis.
82    // Returning ZERO will cause the influence to be 0. This also prevents normalize from being NaN.
83    if raw_axis == Vec3::ZERO {
84        Vec3::ZERO
85    } else {
86        raw_axis.normalize()
87    }
88}
89
90fn get_num_effector_components(goals: &[IKGoal]) -> usize {
91    goals
92        .iter()
93        .map(|g| match g {
94            IKGoal {
95                kind: IKGoalKind::Position(_),
96                ..
97            } => 3,
98            IKGoal {
99                kind: IKGoalKind::Rotation(_),
100                ..
101            } => 3,
102            IKGoal {
103                kind: IKGoalKind::RotY(_),
104                ..
105            } => 1,
106        })
107        .sum()
108}
109
110struct IterableVec3(Vec3);
111
112impl IntoIterator for IterableVec3 {
113    type Item = f32;
114    type IntoIter = std::array::IntoIter<f32, 3>;
115
116    fn into_iter(self) -> Self::IntoIter {
117        [self.0.x, self.0.y, self.0.z].into_iter()
118    }
119}
120
121// the different methods of joint movement
122enum DoFKind {
123    Quaternion { axis: Vec3 },
124}
125
126struct DegreeOfFreedom {
127    joint_id: usize,
128    kind: DoFKind,
129    influences: Vec<f32>, // one per end effector
130}
131
132fn ang_diff(a: f32, b: f32) -> f32 {
133    let delta = b - a;
134    (delta + PI) % (2.0 * PI) - PI
135}
136
137const USE_QUATERNIONS: bool = true;
138
139fn build_dof_data(
140    full_skeleton: &[Mat4],
141    affected_joints: &[IKJointControl],
142    goals: &[IKGoal],
143    use_quaternions: bool,
144) -> Vec<DegreeOfFreedom> {
145    let mut dof_data: Vec<DegreeOfFreedom> = Vec::new();
146
147    for joint in affected_joints {
148        // TODO: Use w_axis.truncate(); instead of transform_point3
149        let origin_of_rotation = full_skeleton[joint.joint_id].transform_point3(Vec3::ZERO);
150
151        // Quaternion joints are a special case because they allow infinite number of rotation axes.
152        if use_quaternions {
153            for (goal_idx, goal) in goals.iter().enumerate() {
154                let mut influences: Vec<f32> = Vec::new();
155
156                let (temp_influences, axis) = match goal.kind {
157                    IKGoalKind::Position(goal_position) => {
158                        let end_effector_pos =
159                            full_skeleton[goal.end_effector_id].transform_point3(Vec3::ZERO);
160                        let target_direction = goal_position - end_effector_pos;
161
162                        let to_e = end_effector_pos - origin_of_rotation;
163                        let axis_of_rotation = if let Some(axis) = joint.restrict_rotation_axis {
164                            let joint_rot = Quat::from_mat4(&full_skeleton[joint.joint_id]);
165                            joint_rot * axis
166                        } else {
167                            get_rotation_axis(to_e, target_direction)
168                        };
169
170                        let influence = axis_of_rotation.cross(to_e);
171
172                        (
173                            vec![influence.x, influence.y, influence.z],
174                            axis_of_rotation,
175                        )
176                    }
177                    IKGoalKind::Rotation(goal_rotation) => {
178                        let end_effector_rot =
179                            Quat::from_mat4(&full_skeleton[goal.end_effector_id]);
180
181                        let rotation = Quat::from_mat3(&goal_rotation) * end_effector_rot.inverse();
182
183                        let axis_of_rotation = if let Some(axis) = joint.restrict_rotation_axis {
184                            let joint_rot = Quat::from_mat4(&full_skeleton[joint.joint_id]);
185                            joint_rot * axis
186                        } else {
187                            let (axis_of_rotation, angle) = rotation.to_axis_angle_180();
188
189                            if angle < 0.0 {
190                                -axis_of_rotation
191                            } else {
192                                axis_of_rotation
193                            }
194                        };
195
196                        let influence = axis_of_rotation;
197                        (
198                            vec![influence.x, influence.y, influence.z],
199                            axis_of_rotation,
200                        )
201                    }
202                    IKGoalKind::RotY(_) => {
203                        let axis_of_rotation = if let Some(axis) = joint.restrict_rotation_axis {
204                            let joint_rot = Quat::from_mat4(&full_skeleton[joint.joint_id]);
205                            joint_rot * axis
206                        } else {
207                            Vec3::Y
208                        };
209
210                        let influence = Quat::from_axis_angle(axis_of_rotation, 1.0)
211                            .to_euler(EulerRot::YXZ)
212                            .0;
213
214                        (vec![influence], axis_of_rotation)
215                    }
216                };
217
218                for g2_idx in 0..goals.len() {
219                    if g2_idx == goal_idx {
220                        influences.extend(&temp_influences);
221                    } else {
222                        match goals[g2_idx].kind {
223                            IKGoalKind::Position(_) => {
224                                // TODO: Calculate these values in order to make the IK converge faster
225                                // let end_effector_pos = full_skeleton[goals[g2_idx].end_effector_id]
226                                //     .transform_point3(Vec3::ZERO);
227                                // let to_e = end_effector_pos - origin_of_rotation;
228                                // let influence = axis.cross(to_e);
229                                // influences.push(influence.x);
230                                // influences.push(influence.y);
231                                // influences.push(influence.z);
232                                influences.push(0.0);
233                                influences.push(0.0);
234                                influences.push(0.0);
235                            }
236                            IKGoalKind::Rotation(_) => {
237                                // influences.push(axis.x);
238                                // influences.push(axis.y);
239                                // influences.push(axis.z);
240                                influences.push(0.0);
241                                influences.push(0.0);
242                                influences.push(0.0);
243                            }
244                            IKGoalKind::RotY(_) => {
245                                influences.push(0.0); // TODO
246                            }
247                        }
248                    }
249                }
250
251                // Quaternion joints require us to create new DegreesOfFreedom on demand for each new rotation axis.
252                dof_data.push(DegreeOfFreedom {
253                    joint_id: joint.joint_id,
254                    kind: DoFKind::Quaternion { axis },
255                    influences,
256                });
257            }
258        }
259    }
260
261    dof_data
262}
263
264fn pseudo_inverse_damped_least_squares(
265    jacobian: &DMatrix<f32>,
266    num_effectors: usize,
267) -> DMatrix<f32> {
268    let jac_transp = jacobian.transpose();
269
270    let damping_ident_matrix = nalgebra::DMatrix::<f32>::identity(num_effectors, num_effectors);
271
272    let jacobian_square = jacobian * &jac_transp;
273
274    let jacobian_square = jacobian_square + DAMPING * DAMPING * damping_ident_matrix;
275
276    let jac_inv = jac_transp * jacobian_square.try_inverse().unwrap(); // TODO: Handle error
277
278    jac_inv
279}
280
281// Important: The joint chain must be in topological order
282pub fn solve(
283    full_skeleton: &mut [Mat4],
284    local_bind_pose: &[Mat4],
285    parents: &[i32],
286    affected_joints: &[IKJointControl],
287    goals: &[IKGoal],
288) {
289    let dof_data = build_dof_data(full_skeleton, affected_joints, goals, USE_QUATERNIONS);
290
291    let num_goal_components = get_num_effector_components(goals);
292    let num_dof_components = dof_data.len();
293
294    let jacobian = DMatrix::<f32>::from_fn(num_goal_components, num_dof_components, |i, j| {
295        dof_data[j].influences[i]
296    });
297
298    let jac_inv = pseudo_inverse_damped_least_squares(&jacobian, num_goal_components);
299
300    let effectors: Vec<_> = goals
301        .iter()
302        .map(|goal| match goal.kind {
303            IKGoalKind::Position(goal_position) => {
304                let end_effector_pos =
305                    full_skeleton[goal.end_effector_id].transform_point3(Vec3::ZERO);
306                IterableVec3(goal_position - end_effector_pos)
307                    .into_iter()
308                    .collect::<Vec<_>>()
309            }
310            IKGoalKind::Rotation(goal_rotation) => {
311                let end_effector_rot = Quat::from_mat4(&full_skeleton[goal.end_effector_id]);
312
313                let r = Quat::from_mat3(&goal_rotation) * end_effector_rot.inverse();
314
315                let (axis_of_rotation, angle) = r.to_axis_angle_180();
316
317                let scaled_axis = axis_of_rotation * angle;
318
319                IterableVec3(scaled_axis).into_iter().collect::<Vec<_>>()
320            }
321            IKGoalKind::RotY(goal_rot_y) => {
322                let end_effector_rot = Quat::from_mat4(&full_skeleton[goal.end_effector_id])
323                    .to_euler(EulerRot::YXZ)
324                    .0;
325                let delta = ang_diff(end_effector_rot, goal_rot_y);
326                vec![delta]
327            }
328        })
329        .flatten()
330        .collect();
331
332    let effector_vec = Matrix::<f32, Dyn, U1, VecStorage<f32, Dyn, U1>>::from_vec(effectors);
333
334    // Theta is the resulting angles that we want to rotate by
335    let theta = &jac_inv * &effector_vec;
336    let threshold = THRESHOLD.to_radians();
337    let max_angle = theta.amax();
338    let beta = threshold / f32::max(max_angle, threshold);
339
340    // Need to remember the original joint transforms
341    let previous_skeleton = full_skeleton.to_vec();
342
343    // Our rotation axis is in world space, but during the rotation our position needs to stay fixed.
344    let mut raw_joint_xforms = full_skeleton.to_vec();
345
346    for (theta_idx, dof) in dof_data.iter().enumerate() {
347        let joint_xform = raw_joint_xforms[dof.joint_id];
348        let translation = joint_xform.transform_point3(Vec3::ZERO);
349        let rotation = Quat::from_mat4(&joint_xform);
350
351        #[allow(irrefutable_let_patterns)]
352        if let DoFKind::Quaternion { axis } = dof.kind {
353            let world_rot = Quat::from_axis_angle(axis, beta * theta[theta_idx]);
354
355            let end_rot = world_rot * rotation;
356
357            raw_joint_xforms[dof.joint_id] = Mat4::from_rotation_translation(end_rot, translation);
358        }
359    }
360
361    // Correct the rotation of the other joints
362    for dof in &dof_data {
363        let parent_id = parents[dof.joint_id];
364        let parent_xform = *full_skeleton
365            .get(parent_id as usize)
366            .unwrap_or(&Mat4::IDENTITY);
367        let parent_xform_old = *previous_skeleton
368            .get(parent_id as usize)
369            .unwrap_or(&Mat4::IDENTITY);
370
371        let local_rot =
372            Quat::from_mat4(&(parent_xform_old.inverse() * raw_joint_xforms[dof.joint_id]))
373                .normalize();
374
375        let local_translation = local_bind_pose[dof.joint_id].transform_point3(Vec3::ZERO);
376        let mut world_xform =
377            parent_xform * Mat4::from_rotation_translation(local_rot, local_translation);
378
379        let joint_cfg = affected_joints
380            .iter()
381            .find(|joint| joint.joint_id == dof.joint_id)
382            .unwrap();
383
384        if let Some(limits) = &joint_cfg.limits {
385            let local_xform = Quat::from_mat4(&(parent_xform.inverse() * world_xform));
386            let local_bind_xform = Quat::from_mat4(&(local_bind_pose[dof.joint_id]));
387
388            // limits are relative to the local bind pose of the joints
389            let local_change = local_bind_xform.inverse() * local_xform;
390            for limit in limits.iter() {
391                let custom_axis = local_change * limit.axis;
392                let angle = swing_twist_decompose(local_change, custom_axis);
393
394                if angle < limit.min {
395                    world_xform = world_xform
396                        * Mat4::from_quat(Quat::from_axis_angle(custom_axis, limit.min - angle));
397                } else if angle > limit.max {
398                    world_xform = world_xform
399                        * Mat4::from_quat(Quat::from_axis_angle(custom_axis, limit.max - angle));
400                }
401            }
402        }
403
404        full_skeleton[dof.joint_id] = world_xform;
405    }
406}
407
408// Retrieve the angle of rotation around the given axis
409// https://stackoverflow.com/questions/3684269/component-of-a-quaternion-rotation-around-an-axis
410fn swing_twist_decompose(q: Quat, dir: Vec3) -> f32 {
411    let rotation_axis = vec3(q.x, q.y, q.z);
412    let dot_prod = dir.dot(rotation_axis);
413    let p = dir * dot_prod;
414    let mut twist = Quat::from_xyzw(p.x, p.y, p.z, q.w).normalize();
415
416    if dot_prod < 0.0 {
417        twist = -twist;
418    }
419
420    twist.to_axis_angle_180().1
421}
422
423trait ToAxisAngle180 {
424    fn to_axis_angle_180(self) -> (Vec3, f32);
425}
426
427// Ensures that the angle is between -PI and PI
428impl ToAxisAngle180 for Quat {
429    fn to_axis_angle_180(self) -> (Vec3, f32) {
430        let (axis, angle) = self.to_axis_angle();
431        let angle = (angle + PI) % (2.0 * PI) - PI;
432        (axis, angle)
433    }
434}
435
436#[allow(unused)] // TODO
437fn fuzzy_compare_vec3(a: Vec3, b: Vec3) -> bool {
438    let epsilon = 0.01;
439    (a.x - b.x).abs() < epsilon && (a.y - b.y).abs() < epsilon && (a.z - b.z).abs() < epsilon
440}
441
442fn fuzzy_compare_f32(a: f32, b: f32) -> bool {
443    let epsilon = 0.0001;
444    (a - b).abs() < epsilon
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_solve_orient() {
453        let mut skeleton = [
454            Mat4::from_translation(Vec3::new(0.0, 0.0, 0.0)),
455            Mat4::from_translation(Vec3::new(0.0, 0.0, 1.0)),
456            Mat4::from_translation(Vec3::new(0.0, 0.0, 2.0)),
457        ];
458        let bind_pose = skeleton.to_vec();
459        let parents = [-1, 0, 1];
460        let affected_joints = [
461            IKJointControl::new(0),
462            IKJointControl::new(1),
463            IKJointControl::new(2),
464        ];
465
466        let expected_rot_mat = Mat3::from_axis_angle(Vec3::Y, 45f32.to_radians());
467
468        let goals = [IKGoal {
469            end_effector_id: 2,
470            kind: IKGoalKind::Rotation(expected_rot_mat),
471        }];
472
473        for _ in 0..200 {
474            solve(
475                &mut skeleton,
476                &bind_pose,
477                &parents,
478                &affected_joints,
479                &goals,
480            );
481        }
482
483        let expected_rot = Quat::from_mat3(&expected_rot_mat);
484        let actual_rot = Quat::from_mat4(&skeleton[2]);
485
486        assert!(
487            fuzzy_compare_vec3(expected_rot.to_scaled_axis(), actual_rot.to_scaled_axis()),
488            "Expected: {:?}, Actual: {:?}",
489            expected_rot.to_axis_angle(),
490            actual_rot.to_axis_angle()
491        );
492    }
493}