Skip to main content

TensorCore

Struct TensorCore 

Source
pub struct TensorCore {
    pub dims: (usize, usize, usize),
    pub threads: usize,
    pub elements_per_thread: (usize, usize, usize),
    pub dtype_in: DType,
    pub dtype_out: DType,
    pub opts: SmallVec<[TcOpt; 8]>,
    pub swizzle: ((SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>), (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>)),
    pub pack_a: bool,
    pub tile_grid: (usize, usize),
}
Expand description

Tensor core configuration for hardware-accelerated matrix multiplication.

Describes a specific matrix multiplication unit with fixed dimensions and data types. Based on NVIDIA’s WMMA (Warp Matrix Multiply-Accumulate) API and similar accelerators.

§Matrix Dimensions

Tensor cores perform: C[M,N] += A[M,K] × B[K,N]

  • dims.0 (N): Number of output columns
  • dims.1 (M): Number of output rows
  • dims.2 (K): Reduction dimension size

§Example

NVIDIA Tensor Core 16x16x16:

  • Processes 16×16 output tile
  • Accumulates across 16 K elements
  • Uses 32 threads (warp size)
  • Each thread handles multiple elements via opts

Fields§

§dims: (usize, usize, usize)

Matrix dimensions (N, M, K).

§threads: usize

Number of threads required (typically warp size: 32 for CUDA, 64 for AMD).

§elements_per_thread: (usize, usize, usize)

Elements per thread in each dimension (N, M, K).

Describes how the matrix is distributed across threads. Example: (2, 2, 4) means each thread handles 2×2 output elements and processes 4 K elements.

§dtype_in: DType

Input matrix data type (A and B matrices).

§dtype_out: DType

Output/accumulator data type (C matrix).

§opts: SmallVec<[TcOpt; 8]>

Optimization sequence for tensor core application.

A sequence of operations to transform ranges. Each operation splits a dimension (N, M, or K) and assigns it to a new axis type.

Example: [Upcast(0), Local(0), Local(0), Local(1), Local(1), Local(1), Upcast(1)]

  • Upcast N once
  • Local split N twice
  • Local split M three times
  • Upcast M once

Uses SmallVec to avoid heap allocation for typical tensor cores (≤8 ops).

§swizzle: ((SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>), (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>))

Swizzle patterns for input permutation.

Describes how to permute input matrices to match hardware layout. Format: ((A_local, A_upcast, A_reduce), (B_local, B_upcast, B_reduce))

Each tuple contains axis references that describe the permutation pattern for optimal memory access. The first tuple is for matrix A, second for B.

Uses SmallVec to avoid heap allocation for typical swizzles (≤8 axes per vec).

§pack_a: bool

Pre-pack operand A into contiguous scratch buffer before the reduction loop. Beneficial when the A operand has non-unit stride access (e.g., AMX row-major matmul).

§tile_grid: (usize, usize)

Tile grid for multi-FMA batching (tile_y_count, tile_x_count).

When > (1, 1), the codegen emits load-pair instructions and multiple FMAs per K iteration to compute a grid of output tiles simultaneously. Default is (1, 1) for single-tile operation.

Implementations§

Source§

impl TensorCore

Source

pub fn get_reduce_axes(&self) -> Vec<(usize, usize)>

Get the axes for reduction unrolling.

Returns pairs of (dimension_index, unroll_amount) for the K dimension. Used during TC application to unroll the reduction dimension.

Source

pub fn upcast_axes(&self) -> (Vec<usize>, Vec<usize>, Vec<usize>)

Get the upcast axes configuration for WMMA construction.

Returns axes configuration for CONTRACT operations. Format: (A_axes, B_axes, output_axes)

Source

pub fn sm75_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for NVIDIA SM75 architecture (Turing).

Source

pub fn sm80_tensor_cores(allow_tf32: bool) -> Vec<TensorCore>

Get all tensor cores for NVIDIA SM80 architecture (Ampere).

Source

pub fn sm89_tensor_cores(allow_tf32: bool) -> Vec<TensorCore>

Get all tensor cores for NVIDIA SM89 architecture (Hopper).

Source

pub fn rdna3_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for AMD RDNA3 architecture (RX 7000 series).

Source

pub fn rdna4_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for AMD RDNA4 architecture.

Source

pub fn cdna3_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for AMD CDNA3 architecture (MI300).

Source

pub fn cdna4_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for AMD CDNA4 architecture.

Source

pub fn metal_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for Apple Metal (M1/M2/M3).

Source

pub fn amx_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for Apple AMX (M1/M2/M3 matrix accelerators).

Source

pub fn intel_tensor_cores() -> Vec<TensorCore>

Get all tensor cores for Intel Xe architecture.

Trait Implementations§

Source§

impl Clone for TensorCore

Source§

fn clone(&self) -> TensorCore

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for TensorCore

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more