1use crate::error::{IoError, IoResult};
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::random::{thread_rng, Distribution, Normal};
14
15#[derive(Debug, Clone)]
20pub struct KalmanFilter {
21 state: Array1<f32>,
23 covariance: Array2<f32>,
25 transition: Array2<f32>,
27 observation: Array2<f32>,
29 process_noise: Array2<f32>,
31 measurement_noise: Array2<f32>,
33}
34
35impl KalmanFilter {
36 pub fn new(
46 initial_state: Array1<f32>,
47 initial_covariance: Array2<f32>,
48 transition: Array2<f32>,
49 observation: Array2<f32>,
50 process_noise: Array2<f32>,
51 measurement_noise: Array2<f32>,
52 ) -> IoResult<Self> {
53 let n = initial_state.len();
54 let m = observation.shape()[0];
55
56 if initial_covariance.shape() != [n, n] {
57 return Err(IoError::SignalError(
58 "Initial covariance must be n×n".into(),
59 ));
60 }
61 if transition.shape() != [n, n] {
62 return Err(IoError::SignalError("Transition matrix must be n×n".into()));
63 }
64 if observation.shape() != [m, n] {
65 return Err(IoError::SignalError(
66 "Observation matrix must be m×n".into(),
67 ));
68 }
69 if process_noise.shape() != [n, n] {
70 return Err(IoError::SignalError("Process noise must be n×n".into()));
71 }
72 if measurement_noise.shape() != [m, m] {
73 return Err(IoError::SignalError("Measurement noise must be m×m".into()));
74 }
75
76 Ok(Self {
77 state: initial_state,
78 covariance: initial_covariance,
79 transition,
80 observation,
81 process_noise,
82 measurement_noise,
83 })
84 }
85
86 pub fn predict(&mut self) {
88 let x_pred = self.transition.dot(&self.state);
90 let p_temp = self.transition.dot(&self.covariance);
92 let p_pred = p_temp.dot(&self.transition.t()) + &self.process_noise;
93
94 self.state = x_pred;
95 self.covariance = p_pred;
96 }
97
98 pub fn update(&mut self, measurement: &Array1<f32>) -> IoResult<()> {
100 let predicted_measurement = self.observation.dot(&self.state);
102 let innovation = measurement - &predicted_measurement;
103
104 let h_p = self.observation.dot(&self.covariance);
106 let s = h_p.dot(&self.observation.t()) + &self.measurement_noise;
107
108 let s_inv = Self::invert_matrix(&s)?;
110 let p_ht = self.covariance.dot(&self.observation.t());
111 let k = p_ht.dot(&s_inv);
112
113 let state_update = k.dot(&innovation);
115 self.state = &self.state + &state_update;
116
117 let n = self.state.len();
119 let identity = Array2::eye(n);
120 let kh = k.dot(&self.observation);
121 let p_update = (&identity - &kh).dot(&self.covariance);
122 self.covariance = p_update;
123
124 Ok(())
125 }
126
127 pub fn state(&self) -> &Array1<f32> {
129 &self.state
130 }
131
132 pub fn covariance(&self) -> &Array2<f32> {
134 &self.covariance
135 }
136
137 pub fn reset(&mut self, state: Array1<f32>, covariance: Array2<f32>) {
139 self.state = state;
140 self.covariance = covariance;
141 }
142
143 fn invert_matrix(mat: &Array2<f32>) -> IoResult<Array2<f32>> {
145 let shape = mat.shape();
146 if shape[0] != shape[1] {
147 return Err(IoError::SignalError("Matrix must be square".into()));
148 }
149
150 let n = shape[0];
151 if n == 1 {
152 let det = mat[[0, 0]];
153 if det.abs() < 1e-10 {
154 return Err(IoError::SignalError("Singular matrix".into()));
155 }
156 let mut inv = Array2::zeros((1, 1));
157 inv[[0, 0]] = 1.0 / det;
158 return Ok(inv);
159 }
160
161 if n == 2 {
162 let a = mat[[0, 0]];
163 let b = mat[[0, 1]];
164 let c = mat[[1, 0]];
165 let d = mat[[1, 1]];
166
167 let det = a * d - b * c;
168 if det.abs() < 1e-10 {
169 return Err(IoError::SignalError("Singular matrix".into()));
170 }
171
172 let mut inv = Array2::zeros((2, 2));
173 inv[[0, 0]] = d / det;
174 inv[[0, 1]] = -b / det;
175 inv[[1, 0]] = -c / det;
176 inv[[1, 1]] = a / det;
177 return Ok(inv);
178 }
179
180 let mut augmented = Array2::zeros((n, 2 * n));
182 for i in 0..n {
183 for j in 0..n {
184 augmented[[i, j]] = mat[[i, j]];
185 augmented[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
186 }
187 }
188
189 for i in 0..n {
191 let mut max_row = i;
193 for k in i + 1..n {
194 if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
195 max_row = k;
196 }
197 }
198
199 for j in 0..2 * n {
201 let tmp = augmented[[i, j]];
202 augmented[[i, j]] = augmented[[max_row, j]];
203 augmented[[max_row, j]] = tmp;
204 }
205
206 let pivot = augmented[[i, i]];
207 if pivot.abs() < 1e-10 {
208 return Err(IoError::SignalError("Singular matrix".into()));
209 }
210
211 for j in 0..2 * n {
213 augmented[[i, j]] /= pivot;
214 }
215
216 for k in 0..n {
218 if k != i {
219 let factor = augmented[[k, i]];
220 for j in 0..2 * n {
221 augmented[[k, j]] -= factor * augmented[[i, j]];
222 }
223 }
224 }
225 }
226
227 let mut inv = Array2::zeros((n, n));
229 for i in 0..n {
230 for j in 0..n {
231 inv[[i, j]] = augmented[[i, j + n]];
232 }
233 }
234
235 Ok(inv)
236 }
237}
238
239#[derive(Debug, Clone)]
244pub struct LmsFilter {
245 weights: Array1<f32>,
247 mu: f32,
249 buffer: Vec<f32>,
251 pos: usize,
253}
254
255impl LmsFilter {
256 pub fn new(num_taps: usize, mu: f32) -> IoResult<Self> {
262 if num_taps == 0 {
263 return Err(IoError::SignalError("Number of taps must be > 0".into()));
264 }
265 if mu <= 0.0 {
266 return Err(IoError::SignalError("Step size must be > 0".into()));
267 }
268
269 Ok(Self {
270 weights: Array1::zeros(num_taps),
271 mu,
272 buffer: vec![0.0; num_taps],
273 pos: 0,
274 })
275 }
276
277 pub fn adapt(&mut self, input: f32, desired: f32) -> (f32, f32) {
287 self.buffer[self.pos] = input;
289
290 let mut output = 0.0;
292 let mut buf_idx = self.pos;
293
294 for &weight in self.weights.iter() {
295 output += weight * self.buffer[buf_idx];
296 if buf_idx == 0 {
297 buf_idx = self.buffer.len() - 1;
298 } else {
299 buf_idx -= 1;
300 }
301 }
302
303 let error = desired - output;
305
306 buf_idx = self.pos;
308 for weight in self.weights.iter_mut() {
309 *weight += self.mu * error * self.buffer[buf_idx];
310 if buf_idx == 0 {
311 buf_idx = self.buffer.len() - 1;
312 } else {
313 buf_idx -= 1;
314 }
315 }
316
317 self.pos = (self.pos + 1) % self.buffer.len();
318
319 (output, error)
320 }
321
322 pub fn weights(&self) -> &Array1<f32> {
324 &self.weights
325 }
326
327 pub fn reset(&mut self) {
329 self.weights.fill(0.0);
330 self.buffer.fill(0.0);
331 self.pos = 0;
332 }
333}
334
335#[derive(Debug, Clone)]
340pub struct NlmsFilter {
341 weights: Array1<f32>,
343 mu: f32,
345 epsilon: f32,
347 buffer: Vec<f32>,
349 pos: usize,
351}
352
353impl NlmsFilter {
354 pub fn new(num_taps: usize, mu: f32, epsilon: Option<f32>) -> IoResult<Self> {
361 if num_taps == 0 {
362 return Err(IoError::SignalError("Number of taps must be > 0".into()));
363 }
364 if mu <= 0.0 {
365 return Err(IoError::SignalError("Step size must be > 0".into()));
366 }
367
368 Ok(Self {
369 weights: Array1::zeros(num_taps),
370 mu,
371 epsilon: epsilon.unwrap_or(1e-6),
372 buffer: vec![0.0; num_taps],
373 pos: 0,
374 })
375 }
376
377 pub fn adapt(&mut self, input: f32, desired: f32) -> (f32, f32) {
379 self.buffer[self.pos] = input;
381
382 let mut output = 0.0;
384 let mut buf_idx = self.pos;
385
386 for &weight in self.weights.iter() {
387 output += weight * self.buffer[buf_idx];
388 if buf_idx == 0 {
389 buf_idx = self.buffer.len() - 1;
390 } else {
391 buf_idx -= 1;
392 }
393 }
394
395 let error = desired - output;
397
398 let power: f32 = self.buffer.iter().map(|x| x * x).sum();
400 let normalized_mu = self.mu / (power + self.epsilon);
401
402 buf_idx = self.pos;
404 for weight in self.weights.iter_mut() {
405 *weight += normalized_mu * error * self.buffer[buf_idx];
406 if buf_idx == 0 {
407 buf_idx = self.buffer.len() - 1;
408 } else {
409 buf_idx -= 1;
410 }
411 }
412
413 self.pos = (self.pos + 1) % self.buffer.len();
414
415 (output, error)
416 }
417
418 pub fn weights(&self) -> &Array1<f32> {
420 &self.weights
421 }
422
423 pub fn reset(&mut self) {
425 self.weights.fill(0.0);
426 self.buffer.fill(0.0);
427 self.pos = 0;
428 }
429}
430
431#[derive(Debug, Clone)]
436pub struct RlsFilter {
437 weights: Array1<f32>,
439 p_matrix: Array2<f32>,
441 lambda: f32,
443 buffer: Vec<f32>,
445 pos: usize,
447}
448
449impl RlsFilter {
450 pub fn new(num_taps: usize, lambda: f32, delta: f32) -> IoResult<Self> {
457 if num_taps == 0 {
458 return Err(IoError::SignalError("Number of taps must be > 0".into()));
459 }
460 if lambda <= 0.0 || lambda > 1.0 {
461 return Err(IoError::SignalError(
462 "Forgetting factor must be in (0, 1]".into(),
463 ));
464 }
465 if delta <= 0.0 {
466 return Err(IoError::SignalError("Delta must be > 0".into()));
467 }
468
469 let p_matrix = Array2::eye(num_taps) * (1.0 / delta);
471
472 Ok(Self {
473 weights: Array1::zeros(num_taps),
474 p_matrix,
475 lambda,
476 buffer: vec![0.0; num_taps],
477 pos: 0,
478 })
479 }
480
481 pub fn adapt(&mut self, input: f32, desired: f32) -> (f32, f32) {
483 self.buffer[self.pos] = input;
485
486 let mut x = Array1::zeros(self.buffer.len());
488 let mut buf_idx = self.pos;
489 for i in 0..self.buffer.len() {
490 x[i] = self.buffer[buf_idx];
491 if buf_idx == 0 {
492 buf_idx = self.buffer.len() - 1;
493 } else {
494 buf_idx -= 1;
495 }
496 }
497
498 let output = self.weights.dot(&x);
500
501 let error = desired - output;
503
504 let p_x = self.p_matrix.dot(&x);
506 let denominator = self.lambda + x.dot(&p_x);
507 let k = &p_x / denominator;
508
509 self.weights = &self.weights + &(&k * error);
511
512 let k_reshape = k
514 .clone()
515 .to_shape((k.len(), 1))
516 .expect("Adaptive filter operation must succeed")
517 .to_owned();
518 let x_reshape = x
519 .clone()
520 .to_shape((1, x.len()))
521 .expect("Adaptive filter operation must succeed")
522 .to_owned();
523 let k_xt_p = k_reshape.dot(&x_reshape).dot(&self.p_matrix);
524 self.p_matrix = (&self.p_matrix - &k_xt_p) / self.lambda;
525
526 self.pos = (self.pos + 1) % self.buffer.len();
527
528 (output, error)
529 }
530
531 pub fn weights(&self) -> &Array1<f32> {
533 &self.weights
534 }
535
536 pub fn reset(&mut self, delta: f32) {
538 self.weights.fill(0.0);
539 let n = self.weights.len();
540 self.p_matrix = Array2::eye(n) * (1.0 / delta);
541 self.buffer.fill(0.0);
542 self.pos = 0;
543 }
544}
545
546#[derive(Debug, Clone)]
548struct Particle {
549 state: Array1<f32>,
551 weight: f32,
553}
554
555#[derive(Debug, Clone)]
560pub struct ParticleFilter {
561 particles: Vec<Particle>,
563 num_particles: usize,
565 process_noise_std: f32,
567 measurement_noise_std: f32,
569}
570
571impl ParticleFilter {
572 pub fn new(
581 num_particles: usize,
582 initial_state: Array1<f32>,
583 initial_std: f32,
584 process_noise_std: f32,
585 measurement_noise_std: f32,
586 ) -> IoResult<Self> {
587 if num_particles == 0 {
588 return Err(IoError::SignalError(
589 "Number of particles must be > 0".into(),
590 ));
591 }
592
593 let mut rng = thread_rng();
594 let normal = Normal::new(0.0, initial_std as f64).map_err(|e| {
595 IoError::SignalError(format!("Failed to create normal distribution: {}", e))
596 })?;
597
598 let weight = 1.0 / num_particles as f32;
600 let particles: Vec<Particle> = (0..num_particles)
601 .map(|_| {
602 let mut state = initial_state.clone();
603 for s in state.iter_mut() {
604 *s += normal.sample(&mut rng) as f32;
605 }
606 Particle { state, weight }
607 })
608 .collect();
609
610 Ok(Self {
611 particles,
612 num_particles,
613 process_noise_std,
614 measurement_noise_std,
615 })
616 }
617
618 pub fn predict<F>(&mut self, transition_fn: F)
623 where
624 F: Fn(&Array1<f32>) -> Array1<f32>,
625 {
626 let mut rng = thread_rng();
627 let normal = Normal::new(0.0, self.process_noise_std as f64)
628 .expect("Adaptive filter operation must succeed");
629
630 for particle in &mut self.particles {
631 particle.state = transition_fn(&particle.state);
633
634 for s in particle.state.iter_mut() {
636 *s += normal.sample(&mut rng) as f32;
637 }
638 }
639 }
640
641 pub fn update<F>(&mut self, measurement: &Array1<f32>, measurement_fn: F)
647 where
648 F: Fn(&Array1<f32>) -> Array1<f32>,
649 {
650 for particle in &mut self.particles {
652 let predicted_measurement = measurement_fn(&particle.state);
653 let diff = measurement - &predicted_measurement;
654 let distance_sq: f32 = diff.iter().map(|&x| x * x).sum();
655
656 let variance = self.measurement_noise_std * self.measurement_noise_std;
658 particle.weight *= (-distance_sq / (2.0 * variance)).exp();
659 }
660
661 let sum_weights: f32 = self.particles.iter().map(|p| p.weight).sum();
663 if sum_weights > 1e-10 {
664 for particle in &mut self.particles {
665 particle.weight /= sum_weights;
666 }
667 } else {
668 let uniform_weight = 1.0 / self.num_particles as f32;
670 for particle in &mut self.particles {
671 particle.weight = uniform_weight;
672 }
673 }
674
675 self.resample_if_needed();
677 }
678
679 fn resample_if_needed(&mut self) {
681 let sum_sq_weights: f32 = self.particles.iter().map(|p| p.weight * p.weight).sum();
683 let n_eff = 1.0 / sum_sq_weights;
684
685 if n_eff < (self.num_particles as f32 / 2.0) {
687 self.systematic_resample();
688 }
689 }
690
691 fn systematic_resample(&mut self) {
693 let mut rng = thread_rng();
694 let n = self.num_particles;
695 let mut new_particles = Vec::with_capacity(n);
696
697 let mut cumsum = Vec::with_capacity(n);
699 let mut sum = 0.0;
700 for particle in &self.particles {
701 sum += particle.weight;
702 cumsum.push(sum);
703 }
704
705 let step = 1.0 / n as f32;
707 let start: f32 = rng.gen_range(0.0..step);
708
709 let mut i = 0;
710 for j in 0..n {
711 let u = start + j as f32 * step;
712 while i < n - 1 && cumsum[i] < u {
713 i += 1;
714 }
715 let mut new_particle = self.particles[i].clone();
716 new_particle.weight = 1.0 / n as f32;
717 new_particles.push(new_particle);
718 }
719
720 self.particles = new_particles;
721 }
722
723 pub fn mean_state(&self) -> Array1<f32> {
725 let state_dim = self.particles[0].state.len();
726 let mut mean = Array1::zeros(state_dim);
727
728 for particle in &self.particles {
729 mean += &(&particle.state * particle.weight);
730 }
731
732 mean
733 }
734
735 pub fn covariance(&self) -> Array2<f32> {
737 let state_dim = self.particles[0].state.len();
738 let mean = self.mean_state();
739 let mut cov = Array2::zeros((state_dim, state_dim));
740
741 for particle in &self.particles {
742 let diff = &particle.state - &mean;
743 let diff_col = diff
744 .clone()
745 .to_shape((state_dim, 1))
746 .expect("Adaptive filter operation must succeed")
747 .to_owned();
748 let diff_row = diff
749 .to_shape((1, state_dim))
750 .expect("Adaptive filter operation must succeed")
751 .to_owned();
752 let outer = diff_col.dot(&diff_row);
753 cov += &(&outer * particle.weight);
754 }
755
756 cov
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use scirs2_core::ndarray::arr1;
764
765 #[test]
766 fn test_kalman_filter_1d() {
767 let initial_state = arr1(&[0.0]);
769 let initial_cov = Array2::from_shape_fn((1, 1), |_| 1.0);
770 let transition = Array2::from_shape_fn((1, 1), |_| 1.0); let observation = Array2::from_shape_fn((1, 1), |_| 1.0); let process_noise = Array2::from_shape_fn((1, 1), |_| 0.01);
773 let measurement_noise = Array2::from_shape_fn((1, 1), |_| 0.1);
774
775 let mut kf = KalmanFilter::new(
776 initial_state,
777 initial_cov,
778 transition,
779 observation,
780 process_noise,
781 measurement_noise,
782 )
783 .expect("Adaptive filter operation must succeed");
784
785 let measurement = arr1(&[1.0]);
787
788 kf.predict();
789 kf.update(&measurement)
790 .expect("Adaptive filter operation must succeed");
791
792 assert!(kf.state()[0] > 0.0 && kf.state()[0] < 1.0);
794 }
795
796 #[test]
797 fn test_lms_filter() {
798 let mut lms = LmsFilter::new(4, 0.01).expect("Adaptive filter operation must succeed");
799
800 for _ in 0..100 {
802 let input = scirs2_core::random::thread_rng().gen_range(-1.0..1.0);
803 let desired = input * 0.5; lms.adapt(input, desired);
805 }
806
807 assert_eq!(lms.weights().len(), 4);
810 }
811
812 #[test]
813 fn test_nlms_filter() {
814 let mut nlms =
815 NlmsFilter::new(4, 0.5, None).expect("Adaptive filter operation must succeed");
816
817 for _ in 0..100 {
818 let input = scirs2_core::random::thread_rng().gen_range(-1.0..1.0);
819 let desired = input * 0.5;
820 nlms.adapt(input, desired);
821 }
822
823 assert_eq!(nlms.weights().len(), 4);
824 }
825
826 #[test]
827 fn test_rls_filter() {
828 let mut rls = RlsFilter::new(4, 0.99, 0.1).expect("Adaptive filter operation must succeed");
829
830 for _ in 0..100 {
831 let input = scirs2_core::random::thread_rng().gen_range(-1.0..1.0);
832 let desired = input * 0.5;
833 rls.adapt(input, desired);
834 }
835
836 assert_eq!(rls.weights().len(), 4);
837 }
838
839 #[test]
840 fn test_particle_filter() {
841 let initial_state = arr1(&[0.0]);
842 let pf = ParticleFilter::new(100, initial_state, 1.0, 0.1, 0.5)
843 .expect("Adaptive filter operation must succeed");
844
845 assert_eq!(pf.particles.len(), 100);
846
847 let mean = pf.mean_state();
848 assert_eq!(mean.len(), 1);
849 }
850
851 #[test]
852 fn test_matrix_inversion_2x2() {
853 let mat = Array2::from_shape_vec((2, 2), vec![4.0, 7.0, 2.0, 6.0])
854 .expect("Adaptive filter operation must succeed");
855 let inv =
856 KalmanFilter::invert_matrix(&mat).expect("Adaptive filter operation must succeed");
857
858 let product = mat.dot(&inv);
860 assert!((product[[0, 0]] - 1.0).abs() < 1e-5);
861 assert!((product[[1, 1]] - 1.0).abs() < 1e-5);
862 assert!(product[[0, 1]].abs() < 1e-5);
863 assert!(product[[1, 0]].abs() < 1e-5);
864 }
865}