1use crate::{
2 BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
3 IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
4};
5use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
6
7use burn_tensor::{
8 Element,
9 ops::{BoolTensor, IntTensor},
10};
11
12use super::connected_components::hardware_accelerated;
13
14impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
15where
16 R: CubeRuntime,
17 F: FloatElement,
18 I: IntElement,
19 BT: BoolElement,
20{
21 fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
22 hardware_accelerated::<R, F, I, BT>(
23 img.clone(),
24 ConnectedStatsOptions::none(),
25 connectivity,
26 )
27 .map(|it| it.0)
28 .unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
29 }
30
31 fn connected_components_with_stats(
32 img: BoolTensor<Self>,
33 connectivity: Connectivity,
34 opts: ConnectedStatsOptions,
35 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
36 hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
37 cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
38 })
39 }
40}
41
42impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
43where
44 R: CubeRuntime,
45 F: FloatElement,
46 I: IntElement,
47 BT: BoolElement,
48{
49}
50impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
51where
52 R: CubeRuntime,
53 F: FloatElement,
54 I: IntElement,
55 BT: BoolElement,
56{
57}
58impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
59where
60 R: CubeRuntime,
61 F: FloatElement,
62 I: IntElement,
63 BT: BoolElement,
64{
65}
66impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
67where
68 R: CubeRuntime,
69 F: FloatElement,
70 I: IntElement,
71 BT: BoolElement,
72{
73}
74
75#[cfg(feature = "fusion")]
76mod fusion {
77 use super::*;
78 use burn_fusion::{
79 Fusion, FusionBackend, FusionRuntime,
80 stream::{Operation, OperationStreams},
81 };
82 use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr};
83 use burn_tensor::Shape;
84
85 impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
86 fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
87 let height = img.shape[0];
88 let width = img.shape[1];
89 let client = img.client.clone();
90
91 #[derive(derive_new::new, Clone, Debug)]
92 struct ConnComp<B> {
93 desc: CustomOpIr,
94 conn: Connectivity,
95 _b: core::marker::PhantomData<B>,
96 }
97
98 impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
99 fn execute(
100 &self,
101 handles: &mut HandleContainer<
102 <B1::FusionRuntime as FusionRuntime>::FusionHandle,
103 >,
104 ) {
105 let ([img], [labels]) = self.desc.as_fixed();
106 let input = handles.get_bool_tensor::<B1>(img);
107 let output = B1::connected_components(input, self.conn);
108
109 handles.register_int_tensor::<B1>(&labels.id, output);
110 }
111 }
112
113 let streams = OperationStreams::with_inputs([&img]);
114 let out = TensorIr::uninit(
115 client.create_empty_handle(),
116 Shape::new([height, width]),
117 B::IntElem::dtype(),
118 );
119
120 let desc = CustomOpIr::new("connected_components", &[img.into_ir()], &[out]);
121 client
122 .register(
123 streams,
124 OperationIr::Custom(desc.clone()),
125 ConnComp::<B>::new(desc, conn),
126 )
127 .output()
128 }
129
130 fn connected_components_with_stats(
131 img: BoolTensor<Self>,
132 conn: Connectivity,
133 opts: ConnectedStatsOptions,
134 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
135 let height = img.shape[0];
136 let width = img.shape[1];
137 let client = img.client.clone();
138
139 #[derive(derive_new::new, Clone, Debug)]
140 struct ConnCompStats<B> {
141 desc: CustomOpIr,
142 conn: Connectivity,
143 opts: ConnectedStatsOptions,
144 _b: core::marker::PhantomData<B>,
145 }
146
147 impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
148 fn execute(
149 &self,
150 handles: &mut HandleContainer<
151 <B1::FusionRuntime as FusionRuntime>::FusionHandle,
152 >,
153 ) {
154 let ([img], [labels, area, left, top, right, bottom, max_label]) =
155 self.desc.as_fixed();
156 let input = handles.get_bool_tensor::<B1>(img);
157 let (output, stats) =
158 B1::connected_components_with_stats(input, self.conn, self.opts);
159
160 handles.register_int_tensor::<B1>(&labels.id, output);
161 handles.register_int_tensor::<B1>(&area.id, stats.area);
162 handles.register_int_tensor::<B1>(&left.id, stats.left);
163 handles.register_int_tensor::<B1>(&top.id, stats.top);
164 handles.register_int_tensor::<B1>(&right.id, stats.right);
165 handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
166 handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
167 }
168 }
169
170 let dtype = B::IntElem::dtype();
171 let shape = Shape::new([height, width]);
172 let shape_flat = shape.clone().flatten();
173 let streams = OperationStreams::with_inputs([&img]);
174 let out = TensorIr::uninit(client.create_empty_handle(), shape.clone(), dtype);
175 let area = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
176 let left = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
177 let top = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
178 let right = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
179 let bottom = TensorIr::uninit(client.create_empty_handle(), shape_flat, dtype);
180 let max_label = TensorIr::uninit(client.create_empty_handle(), [1].into(), dtype);
181
182 let desc = CustomOpIr::new(
183 "connected_components",
184 &[img.into_ir()],
185 &[out, area, left, top, right, bottom, max_label],
186 );
187 let [out, area, left, top, right, bottom, max_label] = client
188 .register(
189 streams,
190 OperationIr::Custom(desc.clone()),
191 ConnCompStats::<B>::new(desc, conn, opts),
192 )
193 .try_into()
194 .unwrap();
195
196 let stats = ConnectedStatsPrimitive {
197 area,
198 left,
199 top,
200 right,
201 bottom,
202 max_label,
203 };
204 (out, stats)
205 }
206 }
207 impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
208 impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
209 impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
210 impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
211}