piper_client/control/
trajectory.rs1use crate::types::{JointArray, Rad};
42use std::time::Duration;
43
44#[derive(Debug, Clone, Copy)]
48struct CubicCoeffs {
49 a0: f64,
50 a1: f64,
51 a2: f64,
52 a3: f64,
53}
54
55impl CubicCoeffs {
56 fn position(&self, t: f64) -> f64 {
58 self.a0 + self.a1 * t + self.a2 * t * t + self.a3 * t * t * t
59 }
60
61 fn velocity(&self, t: f64) -> f64 {
65 self.a1 + 2.0 * self.a2 * t + 3.0 * self.a3 * t * t
66 }
67}
68
69pub struct TrajectoryPlanner {
73 spline_coeffs: JointArray<CubicCoeffs>,
75
76 duration: Duration,
78
79 current_index: usize,
81
82 total_samples: usize,
84}
85
86impl TrajectoryPlanner {
87 pub fn new(
121 start: JointArray<Rad>,
122 end: JointArray<Rad>,
123 duration: Duration,
124 frequency_hz: f64,
125 ) -> Self {
126 assert!(
128 frequency_hz > 0.0,
129 "frequency_hz must be positive, got: {}",
130 frequency_hz
131 );
132
133 let duration_sec = duration.as_secs_f64();
134
135 let v_start = 0.0; let v_end = 0.0; let spline_coeffs = start.map_with(end, |s, e| {
143 Self::compute_cubic_spline(s.0, v_start, e.0, v_end)
144 });
145
146 let total_samples = (duration_sec * frequency_hz).ceil() as usize;
147
148 TrajectoryPlanner {
149 spline_coeffs,
150 duration,
151 current_index: 0,
152 total_samples,
153 }
154 }
155
156 fn compute_cubic_spline(p0: f64, v0: f64, p1: f64, v1: f64) -> CubicCoeffs {
172 let a0 = p0;
179 let a1 = v0;
180
181 let a2 = 3.0 * (p1 - p0) - 2.0 * v0 - v1;
189 let a3 = -2.0 * (p1 - p0) + v0 + v1;
190
191 CubicCoeffs { a0, a1, a2, a3 }
192 }
193
194 fn evaluate_at(&self, t: f64) -> (JointArray<Rad>, JointArray<f64>) {
204 let duration_sec = self.duration.as_secs_f64();
205
206 let position = self.spline_coeffs.map(|coeff| Rad(coeff.position(t)));
207
208 let velocity = self.spline_coeffs.map(|coeff| coeff.velocity(t) / duration_sec);
210
211 (position, velocity)
212 }
213
214 pub fn reset(&mut self) {
216 self.current_index = 0;
217 }
218
219 pub fn total_samples(&self) -> usize {
221 self.total_samples
222 }
223
224 pub fn progress(&self) -> f64 {
226 if self.total_samples == 0 {
227 1.0
228 } else {
229 (self.current_index as f64) / (self.total_samples as f64)
230 }
231 }
232}
233
234impl Iterator for TrajectoryPlanner {
235 type Item = (JointArray<Rad>, JointArray<f64>);
236
237 fn next(&mut self) -> Option<Self::Item> {
238 if self.current_index >= self.total_samples {
239 return None;
240 }
241
242 let t = if self.total_samples <= 1 {
244 1.0
245 } else {
246 (self.current_index as f64) / ((self.total_samples - 1) as f64)
247 };
248
249 let result = self.evaluate_at(t);
250 self.current_index += 1;
251
252 Some(result)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_cubic_coeffs_position() {
262 let coeffs = CubicCoeffs {
263 a0: 0.0,
264 a1: 0.0,
265 a2: 3.0,
266 a3: -2.0,
267 };
268
269 assert!((coeffs.position(0.0) - 0.0).abs() < 1e-10);
271
272 assert!((coeffs.position(1.0) - 1.0).abs() < 1e-10);
274 }
275
276 #[test]
277 fn test_cubic_coeffs_velocity() {
278 let coeffs = CubicCoeffs {
279 a0: 0.0,
280 a1: 0.0,
281 a2: 3.0,
282 a3: -2.0,
283 };
284
285 assert!((coeffs.velocity(0.0) - 0.0).abs() < 1e-10);
288
289 assert!((coeffs.velocity(1.0) - 0.0).abs() < 1e-10);
291 }
292
293 #[test]
294 fn test_compute_cubic_spline_zero_velocity() {
295 let coeffs = TrajectoryPlanner::compute_cubic_spline(0.0, 0.0, 1.0, 0.0);
296
297 assert!((coeffs.position(0.0) - 0.0).abs() < 1e-10);
299 assert!((coeffs.position(1.0) - 1.0).abs() < 1e-10);
300 assert!((coeffs.velocity(0.0) - 0.0).abs() < 1e-10);
301 assert!((coeffs.velocity(1.0) - 0.0).abs() < 1e-10);
302 }
303
304 #[test]
305 fn test_trajectory_planner_new() {
306 let start = JointArray::from([Rad(0.0); 6]);
307 let end = JointArray::from([Rad(1.0); 6]);
308 let duration = Duration::from_secs(1);
309 let frequency_hz = 10.0;
310
311 let planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
312
313 assert_eq!(planner.total_samples, 10);
314 assert_eq!(planner.current_index, 0);
315 }
316
317 #[test]
318 fn test_trajectory_iterator_basic() {
319 let start = JointArray::from([Rad(0.0); 6]);
320 let end = JointArray::from([Rad(1.0); 6]);
321 let duration = Duration::from_secs(1);
322 let frequency_hz = 5.0; let planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
325
326 let mut count = 0;
327 for (pos, _vel) in planner {
328 count += 1;
329 assert!(pos[0].0 >= -0.1 && pos[0].0 <= 1.1);
331 }
332
333 assert_eq!(count, 5);
334 }
335
336 #[test]
337 fn test_trajectory_boundary_conditions() {
338 let start = JointArray::from([Rad(0.0); 6]);
339 let end = JointArray::from([Rad(1.57); 6]);
340 let duration = Duration::from_secs(2);
341 let frequency_hz = 100.0;
342
343 let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
344
345 let (first_pos, first_vel) = planner.next().unwrap();
347 assert!((first_pos[0].0 - 0.0).abs() < 1e-6);
348 assert!(first_vel[0].abs() < 1e-6); let mut last = None;
352 for item in planner {
353 last = Some(item);
354 }
355
356 let (last_pos, last_vel) = last.unwrap();
357 assert!((last_pos[0].0 - 1.57).abs() < 1e-6);
358 assert!(last_vel[0].abs() < 1e-6); }
360
361 #[test]
362 fn test_trajectory_reset() {
363 let start = JointArray::from([Rad(0.0); 6]);
364 let end = JointArray::from([Rad(1.0); 6]);
365 let duration = Duration::from_secs(1);
366 let frequency_hz = 10.0;
367
368 let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
369
370 planner.next();
372 planner.next();
373 assert_eq!(planner.current_index, 2);
374
375 planner.reset();
377 assert_eq!(planner.current_index, 0);
378
379 let (pos, _vel) = planner.next().unwrap();
381 assert!((pos[0].0 - 0.0).abs() < 1e-6);
382 }
383
384 #[test]
385 fn test_trajectory_progress() {
386 let start = JointArray::from([Rad(0.0); 6]);
387 let end = JointArray::from([Rad(1.0); 6]);
388 let duration = Duration::from_secs(1);
389 let frequency_hz = 10.0;
390
391 let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
392
393 assert!((planner.progress() - 0.0).abs() < 1e-10);
394
395 planner.next();
396 assert!(planner.progress() > 0.0 && planner.progress() < 1.0);
397
398 while planner.next().is_some() {}
400 assert!((planner.progress() - 1.0).abs() < 1e-10);
401 }
402
403 #[test]
404 fn test_trajectory_smoothness() {
405 let start = JointArray::from([Rad(0.0); 6]);
406 let end = JointArray::from([Rad(1.0); 6]);
407 let duration = Duration::from_secs(1);
408 let frequency_hz = 1000.0; let planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
411
412 let mut last_vel: Option<f64> = None;
413 let mut max_accel: f64 = 0.0;
414 let dt: f64 = 1.0 / frequency_hz;
415
416 for (_pos, vel) in planner {
417 if let Some(lv) = last_vel {
418 let accel: f64 = (vel[0] - lv) / dt;
419 max_accel = max_accel.max(accel.abs());
420 }
421 last_vel = Some(vel[0]);
422 }
423
424 assert!(max_accel < 100.0, "Max accel: {}", max_accel);
427 }
428
429 #[test]
430 fn test_trajectory_single_point() {
431 let start = JointArray::from([Rad(0.0); 6]);
432 let end = JointArray::from([Rad(0.0); 6]);
433 let duration = Duration::from_millis(10);
434 let frequency_hz = 100.0;
435
436 let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
437
438 let mut count = 0;
440 while planner.next().is_some() {
441 count += 1;
442 }
443
444 assert!(count > 0);
445 }
446}