1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use std::ops::{Deref, DerefMut};
6
7use crate::tensor::{
8 ViewExpand,
9 layout::{
10 Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
11 },
12 view::View,
13};
14
15#[derive(Clone)]
17pub struct VirtualTensor<E: Numeric, N: Size, IO = ReadOnly> {
18 _e: PhantomData<E>,
20 _n: PhantomData<N>,
21 _p: PhantomData<IO>,
22}
23
24impl<E: Numeric, N: Size, IO: Clone> Copy for VirtualTensor<E, N, IO> {}
25
26#[derive(Clone)]
28pub struct VirtualTensorExpand<E: Numeric, N: Size, IO> {
29 state: Arc<dyn VirtualTensorOperationsExpand<E, N>>,
30 _p: PhantomData<IO>,
31}
32
33impl<E: Numeric, N: Size, IO: Clone> List<Vector<E, N>> for VirtualTensor<E, N, IO> {
34 fn __expand_read(
35 scope: &mut Scope,
36 this: VirtualTensorExpand<E, N, IO>,
37 index: <usize as CubeType>::ExpandType,
38 ) -> <Vector<E, N> as CubeType>::ExpandType {
39 this.__expand_read_method(scope, index)
40 }
41}
42
43impl<T: Numeric, N: Size, IO: Clone> Deref for VirtualTensor<T, N, IO> {
44 type Target = [Vector<T, N>];
45
46 fn deref(&self) -> &Self::Target {
47 unexpanded!()
48 }
49}
50
51impl<T: Numeric, N: Size> DerefMut for VirtualTensor<T, N, ReadWrite> {
52 fn deref_mut(&mut self) -> &mut Self::Target {
53 unexpanded!()
54 }
55}
56
57impl<E: Numeric, N: Size, IO: Clone> ListExpand<Vector<E, N>> for VirtualTensorExpand<E, N, IO> {
58 fn __expand_read_method(
59 &self,
60 scope: &mut Scope,
61 index: <usize as CubeType>::ExpandType,
62 ) -> <Vector<E, N> as CubeType>::ExpandType {
63 self.state.clone().__expand_read_method(scope, index)
64 }
65
66 fn __expand_read_unchecked_method(
67 &self,
68 _scope: &mut Scope,
69 _index: NativeExpand<usize>,
70 ) -> <Vector<E, N> as CubeType>::ExpandType {
71 todo!("VirtualTensor don't support read unchecked yet");
72 }
73
74 fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
75 self.state.clone().__expand_len_method(scope)
76 }
77}
78
79impl<E: Numeric, N: Size, IO: Clone> Vectorized for VirtualTensor<E, N, IO> {}
80impl<E: Numeric, N: Size, IO: Clone> VectorizedExpand for VirtualTensorExpand<E, N, IO> {
81 fn vector_size(&self) -> VectorSize {
82 self.state.clone().vector_size()
83 }
84}
85
86impl<E: Numeric, N: Size, IO: Clone> SliceOperator<Vector<E, N>> for VirtualTensor<E, N, IO> {}
87impl<E: Numeric, N: Size, IO: Clone> SliceOperatorExpand<Vector<E, N>>
88 for VirtualTensorExpand<E, N, IO>
89{
90 fn __expand_slice_method(
91 &self,
92 scope: &mut Scope,
93 start: NativeExpand<usize>,
94 end: NativeExpand<usize>,
95 ) -> SliceExpand<Vector<E, N>, ReadOnly> {
96 self.state
97 .clone()
98 .__expand_read_window_method(scope, start, end)
99 }
100
101 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Vector<E, N>, ReadOnly> {
102 let end = self.clone().__expand_buffer_len_method(scope);
103 self.state
104 .clone()
105 .__expand_read_window_method(scope, 0.into(), end)
106 }
107}
108
109#[allow(unused, clippy::all)]
110impl<E: Numeric, N: Size, IO: Clone> VirtualTensor<E, N, IO> {
111 pub fn as_tensor_map(&self) -> Option<TensorMap<E, Tiled>> {
112 unexpanded!()
113 }
114 pub fn as_slice(&self, start: usize, end: usize) -> Slice<Vector<E, N>> {
115 unexpanded!();
116 }
117 pub fn shape(&self, axis: usize) -> usize {
119 unexpanded!();
120 }
121 pub fn stride(&self, axis: usize) -> usize {
123 unexpanded!();
124 }
125 pub fn rank(&self) -> usize {
127 unexpanded!();
128 }
129
130 pub fn buffer_len(&self) -> usize {
131 unexpanded!();
132 }
133
134 pub fn __expand_as_tensor_map(
135 context: &mut Scope,
136 this: <Self as CubeType>::ExpandType,
137 ) -> <ComptimeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
138 this.__expand_as_tensor_map_method(context)
139 }
140 pub fn __expand_as_slice(
141 context: &mut Scope,
142 this: <Self as CubeType>::ExpandType,
143 start: <usize as CubeType>::ExpandType,
144 end: <usize as CubeType>::ExpandType,
145 ) -> <Slice<Vector<E, N>> as CubeType>::ExpandType {
146 this.__expand_as_slice_method(context, start, end)
147 }
148 pub fn __expand_shape(
149 scope: &mut Scope,
150 this: <Self as CubeType>::ExpandType,
151 axis: <usize as CubeType>::ExpandType,
152 ) -> <usize as CubeType>::ExpandType {
153 this.__expand_shape_method(scope, axis)
154 }
155 pub fn __expand_stride(
156 scope: &mut Scope,
157 this: <Self as CubeType>::ExpandType,
158 axis: <usize as CubeType>::ExpandType,
159 ) -> <usize as CubeType>::ExpandType {
160 this.__expand_stride_method(scope, axis)
161 }
162 pub fn __expand_rank(
163 scope: &mut Scope,
164 this: <Self as CubeType>::ExpandType,
165 ) -> <usize as CubeType>::ExpandType {
166 this.__expand_rank_method(scope)
167 }
168 pub fn __expand_buffer_len(
169 scope: &mut Scope,
170 this: <Self as CubeType>::ExpandType,
171 ) -> <usize as CubeType>::ExpandType {
172 this.__expand_buffer_len_method(scope)
173 }
174}
175
176#[allow(unused, clippy::all)]
177impl<E: Numeric, N: Size, IO: Clone> VirtualTensorExpand<E, N, IO> {
178 pub fn __expand_as_tensor_map_method(
179 self,
180 context: &mut Scope,
181 ) -> <ComptimeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
182 self.state.clone().__expand_as_tensor_map_method(context)
183 }
184
185 pub fn __expand_as_slice_method(
186 self,
187 context: &mut Scope,
188 start: <usize as CubeType>::ExpandType,
189 end: <usize as CubeType>::ExpandType,
190 ) -> <Slice<Vector<E, N>> as CubeType>::ExpandType {
191 self.state
192 .clone()
193 .__expand_read_window_method(context, start, end)
194 }
195
196 pub fn __expand_shape_method(
197 self,
198 scope: &mut Scope,
199 axis: <usize as CubeType>::ExpandType,
200 ) -> <usize as CubeType>::ExpandType {
201 let _arg_0 = axis;
202 self.state
203 .clone()
204 .__expand_shape_method(scope, _arg_0.into())
205 }
206
207 pub fn __expand_stride_method(
208 self,
209 scope: &mut Scope,
210 axis: <usize as CubeType>::ExpandType,
211 ) -> <usize as CubeType>::ExpandType {
212 let _arg_0 = axis;
213 self.state
214 .clone()
215 .__expand_stride_method(scope, _arg_0.into())
216 }
217
218 pub fn __expand_rank_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
219 self.state.clone().__expand_rank_method(scope)
220 }
221
222 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
223 self.state.clone().__expand_buffer_len_method(scope)
224 }
225
226 pub fn __expand_read(
227 scope: &mut Scope,
228 this: Self,
229 index: <usize as CubeType>::ExpandType,
230 ) -> <Vector<E, N> as CubeType>::ExpandType {
231 VirtualTensor::<E, N, IO>::__expand_read(scope, this, index)
232 }
233
234 pub fn __expand_shape(
235 scope: &mut Scope,
236 this: Self,
237 axis: <usize as CubeType>::ExpandType,
238 ) -> <usize as CubeType>::ExpandType {
239 VirtualTensor::<E, N, IO>::__expand_shape(scope, this, axis)
240 }
241
242 pub fn __expand_stride(
243 scope: &mut Scope,
244 this: Self,
245 axis: <usize as CubeType>::ExpandType,
246 ) -> <usize as CubeType>::ExpandType {
247 VirtualTensor::<E, N, IO>::__expand_stride(scope, this, axis)
248 }
249
250 pub fn __expand_rank(scope: &mut Scope, this: Self) -> <usize as CubeType>::ExpandType {
251 VirtualTensor::<E, N, IO>::__expand_rank(scope, this)
252 }
253}
254
255impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensor<E, N, IO> {
256 pub fn view<C: Coordinates + 'static>(
259 &self,
260 layout: impl Into<VirtualLayout<C, Coords1d>>,
261 ) -> View<Vector<E, N>, C, ReadOnly> {
262 View::new::<VirtualTensor<E, N, IO>, Coords1d>(self, layout)
263 }
264}
265
266#[cube]
267impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensor<E, N, IO> {
268 pub fn as_view(&self) -> View<Vector<E, N>, usize, ReadOnly> {
270 let vector_size = self.vector_size();
271 View::new::<VirtualTensor<E, N, IO>, usize>(
272 self,
273 SimpleLayout::new(self.len() * vector_size, vector_size),
274 )
275 }
276}
277
278impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensorExpand<E, N, IO> {
279 pub fn __expand_view_method<C: Coordinates + 'static>(
282 &self,
283 scope: &mut Scope,
284 layout: VirtualLayoutExpand<C, Coords1d>,
285 ) -> ViewExpand<Vector<E, N>, C, ReadOnly> {
286 View::__expand_new::<VirtualTensor<E, N, IO>, Coords1d>(scope, self.clone(), layout)
287 }
288}
289
290impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadWrite> {
291 #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
292 #[doc = " with custom layouts"]
293 pub fn view_mut<C: Coordinates + 'static>(
294 &self,
295 layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
296 ) -> View<Vector<E, N>, C, ReadWrite> {
297 let mut this: VirtualTensor<E, N, ReadWrite> = *self;
298 View::new_mut::<VirtualTensor<E, N, ReadWrite>, Coords1d>(&mut this, layout)
299 }
300 pub fn __expand_view_mut<C: Coordinates + 'static>(
301 scope: &mut Scope,
302 this: VirtualTensorExpand<E, N, ReadWrite>,
303 layout: VirtualLayoutExpand<C, Coords1d>,
304 ) -> ViewExpand<Vector<E, N>, C, ReadWrite> {
305 this.__expand_view_mut_method::<C>(scope, layout)
306 }
307}
308impl<E: Numeric, N: Size> VirtualTensorExpand<E, N, ReadWrite> {
309 pub fn __expand_view_mut_method<C: Coordinates + 'static>(
310 self,
311 scope: &mut Scope,
312 layout: VirtualLayoutExpand<C, Coords1d>,
313 ) -> ViewExpand<Vector<E, N>, C, ReadWrite> {
314 View::__expand_new_mut::<VirtualTensor<E, N, ReadWrite>, Coords1d>(scope, self, layout)
315 }
316}
317
318#[cube]
319impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadWrite> {
320 pub fn as_view_mut(&mut self) -> View<Vector<E, N>, usize, ReadWrite> {
322 let vector_size = self.vector_size();
323 View::new_mut::<VirtualTensor<E, N, ReadWrite>, usize>(
324 self,
325 SimpleLayout::new(self.len() * vector_size, vector_size),
326 )
327 }
328}
329
330#[cube]
331impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensor<E, N, IO> {
332 pub fn coordinate(&self, index: usize, dim: usize) -> usize {
333 let num_strides = index / self.stride(dim);
334 num_strides % self.shape(dim)
335 }
336}
337
338impl<E: Numeric, N: Size> ListMut<Vector<E, N>> for VirtualTensor<E, N, ReadWrite> {
339 fn __expand_write(
340 scope: &mut Scope,
341 this: VirtualTensorExpand<E, N, ReadWrite>,
342 index: <usize as CubeType>::ExpandType,
343 value: <Vector<E, N> as CubeType>::ExpandType,
344 ) -> <() as CubeType>::ExpandType {
345 this.__expand_write_method(scope, index, value)
346 }
347}
348
349impl<E: Numeric, N: Size> ListMutExpand<Vector<E, N>> for VirtualTensorExpand<E, N, ReadWrite> {
350 fn __expand_write_method(
351 &self,
352 scope: &mut Scope,
353 index: <usize as CubeType>::ExpandType,
354 value: <Vector<E, N> as CubeType>::ExpandType,
355 ) -> <() as CubeType>::ExpandType {
356 self.state
357 .clone()
358 .__expand_write_method(scope, index, value)
359 }
360}
361
362impl<E: Numeric, N: Size> SliceMutOperator<Vector<E, N>> for VirtualTensor<E, N, ReadWrite> {}
363impl<E: Numeric, N: Size> SliceMutOperatorExpand<Vector<E, N>>
364 for VirtualTensorExpand<E, N, ReadWrite>
365{
366 #[allow(unused_variables)]
367 fn __expand_slice_mut_method(
368 &self,
369 scope: &mut Scope,
370 start: NativeExpand<usize>,
371 end: NativeExpand<usize>,
372 ) -> SliceExpand<Vector<E, N>, cubecl_core::prelude::ReadWrite> {
373 todo!("VirtualTensor don't support slice mut yet");
374 }
375
376 #[allow(unused_variables)]
377 fn __expand_to_slice_mut_method(
378 &self,
379 scope: &mut Scope,
380 ) -> SliceExpand<Vector<E, N>, cubecl_core::prelude::ReadWrite> {
381 todo!("VirtualTensor don't support slice mut yet");
382 }
383}
384
385impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadOnly> {
386 pub fn new<V: VirtualTensorOperations<E, N> + 'static>(_v: &V) -> Self {
388 unexpanded!()
389 }
390
391 pub fn __expand_new<V: VirtualTensorOperations<E, N> + 'static>(
393 _scope: &mut Scope,
394 v: V::ExpandType,
395 ) -> VirtualTensorExpand<E, N, ReadOnly> {
396 VirtualTensorExpand {
397 state: Arc::new(v),
398 _p: PhantomData,
399 }
400 }
401}
402
403impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadWrite> {
404 pub fn new<V: VirtualTensorOperations<E, N> + 'static>(_v: &mut V) -> Self {
406 unexpanded!()
407 }
408
409 pub fn __expand_new<V: VirtualTensorOperations<E, N> + 'static>(
411 _scope: &mut Scope,
412 v: V::ExpandType,
413 ) -> VirtualTensorExpand<E, N, ReadWrite> {
414 VirtualTensorExpand {
415 state: Arc::new(v),
416 _p: PhantomData,
417 }
418 }
419}
420
421#[cube(self_type = "ref", expand_base_traits = "VectorizedExpand")]
431pub trait VirtualTensorOperations<E: Numeric, N: Size>: Vectorized {
432 fn as_tensor_map(&self) -> ComptimeOption<TensorMap<E, Tiled>> {
433 unexpanded!()
434 }
435 fn read(&self, _index: usize) -> Vector<E, N> {
437 unexpanded!()
438 }
439 fn read_window(&self, _start: usize, _end: usize) -> Slice<Vector<E, N>, ReadOnly> {
440 unexpanded!()
441 }
442 fn write(&self, _index: usize, _value: Vector<E, N>) {
444 unexpanded!()
445 }
446 fn shape(&self, _axis: usize) -> usize {
448 unexpanded!()
449 }
450 fn stride(&self, _axis: usize) -> usize {
452 unexpanded!()
453 }
454 fn rank(&self) -> usize {
456 unexpanded!()
457 }
458 fn len(&self) -> usize {
459 unexpanded!()
460 }
461 fn buffer_len(&self) -> usize {
462 unexpanded!()
463 }
464}
465
466mod __cube_type {
468 use super::*;
469
470 impl<E: Numeric, N: Size, IO: Clone> CubeType for VirtualTensor<E, N, IO> {
471 type ExpandType = VirtualTensorExpand<E, N, IO>;
472 }
473
474 impl<E: Numeric, N: Size, IO> IntoMut for VirtualTensorExpand<E, N, IO> {
475 fn into_mut(self, _scope: &mut Scope) -> Self {
476 self
477 }
478 }
479
480 impl<E: Numeric, N: Size, IO> CubeDebug for VirtualTensorExpand<E, N, IO> {}
481}
482
483mod __tensor {
485 use super::*;
486
487 impl<E: Numeric, N: Size> VirtualTensorOperations<E, N> for Tensor<Vector<E, N>> {}
488 impl<E: Numeric, N: Size> VirtualTensorOperationsExpand<E, N>
489 for NativeExpand<Tensor<Vector<E, N>>>
490 {
491 fn __expand_read_method(
492 &self,
493 scope: &mut Scope,
494 index: NativeExpand<usize>,
495 ) -> NativeExpand<Vector<E, N>> {
496 self.clone().__expand_index_unchecked_method(scope, index)
497 }
498 fn __expand_read_window_method(
499 &self,
500 context: &mut Scope,
501 start: NativeExpand<usize>,
502 end: NativeExpand<usize>,
503 ) -> SliceExpand<Vector<E, N>, ReadOnly> {
504 self.clone().__expand_slice_method(context, start, end)
505 }
506
507 fn __expand_write_method(
508 &self,
509 scope: &mut Scope,
510 index: NativeExpand<usize>,
511 value: NativeExpand<Vector<E, N>>,
512 ) {
513 self.clone()
514 .__expand_index_assign_unchecked_method(scope, index, value)
515 }
516
517 fn __expand_shape_method(
518 &self,
519 scope: &mut Scope,
520 axis: NativeExpand<usize>,
521 ) -> NativeExpand<usize> {
522 self.clone().__expand_shape_method(scope, axis)
523 }
524
525 fn __expand_stride_method(
526 &self,
527 scope: &mut Scope,
528 axis: NativeExpand<usize>,
529 ) -> NativeExpand<usize> {
530 self.clone().__expand_stride_method(scope, axis)
531 }
532
533 fn __expand_rank_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
534 self.clone().__expand_rank_method(scope)
535 }
536 fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
537 self.clone().__expand_len_method(scope)
538 }
539 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
540 self.clone().__expand_buffer_len_method(scope)
541 }
542
543 fn __expand_as_tensor_map_method(
544 &self,
545 scope: &mut Scope,
546 ) -> ComptimeOptionExpand<TensorMap<E, Tiled>> {
547 ComptimeOption::__expand_new_None(scope)
548 }
549 }
550}
551
552mod __tensor_map {
554 use super::*;
555
556 impl<E: Numeric, N: Size> VirtualTensorOperations<E, N> for TensorMap<E, Tiled> {}
557 impl<E: Numeric, N: Size> VirtualTensorOperationsExpand<E, N>
558 for NativeExpand<TensorMap<E, Tiled>>
559 {
560 fn __expand_read_method(
561 &self,
562 _scope: &mut Scope,
563 _index: NativeExpand<usize>,
564 ) -> NativeExpand<Vector<E, N>> {
565 todo!()
566 }
567 fn __expand_read_window_method(
568 &self,
569 _context: &mut Scope,
570 _start: NativeExpand<usize>,
571 _end: NativeExpand<usize>,
572 ) -> SliceExpand<Vector<E, N>, ReadOnly> {
573 todo!()
574 }
575
576 fn __expand_write_method(
577 &self,
578 _scope: &mut Scope,
579 _index: NativeExpand<usize>,
580 _value: NativeExpand<Vector<E, N>>,
581 ) {
582 todo!()
583 }
584
585 fn __expand_shape_method(
586 &self,
587 _scope: &mut Scope,
588 _axis: NativeExpand<usize>,
589 ) -> NativeExpand<usize> {
590 todo!()
591 }
592
593 fn __expand_stride_method(
594 &self,
595 _scope: &mut Scope,
596 _axis: NativeExpand<usize>,
597 ) -> NativeExpand<usize> {
598 todo!()
599 }
600
601 fn __expand_rank_method(&self, _scope: &mut Scope) -> NativeExpand<usize> {
602 todo!()
603 }
604 fn __expand_len_method(&self, _scope: &mut Scope) -> NativeExpand<usize> {
605 todo!()
606 }
607 fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> NativeExpand<usize> {
608 todo!()
609 }
610
611 fn __expand_as_tensor_map_method(
612 &self,
613 scope: &mut Scope,
614 ) -> ComptimeOptionExpand<TensorMap<E, Tiled>> {
615 ComptimeOption::__expand_new_Some(scope, self.clone())
616 }
617 }
618}