use crate::op::MaskKind;
use crate::{Graph, NodeId, Op, Shape};
pub fn attention_kind_op(
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
score_scale: Option<f32>,
attn_logit_softcap: Option<f32>,
) -> Op {
Op::Attention {
num_heads,
head_dim,
mask_kind,
score_scale,
attn_logit_softcap,
}
}
impl Graph {
pub fn attention(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
mask: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
) -> NodeId {
self.attention_opts(q, k, v, mask, num_heads, head_dim, shape, None, None)
}
pub fn attention_opts(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
mask: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
score_scale: Option<f32>,
attn_logit_softcap: Option<f32>,
) -> NodeId {
self.push(
attention_kind_op(
num_heads,
head_dim,
MaskKind::Custom,
score_scale,
attn_logit_softcap,
),
vec![q, k, v, mask],
shape,
None,
)
}
pub fn attention_kind(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
shape: Shape,
) -> NodeId {
self.attention_kind_opts(q, k, v, num_heads, head_dim, mask_kind, shape, None, None)
}
pub fn attention_kind_opts(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
shape: Shape,
score_scale: Option<f32>,
attn_logit_softcap: Option<f32>,
) -> NodeId {
debug_assert!(
!matches!(mask_kind, MaskKind::Custom | MaskKind::Bias),
"attention_kind() requires a non-tensor MaskKind; use attention() for Custom or attention_bias() for Bias"
);
self.push(
attention_kind_op(
num_heads,
head_dim,
mask_kind,
score_scale,
attn_logit_softcap,
),
vec![q, k, v],
shape,
None,
)
}
pub fn attention_bias(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
bias: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
) -> NodeId {
self.push(
attention_kind_op(num_heads, head_dim, MaskKind::Bias, None, None),
vec![q, k, v, bias],
shape,
None,
)
}
}