tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::internal::*;

pub mod conv;
pub mod deconv;
mod maxpool;
mod padding;
mod patch_axis;
mod patches;
pub mod pools;
mod sumpool;

pub use self::conv::{Conv, KernelFormat};
pub use self::deconv::Deconv;
pub use self::maxpool::MaxPool;
pub use self::padding::PaddingSpec;
pub use self::patch_axis::PatchAxis;
pub use self::patches::{Patch, PatchSpec};
pub use self::pools::PoolSpec;
pub use self::sumpool::SumPool;

use super::array::MultiBroadcastTo;

pub fn wire_reshape_bias_as_vector(
    model: &mut TypedModel,
    name: impl AsRef<str>,
    outlet: OutletId,
    output_channels: usize,
) -> TractResult<TVec<OutletId>> {
    let name = name.as_ref();
    let mut bias = tvec!(outlet);
    let fact = model.outlet_fact(outlet)?.clone();
    if fact.shape.volume().is_one() && fact.rank() > 0 {
        bias = model.wire_node(
            format!("{name}.bias.make_scalar"),
            AxisOp::Reshape(0, fact.shape.to_tvec(), tvec![]),
            &bias,
        )?;
    }
    if model.outlet_fact(bias[0])?.rank() == 0 {
        bias = model.wire_node(
            format!("{name}.bias.broadcast"),
            MultiBroadcastTo { shape: tvec!(output_channels).into() },
            &bias,
        )?;
    }
    Ok(bias)
}

pub fn wire_reshape_bias_for_bin(
    model: &mut TypedModel,
    name: impl AsRef<str>,
    outlet: OutletId,
    rank: usize,
    c_axis: usize,
    output_channels: usize,
) -> TractResult<TVec<OutletId>> {
    let name = name.as_ref();
    let mut bias = wire_reshape_bias_as_vector(model, name, outlet, output_channels)?;
    let fact = model.outlet_fact(bias[0])?.clone();
    let mut bias_final_shape = tvec![1.to_dim(); rank];
    bias_final_shape[c_axis] = output_channels.to_dim();
    if *bias_final_shape != *fact.shape {
        bias = model.wire_node(
            format!("{name}.bias"),
            AxisOp::Reshape(0, fact.shape.to_tvec(), bias_final_shape),
            &bias,
        )?;
    }
    Ok(bias)
}

pub fn rewrite_conv_with_n_axis(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    name: &str,
    conv: &Conv,
) -> TractResult<Option<TypedModelPatch>> {
    if !conv.pool_spec.data_format.has_n() {
        let mut new = conv.clone();
        new.pool_spec.data_format = conv.pool_spec.data_format.with_n();
        let mut patch = TypedModelPatch::default();
        let mut wire = patch.taps(model, &node.inputs)?;
        wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
        wire = patch.wire_node(name, new, &wire)?;
        wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
        patch.shunt_outside(model, node.id.into(), wire[0])?;
        return Ok(Some(patch));
    }
    Ok(None)
}

pub fn rewrite_deconv_with_n_axis(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    name: &str,
    deconv: &Deconv,
) -> TractResult<Option<TypedModelPatch>> {
    if !deconv.pool_spec.data_format.has_n() {
        let mut new = deconv.clone();
        new.pool_spec.data_format = deconv.pool_spec.data_format.with_n();
        let mut patch = TypedModelPatch::default();
        let mut wire = patch.taps(model, &node.inputs)?;
        wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
        wire = patch.wire_node(name, new, &wire)?;
        wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
        patch.shunt_outside(model, node.id.into(), wire[0])?;
        return Ok(Some(patch));
    }
    Ok(None)
}