tract-hir 0.19.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::internal::*;
use tract_core::ops::math::*;

macro_rules! activation {
    ($op: ident, $wire:expr) => {
        impl_dyn_hash!($op);

        impl Expansion for $op {
            fn name(&self) -> Cow<str> {
                stringify!($op).into()
            }

            fn rules<'r, 'p: 'r, 's: 'r>(
                &'s self,
                s: &mut Solver<'r>,
                inputs: &'p [TensorProxy],
                outputs: &'p [TensorProxy],
            ) -> InferenceResult {
                simple_unary_rules(s, inputs, outputs)
            }

            fn wire(
                &self,
                name: &str,
                model: &mut TypedModel,
                inputs: &[OutletId],
            ) -> TractResult<TVec<OutletId>> {
                let wire: fn(
                    &$op,
                    &str,
                    &mut TypedModel,
                    &[OutletId],
                ) -> TractResult<TVec<OutletId>> = $wire;
                (wire)(self, name, model, inputs)
            }
        }
    };
}

macro_rules! cst {
    ($model: expr, $inputs: expr, $name: expr, $id:ident, $value: expr) => {
        let $id = broadcast_scalar($value, $model, $inputs)?;
        let $id = $model.add_const($name.to_string() + "." + stringify!($id), $id)?;
    };
}

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct Clip(
    #[educe(Hash(method = "hash_opt_f32"))] Option<f32>,
    #[educe(Hash(method = "hash_opt_f32"))] Option<f32>,
);

activation!(Clip, |op, name: &str, model: &mut TypedModel, inputs| {
    let mut wire: TVec<OutletId> = inputs.into();
    if let Some(low) = op.0 {
        let low = broadcast_scalar(low, model, inputs)?;
        let low = model.add_const(name.to_string() + ".low.cst", low)?;
        wire = model.wire_node(name.to_string() + ".low", max(), &[wire[0], low])?;
    }
    if let Some(high) = op.1 {
        let high = broadcast_scalar(high, model, inputs)?;
        let high = model.add_const(name.to_string() + ".high.cst", high)?;
        wire = model.wire_node(name.to_string() + ".high", min(), &[wire[0], high])?;
    }
    Ok(wire)
});

#[derive(Debug, Clone, new, Hash)]
pub struct Softplus;

activation!(Softplus, |_op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, one, 1.0);
    let wire = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
    let wire = model.wire_node(name.to_string() + ".plus_one", add(), &[wire[0], one])?;
    let wire = model.wire_node(name.to_string() + ".ln", ln(), &wire)?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Hash)]
pub struct Softsign;

activation!(Softsign, |_op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, one, 1.0);
    let x_abs = model.wire_node(name.to_string() + ".abs", abs(), inputs)?;
    let denum = model.wire_node(name.to_string() + ".plus_one", add(), &[x_abs[0], one])?;
    let wire = model.wire_node(name.to_string() + ".div", div(), &[inputs[0], denum[0]])?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct Elu(#[educe(Hash(method = "hash_f32"))] pub f32);

activation!(Elu, |op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, zero, 0.0);
    cst!(model, inputs, name, one, 1.0);
    cst!(model, inputs, name, alpha, op.0);
    let x_exp = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
    let minus_one = model.wire_node(name.to_string() + ".minus_one", sub(), &[x_exp[0], one])?;
    let neg = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, minus_one[0]])?;
    let test = model.wire_node(
        name.to_string() + ".test",
        tract_core::ops::logic::less(),
        &[zero, inputs[0]],
    )?;
    let wire = model.wire_node(
        name.to_string() + ".iff",
        tract_core::ops::logic::Iff,
        &[test[0], inputs[0], neg[0]],
    )?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct HardSigmoid(
    #[educe(Hash(method = "hash_f32"))] pub f32,
    #[educe(Hash(method = "hash_f32"))] pub f32,
);

activation!(HardSigmoid, |op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, zero, 0.0);
    cst!(model, inputs, name, one, 1.0);
    cst!(model, inputs, name, alpha, op.0);
    cst!(model, inputs, name, beta, op.1);
    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, inputs[0]])?;
    let wire = model.wire_node(name.to_string() + ".add_beta", add(), &[beta, wire[0]])?;
    let wire = model.wire_node(name.to_string() + ".sat-one", min(), &[one, wire[0]])?;
    let wire = model.wire_node(name.to_string() + ".sat-zero", max(), &[zero, wire[0]])?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct LeakyRelu(#[educe(Hash(method = "hash_f32"))] pub f32);

activation!(LeakyRelu, |op, name: &str, model: &mut TypedModel, inputs| {
    model.wire_node(name, tract_core::ops::nn::leaky_relu(op.0), inputs)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct ParametricSoftplus(
    #[educe(Hash(method = "hash_f32"))] pub f32,
    #[educe(Hash(method = "hash_f32"))] pub f32,
);

activation!(ParametricSoftplus, |op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, one, 1.0);
    cst!(model, inputs, name, alpha, op.0);
    cst!(model, inputs, name, beta, op.1);
    let wire = model.wire_node(name.to_string() + ".mul_beta", mul(), &[beta, inputs[0]])?;
    let wire = model.wire_node(name.to_string() + ".exp", exp(), &wire)?;
    let wire = model.wire_node(name.to_string() + ".plus_one", add(), &[one, wire[0]])?;
    let wire = model.wire_node(name.to_string() + ".ln", ln(), &wire)?;
    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, wire[0]])?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct ScaledTanh(
    #[educe(Hash(method = "hash_f32"))] pub f32,
    #[educe(Hash(method = "hash_f32"))] pub f32,
);

activation!(ScaledTanh, |op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, alpha, op.0);
    cst!(model, inputs, name, beta, op.1);
    let wire = model.wire_node(name.to_string() + ".mul_beta", mul(), &[beta, inputs[0]])?;
    let wire = model.wire_node(name.to_string() + ".tanh", tanh(), &wire)?;
    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, wire[0]])?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct Selu(
    #[educe(Hash(method = "hash_f32"))] pub f32,
    #[educe(Hash(method = "hash_f32"))] pub f32,
);

