1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::{CubeType, Scope, *};
4use cubecl_core::{self as cubecl, unexpanded};
5
6#[derive(Clone)]
8pub struct Read;
9
10#[derive(Clone)]
12pub struct ReadWrite;
13
14#[derive(Clone)]
16pub struct VirtualTensor<E: Numeric, IO = Read> {
17 _e: PhantomData<E>,
19 _p: PhantomData<IO>,
20}
21
22impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
23
24#[derive(Clone)]
26pub struct VirtualTensorExpand<E: Numeric, IO> {
27 state: Arc<dyn VirtualTensorOperationsExpand<E>>,
28 _p: PhantomData<IO>,
29}
30
31impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
32 fn __expand_read(
33 scope: &mut Scope,
34 this: VirtualTensorExpand<E, IO>,
35 index: <u32 as CubeType>::ExpandType,
36 ) -> <Line<E> as CubeType>::ExpandType {
37 this.__expand_read_method(scope, index)
38 }
39}
40
41impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
42 fn __expand_read_method(
43 &self,
44 scope: &mut Scope,
45 index: <u32 as CubeType>::ExpandType,
46 ) -> <Line<E> as CubeType>::ExpandType {
47 self.state.clone().__expand_read_method(scope, index)
48 }
49
50 fn __expand_read_unchecked_method(
51 &self,
52 _scope: &mut Scope,
53 _index: ExpandElementTyped<u32>,
54 ) -> <Line<E> as CubeType>::ExpandType {
55 todo!("VirtualTensor don't support read unchecked yet");
56 }
57}
58
59#[allow(unused, clippy::all)]
60impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
61 pub fn as_tensor_map(&self) -> TensorMap<E> {
62 unexpanded!()
63 }
64 pub fn as_slice(&self, start: u32, end: u32) -> Slice<Line<E>> {
65 unexpanded!();
66 }
67 pub fn shape(&self, axis: u32) -> u32 {
69 unexpanded!();
70 }
71 pub fn stride(&self, axis: u32) -> u32 {
73 unexpanded!();
74 }
75 pub fn rank(&self) -> u32 {
77 unexpanded!();
78 }
79
80 pub fn len(&self) -> u32 {
81 unexpanded!();
82 }
83
84 pub fn buffer_len(&self) -> u32 {
85 unexpanded!();
86 }
87
88 pub fn __expand_as_tensor_map(
89 context: &mut Scope,
90 this: <Self as CubeType>::ExpandType,
91 ) -> <TensorMap<E> as CubeType>::ExpandType {
92 this.__expand_as_tensor_map_method(context)
93 }
94 pub fn __expand_as_slice(
95 context: &mut Scope,
96 this: <Self as CubeType>::ExpandType,
97 start: <u32 as CubeType>::ExpandType,
98 end: <u32 as CubeType>::ExpandType,
99 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
100 this.__expand_as_slice_method(context, start, end)
101 }
102 pub fn __expand_shape(
103 scope: &mut Scope,
104 this: <Self as CubeType>::ExpandType,
105 axis: <u32 as CubeType>::ExpandType,
106 ) -> <u32 as CubeType>::ExpandType {
107 this.__expand_shape_method(scope, axis)
108 }
109 pub fn __expand_stride(
110 scope: &mut Scope,
111 this: <Self as CubeType>::ExpandType,
112 axis: <u32 as CubeType>::ExpandType,
113 ) -> <u32 as CubeType>::ExpandType {
114 this.__expand_stride_method(scope, axis)
115 }
116 pub fn __expand_rank(
117 scope: &mut Scope,
118 this: <Self as CubeType>::ExpandType,
119 ) -> <u32 as CubeType>::ExpandType {
120 this.__expand_rank_method(scope)
121 }
122 pub fn __expand_len(
123 scope: &mut Scope,
124 this: <Self as CubeType>::ExpandType,
125 ) -> <u32 as CubeType>::ExpandType {
126 this.__expand_len_method(scope)
127 }
128 pub fn __expand_buffer_len(
129 scope: &mut Scope,
130 this: <Self as CubeType>::ExpandType,
131 ) -> <u32 as CubeType>::ExpandType {
132 this.__expand_buffer_len_method(scope)
133 }
134}
135
136#[allow(unused, clippy::all)]
137impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
138 pub fn __expand_as_tensor_map_method(
139 self,
140 context: &mut Scope,
141 ) -> <TensorMap<E> as CubeType>::ExpandType {
142 self.state.clone().__expand_as_tensor_map_method(context)
143 }
144
145 pub fn __expand_as_slice_method(
146 self,
147 context: &mut Scope,
148 start: <u32 as CubeType>::ExpandType,
149 end: <u32 as CubeType>::ExpandType,
150 ) -> <Slice<Line<E>> as CubeType>::ExpandType {
151 self.state
152 .clone()
153 .__expand_read_window_method(context, start, end)
154 }
155
156 pub fn __expand_shape_method(
157 self,
158 scope: &mut Scope,
159 axis: <u32 as CubeType>::ExpandType,
160 ) -> <u32 as CubeType>::ExpandType {
161 let _arg_0 = axis;
162 self.state
163 .clone()
164 .__expand_shape_method(scope, _arg_0.into())
165 }
166
167 pub fn __expand_stride_method(
168 self,
169 scope: &mut Scope,
170 axis: <u32 as CubeType>::ExpandType,
171 ) -> <u32 as CubeType>::ExpandType {
172 let _arg_0 = axis;
173 self.state
174 .clone()
175 .__expand_stride_method(scope, _arg_0.into())
176 }
177
178 pub fn __expand_rank_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
179 self.state.clone().__expand_rank_method(scope)
180 }
181
182 pub fn __expand_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
183 self.state.clone().__expand_len_method(scope)
184 }
185
186 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
187 self.state.clone().__expand_buffer_len_method(scope)
188 }
189
190 pub fn __expand_read(
191 scope: &mut Scope,
192 this: Self,
193 index: <u32 as CubeType>::ExpandType,
194 ) -> <Line<E> as CubeType>::ExpandType {
195 VirtualTensor::<E, IO>::__expand_read(scope, this, index)
196 }
197
198 pub fn __expand_shape(
199 scope: &mut Scope,
200 this: Self,
201 axis: <u32 as CubeType>::ExpandType,
202 ) -> <u32 as CubeType>::ExpandType {
203 VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
204 }
205
206 pub fn __expand_stride(
207 scope: &mut Scope,
208 this: Self,
209 axis: <u32 as CubeType>::ExpandType,
210 ) -> <u32 as CubeType>::ExpandType {
211 VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
212 }
213
214 pub fn __expand_rank(scope: &mut Scope, this: Self) -> <u32 as CubeType>::ExpandType {
215 VirtualTensor::<E, IO>::__expand_rank(scope, this)
216 }
217}
218
219#[cube]
220impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
221 pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
222 let num_strides = index / self.stride(dim);
223 num_strides % self.shape(dim)
224 }
225}
226
227impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
228 fn __expand_write(
229 scope: &mut Scope,
230 this: VirtualTensorExpand<E, ReadWrite>,
231 index: <u32 as CubeType>::ExpandType,
232 value: <Line<E> as CubeType>::ExpandType,
233 ) -> <() as CubeType>::ExpandType {
234 this.__expand_write_method(scope, index, value)
235 }
236}
237
238impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
239 fn __expand_write_method(
240 &self,
241 scope: &mut Scope,
242 index: <u32 as CubeType>::ExpandType,
243 value: <Line<E> as CubeType>::ExpandType,
244 ) -> <() as CubeType>::ExpandType {
245 self.state
246 .clone()
247 .__expand_write_method(scope, index, value)
248 }
249}
250
251impl<E: Numeric> VirtualTensor<E, Read> {
252 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
254 unexpanded!()
255 }
256
257 pub fn __expand_new<V>(_scope: &mut Scope, v: V::ExpandType) -> VirtualTensorExpand<E, Read>
259 where
260 V::ExpandType: VirtualTensorOperationsExpand<E>,
261 V: VirtualTensorOperations<E> + CubeType + 'static,
262 {
263 VirtualTensorExpand {
264 state: Arc::new(v),
265 _p: PhantomData,
266 }
267 }
268}
269
270impl<E: Numeric> VirtualTensor<E, ReadWrite> {
271 pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
273 unexpanded!()
274 }
275
276 pub fn __expand_new<V>(
278 _scope: &mut Scope,
279 v: V::ExpandType,
280 ) -> VirtualTensorExpand<E, ReadWrite>
281 where
282 V::ExpandType: VirtualTensorOperationsExpand<E>,
283 V: VirtualTensorOperations<E> + CubeType + 'static,
284 {
285 VirtualTensorExpand {
286 state: Arc::new(v),
287 _p: PhantomData,
288 }
289 }
290}
291
292pub trait VirtualTensorOperations<E: Numeric> {
302 fn read(&self, _index: u32) -> Line<E> {
304 unexpanded!()
305 }
306 fn write(&self, _index: u32, _value: Line<E>) {
308 unexpanded!()
309 }
310 fn shape(&self, _axis: u32) -> u32 {
312 unexpanded!()
313 }
314 fn stride(&self, _axis: u32) -> u32 {
316 unexpanded!()
317 }
318 fn rank(&self) -> u32 {
320 unexpanded!()
321 }
322}
323
324pub trait VirtualTensorOperationsExpand<E: Numeric> {
329 fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<E>>;
330 fn __expand_read_method(
331 &self,
332 scope: &mut Scope,
333 index: ExpandElementTyped<u32>,
334 ) -> ExpandElementTyped<Line<E>>;
335 fn __expand_read_window_method(
336 &self,
337 context: &mut Scope,
338 start: ExpandElementTyped<u32>,
339 end: ExpandElementTyped<u32>,
340 ) -> SliceExpand<Line<E>, ReadOnly>;
341 fn __expand_write_method(
342 &self,
343 scope: &mut Scope,
344 index: ExpandElementTyped<u32>,
345 value: ExpandElementTyped<Line<E>>,
346 );
347 fn __expand_shape_method(
348 &self,
349 scope: &mut Scope,
350 axis: ExpandElementTyped<u32>,
351 ) -> ExpandElementTyped<u32>;
352 fn __expand_stride_method(
353 &self,
354 scope: &mut Scope,
355 axis: ExpandElementTyped<u32>,
356 ) -> ExpandElementTyped<u32>;
357 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
358 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
359 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
360}
361
362mod __cube_type {
364 use super::*;
365
366 impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
367 type ExpandType = VirtualTensorExpand<E, IO>;
368 }
369
370 impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
371 fn into_mut(self, _scope: &mut Scope) -> Self {
372 self
373 }
374 }
375
376 impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
377}
378
379mod __tensor {
381 use super::*;
382
383 impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
384 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
385 fn __expand_read_method(
386 &self,
387 scope: &mut Scope,
388 index: ExpandElementTyped<u32>,
389 ) -> ExpandElementTyped<Line<E>> {
390 self.clone().__expand_index_unchecked_method(scope, index)
391 }
392 fn __expand_read_window_method(
393 &self,
394 context: &mut Scope,
395 start: ExpandElementTyped<u32>,
396 end: ExpandElementTyped<u32>,
397 ) -> SliceExpand<Line<E>, ReadOnly> {
398 self.clone().__expand_slice_method(context, start, end)
399 }
400
401 fn __expand_write_method(
402 &self,
403 scope: &mut Scope,
404 index: ExpandElementTyped<u32>,
405 value: ExpandElementTyped<Line<E>>,
406 ) {
407 self.clone()
408 .__expand_index_assign_unchecked_method(scope, index, value)
409 }
410
411 fn __expand_shape_method(
412 &self,
413 scope: &mut Scope,
414 axis: ExpandElementTyped<u32>,
415 ) -> ExpandElementTyped<u32> {
416 self.clone().__expand_shape_method(scope, axis)
417 }
418
419 fn __expand_stride_method(
420 &self,
421 scope: &mut Scope,
422 axis: ExpandElementTyped<u32>,
423 ) -> ExpandElementTyped<u32> {
424 self.clone().__expand_stride_method(scope, axis)
425 }
426
427 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
428 self.clone().__expand_rank_method(scope)
429 }
430 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
431 self.clone().__expand_len_method(scope)
432 }
433 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
434 self.clone().__expand_buffer_len_method(scope)
435 }
436
437 fn __expand_as_tensor_map_method(
438 &self,
439 _scope: &mut Scope,
440 ) -> ExpandElementTyped<TensorMap<E>> {
441 unimplemented!("Can't turn normal tensor into `TensorMap`");
442 }
443 }
444}
445
446mod __tensor_map {
448 use super::*;
449
450 impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E> {}
451 impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E>> {
452 fn __expand_read_method(
453 &self,
454 _scope: &mut Scope,
455 _index: ExpandElementTyped<u32>,
456 ) -> ExpandElementTyped<Line<E>> {
457 todo!()
458 }
459 fn __expand_read_window_method(
460 &self,
461 _context: &mut Scope,
462 _start: ExpandElementTyped<u32>,
463 _end: ExpandElementTyped<u32>,
464 ) -> SliceExpand<Line<E>, ReadOnly> {
465 todo!()
466 }
467
468 fn __expand_write_method(
469 &self,
470 _scope: &mut Scope,
471 _index: ExpandElementTyped<u32>,
472 _value: ExpandElementTyped<Line<E>>,
473 ) {
474 todo!()
475 }
476
477 fn __expand_shape_method(
478 &self,
479 _scope: &mut Scope,
480 _axis: ExpandElementTyped<u32>,
481 ) -> ExpandElementTyped<u32> {
482 todo!()
483 }
484
485 fn __expand_stride_method(
486 &self,
487 _scope: &mut Scope,
488 _axis: ExpandElementTyped<u32>,
489 ) -> ExpandElementTyped<u32> {
490 todo!()
491 }
492
493 fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
494 todo!()
495 }
496 fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
497 todo!()
498 }
499 fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
500 todo!()
501 }
502
503 fn __expand_as_tensor_map_method(
504 &self,
505 _scope: &mut Scope,
506 ) -> ExpandElementTyped<TensorMap<E>> {
507 self.clone()
508 }
509 }
510}