use ndarray::*;
use crate::tfpb::node_def::NodeDef;
use tract_core::internal::*;
use tract_core::ops::nn::sigmoid::sigmoid_f32;
use tract_core::ops::nn::tanh::tanh_f32;
pub fn block_lstm(node: &NodeDef) -> TractResult<Box<Op>> {
let forget_bias = node.get_attr_opt_float("forget_bias")?.unwrap_or(1.0);
let cell_clip = node.get_attr_opt_float("cell_clip")?.unwrap_or(3.0);
let t = node.get_attr_datum_type("T")?;
let use_peephole = node.get_attr_opt_bool("use_peephole")?.unwrap_or(false);
Ok(Box::new(BlockLSTM::new(forget_bias, cell_clip, t, use_peephole)))
}
#[derive(Clone, Debug, new)]
pub struct BlockLSTM {
forget_bias: f32,
cell_clip: f32,
t: DatumType,
use_peephole: bool,
}
impl Op for BlockLSTM {
fn name(&self) -> Cow<str> {
"tf.BlockLSTM".into()
}
}
impl StatelessOp for BlockLSTM {
fn eval(&self, inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
let len = *inputs[0].cast_to::<i32>()?.to_scalar::<i32>()? as usize;
let x = inputs[1].to_array_view::<f32>()?.into_dimensionality::<Ix3>()?;
let cell_size = x.shape()[2];
let cs_prev = inputs[2].to_array_view::<f32>()?;
let h_prev = inputs[3].to_array_view::<f32>()?.into_dimensionality::<Ix2>()?;
let w = inputs[4].to_array_view::<f32>()?.into_dimensionality::<Ix2>()?;
let bias = inputs[8].to_array_view::<f32>()?;
let outputs_shape = x.shape();
let mut i = unsafe { ArrayD::<f32>::uninitialized(&*outputs_shape) };
let mut cs = unsafe { ArrayD::<f32>::uninitialized(&*outputs_shape) };
let mut f = unsafe { ArrayD::<f32>::uninitialized(&*outputs_shape) };
let mut o = unsafe { ArrayD::<f32>::uninitialized(&*outputs_shape) };
let mut ci = unsafe { ArrayD::<f32>::uninitialized(&*outputs_shape) };
let mut co = unsafe { ArrayD::<f32>::uninitialized(&*outputs_shape) };
let mut h = unsafe { ArrayD::<f32>::uninitialized(&*outputs_shape) };
let mut h_prev = h_prev.to_owned();
let mut cs_prev = cs_prev.to_owned();
for n in 0..len {
let x = x.index_axis(Axis(0), n);
let mut i = i.index_axis_mut(Axis(0), n);
let mut cs = cs.index_axis_mut(Axis(0), n);
let mut f = f.index_axis_mut(Axis(0), n);
let mut o = o.index_axis_mut(Axis(0), n);
let mut ci = ci.index_axis_mut(Axis(0), n);
let mut co = co.index_axis_mut(Axis(0), n);
let mut h = h.index_axis_mut(Axis(0), n);
let xh = ndarray::stack(Axis(1), &[x, h_prev.view()])?;
let i_ci_f_o = xh.dot(&w) + &bias;
i.assign(&i_ci_f_o.slice_axis(Axis(1), (0..cell_size).into()));
i.mapv_inplace(sigmoid_f32);
f.assign(&i_ci_f_o.slice_axis(Axis(1), (2 * cell_size..3 * cell_size).into()));
f.mapv_inplace(|x| sigmoid_f32(x + self.forget_bias));
ci.assign(&i_ci_f_o.slice_axis(Axis(1), (cell_size..2 * cell_size).into()));
ci.mapv_inplace(tanh_f32);
cs_prev *= &f;
cs_prev += &(ci.to_owned() * &i);
cs.assign(&cs_prev);
o.assign(&i_ci_f_o.slice_axis(Axis(1), (3 * cell_size..4 * cell_size).into()));
o.mapv_inplace(sigmoid_f32);
co.assign(&cs);
co.mapv_inplace(tanh_f32);
h_prev.assign(&co);
h_prev *= &o;
h.assign(&h_prev);
}
if x.shape()[0] > len as usize {
i.slice_axis_mut(Axis(0), (len..).into()).fill(0.0);
cs.slice_axis_mut(Axis(0), (len..).into()).fill(0.0);
f.slice_axis_mut(Axis(0), (len..).into()).fill(0.0);
o.slice_axis_mut(Axis(0), (len..).into()).fill(0.0);
ci.slice_axis_mut(Axis(0), (len..).into()).fill(0.0);
co.slice_axis_mut(Axis(0), (len..).into()).fill(0.0);
h.slice_axis_mut(Axis(0), (len..).into()).fill(0.0);
}
Ok(tvec!(
i.into_arc_tensor(),
cs.into_arc_tensor(),
f.into_arc_tensor(),
o.into_arc_tensor(),
ci.into_arc_tensor(),
co.into_arc_tensor(),
h.into_arc_tensor()
))
}
}
impl InferenceRulesOp for BlockLSTM {
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult {
check_input_arity(&inputs, 9)?;
check_input_arity(&outputs, 7)?;
s.equals(&inputs[0].rank, 0)?; s.equals(&inputs[0].datum_type, i64::datum_type())?;
s.equals_all((1..=7).map(move |i| (&inputs[i].datum_type).bex()).collect())?;
s.equals(&inputs[1].rank, 3)?; s.equals(&inputs[2].rank, 2)?; s.equals(&inputs[3].rank, 2)?; s.equals(&inputs[4].rank, 2)?; s.equals(&inputs[5].rank, 1)?; s.equals(&inputs[6].rank, 1)?; s.equals(&inputs[7].rank, 1)?; s.equals(&inputs[8].rank, 1)?; s.equals(&inputs[8].shape[0], 4 * inputs[1].shape[2].bex())?;
for i in 0..7 {
s.equals(&inputs[1].datum_type, &outputs[i].datum_type)?;
s.equals(&outputs[i].shape, &inputs[1].shape)?;
}
Ok(())
}
}