burn_vision/backends/cube/connected_components/
mod.rs1mod hardware_accelerated;
2
3mod prefix_sum;
6
7use burn_cubecl::{
8 BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement,
9 ops::numeric::{full_client, zeros_client},
10 tensor::CubeTensor,
11};
12use burn_tensor::Shape;
13pub use hardware_accelerated::*;
14
15use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive};
16
17pub(crate) fn stats_from_opts<R, F, I, BT>(
18 l: CubeTensor<R>,
19 opts: ConnectedStatsOptions,
20) -> ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>
21where
22 R: CubeRuntime,
23 F: FloatElement,
24 I: IntElement,
25 BT: BoolElement,
26{
27 let [height, width] = l.shape.dims();
28 let shape = Shape::new([height * width]);
29 let zeros = || {
30 zeros_client::<R>(
31 l.client.clone(),
32 l.device.clone(),
33 shape.clone(),
34 I::dtype(),
35 )
36 };
37 let max = I::max_value();
38 let max = || full_client::<R, I>(l.client.clone(), shape.clone(), l.device.clone(), max);
39 let dummy = || {
40 CubeTensor::new_contiguous(
41 l.client.clone(),
42 l.device.clone(),
43 shape.clone(),
44 l.handle.clone(),
45 l.dtype,
46 )
47 };
48 ConnectedStatsPrimitive {
49 area: (opts != ConnectedStatsOptions::none())
50 .then(zeros)
51 .unwrap_or_else(dummy),
52 left: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
53 top: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
54 right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
55 bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
56 max_label: zeros_client::<R>(
57 l.client.clone(),
58 l.device.clone(),
59 Shape::new([1]),
60 I::dtype(),
61 ),
62 }
63}