1use super::controller::Controller;
43use crate::types::{JointArray, NewtonMeter, Rad};
44use std::time::Duration;
45
46#[derive(Debug, Clone)]
50pub struct PidController {
51 target: JointArray<Rad>,
53
54 kp: f64,
56
57 ki: f64,
59
60 kd: f64,
62
63 integral: JointArray<f64>,
65
66 last_error: JointArray<f64>,
68
69 integral_limit: f64,
71
72 output_limit: f64,
74}
75
76impl PidController {
77 pub fn new(target: JointArray<Rad>) -> Self {
98 PidController {
99 target,
100 kp: 0.0,
101 ki: 0.0,
102 kd: 0.0,
103 integral: JointArray::from([0.0; 6]),
104 last_error: JointArray::from([0.0; 6]),
105 integral_limit: 10.0,
106 output_limit: 100.0,
107 }
108 }
109
110 pub fn with_gains(mut self, kp: f64, ki: f64, kd: f64) -> Self {
128 self.kp = kp;
129 self.ki = ki;
130 self.kd = kd;
131 self
132 }
133
134 pub fn with_integral_limit(mut self, limit: f64) -> Self {
152 self.integral_limit = limit;
153 self
154 }
155
156 pub fn with_output_limit(mut self, limit: f64) -> Self {
172 self.output_limit = limit;
173 self
174 }
175
176 pub fn set_target(&mut self, target: JointArray<Rad>) {
192 self.target = target;
193 }
194
195 pub fn target(&self) -> JointArray<Rad> {
197 self.target
198 }
199
200 pub fn integral(&self) -> JointArray<f64> {
204 self.integral
205 }
206}
207
208impl Controller for PidController {
209 type Error = std::io::Error;
210
211 fn tick(
212 &mut self,
213 current: &JointArray<Rad>,
214 dt: Duration,
215 ) -> Result<JointArray<NewtonMeter>, Self::Error> {
216 let dt_sec = dt.as_secs_f64();
217
218 if dt_sec <= 0.0 {
220 tracing::warn!(
221 "PID controller received zero or negative dt: {:?}, returning zero output",
222 dt
223 );
224 return Ok(JointArray::from([NewtonMeter(0.0); 6]));
225 }
226
227 let error = self.target.map_with(*current, |t, c| (t - c).0);
229
230 let p_term = error.map(|e| self.kp * e);
232
233 self.integral = self.integral.map_with(error, |i, e| {
235 let new_i = i + e * dt_sec;
236 new_i.clamp(-self.integral_limit, self.integral_limit)
238 });
239 let i_term = self.integral.map(|i| self.ki * i);
240
241 let d_term = error.map_with(self.last_error, |e, le| self.kd * (e - le) / dt_sec);
243
244 self.last_error = error;
246
247 let output = p_term.map_with(i_term, |p, i| p + i).map_with(d_term, |pi, d| pi + d);
249
250 let clamped_output =
252 output.map(|o| NewtonMeter(o.clamp(-self.output_limit, self.output_limit)));
253
254 Ok(clamped_output)
255 }
256
257 fn on_time_jump(&mut self, dt: Duration) -> Result<(), Self::Error> {
258 tracing::warn!(
259 "PID controller detected time jump: {:?}, resetting derivative term only",
260 dt
261 );
262
263 self.last_error = JointArray::from([0.0; 6]);
265
266 Ok(())
271 }
272
273 fn reset(&mut self) -> Result<(), Self::Error> {
274 self.integral = JointArray::from([0.0; 6]);
276 self.last_error = JointArray::from([0.0; 6]);
277 Ok(())
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_pid_new() {
287 let target = JointArray::from([Rad(1.0); 6]);
288 let pid = PidController::new(target);
289
290 assert_eq!(pid.kp, 0.0);
291 assert_eq!(pid.ki, 0.0);
292 assert_eq!(pid.kd, 0.0);
293 assert_eq!(pid.integral_limit, 10.0);
294 assert_eq!(pid.output_limit, 100.0);
295 }
296
297 #[test]
298 fn test_pid_builder() {
299 let target = JointArray::from([Rad(1.0); 6]);
300 let pid = PidController::new(target)
301 .with_gains(10.0, 0.5, 0.1)
302 .with_integral_limit(5.0)
303 .with_output_limit(50.0);
304
305 assert_eq!(pid.kp, 10.0);
306 assert_eq!(pid.ki, 0.5);
307 assert_eq!(pid.kd, 0.1);
308 assert_eq!(pid.integral_limit, 5.0);
309 assert_eq!(pid.output_limit, 50.0);
310 }
311
312 #[test]
313 fn test_pid_proportional_only() {
314 let target = JointArray::from([Rad(1.0); 6]);
315 let mut pid = PidController::new(target).with_gains(10.0, 0.0, 0.0);
316
317 let current = JointArray::from([Rad(0.5); 6]);
318 let dt = Duration::from_millis(10);
319
320 let output = pid.tick(¤t, dt).unwrap();
321
322 assert!((output[0].0 - 5.0).abs() < 1e-10);
325 }
326
327 #[test]
328 fn test_pid_integral_accumulation() {
329 let target = JointArray::from([Rad(1.0); 6]);
330 let mut pid = PidController::new(target).with_gains(0.0, 1.0, 0.0); let current = JointArray::from([Rad(0.5); 6]);
333 let dt = Duration::from_millis(100); let output1 = pid.tick(¤t, dt).unwrap();
337 assert!((output1[0].0 - 0.05).abs() < 1e-10);
340
341 let output2 = pid.tick(¤t, dt).unwrap();
343 assert!((output2[0].0 - 0.1).abs() < 1e-10);
346 }
347
348 #[test]
349 fn test_pid_integral_saturation() {
350 let target = JointArray::from([Rad(1.0); 6]);
351 let mut pid = PidController::new(target).with_gains(0.0, 1.0, 0.0).with_integral_limit(0.5); let current = JointArray::from([Rad(0.0); 6]);
354 let dt = Duration::from_secs(1);
355
356 for _ in 0..10 {
359 pid.tick(¤t, dt).unwrap();
360 }
361
362 assert!((pid.integral()[0] - 0.5).abs() < 1e-10);
364 }
365
366 #[test]
367 fn test_pid_derivative_term() {
368 let target = JointArray::from([Rad(1.0); 6]);
369 let mut pid = PidController::new(target).with_gains(0.0, 0.0, 1.0); let dt = Duration::from_millis(100);
372
373 let current1 = JointArray::from([Rad(0.5); 6]);
375 let output1 = pid.tick(¤t1, dt).unwrap();
376 assert!((output1[0].0 - 5.0).abs() < 1e-10);
379
380 let output2 = pid.tick(¤t1, dt).unwrap();
382 assert!((output2[0].0 - 0.0).abs() < 1e-10);
384 }
385
386 #[test]
387 fn test_pid_output_clamping() {
388 let target = JointArray::from([Rad(100.0); 6]);
389 let mut pid =
390 PidController::new(target).with_gains(100.0, 0.0, 0.0).with_output_limit(50.0);
391
392 let current = JointArray::from([Rad(0.0); 6]);
393 let dt = Duration::from_millis(10);
394
395 let output = pid.tick(¤t, dt).unwrap();
396
397 assert!((output[0].0 - 50.0).abs() < 1e-10);
400 }
401
402 #[test]
403 fn test_pid_on_time_jump_preserves_integral() {
404 let target = JointArray::from([Rad(1.0); 6]);
405 let mut pid = PidController::new(target).with_gains(0.0, 1.0, 1.0);
406
407 let current = JointArray::from([Rad(0.5); 6]);
408 let dt = Duration::from_secs(1);
409
410 pid.tick(¤t, dt).unwrap();
412 let integral_before = pid.integral()[0];
413 assert!(integral_before > 0.0);
414
415 pid.on_time_jump(Duration::from_secs(10)).unwrap();
417
418 let integral_after = pid.integral()[0];
420 assert_eq!(integral_before, integral_after);
421
422 assert_eq!(pid.last_error[0], 0.0);
424 }
425
426 #[test]
427 fn test_pid_reset() {
428 let target = JointArray::from([Rad(1.0); 6]);
429 let mut pid = PidController::new(target).with_gains(1.0, 1.0, 1.0);
430
431 let current = JointArray::from([Rad(0.5); 6]);
432 let dt = Duration::from_secs(1);
433
434 pid.tick(¤t, dt).unwrap();
436 assert!(pid.integral()[0] != 0.0);
437 assert!(pid.last_error[0] != 0.0);
438
439 pid.reset().unwrap();
441
442 assert_eq!(pid.integral()[0], 0.0);
444 assert_eq!(pid.last_error[0], 0.0);
445 }
446
447 #[test]
448 fn test_pid_set_target() {
449 let target1 = JointArray::from([Rad(1.0); 6]);
450 let mut pid = PidController::new(target1);
451
452 let target2 = JointArray::from([Rad(2.0); 6]);
453 pid.set_target(target2);
454
455 assert_eq!(pid.target()[0].0, 2.0);
456 }
457
458 #[test]
459 fn test_pid_zero_dt() {
460 let target = JointArray::from([Rad(1.0); 6]);
461 let mut pid = PidController::new(target).with_gains(10.0, 1.0, 1.0);
462
463 let current = JointArray::from([Rad(0.5); 6]);
464 let dt = Duration::from_secs(0);
465
466 let output = pid.tick(¤t, dt).unwrap();
468 assert_eq!(output[0].0, 0.0);
469 }
470}