ndarray_ndimage/filters/
gaussian.rs

1use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, Dimension, Zip};
2use num_traits::{Float, FromPrimitive};
3
4use crate::{array_like, BorderMode};
5
6use super::{con_corr::inner_correlate1d, symmetry::SymmetryStateCheck};
7
8/// Gaussian filter for n-dimensional arrays.
9///
10/// Currently hardcoded with the `PadMode::Reflect` padding mode.
11///
12/// * `data` - The input N-D data.
13/// * `sigma` - Standard deviation for Gaussian kernel.
14/// * `order` - The order of the filter along all axes. An order of 0 corresponds to a convolution
15///   with a Gaussian kernel. A positive order corresponds to a convolution with that derivative of
16///   a Gaussian.
17/// * `mode` - Method that will be used to select the padded values. See the
18///   [`BorderMode`](crate::BorderMode) enum for more information.
19/// * `truncate` - Truncate the filter at this many standard deviations.
20///
21/// **Panics** if one of the axis' lengths is lower than `truncate * sigma + 0.5`.
22pub fn gaussian_filter<S, A, D>(
23    data: &ArrayBase<S, D>,
24    sigma: A,
25    order: usize,
26    mode: BorderMode<A>,
27    truncate: usize,
28) -> Array<A, D>
29where
30    S: Data<Elem = A>,
31    A: Float + FromPrimitive + 'static,
32    for<'a> &'a [A]: SymmetryStateCheck,
33    D: Dimension,
34{
35    let weights = weights(sigma, order, truncate);
36    let half = weights.len() / 2;
37
38    // We need 2 buffers because
39    // * We're reading neighbours so we can't read and write on the same location.
40    // * The process is applied for each axis on the result of the previous process.
41    // * It's uglier (using &mut) but much faster than allocating for each axis.
42    let mut data = data.to_owned();
43    let mut output = array_like(&data, data.dim(), A::zero());
44
45    for d in 0..data.ndim() {
46        // TODO This can be made to work if the padding modes (`reflect`, `symmetric`, `wrap`) are
47        // more robust. One just needs to reflect the input data several times if the `weights`
48        // length is greater than the input array. It works in SciPy because they are looping on a
49        // size variable instead of running the algo only once like we do.
50        let n = data.len_of(Axis(d));
51        if half > n {
52            panic!("Data size is too small for the inputs (sigma and truncate)");
53        }
54
55        inner_correlate1d(&data, &weights, Axis(d), mode, 0, &mut output);
56        if d != data.ndim() - 1 {
57            std::mem::swap(&mut output, &mut data);
58        }
59    }
60    output
61}
62
63/// Gaussian filter for 1-dimensional arrays.
64///
65/// * `data` - The input N-D data.
66/// * `sigma` - Standard deviation for Gaussian kernel.
67/// * `axis` - The axis of input along which to calculate.
68/// * `order` - The order of the filter along all axes. An order of 0 corresponds to a convolution
69///   with a Gaussian kernel. A positive order corresponds to a convolution with that derivative of
70///   a Gaussian.
71/// * `mode` - Method that will be used to select the padded values. See the
72///   [`BorderMode`](crate::BorderMode) enum for more information.
73/// * `truncate` - Truncate the filter at this many standard deviations.
74///
75/// **Panics** if the axis length is lower than `truncate * sigma + 0.5`.
76pub fn gaussian_filter1d<S, A, D>(
77    data: &ArrayBase<S, D>,
78    sigma: A,
79    axis: Axis,
80    order: usize,
81    mode: BorderMode<A>,
82    truncate: usize,
83) -> Array<A, D>
84where
85    S: Data<Elem = A>,
86    A: Float + FromPrimitive + 'static,
87    for<'a> &'a [A]: SymmetryStateCheck,
88    D: Dimension,
89{
90    let weights = weights(sigma, order, truncate);
91    let mut output = array_like(&data, data.dim(), A::zero());
92    inner_correlate1d(data, &weights, axis, mode, 0, &mut output);
93    output
94}
95
96/// Computes a 1-D Gaussian convolution kernel.
97fn weights<A>(sigma: A, order: usize, truncate: usize) -> Vec<A>
98where
99    A: Float + FromPrimitive + 'static,
100{
101    // Make the radius of the filter equal to truncate standard deviations
102    let radius = (A::from(truncate).unwrap() * sigma + A::from(0.5).unwrap()).to_isize().unwrap();
103
104    let sigma2 = sigma.powi(2);
105    let phi_x = {
106        let m05 = A::from(-0.5).unwrap();
107        let mut phi_x: Vec<_> =
108            (-radius..=radius).map(|x| (m05 / sigma2 * A::from(x.pow(2)).unwrap()).exp()).collect();
109        let sum = phi_x.iter().fold(A::zero(), |acc, &v| acc + v);
110        phi_x.iter_mut().for_each(|v| *v = *v / sum);
111        phi_x
112    };
113
114    if order == 0 {
115        phi_x
116    } else {
117        let mut q = Array1::zeros(order + 1);
118        q[0] = A::one();
119
120        let q_d = {
121            let mut q_d = Array2::<A>::zeros((order + 1, order + 1));
122            for (e, i) in q_d.slice_mut(s![..order, 1..]).diag_mut().iter_mut().zip(1..) {
123                *e = A::from(i).unwrap();
124            }
125
126            q_d.slice_mut(s![1.., ..order]).diag_mut().fill(-sigma2.recip());
127            q_d
128        };
129
130        for _ in 0..order {
131            q = q_d.dot(&q);
132        }
133
134        (-radius..=radius)
135            .zip(phi_x.into_iter())
136            .map(|(x, phi_x)| {
137                Zip::indexed(&q)
138                    .fold(A::zero(), |acc, i, &q| acc + q * A::from(x.pow(i as u32)).unwrap())
139                    * phi_x
140            })
141            .collect()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use approx::assert_relative_eq;
149
150    #[test]
151    fn test_weights() {
152        assert_relative_eq!(
153            weights(1.0, 0, 3).as_slice(),
154            &[0.00443304, 0.05400558, 0.24203622, 0.39905027, 0.24203622, 0.05400558, 0.00443304][..],
155            epsilon = 1e-7
156        );
157        assert_relative_eq!(
158            weights(1.0, 0, 4).as_slice(),
159            &[
160                0.00013383, 0.00443186, 0.05399112, 0.24197144, 0.39894346, 0.24197144, 0.05399112,
161                0.00443186, 0.00013383,
162            ][..],
163            epsilon = 1e-7
164        );
165
166        // Different orders
167        assert_relative_eq!(
168            weights(1.0, 1, 3).as_slice(),
169            &[0.01329914, 0.10801116, 0.24203622, 0.0, -0.24203622, -0.10801116, -0.01329914][..],
170            epsilon = 1e-7
171        );
172        assert_relative_eq!(
173            weights(1.0, 1, 4).as_slice(),
174            &[
175                0.00053532,
176                0.01329558,
177                0.10798225,
178                0.24197144,
179                0.0,
180                -0.24197144,
181                -0.10798225,
182                -0.01329558,
183                -0.00053532,
184            ][..],
185            epsilon = 1e-7
186        );
187        assert_relative_eq!(
188            weights(1.0, 2, 3).as_slice(),
189            &[0.03546438, 0.16201674, 0.0, -0.39905027, 0.0, 0.16201674, 0.03546438][..],
190            epsilon = 1e-7
191        );
192        assert_relative_eq!(
193            weights(0.75, 3, 3).as_slice(),
194            &[0.39498175, -0.84499983, 0.0, 0.84499983, -0.39498175][..],
195            epsilon = 1e-7
196        );
197    }
198}