use cubecl::ir::StorageType;
use cubecl::prelude::*;
use cubecl::server::TensorMapMeta;
pub use cubecl::zspace::metadata::Metadata;
use cubecl::zspace::{Shape, Strides};
use crate::MatrixLayout;
pub fn remap_storage_for_tma(ty: StorageType) -> StorageType {
if ty == f32::as_type_native_unchecked().storage_type() {
tf32::as_type_native_unchecked().storage_type()
} else {
ty
}
}
pub fn transpose_inner_for_tma(
shape: &mut Shape,
strides: &mut Strides,
layout: MatrixLayout,
) -> bool {
if matches!(layout, MatrixLayout::ColMajor) {
let s_rank = shape.num_dims();
let t_rank = strides.rank();
shape.swap(s_rank - 1, s_rank - 2);
strides.swap(t_rank - 1, t_rank - 2);
true
} else {
false
}
}
pub fn tma_meta_tiled(
metadata: Metadata,
tile_size: Shape,
storage_ty: StorageType,
swizzle: TensorMapSwizzle,
) -> TensorMapMeta {
let rank = metadata.rank();
TensorMapMeta {
format: TensorMapFormat::Tiled(TiledArgs { tile_size }),
metadata,
elem_stride: Strides::new(&vec![1; rank]),
interleave: TensorMapInterleave::None,
swizzle,
prefetch: TensorMapPrefetch::None,
oob_fill: OobFill::Zero,
storage_ty,
}
}