tract_tensorflow/ops/rec/
block_lstm.rs1use tract_hir::internal::*;
2use tract_hir::tract_core::ops::einsum::EinSum;
3use tract_hir::tract_core::ops::scan::ScanInfo;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub fn block_lstm(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
9 let forget_bias = node.get_attr_opt_float("forget_bias")?.unwrap_or(1.0);
10 let cell_clip = node.get_attr_opt_float("cell_clip")?.unwrap_or(3.0);
11 let t = node.get_attr_datum_type("T")?;
12 let use_peephole = node.get_attr_opt_bool("use_peephole")?.unwrap_or(false);
13 if use_peephole {
14 unimplemented!("Block LSTM peeplholes");
15 }
16 Ok(expand(BlockLSTM::new(forget_bias, cell_clip, t, use_peephole)))
17}
18
19#[derive(Clone, Debug, new)]
20#[allow(dead_code)]
21pub struct BlockLSTM {
22 forget_bias: f32,
23 cell_clip: f32,
24 t: DatumType,
25 use_peephole: bool,
26}
27
28impl Expansion for BlockLSTM {
29 fn name(&self) -> StaticName {
30 "BlockLSTM".into()
31 }
32
33 fn rules<'r, 'p: 'r, 's: 'r>(
34 &'s self,
35 s: &mut Solver<'r>,
36 inputs: &'p [TensorProxy],
37 outputs: &'p [TensorProxy],
38 ) -> InferenceResult {
39 check_input_arity(inputs, 9)?;
40 check_input_arity(outputs, 7)?;
41
42 s.equals(&inputs[0].rank, 0)?; s.equals(&inputs[0].datum_type, i64::datum_type())?;
44
45 s.equals_all((1..=7).map(move |i| (&inputs[i].datum_type).bex()).collect())?;
47
48 s.equals(&inputs[1].rank, 3)?; s.equals(&inputs[2].rank, 2)?; s.equals(&inputs[3].rank, 2)?; s.equals(&inputs[4].rank, 2)?; s.equals(&inputs[5].rank, 1)?; s.equals(&inputs[6].rank, 1)?; s.equals(&inputs[7].rank, 1)?; s.equals(&inputs[8].rank, 1)?; s.equals(&inputs[8].shape[0], 4 * inputs[1].shape[2].bex())?; for (i, output) in outputs.iter().take(7).enumerate() {
60 s.equals(&inputs[1].datum_type, &output.datum_type)?;
61 s.equals(&outputs[i].shape, &inputs[1].shape)?;
62 }
63
64 Ok(())
65 }
66
67 fn nboutputs(&self) -> TractResult<usize> {
68 Ok(7)
69 }
70
71 fn wire(
72 &self,
73 prefix: &str,
74 model: &mut TypedModel,
75 inputs: &[OutletId],
76 ) -> TractResult<TVec<OutletId>> {
77 use tract_hir::tract_core::ops::{array, math, nn, scan};
78
79 let mut body = TypedModel::default();
80 let mut outer_inputs = vec![];
81 let mut input_mapping = vec![];
82 let mut output_mapping = vec![];
83
84 let w = model.outlet_fact(inputs[4])?.konst.clone().context("W must be cosntant")?;
85 let b = model.outlet_fact(inputs[8])?.konst.clone().context("B must be constant")?;
86 let cell_size = w.shape()[1] / 4;
87 let mut b = b.into_tensor();
88 b.insert_axis(0)?;
89
90 macro_rules! wire {
91 ($name: ident = $op: expr, $($param: expr),*) => {
92 let $name = body.wire_node(
93 format!("{}-{}", prefix, stringify!($name)),
94 $op, [$($param),*].as_ref())?[0];
95 }
96 }
97
98 outer_inputs.push(inputs[1]);
100 input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 0, chunk: 1 }));
101 let mut x_source_fact = model.outlet_fact(inputs[1])?.clone();
102 x_source_fact.shape.set(0, 1.to_dim());
103 let x_source = body.add_source("x_source", x_source_fact)?;
104 wire!(x = AxisOp::Rm(0), x_source);
105
106 let cs = model.wire_node(format!("{prefix}.cs-axis"), AxisOp::Add(0), &[inputs[2]])?[0];
108 outer_inputs.push(cs);
109 let cs_fact = model.outlet_fact(cs)?.clone();
110 let cs_source = body.add_source("cs_source", cs_fact)?;
111 input_mapping.push(scan::InputMapping::State);
112 wire!(cs_prev = AxisOp::Rm(0), cs_source);
113
114 let h = model.wire_node(format!("{prefix}.h-axis"), AxisOp::Add(0), &[inputs[3]])?[0];
116 outer_inputs.push(h);
117 let h_fact = model.outlet_fact(h)?.clone();
118 let h_source = body.add_source("h_source", h_fact)?;
119 input_mapping.push(scan::InputMapping::State);
120 wire!(h_prev = AxisOp::Rm(0), h_source);
121
122 wire!(xh = array::TypedConcat::new(1), x, h_prev);
123
124 let w = body.add_const(format!("{prefix}-w"), w)?;
125 let b = body.add_const(format!("{prefix}-b"), b)?;
126 wire!(i_ci_f_o_1 = EinSum::new("mk,kn->mn".parse()?, f32::datum_type()), xh, w);
127 wire!(i_ci_f_o = math::add(), b, i_ci_f_o_1);
128
129 wire!(i_1 = array::Slice::new(1, 0, cell_size), i_ci_f_o);
130 wire!(i = nn::sigmoid(), i_1);
131
132 wire!(f_1 = array::Slice::new(1, 2 * cell_size, 3 * cell_size), i_ci_f_o);
133 let bias = body.add_const(format!("{prefix}-bias"), rctensor2(&[[self.forget_bias]]))?;
134 wire!(f_2 = math::add(), f_1, bias);
135 wire!(f = nn::sigmoid(), f_2);
136
137 wire!(ci_1 = array::Slice::new(1, cell_size, 2 * cell_size), i_ci_f_o);
138 wire!(ci = math::tanh(), ci_1);
139
140 wire!(o_1 = array::Slice::new(1, 3 * cell_size, 4 * cell_size), i_ci_f_o);
141 wire!(o = nn::sigmoid(), o_1);
142
143 wire!(ci_i = math::mul(), ci, i);
144 wire!(cs_1 = math::mul(), cs_prev, f);
145 wire!(cs = math::add(), cs_1, ci_i);
146
147 wire!(co = math::tanh(), cs);
148 wire!(h = math::mul(), co, o);
149
150 wire!(i_ = AxisOp::Add(0), i);
151 wire!(cs_ = AxisOp::Add(0), cs);
152 wire!(f_ = AxisOp::Add(0), f);
153 wire!(o_ = AxisOp::Add(0), o);
154 wire!(ci_ = AxisOp::Add(0), ci);
155 wire!(co_ = AxisOp::Add(0), co);
156 wire!(h_ = AxisOp::Add(0), h);
157 body.set_output_outlets(&[i_, cs_, f_, o_, ci_, co_, h_])?;
158 for ix in 0..7 {
159 output_mapping.push(scan::OutputMapping::<TDim> {
160 state: ix == 1 || ix == 6,
161 full_dim_hint: None,
162 last_value_slot: None,
163 scan: Some((ix, ScanInfo { axis: 0, chunk: 1 })),
164 })
165 }
166
167 let Some(seqlen) = &model.outlet_fact(inputs[0])?.konst else {
168 bail!("Non constant seq_len is not supported");
169 };
170 let Some(seqlen) = seqlen.as_uniform() else {
171 bail!("Non uniform seq_len is not supported");
172 };
173 let seqlen = seqlen.cast_to::<TDim>()?;
174 if seqlen.to_scalar::<TDim>()? != &model.outlet_fact(inputs[1])?.shape[0] {
175 bail!("seq_len only supported for trivial noop case");
176 };
177 let scan = scan::Scan::new(body, input_mapping, output_mapping, 0)?;
178 model.wire_node(prefix, scan, &outer_inputs)
179 }
180}
181
182