ndarray_vision/processing/
kernels.rs

1use crate::processing::Error;
2use core::ops::Neg;
3use ndarray::prelude::*;
4use ndarray::IntoDimension;
5use num_traits::{cast::FromPrimitive, float::Float, sign::Signed, Num, NumAssignOps, NumOps};
6
7/// Builds a convolutioon kernel given a shape and optional parameters
8pub trait KernelBuilder<T> {
9    /// Parameters used in construction of the kernel
10    type Params;
11    /// Build a kernel with a given dimension given sensible defaults for any
12    /// parameters
13    fn build<D>(shape: D) -> Result<Array3<T>, Error>
14    where
15        D: Copy + IntoDimension<Dim = Ix3>;
16    /// For kernels with optional parameters use build with params otherwise
17    /// appropriate default parameters will be chosen
18    fn build_with_params<D>(shape: D, _p: Self::Params) -> Result<Array3<T>, Error>
19    where
20        D: Copy + IntoDimension<Dim = Ix3>,
21    {
22        Self::build(shape)
23    }
24}
25
26/// Create a kernel with a fixed dimension
27pub trait FixedDimensionKernelBuilder<T> {
28    /// Parameters used in construction of the kernel
29    type Params;
30    /// Build a fixed size kernel
31    fn build() -> Result<Array3<T>, Error>;
32    /// Build a fixed size kernel with the given parameters
33    fn build_with_params(_p: Self::Params) -> Result<Array3<T>, Error> {
34        Self::build()
35    }
36}
37
38/// Create a Laplacian filter, this provides the 2nd spatial derivative of an
39/// image.
40#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
41pub struct LaplaceFilter;
42
43/// Specifies the type of Laplacian filter
44#[derive(Copy, Clone, Eq, PartialEq, Hash)]
45pub enum LaplaceType {
46    /// Standard filter and the default parameter choice, for a 3x3x1 matrix it is:
47    /// ```text
48    /// [0, -1, 0]
49    /// [-1, 4, -1]
50    /// [0, -1, 0]
51    /// ```
52    Standard,
53    /// The diagonal filter also contains derivatives for diagonal lines and
54    /// for a 3x3x1 matrix is given by:
55    /// ```text
56    /// [-1, -1, -1]
57    /// [-1, 8, -1]
58    /// [-1, -1, -1]
59    /// ```
60    Diagonal,
61}
62
63impl<T> FixedDimensionKernelBuilder<T> for LaplaceFilter
64where
65    T: Copy + Clone + Num + NumOps + Signed + FromPrimitive,
66{
67    /// Type of Laplacian filter to construct
68    type Params = LaplaceType;
69
70    fn build() -> Result<Array3<T>, Error> {
71        Self::build_with_params(LaplaceType::Standard)
72    }
73
74    fn build_with_params(p: Self::Params) -> Result<Array3<T>, Error> {
75        let res = match p {
76            LaplaceType::Standard => {
77                let m_1 = -T::one();
78                let p_4 = T::from_u8(4).ok_or(Error::NumericError)?;
79                let z = T::zero();
80
81                arr2(&[[z, m_1, z], [m_1, p_4, m_1], [z, m_1, z]])
82            }
83            LaplaceType::Diagonal => {
84                let m_1 = -T::one();
85                let p_8 = T::from_u8(8).ok_or(Error::NumericError)?;
86
87                arr2(&[[m_1, m_1, m_1], [m_1, p_8, m_1], [m_1, m_1, m_1]])
88            }
89        };
90        Ok(res.insert_axis(Axis(2)))
91    }
92}
93
94/// Builds a Gaussian kernel taking the covariance as a parameter. Covariance
95/// is given as 2 values for the x and y variance.
96#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
97pub struct GaussianFilter;
98
99impl<T> KernelBuilder<T> for GaussianFilter
100where
101    T: Copy + Clone + FromPrimitive + Num,
102{
103    /// The parameter for the Gaussian filter is the horizontal and vertical
104    /// covariances to form the covariance matrix.
105    /// ```text
106    /// [ Params[0], 0]
107    /// [ 0, Params[1]]
108    /// ```
109    type Params = [f64; 2];
110
111    fn build<D>(shape: D) -> Result<Array3<T>, Error>
112    where
113        D: Copy + IntoDimension<Dim = Ix3>,
114    {
115        // This recommendation was taken from OpenCV 2.4 docs
116        let s = shape.into_dimension();
117        let sig = 0.3 * (((std::cmp::max(s[0], 1) - 1) as f64) * 0.5 - 1.0) + 0.8;
118        Self::build_with_params(shape, [sig, sig])
119    }
120
121    fn build_with_params<D>(shape: D, covar: Self::Params) -> Result<Array3<T>, Error>
122    where
123        D: Copy + IntoDimension<Dim = Ix3>,
124    {
125        let is_even = |x| x & 1 == 0;
126        let s = shape.into_dimension();
127        if is_even(s[0]) || is_even(s[1]) || s[0] != s[1] || s[2] == 0 {
128            Err(Error::InvalidDimensions)
129        } else if covar[0] <= 0.0f64 || covar[1] <= 0.0f64 {
130            Err(Error::InvalidParameter)
131        } else {
132            let centre: isize = (s[0] as isize + 1) / 2 - 1;
133            let gauss = |coord, covar| ((coord - centre) as f64).powi(2) / (2.0f64 * covar);
134
135            let mut temp = Array2::from_shape_fn((s[0], s[1]), |(r, c)| {
136                f64::exp(-(gauss(r as isize, covar[1]) + gauss(c as isize, covar[0])))
137            });
138
139            let sum = temp.sum();
140
141            temp *= 1.0f64 / sum;
142
143            let temp = temp.mapv(T::from_f64);
144
145            if temp.iter().any(|x| x.is_none()) {
146                Err(Error::NumericError)
147            } else {
148                let temp = temp.mapv(|x| x.unwrap());
149                Ok(Array3::from_shape_fn(shape, |(r, c, _)| temp[[r, c]]))
150            }
151        }
152    }
153}
154
155/// The box linear filter is roughly defined as `1/(R*C)*Array2::ones((R, C))`
156/// This filter will be a box linear for every colour channel provided
157#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
158pub struct BoxLinearFilter;
159
160impl<T> KernelBuilder<T> for BoxLinearFilter
161where
162    T: Float + Num + NumAssignOps + FromPrimitive,
163{
164    /// If false the kernel will not be normalised - this means that pixel bounds
165    /// may be exceeded and overflow may occur
166    type Params = bool;
167
168    fn build<D>(shape: D) -> Result<Array3<T>, Error>
169    where
170        D: Copy + IntoDimension<Dim = Ix3>,
171    {
172        Self::build_with_params(shape, true)
173    }
174
175    fn build_with_params<D>(shape: D, normalise: Self::Params) -> Result<Array3<T>, Error>
176    where
177        D: Copy + IntoDimension<Dim = Ix3>,
178    {
179        let shape = shape.into_dimension();
180        if shape[0] < 1 || shape[1] < 1 || shape[2] < 1 {
181            Err(Error::InvalidDimensions)
182        } else if normalise {
183            let weight = 1.0f64 / ((shape[0] * shape[1]) as f64);
184            match T::from_f64(weight) {
185                Some(weight) => Ok(Array3::from_elem(shape, weight)),
186                None => Err(Error::NumericError),
187            }
188        } else {
189            Ok(Array3::ones(shape))
190        }
191    }
192}
193
194/// Builder to create either a horizontal or vertical Sobel filter for the Sobel
195/// operator
196#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
197pub struct SobelFilter;
198
199/// Orientation of the filter
200#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
201pub enum Orientation {
202    /// Obtain the vertical derivatives of an image
203    Vertical,
204    /// Obtain the horizontal derivatives of an image
205    Horizontal,
206}
207
208impl<T> FixedDimensionKernelBuilder<T> for SobelFilter
209where
210    T: Copy + Clone + Num + Neg<Output = T> + FromPrimitive,
211{
212    /// Orientation of the filter. Default is vertical
213    type Params = Orientation;
214    /// Build a fixed size kernel
215    fn build() -> Result<Array3<T>, Error> {
216        // Arbitary decision
217        Self::build_with_params(Orientation::Vertical)
218    }
219
220    /// Build a fixed size kernel with the given parameters
221    fn build_with_params(p: Self::Params) -> Result<Array3<T>, Error> {
222        let two = T::from_i8(2).ok_or(Error::NumericError)?;
223        // Gets the gradient along the horizontal axis
224        #[rustfmt::skip]
225        let horz_sobel = arr2(&[
226            [T::one(),  T::zero(), -T::one()],
227            [two,       T::zero(), -two],
228            [T::one(),  T::zero(), -T::one()],
229        ]);
230        let sobel = match p {
231            Orientation::Vertical => horz_sobel.t().to_owned(),
232            Orientation::Horizontal => horz_sobel,
233        };
234        Ok(sobel.insert_axis(Axis(2)))
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use ndarray::arr3;
242
243    #[test]
244    fn test_box_linear_filter() {
245        let filter: Array3<f64> = BoxLinearFilter::build(Ix3(2, 2, 3)).unwrap();
246
247        assert_eq!(filter, Array3::from_elem((2, 2, 3), 0.25f64));
248
249        let filter: Result<Array3<f64>, Error> = BoxLinearFilter::build(Ix3(0, 2, 3));
250        assert!(filter.is_err());
251    }
252
253    #[test]
254    fn test_sobel_filter() {
255        // As sobel works with integer numbers I'm going to ignore the perils of
256        // floating point comparisons... for now.
257        let filter: Array3<f32> = SobelFilter::build_with_params(Orientation::Horizontal).unwrap();
258
259        assert_eq!(
260            filter,
261            arr3(&[
262                [[1.0f32], [0.0f32], [-1.0f32]],
263                [[2.0f32], [0.0f32], [-2.0f32]],
264                [[1.0f32], [0.0f32], [-1.0f32]]
265            ])
266        );
267
268        let filter: Array3<f32> = SobelFilter::build_with_params(Orientation::Vertical).unwrap();
269
270        assert_eq!(
271            filter,
272            arr3(&[
273                [[1.0f32], [2.0f32], [1.0f32]],
274                [[0.0f32], [0.0f32], [0.0f32]],
275                [[-1.0f32], [-2.0f32], [-1.0f32]]
276            ])
277        )
278    }
279
280    #[test]
281    fn test_gaussian_filter() {
282        let bad_gauss: Result<Array3<f64>, _> = GaussianFilter::build(Ix3(3, 5, 2));
283        assert_eq!(bad_gauss, Err(Error::InvalidDimensions));
284        let bad_gauss: Result<Array3<f64>, _> = GaussianFilter::build(Ix3(4, 4, 2));
285        assert_eq!(bad_gauss, Err(Error::InvalidDimensions));
286        let bad_gauss: Result<Array3<f64>, _> = GaussianFilter::build(Ix3(4, 0, 2));
287        assert_eq!(bad_gauss, Err(Error::InvalidDimensions));
288
289        let channels = 2;
290        let filter: Array3<f64> =
291            GaussianFilter::build_with_params(Ix3(3, 3, channels), [0.3, 0.3]).unwrap();
292
293        assert_eq!(filter.sum().round(), channels as f64);
294
295        let filter: Array3<f64> =
296            GaussianFilter::build_with_params(Ix3(3, 3, 1), [0.05, 0.05]).unwrap();
297
298        let filter = filter.mapv(|x| x.round() as u8);
299        // Need to do a proper test but this should cover enough
300        assert_eq!(
301            filter,
302            arr3(&[[[0], [0], [0]], [[0], [1], [0]], [[0], [0], [0]]])
303        );
304    }
305
306    #[test]
307    fn test_laplace_filters() {
308        let standard: Array3<i64> = LaplaceFilter::build().unwrap();
309        assert_eq!(
310            standard,
311            arr3(&[[[0], [-1], [0]], [[-1], [4], [-1]], [[0], [-1], [0]]])
312        );
313
314        let standard: Array3<i64> =
315            LaplaceFilter::build_with_params(LaplaceType::Diagonal).unwrap();
316        assert_eq!(
317            standard,
318            arr3(&[[[-1], [-1], [-1]], [[-1], [8], [-1]], [[-1], [-1], [-1]]])
319        );
320    }
321}