nabled_model/
tree_model.rs1use nabled_core::scalar::NabledReal;
4use nabled_kinematics::error::KinematicsError;
5use nabled_kinematics::tree::{KinematicTreeModel, TreeJointType};
6use nabled_linalg::geometry::Transform3;
7
8use crate::joint::JointType;
9use crate::robot::{RobotModel, extract_chain};
10
11impl<T: NabledReal + Clone> KinematicTreeModel<T> for RobotModel<T> {
12 fn validate_tree(&self) -> Result<(), KinematicsError> {
13 self.validate().map_err(|err| KinematicsError::InvalidInput(err.to_string()))
14 }
15
16 fn dof(&self) -> usize { RobotModel::dof(self) }
17
18 fn actuated_indices(&self) -> Vec<usize> { RobotModel::actuated_indices(self) }
19
20 fn topological_order(&self) -> Vec<usize> { RobotModel::topological_order(self) }
21
22 fn body_index_for_link(&self, link_name: &str) -> Option<usize> {
23 RobotModel::body_index_for_link(self, link_name)
24 }
25
26 fn parent_link(&self, body_index: usize) -> &str {
27 &self.joint(body_index).expect("valid body index").parent_link
28 }
29
30 fn child_link(&self, body_index: usize) -> &str {
31 &self.joint(body_index).expect("valid body index").link.name
32 }
33
34 fn joint_type(&self, body_index: usize) -> TreeJointType {
35 match self.joint(body_index).expect("valid body index").joint_type {
36 JointType::Revolute => TreeJointType::Revolute,
37 JointType::Prismatic => TreeJointType::Prismatic,
38 JointType::Fixed => TreeJointType::Fixed,
39 }
40 }
41
42 fn joint_origin(&self, body_index: usize) -> &Transform3<T> {
43 &self.joint(body_index).expect("valid body index").joint_origin
44 }
45
46 fn joint_axis(&self, body_index: usize) -> [T; 3] {
47 self.joint(body_index).expect("valid body index").axis.unit_vector()
48 }
49
50 fn chain_indices(&self, base_link: &str, ee_link: &str) -> Result<Vec<usize>, KinematicsError> {
51 extract_chain(self, base_link, ee_link)
52 .map_err(|err| KinematicsError::InvalidInput(err.to_string()))
53 }
54
55 fn joint_limits(&self, joint_index: usize) -> Option<(T, T)> {
56 self.limits_for_joint(joint_index).map(|limits| (limits.lower, limits.upper))
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use approx::assert_relative_eq;
63 use nabled_kinematics::tree::{KinematicTreeModel, TreeJointType};
64
65 use super::*;
66 use crate::urdf::from_urdf_file;
67
68 fn y_branch_model() -> RobotModel<f64> {
69 let urdf_path = concat!(
70 env!("CARGO_MANIFEST_DIR"),
71 "/../nabled/tests/fixtures/physical_ai/Y_branch.urdf"
72 );
73 from_urdf_file(urdf_path).expect("Y-branch URDF")
74 }
75
76 #[test]
77 fn y_branch_tree_model_metadata() {
78 let model = y_branch_model();
79 model.validate_tree().expect("valid tree");
80 assert_eq!(model.dof(), 3);
81 assert_eq!(model.actuated_indices(), vec![0, 1, 2]);
82 assert_eq!(model.topological_order(), vec![0, 1, 2]);
83 assert_eq!(model.body_index_for_link("left_ee"), Some(1));
84 assert_eq!(model.body_index_for_link("right_ee"), Some(2));
85 }
86
87 #[test]
88 fn y_branch_chain_indices_for_branches() {
89 let model = y_branch_model();
90 let left = model.chain_indices("base", "left_ee").expect("left chain");
91 let right = model.chain_indices("base", "right_ee").expect("right chain");
92 assert_eq!(left, vec![0, 1]);
93 assert_eq!(right, vec![0, 2]);
94 }
95
96 #[test]
97 fn y_branch_joint_limits_from_urdf() {
98 let model = y_branch_model();
99 for joint_index in 0..model.dof() {
100 let (lower, upper) = model.joint_limits(joint_index).expect("limits");
101 assert!(lower < upper);
102 assert_relative_eq!(lower, -std::f64::consts::PI, epsilon = 1e-4);
103 assert_relative_eq!(upper, std::f64::consts::PI, epsilon = 1e-4);
104 }
105 }
106
107 #[test]
108 fn rejects_unknown_link_in_chain_indices() {
109 let model = y_branch_model();
110 let err = model.chain_indices("base", "missing").unwrap_err();
111 assert!(matches!(err, KinematicsError::InvalidInput(_)));
112 }
113
114 #[test]
115 fn trait_accessors_exercise_all_body_fields() {
116 let model = y_branch_model();
117 for &body_index in &model.actuated_indices() {
118 assert!(!model.parent_link(body_index).is_empty());
119 assert!(!model.child_link(body_index).is_empty());
120 assert!(matches!(
121 model.joint_type(body_index),
122 TreeJointType::Revolute | TreeJointType::Prismatic | TreeJointType::Fixed
123 ));
124 let origin = model.joint_origin(body_index);
125 assert!(origin.translation.iter().all(|v| v.is_finite()));
126 let axis = model.joint_axis(body_index);
127 let axis_len = (axis[0] * axis[0] + axis[1] * axis[1] + axis[2] * axis[2]).sqrt();
128 assert_relative_eq!(axis_len, 1.0, epsilon = 1e-6);
129 }
130 }
131}