1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5 self as cubecl,
6 ir::LineSize,
7 prelude::barrier::{Barrier, BarrierExpand},
8 unexpanded,
9};
10
11use crate::tensor::{
12 ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand, VirtualView,
13 VirtualViewMut,
14 layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
15};
16
17#[derive(Clone)]
21pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
22 _layout: PhantomData<C>,
23 _ty: PhantomData<(E, IO)>,
24}
25
26unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
28unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
29impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
30
31#[derive(Clone)]
32pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
33 Read(Arc<dyn ViewOperationsExpand<E, C>>),
34 ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
35}
36
37impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
38 pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
40 match self {
41 ViewType::Read(list) => &**list,
42 ViewType::ReadWrite(list) => &**list,
43 }
44 }
45
46 pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
48 match self {
49 ViewType::Read(_) => panic!("Can't write to readonly list"),
50 ViewType::ReadWrite(list) => &**list,
51 }
52 }
53}
54
55#[derive(Clone)]
57pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
58 pub(super) inner: ViewType<E, C>,
59 pub(super) _io: PhantomData<IO>,
60}
61
62impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
63 type ExpandType = ViewExpand<E, C, IO>;
64}
65
66impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
67 fn into_mut(self, _scope: &mut Scope) -> Self {
68 self
69 }
70}
71
72impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
73
74impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
75 #[allow(unused_variables)]
78 pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
79 view: &V,
80 layout: impl Into<VirtualLayout<C, S>>,
81 ) -> Self {
82 View {
83 _layout: PhantomData,
84 _ty: PhantomData,
85 }
86 }
87
88 pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
90 scope: &mut Scope,
91 view: V::ExpandType,
92 layout: VirtualLayoutExpand<C, S>,
93 ) -> ViewExpand<E, C, ReadOnly> {
94 ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
95 }
96}
97
98impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
99 pub fn view<T: Coordinates>(
100 &self,
101 _layout: impl Into<VirtualLayout<T, C>>,
102 ) -> View<E, T, ReadOnly> {
103 unexpanded!()
104 }
105
106 pub fn __expand_view<T: Coordinates + 'static>(
107 scope: &mut Scope,
108 this: ViewExpand<E, C, IO>,
109 layout: VirtualLayoutExpand<T, C>,
110 ) -> ViewExpand<E, T, ReadOnly> {
111 this.__expand_view_method(scope, layout)
112 }
113}
114
115impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
116 pub fn __expand_view_method<T: Coordinates + 'static>(
117 self,
118 scope: &mut Scope,
119 layout: VirtualLayoutExpand<T, C>,
120 ) -> ViewExpand<E, T, ReadOnly> {
121 View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
122 }
123
124 pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
125 ViewExpand {
126 inner: ViewType::Read(Arc::new(view)),
127 _io: PhantomData,
128 }
129 }
130
131 pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
132 ViewExpand {
133 inner: ViewType::ReadWrite(Arc::new(view)),
134 _io: PhantomData,
135 }
136 }
137}
138
139impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
140 pub fn view_mut<T: Coordinates>(
141 &self,
142 _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
143 ) -> View<E, T, ReadWrite> {
144 unexpanded!()
145 }
146
147 pub fn __expand_view_mut<T: Coordinates + 'static>(
148 scope: &mut Scope,
149 this: ViewExpand<E, C, ReadWrite>,
150 layout: VirtualLayoutExpand<T, C>,
151 ) -> ViewExpand<E, T, ReadWrite> {
152 this.__expand_view_mut_method(scope, layout)
153 }
154}
155
156impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
157 pub fn __expand_view_mut_method<T: Coordinates + 'static>(
158 self,
159 scope: &mut Scope,
160 layout: VirtualLayoutExpand<T, C>,
161 ) -> ViewExpand<E, T, ReadWrite> {
162 View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
163 }
164}
165
166impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
167 pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
170 _view: &mut V,
171 _layout: impl Into<VirtualLayout<C, S>>,
172 ) -> View<E, C, ReadWrite> {
173 View {
174 _ty: PhantomData,
175 _layout: PhantomData,
176 }
177 }
178
179 pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
181 scope: &mut Scope,
182 view: V::ExpandType,
183 layout: VirtualLayoutExpand<C, S>,
184 ) -> ViewExpand<E, C, ReadWrite> {
185 ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
186 scope, view, layout,
187 ))
188 }
189}
190
191impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
192 pub fn shape(&self) -> C {
194 unexpanded!()
195 }
196
197 pub fn is_in_bounds(&self, _pos: C) -> bool {
199 unexpanded!()
200 }
201
202 pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
203 this.__expand_shape_method(scope)
204 }
205
206 pub fn __expand_is_in_bounds(
207 scope: &mut Scope,
208 this: ViewExpand<E, C, IO>,
209 pos: C::ExpandType,
210 ) -> ExpandElementTyped<bool> {
211 this.__expand_is_in_bounds_method(scope, pos)
212 }
213}
214
215impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
216 pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
217 self.inner.read().__expand_shape_method(scope)
218 }
219
220 pub fn __expand_is_in_bounds_method(
221 &self,
222 scope: &mut Scope,
223 pos: C::ExpandType,
224 ) -> ExpandElementTyped<bool> {
225 self.inner.read().__expand_is_in_bounds_method(scope, pos)
226 }
227}
228
229#[allow(unused_variables)]
230impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
231 pub fn read(&self, pos: C) -> E {
233 unexpanded!()
234 }
235
236 pub fn read_unchecked(&self, pos: C) -> E {
239 unexpanded!()
240 }
241
242 pub fn read_checked(&self, pos: C) -> E {
244 unexpanded!()
245 }
246
247 pub fn read_masked(&self, pos: C, mask_value: E) -> E {
249 unexpanded!()
250 }
251
252 pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
258 unexpanded!()
259 }
260
261 pub fn line_size(&self) -> LineSize {
262 unexpanded!()
263 }
264}
265
266impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
267 pub fn __expand_read_method(
269 self,
270 scope: &mut Scope,
271 pos: C::ExpandType,
272 ) -> ExpandElementTyped<E> {
273 self.inner.read().__expand_read_method(scope, pos)
274 }
275
276 pub fn __expand_read_unchecked_method(
278 self,
279 scope: &mut Scope,
280 pos: C::ExpandType,
281 ) -> ExpandElementTyped<E> {
282 self.inner.read().__expand_read_unchecked_method(scope, pos)
283 }
284
285 pub fn __expand_read_checked_method(
287 self,
288 scope: &mut Scope,
289 pos: C::ExpandType,
290 ) -> ExpandElementTyped<E> {
291 self.inner.read().__expand_read_checked_method(scope, pos)
292 }
293
294 pub fn __expand_read_masked_method(
296 self,
297 scope: &mut Scope,
298 pos: C::ExpandType,
299 mask_value: E::ExpandType,
300 ) -> ExpandElementTyped<E> {
301 self.inner
302 .read()
303 .__expand_read_masked_method(scope, pos, mask_value)
304 }
305
306 pub fn __expand_line_size_method(self, _scope: &mut Scope) -> LineSize {
308 self.inner.read().line_size()
309 }
310
311 pub fn line_size(&self) -> LineSize {
312 self.inner.read().line_size()
313 }
314
315 pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
316 let shape = self.inner.read().__expand_shape_method(scope);
317 let origin = C::__expand_from_int(scope, shape.clone(), 0);
318 let one = C::__expand_from_int(scope, shape.clone(), 1);
320 let shape = C::__expand_max(scope, shape, one.clone());
321 let end = C::__expand_sub(scope, shape, one);
322 self.inner
323 .read()
324 .__expand_to_linear_slice_method(scope, origin, end)
325 }
326
327 pub(super) fn __expand_to_linear_slice_inner_method(
328 self,
329 scope: &mut Scope,
330 pos: C::ExpandType,
331 end: C::ExpandType,
332 ) -> SliceExpand<E, ReadOnly> {
333 self.inner
334 .read()
335 .__expand_to_linear_slice_method(scope, pos, end)
336 }
337}
338
339impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
340 pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
344 unexpanded!()
345 }
346
347 pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
353 unexpanded!()
354 }
355
356 pub fn __expand_slice(
357 scope: &mut Scope,
358 this: ViewExpand<E, C, IO>,
359 pos: C::ExpandType,
360 size: C::ExpandType,
361 ) -> ViewExpand<E, C, ReadOnly> {
362 this.__expand_slice_method(scope, pos, size)
363 }
364
365 pub fn __expand_slice_unchecked(
366 scope: &mut Scope,
367 this: ViewExpand<E, C, IO>,
368 pos: C::ExpandType,
369 size: C::ExpandType,
370 ) -> ViewExpand<E, C, ReadOnly> {
371 this.__expand_slice_unchecked_method(scope, pos, size)
372 }
373}
374
375#[cube]
376impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
377
378impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
379 pub fn __expand_slice_method(
380 &self,
381 scope: &mut Scope,
382 pos: C::ExpandType,
383 size: C::ExpandType,
384 ) -> ViewExpand<E, C, ReadOnly> {
385 self.slice(scope, pos, size, true)
386 }
387
388 pub fn __expand_slice_unchecked_method(
389 &self,
390 scope: &mut Scope,
391 pos: C::ExpandType,
392 size: C::ExpandType,
393 ) -> ViewExpand<E, C, ReadOnly> {
394 self.slice(scope, pos, size, false)
395 }
396
397 fn slice(
398 &self,
399 scope: &mut Scope,
400 pos: C::ExpandType,
401 size: C::ExpandType,
402 checked: bool,
403 ) -> ViewExpand<E, C, ReadOnly> {
404 let shape = self.__expand_shape_method(scope);
405 let pos = C::__expand_min(scope, pos, shape.clone());
406 let max_size = C::__expand_sub(scope, shape, pos.clone());
407 let size = C::__expand_min(scope, size, max_size);
408 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
409 self.clone().__expand_view_method(scope, layout.into())
410 }
411}
412
413#[allow(unused_variables)]
414impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
415 pub fn write(&self, pos: C, value: E) {
417 unexpanded!()
418 }
419
420 pub fn write_checked(&self, pos: C, value: E) {
422 unexpanded!()
423 }
424
425 pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
431 unexpanded!()
432 }
433}
434
435impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
436 pub fn __expand_write_method(
438 self,
439 scope: &mut Scope,
440 pos: C::ExpandType,
441 value: ExpandElementTyped<E>,
442 ) {
443 self.inner.write().__expand_write_method(scope, pos, value);
444 }
445
446 pub fn __expand_write_checked_method(
448 self,
449 scope: &mut Scope,
450 pos: C::ExpandType,
451 value: ExpandElementTyped<E>,
452 ) {
453 self.inner
454 .write()
455 .__expand_write_checked_method(scope, pos, value);
456 }
457
458 pub fn __expand_to_linear_slice_mut_method(
459 self,
460 scope: &mut Scope,
461 ) -> SliceExpand<E, ReadWrite> {
462 let shape = self.inner.read().__expand_shape_method(scope);
463 let origin = C::__expand_from_int(scope, shape.clone(), 0);
464 let one = C::__expand_from_int(scope, shape.clone(), 1);
466 let shape = C::__expand_max(scope, shape, one.clone());
467 let end = C::__expand_sub(scope, shape, one);
468 self.inner
469 .write()
470 .__expand_to_linear_slice_mut_method(scope, origin, end)
471 }
472
473 pub(super) fn __expand_to_linear_slice_mut_inner_method(
474 self,
475 scope: &mut Scope,
476 pos: C::ExpandType,
477 end: C::ExpandType,
478 ) -> SliceExpand<E, ReadWrite> {
479 self.inner
480 .write()
481 .__expand_to_linear_slice_mut_method(scope, pos, end)
482 }
483}
484
485impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
486 pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
490 unexpanded!()
491 }
492
493 pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
500 unexpanded!()
501 }
502
503 pub fn __expand_slice_mut(
504 scope: &mut Scope,
505 this: ViewExpand<E, C, ReadWrite>,
506 pos: C::ExpandType,
507 size: C::ExpandType,
508 ) -> ViewExpand<E, C, ReadWrite> {
509 this.__expand_slice_mut_method(scope, pos, size)
510 }
511
512 pub fn __expand_slice_mut_unchecked(
513 scope: &mut Scope,
514 this: ViewExpand<E, C, ReadWrite>,
515 pos: C::ExpandType,
516 size: C::ExpandType,
517 ) -> ViewExpand<E, C, ReadWrite> {
518 this.__expand_slice_mut_unchecked_method(scope, pos, size)
519 }
520}
521
522#[cube]
523impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
524
525impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
526 pub fn __expand_slice_mut_method(
527 &self,
528 scope: &mut Scope,
529 pos: C::ExpandType,
530 size: C::ExpandType,
531 ) -> ViewExpand<E, C, ReadWrite> {
532 self.slice_mut(scope, pos, size, true)
533 }
534
535 pub fn __expand_slice_mut_unchecked_method(
536 &self,
537 scope: &mut Scope,
538 pos: C::ExpandType,
539 size: C::ExpandType,
540 ) -> ViewExpand<E, C, ReadWrite> {
541 self.slice_mut(scope, pos, size, false)
542 }
543
544 fn slice_mut(
545 &self,
546 scope: &mut Scope,
547 pos: C::ExpandType,
548 size: C::ExpandType,
549 checked: bool,
550 ) -> ViewExpand<E, C, ReadWrite> {
551 let shape = self.__expand_shape_method(scope);
552 let pos = C::__expand_min(scope, pos, shape.clone());
553 let max_size = C::__expand_sub(scope, shape, pos.clone());
554 let size = C::__expand_min(scope, size, max_size);
555 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
556 self.clone().__expand_view_mut_method(scope, layout.into())
557 }
558}
559
560impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
561 pub fn tensor_map_load(
564 &self,
565 _barrier: &Barrier,
566 _shared_memory: &mut Slice<E, ReadWrite>,
567 _pos: C,
568 ) -> View<E, C, ReadWrite> {
569 unexpanded!()
570 }
571}
572
573impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
574 pub fn __expand_tensor_map_load_method(
575 self,
576 scope: &mut Scope,
577 barrier: BarrierExpand,
578 shared_memory: SliceExpand<E, ReadWrite>,
579 pos: C::ExpandType,
580 ) {
581 self.inner
582 .read()
583 .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
584 }
585}
586
587impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
588 pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
591 unexpanded!()
592 }
593}
594
595impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
596 pub fn __expand_tensor_map_store_method(
597 self,
598 scope: &mut Scope,
599 shared_memory: SliceExpand<E, ReadOnly>,
600 pos: C::ExpandType,
601 ) {
602 self.inner
603 .write()
604 .__expand_tensor_map_store_method(scope, shared_memory, pos)
605 }
606}