pub enum MaskKind {
None,
Causal,
SlidingWindow(usize),
Custom,
Bias,
}Expand description
What kind of attention mask the kernel should apply.
Borrowed from MAX’s nn/attention/mha_mask.mojo pattern (#20 in
PLAN.md): one attention kernel handles all variants by branching on
the mask kind, instead of forcing every caller to materialize a mask
tensor. The win is two-fold:
None— single unpadded sequence: no mask load, no per-key compare in the inner loop.Causal— autoregressive decode: kernel generates the upper- triangular fill from(qi, ki)directly; noseq²mask tensor ever exists.
Custom is the existing path — read mask values from the 4th input.
Variants§
None
No masking — every position attends to every position.
Causal
Causal (autoregressive) — position qi attends only to ki <= qi.
SlidingWindow(usize)
Sliding window — position qi attends to ki ∈ [qi - w, qi].
Custom
Read mask values from the input tensor (default; matches BERT
padding-mask behavior). Tensor shape [batch, key_len] with
1.0 = valid, <0.5 = ignored.
Bias
Additive per-head, per-query bias tensor
[batch, num_heads, query_len, key_len] added to the
QK^T · scale scores before softmax. Lets DETR-style boxRPB
and other learned position biases reuse the fast Op::Attention
path instead of decomposing into matmul + add + softmax + matmul.
Trait Implementations§
Source§impl<'de> Deserialize<'de> for MaskKind
impl<'de> Deserialize<'de> for MaskKind
Source§fn deserialize<__D>(
__deserializer: __D,
) -> Result<MaskKind, <__D as Deserializer<'de>>::Error>where
__D: Deserializer<'de>,
fn deserialize<__D>(
__deserializer: __D,
) -> Result<MaskKind, <__D as Deserializer<'de>>::Error>where
__D: Deserializer<'de>,
Source§impl Serialize for MaskKind
impl Serialize for MaskKind
Source§fn serialize<__S>(
&self,
__serializer: __S,
) -> Result<<__S as Serializer>::Ok, <__S as Serializer>::Error>where
__S: Serializer,
fn serialize<__S>(
&self,
__serializer: __S,
) -> Result<<__S as Serializer>::Ok, <__S as Serializer>::Error>where
__S: Serializer,
impl Copy for MaskKind
impl Eq for MaskKind
impl StructuralPartialEq for MaskKind
Auto Trait Implementations§
impl Freeze for MaskKind
impl RefUnwindSafe for MaskKind
impl Send for MaskKind
impl Sync for MaskKind
impl Unpin for MaskKind
impl UnsafeUnpin for MaskKind
impl UnwindSafe for MaskKind
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<Q, K> Equivalent<K> for Q
impl<Q, K> Equivalent<K> for Q
Source§impl<Q, K> Equivalent<K> for Q
impl<Q, K> Equivalent<K> for Q
Source§fn equivalent(&self, key: &K) -> bool
fn equivalent(&self, key: &K) -> bool
key and return true if they are equal.Source§impl<Q, K> Equivalent<K> for Q
impl<Q, K> Equivalent<K> for Q
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more