1use super::*;
2use crate::{CubeOption, CubeOptionExpand, tensor::layout::*};
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
5
6macro_rules! impl_tensor_map {
8 ($dim: literal, $coords: ty, $($var: ident),*) => {
9 paste::paste! {
10 impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T> {}
11 impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for ExpandElementTyped<TensorMap<T>> {
12 fn __expand_read_method(
13 &self,
14 _scope: &mut Scope,
15 _pos: <$coords as CubeType>::ExpandType,
16 ) -> <T as CubeType>::ExpandType {
17 unimplemented!("Can't read from tensor map");
18 }
19
20 fn __expand_read_checked_method(
21 &self,
22 _scope: &mut Scope,
23 _pos: <$coords as CubeType>::ExpandType,
24 ) -> <T as CubeType>::ExpandType {
25 unimplemented!("Can't read from tensor map");
26 }
27
28 fn __expand_read_masked_method(
29 &self,
30 _scope: &mut Scope,
31 _pos: <$coords as CubeType>::ExpandType,
32 _mask_value: <T as CubeType>::ExpandType,
33 ) -> <T as CubeType>::ExpandType {
34 unimplemented!("Can't read from tensor map");
35 }
36
37 fn __expand_read_unchecked_method(
38 &self,
39 _scope: &mut Scope,
40 _pos: <$coords as CubeType>::ExpandType,
41 ) -> <T as CubeType>::ExpandType {
42 unimplemented!("Can't read from tensor map");
43 }
44
45 fn __expand_to_linear_slice_method(
46 &self,
47 _scope: &mut Scope,
48 _pos: <$coords as CubeType>::ExpandType,
49 _end: <$coords as CubeType>::ExpandType,
50 ) -> SliceExpand<T, ReadOnly> {
51 unimplemented!("Can't read from tensor map");
52 }
53
54 fn __expand_as_tensor_map_method(
55 &self,
56 scope: &mut Scope,
57 ) -> CubeOptionExpand<TensorMap<T>> {
58 CubeOption::__expand_new_Some(scope, self.clone())
59 }
60
61 fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
62 unimplemented!("Can't read from tensor map");
63 }
64
65 fn __expand_is_in_bounds_method(
66 &self,
67 _scope: &mut Scope,
68 _pos: <$coords as CubeType>::ExpandType,
69 ) -> ExpandElementTyped<bool> {
70 unimplemented!("Can't read from tensor map");
71 }
72
73 #[allow(unused_parens)]
74 fn __expand_tensor_map_load_method(
75 &self,
76 scope: &mut Scope,
77 barrier: BarrierExpand,
78 shared_memory: SliceExpand<T, ReadWrite>,
79 pos: <$coords as CubeType>::ExpandType,
80 ) {
81 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
82 let ($($var),*) = pos;
83 let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
84 barrier.[<__expand_tma_load_ $dim d_method>]::<T>(scope, self.clone(), shared, $($var),*);
85 }
86 }
87
88 impl<T: CubePrimitive> ViewOperationsMut<T, $coords> for TensorMap<T> {}
89 impl<T: CubePrimitive> ViewOperationsMutExpand<T, $coords> for ExpandElementTyped<TensorMap<T>> {
90 fn __expand_write_method(
91 &self,
92 _scope: &mut Scope,
93 _pos: <$coords as CubeType>::ExpandType,
94 _value: <T as CubeType>::ExpandType,
95 ) {
96 unimplemented!("Can't write to tensor map");
97 }
98
99 fn __expand_write_checked_method(
100 &self,
101 _scope: &mut Scope,
102 _pos: <$coords as CubeType>::ExpandType,
103 _value: <T as CubeType>::ExpandType,
104 ) {
105 unimplemented!("Can't write to tensor map");
106 }
107
108 fn __expand_to_linear_slice_mut_method(
109 &self,
110 _scope: &mut Scope,
111 _pos: <$coords as CubeType>::ExpandType,
112 _end: <$coords as CubeType>::ExpandType,
113 ) -> SliceExpand<T, ReadWrite> {
114 unimplemented!("Can't write to tensor map");
115 }
116
117 #[allow(unused_parens)]
118 fn __expand_tensor_map_store_method(
119 &self,
120 scope: &mut Scope,
121 shared_memory: SliceExpand<T, ReadOnly>,
122 pos: <$coords as CubeType>::ExpandType,
123 ) {
124 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
125 let ($($var),*) = pos;
126 let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
127 [<tma_store_ $dim d>]::expand(scope, shared, self.clone(), $($var),*);
128 }
129 }
130 }
131 };
132}
133
134impl_tensor_map!(1, Coords1d, x);
135impl_tensor_map!(2, Coords2d, x, y);
136impl_tensor_map!(3, Coords3d, x, y, z);
137impl_tensor_map!(4, Coords4d, x, y, z, v);
138impl_tensor_map!(5, Coords5d, x, y, z, v, w);
139
140impl_tensor_map!(1, Coords1i, x);
141impl_tensor_map!(2, Coords2i, x, y);
142impl_tensor_map!(3, Coords3i, x, y, z);
143impl_tensor_map!(4, Coords4i, x, y, z, v);
144impl_tensor_map!(5, Coords5i, x, y, z, v, w);
145
146fn as_i32<T: CubePrimitive>(
147 scope: &mut Scope,
148 pos: &SequenceExpand<T>,
149 i: u32,
150) -> ExpandElementTyped<i32> {
151 let x = pos.__expand_index_method(scope, i.into());
152 i32::__expand_cast_from(scope, x)
153}
154
155impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperations<T, Sequence<N>>
156 for TensorMap<T>
157{
158}
159impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsExpand<T, Sequence<N>>
160 for ExpandElementTyped<TensorMap<T>>
161{
162 fn __expand_read_method(
163 &self,
164 _scope: &mut Scope,
165 _pos: SequenceExpand<N>,
166 ) -> <T as CubeType>::ExpandType {
167 unimplemented!("Can't read from tensor map");
168 }
169
170 fn __expand_read_checked_method(
171 &self,
172 _scope: &mut Scope,
173 _pos: SequenceExpand<N>,
174 ) -> <T as CubeType>::ExpandType {
175 unimplemented!("Can't read from tensor map");
176 }
177
178 fn __expand_read_masked_method(
179 &self,
180 _scope: &mut Scope,
181 _pos: SequenceExpand<N>,
182 _mask_value: <T as CubeType>::ExpandType,
183 ) -> <T as CubeType>::ExpandType {
184 unimplemented!("Can't read from tensor map");
185 }
186
187 fn __expand_read_unchecked_method(
188 &self,
189 _scope: &mut Scope,
190 _pos: SequenceExpand<N>,
191 ) -> <T as CubeType>::ExpandType {
192 unimplemented!("Can't read from tensor map");
193 }
194
195 fn __expand_to_linear_slice_method(
196 &self,
197 _scope: &mut Scope,
198 _pos: SequenceExpand<N>,
199 _end: SequenceExpand<N>,
200 ) -> SliceExpand<T, ReadOnly> {
201 unimplemented!("Can't read from tensor map");
202 }
203
204 fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<T>> {
205 CubeOption::__expand_new_Some(scope, self.clone())
206 }
207
208 fn __expand_shape_method(&self, _scope: &mut Scope) -> SequenceExpand<N> {
209 unimplemented!("Can't read from tensor map");
210 }
211
212 fn __expand_is_in_bounds_method(
213 &self,
214 _scope: &mut Scope,
215 _pos: SequenceExpand<N>,
216 ) -> ExpandElementTyped<bool> {
217 unimplemented!("Can't read from tensor map");
218 }
219
220 #[allow(unused_parens)]
221 fn __expand_tensor_map_load_method(
222 &self,
223 scope: &mut Scope,
224 barrier: BarrierExpand,
225 shared_memory: SliceExpand<T, ReadWrite>,
226 pos: SequenceExpand<N>,
227 ) {
228 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
229 let rank = pos.len();
230 let pos = &pos;
231 match rank {
232 1 => {
233 let x = as_i32(scope, pos, 0);
234 barrier.__expand_tma_load_1d_method(scope, self.clone(), shared, x);
235 }
236 2 => {
237 let y = as_i32(scope, pos, 0);
238 let x = as_i32(scope, pos, 1);
239 barrier.__expand_tma_load_2d_method(scope, self.clone(), shared, y, x);
240 }
241 3 => {
242 let z = as_i32(scope, pos, 0);
243 let y = as_i32(scope, pos, 1);
244 let x = as_i32(scope, pos, 2);
245 barrier.__expand_tma_load_3d_method(scope, self.clone(), shared, z, y, x);
246 }
247 4 => {
248 let w = as_i32(scope, pos, 0);
249 let z = as_i32(scope, pos, 1);
250 let y = as_i32(scope, pos, 2);
251 let x = as_i32(scope, pos, 3);
252 barrier.__expand_tma_load_4d_method(scope, self.clone(), shared, w, z, y, x);
253 }
254 5 => {
255 let v = as_i32(scope, pos, 0);
256 let w = as_i32(scope, pos, 1);
257 let z = as_i32(scope, pos, 2);
258 let y = as_i32(scope, pos, 3);
259 let x = as_i32(scope, pos, 4);
260 barrier.__expand_tma_load_5d_method(scope, self.clone(), shared, v, w, z, y, x);
261 }
262 _ => panic!("TMA only supports 1D-5D loads"),
263 }
264 }
265}
266
267impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMut<T, Sequence<N>>
268 for TensorMap<T>
269{
270}
271impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMutExpand<T, Sequence<N>>
272 for ExpandElementTyped<TensorMap<T>>
273{
274 fn __expand_write_method(
275 &self,
276 _scope: &mut Scope,
277 _pos: SequenceExpand<N>,
278 _value: <T as CubeType>::ExpandType,
279 ) {
280 unimplemented!("Can't write to tensor map");
281 }
282
283 fn __expand_write_checked_method(
284 &self,
285 _scope: &mut Scope,
286 _pos: SequenceExpand<N>,
287 _value: <T as CubeType>::ExpandType,
288 ) {
289 unimplemented!("Can't write to tensor map");
290 }
291
292 fn __expand_to_linear_slice_mut_method(
293 &self,
294 _scope: &mut Scope,
295 _pos: SequenceExpand<N>,
296 _end: SequenceExpand<N>,
297 ) -> SliceExpand<T, ReadWrite> {
298 unimplemented!("Can't write to tensor map");
299 }
300
301 #[allow(unused_parens)]
302 fn __expand_tensor_map_store_method(
303 &self,
304 scope: &mut Scope,
305 shared_memory: SliceExpand<T, ReadOnly>,
306 pos: SequenceExpand<N>,
307 ) {
308 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
309 let rank = pos.len();
310 let pos = &pos;
311 match rank {
312 1 => {
313 let x = as_i32(scope, pos, 0);
314 tma_store_1d::expand(scope, shared, self.clone(), x);
315 }
316 2 => {
317 let y = as_i32(scope, pos, 0);
318 let x = as_i32(scope, pos, 1);
319 tma_store_2d::expand(scope, shared, self.clone(), y, x);
320 }
321 3 => {
322 let z = as_i32(scope, pos, 0);
323 let y = as_i32(scope, pos, 1);
324 let x = as_i32(scope, pos, 2);
325 tma_store_3d::expand(scope, shared, self.clone(), z, y, x);
326 }
327 4 => {
328 let w = as_i32(scope, pos, 0);
329 let z = as_i32(scope, pos, 1);
330 let y = as_i32(scope, pos, 2);
331 let x = as_i32(scope, pos, 3);
332 tma_store_4d::expand(scope, shared, self.clone(), w, z, y, x);
333 }
334 5 => {
335 let v = as_i32(scope, pos, 0);
336 let w = as_i32(scope, pos, 1);
337 let z = as_i32(scope, pos, 2);
338 let y = as_i32(scope, pos, 3);
339 let x = as_i32(scope, pos, 4);
340 tma_store_5d::expand(scope, shared, self.clone(), v, w, z, y, x);
341 }
342 _ => panic!("TMA store supports 1D-5D loads"),
343 }
344 }
345}