cubecl_std/tensor/
virtual.rs

1use alloc::sync::Arc;
2use core::marker::PhantomData;
3use cubecl::prelude::{CubeType, Scope, *};
4use cubecl_core::{self as cubecl, unexpanded};
5
6/// The read tag for [virtual tensor](VirtualTensor).
7#[derive(Clone)]
8pub struct Read;
9
10/// The read write tag for [virtual tensor](VirtualTensor).
11#[derive(Clone)]
12pub struct ReadWrite;
13
14/// Tensor representation that is decoupled from how the tensor is stored.
15#[derive(Clone)]
16pub struct VirtualTensor<E: Numeric, IO = Read> {
17    // state: Arc<dyn VirtualTensorOperations<E>>,
18    _e: PhantomData<E>,
19    _p: PhantomData<IO>,
20}
21
22impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
23
24/// Expand type for [VirtualTensor].
25#[derive(Clone)]
26pub struct VirtualTensorExpand<E: Numeric, IO> {
27    state: Arc<dyn VirtualTensorOperationsExpand<E>>,
28    _p: PhantomData<IO>,
29}
30
31impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
32    fn __expand_read(
33        scope: &mut Scope,
34        this: VirtualTensorExpand<E, IO>,
35        index: <u32 as CubeType>::ExpandType,
36    ) -> <Line<E> as CubeType>::ExpandType {
37        this.__expand_read_method(scope, index)
38    }
39}
40
41impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
42    fn __expand_read_method(
43        &self,
44        scope: &mut Scope,
45        index: <u32 as CubeType>::ExpandType,
46    ) -> <Line<E> as CubeType>::ExpandType {
47        self.state.clone().__expand_read_method(scope, index)
48    }
49
50    fn __expand_read_unchecked_method(
51        &self,
52        _scope: &mut Scope,
53        _index: ExpandElementTyped<u32>,
54    ) -> <Line<E> as CubeType>::ExpandType {
55        todo!("VirtualTensor don't support read unchecked yet");
56    }
57}
58
59#[allow(unused, clippy::all)]
60impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
61    pub fn as_tensor_map(&self) -> TensorMap<E> {
62        unexpanded!()
63    }
64    pub fn as_slice(&self, start: u32, end: u32) -> Slice<Line<E>> {
65        unexpanded!();
66    }
67    /// Get the shape of the tensor at the given axis.
68    pub fn shape(&self, axis: u32) -> u32 {
69        unexpanded!();
70    }
71    /// Get the stride of the tensor at the given axis.
72    pub fn stride(&self, axis: u32) -> u32 {
73        unexpanded!();
74    }
75    /// Get the rank of the tensor.
76    pub fn rank(&self) -> u32 {
77        unexpanded!();
78    }
79
80    pub fn len(&self) -> u32 {
81        unexpanded!();
82    }
83
84    pub fn buffer_len(&self) -> u32 {
85        unexpanded!();
86    }
87
88    pub fn __expand_as_tensor_map(
89        context: &mut Scope,
90        this: <Self as CubeType>::ExpandType,
91    ) -> <TensorMap<E> as CubeType>::ExpandType {
92        this.__expand_as_tensor_map_method(context)
93    }
94    pub fn __expand_as_slice(
95        context: &mut Scope,
96        this: <Self as CubeType>::ExpandType,
97        start: <u32 as CubeType>::ExpandType,
98        end: <u32 as CubeType>::ExpandType,
99    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
100        this.__expand_as_slice_method(context, start, end)
101    }
102    pub fn __expand_shape(
103        scope: &mut Scope,
104        this: <Self as CubeType>::ExpandType,
105        axis: <u32 as CubeType>::ExpandType,
106    ) -> <u32 as CubeType>::ExpandType {
107        this.__expand_shape_method(scope, axis)
108    }
109    pub fn __expand_stride(
110        scope: &mut Scope,
111        this: <Self as CubeType>::ExpandType,
112        axis: <u32 as CubeType>::ExpandType,
113    ) -> <u32 as CubeType>::ExpandType {
114        this.__expand_stride_method(scope, axis)
115    }
116    pub fn __expand_rank(
117        scope: &mut Scope,
118        this: <Self as CubeType>::ExpandType,
119    ) -> <u32 as CubeType>::ExpandType {
120        this.__expand_rank_method(scope)
121    }
122    pub fn __expand_len(
123        scope: &mut Scope,
124        this: <Self as CubeType>::ExpandType,
125    ) -> <u32 as CubeType>::ExpandType {
126        this.__expand_len_method(scope)
127    }
128    pub fn __expand_buffer_len(
129        scope: &mut Scope,
130        this: <Self as CubeType>::ExpandType,
131    ) -> <u32 as CubeType>::ExpandType {
132        this.__expand_buffer_len_method(scope)
133    }
134}
135
136#[allow(unused, clippy::all)]
137impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
138    pub fn __expand_as_tensor_map_method(
139        self,
140        context: &mut Scope,
141    ) -> <TensorMap<E> as CubeType>::ExpandType {
142        self.state.clone().__expand_as_tensor_map_method(context)
143    }
144
145    pub fn __expand_as_slice_method(
146        self,
147        context: &mut Scope,
148        start: <u32 as CubeType>::ExpandType,
149        end: <u32 as CubeType>::ExpandType,
150    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
151        self.state
152            .clone()
153            .__expand_read_window_method(context, start, end)
154    }
155
156    pub fn __expand_shape_method(
157        self,
158        scope: &mut Scope,
159        axis: <u32 as CubeType>::ExpandType,
160    ) -> <u32 as CubeType>::ExpandType {
161        let _arg_0 = axis;
162        self.state
163            .clone()
164            .__expand_shape_method(scope, _arg_0.into())
165    }
166
167    pub fn __expand_stride_method(
168        self,
169        scope: &mut Scope,
170        axis: <u32 as CubeType>::ExpandType,
171    ) -> <u32 as CubeType>::ExpandType {
172        let _arg_0 = axis;
173        self.state
174            .clone()
175            .__expand_stride_method(scope, _arg_0.into())
176    }
177
178    pub fn __expand_rank_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
179        self.state.clone().__expand_rank_method(scope)
180    }
181
182    pub fn __expand_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
183        self.state.clone().__expand_len_method(scope)
184    }
185
186    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
187        self.state.clone().__expand_buffer_len_method(scope)
188    }
189
190    pub fn __expand_read(
191        scope: &mut Scope,
192        this: Self,
193        index: <u32 as CubeType>::ExpandType,
194    ) -> <Line<E> as CubeType>::ExpandType {
195        VirtualTensor::<E, IO>::__expand_read(scope, this, index)
196    }
197
198    pub fn __expand_shape(
199        scope: &mut Scope,
200        this: Self,
201        axis: <u32 as CubeType>::ExpandType,
202    ) -> <u32 as CubeType>::ExpandType {
203        VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
204    }
205
206    pub fn __expand_stride(
207        scope: &mut Scope,
208        this: Self,
209        axis: <u32 as CubeType>::ExpandType,
210    ) -> <u32 as CubeType>::ExpandType {
211        VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
212    }
213
214    pub fn __expand_rank(scope: &mut Scope, this: Self) -> <u32 as CubeType>::ExpandType {
215        VirtualTensor::<E, IO>::__expand_rank(scope, this)
216    }
217}
218
219#[cube]
220impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
221    pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
222        let num_strides = index / self.stride(dim);
223        num_strides % self.shape(dim)
224    }
225}
226
227impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
228    fn __expand_write(
229        scope: &mut Scope,
230        this: VirtualTensorExpand<E, ReadWrite>,
231        index: <u32 as CubeType>::ExpandType,
232        value: <Line<E> as CubeType>::ExpandType,
233    ) -> <() as CubeType>::ExpandType {
234        this.__expand_write_method(scope, index, value)
235    }
236}
237
238impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
239    fn __expand_write_method(
240        &self,
241        scope: &mut Scope,
242        index: <u32 as CubeType>::ExpandType,
243        value: <Line<E> as CubeType>::ExpandType,
244    ) -> <() as CubeType>::ExpandType {
245        self.state
246            .clone()
247            .__expand_write_method(scope, index, value)
248    }
249}
250
251impl<E: Numeric> VirtualTensor<E, Read> {
252    /// Create a new [read only](Read) [virtual tensor](VirtualTensor).
253    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
254        unexpanded!()
255    }
256
257    /// Expand function of [Self::new].
258    pub fn __expand_new<V>(_scope: &mut Scope, v: V::ExpandType) -> VirtualTensorExpand<E, Read>
259    where
260        V::ExpandType: VirtualTensorOperationsExpand<E>,
261        V: VirtualTensorOperations<E> + CubeType + 'static,
262    {
263        VirtualTensorExpand {
264            state: Arc::new(v),
265            _p: PhantomData,
266        }
267    }
268}
269
270impl<E: Numeric> VirtualTensor<E, ReadWrite> {
271    /// Create a new [read write](ReadWrite) [virtual tensor](VirtualTensor).
272    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
273        unexpanded!()
274    }
275
276    /// Expand function of [Self::new].
277    pub fn __expand_new<V>(
278        _scope: &mut Scope,
279        v: V::ExpandType,
280    ) -> VirtualTensorExpand<E, ReadWrite>
281    where
282        V::ExpandType: VirtualTensorOperationsExpand<E>,
283        V: VirtualTensorOperations<E> + CubeType + 'static,
284    {
285        VirtualTensorExpand {
286            state: Arc::new(v),
287            _p: PhantomData,
288        }
289    }
290}
291
292/// Trait to be implemented by a type that can become a [virtual tensor](VirtualTensor).
293///
294/// The [expand trait](VirtualTensorOperationsExpand) also need to be implemented for the type's
295/// expand type.
296///
297/// # Warning
298///
299/// This trait is kind of unsafe, [VirtualTensorOperations::write] doesn't follow the mutability
300/// rules, but it won't lead to any undefined behavior.
301pub trait VirtualTensorOperations<E: Numeric> {
302    /// Read the tensor at the given index.
303    fn read(&self, _index: u32) -> Line<E> {
304        unexpanded!()
305    }
306    /// Write the tensor at the given index.
307    fn write(&self, _index: u32, _value: Line<E>) {
308        unexpanded!()
309    }
310    /// Get the shape of the tensor at the given axis.
311    fn shape(&self, _axis: u32) -> u32 {
312        unexpanded!()
313    }
314    /// Get the stride of the tensor at the given axis.
315    fn stride(&self, _axis: u32) -> u32 {
316        unexpanded!()
317    }
318    /// Get the rank of the tensor.
319    fn rank(&self) -> u32 {
320        unexpanded!()
321    }
322}
323
324/// Expand trait for [VirtualTensorOperations].
325///
326/// For now this needs to be manually implemented for any type that needs to become a virtual
327/// tensor.
328pub trait VirtualTensorOperationsExpand<E: Numeric> {
329    fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<E>>;
330    fn __expand_read_method(
331        &self,
332        scope: &mut Scope,
333        index: ExpandElementTyped<u32>,
334    ) -> ExpandElementTyped<Line<E>>;
335    fn __expand_read_window_method(
336        &self,
337        context: &mut Scope,
338        start: ExpandElementTyped<u32>,
339        end: ExpandElementTyped<u32>,
340    ) -> SliceExpand<Line<E>, ReadOnly>;
341    fn __expand_write_method(
342        &self,
343        scope: &mut Scope,
344        index: ExpandElementTyped<u32>,
345        value: ExpandElementTyped<Line<E>>,
346    );
347    fn __expand_shape_method(
348        &self,
349        scope: &mut Scope,
350        axis: ExpandElementTyped<u32>,
351    ) -> ExpandElementTyped<u32>;
352    fn __expand_stride_method(
353        &self,
354        scope: &mut Scope,
355        axis: ExpandElementTyped<u32>,
356    ) -> ExpandElementTyped<u32>;
357    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
358    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
359    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
360}
361
362/// Making [virtual tensors](VirtualTensor) a proper [cube type](CubeType).
363mod __cube_type {
364    use super::*;
365
366    impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
367        type ExpandType = VirtualTensorExpand<E, IO>;
368    }
369
370    impl<E: Numeric, IO> IntoMut for VirtualTensorExpand<E, IO> {
371        fn into_mut(self, _scope: &mut Scope) -> Self {
372            self
373        }
374    }
375
376    impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
377}
378
379/// Enable tensors to be virtual.
380mod __tensor {
381    use super::*;
382
383    impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
384    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
385        fn __expand_read_method(
386            &self,
387            scope: &mut Scope,
388            index: ExpandElementTyped<u32>,
389        ) -> ExpandElementTyped<Line<E>> {
390            self.clone().__expand_index_unchecked_method(scope, index)
391        }
392        fn __expand_read_window_method(
393            &self,
394            context: &mut Scope,
395            start: ExpandElementTyped<u32>,
396            end: ExpandElementTyped<u32>,
397        ) -> SliceExpand<Line<E>, ReadOnly> {
398            self.clone().__expand_slice_method(context, start, end)
399        }
400
401        fn __expand_write_method(
402            &self,
403            scope: &mut Scope,
404            index: ExpandElementTyped<u32>,
405            value: ExpandElementTyped<Line<E>>,
406        ) {
407            self.clone()
408                .__expand_index_assign_unchecked_method(scope, index, value)
409        }
410
411        fn __expand_shape_method(
412            &self,
413            scope: &mut Scope,
414            axis: ExpandElementTyped<u32>,
415        ) -> ExpandElementTyped<u32> {
416            self.clone().__expand_shape_method(scope, axis)
417        }
418
419        fn __expand_stride_method(
420            &self,
421            scope: &mut Scope,
422            axis: ExpandElementTyped<u32>,
423        ) -> ExpandElementTyped<u32> {
424            self.clone().__expand_stride_method(scope, axis)
425        }
426
427        fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
428            self.clone().__expand_rank_method(scope)
429        }
430        fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
431            self.clone().__expand_len_method(scope)
432        }
433        fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
434            self.clone().__expand_buffer_len_method(scope)
435        }
436
437        fn __expand_as_tensor_map_method(
438            &self,
439            _scope: &mut Scope,
440        ) -> ExpandElementTyped<TensorMap<E>> {
441            unimplemented!("Can't turn normal tensor into `TensorMap`");
442        }
443    }
444}
445
446/// Enable tensor maps to be virtual.
447mod __tensor_map {
448    use super::*;
449
450    impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E> {}
451    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E>> {
452        fn __expand_read_method(
453            &self,
454            _scope: &mut Scope,
455            _index: ExpandElementTyped<u32>,
456        ) -> ExpandElementTyped<Line<E>> {
457            todo!()
458        }
459        fn __expand_read_window_method(
460            &self,
461            _context: &mut Scope,
462            _start: ExpandElementTyped<u32>,
463            _end: ExpandElementTyped<u32>,
464        ) -> SliceExpand<Line<E>, ReadOnly> {
465            todo!()
466        }
467
468        fn __expand_write_method(
469            &self,
470            _scope: &mut Scope,
471            _index: ExpandElementTyped<u32>,
472            _value: ExpandElementTyped<Line<E>>,
473        ) {
474            todo!()
475        }
476
477        fn __expand_shape_method(
478            &self,
479            _scope: &mut Scope,
480            _axis: ExpandElementTyped<u32>,
481        ) -> ExpandElementTyped<u32> {
482            todo!()
483        }
484
485        fn __expand_stride_method(
486            &self,
487            _scope: &mut Scope,
488            _axis: ExpandElementTyped<u32>,
489        ) -> ExpandElementTyped<u32> {
490            todo!()
491        }
492
493        fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
494            todo!()
495        }
496        fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
497            todo!()
498        }
499        fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
500            todo!()
501        }
502
503        fn __expand_as_tensor_map_method(
504            &self,
505            _scope: &mut Scope,
506        ) -> ExpandElementTyped<TensorMap<E>> {
507            self.clone()
508        }
509    }
510}