Skip to main content

cubecl_core/frontend/container/tensor/
base.rs

1use crate::{
2    frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, SizedContainer},
3    ir::{Metadata, Scope, Type},
4    prelude::{
5        Line, Lined, LinedExpand, List, ListExpand, ListMut, ListMutExpand, index, index_assign,
6        index_unchecked,
7    },
8    unexpanded,
9};
10use cubecl_ir::{ExpandElement, LineSize};
11use cubecl_macros::{cube, intrinsic};
12use std::{
13    marker::PhantomData,
14    ops::{Deref, DerefMut},
15};
16
17use crate as cubecl;
18
19/// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more
20/// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape).
21#[derive(new, Clone, Copy)]
22pub struct Tensor<T: CubeType> {
23    _val: PhantomData<T>,
24}
25
26type TensorExpand<T> = ExpandElementTyped<Tensor<T>>;
27
28/// Module that contains the implementation details of the metadata functions.
29mod metadata {
30    use cubecl_ir::ExpandElement;
31
32    use super::*;
33    use crate::{
34        ir::{Arithmetic, BinaryOperator, Instruction},
35        prelude::Array,
36    };
37
38    #[cube]
39    impl<T: CubeType> Tensor<T> {
40        /// Obtain the stride of input at dimension dim
41        #[allow(unused_variables)]
42        pub fn stride(&self, dim: usize) -> usize {
43            intrinsic!(|scope| {
44                let dim: ExpandElement = dim.into();
45                let out = scope.create_local(Type::new(usize::as_type(scope)));
46                scope.register(Instruction::new(
47                    Metadata::Stride {
48                        dim: *dim,
49                        var: self.expand.into(),
50                    },
51                    out.clone().into(),
52                ));
53                out.into()
54            })
55        }
56
57        /// Obtain the shape of input at dimension dim
58        #[allow(unused_variables)]
59        pub fn shape(&self, dim: usize) -> usize {
60            intrinsic!(|scope| {
61                let dim: ExpandElement = dim.into();
62                let out = scope.create_local(Type::new(usize::as_type(scope)));
63                scope.register(Instruction::new(
64                    Metadata::Shape {
65                        dim: *dim,
66                        var: self.expand.into(),
67                    },
68                    out.clone().into(),
69                ));
70                out.into()
71            })
72        }
73
74        /// Obtain the coordinate corresponding to the given `index` of the tensor at dimension `dim`.
75        ///
76        /// A coordinate is a list of indices corresponding to the multi-dimensional position of an element in the tensor.
77        /// The `dim` element in a coordinate is the position along the `dim` dimension of the tensor.
78        #[allow(unused_variables)]
79        pub fn coordinate(&self, index: usize, dim: usize) -> usize {
80            intrinsic!(|scope| {
81                let index: ExpandElement = index.into();
82                let stride = self.clone().__expand_stride_method(scope, dim.clone());
83                let shape = self.clone().__expand_shape_method(scope, dim.clone());
84
85                // Compute `num_strides = index / stride`.
86                let num_strides = scope.create_local(Type::new(usize::as_type(scope)));
87                scope.register(Instruction::new(
88                    Arithmetic::Div(BinaryOperator {
89                        lhs: *index,
90                        rhs: stride.expand.into(),
91                    }),
92                    num_strides.clone().into(),
93                ));
94
95                // Compute `coordinate = num_strides % shape `.
96                let coordinate = scope.create_local(Type::new(usize::as_type(scope)));
97                scope.register(Instruction::new(
98                    Arithmetic::Modulo(BinaryOperator {
99                        lhs: *num_strides,
100                        rhs: shape.expand.into(),
101                    }),
102                    coordinate.clone().into(),
103                ));
104
105                coordinate.into()
106            })
107        }
108
109        /// The number of vectorized elements in the tensor.
110        ///
111        /// # Warning
112        ///
113        /// The length will be affected by the vectorization factor. To obtain the number of elements,
114        /// you should multiply the length by the vectorization factor.
115        #[allow(clippy::len_without_is_empty)]
116        pub fn len(&self) -> usize {
117            intrinsic!(|scope| {
118                let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
119                elem.__expand_len_method(scope)
120            })
121        }
122
123        /// The length of the buffer representing the tensor in terms of vectorized elements.
124        ///
125        /// # Warning
126        ///
127        /// The buffer length will be affected by the vectorization factor. To obtain the number of
128        /// elements, you should multiply the length by the vectorization factor.
129        #[allow(clippy::len_without_is_empty)]
130        pub fn buffer_len(&self) -> usize {
131            intrinsic!(|scope| {
132                let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
133                elem.__expand_buffer_len_method(scope)
134            })
135        }
136
137        /// Returns the rank of the tensor.
138        pub fn rank(&self) -> usize {
139            intrinsic!(|scope| {
140                let out = scope.create_local(Type::new(usize::as_type(scope)));
141                scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
142                out.into()
143            })
144        }
145    }
146}
147
148/// Module that contains the implementation details of the index functions.
149mod indexation {
150    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
151
152    use crate::{
153        ir::Instruction,
154        prelude::{CubeIndex, CubeIndexMut},
155    };
156
157    use super::*;
158
159    #[cube]
160    impl<E: CubePrimitive> Tensor<E> {
161        /// Perform an unchecked index into the array
162        ///
163        /// # Safety
164        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
165        /// always in bounds
166        #[allow(unused_variables)]
167        pub unsafe fn index_unchecked(&self, i: usize) -> &E
168        where
169            Self: CubeIndex,
170        {
171            intrinsic!(|scope| {
172                let out = scope.create_local(self.expand.ty);
173                scope.register(Instruction::new(
174                    Operator::UncheckedIndex(IndexOperator {
175                        list: *self.expand,
176                        index: i.expand.consume(),
177                        line_size: 0,
178                        unroll_factor: 1,
179                    }),
180                    *out,
181                ));
182                out.into()
183            })
184        }
185
186        /// Perform an unchecked index assignment into the array
187        ///
188        /// # Safety
189        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
190        /// always in bounds
191        #[allow(unused_variables)]
192        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
193        where
194            Self: CubeIndexMut,
195        {
196            intrinsic!(|scope| {
197                scope.register(Instruction::new(
198                    Operator::UncheckedIndexAssign(IndexAssignOperator {
199                        index: i.expand.consume(),
200                        value: value.expand.consume(),
201                        line_size: 0,
202                        unroll_factor: 1,
203                    }),
204                    *self.expand,
205                ));
206            })
207        }
208    }
209}
210
211/// Module that contains the implementation details of the `line_size` function.
212mod line {
213    use super::*;
214
215    impl<P: CubePrimitive> Tensor<Line<P>> {
216        /// Get the size of each line contained in the tensor.
217        ///
218        /// Same as the following:
219        ///
220        /// ```rust, ignore
221        /// let size = tensor[0].size();
222        /// ```
223        pub fn line_size(&self) -> LineSize {
224            unexpanded!()
225        }
226
227        // Expand function of [size](Tensor::line_size).
228        pub fn __expand_line_size(
229            expand: <Self as CubeType>::ExpandType,
230            scope: &mut Scope,
231        ) -> LineSize {
232            expand.__expand_line_size_method(scope)
233        }
234    }
235}
236
237impl<T: CubePrimitive> SizedContainer for Tensor<T> {
238    type Item = T;
239}
240
241impl<T: CubeType> Iterator for &Tensor<T> {
242    type Item = T;
243
244    fn next(&mut self) -> Option<Self::Item> {
245        unexpanded!()
246    }
247}
248
249impl<T: CubeType> CubeType for Tensor<T> {
250    type ExpandType = ExpandElementTyped<Tensor<T>>;
251}
252
253impl<T: CubeType> CubeType for *const Tensor<T> {
254    type ExpandType = ExpandElementTyped<Tensor<T>>;
255}
256
257impl<T: CubeType> CubeType for *mut Tensor<T> {
258    type ExpandType = ExpandElementTyped<Tensor<T>>;
259}
260
261impl<T: CubeType> CubeType for &mut Tensor<T> {
262    type ExpandType = ExpandElementTyped<Tensor<T>>;
263}
264
265impl<T: CubeType> CubeType for &Tensor<T> {
266    type ExpandType = ExpandElementTyped<Tensor<T>>;
267}
268
269impl<C: CubeType> ExpandElementIntoMut for Tensor<C> {
270    fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
271        elem
272    }
273}
274
275impl<T: CubePrimitive> List<T> for Tensor<T> {
276    fn __expand_read(
277        scope: &mut Scope,
278        this: ExpandElementTyped<Tensor<T>>,
279        idx: ExpandElementTyped<usize>,
280    ) -> ExpandElementTyped<T> {
281        index::expand(scope, this, idx)
282    }
283}
284
285impl<T: CubePrimitive> Deref for Tensor<T> {
286    type Target = [T];
287
288    fn deref(&self) -> &Self::Target {
289        unexpanded!()
290    }
291}
292
293impl<T: CubePrimitive> DerefMut for Tensor<T> {
294    fn deref_mut(&mut self) -> &mut Self::Target {
295        unexpanded!()
296    }
297}
298
299impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
300    fn __expand_read_method(
301        &self,
302        scope: &mut Scope,
303        idx: ExpandElementTyped<usize>,
304    ) -> ExpandElementTyped<T> {
305        index::expand(scope, self.clone(), idx)
306    }
307    fn __expand_read_unchecked_method(
308        &self,
309        scope: &mut Scope,
310        idx: ExpandElementTyped<usize>,
311    ) -> ExpandElementTyped<T> {
312        index_unchecked::expand(scope, self.clone(), idx)
313    }
314
315    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
316        Self::__expand_len(scope, self.clone())
317    }
318}
319
320impl<T: CubePrimitive> Lined for Tensor<T> {}
321impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Tensor<T>> {
322    fn line_size(&self) -> LineSize {
323        self.expand.ty.line_size()
324    }
325}
326
327impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
328    fn __expand_write(
329        scope: &mut Scope,
330        this: ExpandElementTyped<Tensor<T>>,
331        idx: ExpandElementTyped<usize>,
332        value: ExpandElementTyped<T>,
333    ) {
334        index_assign::expand(scope, this, idx, value);
335    }
336}
337
338impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
339    fn __expand_write_method(
340        &self,
341        scope: &mut Scope,
342        idx: ExpandElementTyped<usize>,
343        value: ExpandElementTyped<T>,
344    ) {
345        index_assign::expand(scope, self.clone(), idx, value);
346    }
347}