Skip to main content

PreparedGroupedGemm

Struct PreparedGroupedGemm 

Source
pub struct PreparedGroupedGemm<'a, T>
where T: Element,
{ /* private fields */ }
Expand description

A GroupedGemmPlan bound to a concrete set of per-group problems.

Owns a PinnedBuffer<u8> holding the packed metadata (problem sizes, pointer arrays, leading dimensions). Pinned host memory is what makes the H2D inside run truly async — and therefore safely capturable into a CUDA graph. Owns no device memory; the caller supplies that via Workspace::Borrowed at run time.

§Lifetime contract

PreparedGroupedGemm extracts raw device pointers from the input GroupedProblem slice during prepare and stores them in pinned memory — it does not hold a Rust borrow on the input buffers afterwards. This is required for stream capture: the captured graph references the pinned buffer (for the metadata H2D) and the device buffers (via the pointer arrays) by raw address, not by Rust lifetime. The caller must therefore keep both this PreparedGroupedGemm and the underlying device buffers alive for as long as any captured graph that references them is in use.

In practice the pattern is: build groups, call prepare, capture into a graph, then keep PreparedGroupedGemm plus the input/output device buffers alive for the lifetime of the captured graph.

Implementations§

Source§

impl<'a, T> PreparedGroupedGemm<'a, T>
where T: Element,

Source

pub fn workspace_size(&self) -> usize

Total bytes of device workspace this plan needs at run time.

Includes both the packed metadata layout and CUTLASS’s internal scratch tail with alignment padding between them.

Source

pub fn sku(&self) -> GemmSku

Identity of the kernel this plan chose (forwarded from the parent GroupedGemmPlan).

Source

pub fn group_count(&self) -> usize

Group count this plan was prepared for.

Source

pub fn run( &self, stream: &Stream, workspace: Workspace<'_>, ) -> Result<(), Error>

Launch the grouped GEMM.

Uploads the packed metadata to the start of workspace via async H2D on stream, then enqueues the grouped kernel using the remainder of the workspace as CUTLASS internal scratch.

Trait Implementations§

Source§

impl<'a, T> Debug for PreparedGroupedGemm<'a, T>
where T: Debug + Element,

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<'a, T> Freeze for PreparedGroupedGemm<'a, T>

§

impl<'a, T> RefUnwindSafe for PreparedGroupedGemm<'a, T>
where T: RefUnwindSafe,

§

impl<'a, T> Send for PreparedGroupedGemm<'a, T>
where T: Send + Sync,

§

impl<'a, T> Sync for PreparedGroupedGemm<'a, T>
where T: Sync,

§

impl<'a, T> Unpin for PreparedGroupedGemm<'a, T>
where T: Unpin,

§

impl<'a, T> UnsafeUnpin for PreparedGroupedGemm<'a, T>

§

impl<'a, T> UnwindSafe for PreparedGroupedGemm<'a, T>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.