1use std::marker::PhantomData;
2
3use super::*;
4use crate::tensor::layout::{Coordinates, VirtualLayout, VirtualLayoutExpand};
5use cubecl::prelude::*;
6use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
7
8#[derive(CubeType)]
9pub struct VirtualView<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>> {
10 #[allow(unused)]
11 view: V,
12 #[allow(unused)]
13 layout: VirtualLayout<C, S>,
14 #[cube(comptime)]
15 _ty: PhantomData<T>,
16}
17
18#[cube]
19impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
20 VirtualView<T, C, S, V>
21{
22 pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
23 VirtualView::<T, C, S, V> {
24 view,
25 layout,
26 _ty: PhantomData,
27 }
28 }
29}
30
31impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
32 VirtualViewExpand<T, C, S, V>
33{
34 pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
35 VirtualViewExpand::<T, C, S, V> {
36 view,
37 layout,
38 _ty: PhantomData,
39 }
40 }
41}
42
43#[derive(CubeType)]
44pub struct VirtualViewMut<
45 T: CubePrimitive,
46 C: Coordinates,
47 S: Coordinates,
48 V: ViewOperationsMut<T, S>,
49> {
50 #[allow(unused)]
51 view: V,
52 #[allow(unused)]
53 layout: VirtualLayout<C, S>,
54 #[cube(comptime)]
55 _ty: PhantomData<T>,
56}
57
58#[cube]
59impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
60 VirtualViewMut<T, C, S, V>
61{
62 pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
63 VirtualViewMut::<T, C, S, V> {
64 view,
65 layout,
66 _ty: PhantomData,
67 }
68 }
69}
70
71impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
72 VirtualViewMutExpand<T, C, S, V>
73{
74 pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
75 VirtualViewMutExpand::<T, C, S, V> {
76 view,
77 layout,
78 _ty: PhantomData,
79 }
80 }
81}
82
83macro_rules! impl_virtual_read {
84 ($ty: ident, $expand: ident, $trait: ident) => {
85 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> Lined for $ty<T, C, S, V> where
86 V: $trait<T, S>
87 {
88 }
89 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> LinedExpand
90 for $expand<T, C, S, V>
91 where
92 V: $trait<T, S>,
93 {
94 fn line_size(&self) -> u32 {
95 self.view.line_size()
96 }
97 }
98
99 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperations<T, C>
100 for $ty<T, C, S, V>
101 where
102 V: $trait<T, S>,
103 {
104 }
105
106 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsExpand<T, C>
107 for $expand<T, C, S, V>
108 where
109 V: $trait<T, S>,
110 {
111 fn __expand_read_method(
112 &self,
113 scope: &mut Scope,
114 pos: <C>::ExpandType,
115 ) -> <T>::ExpandType {
116 let pos = self
117 .layout
118 .clone()
119 .__expand_to_source_pos_method(scope, pos);
120 self.view.clone().__expand_read_method(scope, pos)
121 }
122
123 fn __expand_read_checked_method(
124 &self,
125 scope: &mut Scope,
126 pos: <C>::ExpandType,
127 ) -> <T>::ExpandType {
128 let (read_pos, in_bounds) = self
129 .layout
130 .clone()
131 .__expand_to_source_pos_checked_method(scope, pos);
132 let zero = T::__expand_cast_from(scope, 0.into());
133 let value = self.view.__expand_read_checked_method(scope, read_pos);
134 select::expand::<T>(scope, in_bounds, value, zero)
135 }
136
137 fn __expand_read_masked_method(
138 &self,
139 scope: &mut Scope,
140 pos: <C>::ExpandType,
141 mask_value: <T>::ExpandType,
142 ) -> <T>::ExpandType {
143 let (read_pos, in_bounds) = self
144 .layout
145 .clone()
146 .__expand_to_source_pos_checked_method(scope, pos);
147 let value = self.view.__expand_read_checked_method(scope, read_pos);
148 select::expand::<T>(scope, in_bounds, value, mask_value)
149 }
150
151 fn __expand_read_unchecked_method(
152 &self,
153 scope: &mut Scope,
154 pos: <C>::ExpandType,
155 ) -> <T>::ExpandType {
156 let pos = self
157 .layout
158 .clone()
159 .__expand_to_source_pos_method(scope, pos);
160 self.view.__expand_read_unchecked_method(scope, pos)
161 }
162
163 fn __expand_to_linear_slice_method(
164 &self,
165 scope: &mut Scope,
166 pos: <C>::ExpandType,
167 end: <C>::ExpandType,
168 ) -> SliceExpand<T, ReadOnly> {
169 let pos = self
170 .layout
171 .clone()
172 .__expand_to_source_pos_method(scope, pos);
173 let end = self
174 .layout
175 .clone()
176 .__expand_to_source_pos_method(scope, end);
177 self.view.__expand_to_linear_slice_method(scope, pos, end)
178 }
179
180 fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
181 self.layout.clone().__expand_shape_method(scope)
182 }
183
184 fn __expand_is_in_bounds_method(
185 &self,
186 scope: &mut Scope,
187 pos: C::ExpandType,
188 ) -> ExpandElementTyped<bool> {
189 let (pos, in_bounds_layout) = self
190 .layout
191 .clone()
192 .__expand_to_source_pos_checked_method(scope, pos);
193 let in_bounds_view = self.view.clone().__expand_is_in_bounds_method(scope, pos);
194 and::expand(scope, in_bounds_layout, in_bounds_view)
195 }
196
197 fn __expand_tensor_map_load_method(
198 &self,
199 scope: &mut Scope,
200 barrier: BarrierExpand,
201 shared_memory: SliceExpand<T, ReadWrite>,
202 pos: C::ExpandType,
203 ) {
204 let pos = self
205 .layout
206 .clone()
207 .__expand_to_source_pos_method(scope, pos);
208 self.view
209 .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos);
210 }
211 }
212 };
213}
214
215impl_virtual_read!(VirtualView, VirtualViewExpand, ViewOperations);
216impl_virtual_read!(VirtualViewMut, VirtualViewMutExpand, ViewOperationsMut);
217
218impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMut<T, C>
219 for VirtualViewMut<T, C, S, V>
220where
221 V: ViewOperationsMut<T, S>,
222{
223}
224
225impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMutExpand<T, C>
226 for VirtualViewMutExpand<T, C, S, V>
227where
228 V: ViewOperationsMut<T, S>,
229{
230 fn __expand_write_method(
231 &self,
232 scope: &mut Scope,
233 pos: <C>::ExpandType,
234 value: <T>::ExpandType,
235 ) {
236 let pos = self
237 .layout
238 .clone()
239 .__expand_to_source_pos_method(scope, pos);
240 self.view.__expand_write_method(scope, pos, value);
241 }
242
243 fn __expand_write_checked_method(
244 &self,
245 scope: &mut Scope,
246 pos: <C>::ExpandType,
247 value: <T>::ExpandType,
248 ) {
249 let (pos, in_bounds) = self
250 .layout
251 .clone()
252 .__expand_to_source_pos_checked_method(scope, pos);
253 if_expand(scope, in_bounds.into(), |scope| {
254 self.view.__expand_write_checked_method(scope, pos, value);
255 });
256 }
257
258 fn __expand_to_linear_slice_mut_method(
259 &self,
260 scope: &mut Scope,
261 pos: <C>::ExpandType,
262 end: <C>::ExpandType,
263 ) -> SliceExpand<T, ReadWrite> {
264 let pos = self
265 .layout
266 .clone()
267 .__expand_to_source_pos_method(scope, pos);
268 let end = self
269 .layout
270 .clone()
271 .__expand_to_source_pos_method(scope, end);
272 self.view
273 .__expand_to_linear_slice_mut_method(scope, pos, end)
274 }
275
276 fn __expand_tensor_map_store_method(
277 &self,
278 scope: &mut Scope,
279 shared_memory: SliceExpand<T, ReadOnly>,
280 pos: C::ExpandType,
281 ) {
282 let pos = self
283 .layout
284 .clone()
285 .__expand_to_source_pos_method(scope, pos);
286 self.view
287 .__expand_tensor_map_store_method(scope, shared_memory, pos);
288 }
289}