burn_cubecl/kernel/mask/
mask_fill.rs1use burn_backend::DType;
2use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
3
4use crate::{
5 CubeRuntime,
6 kernel::utils::{linear_view, linear_view_alias, linear_view_ref},
7 ops::{max_line_size_many, numeric::empty_device_dtype},
8 tensor::CubeTensor,
9};
10
11#[cube(launch_unchecked)]
12fn mask_fill_kernel<T: Numeric, B: Int>(
13 input: &LinearView<Line<T>>,
14 mask: &LinearView<Line<B>>,
15 output: &mut LinearView<Line<T>, ReadWrite>,
16 value: InputScalar,
17 #[define(T, B)] _dtypes: [StorageType; 2],
18) {
19 if !output.is_in_bounds(ABSOLUTE_POS) {
20 terminate!();
21 }
22
23 let mask = Line::cast_from(mask[ABSOLUTE_POS]);
24 let input = input[ABSOLUTE_POS];
25 let value = Line::new(value.get::<T>());
26
27 output[ABSOLUTE_POS] = select_many(mask, value, input);
28}
29
30#[derive(Clone, Copy, Debug)]
31pub enum MaskFillStrategy {
37 Readonly,
39 Inplace,
41}
42
43pub fn mask_fill<R: CubeRuntime>(
45 input: CubeTensor<R>,
46 mask: CubeTensor<R>,
47 value: InputScalar,
48 strategy: MaskFillStrategy,
49 dtype_bool: DType,
50) -> CubeTensor<R> {
51 let ndims = input.shape.num_dims();
52 let output = match strategy {
53 MaskFillStrategy::Readonly => empty_device_dtype(
54 input.client.clone(),
55 input.device.clone(),
56 input.shape.clone(),
57 input.dtype,
58 ),
59 MaskFillStrategy::Inplace => input.clone(),
60 };
61
62 let line_size = max_line_size_many(&[&input, &mask], ndims - 1);
63 let working_units = input.shape.num_elements() / line_size as usize;
64 let cube_dim = CubeDim::new(&input.client, working_units);
65 let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
66
67 let out_arg = match strategy {
68 MaskFillStrategy::Readonly => linear_view(&output, line_size),
69 MaskFillStrategy::Inplace => linear_view_alias(&output, line_size, 0),
70 };
71
72 unsafe {
73 mask_fill_kernel::launch_unchecked(
74 &input.client,
75 cube_count,
76 cube_dim,
77 linear_view(&input, line_size),
78 linear_view_ref(&mask, &input, line_size),
79 out_arg,
80 value,
81 [output.dtype.into(), dtype_bool.into()],
82 )
83 .expect("Kernel to never fail");
84 }
85
86 output
87}