cubecl_core/frontend/container/tensor/
base.rs1use crate::{
2 frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, SizedContainer},
3 ir::{Metadata, Scope, Type},
4 prelude::{
5 Line, Lined, LinedExpand, List, ListExpand, ListMut, ListMutExpand, index, index_assign,
6 index_unchecked,
7 },
8 unexpanded,
9};
10use cubecl_ir::ExpandElement;
11use cubecl_macros::{cube, intrinsic};
12use std::marker::PhantomData;
13
14use crate as cubecl;
15
16#[derive(new, Clone, Copy)]
19pub struct Tensor<T: CubeType> {
20 _val: PhantomData<T>,
21}
22
23type TensorExpand<T> = ExpandElementTyped<Tensor<T>>;
24
25mod metadata {
27 use cubecl_ir::ExpandElement;
28
29 use super::*;
30 use crate::{
31 ir::{Arithmetic, BinaryOperator, Instruction},
32 prelude::Array,
33 };
34
35 #[cube]
36 impl<T: CubeType> Tensor<T> {
37 #[allow(unused_variables)]
39 pub fn stride(&self, dim: u32) -> u32 {
40 intrinsic!(|scope| {
41 let dim: ExpandElement = dim.into();
42 let out = scope.create_local(Type::new(u32::as_type(scope)));
43 scope.register(Instruction::new(
44 Metadata::Stride {
45 dim: *dim,
46 var: self.expand.into(),
47 },
48 out.clone().into(),
49 ));
50 out.into()
51 })
52 }
53
54 #[allow(unused_variables)]
56 pub fn shape(&self, dim: u32) -> u32 {
57 intrinsic!(|scope| {
58 let dim: ExpandElement = dim.into();
59 let out = scope.create_local(Type::new(u32::as_type(scope)));
60 scope.register(Instruction::new(
61 Metadata::Shape {
62 dim: *dim,
63 var: self.expand.into(),
64 },
65 out.clone().into(),
66 ));
67 out.into()
68 })
69 }
70
71 #[allow(unused_variables)]
76 pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
77 intrinsic!(|scope| {
78 let index: ExpandElement = index.into();
79 let stride = self.clone().__expand_stride_method(scope, dim.clone());
80 let shape = self.clone().__expand_shape_method(scope, dim.clone());
81
82 let num_strides = scope.create_local(Type::new(u32::as_type(scope)));
84 scope.register(Instruction::new(
85 Arithmetic::Div(BinaryOperator {
86 lhs: *index,
87 rhs: stride.expand.into(),
88 }),
89 num_strides.clone().into(),
90 ));
91
92 let coordinate = scope.create_local(Type::new(u32::as_type(scope)));
94 scope.register(Instruction::new(
95 Arithmetic::Modulo(BinaryOperator {
96 lhs: *num_strides,
97 rhs: shape.expand.into(),
98 }),
99 coordinate.clone().into(),
100 ));
101
102 coordinate.into()
103 })
104 }
105
106 #[allow(clippy::len_without_is_empty)]
113 pub fn len(&self) -> u32 {
114 intrinsic!(|scope| {
115 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
116 elem.__expand_len_method(scope)
117 })
118 }
119
120 #[allow(clippy::len_without_is_empty)]
127 pub fn buffer_len(&self) -> u32 {
128 intrinsic!(|scope| {
129 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
130 elem.__expand_buffer_len_method(scope)
131 })
132 }
133
134 pub fn rank(&self) -> u32 {
136 intrinsic!(|scope| {
137 let out = scope.create_local(Type::new(u32::as_type(scope)));
138 scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
139 out.into()
140 })
141 }
142 }
143}
144
145mod indexation {
147 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
148
149 use crate::{
150 ir::Instruction,
151 prelude::{CubeIndex, CubeIndexMut},
152 };
153
154 use super::*;
155
156 #[cube]
157 impl<E: CubePrimitive> Tensor<E> {
158 #[allow(unused_variables)]
164 pub unsafe fn index_unchecked(&self, i: u32) -> &E
165 where
166 Self: CubeIndex,
167 {
168 intrinsic!(|scope| {
169 let out = scope.create_local(self.expand.ty);
170 scope.register(Instruction::new(
171 Operator::UncheckedIndex(IndexOperator {
172 list: *self.expand,
173 index: i.expand.consume(),
174 line_size: 0,
175 unroll_factor: 1,
176 }),
177 *out,
178 ));
179 out.into()
180 })
181 }
182
183 #[allow(unused_variables)]
189 pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
190 where
191 Self: CubeIndexMut,
192 {
193 intrinsic!(|scope| {
194 scope.register(Instruction::new(
195 Operator::UncheckedIndexAssign(IndexAssignOperator {
196 index: i.expand.consume(),
197 value: value.expand.consume(),
198 line_size: 0,
199 unroll_factor: 1,
200 }),
201 *self.expand,
202 ));
203 })
204 }
205 }
206}
207
208mod line {
210 use super::*;
211
212 impl<P: CubePrimitive> Tensor<Line<P>> {
213 pub fn line_size(&self) -> u32 {
221 unexpanded!()
222 }
223
224 pub fn __expand_line_size(
226 expand: <Self as CubeType>::ExpandType,
227 scope: &mut Scope,
228 ) -> u32 {
229 expand.__expand_line_size_method(scope)
230 }
231 }
232}
233
234impl<T: CubePrimitive> SizedContainer for Tensor<T> {
235 type Item = T;
236}
237
238impl<T: CubeType> Iterator for &Tensor<T> {
239 type Item = T;
240
241 fn next(&mut self) -> Option<Self::Item> {
242 unexpanded!()
243 }
244}
245
246impl<T: CubeType> CubeType for Tensor<T> {
247 type ExpandType = ExpandElementTyped<Tensor<T>>;
248}
249
250impl<T: CubeType> CubeType for *const Tensor<T> {
251 type ExpandType = ExpandElementTyped<Tensor<T>>;
252}
253
254impl<T: CubeType> CubeType for *mut Tensor<T> {
255 type ExpandType = ExpandElementTyped<Tensor<T>>;
256}
257
258impl<T: CubeType> CubeType for &mut Tensor<T> {
259 type ExpandType = ExpandElementTyped<Tensor<T>>;
260}
261
262impl<T: CubeType> CubeType for &Tensor<T> {
263 type ExpandType = ExpandElementTyped<Tensor<T>>;
264}
265
266impl<C: CubeType> ExpandElementIntoMut for Tensor<C> {
267 fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
268 elem
269 }
270}
271
272impl<T: CubePrimitive> List<T> for Tensor<T> {
273 fn __expand_read(
274 scope: &mut Scope,
275 this: ExpandElementTyped<Tensor<T>>,
276 idx: ExpandElementTyped<u32>,
277 ) -> ExpandElementTyped<T> {
278 index::expand(scope, this, idx)
279 }
280}
281
282impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
283 fn __expand_read_method(
284 &self,
285 scope: &mut Scope,
286 idx: ExpandElementTyped<u32>,
287 ) -> ExpandElementTyped<T> {
288 index::expand(scope, self.clone(), idx)
289 }
290 fn __expand_read_unchecked_method(
291 &self,
292 scope: &mut Scope,
293 idx: ExpandElementTyped<u32>,
294 ) -> ExpandElementTyped<T> {
295 index_unchecked::expand(scope, self.clone(), idx)
296 }
297
298 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
299 Self::__expand_len(scope, self.clone())
300 }
301}
302
303impl<T: CubePrimitive> Lined for Tensor<T> {}
304impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Tensor<T>> {
305 fn line_size(&self) -> u32 {
306 self.expand.ty.line_size()
307 }
308}
309
310impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
311 fn __expand_write(
312 scope: &mut Scope,
313 this: ExpandElementTyped<Tensor<T>>,
314 idx: ExpandElementTyped<u32>,
315 value: ExpandElementTyped<T>,
316 ) {
317 index_assign::expand(scope, this, idx, value);
318 }
319}
320
321impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
322 fn __expand_write_method(
323 &self,
324 scope: &mut Scope,
325 idx: ExpandElementTyped<u32>,
326 value: ExpandElementTyped<T>,
327 ) {
328 index_assign::expand(scope, self.clone(), idx, value);
329 }
330}