cubecl_core/frontend/container/tensor/
base.rs1use crate::frontend::{ExpandElementBaseInit, ExpandElementTyped, SizedContainer};
2use crate::prelude::IntoRuntime;
3use crate::{
4 frontend::{indexation::Index, CubeContext, CubePrimitive, CubeType, ExpandElement},
5 ir::{Item, Metadata},
6 prelude::Line,
7 unexpanded,
8};
9use std::{marker::PhantomData, num::NonZero};
10
11#[derive(new)]
14pub struct Tensor<T: CubeType> {
15 _val: PhantomData<T>,
16}
17
18mod metadata {
20 use super::*;
21 use crate::{
22 ir::{BinaryOperator, Instruction, Operator},
23 prelude::Array,
24 };
25
26 impl<T: CubeType> Tensor<T> {
27 pub fn stride<C: Index>(&self, _dim: C) -> u32 {
29 unexpanded!()
30 }
31
32 pub fn shape<C: Index>(&self, _dim: C) -> u32 {
34 unexpanded!()
35 }
36
37 pub fn coordinate<I: Index, D: Index>(&self, _index: I, _dim: D) -> u32 {
42 unexpanded!()
43 }
44
45 #[allow(clippy::len_without_is_empty)]
52 pub fn len(&self) -> u32 {
53 unexpanded!()
54 }
55
56 #[allow(clippy::len_without_is_empty)]
63 pub fn buffer_len(&self) -> u32 {
64 unexpanded!()
65 }
66
67 pub fn rank(&self) -> u32 {
69 unexpanded!()
70 }
71
72 pub fn __expand_stride<C: Index>(
74 context: &mut CubeContext,
75 expand: ExpandElementTyped<Tensor<T>>,
76 dim: ExpandElementTyped<u32>,
77 ) -> ExpandElementTyped<u32> {
78 expand.__expand_stride_method(context, dim)
79 }
80
81 pub fn __expand_shape<C: Index>(
83 context: &mut CubeContext,
84 expand: ExpandElementTyped<Tensor<T>>,
85 dim: ExpandElementTyped<u32>,
86 ) -> ExpandElementTyped<u32> {
87 expand.__expand_shape_method(context, dim)
88 }
89
90 pub fn __expand_coordinate<I: Index, D: Index>(
92 context: &mut CubeContext,
93 expand: ExpandElementTyped<Tensor<T>>,
94 index: ExpandElementTyped<u32>,
95 dim: ExpandElementTyped<u32>,
96 ) -> ExpandElementTyped<u32> {
97 expand.__expand_coordinate_method(context, index, dim)
98 }
99
100 pub fn __expand_len<C: Index>(
102 context: &mut CubeContext,
103 expand: ExpandElementTyped<Tensor<T>>,
104 ) -> ExpandElementTyped<u32> {
105 expand.__expand_len_method(context)
106 }
107
108 pub fn __expand_buffer_len<C: Index>(
110 context: &mut CubeContext,
111 expand: ExpandElementTyped<Tensor<T>>,
112 ) -> ExpandElementTyped<u32> {
113 expand.__expand_buffer_len_method(context)
114 }
115
116 pub fn __expand_rank<C: Index>(
118 context: &mut CubeContext,
119 expand: ExpandElementTyped<Tensor<T>>,
120 ) -> ExpandElementTyped<u32> {
121 expand.__expand_rank_method(context)
122 }
123 }
124
125 impl<T: CubeType> ExpandElementTyped<Tensor<T>> {
126 pub fn __expand_stride_method(
128 self,
129 context: &mut CubeContext,
130 dim: ExpandElementTyped<u32>,
131 ) -> ExpandElementTyped<u32> {
132 let dim: ExpandElement = dim.into();
133 let out = context.create_local(Item::new(u32::as_elem(context)));
134 context.register(Instruction::new(
135 Metadata::Stride {
136 dim: *dim,
137 var: self.expand.into(),
138 },
139 out.clone().into(),
140 ));
141 out.into()
142 }
143
144 pub fn __expand_shape_method(
146 self,
147 context: &mut CubeContext,
148 dim: ExpandElementTyped<u32>,
149 ) -> ExpandElementTyped<u32> {
150 let dim: ExpandElement = dim.into();
151 let out = context.create_local(Item::new(u32::as_elem(context)));
152 context.register(Instruction::new(
153 Metadata::Shape {
154 dim: *dim,
155 var: self.expand.into(),
156 },
157 out.clone().into(),
158 ));
159 out.into()
160 }
161
162 pub fn __expand_coordinate_method(
164 self,
165 context: &mut CubeContext,
166 index: ExpandElementTyped<u32>,
167 dim: ExpandElementTyped<u32>,
168 ) -> ExpandElementTyped<u32> {
169 let index: ExpandElement = index.into();
170 let stride = self.clone().__expand_stride_method(context, dim.clone());
171 let shape = self.clone().__expand_shape_method(context, dim.clone());
172
173 let num_strides = context.create_local(Item::new(u32::as_elem(context)));
175 context.register(Instruction::new(
176 Operator::Div(BinaryOperator {
177 lhs: *index,
178 rhs: stride.expand.into(),
179 }),
180 num_strides.clone().into(),
181 ));
182
183 let coordinate = context.create_local(Item::new(u32::as_elem(context)));
185 context.register(Instruction::new(
186 Operator::Modulo(BinaryOperator {
187 lhs: *num_strides,
188 rhs: shape.expand.into(),
189 }),
190 coordinate.clone().into(),
191 ));
192
193 coordinate.into()
194 }
195
196 pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
198 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
199 elem.__expand_len_method(context)
200 }
201
202 pub fn __expand_buffer_len_method(
204 self,
205 context: &mut CubeContext,
206 ) -> ExpandElementTyped<u32> {
207 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
208 elem.__expand_buffer_len_method(context)
209 }
210
211 pub fn __expand_rank_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
213 let out = context.create_local(Item::new(u32::as_elem(context)));
214 context.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
215 out.into()
216 }
217 }
218}
219
220mod indexation {
222 use crate::{
223 ir::{BinaryOperator, Instruction, Operator},
224 prelude::{CubeIndex, CubeIndexMut},
225 };
226
227 use super::*;
228
229 impl<E: CubePrimitive> Tensor<E> {
230 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
236 where
237 Self: CubeIndex<I>,
238 {
239 unexpanded!()
240 }
241
242 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
248 where
249 Self: CubeIndexMut<I>,
250 {
251 unexpanded!()
252 }
253 }
254
255 impl<E: CubePrimitive> ExpandElementTyped<Tensor<E>> {
256 pub fn __expand_index_unchecked_method(
257 self,
258 context: &mut CubeContext,
259 i: ExpandElementTyped<u32>,
260 ) -> ExpandElementTyped<E> {
261 let out = context.create_local(self.expand.item);
262 context.register(Instruction::new(
263 Operator::UncheckedIndex(BinaryOperator {
264 lhs: *self.expand,
265 rhs: i.expand.consume(),
266 }),
267 *out,
268 ));
269 out.into()
270 }
271
272 pub fn __expand_index_assign_unchecked_method(
273 self,
274 context: &mut CubeContext,
275 i: ExpandElementTyped<u32>,
276 value: ExpandElementTyped<E>,
277 ) {
278 context.register(Instruction::new(
279 Operator::UncheckedIndexAssign(BinaryOperator {
280 lhs: i.expand.consume(),
281 rhs: value.expand.consume(),
282 }),
283 *self.expand,
284 ));
285 }
286 }
287}
288
289mod line {
291 use super::*;
292
293 impl<P: CubePrimitive> Tensor<Line<P>> {
294 pub fn line_size(&self) -> u32 {
302 unexpanded!()
303 }
304
305 pub fn __expand_line_size(
307 expand: <Self as CubeType>::ExpandType,
308 context: &mut CubeContext,
309 ) -> u32 {
310 expand.__expand_line_size_method(context)
311 }
312 }
313
314 impl<P: CubePrimitive> ExpandElementTyped<Tensor<Line<P>>> {
315 pub fn line_size(&self) -> u32 {
317 self.expand
318 .item
319 .vectorization
320 .unwrap_or(NonZero::new(1).unwrap())
321 .get() as u32
322 }
323
324 pub fn __expand_line_size_method(&self, _content: &mut CubeContext) -> u32 {
326 self.line_size()
327 }
328 }
329}
330
331impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Tensor<T> {
332 type Item = T;
333}
334
335impl<T: CubeType> Iterator for &Tensor<T> {
336 type Item = T;
337
338 fn next(&mut self) -> Option<Self::Item> {
339 unexpanded!()
340 }
341}
342
343impl<T: CubeType> CubeType for Tensor<T> {
344 type ExpandType = ExpandElementTyped<Tensor<T>>;
345}
346
347impl<T: CubeType> CubeType for *const Tensor<T> {
348 type ExpandType = ExpandElementTyped<Tensor<T>>;
349}
350
351impl<T: CubeType> CubeType for *mut Tensor<T> {
352 type ExpandType = ExpandElementTyped<Tensor<T>>;
353}
354
355impl<C: CubeType> ExpandElementBaseInit for Tensor<C> {
356 fn init_elem(_context: &mut crate::prelude::CubeContext, elem: ExpandElement) -> ExpandElement {
357 elem
359 }
360}
361
362impl<E: CubePrimitive> IntoRuntime for Tensor<E> {
363 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
364 unimplemented!("Tensor can't exist at compile time")
365 }
366}
367
368impl<E: CubePrimitive> IntoRuntime for *const Tensor<E> {
369 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
370 unimplemented!("Tensor can't exist at compile time")
371 }
372}
373
374impl<E: CubePrimitive> IntoRuntime for *mut Tensor<E> {
375 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
376 unimplemented!("Tensor can't exist at compile time")
377 }
378}