#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct TrackingFrame {
pub frame: u64,
pub position: [f32; 3],
pub rotation_quat: [f32; 4],
pub confidence: f32,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum FilterMode {
PassThrough,
LowPass {
cutoff_hz: f32,
sample_rate_hz: f32,
},
Kalman {
process_noise: f32,
measurement_noise: f32,
},
OneEuro {
min_cutoff: f32,
beta: f32,
},
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct JitterReport {
pub rms_position_jitter: f32,
pub max_position_jitter: f32,
pub rms_rotation_jitter: f32,
}
#[derive(Debug, Clone)]
struct ScalarKalman {
x: f32,
p: f32,
q: f32,
r: f32,
initialised: bool,
}
impl ScalarKalman {
fn new(q: f32, r: f32) -> Self {
Self {
x: 0.0,
p: 1.0,
q,
r,
initialised: false,
}
}
fn update(&mut self, measurement: f32) -> f32 {
if !self.initialised {
self.x = measurement;
self.p = 1.0;
self.initialised = true;
return self.x;
}
let p_pred = self.p + self.q;
let k = p_pred / (p_pred + self.r);
self.x += k * (measurement - self.x);
self.p = (1.0 - k) * p_pred;
self.x
}
}
#[derive(Debug, Clone)]
struct OneEuroScalar {
min_cutoff: f32,
beta: f32,
d_cutoff: f32,
x_prev: Option<f32>,
dx_prev: f32,
}
impl OneEuroScalar {
fn new(min_cutoff: f32, beta: f32) -> Self {
Self {
min_cutoff,
beta,
d_cutoff: 1.0,
x_prev: None,
dx_prev: 0.0,
}
}
#[inline]
fn alpha(cutoff_hz: f32, sample_rate_hz: f32) -> f32 {
if sample_rate_hz <= 0.0 || cutoff_hz <= 0.0 {
return 1.0;
}
let tau = 1.0 / (2.0 * std::f32::consts::PI * cutoff_hz);
let dt = 1.0 / sample_rate_hz;
dt / (tau + dt)
}
fn filter(&mut self, raw: f32, sample_rate_hz: f32) -> f32 {
let x_prev = match self.x_prev {
Some(v) => v,
None => {
self.x_prev = Some(raw);
return raw;
}
};
let dx = (raw - x_prev) * sample_rate_hz;
let a_d = Self::alpha(self.d_cutoff, sample_rate_hz);
let dx_hat = a_d * dx + (1.0 - a_d) * self.dx_prev;
self.dx_prev = dx_hat;
let cutoff = self.min_cutoff + self.beta * dx_hat.abs();
let a = Self::alpha(cutoff, sample_rate_hz);
let x_hat = a * raw + (1.0 - a) * x_prev;
self.x_prev = Some(x_hat);
x_hat
}
}
#[derive(Debug, Clone)]
struct IirScalar {
prev: Option<f32>,
alpha: f32,
}
impl IirScalar {
fn new(cutoff_hz: f32, sample_rate_hz: f32) -> Self {
let alpha = if sample_rate_hz > 0.0 && cutoff_hz > 0.0 {
let tau = 1.0 / (2.0 * std::f32::consts::PI * cutoff_hz);
let dt = 1.0 / sample_rate_hz;
dt / (tau + dt)
} else {
1.0
};
Self { prev: None, alpha }
}
fn filter(&mut self, raw: f32) -> f32 {
let out = match self.prev {
None => raw,
Some(p) => self.alpha * raw + (1.0 - self.alpha) * p,
};
self.prev = Some(out);
out
}
}
fn quat_normalise(q: [f32; 4]) -> [f32; 4] {
let len_sq = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3];
if len_sq < 1e-12 {
return [0.0, 0.0, 0.0, 1.0];
}
let inv = 1.0 / len_sq.sqrt();
[q[0] * inv, q[1] * inv, q[2] * inv, q[3] * inv]
}
#[derive(Debug, Clone)]
enum FilterState {
PassThrough,
LowPass {
pos: [IirScalar; 3],
rot: [IirScalar; 4],
},
Kalman {
pos: [ScalarKalman; 3],
rot: [ScalarKalman; 4],
},
OneEuro {
pos: [OneEuroScalar; 3],
rot: [OneEuroScalar; 4],
sample_rate_hz: f32,
},
}
pub struct TrackingFilter {
mode: FilterMode,
state: FilterState,
}
impl TrackingFilter {
#[must_use]
pub fn new(mode: FilterMode) -> Self {
let state = Self::build_state(&mode);
Self { mode, state }
}
pub fn reset(&mut self) {
self.state = Self::build_state(&self.mode);
}
pub fn filter(&mut self, frame: &TrackingFrame) -> TrackingFrame {
match &mut self.state {
FilterState::PassThrough => frame.clone(),
FilterState::LowPass { pos, rot } => {
let filtered_pos = [
pos[0].filter(frame.position[0]),
pos[1].filter(frame.position[1]),
pos[2].filter(frame.position[2]),
];
let q_raw = frame.rotation_quat;
let filtered_rot_raw = [
rot[0].filter(q_raw[0]),
rot[1].filter(q_raw[1]),
rot[2].filter(q_raw[2]),
rot[3].filter(q_raw[3]),
];
TrackingFrame {
frame: frame.frame,
position: filtered_pos,
rotation_quat: quat_normalise(filtered_rot_raw),
confidence: frame.confidence,
}
}
FilterState::Kalman { pos, rot } => {
let filtered_pos = [
pos[0].update(frame.position[0]),
pos[1].update(frame.position[1]),
pos[2].update(frame.position[2]),
];
let q_raw = frame.rotation_quat;
let filtered_rot_raw = [
rot[0].update(q_raw[0]),
rot[1].update(q_raw[1]),
rot[2].update(q_raw[2]),
rot[3].update(q_raw[3]),
];
TrackingFrame {
frame: frame.frame,
position: filtered_pos,
rotation_quat: quat_normalise(filtered_rot_raw),
confidence: frame.confidence,
}
}
FilterState::OneEuro {
pos,
rot,
sample_rate_hz,
} => {
let sr = *sample_rate_hz;
let filtered_pos = [
pos[0].filter(frame.position[0], sr),
pos[1].filter(frame.position[1], sr),
pos[2].filter(frame.position[2], sr),
];
let q_raw = frame.rotation_quat;
let filtered_rot_raw = [
rot[0].filter(q_raw[0], sr),
rot[1].filter(q_raw[1], sr),
rot[2].filter(q_raw[2], sr),
rot[3].filter(q_raw[3], sr),
];
TrackingFrame {
frame: frame.frame,
position: filtered_pos,
rotation_quat: quat_normalise(filtered_rot_raw),
confidence: frame.confidence,
}
}
}
}
fn build_state(mode: &FilterMode) -> FilterState {
match mode {
FilterMode::PassThrough => FilterState::PassThrough,
FilterMode::LowPass {
cutoff_hz,
sample_rate_hz,
} => FilterState::LowPass {
pos: std::array::from_fn(|_| IirScalar::new(*cutoff_hz, *sample_rate_hz)),
rot: std::array::from_fn(|_| IirScalar::new(*cutoff_hz, *sample_rate_hz)),
},
FilterMode::Kalman {
process_noise,
measurement_noise,
} => FilterState::Kalman {
pos: std::array::from_fn(|_| ScalarKalman::new(*process_noise, *measurement_noise)),
rot: std::array::from_fn(|_| ScalarKalman::new(*process_noise, *measurement_noise)),
},
FilterMode::OneEuro { min_cutoff, beta } => {
FilterState::OneEuro {
pos: std::array::from_fn(|_| OneEuroScalar::new(*min_cutoff, *beta)),
rot: std::array::from_fn(|_| OneEuroScalar::new(*min_cutoff, *beta)),
sample_rate_hz: 60.0,
}
}
}
}
}
pub struct JitterMetrics;
impl JitterMetrics {
#[must_use]
pub fn compute(frames: &[TrackingFrame]) -> JitterReport {
if frames.len() < 2 {
return JitterReport {
rms_position_jitter: 0.0,
max_position_jitter: 0.0,
rms_rotation_jitter: 0.0,
};
}
let n = (frames.len() - 1) as f32;
let mut sum_pos_sq = 0.0_f32;
let mut max_pos = 0.0_f32;
let mut sum_rot_sq = 0.0_f32;
for pair in frames.windows(2) {
let a = &pair[0];
let b = &pair[1];
let dp = [
b.position[0] - a.position[0],
b.position[1] - a.position[1],
b.position[2] - a.position[2],
];
let pos_dist = (dp[0] * dp[0] + dp[1] * dp[1] + dp[2] * dp[2]).sqrt();
sum_pos_sq += pos_dist * pos_dist;
max_pos = max_pos.max(pos_dist);
let dot = (a.rotation_quat[0] * b.rotation_quat[0]
+ a.rotation_quat[1] * b.rotation_quat[1]
+ a.rotation_quat[2] * b.rotation_quat[2]
+ a.rotation_quat[3] * b.rotation_quat[3])
.abs()
.clamp(0.0, 1.0);
let rot_dist = 2.0 * dot.acos();
sum_rot_sq += rot_dist * rot_dist;
}
JitterReport {
rms_position_jitter: (sum_pos_sq / n).sqrt(),
max_position_jitter: max_pos,
rms_rotation_jitter: (sum_rot_sq / n).sqrt(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn identity_frame(frame: u64) -> TrackingFrame {
TrackingFrame {
frame,
position: [0.0, 0.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
}
}
#[test]
fn test_passthrough_unchanged_position() {
let mut f = TrackingFilter::new(FilterMode::PassThrough);
let frame = TrackingFrame {
frame: 0,
position: [1.0, 2.0, 3.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 0.9,
};
let out = f.filter(&frame);
assert_eq!(out.position, frame.position);
assert_eq!(out.rotation_quat, frame.rotation_quat);
assert_eq!(out.confidence, frame.confidence);
}
#[test]
fn test_passthrough_unchanged_rotation() {
let mut f = TrackingFilter::new(FilterMode::PassThrough);
let frame = TrackingFrame {
frame: 7,
position: [0.0; 3],
rotation_quat: [0.707, 0.0, 0.707, 0.0],
confidence: 1.0,
};
let out = f.filter(&frame);
assert!((out.rotation_quat[0] - 0.707).abs() < 1e-5);
assert!((out.rotation_quat[2] - 0.707).abs() < 1e-5);
}
#[test]
fn test_lowpass_first_frame_passthrough() {
let mut f = TrackingFilter::new(FilterMode::LowPass {
cutoff_hz: 5.0,
sample_rate_hz: 60.0,
});
let frame = TrackingFrame {
frame: 0,
position: [3.0, -1.0, 2.5],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
};
let out = f.filter(&frame);
assert_eq!(out.position, frame.position);
}
#[test]
fn test_lowpass_smoothes_step_change() {
let mut f = TrackingFilter::new(FilterMode::LowPass {
cutoff_hz: 2.0,
sample_rate_hz: 60.0,
});
for i in 0..30_u64 {
f.filter(&identity_frame(i));
}
let step_frame = TrackingFrame {
frame: 30,
position: [10.0, 0.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
};
let out = f.filter(&step_frame);
assert!(
out.position[0] > 0.0 && out.position[0] < 10.0,
"smoothed position should be between 0 and 10, got {}",
out.position[0]
);
}
#[test]
fn test_kalman_converges_to_stationary_position() {
let mut f = TrackingFilter::new(FilterMode::Kalman {
process_noise: 1e-3,
measurement_noise: 0.1,
});
let mut pos_x_sum = 0.0_f32;
let noises = [0.05, -0.03, 0.07, -0.06, 0.02, -0.04, 0.08, -0.01];
for i in 0..60_u64 {
let noise = noises[(i as usize) % noises.len()];
let frame = TrackingFrame {
frame: i,
position: [5.0 + noise, 0.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
};
let out = f.filter(&frame);
if i >= 30 {
pos_x_sum += out.position[0];
}
}
let avg = pos_x_sum / 30.0;
assert!(
(avg - 5.0).abs() < 0.15,
"Kalman average should be close to 5.0, got {avg}"
);
}
#[test]
fn test_kalman_reduces_rms_jitter() {
let noisy_frames: Vec<TrackingFrame> = (0..120_u64)
.map(|i| {
let noise = if i % 2 == 0 { 0.1 } else { -0.1 };
TrackingFrame {
frame: i,
position: [noise, 0.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
}
})
.collect();
let raw_jitter = JitterMetrics::compute(&noisy_frames).rms_position_jitter;
let mut f = TrackingFilter::new(FilterMode::Kalman {
process_noise: 1e-4,
measurement_noise: 0.05,
});
let filtered_frames: Vec<TrackingFrame> =
noisy_frames.iter().map(|fr| f.filter(fr)).collect();
let filtered_jitter = JitterMetrics::compute(&filtered_frames).rms_position_jitter;
assert!(
filtered_jitter < raw_jitter,
"Kalman should reduce jitter: {filtered_jitter} < {raw_jitter}"
);
}
#[test]
fn test_one_euro_first_frame_passthrough() {
let mut f = TrackingFilter::new(FilterMode::OneEuro {
min_cutoff: 1.0,
beta: 0.007,
});
let frame = TrackingFrame {
frame: 0,
position: [7.0, -2.0, 1.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
};
let out = f.filter(&frame);
assert_eq!(
out.position, frame.position,
"first frame must pass through"
);
}
#[test]
fn test_one_euro_smoothes_slow_jitter() {
let mut f = TrackingFilter::new(FilterMode::OneEuro {
min_cutoff: 1.0,
beta: 0.0,
});
for i in 0..30_u64 {
f.filter(&identity_frame(i));
}
let jitter = TrackingFrame {
frame: 30,
position: [0.5, 0.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
};
let out = f.filter(&jitter);
assert!(
out.position[0] < 0.5,
"One-Euro should smooth small jitter: {}",
out.position[0]
);
}
#[test]
fn test_one_euro_fast_motion_lower_lag_than_lowpass() {
let mut lp = TrackingFilter::new(FilterMode::LowPass {
cutoff_hz: 2.0,
sample_rate_hz: 60.0,
});
let mut oe = TrackingFilter::new(FilterMode::OneEuro {
min_cutoff: 2.0,
beta: 1.0, });
for i in 0..30_u64 {
lp.filter(&identity_frame(i));
oe.filter(&identity_frame(i));
}
let step = TrackingFrame {
frame: 30,
position: [100.0, 0.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
};
let lp_out = lp.filter(&step);
let oe_out = oe.filter(&step);
assert!(
oe_out.position[0] >= lp_out.position[0],
"One-Euro should track fast motion better: oe={} lp={}",
oe_out.position[0],
lp_out.position[0]
);
}
#[test]
fn test_jitter_metrics_empty_or_single() {
let report = JitterMetrics::compute(&[]);
assert_eq!(report.rms_position_jitter, 0.0);
assert_eq!(report.max_position_jitter, 0.0);
assert_eq!(report.rms_rotation_jitter, 0.0);
let single = [identity_frame(0)];
let r2 = JitterMetrics::compute(&single);
assert_eq!(r2.rms_position_jitter, 0.0);
}
#[test]
fn test_jitter_metrics_static_signal_zero_jitter() {
let frames: Vec<TrackingFrame> = (0..10)
.map(|i| TrackingFrame {
frame: i,
position: [1.0, 2.0, 3.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
})
.collect();
let report = JitterMetrics::compute(&frames);
assert!(
report.rms_position_jitter < 1e-6,
"static signal: rms={}",
report.rms_position_jitter
);
assert!(report.max_position_jitter < 1e-6);
}
#[test]
fn test_jitter_metrics_known_displacement() {
let frames = vec![
TrackingFrame {
frame: 0,
position: [0.0, 0.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
},
TrackingFrame {
frame: 1,
position: [3.0, 4.0, 0.0],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
},
];
let report = JitterMetrics::compute(&frames);
assert!(
(report.rms_position_jitter - 5.0).abs() < 1e-4,
"expected 5.0, got {}",
report.rms_position_jitter
);
assert!((report.max_position_jitter - 5.0).abs() < 1e-4);
}
#[test]
fn test_jitter_rotation_nonzero_when_rotating() {
let half = std::f32::consts::FRAC_PI_4;
let frames = vec![
TrackingFrame {
frame: 0,
position: [0.0; 3],
rotation_quat: [0.0, 0.0, 0.0, 1.0],
confidence: 1.0,
},
TrackingFrame {
frame: 1,
position: [0.0; 3],
rotation_quat: [0.0, 0.0, half.sin(), half.cos()],
confidence: 1.0,
},
];
let report = JitterMetrics::compute(&frames);
assert!(
report.rms_rotation_jitter > 0.0,
"rotation jitter should be non-zero"
);
}
}