Skip to main content

baracuda_kernels/gemm/
sparse24.rs

1//! 2:4 Structured Sparsity GEMM forward plan (Phase 54).
2//!
3//! Clean-room hand-port of facebookresearch/xFormers `sparse24/`
4//! algorithmic reference (BSD-3-Clause). See
5//! [`crates/baracuda-kernels-sys/vendor/xformers/VENDOR.md`] for the
6//! attribution + cherry-pick scope documentation.
7//!
8//! ## 2:4 pattern
9//!
10//! In every 4 consecutive weight cells, at most 2 are non-zero.
11//! Compressed weight format:
12//!
13//! | tensor | shape | dtype |
14//! |--------|-------|-------|
15//! | `W_compressed` | `[M, K/2]` | `T` (the GEMM dtype) |
16//! | `W_metadata`   | `[M, K/8]` | `u16` (2 4-groups per uint16; one byte per 4-group) |
17//!
18//! The dense weight reconstruction is:
19//!
20//! ```text
21//! for m in 0..M:
22//!     for k_group in 0..K/4:
23//!         meta_byte = (k_group & 1 == 0) ? metadata[m, k_group/2] & 0xFF
24//!                                        : (metadata[m, k_group/2] >> 8) & 0xFF
25//!         pos0 = meta_byte & 0x3
26//!         pos1 = (meta_byte >> 2) & 0x3
27//!         w_dense[m, k_group*4 + pos0] = compressed[m, k_group*2 + 0]
28//!         w_dense[m, k_group*4 + pos1] = compressed[m, k_group*2 + 1]
29//!         (other 2 positions in the 4-group are 0)
30//! ```
31//!
32//! Output:
33//!
34//! ```text
35//! Y[N, M] = X[N, K] @ W_dense^T
36//! ```
37//!
38//! (Following PyTorch/xFormers convention — weight is `[out_features,
39//! in_features]`; `X @ W^T` is the canonical Linear-layer dispatch.)
40//!
41//! ## Tier-1 implementation strategy
42//!
43//! **Inflate-then-dense-matmul**: launch an inflation kernel that
44//! reconstructs `W_dense` in a caller-supplied workspace buffer
45//! (`M * K * sizeof(T)` bytes), then run a reference dense GEMM. This
46//! is **correctness first**; the sparse-tensor-core (`mma.sp.sync.aligned`)
47//! hardware speedup is deferred to Tier 2 alongside cuSPARSELt
48//! integration.
49//!
50//! The Tier-1 path is **not faster than dense cuBLAS** — it's slower
51//! at most shapes because the reference matmul is a naive triple-loop
52//! kernel (no tensor cores). The API + compression format are the
53//! Phase 54 deliverable; performance lands with the Tier-2
54//! cuSPARSELt-or-PTX backend.
55//!
56//! ## Constraints
57//!
58//! - `K` must be a multiple of 8.
59//! - Wired dtypes: `{f32, f16, bf16}`.
60//!
61//! ## Workspace
62//!
63//! `M * K * sizeof(T)` bytes for the inflated dense W. Query via
64//! [`GemmSparse24Plan::workspace_size`].
65
66use core::marker::PhantomData;
67
68use baracuda_cutlass::{Error, Result};
69use baracuda_driver::Stream;
70use baracuda_kernels_types::{
71    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
72    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
73};
74
75/// Descriptor for a 2:4 sparse GEMM forward op.
76#[derive(Copy, Clone, Debug)]
77pub struct GemmSparse24Descriptor {
78    /// Batch / sequence dimension (`N` rows of the X / Y tensors).
79    pub n: i32,
80    /// Output feature dimension (`M` — number of W rows; cols of Y).
81    pub m: i32,
82    /// Input feature dimension (`K` — cols of X; cols of W). Must be
83    /// a multiple of 8.
84    pub k: i32,
85    /// Element type — must match the plan's type parameter.
86    pub element: ElementKind,
87}
88
89/// Args bundle for a 2:4 sparse GEMM launch.
90pub struct GemmSparse24Args<'a, T: Element> {
91    /// Compressed weight — `[M, K/2]`, row-major contiguous.
92    pub w_compressed: TensorRef<'a, T, 2>,
93    /// Metadata — `[M, K/8]` `u16`, row-major contiguous. Each
94    /// `u16` carries 2 4-groups (one per byte; low byte first).
95    /// Per-byte encoding: bits `[0:1]` = pos0, bits `[2:3]` = pos1.
96    pub w_metadata: TensorRef<'a, u16, 2>,
97    /// Activation — `[N, K]`, row-major contiguous.
98    pub x: TensorRef<'a, T, 2>,
99    /// Output — `[N, M]`, row-major contiguous.
100    pub y: TensorMut<'a, T, 2>,
101}
102
103/// 2:4 structured-sparsity GEMM forward plan.
104///
105/// **When to use**: post-pruning inference where weights have been
106/// offline-compressed to the 2:4 format (e.g. via xFormers'
107/// `sparse24.compress` utility or NVIDIA's `apex.contrib.sparsity`).
108///
109/// **Dtypes**: `f32`, `f16`, `bf16`.
110///
111/// **Workspace**: `M * K * sizeof(T)` bytes — required for the inflated
112/// dense `W` reconstruction.
113///
114/// **Tier-1 caveat**: the trailblazer inflates-then-dense-matmuls;
115/// performance is NOT competitive with cuBLAS at this stage. The
116/// API + compression format are what Phase 54 delivers; the
117/// sparse-tensor-core speedup lands in Tier 2.
118pub struct GemmSparse24Plan<T: Element> {
119    desc: GemmSparse24Descriptor,
120    sku: KernelSku,
121    _marker: PhantomData<T>,
122}
123
124impl<T: Element> GemmSparse24Plan<T> {
125    /// Pick a kernel.
126    pub fn select(
127        _stream: &Stream,
128        desc: &GemmSparse24Descriptor,
129        _pref: PlanPreference,
130    ) -> Result<Self> {
131        if desc.element != T::KIND {
132            return Err(Error::Unsupported(
133                "baracuda-kernels::GemmSparse24Plan: descriptor element != T",
134            ));
135        }
136        if desc.n < 0 || desc.m < 0 || desc.k < 0 {
137            return Err(Error::InvalidProblem(
138                "baracuda-kernels::GemmSparse24Plan: N, M, K must be non-negative",
139            ));
140        }
141        if (desc.k & 7) != 0 {
142            return Err(Error::Unsupported(
143                "baracuda-kernels::GemmSparse24Plan: K must be a multiple of 8",
144            ));
145        }
146        let dtype_in_scope = matches!(
147            T::KIND,
148            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16
149        );
150        if !dtype_in_scope {
151            return Err(Error::Unsupported(
152                "baracuda-kernels::GemmSparse24Plan: wired today: `{f32, f16, bf16}`",
153            ));
154        }
155
156        // Pre-flight C-side validation.
157        #[cfg(feature = "xformers_sparse24")]
158        {
159            let probe = unsafe {
160                match T::KIND {
161                    ElementKind::F32 =>
162                        baracuda_kernels_sys::baracuda_kernels_gemm_f32_sparse24_gemm_can_implement(
163                            desc.n, desc.m, desc.k,
164                        ),
165                    ElementKind::F16 =>
166                        baracuda_kernels_sys::baracuda_kernels_gemm_f16_sparse24_gemm_can_implement(
167                            desc.n, desc.m, desc.k,
168                        ),
169                    ElementKind::Bf16 =>
170                        baracuda_kernels_sys::baracuda_kernels_gemm_bf16_sparse24_gemm_can_implement(
171                            desc.n, desc.m, desc.k,
172                        ),
173                    _ => 3,
174                }
175            };
176            super::super::attention::map_status_pub(probe)?;
177        }
178
179        let precision_guarantee = PrecisionGuarantee {
180            math_precision: MathPrecision::F32,
181            accumulator: ElementKind::F32,
182            // Reference GEMM is deterministic per-cell (no atomicAdd).
183            bit_stable_on_same_hardware: true,
184            deterministic: true,
185        };
186        let sku = KernelSku {
187            category: OpCategory::Gemm,
188            op: 0,
189            element: T::KIND,
190            aux_element: None,
191            layout: None,
192            epilogue: None,
193            arch: ArchSku::Sm80,
194            backend: BackendKind::Bespoke,
195            precision_guarantee,
196        };
197        Ok(Self {
198            desc: *desc,
199            sku,
200            _marker: PhantomData,
201        })
202    }
203
204    /// Validate args against the descriptor.
205    pub fn can_implement(&self, args: &GemmSparse24Args<'_, T>) -> Result<()> {
206        if args.w_compressed.shape != [self.desc.m, self.desc.k / 2] {
207            return Err(Error::InvalidProblem(
208                "baracuda-kernels::GemmSparse24Plan: w_compressed shape must be [M, K/2]",
209            ));
210        }
211        if args.w_metadata.shape != [self.desc.m, self.desc.k / 8] {
212            return Err(Error::InvalidProblem(
213                "baracuda-kernels::GemmSparse24Plan: w_metadata shape must be [M, K/8]",
214            ));
215        }
216        if args.x.shape != [self.desc.n, self.desc.k] {
217            return Err(Error::InvalidProblem(
218                "baracuda-kernels::GemmSparse24Plan: x shape must be [N, K]",
219            ));
220        }
221        if args.y.shape != [self.desc.n, self.desc.m] {
222            return Err(Error::InvalidProblem(
223                "baracuda-kernels::GemmSparse24Plan: y shape must be [N, M]",
224            ));
225        }
226        if !args.w_compressed.is_contiguous()
227            || !args.w_metadata.is_contiguous()
228            || !args.x.is_contiguous()
229            || !args.y.is_contiguous()
230        {
231            return Err(Error::Unsupported(
232                "baracuda-kernels::GemmSparse24Plan: all tensors must be contiguous in Tier 1",
233            ));
234        }
235        Ok(())
236    }
237
238    /// Workspace size in bytes — `M * K * sizeof(T)` for the inflated
239    /// dense W.
240    #[inline]
241    pub fn workspace_size(&self) -> usize {
242        (self.desc.m as usize)
243            * (self.desc.k as usize)
244            * core::mem::size_of::<T>()
245    }
246
247    /// SKU identity.
248    #[inline]
249    pub fn sku(&self) -> KernelSku {
250        self.sku
251    }
252
253    /// Numerical guarantees.
254    #[inline]
255    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
256        self.sku.precision_guarantee
257    }
258
259    /// Run the inflate-then-matmul reference path.
260    pub fn run(
261        &self,
262        stream: &Stream,
263        workspace: Workspace<'_>,
264        args: GemmSparse24Args<'_, T>,
265    ) -> Result<()> {
266        self.can_implement(&args)?;
267        if args.y.numel() == 0 {
268            return Ok(());
269        }
270        #[cfg(feature = "xformers_sparse24")]
271        {
272            let needed = self.workspace_size();
273            let (ws_ptr, ws_bytes) = match workspace {
274                Workspace::None => {
275                    return Err(Error::WorkspaceTooSmall {
276                        needed,
277                        got: 0,
278                    });
279                }
280                Workspace::Borrowed(bytes) => {
281                    let got = bytes.len();
282                    if got < needed {
283                        return Err(Error::WorkspaceTooSmall {
284                            needed,
285                            got,
286                        });
287                    }
288                    (bytes.as_raw().0 as *mut c_void, got as u64)
289                }
290            };
291            let stream_ptr = stream.as_raw() as *mut c_void;
292            let x_ptr = args.x.data.as_raw().0 as *const c_void;
293            let wc_ptr = args.w_compressed.data.as_raw().0 as *const c_void;
294            let wm_ptr = args.w_metadata.data.as_raw().0 as *const c_void;
295            let y_ptr = args.y.data.as_raw().0 as *mut c_void;
296            let status = unsafe {
297                match T::KIND {
298                    ElementKind::F32 =>
299                        baracuda_kernels_sys::baracuda_kernels_gemm_f32_sparse24_gemm_run(
300                            self.desc.n, self.desc.m, self.desc.k,
301                            x_ptr, wc_ptr, wm_ptr, y_ptr,
302                            ws_ptr, ws_bytes, stream_ptr,
303                        ),
304                    ElementKind::F16 =>
305                        baracuda_kernels_sys::baracuda_kernels_gemm_f16_sparse24_gemm_run(
306                            self.desc.n, self.desc.m, self.desc.k,
307                            x_ptr, wc_ptr, wm_ptr, y_ptr,
308                            ws_ptr, ws_bytes, stream_ptr,
309                        ),
310                    ElementKind::Bf16 =>
311                        baracuda_kernels_sys::baracuda_kernels_gemm_bf16_sparse24_gemm_run(
312                            self.desc.n, self.desc.m, self.desc.k,
313                            x_ptr, wc_ptr, wm_ptr, y_ptr,
314                            ws_ptr, ws_bytes, stream_ptr,
315                        ),
316                    _ => return Err(Error::Unsupported(
317                        "baracuda-kernels::GemmSparse24Plan::run reached an unimplemented dtype",
318                    )),
319                }
320            };
321            super::super::attention::map_status_pub(status)
322        }
323        #[cfg(not(feature = "xformers_sparse24"))]
324        {
325            let _ = (stream, workspace);
326            Err(Error::Unsupported(
327                "baracuda-kernels::GemmSparse24Plan: build with the `xformers_sparse24` cargo feature",
328            ))
329        }
330    }
331}