1use crate::error::{SeqError, SeqResult};
9use crate::handle::LcgRng;
10
11#[derive(Debug, Clone)]
13pub struct ParticleConfig {
14 pub n_particles: usize,
16 pub dim_x: usize,
18 pub dim_z: usize,
20 pub resample_threshold: f64,
22}
23
24impl Default for ParticleConfig {
25 fn default() -> Self {
26 Self {
27 n_particles: 100,
28 dim_x: 1,
29 dim_z: 1,
30 resample_threshold: 0.5,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct ParticleResult {
38 pub means: Vec<Vec<f64>>,
40 pub eff_sizes: Vec<f64>,
42 pub n_resamples: usize,
44}
45
46pub struct ParticleFilter<'a> {
51 pub cfg: ParticleConfig,
53 pub q_chol: Vec<f64>,
55 pub r: Vec<f64>,
57 pub f: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
59 pub h: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
61 pub x0: Vec<f64>,
63 pub p0_chol: Vec<f64>,
65}
66
67fn sample_gaussian(mean: &[f64], l_chol: &[f64], dim: usize, rng: &mut LcgRng) -> Vec<f64> {
74 let eps: Vec<f64> = (0..dim).map(|_| rng.next_normal()).collect();
76 let mut noise = vec![0.0; dim];
78 for i in 0..dim {
79 let mut s = 0.0;
80 for k in 0..=i {
81 s += l_chol[i * dim + k] * eps[k];
82 }
83 noise[i] = s;
84 }
85 mean.iter().zip(noise.iter()).map(|(m, n)| m + n).collect()
86}
87
88fn log_likelihood_diag(z_obs: &[f64], z_pred: &[f64], r: &[f64], dim_z: usize) -> f64 {
92 let mut ll = 0.0;
93 for k in 0..dim_z {
94 let r_kk = r[k * dim_z + k];
95 if r_kk > 0.0 {
96 let diff = z_obs[k] - z_pred[k];
97 ll -= 0.5 * diff * diff / r_kk;
98 }
99 }
100 ll
101}
102
103fn systematic_resample(weights: &[f64], rng: &mut LcgRng) -> Vec<usize> {
108 let n = weights.len();
109 let mut cumsum = vec![0.0; n];
111 cumsum[0] = weights[0];
112 for i in 1..n {
113 cumsum[i] = cumsum[i - 1] + weights[i];
114 }
115 let u0 = rng.next_f64() / n as f64;
117 let mut indices = vec![0usize; n];
118 let mut k = 0usize;
119 for j in 0..n {
120 let u_j = u0 + j as f64 / n as f64;
121 while k < n - 1 && cumsum[k] < u_j {
122 k += 1;
123 }
124 indices[j] = k;
125 }
126 indices
127}
128
129impl<'a> ParticleFilter<'a> {
134 fn validate(&self, z: &[f64]) -> SeqResult<()> {
136 if self.cfg.n_particles < 2 {
137 return Err(SeqError::InvalidConfiguration(format!(
138 "n_particles must be >= 2, got {}",
139 self.cfg.n_particles
140 )));
141 }
142 if z.is_empty() {
143 return Err(SeqError::EmptyInput);
144 }
145 if z.len() % self.cfg.dim_z != 0 {
146 return Err(SeqError::DimensionMismatch {
147 a: z.len(),
148 b: self.cfg.dim_z,
149 });
150 }
151 let nx = self.cfg.dim_x;
152 let nz = self.cfg.dim_z;
153 if self.q_chol.len() != nx * nx {
154 return Err(SeqError::ShapeMismatch {
155 expected: nx * nx,
156 got: self.q_chol.len(),
157 });
158 }
159 if self.r.len() != nz * nz {
160 return Err(SeqError::ShapeMismatch {
161 expected: nz * nz,
162 got: self.r.len(),
163 });
164 }
165 if self.x0.len() != nx {
166 return Err(SeqError::ShapeMismatch {
167 expected: nx,
168 got: self.x0.len(),
169 });
170 }
171 if self.p0_chol.len() != nx * nx {
172 return Err(SeqError::ShapeMismatch {
173 expected: nx * nx,
174 got: self.p0_chol.len(),
175 });
176 }
177 Ok(())
178 }
179
180 pub fn run(&self, z: &[f64], rng: &mut LcgRng) -> SeqResult<ParticleResult> {
184 self.validate(z)?;
185
186 let n = self.cfg.n_particles;
187 let nx = self.cfg.dim_x;
188 let nz = self.cfg.dim_z;
189 let t_max = z.len() / nz;
190 let resample_threshold = self.cfg.resample_threshold * n as f64;
191
192 let mut particles: Vec<Vec<f64>> = (0..n)
196 .map(|_| sample_gaussian(&self.x0, &self.p0_chol, nx, rng))
197 .collect();
198
199 let log_n = (n as f64).ln();
201 let mut log_weights: Vec<f64> = vec![-log_n; n];
202
203 let mut means = Vec::with_capacity(t_max);
204 let mut eff_sizes = Vec::with_capacity(t_max);
205 let mut n_resamples = 0usize;
206
207 for t in 0..t_max {
208 let z_t = &z[t * nz..(t + 1) * nz];
209
210 for i in 0..n {
214 let mu_i = (self.f)(&particles[i]);
215 particles[i] = sample_gaussian(&mu_i, &self.q_chol, nx, rng);
216 }
217
218 for i in 0..n {
222 let z_pred = (self.h)(&particles[i]);
223 let ll = log_likelihood_diag(z_t, &z_pred, &self.r, nz);
224 log_weights[i] += ll;
225 }
226
227 let log_max = log_weights
229 .iter()
230 .cloned()
231 .fold(f64::NEG_INFINITY, f64::max);
232 let shifted: Vec<f64> = log_weights.iter().map(|&lw| (lw - log_max).exp()).collect();
233 let sum_shifted: f64 = shifted.iter().sum();
234 let weights: Vec<f64> = shifted.iter().map(|&s| s / sum_shifted).collect();
235
236 for i in 0..n {
238 log_weights[i] = weights[i].max(f64::MIN_POSITIVE).ln();
239 }
240
241 let mut mean = vec![0.0; nx];
245 for i in 0..n {
246 for d in 0..nx {
247 mean[d] += weights[i] * particles[i][d];
248 }
249 }
250
251 let n_eff = 1.0 / weights.iter().map(|&w| w * w).sum::<f64>();
252 means.push(mean);
253 eff_sizes.push(n_eff);
254
255 if n_eff < resample_threshold {
259 let indices = systematic_resample(&weights, rng);
260 let old_particles = particles.clone();
262 for i in 0..n {
263 particles[i] = old_particles[indices[i]].clone();
264 }
265 for i in 0..n {
267 log_weights[i] = -log_n;
268 }
269 n_resamples += 1;
270 }
271 }
272
273 Ok(ParticleResult {
274 means,
275 eff_sizes,
276 n_resamples,
277 })
278 }
279}
280
281#[cfg(test)]
286mod tests {
287 use super::*;
288 use crate::handle::LcgRng;
289
290 fn make_1d_pf(n: usize, q_var: f64, r_var: f64) -> ParticleFilter<'static> {
292 ParticleFilter {
293 cfg: ParticleConfig {
294 n_particles: n,
295 dim_x: 1,
296 dim_z: 1,
297 resample_threshold: 0.5,
298 },
299 q_chol: vec![q_var.sqrt()],
300 r: vec![r_var],
301 f: Box::new(|x: &[f64]| vec![x[0]]),
302 h: Box::new(|x: &[f64]| vec![x[0]]),
303 x0: vec![0.0],
304 p0_chol: vec![1.0],
305 }
306 }
307
308 #[test]
309 fn new_config_default_ok() {
310 let cfg = ParticleConfig::default();
311 assert_eq!(cfg.n_particles, 100);
312 assert_eq!(cfg.dim_x, 1);
313 assert_eq!(cfg.dim_z, 1);
314 assert!((cfg.resample_threshold - 0.5).abs() < 1e-15);
315 }
316
317 #[test]
318 fn pf_output_length() {
319 let pf = make_1d_pf(50, 0.1, 0.5);
320 let z: Vec<f64> = vec![1.0; 10];
321 let mut rng = LcgRng::new(42);
322 let res = pf.run(&z, &mut rng).expect("ok");
323 assert_eq!(res.means.len(), 10);
324 }
325
326 #[test]
327 fn pf_means_dim_correct() {
328 let pf = make_1d_pf(50, 0.1, 0.5);
329 let z = vec![1.0; 5];
330 let mut rng = LcgRng::new(7);
331 let res = pf.run(&z, &mut rng).expect("ok");
332 for (t, m) in res.means.iter().enumerate() {
333 assert_eq!(m.len(), 1, "dim mismatch at t={t}");
334 }
335 }
336
337 #[test]
338 fn pf_constant_obs_converges() {
339 let pf = ParticleFilter {
341 cfg: ParticleConfig {
342 n_particles: 500,
343 dim_x: 1,
344 dim_z: 1,
345 resample_threshold: 0.5,
346 },
347 q_chol: vec![0.01f64.sqrt()],
348 r: vec![0.01],
349 f: Box::new(|x: &[f64]| vec![x[0]]),
350 h: Box::new(|x: &[f64]| vec![x[0]]),
351 x0: vec![1.0],
352 p0_chol: vec![0.1],
353 };
354 let z = vec![1.0_f64; 15];
355 let mut rng = LcgRng::new(12345);
356 let res = pf.run(&z, &mut rng).expect("ok");
357 let last = res.means[14][0];
358 assert!((last - 1.0).abs() < 0.5, "did not converge: {last}");
359 }
360
361 #[test]
362 fn pf_zero_innovation() {
363 let target = 3.0_f64;
365 let pf = ParticleFilter {
366 cfg: ParticleConfig {
367 n_particles: 100,
368 dim_x: 1,
369 dim_z: 1,
370 resample_threshold: 0.5,
371 },
372 q_chol: vec![0.01],
373 r: vec![0.1],
374 f: Box::new(|x: &[f64]| vec![x[0]]),
375 h: Box::new(move |_x: &[f64]| vec![target]),
376 x0: vec![0.0],
377 p0_chol: vec![1.0],
378 };
379 let z = vec![target; 10];
380 let mut rng = LcgRng::new(99);
381 let res = pf.run(&z, &mut rng).expect("ok");
382 assert_eq!(res.means.len(), 10);
384 }
385
386 #[test]
387 fn pf_eff_size_bounded() {
388 let pf = make_1d_pf(100, 0.1, 0.5);
389 let z = vec![1.0; 8];
390 let mut rng = LcgRng::new(77);
391 let res = pf.run(&z, &mut rng).expect("ok");
392 for (t, &neff) in res.eff_sizes.iter().enumerate() {
393 assert!(neff > 0.0, "N_eff <= 0 at t={t}: {neff}");
394 assert!(neff <= 100.0 + 1e-6, "N_eff > N at t={t}: {neff}");
395 }
396 }
397
398 #[test]
399 fn pf_weights_normalize() {
400 let pf = make_1d_pf(50, 0.05, 0.1);
404 let z = vec![1.0; 6];
405 let mut rng = LcgRng::new(11);
406 let res = pf.run(&z, &mut rng).expect("ok");
407 for &neff in &res.eff_sizes {
409 assert!(neff > 0.0 && neff <= 50.0 + 1e-6);
410 }
411 }
412
413 #[test]
414 fn pf_resamples_occur() {
415 let pf = ParticleFilter {
417 cfg: ParticleConfig {
418 n_particles: 50,
419 dim_x: 1,
420 dim_z: 1,
421 resample_threshold: 0.9, },
423 q_chol: vec![0.01],
424 r: vec![1e-6], f: Box::new(|x: &[f64]| vec![x[0]]),
426 h: Box::new(|x: &[f64]| vec![x[0]]),
427 x0: vec![0.0],
428 p0_chol: vec![1.0],
429 };
430 let z = vec![1.0_f64; 20];
431 let mut rng = LcgRng::new(42);
432 let res = pf.run(&z, &mut rng).expect("ok");
433 assert!(
434 res.n_resamples > 0,
435 "expected at least one resampling event"
436 );
437 }
438
439 #[test]
440 fn pf_deterministic_same_seed() {
441 let pf1 = make_1d_pf(80, 0.1, 0.3);
443 let pf2 = make_1d_pf(80, 0.1, 0.3);
444 let z: Vec<f64> = (0..10).map(|i| i as f64 * 0.2).collect();
445 let mut rng1 = LcgRng::new(999);
446 let mut rng2 = LcgRng::new(999);
447 let res1 = pf1.run(&z, &mut rng1).expect("ok");
448 let res2 = pf2.run(&z, &mut rng2).expect("ok");
449 for t in 0..10 {
450 assert!(
451 (res1.means[t][0] - res2.means[t][0]).abs() < 1e-15,
452 "mismatch at t={t}"
453 );
454 }
455 }
456
457 #[test]
458 fn pf_1d_random_walk() {
459 let pf = ParticleFilter {
461 cfg: ParticleConfig {
462 n_particles: 300,
463 dim_x: 1,
464 dim_z: 1,
465 resample_threshold: 0.5,
466 },
467 q_chol: vec![0.1],
468 r: vec![0.25],
469 f: Box::new(|x: &[f64]| vec![x[0]]),
470 h: Box::new(|x: &[f64]| vec![x[0]]),
471 x0: vec![0.0],
472 p0_chol: vec![0.5],
473 };
474 let z: Vec<f64> = vec![0.1, 0.2, 0.15, 0.3, 0.25, 0.4, 0.5, 0.45, 0.6, 0.55];
476 let mut rng = LcgRng::new(314);
477 let res = pf.run(&z, &mut rng).expect("ok");
478 assert_eq!(res.means.len(), 10);
479 let last = res.means[9][0];
481 assert!(last.abs() < 2.0, "estimate out of range: {last}");
482 }
483
484 #[test]
485 fn pf_2d_state_2d_obs() {
486 let pf = ParticleFilter {
488 cfg: ParticleConfig {
489 n_particles: 50,
490 dim_x: 2,
491 dim_z: 2,
492 resample_threshold: 0.5,
493 },
494 q_chol: vec![0.1, 0.0, 0.0, 0.1],
495 r: vec![0.5, 0.0, 0.0, 0.5],
496 f: Box::new(|x: &[f64]| vec![x[0], x[1]]),
497 h: Box::new(|x: &[f64]| vec![x[0], x[1]]),
498 x0: vec![0.0, 0.0],
499 p0_chol: vec![1.0, 0.0, 0.0, 1.0],
500 };
501 let z: Vec<f64> = (0..5)
502 .flat_map(|i| vec![i as f64 * 0.2, i as f64 * 0.1])
503 .collect();
504 let mut rng = LcgRng::new(55);
505 let res = pf.run(&z, &mut rng).expect("2d test failed");
506 assert_eq!(res.means.len(), 5);
507 for (t, m) in res.means.iter().enumerate() {
508 assert_eq!(m.len(), 2, "state dim at t={t}");
509 }
510 }
511
512 #[test]
513 fn err_empty_obs() {
514 let pf = make_1d_pf(50, 0.1, 0.5);
515 let mut rng = LcgRng::new(1);
516 let result = pf.run(&[], &mut rng);
517 assert!(matches!(result, Err(SeqError::EmptyInput)));
518 }
519
520 #[test]
521 fn err_n_particles_lt_2() {
522 let pf = ParticleFilter {
523 cfg: ParticleConfig {
524 n_particles: 1,
525 dim_x: 1,
526 dim_z: 1,
527 resample_threshold: 0.5,
528 },
529 q_chol: vec![0.1],
530 r: vec![0.5],
531 f: Box::new(|x: &[f64]| vec![x[0]]),
532 h: Box::new(|x: &[f64]| vec![x[0]]),
533 x0: vec![0.0],
534 p0_chol: vec![1.0],
535 };
536 let mut rng = LcgRng::new(1);
537 let result = pf.run(&[1.0], &mut rng);
538 assert!(matches!(result, Err(SeqError::InvalidConfiguration(_))));
539 }
540
541 #[test]
542 fn err_z_len_not_multiple_of_dim_z() {
543 let pf = ParticleFilter {
544 cfg: ParticleConfig {
545 n_particles: 10,
546 dim_x: 1,
547 dim_z: 2,
548 resample_threshold: 0.5,
549 },
550 q_chol: vec![0.1],
551 r: vec![0.5, 0.0, 0.0, 0.5],
552 f: Box::new(|x: &[f64]| vec![x[0]]),
553 h: Box::new(|x: &[f64]| vec![x[0], x[0]]),
554 x0: vec![0.0],
555 p0_chol: vec![1.0],
556 };
557 let mut rng = LcgRng::new(1);
558 let result = pf.run(&[1.0, 2.0, 3.0], &mut rng);
559 assert!(matches!(result, Err(SeqError::DimensionMismatch { .. })));
560 }
561
562 #[test]
563 fn err_q_chol_wrong_shape() {
564 let pf = ParticleFilter {
565 cfg: ParticleConfig {
566 n_particles: 10,
567 dim_x: 2,
568 dim_z: 1,
569 resample_threshold: 0.5,
570 },
571 q_chol: vec![0.1], r: vec![0.5],
573 f: Box::new(|x: &[f64]| x.to_vec()),
574 h: Box::new(|x: &[f64]| vec![x[0]]),
575 x0: vec![0.0, 0.0],
576 p0_chol: vec![1.0, 0.0, 0.0, 1.0],
577 };
578 let mut rng = LcgRng::new(1);
579 let result = pf.run(&[1.0], &mut rng);
580 assert!(matches!(result, Err(SeqError::ShapeMismatch { .. })));
581 }
582
583 #[test]
584 fn err_x0_wrong_len() {
585 let pf = ParticleFilter {
586 cfg: ParticleConfig {
587 n_particles: 10,
588 dim_x: 2,
589 dim_z: 1,
590 resample_threshold: 0.5,
591 },
592 q_chol: vec![0.1, 0.0, 0.0, 0.1],
593 r: vec![0.5],
594 f: Box::new(|x: &[f64]| x.to_vec()),
595 h: Box::new(|x: &[f64]| vec![x[0]]),
596 x0: vec![0.0], p0_chol: vec![1.0, 0.0, 0.0, 1.0],
598 };
599 let mut rng = LcgRng::new(1);
600 let result = pf.run(&[1.0], &mut rng);
601 assert!(matches!(result, Err(SeqError::ShapeMismatch { .. })));
602 }
603
604 #[test]
605 fn pf_systematic_resampling_valid() {
606 let weights: Vec<f64> = vec![1.0 / 10.0; 10];
608 let mut rng = LcgRng::new(42);
609 let indices = systematic_resample(&weights, &mut rng);
610 for &idx in &indices {
612 assert!(idx < 10, "index out of range: {idx}");
613 }
614 assert_eq!(indices.len(), 10);
615 }
616
617 #[test]
618 fn pf_n_particles_100_no_panic() {
619 let pf = ParticleFilter {
621 cfg: ParticleConfig {
622 n_particles: 100,
623 dim_x: 1,
624 dim_z: 1,
625 resample_threshold: 0.5,
626 },
627 q_chol: vec![0.1],
628 r: vec![0.3],
629 f: Box::new(|x: &[f64]| vec![x[0]]),
630 h: Box::new(|x: &[f64]| vec![x[0]]),
631 x0: vec![0.0],
632 p0_chol: vec![1.0],
633 };
634 let z: Vec<f64> = (0..10).map(|i| i as f64 * 0.1 + 0.5).collect();
635 let mut rng = LcgRng::new(2025);
636 let res = pf
637 .run(&z, &mut rng)
638 .expect("100 particles clean run failed");
639 assert_eq!(res.means.len(), 10);
640 assert_eq!(res.eff_sizes.len(), 10);
641 for &neff in &res.eff_sizes {
642 assert!(neff > 0.0 && neff <= 100.0 + 1e-6);
643 }
644 }
645}