cubecl_std/tensor/
virtual.rs

1use cubecl::prelude::{CubeType, Scope, *};
2use cubecl_core::{self as cubecl, unexpanded};
3use std::{marker::PhantomData, sync::Arc};
4
5/// The read tag for [virtual tensor](VirtualTensor).
6#[derive(Clone)]
7pub struct Read;
8
9/// The read write tag for [virtual tensor](VirtualTensor).
10#[derive(Clone)]
11pub struct ReadWrite;
12
13/// Tensor representation that is decoupled from how the tensor is stored.
14#[derive(Clone)]
15pub struct VirtualTensor<E: Numeric, IO = Read> {
16    // state: Arc<dyn VirtualTensorOperations<E>>,
17    _e: PhantomData<E>,
18    _p: PhantomData<IO>,
19}
20
21impl<E: Numeric, IO: Clone> Copy for VirtualTensor<E, IO> {}
22
23/// Expand type for [VirtualTensor].
24#[derive(Clone)]
25pub struct VirtualTensorExpand<E: Numeric, IO> {
26    state: Arc<dyn VirtualTensorOperationsExpand<E>>,
27    _p: PhantomData<IO>,
28}
29
30impl<E: Numeric, IO: Clone> List<Line<E>> for VirtualTensor<E, IO> {
31    fn __expand_read(
32        scope: &mut Scope,
33        this: VirtualTensorExpand<E, IO>,
34        index: <u32 as CubeType>::ExpandType,
35    ) -> <Line<E> as CubeType>::ExpandType {
36        this.__expand_read_method(scope, index)
37    }
38}
39
40impl<E: Numeric, IO: Clone> ListExpand<Line<E>> for VirtualTensorExpand<E, IO> {
41    fn __expand_read_method(
42        self,
43        scope: &mut Scope,
44        index: <u32 as CubeType>::ExpandType,
45    ) -> <Line<E> as CubeType>::ExpandType {
46        self.state.clone().__expand_read_method(scope, index)
47    }
48}
49
50#[allow(unused, clippy::all)]
51impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
52    pub fn as_tensor_map(&self) -> TensorMap<E> {
53        unexpanded!()
54    }
55    pub fn as_slice(&self, start: u32, end: u32) -> Slice<Line<E>> {
56        unexpanded!();
57    }
58    /// Get the shape of the tensor at the given axis.
59    pub fn shape(&self, axis: u32) -> u32 {
60        unexpanded!();
61    }
62    /// Get the stride of the tensor at the given axis.
63    pub fn stride(&self, axis: u32) -> u32 {
64        unexpanded!();
65    }
66    /// Get the rank of the tensor.
67    pub fn rank(&self) -> u32 {
68        unexpanded!();
69    }
70
71    pub fn len(&self) -> u32 {
72        unexpanded!();
73    }
74
75    pub fn buffer_len(&self) -> u32 {
76        unexpanded!();
77    }
78
79    pub fn __expand_as_tensor_map(
80        context: &mut Scope,
81        this: <Self as CubeType>::ExpandType,
82    ) -> <TensorMap<E> as CubeType>::ExpandType {
83        this.__expand_as_tensor_map_method(context)
84    }
85    pub fn __expand_as_slice(
86        context: &mut Scope,
87        this: <Self as CubeType>::ExpandType,
88        start: <u32 as CubeType>::ExpandType,
89        end: <u32 as CubeType>::ExpandType,
90    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
91        this.__expand_as_slice_method(context, start, end)
92    }
93    pub fn __expand_shape(
94        scope: &mut Scope,
95        this: <Self as CubeType>::ExpandType,
96        axis: <u32 as CubeType>::ExpandType,
97    ) -> <u32 as CubeType>::ExpandType {
98        this.__expand_shape_method(scope, axis)
99    }
100    pub fn __expand_stride(
101        scope: &mut Scope,
102        this: <Self as CubeType>::ExpandType,
103        axis: <u32 as CubeType>::ExpandType,
104    ) -> <u32 as CubeType>::ExpandType {
105        this.__expand_stride_method(scope, axis)
106    }
107    pub fn __expand_rank(
108        scope: &mut Scope,
109        this: <Self as CubeType>::ExpandType,
110    ) -> <u32 as CubeType>::ExpandType {
111        this.__expand_rank_method(scope)
112    }
113    pub fn __expand_len(
114        scope: &mut Scope,
115        this: <Self as CubeType>::ExpandType,
116    ) -> <u32 as CubeType>::ExpandType {
117        this.__expand_len_method(scope)
118    }
119    pub fn __expand_buffer_len(
120        scope: &mut Scope,
121        this: <Self as CubeType>::ExpandType,
122    ) -> <u32 as CubeType>::ExpandType {
123        this.__expand_buffer_len_method(scope)
124    }
125}
126
127#[allow(unused, clippy::all)]
128impl<E: Numeric, IO: Clone> VirtualTensorExpand<E, IO> {
129    pub fn __expand_as_tensor_map_method(
130        self,
131        context: &mut Scope,
132    ) -> <TensorMap<E> as CubeType>::ExpandType {
133        self.state.clone().__expand_as_tensor_map_method(context)
134    }
135
136    pub fn __expand_as_slice_method(
137        self,
138        context: &mut Scope,
139        start: <u32 as CubeType>::ExpandType,
140        end: <u32 as CubeType>::ExpandType,
141    ) -> <Slice<Line<E>> as CubeType>::ExpandType {
142        self.state
143            .clone()
144            .__expand_read_window_method(context, start, end)
145    }
146
147    pub fn __expand_shape_method(
148        self,
149        scope: &mut Scope,
150        axis: <u32 as CubeType>::ExpandType,
151    ) -> <u32 as CubeType>::ExpandType {
152        let _arg_0 = axis;
153        self.state
154            .clone()
155            .__expand_shape_method(scope, _arg_0.into())
156    }
157
158    pub fn __expand_stride_method(
159        self,
160        scope: &mut Scope,
161        axis: <u32 as CubeType>::ExpandType,
162    ) -> <u32 as CubeType>::ExpandType {
163        let _arg_0 = axis;
164        self.state
165            .clone()
166            .__expand_stride_method(scope, _arg_0.into())
167    }
168
169    pub fn __expand_rank_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
170        self.state.clone().__expand_rank_method(scope)
171    }
172
173    pub fn __expand_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
174        self.state.clone().__expand_len_method(scope)
175    }
176
177    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> <u32 as CubeType>::ExpandType {
178        self.state.clone().__expand_buffer_len_method(scope)
179    }
180
181    pub fn __expand_read(
182        scope: &mut Scope,
183        this: Self,
184        index: <u32 as CubeType>::ExpandType,
185    ) -> <Line<E> as CubeType>::ExpandType {
186        VirtualTensor::<E, IO>::__expand_read(scope, this, index)
187    }
188
189    pub fn __expand_shape(
190        scope: &mut Scope,
191        this: Self,
192        axis: <u32 as CubeType>::ExpandType,
193    ) -> <u32 as CubeType>::ExpandType {
194        VirtualTensor::<E, IO>::__expand_shape(scope, this, axis)
195    }
196
197    pub fn __expand_stride(
198        scope: &mut Scope,
199        this: Self,
200        axis: <u32 as CubeType>::ExpandType,
201    ) -> <u32 as CubeType>::ExpandType {
202        VirtualTensor::<E, IO>::__expand_stride(scope, this, axis)
203    }
204
205    pub fn __expand_rank(scope: &mut Scope, this: Self) -> <u32 as CubeType>::ExpandType {
206        VirtualTensor::<E, IO>::__expand_rank(scope, this)
207    }
208}
209
210#[cube]
211impl<E: Numeric, IO: Clone> VirtualTensor<E, IO> {
212    pub fn coordinate(&self, index: u32, dim: u32) -> u32 {
213        let num_strides = index / self.stride(dim);
214        num_strides % self.shape(dim)
215    }
216}
217
218impl<E: Numeric> ListMut<Line<E>> for VirtualTensor<E, ReadWrite> {
219    fn __expand_write(
220        scope: &mut Scope,
221        this: VirtualTensorExpand<E, ReadWrite>,
222        index: <u32 as CubeType>::ExpandType,
223        value: <Line<E> as CubeType>::ExpandType,
224    ) -> <() as CubeType>::ExpandType {
225        this.__expand_write_method(scope, index, value)
226    }
227}
228
229impl<E: Numeric> ListMutExpand<Line<E>> for VirtualTensorExpand<E, ReadWrite> {
230    fn __expand_write_method(
231        self,
232        scope: &mut Scope,
233        index: <u32 as CubeType>::ExpandType,
234        value: <Line<E> as CubeType>::ExpandType,
235    ) -> <() as CubeType>::ExpandType {
236        self.state
237            .clone()
238            .__expand_write_method(scope, index, value)
239    }
240}
241
242impl<E: Numeric> VirtualTensor<E, Read> {
243    /// Create a new [read only](Read) [virtual tensor](VirtualTensor).
244    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &V) -> Self {
245        unexpanded!()
246    }
247
248    /// Expand function of [Self::new].
249    pub fn __expand_new<V>(_scope: &mut Scope, v: V::ExpandType) -> VirtualTensorExpand<E, Read>
250    where
251        V::ExpandType: VirtualTensorOperationsExpand<E>,
252        V: VirtualTensorOperations<E> + CubeType + 'static,
253    {
254        VirtualTensorExpand {
255            state: Arc::new(v),
256            _p: PhantomData,
257        }
258    }
259}
260
261impl<E: Numeric> VirtualTensor<E, ReadWrite> {
262    /// Create a new [read write](ReadWrite) [virtual tensor](VirtualTensor).
263    pub fn new<V: VirtualTensorOperations<E> + 'static>(_v: &mut V) -> Self {
264        unexpanded!()
265    }
266
267    /// Expand function of [Self::new].
268    pub fn __expand_new<V>(
269        _scope: &mut Scope,
270        v: V::ExpandType,
271    ) -> VirtualTensorExpand<E, ReadWrite>
272    where
273        V::ExpandType: VirtualTensorOperationsExpand<E>,
274        V: VirtualTensorOperations<E> + CubeType + 'static,
275    {
276        VirtualTensorExpand {
277            state: Arc::new(v),
278            _p: PhantomData,
279        }
280    }
281}
282
283/// Trait to be implemented by a type that can become a [virtual tensor](VirtualTensor).
284///
285/// The [expand trait](VirtualTensorOperationsExpand) also need to be implemented for the type's
286/// expand type.
287///
288/// # Warning
289///
290/// This trait is kind of unsafe, [VirtualTensorOperations::write] doesn't follow the mutability
291/// rules, but it won't lead to any undefined behavior.
292pub trait VirtualTensorOperations<E: Numeric> {
293    /// Read the tensor at the given index.
294    fn read(&self, _index: u32) -> Line<E> {
295        unexpanded!()
296    }
297    /// Write the tensor at the given index.
298    fn write(&self, _index: u32, _value: Line<E>) {
299        unexpanded!()
300    }
301    /// Get the shape of the tensor at the given axis.
302    fn shape(&self, _axis: u32) -> u32 {
303        unexpanded!()
304    }
305    /// Get the stride of the tensor at the given axis.
306    fn stride(&self, _axis: u32) -> u32 {
307        unexpanded!()
308    }
309    /// Get the rank of the tensor.
310    fn rank(&self) -> u32 {
311        unexpanded!()
312    }
313}
314
315/// Expand trait for [VirtualTensorOperations].
316///
317/// For now this needs to be manually implemented for any type that needs to become a virtual
318/// tensor.
319pub trait VirtualTensorOperationsExpand<E: Numeric> {
320    fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<E>>;
321    fn __expand_read_method(
322        &self,
323        scope: &mut Scope,
324        index: ExpandElementTyped<u32>,
325    ) -> ExpandElementTyped<Line<E>>;
326    fn __expand_read_window_method(
327        &self,
328        context: &mut Scope,
329        start: ExpandElementTyped<u32>,
330        end: ExpandElementTyped<u32>,
331    ) -> ExpandElementTyped<Slice<Line<E>>>;
332    fn __expand_write_method(
333        &self,
334        scope: &mut Scope,
335        index: ExpandElementTyped<u32>,
336        value: ExpandElementTyped<Line<E>>,
337    );
338    fn __expand_shape_method(
339        &self,
340        scope: &mut Scope,
341        axis: ExpandElementTyped<u32>,
342    ) -> ExpandElementTyped<u32>;
343    fn __expand_stride_method(
344        &self,
345        scope: &mut Scope,
346        axis: ExpandElementTyped<u32>,
347    ) -> ExpandElementTyped<u32>;
348    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
349    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
350    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32>;
351}
352
353/// Making [virtual tensors](VirtualTensor) a proper [cube type](CubeType).
354mod __cube_type {
355    use super::*;
356
357    impl<E: Numeric, IO: Clone> CubeType for VirtualTensor<E, IO> {
358        type ExpandType = VirtualTensorExpand<E, IO>;
359    }
360
361    impl<E: Numeric, IO> Init for VirtualTensorExpand<E, IO> {
362        fn init(self, _scope: &mut Scope) -> Self {
363            self
364        }
365    }
366
367    impl<E: Numeric, IO> CubeDebug for VirtualTensorExpand<E, IO> {}
368}
369
370/// Enable tensors to be virtual.
371mod __tensor {
372    use super::*;
373
374    impl<E: Numeric> VirtualTensorOperations<E> for Tensor<Line<E>> {}
375    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<Tensor<Line<E>>> {
376        fn __expand_read_method(
377            &self,
378            scope: &mut Scope,
379            index: ExpandElementTyped<u32>,
380        ) -> ExpandElementTyped<Line<E>> {
381            self.clone().__expand_index_unchecked_method(scope, index)
382        }
383        fn __expand_read_window_method(
384            &self,
385            context: &mut Scope,
386            start: ExpandElementTyped<u32>,
387            end: ExpandElementTyped<u32>,
388        ) -> ExpandElementTyped<Slice<Line<E>>> {
389            self.clone().__expand_slice_method(context, start, end)
390        }
391
392        fn __expand_write_method(
393            &self,
394            scope: &mut Scope,
395            index: ExpandElementTyped<u32>,
396            value: ExpandElementTyped<Line<E>>,
397        ) {
398            self.clone()
399                .__expand_index_assign_unchecked_method(scope, index, value)
400        }
401
402        fn __expand_shape_method(
403            &self,
404            scope: &mut Scope,
405            axis: ExpandElementTyped<u32>,
406        ) -> ExpandElementTyped<u32> {
407            self.clone().__expand_shape_method(scope, axis)
408        }
409
410        fn __expand_stride_method(
411            &self,
412            scope: &mut Scope,
413            axis: ExpandElementTyped<u32>,
414        ) -> ExpandElementTyped<u32> {
415            self.clone().__expand_stride_method(scope, axis)
416        }
417
418        fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
419            self.clone().__expand_rank_method(scope)
420        }
421        fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
422            self.clone().__expand_len_method(scope)
423        }
424        fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
425            self.clone().__expand_buffer_len_method(scope)
426        }
427
428        fn __expand_as_tensor_map_method(
429            &self,
430            _scope: &mut Scope,
431        ) -> ExpandElementTyped<TensorMap<E>> {
432            unimplemented!("Can't turn normal tensor into `TensorMap`");
433        }
434    }
435}
436
437/// Enable tensor maps to be virtual.
438mod __tensor_map {
439    use super::*;
440
441    impl<E: Numeric> VirtualTensorOperations<E> for TensorMap<E> {}
442    impl<E: Numeric> VirtualTensorOperationsExpand<E> for ExpandElementTyped<TensorMap<E>> {
443        fn __expand_read_method(
444            &self,
445            _scope: &mut Scope,
446            _index: ExpandElementTyped<u32>,
447        ) -> ExpandElementTyped<Line<E>> {
448            todo!()
449        }
450        fn __expand_read_window_method(
451            &self,
452            _context: &mut Scope,
453            _start: ExpandElementTyped<u32>,
454            _end: ExpandElementTyped<u32>,
455        ) -> ExpandElementTyped<Slice<Line<E>>> {
456            todo!()
457        }
458
459        fn __expand_write_method(
460            &self,
461            _scope: &mut Scope,
462            _index: ExpandElementTyped<u32>,
463            _value: ExpandElementTyped<Line<E>>,
464        ) {
465            todo!()
466        }
467
468        fn __expand_shape_method(
469            &self,
470            _scope: &mut Scope,
471            _axis: ExpandElementTyped<u32>,
472        ) -> ExpandElementTyped<u32> {
473            todo!()
474        }
475
476        fn __expand_stride_method(
477            &self,
478            _scope: &mut Scope,
479            _axis: ExpandElementTyped<u32>,
480        ) -> ExpandElementTyped<u32> {
481            todo!()
482        }
483
484        fn __expand_rank_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
485            todo!()
486        }
487        fn __expand_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
488            todo!()
489        }
490        fn __expand_buffer_len_method(&self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
491            todo!()
492        }
493
494        fn __expand_as_tensor_map_method(
495            &self,
496            _scope: &mut Scope,
497        ) -> ExpandElementTyped<TensorMap<E>> {
498            self.clone()
499        }
500    }
501}