cubecl_core/frontend/container/tensor/
base.rs1use crate::{
2 frontend::{
3 CubePrimitive, CubeType, ExpandElementBaseInit, ExpandElementTyped, SizedContainer,
4 indexation::Index,
5 },
6 ir::{Item, Metadata, Scope},
7 prelude::{Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign},
8 unexpanded,
9};
10use cubecl_ir::ExpandElement;
11use std::{marker::PhantomData, num::NonZero};
12
13#[derive(new)]
16pub struct Tensor<T: CubeType> {
17 _val: PhantomData<T>,
18}
19
20mod metadata {
22 use cubecl_ir::ExpandElement;
23
24 use super::*;
25 use crate::{
26 ir::{Arithmetic, BinaryOperator, Instruction},
27 prelude::Array,
28 };
29
30 impl<T: CubeType> Tensor<T> {
31 pub fn stride<C: Index>(&self, _dim: C) -> u32 {
33 unexpanded!()
34 }
35
36 pub fn shape<C: Index>(&self, _dim: C) -> u32 {
38 unexpanded!()
39 }
40
41 pub fn coordinate<I: Index, D: Index>(&self, _index: I, _dim: D) -> u32 {
46 unexpanded!()
47 }
48
49 #[allow(clippy::len_without_is_empty)]
56 pub fn len(&self) -> u32 {
57 unexpanded!()
58 }
59
60 #[allow(clippy::len_without_is_empty)]
67 pub fn buffer_len(&self) -> u32 {
68 unexpanded!()
69 }
70
71 pub fn rank(&self) -> u32 {
73 unexpanded!()
74 }
75
76 pub fn __expand_stride<C: Index>(
78 scope: &mut Scope,
79 expand: ExpandElementTyped<Tensor<T>>,
80 dim: ExpandElementTyped<u32>,
81 ) -> ExpandElementTyped<u32> {
82 expand.__expand_stride_method(scope, dim)
83 }
84
85 pub fn __expand_shape<C: Index>(
87 scope: &mut Scope,
88 expand: ExpandElementTyped<Tensor<T>>,
89 dim: ExpandElementTyped<u32>,
90 ) -> ExpandElementTyped<u32> {
91 expand.__expand_shape_method(scope, dim)
92 }
93
94 pub fn __expand_coordinate<I: Index, D: Index>(
96 scope: &mut Scope,
97 expand: ExpandElementTyped<Tensor<T>>,
98 index: ExpandElementTyped<u32>,
99 dim: ExpandElementTyped<u32>,
100 ) -> ExpandElementTyped<u32> {
101 expand.__expand_coordinate_method(scope, index, dim)
102 }
103
104 pub fn __expand_len<C: Index>(
106 scope: &mut Scope,
107 expand: ExpandElementTyped<Tensor<T>>,
108 ) -> ExpandElementTyped<u32> {
109 expand.__expand_len_method(scope)
110 }
111
112 pub fn __expand_buffer_len<C: Index>(
114 scope: &mut Scope,
115 expand: ExpandElementTyped<Tensor<T>>,
116 ) -> ExpandElementTyped<u32> {
117 expand.__expand_buffer_len_method(scope)
118 }
119
120 pub fn __expand_rank<C: Index>(
122 scope: &mut Scope,
123 expand: ExpandElementTyped<Tensor<T>>,
124 ) -> ExpandElementTyped<u32> {
125 expand.__expand_rank_method(scope)
126 }
127 }
128
129 impl<T: CubeType> ExpandElementTyped<Tensor<T>> {
130 pub fn __expand_stride_method(
132 self,
133 scope: &mut Scope,
134 dim: ExpandElementTyped<u32>,
135 ) -> ExpandElementTyped<u32> {
136 let dim: ExpandElement = dim.into();
137 let out = scope.create_local(Item::new(u32::as_elem(scope)));
138 scope.register(Instruction::new(
139 Metadata::Stride {
140 dim: *dim,
141 var: self.expand.into(),
142 },
143 out.clone().into(),
144 ));
145 out.into()
146 }
147
148 pub fn __expand_shape_method(
150 self,
151 scope: &mut Scope,
152 dim: ExpandElementTyped<u32>,
153 ) -> ExpandElementTyped<u32> {
154 let dim: ExpandElement = dim.into();
155 let out = scope.create_local(Item::new(u32::as_elem(scope)));
156 scope.register(Instruction::new(
157 Metadata::Shape {
158 dim: *dim,
159 var: self.expand.into(),
160 },
161 out.clone().into(),
162 ));
163 out.into()
164 }
165
166 pub fn __expand_coordinate_method(
168 self,
169 scope: &mut Scope,
170 index: ExpandElementTyped<u32>,
171 dim: ExpandElementTyped<u32>,
172 ) -> ExpandElementTyped<u32> {
173 let index: ExpandElement = index.into();
174 let stride = self.clone().__expand_stride_method(scope, dim.clone());
175 let shape = self.clone().__expand_shape_method(scope, dim.clone());
176
177 let num_strides = scope.create_local(Item::new(u32::as_elem(scope)));
179 scope.register(Instruction::new(
180 Arithmetic::Div(BinaryOperator {
181 lhs: *index,
182 rhs: stride.expand.into(),
183 }),
184 num_strides.clone().into(),
185 ));
186
187 let coordinate = scope.create_local(Item::new(u32::as_elem(scope)));
189 scope.register(Instruction::new(
190 Arithmetic::Modulo(BinaryOperator {
191 lhs: *num_strides,
192 rhs: shape.expand.into(),
193 }),
194 coordinate.clone().into(),
195 ));
196
197 coordinate.into()
198 }
199
200 pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
202 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
203 elem.__expand_len_method(scope)
204 }
205
206 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
208 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
209 elem.__expand_buffer_len_method(scope)
210 }
211
212 pub fn __expand_rank_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
214 let out = scope.create_local(Item::new(u32::as_elem(scope)));
215 scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
216 out.into()
217 }
218 }
219}
220
221mod indexation {
223 use cubecl_ir::Operator;
224
225 use crate::{
226 ir::{BinaryOperator, Instruction},
227 prelude::{CubeIndex, CubeIndexMut},
228 };
229
230 use super::*;
231
232 impl<E: CubePrimitive> Tensor<E> {
233 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
239 where
240 Self: CubeIndex<I>,
241 {
242 unexpanded!()
243 }
244
245 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
251 where
252 Self: CubeIndexMut<I>,
253 {
254 unexpanded!()
255 }
256 }
257
258 impl<E: CubePrimitive> ExpandElementTyped<Tensor<E>> {
259 pub fn __expand_index_unchecked_method(
260 self,
261 scope: &mut Scope,
262 i: ExpandElementTyped<u32>,
263 ) -> ExpandElementTyped<E> {
264 let out = scope.create_local(self.expand.item);
265 scope.register(Instruction::new(
266 Operator::UncheckedIndex(BinaryOperator {
267 lhs: *self.expand,
268 rhs: i.expand.consume(),
269 }),
270 *out,
271 ));
272 out.into()
273 }
274
275 pub fn __expand_index_assign_unchecked_method(
276 self,
277 scope: &mut Scope,
278 i: ExpandElementTyped<u32>,
279 value: ExpandElementTyped<E>,
280 ) {
281 scope.register(Instruction::new(
282 Operator::UncheckedIndexAssign(BinaryOperator {
283 lhs: i.expand.consume(),
284 rhs: value.expand.consume(),
285 }),
286 *self.expand,
287 ));
288 }
289 }
290}
291
292mod line {
294 use super::*;
295
296 impl<P: CubePrimitive> Tensor<Line<P>> {
297 pub fn line_size(&self) -> u32 {
305 unexpanded!()
306 }
307
308 pub fn __expand_line_size(
310 expand: <Self as CubeType>::ExpandType,
311 scope: &mut Scope,
312 ) -> u32 {
313 expand.__expand_line_size_method(scope)
314 }
315 }
316
317 impl<P: CubePrimitive> ExpandElementTyped<Tensor<Line<P>>> {
318 pub fn line_size(&self) -> u32 {
320 self.expand
321 .item
322 .vectorization
323 .unwrap_or(NonZero::new(1).unwrap())
324 .get() as u32
325 }
326
327 pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
329 self.line_size()
330 }
331 }
332}
333
334impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Tensor<T> {
335 type Item = T;
336}
337
338impl<T: CubeType> Iterator for &Tensor<T> {
339 type Item = T;
340
341 fn next(&mut self) -> Option<Self::Item> {
342 unexpanded!()
343 }
344}
345
346impl<T: CubeType> CubeType for Tensor<T> {
347 type ExpandType = ExpandElementTyped<Tensor<T>>;
348}
349
350impl<T: CubeType> CubeType for *const Tensor<T> {
351 type ExpandType = ExpandElementTyped<Tensor<T>>;
352}
353
354impl<T: CubeType> CubeType for *mut Tensor<T> {
355 type ExpandType = ExpandElementTyped<Tensor<T>>;
356}
357
358impl<C: CubeType> ExpandElementBaseInit for Tensor<C> {
359 fn init_elem(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
360 elem
362 }
363}
364
365impl<T: CubePrimitive> List<T> for Tensor<T> {
366 fn __expand_read(
367 scope: &mut Scope,
368 this: ExpandElementTyped<Tensor<T>>,
369 idx: ExpandElementTyped<u32>,
370 ) -> ExpandElementTyped<T> {
371 index::expand(scope, this, idx)
372 }
373}
374
375impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
376 fn __expand_read_method(
377 self,
378 scope: &mut Scope,
379 idx: ExpandElementTyped<u32>,
380 ) -> ExpandElementTyped<T> {
381 index::expand(scope, self, idx)
382 }
383}
384
385impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
386 fn __expand_write(
387 scope: &mut Scope,
388 this: ExpandElementTyped<Tensor<T>>,
389 idx: ExpandElementTyped<u32>,
390 value: ExpandElementTyped<T>,
391 ) {
392 index_assign::expand(scope, this, idx, value);
393 }
394}
395
396impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
397 fn __expand_write_method(
398 self,
399 scope: &mut Scope,
400 idx: ExpandElementTyped<u32>,
401 value: ExpandElementTyped<T>,
402 ) {
403 index_assign::expand(scope, self, idx, value);
404 }
405}