1use cubecl_core as cubecl;
2use cubecl_core::{prelude::*, unexpanded};
3
4use crate::tensor::{View, ViewExpand, layout::*};
5
6pub trait AsView<E: CubePrimitive>:
7 CubeType<ExpandType: AsViewExpand<E, SourceCoords = Self::SourceCoords>>
8{
9 type SourceCoords: Coordinates;
10
11 #[allow(unused)]
12 fn view<C: Coordinates + 'static>(
13 &self,
14 layout: impl Into<VirtualLayout<C, Self::SourceCoords>>,
15 ) -> View<E, C, ReadOnly> {
16 unexpanded!()
17 }
18
19 fn __expand_view<C: Coordinates + 'static>(
20 scope: &mut Scope,
21 this: Self::ExpandType,
22 layout: VirtualLayoutExpand<C, Self::SourceCoords>,
23 ) -> ViewExpand<E, C, ReadOnly> {
24 this.__expand_view_method(scope, layout)
25 }
26}
27
28pub trait AsViewExpand<E: CubePrimitive> {
29 type SourceCoords: Coordinates;
30
31 #[allow(unused)]
32 fn __expand_view_method<C: Coordinates + 'static>(
33 self,
34 scope: &mut Scope,
35 layout: VirtualLayoutExpand<C, Self::SourceCoords>,
36 ) -> ViewExpand<E, C, ReadOnly>;
37}
38
39pub trait AsViewMut<E: CubePrimitive>: AsView<E> {
40 #[allow(unused)]
41 fn view_mut<C: Coordinates + 'static>(
42 &mut self,
43 layout: impl Into<VirtualLayout<C, Self::SourceCoords>>,
44 ) -> View<E, C, ReadWrite> {
45 unexpanded!()
46 }
47}
48
49pub trait AsViewMutExpand<E: CubePrimitive>: AsViewExpand<E> {
50 #[allow(clippy::too_many_arguments)]
51 fn __expand_view_mut_method<C: Coordinates + 'static>(
52 self,
53 scope: &mut Scope,
54 layout: VirtualLayoutExpand<C, Self::SourceCoords>,
55 ) -> ViewExpand<E, C, ReadWrite>;
56}
57
58macro_rules! impl_as_view {
59 ($ty: ident, $coords: ty) => {
60 impl<E: CubePrimitive> AsView<E> for $ty<E> {
61 type SourceCoords = $coords;
62 }
63 impl<E: CubePrimitive> AsViewExpand<E> for ExpandElementTyped<$ty<E>> {
64 type SourceCoords = $coords;
65 fn __expand_view_method<C: Coordinates + 'static>(
66 self,
67 scope: &mut Scope,
68 layout: VirtualLayoutExpand<C, $coords>,
69 ) -> super::ViewExpand<E, C, ReadOnly> {
70 View::__expand_new::<$ty<E>, $coords>(scope, self, layout)
71 }
72 }
73
74 impl<E: CubePrimitive> AsViewMut<E> for $ty<E> {}
75 impl<E: CubePrimitive> AsViewMutExpand<E> for ExpandElementTyped<$ty<E>> {
76 fn __expand_view_mut_method<C: Coordinates + 'static>(
77 self,
78 scope: &mut Scope,
79 layout: VirtualLayoutExpand<C, $coords>,
80 ) -> super::ViewExpand<E, C, ReadWrite> {
81 View::__expand_new_mut::<$ty<E>, $coords>(scope, self, layout)
82 }
83 }
84 };
85}
86
87impl_as_view!(Array, Coords1d);
88impl_as_view!(Tensor, Coords1d);
89impl_as_view!(SharedMemory, Coords1d);
90
91impl<E: CubePrimitive, IO: SliceVisibility + 'static> AsView<E> for Slice<E, IO> {
92 type SourceCoords = Coords1d;
93 fn view<C: Coordinates + 'static>(
94 &self,
95 layout: impl Into<VirtualLayout<C, Coords1d>>,
96 ) -> View<E, C, ReadOnly> {
97 View::new::<Slice<E, IO>, Coords1d>(self, layout)
98 }
99}
100
101impl<E: CubePrimitive, IO: SliceVisibility + 'static> AsViewExpand<E> for SliceExpand<E, IO> {
102 type SourceCoords = Coords1d;
103 fn __expand_view_method<C: Coordinates + 'static>(
104 self,
105 scope: &mut Scope,
106 layout: VirtualLayoutExpand<C, Self::SourceCoords>,
107 ) -> ViewExpand<E, C, ReadOnly> {
108 View::__expand_new::<Slice<E, IO>, Self::SourceCoords>(scope, self, layout)
109 }
110}
111
112impl<E: CubePrimitive> AsViewMut<E> for Slice<E, ReadWrite> {
113 fn view_mut<C: Coordinates + 'static>(
114 &mut self,
115 layout: impl Into<VirtualLayout<C, Coords1d>>,
116 ) -> View<E, C, ReadWrite> {
117 View::new_mut::<Slice<E, ReadWrite>, Coords1d>(self, layout)
118 }
119}
120impl<E: CubePrimitive> AsViewMutExpand<E> for SliceExpand<E, ReadWrite> {
121 fn __expand_view_mut_method<C: Coordinates + 'static>(
122 self,
123 scope: &mut cubecl::prelude::Scope,
124 layout: VirtualLayoutExpand<C, Self::SourceCoords>,
125 ) -> ViewExpand<E, C, ReadWrite> {
126 View::__expand_new_mut::<Slice<E, ReadWrite>, Coords1d>(scope, self, layout)
127 }
128}
129
130macro_rules! as_view_tensor_map {
131 ($($dim: literal),*) => {
132 paste::paste! {
133 pub trait AsTensorView<E: CubePrimitive>:
134 CubeType<ExpandType: AsTensorViewExpand<E>>
135 {
136 $(
137 #[allow(unused)]
138 fn [<view_ $dim>]<C: Coordinates + 'static>(
139 &self,
140 layout: impl Into<VirtualLayout<C, [<Coords $dim>]>>,
141 ) -> View<E, C, ReadOnly> {
142 unexpanded!()
143 }
144
145 fn [<__expand_view_ $dim>]<C: Coordinates + 'static>(
146 scope: &mut Scope,
147 this: Self::ExpandType,
148 layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
149 ) -> ViewExpand<E, C, ReadOnly> {
150 this.[<__expand_view_ $dim _method>](scope, layout)
151 }
152 )*
153 }
154
155 pub trait AsTensorViewExpand<E: CubePrimitive> {
156 $(
157 #[allow(unused)]
158 fn [<__expand_view_ $dim _method>]<C: Coordinates + 'static>(
159 self,
160 scope: &mut Scope,
161 layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
162 ) -> ViewExpand<E, C, ReadOnly>;
163 )*
164 }
165
166 pub trait AsTensorViewMut<E: CubePrimitive>: AsTensorView<E> {
167 $(
168 #[allow(unused)]
169 fn [<view_mut_ $dim>]<C: Coordinates + 'static>(
170 &mut self,
171 layout: impl Into<VirtualLayout<C, [<Coords $dim>]>>,
172 ) -> View<E, C, ReadWrite> {
173 unexpanded!()
174 }
175 )*
176 }
177
178 pub trait AsTensorViewMutExpand<E: CubePrimitive>: AsTensorViewExpand<E> {
179 $(
180 #[allow(clippy::too_many_arguments)]
181 fn [<__expand_view_mut_ $dim _method>]<C: Coordinates + 'static>(
182 self,
183 scope: &mut Scope,
184 layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
185 ) -> ViewExpand<E, C, ReadWrite>;
186 )*
187 }
188
189 impl<E: CubePrimitive> AsTensorView<E> for TensorMap<E> {}
190 impl<E: CubePrimitive> AsTensorViewExpand<E> for ExpandElementTyped<TensorMap<E>> {
191 $(
192 fn [<__expand_view_ $dim _method>]<C: Coordinates + 'static>(
193 self,
194 scope: &mut Scope,
195 layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
196 ) -> super::ViewExpand<E, C, ReadOnly> {
197 View::__expand_new::<TensorMap<E>, [<Coords $dim>]>(scope, self, layout)
198 }
199 )*
200 }
201
202 impl<E: CubePrimitive> AsTensorViewMut<E> for TensorMap<E> {}
203 impl<E: CubePrimitive> AsTensorViewMutExpand<E> for ExpandElementTyped<TensorMap<E>> {
204 $(
205 fn [<__expand_view_mut_ $dim _method>]<C: Coordinates + 'static>(
206 self,
207 scope: &mut Scope,
208 layout: VirtualLayoutExpand<C, [<Coords $dim>]>,
209 ) -> super::ViewExpand<E, C, ReadWrite> {
210 View::__expand_new_mut::<TensorMap<E>, [<Coords $dim>]>(scope, self, layout)
211 }
212 )*
213 }
214 }
215 };
216}
217
218as_view_tensor_map!(1d, 2d, 3d, 4d, 5d, 1i, 2i, 3i, 4i, 5i);