use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMapDataType {
UInt8,
UInt16,
UInt32,
Int32,
UInt64,
Int64,
Float16,
Float32,
Float64,
BFloat16,
Float32Ftz,
TFloat32,
TFloat32Ftz,
Float8E4m3,
Float8E5m2,
Float4E2m1,
Float6E2m3,
Float6E3m2,
}
impl TensorMapDataType {
pub fn as_u32(self) -> u32 {
match self {
TensorMapDataType::UInt8 => 0,
TensorMapDataType::UInt16 => 1,
TensorMapDataType::UInt32 => 2,
TensorMapDataType::Int32 => 3,
TensorMapDataType::UInt64 => 4,
TensorMapDataType::Int64 => 5,
TensorMapDataType::Float16 => 6,
TensorMapDataType::Float32 => 7,
TensorMapDataType::Float64 => 8,
TensorMapDataType::BFloat16 => 9,
TensorMapDataType::Float32Ftz => 10,
TensorMapDataType::TFloat32 => 11,
TensorMapDataType::TFloat32Ftz => 12,
TensorMapDataType::Float8E4m3 => 13,
TensorMapDataType::Float8E5m2 => 14,
TensorMapDataType::Float4E2m1 => 15,
TensorMapDataType::Float6E2m3 => 16,
TensorMapDataType::Float6E3m2 => 17,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMapInterleave {
None,
Bytes16,
Bytes32,
}
impl TensorMapInterleave {
pub fn as_u32(self) -> u32 {
match self {
TensorMapInterleave::None => 0,
TensorMapInterleave::Bytes16 => 1,
TensorMapInterleave::Bytes32 => 2,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMapSwizzle {
None,
Bytes32,
Bytes64,
Bytes128,
}
impl TensorMapSwizzle {
pub fn as_u32(self) -> u32 {
match self {
TensorMapSwizzle::None => 0,
TensorMapSwizzle::Bytes32 => 1,
TensorMapSwizzle::Bytes64 => 2,
TensorMapSwizzle::Bytes128 => 3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMapL2Promotion {
None,
Bytes64,
Bytes128,
Bytes256,
}
impl TensorMapL2Promotion {
pub fn as_u32(self) -> u32 {
match self {
TensorMapL2Promotion::None => 0,
TensorMapL2Promotion::Bytes64 => 1,
TensorMapL2Promotion::Bytes128 => 2,
TensorMapL2Promotion::Bytes256 => 3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMapOobFill {
NaZero,
NanRequest,
}
impl TensorMapOobFill {
pub fn as_u32(self) -> u32 {
match self {
TensorMapOobFill::NaZero => 0,
TensorMapOobFill::NanRequest => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct TensorMapDescriptor {
pub data_type: TensorMapDataType,
pub global_address: usize,
pub global_dim: Vec<u64>,
pub global_strides: Vec<u64>,
pub box_dim: Vec<u32>,
pub element_strides: Vec<u32>,
pub interleave: TensorMapInterleave,
pub swizzle: TensorMapSwizzle,
pub l2_promotion: TensorMapL2Promotion,
pub oob_fill: TensorMapOobFill,
}
impl TensorMapDescriptor {
pub fn new(data_type: TensorMapDataType, global_address: usize) -> Self {
Self {
data_type,
global_address,
global_dim: Vec::new(),
global_strides: Vec::new(),
box_dim: Vec::new(),
element_strides: Vec::new(),
interleave: TensorMapInterleave::None,
swizzle: TensorMapSwizzle::None,
l2_promotion: TensorMapL2Promotion::None,
oob_fill: TensorMapOobFill::NaZero,
}
}
pub fn rank(&self) -> usize {
self.global_dim.len()
}
pub fn validate(&self) -> Result<(), TmaEncodeError> {
let r = self.rank();
if r == 0 || r > 5 {
return Err(TmaEncodeError::BadRank(r));
}
if self.box_dim.len() != r {
return Err(TmaEncodeError::Mismatch {
what: "box_dim",
expected: r,
got: self.box_dim.len(),
});
}
if self.element_strides.len() != r {
return Err(TmaEncodeError::Mismatch {
what: "element_strides",
expected: r,
got: self.element_strides.len(),
});
}
if !self.global_strides.is_empty() && self.global_strides.len() != r - 1 {
return Err(TmaEncodeError::Mismatch {
what: "global_strides",
expected: r - 1,
got: self.global_strides.len(),
});
}
if self.global_address % 16 != 0 {
return Err(TmaEncodeError::UnalignedAddress(self.global_address));
}
if self.global_dim.contains(&0) {
return Err(TmaEncodeError::ZeroDim("global_dim"));
}
if self.box_dim.contains(&0) {
return Err(TmaEncodeError::ZeroDim("box_dim"));
}
Ok(())
}
#[cfg(feature = "hopper")]
pub fn encode(&self) -> Result<TensorMap, TmaEncodeError> {
use cudarc::driver::sys as cu;
self.validate()?;
let mut tm: cu::CUtensorMap = unsafe { std::mem::zeroed() };
let res = unsafe {
cu::cuTensorMapEncodeTiled(
&mut tm,
std::mem::transmute::<u32, cu::CUtensorMapDataType>(self.data_type.as_u32()),
self.rank() as cu::cuuint32_t,
self.global_address as *mut _,
self.global_dim.as_ptr(),
self.global_strides.as_ptr(),
self.box_dim.as_ptr(),
self.element_strides.as_ptr(),
std::mem::transmute::<u32, cu::CUtensorMapInterleave>(self.interleave.as_u32()),
std::mem::transmute::<u32, cu::CUtensorMapSwizzle>(self.swizzle.as_u32()),
std::mem::transmute::<u32, cu::CUtensorMapL2promotion>(self.l2_promotion.as_u32()),
std::mem::transmute::<u32, cu::CUtensorMapFloatOOBfill>(self.oob_fill.as_u32()),
)
};
if res != cu::CUresult::CUDA_SUCCESS {
return Err(TmaEncodeError::DriverError(res as i32));
}
Ok(TensorMap(tm))
}
}
#[cfg(feature = "hopper")]
pub struct TensorMap(pub cudarc::driver::sys::CUtensorMap);
#[cfg(feature = "hopper")]
impl TensorMap {
pub fn as_ptr(&self) -> *const cudarc::driver::sys::CUtensorMap {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TmaEncodeError {
BadRank(usize),
Mismatch {
what: &'static str,
expected: usize,
got: usize,
},
UnalignedAddress(usize),
ZeroDim(&'static str),
DriverError(i32),
}
impl fmt::Display for TmaEncodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TmaEncodeError::BadRank(r) => write!(f, "TMA rank {r} out of [1,5]"),
TmaEncodeError::Mismatch {
what,
expected,
got,
} => {
write!(
f,
"TMA descriptor: {what}.len() = {got}, expected {expected}"
)
}
TmaEncodeError::UnalignedAddress(a) => {
write!(f, "TMA global_address 0x{a:x} is not 16-byte aligned")
}
TmaEncodeError::ZeroDim(field) => write!(f, "TMA descriptor: {field} contains a zero"),
TmaEncodeError::DriverError(c) => write!(f, "cuTensorMapEncodeTiled returned {c}"),
}
}
}
impl std::error::Error for TmaEncodeError {}
#[cfg(test)]
mod tests {
use super::*;
fn sample_2d_descriptor() -> TensorMapDescriptor {
TensorMapDescriptor {
data_type: TensorMapDataType::Float16,
global_address: 0x1_0000, global_dim: vec![1024, 1024],
global_strides: vec![1024 * 2], box_dim: vec![64, 64],
element_strides: vec![1, 1],
interleave: TensorMapInterleave::None,
swizzle: TensorMapSwizzle::Bytes128,
l2_promotion: TensorMapL2Promotion::Bytes128,
oob_fill: TensorMapOobFill::NaZero,
}
}
#[test]
fn tensor_map_encode_descriptor_round_trip() {
let d = sample_2d_descriptor();
d.validate().expect("sample descriptor must validate");
assert_eq!(d.rank(), 2);
assert_eq!(d.data_type.as_u32(), TensorMapDataType::Float16.as_u32());
assert_eq!(d.swizzle.as_u32(), TensorMapSwizzle::Bytes128.as_u32());
assert_eq!(
d.l2_promotion.as_u32(),
TensorMapL2Promotion::Bytes128.as_u32()
);
assert_eq!(d.oob_fill.as_u32(), TensorMapOobFill::NaZero.as_u32());
let mut bad = d.clone();
bad.global_address = 0x1_0001;
assert!(matches!(
bad.validate().unwrap_err(),
TmaEncodeError::UnalignedAddress(_)
));
let mut bad = d.clone();
bad.box_dim.push(32);
assert!(matches!(
bad.validate().unwrap_err(),
TmaEncodeError::Mismatch {
what: "box_dim",
..
}
));
let bad = TensorMapDescriptor::new(TensorMapDataType::Float32, 0x10);
assert!(matches!(
bad.validate().unwrap_err(),
TmaEncodeError::BadRank(0)
));
let bad = TensorMapDescriptor {
data_type: TensorMapDataType::Float32,
global_address: 0x10,
global_dim: vec![1; 6],
global_strides: vec![4; 5],
box_dim: vec![1; 6],
element_strides: vec![1; 6],
interleave: TensorMapInterleave::None,
swizzle: TensorMapSwizzle::None,
l2_promotion: TensorMapL2Promotion::None,
oob_fill: TensorMapOobFill::NaZero,
};
assert!(matches!(
bad.validate().unwrap_err(),
TmaEncodeError::BadRank(6)
));
}
#[test]
fn enum_discriminants_are_unique() {
let dts = [
TensorMapDataType::UInt8,
TensorMapDataType::UInt16,
TensorMapDataType::UInt32,
TensorMapDataType::Int32,
TensorMapDataType::UInt64,
TensorMapDataType::Int64,
TensorMapDataType::Float16,
TensorMapDataType::Float32,
TensorMapDataType::Float64,
TensorMapDataType::BFloat16,
TensorMapDataType::Float32Ftz,
TensorMapDataType::TFloat32,
TensorMapDataType::TFloat32Ftz,
TensorMapDataType::Float8E4m3,
TensorMapDataType::Float8E5m2,
TensorMapDataType::Float4E2m1,
TensorMapDataType::Float6E2m3,
TensorMapDataType::Float6E3m2,
];
let mut seen = std::collections::HashSet::new();
for d in dts {
assert!(
seen.insert(d.as_u32()),
"duplicate dtype discriminant for {d:?}"
);
}
}
}