cubecl_std/tensor/view/
base.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5    self as cubecl,
6    prelude::barrier::{Barrier, BarrierExpand},
7    unexpanded,
8};
9
10use crate::tensor::{
11    ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand, VirtualView,
12    VirtualViewMut,
13    layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
14};
15
16/// A conceptual view of an underlying linear storage.
17/// Allows abstract indexing in multiple dimensions, without having to know the data layout or
18/// location.
19#[derive(Clone)]
20pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
21    _layout: PhantomData<C>,
22    _ty: PhantomData<(E, IO)>,
23}
24
25// `View` is a dummy type so it's always send/sync
26unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
27unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
28impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
29
30#[derive(Clone)]
31pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
32    Read(Arc<dyn ViewOperationsExpand<E, C>>),
33    ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
34}
35
36impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
37    /// Dereference in read mode
38    pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
39        match self {
40            ViewType::Read(list) => &**list,
41            ViewType::ReadWrite(list) => &**list,
42        }
43    }
44
45    /// Dereference in write mode
46    pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
47        match self {
48            ViewType::Read(_) => panic!("Can't write to readonly list"),
49            ViewType::ReadWrite(list) => &**list,
50        }
51    }
52}
53
54/// Expand type of [TensorView]
55#[derive(Clone)]
56pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
57    pub(super) inner: ViewType<E, C>,
58    pub(super) _io: PhantomData<IO>,
59}
60
61impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
62    type ExpandType = ViewExpand<E, C, IO>;
63}
64
65impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
66    fn into_mut(self, _scope: &mut Scope) -> Self {
67        self
68    }
69}
70
71impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
72
73impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
74    /// Create a new tensor view from an underlying concrete storage and a layout to map it into
75    /// the target coordinate space
76    #[allow(unused_variables)]
77    pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
78        view: &V,
79        layout: impl Into<VirtualLayout<C, S>>,
80    ) -> Self {
81        View {
82            _layout: PhantomData,
83            _ty: PhantomData,
84        }
85    }
86
87    /// Expand function for [TensorView::new]
88    pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
89        scope: &mut Scope,
90        view: V::ExpandType,
91        layout: VirtualLayoutExpand<C, S>,
92    ) -> ViewExpand<E, C, ReadOnly> {
93        ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
94    }
95}
96
97impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
98    pub fn view<T: Coordinates>(
99        &self,
100        _layout: impl Into<VirtualLayout<T, C>>,
101    ) -> View<E, T, ReadOnly> {
102        unexpanded!()
103    }
104
105    pub fn __expand_view<T: Coordinates + 'static>(
106        scope: &mut Scope,
107        this: ViewExpand<E, C, IO>,
108        layout: VirtualLayoutExpand<T, C>,
109    ) -> ViewExpand<E, T, ReadOnly> {
110        this.__expand_view_method(scope, layout)
111    }
112}
113
114impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
115    pub fn __expand_view_method<T: Coordinates + 'static>(
116        self,
117        scope: &mut Scope,
118        layout: VirtualLayoutExpand<T, C>,
119    ) -> ViewExpand<E, T, ReadOnly> {
120        View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
121    }
122
123    pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
124        ViewExpand {
125            inner: ViewType::Read(Arc::new(view)),
126            _io: PhantomData,
127        }
128    }
129
130    pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
131        ViewExpand {
132            inner: ViewType::ReadWrite(Arc::new(view)),
133            _io: PhantomData,
134        }
135    }
136}
137
138impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
139    pub fn view_mut<T: Coordinates>(
140        &self,
141        _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
142    ) -> View<E, T, ReadWrite> {
143        unexpanded!()
144    }
145
146    pub fn __expand_view_mut<T: Coordinates + 'static>(
147        scope: &mut Scope,
148        this: ViewExpand<E, C, ReadWrite>,
149        layout: VirtualLayoutExpand<T, C>,
150    ) -> ViewExpand<E, T, ReadWrite> {
151        this.__expand_view_mut_method(scope, layout)
152    }
153}
154
155impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
156    pub fn __expand_view_mut_method<T: Coordinates + 'static>(
157        self,
158        scope: &mut Scope,
159        layout: VirtualLayoutExpand<T, C>,
160    ) -> ViewExpand<E, T, ReadWrite> {
161        View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
162    }
163}
164
165impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
166    /// Create a new mutable tensor view from an underlying concrete storage and a layout to map it
167    /// into the target coordinate space
168    pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
169        _view: &mut V,
170        _layout: impl Into<VirtualLayout<C, S>>,
171    ) -> View<E, C, ReadWrite> {
172        View {
173            _ty: PhantomData,
174            _layout: PhantomData,
175        }
176    }
177
178    /// Expand function for [TensorView::new_mut]
179    pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
180        scope: &mut Scope,
181        view: V::ExpandType,
182        layout: VirtualLayoutExpand<C, S>,
183    ) -> ViewExpand<E, C, ReadWrite> {
184        ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
185            scope, view, layout,
186        ))
187    }
188}
189
190impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
191    /// Calls [Layout::shape] on the view's layout
192    pub fn shape(&self) -> C {
193        unexpanded!()
194    }
195
196    /// Calls [Layout::is_in_bounds] on the view's layout
197    pub fn is_in_bounds(&self, _pos: C) -> bool {
198        unexpanded!()
199    }
200
201    pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
202        this.__expand_shape_method(scope)
203    }
204
205    pub fn __expand_is_in_bounds(
206        scope: &mut Scope,
207        this: ViewExpand<E, C, IO>,
208        pos: C::ExpandType,
209    ) -> ExpandElementTyped<bool> {
210        this.__expand_is_in_bounds_method(scope, pos)
211    }
212}
213
214impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
215    pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
216        self.inner.read().__expand_shape_method(scope)
217    }
218
219    pub fn __expand_is_in_bounds_method(
220        &self,
221        scope: &mut Scope,
222        pos: C::ExpandType,
223    ) -> ExpandElementTyped<bool> {
224        self.inner.read().__expand_is_in_bounds_method(scope, pos)
225    }
226}
227
228#[allow(unused_variables)]
229impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
230    /// Read a line at `pos`. The layout handles translation into a concrete index.
231    pub fn read(&self, pos: C) -> E {
232        unexpanded!()
233    }
234
235    /// Read a line at `pos`. The layout handles translation into a concrete index.
236    /// Reading is done unchecked
237    pub fn read_unchecked(&self, pos: C) -> E {
238        unexpanded!()
239    }
240
241    /// Read a line at `pos` if it's in bounds. The layout handles translation into a concrete index.
242    pub fn read_checked(&self, pos: C) -> E {
243        unexpanded!()
244    }
245
246    /// Read a line at `pos` if it's in bounds, returning `mask_value` otherwise. The layout handles translation into a concrete index.
247    pub fn read_masked(&self, pos: C, mask_value: E) -> E {
248        unexpanded!()
249    }
250
251    /// Interpret this view as a linear slice encompassing the entire view.
252    ///
253    /// # Safety
254    ///
255    /// No checking is done on whether the slice is contiguous in memory.
256    pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
257        unexpanded!()
258    }
259
260    pub fn line_size(&self) -> u32 {
261        unexpanded!()
262    }
263}
264
265impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
266    /// Expand method for [TensorView::read]
267    pub fn __expand_read_method(
268        self,
269        scope: &mut Scope,
270        pos: C::ExpandType,
271    ) -> ExpandElementTyped<E> {
272        self.inner.read().__expand_read_method(scope, pos)
273    }
274
275    /// Expand method for [TensorView::read_unchecked]
276    pub fn __expand_read_unchecked_method(
277        self,
278        scope: &mut Scope,
279        pos: C::ExpandType,
280    ) -> ExpandElementTyped<E> {
281        self.inner.read().__expand_read_unchecked_method(scope, pos)
282    }
283
284    /// Expand method for [TensorView::read_checked]
285    pub fn __expand_read_checked_method(
286        self,
287        scope: &mut Scope,
288        pos: C::ExpandType,
289    ) -> ExpandElementTyped<E> {
290        self.inner.read().__expand_read_checked_method(scope, pos)
291    }
292
293    /// Expand method for [TensorView::read_masked]
294    pub fn __expand_read_masked_method(
295        self,
296        scope: &mut Scope,
297        pos: C::ExpandType,
298        mask_value: E::ExpandType,
299    ) -> ExpandElementTyped<E> {
300        self.inner
301            .read()
302            .__expand_read_masked_method(scope, pos, mask_value)
303    }
304
305    /// Expand method for [TensorView::line_size]
306    pub fn __expand_line_size_method(self, _scope: &mut Scope) -> u32 {
307        self.inner.read().line_size()
308    }
309
310    pub fn line_size(&self) -> u32 {
311        self.inner.read().line_size()
312    }
313
314    pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
315        let shape = self.inner.read().__expand_shape_method(scope);
316        let origin = C::__expand_from_int(scope, shape.clone(), 0);
317        // Inclusive end so clamping works correctly
318        let one = C::__expand_from_int(scope, shape.clone(), 1);
319        let shape = C::__expand_max(scope, shape, one.clone());
320        let end = C::__expand_sub(scope, shape, one);
321        self.inner
322            .read()
323            .__expand_to_linear_slice_method(scope, origin, end)
324    }
325
326    pub(super) fn __expand_to_linear_slice_inner_method(
327        self,
328        scope: &mut Scope,
329        pos: C::ExpandType,
330        end: C::ExpandType,
331    ) -> SliceExpand<E, ReadOnly> {
332        self.inner
333            .read()
334            .__expand_to_linear_slice_method(scope, pos, end)
335    }
336}
337
338impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
339    /// Create a slice starting from `pos`, with `size`.
340    /// The layout handles translation into concrete indices.
341    /// Size will be clamped to the current layout size.
342    pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
343        unexpanded!()
344    }
345
346    /// Create a slice starting from `pos`, with `size`.
347    /// The layout handles translation into concrete indices.
348    /// Size and pos will be clamped to the current layout size.
349    /// #Safety
350    /// Access is always unchecked
351    pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
352        unexpanded!()
353    }
354
355    pub fn __expand_slice(
356        scope: &mut Scope,
357        this: ViewExpand<E, C, IO>,
358        pos: C::ExpandType,
359        size: C::ExpandType,
360    ) -> ViewExpand<E, C, ReadOnly> {
361        this.__expand_slice_method(scope, pos, size)
362    }
363
364    pub fn __expand_slice_unchecked(
365        scope: &mut Scope,
366        this: ViewExpand<E, C, IO>,
367        pos: C::ExpandType,
368        size: C::ExpandType,
369    ) -> ViewExpand<E, C, ReadOnly> {
370        this.__expand_slice_unchecked_method(scope, pos, size)
371    }
372}
373
374#[cube]
375impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
376
377impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
378    pub fn __expand_slice_method(
379        &self,
380        scope: &mut Scope,
381        pos: C::ExpandType,
382        size: C::ExpandType,
383    ) -> ViewExpand<E, C, ReadOnly> {
384        self.slice(scope, pos, size, true)
385    }
386
387    pub fn __expand_slice_unchecked_method(
388        &self,
389        scope: &mut Scope,
390        pos: C::ExpandType,
391        size: C::ExpandType,
392    ) -> ViewExpand<E, C, ReadOnly> {
393        self.slice(scope, pos, size, false)
394    }
395
396    fn slice(
397        &self,
398        scope: &mut Scope,
399        pos: C::ExpandType,
400        size: C::ExpandType,
401        checked: bool,
402    ) -> ViewExpand<E, C, ReadOnly> {
403        let shape = self.__expand_shape_method(scope);
404        let pos = C::__expand_min(scope, pos, shape.clone());
405        let max_size = C::__expand_sub(scope, shape, pos.clone());
406        let size = C::__expand_min(scope, size, max_size);
407        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
408        self.clone().__expand_view_method(scope, layout.into())
409    }
410}
411
412#[allow(unused_variables)]
413impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
414    /// Write a line to `pos`. The layout handles translation into a concrete index.
415    pub fn write(&self, pos: C, value: E) {
416        unexpanded!()
417    }
418
419    /// Write a line to `pos` if it's in bounds. The layout handles translation into a concrete index.
420    pub fn write_checked(&self, pos: C, value: E) {
421        unexpanded!()
422    }
423
424    /// Interpret this view as a mutable linear slice encompassing the entire view.
425    ///
426    /// # Safety
427    ///
428    /// No checking is done on whether the slice is contiguous in memory.
429    pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
430        unexpanded!()
431    }
432}
433
434impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
435    /// Expand method for [TensorView::write]
436    pub fn __expand_write_method(
437        self,
438        scope: &mut Scope,
439        pos: C::ExpandType,
440        value: ExpandElementTyped<E>,
441    ) {
442        self.inner.write().__expand_write_method(scope, pos, value);
443    }
444
445    /// Expand method for [TensorView::write_checked]
446    pub fn __expand_write_checked_method(
447        self,
448        scope: &mut Scope,
449        pos: C::ExpandType,
450        value: ExpandElementTyped<E>,
451    ) {
452        self.inner
453            .write()
454            .__expand_write_checked_method(scope, pos, value);
455    }
456
457    pub fn __expand_to_linear_slice_mut_method(
458        self,
459        scope: &mut Scope,
460    ) -> SliceExpand<E, ReadWrite> {
461        let shape = self.inner.read().__expand_shape_method(scope);
462        let origin = C::__expand_from_int(scope, shape.clone(), 0);
463        // Inclusive end so clamping works correctly
464        let one = C::__expand_from_int(scope, shape.clone(), 1);
465        let shape = C::__expand_max(scope, shape, one.clone());
466        let end = C::__expand_sub(scope, shape, one);
467        self.inner
468            .write()
469            .__expand_to_linear_slice_mut_method(scope, origin, end)
470    }
471
472    pub(super) fn __expand_to_linear_slice_mut_inner_method(
473        self,
474        scope: &mut Scope,
475        pos: C::ExpandType,
476        end: C::ExpandType,
477    ) -> SliceExpand<E, ReadWrite> {
478        self.inner
479            .write()
480            .__expand_to_linear_slice_mut_method(scope, pos, end)
481    }
482}
483
484impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
485    /// Create a mutable slice starting from `pos`, with `size`.
486    /// The layout handles translation into concrete indices.
487    /// Size and pos will be clamped to the current layout size.
488    pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
489        unexpanded!()
490    }
491
492    /// Create a mutable slice starting from `pos`, with `size`.
493    /// The layout handles translation into concrete indices.
494    /// Size and pos will be clamped to the current layout size.
495    ///
496    /// # Safety
497    /// Access is always unchecked.
498    pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
499        unexpanded!()
500    }
501
502    pub fn __expand_slice_mut(
503        scope: &mut Scope,
504        this: ViewExpand<E, C, ReadWrite>,
505        pos: C::ExpandType,
506        size: C::ExpandType,
507    ) -> ViewExpand<E, C, ReadWrite> {
508        this.__expand_slice_mut_method(scope, pos, size)
509    }
510
511    pub fn __expand_slice_mut_unchecked(
512        scope: &mut Scope,
513        this: ViewExpand<E, C, ReadWrite>,
514        pos: C::ExpandType,
515        size: C::ExpandType,
516    ) -> ViewExpand<E, C, ReadWrite> {
517        this.__expand_slice_mut_unchecked_method(scope, pos, size)
518    }
519}
520
521#[cube]
522impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
523
524impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
525    pub fn __expand_slice_mut_method(
526        &self,
527        scope: &mut Scope,
528        pos: C::ExpandType,
529        size: C::ExpandType,
530    ) -> ViewExpand<E, C, ReadWrite> {
531        self.slice_mut(scope, pos, size, true)
532    }
533
534    pub fn __expand_slice_mut_unchecked_method(
535        &self,
536        scope: &mut Scope,
537        pos: C::ExpandType,
538        size: C::ExpandType,
539    ) -> ViewExpand<E, C, ReadWrite> {
540        self.slice_mut(scope, pos, size, false)
541    }
542
543    fn slice_mut(
544        &self,
545        scope: &mut Scope,
546        pos: C::ExpandType,
547        size: C::ExpandType,
548        checked: bool,
549    ) -> ViewExpand<E, C, ReadWrite> {
550        let shape = self.__expand_shape_method(scope);
551        let pos = C::__expand_min(scope, pos, shape.clone());
552        let max_size = C::__expand_sub(scope, shape, pos.clone());
553        let size = C::__expand_min(scope, size, max_size);
554        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
555        self.clone().__expand_view_mut_method(scope, layout.into())
556    }
557}
558
559impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
560    ///.Execute a TMA load into shared memory, if the underlying storage supports it.
561    /// Panics if it's unsupported.
562    pub fn tensor_map_load(
563        &self,
564        _barrier: &Barrier,
565        _shared_memory: &mut Slice<E, ReadWrite>,
566        _pos: C,
567    ) -> View<E, C, ReadWrite> {
568        unexpanded!()
569    }
570}
571
572impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
573    pub fn __expand_tensor_map_load_method(
574        self,
575        scope: &mut Scope,
576        barrier: BarrierExpand,
577        shared_memory: SliceExpand<E, ReadWrite>,
578        pos: C::ExpandType,
579    ) {
580        self.inner
581            .read()
582            .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
583    }
584}
585
586impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
587    ///.Execute a TMA store into global memory, if the underlying storage supports it.
588    /// Panics if it's unsupported.
589    pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
590        unexpanded!()
591    }
592}
593
594impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
595    pub fn __expand_tensor_map_store_method(
596        self,
597        scope: &mut Scope,
598        shared_memory: SliceExpand<E, ReadOnly>,
599        pos: C::ExpandType,
600    ) {
601        self.inner
602            .write()
603            .__expand_tensor_map_store_method(scope, shared_memory, pos)
604    }
605}