1use cubecl::prelude::{CubeType, Scope, *};
2use cubecl_core::{self as cubecl, unexpanded};
3use std::{marker::PhantomData, sync::Arc};
4
5#[derive(Clone)]
7pub struct Read;
8
9#[derive(Clone)]
11pub struct ReadWrite;
12
13#[derive(Clone)]
15pub struct VirtualTensor<E: Numeric, IO = Read> {
16 _e: PhantomData<E>,
18 _p: PhantomData<IO>,
19}
20
21impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
22
23#[derive(Clone)]
25pub struct VirtualTensorExpand<E: Numeric, IO> {
26 state: Arc<dyn VirtualTensorOperationsExpand<E>>,
27 _p: PhantomData<IO>,
28}
29
30impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
31 fn __expand_read(
32 scope: &mut Scope,
33 this: VirtualTensorExpand<E, IO>,
34 index: <u32 as CubeType>::ExpandType,
35 ) -> <Line<E> as CubeType>::ExpandType {
36 this.__expand_read_method(scope, index)
37 }
38}
39
40impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
41 fn __expand_read_method(
42 self,
43 scope: &mut Scope,
44 index: <u32 as CubeType>::ExpandType,
45 ) -> <Line<E> as CubeType>::ExpandType {
46 self.state.clone().__expand_read_method(scope, index)
47 }
48}
49
50#[allow(unused, clippy::all)]
51impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
52 pub fn as_tensor_map(&self) -> TensorMap<E> {
53 unexpanded!()
54 }
55 pub fn as_slice(&self, start: u32, end: u32) -> Slice<Line<E>> {
56 unexpanded!();
57 }
58 pub fn shape(&self, axis: u32) -> u32 {
60 unexpanded!();
61 }
62 pub fn stride(&self, axis: u32) -> u32 {
64 unexpanded!();
65 }
66 pub fn rank(&self) -> u32 {
68 unexpanded!();
69 }
70
71 pub fn len(&self) -> u32 {
72 unexpanded!();
73 }
74
75 pub fn buffer_len(&self) -> u32 {
76 unexpanded!();
77 }
78
79 pub fn __expand_as_tensor_map(
80 context: &mut Scope,
81 this: <Self as CubeType>::ExpandType,
82 ) -> <TensorMap<E> as CubeType>::ExpandType {
83 this.__expand_as_tensor_map_method(context)
84 }
85 pub fn __expand_as_slice(
86 context: &mut Scope,
87 this: <Self as CubeType>::ExpandType,
88 start: <u32 as CubeType>::ExpandType,
89 end: <u32 as CubeType>::ExpandType,
90 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
91 this.__expand_as_slice_method(context, start, end)
92 }
93 pub fn __expand_shape(
94 scope: &mut Scope,
95 this: <Self as CubeType>::ExpandType,
96 axis: <u32 as CubeType>::ExpandType,
97 ) -> <u32 as CubeType>::ExpandType {
98 this.__expand_shape_method(scope, axis)
99 }
100 pub fn __expand_stride(
101 scope: &mut Scope,
102 this: <Self as CubeType>::ExpandType,
103 axis: <u32 as CubeType>::ExpandType,
104 ) -> <u32 as CubeType>::ExpandType {
105 this.__expand_stride_method(scope, axis)
106 }
107 pub fn __expand_rank(
108 scope: &mut Scope,
109 this: <Self as CubeType>::ExpandType,
110 ) -> <u32 as CubeType>::ExpandType {
111 this.__expand_rank_method(scope)
112 }
113 pub fn __expand_len(
114 scope: &mut Scope,
115 this: <Self as CubeType>::ExpandType,
116 ) -> <u32 as CubeType>::ExpandType {
117 this.__expand_len_method(scope)
118 }
119 pub fn __expand_buffer_len(
120 scope: &mut Scope,
121 this: <Self as CubeType>::ExpandType,
122 ) -> <u32 as CubeType>::ExpandType {
123 this.__expand_buffer_len_method(scope)
124 }
125}
126
127#[allow(unused, clippy::all)]
128impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
129 pub fn __expand_as_tensor_map_method(
130 self,
131 context: &mut Scope,
132 ) -> <TensorMap<E> as CubeType>::ExpandType {
133 self.state.clone().__expand_as_tensor_map_method(context)
134 }
135
136 pub fn __expand_as_slice_method(
137 self,
138 context: &mut Scope,
139 start: <u32 as CubeType>::ExpandType,
140 end: <u32 as CubeType>::ExpandType,
141 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
142 self.state
143 .clone()
144 .__expand_read_window_method(context, start, end)
145 }
146
147 pub fn __expand_shape_method(
148 self,
149 scope: &mut Scope,
150 axis: <u32 as CubeType>::ExpandType,
151 ) -> <u32 as CubeType>::ExpandType {
152 let _arg_0 = axis;
153 self.state
154 .clone()
155 .__expand_shape_method(scope, _arg_0.into())
156 }
157
158 pub fn __expand_stride_method(
159 self,
160 scope: &mut Scope,
161 axis: <u32 as CubeType>::ExpandType,
162 ) -> <u32 as CubeType>::ExpandType {
163 let _arg_0 = axis;
164 self.state
165 .clone()
166 .__expand_stride_method(scope, _arg_0.into())
167 }
168
169 pub fn __expand_rank_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
170 self.state.clone().__expand_rank_method(scope)
171 }
172
173 pub fn __expand_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
174 self.state.clone().__expand_len_method(scope)
175 }
176
177 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
178 self.state.clone().__expand_buffer_len_method(scope)
179 }
180
181 pub fn __expand_read(
182 scope: &mut Scope,
183 this: Self,
184 index: <u32 as CubeType>::ExpandType,
185 ) -> <Line<E> as CubeType>::ExpandType {
186 VirtualTensor::<E, IO>::__expand_read(scope, this, index)
187 }
188
189 pub fn __expand_shape(
190 scope: &mut Scope,
191 this: Self,
192 axis: <u32 as CubeType>::ExpandType,
193 ) -> <u32 as CubeType>::ExpandType {
194 VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
195 }
196
197 pub fn __expand_stride(
198 scope: &mut Scope,
199 this: Self,
200 axis: <u32 as CubeType>::ExpandType,
201 ) -> <u32 as CubeType>::ExpandType {
202 VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
203 }
204
205 pub fn __expand_rank(scope: &mut Scope, this: Self) -> <u32 as CubeType>::ExpandType {
206 VirtualTensor::<E, IO>::__expand_rank(scope, this)
207 }
208}
209
210#[cube]
211impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
212 pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
213 let num_strides = index / self.stride(dim);
214 num_strides % self.shape(dim)
215 }
216}
217
218impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
219 fn __expand_write(
220 scope: &mut Scope,
221 this: VirtualTensorExpand<E, ReadWrite>,
222 index: <u32 as CubeType>::ExpandType,
223 value: <Line<E> as CubeType>::ExpandType,
224 ) -> <() as CubeType>::ExpandType {
225 this.__expand_write_method(scope, index, value)
226 }
227}
228
229impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
230 fn __expand_write_method(
231 self,
232 scope: &mut Scope,
233 index: <u32 as CubeType>::ExpandType,
234 value: <Line<E> as CubeType>::ExpandType,
235 ) -> <() as CubeType>::ExpandType {
236 self.state
237 .clone()
238 .__expand_write_method(scope, index, value)
239 }
240}
241
242impl<E: Numeric> VirtualTensor<E, Read> {
243 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
245 unexpanded!()
246 }
247
248 pub fn __expand_new<V>(_scope: &mut Scope, v: V::ExpandType) -> VirtualTensorExpand<E, Read>
250 where
251 V::ExpandType: VirtualTensorOperationsExpand<E>,
252 V: VirtualTensorOperations<E> + CubeType + 'static,
253 {
254 VirtualTensorExpand {
255 state: Arc::new(v),
256 _p: PhantomData,
257 }
258 }
259}
260
261impl<E: Numeric> VirtualTensor<E, ReadWrite> {
262 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
264 unexpanded!()
265 }
266
267 pub fn __expand_new<V>(
269 _scope: &mut Scope,
270 v: V::ExpandType,
271 ) -> VirtualTensorExpand<E, ReadWrite>
272 where
273 V::ExpandType: VirtualTensorOperationsExpand<E>,
274 V: VirtualTensorOperations<E> + CubeType + 'static,
275 {
276 VirtualTensorExpand {
277 state: Arc::new(v),
278 _p: PhantomData,
279 }
280 }
281}
282
283pub trait VirtualTensorOperations<E: Numeric> {
293 fn read(&self, _index: u32) -> Line<E> {
295 unexpanded!()
296 }
297 fn write(&self, _index: u32, _value: Line<E>) {
299 unexpanded!()
300 }
301 fn shape(&self, _axis: u32) -> u32 {
303 unexpanded!()
304 }
305 fn stride(&self, _axis: u32) -> u32 {
307 unexpanded!()
308 }
309 fn rank(&self) -> u32 {
311 unexpanded!()
312 }
313}
314
315pub trait VirtualTensorOperationsExpand<E: Numeric> {
320 fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<E>>;
321 fn __expand_read_method(
322 &self,
323 scope: &mut Scope,
324 index: ExpandElementTyped<u32>,
325 ) -> ExpandElementTyped<Line<E>>;
326 fn __expand_read_window_method(
327 &self,
328 context: &mut Scope,
329 start: ExpandElementTyped<u32>,
330 end: ExpandElementTyped<u32>,
331 ) -> ExpandElementTyped<Slice<Line<E>>>;
332 fn __expand_write_method(
333 &self,
334 scope: &mut Scope,
335 index: ExpandElementTyped<u32>,
336 value: ExpandElementTyped<Line<E>>,
337 );
338 fn __expand_shape_method(
339 &self,
340 scope: &mut Scope,
341 axis: ExpandElementTyped<u32>,
342 ) -> ExpandElementTyped<u32>;
343 fn __expand_stride_method(
344 &self,
345 scope: &mut Scope,
346 axis: ExpandElementTyped<u32>,
347 ) -> ExpandElementTyped<u32>;
348 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
349 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
350 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
351}
352
353mod __cube_type {
355 use super::*;
356
357 impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
358 type ExpandType = VirtualTensorExpand<E, IO>;
359 }
360
361 impl<E: Numeric, IO> Init for VirtualTensorExpand<E, IO> {
362 fn init(self, _scope: &mut Scope) -> Self {
363 self
364 }
365 }
366
367 impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
368}
369
370mod __tensor {
372 use super::*;
373
374 impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
375 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
376 fn __expand_read_method(
377 &self,
378 scope: &mut Scope,
379 index: ExpandElementTyped<u32>,
380 ) -> ExpandElementTyped<Line<E>> {
381 self.clone().__expand_index_unchecked_method(scope, index)
382 }
383 fn __expand_read_window_method(
384 &self,
385 context: &mut Scope,
386 start: ExpandElementTyped<u32>,
387 end: ExpandElementTyped<u32>,
388 ) -> ExpandElementTyped<Slice<Line<E>>> {
389 self.clone().__expand_slice_method(context, start, end)
390 }
391
392 fn __expand_write_method(
393 &self,
394 scope: &mut Scope,
395 index: ExpandElementTyped<u32>,
396 value: ExpandElementTyped<Line<E>>,
397 ) {
398 self.clone()
399 .__expand_index_assign_unchecked_method(scope, index, value)
400 }
401
402 fn __expand_shape_method(
403 &self,
404 scope: &mut Scope,
405 axis: ExpandElementTyped<u32>,
406 ) -> ExpandElementTyped<u32> {
407 self.clone().__expand_shape_method(scope, axis)
408 }
409
410 fn __expand_stride_method(
411 &self,
412 scope: &mut Scope,
413 axis: ExpandElementTyped<u32>,
414 ) -> ExpandElementTyped<u32> {
415 self.clone().__expand_stride_method(scope, axis)
416 }
417
418 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
419 self.clone().__expand_rank_method(scope)
420 }
421 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
422 self.clone().__expand_len_method(scope)
423 }
424 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
425 self.clone().__expand_buffer_len_method(scope)
426 }
427
428 fn __expand_as_tensor_map_method(
429 &self,
430 _scope: &mut Scope,
431 ) -> ExpandElementTyped<TensorMap<E>> {
432 unimplemented!("Can't turn normal tensor into `TensorMap`");
433 }
434 }
435}
436
437mod __tensor_map {
439 use super::*;
440
441 impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E> {}
442 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E>> {
443 fn __expand_read_method(
444 &self,
445 _scope: &mut Scope,
446 _index: ExpandElementTyped<u32>,
447 ) -> ExpandElementTyped<Line<E>> {
448 todo!()
449 }
450 fn __expand_read_window_method(
451 &self,
452 _context: &mut Scope,
453 _start: ExpandElementTyped<u32>,
454 _end: ExpandElementTyped<u32>,
455 ) -> ExpandElementTyped<Slice<Line<E>>> {
456 todo!()
457 }
458
459 fn __expand_write_method(
460 &self,
461 _scope: &mut Scope,
462 _index: ExpandElementTyped<u32>,
463 _value: ExpandElementTyped<Line<E>>,
464 ) {
465 todo!()
466 }
467
468 fn __expand_shape_method(
469 &self,
470 _scope: &mut Scope,
471 _axis: ExpandElementTyped<u32>,
472 ) -> ExpandElementTyped<u32> {
473 todo!()
474 }
475
476 fn __expand_stride_method(
477 &self,
478 _scope: &mut Scope,
479 _axis: ExpandElementTyped<u32>,
480 ) -> ExpandElementTyped<u32> {
481 todo!()
482 }
483
484 fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
485 todo!()
486 }
487 fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
488 todo!()
489 }
490 fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
491 todo!()
492 }
493
494 fn __expand_as_tensor_map_method(
495 &self,
496 _scope: &mut Scope,
497 ) -> ExpandElementTyped<TensorMap<E>> {
498 self.clone()
499 }
500 }
501}