pub struct HyperConnectionArgs<'a, T: Element> {
pub x_expanded: TensorRef<'a, T, 3>,
pub rmsnorm_weight: TensorRef<'a, bf16, 1>,
pub h_pre: TensorRef<'a, f32, 1>,
pub h_post: TensorRef<'a, f32, 1>,
pub h_res: TensorRef<'a, f32, 2>,
pub out: TensorMut<'a, T, 3>,
}Expand description
Args bundle for a HyperConnectionPlan launch.
Fields§
§x_expanded: TensorRef<'a, T, 3>Residual-stream input — [B, n, C] row-major contiguous.
rmsnorm_weight: TensorRef<'a, bf16, 1>RMSNorm gamma — [C] bf16. Always bf16 regardless of T
(matches upstream floatX typedef).
h_pre: TensorRef<'a, f32, 1>Pre-mixing logits — [n] f32. The kernel passes them through
sigmoid internally.
h_post: TensorRef<'a, f32, 1>Post-mixing logits — [n] f32. The kernel passes them
through 2 * sigmoid(.) internally.
h_res: TensorRef<'a, f32, 2>Pre-Sinkhorn residual mixing matrix — [n, n] f32. The
kernel passes it through Sinkhorn-Knopp iteration to project
onto the doubly-stochastic manifold before mixing.
out: TensorMut<'a, T, 3>Output — [B, n, C] row-major contiguous, same dtype as
input.
Auto Trait Implementations§
impl<'a, T> !UnwindSafe for HyperConnectionArgs<'a, T>
impl<'a, T> Freeze for HyperConnectionArgs<'a, T>
impl<'a, T> RefUnwindSafe for HyperConnectionArgs<'a, T>where
T: RefUnwindSafe,
impl<'a, T> Send for HyperConnectionArgs<'a, T>
impl<'a, T> Sync for HyperConnectionArgs<'a, T>where
T: Sync,
impl<'a, T> Unpin for HyperConnectionArgs<'a, T>
impl<'a, T> UnsafeUnpin for HyperConnectionArgs<'a, T>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more