Skip to main content

baracuda_kernels/attention/
hyper_connection.rs

1//! Manifold-Constrained Hyper-Connections (Phase 43, Tier 1).
2//!
3//! Drop-in replacement for the bare `y = x + sublayer(x)` residual
4//! connection in transformer blocks. Mixes `n` parallel residual
5//! streams (`x_expanded[B, n, C]`) through a small Sinkhorn-Knopp-
6//! normalized `(n × n)` matrix `M`, then adds the post-stream
7//! contribution `H_post[i] * RMSNorm(aggregate)`.
8//!
9//! Backed by the vendored mHC.cu (Andre Slavescu, MIT) — see
10//! `crates/baracuda-kernels-sys/vendor/mhc/` for license and provenance.
11//! Paper: DeepSeek-AI, *Manifold-Constrained Hyper-Connections*,
12//! arXiv:2512.24880.
13//!
14//! ## Tier 1 scope
15//!
16//! - **Static-H FW only**. Dynamic-H FW and the BW pass live behind
17//!   the same vendored `MHCLayer` class and ship in Tier 2.
18//! - **bf16 weights / f32 activations**. The upstream `floatX` is
19//!   hardcoded to `nv_bfloat16`; f16 / f32 paths require additional
20//!   convert kernels in the launcher and ship in Tier 3.
21//! - **n ≤ 32**. Above 32 the upstream switches to a cuBLAS-Lt
22//!   tensor-core mixing kernel that has not yet been validated in
23//!   the C-ABI shim.
24//!
25//! ## Stateful plan
26//!
27//! Unlike most `*Plan` types in this crate, `HyperConnectionPlan`
28//! owns a non-trivial native handle (an `MHCLayer*`) that allocates
29//! ~`B*n*C*sizeof(float)` bytes of device-side scratch on construction.
30//! `Drop` destroys the handle. Reuse the plan across many forward
31//! calls to amortize the alloc cost — that's the whole point of the
32//! stateful design upstream.
33
34use core::marker::PhantomData;
35
36use baracuda_cutlass::{Error, Result};
37use baracuda_driver::Stream;
38use baracuda_kernels_types::{
39    ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
40    OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
41};
42
43/// Descriptor for a `HyperConnectionPlan` (static-H FW).
44#[derive(Copy, Clone, Debug)]
45pub struct HyperConnectionDescriptor {
46    /// Batch size — outer dim of `x_expanded`.
47    pub batch: i32,
48    /// Hidden dim — innermost dim of `x_expanded` / `out` / RMSNorm.
49    pub hidden_dim: i32,
50    /// Number of parallel residual streams (`n` in the paper). Must
51    /// be in `1..=32`.
52    pub n_streams: i32,
53    /// Sinkhorn-Knopp iteration count. Paper uses 20; the kernel
54    /// rejects anything outside `1..=1000`.
55    pub sinkhorn_iters: i32,
56    /// Epsilon added to the RMSNorm denominator and used as the
57    /// Sinkhorn divide-by-zero guard. Paper uses `1e-5`.
58    pub eps: f32,
59    /// Element type for `x_expanded` / `out` (f32 in Tier 1).
60    pub element: ElementKind,
61}
62
63/// Args bundle for a `HyperConnectionPlan` launch.
64pub struct HyperConnectionArgs<'a, T: Element> {
65    /// Residual-stream input — `[B, n, C]` row-major contiguous.
66    pub x_expanded: TensorRef<'a, T, 3>,
67    /// RMSNorm gamma — `[C]` bf16. **Always bf16 regardless of `T`**
68    /// (matches upstream `floatX` typedef).
69    pub rmsnorm_weight: TensorRef<'a, half::bf16, 1>,
70    /// Pre-mixing logits — `[n]` f32. The kernel passes them through
71    /// sigmoid internally.
72    pub h_pre: TensorRef<'a, f32, 1>,
73    /// Post-mixing logits — `[n]` f32. The kernel passes them
74    /// through `2 * sigmoid(.)` internally.
75    pub h_post: TensorRef<'a, f32, 1>,
76    /// Pre-Sinkhorn residual mixing matrix — `[n, n]` f32. The
77    /// kernel passes it through Sinkhorn-Knopp iteration to project
78    /// onto the doubly-stochastic manifold before mixing.
79    pub h_res: TensorRef<'a, f32, 2>,
80    /// Output — `[B, n, C]` row-major contiguous, same dtype as
81    /// input.
82    pub out: TensorMut<'a, T, 3>,
83}
84
85/// Hyper-Connection forward plan (static-H, bf16 weights, Tier 1).
86///
87/// **Formula** (with `M = Sinkhorn-Knopp(softmax_or_exp(H_res))`,
88/// `s_pre = sigmoid(H_pre)`, `s_post = 2 * sigmoid(H_post)`,
89/// `y_agg[b, c] = Σ_i s_pre[i] * x_expanded[b, i, c]`,
90/// `y_norm = RMSNorm(y_agg)`):
91///
92/// `out[b, i, c] = Σ_j M[i, j] * x_expanded[b, j, c] + s_post[i] * y_norm[b, c]`
93///
94/// **When to use**: replace the bare `x + sublayer(x)` residual in a
95/// transformer block when training a fresh model — mHC reports
96/// improved training stability + downstream task scores in
97/// DeepSeek-AI's experiments.
98///
99/// **Dtypes**: `f32` only in Tier 1. The `rmsnorm_weight` is always
100/// `bf16` regardless of `T`.
101///
102/// **State**: this plan owns a native `MHCLayer*` handle with
103/// ~`B*n*C*sizeof(float)` bytes of GPU scratch. Reuse across many
104/// `run()` calls; construction is heavy.
105pub struct HyperConnectionPlan<T: Element> {
106    desc: HyperConnectionDescriptor,
107    sku: KernelSku,
108    #[cfg(feature = "mhc")]
109    handle: *mut c_void,
110    _marker: PhantomData<T>,
111}
112
113// The handle wraps device memory owned by the current process; safe
114// to send between threads as long as construction / Drop happen on
115// the same one. We don't expose any &mut-on-shared-handle API, so
116// the contract is purely "owner thread + caller stream".
117unsafe impl<T: Element> Send for HyperConnectionPlan<T> {}
118unsafe impl<T: Element> Sync for HyperConnectionPlan<T> {}
119
120impl<T: Element> HyperConnectionPlan<T> {
121    /// Construct a plan for the given descriptor. Allocates the
122    /// internal `MHCLayer` scratch on the current CUDA context.
123    /// Returns `Err(Error::Unsupported)` if the `mhc` feature is off
124    /// or the descriptor is outside the Tier-1 SKU matrix.
125    pub fn select(
126        _stream: &Stream,
127        desc: &HyperConnectionDescriptor,
128        _pref: PlanPreference,
129    ) -> Result<Self> {
130        if desc.element != T::KIND {
131            return Err(Error::Unsupported(
132                "baracuda-kernels::HyperConnectionPlan: descriptor element != T",
133            ));
134        }
135        // Tier 1: f32 only.
136        if !matches!(T::KIND, ElementKind::F32) {
137            return Err(Error::Unsupported(
138                "baracuda-kernels::HyperConnectionPlan: Tier 1 wired today: `{f32}` only \
139                 (f16 / bf16 deferred to Tier 3)",
140            ));
141        }
142        if desc.batch <= 0 || desc.hidden_dim <= 0 || desc.n_streams <= 0 {
143            return Err(Error::InvalidProblem(
144                "baracuda-kernels::HyperConnectionPlan: batch / hidden_dim / n_streams must be positive",
145            ));
146        }
147        if desc.n_streams >= 32 {
148            return Err(Error::Unsupported(
149                "baracuda-kernels::HyperConnectionPlan: n_streams >= 32 not yet supported \
150                 (would activate the cuBLAS-Lt tensor-core mix path)",
151            ));
152        }
153        if desc.hidden_dim < desc.n_streams {
154            return Err(Error::InvalidProblem(
155                "baracuda-kernels::HyperConnectionPlan: hidden_dim < n_streams (need at least \
156                 one channel per stream for the aggregate kernel)",
157            ));
158        }
159        if desc.sinkhorn_iters <= 0 || desc.sinkhorn_iters > 1000 {
160            return Err(Error::InvalidProblem(
161                "baracuda-kernels::HyperConnectionPlan: sinkhorn_iters must be in 1..=1000",
162            ));
163        }
164        if !(desc.eps.is_finite() && desc.eps > 0.0 && desc.eps < 1.0) {
165            return Err(Error::InvalidProblem(
166                "baracuda-kernels::HyperConnectionPlan: eps must be finite and in (0, 1)",
167            ));
168        }
169
170        let precision_guarantee = PrecisionGuarantee {
171            math_precision: MathPrecision::F32,
172            accumulator: ElementKind::F32,
173            // Sinkhorn-Knopp + stream-mix are deterministic per launch;
174            // upstream kernels avoid atomicAdd on the FW path. Two
175            // back-to-back launches at the same shape produce bit-equal
176            // outputs on the same hardware.
177            bit_stable_on_same_hardware: true,
178            deterministic: true,
179        };
180        let sku = KernelSku {
181            category: OpCategory::Attention,
182            op: AttentionKind::HyperConnection as u16,
183            element: T::KIND,
184            aux_element: Some(ElementKind::Bf16), // rmsnorm_weight dtype
185            layout: None,
186            epilogue: None,
187            arch: ArchSku::Sm80,
188            backend: BackendKind::Bespoke,
189            precision_guarantee,
190        };
191
192        #[cfg(feature = "mhc")]
193        {
194            // Pre-flight C-side validation (mirrors create's range
195            // checks). Lets us return InvalidArg / Unsupported without
196            // attempting an alloc that would just fail.
197            let probe = unsafe {
198                baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_can_implement(
199                    desc.batch,
200                    desc.hidden_dim,
201                    desc.n_streams,
202                )
203            };
204            super::map_status(probe)?;
205
206            let handle = unsafe {
207                baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_create(
208                    desc.batch,
209                    desc.hidden_dim,
210                    desc.n_streams,
211                    desc.sinkhorn_iters,
212                    desc.eps,
213                )
214            };
215            if handle.is_null() {
216                return Err(Error::Unsupported(
217                    "baracuda-kernels::HyperConnectionPlan: native MHCLayer init failed",
218                ));
219            }
220            Ok(Self {
221                desc: *desc,
222                sku,
223                handle,
224                _marker: PhantomData,
225            })
226        }
227        #[cfg(not(feature = "mhc"))]
228        {
229            let _ = sku; // silence unused warning when feature is off
230            Err(Error::Unsupported(
231                "baracuda-kernels::HyperConnectionPlan: build with the `mhc` cargo feature",
232            ))
233        }
234    }
235
236    /// Validate args against the descriptor.
237    pub fn can_implement(&self, args: &HyperConnectionArgs<'_, T>) -> Result<()> {
238        let b = self.desc.batch;
239        let n = self.desc.n_streams;
240        let c = self.desc.hidden_dim;
241
242        if args.x_expanded.shape != [b, n, c] {
243            return Err(Error::InvalidProblem(
244                "baracuda-kernels::HyperConnectionPlan: x_expanded shape mismatch with [B, n, C]",
245            ));
246        }
247        if args.out.shape != [b, n, c] {
248            return Err(Error::InvalidProblem(
249                "baracuda-kernels::HyperConnectionPlan: out shape mismatch with [B, n, C]",
250            ));
251        }
252        if args.rmsnorm_weight.shape != [c] {
253            return Err(Error::InvalidProblem(
254                "baracuda-kernels::HyperConnectionPlan: rmsnorm_weight shape mismatch with [C]",
255            ));
256        }
257        if args.h_pre.shape != [n] {
258            return Err(Error::InvalidProblem(
259                "baracuda-kernels::HyperConnectionPlan: h_pre shape mismatch with [n]",
260            ));
261        }
262        if args.h_post.shape != [n] {
263            return Err(Error::InvalidProblem(
264                "baracuda-kernels::HyperConnectionPlan: h_post shape mismatch with [n]",
265            ));
266        }
267        if args.h_res.shape != [n, n] {
268            return Err(Error::InvalidProblem(
269                "baracuda-kernels::HyperConnectionPlan: h_res shape mismatch with [n, n]",
270            ));
271        }
272
273        // The upstream kernels assume row-major contiguous layout for
274        // x_expanded / out. We reject non-contiguous strides — adding
275        // strided support is a kernel rewrite.
276        let bnc = (b as i64) * (n as i64) * (c as i64);
277        if (args.x_expanded.data.len() as i64) < bnc {
278            return Err(Error::BufferTooSmall {
279                needed: bnc as usize,
280                got: args.x_expanded.data.len(),
281            });
282        }
283        if (args.out.data.len() as i64) < bnc {
284            return Err(Error::BufferTooSmall {
285                needed: bnc as usize,
286                got: args.out.data.len(),
287            });
288        }
289        Ok(())
290    }
291
292    /// Workspace size in bytes. Always zero — internal scratch lives
293    /// in the native handle (allocated at `select` time).
294    #[inline]
295    pub fn workspace_size(&self) -> usize {
296        0
297    }
298
299    /// Identity of the kernel SKU this plan dispatches to.
300    #[inline]
301    pub fn sku(&self) -> KernelSku {
302        self.sku
303    }
304
305    /// Numerical guarantees — deterministic, bit-stable on the same
306    /// hardware (no atomicAdd on the FW path).
307    #[inline]
308    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
309        self.sku.precision_guarantee
310    }
311
312    /// Launch the kernel against `args`.
313    pub fn run(
314        &self,
315        stream: &Stream,
316        _workspace: Workspace<'_>,
317        args: HyperConnectionArgs<'_, T>,
318    ) -> Result<()> {
319        self.can_implement(&args)?;
320        #[cfg(feature = "mhc")]
321        {
322            let stream_ptr = stream.as_raw() as *mut c_void;
323            let status = unsafe {
324                baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_run(
325                    self.handle,
326                    args.x_expanded.data.as_raw().0 as *const c_void,
327                    args.rmsnorm_weight.data.as_raw().0 as *const c_void,
328                    args.h_pre.data.as_raw().0 as *const c_void,
329                    args.h_post.data.as_raw().0 as *const c_void,
330                    args.h_res.data.as_raw().0 as *const c_void,
331                    args.out.data.as_raw().0 as *mut c_void,
332                    self.desc.batch,
333                    self.desc.hidden_dim,
334                    self.desc.n_streams,
335                    core::ptr::null_mut(),
336                    0,
337                    stream_ptr,
338                )
339            };
340            super::map_status(status)
341        }
342        #[cfg(not(feature = "mhc"))]
343        {
344            let _ = stream;
345            Err(Error::Unsupported(
346                "baracuda-kernels::HyperConnectionPlan: build with the `mhc` cargo feature",
347            ))
348        }
349    }
350}
351
352impl<T: Element> Drop for HyperConnectionPlan<T> {
353    fn drop(&mut self) {
354        #[cfg(feature = "mhc")]
355        {
356            if !self.handle.is_null() {
357                unsafe {
358                    baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_destroy(
359                        self.handle,
360                    );
361                }
362                self.handle = core::ptr::null_mut();
363            }
364        }
365    }
366}