Skip to main content

burn_vision/
tensor.rs

1use burn_tensor::{
2    BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,
3    ops::BoolTensor,
4};
5
6use crate::{
7    BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions,
8    VisionBackend,
9};
10
11/// Connected components tensor extensions
12pub trait ConnectedComponents<B: Backend> {
13    /// Computes the connected components labeled image of boolean image with 4 or 8 way
14    /// connectivity - returns a tensor of the component label of each pixel.
15    ///
16    /// `img`- The boolean image tensor in the format [batches, height, width]
17    fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;
18
19    /// Computes the connected components labeled image of boolean image with 4 or 8 way
20    /// connectivity and collects statistics on each component - returns a tensor of the component
21    /// label of each pixel, along with stats collected for each component.
22    ///
23    /// `img`- The boolean image tensor in the format [batches, height, width]
24    fn connected_components_with_stats(
25        self,
26        connectivity: Connectivity,
27        options: ConnectedStatsOptions,
28    ) -> (Tensor<B, 2, Int>, ConnectedStats<B>);
29}
30
31/// Morphology tensor operations
32pub trait Morphology<B: Backend, K: TensorKind<B>> {
33    /// Erodes this tensor using the specified kernel.
34    /// Assumes NHWC layout.
35    fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
36    /// Dilates this tensor using the specified kernel.
37    /// Assumes NHWC layout.
38    fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
39}
40
41/// Morphology tensor operations
42pub trait MorphologyKind<B: Backend>: BasicOps<B> {
43    /// Erodes this tensor using the specified kernel
44    fn erode(
45        tensor: Self::Primitive,
46        kernel: BoolTensor<B>,
47        opts: MorphOptions<B, Self>,
48    ) -> Self::Primitive;
49    /// Dilates this tensor using the specified kernel
50    fn dilate(
51        tensor: Self::Primitive,
52        kernel: BoolTensor<B>,
53        opts: MorphOptions<B, Self>,
54    ) -> Self::Primitive;
55}
56
57/// Non-maximum suppression tensor operations
58pub trait Nms<B: Backend> {
59    /// Perform Non-Maximum Suppression on this tensor of bounding boxes.
60    ///
61    /// Returns indices of kept boxes after suppressing overlapping detections.
62    /// Boxes are processed in descending score order; a box suppresses all
63    /// lower-scoring boxes with IoU > threshold.
64    ///
65    /// # Arguments
66    /// * `self` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
67    /// * `scores` - Confidence scores as \[N\] tensor
68    /// * `options` - NMS options (IoU threshold, score threshold, max boxes)
69    ///
70    /// # Returns
71    /// Indices of kept boxes as \[M\] tensor where M <= N
72    fn nms(self, scores: Tensor<B, 1, Float>, opts: NmsOptions) -> Tensor<B, 1, Int>;
73}
74
75impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
76    fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
77        Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
78    }
79
80    fn connected_components_with_stats(
81        self,
82        connectivity: Connectivity,
83        options: ConnectedStatsOptions,
84    ) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {
85        let (labels, stats) =
86            B::connected_components_with_stats(self.into_primitive(), connectivity, options);
87        (Tensor::from_primitive(labels), stats.into())
88    }
89}
90
91impl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {
92    fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
93        Tensor::new(K::erode(
94            self.into_primitive(),
95            kernel.into_primitive(),
96            opts,
97        ))
98    }
99
100    fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
101        Tensor::new(K::dilate(
102            self.into_primitive(),
103            kernel.into_primitive(),
104            opts,
105        ))
106    }
107}
108
109impl<B: VisionBackend> MorphologyKind<B> for Float {
110    fn erode(
111        tensor: Self::Primitive,
112        kernel: BoolTensor<B>,
113        opts: MorphOptions<B, Self>,
114    ) -> Self::Primitive {
115        match tensor {
116            TensorPrimitive::Float(tensor) => {
117                TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))
118            }
119            TensorPrimitive::QFloat(tensor) => {
120                TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))
121            }
122        }
123    }
124
125    fn dilate(
126        tensor: Self::Primitive,
127        kernel: BoolTensor<B>,
128        opts: MorphOptions<B, Self>,
129    ) -> Self::Primitive {
130        match tensor {
131            TensorPrimitive::Float(tensor) => {
132                TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))
133            }
134            TensorPrimitive::QFloat(tensor) => {
135                TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))
136            }
137        }
138    }
139}
140
141impl<B: VisionBackend> MorphologyKind<B> for Int {
142    fn erode(
143        tensor: Self::Primitive,
144        kernel: BoolTensor<B>,
145        opts: MorphOptions<B, Self>,
146    ) -> Self::Primitive {
147        B::int_erode(tensor, kernel, opts)
148    }
149
150    fn dilate(
151        tensor: Self::Primitive,
152        kernel: BoolTensor<B>,
153        opts: MorphOptions<B, Self>,
154    ) -> Self::Primitive {
155        B::int_dilate(tensor, kernel, opts)
156    }
157}
158
159impl<B: VisionBackend> MorphologyKind<B> for Bool {
160    fn erode(
161        tensor: Self::Primitive,
162        kernel: BoolTensor<B>,
163        opts: MorphOptions<B, Self>,
164    ) -> Self::Primitive {
165        B::bool_erode(tensor, kernel, opts)
166    }
167
168    fn dilate(
169        tensor: Self::Primitive,
170        kernel: BoolTensor<B>,
171        opts: MorphOptions<B, Self>,
172    ) -> Self::Primitive {
173        B::bool_dilate(tensor, kernel, opts)
174    }
175}
176
177impl<B: VisionBackend> Nms<B> for Tensor<B, 2> {
178    fn nms(self, scores: Tensor<B, 1>, options: NmsOptions) -> Tensor<B, 1, Int> {
179        match (self.into_primitive(), scores.into_primitive()) {
180            (TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => {
181                Tensor::<B, 1, Int>::from_primitive(B::nms(boxes, scores, options))
182            }
183            _ => todo!("Quantized inputs are not yet supported"),
184        }
185    }
186}