Skip to main content

burn_vision/ops/
base.rs

1use crate::{
2    Point,
3    backends::cpu::{self, MorphOp, morph},
4};
5use bon::Builder;
6use burn_tensor::{
7    Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
8    backend::Backend,
9    ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
10};
11
12/// Connected components connectivity
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum Connectivity {
15    /// Four-connected (only connected in cardinal directions)
16    Four,
17    /// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground)
18    Eight,
19}
20
21/// Which stats should be enabled for `connected_components_with_stats`.
22/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats.
23///
24/// Disabled stats are aliased to the labels tensor
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub struct ConnectedStatsOptions {
27    /// Whether to enable bounding boxes
28    pub bounds_enabled: bool,
29    /// Whether to enable the max label
30    pub max_label_enabled: bool,
31    /// Whether labels must be contiguous starting at 1
32    pub compact_labels: bool,
33}
34
35/// Options for morphology ops
36#[derive(Clone, Debug, Builder)]
37pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
38    /// Anchor position within the kernel. Defaults to the center.
39    pub anchor: Option<Point>,
40    /// Number of iterations to apply
41    #[builder(default = 1)]
42    pub iterations: usize,
43    /// Border type. Default: constant based on operation
44    #[builder(default)]
45    pub border_type: BorderType,
46    /// Value of each channel for constant border type
47    pub border_value: Option<Tensor<B, 1, K>>,
48}
49
50impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
51    fn default() -> Self {
52        Self {
53            anchor: Default::default(),
54            iterations: 1,
55            border_type: Default::default(),
56            border_value: Default::default(),
57        }
58    }
59}
60
61/// Morphology border type
62#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
63pub enum BorderType {
64    /// Constant border with per-channel value. If no value is provided, the value is picked based
65    /// on the morph op.
66    #[default]
67    Constant,
68    /// Replicate first/last element
69    Replicate,
70    /// Reflect start/end elements
71    Reflect,
72    /// Reflect start/end elements, ignoring the first/last element
73    Reflect101,
74    /// Not supported for erode/dilate
75    Wrap,
76}
77
78/// Stats collected by the connected components analysis
79///
80/// Disabled analyses may be aliased to labels
81#[derive(Clone, Debug)]
82pub struct ConnectedStats<B: Backend> {
83    /// Total area of each component
84    pub area: Tensor<B, 1, Int>,
85    /// Topmost y coordinate in the component
86    pub top: Tensor<B, 1, Int>,
87    /// Leftmost x coordinate in the component
88    pub left: Tensor<B, 1, Int>,
89    /// Rightmost x coordinate in the component
90    pub right: Tensor<B, 1, Int>,
91    /// Bottommost y coordinate in the component
92    pub bottom: Tensor<B, 1, Int>,
93    /// Scalar tensor of the max label
94    pub max_label: Tensor<B, 1, Int>,
95}
96
97/// Primitive version of [`ConnectedStats`], to be returned by the backend
98pub struct ConnectedStatsPrimitive<B: Backend> {
99    /// Total area of each component
100    pub area: IntTensor<B>,
101    /// Leftmost x coordinate in the component
102    pub left: IntTensor<B>,
103    /// Topmost y coordinate in the component
104    pub top: IntTensor<B>,
105    /// Rightmost x coordinate in the component
106    pub right: IntTensor<B>,
107    /// Bottommost y coordinate in the component
108    pub bottom: IntTensor<B>,
109    /// Scalar tensor of the max label
110    pub max_label: IntTensor<B>,
111}
112
113impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
114    fn from(value: ConnectedStatsPrimitive<B>) -> Self {
115        ConnectedStats {
116            area: Tensor::from_primitive(value.area),
117            top: Tensor::from_primitive(value.top),
118            left: Tensor::from_primitive(value.left),
119            right: Tensor::from_primitive(value.right),
120            bottom: Tensor::from_primitive(value.bottom),
121            max_label: Tensor::from_primitive(value.max_label),
122        }
123    }
124}
125
126impl<B: Backend> ConnectedStats<B> {
127    /// Convert a connected stats into the corresponding primitive
128    pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
129        ConnectedStatsPrimitive {
130            area: self.area.into_primitive(),
131            top: self.top.into_primitive(),
132            left: self.left.into_primitive(),
133            right: self.right.into_primitive(),
134            bottom: self.bottom.into_primitive(),
135            max_label: self.max_label.into_primitive(),
136        }
137    }
138}
139
140impl Default for ConnectedStatsOptions {
141    fn default() -> Self {
142        Self::all()
143    }
144}
145
146impl ConnectedStatsOptions {
147    /// Don't collect any stats
148    pub fn none() -> Self {
149        Self {
150            bounds_enabled: false,
151            max_label_enabled: false,
152            compact_labels: false,
153        }
154    }
155
156    /// Collect all stats
157    pub fn all() -> Self {
158        Self {
159            bounds_enabled: true,
160            max_label_enabled: true,
161            compact_labels: true,
162        }
163    }
164}
165
166/// Non-Maximum Suppression options.
167#[derive(Clone, Copy, Debug)]
168pub struct NmsOptions {
169    /// IoU threshold for suppression (default: 0.5).
170    /// Boxes with IoU > threshold with a higher-scoring box are suppressed.
171    pub iou_threshold: f32,
172    /// Score threshold to filter boxes before NMS (default: 0.0, i.e., no filtering).
173    /// Boxes with score < score_threshold are discarded.
174    pub score_threshold: f32,
175    /// Maximum number of boxes to keep (0 = unlimited).
176    pub max_output_boxes: usize,
177}
178
179impl Default for NmsOptions {
180    fn default() -> Self {
181        Self {
182            iou_threshold: 0.5,
183            score_threshold: 0.0,
184            max_output_boxes: 0,
185        }
186    }
187}
188
189/// Vision capable backend, implemented by each backend
190pub trait VisionBackend:
191    BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
192{
193}
194
195/// Vision ops on bool tensors
196pub trait BoolVisionOps: Backend {
197    /// Computes the connected components labeled image of boolean image with 4 or 8 way
198    /// connectivity - returns a tensor of the component label of each pixel.
199    ///
200    /// `img`- The boolean image tensor in the format [batches, height, width]
201    fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
202        cpu::connected_components::<Self>(img, connectivity)
203    }
204
205    /// Computes the connected components labeled image of boolean image with 4 or 8 way
206    /// connectivity and collects statistics on each component - returns a tensor of the component
207    /// label of each pixel, along with stats collected for each component.
208    ///
209    /// `img`- The boolean image tensor in the format [batches, height, width]
210    fn connected_components_with_stats(
211        img: BoolTensor<Self>,
212        connectivity: Connectivity,
213        opts: ConnectedStatsOptions,
214    ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
215        cpu::connected_components_with_stats(img, connectivity, opts)
216    }
217
218    /// Erodes an input tensor with the specified kernel.
219    fn bool_erode(
220        input: BoolTensor<Self>,
221        kernel: BoolTensor<Self>,
222        opts: MorphOptions<Self, Bool>,
223    ) -> BoolTensor<Self> {
224        let input = Tensor::<Self, 3, Bool>::from_primitive(input);
225        morph(input, kernel, MorphOp::Erode, opts).into_primitive()
226    }
227
228    /// Dilates an input tensor with the specified kernel.
229    fn bool_dilate(
230        input: BoolTensor<Self>,
231        kernel: BoolTensor<Self>,
232        opts: MorphOptions<Self, Bool>,
233    ) -> BoolTensor<Self> {
234        let input = Tensor::<Self, 3, Bool>::from_primitive(input);
235        morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
236    }
237}
238
239/// Vision ops on int tensors
240pub trait IntVisionOps: Backend {
241    /// Erodes an input tensor with the specified kernel.
242    fn int_erode(
243        input: IntTensor<Self>,
244        kernel: BoolTensor<Self>,
245        opts: MorphOptions<Self, Int>,
246    ) -> IntTensor<Self> {
247        let input = Tensor::<Self, 3, Int>::from_primitive(input);
248        morph(input, kernel, MorphOp::Erode, opts).into_primitive()
249    }
250
251    /// Dilates an input tensor with the specified kernel.
252    fn int_dilate(
253        input: IntTensor<Self>,
254        kernel: BoolTensor<Self>,
255        opts: MorphOptions<Self, Int>,
256    ) -> IntTensor<Self> {
257        let input = Tensor::<Self, 3, Int>::from_primitive(input);
258        morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
259    }
260}
261
262/// Vision ops on float tensors
263pub trait FloatVisionOps: Backend {
264    /// Erodes an input tensor with the specified kernel.
265    fn float_erode(
266        input: FloatTensor<Self>,
267        kernel: BoolTensor<Self>,
268        opts: MorphOptions<Self, Float>,
269    ) -> FloatTensor<Self> {
270        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
271
272        morph(input, kernel, MorphOp::Erode, opts)
273            .into_primitive()
274            .tensor()
275    }
276
277    /// Dilates an input tensor with the specified kernel.
278    fn float_dilate(
279        input: FloatTensor<Self>,
280        kernel: BoolTensor<Self>,
281        opts: MorphOptions<Self, Float>,
282    ) -> FloatTensor<Self> {
283        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
284        morph(input, kernel, MorphOp::Dilate, opts)
285            .into_primitive()
286            .tensor()
287    }
288
289    /// Perform Non-Maximum Suppression on bounding boxes.
290    ///
291    /// Returns indices of kept boxes after suppressing overlapping detections.
292    /// Boxes are processed in descending score order; a box suppresses all
293    /// lower-scoring boxes with IoU > threshold.
294    ///
295    /// # Arguments
296    /// * `boxes` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
297    /// * `scores` - Confidence scores as \[N\] tensor
298    /// * `options` - NMS options (IoU threshold, score threshold, max boxes)
299    ///
300    /// # Returns
301    /// Indices of kept boxes as \[M\] tensor where M <= N
302    fn nms(
303        boxes: FloatTensor<Self>,
304        scores: FloatTensor<Self>,
305        options: NmsOptions,
306    ) -> IntTensor<Self> {
307        let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));
308        let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));
309        cpu::nms::<Self>(boxes, scores, options).into_primitive()
310    }
311}
312
313/// Vision ops on quantized float tensors
314pub trait QVisionOps: Backend {
315    /// Erodes an input tensor with the specified kernel.
316    fn q_erode(
317        input: QuantizedTensor<Self>,
318        kernel: BoolTensor<Self>,
319        opts: MorphOptions<Self, Float>,
320    ) -> QuantizedTensor<Self> {
321        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
322        match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
323            TensorPrimitive::QFloat(tensor) => tensor,
324            _ => unreachable!(),
325        }
326    }
327
328    /// Dilates an input tensor with the specified kernel.
329    fn q_dilate(
330        input: QuantizedTensor<Self>,
331        kernel: BoolTensor<Self>,
332        opts: MorphOptions<Self, Float>,
333    ) -> QuantizedTensor<Self> {
334        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
335        match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
336            TensorPrimitive::QFloat(tensor) => tensor,
337            _ => unreachable!(),
338        }
339    }
340}