use tract_hir::internal::*;
use tract_hir::tract_core::ops::einsum::EinSum;
use tract_hir::tract_core::ops::scan::ScanInfo;
use crate::model::ParsingContext;
use crate::tfpb::tensorflow::NodeDef;
pub fn block_lstm(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
let forget_bias = node.get_attr_opt_float("forget_bias")?.unwrap_or(1.0);
let cell_clip = node.get_attr_opt_float("cell_clip")?.unwrap_or(3.0);
let t = node.get_attr_datum_type("T")?;
let use_peephole = node.get_attr_opt_bool("use_peephole")?.unwrap_or(false);
if use_peephole {
unimplemented!("Block LSTM peeplholes");
}
Ok(expand(BlockLSTM::new(forget_bias, cell_clip, t, use_peephole)))
}
#[derive(Clone, Debug, new)]
#[allow(dead_code)]
pub struct BlockLSTM {
forget_bias: f32,
cell_clip: f32,
t: DatumType,
use_peephole: bool,
}
impl Expansion for BlockLSTM {
fn name(&self) -> StaticName {
"BlockLSTM".into()
}
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult {
check_input_arity(inputs, 9)?;
check_input_arity(outputs, 7)?;
s.equals(&inputs[0].rank, 0)?; s.equals(&inputs[0].datum_type, i64::datum_type())?;
s.equals_all((1..=7).map(move |i| (&inputs[i].datum_type).bex()).collect())?;
s.equals(&inputs[1].rank, 3)?; s.equals(&inputs[2].rank, 2)?; s.equals(&inputs[3].rank, 2)?; s.equals(&inputs[4].rank, 2)?; s.equals(&inputs[5].rank, 1)?; s.equals(&inputs[6].rank, 1)?; s.equals(&inputs[7].rank, 1)?; s.equals(&inputs[8].rank, 1)?; s.equals(&inputs[8].shape[0], 4 * inputs[1].shape[2].bex())?;
for (i, output) in outputs.iter().take(7).enumerate() {
s.equals(&inputs[1].datum_type, &output.datum_type)?;
s.equals(&outputs[i].shape, &inputs[1].shape)?;
}
Ok(())
}
fn nboutputs(&self) -> TractResult<usize> {
Ok(7)
}
fn wire(
&self,
prefix: &str,
model: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
use tract_hir::tract_core::ops::{array, math, nn, scan};
let mut body = TypedModel::default();
let mut outer_inputs = vec![];
let mut input_mapping = vec![];
let mut output_mapping = vec![];
let w = model.outlet_fact(inputs[4])?.konst.clone().context("W must be cosntant")?;
let b = model.outlet_fact(inputs[8])?.konst.clone().context("B must be constant")?;
let cell_size = w.shape()[1] / 4;
let mut b = b.into_tensor();
b.insert_axis(0)?;
macro_rules! wire {
($name: ident = $op: expr, $($param: expr),*) => {
let $name = body.wire_node(
format!("{}-{}", prefix, stringify!($name)),
$op, [$($param),*].as_ref())?[0];
}
}
outer_inputs.push(inputs[1]);
input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 0, chunk: 1 }));
let mut x_source_fact = model.outlet_fact(inputs[1])?.clone();
x_source_fact.shape.set(0, 1.to_dim());
let x_source = body.add_source("x_source", x_source_fact)?;
wire!(x = AxisOp::Rm(0), x_source);
let cs = model.wire_node(format!("{prefix}.cs-axis"), AxisOp::Add(0), &[inputs[2]])?[0];
outer_inputs.push(cs);
let cs_fact = model.outlet_fact(cs)?.clone();
let cs_source = body.add_source("cs_source", cs_fact)?;
input_mapping.push(scan::InputMapping::State);
wire!(cs_prev = AxisOp::Rm(0), cs_source);
let h = model.wire_node(format!("{prefix}.h-axis"), AxisOp::Add(0), &[inputs[3]])?[0];
outer_inputs.push(h);
let h_fact = model.outlet_fact(h)?.clone();
let h_source = body.add_source("h_source", h_fact)?;
input_mapping.push(scan::InputMapping::State);
wire!(h_prev = AxisOp::Rm(0), h_source);
wire!(xh = array::TypedConcat::new(1), x, h_prev);
let w = body.add_const(format!("{prefix}-w"), w)?;
let b = body.add_const(format!("{prefix}-b"), b)?;
wire!(i_ci_f_o_1 = EinSum::new("mk,kn->mn".parse()?, f32::datum_type()), xh, w);
wire!(i_ci_f_o = math::add(), b, i_ci_f_o_1);
wire!(i_1 = array::Slice::new(1, 0, cell_size), i_ci_f_o);
wire!(i = nn::sigmoid(), i_1);
wire!(f_1 = array::Slice::new(1, 2 * cell_size, 3 * cell_size), i_ci_f_o);
let bias = body.add_const(format!("{prefix}-bias"), rctensor2(&[[self.forget_bias]]))?;
wire!(f_2 = math::add(), f_1, bias);
wire!(f = nn::sigmoid(), f_2);
wire!(ci_1 = array::Slice::new(1, cell_size, 2 * cell_size), i_ci_f_o);
wire!(ci = math::tanh(), ci_1);
wire!(o_1 = array::Slice::new(1, 3 * cell_size, 4 * cell_size), i_ci_f_o);
wire!(o = nn::sigmoid(), o_1);
wire!(ci_i = math::mul(), ci, i);
wire!(cs_1 = math::mul(), cs_prev, f);
wire!(cs = math::add(), cs_1, ci_i);
wire!(co = math::tanh(), cs);
wire!(h = math::mul(), co, o);
wire!(i_ = AxisOp::Add(0), i);
wire!(cs_ = AxisOp::Add(0), cs);
wire!(f_ = AxisOp::Add(0), f);
wire!(o_ = AxisOp::Add(0), o);
wire!(ci_ = AxisOp::Add(0), ci);
wire!(co_ = AxisOp::Add(0), co);
wire!(h_ = AxisOp::Add(0), h);
body.set_output_outlets(&[i_, cs_, f_, o_, ci_, co_, h_])?;
for ix in 0..7 {
output_mapping.push(scan::OutputMapping::<TDim> {
state: ix == 1 || ix == 6,
full_dim_hint: None,
last_value_slot: None,
scan: Some((ix, ScanInfo { axis: 0, chunk: 1 })),
})
}
let Some(seqlen) = &model.outlet_fact(inputs[0])?.konst else {
bail!("Non constant seq_len is not supported");
};
let Some(seqlen) = seqlen.as_uniform() else {
bail!("Non uniform seq_len is not supported");
};
let seqlen = seqlen.cast_to::<TDim>()?;
if seqlen.to_scalar::<TDim>()? != &model.outlet_fact(inputs[1])?.shape[0] {
bail!("seq_len only supported for trivial noop case");
};
let scan = scan::Scan::new(body, input_mapping, output_mapping, 0)?;
model.wire_node(prefix, scan, &outer_inputs)
}
}