Skip to main content

cubecl_std/tensor/
virtual.rs

1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use std::ops::{Deref, DerefMut};
6
7use crate::tensor::{
8    ViewExpand,
9    layout::{
10        Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
11    },
12    view::View,
13};
14
15/// Tensor representation that is decoupled from how the tensor is stored.
16#[derive(Clone)]
17pub struct VirtualTensor<E: Numeric, N: Size, IO = ReadOnly> {
18    // state: Arc<dyn VirtualTensorOperations<E>>,
19    _e: PhantomData<E>,
20    _n: PhantomData<N>,
21    _p: PhantomData<IO>,
22}
23
24impl<E: Numeric, N: Size, IO: Clone> Copy for VirtualTensor<E, N, IO> {}
25
26/// Expand type for [`VirtualTensor`].
27#[derive(Clone)]
28pub struct VirtualTensorExpand<E: Numeric, N: Size, IO> {
29    state: Arc<dyn VirtualTensorOperationsExpand<E, N>>,
30    _p: PhantomData<IO>,
31}
32
33impl<E: Numeric, N: Size, IO: Clone> List<Vector<E, N>> for VirtualTensor<E, N, IO> {
34    fn __expand_read(
35        scope: &mut Scope,
36        this: VirtualTensorExpand<E, N, IO>,
37        index: <usize as CubeType>::ExpandType,
38    ) -> <Vector<E, N> as CubeType>::ExpandType {
39        this.__expand_read_method(scope, index)
40    }
41}
42
43impl<T: Numeric, N: Size, IO: Clone> Deref for VirtualTensor<T, N, IO> {
44    type Target = [Vector<T, N>];
45
46    fn deref(&self) -> &Self::Target {
47        unexpanded!()
48    }
49}
50
51impl<T: Numeric, N: Size> DerefMut for VirtualTensor<T, N, ReadWrite> {
52    fn deref_mut(&mut self) -> &mut Self::Target {
53        unexpanded!()
54    }
55}
56
57impl<E: Numeric, N: Size, IO: Clone> ListExpand<Vector<E, N>> for VirtualTensorExpand<E, N, IO> {
58    fn __expand_read_method(
59        &self,
60        scope: &mut Scope,
61        index: <usize as CubeType>::ExpandType,
62    ) -> <Vector<E, N> as CubeType>::ExpandType {
63        self.state.clone().__expand_read_method(scope, index)
64    }
65
66    fn __expand_read_unchecked_method(
67        &self,
68        _scope: &mut Scope,
69        _index: NativeExpand<usize>,
70    ) -> <Vector<E, N> as CubeType>::ExpandType {
71        todo!("VirtualTensor don't support read unchecked yet");
72    }
73
74    fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
75        self.state.clone().__expand_len_method(scope)
76    }
77}
78
79impl<E: Numeric, N: Size, IO: Clone> Vectorized for VirtualTensor<E, N, IO> {}
80impl<E: Numeric, N: Size, IO: Clone> VectorizedExpand for VirtualTensorExpand<E, N, IO> {
81    fn vector_size(&self) -> VectorSize {
82        self.state.clone().vector_size()
83    }
84}
85
86impl<E: Numeric, N: Size, IO: Clone> SliceOperator<Vector<E, N>> for VirtualTensor<E, N, IO> {}
87impl<E: Numeric, N: Size, IO: Clone> SliceOperatorExpand<Vector<E, N>>
88    for VirtualTensorExpand<E, N, IO>
89{
90    fn __expand_slice_method(
91        &self,
92        scope: &mut Scope,
93        start: NativeExpand<usize>,
94        end: NativeExpand<usize>,
95    ) -> SliceExpand<Vector<E, N>, ReadOnly> {
96        self.state
97            .clone()
98            .__expand_read_window_method(scope, start, end)
99    }
100
101    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Vector<E, N>, ReadOnly> {
102        let end = self.clone().__expand_buffer_len_method(scope);
103        self.state
104            .clone()
105            .__expand_read_window_method(scope, 0.into(), end)
106    }
107}
108
109#[allow(unused, clippy::all)]
110impl<E: Numeric, N: Size, IO: Clone> VirtualTensor<E, N, IO> {
111    pub fn as_tensor_map(&self) -> Option<TensorMap<E, Tiled>> {
112        unexpanded!()
113    }
114    pub fn as_slice(&self, start: usize, end: usize) -> Slice<Vector<E, N>> {
115        unexpanded!();
116    }
117    /// Get the shape of the tensor at the given axis.
118    pub fn shape(&self, axis: usize) -> usize {
119        unexpanded!();
120    }
121    /// Get the stride of the tensor at the given axis.
122    pub fn stride(&self, axis: usize) -> usize {
123        unexpanded!();
124    }
125    /// Get the rank of the tensor.
126    pub fn rank(&self) -> usize {
127        unexpanded!();
128    }
129
130    pub fn buffer_len(&self) -> usize {
131        unexpanded!();
132    }
133
134    pub fn __expand_as_tensor_map(
135        context: &mut Scope,
136        this: <Self as CubeType>::ExpandType,
137    ) -> <ComptimeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
138        this.__expand_as_tensor_map_method(context)
139    }
140    pub fn __expand_as_slice(
141        context: &mut Scope,
142        this: <Self as CubeType>::ExpandType,
143        start: <usize as CubeType>::ExpandType,
144        end: <usize as CubeType>::ExpandType,
145    ) -> <Slice<Vector<E, N>> as CubeType>::ExpandType {
146        this.__expand_as_slice_method(context, start, end)
147    }
148    pub fn __expand_shape(
149        scope: &mut Scope,
150        this: <Self as CubeType>::ExpandType,
151        axis: <usize as CubeType>::ExpandType,
152    ) -> <usize as CubeType>::ExpandType {
153        this.__expand_shape_method(scope, axis)
154    }
155    pub fn __expand_stride(
156        scope: &mut Scope,
157        this: <Self as CubeType>::ExpandType,
158        axis: <usize as CubeType>::ExpandType,
159    ) -> <usize as CubeType>::ExpandType {
160        this.__expand_stride_method(scope, axis)
161    }
162    pub fn __expand_rank(
163        scope: &mut Scope,
164        this: <Self as CubeType>::ExpandType,
165    ) -> <usize as CubeType>::ExpandType {
166        this.__expand_rank_method(scope)
167    }
168    pub fn __expand_buffer_len(
169        scope: &mut Scope,
170        this: <Self as CubeType>::ExpandType,
171    ) -> <usize as CubeType>::ExpandType {
172        this.__expand_buffer_len_method(scope)
173    }
174}
175
176#[allow(unused, clippy::all)]
177impl<E: Numeric, N: Size, IO: Clone> VirtualTensorExpand<E, N, IO> {
178    pub fn __expand_as_tensor_map_method(
179        self,
180        context: &mut Scope,
181    ) -> <ComptimeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
182        self.state.clone().__expand_as_tensor_map_method(context)
183    }
184
185    pub fn __expand_as_slice_method(
186        self,
187        context: &mut Scope,
188        start: <usize as CubeType>::ExpandType,
189        end: <usize as CubeType>::ExpandType,
190    ) -> <Slice<Vector<E, N>> as CubeType>::ExpandType {
191        self.state
192            .clone()
193            .__expand_read_window_method(context, start, end)
194    }
195
196    pub fn __expand_shape_method(
197        self,
198        scope: &mut Scope,
199        axis: <usize as CubeType>::ExpandType,
200    ) -> <usize as CubeType>::ExpandType {
201        let _arg_0 = axis;
202        self.state
203            .clone()
204            .__expand_shape_method(scope, _arg_0.into())
205    }
206
207    pub fn __expand_stride_method(
208        self,
209        scope: &mut Scope,
210        axis: <usize as CubeType>::ExpandType,
211    ) -> <usize as CubeType>::ExpandType {
212        let _arg_0 = axis;
213        self.state
214            .clone()
215            .__expand_stride_method(scope, _arg_0.into())
216    }
217
218    pub fn __expand_rank_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
219        self.state.clone().__expand_rank_method(scope)
220    }
221
222    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
223        self.state.clone().__expand_buffer_len_method(scope)
224    }
225
226    pub fn __expand_read(
227        scope: &mut Scope,
228        this: Self,
229        index: <usize as CubeType>::ExpandType,
230    ) -> <Vector<E, N> as CubeType>::ExpandType {
231        VirtualTensor::<E, N, IO>::__expand_read(scope, this, index)
232    }
233
234    pub fn __expand_shape(
235        scope: &mut Scope,
236        this: Self,
237        axis: <usize as CubeType>::ExpandType,
238    ) -> <usize as CubeType>::ExpandType {
239        VirtualTensor::<E, N, IO>::__expand_shape(scope, this, axis)
240    }
241
242    pub fn __expand_stride(
243        scope: &mut Scope,
244        this: Self,
245        axis: <usize as CubeType>::ExpandType,
246    ) -> <usize as CubeType>::ExpandType {
247        VirtualTensor::<E, N, IO>::__expand_stride(scope, this, axis)
248    }
249
250    pub fn __expand_rank(scope: &mut Scope, this: Self) -> <usize as CubeType>::ExpandType {
251        VirtualTensor::<E, N, IO>::__expand_rank(scope, this)
252    }
253}
254
255impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensor<E, N, IO> {
256    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
257    /// layouts
258    pub fn view<C: Coordinates + 'static>(
259        &self,
260        layout: impl Into<VirtualLayout<C, Coords1d>>,
261    ) -> View<Vector<E, N>, C, ReadOnly> {
262        View::new::<VirtualTensor<E, N, IO>, Coords1d>(self, layout)
263    }
264}
265
266#[cube]
267impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensor<E, N, IO> {
268    /// Create a conceptual view over this tensor, with a simple linear layout
269    pub fn as_view(&self) -> View<Vector<E, N>, usize, ReadOnly> {
270        let vector_size = self.vector_size();
271        View::new::<VirtualTensor<E, N, IO>, usize>(
272            self,
273            SimpleLayout::new(self.len() * vector_size, vector_size),
274        )
275    }
276}
277
278impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensorExpand<E, N, IO> {
279    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
280    /// layouts
281    pub fn __expand_view_method<C: Coordinates + 'static>(
282        &self,
283        scope: &mut Scope,
284        layout: VirtualLayoutExpand<C, Coords1d>,
285    ) -> ViewExpand<Vector<E, N>, C, ReadOnly> {
286        View::__expand_new::<VirtualTensor<E, N, IO>, Coords1d>(scope, self.clone(), layout)
287    }
288}
289
290impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadWrite> {
291    #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
292    #[doc = " with custom layouts"]
293    pub fn view_mut<C: Coordinates + 'static>(
294        &self,
295        layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
296    ) -> View<Vector<E, N>, C, ReadWrite> {
297        let mut this: VirtualTensor<E, N, ReadWrite> = *self;
298        View::new_mut::<VirtualTensor<E, N, ReadWrite>, Coords1d>(&mut this, layout)
299    }
300    pub fn __expand_view_mut<C: Coordinates + 'static>(
301        scope: &mut Scope,
302        this: VirtualTensorExpand<E, N, ReadWrite>,
303        layout: VirtualLayoutExpand<C, Coords1d>,
304    ) -> ViewExpand<Vector<E, N>, C, ReadWrite> {
305        this.__expand_view_mut_method::<C>(scope, layout)
306    }
307}
308impl<E: Numeric, N: Size> VirtualTensorExpand<E, N, ReadWrite> {
309    pub fn __expand_view_mut_method<C: Coordinates + 'static>(
310        self,
311        scope: &mut Scope,
312        layout: VirtualLayoutExpand<C, Coords1d>,
313    ) -> ViewExpand<Vector<E, N>, C, ReadWrite> {
314        View::__expand_new_mut::<VirtualTensor<E, N, ReadWrite>, Coords1d>(scope, self, layout)
315    }
316}
317
318#[cube]
319impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadWrite> {
320    /// Create a conceptual mutable view over this tensor, with a simple linear layout
321    pub fn as_view_mut(&mut self) -> View<Vector<E, N>, usize, ReadWrite> {
322        let vector_size = self.vector_size();
323        View::new_mut::<VirtualTensor<E, N, ReadWrite>, usize>(
324            self,
325            SimpleLayout::new(self.len() * vector_size, vector_size),
326        )
327    }
328}
329
330#[cube]
331impl<E: Numeric, N: Size, IO: Clone + 'static> VirtualTensor<E, N, IO> {
332    pub fn coordinate(&self, index: usize, dim: usize) -> usize {
333        let num_strides = index / self.stride(dim);
334        num_strides % self.shape(dim)
335    }
336}
337
338impl<E: Numeric, N: Size> ListMut<Vector<E, N>> for VirtualTensor<E, N, ReadWrite> {
339    fn __expand_write(
340        scope: &mut Scope,
341        this: VirtualTensorExpand<E, N, ReadWrite>,
342        index: <usize as CubeType>::ExpandType,
343        value: <Vector<E, N> as CubeType>::ExpandType,
344    ) -> <() as CubeType>::ExpandType {
345        this.__expand_write_method(scope, index, value)
346    }
347}
348
349impl<E: Numeric, N: Size> ListMutExpand<Vector<E, N>> for VirtualTensorExpand<E, N, ReadWrite> {
350    fn __expand_write_method(
351        &self,
352        scope: &mut Scope,
353        index: <usize as CubeType>::ExpandType,
354        value: <Vector<E, N> as CubeType>::ExpandType,
355    ) -> <() as CubeType>::ExpandType {
356        self.state
357            .clone()
358            .__expand_write_method(scope, index, value)
359    }
360}
361
362impl<E: Numeric, N: Size> SliceMutOperator<Vector<E, N>> for VirtualTensor<E, N, ReadWrite> {}
363impl<E: Numeric, N: Size> SliceMutOperatorExpand<Vector<E, N>>
364    for VirtualTensorExpand<E, N, ReadWrite>
365{
366    #[allow(unused_variables)]
367    fn __expand_slice_mut_method(
368        &self,
369        scope: &mut Scope,
370        start: NativeExpand<usize>,
371        end: NativeExpand<usize>,
372    ) -> SliceExpand<Vector<E, N>, cubecl_core::prelude::ReadWrite> {
373        todo!("VirtualTensor don't support slice mut yet");
374    }
375
376    #[allow(unused_variables)]
377    fn __expand_to_slice_mut_method(
378        &self,
379        scope: &mut Scope,
380    ) -> SliceExpand<Vector<E, N>, cubecl_core::prelude::ReadWrite> {
381        todo!("VirtualTensor don't support slice mut yet");
382    }
383}
384
385impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadOnly> {
386    /// Create a new [read only](ReadOnly) [virtual tensor](VirtualTensor).
387    pub fn new<V: VirtualTensorOperations<E, N> + 'static>(_v: &V) -> Self {
388        unexpanded!()
389    }
390
391    /// Expand function of [`Self::new`].
392    pub fn __expand_new<V: VirtualTensorOperations<E, N> + 'static>(
393        _scope: &mut Scope,
394        v: V::ExpandType,
395    ) -> VirtualTensorExpand<E, N, ReadOnly> {
396        VirtualTensorExpand {
397            state: Arc::new(v),
398            _p: PhantomData,
399        }
400    }
401}
402
403impl<E: Numeric, N: Size> VirtualTensor<E, N, ReadWrite> {
404    /// Create a new [read write](ReadWrite) [virtual tensor](VirtualTensor).
405    pub fn new<V: VirtualTensorOperations<E, N> + 'static>(_v: &mut V) -> Self {
406        unexpanded!()
407    }
408
409    /// Expand function of [`Self::new`].
410    pub fn __expand_new<V: VirtualTensorOperations<E, N> + 'static>(
411        _scope: &mut Scope,
412        v: V::ExpandType,
413    ) -> VirtualTensorExpand<E, N, ReadWrite> {
414        VirtualTensorExpand {
415            state: Arc::new(v),
416            _p: PhantomData,
417        }
418    }
419}
420
421/// Trait to be implemented by a type that can become a [virtual tensor](VirtualTensor).
422///
423/// The [expand trait](VirtualTensorOperationsExpand) also need to be implemented for the type's
424/// expand type.
425///
426/// # Warning
427///
428/// This trait is kind of unsafe, [`VirtualTensorOperations::write`] doesn't follow the mutability
429/// rules, but it won't lead to any undefined behavior.
430#[cube(self_type = "ref", expand_base_traits = "VectorizedExpand")]
431pub trait VirtualTensorOperations<E: Numeric, N: Size>: Vectorized {
432    fn as_tensor_map(&self) -> ComptimeOption<TensorMap<E, Tiled>> {
433        unexpanded!()
434    }
435    /// Read the tensor at the given index.
436    fn read(&self, _index: usize) -> Vector<E, N> {
437        unexpanded!()
438    }
439    fn read_window(&self, _start: usize, _end: usize) -> Slice<Vector<E, N>, ReadOnly> {
440        unexpanded!()
441    }
442    /// Write the tensor at the given index.
443    fn write(&self, _index: usize, _value: Vector<E, N>) {
444        unexpanded!()
445    }
446    /// Get the shape of the tensor at the given axis.
447    fn shape(&self, _axis: usize) -> usize {
448        unexpanded!()
449    }
450    /// Get the stride of the tensor at the given axis.
451    fn stride(&self, _axis: usize) -> usize {
452        unexpanded!()
453    }
454    /// Get the rank of the tensor.
455    fn rank(&self) -> usize {
456        unexpanded!()
457    }
458    fn len(&self) -> usize {
459        unexpanded!()
460    }
461    fn buffer_len(&self) -> usize {
462        unexpanded!()
463    }
464}
465
466/// Making [virtual tensors](VirtualTensor) a proper [cube type](CubeType).
467mod __cube_type {
468    use super::*;
469
470    impl<E: Numeric, N: Size, IO: Clone> CubeType for VirtualTensor<E, N, IO> {
471        type ExpandType = VirtualTensorExpand<E, N, IO>;
472    }
473
474    impl<E: Numeric, N: Size, IO> IntoMut for VirtualTensorExpand<E, N, IO> {
475        fn into_mut(self, _scope: &mut Scope) -> Self {
476            self
477        }
478    }
479
480    impl<E: Numeric, N: Size, IO> CubeDebug for VirtualTensorExpand<E, N, IO> {}
481}
482
483/// Enable tensors to be virtual.
484mod __tensor {
485    use super::*;
486
487    impl<E: Numeric, N: Size> VirtualTensorOperations<E, N> for Tensor<Vector<E, N>> {}
488    impl<E: Numeric, N: Size> VirtualTensorOperationsExpand<E, N>
489        for NativeExpand<Tensor<Vector<E, N>>>
490    {
491        fn __expand_read_method(
492            &self,
493            scope: &mut Scope,
494            index: NativeExpand<usize>,
495        ) -> NativeExpand<Vector<E, N>> {
496            self.clone().__expand_index_unchecked_method(scope, index)
497        }
498        fn __expand_read_window_method(
499            &self,
500            context: &mut Scope,
501            start: NativeExpand<usize>,
502            end: NativeExpand<usize>,
503        ) -> SliceExpand<Vector<E, N>, ReadOnly> {
504            self.clone().__expand_slice_method(context, start, end)
505        }
506
507        fn __expand_write_method(
508            &self,
509            scope: &mut Scope,
510            index: NativeExpand<usize>,
511            value: NativeExpand<Vector<E, N>>,
512        ) {
513            self.clone()
514                .__expand_index_assign_unchecked_method(scope, index, value)
515        }
516
517        fn __expand_shape_method(
518            &self,
519            scope: &mut Scope,
520            axis: NativeExpand<usize>,
521        ) -> NativeExpand<usize> {
522            self.clone().__expand_shape_method(scope, axis)
523        }
524
525        fn __expand_stride_method(
526            &self,
527            scope: &mut Scope,
528            axis: NativeExpand<usize>,
529        ) -> NativeExpand<usize> {
530            self.clone().__expand_stride_method(scope, axis)
531        }
532
533        fn __expand_rank_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
534            self.clone().__expand_rank_method(scope)
535        }
536        fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
537            self.clone().__expand_len_method(scope)
538        }
539        fn __expand_buffer_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
540            self.clone().__expand_buffer_len_method(scope)
541        }
542
543        fn __expand_as_tensor_map_method(
544            &self,
545            scope: &mut Scope,
546        ) -> ComptimeOptionExpand<TensorMap<E, Tiled>> {
547            ComptimeOption::__expand_new_None(scope)
548        }
549    }
550}
551
552/// Enable tensor maps to be virtual.
553mod __tensor_map {
554    use super::*;
555
556    impl<E: Numeric, N: Size> VirtualTensorOperations<E, N> for TensorMap<E, Tiled> {}
557    impl<E: Numeric, N: Size> VirtualTensorOperationsExpand<E, N>
558        for NativeExpand<TensorMap<E, Tiled>>
559    {
560        fn __expand_read_method(
561            &self,
562            _scope: &mut Scope,
563            _index: NativeExpand<usize>,
564        ) -> NativeExpand<Vector<E, N>> {
565            todo!()
566        }
567        fn __expand_read_window_method(
568            &self,
569            _context: &mut Scope,
570            _start: NativeExpand<usize>,
571            _end: NativeExpand<usize>,
572        ) -> SliceExpand<Vector<E, N>, ReadOnly> {
573            todo!()
574        }
575
576        fn __expand_write_method(
577            &self,
578            _scope: &mut Scope,
579            _index: NativeExpand<usize>,
580            _value: NativeExpand<Vector<E, N>>,
581        ) {
582            todo!()
583        }
584
585        fn __expand_shape_method(
586            &self,
587            _scope: &mut Scope,
588            _axis: NativeExpand<usize>,
589        ) -> NativeExpand<usize> {
590            todo!()
591        }
592
593        fn __expand_stride_method(
594            &self,
595            _scope: &mut Scope,
596            _axis: NativeExpand<usize>,
597        ) -> NativeExpand<usize> {
598            todo!()
599        }
600
601        fn __expand_rank_method(&self, _scope: &mut Scope) -> NativeExpand<usize> {
602            todo!()
603        }
604        fn __expand_len_method(&self, _scope: &mut Scope) -> NativeExpand<usize> {
605            todo!()
606        }
607        fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> NativeExpand<usize> {
608            todo!()
609        }
610
611        fn __expand_as_tensor_map_method(
612            &self,
613            scope: &mut Scope,
614        ) -> ComptimeOptionExpand<TensorMap<E, Tiled>> {
615            ComptimeOption::__expand_new_Some(scope, self.clone())
616        }
617    }
618}