use burn_backend::DType;
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
use crate::{
CubeRuntime,
kernel::utils::{
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
},
ops::{max_line_size_many, numeric::empty_device_dtype},
tensor::CubeTensor,
};
#[cube(launch, address_type = "dynamic")]
fn mask_where_kernel<T: Numeric, B: Int>(
input: &LinearView<Line<T>>,
value: &LinearView<Line<T>>,
mask: &LinearView<Line<B>>,
output: &mut LinearView<Line<T>, ReadWrite>,
#[define(T, B)] _dtypes: [StorageType; 2],
) {
let pos = ABSOLUTE_POS;
if !output.is_in_bounds(pos) {
terminate!();
}
output[pos] = select_many(Line::cast_from(mask[pos]), value[pos], input[pos]);
}
#[derive(Clone, Copy, Debug)]
pub enum MaskWhereStrategy {
Readonly,
InplaceLhs,
InplaceRhs,
}
pub fn mask_where<R: CubeRuntime>(
input: CubeTensor<R>,
mask: CubeTensor<R>,
value: CubeTensor<R>,
strategy: MaskWhereStrategy,
dtype_bool: DType,
) -> CubeTensor<R> {
let line_size = max_line_size_many(&[&input, &mask, &value], input.meta.num_dims() - 1);
let working_units = input.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
let out_shape = broadcast_shape(&[&input, &mask, &value]);
let output = match strategy {
MaskWhereStrategy::Readonly => empty_device_dtype(
input.client.clone(),
input.device.clone(),
out_shape,
input.dtype,
),
MaskWhereStrategy::InplaceLhs => input.clone(),
MaskWhereStrategy::InplaceRhs => value.clone(),
};
let out = match strategy {
MaskWhereStrategy::Readonly => linear_view(&output, line_size),
MaskWhereStrategy::InplaceLhs => linear_view_alias(&output, line_size, 0),
MaskWhereStrategy::InplaceRhs => linear_view_alias(&output, line_size, 1),
};
mask_where_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, value, mask, output),
linear_view_ref(&input, &output, line_size),
linear_view_ref(&value, &output, line_size),
linear_view_ref(&mask, &output, line_size),
out,
[output.dtype.into(), dtype_bool.into()],
)
.expect("Kernel to never fail");
output
}