1use cubecl::ir::StorageType;
4use cubecl::prelude::*;
5use cubecl::server::TensorMapMeta;
6pub use cubecl::zspace::metadata::Metadata;
7use cubecl::zspace::{Shape, Strides};
8
9use crate::MatrixLayout;
10
11pub fn remap_storage_for_tma(ty: StorageType) -> StorageType {
13 if ty == f32::as_type_native_unchecked().storage_type() {
14 tf32::as_type_native_unchecked().storage_type()
15 } else {
16 ty
17 }
18}
19
20pub fn transpose_inner_for_tma(
29 shape: &mut Shape,
30 strides: &mut Strides,
31 layout: MatrixLayout,
32) -> bool {
33 if matches!(layout, MatrixLayout::ColMajor) {
34 let s_rank = shape.num_dims();
35 let t_rank = strides.rank();
36 shape.swap(s_rank - 1, s_rank - 2);
37 strides.swap(t_rank - 1, t_rank - 2);
38 true
39 } else {
40 false
41 }
42}
43
44pub fn tma_meta_tiled(
47 metadata: Metadata,
48 tile_size: Shape,
49 storage_ty: StorageType,
50 swizzle: TensorMapSwizzle,
51) -> TensorMapMeta {
52 let rank = metadata.rank();
53 TensorMapMeta {
54 format: TensorMapFormat::Tiled(TiledArgs { tile_size }),
55 metadata,
56 elem_stride: Strides::new(&vec![1; rank]),
57 interleave: TensorMapInterleave::None,
58 swizzle,
59 prefetch: TensorMapPrefetch::None,
60 oob_fill: OobFill::Zero,
61 storage_ty,
62 }
63}