cubecl_core/frontend/
plane.rs

1use super::{CubeContext, CubePrimitive, ExpandElement};
2use crate::{ir::Operation, prelude::ExpandElementTyped};
3use crate::{
4    ir::{Elem, Instruction, Item, Plane, UnaryOperator},
5    unexpanded,
6};
7
8/// Returns true if the cube unit has the lowest plane_unit_id among active unit in the plane
9pub fn plane_elect() -> bool {
10    unexpanded!()
11}
12
13/// Module containing the expand function for [plane_elect()].
14pub mod plane_elect {
15
16    use super::*;
17
18    /// Expand method of [plane_elect()].
19    pub fn expand(context: &mut CubeContext) -> ExpandElementTyped<bool> {
20        let output = context.create_local(Item::new(Elem::Bool));
21        let out = *output;
22
23        context.register(Instruction::new(Plane::Elect, out));
24
25        output.into()
26    }
27}
28
29/// Broadcasts the value from the specified plane unit at the given index
30/// to all active units within that plane.
31#[allow(unused_variables)]
32pub fn plane_broadcast<E: CubePrimitive>(value: E, index: u32) -> E {
33    unexpanded!()
34}
35
36/// Module containing the expand function for [plane_broadcast()].
37pub mod plane_broadcast {
38
39    use super::*;
40
41    /// Expand method of [plane_broadcast()].
42    pub fn expand<E: CubePrimitive>(
43        context: &mut CubeContext,
44        value: ExpandElementTyped<E>,
45        id: ExpandElementTyped<u32>,
46    ) -> ExpandElementTyped<E> {
47        let output = context.create_local(value.expand.item);
48        let out = *output;
49        let lhs = *value.expand;
50        let rhs = *id.expand;
51
52        context.register(Instruction::new(
53            Plane::Broadcast(crate::ir::BinaryOperator { lhs, rhs }),
54            out,
55        ));
56
57        output.into()
58    }
59}
60
61/// Perform a reduce sum operation across all units in a plane.
62#[allow(unused_variables)]
63pub fn plane_sum<E: CubePrimitive>(value: E) -> E {
64    unexpanded!()
65}
66
67/// Module containing the expand function for [plane_sum()].
68pub mod plane_sum {
69    use super::*;
70
71    /// Expand method of [plane_sum()].
72    pub fn expand<E: CubePrimitive>(
73        context: &mut CubeContext,
74        elem: ExpandElementTyped<E>,
75    ) -> ExpandElementTyped<E> {
76        let elem: ExpandElement = elem.into();
77        let output = context.create_local(elem.item);
78
79        let out = *output;
80        let input = *elem;
81
82        context.register(Instruction::new(Plane::Sum(UnaryOperator { input }), out));
83
84        output.into()
85    }
86}
87
88/// Perform a reduce prod operation across all units in a plane.
89pub fn plane_prod<E: CubePrimitive>(_elem: E) -> E {
90    unexpanded!()
91}
92
93/// Module containing the expand function for [plane_prod()].
94pub mod plane_prod {
95    use super::*;
96
97    /// Expand method of [plane_prod()].
98    pub fn expand<E: CubePrimitive>(
99        context: &mut CubeContext,
100        elem: ExpandElementTyped<E>,
101    ) -> ExpandElementTyped<E> {
102        let elem: ExpandElement = elem.into();
103        let output = context.create_local(elem.item);
104
105        let out = *output;
106        let input = *elem;
107
108        context.register(Instruction::new(Plane::Prod(UnaryOperator { input }), out));
109
110        output.into()
111    }
112}
113
114/// Perform a reduce max operation across all units in a plane.
115pub fn plane_max<E: CubePrimitive>(_elem: E) -> E {
116    unexpanded!()
117}
118
119/// Module containing the expand function for [plane_max()].
120pub mod plane_max {
121    use super::*;
122
123    /// Expand method of [plane_max()].
124    pub fn expand<E: CubePrimitive>(
125        context: &mut CubeContext,
126        elem: ExpandElementTyped<E>,
127    ) -> ExpandElementTyped<E> {
128        let elem: ExpandElement = elem.into();
129        let output = context.create_local(elem.item);
130
131        let out = *output;
132        let input = *elem;
133
134        context.register(Instruction::new(Plane::Max(UnaryOperator { input }), out));
135
136        output.into()
137    }
138}
139
140/// Perform a reduce min operation across all units in a plane.
141pub fn plane_min<E: CubePrimitive>(_elem: E) -> E {
142    unexpanded!()
143}
144
145/// Module containing the expand function for [plane_min()].
146pub mod plane_min {
147    use super::*;
148
149    /// Expand method of [plane_min()].
150    pub fn expand<E: CubePrimitive>(
151        context: &mut CubeContext,
152        elem: ExpandElementTyped<E>,
153    ) -> ExpandElementTyped<E> {
154        let elem: ExpandElement = elem.into();
155        let output = context.create_local(elem.item);
156
157        let out = *output;
158        let input = *elem;
159
160        context.register(Instruction::new(Plane::Min(UnaryOperator { input }), out));
161
162        output.into()
163    }
164}
165
166/// Perform a reduce all operation across all units in a plane.
167pub fn plane_all(_elem: bool) -> bool {
168    unexpanded!()
169}
170
171/// Module containing the expand function for [plane_all()].
172pub mod plane_all {
173
174    use super::*;
175
176    /// Expand method of [plane_all()].
177    pub fn expand(
178        context: &mut CubeContext,
179        elem: ExpandElementTyped<bool>,
180    ) -> ExpandElementTyped<bool> {
181        let elem: ExpandElement = elem.into();
182        let output = context.create_local(elem.item);
183
184        let out = *output;
185        let input = *elem;
186
187        context.register(Instruction::new(Plane::All(UnaryOperator { input }), out));
188
189        output.into()
190    }
191}
192
193/// Perform a reduce any operation across all units in a plane.
194pub fn plane_any(_elem: bool) -> bool {
195    unexpanded!()
196}
197
198/// Module containing the expand function for [plane_any()].
199pub mod plane_any {
200
201    use super::*;
202
203    /// Expand method of [plane_any()].
204    pub fn expand(
205        context: &mut CubeContext,
206        elem: ExpandElementTyped<bool>,
207    ) -> ExpandElementTyped<bool> {
208        let elem: ExpandElement = elem.into();
209        let output = context.create_local(elem.item);
210
211        let out = *output;
212        let input = *elem;
213
214        context.register(Instruction::new(Plane::Any(UnaryOperator { input }), out));
215
216        output.into()
217    }
218}
219
220impl From<Plane> for Operation {
221    fn from(value: Plane) -> Self {
222        Operation::Plane(value)
223    }
224}