cubecl_core/frontend/operation/
base.rs1use cubecl_ir::{
2 Arithmetic, BinaryOperator, Comparison, ElemType, ExpandElement, IndexAssignOperator,
3 IndexOperator, Instruction, LineSize, Operation, Operator, Scope, Type, UnaryOperator,
4 Variable, VariableKind,
5};
6use cubecl_macros::cube;
7
8use crate::{
9 self as cubecl,
10 prelude::{CubeIndex, CubeType, ExpandElementTyped, Int, eq, rem},
11};
12
13pub(crate) fn binary_expand<F, Op>(
14 scope: &mut Scope,
15 lhs: ExpandElement,
16 rhs: ExpandElement,
17 func: F,
18) -> ExpandElement
19where
20 F: Fn(BinaryOperator) -> Op,
21 Op: Into<Operation>,
22{
23 let lhs = lhs.consume();
24 let rhs = rhs.consume();
25
26 let item_lhs = lhs.ty;
27 let item_rhs = rhs.ty;
28
29 let line_size = find_vectorization(item_lhs, item_rhs);
30
31 let item = item_lhs.line(line_size);
32
33 let output = scope.create_local(item);
34 let out = *output;
35
36 let op = func(BinaryOperator { lhs, rhs });
37
38 scope.register(Instruction::new(op, out));
39
40 output
41}
42
43pub(crate) fn index_expand_no_vec<F>(
44 scope: &mut Scope,
45 list: ExpandElement,
46 index: ExpandElement,
47 func: F,
48) -> ExpandElement
49where
50 F: Fn(IndexOperator) -> Operator,
51{
52 let list = list.consume();
53 let index = index.consume();
54
55 let item_lhs = list.ty;
56
57 let item = item_lhs.line(0);
58
59 let output = scope.create_local(item);
60 let out = *output;
61
62 let op = func(IndexOperator {
63 list,
64 index,
65 line_size: 0u32,
66 unroll_factor: 1,
67 });
68
69 scope.register(Instruction::new(op, out));
70
71 output
72}
73pub(crate) fn index_expand<F, Op>(
74 scope: &mut Scope,
75 list: ExpandElement,
76 index: ExpandElement,
77 line_size: Option<u32>,
78 func: F,
79) -> ExpandElement
80where
81 F: Fn(IndexOperator) -> Op,
82 Op: Into<Operation>,
83{
84 let list = list.consume();
85 let index = index.consume();
86
87 let item_lhs = list.ty;
88 let item_rhs = index.ty;
89
90 let vec = if let Some(line_size) = line_size {
91 line_size
92 } else {
93 find_vectorization(item_lhs, item_rhs)
94 };
95
96 let item = item_lhs.line(vec);
97
98 let output = scope.create_local(item);
99 let out = *output;
100
101 let op = func(IndexOperator {
102 list,
103 index,
104 line_size: line_size.unwrap_or(0),
105 unroll_factor: 1,
106 });
107
108 scope.register(Instruction::new(op, out));
109
110 output
111}
112
113pub(crate) fn binary_expand_fixed_output<F>(
114 scope: &mut Scope,
115 lhs: ExpandElement,
116 rhs: ExpandElement,
117 out_item: Type,
118 func: F,
119) -> ExpandElement
120where
121 F: Fn(BinaryOperator) -> Arithmetic,
122{
123 let lhs_var = lhs.consume();
124 let rhs_var = rhs.consume();
125
126 let out = scope.create_local(out_item);
127
128 let out_var = *out;
129
130 let op = func(BinaryOperator {
131 lhs: lhs_var,
132 rhs: rhs_var,
133 });
134
135 scope.register(Instruction::new(op, out_var));
136
137 out
138}
139
140pub(crate) fn cmp_expand<F>(
141 scope: &mut Scope,
142 lhs: ExpandElement,
143 rhs: ExpandElement,
144 func: F,
145) -> ExpandElement
146where
147 F: Fn(BinaryOperator) -> Comparison,
148{
149 let lhs = lhs.consume();
150 let rhs = rhs.consume();
151
152 let item_lhs = lhs.ty;
153 let item_rhs = rhs.ty;
154
155 let line_size = find_vectorization(item_lhs, item_rhs);
156
157 let out_item = Type::scalar(ElemType::Bool).line(line_size);
158
159 let out = scope.create_local(out_item);
160 let out_var = *out;
161
162 let op = func(BinaryOperator { lhs, rhs });
163
164 scope.register(Instruction::new(op, out_var));
165
166 out
167}
168
169pub(crate) fn assign_op_expand<F, Op>(
170 scope: &mut Scope,
171 lhs: ExpandElement,
172 rhs: ExpandElement,
173 func: F,
174) -> ExpandElement
175where
176 F: Fn(BinaryOperator) -> Op,
177 Op: Into<Operation>,
178{
179 if lhs.is_immutable() {
180 panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
181 }
182 let lhs_var: Variable = *lhs;
183 let rhs: Variable = *rhs;
184
185 let op = func(BinaryOperator { lhs: lhs_var, rhs });
186
187 scope.register(Instruction::new(op, lhs_var));
188
189 lhs
190}
191
192pub fn unary_expand<F, Op>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
193where
194 F: Fn(UnaryOperator) -> Op,
195 Op: Into<Operation>,
196{
197 let input = input.consume();
198 let item = input.ty;
199
200 let out = scope.create_local(item);
201 let out_var = *out;
202
203 let op = func(UnaryOperator { input });
204
205 scope.register(Instruction::new(op, out_var));
206
207 out
208}
209
210pub fn unary_expand_fixed_output<F, Op>(
211 scope: &mut Scope,
212 input: ExpandElement,
213 out_item: Type,
214 func: F,
215) -> ExpandElement
216where
217 F: Fn(UnaryOperator) -> Op,
218 Op: Into<Operation>,
219{
220 let input = input.consume();
221 let output = scope.create_local(out_item);
222 let out = *output;
223
224 let op = func(UnaryOperator { input });
225
226 scope.register(Instruction::new(op, out));
227
228 output
229}
230
231pub fn init_expand<F>(
232 scope: &mut Scope,
233 input: ExpandElement,
234 mutable: bool,
235 func: F,
236) -> ExpandElement
237where
238 F: Fn(Variable) -> Operation,
239{
240 let input_var: Variable = *input;
241 let item = input.ty;
242
243 let out = if mutable {
244 scope.create_local_mut(item)
245 } else {
246 scope.create_local(item)
247 };
248
249 let out_var = *out;
250
251 let op = func(input_var);
252 scope.register(Instruction::new(op, out_var));
253
254 out
255}
256
257fn find_vectorization(lhs: Type, rhs: Type) -> LineSize {
258 if matches!(lhs, Type::Scalar(_)) && matches!(rhs, Type::Scalar(_)) {
259 0
260 } else {
261 lhs.line_size().max(rhs.line_size())
262 }
263}
264
265pub fn array_assign_binary_op_expand<
266 A: CubeType + CubeIndex,
267 V: CubeType,
268 F: Fn(BinaryOperator) -> Op,
269 Op: Into<Operation>,
270>(
271 scope: &mut Scope,
272 array: ExpandElementTyped<A>,
273 index: ExpandElementTyped<u32>,
274 value: ExpandElementTyped<V>,
275 func: F,
276) where
277 A::Output: CubeType + Sized,
278{
279 let array: ExpandElement = array.into();
280 let index: ExpandElement = index.into();
281 let value: ExpandElement = value.into();
282
283 let array_item = match array.kind {
284 VariableKind::LocalMut { .. } => array.ty.line(0),
286 _ => array.ty,
287 };
288 let array_value = scope.create_local(array_item);
289
290 let read = Instruction::new(
291 Operator::Index(IndexOperator {
292 list: *array,
293 index: *index,
294 line_size: 0,
295 unroll_factor: 1,
296 }),
297 *array_value,
298 );
299 let array_value = array_value.consume();
300 let op_out = scope.create_local(array_item);
301 let calculate = Instruction::new(
302 func(BinaryOperator {
303 lhs: array_value,
304 rhs: *value,
305 }),
306 *op_out,
307 );
308
309 let write = Operator::IndexAssign(IndexAssignOperator {
310 index: *index,
311 value: op_out.consume(),
312 line_size: 0,
313 unroll_factor: 1,
314 });
315 scope.register(read);
316 scope.register(calculate);
317 scope.register(Instruction::new(write, *array));
318}
319
320impl<E: Int> ExpandElementTyped<E> {
322 pub fn __expand_div_ceil_method(
323 self,
324 scope: &mut Scope,
325 divisor: ExpandElementTyped<E>,
326 ) -> ExpandElementTyped<E> {
327 div_ceil::expand::<E>(scope, self, divisor)
328 }
329
330 pub fn __expand_is_multiple_of_method(
331 self,
332 scope: &mut Scope,
333 factor: ExpandElementTyped<E>,
334 ) -> ExpandElementTyped<bool> {
335 let modulo = rem::expand(scope, self, factor);
336 eq::expand(scope, modulo, E::from_int(0).into())
337 }
338}
339
340#[cube]
341pub fn div_ceil<E: Int>(a: E, b: E) -> E {
342 (a + b - E::new(1)) / b
343}