cubecl_core/frontend/container/line/
base.rs

1use crate::{self as cubecl, prelude::FloatOps};
2use crate::{
3    frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped},
4    prelude::MulHi,
5};
6use crate::{
7    ir::{BinaryOperator, Instruction, Scope, Type},
8    prelude::Dot,
9    unexpanded,
10};
11use cubecl_ir::{Comparison, ConstantValue, ExpandElement, StorageType};
12use cubecl_macros::{cube, intrinsic};
13use derive_more::derive::Neg;
14/// A contiguous list of elements that supports auto-vectorized operations.
15
16#[derive(Neg)]
17pub struct Line<P> {
18    // Comptime lines only support 1 element.
19    pub(crate) val: P,
20}
21
22type LineExpand<P> = ExpandElementTyped<Line<P>>;
23
24impl<P: CubePrimitive> Clone for Line<P> {
25    fn clone(&self) -> Self {
26        *self
27    }
28}
29impl<P: CubePrimitive> Eq for Line<P> {}
30impl<P: CubePrimitive> Copy for Line<P> {}
31
32/// Module that contains the implementation details of the new function.
33mod new {
34    use cubecl_ir::LineSize;
35    use cubecl_macros::comptime_type;
36
37    use super::*;
38
39    #[cube]
40    impl<P: CubePrimitive> Line<P> {
41        /// Create a new line of size 1 using the given value.
42        #[allow(unused_variables)]
43        pub fn new(val: P) -> Self {
44            intrinsic!(|_| {
45                let elem: ExpandElementTyped<P> = val;
46                elem.expand.into()
47            })
48        }
49    }
50
51    impl<P: CubePrimitive> Line<P> {
52        /// Get the length of the current line.
53        pub fn line_size(&self) -> comptime_type!(LineSize) {
54            unexpanded!()
55        }
56    }
57}
58
59/// Module that contains the implementation details of the fill function.
60mod fill {
61    use crate::prelude::cast;
62
63    use super::*;
64
65    #[cube]
66    impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
67        /// Fill the line with the given value.
68        ///
69        /// If you want to fill the line with different values, consider using the index API
70        /// instead.
71        ///
72        /// ```rust, ignore
73        /// let mut line = Line::<u32>::empty(2);
74        /// line[0] = 1;
75        /// line[1] = 2;
76        /// ```
77        #[allow(unused_variables)]
78        pub fn fill(self, value: P) -> Self {
79            intrinsic!(|scope| {
80                let length = self.expand.ty.line_size();
81                let output = scope.create_local(Type::new(P::as_type(scope)).line(length));
82
83                cast::expand::<P, Line<P>>(scope, value, output.clone().into());
84
85                output.into()
86            })
87        }
88    }
89}
90
91/// Module that contains the implementation details of the empty function.
92mod empty {
93    use crate::prelude::Cast;
94
95    use super::*;
96
97    #[cube]
98    impl<P: CubePrimitive> Line<P> {
99        /// Create an empty line of the given size.
100        ///
101        /// Note that a line can't change in size once it's fixed.
102        #[allow(unused_variables)]
103        pub fn empty(#[comptime] size: usize) -> Self {
104            let zero = Line::<P>::cast_from(0);
105            intrinsic!(|scope| {
106                // We don't declare const variables in our compilers, only mut variables.
107                // So we need to create the variable as mut here.
108                let var: ExpandElementTyped<Line<P>> = scope
109                    .create_local_mut(Type::new(Self::as_type(scope)).line(size))
110                    .into();
111                cubecl::frontend::assign::expand(scope, zero, var.clone());
112                var
113            })
114        }
115    }
116}
117
118/// Module that contains the implementation details of the size function.
119mod size {
120    use cubecl_ir::LineSize;
121
122    use super::*;
123
124    impl<P: CubePrimitive> Line<P> {
125        /// Get the number of individual elements a line contains.
126        ///
127        /// The size is available at comptime and may be used in combination with the comptime
128        /// macro.
129        ///
130        /// ```rust, ignore
131        /// // The if statement is going to be executed at comptime.
132        /// if comptime!(line.size() == 1) {
133        /// }
134        /// ```
135        pub fn size(&self) -> LineSize {
136            unexpanded!()
137        }
138
139        /// Expand function of [size](Self::size).
140        pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> LineSize {
141            element.__expand_line_size_method(scope)
142        }
143    }
144
145    impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
146        /// Comptime version of [size](Line::size).
147        pub fn size(&self) -> LineSize {
148            self.expand.ty.line_size()
149        }
150
151        /// Expand method of [size](Line::size).
152        pub fn __expand_size_method(&self, _scope: &mut Scope) -> LineSize {
153            self.size()
154        }
155    }
156}
157
158// Implement a comparison operator define in
159macro_rules! impl_line_comparison {
160    ($name:ident, $operator:ident, $comment:literal) => {
161        ::paste::paste! {
162            /// Module that contains the implementation details of the $name function.
163            mod $name {
164
165                use super::*;
166
167                #[cube]
168                impl<P: CubePrimitive> Line<P> {
169                    #[doc = concat!(
170                        "Return a new line with the element-wise comparison of the first line being ",
171                        $comment,
172                        " the second line."
173                    )]
174                    #[allow(unused_variables)]
175                    pub fn $name(self, other: Self) -> Line<bool> {
176                        intrinsic!(|scope| {
177                            let size = self.expand.ty.line_size();
178                            let lhs = self.expand.into();
179                            let rhs = other.expand.into();
180
181                            let output = scope.create_local_mut(Type::new(bool::as_type(scope)).line(size));
182
183                            scope.register(Instruction::new(
184                                Comparison::$operator(BinaryOperator { lhs, rhs }),
185                                output.clone().into(),
186                            ));
187
188                            output.into()
189                        })
190                    }
191                }
192            }
193        }
194
195    };
196}
197
198impl_line_comparison!(equal, Equal, "equal to");
199impl_line_comparison!(not_equal, NotEqual, "not equal to");
200impl_line_comparison!(less_than, Lower, "less than");
201impl_line_comparison!(greater_than, Greater, "greater than");
202impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
203impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
204
205mod bool_and {
206    use cubecl_ir::Operator;
207
208    use crate::prelude::binary_expand;
209
210    use super::*;
211
212    #[cube]
213    impl Line<bool> {
214        /// Return a new line with the element-wise and of the lines
215        #[allow(unused_variables)]
216        pub fn and(self, other: Self) -> Line<bool> {
217            intrinsic!(
218                |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
219            )
220        }
221    }
222}
223
224mod bool_or {
225    use cubecl_ir::Operator;
226
227    use crate::prelude::binary_expand;
228
229    use super::*;
230
231    #[cube]
232    impl Line<bool> {
233        /// Return a new line with the element-wise and of the lines
234        #[allow(unused_variables)]
235        pub fn or(self, other: Self) -> Line<bool> {
236            intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
237        }
238    }
239}
240
241impl<P: CubePrimitive> CubeType for Line<P> {
242    type ExpandType = ExpandElementTyped<Self>;
243}
244
245impl<P: CubePrimitive> CubeType for &Line<P> {
246    type ExpandType = ExpandElementTyped<Line<P>>;
247}
248
249impl<P: CubePrimitive> CubeType for &mut Line<P> {
250    type ExpandType = ExpandElementTyped<Line<P>>;
251}
252
253impl<P: CubePrimitive> ExpandElementIntoMut for Line<P> {
254    fn elem_into_mut(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
255        P::elem_into_mut(scope, elem)
256    }
257}
258
259impl<P: CubePrimitive> CubePrimitive for Line<P> {
260    fn as_type(scope: &Scope) -> StorageType {
261        P::as_type(scope)
262    }
263
264    fn as_type_native() -> Option<StorageType> {
265        P::as_type_native()
266    }
267
268    fn size() -> Option<usize> {
269        P::size()
270    }
271
272    fn from_const_value(value: ConstantValue) -> Self {
273        Self::new(P::from_const_value(value))
274    }
275}
276
277impl<N: Dot + CubePrimitive> Dot for Line<N> {}
278impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}
279impl<N: FloatOps + CubePrimitive> FloatOps for Line<N> {}