1use nabled_core::scalar::NabledReal;
4
5use crate::ModelError;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum JointType {
9 Revolute,
10 Prismatic,
11 Fixed,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum JointAxis {
16 X,
17 Y,
18 Z,
19 Custom([f64; 3]),
20}
21
22impl JointAxis {
23 #[must_use]
25 pub fn unit_vector<T: NabledReal>(&self) -> [T; 3] {
26 match self {
27 JointAxis::X => [T::one(), T::zero(), T::zero()],
28 JointAxis::Y => [T::zero(), T::one(), T::zero()],
29 JointAxis::Z => [T::zero(), T::zero(), T::one()],
30 JointAxis::Custom(v) => {
31 let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
32 if norm <= 1e-12 {
33 return [T::zero(), T::zero(), T::one()];
34 }
35 [
36 T::from_f64(v[0] / norm).unwrap_or(T::zero()),
37 T::from_f64(v[1] / norm).unwrap_or(T::zero()),
38 T::from_f64(v[2] / norm).unwrap_or(T::zero()),
39 ]
40 }
41 }
42 }
43
44 #[must_use]
46 pub fn from_xyz(x: f64, y: f64, z: f64) -> Self {
47 let norm = (x * x + y * y + z * z).sqrt();
48 if norm <= 1e-12 {
49 return JointAxis::Z;
50 }
51 let nx = x / norm;
52 let ny = y / norm;
53 let nz = z / norm;
54 if (nx - 1.0).abs() < 1e-6 && ny.abs() < 1e-6 && nz.abs() < 1e-6 {
55 JointAxis::X
56 } else if nx.abs() < 1e-6 && (ny - 1.0).abs() < 1e-6 && nz.abs() < 1e-6 {
57 JointAxis::Y
58 } else if nx.abs() < 1e-6 && ny.abs() < 1e-6 && (nz - 1.0).abs() < 1e-6 {
59 JointAxis::Z
60 } else {
61 JointAxis::Custom([nx, ny, nz])
62 }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq)]
67pub struct JointLimits<T> {
68 pub lower: T,
69 pub upper: T,
70 pub velocity: T,
71 pub effort: T,
72}
73
74pub fn validate_limits<T: PartialOrd>(limits: &JointLimits<T>) -> Result<(), ModelError> {
76 if limits.lower > limits.upper {
77 return Err(ModelError::InvalidInput("lower limit exceeds upper limit".to_string()));
78 }
79 Ok(())
80}
81
82#[cfg(test)]
83mod tests {
84 use approx::assert_relative_eq;
85
86 use super::*;
87
88 #[test]
89 fn cardinal_axes_are_unit_length() {
90 for axis in [JointAxis::X, JointAxis::Y, JointAxis::Z] {
91 let v = axis.unit_vector::<f64>();
92 let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
93 assert_relative_eq!(len, 1.0, epsilon = 1e-12);
94 }
95 }
96
97 #[test]
98 fn from_xyz_normalizes_custom_axis() {
99 let axis = JointAxis::from_xyz(0.0, 2.0, 0.0);
100 let v = axis.unit_vector::<f64>();
101 assert_relative_eq!(v[1], 1.0, epsilon = 1e-12);
102 }
103
104 #[test]
105 fn limits_validate_ordering() {
106 let ok = JointLimits { lower: -1.0, upper: 1.0, velocity: 2.0, effort: 10.0 };
107 validate_limits(&ok).unwrap();
108 let bad = JointLimits { lower: 1.0, upper: -1.0, velocity: 2.0, effort: 10.0 };
109 assert!(validate_limits(&bad).is_err());
110 }
111
112 #[test]
113 fn from_xyz_maps_cardinal_axes() {
114 assert_eq!(JointAxis::from_xyz(1.0, 0.0, 0.0), JointAxis::X);
115 assert_eq!(JointAxis::from_xyz(0.0, 1.0, 0.0), JointAxis::Y);
116 assert_eq!(JointAxis::from_xyz(0.0, 0.0, 1.0), JointAxis::Z);
117 }
118
119 #[test]
120 fn zero_axis_falls_back_to_z() {
121 let axis = JointAxis::from_xyz(0.0, 0.0, 0.0);
122 assert_eq!(axis, JointAxis::Z);
123 let v = axis.unit_vector::<f64>();
124 assert_relative_eq!(v[2], 1.0, epsilon = 1e-12);
125 }
126
127 #[test]
128 fn custom_axis_normalizes_non_unit_input() {
129 let axis = JointAxis::Custom([3.0, 0.0, 4.0]);
130 let v = axis.unit_vector::<f64>();
131 assert_relative_eq!(v[0], 0.6, epsilon = 1e-12);
132 assert_relative_eq!(v[2], 0.8, epsilon = 1e-12);
133 }
134}