1#![forbid(unsafe_code)]
2#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct PidGains {
24 pub kp: f64,
25 pub ki: f64,
26 pub kd: f64,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
30pub struct PidState {
31 pub previous_error: f64,
32 pub integral: f64,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq)]
36pub struct PidController {
37 pub gains: PidGains,
38 pub state: PidState,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum PidError {
43 InvalidGains,
44 InvalidSignal,
45 InvalidTimestep,
46 NonFiniteOutput,
47}
48
49impl PidController {
50 pub fn new(gains: PidGains) -> Result<Self, PidError> {
51 if !gains.kp.is_finite() || !gains.ki.is_finite() || !gains.kd.is_finite() {
52 return Err(PidError::InvalidGains);
53 }
54
55 Ok(Self {
56 gains,
57 state: PidState {
58 previous_error: 0.0,
59 integral: 0.0,
60 },
61 })
62 }
63
64 pub fn update(&mut self, setpoint: f64, measured: f64, dt: f64) -> Result<f64, PidError> {
65 if !setpoint.is_finite() || !measured.is_finite() {
66 return Err(PidError::InvalidSignal);
67 }
68
69 if !dt.is_finite() || dt <= 0.0 {
70 return Err(PidError::InvalidTimestep);
71 }
72
73 let current_error = setpoint - measured;
74 if !current_error.is_finite() {
75 return Err(PidError::InvalidSignal);
76 }
77
78 self.state.integral += current_error * dt;
79 let derivative = (current_error - self.state.previous_error) / dt;
80 let output = self.gains.kp * current_error
81 + self.gains.ki * self.state.integral
82 + self.gains.kd * derivative;
83
84 if !self.state.integral.is_finite() || !derivative.is_finite() || !output.is_finite() {
85 return Err(PidError::NonFiniteOutput);
86 }
87
88 self.state.previous_error = current_error;
89 Ok(output)
90 }
91
92 pub fn reset(&mut self) {
93 self.state = PidState {
94 previous_error: 0.0,
95 integral: 0.0,
96 };
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::{PidController, PidError, PidGains};
103
104 #[test]
105 fn proportional_only_behavior_stays_simple() {
106 let mut controller = PidController::new(PidGains {
107 kp: 2.0,
108 ki: 0.0,
109 kd: 0.0,
110 })
111 .unwrap();
112
113 assert_eq!(controller.update(5.0, 3.0, 0.5).unwrap(), 4.0);
114 }
115
116 #[test]
117 fn integral_term_accumulates_over_time() {
118 let mut controller = PidController::new(PidGains {
119 kp: 0.0,
120 ki: 1.0,
121 kd: 0.0,
122 })
123 .unwrap();
124
125 assert_eq!(controller.update(2.0, 0.0, 0.5).unwrap(), 1.0);
126 assert_eq!(controller.update(2.0, 0.0, 0.5).unwrap(), 2.0);
127 assert_eq!(controller.state.integral, 2.0);
128 }
129
130 #[test]
131 fn reset_clears_integral_and_previous_error() {
132 let mut controller = PidController::new(PidGains {
133 kp: 1.0,
134 ki: 1.0,
135 kd: 1.0,
136 })
137 .unwrap();
138 controller.update(4.0, 1.0, 0.5).unwrap();
139
140 controller.reset();
141
142 assert_eq!(controller.state.previous_error, 0.0);
143 assert_eq!(controller.state.integral, 0.0);
144 }
145
146 #[test]
147 fn rejects_invalid_inputs() {
148 assert_eq!(
149 PidController::new(PidGains {
150 kp: f64::NAN,
151 ki: 0.0,
152 kd: 0.0,
153 }),
154 Err(PidError::InvalidGains)
155 );
156
157 let mut controller = PidController::new(PidGains {
158 kp: 1.0,
159 ki: 0.0,
160 kd: 0.0,
161 })
162 .unwrap();
163
164 assert_eq!(
165 controller.update(1.0, 0.0, 0.0),
166 Err(PidError::InvalidTimestep)
167 );
168 assert_eq!(
169 controller.update(f64::NAN, 0.0, 1.0),
170 Err(PidError::InvalidSignal)
171 );
172 }
173}