1use std::collections::{HashMap, VecDeque};
4
5use nabled_linalg::geometry::Transform3;
6
7use crate::ModelError;
8use crate::joint::{JointLimits, JointType};
9use crate::link::{InertialSpec, LinkSpec};
10
11#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct DhParams<T> {
19 pub a: T,
20 pub alpha: T,
21 pub d: T,
22 pub theta_offset: T,
23}
24
25#[derive(Debug, Clone, PartialEq)]
26pub struct BodySpec<T> {
27 pub link: LinkSpec,
28 pub parent_link: String,
29 pub joint_type: JointType,
30 pub axis: crate::joint::JointAxis,
31 pub limits: Option<JointLimits<T>>,
32 pub inertial: Option<InertialSpec<T>>,
33 pub joint_origin: Transform3<T>,
35 pub dh_params: Option<DhParams<T>>,
40}
41
42#[derive(Debug, Clone, PartialEq)]
43pub struct RobotModel<T> {
44 bodies: Vec<BodySpec<T>>,
45 parents: Vec<Option<usize>>,
46 link_to_body: HashMap<String, usize>,
47}
48
49impl<T: Clone> Default for RobotModel<T> {
50 fn default() -> Self { Self::new() }
51}
52
53impl<T: Clone> RobotModel<T> {
54 #[must_use]
55 pub fn new() -> Self {
56 Self { bodies: Vec::new(), parents: Vec::new(), link_to_body: HashMap::new() }
57 }
58
59 pub fn add_body(&mut self, parent: Option<usize>, body: BodySpec<T>) -> usize {
60 let index = self.bodies.len();
61 let _ = self.link_to_body.insert(body.link.name.clone(), index);
62 self.bodies.push(body);
63 self.parents.push(parent);
64 index
65 }
66
67 #[must_use]
68 pub fn parent(&self, index: usize) -> Option<usize> {
69 self.parents.get(index).copied().flatten()
70 }
71
72 #[must_use]
73 pub fn joint(&self, index: usize) -> Option<&BodySpec<T>> { self.bodies.get(index) }
74
75 #[must_use]
76 pub fn body_index_for_link(&self, link_name: &str) -> Option<usize> {
77 self.link_to_body.get(link_name).copied()
78 }
79
80 #[must_use]
81 pub fn dof(&self) -> usize {
82 self.bodies.iter().filter(|b| !matches!(b.joint_type, JointType::Fixed)).count()
83 }
84
85 pub fn validate(&self) -> Result<(), ModelError> {
86 if self.bodies.is_empty() {
87 return Err(ModelError::EmptyModel);
88 }
89 for (i, parent) in self.parents.iter().enumerate() {
90 if let Some(p) = parent
91 && *p >= self.bodies.len()
92 {
93 return Err(ModelError::InvalidInput(format!(
94 "parent index {p} out of range for body {i}"
95 )));
96 }
97 }
98 self.validate_acyclic()?;
99 Ok(())
100 }
101
102 fn validate_acyclic(&self) -> Result<(), ModelError> {
103 for start in 0..self.bodies.len() {
104 let mut current = Some(start);
105 let mut steps = 0usize;
106 while let Some(idx) = current {
107 steps += 1;
108 if steps > self.bodies.len() {
109 return Err(ModelError::InvalidInput(
110 "cycle detected in robot tree".to_string(),
111 ));
112 }
113 current = self.parent(idx);
114 }
115 }
116 Ok(())
117 }
118
119 #[must_use]
121 pub fn topological_order(&self) -> Vec<usize> {
122 let mut roots: Vec<usize> = self
123 .parents
124 .iter()
125 .enumerate()
126 .filter_map(|(i, p)| if p.is_none() { Some(i) } else { None })
127 .collect();
128 if roots.is_empty() && !self.bodies.is_empty() {
129 roots.push(0);
130 }
131 let mut order = Vec::with_capacity(self.bodies.len());
132 let mut queue: VecDeque<usize> = roots.into();
133 let mut seen = vec![false; self.bodies.len()];
134 while let Some(index) = queue.pop_front() {
135 if seen[index] {
136 continue;
137 }
138 seen[index] = true;
139 order.push(index);
140 for (child, parent) in self.parents.iter().enumerate() {
141 if parent == &Some(index) && !seen[child] {
142 queue.push_back(child);
143 }
144 }
145 }
146 for (i, was_seen) in seen.iter().enumerate() {
147 if !was_seen {
148 order.push(i);
149 }
150 }
151 order
152 }
153
154 #[must_use]
156 pub fn actuated_indices(&self) -> Vec<usize> {
157 self.topological_order()
158 .into_iter()
159 .filter(|&i| self.joint(i).is_some_and(|b| !matches!(b.joint_type, JointType::Fixed)))
160 .collect()
161 }
162
163 #[must_use]
165 pub fn limits_for_joint(&self, joint_index: usize) -> Option<&JointLimits<T>> {
166 self.actuated_indices()
167 .get(joint_index)
168 .and_then(|&body_index| self.joint(body_index))
169 .and_then(|body| body.limits.as_ref())
170 }
171
172 pub fn update_body(&mut self, index: usize, body: BodySpec<T>) -> Result<(), ModelError> {
173 if index >= self.bodies.len() {
174 return Err(ModelError::InvalidInput(format!("body index {index} out of range")));
175 }
176 let _ = self.link_to_body.insert(body.link.name.clone(), index);
177 self.bodies[index] = body;
178 Ok(())
179 }
180}
181
182pub fn extract_chain<T: Clone>(
184 model: &RobotModel<T>,
185 base_link: &str,
186 ee_link: &str,
187) -> Result<Vec<usize>, ModelError> {
188 let ee = model
189 .body_index_for_link(ee_link)
190 .ok_or_else(|| ModelError::InvalidInput(format!("unknown link {ee_link}")))?;
191
192 let mut chain = Vec::new();
193 let mut current = Some(ee);
194 while let Some(idx) = current {
195 let body = model
196 .joint(idx)
197 .ok_or_else(|| ModelError::InvalidInput(format!("missing body {idx}")))?;
198 chain.push(idx);
199 if body.parent_link == base_link {
200 break;
201 }
202 current = model.parent(idx);
203 }
204
205 let reached_base =
206 model.joint(*chain.last().unwrap()).is_some_and(|body| body.parent_link == base_link);
207 if !reached_base {
208 return Err(ModelError::InvalidInput(format!(
209 "no kinematic path from {base_link} to {ee_link}"
210 )));
211 }
212
213 chain.reverse();
214 Ok(chain)
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::joint::JointAxis;
221
222 fn sample_body(name: &str, parent_link: &str) -> BodySpec<f64> {
223 BodySpec {
224 link: LinkSpec { name: name.to_string() },
225 parent_link: parent_link.to_string(),
226 joint_type: JointType::Revolute,
227 axis: JointAxis::Z,
228 limits: None,
229 inertial: None,
230 joint_origin: crate::origin::transform_from_xyz_rpy([1.0, 0.0, 0.0], [0.0, 0.0, 0.0])
231 .expect("valid origin"),
232 dh_params: Some(DhParams {
233 a: 1.0,
234 alpha: 0.0,
235 d: 0.0,
236 theta_offset: 0.0,
237 }),
238 }
239 }
240
241 #[test]
242 fn topological_order_is_bfs() {
243 let mut model = RobotModel::new();
244 let root = model.add_body(None, sample_body("link1", "base"));
245 let _child = model.add_body(Some(root), sample_body("link2", "link1"));
246 assert_eq!(model.topological_order(), vec![0, 1]);
247 }
248
249 #[test]
250 fn extract_chain_finds_serial_path() {
251 let mut model = RobotModel::new();
252 let j1 = model.add_body(None, sample_body("link1", "base"));
253 let _j2 = model.add_body(Some(j1), sample_body("link2", "link1"));
254 let path = extract_chain(&model, "base", "link2").unwrap();
255 assert_eq!(path, vec![0, 1]);
256 }
257
258 #[test]
259 fn rejects_disconnected_extract_chain() {
260 let mut model = RobotModel::new();
261 let _ = model.add_body(None, sample_body("a", "base"));
262 let _ = model.add_body(None, sample_body("b", "other_base"));
263 assert!(extract_chain(&model, "base", "b").is_err());
264 }
265}