#![no_std] #![forbid(unsafe_code)]
#[cfg(feature = "f32")]
pub type Float = f32;
#[cfg(not(feature = "f32"))]
pub type Float = f64;
#[derive(Debug, Clone, Copy)]
pub struct So2Controller {
pub w_n: Float,
pub zeta: Float,
y: Float,
y_prev: Float,
pub setpoint: Float,
pub gain: Float,
pub max_velocity: Option<Float>,
pub max_acceleration: Option<Float>,
}
impl So2Controller {
#[inline(always)]
pub fn new(
w_n: Float,
zeta: Float,
initial_value: Float,
gain: Float,
) -> Self {
Self {
w_n,
zeta,
y: initial_value,
y_prev: initial_value,
setpoint: initial_value,
gain,
max_velocity: None,
max_acceleration: None,
}
}
#[inline(always)]
pub fn update(&mut self, input: Float, dt: Float) -> Float {
if !dt.is_finite() || dt <= 0.0 {
return self.y;
}
self.setpoint = input;
let safe_dt = if dt < 1e-6 { 1e-6 } else if dt > 0.1 { 0.1 } else { dt };
let a = self.w_n * self.w_n;
let b = 2.0 * self.zeta * self.w_n;
let dy = (self.y - self.y_prev) / safe_dt;
let mut d2y = a * (self.gain * self.setpoint - self.y) - b * dy;
if let Some(max_a) = self.max_acceleration {
if d2y > max_a { d2y = max_a; }
else if d2y < -max_a { d2y = -max_a; }
}
let mut next_y = self.y + dy * safe_dt + 0.5 * d2y * safe_dt * safe_dt;
if let Some(max_v) = self.max_velocity {
let vel = (next_y - self.y) / safe_dt;
if vel > max_v { next_y = self.y + max_v * safe_dt; }
else if vel < -max_v { next_y = self.y - max_v * safe_dt; }
}
self.y_prev = self.y;
self.y = next_y;
self.y
}
#[inline(always)]
pub fn set_target(&mut self, target: Float) {
self.setpoint = target;
}
#[inline(always)]
pub fn reset(&mut self, value: Float) {
self.y = value;
self.y_prev = value;
self.setpoint = value;
}
#[inline(always)]
pub fn set_max_velocity(&mut self, max_v: Float) {
self.max_velocity = Some(max_v);
}
#[inline(always)]
pub fn set_max_acceleration(&mut self, max_a: Float) {
self.max_acceleration = Some(max_a);
}
}
#[cfg(test)]
mod tests {
extern crate std;
use super::*;
#[test]
fn test_so2_stability() {
let mut so2 = So2Controller::new(20.0, 0.5, 0.0, 1.0);
let dt = 0.005;
let mut y = 0.0;
for _ in 0..200 { y = so2.update(10.0, dt); }
assert!((y - 10.0).abs() < 0.1);
}
#[test]
fn test_zero_dt_integrity() {
let mut so2 = So2Controller::new(10.0, 1.0, 5.0, 1.0);
let output = so2.update(10.0, 0.0);
assert_eq!(output, 5.0);
}
#[test]
fn test_setpoint_tracking() {
let mut so2 = So2Controller::new(10.0, 0.7, 0.0, 1.0);
let mut y = 0.0;
for _ in 0..100 { y = so2.update(1.0, 0.01); }
assert!((y - 1.0).abs() < 0.1);
}
}