burn_jit/kernel/mask/
mask_where.rs

1use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*};
2
3use crate::{
4    element::JitElement,
5    ops::{max_vectorization, numeric::empty_device},
6    tensor::JitTensor,
7    BoolElement, JitRuntime,
8};
9
10#[cube(launch)]
11fn mask_where_readonly_kernel<T: CubePrimitive, B: Int>(
12    input: &Tensor<Line<T>>,
13    mask: &Tensor<Line<B>>,
14    value: &Tensor<Line<T>>,
15    output: &mut Tensor<Line<T>>,
16    #[comptime] rank: u32,
17) {
18    if ABSOLUTE_POS >= output.len() {
19        return;
20    }
21
22    let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true);
23    let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true);
24    let index_value = index_offset_with_layout(value, output, ABSOLUTE_POS, 0, rank, true);
25    let mask = Line::cast_from(mask[index_mask]);
26
27    output[ABSOLUTE_POS] = select_many(mask, value[index_value], input[index_input]);
28}
29
30#[cube(launch)]
31fn mask_where_inplace_kernel<T: CubePrimitive, B: Int>(
32    input: &mut Tensor<Line<T>>,
33    mask: &Tensor<Line<B>>,
34    value: &Tensor<Line<T>>,
35    reverse: B,
36    #[comptime] rank: u32,
37) {
38    if ABSOLUTE_POS >= input.len() {
39        return;
40    }
41
42    let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true);
43    let index_value = index_offset_with_layout(value, input, ABSOLUTE_POS, 0, rank, true);
44
45    input[ABSOLUTE_POS] = select(
46        mask[index_mask] != Line::new(reverse),
47        value[index_value],
48        input[ABSOLUTE_POS],
49    );
50}
51
52#[derive(Clone, Copy, Debug)]
53/// Define how to run the mask where kernel.
54///
55/// # Notes
56///
57/// All assertions should be done before choosing the strategy.
58pub enum MaskWhereStrategy {
59    /// Don't mutate any input.
60    Readonly,
61    /// Reuse the lhs tensor inplace.
62    InplaceLhs,
63    /// Reuse the rhs tensor inplace.
64    InplaceRhs,
65}
66
67/// Execute the mask where kernel with the given strategy.
68pub fn mask_where<R: JitRuntime, E: JitElement, BT: BoolElement>(
69    input: JitTensor<R>,
70    mask: JitTensor<R>,
71    value: JitTensor<R>,
72    strategy: MaskWhereStrategy,
73) -> JitTensor<R> {
74    match strategy {
75        MaskWhereStrategy::Readonly => mask_where_readonly::<R, E, BT>(input, mask, value),
76        MaskWhereStrategy::InplaceLhs => mask_where_inplace::<R, E, BT>(input, mask, value, false),
77        MaskWhereStrategy::InplaceRhs => mask_where_inplace::<R, E, BT>(value, mask, input, true),
78    }
79}
80
81fn mask_where_readonly<R: JitRuntime, EI: JitElement, EM: BoolElement>(
82    input: JitTensor<R>,
83    mask: JitTensor<R>,
84    value: JitTensor<R>,
85) -> JitTensor<R> {
86    let ndims = input.shape.num_dims();
87    let output = empty_device::<R, EI>(
88        input.client.clone(),
89        input.device.clone(),
90        input.shape.clone(),
91    );
92
93    let cube_dim = CubeDim::default();
94    let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
95    let vectorization = max_vectorization(&input);
96
97    mask_where_readonly_kernel::launch::<EI, EM, R>(
98        &input.client,
99        cube_count,
100        cube_dim,
101        input.as_tensor_arg::<EI>(vectorization),
102        mask.as_tensor_arg::<EM>(vectorization),
103        value.as_tensor_arg::<EI>(vectorization),
104        output.as_tensor_arg::<EI>(vectorization),
105        ndims as u32,
106    );
107
108    output
109}
110
111fn mask_where_inplace<R: JitRuntime, EI: JitElement, EM: BoolElement>(
112    input: JitTensor<R>,
113    mask: JitTensor<R>,
114    value: JitTensor<R>,
115    reverse: bool,
116) -> JitTensor<R> {
117    let ndims = input.shape.num_dims();
118    let cube_dim = CubeDim::default();
119    let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
120    let vectorization = max_vectorization(&input);
121
122    mask_where_inplace_kernel::launch::<EI, EM, R>(
123        &input.client,
124        cube_count,
125        cube_dim,
126        input.as_tensor_arg::<EI>(vectorization),
127        mask.as_tensor_arg::<EM>(vectorization),
128        value.as_tensor_arg::<EI>(vectorization),
129        ScalarArg::new(EM::new_bool(reverse)),
130        ndims as u32,
131    );
132
133    input
134}