1use std::fmt::Debug;
2
3use burn_tensor::{
4 BasicOps, Bool, DType, Element, Shape, Tensor, TensorData, backend::Backend, cast::ToElement,
5 ops::BoolTensor,
6};
7use filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator};
8use filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter};
9use macerator::{Simd, VOrd};
10
11use crate::{BorderType, MorphOptions, Point, Size};
12
13use super::MinMax;
14
15mod filter;
16mod filter_engine;
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
21pub enum MorphOp {
22 Erode,
23 Dilate,
24}
25
26pub enum MorphKernel<B: Element> {
27 Rect {
28 size: Size,
29 anchor: Point,
30 },
31 Other {
32 kernel: Vec<B>,
33 size: Size,
34 anchor: Point,
35 },
36}
37
38pub fn morph<B: Backend, K: BasicOps<B>>(
39 input: Tensor<B, 3, K>,
40 kernel: BoolTensor<B>,
41 op: MorphOp,
42 opts: MorphOptions<B, K>,
43) -> Tensor<B, 3, K> {
44 let device = input.device();
45
46 let kernel = Tensor::<B, 2, Bool>::new(kernel);
47 let kshape = kernel.shape().dims();
48 let [kh, kw] = kshape;
49
50 let kernel = kernel.into_data().into_vec::<B::BoolElem>().unwrap();
51 let is_rect = kernel.iter().all(|it| it.to_bool());
52 let anchor = opts.anchor.unwrap_or(Point::new(kw / 2, kh / 2));
53 let iter = opts.iterations;
54 let btype = opts.border_type;
55 let bvalue = opts.border_value.map(|it| it.into_data());
56
57 let size = Size::new(kw, kh);
58 let kernel = if is_rect {
59 MorphKernel::Rect { size, anchor }
60 } else {
61 MorphKernel::Other {
62 kernel,
63 size,
64 anchor,
65 }
66 };
67
68 let shape = input.shape();
69 let data = input.into_data();
70 match data.dtype {
71 DType::F64 => {
72 morph_typed::<B, K, f64>(data, shape, kernel, op, iter, btype, bvalue, &device)
73 }
74 DType::F32 | DType::Flex32 => {
75 morph_typed::<B, K, f32>(data, shape, kernel, op, iter, btype, bvalue, &device)
76 }
77 DType::F16 | DType::BF16 => morph_typed::<B, K, f32>(
78 data.convert::<f32>(),
79 shape,
80 kernel,
81 op,
82 iter,
83 btype,
84 bvalue,
85 &device,
86 ),
87 DType::I64 => {
88 morph_typed::<B, K, i64>(data, shape, kernel, op, iter, btype, bvalue, &device)
89 }
90 DType::I32 => {
91 morph_typed::<B, K, i32>(data, shape, kernel, op, iter, btype, bvalue, &device)
92 }
93 DType::I16 => {
94 morph_typed::<B, K, i16>(data, shape, kernel, op, iter, btype, bvalue, &device)
95 }
96 DType::I8 => morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device),
97 DType::U64 => {
98 morph_typed::<B, K, u64>(data, shape, kernel, op, iter, btype, bvalue, &device)
99 }
100 DType::U32 => {
101 morph_typed::<B, K, u32>(data, shape, kernel, op, iter, btype, bvalue, &device)
102 }
103 DType::U16 => {
104 morph_typed::<B, K, u16>(data, shape, kernel, op, iter, btype, bvalue, &device)
105 }
106 DType::U8 => morph_typed::<B, K, u8>(data, shape, kernel, op, iter, btype, bvalue, &device),
107 DType::Bool => morph_bool::<B, K>(data, shape, kernel, op, iter, btype, bvalue, &device),
108 DType::QFloat(_) => unimplemented!(),
109 }
110}
111
112#[allow(clippy::too_many_arguments)]
113fn morph_typed<B: Backend, K: BasicOps<B>, T: VOrd + MinMax + Element>(
114 mut input: TensorData,
115 shape: Shape,
116 kernel: MorphKernel<B::BoolElem>,
117 op: MorphOp,
118 iter: usize,
119 btype: BorderType,
120 bvalue: Option<TensorData>,
121 device: &B::Device,
122) -> Tensor<B, 3, K> {
123 let data = input.as_mut_slice::<T>().unwrap();
124 let bvalue = border_value(btype, bvalue, op, &shape);
125 run_morph(data, shape, kernel, op, iter, btype, &bvalue);
126 Tensor::from_data(input, device)
127}
128
129#[allow(clippy::too_many_arguments)]
130fn morph_bool<B: Backend, K: BasicOps<B>>(
131 mut input: TensorData,
132 shape: Shape,
133 kernel: MorphKernel<B::BoolElem>,
134 op: MorphOp,
135 iter: usize,
136 btype: BorderType,
137 bvalue: Option<TensorData>,
138 device: &B::Device,
139) -> Tensor<B, 3, K> {
140 let data = input.as_mut_slice::<bool>().unwrap();
141 let data = unsafe { core::mem::transmute::<&mut [bool], &mut [u8]>(data) };
143 let bvalue = border_value(btype, bvalue, op, &shape);
144 run_morph(data, shape.clone(), kernel, op, iter, btype, &bvalue);
145 Tensor::from_data(input, device)
146}
147
148fn border_value<T: Element>(
149 btype: BorderType,
150 bvalue: Option<TensorData>,
151 op: MorphOp,
152 shape: &Shape,
153) -> Vec<T> {
154 let [_, _, ch] = shape.dims();
155 match (btype, bvalue) {
156 (BorderType::Constant, Some(value)) => value.convert::<T>().into_vec().unwrap(),
157 (BorderType::Constant, None) => match op {
158 MorphOp::Erode => vec![T::MAX; ch],
159 MorphOp::Dilate => vec![T::MIN; ch],
160 },
161 _ => vec![],
162 }
163}
164
165fn run_morph<T: VOrd + MinMax + Element, B: Element>(
166 input: &mut [T],
167 shape: Shape,
168 kernel: MorphKernel<B>,
169 op: MorphOp,
170 iter: usize,
171 btype: BorderType,
172 bvalue: &[T],
173) {
174 match op {
175 MorphOp::Erode => {
176 let filter = filter::<T, MinOp, B>(kernel);
177 dispatch_morph(input, shape, filter, btype, bvalue, iter);
178 }
179 MorphOp::Dilate => {
180 let filter = filter::<T, MaxOp, B>(kernel);
181 dispatch_morph(input, shape, filter, btype, bvalue, iter);
182 }
183 };
184}
185
186fn filter<T: VOrd + MinMax, Op: MorphOperator<T> + VecMorphOperator<T>, B: Element>(
187 kernel: MorphKernel<B>,
188) -> Filter<T, Op> {
189 match kernel {
190 MorphKernel::Rect { size, anchor } => {
191 let row_filter = RowFilter::new(size.width, anchor.x);
192 let col_filter = ColFilter::new(size.height, anchor.y);
193 Filter::Separable {
194 row_filter,
195 col_filter,
196 }
197 }
198 MorphKernel::Other {
199 kernel,
200 size,
201 anchor,
202 } => {
203 let filter = Filter2D::new(&kernel, size, anchor);
204 Filter::Fallback(filter)
205 }
206 }
207}
208
209#[inline(always)]
210#[allow(clippy::too_many_arguments)]
211#[macerator::with_simd]
212fn dispatch_morph<
213 'a,
214 S: Simd,
215 T: VOrd + MinMax + Debug,
216 Op: MorphOperator<T> + VecMorphOperator<T>,
217>(
218 buffer: &'a mut [T],
219 buffer_shape: Shape,
220 filter: filter_engine::Filter<T, Op>,
221 border_type: BorderType,
222 border_value: &'a [T],
223 iterations: usize,
224) where
225 'a: 'a,
226{
227 let [_, _, ch] = buffer_shape.dims();
228 let mut engine = FilterEngine::<S, _, _>::new(filter, border_type, border_value, ch);
229 engine.apply(buffer, buffer_shape.clone());
230 for _ in 1..iterations {
231 engine.apply(buffer, buffer_shape.clone());
232 }
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
237pub enum KernelShape {
238 Rect,
240 Cross,
242 Ellipse,
244}
245
246pub fn create_structuring_element<B: Backend>(
248 shape: KernelShape,
249 ksize: Size,
250 anchor: Option<Point>,
251 device: &B::Device,
252) -> Tensor<B, 2, Bool> {
253 fn create_kernel(shape: KernelShape, ksize: Size, anchor: Option<Point>) -> Vec<bool> {
254 let anchor = anchor.unwrap_or(Point::new(ksize.width / 2, ksize.height / 2));
255 let mut r = 0;
256 let mut c = 0;
257 let mut inv_r2 = 0.0;
258
259 if (ksize.width == 1 && ksize.height == 1) || shape == KernelShape::Rect {
260 return vec![true; ksize.height * ksize.width];
261 }
262
263 if shape == KernelShape::Ellipse {
264 r = ksize.height / 2;
265 c = ksize.width / 2;
266 inv_r2 = if r > 0 { 1.0 / (r * r) as f64 } else { 0.0 }
267 }
268
269 let mut elem = vec![false; ksize.height * ksize.width];
270
271 for i in 0..ksize.height {
272 let mut j1 = 0;
273 let mut j2 = 0;
274 if shape == KernelShape::Cross && i == anchor.y {
275 j2 = ksize.width;
276 } else if shape == KernelShape::Cross {
277 j1 = anchor.x;
278 j2 = j1 + 1;
279 } else {
280 let dy = i as isize - r as isize;
281 if dy.abs() <= r as isize {
282 let dx = (c as f64 * ((r * r - (dy * dy) as usize) as f64 * inv_r2).sqrt())
283 .round() as isize;
284 j1 = (c as isize - dx).max(0) as usize;
285 j2 = (c + dx as usize + 1).min(ksize.width);
286 }
287 }
288
289 for j in j1..j2 {
290 elem[i * ksize.width + j] = true;
291 }
292 }
293 elem
294 }
295
296 let elem = create_kernel(shape, ksize, anchor);
297
298 let data = TensorData::new(elem, [ksize.height, ksize.width]);
299 Tensor::from_data(data, device)
300}