tract_tensorflow/ops/rec/
block_lstm.rs

1use tract_hir::internal::*;
2use tract_hir::tract_core::ops::einsum::EinSum;
3use tract_hir::tract_core::ops::scan::ScanInfo;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub fn block_lstm(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
9    let forget_bias = node.get_attr_opt_float("forget_bias")?.unwrap_or(1.0);
10    let cell_clip = node.get_attr_opt_float("cell_clip")?.unwrap_or(3.0);
11    let t = node.get_attr_datum_type("T")?;
12    let use_peephole = node.get_attr_opt_bool("use_peephole")?.unwrap_or(false);
13    if use_peephole {
14        unimplemented!("Block LSTM peeplholes");
15    }
16    Ok(expand(BlockLSTM::new(forget_bias, cell_clip, t, use_peephole)))
17}
18
19#[derive(Clone, Debug, new)]
20#[allow(dead_code)]
21pub struct BlockLSTM {
22    forget_bias: f32,
23    cell_clip: f32,
24    t: DatumType,
25    use_peephole: bool,
26}
27
28impl Expansion for BlockLSTM {
29    fn name(&self) -> StaticName {
30        "BlockLSTM".into()
31    }
32
33    fn rules<'r, 'p: 'r, 's: 'r>(
34        &'s self,
35        s: &mut Solver<'r>,
36        inputs: &'p [TensorProxy],
37        outputs: &'p [TensorProxy],
38    ) -> InferenceResult {
39        check_input_arity(inputs, 9)?;
40        check_input_arity(outputs, 7)?;
41
42        s.equals(&inputs[0].rank, 0)?; // seq_len_max
43        s.equals(&inputs[0].datum_type, i64::datum_type())?;
44
45        // other inputs and outputs are consistent float-like
46        s.equals_all((1..=7).map(move |i| (&inputs[i].datum_type).bex()).collect())?;
47
48        s.equals(&inputs[1].rank, 3)?; // x:  [ time, batch, cell_size ]
49        s.equals(&inputs[2].rank, 2)?; // cs_prev: [batch, cell_size]
50        s.equals(&inputs[3].rank, 2)?; // h_prev: [batch, cell_size]
51        s.equals(&inputs[4].rank, 2)?; // w: []
52        s.equals(&inputs[5].rank, 1)?; // peephole input
53        s.equals(&inputs[6].rank, 1)?; // peephole forget
54        s.equals(&inputs[7].rank, 1)?; // peephole output
55        s.equals(&inputs[8].rank, 1)?; // bias: [ 4*cell_size ]
56        s.equals(&inputs[8].shape[0], 4 * inputs[1].shape[2].bex())?; // bias: [ 4*cell_size ]
57
58        // i, cs, f, o, ci, co, h
59        for (i, output) in outputs.iter().take(7).enumerate() {
60            s.equals(&inputs[1].datum_type, &output.datum_type)?;
61            s.equals(&outputs[i].shape, &inputs[1].shape)?;
62        }
63
64        Ok(())
65    }
66
67    fn nboutputs(&self) -> TractResult<usize> {
68        Ok(7)
69    }
70
71    fn wire(
72        &self,
73        prefix: &str,
74        model: &mut TypedModel,
75        inputs: &[OutletId],
76    ) -> TractResult<TVec<OutletId>> {
77        use tract_hir::tract_core::ops::{array, math, nn, scan};
78
79        let mut body = TypedModel::default();
80        let mut outer_inputs = vec![];
81        let mut input_mapping = vec![];
82        let mut output_mapping = vec![];
83
84        let w = model.outlet_fact(inputs[4])?.konst.clone().context("W must be cosntant")?;
85        let b = model.outlet_fact(inputs[8])?.konst.clone().context("B must be constant")?;
86        let cell_size = w.shape()[1] / 4;
87        let mut b = b.into_tensor();
88        b.insert_axis(0)?;
89
90        macro_rules! wire {
91            ($name: ident = $op: expr, $($param: expr),*) => {
92                let $name = body.wire_node(
93                    format!("{}-{}", prefix, stringify!($name)),
94                    $op, [$($param),*].as_ref())?[0];
95            }
96        }
97
98        // X: body input 0: X, new outside input 0 (was 1)
99        outer_inputs.push(inputs[1]);
100        input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 0, chunk: 1 }));
101        let mut x_source_fact = model.outlet_fact(inputs[1])?.clone();
102        x_source_fact.shape.set(0, 1.to_dim());
103        let x_source = body.add_source("x_source", x_source_fact)?;
104        wire!(x = AxisOp::Rm(0), x_source);
105
106        // CS: body input 1
107        let cs = model.wire_node(format!("{prefix}.cs-axis"), AxisOp::Add(0), &[inputs[2]])?[0];
108        outer_inputs.push(cs);
109        let cs_fact = model.outlet_fact(cs)?.clone();
110        let cs_source = body.add_source("cs_source", cs_fact)?;
111        input_mapping.push(scan::InputMapping::State);
112        wire!(cs_prev = AxisOp::Rm(0), cs_source);
113
114        // H: body input 2
115        let h = model.wire_node(format!("{prefix}.h-axis"), AxisOp::Add(0), &[inputs[3]])?[0];
116        outer_inputs.push(h);
117        let h_fact = model.outlet_fact(h)?.clone();
118        let h_source = body.add_source("h_source", h_fact)?;
119        input_mapping.push(scan::InputMapping::State);
120        wire!(h_prev = AxisOp::Rm(0), h_source);
121
122        wire!(xh = array::TypedConcat::new(1), x, h_prev);
123
124        let w = body.add_const(format!("{prefix}-w"), w)?;
125        let b = body.add_const(format!("{prefix}-b"), b)?;
126        wire!(i_ci_f_o_1 = EinSum::new("mk,kn->mn".parse()?, f32::datum_type()), xh, w);
127        wire!(i_ci_f_o = math::add(), b, i_ci_f_o_1);
128
129        wire!(i_1 = array::Slice::new(1, 0, cell_size), i_ci_f_o);
130        wire!(i = nn::sigmoid(), i_1);
131
132        wire!(f_1 = array::Slice::new(1, 2 * cell_size, 3 * cell_size), i_ci_f_o);
133        let bias = body.add_const(format!("{prefix}-bias"), rctensor2(&[[self.forget_bias]]))?;
134        wire!(f_2 = math::add(), f_1, bias);
135        wire!(f = nn::sigmoid(), f_2);
136
137        wire!(ci_1 = array::Slice::new(1, cell_size, 2 * cell_size), i_ci_f_o);
138        wire!(ci = math::tanh(), ci_1);
139
140        wire!(o_1 = array::Slice::new(1, 3 * cell_size, 4 * cell_size), i_ci_f_o);
141        wire!(o = nn::sigmoid(), o_1);
142
143        wire!(ci_i = math::mul(), ci, i);
144        wire!(cs_1 = math::mul(), cs_prev, f);
145        wire!(cs = math::add(), cs_1, ci_i);
146
147        wire!(co = math::tanh(), cs);
148        wire!(h = math::mul(), co, o);
149
150        wire!(i_ = AxisOp::Add(0), i);
151        wire!(cs_ = AxisOp::Add(0), cs);
152        wire!(f_ = AxisOp::Add(0), f);
153        wire!(o_ = AxisOp::Add(0), o);
154        wire!(ci_ = AxisOp::Add(0), ci);
155        wire!(co_ = AxisOp::Add(0), co);
156        wire!(h_ = AxisOp::Add(0), h);
157        body.set_output_outlets(&[i_, cs_, f_, o_, ci_, co_, h_])?;
158        for ix in 0..7 {
159            output_mapping.push(scan::OutputMapping::<TDim> {
160                state: ix == 1 || ix == 6,
161                full_dim_hint: None,
162                last_value_slot: None,
163                scan: Some((ix, ScanInfo { axis: 0, chunk: 1 })),
164            })
165        }
166
167        let Some(seqlen) = &model.outlet_fact(inputs[0])?.konst else {
168                bail!("Non constant seq_len is not supported");
169            };
170        let Some(seqlen) = seqlen.as_uniform() else {
171                bail!("Non uniform seq_len is not supported");
172            };
173        let seqlen = seqlen.cast_to::<TDim>()?;
174        if seqlen.to_scalar::<TDim>()? != &model.outlet_fact(inputs[1])?.shape[0] {
175            bail!("seq_len only supported for trivial noop case");
176        };
177        let scan = scan::Scan::new(body, input_mapping, output_mapping, 0)?;
178        model.wire_node(prefix, scan, &outer_inputs)
179    }
180}
181
182/*
183// TODO: rewrite this logic as a tf.Assign declutter ?
184impl BlockLSTM {
185fn inline_var_assign(
186&self,
187model: &TypedModel,
188node: &TypedNode,
189input_id: usize,
190output_id: usize,
191patch: &mut TypedModelPatch,
192) -> TractResult<Option<Arc<Tensor>>> {
193let var_2 = model.node(node.inputs[input_id].node);
194let var_2_op = if let Some(op) = var_2.op_as::<crate::ops::vars::VariableV2>() {
195op
196} else {
197return Ok(None);
198};
199if var_2.outputs[0].successors.len() != 2 {
200return Ok(None);
201}
202let assign = if let Some(assign_node) = var_2.outputs[0]
203.successors
204.iter()
205.map(|s| model.node(s.node))
206.filter(|s| s.op_is::<crate::ops::vars::Assign>())
207.next()
208{
209assign_node
210} else {
211return Ok(None);
212};
213let rm_axis_node = model.node(assign.inputs[1].node);
214let rm_axis_op = if let Some(op) = rm_axis_node.op_as::<tract_hir::internal::AxisOp>() {
215op
216} else {
217return Ok(None);
218};
219if rm_axis_op != &tract_hir::internal::AxisOp::Rm(0) {
220return Ok(None);
221}
222let slice_node = model.node(rm_axis_node.inputs[0].node);
223let slice_op = if let Some(op) = slice_node.op_as::<tract_hir::ops::array::Slice<usize>>() {
224op
225} else {
226return Ok(None);
227};
228if slice_node.inputs[0] != (node.id, output_id).into() {
229return Ok(None);
230}
231let lstm_output_fact = model.outlet_fact(slice_node.inputs[0])?;
232if slice_op.axis != 0
233|| slice_op.end != slice_op.start + 1
234|| slice_op.end.to_dim() != lstm_output_fact.shape.dim(0)
235{
236return Ok(None);
237}
238let tap = patch.tap_model(model, rm_axis_node.id.into())?;
239patch.shunt_outside(model, assign.id.into(), tap)?;
240Ok(var_2_op.initializer.clone())
241}
242}
243*/