tract-tensorflow 0.22.1

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use tract_hir::internal::*;
use tract_ndarray::prelude::*;

use tract_hir::tract_core::ops::cnn::{Conv, PoolSpec};

#[derive(Debug, Copy, Clone, Hash)]
pub enum PaddingStrat {
    FlexFixed(usize),
    FixedFlex(usize),
    FixedFixed(usize, usize),
}

#[derive(Debug, Clone, new, Hash)]
pub struct SpaceToBatchUnary {
    pub datum_type: DatumType,
    pub space_shape: TVec<TDim>,
    pub batch_shape: TVec<TDim>,
    pub block_shape: Array1<i32>,
    pub pad: TVec<PaddingStrat>,
}

impl Op for SpaceToBatchUnary {
    fn name(&self) -> StaticName {
        "SpaceToBatchUnary".into()
    }

    op_as_typed_op!();
}

impl EvalOp for SpaceToBatchUnary {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let mut paddings = Array2::zeros((self.block_shape.len(), 2));
        for (ax, &strat) in self.pad.iter().enumerate() {
            let spread = (self.batch_shape[2 + ax].clone() * self.block_shape[ax]
                - &self.space_shape[2 + ax])
                .to_usize()?;
            let (bef, aft) = match strat {
                PaddingStrat::FlexFixed(f) => (spread - f, f),
                PaddingStrat::FixedFlex(f) => (f, spread - f),
                PaddingStrat::FixedFixed(a, b) => (a, b),
            };
            paddings[(ax, 0)] = bef as i32;
            paddings[(ax, 1)] = aft as i32;
        }
        let r = dispatch_numbers!(super::space_to_batch(input.datum_type())(
            input,
            &self.block_shape.view(),
            &paddings.view()
        ))?;
        Ok(tvec!(r))
    }
}

impl TypedOp for SpaceToBatchUnary {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(inputs[0].datum_type.fact(&self.batch_shape)))
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let [succ] = &*model.node(node.id).outputs[0].successors else { return Ok(None) };
        let conv_node = model.node(succ.node);
        let Some(conv_op) = conv_node.op_as::<Conv>() else { return Ok(None) };
        let [succ] = &*conv_node.outputs[0].successors else { return Ok(None) };
        let b2s_node = model.node(succ.node);
        let Some(_bs2_op) = b2s_node.op_as::<BatchToSpaceUnary>() else { return Ok(None) };
        let op = Conv {
            pool_spec: PoolSpec {
                dilations: Some(self.block_shape.iter().map(|&i| i as usize).collect()),
                ..conv_op.pool_spec.clone()
            },
            ..conv_op.clone()
        };
        let mut patch = TypedModelPatch::default();
        let taps_s2b = patch.taps(model, &node.inputs)?;
        let mut taps_conv = patch.taps(model, &conv_node.inputs)?;
        taps_conv[0] = taps_s2b[0];
        let out = patch.model.wire_node(&*conv_node.name, op, &taps_conv)?[0];
        patch.shunt_outside(model, OutletId::new(b2s_node.id, 0), out)?;
        Ok(Some(patch))
    }

    as_op!();
}

#[derive(Debug, Clone, new, Hash)]
pub struct BatchToSpaceUnary {
    datum_type: DatumType,
    batch_shape: TVec<TDim>,
    space_shape: TVec<TDim>,
    block_shape: Array1<i32>,
    pad: Vec<PaddingStrat>,
}

impl Op for BatchToSpaceUnary {
    fn name(&self) -> StaticName {
        "BatchToSpaceUnary".into()
    }

    op_as_typed_op!();
}

impl EvalOp for BatchToSpaceUnary {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let mut paddings = Array2::zeros((self.block_shape.len(), 2));
        for (ax, &strat) in self.pad.iter().enumerate() {
            let spread = (self.batch_shape[2 + ax].clone() * self.block_shape[ax]
                - &self.space_shape[2 + ax])
                .to_usize()?;
            let (bef, aft) = match strat {
                PaddingStrat::FlexFixed(f) => (spread - f, f),
                PaddingStrat::FixedFlex(f) => (f, spread - f),
                PaddingStrat::FixedFixed(a, b) => (a, b),
            };
            paddings[(ax, 0)] = bef as i32;
            paddings[(ax, 1)] = aft as i32;
        }
        let r = dispatch_numbers!(super::batch_to_space(input.datum_type())(
            input,
            &self.block_shape.view(),
            &paddings.view()
        ))?;
        Ok(tvec!(r))
    }
}

impl TypedOp for BatchToSpaceUnary {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(inputs[0].datum_type.fact(&self.space_shape)))
    }
}