cubecl_core/frontend/container/vector/
base.rs1use core::{marker::PhantomData, ops::Neg};
2
3use crate::frontend::{CubePrimitive, CubeType, NativeAssign, NativeExpand};
4use crate::ir::{BinaryOperator, Instruction, Scope, Type};
5use crate::{self as cubecl, prelude::*};
6use cubecl_ir::{Comparison, ConstantValue, ManagedVariable};
7use cubecl_macros::{cube, intrinsic};
8
9#[derive(Debug)]
11pub struct Vector<P: Scalar, N: Size> {
12 pub(crate) val: P,
14 pub(crate) _size: PhantomData<N>,
15}
16
17type VectorExpand<P, N> = NativeExpand<Vector<P, N>>;
18
19impl<P: Scalar, N: Size> Clone for Vector<P, N> {
20 fn clone(&self) -> Self {
21 *self
22 }
23}
24impl<P: Scalar, N: Size> Eq for Vector<P, N> {}
25impl<P: Scalar, N: Size> Copy for Vector<P, N> {}
26impl<P: Scalar + Neg<Output = P>, N: Size> Neg for Vector<P, N> {
27 type Output = Self;
28
29 fn neg(self) -> Self::Output {
30 Self {
31 val: -self.val,
32 _size: PhantomData,
33 }
34 }
35}
36
37mod new {
39 use cubecl_ir::VectorSize;
40 use cubecl_macros::comptime_type;
41
42 use crate::prelude::Cast;
43
44 use super::*;
45
46 impl<P: Scalar, N: Size> Vector<P, N> {
47 #[allow(unused_variables)]
49 pub fn new(val: P) -> Self {
50 Self {
51 val,
52 _size: PhantomData,
53 }
54 }
55
56 pub fn __expand_new(scope: &mut Scope, val: NativeExpand<P>) -> VectorExpand<P, N> {
57 Vector::<P, N>::__expand_cast_from(scope, val)
58 }
59 }
60
61 impl<P: Scalar, N: Size> Vector<P, N> {
62 pub fn vector_size(&self) -> comptime_type!(VectorSize) {
64 N::value()
65 }
66 }
67}
68
69mod numeric {
70 use super::*;
71
72 #[cube]
73 impl<P: Numeric, N: Size> Vector<P, N> {
74 pub fn min_value() -> Self {
75 Self::new(P::min_value())
76 }
77 pub fn max_value() -> Self {
78 Self::new(P::max_value())
79 }
80
81 pub fn from_int(val: i64) -> Self {
90 Self::new(P::from_int(val))
91 }
92 }
93}
94
95mod fill {
97 use crate::prelude::cast;
98
99 use super::*;
100
101 #[cube]
102 impl<P: Scalar, N: Size> Vector<P, N> {
103 #[allow(unused_variables)]
114 pub fn fill(self, value: P) -> Self {
115 intrinsic!(|scope| {
116 let output = scope.create_local(Vector::<P, N>::as_type(scope));
117
118 cast::expand::<P, Vector<P, N>>(scope, value, output.clone().into());
119
120 output.into()
121 })
122 }
123 }
124}
125
126mod empty {
128 use bytemuck::Zeroable;
129
130 use super::*;
131
132 #[cube]
133 impl<P: Scalar, N: Size> Vector<P, N> {
134 pub fn empty() -> Self {
135 intrinsic!(|scope| {
136 let value = Self::__expand_default(scope);
137 value.into_mut(scope)
138 })
139 }
140 }
141
142 #[cube]
143 impl<P: Scalar + Zeroable, N: Size> Vector<P, N> {
144 pub fn zeroed() -> Self {
145 intrinsic!(|scope| {
146 let zeroed = P::zeroed().__expand_runtime_method(scope);
147 Self::__expand_cast_from(scope, zeroed)
148 })
149 }
150 }
151}
152
153mod size {
155 use cubecl_ir::VectorSize;
156
157 use super::*;
158
159 impl<P: Scalar, N: Size> Vector<P, N> {
160 pub fn size(&self) -> VectorSize {
171 N::value()
172 }
173
174 pub fn __expand_size(scope: &mut Scope, element: NativeExpand<Vector<P, N>>) -> VectorSize {
176 element.__expand_vector_size_method(scope)
177 }
178 }
179
180 impl<P: Scalar, N: Size> NativeExpand<Vector<P, N>> {
181 pub fn size(&self) -> VectorSize {
183 self.expand.ty.vector_size()
184 }
185
186 pub fn __expand_size_method(&self, _scope: &mut Scope) -> VectorSize {
188 self.size()
189 }
190 }
191}
192
193macro_rules! impl_vector_comparison {
195 ($name:ident, $operator:ident, $comment:literal) => {
196 ::paste::paste! {
197 mod $name {
199
200 use super::*;
201
202 #[cube]
203 impl<P: Scalar, N: Size> Vector<P, N> {
204 #[doc = concat!(
205 "Return a new vector with the element-wise comparison of the first vector being ",
206 $comment,
207 " the second vector."
208 )]
209 #[allow(unused_variables)]
210 pub fn $name(self, other: Self) -> Vector<bool, N> {
211 intrinsic!(|scope| {
212 let size = self.expand.ty.vector_size();
213 let lhs = self.expand.into();
214 let rhs = other.expand.into();
215
216 let output = scope.create_local_mut(Vector::<bool, N>::as_type(scope));
217
218 scope.register(Instruction::new(
219 Comparison::$operator(BinaryOperator { lhs, rhs }),
220 output.clone().into(),
221 ));
222
223 output.into()
224 })
225 }
226 }
227 }
228 }
229
230 };
231}
232
233impl_vector_comparison!(equal, Equal, "equal to");
234impl_vector_comparison!(not_equal, NotEqual, "not equal to");
235impl_vector_comparison!(less_than, Lower, "less than");
236impl_vector_comparison!(greater_than, Greater, "greater than");
237impl_vector_comparison!(less_equal, LowerEqual, "less than or equal to");
238impl_vector_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
239
240mod bool_and {
241 use cubecl_ir::Operator;
242
243 use crate::prelude::binary_expand;
244
245 use super::*;
246
247 #[cube]
248 impl<N: Size> Vector<bool, N> {
249 #[allow(unused_variables)]
251 pub fn and(self, other: Self) -> Vector<bool, N> {
252 intrinsic!(
253 |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
254 )
255 }
256 }
257}
258
259mod bool_or {
260 use cubecl_ir::Operator;
261
262 use crate::prelude::binary_expand;
263
264 use super::*;
265
266 #[cube]
267 impl<N: Size> Vector<bool, N> {
268 #[allow(unused_variables)]
270 pub fn or(self, other: Self) -> Vector<bool, N> {
271 intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
272 }
273 }
274}
275
276impl<P: Scalar, N: Size> CubeType for Vector<P, N> {
277 type ExpandType = NativeExpand<Self>;
278}
279
280impl<P: Scalar, N: Size> CubeType for &Vector<P, N> {
281 type ExpandType = NativeExpand<Vector<P, N>>;
282}
283
284impl<P: Scalar, N: Size> CubeType for &mut Vector<P, N> {
285 type ExpandType = NativeExpand<Vector<P, N>>;
286}
287
288impl<P: Scalar, N: Size> NativeAssign for Vector<P, N> {
289 fn elem_init_mut(scope: &mut crate::ir::Scope, elem: ManagedVariable) -> ManagedVariable {
290 P::elem_init_mut(scope, elem)
291 }
292}
293
294impl<P: Scalar, N: Size> CubePrimitive for Vector<P, N> {
295 type Scalar = P;
296 type Size = N;
297 type WithScalar<S: Scalar> = Vector<S, N>;
298
299 fn as_type(scope: &Scope) -> Type {
300 P::as_type(scope).with_vector_size(N::__expand_value(scope))
301 }
302
303 fn as_type_native() -> Option<Type> {
304 P::as_type_native().and_then(|ty| {
305 let vector_size = N::try_value_const()?;
306 Some(ty.with_vector_size(vector_size))
307 })
308 }
309
310 fn from_const_value(value: ConstantValue) -> Self {
311 Self::new(P::from_const_value(value))
312 }
313}
314
315impl<T: Dot + Scalar, N: Size> Dot for Vector<T, N> {}
316impl<T: MulHi + Scalar, N: Size> MulHi for Vector<T, N> {}
317impl<T: FloatOps + Scalar, N: Size> FloatOps for Vector<T, N> {}
318impl<T: Hypot + Scalar, N: Size> Hypot for Vector<T, N> {}
319impl<T: Rhypot + Scalar, N: Size> Rhypot for Vector<T, N> {}