burn_jit/kernel/mask/
mask_fill.rs

1use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*};
2
3use crate::{
4    element::JitElement,
5    ops::{max_vectorization, numeric::empty_device},
6    tensor::JitTensor,
7    BoolElement, JitRuntime,
8};
9
10#[cube(launch)]
11fn mask_fill_readonly_kernel<T: Numeric, B: Int>(
12    input: &Tensor<Line<T>>,
13    mask: &Tensor<Line<B>>,
14    output: &mut Tensor<Line<T>>,
15    value: T,
16    #[comptime] rank: u32,
17) {
18    if ABSOLUTE_POS >= output.len() {
19        return;
20    }
21
22    let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true);
23    let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true);
24
25    let mask = Line::cast_from(mask[index_mask]);
26
27    output[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[index_input]);
28}
29
30#[cube(launch)]
31fn mask_fill_inplace_kernel<T: Numeric, B: Int>(
32    input: &mut Tensor<Line<T>>,
33    mask: &Tensor<Line<B>>,
34    value: T,
35    #[comptime] rank: u32,
36) {
37    if ABSOLUTE_POS >= input.len() {
38        return;
39    }
40
41    let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true);
42    let mask = Line::cast_from(mask[index_mask]);
43
44    input[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[ABSOLUTE_POS]);
45}
46
47#[derive(Clone, Copy, Debug)]
48/// Define how to run the mask fill kernel.
49///
50/// # Notes
51///
52/// All assertions should be done before choosing the strategy.
53pub enum MaskFillStrategy {
54    /// Don't mutate any input.
55    Readonly,
56    /// Reuse the input tensor inplace.
57    Inplace,
58}
59
60/// Execute the mask fill kernel with the given strategy.
61pub fn mask_fill<R: JitRuntime, E: JitElement, BT: BoolElement>(
62    input: JitTensor<R>,
63    mask: JitTensor<R>,
64    value: E,
65    strategy: MaskFillStrategy,
66) -> JitTensor<R> {
67    match strategy {
68        MaskFillStrategy::Readonly => mask_fill_readonly::<R, E, BT>(input, mask, value),
69        MaskFillStrategy::Inplace => mask_fill_inplace::<R, E, BT>(input, mask, value),
70    }
71}
72
73fn mask_fill_readonly<R: JitRuntime, EI: JitElement, EM: BoolElement>(
74    input: JitTensor<R>,
75    mask: JitTensor<R>,
76    value: EI,
77) -> JitTensor<R> {
78    let ndims = input.shape.num_dims();
79    let output = empty_device::<R, EI>(
80        input.client.clone(),
81        input.device.clone(),
82        input.shape.clone(),
83    );
84
85    let cube_dim = CubeDim::default();
86    let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
87    let vectorization = max_vectorization(&input);
88
89    mask_fill_readonly_kernel::launch::<EI, EM, R>(
90        &input.client,
91        cube_count,
92        cube_dim,
93        input.as_tensor_arg::<EI>(vectorization),
94        mask.as_tensor_arg::<EM>(vectorization),
95        output.as_tensor_arg::<EI>(vectorization),
96        ScalarArg::new(value),
97        ndims as u32,
98    );
99
100    output
101}
102
103fn mask_fill_inplace<R: JitRuntime, EI: JitElement, EM: BoolElement>(
104    input: JitTensor<R>,
105    mask: JitTensor<R>,
106    value: EI,
107) -> JitTensor<R> {
108    let ndims = input.shape.num_dims();
109    let cube_dim = CubeDim::default();
110    let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
111    let vectorization = max_vectorization(&input);
112
113    mask_fill_inplace_kernel::launch::<EI, EM, R>(
114        &input.client,
115        cube_count,
116        cube_dim,
117        input.as_tensor_arg::<EI>(vectorization),
118        mask.as_tensor_arg::<EM>(vectorization),
119        ScalarArg::new(value),
120        ndims as u32,
121    );
122
123    input
124}