cubecl_core/frontend/container/line/
base.rs

1use std::num::NonZero;
2
3use crate as cubecl;
4use crate::{
5    frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped},
6    prelude::MulHi,
7};
8use crate::{
9    ir::{Arithmetic, BinaryOperator, Elem, Instruction, Item, Scope},
10    prelude::{Dot, Numeric, binary_expand_fixed_output},
11    unexpanded,
12};
13use cubecl_ir::{Comparison, ExpandElement};
14use cubecl_macros::{cube, intrinsic};
15use derive_more::derive::Neg;
16/// A contiguous list of elements that supports auto-vectorized operations.
17
18#[derive(Neg)]
19pub struct Line<P> {
20    // Comptime lines only support 1 element.
21    pub(crate) val: P,
22}
23
24type LineExpand<P> = ExpandElementTyped<Line<P>>;
25
26impl<P: CubePrimitive> Clone for Line<P> {
27    fn clone(&self) -> Self {
28        *self
29    }
30}
31impl<P: CubePrimitive> Eq for Line<P> {}
32impl<P: CubePrimitive> Copy for Line<P> {}
33
34/// Module that contains the implementation details of the new function.
35mod new {
36    use cubecl_macros::comptime_type;
37
38    use super::*;
39
40    #[cube]
41    impl<P: CubePrimitive> Line<P> {
42        /// Create a new line of size 1 using the given value.
43        #[allow(unused_variables)]
44        pub fn new(val: P) -> Self {
45            intrinsic!(|_| {
46                let elem: ExpandElementTyped<P> = val;
47                elem.expand.into()
48            })
49        }
50
51        #[allow(unused_variables)]
52        /// Get the length of the current line.
53        pub fn line_size(&self) -> comptime_type!(u32) {
54            intrinsic!(|_| {
55                let elem: ExpandElementTyped<Line<P>> = self;
56                elem.expand
57                    .item
58                    .vectorization
59                    .map(|a| a.get() as u32)
60                    .unwrap_or(1)
61            })
62        }
63    }
64}
65
66/// Module that contains the implementation details of the fill function.
67mod fill {
68    use crate::prelude::cast;
69
70    use super::*;
71
72    #[cube]
73    impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
74        /// Fill the line with the given value.
75        ///
76        /// If you want to fill the line with different values, consider using the index API
77        /// instead.
78        ///
79        /// ```rust, ignore
80        /// let mut line = Line::<u32>::empty(2);
81        /// line[0] = 1;
82        /// line[1] = 2;
83        /// ```
84        #[allow(unused_variables)]
85        pub fn fill(self, value: P) -> Self {
86            intrinsic!(|scope| {
87                let length = self.expand.item.vectorization;
88                let output = scope.create_local(Item::vectorized(P::as_elem(scope), length));
89
90                cast::expand::<P>(scope, value, output.clone().into());
91
92                output.into()
93            })
94        }
95    }
96}
97
98/// Module that contains the implementation details of the empty function.
99mod empty {
100    use crate::prelude::Cast;
101
102    use super::*;
103
104    #[cube]
105    impl<P: CubePrimitive> Line<P> {
106        /// Create an empty line of the given size.
107        ///
108        /// Note that a line can't change in size once it's fixed.
109        #[allow(unused_variables)]
110        pub fn empty(#[comptime] size: u32) -> Self {
111            let zero = Line::<P>::cast_from(0);
112            intrinsic!(|scope| {
113                let length = NonZero::new(size as u8);
114                // We don't declare const variables in our compilers, only mut variables.
115                // So we need to create the variable as mut here.
116                let var: ExpandElementTyped<Line<P>> = scope
117                    .create_local_mut(Item::vectorized(Self::as_elem(scope), length))
118                    .into();
119                cubecl::frontend::assign::expand(scope, zero, var.clone());
120                var
121            })
122        }
123    }
124}
125
126/// Module that contains the implementation details of the size function.
127mod size {
128    use super::*;
129
130    impl<P: CubePrimitive> Line<P> {
131        /// Get the number of individual elements a line contains.
132        ///
133        /// The size is available at comptime and may be used in combination with the comptime
134        /// macro.
135        ///
136        /// ```rust, ignore
137        /// // The if statement is going to be executed at comptime.
138        /// if comptime!(line.size() == 1) {
139        /// }
140        /// ```
141        pub fn size(&self) -> u32 {
142            unexpanded!()
143        }
144
145        /// Expand function of [size](Self::size).
146        pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> u32 {
147            element.__expand_vectorization_factor_method(scope)
148        }
149    }
150
151    impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
152        /// Comptime version of [size](Line::size).
153        pub fn size(&self) -> u32 {
154            self.expand
155                .item
156                .vectorization
157                .unwrap_or(NonZero::new(1).unwrap())
158                .get() as u32
159        }
160
161        /// Expand method of [size](Line::size).
162        pub fn __expand_size_method(&self, _scope: &mut Scope) -> u32 {
163            self.size()
164        }
165    }
166}
167
168// Implement a comparison operator define in
169macro_rules! impl_line_comparison {
170    ($name:ident, $operator:ident, $comment:literal) => {
171        ::paste::paste! {
172            /// Module that contains the implementation details of the $name function.
173            mod $name {
174
175                use super::*;
176
177                #[cube]
178                impl<P: CubePrimitive> Line<P> {
179                    #[doc = concat!(
180                        "Return a new line with the element-wise comparison of the first line being ",
181                        $comment,
182                        " the second line."
183                    )]
184                    #[allow(unused_variables)]
185                    pub fn $name(self, other: Self) -> Line<bool> {
186                        intrinsic!(|scope| {
187                            let size = self.expand.item.vectorization;
188                            let lhs = self.expand.into();
189                            let rhs = other.expand.into();
190
191                            let output = scope.create_local_mut(Item::vectorized(bool::as_elem(scope), size));
192
193                            scope.register(Instruction::new(
194                                Comparison::$operator(BinaryOperator { lhs, rhs }),
195                                output.clone().into(),
196                            ));
197
198                            output.into()
199                        })
200                    }
201                }
202            }
203        }
204
205    };
206}
207
208impl_line_comparison!(equal, Equal, "equal to");
209impl_line_comparison!(not_equal, NotEqual, "not equal to");
210impl_line_comparison!(less_than, Lower, "less than");
211impl_line_comparison!(greater_than, Greater, "greater than");
212impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
213impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
214
215mod bool_and {
216    use cubecl_ir::Operator;
217
218    use crate::prelude::binary_expand;
219
220    use super::*;
221
222    #[cube]
223    impl Line<bool> {
224        /// Return a new line with the element-wise and of the lines
225        #[allow(unused_variables)]
226        pub fn and(self, other: Self) -> Line<bool> {
227            intrinsic!(
228                |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
229            )
230        }
231    }
232}
233
234mod bool_or {
235    use cubecl_ir::Operator;
236
237    use crate::prelude::binary_expand;
238
239    use super::*;
240
241    #[cube]
242    impl Line<bool> {
243        /// Return a new line with the element-wise and of the lines
244        #[allow(unused_variables)]
245        pub fn or(self, other: Self) -> Line<bool> {
246            intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
247        }
248    }
249}
250
251impl<P: CubePrimitive> CubeType for Line<P> {
252    type ExpandType = ExpandElementTyped<Self>;
253}
254
255impl<P: CubePrimitive> ExpandElementIntoMut for Line<P> {
256    fn elem_into_mut(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
257        P::elem_into_mut(scope, elem)
258    }
259}
260
261impl<P: CubePrimitive> CubePrimitive for Line<P> {
262    fn as_elem(scope: &Scope) -> Elem {
263        P::as_elem(scope)
264    }
265
266    fn as_elem_native() -> Option<Elem> {
267        P::as_elem_native()
268    }
269
270    fn size() -> Option<usize> {
271        P::size()
272    }
273}
274
275impl<N: Numeric> Dot for Line<N> {
276    fn dot(self, _rhs: Self) -> Self {
277        unexpanded!()
278    }
279
280    fn __expand_dot(
281        scope: &mut Scope,
282        lhs: ExpandElementTyped<Self>,
283        rhs: ExpandElementTyped<Self>,
284    ) -> ExpandElementTyped<Self> {
285        let lhs: ExpandElement = lhs.into();
286        let mut item = lhs.item;
287        item.vectorization = None;
288        binary_expand_fixed_output(scope, lhs, rhs.into(), item, Arithmetic::Dot).into()
289    }
290}
291
292impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}