cubecl_std/tensor/view/
base.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5    self as cubecl,
6    ir::LineSize,
7    prelude::barrier::{Barrier, BarrierExpand},
8    unexpanded,
9};
10
11use crate::tensor::{
12    ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand, VirtualView,
13    VirtualViewMut,
14    layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
15};
16
17/// A conceptual view of an underlying linear storage.
18/// Allows abstract indexing in multiple dimensions, without having to know the data layout or
19/// location.
20#[derive(Clone)]
21pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
22    _layout: PhantomData<C>,
23    _ty: PhantomData<(E, IO)>,
24}
25
26// `View` is a dummy type so it's always send/sync
27unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
28unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
29impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
30
31#[derive(Clone)]
32pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
33    Read(Arc<dyn ViewOperationsExpand<E, C>>),
34    ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
35}
36
37impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
38    /// Dereference in read mode
39    pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
40        match self {
41            ViewType::Read(list) => &**list,
42            ViewType::ReadWrite(list) => &**list,
43        }
44    }
45
46    /// Dereference in write mode
47    pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
48        match self {
49            ViewType::Read(_) => panic!("Can't write to readonly list"),
50            ViewType::ReadWrite(list) => &**list,
51        }
52    }
53}
54
55/// Expand type of [TensorView]
56#[derive(Clone)]
57pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
58    pub(super) inner: ViewType<E, C>,
59    pub(super) _io: PhantomData<IO>,
60}
61
62impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
63    type ExpandType = ViewExpand<E, C, IO>;
64}
65
66impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
67    fn into_mut(self, _scope: &mut Scope) -> Self {
68        self
69    }
70}
71
72impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
73
74impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
75    /// Create a new tensor view from an underlying concrete storage and a layout to map it into
76    /// the target coordinate space
77    #[allow(unused_variables)]
78    pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
79        view: &V,
80        layout: impl Into<VirtualLayout<C, S>>,
81    ) -> Self {
82        View {
83            _layout: PhantomData,
84            _ty: PhantomData,
85        }
86    }
87
88    /// Expand function for [TensorView::new]
89    pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
90        scope: &mut Scope,
91        view: V::ExpandType,
92        layout: VirtualLayoutExpand<C, S>,
93    ) -> ViewExpand<E, C, ReadOnly> {
94        ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
95    }
96}
97
98impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
99    pub fn view<T: Coordinates>(
100        &self,
101        _layout: impl Into<VirtualLayout<T, C>>,
102    ) -> View<E, T, ReadOnly> {
103        unexpanded!()
104    }
105
106    pub fn __expand_view<T: Coordinates + 'static>(
107        scope: &mut Scope,
108        this: ViewExpand<E, C, IO>,
109        layout: VirtualLayoutExpand<T, C>,
110    ) -> ViewExpand<E, T, ReadOnly> {
111        this.__expand_view_method(scope, layout)
112    }
113}
114
115impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
116    pub fn __expand_view_method<T: Coordinates + 'static>(
117        self,
118        scope: &mut Scope,
119        layout: VirtualLayoutExpand<T, C>,
120    ) -> ViewExpand<E, T, ReadOnly> {
121        View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
122    }
123
124    pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
125        ViewExpand {
126            inner: ViewType::Read(Arc::new(view)),
127            _io: PhantomData,
128        }
129    }
130
131    pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
132        ViewExpand {
133            inner: ViewType::ReadWrite(Arc::new(view)),
134            _io: PhantomData,
135        }
136    }
137}
138
139impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
140    pub fn view_mut<T: Coordinates>(
141        &self,
142        _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
143    ) -> View<E, T, ReadWrite> {
144        unexpanded!()
145    }
146
147    pub fn __expand_view_mut<T: Coordinates + 'static>(
148        scope: &mut Scope,
149        this: ViewExpand<E, C, ReadWrite>,
150        layout: VirtualLayoutExpand<T, C>,
151    ) -> ViewExpand<E, T, ReadWrite> {
152        this.__expand_view_mut_method(scope, layout)
153    }
154}
155
156impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
157    pub fn __expand_view_mut_method<T: Coordinates + 'static>(
158        self,
159        scope: &mut Scope,
160        layout: VirtualLayoutExpand<T, C>,
161    ) -> ViewExpand<E, T, ReadWrite> {
162        View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
163    }
164}
165
166impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
167    /// Create a new mutable tensor view from an underlying concrete storage and a layout to map it
168    /// into the target coordinate space
169    pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
170        _view: &mut V,
171        _layout: impl Into<VirtualLayout<C, S>>,
172    ) -> View<E, C, ReadWrite> {
173        View {
174            _ty: PhantomData,
175            _layout: PhantomData,
176        }
177    }
178
179    /// Expand function for [TensorView::new_mut]
180    pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
181        scope: &mut Scope,
182        view: V::ExpandType,
183        layout: VirtualLayoutExpand<C, S>,
184    ) -> ViewExpand<E, C, ReadWrite> {
185        ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
186            scope, view, layout,
187        ))
188    }
189}
190
191impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
192    /// Calls [Layout::shape] on the view's layout
193    pub fn shape(&self) -> C {
194        unexpanded!()
195    }
196
197    /// Calls [Layout::is_in_bounds] on the view's layout
198    pub fn is_in_bounds(&self, _pos: C) -> bool {
199        unexpanded!()
200    }
201
202    pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
203        this.__expand_shape_method(scope)
204    }
205
206    pub fn __expand_is_in_bounds(
207        scope: &mut Scope,
208        this: ViewExpand<E, C, IO>,
209        pos: C::ExpandType,
210    ) -> ExpandElementTyped<bool> {
211        this.__expand_is_in_bounds_method(scope, pos)
212    }
213}
214
215impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
216    pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
217        self.inner.read().__expand_shape_method(scope)
218    }
219
220    pub fn __expand_is_in_bounds_method(
221        &self,
222        scope: &mut Scope,
223        pos: C::ExpandType,
224    ) -> ExpandElementTyped<bool> {
225        self.inner.read().__expand_is_in_bounds_method(scope, pos)
226    }
227}
228
229#[allow(unused_variables)]
230impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
231    /// Read a line at `pos`. The layout handles translation into a concrete index.
232    pub fn read(&self, pos: C) -> E {
233        unexpanded!()
234    }
235
236    /// Read a line at `pos`. The layout handles translation into a concrete index.
237    /// Reading is done unchecked
238    pub fn read_unchecked(&self, pos: C) -> E {
239        unexpanded!()
240    }
241
242    /// Read a line at `pos` if it's in bounds. The layout handles translation into a concrete index.
243    pub fn read_checked(&self, pos: C) -> E {
244        unexpanded!()
245    }
246
247    /// Read a line at `pos` if it's in bounds, returning `mask_value` otherwise. The layout handles translation into a concrete index.
248    pub fn read_masked(&self, pos: C, mask_value: E) -> E {
249        unexpanded!()
250    }
251
252    /// Interpret this view as a linear slice encompassing the entire view.
253    ///
254    /// # Safety
255    ///
256    /// No checking is done on whether the slice is contiguous in memory.
257    pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
258        unexpanded!()
259    }
260
261    pub fn line_size(&self) -> LineSize {
262        unexpanded!()
263    }
264}
265
266impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
267    /// Expand method for [TensorView::read]
268    pub fn __expand_read_method(
269        self,
270        scope: &mut Scope,
271        pos: C::ExpandType,
272    ) -> ExpandElementTyped<E> {
273        self.inner.read().__expand_read_method(scope, pos)
274    }
275
276    /// Expand method for [TensorView::read_unchecked]
277    pub fn __expand_read_unchecked_method(
278        self,
279        scope: &mut Scope,
280        pos: C::ExpandType,
281    ) -> ExpandElementTyped<E> {
282        self.inner.read().__expand_read_unchecked_method(scope, pos)
283    }
284
285    /// Expand method for [TensorView::read_checked]
286    pub fn __expand_read_checked_method(
287        self,
288        scope: &mut Scope,
289        pos: C::ExpandType,
290    ) -> ExpandElementTyped<E> {
291        self.inner.read().__expand_read_checked_method(scope, pos)
292    }
293
294    /// Expand method for [TensorView::read_masked]
295    pub fn __expand_read_masked_method(
296        self,
297        scope: &mut Scope,
298        pos: C::ExpandType,
299        mask_value: E::ExpandType,
300    ) -> ExpandElementTyped<E> {
301        self.inner
302            .read()
303            .__expand_read_masked_method(scope, pos, mask_value)
304    }
305
306    /// Expand method for [TensorView::line_size]
307    pub fn __expand_line_size_method(self, _scope: &mut Scope) -> LineSize {
308        self.inner.read().line_size()
309    }
310
311    pub fn line_size(&self) -> LineSize {
312        self.inner.read().line_size()
313    }
314
315    pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
316        let shape = self.inner.read().__expand_shape_method(scope);
317        let origin = C::__expand_from_int(scope, shape.clone(), 0);
318        // Inclusive end so clamping works correctly
319        let one = C::__expand_from_int(scope, shape.clone(), 1);
320        let shape = C::__expand_max(scope, shape, one.clone());
321        let end = C::__expand_sub(scope, shape, one);
322        self.inner
323            .read()
324            .__expand_to_linear_slice_method(scope, origin, end)
325    }
326
327    pub(super) fn __expand_to_linear_slice_inner_method(
328        self,
329        scope: &mut Scope,
330        pos: C::ExpandType,
331        end: C::ExpandType,
332    ) -> SliceExpand<E, ReadOnly> {
333        self.inner
334            .read()
335            .__expand_to_linear_slice_method(scope, pos, end)
336    }
337}
338
339impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
340    /// Create a slice starting from `pos`, with `size`.
341    /// The layout handles translation into concrete indices.
342    /// Size will be clamped to the current layout size.
343    pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
344        unexpanded!()
345    }
346
347    /// Create a slice starting from `pos`, with `size`.
348    /// The layout handles translation into concrete indices.
349    /// Size and pos will be clamped to the current layout size.
350    /// #Safety
351    /// Access is always unchecked
352    pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
353        unexpanded!()
354    }
355
356    pub fn __expand_slice(
357        scope: &mut Scope,
358        this: ViewExpand<E, C, IO>,
359        pos: C::ExpandType,
360        size: C::ExpandType,
361    ) -> ViewExpand<E, C, ReadOnly> {
362        this.__expand_slice_method(scope, pos, size)
363    }
364
365    pub fn __expand_slice_unchecked(
366        scope: &mut Scope,
367        this: ViewExpand<E, C, IO>,
368        pos: C::ExpandType,
369        size: C::ExpandType,
370    ) -> ViewExpand<E, C, ReadOnly> {
371        this.__expand_slice_unchecked_method(scope, pos, size)
372    }
373}
374
375#[cube]
376impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
377
378impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
379    pub fn __expand_slice_method(
380        &self,
381        scope: &mut Scope,
382        pos: C::ExpandType,
383        size: C::ExpandType,
384    ) -> ViewExpand<E, C, ReadOnly> {
385        self.slice(scope, pos, size, true)
386    }
387
388    pub fn __expand_slice_unchecked_method(
389        &self,
390        scope: &mut Scope,
391        pos: C::ExpandType,
392        size: C::ExpandType,
393    ) -> ViewExpand<E, C, ReadOnly> {
394        self.slice(scope, pos, size, false)
395    }
396
397    fn slice(
398        &self,
399        scope: &mut Scope,
400        pos: C::ExpandType,
401        size: C::ExpandType,
402        checked: bool,
403    ) -> ViewExpand<E, C, ReadOnly> {
404        let shape = self.__expand_shape_method(scope);
405        let pos = C::__expand_min(scope, pos, shape.clone());
406        let max_size = C::__expand_sub(scope, shape, pos.clone());
407        let size = C::__expand_min(scope, size, max_size);
408        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
409        self.clone().__expand_view_method(scope, layout.into())
410    }
411}
412
413#[allow(unused_variables)]
414impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
415    /// Write a line to `pos`. The layout handles translation into a concrete index.
416    pub fn write(&self, pos: C, value: E) {
417        unexpanded!()
418    }
419
420    /// Write a line to `pos` if it's in bounds. The layout handles translation into a concrete index.
421    pub fn write_checked(&self, pos: C, value: E) {
422        unexpanded!()
423    }
424
425    /// Interpret this view as a mutable linear slice encompassing the entire view.
426    ///
427    /// # Safety
428    ///
429    /// No checking is done on whether the slice is contiguous in memory.
430    pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
431        unexpanded!()
432    }
433}
434
435impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
436    /// Expand method for [TensorView::write]
437    pub fn __expand_write_method(
438        self,
439        scope: &mut Scope,
440        pos: C::ExpandType,
441        value: ExpandElementTyped<E>,
442    ) {
443        self.inner.write().__expand_write_method(scope, pos, value);
444    }
445
446    /// Expand method for [TensorView::write_checked]
447    pub fn __expand_write_checked_method(
448        self,
449        scope: &mut Scope,
450        pos: C::ExpandType,
451        value: ExpandElementTyped<E>,
452    ) {
453        self.inner
454            .write()
455            .__expand_write_checked_method(scope, pos, value);
456    }
457
458    pub fn __expand_to_linear_slice_mut_method(
459        self,
460        scope: &mut Scope,
461    ) -> SliceExpand<E, ReadWrite> {
462        let shape = self.inner.read().__expand_shape_method(scope);
463        let origin = C::__expand_from_int(scope, shape.clone(), 0);
464        // Inclusive end so clamping works correctly
465        let one = C::__expand_from_int(scope, shape.clone(), 1);
466        let shape = C::__expand_max(scope, shape, one.clone());
467        let end = C::__expand_sub(scope, shape, one);
468        self.inner
469            .write()
470            .__expand_to_linear_slice_mut_method(scope, origin, end)
471    }
472
473    pub(super) fn __expand_to_linear_slice_mut_inner_method(
474        self,
475        scope: &mut Scope,
476        pos: C::ExpandType,
477        end: C::ExpandType,
478    ) -> SliceExpand<E, ReadWrite> {
479        self.inner
480            .write()
481            .__expand_to_linear_slice_mut_method(scope, pos, end)
482    }
483}
484
485impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
486    /// Create a mutable slice starting from `pos`, with `size`.
487    /// The layout handles translation into concrete indices.
488    /// Size and pos will be clamped to the current layout size.
489    pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
490        unexpanded!()
491    }
492
493    /// Create a mutable slice starting from `pos`, with `size`.
494    /// The layout handles translation into concrete indices.
495    /// Size and pos will be clamped to the current layout size.
496    ///
497    /// # Safety
498    /// Access is always unchecked.
499    pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
500        unexpanded!()
501    }
502
503    pub fn __expand_slice_mut(
504        scope: &mut Scope,
505        this: ViewExpand<E, C, ReadWrite>,
506        pos: C::ExpandType,
507        size: C::ExpandType,
508    ) -> ViewExpand<E, C, ReadWrite> {
509        this.__expand_slice_mut_method(scope, pos, size)
510    }
511
512    pub fn __expand_slice_mut_unchecked(
513        scope: &mut Scope,
514        this: ViewExpand<E, C, ReadWrite>,
515        pos: C::ExpandType,
516        size: C::ExpandType,
517    ) -> ViewExpand<E, C, ReadWrite> {
518        this.__expand_slice_mut_unchecked_method(scope, pos, size)
519    }
520}
521
522#[cube]
523impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
524
525impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
526    pub fn __expand_slice_mut_method(
527        &self,
528        scope: &mut Scope,
529        pos: C::ExpandType,
530        size: C::ExpandType,
531    ) -> ViewExpand<E, C, ReadWrite> {
532        self.slice_mut(scope, pos, size, true)
533    }
534
535    pub fn __expand_slice_mut_unchecked_method(
536        &self,
537        scope: &mut Scope,
538        pos: C::ExpandType,
539        size: C::ExpandType,
540    ) -> ViewExpand<E, C, ReadWrite> {
541        self.slice_mut(scope, pos, size, false)
542    }
543
544    fn slice_mut(
545        &self,
546        scope: &mut Scope,
547        pos: C::ExpandType,
548        size: C::ExpandType,
549        checked: bool,
550    ) -> ViewExpand<E, C, ReadWrite> {
551        let shape = self.__expand_shape_method(scope);
552        let pos = C::__expand_min(scope, pos, shape.clone());
553        let max_size = C::__expand_sub(scope, shape, pos.clone());
554        let size = C::__expand_min(scope, size, max_size);
555        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
556        self.clone().__expand_view_mut_method(scope, layout.into())
557    }
558}
559
560impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
561    ///.Execute a TMA load into shared memory, if the underlying storage supports it.
562    /// Panics if it's unsupported.
563    pub fn tensor_map_load(
564        &self,
565        _barrier: &Barrier,
566        _shared_memory: &mut Slice<E, ReadWrite>,
567        _pos: C,
568    ) -> View<E, C, ReadWrite> {
569        unexpanded!()
570    }
571}
572
573impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
574    pub fn __expand_tensor_map_load_method(
575        self,
576        scope: &mut Scope,
577        barrier: BarrierExpand,
578        shared_memory: SliceExpand<E, ReadWrite>,
579        pos: C::ExpandType,
580    ) {
581        self.inner
582            .read()
583            .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
584    }
585}
586
587impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
588    ///.Execute a TMA store into global memory, if the underlying storage supports it.
589    /// Panics if it's unsupported.
590    pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
591        unexpanded!()
592    }
593}
594
595impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
596    pub fn __expand_tensor_map_store_method(
597        self,
598        scope: &mut Scope,
599        shared_memory: SliceExpand<E, ReadOnly>,
600        pos: C::ExpandType,
601    ) {
602        self.inner
603            .write()
604            .__expand_tensor_map_store_method(scope, shared_memory, pos)
605    }
606}