cubecl_core/frontend/
plane.rs

1use cubecl_ir::ExpandElement;
2
3use super::{CubePrimitive, Line};
4use crate::prelude::ExpandElementTyped;
5use crate::{
6    ir::{ElemType, Instruction, Plane, Scope, Type, UnaryOperator},
7    unexpanded,
8};
9
10/// Returns true if the cube unit has the lowest plane_unit_id among active unit in the plane
11pub fn plane_elect() -> bool {
12    unexpanded!()
13}
14
15/// Module containing the expand function for [plane_elect()].
16pub mod plane_elect {
17
18    use super::*;
19
20    /// Expand method of [plane_elect()].
21    pub fn expand(scope: &mut Scope) -> ExpandElementTyped<bool> {
22        let output = scope.create_local(Type::scalar(ElemType::Bool));
23        let out = *output;
24
25        scope.register(Instruction::new(Plane::Elect, out));
26
27        output.into()
28    }
29}
30
31/// Broadcasts the value from the specified plane unit at the given index
32/// to all active units within that plane.
33#[allow(unused_variables)]
34pub fn plane_broadcast<E: CubePrimitive>(value: E, index: u32) -> E {
35    unexpanded!()
36}
37
38/// Module containing the expand function for [plane_broadcast()].
39pub mod plane_broadcast {
40
41    use super::*;
42
43    /// Expand method of [plane_broadcast()].
44    pub fn expand<E: CubePrimitive>(
45        scope: &mut Scope,
46        value: ExpandElementTyped<E>,
47        id: ExpandElementTyped<u32>,
48    ) -> ExpandElementTyped<E> {
49        let output = scope.create_local(value.expand.ty);
50        let out = *output;
51        let lhs = *value.expand;
52        let rhs = *id.expand;
53
54        scope.register(Instruction::new(
55            Plane::Broadcast(crate::ir::BinaryOperator { lhs, rhs }),
56            out,
57        ));
58
59        output.into()
60    }
61}
62
63/// Perform an arbitrary lane shuffle operation across the plane.
64/// Each unit reads the value from the specified source lane.
65///
66/// # Example
67/// `plane_shuffle(value, 0)` - all lanes read from lane 0 (same as broadcast)
68/// `plane_shuffle(value, lane_id ^ 1)` - butterfly pattern (same as shuffle_xor)
69#[allow(unused_variables)]
70pub fn plane_shuffle<E: CubePrimitive>(value: E, src_lane: u32) -> E {
71    unexpanded!()
72}
73
74/// Module containing the expand function for [plane_shuffle()].
75pub mod plane_shuffle {
76
77    use super::*;
78
79    /// Expand method of [plane_shuffle()].
80    pub fn expand<E: CubePrimitive>(
81        scope: &mut Scope,
82        value: ExpandElementTyped<E>,
83        src_lane: ExpandElementTyped<u32>,
84    ) -> ExpandElementTyped<E> {
85        let output = scope.create_local(value.expand.ty);
86        let out = *output;
87        let lhs = *value.expand;
88        let rhs = *src_lane.expand;
89
90        scope.register(Instruction::new(
91            Plane::Shuffle(crate::ir::BinaryOperator { lhs, rhs }),
92            out,
93        ));
94
95        output.into()
96    }
97}
98
99/// Perform a shuffle XOR operation across the plane.
100/// Each unit exchanges its value with another unit at an index determined by XOR with the mask.
101/// This is useful for butterfly reduction patterns.
102///
103/// # Example
104/// For a 32-lane warp with mask=1:
105/// - Lane 0 gets value from lane 1, lane 1 gets value from lane 0
106/// - Lane 2 gets value from lane 3, lane 3 gets value from lane 2
107/// - etc.
108#[allow(unused_variables)]
109pub fn plane_shuffle_xor<E: CubePrimitive>(value: E, mask: u32) -> E {
110    unexpanded!()
111}
112
113/// Module containing the expand function for [plane_shuffle_xor()].
114pub mod plane_shuffle_xor {
115
116    use super::*;
117
118    /// Expand method of [plane_shuffle_xor()].
119    pub fn expand<E: CubePrimitive>(
120        scope: &mut Scope,
121        value: ExpandElementTyped<E>,
122        mask: ExpandElementTyped<u32>,
123    ) -> ExpandElementTyped<E> {
124        let output = scope.create_local(value.expand.ty);
125        let out = *output;
126        let lhs = *value.expand;
127        let rhs = *mask.expand;
128
129        scope.register(Instruction::new(
130            Plane::ShuffleXor(crate::ir::BinaryOperator { lhs, rhs }),
131            out,
132        ));
133
134        output.into()
135    }
136}
137
138/// Perform a shuffle up operation across the plane.
139/// Each unit reads the value from a unit with a lower lane ID (current_id - delta).
140/// Units with lane_id < delta will read from themselves (no change).
141///
142/// # Example
143/// For delta=1: `[a, b, c, d] -> [a, a, b, c]`
144#[allow(unused_variables)]
145pub fn plane_shuffle_up<E: CubePrimitive>(value: E, delta: u32) -> E {
146    unexpanded!()
147}
148
149/// Module containing the expand function for [plane_shuffle_up()].
150pub mod plane_shuffle_up {
151
152    use super::*;
153
154    /// Expand method of [plane_shuffle_up()].
155    pub fn expand<E: CubePrimitive>(
156        scope: &mut Scope,
157        value: ExpandElementTyped<E>,
158        delta: ExpandElementTyped<u32>,
159    ) -> ExpandElementTyped<E> {
160        let output = scope.create_local(value.expand.ty);
161        let out = *output;
162        let lhs = *value.expand;
163        let rhs = *delta.expand;
164
165        scope.register(Instruction::new(
166            Plane::ShuffleUp(crate::ir::BinaryOperator { lhs, rhs }),
167            out,
168        ));
169
170        output.into()
171    }
172}
173
174/// Perform a shuffle down operation across the plane.
175/// Each unit reads the value from a unit with a higher lane ID (current_id + delta).
176/// Units at the end will read from themselves if (lane_id + delta >= plane_dim).
177///
178/// # Example
179/// For delta=1: `[a, b, c, d] -> [b, c, d, d]`
180#[allow(unused_variables)]
181pub fn plane_shuffle_down<E: CubePrimitive>(value: E, delta: u32) -> E {
182    unexpanded!()
183}
184
185/// Module containing the expand function for [plane_shuffle_down()].
186pub mod plane_shuffle_down {
187
188    use super::*;
189
190    /// Expand method of [plane_shuffle_down()].
191    pub fn expand<E: CubePrimitive>(
192        scope: &mut Scope,
193        value: ExpandElementTyped<E>,
194        delta: ExpandElementTyped<u32>,
195    ) -> ExpandElementTyped<E> {
196        let output = scope.create_local(value.expand.ty);
197        let out = *output;
198        let lhs = *value.expand;
199        let rhs = *delta.expand;
200
201        scope.register(Instruction::new(
202            Plane::ShuffleDown(crate::ir::BinaryOperator { lhs, rhs }),
203            out,
204        ));
205
206        output.into()
207    }
208}
209
210/// Perform a reduce sum operation across all units in a plane.
211#[allow(unused_variables)]
212pub fn plane_sum<E: CubePrimitive>(value: E) -> E {
213    unexpanded!()
214}
215
216/// Module containing the expand function for [plane_sum()].
217pub mod plane_sum {
218    use super::*;
219
220    /// Expand method of [plane_sum()].
221    pub fn expand<E: CubePrimitive>(
222        scope: &mut Scope,
223        elem: ExpandElementTyped<E>,
224    ) -> ExpandElementTyped<E> {
225        let elem: ExpandElement = elem.into();
226        let output = scope.create_local(elem.ty);
227
228        let out = *output;
229        let input = *elem;
230
231        scope.register(Instruction::new(Plane::Sum(UnaryOperator { input }), out));
232
233        output.into()
234    }
235}
236
237/// Perform an inclusive sum operation across all units in a plane.
238/// This sums all values to the "left" of the unit, including this unit's value.
239/// Also known as "prefix sum" or "inclusive scan".
240///
241/// # Example
242/// `inclusive_sum([1, 2, 3, 4, 5]) == [1, 3, 6, 10, 15]`
243#[allow(unused_variables)]
244pub fn plane_inclusive_sum<E: CubePrimitive>(value: E) -> E {
245    unexpanded!()
246}
247
248/// Module containing the expand function for [plane_inclusive_sum()].
249pub mod plane_inclusive_sum {
250    use super::*;
251
252    /// Expand method of [plane_inclusive_sum()].
253    pub fn expand<E: CubePrimitive>(
254        scope: &mut Scope,
255        elem: ExpandElementTyped<E>,
256    ) -> ExpandElementTyped<E> {
257        let elem: ExpandElement = elem.into();
258        let output = scope.create_local(elem.ty);
259
260        let out = *output;
261        let input = *elem;
262
263        scope.register(Instruction::new(
264            Plane::InclusiveSum(UnaryOperator { input }),
265            out,
266        ));
267
268        output.into()
269    }
270}
271
272/// Perform an exclusive sum operation across all units in a plane.
273/// This sums all values to the "left" of the unit, excluding this unit's value. The 0th unit will
274/// be set to `E::zero()`.
275/// Also known as "exclusive prefix sum" or "exclusive scan".
276///
277/// # Example
278/// `exclusive_sum([1, 2, 3, 4, 5]) == [0, 1, 3, 6, 10]`
279#[allow(unused_variables)]
280pub fn plane_exclusive_sum<E: CubePrimitive>(value: E) -> E {
281    unexpanded!()
282}
283
284/// Module containing the expand function for [plane_exclusive_sum()].
285pub mod plane_exclusive_sum {
286    use super::*;
287
288    /// Expand method of [plane_exclusive_sum()].
289    pub fn expand<E: CubePrimitive>(
290        scope: &mut Scope,
291        elem: ExpandElementTyped<E>,
292    ) -> ExpandElementTyped<E> {
293        let elem: ExpandElement = elem.into();
294        let output = scope.create_local(elem.ty);
295
296        let out = *output;
297        let input = *elem;
298
299        scope.register(Instruction::new(
300            Plane::ExclusiveSum(UnaryOperator { input }),
301            out,
302        ));
303
304        output.into()
305    }
306}
307
308/// Perform a reduce prod operation across all units in a plane.
309pub fn plane_prod<E: CubePrimitive>(_elem: E) -> E {
310    unexpanded!()
311}
312
313/// Module containing the expand function for [plane_prod()].
314pub mod plane_prod {
315    use super::*;
316
317    /// Expand method of [plane_prod()].
318    pub fn expand<E: CubePrimitive>(
319        scope: &mut Scope,
320        elem: ExpandElementTyped<E>,
321    ) -> ExpandElementTyped<E> {
322        let elem: ExpandElement = elem.into();
323        let output = scope.create_local(elem.ty);
324
325        let out = *output;
326        let input = *elem;
327
328        scope.register(Instruction::new(Plane::Prod(UnaryOperator { input }), out));
329
330        output.into()
331    }
332}
333
334/// Perform an inclusive product operation across all units in a plane.
335/// This multiplies all values to the "left" of the unit, including this unit's value.
336/// Also known as "prefix product" or "inclusive scan".
337///
338/// # Example
339/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 2, 6, 24, 120]`
340#[allow(unused_variables)]
341pub fn plane_inclusive_prod<E: CubePrimitive>(value: E) -> E {
342    unexpanded!()
343}
344
345/// Module containing the expand function for [plane_inclusive_prod()].
346pub mod plane_inclusive_prod {
347    use super::*;
348
349    /// Expand method of [plane_inclusive_prod()].
350    pub fn expand<E: CubePrimitive>(
351        scope: &mut Scope,
352        elem: ExpandElementTyped<E>,
353    ) -> ExpandElementTyped<E> {
354        let elem: ExpandElement = elem.into();
355        let output = scope.create_local(elem.ty);
356
357        let out = *output;
358        let input = *elem;
359
360        scope.register(Instruction::new(
361            Plane::InclusiveProd(UnaryOperator { input }),
362            out,
363        ));
364
365        output.into()
366    }
367}
368
369/// Perform an exclusive product operation across all units in a plane.
370/// This multiplies all values to the "left" of the unit, excluding this unit's value. The 0th unit
371/// will be set to `E::one()`.
372/// Also known as "exclusive prefix product" or "exclusive scan".
373///
374/// # Example
375/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 1, 2, 6, 24]`
376#[allow(unused_variables)]
377pub fn plane_exclusive_prod<E: CubePrimitive>(value: E) -> E {
378    unexpanded!()
379}
380
381/// Module containing the expand function for [plane_exclusive_prod()].
382pub mod plane_exclusive_prod {
383    use super::*;
384
385    /// Expand method of [plane_exclusive_prod()].
386    pub fn expand<E: CubePrimitive>(
387        scope: &mut Scope,
388        elem: ExpandElementTyped<E>,
389    ) -> ExpandElementTyped<E> {
390        let elem: ExpandElement = elem.into();
391        let output = scope.create_local(elem.ty);
392
393        let out = *output;
394        let input = *elem;
395
396        scope.register(Instruction::new(
397            Plane::ExclusiveProd(UnaryOperator { input }),
398            out,
399        ));
400
401        output.into()
402    }
403}
404
405/// Perform a reduce max operation across all units in a plane.
406pub fn plane_max<E: CubePrimitive>(_elem: E) -> E {
407    unexpanded!()
408}
409
410/// Module containing the expand function for [plane_max()].
411pub mod plane_max {
412    use super::*;
413
414    /// Expand method of [plane_max()].
415    pub fn expand<E: CubePrimitive>(
416        scope: &mut Scope,
417        elem: ExpandElementTyped<E>,
418    ) -> ExpandElementTyped<E> {
419        let elem: ExpandElement = elem.into();
420        let output = scope.create_local(elem.ty);
421
422        let out = *output;
423        let input = *elem;
424
425        scope.register(Instruction::new(Plane::Max(UnaryOperator { input }), out));
426
427        output.into()
428    }
429}
430
431/// Perform a reduce min operation across all units in a plane.
432pub fn plane_min<E: CubePrimitive>(_elem: E) -> E {
433    unexpanded!()
434}
435
436/// Module containing the expand function for [plane_min()].
437pub mod plane_min {
438    use super::*;
439
440    /// Expand method of [plane_min()].
441    pub fn expand<E: CubePrimitive>(
442        scope: &mut Scope,
443        elem: ExpandElementTyped<E>,
444    ) -> ExpandElementTyped<E> {
445        let elem: ExpandElement = elem.into();
446        let output = scope.create_local(elem.ty);
447
448        let out = *output;
449        let input = *elem;
450
451        scope.register(Instruction::new(Plane::Min(UnaryOperator { input }), out));
452
453        output.into()
454    }
455}
456
457/// Perform a reduce all operation across all units in a plane.
458pub fn plane_all(_elem: bool) -> bool {
459    unexpanded!()
460}
461
462/// Module containing the expand function for [plane_all()].
463pub mod plane_all {
464
465    use super::*;
466
467    /// Expand method of [plane_all()].
468    pub fn expand(scope: &mut Scope, elem: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
469        let elem: ExpandElement = elem.into();
470        let output = scope.create_local(elem.ty);
471
472        let out = *output;
473        let input = *elem;
474
475        scope.register(Instruction::new(Plane::All(UnaryOperator { input }), out));
476
477        output.into()
478    }
479}
480
481/// Perform a reduce any operation across all units in a plane.
482pub fn plane_any(_elem: bool) -> bool {
483    unexpanded!()
484}
485
486/// Module containing the expand function for [plane_any()].
487pub mod plane_any {
488
489    use super::*;
490
491    /// Expand method of [plane_any()].
492    pub fn expand(scope: &mut Scope, elem: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
493        let elem: ExpandElement = elem.into();
494        let output = scope.create_local(elem.ty);
495
496        let out = *output;
497        let input = *elem;
498
499        scope.register(Instruction::new(Plane::Any(UnaryOperator { input }), out));
500
501        output.into()
502    }
503}
504
505/// Perform a ballot operation across all units in a plane.
506/// Returns a set of 32-bit bitfields as a [`Line`], with each element containing the value from 32
507/// invocations.
508/// Note that line size will always be set to 4 even for `PLANE_DIM <= 64`, because we can't
509/// retrieve the actual plane size at expand time. Use the runtime`PLANE_DIM` to index appropriately.
510pub fn plane_ballot(_elem: bool) -> Line<u32> {
511    unexpanded!()
512}
513
514/// Module containing the expand function for [plane_ballot()].
515pub mod plane_ballot {
516    use cubecl_ir::UIntKind;
517
518    use super::*;
519
520    /// Expand method of [plane_ballot()].
521    pub fn expand(
522        scope: &mut Scope,
523        elem: ExpandElementTyped<bool>,
524    ) -> ExpandElementTyped<Line<u32>> {
525        let elem: ExpandElement = elem.into();
526        let out_item = Type::scalar(ElemType::UInt(UIntKind::U32)).line(4);
527        let output = scope.create_local(out_item);
528
529        let out = *output;
530        let input = *elem;
531
532        scope.register(Instruction::new(
533            Plane::Ballot(UnaryOperator { input }),
534            out,
535        ));
536
537        output.into()
538    }
539}