Skip to main content

baracuda_kernels_types/
layout.rs

1//! Layout / arch / epilogue / activation tags shared across kernel
2//! families.
3//!
4//! These are pure descriptor enums that don't carry generic parameters;
5//! they appear in plan descriptors, in [`crate::KernelSku`] (TBD) /
6//! `GemmSku`, and in selector preference fields.
7
8/// Layout SKU. Describes the row/column orientation of A, B, C, and D
9/// for matrix-multiply-shaped kernels.
10///
11/// **Intentionally NOT `#[non_exhaustive]`** — the GEMM layout space
12/// is essentially closed in practice (row-major / column-major
13/// permutations of A, B, C/D); the two wired variants cover the
14/// dispatch space `baracuda-cutlass` selects against. New variants
15/// would be a deliberate breaking change with a major-version bump.
16#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
17pub enum LayoutSku {
18    /// `A` row-major `[M, K]`, `B` column-major `[K, N]`, `C/D` row-major `[M, N]`.
19    ///
20    /// Useful when a row-major weight tensor stored as `[N, K]` is
21    /// reinterpreted as logical column-major `B = [K, N]` without a
22    /// transpose copy.
23    Rcr,
24    /// `A` row-major `[M, K]`, `B` row-major `[K, N]`, `C/D` row-major `[M, N]`.
25    ///
26    /// The natural shape for activation-row-major @ weight-row-major
27    /// matmul (the typical ML graph layout). No transpose pass needed
28    /// before launch — both operands stored in their native row-major
29    /// form.
30    Rrr,
31}
32
33/// Compute capability bucket the selected kernel was compiled for.
34///
35/// **Intentionally NOT `#[non_exhaustive]`** — the cutlass GEMM
36/// dispatchers exhaustively match on this enum to pick per-arch
37/// kernel SKUs; adding a new arch (Blackwell `Sm100a` is tracked in
38/// the ROADMAP) deserves to surface as a build break across every
39/// match site so each can decide whether to JIT-forward or add a
40/// dedicated variant. New variants are a deliberate
41/// breaking-change event.
42#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
43pub enum ArchSku {
44    /// Ampere (also runs on Ada and as forward-compatible fallback on Hopper).
45    Sm80,
46    /// Ada Lovelace specializations (FP8 tensor cores). Requires `sm89`
47    /// feature in the consuming kernel crate.
48    Sm89,
49    /// Hopper-specialized (requires `sm90a` feature in the consuming
50    /// kernel crate).
51    Sm90a,
52}
53
54/// Epilogue applied after the matrix-multiply accumulation.
55///
56/// The four `Bias*` variants share one kernel family: they all fuse the
57/// bias add into the output epilogue and additionally apply the named
58/// activation function before the store. `BiasRelu`, `BiasGelu`, and
59/// `BiasSilu` therefore deliver the full `y = activation(W·x + b)`
60/// transformer-Linear pipeline in a single kernel pass — no extra memory
61/// traffic vs plain `Bias`.
62// EpilogueKind is intentionally NOT `#[non_exhaustive]` — the cutlass
63// GEMM dispatchers exhaustively match `(LayoutSku, EpilogueKind)` to
64// pick per-fused-epilogue kernel SKUs. Adding a new epilogue (e.g.
65// `BiasTanh`, `BiasSigmoid`) deserves to surface as a build break
66// across every match site so each branch can choose to wire it or
67// reject. New variants are a deliberate breaking-change event.
68#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
69pub enum EpilogueKind {
70    /// `D = α · (A · B) + β · C` (no activation, no bias).
71    Identity,
72    /// `D = α · (A · B) + β · C + bias_broadcast(N)`. The bias vector
73    /// has length `N` (one element per output column) and is broadcast
74    /// across rows.
75    Bias,
76    /// `D = relu(α · (A · B) + β · C + bias_broadcast(N))`.
77    /// `relu(x) = max(x, 0)`.
78    BiasRelu,
79    /// `D = gelu(α · (A · B) + β · C + bias_broadcast(N))` using the
80    /// exact (erf-based) GELU — matches PyTorch's default `nn.GELU()`.
81    BiasGelu,
82    /// `D = silu(α · (A · B) + β · C + bias_broadcast(N))` where
83    /// `silu(x) = x · sigmoid(x)`. Also known as Swish-1.
84    BiasSilu,
85}
86
87impl EpilogueKind {
88    /// `true` if a bias broadcast must be supplied for this epilogue.
89    /// Equivalent to "any `Bias*` variant".
90    #[inline]
91    pub const fn requires_bias(self) -> bool {
92        matches!(
93            self,
94            Self::Bias | Self::BiasRelu | Self::BiasGelu | Self::BiasSilu,
95        )
96    }
97
98    /// Activation function this epilogue applies after the linear
99    /// combination, if any.
100    ///
101    /// Returns `None` for [`Identity`](Self::Identity) and
102    /// [`Bias`](Self::Bias) (both apply no activation); returns the
103    /// corresponding [`ActivationKind`] for the `Bias*Activation`
104    /// variants.
105    #[inline]
106    pub const fn activation(self) -> Option<ActivationKind> {
107        match self {
108            Self::Identity | Self::Bias => None,
109            Self::BiasRelu => Some(ActivationKind::Relu),
110            Self::BiasGelu => Some(ActivationKind::Gelu),
111            Self::BiasSilu => Some(ActivationKind::Silu),
112        }
113    }
114}
115
116/// Activation functions implemented by the `Bias*Activation`
117/// [`EpilogueKind`] variants. Surfaced for telemetry and selector
118/// logic; the kernel selection itself is driven by the enum variant.
119///
120/// **Intentionally NOT `#[non_exhaustive]`** — paired with
121/// [`EpilogueKind`] which is also left exhaustive. Adding a new
122/// activation requires shipping a matching `Bias<Activation>` epilogue
123/// kernel, which is a deliberate breaking-change event.
124#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
125pub enum ActivationKind {
126    /// `relu(x) = max(x, 0)`.
127    Relu,
128    /// Exact (erf-based) Gaussian Error Linear Unit. Matches
129    /// PyTorch's default `nn.GELU()`.
130    Gelu,
131    /// `silu(x) = x · sigmoid(x)`. Also known as Swish-1.
132    Silu,
133}