Skip to main content

proof_engine/animation/
ik.rs

1//! Inverse Kinematics solvers.
2//!
3//! Provides three complementary IK algorithms:
4//!
5//! - **FABRIK** (Forward And Backward Reaching IK) -- iterative, handles
6//!   multi-bone chains with constraints, converges fast in ~10 iterations.
7//! - **CCD** (Cyclic Coordinate Descent) -- rotational approach, each
8//!   joint rotates to minimize end-effector error.
9//! - **Analytical 2-bone** -- closed-form solution for shoulder/elbow/wrist
10//!   style rigs, exact, instantaneous, with elbow pole vector control.
11//!
12//! ## Quick Start
13//! ```rust,no_run
14//! use proof_engine::animation::ik::{IkChain, FabrikSolver};
15//! use glam::Vec3;
16//!
17//! let mut chain = IkChain::new(vec![
18//!     Vec3::new(0.0, 0.0, 0.0),
19//!     Vec3::new(0.0, 1.0, 0.0),
20//!     Vec3::new(0.0, 2.0, 0.0),
21//! ]);
22//! FabrikSolver::solve(&mut chain, Vec3::new(1.5, 1.5, 0.0), 10);
23//! ```
24
25use glam::{Vec2, Vec3, Quat};
26use std::f32::consts::PI;
27
28// ── IkJoint ───────────────────────────────────────────────────────────────────
29
30/// A single joint in an IK chain.
31#[derive(Debug, Clone)]
32pub struct IkJoint {
33    /// World-space position of this joint.
34    pub position:    Vec3,
35    /// Length of the bone from this joint to the next.
36    pub bone_length: f32,
37    /// Optional angle constraint: (min_angle, max_angle) in radians relative to parent.
38    pub angle_limit: Option<(f32, f32)>,
39    /// Optional twist limit around the bone axis (radians).
40    pub twist_limit: Option<(f32, f32)>,
41    /// Stiffness [0, 1]: how much this joint resists rotation. 0 = free, 1 = locked.
42    pub stiffness:   f32,
43}
44
45impl IkJoint {
46    pub fn new(position: Vec3, bone_length: f32) -> Self {
47        Self {
48            position,
49            bone_length,
50            angle_limit: None,
51            twist_limit: None,
52            stiffness: 0.0,
53        }
54    }
55
56    pub fn with_angle_limit(mut self, min: f32, max: f32) -> Self {
57        self.angle_limit = Some((min, max));
58        self
59    }
60
61    pub fn with_stiffness(mut self, s: f32) -> Self {
62        self.stiffness = s.clamp(0.0, 1.0);
63        self
64    }
65}
66
67// ── IkChain ───────────────────────────────────────────────────────────────────
68
69/// A chain of joints that IK solvers operate on.
70#[derive(Debug, Clone)]
71pub struct IkChain {
72    pub joints:        Vec<IkJoint>,
73    /// Root joint is pinned to this world-space position.
74    pub root_pin:      Vec3,
75    /// Total reach of the chain.
76    pub total_length:  f32,
77    /// Position tolerance: solver stops when end-effector is within this distance.
78    pub tolerance:     f32,
79    /// Whether the root joint can move (false = fixed root).
80    pub fixed_root:    bool,
81}
82
83impl IkChain {
84    /// Build a chain from a list of joint positions.
85    /// Bone lengths are computed from consecutive positions.
86    pub fn new(positions: Vec<Vec3>) -> Self {
87        assert!(positions.len() >= 2, "IK chain needs at least 2 joints");
88        let mut joints = Vec::with_capacity(positions.len());
89        let mut total = 0.0;
90        for i in 0..positions.len() {
91            let bone_len = if i + 1 < positions.len() {
92                (positions[i+1] - positions[i]).length()
93            } else {
94                0.0
95            };
96            total += bone_len;
97            joints.push(IkJoint::new(positions[i], bone_len));
98        }
99        let root = joints[0].position;
100        Self { joints, root_pin: root, total_length: total, tolerance: 0.001, fixed_root: true }
101    }
102
103    /// Build a chain from a root position, bone directions, and uniform bone length.
104    pub fn uniform(root: Vec3, count: usize, bone_length: f32, direction: Vec3) -> Self {
105        let dir = direction.normalize_or_zero();
106        let positions: Vec<Vec3> = (0..count+1)
107            .map(|i| root + dir * (i as f32 * bone_length))
108            .collect();
109        Self::new(positions)
110    }
111
112    pub fn with_tolerance(mut self, t: f32) -> Self { self.tolerance = t; self }
113
114    /// End-effector position (last joint).
115    pub fn end_effector(&self) -> Vec3 {
116        self.joints.last().map(|j| j.position).unwrap_or(Vec3::ZERO)
117    }
118
119    /// Number of bones in the chain.
120    pub fn bone_count(&self) -> usize {
121        self.joints.len().saturating_sub(1)
122    }
123
124    /// Check if target is reachable.
125    pub fn can_reach(&self, target: Vec3) -> bool {
126        (target - self.root_pin).length() <= self.total_length + self.tolerance
127    }
128
129    /// Reconstruct all joint positions from root_pin, preserving bone lengths.
130    pub fn rebuild_from_root(&mut self) {
131        if self.joints.is_empty() { return; }
132        self.joints[0].position = self.root_pin;
133        for i in 1..self.joints.len() {
134            let prev = self.joints[i-1].position;
135            let dir = (self.joints[i].position - prev).normalize_or_zero();
136            let len = self.joints[i-1].bone_length;
137            self.joints[i].position = prev + dir * len;
138        }
139    }
140}
141
142// ── FABRIK Solver ─────────────────────────────────────────────────────────────
143
144/// FABRIK: Forward And Backward Reaching IK.
145///
146/// Iteratively moves joints forward (end to root) then backward (root to end)
147/// until the end-effector reaches the target or iterations are exhausted.
148pub struct FabrikSolver;
149
150impl FabrikSolver {
151    /// Solve the chain in place.
152    /// Returns the number of iterations taken, or None if target was unreachable.
153    pub fn solve(chain: &mut IkChain, target: Vec3, max_iterations: usize) -> Option<usize> {
154        let n = chain.joints.len();
155        if n < 2 { return None; }
156
157        // If target is unreachable, stretch toward it
158        if !chain.can_reach(target) {
159            Self::stretch_toward(chain, target);
160            return None;
161        }
162
163        let root = chain.root_pin;
164
165        for iter in 0..max_iterations {
166            // Check convergence
167            let err = (chain.end_effector() - target).length();
168            if err <= chain.tolerance { return Some(iter); }
169
170            // Forward pass: move end-effector to target, pull joints
171            chain.joints[n-1].position = target;
172            for i in (0..n-1).rev() {
173                let dir = (chain.joints[i].position - chain.joints[i+1].position)
174                    .normalize_or_zero();
175                let len = chain.joints[i].bone_length;
176                chain.joints[i].position = chain.joints[i+1].position + dir * len;
177            }
178
179            // Backward pass: anchor root, push joints forward
180            if chain.fixed_root {
181                chain.joints[0].position = root;
182            }
183            for i in 0..n-1 {
184                let dir = (chain.joints[i+1].position - chain.joints[i].position)
185                    .normalize_or_zero();
186                let len = chain.joints[i].bone_length;
187                chain.joints[i+1].position = chain.joints[i].position + dir * len;
188            }
189
190            // Apply joint constraints (angle limits)
191            Self::apply_constraints(chain);
192        }
193
194        Some(max_iterations)
195    }
196
197    /// Solve with a pole vector hint for elbow/knee direction.
198    pub fn solve_with_pole(
199        chain: &mut IkChain,
200        target: Vec3,
201        pole: Vec3,
202        max_iterations: usize,
203    ) -> Option<usize> {
204        let result = Self::solve(chain, target, max_iterations);
205
206        // Apply pole vector influence on interior joints
207        let n = chain.joints.len();
208        if n < 3 { return result; }
209
210        for i in 1..n-1 {
211            let root_pos = chain.joints[0].position;
212            let tip_pos  = chain.joints[n-1].position;
213            let joint_pos = chain.joints[i].position;
214
215            // Project joint onto root-to-tip line
216            let bone_dir = (tip_pos - root_pos).normalize_or_zero();
217            let to_joint = joint_pos - root_pos;
218            let proj_len = to_joint.dot(bone_dir);
219            let proj = root_pos + bone_dir * proj_len;
220
221            // Pole plane: perpendicular to bone_dir
222            let to_pole = (pole - proj).normalize_or_zero();
223            let to_joint_norm = (joint_pos - proj).normalize_or_zero();
224
225            if to_pole.length_squared() < 1e-6 || to_joint_norm.length_squared() < 1e-6 {
226                continue;
227            }
228
229            // Rotate joint toward pole
230            let angle = to_joint_norm.dot(to_pole).clamp(-1.0, 1.0).acos();
231            if angle < 1e-4 { continue; }
232
233            let axis = to_joint_norm.cross(to_pole).normalize_or_zero();
234            if axis.length_squared() < 1e-6 { continue; }
235
236            let rotation = Quat::from_axis_angle(axis, angle * 0.5);
237            let dist = (joint_pos - proj).length();
238            let new_dir = rotation * to_joint_norm;
239            chain.joints[i].position = proj + new_dir * dist;
240        }
241
242        result
243    }
244
245    fn stretch_toward(chain: &mut IkChain, target: Vec3) {
246        let dir = (target - chain.root_pin).normalize_or_zero();
247        let n = chain.joints.len();
248        chain.joints[0].position = chain.root_pin;
249        for i in 1..n {
250            let len = chain.joints[i-1].bone_length;
251            chain.joints[i].position = chain.joints[i-1].position + dir * len;
252        }
253    }
254
255    fn apply_constraints(chain: &mut IkChain) {
256        let n = chain.joints.len();
257        for i in 1..n-1 {
258            if let Some((min_a, max_a)) = chain.joints[i].angle_limit {
259                // Compute current angle at this joint
260                let to_prev = (chain.joints[i-1].position - chain.joints[i].position).normalize_or_zero();
261                let to_next = (chain.joints[i+1].position - chain.joints[i].position).normalize_or_zero();
262                let angle = to_prev.dot(to_next).clamp(-1.0, 1.0).acos();
263                let clamped = angle.clamp(min_a, max_a);
264                if (clamped - angle).abs() > 1e-4 {
265                    let axis = to_prev.cross(to_next).normalize_or_zero();
266                    if axis.length_squared() > 1e-6 {
267                        let rot = Quat::from_axis_angle(axis, clamped - angle);
268                        let dist = (chain.joints[i+1].position - chain.joints[i].position).length();
269                        let new_dir = rot * to_next;
270                        chain.joints[i+1].position = chain.joints[i].position + new_dir * dist;
271                    }
272                }
273            }
274        }
275    }
276}
277
278// ── CCD Solver ────────────────────────────────────────────────────────────────
279
280/// CCD: Cyclic Coordinate Descent IK.
281///
282/// Rotates each joint in the chain to minimize the angular error at the
283/// end-effector. Slower convergence than FABRIK but naturally respects
284/// local joint angle limits.
285pub struct CcdSolver;
286
287impl CcdSolver {
288    pub fn solve(chain: &mut IkChain, target: Vec3, max_iterations: usize) -> Option<usize> {
289        let n = chain.joints.len();
290        if n < 2 { return None; }
291
292        for iter in 0..max_iterations {
293            let err = (chain.end_effector() - target).length();
294            if err <= chain.tolerance { return Some(iter); }
295
296            // Iterate from end-1 down to root
297            for j in (0..n-1).rev() {
298                let joint_pos = chain.joints[j].position;
299                let end_pos   = chain.end_effector();
300
301                let to_end    = (end_pos  - joint_pos).normalize_or_zero();
302                let to_target = (target   - joint_pos).normalize_or_zero();
303
304                if to_end.length_squared() < 1e-6 || to_target.length_squared() < 1e-6 {
305                    continue;
306                }
307
308                let dot = to_end.dot(to_target).clamp(-1.0, 1.0);
309                let mut angle = dot.acos();
310
311                // Apply stiffness
312                angle *= 1.0 - chain.joints[j].stiffness;
313
314                if angle < 1e-4 { continue; }
315
316                let axis = to_end.cross(to_target).normalize_or_zero();
317                if axis.length_squared() < 1e-6 { continue; }
318
319                // Clamp to angle limit
320                if let Some((min_a, max_a)) = chain.joints[j].angle_limit {
321                    angle = angle.clamp(min_a, max_a);
322                }
323
324                let rot = Quat::from_axis_angle(axis, angle);
325
326                // Rotate all downstream joints around this joint
327                for k in j+1..n {
328                    let offset = chain.joints[k].position - joint_pos;
329                    chain.joints[k].position = joint_pos + rot * offset;
330                }
331            }
332
333            // Re-anchor root
334            if chain.fixed_root {
335                let offset = chain.root_pin - chain.joints[0].position;
336                for j in &mut chain.joints {
337                    j.position += offset;
338                }
339            }
340        }
341
342        Some(max_iterations)
343    }
344}
345
346// ── 2-Bone Analytical Solver ──────────────────────────────────────────────────
347
348/// Closed-form IK for a 2-bone chain (3 joints: shoulder, elbow, wrist/hand).
349///
350/// Uses the law of cosines for exact, instantaneous solution. Supports
351/// a pole vector to control the elbow/knee direction.
352pub struct TwoBoneSolver;
353
354/// Result of a 2-bone IK solve.
355#[derive(Debug, Clone)]
356pub struct TwoBoneResult {
357    /// World position of the middle joint (elbow/knee).
358    pub mid_position:  Vec3,
359    /// Whether the target was reachable.
360    pub reachable:     bool,
361    /// Elbow angle in radians.
362    pub elbow_angle:   f32,
363}
364
365impl TwoBoneSolver {
366    /// Solve a 2-bone chain.
367    ///
368    /// - `root`    : shoulder/hip position
369    /// - `mid`     : current elbow/knee position (used for initial plane)
370    /// - `len_a`   : upper bone length (shoulder to elbow)
371    /// - `len_b`   : lower bone length (elbow to wrist)
372    /// - `target`  : desired wrist/ankle position
373    /// - `pole`    : pole vector pointing toward desired elbow direction
374    pub fn solve(
375        root:   Vec3,
376        mid:    Vec3,
377        len_a:  f32,
378        len_b:  f32,
379        target: Vec3,
380        pole:   Option<Vec3>,
381    ) -> TwoBoneResult {
382        let target_dist = (target - root).length();
383        let max_reach   = len_a + len_b;
384        let min_reach   = (len_a - len_b).abs();
385
386        let reachable = target_dist >= min_reach && target_dist <= max_reach;
387        let eff_dist  = target_dist.clamp(min_reach + 1e-4, max_reach - 1e-4);
388
389        // Law of cosines: cos(angle_at_root) = (a^2 + c^2 - b^2) / (2ac)
390        let a2 = len_a * len_a;
391        let b2 = len_b * len_b;
392        let c2 = eff_dist * eff_dist;
393        let cos_elbow = ((a2 + b2 - c2) / (2.0 * len_a * len_b)).clamp(-1.0, 1.0);
394        let elbow_angle = cos_elbow.acos();
395
396        // cos(angle_at_root) = (a^2 + c^2 - b^2) / (2ac)
397        let cos_root = ((a2 + c2 - b2) / (2.0 * len_a * eff_dist)).clamp(-1.0, 1.0);
398        let root_angle = cos_root.acos();
399
400        // Direction from root to target
401        let root_to_target = (target - root).normalize_or_zero();
402
403        // Find the plane normal using pole vector or fallback to existing mid joint
404        let plane_normal = {
405            let candidate = if let Some(p) = pole {
406                (p - root).normalize_or_zero()
407            } else {
408                (mid - root).normalize_or_zero()
409            };
410            // Orthogonalize against root_to_target
411            let n = candidate - root_to_target * candidate.dot(root_to_target);
412            n.normalize_or_zero()
413        };
414
415        // The mid joint lives in the root-target-pole plane
416        let mid_dir = if plane_normal.length_squared() > 1e-6 {
417            // Rotate root_to_target by root_angle around plane_normal
418            let rot = Quat::from_axis_angle(plane_normal, root_angle);
419            rot * root_to_target
420        } else {
421            // Degenerate: target is inline with root, project up
422            let up = if root_to_target.dot(Vec3::Y).abs() < 0.99 { Vec3::Y } else { Vec3::Z };
423            (up - root_to_target * up.dot(root_to_target)).normalize_or_zero()
424        };
425
426        let mid_position = root + mid_dir * len_a;
427
428        TwoBoneResult { mid_position, reachable, elbow_angle }
429    }
430
431    /// Apply the solve result to a 3-joint chain in place.
432    pub fn apply(chain: &mut IkChain, target: Vec3, pole: Option<Vec3>) -> TwoBoneResult {
433        assert!(chain.joints.len() == 3, "TwoBoneSolver requires exactly 3 joints");
434        let root  = chain.joints[0].position;
435        let mid   = chain.joints[1].position;
436        let len_a = chain.joints[0].bone_length;
437        let len_b = chain.joints[1].bone_length;
438        let result = Self::solve(root, mid, len_a, len_b, target, pole);
439        chain.joints[1].position = result.mid_position;
440        chain.joints[2].position = target;
441        result
442    }
443}
444
445// ── Look-At IK ────────────────────────────────────────────────────────────────
446
447/// Rotates a joint to aim at a target (look-at constraint).
448///
449/// Used for head/eye tracking, weapon aiming, etc.
450pub struct LookAtSolver;
451
452impl LookAtSolver {
453    /// Compute the rotation quaternion that rotates `forward` to point at `target`
454    /// from `eye_position`, with an `up` hint vector.
455    pub fn look_at_quat(eye_position: Vec3, target: Vec3, forward: Vec3, up: Vec3) -> Quat {
456        let desired_dir = (target - eye_position).normalize_or_zero();
457        if desired_dir.length_squared() < 1e-6 {
458            return Quat::IDENTITY;
459        }
460        let current_dir = forward.normalize_or_zero();
461        if current_dir.length_squared() < 1e-6 {
462            return Quat::IDENTITY;
463        }
464
465        let dot = current_dir.dot(desired_dir).clamp(-1.0, 1.0);
466        let angle = dot.acos();
467        if angle < 1e-5 { return Quat::IDENTITY; }
468
469        let axis = current_dir.cross(desired_dir);
470        if axis.length_squared() < 1e-10 {
471            // 180 degree case: rotate around up vector
472            return Quat::from_axis_angle(up.normalize_or_zero(), PI);
473        }
474        Quat::from_axis_angle(axis.normalize(), angle)
475    }
476
477    /// Partially rotate toward target with given weight [0, 1].
478    pub fn look_at_weighted(
479        eye_position: Vec3,
480        target: Vec3,
481        forward: Vec3,
482        up: Vec3,
483        weight: f32,
484    ) -> Quat {
485        let full = Self::look_at_quat(eye_position, target, forward, up);
486        Quat::IDENTITY.slerp(full, weight.clamp(0.0, 1.0))
487    }
488
489    /// Apply with angle limits: clamp the resulting rotation to ±max_angle.
490    pub fn look_at_clamped(
491        eye_position: Vec3,
492        target: Vec3,
493        forward: Vec3,
494        up: Vec3,
495        max_angle: f32,
496    ) -> Quat {
497        let q = Self::look_at_quat(eye_position, target, forward, up);
498        let (axis, angle) = q.to_axis_angle();
499        let clamped_angle = angle.clamp(-max_angle, max_angle);
500        Quat::from_axis_angle(axis, clamped_angle)
501    }
502}
503
504// ── IkRig ─────────────────────────────────────────────────────────────────────
505
506/// A full character IK rig with multiple named chains.
507pub struct IkRig {
508    pub chains:  HashMap<String, IkChain>,
509    pub targets: HashMap<String, Vec3>,
510    pub poles:   HashMap<String, Vec3>,
511    pub weights: HashMap<String, f32>,
512    pub enabled: bool,
513}
514
515use std::collections::HashMap;
516
517impl IkRig {
518    pub fn new() -> Self {
519        Self {
520            chains:  HashMap::new(),
521            targets: HashMap::new(),
522            poles:   HashMap::new(),
523            weights: HashMap::new(),
524            enabled: true,
525        }
526    }
527
528    pub fn add_chain(&mut self, name: impl Into<String>, chain: IkChain) {
529        let key = name.into();
530        self.weights.insert(key.clone(), 1.0);
531        self.chains.insert(key, chain);
532    }
533
534    pub fn set_target(&mut self, chain: &str, target: Vec3) {
535        self.targets.insert(chain.to_owned(), target);
536    }
537
538    pub fn set_pole(&mut self, chain: &str, pole: Vec3) {
539        self.poles.insert(chain.to_owned(), pole);
540    }
541
542    pub fn set_weight(&mut self, chain: &str, w: f32) {
543        self.weights.insert(chain.to_owned(), w.clamp(0.0, 1.0));
544    }
545
546    /// Solve all chains toward their targets.
547    pub fn solve_all(&mut self, max_iterations: usize) {
548        if !self.enabled { return; }
549        for (name, chain) in &mut self.chains {
550            let target = match self.targets.get(name) {
551                Some(t) => *t,
552                None    => continue,
553            };
554            let pole = self.poles.get(name).copied();
555            let weight = self.weights.get(name).copied().unwrap_or(1.0);
556            if weight < 1e-4 { continue; }
557
558            // Save pre-solve positions for weight blending
559            let before: Vec<Vec3> = chain.joints.iter().map(|j| j.position).collect();
560
561            if let Some(p) = pole {
562                FabrikSolver::solve_with_pole(chain, target, p, max_iterations);
563            } else {
564                FabrikSolver::solve(chain, target, max_iterations);
565            }
566
567            // Blend result with pre-solve by weight
568            if weight < 1.0 {
569                for (j, before_pos) in chain.joints.iter_mut().zip(before.iter()) {
570                    j.position = *before_pos + (j.position - *before_pos) * weight;
571                }
572            }
573        }
574    }
575
576    /// Get the solved end-effector position for a named chain.
577    pub fn end_effector(&self, chain: &str) -> Option<Vec3> {
578        self.chains.get(chain).map(|c| c.end_effector())
579    }
580}
581
582impl Default for IkRig {
583    fn default() -> Self { Self::new() }
584}
585
586// ── 2D IK (Vec2) ──────────────────────────────────────────────────────────────
587
588/// 2D IK chain for flat simulations (Vec2 joints).
589#[derive(Debug, Clone)]
590pub struct IkChain2D {
591    pub positions:    Vec<Vec2>,
592    pub bone_lengths: Vec<f32>,
593    pub root_pin:     Vec2,
594    pub tolerance:    f32,
595}
596
597impl IkChain2D {
598    pub fn new(positions: Vec<Vec2>) -> Self {
599        assert!(positions.len() >= 2);
600        let bone_lengths: Vec<f32> = positions.windows(2)
601            .map(|w| (w[1] - w[0]).length())
602            .collect();
603        let root = positions[0];
604        Self { positions, bone_lengths, root_pin: root, tolerance: 0.001 }
605    }
606
607    pub fn total_length(&self) -> f32 { self.bone_lengths.iter().sum() }
608
609    pub fn end_effector(&self) -> Vec2 { *self.positions.last().unwrap() }
610
611    /// FABRIK solve in 2D.
612    pub fn solve_fabrik(&mut self, target: Vec2, max_iter: usize) -> bool {
613        let n = self.positions.len();
614        let total = self.total_length();
615        let dist  = (target - self.root_pin).length();
616
617        if dist > total {
618            // Stretch toward target
619            let dir = (target - self.root_pin).normalize_or_zero();
620            self.positions[0] = self.root_pin;
621            for i in 1..n {
622                self.positions[i] = self.positions[i-1] + dir * self.bone_lengths[i-1];
623            }
624            return false;
625        }
626
627        for _ in 0..max_iter {
628            if (self.end_effector() - target).length() <= self.tolerance { return true; }
629
630            // Forward pass
631            self.positions[n-1] = target;
632            for i in (0..n-1).rev() {
633                let dir = (self.positions[i] - self.positions[i+1]).normalize_or_zero();
634                self.positions[i] = self.positions[i+1] + dir * self.bone_lengths[i];
635            }
636            // Backward pass
637            self.positions[0] = self.root_pin;
638            for i in 0..n-1 {
639                let dir = (self.positions[i+1] - self.positions[i]).normalize_or_zero();
640                self.positions[i+1] = self.positions[i] + dir * self.bone_lengths[i];
641            }
642        }
643        false
644    }
645}