cubecl_core/frontend/container/tensor/
base.rs1use crate::{
2 frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, SizedContainer},
3 ir::{Item, Metadata, Scope},
4 prelude::{
5 Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
6 },
7 unexpanded,
8};
9use cubecl_ir::ExpandElement;
10use cubecl_macros::{cube, intrinsic};
11use std::{marker::PhantomData, num::NonZero};
12
13use crate as cubecl;
14
15#[derive(new)]
18pub struct Tensor<T: CubeType> {
19 _val: PhantomData<T>,
20}
21
22type TensorExpand<T> = ExpandElementTyped<Tensor<T>>;
23
24mod metadata {
26 use cubecl_ir::ExpandElement;
27
28 use super::*;
29 use crate::{
30 ir::{Arithmetic, BinaryOperator, Instruction},
31 prelude::Array,
32 };
33
34 #[cube]
35 impl<T: CubeType> Tensor<T> {
36 #[allow(unused_variables)]
38 pub fn stride(&self, dim: u32) -> u32 {
39 intrinsic!(|scope| {
40 let dim: ExpandElement = dim.into();
41 let out = scope.create_local(Item::new(u32::as_elem(scope)));
42 scope.register(Instruction::new(
43 Metadata::Stride {
44 dim: *dim,
45 var: self.expand.into(),
46 },
47 out.clone().into(),
48 ));
49 out.into()
50 })
51 }
52
53 #[allow(unused_variables)]
55 pub fn shape(&self, dim: u32) -> u32 {
56 intrinsic!(|scope| {
57 let dim: ExpandElement = dim.into();
58 let out = scope.create_local(Item::new(u32::as_elem(scope)));
59 scope.register(Instruction::new(
60 Metadata::Shape {
61 dim: *dim,
62 var: self.expand.into(),
63 },
64 out.clone().into(),
65 ));
66 out.into()
67 })
68 }
69
70 #[allow(unused_variables)]
75 pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
76 intrinsic!(|scope| {
77 let index: ExpandElement = index.into();
78 let stride = self.clone().__expand_stride_method(scope, dim.clone());
79 let shape = self.clone().__expand_shape_method(scope, dim.clone());
80
81 let num_strides = scope.create_local(Item::new(u32::as_elem(scope)));
83 scope.register(Instruction::new(
84 Arithmetic::Div(BinaryOperator {
85 lhs: *index,
86 rhs: stride.expand.into(),
87 }),
88 num_strides.clone().into(),
89 ));
90
91 let coordinate = scope.create_local(Item::new(u32::as_elem(scope)));
93 scope.register(Instruction::new(
94 Arithmetic::Modulo(BinaryOperator {
95 lhs: *num_strides,
96 rhs: shape.expand.into(),
97 }),
98 coordinate.clone().into(),
99 ));
100
101 coordinate.into()
102 })
103 }
104
105 #[allow(clippy::len_without_is_empty)]
112 pub fn len(&self) -> u32 {
113 intrinsic!(|scope| {
114 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
115 elem.__expand_len_method(scope)
116 })
117 }
118
119 #[allow(clippy::len_without_is_empty)]
126 pub fn buffer_len(&self) -> u32 {
127 intrinsic!(|scope| {
128 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
129 elem.__expand_buffer_len_method(scope)
130 })
131 }
132
133 pub fn rank(&self) -> u32 {
135 intrinsic!(|scope| {
136 let out = scope.create_local(Item::new(u32::as_elem(scope)));
137 scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
138 out.into()
139 })
140 }
141 }
142}
143
144mod indexation {
146 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
147
148 use crate::{
149 ir::Instruction,
150 prelude::{CubeIndex, CubeIndexMut},
151 };
152
153 use super::*;
154
155 #[cube]
156 impl<E: CubePrimitive> Tensor<E> {
157 #[allow(unused_variables)]
163 pub unsafe fn index_unchecked(&self, i: u32) -> &E
164 where
165 Self: CubeIndex,
166 {
167 intrinsic!(|scope| {
168 let out = scope.create_local(self.expand.item);
169 scope.register(Instruction::new(
170 Operator::UncheckedIndex(IndexOperator {
171 list: *self.expand,
172 index: i.expand.consume(),
173 line_size: 0,
174 }),
175 *out,
176 ));
177 out.into()
178 })
179 }
180
181 #[allow(unused_variables)]
187 pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
188 where
189 Self: CubeIndexMut,
190 {
191 intrinsic!(|scope| {
192 scope.register(Instruction::new(
193 Operator::UncheckedIndexAssign(IndexAssignOperator {
194 index: i.expand.consume(),
195 value: value.expand.consume(),
196 line_size: 0,
197 }),
198 *self.expand,
199 ));
200 })
201 }
202 }
203}
204
205mod line {
207 use super::*;
208
209 impl<P: CubePrimitive> Tensor<Line<P>> {
210 pub fn line_size(&self) -> u32 {
218 unexpanded!()
219 }
220
221 pub fn __expand_line_size(
223 expand: <Self as CubeType>::ExpandType,
224 scope: &mut Scope,
225 ) -> u32 {
226 expand.__expand_line_size_method(scope)
227 }
228 }
229
230 impl<P: CubePrimitive> ExpandElementTyped<Tensor<Line<P>>> {
231 pub fn line_size(&self) -> u32 {
233 self.expand
234 .item
235 .vectorization
236 .unwrap_or(NonZero::new(1).unwrap())
237 .get() as u32
238 }
239
240 pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
242 self.line_size()
243 }
244 }
245}
246
247impl<T: CubePrimitive> SizedContainer for Tensor<T> {
248 type Item = T;
249}
250
251impl<T: CubeType> Iterator for &Tensor<T> {
252 type Item = T;
253
254 fn next(&mut self) -> Option<Self::Item> {
255 unexpanded!()
256 }
257}
258
259impl<T: CubeType> CubeType for Tensor<T> {
260 type ExpandType = ExpandElementTyped<Tensor<T>>;
261}
262
263impl<T: CubeType> CubeType for *const Tensor<T> {
264 type ExpandType = ExpandElementTyped<Tensor<T>>;
265}
266
267impl<T: CubeType> CubeType for *mut Tensor<T> {
268 type ExpandType = ExpandElementTyped<Tensor<T>>;
269}
270
271impl<C: CubeType> ExpandElementIntoMut for Tensor<C> {
272 fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
273 elem
274 }
275}
276
277impl<T: CubePrimitive> List<T> for Tensor<T> {
278 fn __expand_read(
279 scope: &mut Scope,
280 this: ExpandElementTyped<Tensor<T>>,
281 idx: ExpandElementTyped<u32>,
282 ) -> ExpandElementTyped<T> {
283 index::expand(scope, this, idx)
284 }
285}
286
287impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
288 fn __expand_read_method(
289 &self,
290 scope: &mut Scope,
291 idx: ExpandElementTyped<u32>,
292 ) -> ExpandElementTyped<T> {
293 index::expand(scope, self.clone(), idx)
294 }
295 fn __expand_read_unchecked_method(
296 &self,
297 scope: &mut Scope,
298 idx: ExpandElementTyped<u32>,
299 ) -> ExpandElementTyped<T> {
300 index_unchecked::expand(scope, self.clone(), idx)
301 }
302}
303
304impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
305 fn __expand_write(
306 scope: &mut Scope,
307 this: ExpandElementTyped<Tensor<T>>,
308 idx: ExpandElementTyped<u32>,
309 value: ExpandElementTyped<T>,
310 ) {
311 index_assign::expand(scope, this, idx, value);
312 }
313}
314
315impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
316 fn __expand_write_method(
317 &self,
318 scope: &mut Scope,
319 idx: ExpandElementTyped<u32>,
320 value: ExpandElementTyped<T>,
321 ) {
322 index_assign::expand(scope, self.clone(), idx, value);
323 }
324}