cubecl_core/frontend/container/tensor/
base.rs

1use crate::frontend::{ExpandElementBaseInit, ExpandElementTyped, SizedContainer};
2use crate::prelude::IntoRuntime;
3use crate::{
4    frontend::{indexation::Index, CubeContext, CubePrimitive, CubeType, ExpandElement},
5    ir::{Item, Metadata},
6    prelude::Line,
7    unexpanded,
8};
9use std::{marker::PhantomData, num::NonZero};
10
11/// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more
12/// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape).
13#[derive(new)]
14pub struct Tensor<T: CubeType> {
15    _val: PhantomData<T>,
16}
17
18/// Module that contains the implementation details of the metadata functions.
19mod metadata {
20    use super::*;
21    use crate::{
22        ir::{BinaryOperator, Instruction, Operator},
23        prelude::Array,
24    };
25
26    impl<T: CubeType> Tensor<T> {
27        /// Obtain the stride of input at dimension dim
28        pub fn stride<C: Index>(&self, _dim: C) -> u32 {
29            unexpanded!()
30        }
31
32        /// Obtain the shape of input at dimension dim
33        pub fn shape<C: Index>(&self, _dim: C) -> u32 {
34            unexpanded!()
35        }
36
37        /// Obtain the coordinate corresponding to the given `index` of the tensor at dimension `dim`.
38        ///
39        /// A coordinate is a list of indices corresponding to the multi-dimensional position of an element in the tensor.
40        /// The `dim` element in a coordinate is the position along the `dim` dimension of the tensor.
41        pub fn coordinate<I: Index, D: Index>(&self, _index: I, _dim: D) -> u32 {
42            unexpanded!()
43        }
44
45        /// The number of vectorized elements in the tensor.
46        ///
47        /// # Warning
48        ///
49        /// The length will be affected by the vectorization factor. To obtain the number of elements,
50        /// you should multiply the length by the vectorization factor.
51        #[allow(clippy::len_without_is_empty)]
52        pub fn len(&self) -> u32 {
53            unexpanded!()
54        }
55
56        /// The length of the buffer representing the tensor in terms of vectorized elements.
57        ///
58        /// # Warning
59        ///
60        /// The buffer length will be affected by the vectorization factor. To obtain the number of
61        /// elements, you should multiply the length by the vectorization factor.
62        #[allow(clippy::len_without_is_empty)]
63        pub fn buffer_len(&self) -> u32 {
64            unexpanded!()
65        }
66
67        /// Returns the rank of the tensor.
68        pub fn rank(&self) -> u32 {
69            unexpanded!()
70        }
71
72        // Expand function of [stride](Tensor::stride).
73        pub fn __expand_stride<C: Index>(
74            context: &mut CubeContext,
75            expand: ExpandElementTyped<Tensor<T>>,
76            dim: ExpandElementTyped<u32>,
77        ) -> ExpandElementTyped<u32> {
78            expand.__expand_stride_method(context, dim)
79        }
80
81        // Expand function of [shape](Tensor::shape).
82        pub fn __expand_shape<C: Index>(
83            context: &mut CubeContext,
84            expand: ExpandElementTyped<Tensor<T>>,
85            dim: ExpandElementTyped<u32>,
86        ) -> ExpandElementTyped<u32> {
87            expand.__expand_shape_method(context, dim)
88        }
89
90        // Expand function of [coordinate](Tensor::coordinate).
91        pub fn __expand_coordinate<I: Index, D: Index>(
92            context: &mut CubeContext,
93            expand: ExpandElementTyped<Tensor<T>>,
94            index: ExpandElementTyped<u32>,
95            dim: ExpandElementTyped<u32>,
96        ) -> ExpandElementTyped<u32> {
97            expand.__expand_coordinate_method(context, index, dim)
98        }
99
100        // Expand function of [len](Tensor::len).
101        pub fn __expand_len<C: Index>(
102            context: &mut CubeContext,
103            expand: ExpandElementTyped<Tensor<T>>,
104        ) -> ExpandElementTyped<u32> {
105            expand.__expand_len_method(context)
106        }
107
108        // Expand function of [buffer_len](Tensor::buffer_len).
109        pub fn __expand_buffer_len<C: Index>(
110            context: &mut CubeContext,
111            expand: ExpandElementTyped<Tensor<T>>,
112        ) -> ExpandElementTyped<u32> {
113            expand.__expand_buffer_len_method(context)
114        }
115
116        // Expand function of [rank](Tensor::rank).
117        pub fn __expand_rank<C: Index>(
118            context: &mut CubeContext,
119            expand: ExpandElementTyped<Tensor<T>>,
120        ) -> ExpandElementTyped<u32> {
121            expand.__expand_rank_method(context)
122        }
123    }
124
125    impl<T: CubeType> ExpandElementTyped<Tensor<T>> {
126        // Expand method of [stride](Tensor::stride).
127        pub fn __expand_stride_method(
128            self,
129            context: &mut CubeContext,
130            dim: ExpandElementTyped<u32>,
131        ) -> ExpandElementTyped<u32> {
132            let dim: ExpandElement = dim.into();
133            let out = context.create_local(Item::new(u32::as_elem(context)));
134            context.register(Instruction::new(
135                Metadata::Stride {
136                    dim: *dim,
137                    var: self.expand.into(),
138                },
139                out.clone().into(),
140            ));
141            out.into()
142        }
143
144        // Expand method of [shape](Tensor::shape).
145        pub fn __expand_shape_method(
146            self,
147            context: &mut CubeContext,
148            dim: ExpandElementTyped<u32>,
149        ) -> ExpandElementTyped<u32> {
150            let dim: ExpandElement = dim.into();
151            let out = context.create_local(Item::new(u32::as_elem(context)));
152            context.register(Instruction::new(
153                Metadata::Shape {
154                    dim: *dim,
155                    var: self.expand.into(),
156                },
157                out.clone().into(),
158            ));
159            out.into()
160        }
161
162        // Expand method of [coordinate](Tensor::coordinate).
163        pub fn __expand_coordinate_method(
164            self,
165            context: &mut CubeContext,
166            index: ExpandElementTyped<u32>,
167            dim: ExpandElementTyped<u32>,
168        ) -> ExpandElementTyped<u32> {
169            let index: ExpandElement = index.into();
170            let stride = self.clone().__expand_stride_method(context, dim.clone());
171            let shape = self.clone().__expand_shape_method(context, dim.clone());
172
173            // Compute `num_strides = index / stride`.
174            let num_strides = context.create_local(Item::new(u32::as_elem(context)));
175            context.register(Instruction::new(
176                Operator::Div(BinaryOperator {
177                    lhs: *index,
178                    rhs: stride.expand.into(),
179                }),
180                num_strides.clone().into(),
181            ));
182
183            // Compute `coordinate = num_strides % shape `.
184            let coordinate = context.create_local(Item::new(u32::as_elem(context)));
185            context.register(Instruction::new(
186                Operator::Modulo(BinaryOperator {
187                    lhs: *num_strides,
188                    rhs: shape.expand.into(),
189                }),
190                coordinate.clone().into(),
191            ));
192
193            coordinate.into()
194        }
195
196        // Expand method of [len](Tensor::len).
197        pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
198            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
199            elem.__expand_len_method(context)
200        }
201
202        // Expand method of [buffer_len](Tensor::buffer_len).
203        pub fn __expand_buffer_len_method(
204            self,
205            context: &mut CubeContext,
206        ) -> ExpandElementTyped<u32> {
207            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
208            elem.__expand_buffer_len_method(context)
209        }
210
211        // Expand method of [rank](Tensor::rank).
212        pub fn __expand_rank_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
213            let out = context.create_local(Item::new(u32::as_elem(context)));
214            context.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
215            out.into()
216        }
217    }
218}
219
220/// Module that contains the implementation details of the index functions.
221mod indexation {
222    use crate::{
223        ir::{BinaryOperator, Instruction, Operator},
224        prelude::{CubeIndex, CubeIndexMut},
225    };
226
227    use super::*;
228
229    impl<E: CubePrimitive> Tensor<E> {
230        /// Perform an unchecked index into the array
231        ///
232        /// # Safety
233        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
234        /// always in bounds
235        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
236        where
237            Self: CubeIndex<I>,
238        {
239            unexpanded!()
240        }
241
242        /// Perform an unchecked index assignment into the array
243        ///
244        /// # Safety
245        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
246        /// always in bounds
247        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
248        where
249            Self: CubeIndexMut<I>,
250        {
251            unexpanded!()
252        }
253    }
254
255    impl<E: CubePrimitive> ExpandElementTyped<Tensor<E>> {
256        pub fn __expand_index_unchecked_method(
257            self,
258            context: &mut CubeContext,
259            i: ExpandElementTyped<u32>,
260        ) -> ExpandElementTyped<E> {
261            let out = context.create_local(self.expand.item);
262            context.register(Instruction::new(
263                Operator::UncheckedIndex(BinaryOperator {
264                    lhs: *self.expand,
265                    rhs: i.expand.consume(),
266                }),
267                *out,
268            ));
269            out.into()
270        }
271
272        pub fn __expand_index_assign_unchecked_method(
273            self,
274            context: &mut CubeContext,
275            i: ExpandElementTyped<u32>,
276            value: ExpandElementTyped<E>,
277        ) {
278            context.register(Instruction::new(
279                Operator::UncheckedIndexAssign(BinaryOperator {
280                    lhs: i.expand.consume(),
281                    rhs: value.expand.consume(),
282                }),
283                *self.expand,
284            ));
285        }
286    }
287}
288
289/// Module that contains the implementation details of the line_size function.
290mod line {
291    use super::*;
292
293    impl<P: CubePrimitive> Tensor<Line<P>> {
294        /// Get the size of each line contained in the tensor.
295        ///
296        /// Same as the following:
297        ///
298        /// ```rust, ignore
299        /// let size = tensor[0].size();
300        /// ```
301        pub fn line_size(&self) -> u32 {
302            unexpanded!()
303        }
304
305        // Expand function of [size](Tensor::line_size).
306        pub fn __expand_line_size(
307            expand: <Self as CubeType>::ExpandType,
308            context: &mut CubeContext,
309        ) -> u32 {
310            expand.__expand_line_size_method(context)
311        }
312    }
313
314    impl<P: CubePrimitive> ExpandElementTyped<Tensor<Line<P>>> {
315        /// Comptime version of [size](Tensor::line_size).
316        pub fn line_size(&self) -> u32 {
317            self.expand
318                .item
319                .vectorization
320                .unwrap_or(NonZero::new(1).unwrap())
321                .get() as u32
322        }
323
324        // Expand method of [size](Tensor::line_size).
325        pub fn __expand_line_size_method(&self, _content: &mut CubeContext) -> u32 {
326            self.line_size()
327        }
328    }
329}
330
331impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Tensor<T> {
332    type Item = T;
333}
334
335impl<T: CubeType> Iterator for &Tensor<T> {
336    type Item = T;
337
338    fn next(&mut self) -> Option<Self::Item> {
339        unexpanded!()
340    }
341}
342
343impl<T: CubeType> CubeType for Tensor<T> {
344    type ExpandType = ExpandElementTyped<Tensor<T>>;
345}
346
347impl<T: CubeType> CubeType for *const Tensor<T> {
348    type ExpandType = ExpandElementTyped<Tensor<T>>;
349}
350
351impl<T: CubeType> CubeType for *mut Tensor<T> {
352    type ExpandType = ExpandElementTyped<Tensor<T>>;
353}
354
355impl<C: CubeType> ExpandElementBaseInit for Tensor<C> {
356    fn init_elem(_context: &mut crate::prelude::CubeContext, elem: ExpandElement) -> ExpandElement {
357        // The type can't be deeply cloned/copied.
358        elem
359    }
360}
361
362impl<E: CubePrimitive> IntoRuntime for Tensor<E> {
363    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
364        unimplemented!("Tensor can't exist at compile time")
365    }
366}
367
368impl<E: CubePrimitive> IntoRuntime for *const Tensor<E> {
369    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
370        unimplemented!("Tensor can't exist at compile time")
371    }
372}
373
374impl<E: CubePrimitive> IntoRuntime for *mut Tensor<E> {
375    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
376        unimplemented!("Tensor can't exist at compile time")
377    }
378}