Skip to main content

burn_vision/backends/cpu/morphology/
mod.rs

1use std::fmt::Debug;
2
3use burn_tensor::{
4    BasicOps, Bool, DType, Element, Shape, Tensor, TensorData, backend::Backend, cast::ToElement,
5    ops::BoolTensor,
6};
7use filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator};
8use filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter};
9use macerator::{Simd, VOrd};
10
11use crate::{BorderType, MorphOptions, Point, Size};
12
13use super::MinMax;
14
15mod filter;
16mod filter_engine;
17
18/// A morphology operation.
19/// TODO: Implement composite ops
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
21pub enum MorphOp {
22    Erode,
23    Dilate,
24}
25
26pub enum MorphKernel<B: Element> {
27    Rect {
28        size: Size,
29        anchor: Point,
30    },
31    Other {
32        kernel: Vec<B>,
33        size: Size,
34        anchor: Point,
35    },
36}
37
38pub fn morph<B: Backend, K: BasicOps<B>>(
39    input: Tensor<B, 3, K>,
40    kernel: BoolTensor<B>,
41    op: MorphOp,
42    opts: MorphOptions<B, K>,
43) -> Tensor<B, 3, K> {
44    let device = input.device();
45
46    let kernel = Tensor::<B, 2, Bool>::new(kernel);
47    let kshape = kernel.shape().dims();
48    let [kh, kw] = kshape;
49
50    let kernel = kernel.into_data().into_vec::<B::BoolElem>().unwrap();
51    let is_rect = kernel.iter().all(|it| it.to_bool());
52    let anchor = opts.anchor.unwrap_or(Point::new(kw / 2, kh / 2));
53    let iter = opts.iterations;
54    let btype = opts.border_type;
55    let bvalue = opts.border_value.map(|it| it.into_data());
56
57    let size = Size::new(kw, kh);
58    let kernel = if is_rect {
59        MorphKernel::Rect { size, anchor }
60    } else {
61        MorphKernel::Other {
62            kernel,
63            size,
64            anchor,
65        }
66    };
67
68    let shape = input.shape();
69    let data = input.into_data();
70    match data.dtype {
71        DType::F64 => {
72            morph_typed::<B, K, f64>(data, shape, kernel, op, iter, btype, bvalue, &device)
73        }
74        DType::F32 | DType::Flex32 => {
75            morph_typed::<B, K, f32>(data, shape, kernel, op, iter, btype, bvalue, &device)
76        }
77        DType::F16 | DType::BF16 => morph_typed::<B, K, f32>(
78            data.convert::<f32>(),
79            shape,
80            kernel,
81            op,
82            iter,
83            btype,
84            bvalue,
85            &device,
86        ),
87        DType::I64 => {
88            morph_typed::<B, K, i64>(data, shape, kernel, op, iter, btype, bvalue, &device)
89        }
90        DType::I32 => {
91            morph_typed::<B, K, i32>(data, shape, kernel, op, iter, btype, bvalue, &device)
92        }
93        DType::I16 => {
94            morph_typed::<B, K, i16>(data, shape, kernel, op, iter, btype, bvalue, &device)
95        }
96        DType::I8 => morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device),
97        DType::U64 => {
98            morph_typed::<B, K, u64>(data, shape, kernel, op, iter, btype, bvalue, &device)
99        }
100        DType::U32 => {
101            morph_typed::<B, K, u32>(data, shape, kernel, op, iter, btype, bvalue, &device)
102        }
103        DType::U16 => {
104            morph_typed::<B, K, u16>(data, shape, kernel, op, iter, btype, bvalue, &device)
105        }
106        DType::U8 => morph_typed::<B, K, u8>(data, shape, kernel, op, iter, btype, bvalue, &device),
107        DType::Bool => morph_bool::<B, K>(data, shape, kernel, op, iter, btype, bvalue, &device),
108        DType::QFloat(_) => unimplemented!(),
109    }
110}
111
112#[allow(clippy::too_many_arguments)]
113fn morph_typed<B: Backend, K: BasicOps<B>, T: VOrd + MinMax + Element>(
114    mut input: TensorData,
115    shape: Shape,
116    kernel: MorphKernel<B::BoolElem>,
117    op: MorphOp,
118    iter: usize,
119    btype: BorderType,
120    bvalue: Option<TensorData>,
121    device: &B::Device,
122) -> Tensor<B, 3, K> {
123    let data = input.as_mut_slice::<T>().unwrap();
124    let bvalue = border_value(btype, bvalue, op, &shape);
125    run_morph(data, shape, kernel, op, iter, btype, &bvalue);
126    Tensor::from_data(input, device)
127}
128
129#[allow(clippy::too_many_arguments)]
130fn morph_bool<B: Backend, K: BasicOps<B>>(
131    mut input: TensorData,
132    shape: Shape,
133    kernel: MorphKernel<B::BoolElem>,
134    op: MorphOp,
135    iter: usize,
136    btype: BorderType,
137    bvalue: Option<TensorData>,
138    device: &B::Device,
139) -> Tensor<B, 3, K> {
140    let data = input.as_mut_slice::<bool>().unwrap();
141    // SAFETY: Morph can't produce invalid boolean values
142    let data = unsafe { core::mem::transmute::<&mut [bool], &mut [u8]>(data) };
143    let bvalue = border_value(btype, bvalue, op, &shape);
144    run_morph(data, shape.clone(), kernel, op, iter, btype, &bvalue);
145    Tensor::from_data(input, device)
146}
147
148fn border_value<T: Element>(
149    btype: BorderType,
150    bvalue: Option<TensorData>,
151    op: MorphOp,
152    shape: &Shape,
153) -> Vec<T> {
154    let [_, _, ch] = shape.dims();
155    match (btype, bvalue) {
156        (BorderType::Constant, Some(value)) => value.convert::<T>().into_vec().unwrap(),
157        (BorderType::Constant, None) => match op {
158            MorphOp::Erode => vec![T::MAX; ch],
159            MorphOp::Dilate => vec![T::MIN; ch],
160        },
161        _ => vec![],
162    }
163}
164
165fn run_morph<T: VOrd + MinMax + Element, B: Element>(
166    input: &mut [T],
167    shape: Shape,
168    kernel: MorphKernel<B>,
169    op: MorphOp,
170    iter: usize,
171    btype: BorderType,
172    bvalue: &[T],
173) {
174    match op {
175        MorphOp::Erode => {
176            let filter = filter::<T, MinOp, B>(kernel);
177            dispatch_morph(input, shape, filter, btype, bvalue, iter);
178        }
179        MorphOp::Dilate => {
180            let filter = filter::<T, MaxOp, B>(kernel);
181            dispatch_morph(input, shape, filter, btype, bvalue, iter);
182        }
183    };
184}
185
186fn filter<T: VOrd + MinMax, Op: MorphOperator<T> + VecMorphOperator<T>, B: Element>(
187    kernel: MorphKernel<B>,
188) -> Filter<T, Op> {
189    match kernel {
190        MorphKernel::Rect { size, anchor } => {
191            let row_filter = RowFilter::new(size.width, anchor.x);
192            let col_filter = ColFilter::new(size.height, anchor.y);
193            Filter::Separable {
194                row_filter,
195                col_filter,
196            }
197        }
198        MorphKernel::Other {
199            kernel,
200            size,
201            anchor,
202        } => {
203            let filter = Filter2D::new(&kernel, size, anchor);
204            Filter::Fallback(filter)
205        }
206    }
207}
208
209#[inline(always)]
210#[allow(clippy::too_many_arguments)]
211#[macerator::with_simd]
212fn dispatch_morph<
213    'a,
214    S: Simd,
215    T: VOrd + MinMax + Debug,
216    Op: MorphOperator<T> + VecMorphOperator<T>,
217>(
218    buffer: &'a mut [T],
219    buffer_shape: Shape,
220    filter: filter_engine::Filter<T, Op>,
221    border_type: BorderType,
222    border_value: &'a [T],
223    iterations: usize,
224) where
225    'a: 'a,
226{
227    let [_, _, ch] = buffer_shape.dims();
228    let mut engine = FilterEngine::<S, _, _>::new(filter, border_type, border_value, ch);
229    engine.apply(buffer, buffer_shape.clone());
230    for _ in 1..iterations {
231        engine.apply(buffer, buffer_shape.clone());
232    }
233}
234
235/// Shape of the structuring element
236#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
237pub enum KernelShape {
238    /// Rectangular kernel
239    Rect,
240    /// Cross shaped kernel
241    Cross,
242    /// Ellipse shaped kernel
243    Ellipse,
244}
245
246/// Create a structuring element tensor for use with morphology ops
247pub fn create_structuring_element<B: Backend>(
248    shape: KernelShape,
249    ksize: Size,
250    anchor: Option<Point>,
251    device: &B::Device,
252) -> Tensor<B, 2, Bool> {
253    fn create_kernel(shape: KernelShape, ksize: Size, anchor: Option<Point>) -> Vec<bool> {
254        let anchor = anchor.unwrap_or(Point::new(ksize.width / 2, ksize.height / 2));
255        let mut r = 0;
256        let mut c = 0;
257        let mut inv_r2 = 0.0;
258
259        if (ksize.width == 1 && ksize.height == 1) || shape == KernelShape::Rect {
260            return vec![true; ksize.height * ksize.width];
261        }
262
263        if shape == KernelShape::Ellipse {
264            r = ksize.height / 2;
265            c = ksize.width / 2;
266            inv_r2 = if r > 0 { 1.0 / (r * r) as f64 } else { 0.0 }
267        }
268
269        let mut elem = vec![false; ksize.height * ksize.width];
270
271        for i in 0..ksize.height {
272            let mut j1 = 0;
273            let mut j2 = 0;
274            if shape == KernelShape::Cross && i == anchor.y {
275                j2 = ksize.width;
276            } else if shape == KernelShape::Cross {
277                j1 = anchor.x;
278                j2 = j1 + 1;
279            } else {
280                let dy = i as isize - r as isize;
281                if dy.abs() <= r as isize {
282                    let dx = (c as f64 * ((r * r - (dy * dy) as usize) as f64 * inv_r2).sqrt())
283                        .round() as isize;
284                    j1 = (c as isize - dx).max(0) as usize;
285                    j2 = (c + dx as usize + 1).min(ksize.width);
286                }
287            }
288
289            for j in j1..j2 {
290                elem[i * ksize.width + j] = true;
291            }
292        }
293        elem
294    }
295
296    let elem = create_kernel(shape, ksize, anchor);
297
298    let data = TensorData::new(elem, [ksize.height, ksize.width]);
299    Tensor::from_data(data, device)
300}