tract-gpu 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use tract_core::internal::*;
use tract_transformers::ops::sdpa::{Sdpa, SdpaMaskMode, wire_attention_mask};

pub fn rewire_sdpa(model: &mut TypedModel) -> TractResult<()> {
    Rewriter::default().with_rule_for("flatten-sdpa", rewire_sdpa_op).rewrite(&(), model)
}

pub fn rewire_sdpa_op(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    _node_name: &str,
    op: &Sdpa,
) -> TractResult<Option<TypedModelPatch>> {
    op.patch_sdpa(model, node)
}

pub fn create_sdpa_mask_graph(
    model: &TypedModel,
    node: &TypedNode,
    node_name: &str,
    op: &Sdpa,
    mode: SdpaMaskMode,
) -> TractResult<Option<TypedModelPatch>> {
    let in_facts = model.node_input_facts(node.id)?;
    let q_shape = &in_facts[0].shape;
    let k_shape = &in_facts[1].shape;
    let rank = q_shape.len();
    ensure!(k_shape.len() == rank);

    let q_len = &q_shape[rank - 2];
    let k_len = &k_shape[rank - 2];

    let mut patch = TypedModelPatch::default();
    let mut inputs = patch.taps(model, &node.inputs)?;

    let mask =
        wire_attention_mask(&mut patch, &node.name, op.acc_datum_type, mode, rank, q_len, k_len)?;
    inputs.push(mask);

    let mut new_op = op.clone();
    new_op.is_causal = false;
    let new_sdpa = patch.wire_node(node_name, new_op, &inputs)?[0];
    patch.shunt_outside(model, node.id.into(), new_sdpa)?;

    Ok(Some(patch))
}

pub fn neutral_mask_for_full_attn(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    node_name: &str,
    op: &Sdpa,
) -> TractResult<Option<TypedModelPatch>> {
    rule_if!(!op.is_causal && node.inputs.len() == 3);
    create_sdpa_mask_graph(model, node, node_name, op, SdpaMaskMode::Neutral)
}

pub fn causal_mask_as_extern(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    node_name: &str,
    op: &Sdpa,
) -> TractResult<Option<TypedModelPatch>> {
    rule_if!(op.is_causal);
    create_sdpa_mask_graph(model, node, node_name, op, SdpaMaskMode::Causal)
}