cubecl_std/tensor/view/
base.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5    self as cubecl,
6    prelude::barrier::{Barrier, BarrierExpand},
7    unexpanded,
8};
9
10use crate::{
11    CubeOption, CubeOptionExpand,
12    tensor::{
13        ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
14        VirtualView, VirtualViewMut,
15        layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
16    },
17};
18
19/// A conceptual view of an underlying linear storage.
20/// Allows abstract indexing in multiple dimensions, without having to know the data layout or
21/// location.
22#[derive(Clone)]
23pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
24    _layout: PhantomData<C>,
25    _ty: PhantomData<(E, IO)>,
26}
27
28// `View` is a dummy type so it's always send/sync
29unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
30unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
31impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
32
33#[derive(Clone)]
34pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
35    Read(Arc<dyn ViewOperationsExpand<E, C>>),
36    ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
37}
38
39impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
40    /// Dereference in read mode
41    pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
42        match self {
43            ViewType::Read(list) => &**list,
44            ViewType::ReadWrite(list) => &**list,
45        }
46    }
47
48    /// Dereference in write mode
49    pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
50        match self {
51            ViewType::Read(_) => panic!("Can't write to readonly list"),
52            ViewType::ReadWrite(list) => &**list,
53        }
54    }
55}
56
57/// Expand type of [TensorView]
58#[derive(Clone)]
59pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
60    pub(super) inner: ViewType<E, C>,
61    pub(super) _io: PhantomData<IO>,
62}
63
64impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
65    type ExpandType = ViewExpand<E, C, IO>;
66}
67
68impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
69    fn into_mut(self, _scope: &mut Scope) -> Self {
70        self
71    }
72}
73
74impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
75
76impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
77    /// Create a new tensor view from an underlying concrete storage and a layout to map it into
78    /// the target coordinate space
79    #[allow(unused_variables)]
80    pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
81        view: &V,
82        layout: impl Into<VirtualLayout<C, S>>,
83    ) -> Self {
84        View {
85            _layout: PhantomData,
86            _ty: PhantomData,
87        }
88    }
89
90    /// Expand function for [TensorView::new]
91    pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
92        scope: &mut Scope,
93        view: V::ExpandType,
94        layout: VirtualLayoutExpand<C, S>,
95    ) -> ViewExpand<E, C, ReadOnly> {
96        ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
97    }
98}
99
100impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
101    pub fn view<T: Coordinates>(
102        &self,
103        _layout: impl Into<VirtualLayout<T, C>>,
104    ) -> View<E, T, ReadOnly> {
105        unexpanded!()
106    }
107
108    pub fn __expand_view<T: Coordinates + 'static>(
109        scope: &mut Scope,
110        this: ViewExpand<E, C, IO>,
111        layout: VirtualLayoutExpand<T, C>,
112    ) -> ViewExpand<E, T, ReadOnly> {
113        this.__expand_view_method(scope, layout)
114    }
115}
116
117impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
118    pub fn __expand_view_method<T: Coordinates + 'static>(
119        self,
120        scope: &mut Scope,
121        layout: VirtualLayoutExpand<T, C>,
122    ) -> ViewExpand<E, T, ReadOnly> {
123        View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
124    }
125
126    pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
127        ViewExpand {
128            inner: ViewType::Read(Arc::new(view)),
129            _io: PhantomData,
130        }
131    }
132
133    pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
134        ViewExpand {
135            inner: ViewType::ReadWrite(Arc::new(view)),
136            _io: PhantomData,
137        }
138    }
139}
140
141impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
142    pub fn view_mut<T: Coordinates>(
143        &self,
144        _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
145    ) -> View<E, T, ReadWrite> {
146        unexpanded!()
147    }
148
149    pub fn __expand_view_mut<T: Coordinates + 'static>(
150        scope: &mut Scope,
151        this: ViewExpand<E, C, ReadWrite>,
152        layout: VirtualLayoutExpand<T, C>,
153    ) -> ViewExpand<E, T, ReadWrite> {
154        this.__expand_view_mut_method(scope, layout)
155    }
156}
157
158impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
159    pub fn __expand_view_mut_method<T: Coordinates + 'static>(
160        self,
161        scope: &mut Scope,
162        layout: VirtualLayoutExpand<T, C>,
163    ) -> ViewExpand<E, T, ReadWrite> {
164        View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
165    }
166}
167
168impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
169    /// Create a new mutable tensor view from an underlying concrete storage and a layout to map it
170    /// into the target coordinate space
171    pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
172        _view: &mut V,
173        _layout: impl Into<VirtualLayout<C, S>>,
174    ) -> View<E, C, ReadWrite> {
175        View {
176            _ty: PhantomData,
177            _layout: PhantomData,
178        }
179    }
180
181    /// Expand function for [TensorView::new_mut]
182    pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
183        scope: &mut Scope,
184        view: V::ExpandType,
185        layout: VirtualLayoutExpand<C, S>,
186    ) -> ViewExpand<E, C, ReadWrite> {
187        ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
188            scope, view, layout,
189        ))
190    }
191}
192
193impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
194    /// Calls [Layout::shape] on the view's layout
195    pub fn shape(&self) -> C {
196        unexpanded!()
197    }
198
199    /// Calls [Layout::is_in_bounds] on the view's layout
200    pub fn is_in_bounds(&self, _pos: C) -> bool {
201        unexpanded!()
202    }
203
204    pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
205        this.__expand_shape_method(scope)
206    }
207
208    pub fn __expand_is_in_bounds(
209        scope: &mut Scope,
210        this: ViewExpand<E, C, IO>,
211        pos: C::ExpandType,
212    ) -> ExpandElementTyped<bool> {
213        this.__expand_is_in_bounds_method(scope, pos)
214    }
215}
216
217impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
218    pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
219        self.inner.read().__expand_shape_method(scope)
220    }
221
222    pub fn __expand_is_in_bounds_method(
223        &self,
224        scope: &mut Scope,
225        pos: C::ExpandType,
226    ) -> ExpandElementTyped<bool> {
227        self.inner.read().__expand_is_in_bounds_method(scope, pos)
228    }
229}
230
231#[allow(unused_variables)]
232impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
233    /// Read a line at `pos`. The layout handles translation into a concrete index.
234    pub fn read(&self, pos: C) -> E {
235        unexpanded!()
236    }
237
238    /// Read a line at `pos`. The layout handles translation into a concrete index.
239    /// Reading is done unchecked
240    pub fn read_unchecked(&self, pos: C) -> E {
241        unexpanded!()
242    }
243
244    /// Read a line at `pos` if it's in bounds. The layout handles translation into a concrete index.
245    pub fn read_checked(&self, pos: C) -> E {
246        unexpanded!()
247    }
248
249    /// Read a line at `pos` if it's in bounds, returning `mask_value` otherwise. The layout handles translation into a concrete index.
250    pub fn read_masked(&self, pos: C, mask_value: E) -> E {
251        unexpanded!()
252    }
253
254    /// Interpret this view as a linear slice encompassing the entire view.
255    ///
256    /// # Safety
257    ///
258    /// No checking is done on whether the slice is contiguous in memory.
259    pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
260        unexpanded!()
261    }
262
263    pub fn line_size(&self) -> u32 {
264        unexpanded!()
265    }
266}
267
268impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
269    /// Expand method for [TensorView::read]
270    pub fn __expand_read_method(
271        self,
272        scope: &mut Scope,
273        pos: C::ExpandType,
274    ) -> ExpandElementTyped<E> {
275        self.inner.read().__expand_read_method(scope, pos)
276    }
277
278    /// Expand method for [TensorView::read_unchecked]
279    pub fn __expand_read_unchecked_method(
280        self,
281        scope: &mut Scope,
282        pos: C::ExpandType,
283    ) -> ExpandElementTyped<E> {
284        self.inner.read().__expand_read_unchecked_method(scope, pos)
285    }
286
287    /// Expand method for [TensorView::read_checked]
288    pub fn __expand_read_checked_method(
289        self,
290        scope: &mut Scope,
291        pos: C::ExpandType,
292    ) -> ExpandElementTyped<E> {
293        self.inner.read().__expand_read_checked_method(scope, pos)
294    }
295
296    /// Expand method for [TensorView::read_masked]
297    pub fn __expand_read_masked_method(
298        self,
299        scope: &mut Scope,
300        pos: C::ExpandType,
301        mask_value: E::ExpandType,
302    ) -> ExpandElementTyped<E> {
303        self.inner
304            .read()
305            .__expand_read_masked_method(scope, pos, mask_value)
306    }
307
308    /// Expand method for [TensorView::line_size]
309    pub fn __expand_line_size_method(self, _scope: &mut Scope) -> u32 {
310        self.inner.read().line_size()
311    }
312
313    pub fn line_size(&self) -> u32 {
314        self.inner.read().line_size()
315    }
316
317    pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
318        let shape = self.inner.read().__expand_shape_method(scope);
319        let origin = C::__expand_from_int(scope, shape.clone(), 0);
320        // Inclusive end so clamping works correctly
321        let one = C::__expand_from_int(scope, shape.clone(), 1);
322        let shape = C::__expand_max(scope, shape, one.clone());
323        let end = C::__expand_sub(scope, shape, one);
324        self.inner
325            .read()
326            .__expand_to_linear_slice_method(scope, origin, end)
327    }
328
329    pub(super) fn __expand_to_linear_slice_inner_method(
330        self,
331        scope: &mut Scope,
332        pos: C::ExpandType,
333        end: C::ExpandType,
334    ) -> SliceExpand<E, ReadOnly> {
335        self.inner
336            .read()
337            .__expand_to_linear_slice_method(scope, pos, end)
338    }
339}
340
341impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
342    /// Create a slice starting from `pos`, with `size`.
343    /// The layout handles translation into concrete indices.
344    /// Size will be clamped to the current layout size.
345    pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
346        unexpanded!()
347    }
348
349    /// Create a slice starting from `pos`, with `size`.
350    /// The layout handles translation into concrete indices.
351    /// Size and pos will be clamped to the current layout size.
352    /// #Safety
353    /// Access is always unchecked
354    pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
355        unexpanded!()
356    }
357
358    pub fn __expand_slice(
359        scope: &mut Scope,
360        this: ViewExpand<E, C, IO>,
361        pos: C::ExpandType,
362        size: C::ExpandType,
363    ) -> ViewExpand<E, C, ReadOnly> {
364        this.__expand_slice_method(scope, pos, size)
365    }
366
367    pub fn __expand_slice_unchecked(
368        scope: &mut Scope,
369        this: ViewExpand<E, C, IO>,
370        pos: C::ExpandType,
371        size: C::ExpandType,
372    ) -> ViewExpand<E, C, ReadOnly> {
373        this.__expand_slice_unchecked_method(scope, pos, size)
374    }
375}
376
377#[cube]
378impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
379
380impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
381    pub fn __expand_slice_method(
382        &self,
383        scope: &mut Scope,
384        pos: C::ExpandType,
385        size: C::ExpandType,
386    ) -> ViewExpand<E, C, ReadOnly> {
387        self.slice(scope, pos, size, true)
388    }
389
390    pub fn __expand_slice_unchecked_method(
391        &self,
392        scope: &mut Scope,
393        pos: C::ExpandType,
394        size: C::ExpandType,
395    ) -> ViewExpand<E, C, ReadOnly> {
396        self.slice(scope, pos, size, false)
397    }
398
399    fn slice(
400        &self,
401        scope: &mut Scope,
402        pos: C::ExpandType,
403        size: C::ExpandType,
404        checked: bool,
405    ) -> ViewExpand<E, C, ReadOnly> {
406        let shape = self.__expand_shape_method(scope);
407        let pos = C::__expand_min(scope, pos, shape.clone());
408        let max_size = C::__expand_sub(scope, shape, pos.clone());
409        let size = C::__expand_min(scope, size, max_size);
410        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
411        self.clone().__expand_view_method(scope, layout.into())
412    }
413}
414
415#[allow(unused_variables)]
416impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
417    /// Write a line to `pos`. The layout handles translation into a concrete index.
418    pub fn write(&self, pos: C, value: E) {
419        unexpanded!()
420    }
421
422    /// Write a line to `pos` if it's in bounds. The layout handles translation into a concrete index.
423    pub fn write_checked(&self, pos: C, value: E) {
424        unexpanded!()
425    }
426
427    /// Interpret this view as a mutable linear slice encompassing the entire view.
428    ///
429    /// # Safety
430    ///
431    /// No checking is done on whether the slice is contiguous in memory.
432    pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
433        unexpanded!()
434    }
435}
436
437impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
438    /// Expand method for [TensorView::write]
439    pub fn __expand_write_method(
440        self,
441        scope: &mut Scope,
442        pos: C::ExpandType,
443        value: ExpandElementTyped<E>,
444    ) {
445        self.inner.write().__expand_write_method(scope, pos, value);
446    }
447
448    /// Expand method for [TensorView::write_checked]
449    pub fn __expand_write_checked_method(
450        self,
451        scope: &mut Scope,
452        pos: C::ExpandType,
453        value: ExpandElementTyped<E>,
454    ) {
455        self.inner
456            .write()
457            .__expand_write_checked_method(scope, pos, value);
458    }
459
460    pub fn __expand_to_linear_slice_mut_method(
461        self,
462        scope: &mut Scope,
463    ) -> SliceExpand<E, ReadWrite> {
464        let shape = self.inner.read().__expand_shape_method(scope);
465        let origin = C::__expand_from_int(scope, shape.clone(), 0);
466        // Inclusive end so clamping works correctly
467        let one = C::__expand_from_int(scope, shape.clone(), 1);
468        let shape = C::__expand_max(scope, shape, one.clone());
469        let end = C::__expand_sub(scope, shape, one);
470        self.inner
471            .write()
472            .__expand_to_linear_slice_mut_method(scope, origin, end)
473    }
474
475    pub(super) fn __expand_to_linear_slice_mut_inner_method(
476        self,
477        scope: &mut Scope,
478        pos: C::ExpandType,
479        end: C::ExpandType,
480    ) -> SliceExpand<E, ReadWrite> {
481        self.inner
482            .write()
483            .__expand_to_linear_slice_mut_method(scope, pos, end)
484    }
485}
486
487impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
488    /// Create a mutable slice starting from `pos`, with `size`.
489    /// The layout handles translation into concrete indices.
490    /// Size and pos will be clamped to the current layout size.
491    pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
492        unexpanded!()
493    }
494
495    /// Create a mutable slice starting from `pos`, with `size`.
496    /// The layout handles translation into concrete indices.
497    /// Size and pos will be clamped to the current layout size.
498    ///
499    /// # Safety
500    /// Access is always unchecked.
501    pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
502        unexpanded!()
503    }
504
505    pub fn __expand_slice_mut(
506        scope: &mut Scope,
507        this: ViewExpand<E, C, ReadWrite>,
508        pos: C::ExpandType,
509        size: C::ExpandType,
510    ) -> ViewExpand<E, C, ReadWrite> {
511        this.__expand_slice_mut_method(scope, pos, size)
512    }
513
514    pub fn __expand_slice_mut_unchecked(
515        scope: &mut Scope,
516        this: ViewExpand<E, C, ReadWrite>,
517        pos: C::ExpandType,
518        size: C::ExpandType,
519    ) -> ViewExpand<E, C, ReadWrite> {
520        this.__expand_slice_mut_unchecked_method(scope, pos, size)
521    }
522}
523
524#[cube]
525impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
526
527impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
528    pub fn __expand_slice_mut_method(
529        &self,
530        scope: &mut Scope,
531        pos: C::ExpandType,
532        size: C::ExpandType,
533    ) -> ViewExpand<E, C, ReadWrite> {
534        self.slice_mut(scope, pos, size, true)
535    }
536
537    pub fn __expand_slice_mut_unchecked_method(
538        &self,
539        scope: &mut Scope,
540        pos: C::ExpandType,
541        size: C::ExpandType,
542    ) -> ViewExpand<E, C, ReadWrite> {
543        self.slice_mut(scope, pos, size, false)
544    }
545
546    fn slice_mut(
547        &self,
548        scope: &mut Scope,
549        pos: C::ExpandType,
550        size: C::ExpandType,
551        checked: bool,
552    ) -> ViewExpand<E, C, ReadWrite> {
553        let shape = self.__expand_shape_method(scope);
554        let pos = C::__expand_min(scope, pos, shape.clone());
555        let max_size = C::__expand_sub(scope, shape, pos.clone());
556        let size = C::__expand_min(scope, size, max_size);
557        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
558        self.clone().__expand_view_mut_method(scope, layout.into())
559    }
560}
561
562impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
563    ///.Execute a TMA load into shared memory, if the underlying storage supports it.
564    /// Panics if it's unsupported.
565    pub fn tensor_map_load(
566        &self,
567        _barrier: &Barrier,
568        _shared_memory: &mut Slice<E, ReadWrite>,
569        _pos: C,
570    ) -> View<E, C, ReadWrite> {
571        unexpanded!()
572    }
573
574    /// Fetches the underlying tensor map if present, to override the layout and apply a different
575    /// base dimensionality.
576    pub fn as_tensor_map(&self) -> CubeOption<TensorMap<E>> {
577        unexpanded!()
578    }
579}
580
581impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
582    pub fn __expand_tensor_map_load_method(
583        self,
584        scope: &mut Scope,
585        barrier: BarrierExpand,
586        shared_memory: SliceExpand<E, ReadWrite>,
587        pos: C::ExpandType,
588    ) {
589        self.inner
590            .read()
591            .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
592    }
593
594    pub fn __expand_as_tensor_map_method(
595        self,
596        scope: &mut Scope,
597    ) -> CubeOptionExpand<TensorMap<E>> {
598        self.inner.read().__expand_as_tensor_map_method(scope)
599    }
600}
601
602impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
603    ///.Execute a TMA store into global memory, if the underlying storage supports it.
604    /// Panics if it's unsupported.
605    pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
606        unexpanded!()
607    }
608}
609
610impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
611    pub fn __expand_tensor_map_store_method(
612        self,
613        scope: &mut Scope,
614        shared_memory: SliceExpand<E, ReadOnly>,
615        pos: C::ExpandType,
616    ) {
617        self.inner
618            .write()
619            .__expand_tensor_map_store_method(scope, shared_memory, pos)
620    }
621}