activation!(Selu, |op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, zero, 0.0);
    cst!(model, inputs, name, alpha, op.0);
    cst!(model, inputs, name, gamma, op.1);
    let wire = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[wire[0], alpha])?;
    let wire = model.wire_node(name.to_string() + ".sub_alpha", sub(), &[wire[0], alpha])?;
    let test = model.wire_node(
        name.to_string() + ".test",
        tract_core::ops::logic::less(),
        &[zero, inputs[0]],
    )?;
    let wire = model.wire_node(
        name.to_string() + ".iff",
        tract_core::ops::logic::Iff,
        &[test[0], inputs[0], wire[0]],
    )?;
    let wire = model.wire_node(name.to_string() + ".mul_gamma", mul(), &[gamma, wire[0]])?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct Shrink(
    #[educe(Hash(method = "hash_f32"))] pub f32,
    #[educe(Hash(method = "hash_f32"))] pub f32,
);

activation!(Shrink, |op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, bias, op.0);
    cst!(model, inputs, name, lambda, op.1);
    cst!(model, inputs, name, minus_lambda, -op.1);
    let zero = broadcast_scalar(0.0, model, inputs)?;
    let zero = model.add_const(name.to_string() + ".zero", zero)?;
    let test_pos = model.wire_node(
        name.to_string() + ".test_pos",
        tract_core::ops::logic::less(),
        &[lambda, inputs[0]],
    )?;
    let pos = model.wire_node(
        name.to_string() + ".pos",
        tract_core::ops::math::sub(),
        &[inputs[0], bias],
    )?;
    let test_neg = model.wire_node(
        name.to_string() + ".test_neg",
        tract_core::ops::logic::greater(),
        &[minus_lambda, inputs[0]],
    )?;
    let neg = model.wire_node(
        name.to_string() + ".neg",
        tract_core::ops::math::add(),
        &[bias, inputs[0]],
    )?;
    let wire = model.wire_node(
        name.to_string() + ".if_pos",
        tract_core::ops::logic::Iff,
        &[test_pos[0], pos[0], zero],
    )?;
    let wire = model.wire_node(
        name.to_string() + ".if_neg",
        tract_core::ops::logic::Iff,
        &[test_neg[0], neg[0], wire[0]],
    )?;
    Ok(wire)
});

#[derive(Debug, Clone, new, Educe)]
#[educe(Hash)]
pub struct ThresholdRelu(#[educe(Hash(method = "hash_f32"))] pub f32);

activation!(ThresholdRelu, |op, name: &str, model: &mut TypedModel, inputs| {
    cst!(model, inputs, name, zero, 0.0);
    cst!(model, inputs, name, alpha, op.0);
    let test = model.wire_node(
        name.to_string() + ".test",
        tract_core::ops::logic::less(),
        &[alpha, inputs[0]],
    )?;
    let wire = model.wire_node(
        name.to_string() + ".iff",
        tract_core::ops::logic::Iff,
        &[test[0], inputs[0], zero],
    )?;
    Ok(wire)
});

fn simple_unary_rules<'r, 'p: 'r, 's: 'r>(
    s: &mut Solver<'r>,
    inputs: &'p [TensorProxy],
    outputs: &'p [TensorProxy],
) -> InferenceResult {
    check_input_arity(inputs, 1)?;
    check_output_arity(outputs, 1)?;
    s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
    s.equals(&inputs[0].shape, &outputs[0].shape)?;
    Ok(())
}

pub fn broadcast_scalar(
    f: f32,
    model: &TypedModel,
    inputs: &[OutletId],
) -> TractResult<Arc<Tensor>> {
    let fact = model.outlet_fact(inputs[0])?;
    let mut tensor = tensor0(f).cast_to_dt(fact.datum_type)?.into_owned();
    while tensor.rank() < fact.rank() {
        tensor.insert_axis(0)?;
    }
    Ok(tensor.into_arc_tensor())
}