use crate::{SensorFusion, SensorFusionMath};
use num_traits::{ConstOne, ConstZero, float::FloatCore};
use vqm::{Quaternion, QuaternionMath, SqrtMethods, TrigonometricMethods, Vector3d};
pub type MadgwickFilterf32 = MadgwickFilter<f32>;
pub type MadgwickFilterf64 = MadgwickFilter<f64>;
pub trait ConstFour {
const FOUR: Self;
}
impl ConstFour for f32 {
const FOUR: Self = 4.0;
}
impl ConstFour for f64 {
const FOUR: Self = 4.0;
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MadgwickFilter<T> {
q: Quaternion<T>,
max_acc_magnitude_squared: T,
beta: T,
beta_yaw: T,
}
impl<T> Default for MadgwickFilter<T>
where
T: Copy + ConstZero + ConstOne + ConstFour,
{
fn default() -> Self {
Self::new()
}
}
impl<T> MadgwickFilter<T>
where
T: Copy + ConstZero + ConstOne + ConstFour,
{
pub const fn with_orientation_and_beta(orientation: Quaternion<T>, beta: T) -> Self {
MadgwickFilter { q: orientation, max_acc_magnitude_squared: T::FOUR, beta, beta_yaw: T::ZERO }
}
pub const fn with_orientation(orientation: Quaternion<T>) -> Self {
MadgwickFilter { q: orientation, max_acc_magnitude_squared: T::FOUR, beta: T::ONE, beta_yaw: T::ZERO }
}
pub const fn new() -> Self {
MadgwickFilter {
q: Quaternion { w: T::ONE, x: T::ZERO, y: T::ZERO, z: T::ZERO },
max_acc_magnitude_squared: T::FOUR,
beta: T::ONE,
beta_yaw: T::ZERO,
}
}
}
impl<T: Copy> MadgwickFilter<T> {
pub fn set_orientation(&mut self, orientation: Quaternion<T>) {
self.q = orientation;
}
pub fn set_beta(&mut self, beta: T) {
self.beta = beta;
}
pub fn set_beta_yaw(&mut self, beta_yaw: T) {
self.beta_yaw = beta_yaw;
}
}
impl<T> MadgwickFilter<T>
where
T: Copy + FloatCore + QuaternionMath + SqrtMethods + SensorFusionMath + TrigonometricMethods,
{
pub fn correct_yaw(&mut self, yaw_radians: T, delta_t: T) -> Quaternion<T> {
self.correct_yaw_with_gain(yaw_radians, self.beta_yaw, delta_t)
}
}
impl<T> SensorFusion<T> for MadgwickFilter<T>
where
T: Copy + FloatCore + QuaternionMath + SqrtMethods + SensorFusionMath + TrigonometricMethods,
{
fn set_gains(&mut self, gain0: T, gain1: T) {
self.beta = gain0;
self.beta_yaw = gain1;
}
fn gains(&self) -> (T, T) {
(self.beta, self.beta_yaw)
}
fn requires_initialization() -> bool {
true
}
fn fuse_acc_gyro(&mut self, acc: Vector3d<T>, gyro_rps: Vector3d<T>, delta_t: T) -> Quaternion<T> {
let step = SensorFusionMath::madgwick_step_acc(self.q, acc, self.max_acc_magnitude_squared);
let q_dot = SensorFusionMath::derivative(self.q, gyro_rps) - step * self.beta;
self.q += q_dot * delta_t;
self.q = self.q.normalize();
self.q
}
fn fuse_acc_gyro_mag(
&mut self,
acc: Vector3d<T>,
gyro_rps: Vector3d<T>,
mag: Vector3d<T>,
delta_t: T,
) -> Quaternion<T> {
let step = SensorFusionMath::madgwick_step_acc_mag(self.q, acc, mag, self.max_acc_magnitude_squared);
let q_dot = SensorFusionMath::derivative(self.q, gyro_rps) - step * self.beta;
self.q += q_dot * delta_t;
self.q = self.q.normalize();
self.q
}
fn correct_yaw_with_gain(&mut self, yaw_radians: T, gain: T, delta_t: T) -> Quaternion<T> {
let mut target = Quaternion::<T>::from_yaw_radians(yaw_radians);
if self.q.dot(target) < T::zero() {
target = -target;
}
let error = self.q * target.conjugate();
let omega = Quaternion::<T>::new(T::zero(), T::zero(), T::zero(), -error.z);
let q_dot = self.q * omega;
self.q += q_dot * (delta_t * gain);
self.q = self.q.normalize();
self.q
}
}
#[cfg(any(debug_assertions, test))]
mod tests {
#![allow(unused)]
use crate::FuseAccGyro;
use super::*;
use vqm::{Quaternionf32, Vector3df32};
fn is_normal<T: Sized + Send + Sync + Unpin>() {}
fn is_full<T: Sized + Send + Sync + Unpin + Copy + Clone + Default + PartialEq>() {}
#[test]
#[allow(clippy::float_cmp)]
fn q_dot() {
let q = Quaternionf32::new(2.0, 3.0, 5.0, 7.0);
let mut target = Quaternionf32::new(11.0, 13.0, 17.0, 19.0);
if q.dot(target) < 0.0 {
target = -target;
}
let error = q * target.conjugate();
assert_eq!(53.0, error.z);
let omega = Quaternionf32::new(0.0, 0.0, 0.0, -error.z);
let q_dot = q * omega;
assert_eq!(Quaternionf32 { w: 371.0, x: -265.0, y: 159.0, z: -106.0 }, q_dot);
let error_z = -q.w * target.z - q.x * target.y + q.y * target.x + q.z * target.w;
assert_eq!(53.0, error_z);
assert_eq!(error.z, error_z);
let q_dot2 = Quaternion { w: q.z * error_z, x: -q.y * error_z, y: q.x * error_z, z: -q.w * error_z };
assert_eq!(q_dot, q_dot2);
}
#[test]
fn normal_types() {
is_full::<MadgwickFilter<f32>>();
}
#[test]
fn readme() {
let mut madgwick_filter = MadgwickFilterf32::default();
let dt: f32 = 0.001;
let acc = Vector3df32::default();
let gyro_rps = Vector3df32::default();
let orientation = madgwick_filter.fuse_acc_gyro(acc, gyro_rps, dt);
assert_eq!(orientation, Quaternion { w: 1.0, x: 0.0, y: 0.0, z: 0.0 });
}
#[test]
fn update_orientation() {
let mut madgwick_filter = MadgwickFilterf32::default();
let requires_initialization = MadgwickFilterf32::requires_initialization();
assert!(requires_initialization);
madgwick_filter.set_beta(1.0);
let delta_t: f32 = 0.0;
let acc = Vector3df32::default();
let gyro_rps = Vector3df32::default();
let orientation = madgwick_filter.fuse_acc_gyro(acc, gyro_rps, delta_t);
assert_eq!(orientation, Quaternion { w: 1.0, x: 0.0, y: 0.0, z: 0.0 });
}
#[test]
fn fuse_acc_gyro_using() {
let mut madgwick_filter = MadgwickFilterf32::default();
let delta_t: f32 = 0.0;
let acc = Vector3df32::default();
let gyro_rps = Vector3df32::default();
let orientation = (acc, gyro_rps).fuse_acc_gyro_using(&mut madgwick_filter, delta_t);
assert_eq!(orientation, Quaternion { w: 1.0, x: 0.0, y: 0.0, z: 0.0 });
}
#[allow(clippy::float_cmp)]
#[test]
fn yaw_correction() {
let initial_yaw = 179.0f32;
let initial_q = Quaternionf32::from_yaw_degrees(initial_yaw);
let mut filter = MadgwickFilterf32::with_orientation(initial_q);
let gain = 5.0;
let delta_t = 0.1;
filter.set_beta_yaw(gain);
let target_yaw = 178.0_f32;
let filtered_q = filter.correct_yaw_with_gain(target_yaw.to_radians(), gain, delta_t);
let final_yaw = filtered_q.calculate_yaw_degrees();
println!("Yaw, initial: {initial_yaw:.4}° | target: {target_yaw:.4}° | filtered: {final_yaw:.4}°");
assert_eq!(178.5, final_yaw);
filter.set_orientation(initial_q);
let target_yaw = 179.5_f32;
let filtered_q = filter.correct_yaw_with_gain(target_yaw.to_radians(), gain, delta_t);
let final_yaw = filtered_q.calculate_yaw_degrees();
println!("Yaw, initial: {initial_yaw:.4}° | target: {target_yaw:.4}° | filtered: {final_yaw:.4}°");
assert!((179.25 - final_yaw).abs() < 2e-5);
let initial_yaw = 179.75f32;
let initial_q = Quaternionf32::from_yaw_degrees(initial_yaw);
filter.set_orientation(initial_q);
let target_yaw = 181.0_f32;
let filtered_q = filter.correct_yaw_with_gain(target_yaw.to_radians(), gain, delta_t);
let final_yaw = filtered_q.calculate_yaw_degrees();
println!("Yaw, initial: {initial_yaw:.4}° | target: {target_yaw:.4}° | filtered: {final_yaw:.4}°");
assert!((-179.625 - final_yaw).abs() < 4e-5);
assert!(
!(0.0..=179.0).contains(&final_yaw),
"\n**** Filter took the long way around! Yaw dropped down to {final_yaw:.2}°"
);
}
}