cubecl_core/frontend/
plane.rs

1use cubecl_ir::ExpandElement;
2
3use super::{CubePrimitive, Line};
4use crate::prelude::ExpandElementTyped;
5use crate::{
6    ir::{ElemType, Instruction, Plane, Scope, Type, UnaryOperator},
7    unexpanded,
8};
9
10/// Returns true if the cube unit has the lowest plane_unit_id among active unit in the plane
11pub fn plane_elect() -> bool {
12    unexpanded!()
13}
14
15/// Module containing the expand function for [plane_elect()].
16pub mod plane_elect {
17
18    use super::*;
19
20    /// Expand method of [plane_elect()].
21    pub fn expand(scope: &mut Scope) -> ExpandElementTyped<bool> {
22        let output = scope.create_local(Type::scalar(ElemType::Bool));
23        let out = *output;
24
25        scope.register(Instruction::new(Plane::Elect, out));
26
27        output.into()
28    }
29}
30
31/// Broadcasts the value from the specified plane unit at the given index
32/// to all active units within that plane. Requires a constant index. For non-constant indices,
33/// use [`plane_shuffle`].
34#[allow(unused_variables)]
35pub fn plane_broadcast<E: CubePrimitive>(value: E, index: u32) -> E {
36    unexpanded!()
37}
38
39/// Module containing the expand function for [plane_broadcast()].
40pub mod plane_broadcast {
41
42    use super::*;
43
44    /// Expand method of [plane_broadcast()].
45    pub fn expand<E: CubePrimitive>(
46        scope: &mut Scope,
47        value: ExpandElementTyped<E>,
48        id: u32,
49    ) -> ExpandElementTyped<E> {
50        let output = scope.create_local(value.expand.ty);
51        let out = *output;
52        let lhs = *value.expand;
53        let rhs = id.into();
54
55        scope.register(Instruction::new(
56            Plane::Broadcast(crate::ir::BinaryOperator { lhs, rhs }),
57            out,
58        ));
59
60        output.into()
61    }
62}
63
64/// Perform an arbitrary lane shuffle operation across the plane.
65/// Each unit reads the value from the specified source lane.
66///
67/// # Example
68/// `plane_shuffle(value, 0)` - all lanes read from lane 0 (same as broadcast)
69/// `plane_shuffle(value, lane_id ^ 1)` - butterfly pattern (same as shuffle_xor)
70#[allow(unused_variables)]
71pub fn plane_shuffle<E: CubePrimitive>(value: E, src_lane: u32) -> E {
72    unexpanded!()
73}
74
75/// Module containing the expand function for [plane_shuffle()].
76pub mod plane_shuffle {
77
78    use super::*;
79
80    /// Expand method of [plane_shuffle()].
81    pub fn expand<E: CubePrimitive>(
82        scope: &mut Scope,
83        value: ExpandElementTyped<E>,
84        src_lane: ExpandElementTyped<u32>,
85    ) -> ExpandElementTyped<E> {
86        let output = scope.create_local(value.expand.ty);
87        let out = *output;
88        let lhs = *value.expand;
89        let rhs = *src_lane.expand;
90
91        scope.register(Instruction::new(
92            Plane::Shuffle(crate::ir::BinaryOperator { lhs, rhs }),
93            out,
94        ));
95
96        output.into()
97    }
98}
99
100/// Perform a shuffle XOR operation across the plane.
101/// Each unit exchanges its value with another unit at an index determined by XOR with the mask.
102/// This is useful for butterfly reduction patterns.
103///
104/// # Example
105/// For a 32-lane warp with mask=1:
106/// - Lane 0 gets value from lane 1, lane 1 gets value from lane 0
107/// - Lane 2 gets value from lane 3, lane 3 gets value from lane 2
108/// - etc.
109#[allow(unused_variables)]
110pub fn plane_shuffle_xor<E: CubePrimitive>(value: E, mask: u32) -> E {
111    unexpanded!()
112}
113
114/// Module containing the expand function for [plane_shuffle_xor()].
115pub mod plane_shuffle_xor {
116
117    use super::*;
118
119    /// Expand method of [plane_shuffle_xor()].
120    pub fn expand<E: CubePrimitive>(
121        scope: &mut Scope,
122        value: ExpandElementTyped<E>,
123        mask: ExpandElementTyped<u32>,
124    ) -> ExpandElementTyped<E> {
125        let output = scope.create_local(value.expand.ty);
126        let out = *output;
127        let lhs = *value.expand;
128        let rhs = *mask.expand;
129
130        scope.register(Instruction::new(
131            Plane::ShuffleXor(crate::ir::BinaryOperator { lhs, rhs }),
132            out,
133        ));
134
135        output.into()
136    }
137}
138
139/// Perform a shuffle up operation across the plane.
140/// Each unit reads the value from a unit with a lower lane ID (current_id - delta).
141/// Units with lane_id < delta will read from themselves (no change).
142///
143/// # Example
144/// For delta=1: `[a, b, c, d] -> [a, a, b, c]`
145#[allow(unused_variables)]
146pub fn plane_shuffle_up<E: CubePrimitive>(value: E, delta: u32) -> E {
147    unexpanded!()
148}
149
150/// Module containing the expand function for [plane_shuffle_up()].
151pub mod plane_shuffle_up {
152
153    use super::*;
154
155    /// Expand method of [plane_shuffle_up()].
156    pub fn expand<E: CubePrimitive>(
157        scope: &mut Scope,
158        value: ExpandElementTyped<E>,
159        delta: ExpandElementTyped<u32>,
160    ) -> ExpandElementTyped<E> {
161        let output = scope.create_local(value.expand.ty);
162        let out = *output;
163        let lhs = *value.expand;
164        let rhs = *delta.expand;
165
166        scope.register(Instruction::new(
167            Plane::ShuffleUp(crate::ir::BinaryOperator { lhs, rhs }),
168            out,
169        ));
170
171        output.into()
172    }
173}
174
175/// Perform a shuffle down operation across the plane.
176/// Each unit reads the value from a unit with a higher lane ID (current_id + delta).
177/// Units at the end will read from themselves if (lane_id + delta >= plane_dim).
178///
179/// # Example
180/// For delta=1: `[a, b, c, d] -> [b, c, d, d]`
181#[allow(unused_variables)]
182pub fn plane_shuffle_down<E: CubePrimitive>(value: E, delta: u32) -> E {
183    unexpanded!()
184}
185
186/// Module containing the expand function for [plane_shuffle_down()].
187pub mod plane_shuffle_down {
188
189    use super::*;
190
191    /// Expand method of [plane_shuffle_down()].
192    pub fn expand<E: CubePrimitive>(
193        scope: &mut Scope,
194        value: ExpandElementTyped<E>,
195        delta: ExpandElementTyped<u32>,
196    ) -> ExpandElementTyped<E> {
197        let output = scope.create_local(value.expand.ty);
198        let out = *output;
199        let lhs = *value.expand;
200        let rhs = *delta.expand;
201
202        scope.register(Instruction::new(
203            Plane::ShuffleDown(crate::ir::BinaryOperator { lhs, rhs }),
204            out,
205        ));
206
207        output.into()
208    }
209}
210
211/// Perform a reduce sum operation across all units in a plane.
212#[allow(unused_variables)]
213pub fn plane_sum<E: CubePrimitive>(value: E) -> E {
214    unexpanded!()
215}
216
217/// Module containing the expand function for [plane_sum()].
218pub mod plane_sum {
219    use super::*;
220
221    /// Expand method of [plane_sum()].
222    pub fn expand<E: CubePrimitive>(
223        scope: &mut Scope,
224        elem: ExpandElementTyped<E>,
225    ) -> ExpandElementTyped<E> {
226        let elem: ExpandElement = elem.into();
227        let output = scope.create_local(elem.ty);
228
229        let out = *output;
230        let input = *elem;
231
232        scope.register(Instruction::new(Plane::Sum(UnaryOperator { input }), out));
233
234        output.into()
235    }
236}
237
238/// Perform an inclusive sum operation across all units in a plane.
239/// This sums all values to the "left" of the unit, including this unit's value.
240/// Also known as "prefix sum" or "inclusive scan".
241///
242/// # Example
243/// `inclusive_sum([1, 2, 3, 4, 5]) == [1, 3, 6, 10, 15]`
244#[allow(unused_variables)]
245pub fn plane_inclusive_sum<E: CubePrimitive>(value: E) -> E {
246    unexpanded!()
247}
248
249/// Module containing the expand function for [plane_inclusive_sum()].
250pub mod plane_inclusive_sum {
251    use super::*;
252
253    /// Expand method of [plane_inclusive_sum()].
254    pub fn expand<E: CubePrimitive>(
255        scope: &mut Scope,
256        elem: ExpandElementTyped<E>,
257    ) -> ExpandElementTyped<E> {
258        let elem: ExpandElement = elem.into();
259        let output = scope.create_local(elem.ty);
260
261        let out = *output;
262        let input = *elem;
263
264        scope.register(Instruction::new(
265            Plane::InclusiveSum(UnaryOperator { input }),
266            out,
267        ));
268
269        output.into()
270    }
271}
272
273/// Perform an exclusive sum operation across all units in a plane.
274/// This sums all values to the "left" of the unit, excluding this unit's value. The 0th unit will
275/// be set to `E::zero()`.
276/// Also known as "exclusive prefix sum" or "exclusive scan".
277///
278/// # Example
279/// `exclusive_sum([1, 2, 3, 4, 5]) == [0, 1, 3, 6, 10]`
280#[allow(unused_variables)]
281pub fn plane_exclusive_sum<E: CubePrimitive>(value: E) -> E {
282    unexpanded!()
283}
284
285/// Module containing the expand function for [plane_exclusive_sum()].
286pub mod plane_exclusive_sum {
287    use super::*;
288
289    /// Expand method of [plane_exclusive_sum()].
290    pub fn expand<E: CubePrimitive>(
291        scope: &mut Scope,
292        elem: ExpandElementTyped<E>,
293    ) -> ExpandElementTyped<E> {
294        let elem: ExpandElement = elem.into();
295        let output = scope.create_local(elem.ty);
296
297        let out = *output;
298        let input = *elem;
299
300        scope.register(Instruction::new(
301            Plane::ExclusiveSum(UnaryOperator { input }),
302            out,
303        ));
304
305        output.into()
306    }
307}
308
309/// Perform a reduce prod operation across all units in a plane.
310pub fn plane_prod<E: CubePrimitive>(_elem: E) -> E {
311    unexpanded!()
312}
313
314/// Module containing the expand function for [plane_prod()].
315pub mod plane_prod {
316    use super::*;
317
318    /// Expand method of [plane_prod()].
319    pub fn expand<E: CubePrimitive>(
320        scope: &mut Scope,
321        elem: ExpandElementTyped<E>,
322    ) -> ExpandElementTyped<E> {
323        let elem: ExpandElement = elem.into();
324        let output = scope.create_local(elem.ty);
325
326        let out = *output;
327        let input = *elem;
328
329        scope.register(Instruction::new(Plane::Prod(UnaryOperator { input }), out));
330
331        output.into()
332    }
333}
334
335/// Perform an inclusive product operation across all units in a plane.
336/// This multiplies all values to the "left" of the unit, including this unit's value.
337/// Also known as "prefix product" or "inclusive scan".
338///
339/// # Example
340/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 2, 6, 24, 120]`
341#[allow(unused_variables)]
342pub fn plane_inclusive_prod<E: CubePrimitive>(value: E) -> E {
343    unexpanded!()
344}
345
346/// Module containing the expand function for [plane_inclusive_prod()].
347pub mod plane_inclusive_prod {
348    use super::*;
349
350    /// Expand method of [plane_inclusive_prod()].
351    pub fn expand<E: CubePrimitive>(
352        scope: &mut Scope,
353        elem: ExpandElementTyped<E>,
354    ) -> ExpandElementTyped<E> {
355        let elem: ExpandElement = elem.into();
356        let output = scope.create_local(elem.ty);
357
358        let out = *output;
359        let input = *elem;
360
361        scope.register(Instruction::new(
362            Plane::InclusiveProd(UnaryOperator { input }),
363            out,
364        ));
365
366        output.into()
367    }
368}
369
370/// Perform an exclusive product operation across all units in a plane.
371/// This multiplies all values to the "left" of the unit, excluding this unit's value. The 0th unit
372/// will be set to `E::one()`.
373/// Also known as "exclusive prefix product" or "exclusive scan".
374///
375/// # Example
376/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 1, 2, 6, 24]`
377#[allow(unused_variables)]
378pub fn plane_exclusive_prod<E: CubePrimitive>(value: E) -> E {
379    unexpanded!()
380}
381
382/// Module containing the expand function for [plane_exclusive_prod()].
383pub mod plane_exclusive_prod {
384    use super::*;
385
386    /// Expand method of [plane_exclusive_prod()].
387    pub fn expand<E: CubePrimitive>(
388        scope: &mut Scope,
389        elem: ExpandElementTyped<E>,
390    ) -> ExpandElementTyped<E> {
391        let elem: ExpandElement = elem.into();
392        let output = scope.create_local(elem.ty);
393
394        let out = *output;
395        let input = *elem;
396
397        scope.register(Instruction::new(
398            Plane::ExclusiveProd(UnaryOperator { input }),
399            out,
400        ));
401
402        output.into()
403    }
404}
405
406/// Perform a reduce max operation across all units in a plane.
407pub fn plane_max<E: CubePrimitive>(_elem: E) -> E {
408    unexpanded!()
409}
410
411/// Module containing the expand function for [plane_max()].
412pub mod plane_max {
413    use super::*;
414
415    /// Expand method of [plane_max()].
416    pub fn expand<E: CubePrimitive>(
417        scope: &mut Scope,
418        elem: ExpandElementTyped<E>,
419    ) -> ExpandElementTyped<E> {
420        let elem: ExpandElement = elem.into();
421        let output = scope.create_local(elem.ty);
422
423        let out = *output;
424        let input = *elem;
425
426        scope.register(Instruction::new(Plane::Max(UnaryOperator { input }), out));
427
428        output.into()
429    }
430}
431
432/// Perform a reduce min operation across all units in a plane.
433pub fn plane_min<E: CubePrimitive>(_elem: E) -> E {
434    unexpanded!()
435}
436
437/// Module containing the expand function for [plane_min()].
438pub mod plane_min {
439    use super::*;
440
441    /// Expand method of [plane_min()].
442    pub fn expand<E: CubePrimitive>(
443        scope: &mut Scope,
444        elem: ExpandElementTyped<E>,
445    ) -> ExpandElementTyped<E> {
446        let elem: ExpandElement = elem.into();
447        let output = scope.create_local(elem.ty);
448
449        let out = *output;
450        let input = *elem;
451
452        scope.register(Instruction::new(Plane::Min(UnaryOperator { input }), out));
453
454        output.into()
455    }
456}
457
458/// Perform a reduce all operation across all units in a plane.
459pub fn plane_all(_elem: bool) -> bool {
460    unexpanded!()
461}
462
463/// Module containing the expand function for [plane_all()].
464pub mod plane_all {
465
466    use super::*;
467
468    /// Expand method of [plane_all()].
469    pub fn expand(scope: &mut Scope, elem: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
470        let elem: ExpandElement = elem.into();
471        let output = scope.create_local(elem.ty);
472
473        let out = *output;
474        let input = *elem;
475
476        scope.register(Instruction::new(Plane::All(UnaryOperator { input }), out));
477
478        output.into()
479    }
480}
481
482/// Perform a reduce any operation across all units in a plane.
483pub fn plane_any(_elem: bool) -> bool {
484    unexpanded!()
485}
486
487/// Module containing the expand function for [plane_any()].
488pub mod plane_any {
489
490    use super::*;
491
492    /// Expand method of [plane_any()].
493    pub fn expand(scope: &mut Scope, elem: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
494        let elem: ExpandElement = elem.into();
495        let output = scope.create_local(elem.ty);
496
497        let out = *output;
498        let input = *elem;
499
500        scope.register(Instruction::new(Plane::Any(UnaryOperator { input }), out));
501
502        output.into()
503    }
504}
505
506/// Perform a ballot operation across all units in a plane.
507/// Returns a set of 32-bit bitfields as a [`Line`], with each element containing the value from 32
508/// invocations.
509/// Note that line size will always be set to 4 even for `PLANE_DIM <= 64`, because we can't
510/// retrieve the actual plane size at expand time. Use the runtime`PLANE_DIM` to index appropriately.
511pub fn plane_ballot(_elem: bool) -> Line<u32> {
512    unexpanded!()
513}
514
515/// Module containing the expand function for [plane_ballot()].
516pub mod plane_ballot {
517    use cubecl_ir::UIntKind;
518
519    use super::*;
520
521    /// Expand method of [plane_ballot()].
522    pub fn expand(
523        scope: &mut Scope,
524        elem: ExpandElementTyped<bool>,
525    ) -> ExpandElementTyped<Line<u32>> {
526        let elem: ExpandElement = elem.into();
527        let out_item = Type::scalar(ElemType::UInt(UIntKind::U32)).line(4);
528        let output = scope.create_local(out_item);
529
530        let out = *output;
531        let input = *elem;
532
533        scope.register(Instruction::new(
534            Plane::Ballot(UnaryOperator { input }),
535            out,
536        ));
537
538        output.into()
539    }
540}