use crate::op::MaskKind;
use crate::{Graph, NodeId, Op, Shape};
impl Graph {
pub fn attention(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
mask: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
) -> NodeId {
self.push(
Op::Attention {
num_heads,
head_dim,
mask_kind: MaskKind::Custom,
},
vec![q, k, v, mask],
shape,
None,
)
}
pub fn attention_kind(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
shape: Shape,
) -> NodeId {
debug_assert!(
!matches!(mask_kind, MaskKind::Custom | MaskKind::Bias),
"attention_kind() requires a non-tensor MaskKind; use attention() for Custom or attention_bias() for Bias"
);
self.push(
Op::Attention {
num_heads,
head_dim,
mask_kind,
},
vec![q, k, v],
shape,
None,
)
}
pub fn attention_bias(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
bias: NodeId,
num_heads: usize,
head_dim: usize,
shape: Shape,
) -> NodeId {
self.push(
Op::Attention {
num_heads,
head_dim,
mask_kind: MaskKind::Bias,
},
vec![q, k, v, bias],
shape,
None,
)
}
}