#![allow(missing_docs)]
#![allow(dead_code)]
use serde::{Deserialize, Serialize};
#[inline]
fn dot(a: [f64; 3], b: [f64; 3]) -> f64 {
a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
}
#[inline]
fn add(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
[a[0] + b[0], a[1] + b[1], a[2] + b[2]]
}
#[inline]
fn sub(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
[a[0] - b[0], a[1] - b[1], a[2] - b[2]]
}
#[inline]
fn scale(a: [f64; 3], s: f64) -> [f64; 3] {
[a[0] * s, a[1] * s, a[2] * s]
}
#[inline]
fn len_sq(a: [f64; 3]) -> f64 {
dot(a, a)
}
#[inline]
fn len(a: [f64; 3]) -> f64 {
len_sq(a).sqrt()
}
#[inline]
fn dist(a: [f64; 3], b: [f64; 3]) -> f64 {
len(sub(b, a))
}
#[inline]
fn normalize(a: [f64; 3]) -> Option<[f64; 3]> {
let l = len(a);
if l < 1e-12 {
None
} else {
Some(scale(a, 1.0 / l))
}
}
#[inline]
fn cross(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
[
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
]
}
fn any_perpendicular(v: [f64; 3]) -> [f64; 3] {
let ax = v[0].abs();
let ay = v[1].abs();
let az = v[2].abs();
let candidate = if ax <= ay && ax <= az {
[1.0, 0.0, 0.0]
} else if ay <= ax && ay <= az {
[0.0, 1.0, 0.0]
} else {
[0.0, 0.0, 1.0]
};
normalize(cross(v, candidate)).unwrap_or([0.0, 1.0, 0.0])
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IkJoint {
pub local_offset: [f64; 3],
pub cone_limit_rad: f64,
pub twist_limit_rad: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IkChain {
pub root_world: [f64; 3],
pub joints: Vec<IkJoint>,
pub segment_lengths: Vec<f64>,
pub pole_target: Option<[f64; 3]>,
}
impl IkChain {
pub fn new(root_world: [f64; 3], segments: &[(f64, f64, f64)]) -> Self {
let joints: Vec<IkJoint> = segments
.iter()
.map(|&(seg_len, cone_limit_rad, twist_limit_rad)| IkJoint {
local_offset: [0.0, -seg_len, 0.0],
cone_limit_rad,
twist_limit_rad,
})
.collect();
let segment_lengths: Vec<f64> = segments.iter().map(|&(l, _, _)| l).collect();
Self {
root_world,
joints,
segment_lengths,
pole_target: None,
}
}
pub fn total_reach(&self) -> f64 {
self.segment_lengths.iter().sum()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SolveReport {
pub iterations: usize,
pub converged: bool,
pub residual: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum IkError {
TwoBoneRequiresTwoSegments,
ChainTooShort,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum IkSolver {
Fabrik {
max_iterations: usize,
tolerance: f64,
},
TwoBone,
}
fn clamp_cone(
parent: Option<[f64; 3]>,
joint: [f64; 3],
child: &mut [f64; 3],
cone_limit_rad: f64,
seg_len: f64,
) {
use std::f64::consts::PI;
if cone_limit_rad >= PI - 1e-6 {
return; }
let Some(parent_pos) = parent else { return };
let ref_dir = match normalize(sub(joint, parent_pos)) {
Some(d) => d,
None => return, };
let bone_vec = sub(*child, joint);
let bone_dir = match normalize(bone_vec) {
Some(d) => d,
None => return, };
let cos_a = dot(ref_dir, bone_dir).clamp(-1.0, 1.0);
let angle = cos_a.acos();
if angle <= cone_limit_rad {
return; }
let axis_raw = cross(ref_dir, bone_dir);
let axis = match normalize(axis_raw) {
Some(a) => a,
None => return, };
let cos_c = cone_limit_rad.cos();
let sin_c = cone_limit_rad.sin();
let clamped_dir = add(
add(scale(ref_dir, cos_c), scale(cross(axis, ref_dir), sin_c)),
scale(axis, dot(axis, ref_dir) * (1.0 - cos_c)),
);
let clamped_dir = normalize(clamped_dir).unwrap_or(ref_dir);
*child = add(joint, scale(clamped_dir, seg_len));
}
impl IkSolver {
pub fn solve(&self, chain: &mut IkChain, target: [f64; 3]) -> SolveReport {
match self {
IkSolver::Fabrik {
max_iterations,
tolerance,
} => solve_fabrik(chain, target, *max_iterations, *tolerance),
IkSolver::TwoBone => solve_two_bone(chain, target),
}
}
}
fn solve_fabrik(
chain: &mut IkChain,
target: [f64; 3],
max_iterations: usize,
tolerance: f64,
) -> SolveReport {
let n = chain.joints.len(); if n == 0 {
return SolveReport {
iterations: 0,
converged: false,
residual: dist(chain.root_world, target),
};
}
let mut positions: Vec<[f64; 3]> = Vec::with_capacity(n + 1);
positions.push(chain.root_world);
for i in 0..n {
let prev = positions[i];
positions.push(add(prev, chain.joints[i].local_offset));
}
let total_reach: f64 = chain.segment_lengths.iter().sum();
let root_to_target = dist(chain.root_world, target);
if root_to_target > total_reach {
let dir = normalize(sub(target, chain.root_world)).unwrap_or([1.0, 0.0, 0.0]);
positions[0] = chain.root_world;
for i in 0..n {
positions[i + 1] = add(positions[i], scale(dir, chain.segment_lengths[i]));
}
for i in 0..n {
chain.joints[i].local_offset = sub(positions[i + 1], positions[i]);
}
let residual = dist(positions[n], target);
return SolveReport {
iterations: 0,
converged: false,
residual,
};
}
let mut converged = false;
let mut iters = 0;
for _ in 0..max_iterations {
iters += 1;
positions[n] = target;
for i in (0..n).rev() {
let dir = normalize(sub(positions[i], positions[i + 1])).unwrap_or([0.0, 1.0, 0.0]);
positions[i] = add(positions[i + 1], scale(dir, chain.segment_lengths[i]));
if i > 0 {
let parent = Some(positions[i - 1]);
clamp_cone(
parent,
positions[i],
&mut positions[i + 1],
chain.joints[i].cone_limit_rad,
chain.segment_lengths[i],
);
}
}
positions[0] = chain.root_world;
for i in 0..n {
let dir = normalize(sub(positions[i + 1], positions[i])).unwrap_or([0.0, -1.0, 0.0]);
positions[i + 1] = add(positions[i], scale(dir, chain.segment_lengths[i]));
let child_ref = if i + 2 <= n { positions[i + 2] } else { target };
let mut child_pos = child_ref;
clamp_cone(
Some(positions[i]),
positions[i + 1],
&mut child_pos,
chain.joints[i].cone_limit_rad,
chain.segment_lengths[i],
);
if i + 2 <= n {
positions[i + 2] = child_pos;
}
}
let residual = dist(positions[n], target);
if residual < tolerance {
converged = true;
break;
}
}
let residual = dist(positions[n], target);
for i in 0..n {
chain.joints[i].local_offset = sub(positions[i + 1], positions[i]);
}
SolveReport {
iterations: iters,
converged,
residual,
}
}
fn solve_two_bone(chain: &mut IkChain, target: [f64; 3]) -> SolveReport {
if chain.joints.len() != 2 || chain.segment_lengths.len() != 2 {
return SolveReport {
iterations: 0,
converged: false,
residual: f64::INFINITY,
};
}
let l1 = chain.segment_lengths[0];
let l2 = chain.segment_lengths[1];
let root = chain.root_world;
let d_raw = dist(root, target);
let d = d_raw.clamp((l1 - l2).abs(), l1 + l2 - 1e-9);
let cos_alpha = ((l1 * l1 + d * d - l2 * l2) / (2.0 * l1 * d)).clamp(-1.0, 1.0);
let sin_alpha = (1.0 - cos_alpha * cos_alpha).max(0.0).sqrt();
let along = normalize(sub(target, root)).unwrap_or([1.0, 0.0, 0.0]);
let perp = if let Some(pole) = chain.pole_target {
let pole_vec = sub(pole, root);
let n = cross(sub(target, root), pole_vec);
match normalize(n) {
Some(n_norm) => {
let p = cross(n_norm, along);
normalize(p).unwrap_or_else(|| any_perpendicular(along))
}
None => any_perpendicular(along),
}
} else {
any_perpendicular(along)
};
let elbow_pos = add(
add(root, scale(along, l1 * cos_alpha)),
scale(perp, l1 * sin_alpha),
);
let to_target_from_elbow = normalize(sub(target, elbow_pos)).unwrap_or(along);
let tip_pos = add(elbow_pos, scale(to_target_from_elbow, l2));
chain.joints[0].local_offset = sub(elbow_pos, root);
chain.joints[1].local_offset = sub(tip_pos, elbow_pos);
chain.segment_lengths[0] = l1;
chain.segment_lengths[1] = l2;
let residual = dist(tip_pos, target);
SolveReport {
iterations: 1,
converged: residual < 1e-4,
residual,
}
}
#[cfg(test)]
mod tests {
use super::*;
const PI: f64 = std::f64::consts::PI;
fn tip(chain: &IkChain) -> [f64; 3] {
let mut pos = chain.root_world;
for j in &chain.joints {
pos = add(pos, j.local_offset);
}
pos
}
#[test]
fn test_fabrik_straight_chain() {
let mut chain = IkChain::new(
[0.0, 0.0, 0.0],
&[(1.0, PI, 0.0), (1.0, PI, 0.0), (1.0, PI, 0.0)],
);
let solver = IkSolver::Fabrik {
max_iterations: 50,
tolerance: 1e-6,
};
let report = solver.solve(&mut chain, [0.0, 3.0, 0.0]);
assert!(
report.converged,
"Should converge; residual={:.2e}",
report.residual
);
assert!(
report.residual < 1e-6,
"Residual too large: {:.2e}",
report.residual
);
}
#[test]
fn test_fabrik_unreachable() {
let mut chain = IkChain::new([0.0, 0.0, 0.0], &[(1.0, PI, 0.0), (1.0, PI, 0.0)]);
let solver = IkSolver::Fabrik {
max_iterations: 50,
tolerance: 1e-6,
};
let report = solver.solve(&mut chain, [10.0, 0.0, 0.0]);
assert!(!report.converged, "Should NOT converge");
let tip_pos = tip(&chain);
assert!(
tip_pos[0] > 1.8,
"Chain should extend toward target; tip_x={:.3}",
tip_pos[0]
);
}
#[test]
fn test_fabrik_partial_reach() {
let mut chain = IkChain::new(
[0.0, 0.0, 0.0],
&[(1.0, PI, 0.0), (1.0, PI, 0.0), (1.0, PI, 0.0)],
);
let solver = IkSolver::Fabrik {
max_iterations: 100,
tolerance: 1e-4,
};
let report = solver.solve(&mut chain, [2.5, 0.0, 0.0]);
assert!(
report.residual < 1e-4,
"Residual too large for partial-reach: {:.2e}",
report.residual
);
}
#[test]
fn test_two_bone_two_segment() {
let mut chain = IkChain::new([0.0, 0.0, 0.0], &[(1.0, PI, 0.0), (1.0, PI, 0.0)]);
let solver = IkSolver::TwoBone;
let target = [2_f64.sqrt(), 0.0, 0.0];
let report = solver.solve(&mut chain, target);
assert!(
report.residual < 1e-4,
"TwoBone residual too large: {:.2e}",
report.residual
);
assert!(
(len(chain.joints[0].local_offset) - 1.0).abs() < 1e-6,
"Segment 0 length changed"
);
assert!(
(len(chain.joints[1].local_offset) - 1.0).abs() < 1e-6,
"Segment 1 length changed"
);
}
#[test]
fn test_two_bone_wrong_length() {
let mut chain = IkChain::new(
[0.0, 0.0, 0.0],
&[(1.0, PI, 0.0), (1.0, PI, 0.0), (1.0, PI, 0.0)],
);
let solver = IkSolver::TwoBone;
let report = solver.solve(&mut chain, [1.0, 0.0, 0.0]);
assert!(!report.converged);
assert!(
report.residual.is_infinite(),
"Expected INFINITY residual, got {}",
report.residual
);
}
#[test]
fn test_cone_limit() {
let cone = 0.1_f64; let mut chain = IkChain::new(
[0.0, 0.0, 0.0],
&[(1.0, cone, 0.0), (1.0, cone, 0.0), (1.0, cone, 0.0)],
);
let solver = IkSolver::Fabrik {
max_iterations: 50,
tolerance: 1e-6,
};
solver.solve(&mut chain, [3.0, 0.0, 0.0]);
let mut positions = vec![chain.root_world];
for j in &chain.joints {
let prev = *positions.last().expect("positions non-empty");
positions.push(add(prev, j.local_offset));
}
for i in 1..chain.joints.len() {
let d0 = normalize(sub(positions[i], positions[i - 1])).unwrap_or([0.0, 1.0, 0.0]);
let d1 = normalize(sub(positions[i + 1], positions[i])).unwrap_or([0.0, 1.0, 0.0]);
let angle = dot(d0, d1).clamp(-1.0, 1.0).acos();
assert!(
angle <= cone + 1e-3,
"Bone {} angle {:.4} rad exceeds cone limit {:.4} rad",
i,
angle,
cone
);
}
}
#[test]
fn test_serde_round_trip() {
let chain = IkChain::new([1.0, 2.0, 3.0], &[(1.5, 0.5, 0.1), (2.0, PI, 0.0)]);
let json = serde_json::to_string(&chain).expect("serialize IkChain");
let chain2: IkChain = serde_json::from_str(&json).expect("deserialize IkChain");
assert_eq!(chain, chain2);
let solvers = [
IkSolver::Fabrik {
max_iterations: 20,
tolerance: 1e-5,
},
IkSolver::TwoBone,
];
for solver in &solvers {
let json = serde_json::to_string(solver).expect("serialize IkSolver");
let solver2: IkSolver = serde_json::from_str(&json).expect("deserialize IkSolver");
assert_eq!(solver, &solver2);
}
let report = SolveReport {
iterations: 5,
converged: true,
residual: 1e-7,
};
let json = serde_json::to_string(&report).expect("serialize SolveReport");
let report2: SolveReport = serde_json::from_str(&json).expect("deserialize SolveReport");
assert_eq!(report, report2);
let err = IkError::TwoBoneRequiresTwoSegments;
let json = serde_json::to_string(&err).expect("serialize IkError");
let err2: IkError = serde_json::from_str(&json).expect("deserialize IkError");
assert_eq!(err, err2);
}
}