Skip to main content

yscv_track/
kalman.rs

1//! Linear Kalman filter for bounding-box tracking.
2//!
3//! State vector: `[cx, cy, w, h, vx, vy, vw, vh]` — center position, size,
4//! and their velocities.  The measurement is `[cx, cy, w, h]`.
5//!
6//! All matrices are stored as flat `[f32; N*N]` arrays for simplicity
7//! (no external linear-algebra dependency).
8
9use yscv_detect::BoundingBox;
10
11/// Dimension of the state vector.
12const STATE_DIM: usize = 8;
13/// Dimension of the measurement vector.
14const MEAS_DIM: usize = 4;
15
16/// A 2-D Kalman filter for bounding-box tracking.
17#[derive(Debug, Clone)]
18pub struct KalmanFilter {
19    /// State estimate `[cx, cy, w, h, vx, vy, vw, vh]`.
20    pub(crate) x: [f32; STATE_DIM],
21    /// Error covariance matrix (8×8, row-major).
22    pub(crate) p: [f32; STATE_DIM * STATE_DIM],
23    /// Process noise covariance (8×8).
24    q: [f32; STATE_DIM * STATE_DIM],
25    /// Measurement noise covariance (4×4).
26    r: [f32; MEAS_DIM * MEAS_DIM],
27}
28
29impl KalmanFilter {
30    /// Create a Kalman filter initialized from a bounding box.
31    pub fn new(bbox: BoundingBox) -> Self {
32        let cx = (bbox.x1 + bbox.x2) * 0.5;
33        let cy = (bbox.y1 + bbox.y2) * 0.5;
34        let w = bbox.width();
35        let h = bbox.height();
36
37        let mut x = [0.0f32; STATE_DIM];
38        x[0] = cx;
39        x[1] = cy;
40        x[2] = w;
41        x[3] = h;
42        // velocities start at zero
43
44        // Initial covariance: large uncertainty on velocities.
45        let mut p = [0.0f32; STATE_DIM * STATE_DIM];
46        for i in 0..4 {
47            p[i * STATE_DIM + i] = 10.0;
48        }
49        for i in 4..8 {
50            p[i * STATE_DIM + i] = 100.0;
51        }
52
53        // Process noise
54        let mut q = [0.0f32; STATE_DIM * STATE_DIM];
55        for i in 0..4 {
56            q[i * STATE_DIM + i] = 1.0;
57        }
58        for i in 4..8 {
59            q[i * STATE_DIM + i] = 0.01;
60        }
61
62        // Measurement noise
63        let mut r = [0.0f32; MEAS_DIM * MEAS_DIM];
64        for i in 0..MEAS_DIM {
65            r[i * MEAS_DIM + i] = 1.0;
66        }
67
68        Self { x, p, q, r }
69    }
70
71    /// Predict the next state (one time step dt=1).
72    #[allow(clippy::needless_range_loop)]
73    pub fn predict(&mut self) {
74        // x' = F * x  where F is identity + velocity rows
75        // x[0..4] += x[4..8]
76        for i in 0..4 {
77            self.x[i] += self.x[i + 4];
78        }
79
80        // P' = F * P * F^T + Q
81        // F has 1s on diagonal plus 1s at (i, i+4) for i in 0..4
82        let f = transition_matrix();
83        let ft = transpose_8x8(&f);
84        let fp = mat_mul_8x8(&f, &self.p);
85        let fpft = mat_mul_8x8(&fp, &ft);
86        for i in 0..STATE_DIM * STATE_DIM {
87            self.p[i] = fpft[i] + self.q[i];
88        }
89    }
90
91    /// Update the filter with a measurement `[cx, cy, w, h]`.
92    pub fn update(&mut self, measurement: [f32; MEAS_DIM]) {
93        // H is the 4×8 measurement matrix: identity for first 4 cols, zeros for rest.
94        // Innovation: y = z - H * x
95        let mut y = [0.0f32; MEAS_DIM];
96        for i in 0..MEAS_DIM {
97            y[i] = measurement[i] - self.x[i];
98        }
99
100        // S = H * P * H^T + R  (4×4)
101        // Since H selects first 4 rows/cols of P:
102        let mut s = [0.0f32; MEAS_DIM * MEAS_DIM];
103        for i in 0..MEAS_DIM {
104            for j in 0..MEAS_DIM {
105                s[i * MEAS_DIM + j] = self.p[i * STATE_DIM + j] + self.r[i * MEAS_DIM + j];
106            }
107        }
108
109        // K = P * H^T * S^{-1}  (8×4)
110        // P * H^T is first 4 columns of P (8×4)
111        let s_inv = invert_4x4(&s);
112        let mut k = [0.0f32; STATE_DIM * MEAS_DIM];
113        for i in 0..STATE_DIM {
114            for j in 0..MEAS_DIM {
115                let mut sum = 0.0f32;
116                for m in 0..MEAS_DIM {
117                    sum += self.p[i * STATE_DIM + m] * s_inv[m * MEAS_DIM + j];
118                }
119                k[i * MEAS_DIM + j] = sum;
120            }
121        }
122
123        // x = x + K * y
124        for i in 0..STATE_DIM {
125            let mut sum = 0.0f32;
126            for j in 0..MEAS_DIM {
127                sum += k[i * MEAS_DIM + j] * y[j];
128            }
129            self.x[i] += sum;
130        }
131
132        // P = (I - K * H) * P
133        // K * H is 8×8 where (K*H)[i][j] = K[i][j] for j < 4, 0 otherwise
134        let mut kh = [0.0f32; STATE_DIM * STATE_DIM];
135        for i in 0..STATE_DIM {
136            for j in 0..MEAS_DIM {
137                kh[i * STATE_DIM + j] = k[i * MEAS_DIM + j];
138            }
139        }
140        // I - K*H
141        let mut i_kh = [0.0f32; STATE_DIM * STATE_DIM];
142        for i in 0..STATE_DIM {
143            for j in 0..STATE_DIM {
144                i_kh[i * STATE_DIM + j] = if i == j { 1.0 } else { 0.0 } - kh[i * STATE_DIM + j];
145            }
146        }
147        let new_p = mat_mul_8x8(&i_kh, &self.p);
148        self.p = new_p;
149    }
150
151    /// Get current state as bounding box.
152    pub fn bbox(&self) -> BoundingBox {
153        let cx = self.x[0];
154        let cy = self.x[1];
155        let w = self.x[2].max(1e-3);
156        let h = self.x[3].max(1e-3);
157        BoundingBox {
158            x1: cx - w * 0.5,
159            y1: cy - h * 0.5,
160            x2: cx + w * 0.5,
161            y2: cy + h * 0.5,
162        }
163    }
164
165    /// Get predicted bbox without mutating state.
166    pub fn predicted_bbox(&self) -> BoundingBox {
167        let cx = self.x[0] + self.x[4];
168        let cy = self.x[1] + self.x[5];
169        let w = (self.x[2] + self.x[6]).max(1e-3);
170        let h = (self.x[3] + self.x[7]).max(1e-3);
171        BoundingBox {
172            x1: cx - w * 0.5,
173            y1: cy - h * 0.5,
174            x2: cx + w * 0.5,
175            y2: cy + h * 0.5,
176        }
177    }
178}
179
180// ── Small matrix helpers (8×8 and 4×4) ─────────────────────────────
181
182fn transition_matrix() -> [f32; STATE_DIM * STATE_DIM] {
183    let mut f = [0.0f32; STATE_DIM * STATE_DIM];
184    // Identity
185    for i in 0..STATE_DIM {
186        f[i * STATE_DIM + i] = 1.0;
187    }
188    // Position += velocity (dt=1)
189    for i in 0..4 {
190        f[i * STATE_DIM + i + 4] = 1.0;
191    }
192    f
193}
194
195fn transpose_8x8(a: &[f32; STATE_DIM * STATE_DIM]) -> [f32; STATE_DIM * STATE_DIM] {
196    let mut out = [0.0f32; STATE_DIM * STATE_DIM];
197    for i in 0..STATE_DIM {
198        for j in 0..STATE_DIM {
199            out[j * STATE_DIM + i] = a[i * STATE_DIM + j];
200        }
201    }
202    out
203}
204
205fn mat_mul_8x8(
206    a: &[f32; STATE_DIM * STATE_DIM],
207    b: &[f32; STATE_DIM * STATE_DIM],
208) -> [f32; STATE_DIM * STATE_DIM] {
209    let mut out = [0.0f32; STATE_DIM * STATE_DIM];
210    for i in 0..STATE_DIM {
211        for j in 0..STATE_DIM {
212            let mut sum = 0.0f32;
213            for k in 0..STATE_DIM {
214                sum += a[i * STATE_DIM + k] * b[k * STATE_DIM + j];
215            }
216            out[i * STATE_DIM + j] = sum;
217        }
218    }
219    out
220}
221
222/// Invert a 4×4 matrix using the adjugate method.
223#[allow(clippy::needless_range_loop)]
224fn invert_4x4(m: &[f32; MEAS_DIM * MEAS_DIM]) -> [f32; MEAS_DIM * MEAS_DIM] {
225    let n = MEAS_DIM;
226    // Gauss-Jordan elimination on augmented matrix
227    let mut aug = [[0.0f32; 8]; 4];
228    for i in 0..n {
229        for j in 0..n {
230            aug[i][j] = m[i * n + j];
231        }
232        aug[i][n + i] = 1.0;
233    }
234
235    for col in 0..n {
236        // Partial pivoting
237        let mut max_row = col;
238        let mut max_val = aug[col][col].abs();
239        for row in col + 1..n {
240            let v = aug[row][col].abs();
241            if v > max_val {
242                max_val = v;
243                max_row = row;
244            }
245        }
246        aug.swap(col, max_row);
247
248        let pivot = aug[col][col];
249        if pivot.abs() < 1e-12 {
250            // Singular — return identity as fallback
251            let mut result = [0.0f32; MEAS_DIM * MEAS_DIM];
252            for i in 0..n {
253                result[i * n + i] = 1.0;
254            }
255            return result;
256        }
257
258        let inv_pivot = 1.0 / pivot;
259        for j in 0..2 * n {
260            aug[col][j] *= inv_pivot;
261        }
262
263        for row in 0..n {
264            if row == col {
265                continue;
266            }
267            let factor = aug[row][col];
268            for j in 0..2 * n {
269                aug[row][j] -= factor * aug[col][j];
270            }
271        }
272    }
273
274    let mut result = [0.0f32; MEAS_DIM * MEAS_DIM];
275    for i in 0..n {
276        for j in 0..n {
277            result[i * n + j] = aug[i][n + j];
278        }
279    }
280    result
281}