baracuda_kernels/gemm/sparse24.rs
1//! 2:4 Structured Sparsity GEMM forward plan (Phase 54).
2//!
3//! Clean-room hand-port of facebookresearch/xFormers `sparse24/`
4//! algorithmic reference (BSD-3-Clause). See
5//! [`crates/baracuda-kernels-sys/vendor/xformers/VENDOR.md`] for the
6//! attribution + cherry-pick scope documentation.
7//!
8//! ## 2:4 pattern
9//!
10//! In every 4 consecutive weight cells, at most 2 are non-zero.
11//! Compressed weight format:
12//!
13//! | tensor | shape | dtype |
14//! |--------|-------|-------|
15//! | `W_compressed` | `[M, K/2]` | `T` (the GEMM dtype) |
16//! | `W_metadata` | `[M, K/8]` | `u16` (2 4-groups per uint16; one byte per 4-group) |
17//!
18//! The dense weight reconstruction is:
19//!
20//! ```text
21//! for m in 0..M:
22//! for k_group in 0..K/4:
23//! meta_byte = (k_group & 1 == 0) ? metadata[m, k_group/2] & 0xFF
24//! : (metadata[m, k_group/2] >> 8) & 0xFF
25//! pos0 = meta_byte & 0x3
26//! pos1 = (meta_byte >> 2) & 0x3
27//! w_dense[m, k_group*4 + pos0] = compressed[m, k_group*2 + 0]
28//! w_dense[m, k_group*4 + pos1] = compressed[m, k_group*2 + 1]
29//! (other 2 positions in the 4-group are 0)
30//! ```
31//!
32//! Output:
33//!
34//! ```text
35//! Y[N, M] = X[N, K] @ W_dense^T
36//! ```
37//!
38//! (Following PyTorch/xFormers convention — weight is `[out_features,
39//! in_features]`; `X @ W^T` is the canonical Linear-layer dispatch.)
40//!
41//! ## Tier-1 implementation strategy
42//!
43//! **Inflate-then-dense-matmul**: launch an inflation kernel that
44//! reconstructs `W_dense` in a caller-supplied workspace buffer
45//! (`M * K * sizeof(T)` bytes), then run a reference dense GEMM. This
46//! is **correctness first**; the sparse-tensor-core (`mma.sp.sync.aligned`)
47//! hardware speedup is deferred to Tier 2 alongside cuSPARSELt
48//! integration.
49//!
50//! The Tier-1 path is **not faster than dense cuBLAS** — it's slower
51//! at most shapes because the reference matmul is a naive triple-loop
52//! kernel (no tensor cores). The API + compression format are the
53//! Phase 54 deliverable; performance lands with the Tier-2
54//! cuSPARSELt-or-PTX backend.
55//!
56//! ## Constraints
57//!
58//! - `K` must be a multiple of 8.
59//! - Wired dtypes: `{f32, f16, bf16}`.
60//!
61//! ## Workspace
62//!
63//! `M * K * sizeof(T)` bytes for the inflated dense W. Query via
64//! [`GemmSparse24Plan::workspace_size`].
65
66use core::marker::PhantomData;
67
68use baracuda_cutlass::{Error, Result};
69use baracuda_driver::Stream;
70use baracuda_kernels_types::{
71 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
72 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
73};
74
75/// Descriptor for a 2:4 sparse GEMM forward op.
76#[derive(Copy, Clone, Debug)]
77pub struct GemmSparse24Descriptor {
78 /// Batch / sequence dimension (`N` rows of the X / Y tensors).
79 pub n: i32,
80 /// Output feature dimension (`M` — number of W rows; cols of Y).
81 pub m: i32,
82 /// Input feature dimension (`K` — cols of X; cols of W). Must be
83 /// a multiple of 8.
84 pub k: i32,
85 /// Element type — must match the plan's type parameter.
86 pub element: ElementKind,
87}
88
89/// Args bundle for a 2:4 sparse GEMM launch.
90pub struct GemmSparse24Args<'a, T: Element> {
91 /// Compressed weight — `[M, K/2]`, row-major contiguous.
92 pub w_compressed: TensorRef<'a, T, 2>,
93 /// Metadata — `[M, K/8]` `u16`, row-major contiguous. Each
94 /// `u16` carries 2 4-groups (one per byte; low byte first).
95 /// Per-byte encoding: bits `[0:1]` = pos0, bits `[2:3]` = pos1.
96 pub w_metadata: TensorRef<'a, u16, 2>,
97 /// Activation — `[N, K]`, row-major contiguous.
98 pub x: TensorRef<'a, T, 2>,
99 /// Output — `[N, M]`, row-major contiguous.
100 pub y: TensorMut<'a, T, 2>,
101}
102
103/// 2:4 structured-sparsity GEMM forward plan.
104///
105/// **When to use**: post-pruning inference where weights have been
106/// offline-compressed to the 2:4 format (e.g. via xFormers'
107/// `sparse24.compress` utility or NVIDIA's `apex.contrib.sparsity`).
108///
109/// **Dtypes**: `f32`, `f16`, `bf16`.
110///
111/// **Workspace**: `M * K * sizeof(T)` bytes — required for the inflated
112/// dense `W` reconstruction.
113///
114/// **Tier-1 caveat**: the trailblazer inflates-then-dense-matmuls;
115/// performance is NOT competitive with cuBLAS at this stage. The
116/// API + compression format are what Phase 54 delivers; the
117/// sparse-tensor-core speedup lands in Tier 2.
118pub struct GemmSparse24Plan<T: Element> {
119 desc: GemmSparse24Descriptor,
120 sku: KernelSku,
121 _marker: PhantomData<T>,
122}
123
124impl<T: Element> GemmSparse24Plan<T> {
125 /// Pick a kernel.
126 pub fn select(
127 _stream: &Stream,
128 desc: &GemmSparse24Descriptor,
129 _pref: PlanPreference,
130 ) -> Result<Self> {
131 if desc.element != T::KIND {
132 return Err(Error::Unsupported(
133 "baracuda-kernels::GemmSparse24Plan: descriptor element != T",
134 ));
135 }
136 if desc.n < 0 || desc.m < 0 || desc.k < 0 {
137 return Err(Error::InvalidProblem(
138 "baracuda-kernels::GemmSparse24Plan: N, M, K must be non-negative",
139 ));
140 }
141 if (desc.k & 7) != 0 {
142 return Err(Error::Unsupported(
143 "baracuda-kernels::GemmSparse24Plan: K must be a multiple of 8",
144 ));
145 }
146 let dtype_in_scope = matches!(
147 T::KIND,
148 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16
149 );
150 if !dtype_in_scope {
151 return Err(Error::Unsupported(
152 "baracuda-kernels::GemmSparse24Plan: wired today: `{f32, f16, bf16}`",
153 ));
154 }
155
156 // Pre-flight C-side validation.
157 #[cfg(feature = "xformers_sparse24")]
158 {
159 let probe = unsafe {
160 match T::KIND {
161 ElementKind::F32 =>
162 baracuda_kernels_sys::baracuda_kernels_gemm_f32_sparse24_gemm_can_implement(
163 desc.n, desc.m, desc.k,
164 ),
165 ElementKind::F16 =>
166 baracuda_kernels_sys::baracuda_kernels_gemm_f16_sparse24_gemm_can_implement(
167 desc.n, desc.m, desc.k,
168 ),
169 ElementKind::Bf16 =>
170 baracuda_kernels_sys::baracuda_kernels_gemm_bf16_sparse24_gemm_can_implement(
171 desc.n, desc.m, desc.k,
172 ),
173 _ => 3,
174 }
175 };
176 super::super::attention::map_status_pub(probe)?;
177 }
178
179 let precision_guarantee = PrecisionGuarantee {
180 math_precision: MathPrecision::F32,
181 accumulator: ElementKind::F32,
182 // Reference GEMM is deterministic per-cell (no atomicAdd).
183 bit_stable_on_same_hardware: true,
184 deterministic: true,
185 };
186 let sku = KernelSku {
187 category: OpCategory::Gemm,
188 op: 0,
189 element: T::KIND,
190 aux_element: None,
191 layout: None,
192 epilogue: None,
193 arch: ArchSku::Sm80,
194 backend: BackendKind::Bespoke,
195 precision_guarantee,
196 };
197 Ok(Self {
198 desc: *desc,
199 sku,
200 _marker: PhantomData,
201 })
202 }
203
204 /// Validate args against the descriptor.
205 pub fn can_implement(&self, args: &GemmSparse24Args<'_, T>) -> Result<()> {
206 if args.w_compressed.shape != [self.desc.m, self.desc.k / 2] {
207 return Err(Error::InvalidProblem(
208 "baracuda-kernels::GemmSparse24Plan: w_compressed shape must be [M, K/2]",
209 ));
210 }
211 if args.w_metadata.shape != [self.desc.m, self.desc.k / 8] {
212 return Err(Error::InvalidProblem(
213 "baracuda-kernels::GemmSparse24Plan: w_metadata shape must be [M, K/8]",
214 ));
215 }
216 if args.x.shape != [self.desc.n, self.desc.k] {
217 return Err(Error::InvalidProblem(
218 "baracuda-kernels::GemmSparse24Plan: x shape must be [N, K]",
219 ));
220 }
221 if args.y.shape != [self.desc.n, self.desc.m] {
222 return Err(Error::InvalidProblem(
223 "baracuda-kernels::GemmSparse24Plan: y shape must be [N, M]",
224 ));
225 }
226 if !args.w_compressed.is_contiguous()
227 || !args.w_metadata.is_contiguous()
228 || !args.x.is_contiguous()
229 || !args.y.is_contiguous()
230 {
231 return Err(Error::Unsupported(
232 "baracuda-kernels::GemmSparse24Plan: all tensors must be contiguous in Tier 1",
233 ));
234 }
235 Ok(())
236 }
237
238 /// Workspace size in bytes — `M * K * sizeof(T)` for the inflated
239 /// dense W.
240 #[inline]
241 pub fn workspace_size(&self) -> usize {
242 (self.desc.m as usize)
243 * (self.desc.k as usize)
244 * core::mem::size_of::<T>()
245 }
246
247 /// SKU identity.
248 #[inline]
249 pub fn sku(&self) -> KernelSku {
250 self.sku
251 }
252
253 /// Numerical guarantees.
254 #[inline]
255 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
256 self.sku.precision_guarantee
257 }
258
259 /// Run the inflate-then-matmul reference path.
260 pub fn run(
261 &self,
262 stream: &Stream,
263 workspace: Workspace<'_>,
264 args: GemmSparse24Args<'_, T>,
265 ) -> Result<()> {
266 self.can_implement(&args)?;
267 if args.y.numel() == 0 {
268 return Ok(());
269 }
270 #[cfg(feature = "xformers_sparse24")]
271 {
272 let needed = self.workspace_size();
273 let (ws_ptr, ws_bytes) = match workspace {
274 Workspace::None => {
275 return Err(Error::WorkspaceTooSmall {
276 needed,
277 got: 0,
278 });
279 }
280 Workspace::Borrowed(bytes) => {
281 let got = bytes.len();
282 if got < needed {
283 return Err(Error::WorkspaceTooSmall {
284 needed,
285 got,
286 });
287 }
288 (bytes.as_raw().0 as *mut c_void, got as u64)
289 }
290 };
291 let stream_ptr = stream.as_raw() as *mut c_void;
292 let x_ptr = args.x.data.as_raw().0 as *const c_void;
293 let wc_ptr = args.w_compressed.data.as_raw().0 as *const c_void;
294 let wm_ptr = args.w_metadata.data.as_raw().0 as *const c_void;
295 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
296 let status = unsafe {
297 match T::KIND {
298 ElementKind::F32 =>
299 baracuda_kernels_sys::baracuda_kernels_gemm_f32_sparse24_gemm_run(
300 self.desc.n, self.desc.m, self.desc.k,
301 x_ptr, wc_ptr, wm_ptr, y_ptr,
302 ws_ptr, ws_bytes, stream_ptr,
303 ),
304 ElementKind::F16 =>
305 baracuda_kernels_sys::baracuda_kernels_gemm_f16_sparse24_gemm_run(
306 self.desc.n, self.desc.m, self.desc.k,
307 x_ptr, wc_ptr, wm_ptr, y_ptr,
308 ws_ptr, ws_bytes, stream_ptr,
309 ),
310 ElementKind::Bf16 =>
311 baracuda_kernels_sys::baracuda_kernels_gemm_bf16_sparse24_gemm_run(
312 self.desc.n, self.desc.m, self.desc.k,
313 x_ptr, wc_ptr, wm_ptr, y_ptr,
314 ws_ptr, ws_bytes, stream_ptr,
315 ),
316 _ => return Err(Error::Unsupported(
317 "baracuda-kernels::GemmSparse24Plan::run reached an unimplemented dtype",
318 )),
319 }
320 };
321 super::super::attention::map_status_pub(status)
322 }
323 #[cfg(not(feature = "xformers_sparse24"))]
324 {
325 let _ = (stream, workspace);
326 Err(Error::Unsupported(
327 "baracuda-kernels::GemmSparse24Plan: build with the `xformers_sparse24` cargo feature",
328 ))
329 }
330 }
331}