oxicuda_driver/tma.rs
1//! Tensor Memory Accelerator (TMA) descriptor types for CUDA 12.x / sm_90+.
2//!
3//! The Tensor Memory Accelerator (TMA) is a hardware unit introduced on
4//! Hopper GPUs (sm_90) and extended on Blackwell (sm_100 / sm_120). It
5//! enables high-throughput bulk copies between global and shared memory
6//! using pre-built *tensor map* descriptors that encode address layout,
7//! swizzle, and out-of-bounds fill modes.
8//!
9//! # Descriptor creation
10//!
11//! Descriptors are created on the host by calling
12//! `cuTensorMapEncodeTiled` (exposed as an optional driver function pointer
13//! in [`DriverApi`](crate::loader::DriverApi)). This module provides:
14//!
15//! - The opaque [`CuTensorMap`] container (64 bytes, 64-byte aligned).
16//! - Configuration enums for data type, interleave, swizzle, etc.
17//! - [`TmaDescriptorBuilder`] — a typed builder that collects parameters
18//! and produces a [`TmaEncodeTiledParams`] ready to pass to the driver.
19//!
20//! # Example
21//!
22//! ```rust
23//! use oxicuda_driver::tma::{
24//! CuTensorMapDataType, CuTensorMapSwizzle, TmaDescriptorBuilder,
25//! };
26//!
27//! // Build a descriptor for a row-major f16 matrix of shape 1024×2048,
28//! // tiled with 64-row × 64-col shared-memory tiles.
29//! let params = TmaDescriptorBuilder::new_2d(
30//! CuTensorMapDataType::Float16,
31//! 1024, // rows
32//! 2048, // cols
33//! 2048 * 2, // row stride in bytes (2 bytes per f16)
34//! 64, 64, // tile rows, tile cols
35//! )
36//! .with_swizzle(CuTensorMapSwizzle::B128)
37//! .params();
38//!
39//! assert_eq!(params.num_dims, 2);
40//! assert_eq!(params.global_dims[0], 2048); // cols first in CUDA convention
41//! ```
42//!
43//! # CUDA 12+ and sm_90+ requirement
44//!
45//! TMA hardware is only present on Hopper (sm_90), Blackwell B100 (sm_100),
46//! and Blackwell B200 (sm_120) GPUs. On older devices the descriptor can
47//! still be built on the host but will not be usable in a kernel.
48
49// =========================================================================
50// CuTensorMap — the opaque descriptor container
51// =========================================================================
52
53/// Number of 64-bit words in the opaque tensor-map blob.
54pub const CU_TENSOR_MAP_NUM_QWORDS: usize = 16;
55
56/// Opaque TMA tensor map descriptor (64 bytes, 64-byte aligned).
57///
58/// Passed to CUDA kernels via kernel arguments so the TMA hardware can
59/// read its encoding. Created on the host with `cuTensorMapEncodeTiled`
60/// and must not be modified after the driver populates it.
61///
62/// # Layout
63///
64/// The 128-byte structure matches CUDA's `CUtensorMap` exactly. The
65/// internal encoding is private to the driver; user code should treat
66/// this as an opaque blob.
67#[repr(C, align(64))]
68#[derive(Clone, Copy)]
69pub struct CuTensorMap {
70 /// Opaque 128-byte payload.
71 pub opaque: [u64; CU_TENSOR_MAP_NUM_QWORDS],
72}
73
74impl Default for CuTensorMap {
75 fn default() -> Self {
76 Self {
77 opaque: [0u64; CU_TENSOR_MAP_NUM_QWORDS],
78 }
79 }
80}
81
82impl std::fmt::Debug for CuTensorMap {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("CuTensorMap")
85 .field("opaque[0]", &self.opaque[0])
86 .field("opaque[1]", &self.opaque[1])
87 .finish_non_exhaustive()
88 }
89}
90
91// =========================================================================
92// Configuration enums
93// =========================================================================
94
95/// Element data type for the TMA descriptor.
96///
97/// Corresponds to `CUtensorMapDataType` in the CUDA Driver API.
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
99#[repr(u32)]
100pub enum CuTensorMapDataType {
101 /// Unsigned 8-bit integer.
102 Uint8 = 0,
103 /// Unsigned 16-bit integer.
104 Uint16 = 1,
105 /// Unsigned 32-bit integer.
106 Uint32 = 2,
107 /// Signed 32-bit integer.
108 Int32 = 3,
109 /// Unsigned 64-bit integer.
110 Uint64 = 4,
111 /// Signed 64-bit integer.
112 Int64 = 5,
113 /// IEEE-754 half-precision float (f16).
114 Float16 = 6,
115 /// IEEE-754 single-precision float (f32).
116 Float32 = 7,
117 /// IEEE-754 double-precision float (f64).
118 Float64 = 8,
119 /// Brain float 16 (bfloat16).
120 Bfloat16 = 9,
121 /// f32 with flush-to-zero for subnormals.
122 Float32Ftz = 10,
123 /// TensorFloat-32 (TF32).
124 TF32 = 11,
125 /// TF32 with flush-to-zero.
126 TF32Ftz = 12,
127}
128
129impl CuTensorMapDataType {
130 /// Returns the element size in bytes.
131 #[must_use]
132 pub const fn element_size_bytes(self) -> u32 {
133 match self {
134 Self::Uint8 => 1,
135 Self::Uint16 | Self::Float16 | Self::Bfloat16 => 2,
136 Self::Uint32
137 | Self::Int32
138 | Self::Float32
139 | Self::Float32Ftz
140 | Self::TF32
141 | Self::TF32Ftz => 4,
142 Self::Uint64 | Self::Int64 | Self::Float64 => 8,
143 }
144 }
145}
146
147/// Interleave pattern for the TMA descriptor.
148///
149/// Controls how elements from different warp lanes are interleaved.
150/// `None` is safe for most use cases; `B16`/`B32` are only needed for
151/// 1D interleaving of narrow data types.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153#[repr(u32)]
154pub enum CuTensorMapInterleave {
155 /// No interleaving.
156 None = 0,
157 /// 16-byte interleaving stride.
158 B16 = 1,
159 /// 32-byte interleaving stride.
160 B32 = 2,
161}
162
163/// Swizzle pattern applied to shared-memory tiles.
164///
165/// Swizzling re-maps addresses within the tile to avoid shared-memory
166/// bank conflicts. Choose based on the tile width and element type.
167#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
168#[repr(u32)]
169pub enum CuTensorMapSwizzle {
170 /// No swizzle (row-major, bank conflicts possible).
171 None = 0,
172 /// 32-byte swizzle sector.
173 B32 = 1,
174 /// 64-byte swizzle sector.
175 B64 = 2,
176 /// 128-byte swizzle sector. Default for most f16/bf16 workloads.
177 B128 = 3,
178}
179
180/// L2 cache promotion hint for TMA loads.
181///
182/// Instructs the L2 cache to pre-fetch TMA data in larger chunks.
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
184#[repr(u32)]
185pub enum CuTensorMapL2Promotion {
186 /// No L2 promotion.
187 None = 0,
188 /// Promote to 64-byte L2 lines.
189 L2B64 = 1,
190 /// Promote to 128-byte L2 lines.
191 L2B128 = 2,
192 /// Promote to 256-byte L2 lines.
193 L2B256 = 3,
194}
195
196/// Out-of-bounds fill mode for TMA loads.
197///
198/// When a TMA access goes out of the declared tensor bounds, this controls
199/// whether out-of-bounds elements return zero or NaN.
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
201#[repr(u32)]
202pub enum CuTensorMapFloatOobFill {
203 /// No fill — out-of-bounds reads are undefined.
204 None = 0,
205 /// Out-of-bounds float reads return NaN; FMA returns zero.
206 NanRequestZeroFma = 1,
207}
208
209// =========================================================================
210// TmaEncodeTiledParams — flattened parameters for the driver call
211// =========================================================================
212
213/// All parameters required by `cuTensorMapEncodeTiled`.
214///
215/// Produced by [`TmaDescriptorBuilder`]. Pass these to the driver function
216/// pointer available in
217/// [`DriverApi::cu_tensor_map_encode_tiled`](crate::loader::DriverApi::cu_tensor_map_encode_tiled).
218///
219/// # Dimension ordering
220///
221/// CUDA TMA uses **column-major** ordering for dimensions: `global_dims[0]`
222/// is the *innermost* (fastest-varying) dimension. For a row-major matrix
223/// of shape `R × C`, set `global_dims[0] = C` (cols) and
224/// `global_dims[1] = R` (rows).
225#[derive(Debug, Clone)]
226pub struct TmaEncodeTiledParams {
227 /// Element data type.
228 pub data_type: CuTensorMapDataType,
229 /// Number of tensor dimensions (1–5).
230 pub num_dims: u32,
231 /// Size of each global tensor dimension (innermost first).
232 pub global_dims: [u64; 5],
233 /// Byte stride between elements in outer dimensions (innermost stride omitted).
234 pub global_strides: [u64; 4],
235 /// Size of each tile dimension (must fit in shared memory).
236 pub box_dims: [u32; 5],
237 /// Element stride within each tile dimension (typically all-ones).
238 pub element_strides: [u32; 5],
239 /// Interleave mode.
240 pub interleave: CuTensorMapInterleave,
241 /// Swizzle mode.
242 pub swizzle: CuTensorMapSwizzle,
243 /// L2 promotion hint.
244 pub l2_promotion: CuTensorMapL2Promotion,
245 /// Out-of-bounds fill mode for float elements.
246 pub oob_fill: CuTensorMapFloatOobFill,
247}
248
249// =========================================================================
250// TmaDescriptorBuilder
251// =========================================================================
252
253/// Typed builder for TMA tensor-map descriptors.
254///
255/// Collects parameters in a convenient Rust API and produces a
256/// [`TmaEncodeTiledParams`] struct suitable for passing to the CUDA driver's
257/// `cuTensorMapEncodeTiled` entry point.
258///
259/// # Example
260///
261/// ```rust
262/// use oxicuda_driver::tma::{
263/// CuTensorMapDataType, CuTensorMapSwizzle, TmaDescriptorBuilder,
264/// };
265///
266/// let params = TmaDescriptorBuilder::new_2d(
267/// CuTensorMapDataType::Bfloat16,
268/// 512, 1024, // rows × cols
269/// 1024 * 2, // row stride in bytes
270/// 64, 64, // tile rows × tile cols
271/// )
272/// .with_swizzle(CuTensorMapSwizzle::B128)
273/// .params();
274///
275/// assert_eq!(params.num_dims, 2);
276/// assert_eq!(params.global_dims[0], 1024); // cols (innermost)
277/// assert_eq!(params.global_dims[1], 512); // rows
278/// assert_eq!(params.box_dims[0], 64); // tile cols
279/// assert_eq!(params.box_dims[1], 64); // tile rows
280/// ```
281#[derive(Debug, Clone)]
282pub struct TmaDescriptorBuilder {
283 data_type: CuTensorMapDataType,
284 num_dims: u32,
285 global_dims: [u64; 5],
286 global_strides: [u64; 4],
287 box_dims: [u32; 5],
288 element_strides: [u32; 5],
289 interleave: CuTensorMapInterleave,
290 swizzle: CuTensorMapSwizzle,
291 l2_promotion: CuTensorMapL2Promotion,
292 oob_fill: CuTensorMapFloatOobFill,
293}
294
295impl TmaDescriptorBuilder {
296 /// Create a 2-D tiled TMA descriptor for a row-major matrix.
297 ///
298 /// # Parameters
299 ///
300 /// * `data_type` — element type.
301 /// * `rows` — number of rows in the global tensor.
302 /// * `cols` — number of columns in the global tensor.
303 /// * `row_stride_bytes` — byte offset between consecutive rows in global
304 /// memory (often `cols * element_size`).
305 /// * `box_rows` — tile height (rows per block in shared memory).
306 /// * `box_cols` — tile width (cols per block in shared memory).
307 ///
308 /// # Panics
309 ///
310 /// Does not panic; invalid parameters will be caught by the driver when
311 /// `cuTensorMapEncodeTiled` is called.
312 #[must_use]
313 pub fn new_2d(
314 data_type: CuTensorMapDataType,
315 rows: u64,
316 cols: u64,
317 row_stride_bytes: u64,
318 box_rows: u32,
319 box_cols: u32,
320 ) -> Self {
321 Self {
322 data_type,
323 num_dims: 2,
324 // CUDA uses innermost-first ordering.
325 global_dims: [cols, rows, 1, 1, 1],
326 global_strides: [row_stride_bytes, 0, 0, 0],
327 box_dims: [box_cols, box_rows, 1, 1, 1],
328 element_strides: [1, 1, 1, 1, 1],
329 interleave: CuTensorMapInterleave::None,
330 swizzle: CuTensorMapSwizzle::B128,
331 l2_promotion: CuTensorMapL2Promotion::L2B128,
332 oob_fill: CuTensorMapFloatOobFill::None,
333 }
334 }
335
336 /// Create an N-dimensional tiled TMA descriptor (N ≤ 5).
337 ///
338 /// # Parameters
339 ///
340 /// * `data_type` — element type.
341 /// * `num_dims` — number of tensor dimensions (1–5).
342 /// * `global_dims` — size of each dimension, innermost (col) first.
343 /// * `global_strides` — byte stride for each outer dimension (`num_dims - 1` entries).
344 /// * `box_dims` — tile extent per dimension.
345 /// * `element_strides` — stride between elements in each tile dimension.
346 #[must_use]
347 #[allow(clippy::too_many_arguments)]
348 pub fn new_nd(
349 data_type: CuTensorMapDataType,
350 num_dims: u32,
351 global_dims: [u64; 5],
352 global_strides: [u64; 4],
353 box_dims: [u32; 5],
354 element_strides: [u32; 5],
355 ) -> Self {
356 Self {
357 data_type,
358 num_dims,
359 global_dims,
360 global_strides,
361 box_dims,
362 element_strides,
363 interleave: CuTensorMapInterleave::None,
364 swizzle: CuTensorMapSwizzle::B128,
365 l2_promotion: CuTensorMapL2Promotion::L2B128,
366 oob_fill: CuTensorMapFloatOobFill::None,
367 }
368 }
369
370 /// Override the swizzle pattern (default: [`CuTensorMapSwizzle::B128`]).
371 #[must_use]
372 pub fn with_swizzle(mut self, swizzle: CuTensorMapSwizzle) -> Self {
373 self.swizzle = swizzle;
374 self
375 }
376
377 /// Override the interleave mode (default: [`CuTensorMapInterleave::None`]).
378 #[must_use]
379 pub fn with_interleave(mut self, interleave: CuTensorMapInterleave) -> Self {
380 self.interleave = interleave;
381 self
382 }
383
384 /// Override the L2 promotion hint (default: [`CuTensorMapL2Promotion::L2B128`]).
385 #[must_use]
386 pub fn with_l2_promotion(mut self, l2_promotion: CuTensorMapL2Promotion) -> Self {
387 self.l2_promotion = l2_promotion;
388 self
389 }
390
391 /// Override the out-of-bounds fill mode
392 /// (default: [`CuTensorMapFloatOobFill::None`]).
393 #[must_use]
394 pub fn with_oob_fill(mut self, oob_fill: CuTensorMapFloatOobFill) -> Self {
395 self.oob_fill = oob_fill;
396 self
397 }
398
399 /// Finalise the builder and return the flat parameter struct.
400 ///
401 /// Pass the fields of the returned [`TmaEncodeTiledParams`] directly to
402 /// `cuTensorMapEncodeTiled`.
403 #[must_use]
404 pub fn params(self) -> TmaEncodeTiledParams {
405 TmaEncodeTiledParams {
406 data_type: self.data_type,
407 num_dims: self.num_dims,
408 global_dims: self.global_dims,
409 global_strides: self.global_strides,
410 box_dims: self.box_dims,
411 element_strides: self.element_strides,
412 interleave: self.interleave,
413 swizzle: self.swizzle,
414 l2_promotion: self.l2_promotion,
415 oob_fill: self.oob_fill,
416 }
417 }
418}
419
420// =========================================================================
421// Tests
422// =========================================================================
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_cu_tensor_map_size_and_alignment() {
430 // Must be exactly 128 bytes and 64-byte aligned.
431 assert_eq!(std::mem::size_of::<CuTensorMap>(), 128);
432 assert_eq!(std::mem::align_of::<CuTensorMap>(), 64);
433 }
434
435 #[test]
436 fn test_cu_tensor_map_default_is_zero() {
437 let m = CuTensorMap::default();
438 assert!(m.opaque.iter().all(|&v| v == 0));
439 }
440
441 #[test]
442 fn test_data_type_element_sizes() {
443 assert_eq!(CuTensorMapDataType::Uint8.element_size_bytes(), 1);
444 assert_eq!(CuTensorMapDataType::Float16.element_size_bytes(), 2);
445 assert_eq!(CuTensorMapDataType::Bfloat16.element_size_bytes(), 2);
446 assert_eq!(CuTensorMapDataType::Float32.element_size_bytes(), 4);
447 assert_eq!(CuTensorMapDataType::Int32.element_size_bytes(), 4);
448 assert_eq!(CuTensorMapDataType::Float64.element_size_bytes(), 8);
449 assert_eq!(CuTensorMapDataType::Uint64.element_size_bytes(), 8);
450 }
451
452 #[test]
453 fn test_tma_builder_2d_dimension_ordering() {
454 // CUDA uses innermost-first — cols come before rows.
455 let params = TmaDescriptorBuilder::new_2d(
456 CuTensorMapDataType::Float16,
457 1024, // rows
458 2048, // cols
459 2048 * 2,
460 64,
461 128,
462 )
463 .params();
464
465 assert_eq!(params.num_dims, 2);
466 assert_eq!(params.global_dims[0], 2048); // cols first
467 assert_eq!(params.global_dims[1], 1024); // rows second
468 assert_eq!(params.box_dims[0], 128); // tile cols
469 assert_eq!(params.box_dims[1], 64); // tile rows
470 }
471
472 #[test]
473 fn test_tma_builder_swizzle_override() {
474 let params =
475 TmaDescriptorBuilder::new_2d(CuTensorMapDataType::Float32, 64, 64, 64 * 4, 16, 16)
476 .with_swizzle(CuTensorMapSwizzle::B64)
477 .params();
478
479 assert!(matches!(params.swizzle, CuTensorMapSwizzle::B64));
480 }
481
482 #[test]
483 fn test_tma_builder_interleave_and_oob() {
484 let params =
485 TmaDescriptorBuilder::new_2d(CuTensorMapDataType::Uint8, 256, 256, 256, 32, 32)
486 .with_interleave(CuTensorMapInterleave::B16)
487 .with_oob_fill(CuTensorMapFloatOobFill::NanRequestZeroFma)
488 .params();
489
490 assert!(matches!(params.interleave, CuTensorMapInterleave::B16));
491 assert!(matches!(
492 params.oob_fill,
493 CuTensorMapFloatOobFill::NanRequestZeroFma
494 ));
495 }
496
497 #[test]
498 fn test_tma_builder_l2_promotion() {
499 let params = TmaDescriptorBuilder::new_2d(
500 CuTensorMapDataType::Bfloat16,
501 512,
502 1024,
503 1024 * 2,
504 64,
505 64,
506 )
507 .with_l2_promotion(CuTensorMapL2Promotion::L2B256)
508 .params();
509
510 assert!(matches!(
511 params.l2_promotion,
512 CuTensorMapL2Promotion::L2B256
513 ));
514 }
515
516 #[test]
517 fn test_enum_repr_values() {
518 assert_eq!(CuTensorMapDataType::Uint8 as u32, 0);
519 assert_eq!(CuTensorMapDataType::Float16 as u32, 6);
520 assert_eq!(CuTensorMapDataType::Bfloat16 as u32, 9);
521 assert_eq!(CuTensorMapDataType::TF32 as u32, 11);
522 assert_eq!(CuTensorMapInterleave::None as u32, 0);
523 assert_eq!(CuTensorMapInterleave::B32 as u32, 2);
524 assert_eq!(CuTensorMapSwizzle::B128 as u32, 3);
525 assert_eq!(CuTensorMapL2Promotion::L2B256 as u32, 3);
526 assert_eq!(CuTensorMapFloatOobFill::NanRequestZeroFma as u32, 1);
527 }
528
529 #[test]
530 fn test_nd_builder() {
531 let params = TmaDescriptorBuilder::new_nd(
532 CuTensorMapDataType::Float32,
533 3,
534 [512, 256, 128, 1, 1],
535 [512 * 4, 512 * 256 * 4, 0, 0],
536 [32, 16, 8, 1, 1],
537 [1, 1, 1, 1, 1],
538 )
539 .params();
540
541 assert_eq!(params.num_dims, 3);
542 assert_eq!(params.global_dims[0], 512);
543 assert_eq!(params.global_dims[1], 256);
544 assert_eq!(params.global_dims[2], 128);
545 }
546}