Skip to main content

burn_vision/backends/cube/
ops.rs

1use crate::{
2    BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
3    IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
4};
5use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
6
7use burn_tensor::{
8    Element,
9    ops::{BoolTensor, IntTensor},
10};
11
12use super::connected_components::hardware_accelerated;
13
14impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
15where
16    R: CubeRuntime,
17    F: FloatElement,
18    I: IntElement,
19    BT: BoolElement,
20{
21    fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
22        hardware_accelerated::<R, F, I, BT>(
23            img.clone(),
24            ConnectedStatsOptions::none(),
25            connectivity,
26        )
27        .map(|it| it.0)
28        .unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
29    }
30
31    fn connected_components_with_stats(
32        img: BoolTensor<Self>,
33        connectivity: Connectivity,
34        opts: ConnectedStatsOptions,
35    ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
36        hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
37            cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
38        })
39    }
40}
41
42impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
43where
44    R: CubeRuntime,
45    F: FloatElement,
46    I: IntElement,
47    BT: BoolElement,
48{
49}
50impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
51where
52    R: CubeRuntime,
53    F: FloatElement,
54    I: IntElement,
55    BT: BoolElement,
56{
57}
58impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
59where
60    R: CubeRuntime,
61    F: FloatElement,
62    I: IntElement,
63    BT: BoolElement,
64{
65}
66impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
67where
68    R: CubeRuntime,
69    F: FloatElement,
70    I: IntElement,
71    BT: BoolElement,
72{
73}
74
75#[cfg(feature = "fusion")]
76mod fusion {
77    use super::*;
78    use burn_fusion::{
79        Fusion, FusionBackend, FusionRuntime,
80        stream::{Operation, OperationStreams},
81    };
82    use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr};
83    use burn_tensor::Shape;
84
85    impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
86        fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
87            let height = img.shape[0];
88            let width = img.shape[1];
89            let client = img.client.clone();
90
91            #[derive(derive_new::new, Clone, Debug)]
92            struct ConnComp<B> {
93                desc: CustomOpIr,
94                conn: Connectivity,
95                _b: core::marker::PhantomData<B>,
96            }
97
98            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
99                fn execute(
100                    &self,
101                    handles: &mut HandleContainer<
102                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,
103                    >,
104                ) {
105                    let ([img], [labels]) = self.desc.as_fixed();
106                    let input = handles.get_bool_tensor::<B1>(img);
107                    let output = B1::connected_components(input, self.conn);
108
109                    handles.register_int_tensor::<B1>(&labels.id, output);
110                }
111            }
112
113            let streams = OperationStreams::with_inputs([&img]);
114            let out = TensorIr::uninit(
115                client.create_empty_handle(),
116                Shape::new([height, width]),
117                B::IntElem::dtype(),
118            );
119
120            let desc = CustomOpIr::new("connected_components", &[img.into_ir()], &[out]);
121            client
122                .register(
123                    streams,
124                    OperationIr::Custom(desc.clone()),
125                    ConnComp::<B>::new(desc, conn),
126                )
127                .output()
128        }
129
130        fn connected_components_with_stats(
131            img: BoolTensor<Self>,
132            conn: Connectivity,
133            opts: ConnectedStatsOptions,
134        ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
135            let height = img.shape[0];
136            let width = img.shape[1];
137            let client = img.client.clone();
138
139            #[derive(derive_new::new, Clone, Debug)]
140            struct ConnCompStats<B> {
141                desc: CustomOpIr,
142                conn: Connectivity,
143                opts: ConnectedStatsOptions,
144                _b: core::marker::PhantomData<B>,
145            }
146
147            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
148                fn execute(
149                    &self,
150                    handles: &mut HandleContainer<
151                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,
152                    >,
153                ) {
154                    let ([img], [labels, area, left, top, right, bottom, max_label]) =
155                        self.desc.as_fixed();
156                    let input = handles.get_bool_tensor::<B1>(img);
157                    let (output, stats) =
158                        B1::connected_components_with_stats(input, self.conn, self.opts);
159
160                    handles.register_int_tensor::<B1>(&labels.id, output);
161                    handles.register_int_tensor::<B1>(&area.id, stats.area);
162                    handles.register_int_tensor::<B1>(&left.id, stats.left);
163                    handles.register_int_tensor::<B1>(&top.id, stats.top);
164                    handles.register_int_tensor::<B1>(&right.id, stats.right);
165                    handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
166                    handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
167                }
168            }
169
170            let dtype = B::IntElem::dtype();
171            let shape = Shape::new([height, width]);
172            let shape_flat = shape.clone().flatten();
173            let streams = OperationStreams::with_inputs([&img]);
174            let out = TensorIr::uninit(client.create_empty_handle(), shape.clone(), dtype);
175            let area = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
176            let left = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
177            let top = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
178            let right = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
179            let bottom = TensorIr::uninit(client.create_empty_handle(), shape_flat, dtype);
180            let max_label = TensorIr::uninit(client.create_empty_handle(), [1].into(), dtype);
181
182            let desc = CustomOpIr::new(
183                "connected_components",
184                &[img.into_ir()],
185                &[out, area, left, top, right, bottom, max_label],
186            );
187            let [out, area, left, top, right, bottom, max_label] = client
188                .register(
189                    streams,
190                    OperationIr::Custom(desc.clone()),
191                    ConnCompStats::<B>::new(desc, conn, opts),
192                )
193                .try_into()
194                .unwrap();
195
196            let stats = ConnectedStatsPrimitive {
197                area,
198                left,
199                top,
200                right,
201                bottom,
202                max_label,
203            };
204            (out, stats)
205        }
206    }
207    impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
208    impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
209    impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
210    impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
211}