burn_vision/backends/cube/
ops.rs

1use crate::{
2    BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
3    IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
4};
5use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
6
7use burn_tensor::{
8    Element,
9    ops::{BoolTensor, IntTensor},
10};
11
12use super::connected_components::hardware_accelerated;
13
14impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
15where
16    R: CubeRuntime,
17    F: FloatElement,
18    I: IntElement,
19    BT: BoolElement,
20{
21    fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
22        hardware_accelerated::<R, F, I, BT>(
23            img.clone(),
24            ConnectedStatsOptions::none(),
25            connectivity,
26        )
27        .map(|it| it.0)
28        .unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
29    }
30
31    fn connected_components_with_stats(
32        img: BoolTensor<Self>,
33        connectivity: Connectivity,
34        opts: ConnectedStatsOptions,
35    ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
36        hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
37            cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
38        })
39    }
40}
41
42impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
43where
44    R: CubeRuntime,
45    F: FloatElement,
46    I: IntElement,
47    BT: BoolElement,
48{
49}
50impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
51where
52    R: CubeRuntime,
53    F: FloatElement,
54    I: IntElement,
55    BT: BoolElement,
56{
57}
58impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
59where
60    R: CubeRuntime,
61    F: FloatElement,
62    I: IntElement,
63    BT: BoolElement,
64{
65}
66impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
67where
68    R: CubeRuntime,
69    F: FloatElement,
70    I: IntElement,
71    BT: BoolElement,
72{
73}
74
75#[cfg(feature = "fusion")]
76mod fusion {
77    use super::*;
78    use burn_fusion::{
79        Fusion, FusionBackend, FusionRuntime,
80        client::FusionClient,
81        stream::{Operation, OperationStreams},
82    };
83    use burn_ir::{CustomOpIr, HandleContainer, OperationIr};
84
85    impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
86        fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
87            let height = img.shape[0];
88            let width = img.shape[1];
89            let client = img.client.clone();
90
91            #[derive(derive_new::new, Clone, Debug)]
92            struct ConnComp<B> {
93                desc: CustomOpIr,
94                conn: Connectivity,
95                _b: core::marker::PhantomData<B>,
96            }
97
98            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
99                fn execute(
100                    &self,
101                    handles: &mut HandleContainer<
102                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,
103                    >,
104                ) {
105                    let ([img], [labels]) = self.desc.as_fixed();
106                    let input = handles.get_bool_tensor::<B1>(img);
107                    let output = B1::connected_components(input, self.conn);
108
109                    handles.register_int_tensor::<B1>(&labels.id, output);
110                }
111            }
112
113            let mut streams = OperationStreams::default();
114            streams.tensor(&img);
115            let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
116
117            let desc =
118                CustomOpIr::new("connected_components", &[img.into_ir()], &[out.to_ir_out()]);
119            client.register(
120                streams,
121                OperationIr::Custom(desc.clone()),
122                ConnComp::<B>::new(desc, conn),
123            );
124
125            out
126        }
127
128        fn connected_components_with_stats(
129            img: BoolTensor<Self>,
130            conn: Connectivity,
131            opts: ConnectedStatsOptions,
132        ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
133            let height = img.shape[0];
134            let width = img.shape[1];
135            let client = img.client.clone();
136
137            #[derive(derive_new::new, Clone, Debug)]
138            struct ConnCompStats<B> {
139                desc: CustomOpIr,
140                conn: Connectivity,
141                opts: ConnectedStatsOptions,
142                _b: core::marker::PhantomData<B>,
143            }
144
145            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
146                fn execute(
147                    &self,
148                    handles: &mut HandleContainer<
149                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,
150                    >,
151                ) {
152                    let ([img], [labels, area, left, top, right, bottom, max_label]) =
153                        self.desc.as_fixed();
154                    let input = handles.get_bool_tensor::<B1>(img);
155                    let (output, stats) =
156                        B1::connected_components_with_stats(input, self.conn, self.opts);
157
158                    handles.register_int_tensor::<B1>(&labels.id, output);
159                    handles.register_int_tensor::<B1>(&area.id, stats.area);
160                    handles.register_int_tensor::<B1>(&left.id, stats.left);
161                    handles.register_int_tensor::<B1>(&top.id, stats.top);
162                    handles.register_int_tensor::<B1>(&right.id, stats.right);
163                    handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
164                    handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
165                }
166            }
167
168            let mut streams = OperationStreams::default();
169            streams.tensor(&img);
170            let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
171            let area = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
172            let left = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
173            let top = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
174            let right = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
175            let bottom = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
176            let max_label = client.tensor_uninitialized(vec![1], B::IntElem::dtype());
177
178            let desc = CustomOpIr::new(
179                "connected_components",
180                &[img.into_ir()],
181                &[
182                    out.to_ir_out(),
183                    area.to_ir_out(),
184                    left.to_ir_out(),
185                    top.to_ir_out(),
186                    right.to_ir_out(),
187                    bottom.to_ir_out(),
188                    max_label.to_ir_out(),
189                ],
190            );
191            client.register(
192                streams,
193                OperationIr::Custom(desc.clone()),
194                ConnCompStats::<B>::new(desc, conn, opts),
195            );
196
197            let stats = ConnectedStatsPrimitive {
198                area,
199                left,
200                top,
201                right,
202                bottom,
203                max_label,
204            };
205            (out, stats)
206        }
207    }
208    impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
209    impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
210    impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
211    impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
212}