cubecl_core/frontend/container/tensor/
base.rs

1use crate::{
2    frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, SizedContainer},
3    ir::{Item, Metadata, Scope},
4    prelude::{
5        Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
6    },
7    unexpanded,
8};
9use cubecl_ir::ExpandElement;
10use cubecl_macros::{cube, intrinsic};
11use std::{marker::PhantomData, num::NonZero};
12
13use crate as cubecl;
14
15/// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more
16/// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape).
17#[derive(new)]
18pub struct Tensor<T: CubeType> {
19    _val: PhantomData<T>,
20}
21
22type TensorExpand<T> = ExpandElementTyped<Tensor<T>>;
23
24/// Module that contains the implementation details of the metadata functions.
25mod metadata {
26    use cubecl_ir::ExpandElement;
27
28    use super::*;
29    use crate::{
30        ir::{Arithmetic, BinaryOperator, Instruction},
31        prelude::Array,
32    };
33
34    #[cube]
35    impl<T: CubeType> Tensor<T> {
36        /// Obtain the stride of input at dimension dim
37        #[allow(unused_variables)]
38        pub fn stride(&self, dim: u32) -> u32 {
39            intrinsic!(|scope| {
40                let dim: ExpandElement = dim.into();
41                let out = scope.create_local(Item::new(u32::as_elem(scope)));
42                scope.register(Instruction::new(
43                    Metadata::Stride {
44                        dim: *dim,
45                        var: self.expand.into(),
46                    },
47                    out.clone().into(),
48                ));
49                out.into()
50            })
51        }
52
53        /// Obtain the shape of input at dimension dim
54        #[allow(unused_variables)]
55        pub fn shape(&self, dim: u32) -> u32 {
56            intrinsic!(|scope| {
57                let dim: ExpandElement = dim.into();
58                let out = scope.create_local(Item::new(u32::as_elem(scope)));
59                scope.register(Instruction::new(
60                    Metadata::Shape {
61                        dim: *dim,
62                        var: self.expand.into(),
63                    },
64                    out.clone().into(),
65                ));
66                out.into()
67            })
68        }
69
70        /// Obtain the coordinate corresponding to the given `index` of the tensor at dimension `dim`.
71        ///
72        /// A coordinate is a list of indices corresponding to the multi-dimensional position of an element in the tensor.
73        /// The `dim` element in a coordinate is the position along the `dim` dimension of the tensor.
74        #[allow(unused_variables)]
75        pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
76            intrinsic!(|scope| {
77                let index: ExpandElement = index.into();
78                let stride = self.clone().__expand_stride_method(scope, dim.clone());
79                let shape = self.clone().__expand_shape_method(scope, dim.clone());
80
81                // Compute `num_strides = index / stride`.
82                let num_strides = scope.create_local(Item::new(u32::as_elem(scope)));
83                scope.register(Instruction::new(
84                    Arithmetic::Div(BinaryOperator {
85                        lhs: *index,
86                        rhs: stride.expand.into(),
87                    }),
88                    num_strides.clone().into(),
89                ));
90
91                // Compute `coordinate = num_strides % shape `.
92                let coordinate = scope.create_local(Item::new(u32::as_elem(scope)));
93                scope.register(Instruction::new(
94                    Arithmetic::Modulo(BinaryOperator {
95                        lhs: *num_strides,
96                        rhs: shape.expand.into(),
97                    }),
98                    coordinate.clone().into(),
99                ));
100
101                coordinate.into()
102            })
103        }
104
105        /// The number of vectorized elements in the tensor.
106        ///
107        /// # Warning
108        ///
109        /// The length will be affected by the vectorization factor. To obtain the number of elements,
110        /// you should multiply the length by the vectorization factor.
111        #[allow(clippy::len_without_is_empty)]
112        pub fn len(&self) -> u32 {
113            intrinsic!(|scope| {
114                let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
115                elem.__expand_len_method(scope)
116            })
117        }
118
119        /// The length of the buffer representing the tensor in terms of vectorized elements.
120        ///
121        /// # Warning
122        ///
123        /// The buffer length will be affected by the vectorization factor. To obtain the number of
124        /// elements, you should multiply the length by the vectorization factor.
125        #[allow(clippy::len_without_is_empty)]
126        pub fn buffer_len(&self) -> u32 {
127            intrinsic!(|scope| {
128                let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
129                elem.__expand_buffer_len_method(scope)
130            })
131        }
132
133        /// Returns the rank of the tensor.
134        pub fn rank(&self) -> u32 {
135            intrinsic!(|scope| {
136                let out = scope.create_local(Item::new(u32::as_elem(scope)));
137                scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
138                out.into()
139            })
140        }
141    }
142}
143
144/// Module that contains the implementation details of the index functions.
145mod indexation {
146    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
147
148    use crate::{
149        ir::Instruction,
150        prelude::{CubeIndex, CubeIndexMut},
151    };
152
153    use super::*;
154
155    #[cube]
156    impl<E: CubePrimitive> Tensor<E> {
157        /// Perform an unchecked index into the array
158        ///
159        /// # Safety
160        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
161        /// always in bounds
162        #[allow(unused_variables)]
163        pub unsafe fn index_unchecked(&self, i: u32) -> &E
164        where
165            Self: CubeIndex,
166        {
167            intrinsic!(|scope| {
168                let out = scope.create_local(self.expand.item);
169                scope.register(Instruction::new(
170                    Operator::UncheckedIndex(IndexOperator {
171                        list: *self.expand,
172                        index: i.expand.consume(),
173                        line_size: 0,
174                    }),
175                    *out,
176                ));
177                out.into()
178            })
179        }
180
181        /// Perform an unchecked index assignment into the array
182        ///
183        /// # Safety
184        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
185        /// always in bounds
186        #[allow(unused_variables)]
187        pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
188        where
189            Self: CubeIndexMut,
190        {
191            intrinsic!(|scope| {
192                scope.register(Instruction::new(
193                    Operator::UncheckedIndexAssign(IndexAssignOperator {
194                        index: i.expand.consume(),
195                        value: value.expand.consume(),
196                        line_size: 0,
197                    }),
198                    *self.expand,
199                ));
200            })
201        }
202    }
203}
204
205/// Module that contains the implementation details of the line_size function.
206mod line {
207    use super::*;
208
209    impl<P: CubePrimitive> Tensor<Line<P>> {
210        /// Get the size of each line contained in the tensor.
211        ///
212        /// Same as the following:
213        ///
214        /// ```rust, ignore
215        /// let size = tensor[0].size();
216        /// ```
217        pub fn line_size(&self) -> u32 {
218            unexpanded!()
219        }
220
221        // Expand function of [size](Tensor::line_size).
222        pub fn __expand_line_size(
223            expand: <Self as CubeType>::ExpandType,
224            scope: &mut Scope,
225        ) -> u32 {
226            expand.__expand_line_size_method(scope)
227        }
228    }
229
230    impl<P: CubePrimitive> ExpandElementTyped<Tensor<Line<P>>> {
231        /// Comptime version of [size](Tensor::line_size).
232        pub fn line_size(&self) -> u32 {
233            self.expand
234                .item
235                .vectorization
236                .unwrap_or(NonZero::new(1).unwrap())
237                .get() as u32
238        }
239
240        // Expand method of [size](Tensor::line_size).
241        pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
242            self.line_size()
243        }
244    }
245}
246
247impl<T: CubePrimitive> SizedContainer for Tensor<T> {
248    type Item = T;
249}
250
251impl<T: CubeType> Iterator for &Tensor<T> {
252    type Item = T;
253
254    fn next(&mut self) -> Option<Self::Item> {
255        unexpanded!()
256    }
257}
258
259impl<T: CubeType> CubeType for Tensor<T> {
260    type ExpandType = ExpandElementTyped<Tensor<T>>;
261}
262
263impl<T: CubeType> CubeType for *const Tensor<T> {
264    type ExpandType = ExpandElementTyped<Tensor<T>>;
265}
266
267impl<T: CubeType> CubeType for *mut Tensor<T> {
268    type ExpandType = ExpandElementTyped<Tensor<T>>;
269}
270
271impl<C: CubeType> ExpandElementIntoMut for Tensor<C> {
272    fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
273        elem
274    }
275}
276
277impl<T: CubePrimitive> List<T> for Tensor<T> {
278    fn __expand_read(
279        scope: &mut Scope,
280        this: ExpandElementTyped<Tensor<T>>,
281        idx: ExpandElementTyped<u32>,
282    ) -> ExpandElementTyped<T> {
283        index::expand(scope, this, idx)
284    }
285}
286
287impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
288    fn __expand_read_method(
289        &self,
290        scope: &mut Scope,
291        idx: ExpandElementTyped<u32>,
292    ) -> ExpandElementTyped<T> {
293        index::expand(scope, self.clone(), idx)
294    }
295    fn __expand_read_unchecked_method(
296        &self,
297        scope: &mut Scope,
298        idx: ExpandElementTyped<u32>,
299    ) -> ExpandElementTyped<T> {
300        index_unchecked::expand(scope, self.clone(), idx)
301    }
302}
303
304impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
305    fn __expand_write(
306        scope: &mut Scope,
307        this: ExpandElementTyped<Tensor<T>>,
308        idx: ExpandElementTyped<u32>,
309        value: ExpandElementTyped<T>,
310    ) {
311        index_assign::expand(scope, this, idx, value);
312    }
313}
314
315impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
316    fn __expand_write_method(
317        &self,
318        scope: &mut Scope,
319        idx: ExpandElementTyped<u32>,
320        value: ExpandElementTyped<T>,
321    ) {
322        index_assign::expand(scope, self.clone(), idx, value);
323    }
324}