use crate::internal::*;
use tract_linalg::element_wise::ElementWise;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct LstmEpilogue {
pub hidden: usize,
}
impl Op for LstmEpilogue {
fn name(&self) -> StaticName {
"LstmEpilogue".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![format!("hidden={}", self.hidden)])
}
op_as_typed_op!();
}
impl EvalOp for LstmEpilogue {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let ops = tract_linalg::ops();
match inputs[0].datum_type().unquantized() {
DatumType::F32 => self.eval_t::<f32>(inputs, (ops.sigmoid_f32)(), (ops.tanh_f32)()),
DatumType::F16 => self.eval_t::<f16>(inputs, (ops.sigmoid_f16)(), (ops.tanh_f16)()),
dt => bail!("LstmEpilogue only supports f32 and f16 preactivations, got {dt:?}"),
}
}
}
impl LstmEpilogue {
fn eval_t<T>(
&self,
inputs: TVec<TValue>,
sigmoid: Box<dyn ElementWise<T>>,
tanh: Box<dyn ElementWise<T>>,
) -> TractResult<TVec<TValue>>
where
T: Datum + Copy + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
{
let h = self.hidden;
let c_prev = &inputs[1]; let cp = unsafe { c_prev.as_slice_unchecked::<T>() };
let rows = inputs[0].len() / (4 * h); let mut pre_t = inputs[0].clone().into_tensor();
let pre = unsafe { pre_t.as_slice_mut_unchecked::<T>() };
let mut ht = unsafe { Tensor::uninitialized_dt(T::datum_type(), c_prev.shape())? };
let mut ct = unsafe { Tensor::uninitialized_dt(T::datum_type(), c_prev.shape())? };
{
let hs = unsafe { ht.as_slice_mut_unchecked::<T>() };
let cs = unsafe { ct.as_slice_mut_unchecked::<T>() };
for r in 0..rows {
let pb = r * 4 * h;
let cb = r * h;
let row = &mut pre[pb..pb + 4 * h];
sigmoid.run(&mut row[0..3 * h])?;
tanh.run(&mut row[3 * h..4 * h])?;
for j in 0..h {
cs[cb + j] = row[2 * h + j] * cp[cb + j] + row[j] * row[3 * h + j];
}
hs[cb..cb + h].copy_from_slice(&cs[cb..cb + h]);
tanh.run(&mut hs[cb..cb + h])?;
for j in 0..h {
hs[cb + j] = hs[cb + j] * row[h + j];
}
}
}
Ok(tvec!(ht.into_tvalue(), ct.into_tvalue()))
}
}
impl TypedOp for LstmEpilogue {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs.len() == 2, "LstmEpilogue expects [preact, c_prev]");
let c_prev = inputs[1];
let fact = c_prev.datum_type.fact(c_prev.shape.clone());
Ok(tvec!(fact.clone(), fact))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn epilogue_matches_scalar_reference() {
let h = 6usize;
let batch = 3usize;
let preact: Vec<f32> =
(0..batch * 4 * h).map(|i| ((i * 7 % 29) as f32 - 14.0) * 0.25).collect();
let cprev: Vec<f32> = (0..batch * h).map(|i| ((i * 5 % 17) as f32 - 8.0) * 0.2).collect();
let pre_t = Tensor::from_shape(&[batch, 4 * h], &preact).unwrap();
let cprev_t = Tensor::from_shape(&[batch, h], &cprev).unwrap();
let op = LstmEpilogue { hidden: h };
let out = op.eval(tvec!(pre_t.into_tvalue(), cprev_t.into_tvalue())).unwrap();
let ht = unsafe { out[0].as_slice_unchecked::<f32>() };
let ct = unsafe { out[1].as_slice_unchecked::<f32>() };
let sig = |x: f32| 1.0 / (1.0 + (-x).exp());
for r in 0..batch {
for j in 0..h {
let p = r * 4 * h; let it = sig(preact[p + j]);
let ot = sig(preact[p + h + j]);
let ft = sig(preact[p + 2 * h + j]);
let cc = preact[p + 3 * h + j].tanh();
let c_ref = ft * cprev[r * h + j] + it * cc;
let h_ref = ot * c_ref.tanh();
assert!((ct[r * h + j] - c_ref).abs() < 1e-3, "Ct mismatch at ({r},{j})");
assert!((ht[r * h + j] - h_ref).abs() < 1e-3, "Ht mismatch at ({r},{j})");
}
}
}
}