cubecl_core/frontend/container/tensor/
base.rs

1use crate::{
2    frontend::{
3        CubePrimitive, CubeType, ExpandElementBaseInit, ExpandElementTyped, SizedContainer,
4        indexation::Index,
5    },
6    ir::{Item, Metadata, Scope},
7    prelude::{Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign},
8    unexpanded,
9};
10use cubecl_ir::ExpandElement;
11use std::{marker::PhantomData, num::NonZero};
12
13/// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more
14/// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape).
15#[derive(new)]
16pub struct Tensor<T: CubeType> {
17    _val: PhantomData<T>,
18}
19
20/// Module that contains the implementation details of the metadata functions.
21mod metadata {
22    use cubecl_ir::ExpandElement;
23
24    use super::*;
25    use crate::{
26        ir::{Arithmetic, BinaryOperator, Instruction},
27        prelude::Array,
28    };
29
30    impl<T: CubeType> Tensor<T> {
31        /// Obtain the stride of input at dimension dim
32        pub fn stride<C: Index>(&self, _dim: C) -> u32 {
33            unexpanded!()
34        }
35
36        /// Obtain the shape of input at dimension dim
37        pub fn shape<C: Index>(&self, _dim: C) -> u32 {
38            unexpanded!()
39        }
40
41        /// Obtain the coordinate corresponding to the given `index` of the tensor at dimension `dim`.
42        ///
43        /// A coordinate is a list of indices corresponding to the multi-dimensional position of an element in the tensor.
44        /// The `dim` element in a coordinate is the position along the `dim` dimension of the tensor.
45        pub fn coordinate<I: Index, D: Index>(&self, _index: I, _dim: D) -> u32 {
46            unexpanded!()
47        }
48
49        /// The number of vectorized elements in the tensor.
50        ///
51        /// # Warning
52        ///
53        /// The length will be affected by the vectorization factor. To obtain the number of elements,
54        /// you should multiply the length by the vectorization factor.
55        #[allow(clippy::len_without_is_empty)]
56        pub fn len(&self) -> u32 {
57            unexpanded!()
58        }
59
60        /// The length of the buffer representing the tensor in terms of vectorized elements.
61        ///
62        /// # Warning
63        ///
64        /// The buffer length will be affected by the vectorization factor. To obtain the number of
65        /// elements, you should multiply the length by the vectorization factor.
66        #[allow(clippy::len_without_is_empty)]
67        pub fn buffer_len(&self) -> u32 {
68            unexpanded!()
69        }
70
71        /// Returns the rank of the tensor.
72        pub fn rank(&self) -> u32 {
73            unexpanded!()
74        }
75
76        // Expand function of [stride](Tensor::stride).
77        pub fn __expand_stride<C: Index>(
78            scope: &mut Scope,
79            expand: ExpandElementTyped<Tensor<T>>,
80            dim: ExpandElementTyped<u32>,
81        ) -> ExpandElementTyped<u32> {
82            expand.__expand_stride_method(scope, dim)
83        }
84
85        // Expand function of [shape](Tensor::shape).
86        pub fn __expand_shape<C: Index>(
87            scope: &mut Scope,
88            expand: ExpandElementTyped<Tensor<T>>,
89            dim: ExpandElementTyped<u32>,
90        ) -> ExpandElementTyped<u32> {
91            expand.__expand_shape_method(scope, dim)
92        }
93
94        // Expand function of [coordinate](Tensor::coordinate).
95        pub fn __expand_coordinate<I: Index, D: Index>(
96            scope: &mut Scope,
97            expand: ExpandElementTyped<Tensor<T>>,
98            index: ExpandElementTyped<u32>,
99            dim: ExpandElementTyped<u32>,
100        ) -> ExpandElementTyped<u32> {
101            expand.__expand_coordinate_method(scope, index, dim)
102        }
103
104        // Expand function of [len](Tensor::len).
105        pub fn __expand_len<C: Index>(
106            scope: &mut Scope,
107            expand: ExpandElementTyped<Tensor<T>>,
108        ) -> ExpandElementTyped<u32> {
109            expand.__expand_len_method(scope)
110        }
111
112        // Expand function of [buffer_len](Tensor::buffer_len).
113        pub fn __expand_buffer_len<C: Index>(
114            scope: &mut Scope,
115            expand: ExpandElementTyped<Tensor<T>>,
116        ) -> ExpandElementTyped<u32> {
117            expand.__expand_buffer_len_method(scope)
118        }
119
120        // Expand function of [rank](Tensor::rank).
121        pub fn __expand_rank<C: Index>(
122            scope: &mut Scope,
123            expand: ExpandElementTyped<Tensor<T>>,
124        ) -> ExpandElementTyped<u32> {
125            expand.__expand_rank_method(scope)
126        }
127    }
128
129    impl<T: CubeType> ExpandElementTyped<Tensor<T>> {
130        // Expand method of [stride](Tensor::stride).
131        pub fn __expand_stride_method(
132            self,
133            scope: &mut Scope,
134            dim: ExpandElementTyped<u32>,
135        ) -> ExpandElementTyped<u32> {
136            let dim: ExpandElement = dim.into();
137            let out = scope.create_local(Item::new(u32::as_elem(scope)));
138            scope.register(Instruction::new(
139                Metadata::Stride {
140                    dim: *dim,
141                    var: self.expand.into(),
142                },
143                out.clone().into(),
144            ));
145            out.into()
146        }
147
148        // Expand method of [shape](Tensor::shape).
149        pub fn __expand_shape_method(
150            self,
151            scope: &mut Scope,
152            dim: ExpandElementTyped<u32>,
153        ) -> ExpandElementTyped<u32> {
154            let dim: ExpandElement = dim.into();
155            let out = scope.create_local(Item::new(u32::as_elem(scope)));
156            scope.register(Instruction::new(
157                Metadata::Shape {
158                    dim: *dim,
159                    var: self.expand.into(),
160                },
161                out.clone().into(),
162            ));
163            out.into()
164        }
165
166        // Expand method of [coordinate](Tensor::coordinate).
167        pub fn __expand_coordinate_method(
168            self,
169            scope: &mut Scope,
170            index: ExpandElementTyped<u32>,
171            dim: ExpandElementTyped<u32>,
172        ) -> ExpandElementTyped<u32> {
173            let index: ExpandElement = index.into();
174            let stride = self.clone().__expand_stride_method(scope, dim.clone());
175            let shape = self.clone().__expand_shape_method(scope, dim.clone());
176
177            // Compute `num_strides = index / stride`.
178            let num_strides = scope.create_local(Item::new(u32::as_elem(scope)));
179            scope.register(Instruction::new(
180                Arithmetic::Div(BinaryOperator {
181                    lhs: *index,
182                    rhs: stride.expand.into(),
183                }),
184                num_strides.clone().into(),
185            ));
186
187            // Compute `coordinate = num_strides % shape `.
188            let coordinate = scope.create_local(Item::new(u32::as_elem(scope)));
189            scope.register(Instruction::new(
190                Arithmetic::Modulo(BinaryOperator {
191                    lhs: *num_strides,
192                    rhs: shape.expand.into(),
193                }),
194                coordinate.clone().into(),
195            ));
196
197            coordinate.into()
198        }
199
200        // Expand method of [len](Tensor::len).
201        pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
202            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
203            elem.__expand_len_method(scope)
204        }
205
206        // Expand method of [buffer_len](Tensor::buffer_len).
207        pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
208            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
209            elem.__expand_buffer_len_method(scope)
210        }
211
212        // Expand method of [rank](Tensor::rank).
213        pub fn __expand_rank_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
214            let out = scope.create_local(Item::new(u32::as_elem(scope)));
215            scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
216            out.into()
217        }
218    }
219}
220
221/// Module that contains the implementation details of the index functions.
222mod indexation {
223    use cubecl_ir::Operator;
224
225    use crate::{
226        ir::{BinaryOperator, Instruction},
227        prelude::{CubeIndex, CubeIndexMut},
228    };
229
230    use super::*;
231
232    impl<E: CubePrimitive> Tensor<E> {
233        /// Perform an unchecked index into the array
234        ///
235        /// # Safety
236        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
237        /// always in bounds
238        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
239        where
240            Self: CubeIndex<I>,
241        {
242            unexpanded!()
243        }
244
245        /// Perform an unchecked index assignment into the array
246        ///
247        /// # Safety
248        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
249        /// always in bounds
250        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
251        where
252            Self: CubeIndexMut<I>,
253        {
254            unexpanded!()
255        }
256    }
257
258    impl<E: CubePrimitive> ExpandElementTyped<Tensor<E>> {
259        pub fn __expand_index_unchecked_method(
260            self,
261            scope: &mut Scope,
262            i: ExpandElementTyped<u32>,
263        ) -> ExpandElementTyped<E> {
264            let out = scope.create_local(self.expand.item);
265            scope.register(Instruction::new(
266                Operator::UncheckedIndex(BinaryOperator {
267                    lhs: *self.expand,
268                    rhs: i.expand.consume(),
269                }),
270                *out,
271            ));
272            out.into()
273        }
274
275        pub fn __expand_index_assign_unchecked_method(
276            self,
277            scope: &mut Scope,
278            i: ExpandElementTyped<u32>,
279            value: ExpandElementTyped<E>,
280        ) {
281            scope.register(Instruction::new(
282                Operator::UncheckedIndexAssign(BinaryOperator {
283                    lhs: i.expand.consume(),
284                    rhs: value.expand.consume(),
285                }),
286                *self.expand,
287            ));
288        }
289    }
290}
291
292/// Module that contains the implementation details of the line_size function.
293mod line {
294    use super::*;
295
296    impl<P: CubePrimitive> Tensor<Line<P>> {
297        /// Get the size of each line contained in the tensor.
298        ///
299        /// Same as the following:
300        ///
301        /// ```rust, ignore
302        /// let size = tensor[0].size();
303        /// ```
304        pub fn line_size(&self) -> u32 {
305            unexpanded!()
306        }
307
308        // Expand function of [size](Tensor::line_size).
309        pub fn __expand_line_size(
310            expand: <Self as CubeType>::ExpandType,
311            scope: &mut Scope,
312        ) -> u32 {
313            expand.__expand_line_size_method(scope)
314        }
315    }
316
317    impl<P: CubePrimitive> ExpandElementTyped<Tensor<Line<P>>> {
318        /// Comptime version of [size](Tensor::line_size).
319        pub fn line_size(&self) -> u32 {
320            self.expand
321                .item
322                .vectorization
323                .unwrap_or(NonZero::new(1).unwrap())
324                .get() as u32
325        }
326
327        // Expand method of [size](Tensor::line_size).
328        pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
329            self.line_size()
330        }
331    }
332}
333
334impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Tensor<T> {
335    type Item = T;
336}
337
338impl<T: CubeType> Iterator for &Tensor<T> {
339    type Item = T;
340
341    fn next(&mut self) -> Option<Self::Item> {
342        unexpanded!()
343    }
344}
345
346impl<T: CubeType> CubeType for Tensor<T> {
347    type ExpandType = ExpandElementTyped<Tensor<T>>;
348}
349
350impl<T: CubeType> CubeType for *const Tensor<T> {
351    type ExpandType = ExpandElementTyped<Tensor<T>>;
352}
353
354impl<T: CubeType> CubeType for *mut Tensor<T> {
355    type ExpandType = ExpandElementTyped<Tensor<T>>;
356}
357
358impl<C: CubeType> ExpandElementBaseInit for Tensor<C> {
359    fn init_elem(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
360        // The type can't be deeply cloned/copied.
361        elem
362    }
363}
364
365impl<T: CubePrimitive> List<T> for Tensor<T> {
366    fn __expand_read(
367        scope: &mut Scope,
368        this: ExpandElementTyped<Tensor<T>>,
369        idx: ExpandElementTyped<u32>,
370    ) -> ExpandElementTyped<T> {
371        index::expand(scope, this, idx)
372    }
373}
374
375impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
376    fn __expand_read_method(
377        self,
378        scope: &mut Scope,
379        idx: ExpandElementTyped<u32>,
380    ) -> ExpandElementTyped<T> {
381        index::expand(scope, self, idx)
382    }
383}
384
385impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
386    fn __expand_write(
387        scope: &mut Scope,
388        this: ExpandElementTyped<Tensor<T>>,
389        idx: ExpandElementTyped<u32>,
390        value: ExpandElementTyped<T>,
391    ) {
392        index_assign::expand(scope, this, idx, value);
393    }
394}
395
396impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
397    fn __expand_write_method(
398        self,
399        scope: &mut Scope,
400        idx: ExpandElementTyped<u32>,
401        value: ExpandElementTyped<T>,
402    ) {
403        index_assign::expand(scope, self, idx, value);
404    }
405}