Skip to main content

cu_pid/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2
3#[cfg(not(feature = "std"))]
4extern crate alloc;
5
6use bincode::de::Decoder;
7use bincode::enc::Encoder;
8use bincode::error::{DecodeError, EncodeError};
9use bincode::{Decode, Encode};
10use core::marker::PhantomData;
11use cu29::prelude::*;
12use cu29::reflect::{Reflect, ReflectTypePath};
13use serde::{Deserialize, Serialize};
14
15#[cfg(not(feature = "std"))]
16use alloc::format;
17
18/// Output of the PID controller.
19#[derive(Debug, Default, Clone, Encode, Decode, Serialize, Deserialize, Reflect)]
20pub struct PIDControlOutputPayload {
21    /// Proportional term
22    pub p: f32,
23    /// Integral term
24    pub i: f32,
25    /// Derivative term
26    pub d: f32,
27    /// Final output
28    pub output: f32,
29}
30
31/// This is the underlying standard PID controller.
32#[derive(Reflect)]
33pub struct PIDController {
34    // Configuration
35    kp: f32,
36    ki: f32,
37    kd: f32,
38    setpoint: f32,
39    p_limit: f32,
40    i_limit: f32,
41    d_limit: f32,
42    output_limit: f32,
43    sampling: CuDuration,
44    // Internal state
45    integral: f32,
46    last_error: f32,
47    elapsed: CuDuration,
48    last_output: PIDControlOutputPayload,
49}
50
51impl PIDController {
52    #[allow(clippy::too_many_arguments)]
53    pub fn new(
54        kp: f32,
55        ki: f32,
56        kd: f32,
57        setpoint: f32,
58        p_limit: f32,
59        i_limit: f32,
60        d_limit: f32,
61        output_limit: f32,
62        sampling: CuDuration, // to avoid oversampling and get a bunch of zeros.
63    ) -> Self {
64        PIDController {
65            kp,
66            ki,
67            kd,
68            setpoint,
69            integral: 0.0,
70            last_error: 0.0,
71            p_limit,
72            i_limit,
73            d_limit,
74            output_limit,
75            elapsed: CuDuration::default(),
76            sampling,
77            last_output: PIDControlOutputPayload::default(),
78        }
79    }
80
81    pub fn reset(&mut self) {
82        self.integral = 0.0f32;
83        self.last_error = 0.0f32;
84    }
85
86    pub fn reset_integral(&mut self) {
87        self.integral = 0.0f32;
88    }
89
90    pub fn init_measurement(&mut self, measurement: f32) {
91        self.last_error = self.setpoint - measurement;
92        self.elapsed = self.sampling; // force the computation on the first next_control_output
93    }
94
95    pub fn next_control_output(
96        &mut self,
97        measurement: f32,
98        dt: CuDuration,
99    ) -> PIDControlOutputPayload {
100        self.elapsed += dt;
101
102        if self.elapsed < self.sampling {
103            // if we bang too fast the PID controller, just keep on giving the same answer
104            return self.last_output.clone();
105        }
106
107        let error = self.setpoint - measurement;
108        let CuDuration(elapsed) = self.elapsed;
109        let dt = elapsed as f32 / 1_000_000f32; // the unit is kind of arbitrary.
110        if dt == 0.0 {
111            return self.last_output.clone();
112        }
113
114        // Proportional term
115        let p_unbounded = self.kp * error;
116        let p = p_unbounded.clamp(-self.p_limit, self.p_limit);
117
118        // Integral term (accumulated over time)
119        self.integral += error * dt;
120        let i_unbounded = self.ki * self.integral;
121        let i = i_unbounded.clamp(-self.i_limit, self.i_limit);
122
123        // Derivative term (rate of change)
124        let derivative = (error - self.last_error) / dt;
125        let d_unbounded = self.kd * derivative;
126        let d = d_unbounded.clamp(-self.d_limit, self.d_limit);
127
128        // Update last error for next calculation
129        self.last_error = error;
130
131        // Final output: sum of P, I, D with output limit
132        let output_unbounded = p + i + d;
133        let output = output_unbounded.clamp(-self.output_limit, self.output_limit);
134
135        let output = PIDControlOutputPayload { p, i, d, output };
136
137        self.last_output = output.clone();
138        self.elapsed = CuDuration::default();
139        output
140    }
141}
142
143/// This is the Copper task encapsulating the PID controller.
144#[derive(Reflect)]
145pub struct GenericPIDTask<I>
146where
147    f32: for<'a> From<&'a I>,
148{
149    #[reflect(ignore)]
150    _marker: PhantomData<fn() -> I>,
151    pid: PIDController,
152    first_run: bool,
153    last_tov: CuTime,
154    setpoint: f32,
155    cutoff: f32,
156}
157
158impl<I> CuTask for GenericPIDTask<I>
159where
160    f32: for<'a> From<&'a I>,
161    I: CuMsgPayload + ReflectTypePath + 'static,
162{
163    type Resources<'r> = ();
164    type Input<'m> = input_msg!(I);
165    type Output<'m> = output_msg!(PIDControlOutputPayload);
166
167    fn new(config: Option<&ComponentConfig>, _resources: Self::Resources<'_>) -> CuResult<Self>
168    where
169        Self: Sized,
170    {
171        match config {
172            Some(config) => {
173                debug!("PIDTask config loaded");
174                let setpoint: f32 = config
175                    .get::<f64>("setpoint")?
176                    .ok_or("'setpoint' not found in config")?
177                    as f32;
178
179                let cutoff: f32 = config.get::<f64>("cutoff")?.ok_or(
180                    "'cutoff' not found in config, please set an operating +/- limit on the input.",
181                )? as f32;
182
183                // p is mandatory
184                let kp = match config.get::<f64>("kp")? {
185                    Some(kp) => Ok(kp as f32),
186                    None => Err(CuError::from(
187                        "'kp' not found in the config. We need at least 'kp' to make the PID algorithm work.",
188                    )),
189                }?;
190
191                let p_limit = getcfg(config, "pl", 2.0f32)?;
192                let ki = getcfg(config, "ki", 0.0f32)?;
193                let i_limit = getcfg(config, "il", 1.0f32)?;
194                let kd = getcfg(config, "kd", 0.0f32)?;
195                let d_limit = getcfg(config, "dl", 2.0f32)?;
196                let output_limit = getcfg(config, "ol", 1.0f32)?;
197
198                let sampling = if let Some(value) = config.get::<u32>("sampling_ms")? {
199                    CuDuration::from(value as u64 * 1_000_000u64)
200                } else {
201                    CuDuration::default()
202                };
203
204                let pid: PIDController = PIDController::new(
205                    kp,
206                    ki,
207                    kd,
208                    setpoint,
209                    p_limit,
210                    i_limit,
211                    d_limit,
212                    output_limit,
213                    sampling,
214                );
215
216                Ok(Self {
217                    _marker: PhantomData,
218                    pid,
219                    first_run: true,
220                    last_tov: CuTime::default(),
221                    setpoint,
222                    cutoff,
223                })
224            }
225            None => Err(CuError::from("PIDTask needs a config.")),
226        }
227    }
228
229    fn process(
230        &mut self,
231        _ctx: &CuContext,
232        input: &Self::Input<'_>,
233        output: &mut Self::Output<'_>,
234    ) -> CuResult<()> {
235        output.tov = input.tov;
236        match input.payload() {
237            Some(payload) => {
238                let tov = match input.tov {
239                    Tov::Time(single) => single,
240                    _ => return Err("Unexpected variant for a TOV of PID".into()),
241                };
242
243                let measure: f32 = payload.into();
244
245                if self.first_run {
246                    self.first_run = false;
247                    self.last_tov = tov;
248                    self.pid.init_measurement(measure);
249                    output.clear_payload();
250                    return Ok(());
251                }
252                let dt = tov - self.last_tov;
253                self.last_tov = tov;
254
255                // update the status of the pid.
256                let state = self.pid.next_control_output(measure, dt);
257                // But safety check if the input is within operational margins and cut power if it is not.
258                let upper_limit = self.setpoint + self.cutoff;
259                let lower_limit = self.setpoint - self.cutoff;
260                if measure > upper_limit {
261                    return Err(format!("{} > {} (cutoff)", measure, upper_limit).into());
262                }
263                if measure < lower_limit {
264                    return Err(format!("{} < {} (cutoff)", measure, lower_limit).into());
265                }
266                output.metadata.set_status(format!(
267                    "{:>5.2} {:>5.2} {:>5.2} {:>5.2}",
268                    &state.output, &state.p, &state.i, &state.d
269                ));
270                output.set_payload(state);
271            }
272            None => output.clear_payload(),
273        };
274        Ok(())
275    }
276
277    fn stop(&mut self, _ctx: &CuContext) -> CuResult<()> {
278        self.pid.reset();
279        self.first_run = true;
280        Ok(())
281    }
282}
283
284/// Store/Restore the internal state of the PID controller.
285impl<I> Freezable for GenericPIDTask<I>
286where
287    f32: for<'a> From<&'a I>,
288{
289    fn freeze<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
290        Encode::encode(&self.pid.integral, encoder)?;
291        Encode::encode(&self.pid.last_error, encoder)?;
292        Encode::encode(&self.pid.elapsed, encoder)?;
293        Encode::encode(&self.pid.last_output, encoder)?;
294        Encode::encode(&self.first_run, encoder)?;
295        Encode::encode(&self.last_tov, encoder)?;
296        Ok(())
297    }
298
299    fn thaw<D: Decoder>(&mut self, decoder: &mut D) -> Result<(), DecodeError> {
300        self.pid.integral = Decode::decode(decoder)?;
301        self.pid.last_error = Decode::decode(decoder)?;
302        self.pid.elapsed = Decode::decode(decoder)?;
303        self.pid.last_output = Decode::decode(decoder)?;
304        self.first_run = Decode::decode(decoder)?;
305        self.last_tov = Decode::decode(decoder)?;
306        Ok(())
307    }
308}
309
310// Small helper befause we do this again and again
311fn getcfg(config: &ComponentConfig, key: &str, default: f32) -> Result<f32, ConfigError> {
312    Ok(config
313        .get::<f64>(key)?
314        .map(|value| value as f32)
315        .unwrap_or(default))
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use bincode::config::standard;
322    use bincode::de::DecoderImpl;
323    use bincode::de::read::SliceReader;
324    use bincode::encode_to_vec;
325
326    #[derive(Clone, Copy)]
327    struct TestInput;
328
329    impl From<&TestInput> for f32 {
330        fn from(_: &TestInput) -> Self {
331            0.0
332        }
333    }
334
335    fn sample_task() -> GenericPIDTask<TestInput> {
336        GenericPIDTask {
337            _marker: PhantomData,
338            pid: PIDController {
339                kp: 1.0,
340                ki: 2.0,
341                kd: 3.0,
342                setpoint: 4.0,
343                p_limit: 5.0,
344                i_limit: 6.0,
345                d_limit: 7.0,
346                output_limit: 8.0,
347                sampling: CuDuration::from(9),
348                integral: 10.0,
349                last_error: 11.0,
350                elapsed: CuDuration::from(12),
351                last_output: PIDControlOutputPayload {
352                    p: 13.0,
353                    i: 14.0,
354                    d: 15.0,
355                    output: 16.0,
356                },
357            },
358            first_run: false,
359            last_tov: CuTime::from(17_u64),
360            setpoint: 18.0,
361            cutoff: 19.0,
362        }
363    }
364
365    #[test]
366    fn freeze_thaw_restores_pid_timekeeping_state() {
367        let original = sample_task();
368        let bytes =
369            encode_to_vec(BincodeAdapter(&original), standard()).expect("encode pid task state");
370
371        let mut restored = sample_task();
372        restored.pid.integral = -1.0;
373        restored.pid.last_error = -2.0;
374        restored.pid.elapsed = CuDuration::from(999);
375        restored.pid.last_output = PIDControlOutputPayload {
376            p: -3.0,
377            i: -4.0,
378            d: -5.0,
379            output: -6.0,
380        };
381        restored.first_run = true;
382        restored.last_tov = CuTime::from(1_000_u64);
383
384        let reader = SliceReader::new(&bytes);
385        let mut decoder = DecoderImpl::new(reader, standard(), ());
386        restored.thaw(&mut decoder).expect("thaw pid task state");
387
388        assert_eq!(restored.pid.integral, original.pid.integral);
389        assert_eq!(restored.pid.last_error, original.pid.last_error);
390        assert_eq!(restored.pid.elapsed, original.pid.elapsed);
391        assert_eq!(restored.pid.last_output.p, original.pid.last_output.p);
392        assert_eq!(restored.pid.last_output.i, original.pid.last_output.i);
393        assert_eq!(restored.pid.last_output.d, original.pid.last_output.d);
394        assert_eq!(
395            restored.pid.last_output.output,
396            original.pid.last_output.output
397        );
398        assert_eq!(restored.first_run, original.first_run);
399        assert_eq!(restored.last_tov, original.last_tov);
400    }
401}