cubecl_core/frontend/container/line/
base.rs

1use std::num::NonZero;
2
3use crate::{
4    frontend::{CubePrimitive, CubeType, ExpandElementBaseInit, ExpandElementTyped},
5    prelude::MulHi,
6};
7use crate::{
8    ir::{Arithmetic, BinaryOperator, Elem, Instruction, Item, Scope},
9    prelude::{Dot, Numeric, binary_expand_fixed_output},
10    unexpanded,
11};
12use cubecl_ir::{Comparison, ExpandElement};
13use derive_more::derive::Neg;
14/// A contiguous list of elements that supports auto-vectorized operations.
15
16#[derive(Neg)]
17pub struct Line<P> {
18    // Comptime lines only support 1 element.
19    pub(crate) val: P,
20}
21
22impl<P: CubePrimitive> Clone for Line<P> {
23    fn clone(&self) -> Self {
24        *self
25    }
26}
27impl<P: CubePrimitive> Eq for Line<P> {}
28impl<P: CubePrimitive> Copy for Line<P> {}
29
30/// Module that contains the implementation details of the new function.
31mod new {
32    use super::*;
33
34    impl<P: CubePrimitive> Line<P> {
35        /// Create a new line of size 1 using the given value.
36        pub fn new(val: P) -> Self {
37            Self { val }
38        }
39
40        /// Expand function of [Self::new].
41        pub fn __expand_new(_scope: &mut Scope, val: P::ExpandType) -> ExpandElementTyped<Self> {
42            let elem: ExpandElementTyped<P> = val;
43            elem.expand.into()
44        }
45    }
46}
47
48/// Module that contains the implementation details of the fill function.
49mod fill {
50    use crate::prelude::cast;
51
52    use super::*;
53
54    impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
55        /// Fill the line with the given value.
56        ///
57        /// If you want to fill the line with different values, consider using the index API
58        /// instead.
59        ///
60        /// ```rust, ignore
61        /// let mut line = Line::<u32>::empty(2);
62        /// line[0] = 1;
63        /// line[1] = 2;
64        /// ```
65        #[allow(unused_variables)]
66        pub fn fill(mut self, value: P) -> Self {
67            self.val = value;
68            self
69        }
70
71        /// Expand function of [fill](Self::fill).
72        pub fn __expand_fill(
73            scope: &mut Scope,
74            line: ExpandElementTyped<Self>,
75            value: ExpandElementTyped<P>,
76        ) -> ExpandElementTyped<Self> {
77            line.__expand_fill_method(scope, value)
78        }
79    }
80
81    impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
82        /// Expand method of [fill](Line::fill).
83        pub fn __expand_fill_method(self, scope: &mut Scope, value: ExpandElementTyped<P>) -> Self {
84            let length = self.expand.item.vectorization;
85            let output = scope.create_local(Item::vectorized(P::as_elem(scope), length));
86
87            cast::expand::<P>(scope, value, output.clone().into());
88
89            output.into()
90        }
91    }
92}
93
94/// Module that contains the implementation details of the empty function.
95mod empty {
96    use super::*;
97
98    impl<P: CubePrimitive> Line<P> {
99        /// Create an empty line of the given size.
100        ///
101        /// Note that a line can't change in size once it's fixed.
102        #[allow(unused_variables)]
103        pub fn empty(size: u32) -> Self {
104            unexpanded!()
105        }
106
107        /// Expand function of [empty](Self::empty).
108        pub fn __expand_empty(scope: &mut Scope, length: u32) -> ExpandElementTyped<Self> {
109            let length = NonZero::new(length as u8);
110            scope
111                .create_local_mut(Item::vectorized(Self::as_elem(scope), length))
112                .into()
113        }
114    }
115}
116
117/// Module that contains the implementation details of the size function.
118mod size {
119    use super::*;
120
121    impl<P: CubePrimitive> Line<P> {
122        /// Get the number of individual elements a line contains.
123        ///
124        /// The size is available at comptime and may be used in combination with the comptime
125        /// macro.
126        ///
127        /// ```rust, ignore
128        /// // The if statement is going to be executed at comptime.
129        /// if comptime!(line.size() == 1) {
130        /// }
131        /// ```
132        pub fn size(&self) -> u32 {
133            unexpanded!()
134        }
135
136        /// Expand function of [size](Self::size).
137        pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> u32 {
138            element.__expand_vectorization_factor_method(scope)
139        }
140    }
141
142    impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
143        /// Comptime version of [size](Line::size).
144        pub fn size(&self) -> u32 {
145            self.expand
146                .item
147                .vectorization
148                .unwrap_or(NonZero::new(1).unwrap())
149                .get() as u32
150        }
151
152        /// Expand method of [size](Line::size).
153        pub fn __expand_size_method(&self, _scope: &mut Scope) -> u32 {
154            self.size()
155        }
156    }
157}
158
159// Implement a comparison operator define in
160macro_rules! impl_line_comparison {
161    ($name:ident, $operator:ident, $comment:literal) => {
162        ::paste::paste! {
163            /// Module that contains the implementation details of the $name function.
164            mod $name {
165
166                use super::*;
167
168                impl<P: CubePrimitive> Line<P> {
169                    #[doc = concat!(
170                        "Return a new line with the element-wise comparison of the first line being ",
171                        $comment,
172                        " the second line."
173                    )]
174                    pub fn $name(self, _other: Self) -> Line<bool> {
175                        unexpanded!()
176                    }
177
178                    /// Expand function of [$name](Self::$name).
179                    pub fn [< __expand_ $name >](
180                        scope: &mut Scope,
181                        lhs: ExpandElementTyped<Self>,
182                        rhs: ExpandElementTyped<Self>,
183                    ) -> ExpandElementTyped<Line<bool>> {
184                        lhs.[< __expand_ $name _method >](scope, rhs)
185                    }
186                }
187
188                impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
189                    /// Expand method of [equal](Line::equal).
190                    pub fn [< __expand_ $name _method >](
191                        self,
192                        scope: &mut Scope,
193                        rhs: Self,
194                    ) -> ExpandElementTyped<Line<bool>> {
195                        let size = self.expand.item.vectorization;
196                        let lhs = self.expand.into();
197                        let rhs = rhs.expand.into();
198
199                        let output = scope.create_local_mut(Item::vectorized(bool::as_elem(scope), size));
200
201                        scope.register(Instruction::new(
202                            Comparison::$operator(BinaryOperator { lhs, rhs }),
203                            output.clone().into(),
204                        ));
205
206                        output.into()
207                    }
208                }
209            }
210        }
211
212    };
213}
214
215impl_line_comparison!(equal, Equal, "equal to");
216impl_line_comparison!(not_equal, NotEqual, "not equal to");
217impl_line_comparison!(less_than, Lower, "less than");
218impl_line_comparison!(greater_than, Greater, "greater than");
219impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
220impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
221
222mod bool_and {
223    use cubecl_ir::Operator;
224
225    use crate::prelude::binary_expand;
226
227    use super::*;
228
229    impl Line<bool> {
230        /// Return a new line with the element-wise and of the lines
231        pub fn and(self, _other: Self) -> Line<bool> {
232            unexpanded!()
233        }
234
235        /// Expand function of [and](Self::and).
236        pub fn __expand_and(
237            scope: &mut Scope,
238            lhs: ExpandElementTyped<Self>,
239            rhs: ExpandElementTyped<Self>,
240        ) -> ExpandElementTyped<Line<bool>> {
241            lhs.__expand_and_method(scope, rhs)
242        }
243    }
244
245    impl ExpandElementTyped<Line<bool>> {
246        /// Expand method of [equal](Line::equal).
247        pub fn __expand_and_method(
248            self,
249            scope: &mut Scope,
250            rhs: Self,
251        ) -> ExpandElementTyped<Line<bool>> {
252            binary_expand(scope, self.expand, rhs.expand, Operator::And).into()
253        }
254    }
255}
256
257mod bool_or {
258    use cubecl_ir::Operator;
259
260    use crate::prelude::binary_expand;
261
262    use super::*;
263
264    impl Line<bool> {
265        /// Return a new line with the element-wise and of the lines
266        pub fn or(self, _other: Self) -> Line<bool> {
267            unexpanded!()
268        }
269
270        /// Expand function of [and](Self::and).
271        pub fn __expand_or(
272            scope: &mut Scope,
273            lhs: ExpandElementTyped<Self>,
274            rhs: ExpandElementTyped<Self>,
275        ) -> ExpandElementTyped<Line<bool>> {
276            lhs.__expand_and_method(scope, rhs)
277        }
278    }
279
280    impl ExpandElementTyped<Line<bool>> {
281        /// Expand method of [equal](Line::equal).
282        pub fn __expand_or_method(
283            self,
284            scope: &mut Scope,
285            rhs: Self,
286        ) -> ExpandElementTyped<Line<bool>> {
287            binary_expand(scope, self.expand, rhs.expand, Operator::Or).into()
288        }
289    }
290}
291
292impl<P: CubePrimitive> CubeType for Line<P> {
293    type ExpandType = ExpandElementTyped<Self>;
294}
295
296impl<P: CubePrimitive> ExpandElementBaseInit for Line<P> {
297    fn init_elem(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
298        P::init_elem(scope, elem)
299    }
300}
301
302impl<P: CubePrimitive> CubePrimitive for Line<P> {
303    fn as_elem(scope: &Scope) -> Elem {
304        P::as_elem(scope)
305    }
306
307    fn as_elem_native() -> Option<Elem> {
308        P::as_elem_native()
309    }
310
311    fn size() -> Option<usize> {
312        P::size()
313    }
314}
315
316impl<N: Numeric> Dot for Line<N> {
317    fn dot(self, _rhs: Self) -> Self {
318        unexpanded!()
319    }
320
321    fn __expand_dot(
322        scope: &mut Scope,
323        lhs: ExpandElementTyped<Self>,
324        rhs: ExpandElementTyped<Self>,
325    ) -> ExpandElementTyped<Self> {
326        let lhs: ExpandElement = lhs.into();
327        let mut item = lhs.item;
328        item.vectorization = None;
329        binary_expand_fixed_output(scope, lhs, rhs.into(), item, Arithmetic::Dot).into()
330    }
331}
332
333impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}