use snafu::ResultExt;
use svod_dtype::DType;
use svod_tensor::Tensor;
use svod_tensor::nn::LSTMCell;
use crate::init::fan_in_uniform;
use crate::state::{self, HasStateDict, StateDict, get_tensor, prefixed};
use crate::gigaam::Result;
use crate::gigaam::error::TensorSnafu;
#[derive(Clone)]
pub struct RnntPredictor {
pub embed: Tensor,
pub layers: Vec<LSTMCell>,
pub pred_hidden: usize,
pub num_classes: usize,
pub blank_id: usize,
}
impl RnntPredictor {
pub fn empty(pred_hidden: usize, num_layers: usize, num_classes: usize) -> Self {
let blank_id = num_classes - 1;
let h4 = 4 * pred_hidden;
Self {
embed: fan_in_uniform(&[num_classes, pred_hidden], num_classes, DType::Float32),
layers: (0..num_layers)
.map(|_| {
LSTMCell::new(
fan_in_uniform(&[h4, pred_hidden], pred_hidden, DType::Float32),
fan_in_uniform(&[h4, pred_hidden], pred_hidden, DType::Float32),
fan_in_uniform(&[h4], pred_hidden, DType::Float32),
fan_in_uniform(&[h4], pred_hidden, DType::Float32),
)
})
.collect(),
pred_hidden,
num_classes,
blank_id,
}
}
pub fn forward_concat(&self, prev_token: &Tensor, h_in: &Tensor, c_in: &Tensor) -> Result<Tensor> {
let p = self.pred_hidden as isize;
let l = self.layers.len() as isize;
let emb = self.embed.embedding(prev_token).context(TensorSnafu)?;
let mut layer_in = emb.try_squeeze(Some(1)).context(TensorSnafu)?;
let mut new_hs: Vec<Tensor> = Vec::with_capacity(self.layers.len());
let mut new_cs: Vec<Tensor> = Vec::with_capacity(self.layers.len());
for (i, cell) in self.layers.iter().enumerate() {
let i_i = i as isize;
let h_i = h_in
.try_shrink([(i_i, i_i + 1), (0, 1), (0, p)])
.context(TensorSnafu)?
.try_squeeze(Some(0))
.context(TensorSnafu)?;
let c_i = c_in
.try_shrink([(i_i, i_i + 1), (0, 1), (0, p)])
.context(TensorSnafu)?
.try_squeeze(Some(0))
.context(TensorSnafu)?;
let (new_h, new_c) = cell.step(&layer_in, &h_i, &c_i).context(TensorSnafu)?;
new_hs.push(new_h.clone());
new_cs.push(new_c.clone());
layer_in = new_h;
}
let g = layer_in.try_unsqueeze(1).context(TensorSnafu)?;
let new_h_stacked = Tensor::stack(&new_hs.iter().collect::<Vec<_>>(), 0).context(TensorSnafu)?;
let new_c_stacked = Tensor::stack(&new_cs.iter().collect::<Vec<_>>(), 0).context(TensorSnafu)?;
let new_h_flat = new_h_stacked.try_reshape([1, 1, l * p]).context(TensorSnafu)?;
let new_c_flat = new_c_stacked.try_reshape([1, 1, l * p]).context(TensorSnafu)?;
Tensor::cat(&[&g, &new_h_flat, &new_c_flat], 2).context(TensorSnafu)
}
pub(crate) fn prepare_for_inference(&mut self) -> Result<()> {
let mut mask_data = vec![1.0_f32; self.num_classes];
mask_data[self.blank_id] = 0.0;
let embed_dtype = self.embed.uop().dtype();
let mask = Tensor::from_slice(&mask_data)
.try_reshape([self.num_classes, 1])
.context(TensorSnafu)?
.cast(embed_dtype)
.context(TensorSnafu)?;
self.embed = self.embed.try_mul(&mask).context(TensorSnafu)?;
self.embed.realize().context(TensorSnafu)?;
Ok(())
}
}
impl HasStateDict for RnntPredictor {
fn state_dict(&self, prefix: &str) -> StateDict {
let mut sd = StateDict::new();
sd.insert(prefixed(prefix, "embed"), self.embed.clone());
for (i, cell) in self.layers.iter().enumerate() {
let p = prefixed(prefix, &format!("lstm.{i}"));
sd.insert(prefixed(&p, "w_ih"), cell.weight_ih.clone());
sd.insert(prefixed(&p, "w_hh"), cell.weight_hh.clone());
sd.insert(prefixed(&p, "b_ih"), cell.bias_ih.clone());
sd.insert(prefixed(&p, "b_hh"), cell.bias_hh.clone());
}
sd
}
fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
self.embed = get_tensor(sd, &prefixed(prefix, "embed"))?;
for (i, cell) in self.layers.iter_mut().enumerate() {
let p = prefixed(prefix, &format!("lstm.{i}"));
cell.weight_ih = get_tensor(sd, &prefixed(&p, "w_ih"))?;
cell.weight_hh = get_tensor(sd, &prefixed(&p, "w_hh"))?;
cell.bias_ih = get_tensor(sd, &prefixed(&p, "b_ih"))?;
cell.bias_hh = get_tensor(sd, &prefixed(&p, "b_hh"))?;
}
Ok(())
}
}