1use yscv_detect::BoundingBox;
10
11const STATE_DIM: usize = 8;
13const MEAS_DIM: usize = 4;
15
16#[derive(Debug, Clone)]
18pub struct KalmanFilter {
19 pub(crate) x: [f32; STATE_DIM],
21 pub(crate) p: [f32; STATE_DIM * STATE_DIM],
23 q: [f32; STATE_DIM * STATE_DIM],
25 r: [f32; MEAS_DIM * MEAS_DIM],
27}
28
29impl KalmanFilter {
30 pub fn new(bbox: BoundingBox) -> Self {
32 let cx = (bbox.x1 + bbox.x2) * 0.5;
33 let cy = (bbox.y1 + bbox.y2) * 0.5;
34 let w = bbox.width();
35 let h = bbox.height();
36
37 let mut x = [0.0f32; STATE_DIM];
38 x[0] = cx;
39 x[1] = cy;
40 x[2] = w;
41 x[3] = h;
42 let mut p = [0.0f32; STATE_DIM * STATE_DIM];
46 for i in 0..4 {
47 p[i * STATE_DIM + i] = 10.0;
48 }
49 for i in 4..8 {
50 p[i * STATE_DIM + i] = 100.0;
51 }
52
53 let mut q = [0.0f32; STATE_DIM * STATE_DIM];
55 for i in 0..4 {
56 q[i * STATE_DIM + i] = 1.0;
57 }
58 for i in 4..8 {
59 q[i * STATE_DIM + i] = 0.01;
60 }
61
62 let mut r = [0.0f32; MEAS_DIM * MEAS_DIM];
64 for i in 0..MEAS_DIM {
65 r[i * MEAS_DIM + i] = 1.0;
66 }
67
68 Self { x, p, q, r }
69 }
70
71 #[allow(clippy::needless_range_loop)]
73 pub fn predict(&mut self) {
74 for i in 0..4 {
77 self.x[i] += self.x[i + 4];
78 }
79
80 let f = transition_matrix();
83 let ft = transpose_8x8(&f);
84 let fp = mat_mul_8x8(&f, &self.p);
85 let fpft = mat_mul_8x8(&fp, &ft);
86 for i in 0..STATE_DIM * STATE_DIM {
87 self.p[i] = fpft[i] + self.q[i];
88 }
89 }
90
91 pub fn update(&mut self, measurement: [f32; MEAS_DIM]) {
93 let mut y = [0.0f32; MEAS_DIM];
96 for i in 0..MEAS_DIM {
97 y[i] = measurement[i] - self.x[i];
98 }
99
100 let mut s = [0.0f32; MEAS_DIM * MEAS_DIM];
103 for i in 0..MEAS_DIM {
104 for j in 0..MEAS_DIM {
105 s[i * MEAS_DIM + j] = self.p[i * STATE_DIM + j] + self.r[i * MEAS_DIM + j];
106 }
107 }
108
109 let s_inv = invert_4x4(&s);
112 let mut k = [0.0f32; STATE_DIM * MEAS_DIM];
113 for i in 0..STATE_DIM {
114 for j in 0..MEAS_DIM {
115 let mut sum = 0.0f32;
116 for m in 0..MEAS_DIM {
117 sum += self.p[i * STATE_DIM + m] * s_inv[m * MEAS_DIM + j];
118 }
119 k[i * MEAS_DIM + j] = sum;
120 }
121 }
122
123 for i in 0..STATE_DIM {
125 let mut sum = 0.0f32;
126 for j in 0..MEAS_DIM {
127 sum += k[i * MEAS_DIM + j] * y[j];
128 }
129 self.x[i] += sum;
130 }
131
132 let mut kh = [0.0f32; STATE_DIM * STATE_DIM];
135 for i in 0..STATE_DIM {
136 for j in 0..MEAS_DIM {
137 kh[i * STATE_DIM + j] = k[i * MEAS_DIM + j];
138 }
139 }
140 let mut i_kh = [0.0f32; STATE_DIM * STATE_DIM];
142 for i in 0..STATE_DIM {
143 for j in 0..STATE_DIM {
144 i_kh[i * STATE_DIM + j] = if i == j { 1.0 } else { 0.0 } - kh[i * STATE_DIM + j];
145 }
146 }
147 let new_p = mat_mul_8x8(&i_kh, &self.p);
148 self.p = new_p;
149 }
150
151 pub fn bbox(&self) -> BoundingBox {
153 let cx = self.x[0];
154 let cy = self.x[1];
155 let w = self.x[2].max(1e-3);
156 let h = self.x[3].max(1e-3);
157 BoundingBox {
158 x1: cx - w * 0.5,
159 y1: cy - h * 0.5,
160 x2: cx + w * 0.5,
161 y2: cy + h * 0.5,
162 }
163 }
164
165 pub fn predicted_bbox(&self) -> BoundingBox {
167 let cx = self.x[0] + self.x[4];
168 let cy = self.x[1] + self.x[5];
169 let w = (self.x[2] + self.x[6]).max(1e-3);
170 let h = (self.x[3] + self.x[7]).max(1e-3);
171 BoundingBox {
172 x1: cx - w * 0.5,
173 y1: cy - h * 0.5,
174 x2: cx + w * 0.5,
175 y2: cy + h * 0.5,
176 }
177 }
178}
179
180fn transition_matrix() -> [f32; STATE_DIM * STATE_DIM] {
183 let mut f = [0.0f32; STATE_DIM * STATE_DIM];
184 for i in 0..STATE_DIM {
186 f[i * STATE_DIM + i] = 1.0;
187 }
188 for i in 0..4 {
190 f[i * STATE_DIM + i + 4] = 1.0;
191 }
192 f
193}
194
195fn transpose_8x8(a: &[f32; STATE_DIM * STATE_DIM]) -> [f32; STATE_DIM * STATE_DIM] {
196 let mut out = [0.0f32; STATE_DIM * STATE_DIM];
197 for i in 0..STATE_DIM {
198 for j in 0..STATE_DIM {
199 out[j * STATE_DIM + i] = a[i * STATE_DIM + j];
200 }
201 }
202 out
203}
204
205fn mat_mul_8x8(
206 a: &[f32; STATE_DIM * STATE_DIM],
207 b: &[f32; STATE_DIM * STATE_DIM],
208) -> [f32; STATE_DIM * STATE_DIM] {
209 let mut out = [0.0f32; STATE_DIM * STATE_DIM];
210 for i in 0..STATE_DIM {
211 for j in 0..STATE_DIM {
212 let mut sum = 0.0f32;
213 for k in 0..STATE_DIM {
214 sum += a[i * STATE_DIM + k] * b[k * STATE_DIM + j];
215 }
216 out[i * STATE_DIM + j] = sum;
217 }
218 }
219 out
220}
221
222#[allow(clippy::needless_range_loop)]
224fn invert_4x4(m: &[f32; MEAS_DIM * MEAS_DIM]) -> [f32; MEAS_DIM * MEAS_DIM] {
225 let n = MEAS_DIM;
226 let mut aug = [[0.0f32; 8]; 4];
228 for i in 0..n {
229 for j in 0..n {
230 aug[i][j] = m[i * n + j];
231 }
232 aug[i][n + i] = 1.0;
233 }
234
235 for col in 0..n {
236 let mut max_row = col;
238 let mut max_val = aug[col][col].abs();
239 for row in col + 1..n {
240 let v = aug[row][col].abs();
241 if v > max_val {
242 max_val = v;
243 max_row = row;
244 }
245 }
246 aug.swap(col, max_row);
247
248 let pivot = aug[col][col];
249 if pivot.abs() < 1e-12 {
250 let mut result = [0.0f32; MEAS_DIM * MEAS_DIM];
252 for i in 0..n {
253 result[i * n + i] = 1.0;
254 }
255 return result;
256 }
257
258 let inv_pivot = 1.0 / pivot;
259 for j in 0..2 * n {
260 aug[col][j] *= inv_pivot;
261 }
262
263 for row in 0..n {
264 if row == col {
265 continue;
266 }
267 let factor = aug[row][col];
268 for j in 0..2 * n {
269 aug[row][j] -= factor * aug[col][j];
270 }
271 }
272 }
273
274 let mut result = [0.0f32; MEAS_DIM * MEAS_DIM];
275 for i in 0..n {
276 for j in 0..n {
277 result[i * n + j] = aug[i][n + j];
278 }
279 }
280 result
281}