burn_vision/backends/cpu/morphology/
mod.rs

1use std::fmt::Debug;
2
3use burn_tensor::{
4    BasicOps, Bool, DType, Element, Shape, Tensor, TensorData,
5    backend::Backend,
6    cast::ToElement,
7    ops::BoolTensor,
8    quantization::{QuantizationScheme, QuantizationType},
9};
10use filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator};
11use filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter};
12use macerator::{Simd, VOrd};
13
14use crate::{BorderType, MorphOptions, Point, Size};
15
16use super::MinMax;
17
18mod filter;
19mod filter_engine;
20
21/// A morphology operation.
22/// TODO: Implement composite ops
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
24pub enum MorphOp {
25    Erode,
26    Dilate,
27}
28
29pub enum MorphKernel<B: Element> {
30    Rect {
31        size: Size,
32        anchor: Point,
33    },
34    Other {
35        kernel: Vec<B>,
36        size: Size,
37        anchor: Point,
38    },
39}
40
41pub fn morph<B: Backend, K: BasicOps<B>>(
42    input: Tensor<B, 3, K>,
43    kernel: BoolTensor<B>,
44    op: MorphOp,
45    opts: MorphOptions<B, K>,
46) -> Tensor<B, 3, K> {
47    let device = input.device();
48
49    let kernel = Tensor::<B, 2, Bool>::new(kernel);
50    let kshape = kernel.shape().dims();
51    let [kh, kw] = kshape;
52
53    let kernel = kernel.into_data().into_vec::<B::BoolElem>().unwrap();
54    let is_rect = kernel.iter().all(|it| it.to_bool());
55    let anchor = opts.anchor.unwrap_or(Point::new(kw / 2, kh / 2));
56    let iter = opts.iterations;
57    let btype = opts.border_type;
58    let bvalue = opts.border_value.map(|it| it.into_data());
59
60    let size = Size::new(kw, kh);
61    let kernel = if is_rect {
62        MorphKernel::Rect { size, anchor }
63    } else {
64        MorphKernel::Other {
65            kernel,
66            size,
67            anchor,
68        }
69    };
70
71    let shape = input.shape();
72    let data = input.into_data();
73    match data.dtype {
74        DType::F64 => {
75            morph_typed::<B, K, f64>(data, shape, kernel, op, iter, btype, bvalue, &device)
76        }
77        DType::F32 | DType::Flex32 => {
78            morph_typed::<B, K, f32>(data, shape, kernel, op, iter, btype, bvalue, &device)
79        }
80        DType::F16 | DType::BF16 => morph_typed::<B, K, f32>(
81            data.convert::<f32>(),
82            shape,
83            kernel,
84            op,
85            iter,
86            btype,
87            bvalue,
88            &device,
89        ),
90        DType::I64 => {
91            morph_typed::<B, K, i64>(data, shape, kernel, op, iter, btype, bvalue, &device)
92        }
93        DType::I32 => {
94            morph_typed::<B, K, i32>(data, shape, kernel, op, iter, btype, bvalue, &device)
95        }
96        DType::I16 => {
97            morph_typed::<B, K, i16>(data, shape, kernel, op, iter, btype, bvalue, &device)
98        }
99        DType::I8 => morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device),
100        DType::U64 => {
101            morph_typed::<B, K, u64>(data, shape, kernel, op, iter, btype, bvalue, &device)
102        }
103        DType::U32 => {
104            morph_typed::<B, K, u32>(data, shape, kernel, op, iter, btype, bvalue, &device)
105        }
106        DType::U16 => {
107            morph_typed::<B, K, u16>(data, shape, kernel, op, iter, btype, bvalue, &device)
108        }
109        DType::U8 => morph_typed::<B, K, u8>(data, shape, kernel, op, iter, btype, bvalue, &device),
110        DType::Bool => morph_bool::<B, K>(data, shape, kernel, op, iter, btype, bvalue, &device),
111        DType::QFloat(scheme) => match scheme {
112            QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
113                morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device)
114            }
115        },
116    }
117}
118
119#[allow(clippy::too_many_arguments)]
120fn morph_typed<B: Backend, K: BasicOps<B>, T: VOrd + MinMax + Element>(
121    mut input: TensorData,
122    shape: Shape,
123    kernel: MorphKernel<B::BoolElem>,
124    op: MorphOp,
125    iter: usize,
126    btype: BorderType,
127    bvalue: Option<TensorData>,
128    device: &B::Device,
129) -> Tensor<B, 3, K> {
130    let data = input.as_mut_slice::<T>().unwrap();
131    let bvalue = border_value(btype, bvalue, op, &shape);
132    run_morph(data, shape, kernel, op, iter, btype, &bvalue);
133    Tensor::from_data(input, device)
134}
135
136#[allow(clippy::too_many_arguments)]
137fn morph_bool<B: Backend, K: BasicOps<B>>(
138    mut input: TensorData,
139    shape: Shape,
140    kernel: MorphKernel<B::BoolElem>,
141    op: MorphOp,
142    iter: usize,
143    btype: BorderType,
144    bvalue: Option<TensorData>,
145    device: &B::Device,
146) -> Tensor<B, 3, K> {
147    let data = input.as_mut_slice::<bool>().unwrap();
148    // SAFETY: Morph can't produce invalid boolean values
149    let data = unsafe { core::mem::transmute::<&mut [bool], &mut [u8]>(data) };
150    let bvalue = border_value(btype, bvalue, op, &shape);
151    run_morph(data, shape.clone(), kernel, op, iter, btype, &bvalue);
152    Tensor::from_data(input, device)
153}
154
155fn border_value<T: Element>(
156    btype: BorderType,
157    bvalue: Option<TensorData>,
158    op: MorphOp,
159    shape: &Shape,
160) -> Vec<T> {
161    let [_, _, ch] = shape.dims();
162    match (btype, bvalue) {
163        (BorderType::Constant, Some(value)) => value.convert::<T>().into_vec().unwrap(),
164        (BorderType::Constant, None) => match op {
165            MorphOp::Erode => vec![T::MAX; ch],
166            MorphOp::Dilate => vec![T::MIN; ch],
167        },
168        _ => vec![],
169    }
170}
171
172fn run_morph<T: VOrd + MinMax + Element, B: Element>(
173    input: &mut [T],
174    shape: Shape,
175    kernel: MorphKernel<B>,
176    op: MorphOp,
177    iter: usize,
178    btype: BorderType,
179    bvalue: &[T],
180) {
181    match op {
182        MorphOp::Erode => {
183            let filter = filter::<T, MinOp, B>(kernel);
184            dispatch_morph(input, shape, filter, btype, bvalue, iter);
185        }
186        MorphOp::Dilate => {
187            let filter = filter::<T, MaxOp, B>(kernel);
188            dispatch_morph(input, shape, filter, btype, bvalue, iter);
189        }
190    };
191}
192
193fn filter<T: VOrd + MinMax, Op: MorphOperator<T> + VecMorphOperator<T>, B: Element>(
194    kernel: MorphKernel<B>,
195) -> Filter<T, Op> {
196    match kernel {
197        MorphKernel::Rect { size, anchor } => {
198            let row_filter = RowFilter::new(size.width, anchor.x);
199            let col_filter = ColFilter::new(size.height, anchor.y);
200            Filter::Separable {
201                row_filter,
202                col_filter,
203            }
204        }
205        MorphKernel::Other {
206            kernel,
207            size,
208            anchor,
209        } => {
210            let filter = Filter2D::new(&kernel, size, anchor);
211            Filter::Fallback(filter)
212        }
213    }
214}
215
216#[inline(always)]
217#[allow(clippy::too_many_arguments)]
218#[macerator::with_simd]
219fn dispatch_morph<
220    'a,
221    S: Simd,
222    T: VOrd + MinMax + Debug,
223    Op: MorphOperator<T> + VecMorphOperator<T>,
224>(
225    buffer: &'a mut [T],
226    buffer_shape: Shape,
227    filter: filter_engine::Filter<T, Op>,
228    border_type: BorderType,
229    border_value: &'a [T],
230    iterations: usize,
231) where
232    'a: 'a,
233{
234    let [_, _, ch] = buffer_shape.dims();
235    let mut engine = FilterEngine::<S, _, _>::new(filter, border_type, border_value, ch);
236    engine.apply(buffer, buffer_shape.clone());
237    for _ in 1..iterations {
238        engine.apply(buffer, buffer_shape.clone());
239    }
240}
241
242/// Shape of the structuring element
243#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
244pub enum KernelShape {
245    /// Rectangular kernel
246    Rect,
247    /// Cross shaped kernel
248    Cross,
249    /// Ellipse shaped kernel
250    Ellipse,
251}
252
253/// Create a structuring element tensor for use with morphology ops
254pub fn create_structuring_element<B: Backend>(
255    shape: KernelShape,
256    ksize: Size,
257    anchor: Option<Point>,
258    device: &B::Device,
259) -> Tensor<B, 2, Bool> {
260    fn create_kernel(shape: KernelShape, ksize: Size, anchor: Option<Point>) -> Vec<bool> {
261        let anchor = anchor.unwrap_or(Point::new(ksize.width / 2, ksize.height / 2));
262        let mut r = 0;
263        let mut c = 0;
264        let mut inv_r2 = 0.0;
265
266        if (ksize.width == 1 && ksize.height == 1) || shape == KernelShape::Rect {
267            return vec![true; ksize.height * ksize.width];
268        }
269
270        if shape == KernelShape::Ellipse {
271            r = ksize.height / 2;
272            c = ksize.width / 2;
273            inv_r2 = if r > 0 { 1.0 / (r * r) as f64 } else { 0.0 }
274        }
275
276        let mut elem = vec![false; ksize.height * ksize.width];
277
278        for i in 0..ksize.height {
279            let mut j1 = 0;
280            let mut j2 = 0;
281            if shape == KernelShape::Cross && i == anchor.y {
282                j2 = ksize.width;
283            } else if shape == KernelShape::Cross {
284                j1 = anchor.x;
285                j2 = j1 + 1;
286            } else {
287                let dy = i as isize - r as isize;
288                if dy.abs() <= r as isize {
289                    let dx = (c as f64 * ((r * r - (dy * dy) as usize) as f64 * inv_r2).sqrt())
290                        .round() as isize;
291                    j1 = (c as isize - dx).max(0) as usize;
292                    j2 = (c + dx as usize + 1).min(ksize.width);
293                }
294            }
295
296            for j in j1..j2 {
297                elem[i * ksize.width + j] = true;
298            }
299        }
300        elem
301    }
302
303    let elem = create_kernel(shape, ksize, anchor);
304
305    let data = TensorData::new(elem, [ksize.height, ksize.width]);
306    Tensor::from_data(data, device)
307}