1use crate::{
2 Point,
3 backends::cpu::{self, MorphOp, morph},
4};
5use bon::Builder;
6use burn_tensor::{
7 Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
8 backend::Backend,
9 ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
10};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum Connectivity {
15 Four,
17 Eight,
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub struct ConnectedStatsOptions {
27 pub bounds_enabled: bool,
29 pub max_label_enabled: bool,
31 pub compact_labels: bool,
33}
34
35#[derive(Clone, Debug, Builder)]
37pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
38 pub anchor: Option<Point>,
40 #[builder(default = 1)]
42 pub iterations: usize,
43 #[builder(default)]
45 pub border_type: BorderType,
46 pub border_value: Option<Tensor<B, 1, K>>,
48}
49
50impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
51 fn default() -> Self {
52 Self {
53 anchor: Default::default(),
54 iterations: 1,
55 border_type: Default::default(),
56 border_value: Default::default(),
57 }
58 }
59}
60
61#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
63pub enum BorderType {
64 #[default]
67 Constant,
68 Replicate,
70 Reflect,
72 Reflect101,
74 Wrap,
76}
77
78#[derive(Clone, Debug)]
82pub struct ConnectedStats<B: Backend> {
83 pub area: Tensor<B, 1, Int>,
85 pub top: Tensor<B, 1, Int>,
87 pub left: Tensor<B, 1, Int>,
89 pub right: Tensor<B, 1, Int>,
91 pub bottom: Tensor<B, 1, Int>,
93 pub max_label: Tensor<B, 1, Int>,
95}
96
97pub struct ConnectedStatsPrimitive<B: Backend> {
99 pub area: IntTensor<B>,
101 pub left: IntTensor<B>,
103 pub top: IntTensor<B>,
105 pub right: IntTensor<B>,
107 pub bottom: IntTensor<B>,
109 pub max_label: IntTensor<B>,
111}
112
113impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
114 fn from(value: ConnectedStatsPrimitive<B>) -> Self {
115 ConnectedStats {
116 area: Tensor::from_primitive(value.area),
117 top: Tensor::from_primitive(value.top),
118 left: Tensor::from_primitive(value.left),
119 right: Tensor::from_primitive(value.right),
120 bottom: Tensor::from_primitive(value.bottom),
121 max_label: Tensor::from_primitive(value.max_label),
122 }
123 }
124}
125
126impl<B: Backend> ConnectedStats<B> {
127 pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
129 ConnectedStatsPrimitive {
130 area: self.area.into_primitive(),
131 top: self.top.into_primitive(),
132 left: self.left.into_primitive(),
133 right: self.right.into_primitive(),
134 bottom: self.bottom.into_primitive(),
135 max_label: self.max_label.into_primitive(),
136 }
137 }
138}
139
140impl Default for ConnectedStatsOptions {
141 fn default() -> Self {
142 Self::all()
143 }
144}
145
146impl ConnectedStatsOptions {
147 pub fn none() -> Self {
149 Self {
150 bounds_enabled: false,
151 max_label_enabled: false,
152 compact_labels: false,
153 }
154 }
155
156 pub fn all() -> Self {
158 Self {
159 bounds_enabled: true,
160 max_label_enabled: true,
161 compact_labels: true,
162 }
163 }
164}
165
166#[derive(Clone, Copy, Debug)]
168pub struct NmsOptions {
169 pub iou_threshold: f32,
172 pub score_threshold: f32,
175 pub max_output_boxes: usize,
177}
178
179impl Default for NmsOptions {
180 fn default() -> Self {
181 Self {
182 iou_threshold: 0.5,
183 score_threshold: 0.0,
184 max_output_boxes: 0,
185 }
186 }
187}
188
189pub trait VisionBackend:
191 BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
192{
193}
194
195pub trait BoolVisionOps: Backend {
197 fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
202 cpu::connected_components::<Self>(img, connectivity)
203 }
204
205 fn connected_components_with_stats(
211 img: BoolTensor<Self>,
212 connectivity: Connectivity,
213 opts: ConnectedStatsOptions,
214 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
215 cpu::connected_components_with_stats(img, connectivity, opts)
216 }
217
218 fn bool_erode(
220 input: BoolTensor<Self>,
221 kernel: BoolTensor<Self>,
222 opts: MorphOptions<Self, Bool>,
223 ) -> BoolTensor<Self> {
224 let input = Tensor::<Self, 3, Bool>::from_primitive(input);
225 morph(input, kernel, MorphOp::Erode, opts).into_primitive()
226 }
227
228 fn bool_dilate(
230 input: BoolTensor<Self>,
231 kernel: BoolTensor<Self>,
232 opts: MorphOptions<Self, Bool>,
233 ) -> BoolTensor<Self> {
234 let input = Tensor::<Self, 3, Bool>::from_primitive(input);
235 morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
236 }
237}
238
239pub trait IntVisionOps: Backend {
241 fn int_erode(
243 input: IntTensor<Self>,
244 kernel: BoolTensor<Self>,
245 opts: MorphOptions<Self, Int>,
246 ) -> IntTensor<Self> {
247 let input = Tensor::<Self, 3, Int>::from_primitive(input);
248 morph(input, kernel, MorphOp::Erode, opts).into_primitive()
249 }
250
251 fn int_dilate(
253 input: IntTensor<Self>,
254 kernel: BoolTensor<Self>,
255 opts: MorphOptions<Self, Int>,
256 ) -> IntTensor<Self> {
257 let input = Tensor::<Self, 3, Int>::from_primitive(input);
258 morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
259 }
260}
261
262pub trait FloatVisionOps: Backend {
264 fn float_erode(
266 input: FloatTensor<Self>,
267 kernel: BoolTensor<Self>,
268 opts: MorphOptions<Self, Float>,
269 ) -> FloatTensor<Self> {
270 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
271
272 morph(input, kernel, MorphOp::Erode, opts)
273 .into_primitive()
274 .tensor()
275 }
276
277 fn float_dilate(
279 input: FloatTensor<Self>,
280 kernel: BoolTensor<Self>,
281 opts: MorphOptions<Self, Float>,
282 ) -> FloatTensor<Self> {
283 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
284 morph(input, kernel, MorphOp::Dilate, opts)
285 .into_primitive()
286 .tensor()
287 }
288
289 fn nms(
303 boxes: FloatTensor<Self>,
304 scores: FloatTensor<Self>,
305 options: NmsOptions,
306 ) -> IntTensor<Self> {
307 let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));
308 let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));
309 cpu::nms::<Self>(boxes, scores, options).into_primitive()
310 }
311}
312
313pub trait QVisionOps: Backend {
315 fn q_erode(
317 input: QuantizedTensor<Self>,
318 kernel: BoolTensor<Self>,
319 opts: MorphOptions<Self, Float>,
320 ) -> QuantizedTensor<Self> {
321 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
322 match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
323 TensorPrimitive::QFloat(tensor) => tensor,
324 _ => unreachable!(),
325 }
326 }
327
328 fn q_dilate(
330 input: QuantizedTensor<Self>,
331 kernel: BoolTensor<Self>,
332 opts: MorphOptions<Self, Float>,
333 ) -> QuantizedTensor<Self> {
334 let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
335 match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
336 TensorPrimitive::QFloat(tensor) => tensor,
337 _ => unreachable!(),
338 }
339 }
340}