1use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*};
2
3use crate::{
4 element::JitElement,
5 ops::{max_vectorization, numeric::empty_device},
6 tensor::JitTensor,
7 BoolElement, JitRuntime,
8};
9
10#[cube(launch)]
11fn mask_fill_readonly_kernel<T: Numeric, B: Int>(
12 input: &Tensor<Line<T>>,
13 mask: &Tensor<Line<B>>,
14 output: &mut Tensor<Line<T>>,
15 value: T,
16 #[comptime] rank: u32,
17) {
18 if ABSOLUTE_POS >= output.len() {
19 return;
20 }
21
22 let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true);
23 let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true);
24
25 let mask = Line::cast_from(mask[index_mask]);
26
27 output[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[index_input]);
28}
29
30#[cube(launch)]
31fn mask_fill_inplace_kernel<T: Numeric, B: Int>(
32 input: &mut Tensor<Line<T>>,
33 mask: &Tensor<Line<B>>,
34 value: T,
35 #[comptime] rank: u32,
36) {
37 if ABSOLUTE_POS >= input.len() {
38 return;
39 }
40
41 let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true);
42 let mask = Line::cast_from(mask[index_mask]);
43
44 input[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[ABSOLUTE_POS]);
45}
46
47#[derive(Clone, Copy, Debug)]
48pub enum MaskFillStrategy {
54 Readonly,
56 Inplace,
58}
59
60pub fn mask_fill<R: JitRuntime, E: JitElement, BT: BoolElement>(
62 input: JitTensor<R>,
63 mask: JitTensor<R>,
64 value: E,
65 strategy: MaskFillStrategy,
66) -> JitTensor<R> {
67 match strategy {
68 MaskFillStrategy::Readonly => mask_fill_readonly::<R, E, BT>(input, mask, value),
69 MaskFillStrategy::Inplace => mask_fill_inplace::<R, E, BT>(input, mask, value),
70 }
71}
72
73fn mask_fill_readonly<R: JitRuntime, EI: JitElement, EM: BoolElement>(
74 input: JitTensor<R>,
75 mask: JitTensor<R>,
76 value: EI,
77) -> JitTensor<R> {
78 let ndims = input.shape.num_dims();
79 let output = empty_device::<R, EI>(
80 input.client.clone(),
81 input.device.clone(),
82 input.shape.clone(),
83 );
84
85 let cube_dim = CubeDim::default();
86 let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
87 let vectorization = max_vectorization(&input);
88
89 mask_fill_readonly_kernel::launch::<EI, EM, R>(
90 &input.client,
91 cube_count,
92 cube_dim,
93 input.as_tensor_arg::<EI>(vectorization),
94 mask.as_tensor_arg::<EM>(vectorization),
95 output.as_tensor_arg::<EI>(vectorization),
96 ScalarArg::new(value),
97 ndims as u32,
98 );
99
100 output
101}
102
103fn mask_fill_inplace<R: JitRuntime, EI: JitElement, EM: BoolElement>(
104 input: JitTensor<R>,
105 mask: JitTensor<R>,
106 value: EI,
107) -> JitTensor<R> {
108 let ndims = input.shape.num_dims();
109 let cube_dim = CubeDim::default();
110 let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
111 let vectorization = max_vectorization(&input);
112
113 mask_fill_inplace_kernel::launch::<EI, EM, R>(
114 &input.client,
115 cube_count,
116 cube_dim,
117 input.as_tensor_arg::<EI>(vectorization),
118 mask.as_tensor_arg::<EM>(vectorization),
119 ScalarArg::new(value),
120 ndims as u32,
121 );
122
123 input
124}