1use ndarray::{s, Array, Array1, Array2, ArrayRef, Axis, Dimension, Zip};
2use num_traits::{Float, FromPrimitive};
3
4use crate::{array_like, BorderMode};
5
6use super::{con_corr::inner_correlate1d, symmetry::SymmetryStateCheck};
7
8pub fn gaussian_filter<A, D>(
23 data: &ArrayRef<A, D>,
24 sigma: A,
25 order: usize,
26 mode: BorderMode<A>,
27 truncate: usize,
28) -> Array<A, D>
29where
30 A: Float + FromPrimitive + 'static,
31 for<'a> &'a [A]: SymmetryStateCheck,
32 D: Dimension,
33{
34 let weights = weights(sigma, order, truncate);
35 let half = weights.len() / 2;
36
37 let mut data = data.to_owned();
42 let mut output = array_like(&data, data.dim(), A::zero());
43
44 for d in 0..data.ndim() {
45 let n = data.len_of(Axis(d));
50 if half > n {
51 panic!("Data size is too small for the inputs (sigma and truncate)");
52 }
53
54 inner_correlate1d(&data, &weights, Axis(d), mode, 0, &mut output);
55 if d != data.ndim() - 1 {
56 std::mem::swap(&mut output, &mut data);
57 }
58 }
59 output
60}
61
62pub fn gaussian_filter1d<A, D>(
76 data: &ArrayRef<A, D>,
77 sigma: A,
78 axis: Axis,
79 order: usize,
80 mode: BorderMode<A>,
81 truncate: usize,
82) -> Array<A, D>
83where
84 A: Float + FromPrimitive + 'static,
85 for<'a> &'a [A]: SymmetryStateCheck,
86 D: Dimension,
87{
88 let weights = weights(sigma, order, truncate);
89 let mut output = array_like(&data, data.dim(), A::zero());
90 inner_correlate1d(data, &weights, axis, mode, 0, &mut output);
91 output
92}
93
94fn weights<A>(sigma: A, order: usize, truncate: usize) -> Vec<A>
96where
97 A: Float + FromPrimitive + 'static,
98{
99 let radius = (A::from(truncate).unwrap() * sigma + A::from(0.5).unwrap()).to_isize().unwrap();
101
102 let sigma2 = sigma.powi(2);
103 let phi_x = {
104 let m05 = A::from(-0.5).unwrap();
105 let mut phi_x: Vec<_> =
106 (-radius..=radius).map(|x| (m05 / sigma2 * A::from(x.pow(2)).unwrap()).exp()).collect();
107 let sum = phi_x.iter().fold(A::zero(), |acc, &v| acc + v);
108 phi_x.iter_mut().for_each(|v| *v = *v / sum);
109 phi_x
110 };
111
112 if order == 0 {
113 phi_x
114 } else {
115 let mut q = Array1::zeros(order + 1);
116 q[0] = A::one();
117
118 let q_d = {
119 let mut q_d = Array2::<A>::zeros((order + 1, order + 1));
120 for (e, i) in q_d.slice_mut(s![..order, 1..]).diag_mut().iter_mut().zip(1..) {
121 *e = A::from(i).unwrap();
122 }
123
124 q_d.slice_mut(s![1.., ..order]).diag_mut().fill(-sigma2.recip());
125 q_d
126 };
127
128 for _ in 0..order {
129 q = q_d.dot(&q);
130 }
131
132 (-radius..=radius)
133 .zip(phi_x.into_iter())
134 .map(|(x, phi_x)| {
135 Zip::indexed(&q)
136 .fold(A::zero(), |acc, i, &q| acc + q * A::from(x.pow(i as u32)).unwrap())
137 * phi_x
138 })
139 .collect()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use approx::assert_relative_eq;
147
148 #[test]
149 fn test_weights() {
150 assert_relative_eq!(
151 weights(1.0, 0, 3).as_slice(),
152 &[0.00443304, 0.05400558, 0.24203622, 0.39905027, 0.24203622, 0.05400558, 0.00443304][..],
153 epsilon = 1e-7
154 );
155 assert_relative_eq!(
156 weights(1.0, 0, 4).as_slice(),
157 &[
158 0.00013383, 0.00443186, 0.05399112, 0.24197144, 0.39894346, 0.24197144, 0.05399112,
159 0.00443186, 0.00013383,
160 ][..],
161 epsilon = 1e-7
162 );
163
164 assert_relative_eq!(
166 weights(1.0, 1, 3).as_slice(),
167 &[0.01329914, 0.10801116, 0.24203622, 0.0, -0.24203622, -0.10801116, -0.01329914][..],
168 epsilon = 1e-7
169 );
170 assert_relative_eq!(
171 weights(1.0, 1, 4).as_slice(),
172 &[
173 0.00053532,
174 0.01329558,
175 0.10798225,
176 0.24197144,
177 0.0,
178 -0.24197144,
179 -0.10798225,
180 -0.01329558,
181 -0.00053532,
182 ][..],
183 epsilon = 1e-7
184 );
185 assert_relative_eq!(
186 weights(1.0, 2, 3).as_slice(),
187 &[0.03546438, 0.16201674, 0.0, -0.39905027, 0.0, 0.16201674, 0.03546438][..],
188 epsilon = 1e-7
189 );
190 assert_relative_eq!(
191 weights(0.75, 3, 3).as_slice(),
192 &[0.39498175, -0.84499983, 0.0, 0.84499983, -0.39498175][..],
193 epsilon = 1e-7
194 );
195 }
196}