Skip to main content

nabled_model/
joint.rs

1//! Joint types and limits.
2
3use nabled_core::scalar::NabledReal;
4
5use crate::ModelError;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum JointType {
9    Revolute,
10    Prismatic,
11    Fixed,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum JointAxis {
16    X,
17    Y,
18    Z,
19    Custom([f64; 3]),
20}
21
22impl JointAxis {
23    /// Unit axis vector for revolute/prismatic motion.
24    #[must_use]
25    pub fn unit_vector<T: NabledReal>(&self) -> [T; 3] {
26        match self {
27            JointAxis::X => [T::one(), T::zero(), T::zero()],
28            JointAxis::Y => [T::zero(), T::one(), T::zero()],
29            JointAxis::Z => [T::zero(), T::zero(), T::one()],
30            JointAxis::Custom(v) => {
31                let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
32                if norm <= 1e-12 {
33                    return [T::zero(), T::zero(), T::one()];
34                }
35                [
36                    T::from_f64(v[0] / norm).unwrap_or(T::zero()),
37                    T::from_f64(v[1] / norm).unwrap_or(T::zero()),
38                    T::from_f64(v[2] / norm).unwrap_or(T::zero()),
39                ]
40            }
41        }
42    }
43
44    /// Parse `<axis xyz="..."/>` into a normalized axis.
45    #[must_use]
46    pub fn from_xyz(x: f64, y: f64, z: f64) -> Self {
47        let norm = (x * x + y * y + z * z).sqrt();
48        if norm <= 1e-12 {
49            return JointAxis::Z;
50        }
51        let nx = x / norm;
52        let ny = y / norm;
53        let nz = z / norm;
54        if (nx - 1.0).abs() < 1e-6 && ny.abs() < 1e-6 && nz.abs() < 1e-6 {
55            JointAxis::X
56        } else if nx.abs() < 1e-6 && (ny - 1.0).abs() < 1e-6 && nz.abs() < 1e-6 {
57            JointAxis::Y
58        } else if nx.abs() < 1e-6 && ny.abs() < 1e-6 && (nz - 1.0).abs() < 1e-6 {
59            JointAxis::Z
60        } else {
61            JointAxis::Custom([nx, ny, nz])
62        }
63    }
64}
65
66#[derive(Debug, Clone, PartialEq)]
67pub struct JointLimits<T> {
68    pub lower:    T,
69    pub upper:    T,
70    pub velocity: T,
71    pub effort:   T,
72}
73
74/// Validate joint limits ordering.
75pub fn validate_limits<T: PartialOrd>(limits: &JointLimits<T>) -> Result<(), ModelError> {
76    if limits.lower > limits.upper {
77        return Err(ModelError::InvalidInput("lower limit exceeds upper limit".to_string()));
78    }
79    Ok(())
80}
81
82#[cfg(test)]
83mod tests {
84    use approx::assert_relative_eq;
85
86    use super::*;
87
88    #[test]
89    fn cardinal_axes_are_unit_length() {
90        for axis in [JointAxis::X, JointAxis::Y, JointAxis::Z] {
91            let v = axis.unit_vector::<f64>();
92            let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
93            assert_relative_eq!(len, 1.0, epsilon = 1e-12);
94        }
95    }
96
97    #[test]
98    fn from_xyz_normalizes_custom_axis() {
99        let axis = JointAxis::from_xyz(0.0, 2.0, 0.0);
100        let v = axis.unit_vector::<f64>();
101        assert_relative_eq!(v[1], 1.0, epsilon = 1e-12);
102    }
103
104    #[test]
105    fn limits_validate_ordering() {
106        let ok = JointLimits { lower: -1.0, upper: 1.0, velocity: 2.0, effort: 10.0 };
107        validate_limits(&ok).unwrap();
108        let bad = JointLimits { lower: 1.0, upper: -1.0, velocity: 2.0, effort: 10.0 };
109        assert!(validate_limits(&bad).is_err());
110    }
111
112    #[test]
113    fn from_xyz_maps_cardinal_axes() {
114        assert_eq!(JointAxis::from_xyz(1.0, 0.0, 0.0), JointAxis::X);
115        assert_eq!(JointAxis::from_xyz(0.0, 1.0, 0.0), JointAxis::Y);
116        assert_eq!(JointAxis::from_xyz(0.0, 0.0, 1.0), JointAxis::Z);
117    }
118
119    #[test]
120    fn zero_axis_falls_back_to_z() {
121        let axis = JointAxis::from_xyz(0.0, 0.0, 0.0);
122        assert_eq!(axis, JointAxis::Z);
123        let v = axis.unit_vector::<f64>();
124        assert_relative_eq!(v[2], 1.0, epsilon = 1e-12);
125    }
126
127    #[test]
128    fn custom_axis_normalizes_non_unit_input() {
129        let axis = JointAxis::Custom([3.0, 0.0, 4.0]);
130        let v = axis.unit_vector::<f64>();
131        assert_relative_eq!(v[0], 0.6, epsilon = 1e-12);
132        assert_relative_eq!(v[2], 0.8, epsilon = 1e-12);
133    }
134}