cubecl_std/tensor/
virtual.rs

1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::{CubeType, Scope, *};
4use cubecl_core::{self as cubecl, unexpanded};
5
6use crate::{
7    CubeOption,
8    tensor::{
9        ViewExpand,
10        layout::{
11            Coordinates, Coords1d, Layout, VirtualLayout, VirtualLayoutExpand, simple::SimpleLayout,
12        },
13        view::View,
14    },
15};
16
17/// Tensor representation that is decoupled from how the tensor is stored.
18#[derive(Clone)]
19pub struct VirtualTensor<E: Numeric, IO = ReadOnly> {
20    // state: Arc<dyn VirtualTensorOperations<E>>,
21    _e: PhantomData<E>,
22    _p: PhantomData<IO>,
23}
24
25impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
26
27/// Expand type for [VirtualTensor].
28#[derive(Clone)]
29pub struct VirtualTensorExpand<E: Numeric, IO> {
30    state: Arc<dyn VirtualTensorOperationsExpand<E>>,
31    _p: PhantomData<IO>,
32}
33
34impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
35    fn __expand_read(
36        scope: &mut Scope,
37        this: VirtualTensorExpand<E, IO>,
38        index: <u32 as CubeType>::ExpandType,
39    ) -> <Line<E> as CubeType>::ExpandType {
40        this.__expand_read_method(scope, index)
41    }
42}
43
44impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
45    fn __expand_read_method(
46        &self,
47        scope: &mut Scope,
48        index: <u32 as CubeType>::ExpandType,
49    ) -> <Line<E> as CubeType>::ExpandType {
50        self.state.clone().__expand_read_method(scope, index)
51    }
52
53    fn __expand_read_unchecked_method(
54        &self,
55        _scope: &mut Scope,
56        _index: ExpandElementTyped<u32>,
57    ) -> <Line<E> as CubeType>::ExpandType {
58        todo!("VirtualTensor don't support read unchecked yet");
59    }
60
61    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
62        self.state.clone().__expand_len_method(scope)
63    }
64}
65
66impl<E: Numeric, IO: Clone> Lined for VirtualTensor<E, IO> {}
67impl<E: Numeric, IO: Clone> LinedExpand for VirtualTensorExpand<E, IO> {
68    fn line_size(&self) -> u32 {
69        self.state.clone().line_size()
70    }
71}
72
73impl<E: Numeric, IO: Clone> SliceOperator<Line<E>> for VirtualTensor<E, IO> {}
74impl<E: Numeric, IO: Clone> SliceOperatorExpand<Line<E>> for VirtualTensorExpand<E, IO> {
75    fn __expand_slice_method(
76        &self,
77        scope: &mut Scope,
78        start: ExpandElementTyped<u32>,
79        end: ExpandElementTyped<u32>,
80    ) -> SliceExpand<Line<E>, ReadOnly> {
81        self.state
82            .clone()
83            .__expand_read_window_method(scope, start, end)
84    }
85
86    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<Line<E>, ReadOnly> {
87        let end = self.clone().__expand_buffer_len_method(scope);
88        self.state
89            .clone()
90            .__expand_read_window_method(scope, 0.into(), end)
91    }
92}
93
94#[allow(unused, clippy::all)]
95impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
96    pub fn as_tensor_map(&self) -> CubeOption<TensorMap<E>> {
97        unexpanded!()
98    }
99    pub fn as_slice(&self, start: u32, end: u32) -> Slice<Line<E>> {
100        unexpanded!();
101    }
102    /// Get the shape of the tensor at the given axis.
103    pub fn shape(&self, axis: u32) -> u32 {
104        unexpanded!();
105    }
106    /// Get the stride of the tensor at the given axis.
107    pub fn stride(&self, axis: u32) -> u32 {
108        unexpanded!();
109    }
110    /// Get the rank of the tensor.
111    pub fn rank(&self) -> u32 {
112        unexpanded!();
113    }
114
115    pub fn buffer_len(&self) -> u32 {
116        unexpanded!();
117    }
118
119    pub fn __expand_as_tensor_map(
120        context: &mut Scope,
121        this: <Self as CubeType>::ExpandType,
122    ) -> <CubeOption<TensorMap<E>> as CubeType>::ExpandType {
123        this.__expand_as_tensor_map_method(context)
124    }
125    pub fn __expand_as_slice(
126        context: &mut Scope,
127        this: <Self as CubeType>::ExpandType,
128        start: <u32 as CubeType>::ExpandType,
129        end: <u32 as CubeType>::ExpandType,
130    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
131        this.__expand_as_slice_method(context, start, end)
132    }
133    pub fn __expand_shape(
134        scope: &mut Scope,
135        this: <Self as CubeType>::ExpandType,
136        axis: <u32 as CubeType>::ExpandType,
137    ) -> <u32 as CubeType>::ExpandType {
138        this.__expand_shape_method(scope, axis)
139    }
140    pub fn __expand_stride(
141        scope: &mut Scope,
142        this: <Self as CubeType>::ExpandType,
143        axis: <u32 as CubeType>::ExpandType,
144    ) -> <u32 as CubeType>::ExpandType {
145        this.__expand_stride_method(scope, axis)
146    }
147    pub fn __expand_rank(
148        scope: &mut Scope,
149        this: <Self as CubeType>::ExpandType,
150    ) -> <u32 as CubeType>::ExpandType {
151        this.__expand_rank_method(scope)
152    }
153    pub fn __expand_buffer_len(
154        scope: &mut Scope,
155        this: <Self as CubeType>::ExpandType,
156    ) -> <u32 as CubeType>::ExpandType {
157        this.__expand_buffer_len_method(scope)
158    }
159}
160
161#[allow(unused, clippy::all)]
162impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
163    pub fn __expand_as_tensor_map_method(
164        self,
165        context: &mut Scope,
166    ) -> <CubeOption<TensorMap<E>> as CubeType>::ExpandType {
167        self.state.clone().__expand_as_tensor_map_method(context)
168    }
169
170    pub fn __expand_as_slice_method(
171        self,
172        context: &mut Scope,
173        start: <u32 as CubeType>::ExpandType,
174        end: <u32 as CubeType>::ExpandType,
175    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
176        self.state
177            .clone()
178            .__expand_read_window_method(context, start, end)
179    }
180
181    pub fn __expand_shape_method(
182        self,
183        scope: &mut Scope,
184        axis: <u32 as CubeType>::ExpandType,
185    ) -> <u32 as CubeType>::ExpandType {
186        let _arg_0 = axis;
187        self.state
188            .clone()
189            .__expand_shape_method(scope, _arg_0.into())
190    }
191
192    pub fn __expand_stride_method(
193        self,
194        scope: &mut Scope,
195        axis: <u32 as CubeType>::ExpandType,
196    ) -> <u32 as CubeType>::ExpandType {
197        let _arg_0 = axis;
198        self.state
199            .clone()
200            .__expand_stride_method(scope, _arg_0.into())
201    }
202
203    pub fn __expand_rank_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
204        self.state.clone().__expand_rank_method(scope)
205    }
206
207    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
208        self.state.clone().__expand_buffer_len_method(scope)
209    }
210
211    pub fn __expand_read(
212        scope: &mut Scope,
213        this: Self,
214        index: <u32 as CubeType>::ExpandType,
215    ) -> <Line<E> as CubeType>::ExpandType {
216        VirtualTensor::<E, IO>::__expand_read(scope, this, index)
217    }
218
219    pub fn __expand_shape(
220        scope: &mut Scope,
221        this: Self,
222        axis: <u32 as CubeType>::ExpandType,
223    ) -> <u32 as CubeType>::ExpandType {
224        VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
225    }
226
227    pub fn __expand_stride(
228        scope: &mut Scope,
229        this: Self,
230        axis: <u32 as CubeType>::ExpandType,
231    ) -> <u32 as CubeType>::ExpandType {
232        VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
233    }
234
235    pub fn __expand_rank(scope: &mut Scope, this: Self) -> <u32 as CubeType>::ExpandType {
236        VirtualTensor::<E, IO>::__expand_rank(scope, this)
237    }
238}
239
240impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
241    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
242    /// layouts
243    pub fn view<C: Coordinates + 'static>(
244        &self,
245        layout: impl Into<VirtualLayout<C, Coords1d>>,
246    ) -> View<Line<E>, C, ReadOnly> {
247        View::new::<VirtualTensor<E, IO>, Coords1d>(self, layout)
248    }
249}
250
251#[cube]
252impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
253    /// Create a conceptual view over this tensor, with a simple linear layout
254    pub fn as_view(&self) -> View<Line<E>, u32, ReadOnly> {
255        let line_size = self.line_size();
256        View::new::<VirtualTensor<E, IO>, u32>(
257            self,
258            SimpleLayout::new(self.len() * line_size, line_size),
259        )
260    }
261}
262
263impl<E: Numeric, IO: Clone + 'static> VirtualTensorExpand<E, IO> {
264    /// Create a conceptual view over this tensor, allowing for multi-dimensional indexing with custom
265    /// layouts
266    pub fn __expand_view_method<C: Coordinates + 'static>(
267        &self,
268        scope: &mut Scope,
269        layout: VirtualLayoutExpand<C, Coords1d>,
270    ) -> ViewExpand<Line<E>, C, ReadOnly> {
271        View::__expand_new::<VirtualTensor<E, IO>, Coords1d>(scope, self.clone(), layout)
272    }
273}
274
275impl<E: Numeric> VirtualTensor<E, ReadWrite> {
276    #[doc = " Create a mutable conceptual view over this tensor, allowing for multi-dimensional indexing"]
277    #[doc = " with custom layouts"]
278    pub fn view_mut<C: Coordinates + 'static>(
279        &self,
280        layout: impl Layout<Coordinates = C, SourceCoordinates = Coords1d> + 'static,
281    ) -> View<Line<E>, C, ReadWrite> {
282        let mut this: VirtualTensor<E, ReadWrite> = *self;
283        View::new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(&mut this, layout)
284    }
285    pub fn __expand_view_mut<C: Coordinates + 'static>(
286        scope: &mut Scope,
287        this: VirtualTensorExpand<E, ReadWrite>,
288        layout: VirtualLayoutExpand<C, Coords1d>,
289    ) -> ViewExpand<Line<E>, C, ReadWrite> {
290        this.__expand_view_mut_method::<C>(scope, layout)
291    }
292}
293impl<E: Numeric> VirtualTensorExpand<E, ReadWrite> {
294    pub fn __expand_view_mut_method<C: Coordinates + 'static>(
295        self,
296        scope: &mut Scope,
297        layout: VirtualLayoutExpand<C, Coords1d>,
298    ) -> ViewExpand<Line<E>, C, ReadWrite> {
299        View::__expand_new_mut::<VirtualTensor<E, ReadWrite>, Coords1d>(scope, self, layout)
300    }
301}
302
303#[cube]
304impl<E: Numeric> VirtualTensor<E, ReadWrite> {
305    /// Create a conceptual mutable view over this tensor, with a simple linear layout
306    pub fn as_view_mut(&mut self) -> View<Line<E>, u32, ReadWrite> {
307        let line_size = self.line_size();
308        View::new_mut::<VirtualTensor<E, ReadWrite>, u32>(
309            self,
310            SimpleLayout::new(self.len() * line_size, line_size),
311        )
312    }
313}
314
315#[cube]
316impl<E: Numeric, IO: Clone + 'static> VirtualTensor<E, IO> {
317    pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
318        let num_strides = index / self.stride(dim);
319        num_strides % self.shape(dim)
320    }
321}
322
323impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
324    fn __expand_write(
325        scope: &mut Scope,
326        this: VirtualTensorExpand<E, ReadWrite>,
327        index: <u32 as CubeType>::ExpandType,
328        value: <Line<E> as CubeType>::ExpandType,
329    ) -> <() as CubeType>::ExpandType {
330        this.__expand_write_method(scope, index, value)
331    }
332}
333
334impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
335    fn __expand_write_method(
336        &self,
337        scope: &mut Scope,
338        index: <u32 as CubeType>::ExpandType,
339        value: <Line<E> as CubeType>::ExpandType,
340    ) -> <() as CubeType>::ExpandType {
341        self.state
342            .clone()
343            .__expand_write_method(scope, index, value)
344    }
345}
346
347impl<E: Numeric> SliceMutOperator<Line<E>> for VirtualTensor<E, ReadWrite> {}
348impl<E: Numeric> SliceMutOperatorExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
349    #[allow(unused_variables)]
350    fn __expand_slice_mut_method(
351        &self,
352        scope: &mut Scope,
353        start: ExpandElementTyped<u32>,
354        end: ExpandElementTyped<u32>,
355    ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
356        todo!("VirtualTensor don't support slice mut yet");
357    }
358
359    #[allow(unused_variables)]
360    fn __expand_to_slice_mut_method(
361        &self,
362        scope: &mut Scope,
363    ) -> SliceExpand<Line<E>, cubecl_core::prelude::ReadWrite> {
364        todo!("VirtualTensor don't support slice mut yet");
365    }
366}
367
368impl<E: Numeric> VirtualTensor<E, ReadOnly> {
369    /// Create a new [read only](Read) [virtual tensor](VirtualTensor).
370    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
371        unexpanded!()
372    }
373
374    /// Expand function of [Self::new].
375    pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
376        _scope: &mut Scope,
377        v: V::ExpandType,
378    ) -> VirtualTensorExpand<E, ReadOnly> {
379        VirtualTensorExpand {
380            state: Arc::new(v),
381            _p: PhantomData,
382        }
383    }
384}
385
386impl<E: Numeric> VirtualTensor<E, ReadWrite> {
387    /// Create a new [read write](ReadWrite) [virtual tensor](VirtualTensor).
388    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
389        unexpanded!()
390    }
391
392    /// Expand function of [Self::new].
393    pub fn __expand_new<V: VirtualTensorOperations<E> + 'static>(
394        _scope: &mut Scope,
395        v: V::ExpandType,
396    ) -> VirtualTensorExpand<E, ReadWrite> {
397        VirtualTensorExpand {
398            state: Arc::new(v),
399            _p: PhantomData,
400        }
401    }
402}
403
404/// Trait to be implemented by a type that can become a [virtual tensor](VirtualTensor).
405///
406/// The [expand trait](VirtualTensorOperationsExpand) also need to be implemented for the type's
407/// expand type.
408///
409/// # Warning
410///
411/// This trait is kind of unsafe, [VirtualTensorOperations::write] doesn't follow the mutability
412/// rules, but it won't lead to any undefined behavior.
413#[cube(self_type = "ref", expand_base_traits = "LinedExpand")]
414pub trait VirtualTensorOperations<E: Numeric>: Lined {
415    fn as_tensor_map(&self) -> CubeOption<TensorMap<E>> {
416        unexpanded!()
417    }
418    /// Read the tensor at the given index.
419    fn read(&self, _index: u32) -> Line<E> {
420        unexpanded!()
421    }
422    fn read_window(&self, _start: u32, _end: u32) -> Slice<Line<E>, ReadOnly> {
423        unexpanded!()
424    }
425    /// Write the tensor at the given index.
426    fn write(&self, _index: u32, _value: Line<E>) {
427        unexpanded!()
428    }
429    /// Get the shape of the tensor at the given axis.
430    fn shape(&self, _axis: u32) -> u32 {
431        unexpanded!()
432    }
433    /// Get the stride of the tensor at the given axis.
434    fn stride(&self, _axis: u32) -> u32 {
435        unexpanded!()
436    }
437    /// Get the rank of the tensor.
438    fn rank(&self) -> u32 {
439        unexpanded!()
440    }
441    fn len(&self) -> u32 {
442        unexpanded!()
443    }
444    fn buffer_len(&self) -> u32 {
445        unexpanded!()
446    }
447}
448
449/// Making [virtual tensors](VirtualTensor) a proper [cube type](CubeType).
450mod __cube_type {
451    use super::*;
452
453    impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
454        type ExpandType = VirtualTensorExpand<E, IO>;
455    }
456
457    impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
458        fn into_mut(self, _scope: &mut Scope) -> Self {
459            self
460        }
461    }
462
463    impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
464}
465
466/// Enable tensors to be virtual.
467mod __tensor {
468    use crate::CubeOptionExpand;
469
470    use super::*;
471
472    impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
473    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
474        fn __expand_read_method(
475            &self,
476            scope: &mut Scope,
477            index: ExpandElementTyped<u32>,
478        ) -> ExpandElementTyped<Line<E>> {
479            self.clone().__expand_index_unchecked_method(scope, index)
480        }
481        fn __expand_read_window_method(
482            &self,
483            context: &mut Scope,
484            start: ExpandElementTyped<u32>,
485            end: ExpandElementTyped<u32>,
486        ) -> SliceExpand<Line<E>, ReadOnly> {
487            self.clone().__expand_slice_method(context, start, end)
488        }
489
490        fn __expand_write_method(
491            &self,
492            scope: &mut Scope,
493            index: ExpandElementTyped<u32>,
494            value: ExpandElementTyped<Line<E>>,
495        ) {
496            self.clone()
497                .__expand_index_assign_unchecked_method(scope, index, value)
498        }
499
500        fn __expand_shape_method(
501            &self,
502            scope: &mut Scope,
503            axis: ExpandElementTyped<u32>,
504        ) -> ExpandElementTyped<u32> {
505            self.clone().__expand_shape_method(scope, axis)
506        }
507
508        fn __expand_stride_method(
509            &self,
510            scope: &mut Scope,
511            axis: ExpandElementTyped<u32>,
512        ) -> ExpandElementTyped<u32> {
513            self.clone().__expand_stride_method(scope, axis)
514        }
515
516        fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
517            self.clone().__expand_rank_method(scope)
518        }
519        fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
520            self.clone().__expand_len_method(scope)
521        }
522        fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
523            self.clone().__expand_buffer_len_method(scope)
524        }
525
526        fn __expand_as_tensor_map_method(
527            &self,
528            scope: &mut Scope,
529        ) -> CubeOptionExpand<TensorMap<E>> {
530            CubeOption::__expand_new_None(scope)
531        }
532    }
533}
534
535/// Enable tensor maps to be virtual.
536mod __tensor_map {
537    use crate::CubeOptionExpand;
538
539    use super::*;
540
541    impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E> {}
542    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E>> {
543        fn __expand_read_method(
544            &self,
545            _scope: &mut Scope,
546            _index: ExpandElementTyped<u32>,
547        ) -> ExpandElementTyped<Line<E>> {
548            todo!()
549        }
550        fn __expand_read_window_method(
551            &self,
552            _context: &mut Scope,
553            _start: ExpandElementTyped<u32>,
554            _end: ExpandElementTyped<u32>,
555        ) -> SliceExpand<Line<E>, ReadOnly> {
556            todo!()
557        }
558
559        fn __expand_write_method(
560            &self,
561            _scope: &mut Scope,
562            _index: ExpandElementTyped<u32>,
563            _value: ExpandElementTyped<Line<E>>,
564        ) {
565            todo!()
566        }
567
568        fn __expand_shape_method(
569            &self,
570            _scope: &mut Scope,
571            _axis: ExpandElementTyped<u32>,
572        ) -> ExpandElementTyped<u32> {
573            todo!()
574        }
575
576        fn __expand_stride_method(
577            &self,
578            _scope: &mut Scope,
579            _axis: ExpandElementTyped<u32>,
580        ) -> ExpandElementTyped<u32> {
581            todo!()
582        }
583
584        fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
585            todo!()
586        }
587        fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
588            todo!()
589        }
590        fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
591            todo!()
592        }
593
594        fn __expand_as_tensor_map_method(
595            &self,
596            scope: &mut Scope,
597        ) -> CubeOptionExpand<TensorMap<E>> {
598            CubeOption::__expand_new_Some(scope, self.clone())
599        }
600    }
601}