burn_cubecl/kernel/mask/
mask_fill.rs

1use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*};
2
3use crate::{
4    BoolElement, CubeRuntime,
5    element::CubeElement,
6    ops::{max_line_size_many, numeric::empty_device},
7    tensor::CubeTensor,
8};
9
10#[cube(launch_unchecked)]
11fn mask_fill_readonly_kernel<T: Numeric, B: Int>(
12    input: &Tensor<Line<T>>,
13    mask: &Tensor<Line<B>>,
14    output: &mut Tensor<Line<T>>,
15    value: T,
16    #[comptime] rank: u32,
17) {
18    let pos = ABSOLUTE_POS;
19
20    if pos >= output.len() {
21        terminate!();
22    }
23
24    let index_input = index_offset_with_layout(input, output, pos, 0, rank, false);
25    let index_mask = index_offset_with_layout(mask, output, pos, 0, rank, false);
26
27    let mask = Line::cast_from(mask[index_mask]);
28    let input = input[index_input];
29    let value = Line::new(value);
30
31    output[pos] = select_many(mask, value, input);
32}
33
34#[cube(launch_unchecked)]
35fn mask_fill_inplace_kernel<T: Numeric, B: Int>(
36    input: &mut Tensor<Line<T>>,
37    mask: &Tensor<Line<B>>,
38    value: T,
39    #[comptime] rank: u32,
40) {
41    let pos = ABSOLUTE_POS;
42
43    if pos >= input.len() {
44        terminate!();
45    }
46
47    let index_mask = index_offset_with_layout(mask, input, pos, 0, rank, false);
48    let mask = Line::cast_from(mask[index_mask]);
49    let value = Line::new(value);
50
51    input[pos] = select_many(mask, value, input[pos]);
52}
53
54#[derive(Clone, Copy, Debug)]
55/// Define how to run the mask fill kernel.
56///
57/// # Notes
58///
59/// All assertions should be done before choosing the strategy.
60pub enum MaskFillStrategy {
61    /// Don't mutate any input.
62    Readonly,
63    /// Reuse the input tensor inplace.
64    Inplace,
65}
66
67/// Execute the mask fill kernel with the given strategy.
68pub fn mask_fill<R: CubeRuntime, E: CubeElement, BT: BoolElement>(
69    input: CubeTensor<R>,
70    mask: CubeTensor<R>,
71    value: E,
72    strategy: MaskFillStrategy,
73) -> CubeTensor<R> {
74    match strategy {
75        MaskFillStrategy::Readonly => mask_fill_readonly::<R, E, BT>(input, mask, value),
76        MaskFillStrategy::Inplace => mask_fill_inplace::<R, E, BT>(input, mask, value),
77    }
78}
79
80fn mask_fill_readonly<R: CubeRuntime, EI: CubeElement, EM: BoolElement>(
81    input: CubeTensor<R>,
82    mask: CubeTensor<R>,
83    value: EI,
84) -> CubeTensor<R> {
85    let ndims = input.shape.num_dims();
86    let output = empty_device::<R, EI>(
87        input.client.clone(),
88        input.device.clone(),
89        input.shape.clone(),
90    );
91
92    let cube_dim = CubeDim::default();
93    let vectorization = max_line_size_many(&[&input, &mask], ndims - 1);
94    let cube_count = calculate_cube_count_elemwise(
95        input.shape.num_elements() / vectorization as usize,
96        cube_dim,
97    );
98
99    unsafe {
100        mask_fill_readonly_kernel::launch_unchecked::<EI, EM, R>(
101            &input.client,
102            cube_count,
103            cube_dim,
104            input.as_tensor_arg::<EI>(vectorization),
105            mask.as_tensor_arg::<EM>(vectorization),
106            output.as_tensor_arg::<EI>(vectorization),
107            ScalarArg::new(value),
108            ndims as u32,
109        );
110    }
111
112    output
113}
114
115fn mask_fill_inplace<R: CubeRuntime, EI: CubeElement, EM: BoolElement>(
116    input: CubeTensor<R>,
117    mask: CubeTensor<R>,
118    value: EI,
119) -> CubeTensor<R> {
120    let ndims = input.shape.num_dims();
121    let cube_dim = CubeDim::default();
122    let vectorization = max_line_size_many(&[&input, &mask], ndims - 1);
123    let cube_count = calculate_cube_count_elemwise(
124        input.shape.num_elements() / vectorization as usize,
125        cube_dim,
126    );
127
128    unsafe {
129        mask_fill_inplace_kernel::launch_unchecked::<EI, EM, R>(
130            &input.client,
131            cube_count,
132            cube_dim,
133            input.as_tensor_arg::<EI>(vectorization),
134            mask.as_tensor_arg::<EM>(vectorization),
135            ScalarArg::new(value),
136            ndims as u32,
137        );
138    }
139
140    input
141}