cubecl_std/tensor/view/operations/
virtual_view.rs

1use std::marker::PhantomData;
2
3use super::*;
4use crate::{
5    CubeOptionExpand,
6    tensor::layout::{Coordinates, VirtualLayout, VirtualLayoutExpand},
7};
8use cubecl::prelude::*;
9use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
10
11#[derive(CubeType)]
12pub struct VirtualView<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>> {
13    #[allow(unused)]
14    view: V,
15    #[allow(unused)]
16    layout: VirtualLayout<C, S>,
17    #[cube(comptime)]
18    _ty: PhantomData<T>,
19}
20
21#[cube]
22impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
23    VirtualView<T, C, S, V>
24{
25    pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
26        VirtualView::<T, C, S, V> {
27            view,
28            layout,
29            _ty: PhantomData,
30        }
31    }
32}
33
34impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
35    VirtualViewExpand<T, C, S, V>
36{
37    pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
38        VirtualViewExpand::<T, C, S, V> {
39            view,
40            layout,
41            _ty: PhantomData,
42        }
43    }
44}
45
46#[derive(CubeType)]
47pub struct VirtualViewMut<
48    T: CubePrimitive,
49    C: Coordinates,
50    S: Coordinates,
51    V: ViewOperationsMut<T, S>,
52> {
53    #[allow(unused)]
54    view: V,
55    #[allow(unused)]
56    layout: VirtualLayout<C, S>,
57    #[cube(comptime)]
58    _ty: PhantomData<T>,
59}
60
61#[cube]
62impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
63    VirtualViewMut<T, C, S, V>
64{
65    pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
66        VirtualViewMut::<T, C, S, V> {
67            view,
68            layout,
69            _ty: PhantomData,
70        }
71    }
72}
73
74impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
75    VirtualViewMutExpand<T, C, S, V>
76{
77    pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
78        VirtualViewMutExpand::<T, C, S, V> {
79            view,
80            layout,
81            _ty: PhantomData,
82        }
83    }
84}
85
86macro_rules! impl_virtual_read {
87    ($ty: ident, $expand: ident, $trait: ident) => {
88        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> Lined for $ty<T, C, S, V> where
89            V: $trait<T, S>
90        {
91        }
92        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> LinedExpand
93            for $expand<T, C, S, V>
94        where
95            V: $trait<T, S>,
96        {
97            fn line_size(&self) -> u32 {
98                self.view.line_size()
99            }
100        }
101
102        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperations<T, C>
103            for $ty<T, C, S, V>
104        where
105            V: $trait<T, S>,
106        {
107        }
108
109        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsExpand<T, C>
110            for $expand<T, C, S, V>
111        where
112            V: $trait<T, S>,
113        {
114            fn __expand_read_method(
115                &self,
116                scope: &mut Scope,
117                pos: <C>::ExpandType,
118            ) -> <T>::ExpandType {
119                let pos = self
120                    .layout
121                    .clone()
122                    .__expand_to_source_pos_method(scope, pos);
123                self.view.clone().__expand_read_method(scope, pos)
124            }
125
126            fn __expand_read_checked_method(
127                &self,
128                scope: &mut Scope,
129                pos: <C>::ExpandType,
130            ) -> <T>::ExpandType {
131                let (read_pos, in_bounds) = self
132                    .layout
133                    .clone()
134                    .__expand_to_source_pos_checked_method(scope, pos);
135                let zero = T::__expand_cast_from(scope, 0.into());
136                let value = self.view.__expand_read_checked_method(scope, read_pos);
137                select::expand::<T>(scope, in_bounds, value, zero)
138            }
139
140            fn __expand_read_masked_method(
141                &self,
142                scope: &mut Scope,
143                pos: <C>::ExpandType,
144                mask_value: <T>::ExpandType,
145            ) -> <T>::ExpandType {
146                let (read_pos, in_bounds) = self
147                    .layout
148                    .clone()
149                    .__expand_to_source_pos_checked_method(scope, pos);
150                let value = self.view.__expand_read_checked_method(scope, read_pos);
151                select::expand::<T>(scope, in_bounds, value, mask_value)
152            }
153
154            fn __expand_read_unchecked_method(
155                &self,
156                scope: &mut Scope,
157                pos: <C>::ExpandType,
158            ) -> <T>::ExpandType {
159                let pos = self
160                    .layout
161                    .clone()
162                    .__expand_to_source_pos_method(scope, pos);
163                self.view.__expand_read_unchecked_method(scope, pos)
164            }
165
166            fn __expand_to_linear_slice_method(
167                &self,
168                scope: &mut Scope,
169                pos: <C>::ExpandType,
170                end: <C>::ExpandType,
171            ) -> SliceExpand<T, ReadOnly> {
172                let pos = self
173                    .layout
174                    .clone()
175                    .__expand_to_source_pos_method(scope, pos);
176                let end = self
177                    .layout
178                    .clone()
179                    .__expand_to_source_pos_method(scope, end);
180                self.view.__expand_to_linear_slice_method(scope, pos, end)
181            }
182
183            fn __expand_as_tensor_map_method(
184                &self,
185                scope: &mut Scope,
186            ) -> CubeOptionExpand<TensorMap<T>> {
187                self.view.__expand_as_tensor_map_method(scope)
188            }
189
190            fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
191                self.layout.clone().__expand_shape_method(scope)
192            }
193
194            fn __expand_is_in_bounds_method(
195                &self,
196                scope: &mut Scope,
197                pos: C::ExpandType,
198            ) -> ExpandElementTyped<bool> {
199                self.layout.clone().__expand_is_in_bounds_method(scope, pos)
200            }
201
202            fn __expand_tensor_map_load_method(
203                &self,
204                scope: &mut Scope,
205                barrier: BarrierExpand,
206                shared_memory: SliceExpand<T, ReadWrite>,
207                pos: C::ExpandType,
208            ) {
209                let pos = self
210                    .layout
211                    .clone()
212                    .__expand_to_source_pos_method(scope, pos);
213                self.view
214                    .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos);
215            }
216        }
217    };
218}
219
220impl_virtual_read!(VirtualView, VirtualViewExpand, ViewOperations);
221impl_virtual_read!(VirtualViewMut, VirtualViewMutExpand, ViewOperationsMut);
222
223impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMut<T, C>
224    for VirtualViewMut<T, C, S, V>
225where
226    V: ViewOperationsMut<T, S>,
227{
228}
229
230impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMutExpand<T, C>
231    for VirtualViewMutExpand<T, C, S, V>
232where
233    V: ViewOperationsMut<T, S>,
234{
235    fn __expand_write_method(
236        &self,
237        scope: &mut Scope,
238        pos: <C>::ExpandType,
239        value: <T>::ExpandType,
240    ) {
241        let pos = self
242            .layout
243            .clone()
244            .__expand_to_source_pos_method(scope, pos);
245        self.view.__expand_write_method(scope, pos, value);
246    }
247
248    fn __expand_write_checked_method(
249        &self,
250        scope: &mut Scope,
251        pos: <C>::ExpandType,
252        value: <T>::ExpandType,
253    ) {
254        let (pos, in_bounds) = self
255            .layout
256            .clone()
257            .__expand_to_source_pos_checked_method(scope, pos);
258        if_expand(scope, in_bounds.into(), |scope| {
259            self.view.__expand_write_checked_method(scope, pos, value);
260        });
261    }
262
263    fn __expand_to_linear_slice_mut_method(
264        &self,
265        scope: &mut Scope,
266        pos: <C>::ExpandType,
267        end: <C>::ExpandType,
268    ) -> SliceExpand<T, ReadWrite> {
269        let pos = self
270            .layout
271            .clone()
272            .__expand_to_source_pos_method(scope, pos);
273        let end = self
274            .layout
275            .clone()
276            .__expand_to_source_pos_method(scope, end);
277        self.view
278            .__expand_to_linear_slice_mut_method(scope, pos, end)
279    }
280
281    fn __expand_tensor_map_store_method(
282        &self,
283        scope: &mut Scope,
284        shared_memory: SliceExpand<T, ReadOnly>,
285        pos: C::ExpandType,
286    ) {
287        let pos = self
288            .layout
289            .clone()
290            .__expand_to_source_pos_method(scope, pos);
291        self.view
292            .__expand_tensor_map_store_method(scope, shared_memory, pos);
293    }
294}