use nalgebra::{UnitQuaternion, Vector3, Vector4};
use super::spatial::{SpatialTransform, SpatialVector};
#[derive(Clone, Debug)]
pub enum GenJoint {
Fixed,
Revolute {
axis: Vector3<f32>,
},
Prismatic {
axis: Vector3<f32>,
},
Spherical,
Floating,
Planar {
normal: Vector3<f32>,
},
}
impl GenJoint {
pub fn dof(&self) -> usize {
match self {
GenJoint::Fixed => 0,
GenJoint::Revolute { .. } | GenJoint::Prismatic { .. } => 1,
GenJoint::Spherical | GenJoint::Planar { .. } => 3,
GenJoint::Floating => 6,
}
}
pub fn n_pos(&self) -> usize {
match self {
GenJoint::Fixed => 0,
GenJoint::Revolute { .. } | GenJoint::Prismatic { .. } => 1,
GenJoint::Planar { .. } => 3,
GenJoint::Spherical => 4,
GenJoint::Floating => 7,
}
}
pub fn transform(&self, q: &[f32]) -> SpatialTransform {
match self {
GenJoint::Fixed => SpatialTransform::identity(),
GenJoint::Revolute { axis } => {
let angle = q.first().copied().unwrap_or(0.0);
SpatialTransform::from_axis_angle_translation(*axis, angle, Vector3::zeros())
}
GenJoint::Prismatic { axis } => {
let displacement = q.first().copied().unwrap_or(0.0);
SpatialTransform::from_translation(*axis * displacement)
}
GenJoint::Spherical => {
let rot = quat_from_slice(q);
SpatialTransform::from_rotation(rot.to_rotation_matrix().into_inner())
}
GenJoint::Floating => {
if q.len() >= 7 {
let trans = Vector3::new(q[0], q[1], q[2]);
let rot = quat_from_slice(&q[3..]);
SpatialTransform::from_rotation_translation(
rot.to_rotation_matrix().into_inner(),
trans,
)
} else {
SpatialTransform::identity()
}
}
GenJoint::Planar { normal } => {
if q.len() >= 3 {
let (t1, t2) = plane_basis(normal);
let trans = t1 * q[0] + t2 * q[1];
let rot = SpatialTransform::from_axis_angle_translation(
*normal,
q[2],
Vector3::zeros(),
);
SpatialTransform::from_rotation_translation(rot.rotation, trans)
} else {
SpatialTransform::identity()
}
}
}
}
pub fn motion_subspace(&self, _q: &[f32]) -> Vec<SpatialVector> {
match self {
GenJoint::Fixed => vec![],
GenJoint::Revolute { axis } => {
vec![SpatialVector::new(*axis, Vector3::zeros())]
}
GenJoint::Prismatic { axis } => {
vec![SpatialVector::new(Vector3::zeros(), *axis)]
}
GenJoint::Spherical => {
vec![
SpatialVector::new(Vector3::x(), Vector3::zeros()),
SpatialVector::new(Vector3::y(), Vector3::zeros()),
SpatialVector::new(Vector3::z(), Vector3::zeros()),
]
}
GenJoint::Floating => {
vec![
SpatialVector::new(Vector3::zeros(), Vector3::x()),
SpatialVector::new(Vector3::zeros(), Vector3::y()),
SpatialVector::new(Vector3::zeros(), Vector3::z()),
SpatialVector::new(Vector3::x(), Vector3::zeros()),
SpatialVector::new(Vector3::y(), Vector3::zeros()),
SpatialVector::new(Vector3::z(), Vector3::zeros()),
]
}
GenJoint::Planar { normal } => {
let (t1, t2) = plane_basis(normal);
vec![
SpatialVector::new(Vector3::zeros(), t1),
SpatialVector::new(Vector3::zeros(), t2),
SpatialVector::new(*normal, Vector3::zeros()),
]
}
}
}
pub fn velocity_from_qdot(&self, q: &[f32], qdot: &[f32]) -> Vec<f32> {
match self {
GenJoint::Fixed => vec![],
GenJoint::Revolute { .. } | GenJoint::Prismatic { .. } => {
vec![qdot.first().copied().unwrap_or(0.0)]
}
GenJoint::Planar { .. } => {
let mut v = vec![0.0; 3];
let n = 3.min(qdot.len());
v[..n].copy_from_slice(&qdot[..n]);
v
}
GenJoint::Spherical => {
let quat = quat_from_slice(q);
let qdot_quat = Vector4::new(
qdot.first().copied().unwrap_or(0.0), qdot.get(1).copied().unwrap_or(0.0), qdot.get(2).copied().unwrap_or(0.0), qdot.get(3).copied().unwrap_or(0.0), );
let omega = quat_deriv_to_omega(quat, qdot_quat);
vec![omega.x, omega.y, omega.z]
}
GenJoint::Floating => {
let vx = qdot.first().copied().unwrap_or(0.0);
let vy = qdot.get(1).copied().unwrap_or(0.0);
let vz = qdot.get(2).copied().unwrap_or(0.0);
let quat = quat_from_slice(if q.len() >= 7 { &q[3..] } else { &[] });
let qdot_quat = Vector4::new(
qdot.get(3).copied().unwrap_or(0.0),
qdot.get(4).copied().unwrap_or(0.0),
qdot.get(5).copied().unwrap_or(0.0),
qdot.get(6).copied().unwrap_or(0.0),
);
let omega = quat_deriv_to_omega(quat, qdot_quat);
vec![vx, vy, vz, omega.x, omega.y, omega.z]
}
}
}
pub fn qdot_from_velocity(&self, q: &[f32], v: &[f32]) -> Vec<f32> {
match self {
GenJoint::Fixed => vec![],
GenJoint::Revolute { .. } | GenJoint::Prismatic { .. } => {
vec![v.first().copied().unwrap_or(0.0)]
}
GenJoint::Planar { .. } => {
let mut result = vec![0.0; 3];
let n = 3.min(v.len());
result[..n].copy_from_slice(&v[..n]);
result
}
GenJoint::Spherical => {
let quat = quat_from_slice(q);
let omega = Vector3::new(
v.first().copied().unwrap_or(0.0),
v.get(1).copied().unwrap_or(0.0),
v.get(2).copied().unwrap_or(0.0),
);
let qdot = omega_to_quat_deriv(quat, omega);
vec![qdot.w, qdot.i, qdot.j, qdot.k]
}
GenJoint::Floating => {
let vx = v.first().copied().unwrap_or(0.0);
let vy = v.get(1).copied().unwrap_or(0.0);
let vz = v.get(2).copied().unwrap_or(0.0);
let quat = quat_from_slice(if q.len() >= 7 { &q[3..] } else { &[] });
let omega = Vector3::new(
v.get(3).copied().unwrap_or(0.0),
v.get(4).copied().unwrap_or(0.0),
v.get(5).copied().unwrap_or(0.0),
);
let qdot = omega_to_quat_deriv(quat, omega);
vec![vx, vy, vz, qdot.w, qdot.i, qdot.j, qdot.k]
}
}
}
pub fn normalize_q(&self, q: &mut [f32]) {
match self {
GenJoint::Spherical if q.len() >= 4 => {
normalize_quat_slice(&mut q[0..4]);
}
GenJoint::Floating if q.len() >= 7 => {
normalize_quat_slice(&mut q[3..7]);
}
_ => {} }
}
pub fn default_q(&self) -> Vec<f32> {
match self {
GenJoint::Fixed => vec![],
GenJoint::Revolute { .. } | GenJoint::Prismatic { .. } => vec![0.0],
GenJoint::Planar { .. } => vec![0.0, 0.0, 0.0],
GenJoint::Spherical => vec![1.0, 0.0, 0.0, 0.0], GenJoint::Floating => vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], }
}
}
fn quat_from_slice(q: &[f32]) -> UnitQuaternion<f32> {
if q.len() >= 4 {
let quat = nalgebra::Quaternion::new(q[0], q[1], q[2], q[3]);
if quat.norm() < 1e-10 {
UnitQuaternion::identity()
} else {
UnitQuaternion::new_normalize(quat)
}
} else {
UnitQuaternion::identity()
}
}
fn quat_deriv_to_omega(q: UnitQuaternion<f32>, qdot: Vector4<f32>) -> Vector3<f32> {
let qc = q.conjugate();
let qc_inner = qc.into_inner();
let qdot_quat = nalgebra::Quaternion::new(qdot[0], qdot[1], qdot[2], qdot[3]);
let result = qc_inner * qdot_quat;
Vector3::new(result.i, result.j, result.k) * 2.0
}
fn omega_to_quat_deriv(q: UnitQuaternion<f32>, omega: Vector3<f32>) -> nalgebra::Quaternion<f32> {
let omega_quat = nalgebra::Quaternion::new(0.0, omega.x, omega.y, omega.z);
q.into_inner() * omega_quat * 0.5
}
fn normalize_quat_slice(q: &mut [f32]) {
let norm = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
if norm > 1e-10 {
let inv = 1.0 / norm;
q[0] *= inv;
q[1] *= inv;
q[2] *= inv;
q[3] *= inv;
} else {
#[cfg(feature = "tracing")]
tracing::warn!("Degenerate quaternion detected (norm={:.2e}), resetting to identity", norm);
let _ = norm;
q[0] = 1.0;
q[1] = 0.0;
q[2] = 0.0;
q[3] = 0.0;
}
}
pub fn plane_basis(normal: &Vector3<f32>) -> (Vector3<f32>, Vector3<f32>) {
if normal.norm_squared() < 1e-12 {
return (Vector3::x(), Vector3::y());
}
let n = normal.normalize();
let hint = if n.x.abs() < 0.9 {
Vector3::x()
} else {
Vector3::y()
};
let t1 = n.cross(&hint).normalize();
let t2 = n.cross(&t1).normalize();
(t1, t2)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use nalgebra::Matrix3;
use std::f32::consts::FRAC_PI_2;
#[test]
fn test_dof_and_npos() {
assert_eq!(GenJoint::Fixed.dof(), 0);
assert_eq!(GenJoint::Fixed.n_pos(), 0);
let rev = GenJoint::Revolute {
axis: Vector3::z(),
};
assert_eq!(rev.dof(), 1);
assert_eq!(rev.n_pos(), 1);
let pris = GenJoint::Prismatic {
axis: Vector3::x(),
};
assert_eq!(pris.dof(), 1);
assert_eq!(pris.n_pos(), 1);
assert_eq!(GenJoint::Spherical.dof(), 3);
assert_eq!(GenJoint::Spherical.n_pos(), 4);
assert_eq!(GenJoint::Floating.dof(), 6);
assert_eq!(GenJoint::Floating.n_pos(), 7);
let planar = GenJoint::Planar {
normal: Vector3::z(),
};
assert_eq!(planar.dof(), 3);
assert_eq!(planar.n_pos(), 3);
}
#[test]
fn test_revolute_transform_90deg() {
let jt = GenJoint::Revolute {
axis: Vector3::z(),
};
let x = jt.transform(&[FRAC_PI_2]);
let p = x.rotation * Vector3::x();
assert_relative_eq!(p.y, 1.0, epsilon = 1e-5);
assert!(p.x.abs() < 1e-5);
}
#[test]
fn test_prismatic_transform() {
let jt = GenJoint::Prismatic {
axis: Vector3::x(),
};
let x = jt.transform(&[0.5]);
assert_relative_eq!(x.translation.x, 0.5, epsilon = 1e-10);
}
#[test]
fn test_fixed_identity() {
let x = GenJoint::Fixed.transform(&[]);
assert!((x.rotation - Matrix3::identity()).norm() < 1e-10);
}
#[test]
fn test_spherical_identity_at_unit_quat() {
let x = GenJoint::Spherical.transform(&[1.0, 0.0, 0.0, 0.0]);
assert!((x.rotation - Matrix3::identity()).norm() < 1e-6);
}
#[test]
fn test_spherical_90deg_z() {
let half_angle = FRAC_PI_2 / 2.0;
let w = half_angle.cos();
let z = half_angle.sin();
let x = GenJoint::Spherical.transform(&[w, 0.0, 0.0, z]);
let p = x.rotation * Vector3::x();
assert_relative_eq!(p.y, 1.0, epsilon = 1e-5);
assert!(p.x.abs() < 1e-5);
}
#[test]
fn test_floating_identity() {
let x = GenJoint::Floating.transform(&[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]);
assert!((x.rotation - Matrix3::identity()).norm() < 1e-6);
assert!(x.translation.norm() < 1e-10);
}
#[test]
fn test_floating_translation() {
let x = GenJoint::Floating.transform(&[1.0, 2.0, 3.0, 1.0, 0.0, 0.0, 0.0]);
assert_relative_eq!(x.translation.x, 1.0, epsilon = 1e-6);
assert_relative_eq!(x.translation.y, 2.0, epsilon = 1e-6);
assert_relative_eq!(x.translation.z, 3.0, epsilon = 1e-6);
}
#[test]
fn test_transform_inverse_round_trip() {
let joints_and_qs: Vec<(GenJoint, Vec<f32>)> = vec![
(GenJoint::Revolute { axis: Vector3::z() }, vec![0.7]),
(GenJoint::Prismatic { axis: Vector3::x() }, vec![1.5]),
(GenJoint::Spherical, vec![0.9239, 0.0, 0.0, 0.3827]), (
GenJoint::Floating,
vec![1.0, 2.0, 3.0, 0.9239, 0.0, 0.0, 0.3827],
),
(GenJoint::Planar { normal: Vector3::z() }, vec![1.0, 2.0, 0.5]),
];
for (joint, q) in &joints_and_qs {
let x = joint.transform(q);
let x_inv = x.inverse();
let composed = x.compose(&x_inv);
assert!(
(composed.rotation - Matrix3::identity()).norm() < 1e-5,
"Round-trip failed for {joint:?}"
);
assert!(
composed.translation.norm() < 1e-5,
"Round-trip translation failed for {joint:?}"
);
}
}
#[test]
fn test_motion_subspace_dimensions() {
let joints: Vec<GenJoint> = vec![
GenJoint::Fixed,
GenJoint::Revolute { axis: Vector3::z() },
GenJoint::Prismatic { axis: Vector3::y() },
GenJoint::Spherical,
GenJoint::Floating,
GenJoint::Planar { normal: Vector3::z() },
];
for j in &joints {
let s = j.motion_subspace(&j.default_q());
assert_eq!(s.len(), j.dof(), "S columns should equal dof() for {j:?}");
}
}
#[test]
fn test_motion_subspace_revolute() {
let jt = GenJoint::Revolute {
axis: Vector3::z(),
};
let s = jt.motion_subspace(&[0.0]);
assert_eq!(s.len(), 1);
assert_relative_eq!(s[0].angular().z, 1.0, epsilon = 1e-10);
assert!(s[0].linear().norm() < 1e-10);
}
#[test]
fn test_motion_subspace_prismatic() {
let jt = GenJoint::Prismatic {
axis: Vector3::y(),
};
let s = jt.motion_subspace(&[0.0]);
assert!(s[0].angular().norm() < 1e-10);
assert_relative_eq!(s[0].linear().y, 1.0, epsilon = 1e-10);
}
#[test]
fn test_velocity_round_trip_revolute() {
let jt = GenJoint::Revolute {
axis: Vector3::z(),
};
let q = vec![0.5];
let omega = vec![1.0];
let qdot = jt.qdot_from_velocity(&q, &omega);
let omega_back = jt.velocity_from_qdot(&q, &qdot);
assert_relative_eq!(omega_back[0], omega[0], epsilon = 1e-10);
}
#[test]
fn test_velocity_round_trip_spherical() {
let jt = GenJoint::Spherical;
let q = vec![1.0, 0.0, 0.0, 0.0]; let omega = vec![0.5, -0.3, 0.7];
let qdot = jt.qdot_from_velocity(&q, &omega);
let omega_back = jt.velocity_from_qdot(&q, &qdot);
for i in 0..3 {
assert_relative_eq!(omega_back[i], omega[i], epsilon = 1e-5);
}
}
#[test]
fn test_velocity_round_trip_floating() {
let jt = GenJoint::Floating;
let q = vec![1.0, 2.0, 3.0, 1.0, 0.0, 0.0, 0.0];
let v = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
let qdot = jt.qdot_from_velocity(&q, &v);
let v_back = jt.velocity_from_qdot(&q, &qdot);
for i in 0..6 {
assert_relative_eq!(v_back[i], v[i], epsilon = 1e-5);
}
}
#[test]
fn test_normalize_spherical() {
let jt = GenJoint::Spherical;
let mut q = vec![2.0, 0.0, 0.0, 0.0];
jt.normalize_q(&mut q);
let norm = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
assert_relative_eq!(norm, 1.0, epsilon = 1e-7);
}
#[test]
fn test_normalize_floating() {
let jt = GenJoint::Floating;
let mut q = vec![1.0, 2.0, 3.0, 0.5, 0.5, 0.5, 0.5];
jt.normalize_q(&mut q);
assert_relative_eq!(q[0], 1.0);
assert_relative_eq!(q[1], 2.0);
assert_relative_eq!(q[2], 3.0);
let norm = (q[3] * q[3] + q[4] * q[4] + q[5] * q[5] + q[6] * q[6]).sqrt();
assert_relative_eq!(norm, 1.0, epsilon = 1e-7);
}
#[test]
fn test_default_q_identity_transform() {
let joints: Vec<GenJoint> = vec![
GenJoint::Fixed,
GenJoint::Revolute { axis: Vector3::z() },
GenJoint::Prismatic { axis: Vector3::x() },
GenJoint::Spherical,
GenJoint::Floating,
GenJoint::Planar { normal: Vector3::z() },
];
for j in &joints {
let default_q = j.default_q();
assert_eq!(default_q.len(), j.n_pos(), "default_q len for {j:?}");
let x = j.transform(&default_q);
assert!(
(x.rotation - Matrix3::identity()).norm() < 1e-6,
"default_q should give identity rotation for {j:?}"
);
assert!(
x.translation.norm() < 1e-10,
"default_q should give zero translation for {j:?}"
);
}
}
#[test]
#[test]
fn intent_joint_dof_matches_motion_subspace_len() {
let joints = [
GenJoint::Fixed,
GenJoint::Revolute { axis: Vector3::z() },
GenJoint::Prismatic { axis: Vector3::x() },
GenJoint::Spherical,
GenJoint::Floating,
GenJoint::Planar { normal: Vector3::y() },
];
for j in &joints {
let q = j.default_q();
let s = j.motion_subspace(&q);
assert_eq!(
s.len(), j.dof(),
"{:?}: motion_subspace len ({}) != dof ({})", j, s.len(), j.dof()
);
}
}
#[test]
fn intent_joint_npos_geq_dof() {
let joints = [
GenJoint::Fixed,
GenJoint::Revolute { axis: Vector3::z() },
GenJoint::Prismatic { axis: Vector3::x() },
GenJoint::Spherical,
GenJoint::Floating,
GenJoint::Planar { normal: Vector3::y() },
];
for j in &joints {
assert!(
j.n_pos() >= j.dof(),
"{:?}: n_pos ({}) must be >= dof ({})", j, j.n_pos(), j.dof()
);
}
}
#[test]
fn intent_spherical_quaternion_normalize_preserves_unit() {
let j = GenJoint::Spherical;
let mut q = vec![1.0_f32, 0.0, 0.0, 0.0];
j.normalize_q(&mut q);
let norm: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6, "unit quat preserved: norm={}", norm);
}
#[test]
fn intent_spherical_quaternion_normalize_fixes_unnormalized() {
let j = GenJoint::Spherical;
let mut q = vec![2.0_f32, 0.0, 0.0, 0.0];
j.normalize_q(&mut q);
let norm: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "unnormalized fixed: norm={}", norm);
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_revolute_transform_at_zero_is_identity(
ax in -1.0f32..1.0, ay in -1.0f32..1.0, az in 0.1f32..1.0,
) {
let len = (ax * ax + ay * ay + az * az).sqrt();
let axis = Vector3::new(ax / len, ay / len, az / len);
let joint = GenJoint::Revolute { axis };
let t = joint.transform(&[0.0]);
let diff = (t.rotation - Matrix3::identity()).norm();
prop_assert!(diff < 1e-5, "Transform at q=0 should be identity, got rotation diff={diff}");
prop_assert!(t.translation.norm() < 1e-5, "Translation at q=0 should be zero");
}
#[test]
fn prop_revolute_velocity_roundtrip(
angle in -3.0f32..3.0,
vel in -10.0f32..10.0,
) {
let joint = GenJoint::Revolute { axis: Vector3::z() };
let q = vec![angle];
let v_in = vec![vel];
let qdot = joint.qdot_from_velocity(&q, &v_in);
let v_out = joint.velocity_from_qdot(&q, &qdot);
prop_assert!((v_out[0] - vel).abs() < 1e-4,
"Velocity roundtrip: in={vel}, out={}", v_out[0]);
}
#[test]
fn prop_prismatic_default_q_gives_identity(
ax in -1.0f32..1.0, ay in -1.0f32..1.0, az in 0.1f32..1.0,
) {
let len = (ax * ax + ay * ay + az * az).sqrt();
let axis = Vector3::new(ax / len, ay / len, az / len);
let joint = GenJoint::Prismatic { axis };
let q = joint.default_q();
let t = joint.transform(&q);
prop_assert!(t.translation.norm() < 1e-5, "Default prismatic should have zero translation");
}
#[test]
fn prop_motion_subspace_dimension_equals_dof(
angle in -3.0f32..3.0,
) {
let joints = vec![
GenJoint::Revolute { axis: Vector3::z() },
GenJoint::Prismatic { axis: Vector3::y() },
GenJoint::Spherical,
GenJoint::Fixed,
];
for j in &joints {
let q = j.default_q();
let s = j.motion_subspace(&q);
prop_assert_eq!(s.len(), j.dof(),
"{:?}: motion subspace columns ({}) != dof ({})", j, s.len(), j.dof());
}
}
}
#[test]
fn intent_all_joints_default_q_produces_identity_transform() {
let joints: Vec<GenJoint> = vec![
GenJoint::Fixed,
GenJoint::Revolute { axis: Vector3::z() },
GenJoint::Prismatic { axis: Vector3::y() },
GenJoint::Spherical,
GenJoint::Floating,
GenJoint::Planar { normal: Vector3::y() },
];
for j in &joints {
let q = j.default_q();
let t = j.transform(&q);
let rot_diff = (t.rotation - Matrix3::identity()).norm();
let trans_diff = t.translation.norm();
assert!(rot_diff < 1e-4, "{:?}: default rotation should be identity, diff={rot_diff}", j);
assert!(trans_diff < 1e-4, "{:?}: default translation should be zero, diff={trans_diff}", j);
}
}
#[test]
fn test_planar_joint_dof_and_npos() {
let j = GenJoint::Planar { normal: Vector3::y() };
assert_eq!(j.dof(), 3, "planar = 3 DOF");
assert_eq!(j.n_pos(), 3, "planar = 3 pos params");
}
#[test]
fn test_floating_joint_velocity_roundtrip() {
let j = GenJoint::Floating;
let q = j.default_q(); let v = vec![1.0_f32, 2.0, 3.0, 0.1, 0.2, 0.3]; let qdot = j.qdot_from_velocity(&q, &v);
let v_back = j.velocity_from_qdot(&q, &qdot);
for i in 0..6 {
assert!((v_back[i] - v[i]).abs() < 0.01,
"floating velocity roundtrip failed at [{i}]: in={}, out={}", v[i], v_back[i]);
}
}
#[test]
fn test_spherical_joint_3dof_motion_subspace() {
let j = GenJoint::Spherical;
let q = j.default_q();
let s = j.motion_subspace(&q);
assert_eq!(s.len(), 3, "spherical should have 3 motion subspace columns");
}
}