Skip to main content

ndarray_ndimage/filters/
gaussian.rs

1use ndarray::{s, Array, Array1, Array2, ArrayRef, Axis, Dimension, Zip};
2use num_traits::{Float, FromPrimitive};
3
4use crate::{array_like, BorderMode};
5
6use super::{con_corr::inner_correlate1d, symmetry::SymmetryStateCheck};
7
8/// Gaussian filter for n-dimensional arrays.
9///
10/// Currently hardcoded with the `PadMode::Reflect` padding mode.
11///
12/// * `data` - The input N-D data.
13/// * `sigma` - Standard deviation for Gaussian kernel.
14/// * `order` - The order of the filter along all axes. An order of 0 corresponds to a convolution
15///   with a Gaussian kernel. A positive order corresponds to a convolution with that derivative of
16///   a Gaussian.
17/// * `mode` - Method that will be used to select the padded values. See the
18///   [`BorderMode`](crate::BorderMode) enum for more information.
19/// * `truncate` - Truncate the filter at this many standard deviations.
20///
21/// **Panics** if one of the axis' lengths is lower than `truncate * sigma + 0.5`.
22pub fn gaussian_filter<A, D>(
23    data: &ArrayRef<A, D>,
24    sigma: A,
25    order: usize,
26    mode: BorderMode<A>,
27    truncate: usize,
28) -> Array<A, D>
29where
30    A: Float + FromPrimitive + 'static,
31    for<'a> &'a [A]: SymmetryStateCheck,
32    D: Dimension,
33{
34    let weights = weights(sigma, order, truncate);
35    let half = weights.len() / 2;
36
37    // We need 2 buffers because
38    // * We're reading neighbours so we can't read and write on the same location.
39    // * The process is applied for each axis on the result of the previous process.
40    // * It's uglier (using &mut) but much faster than allocating for each axis.
41    let mut data = data.to_owned();
42    let mut output = array_like(&data, data.dim(), A::zero());
43
44    for d in 0..data.ndim() {
45        // TODO This can be made to work if the padding modes (`reflect`, `symmetric`, `wrap`) are
46        // more robust. One just needs to reflect the input data several times if the `weights`
47        // length is greater than the input array. It works in SciPy because they are looping on a
48        // size variable instead of running the algo only once like we do.
49        let n = data.len_of(Axis(d));
50        if half > n {
51            panic!("Data size is too small for the inputs (sigma and truncate)");
52        }
53
54        inner_correlate1d(&data, &weights, Axis(d), mode, 0, &mut output);
55        if d != data.ndim() - 1 {
56            std::mem::swap(&mut output, &mut data);
57        }
58    }
59    output
60}
61
62/// Gaussian filter for 1-dimensional arrays.
63///
64/// * `data` - The input N-D data.
65/// * `sigma` - Standard deviation for Gaussian kernel.
66/// * `axis` - The axis of input along which to calculate.
67/// * `order` - The order of the filter along all axes. An order of 0 corresponds to a convolution
68///   with a Gaussian kernel. A positive order corresponds to a convolution with that derivative of
69///   a Gaussian.
70/// * `mode` - Method that will be used to select the padded values. See the
71///   [`BorderMode`](crate::BorderMode) enum for more information.
72/// * `truncate` - Truncate the filter at this many standard deviations.
73///
74/// **Panics** if the axis length is lower than `truncate * sigma + 0.5`.
75pub fn gaussian_filter1d<A, D>(
76    data: &ArrayRef<A, D>,
77    sigma: A,
78    axis: Axis,
79    order: usize,
80    mode: BorderMode<A>,
81    truncate: usize,
82) -> Array<A, D>
83where
84    A: Float + FromPrimitive + 'static,
85    for<'a> &'a [A]: SymmetryStateCheck,
86    D: Dimension,
87{
88    let weights = weights(sigma, order, truncate);
89    let mut output = array_like(&data, data.dim(), A::zero());
90    inner_correlate1d(data, &weights, axis, mode, 0, &mut output);
91    output
92}
93
94/// Computes a 1-D Gaussian convolution kernel.
95fn weights<A>(sigma: A, order: usize, truncate: usize) -> Vec<A>
96where
97    A: Float + FromPrimitive + 'static,
98{
99    // Make the radius of the filter equal to truncate standard deviations
100    let radius = (A::from(truncate).unwrap() * sigma + A::from(0.5).unwrap()).to_isize().unwrap();
101
102    let sigma2 = sigma.powi(2);
103    let phi_x = {
104        let m05 = A::from(-0.5).unwrap();
105        let mut phi_x: Vec<_> =
106            (-radius..=radius).map(|x| (m05 / sigma2 * A::from(x.pow(2)).unwrap()).exp()).collect();
107        let sum = phi_x.iter().fold(A::zero(), |acc, &v| acc + v);
108        phi_x.iter_mut().for_each(|v| *v = *v / sum);
109        phi_x
110    };
111
112    if order == 0 {
113        phi_x
114    } else {
115        let mut q = Array1::zeros(order + 1);
116        q[0] = A::one();
117
118        let q_d = {
119            let mut q_d = Array2::<A>::zeros((order + 1, order + 1));
120            for (e, i) in q_d.slice_mut(s![..order, 1..]).diag_mut().iter_mut().zip(1..) {
121                *e = A::from(i).unwrap();
122            }
123
124            q_d.slice_mut(s![1.., ..order]).diag_mut().fill(-sigma2.recip());
125            q_d
126        };
127
128        for _ in 0..order {
129            q = q_d.dot(&q);
130        }
131
132        (-radius..=radius)
133            .zip(phi_x.into_iter())
134            .map(|(x, phi_x)| {
135                Zip::indexed(&q)
136                    .fold(A::zero(), |acc, i, &q| acc + q * A::from(x.pow(i as u32)).unwrap())
137                    * phi_x
138            })
139            .collect()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use approx::assert_relative_eq;
147
148    #[test]
149    fn test_weights() {
150        assert_relative_eq!(
151            weights(1.0, 0, 3).as_slice(),
152            &[0.00443304, 0.05400558, 0.24203622, 0.39905027, 0.24203622, 0.05400558, 0.00443304][..],
153            epsilon = 1e-7
154        );
155        assert_relative_eq!(
156            weights(1.0, 0, 4).as_slice(),
157            &[
158                0.00013383, 0.00443186, 0.05399112, 0.24197144, 0.39894346, 0.24197144, 0.05399112,
159                0.00443186, 0.00013383,
160            ][..],
161            epsilon = 1e-7
162        );
163
164        // Different orders
165        assert_relative_eq!(
166            weights(1.0, 1, 3).as_slice(),
167            &[0.01329914, 0.10801116, 0.24203622, 0.0, -0.24203622, -0.10801116, -0.01329914][..],
168            epsilon = 1e-7
169        );
170        assert_relative_eq!(
171            weights(1.0, 1, 4).as_slice(),
172            &[
173                0.00053532,
174                0.01329558,
175                0.10798225,
176                0.24197144,
177                0.0,
178                -0.24197144,
179                -0.10798225,
180                -0.01329558,
181                -0.00053532,
182            ][..],
183            epsilon = 1e-7
184        );
185        assert_relative_eq!(
186            weights(1.0, 2, 3).as_slice(),
187            &[0.03546438, 0.16201674, 0.0, -0.39905027, 0.0, 0.16201674, 0.03546438][..],
188            epsilon = 1e-7
189        );
190        assert_relative_eq!(
191            weights(0.75, 3, 3).as_slice(),
192            &[0.39498175, -0.84499983, 0.0, 0.84499983, -0.39498175][..],
193            epsilon = 1e-7
194        );
195    }
196}