cubecl_core/frontend/operation/
base.rs1use std::num::NonZeroU8;
2
3use cubecl_ir::{
4 Arithmetic, BinaryOperator, Comparison, Elem, ExpandElement, Instruction, Item, Operation,
5 Operator, Scope, UnaryOperator, Variable, VariableKind, Vectorization,
6};
7
8use crate::prelude::{CubeIndex, CubeType, ExpandElementTyped};
9
10pub(crate) fn binary_expand<F, Op>(
11 scope: &mut Scope,
12 lhs: ExpandElement,
13 rhs: ExpandElement,
14 func: F,
15) -> ExpandElement
16where
17 F: Fn(BinaryOperator) -> Op,
18 Op: Into<Operation>,
19{
20 let lhs = lhs.consume();
21 let rhs = rhs.consume();
22
23 let item_lhs = lhs.item;
24 let item_rhs = rhs.item;
25
26 let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);
27
28 let item = Item::vectorized(item_lhs.elem, vectorization);
29
30 let output = scope.create_local(item);
31 let out = *output;
32
33 let op = func(BinaryOperator { lhs, rhs });
34
35 scope.register(Instruction::new(op, out));
36
37 output
38}
39
40pub(crate) fn binary_expand_fixed_output<F>(
41 scope: &mut Scope,
42 lhs: ExpandElement,
43 rhs: ExpandElement,
44 out_item: Item,
45 func: F,
46) -> ExpandElement
47where
48 F: Fn(BinaryOperator) -> Arithmetic,
49{
50 let lhs_var = lhs.consume();
51 let rhs_var = rhs.consume();
52
53 let out = scope.create_local(out_item);
54
55 let out_var = *out;
56
57 let op = func(BinaryOperator {
58 lhs: lhs_var,
59 rhs: rhs_var,
60 });
61
62 scope.register(Instruction::new(op, out_var));
63
64 out
65}
66
67pub(crate) fn binary_expand_no_vec<F>(
68 scope: &mut Scope,
69 lhs: ExpandElement,
70 rhs: ExpandElement,
71 func: F,
72) -> ExpandElement
73where
74 F: Fn(BinaryOperator) -> Operator,
75{
76 let lhs = lhs.consume();
77 let rhs = rhs.consume();
78
79 let item_lhs = lhs.item;
80
81 let item = Item::new(item_lhs.elem);
82
83 let output = scope.create_local(item);
84 let out = *output;
85
86 let op = func(BinaryOperator { lhs, rhs });
87
88 scope.register(Instruction::new(op, out));
89
90 output
91}
92
93pub(crate) fn cmp_expand<F>(
94 scope: &mut Scope,
95 lhs: ExpandElement,
96 rhs: ExpandElement,
97 func: F,
98) -> ExpandElement
99where
100 F: Fn(BinaryOperator) -> Comparison,
101{
102 let lhs: Variable = *lhs;
103 let rhs: Variable = *rhs;
104 let item = lhs.item;
105
106 find_vectorization(item.vectorization, rhs.item.vectorization);
107
108 let out_item = Item {
109 elem: Elem::Bool,
110 vectorization: item.vectorization,
111 };
112
113 let out = scope.create_local(out_item);
114 let out_var = *out;
115
116 let op = func(BinaryOperator { lhs, rhs });
117
118 scope.register(Instruction::new(op, out_var));
119
120 out
121}
122
123pub(crate) fn assign_op_expand<F, Op>(
124 scope: &mut Scope,
125 lhs: ExpandElement,
126 rhs: ExpandElement,
127 func: F,
128) -> ExpandElement
129where
130 F: Fn(BinaryOperator) -> Op,
131 Op: Into<Operation>,
132{
133 let lhs_var: Variable = *lhs;
134 let rhs: Variable = *rhs;
135
136 find_vectorization(lhs_var.item.vectorization, rhs.item.vectorization);
137
138 let op = func(BinaryOperator { lhs: lhs_var, rhs });
139
140 scope.register(Instruction::new(op, lhs_var));
141
142 lhs
143}
144
145pub fn unary_expand<F, Op>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
146where
147 F: Fn(UnaryOperator) -> Op,
148 Op: Into<Operation>,
149{
150 let input = input.consume();
151 let item = input.item;
152
153 let out = scope.create_local(item);
154 let out_var = *out;
155
156 let op = func(UnaryOperator { input });
157
158 scope.register(Instruction::new(op, out_var));
159
160 out
161}
162
163pub fn unary_expand_fixed_output<F, Op>(
164 scope: &mut Scope,
165 input: ExpandElement,
166 out_item: Item,
167 func: F,
168) -> ExpandElement
169where
170 F: Fn(UnaryOperator) -> Op,
171 Op: Into<Operation>,
172{
173 let input = input.consume();
174 let output = scope.create_local(out_item);
175 let out = *output;
176
177 let op = func(UnaryOperator { input });
178
179 scope.register(Instruction::new(op, out));
180
181 output
182}
183
184pub fn init_expand<F>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
185where
186 F: Fn(Variable) -> Operation,
187{
188 if input.can_mut() {
189 return input;
190 }
191 let input_var: Variable = *input;
192 let item = input.item;
193
194 let out = scope.create_local_mut(item); let out_var = *out;
196
197 let op = func(input_var);
198 scope.register(Instruction::new(op, out_var));
199
200 out
201}
202
203fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
204 match (lhs, rhs) {
205 (None, None) => None,
206 (None, Some(rhs)) => Some(rhs),
207 (Some(lhs), None) => Some(lhs),
208 (Some(lhs), Some(rhs)) => {
209 if lhs == rhs {
210 Some(lhs)
211 } else if lhs == NonZeroU8::new(1).unwrap() || rhs == NonZeroU8::new(1).unwrap() {
212 Some(core::cmp::max(lhs, rhs))
213 } else {
214 panic!(
215 "Left and right have different vectorizations.
216 Left: {lhs}, right: {rhs}.
217 Auto-matching fixed vectorization currently unsupported."
218 );
219 }
220 }
221 }
222}
223
224pub fn array_assign_binary_op_expand<
225 A: CubeType + CubeIndex<u32>,
226 V: CubeType,
227 F: Fn(BinaryOperator) -> Op,
228 Op: Into<Operation>,
229>(
230 scope: &mut Scope,
231 array: ExpandElementTyped<A>,
232 index: ExpandElementTyped<u32>,
233 value: ExpandElementTyped<V>,
234 func: F,
235) where
236 A::Output: CubeType + Sized,
237{
238 let array: ExpandElement = array.into();
239 let index: ExpandElement = index.into();
240 let value: ExpandElement = value.into();
241
242 let array_item = match array.kind {
243 VariableKind::LocalMut { .. } => array.item.vectorize(None),
245 _ => array.item,
246 };
247 let array_value = scope.create_local(array_item);
248
249 let read = Instruction::new(
250 Operator::Index(BinaryOperator {
251 lhs: *array,
252 rhs: *index,
253 }),
254 *array_value,
255 );
256 let array_value = array_value.consume();
257 let op_out = scope.create_local(array_item);
258 let calculate = Instruction::new(
259 func(BinaryOperator {
260 lhs: array_value,
261 rhs: *value,
262 }),
263 *op_out,
264 );
265
266 let write = Operator::IndexAssign(BinaryOperator {
267 lhs: *index,
268 rhs: op_out.consume(),
269 });
270 scope.register(read);
271 scope.register(calculate);
272 scope.register(Instruction::new(write, *array));
273}