Skip to main content

cubek_std/launch/
tma.rs

1//! Helpers for building TMA (Tensor Memory Accelerator) descriptors.
2
3use cubecl::ir::StorageType;
4use cubecl::prelude::*;
5use cubecl::server::TensorMapMeta;
6pub use cubecl::zspace::metadata::Metadata;
7use cubecl::zspace::{Shape, Strides};
8
9use crate::MatrixLayout;
10
11/// CUDA's TMA loads f32 as tf32 internally; remap explicitly so the descriptor matches.
12pub fn remap_storage_for_tma(ty: StorageType) -> StorageType {
13    if ty == f32::as_type_native_unchecked().storage_type() {
14        tf32::as_type_native_unchecked().storage_type()
15    } else {
16        ty
17    }
18}
19
20/// TMA assumes the last stride is contiguous and discards it. For ColMajor inputs we therefore
21/// swap the inner two dims so the contiguous one ends up last. The tensor's own metadata stays
22/// in its original layout — only the TMA descriptor sees the transposed form.
23///
24/// `shape` and `strides` may have different ranks (the matmul builder constructs them
25/// transiently mismatched and aligns them afterwards). Each is swapped on its own inner pair.
26///
27/// Returns `true` if a swap occurred.
28pub fn transpose_inner_for_tma(
29    shape: &mut Shape,
30    strides: &mut Strides,
31    layout: MatrixLayout,
32) -> bool {
33    if matches!(layout, MatrixLayout::ColMajor) {
34        let s_rank = shape.num_dims();
35        let t_rank = strides.rank();
36        shape.swap(s_rank - 1, s_rank - 2);
37        strides.swap(t_rank - 1, t_rank - 2);
38        true
39    } else {
40        false
41    }
42}
43
44/// Build a tiled [`TensorMapMeta`] with the defaults shared by every current call site
45/// (no interleave, no prefetch, OOB-fill = zero, elem_stride = `[1; rank]`).
46pub fn tma_meta_tiled(
47    metadata: Metadata,
48    tile_size: Shape,
49    storage_ty: StorageType,
50    swizzle: TensorMapSwizzle,
51) -> TensorMapMeta {
52    let rank = metadata.rank();
53    TensorMapMeta {
54        format: TensorMapFormat::Tiled(TiledArgs { tile_size }),
55        metadata,
56        elem_stride: Strides::new(&vec![1; rank]),
57        interleave: TensorMapInterleave::None,
58        swizzle,
59        prefetch: TensorMapPrefetch::None,
60        oob_fill: OobFill::Zero,
61        storage_ty,
62    }
63}