burn_vision/ops/
base.rs

1use crate::{
2    Point,
3    backends::cpu::{self, MorphOp, morph},
4};
5use bon::Builder;
6use burn_tensor::{
7    Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
8    backend::Backend,
9    ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
10};
11
12/// Connected components connectivity
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum Connectivity {
15    /// Four-connected (only connected in cardinal directions)
16    Four,
17    /// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground)
18    Eight,
19}
20
21/// Which stats should be enabled for `connected_components_with_stats`.
22/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats.
23///
24/// Disabled stats are aliased to the labels tensor
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub struct ConnectedStatsOptions {
27    /// Whether to enable bounding boxes
28    pub bounds_enabled: bool,
29    /// Whether to enable the max label
30    pub max_label_enabled: bool,
31    /// Whether labels must be contiguous starting at 1
32    pub compact_labels: bool,
33}
34
35/// Options for morphology ops
36#[derive(Clone, Debug, Builder)]
37pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
38    /// Anchor position within the kernel. Defaults to the center.
39    pub anchor: Option<Point>,
40    /// Number of iterations to apply
41    #[builder(default = 1)]
42    pub iterations: usize,
43    /// Border type. Default: constant based on operation
44    #[builder(default)]
45    pub border_type: BorderType,
46    /// Value of each channel for constant border type
47    pub border_value: Option<Tensor<B, 1, K>>,
48}
49
50impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
51    fn default() -> Self {
52        Self {
53            anchor: Default::default(),
54            iterations: 1,
55            border_type: Default::default(),
56            border_value: Default::default(),
57        }
58    }
59}
60
61/// Morphology border type
62#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
63pub enum BorderType {
64    /// Constant border with per-channel value. If no value is provided, the value is picked based
65    /// on the morph op.
66    #[default]
67    Constant,
68    /// Replicate first/last element
69    Replicate,
70    /// Reflect start/end elements
71    Reflect,
72    /// Reflect start/end elements, ignoring the first/last element
73    Reflect101,
74    /// Not supported for erode/dilate
75    Wrap,
76}
77
78/// Stats collected by the connected components analysis
79///
80/// Disabled analyses may be aliased to labels
81#[derive(Clone, Debug)]
82pub struct ConnectedStats<B: Backend> {
83    /// Total area of each component
84    pub area: Tensor<B, 1, Int>,
85    /// Topmost y coordinate in the component
86    pub top: Tensor<B, 1, Int>,
87    /// Leftmost x coordinate in the component
88    pub left: Tensor<B, 1, Int>,
89    /// Rightmost x coordinate in the component
90    pub right: Tensor<B, 1, Int>,
91    /// Bottommost y coordinate in the component
92    pub bottom: Tensor<B, 1, Int>,
93    /// Scalar tensor of the max label
94    pub max_label: Tensor<B, 1, Int>,
95}
96
97/// Primitive version of [`ConnectedStats`], to be returned by the backend
98pub struct ConnectedStatsPrimitive<B: Backend> {
99    /// Total area of each component
100    pub area: IntTensor<B>,
101    /// Leftmost x coordinate in the component
102    pub left: IntTensor<B>,
103    /// Topmost y coordinate in the component
104    pub top: IntTensor<B>,
105    /// Rightmost x coordinate in the component
106    pub right: IntTensor<B>,
107    /// Bottommost y coordinate in the component
108    pub bottom: IntTensor<B>,
109    /// Scalar tensor of the max label
110    pub max_label: IntTensor<B>,
111}
112
113impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
114    fn from(value: ConnectedStatsPrimitive<B>) -> Self {
115        ConnectedStats {
116            area: Tensor::from_primitive(value.area),
117            top: Tensor::from_primitive(value.top),
118            left: Tensor::from_primitive(value.left),
119            right: Tensor::from_primitive(value.right),
120            bottom: Tensor::from_primitive(value.bottom),
121            max_label: Tensor::from_primitive(value.max_label),
122        }
123    }
124}
125
126impl<B: Backend> ConnectedStats<B> {
127    /// Convert a connected stats into the corresponding primitive
128    pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
129        ConnectedStatsPrimitive {
130            area: self.area.into_primitive(),
131            top: self.top.into_primitive(),
132            left: self.left.into_primitive(),
133            right: self.right.into_primitive(),
134            bottom: self.bottom.into_primitive(),
135            max_label: self.max_label.into_primitive(),
136        }
137    }
138}
139
140impl Default for ConnectedStatsOptions {
141    fn default() -> Self {
142        Self::all()
143    }
144}
145
146impl ConnectedStatsOptions {
147    /// Don't collect any stats
148    pub fn none() -> Self {
149        Self {
150            bounds_enabled: false,
151            max_label_enabled: false,
152            compact_labels: false,
153        }
154    }
155
156    /// Collect all stats
157    pub fn all() -> Self {
158        Self {
159            bounds_enabled: true,
160            max_label_enabled: true,
161            compact_labels: true,
162        }
163    }
164}
165
166/// Vision capable backend, implemented by each backend
167pub trait VisionBackend:
168    BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
169{
170}
171
172/// Vision ops on bool tensors
173pub trait BoolVisionOps: Backend {
174    /// Computes the connected components labeled image of boolean image with 4 or 8 way
175    /// connectivity - returns a tensor of the component label of each pixel.
176    ///
177    /// `img`- The boolean image tensor in the format [batches, height, width]
178    fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
179        cpu::connected_components::<Self>(img, connectivity)
180    }
181
182    /// Computes the connected components labeled image of boolean image with 4 or 8 way
183    /// connectivity and collects statistics on each component - returns a tensor of the component
184    /// label of each pixel, along with stats collected for each component.
185    ///
186    /// `img`- The boolean image tensor in the format [batches, height, width]
187    fn connected_components_with_stats(
188        img: BoolTensor<Self>,
189        connectivity: Connectivity,
190        opts: ConnectedStatsOptions,
191    ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
192        cpu::connected_components_with_stats(img, connectivity, opts)
193    }
194
195    /// Erodes an input tensor with the specified kernel.
196    fn bool_erode(
197        input: BoolTensor<Self>,
198        kernel: BoolTensor<Self>,
199        opts: MorphOptions<Self, Bool>,
200    ) -> BoolTensor<Self> {
201        let input = Tensor::<Self, 3, Bool>::from_primitive(input);
202        morph(input, kernel, MorphOp::Erode, opts).into_primitive()
203    }
204
205    /// Dilates an input tensor with the specified kernel.
206    fn bool_dilate(
207        input: BoolTensor<Self>,
208        kernel: BoolTensor<Self>,
209        opts: MorphOptions<Self, Bool>,
210    ) -> BoolTensor<Self> {
211        let input = Tensor::<Self, 3, Bool>::from_primitive(input);
212        morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
213    }
214}
215
216/// Vision ops on int tensors
217pub trait IntVisionOps: Backend {
218    /// Erodes an input tensor with the specified kernel.
219    fn int_erode(
220        input: IntTensor<Self>,
221        kernel: BoolTensor<Self>,
222        opts: MorphOptions<Self, Int>,
223    ) -> IntTensor<Self> {
224        let input = Tensor::<Self, 3, Int>::from_primitive(input);
225        morph(input, kernel, MorphOp::Erode, opts).into_primitive()
226    }
227
228    /// Dilates an input tensor with the specified kernel.
229    fn int_dilate(
230        input: IntTensor<Self>,
231        kernel: BoolTensor<Self>,
232        opts: MorphOptions<Self, Int>,
233    ) -> IntTensor<Self> {
234        let input = Tensor::<Self, 3, Int>::from_primitive(input);
235        morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
236    }
237}
238
239/// Vision ops on float tensors
240pub trait FloatVisionOps: Backend {
241    /// Erodes an input tensor with the specified kernel.
242    fn float_erode(
243        input: FloatTensor<Self>,
244        kernel: BoolTensor<Self>,
245        opts: MorphOptions<Self, Float>,
246    ) -> FloatTensor<Self> {
247        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
248
249        morph(input, kernel, MorphOp::Erode, opts)
250            .into_primitive()
251            .tensor()
252    }
253
254    /// Dilates an input tensor with the specified kernel.
255    fn float_dilate(
256        input: FloatTensor<Self>,
257        kernel: BoolTensor<Self>,
258        opts: MorphOptions<Self, Float>,
259    ) -> FloatTensor<Self> {
260        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
261        morph(input, kernel, MorphOp::Dilate, opts)
262            .into_primitive()
263            .tensor()
264    }
265}
266
267/// Vision ops on quantized float tensors
268pub trait QVisionOps: Backend {
269    /// Erodes an input tensor with the specified kernel.
270    fn q_erode(
271        input: QuantizedTensor<Self>,
272        kernel: BoolTensor<Self>,
273        opts: MorphOptions<Self, Float>,
274    ) -> QuantizedTensor<Self> {
275        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
276        match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
277            TensorPrimitive::QFloat(tensor) => tensor,
278            _ => unreachable!(),
279        }
280    }
281
282    /// Dilates an input tensor with the specified kernel.
283    fn q_dilate(
284        input: QuantizedTensor<Self>,
285        kernel: BoolTensor<Self>,
286        opts: MorphOptions<Self, Float>,
287    ) -> QuantizedTensor<Self> {
288        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
289        match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
290            TensorPrimitive::QFloat(tensor) => tensor,
291            _ => unreachable!(),
292        }
293    }
294}