1use std::fmt::Debug;
2
3use burn_tensor::{
4 BasicOps, Bool, DType, Element, Shape, Tensor, TensorData,
5 backend::Backend,
6 cast::ToElement,
7 ops::BoolTensor,
8 quantization::{QuantizationScheme, QuantizationType},
9};
10use filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator};
11use filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter};
12use macerator::{Simd, VOrd};
13
14use crate::{BorderType, MorphOptions, Point, Size};
15
16use super::MinMax;
17
18mod filter;
19mod filter_engine;
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
24pub enum MorphOp {
25 Erode,
26 Dilate,
27}
28
29pub enum MorphKernel<B: Element> {
30 Rect {
31 size: Size,
32 anchor: Point,
33 },
34 Other {
35 kernel: Vec<B>,
36 size: Size,
37 anchor: Point,
38 },
39}
40
41pub fn morph<B: Backend, K: BasicOps<B>>(
42 input: Tensor<B, 3, K>,
43 kernel: BoolTensor<B>,
44 op: MorphOp,
45 opts: MorphOptions<B, K>,
46) -> Tensor<B, 3, K> {
47 let device = input.device();
48
49 let kernel = Tensor::<B, 2, Bool>::new(kernel);
50 let kshape = kernel.shape().dims();
51 let [kh, kw] = kshape;
52
53 let kernel = kernel.into_data().into_vec::<B::BoolElem>().unwrap();
54 let is_rect = kernel.iter().all(|it| it.to_bool());
55 let anchor = opts.anchor.unwrap_or(Point::new(kw / 2, kh / 2));
56 let iter = opts.iterations;
57 let btype = opts.border_type;
58 let bvalue = opts.border_value.map(|it| it.into_data());
59
60 let size = Size::new(kw, kh);
61 let kernel = if is_rect {
62 MorphKernel::Rect { size, anchor }
63 } else {
64 MorphKernel::Other {
65 kernel,
66 size,
67 anchor,
68 }
69 };
70
71 let shape = input.shape();
72 let data = input.into_data();
73 match data.dtype {
74 DType::F64 => {
75 morph_typed::<B, K, f64>(data, shape, kernel, op, iter, btype, bvalue, &device)
76 }
77 DType::F32 | DType::Flex32 => {
78 morph_typed::<B, K, f32>(data, shape, kernel, op, iter, btype, bvalue, &device)
79 }
80 DType::F16 | DType::BF16 => morph_typed::<B, K, f32>(
81 data.convert::<f32>(),
82 shape,
83 kernel,
84 op,
85 iter,
86 btype,
87 bvalue,
88 &device,
89 ),
90 DType::I64 => {
91 morph_typed::<B, K, i64>(data, shape, kernel, op, iter, btype, bvalue, &device)
92 }
93 DType::I32 => {
94 morph_typed::<B, K, i32>(data, shape, kernel, op, iter, btype, bvalue, &device)
95 }
96 DType::I16 => {
97 morph_typed::<B, K, i16>(data, shape, kernel, op, iter, btype, bvalue, &device)
98 }
99 DType::I8 => morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device),
100 DType::U64 => {
101 morph_typed::<B, K, u64>(data, shape, kernel, op, iter, btype, bvalue, &device)
102 }
103 DType::U32 => {
104 morph_typed::<B, K, u32>(data, shape, kernel, op, iter, btype, bvalue, &device)
105 }
106 DType::U16 => {
107 morph_typed::<B, K, u16>(data, shape, kernel, op, iter, btype, bvalue, &device)
108 }
109 DType::U8 => morph_typed::<B, K, u8>(data, shape, kernel, op, iter, btype, bvalue, &device),
110 DType::Bool => morph_bool::<B, K>(data, shape, kernel, op, iter, btype, bvalue, &device),
111 DType::QFloat(scheme) => match scheme {
112 QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
113 morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device)
114 }
115 },
116 }
117}
118
119#[allow(clippy::too_many_arguments)]
120fn morph_typed<B: Backend, K: BasicOps<B>, T: VOrd + MinMax + Element>(
121 mut input: TensorData,
122 shape: Shape,
123 kernel: MorphKernel<B::BoolElem>,
124 op: MorphOp,
125 iter: usize,
126 btype: BorderType,
127 bvalue: Option<TensorData>,
128 device: &B::Device,
129) -> Tensor<B, 3, K> {
130 let data = input.as_mut_slice::<T>().unwrap();
131 let bvalue = border_value(btype, bvalue, op, &shape);
132 run_morph(data, shape, kernel, op, iter, btype, &bvalue);
133 Tensor::from_data(input, device)
134}
135
136#[allow(clippy::too_many_arguments)]
137fn morph_bool<B: Backend, K: BasicOps<B>>(
138 mut input: TensorData,
139 shape: Shape,
140 kernel: MorphKernel<B::BoolElem>,
141 op: MorphOp,
142 iter: usize,
143 btype: BorderType,
144 bvalue: Option<TensorData>,
145 device: &B::Device,
146) -> Tensor<B, 3, K> {
147 let data = input.as_mut_slice::<bool>().unwrap();
148 let data = unsafe { core::mem::transmute::<&mut [bool], &mut [u8]>(data) };
150 let bvalue = border_value(btype, bvalue, op, &shape);
151 run_morph(data, shape.clone(), kernel, op, iter, btype, &bvalue);
152 Tensor::from_data(input, device)
153}
154
155fn border_value<T: Element>(
156 btype: BorderType,
157 bvalue: Option<TensorData>,
158 op: MorphOp,
159 shape: &Shape,
160) -> Vec<T> {
161 let [_, _, ch] = shape.dims();
162 match (btype, bvalue) {
163 (BorderType::Constant, Some(value)) => value.convert::<T>().into_vec().unwrap(),
164 (BorderType::Constant, None) => match op {
165 MorphOp::Erode => vec![T::MAX; ch],
166 MorphOp::Dilate => vec![T::MIN; ch],
167 },
168 _ => vec![],
169 }
170}
171
172fn run_morph<T: VOrd + MinMax + Element, B: Element>(
173 input: &mut [T],
174 shape: Shape,
175 kernel: MorphKernel<B>,
176 op: MorphOp,
177 iter: usize,
178 btype: BorderType,
179 bvalue: &[T],
180) {
181 match op {
182 MorphOp::Erode => {
183 let filter = filter::<T, MinOp, B>(kernel);
184 dispatch_morph(input, shape, filter, btype, bvalue, iter);
185 }
186 MorphOp::Dilate => {
187 let filter = filter::<T, MaxOp, B>(kernel);
188 dispatch_morph(input, shape, filter, btype, bvalue, iter);
189 }
190 };
191}
192
193fn filter<T: VOrd + MinMax, Op: MorphOperator<T> + VecMorphOperator<T>, B: Element>(
194 kernel: MorphKernel<B>,
195) -> Filter<T, Op> {
196 match kernel {
197 MorphKernel::Rect { size, anchor } => {
198 let row_filter = RowFilter::new(size.width, anchor.x);
199 let col_filter = ColFilter::new(size.height, anchor.y);
200 Filter::Separable {
201 row_filter,
202 col_filter,
203 }
204 }
205 MorphKernel::Other {
206 kernel,
207 size,
208 anchor,
209 } => {
210 let filter = Filter2D::new(&kernel, size, anchor);
211 Filter::Fallback(filter)
212 }
213 }
214}
215
216#[inline(always)]
217#[allow(clippy::too_many_arguments)]
218#[macerator::with_simd]
219fn dispatch_morph<
220 'a,
221 S: Simd,
222 T: VOrd + MinMax + Debug,
223 Op: MorphOperator<T> + VecMorphOperator<T>,
224>(
225 buffer: &'a mut [T],
226 buffer_shape: Shape,
227 filter: filter_engine::Filter<T, Op>,
228 border_type: BorderType,
229 border_value: &'a [T],
230 iterations: usize,
231) where
232 'a: 'a,
233{
234 let [_, _, ch] = buffer_shape.dims();
235 let mut engine = FilterEngine::<S, _, _>::new(filter, border_type, border_value, ch);
236 engine.apply(buffer, buffer_shape.clone());
237 for _ in 1..iterations {
238 engine.apply(buffer, buffer_shape.clone());
239 }
240}
241
242#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
244pub enum KernelShape {
245 Rect,
247 Cross,
249 Ellipse,
251}
252
253pub fn create_structuring_element<B: Backend>(
255 shape: KernelShape,
256 ksize: Size,
257 anchor: Option<Point>,
258 device: &B::Device,
259) -> Tensor<B, 2, Bool> {
260 fn create_kernel(shape: KernelShape, ksize: Size, anchor: Option<Point>) -> Vec<bool> {
261 let anchor = anchor.unwrap_or(Point::new(ksize.width / 2, ksize.height / 2));
262 let mut r = 0;
263 let mut c = 0;
264 let mut inv_r2 = 0.0;
265
266 if (ksize.width == 1 && ksize.height == 1) || shape == KernelShape::Rect {
267 return vec![true; ksize.height * ksize.width];
268 }
269
270 if shape == KernelShape::Ellipse {
271 r = ksize.height / 2;
272 c = ksize.width / 2;
273 inv_r2 = if r > 0 { 1.0 / (r * r) as f64 } else { 0.0 }
274 }
275
276 let mut elem = vec![false; ksize.height * ksize.width];
277
278 for i in 0..ksize.height {
279 let mut j1 = 0;
280 let mut j2 = 0;
281 if shape == KernelShape::Cross && i == anchor.y {
282 j2 = ksize.width;
283 } else if shape == KernelShape::Cross {
284 j1 = anchor.x;
285 j2 = j1 + 1;
286 } else {
287 let dy = i as isize - r as isize;
288 if dy.abs() <= r as isize {
289 let dx = (c as f64 * ((r * r - (dy * dy) as usize) as f64 * inv_r2).sqrt())
290 .round() as isize;
291 j1 = (c as isize - dx).max(0) as usize;
292 j2 = (c + dx as usize + 1).min(ksize.width);
293 }
294 }
295
296 for j in j1..j2 {
297 elem[i * ksize.width + j] = true;
298 }
299 }
300 elem
301 }
302
303 let elem = create_kernel(shape, ksize, anchor);
304
305 let data = TensorData::new(elem, [ksize.height, ksize.width]);
306 Tensor::from_data(data, device)
307}