1use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, Dimension, Zip};
2use num_traits::{Float, FromPrimitive};
3
4use crate::{array_like, BorderMode};
5
6use super::{con_corr::inner_correlate1d, symmetry::SymmetryStateCheck};
7
8pub fn gaussian_filter<S, A, D>(
23 data: &ArrayBase<S, D>,
24 sigma: A,
25 order: usize,
26 mode: BorderMode<A>,
27 truncate: usize,
28) -> Array<A, D>
29where
30 S: Data<Elem = A>,
31 A: Float + FromPrimitive + 'static,
32 for<'a> &'a [A]: SymmetryStateCheck,
33 D: Dimension,
34{
35 let weights = weights(sigma, order, truncate);
36 let half = weights.len() / 2;
37
38 let mut data = data.to_owned();
43 let mut output = array_like(&data, data.dim(), A::zero());
44
45 for d in 0..data.ndim() {
46 let n = data.len_of(Axis(d));
51 if half > n {
52 panic!("Data size is too small for the inputs (sigma and truncate)");
53 }
54
55 inner_correlate1d(&data, &weights, Axis(d), mode, 0, &mut output);
56 if d != data.ndim() - 1 {
57 std::mem::swap(&mut output, &mut data);
58 }
59 }
60 output
61}
62
63pub fn gaussian_filter1d<S, A, D>(
77 data: &ArrayBase<S, D>,
78 sigma: A,
79 axis: Axis,
80 order: usize,
81 mode: BorderMode<A>,
82 truncate: usize,
83) -> Array<A, D>
84where
85 S: Data<Elem = A>,
86 A: Float + FromPrimitive + 'static,
87 for<'a> &'a [A]: SymmetryStateCheck,
88 D: Dimension,
89{
90 let weights = weights(sigma, order, truncate);
91 let mut output = array_like(&data, data.dim(), A::zero());
92 inner_correlate1d(data, &weights, axis, mode, 0, &mut output);
93 output
94}
95
96fn weights<A>(sigma: A, order: usize, truncate: usize) -> Vec<A>
98where
99 A: Float + FromPrimitive + 'static,
100{
101 let radius = (A::from(truncate).unwrap() * sigma + A::from(0.5).unwrap()).to_isize().unwrap();
103
104 let sigma2 = sigma.powi(2);
105 let phi_x = {
106 let m05 = A::from(-0.5).unwrap();
107 let mut phi_x: Vec<_> =
108 (-radius..=radius).map(|x| (m05 / sigma2 * A::from(x.pow(2)).unwrap()).exp()).collect();
109 let sum = phi_x.iter().fold(A::zero(), |acc, &v| acc + v);
110 phi_x.iter_mut().for_each(|v| *v = *v / sum);
111 phi_x
112 };
113
114 if order == 0 {
115 phi_x
116 } else {
117 let mut q = Array1::zeros(order + 1);
118 q[0] = A::one();
119
120 let q_d = {
121 let mut q_d = Array2::<A>::zeros((order + 1, order + 1));
122 for (e, i) in q_d.slice_mut(s![..order, 1..]).diag_mut().iter_mut().zip(1..) {
123 *e = A::from(i).unwrap();
124 }
125
126 q_d.slice_mut(s![1.., ..order]).diag_mut().fill(-sigma2.recip());
127 q_d
128 };
129
130 for _ in 0..order {
131 q = q_d.dot(&q);
132 }
133
134 (-radius..=radius)
135 .zip(phi_x.into_iter())
136 .map(|(x, phi_x)| {
137 Zip::indexed(&q)
138 .fold(A::zero(), |acc, i, &q| acc + q * A::from(x.pow(i as u32)).unwrap())
139 * phi_x
140 })
141 .collect()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use approx::assert_relative_eq;
149
150 #[test]
151 fn test_weights() {
152 assert_relative_eq!(
153 weights(1.0, 0, 3).as_slice(),
154 &[0.00443304, 0.05400558, 0.24203622, 0.39905027, 0.24203622, 0.05400558, 0.00443304][..],
155 epsilon = 1e-7
156 );
157 assert_relative_eq!(
158 weights(1.0, 0, 4).as_slice(),
159 &[
160 0.00013383, 0.00443186, 0.05399112, 0.24197144, 0.39894346, 0.24197144, 0.05399112,
161 0.00443186, 0.00013383,
162 ][..],
163 epsilon = 1e-7
164 );
165
166 assert_relative_eq!(
168 weights(1.0, 1, 3).as_slice(),
169 &[0.01329914, 0.10801116, 0.24203622, 0.0, -0.24203622, -0.10801116, -0.01329914][..],
170 epsilon = 1e-7
171 );
172 assert_relative_eq!(
173 weights(1.0, 1, 4).as_slice(),
174 &[
175 0.00053532,
176 0.01329558,
177 0.10798225,
178 0.24197144,
179 0.0,
180 -0.24197144,
181 -0.10798225,
182 -0.01329558,
183 -0.00053532,
184 ][..],
185 epsilon = 1e-7
186 );
187 assert_relative_eq!(
188 weights(1.0, 2, 3).as_slice(),
189 &[0.03546438, 0.16201674, 0.0, -0.39905027, 0.0, 0.16201674, 0.03546438][..],
190 epsilon = 1e-7
191 );
192 assert_relative_eq!(
193 weights(0.75, 3, 3).as_slice(),
194 &[0.39498175, -0.84499983, 0.0, 0.84499983, -0.39498175][..],
195 epsilon = 1e-7
196 );
197 }
198}