1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5 self as cubecl,
6 prelude::barrier::{Barrier, BarrierExpand},
7 unexpanded,
8};
9
10use crate::{
11 CubeOption, CubeOptionExpand,
12 tensor::{
13 ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
14 VirtualView, VirtualViewMut,
15 layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
16 },
17};
18
19#[derive(Clone)]
23pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
24 _layout: PhantomData<C>,
25 _ty: PhantomData<(E, IO)>,
26}
27
28unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
30unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
31impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
32
33#[derive(Clone)]
34pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
35 Read(Arc<dyn ViewOperationsExpand<E, C>>),
36 ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
37}
38
39impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
40 pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
42 match self {
43 ViewType::Read(list) => &**list,
44 ViewType::ReadWrite(list) => &**list,
45 }
46 }
47
48 pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
50 match self {
51 ViewType::Read(_) => panic!("Can't write to readonly list"),
52 ViewType::ReadWrite(list) => &**list,
53 }
54 }
55}
56
57#[derive(Clone)]
59pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
60 pub(super) inner: ViewType<E, C>,
61 pub(super) _io: PhantomData<IO>,
62}
63
64impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
65 type ExpandType = ViewExpand<E, C, IO>;
66}
67
68impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
69 fn into_mut(self, _scope: &mut Scope) -> Self {
70 self
71 }
72}
73
74impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
75
76impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
77 #[allow(unused_variables)]
80 pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
81 view: &V,
82 layout: impl Into<VirtualLayout<C, S>>,
83 ) -> Self {
84 View {
85 _layout: PhantomData,
86 _ty: PhantomData,
87 }
88 }
89
90 pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
92 scope: &mut Scope,
93 view: V::ExpandType,
94 layout: VirtualLayoutExpand<C, S>,
95 ) -> ViewExpand<E, C, ReadOnly> {
96 ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
97 }
98}
99
100impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
101 pub fn view<T: Coordinates>(
102 &self,
103 _layout: impl Into<VirtualLayout<T, C>>,
104 ) -> View<E, T, ReadOnly> {
105 unexpanded!()
106 }
107
108 pub fn __expand_view<T: Coordinates + 'static>(
109 scope: &mut Scope,
110 this: ViewExpand<E, C, IO>,
111 layout: VirtualLayoutExpand<T, C>,
112 ) -> ViewExpand<E, T, ReadOnly> {
113 this.__expand_view_method(scope, layout)
114 }
115}
116
117impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
118 pub fn __expand_view_method<T: Coordinates + 'static>(
119 self,
120 scope: &mut Scope,
121 layout: VirtualLayoutExpand<T, C>,
122 ) -> ViewExpand<E, T, ReadOnly> {
123 View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
124 }
125
126 pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
127 ViewExpand {
128 inner: ViewType::Read(Arc::new(view)),
129 _io: PhantomData,
130 }
131 }
132
133 pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
134 ViewExpand {
135 inner: ViewType::ReadWrite(Arc::new(view)),
136 _io: PhantomData,
137 }
138 }
139}
140
141impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
142 pub fn view_mut<T: Coordinates>(
143 &self,
144 _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
145 ) -> View<E, T, ReadWrite> {
146 unexpanded!()
147 }
148
149 pub fn __expand_view_mut<T: Coordinates + 'static>(
150 scope: &mut Scope,
151 this: ViewExpand<E, C, ReadWrite>,
152 layout: VirtualLayoutExpand<T, C>,
153 ) -> ViewExpand<E, T, ReadWrite> {
154 this.__expand_view_mut_method(scope, layout)
155 }
156}
157
158impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
159 pub fn __expand_view_mut_method<T: Coordinates + 'static>(
160 self,
161 scope: &mut Scope,
162 layout: VirtualLayoutExpand<T, C>,
163 ) -> ViewExpand<E, T, ReadWrite> {
164 View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
165 }
166}
167
168impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
169 pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
172 _view: &mut V,
173 _layout: impl Into<VirtualLayout<C, S>>,
174 ) -> View<E, C, ReadWrite> {
175 View {
176 _ty: PhantomData,
177 _layout: PhantomData,
178 }
179 }
180
181 pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
183 scope: &mut Scope,
184 view: V::ExpandType,
185 layout: VirtualLayoutExpand<C, S>,
186 ) -> ViewExpand<E, C, ReadWrite> {
187 ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
188 scope, view, layout,
189 ))
190 }
191}
192
193impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
194 pub fn shape(&self) -> C {
196 unexpanded!()
197 }
198
199 pub fn is_in_bounds(&self, _pos: C) -> bool {
201 unexpanded!()
202 }
203
204 pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
205 this.__expand_shape_method(scope)
206 }
207
208 pub fn __expand_is_in_bounds(
209 scope: &mut Scope,
210 this: ViewExpand<E, C, IO>,
211 pos: C::ExpandType,
212 ) -> ExpandElementTyped<bool> {
213 this.__expand_is_in_bounds_method(scope, pos)
214 }
215}
216
217impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
218 pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
219 self.inner.read().__expand_shape_method(scope)
220 }
221
222 pub fn __expand_is_in_bounds_method(
223 &self,
224 scope: &mut Scope,
225 pos: C::ExpandType,
226 ) -> ExpandElementTyped<bool> {
227 self.inner.read().__expand_is_in_bounds_method(scope, pos)
228 }
229}
230
231#[allow(unused_variables)]
232impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
233 pub fn read(&self, pos: C) -> E {
235 unexpanded!()
236 }
237
238 pub fn read_unchecked(&self, pos: C) -> E {
241 unexpanded!()
242 }
243
244 pub fn read_checked(&self, pos: C) -> E {
246 unexpanded!()
247 }
248
249 pub fn read_masked(&self, pos: C, mask_value: E) -> E {
251 unexpanded!()
252 }
253
254 pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
260 unexpanded!()
261 }
262
263 pub fn line_size(&self) -> u32 {
264 unexpanded!()
265 }
266}
267
268impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
269 pub fn __expand_read_method(
271 self,
272 scope: &mut Scope,
273 pos: C::ExpandType,
274 ) -> ExpandElementTyped<E> {
275 self.inner.read().__expand_read_method(scope, pos)
276 }
277
278 pub fn __expand_read_unchecked_method(
280 self,
281 scope: &mut Scope,
282 pos: C::ExpandType,
283 ) -> ExpandElementTyped<E> {
284 self.inner.read().__expand_read_unchecked_method(scope, pos)
285 }
286
287 pub fn __expand_read_checked_method(
289 self,
290 scope: &mut Scope,
291 pos: C::ExpandType,
292 ) -> ExpandElementTyped<E> {
293 self.inner.read().__expand_read_checked_method(scope, pos)
294 }
295
296 pub fn __expand_read_masked_method(
298 self,
299 scope: &mut Scope,
300 pos: C::ExpandType,
301 mask_value: E::ExpandType,
302 ) -> ExpandElementTyped<E> {
303 self.inner
304 .read()
305 .__expand_read_masked_method(scope, pos, mask_value)
306 }
307
308 pub fn __expand_line_size_method(self, _scope: &mut Scope) -> u32 {
310 self.inner.read().line_size()
311 }
312
313 pub fn line_size(&self) -> u32 {
314 self.inner.read().line_size()
315 }
316
317 pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
318 let shape = self.inner.read().__expand_shape_method(scope);
319 let origin = C::__expand_from_int(scope, shape.clone(), 0);
320 let one = C::__expand_from_int(scope, shape.clone(), 1);
322 let shape = C::__expand_max(scope, shape, one.clone());
323 let end = C::__expand_sub(scope, shape, one);
324 self.inner
325 .read()
326 .__expand_to_linear_slice_method(scope, origin, end)
327 }
328
329 pub(super) fn __expand_to_linear_slice_inner_method(
330 self,
331 scope: &mut Scope,
332 pos: C::ExpandType,
333 end: C::ExpandType,
334 ) -> SliceExpand<E, ReadOnly> {
335 self.inner
336 .read()
337 .__expand_to_linear_slice_method(scope, pos, end)
338 }
339}
340
341impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
342 pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
346 unexpanded!()
347 }
348
349 pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
355 unexpanded!()
356 }
357
358 pub fn __expand_slice(
359 scope: &mut Scope,
360 this: ViewExpand<E, C, IO>,
361 pos: C::ExpandType,
362 size: C::ExpandType,
363 ) -> ViewExpand<E, C, ReadOnly> {
364 this.__expand_slice_method(scope, pos, size)
365 }
366
367 pub fn __expand_slice_unchecked(
368 scope: &mut Scope,
369 this: ViewExpand<E, C, IO>,
370 pos: C::ExpandType,
371 size: C::ExpandType,
372 ) -> ViewExpand<E, C, ReadOnly> {
373 this.__expand_slice_unchecked_method(scope, pos, size)
374 }
375}
376
377#[cube]
378impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
379
380impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
381 pub fn __expand_slice_method(
382 &self,
383 scope: &mut Scope,
384 pos: C::ExpandType,
385 size: C::ExpandType,
386 ) -> ViewExpand<E, C, ReadOnly> {
387 self.slice(scope, pos, size, true)
388 }
389
390 pub fn __expand_slice_unchecked_method(
391 &self,
392 scope: &mut Scope,
393 pos: C::ExpandType,
394 size: C::ExpandType,
395 ) -> ViewExpand<E, C, ReadOnly> {
396 self.slice(scope, pos, size, false)
397 }
398
399 fn slice(
400 &self,
401 scope: &mut Scope,
402 pos: C::ExpandType,
403 size: C::ExpandType,
404 checked: bool,
405 ) -> ViewExpand<E, C, ReadOnly> {
406 let shape = self.__expand_shape_method(scope);
407 let pos = C::__expand_min(scope, pos, shape.clone());
408 let max_size = C::__expand_sub(scope, shape, pos.clone());
409 let size = C::__expand_min(scope, size, max_size);
410 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
411 self.clone().__expand_view_method(scope, layout.into())
412 }
413}
414
415#[allow(unused_variables)]
416impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
417 pub fn write(&self, pos: C, value: E) {
419 unexpanded!()
420 }
421
422 pub fn write_checked(&self, pos: C, value: E) {
424 unexpanded!()
425 }
426
427 pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
433 unexpanded!()
434 }
435}
436
437impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
438 pub fn __expand_write_method(
440 self,
441 scope: &mut Scope,
442 pos: C::ExpandType,
443 value: ExpandElementTyped<E>,
444 ) {
445 self.inner.write().__expand_write_method(scope, pos, value);
446 }
447
448 pub fn __expand_write_checked_method(
450 self,
451 scope: &mut Scope,
452 pos: C::ExpandType,
453 value: ExpandElementTyped<E>,
454 ) {
455 self.inner
456 .write()
457 .__expand_write_checked_method(scope, pos, value);
458 }
459
460 pub fn __expand_to_linear_slice_mut_method(
461 self,
462 scope: &mut Scope,
463 ) -> SliceExpand<E, ReadWrite> {
464 let shape = self.inner.read().__expand_shape_method(scope);
465 let origin = C::__expand_from_int(scope, shape.clone(), 0);
466 let one = C::__expand_from_int(scope, shape.clone(), 1);
468 let shape = C::__expand_max(scope, shape, one.clone());
469 let end = C::__expand_sub(scope, shape, one);
470 self.inner
471 .write()
472 .__expand_to_linear_slice_mut_method(scope, origin, end)
473 }
474
475 pub(super) fn __expand_to_linear_slice_mut_inner_method(
476 self,
477 scope: &mut Scope,
478 pos: C::ExpandType,
479 end: C::ExpandType,
480 ) -> SliceExpand<E, ReadWrite> {
481 self.inner
482 .write()
483 .__expand_to_linear_slice_mut_method(scope, pos, end)
484 }
485}
486
487impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
488 pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
492 unexpanded!()
493 }
494
495 pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
502 unexpanded!()
503 }
504
505 pub fn __expand_slice_mut(
506 scope: &mut Scope,
507 this: ViewExpand<E, C, ReadWrite>,
508 pos: C::ExpandType,
509 size: C::ExpandType,
510 ) -> ViewExpand<E, C, ReadWrite> {
511 this.__expand_slice_mut_method(scope, pos, size)
512 }
513
514 pub fn __expand_slice_mut_unchecked(
515 scope: &mut Scope,
516 this: ViewExpand<E, C, ReadWrite>,
517 pos: C::ExpandType,
518 size: C::ExpandType,
519 ) -> ViewExpand<E, C, ReadWrite> {
520 this.__expand_slice_mut_unchecked_method(scope, pos, size)
521 }
522}
523
524#[cube]
525impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
526
527impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
528 pub fn __expand_slice_mut_method(
529 &self,
530 scope: &mut Scope,
531 pos: C::ExpandType,
532 size: C::ExpandType,
533 ) -> ViewExpand<E, C, ReadWrite> {
534 self.slice_mut(scope, pos, size, true)
535 }
536
537 pub fn __expand_slice_mut_unchecked_method(
538 &self,
539 scope: &mut Scope,
540 pos: C::ExpandType,
541 size: C::ExpandType,
542 ) -> ViewExpand<E, C, ReadWrite> {
543 self.slice_mut(scope, pos, size, false)
544 }
545
546 fn slice_mut(
547 &self,
548 scope: &mut Scope,
549 pos: C::ExpandType,
550 size: C::ExpandType,
551 checked: bool,
552 ) -> ViewExpand<E, C, ReadWrite> {
553 let shape = self.__expand_shape_method(scope);
554 let pos = C::__expand_min(scope, pos, shape.clone());
555 let max_size = C::__expand_sub(scope, shape, pos.clone());
556 let size = C::__expand_min(scope, size, max_size);
557 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
558 self.clone().__expand_view_mut_method(scope, layout.into())
559 }
560}
561
562impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
563 pub fn tensor_map_load(
566 &self,
567 _barrier: &Barrier,
568 _shared_memory: &mut Slice<E, ReadWrite>,
569 _pos: C,
570 ) -> View<E, C, ReadWrite> {
571 unexpanded!()
572 }
573
574 pub fn as_tensor_map(&self) -> CubeOption<TensorMap<E>> {
577 unexpanded!()
578 }
579}
580
581impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
582 pub fn __expand_tensor_map_load_method(
583 self,
584 scope: &mut Scope,
585 barrier: BarrierExpand,
586 shared_memory: SliceExpand<E, ReadWrite>,
587 pos: C::ExpandType,
588 ) {
589 self.inner
590 .read()
591 .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
592 }
593
594 pub fn __expand_as_tensor_map_method(
595 self,
596 scope: &mut Scope,
597 ) -> CubeOptionExpand<TensorMap<E>> {
598 self.inner.read().__expand_as_tensor_map_method(scope)
599 }
600}
601
602impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
603 pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
606 unexpanded!()
607 }
608}
609
610impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
611 pub fn __expand_tensor_map_store_method(
612 self,
613 scope: &mut Scope,
614 shared_memory: SliceExpand<E, ReadOnly>,
615 pos: C::ExpandType,
616 ) {
617 self.inner
618 .write()
619 .__expand_tensor_map_store_method(scope, shared_memory, pos)
620 }
621}