1use crate::{
2 BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
3 IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
4};
5use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
6
7use burn_tensor::{
8 Element,
9 ops::{BoolTensor, IntTensor},
10};
11
12use super::connected_components::hardware_accelerated;
13
14impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
15where
16 R: CubeRuntime,
17 F: FloatElement,
18 I: IntElement,
19 BT: BoolElement,
20{
21 fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
22 hardware_accelerated::<R, F, I, BT>(
23 img.clone(),
24 ConnectedStatsOptions::none(),
25 connectivity,
26 )
27 .map(|it| it.0)
28 .unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
29 }
30
31 fn connected_components_with_stats(
32 img: BoolTensor<Self>,
33 connectivity: Connectivity,
34 opts: ConnectedStatsOptions,
35 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
36 hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
37 cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
38 })
39 }
40}
41
42impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
43where
44 R: CubeRuntime,
45 F: FloatElement,
46 I: IntElement,
47 BT: BoolElement,
48{
49}
50impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
51where
52 R: CubeRuntime,
53 F: FloatElement,
54 I: IntElement,
55 BT: BoolElement,
56{
57}
58impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
59where
60 R: CubeRuntime,
61 F: FloatElement,
62 I: IntElement,
63 BT: BoolElement,
64{
65}
66impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
67where
68 R: CubeRuntime,
69 F: FloatElement,
70 I: IntElement,
71 BT: BoolElement,
72{
73}
74
75#[cfg(feature = "fusion")]
76mod fusion {
77 use super::*;
78 use burn_fusion::{
79 Fusion, FusionBackend, FusionRuntime, client::FusionClient, stream::Operation,
80 };
81 use burn_ir::{CustomOpIr, HandleContainer, OperationIr};
82
83 impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
84 fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
85 let height = img.shape[0];
86 let width = img.shape[1];
87 let client = img.client.clone();
88
89 #[derive(derive_new::new)]
90 struct ConnComp<B> {
91 desc: CustomOpIr,
92 conn: Connectivity,
93 _b: core::marker::PhantomData<B>,
94 }
95
96 impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
97 fn execute(
98 self: Box<Self>,
99 handles: &mut HandleContainer<
100 <B1::FusionRuntime as FusionRuntime>::FusionHandle,
101 >,
102 ) {
103 let ([img], [labels]) = self.desc.consume();
104 let input = handles.get_bool_tensor::<B1>(&img);
105 let output = B1::connected_components(input, self.conn);
106
107 handles.register_int_tensor::<B1>(&labels.id, output);
108 }
109 }
110
111 let stream = img.stream;
112 let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
113
114 let desc =
115 CustomOpIr::new("connected_components", &[img.into_ir()], &[out.to_ir_out()]);
116 client.register(
117 vec![stream],
118 OperationIr::Custom(desc.clone()),
119 ConnComp::<B>::new(desc, conn),
120 );
121
122 out
123 }
124
125 fn connected_components_with_stats(
126 img: BoolTensor<Self>,
127 conn: Connectivity,
128 opts: ConnectedStatsOptions,
129 ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
130 let height = img.shape[0];
131 let width = img.shape[1];
132 let client = img.client.clone();
133
134 #[derive(derive_new::new)]
135 struct ConnCompStats<B> {
136 desc: CustomOpIr,
137 conn: Connectivity,
138 opts: ConnectedStatsOptions,
139 _b: core::marker::PhantomData<B>,
140 }
141
142 impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
143 fn execute(
144 self: Box<Self>,
145 handles: &mut HandleContainer<
146 <B1::FusionRuntime as FusionRuntime>::FusionHandle,
147 >,
148 ) {
149 let ([img], [labels, area, left, top, right, bottom, max_label]) =
150 self.desc.consume();
151 let input = handles.get_bool_tensor::<B1>(&img);
152 let (output, stats) =
153 B1::connected_components_with_stats(input, self.conn, self.opts);
154
155 handles.register_int_tensor::<B1>(&labels.id, output);
156 handles.register_int_tensor::<B1>(&area.id, stats.area);
157 handles.register_int_tensor::<B1>(&left.id, stats.left);
158 handles.register_int_tensor::<B1>(&top.id, stats.top);
159 handles.register_int_tensor::<B1>(&right.id, stats.right);
160 handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
161 handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
162 }
163 }
164
165 let stream = img.stream;
166 let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
167 let area = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
168 let left = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
169 let top = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
170 let right = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
171 let bottom = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
172 let max_label = client.tensor_uninitialized(vec![1], B::IntElem::dtype());
173
174 let desc = CustomOpIr::new(
175 "connected_components",
176 &[img.into_ir()],
177 &[
178 out.to_ir_out(),
179 area.to_ir_out(),
180 left.to_ir_out(),
181 top.to_ir_out(),
182 right.to_ir_out(),
183 bottom.to_ir_out(),
184 max_label.to_ir_out(),
185 ],
186 );
187 client.register(
188 vec![stream],
189 OperationIr::Custom(desc.clone()),
190 ConnCompStats::<B>::new(desc, conn, opts),
191 );
192
193 let stats = ConnectedStatsPrimitive {
194 area,
195 left,
196 top,
197 right,
198 bottom,
199 max_label,
200 };
201 (out, stats)
202 }
203 }
204 impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
205 impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
206 impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
207 impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
208}