cubecl_core/frontend/container/line/
base.rs

1use std::num::NonZero;
2
3use crate::{
4    ir::{BinaryOperator, ConstantScalarValue, Elem, Instruction, Item, Operator},
5    prelude::{binary_expand_fixed_output, CubeContext, Dot, ExpandElement, Numeric},
6    unexpanded,
7};
8
9use crate::frontend::{
10    CubePrimitive, CubeType, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime,
11};
12
13/// A contiguous list of elements that supports auto-vectorized operations.
14pub struct Line<P> {
15    // Comptime lines only support 1 element.
16    pub(crate) val: P,
17}
18
19impl<P: CubePrimitive> Clone for Line<P> {
20    fn clone(&self) -> Self {
21        *self
22    }
23}
24impl<P: CubePrimitive> Eq for Line<P> {}
25impl<P: CubePrimitive> Copy for Line<P> {}
26
27/// Module that contains the implementation details of the new function.
28mod new {
29    use super::*;
30
31    impl<P: CubePrimitive> Line<P> {
32        /// Create a new line of size 1 using the given value.
33        pub fn new(val: P) -> Self {
34            Self { val }
35        }
36
37        /// Expand function of [Self::new].
38        pub fn __expand_new(
39            _context: &mut CubeContext,
40            val: P::ExpandType,
41        ) -> ExpandElementTyped<Self> {
42            let elem: ExpandElementTyped<P> = val;
43            elem.expand.into()
44        }
45    }
46}
47
48/// Module that contains the implementation details of the fill function.
49mod fill {
50    use crate::prelude::cast;
51
52    use super::*;
53
54    impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
55        /// Fill the line with the given value.
56        ///
57        /// If you want to fill the line with different values, consider using the index API
58        /// instead.
59        ///
60        /// ```rust, ignore
61        /// let mut line = Line::<u32>::empty(2);
62        /// line[0] = 1;
63        /// line[1] = 2;
64        /// ```
65        #[allow(unused_variables)]
66        pub fn fill(mut self, value: P) -> Self {
67            self.val = value;
68            self
69        }
70
71        /// Expand function of [fill](Self::fill).
72        pub fn __expand_fill(
73            context: &mut CubeContext,
74            line: ExpandElementTyped<Self>,
75            value: ExpandElementTyped<P>,
76        ) -> ExpandElementTyped<Self> {
77            line.__expand_fill_method(context, value)
78        }
79    }
80
81    impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
82        /// Expand method of [fill](Line::fill).
83        pub fn __expand_fill_method(
84            self,
85            context: &mut CubeContext,
86            value: ExpandElementTyped<P>,
87        ) -> Self {
88            let length = self.expand.item.vectorization;
89            let output = context.create_local(Item::vectorized(P::as_elem(context), length));
90
91            cast::expand::<P>(context, value, output.clone().into());
92
93            output.into()
94        }
95    }
96}
97
98/// Module that contains the implementation details of the empty function.
99mod empty {
100    use super::*;
101
102    impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
103        /// Create an empty line of the given size.
104        ///
105        /// Note that a line can't change in size once it's fixed.
106        #[allow(unused_variables)]
107        pub fn empty(size: u32) -> Self {
108            unexpanded!()
109        }
110
111        /// Expand function of [empty](Self::empty).
112        pub fn __expand_empty(
113            context: &mut CubeContext,
114            length: ExpandElementTyped<u32>,
115        ) -> ExpandElementTyped<Self> {
116            let length = match length.expand.as_const() {
117                Some(val) => match val {
118                    ConstantScalarValue::Int(val, _) => NonZero::new(val)
119                        .map(|val| val.get() as u8)
120                        .map(|val| NonZero::new(val).unwrap()),
121                    ConstantScalarValue::Float(val, _) => NonZero::new(val as i64)
122                        .map(|val| val.get() as u8)
123                        .map(|val| NonZero::new(val).unwrap()),
124                    ConstantScalarValue::UInt(val, _) => NonZero::new(val as u8),
125                    ConstantScalarValue::Bool(_) => None,
126                },
127                None => None,
128            };
129            context
130                .create_local_mut(Item::vectorized(Self::as_elem(context), length))
131                .into()
132        }
133    }
134}
135
136/// Module that contains the implementation details of the size function.
137mod size {
138    use super::*;
139
140    impl<P: CubePrimitive> Line<P> {
141        /// Get the number of individual elements a line contains.
142        ///
143        /// The size is available at comptime and may be used in combination with the comptime
144        /// macro.
145        ///
146        /// ```rust, ignore
147        /// // The if statement is going to be executed at comptime.
148        /// if comptime!(line.size() == 1) {
149        /// }
150        /// ```
151        pub fn size(&self) -> u32 {
152            unexpanded!()
153        }
154
155        /// Expand function of [size](Self::size).
156        pub fn __expand_size(context: &mut CubeContext, element: ExpandElementTyped<P>) -> u32 {
157            element.__expand_vectorization_factor_method(context)
158        }
159    }
160
161    impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
162        /// Comptime version of [size](Line::size).
163        pub fn size(&self) -> u32 {
164            self.expand
165                .item
166                .vectorization
167                .unwrap_or(NonZero::new(1).unwrap())
168                .get() as u32
169        }
170
171        /// Expand method of [size](Line::size).
172        pub fn __expand_size_method(&self, _context: &mut CubeContext) -> u32 {
173            self.size()
174        }
175    }
176}
177
178// Implement a comparison operator define in
179macro_rules! impl_line_comparison {
180    ($name:ident, $operator:ident, $comment:literal) => {
181        ::paste::paste! {
182            /// Module that contains the implementation details of the $name function.
183            mod $name {
184
185                use super::*;
186
187                impl<P: CubePrimitive> Line<P> {
188                    #[doc = concat!(
189                        "Return a new line with the element-wise comparison of the first line being ",
190                        $comment,
191                        " the second line."
192                    )]
193                    pub fn $name(self, _other: Self) -> Line<bool> {
194                        unexpanded!()
195                    }
196
197                    /// Expand function of [$name](Self::$name).
198                    pub fn [< __expand_ $name >](
199                        context: &mut CubeContext,
200                        lhs: ExpandElementTyped<Self>,
201                        rhs: ExpandElementTyped<Self>,
202                    ) -> ExpandElementTyped<Line<bool>> {
203                        lhs.[< __expand_ $name _method >](context, rhs)
204                    }
205                }
206
207                impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
208                    /// Expand method of [equal](Line::equal).
209                    pub fn [< __expand_ $name _method >](
210                        self,
211                        context: &mut CubeContext,
212                        rhs: Self,
213                    ) -> ExpandElementTyped<Line<bool>> {
214                        let size = self.expand.item.vectorization;
215                        let lhs = self.expand.into();
216                        let rhs = rhs.expand.into();
217
218                        let output = context.create_local_mut(Item::vectorized(bool::as_elem(context), size));
219
220                        context.register(Instruction::new(
221                            Operator::$operator(BinaryOperator { lhs, rhs }),
222                            output.clone().into(),
223                        ));
224
225                        output.into()
226                    }
227                }
228            }
229        }
230
231    };
232}
233
234impl_line_comparison!(equal, Equal, "equal to");
235impl_line_comparison!(not_equal, NotEqual, "not equal to");
236impl_line_comparison!(less_than, Lower, "less than");
237impl_line_comparison!(greater_than, Greater, "greater than");
238impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
239impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
240
241impl<P: CubePrimitive> CubeType for Line<P> {
242    type ExpandType = ExpandElementTyped<Self>;
243}
244
245impl<P: CubePrimitive> ExpandElementBaseInit for Line<P> {
246    fn init_elem(context: &mut crate::prelude::CubeContext, elem: ExpandElement) -> ExpandElement {
247        P::init_elem(context, elem)
248    }
249}
250
251impl<P: CubePrimitive> IntoRuntime for Line<P> {
252    fn __expand_runtime_method(
253        self,
254        context: &mut crate::prelude::CubeContext,
255    ) -> Self::ExpandType {
256        self.val.__expand_runtime_method(context).expand.into()
257    }
258}
259
260impl<P: CubePrimitive> CubePrimitive for Line<P> {
261    fn as_elem(context: &CubeContext) -> Elem {
262        P::as_elem(context)
263    }
264
265    fn as_elem_native() -> Option<Elem> {
266        P::as_elem_native()
267    }
268
269    fn size() -> Option<usize> {
270        P::size()
271    }
272}
273
274impl<N: Numeric> Dot for Line<N> {
275    fn dot(self, _rhs: Self) -> Self {
276        unexpanded!()
277    }
278
279    fn __expand_dot(
280        context: &mut CubeContext,
281        lhs: ExpandElementTyped<Self>,
282        rhs: ExpandElementTyped<Self>,
283    ) -> ExpandElementTyped<Self> {
284        let lhs: ExpandElement = lhs.into();
285        let mut item = lhs.item;
286        item.vectorization = None;
287        binary_expand_fixed_output(context, lhs, rhs.into(), item, Operator::Dot).into()
288    }
289}