burn_cubecl/kernel/mask/
mask_where.rs1use burn_backend::DType;
2use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
3
4use crate::{
5 CubeRuntime,
6 kernel::utils::{broadcast_shape, linear_view, linear_view_alias, linear_view_ref},
7 ops::{max_line_size_many, numeric::empty_device_dtype},
8 tensor::CubeTensor,
9};
10
11#[cube(launch)]
12fn mask_where_kernel<T: Numeric, B: Int>(
13 input: &LinearView<Line<T>>,
14 value: &LinearView<Line<T>>,
15 mask: &LinearView<Line<B>>,
16 output: &mut LinearView<Line<T>, ReadWrite>,
17 #[define(T, B)] _dtypes: [StorageType; 2],
18) {
19 let pos = ABSOLUTE_POS;
20 if !output.is_in_bounds(pos) {
21 terminate!();
22 }
23
24 output[pos] = select_many(Line::cast_from(mask[pos]), value[pos], input[pos]);
25}
26
27#[derive(Clone, Copy, Debug)]
28pub enum MaskWhereStrategy {
34 Readonly,
36 InplaceLhs,
38 InplaceRhs,
40}
41
42pub fn mask_where<R: CubeRuntime>(
44 input: CubeTensor<R>,
45 mask: CubeTensor<R>,
46 value: CubeTensor<R>,
47 strategy: MaskWhereStrategy,
48 dtype_bool: DType,
49) -> CubeTensor<R> {
50 let line_size = max_line_size_many(&[&input, &mask, &value], input.shape.num_dims() - 1);
51
52 let working_units = input.shape.num_elements() / line_size as usize;
53 let cube_dim = CubeDim::new(&input.client, working_units);
54 let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
55
56 let out_shape = broadcast_shape(&[&input, &mask, &value]);
57
58 let output = match strategy {
59 MaskWhereStrategy::Readonly => empty_device_dtype(
60 input.client.clone(),
61 input.device.clone(),
62 out_shape,
63 input.dtype,
64 ),
65 MaskWhereStrategy::InplaceLhs => input.clone(),
66 MaskWhereStrategy::InplaceRhs => value.clone(),
67 };
68
69 let out = match strategy {
70 MaskWhereStrategy::Readonly => linear_view(&output, line_size),
71 MaskWhereStrategy::InplaceLhs => linear_view_alias(&output, line_size, 0),
72 MaskWhereStrategy::InplaceRhs => linear_view_alias(&output, line_size, 1),
73 };
74
75 mask_where_kernel::launch(
76 &input.client,
77 cube_count,
78 cube_dim,
79 linear_view_ref(&input, &output, line_size),
80 linear_view_ref(&value, &output, line_size),
81 linear_view_ref(&mask, &output, line_size),
82 out,
83 [output.dtype.into(), dtype_bool.into()],
84 )
85 .expect("Kernel to never fail");
86
87 output
88}