1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use tract_core::internal::*;

use crate::model::ParsingContext;
use crate::model::TfOpRegister;
use crate::tfpb::tensorflow::NodeDef;

pub fn register_all_ops(reg: &mut TfOpRegister) {
    reg.insert("FakeQuantWithMinMaxVars", fake_quant_with_min_max_vars);
}

fn fake_quant_with_min_max_vars(
    _ctx: &ParsingContext,
    node: &NodeDef,
) -> TractResult<Box<dyn InferenceOp>> {
    let narrow_range = node.get_attr_bool("narrow_range")?;
    let num_bits = node.get_attr_int("num_bits")?;
    Ok(Box::new(FakeQuantWithMinMaxVars::new(narrow_range, num_bits)))
}

#[derive(Clone, Debug, new)]
struct FakeQuantWithMinMaxVars {
    narrow_range: bool,
    num_bits: usize,
}

impl FakeQuantWithMinMaxVars {
    fn step(&self, min: &Tensor, max: &Tensor) -> TractResult<f32> {
        let min = min.to_scalar::<f32>()?;
        let max = max.to_scalar::<f32>()?;
        let amplitude = max - min;
        let scale_len = 2_usize.pow(self.num_bits as u32) - 1 - self.narrow_range as usize;
        Ok(amplitude / scale_len as f32)
    }
}

impl Op for FakeQuantWithMinMaxVars {
    fn name(&self) -> Cow<str> {
        "tf.FakeQuantWithMinMaxVars".into()
    }

    not_a_typed_op!();
}

impl StatelessOp for FakeQuantWithMinMaxVars {
    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let (input, min, max) = args_3!(inputs);
        let step = self.step(&min, &max)?;
        let min = min.to_scalar::<f32>()?;
        let mut tensor = input.into_tensor().into_array::<f32>()?;
        tensor.mapv_inplace(|v| ((v - min) / step).round() * step + min);
        Ok(tvec!(tensor.into_arc_tensor()))
    }
}

impl InferenceRulesOp for FakeQuantWithMinMaxVars {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        check_input_arity(&inputs, 3)?;
        check_output_arity(&outputs, 1)?;
        s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
        s.equals(&inputs[0].datum_type, &inputs[2].datum_type)?;
        s.equals(&inputs[1].shape, shapefact!())?;
        s.equals(&inputs[2].shape, shapefact!())?;
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[0].shape, &outputs[0].shape)?;
        Ok(())
    }

    fn to_typed(
        &self,
        _source: &InferenceModel,
        node: &InferenceNode,
        target: &mut TypedModel,
        mapping: &HashMap<OutletId, OutletId>,
    ) -> TractResult<TVec<OutletId>> {
        if let (Some(min), Some(max)) = (
            target.outlet_fact(mapping[&node.inputs[1]])?.konst.as_ref(),
            target.outlet_fact(mapping[&node.inputs[2]])?.konst.as_ref(),
        ) {
            let rank = target.outlet_fact(mapping[&node.inputs[0]])?.rank();
            let step = self.step(&min, &max)?;
            let min = *min.to_scalar::<f32>()?;
            let bc = |v| -> TractResult<Arc<Tensor>> {
                let mut t = tensor0(v);
                while t.rank() < rank {
                    t.insert_axis(0)?;
                }
                Ok(t.into_arc_tensor())
            };
            let wire = mapping[&node.inputs[0]];
            let wire = target.wire_node(
                format!("{}-sub-min", &*node.name),
                tract_core::ops::math::add::unary(bc(-min)?),
                &[wire],
            )?[0];
            let wire = target.wire_node(
                format!("{}-div-step", &*node.name),
                tract_core::ops::math::mul::unary(bc(step.recip())?),
                &[wire],
            )?[0];
            let wire = target.wire_node(
                format!("{}-round", &*node.name),
                tract_core::ops::math::round(),
                &[wire],
            )?[0];
            let wire = target.wire_node(
                format!("{}-mul-step", &*node.name),
                tract_core::ops::math::mul::unary(bc(step)?),
                &[wire],
            )?[0];
            let wire = target.wire_node(
                format!("{}-add-min", &*node.name),
                tract_core::ops::math::add::unary(bc(min)?),
                &[wire],
            )?[0];
            return Ok(tvec!(wire));
        }
        bail!("Operator can not be made a TypedOp.")
    }

    inference_op_as_op!();
}