1use super::*;
2use crate::tensor::layout::*;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::BarrierExpand};
5
6macro_rules! impl_tensor_map {
8 ($dim: literal, $coords: ty, $($var: ident),*) => {
9 paste::paste! {
10 impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Tiled> {}
11 impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for ExpandElementTyped<TensorMap<T, Tiled>> {
12 fn __expand_read_method(
13 &self,
14 _scope: &mut Scope,
15 _pos: <$coords as CubeType>::ExpandType,
16 ) -> <T as CubeType>::ExpandType {
17 unimplemented!("Can't read from tensor map");
18 }
19
20 fn __expand_read_checked_method(
21 &self,
22 _scope: &mut Scope,
23 _pos: <$coords as CubeType>::ExpandType,
24 ) -> <T as CubeType>::ExpandType {
25 unimplemented!("Can't read from tensor map");
26 }
27
28 fn __expand_read_masked_method(
29 &self,
30 _scope: &mut Scope,
31 _pos: <$coords as CubeType>::ExpandType,
32 _mask_value: <T as CubeType>::ExpandType,
33 ) -> <T as CubeType>::ExpandType {
34 unimplemented!("Can't read from tensor map");
35 }
36
37 fn __expand_read_unchecked_method(
38 &self,
39 _scope: &mut Scope,
40 _pos: <$coords as CubeType>::ExpandType,
41 ) -> <T as CubeType>::ExpandType {
42 unimplemented!("Can't read from tensor map");
43 }
44
45 fn __expand_to_linear_slice_method(
46 &self,
47 _scope: &mut Scope,
48 _pos: <$coords as CubeType>::ExpandType,
49 _end: <$coords as CubeType>::ExpandType,
50 ) -> SliceExpand<T, ReadOnly> {
51 unimplemented!("Can't read from tensor map");
52 }
53
54 fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
55 unimplemented!("Can't read from tensor map");
56 }
57
58 fn __expand_is_in_bounds_method(
59 &self,
60 _scope: &mut Scope,
61 _pos: <$coords as CubeType>::ExpandType,
62 ) -> ExpandElementTyped<bool> {
63 true.into()
65 }
66
67 #[allow(unused_parens)]
68 fn __expand_tensor_map_load_method(
69 &self,
70 scope: &mut Scope,
71 barrier: BarrierExpand,
72 shared_memory: SliceExpand<T, ReadWrite>,
73 pos: <$coords as CubeType>::ExpandType,
74 ) {
75 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
76 let ($($var),*) = pos;
77 let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
78 barrier.[<__expand_tma_load_ $dim d_method>]::<T>(scope, self.clone(), shared, $($var),*);
79 }
80 }
81
82 impl<T: CubePrimitive> ViewOperationsMut<T, $coords> for TensorMap<T, Tiled> {}
83 impl<T: CubePrimitive> ViewOperationsMutExpand<T, $coords> for ExpandElementTyped<TensorMap<T, Tiled>> {
84 fn __expand_write_method(
85 &self,
86 _scope: &mut Scope,
87 _pos: <$coords as CubeType>::ExpandType,
88 _value: <T as CubeType>::ExpandType,
89 ) {
90 unimplemented!("Can't write to tensor map");
91 }
92
93 fn __expand_write_checked_method(
94 &self,
95 _scope: &mut Scope,
96 _pos: <$coords as CubeType>::ExpandType,
97 _value: <T as CubeType>::ExpandType,
98 ) {
99 unimplemented!("Can't write to tensor map");
100 }
101
102 fn __expand_to_linear_slice_mut_method(
103 &self,
104 _scope: &mut Scope,
105 _pos: <$coords as CubeType>::ExpandType,
106 _end: <$coords as CubeType>::ExpandType,
107 ) -> SliceExpand<T, ReadWrite> {
108 unimplemented!("Can't write to tensor map");
109 }
110
111 #[allow(unused_parens)]
112 fn __expand_tensor_map_store_method(
113 &self,
114 scope: &mut Scope,
115 shared_memory: SliceExpand<T, ReadOnly>,
116 pos: <$coords as CubeType>::ExpandType,
117 ) {
118 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
119 let ($($var),*) = pos;
120 let ($($var),*) = ($(i32::__expand_cast_from(scope, $var)),*);
121 [<tma_store_ $dim d>]::expand(scope, shared, self.clone(), $($var),*);
122 }
123 }
124 }
125 };
126}
127
128impl_tensor_map!(1, Coords1d, x);
129impl_tensor_map!(2, Coords2d, x, y);
130impl_tensor_map!(3, Coords3d, x, y, z);
131impl_tensor_map!(4, Coords4d, x, y, z, v);
132impl_tensor_map!(5, Coords5d, x, y, z, v, w);
133
134impl_tensor_map!(1, Coords1i, x);
135impl_tensor_map!(2, Coords2i, x, y);
136impl_tensor_map!(3, Coords3i, x, y, z);
137impl_tensor_map!(4, Coords4i, x, y, z, v);
138impl_tensor_map!(5, Coords5i, x, y, z, v, w);
139
140macro_rules! impl_tensor_map_im2col {
142 ($dim: literal, $coords: ty, $($pos: ident),*; $($offs: ident),*) => {
143 paste::paste! {
144 impl<T: CubePrimitive> ViewOperations<T, $coords> for TensorMap<T, Im2col> {}
145 impl<T: CubePrimitive> ViewOperationsExpand<T, $coords> for ExpandElementTyped<TensorMap<T, Im2col>> {
146 fn __expand_read_method(
147 &self,
148 _scope: &mut Scope,
149 _pos: <$coords as CubeType>::ExpandType,
150 ) -> <T as CubeType>::ExpandType {
151 unimplemented!("Can't read from tensor map");
152 }
153
154 fn __expand_read_checked_method(
155 &self,
156 _scope: &mut Scope,
157 _pos: <$coords as CubeType>::ExpandType,
158 ) -> <T as CubeType>::ExpandType {
159 unimplemented!("Can't read from tensor map");
160 }
161
162 fn __expand_read_masked_method(
163 &self,
164 _scope: &mut Scope,
165 _pos: <$coords as CubeType>::ExpandType,
166 _mask_value: <T as CubeType>::ExpandType,
167 ) -> <T as CubeType>::ExpandType {
168 unimplemented!("Can't read from tensor map");
169 }
170
171 fn __expand_read_unchecked_method(
172 &self,
173 _scope: &mut Scope,
174 _pos: <$coords as CubeType>::ExpandType,
175 ) -> <T as CubeType>::ExpandType {
176 unimplemented!("Can't read from tensor map");
177 }
178
179 fn __expand_to_linear_slice_method(
180 &self,
181 _scope: &mut Scope,
182 _pos: <$coords as CubeType>::ExpandType,
183 _end: <$coords as CubeType>::ExpandType,
184 ) -> SliceExpand<T, ReadOnly> {
185 unimplemented!("Can't read from tensor map");
186 }
187
188 fn __expand_shape_method(&self, _scope: &mut Scope) -> <$coords as CubeType>::ExpandType {
189 unimplemented!("Can't read from tensor map");
190 }
191
192 fn __expand_is_in_bounds_method(
193 &self,
194 _scope: &mut Scope,
195 _pos: <$coords as CubeType>::ExpandType,
196 ) -> ExpandElementTyped<bool> {
197 true.into()
199 }
200
201 #[allow(unused_parens)]
202 fn __expand_tensor_map_load_method(
203 &self,
204 scope: &mut Scope,
205 barrier: BarrierExpand,
206 shared_memory: SliceExpand<T, ReadWrite>,
207 pos: <$coords as CubeType>::ExpandType,
208 ) {
209 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
210 let ($($pos),*) = pos.0;
211 let ($($pos),*) = ($(i32::__expand_cast_from(scope, $pos)),*);
212 let ($($offs),*) = pos.1;
213 let ($($offs),*) = ($(u16::__expand_cast_from(scope, $offs)),*);
214
215 barrier.[<__expand_tma_load_im2col_ $dim d_method>]::<T>(scope, self.clone(), shared, $($pos),*, $($offs),*);
216 }
217 }
218 }
219 };
220}
221
222impl_tensor_map_im2col!(3, (Coords3d, Coords1d), n, w, c; x);
223impl_tensor_map_im2col!(4, (Coords4d, Coords2d), n, h, w, c; y, x);
224impl_tensor_map_im2col!(5, (Coords5d, Coords3d), n, d, h, w, c; z, y, x);
225
226impl_tensor_map_im2col!(3, (Coords3i, Coords1d), n, w, c; x);
227impl_tensor_map_im2col!(4, (Coords4i, Coords2d), n, h, w, c; y, x);
228impl_tensor_map_im2col!(5, (Coords5i, Coords3d), n, d, h, w, c; z, y, x);
229
230fn as_i32<T: CubePrimitive>(
231 scope: &mut Scope,
232 pos: &SequenceExpand<T>,
233 i: u32,
234) -> ExpandElementTyped<i32> {
235 let x = pos.__expand_index_method(scope, i.into());
236 i32::__expand_cast_from(scope, x)
237}
238
239fn as_u16<T: CubePrimitive>(
240 scope: &mut Scope,
241 offs: &SequenceExpand<T>,
242 i: u32,
243) -> ExpandElementTyped<u16> {
244 let x = offs.__expand_index_method(scope, i.into());
245 u16::__expand_cast_from(scope, x)
246}
247
248impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperations<T, Sequence<N>>
249 for TensorMap<T, Tiled>
250{
251}
252impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsExpand<T, Sequence<N>>
253 for ExpandElementTyped<TensorMap<T, Tiled>>
254{
255 fn __expand_read_method(
256 &self,
257 _scope: &mut Scope,
258 _pos: SequenceExpand<N>,
259 ) -> <T as CubeType>::ExpandType {
260 unimplemented!("Can't read from tensor map");
261 }
262
263 fn __expand_read_checked_method(
264 &self,
265 _scope: &mut Scope,
266 _pos: SequenceExpand<N>,
267 ) -> <T as CubeType>::ExpandType {
268 unimplemented!("Can't read from tensor map");
269 }
270
271 fn __expand_read_masked_method(
272 &self,
273 _scope: &mut Scope,
274 _pos: SequenceExpand<N>,
275 _mask_value: <T as CubeType>::ExpandType,
276 ) -> <T as CubeType>::ExpandType {
277 unimplemented!("Can't read from tensor map");
278 }
279
280 fn __expand_read_unchecked_method(
281 &self,
282 _scope: &mut Scope,
283 _pos: SequenceExpand<N>,
284 ) -> <T as CubeType>::ExpandType {
285 unimplemented!("Can't read from tensor map");
286 }
287
288 fn __expand_to_linear_slice_method(
289 &self,
290 _scope: &mut Scope,
291 _pos: SequenceExpand<N>,
292 _end: SequenceExpand<N>,
293 ) -> SliceExpand<T, ReadOnly> {
294 unimplemented!("Can't read from tensor map");
295 }
296
297 fn __expand_shape_method(&self, _scope: &mut Scope) -> SequenceExpand<N> {
298 unimplemented!("Can't read from tensor map");
299 }
300
301 fn __expand_is_in_bounds_method(
302 &self,
303 _scope: &mut Scope,
304 _pos: SequenceExpand<N>,
305 ) -> ExpandElementTyped<bool> {
306 true.into()
308 }
309
310 #[allow(unused_parens)]
311 fn __expand_tensor_map_load_method(
312 &self,
313 scope: &mut Scope,
314 barrier: BarrierExpand,
315 shared_memory: SliceExpand<T, ReadWrite>,
316 pos: SequenceExpand<N>,
317 ) {
318 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
319 let rank = pos.len();
320 let pos = &pos;
321 match rank {
322 1 => {
323 let x = as_i32(scope, pos, 0);
324 barrier.__expand_tma_load_1d_method(scope, self.clone(), shared, x);
325 }
326 2 => {
327 let y = as_i32(scope, pos, 0);
328 let x = as_i32(scope, pos, 1);
329 barrier.__expand_tma_load_2d_method(scope, self.clone(), shared, y, x);
330 }
331 3 => {
332 let z = as_i32(scope, pos, 0);
333 let y = as_i32(scope, pos, 1);
334 let x = as_i32(scope, pos, 2);
335 barrier.__expand_tma_load_3d_method(scope, self.clone(), shared, z, y, x);
336 }
337 4 => {
338 let w = as_i32(scope, pos, 0);
339 let z = as_i32(scope, pos, 1);
340 let y = as_i32(scope, pos, 2);
341 let x = as_i32(scope, pos, 3);
342 barrier.__expand_tma_load_4d_method(scope, self.clone(), shared, w, z, y, x);
343 }
344 5 => {
345 let v = as_i32(scope, pos, 0);
346 let w = as_i32(scope, pos, 1);
347 let z = as_i32(scope, pos, 2);
348 let y = as_i32(scope, pos, 3);
349 let x = as_i32(scope, pos, 4);
350 barrier.__expand_tma_load_5d_method(scope, self.clone(), shared, v, w, z, y, x);
351 }
352 _ => panic!("TMA only supports 1D-5D loads"),
353 }
354 }
355}
356
357impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMut<T, Sequence<N>>
358 for TensorMap<T, Tiled>
359{
360}
361impl<T: CubePrimitive, N: CubePrimitive + Coordinates> ViewOperationsMutExpand<T, Sequence<N>>
362 for ExpandElementTyped<TensorMap<T, Tiled>>
363{
364 fn __expand_write_method(
365 &self,
366 _scope: &mut Scope,
367 _pos: SequenceExpand<N>,
368 _value: <T as CubeType>::ExpandType,
369 ) {
370 unimplemented!("Can't write to tensor map");
371 }
372
373 fn __expand_write_checked_method(
374 &self,
375 _scope: &mut Scope,
376 _pos: SequenceExpand<N>,
377 _value: <T as CubeType>::ExpandType,
378 ) {
379 unimplemented!("Can't write to tensor map");
380 }
381
382 fn __expand_to_linear_slice_mut_method(
383 &self,
384 _scope: &mut Scope,
385 _pos: SequenceExpand<N>,
386 _end: SequenceExpand<N>,
387 ) -> SliceExpand<T, ReadWrite> {
388 unimplemented!("Can't write to tensor map");
389 }
390
391 #[allow(unused_parens)]
392 fn __expand_tensor_map_store_method(
393 &self,
394 scope: &mut Scope,
395 shared_memory: SliceExpand<T, ReadOnly>,
396 pos: SequenceExpand<N>,
397 ) {
398 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
399 let rank = pos.len();
400 let pos = &pos;
401 match rank {
402 1 => {
403 let x = as_i32(scope, pos, 0);
404 tma_store_1d::expand(scope, shared, self.clone(), x);
405 }
406 2 => {
407 let y = as_i32(scope, pos, 0);
408 let x = as_i32(scope, pos, 1);
409 tma_store_2d::expand(scope, shared, self.clone(), y, x);
410 }
411 3 => {
412 let z = as_i32(scope, pos, 0);
413 let y = as_i32(scope, pos, 1);
414 let x = as_i32(scope, pos, 2);
415 tma_store_3d::expand(scope, shared, self.clone(), z, y, x);
416 }
417 4 => {
418 let w = as_i32(scope, pos, 0);
419 let z = as_i32(scope, pos, 1);
420 let y = as_i32(scope, pos, 2);
421 let x = as_i32(scope, pos, 3);
422 tma_store_4d::expand(scope, shared, self.clone(), w, z, y, x);
423 }
424 5 => {
425 let v = as_i32(scope, pos, 0);
426 let w = as_i32(scope, pos, 1);
427 let z = as_i32(scope, pos, 2);
428 let y = as_i32(scope, pos, 3);
429 let x = as_i32(scope, pos, 4);
430 tma_store_5d::expand(scope, shared, self.clone(), v, w, z, y, x);
431 }
432 _ => panic!("TMA store supports 1D-5D loads"),
433 }
434 }
435}
436
437impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
438 ViewOperations<T, (Sequence<P>, Sequence<O>)> for TensorMap<T, Im2col>
439{
440}
441impl<T: CubePrimitive, P: CubePrimitive + Coordinates, O: CubePrimitive + Coordinates>
442 ViewOperationsExpand<T, (Sequence<P>, Sequence<O>)>
443 for ExpandElementTyped<TensorMap<T, Im2col>>
444{
445 fn __expand_read_method(
446 &self,
447 _scope: &mut Scope,
448 _pos: (SequenceExpand<P>, SequenceExpand<O>),
449 ) -> <T as CubeType>::ExpandType {
450 unimplemented!("Can't read from tensor map");
451 }
452
453 fn __expand_read_checked_method(
454 &self,
455 _scope: &mut Scope,
456 _pos: (SequenceExpand<P>, SequenceExpand<O>),
457 ) -> <T as CubeType>::ExpandType {
458 unimplemented!("Can't read from tensor map");
459 }
460
461 fn __expand_read_masked_method(
462 &self,
463 _scope: &mut Scope,
464 _pos: (SequenceExpand<P>, SequenceExpand<O>),
465 _mask_value: <T as CubeType>::ExpandType,
466 ) -> <T as CubeType>::ExpandType {
467 unimplemented!("Can't read from tensor map");
468 }
469
470 fn __expand_read_unchecked_method(
471 &self,
472 _scope: &mut Scope,
473 _pos: (SequenceExpand<P>, SequenceExpand<O>),
474 ) -> <T as CubeType>::ExpandType {
475 unimplemented!("Can't read from tensor map");
476 }
477
478 fn __expand_to_linear_slice_method(
479 &self,
480 _scope: &mut Scope,
481 _pos: (SequenceExpand<P>, SequenceExpand<O>),
482 _end: (SequenceExpand<P>, SequenceExpand<O>),
483 ) -> SliceExpand<T, ReadOnly> {
484 unimplemented!("Can't read from tensor map");
485 }
486
487 fn __expand_shape_method(&self, _scope: &mut Scope) -> (SequenceExpand<P>, SequenceExpand<O>) {
488 unimplemented!("Can't read from tensor map");
489 }
490
491 fn __expand_is_in_bounds_method(
492 &self,
493 _scope: &mut Scope,
494 _pos: (SequenceExpand<P>, SequenceExpand<O>),
495 ) -> ExpandElementTyped<bool> {
496 true.into()
498 }
499
500 #[allow(unused_parens)]
501 fn __expand_tensor_map_load_method(
502 &self,
503 scope: &mut Scope,
504 barrier: BarrierExpand,
505 shared_memory: SliceExpand<T, ReadWrite>,
506 pos: (SequenceExpand<P>, SequenceExpand<O>),
507 ) {
508 let shared = shared_memory.__expand_try_cast_unchecked_method(scope);
509 let (pos, offs) = &pos;
510 let rank = pos.len();
511
512 match rank {
513 3 => {
514 let n = as_i32(scope, pos, 0);
515 let w = as_i32(scope, pos, 1);
516 let c = as_i32(scope, pos, 2);
517 let x = as_u16(scope, offs, 0);
518 barrier.__expand_tma_load_im2col_3d_method(scope, self.clone(), shared, n, w, c, x);
519 }
520 4 => {
521 let n = as_i32(scope, pos, 0);
522 let h = as_i32(scope, pos, 1);
523 let w = as_i32(scope, pos, 2);
524 let c = as_i32(scope, pos, 3);
525 let y = as_u16(scope, offs, 0);
526 let x = as_u16(scope, offs, 1);
527 barrier.__expand_tma_load_im2col_4d_method(
528 scope,
529 self.clone(),
530 shared,
531 n,
532 h,
533 w,
534 c,
535 y,
536 x,
537 );
538 }
539 5 => {
540 let n = as_i32(scope, pos, 0);
541 let d = as_i32(scope, pos, 1);
542 let h = as_i32(scope, pos, 2);
543 let w = as_i32(scope, pos, 3);
544 let c = as_i32(scope, pos, 4);
545 let z = as_u16(scope, offs, 0);
546 let y = as_u16(scope, offs, 1);
547 let x = as_u16(scope, offs, 2);
548 barrier.__expand_tma_load_im2col_5d_method(
549 scope,
550 self.clone(),
551 shared,
552 n,
553 d,
554 h,
555 w,
556 c,
557 z,
558 y,
559 x,
560 );
561 }
562 _ => panic!("TMA im2col only supports 3D-5D loads"),
563 }
564 }
565}