use crate::math::MulAdd;
macro_rules! impl_kalman1d {
($name:ident, $builder:ident, $ty:ty) => {
#[derive(Debug, Clone)]
pub struct $name {
// State: [position, velocity]
x0: $ty,
x1: $ty,
p00: $ty,
p01: $ty,
p11: $ty,
q: $ty, r: $ty, count: u64,
min_samples: u64,
initialized: bool,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
q: Option<$ty>,
r: Option<$ty>,
min_samples: u64,
seed_pos: Option<$ty>,
seed_vel: Option<$ty>,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
q: Option::None,
r: Option::None,
min_samples: 1,
seed_pos: Option::None,
seed_vel: Option::None,
}
}
#[inline]
#[must_use]
pub fn update(&mut self, measurement: $ty) -> Option<($ty, $ty)> {
self.count += 1;
if !self.initialized {
self.x0 = measurement;
self.x1 = 0.0 as $ty;
self.p00 = self.r;
self.p01 = 0.0 as $ty;
self.p11 = 1.0 as $ty;
self.initialized = true;
return if self.count >= self.min_samples {
Option::Some((self.x0, self.x1))
} else {
Option::None
};
}
let pred_x0 = self.x0 + self.x1;
let pred_x1 = self.x1;
let pred_p00 = (2.0 as $ty).fma(self.p01, self.p00) + self.p11 + self.q;
let pred_p01 = self.p01 + self.p11;
let pred_p11 = self.p11 + self.q;
let y = measurement - pred_x0;
let s = pred_p00 + self.r;
let k0 = pred_p00 / s;
let k1 = pred_p01 / s;
self.x0 = k0.fma(y, pred_x0);
self.x1 = k1.fma(y, pred_x1);
self.p00 = (1.0 as $ty - k0) * pred_p00;
self.p01 = (1.0 as $ty - k0) * pred_p01;
self.p11 = pred_p11 - k1 * pred_p01;
if self.count >= self.min_samples {
Option::Some((self.x0, self.x1))
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn position(&self) -> Option<$ty> {
if self.count >= self.min_samples {
Option::Some(self.x0)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn velocity(&self) -> Option<$ty> {
if self.count >= self.min_samples {
Option::Some(self.x1)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn uncertainty(&self) -> $ty {
self.p00
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count >= self.min_samples
}
#[inline]
pub fn reset(&mut self) {
self.x0 = 0.0 as $ty;
self.x1 = 0.0 as $ty;
self.p00 = 1.0 as $ty;
self.p01 = 0.0 as $ty;
self.p11 = 1.0 as $ty;
self.count = 0;
self.initialized = false;
}
}
impl $builder {
#[inline]
#[must_use]
pub fn process_noise(mut self, q: $ty) -> Self {
self.q = Option::Some(q);
self
}
#[inline]
#[must_use]
pub fn measurement_noise(mut self, r: $ty) -> Self {
self.r = Option::Some(r);
self
}
#[inline]
#[must_use]
pub fn min_samples(mut self, min: u64) -> Self {
self.min_samples = min;
self
}
#[inline]
#[must_use]
pub fn seed(mut self, position: $ty, velocity: $ty) -> Self {
self.seed_pos = Option::Some(position);
self.seed_vel = Option::Some(velocity);
self
}
#[inline]
pub fn build(self) -> Result<$name, crate::ConfigError> {
let q = self.q.ok_or(crate::ConfigError::Missing("process_noise"))?;
let r = self
.r
.ok_or(crate::ConfigError::Missing("measurement_noise"))?;
if q <= 0.0 as $ty {
return Err(crate::ConfigError::Invalid(
"process_noise must be positive",
));
}
if r <= 0.0 as $ty {
return Err(crate::ConfigError::Invalid(
"measurement_noise must be positive",
));
}
let (x0, x1, count, initialized) =
if let (Some(pos), Some(vel)) = (self.seed_pos, self.seed_vel) {
(pos, vel, self.min_samples, true)
} else {
(0.0 as $ty, 0.0 as $ty, 0, false)
};
Ok($name {
x0,
x1,
p00: 1.0 as $ty,
p01: 0.0 as $ty,
p11: 1.0 as $ty,
q,
r,
count,
min_samples: self.min_samples,
initialized,
})
}
}
};
}
impl_kalman1d!(Kalman1dF64, Kalman1dF64Builder, f64);
impl_kalman1d!(Kalman1dF32, Kalman1dF32Builder, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn converges_on_constant() {
let mut kf = Kalman1dF64::builder()
.process_noise(0.01)
.measurement_noise(1.0)
.build()
.unwrap();
for _ in 0..100 {
let _ = kf.update(50.0);
}
let pos = kf.position().unwrap();
assert!(
(pos - 50.0).abs() < 1.0,
"should converge to ~50, got {pos}"
);
}
#[test]
fn tracks_linear_trend() {
let mut kf = Kalman1dF64::builder()
.process_noise(0.1)
.measurement_noise(1.0)
.build()
.unwrap();
for i in 0..100 {
let _ = kf.update(i as f64 * 10.0);
}
let vel = kf.velocity().unwrap();
assert!(
(vel - 10.0).abs() < 2.0,
"velocity should be ~10, got {vel}"
);
}
#[test]
fn high_process_noise_reactive() {
let mut reactive = Kalman1dF64::builder()
.process_noise(10.0)
.measurement_noise(1.0)
.build()
.unwrap();
let mut smooth = Kalman1dF64::builder()
.process_noise(0.001)
.measurement_noise(1.0)
.build()
.unwrap();
for _ in 0..20 {
let _ = reactive.update(100.0);
let _ = smooth.update(100.0);
}
let _ = reactive.update(200.0);
let _ = smooth.update(200.0);
let r_pos = reactive.position().unwrap();
let s_pos = smooth.position().unwrap();
assert!(
r_pos > s_pos,
"reactive ({r_pos}) should track faster than smooth ({s_pos})"
);
}
#[test]
fn uncertainty_decreases() {
let mut kf = Kalman1dF64::builder()
.process_noise(0.01)
.measurement_noise(1.0)
.build()
.unwrap();
let _ = kf.update(50.0);
let u1 = kf.uncertainty();
for _ in 0..50 {
let _ = kf.update(50.0);
}
let u2 = kf.uncertainty();
assert!(u2 < u1, "uncertainty should decrease, was {u1} now {u2}");
}
#[test]
fn seeded_startup() {
let kf = Kalman1dF64::builder()
.process_noise(0.01)
.measurement_noise(1.0)
.seed(100.0, 5.0)
.build()
.unwrap();
assert!(kf.is_primed());
let pos = kf.position().unwrap();
assert!((pos - 100.0).abs() < 1e-10);
}
#[test]
fn reset() {
let mut kf = Kalman1dF64::builder()
.process_noise(0.01)
.measurement_noise(1.0)
.build()
.unwrap();
for _ in 0..50 {
let _ = kf.update(100.0);
}
kf.reset();
assert_eq!(kf.count(), 0);
}
#[test]
fn f32_basic() {
let mut kf = Kalman1dF32::builder()
.process_noise(0.1)
.measurement_noise(1.0)
.build()
.unwrap();
let _ = kf.update(50.0);
assert!(kf.position().is_some());
}
#[test]
fn seed_zero_zero_works() {
let mut kf = Kalman1dF64::builder()
.process_noise(0.01)
.measurement_noise(1.0)
.seed(0.0, 0.0)
.build()
.unwrap();
assert!(kf.is_primed());
let (pos, _vel) = kf.update(10.0).unwrap();
assert!(pos > 0.0, "should track toward 10, got {pos}");
}
#[test]
fn errors_without_process_noise() {
let result = Kalman1dF64::builder().measurement_noise(1.0).build();
assert!(matches!(
result,
Err(crate::ConfigError::Missing("process_noise"))
));
}
}