pub struct GemmArgs<'a, T>where
T: Element,{
pub a: MatrixRef<'a, T>,
pub b: MatrixRef<'a, T>,
pub c: Option<MatrixRef<'a, T>>,
pub d: MatrixMut<'a, T>,
pub bias: Option<VectorRef<'a, T>>,
pub alpha: <T as Element>::Scalar,
pub beta: <T as Element>::Scalar,
}Expand description
Per-launch arguments for a GemmPlan::run call.
c is optional: when None, β is ignored at the safe layer (treated
as 0) and the kernel computes D = α · A · B. When Some, the
kernel computes D = α · A · B + β · C — including the
c.data == d.data case for in-place accumulation.
bias is required iff the descriptor’s epilogue is one of the
Bias* variants, in which case the kernel computes
D = activation(α · A · B + β · C + bias_broadcast(N)).
Fields§
§a: MatrixRef<'a, T>Left input. Row-major [M, K].
b: MatrixRef<'a, T>Right input. Layout depends on the descriptor’s LayoutSku:
column-major [K, N] for LayoutSku::Rcr, row-major [K, N]
for LayoutSku::Rrr.
c: Option<MatrixRef<'a, T>>Optional accumulation source. Row-major [M, N].
d: MatrixMut<'a, T>Output. Row-major [M, N].
bias: Option<VectorRef<'a, T>>Optional bias vector. Required (Some) when the descriptor’s
epilogue is any Bias* variant; must be None for
EpilogueKind::Identity. Length-N, contiguous (stride 1)
device memory; broadcast across rows of D.
alpha: <T as Element>::ScalarMultiplier on the matrix-multiply accumulator. Scalar type
matches T::Scalar — f32 for f16/bf16/f32/F32Strict, f64
for f64.
beta: <T as Element>::ScalarMultiplier on c. Forced to 0 internally when c is None,
so callers don’t need to pre-zero it for the no-accumulate case.