use crate::traits::FloatScalar;
#[derive(Debug, Clone, Copy)]
pub struct Pid<T> {
kp: T,
ki: T,
kd: T,
dt: T,
output_min: T,
output_max: T,
tau_d: T,
kb: T,
integral: T,
prev_error: T,
prev_measurement: T,
prev_derivative: T,
initialized: bool,
}
impl<T: FloatScalar> Pid<T> {
pub fn new(kp: T, ki: T, kd: T, dt: T) -> Self {
assert!(dt > T::zero() && dt.is_finite(), "dt must be positive and finite");
let kb = if kp != T::zero() { ki / kp } else { ki };
Self {
kp,
ki,
kd,
dt,
output_min: T::neg_infinity(),
output_max: T::infinity(),
tau_d: T::zero(),
kb,
integral: T::zero(),
prev_error: T::zero(),
prev_measurement: T::zero(),
prev_derivative: T::zero(),
initialized: false,
}
}
pub fn with_output_limits(mut self, min: T, max: T) -> Self {
assert!(min < max, "output_min must be less than output_max");
self.output_min = min;
self.output_max = max;
self
}
pub fn with_derivative_filter(mut self, tau: T) -> Self {
assert!(!(tau < T::zero()), "derivative filter time constant must be non-negative");
self.tau_d = tau;
self
}
pub fn with_back_calculation_gain(mut self, kb: T) -> Self {
self.kb = kb;
self
}
#[inline]
pub fn tick(&mut self, setpoint: T, measurement: T) -> T {
let error = setpoint - measurement;
let two = T::one() + T::one();
let p_term = self.kp * error;
if self.initialized {
self.integral = self.integral + self.ki * (error + self.prev_error) * self.dt / two;
}
let d_raw = if self.initialized {
-(measurement - self.prev_measurement) / self.dt
} else {
T::zero()
};
let d_filtered = if self.tau_d > T::zero() && self.initialized {
let alpha = self.dt / (self.tau_d + self.dt);
self.prev_derivative + alpha * (d_raw - self.prev_derivative)
} else {
d_raw
};
let d_term = self.kd * d_filtered;
let u_unclamped = p_term + self.integral + d_term;
let u_clamped = if u_unclamped > self.output_max {
self.output_max
} else if u_unclamped < self.output_min {
self.output_min
} else {
u_unclamped
};
if self.initialized {
self.integral =
self.integral + self.kb * (u_clamped - u_unclamped) * self.dt;
}
self.prev_error = error;
self.prev_measurement = measurement;
self.prev_derivative = d_filtered;
self.initialized = true;
u_clamped
}
pub fn reset(&mut self) {
self.integral = T::zero();
self.prev_error = T::zero();
self.prev_measurement = T::zero();
self.prev_derivative = T::zero();
self.initialized = false;
}
pub fn gains(&self) -> (T, T, T) {
(self.kp, self.ki, self.kd)
}
pub fn set_gains(&mut self, kp: T, ki: T, kd: T) {
self.kp = kp;
self.ki = ki;
self.kd = kd;
}
pub fn integral(&self) -> T {
self.integral
}
pub fn set_integral(&mut self, value: T) {
self.integral = value;
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-12;
fn assert_near(a: f64, b: f64, tol: f64, msg: &str) {
assert!(
(a - b).abs() < tol,
"{}: {} vs {} (diff {})",
msg,
a,
b,
(a - b).abs()
);
}
#[test]
fn test_new() {
let pid = Pid::new(1.0, 2.0, 3.0, 0.01);
assert_eq!(pid.gains(), (1.0, 2.0, 3.0));
assert_eq!(pid.integral(), 0.0);
}
#[test]
#[should_panic]
fn test_zero_dt_panics() {
Pid::new(1.0, 0.0, 0.0, 0.0);
}
#[test]
#[should_panic]
fn test_negative_dt_panics() {
Pid::new(1.0, 0.0, 0.0, -0.01);
}
#[test]
#[should_panic]
fn test_invalid_limits_panics() {
Pid::new(1.0, 0.0, 0.0, 0.01).with_output_limits(5.0, 5.0);
}
#[test]
#[should_panic]
fn test_invalid_limits_reversed_panics() {
Pid::new(1.0, 0.0, 0.0, 0.01).with_output_limits(5.0, 1.0);
}
#[test]
#[should_panic]
fn test_negative_tau_panics() {
Pid::new(1.0, 0.0, 0.0, 0.01).with_derivative_filter(-0.01);
}
#[test]
fn test_builder_chaining() {
let pid = Pid::new(1.0, 0.5, 0.1, 0.01)
.with_output_limits(-10.0, 10.0)
.with_derivative_filter(0.02)
.with_back_calculation_gain(1.5);
assert_eq!(pid.gains(), (1.0, 0.5, 0.1));
}
#[test]
fn test_p_only() {
let mut pid = Pid::new(2.5, 0.0, 0.0, 0.01);
let u = pid.tick(10.0, 3.0);
assert_near(u, 2.5 * 7.0, TOL, "P-only output");
}
#[test]
fn test_p_only_negative_error() {
let mut pid = Pid::new(3.0, 0.0, 0.0, 0.01);
let u = pid.tick(1.0, 5.0);
assert_near(u, 3.0 * (-4.0), TOL, "P-only negative error");
}
#[test]
fn test_p_only_zero_error() {
let mut pid = Pid::new(3.0, 0.0, 0.0, 0.01);
let u = pid.tick(5.0, 5.0);
assert_near(u, 0.0, TOL, "P-only zero error");
}
#[test]
fn test_i_only_trapezoidal() {
let dt = 0.01;
let ki = 2.0;
let mut pid = Pid::new(0.0, ki, 0.0, dt);
let u1 = pid.tick(10.0, 0.0); assert_near(u1, 0.0, TOL, "I-only first tick (no integration)");
let u2 = pid.tick(10.0, 0.0);
assert_near(u2, 0.2, TOL, "I-only second tick");
assert_near(pid.integral(), 0.2, TOL, "integral value after 2 ticks");
let u3 = pid.tick(10.0, 0.0);
assert_near(u3, 0.4, TOL, "I-only third tick");
}
#[test]
fn test_i_only_varying_error() {
let dt = 0.1;
let ki = 1.0;
let mut pid = Pid::new(0.0, ki, 0.0, dt);
pid.tick(10.0, 0.0); let u = pid.tick(5.0, 0.0);
assert_near(u, 0.75, TOL, "I-only varying error trapezoidal");
}
#[test]
fn test_d_only_step_on_measurement() {
let dt = 0.01;
let kd = 0.5;
let mut pid = Pid::new(0.0, 0.0, kd, dt);
let u1 = pid.tick(0.0, 0.0);
assert_near(u1, 0.0, TOL, "D first tick zero");
let u2 = pid.tick(0.0, 1.0);
assert_near(u2, -50.0, TOL, "D step on measurement");
let u3 = pid.tick(0.0, 1.0);
assert_near(u3, 0.0, TOL, "D constant measurement");
}
#[test]
fn test_d_no_derivative_kick_on_setpoint_change() {
let dt = 0.01;
let kd = 1.0;
let mut pid = Pid::new(0.0, 0.0, kd, dt);
pid.tick(0.0, 5.0); let u = pid.tick(100.0, 5.0);
assert_near(u, 0.0, TOL, "no derivative kick on setpoint change");
}
#[test]
fn test_d_filter_smoothing() {
let dt = 0.01;
let kd = 1.0;
let tau = 0.05; let mut pid = Pid::new(0.0, 0.0, kd, dt).with_derivative_filter(tau);
pid.tick(0.0, 0.0); let u = pid.tick(0.0, 1.0);
let expected = kd * (-100.0 / 6.0);
assert_near(u, expected, TOL, "D filtered step (attenuated)");
assert!(u.abs() < 100.0, "filter reduces derivative impulse");
}
#[test]
fn test_pid_first_order_plant_convergence() {
let dt = 0.001;
let mut pid = Pid::new(5.0, 10.0, 0.1, dt);
let setpoint = 1.0;
let mut y = 0.0;
for _ in 0..10_000 {
let u = pid.tick(setpoint, y);
y = y + dt * (-y + u);
}
assert_near(y, setpoint, 1e-4, "PID converges to setpoint");
}
#[test]
fn test_pid_step_response_integrator_eliminates_offset() {
let dt = 0.001;
let mut pid = Pid::new(1.0, 5.0, 0.0, dt);
let setpoint = 1.0;
let mut y = 0.0;
for _ in 0..20_000 {
let u = pid.tick(setpoint, y);
y = y + dt * (-y + u);
}
assert_near(y, setpoint, 1e-4, "PI eliminates steady-state error");
}
#[test]
fn test_output_clamped_upper() {
let mut pid = Pid::new(100.0, 0.0, 0.0, 0.01).with_output_limits(-5.0, 5.0);
let u = pid.tick(10.0, 0.0); assert_near(u, 5.0, TOL, "clamped to upper limit");
}
#[test]
fn test_output_clamped_lower() {
let mut pid = Pid::new(100.0, 0.0, 0.0, 0.01).with_output_limits(-5.0, 5.0);
let u = pid.tick(0.0, 10.0); assert_near(u, -5.0, TOL, "clamped to lower limit");
}
#[test]
fn test_output_within_range_not_clamped() {
let mut pid = Pid::new(1.0, 0.0, 0.0, 0.01).with_output_limits(-100.0, 100.0);
let u = pid.tick(5.0, 3.0); assert_near(u, 2.0, TOL, "within range, not clamped");
}
#[test]
fn test_anti_windup_faster_recovery() {
let dt: f64 = 0.01;
let kp = 1.0;
let ki = 10.0;
let limit = 1.0;
let setpoint = 10.0;
let mut pid_aw = Pid::new(kp, ki, 0.0, dt).with_output_limits(-limit, limit);
let mut pid_no_aw = Pid::new(kp, ki, 0.0, dt)
.with_output_limits(-limit, limit)
.with_back_calculation_gain(0.0);
for _ in 0..100 {
pid_aw.tick(setpoint, 0.0);
pid_no_aw.tick(setpoint, 0.0);
}
assert!(
pid_no_aw.integral().abs() > pid_aw.integral().abs(),
"without anti-windup, integral winds up more: {} vs {}",
pid_no_aw.integral(),
pid_aw.integral()
);
let mut y_aw = 0.0;
let mut y_no_aw = 0.0;
let mut recovery_aw = usize::MAX;
let mut recovery_no_aw = usize::MAX;
for i in 0..1000 {
let u_aw = pid_aw.tick(0.0, y_aw);
let u_no_aw = pid_no_aw.tick(0.0, y_no_aw);
y_aw = y_aw + dt * (-y_aw + u_aw);
y_no_aw = y_no_aw + dt * (-y_no_aw + u_no_aw);
if recovery_aw == usize::MAX && y_aw.abs() < 0.1 {
recovery_aw = i;
}
if recovery_no_aw == usize::MAX && y_no_aw.abs() < 0.1 {
recovery_no_aw = i;
}
}
assert!(
recovery_aw <= recovery_no_aw,
"anti-windup should recover faster: {} vs {}",
recovery_aw,
recovery_no_aw
);
}
#[test]
fn test_reset_clears_state() {
let mut pid = Pid::new(1.0, 1.0, 1.0, 0.01);
pid.tick(10.0, 0.0);
pid.tick(10.0, 1.0);
assert!(pid.integral() != 0.0);
pid.reset();
assert_eq!(pid.integral(), 0.0);
assert_eq!(pid.gains(), (1.0, 1.0, 1.0)); }
#[test]
fn test_reset_preserves_config() {
let mut pid = Pid::new(2.0, 3.0, 4.0, 0.005)
.with_output_limits(-10.0, 10.0)
.with_derivative_filter(0.02);
pid.tick(5.0, 0.0);
pid.reset();
let u = pid.tick(5.0, 0.0);
assert_near(u, 2.0 * 5.0, TOL, "after reset, first tick is P-only");
}
#[test]
fn test_set_gains() {
let mut pid = Pid::new(1.0, 0.0, 0.0, 0.01);
pid.set_gains(5.0, 2.0, 1.0);
assert_eq!(pid.gains(), (5.0, 2.0, 1.0));
let u = pid.tick(3.0, 0.0);
assert_near(u, 5.0 * 3.0, TOL, "new kp takes effect");
}
#[test]
fn test_set_integral() {
let mut pid = Pid::new(1.0, 1.0, 0.0, 0.01);
pid.set_integral(5.0);
assert_eq!(pid.integral(), 5.0);
let u = pid.tick(1.0, 0.0); assert_near(u, 6.0, TOL, "set_integral adds to output");
}
#[test]
fn test_f32() {
let mut pid = Pid::new(1.0_f32, 0.5, 0.1, 0.01)
.with_output_limits(-10.0, 10.0)
.with_derivative_filter(0.02);
let u = pid.tick(5.0_f32, 0.0);
assert!((u - 5.0_f32).abs() < 1e-5, "f32 P-only output");
for _ in 0..10 {
pid.tick(5.0, u);
}
assert!(pid.integral().is_finite(), "f32 integral is finite");
}
}