Skip to main content

baracuda_driver/
tensor_map.rs

1//! Hopper Tensor Memory Accelerator (TMA) descriptors.
2//!
3//! CUDA 12.0+ introduced `cuTensorMapEncodeTiled` / `cuTensorMapEncodeIm2col`
4//! to produce `CUtensorMap` descriptors that TMA instructions in kernels
5//! consume to asynchronously move multi-dimensional tiles between global
6//! and shared memory. This is a Hopper-only hardware feature (SM 9.0+),
7//! but the descriptor *encoding* itself is pure host code and works on
8//! any device.
9//!
10//! See the [`TensorMap`] builder for a typed wrapper around
11//! `cuTensorMapEncodeTiled`.
12
13use baracuda_cuda_sys::types::CUtensorMap;
14use baracuda_cuda_sys::{driver, CUdeviceptr};
15
16use crate::error::{check, Result};
17
18pub use baracuda_cuda_sys::types::{
19    CUtensorMapDataType as DataType, CUtensorMapFloatOOBfill as OOBFill,
20    CUtensorMapInterleave as Interleave, CUtensorMapL2promotion as L2Promotion,
21    CUtensorMapSwizzle as Swizzle,
22};
23
24/// A 128-byte Hopper TMA descriptor. Pass to a kernel as a `__grid_constant__`
25/// parameter of type `CUtensorMap` for use with TMA instructions.
26pub struct TensorMap {
27    inner: CUtensorMap,
28}
29
30impl core::fmt::Debug for TensorMap {
31    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32        f.debug_struct("TensorMap")
33            .field(
34                "non_zero_words",
35                &self.inner.opaque.iter().filter(|w| **w != 0).count(),
36            )
37            .finish_non_exhaustive()
38    }
39}
40
41impl TensorMap {
42    /// Build a tiled TMA descriptor.
43    ///
44    /// - `data_type`: element type (one of the `DataType::*` constants).
45    /// - `global_base`: pointer to the first element of the tensor.
46    /// - `global_dim`: per-axis size of the global tensor (innermost-to-outermost).
47    /// - `global_strides`: per-axis byte strides between successive elements.
48    /// - `box_dim`: per-axis shape of the tile copied at a time.
49    /// - `element_strides`: per-axis element-strides (typically all 1).
50    ///
51    /// All arrays must have length `rank = global_dim.len()`.
52    #[allow(clippy::too_many_arguments)]
53    pub fn encode_tiled(
54        data_type: i32,
55        global_base: CUdeviceptr,
56        global_dim: &[u64],
57        global_strides: &[u64],
58        box_dim: &[u32],
59        element_strides: &[u32],
60        interleave: i32,
61        swizzle: i32,
62        l2_promotion: i32,
63        oob_fill: i32,
64    ) -> Result<Self> {
65        let rank = global_dim.len();
66        assert_eq!(global_strides.len(), rank);
67        assert_eq!(box_dim.len(), rank);
68        assert_eq!(element_strides.len(), rank);
69        let d = driver()?;
70        let cu = d.cu_tensor_map_encode_tiled()?;
71        let mut map = CUtensorMap::default();
72        check(unsafe {
73            cu(
74                &mut map,
75                data_type,
76                rank as core::ffi::c_uint,
77                global_base.0 as *mut core::ffi::c_void,
78                global_dim.as_ptr(),
79                global_strides.as_ptr(),
80                box_dim.as_ptr(),
81                element_strides.as_ptr(),
82                interleave,
83                swizzle,
84                l2_promotion,
85                oob_fill,
86            )
87        })?;
88        Ok(Self { inner: map })
89    }
90
91    /// Swap the global base address of an existing descriptor in place.
92    /// Lets you reuse one `TensorMap` across multiple buffers of the same
93    /// shape/stride.
94    pub fn replace_address(&mut self, new_base: CUdeviceptr) -> Result<()> {
95        let d = driver()?;
96        let cu = d.cu_tensor_map_replace_address()?;
97        check(unsafe { cu(&mut self.inner, new_base.0 as *mut core::ffi::c_void) })
98    }
99
100    /// Raw pointer to the 128-byte descriptor — pass this to kernels that
101    /// take a `CUtensorMap` parameter.
102    #[inline]
103    pub fn as_raw(&self) -> &CUtensorMap {
104        &self.inner
105    }
106
107    /// Mutable raw access (for FFI calls that want `*mut CUtensorMap`).
108    #[inline]
109    pub fn as_raw_mut(&mut self) -> &mut CUtensorMap {
110        &mut self.inner
111    }
112}