pub const CU_TENSOR_MAP_NUM_QWORDS: usize = 16;
#[repr(C, align(64))]
#[derive(Clone, Copy)]
pub struct CuTensorMap {
pub opaque: [u64; CU_TENSOR_MAP_NUM_QWORDS],
}
impl Default for CuTensorMap {
fn default() -> Self {
Self {
opaque: [0u64; CU_TENSOR_MAP_NUM_QWORDS],
}
}
}
impl std::fmt::Debug for CuTensorMap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CuTensorMap")
.field("opaque[0]", &self.opaque[0])
.field("opaque[1]", &self.opaque[1])
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum CuTensorMapDataType {
Uint8 = 0,
Uint16 = 1,
Uint32 = 2,
Int32 = 3,
Uint64 = 4,
Int64 = 5,
Float16 = 6,
Float32 = 7,
Float64 = 8,
Bfloat16 = 9,
Float32Ftz = 10,
TF32 = 11,
TF32Ftz = 12,
}
impl CuTensorMapDataType {
#[must_use]
pub const fn element_size_bytes(self) -> u32 {
match self {
Self::Uint8 => 1,
Self::Uint16 | Self::Float16 | Self::Bfloat16 => 2,
Self::Uint32
| Self::Int32
| Self::Float32
| Self::Float32Ftz
| Self::TF32
| Self::TF32Ftz => 4,
Self::Uint64 | Self::Int64 | Self::Float64 => 8,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum CuTensorMapInterleave {
None = 0,
B16 = 1,
B32 = 2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum CuTensorMapSwizzle {
None = 0,
B32 = 1,
B64 = 2,
B128 = 3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum CuTensorMapL2Promotion {
None = 0,
L2B64 = 1,
L2B128 = 2,
L2B256 = 3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum CuTensorMapFloatOobFill {
None = 0,
NanRequestZeroFma = 1,
}
#[derive(Debug, Clone)]
pub struct TmaEncodeTiledParams {
pub data_type: CuTensorMapDataType,
pub num_dims: u32,
pub global_dims: [u64; 5],
pub global_strides: [u64; 4],
pub box_dims: [u32; 5],
pub element_strides: [u32; 5],
pub interleave: CuTensorMapInterleave,
pub swizzle: CuTensorMapSwizzle,
pub l2_promotion: CuTensorMapL2Promotion,
pub oob_fill: CuTensorMapFloatOobFill,
}
#[derive(Debug, Clone)]
pub struct TmaDescriptorBuilder {
data_type: CuTensorMapDataType,
num_dims: u32,
global_dims: [u64; 5],
global_strides: [u64; 4],
box_dims: [u32; 5],
element_strides: [u32; 5],
interleave: CuTensorMapInterleave,
swizzle: CuTensorMapSwizzle,
l2_promotion: CuTensorMapL2Promotion,
oob_fill: CuTensorMapFloatOobFill,
}
impl TmaDescriptorBuilder {
#[must_use]
pub fn new_2d(
data_type: CuTensorMapDataType,
rows: u64,
cols: u64,
row_stride_bytes: u64,
box_rows: u32,
box_cols: u32,
) -> Self {
Self {
data_type,
num_dims: 2,
global_dims: [cols, rows, 1, 1, 1],
global_strides: [row_stride_bytes, 0, 0, 0],
box_dims: [box_cols, box_rows, 1, 1, 1],
element_strides: [1, 1, 1, 1, 1],
interleave: CuTensorMapInterleave::None,
swizzle: CuTensorMapSwizzle::B128,
l2_promotion: CuTensorMapL2Promotion::L2B128,
oob_fill: CuTensorMapFloatOobFill::None,
}
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn new_nd(
data_type: CuTensorMapDataType,
num_dims: u32,
global_dims: [u64; 5],
global_strides: [u64; 4],
box_dims: [u32; 5],
element_strides: [u32; 5],
) -> Self {
Self {
data_type,
num_dims,
global_dims,
global_strides,
box_dims,
element_strides,
interleave: CuTensorMapInterleave::None,
swizzle: CuTensorMapSwizzle::B128,
l2_promotion: CuTensorMapL2Promotion::L2B128,
oob_fill: CuTensorMapFloatOobFill::None,
}
}
#[must_use]
pub fn with_swizzle(mut self, swizzle: CuTensorMapSwizzle) -> Self {
self.swizzle = swizzle;
self
}
#[must_use]
pub fn with_interleave(mut self, interleave: CuTensorMapInterleave) -> Self {
self.interleave = interleave;
self
}
#[must_use]
pub fn with_l2_promotion(mut self, l2_promotion: CuTensorMapL2Promotion) -> Self {
self.l2_promotion = l2_promotion;
self
}
#[must_use]
pub fn with_oob_fill(mut self, oob_fill: CuTensorMapFloatOobFill) -> Self {
self.oob_fill = oob_fill;
self
}
#[must_use]
pub fn params(self) -> TmaEncodeTiledParams {
TmaEncodeTiledParams {
data_type: self.data_type,
num_dims: self.num_dims,
global_dims: self.global_dims,
global_strides: self.global_strides,
box_dims: self.box_dims,
element_strides: self.element_strides,
interleave: self.interleave,
swizzle: self.swizzle,
l2_promotion: self.l2_promotion,
oob_fill: self.oob_fill,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cu_tensor_map_size_and_alignment() {
assert_eq!(std::mem::size_of::<CuTensorMap>(), 128);
assert_eq!(std::mem::align_of::<CuTensorMap>(), 64);
}
#[test]
fn test_cu_tensor_map_default_is_zero() {
let m = CuTensorMap::default();
assert!(m.opaque.iter().all(|&v| v == 0));
}
#[test]
fn test_data_type_element_sizes() {
assert_eq!(CuTensorMapDataType::Uint8.element_size_bytes(), 1);
assert_eq!(CuTensorMapDataType::Float16.element_size_bytes(), 2);
assert_eq!(CuTensorMapDataType::Bfloat16.element_size_bytes(), 2);
assert_eq!(CuTensorMapDataType::Float32.element_size_bytes(), 4);
assert_eq!(CuTensorMapDataType::Int32.element_size_bytes(), 4);
assert_eq!(CuTensorMapDataType::Float64.element_size_bytes(), 8);
assert_eq!(CuTensorMapDataType::Uint64.element_size_bytes(), 8);
}
#[test]
fn test_tma_builder_2d_dimension_ordering() {
let params = TmaDescriptorBuilder::new_2d(
CuTensorMapDataType::Float16,
1024, 2048, 2048 * 2,
64,
128,
)
.params();
assert_eq!(params.num_dims, 2);
assert_eq!(params.global_dims[0], 2048); assert_eq!(params.global_dims[1], 1024); assert_eq!(params.box_dims[0], 128); assert_eq!(params.box_dims[1], 64); }
#[test]
fn test_tma_builder_swizzle_override() {
let params =
TmaDescriptorBuilder::new_2d(CuTensorMapDataType::Float32, 64, 64, 64 * 4, 16, 16)
.with_swizzle(CuTensorMapSwizzle::B64)
.params();
assert!(matches!(params.swizzle, CuTensorMapSwizzle::B64));
}
#[test]
fn test_tma_builder_interleave_and_oob() {
let params =
TmaDescriptorBuilder::new_2d(CuTensorMapDataType::Uint8, 256, 256, 256, 32, 32)
.with_interleave(CuTensorMapInterleave::B16)
.with_oob_fill(CuTensorMapFloatOobFill::NanRequestZeroFma)
.params();
assert!(matches!(params.interleave, CuTensorMapInterleave::B16));
assert!(matches!(
params.oob_fill,
CuTensorMapFloatOobFill::NanRequestZeroFma
));
}
#[test]
fn test_tma_builder_l2_promotion() {
let params = TmaDescriptorBuilder::new_2d(
CuTensorMapDataType::Bfloat16,
512,
1024,
1024 * 2,
64,
64,
)
.with_l2_promotion(CuTensorMapL2Promotion::L2B256)
.params();
assert!(matches!(
params.l2_promotion,
CuTensorMapL2Promotion::L2B256
));
}
#[test]
fn test_enum_repr_values() {
assert_eq!(CuTensorMapDataType::Uint8 as u32, 0);
assert_eq!(CuTensorMapDataType::Float16 as u32, 6);
assert_eq!(CuTensorMapDataType::Bfloat16 as u32, 9);
assert_eq!(CuTensorMapDataType::TF32 as u32, 11);
assert_eq!(CuTensorMapInterleave::None as u32, 0);
assert_eq!(CuTensorMapInterleave::B32 as u32, 2);
assert_eq!(CuTensorMapSwizzle::B128 as u32, 3);
assert_eq!(CuTensorMapL2Promotion::L2B256 as u32, 3);
assert_eq!(CuTensorMapFloatOobFill::NanRequestZeroFma as u32, 1);
}
#[test]
fn test_nd_builder() {
let params = TmaDescriptorBuilder::new_nd(
CuTensorMapDataType::Float32,
3,
[512, 256, 128, 1, 1],
[512 * 4, 512 * 256 * 4, 0, 0],
[32, 16, 8, 1, 1],
[1, 1, 1, 1, 1],
)
.params();
assert_eq!(params.num_dims, 3);
assert_eq!(params.global_dims[0], 512);
assert_eq!(params.global_dims[1], 256);
assert_eq!(params.global_dims[2], 128);
}
}