cubecl_core/frontend/operation/
base.rs

1use std::num::NonZeroU8;
2
3use cubecl_ir::{
4    Arithmetic, BinaryOperator, Comparison, Elem, ExpandElement, Instruction, Item, Operation,
5    Operator, Scope, UnaryOperator, Variable, VariableKind, Vectorization,
6};
7
8use crate::prelude::{CubeIndex, CubeType, ExpandElementTyped};
9
10pub(crate) fn binary_expand<F, Op>(
11    scope: &mut Scope,
12    lhs: ExpandElement,
13    rhs: ExpandElement,
14    func: F,
15) -> ExpandElement
16where
17    F: Fn(BinaryOperator) -> Op,
18    Op: Into<Operation>,
19{
20    let lhs = lhs.consume();
21    let rhs = rhs.consume();
22
23    let item_lhs = lhs.item;
24    let item_rhs = rhs.item;
25
26    let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);
27
28    let item = Item::vectorized(item_lhs.elem, vectorization);
29
30    let output = scope.create_local(item);
31    let out = *output;
32
33    let op = func(BinaryOperator { lhs, rhs });
34
35    scope.register(Instruction::new(op, out));
36
37    output
38}
39
40pub(crate) fn binary_expand_fixed_output<F>(
41    scope: &mut Scope,
42    lhs: ExpandElement,
43    rhs: ExpandElement,
44    out_item: Item,
45    func: F,
46) -> ExpandElement
47where
48    F: Fn(BinaryOperator) -> Arithmetic,
49{
50    let lhs_var = lhs.consume();
51    let rhs_var = rhs.consume();
52
53    let out = scope.create_local(out_item);
54
55    let out_var = *out;
56
57    let op = func(BinaryOperator {
58        lhs: lhs_var,
59        rhs: rhs_var,
60    });
61
62    scope.register(Instruction::new(op, out_var));
63
64    out
65}
66
67pub(crate) fn binary_expand_no_vec<F>(
68    scope: &mut Scope,
69    lhs: ExpandElement,
70    rhs: ExpandElement,
71    func: F,
72) -> ExpandElement
73where
74    F: Fn(BinaryOperator) -> Operator,
75{
76    let lhs = lhs.consume();
77    let rhs = rhs.consume();
78
79    let item_lhs = lhs.item;
80
81    let item = Item::new(item_lhs.elem);
82
83    let output = scope.create_local(item);
84    let out = *output;
85
86    let op = func(BinaryOperator { lhs, rhs });
87
88    scope.register(Instruction::new(op, out));
89
90    output
91}
92
93pub(crate) fn cmp_expand<F>(
94    scope: &mut Scope,
95    lhs: ExpandElement,
96    rhs: ExpandElement,
97    func: F,
98) -> ExpandElement
99where
100    F: Fn(BinaryOperator) -> Comparison,
101{
102    let lhs: Variable = *lhs;
103    let rhs: Variable = *rhs;
104    let item = lhs.item;
105
106    find_vectorization(item.vectorization, rhs.item.vectorization);
107
108    let out_item = Item {
109        elem: Elem::Bool,
110        vectorization: item.vectorization,
111    };
112
113    let out = scope.create_local(out_item);
114    let out_var = *out;
115
116    let op = func(BinaryOperator { lhs, rhs });
117
118    scope.register(Instruction::new(op, out_var));
119
120    out
121}
122
123pub(crate) fn assign_op_expand<F, Op>(
124    scope: &mut Scope,
125    lhs: ExpandElement,
126    rhs: ExpandElement,
127    func: F,
128) -> ExpandElement
129where
130    F: Fn(BinaryOperator) -> Op,
131    Op: Into<Operation>,
132{
133    let lhs_var: Variable = *lhs;
134    let rhs: Variable = *rhs;
135
136    find_vectorization(lhs_var.item.vectorization, rhs.item.vectorization);
137
138    let op = func(BinaryOperator { lhs: lhs_var, rhs });
139
140    scope.register(Instruction::new(op, lhs_var));
141
142    lhs
143}
144
145pub fn unary_expand<F, Op>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
146where
147    F: Fn(UnaryOperator) -> Op,
148    Op: Into<Operation>,
149{
150    let input = input.consume();
151    let item = input.item;
152
153    let out = scope.create_local(item);
154    let out_var = *out;
155
156    let op = func(UnaryOperator { input });
157
158    scope.register(Instruction::new(op, out_var));
159
160    out
161}
162
163pub fn unary_expand_fixed_output<F, Op>(
164    scope: &mut Scope,
165    input: ExpandElement,
166    out_item: Item,
167    func: F,
168) -> ExpandElement
169where
170    F: Fn(UnaryOperator) -> Op,
171    Op: Into<Operation>,
172{
173    let input = input.consume();
174    let output = scope.create_local(out_item);
175    let out = *output;
176
177    let op = func(UnaryOperator { input });
178
179    scope.register(Instruction::new(op, out));
180
181    output
182}
183
184pub fn init_expand<F>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
185where
186    F: Fn(Variable) -> Operation,
187{
188    if input.can_mut() {
189        return input;
190    }
191    let input_var: Variable = *input;
192    let item = input.item;
193
194    let out = scope.create_local_mut(item); // TODO: The mut is safe, but unnecessary if the variable is immutable.
195    let out_var = *out;
196
197    let op = func(input_var);
198    scope.register(Instruction::new(op, out_var));
199
200    out
201}
202
203fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
204    match (lhs, rhs) {
205        (None, None) => None,
206        (None, Some(rhs)) => Some(rhs),
207        (Some(lhs), None) => Some(lhs),
208        (Some(lhs), Some(rhs)) => {
209            if lhs == rhs {
210                Some(lhs)
211            } else if lhs == NonZeroU8::new(1).unwrap() || rhs == NonZeroU8::new(1).unwrap() {
212                Some(core::cmp::max(lhs, rhs))
213            } else {
214                panic!(
215                    "Left and right have different vectorizations.
216                    Left: {lhs}, right: {rhs}.
217                    Auto-matching fixed vectorization currently unsupported."
218                );
219            }
220        }
221    }
222}
223
224pub fn array_assign_binary_op_expand<
225    A: CubeType + CubeIndex<u32>,
226    V: CubeType,
227    F: Fn(BinaryOperator) -> Op,
228    Op: Into<Operation>,
229>(
230    scope: &mut Scope,
231    array: ExpandElementTyped<A>,
232    index: ExpandElementTyped<u32>,
233    value: ExpandElementTyped<V>,
234    func: F,
235) where
236    A::Output: CubeType + Sized,
237{
238    let array: ExpandElement = array.into();
239    let index: ExpandElement = index.into();
240    let value: ExpandElement = value.into();
241
242    let array_item = match array.kind {
243        // In that case, the array is a line.
244        VariableKind::LocalMut { .. } => array.item.vectorize(None),
245        _ => array.item,
246    };
247    let array_value = scope.create_local(array_item);
248
249    let read = Instruction::new(
250        Operator::Index(BinaryOperator {
251            lhs: *array,
252            rhs: *index,
253        }),
254        *array_value,
255    );
256    let array_value = array_value.consume();
257    let op_out = scope.create_local(array_item);
258    let calculate = Instruction::new(
259        func(BinaryOperator {
260            lhs: array_value,
261            rhs: *value,
262        }),
263        *op_out,
264    );
265
266    let write = Operator::IndexAssign(BinaryOperator {
267        lhs: *index,
268        rhs: op_out.consume(),
269    });
270    scope.register(read);
271    scope.register(calculate);
272    scope.register(Instruction::new(write, *array));
273}