baracuda_driver/tensor_map.rs
1//! Hopper Tensor Memory Accelerator (TMA) descriptors.
2//!
3//! CUDA 12.0+ introduced `cuTensorMapEncodeTiled` / `cuTensorMapEncodeIm2col`
4//! to produce `CUtensorMap` descriptors that TMA instructions in kernels
5//! consume to asynchronously move multi-dimensional tiles between global
6//! and shared memory. This is a Hopper-only hardware feature (SM 9.0+),
7//! but the descriptor *encoding* itself is pure host code and works on
8//! any device.
9//!
10//! See the [`TensorMap`] builder for a typed wrapper around
11//! `cuTensorMapEncodeTiled`.
12
13use baracuda_cuda_sys::types::CUtensorMap;
14use baracuda_cuda_sys::{driver, CUdeviceptr};
15
16use crate::error::{check, Result};
17
18pub use baracuda_cuda_sys::types::{
19 CUtensorMapDataType as DataType, CUtensorMapFloatOOBfill as OOBFill,
20 CUtensorMapInterleave as Interleave, CUtensorMapL2promotion as L2Promotion,
21 CUtensorMapSwizzle as Swizzle,
22};
23
24/// A 128-byte Hopper TMA descriptor. Pass to a kernel as a `__grid_constant__`
25/// parameter of type `CUtensorMap` for use with TMA instructions.
26pub struct TensorMap {
27 inner: CUtensorMap,
28}
29
30impl core::fmt::Debug for TensorMap {
31 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32 f.debug_struct("TensorMap")
33 .field(
34 "non_zero_words",
35 &self.inner.opaque.iter().filter(|w| **w != 0).count(),
36 )
37 .finish_non_exhaustive()
38 }
39}
40
41impl TensorMap {
42 /// Build a tiled TMA descriptor.
43 ///
44 /// - `data_type`: element type (one of the `DataType::*` constants).
45 /// - `global_base`: pointer to the first element of the tensor.
46 /// - `global_dim`: per-axis size of the global tensor (innermost-to-outermost).
47 /// - `global_strides`: per-axis byte strides between successive elements.
48 /// - `box_dim`: per-axis shape of the tile copied at a time.
49 /// - `element_strides`: per-axis element-strides (typically all 1).
50 ///
51 /// All arrays must have length `rank = global_dim.len()`.
52 #[allow(clippy::too_many_arguments)]
53 pub fn encode_tiled(
54 data_type: i32,
55 global_base: CUdeviceptr,
56 global_dim: &[u64],
57 global_strides: &[u64],
58 box_dim: &[u32],
59 element_strides: &[u32],
60 interleave: i32,
61 swizzle: i32,
62 l2_promotion: i32,
63 oob_fill: i32,
64 ) -> Result<Self> {
65 let rank = global_dim.len();
66 assert_eq!(global_strides.len(), rank);
67 assert_eq!(box_dim.len(), rank);
68 assert_eq!(element_strides.len(), rank);
69 let d = driver()?;
70 let cu = d.cu_tensor_map_encode_tiled()?;
71 let mut map = CUtensorMap::default();
72 check(unsafe {
73 cu(
74 &mut map,
75 data_type,
76 rank as core::ffi::c_uint,
77 global_base.0 as *mut core::ffi::c_void,
78 global_dim.as_ptr(),
79 global_strides.as_ptr(),
80 box_dim.as_ptr(),
81 element_strides.as_ptr(),
82 interleave,
83 swizzle,
84 l2_promotion,
85 oob_fill,
86 )
87 })?;
88 Ok(Self { inner: map })
89 }
90
91 /// Swap the global base address of an existing descriptor in place.
92 /// Lets you reuse one `TensorMap` across multiple buffers of the same
93 /// shape/stride.
94 pub fn replace_address(&mut self, new_base: CUdeviceptr) -> Result<()> {
95 let d = driver()?;
96 let cu = d.cu_tensor_map_replace_address()?;
97 check(unsafe { cu(&mut self.inner, new_base.0 as *mut core::ffi::c_void) })
98 }
99
100 /// Raw pointer to the 128-byte descriptor — pass this to kernels that
101 /// take a `CUtensorMap` parameter.
102 #[inline]
103 pub fn as_raw(&self) -> &CUtensorMap {
104 &self.inner
105 }
106
107 /// Mutable raw access (for FFI calls that want `*mut CUtensorMap`).
108 #[inline]
109 pub fn as_raw_mut(&mut self) -> &mut CUtensorMap {
110 &mut self.inner
111 }
112}