1use crate::{
2 BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
3 IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
4};
5use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
6
7use burn_tensor::{
8 Element,
9 ops::{BoolTensor, IntTensor},
10};
11
12use super::connected_components::hardware_accelerated;
13
14impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
15where
16 R: CubeRuntime,
17 F: FloatElement,
18 I: IntElement,
19 BT: BoolElement,
20{
21 fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
22 hardware_accelerated::<R, F, I, BT>(
23 img.clone(),
24 ConnectedStatsOptions::none(),
25 connectivity,
26 )
27 .map(|it| it.0)
28 .unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
29 }
30
31 fn connected_components_with_stats(
32 img: BoolTensor<Self>,
33 connectivity: Connectivity,
34 opts: ConnectedStatsOptions,
35 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
36 hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
37 cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
38 })
39 }
40}
41
42impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
43where
44 R: CubeRuntime,
45 F: FloatElement,
46 I: IntElement,
47 BT: BoolElement,
48{
49}
50impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
51where
52 R: CubeRuntime,
53 F: FloatElement,
54 I: IntElement,
55 BT: BoolElement,
56{
57}
58impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
59where
60 R: CubeRuntime,
61 F: FloatElement,
62 I: IntElement,
63 BT: BoolElement,
64{
65}
66impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
67where
68 R: CubeRuntime,
69 F: FloatElement,
70 I: IntElement,
71 BT: BoolElement,
72{
73}
74
75#[cfg(feature = "fusion")]
76mod fusion {
77 use super::*;
78 use burn_fusion::{
79 Fusion, FusionBackend, FusionRuntime,
80 client::FusionClient,
81 stream::{Operation, OperationStreams},
82 };
83 use burn_ir::{CustomOpIr, HandleContainer, OperationIr};
84
85 impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
86 fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
87 let height = img.shape[0];
88 let width = img.shape[1];
89 let client = img.client.clone();
90
91 #[derive(derive_new::new, Clone, Debug)]
92 struct ConnComp<B> {
93 desc: CustomOpIr,
94 conn: Connectivity,
95 _b: core::marker::PhantomData<B>,
96 }
97
98 impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
99 fn execute(
100 &self,
101 handles: &mut HandleContainer<
102 <B1::FusionRuntime as FusionRuntime>::FusionHandle,
103 >,
104 ) {
105 let ([img], [labels]) = self.desc.as_fixed();
106 let input = handles.get_bool_tensor::<B1>(img);
107 let output = B1::connected_components(input, self.conn);
108
109 handles.register_int_tensor::<B1>(&labels.id, output);
110 }
111 }
112
113 let mut streams = OperationStreams::default();
114 streams.tensor(&img);
115 let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
116
117 let desc =
118 CustomOpIr::new("connected_components", &[img.into_ir()], &[out.to_ir_out()]);
119 client.register(
120 streams,
121 OperationIr::Custom(desc.clone()),
122 ConnComp::<B>::new(desc, conn),
123 );
124
125 out
126 }
127
128 fn connected_components_with_stats(
129 img: BoolTensor<Self>,
130 conn: Connectivity,
131 opts: ConnectedStatsOptions,
132 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
133 let height = img.shape[0];
134 let width = img.shape[1];
135 let client = img.client.clone();
136
137 #[derive(derive_new::new, Clone, Debug)]
138 struct ConnCompStats<B> {
139 desc: CustomOpIr,
140 conn: Connectivity,
141 opts: ConnectedStatsOptions,
142 _b: core::marker::PhantomData<B>,
143 }
144
145 impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
146 fn execute(
147 &self,
148 handles: &mut HandleContainer<
149 <B1::FusionRuntime as FusionRuntime>::FusionHandle,
150 >,
151 ) {
152 let ([img], [labels, area, left, top, right, bottom, max_label]) =
153 self.desc.as_fixed();
154 let input = handles.get_bool_tensor::<B1>(img);
155 let (output, stats) =
156 B1::connected_components_with_stats(input, self.conn, self.opts);
157
158 handles.register_int_tensor::<B1>(&labels.id, output);
159 handles.register_int_tensor::<B1>(&area.id, stats.area);
160 handles.register_int_tensor::<B1>(&left.id, stats.left);
161 handles.register_int_tensor::<B1>(&top.id, stats.top);
162 handles.register_int_tensor::<B1>(&right.id, stats.right);
163 handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
164 handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
165 }
166 }
167
168 let mut streams = OperationStreams::default();
169 streams.tensor(&img);
170 let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
171 let area = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
172 let left = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
173 let top = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
174 let right = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
175 let bottom = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
176 let max_label = client.tensor_uninitialized(vec![1], B::IntElem::dtype());
177
178 let desc = CustomOpIr::new(
179 "connected_components",
180 &[img.into_ir()],
181 &[
182 out.to_ir_out(),
183 area.to_ir_out(),
184 left.to_ir_out(),
185 top.to_ir_out(),
186 right.to_ir_out(),
187 bottom.to_ir_out(),
188 max_label.to_ir_out(),
189 ],
190 );
191 client.register(
192 streams,
193 OperationIr::Custom(desc.clone()),
194 ConnCompStats::<B>::new(desc, conn, opts),
195 );
196
197 let stats = ConnectedStatsPrimitive {
198 area,
199 left,
200 top,
201 right,
202 bottom,
203 max_label,
204 };
205 (out, stats)
206 }
207 }
208 impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
209 impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
210 impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
211 impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
212}