cubecl_std/tensor/view/operations/
tensor_map.rs

1use super::*;
2use crate::tensor::layout::*;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
5
6// We don't know the linear layout, so only implement TMA loads/stores
7macro_rules! impl_tensor_map {
8    ($dim: literal, $coords: ty, $($var: ident),*) => {
9        paste::paste! {
10            impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Tiled> {}
11            impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for ExpandElementTyped<TensorMap<T, Tiled>> {
12                fn __expand_read_method(
13                    &self,
14                    _scope: &mut Scope,
15                    _pos: <$coords as CubeType>::ExpandType,
16                ) -> <T as CubeType>::ExpandType {
17                    unimplemented!("Can't read from tensor map");
18                }
19
20                fn __expand_read_checked_method(
21                    &self,
22                    _scope: &mut Scope,
23                    _pos: <$coords as CubeType>::ExpandType,
24                ) -> <T as CubeType>::ExpandType {
25                    unimplemented!("Can't read from tensor map");
26                }
27
28                fn __expand_read_masked_method(
29                    &self,
30                    _scope: &mut Scope,
31                    _pos: <$coords as CubeType>::ExpandType,
32                    _mask_value: <T as CubeType>::ExpandType,
33                ) -> <T as CubeType>::ExpandType {
34                    unimplemented!("Can't read from tensor map");
35                }
36
37                fn __expand_read_unchecked_method(
38                    &self,
39                    _scope: &mut Scope,
40                    _pos: <$coords as CubeType>::ExpandType,
41                ) -> <T as CubeType>::ExpandType {
42                    unimplemented!("Can't read from tensor map");
43                }
44
45                fn __expand_to_linear_slice_method(
46                    &self,
47                    _scope: &mut Scope,
48                    _pos: <$coords as CubeType>::ExpandType,
49                    _end: <$coords as CubeType>::ExpandType,
50                ) -> SliceExpand<T, ReadOnly> {
51                    unimplemented!("Can't read from tensor map");
52                }
53
54                fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
55                    unimplemented!("Can't read from tensor map");
56                }
57
58                fn __expand_is_in_bounds_method(
59                    &self,
60                    _scope: &mut Scope,
61                    _pos: <$coords as CubeType>::ExpandType,
62                ) -> ExpandElementTyped<bool> {
63                    // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
64                    true.into()
65                }
66
67                #[allow(unused_parens)]
68                fn __expand_tensor_map_load_method(
69                    &self,
70                    scope: &mut Scope,
71                    barrier: BarrierExpand,
72                    shared_memory: SliceExpand<T, ReadWrite>,
73                    pos: <$coords as CubeType>::ExpandType,
74                ) {
75                    let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
76                    let ($($var),*) = pos;
77                    let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
78                    barrier.[<__expand_tma_load_ $dim d_method>]::<T>(scope, self.clone(), shared, $($var),*);
79                }
80            }
81
82            impl<T: CubePrimitive> ViewOperationsMut<T, $coords> for TensorMap<T, Tiled> {}
83            impl<T: CubePrimitive> ViewOperationsMutExpand<T, $coords> for ExpandElementTyped<TensorMap<T, Tiled>> {
84                fn __expand_write_method(
85                    &self,
86                    _scope: &mut Scope,
87                    _pos: <$coords as CubeType>::ExpandType,
88                    _value: <T as CubeType>::ExpandType,
89                ) {
90                    unimplemented!("Can't write to tensor map");
91                }
92
93                fn __expand_write_checked_method(
94                    &self,
95                    _scope: &mut Scope,
96                    _pos: <$coords as CubeType>::ExpandType,
97                    _value: <T as CubeType>::ExpandType,
98                ) {
99                    unimplemented!("Can't write to tensor map");
100                }
101
102                fn __expand_to_linear_slice_mut_method(
103                    &self,
104                    _scope: &mut Scope,
105                    _pos: <$coords as CubeType>::ExpandType,
106                    _end: <$coords as CubeType>::ExpandType,
107                ) -> SliceExpand<T, ReadWrite> {
108                    unimplemented!("Can't write to tensor map");
109                }
110
111                #[allow(unused_parens)]
112                fn __expand_tensor_map_store_method(
113                    &self,
114                    scope: &mut Scope,
115                    shared_memory: SliceExpand<T, ReadOnly>,
116                    pos: <$coords as CubeType>::ExpandType,
117                ) {
118                    let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
119                    let ($($var),*) = pos;
120                    let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
121                    [<tma_store_ $dim d>]::expand(scope, shared, self.clone(), $($var),*);
122                }
123            }
124        }
125    };
126}
127
128impl_tensor_map!(1, Coords1d, x);
129impl_tensor_map!(2, Coords2d, x, y);
130impl_tensor_map!(3, Coords3d, x, y, z);
131impl_tensor_map!(4, Coords4d, x, y, z, v);
132impl_tensor_map!(5, Coords5d, x, y, z, v, w);
133
134impl_tensor_map!(1, Coords1i, x);
135impl_tensor_map!(2, Coords2i, x, y);
136impl_tensor_map!(3, Coords3i, x, y, z);
137impl_tensor_map!(4, Coords4i, x, y, z, v);
138impl_tensor_map!(5, Coords5i, x, y, z, v, w);
139
140// We don't know the linear layout, so only implement TMA loads
141macro_rules! impl_tensor_map_im2col {
142    ($dim: literal, $coords: ty, $($pos: ident),*; $($offs: ident),*) => {
143        paste::paste! {
144            impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Im2col> {}
145            impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for ExpandElementTyped<TensorMap<T, Im2col>> {
146                fn __expand_read_method(
147                    &self,
148                    _scope: &mut Scope,
149                    _pos: <$coords as CubeType>::ExpandType,
150                ) -> <T as CubeType>::ExpandType {
151                    unimplemented!("Can't read from tensor map");
152                }
153
154                fn __expand_read_checked_method(
155                    &self,
156                    _scope: &mut Scope,
157                    _pos: <$coords as CubeType>::ExpandType,
158                ) -> <T as CubeType>::ExpandType {
159                    unimplemented!("Can't read from tensor map");
160                }
161
162                fn __expand_read_masked_method(
163                    &self,
164                    _scope: &mut Scope,
165                    _pos: <$coords as CubeType>::ExpandType,
166                    _mask_value: <T as CubeType>::ExpandType,
167                ) -> <T as CubeType>::ExpandType {
168                    unimplemented!("Can't read from tensor map");
169                }
170
171                fn __expand_read_unchecked_method(
172                    &self,
173                    _scope: &mut Scope,
174                    _pos: <$coords as CubeType>::ExpandType,
175                ) -> <T as CubeType>::ExpandType {
176                    unimplemented!("Can't read from tensor map");
177                }
178
179                fn __expand_to_linear_slice_method(
180                    &self,
181                    _scope: &mut Scope,
182                    _pos: <$coords as CubeType>::ExpandType,
183                    _end: <$coords as CubeType>::ExpandType,
184                ) -> SliceExpand<T, ReadOnly> {
185                    unimplemented!("Can't read from tensor map");
186                }
187
188                fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
189                    unimplemented!("Can't read from tensor map");
190                }
191
192                fn __expand_is_in_bounds_method(
193                    &self,
194                    _scope: &mut Scope,
195                    _pos: <$coords as CubeType>::ExpandType,
196                ) -> ExpandElementTyped<bool> {
197                    // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
198                    true.into()
199                }
200
201                #[allow(unused_parens)]
202                fn __expand_tensor_map_load_method(
203                    &self,
204                    scope: &mut Scope,
205                    barrier: BarrierExpand,
206                    shared_memory: SliceExpand<T, ReadWrite>,
207                    pos: <$coords as CubeType>::ExpandType,
208                ) {
209                    let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
210                    let ($($pos),*) = pos.0;
211                    let ($($pos),*) = ($(i32::__expand_cast_from(scope, $pos)),*);
212                    let ($($offs),*) = pos.1;
213                    let ($($offs),*) = ($(u16::__expand_cast_from(scope, $offs)),*);
214
215                    barrier.[<__expand_tma_load_im2col_ $dim d_method>]::<T>(scope, self.clone(), shared, $($pos),*, $($offs),*);
216                }
217            }
218        }
219    };
220}
221
222impl_tensor_map_im2col!(3, (Coords3d, Coords1d), n, w, c; x);
223impl_tensor_map_im2col!(4, (Coords4d, Coords2d), n, h, w, c; y, x);
224impl_tensor_map_im2col!(5, (Coords5d, Coords3d), n, d, h, w, c; z, y, x);
225
226impl_tensor_map_im2col!(3, (Coords3i, Coords1d), n, w, c; x);
227impl_tensor_map_im2col!(4, (Coords4i, Coords2d), n, h, w, c; y, x);
228impl_tensor_map_im2col!(5, (Coords5i, Coords3d), n, d, h, w, c; z, y, x);
229
230fn as_i32<T: CubePrimitive>(
231    scope: &mut Scope,
232    pos: &SequenceExpand<T>,
233    i: u32,
234) -> ExpandElementTyped<i32> {
235    let x = pos.__expand_index_method(scope, i.into());
236    i32::__expand_cast_from(scope, x)
237}
238
239fn as_u16<T: CubePrimitive>(
240    scope: &mut Scope,
241    offs: &SequenceExpand<T>,
242    i: u32,
243) -> ExpandElementTyped<u16> {
244    let x = offs.__expand_index_method(scope, i.into());
245    u16::__expand_cast_from(scope, x)
246}
247
248impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperations<T, Sequence<N>>
249    for TensorMap<T, Tiled>
250{
251}
252impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsExpand<T, Sequence<N>>
253    for ExpandElementTyped<TensorMap<T, Tiled>>
254{
255    fn __expand_read_method(
256        &self,
257        _scope: &mut Scope,
258        _pos: SequenceExpand<N>,
259    ) -> <T as CubeType>::ExpandType {
260        unimplemented!("Can't read from tensor map");
261    }
262
263    fn __expand_read_checked_method(
264        &self,
265        _scope: &mut Scope,
266        _pos: SequenceExpand<N>,
267    ) -> <T as CubeType>::ExpandType {
268        unimplemented!("Can't read from tensor map");
269    }
270
271    fn __expand_read_masked_method(
272        &self,
273        _scope: &mut Scope,
274        _pos: SequenceExpand<N>,
275        _mask_value: <T as CubeType>::ExpandType,
276    ) -> <T as CubeType>::ExpandType {
277        unimplemented!("Can't read from tensor map");
278    }
279
280    fn __expand_read_unchecked_method(
281        &self,
282        _scope: &mut Scope,
283        _pos: SequenceExpand<N>,
284    ) -> <T as CubeType>::ExpandType {
285        unimplemented!("Can't read from tensor map");
286    }
287
288    fn __expand_to_linear_slice_method(
289        &self,
290        _scope: &mut Scope,
291        _pos: SequenceExpand<N>,
292        _end: SequenceExpand<N>,
293    ) -> SliceExpand<T, ReadOnly> {
294        unimplemented!("Can't read from tensor map");
295    }
296
297    fn __expand_shape_method(&self, _scope: &mut Scope) -> SequenceExpand<N> {
298        unimplemented!("Can't read from tensor map");
299    }
300
301    fn __expand_is_in_bounds_method(
302        &self,
303        _scope: &mut Scope,
304        _pos: SequenceExpand<N>,
305    ) -> ExpandElementTyped<bool> {
306        // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
307        true.into()
308    }
309
310    #[allow(unused_parens)]
311    fn __expand_tensor_map_load_method(
312        &self,
313        scope: &mut Scope,
314        barrier: BarrierExpand,
315        shared_memory: SliceExpand<T, ReadWrite>,
316        pos: SequenceExpand<N>,
317    ) {
318        let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
319        let rank = pos.len();
320        let pos = &pos;
321        match rank {
322            1 => {
323                let x = as_i32(scope, pos, 0);
324                barrier.__expand_tma_load_1d_method(scope, self.clone(), shared, x);
325            }
326            2 => {
327                let y = as_i32(scope, pos, 0);
328                let x = as_i32(scope, pos, 1);
329                barrier.__expand_tma_load_2d_method(scope, self.clone(), shared, y, x);
330            }
331            3 => {
332                let z = as_i32(scope, pos, 0);
333                let y = as_i32(scope, pos, 1);
334                let x = as_i32(scope, pos, 2);
335                barrier.__expand_tma_load_3d_method(scope, self.clone(), shared, z, y, x);
336            }
337            4 => {
338                let w = as_i32(scope, pos, 0);
339                let z = as_i32(scope, pos, 1);
340                let y = as_i32(scope, pos, 2);
341                let x = as_i32(scope, pos, 3);
342                barrier.__expand_tma_load_4d_method(scope, self.clone(), shared, w, z, y, x);
343            }
344            5 => {
345                let v = as_i32(scope, pos, 0);
346                let w = as_i32(scope, pos, 1);
347                let z = as_i32(scope, pos, 2);
348                let y = as_i32(scope, pos, 3);
349                let x = as_i32(scope, pos, 4);
350                barrier.__expand_tma_load_5d_method(scope, self.clone(), shared, v, w, z, y, x);
351            }
352            _ => panic!("TMA only supports 1D-5D loads"),
353        }
354    }
355}
356
357impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMut<T, Sequence<N>>
358    for TensorMap<T, Tiled>
359{
360}
361impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMutExpand<T, Sequence<N>>
362    for ExpandElementTyped<TensorMap<T, Tiled>>
363{
364    fn __expand_write_method(
365        &self,
366        _scope: &mut Scope,
367        _pos: SequenceExpand<N>,
368        _value: <T as CubeType>::ExpandType,
369    ) {
370        unimplemented!("Can't write to tensor map");
371    }
372
373    fn __expand_write_checked_method(
374        &self,
375        _scope: &mut Scope,
376        _pos: SequenceExpand<N>,
377        _value: <T as CubeType>::ExpandType,
378    ) {
379        unimplemented!("Can't write to tensor map");
380    }
381
382    fn __expand_to_linear_slice_mut_method(
383        &self,
384        _scope: &mut Scope,
385        _pos: SequenceExpand<N>,
386        _end: SequenceExpand<N>,
387    ) -> SliceExpand<T, ReadWrite> {
388        unimplemented!("Can't write to tensor map");
389    }
390
391    #[allow(unused_parens)]
392    fn __expand_tensor_map_store_method(
393        &self,
394        scope: &mut Scope,
395        shared_memory: SliceExpand<T, ReadOnly>,
396        pos: SequenceExpand<N>,
397    ) {
398        let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
399        let rank = pos.len();
400        let pos = &pos;
401        match rank {
402            1 => {
403                let x = as_i32(scope, pos, 0);
404                tma_store_1d::expand(scope, shared, self.clone(), x);
405            }
406            2 => {
407                let y = as_i32(scope, pos, 0);
408                let x = as_i32(scope, pos, 1);
409                tma_store_2d::expand(scope, shared, self.clone(), y, x);
410            }
411            3 => {
412                let z = as_i32(scope, pos, 0);
413                let y = as_i32(scope, pos, 1);
414                let x = as_i32(scope, pos, 2);
415                tma_store_3d::expand(scope, shared, self.clone(), z, y, x);
416            }
417            4 => {
418                let w = as_i32(scope, pos, 0);
419                let z = as_i32(scope, pos, 1);
420                let y = as_i32(scope, pos, 2);
421                let x = as_i32(scope, pos, 3);
422                tma_store_4d::expand(scope, shared, self.clone(), w, z, y, x);
423            }
424            5 => {
425                let v = as_i32(scope, pos, 0);
426                let w = as_i32(scope, pos, 1);
427                let z = as_i32(scope, pos, 2);
428                let y = as_i32(scope, pos, 3);
429                let x = as_i32(scope, pos, 4);
430                tma_store_5d::expand(scope, shared, self.clone(), v, w, z, y, x);
431            }
432            _ => panic!("TMA store supports 1D-5D loads"),
433        }
434    }
435}
436
437impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
438    ViewOperations<T, (Sequence<P>, Sequence<O>)> for TensorMap<T, Im2col>
439{
440}
441impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
442    ViewOperationsExpand<T, (Sequence<P>, Sequence<O>)>
443    for ExpandElementTyped<TensorMap<T, Im2col>>
444{
445    fn __expand_read_method(
446        &self,
447        _scope: &mut Scope,
448        _pos: (SequenceExpand<P>, SequenceExpand<O>),
449    ) -> <T as CubeType>::ExpandType {
450        unimplemented!("Can't read from tensor map");
451    }
452
453    fn __expand_read_checked_method(
454        &self,
455        _scope: &mut Scope,
456        _pos: (SequenceExpand<P>, SequenceExpand<O>),
457    ) -> <T as CubeType>::ExpandType {
458        unimplemented!("Can't read from tensor map");
459    }
460
461    fn __expand_read_masked_method(
462        &self,
463        _scope: &mut Scope,
464        _pos: (SequenceExpand<P>, SequenceExpand<O>),
465        _mask_value: <T as CubeType>::ExpandType,
466    ) -> <T as CubeType>::ExpandType {
467        unimplemented!("Can't read from tensor map");
468    }
469
470    fn __expand_read_unchecked_method(
471        &self,
472        _scope: &mut Scope,
473        _pos: (SequenceExpand<P>, SequenceExpand<O>),
474    ) -> <T as CubeType>::ExpandType {
475        unimplemented!("Can't read from tensor map");
476    }
477
478    fn __expand_to_linear_slice_method(
479        &self,
480        _scope: &mut Scope,
481        _pos: (SequenceExpand<P>, SequenceExpand<O>),
482        _end: (SequenceExpand<P>, SequenceExpand<O>),
483    ) -> SliceExpand<T, ReadOnly> {
484        unimplemented!("Can't read from tensor map");
485    }
486
487    fn __expand_shape_method(&self, _scope: &mut Scope) -> (SequenceExpand<P>, SequenceExpand<O>) {
488        unimplemented!("Can't read from tensor map");
489    }
490
491    fn __expand_is_in_bounds_method(
492        &self,
493        _scope: &mut Scope,
494        _pos: (SequenceExpand<P>, SequenceExpand<O>),
495    ) -> ExpandElementTyped<bool> {
496        // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
497        true.into()
498    }
499
500    #[allow(unused_parens)]
501    fn __expand_tensor_map_load_method(
502        &self,
503        scope: &mut Scope,
504        barrier: BarrierExpand,
505        shared_memory: SliceExpand<T, ReadWrite>,
506        pos: (SequenceExpand<P>, SequenceExpand<O>),
507    ) {
508        let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
509        let (pos, offs) = &pos;
510        let rank = pos.len();
511
512        match rank {
513            3 => {
514                let n = as_i32(scope, pos, 0);
515                let w = as_i32(scope, pos, 1);
516                let c = as_i32(scope, pos, 2);
517                let x = as_u16(scope, offs, 0);
518                barrier.__expand_tma_load_im2col_3d_method(scope, self.clone(), shared, n, w, c, x);
519            }
520            4 => {
521                let n = as_i32(scope, pos, 0);
522                let h = as_i32(scope, pos, 1);
523                let w = as_i32(scope, pos, 2);
524                let c = as_i32(scope, pos, 3);
525                let y = as_u16(scope, offs, 0);
526                let x = as_u16(scope, offs, 1);
527                barrier.__expand_tma_load_im2col_4d_method(
528                    scope,
529                    self.clone(),
530                    shared,
531                    n,
532                    h,
533                    w,
534                    c,
535                    y,
536                    x,
537                );
538            }
539            5 => {
540                let n = as_i32(scope, pos, 0);
541                let d = as_i32(scope, pos, 1);
542                let h = as_i32(scope, pos, 2);
543                let w = as_i32(scope, pos, 3);
544                let c = as_i32(scope, pos, 4);
545                let z = as_u16(scope, offs, 0);
546                let y = as_u16(scope, offs, 1);
547                let x = as_u16(scope, offs, 2);
548                barrier.__expand_tma_load_im2col_5d_method(
549                    scope,
550                    self.clone(),
551                    shared,
552                    n,
553                    d,
554                    h,
555                    w,
556                    c,
557                    z,
558                    y,
559                    x,
560                );
561            }
562            _ => panic!("TMA im2col only supports 3D-5D loads"),
563        }
564    }
565}