pub struct PreparedGroupedGemm<'a, T>where
T: Element,{ /* private fields */ }Expand description
A GroupedGemmPlan bound to a concrete set of per-group problems.
Owns a PinnedBuffer<u8> holding the packed metadata (problem
sizes, pointer arrays, leading dimensions). Pinned host memory is
what makes the H2D inside run truly async — and
therefore safely capturable into a CUDA graph. Owns no device memory;
the caller supplies that via Workspace::Borrowed at run time.
§Lifetime contract
PreparedGroupedGemm extracts raw device pointers from the input
GroupedProblem slice during prepare
and stores them in pinned memory — it does not hold a Rust borrow
on the input buffers afterwards. This is required for stream capture:
the captured graph references the pinned buffer (for the metadata
H2D) and the device buffers (via the pointer arrays) by raw address,
not by Rust lifetime. The caller must therefore keep both this
PreparedGroupedGemm and the underlying device buffers alive for as
long as any captured graph that references them is in use.
In practice the pattern is: build groups, call prepare, capture
into a graph, then keep PreparedGroupedGemm plus the input/output
device buffers alive for the lifetime of the captured graph.
Implementations§
Source§impl<'a, T> PreparedGroupedGemm<'a, T>where
T: Element,
impl<'a, T> PreparedGroupedGemm<'a, T>where
T: Element,
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Total bytes of device workspace this plan needs at run time.
Includes both the packed metadata layout and CUTLASS’s internal scratch tail with alignment padding between them.
Sourcepub fn sku(&self) -> GemmSku
pub fn sku(&self) -> GemmSku
Identity of the kernel this plan chose (forwarded from the parent
GroupedGemmPlan).
Sourcepub fn group_count(&self) -> usize
pub fn group_count(&self) -> usize
Group count this plan was prepared for.