baracuda_kernels/attention/hyper_connection.rs
1//! Manifold-Constrained Hyper-Connections (Phase 43, Tier 1).
2//!
3//! Drop-in replacement for the bare `y = x + sublayer(x)` residual
4//! connection in transformer blocks. Mixes `n` parallel residual
5//! streams (`x_expanded[B, n, C]`) through a small Sinkhorn-Knopp-
6//! normalized `(n × n)` matrix `M`, then adds the post-stream
7//! contribution `H_post[i] * RMSNorm(aggregate)`.
8//!
9//! Backed by the vendored mHC.cu (Andre Slavescu, MIT) — see
10//! `crates/baracuda-kernels-sys/vendor/mhc/` for license and provenance.
11//! Paper: DeepSeek-AI, *Manifold-Constrained Hyper-Connections*,
12//! arXiv:2512.24880.
13//!
14//! ## Tier 1 scope
15//!
16//! - **Static-H FW only**. Dynamic-H FW and the BW pass live behind
17//! the same vendored `MHCLayer` class and ship in Tier 2.
18//! - **bf16 weights / f32 activations**. The upstream `floatX` is
19//! hardcoded to `nv_bfloat16`; f16 / f32 paths require additional
20//! convert kernels in the launcher and ship in Tier 3.
21//! - **n ≤ 32**. Above 32 the upstream switches to a cuBLAS-Lt
22//! tensor-core mixing kernel that has not yet been validated in
23//! the C-ABI shim.
24//!
25//! ## Stateful plan
26//!
27//! Unlike most `*Plan` types in this crate, `HyperConnectionPlan`
28//! owns a non-trivial native handle (an `MHCLayer*`) that allocates
29//! ~`B*n*C*sizeof(float)` bytes of device-side scratch on construction.
30//! `Drop` destroys the handle. Reuse the plan across many forward
31//! calls to amortize the alloc cost — that's the whole point of the
32//! stateful design upstream.
33
34use core::marker::PhantomData;
35
36use baracuda_cutlass::{Error, Result};
37use baracuda_driver::Stream;
38use baracuda_kernels_types::{
39 ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
40 OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
41};
42
43/// Descriptor for a `HyperConnectionPlan` (static-H FW).
44#[derive(Copy, Clone, Debug)]
45pub struct HyperConnectionDescriptor {
46 /// Batch size — outer dim of `x_expanded`.
47 pub batch: i32,
48 /// Hidden dim — innermost dim of `x_expanded` / `out` / RMSNorm.
49 pub hidden_dim: i32,
50 /// Number of parallel residual streams (`n` in the paper). Must
51 /// be in `1..=32`.
52 pub n_streams: i32,
53 /// Sinkhorn-Knopp iteration count. Paper uses 20; the kernel
54 /// rejects anything outside `1..=1000`.
55 pub sinkhorn_iters: i32,
56 /// Epsilon added to the RMSNorm denominator and used as the
57 /// Sinkhorn divide-by-zero guard. Paper uses `1e-5`.
58 pub eps: f32,
59 /// Element type for `x_expanded` / `out` (f32 in Tier 1).
60 pub element: ElementKind,
61}
62
63/// Args bundle for a `HyperConnectionPlan` launch.
64pub struct HyperConnectionArgs<'a, T: Element> {
65 /// Residual-stream input — `[B, n, C]` row-major contiguous.
66 pub x_expanded: TensorRef<'a, T, 3>,
67 /// RMSNorm gamma — `[C]` bf16. **Always bf16 regardless of `T`**
68 /// (matches upstream `floatX` typedef).
69 pub rmsnorm_weight: TensorRef<'a, half::bf16, 1>,
70 /// Pre-mixing logits — `[n]` f32. The kernel passes them through
71 /// sigmoid internally.
72 pub h_pre: TensorRef<'a, f32, 1>,
73 /// Post-mixing logits — `[n]` f32. The kernel passes them
74 /// through `2 * sigmoid(.)` internally.
75 pub h_post: TensorRef<'a, f32, 1>,
76 /// Pre-Sinkhorn residual mixing matrix — `[n, n]` f32. The
77 /// kernel passes it through Sinkhorn-Knopp iteration to project
78 /// onto the doubly-stochastic manifold before mixing.
79 pub h_res: TensorRef<'a, f32, 2>,
80 /// Output — `[B, n, C]` row-major contiguous, same dtype as
81 /// input.
82 pub out: TensorMut<'a, T, 3>,
83}
84
85/// Hyper-Connection forward plan (static-H, bf16 weights, Tier 1).
86///
87/// **Formula** (with `M = Sinkhorn-Knopp(softmax_or_exp(H_res))`,
88/// `s_pre = sigmoid(H_pre)`, `s_post = 2 * sigmoid(H_post)`,
89/// `y_agg[b, c] = Σ_i s_pre[i] * x_expanded[b, i, c]`,
90/// `y_norm = RMSNorm(y_agg)`):
91///
92/// `out[b, i, c] = Σ_j M[i, j] * x_expanded[b, j, c] + s_post[i] * y_norm[b, c]`
93///
94/// **When to use**: replace the bare `x + sublayer(x)` residual in a
95/// transformer block when training a fresh model — mHC reports
96/// improved training stability + downstream task scores in
97/// DeepSeek-AI's experiments.
98///
99/// **Dtypes**: `f32` only in Tier 1. The `rmsnorm_weight` is always
100/// `bf16` regardless of `T`.
101///
102/// **State**: this plan owns a native `MHCLayer*` handle with
103/// ~`B*n*C*sizeof(float)` bytes of GPU scratch. Reuse across many
104/// `run()` calls; construction is heavy.
105pub struct HyperConnectionPlan<T: Element> {
106 desc: HyperConnectionDescriptor,
107 sku: KernelSku,
108 #[cfg(feature = "mhc")]
109 handle: *mut c_void,
110 _marker: PhantomData<T>,
111}
112
113// The handle wraps device memory owned by the current process; safe
114// to send between threads as long as construction / Drop happen on
115// the same one. We don't expose any &mut-on-shared-handle API, so
116// the contract is purely "owner thread + caller stream".
117unsafe impl<T: Element> Send for HyperConnectionPlan<T> {}
118unsafe impl<T: Element> Sync for HyperConnectionPlan<T> {}
119
120impl<T: Element> HyperConnectionPlan<T> {
121 /// Construct a plan for the given descriptor. Allocates the
122 /// internal `MHCLayer` scratch on the current CUDA context.
123 /// Returns `Err(Error::Unsupported)` if the `mhc` feature is off
124 /// or the descriptor is outside the Tier-1 SKU matrix.
125 pub fn select(
126 _stream: &Stream,
127 desc: &HyperConnectionDescriptor,
128 _pref: PlanPreference,
129 ) -> Result<Self> {
130 if desc.element != T::KIND {
131 return Err(Error::Unsupported(
132 "baracuda-kernels::HyperConnectionPlan: descriptor element != T",
133 ));
134 }
135 // Tier 1: f32 only.
136 if !matches!(T::KIND, ElementKind::F32) {
137 return Err(Error::Unsupported(
138 "baracuda-kernels::HyperConnectionPlan: Tier 1 wired today: `{f32}` only \
139 (f16 / bf16 deferred to Tier 3)",
140 ));
141 }
142 if desc.batch <= 0 || desc.hidden_dim <= 0 || desc.n_streams <= 0 {
143 return Err(Error::InvalidProblem(
144 "baracuda-kernels::HyperConnectionPlan: batch / hidden_dim / n_streams must be positive",
145 ));
146 }
147 if desc.n_streams >= 32 {
148 return Err(Error::Unsupported(
149 "baracuda-kernels::HyperConnectionPlan: n_streams >= 32 not yet supported \
150 (would activate the cuBLAS-Lt tensor-core mix path)",
151 ));
152 }
153 if desc.hidden_dim < desc.n_streams {
154 return Err(Error::InvalidProblem(
155 "baracuda-kernels::HyperConnectionPlan: hidden_dim < n_streams (need at least \
156 one channel per stream for the aggregate kernel)",
157 ));
158 }
159 if desc.sinkhorn_iters <= 0 || desc.sinkhorn_iters > 1000 {
160 return Err(Error::InvalidProblem(
161 "baracuda-kernels::HyperConnectionPlan: sinkhorn_iters must be in 1..=1000",
162 ));
163 }
164 if !(desc.eps.is_finite() && desc.eps > 0.0 && desc.eps < 1.0) {
165 return Err(Error::InvalidProblem(
166 "baracuda-kernels::HyperConnectionPlan: eps must be finite and in (0, 1)",
167 ));
168 }
169
170 let precision_guarantee = PrecisionGuarantee {
171 math_precision: MathPrecision::F32,
172 accumulator: ElementKind::F32,
173 // Sinkhorn-Knopp + stream-mix are deterministic per launch;
174 // upstream kernels avoid atomicAdd on the FW path. Two
175 // back-to-back launches at the same shape produce bit-equal
176 // outputs on the same hardware.
177 bit_stable_on_same_hardware: true,
178 deterministic: true,
179 };
180 let sku = KernelSku {
181 category: OpCategory::Attention,
182 op: AttentionKind::HyperConnection as u16,
183 element: T::KIND,
184 aux_element: Some(ElementKind::Bf16), // rmsnorm_weight dtype
185 layout: None,
186 epilogue: None,
187 arch: ArchSku::Sm80,
188 backend: BackendKind::Bespoke,
189 precision_guarantee,
190 };
191
192 #[cfg(feature = "mhc")]
193 {
194 // Pre-flight C-side validation (mirrors create's range
195 // checks). Lets us return InvalidArg / Unsupported without
196 // attempting an alloc that would just fail.
197 let probe = unsafe {
198 baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_can_implement(
199 desc.batch,
200 desc.hidden_dim,
201 desc.n_streams,
202 )
203 };
204 super::map_status(probe)?;
205
206 let handle = unsafe {
207 baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_create(
208 desc.batch,
209 desc.hidden_dim,
210 desc.n_streams,
211 desc.sinkhorn_iters,
212 desc.eps,
213 )
214 };
215 if handle.is_null() {
216 return Err(Error::Unsupported(
217 "baracuda-kernels::HyperConnectionPlan: native MHCLayer init failed",
218 ));
219 }
220 Ok(Self {
221 desc: *desc,
222 sku,
223 handle,
224 _marker: PhantomData,
225 })
226 }
227 #[cfg(not(feature = "mhc"))]
228 {
229 let _ = sku; // silence unused warning when feature is off
230 Err(Error::Unsupported(
231 "baracuda-kernels::HyperConnectionPlan: build with the `mhc` cargo feature",
232 ))
233 }
234 }
235
236 /// Validate args against the descriptor.
237 pub fn can_implement(&self, args: &HyperConnectionArgs<'_, T>) -> Result<()> {
238 let b = self.desc.batch;
239 let n = self.desc.n_streams;
240 let c = self.desc.hidden_dim;
241
242 if args.x_expanded.shape != [b, n, c] {
243 return Err(Error::InvalidProblem(
244 "baracuda-kernels::HyperConnectionPlan: x_expanded shape mismatch with [B, n, C]",
245 ));
246 }
247 if args.out.shape != [b, n, c] {
248 return Err(Error::InvalidProblem(
249 "baracuda-kernels::HyperConnectionPlan: out shape mismatch with [B, n, C]",
250 ));
251 }
252 if args.rmsnorm_weight.shape != [c] {
253 return Err(Error::InvalidProblem(
254 "baracuda-kernels::HyperConnectionPlan: rmsnorm_weight shape mismatch with [C]",
255 ));
256 }
257 if args.h_pre.shape != [n] {
258 return Err(Error::InvalidProblem(
259 "baracuda-kernels::HyperConnectionPlan: h_pre shape mismatch with [n]",
260 ));
261 }
262 if args.h_post.shape != [n] {
263 return Err(Error::InvalidProblem(
264 "baracuda-kernels::HyperConnectionPlan: h_post shape mismatch with [n]",
265 ));
266 }
267 if args.h_res.shape != [n, n] {
268 return Err(Error::InvalidProblem(
269 "baracuda-kernels::HyperConnectionPlan: h_res shape mismatch with [n, n]",
270 ));
271 }
272
273 // The upstream kernels assume row-major contiguous layout for
274 // x_expanded / out. We reject non-contiguous strides — adding
275 // strided support is a kernel rewrite.
276 let bnc = (b as i64) * (n as i64) * (c as i64);
277 if (args.x_expanded.data.len() as i64) < bnc {
278 return Err(Error::BufferTooSmall {
279 needed: bnc as usize,
280 got: args.x_expanded.data.len(),
281 });
282 }
283 if (args.out.data.len() as i64) < bnc {
284 return Err(Error::BufferTooSmall {
285 needed: bnc as usize,
286 got: args.out.data.len(),
287 });
288 }
289 Ok(())
290 }
291
292 /// Workspace size in bytes. Always zero — internal scratch lives
293 /// in the native handle (allocated at `select` time).
294 #[inline]
295 pub fn workspace_size(&self) -> usize {
296 0
297 }
298
299 /// Identity of the kernel SKU this plan dispatches to.
300 #[inline]
301 pub fn sku(&self) -> KernelSku {
302 self.sku
303 }
304
305 /// Numerical guarantees — deterministic, bit-stable on the same
306 /// hardware (no atomicAdd on the FW path).
307 #[inline]
308 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
309 self.sku.precision_guarantee
310 }
311
312 /// Launch the kernel against `args`.
313 pub fn run(
314 &self,
315 stream: &Stream,
316 _workspace: Workspace<'_>,
317 args: HyperConnectionArgs<'_, T>,
318 ) -> Result<()> {
319 self.can_implement(&args)?;
320 #[cfg(feature = "mhc")]
321 {
322 let stream_ptr = stream.as_raw() as *mut c_void;
323 let status = unsafe {
324 baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_run(
325 self.handle,
326 args.x_expanded.data.as_raw().0 as *const c_void,
327 args.rmsnorm_weight.data.as_raw().0 as *const c_void,
328 args.h_pre.data.as_raw().0 as *const c_void,
329 args.h_post.data.as_raw().0 as *const c_void,
330 args.h_res.data.as_raw().0 as *const c_void,
331 args.out.data.as_raw().0 as *mut c_void,
332 self.desc.batch,
333 self.desc.hidden_dim,
334 self.desc.n_streams,
335 core::ptr::null_mut(),
336 0,
337 stream_ptr,
338 )
339 };
340 super::map_status(status)
341 }
342 #[cfg(not(feature = "mhc"))]
343 {
344 let _ = stream;
345 Err(Error::Unsupported(
346 "baracuda-kernels::HyperConnectionPlan: build with the `mhc` cargo feature",
347 ))
348 }
349 }
350}
351
352impl<T: Element> Drop for HyperConnectionPlan<T> {
353 fn drop(&mut self) {
354 #[cfg(feature = "mhc")]
355 {
356 if !self.handle.is_null() {
357 unsafe {
358 baracuda_kernels_sys::baracuda_kernels_mhc_layer_static_bf16_destroy(
359 self.handle,
360 );
361 }
362 self.handle = core::ptr::null_mut();
363 }
364 }
365 }
366}