Skip to main content

burn_cubecl/kernel/mask/
mask_where.rs

1use burn_backend::DType;
2use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
3
4use crate::{
5    CubeRuntime,
6    kernel::utils::{address_type, broadcast_shape},
7    ops::{max_vector_size_many, numeric::empty_device_dtype},
8    tensor::CubeTensor,
9};
10
11#[cube(launch, address_type = "dynamic")]
12fn mask_where_kernel<T: Numeric, B: Int, N: Size>(
13    input: &LinearView<Vector<T, N>>,
14    value: &LinearView<Vector<T, N>>,
15    mask: &LinearView<Vector<B, N>>,
16    output: &mut LinearView<Vector<T, N>, ReadWrite>,
17    #[define(T, B)] _dtypes: [StorageType; 2],
18) {
19    let pos = ABSOLUTE_POS;
20    if !output.is_in_bounds(pos) {
21        terminate!();
22    }
23
24    output[pos] = select_many(Vector::cast_from(mask[pos]), value[pos], input[pos]);
25}
26
27#[derive(Clone, Copy, Debug)]
28/// Define how to run the mask where kernel.
29///
30/// # Notes
31///
32/// All assertions should be done before choosing the strategy.
33pub enum MaskWhereStrategy {
34    /// Don't mutate any input.
35    Readonly,
36    /// Reuse the lhs tensor inplace.
37    InplaceLhs,
38    /// Reuse the rhs tensor inplace.
39    InplaceRhs,
40}
41
42/// Execute the mask where kernel with the given strategy.
43pub fn mask_where<R: CubeRuntime>(
44    input: CubeTensor<R>,
45    mask: CubeTensor<R>,
46    value: CubeTensor<R>,
47    strategy: MaskWhereStrategy,
48    dtype_bool: DType,
49) -> CubeTensor<R> {
50    let vector_size = max_vector_size_many(&[&input, &mask, &value], input.meta.num_dims() - 1);
51
52    let working_units = input.meta.num_elements() / vector_size as usize;
53    let cube_dim = CubeDim::new(&input.client, working_units);
54    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
55
56    let out_shape = broadcast_shape(&[&input, &mask, &value]);
57
58    let output = match strategy {
59        MaskWhereStrategy::Readonly => empty_device_dtype(
60            input.client.clone(),
61            input.device.clone(),
62            out_shape,
63            input.dtype,
64        ),
65        MaskWhereStrategy::InplaceLhs => input.clone(),
66        MaskWhereStrategy::InplaceRhs => value.clone(),
67    };
68
69    let out = match strategy {
70        MaskWhereStrategy::Readonly => output.clone().into_linear_view(),
71        MaskWhereStrategy::InplaceLhs => output.as_linear_view_alias(0),
72        MaskWhereStrategy::InplaceRhs => output.as_linear_view_alias(1),
73    };
74
75    mask_where_kernel::launch(
76        &output.client,
77        cube_count,
78        cube_dim,
79        address_type!(input, value, mask, output),
80        vector_size,
81        input.into_linear_view_like(&output),
82        value.into_linear_view_like(&output),
83        mask.into_linear_view_like(&output),
84        out,
85        [output.dtype.into(), dtype_bool.into()],
86    );
87
88    output
89}