Skip to main content

piper_client/control/
trajectory.rs

1//! Trajectory Planner - 轨迹规划器
2//!
3//! 使用三次样条插值生成平滑的关节空间轨迹。
4//!
5//! # 算法
6//!
7//! 使用三次多项式插值:
8//! ```text
9//! p(t) = a0 + a1*t + a2*t² + a3*t³
10//! v(t) = a1 + 2*a2*t + 3*a3*t²
11//! ```
12//!
13//! 边界条件:起止速度为 0
14//!
15//! # 特性
16//!
17//! - **Iterator 模式**: 按需生成轨迹点,内存高效
18//! - **平滑性保证**: C² 连续(加速度连续)
19//! - **强类型**: 使用 `Rad` 确保单位正确
20//!
21//! # 示例
22//!
23//! ```rust,no_run
24//! use piper_client::control::TrajectoryPlanner;
25//! use piper_client::types::{JointArray, Rad};
26//! use std::time::Duration;
27//!
28//! let start = JointArray::from([Rad(0.0); 6]);
29//! let end = JointArray::from([Rad(1.0); 6]);
30//! let duration = Duration::from_secs(5);
31//! let frequency_hz = 100.0;  // 100Hz
32//!
33//! let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
34//!
35//! for (position, velocity) in &mut planner {
36//!     // 使用 position 和 velocity
37//!     println!("pos: {:?}, vel: {:?}", position, velocity);
38//! }
39//! ```
40
41use crate::types::{JointArray, Rad};
42use std::time::Duration;
43
44/// 三次样条系数
45///
46/// 表示 `p(t) = a0 + a1*t + a2*t² + a3*t³`
47#[derive(Debug, Clone, Copy)]
48struct CubicCoeffs {
49    a0: f64,
50    a1: f64,
51    a2: f64,
52    a3: f64,
53}
54
55impl CubicCoeffs {
56    /// 在归一化时间 t ∈ [0, 1] 处计算位置
57    fn position(&self, t: f64) -> f64 {
58        self.a0 + self.a1 * t + self.a2 * t * t + self.a3 * t * t * t
59    }
60
61    /// 在归一化时间 t ∈ [0, 1] 处计算速度
62    ///
63    /// 注意:这是对归一化时间的导数,需要除以实际时间长度
64    fn velocity(&self, t: f64) -> f64 {
65        self.a1 + 2.0 * self.a2 * t + 3.0 * self.a3 * t * t
66    }
67}
68
69/// 轨迹规划器
70///
71/// 生成从起点到终点的平滑轨迹。
72pub struct TrajectoryPlanner {
73    /// 每个关节的样条系数
74    spline_coeffs: JointArray<CubicCoeffs>,
75
76    /// 轨迹总时长
77    duration: Duration,
78
79    /// 当前迭代索引
80    current_index: usize,
81
82    /// 总采样点数
83    total_samples: usize,
84}
85
86impl TrajectoryPlanner {
87    /// 创建新的轨迹规划器
88    ///
89    /// # 参数
90    ///
91    /// - `start`: 起始位置
92    /// - `end`: 终止位置
93    /// - `duration`: 轨迹时长
94    /// - `frequency_hz`: 采样频率(Hz)
95    ///
96    /// # 边界条件
97    ///
98    /// - 起始速度: 0
99    /// - 终止速度: 0
100    ///
101    /// # 错误
102    ///
103    /// 如果 `frequency_hz` 不是正数,将 panic。
104    ///
105    /// # 示例
106    ///
107    /// ```rust
108    /// # use piper_client::control::TrajectoryPlanner;
109    /// # use piper_client::types::{JointArray, Rad};
110    /// # use std::time::Duration;
111    /// let start = JointArray::from([Rad(0.0); 6]);
112    /// let end = JointArray::from([Rad(1.57); 6]);
113    /// let planner = TrajectoryPlanner::new(
114    ///     start,
115    ///     end,
116    ///     Duration::from_secs(3),
117    ///     100.0,
118    /// );
119    /// ```
120    pub fn new(
121        start: JointArray<Rad>,
122        end: JointArray<Rad>,
123        duration: Duration,
124        frequency_hz: f64,
125    ) -> Self {
126        // ✅ 输入验证
127        assert!(
128            frequency_hz > 0.0,
129            "frequency_hz must be positive, got: {}",
130            frequency_hz
131        );
132
133        let duration_sec = duration.as_secs_f64();
134
135        // ⚠️ 重要:未来支持 Via Points(途径点)时,
136        // 需要将物理速度乘以 duration_sec 进行时间缩放
137        // 例如: v_start_normalized = v_start_physical * duration_sec
138        let v_start = 0.0; // 当前:起始速度为 0
139        let v_end = 0.0; // 当前:终止速度为 0
140
141        // 为每个关节计算样条系数
142        let spline_coeffs = start.map_with(end, |s, e| {
143            Self::compute_cubic_spline(s.0, v_start, e.0, v_end)
144        });
145
146        let total_samples = (duration_sec * frequency_hz).ceil() as usize;
147
148        TrajectoryPlanner {
149            spline_coeffs,
150            duration,
151            current_index: 0,
152            total_samples,
153        }
154    }
155
156    /// 计算三次样条系数
157    ///
158    /// 给定边界条件 `p(0) = p0`, `v(0) = v0`, `p(1) = p1`, `v(1) = v1`,
159    /// 计算 `p(t) = a0 + a1*t + a2*t² + a3*t³` 的系数。
160    ///
161    /// # 参数
162    ///
163    /// - `p0`: 起始位置
164    /// - `v0`: 起始速度(归一化时间)
165    /// - `p1`: 终止位置
166    /// - `v1`: 终止速度(归一化时间)
167    ///
168    /// # 返回
169    ///
170    /// 三次样条系数
171    fn compute_cubic_spline(p0: f64, v0: f64, p1: f64, v1: f64) -> CubicCoeffs {
172        // 边界条件:
173        // p(0) = a0 = p0
174        // v(0) = a1 = v0
175        // p(1) = a0 + a1 + a2 + a3 = p1
176        // v(1) = a1 + 2*a2 + 3*a3 = v1
177
178        let a0 = p0;
179        let a1 = v0;
180
181        // 解线性方程组:
182        // a2 + a3 = p1 - p0 - v0
183        // 2*a2 + 3*a3 = v1 - v0
184        // =>
185        // a3 = -2*p1 + 2*p0 + v0 + v1
186        // a2 = 3*p1 - 3*p0 - 2*v0 - v1
187
188        let a2 = 3.0 * (p1 - p0) - 2.0 * v0 - v1;
189        let a3 = -2.0 * (p1 - p0) + v0 + v1;
190
191        CubicCoeffs { a0, a1, a2, a3 }
192    }
193
194    /// 在指定时间计算位置和速度
195    ///
196    /// # 参数
197    ///
198    /// - `t`: 归一化时间 [0, 1]
199    ///
200    /// # 返回
201    ///
202    /// `(position, velocity)` 元组
203    fn evaluate_at(&self, t: f64) -> (JointArray<Rad>, JointArray<f64>) {
204        let duration_sec = self.duration.as_secs_f64();
205
206        let position = self.spline_coeffs.map(|coeff| Rad(coeff.position(t)));
207
208        // 速度:需要除以时间长度(从归一化时间导数转换为物理速度)
209        let velocity = self.spline_coeffs.map(|coeff| coeff.velocity(t) / duration_sec);
210
211        (position, velocity)
212    }
213
214    /// 重置迭代器到起点
215    pub fn reset(&mut self) {
216        self.current_index = 0;
217    }
218
219    /// 获取总采样点数
220    pub fn total_samples(&self) -> usize {
221        self.total_samples
222    }
223
224    /// 获取当前进度(0.0 到 1.0)
225    pub fn progress(&self) -> f64 {
226        if self.total_samples == 0 {
227            1.0
228        } else {
229            (self.current_index as f64) / (self.total_samples as f64)
230        }
231    }
232}
233
234impl Iterator for TrajectoryPlanner {
235    type Item = (JointArray<Rad>, JointArray<f64>);
236
237    fn next(&mut self) -> Option<Self::Item> {
238        if self.current_index >= self.total_samples {
239            return None;
240        }
241
242        // 计算归一化时间 t ∈ [0, 1]
243        let t = if self.total_samples <= 1 {
244            1.0
245        } else {
246            (self.current_index as f64) / ((self.total_samples - 1) as f64)
247        };
248
249        let result = self.evaluate_at(t);
250        self.current_index += 1;
251
252        Some(result)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_cubic_coeffs_position() {
262        let coeffs = CubicCoeffs {
263            a0: 0.0,
264            a1: 0.0,
265            a2: 3.0,
266            a3: -2.0,
267        };
268
269        // t=0: p = 0
270        assert!((coeffs.position(0.0) - 0.0).abs() < 1e-10);
271
272        // t=1: p = 0 + 0 + 3 - 2 = 1
273        assert!((coeffs.position(1.0) - 1.0).abs() < 1e-10);
274    }
275
276    #[test]
277    fn test_cubic_coeffs_velocity() {
278        let coeffs = CubicCoeffs {
279            a0: 0.0,
280            a1: 0.0,
281            a2: 3.0,
282            a3: -2.0,
283        };
284
285        // v(t) = 0 + 6*t - 6*t²
286        // t=0: v = 0
287        assert!((coeffs.velocity(0.0) - 0.0).abs() < 1e-10);
288
289        // t=1: v = 6 - 6 = 0
290        assert!((coeffs.velocity(1.0) - 0.0).abs() < 1e-10);
291    }
292
293    #[test]
294    fn test_compute_cubic_spline_zero_velocity() {
295        let coeffs = TrajectoryPlanner::compute_cubic_spline(0.0, 0.0, 1.0, 0.0);
296
297        // 边界条件检查
298        assert!((coeffs.position(0.0) - 0.0).abs() < 1e-10);
299        assert!((coeffs.position(1.0) - 1.0).abs() < 1e-10);
300        assert!((coeffs.velocity(0.0) - 0.0).abs() < 1e-10);
301        assert!((coeffs.velocity(1.0) - 0.0).abs() < 1e-10);
302    }
303
304    #[test]
305    fn test_trajectory_planner_new() {
306        let start = JointArray::from([Rad(0.0); 6]);
307        let end = JointArray::from([Rad(1.0); 6]);
308        let duration = Duration::from_secs(1);
309        let frequency_hz = 10.0;
310
311        let planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
312
313        assert_eq!(planner.total_samples, 10);
314        assert_eq!(planner.current_index, 0);
315    }
316
317    #[test]
318    fn test_trajectory_iterator_basic() {
319        let start = JointArray::from([Rad(0.0); 6]);
320        let end = JointArray::from([Rad(1.0); 6]);
321        let duration = Duration::from_secs(1);
322        let frequency_hz = 5.0; // 5 个采样点
323
324        let planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
325
326        let mut count = 0;
327        for (pos, _vel) in planner {
328            count += 1;
329            // 位置应该在 [0, 1] 范围内
330            assert!(pos[0].0 >= -0.1 && pos[0].0 <= 1.1);
331        }
332
333        assert_eq!(count, 5);
334    }
335
336    #[test]
337    fn test_trajectory_boundary_conditions() {
338        let start = JointArray::from([Rad(0.0); 6]);
339        let end = JointArray::from([Rad(1.57); 6]);
340        let duration = Duration::from_secs(2);
341        let frequency_hz = 100.0;
342
343        let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
344
345        // 第一个点
346        let (first_pos, first_vel) = planner.next().unwrap();
347        assert!((first_pos[0].0 - 0.0).abs() < 1e-6);
348        assert!(first_vel[0].abs() < 1e-6); // 起始速度应该接近 0
349
350        // 跳到最后一个点
351        let mut last = None;
352        for item in planner {
353            last = Some(item);
354        }
355
356        let (last_pos, last_vel) = last.unwrap();
357        assert!((last_pos[0].0 - 1.57).abs() < 1e-6);
358        assert!(last_vel[0].abs() < 1e-6); // 终止速度应该接近 0
359    }
360
361    #[test]
362    fn test_trajectory_reset() {
363        let start = JointArray::from([Rad(0.0); 6]);
364        let end = JointArray::from([Rad(1.0); 6]);
365        let duration = Duration::from_secs(1);
366        let frequency_hz = 10.0;
367
368        let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
369
370        // 迭代几次
371        planner.next();
372        planner.next();
373        assert_eq!(planner.current_index, 2);
374
375        // 重置
376        planner.reset();
377        assert_eq!(planner.current_index, 0);
378
379        // 应该可以重新迭代
380        let (pos, _vel) = planner.next().unwrap();
381        assert!((pos[0].0 - 0.0).abs() < 1e-6);
382    }
383
384    #[test]
385    fn test_trajectory_progress() {
386        let start = JointArray::from([Rad(0.0); 6]);
387        let end = JointArray::from([Rad(1.0); 6]);
388        let duration = Duration::from_secs(1);
389        let frequency_hz = 10.0;
390
391        let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
392
393        assert!((planner.progress() - 0.0).abs() < 1e-10);
394
395        planner.next();
396        assert!(planner.progress() > 0.0 && planner.progress() < 1.0);
397
398        // 迭代到最后
399        while planner.next().is_some() {}
400        assert!((planner.progress() - 1.0).abs() < 1e-10);
401    }
402
403    #[test]
404    fn test_trajectory_smoothness() {
405        let start = JointArray::from([Rad(0.0); 6]);
406        let end = JointArray::from([Rad(1.0); 6]);
407        let duration = Duration::from_secs(1);
408        let frequency_hz = 1000.0; // 高频采样
409
410        let planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
411
412        let mut last_vel: Option<f64> = None;
413        let mut max_accel: f64 = 0.0;
414        let dt: f64 = 1.0 / frequency_hz;
415
416        for (_pos, vel) in planner {
417            if let Some(lv) = last_vel {
418                let accel: f64 = (vel[0] - lv) / dt;
419                max_accel = max_accel.max(accel.abs());
420            }
421            last_vel = Some(vel[0]);
422        }
423
424        // 加速度应该是有界的(对于这个简单的轨迹)
425        // 由于数值微分可能引入噪声,我们使用一个较宽松的阈值
426        assert!(max_accel < 100.0, "Max accel: {}", max_accel);
427    }
428
429    #[test]
430    fn test_trajectory_single_point() {
431        let start = JointArray::from([Rad(0.0); 6]);
432        let end = JointArray::from([Rad(0.0); 6]);
433        let duration = Duration::from_millis(10);
434        let frequency_hz = 100.0;
435
436        let mut planner = TrajectoryPlanner::new(start, end, duration, frequency_hz);
437
438        // 即使起点和终点相同,也应该生成轨迹
439        let mut count = 0;
440        while planner.next().is_some() {
441            count += 1;
442        }
443
444        assert!(count > 0);
445    }
446}