1use crate::processing::Error;
2use core::ops::Neg;
3use ndarray::prelude::*;
4use ndarray::IntoDimension;
5use num_traits::{cast::FromPrimitive, float::Float, sign::Signed, Num, NumAssignOps, NumOps};
6
7pub trait KernelBuilder<T> {
9 type Params;
11 fn build<D>(shape: D) -> Result<Array3<T>, Error>
14 where
15 D: Copy + IntoDimension<Dim = Ix3>;
16 fn build_with_params<D>(shape: D, _p: Self::Params) -> Result<Array3<T>, Error>
19 where
20 D: Copy + IntoDimension<Dim = Ix3>,
21 {
22 Self::build(shape)
23 }
24}
25
26pub trait FixedDimensionKernelBuilder<T> {
28 type Params;
30 fn build() -> Result<Array3<T>, Error>;
32 fn build_with_params(_p: Self::Params) -> Result<Array3<T>, Error> {
34 Self::build()
35 }
36}
37
38#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
41pub struct LaplaceFilter;
42
43#[derive(Copy, Clone, Eq, PartialEq, Hash)]
45pub enum LaplaceType {
46 Standard,
53 Diagonal,
61}
62
63impl<T> FixedDimensionKernelBuilder<T> for LaplaceFilter
64where
65 T: Copy + Clone + Num + NumOps + Signed + FromPrimitive,
66{
67 type Params = LaplaceType;
69
70 fn build() -> Result<Array3<T>, Error> {
71 Self::build_with_params(LaplaceType::Standard)
72 }
73
74 fn build_with_params(p: Self::Params) -> Result<Array3<T>, Error> {
75 let res = match p {
76 LaplaceType::Standard => {
77 let m_1 = -T::one();
78 let p_4 = T::from_u8(4).ok_or(Error::NumericError)?;
79 let z = T::zero();
80
81 arr2(&[[z, m_1, z], [m_1, p_4, m_1], [z, m_1, z]])
82 }
83 LaplaceType::Diagonal => {
84 let m_1 = -T::one();
85 let p_8 = T::from_u8(8).ok_or(Error::NumericError)?;
86
87 arr2(&[[m_1, m_1, m_1], [m_1, p_8, m_1], [m_1, m_1, m_1]])
88 }
89 };
90 Ok(res.insert_axis(Axis(2)))
91 }
92}
93
94#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
97pub struct GaussianFilter;
98
99impl<T> KernelBuilder<T> for GaussianFilter
100where
101 T: Copy + Clone + FromPrimitive + Num,
102{
103 type Params = [f64; 2];
110
111 fn build<D>(shape: D) -> Result<Array3<T>, Error>
112 where
113 D: Copy + IntoDimension<Dim = Ix3>,
114 {
115 let s = shape.into_dimension();
117 let sig = 0.3 * (((std::cmp::max(s[0], 1) - 1) as f64) * 0.5 - 1.0) + 0.8;
118 Self::build_with_params(shape, [sig, sig])
119 }
120
121 fn build_with_params<D>(shape: D, covar: Self::Params) -> Result<Array3<T>, Error>
122 where
123 D: Copy + IntoDimension<Dim = Ix3>,
124 {
125 let is_even = |x| x & 1 == 0;
126 let s = shape.into_dimension();
127 if is_even(s[0]) || is_even(s[1]) || s[0] != s[1] || s[2] == 0 {
128 Err(Error::InvalidDimensions)
129 } else if covar[0] <= 0.0f64 || covar[1] <= 0.0f64 {
130 Err(Error::InvalidParameter)
131 } else {
132 let centre: isize = (s[0] as isize + 1) / 2 - 1;
133 let gauss = |coord, covar| ((coord - centre) as f64).powi(2) / (2.0f64 * covar);
134
135 let mut temp = Array2::from_shape_fn((s[0], s[1]), |(r, c)| {
136 f64::exp(-(gauss(r as isize, covar[1]) + gauss(c as isize, covar[0])))
137 });
138
139 let sum = temp.sum();
140
141 temp *= 1.0f64 / sum;
142
143 let temp = temp.mapv(T::from_f64);
144
145 if temp.iter().any(|x| x.is_none()) {
146 Err(Error::NumericError)
147 } else {
148 let temp = temp.mapv(|x| x.unwrap());
149 Ok(Array3::from_shape_fn(shape, |(r, c, _)| temp[[r, c]]))
150 }
151 }
152 }
153}
154
155#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
158pub struct BoxLinearFilter;
159
160impl<T> KernelBuilder<T> for BoxLinearFilter
161where
162 T: Float + Num + NumAssignOps + FromPrimitive,
163{
164 type Params = bool;
167
168 fn build<D>(shape: D) -> Result<Array3<T>, Error>
169 where
170 D: Copy + IntoDimension<Dim = Ix3>,
171 {
172 Self::build_with_params(shape, true)
173 }
174
175 fn build_with_params<D>(shape: D, normalise: Self::Params) -> Result<Array3<T>, Error>
176 where
177 D: Copy + IntoDimension<Dim = Ix3>,
178 {
179 let shape = shape.into_dimension();
180 if shape[0] < 1 || shape[1] < 1 || shape[2] < 1 {
181 Err(Error::InvalidDimensions)
182 } else if normalise {
183 let weight = 1.0f64 / ((shape[0] * shape[1]) as f64);
184 match T::from_f64(weight) {
185 Some(weight) => Ok(Array3::from_elem(shape, weight)),
186 None => Err(Error::NumericError),
187 }
188 } else {
189 Ok(Array3::ones(shape))
190 }
191 }
192}
193
194#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
197pub struct SobelFilter;
198
199#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
201pub enum Orientation {
202 Vertical,
204 Horizontal,
206}
207
208impl<T> FixedDimensionKernelBuilder<T> for SobelFilter
209where
210 T: Copy + Clone + Num + Neg<Output = T> + FromPrimitive,
211{
212 type Params = Orientation;
214 fn build() -> Result<Array3<T>, Error> {
216 Self::build_with_params(Orientation::Vertical)
218 }
219
220 fn build_with_params(p: Self::Params) -> Result<Array3<T>, Error> {
222 let two = T::from_i8(2).ok_or(Error::NumericError)?;
223 #[rustfmt::skip]
225 let horz_sobel = arr2(&[
226 [T::one(), T::zero(), -T::one()],
227 [two, T::zero(), -two],
228 [T::one(), T::zero(), -T::one()],
229 ]);
230 let sobel = match p {
231 Orientation::Vertical => horz_sobel.t().to_owned(),
232 Orientation::Horizontal => horz_sobel,
233 };
234 Ok(sobel.insert_axis(Axis(2)))
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use ndarray::arr3;
242
243 #[test]
244 fn test_box_linear_filter() {
245 let filter: Array3<f64> = BoxLinearFilter::build(Ix3(2, 2, 3)).unwrap();
246
247 assert_eq!(filter, Array3::from_elem((2, 2, 3), 0.25f64));
248
249 let filter: Result<Array3<f64>, Error> = BoxLinearFilter::build(Ix3(0, 2, 3));
250 assert!(filter.is_err());
251 }
252
253 #[test]
254 fn test_sobel_filter() {
255 let filter: Array3<f32> = SobelFilter::build_with_params(Orientation::Horizontal).unwrap();
258
259 assert_eq!(
260 filter,
261 arr3(&[
262 [[1.0f32], [0.0f32], [-1.0f32]],
263 [[2.0f32], [0.0f32], [-2.0f32]],
264 [[1.0f32], [0.0f32], [-1.0f32]]
265 ])
266 );
267
268 let filter: Array3<f32> = SobelFilter::build_with_params(Orientation::Vertical).unwrap();
269
270 assert_eq!(
271 filter,
272 arr3(&[
273 [[1.0f32], [2.0f32], [1.0f32]],
274 [[0.0f32], [0.0f32], [0.0f32]],
275 [[-1.0f32], [-2.0f32], [-1.0f32]]
276 ])
277 )
278 }
279
280 #[test]
281 fn test_gaussian_filter() {
282 let bad_gauss: Result<Array3<f64>, _> = GaussianFilter::build(Ix3(3, 5, 2));
283 assert_eq!(bad_gauss, Err(Error::InvalidDimensions));
284 let bad_gauss: Result<Array3<f64>, _> = GaussianFilter::build(Ix3(4, 4, 2));
285 assert_eq!(bad_gauss, Err(Error::InvalidDimensions));
286 let bad_gauss: Result<Array3<f64>, _> = GaussianFilter::build(Ix3(4, 0, 2));
287 assert_eq!(bad_gauss, Err(Error::InvalidDimensions));
288
289 let channels = 2;
290 let filter: Array3<f64> =
291 GaussianFilter::build_with_params(Ix3(3, 3, channels), [0.3, 0.3]).unwrap();
292
293 assert_eq!(filter.sum().round(), channels as f64);
294
295 let filter: Array3<f64> =
296 GaussianFilter::build_with_params(Ix3(3, 3, 1), [0.05, 0.05]).unwrap();
297
298 let filter = filter.mapv(|x| x.round() as u8);
299 assert_eq!(
301 filter,
302 arr3(&[[[0], [0], [0]], [[0], [1], [0]], [[0], [0], [0]]])
303 );
304 }
305
306 #[test]
307 fn test_laplace_filters() {
308 let standard: Array3<i64> = LaplaceFilter::build().unwrap();
309 assert_eq!(
310 standard,
311 arr3(&[[[0], [-1], [0]], [[-1], [4], [-1]], [[0], [-1], [0]]])
312 );
313
314 let standard: Array3<i64> =
315 LaplaceFilter::build_with_params(LaplaceType::Diagonal).unwrap();
316 assert_eq!(
317 standard,
318 arr3(&[[[-1], [-1], [-1]], [[-1], [8], [-1]], [[-1], [-1], [-1]]])
319 );
320 }
321}