cubecl_core/frontend/container/tensor/
base.rs1use crate::{
2 frontend::{CubePrimitive, CubeType, NativeExpand, SizedContainer},
3 ir::{Metadata, Scope},
4 prelude::*,
5 unexpanded,
6};
7use core::{
8 marker::PhantomData,
9 ops::{Deref, DerefMut},
10};
11use cubecl_ir::VectorSize;
12use cubecl_macros::{cube, intrinsic};
13
14use crate as cubecl;
15
16#[derive(new, Clone, Copy)]
19pub struct Tensor<T: CubeType> {
20 _val: PhantomData<T>,
21}
22
23type TensorExpand<T> = NativeExpand<Tensor<T>>;
24
25mod metadata {
27 use cubecl_ir::ManagedVariable;
28
29 use super::*;
30 use crate::{
31 ir::{Arithmetic, BinaryOperator, Instruction},
32 prelude::Array,
33 };
34
35 #[cube]
36 impl<T: CubeType> Tensor<T> {
37 #[allow(unused_variables)]
39 pub fn stride(&self, dim: usize) -> usize {
40 intrinsic!(|scope| {
41 let dim: ManagedVariable = dim.into();
42 let out = scope.create_local(usize::as_type(scope));
43 scope.register(Instruction::new(
44 Metadata::Stride {
45 dim: *dim,
46 var: self.expand.into(),
47 },
48 out.clone().into(),
49 ));
50 out.into()
51 })
52 }
53
54 #[allow(unused_variables)]
56 pub fn shape(&self, dim: usize) -> usize {
57 intrinsic!(|scope| {
58 let dim: ManagedVariable = dim.into();
59 let out = scope.create_local(usize::as_type(scope));
60 scope.register(Instruction::new(
61 Metadata::Shape {
62 dim: *dim,
63 var: self.expand.into(),
64 },
65 out.clone().into(),
66 ));
67 out.into()
68 })
69 }
70
71 #[allow(unused_variables)]
76 pub fn coordinate(&self, index: usize, dim: usize) -> usize {
77 intrinsic!(|scope| {
78 let index: ManagedVariable = index.into();
79 let stride = self.clone().__expand_stride_method(scope, dim.clone());
80 let shape = self.clone().__expand_shape_method(scope, dim.clone());
81
82 let num_strides = scope.create_local(usize::as_type(scope));
84 scope.register(Instruction::new(
85 Arithmetic::Div(BinaryOperator {
86 lhs: *index,
87 rhs: stride.expand.into(),
88 }),
89 num_strides.clone().into(),
90 ));
91
92 let coordinate = scope.create_local(usize::as_type(scope));
94 scope.register(Instruction::new(
95 Arithmetic::Modulo(BinaryOperator {
96 lhs: *num_strides,
97 rhs: shape.expand.into(),
98 }),
99 coordinate.clone().into(),
100 ));
101
102 coordinate.into()
103 })
104 }
105
106 #[allow(clippy::len_without_is_empty)]
113 pub fn len(&self) -> usize {
114 intrinsic!(|scope| {
115 let elem: NativeExpand<Array<u32>> = self.expand.into();
116 elem.__expand_len_method(scope)
117 })
118 }
119
120 #[allow(clippy::len_without_is_empty)]
127 pub fn buffer_len(&self) -> usize {
128 intrinsic!(|scope| {
129 let elem: NativeExpand<Array<u32>> = self.expand.into();
130 elem.__expand_buffer_len_method(scope)
131 })
132 }
133
134 pub fn rank(&self) -> usize {
136 intrinsic!(|scope| {
137 let out = scope.create_local(usize::as_type(scope));
138 scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
139 out.into()
140 })
141 }
142 }
143}
144
145mod indexation {
147 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
148
149 use crate::ir::Instruction;
150
151 use super::*;
152
153 #[cube]
154 impl<E: CubePrimitive> Tensor<E> {
155 #[allow(unused_variables)]
161 pub unsafe fn index_unchecked(&self, i: usize) -> &E {
162 intrinsic!(|scope| {
163 let out = scope.create_local(self.expand.ty);
164 scope.register(Instruction::new(
165 Operator::UncheckedIndex(IndexOperator {
166 list: *self.expand,
167 index: i.expand.consume(),
168 vector_size: 0,
169 unroll_factor: 1,
170 }),
171 *out,
172 ));
173 out.into()
174 })
175 }
176
177 #[allow(unused_variables)]
183 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
184 intrinsic!(|scope| {
185 scope.register(Instruction::new(
186 Operator::UncheckedIndexAssign(IndexAssignOperator {
187 index: i.expand.consume(),
188 value: value.expand.consume(),
189 vector_size: 0,
190 unroll_factor: 1,
191 }),
192 *self.expand,
193 ));
194 })
195 }
196 }
197}
198
199mod vector {
201 use super::*;
202
203 impl<P: Scalar, N: Size> Tensor<Vector<P, N>> {
204 pub fn vector_size(&self) -> VectorSize {
212 N::value()
213 }
214
215 pub fn __expand_vector_size(
217 expand: <Self as CubeType>::ExpandType,
218 scope: &mut Scope,
219 ) -> VectorSize {
220 expand.__expand_vector_size_method(scope)
221 }
222 }
223}
224
225impl<T: CubePrimitive> SizedContainer for Tensor<T> {
226 type Item = T;
227}
228
229impl<T: CubeType> Iterator for &Tensor<T> {
230 type Item = T;
231
232 fn next(&mut self) -> Option<Self::Item> {
233 unexpanded!()
234 }
235}
236
237impl<T: CubeType> CubeType for Tensor<T> {
238 type ExpandType = NativeExpand<Tensor<T>>;
239}
240
241impl<T: CubeType> CubeType for *const Tensor<T> {
242 type ExpandType = NativeExpand<Tensor<T>>;
243}
244
245impl<T: CubeType> CubeType for *mut Tensor<T> {
246 type ExpandType = NativeExpand<Tensor<T>>;
247}
248
249impl<T: CubeType> CubeType for &mut Tensor<T> {
250 type ExpandType = NativeExpand<Tensor<T>>;
251}
252
253impl<T: CubeType> CubeType for &Tensor<T> {
254 type ExpandType = NativeExpand<Tensor<T>>;
255}
256
257impl<C: CubeType> IntoMut for NativeExpand<Tensor<C>> {
258 fn into_mut(self, _scope: &mut Scope) -> Self {
259 self
260 }
261}
262
263impl<T: CubePrimitive> List<T> for Tensor<T> {
264 fn __expand_read(
265 scope: &mut Scope,
266 this: NativeExpand<Tensor<T>>,
267 idx: NativeExpand<usize>,
268 ) -> NativeExpand<T> {
269 index::expand(scope, this, idx)
270 }
271}
272
273impl<T: CubePrimitive> Deref for Tensor<T> {
274 type Target = [T];
275
276 fn deref(&self) -> &Self::Target {
277 unexpanded!()
278 }
279}
280
281impl<T: CubePrimitive> DerefMut for Tensor<T> {
282 fn deref_mut(&mut self) -> &mut Self::Target {
283 unexpanded!()
284 }
285}
286
287impl<T: CubePrimitive> ListExpand<T> for NativeExpand<Tensor<T>> {
288 fn __expand_read_method(&self, scope: &mut Scope, idx: NativeExpand<usize>) -> NativeExpand<T> {
289 index::expand(scope, self.clone(), idx)
290 }
291 fn __expand_read_unchecked_method(
292 &self,
293 scope: &mut Scope,
294 idx: NativeExpand<usize>,
295 ) -> NativeExpand<T> {
296 index_unchecked::expand(scope, self.clone(), idx)
297 }
298
299 fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
300 Self::__expand_len(scope, self.clone())
301 }
302}
303
304impl<T: CubePrimitive> Vectorized for Tensor<T> {}
305impl<T: CubePrimitive> VectorizedExpand for NativeExpand<Tensor<T>> {
306 fn vector_size(&self) -> VectorSize {
307 self.expand.ty.vector_size()
308 }
309}
310
311impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
312 fn __expand_write(
313 scope: &mut Scope,
314 this: NativeExpand<Tensor<T>>,
315 idx: NativeExpand<usize>,
316 value: NativeExpand<T>,
317 ) {
318 index_assign::expand(scope, this, idx, value);
319 }
320}
321
322impl<T: CubePrimitive> ListMutExpand<T> for NativeExpand<Tensor<T>> {
323 fn __expand_write_method(
324 &self,
325 scope: &mut Scope,
326 idx: NativeExpand<usize>,
327 value: NativeExpand<T>,
328 ) {
329 index_assign::expand(scope, self.clone(), idx, value);
330 }
331}