use std::time::{Duration, Instant};
use snafu::ResultExt;
use svod_arch::rnnt::JointStep;
use crate::jit::{DeviceSnafu, InputSpec, JitError, JitRecurrent, LstmState, RecurrentJit};
use super::jit::{RnntJointStepJit, RnntPredictorStepJit};
use crate::gigaam::model::GigaAm;
impl RecurrentJit for RnntPredictorStepJit {
fn pack_state(&mut self, s: &LstmState) -> crate::jit::Result<()> {
{
let buf = self.h_in_mut()?;
let mut view = buf.as_array_mut::<f32>().context(DeviceSnafu)?;
view.as_slice_mut().expect("contiguous h_in").copy_from_slice(&s.h);
}
{
let buf = self.c_in_mut()?;
let mut view = buf.as_array_mut::<f32>().context(DeviceSnafu)?;
view.as_slice_mut().expect("contiguous c_in").copy_from_slice(&s.c);
}
Ok(())
}
fn execute_step(&mut self) -> crate::jit::Result<()> {
self.execute()
}
fn output_buffer(&self) -> crate::jit::Result<&svod_device::Buffer> {
self.output()
}
}
pub struct RnntStepBackend {
predictor: JitRecurrent<RnntPredictorStepJit>,
joint_jit: RnntJointStepJit,
committed: LstmState,
g_tentative: Vec<f32>,
blank_id: usize,
enc_hidden: usize,
total_vocab: usize,
pub stats: StepStats,
}
#[derive(Default, Clone, Debug)]
pub struct StepStats {
pub n_steps: u64,
pub n_commits: u64,
pub n_resets: u64,
pub t_pred_pack: Duration,
pub t_pred_exec: Duration,
pub t_pred_read: Duration,
pub t_joint_pack: Duration,
pub t_joint_exec: Duration,
pub t_joint_read: Duration,
}
impl RnntStepBackend {
pub fn from_model(model: GigaAm) -> crate::jit::Result<Self> {
let (rnnt_head, _) =
model.head.expect_rnnt("RnntStepBackend").map_err(|e| JitError::Build { source: Box::new(e) })?;
let pred_hidden = rnnt_head.pred_hidden;
let pred_rnn_layers = rnnt_head.pred_rnn_layers;
let total_vocab = rnnt_head.num_classes;
let blank_id = total_vocab - 1;
let enc_hidden = model.config.d_model;
let lp = pred_rnn_layers * pred_hidden;
let mut predictor_jit = RnntPredictorStepJit::new(model.clone());
predictor_jit.prepare(
InputSpec::i64(&[1, 1]),
InputSpec::f32(&[pred_rnn_layers, 1, pred_hidden]),
InputSpec::f32(&[pred_rnn_layers, 1, pred_hidden]),
)?;
let mut joint_jit = RnntJointStepJit::new(model);
joint_jit.prepare(InputSpec::f32(&[1, 1, enc_hidden]), InputSpec::f32(&[1, 1, pred_hidden]))?;
Ok(Self {
predictor: JitRecurrent::new(predictor_jit, LstmState::zeros(lp), pred_hidden)?,
joint_jit,
committed: LstmState::zeros(lp),
g_tentative: vec![0.0f32; pred_hidden],
blank_id,
enc_hidden,
total_vocab,
stats: StepStats::default(),
})
}
}
impl JointStep for RnntStepBackend {
type Error = JitError;
fn step(
&mut self,
encoder_frame: &[f32],
prev_token: Option<usize>,
logits_out: &mut [f32],
) -> std::result::Result<(), Self::Error> {
debug_assert_eq!(encoder_frame.len(), self.enc_hidden);
debug_assert_eq!(logits_out.len(), self.total_vocab);
let tok_value = prev_token.unwrap_or(self.blank_id) as i64;
let t_state_copy = Instant::now();
self.predictor.state_mut().h.copy_from_slice(&self.committed.h);
self.predictor.state_mut().c.copy_from_slice(&self.committed.c);
let state_copy = t_state_copy.elapsed();
let g = self.predictor.step(|jit| {
let buf = jit.prev_token_mut()?;
let mut view = buf.as_array_mut::<i64>().context(DeviceSnafu)?;
view.as_slice_mut().expect("contiguous prev_token")[0] = tok_value;
Ok(())
})?;
let t_g_copy = Instant::now();
self.g_tentative.copy_from_slice(g);
let g_copy = t_g_copy.elapsed();
let pred_timing = self.predictor.last_timing.clone();
let t0 = Instant::now();
{
let buf = self.joint_jit.enc_t_mut()?;
let mut view = buf.as_array_mut::<f32>().context(DeviceSnafu)?;
view.as_slice_mut().expect("contiguous enc_t").copy_from_slice(encoder_frame);
}
{
let buf = self.joint_jit.g_mut()?;
let mut view = buf.as_array_mut::<f32>().context(DeviceSnafu)?;
view.as_slice_mut().expect("contiguous g").copy_from_slice(&self.g_tentative);
}
let t1 = Instant::now();
self.joint_jit.execute()?;
let t2 = Instant::now();
{
let out = self.joint_jit.output()?;
let arr = out.as_array::<f32>().context(DeviceSnafu)?;
let flat = arr.as_slice().expect("contiguous joint output");
logits_out.copy_from_slice(&flat[..self.total_vocab]);
}
let t3 = Instant::now();
self.stats.n_steps += 1;
self.stats.t_pred_pack += pred_timing.pack + state_copy;
self.stats.t_pred_exec += pred_timing.exec;
self.stats.t_pred_read += pred_timing.read + g_copy;
self.stats.t_joint_pack += t1 - t0;
self.stats.t_joint_exec += t2 - t1;
self.stats.t_joint_read += t3 - t2;
Ok(())
}
fn commit(&mut self) {
self.stats.n_commits += 1;
let state = self.predictor.state_mut();
self.committed.h.copy_from_slice(&state.h);
self.committed.c.copy_from_slice(&state.c);
}
fn reset(&mut self) {
self.stats.n_resets += 1;
self.committed.reset();
self.predictor.reset();
}
}