cubecl_core/frontend/container/tensor/
base.rs1use crate::{
2 frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, SizedContainer},
3 ir::{Metadata, Scope, Type},
4 prelude::{
5 Line, Lined, LinedExpand, List, ListExpand, ListMut, ListMutExpand, index, index_assign,
6 index_unchecked,
7 },
8 unexpanded,
9};
10use cubecl_ir::{ExpandElement, LineSize};
11use cubecl_macros::{cube, intrinsic};
12use std::{
13 marker::PhantomData,
14 ops::{Deref, DerefMut},
15};
16
17use crate as cubecl;
18
19#[derive(new, Clone, Copy)]
22pub struct Tensor<T: CubeType> {
23 _val: PhantomData<T>,
24}
25
26type TensorExpand<T> = ExpandElementTyped<Tensor<T>>;
27
28mod metadata {
30 use cubecl_ir::ExpandElement;
31
32 use super::*;
33 use crate::{
34 ir::{Arithmetic, BinaryOperator, Instruction},
35 prelude::Array,
36 };
37
38 #[cube]
39 impl<T: CubeType> Tensor<T> {
40 #[allow(unused_variables)]
42 pub fn stride(&self, dim: usize) -> usize {
43 intrinsic!(|scope| {
44 let dim: ExpandElement = dim.into();
45 let out = scope.create_local(Type::new(usize::as_type(scope)));
46 scope.register(Instruction::new(
47 Metadata::Stride {
48 dim: *dim,
49 var: self.expand.into(),
50 },
51 out.clone().into(),
52 ));
53 out.into()
54 })
55 }
56
57 #[allow(unused_variables)]
59 pub fn shape(&self, dim: usize) -> usize {
60 intrinsic!(|scope| {
61 let dim: ExpandElement = dim.into();
62 let out = scope.create_local(Type::new(usize::as_type(scope)));
63 scope.register(Instruction::new(
64 Metadata::Shape {
65 dim: *dim,
66 var: self.expand.into(),
67 },
68 out.clone().into(),
69 ));
70 out.into()
71 })
72 }
73
74 #[allow(unused_variables)]
79 pub fn coordinate(&self, index: usize, dim: usize) -> usize {
80 intrinsic!(|scope| {
81 let index: ExpandElement = index.into();
82 let stride = self.clone().__expand_stride_method(scope, dim.clone());
83 let shape = self.clone().__expand_shape_method(scope, dim.clone());
84
85 let num_strides = scope.create_local(Type::new(usize::as_type(scope)));
87 scope.register(Instruction::new(
88 Arithmetic::Div(BinaryOperator {
89 lhs: *index,
90 rhs: stride.expand.into(),
91 }),
92 num_strides.clone().into(),
93 ));
94
95 let coordinate = scope.create_local(Type::new(usize::as_type(scope)));
97 scope.register(Instruction::new(
98 Arithmetic::Modulo(BinaryOperator {
99 lhs: *num_strides,
100 rhs: shape.expand.into(),
101 }),
102 coordinate.clone().into(),
103 ));
104
105 coordinate.into()
106 })
107 }
108
109 #[allow(clippy::len_without_is_empty)]
116 pub fn len(&self) -> usize {
117 intrinsic!(|scope| {
118 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
119 elem.__expand_len_method(scope)
120 })
121 }
122
123 #[allow(clippy::len_without_is_empty)]
130 pub fn buffer_len(&self) -> usize {
131 intrinsic!(|scope| {
132 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
133 elem.__expand_buffer_len_method(scope)
134 })
135 }
136
137 pub fn rank(&self) -> usize {
139 intrinsic!(|scope| {
140 let out = scope.create_local(Type::new(usize::as_type(scope)));
141 scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
142 out.into()
143 })
144 }
145 }
146}
147
148mod indexation {
150 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
151
152 use crate::{
153 ir::Instruction,
154 prelude::{CubeIndex, CubeIndexMut},
155 };
156
157 use super::*;
158
159 #[cube]
160 impl<E: CubePrimitive> Tensor<E> {
161 #[allow(unused_variables)]
167 pub unsafe fn index_unchecked(&self, i: usize) -> &E
168 where
169 Self: CubeIndex,
170 {
171 intrinsic!(|scope| {
172 let out = scope.create_local(self.expand.ty);
173 scope.register(Instruction::new(
174 Operator::UncheckedIndex(IndexOperator {
175 list: *self.expand,
176 index: i.expand.consume(),
177 line_size: 0,
178 unroll_factor: 1,
179 }),
180 *out,
181 ));
182 out.into()
183 })
184 }
185
186 #[allow(unused_variables)]
192 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
193 where
194 Self: CubeIndexMut,
195 {
196 intrinsic!(|scope| {
197 scope.register(Instruction::new(
198 Operator::UncheckedIndexAssign(IndexAssignOperator {
199 index: i.expand.consume(),
200 value: value.expand.consume(),
201 line_size: 0,
202 unroll_factor: 1,
203 }),
204 *self.expand,
205 ));
206 })
207 }
208 }
209}
210
211mod line {
213 use super::*;
214
215 impl<P: CubePrimitive> Tensor<Line<P>> {
216 pub fn line_size(&self) -> LineSize {
224 unexpanded!()
225 }
226
227 pub fn __expand_line_size(
229 expand: <Self as CubeType>::ExpandType,
230 scope: &mut Scope,
231 ) -> LineSize {
232 expand.__expand_line_size_method(scope)
233 }
234 }
235}
236
237impl<T: CubePrimitive> SizedContainer for Tensor<T> {
238 type Item = T;
239}
240
241impl<T: CubeType> Iterator for &Tensor<T> {
242 type Item = T;
243
244 fn next(&mut self) -> Option<Self::Item> {
245 unexpanded!()
246 }
247}
248
249impl<T: CubeType> CubeType for Tensor<T> {
250 type ExpandType = ExpandElementTyped<Tensor<T>>;
251}
252
253impl<T: CubeType> CubeType for *const Tensor<T> {
254 type ExpandType = ExpandElementTyped<Tensor<T>>;
255}
256
257impl<T: CubeType> CubeType for *mut Tensor<T> {
258 type ExpandType = ExpandElementTyped<Tensor<T>>;
259}
260
261impl<T: CubeType> CubeType for &mut Tensor<T> {
262 type ExpandType = ExpandElementTyped<Tensor<T>>;
263}
264
265impl<T: CubeType> CubeType for &Tensor<T> {
266 type ExpandType = ExpandElementTyped<Tensor<T>>;
267}
268
269impl<C: CubeType> ExpandElementIntoMut for Tensor<C> {
270 fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
271 elem
272 }
273}
274
275impl<T: CubePrimitive> List<T> for Tensor<T> {
276 fn __expand_read(
277 scope: &mut Scope,
278 this: ExpandElementTyped<Tensor<T>>,
279 idx: ExpandElementTyped<usize>,
280 ) -> ExpandElementTyped<T> {
281 index::expand(scope, this, idx)
282 }
283}
284
285impl<T: CubePrimitive> Deref for Tensor<T> {
286 type Target = [T];
287
288 fn deref(&self) -> &Self::Target {
289 unexpanded!()
290 }
291}
292
293impl<T: CubePrimitive> DerefMut for Tensor<T> {
294 fn deref_mut(&mut self) -> &mut Self::Target {
295 unexpanded!()
296 }
297}
298
299impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
300 fn __expand_read_method(
301 &self,
302 scope: &mut Scope,
303 idx: ExpandElementTyped<usize>,
304 ) -> ExpandElementTyped<T> {
305 index::expand(scope, self.clone(), idx)
306 }
307 fn __expand_read_unchecked_method(
308 &self,
309 scope: &mut Scope,
310 idx: ExpandElementTyped<usize>,
311 ) -> ExpandElementTyped<T> {
312 index_unchecked::expand(scope, self.clone(), idx)
313 }
314
315 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
316 Self::__expand_len(scope, self.clone())
317 }
318}
319
320impl<T: CubePrimitive> Lined for Tensor<T> {}
321impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Tensor<T>> {
322 fn line_size(&self) -> LineSize {
323 self.expand.ty.line_size()
324 }
325}
326
327impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
328 fn __expand_write(
329 scope: &mut Scope,
330 this: ExpandElementTyped<Tensor<T>>,
331 idx: ExpandElementTyped<usize>,
332 value: ExpandElementTyped<T>,
333 ) {
334 index_assign::expand(scope, this, idx, value);
335 }
336}
337
338impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
339 fn __expand_write_method(
340 &self,
341 scope: &mut Scope,
342 idx: ExpandElementTyped<usize>,
343 value: ExpandElementTyped<T>,
344 ) {
345 index_assign::expand(scope, self.clone(), idx, value);
346 }
347}