mod hardware_accelerated;
mod prefix_sum;
use burn_cubecl::{
BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement,
ops::numeric::{full_client, zeros_client},
tensor::CubeTensor,
};
use burn_tensor::Shape;
pub use hardware_accelerated::*;
use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive};
pub(crate) fn stats_from_opts<R, F, I, BT>(
l: CubeTensor<R>,
opts: ConnectedStatsOptions,
) -> ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
let [height, width] = l.shape.dims();
let shape = Shape::new([height * width]);
let zeros = || {
zeros_client::<R>(
l.client.clone(),
l.device.clone(),
shape.clone(),
I::dtype(),
)
};
let max = I::max_value();
let max = || full_client::<R, I>(l.client.clone(), shape.clone(), l.device.clone(), max);
let dummy = || {
CubeTensor::new_contiguous(
l.client.clone(),
l.device.clone(),
shape.clone(),
l.handle.clone(),
l.dtype,
)
};
ConnectedStatsPrimitive {
area: (opts != ConnectedStatsOptions::none())
.then(zeros)
.unwrap_or_else(dummy),
left: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
top: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
max_label: zeros_client::<R>(
l.client.clone(),
l.device.clone(),
Shape::new([1]),
I::dtype(),
),
}
}