Skip to main content

burn_cubecl/kernel/mask/
mask_fill.rs

1use burn_backend::{DType, TensorMetadata};
2use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
3
4use crate::{
5    CubeRuntime,
6    kernel::utils::address_type,
7    ops::{max_vector_size_many, numeric::empty_device_dtype},
8    tensor::CubeTensor,
9};
10
11#[cube(launch_unchecked, address_type = "dynamic")]
12fn mask_fill_kernel<T: Numeric, B: Int, N: Size>(
13    input: &LinearView<Vector<T, N>>,
14    mask: &LinearView<Vector<B, N>>,
15    output: &mut LinearView<Vector<T, N>, ReadWrite>,
16    value: InputScalar,
17    #[define(T, B)] _dtypes: [StorageType; 2],
18) {
19    if !output.is_in_bounds(ABSOLUTE_POS) {
20        terminate!();
21    }
22
23    let mask = Vector::cast_from(mask[ABSOLUTE_POS]);
24    let input = input[ABSOLUTE_POS];
25    let value = Vector::new(value.get::<T>());
26
27    output[ABSOLUTE_POS] = select_many(mask, value, input);
28}
29
30#[derive(Clone, Copy, Debug)]
31/// Define how to run the mask fill kernel.
32///
33/// # Notes
34///
35/// All assertions should be done before choosing the strategy.
36pub enum MaskFillStrategy {
37    /// Don't mutate any input.
38    Readonly,
39    /// Reuse the input tensor inplace.
40    Inplace,
41}
42
43/// Execute the mask fill kernel with the given strategy.
44pub fn mask_fill<R: CubeRuntime>(
45    input: CubeTensor<R>,
46    mask: CubeTensor<R>,
47    value: InputScalar,
48    strategy: MaskFillStrategy,
49    dtype_bool: DType,
50) -> CubeTensor<R> {
51    let ndims = input.meta.num_dims();
52    let output = match strategy {
53        MaskFillStrategy::Readonly => empty_device_dtype(
54            input.client.clone(),
55            input.device.clone(),
56            input.shape(),
57            input.dtype,
58        ),
59        MaskFillStrategy::Inplace => input.clone(),
60    };
61
62    let vector_size = max_vector_size_many(&[&input, &mask], ndims - 1);
63    let working_units = input.meta.num_elements() / vector_size as usize;
64    let cube_dim = CubeDim::new(&input.client, working_units);
65    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
66
67    let out_arg = match strategy {
68        MaskFillStrategy::Readonly => output.clone().into_linear_view(),
69        MaskFillStrategy::Inplace => output.as_linear_view_alias(0),
70    };
71
72    let at = address_type!(input, mask, output);
73    let mask = mask.into_linear_view_like(&input);
74
75    unsafe {
76        mask_fill_kernel::launch_unchecked(
77            &output.client,
78            cube_count,
79            cube_dim,
80            at,
81            vector_size,
82            input.into_linear_view(),
83            mask,
84            out_arg,
85            value,
86            [output.dtype.into(), dtype_bool.into()],
87        );
88    }
89
90    output
91}