1use std::marker::PhantomData;
2
3use super::*;
4use crate::{
5 CubeOptionExpand,
6 tensor::layout::{Coordinates, VirtualLayout, VirtualLayoutExpand},
7};
8use cubecl::prelude::*;
9use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
10
11#[derive(CubeType)]
12pub struct VirtualView<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>> {
13 #[allow(unused)]
14 view: V,
15 #[allow(unused)]
16 layout: VirtualLayout<C, S>,
17 #[cube(comptime)]
18 _ty: PhantomData<T>,
19}
20
21#[cube]
22impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
23 VirtualView<T, C, S, V>
24{
25 pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
26 VirtualView::<T, C, S, V> {
27 view,
28 layout,
29 _ty: PhantomData,
30 }
31 }
32}
33
34impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperations<T, S>>
35 VirtualViewExpand<T, C, S, V>
36{
37 pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
38 VirtualViewExpand::<T, C, S, V> {
39 view,
40 layout,
41 _ty: PhantomData,
42 }
43 }
44}
45
46#[derive(CubeType)]
47pub struct VirtualViewMut<
48 T: CubePrimitive,
49 C: Coordinates,
50 S: Coordinates,
51 V: ViewOperationsMut<T, S>,
52> {
53 #[allow(unused)]
54 view: V,
55 #[allow(unused)]
56 layout: VirtualLayout<C, S>,
57 #[cube(comptime)]
58 _ty: PhantomData<T>,
59}
60
61#[cube]
62impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
63 VirtualViewMut<T, C, S, V>
64{
65 pub fn new(view: V, layout: VirtualLayout<C, S>) -> Self {
66 VirtualViewMut::<T, C, S, V> {
67 view,
68 layout,
69 _ty: PhantomData,
70 }
71 }
72}
73
74impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V: ViewOperationsMut<T, S>>
75 VirtualViewMutExpand<T, C, S, V>
76{
77 pub fn new(view: V::ExpandType, layout: VirtualLayoutExpand<C, S>) -> Self {
78 VirtualViewMutExpand::<T, C, S, V> {
79 view,
80 layout,
81 _ty: PhantomData,
82 }
83 }
84}
85
86macro_rules! impl_virtual_read {
87 ($ty: ident, $expand: ident, $trait: ident) => {
88 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> Lined for $ty<T, C, S, V> where
89 V: $trait<T, S>
90 {
91 }
92 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> LinedExpand
93 for $expand<T, C, S, V>
94 where
95 V: $trait<T, S>,
96 {
97 fn line_size(&self) -> u32 {
98 self.view.line_size()
99 }
100 }
101
102 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperations<T, C>
103 for $ty<T, C, S, V>
104 where
105 V: $trait<T, S>,
106 {
107 }
108
109 impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsExpand<T, C>
110 for $expand<T, C, S, V>
111 where
112 V: $trait<T, S>,
113 {
114 fn __expand_read_method(
115 &self,
116 scope: &mut Scope,
117 pos: <C>::ExpandType,
118 ) -> <T>::ExpandType {
119 let pos = self
120 .layout
121 .clone()
122 .__expand_to_source_pos_method(scope, pos);
123 self.view.clone().__expand_read_method(scope, pos)
124 }
125
126 fn __expand_read_checked_method(
127 &self,
128 scope: &mut Scope,
129 pos: <C>::ExpandType,
130 ) -> <T>::ExpandType {
131 let (read_pos, in_bounds) = self
132 .layout
133 .clone()
134 .__expand_to_source_pos_checked_method(scope, pos);
135 let zero = T::__expand_cast_from(scope, 0.into());
136 let value = self.view.__expand_read_checked_method(scope, read_pos);
137 select::expand::<T>(scope, in_bounds, value, zero)
138 }
139
140 fn __expand_read_masked_method(
141 &self,
142 scope: &mut Scope,
143 pos: <C>::ExpandType,
144 mask_value: <T>::ExpandType,
145 ) -> <T>::ExpandType {
146 let (read_pos, in_bounds) = self
147 .layout
148 .clone()
149 .__expand_to_source_pos_checked_method(scope, pos);
150 let value = self.view.__expand_read_checked_method(scope, read_pos);
151 select::expand::<T>(scope, in_bounds, value, mask_value)
152 }
153
154 fn __expand_read_unchecked_method(
155 &self,
156 scope: &mut Scope,
157 pos: <C>::ExpandType,
158 ) -> <T>::ExpandType {
159 let pos = self
160 .layout
161 .clone()
162 .__expand_to_source_pos_method(scope, pos);
163 self.view.__expand_read_unchecked_method(scope, pos)
164 }
165
166 fn __expand_to_linear_slice_method(
167 &self,
168 scope: &mut Scope,
169 pos: <C>::ExpandType,
170 end: <C>::ExpandType,
171 ) -> SliceExpand<T, ReadOnly> {
172 let pos = self
173 .layout
174 .clone()
175 .__expand_to_source_pos_method(scope, pos);
176 let end = self
177 .layout
178 .clone()
179 .__expand_to_source_pos_method(scope, end);
180 self.view.__expand_to_linear_slice_method(scope, pos, end)
181 }
182
183 fn __expand_as_tensor_map_method(
184 &self,
185 scope: &mut Scope,
186 ) -> CubeOptionExpand<TensorMap<T>> {
187 self.view.__expand_as_tensor_map_method(scope)
188 }
189
190 fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
191 self.layout.clone().__expand_shape_method(scope)
192 }
193
194 fn __expand_is_in_bounds_method(
195 &self,
196 scope: &mut Scope,
197 pos: C::ExpandType,
198 ) -> ExpandElementTyped<bool> {
199 self.layout.clone().__expand_is_in_bounds_method(scope, pos)
200 }
201
202 fn __expand_tensor_map_load_method(
203 &self,
204 scope: &mut Scope,
205 barrier: BarrierExpand,
206 shared_memory: SliceExpand<T, ReadWrite>,
207 pos: C::ExpandType,
208 ) {
209 let pos = self
210 .layout
211 .clone()
212 .__expand_to_source_pos_method(scope, pos);
213 self.view
214 .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos);
215 }
216 }
217 };
218}
219
220impl_virtual_read!(VirtualView, VirtualViewExpand, ViewOperations);
221impl_virtual_read!(VirtualViewMut, VirtualViewMutExpand, ViewOperationsMut);
222
223impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMut<T, C>
224 for VirtualViewMut<T, C, S, V>
225where
226 V: ViewOperationsMut<T, S>,
227{
228}
229
230impl<T: CubePrimitive, C: Coordinates, S: Coordinates, V> ViewOperationsMutExpand<T, C>
231 for VirtualViewMutExpand<T, C, S, V>
232where
233 V: ViewOperationsMut<T, S>,
234{
235 fn __expand_write_method(
236 &self,
237 scope: &mut Scope,
238 pos: <C>::ExpandType,
239 value: <T>::ExpandType,
240 ) {
241 let pos = self
242 .layout
243 .clone()
244 .__expand_to_source_pos_method(scope, pos);
245 self.view.__expand_write_method(scope, pos, value);
246 }
247
248 fn __expand_write_checked_method(
249 &self,
250 scope: &mut Scope,
251 pos: <C>::ExpandType,
252 value: <T>::ExpandType,
253 ) {
254 let (pos, in_bounds) = self
255 .layout
256 .clone()
257 .__expand_to_source_pos_checked_method(scope, pos);
258 if_expand(scope, in_bounds.into(), |scope| {
259 self.view.__expand_write_checked_method(scope, pos, value);
260 });
261 }
262
263 fn __expand_to_linear_slice_mut_method(
264 &self,
265 scope: &mut Scope,
266 pos: <C>::ExpandType,
267 end: <C>::ExpandType,
268 ) -> SliceExpand<T, ReadWrite> {
269 let pos = self
270 .layout
271 .clone()
272 .__expand_to_source_pos_method(scope, pos);
273 let end = self
274 .layout
275 .clone()
276 .__expand_to_source_pos_method(scope, end);
277 self.view
278 .__expand_to_linear_slice_mut_method(scope, pos, end)
279 }
280
281 fn __expand_tensor_map_store_method(
282 &self,
283 scope: &mut Scope,
284 shared_memory: SliceExpand<T, ReadOnly>,
285 pos: C::ExpandType,
286 ) {
287 let pos = self
288 .layout
289 .clone()
290 .__expand_to_source_pos_method(scope, pos);
291 self.view
292 .__expand_tensor_map_store_method(scope, shared_memory, pos);
293 }
294}