Skip to main content

burn_vision/backends/cube/connected_components/
mod.rs

1mod hardware_accelerated;
2
3/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops
4/// to really use it in a general case. Needs more work to use as a normal tensor method.
5mod prefix_sum;
6
7use burn_cubecl::{
8    BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement,
9    ops::numeric::{full_client, zeros_client},
10    tensor::CubeTensor,
11};
12use burn_tensor::Shape;
13pub use hardware_accelerated::*;
14
15use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive};
16
17pub(crate) fn stats_from_opts<R, F, I, BT>(
18    l: CubeTensor<R>,
19    opts: ConnectedStatsOptions,
20) -> ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>
21where
22    R: CubeRuntime,
23    F: FloatElement,
24    I: IntElement,
25    BT: BoolElement,
26{
27    let [height, width] = l.shape.dims();
28    let shape = Shape::new([height * width]);
29    let zeros = || {
30        zeros_client::<R>(
31            l.client.clone(),
32            l.device.clone(),
33            shape.clone(),
34            I::dtype(),
35        )
36    };
37    let max = I::max_value();
38    let max = || full_client::<R, I>(l.client.clone(), shape.clone(), l.device.clone(), max);
39    let dummy = || {
40        CubeTensor::new_contiguous(
41            l.client.clone(),
42            l.device.clone(),
43            shape.clone(),
44            l.handle.clone(),
45            l.dtype,
46        )
47    };
48    ConnectedStatsPrimitive {
49        area: (opts != ConnectedStatsOptions::none())
50            .then(zeros)
51            .unwrap_or_else(dummy),
52        left: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
53        top: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
54        right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
55        bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
56        max_label: zeros_client::<R>(
57            l.client.clone(),
58            l.device.clone(),
59            Shape::new([1]),
60            I::dtype(),
61        ),
62    }
63}