1use super::linalg::{cholesky, inverse, matmul_rect, sub, transpose_rect};
8use crate::error::{SeqError, SeqResult};
9
10#[derive(Debug, Clone, Copy)]
12pub struct UkfParams {
13 pub alpha: f64,
15 pub beta: f64,
17 pub kappa: f64,
19}
20
21impl Default for UkfParams {
22 fn default() -> Self {
23 Self {
24 alpha: 1e-3,
25 beta: 2.0,
26 kappa: 0.0,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct UkfResult {
34 pub means: Vec<Vec<f64>>,
36 pub covs: Vec<Vec<f64>>,
38 pub pred_means: Vec<Vec<f64>>,
40 pub pred_covs: Vec<Vec<f64>>,
42}
43
44pub struct UnscentedKalmanFilter<'a> {
49 pub dim_x: usize,
51 pub dim_z: usize,
53 pub f: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
55 pub h: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
57 pub q: Vec<f64>,
59 pub r: Vec<f64>,
61 pub x0: Vec<f64>,
63 pub p0: Vec<f64>,
65 pub params: UkfParams,
67}
68
69fn compute_weights(n: usize, p: &UkfParams) -> (Vec<f64>, Vec<f64>) {
77 let lambda = p.alpha * p.alpha * (n as f64 + p.kappa) - n as f64;
78 let denom = n as f64 + lambda;
79 let n_pts = 2 * n + 1;
80
81 let mut wm = vec![0.0; n_pts];
82 let mut wc = vec![0.0; n_pts];
83
84 wm[0] = lambda / denom;
86 wc[0] = lambda / denom + (1.0 - p.alpha * p.alpha + p.beta);
87
88 let w_sym = 0.5 / denom;
90 for i in 1..n_pts {
91 wm[i] = w_sym;
92 wc[i] = w_sym;
93 }
94 (wm, wc)
95}
96
97fn sigma_points(
104 x_bar: &[f64],
105 p: &[f64],
106 n: usize,
107 p_params: &UkfParams,
108) -> SeqResult<Vec<Vec<f64>>> {
109 let lambda = p_params.alpha * p_params.alpha * (n as f64 + p_params.kappa) - n as f64;
110 let gamma = (n as f64 + lambda).sqrt();
111
112 let l = cholesky(p, n)?;
113
114 let n_pts = 2 * n + 1;
115 let mut pts: Vec<Vec<f64>> = Vec::with_capacity(n_pts);
116
117 pts.push(x_bar.to_vec());
119
120 for i in 0..n {
122 let col_i: Vec<f64> = (0..n).map(|r| l[r * n + i]).collect();
123 let pt: Vec<f64> = x_bar
124 .iter()
125 .zip(col_i.iter())
126 .map(|(m, c)| m + gamma * c)
127 .collect();
128 pts.push(pt);
129 }
130
131 for i in 0..n {
133 let col_i: Vec<f64> = (0..n).map(|r| l[r * n + i]).collect();
134 let pt: Vec<f64> = x_bar
135 .iter()
136 .zip(col_i.iter())
137 .map(|(m, c)| m - gamma * c)
138 .collect();
139 pts.push(pt);
140 }
141
142 Ok(pts)
143}
144
145fn weighted_mean(pts: &[Vec<f64>], w: &[f64], dim: usize) -> Vec<f64> {
147 let mut mean = vec![0.0; dim];
148 for (i, pt) in pts.iter().enumerate() {
149 for d in 0..dim {
150 mean[d] += w[i] * pt[d];
151 }
152 }
153 mean
154}
155
156fn weighted_cross_cov(
161 u_pts: &[Vec<f64>],
162 u_bar: &[f64],
163 v_pts: &[Vec<f64>],
164 v_bar: &[f64],
165 wc: &[f64],
166 dim_u: usize,
167 dim_v: usize,
168) -> Vec<f64> {
169 let mut cov = vec![0.0; dim_u * dim_v];
170 for i in 0..u_pts.len() {
171 let du: Vec<f64> = u_pts[i]
172 .iter()
173 .zip(u_bar.iter())
174 .map(|(a, b)| a - b)
175 .collect();
176 let dv: Vec<f64> = v_pts[i]
177 .iter()
178 .zip(v_bar.iter())
179 .map(|(a, b)| a - b)
180 .collect();
181 for r in 0..dim_u {
182 for c in 0..dim_v {
183 cov[r * dim_v + c] += wc[i] * du[r] * dv[c];
184 }
185 }
186 }
187 cov
188}
189
190impl<'a> UnscentedKalmanFilter<'a> {
195 fn validate(&self, z: &[f64]) -> SeqResult<()> {
197 if z.is_empty() {
198 return Err(SeqError::EmptyInput);
199 }
200 if z.len() % self.dim_z != 0 {
201 return Err(SeqError::DimensionMismatch {
202 a: z.len(),
203 b: self.dim_z,
204 });
205 }
206 if self.q.len() != self.dim_x * self.dim_x {
207 return Err(SeqError::ShapeMismatch {
208 expected: self.dim_x * self.dim_x,
209 got: self.q.len(),
210 });
211 }
212 if self.r.len() != self.dim_z * self.dim_z {
213 return Err(SeqError::ShapeMismatch {
214 expected: self.dim_z * self.dim_z,
215 got: self.r.len(),
216 });
217 }
218 if self.x0.len() != self.dim_x {
219 return Err(SeqError::ShapeMismatch {
220 expected: self.dim_x,
221 got: self.x0.len(),
222 });
223 }
224 if self.p0.len() != self.dim_x * self.dim_x {
225 return Err(SeqError::ShapeMismatch {
226 expected: self.dim_x * self.dim_x,
227 got: self.p0.len(),
228 });
229 }
230 if self.params.alpha <= 0.0 || self.params.alpha > 1.0 {
231 return Err(SeqError::InvalidParameter {
232 name: "alpha".to_string(),
233 value: self.params.alpha,
234 });
235 }
236 if self.params.beta < 0.0 {
237 return Err(SeqError::InvalidParameter {
238 name: "beta".to_string(),
239 value: self.params.beta,
240 });
241 }
242 if self.params.kappa < 0.0 {
243 return Err(SeqError::InvalidParameter {
244 name: "kappa".to_string(),
245 value: self.params.kappa,
246 });
247 }
248 Ok(())
249 }
250
251 pub fn run(&self, z: &[f64]) -> SeqResult<UkfResult> {
256 self.validate(z)?;
257
258 let nx = self.dim_x;
259 let nz = self.dim_z;
260 let t_max = z.len() / nz;
261
262 let (wm, wc) = compute_weights(nx, &self.params);
263
264 let mut x = self.x0.clone();
265 let mut p = self.p0.clone();
266
267 let mut means = Vec::with_capacity(t_max);
268 let mut covs = Vec::with_capacity(t_max);
269 let mut pred_means = Vec::with_capacity(t_max);
270 let mut pred_covs = Vec::with_capacity(t_max);
271
272 for t in 0..t_max {
273 let chi = sigma_points(&x, &p, nx, &self.params)?;
277
278 let gamma_pts: Vec<Vec<f64>> = chi.iter().map(|s| (self.f)(s)).collect();
282 let x_pred = weighted_mean(&gamma_pts, &wm, nx);
283
284 let mut p_pred =
286 weighted_cross_cov(&gamma_pts, &x_pred, &gamma_pts, &x_pred, &wc, nx, nx);
287 for k in 0..p_pred.len() {
288 p_pred[k] += self.q[k];
289 }
290
291 pred_means.push(x_pred.clone());
292 pred_covs.push(p_pred.clone());
293
294 let chi_pred = sigma_points(&x_pred, &p_pred, nx, &self.params)?;
299 let upsilon_pts: Vec<Vec<f64>> = chi_pred.iter().map(|s| (self.h)(s)).collect();
300 let y_pred = weighted_mean(&upsilon_pts, &wm, nz);
301
302 let mut s_mat =
304 weighted_cross_cov(&upsilon_pts, &y_pred, &upsilon_pts, &y_pred, &wc, nz, nz);
305 for k in 0..s_mat.len() {
306 s_mat[k] += self.r[k];
307 }
308
309 let p_xy = weighted_cross_cov(&chi_pred, &x_pred, &upsilon_pts, &y_pred, &wc, nx, nz);
311
312 let s_inv = inverse(&s_mat, nz)?;
314 let k_gain = matmul_rect(&p_xy, &s_inv, nx, nz, nz);
315
316 let z_t = &z[t * nz..(t + 1) * nz];
318 let nu = sub(z_t, &y_pred);
319
320 let k_nu = matmul_rect(&k_gain, &nu, nx, nz, 1);
322 x = x_pred.iter().zip(k_nu.iter()).map(|(a, b)| a + b).collect();
323
324 let ks = matmul_rect(&k_gain, &s_mat, nx, nz, nz);
326 let k_t = transpose_rect(&k_gain, nx, nz);
327 let kskt = matmul_rect(&ks, &k_t, nx, nz, nx);
328 p = sub(&p_pred, &kskt);
329
330 means.push(x.clone());
331 covs.push(p.clone());
332 }
333
334 Ok(UkfResult {
335 means,
336 covs,
337 pred_means,
338 pred_covs,
339 })
340 }
341}
342
343#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::kalman::kalman_filter::KalmanFilter;
351
352 fn make_1d_ukf<'a>(q_val: f64, r_val: f64, x0: f64, p0: f64) -> UnscentedKalmanFilter<'a> {
353 UnscentedKalmanFilter {
354 dim_x: 1,
355 dim_z: 1,
356 f: Box::new(|x: &[f64]| vec![x[0]]),
357 h: Box::new(|x: &[f64]| vec![x[0]]),
358 q: vec![q_val],
359 r: vec![r_val],
360 x0: vec![x0],
361 p0: vec![p0],
362 params: UkfParams::default(),
363 }
364 }
365
366 #[test]
367 fn default_params_ok() {
368 let p = UkfParams::default();
369 assert!((p.alpha - 1e-3).abs() < 1e-15);
370 assert!((p.beta - 2.0).abs() < 1e-15);
371 assert!(p.kappa.abs() < 1e-15);
372 }
373
374 #[test]
375 fn ukf_linear_matches_kf() {
376 let z = vec![1.0, 1.05, 0.95, 1.02, 1.0];
378 let q_val = 0.01;
379 let r_val = 0.05;
380
381 let ukf = make_1d_ukf(q_val, r_val, 0.0, 1.0);
382 let ukf_res = ukf.run(&z).expect("UKF run failed");
383
384 let kf = KalmanFilter::new(
385 1,
386 1,
387 vec![1.0],
388 vec![1.0],
389 vec![q_val],
390 vec![r_val],
391 vec![0.0],
392 vec![1.0],
393 )
394 .expect("ok");
395 let kf_res = kf.filter(&z).expect("KF run failed");
396
397 for t in 0..z.len() {
398 let diff = (ukf_res.means[t][0] - kf_res.means[t][0]).abs();
399 assert!(
400 diff < 1e-6,
401 "step {t}: UKF={:.10} KF={:.10} diff={:.2e}",
402 ukf_res.means[t][0],
403 kf_res.means[t][0],
404 diff
405 );
406 }
407 }
408
409 #[test]
410 fn ukf_identity_state() {
411 let z = vec![2.0, 2.0, 2.0, 2.0, 2.0];
413 let ukf = UnscentedKalmanFilter {
414 dim_x: 1,
415 dim_z: 1,
416 f: Box::new(|x: &[f64]| vec![x[0]]),
417 h: Box::new(|x: &[f64]| vec![x[0]]),
418 q: vec![0.001],
419 r: vec![1e-6],
420 x0: vec![0.0],
421 p0: vec![10.0],
422 params: UkfParams::default(),
423 };
424 let res = ukf.run(&z).expect("ok");
425 let last = res.means[res.means.len() - 1][0];
426 assert!((last - 2.0).abs() < 0.01, "expected ~2.0 got {last}");
427 }
428
429 #[test]
430 fn ukf_output_length() {
431 let z: Vec<f64> = (0..7).map(|i| i as f64 * 0.1).collect();
432 let ukf = make_1d_ukf(0.01, 0.05, 0.0, 1.0);
433 let res = ukf.run(&z).expect("ok");
434 assert_eq!(res.means.len(), 7);
435 for t in 0..7 {
436 assert_eq!(res.means[t].len(), 1, "means dim mismatch at t={t}");
437 }
438 }
439
440 #[test]
441 fn ukf_cov_positive_diagonal() {
442 let z: Vec<f64> = (0..10).map(|i| (i as f64) * 0.1 + 1.0).collect();
443 let ukf = make_1d_ukf(0.01, 0.05, 0.0, 1.0);
444 let res = ukf.run(&z).expect("ok");
445 for (t, cov) in res.covs.iter().enumerate() {
446 assert!(cov[0] > 0.0, "non-positive diagonal at t={t}: {}", cov[0]);
448 }
449 }
450
451 #[test]
452 fn ukf_pred_means_correct_length() {
453 let z = vec![1.0, 2.0, 3.0];
454 let ukf = make_1d_ukf(0.01, 0.1, 0.0, 1.0);
455 let res = ukf.run(&z).expect("ok");
456 assert_eq!(res.pred_means.len(), 3);
457 for t in 0..3 {
458 assert_eq!(res.pred_means[t].len(), 1, "pred_means dim at t={t}");
459 }
460 }
461
462 #[test]
463 fn ukf_nonlinear_cos() {
464 let ukf = UnscentedKalmanFilter {
466 dim_x: 1,
467 dim_z: 1,
468 f: Box::new(|x: &[f64]| vec![x[0].cos()]),
469 h: Box::new(|x: &[f64]| vec![x[0]]),
470 q: vec![0.1],
471 r: vec![0.5],
472 x0: vec![0.5],
473 p0: vec![1.0],
474 params: UkfParams::default(),
475 };
476 let z = vec![0.9, 0.95, 0.98, 0.97, 0.96];
477 let res = ukf.run(&z).expect("nonlinear UKF failed");
478 assert_eq!(res.means.len(), 5);
479 for (t, cov) in res.covs.iter().enumerate() {
481 assert!(cov[0] > 0.0, "negative cov at t={t}");
482 }
483 }
484
485 #[test]
486 fn ukf_tracks_slowly_varying() {
487 let z: Vec<f64> = (0..20).map(|i| 1.0 + i as f64 * 0.05).collect();
489 let ukf = UnscentedKalmanFilter {
490 dim_x: 1,
491 dim_z: 1,
492 f: Box::new(|x: &[f64]| vec![x[0]]),
493 h: Box::new(|x: &[f64]| vec![x[0]]),
494 q: vec![0.01],
495 r: vec![0.05],
496 x0: vec![1.0],
497 p0: vec![1.0],
498 params: UkfParams::default(),
499 };
500 let res = ukf.run(&z).expect("ok");
501 let last_mean = res.means[19][0];
502 let last_std = res.covs[19][0].sqrt();
503 let true_val = z[19];
504 assert!(
505 (last_mean - true_val).abs() < 3.0 * last_std + 0.5,
506 "drifted too far: mean={last_mean:.4} true={true_val:.4} std={last_std:.4}"
507 );
508 }
509
510 #[test]
511 fn err_empty_obs() {
512 let ukf = make_1d_ukf(0.01, 0.05, 0.0, 1.0);
513 let result = ukf.run(&[]);
514 assert!(matches!(result, Err(SeqError::EmptyInput)));
515 }
516
517 #[test]
518 fn err_z_len_not_multiple_of_dim_z() {
519 let ukf = UnscentedKalmanFilter {
520 dim_x: 1,
521 dim_z: 2,
522 f: Box::new(|x: &[f64]| x.to_vec()),
523 h: Box::new(|x: &[f64]| vec![x[0], x[0]]),
524 q: vec![0.01, 0.0, 0.0, 0.01],
525 r: vec![0.1, 0.0, 0.0, 0.1],
526 x0: vec![0.0],
527 p0: vec![1.0],
528 params: UkfParams::default(),
529 };
530 let result = ukf.run(&[1.0, 2.0, 3.0]);
532 assert!(matches!(result, Err(SeqError::DimensionMismatch { .. })));
533 }
534
535 #[test]
536 fn err_q_wrong_shape() {
537 let ukf = UnscentedKalmanFilter {
538 dim_x: 2,
539 dim_z: 1,
540 f: Box::new(|x: &[f64]| x.to_vec()),
541 h: Box::new(|x: &[f64]| vec![x[0]]),
542 q: vec![0.01], r: vec![0.1],
544 x0: vec![0.0, 0.0],
545 p0: vec![1.0, 0.0, 0.0, 1.0],
546 params: UkfParams::default(),
547 };
548 let result = ukf.run(&[1.0, 2.0]);
549 assert!(matches!(result, Err(SeqError::ShapeMismatch { .. })));
550 }
551
552 #[test]
553 fn err_r_wrong_shape() {
554 let ukf = UnscentedKalmanFilter {
555 dim_x: 1,
556 dim_z: 1,
557 f: Box::new(|x: &[f64]| vec![x[0]]),
558 h: Box::new(|x: &[f64]| vec![x[0]]),
559 q: vec![0.01],
560 r: vec![0.1, 0.0, 0.0], x0: vec![0.0],
562 p0: vec![1.0],
563 params: UkfParams::default(),
564 };
565 let result = ukf.run(&[1.0, 2.0]);
566 assert!(matches!(result, Err(SeqError::ShapeMismatch { .. })));
567 }
568
569 #[test]
570 fn err_x0_wrong_len() {
571 let ukf = UnscentedKalmanFilter {
572 dim_x: 2,
573 dim_z: 1,
574 f: Box::new(|x: &[f64]| x.to_vec()),
575 h: Box::new(|x: &[f64]| vec![x[0]]),
576 q: vec![0.01, 0.0, 0.0, 0.01],
577 r: vec![0.1],
578 x0: vec![0.0], p0: vec![1.0, 0.0, 0.0, 1.0],
580 params: UkfParams::default(),
581 };
582 let result = ukf.run(&[1.0, 2.0]);
583 assert!(matches!(result, Err(SeqError::ShapeMismatch { .. })));
584 }
585
586 #[test]
587 fn sigma_point_count() {
588 let ukf = UnscentedKalmanFilter {
590 dim_x: 2,
591 dim_z: 2,
592 f: Box::new(|x: &[f64]| vec![x[0], x[1]]),
593 h: Box::new(|x: &[f64]| vec![x[0], x[1]]),
594 q: vec![0.01, 0.0, 0.0, 0.01],
595 r: vec![0.1, 0.0, 0.0, 0.1],
596 x0: vec![0.0, 0.0],
597 p0: vec![1.0, 0.0, 0.0, 1.0],
598 params: UkfParams::default(),
599 };
600 let z = vec![1.0, 2.0, 1.1, 2.1, 0.9, 1.9];
602 let res = ukf.run(&z).expect("sigma_point_count test failed");
603 assert_eq!(res.means.len(), 3);
604 }
605
606 #[test]
607 fn ukf_2d_state_1d_obs() {
608 let dt = 1.0_f64;
610 let ukf = UnscentedKalmanFilter {
611 dim_x: 2,
612 dim_z: 1,
613 f: Box::new(move |x: &[f64]| vec![x[0] + dt * x[1], x[1]]),
614 h: Box::new(|x: &[f64]| vec![x[0]]),
615 q: vec![0.01, 0.0, 0.0, 0.01],
616 r: vec![0.5],
617 x0: vec![0.0, 1.0],
618 p0: vec![1.0, 0.0, 0.0, 1.0],
619 params: UkfParams::default(),
620 };
621 let z: Vec<f64> = (0..8).map(|t| t as f64 * 1.0).collect();
622 let res = ukf.run(&z).expect("2d state 1d obs failed");
623 assert_eq!(res.means.len(), 8);
624 for (t, m) in res.means.iter().enumerate() {
625 assert_eq!(m.len(), 2, "state dim at t={t}");
626 }
627 }
628
629 #[test]
630 fn ukf_dim_x_1_dim_z_1() {
631 let ukf = make_1d_ukf(0.01, 0.1, 0.0, 1.0);
633 let z = vec![1.0; 20];
634 let res = ukf.run(&z).expect("simplest case failed");
635 let last = res.means[19][0];
636 assert!((last - 1.0).abs() < 0.05, "did not converge: {last}");
637 }
638
639 #[test]
640 fn ukf_weights_sum_to_one() {
641 for n in [1usize, 2, 3, 5, 10] {
644 let p = UkfParams::default();
645 let (wm, _wc) = compute_weights(n, &p);
646 let sum: f64 = wm.iter().sum();
647 assert!(
648 (sum - 1.0).abs() < 1e-9,
649 "weights don't sum to 1 for n={n}: sum={sum}"
650 );
651 }
652 }
653}