pub struct TensorCore {
pub dims: (usize, usize, usize),
pub threads: usize,
pub elements_per_thread: (usize, usize, usize),
pub dtype_in: DType,
pub dtype_out: DType,
pub opts: SmallVec<[TcOpt; 8]>,
pub swizzle: ((SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>), (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>)),
pub pack_a: bool,
pub tile_grid: (usize, usize),
}Expand description
Tensor core configuration for hardware-accelerated matrix multiplication.
Describes a specific matrix multiplication unit with fixed dimensions and data types. Based on NVIDIA’s WMMA (Warp Matrix Multiply-Accumulate) API and similar accelerators.
§Matrix Dimensions
Tensor cores perform: C[M,N] += A[M,K] × B[K,N]
dims.0(N): Number of output columnsdims.1(M): Number of output rowsdims.2(K): Reduction dimension size
§Example
NVIDIA Tensor Core 16x16x16:
- Processes 16×16 output tile
- Accumulates across 16 K elements
- Uses 32 threads (warp size)
- Each thread handles multiple elements via opts
Fields§
§dims: (usize, usize, usize)Matrix dimensions (N, M, K).
threads: usizeNumber of threads required (typically warp size: 32 for CUDA, 64 for AMD).
elements_per_thread: (usize, usize, usize)Elements per thread in each dimension (N, M, K).
Describes how the matrix is distributed across threads. Example: (2, 2, 4) means each thread handles 2×2 output elements and processes 4 K elements.
dtype_in: DTypeInput matrix data type (A and B matrices).
dtype_out: DTypeOutput/accumulator data type (C matrix).
opts: SmallVec<[TcOpt; 8]>Optimization sequence for tensor core application.
A sequence of operations to transform ranges. Each operation splits a dimension (N, M, or K) and assigns it to a new axis type.
Example: [Upcast(0), Local(0), Local(0), Local(1), Local(1), Local(1), Upcast(1)]
- Upcast N once
- Local split N twice
- Local split M three times
- Upcast M once
Uses SmallVec to avoid heap allocation for typical tensor cores (≤8 ops).
swizzle: ((SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>), (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>))Swizzle patterns for input permutation.
Describes how to permute input matrices to match hardware layout. Format: ((A_local, A_upcast, A_reduce), (B_local, B_upcast, B_reduce))
Each tuple contains axis references that describe the permutation pattern for optimal memory access. The first tuple is for matrix A, second for B.
Uses SmallVec to avoid heap allocation for typical swizzles (≤8 axes per vec).
pack_a: boolPre-pack operand A into contiguous scratch buffer before the reduction loop. Beneficial when the A operand has non-unit stride access (e.g., AMX row-major matmul).
tile_grid: (usize, usize)Tile grid for multi-FMA batching (tile_y_count, tile_x_count).
When > (1, 1), the codegen emits load-pair instructions and multiple FMAs per K iteration to compute a grid of output tiles simultaneously. Default is (1, 1) for single-tile operation.
Implementations§
Source§impl TensorCore
impl TensorCore
Sourcepub fn get_reduce_axes(&self) -> Vec<(usize, usize)>
pub fn get_reduce_axes(&self) -> Vec<(usize, usize)>
Get the axes for reduction unrolling.
Returns pairs of (dimension_index, unroll_amount) for the K dimension. Used during TC application to unroll the reduction dimension.
Sourcepub fn upcast_axes(&self) -> (Vec<usize>, Vec<usize>, Vec<usize>)
pub fn upcast_axes(&self) -> (Vec<usize>, Vec<usize>, Vec<usize>)
Get the upcast axes configuration for WMMA construction.
Returns axes configuration for CONTRACT operations. Format: (A_axes, B_axes, output_axes)
Sourcepub fn sm75_tensor_cores() -> Vec<TensorCore>
pub fn sm75_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for NVIDIA SM75 architecture (Turing).
Sourcepub fn sm80_tensor_cores(allow_tf32: bool) -> Vec<TensorCore>
pub fn sm80_tensor_cores(allow_tf32: bool) -> Vec<TensorCore>
Get all tensor cores for NVIDIA SM80 architecture (Ampere).
Sourcepub fn sm89_tensor_cores(allow_tf32: bool) -> Vec<TensorCore>
pub fn sm89_tensor_cores(allow_tf32: bool) -> Vec<TensorCore>
Get all tensor cores for NVIDIA SM89 architecture (Hopper).
Sourcepub fn rdna3_tensor_cores() -> Vec<TensorCore>
pub fn rdna3_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for AMD RDNA3 architecture (RX 7000 series).
Sourcepub fn rdna4_tensor_cores() -> Vec<TensorCore>
pub fn rdna4_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for AMD RDNA4 architecture.
Sourcepub fn cdna3_tensor_cores() -> Vec<TensorCore>
pub fn cdna3_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for AMD CDNA3 architecture (MI300).
Sourcepub fn cdna4_tensor_cores() -> Vec<TensorCore>
pub fn cdna4_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for AMD CDNA4 architecture.
Sourcepub fn metal_tensor_cores() -> Vec<TensorCore>
pub fn metal_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for Apple Metal (M1/M2/M3).
Sourcepub fn amx_tensor_cores() -> Vec<TensorCore>
pub fn amx_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for Apple AMX (M1/M2/M3 matrix accelerators).
Sourcepub fn intel_tensor_cores() -> Vec<TensorCore>
pub fn intel_tensor_cores() -> Vec<TensorCore>
Get all tensor cores for Intel Xe architecture.
Trait Implementations§
Source§impl Clone for TensorCore
impl Clone for TensorCore
Source§fn clone(&self) -> TensorCore
fn clone(&self) -> TensorCore
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreAuto Trait Implementations§
impl Freeze for TensorCore
impl RefUnwindSafe for TensorCore
impl Send for TensorCore
impl Sync for TensorCore
impl Unpin for TensorCore
impl UnsafeUnpin for TensorCore
impl UnwindSafe for TensorCore
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more