Skip to main content

cubecl_std/tensor/view/
base.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use cubecl::prelude::*;
4use cubecl_core::{
5    self as cubecl,
6    ir::VectorSize,
7    prelude::barrier::{Barrier, BarrierExpand},
8    unexpanded,
9};
10
11use crate::tensor::{
12    ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand, VirtualView,
13    VirtualViewMut,
14    layout::{Coordinates, Layout, VirtualLayout, VirtualLayoutExpand, slice::SliceLayout},
15};
16
17/// A conceptual view of an underlying linear storage.
18/// Allows abstract indexing in multiple dimensions, without having to know the data layout or
19/// location.
20#[derive(Clone)]
21pub struct View<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
22    _layout: PhantomData<C>,
23    _ty: PhantomData<(E, IO)>,
24}
25
26// `View` is a dummy type so it's always send/sync
27unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Send for View<E, C, IO> {}
28unsafe impl<E: CubePrimitive, C: Coordinates, IO: Clone> Sync for View<E, C, IO> {}
29impl<E: CubePrimitive, C: Coordinates, IO: Clone> Copy for View<E, C, IO> {}
30
31#[derive(Clone)]
32pub(super) enum ViewType<E: CubePrimitive, C: Coordinates> {
33    Read(Arc<dyn ViewOperationsExpand<E, C>>),
34    ReadWrite(Arc<dyn ViewOperationsMutExpand<E, C>>),
35}
36
37impl<E: CubePrimitive, C: Coordinates> ViewType<E, C> {
38    /// Dereference in read mode
39    pub fn read(&self) -> &dyn ViewOperationsExpand<E, C> {
40        match self {
41            ViewType::Read(list) => &**list,
42            ViewType::ReadWrite(list) => &**list,
43        }
44    }
45
46    /// Dereference in write mode
47    pub fn write(&self) -> &dyn ViewOperationsMutExpand<E, C> {
48        match self {
49            ViewType::Read(_) => panic!("Can't write to readonly list"),
50            ViewType::ReadWrite(list) => &**list,
51        }
52    }
53}
54
55/// Expand type of [`View`]
56#[derive(Clone)]
57pub struct ViewExpand<E: CubePrimitive, C: Coordinates, IO: Clone = ReadOnly> {
58    pub(super) inner: ViewType<E, C>,
59    pub(super) _io: PhantomData<IO>,
60}
61
62impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeType for View<E, C, IO> {
63    type ExpandType = ViewExpand<E, C, IO>;
64}
65
66impl<E: CubePrimitive, C: Coordinates, IO: Clone> IntoMut for ViewExpand<E, C, IO> {
67    fn into_mut(self, _scope: &mut Scope) -> Self {
68        self
69    }
70}
71
72impl<E: CubePrimitive, C: Coordinates, IO: Clone> CubeDebug for ViewExpand<E, C, IO> {}
73
74impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadOnly> {
75    /// Create a new tensor view from an underlying concrete storage and a layout to map it into
76    /// the target coordinate space
77    #[allow(unused_variables)]
78    pub fn new<V: ViewOperations<E, S>, S: Coordinates>(
79        view: &V,
80        layout: impl Into<VirtualLayout<C, S>>,
81    ) -> Self {
82        View {
83            _layout: PhantomData,
84            _ty: PhantomData,
85        }
86    }
87
88    /// Expand function for [`View::new`]
89    pub fn __expand_new<V: ViewOperations<E, S> + 'static, S: Coordinates + 'static>(
90        scope: &mut Scope,
91        view: V::ExpandType,
92        layout: VirtualLayoutExpand<C, S>,
93    ) -> ViewExpand<E, C, ReadOnly> {
94        ViewExpand::new(VirtualView::<E, C, S, V>::__expand_new(scope, view, layout))
95    }
96}
97
98impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
99    pub fn view<T: Coordinates>(
100        &self,
101        _layout: impl Into<VirtualLayout<T, C>>,
102    ) -> View<E, T, ReadOnly> {
103        unexpanded!()
104    }
105
106    pub fn __expand_view<T: Coordinates + 'static>(
107        scope: &mut Scope,
108        this: ViewExpand<E, C, IO>,
109        layout: VirtualLayoutExpand<T, C>,
110    ) -> ViewExpand<E, T, ReadOnly> {
111        this.__expand_view_method(scope, layout)
112    }
113}
114
115impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
116    pub fn __expand_view_method<T: Coordinates + 'static>(
117        self,
118        scope: &mut Scope,
119        layout: VirtualLayoutExpand<T, C>,
120    ) -> ViewExpand<E, T, ReadOnly> {
121        View::__expand_new::<View<E, C, IO>, C>(scope, self, layout)
122    }
123
124    pub fn new<V: ViewOperationsExpand<E, C> + 'static>(view: V) -> Self {
125        ViewExpand {
126            inner: ViewType::Read(Arc::new(view)),
127            _io: PhantomData,
128        }
129    }
130
131    pub fn new_mut<V: ViewOperationsMutExpand<E, C> + 'static>(view: V) -> Self {
132        ViewExpand {
133            inner: ViewType::ReadWrite(Arc::new(view)),
134            _io: PhantomData,
135        }
136    }
137}
138
139impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
140    pub fn view_mut<T: Coordinates>(
141        &self,
142        _layout: impl Layout<Coordinates = T, SourceCoordinates = C>,
143    ) -> View<E, T, ReadWrite> {
144        unexpanded!()
145    }
146
147    pub fn __expand_view_mut<T: Coordinates + 'static>(
148        scope: &mut Scope,
149        this: ViewExpand<E, C, ReadWrite>,
150        layout: VirtualLayoutExpand<T, C>,
151    ) -> ViewExpand<E, T, ReadWrite> {
152        this.__expand_view_mut_method(scope, layout)
153    }
154}
155
156impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
157    pub fn __expand_view_mut_method<T: Coordinates + 'static>(
158        self,
159        scope: &mut Scope,
160        layout: VirtualLayoutExpand<T, C>,
161    ) -> ViewExpand<E, T, ReadWrite> {
162        View::__expand_new_mut::<View<E, C, ReadWrite>, C>(scope, self, layout)
163    }
164}
165
166impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
167    /// Create a new mutable tensor view from an underlying concrete storage and a layout to map it
168    /// into the target coordinate space
169    pub fn new_mut<V: ViewOperationsMut<E, S>, S: Coordinates>(
170        _view: &mut V,
171        _layout: impl Into<VirtualLayout<C, S>>,
172    ) -> View<E, C, ReadWrite> {
173        View {
174            _ty: PhantomData,
175            _layout: PhantomData,
176        }
177    }
178
179    /// Expand function for [`View::new_mut`]
180    pub fn __expand_new_mut<V: ViewOperationsMut<E, S> + 'static, S: Coordinates + 'static>(
181        scope: &mut Scope,
182        view: V::ExpandType,
183        layout: VirtualLayoutExpand<C, S>,
184    ) -> ViewExpand<E, C, ReadWrite> {
185        ViewExpand::new_mut(VirtualViewMut::<E, C, S, V>::__expand_new(
186            scope, view, layout,
187        ))
188    }
189}
190
191impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
192    /// Calls [`Layout::shape`] on the view's layout
193    pub fn shape(&self) -> C {
194        unexpanded!()
195    }
196
197    /// Calls [`Layout::is_in_bounds`] on the view's layout
198    pub fn is_in_bounds(&self, _pos: C) -> bool {
199        unexpanded!()
200    }
201
202    pub fn __expand_shape(scope: &mut Scope, this: ViewExpand<E, C, IO>) -> C::ExpandType {
203        this.__expand_shape_method(scope)
204    }
205
206    pub fn __expand_is_in_bounds(
207        scope: &mut Scope,
208        this: ViewExpand<E, C, IO>,
209        pos: C::ExpandType,
210    ) -> NativeExpand<bool> {
211        this.__expand_is_in_bounds_method(scope, pos)
212    }
213}
214
215impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
216    pub fn __expand_shape_method(&self, scope: &mut Scope) -> C::ExpandType {
217        self.inner.read().__expand_shape_method(scope)
218    }
219
220    pub fn __expand_is_in_bounds_method(
221        &self,
222        scope: &mut Scope,
223        pos: C::ExpandType,
224    ) -> NativeExpand<bool> {
225        self.inner.read().__expand_is_in_bounds_method(scope, pos)
226    }
227}
228
229#[allow(unused_variables)]
230impl<E: CubePrimitive, C: Coordinates, IO: Clone> View<E, C, IO> {
231    /// Read a value at `pos`. The layout handles translation into a concrete index.
232    pub fn read(&self, pos: C) -> E {
233        unexpanded!()
234    }
235
236    /// Read a value at `pos`. The layout handles translation into a concrete index.
237    /// Reading is done unchecked
238    pub fn read_unchecked(&self, pos: C) -> E {
239        unexpanded!()
240    }
241
242    /// Read a value at `pos` if it's in bounds. The layout handles translation into a concrete index.
243    pub fn read_checked(&self, pos: C) -> E {
244        unexpanded!()
245    }
246
247    /// Read a value at `pos` if it's in bounds, returning `mask_value` otherwise. The layout handles translation into a concrete index.
248    pub fn read_masked(&self, pos: C, mask_value: E) -> E {
249        unexpanded!()
250    }
251
252    /// Interpret this view as a linear slice encompassing the entire view.
253    ///
254    /// # Safety
255    ///
256    /// No checking is done on whether the slice is contiguous in memory.
257    pub fn to_linear_slice(&self) -> Slice<E, ReadOnly> {
258        unexpanded!()
259    }
260
261    pub fn vector_size(&self) -> VectorSize {
262        unexpanded!()
263    }
264}
265
266impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
267    /// Expand method for [`View::read`]
268    pub fn __expand_read_method(self, scope: &mut Scope, pos: C::ExpandType) -> NativeExpand<E> {
269        self.inner.read().__expand_read_method(scope, pos)
270    }
271
272    /// Expand method for [`View::read_unchecked`]
273    pub fn __expand_read_unchecked_method(
274        self,
275        scope: &mut Scope,
276        pos: C::ExpandType,
277    ) -> NativeExpand<E> {
278        self.inner.read().__expand_read_unchecked_method(scope, pos)
279    }
280
281    /// Expand method for [`View::read_checked`]
282    pub fn __expand_read_checked_method(
283        self,
284        scope: &mut Scope,
285        pos: C::ExpandType,
286    ) -> NativeExpand<E> {
287        self.inner.read().__expand_read_checked_method(scope, pos)
288    }
289
290    /// Expand method for [`View::read_masked`]
291    pub fn __expand_read_masked_method(
292        self,
293        scope: &mut Scope,
294        pos: C::ExpandType,
295        mask_value: E::ExpandType,
296    ) -> NativeExpand<E> {
297        self.inner
298            .read()
299            .__expand_read_masked_method(scope, pos, mask_value)
300    }
301
302    /// Expand method for [`View::vector_size`]
303    pub fn __expand_vector_size_method(&self, _scope: &mut Scope) -> VectorSize {
304        self.inner.read().vector_size()
305    }
306
307    pub fn vector_size(&self) -> VectorSize {
308        self.inner.read().vector_size()
309    }
310
311    pub fn __expand_to_linear_slice_method(self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
312        let shape = self.inner.read().__expand_shape_method(scope);
313        let origin = C::__expand_from_int(scope, shape.clone(), 0);
314        // Inclusive end so clamping works correctly
315        let one = C::__expand_from_int(scope, shape.clone(), 1);
316        let shape = C::__expand_max(scope, shape, one.clone());
317        let end = C::__expand_sub(scope, shape, one);
318        self.inner
319            .read()
320            .__expand_to_linear_slice_method(scope, origin, end)
321    }
322
323    pub(super) fn __expand_to_linear_slice_inner_method(
324        self,
325        scope: &mut Scope,
326        pos: C::ExpandType,
327        end: C::ExpandType,
328    ) -> SliceExpand<E, ReadOnly> {
329        self.inner
330            .read()
331            .__expand_to_linear_slice_method(scope, pos, end)
332    }
333}
334
335impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {
336    /// Create a slice starting from `pos`, with `size`.
337    /// The layout handles translation into concrete indices.
338    /// Size will be clamped to the current layout size.
339    pub fn slice(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
340        unexpanded!()
341    }
342
343    /// Create a slice starting from `pos`, with `size`.
344    /// The layout handles translation into concrete indices.
345    /// Size and pos will be clamped to the current layout size.
346    /// #Safety
347    /// Access is always unchecked
348    pub fn slice_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadOnly> {
349        unexpanded!()
350    }
351
352    pub fn __expand_slice(
353        scope: &mut Scope,
354        this: ViewExpand<E, C, IO>,
355        pos: C::ExpandType,
356        size: C::ExpandType,
357    ) -> ViewExpand<E, C, ReadOnly> {
358        this.__expand_slice_method(scope, pos, size)
359    }
360
361    pub fn __expand_slice_unchecked(
362        scope: &mut Scope,
363        this: ViewExpand<E, C, IO>,
364        pos: C::ExpandType,
365        size: C::ExpandType,
366    ) -> ViewExpand<E, C, ReadOnly> {
367        this.__expand_slice_unchecked_method(scope, pos, size)
368    }
369}
370
371#[cube]
372impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> View<E, C, IO> {}
373
374impl<E: CubePrimitive, C: Coordinates + 'static, IO: Clone + 'static> ViewExpand<E, C, IO> {
375    pub fn __expand_slice_method(
376        &self,
377        scope: &mut Scope,
378        pos: C::ExpandType,
379        size: C::ExpandType,
380    ) -> ViewExpand<E, C, ReadOnly> {
381        self.slice(scope, pos, size, true)
382    }
383
384    pub fn __expand_slice_unchecked_method(
385        &self,
386        scope: &mut Scope,
387        pos: C::ExpandType,
388        size: C::ExpandType,
389    ) -> ViewExpand<E, C, ReadOnly> {
390        self.slice(scope, pos, size, false)
391    }
392
393    fn slice(
394        &self,
395        scope: &mut Scope,
396        pos: C::ExpandType,
397        size: C::ExpandType,
398        checked: bool,
399    ) -> ViewExpand<E, C, ReadOnly> {
400        let shape = self.__expand_shape_method(scope);
401        let pos = C::__expand_min(scope, pos, shape.clone());
402        let max_size = C::__expand_sub(scope, shape, pos.clone());
403        let size = C::__expand_min(scope, size, max_size);
404        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
405        self.clone().__expand_view_method(scope, layout.into())
406    }
407}
408
409#[allow(unused_variables)]
410impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
411    /// Write a value to `pos`. The layout handles translation into a concrete index.
412    pub fn write(&self, pos: C, value: E) {
413        unexpanded!()
414    }
415
416    /// Write a value to `pos` if it's in bounds. The layout handles translation into a concrete index.
417    pub fn write_checked(&self, pos: C, value: E) {
418        unexpanded!()
419    }
420
421    /// Interpret this view as a mutable linear slice encompassing the entire view.
422    ///
423    /// # Safety
424    ///
425    /// No checking is done on whether the slice is contiguous in memory.
426    pub fn to_linear_slice_mut(&self) -> Slice<E, ReadWrite> {
427        unexpanded!()
428    }
429}
430
431impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
432    /// Expand method for [`View::write`]
433    pub fn __expand_write_method(
434        self,
435        scope: &mut Scope,
436        pos: C::ExpandType,
437        value: NativeExpand<E>,
438    ) {
439        self.inner.write().__expand_write_method(scope, pos, value);
440    }
441
442    /// Expand method for [`View::write_checked`]
443    pub fn __expand_write_checked_method(
444        self,
445        scope: &mut Scope,
446        pos: C::ExpandType,
447        value: NativeExpand<E>,
448    ) {
449        self.inner
450            .write()
451            .__expand_write_checked_method(scope, pos, value);
452    }
453
454    pub fn __expand_to_linear_slice_mut_method(
455        self,
456        scope: &mut Scope,
457    ) -> SliceExpand<E, ReadWrite> {
458        let shape = self.inner.read().__expand_shape_method(scope);
459        let origin = C::__expand_from_int(scope, shape.clone(), 0);
460        // Inclusive end so clamping works correctly
461        let one = C::__expand_from_int(scope, shape.clone(), 1);
462        let shape = C::__expand_max(scope, shape, one.clone());
463        let end = C::__expand_sub(scope, shape, one);
464        self.inner
465            .write()
466            .__expand_to_linear_slice_mut_method(scope, origin, end)
467    }
468
469    pub(super) fn __expand_to_linear_slice_mut_inner_method(
470        self,
471        scope: &mut Scope,
472        pos: C::ExpandType,
473        end: C::ExpandType,
474    ) -> SliceExpand<E, ReadWrite> {
475        self.inner
476            .write()
477            .__expand_to_linear_slice_mut_method(scope, pos, end)
478    }
479}
480
481impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {
482    /// Create a mutable slice starting from `pos`, with `size`.
483    /// The layout handles translation into concrete indices.
484    /// Size and pos will be clamped to the current layout size.
485    pub fn slice_mut(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
486        unexpanded!()
487    }
488
489    /// Create a mutable slice starting from `pos`, with `size`.
490    /// The layout handles translation into concrete indices.
491    /// Size and pos will be clamped to the current layout size.
492    ///
493    /// # Safety
494    /// Access is always unchecked.
495    pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View<E, C, ReadWrite> {
496        unexpanded!()
497    }
498
499    pub fn __expand_slice_mut(
500        scope: &mut Scope,
501        this: ViewExpand<E, C, ReadWrite>,
502        pos: C::ExpandType,
503        size: C::ExpandType,
504    ) -> ViewExpand<E, C, ReadWrite> {
505        this.__expand_slice_mut_method(scope, pos, size)
506    }
507
508    pub fn __expand_slice_mut_unchecked(
509        scope: &mut Scope,
510        this: ViewExpand<E, C, ReadWrite>,
511        pos: C::ExpandType,
512        size: C::ExpandType,
513    ) -> ViewExpand<E, C, ReadWrite> {
514        this.__expand_slice_mut_unchecked_method(scope, pos, size)
515    }
516}
517
518#[cube]
519impl<E: CubePrimitive, C: Coordinates + 'static> View<E, C, ReadWrite> {}
520
521impl<E: CubePrimitive, C: Coordinates + 'static> ViewExpand<E, C, ReadWrite> {
522    pub fn __expand_slice_mut_method(
523        &self,
524        scope: &mut Scope,
525        pos: C::ExpandType,
526        size: C::ExpandType,
527    ) -> ViewExpand<E, C, ReadWrite> {
528        self.slice_mut(scope, pos, size, true)
529    }
530
531    pub fn __expand_slice_mut_unchecked_method(
532        &self,
533        scope: &mut Scope,
534        pos: C::ExpandType,
535        size: C::ExpandType,
536    ) -> ViewExpand<E, C, ReadWrite> {
537        self.slice_mut(scope, pos, size, false)
538    }
539
540    fn slice_mut(
541        &self,
542        scope: &mut Scope,
543        pos: C::ExpandType,
544        size: C::ExpandType,
545        checked: bool,
546    ) -> ViewExpand<E, C, ReadWrite> {
547        let shape = self.__expand_shape_method(scope);
548        let pos = C::__expand_min(scope, pos, shape.clone());
549        let max_size = C::__expand_sub(scope, shape, pos.clone());
550        let size = C::__expand_min(scope, size, max_size);
551        let layout = SliceLayout::__expand_new(scope, pos, size, checked);
552        self.clone().__expand_view_mut_method(scope, layout.into())
553    }
554}
555
556impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> View<E, C, IO> {
557    ///.Execute a TMA load into shared memory, if the underlying storage supports it.
558    /// Panics if it's unsupported.
559    pub fn tensor_map_load(
560        &self,
561        _barrier: &Barrier,
562        _shared_memory: &mut Slice<E, ReadWrite>,
563        _pos: C,
564    ) -> View<E, C, ReadWrite> {
565        unexpanded!()
566    }
567}
568
569impl<E: CubePrimitive, C: Coordinates, IO: Clone> ViewExpand<E, C, IO> {
570    pub fn __expand_tensor_map_load_method(
571        self,
572        scope: &mut Scope,
573        barrier: BarrierExpand,
574        shared_memory: SliceExpand<E, ReadWrite>,
575        pos: C::ExpandType,
576    ) {
577        self.inner
578            .read()
579            .__expand_tensor_map_load_method(scope, barrier, shared_memory, pos)
580    }
581}
582
583impl<E: CubePrimitive, C: Coordinates> View<E, C, ReadWrite> {
584    ///.Execute a TMA store into global memory, if the underlying storage supports it.
585    /// Panics if it's unsupported.
586    pub fn tensor_map_store(&self, _shared_memory: &Slice<E>, _pos: C) -> View<E, C, ReadWrite> {
587        unexpanded!()
588    }
589}
590
591impl<E: CubePrimitive, C: Coordinates> ViewExpand<E, C, ReadWrite> {
592    pub fn __expand_tensor_map_store_method(
593        self,
594        scope: &mut Scope,
595        shared_memory: SliceExpand<E, ReadOnly>,
596        pos: C::ExpandType,
597    ) {
598        self.inner
599            .write()
600            .__expand_tensor_map_store_method(scope, shared_memory, pos)
601    }
602}