burn_vision/backends/cube/
ops.rs

1use crate::{
2    BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
3    IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
4};
5use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
6
7use burn_tensor::{
8    Element,
9    ops::{BoolTensor, IntTensor},
10};
11
12use super::connected_components::hardware_accelerated;
13
14impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
15where
16    R: CubeRuntime,
17    F: FloatElement,
18    I: IntElement,
19    BT: BoolElement,
20{
21    fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
22        hardware_accelerated::<R, F, I, BT>(
23            img.clone(),
24            ConnectedStatsOptions::none(),
25            connectivity,
26        )
27        .map(|it| it.0)
28        .unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
29    }
30
31    fn connected_components_with_stats(
32        img: BoolTensor<Self>,
33        connectivity: Connectivity,
34        opts: ConnectedStatsOptions,
35    ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
36        hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
37            cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
38        })
39    }
40}
41
42impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
43where
44    R: CubeRuntime,
45    F: FloatElement,
46    I: IntElement,
47    BT: BoolElement,
48{
49}
50impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
51where
52    R: CubeRuntime,
53    F: FloatElement,
54    I: IntElement,
55    BT: BoolElement,
56{
57}
58impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
59where
60    R: CubeRuntime,
61    F: FloatElement,
62    I: IntElement,
63    BT: BoolElement,
64{
65}
66impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
67where
68    R: CubeRuntime,
69    F: FloatElement,
70    I: IntElement,
71    BT: BoolElement,
72{
73}
74
75#[cfg(feature = "fusion")]
76mod fusion {
77    use super::*;
78    use burn_fusion::{
79        Fusion, FusionBackend, FusionRuntime, client::FusionClient, stream::Operation,
80    };
81    use burn_ir::{CustomOpIr, HandleContainer, OperationIr};
82
83    impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
84        fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
85            let height = img.shape[0];
86            let width = img.shape[1];
87            let client = img.client.clone();
88
89            #[derive(derive_new::new)]
90            struct ConnComp<B> {
91                desc: CustomOpIr,
92                conn: Connectivity,
93                _b: core::marker::PhantomData<B>,
94            }
95
96            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
97                fn execute(
98                    self: Box<Self>,
99                    handles: &mut HandleContainer<
100                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,
101                    >,
102                ) {
103                    let ([img], [labels]) = self.desc.consume();
104                    let input = handles.get_bool_tensor::<B1>(&img);
105                    let output = B1::connected_components(input, self.conn);
106
107                    handles.register_int_tensor::<B1>(&labels.id, output);
108                }
109            }
110
111            let stream = img.stream;
112            let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
113
114            let desc =
115                CustomOpIr::new("connected_components", &[img.into_ir()], &[out.to_ir_out()]);
116            client.register(
117                vec![stream],
118                OperationIr::Custom(desc.clone()),
119                ConnComp::<B>::new(desc, conn),
120            );
121
122            out
123        }
124
125        fn connected_components_with_stats(
126            img: BoolTensor<Self>,
127            conn: Connectivity,
128            opts: ConnectedStatsOptions,
129        ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
130            let height = img.shape[0];
131            let width = img.shape[1];
132            let client = img.client.clone();
133
134            #[derive(derive_new::new)]
135            struct ConnCompStats<B> {
136                desc: CustomOpIr,
137                conn: Connectivity,
138                opts: ConnectedStatsOptions,
139                _b: core::marker::PhantomData<B>,
140            }
141
142            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
143                fn execute(
144                    self: Box<Self>,
145                    handles: &mut HandleContainer<
146                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,
147                    >,
148                ) {
149                    let ([img], [labels, area, left, top, right, bottom, max_label]) =
150                        self.desc.consume();
151                    let input = handles.get_bool_tensor::<B1>(&img);
152                    let (output, stats) =
153                        B1::connected_components_with_stats(input, self.conn, self.opts);
154
155                    handles.register_int_tensor::<B1>(&labels.id, output);
156                    handles.register_int_tensor::<B1>(&area.id, stats.area);
157                    handles.register_int_tensor::<B1>(&left.id, stats.left);
158                    handles.register_int_tensor::<B1>(&top.id, stats.top);
159                    handles.register_int_tensor::<B1>(&right.id, stats.right);
160                    handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
161                    handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
162                }
163            }
164
165            let stream = img.stream;
166            let out = client.tensor_uninitialized(vec![height, width], B::IntElem::dtype());
167            let area = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
168            let left = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
169            let top = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
170            let right = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
171            let bottom = client.tensor_uninitialized(vec![height * width], B::IntElem::dtype());
172            let max_label = client.tensor_uninitialized(vec![1], B::IntElem::dtype());
173
174            let desc = CustomOpIr::new(
175                "connected_components",
176                &[img.into_ir()],
177                &[
178                    out.to_ir_out(),
179                    area.to_ir_out(),
180                    left.to_ir_out(),
181                    top.to_ir_out(),
182                    right.to_ir_out(),
183                    bottom.to_ir_out(),
184                    max_label.to_ir_out(),
185                ],
186            );
187            client.register(
188                vec![stream],
189                OperationIr::Custom(desc.clone()),
190                ConnCompStats::<B>::new(desc, conn, opts),
191            );
192
193            let stats = ConnectedStatsPrimitive {
194                area,
195                left,
196                top,
197                right,
198                bottom,
199                max_label,
200            };
201            (out, stats)
202        }
203    }
204    impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
205    impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
206    impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
207    impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
208}