cubecl_core/frontend/container/line/
base.rs

1use crate as cubecl;
2use crate::{
3    frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped},
4    prelude::MulHi,
5};
6use crate::{
7    ir::{Arithmetic, BinaryOperator, Instruction, Scope, Type},
8    prelude::{Dot, Numeric, binary_expand_fixed_output},
9    unexpanded,
10};
11use cubecl_ir::{Comparison, ConstantScalarValue, ExpandElement, StorageType};
12use cubecl_macros::{cube, intrinsic};
13use derive_more::derive::Neg;
14/// A contiguous list of elements that supports auto-vectorized operations.
15
16#[derive(Neg)]
17pub struct Line<P> {
18    // Comptime lines only support 1 element.
19    pub(crate) val: P,
20}
21
22type LineExpand<P> = ExpandElementTyped<Line<P>>;
23
24impl<P: CubePrimitive> Clone for Line<P> {
25    fn clone(&self) -> Self {
26        *self
27    }
28}
29impl<P: CubePrimitive> Eq for Line<P> {}
30impl<P: CubePrimitive> Copy for Line<P> {}
31
32/// Module that contains the implementation details of the new function.
33mod new {
34    use cubecl_macros::comptime_type;
35
36    use super::*;
37
38    #[cube]
39    impl<P: CubePrimitive> Line<P> {
40        /// Create a new line of size 1 using the given value.
41        #[allow(unused_variables)]
42        pub fn new(val: P) -> Self {
43            intrinsic!(|_| {
44                let elem: ExpandElementTyped<P> = val;
45                elem.expand.into()
46            })
47        }
48    }
49
50    impl<P: CubePrimitive> Line<P> {
51        /// Get the length of the current line.
52        pub fn line_size(&self) -> comptime_type!(u32) {
53            unexpanded!()
54        }
55    }
56}
57
58/// Module that contains the implementation details of the fill function.
59mod fill {
60    use crate::prelude::cast;
61
62    use super::*;
63
64    #[cube]
65    impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
66        /// Fill the line with the given value.
67        ///
68        /// If you want to fill the line with different values, consider using the index API
69        /// instead.
70        ///
71        /// ```rust, ignore
72        /// let mut line = Line::<u32>::empty(2);
73        /// line[0] = 1;
74        /// line[1] = 2;
75        /// ```
76        #[allow(unused_variables)]
77        pub fn fill(self, value: P) -> Self {
78            intrinsic!(|scope| {
79                let length = self.expand.ty.line_size();
80                let output = scope.create_local(Type::new(P::as_type(scope)).line(length));
81
82                cast::expand::<P>(scope, value, output.clone().into());
83
84                output.into()
85            })
86        }
87    }
88}
89
90/// Module that contains the implementation details of the empty function.
91mod empty {
92    use crate::prelude::Cast;
93
94    use super::*;
95
96    #[cube]
97    impl<P: CubePrimitive> Line<P> {
98        /// Create an empty line of the given size.
99        ///
100        /// Note that a line can't change in size once it's fixed.
101        #[allow(unused_variables)]
102        pub fn empty(#[comptime] size: u32) -> Self {
103            let zero = Line::<P>::cast_from(0);
104            intrinsic!(|scope| {
105                // We don't declare const variables in our compilers, only mut variables.
106                // So we need to create the variable as mut here.
107                let var: ExpandElementTyped<Line<P>> = scope
108                    .create_local_mut(Type::new(Self::as_type(scope)).line(size))
109                    .into();
110                cubecl::frontend::assign::expand(scope, zero, var.clone());
111                var
112            })
113        }
114    }
115}
116
117/// Module that contains the implementation details of the size function.
118mod size {
119    use super::*;
120
121    impl<P: CubePrimitive> Line<P> {
122        /// Get the number of individual elements a line contains.
123        ///
124        /// The size is available at comptime and may be used in combination with the comptime
125        /// macro.
126        ///
127        /// ```rust, ignore
128        /// // The if statement is going to be executed at comptime.
129        /// if comptime!(line.size() == 1) {
130        /// }
131        /// ```
132        pub fn size(&self) -> u32 {
133            unexpanded!()
134        }
135
136        /// Expand function of [size](Self::size).
137        pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> u32 {
138            element.__expand_line_size_method(scope)
139        }
140    }
141
142    impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
143        /// Comptime version of [size](Line::size).
144        pub fn size(&self) -> u32 {
145            self.expand.ty.line_size()
146        }
147
148        /// Expand method of [size](Line::size).
149        pub fn __expand_size_method(&self, _scope: &mut Scope) -> u32 {
150            self.size()
151        }
152    }
153}
154
155// Implement a comparison operator define in
156macro_rules! impl_line_comparison {
157    ($name:ident, $operator:ident, $comment:literal) => {
158        ::paste::paste! {
159            /// Module that contains the implementation details of the $name function.
160            mod $name {
161
162                use super::*;
163
164                #[cube]
165                impl<P: CubePrimitive> Line<P> {
166                    #[doc = concat!(
167                        "Return a new line with the element-wise comparison of the first line being ",
168                        $comment,
169                        " the second line."
170                    )]
171                    #[allow(unused_variables)]
172                    pub fn $name(self, other: Self) -> Line<bool> {
173                        intrinsic!(|scope| {
174                            let size = self.expand.ty.line_size();
175                            let lhs = self.expand.into();
176                            let rhs = other.expand.into();
177
178                            let output = scope.create_local_mut(Type::new(bool::as_type(scope)).line(size));
179
180                            scope.register(Instruction::new(
181                                Comparison::$operator(BinaryOperator { lhs, rhs }),
182                                output.clone().into(),
183                            ));
184
185                            output.into()
186                        })
187                    }
188                }
189            }
190        }
191
192    };
193}
194
195impl_line_comparison!(equal, Equal, "equal to");
196impl_line_comparison!(not_equal, NotEqual, "not equal to");
197impl_line_comparison!(less_than, Lower, "less than");
198impl_line_comparison!(greater_than, Greater, "greater than");
199impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
200impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
201
202mod bool_and {
203    use cubecl_ir::Operator;
204
205    use crate::prelude::binary_expand;
206
207    use super::*;
208
209    #[cube]
210    impl Line<bool> {
211        /// Return a new line with the element-wise and of the lines
212        #[allow(unused_variables)]
213        pub fn and(self, other: Self) -> Line<bool> {
214            intrinsic!(
215                |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
216            )
217        }
218    }
219}
220
221mod bool_or {
222    use cubecl_ir::Operator;
223
224    use crate::prelude::binary_expand;
225
226    use super::*;
227
228    #[cube]
229    impl Line<bool> {
230        /// Return a new line with the element-wise and of the lines
231        #[allow(unused_variables)]
232        pub fn or(self, other: Self) -> Line<bool> {
233            intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
234        }
235    }
236}
237
238impl<P: CubePrimitive> CubeType for Line<P> {
239    type ExpandType = ExpandElementTyped<Self>;
240}
241
242impl<P: CubePrimitive> CubeType for &Line<P> {
243    type ExpandType = ExpandElementTyped<Line<P>>;
244}
245
246impl<P: CubePrimitive> CubeType for &mut Line<P> {
247    type ExpandType = ExpandElementTyped<Line<P>>;
248}
249
250impl<P: CubePrimitive> ExpandElementIntoMut for Line<P> {
251    fn elem_into_mut(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
252        P::elem_into_mut(scope, elem)
253    }
254}
255
256impl<P: CubePrimitive> CubePrimitive for Line<P> {
257    fn as_type(scope: &Scope) -> StorageType {
258        P::as_type(scope)
259    }
260
261    fn as_type_native() -> Option<StorageType> {
262        P::as_type_native()
263    }
264
265    fn size() -> Option<usize> {
266        P::size()
267    }
268
269    fn from_const_value(value: ConstantScalarValue) -> Self {
270        Self::new(P::from_const_value(value))
271    }
272}
273
274impl<N: Numeric> Dot for Line<N> {
275    fn dot(self, _rhs: Self) -> Self {
276        unexpanded!()
277    }
278
279    fn __expand_dot(
280        scope: &mut Scope,
281        lhs: ExpandElementTyped<Self>,
282        rhs: ExpandElementTyped<Self>,
283    ) -> ExpandElementTyped<Self> {
284        let lhs: ExpandElement = lhs.into();
285        let item = lhs.ty.storage_type().into();
286        binary_expand_fixed_output(scope, lhs, rhs.into(), item, Arithmetic::Dot).into()
287    }
288}
289
290impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}