Skip to main content

cubecl_core/frontend/container/tensor/
base.rs

1use crate::{
2    frontend::{CubePrimitive, CubeType, NativeExpand, SizedContainer},
3    ir::{Metadata, Scope},
4    prelude::*,
5    unexpanded,
6};
7use core::{
8    marker::PhantomData,
9    ops::{Deref, DerefMut},
10};
11use cubecl_ir::VectorSize;
12use cubecl_macros::{cube, intrinsic};
13
14use crate as cubecl;
15
16/// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more
17/// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape).
18#[derive(new, Clone, Copy)]
19pub struct Tensor<T: CubeType> {
20    _val: PhantomData<T>,
21}
22
23type TensorExpand<T> = NativeExpand<Tensor<T>>;
24
25/// Module that contains the implementation details of the metadata functions.
26mod metadata {
27    use cubecl_ir::ManagedVariable;
28
29    use super::*;
30    use crate::{
31        ir::{Arithmetic, BinaryOperator, Instruction},
32        prelude::Array,
33    };
34
35    #[cube]
36    impl<T: CubeType> Tensor<T> {
37        /// Obtain the stride of input at dimension dim
38        #[allow(unused_variables)]
39        pub fn stride(&self, dim: usize) -> usize {
40            intrinsic!(|scope| {
41                let dim: ManagedVariable = dim.into();
42                let out = scope.create_local(usize::as_type(scope));
43                scope.register(Instruction::new(
44                    Metadata::Stride {
45                        dim: *dim,
46                        var: self.expand.into(),
47                    },
48                    out.clone().into(),
49                ));
50                out.into()
51            })
52        }
53
54        /// Obtain the shape of input at dimension dim
55        #[allow(unused_variables)]
56        pub fn shape(&self, dim: usize) -> usize {
57            intrinsic!(|scope| {
58                let dim: ManagedVariable = dim.into();
59                let out = scope.create_local(usize::as_type(scope));
60                scope.register(Instruction::new(
61                    Metadata::Shape {
62                        dim: *dim,
63                        var: self.expand.into(),
64                    },
65                    out.clone().into(),
66                ));
67                out.into()
68            })
69        }
70
71        /// Obtain the coordinate corresponding to the given `index` of the tensor at dimension `dim`.
72        ///
73        /// A coordinate is a list of indices corresponding to the multi-dimensional position of an element in the tensor.
74        /// The `dim` element in a coordinate is the position along the `dim` dimension of the tensor.
75        #[allow(unused_variables)]
76        pub fn coordinate(&self, index: usize, dim: usize) -> usize {
77            intrinsic!(|scope| {
78                let index: ManagedVariable = index.into();
79                let stride = self.clone().__expand_stride_method(scope, dim.clone());
80                let shape = self.clone().__expand_shape_method(scope, dim.clone());
81
82                // Compute `num_strides = index / stride`.
83                let num_strides = scope.create_local(usize::as_type(scope));
84                scope.register(Instruction::new(
85                    Arithmetic::Div(BinaryOperator {
86                        lhs: *index,
87                        rhs: stride.expand.into(),
88                    }),
89                    num_strides.clone().into(),
90                ));
91
92                // Compute `coordinate = num_strides % shape `.
93                let coordinate = scope.create_local(usize::as_type(scope));
94                scope.register(Instruction::new(
95                    Arithmetic::Modulo(BinaryOperator {
96                        lhs: *num_strides,
97                        rhs: shape.expand.into(),
98                    }),
99                    coordinate.clone().into(),
100                ));
101
102                coordinate.into()
103            })
104        }
105
106        /// The number of vectorized elements in the tensor.
107        ///
108        /// # Warning
109        ///
110        /// The length will be affected by the vectorization factor. To obtain the number of elements,
111        /// you should multiply the length by the vectorization factor.
112        #[allow(clippy::len_without_is_empty)]
113        pub fn len(&self) -> usize {
114            intrinsic!(|scope| {
115                let elem: NativeExpand<Array<u32>> = self.expand.into();
116                elem.__expand_len_method(scope)
117            })
118        }
119
120        /// The length of the buffer representing the tensor in terms of vectorized elements.
121        ///
122        /// # Warning
123        ///
124        /// The buffer length will be affected by the vectorization factor. To obtain the number of
125        /// elements, you should multiply the length by the vectorization factor.
126        #[allow(clippy::len_without_is_empty)]
127        pub fn buffer_len(&self) -> usize {
128            intrinsic!(|scope| {
129                let elem: NativeExpand<Array<u32>> = self.expand.into();
130                elem.__expand_buffer_len_method(scope)
131            })
132        }
133
134        /// Returns the rank of the tensor.
135        pub fn rank(&self) -> usize {
136            intrinsic!(|scope| {
137                let out = scope.create_local(usize::as_type(scope));
138                scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
139                out.into()
140            })
141        }
142    }
143}
144
145/// Module that contains the implementation details of the index functions.
146mod indexation {
147    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
148
149    use crate::ir::Instruction;
150
151    use super::*;
152
153    #[cube]
154    impl<E: CubePrimitive> Tensor<E> {
155        /// Perform an unchecked index into the array
156        ///
157        /// # Safety
158        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
159        /// always in bounds
160        #[allow(unused_variables)]
161        pub unsafe fn index_unchecked(&self, i: usize) -> &E {
162            intrinsic!(|scope| {
163                let out = scope.create_local(self.expand.ty);
164                scope.register(Instruction::new(
165                    Operator::UncheckedIndex(IndexOperator {
166                        list: *self.expand,
167                        index: i.expand.consume(),
168                        vector_size: 0,
169                        unroll_factor: 1,
170                    }),
171                    *out,
172                ));
173                out.into()
174            })
175        }
176
177        /// Perform an unchecked index assignment into the array
178        ///
179        /// # Safety
180        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
181        /// always in bounds
182        #[allow(unused_variables)]
183        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
184            intrinsic!(|scope| {
185                scope.register(Instruction::new(
186                    Operator::UncheckedIndexAssign(IndexAssignOperator {
187                        index: i.expand.consume(),
188                        value: value.expand.consume(),
189                        vector_size: 0,
190                        unroll_factor: 1,
191                    }),
192                    *self.expand,
193                ));
194            })
195        }
196    }
197}
198
199/// Module that contains the implementation details of the `vector_size` function.
200mod vector {
201    use super::*;
202
203    impl<P: Scalar, N: Size> Tensor<Vector<P, N>> {
204        /// Get the size of each vector contained in the tensor.
205        ///
206        /// Same as the following:
207        ///
208        /// ```rust, ignore
209        /// let size = tensor[0].size();
210        /// ```
211        pub fn vector_size(&self) -> VectorSize {
212            N::value()
213        }
214
215        // Expand function of [size](Tensor::vector_size).
216        pub fn __expand_vector_size(
217            expand: <Self as CubeType>::ExpandType,
218            scope: &mut Scope,
219        ) -> VectorSize {
220            expand.__expand_vector_size_method(scope)
221        }
222    }
223}
224
225impl<T: CubePrimitive> SizedContainer for Tensor<T> {
226    type Item = T;
227}
228
229impl<T: CubeType> Iterator for &Tensor<T> {
230    type Item = T;
231
232    fn next(&mut self) -> Option<Self::Item> {
233        unexpanded!()
234    }
235}
236
237impl<T: CubeType> CubeType for Tensor<T> {
238    type ExpandType = NativeExpand<Tensor<T>>;
239}
240
241impl<T: CubeType> CubeType for *const Tensor<T> {
242    type ExpandType = NativeExpand<Tensor<T>>;
243}
244
245impl<T: CubeType> CubeType for *mut Tensor<T> {
246    type ExpandType = NativeExpand<Tensor<T>>;
247}
248
249impl<T: CubeType> CubeType for &mut Tensor<T> {
250    type ExpandType = NativeExpand<Tensor<T>>;
251}
252
253impl<T: CubeType> CubeType for &Tensor<T> {
254    type ExpandType = NativeExpand<Tensor<T>>;
255}
256
257impl<C: CubeType> IntoMut for NativeExpand<Tensor<C>> {
258    fn into_mut(self, _scope: &mut Scope) -> Self {
259        self
260    }
261}
262
263impl<T: CubePrimitive> List<T> for Tensor<T> {
264    fn __expand_read(
265        scope: &mut Scope,
266        this: NativeExpand<Tensor<T>>,
267        idx: NativeExpand<usize>,
268    ) -> NativeExpand<T> {
269        index::expand(scope, this, idx)
270    }
271}
272
273impl<T: CubePrimitive> Deref for Tensor<T> {
274    type Target = [T];
275
276    fn deref(&self) -> &Self::Target {
277        unexpanded!()
278    }
279}
280
281impl<T: CubePrimitive> DerefMut for Tensor<T> {
282    fn deref_mut(&mut self) -> &mut Self::Target {
283        unexpanded!()
284    }
285}
286
287impl<T: CubePrimitive> ListExpand<T> for NativeExpand<Tensor<T>> {
288    fn __expand_read_method(&self, scope: &mut Scope, idx: NativeExpand<usize>) -> NativeExpand<T> {
289        index::expand(scope, self.clone(), idx)
290    }
291    fn __expand_read_unchecked_method(
292        &self,
293        scope: &mut Scope,
294        idx: NativeExpand<usize>,
295    ) -> NativeExpand<T> {
296        index_unchecked::expand(scope, self.clone(), idx)
297    }
298
299    fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
300        Self::__expand_len(scope, self.clone())
301    }
302}
303
304impl<T: CubePrimitive> Vectorized for Tensor<T> {}
305impl<T: CubePrimitive> VectorizedExpand for NativeExpand<Tensor<T>> {
306    fn vector_size(&self) -> VectorSize {
307        self.expand.ty.vector_size()
308    }
309}
310
311impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
312    fn __expand_write(
313        scope: &mut Scope,
314        this: NativeExpand<Tensor<T>>,
315        idx: NativeExpand<usize>,
316        value: NativeExpand<T>,
317    ) {
318        index_assign::expand(scope, this, idx, value);
319    }
320}
321
322impl<T: CubePrimitive> ListMutExpand<T> for NativeExpand<Tensor<T>> {
323    fn __expand_write_method(
324        &self,
325        scope: &mut Scope,
326        idx: NativeExpand<usize>,
327        value: NativeExpand<T>,
328    ) {
329        index_assign::expand(scope, self.clone(), idx, value);
330    }
331}