use crate::registration::{TransformMatrix, TransformationType};
#[derive(Debug, Clone, Copy)]
pub struct GyroSample {
pub t_ns: i64,
pub omega_xyz: [f32; 3],
}
#[derive(Debug, Clone, Copy)]
pub struct CameraIntrinsics {
pub fx: f32,
pub fy: f32,
pub cx: f32,
pub cy: f32,
}
impl CameraIntrinsics {
#[must_use]
pub fn default_1080p() -> Self {
Self {
fx: 1500.0,
fy: 1500.0,
cx: 960.0,
cy: 540.0,
}
}
}
#[derive(Debug, Clone)]
struct FusionState {
tx: f64,
ty: f64,
theta: f64,
}
impl FusionState {
fn identity() -> Self {
Self {
tx: 0.0,
ty: 0.0,
theta: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct GyroFusionFilter {
alpha: f64,
intrinsics: CameraIntrinsics,
fps: f64,
state: FusionState,
}
impl GyroFusionFilter {
#[must_use]
pub fn new(alpha: f32, intrinsics: CameraIntrinsics, fps: f32) -> Self {
debug_assert!(fps > 0.0, "fps must be positive");
Self {
alpha: alpha.clamp(0.0, 1.0) as f64,
intrinsics,
fps: fps as f64,
state: FusionState::identity(),
}
}
pub fn step(&mut self, visual: TransformMatrix, gyro: &[GyroSample]) -> TransformMatrix {
let dt = 1.0 / self.fps;
let (delta_tx, delta_ty, _delta_theta) = integrate_gyro(gyro, dt, &self.intrinsics);
let pred_tx = self.state.tx + delta_tx;
let pred_ty = self.state.ty + delta_ty;
let (vtx, vty) = visual.get_translation();
let fused_tx = self.alpha * pred_tx + (1.0 - self.alpha) * vtx;
let fused_ty = self.alpha * pred_ty + (1.0 - self.alpha) * vty;
self.state.tx = fused_tx;
self.state.ty = fused_ty;
TransformMatrix {
data: [1.0, 0.0, fused_tx, 0.0, 1.0, fused_ty, 0.0, 0.0, 1.0],
transform_type: TransformationType::Translation,
}
}
pub fn predict_only(&mut self, gyro: &[GyroSample]) -> TransformMatrix {
let dt = 1.0 / self.fps;
let (delta_tx, delta_ty, _delta_theta) = integrate_gyro(gyro, dt, &self.intrinsics);
self.state.tx += delta_tx;
self.state.ty += delta_ty;
TransformMatrix {
data: [
1.0,
0.0,
self.state.tx,
0.0,
1.0,
self.state.ty,
0.0,
0.0,
1.0,
],
transform_type: TransformationType::Translation,
}
}
pub fn reset(&mut self) {
self.state = FusionState::identity();
}
#[must_use]
pub fn current_translation(&self) -> (f64, f64) {
(self.state.tx, self.state.ty)
}
}
fn integrate_gyro(
samples: &[GyroSample],
dt_frame: f64,
intrinsics: &CameraIntrinsics,
) -> (f64, f64, f64) {
if samples.is_empty() {
return (0.0, 0.0, 0.0);
}
let fx = intrinsics.fx as f64;
let fy = intrinsics.fy as f64;
let mut omega_x_int = 0.0_f64; let mut omega_y_int = 0.0_f64; let mut omega_z_int = 0.0_f64;
if samples.len() == 1 {
let s = &samples[0];
omega_x_int = s.omega_xyz[0] as f64 * dt_frame;
omega_y_int = s.omega_xyz[1] as f64 * dt_frame;
omega_z_int = s.omega_xyz[2] as f64 * dt_frame;
} else {
let t_start = samples[0].t_ns as f64 * 1e-9;
let t_end = samples[samples.len() - 1].t_ns as f64 * 1e-9;
let span = t_end - t_start;
for i in 0..samples.len() - 1 {
let t0 = samples[i].t_ns as f64 * 1e-9;
let t1 = samples[i + 1].t_ns as f64 * 1e-9;
let sub_dt = (t1 - t0).max(0.0);
let ax = (samples[i].omega_xyz[0] + samples[i + 1].omega_xyz[0]) as f64 * 0.5;
let ay = (samples[i].omega_xyz[1] + samples[i + 1].omega_xyz[1]) as f64 * 0.5;
let az = (samples[i].omega_xyz[2] + samples[i + 1].omega_xyz[2]) as f64 * 0.5;
omega_x_int += ax * sub_dt;
omega_y_int += ay * sub_dt;
omega_z_int += az * sub_dt;
}
if span > 1e-9 {
let scale = dt_frame / span;
omega_x_int *= scale;
omega_y_int *= scale;
omega_z_int *= scale;
}
}
let delta_tx = omega_y_int * fx; let delta_ty = omega_x_int * fy; let delta_theta = omega_z_int;
(delta_tx, delta_ty, delta_theta)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::registration::TransformMatrix;
fn make_intrinsics(fx: f32) -> CameraIntrinsics {
CameraIntrinsics {
fx,
fy: fx,
cx: 0.0,
cy: 0.0,
}
}
#[test]
fn test_gyro_fusion_pure_rotation() {
let intrinsics = make_intrinsics(500.0);
let mut filter = GyroFusionFilter::new(1.0, intrinsics, 60.0);
let omega_y = (1.0_f32).to_radians() * 60.0; let samples = vec![GyroSample {
t_ns: 0,
omega_xyz: [0.0, omega_y, 0.0],
}];
let visual = TransformMatrix::translation(0.0, 0.0);
let fused = filter.step(visual, &samples);
let (tx, _ty) = fused.get_translation();
let expected = std::f64::consts::PI / 180.0 * 500.0;
assert!(
(tx - expected).abs() < 1.0,
"Expected x-translation ≈{expected:.3}, got {tx:.3}"
);
}
#[test]
fn test_gyro_visual_disagree_alpha_high_trusts_gyro() {
let intrinsics = make_intrinsics(500.0);
let mut filter = GyroFusionFilter::new(0.97, intrinsics, 30.0);
let samples = vec![GyroSample {
t_ns: 0,
omega_xyz: [0.0, 0.0, 0.0],
}];
let visual = TransformMatrix::translation(100.0, 0.0);
let fused = filter.step(visual, &samples);
let (tx, _ty) = fused.get_translation();
assert!(
tx < 5.0,
"Expected fused tx < 5 px (alpha trusts gyro), got {tx:.3}"
);
}
#[test]
fn test_gyro_only_predict_during_dropout() {
let intrinsics = make_intrinsics(500.0);
let mut filter = GyroFusionFilter::new(1.0, intrinsics, 10.0);
let samples = vec![GyroSample {
t_ns: 0,
omega_xyz: [0.0, 1.0, 0.0], }];
for _ in 0..10 {
filter.predict_only(&samples);
}
let (tx, _ty) = filter.current_translation();
let expected = 1.0_f64 * 500.0; assert!(
(tx - expected).abs() < 2.0,
"Expected cumulative tx ≈{expected:.1} px, got {tx:.3}"
);
}
#[test]
fn test_gyro_reset_clears_state() {
let intrinsics = make_intrinsics(500.0);
let mut filter = GyroFusionFilter::new(0.97, intrinsics, 30.0);
let samples = vec![GyroSample {
t_ns: 0,
omega_xyz: [0.0, 1.0, 0.0],
}];
let visual = TransformMatrix::identity();
filter.predict_only(&samples);
filter.reset();
let (tx, ty) = filter.current_translation();
assert!((tx).abs() < 1e-9);
assert!((ty).abs() < 1e-9);
let out = filter.step(
visual,
&[GyroSample {
t_ns: 0,
omega_xyz: [0.0, 0.0, 0.0],
}],
);
let (tx2, _) = out.get_translation();
assert!(tx2.abs() < 1e-6);
}
#[test]
fn test_gyro_multi_sample_integration() {
let intrinsics = make_intrinsics(1000.0);
let mut filter = GyroFusionFilter::new(1.0, intrinsics, 1.0);
let samples = vec![
GyroSample {
t_ns: 0,
omega_xyz: [0.0, 1.0, 0.0],
},
GyroSample {
t_ns: 1_000_000_000,
omega_xyz: [0.0, 1.0, 0.0],
},
];
let visual = TransformMatrix::translation(0.0, 0.0);
let fused = filter.step(visual, &samples);
let (tx, _ty) = fused.get_translation();
assert!(
(tx - 1000.0).abs() < 2.0,
"Expected tx ≈ 1000 px, got {tx:.3}"
);
}
}