cubecl_core/frontend/operation/
base.rs1use std::num::{NonZero, NonZeroU8};
2
3use cubecl_ir::{
4 Arithmetic, BinaryOperator, Comparison, Elem, ExpandElement, IndexAssignOperator,
5 IndexOperator, Instruction, Item, Operation, Operator, Scope, UnaryOperator, Variable,
6 VariableKind, Vectorization,
7};
8
9use crate::prelude::{CubeIndex, CubeType, ExpandElementTyped};
10
11pub(crate) fn binary_expand<F, Op>(
12 scope: &mut Scope,
13 lhs: ExpandElement,
14 rhs: ExpandElement,
15 func: F,
16) -> ExpandElement
17where
18 F: Fn(BinaryOperator) -> Op,
19 Op: Into<Operation>,
20{
21 let lhs = lhs.consume();
22 let rhs = rhs.consume();
23
24 let item_lhs = lhs.item;
25 let item_rhs = rhs.item;
26
27 let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);
28
29 let item = Item::vectorized(item_lhs.elem, vectorization);
30
31 let output = scope.create_local(item);
32 let out = *output;
33
34 let op = func(BinaryOperator { lhs, rhs });
35
36 scope.register(Instruction::new(op, out));
37
38 output
39}
40
41pub(crate) fn index_expand_no_vec<F>(
42 scope: &mut Scope,
43 list: ExpandElement,
44 index: ExpandElement,
45 func: F,
46) -> ExpandElement
47where
48 F: Fn(IndexOperator) -> Operator,
49{
50 let list = list.consume();
51 let index = index.consume();
52
53 let item_lhs = list.item;
54
55 let item = Item::new(item_lhs.elem);
56
57 let output = scope.create_local(item);
58 let out = *output;
59
60 let op = func(IndexOperator {
61 list,
62 index,
63 line_size: 0u32,
64 });
65
66 scope.register(Instruction::new(op, out));
67
68 output
69}
70pub(crate) fn index_expand<F, Op>(
71 scope: &mut Scope,
72 list: ExpandElement,
73 index: ExpandElement,
74 line_size: Option<u32>,
75 func: F,
76) -> ExpandElement
77where
78 F: Fn(IndexOperator) -> Op,
79 Op: Into<Operation>,
80{
81 let list = list.consume();
82 let index = index.consume();
83
84 let item_lhs = list.item;
85 let item_rhs = index.item;
86
87 let vectorization = if let Some(line_size) = line_size {
88 NonZero::new(line_size as u8)
89 } else {
90 find_vectorization(item_lhs.vectorization, item_rhs.vectorization)
91 };
92
93 let item = Item::vectorized(item_lhs.elem, vectorization);
94
95 let output = scope.create_local(item);
96 let out = *output;
97
98 let op = func(IndexOperator {
99 list,
100 index,
101 line_size: line_size.unwrap_or(0),
102 });
103
104 scope.register(Instruction::new(op, out));
105
106 output
107}
108
109pub(crate) fn binary_expand_fixed_output<F>(
110 scope: &mut Scope,
111 lhs: ExpandElement,
112 rhs: ExpandElement,
113 out_item: Item,
114 func: F,
115) -> ExpandElement
116where
117 F: Fn(BinaryOperator) -> Arithmetic,
118{
119 let lhs_var = lhs.consume();
120 let rhs_var = rhs.consume();
121
122 let out = scope.create_local(out_item);
123
124 let out_var = *out;
125
126 let op = func(BinaryOperator {
127 lhs: lhs_var,
128 rhs: rhs_var,
129 });
130
131 scope.register(Instruction::new(op, out_var));
132
133 out
134}
135
136pub(crate) fn cmp_expand<F>(
137 scope: &mut Scope,
138 lhs: ExpandElement,
139 rhs: ExpandElement,
140 func: F,
141) -> ExpandElement
142where
143 F: Fn(BinaryOperator) -> Comparison,
144{
145 let lhs: Variable = *lhs;
146 let rhs: Variable = *rhs;
147 let item = lhs.item;
148
149 find_vectorization(item.vectorization, rhs.item.vectorization);
150
151 let out_item = Item {
152 elem: Elem::Bool,
153 vectorization: item.vectorization,
154 };
155
156 let out = scope.create_local(out_item);
157 let out_var = *out;
158
159 let op = func(BinaryOperator { lhs, rhs });
160
161 scope.register(Instruction::new(op, out_var));
162
163 out
164}
165
166pub(crate) fn assign_op_expand<F, Op>(
167 scope: &mut Scope,
168 lhs: ExpandElement,
169 rhs: ExpandElement,
170 func: F,
171) -> ExpandElement
172where
173 F: Fn(BinaryOperator) -> Op,
174 Op: Into<Operation>,
175{
176 if lhs.is_immutable() {
177 panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
178 }
179 let lhs_var: Variable = *lhs;
180 let rhs: Variable = *rhs;
181
182 find_vectorization(lhs_var.item.vectorization, rhs.item.vectorization);
183
184 let op = func(BinaryOperator { lhs: lhs_var, rhs });
185
186 scope.register(Instruction::new(op, lhs_var));
187
188 lhs
189}
190
191pub fn unary_expand<F, Op>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
192where
193 F: Fn(UnaryOperator) -> Op,
194 Op: Into<Operation>,
195{
196 let input = input.consume();
197 let item = input.item;
198
199 let out = scope.create_local(item);
200 let out_var = *out;
201
202 let op = func(UnaryOperator { input });
203
204 scope.register(Instruction::new(op, out_var));
205
206 out
207}
208
209pub fn unary_expand_fixed_output<F, Op>(
210 scope: &mut Scope,
211 input: ExpandElement,
212 out_item: Item,
213 func: F,
214) -> ExpandElement
215where
216 F: Fn(UnaryOperator) -> Op,
217 Op: Into<Operation>,
218{
219 let input = input.consume();
220 let output = scope.create_local(out_item);
221 let out = *output;
222
223 let op = func(UnaryOperator { input });
224
225 scope.register(Instruction::new(op, out));
226
227 output
228}
229
230pub fn init_expand<F>(
231 scope: &mut Scope,
232 input: ExpandElement,
233 mutable: bool,
234 func: F,
235) -> ExpandElement
236where
237 F: Fn(Variable) -> Operation,
238{
239 let input_var: Variable = *input;
240 let item = input.item;
241
242 let out = if mutable {
243 scope.create_local_mut(item)
244 } else {
245 scope.create_local(item)
246 };
247
248 let out_var = *out;
249
250 let op = func(input_var);
251 scope.register(Instruction::new(op, out_var));
252
253 out
254}
255
256fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
257 match (lhs, rhs) {
258 (None, None) => None,
259 (None, Some(rhs)) => Some(rhs),
260 (Some(lhs), None) => Some(lhs),
261 (Some(lhs), Some(rhs)) => {
262 if lhs == rhs {
263 Some(lhs)
264 } else if lhs == NonZeroU8::new(1).unwrap() || rhs == NonZeroU8::new(1).unwrap() {
265 Some(core::cmp::max(lhs, rhs))
266 } else {
267 panic!(
268 "Left and right have different vectorizations.
269 Left: {lhs}, right: {rhs}.
270 Auto-matching fixed vectorization currently unsupported."
271 );
272 }
273 }
274 }
275}
276
277pub fn array_assign_binary_op_expand<
278 A: CubeType + CubeIndex,
279 V: CubeType,
280 F: Fn(BinaryOperator) -> Op,
281 Op: Into<Operation>,
282>(
283 scope: &mut Scope,
284 array: ExpandElementTyped<A>,
285 index: ExpandElementTyped<u32>,
286 value: ExpandElementTyped<V>,
287 func: F,
288) where
289 A::Output: CubeType + Sized,
290{
291 let array: ExpandElement = array.into();
292 let index: ExpandElement = index.into();
293 let value: ExpandElement = value.into();
294
295 let array_item = match array.kind {
296 VariableKind::LocalMut { .. } => array.item.vectorize(None),
298 _ => array.item,
299 };
300 let array_value = scope.create_local(array_item);
301
302 let read = Instruction::new(
303 Operator::Index(IndexOperator {
304 list: *array,
305 index: *index,
306 line_size: 0,
307 }),
308 *array_value,
309 );
310 let array_value = array_value.consume();
311 let op_out = scope.create_local(array_item);
312 let calculate = Instruction::new(
313 func(BinaryOperator {
314 lhs: array_value,
315 rhs: *value,
316 }),
317 *op_out,
318 );
319
320 let write = Operator::IndexAssign(IndexAssignOperator {
321 index: *index,
322 value: op_out.consume(),
323 line_size: 0,
324 });
325 scope.register(read);
326 scope.register(calculate);
327 scope.register(Instruction::new(write, *array));
328}