1use crate::{
2 Point,
3 backends::cpu::{self, MorphOp, morph},
4};
5use bon::Builder;
6use burn_tensor::{
7 Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
8 backend::Backend,
9 ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
10};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum Connectivity {
15 Four,
17 Eight,
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub struct ConnectedStatsOptions {
27 pub bounds_enabled: bool,
29 pub max_label_enabled: bool,
31 pub compact_labels: bool,
33}
34
35#[derive(Clone, Debug, Builder)]
37pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
38 pub anchor: Option<Point>,
40 #[builder(default = 1)]
42 pub iterations: usize,
43 #[builder(default)]
45 pub border_type: BorderType,
46 pub border_value: Option<Tensor<B, 1, K>>,
48}
49
50impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
51 fn default() -> Self {
52 Self {
53 anchor: Default::default(),
54 iterations: 1,
55 border_type: Default::default(),
56 border_value: Default::default(),
57 }
58 }
59}
60
61#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
63pub enum BorderType {
64 #[default]
67 Constant,
68 Replicate,
70 Reflect,
72 Reflect101,
74 Wrap,
76}
77
78#[derive(Clone, Debug)]
82pub struct ConnectedStats<B: Backend> {
83 pub area: Tensor<B, 1, Int>,
85 pub top: Tensor<B, 1, Int>,
87 pub left: Tensor<B, 1, Int>,
89 pub right: Tensor<B, 1, Int>,
91 pub bottom: Tensor<B, 1, Int>,
93 pub max_label: Tensor<B, 1, Int>,
95}
96
97pub struct ConnectedStatsPrimitive<B: Backend> {
99 pub area: IntTensor<B>,
101 pub left: IntTensor<B>,
103 pub top: IntTensor<B>,
105 pub right: IntTensor<B>,
107 pub bottom: IntTensor<B>,
109 pub max_label: IntTensor<B>,
111}
112
113impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
114 fn from(value: ConnectedStatsPrimitive<B>) -> Self {
115 ConnectedStats {
116 area: Tensor::from_primitive(value.area),
117 top: Tensor::from_primitive(value.top),
118 left: Tensor::from_primitive(value.left),
119 right: Tensor::from_primitive(value.right),
120 bottom: Tensor::from_primitive(value.bottom),
121 max_label: Tensor::from_primitive(value.max_label),
122 }
123 }
124}
125
126impl<B: Backend> ConnectedStats<B> {
127 pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
129 ConnectedStatsPrimitive {
130 area: self.area.into_primitive(),
131 top: self.top.into_primitive(),
132 left: self.left.into_primitive(),
133 right: self.right.into_primitive(),
134 bottom: self.bottom.into_primitive(),
135 max_label: self.max_label.into_primitive(),
136 }
137 }
138}
139
140impl Default for ConnectedStatsOptions {
141 fn default() -> Self {
142 Self::all()
143 }
144}
145
146impl ConnectedStatsOptions {
147 pub fn none() -> Self {
149 Self {
150 bounds_enabled: false,
151 max_label_enabled: false,
152 compact_labels: false,
153 }
154 }
155
156 pub fn all() -> Self {
158 Self {
159 bounds_enabled: true,
160 max_label_enabled: true,
161 compact_labels: true,
162 }
163 }
164}
165
166pub trait VisionBackend:
168 BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
169{
170}
171
172pub trait BoolVisionOps: Backend {
174 fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
179 cpu::connected_components::<Self>(img, connectivity)
180 }
181
182 fn connected_components_with_stats(
188 img: BoolTensor<Self>,
189 connectivity: Connectivity,
190 opts: ConnectedStatsOptions,
191 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
192 cpu::connected_components_with_stats(img, connectivity, opts)
193 }
194
195 fn bool_erode(
197 input: BoolTensor<Self>,
198 kernel: BoolTensor<Self>,
199 opts: MorphOptions<Self, Bool>,
200 ) -> BoolTensor<Self> {
201 let input = Tensor::<Self, 3, Bool>::from_primitive(input);
202 morph(input, kernel, MorphOp::Erode, opts).into_primitive()
203 }
204
205 fn bool_dilate(
207 input: BoolTensor<Self>,
208 kernel: BoolTensor<Self>,
209 opts: MorphOptions<Self, Bool>,
210 ) -> BoolTensor<Self> {
211 let input = Tensor::<Self, 3, Bool>::from_primitive(input);
212 morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
213 }
214}
215
216pub trait IntVisionOps: Backend {
218 fn int_erode(
220 input: IntTensor<Self>,
221 kernel: BoolTensor<Self>,
222 opts: MorphOptions<Self, Int>,
223 ) -> IntTensor<Self> {
224 let input = Tensor::<Self, 3, Int>::from_primitive(input);
225 morph(input, kernel, MorphOp::Erode, opts).into_primitive()
226 }
227
228 fn int_dilate(
230 input: IntTensor<Self>,
231 kernel: BoolTensor<Self>,
232 opts: MorphOptions<Self, Int>,
233 ) -> IntTensor<Self> {
234 let input = Tensor::<Self, 3, Int>::from_primitive(input);
235 morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
236 }
237}
238
239pub trait FloatVisionOps: Backend {
241 fn float_erode(
243 input: FloatTensor<Self>,
244 kernel: BoolTensor<Self>,
245 opts: MorphOptions<Self, Float>,
246 ) -> FloatTensor<Self> {
247 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
248
249 morph(input, kernel, MorphOp::Erode, opts)
250 .into_primitive()
251 .tensor()
252 }
253
254 fn float_dilate(
256 input: FloatTensor<Self>,
257 kernel: BoolTensor<Self>,
258 opts: MorphOptions<Self, Float>,
259 ) -> FloatTensor<Self> {
260 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
261 morph(input, kernel, MorphOp::Dilate, opts)
262 .into_primitive()
263 .tensor()
264 }
265}
266
267pub trait QVisionOps: Backend {
269 fn q_erode(
271 input: QuantizedTensor<Self>,
272 kernel: BoolTensor<Self>,
273 opts: MorphOptions<Self, Float>,
274 ) -> QuantizedTensor<Self> {
275 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
276 match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
277 TensorPrimitive::QFloat(tensor) => tensor,
278 _ => unreachable!(),
279 }
280 }
281
282 fn q_dilate(
284 input: QuantizedTensor<Self>,
285 kernel: BoolTensor<Self>,
286 opts: MorphOptions<Self, Float>,
287 ) -> QuantizedTensor<Self> {
288 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
289 match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
290 TensorPrimitive::QFloat(tensor) => tensor,
291 _ => unreachable!(),
292 }
293 }
294}