cubecl_std/tensor/view/
as_view.rs

1use cubecl_core as cubecl;
2use cubecl_core::{prelude::*, unexpanded};
3
4use crate::tensor::{View, ViewExpand, layout::*};
5
6pub trait AsView<E: CubePrimitive>:
7    CubeType<ExpandType: AsViewExpand<E, SourceCoords = Self::SourceCoords>>
8{
9    type SourceCoords: Coordinates;
10
11    #[allow(unused)]
12    fn view<C: Coordinates + 'static>(
13        &self,
14        layout: impl Into<VirtualLayout<C, Self::SourceCoords>>,
15    ) -> View<E, C, ReadOnly> {
16        unexpanded!()
17    }
18
19    fn __expand_view<C: Coordinates + 'static>(
20        scope: &mut Scope,
21        this: Self::ExpandType,
22        layout: VirtualLayoutExpand<C, Self::SourceCoords>,
23    ) -> ViewExpand<E, C, ReadOnly> {
24        this.__expand_view_method(scope, layout)
25    }
26}
27
28pub trait AsViewExpand<E: CubePrimitive> {
29    type SourceCoords: Coordinates;
30
31    #[allow(unused)]
32    fn __expand_view_method<C: Coordinates + 'static>(
33        self,
34        scope: &mut Scope,
35        layout: VirtualLayoutExpand<C, Self::SourceCoords>,
36    ) -> ViewExpand<E, C, ReadOnly>;
37}
38
39pub trait AsViewMut<E: CubePrimitive>: AsView<E> {
40    #[allow(unused)]
41    fn view_mut<C: Coordinates + 'static>(
42        &mut self,
43        layout: impl Into<VirtualLayout<C, Self::SourceCoords>>,
44    ) -> View<E, C, ReadWrite> {
45        unexpanded!()
46    }
47}
48
49pub trait AsViewMutExpand<E: CubePrimitive>: AsViewExpand<E> {
50    #[allow(clippy::too_many_arguments)]
51    fn __expand_view_mut_method<C: Coordinates + 'static>(
52        self,
53        scope: &mut Scope,
54        layout: VirtualLayoutExpand<C, Self::SourceCoords>,
55    ) -> ViewExpand<E, C, ReadWrite>;
56}
57
58macro_rules! impl_as_view {
59    ($ty: ident, $coords: ty) => {
60        impl<E: CubePrimitive> AsView<E> for $ty<E> {
61            type SourceCoords = $coords;
62        }
63        impl<E: CubePrimitive> AsViewExpand<E> for ExpandElementTyped<$ty<E>> {
64            type SourceCoords = $coords;
65            fn __expand_view_method<C: Coordinates + 'static>(
66                self,
67                scope: &mut Scope,
68                layout: VirtualLayoutExpand<C, $coords>,
69            ) -> super::ViewExpand<E, C, ReadOnly> {
70                View::__expand_new::<$ty<E>, $coords>(scope, self, layout)
71            }
72        }
73
74        impl<E: CubePrimitive> AsViewMut<E> for $ty<E> {}
75        impl<E: CubePrimitive> AsViewMutExpand<E> for ExpandElementTyped<$ty<E>> {
76            fn __expand_view_mut_method<C: Coordinates + 'static>(
77                self,
78                scope: &mut Scope,
79                layout: VirtualLayoutExpand<C, $coords>,
80            ) -> super::ViewExpand<E, C, ReadWrite> {
81                View::__expand_new_mut::<$ty<E>, $coords>(scope, self, layout)
82            }
83        }
84    };
85}
86
87impl_as_view!(Array, Coords1d);
88impl_as_view!(Tensor, Coords1d);
89impl_as_view!(SharedMemory, Coords1d);
90
91impl<E: CubePrimitive, IO: SliceVisibility + 'static> AsView<E> for Slice<E, IO> {
92    type SourceCoords = Coords1d;
93    fn view<C: Coordinates + 'static>(
94        &self,
95        layout: impl Into<VirtualLayout<C, Coords1d>>,
96    ) -> View<E, C, ReadOnly> {
97        View::new::<Slice<E, IO>, Coords1d>(self, layout)
98    }
99}
100
101impl<E: CubePrimitive, IO: SliceVisibility + 'static> AsViewExpand<E> for SliceExpand<E, IO> {
102    type SourceCoords = Coords1d;
103    fn __expand_view_method<C: Coordinates + 'static>(
104        self,
105        scope: &mut Scope,
106        layout: VirtualLayoutExpand<C, Self::SourceCoords>,
107    ) -> ViewExpand<E, C, ReadOnly> {
108        View::__expand_new::<Slice<E, IO>, Self::SourceCoords>(scope, self, layout)
109    }
110}
111
112impl<E: CubePrimitive> AsViewMut<E> for Slice<E, ReadWrite> {
113    fn view_mut<C: Coordinates + 'static>(
114        &mut self,
115        layout: impl Into<VirtualLayout<C, Coords1d>>,
116    ) -> View<E, C, ReadWrite> {
117        View::new_mut::<Slice<E, ReadWrite>, Coords1d>(self, layout)
118    }
119}
120impl<E: CubePrimitive> AsViewMutExpand<E> for SliceExpand<E, ReadWrite> {
121    fn __expand_view_mut_method<C: Coordinates + 'static>(
122        self,
123        scope: &mut cubecl::prelude::Scope,
124        layout: VirtualLayoutExpand<C, Self::SourceCoords>,
125    ) -> ViewExpand<E, C, ReadWrite> {
126        View::__expand_new_mut::<Slice<E, ReadWrite>, Coords1d>(scope, self, layout)
127    }
128}
129
130macro_rules! as_view_tensor_map {
131    ($($dim: literal),*) => {
132        paste::paste! {
133            pub trait AsTensorView<E: CubePrimitive>:
134                CubeType<ExpandType: AsTensorViewExpand<E>>
135            {
136                $(
137                    #[allow(unused)]
138                    fn [<view_ $dim>]<C: Coordinates + 'static>(
139                        &self,
140                        layout: impl Into<VirtualLayout<C, [<Coords $dim>]>>,
141                    ) -> View<E, C, ReadOnly> {
142                        unexpanded!()
143                    }
144
145                    fn [<__expand_view_ $dim>]<C: Coordinates + 'static>(
146                        scope: &mut Scope,
147                        this: Self::ExpandType,
148                        layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
149                    ) -> ViewExpand<E, C, ReadOnly> {
150                        this.[<__expand_view_ $dim _method>](scope, layout)
151                    }
152                )*
153            }
154
155            pub trait AsTensorViewExpand<E: CubePrimitive> {
156                $(
157                    #[allow(unused)]
158                    fn [<__expand_view_ $dim _method>]<C: Coordinates + 'static>(
159                        self,
160                        scope: &mut Scope,
161                        layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
162                    ) -> ViewExpand<E, C, ReadOnly>;
163                )*
164            }
165
166            pub trait AsTensorViewMut<E: CubePrimitive>: AsTensorView<E> {
167                $(
168                    #[allow(unused)]
169                    fn [<view_mut_ $dim>]<C: Coordinates + 'static>(
170                        &mut self,
171                        layout: impl Into<VirtualLayout<C, [<Coords $dim>]>>,
172                    ) -> View<E, C, ReadWrite> {
173                        unexpanded!()
174                    }
175                )*
176            }
177
178            pub trait AsTensorViewMutExpand<E: CubePrimitive>: AsTensorViewExpand<E> {
179                $(
180                    #[allow(clippy::too_many_arguments)]
181                    fn [<__expand_view_mut_ $dim _method>]<C: Coordinates + 'static>(
182                        self,
183                        scope: &mut Scope,
184                        layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
185                    ) -> ViewExpand<E, C, ReadWrite>;
186                )*
187            }
188
189            impl<E: CubePrimitive> AsTensorView<E> for TensorMap<E> {}
190            impl<E: CubePrimitive> AsTensorViewExpand<E> for ExpandElementTyped<TensorMap<E>> {
191                $(
192                    fn [<__expand_view_ $dim _method>]<C: Coordinates + 'static>(
193                        self,
194                        scope: &mut Scope,
195                        layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
196                    ) -> super::ViewExpand<E, C, ReadOnly> {
197                        View::__expand_new::<TensorMap<E>, [<Coords $dim>]>(scope, self, layout)
198                    }
199                )*
200            }
201
202            impl<E: CubePrimitive> AsTensorViewMut<E> for TensorMap<E> {}
203            impl<E: CubePrimitive> AsTensorViewMutExpand<E> for ExpandElementTyped<TensorMap<E>> {
204                $(
205                    fn [<__expand_view_mut_ $dim _method>]<C: Coordinates + 'static>(
206                        self,
207                        scope: &mut Scope,
208                        layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
209                    ) -> super::ViewExpand<E, C, ReadWrite> {
210                        View::__expand_new_mut::<TensorMap<E>, [<Coords $dim>]>(scope, self, layout)
211                    }
212                )*
213            }
214        }
215    };
216}
217
218as_view_tensor_map!(1d, 2d, 3d, 4d, 5d, 1i, 2i, 3i, 4i, 5i);