burn_vision/
tensor.rs

1use burn_tensor::{
2    BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,
3    ops::BoolTensor,
4};
5
6use crate::{
7    BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, VisionBackend,
8};
9
10/// Connected components tensor extensions
11pub trait ConnectedComponents<B: Backend> {
12    /// Computes the connected components labeled image of boolean image with 4 or 8 way
13    /// connectivity - returns a tensor of the component label of each pixel.
14    ///
15    /// `img`- The boolean image tensor in the format [batches, height, width]
16    fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;
17
18    /// Computes the connected components labeled image of boolean image with 4 or 8 way
19    /// connectivity and collects statistics on each component - returns a tensor of the component
20    /// label of each pixel, along with stats collected for each component.
21    ///
22    /// `img`- The boolean image tensor in the format [batches, height, width]
23    fn connected_components_with_stats(
24        self,
25        connectivity: Connectivity,
26        options: ConnectedStatsOptions,
27    ) -> (Tensor<B, 2, Int>, ConnectedStats<B>);
28}
29
30/// Morphology tensor operations
31pub trait Morphology<B: Backend, K: TensorKind<B>> {
32    /// Erodes this tensor using the specified kernel.
33    /// Assumes NHWC layout.
34    fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
35    /// Dilates this tensor using the specified kernel.
36    /// Assumes NHWC layout.
37    fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
38}
39
40/// Morphology tensor operations
41pub trait MorphologyKind<B: Backend>: BasicOps<B> {
42    /// Erodes this tensor using the specified kernel
43    fn erode(
44        tensor: Self::Primitive,
45        kernel: BoolTensor<B>,
46        opts: MorphOptions<B, Self>,
47    ) -> Self::Primitive;
48    /// Dilates this tensor using the specified kernel
49    fn dilate(
50        tensor: Self::Primitive,
51        kernel: BoolTensor<B>,
52        opts: MorphOptions<B, Self>,
53    ) -> Self::Primitive;
54}
55
56impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
57    fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
58        Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
59    }
60
61    fn connected_components_with_stats(
62        self,
63        connectivity: Connectivity,
64        options: ConnectedStatsOptions,
65    ) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {
66        let (labels, stats) =
67            B::connected_components_with_stats(self.into_primitive(), connectivity, options);
68        (Tensor::from_primitive(labels), stats.into())
69    }
70}
71
72impl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {
73    fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
74        Tensor::new(K::erode(
75            self.into_primitive(),
76            kernel.into_primitive(),
77            opts,
78        ))
79    }
80
81    fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
82        Tensor::new(K::dilate(
83            self.into_primitive(),
84            kernel.into_primitive(),
85            opts,
86        ))
87    }
88}
89
90impl<B: VisionBackend> MorphologyKind<B> for Float {
91    fn erode(
92        tensor: Self::Primitive,
93        kernel: BoolTensor<B>,
94        opts: MorphOptions<B, Self>,
95    ) -> Self::Primitive {
96        match tensor {
97            TensorPrimitive::Float(tensor) => {
98                TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))
99            }
100            TensorPrimitive::QFloat(tensor) => {
101                TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))
102            }
103        }
104    }
105
106    fn dilate(
107        tensor: Self::Primitive,
108        kernel: BoolTensor<B>,
109        opts: MorphOptions<B, Self>,
110    ) -> Self::Primitive {
111        match tensor {
112            TensorPrimitive::Float(tensor) => {
113                TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))
114            }
115            TensorPrimitive::QFloat(tensor) => {
116                TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))
117            }
118        }
119    }
120}
121
122impl<B: VisionBackend> MorphologyKind<B> for Int {
123    fn erode(
124        tensor: Self::Primitive,
125        kernel: BoolTensor<B>,
126        opts: MorphOptions<B, Self>,
127    ) -> Self::Primitive {
128        B::int_erode(tensor, kernel, opts)
129    }
130
131    fn dilate(
132        tensor: Self::Primitive,
133        kernel: BoolTensor<B>,
134        opts: MorphOptions<B, Self>,
135    ) -> Self::Primitive {
136        B::int_dilate(tensor, kernel, opts)
137    }
138}
139
140impl<B: VisionBackend> MorphologyKind<B> for Bool {
141    fn erode(
142        tensor: Self::Primitive,
143        kernel: BoolTensor<B>,
144        opts: MorphOptions<B, Self>,
145    ) -> Self::Primitive {
146        B::bool_erode(tensor, kernel, opts)
147    }
148
149    fn dilate(
150        tensor: Self::Primitive,
151        kernel: BoolTensor<B>,
152        opts: MorphOptions<B, Self>,
153    ) -> Self::Primitive {
154        B::bool_dilate(tensor, kernel, opts)
155    }
156}