pub struct GemmPlan<T>where
T: Element,{ /* private fields */ }Expand description
Selected GEMM kernel and the host-side metadata needed to launch it.
Plans are cheap to construct, hold no device memory, and are
Send + Sync for the same reason — they’re pure host data. The
Phase-30 cuBLAS fast-path adds no per-plan state: cuBLAS handles
live in a thread-local cache so the plan itself stays trivially
thread-safe.
See the crate root for usage; key methods:
select— pick a kernel for a problem shape.can_implement— host-side validation.workspace_size— bytes of scratch needed.run— launch on a stream.sku— identity of the chosen kernel.backend— which backend (CUTLASS / cuBLAS) was picked. Phase 30 added the cuBLAS fast-path for f16/bf16 low-M decode shapes; the heuristic is documented onshould_use_cublas_for_fp.
Implementations§
Source§impl<T> GemmPlan<T>where
T: Element,
impl<T> GemmPlan<T>where
T: Element,
Sourcepub fn select(
stream: &Stream,
desc: &GemmDescriptor,
pref: PlanPreference,
) -> Result<GemmPlan<T>, Error>
pub fn select( stream: &Stream, desc: &GemmDescriptor, pref: PlanPreference, ) -> Result<GemmPlan<T>, Error>
Pick a kernel for desc.
Queries the stream’s device for its compute capability and selects
between the CUTLASS-sm_80 (forward-compatible across Ampere /
Ada / Hopper), CUTLASS-sm_90a (Hopper-specialized, when feature-
enabled and the device actually is Hopper), and the Phase-30
cuBLAS fast-path. Build features filter what kernels are
available; the device cap and the f16/bf16-low-M heuristic
decide what to use. See [should_use_cublas_for_fp] for the
dispatch rules. Override the heuristic via
PlanPreference::prefer_backend.
Sourcepub fn backend(&self) -> BackendKind
pub fn backend(&self) -> BackendKind
Which backend this plan picked.
CUTLASS (the default path) or cuBLAS (the Phase-30 fast-path
for f16/bf16 low-M decode shapes). The dispatch heuristic is
documented on [should_use_cublas_for_fp].
Sourcepub fn can_implement(&self, args: &GemmArgs<'_, T>) -> Result<(), Error>
pub fn can_implement(&self, args: &GemmArgs<'_, T>) -> Result<(), Error>
Validate that this plan can actually launch with args.
Two-stage check:
- Host-side: shape/stride/buffer-size validation in pure Rust.
- Kernel-side: calls CUTLASS’s
Gemm::can_implementhost adapter via a no-launch FFI symbol to catch alignment and kernel-support issues that the host can’t see (e.g., the selected tile’s element-per-access requirement onlda/ldb).
Returns without launching a kernel and without touching the device.
Use this as a clean prelaunch branch point: if it returns Ok, the
run call will succeed barring runtime CUDA errors.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Bytes of device scratch this plan needs at run time.
Returns 0 when the kernel’s launch is workspace-free; pass
Workspace::None in that case.
Phase 30 note: even when the plan picked the cuBLAS backend (which manages its own scratch internally), this method reports the CUTLASS-side workspace requirement. The cuBLAS path falls back to CUTLASS under graph capture (cuBLAS-classic calls aren’t capture-safe), so the caller must size the workspace for the CUTLASS path in case the fallback triggers. In practice CUTLASS Identity-epilogue GEMM on sm_80 reports 0 bytes for most (M, N, K), so this conservative reporting rarely costs anything.
Sourcepub fn precision_guarantee(&self) -> PrecisionGuarantee
pub fn precision_guarantee(&self) -> PrecisionGuarantee
Numerical guarantees this plan’s kernel provides.
Convenience for GemmSku::precision_guarantee applied to this
plan’s SKU. Useful for callers that maintain a per-decision-point
alternatives table (e.g. picking between cuBLAS and CUTLASS for a
given precision contract) without having to re-derive the
guarantees from per-kernel documentation.
Sourcepub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: GemmArgs<'_, T>,
) -> Result<(), Error>
pub fn run( &self, stream: &Stream, workspace: Workspace<'_>, args: GemmArgs<'_, T>, ) -> Result<(), Error>
Launch the kernel.
workspace must be at least workspace_size
bytes when non-zero, or Workspace::None when zero. The stream
must be in the same context as the device buffers in args.