use num_traits::{Float, Signed};
#[derive(Debug)]
pub struct PIDController<T> {
setpoint: T,
kp: T,
ki: T,
kd: T,
error_bias: T,
error_limit: Option<T>,
output_limit: Option<T>,
accumulated_error: T,
previous_error: T,
}
impl<T: Float + Signed + Copy> PIDController<T> {
pub fn new(
setpoint: T,
kp: T,
ki: T,
kd: T,
error_bias: T,
error_limit: Option<T>,
output_limit: Option<T>,
) -> Self {
PIDController {
setpoint,
kp,
ki,
kd,
error_limit,
output_limit,
accumulated_error: T::zero(),
previous_error: T::zero(),
error_bias,
}
}
pub fn new_static_controller(setpoint: T) -> Self {
PIDController {
setpoint,
kp: T::zero(),
ki: T::zero(),
kd: T::zero(),
error_limit: None,
output_limit: None,
accumulated_error: T::zero(),
previous_error: T::zero(),
error_bias: T::one(),
}
}
pub fn compute_correction(&mut self, signal: impl Into<T>) -> T {
let error = self.setpoint - signal.into();
let p = self.kp * error;
let biased_error = if error.is_positive() {
error * (num_traits::one::<T>() + self.error_bias)
} else {
error * (num_traits::one::<T>() - self.error_bias)
};
self.accumulated_error = self.accumulated_error + biased_error;
if let Some(error_limit) = self.error_limit {
self.accumulated_error = num_traits::clamp(
self.accumulated_error,
-error_limit.abs(),
error_limit.abs(),
);
}
let i = self.ki * self.accumulated_error;
let d = self.kd * (error - self.previous_error);
let correction = p + i + d;
let clamped_correction = if let Some(output_limit) = self.output_limit {
num_traits::clamp(correction, -output_limit.abs(), output_limit.abs())
} else {
correction
};
if correction != clamped_correction {
let feedback = correction - clamped_correction;
self.accumulated_error = self.accumulated_error - (feedback / self.ki);
}
self.previous_error = error;
clamped_correction
}
pub fn accumulated_error(&self) -> T {
self.accumulated_error
}
pub fn setpoint(&self) -> T {
self.setpoint
}
}
pub struct PIDControllerBuilder<T> {
setpoint: T,
kp: T,
ki: T,
kd: T,
error_bias: T,
error_limit: Option<T>,
output_limit: Option<T>,
}
impl<T: Float + Signed + Copy> PIDControllerBuilder<T> {
pub fn new(setpoint: impl Into<T>) -> Self {
PIDControllerBuilder {
setpoint: setpoint.into(),
kp: T::zero(),
ki: T::zero(),
kd: T::zero(),
error_bias: T::one(),
error_limit: None,
output_limit: None,
}
}
pub fn kp(mut self, kp: impl Into<T>) -> Self {
self.kp = kp.into();
self
}
pub fn ki(mut self, ki: impl Into<T>) -> Self {
self.ki = ki.into();
self
}
pub fn kd(mut self, kd: impl Into<T>) -> Self {
self.kd = kd.into();
self
}
pub fn error_bias(mut self, error_bias: impl Into<T>) -> Self {
self.error_bias = error_bias.into();
self
}
pub fn error_limit(mut self, error_limit: impl Into<T>) -> Self {
self.error_limit = Some(error_limit.into());
self
}
pub fn output_limit(mut self, output_limit: impl Into<T>) -> Self {
self.output_limit = Some(output_limit.into());
self
}
pub fn build(self) -> PIDController<T> {
PIDController {
setpoint: self.setpoint,
kp: self.kp,
ki: self.ki,
kd: self.kd,
error_bias: self.error_bias,
error_limit: self.error_limit,
output_limit: self.output_limit,
accumulated_error: T::zero(),
previous_error: T::zero(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_pid_controller<T: Float + Signed + Copy>(
setpoint: T,
kp: T,
ki: T,
kd: T,
error_bias: T,
error_limit: Option<T>,
output_limit: Option<T>,
) -> PIDController<T> {
let mut pid_controller_builder = PIDControllerBuilder::new(setpoint)
.kp(kp)
.ki(ki)
.kd(kd)
.error_bias(error_bias);
if let Some(error_limit) = error_limit {
pid_controller_builder = pid_controller_builder.error_limit(error_limit);
}
if let Some(output_limit) = output_limit {
pid_controller_builder = pid_controller_builder.output_limit(output_limit);
}
pid_controller_builder.build()
}
#[test]
fn test_pid_initialization() {
let pid = create_pid_controller(1.0, 2.0, 3.0, 4.0, 0.5, Some(10.0), Some(5.0));
assert_eq!(pid.setpoint, 1.0);
assert_eq!(pid.kp, 2.0);
assert_eq!(pid.ki, 3.0);
assert_eq!(pid.kd, 4.0);
assert_eq!(pid.error_bias, 0.5);
assert_eq!(pid.error_limit, Some(10.0));
assert_eq!(pid.output_limit, Some(5.0));
assert_eq!(pid.accumulated_error, 0.0);
assert_eq!(pid.previous_error, 0.0);
}
#[test]
fn test_pid_compute_correction() {
let mut pid = create_pid_controller(1.0, 2.0, 3.0, 4.0, 0.5, None, None);
let correction = pid.compute_correction(0.5);
assert!(correction > 0.0);
}
#[test]
fn test_pid_compute_correction_with_error_limit() {
let mut pid = create_pid_controller(1.0, 2.0, 3.0, 4.0, 0.5, Some(0.1), None);
let correction = pid.compute_correction(0.5);
assert!(correction > 0.0);
assert!(pid.accumulated_error <= 0.1);
}
#[test]
fn test_pid_compute_correction_with_output_limit() {
let mut pid = create_pid_controller(1.0, 2.0, 3.0, 4.0, 0.5, None, Some(0.1));
let correction = pid.compute_correction(0.5);
assert!(correction <= 0.1);
}
#[test]
fn test_pid_zero_gains() {
let mut pid = create_pid_controller(1.0, 0.0, 0.0, 0.0, 0.0, None, None);
let correction = pid.compute_correction(0.5);
assert_eq!(correction, 0.0);
}
#[test]
fn test_pid_negative_feedback() {
let mut pid = create_pid_controller(1.0, -2.0, -3.0, -4.0, 0.5, None, None);
let correction = pid.compute_correction(0.5);
assert!(correction < 0.0);
}
#[test]
fn test_pid_anti_windup() {
let mut pid = create_pid_controller(1.0, 2.0, 3.0, 4.0, 0.5, Some(0.1), Some(0.5));
pid.compute_correction(0.5);
let correction = pid.compute_correction(0.5);
assert!(correction <= 0.5);
}
#[test]
fn test_pid_accumulated_error() {
let mut pid = create_pid_controller(1.0, 2.0, 3.0, 4.0, 0.5, None, None);
pid.compute_correction(0.5);
assert!(pid.accumulated_error() > 0.0);
}
#[test]
fn test_pid_setpoint() {
let pid = create_pid_controller(1.0, 2.0, 3.0, 4.0, 0.5, None, None);
assert_eq!(pid.setpoint, 1.0);
}
}