cubecl_core/frontend/operation/
base.rs

1use std::num::{NonZero, NonZeroU8};
2
3use cubecl_ir::{
4    Arithmetic, BinaryOperator, Comparison, Elem, ExpandElement, IndexAssignOperator,
5    IndexOperator, Instruction, Item, Operation, Operator, Scope, UnaryOperator, Variable,
6    VariableKind, Vectorization,
7};
8
9use crate::prelude::{CubeIndex, CubeType, ExpandElementTyped};
10
11pub(crate) fn binary_expand<F, Op>(
12    scope: &mut Scope,
13    lhs: ExpandElement,
14    rhs: ExpandElement,
15    func: F,
16) -> ExpandElement
17where
18    F: Fn(BinaryOperator) -> Op,
19    Op: Into<Operation>,
20{
21    let lhs = lhs.consume();
22    let rhs = rhs.consume();
23
24    let item_lhs = lhs.item;
25    let item_rhs = rhs.item;
26
27    let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);
28
29    let item = Item::vectorized(item_lhs.elem, vectorization);
30
31    let output = scope.create_local(item);
32    let out = *output;
33
34    let op = func(BinaryOperator { lhs, rhs });
35
36    scope.register(Instruction::new(op, out));
37
38    output
39}
40
41pub(crate) fn index_expand_no_vec<F>(
42    scope: &mut Scope,
43    list: ExpandElement,
44    index: ExpandElement,
45    func: F,
46) -> ExpandElement
47where
48    F: Fn(IndexOperator) -> Operator,
49{
50    let list = list.consume();
51    let index = index.consume();
52
53    let item_lhs = list.item;
54
55    let item = Item::new(item_lhs.elem);
56
57    let output = scope.create_local(item);
58    let out = *output;
59
60    let op = func(IndexOperator {
61        list,
62        index,
63        line_size: 0u32,
64    });
65
66    scope.register(Instruction::new(op, out));
67
68    output
69}
70pub(crate) fn index_expand<F, Op>(
71    scope: &mut Scope,
72    list: ExpandElement,
73    index: ExpandElement,
74    line_size: Option<u32>,
75    func: F,
76) -> ExpandElement
77where
78    F: Fn(IndexOperator) -> Op,
79    Op: Into<Operation>,
80{
81    let list = list.consume();
82    let index = index.consume();
83
84    let item_lhs = list.item;
85    let item_rhs = index.item;
86
87    let vectorization = if let Some(line_size) = line_size {
88        NonZero::new(line_size as u8)
89    } else {
90        find_vectorization(item_lhs.vectorization, item_rhs.vectorization)
91    };
92
93    let item = Item::vectorized(item_lhs.elem, vectorization);
94
95    let output = scope.create_local(item);
96    let out = *output;
97
98    let op = func(IndexOperator {
99        list,
100        index,
101        line_size: line_size.unwrap_or(0),
102    });
103
104    scope.register(Instruction::new(op, out));
105
106    output
107}
108
109pub(crate) fn binary_expand_fixed_output<F>(
110    scope: &mut Scope,
111    lhs: ExpandElement,
112    rhs: ExpandElement,
113    out_item: Item,
114    func: F,
115) -> ExpandElement
116where
117    F: Fn(BinaryOperator) -> Arithmetic,
118{
119    let lhs_var = lhs.consume();
120    let rhs_var = rhs.consume();
121
122    let out = scope.create_local(out_item);
123
124    let out_var = *out;
125
126    let op = func(BinaryOperator {
127        lhs: lhs_var,
128        rhs: rhs_var,
129    });
130
131    scope.register(Instruction::new(op, out_var));
132
133    out
134}
135
136pub(crate) fn cmp_expand<F>(
137    scope: &mut Scope,
138    lhs: ExpandElement,
139    rhs: ExpandElement,
140    func: F,
141) -> ExpandElement
142where
143    F: Fn(BinaryOperator) -> Comparison,
144{
145    let lhs: Variable = *lhs;
146    let rhs: Variable = *rhs;
147    let item = lhs.item;
148
149    find_vectorization(item.vectorization, rhs.item.vectorization);
150
151    let out_item = Item {
152        elem: Elem::Bool,
153        vectorization: item.vectorization,
154    };
155
156    let out = scope.create_local(out_item);
157    let out_var = *out;
158
159    let op = func(BinaryOperator { lhs, rhs });
160
161    scope.register(Instruction::new(op, out_var));
162
163    out
164}
165
166pub(crate) fn assign_op_expand<F, Op>(
167    scope: &mut Scope,
168    lhs: ExpandElement,
169    rhs: ExpandElement,
170    func: F,
171) -> ExpandElement
172where
173    F: Fn(BinaryOperator) -> Op,
174    Op: Into<Operation>,
175{
176    if lhs.is_immutable() {
177        panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
178    }
179    let lhs_var: Variable = *lhs;
180    let rhs: Variable = *rhs;
181
182    find_vectorization(lhs_var.item.vectorization, rhs.item.vectorization);
183
184    let op = func(BinaryOperator { lhs: lhs_var, rhs });
185
186    scope.register(Instruction::new(op, lhs_var));
187
188    lhs
189}
190
191pub fn unary_expand<F, Op>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
192where
193    F: Fn(UnaryOperator) -> Op,
194    Op: Into<Operation>,
195{
196    let input = input.consume();
197    let item = input.item;
198
199    let out = scope.create_local(item);
200    let out_var = *out;
201
202    let op = func(UnaryOperator { input });
203
204    scope.register(Instruction::new(op, out_var));
205
206    out
207}
208
209pub fn unary_expand_fixed_output<F, Op>(
210    scope: &mut Scope,
211    input: ExpandElement,
212    out_item: Item,
213    func: F,
214) -> ExpandElement
215where
216    F: Fn(UnaryOperator) -> Op,
217    Op: Into<Operation>,
218{
219    let input = input.consume();
220    let output = scope.create_local(out_item);
221    let out = *output;
222
223    let op = func(UnaryOperator { input });
224
225    scope.register(Instruction::new(op, out));
226
227    output
228}
229
230pub fn init_expand<F>(
231    scope: &mut Scope,
232    input: ExpandElement,
233    mutable: bool,
234    func: F,
235) -> ExpandElement
236where
237    F: Fn(Variable) -> Operation,
238{
239    let input_var: Variable = *input;
240    let item = input.item;
241
242    let out = if mutable {
243        scope.create_local_mut(item)
244    } else {
245        scope.create_local(item)
246    };
247
248    let out_var = *out;
249
250    let op = func(input_var);
251    scope.register(Instruction::new(op, out_var));
252
253    out
254}
255
256fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
257    match (lhs, rhs) {
258        (None, None) => None,
259        (None, Some(rhs)) => Some(rhs),
260        (Some(lhs), None) => Some(lhs),
261        (Some(lhs), Some(rhs)) => {
262            if lhs == rhs {
263                Some(lhs)
264            } else if lhs == NonZeroU8::new(1).unwrap() || rhs == NonZeroU8::new(1).unwrap() {
265                Some(core::cmp::max(lhs, rhs))
266            } else {
267                panic!(
268                    "Left and right have different vectorizations.
269                    Left: {lhs}, right: {rhs}.
270                    Auto-matching fixed vectorization currently unsupported."
271                );
272            }
273        }
274    }
275}
276
277pub fn array_assign_binary_op_expand<
278    A: CubeType + CubeIndex,
279    V: CubeType,
280    F: Fn(BinaryOperator) -> Op,
281    Op: Into<Operation>,
282>(
283    scope: &mut Scope,
284    array: ExpandElementTyped<A>,
285    index: ExpandElementTyped<u32>,
286    value: ExpandElementTyped<V>,
287    func: F,
288) where
289    A::Output: CubeType + Sized,
290{
291    let array: ExpandElement = array.into();
292    let index: ExpandElement = index.into();
293    let value: ExpandElement = value.into();
294
295    let array_item = match array.kind {
296        // In that case, the array is a line.
297        VariableKind::LocalMut { .. } => array.item.vectorize(None),
298        _ => array.item,
299    };
300    let array_value = scope.create_local(array_item);
301
302    let read = Instruction::new(
303        Operator::Index(IndexOperator {
304            list: *array,
305            index: *index,
306            line_size: 0,
307        }),
308        *array_value,
309    );
310    let array_value = array_value.consume();
311    let op_out = scope.create_local(array_item);
312    let calculate = Instruction::new(
313        func(BinaryOperator {
314            lhs: array_value,
315            rhs: *value,
316        }),
317        *op_out,
318    );
319
320    let write = Operator::IndexAssign(IndexAssignOperator {
321        index: *index,
322        value: op_out.consume(),
323        line_size: 0,
324    });
325    scope.register(read);
326    scope.register(calculate);
327    scope.register(Instruction::new(write, *array));
328}