Skip to main content

GemmPlan

Struct GemmPlan 

Source
pub struct GemmPlan<T>
where T: Element,
{ /* private fields */ }
Expand description

Selected GEMM kernel and the host-side metadata needed to launch it.

Plans are cheap to construct, hold no device memory, and are Send + Sync for the same reason — they’re pure host data. The Phase-30 cuBLAS fast-path adds no per-plan state: cuBLAS handles live in a thread-local cache so the plan itself stays trivially thread-safe.

See the crate root for usage; key methods:

  • select — pick a kernel for a problem shape.
  • can_implement — host-side validation.
  • workspace_size — bytes of scratch needed.
  • run — launch on a stream.
  • sku — identity of the chosen kernel.
  • backend — which backend (CUTLASS / cuBLAS) was picked. Phase 30 added the cuBLAS fast-path for f16/bf16 low-M decode shapes; the heuristic is documented on should_use_cublas_for_fp.

Implementations§

Source§

impl<T> GemmPlan<T>
where T: Element,

Source

pub fn select( stream: &Stream, desc: &GemmDescriptor, pref: PlanPreference, ) -> Result<GemmPlan<T>, Error>

Pick a kernel for desc.

Queries the stream’s device for its compute capability and selects between the CUTLASS-sm_80 (forward-compatible across Ampere / Ada / Hopper), CUTLASS-sm_90a (Hopper-specialized, when feature- enabled and the device actually is Hopper), and the Phase-30 cuBLAS fast-path. Build features filter what kernels are available; the device cap and the f16/bf16-low-M heuristic decide what to use. See [should_use_cublas_for_fp] for the dispatch rules. Override the heuristic via PlanPreference::prefer_backend.

Source

pub fn backend(&self) -> BackendKind

Which backend this plan picked.

CUTLASS (the default path) or cuBLAS (the Phase-30 fast-path for f16/bf16 low-M decode shapes). The dispatch heuristic is documented on [should_use_cublas_for_fp].

Source

pub fn can_implement(&self, args: &GemmArgs<'_, T>) -> Result<(), Error>

Validate that this plan can actually launch with args.

Two-stage check:

  1. Host-side: shape/stride/buffer-size validation in pure Rust.
  2. Kernel-side: calls CUTLASS’s Gemm::can_implement host adapter via a no-launch FFI symbol to catch alignment and kernel-support issues that the host can’t see (e.g., the selected tile’s element-per-access requirement on lda/ldb).

Returns without launching a kernel and without touching the device. Use this as a clean prelaunch branch point: if it returns Ok, the run call will succeed barring runtime CUDA errors.

Source

pub fn workspace_size(&self) -> usize

Bytes of device scratch this plan needs at run time.

Returns 0 when the kernel’s launch is workspace-free; pass Workspace::None in that case.

Phase 30 note: even when the plan picked the cuBLAS backend (which manages its own scratch internally), this method reports the CUTLASS-side workspace requirement. The cuBLAS path falls back to CUTLASS under graph capture (cuBLAS-classic calls aren’t capture-safe), so the caller must size the workspace for the CUTLASS path in case the fallback triggers. In practice CUTLASS Identity-epilogue GEMM on sm_80 reports 0 bytes for most (M, N, K), so this conservative reporting rarely costs anything.

Source

pub fn sku(&self) -> GemmSku

Identity of the kernel this plan chose.

Source

pub fn precision_guarantee(&self) -> PrecisionGuarantee

Numerical guarantees this plan’s kernel provides.

Convenience for GemmSku::precision_guarantee applied to this plan’s SKU. Useful for callers that maintain a per-decision-point alternatives table (e.g. picking between cuBLAS and CUTLASS for a given precision contract) without having to re-derive the guarantees from per-kernel documentation.

Source

pub fn run( &self, stream: &Stream, workspace: Workspace<'_>, args: GemmArgs<'_, T>, ) -> Result<(), Error>

Launch the kernel.

workspace must be at least workspace_size bytes when non-zero, or Workspace::None when zero. The stream must be in the same context as the device buffers in args.

Trait Implementations§

Source§

impl<T> Debug for GemmPlan<T>
where T: Debug + Element,

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<T> Freeze for GemmPlan<T>

§

impl<T> RefUnwindSafe for GemmPlan<T>
where T: RefUnwindSafe,

§

impl<T> Send for GemmPlan<T>
where T: Send,

§

impl<T> Sync for GemmPlan<T>
where T: Sync,

§

impl<T> Unpin for GemmPlan<T>
where T: Unpin,

§

impl<T> UnsafeUnpin for GemmPlan<T>

§

impl<T> UnwindSafe for GemmPlan<T>
where T: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.