Skip to main content

cubecl_core/frontend/container/vector/
base.rs

1use core::{marker::PhantomData, ops::Neg};
2
3use crate::frontend::{CubePrimitive, CubeType, NativeAssign, NativeExpand};
4use crate::ir::{BinaryOperator, Instruction, Scope, Type};
5use crate::{self as cubecl, prelude::*};
6use cubecl_ir::{Comparison, ConstantValue, ManagedVariable};
7use cubecl_macros::{cube, intrinsic};
8
9/// A contiguous list of elements that supports auto-vectorized operations.
10#[derive(Debug)]
11pub struct Vector<P: Scalar, N: Size> {
12    // Comptime vectors only support 1 element.
13    pub(crate) val: P,
14    pub(crate) _size: PhantomData<N>,
15}
16
17type VectorExpand<P, N> = NativeExpand<Vector<P, N>>;
18
19impl<P: Scalar, N: Size> Clone for Vector<P, N> {
20    fn clone(&self) -> Self {
21        *self
22    }
23}
24impl<P: Scalar, N: Size> Eq for Vector<P, N> {}
25impl<P: Scalar, N: Size> Copy for Vector<P, N> {}
26impl<P: Scalar + Neg<Output = P>, N: Size> Neg for Vector<P, N> {
27    type Output = Self;
28
29    fn neg(self) -> Self::Output {
30        Self {
31            val: -self.val,
32            _size: PhantomData,
33        }
34    }
35}
36
37/// Module that contains the implementation details of the new function.
38mod new {
39    use cubecl_ir::VectorSize;
40    use cubecl_macros::comptime_type;
41
42    use crate::prelude::Cast;
43
44    use super::*;
45
46    impl<P: Scalar, N: Size> Vector<P, N> {
47        /// Create a new vector of size 1 using the given value.
48        #[allow(unused_variables)]
49        pub fn new(val: P) -> Self {
50            Self {
51                val,
52                _size: PhantomData,
53            }
54        }
55
56        pub fn __expand_new(scope: &mut Scope, val: NativeExpand<P>) -> VectorExpand<P, N> {
57            Vector::<P, N>::__expand_cast_from(scope, val)
58        }
59    }
60
61    impl<P: Scalar, N: Size> Vector<P, N> {
62        /// Get the length of the current vector.
63        pub fn vector_size(&self) -> comptime_type!(VectorSize) {
64            N::value()
65        }
66    }
67}
68
69mod numeric {
70    use super::*;
71
72    #[cube]
73    impl<P: Numeric, N: Size> Vector<P, N> {
74        pub fn min_value() -> Self {
75            Self::new(P::min_value())
76        }
77        pub fn max_value() -> Self {
78            Self::new(P::max_value())
79        }
80
81        /// Create a new constant numeric.
82        ///
83        /// Note: since this must work for both integer and float
84        /// only the less expressive of both can be created (int)
85        /// If a number with decimals is needed, use `Float::new`.
86        ///
87        /// This method panics when unexpanded. For creating an element
88        /// with a val, use the new method of the sub type.
89        pub fn from_int(val: i64) -> Self {
90            Self::new(P::from_int(val))
91        }
92    }
93}
94
95/// Module that contains the implementation details of the fill function.
96mod fill {
97    use crate::prelude::cast;
98
99    use super::*;
100
101    #[cube]
102    impl<P: Scalar, N: Size> Vector<P, N> {
103        /// Fill the vector with the given value.
104        ///
105        /// If you want to fill the vector with different values, consider using the index API
106        /// instead.
107        ///
108        /// ```rust, ignore
109        /// let mut vector = Vector::<u32>::empty(2);
110        /// vector[0] = 1;
111        /// vector[1] = 2;
112        /// ```
113        #[allow(unused_variables)]
114        pub fn fill(self, value: P) -> Self {
115            intrinsic!(|scope| {
116                let output = scope.create_local(Vector::<P, N>::as_type(scope));
117
118                cast::expand::<P, Vector<P, N>>(scope, value, output.clone().into());
119
120                output.into()
121            })
122        }
123    }
124}
125
126/// Module that contains the implementation details of the empty function.
127mod empty {
128    use bytemuck::Zeroable;
129
130    use super::*;
131
132    #[cube]
133    impl<P: Scalar, N: Size> Vector<P, N> {
134        pub fn empty() -> Self {
135            intrinsic!(|scope| {
136                let value = Self::__expand_default(scope);
137                value.into_mut(scope)
138            })
139        }
140    }
141
142    #[cube]
143    impl<P: Scalar + Zeroable, N: Size> Vector<P, N> {
144        pub fn zeroed() -> Self {
145            intrinsic!(|scope| {
146                let zeroed = P::zeroed().__expand_runtime_method(scope);
147                Self::__expand_cast_from(scope, zeroed)
148            })
149        }
150    }
151}
152
153/// Module that contains the implementation details of the size function.
154mod size {
155    use cubecl_ir::VectorSize;
156
157    use super::*;
158
159    impl<P: Scalar, N: Size> Vector<P, N> {
160        /// Get the number of individual elements a vector contains.
161        ///
162        /// The size is available at comptime and may be used in combination with the comptime
163        /// macro.
164        ///
165        /// ```rust, ignore
166        /// // The if statement is going to be executed at comptime.
167        /// if comptime!(vector.size() == 1) {
168        /// }
169        /// ```
170        pub fn size(&self) -> VectorSize {
171            N::value()
172        }
173
174        /// Expand function of [size](Self::size).
175        pub fn __expand_size(scope: &mut Scope, element: NativeExpand<Vector<P, N>>) -> VectorSize {
176            element.__expand_vector_size_method(scope)
177        }
178    }
179
180    impl<P: Scalar, N: Size> NativeExpand<Vector<P, N>> {
181        /// Comptime version of [size](Vector::size).
182        pub fn size(&self) -> VectorSize {
183            self.expand.ty.vector_size()
184        }
185
186        /// Expand method of [size](Vector::size).
187        pub fn __expand_size_method(&self, _scope: &mut Scope) -> VectorSize {
188            self.size()
189        }
190    }
191}
192
193// Implement a comparison operator define in
194macro_rules! impl_vector_comparison {
195    ($name:ident, $operator:ident, $comment:literal) => {
196        ::paste::paste! {
197            /// Module that contains the implementation details of the $name function.
198            mod $name {
199
200                use super::*;
201
202                #[cube]
203                impl<P: Scalar, N: Size> Vector<P, N> {
204                    #[doc = concat!(
205                        "Return a new vector with the element-wise comparison of the first vector being ",
206                        $comment,
207                        " the second vector."
208                    )]
209                    #[allow(unused_variables)]
210                    pub fn $name(self, other: Self) -> Vector<bool, N> {
211                        intrinsic!(|scope| {
212                            let size = self.expand.ty.vector_size();
213                            let lhs = self.expand.into();
214                            let rhs = other.expand.into();
215
216                            let output = scope.create_local_mut(Vector::<bool, N>::as_type(scope));
217
218                            scope.register(Instruction::new(
219                                Comparison::$operator(BinaryOperator { lhs, rhs }),
220                                output.clone().into(),
221                            ));
222
223                            output.into()
224                        })
225                    }
226                }
227            }
228        }
229
230    };
231}
232
233impl_vector_comparison!(equal, Equal, "equal to");
234impl_vector_comparison!(not_equal, NotEqual, "not equal to");
235impl_vector_comparison!(less_than, Lower, "less than");
236impl_vector_comparison!(greater_than, Greater, "greater than");
237impl_vector_comparison!(less_equal, LowerEqual, "less than or equal to");
238impl_vector_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
239
240mod bool_and {
241    use cubecl_ir::Operator;
242
243    use crate::prelude::binary_expand;
244
245    use super::*;
246
247    #[cube]
248    impl<N: Size> Vector<bool, N> {
249        /// Return a new vector with the element-wise and of the vectors
250        #[allow(unused_variables)]
251        pub fn and(self, other: Self) -> Vector<bool, N> {
252            intrinsic!(
253                |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
254            )
255        }
256    }
257}
258
259mod bool_or {
260    use cubecl_ir::Operator;
261
262    use crate::prelude::binary_expand;
263
264    use super::*;
265
266    #[cube]
267    impl<N: Size> Vector<bool, N> {
268        /// Return a new vector with the element-wise and of the vectors
269        #[allow(unused_variables)]
270        pub fn or(self, other: Self) -> Vector<bool, N> {
271            intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
272        }
273    }
274}
275
276impl<P: Scalar, N: Size> CubeType for Vector<P, N> {
277    type ExpandType = NativeExpand<Self>;
278}
279
280impl<P: Scalar, N: Size> CubeType for &Vector<P, N> {
281    type ExpandType = NativeExpand<Vector<P, N>>;
282}
283
284impl<P: Scalar, N: Size> CubeType for &mut Vector<P, N> {
285    type ExpandType = NativeExpand<Vector<P, N>>;
286}
287
288impl<P: Scalar, N: Size> NativeAssign for Vector<P, N> {
289    fn elem_init_mut(scope: &mut crate::ir::Scope, elem: ManagedVariable) -> ManagedVariable {
290        P::elem_init_mut(scope, elem)
291    }
292}
293
294impl<P: Scalar, N: Size> CubePrimitive for Vector<P, N> {
295    type Scalar = P;
296    type Size = N;
297    type WithScalar<S: Scalar> = Vector<S, N>;
298
299    fn as_type(scope: &Scope) -> Type {
300        P::as_type(scope).with_vector_size(N::__expand_value(scope))
301    }
302
303    fn as_type_native() -> Option<Type> {
304        P::as_type_native().and_then(|ty| {
305            let vector_size = N::try_value_const()?;
306            Some(ty.with_vector_size(vector_size))
307        })
308    }
309
310    fn from_const_value(value: ConstantValue) -> Self {
311        Self::new(P::from_const_value(value))
312    }
313}
314
315impl<T: Dot + Scalar, N: Size> Dot for Vector<T, N> {}
316impl<T: MulHi + Scalar, N: Size> MulHi for Vector<T, N> {}
317impl<T: FloatOps + Scalar, N: Size> FloatOps for Vector<T, N> {}
318impl<T: Hypot + Scalar, N: Size> Hypot for Vector<T, N> {}
319impl<T: Rhypot + Scalar, N: Size> Rhypot for Vector<T, N> {}