1use super::*;
2use crate::tensor::layout::*;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
5
6macro_rules! impl_tensor_map {
8 ($dim: literal, $coords: ty, $($var: ident),*) => {
9 paste::paste! {
10 impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Tiled> {}
11 impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for NativeExpand<TensorMap<T, Tiled>> {
12 fn __expand_read_method(
13 &self,
14 _scope: &mut Scope,
15 _pos: <$coords as CubeType>::ExpandType,
16 ) -> <T as CubeType>::ExpandType {
17 unimplemented!("Can't read from tensor map");
18 }
19
20 fn __expand_read_checked_method(
21 &self,
22 _scope: &mut Scope,
23 _pos: <$coords as CubeType>::ExpandType,
24 ) -> <T as CubeType>::ExpandType {
25 unimplemented!("Can't read from tensor map");
26 }
27
28 fn __expand_read_masked_method(
29 &self,
30 _scope: &mut Scope,
31 _pos: <$coords as CubeType>::ExpandType,
32 _mask_value: <T as CubeType>::ExpandType,
33 ) -> <T as CubeType>::ExpandType {
34 unimplemented!("Can't read from tensor map");
35 }
36
37 fn __expand_read_unchecked_method(
38 &self,
39 _scope: &mut Scope,
40 _pos: <$coords as CubeType>::ExpandType,
41 ) -> <T as CubeType>::ExpandType {
42 unimplemented!("Can't read from tensor map");
43 }
44
45 fn __expand_to_linear_slice_method(
46 &self,
47 _scope: &mut Scope,
48 _pos: <$coords as CubeType>::ExpandType,
49 _end: <$coords as CubeType>::ExpandType,
50 ) -> SliceExpand<T, ReadOnly> {
51 unimplemented!("Can't read from tensor map");
52 }
53
54 fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
55 unimplemented!("Can't read from tensor map");
56 }
57
58 fn __expand_is_in_bounds_method(
59 &self,
60 _scope: &mut Scope,
61 _pos: <$coords as CubeType>::ExpandType,
62 ) -> NativeExpand<bool> {
63 true.into()
65 }
66
67 #[allow(unused_parens)]
68 fn __expand_tensor_map_load_method(
69 &self,
70 scope: &mut Scope,
71 barrier: BarrierExpand,
72 shared_memory: SliceExpand<T, ReadWrite>,
73 pos: <$coords as CubeType>::ExpandType,
74 ) {
75 let shared = shared_memory.__expand_downcast_method(scope);
76 let ($($var),*) = pos;
77 let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
78 barrier.[<__expand_tma_load_ $dim d_method>]::<T, T>(scope, self.clone(), shared, $($var),*);
79 }
80 }
81
82 impl<T: CubePrimitive> ViewOperationsMut<T, $coords> for TensorMap<T, Tiled> {}
83 impl<T: CubePrimitive> ViewOperationsMutExpand<T, $coords> for NativeExpand<TensorMap<T, Tiled>> {
84 fn __expand_write_method(
85 &self,
86 _scope: &mut Scope,
87 _pos: <$coords as CubeType>::ExpandType,
88 _value: <T as CubeType>::ExpandType,
89 ) {
90 unimplemented!("Can't write to tensor map");
91 }
92
93 fn __expand_write_checked_method(
94 &self,
95 _scope: &mut Scope,
96 _pos: <$coords as CubeType>::ExpandType,
97 _value: <T as CubeType>::ExpandType,
98 ) {
99 unimplemented!("Can't write to tensor map");
100 }
101
102 fn __expand_to_linear_slice_mut_method(
103 &self,
104 _scope: &mut Scope,
105 _pos: <$coords as CubeType>::ExpandType,
106 _end: <$coords as CubeType>::ExpandType,
107 ) -> SliceExpand<T, ReadWrite> {
108 unimplemented!("Can't write to tensor map");
109 }
110
111 #[allow(unused_parens)]
112 fn __expand_tensor_map_store_method(
113 &self,
114 scope: &mut Scope,
115 shared_memory: SliceExpand<T, ReadOnly>,
116 pos: <$coords as CubeType>::ExpandType,
117 ) {
118 let shared = shared_memory.__expand_downcast_method(scope);
119 let ($($var),*) = pos;
120 let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
121 [<tma_store_ $dim d>]::expand::<T, T>(scope, shared, self.clone(), $($var),*);
122 }
123 }
124 }
125 };
126}
127
128impl_tensor_map!(1, Coords1d, x);
129impl_tensor_map!(2, Coords2d, x, y);
130impl_tensor_map!(3, Coords3d, x, y, z);
131impl_tensor_map!(4, Coords4d, x, y, z, v);
132impl_tensor_map!(5, Coords5d, x, y, z, v, w);
133
134impl_tensor_map!(1, Coords1i, x);
135impl_tensor_map!(2, Coords2i, x, y);
136impl_tensor_map!(3, Coords3i, x, y, z);
137impl_tensor_map!(4, Coords4i, x, y, z, v);
138impl_tensor_map!(5, Coords5i, x, y, z, v, w);
139
140macro_rules! impl_tensor_map_im2col {
142 ($dim: literal, $coords: ty, $($pos: ident),*; $($offs: ident),*) => {
143 paste::paste! {
144 impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Im2col> {}
145 impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for NativeExpand<TensorMap<T, Im2col>> {
146 fn __expand_read_method(
147 &self,
148 _scope: &mut Scope,
149 _pos: <$coords as CubeType>::ExpandType,
150 ) -> <T as CubeType>::ExpandType {
151 unimplemented!("Can't read from tensor map");
152 }
153
154 fn __expand_read_checked_method(
155 &self,
156 _scope: &mut Scope,
157 _pos: <$coords as CubeType>::ExpandType,
158 ) -> <T as CubeType>::ExpandType {
159 unimplemented!("Can't read from tensor map");
160 }
161
162 fn __expand_read_masked_method(
163 &self,
164 _scope: &mut Scope,
165 _pos: <$coords as CubeType>::ExpandType,
166 _mask_value: <T as CubeType>::ExpandType,
167 ) -> <T as CubeType>::ExpandType {
168 unimplemented!("Can't read from tensor map");
169 }
170
171 fn __expand_read_unchecked_method(
172 &self,
173 _scope: &mut Scope,
174 _pos: <$coords as CubeType>::ExpandType,
175 ) -> <T as CubeType>::ExpandType {
176 unimplemented!("Can't read from tensor map");
177 }
178
179 fn __expand_to_linear_slice_method(
180 &self,
181 _scope: &mut Scope,
182 _pos: <$coords as CubeType>::ExpandType,
183 _end: <$coords as CubeType>::ExpandType,
184 ) -> SliceExpand<T, ReadOnly> {
185 unimplemented!("Can't read from tensor map");
186 }
187
188 fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
189 unimplemented!("Can't read from tensor map");
190 }
191
192 fn __expand_is_in_bounds_method(
193 &self,
194 _scope: &mut Scope,
195 _pos: <$coords as CubeType>::ExpandType,
196 ) -> NativeExpand<bool> {
197 true.into()
199 }
200
201 #[allow(unused_parens)]
202 fn __expand_tensor_map_load_method(
203 &self,
204 scope: &mut Scope,
205 barrier: BarrierExpand,
206 shared_memory: SliceExpand<T, ReadWrite>,
207 pos: <$coords as CubeType>::ExpandType,
208 ) {
209 let shared = shared_memory.__expand_downcast_method(scope);
210 let ($($pos),*) = pos.0;
211 let ($($pos),*) = ($(i32::__expand_cast_from(scope, $pos)),*);
212 let ($($offs),*) = pos.1;
213 let ($($offs),*) = ($(u16::__expand_cast_from(scope, $offs)),*);
214
215 barrier.[<__expand_tma_load_im2col_ $dim d_method>]::<T, T>(scope, self.clone(), shared, $($pos),*, $($offs),*);
216 }
217 }
218 }
219 };
220}
221
222impl_tensor_map_im2col!(3, (Coords3d, Coords1d), n, w, c; x);
223impl_tensor_map_im2col!(4, (Coords4d, Coords2d), n, h, w, c; y, x);
224impl_tensor_map_im2col!(5, (Coords5d, Coords3d), n, d, h, w, c; z, y, x);
225
226impl_tensor_map_im2col!(3, (Coords3i, Coords1d), n, w, c; x);
227impl_tensor_map_im2col!(4, (Coords4i, Coords2d), n, h, w, c; y, x);
228impl_tensor_map_im2col!(5, (Coords5i, Coords3d), n, d, h, w, c; z, y, x);
229
230fn as_i32<T: CubePrimitive>(
231 scope: &mut Scope,
232 pos: &SequenceExpand<T>,
233 i: usize,
234) -> NativeExpand<i32> {
235 let x = pos.__expand_index_method(scope, i);
236 i32::__expand_cast_from(scope, x)
237}
238
239fn as_u16<T: CubePrimitive>(
240 scope: &mut Scope,
241 offs: &SequenceExpand<T>,
242 i: usize,
243) -> NativeExpand<u16> {
244 let x = offs.__expand_index_method(scope, i);
245 u16::__expand_cast_from(scope, x)
246}
247
248impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperations<T, Sequence<N>>
249 for TensorMap<T, Tiled>
250{
251}
252impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsExpand<T, Sequence<N>>
253 for NativeExpand<TensorMap<T, Tiled>>
254{
255 fn __expand_read_method(
256 &self,
257 _scope: &mut Scope,
258 _pos: SequenceExpand<N>,
259 ) -> <T as CubeType>::ExpandType {
260 unimplemented!("Can't read from tensor map");
261 }
262
263 fn __expand_read_checked_method(
264 &self,
265 _scope: &mut Scope,
266 _pos: SequenceExpand<N>,
267 ) -> <T as CubeType>::ExpandType {
268 unimplemented!("Can't read from tensor map");
269 }
270
271 fn __expand_read_masked_method(
272 &self,
273 _scope: &mut Scope,
274 _pos: SequenceExpand<N>,
275 _mask_value: <T as CubeType>::ExpandType,
276 ) -> <T as CubeType>::ExpandType {
277 unimplemented!("Can't read from tensor map");
278 }
279
280 fn __expand_read_unchecked_method(
281 &self,
282 _scope: &mut Scope,
283 _pos: SequenceExpand<N>,
284 ) -> <T as CubeType>::ExpandType {
285 unimplemented!("Can't read from tensor map");
286 }
287
288 fn __expand_to_linear_slice_method(
289 &self,
290 _scope: &mut Scope,
291 _pos: SequenceExpand<N>,
292 _end: SequenceExpand<N>,
293 ) -> SliceExpand<T, ReadOnly> {
294 unimplemented!("Can't read from tensor map");
295 }
296
297 fn __expand_shape_method(&self, _scope: &mut Scope) -> SequenceExpand<N> {
298 unimplemented!("Can't read from tensor map");
299 }
300
301 fn __expand_is_in_bounds_method(
302 &self,
303 _scope: &mut Scope,
304 _pos: SequenceExpand<N>,
305 ) -> NativeExpand<bool> {
306 true.into()
308 }
309
310 #[allow(unused_parens)]
311 fn __expand_tensor_map_load_method(
312 &self,
313 scope: &mut Scope,
314 barrier: BarrierExpand,
315 shared_memory: SliceExpand<T, ReadWrite>,
316 pos: SequenceExpand<N>,
317 ) {
318 let shared: SliceExpand<T, ReadWrite> =
319 shared_memory.__expand_downcast_unchecked_method(scope);
320 let rank = pos.len();
321 let pos = &pos;
322 match rank {
323 1 => {
324 let x = as_i32(scope, pos, 0);
325 barrier.__expand_tma_load_1d_method(scope, self.clone(), shared, x);
326 }
327 2 => {
328 let y = as_i32(scope, pos, 0);
329 let x = as_i32(scope, pos, 1);
330 barrier.__expand_tma_load_2d_method(scope, self.clone(), shared, y, x);
331 }
332 3 => {
333 let z = as_i32(scope, pos, 0);
334 let y = as_i32(scope, pos, 1);
335 let x = as_i32(scope, pos, 2);
336 barrier.__expand_tma_load_3d_method(scope, self.clone(), shared, z, y, x);
337 }
338 4 => {
339 let w = as_i32(scope, pos, 0);
340 let z = as_i32(scope, pos, 1);
341 let y = as_i32(scope, pos, 2);
342 let x = as_i32(scope, pos, 3);
343 barrier.__expand_tma_load_4d_method(scope, self.clone(), shared, w, z, y, x);
344 }
345 5 => {
346 let v = as_i32(scope, pos, 0);
347 let w = as_i32(scope, pos, 1);
348 let z = as_i32(scope, pos, 2);
349 let y = as_i32(scope, pos, 3);
350 let x = as_i32(scope, pos, 4);
351 barrier.__expand_tma_load_5d_method(scope, self.clone(), shared, v, w, z, y, x);
352 }
353 _ => panic!("TMA only supports 1D-5D loads"),
354 }
355 }
356}
357
358impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMut<T, Sequence<N>>
359 for TensorMap<T, Tiled>
360{
361}
362impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMutExpand<T, Sequence<N>>
363 for NativeExpand<TensorMap<T, Tiled>>
364{
365 fn __expand_write_method(
366 &self,
367 _scope: &mut Scope,
368 _pos: SequenceExpand<N>,
369 _value: <T as CubeType>::ExpandType,
370 ) {
371 unimplemented!("Can't write to tensor map");
372 }
373
374 fn __expand_write_checked_method(
375 &self,
376 _scope: &mut Scope,
377 _pos: SequenceExpand<N>,
378 _value: <T as CubeType>::ExpandType,
379 ) {
380 unimplemented!("Can't write to tensor map");
381 }
382
383 fn __expand_to_linear_slice_mut_method(
384 &self,
385 _scope: &mut Scope,
386 _pos: SequenceExpand<N>,
387 _end: SequenceExpand<N>,
388 ) -> SliceExpand<T, ReadWrite> {
389 unimplemented!("Can't write to tensor map");
390 }
391
392 #[allow(unused_parens)]
393 fn __expand_tensor_map_store_method(
394 &self,
395 scope: &mut Scope,
396 shared_memory: SliceExpand<T, ReadOnly>,
397 pos: SequenceExpand<N>,
398 ) {
399 let shared: SliceExpand<T, ReadOnly> =
400 shared_memory.__expand_downcast_unchecked_method(scope);
401 let rank = pos.len();
402 let pos = &pos;
403 match rank {
404 1 => {
405 let x = as_i32(scope, pos, 0);
406 tma_store_1d::expand(scope, shared, self.clone(), x);
407 }
408 2 => {
409 let y = as_i32(scope, pos, 0);
410 let x = as_i32(scope, pos, 1);
411 tma_store_2d::expand(scope, shared, self.clone(), y, x);
412 }
413 3 => {
414 let z = as_i32(scope, pos, 0);
415 let y = as_i32(scope, pos, 1);
416 let x = as_i32(scope, pos, 2);
417 tma_store_3d::expand(scope, shared, self.clone(), z, y, x);
418 }
419 4 => {
420 let w = as_i32(scope, pos, 0);
421 let z = as_i32(scope, pos, 1);
422 let y = as_i32(scope, pos, 2);
423 let x = as_i32(scope, pos, 3);
424 tma_store_4d::expand(scope, shared, self.clone(), w, z, y, x);
425 }
426 5 => {
427 let v = as_i32(scope, pos, 0);
428 let w = as_i32(scope, pos, 1);
429 let z = as_i32(scope, pos, 2);
430 let y = as_i32(scope, pos, 3);
431 let x = as_i32(scope, pos, 4);
432 tma_store_5d::expand(scope, shared, self.clone(), v, w, z, y, x);
433 }
434 _ => panic!("TMA store supports 1D-5D loads"),
435 }
436 }
437}
438
439impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
440 ViewOperations<T, (Sequence<P>, Sequence<O>)> for TensorMap<T, Im2col>
441{
442}
443impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
444 ViewOperationsExpand<T, (Sequence<P>, Sequence<O>)> for NativeExpand<TensorMap<T, Im2col>>
445{
446 fn __expand_read_method(
447 &self,
448 _scope: &mut Scope,
449 _pos: (SequenceExpand<P>, SequenceExpand<O>),
450 ) -> <T as CubeType>::ExpandType {
451 unimplemented!("Can't read from tensor map");
452 }
453
454 fn __expand_read_checked_method(
455 &self,
456 _scope: &mut Scope,
457 _pos: (SequenceExpand<P>, SequenceExpand<O>),
458 ) -> <T as CubeType>::ExpandType {
459 unimplemented!("Can't read from tensor map");
460 }
461
462 fn __expand_read_masked_method(
463 &self,
464 _scope: &mut Scope,
465 _pos: (SequenceExpand<P>, SequenceExpand<O>),
466 _mask_value: <T as CubeType>::ExpandType,
467 ) -> <T as CubeType>::ExpandType {
468 unimplemented!("Can't read from tensor map");
469 }
470
471 fn __expand_read_unchecked_method(
472 &self,
473 _scope: &mut Scope,
474 _pos: (SequenceExpand<P>, SequenceExpand<O>),
475 ) -> <T as CubeType>::ExpandType {
476 unimplemented!("Can't read from tensor map");
477 }
478
479 fn __expand_to_linear_slice_method(
480 &self,
481 _scope: &mut Scope,
482 _pos: (SequenceExpand<P>, SequenceExpand<O>),
483 _end: (SequenceExpand<P>, SequenceExpand<O>),
484 ) -> SliceExpand<T, ReadOnly> {
485 unimplemented!("Can't read from tensor map");
486 }
487
488 fn __expand_shape_method(&self, _scope: &mut Scope) -> (SequenceExpand<P>, SequenceExpand<O>) {
489 unimplemented!("Can't read from tensor map");
490 }
491
492 fn __expand_is_in_bounds_method(
493 &self,
494 _scope: &mut Scope,
495 _pos: (SequenceExpand<P>, SequenceExpand<O>),
496 ) -> NativeExpand<bool> {
497 true.into()
499 }
500
501 #[allow(unused_parens)]
502 fn __expand_tensor_map_load_method(
503 &self,
504 scope: &mut Scope,
505 barrier: BarrierExpand,
506 shared_memory: SliceExpand<T, ReadWrite>,
507 pos: (SequenceExpand<P>, SequenceExpand<O>),
508 ) {
509 let shared: SliceExpand<T, ReadWrite> =
510 shared_memory.__expand_downcast_unchecked_method(scope);
511 let (pos, offs) = &pos;
512 let rank = pos.len();
513
514 match rank {
515 3 => {
516 let n = as_i32(scope, pos, 0);
517 let w = as_i32(scope, pos, 1);
518 let c = as_i32(scope, pos, 2);
519 let x = as_u16(scope, offs, 0);
520 barrier.__expand_tma_load_im2col_3d_method(scope, self.clone(), shared, n, w, c, x);
521 }
522 4 => {
523 let n = as_i32(scope, pos, 0);
524 let h = as_i32(scope, pos, 1);
525 let w = as_i32(scope, pos, 2);
526 let c = as_i32(scope, pos, 3);
527 let y = as_u16(scope, offs, 0);
528 let x = as_u16(scope, offs, 1);
529 barrier.__expand_tma_load_im2col_4d_method(
530 scope,
531 self.clone(),
532 shared,
533 n,
534 h,
535 w,
536 c,
537 y,
538 x,
539 );
540 }
541 5 => {
542 let n = as_i32(scope, pos, 0);
543 let d = as_i32(scope, pos, 1);
544 let h = as_i32(scope, pos, 2);
545 let w = as_i32(scope, pos, 3);
546 let c = as_i32(scope, pos, 4);
547 let z = as_u16(scope, offs, 0);
548 let y = as_u16(scope, offs, 1);
549 let x = as_u16(scope, offs, 2);
550 barrier.__expand_tma_load_im2col_5d_method(
551 scope,
552 self.clone(),
553 shared,
554 n,
555 d,
556 h,
557 w,
558 c,
559 z,
560 y,
561 x,
562 );
563 }
564 _ => panic!("TMA im2col only supports 3D-5D loads"),
565 }
566 }
567}