cubecl_core/frontend/container/tensor/
base.rs

1use crate::{
2    frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, SizedContainer},
3    ir::{Metadata, Scope, Type},
4    prelude::{
5        Line, Lined, LinedExpand, List, ListExpand, ListMut, ListMutExpand, index, index_assign,
6        index_unchecked,
7    },
8    unexpanded,
9};
10use cubecl_ir::ExpandElement;
11use cubecl_macros::{cube, intrinsic};
12use std::marker::PhantomData;
13
14use crate as cubecl;
15
16/// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more
17/// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape).
18#[derive(new, Clone, Copy)]
19pub struct Tensor<T: CubeType> {
20    _val: PhantomData<T>,
21}
22
23type TensorExpand<T> = ExpandElementTyped<Tensor<T>>;
24
25/// Module that contains the implementation details of the metadata functions.
26mod metadata {
27    use cubecl_ir::ExpandElement;
28
29    use super::*;
30    use crate::{
31        ir::{Arithmetic, BinaryOperator, Instruction},
32        prelude::Array,
33    };
34
35    #[cube]
36    impl<T: CubeType> Tensor<T> {
37        /// Obtain the stride of input at dimension dim
38        #[allow(unused_variables)]
39        pub fn stride(&self, dim: u32) -> u32 {
40            intrinsic!(|scope| {
41                let dim: ExpandElement = dim.into();
42                let out = scope.create_local(Type::new(u32::as_type(scope)));
43                scope.register(Instruction::new(
44                    Metadata::Stride {
45                        dim: *dim,
46                        var: self.expand.into(),
47                    },
48                    out.clone().into(),
49                ));
50                out.into()
51            })
52        }
53
54        /// Obtain the shape of input at dimension dim
55        #[allow(unused_variables)]
56        pub fn shape(&self, dim: u32) -> u32 {
57            intrinsic!(|scope| {
58                let dim: ExpandElement = dim.into();
59                let out = scope.create_local(Type::new(u32::as_type(scope)));
60                scope.register(Instruction::new(
61                    Metadata::Shape {
62                        dim: *dim,
63                        var: self.expand.into(),
64                    },
65                    out.clone().into(),
66                ));
67                out.into()
68            })
69        }
70
71        /// Obtain the coordinate corresponding to the given `index` of the tensor at dimension `dim`.
72        ///
73        /// A coordinate is a list of indices corresponding to the multi-dimensional position of an element in the tensor.
74        /// The `dim` element in a coordinate is the position along the `dim` dimension of the tensor.
75        #[allow(unused_variables)]
76        pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
77            intrinsic!(|scope| {
78                let index: ExpandElement = index.into();
79                let stride = self.clone().__expand_stride_method(scope, dim.clone());
80                let shape = self.clone().__expand_shape_method(scope, dim.clone());
81
82                // Compute `num_strides = index / stride`.
83                let num_strides = scope.create_local(Type::new(u32::as_type(scope)));
84                scope.register(Instruction::new(
85                    Arithmetic::Div(BinaryOperator {
86                        lhs: *index,
87                        rhs: stride.expand.into(),
88                    }),
89                    num_strides.clone().into(),
90                ));
91
92                // Compute `coordinate = num_strides % shape `.
93                let coordinate = scope.create_local(Type::new(u32::as_type(scope)));
94                scope.register(Instruction::new(
95                    Arithmetic::Modulo(BinaryOperator {
96                        lhs: *num_strides,
97                        rhs: shape.expand.into(),
98                    }),
99                    coordinate.clone().into(),
100                ));
101
102                coordinate.into()
103            })
104        }
105
106        /// The number of vectorized elements in the tensor.
107        ///
108        /// # Warning
109        ///
110        /// The length will be affected by the vectorization factor. To obtain the number of elements,
111        /// you should multiply the length by the vectorization factor.
112        #[allow(clippy::len_without_is_empty)]
113        pub fn len(&self) -> u32 {
114            intrinsic!(|scope| {
115                let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
116                elem.__expand_len_method(scope)
117            })
118        }
119
120        /// The length of the buffer representing the tensor in terms of vectorized elements.
121        ///
122        /// # Warning
123        ///
124        /// The buffer length will be affected by the vectorization factor. To obtain the number of
125        /// elements, you should multiply the length by the vectorization factor.
126        #[allow(clippy::len_without_is_empty)]
127        pub fn buffer_len(&self) -> u32 {
128            intrinsic!(|scope| {
129                let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
130                elem.__expand_buffer_len_method(scope)
131            })
132        }
133
134        /// Returns the rank of the tensor.
135        pub fn rank(&self) -> u32 {
136            intrinsic!(|scope| {
137                let out = scope.create_local(Type::new(u32::as_type(scope)));
138                scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
139                out.into()
140            })
141        }
142    }
143}
144
145/// Module that contains the implementation details of the index functions.
146mod indexation {
147    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
148
149    use crate::{
150        ir::Instruction,
151        prelude::{CubeIndex, CubeIndexMut},
152    };
153
154    use super::*;
155
156    #[cube]
157    impl<E: CubePrimitive> Tensor<E> {
158        /// Perform an unchecked index into the array
159        ///
160        /// # Safety
161        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
162        /// always in bounds
163        #[allow(unused_variables)]
164        pub unsafe fn index_unchecked(&self, i: u32) -> &E
165        where
166            Self: CubeIndex,
167        {
168            intrinsic!(|scope| {
169                let out = scope.create_local(self.expand.ty);
170                scope.register(Instruction::new(
171                    Operator::UncheckedIndex(IndexOperator {
172                        list: *self.expand,
173                        index: i.expand.consume(),
174                        line_size: 0,
175                        unroll_factor: 1,
176                    }),
177                    *out,
178                ));
179                out.into()
180            })
181        }
182
183        /// Perform an unchecked index assignment into the array
184        ///
185        /// # Safety
186        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
187        /// always in bounds
188        #[allow(unused_variables)]
189        pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
190        where
191            Self: CubeIndexMut,
192        {
193            intrinsic!(|scope| {
194                scope.register(Instruction::new(
195                    Operator::UncheckedIndexAssign(IndexAssignOperator {
196                        index: i.expand.consume(),
197                        value: value.expand.consume(),
198                        line_size: 0,
199                        unroll_factor: 1,
200                    }),
201                    *self.expand,
202                ));
203            })
204        }
205    }
206}
207
208/// Module that contains the implementation details of the line_size function.
209mod line {
210    use super::*;
211
212    impl<P: CubePrimitive> Tensor<Line<P>> {
213        /// Get the size of each line contained in the tensor.
214        ///
215        /// Same as the following:
216        ///
217        /// ```rust, ignore
218        /// let size = tensor[0].size();
219        /// ```
220        pub fn line_size(&self) -> u32 {
221            unexpanded!()
222        }
223
224        // Expand function of [size](Tensor::line_size).
225        pub fn __expand_line_size(
226            expand: <Self as CubeType>::ExpandType,
227            scope: &mut Scope,
228        ) -> u32 {
229            expand.__expand_line_size_method(scope)
230        }
231    }
232}
233
234impl<T: CubePrimitive> SizedContainer for Tensor<T> {
235    type Item = T;
236}
237
238impl<T: CubeType> Iterator for &Tensor<T> {
239    type Item = T;
240
241    fn next(&mut self) -> Option<Self::Item> {
242        unexpanded!()
243    }
244}
245
246impl<T: CubeType> CubeType for Tensor<T> {
247    type ExpandType = ExpandElementTyped<Tensor<T>>;
248}
249
250impl<T: CubeType> CubeType for *const Tensor<T> {
251    type ExpandType = ExpandElementTyped<Tensor<T>>;
252}
253
254impl<T: CubeType> CubeType for *mut Tensor<T> {
255    type ExpandType = ExpandElementTyped<Tensor<T>>;
256}
257
258impl<T: CubeType> CubeType for &mut Tensor<T> {
259    type ExpandType = ExpandElementTyped<Tensor<T>>;
260}
261
262impl<T: CubeType> CubeType for &Tensor<T> {
263    type ExpandType = ExpandElementTyped<Tensor<T>>;
264}
265
266impl<C: CubeType> ExpandElementIntoMut for Tensor<C> {
267    fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
268        elem
269    }
270}
271
272impl<T: CubePrimitive> List<T> for Tensor<T> {
273    fn __expand_read(
274        scope: &mut Scope,
275        this: ExpandElementTyped<Tensor<T>>,
276        idx: ExpandElementTyped<u32>,
277    ) -> ExpandElementTyped<T> {
278        index::expand(scope, this, idx)
279    }
280}
281
282impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
283    fn __expand_read_method(
284        &self,
285        scope: &mut Scope,
286        idx: ExpandElementTyped<u32>,
287    ) -> ExpandElementTyped<T> {
288        index::expand(scope, self.clone(), idx)
289    }
290    fn __expand_read_unchecked_method(
291        &self,
292        scope: &mut Scope,
293        idx: ExpandElementTyped<u32>,
294    ) -> ExpandElementTyped<T> {
295        index_unchecked::expand(scope, self.clone(), idx)
296    }
297
298    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
299        Self::__expand_len(scope, self.clone())
300    }
301}
302
303impl<T: CubePrimitive> Lined for Tensor<T> {}
304impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Tensor<T>> {
305    fn line_size(&self) -> u32 {
306        self.expand.ty.line_size()
307    }
308}
309
310impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
311    fn __expand_write(
312        scope: &mut Scope,
313        this: ExpandElementTyped<Tensor<T>>,
314        idx: ExpandElementTyped<u32>,
315        value: ExpandElementTyped<T>,
316    ) {
317        index_assign::expand(scope, this, idx, value);
318    }
319}
320
321impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
322    fn __expand_write_method(
323        &self,
324        scope: &mut Scope,
325        idx: ExpandElementTyped<u32>,
326        value: ExpandElementTyped<T>,
327    ) {
328        index_assign::expand(scope, self.clone(), idx, value);
329    }
330}