cubecl_core/frontend/operation/
base.rs1use std::num::NonZeroU8;
2
3use crate::ir::{
4 BinaryOperator, Elem, Instruction, Item, Operation, Operator, UnaryOperator, Variable,
5 VariableKind, Vectorization,
6};
7use crate::prelude::{CubeType, ExpandElementTyped};
8use crate::{
9 frontend::{CubeContext, ExpandElement},
10 prelude::CubeIndex,
11};
12
13pub(crate) fn binary_expand<F>(
14 context: &mut CubeContext,
15 lhs: ExpandElement,
16 rhs: ExpandElement,
17 func: F,
18) -> ExpandElement
19where
20 F: Fn(BinaryOperator) -> Operator,
21{
22 let lhs = lhs.consume();
23 let rhs = rhs.consume();
24
25 let item_lhs = lhs.item;
26 let item_rhs = rhs.item;
27
28 let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);
29
30 let item = Item::vectorized(item_lhs.elem, vectorization);
31
32 let output = context.create_local(item);
33 let out = *output;
34
35 let op = func(BinaryOperator { lhs, rhs });
36
37 context.register(Instruction::new(op, out));
38
39 output
40}
41
42pub(crate) fn binary_expand_fixed_output<F>(
43 context: &mut CubeContext,
44 lhs: ExpandElement,
45 rhs: ExpandElement,
46 out_item: Item,
47 func: F,
48) -> ExpandElement
49where
50 F: Fn(BinaryOperator) -> Operator,
51{
52 let lhs_var = lhs.consume();
53 let rhs_var = rhs.consume();
54
55 let out = context.create_local(out_item);
56
57 let out_var = *out;
58
59 let op = func(BinaryOperator {
60 lhs: lhs_var,
61 rhs: rhs_var,
62 });
63
64 context.register(Instruction::new(op, out_var));
65
66 out
67}
68
69pub(crate) fn binary_expand_no_vec<F>(
70 context: &mut CubeContext,
71 lhs: ExpandElement,
72 rhs: ExpandElement,
73 func: F,
74) -> ExpandElement
75where
76 F: Fn(BinaryOperator) -> Operator,
77{
78 let lhs = lhs.consume();
79 let rhs = rhs.consume();
80
81 let item_lhs = lhs.item;
82
83 let item = Item::new(item_lhs.elem);
84
85 let output = context.create_local(item);
86 let out = *output;
87
88 let op = func(BinaryOperator { lhs, rhs });
89
90 context.register(Instruction::new(op, out));
91
92 output
93}
94
95pub(crate) fn cmp_expand<F>(
96 context: &mut CubeContext,
97 lhs: ExpandElement,
98 rhs: ExpandElement,
99 func: F,
100) -> ExpandElement
101where
102 F: Fn(BinaryOperator) -> Operator,
103{
104 let lhs: Variable = *lhs;
105 let rhs: Variable = *rhs;
106 let item = lhs.item;
107
108 find_vectorization(item.vectorization, rhs.item.vectorization);
109
110 let out_item = Item {
111 elem: Elem::Bool,
112 vectorization: item.vectorization,
113 };
114
115 let out = context.create_local(out_item);
116 let out_var = *out;
117
118 let op = func(BinaryOperator { lhs, rhs });
119
120 context.register(Instruction::new(op, out_var));
121
122 out
123}
124
125pub(crate) fn assign_op_expand<F>(
126 context: &mut CubeContext,
127 lhs: ExpandElement,
128 rhs: ExpandElement,
129 func: F,
130) -> ExpandElement
131where
132 F: Fn(BinaryOperator) -> Operator,
133{
134 let lhs_var: Variable = *lhs;
135 let rhs: Variable = *rhs;
136
137 find_vectorization(lhs_var.item.vectorization, rhs.item.vectorization);
138
139 let op = func(BinaryOperator { lhs: lhs_var, rhs });
140
141 context.register(Instruction::new(op, lhs_var));
142
143 lhs
144}
145
146pub fn unary_expand<F>(context: &mut CubeContext, input: ExpandElement, func: F) -> ExpandElement
147where
148 F: Fn(UnaryOperator) -> Operator,
149{
150 let input = input.consume();
151 let item = input.item;
152
153 let out = context.create_local(item);
154 let out_var = *out;
155
156 let op = func(UnaryOperator { input });
157
158 context.register(Instruction::new(op, out_var));
159
160 out
161}
162
163pub fn unary_expand_fixed_output<F>(
164 context: &mut CubeContext,
165 input: ExpandElement,
166 out_item: Item,
167 func: F,
168) -> ExpandElement
169where
170 F: Fn(UnaryOperator) -> Operator,
171{
172 let input = input.consume();
173 let output = context.create_local(out_item);
174 let out = *output;
175
176 let op = func(UnaryOperator { input });
177
178 context.register(Instruction::new(op, out));
179
180 output
181}
182
183pub fn init_expand<F>(context: &mut CubeContext, input: ExpandElement, func: F) -> ExpandElement
184where
185 F: Fn(Variable) -> Operation,
186{
187 if input.can_mut() {
188 return input;
189 }
190 let input_var: Variable = *input;
191 let item = input.item;
192
193 let out = context.create_local_mut(item); let out_var = *out;
195
196 let op = func(input_var);
197 context.register(Instruction::new(op, out_var));
198
199 out
200}
201
202fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
203 match (lhs, rhs) {
204 (None, None) => None,
205 (None, Some(rhs)) => Some(rhs),
206 (Some(lhs), None) => Some(lhs),
207 (Some(lhs), Some(rhs)) => {
208 if lhs == rhs {
209 Some(lhs)
210 } else if lhs == NonZeroU8::new(1).unwrap() || rhs == NonZeroU8::new(1).unwrap() {
211 Some(core::cmp::max(lhs, rhs))
212 } else {
213 panic!(
214 "Left and right have different vectorizations.
215 Left: {lhs}, right: {rhs}.
216 Auto-matching fixed vectorization currently unsupported."
217 );
218 }
219 }
220 }
221}
222
223pub fn array_assign_binary_op_expand<
224 A: CubeType + CubeIndex<u32>,
225 V: CubeType,
226 F: Fn(BinaryOperator) -> Operator,
227>(
228 context: &mut CubeContext,
229 array: ExpandElementTyped<A>,
230 index: ExpandElementTyped<u32>,
231 value: ExpandElementTyped<V>,
232 func: F,
233) where
234 A::Output: CubeType + Sized,
235{
236 let array: ExpandElement = array.into();
237 let index: ExpandElement = index.into();
238 let value: ExpandElement = value.into();
239
240 let array_item = match array.kind {
241 VariableKind::LocalMut { .. } => array.item.vectorize(None),
243 _ => array.item,
244 };
245 let array_value = context.create_local(array_item);
246
247 let read = Instruction::new(
248 Operator::Index(BinaryOperator {
249 lhs: *array,
250 rhs: *index,
251 }),
252 *array_value,
253 );
254 let array_value = array_value.consume();
255 let op_out = context.create_local(array_item);
256 let calculate = Instruction::new(
257 func(BinaryOperator {
258 lhs: array_value,
259 rhs: *value,
260 }),
261 *op_out,
262 );
263
264 let write = Operator::IndexAssign(BinaryOperator {
265 lhs: *index,
266 rhs: op_out.consume(),
267 });
268 context.register(read);
269 context.register(calculate);
270 context.register(Instruction::new(write, *array));
271}