Skip to main content

cubecl_std/tensor/
virtual.rs

1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::{CubeType, Scope, *};
4use cubecl_core::{self as cubecl, unexpanded};
5use std::ops::{Deref, DerefMut};
6
7use crate::{
8    CubeOption,
9    tensor::{
10        ViewExpand,
11        layout::{
12            Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
13        },
14        view::View,
15    },
16};
17
18/// Tensor representation that is decoupled from how the tensor is stored.
19#[derive(Clone)]
20pub struct VirtualTensor<E: Numeric, IO = ReadOnly> {
21    // state: Arc<dyn VirtualTensorOperations<E>>,
22    _e: PhantomData<E>,
23    _p: PhantomData<IO>,
24}
25
26impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
27
28/// Expand type for [`VirtualTensor`].
29#[derive(Clone)]
30pub struct VirtualTensorExpand<E: Numeric, IO> {
31    state: Arc<dyn VirtualTensorOperationsExpand<E>>,
32    _p: PhantomData<IO>,
33}
34
35impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
36    fn __expand_read(
37        scope: &mut Scope,
38        this: VirtualTensorExpand<E, IO>,
39        index: <usize as CubeType>::ExpandType,
40    ) -> <Line<E> as CubeType>::ExpandType {
41        this.__expand_read_method(scope, index)
42    }
43}
44
45impl<T: Numeric, IO: Clone> Deref for VirtualTensor<T, IO> {
46    type Target = [Line<T>];
47
48    fn deref(&self) -> &Self::Target {
49        unexpanded!()
50    }
51}
52
53impl<T: Numeric> DerefMut for VirtualTensor<T, ReadWrite> {
54    fn deref_mut(&mut self) -> &mut Self::Target {
55        unexpanded!()
56    }
57}
58
59impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
60    fn __expand_read_method(
61        &self,
62        scope: &mut Scope,
63        index: <usize as CubeType>::ExpandType,
64    ) -> <Line<E> as CubeType>::ExpandType {
65        self.state.clone().__expand_read_method(scope, index)
66    }
67
68    fn __expand_read_unchecked_method(
69        &self,
70        _scope: &mut Scope,
71        _index: ExpandElementTyped<usize>,
72    ) -> <Line<E> as CubeType>::ExpandType {
73        todo!("VirtualTensor don't support read unchecked yet");
74    }
75
76    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
77        self.state.clone().__expand_len_method(scope)
78    }
79}
80
81impl<E: Numeric, IO: Clone> Lined for VirtualTensor<E, IO> {}
82impl<E: Numeric, IO: Clone> LinedExpand for VirtualTensorExpand<E, IO> {
83    fn line_size(&self) -> LineSize {
84        self.state.clone().line_size()
85    }
86}
87
88impl<E: Numeric, IO: Clone> SliceOperator<Line<E>> for VirtualTensor<E, IO> {}
89impl<E: Numeric, IO: Clone> SliceOperatorExpand<Line<E>> for VirtualTensorExpand<E, IO> {
90    fn __expand_slice_method(
91        &self,
92        scope: &mut Scope,
93        start: ExpandElementTyped<usize>,
94        end: ExpandElementTyped<usize>,
95    ) -> SliceExpand<Line<E>, ReadOnly> {
96        self.state
97            .clone()
98            .__expand_read_window_method(scope, start, end)
99    }
100
101    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Line<E>, ReadOnly> {
102        let end = self.clone().__expand_buffer_len_method(scope);
103        self.state
104            .clone()
105            .__expand_read_window_method(scope, 0.into(), end)
106    }
107}
108
109#[allow(unused, clippy::all)]
110impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
111    pub fn as_tensor_map(&self) -> CubeOption<TensorMap<E, Tiled>> {
112        unexpanded!()
113    }
114    pub fn as_slice(&self, start: usize, end: usize) -> Slice<Line<E>> {
115        unexpanded!();
116    }
117    /// Get the shape of the tensor at the given axis.
118    pub fn shape(&self, axis: usize) -> usize {
119        unexpanded!();
120    }
121    /// Get the stride of the tensor at the given axis.
122    pub fn stride(&self, axis: usize) -> usize {
123        unexpanded!();
124    }
125    /// Get the rank of the tensor.
126    pub fn rank(&self) -> usize {
127        unexpanded!();
128    }
129
130    pub fn buffer_len(&self) -> usize {
131        unexpanded!();
132    }
133
134    pub fn __expand_as_tensor_map(
135        context: &mut Scope,
136        this: <Self as CubeType>::ExpandType,
137    ) -> <CubeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
138        this.__expand_as_tensor_map_method(context)
139    }
140    pub fn __expand_as_slice(
141        context: &mut Scope,
142        this: <Self as CubeType>::ExpandType,
143        start: <usize as CubeType>::ExpandType,
144        end: <usize as CubeType>::ExpandType,
145    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
146        this.__expand_as_slice_method(context, start, end)
147    }
148    pub fn __expand_shape(
149        scope: &mut Scope,
150        this: <Self as CubeType>::ExpandType,
151        axis: <usize as CubeType>::ExpandType,
152    ) -> <usize as CubeType>::ExpandType {
153        this.__expand_shape_method(scope, axis)
154    }
155    pub fn __expand_stride(
156        scope: &mut Scope,
157        this: <Self as CubeType>::ExpandType,
158        axis: <usize as CubeType>::ExpandType,
159    ) -> <usize as CubeType>::ExpandType {
160        this.__expand_stride_method(scope, axis)
161    }
162    pub fn __expand_rank(
163        scope: &mut Scope,
164        this: <Self as CubeType>::ExpandType,
165    ) -> <usize as CubeType>::ExpandType {
166        this.__expand_rank_method(scope)
167    }
168    pub fn __expand_buffer_len(
169        scope: &mut Scope,
170        this: <Self as CubeType>::ExpandType,
171    ) -> <usize as CubeType>::ExpandType {
172        this.__expand_buffer_len_method(scope)
173    }
174}
175
176#[allow(unused, clippy::all)]
177impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
178    pub fn __expand_as_tensor_map_method(
179        self,
180        context: &mut Scope,
181    ) -> <CubeOption<TensorMap<E, Tiled>> as CubeType>::ExpandType {
182        self.state.clone().__expand_as_tensor_map_method(context)
183    }
184
185    pub fn __expand_as_slice_method(
186        self,
187        context: &mut Scope,
188        start: <usize as CubeType>::ExpandType,
189        end: <usize as CubeType>::ExpandType,
190    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
191        self.state
192            .clone()
193            .__expand_read_window_method(context, start, end)
194    }
195
196    pub fn __expand_shape_method(
197        self,
198        scope: &mut Scope,
199        axis: <usize as CubeType>::ExpandType,
200    ) -> <usize as CubeType>::ExpandType {
201        let _arg_0 = axis;
202        self.state
203            .clone()
204            .__expand_shape_method(scope, _arg_0.into())
205    }
206
207    pub fn __expand_stride_method(
208        self,
209        scope: &mut Scope,
210        axis: <usize as CubeType>::ExpandType,
211    ) -> <usize as CubeType>::ExpandType {
212        let _arg_0 = axis;
213        self.state
214            .clone()
215            .__expand_stride_method(scope, _arg_0.into())
216    }
217
218    pub fn __expand_rank_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
219        self.state.clone().__expand_rank_method(scope)
220    }
221
222    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
223        self.state.clone().__expand_buffer_len_method(scope)
224    }
225
226    pub fn __expand_read(
227        scope: &mut Scope,
228        this: Self,
229        index: <usize as CubeType>::ExpandType,
230    ) -> <Line<E> as CubeType>::ExpandType {
231        VirtualTensor::<E, IO>::__expand_read(scope, this, index)
232    }
233
234    pub fn __expand_shape(
235        scope: &mut Scope,
236        this: Self,
237        axis: <usize as CubeType>::ExpandType,
238    ) -> <usize as CubeType>::ExpandType {
239        VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
240    }
241
242    pub fn __expand_stride(
243        scope: &mut Scope,
244        this: Self,
245        axis: <usize as CubeType>::ExpandType,
246    ) -> <usize as CubeType>::ExpandType {
247        VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
248    }
249
250    pub fn __expand_rank(scope: &mut Scope, this: Self) -> <usize as CubeType>::ExpandType {
251        VirtualTensor::<E, IO>::__expand_rank(scope, this)
252    }
253}
254
255impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
256    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
257    /// layouts
258    pub fn view<C: Coordinates + 'static>(
259        &self,
260        layout: impl Into<VirtualLayout<C, Coords1d>>,
261    ) -> View<Line<E>, C, ReadOnly> {
262        View::new::<VirtualTensor<E, IO>, Coords1d>(self, layout)
263    }
264}
265
266#[cube]
267impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
268    /// Create a conceptual view over this tensor, with a simple linear layout
269    pub fn as_view(&self) -> View<Line<E>, usize, ReadOnly> {
270        let line_size = self.line_size();
271        View::new::<VirtualTensor<E, IO>, usize>(
272            self,
273            SimpleLayout::new(self.len() * line_size, line_size),
274        )
275    }
276}
277
278impl<E: Numeric, IO: Clone + 'static> VirtualTensorExpand<E, IO> {
279    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
280    /// layouts
281    pub fn __expand_view_method<C: Coordinates + 'static>(
282        &self,
283        scope: &mut Scope,
284        layout: VirtualLayoutExpand<C, Coords1d>,
285    ) -> ViewExpand<Line<E>, C, ReadOnly> {
286        View::__expand_new::<VirtualTensor<E, IO>, Coords1d>(scope, self.clone(), layout)
287    }
288}
289
290impl<E: Numeric> VirtualTensor<E, ReadWrite> {
291    #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
292    #[doc = " with custom layouts"]
293    pub fn view_mut<C: Coordinates + 'static>(
294        &self,
295        layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
296    ) -> View<Line<E>, C, ReadWrite> {
297        let mut this: VirtualTensor<E, ReadWrite> = *self;
298        View::new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(&mut this, layout)
299    }
300    pub fn __expand_view_mut<C: Coordinates + 'static>(
301        scope: &mut Scope,
302        this: VirtualTensorExpand<E, ReadWrite>,
303        layout: VirtualLayoutExpand<C, Coords1d>,
304    ) -> ViewExpand<Line<E>, C, ReadWrite> {
305        this.__expand_view_mut_method::<C>(scope, layout)
306    }
307}
308impl<E: Numeric> VirtualTensorExpand<E, ReadWrite> {
309    pub fn __expand_view_mut_method<C: Coordinates + 'static>(
310        self,
311        scope: &mut Scope,
312        layout: VirtualLayoutExpand<C, Coords1d>,
313    ) -> ViewExpand<Line<E>, C, ReadWrite> {
314        View::__expand_new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(scope, self, layout)
315    }
316}
317
318#[cube]
319impl<E: Numeric> VirtualTensor<E, ReadWrite> {
320    /// Create a conceptual mutable view over this tensor, with a simple linear layout
321    pub fn as_view_mut(&mut self) -> View<Line<E>, usize, ReadWrite> {
322        let line_size = self.line_size();
323        View::new_mut::<VirtualTensor<E, ReadWrite>, usize>(
324            self,
325            SimpleLayout::new(self.len() * line_size, line_size),
326        )
327    }
328}
329
330#[cube]
331impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
332    pub fn coordinate(&self, index: usize, dim: usize) -> usize {
333        let num_strides = index / self.stride(dim);
334        num_strides % self.shape(dim)
335    }
336}
337
338impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
339    fn __expand_write(
340        scope: &mut Scope,
341        this: VirtualTensorExpand<E, ReadWrite>,
342        index: <usize as CubeType>::ExpandType,
343        value: <Line<E> as CubeType>::ExpandType,
344    ) -> <() as CubeType>::ExpandType {
345        this.__expand_write_method(scope, index, value)
346    }
347}
348
349impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
350    fn __expand_write_method(
351        &self,
352        scope: &mut Scope,
353        index: <usize as CubeType>::ExpandType,
354        value: <Line<E> as CubeType>::ExpandType,
355    ) -> <() as CubeType>::ExpandType {
356        self.state
357            .clone()
358            .__expand_write_method(scope, index, value)
359    }
360}
361
362impl<E: Numeric> SliceMutOperator<Line<E>> for VirtualTensor<E, ReadWrite> {}
363impl<E: Numeric> SliceMutOperatorExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
364    #[allow(unused_variables)]
365    fn __expand_slice_mut_method(
366        &self,
367        scope: &mut Scope,
368        start: ExpandElementTyped<usize>,
369        end: ExpandElementTyped<usize>,
370    ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
371        todo!("VirtualTensor don't support slice mut yet");
372    }
373
374    #[allow(unused_variables)]
375    fn __expand_to_slice_mut_method(
376        &self,
377        scope: &mut Scope,
378    ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
379        todo!("VirtualTensor don't support slice mut yet");
380    }
381}
382
383impl<E: Numeric> VirtualTensor<E, ReadOnly> {
384    /// Create a new [read only](ReadOnly) [virtual tensor](VirtualTensor).
385    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
386        unexpanded!()
387    }
388
389    /// Expand function of [`Self::new`].
390    pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
391        _scope: &mut Scope,
392        v: V::ExpandType,
393    ) -> VirtualTensorExpand<E, ReadOnly> {
394        VirtualTensorExpand {
395            state: Arc::new(v),
396            _p: PhantomData,
397        }
398    }
399}
400
401impl<E: Numeric> VirtualTensor<E, ReadWrite> {
402    /// Create a new [read write](ReadWrite) [virtual tensor](VirtualTensor).
403    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
404        unexpanded!()
405    }
406
407    /// Expand function of [`Self::new`].
408    pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
409        _scope: &mut Scope,
410        v: V::ExpandType,
411    ) -> VirtualTensorExpand<E, ReadWrite> {
412        VirtualTensorExpand {
413            state: Arc::new(v),
414            _p: PhantomData,
415        }
416    }
417}
418
419/// Trait to be implemented by a type that can become a [virtual tensor](VirtualTensor).
420///
421/// The [expand trait](VirtualTensorOperationsExpand) also need to be implemented for the type's
422/// expand type.
423///
424/// # Warning
425///
426/// This trait is kind of unsafe, [`VirtualTensorOperations::write`] doesn't follow the mutability
427/// rules, but it won't lead to any undefined behavior.
428#[cube(self_type = "ref", expand_base_traits = "LinedExpand")]
429pub trait VirtualTensorOperations<E: Numeric>: Lined {
430    fn as_tensor_map(&self) -> CubeOption<TensorMap<E, Tiled>> {
431        unexpanded!()
432    }
433    /// Read the tensor at the given index.
434    fn read(&self, _index: usize) -> Line<E> {
435        unexpanded!()
436    }
437    fn read_window(&self, _start: usize, _end: usize) -> Slice<Line<E>, ReadOnly> {
438        unexpanded!()
439    }
440    /// Write the tensor at the given index.
441    fn write(&self, _index: usize, _value: Line<E>) {
442        unexpanded!()
443    }
444    /// Get the shape of the tensor at the given axis.
445    fn shape(&self, _axis: usize) -> usize {
446        unexpanded!()
447    }
448    /// Get the stride of the tensor at the given axis.
449    fn stride(&self, _axis: usize) -> usize {
450        unexpanded!()
451    }
452    /// Get the rank of the tensor.
453    fn rank(&self) -> usize {
454        unexpanded!()
455    }
456    fn len(&self) -> usize {
457        unexpanded!()
458    }
459    fn buffer_len(&self) -> usize {
460        unexpanded!()
461    }
462}
463
464/// Making [virtual tensors](VirtualTensor) a proper [cube type](CubeType).
465mod __cube_type {
466    use super::*;
467
468    impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
469        type ExpandType = VirtualTensorExpand<E, IO>;
470    }
471
472    impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
473        fn into_mut(self, _scope: &mut Scope) -> Self {
474            self
475        }
476    }
477
478    impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
479}
480
481/// Enable tensors to be virtual.
482mod __tensor {
483    use crate::CubeOptionExpand;
484
485    use super::*;
486
487    impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
488    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
489        fn __expand_read_method(
490            &self,
491            scope: &mut Scope,
492            index: ExpandElementTyped<usize>,
493        ) -> ExpandElementTyped<Line<E>> {
494            self.clone().__expand_index_unchecked_method(scope, index)
495        }
496        fn __expand_read_window_method(
497            &self,
498            context: &mut Scope,
499            start: ExpandElementTyped<usize>,
500            end: ExpandElementTyped<usize>,
501        ) -> SliceExpand<Line<E>, ReadOnly> {
502            self.clone().__expand_slice_method(context, start, end)
503        }
504
505        fn __expand_write_method(
506            &self,
507            scope: &mut Scope,
508            index: ExpandElementTyped<usize>,
509            value: ExpandElementTyped<Line<E>>,
510        ) {
511            self.clone()
512                .__expand_index_assign_unchecked_method(scope, index, value)
513        }
514
515        fn __expand_shape_method(
516            &self,
517            scope: &mut Scope,
518            axis: ExpandElementTyped<usize>,
519        ) -> ExpandElementTyped<usize> {
520            self.clone().__expand_shape_method(scope, axis)
521        }
522
523        fn __expand_stride_method(
524            &self,
525            scope: &mut Scope,
526            axis: ExpandElementTyped<usize>,
527        ) -> ExpandElementTyped<usize> {
528            self.clone().__expand_stride_method(scope, axis)
529        }
530
531        fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
532            self.clone().__expand_rank_method(scope)
533        }
534        fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
535            self.clone().__expand_len_method(scope)
536        }
537        fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
538            self.clone().__expand_buffer_len_method(scope)
539        }
540
541        fn __expand_as_tensor_map_method(
542            &self,
543            scope: &mut Scope,
544        ) -> CubeOptionExpand<TensorMap<E, Tiled>> {
545            CubeOption::__expand_new_None(scope)
546        }
547    }
548}
549
550/// Enable tensor maps to be virtual.
551mod __tensor_map {
552    use crate::CubeOptionExpand;
553
554    use super::*;
555
556    impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E, Tiled> {}
557    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E, Tiled>> {
558        fn __expand_read_method(
559            &self,
560            _scope: &mut Scope,
561            _index: ExpandElementTyped<usize>,
562        ) -> ExpandElementTyped<Line<E>> {
563            todo!()
564        }
565        fn __expand_read_window_method(
566            &self,
567            _context: &mut Scope,
568            _start: ExpandElementTyped<usize>,
569            _end: ExpandElementTyped<usize>,
570        ) -> SliceExpand<Line<E>, ReadOnly> {
571            todo!()
572        }
573
574        fn __expand_write_method(
575            &self,
576            _scope: &mut Scope,
577            _index: ExpandElementTyped<usize>,
578            _value: ExpandElementTyped<Line<E>>,
579        ) {
580            todo!()
581        }
582
583        fn __expand_shape_method(
584            &self,
585            _scope: &mut Scope,
586            _axis: ExpandElementTyped<usize>,
587        ) -> ExpandElementTyped<usize> {
588            todo!()
589        }
590
591        fn __expand_stride_method(
592            &self,
593            _scope: &mut Scope,
594            _axis: ExpandElementTyped<usize>,
595        ) -> ExpandElementTyped<usize> {
596            todo!()
597        }
598
599        fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
600            todo!()
601        }
602        fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
603            todo!()
604        }
605        fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
606            todo!()
607        }
608
609        fn __expand_as_tensor_map_method(
610            &self,
611            scope: &mut Scope,
612        ) -> CubeOptionExpand<TensorMap<E, Tiled>> {
613            CubeOption::__expand_new_Some(scope, self.clone())
614        }
615    }
616}