1use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*};
2
3use crate::{
4 element::JitElement,
5 ops::{max_vectorization, numeric::empty_device},
6 tensor::JitTensor,
7 BoolElement, JitRuntime,
8};
9
10#[cube(launch)]
11fn mask_where_readonly_kernel<T: CubePrimitive, B: Int>(
12 input: &Tensor<Line<T>>,
13 mask: &Tensor<Line<B>>,
14 value: &Tensor<Line<T>>,
15 output: &mut Tensor<Line<T>>,
16 #[comptime] rank: u32,
17) {
18 if ABSOLUTE_POS >= output.len() {
19 return;
20 }
21
22 let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true);
23 let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true);
24 let index_value = index_offset_with_layout(value, output, ABSOLUTE_POS, 0, rank, true);
25 let mask = Line::cast_from(mask[index_mask]);
26
27 output[ABSOLUTE_POS] = select_many(mask, value[index_value], input[index_input]);
28}
29
30#[cube(launch)]
31fn mask_where_inplace_kernel<T: CubePrimitive, B: Int>(
32 input: &mut Tensor<Line<T>>,
33 mask: &Tensor<Line<B>>,
34 value: &Tensor<Line<T>>,
35 reverse: B,
36 #[comptime] rank: u32,
37) {
38 if ABSOLUTE_POS >= input.len() {
39 return;
40 }
41
42 let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true);
43 let index_value = index_offset_with_layout(value, input, ABSOLUTE_POS, 0, rank, true);
44
45 input[ABSOLUTE_POS] = select(
46 mask[index_mask] != Line::new(reverse),
47 value[index_value],
48 input[ABSOLUTE_POS],
49 );
50}
51
52#[derive(Clone, Copy, Debug)]
53pub enum MaskWhereStrategy {
59 Readonly,
61 InplaceLhs,
63 InplaceRhs,
65}
66
67pub fn mask_where<R: JitRuntime, E: JitElement, BT: BoolElement>(
69 input: JitTensor<R>,
70 mask: JitTensor<R>,
71 value: JitTensor<R>,
72 strategy: MaskWhereStrategy,
73) -> JitTensor<R> {
74 match strategy {
75 MaskWhereStrategy::Readonly => mask_where_readonly::<R, E, BT>(input, mask, value),
76 MaskWhereStrategy::InplaceLhs => mask_where_inplace::<R, E, BT>(input, mask, value, false),
77 MaskWhereStrategy::InplaceRhs => mask_where_inplace::<R, E, BT>(value, mask, input, true),
78 }
79}
80
81fn mask_where_readonly<R: JitRuntime, EI: JitElement, EM: BoolElement>(
82 input: JitTensor<R>,
83 mask: JitTensor<R>,
84 value: JitTensor<R>,
85) -> JitTensor<R> {
86 let ndims = input.shape.num_dims();
87 let output = empty_device::<R, EI>(
88 input.client.clone(),
89 input.device.clone(),
90 input.shape.clone(),
91 );
92
93 let cube_dim = CubeDim::default();
94 let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
95 let vectorization = max_vectorization(&input);
96
97 mask_where_readonly_kernel::launch::<EI, EM, R>(
98 &input.client,
99 cube_count,
100 cube_dim,
101 input.as_tensor_arg::<EI>(vectorization),
102 mask.as_tensor_arg::<EM>(vectorization),
103 value.as_tensor_arg::<EI>(vectorization),
104 output.as_tensor_arg::<EI>(vectorization),
105 ndims as u32,
106 );
107
108 output
109}
110
111fn mask_where_inplace<R: JitRuntime, EI: JitElement, EM: BoolElement>(
112 input: JitTensor<R>,
113 mask: JitTensor<R>,
114 value: JitTensor<R>,
115 reverse: bool,
116) -> JitTensor<R> {
117 let ndims = input.shape.num_dims();
118 let cube_dim = CubeDim::default();
119 let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim);
120 let vectorization = max_vectorization(&input);
121
122 mask_where_inplace_kernel::launch::<EI, EM, R>(
123 &input.client,
124 cube_count,
125 cube_dim,
126 input.as_tensor_arg::<EI>(vectorization),
127 mask.as_tensor_arg::<EM>(vectorization),
128 value.as_tensor_arg::<EI>(vectorization),
129 ScalarArg::new(EM::new_bool(reverse)),
130 ndims as u32,
131 );
132
133 input
134}