1use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*};
2
3use crate::{
4 BoolElement, CubeRuntime,
5 element::CubeElement,
6 ops::{max_line_size_many, numeric::empty_device},
7 tensor::CubeTensor,
8};
9
10#[cube(launch_unchecked)]
11fn mask_fill_readonly_kernel<T: Numeric, B: Int>(
12 input: &Tensor<Line<T>>,
13 mask: &Tensor<Line<B>>,
14 output: &mut Tensor<Line<T>>,
15 value: T,
16 #[comptime] rank: u32,
17) {
18 let pos = ABSOLUTE_POS;
19
20 if pos >= output.len() {
21 terminate!();
22 }
23
24 let index_input = index_offset_with_layout(input, output, pos, 0, rank, false);
25 let index_mask = index_offset_with_layout(mask, output, pos, 0, rank, false);
26
27 let mask = Line::cast_from(mask[index_mask]);
28 let input = input[index_input];
29 let value = Line::new(value);
30
31 output[pos] = select_many(mask, value, input);
32}
33
34#[cube(launch_unchecked)]
35fn mask_fill_inplace_kernel<T: Numeric, B: Int>(
36 input: &mut Tensor<Line<T>>,
37 mask: &Tensor<Line<B>>,
38 value: T,
39 #[comptime] rank: u32,
40) {
41 let pos = ABSOLUTE_POS;
42
43 if pos >= input.len() {
44 terminate!();
45 }
46
47 let index_mask = index_offset_with_layout(mask, input, pos, 0, rank, false);
48 let mask = Line::cast_from(mask[index_mask]);
49 let value = Line::new(value);
50
51 input[pos] = select_many(mask, value, input[pos]);
52}
53
54#[derive(Clone, Copy, Debug)]
55pub enum MaskFillStrategy {
61 Readonly,
63 Inplace,
65}
66
67pub fn mask_fill<R: CubeRuntime, E: CubeElement, BT: BoolElement>(
69 input: CubeTensor<R>,
70 mask: CubeTensor<R>,
71 value: E,
72 strategy: MaskFillStrategy,
73) -> CubeTensor<R> {
74 match strategy {
75 MaskFillStrategy::Readonly => mask_fill_readonly::<R, E, BT>(input, mask, value),
76 MaskFillStrategy::Inplace => mask_fill_inplace::<R, E, BT>(input, mask, value),
77 }
78}
79
80fn mask_fill_readonly<R: CubeRuntime, EI: CubeElement, EM: BoolElement>(
81 input: CubeTensor<R>,
82 mask: CubeTensor<R>,
83 value: EI,
84) -> CubeTensor<R> {
85 let ndims = input.shape.num_dims();
86 let output = empty_device::<R, EI>(
87 input.client.clone(),
88 input.device.clone(),
89 input.shape.clone(),
90 );
91
92 let cube_dim = CubeDim::default();
93 let vectorization = max_line_size_many(&[&input, &mask], ndims - 1);
94 let cube_count = calculate_cube_count_elemwise(
95 input.shape.num_elements() / vectorization as usize,
96 cube_dim,
97 );
98
99 unsafe {
100 mask_fill_readonly_kernel::launch_unchecked::<EI, EM, R>(
101 &input.client,
102 cube_count,
103 cube_dim,
104 input.as_tensor_arg::<EI>(vectorization),
105 mask.as_tensor_arg::<EM>(vectorization),
106 output.as_tensor_arg::<EI>(vectorization),
107 ScalarArg::new(value),
108 ndims as u32,
109 );
110 }
111
112 output
113}
114
115fn mask_fill_inplace<R: CubeRuntime, EI: CubeElement, EM: BoolElement>(
116 input: CubeTensor<R>,
117 mask: CubeTensor<R>,
118 value: EI,
119) -> CubeTensor<R> {
120 let ndims = input.shape.num_dims();
121 let cube_dim = CubeDim::default();
122 let vectorization = max_line_size_many(&[&input, &mask], ndims - 1);
123 let cube_count = calculate_cube_count_elemwise(
124 input.shape.num_elements() / vectorization as usize,
125 cube_dim,
126 );
127
128 unsafe {
129 mask_fill_inplace_kernel::launch_unchecked::<EI, EM, R>(
130 &input.client,
131 cube_count,
132 cube_dim,
133 input.as_tensor_arg::<EI>(vectorization),
134 mask.as_tensor_arg::<EM>(vectorization),
135 ScalarArg::new(value),
136 ndims as u32,
137 );
138 }
139
140 input
141}