1use burn_tensor::{
2 BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,
3 ops::BoolTensor,
4};
5
6use crate::{
7 BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, VisionBackend,
8};
9
10pub trait ConnectedComponents<B: Backend> {
12 fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;
17
18 fn connected_components_with_stats(
24 self,
25 connectivity: Connectivity,
26 options: ConnectedStatsOptions,
27 ) -> (Tensor<B, 2, Int>, ConnectedStats<B>);
28}
29
30pub trait Morphology<B: Backend, K: TensorKind<B>> {
32 fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
35 fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
38}
39
40pub trait MorphologyKind<B: Backend>: BasicOps<B> {
42 fn erode(
44 tensor: Self::Primitive,
45 kernel: BoolTensor<B>,
46 opts: MorphOptions<B, Self>,
47 ) -> Self::Primitive;
48 fn dilate(
50 tensor: Self::Primitive,
51 kernel: BoolTensor<B>,
52 opts: MorphOptions<B, Self>,
53 ) -> Self::Primitive;
54}
55
56impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
57 fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
58 Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
59 }
60
61 fn connected_components_with_stats(
62 self,
63 connectivity: Connectivity,
64 options: ConnectedStatsOptions,
65 ) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {
66 let (labels, stats) =
67 B::connected_components_with_stats(self.into_primitive(), connectivity, options);
68 (Tensor::from_primitive(labels), stats.into())
69 }
70}
71
72impl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {
73 fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
74 Tensor::new(K::erode(
75 self.into_primitive(),
76 kernel.into_primitive(),
77 opts,
78 ))
79 }
80
81 fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
82 Tensor::new(K::dilate(
83 self.into_primitive(),
84 kernel.into_primitive(),
85 opts,
86 ))
87 }
88}
89
90impl<B: VisionBackend> MorphologyKind<B> for Float {
91 fn erode(
92 tensor: Self::Primitive,
93 kernel: BoolTensor<B>,
94 opts: MorphOptions<B, Self>,
95 ) -> Self::Primitive {
96 match tensor {
97 TensorPrimitive::Float(tensor) => {
98 TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))
99 }
100 TensorPrimitive::QFloat(tensor) => {
101 TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))
102 }
103 }
104 }
105
106 fn dilate(
107 tensor: Self::Primitive,
108 kernel: BoolTensor<B>,
109 opts: MorphOptions<B, Self>,
110 ) -> Self::Primitive {
111 match tensor {
112 TensorPrimitive::Float(tensor) => {
113 TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))
114 }
115 TensorPrimitive::QFloat(tensor) => {
116 TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))
117 }
118 }
119 }
120}
121
122impl<B: VisionBackend> MorphologyKind<B> for Int {
123 fn erode(
124 tensor: Self::Primitive,
125 kernel: BoolTensor<B>,
126 opts: MorphOptions<B, Self>,
127 ) -> Self::Primitive {
128 B::int_erode(tensor, kernel, opts)
129 }
130
131 fn dilate(
132 tensor: Self::Primitive,
133 kernel: BoolTensor<B>,
134 opts: MorphOptions<B, Self>,
135 ) -> Self::Primitive {
136 B::int_dilate(tensor, kernel, opts)
137 }
138}
139
140impl<B: VisionBackend> MorphologyKind<B> for Bool {
141 fn erode(
142 tensor: Self::Primitive,
143 kernel: BoolTensor<B>,
144 opts: MorphOptions<B, Self>,
145 ) -> Self::Primitive {
146 B::bool_erode(tensor, kernel, opts)
147 }
148
149 fn dilate(
150 tensor: Self::Primitive,
151 kernel: BoolTensor<B>,
152 opts: MorphOptions<B, Self>,
153 ) -> Self::Primitive {
154 B::bool_dilate(tensor, kernel, opts)
155 }
156}