Skip to main content

oxicuda_driver/
tma.rs

1//! Tensor Memory Accelerator (TMA) descriptor types for CUDA 12.x / sm_90+.
2//!
3//! The Tensor Memory Accelerator (TMA) is a hardware unit introduced on
4//! Hopper GPUs (sm_90) and extended on Blackwell (sm_100 / sm_120). It
5//! enables high-throughput bulk copies between global and shared memory
6//! using pre-built *tensor map* descriptors that encode address layout,
7//! swizzle, and out-of-bounds fill modes.
8//!
9//! # Descriptor creation
10//!
11//! Descriptors are created on the host by calling
12//! `cuTensorMapEncodeTiled` (exposed as an optional driver function pointer
13//! in [`DriverApi`](crate::loader::DriverApi)). This module provides:
14//!
15//! - The opaque [`CuTensorMap`] container (64 bytes, 64-byte aligned).
16//! - Configuration enums for data type, interleave, swizzle, etc.
17//! - [`TmaDescriptorBuilder`] — a typed builder that collects parameters
18//!   and produces a [`TmaEncodeTiledParams`] ready to pass to the driver.
19//!
20//! # Example
21//!
22//! ```rust
23//! use oxicuda_driver::tma::{
24//!     CuTensorMapDataType, CuTensorMapSwizzle, TmaDescriptorBuilder,
25//! };
26//!
27//! // Build a descriptor for a row-major f16 matrix of shape 1024×2048,
28//! // tiled with 64-row × 64-col shared-memory tiles.
29//! let params = TmaDescriptorBuilder::new_2d(
30//!     CuTensorMapDataType::Float16,
31//!     1024,           // rows
32//!     2048,           // cols
33//!     2048 * 2,       // row stride in bytes (2 bytes per f16)
34//!     64, 64,         // tile rows, tile cols
35//! )
36//! .with_swizzle(CuTensorMapSwizzle::B128)
37//! .params();
38//!
39//! assert_eq!(params.num_dims, 2);
40//! assert_eq!(params.global_dims[0], 2048); // cols first in CUDA convention
41//! ```
42//!
43//! # CUDA 12+ and sm_90+ requirement
44//!
45//! TMA hardware is only present on Hopper (sm_90), Blackwell B100 (sm_100),
46//! and Blackwell B200 (sm_120) GPUs. On older devices the descriptor can
47//! still be built on the host but will not be usable in a kernel.
48
49// =========================================================================
50// CuTensorMap — the opaque descriptor container
51// =========================================================================
52
53/// Number of 64-bit words in the opaque tensor-map blob.
54pub const CU_TENSOR_MAP_NUM_QWORDS: usize = 16;
55
56/// Opaque TMA tensor map descriptor (64 bytes, 64-byte aligned).
57///
58/// Passed to CUDA kernels via kernel arguments so the TMA hardware can
59/// read its encoding. Created on the host with `cuTensorMapEncodeTiled`
60/// and must not be modified after the driver populates it.
61///
62/// # Layout
63///
64/// The 128-byte structure matches CUDA's `CUtensorMap` exactly. The
65/// internal encoding is private to the driver; user code should treat
66/// this as an opaque blob.
67#[repr(C, align(64))]
68#[derive(Clone, Copy)]
69pub struct CuTensorMap {
70    /// Opaque 128-byte payload.
71    pub opaque: [u64; CU_TENSOR_MAP_NUM_QWORDS],
72}
73
74impl Default for CuTensorMap {
75    fn default() -> Self {
76        Self {
77            opaque: [0u64; CU_TENSOR_MAP_NUM_QWORDS],
78        }
79    }
80}
81
82impl std::fmt::Debug for CuTensorMap {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("CuTensorMap")
85            .field("opaque[0]", &self.opaque[0])
86            .field("opaque[1]", &self.opaque[1])
87            .finish_non_exhaustive()
88    }
89}
90
91// =========================================================================
92// Configuration enums
93// =========================================================================
94
95/// Element data type for the TMA descriptor.
96///
97/// Corresponds to `CUtensorMapDataType` in the CUDA Driver API.
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
99#[repr(u32)]
100pub enum CuTensorMapDataType {
101    /// Unsigned 8-bit integer.
102    Uint8 = 0,
103    /// Unsigned 16-bit integer.
104    Uint16 = 1,
105    /// Unsigned 32-bit integer.
106    Uint32 = 2,
107    /// Signed 32-bit integer.
108    Int32 = 3,
109    /// Unsigned 64-bit integer.
110    Uint64 = 4,
111    /// Signed 64-bit integer.
112    Int64 = 5,
113    /// IEEE-754 half-precision float (f16).
114    Float16 = 6,
115    /// IEEE-754 single-precision float (f32).
116    Float32 = 7,
117    /// IEEE-754 double-precision float (f64).
118    Float64 = 8,
119    /// Brain float 16 (bfloat16).
120    Bfloat16 = 9,
121    /// f32 with flush-to-zero for subnormals.
122    Float32Ftz = 10,
123    /// TensorFloat-32 (TF32).
124    TF32 = 11,
125    /// TF32 with flush-to-zero.
126    TF32Ftz = 12,
127}
128
129impl CuTensorMapDataType {
130    /// Returns the element size in bytes.
131    #[must_use]
132    pub const fn element_size_bytes(self) -> u32 {
133        match self {
134            Self::Uint8 => 1,
135            Self::Uint16 | Self::Float16 | Self::Bfloat16 => 2,
136            Self::Uint32
137            | Self::Int32
138            | Self::Float32
139            | Self::Float32Ftz
140            | Self::TF32
141            | Self::TF32Ftz => 4,
142            Self::Uint64 | Self::Int64 | Self::Float64 => 8,
143        }
144    }
145}
146
147/// Interleave pattern for the TMA descriptor.
148///
149/// Controls how elements from different warp lanes are interleaved.
150/// `None` is safe for most use cases; `B16`/`B32` are only needed for
151/// 1D interleaving of narrow data types.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153#[repr(u32)]
154pub enum CuTensorMapInterleave {
155    /// No interleaving.
156    None = 0,
157    /// 16-byte interleaving stride.
158    B16 = 1,
159    /// 32-byte interleaving stride.
160    B32 = 2,
161}
162
163/// Swizzle pattern applied to shared-memory tiles.
164///
165/// Swizzling re-maps addresses within the tile to avoid shared-memory
166/// bank conflicts. Choose based on the tile width and element type.
167#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
168#[repr(u32)]
169pub enum CuTensorMapSwizzle {
170    /// No swizzle (row-major, bank conflicts possible).
171    None = 0,
172    /// 32-byte swizzle sector.
173    B32 = 1,
174    /// 64-byte swizzle sector.
175    B64 = 2,
176    /// 128-byte swizzle sector.  Default for most f16/bf16 workloads.
177    B128 = 3,
178}
179
180/// L2 cache promotion hint for TMA loads.
181///
182/// Instructs the L2 cache to pre-fetch TMA data in larger chunks.
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
184#[repr(u32)]
185pub enum CuTensorMapL2Promotion {
186    /// No L2 promotion.
187    None = 0,
188    /// Promote to 64-byte L2 lines.
189    L2B64 = 1,
190    /// Promote to 128-byte L2 lines.
191    L2B128 = 2,
192    /// Promote to 256-byte L2 lines.
193    L2B256 = 3,
194}
195
196/// Out-of-bounds fill mode for TMA loads.
197///
198/// When a TMA access goes out of the declared tensor bounds, this controls
199/// whether out-of-bounds elements return zero or NaN.
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
201#[repr(u32)]
202pub enum CuTensorMapFloatOobFill {
203    /// No fill — out-of-bounds reads are undefined.
204    None = 0,
205    /// Out-of-bounds float reads return NaN; FMA returns zero.
206    NanRequestZeroFma = 1,
207}
208
209// =========================================================================
210// TmaEncodeTiledParams — flattened parameters for the driver call
211// =========================================================================
212
213/// All parameters required by `cuTensorMapEncodeTiled`.
214///
215/// Produced by [`TmaDescriptorBuilder`]. Pass these to the driver function
216/// pointer available in
217/// [`DriverApi::cu_tensor_map_encode_tiled`](crate::loader::DriverApi::cu_tensor_map_encode_tiled).
218///
219/// # Dimension ordering
220///
221/// CUDA TMA uses **column-major** ordering for dimensions: `global_dims[0]`
222/// is the *innermost* (fastest-varying) dimension. For a row-major matrix
223/// of shape `R × C`, set `global_dims[0] = C` (cols) and
224/// `global_dims[1] = R` (rows).
225#[derive(Debug, Clone)]
226pub struct TmaEncodeTiledParams {
227    /// Element data type.
228    pub data_type: CuTensorMapDataType,
229    /// Number of tensor dimensions (1–5).
230    pub num_dims: u32,
231    /// Size of each global tensor dimension (innermost first).
232    pub global_dims: [u64; 5],
233    /// Byte stride between elements in outer dimensions (innermost stride omitted).
234    pub global_strides: [u64; 4],
235    /// Size of each tile dimension (must fit in shared memory).
236    pub box_dims: [u32; 5],
237    /// Element stride within each tile dimension (typically all-ones).
238    pub element_strides: [u32; 5],
239    /// Interleave mode.
240    pub interleave: CuTensorMapInterleave,
241    /// Swizzle mode.
242    pub swizzle: CuTensorMapSwizzle,
243    /// L2 promotion hint.
244    pub l2_promotion: CuTensorMapL2Promotion,
245    /// Out-of-bounds fill mode for float elements.
246    pub oob_fill: CuTensorMapFloatOobFill,
247}
248
249// =========================================================================
250// TmaDescriptorBuilder
251// =========================================================================
252
253/// Typed builder for TMA tensor-map descriptors.
254///
255/// Collects parameters in a convenient Rust API and produces a
256/// [`TmaEncodeTiledParams`] struct suitable for passing to the CUDA driver's
257/// `cuTensorMapEncodeTiled` entry point.
258///
259/// # Example
260///
261/// ```rust
262/// use oxicuda_driver::tma::{
263///     CuTensorMapDataType, CuTensorMapSwizzle, TmaDescriptorBuilder,
264/// };
265///
266/// let params = TmaDescriptorBuilder::new_2d(
267///     CuTensorMapDataType::Bfloat16,
268///     512, 1024,           // rows × cols
269///     1024 * 2,            // row stride in bytes
270///     64, 64,              // tile rows × tile cols
271/// )
272/// .with_swizzle(CuTensorMapSwizzle::B128)
273/// .params();
274///
275/// assert_eq!(params.num_dims, 2);
276/// assert_eq!(params.global_dims[0], 1024); // cols (innermost)
277/// assert_eq!(params.global_dims[1], 512);  // rows
278/// assert_eq!(params.box_dims[0], 64);      // tile cols
279/// assert_eq!(params.box_dims[1], 64);      // tile rows
280/// ```
281#[derive(Debug, Clone)]
282pub struct TmaDescriptorBuilder {
283    data_type: CuTensorMapDataType,
284    num_dims: u32,
285    global_dims: [u64; 5],
286    global_strides: [u64; 4],
287    box_dims: [u32; 5],
288    element_strides: [u32; 5],
289    interleave: CuTensorMapInterleave,
290    swizzle: CuTensorMapSwizzle,
291    l2_promotion: CuTensorMapL2Promotion,
292    oob_fill: CuTensorMapFloatOobFill,
293}
294
295impl TmaDescriptorBuilder {
296    /// Create a 2-D tiled TMA descriptor for a row-major matrix.
297    ///
298    /// # Parameters
299    ///
300    /// * `data_type` — element type.
301    /// * `rows` — number of rows in the global tensor.
302    /// * `cols` — number of columns in the global tensor.
303    /// * `row_stride_bytes` — byte offset between consecutive rows in global
304    ///   memory (often `cols * element_size`).
305    /// * `box_rows` — tile height (rows per block in shared memory).
306    /// * `box_cols` — tile width (cols per block in shared memory).
307    ///
308    /// # Panics
309    ///
310    /// Does not panic; invalid parameters will be caught by the driver when
311    /// `cuTensorMapEncodeTiled` is called.
312    #[must_use]
313    pub fn new_2d(
314        data_type: CuTensorMapDataType,
315        rows: u64,
316        cols: u64,
317        row_stride_bytes: u64,
318        box_rows: u32,
319        box_cols: u32,
320    ) -> Self {
321        Self {
322            data_type,
323            num_dims: 2,
324            // CUDA uses innermost-first ordering.
325            global_dims: [cols, rows, 1, 1, 1],
326            global_strides: [row_stride_bytes, 0, 0, 0],
327            box_dims: [box_cols, box_rows, 1, 1, 1],
328            element_strides: [1, 1, 1, 1, 1],
329            interleave: CuTensorMapInterleave::None,
330            swizzle: CuTensorMapSwizzle::B128,
331            l2_promotion: CuTensorMapL2Promotion::L2B128,
332            oob_fill: CuTensorMapFloatOobFill::None,
333        }
334    }
335
336    /// Create an N-dimensional tiled TMA descriptor (N ≤ 5).
337    ///
338    /// # Parameters
339    ///
340    /// * `data_type` — element type.
341    /// * `num_dims` — number of tensor dimensions (1–5).
342    /// * `global_dims` — size of each dimension, innermost (col) first.
343    /// * `global_strides` — byte stride for each outer dimension (`num_dims - 1` entries).
344    /// * `box_dims` — tile extent per dimension.
345    /// * `element_strides` — stride between elements in each tile dimension.
346    #[must_use]
347    #[allow(clippy::too_many_arguments)]
348    pub fn new_nd(
349        data_type: CuTensorMapDataType,
350        num_dims: u32,
351        global_dims: [u64; 5],
352        global_strides: [u64; 4],
353        box_dims: [u32; 5],
354        element_strides: [u32; 5],
355    ) -> Self {
356        Self {
357            data_type,
358            num_dims,
359            global_dims,
360            global_strides,
361            box_dims,
362            element_strides,
363            interleave: CuTensorMapInterleave::None,
364            swizzle: CuTensorMapSwizzle::B128,
365            l2_promotion: CuTensorMapL2Promotion::L2B128,
366            oob_fill: CuTensorMapFloatOobFill::None,
367        }
368    }
369
370    /// Override the swizzle pattern (default: [`CuTensorMapSwizzle::B128`]).
371    #[must_use]
372    pub fn with_swizzle(mut self, swizzle: CuTensorMapSwizzle) -> Self {
373        self.swizzle = swizzle;
374        self
375    }
376
377    /// Override the interleave mode (default: [`CuTensorMapInterleave::None`]).
378    #[must_use]
379    pub fn with_interleave(mut self, interleave: CuTensorMapInterleave) -> Self {
380        self.interleave = interleave;
381        self
382    }
383
384    /// Override the L2 promotion hint (default: [`CuTensorMapL2Promotion::L2B128`]).
385    #[must_use]
386    pub fn with_l2_promotion(mut self, l2_promotion: CuTensorMapL2Promotion) -> Self {
387        self.l2_promotion = l2_promotion;
388        self
389    }
390
391    /// Override the out-of-bounds fill mode
392    /// (default: [`CuTensorMapFloatOobFill::None`]).
393    #[must_use]
394    pub fn with_oob_fill(mut self, oob_fill: CuTensorMapFloatOobFill) -> Self {
395        self.oob_fill = oob_fill;
396        self
397    }
398
399    /// Finalise the builder and return the flat parameter struct.
400    ///
401    /// Pass the fields of the returned [`TmaEncodeTiledParams`] directly to
402    /// `cuTensorMapEncodeTiled`.
403    #[must_use]
404    pub fn params(self) -> TmaEncodeTiledParams {
405        TmaEncodeTiledParams {
406            data_type: self.data_type,
407            num_dims: self.num_dims,
408            global_dims: self.global_dims,
409            global_strides: self.global_strides,
410            box_dims: self.box_dims,
411            element_strides: self.element_strides,
412            interleave: self.interleave,
413            swizzle: self.swizzle,
414            l2_promotion: self.l2_promotion,
415            oob_fill: self.oob_fill,
416        }
417    }
418}
419
420// =========================================================================
421// Tests
422// =========================================================================
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_cu_tensor_map_size_and_alignment() {
430        // Must be exactly 128 bytes and 64-byte aligned.
431        assert_eq!(std::mem::size_of::<CuTensorMap>(), 128);
432        assert_eq!(std::mem::align_of::<CuTensorMap>(), 64);
433    }
434
435    #[test]
436    fn test_cu_tensor_map_default_is_zero() {
437        let m = CuTensorMap::default();
438        assert!(m.opaque.iter().all(|&v| v == 0));
439    }
440
441    #[test]
442    fn test_data_type_element_sizes() {
443        assert_eq!(CuTensorMapDataType::Uint8.element_size_bytes(), 1);
444        assert_eq!(CuTensorMapDataType::Float16.element_size_bytes(), 2);
445        assert_eq!(CuTensorMapDataType::Bfloat16.element_size_bytes(), 2);
446        assert_eq!(CuTensorMapDataType::Float32.element_size_bytes(), 4);
447        assert_eq!(CuTensorMapDataType::Int32.element_size_bytes(), 4);
448        assert_eq!(CuTensorMapDataType::Float64.element_size_bytes(), 8);
449        assert_eq!(CuTensorMapDataType::Uint64.element_size_bytes(), 8);
450    }
451
452    #[test]
453    fn test_tma_builder_2d_dimension_ordering() {
454        // CUDA uses innermost-first — cols come before rows.
455        let params = TmaDescriptorBuilder::new_2d(
456            CuTensorMapDataType::Float16,
457            1024, // rows
458            2048, // cols
459            2048 * 2,
460            64,
461            128,
462        )
463        .params();
464
465        assert_eq!(params.num_dims, 2);
466        assert_eq!(params.global_dims[0], 2048); // cols first
467        assert_eq!(params.global_dims[1], 1024); // rows second
468        assert_eq!(params.box_dims[0], 128); // tile cols
469        assert_eq!(params.box_dims[1], 64); // tile rows
470    }
471
472    #[test]
473    fn test_tma_builder_swizzle_override() {
474        let params =
475            TmaDescriptorBuilder::new_2d(CuTensorMapDataType::Float32, 64, 64, 64 * 4, 16, 16)
476                .with_swizzle(CuTensorMapSwizzle::B64)
477                .params();
478
479        assert!(matches!(params.swizzle, CuTensorMapSwizzle::B64));
480    }
481
482    #[test]
483    fn test_tma_builder_interleave_and_oob() {
484        let params =
485            TmaDescriptorBuilder::new_2d(CuTensorMapDataType::Uint8, 256, 256, 256, 32, 32)
486                .with_interleave(CuTensorMapInterleave::B16)
487                .with_oob_fill(CuTensorMapFloatOobFill::NanRequestZeroFma)
488                .params();
489
490        assert!(matches!(params.interleave, CuTensorMapInterleave::B16));
491        assert!(matches!(
492            params.oob_fill,
493            CuTensorMapFloatOobFill::NanRequestZeroFma
494        ));
495    }
496
497    #[test]
498    fn test_tma_builder_l2_promotion() {
499        let params = TmaDescriptorBuilder::new_2d(
500            CuTensorMapDataType::Bfloat16,
501            512,
502            1024,
503            1024 * 2,
504            64,
505            64,
506        )
507        .with_l2_promotion(CuTensorMapL2Promotion::L2B256)
508        .params();
509
510        assert!(matches!(
511            params.l2_promotion,
512            CuTensorMapL2Promotion::L2B256
513        ));
514    }
515
516    #[test]
517    fn test_enum_repr_values() {
518        assert_eq!(CuTensorMapDataType::Uint8 as u32, 0);
519        assert_eq!(CuTensorMapDataType::Float16 as u32, 6);
520        assert_eq!(CuTensorMapDataType::Bfloat16 as u32, 9);
521        assert_eq!(CuTensorMapDataType::TF32 as u32, 11);
522        assert_eq!(CuTensorMapInterleave::None as u32, 0);
523        assert_eq!(CuTensorMapInterleave::B32 as u32, 2);
524        assert_eq!(CuTensorMapSwizzle::B128 as u32, 3);
525        assert_eq!(CuTensorMapL2Promotion::L2B256 as u32, 3);
526        assert_eq!(CuTensorMapFloatOobFill::NanRequestZeroFma as u32, 1);
527    }
528
529    #[test]
530    fn test_nd_builder() {
531        let params = TmaDescriptorBuilder::new_nd(
532            CuTensorMapDataType::Float32,
533            3,
534            [512, 256, 128, 1, 1],
535            [512 * 4, 512 * 256 * 4, 0, 0],
536            [32, 16, 8, 1, 1],
537            [1, 1, 1, 1, 1],
538        )
539        .params();
540
541        assert_eq!(params.num_dims, 3);
542        assert_eq!(params.global_dims[0], 512);
543        assert_eq!(params.global_dims[1], 256);
544        assert_eq!(params.global_dims[2], 128);
545    }
546}