Skip to main content

cubecl_core/frontend/
plane.rs

1use cubecl_ir::ManagedVariable;
2
3use super::{CubePrimitive, Vector};
4use crate::prelude::*;
5use crate::{
6    ir::{ElemType, Instruction, Plane, Scope, Type, UnaryOperator},
7    unexpanded,
8};
9
10/// Returns true if the cube unit has the lowest `plane_unit_id` among active unit in the plane
11pub fn plane_elect() -> bool {
12    unexpanded!()
13}
14
15/// Module containing the expand function for [`plane_elect()`].
16pub mod plane_elect {
17
18    use super::*;
19
20    /// Expand method of [`plane_elect()`].
21    pub fn expand(scope: &mut Scope) -> NativeExpand<bool> {
22        let output = scope.create_local(Type::scalar(ElemType::Bool));
23        let out = *output;
24
25        scope.register(Instruction::new(Plane::Elect, out));
26
27        output.into()
28    }
29}
30
31/// Broadcasts the value from the specified plane unit at the given index
32/// to all active units within that plane. Requires a constant index. For non-constant indices,
33/// use [`plane_shuffle()`].
34#[allow(unused_variables)]
35pub fn plane_broadcast<E: CubePrimitive>(value: E, index: u32) -> E {
36    unexpanded!()
37}
38
39/// Module containing the expand function for [`plane_broadcast()`].
40pub mod plane_broadcast {
41
42    use super::*;
43
44    /// Expand method of [`plane_broadcast()`].
45    pub fn expand<E: CubePrimitive>(
46        scope: &mut Scope,
47        value: NativeExpand<E>,
48        id: u32,
49    ) -> NativeExpand<E> {
50        let output = scope.create_local(value.expand.ty);
51        let out = *output;
52        let lhs = *value.expand;
53        let rhs = id.into();
54
55        scope.register(Instruction::new(
56            Plane::Broadcast(crate::ir::BinaryOperator { lhs, rhs }),
57            out,
58        ));
59
60        output.into()
61    }
62}
63
64/// Perform an arbitrary lane shuffle operation across the plane.
65/// Each unit reads the value from the specified source lane.
66///
67/// # Example
68/// `plane_shuffle(value, 0)` - all lanes read from lane 0 (same as broadcast)
69/// `plane_shuffle(value, lane_id ^ 1)` - butterfly pattern (same as `shuffle_xor`)
70#[allow(unused_variables)]
71pub fn plane_shuffle<E: CubePrimitive>(value: E, src_lane: u32) -> E {
72    unexpanded!()
73}
74
75/// Module containing the expand function for [`plane_shuffle()`].
76pub mod plane_shuffle {
77
78    use super::*;
79
80    /// Expand method of [`plane_shuffle()`].
81    pub fn expand<E: CubePrimitive>(
82        scope: &mut Scope,
83        value: NativeExpand<E>,
84        src_lane: NativeExpand<u32>,
85    ) -> NativeExpand<E> {
86        let output = scope.create_local(value.expand.ty);
87        let out = *output;
88        let lhs = *value.expand;
89        let rhs = *src_lane.expand;
90
91        scope.register(Instruction::new(
92            Plane::Shuffle(crate::ir::BinaryOperator { lhs, rhs }),
93            out,
94        ));
95
96        output.into()
97    }
98}
99
100/// Perform a shuffle XOR operation across the plane.
101/// Each unit exchanges its value with another unit at an index determined by XOR with the mask.
102/// This is useful for butterfly reduction patterns.
103///
104/// # Example
105/// For a 32-lane warp with mask=1:
106/// - Lane 0 gets value from lane 1, lane 1 gets value from lane 0
107/// - Lane 2 gets value from lane 3, lane 3 gets value from lane 2
108/// - etc.
109#[allow(unused_variables)]
110pub fn plane_shuffle_xor<E: CubePrimitive>(value: E, mask: u32) -> E {
111    unexpanded!()
112}
113
114/// Module containing the expand function for [`plane_shuffle_xor()`].
115pub mod plane_shuffle_xor {
116
117    use super::*;
118
119    /// Expand method of [`plane_shuffle_xor()`].
120    pub fn expand<E: CubePrimitive>(
121        scope: &mut Scope,
122        value: NativeExpand<E>,
123        mask: NativeExpand<u32>,
124    ) -> NativeExpand<E> {
125        let output = scope.create_local(value.expand.ty);
126        let out = *output;
127        let lhs = *value.expand;
128        let rhs = *mask.expand;
129
130        scope.register(Instruction::new(
131            Plane::ShuffleXor(crate::ir::BinaryOperator { lhs, rhs }),
132            out,
133        ));
134
135        output.into()
136    }
137}
138
139/// Perform a shuffle up operation across the plane.
140/// Each unit reads the value from a unit with a lower lane ID (`current_id` - delta).
141/// Units with `lane_id` < delta will read from themselves (no change).
142///
143/// # Example
144/// For delta=1: `[a, b, c, d] -> [a, a, b, c]`
145#[allow(unused_variables)]
146pub fn plane_shuffle_up<E: CubePrimitive>(value: E, delta: u32) -> E {
147    unexpanded!()
148}
149
150/// Module containing the expand function for [`plane_shuffle_up()`].
151pub mod plane_shuffle_up {
152
153    use super::*;
154
155    /// Expand method of [`plane_shuffle_up()`].
156    pub fn expand<E: CubePrimitive>(
157        scope: &mut Scope,
158        value: NativeExpand<E>,
159        delta: NativeExpand<u32>,
160    ) -> NativeExpand<E> {
161        let output = scope.create_local(value.expand.ty);
162        let out = *output;
163        let lhs = *value.expand;
164        let rhs = *delta.expand;
165
166        scope.register(Instruction::new(
167            Plane::ShuffleUp(crate::ir::BinaryOperator { lhs, rhs }),
168            out,
169        ));
170
171        output.into()
172    }
173}
174
175/// Perform a shuffle down operation across the plane.
176/// Each unit reads the value from a unit with a higher lane ID (`current_id` + delta).
177/// Units at the end will read from themselves if (`lane_id` + delta >= `plane_dim`).
178///
179/// # Example
180/// For delta=1: `[a, b, c, d] -> [b, c, d, d]`
181#[allow(unused_variables)]
182pub fn plane_shuffle_down<E: CubePrimitive>(value: E, delta: u32) -> E {
183    unexpanded!()
184}
185
186/// Module containing the expand function for [`plane_shuffle_down()`].
187pub mod plane_shuffle_down {
188
189    use super::*;
190
191    /// Expand method of [`plane_shuffle_down()`].
192    pub fn expand<E: CubePrimitive>(
193        scope: &mut Scope,
194        value: NativeExpand<E>,
195        delta: NativeExpand<u32>,
196    ) -> NativeExpand<E> {
197        let output = scope.create_local(value.expand.ty);
198        let out = *output;
199        let lhs = *value.expand;
200        let rhs = *delta.expand;
201
202        scope.register(Instruction::new(
203            Plane::ShuffleDown(crate::ir::BinaryOperator { lhs, rhs }),
204            out,
205        ));
206
207        output.into()
208    }
209}
210
211/// Perform a reduce sum operation across all units in a plane.
212#[allow(unused_variables)]
213pub fn plane_sum<E: CubePrimitive>(value: E) -> E {
214    unexpanded!()
215}
216
217/// Module containing the expand function for [`plane_sum()`].
218pub mod plane_sum {
219    use super::*;
220
221    /// Expand method of [`plane_sum()`].
222    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
223        let elem: ManagedVariable = elem.into();
224        let output = scope.create_local(elem.ty);
225
226        let out = *output;
227        let input = *elem;
228
229        scope.register(Instruction::new(Plane::Sum(UnaryOperator { input }), out));
230
231        output.into()
232    }
233}
234
235/// Perform an inclusive sum operation across all units in a plane.
236/// This sums all values to the "left" of the unit, including this unit's value.
237/// Also known as "prefix sum" or "inclusive scan".
238///
239/// # Example
240/// `inclusive_sum([1, 2, 3, 4, 5]) == [1, 3, 6, 10, 15]`
241#[allow(unused_variables)]
242pub fn plane_inclusive_sum<E: CubePrimitive>(value: E) -> E {
243    unexpanded!()
244}
245
246/// Module containing the expand function for [`plane_inclusive_sum()`].
247pub mod plane_inclusive_sum {
248    use super::*;
249
250    /// Expand method of [`plane_inclusive_sum()`].
251    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
252        let elem: ManagedVariable = elem.into();
253        let output = scope.create_local(elem.ty);
254
255        let out = *output;
256        let input = *elem;
257
258        scope.register(Instruction::new(
259            Plane::InclusiveSum(UnaryOperator { input }),
260            out,
261        ));
262
263        output.into()
264    }
265}
266
267/// Perform an exclusive sum operation across all units in a plane.
268/// This sums all values to the "left" of the unit, excluding this unit's value. The 0th unit will
269/// be set to `E::zero()`.
270/// Also known as "exclusive prefix sum" or "exclusive scan".
271///
272/// # Example
273/// `exclusive_sum([1, 2, 3, 4, 5]) == [0, 1, 3, 6, 10]`
274#[allow(unused_variables)]
275pub fn plane_exclusive_sum<E: CubePrimitive>(value: E) -> E {
276    unexpanded!()
277}
278
279/// Module containing the expand function for [`plane_exclusive_sum()`].
280pub mod plane_exclusive_sum {
281    use super::*;
282
283    /// Expand method of [`plane_exclusive_sum()`].
284    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
285        let elem: ManagedVariable = elem.into();
286        let output = scope.create_local(elem.ty);
287
288        let out = *output;
289        let input = *elem;
290
291        scope.register(Instruction::new(
292            Plane::ExclusiveSum(UnaryOperator { input }),
293            out,
294        ));
295
296        output.into()
297    }
298}
299
300/// Perform a reduce prod operation across all units in a plane.
301pub fn plane_prod<E: CubePrimitive>(_elem: E) -> E {
302    unexpanded!()
303}
304
305/// Module containing the expand function for [`plane_prod()`].
306pub mod plane_prod {
307    use super::*;
308
309    /// Expand method of [`plane_prod()`].
310    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
311        let elem: ManagedVariable = elem.into();
312        let output = scope.create_local(elem.ty);
313
314        let out = *output;
315        let input = *elem;
316
317        scope.register(Instruction::new(Plane::Prod(UnaryOperator { input }), out));
318
319        output.into()
320    }
321}
322
323/// Perform an inclusive product operation across all units in a plane.
324/// This multiplies all values to the "left" of the unit, including this unit's value.
325/// Also known as "prefix product" or "inclusive scan".
326///
327/// # Example
328/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 2, 6, 24, 120]`
329#[allow(unused_variables)]
330pub fn plane_inclusive_prod<E: CubePrimitive>(value: E) -> E {
331    unexpanded!()
332}
333
334/// Module containing the expand function for [`plane_inclusive_prod()`].
335pub mod plane_inclusive_prod {
336    use super::*;
337
338    /// Expand method of [`plane_inclusive_prod()`].
339    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
340        let elem: ManagedVariable = elem.into();
341        let output = scope.create_local(elem.ty);
342
343        let out = *output;
344        let input = *elem;
345
346        scope.register(Instruction::new(
347            Plane::InclusiveProd(UnaryOperator { input }),
348            out,
349        ));
350
351        output.into()
352    }
353}
354
355/// Perform an exclusive product operation across all units in a plane.
356/// This multiplies all values to the "left" of the unit, excluding this unit's value. The 0th unit
357/// will be set to `E::one()`.
358/// Also known as "exclusive prefix product" or "exclusive scan".
359///
360/// # Example
361/// `exclusive_prod([1, 2, 3, 4, 5]) == [1, 1, 2, 6, 24]`
362#[allow(unused_variables)]
363pub fn plane_exclusive_prod<E: CubePrimitive>(value: E) -> E {
364    unexpanded!()
365}
366
367/// Module containing the expand function for [`plane_exclusive_prod()`].
368pub mod plane_exclusive_prod {
369    use super::*;
370
371    /// Expand method of [`plane_exclusive_prod()`].
372    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
373        let elem: ManagedVariable = elem.into();
374        let output = scope.create_local(elem.ty);
375
376        let out = *output;
377        let input = *elem;
378
379        scope.register(Instruction::new(
380            Plane::ExclusiveProd(UnaryOperator { input }),
381            out,
382        ));
383
384        output.into()
385    }
386}
387
388/// Perform a reduce max operation across all units in a plane.
389pub fn plane_max<E: CubePrimitive>(_elem: E) -> E {
390    unexpanded!()
391}
392
393/// Module containing the expand function for [`plane_max()`].
394pub mod plane_max {
395    use super::*;
396
397    /// Expand method of [`plane_max()`].
398    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
399        let elem: ManagedVariable = elem.into();
400        let output = scope.create_local(elem.ty);
401
402        let out = *output;
403        let input = *elem;
404
405        scope.register(Instruction::new(Plane::Max(UnaryOperator { input }), out));
406
407        output.into()
408    }
409}
410
411/// Perform a reduce min operation across all units in a plane.
412pub fn plane_min<E: CubePrimitive>(_elem: E) -> E {
413    unexpanded!()
414}
415
416/// Module containing the expand function for [`plane_min()`].
417pub mod plane_min {
418    use super::*;
419
420    /// Expand method of [`plane_min()`].
421    pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: NativeExpand<E>) -> NativeExpand<E> {
422        let elem: ManagedVariable = elem.into();
423        let output = scope.create_local(elem.ty);
424
425        let out = *output;
426        let input = *elem;
427
428        scope.register(Instruction::new(Plane::Min(UnaryOperator { input }), out));
429
430        output.into()
431    }
432}
433
434/// Perform a reduce all operation across all units in a plane.
435pub fn plane_all(_elem: bool) -> bool {
436    unexpanded!()
437}
438
439/// Module containing the expand function for [`plane_all()`].
440pub mod plane_all {
441
442    use super::*;
443
444    /// Expand method of [`plane_all()`].
445    pub fn expand(scope: &mut Scope, elem: NativeExpand<bool>) -> NativeExpand<bool> {
446        let elem: ManagedVariable = elem.into();
447        let output = scope.create_local(elem.ty);
448
449        let out = *output;
450        let input = *elem;
451
452        scope.register(Instruction::new(Plane::All(UnaryOperator { input }), out));
453
454        output.into()
455    }
456}
457
458/// Perform a reduce any operation across all units in a plane.
459pub fn plane_any(_elem: bool) -> bool {
460    unexpanded!()
461}
462
463/// Module containing the expand function for [`plane_any()`].
464pub mod plane_any {
465
466    use super::*;
467
468    /// Expand method of [`plane_any()`].
469    pub fn expand(scope: &mut Scope, elem: NativeExpand<bool>) -> NativeExpand<bool> {
470        let elem: ManagedVariable = elem.into();
471        let output = scope.create_local(elem.ty);
472
473        let out = *output;
474        let input = *elem;
475
476        scope.register(Instruction::new(Plane::Any(UnaryOperator { input }), out));
477
478        output.into()
479    }
480}
481
482/// Perform a ballot operation across all units in a plane.
483/// Returns a set of 32-bit bitfields as a [`Vector`], with each element containing the value from 32
484/// invocations.
485/// Note that vector size will always be set to 4 even for `PLANE_DIM <= 64`, because we can't
486/// retrieve the actual plane size at expand time. Use the runtime`PLANE_DIM` to index appropriately.
487pub fn plane_ballot(_elem: bool) -> Vector<u32, Const<4>> {
488    unexpanded!()
489}
490
491/// Module containing the expand function for [`plane_ballot()`].
492pub mod plane_ballot {
493    use cubecl_ir::UIntKind;
494
495    use super::*;
496
497    /// Expand method of [`plane_ballot()`].
498    pub fn expand(
499        scope: &mut Scope,
500        elem: NativeExpand<bool>,
501    ) -> NativeExpand<Vector<u32, Const<4>>> {
502        let elem: ManagedVariable = elem.into();
503        let out_item = Type::scalar(ElemType::UInt(UIntKind::U32)).with_vector_size(4);
504        let output = scope.create_local(out_item);
505
506        let out = *output;
507        let input = *elem;
508
509        scope.register(Instruction::new(
510            Plane::Ballot(UnaryOperator { input }),
511            out,
512        ));
513
514        output.into()
515    }
516}