cubecl_core/frontend/operation/
base.rs

1use std::num::NonZeroU8;
2
3use crate::ir::{
4    BinaryOperator, Elem, Instruction, Item, Operation, Operator, UnaryOperator, Variable,
5    VariableKind, Vectorization,
6};
7use crate::prelude::{CubeType, ExpandElementTyped};
8use crate::{
9    frontend::{CubeContext, ExpandElement},
10    prelude::CubeIndex,
11};
12
13pub(crate) fn binary_expand<F>(
14    context: &mut CubeContext,
15    lhs: ExpandElement,
16    rhs: ExpandElement,
17    func: F,
18) -> ExpandElement
19where
20    F: Fn(BinaryOperator) -> Operator,
21{
22    let lhs = lhs.consume();
23    let rhs = rhs.consume();
24
25    let item_lhs = lhs.item;
26    let item_rhs = rhs.item;
27
28    let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);
29
30    let item = Item::vectorized(item_lhs.elem, vectorization);
31
32    let output = context.create_local(item);
33    let out = *output;
34
35    let op = func(BinaryOperator { lhs, rhs });
36
37    context.register(Instruction::new(op, out));
38
39    output
40}
41
42pub(crate) fn binary_expand_fixed_output<F>(
43    context: &mut CubeContext,
44    lhs: ExpandElement,
45    rhs: ExpandElement,
46    out_item: Item,
47    func: F,
48) -> ExpandElement
49where
50    F: Fn(BinaryOperator) -> Operator,
51{
52    let lhs_var = lhs.consume();
53    let rhs_var = rhs.consume();
54
55    let out = context.create_local(out_item);
56
57    let out_var = *out;
58
59    let op = func(BinaryOperator {
60        lhs: lhs_var,
61        rhs: rhs_var,
62    });
63
64    context.register(Instruction::new(op, out_var));
65
66    out
67}
68
69pub(crate) fn binary_expand_no_vec<F>(
70    context: &mut CubeContext,
71    lhs: ExpandElement,
72    rhs: ExpandElement,
73    func: F,
74) -> ExpandElement
75where
76    F: Fn(BinaryOperator) -> Operator,
77{
78    let lhs = lhs.consume();
79    let rhs = rhs.consume();
80
81    let item_lhs = lhs.item;
82
83    let item = Item::new(item_lhs.elem);
84
85    let output = context.create_local(item);
86    let out = *output;
87
88    let op = func(BinaryOperator { lhs, rhs });
89
90    context.register(Instruction::new(op, out));
91
92    output
93}
94
95pub(crate) fn cmp_expand<F>(
96    context: &mut CubeContext,
97    lhs: ExpandElement,
98    rhs: ExpandElement,
99    func: F,
100) -> ExpandElement
101where
102    F: Fn(BinaryOperator) -> Operator,
103{
104    let lhs: Variable = *lhs;
105    let rhs: Variable = *rhs;
106    let item = lhs.item;
107
108    find_vectorization(item.vectorization, rhs.item.vectorization);
109
110    let out_item = Item {
111        elem: Elem::Bool,
112        vectorization: item.vectorization,
113    };
114
115    let out = context.create_local(out_item);
116    let out_var = *out;
117
118    let op = func(BinaryOperator { lhs, rhs });
119
120    context.register(Instruction::new(op, out_var));
121
122    out
123}
124
125pub(crate) fn assign_op_expand<F>(
126    context: &mut CubeContext,
127    lhs: ExpandElement,
128    rhs: ExpandElement,
129    func: F,
130) -> ExpandElement
131where
132    F: Fn(BinaryOperator) -> Operator,
133{
134    let lhs_var: Variable = *lhs;
135    let rhs: Variable = *rhs;
136
137    find_vectorization(lhs_var.item.vectorization, rhs.item.vectorization);
138
139    let op = func(BinaryOperator { lhs: lhs_var, rhs });
140
141    context.register(Instruction::new(op, lhs_var));
142
143    lhs
144}
145
146pub fn unary_expand<F>(context: &mut CubeContext, input: ExpandElement, func: F) -> ExpandElement
147where
148    F: Fn(UnaryOperator) -> Operator,
149{
150    let input = input.consume();
151    let item = input.item;
152
153    let out = context.create_local(item);
154    let out_var = *out;
155
156    let op = func(UnaryOperator { input });
157
158    context.register(Instruction::new(op, out_var));
159
160    out
161}
162
163pub fn unary_expand_fixed_output<F>(
164    context: &mut CubeContext,
165    input: ExpandElement,
166    out_item: Item,
167    func: F,
168) -> ExpandElement
169where
170    F: Fn(UnaryOperator) -> Operator,
171{
172    let input = input.consume();
173    let output = context.create_local(out_item);
174    let out = *output;
175
176    let op = func(UnaryOperator { input });
177
178    context.register(Instruction::new(op, out));
179
180    output
181}
182
183pub fn init_expand<F>(context: &mut CubeContext, input: ExpandElement, func: F) -> ExpandElement
184where
185    F: Fn(Variable) -> Operation,
186{
187    if input.can_mut() {
188        return input;
189    }
190    let input_var: Variable = *input;
191    let item = input.item;
192
193    let out = context.create_local_mut(item); // TODO: The mut is safe, but unecessary if the variable is immutable.
194    let out_var = *out;
195
196    let op = func(input_var);
197    context.register(Instruction::new(op, out_var));
198
199    out
200}
201
202fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
203    match (lhs, rhs) {
204        (None, None) => None,
205        (None, Some(rhs)) => Some(rhs),
206        (Some(lhs), None) => Some(lhs),
207        (Some(lhs), Some(rhs)) => {
208            if lhs == rhs {
209                Some(lhs)
210            } else if lhs == NonZeroU8::new(1).unwrap() || rhs == NonZeroU8::new(1).unwrap() {
211                Some(core::cmp::max(lhs, rhs))
212            } else {
213                panic!(
214                    "Left and right have different vectorizations.
215                    Left: {lhs}, right: {rhs}.
216                    Auto-matching fixed vectorization currently unsupported."
217                );
218            }
219        }
220    }
221}
222
223pub fn array_assign_binary_op_expand<
224    A: CubeType + CubeIndex<u32>,
225    V: CubeType,
226    F: Fn(BinaryOperator) -> Operator,
227>(
228    context: &mut CubeContext,
229    array: ExpandElementTyped<A>,
230    index: ExpandElementTyped<u32>,
231    value: ExpandElementTyped<V>,
232    func: F,
233) where
234    A::Output: CubeType + Sized,
235{
236    let array: ExpandElement = array.into();
237    let index: ExpandElement = index.into();
238    let value: ExpandElement = value.into();
239
240    let array_item = match array.kind {
241        // In that case, the array is a line.
242        VariableKind::LocalMut { .. } => array.item.vectorize(None),
243        _ => array.item,
244    };
245    let array_value = context.create_local(array_item);
246
247    let read = Instruction::new(
248        Operator::Index(BinaryOperator {
249            lhs: *array,
250            rhs: *index,
251        }),
252        *array_value,
253    );
254    let array_value = array_value.consume();
255    let op_out = context.create_local(array_item);
256    let calculate = Instruction::new(
257        func(BinaryOperator {
258            lhs: array_value,
259            rhs: *value,
260        }),
261        *op_out,
262    );
263
264    let write = Operator::IndexAssign(BinaryOperator {
265        lhs: *index,
266        rhs: op_out.consume(),
267    });
268    context.register(read);
269    context.register(calculate);
270    context.register(Instruction::new(write, *array));
271}