1use std::collections::HashMap;
11
12use nabled_core::scalar::NabledReal;
13use nabled_kinematics::chain::{ChainSpec, DhConvention, JointType as KinJointType};
14use ndarray::Array1;
15
16use crate::ModelError;
17use crate::joint::JointType;
18use crate::robot::{RobotModel, extract_chain};
19
20fn chain_spec_from_indices<T: NabledReal + Default>(
21 model: &RobotModel<T>,
22 indices: &[usize],
23) -> Result<ChainSpec<T>, ModelError> {
24 model.validate()?;
25 let mut joint_types = Vec::new();
26 let mut a = Vec::new();
27 let mut alpha = Vec::new();
28 let mut d = Vec::new();
29 let mut theta_offset = Vec::new();
30 for &index in indices {
31 let body = model.joint(index).ok_or(ModelError::EmptyModel)?;
32 if matches!(body.joint_type, JointType::Fixed) {
33 continue;
34 }
35 let dh = body.dh_params.ok_or_else(|| {
36 ModelError::InvalidInput(format!(
37 "body {} (link '{}') has no DH parameters; URDF-derived models must use \
38 nabled-kinematics::tree (tree FK/Jacobian/IK) or be loaded via a fixture that \
39 provides explicit DH parameters",
40 index, body.link.name
41 ))
42 })?;
43 joint_types.push(match body.joint_type {
44 JointType::Revolute => KinJointType::Revolute,
45 JointType::Prismatic => KinJointType::Prismatic,
46 JointType::Fixed => unreachable!(),
47 });
48 a.push(dh.a);
49 alpha.push(dh.alpha);
50 d.push(dh.d);
51 theta_offset.push(dh.theta_offset);
52 }
53 ChainSpec::from_dh(
54 DhConvention::Standard,
55 joint_types,
56 Array1::from(a),
57 Array1::from(alpha),
58 Array1::from(d),
59 Array1::from(theta_offset),
60 )
61 .map_err(|_| ModelError::DimensionMismatch)
62}
63
64pub fn to_chain_spec<T: NabledReal + Default>(
72 model: &RobotModel<T>,
73) -> Result<ChainSpec<T>, ModelError> {
74 let order = model.topological_order();
75 chain_spec_from_indices(model, &order)
76}
77
78pub fn extract_chain_spec<T: NabledReal + Default>(
85 model: &RobotModel<T>,
86 base_link: &str,
87 ee_link: &str,
88) -> Result<ChainSpec<T>, ModelError> {
89 let indices = extract_chain(model, base_link, ee_link)?;
90 chain_spec_from_indices(model, &indices)
91}
92
93#[derive(Debug, Clone, PartialEq)]
100pub struct DynamicsBranchSpec<T> {
101 pub chain: ChainSpec<T>,
102 pub q_indices: Vec<usize>,
104 pub body_indices: Vec<usize>,
107}
108
109impl<T: Clone> DynamicsBranchSpec<T> {
110 pub fn branch_q(&self, model: &RobotModel<T>, q: &Array1<T>) -> Result<Array1<T>, ModelError>
116 where
117 T: NabledReal,
118 {
119 if q.len() != model.dof() {
120 return Err(ModelError::DimensionMismatch);
121 }
122 if self.q_indices.len() != self.chain.num_joints() {
123 return Err(ModelError::DimensionMismatch);
124 }
125 Ok(Array1::from(self.q_indices.iter().map(|&index| q[index]).collect::<Vec<_>>()))
126 }
127}
128
129pub fn extract_chain_spec_for_dynamics<T: NabledReal + Default>(
141 model: &RobotModel<T>,
142 base_link: &str,
143 ee_link: &str,
144) -> Result<DynamicsBranchSpec<T>, ModelError> {
145 let indices = extract_chain(model, base_link, ee_link)?;
146 let chain = chain_spec_from_indices(model, &indices)?;
147
148 let actuated = model.actuated_indices();
149 let actuated_map: HashMap<usize, usize> = actuated
150 .iter()
151 .enumerate()
152 .map(|(joint_index, &body_index)| (body_index, joint_index))
153 .collect();
154
155 let mut q_indices = Vec::new();
156 for &body_index in &indices {
157 let body = model.joint(body_index).ok_or(ModelError::EmptyModel)?;
158 if matches!(body.joint_type, JointType::Fixed) {
159 continue;
160 }
161 let joint_index =
162 actuated_map.get(&body_index).copied().ok_or(ModelError::DimensionMismatch)?;
163 q_indices.push(joint_index);
164 }
165
166 if q_indices.len() != chain.num_joints() {
167 return Err(ModelError::DimensionMismatch);
168 }
169
170 Ok(DynamicsBranchSpec { chain, q_indices, body_indices: indices })
171}
172
173#[cfg(test)]
174mod tests {
175 use approx::assert_relative_eq;
176 use ndarray::arr1;
177
178 use super::*;
179 use crate::joint::JointAxis;
180 use crate::link::LinkSpec;
181 use crate::origin::joint_origin_from_dh_scalars;
182 use crate::robot::{BodySpec, DhParams};
183
184 fn sample_body(name: &str, parent_link: &str) -> BodySpec<f64> {
185 BodySpec {
186 link: LinkSpec { name: name.to_string() },
187 parent_link: parent_link.to_string(),
188 joint_type: JointType::Revolute,
189 axis: JointAxis::Z,
190 limits: None,
191 inertial: None,
192 joint_origin: joint_origin_from_dh_scalars(1.0, 0.0, 0.0, 0.0).unwrap(),
193 dh_params: Some(DhParams {
194 a: 1.0,
195 alpha: 0.0,
196 d: 0.0,
197 theta_offset: 0.0,
198 }),
199 }
200 }
201
202 fn sample_body_without_dh(name: &str, parent_link: &str) -> BodySpec<f64> {
203 let mut body = sample_body(name, parent_link);
204 body.dh_params = None;
205 body
206 }
207
208 #[test]
209 fn extract_chain_matches_full_serial_model() {
210 let mut model = RobotModel::new();
211 let root = model.add_body(None, sample_body("link1", "base"));
212 let _ = model.add_body(Some(root), sample_body("link2", "link1"));
213 let full = to_chain_spec(&model).unwrap();
214 let extracted = extract_chain_spec(&model, "base", "link2").unwrap();
215 assert_eq!(full, extracted);
216 assert_eq!(full.a, arr1(&[1.0, 1.0]));
217 }
218
219 #[test]
220 fn dynamics_branch_slices_full_q() {
221 let mut model = RobotModel::new();
222 let root = model.add_body(None, sample_body("link1", "base"));
223 let _ = model.add_body(Some(root), sample_body("link2", "link1"));
224 let branch = extract_chain_spec_for_dynamics(&model, "base", "link2").unwrap();
225 assert_eq!(branch.chain.num_joints(), 2);
226 assert_eq!(branch.q_indices, vec![0, 1]);
227 let q = arr1(&[0.2, 0.4]);
228 let sliced = branch.branch_q(&model, &q).unwrap();
229 assert_relative_eq!(sliced, q, epsilon = 1e-12);
230 }
231
232 #[test]
233 fn dynamics_branch_rejects_dof_mismatch() {
234 let mut model = RobotModel::new();
235 let root = model.add_body(None, sample_body("link1", "base"));
236 let _ = model.add_body(Some(root), sample_body("link2", "link1"));
237 let branch = extract_chain_spec_for_dynamics(&model, "base", "link2").unwrap();
238 let q = arr1(&[0.2]);
239 assert!(branch.branch_q(&model, &q).is_err());
240 }
241
242 #[test]
243 fn dynamics_branch_carries_body_indices() {
244 let mut model = RobotModel::new();
245 let root = model.add_body(None, sample_body("link1", "base"));
246 let _ = model.add_body(Some(root), sample_body("link2", "link1"));
247 let branch = extract_chain_spec_for_dynamics(&model, "base", "link2").unwrap();
248 assert_eq!(branch.body_indices, vec![0, 1]);
249 }
250
251 #[test]
252 fn to_chain_spec_rejects_bodies_without_dh_params() {
253 let mut model = RobotModel::new();
254 let root = model.add_body(None, sample_body_without_dh("link1", "base"));
255 let _ = model.add_body(Some(root), sample_body_without_dh("link2", "link1"));
256 let err = to_chain_spec(&model).expect_err("no DH params -> error");
257 assert!(
258 matches!(err, ModelError::InvalidInput(message) if message.contains("no DH parameters"))
259 );
260 let err = extract_chain_spec(&model, "base", "link2").expect_err("no DH params -> error");
261 assert!(matches!(err, ModelError::InvalidInput(_)));
262 let err = extract_chain_spec_for_dynamics(&model, "base", "link2")
263 .expect_err("no DH params -> error");
264 assert!(matches!(err, ModelError::InvalidInput(_)));
265 }
266}