cubecl_core/frontend/operation/
base.rs

1use cubecl_ir::{
2    Arithmetic, BinaryOperator, Comparison, ElemType, ExpandElement, IndexAssignOperator,
3    IndexOperator, Instruction, LineSize, Operation, Operator, Scope, Type, UnaryOperator,
4    Variable, VariableKind,
5};
6use cubecl_macros::cube;
7
8use crate::{
9    self as cubecl,
10    prelude::{CubeIndex, CubeType, ExpandElementTyped, Int, eq, rem},
11};
12
13pub(crate) fn binary_expand<F, Op>(
14    scope: &mut Scope,
15    lhs: ExpandElement,
16    rhs: ExpandElement,
17    func: F,
18) -> ExpandElement
19where
20    F: Fn(BinaryOperator) -> Op,
21    Op: Into<Operation>,
22{
23    let lhs = lhs.consume();
24    let rhs = rhs.consume();
25
26    let item_lhs = lhs.ty;
27    let item_rhs = rhs.ty;
28
29    let line_size = find_vectorization(item_lhs, item_rhs);
30
31    let item = item_lhs.line(line_size);
32
33    let output = scope.create_local(item);
34    let out = *output;
35
36    let op = func(BinaryOperator { lhs, rhs });
37
38    scope.register(Instruction::new(op, out));
39
40    output
41}
42
43pub(crate) fn index_expand_no_vec<F>(
44    scope: &mut Scope,
45    list: ExpandElement,
46    index: ExpandElement,
47    func: F,
48) -> ExpandElement
49where
50    F: Fn(IndexOperator) -> Operator,
51{
52    let list = list.consume();
53    let index = index.consume();
54
55    let item_lhs = list.ty;
56
57    let item = item_lhs.line(0);
58
59    let output = scope.create_local(item);
60    let out = *output;
61
62    let op = func(IndexOperator {
63        list,
64        index,
65        line_size: 0u32,
66        unroll_factor: 1,
67    });
68
69    scope.register(Instruction::new(op, out));
70
71    output
72}
73pub(crate) fn index_expand<F, Op>(
74    scope: &mut Scope,
75    list: ExpandElement,
76    index: ExpandElement,
77    line_size: Option<u32>,
78    func: F,
79) -> ExpandElement
80where
81    F: Fn(IndexOperator) -> Op,
82    Op: Into<Operation>,
83{
84    let list = list.consume();
85    let index = index.consume();
86
87    let item_lhs = list.ty;
88    let item_rhs = index.ty;
89
90    let vec = if let Some(line_size) = line_size {
91        line_size
92    } else {
93        find_vectorization(item_lhs, item_rhs)
94    };
95
96    let item = item_lhs.line(vec);
97
98    let output = scope.create_local(item);
99    let out = *output;
100
101    let op = func(IndexOperator {
102        list,
103        index,
104        line_size: line_size.unwrap_or(0),
105        unroll_factor: 1,
106    });
107
108    scope.register(Instruction::new(op, out));
109
110    output
111}
112
113pub(crate) fn binary_expand_fixed_output<F>(
114    scope: &mut Scope,
115    lhs: ExpandElement,
116    rhs: ExpandElement,
117    out_item: Type,
118    func: F,
119) -> ExpandElement
120where
121    F: Fn(BinaryOperator) -> Arithmetic,
122{
123    let lhs_var = lhs.consume();
124    let rhs_var = rhs.consume();
125
126    let out = scope.create_local(out_item);
127
128    let out_var = *out;
129
130    let op = func(BinaryOperator {
131        lhs: lhs_var,
132        rhs: rhs_var,
133    });
134
135    scope.register(Instruction::new(op, out_var));
136
137    out
138}
139
140pub(crate) fn cmp_expand<F>(
141    scope: &mut Scope,
142    lhs: ExpandElement,
143    rhs: ExpandElement,
144    func: F,
145) -> ExpandElement
146where
147    F: Fn(BinaryOperator) -> Comparison,
148{
149    let lhs = lhs.consume();
150    let rhs = rhs.consume();
151
152    let item_lhs = lhs.ty;
153    let item_rhs = rhs.ty;
154
155    let line_size = find_vectorization(item_lhs, item_rhs);
156
157    let out_item = Type::scalar(ElemType::Bool).line(line_size);
158
159    let out = scope.create_local(out_item);
160    let out_var = *out;
161
162    let op = func(BinaryOperator { lhs, rhs });
163
164    scope.register(Instruction::new(op, out_var));
165
166    out
167}
168
169pub(crate) fn assign_op_expand<F, Op>(
170    scope: &mut Scope,
171    lhs: ExpandElement,
172    rhs: ExpandElement,
173    func: F,
174) -> ExpandElement
175where
176    F: Fn(BinaryOperator) -> Op,
177    Op: Into<Operation>,
178{
179    if lhs.is_immutable() {
180        panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
181    }
182    let lhs_var: Variable = *lhs;
183    let rhs: Variable = *rhs;
184
185    let op = func(BinaryOperator { lhs: lhs_var, rhs });
186
187    scope.register(Instruction::new(op, lhs_var));
188
189    lhs
190}
191
192pub fn unary_expand<F, Op>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
193where
194    F: Fn(UnaryOperator) -> Op,
195    Op: Into<Operation>,
196{
197    let input = input.consume();
198    let item = input.ty;
199
200    let out = scope.create_local(item);
201    let out_var = *out;
202
203    let op = func(UnaryOperator { input });
204
205    scope.register(Instruction::new(op, out_var));
206
207    out
208}
209
210pub fn unary_expand_fixed_output<F, Op>(
211    scope: &mut Scope,
212    input: ExpandElement,
213    out_item: Type,
214    func: F,
215) -> ExpandElement
216where
217    F: Fn(UnaryOperator) -> Op,
218    Op: Into<Operation>,
219{
220    let input = input.consume();
221    let output = scope.create_local(out_item);
222    let out = *output;
223
224    let op = func(UnaryOperator { input });
225
226    scope.register(Instruction::new(op, out));
227
228    output
229}
230
231pub fn init_expand<F>(
232    scope: &mut Scope,
233    input: ExpandElement,
234    mutable: bool,
235    func: F,
236) -> ExpandElement
237where
238    F: Fn(Variable) -> Operation,
239{
240    let input_var: Variable = *input;
241    let item = input.ty;
242
243    let out = if mutable {
244        scope.create_local_mut(item)
245    } else {
246        scope.create_local(item)
247    };
248
249    let out_var = *out;
250
251    let op = func(input_var);
252    scope.register(Instruction::new(op, out_var));
253
254    out
255}
256
257fn find_vectorization(lhs: Type, rhs: Type) -> LineSize {
258    if matches!(lhs, Type::Scalar(_)) && matches!(rhs, Type::Scalar(_)) {
259        0
260    } else {
261        lhs.line_size().max(rhs.line_size())
262    }
263}
264
265pub fn array_assign_binary_op_expand<
266    A: CubeType + CubeIndex,
267    V: CubeType,
268    F: Fn(BinaryOperator) -> Op,
269    Op: Into<Operation>,
270>(
271    scope: &mut Scope,
272    array: ExpandElementTyped<A>,
273    index: ExpandElementTyped<u32>,
274    value: ExpandElementTyped<V>,
275    func: F,
276) where
277    A::Output: CubeType + Sized,
278{
279    let array: ExpandElement = array.into();
280    let index: ExpandElement = index.into();
281    let value: ExpandElement = value.into();
282
283    let array_item = match array.kind {
284        // In that case, the array is a line.
285        VariableKind::LocalMut { .. } => array.ty.line(0),
286        _ => array.ty,
287    };
288    let array_value = scope.create_local(array_item);
289
290    let read = Instruction::new(
291        Operator::Index(IndexOperator {
292            list: *array,
293            index: *index,
294            line_size: 0,
295            unroll_factor: 1,
296        }),
297        *array_value,
298    );
299    let array_value = array_value.consume();
300    let op_out = scope.create_local(array_item);
301    let calculate = Instruction::new(
302        func(BinaryOperator {
303            lhs: array_value,
304            rhs: *value,
305        }),
306        *op_out,
307    );
308
309    let write = Operator::IndexAssign(IndexAssignOperator {
310        index: *index,
311        value: op_out.consume(),
312        line_size: 0,
313        unroll_factor: 1,
314    });
315    scope.register(read);
316    scope.register(calculate);
317    scope.register(Instruction::new(write, *array));
318}
319
320// Utilities for clippy lint compatibility
321impl<E: Int> ExpandElementTyped<E> {
322    pub fn __expand_div_ceil_method(
323        self,
324        scope: &mut Scope,
325        divisor: ExpandElementTyped<E>,
326    ) -> ExpandElementTyped<E> {
327        div_ceil::expand::<E>(scope, self, divisor)
328    }
329
330    pub fn __expand_is_multiple_of_method(
331        self,
332        scope: &mut Scope,
333        factor: ExpandElementTyped<E>,
334    ) -> ExpandElementTyped<bool> {
335        let modulo = rem::expand(scope, self, factor);
336        eq::expand(scope, modulo, E::from_int(0).into())
337    }
338}
339
340#[cube]
341pub fn div_ceil<E: Int>(a: E, b: E) -> E {
342    (a + b - E::new(1)) / b
343}