cubecl_core/frontend/
plane.rs

1use cubecl_ir::ExpandElement;
2
3use super::{CubePrimitive, Line};
4use crate::prelude::ExpandElementTyped;
5use crate::{
6    ir::{Elem, Instruction, Item, Plane, Scope, UnaryOperator},
7    unexpanded,
8};
9
10/// Returns true if the cube unit has the lowest plane_unit_id among active unit in the plane
11pub fn plane_elect() -> bool {
12    unexpanded!()
13}
14
15/// Module containing the expand function for [plane_elect()].
16pub mod plane_elect {
17
18    use super::*;
19
20    /// Expand method of [plane_elect()].
21    pub fn expand(scope: &mut Scope) -> ExpandElementTyped<bool> {
22        let output = scope.create_local(Item::new(Elem::Bool));
23        let out = *output;
24
25        scope.register(Instruction::new(Plane::Elect, out));
26
27        output.into()
28    }
29}
30
31/// Broadcasts the value from the specified plane unit at the given index
32/// to all active units within that plane.
33#[allow(unused_variables)]
34pub fn plane_broadcast<E: CubePrimitive>(value: E, index: u32) -> E {
35    unexpanded!()
36}
37
38/// Module containing the expand function for [plane_broadcast()].
39pub mod plane_broadcast {
40
41    use super::*;
42
43    /// Expand method of [plane_broadcast()].
44    pub fn expand<E: CubePrimitive>(
45        scope: &mut Scope,
46        value: ExpandElementTyped<E>,
47        id: ExpandElementTyped<u32>,
48    ) -> ExpandElementTyped<E> {
49        let output = scope.create_local(value.expand.item);
50        let out = *output;
51        let lhs = *value.expand;
52        let rhs = *id.expand;
53
54        scope.register(Instruction::new(
55            Plane::Broadcast(crate::ir::BinaryOperator { lhs, rhs }),
56            out,
57        ));
58
59        output.into()
60    }
61}
62
63/// Perform a reduce sum operation across all units in a plane.
64#[allow(unused_variables)]
65pub fn plane_sum<E: CubePrimitive>(value: E) -> E {
66    unexpanded!()
67}
68
69/// Module containing the expand function for [plane_sum()].
70pub mod plane_sum {
71    use super::*;
72
73    /// Expand method of [plane_sum()].
74    pub fn expand<E: CubePrimitive>(
75        scope: &mut Scope,
76        elem: ExpandElementTyped<E>,
77    ) -> ExpandElementTyped<E> {
78        let elem: ExpandElement = elem.into();
79        let output = scope.create_local(elem.item);
80
81        let out = *output;
82        let input = *elem;
83
84        scope.register(Instruction::new(Plane::Sum(UnaryOperator { input }), out));
85
86        output.into()
87    }
88}
89
90/// Perform an inclusive sum operation across all units in a plane.
91/// This sums all values to the "left" of the unit, including this unit's value.
92/// Also known as "prefix sum" or "inclusive scan".
93///
94/// # Example
95/// `inclusive_sum([1, 2, 3, 4, 5]) == [1, 3, 6, 10, 15]`
96#[allow(unused_variables)]
97pub fn plane_inclusive_sum<E: CubePrimitive>(value: E) -> E {
98    unexpanded!()
99}
100
101/// Module containing the expand function for [plane_inclusive_sum()].
102pub mod plane_inclusive_sum {
103    use super::*;
104
105    /// Expand method of [plane_inclusive_sum()].
106    pub fn expand<E: CubePrimitive>(
107        scope: &mut Scope,
108        elem: ExpandElementTyped<E>,
109    ) -> ExpandElementTyped<E> {
110        let elem: ExpandElement = elem.into();
111        let output = scope.create_local(elem.item);
112
113        let out = *output;
114        let input = *elem;
115
116        scope.register(Instruction::new(
117            Plane::InclusiveSum(UnaryOperator { input }),
118            out,
119        ));
120
121        output.into()
122    }
123}
124
125/// Perform an exclusive sum operation across all units in a plane.
126/// This sums all values to the "left" of the unit, excluding this unit's value. The 0th unit will
127/// be set to `E::zero()`.
128/// Also known as "exclusive prefix sum" or "exclusive scan".
129///
130/// # Example
131/// `exclusive_sum([1, 2, 3, 4, 5]) == [0, 1, 3, 6, 10]`
132#[allow(unused_variables)]
133pub fn plane_exclusive_sum<E: CubePrimitive>(value: E) -> E {
134    unexpanded!()
135}
136
137/// Module containing the expand function for [plane_exclusive_sum()].
138pub mod plane_exclusive_sum {
139    use super::*;
140
141    /// Expand method of [plane_exclusive_sum()].
142    pub fn expand<E: CubePrimitive>(
143        scope: &mut Scope,
144        elem: ExpandElementTyped<E>,
145    ) -> ExpandElementTyped<E> {
146        let elem: ExpandElement = elem.into();
147        let output = scope.create_local(elem.item);
148
149        let out = *output;
150        let input = *elem;
151
152        scope.register(Instruction::new(
153            Plane::ExclusiveSum(UnaryOperator { input }),
154            out,
155        ));
156
157        output.into()
158    }
159}
160
161/// Perform a reduce prod operation across all units in a plane.
162pub fn plane_prod<E: CubePrimitive>(_elem: E) -> E {
163    unexpanded!()
164}
165
166/// Module containing the expand function for [plane_prod()].
167pub mod plane_prod {
168    use super::*;
169
170    /// Expand method of [plane_prod()].
171    pub fn expand<E: CubePrimitive>(
172        scope: &mut Scope,
173        elem: ExpandElementTyped<E>,
174    ) -> ExpandElementTyped<E> {
175        let elem: ExpandElement = elem.into();
176        let output = scope.create_local(elem.item);
177
178        let out = *output;
179        let input = *elem;
180
181        scope.register(Instruction::new(Plane::Prod(UnaryOperator { input }), out));
182
183        output.into()
184    }
185}
186
187/// Perform an inclusive product operation across all units in a plane.
188/// This multiplies all values to the "left" of the unit, including this unit's value.
189/// Also known as "prefix product" or "inclusive scan".
190///
191/// # Example
192/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 2, 6, 24, 120]`
193#[allow(unused_variables)]
194pub fn plane_inclusive_prod<E: CubePrimitive>(value: E) -> E {
195    unexpanded!()
196}
197
198/// Module containing the expand function for [plane_inclusive_prod()].
199pub mod plane_inclusive_prod {
200    use super::*;
201
202    /// Expand method of [plane_inclusive_prod()].
203    pub fn expand<E: CubePrimitive>(
204        scope: &mut Scope,
205        elem: ExpandElementTyped<E>,
206    ) -> ExpandElementTyped<E> {
207        let elem: ExpandElement = elem.into();
208        let output = scope.create_local(elem.item);
209
210        let out = *output;
211        let input = *elem;
212
213        scope.register(Instruction::new(
214            Plane::InclusiveProd(UnaryOperator { input }),
215            out,
216        ));
217
218        output.into()
219    }
220}
221
222/// Perform an exclusive product operation across all units in a plane.
223/// This multiplies all values to the "left" of the unit, excluding this unit's value. The 0th unit
224/// will be set to `E::one()`.
225/// Also known as "exclusive prefix product" or "exclusive scan".
226///
227/// # Example
228/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 1, 2, 6, 24]`
229#[allow(unused_variables)]
230pub fn plane_exclusive_prod<E: CubePrimitive>(value: E) -> E {
231    unexpanded!()
232}
233
234/// Module containing the expand function for [plane_exclusive_prod()].
235pub mod plane_exclusive_prod {
236    use super::*;
237
238    /// Expand method of [plane_exclusive_prod()].
239    pub fn expand<E: CubePrimitive>(
240        scope: &mut Scope,
241        elem: ExpandElementTyped<E>,
242    ) -> ExpandElementTyped<E> {
243        let elem: ExpandElement = elem.into();
244        let output = scope.create_local(elem.item);
245
246        let out = *output;
247        let input = *elem;
248
249        scope.register(Instruction::new(
250            Plane::ExclusiveProd(UnaryOperator { input }),
251            out,
252        ));
253
254        output.into()
255    }
256}
257
258/// Perform a reduce max operation across all units in a plane.
259pub fn plane_max<E: CubePrimitive>(_elem: E) -> E {
260    unexpanded!()
261}
262
263/// Module containing the expand function for [plane_max()].
264pub mod plane_max {
265    use super::*;
266
267    /// Expand method of [plane_max()].
268    pub fn expand<E: CubePrimitive>(
269        scope: &mut Scope,
270        elem: ExpandElementTyped<E>,
271    ) -> ExpandElementTyped<E> {
272        let elem: ExpandElement = elem.into();
273        let output = scope.create_local(elem.item);
274
275        let out = *output;
276        let input = *elem;
277
278        scope.register(Instruction::new(Plane::Max(UnaryOperator { input }), out));
279
280        output.into()
281    }
282}
283
284/// Perform a reduce min operation across all units in a plane.
285pub fn plane_min<E: CubePrimitive>(_elem: E) -> E {
286    unexpanded!()
287}
288
289/// Module containing the expand function for [plane_min()].
290pub mod plane_min {
291    use super::*;
292
293    /// Expand method of [plane_min()].
294    pub fn expand<E: CubePrimitive>(
295        scope: &mut Scope,
296        elem: ExpandElementTyped<E>,
297    ) -> ExpandElementTyped<E> {
298        let elem: ExpandElement = elem.into();
299        let output = scope.create_local(elem.item);
300
301        let out = *output;
302        let input = *elem;
303
304        scope.register(Instruction::new(Plane::Min(UnaryOperator { input }), out));
305
306        output.into()
307    }
308}
309
310/// Perform a reduce all operation across all units in a plane.
311pub fn plane_all(_elem: bool) -> bool {
312    unexpanded!()
313}
314
315/// Module containing the expand function for [plane_all()].
316pub mod plane_all {
317
318    use super::*;
319
320    /// Expand method of [plane_all()].
321    pub fn expand(scope: &mut Scope, elem: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
322        let elem: ExpandElement = elem.into();
323        let output = scope.create_local(elem.item);
324
325        let out = *output;
326        let input = *elem;
327
328        scope.register(Instruction::new(Plane::All(UnaryOperator { input }), out));
329
330        output.into()
331    }
332}
333
334/// Perform a reduce any operation across all units in a plane.
335pub fn plane_any(_elem: bool) -> bool {
336    unexpanded!()
337}
338
339/// Module containing the expand function for [plane_any()].
340pub mod plane_any {
341
342    use super::*;
343
344    /// Expand method of [plane_any()].
345    pub fn expand(scope: &mut Scope, elem: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
346        let elem: ExpandElement = elem.into();
347        let output = scope.create_local(elem.item);
348
349        let out = *output;
350        let input = *elem;
351
352        scope.register(Instruction::new(Plane::Any(UnaryOperator { input }), out));
353
354        output.into()
355    }
356}
357
358/// Perform a ballot operation across all units in a plane.
359/// Returns a set of 32-bit bitfields as a [`Line`], with each element containing the value from 32
360/// invocations.
361/// Note that line size will always be set to 4 even for `PLANE_DIM <= 64`, because we can't
362/// retrieve the actual plane size at expand time. Use the runtime`PLANE_DIM` to index appropriately.
363pub fn plane_ballot(_elem: bool) -> Line<u32> {
364    unexpanded!()
365}
366
367/// Module containing the expand function for [plane_ballot()].
368pub mod plane_ballot {
369
370    use std::num::NonZero;
371
372    use cubecl_ir::UIntKind;
373
374    use super::*;
375
376    /// Expand method of [plane_ballot()].
377    pub fn expand(
378        scope: &mut Scope,
379        elem: ExpandElementTyped<bool>,
380    ) -> ExpandElementTyped<Line<u32>> {
381        let elem: ExpandElement = elem.into();
382        let out_item = Item::vectorized(Elem::UInt(UIntKind::U32), NonZero::new(4));
383        let output = scope.create_local(out_item);
384
385        let out = *output;
386        let input = *elem;
387
388        scope.register(Instruction::new(
389            Plane::Ballot(UnaryOperator { input }),
390            out,
391        ));
392
393        output.into()
394    }
395}