use crate::internal::*;
use crate::ops::cast::cast;
use crate::ops::math::add;
use crate::ops::matmul::quant::{
combine_scales, compensate_zero_points, requant, wire_ensure_q8_flavour,
};
use crate::ops::nn::{Reduce, Reducer};
use super::bilinear::BilinearEinSum;
use super::EinSum;
pub fn dequant_bilinear(
model: &TypedModel,
node: &TypedNode,
bi: &BilinearEinSum,
) -> TractResult<Option<TypedModelPatch>> {
let name = &node.name;
let mut patch = TypedModelPatch::new("Dequantizing einsum");
let [k_axis] = &*bi.k_axes else { return Ok(None) };
let axes = &bi.op.axes;
let k_axis = axes.axis(*k_axis)?;
let mut taps = patch.taps(model, &node.inputs)?;
for ab in [0, 1] {
let scale_input = 4 + ab * 2;
if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
let q_axis_in_output = bi.op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
let output_rank = node.outputs[0].fact.rank();
for i in 1..(output_rank - q_axis_in_output) {
taps[scale_input] = patch.wire_node(
format!("{name}.scale_input{ab}_axis_fix_{i}"),
AxisOp::Add(i),
&[taps[scale_input]],
)?[0];
}
}
}
let [mut a, mut b, bias, mut a0, a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
bail!("Expect exactly 9 inputs")
};
wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
let mut output = patch.wire_node(
&node.name,
EinSum {
q_params: None,
axes: bi.op.axes.extract_sub_mapping(&[0, 1], &[0])?,
operating_dt: bi.op.operating_dt,
},
&[a, b],
)?;
let a_i32 = patch.wire_node(format!("{name}.a_as_i32"), cast(i32::datum_type()), &[a])?[0];
let b_i32 = patch.wire_node(format!("{name}.b_as_i32"), cast(i32::datum_type()), &[b])?[0];
let sum_a = patch.wire_node(
format!("{name}.sum_a"),
Reduce::new(tvec!(k_axis.inputs[0][0]), Reducer::Sum),
&[a_i32],
)?;
let sum_b = patch.wire_node(
format!("{name}.sum_b"),
Reduce::new(tvec!(k_axis.inputs[1][0]), Reducer::Sum),
&[b_i32],
)?;
let sum_a =
wire_axes_fix(&mut patch, name, "sum_a", &axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
let sum_b =
wire_axes_fix(&mut patch, name, "sum_b", &axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
let bias = tvec!(bias);
let bias =
wire_axes_fix(&mut patch, name, "bias", &axes.extract_sub_mapping(&[2], &[0])?, bias)?;
let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;
output = patch.wire_node(format!("{name}.add_bias"), add(), &[output[0], bias[0]])?;
let k = model.outlet_fact(node.inputs[0])?.shape[k_axis.inputs[0][0]].clone();
let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
.context("Zero point compensation")?;
let output = requant(&mut patch, name, output, bi.op.q_params.unwrap(), abc_scale, c0)?;
patch.shunt_outside(model, node.id.into(), output)?;
Ok(Some(patch))
}
fn wire_axes_fix(
patch: &mut TypedModelPatch,
name: &str,
var: &str,
mapping: &AxesMapping,
mut outlet: TVec<OutletId>,
) -> TractResult<TVec<OutletId>> {
for (ix, axis_op) in mapping.translate_to_axis_ops()?.into_iter().enumerate() {
outlet = patch.wire_node(format!("{name}.fix_{var}.{ix})"), axis_op, &outlet)?;
}
Ok(outlet)
}