1use super::*;
2use crate::{CubeOption, CubeOptionExpand, tensor::layout::*};
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
5
6macro_rules! impl_tensor_map {
8    ($dim: literal, $coords: ty, $($var: ident),*) => {
9        paste::paste! {
10            impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T> {}
11            impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for ExpandElementTyped<TensorMap<T>> {
12                fn __expand_read_method(
13                    &self,
14                    _scope: &mut Scope,
15                    _pos: <$coords as CubeType>::ExpandType,
16                ) -> <T as CubeType>::ExpandType {
17                    unimplemented!("Can't read from tensor map");
18                }
19
20                fn __expand_read_checked_method(
21                    &self,
22                    _scope: &mut Scope,
23                    _pos: <$coords as CubeType>::ExpandType,
24                ) -> <T as CubeType>::ExpandType {
25                    unimplemented!("Can't read from tensor map");
26                }
27
28                fn __expand_read_masked_method(
29                    &self,
30                    _scope: &mut Scope,
31                    _pos: <$coords as CubeType>::ExpandType,
32                    _mask_value: <T as CubeType>::ExpandType,
33                ) -> <T as CubeType>::ExpandType {
34                    unimplemented!("Can't read from tensor map");
35                }
36
37                fn __expand_read_unchecked_method(
38                    &self,
39                    _scope: &mut Scope,
40                    _pos: <$coords as CubeType>::ExpandType,
41                ) -> <T as CubeType>::ExpandType {
42                    unimplemented!("Can't read from tensor map");
43                }
44
45                fn __expand_to_linear_slice_method(
46                    &self,
47                    _scope: &mut Scope,
48                    _pos: <$coords as CubeType>::ExpandType,
49                    _end: <$coords as CubeType>::ExpandType,
50                ) -> SliceExpand<T, ReadOnly> {
51                    unimplemented!("Can't read from tensor map");
52                }
53
54                fn __expand_as_tensor_map_method(
55                    &self,
56                    scope: &mut Scope,
57                ) -> CubeOptionExpand<TensorMap<T>> {
58                    CubeOption::__expand_new_Some(scope, self.clone())
59                }
60
61                fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
62                    unimplemented!("Can't read from tensor map");
63                }
64
65                fn __expand_is_in_bounds_method(
66                    &self,
67                    _scope: &mut Scope,
68                    _pos: <$coords as CubeType>::ExpandType,
69                ) -> ExpandElementTyped<bool> {
70                    unimplemented!("Can't read from tensor map");
71                }
72
73                #[allow(unused_parens)]
74                fn __expand_tensor_map_load_method(
75                    &self,
76                    scope: &mut Scope,
77                    barrier: BarrierExpand,
78                    shared_memory: SliceExpand<T, ReadWrite>,
79                    pos: <$coords as CubeType>::ExpandType,
80                ) {
81                    let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
82                    let ($($var),*) = pos;
83                    let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
84                    barrier.[<__expand_tma_load_ $dim d_method>]::<T>(scope, self.clone(), shared, $($var),*);
85                }
86            }
87
88            impl<T: CubePrimitive> ViewOperationsMut<T, $coords> for TensorMap<T> {}
89            impl<T: CubePrimitive> ViewOperationsMutExpand<T, $coords> for ExpandElementTyped<TensorMap<T>> {
90                fn __expand_write_method(
91                    &self,
92                    _scope: &mut Scope,
93                    _pos: <$coords as CubeType>::ExpandType,
94                    _value: <T as CubeType>::ExpandType,
95                ) {
96                    unimplemented!("Can't write to tensor map");
97                }
98
99                fn __expand_write_checked_method(
100                    &self,
101                    _scope: &mut Scope,
102                    _pos: <$coords as CubeType>::ExpandType,
103                    _value: <T as CubeType>::ExpandType,
104                ) {
105                    unimplemented!("Can't write to tensor map");
106                }
107
108                fn __expand_to_linear_slice_mut_method(
109                    &self,
110                    _scope: &mut Scope,
111                    _pos: <$coords as CubeType>::ExpandType,
112                    _end: <$coords as CubeType>::ExpandType,
113                ) -> SliceExpand<T, ReadWrite> {
114                    unimplemented!("Can't write to tensor map");
115                }
116
117                #[allow(unused_parens)]
118                fn __expand_tensor_map_store_method(
119                    &self,
120                    scope: &mut Scope,
121                    shared_memory: SliceExpand<T, ReadOnly>,
122                    pos: <$coords as CubeType>::ExpandType,
123                ) {
124                    let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
125                    let ($($var),*) = pos;
126                    let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
127                    [<tma_store_ $dim d>]::expand(scope, shared, self.clone(), $($var),*);
128                }
129            }
130        }
131    };
132}
133
134impl_tensor_map!(1, Coords1d, x);
135impl_tensor_map!(2, Coords2d, x, y);
136impl_tensor_map!(3, Coords3d, x, y, z);
137impl_tensor_map!(4, Coords4d, x, y, z, v);
138impl_tensor_map!(5, Coords5d, x, y, z, v, w);
139
140impl_tensor_map!(1, Coords1i, x);
141impl_tensor_map!(2, Coords2i, x, y);
142impl_tensor_map!(3, Coords3i, x, y, z);
143impl_tensor_map!(4, Coords4i, x, y, z, v);
144impl_tensor_map!(5, Coords5i, x, y, z, v, w);
145
146fn as_i32<T: CubePrimitive>(
147    scope: &mut Scope,
148    pos: &SequenceExpand<T>,
149    i: u32,
150) -> ExpandElementTyped<i32> {
151    let x = pos.__expand_index_method(scope, i.into());
152    i32::__expand_cast_from(scope, x)
153}
154
155impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperations<T, Sequence<N>>
156    for TensorMap<T>
157{
158}
159impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsExpand<T, Sequence<N>>
160    for ExpandElementTyped<TensorMap<T>>
161{
162    fn __expand_read_method(
163        &self,
164        _scope: &mut Scope,
165        _pos: SequenceExpand<N>,
166    ) -> <T as CubeType>::ExpandType {
167        unimplemented!("Can't read from tensor map");
168    }
169
170    fn __expand_read_checked_method(
171        &self,
172        _scope: &mut Scope,
173        _pos: SequenceExpand<N>,
174    ) -> <T as CubeType>::ExpandType {
175        unimplemented!("Can't read from tensor map");
176    }
177
178    fn __expand_read_masked_method(
179        &self,
180        _scope: &mut Scope,
181        _pos: SequenceExpand<N>,
182        _mask_value: <T as CubeType>::ExpandType,
183    ) -> <T as CubeType>::ExpandType {
184        unimplemented!("Can't read from tensor map");
185    }
186
187    fn __expand_read_unchecked_method(
188        &self,
189        _scope: &mut Scope,
190        _pos: SequenceExpand<N>,
191    ) -> <T as CubeType>::ExpandType {
192        unimplemented!("Can't read from tensor map");
193    }
194
195    fn __expand_to_linear_slice_method(
196        &self,
197        _scope: &mut Scope,
198        _pos: SequenceExpand<N>,
199        _end: SequenceExpand<N>,
200    ) -> SliceExpand<T, ReadOnly> {
201        unimplemented!("Can't read from tensor map");
202    }
203
204    fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<T>> {
205        CubeOption::__expand_new_Some(scope, self.clone())
206    }
207
208    fn __expand_shape_method(&self, _scope: &mut Scope) -> SequenceExpand<N> {
209        unimplemented!("Can't read from tensor map");
210    }
211
212    fn __expand_is_in_bounds_method(
213        &self,
214        _scope: &mut Scope,
215        _pos: SequenceExpand<N>,
216    ) -> ExpandElementTyped<bool> {
217        unimplemented!("Can't read from tensor map");
218    }
219
220    #[allow(unused_parens)]
221    fn __expand_tensor_map_load_method(
222        &self,
223        scope: &mut Scope,
224        barrier: BarrierExpand,
225        shared_memory: SliceExpand<T, ReadWrite>,
226        pos: SequenceExpand<N>,
227    ) {
228        let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
229        let rank = pos.len();
230        let pos = &pos;
231        match rank {
232            1 => {
233                let x = as_i32(scope, pos, 0);
234                barrier.__expand_tma_load_1d_method(scope, self.clone(), shared, x);
235            }
236            2 => {
237                let y = as_i32(scope, pos, 0);
238                let x = as_i32(scope, pos, 1);
239                barrier.__expand_tma_load_2d_method(scope, self.clone(), shared, y, x);
240            }
241            3 => {
242                let z = as_i32(scope, pos, 0);
243                let y = as_i32(scope, pos, 1);
244                let x = as_i32(scope, pos, 2);
245                barrier.__expand_tma_load_3d_method(scope, self.clone(), shared, z, y, x);
246            }
247            4 => {
248                let w = as_i32(scope, pos, 0);
249                let z = as_i32(scope, pos, 1);
250                let y = as_i32(scope, pos, 2);
251                let x = as_i32(scope, pos, 3);
252                barrier.__expand_tma_load_4d_method(scope, self.clone(), shared, w, z, y, x);
253            }
254            5 => {
255                let v = as_i32(scope, pos, 0);
256                let w = as_i32(scope, pos, 1);
257                let z = as_i32(scope, pos, 2);
258                let y = as_i32(scope, pos, 3);
259                let x = as_i32(scope, pos, 4);
260                barrier.__expand_tma_load_5d_method(scope, self.clone(), shared, v, w, z, y, x);
261            }
262            _ => panic!("TMA only supports 1D-5D loads"),
263        }
264    }
265}
266
267impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMut<T, Sequence<N>>
268    for TensorMap<T>
269{
270}
271impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMutExpand<T, Sequence<N>>
272    for ExpandElementTyped<TensorMap<T>>
273{
274    fn __expand_write_method(
275        &self,
276        _scope: &mut Scope,
277        _pos: SequenceExpand<N>,
278        _value: <T as CubeType>::ExpandType,
279    ) {
280        unimplemented!("Can't write to tensor map");
281    }
282
283    fn __expand_write_checked_method(
284        &self,
285        _scope: &mut Scope,
286        _pos: SequenceExpand<N>,
287        _value: <T as CubeType>::ExpandType,
288    ) {
289        unimplemented!("Can't write to tensor map");
290    }
291
292    fn __expand_to_linear_slice_mut_method(
293        &self,
294        _scope: &mut Scope,
295        _pos: SequenceExpand<N>,
296        _end: SequenceExpand<N>,
297    ) -> SliceExpand<T, ReadWrite> {
298        unimplemented!("Can't write to tensor map");
299    }
300
301    #[allow(unused_parens)]
302    fn __expand_tensor_map_store_method(
303        &self,
304        scope: &mut Scope,
305        shared_memory: SliceExpand<T, ReadOnly>,
306        pos: SequenceExpand<N>,
307    ) {
308        let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
309        let rank = pos.len();
310        let pos = &pos;
311        match rank {
312            1 => {
313                let x = as_i32(scope, pos, 0);
314                tma_store_1d::expand(scope, shared, self.clone(), x);
315            }
316            2 => {
317                let y = as_i32(scope, pos, 0);
318                let x = as_i32(scope, pos, 1);
319                tma_store_2d::expand(scope, shared, self.clone(), y, x);
320            }
321            3 => {
322                let z = as_i32(scope, pos, 0);
323                let y = as_i32(scope, pos, 1);
324                let x = as_i32(scope, pos, 2);
325                tma_store_3d::expand(scope, shared, self.clone(), z, y, x);
326            }
327            4 => {
328                let w = as_i32(scope, pos, 0);
329                let z = as_i32(scope, pos, 1);
330                let y = as_i32(scope, pos, 2);
331                let x = as_i32(scope, pos, 3);
332                tma_store_4d::expand(scope, shared, self.clone(), w, z, y, x);
333            }
334            5 => {
335                let v = as_i32(scope, pos, 0);
336                let w = as_i32(scope, pos, 1);
337                let z = as_i32(scope, pos, 2);
338                let y = as_i32(scope, pos, 3);
339                let x = as_i32(scope, pos, 4);
340                tma_store_5d::expand(scope, shared, self.clone(), v, w, z, y, x);
341            }
342            _ => panic!("TMA store supports 1D-5D loads"),
343        }
344    }
345}