cubecl_std/tensor/view/operations/
virtual_view.rs

1use std::marker::PhantomData;
2
3use super::*;
4use crate::{
5    CubeOptionExpand,
6    tensor::layout::{Coordinates, VirtualLayout, VirtualLayoutExpand},
7};
8use cubecl::prelude::*;
9use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
10
11#[derive(CubeType)]
12pub struct VirtualView<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>> {
13    #[allow(unused)]
14    view: V,
15    #[allow(unused)]
16    layout: VirtualLayout<C, S>,
17    #[cube(comptime)]
18    _ty: PhantomData<T>,
19}
20
21#[cube]
22impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
23    VirtualView<T, C, S, V>
24{
25    pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
26        VirtualView::<T, C, S, V> {
27            view,
28            layout,
29            _ty: PhantomData,
30        }
31    }
32}
33
34impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
35    VirtualViewExpand<T, C, S, V>
36{
37    pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
38        VirtualViewExpand::<T, C, S, V> {
39            view,
40            layout,
41            _ty: PhantomData,
42        }
43    }
44}
45
46#[derive(CubeType)]
47pub struct VirtualViewMut<
48    T: CubePrimitive,
49    C: Coordinates,
50    S: Coordinates,
51    V: ViewOperationsMut<T, S>,
52> {
53    #[allow(unused)]
54    view: V,
55    #[allow(unused)]
56    layout: VirtualLayout<C, S>,
57    #[cube(comptime)]
58    _ty: PhantomData<T>,
59}
60
61#[cube]
62impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
63    VirtualViewMut<T, C, S, V>
64{
65    pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
66        VirtualViewMut::<T, C, S, V> {
67            view,
68            layout,
69            _ty: PhantomData,
70        }
71    }
72}
73
74impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
75    VirtualViewMutExpand<T, C, S, V>
76{
77    pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
78        VirtualViewMutExpand::<T, C, S, V> {
79            view,
80            layout,
81            _ty: PhantomData,
82        }
83    }
84}
85
86macro_rules! impl_virtual_read {
87    ($ty: ident, $expand: ident, $trait: ident) => {
88        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> Lined for $ty<T, C, S, V> where
89            V: $trait<T, S>
90        {
91        }
92        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> LinedExpand
93            for $expand<T, C, S, V>
94        where
95            V: $trait<T, S>,
96        {
97            fn line_size(&self) -> u32 {
98                self.view.line_size()
99            }
100        }
101
102        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperations<T, C>
103            for $ty<T, C, S, V>
104        where
105            V: $trait<T, S>,
106        {
107        }
108
109        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsExpand<T, C>
110            for $expand<T, C, S, V>
111        where
112            V: $trait<T, S>,
113        {
114            fn __expand_read_method(
115                &self,
116                scope: &mut Scope,
117                pos: <C>::ExpandType,
118            ) -> <T>::ExpandType {
119                let pos = self
120                    .layout
121                    .clone()
122                    .__expand_to_source_pos_method(scope, pos);
123                self.view.clone().__expand_read_method(scope, pos)
124            }
125
126            fn __expand_read_checked_method(
127                &self,
128                scope: &mut Scope,
129                pos: <C>::ExpandType,
130            ) -> <T>::ExpandType {
131                let (read_pos, in_bounds) = self
132                    .layout
133                    .clone()
134                    .__expand_to_source_pos_checked_method(scope, pos);
135                let zero = T::__expand_cast_from(scope, 0.into());
136                let value = self.view.__expand_read_checked_method(scope, read_pos);
137                select::expand::<T>(scope, in_bounds, value, zero)
138            }
139
140            fn __expand_read_masked_method(
141                &self,
142                scope: &mut Scope,
143                pos: <C>::ExpandType,
144                mask_value: <T>::ExpandType,
145            ) -> <T>::ExpandType {
146                let (read_pos, in_bounds) = self
147                    .layout
148                    .clone()
149                    .__expand_to_source_pos_checked_method(scope, pos);
150                let value = self.view.__expand_read_checked_method(scope, read_pos);
151                select::expand::<T>(scope, in_bounds, value, mask_value)
152            }
153
154            fn __expand_read_unchecked_method(
155                &self,
156                scope: &mut Scope,
157                pos: <C>::ExpandType,
158            ) -> <T>::ExpandType {
159                let pos = self
160                    .layout
161                    .clone()
162                    .__expand_to_source_pos_method(scope, pos);
163                self.view.__expand_read_unchecked_method(scope, pos)
164            }
165
166            fn __expand_to_linear_slice_method(
167                &self,
168                scope: &mut Scope,
169                pos: <C>::ExpandType,
170                end: <C>::ExpandType,
171            ) -> SliceExpand<T, ReadOnly> {
172                let pos = self
173                    .layout
174                    .clone()
175                    .__expand_to_source_pos_method(scope, pos);
176                let end = self
177                    .layout
178                    .clone()
179                    .__expand_to_source_pos_method(scope, end);
180                self.view.__expand_to_linear_slice_method(scope, pos, end)
181            }
182
183            fn __expand_as_tensor_map_method(
184                &self,
185                scope: &mut Scope,
186            ) -> CubeOptionExpand<TensorMap<T>> {
187                self.view.__expand_as_tensor_map_method(scope)
188            }
189
190            fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
191                self.layout.clone().__expand_shape_method(scope)
192            }
193
194            fn __expand_is_in_bounds_method(
195                &self,
196                scope: &mut Scope,
197                pos: C::ExpandType,
198            ) -> ExpandElementTyped<bool> {
199                let (pos, in_bounds_layout) = self
200                    .layout
201                    .clone()
202                    .__expand_to_source_pos_checked_method(scope, pos);
203                let in_bounds_view = self.view.clone().__expand_is_in_bounds_method(scope, pos);
204                and::expand(scope, in_bounds_layout, in_bounds_view)
205            }
206
207            fn __expand_tensor_map_load_method(
208                &self,
209                scope: &mut Scope,
210                barrier: BarrierExpand,
211                shared_memory: SliceExpand<T, ReadWrite>,
212                pos: C::ExpandType,
213            ) {
214                let pos = self
215                    .layout
216                    .clone()
217                    .__expand_to_source_pos_method(scope, pos);
218                self.view
219                    .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos);
220            }
221        }
222    };
223}
224
225impl_virtual_read!(VirtualView, VirtualViewExpand, ViewOperations);
226impl_virtual_read!(VirtualViewMut, VirtualViewMutExpand, ViewOperationsMut);
227
228impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMut<T, C>
229    for VirtualViewMut<T, C, S, V>
230where
231    V: ViewOperationsMut<T, S>,
232{
233}
234
235impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMutExpand<T, C>
236    for VirtualViewMutExpand<T, C, S, V>
237where
238    V: ViewOperationsMut<T, S>,
239{
240    fn __expand_write_method(
241        &self,
242        scope: &mut Scope,
243        pos: <C>::ExpandType,
244        value: <T>::ExpandType,
245    ) {
246        let pos = self
247            .layout
248            .clone()
249            .__expand_to_source_pos_method(scope, pos);
250        self.view.__expand_write_method(scope, pos, value);
251    }
252
253    fn __expand_write_checked_method(
254        &self,
255        scope: &mut Scope,
256        pos: <C>::ExpandType,
257        value: <T>::ExpandType,
258    ) {
259        let (pos, in_bounds) = self
260            .layout
261            .clone()
262            .__expand_to_source_pos_checked_method(scope, pos);
263        if_expand(scope, in_bounds.into(), |scope| {
264            self.view.__expand_write_checked_method(scope, pos, value);
265        });
266    }
267
268    fn __expand_to_linear_slice_mut_method(
269        &self,
270        scope: &mut Scope,
271        pos: <C>::ExpandType,
272        end: <C>::ExpandType,
273    ) -> SliceExpand<T, ReadWrite> {
274        let pos = self
275            .layout
276            .clone()
277            .__expand_to_source_pos_method(scope, pos);
278        let end = self
279            .layout
280            .clone()
281            .__expand_to_source_pos_method(scope, end);
282        self.view
283            .__expand_to_linear_slice_mut_method(scope, pos, end)
284    }
285
286    fn __expand_tensor_map_store_method(
287        &self,
288        scope: &mut Scope,
289        shared_memory: SliceExpand<T, ReadOnly>,
290        pos: C::ExpandType,
291    ) {
292        let pos = self
293            .layout
294            .clone()
295            .__expand_to_source_pos_method(scope, pos);
296        self.view
297            .__expand_tensor_map_store_method(scope, shared_memory, pos);
298    }
299}