tract-transformers 0.22.1

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
pub mod apply_rope;
pub mod dyn_kv_cache;
pub mod gelu_approximate;
pub mod rms_norm;
pub mod scaled_masked_softmax;
pub mod sdpa;
pub mod silu;

use tract_core::internal::*;
use tract_core::ops::konst::Const;
use tract_nnef::tract_core;

pub use apply_rope::{apply_rope_rule, rotate_half_rule};
pub use dyn_kv_cache::replace_kv_cache;
pub use gelu_approximate::gelu_approx_rule;
pub use rms_norm::rms_norm_rule;
pub use scaled_masked_softmax::scaled_masked_softmax_rule;
pub use sdpa::fuse_kv_cache_broadcast_rule;
pub use silu::silu_rule;

use tract_core::ops::binary::TypedBinOp;
use tract_core::ops::math::{Add, Mul};

#[macro_export]
macro_rules! rule_ensure {
    ($cond:expr) => {
        if !$cond {
            return Ok(None);
        }
    };
}

fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
    if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
        return None;
    }
    let succ = node.outputs[0].successors[0];
    Some(&model.nodes()[succ.node])
}

fn previous_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
    if node.inputs.len() != 1 {
        return None;
    }
    Some(&model.nodes()[node.inputs[0].node])
}

fn previous_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a TypedNode> {
    node.inputs.iter().map(|n| &model.nodes()[n.node]).collect()
}

fn collect_node_const_inputs<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a Const> {
    node.inputs
        .iter()
        .filter_map(|i| {
            let prec = &model.nodes()[i.node];
            prec.op_as::<Const>()
        })
        .collect::<TVec<_>>()
}

fn single_prev_node_as<'a, O: TypedOp>(
    model: &'a TypedModel,
    node: &TypedNode,
) -> Option<(usize, &'a TypedNode)> {
    let prev_nodes = node
        .inputs
        .iter()
        .enumerate()
        .filter_map(|(in_idx, i)| {
            let prec = &model.nodes()[i.node];
            prec.op_is::<O>().then_some((in_idx, prec))
        })
        .collect::<TVec<_>>();

    if prev_nodes.len() != 1 {
        None
    } else {
        Some(prev_nodes[0])
    }
}

fn find_succ_mul_with_const<'a>(
    model: &'a TypedModel,
    node: &'a TypedNode,
    konst: f32,
) -> Option<&'a TypedNode> {
    let mul_coef_a = next_node(model, node)?;
    let mul_coef_a_op = mul_coef_a.op_as::<TypedBinOp>()?;
    (mul_coef_a_op.0.is::<Mul>() && matches_single_input_const(model, mul_coef_a, konst))
        .then_some(mul_coef_a)
}

fn find_succ_add_with<'a>(
    model: &'a TypedModel,
    node: &'a TypedNode,
    outled_id: &OutletId,
) -> Option<&'a TypedNode> {
    let add_succ = next_node(model, node)?;
    let add_succ_op = add_succ.op_as::<TypedBinOp>()?;
    (add_succ_op.0.is::<Add>() && add_succ.inputs.contains(outled_id)).then_some(add_succ)
}

fn matches_single_input_const(model: &TypedModel, node: &TypedNode, konst: f32) -> bool {
    let consts = collect_node_const_inputs(model, node);
    if consts.len() != 1 {
        return false;
    }
    let Ok(in_const) = consts[0].val().cast_to_dt(DatumType::F32) else {
        return false;
    };
    let Ok(in_const) = in_const.to_scalar_tensor() else {
        return false;
    };

    in_const.close_enough(&tensor0(konst), Approximation::Approximate).is_ok()
}

fn find_succ_add_with_const<'a>(
    model: &'a TypedModel,
    node: &'a TypedNode,
    konst: f32,
) -> Option<&'a TypedNode> {
    let add_coef_a = next_node(model, node)?;
    let add_coef_a_op = add_coef_a.op_as::<TypedBinOp>()?;
    if !add_coef_a_op.0.is::<Add>() {
        return None;
    }
    (add_coef_a_op.0.is::<Add>() && matches_single_input_const(model, add_coef_a, konst))
        .then_some(add_coef_a)
}