1use std::marker::PhantomData;
2
3use super::*;
4use crate::{
5 CubeOptionExpand,
6 tensor::layout::{Coordinates, VirtualLayout, VirtualLayoutExpand},
7};
8use cubecl::prelude::*;
9use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
10
11#[derive(CubeType)]
12pub struct VirtualView<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>> {
13 #[allow(unused)]
14 view: V,
15 #[allow(unused)]
16 layout: VirtualLayout<C, S>,
17 #[cube(comptime)]
18 _ty: PhantomData<T>,
19}
20
21#[cube]
22impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
23 VirtualView<T, C, S, V>
24{
25 pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
26 VirtualView::<T, C, S, V> {
27 view,
28 layout,
29 _ty: PhantomData,
30 }
31 }
32}
33
34impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
35 VirtualViewExpand<T, C, S, V>
36{
37 pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
38 VirtualViewExpand::<T, C, S, V> {
39 view,
40 layout,
41 _ty: PhantomData,
42 }
43 }
44}
45
46#[derive(CubeType)]
47pub struct VirtualViewMut<
48 T: CubePrimitive,
49 C: Coordinates,
50 S: Coordinates,
51 V: ViewOperationsMut<T, S>,
52> {
53 #[allow(unused)]
54 view: V,
55 #[allow(unused)]
56 layout: VirtualLayout<C, S>,
57 #[cube(comptime)]
58 _ty: PhantomData<T>,
59}
60
61#[cube]
62impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
63 VirtualViewMut<T, C, S, V>
64{
65 pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
66 VirtualViewMut::<T, C, S, V> {
67 view,
68 layout,
69 _ty: PhantomData,
70 }
71 }
72}
73
74impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
75 VirtualViewMutExpand<T, C, S, V>
76{
77 pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
78 VirtualViewMutExpand::<T, C, S, V> {
79 view,
80 layout,
81 _ty: PhantomData,
82 }
83 }
84}
85
86macro_rules! impl_virtual_read {
87 ($ty: ident, $expand: ident, $trait: ident) => {
88 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> Lined for $ty<T, C, S, V> where
89 V: $trait<T, S>
90 {
91 }
92 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> LinedExpand
93 for $expand<T, C, S, V>
94 where
95 V: $trait<T, S>,
96 {
97 fn line_size(&self) -> u32 {
98 self.view.line_size()
99 }
100 }
101
102 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperations<T, C>
103 for $ty<T, C, S, V>
104 where
105 V: $trait<T, S>,
106 {
107 }
108
109 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsExpand<T, C>
110 for $expand<T, C, S, V>
111 where
112 V: $trait<T, S>,
113 {
114 fn __expand_read_method(
115 &self,
116 scope: &mut Scope,
117 pos: <C>::ExpandType,
118 ) -> <T>::ExpandType {
119 let pos = self
120 .layout
121 .clone()
122 .__expand_to_source_pos_method(scope, pos);
123 self.view.clone().__expand_read_method(scope, pos)
124 }
125
126 fn __expand_read_checked_method(
127 &self,
128 scope: &mut Scope,
129 pos: <C>::ExpandType,
130 ) -> <T>::ExpandType {
131 let (read_pos, in_bounds) = self
132 .layout
133 .clone()
134 .__expand_to_source_pos_checked_method(scope, pos);
135 let zero = T::__expand_cast_from(scope, 0.into());
136 let value = self.view.__expand_read_checked_method(scope, read_pos);
137 select::expand::<T>(scope, in_bounds, value, zero)
138 }
139
140 fn __expand_read_masked_method(
141 &self,
142 scope: &mut Scope,
143 pos: <C>::ExpandType,
144 mask_value: <T>::ExpandType,
145 ) -> <T>::ExpandType {
146 let (read_pos, in_bounds) = self
147 .layout
148 .clone()
149 .__expand_to_source_pos_checked_method(scope, pos);
150 let value = self.view.__expand_read_checked_method(scope, read_pos);
151 select::expand::<T>(scope, in_bounds, value, mask_value)
152 }
153
154 fn __expand_read_unchecked_method(
155 &self,
156 scope: &mut Scope,
157 pos: <C>::ExpandType,
158 ) -> <T>::ExpandType {
159 let pos = self
160 .layout
161 .clone()
162 .__expand_to_source_pos_method(scope, pos);
163 self.view.__expand_read_unchecked_method(scope, pos)
164 }
165
166 fn __expand_to_linear_slice_method(
167 &self,
168 scope: &mut Scope,
169 pos: <C>::ExpandType,
170 end: <C>::ExpandType,
171 ) -> SliceExpand<T, ReadOnly> {
172 let pos = self
173 .layout
174 .clone()
175 .__expand_to_source_pos_method(scope, pos);
176 let end = self
177 .layout
178 .clone()
179 .__expand_to_source_pos_method(scope, end);
180 self.view.__expand_to_linear_slice_method(scope, pos, end)
181 }
182
183 fn __expand_as_tensor_map_method(
184 &self,
185 scope: &mut Scope,
186 ) -> CubeOptionExpand<TensorMap<T>> {
187 self.view.__expand_as_tensor_map_method(scope)
188 }
189
190 fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
191 self.layout.clone().__expand_shape_method(scope)
192 }
193
194 fn __expand_is_in_bounds_method(
195 &self,
196 scope: &mut Scope,
197 pos: C::ExpandType,
198 ) -> ExpandElementTyped<bool> {
199 let (pos, in_bounds_layout) = self
200 .layout
201 .clone()
202 .__expand_to_source_pos_checked_method(scope, pos);
203 let in_bounds_view = self.view.clone().__expand_is_in_bounds_method(scope, pos);
204 and::expand(scope, in_bounds_layout, in_bounds_view)
205 }
206
207 fn __expand_tensor_map_load_method(
208 &self,
209 scope: &mut Scope,
210 barrier: BarrierExpand,
211 shared_memory: SliceExpand<T, ReadWrite>,
212 pos: C::ExpandType,
213 ) {
214 let pos = self
215 .layout
216 .clone()
217 .__expand_to_source_pos_method(scope, pos);
218 self.view
219 .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos);
220 }
221 }
222 };
223}
224
225impl_virtual_read!(VirtualView, VirtualViewExpand, ViewOperations);
226impl_virtual_read!(VirtualViewMut, VirtualViewMutExpand, ViewOperationsMut);
227
228impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMut<T, C>
229 for VirtualViewMut<T, C, S, V>
230where
231 V: ViewOperationsMut<T, S>,
232{
233}
234
235impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMutExpand<T, C>
236 for VirtualViewMutExpand<T, C, S, V>
237where
238 V: ViewOperationsMut<T, S>,
239{
240 fn __expand_write_method(
241 &self,
242 scope: &mut Scope,
243 pos: <C>::ExpandType,
244 value: <T>::ExpandType,
245 ) {
246 let pos = self
247 .layout
248 .clone()
249 .__expand_to_source_pos_method(scope, pos);
250 self.view.__expand_write_method(scope, pos, value);
251 }
252
253 fn __expand_write_checked_method(
254 &self,
255 scope: &mut Scope,
256 pos: <C>::ExpandType,
257 value: <T>::ExpandType,
258 ) {
259 let (pos, in_bounds) = self
260 .layout
261 .clone()
262 .__expand_to_source_pos_checked_method(scope, pos);
263 if_expand(scope, in_bounds.into(), |scope| {
264 self.view.__expand_write_checked_method(scope, pos, value);
265 });
266 }
267
268 fn __expand_to_linear_slice_mut_method(
269 &self,
270 scope: &mut Scope,
271 pos: <C>::ExpandType,
272 end: <C>::ExpandType,
273 ) -> SliceExpand<T, ReadWrite> {
274 let pos = self
275 .layout
276 .clone()
277 .__expand_to_source_pos_method(scope, pos);
278 let end = self
279 .layout
280 .clone()
281 .__expand_to_source_pos_method(scope, end);
282 self.view
283 .__expand_to_linear_slice_mut_method(scope, pos, end)
284 }
285
286 fn __expand_tensor_map_store_method(
287 &self,
288 scope: &mut Scope,
289 shared_memory: SliceExpand<T, ReadOnly>,
290 pos: C::ExpandType,
291 ) {
292 let pos = self
293 .layout
294 .clone()
295 .__expand_to_source_pos_method(scope, pos);
296 self.view
297 .__expand_tensor_map_store_method(scope, shared_memory, pos);
298 }
299}