1use burn_tensor::{
2 BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,
3 ops::BoolTensor,
4};
5
6use crate::{
7 BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions,
8 VisionBackend,
9};
10
11pub trait ConnectedComponents<B: Backend> {
13 fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;
18
19 fn connected_components_with_stats(
25 self,
26 connectivity: Connectivity,
27 options: ConnectedStatsOptions,
28 ) -> (Tensor<B, 2, Int>, ConnectedStats<B>);
29}
30
31pub trait Morphology<B: Backend, K: TensorKind<B>> {
33 fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
36 fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
39}
40
41pub trait MorphologyKind<B: Backend>: BasicOps<B> {
43 fn erode(
45 tensor: Self::Primitive,
46 kernel: BoolTensor<B>,
47 opts: MorphOptions<B, Self>,
48 ) -> Self::Primitive;
49 fn dilate(
51 tensor: Self::Primitive,
52 kernel: BoolTensor<B>,
53 opts: MorphOptions<B, Self>,
54 ) -> Self::Primitive;
55}
56
57pub trait Nms<B: Backend> {
59 fn nms(self, scores: Tensor<B, 1, Float>, opts: NmsOptions) -> Tensor<B, 1, Int>;
73}
74
75impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
76 fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
77 Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
78 }
79
80 fn connected_components_with_stats(
81 self,
82 connectivity: Connectivity,
83 options: ConnectedStatsOptions,
84 ) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {
85 let (labels, stats) =
86 B::connected_components_with_stats(self.into_primitive(), connectivity, options);
87 (Tensor::from_primitive(labels), stats.into())
88 }
89}
90
91impl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {
92 fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
93 Tensor::new(K::erode(
94 self.into_primitive(),
95 kernel.into_primitive(),
96 opts,
97 ))
98 }
99
100 fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
101 Tensor::new(K::dilate(
102 self.into_primitive(),
103 kernel.into_primitive(),
104 opts,
105 ))
106 }
107}
108
109impl<B: VisionBackend> MorphologyKind<B> for Float {
110 fn erode(
111 tensor: Self::Primitive,
112 kernel: BoolTensor<B>,
113 opts: MorphOptions<B, Self>,
114 ) -> Self::Primitive {
115 match tensor {
116 TensorPrimitive::Float(tensor) => {
117 TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))
118 }
119 TensorPrimitive::QFloat(tensor) => {
120 TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))
121 }
122 }
123 }
124
125 fn dilate(
126 tensor: Self::Primitive,
127 kernel: BoolTensor<B>,
128 opts: MorphOptions<B, Self>,
129 ) -> Self::Primitive {
130 match tensor {
131 TensorPrimitive::Float(tensor) => {
132 TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))
133 }
134 TensorPrimitive::QFloat(tensor) => {
135 TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))
136 }
137 }
138 }
139}
140
141impl<B: VisionBackend> MorphologyKind<B> for Int {
142 fn erode(
143 tensor: Self::Primitive,
144 kernel: BoolTensor<B>,
145 opts: MorphOptions<B, Self>,
146 ) -> Self::Primitive {
147 B::int_erode(tensor, kernel, opts)
148 }
149
150 fn dilate(
151 tensor: Self::Primitive,
152 kernel: BoolTensor<B>,
153 opts: MorphOptions<B, Self>,
154 ) -> Self::Primitive {
155 B::int_dilate(tensor, kernel, opts)
156 }
157}
158
159impl<B: VisionBackend> MorphologyKind<B> for Bool {
160 fn erode(
161 tensor: Self::Primitive,
162 kernel: BoolTensor<B>,
163 opts: MorphOptions<B, Self>,
164 ) -> Self::Primitive {
165 B::bool_erode(tensor, kernel, opts)
166 }
167
168 fn dilate(
169 tensor: Self::Primitive,
170 kernel: BoolTensor<B>,
171 opts: MorphOptions<B, Self>,
172 ) -> Self::Primitive {
173 B::bool_dilate(tensor, kernel, opts)
174 }
175}
176
177impl<B: VisionBackend> Nms<B> for Tensor<B, 2> {
178 fn nms(self, scores: Tensor<B, 1>, options: NmsOptions) -> Tensor<B, 1, Int> {
179 match (self.into_primitive(), scores.into_primitive()) {
180 (TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => {
181 Tensor::<B, 1, Int>::from_primitive(B::nms(boxes, scores, options))
182 }
183 _ => todo!("Quantized inputs are not yet supported"),
184 }
185 }
186}