1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5 self as cubecl,
6 ir::VectorSize,
7 prelude::barrier::{Barrier, BarrierExpand},
8 unexpanded,
9};
10
11use crate::tensor::{
12 ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand, VirtualView,
13 VirtualViewMut,
14 layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
15};
16
17#[derive(Clone)]
21pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
22 _layout: PhantomData<C>,
23 _ty: PhantomData<(E, IO)>,
24}
25
26unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
28unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
29impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
30
31#[derive(Clone)]
32pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
33 Read(Arc<dyn ViewOperationsExpand<E, C>>),
34 ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
35}
36
37impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
38 pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
40 match self {
41 ViewType::Read(list) => &**list,
42 ViewType::ReadWrite(list) => &**list,
43 }
44 }
45
46 pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
48 match self {
49 ViewType::Read(_) => panic!("Can't write to readonly list"),
50 ViewType::ReadWrite(list) => &**list,
51 }
52 }
53}
54
55#[derive(Clone)]
57pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
58 pub(super) inner: ViewType<E, C>,
59 pub(super) _io: PhantomData<IO>,
60}
61
62impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
63 type ExpandType = ViewExpand<E, C, IO>;
64}
65
66impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
67 fn into_mut(self, _scope: &mut Scope) -> Self {
68 self
69 }
70}
71
72impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
73
74impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
75 #[allow(unused_variables)]
78 pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
79 view: &V,
80 layout: impl Into<VirtualLayout<C, S>>,
81 ) -> Self {
82 View {
83 _layout: PhantomData,
84 _ty: PhantomData,
85 }
86 }
87
88 pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
90 scope: &mut Scope,
91 view: V::ExpandType,
92 layout: VirtualLayoutExpand<C, S>,
93 ) -> ViewExpand<E, C, ReadOnly> {
94 ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
95 }
96}
97
98impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
99 pub fn view<T: Coordinates>(
100 &self,
101 _layout: impl Into<VirtualLayout<T, C>>,
102 ) -> View<E, T, ReadOnly> {
103 unexpanded!()
104 }
105
106 pub fn __expand_view<T: Coordinates + 'static>(
107 scope: &mut Scope,
108 this: ViewExpand<E, C, IO>,
109 layout: VirtualLayoutExpand<T, C>,
110 ) -> ViewExpand<E, T, ReadOnly> {
111 this.__expand_view_method(scope, layout)
112 }
113}
114
115impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
116 pub fn __expand_view_method<T: Coordinates + 'static>(
117 self,
118 scope: &mut Scope,
119 layout: VirtualLayoutExpand<T, C>,
120 ) -> ViewExpand<E, T, ReadOnly> {
121 View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
122 }
123
124 pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
125 ViewExpand {
126 inner: ViewType::Read(Arc::new(view)),
127 _io: PhantomData,
128 }
129 }
130
131 pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
132 ViewExpand {
133 inner: ViewType::ReadWrite(Arc::new(view)),
134 _io: PhantomData,
135 }
136 }
137}
138
139impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
140 pub fn view_mut<T: Coordinates>(
141 &self,
142 _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
143 ) -> View<E, T, ReadWrite> {
144 unexpanded!()
145 }
146
147 pub fn __expand_view_mut<T: Coordinates + 'static>(
148 scope: &mut Scope,
149 this: ViewExpand<E, C, ReadWrite>,
150 layout: VirtualLayoutExpand<T, C>,
151 ) -> ViewExpand<E, T, ReadWrite> {
152 this.__expand_view_mut_method(scope, layout)
153 }
154}
155
156impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
157 pub fn __expand_view_mut_method<T: Coordinates + 'static>(
158 self,
159 scope: &mut Scope,
160 layout: VirtualLayoutExpand<T, C>,
161 ) -> ViewExpand<E, T, ReadWrite> {
162 View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
163 }
164}
165
166impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
167 pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
170 _view: &mut V,
171 _layout: impl Into<VirtualLayout<C, S>>,
172 ) -> View<E, C, ReadWrite> {
173 View {
174 _ty: PhantomData,
175 _layout: PhantomData,
176 }
177 }
178
179 pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
181 scope: &mut Scope,
182 view: V::ExpandType,
183 layout: VirtualLayoutExpand<C, S>,
184 ) -> ViewExpand<E, C, ReadWrite> {
185 ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
186 scope, view, layout,
187 ))
188 }
189}
190
191impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
192 pub fn shape(&self) -> C {
194 unexpanded!()
195 }
196
197 pub fn is_in_bounds(&self, _pos: C) -> bool {
199 unexpanded!()
200 }
201
202 pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
203 this.__expand_shape_method(scope)
204 }
205
206 pub fn __expand_is_in_bounds(
207 scope: &mut Scope,
208 this: ViewExpand<E, C, IO>,
209 pos: C::ExpandType,
210 ) -> NativeExpand<bool> {
211 this.__expand_is_in_bounds_method(scope, pos)
212 }
213}
214
215impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
216 pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
217 self.inner.read().__expand_shape_method(scope)
218 }
219
220 pub fn __expand_is_in_bounds_method(
221 &self,
222 scope: &mut Scope,
223 pos: C::ExpandType,
224 ) -> NativeExpand<bool> {
225 self.inner.read().__expand_is_in_bounds_method(scope, pos)
226 }
227}
228
229#[allow(unused_variables)]
230impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
231 pub fn read(&self, pos: C) -> E {
233 unexpanded!()
234 }
235
236 pub fn read_unchecked(&self, pos: C) -> E {
239 unexpanded!()
240 }
241
242 pub fn read_checked(&self, pos: C) -> E {
244 unexpanded!()
245 }
246
247 pub fn read_masked(&self, pos: C, mask_value: E) -> E {
249 unexpanded!()
250 }
251
252 pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
258 unexpanded!()
259 }
260
261 pub fn vector_size(&self) -> VectorSize {
262 unexpanded!()
263 }
264}
265
266impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
267 pub fn __expand_read_method(self, scope: &mut Scope, pos: C::ExpandType) -> NativeExpand<E> {
269 self.inner.read().__expand_read_method(scope, pos)
270 }
271
272 pub fn __expand_read_unchecked_method(
274 self,
275 scope: &mut Scope,
276 pos: C::ExpandType,
277 ) -> NativeExpand<E> {
278 self.inner.read().__expand_read_unchecked_method(scope, pos)
279 }
280
281 pub fn __expand_read_checked_method(
283 self,
284 scope: &mut Scope,
285 pos: C::ExpandType,
286 ) -> NativeExpand<E> {
287 self.inner.read().__expand_read_checked_method(scope, pos)
288 }
289
290 pub fn __expand_read_masked_method(
292 self,
293 scope: &mut Scope,
294 pos: C::ExpandType,
295 mask_value: E::ExpandType,
296 ) -> NativeExpand<E> {
297 self.inner
298 .read()
299 .__expand_read_masked_method(scope, pos, mask_value)
300 }
301
302 pub fn __expand_vector_size_method(&self, _scope: &mut Scope) -> VectorSize {
304 self.inner.read().vector_size()
305 }
306
307 pub fn vector_size(&self) -> VectorSize {
308 self.inner.read().vector_size()
309 }
310
311 pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
312 let shape = self.inner.read().__expand_shape_method(scope);
313 let origin = C::__expand_from_int(scope, shape.clone(), 0);
314 let one = C::__expand_from_int(scope, shape.clone(), 1);
316 let shape = C::__expand_max(scope, shape, one.clone());
317 let end = C::__expand_sub(scope, shape, one);
318 self.inner
319 .read()
320 .__expand_to_linear_slice_method(scope, origin, end)
321 }
322
323 pub(super) fn __expand_to_linear_slice_inner_method(
324 self,
325 scope: &mut Scope,
326 pos: C::ExpandType,
327 end: C::ExpandType,
328 ) -> SliceExpand<E, ReadOnly> {
329 self.inner
330 .read()
331 .__expand_to_linear_slice_method(scope, pos, end)
332 }
333}
334
335impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
336 pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
340 unexpanded!()
341 }
342
343 pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
349 unexpanded!()
350 }
351
352 pub fn __expand_slice(
353 scope: &mut Scope,
354 this: ViewExpand<E, C, IO>,
355 pos: C::ExpandType,
356 size: C::ExpandType,
357 ) -> ViewExpand<E, C, ReadOnly> {
358 this.__expand_slice_method(scope, pos, size)
359 }
360
361 pub fn __expand_slice_unchecked(
362 scope: &mut Scope,
363 this: ViewExpand<E, C, IO>,
364 pos: C::ExpandType,
365 size: C::ExpandType,
366 ) -> ViewExpand<E, C, ReadOnly> {
367 this.__expand_slice_unchecked_method(scope, pos, size)
368 }
369}
370
371#[cube]
372impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
373
374impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
375 pub fn __expand_slice_method(
376 &self,
377 scope: &mut Scope,
378 pos: C::ExpandType,
379 size: C::ExpandType,
380 ) -> ViewExpand<E, C, ReadOnly> {
381 self.slice(scope, pos, size, true)
382 }
383
384 pub fn __expand_slice_unchecked_method(
385 &self,
386 scope: &mut Scope,
387 pos: C::ExpandType,
388 size: C::ExpandType,
389 ) -> ViewExpand<E, C, ReadOnly> {
390 self.slice(scope, pos, size, false)
391 }
392
393 fn slice(
394 &self,
395 scope: &mut Scope,
396 pos: C::ExpandType,
397 size: C::ExpandType,
398 checked: bool,
399 ) -> ViewExpand<E, C, ReadOnly> {
400 let shape = self.__expand_shape_method(scope);
401 let pos = C::__expand_min(scope, pos, shape.clone());
402 let max_size = C::__expand_sub(scope, shape, pos.clone());
403 let size = C::__expand_min(scope, size, max_size);
404 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
405 self.clone().__expand_view_method(scope, layout.into())
406 }
407}
408
409#[allow(unused_variables)]
410impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
411 pub fn write(&self, pos: C, value: E) {
413 unexpanded!()
414 }
415
416 pub fn write_checked(&self, pos: C, value: E) {
418 unexpanded!()
419 }
420
421 pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
427 unexpanded!()
428 }
429}
430
431impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
432 pub fn __expand_write_method(
434 self,
435 scope: &mut Scope,
436 pos: C::ExpandType,
437 value: NativeExpand<E>,
438 ) {
439 self.inner.write().__expand_write_method(scope, pos, value);
440 }
441
442 pub fn __expand_write_checked_method(
444 self,
445 scope: &mut Scope,
446 pos: C::ExpandType,
447 value: NativeExpand<E>,
448 ) {
449 self.inner
450 .write()
451 .__expand_write_checked_method(scope, pos, value);
452 }
453
454 pub fn __expand_to_linear_slice_mut_method(
455 self,
456 scope: &mut Scope,
457 ) -> SliceExpand<E, ReadWrite> {
458 let shape = self.inner.read().__expand_shape_method(scope);
459 let origin = C::__expand_from_int(scope, shape.clone(), 0);
460 let one = C::__expand_from_int(scope, shape.clone(), 1);
462 let shape = C::__expand_max(scope, shape, one.clone());
463 let end = C::__expand_sub(scope, shape, one);
464 self.inner
465 .write()
466 .__expand_to_linear_slice_mut_method(scope, origin, end)
467 }
468
469 pub(super) fn __expand_to_linear_slice_mut_inner_method(
470 self,
471 scope: &mut Scope,
472 pos: C::ExpandType,
473 end: C::ExpandType,
474 ) -> SliceExpand<E, ReadWrite> {
475 self.inner
476 .write()
477 .__expand_to_linear_slice_mut_method(scope, pos, end)
478 }
479}
480
481impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
482 pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
486 unexpanded!()
487 }
488
489 pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
496 unexpanded!()
497 }
498
499 pub fn __expand_slice_mut(
500 scope: &mut Scope,
501 this: ViewExpand<E, C, ReadWrite>,
502 pos: C::ExpandType,
503 size: C::ExpandType,
504 ) -> ViewExpand<E, C, ReadWrite> {
505 this.__expand_slice_mut_method(scope, pos, size)
506 }
507
508 pub fn __expand_slice_mut_unchecked(
509 scope: &mut Scope,
510 this: ViewExpand<E, C, ReadWrite>,
511 pos: C::ExpandType,
512 size: C::ExpandType,
513 ) -> ViewExpand<E, C, ReadWrite> {
514 this.__expand_slice_mut_unchecked_method(scope, pos, size)
515 }
516}
517
518#[cube]
519impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
520
521impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
522 pub fn __expand_slice_mut_method(
523 &self,
524 scope: &mut Scope,
525 pos: C::ExpandType,
526 size: C::ExpandType,
527 ) -> ViewExpand<E, C, ReadWrite> {
528 self.slice_mut(scope, pos, size, true)
529 }
530
531 pub fn __expand_slice_mut_unchecked_method(
532 &self,
533 scope: &mut Scope,
534 pos: C::ExpandType,
535 size: C::ExpandType,
536 ) -> ViewExpand<E, C, ReadWrite> {
537 self.slice_mut(scope, pos, size, false)
538 }
539
540 fn slice_mut(
541 &self,
542 scope: &mut Scope,
543 pos: C::ExpandType,
544 size: C::ExpandType,
545 checked: bool,
546 ) -> ViewExpand<E, C, ReadWrite> {
547 let shape = self.__expand_shape_method(scope);
548 let pos = C::__expand_min(scope, pos, shape.clone());
549 let max_size = C::__expand_sub(scope, shape, pos.clone());
550 let size = C::__expand_min(scope, size, max_size);
551 let layout = SliceLayout::__expand_new(scope, pos, size, checked);
552 self.clone().__expand_view_mut_method(scope, layout.into())
553 }
554}
555
556impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
557 pub fn tensor_map_load(
560 &self,
561 _barrier: &Barrier,
562 _shared_memory: &mut Slice<E, ReadWrite>,
563 _pos: C,
564 ) -> View<E, C, ReadWrite> {
565 unexpanded!()
566 }
567}
568
569impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
570 pub fn __expand_tensor_map_load_method(
571 self,
572 scope: &mut Scope,
573 barrier: BarrierExpand,
574 shared_memory: SliceExpand<E, ReadWrite>,
575 pos: C::ExpandType,
576 ) {
577 self.inner
578 .read()
579 .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
580 }
581}
582
583impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
584 pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
587 unexpanded!()
588 }
589}
590
591impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
592 pub fn __expand_tensor_map_store_method(
593 self,
594 scope: &mut Scope,
595 shared_memory: SliceExpand<E, ReadOnly>,
596 pos: C::ExpandType,
597 ) {
598 self.inner
599 .write()
600 .__expand_tensor_map_store_method(scope, shared_memory, pos)
601 }
602}