Skip to main content

cubecl_std/tensor/view/operations/
tensor_map.rs

1use super::*;
2use crate::tensor::layout::*;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
5
6// We don't know the linear layout, so only implement TMA loads/stores
7macro_rules! impl_tensor_map {
8    ($dim: literal, $coords: ty, $($var: ident),*) => {
9        paste::paste! {
10            impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Tiled> {}
11            impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for NativeExpand<TensorMap<T, Tiled>> {
12                fn __expand_read_method(
13                    &self,
14                    _scope: &mut Scope,
15                    _pos: <$coords as CubeType>::ExpandType,
16                ) -> <T as CubeType>::ExpandType {
17                    unimplemented!("Can't read from tensor map");
18                }
19
20                fn __expand_read_checked_method(
21                    &self,
22                    _scope: &mut Scope,
23                    _pos: <$coords as CubeType>::ExpandType,
24                ) -> <T as CubeType>::ExpandType {
25                    unimplemented!("Can't read from tensor map");
26                }
27
28                fn __expand_read_masked_method(
29                    &self,
30                    _scope: &mut Scope,
31                    _pos: <$coords as CubeType>::ExpandType,
32                    _mask_value: <T as CubeType>::ExpandType,
33                ) -> <T as CubeType>::ExpandType {
34                    unimplemented!("Can't read from tensor map");
35                }
36
37                fn __expand_read_unchecked_method(
38                    &self,
39                    _scope: &mut Scope,
40                    _pos: <$coords as CubeType>::ExpandType,
41                ) -> <T as CubeType>::ExpandType {
42                    unimplemented!("Can't read from tensor map");
43                }
44
45                fn __expand_to_linear_slice_method(
46                    &self,
47                    _scope: &mut Scope,
48                    _pos: <$coords as CubeType>::ExpandType,
49                    _end: <$coords as CubeType>::ExpandType,
50                ) -> SliceExpand<T, ReadOnly> {
51                    unimplemented!("Can't read from tensor map");
52                }
53
54                fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
55                    unimplemented!("Can't read from tensor map");
56                }
57
58                fn __expand_is_in_bounds_method(
59                    &self,
60                    _scope: &mut Scope,
61                    _pos: <$coords as CubeType>::ExpandType,
62                ) -> NativeExpand<bool> {
63                    // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
64                    true.into()
65                }
66
67                #[allow(unused_parens)]
68                fn __expand_tensor_map_load_method(
69                    &self,
70                    scope: &mut Scope,
71                    barrier: BarrierExpand,
72                    shared_memory: SliceExpand<T, ReadWrite>,
73                    pos: <$coords as CubeType>::ExpandType,
74                ) {
75                    let shared = shared_memory.__expand_downcast_method(scope);
76                    let ($($var),*) = pos;
77                    let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
78                    barrier.[<__expand_tma_load_ $dim d_method>]::<T, T>(scope, self.clone(), shared, $($var),*);
79                }
80            }
81
82            impl<T: CubePrimitive> ViewOperationsMut<T, $coords> for TensorMap<T, Tiled> {}
83            impl<T: CubePrimitive> ViewOperationsMutExpand<T, $coords> for NativeExpand<TensorMap<T, Tiled>> {
84                fn __expand_write_method(
85                    &self,
86                    _scope: &mut Scope,
87                    _pos: <$coords as CubeType>::ExpandType,
88                    _value: <T as CubeType>::ExpandType,
89                ) {
90                    unimplemented!("Can't write to tensor map");
91                }
92
93                fn __expand_write_checked_method(
94                    &self,
95                    _scope: &mut Scope,
96                    _pos: <$coords as CubeType>::ExpandType,
97                    _value: <T as CubeType>::ExpandType,
98                ) {
99                    unimplemented!("Can't write to tensor map");
100                }
101
102                fn __expand_to_linear_slice_mut_method(
103                    &self,
104                    _scope: &mut Scope,
105                    _pos: <$coords as CubeType>::ExpandType,
106                    _end: <$coords as CubeType>::ExpandType,
107                ) -> SliceExpand<T, ReadWrite> {
108                    unimplemented!("Can't write to tensor map");
109                }
110
111                #[allow(unused_parens)]
112                fn __expand_tensor_map_store_method(
113                    &self,
114                    scope: &mut Scope,
115                    shared_memory: SliceExpand<T, ReadOnly>,
116                    pos: <$coords as CubeType>::ExpandType,
117                ) {
118                    let shared = shared_memory.__expand_downcast_method(scope);
119                    let ($($var),*) = pos;
120                    let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
121                    [<tma_store_ $dim d>]::expand::<T, T>(scope, shared, self.clone(), $($var),*);
122                }
123            }
124        }
125    };
126}
127
128impl_tensor_map!(1, Coords1d, x);
129impl_tensor_map!(2, Coords2d, x, y);
130impl_tensor_map!(3, Coords3d, x, y, z);
131impl_tensor_map!(4, Coords4d, x, y, z, v);
132impl_tensor_map!(5, Coords5d, x, y, z, v, w);
133
134impl_tensor_map!(1, Coords1i, x);
135impl_tensor_map!(2, Coords2i, x, y);
136impl_tensor_map!(3, Coords3i, x, y, z);
137impl_tensor_map!(4, Coords4i, x, y, z, v);
138impl_tensor_map!(5, Coords5i, x, y, z, v, w);
139
140// We don't know the linear layout, so only implement TMA loads
141macro_rules! impl_tensor_map_im2col {
142    ($dim: literal, $coords: ty, $($pos: ident),*; $($offs: ident),*) => {
143        paste::paste! {
144            impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Im2col> {}
145            impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for NativeExpand<TensorMap<T, Im2col>> {
146                fn __expand_read_method(
147                    &self,
148                    _scope: &mut Scope,
149                    _pos: <$coords as CubeType>::ExpandType,
150                ) -> <T as CubeType>::ExpandType {
151                    unimplemented!("Can't read from tensor map");
152                }
153
154                fn __expand_read_checked_method(
155                    &self,
156                    _scope: &mut Scope,
157                    _pos: <$coords as CubeType>::ExpandType,
158                ) -> <T as CubeType>::ExpandType {
159                    unimplemented!("Can't read from tensor map");
160                }
161
162                fn __expand_read_masked_method(
163                    &self,
164                    _scope: &mut Scope,
165                    _pos: <$coords as CubeType>::ExpandType,
166                    _mask_value: <T as CubeType>::ExpandType,
167                ) -> <T as CubeType>::ExpandType {
168                    unimplemented!("Can't read from tensor map");
169                }
170
171                fn __expand_read_unchecked_method(
172                    &self,
173                    _scope: &mut Scope,
174                    _pos: <$coords as CubeType>::ExpandType,
175                ) -> <T as CubeType>::ExpandType {
176                    unimplemented!("Can't read from tensor map");
177                }
178
179                fn __expand_to_linear_slice_method(
180                    &self,
181                    _scope: &mut Scope,
182                    _pos: <$coords as CubeType>::ExpandType,
183                    _end: <$coords as CubeType>::ExpandType,
184                ) -> SliceExpand<T, ReadOnly> {
185                    unimplemented!("Can't read from tensor map");
186                }
187
188                fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
189                    unimplemented!("Can't read from tensor map");
190                }
191
192                fn __expand_is_in_bounds_method(
193                    &self,
194                    _scope: &mut Scope,
195                    _pos: <$coords as CubeType>::ExpandType,
196                ) -> NativeExpand<bool> {
197                    // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
198                    true.into()
199                }
200
201                #[allow(unused_parens)]
202                fn __expand_tensor_map_load_method(
203                    &self,
204                    scope: &mut Scope,
205                    barrier: BarrierExpand,
206                    shared_memory: SliceExpand<T, ReadWrite>,
207                    pos: <$coords as CubeType>::ExpandType,
208                ) {
209                    let shared = shared_memory.__expand_downcast_method(scope);
210                    let ($($pos),*) = pos.0;
211                    let ($($pos),*) = ($(i32::__expand_cast_from(scope, $pos)),*);
212                    let ($($offs),*) = pos.1;
213                    let ($($offs),*) = ($(u16::__expand_cast_from(scope, $offs)),*);
214
215                    barrier.[<__expand_tma_load_im2col_ $dim d_method>]::<T, T>(scope, self.clone(), shared, $($pos),*, $($offs),*);
216                }
217            }
218        }
219    };
220}
221
222impl_tensor_map_im2col!(3, (Coords3d, Coords1d), n, w, c; x);
223impl_tensor_map_im2col!(4, (Coords4d, Coords2d), n, h, w, c; y, x);
224impl_tensor_map_im2col!(5, (Coords5d, Coords3d), n, d, h, w, c; z, y, x);
225
226impl_tensor_map_im2col!(3, (Coords3i, Coords1d), n, w, c; x);
227impl_tensor_map_im2col!(4, (Coords4i, Coords2d), n, h, w, c; y, x);
228impl_tensor_map_im2col!(5, (Coords5i, Coords3d), n, d, h, w, c; z, y, x);
229
230fn as_i32<T: CubePrimitive>(
231    scope: &mut Scope,
232    pos: &SequenceExpand<T>,
233    i: usize,
234) -> NativeExpand<i32> {
235    let x = pos.__expand_index_method(scope, i);
236    i32::__expand_cast_from(scope, x)
237}
238
239fn as_u16<T: CubePrimitive>(
240    scope: &mut Scope,
241    offs: &SequenceExpand<T>,
242    i: usize,
243) -> NativeExpand<u16> {
244    let x = offs.__expand_index_method(scope, i);
245    u16::__expand_cast_from(scope, x)
246}
247
248impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperations<T, Sequence<N>>
249    for TensorMap<T, Tiled>
250{
251}
252impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsExpand<T, Sequence<N>>
253    for NativeExpand<TensorMap<T, Tiled>>
254{
255    fn __expand_read_method(
256        &self,
257        _scope: &mut Scope,
258        _pos: SequenceExpand<N>,
259    ) -> <T as CubeType>::ExpandType {
260        unimplemented!("Can't read from tensor map");
261    }
262
263    fn __expand_read_checked_method(
264        &self,
265        _scope: &mut Scope,
266        _pos: SequenceExpand<N>,
267    ) -> <T as CubeType>::ExpandType {
268        unimplemented!("Can't read from tensor map");
269    }
270
271    fn __expand_read_masked_method(
272        &self,
273        _scope: &mut Scope,
274        _pos: SequenceExpand<N>,
275        _mask_value: <T as CubeType>::ExpandType,
276    ) -> <T as CubeType>::ExpandType {
277        unimplemented!("Can't read from tensor map");
278    }
279
280    fn __expand_read_unchecked_method(
281        &self,
282        _scope: &mut Scope,
283        _pos: SequenceExpand<N>,
284    ) -> <T as CubeType>::ExpandType {
285        unimplemented!("Can't read from tensor map");
286    }
287
288    fn __expand_to_linear_slice_method(
289        &self,
290        _scope: &mut Scope,
291        _pos: SequenceExpand<N>,
292        _end: SequenceExpand<N>,
293    ) -> SliceExpand<T, ReadOnly> {
294        unimplemented!("Can't read from tensor map");
295    }
296
297    fn __expand_shape_method(&self, _scope: &mut Scope) -> SequenceExpand<N> {
298        unimplemented!("Can't read from tensor map");
299    }
300
301    fn __expand_is_in_bounds_method(
302        &self,
303        _scope: &mut Scope,
304        _pos: SequenceExpand<N>,
305    ) -> NativeExpand<bool> {
306        // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
307        true.into()
308    }
309
310    #[allow(unused_parens)]
311    fn __expand_tensor_map_load_method(
312        &self,
313        scope: &mut Scope,
314        barrier: BarrierExpand,
315        shared_memory: SliceExpand<T, ReadWrite>,
316        pos: SequenceExpand<N>,
317    ) {
318        let shared: SliceExpand<T, ReadWrite> =
319            shared_memory.__expand_downcast_unchecked_method(scope);
320        let rank = pos.len();
321        let pos = &pos;
322        match rank {
323            1 => {
324                let x = as_i32(scope, pos, 0);
325                barrier.__expand_tma_load_1d_method(scope, self.clone(), shared, x);
326            }
327            2 => {
328                let y = as_i32(scope, pos, 0);
329                let x = as_i32(scope, pos, 1);
330                barrier.__expand_tma_load_2d_method(scope, self.clone(), shared, y, x);
331            }
332            3 => {
333                let z = as_i32(scope, pos, 0);
334                let y = as_i32(scope, pos, 1);
335                let x = as_i32(scope, pos, 2);
336                barrier.__expand_tma_load_3d_method(scope, self.clone(), shared, z, y, x);
337            }
338            4 => {
339                let w = as_i32(scope, pos, 0);
340                let z = as_i32(scope, pos, 1);
341                let y = as_i32(scope, pos, 2);
342                let x = as_i32(scope, pos, 3);
343                barrier.__expand_tma_load_4d_method(scope, self.clone(), shared, w, z, y, x);
344            }
345            5 => {
346                let v = as_i32(scope, pos, 0);
347                let w = as_i32(scope, pos, 1);
348                let z = as_i32(scope, pos, 2);
349                let y = as_i32(scope, pos, 3);
350                let x = as_i32(scope, pos, 4);
351                barrier.__expand_tma_load_5d_method(scope, self.clone(), shared, v, w, z, y, x);
352            }
353            _ => panic!("TMA only supports 1D-5D loads"),
354        }
355    }
356}
357
358impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMut<T, Sequence<N>>
359    for TensorMap<T, Tiled>
360{
361}
362impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMutExpand<T, Sequence<N>>
363    for NativeExpand<TensorMap<T, Tiled>>
364{
365    fn __expand_write_method(
366        &self,
367        _scope: &mut Scope,
368        _pos: SequenceExpand<N>,
369        _value: <T as CubeType>::ExpandType,
370    ) {
371        unimplemented!("Can't write to tensor map");
372    }
373
374    fn __expand_write_checked_method(
375        &self,
376        _scope: &mut Scope,
377        _pos: SequenceExpand<N>,
378        _value: <T as CubeType>::ExpandType,
379    ) {
380        unimplemented!("Can't write to tensor map");
381    }
382
383    fn __expand_to_linear_slice_mut_method(
384        &self,
385        _scope: &mut Scope,
386        _pos: SequenceExpand<N>,
387        _end: SequenceExpand<N>,
388    ) -> SliceExpand<T, ReadWrite> {
389        unimplemented!("Can't write to tensor map");
390    }
391
392    #[allow(unused_parens)]
393    fn __expand_tensor_map_store_method(
394        &self,
395        scope: &mut Scope,
396        shared_memory: SliceExpand<T, ReadOnly>,
397        pos: SequenceExpand<N>,
398    ) {
399        let shared: SliceExpand<T, ReadOnly> =
400            shared_memory.__expand_downcast_unchecked_method(scope);
401        let rank = pos.len();
402        let pos = &pos;
403        match rank {
404            1 => {
405                let x = as_i32(scope, pos, 0);
406                tma_store_1d::expand(scope, shared, self.clone(), x);
407            }
408            2 => {
409                let y = as_i32(scope, pos, 0);
410                let x = as_i32(scope, pos, 1);
411                tma_store_2d::expand(scope, shared, self.clone(), y, x);
412            }
413            3 => {
414                let z = as_i32(scope, pos, 0);
415                let y = as_i32(scope, pos, 1);
416                let x = as_i32(scope, pos, 2);
417                tma_store_3d::expand(scope, shared, self.clone(), z, y, x);
418            }
419            4 => {
420                let w = as_i32(scope, pos, 0);
421                let z = as_i32(scope, pos, 1);
422                let y = as_i32(scope, pos, 2);
423                let x = as_i32(scope, pos, 3);
424                tma_store_4d::expand(scope, shared, self.clone(), w, z, y, x);
425            }
426            5 => {
427                let v = as_i32(scope, pos, 0);
428                let w = as_i32(scope, pos, 1);
429                let z = as_i32(scope, pos, 2);
430                let y = as_i32(scope, pos, 3);
431                let x = as_i32(scope, pos, 4);
432                tma_store_5d::expand(scope, shared, self.clone(), v, w, z, y, x);
433            }
434            _ => panic!("TMA store supports 1D-5D loads"),
435        }
436    }
437}
438
439impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
440    ViewOperations<T, (Sequence<P>, Sequence<O>)> for TensorMap<T, Im2col>
441{
442}
443impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
444    ViewOperationsExpand<T, (Sequence<P>, Sequence<O>)> for NativeExpand<TensorMap<T, Im2col>>
445{
446    fn __expand_read_method(
447        &self,
448        _scope: &mut Scope,
449        _pos: (SequenceExpand<P>, SequenceExpand<O>),
450    ) -> <T as CubeType>::ExpandType {
451        unimplemented!("Can't read from tensor map");
452    }
453
454    fn __expand_read_checked_method(
455        &self,
456        _scope: &mut Scope,
457        _pos: (SequenceExpand<P>, SequenceExpand<O>),
458    ) -> <T as CubeType>::ExpandType {
459        unimplemented!("Can't read from tensor map");
460    }
461
462    fn __expand_read_masked_method(
463        &self,
464        _scope: &mut Scope,
465        _pos: (SequenceExpand<P>, SequenceExpand<O>),
466        _mask_value: <T as CubeType>::ExpandType,
467    ) -> <T as CubeType>::ExpandType {
468        unimplemented!("Can't read from tensor map");
469    }
470
471    fn __expand_read_unchecked_method(
472        &self,
473        _scope: &mut Scope,
474        _pos: (SequenceExpand<P>, SequenceExpand<O>),
475    ) -> <T as CubeType>::ExpandType {
476        unimplemented!("Can't read from tensor map");
477    }
478
479    fn __expand_to_linear_slice_method(
480        &self,
481        _scope: &mut Scope,
482        _pos: (SequenceExpand<P>, SequenceExpand<O>),
483        _end: (SequenceExpand<P>, SequenceExpand<O>),
484    ) -> SliceExpand<T, ReadOnly> {
485        unimplemented!("Can't read from tensor map");
486    }
487
488    fn __expand_shape_method(&self, _scope: &mut Scope) -> (SequenceExpand<P>, SequenceExpand<O>) {
489        unimplemented!("Can't read from tensor map");
490    }
491
492    fn __expand_is_in_bounds_method(
493        &self,
494        _scope: &mut Scope,
495        _pos: (SequenceExpand<P>, SequenceExpand<O>),
496    ) -> NativeExpand<bool> {
497        // Bounds checks are done in hardware, so treat them as always in bounds for the kernels
498        true.into()
499    }
500
501    #[allow(unused_parens)]
502    fn __expand_tensor_map_load_method(
503        &self,
504        scope: &mut Scope,
505        barrier: BarrierExpand,
506        shared_memory: SliceExpand<T, ReadWrite>,
507        pos: (SequenceExpand<P>, SequenceExpand<O>),
508    ) {
509        let shared: SliceExpand<T, ReadWrite> =
510            shared_memory.__expand_downcast_unchecked_method(scope);
511        let (pos, offs) = &pos;
512        let rank = pos.len();
513
514        match rank {
515            3 => {
516                let n = as_i32(scope, pos, 0);
517                let w = as_i32(scope, pos, 1);
518                let c = as_i32(scope, pos, 2);
519                let x = as_u16(scope, offs, 0);
520                barrier.__expand_tma_load_im2col_3d_method(scope, self.clone(), shared, n, w, c, x);
521            }
522            4 => {
523                let n = as_i32(scope, pos, 0);
524                let h = as_i32(scope, pos, 1);
525                let w = as_i32(scope, pos, 2);
526                let c = as_i32(scope, pos, 3);
527                let y = as_u16(scope, offs, 0);
528                let x = as_u16(scope, offs, 1);
529                barrier.__expand_tma_load_im2col_4d_method(
530                    scope,
531                    self.clone(),
532                    shared,
533                    n,
534                    h,
535                    w,
536                    c,
537                    y,
538                    x,
539                );
540            }
541            5 => {
542                let n = as_i32(scope, pos, 0);
543                let d = as_i32(scope, pos, 1);
544                let h = as_i32(scope, pos, 2);
545                let w = as_i32(scope, pos, 3);
546                let c = as_i32(scope, pos, 4);
547                let z = as_u16(scope, offs, 0);
548                let y = as_u16(scope, offs, 1);
549                let x = as_u16(scope, offs, 2);
550                barrier.__expand_tma_load_im2col_5d_method(
551                    scope,
552                    self.clone(),
553                    shared,
554                    n,
555                    d,
556                    h,
557                    w,
558                    c,
559                    z,
560                    y,
561                    x,
562                );
563            }
564            _ => panic!("TMA im2col only supports 3D-5D loads"),
565        }
566    }
567}