1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5 self as cubecl,
6 prelude::barrier::{Barrier, BarrierExpand},
7 unexpanded,
8};
9
10use crate::tensor::{
11 ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand, VirtualView,
12 VirtualViewMut,
13 layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
14};
15
16#[derive(Clone)]
20pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
21 _layout: PhantomData<C>,
22 _ty: PhantomData<(E, IO)>,
23}
24
25unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
27unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
28impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
29
30#[derive(Clone)]
31pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
32 Read(Arc<dyn ViewOperationsExpand<E, C>>),
33 ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
34}
35
36impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
37 pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
39 match self {
40 ViewType::Read(list) => &**list,
41 ViewType::ReadWrite(list) => &**list,
42 }
43 }
44
45 pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
47 match self {
48 ViewType::Read(_) => panic!("Can't write to readonly list"),
49 ViewType::ReadWrite(list) => &**list,
50 }
51 }
52}
53
54#[derive(Clone)]
56pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
57 pub(super) inner: ViewType<E, C>,
58 pub(super) _io: PhantomData<IO>,
59}
60
61impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
62 type ExpandType = ViewExpand<E, C, IO>;
63}
64
65impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
66 fn into_mut(self, _scope: &mut Scope) -> Self {
67 self
68 }
69}
70
71impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
72
73impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
74 #[allow(unused_variables)]
77 pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
78 view: &V,
79 layout: impl Into<VirtualLayout<C, S>>,
80 ) -> Self {
81 View {
82 _layout: PhantomData,
83 _ty: PhantomData,
84 }
85 }
86
87 pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
89 scope: &mut Scope,
90 view: V::ExpandType,
91 layout: VirtualLayoutExpand<C, S>,
92 ) -> ViewExpand<E, C, ReadOnly> {
93 ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
94 }
95}
96
97impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
98 pub fn view<T: Coordinates>(
99 &self,
100 _layout: impl Into<VirtualLayout<T, C>>,
101 ) -> View<E, T, ReadOnly> {
102 unexpanded!()
103 }
104
105 pub fn __expand_view<T: Coordinates + 'static>(
106 scope: &mut Scope,
107 this: ViewExpand<E, C, IO>,
108 layout: VirtualLayoutExpand<T, C>,
109 ) -> ViewExpand<E, T, ReadOnly> {
110 this.__expand_view_method(scope, layout)
111 }
112}
113
114impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
115 pub fn __expand_view_method<T: Coordinates + 'static>(
116 self,
117 scope: &mut Scope,
118 layout: VirtualLayoutExpand<T, C>,
119 ) -> ViewExpand<E, T, ReadOnly> {
120 View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
121 }
122
123 pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
124 ViewExpand {
125 inner: ViewType::Read(Arc::new(view)),
126 _io: PhantomData,
127 }
128 }
129
130 pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
131 ViewExpand {
132 inner: ViewType::ReadWrite(Arc::new(view)),
133 _io: PhantomData,
134 }
135 }
136}
137
138impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
139 pub fn view_mut<T: Coordinates>(
140 &self,
141 _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
142 ) -> View<E, T, ReadWrite> {
143 unexpanded!()
144 }
145
146 pub fn __expand_view_mut<T: Coordinates + 'static>(
147 scope: &mut Scope,
148 this: ViewExpand<E, C, ReadWrite>,
149 layout: VirtualLayoutExpand<T, C>,
150 ) -> ViewExpand<E, T, ReadWrite> {
151 this.__expand_view_mut_method(scope, layout)
152 }
153}
154
155impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
156 pub fn __expand_view_mut_method<T: Coordinates + 'static>(
157 self,
158 scope: &mut Scope,
159 layout: VirtualLayoutExpand<T, C>,
160 ) -> ViewExpand<E, T, ReadWrite> {
161 View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
162 }
163}
164
165impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
166 pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
169 _view: &mut V,
170 _layout: impl Into<VirtualLayout<C, S>>,
171 ) -> View<E, C, ReadWrite> {
172 View {
173 _ty: PhantomData,
174 _layout: PhantomData,
175 }
176 }
177
178 pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
180 scope: &mut Scope,
181 view: V::ExpandType,
182 layout: VirtualLayoutExpand<C, S>,
183 ) -> ViewExpand<E, C, ReadWrite> {
184 ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
185 scope, view, layout,
186 ))
187 }
188}
189
190impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
191 pub fn shape(&self) -> C {
193 unexpanded!()
194 }
195
196 pub fn is_in_bounds(&self, _pos: C) -> bool {
198 unexpanded!()
199 }
200
201 pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
202 this.__expand_shape_method(scope)
203 }
204
205 pub fn __expand_is_in_bounds(
206 scope: &mut Scope,
207 this: ViewExpand<E, C, IO>,
208 pos: C::ExpandType,
209 ) -> ExpandElementTyped<bool> {
210 this.__expand_is_in_bounds_method(scope, pos)
211 }
212}
213
214impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
215 pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
216 self.inner.read().__expand_shape_method(scope)
217 }
218
219 pub fn __expand_is_in_bounds_method(
220 &self,
221 scope: &mut Scope,
222 pos: C::ExpandType,
223 ) -> ExpandElementTyped<bool> {
224 self.inner.read().__expand_is_in_bounds_method(scope, pos)
225 }
226}
227
228#[allow(unused_variables)]
229impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
230 pub fn read(&self, pos: C) -> E {
232 unexpanded!()
233 }
234
235 pub fn read_unchecked(&self, pos: C) -> E {
238 unexpanded!()
239 }
240
241 pub fn read_checked(&self, pos: C) -> E {
243 unexpanded!()
244 }
245
246 pub fn read_masked(&self, pos: C, mask_value: E) -> E {
248 unexpanded!()
249 }
250
251 pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
257 unexpanded!()
258 }
259
260 pub fn line_size(&self) -> u32 {
261 unexpanded!()
262 }
263}
264
265impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
266 pub fn __expand_read_method(
268 self,
269 scope: &mut Scope,
270 pos: C::ExpandType,
271 ) -> ExpandElementTyped<E> {
272 self.inner.read().__expand_read_method(scope, pos)
273 }
274
275 pub fn __expand_read_unchecked_method(
277 self,
278 scope: &mut Scope,
279 pos: C::ExpandType,
280 ) -> ExpandElementTyped<E> {
281 self.inner.read().__expand_read_unchecked_method(scope, pos)
282 }
283
284 pub fn __expand_read_checked_method(
286 self,
287 scope: &mut Scope,
288 pos: C::ExpandType,
289 ) -> ExpandElementTyped<E> {
290 self.inner.read().__expand_read_checked_method(scope, pos)
291 }
292
293 pub fn __expand_read_masked_method(
295 self,
296 scope: &mut Scope,
297 pos: C::ExpandType,
298 mask_value: E::ExpandType,
299 ) -> ExpandElementTyped<E> {
300 self.inner
301 .read()
302 .__expand_read_masked_method(scope, pos, mask_value)
303 }
304
305 pub fn __expand_line_size_method(self, _scope: &mut Scope) -> u32 {
307 self.inner.read().line_size()
308 }
309
310 pub fn line_size(&self) -> u32 {
311 self.inner.read().line_size()
312 }
313
314 pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
315 let shape = self.inner.read().__expand_shape_method(scope);
316 let origin = C::__expand_from_int(scope, shape.clone(), 0);
317 let one = C::__expand_from_int(scope, shape.clone(), 1);
319 let shape = C::__expand_max(scope, shape, one.clone());
320 let end = C::__expand_sub(scope, shape, one);
321 self.inner
322 .read()
323 .__expand_to_linear_slice_method(scope, origin, end)
324 }
325
326 pub(super) fn __expand_to_linear_slice_inner_method(
327 self,
328 scope: &mut Scope,
329 pos: C::ExpandType,
330 end: C::ExpandType,
331 ) -> SliceExpand<E, ReadOnly> {
332 self.inner
333 .read()
334 .__expand_to_linear_slice_method(scope, pos, end)
335 }
336}
337
338impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
339 pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
343 unexpanded!()
344 }
345
346 pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
352 unexpanded!()
353 }
354
355 pub fn __expand_slice(
356 scope: &mut Scope,
357 this: ViewExpand<E, C, IO>,
358 pos: C::ExpandType,
359 size: C::ExpandType,
360 ) -> ViewExpand<E, C, ReadOnly> {
361 this.__expand_slice_method(scope, pos, size)
362 }
363
364 pub fn __expand_slice_unchecked(
365 scope: &mut Scope,
366 this: ViewExpand<E, C, IO>,
367 pos: C::ExpandType,
368 size: C::ExpandType,
369 ) -> ViewExpand<E, C, ReadOnly> {
370 this.__expand_slice_unchecked_method(scope, pos, size)
371 }
372}
373
374#[cube]
375impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
376
377impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
378 pub fn __expand_slice_method(
379 &self,
380 scope: &mut Scope,
381 pos: C::ExpandType,
382 size: C::ExpandType,
383 ) -> ViewExpand<E, C, ReadOnly> {
384 self.slice(scope, pos, size, true)
385 }
386
387 pub fn __expand_slice_unchecked_method(
388 &self,
389 scope: &mut Scope,
390 pos: C::ExpandType,
391 size: C::ExpandType,
392 ) -> ViewExpand<E, C, ReadOnly> {
393 self.slice(scope, pos, size, false)
394 }
395
396 fn slice(
397 &self,
398 scope: &mut Scope,
399 pos: C::ExpandType,
400 size: C::ExpandType,
401 checked: bool,
402 ) -> ViewExpand<E, C, ReadOnly> {
403 let shape = self.__expand_shape_method(scope);
404 let pos = C::__expand_min(scope, pos, shape.clone());
405 let max_size = C::__expand_sub(scope, shape, pos.clone());
406 let size = C::__expand_min(scope, size, max_size);
407 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
408 self.clone().__expand_view_method(scope, layout.into())
409 }
410}
411
412#[allow(unused_variables)]
413impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
414 pub fn write(&self, pos: C, value: E) {
416 unexpanded!()
417 }
418
419 pub fn write_checked(&self, pos: C, value: E) {
421 unexpanded!()
422 }
423
424 pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
430 unexpanded!()
431 }
432}
433
434impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
435 pub fn __expand_write_method(
437 self,
438 scope: &mut Scope,
439 pos: C::ExpandType,
440 value: ExpandElementTyped<E>,
441 ) {
442 self.inner.write().__expand_write_method(scope, pos, value);
443 }
444
445 pub fn __expand_write_checked_method(
447 self,
448 scope: &mut Scope,
449 pos: C::ExpandType,
450 value: ExpandElementTyped<E>,
451 ) {
452 self.inner
453 .write()
454 .__expand_write_checked_method(scope, pos, value);
455 }
456
457 pub fn __expand_to_linear_slice_mut_method(
458 self,
459 scope: &mut Scope,
460 ) -> SliceExpand<E, ReadWrite> {
461 let shape = self.inner.read().__expand_shape_method(scope);
462 let origin = C::__expand_from_int(scope, shape.clone(), 0);
463 let one = C::__expand_from_int(scope, shape.clone(), 1);
465 let shape = C::__expand_max(scope, shape, one.clone());
466 let end = C::__expand_sub(scope, shape, one);
467 self.inner
468 .write()
469 .__expand_to_linear_slice_mut_method(scope, origin, end)
470 }
471
472 pub(super) fn __expand_to_linear_slice_mut_inner_method(
473 self,
474 scope: &mut Scope,
475 pos: C::ExpandType,
476 end: C::ExpandType,
477 ) -> SliceExpand<E, ReadWrite> {
478 self.inner
479 .write()
480 .__expand_to_linear_slice_mut_method(scope, pos, end)
481 }
482}
483
484impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
485 pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
489 unexpanded!()
490 }
491
492 pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
499 unexpanded!()
500 }
501
502 pub fn __expand_slice_mut(
503 scope: &mut Scope,
504 this: ViewExpand<E, C, ReadWrite>,
505 pos: C::ExpandType,
506 size: C::ExpandType,
507 ) -> ViewExpand<E, C, ReadWrite> {
508 this.__expand_slice_mut_method(scope, pos, size)
509 }
510
511 pub fn __expand_slice_mut_unchecked(
512 scope: &mut Scope,
513 this: ViewExpand<E, C, ReadWrite>,
514 pos: C::ExpandType,
515 size: C::ExpandType,
516 ) -> ViewExpand<E, C, ReadWrite> {
517 this.__expand_slice_mut_unchecked_method(scope, pos, size)
518 }
519}
520
521#[cube]
522impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
523
524impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
525 pub fn __expand_slice_mut_method(
526 &self,
527 scope: &mut Scope,
528 pos: C::ExpandType,
529 size: C::ExpandType,
530 ) -> ViewExpand<E, C, ReadWrite> {
531 self.slice_mut(scope, pos, size, true)
532 }
533
534 pub fn __expand_slice_mut_unchecked_method(
535 &self,
536 scope: &mut Scope,
537 pos: C::ExpandType,
538 size: C::ExpandType,
539 ) -> ViewExpand<E, C, ReadWrite> {
540 self.slice_mut(scope, pos, size, false)
541 }
542
543 fn slice_mut(
544 &self,
545 scope: &mut Scope,
546 pos: C::ExpandType,
547 size: C::ExpandType,
548 checked: bool,
549 ) -> ViewExpand<E, C, ReadWrite> {
550 let shape = self.__expand_shape_method(scope);
551 let pos = C::__expand_min(scope, pos, shape.clone());
552 let max_size = C::__expand_sub(scope, shape, pos.clone());
553 let size = C::__expand_min(scope, size, max_size);
554 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
555 self.clone().__expand_view_mut_method(scope, layout.into())
556 }
557}
558
559impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
560 pub fn tensor_map_load(
563 &self,
564 _barrier: &Barrier,
565 _shared_memory: &mut Slice<E, ReadWrite>,
566 _pos: C,
567 ) -> View<E, C, ReadWrite> {
568 unexpanded!()
569 }
570}
571
572impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
573 pub fn __expand_tensor_map_load_method(
574 self,
575 scope: &mut Scope,
576 barrier: BarrierExpand,
577 shared_memory: SliceExpand<E, ReadWrite>,
578 pos: C::ExpandType,
579 ) {
580 self.inner
581 .read()
582 .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
583 }
584}
585
586impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
587 pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
590 unexpanded!()
591 }
592}
593
594impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
595 pub fn __expand_tensor_map_store_method(
596 self,
597 scope: &mut Scope,
598 shared_memory: SliceExpand<E, ReadOnly>,
599 pos: C::ExpandType,
600 ) {
601 self.inner
602 .write()
603 .__expand_tensor_map_store_method(scope, shared_memory, pos)
604 }
605}