Skip to main content

cubecl_std/tensor/
virtual.rs

1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5use std::ops::{Deref, DerefMut};
6
7use crate::tensor::{
8    ViewExpand,
9    layout::{
10        Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
11    },
12    view::View,
13};
14
15/// Tensor representation that is decoupled from how the tensor is stored.
16#[derive(Clone)]
17pub struct VirtualTensor<E: Numeric, IO = ReadOnly> {
18    // state: Arc<dyn VirtualTensorOperations<E>>,
19    _e: PhantomData<E>,
20    _p: PhantomData<IO>,
21}
22
23impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
24
25/// Expand type for [`VirtualTensor`].
26#[derive(Clone)]
27pub struct VirtualTensorExpand<E: Numeric, IO> {
28    state: Arc<dyn VirtualTensorOperationsExpand<E>>,
29    _p: PhantomData<IO>,
30}
31
32impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
33    fn __expand_read(
34        scope: &mut Scope,
35        this: VirtualTensorExpand<E, IO>,
36        index: <usize as CubeType>::ExpandType,
37    ) -> <Line<E> as CubeType>::ExpandType {
38        this.__expand_read_method(scope, index)
39    }
40}
41
42impl<T: Numeric, IO: Clone> Deref for VirtualTensor<T, IO> {
43    type Target = [Line<T>];
44
45    fn deref(&self) -> &Self::Target {
46        unexpanded!()
47    }
48}
49
50impl<T: Numeric> DerefMut for VirtualTensor<T, ReadWrite> {
51    fn deref_mut(&mut self) -> &mut Self::Target {
52        unexpanded!()
53    }
54}
55
56impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
57    fn __expand_read_method(
58        &self,
59        scope: &mut Scope,
60        index: <usize as CubeType>::ExpandType,
61    ) -> <Line<E> as CubeType>::ExpandType {
62        self.state.clone().__expand_read_method(scope, index)
63    }
64
65    fn __expand_read_unchecked_method(
66        &self,
67        _scope: &mut Scope,
68        _index: ExpandElementTyped<usize>,
69    ) -> <Line<E> as CubeType>::ExpandType {
70        todo!("VirtualTensor don't support read unchecked yet");
71    }
72
73    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
74        self.state.clone().__expand_len_method(scope)
75    }
76}
77
78impl<E: Numeric, IO: Clone> Lined for VirtualTensor<E, IO> {}
79impl<E: Numeric, IO: Clone> LinedExpand for VirtualTensorExpand<E, IO> {
80    fn line_size(&self) -> LineSize {
81        self.state.clone().line_size()
82    }
83}
84
85impl<E: Numeric, IO: Clone> SliceOperator<Line<E>> for VirtualTensor<E, IO> {}
86impl<E: Numeric, IO: Clone> SliceOperatorExpand<Line<E>> for VirtualTensorExpand<E, IO> {
87    fn __expand_slice_method(
88        &self,
89        scope: &mut Scope,
90        start: ExpandElementTyped<usize>,
91        end: ExpandElementTyped<usize>,
92    ) -> SliceExpand<Line<E>, ReadOnly> {
93        self.state
94            .clone()
95            .__expand_read_window_method(scope, start, end)
96    }
97
98    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Line<E>, ReadOnly> {
99        let end = self.clone().__expand_buffer_len_method(scope);
100        self.state
101            .clone()
102            .__expand_read_window_method(scope, 0.into(), end)
103    }
104}
105
106#[allow(unused, clippy::all)]
107impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
108    pub fn as_tensor_map(&self) -> Option<TensorMap<E, Tiled>> {
109        unexpanded!()
110    }
111    pub fn as_slice(&self, start: usize, end: usize) -> Slice<Line<E>> {
112        unexpanded!();
113    }
114    /// Get the shape of the tensor at the given axis.
115    pub fn shape(&self, axis: usize) -> usize {
116        unexpanded!();
117    }
118    /// Get the stride of the tensor at the given axis.
119    pub fn stride(&self, axis: usize) -> usize {
120        unexpanded!();
121    }
122    /// Get the rank of the tensor.
123    pub fn rank(&self) -> usize {
124        unexpanded!();
125    }
126
127    pub fn buffer_len(&self) -> usize {
128        unexpanded!();
129    }
130
131    pub fn __expand_as_tensor_map(
132        context: &mut Scope,
133        this: <Self as CubeType>::ExpandType,
134    ) -> <Option<TensorMap<E, Tiled>> as CubeType>::ExpandType {
135        this.__expand_as_tensor_map_method(context)
136    }
137    pub fn __expand_as_slice(
138        context: &mut Scope,
139        this: <Self as CubeType>::ExpandType,
140        start: <usize as CubeType>::ExpandType,
141        end: <usize as CubeType>::ExpandType,
142    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
143        this.__expand_as_slice_method(context, start, end)
144    }
145    pub fn __expand_shape(
146        scope: &mut Scope,
147        this: <Self as CubeType>::ExpandType,
148        axis: <usize as CubeType>::ExpandType,
149    ) -> <usize as CubeType>::ExpandType {
150        this.__expand_shape_method(scope, axis)
151    }
152    pub fn __expand_stride(
153        scope: &mut Scope,
154        this: <Self as CubeType>::ExpandType,
155        axis: <usize as CubeType>::ExpandType,
156    ) -> <usize as CubeType>::ExpandType {
157        this.__expand_stride_method(scope, axis)
158    }
159    pub fn __expand_rank(
160        scope: &mut Scope,
161        this: <Self as CubeType>::ExpandType,
162    ) -> <usize as CubeType>::ExpandType {
163        this.__expand_rank_method(scope)
164    }
165    pub fn __expand_buffer_len(
166        scope: &mut Scope,
167        this: <Self as CubeType>::ExpandType,
168    ) -> <usize as CubeType>::ExpandType {
169        this.__expand_buffer_len_method(scope)
170    }
171}
172
173#[allow(unused, clippy::all)]
174impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
175    pub fn __expand_as_tensor_map_method(
176        self,
177        context: &mut Scope,
178    ) -> <Option<TensorMap<E, Tiled>> as CubeType>::ExpandType {
179        self.state.clone().__expand_as_tensor_map_method(context)
180    }
181
182    pub fn __expand_as_slice_method(
183        self,
184        context: &mut Scope,
185        start: <usize as CubeType>::ExpandType,
186        end: <usize as CubeType>::ExpandType,
187    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
188        self.state
189            .clone()
190            .__expand_read_window_method(context, start, end)
191    }
192
193    pub fn __expand_shape_method(
194        self,
195        scope: &mut Scope,
196        axis: <usize as CubeType>::ExpandType,
197    ) -> <usize as CubeType>::ExpandType {
198        let _arg_0 = axis;
199        self.state
200            .clone()
201            .__expand_shape_method(scope, _arg_0.into())
202    }
203
204    pub fn __expand_stride_method(
205        self,
206        scope: &mut Scope,
207        axis: <usize as CubeType>::ExpandType,
208    ) -> <usize as CubeType>::ExpandType {
209        let _arg_0 = axis;
210        self.state
211            .clone()
212            .__expand_stride_method(scope, _arg_0.into())
213    }
214
215    pub fn __expand_rank_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
216        self.state.clone().__expand_rank_method(scope)
217    }
218
219    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <usize as CubeType>::ExpandType {
220        self.state.clone().__expand_buffer_len_method(scope)
221    }
222
223    pub fn __expand_read(
224        scope: &mut Scope,
225        this: Self,
226        index: <usize as CubeType>::ExpandType,
227    ) -> <Line<E> as CubeType>::ExpandType {
228        VirtualTensor::<E, IO>::__expand_read(scope, this, index)
229    }
230
231    pub fn __expand_shape(
232        scope: &mut Scope,
233        this: Self,
234        axis: <usize as CubeType>::ExpandType,
235    ) -> <usize as CubeType>::ExpandType {
236        VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
237    }
238
239    pub fn __expand_stride(
240        scope: &mut Scope,
241        this: Self,
242        axis: <usize as CubeType>::ExpandType,
243    ) -> <usize as CubeType>::ExpandType {
244        VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
245    }
246
247    pub fn __expand_rank(scope: &mut Scope, this: Self) -> <usize as CubeType>::ExpandType {
248        VirtualTensor::<E, IO>::__expand_rank(scope, this)
249    }
250}
251
252impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
253    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
254    /// layouts
255    pub fn view<C: Coordinates + 'static>(
256        &self,
257        layout: impl Into<VirtualLayout<C, Coords1d>>,
258    ) -> View<Line<E>, C, ReadOnly> {
259        View::new::<VirtualTensor<E, IO>, Coords1d>(self, layout)
260    }
261}
262
263#[cube]
264impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
265    /// Create a conceptual view over this tensor, with a simple linear layout
266    pub fn as_view(&self) -> View<Line<E>, usize, ReadOnly> {
267        let line_size = self.line_size();
268        View::new::<VirtualTensor<E, IO>, usize>(
269            self,
270            SimpleLayout::new(self.len() * line_size, line_size),
271        )
272    }
273}
274
275impl<E: Numeric, IO: Clone + 'static> VirtualTensorExpand<E, IO> {
276    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
277    /// layouts
278    pub fn __expand_view_method<C: Coordinates + 'static>(
279        &self,
280        scope: &mut Scope,
281        layout: VirtualLayoutExpand<C, Coords1d>,
282    ) -> ViewExpand<Line<E>, C, ReadOnly> {
283        View::__expand_new::<VirtualTensor<E, IO>, Coords1d>(scope, self.clone(), layout)
284    }
285}
286
287impl<E: Numeric> VirtualTensor<E, ReadWrite> {
288    #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
289    #[doc = " with custom layouts"]
290    pub fn view_mut<C: Coordinates + 'static>(
291        &self,
292        layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
293    ) -> View<Line<E>, C, ReadWrite> {
294        let mut this: VirtualTensor<E, ReadWrite> = *self;
295        View::new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(&mut this, layout)
296    }
297    pub fn __expand_view_mut<C: Coordinates + 'static>(
298        scope: &mut Scope,
299        this: VirtualTensorExpand<E, ReadWrite>,
300        layout: VirtualLayoutExpand<C, Coords1d>,
301    ) -> ViewExpand<Line<E>, C, ReadWrite> {
302        this.__expand_view_mut_method::<C>(scope, layout)
303    }
304}
305impl<E: Numeric> VirtualTensorExpand<E, ReadWrite> {
306    pub fn __expand_view_mut_method<C: Coordinates + 'static>(
307        self,
308        scope: &mut Scope,
309        layout: VirtualLayoutExpand<C, Coords1d>,
310    ) -> ViewExpand<Line<E>, C, ReadWrite> {
311        View::__expand_new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(scope, self, layout)
312    }
313}
314
315#[cube]
316impl<E: Numeric> VirtualTensor<E, ReadWrite> {
317    /// Create a conceptual mutable view over this tensor, with a simple linear layout
318    pub fn as_view_mut(&mut self) -> View<Line<E>, usize, ReadWrite> {
319        let line_size = self.line_size();
320        View::new_mut::<VirtualTensor<E, ReadWrite>, usize>(
321            self,
322            SimpleLayout::new(self.len() * line_size, line_size),
323        )
324    }
325}
326
327#[cube]
328impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
329    pub fn coordinate(&self, index: usize, dim: usize) -> usize {
330        let num_strides = index / self.stride(dim);
331        num_strides % self.shape(dim)
332    }
333}
334
335impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
336    fn __expand_write(
337        scope: &mut Scope,
338        this: VirtualTensorExpand<E, ReadWrite>,
339        index: <usize as CubeType>::ExpandType,
340        value: <Line<E> as CubeType>::ExpandType,
341    ) -> <() as CubeType>::ExpandType {
342        this.__expand_write_method(scope, index, value)
343    }
344}
345
346impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
347    fn __expand_write_method(
348        &self,
349        scope: &mut Scope,
350        index: <usize as CubeType>::ExpandType,
351        value: <Line<E> as CubeType>::ExpandType,
352    ) -> <() as CubeType>::ExpandType {
353        self.state
354            .clone()
355            .__expand_write_method(scope, index, value)
356    }
357}
358
359impl<E: Numeric> SliceMutOperator<Line<E>> for VirtualTensor<E, ReadWrite> {}
360impl<E: Numeric> SliceMutOperatorExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
361    #[allow(unused_variables)]
362    fn __expand_slice_mut_method(
363        &self,
364        scope: &mut Scope,
365        start: ExpandElementTyped<usize>,
366        end: ExpandElementTyped<usize>,
367    ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
368        todo!("VirtualTensor don't support slice mut yet");
369    }
370
371    #[allow(unused_variables)]
372    fn __expand_to_slice_mut_method(
373        &self,
374        scope: &mut Scope,
375    ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
376        todo!("VirtualTensor don't support slice mut yet");
377    }
378}
379
380impl<E: Numeric> VirtualTensor<E, ReadOnly> {
381    /// Create a new [read only](ReadOnly) [virtual tensor](VirtualTensor).
382    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
383        unexpanded!()
384    }
385
386    /// Expand function of [`Self::new`].
387    pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
388        _scope: &mut Scope,
389        v: V::ExpandType,
390    ) -> VirtualTensorExpand<E, ReadOnly> {
391        VirtualTensorExpand {
392            state: Arc::new(v),
393            _p: PhantomData,
394        }
395    }
396}
397
398impl<E: Numeric> VirtualTensor<E, ReadWrite> {
399    /// Create a new [read write](ReadWrite) [virtual tensor](VirtualTensor).
400    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
401        unexpanded!()
402    }
403
404    /// Expand function of [`Self::new`].
405    pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
406        _scope: &mut Scope,
407        v: V::ExpandType,
408    ) -> VirtualTensorExpand<E, ReadWrite> {
409        VirtualTensorExpand {
410            state: Arc::new(v),
411            _p: PhantomData,
412        }
413    }
414}
415
416/// Trait to be implemented by a type that can become a [virtual tensor](VirtualTensor).
417///
418/// The [expand trait](VirtualTensorOperationsExpand) also need to be implemented for the type's
419/// expand type.
420///
421/// # Warning
422///
423/// This trait is kind of unsafe, [`VirtualTensorOperations::write`] doesn't follow the mutability
424/// rules, but it won't lead to any undefined behavior.
425#[cube(self_type = "ref", expand_base_traits = "LinedExpand")]
426pub trait VirtualTensorOperations<E: Numeric>: Lined {
427    fn as_tensor_map(&self) -> Option<TensorMap<E, Tiled>> {
428        unexpanded!()
429    }
430    /// Read the tensor at the given index.
431    fn read(&self, _index: usize) -> Line<E> {
432        unexpanded!()
433    }
434    fn read_window(&self, _start: usize, _end: usize) -> Slice<Line<E>, ReadOnly> {
435        unexpanded!()
436    }
437    /// Write the tensor at the given index.
438    fn write(&self, _index: usize, _value: Line<E>) {
439        unexpanded!()
440    }
441    /// Get the shape of the tensor at the given axis.
442    fn shape(&self, _axis: usize) -> usize {
443        unexpanded!()
444    }
445    /// Get the stride of the tensor at the given axis.
446    fn stride(&self, _axis: usize) -> usize {
447        unexpanded!()
448    }
449    /// Get the rank of the tensor.
450    fn rank(&self) -> usize {
451        unexpanded!()
452    }
453    fn len(&self) -> usize {
454        unexpanded!()
455    }
456    fn buffer_len(&self) -> usize {
457        unexpanded!()
458    }
459}
460
461/// Making [virtual tensors](VirtualTensor) a proper [cube type](CubeType).
462mod __cube_type {
463    use super::*;
464
465    impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
466        type ExpandType = VirtualTensorExpand<E, IO>;
467    }
468
469    impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
470        fn into_mut(self, _scope: &mut Scope) -> Self {
471            self
472        }
473    }
474
475    impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
476}
477
478/// Enable tensors to be virtual.
479mod __tensor {
480    use super::*;
481
482    impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
483    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
484        fn __expand_read_method(
485            &self,
486            scope: &mut Scope,
487            index: ExpandElementTyped<usize>,
488        ) -> ExpandElementTyped<Line<E>> {
489            self.clone().__expand_index_unchecked_method(scope, index)
490        }
491        fn __expand_read_window_method(
492            &self,
493            context: &mut Scope,
494            start: ExpandElementTyped<usize>,
495            end: ExpandElementTyped<usize>,
496        ) -> SliceExpand<Line<E>, ReadOnly> {
497            self.clone().__expand_slice_method(context, start, end)
498        }
499
500        fn __expand_write_method(
501            &self,
502            scope: &mut Scope,
503            index: ExpandElementTyped<usize>,
504            value: ExpandElementTyped<Line<E>>,
505        ) {
506            self.clone()
507                .__expand_index_assign_unchecked_method(scope, index, value)
508        }
509
510        fn __expand_shape_method(
511            &self,
512            scope: &mut Scope,
513            axis: ExpandElementTyped<usize>,
514        ) -> ExpandElementTyped<usize> {
515            self.clone().__expand_shape_method(scope, axis)
516        }
517
518        fn __expand_stride_method(
519            &self,
520            scope: &mut Scope,
521            axis: ExpandElementTyped<usize>,
522        ) -> ExpandElementTyped<usize> {
523            self.clone().__expand_stride_method(scope, axis)
524        }
525
526        fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
527            self.clone().__expand_rank_method(scope)
528        }
529        fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
530            self.clone().__expand_len_method(scope)
531        }
532        fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
533            self.clone().__expand_buffer_len_method(scope)
534        }
535
536        fn __expand_as_tensor_map_method(
537            &self,
538            scope: &mut Scope,
539        ) -> OptionExpand<TensorMap<E, Tiled>> {
540            Option::__expand_new_None(scope)
541        }
542    }
543}
544
545/// Enable tensor maps to be virtual.
546mod __tensor_map {
547    use super::*;
548
549    impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E, Tiled> {}
550    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E, Tiled>> {
551        fn __expand_read_method(
552            &self,
553            _scope: &mut Scope,
554            _index: ExpandElementTyped<usize>,
555        ) -> ExpandElementTyped<Line<E>> {
556            todo!()
557        }
558        fn __expand_read_window_method(
559            &self,
560            _context: &mut Scope,
561            _start: ExpandElementTyped<usize>,
562            _end: ExpandElementTyped<usize>,
563        ) -> SliceExpand<Line<E>, ReadOnly> {
564            todo!()
565        }
566
567        fn __expand_write_method(
568            &self,
569            _scope: &mut Scope,
570            _index: ExpandElementTyped<usize>,
571            _value: ExpandElementTyped<Line<E>>,
572        ) {
573            todo!()
574        }
575
576        fn __expand_shape_method(
577            &self,
578            _scope: &mut Scope,
579            _axis: ExpandElementTyped<usize>,
580        ) -> ExpandElementTyped<usize> {
581            todo!()
582        }
583
584        fn __expand_stride_method(
585            &self,
586            _scope: &mut Scope,
587            _axis: ExpandElementTyped<usize>,
588        ) -> ExpandElementTyped<usize> {
589            todo!()
590        }
591
592        fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
593            todo!()
594        }
595        fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
596            todo!()
597        }
598        fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<usize> {
599            todo!()
600        }
601
602        fn __expand_as_tensor_map_method(
603            &self,
604            scope: &mut Scope,
605        ) -> OptionExpand<TensorMap<E, Tiled>> {
606            Option::__expand_new_Some(scope, self.clone())
607        }
608    }
609}