Skip to main content

piper_client/control/
pid.rs

1//! PID Controller - 比例-积分-微分控制器
2//!
3//! 实现经典的 PID 控制算法,适用于关节位置控制。
4//!
5//! # 算法
6//!
7//! ```text
8//! output = Kp * e + Ki * ∫e dt + Kd * de/dt
9//! ```
10//!
11//! 其中:
12//! - `e` = 目标位置 - 当前位置(误差)
13//! - `∫e dt` = 累积误差(积分项)
14//! - `de/dt` = 误差变化率(微分项)
15//!
16//! # 特性
17//!
18//! - **积分饱和保护**: 限制积分项累积,防止积分饱和(Integral Windup)
19//! - **时间跳变处理**: 正确处理 `dt` 异常,只重置微分项,保留积分项
20//! - **强类型单位**: 使用 `Rad` 和 `NewtonMeter` 确保单位正确
21//!
22//! # 示例
23//!
24//! ```rust,no_run
25//! use piper_client::control::{PidController, Controller};
26//! use piper_client::types::{JointArray, Rad};
27//!
28//! // 创建 PID 控制器
29//! let target = JointArray::from([Rad(1.0); 6]);
30//! let mut pid = PidController::new(target)
31//!     .with_gains(10.0, 0.5, 0.1)
32//!     .with_integral_limit(5.0)
33//!     .with_output_limit(50.0);
34//!
35//! // 在控制循环中使用
36//! # use std::time::Duration;
37//! # let current = JointArray::from([Rad(0.5); 6]);
38//! # let dt = Duration::from_millis(10);
39//! let output = pid.tick(&current, dt).unwrap();
40//! ```
41
42use super::controller::Controller;
43use crate::types::{JointArray, NewtonMeter, Rad};
44use std::time::Duration;
45
46/// PID 控制器
47///
48/// 实现经典的比例-积分-微分控制算法。
49#[derive(Debug, Clone)]
50pub struct PidController {
51    /// 目标位置
52    target: JointArray<Rad>,
53
54    /// 比例增益 (Kp)
55    kp: f64,
56
57    /// 积分增益 (Ki)
58    ki: f64,
59
60    /// 微分增益 (Kd)
61    kd: f64,
62
63    /// 积分项累积值
64    integral: JointArray<f64>,
65
66    /// 上一次的误差(用于计算微分)
67    last_error: JointArray<f64>,
68
69    /// 积分项限制(防止积分饱和)
70    integral_limit: f64,
71
72    /// 输出力矩限制
73    output_limit: f64,
74}
75
76impl PidController {
77    /// 创建新的 PID 控制器
78    ///
79    /// # 参数
80    ///
81    /// - `target`: 目标关节位置
82    ///
83    /// # 默认参数
84    ///
85    /// - Kp = 0.0, Ki = 0.0, Kd = 0.0(需要手动设置)
86    /// - 积分限制 = 10.0
87    /// - 输出限制 = 100.0 Nm
88    ///
89    /// # 示例
90    ///
91    /// ```rust
92    /// # use piper_client::control::PidController;
93    /// # use piper_client::types::{JointArray, Rad};
94    /// let target = JointArray::from([Rad(1.0); 6]);
95    /// let pid = PidController::new(target);
96    /// ```
97    pub fn new(target: JointArray<Rad>) -> Self {
98        PidController {
99            target,
100            kp: 0.0,
101            ki: 0.0,
102            kd: 0.0,
103            integral: JointArray::from([0.0; 6]),
104            last_error: JointArray::from([0.0; 6]),
105            integral_limit: 10.0,
106            output_limit: 100.0,
107        }
108    }
109
110    /// 设置 PID 增益
111    ///
112    /// # 参数
113    ///
114    /// - `kp`: 比例增益
115    /// - `ki`: 积分增益
116    /// - `kd`: 微分增益
117    ///
118    /// # 示例
119    ///
120    /// ```rust
121    /// # use piper_client::control::PidController;
122    /// # use piper_client::types::{JointArray, Rad};
123    /// # let target = JointArray::from([Rad(1.0); 6]);
124    /// let pid = PidController::new(target)
125    ///     .with_gains(10.0, 0.5, 0.1);
126    /// ```
127    pub fn with_gains(mut self, kp: f64, ki: f64, kd: f64) -> Self {
128        self.kp = kp;
129        self.ki = ki;
130        self.kd = kd;
131        self
132    }
133
134    /// 设置积分项限制
135    ///
136    /// 防止积分饱和(Integral Windup)。
137    ///
138    /// # 参数
139    ///
140    /// - `limit`: 积分项绝对值的最大值
141    ///
142    /// # 示例
143    ///
144    /// ```rust
145    /// # use piper_client::control::PidController;
146    /// # use piper_client::types::{JointArray, Rad};
147    /// # let target = JointArray::from([Rad(1.0); 6]);
148    /// let pid = PidController::new(target)
149    ///     .with_integral_limit(5.0);
150    /// ```
151    pub fn with_integral_limit(mut self, limit: f64) -> Self {
152        self.integral_limit = limit;
153        self
154    }
155
156    /// 设置输出力矩限制
157    ///
158    /// # 参数
159    ///
160    /// - `limit`: 输出力矩绝对值的最大值(Nm)
161    ///
162    /// # 示例
163    ///
164    /// ```rust
165    /// # use piper_client::control::PidController;
166    /// # use piper_client::types::{JointArray, Rad};
167    /// # let target = JointArray::from([Rad(1.0); 6]);
168    /// let pid = PidController::new(target)
169    ///     .with_output_limit(50.0);
170    /// ```
171    pub fn with_output_limit(mut self, limit: f64) -> Self {
172        self.output_limit = limit;
173        self
174    }
175
176    /// 更新目标位置
177    ///
178    /// # 参数
179    ///
180    /// - `target`: 新的目标关节位置
181    ///
182    /// # 示例
183    ///
184    /// ```rust
185    /// # use piper_client::control::PidController;
186    /// # use piper_client::types::{JointArray, Rad};
187    /// # let target = JointArray::from([Rad(1.0); 6]);
188    /// let mut pid = PidController::new(target);
189    /// pid.set_target(JointArray::from([Rad(2.0); 6]));
190    /// ```
191    pub fn set_target(&mut self, target: JointArray<Rad>) {
192        self.target = target;
193    }
194
195    /// 获取当前目标位置
196    pub fn target(&self) -> JointArray<Rad> {
197        self.target
198    }
199
200    /// 获取当前积分项
201    ///
202    /// 用于调试和监控。
203    pub fn integral(&self) -> JointArray<f64> {
204        self.integral
205    }
206}
207
208impl Controller for PidController {
209    type Error = std::io::Error;
210
211    fn tick(
212        &mut self,
213        current: &JointArray<Rad>,
214        dt: Duration,
215    ) -> Result<JointArray<NewtonMeter>, Self::Error> {
216        let dt_sec = dt.as_secs_f64();
217
218        // 防止除零
219        if dt_sec <= 0.0 {
220            tracing::warn!(
221                "PID controller received zero or negative dt: {:?}, returning zero output",
222                dt
223            );
224            return Ok(JointArray::from([NewtonMeter(0.0); 6]));
225        }
226
227        // 1. 计算误差
228        let error = self.target.map_with(*current, |t, c| (t - c).0);
229
230        // 2. 比例项(P)
231        let p_term = error.map(|e| self.kp * e);
232
233        // 3. 积分项(I)+ 饱和保护
234        self.integral = self.integral.map_with(error, |i, e| {
235            let new_i = i + e * dt_sec;
236            // 钳位到 [-integral_limit, +integral_limit]
237            new_i.clamp(-self.integral_limit, self.integral_limit)
238        });
239        let i_term = self.integral.map(|i| self.ki * i);
240
241        // 4. 微分项(D)
242        let d_term = error.map_with(self.last_error, |e, le| self.kd * (e - le) / dt_sec);
243
244        // 5. 更新上一次误差
245        self.last_error = error;
246
247        // 6. 计算总输出
248        let output = p_term.map_with(i_term, |p, i| p + i).map_with(d_term, |pi, d| pi + d);
249
250        // 7. 钳位输出
251        let clamped_output =
252            output.map(|o| NewtonMeter(o.clamp(-self.output_limit, self.output_limit)));
253
254        Ok(clamped_output)
255    }
256
257    fn on_time_jump(&mut self, dt: Duration) -> Result<(), Self::Error> {
258        tracing::warn!(
259            "PID controller detected time jump: {:?}, resetting derivative term only",
260            dt
261        );
262
263        // ✅ 只重置微分项
264        self.last_error = JointArray::from([0.0; 6]);
265
266        // ❌ 不要清零积分项!
267        // 原因:机械臂可能依赖积分项对抗重力
268        // 清零会导致机械臂瞬间下坠(Sagging)
269
270        Ok(())
271    }
272
273    fn reset(&mut self) -> Result<(), Self::Error> {
274        // 完全重置控制器状态
275        self.integral = JointArray::from([0.0; 6]);
276        self.last_error = JointArray::from([0.0; 6]);
277        Ok(())
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_pid_new() {
287        let target = JointArray::from([Rad(1.0); 6]);
288        let pid = PidController::new(target);
289
290        assert_eq!(pid.kp, 0.0);
291        assert_eq!(pid.ki, 0.0);
292        assert_eq!(pid.kd, 0.0);
293        assert_eq!(pid.integral_limit, 10.0);
294        assert_eq!(pid.output_limit, 100.0);
295    }
296
297    #[test]
298    fn test_pid_builder() {
299        let target = JointArray::from([Rad(1.0); 6]);
300        let pid = PidController::new(target)
301            .with_gains(10.0, 0.5, 0.1)
302            .with_integral_limit(5.0)
303            .with_output_limit(50.0);
304
305        assert_eq!(pid.kp, 10.0);
306        assert_eq!(pid.ki, 0.5);
307        assert_eq!(pid.kd, 0.1);
308        assert_eq!(pid.integral_limit, 5.0);
309        assert_eq!(pid.output_limit, 50.0);
310    }
311
312    #[test]
313    fn test_pid_proportional_only() {
314        let target = JointArray::from([Rad(1.0); 6]);
315        let mut pid = PidController::new(target).with_gains(10.0, 0.0, 0.0);
316
317        let current = JointArray::from([Rad(0.5); 6]);
318        let dt = Duration::from_millis(10);
319
320        let output = pid.tick(&current, dt).unwrap();
321
322        // 误差 = 1.0 - 0.5 = 0.5
323        // 输出 = 10.0 * 0.5 = 5.0
324        assert!((output[0].0 - 5.0).abs() < 1e-10);
325    }
326
327    #[test]
328    fn test_pid_integral_accumulation() {
329        let target = JointArray::from([Rad(1.0); 6]);
330        let mut pid = PidController::new(target).with_gains(0.0, 1.0, 0.0); // 只有积分项
331
332        let current = JointArray::from([Rad(0.5); 6]);
333        let dt = Duration::from_millis(100); // 0.1 秒
334
335        // 第一次 tick
336        let output1 = pid.tick(&current, dt).unwrap();
337        // 误差 = 0.5, 积分 = 0.5 * 0.1 = 0.05
338        // 输出 = 1.0 * 0.05 = 0.05
339        assert!((output1[0].0 - 0.05).abs() < 1e-10);
340
341        // 第二次 tick
342        let output2 = pid.tick(&current, dt).unwrap();
343        // 积分 = 0.05 + 0.5 * 0.1 = 0.1
344        // 输出 = 1.0 * 0.1 = 0.1
345        assert!((output2[0].0 - 0.1).abs() < 1e-10);
346    }
347
348    #[test]
349    fn test_pid_integral_saturation() {
350        let target = JointArray::from([Rad(1.0); 6]);
351        let mut pid = PidController::new(target).with_gains(0.0, 1.0, 0.0).with_integral_limit(0.5); // 积分限制
352
353        let current = JointArray::from([Rad(0.0); 6]);
354        let dt = Duration::from_secs(1);
355
356        // 误差 = 1.0, 每秒累积 1.0
357        // 但积分被限制在 0.5
358        for _ in 0..10 {
359            pid.tick(&current, dt).unwrap();
360        }
361
362        // 积分应该被钳位到 0.5
363        assert!((pid.integral()[0] - 0.5).abs() < 1e-10);
364    }
365
366    #[test]
367    fn test_pid_derivative_term() {
368        let target = JointArray::from([Rad(1.0); 6]);
369        let mut pid = PidController::new(target).with_gains(0.0, 0.0, 1.0); // 只有微分项
370
371        let dt = Duration::from_millis(100);
372
373        // 第一次:误差从 0 变化
374        let current1 = JointArray::from([Rad(0.5); 6]);
375        let output1 = pid.tick(&current1, dt).unwrap();
376        // 误差 = 0.5, 上次误差 = 0, 变化率 = 0.5 / 0.1 = 5.0
377        // 输出 = 1.0 * 5.0 = 5.0
378        assert!((output1[0].0 - 5.0).abs() < 1e-10);
379
380        // 第二次:误差不变
381        let output2 = pid.tick(&current1, dt).unwrap();
382        // 误差变化 = 0, 输出 = 0
383        assert!((output2[0].0 - 0.0).abs() < 1e-10);
384    }
385
386    #[test]
387    fn test_pid_output_clamping() {
388        let target = JointArray::from([Rad(100.0); 6]);
389        let mut pid =
390            PidController::new(target).with_gains(100.0, 0.0, 0.0).with_output_limit(50.0);
391
392        let current = JointArray::from([Rad(0.0); 6]);
393        let dt = Duration::from_millis(10);
394
395        let output = pid.tick(&current, dt).unwrap();
396
397        // 理论输出 = 100.0 * 100.0 = 10000.0
398        // 但被钳位到 50.0
399        assert!((output[0].0 - 50.0).abs() < 1e-10);
400    }
401
402    #[test]
403    fn test_pid_on_time_jump_preserves_integral() {
404        let target = JointArray::from([Rad(1.0); 6]);
405        let mut pid = PidController::new(target).with_gains(0.0, 1.0, 1.0);
406
407        let current = JointArray::from([Rad(0.5); 6]);
408        let dt = Duration::from_secs(1);
409
410        // 累积一些积分
411        pid.tick(&current, dt).unwrap();
412        let integral_before = pid.integral()[0];
413        assert!(integral_before > 0.0);
414
415        // 调用 on_time_jump
416        pid.on_time_jump(Duration::from_secs(10)).unwrap();
417
418        // ✅ 积分应该保留
419        let integral_after = pid.integral()[0];
420        assert_eq!(integral_before, integral_after);
421
422        // ✅ 微分项应该被重置
423        assert_eq!(pid.last_error[0], 0.0);
424    }
425
426    #[test]
427    fn test_pid_reset() {
428        let target = JointArray::from([Rad(1.0); 6]);
429        let mut pid = PidController::new(target).with_gains(1.0, 1.0, 1.0);
430
431        let current = JointArray::from([Rad(0.5); 6]);
432        let dt = Duration::from_secs(1);
433
434        // 累积一些状态
435        pid.tick(&current, dt).unwrap();
436        assert!(pid.integral()[0] != 0.0);
437        assert!(pid.last_error[0] != 0.0);
438
439        // 重置
440        pid.reset().unwrap();
441
442        // 所有状态应该被清零
443        assert_eq!(pid.integral()[0], 0.0);
444        assert_eq!(pid.last_error[0], 0.0);
445    }
446
447    #[test]
448    fn test_pid_set_target() {
449        let target1 = JointArray::from([Rad(1.0); 6]);
450        let mut pid = PidController::new(target1);
451
452        let target2 = JointArray::from([Rad(2.0); 6]);
453        pid.set_target(target2);
454
455        assert_eq!(pid.target()[0].0, 2.0);
456    }
457
458    #[test]
459    fn test_pid_zero_dt() {
460        let target = JointArray::from([Rad(1.0); 6]);
461        let mut pid = PidController::new(target).with_gains(10.0, 1.0, 1.0);
462
463        let current = JointArray::from([Rad(0.5); 6]);
464        let dt = Duration::from_secs(0);
465
466        // dt = 0 应该返回零输出
467        let output = pid.tick(&current, dt).unwrap();
468        assert_eq!(output[0].0, 0.0);
469    }
470}