cubecl_std/tensor/view/operations/
virtual_view.rs

1use std::marker::PhantomData;
2
3use super::*;
4use crate::tensor::layout::{Coordinates, VirtualLayout, VirtualLayoutExpand};
5use cubecl::prelude::*;
6use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
7
8#[derive(CubeType)]
9pub struct VirtualView<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>> {
10    #[allow(unused)]
11    view: V,
12    #[allow(unused)]
13    layout: VirtualLayout<C, S>,
14    #[cube(comptime)]
15    _ty: PhantomData<T>,
16}
17
18#[cube]
19impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
20    VirtualView<T, C, S, V>
21{
22    pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
23        VirtualView::<T, C, S, V> {
24            view,
25            layout,
26            _ty: PhantomData,
27        }
28    }
29}
30
31impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
32    VirtualViewExpand<T, C, S, V>
33{
34    pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
35        VirtualViewExpand::<T, C, S, V> {
36            view,
37            layout,
38            _ty: PhantomData,
39        }
40    }
41}
42
43#[derive(CubeType)]
44pub struct VirtualViewMut<
45    T: CubePrimitive,
46    C: Coordinates,
47    S: Coordinates,
48    V: ViewOperationsMut<T, S>,
49> {
50    #[allow(unused)]
51    view: V,
52    #[allow(unused)]
53    layout: VirtualLayout<C, S>,
54    #[cube(comptime)]
55    _ty: PhantomData<T>,
56}
57
58#[cube]
59impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
60    VirtualViewMut<T, C, S, V>
61{
62    pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
63        VirtualViewMut::<T, C, S, V> {
64            view,
65            layout,
66            _ty: PhantomData,
67        }
68    }
69}
70
71impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
72    VirtualViewMutExpand<T, C, S, V>
73{
74    pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
75        VirtualViewMutExpand::<T, C, S, V> {
76            view,
77            layout,
78            _ty: PhantomData,
79        }
80    }
81}
82
83macro_rules! impl_virtual_read {
84    ($ty: ident, $expand: ident, $trait: ident) => {
85        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> Lined for $ty<T, C, S, V> where
86            V: $trait<T, S>
87        {
88        }
89        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> LinedExpand
90            for $expand<T, C, S, V>
91        where
92            V: $trait<T, S>,
93        {
94            fn line_size(&self) -> u32 {
95                self.view.line_size()
96            }
97        }
98
99        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperations<T, C>
100            for $ty<T, C, S, V>
101        where
102            V: $trait<T, S>,
103        {
104        }
105
106        impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsExpand<T, C>
107            for $expand<T, C, S, V>
108        where
109            V: $trait<T, S>,
110        {
111            fn __expand_read_method(
112                &self,
113                scope: &mut Scope,
114                pos: <C>::ExpandType,
115            ) -> <T>::ExpandType {
116                let pos = self
117                    .layout
118                    .clone()
119                    .__expand_to_source_pos_method(scope, pos);
120                self.view.clone().__expand_read_method(scope, pos)
121            }
122
123            fn __expand_read_checked_method(
124                &self,
125                scope: &mut Scope,
126                pos: <C>::ExpandType,
127            ) -> <T>::ExpandType {
128                let (read_pos, in_bounds) = self
129                    .layout
130                    .clone()
131                    .__expand_to_source_pos_checked_method(scope, pos);
132                let zero = T::__expand_cast_from(scope, 0.into());
133                let value = self.view.__expand_read_checked_method(scope, read_pos);
134                select::expand::<T>(scope, in_bounds, value, zero)
135            }
136
137            fn __expand_read_masked_method(
138                &self,
139                scope: &mut Scope,
140                pos: <C>::ExpandType,
141                mask_value: <T>::ExpandType,
142            ) -> <T>::ExpandType {
143                let (read_pos, in_bounds) = self
144                    .layout
145                    .clone()
146                    .__expand_to_source_pos_checked_method(scope, pos);
147                let value = self.view.__expand_read_checked_method(scope, read_pos);
148                select::expand::<T>(scope, in_bounds, value, mask_value)
149            }
150
151            fn __expand_read_unchecked_method(
152                &self,
153                scope: &mut Scope,
154                pos: <C>::ExpandType,
155            ) -> <T>::ExpandType {
156                let pos = self
157                    .layout
158                    .clone()
159                    .__expand_to_source_pos_method(scope, pos);
160                self.view.__expand_read_unchecked_method(scope, pos)
161            }
162
163            fn __expand_to_linear_slice_method(
164                &self,
165                scope: &mut Scope,
166                pos: <C>::ExpandType,
167                end: <C>::ExpandType,
168            ) -> SliceExpand<T, ReadOnly> {
169                let pos = self
170                    .layout
171                    .clone()
172                    .__expand_to_source_pos_method(scope, pos);
173                let end = self
174                    .layout
175                    .clone()
176                    .__expand_to_source_pos_method(scope, end);
177                self.view.__expand_to_linear_slice_method(scope, pos, end)
178            }
179
180            fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
181                self.layout.clone().__expand_shape_method(scope)
182            }
183
184            fn __expand_is_in_bounds_method(
185                &self,
186                scope: &mut Scope,
187                pos: C::ExpandType,
188            ) -> ExpandElementTyped<bool> {
189                let (pos, in_bounds_layout) = self
190                    .layout
191                    .clone()
192                    .__expand_to_source_pos_checked_method(scope, pos);
193                let in_bounds_view = self.view.clone().__expand_is_in_bounds_method(scope, pos);
194                and::expand(scope, in_bounds_layout, in_bounds_view)
195            }
196
197            fn __expand_tensor_map_load_method(
198                &self,
199                scope: &mut Scope,
200                barrier: BarrierExpand,
201                shared_memory: SliceExpand<T, ReadWrite>,
202                pos: C::ExpandType,
203            ) {
204                let pos = self
205                    .layout
206                    .clone()
207                    .__expand_to_source_pos_method(scope, pos);
208                self.view
209                    .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos);
210            }
211        }
212    };
213}
214
215impl_virtual_read!(VirtualView, VirtualViewExpand, ViewOperations);
216impl_virtual_read!(VirtualViewMut, VirtualViewMutExpand, ViewOperationsMut);
217
218impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMut<T, C>
219    for VirtualViewMut<T, C, S, V>
220where
221    V: ViewOperationsMut<T, S>,
222{
223}
224
225impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMutExpand<T, C>
226    for VirtualViewMutExpand<T, C, S, V>
227where
228    V: ViewOperationsMut<T, S>,
229{
230    fn __expand_write_method(
231        &self,
232        scope: &mut Scope,
233        pos: <C>::ExpandType,
234        value: <T>::ExpandType,
235    ) {
236        let pos = self
237            .layout
238            .clone()
239            .__expand_to_source_pos_method(scope, pos);
240        self.view.__expand_write_method(scope, pos, value);
241    }
242
243    fn __expand_write_checked_method(
244        &self,
245        scope: &mut Scope,
246        pos: <C>::ExpandType,
247        value: <T>::ExpandType,
248    ) {
249        let (pos, in_bounds) = self
250            .layout
251            .clone()
252            .__expand_to_source_pos_checked_method(scope, pos);
253        if_expand(scope, in_bounds.into(), |scope| {
254            self.view.__expand_write_checked_method(scope, pos, value);
255        });
256    }
257
258    fn __expand_to_linear_slice_mut_method(
259        &self,
260        scope: &mut Scope,
261        pos: <C>::ExpandType,
262        end: <C>::ExpandType,
263    ) -> SliceExpand<T, ReadWrite> {
264        let pos = self
265            .layout
266            .clone()
267            .__expand_to_source_pos_method(scope, pos);
268        let end = self
269            .layout
270            .clone()
271            .__expand_to_source_pos_method(scope, end);
272        self.view
273            .__expand_to_linear_slice_mut_method(scope, pos, end)
274    }
275
276    fn __expand_tensor_map_store_method(
277        &self,
278        scope: &mut Scope,
279        shared_memory: SliceExpand<T, ReadOnly>,
280        pos: C::ExpandType,
281    ) {
282        let pos = self
283            .layout
284            .clone()
285            .__expand_to_source_pos_method(scope, pos);
286        self.view
287            .__expand_tensor_map_store_method(scope, shared_memory, pos);
288    }
289}