use std::time::{Duration, Instant};
use snafu::ResultExt;
use svod_device::Buffer;
use crate::jit::{DeviceSnafu, JitError, Result};
pub struct LstmState {
pub h: Vec<f32>,
pub c: Vec<f32>,
}
impl LstmState {
pub fn zeros(total: usize) -> Self {
Self { h: vec![0.0f32; total], c: vec![0.0f32; total] }
}
pub fn reset(&mut self) {
self.h.fill(0.0);
self.c.fill(0.0);
}
}
#[derive(Default, Clone, Debug)]
pub struct StepTiming {
pub pack: Duration,
pub exec: Duration,
pub read: Duration,
}
pub trait RecurrentJit {
fn pack_state(&mut self, state: &LstmState) -> Result<()>;
fn execute_step(&mut self) -> Result<()>;
fn output_buffer(&self) -> Result<&Buffer>;
}
pub struct JitRecurrent<J: RecurrentJit> {
pub jit: J,
state: LstmState,
head_buf: Vec<f32>,
head_len: usize,
pub last_timing: StepTiming,
}
impl<J: RecurrentJit> JitRecurrent<J> {
pub fn new(jit: J, state: LstmState, head_len: usize) -> Result<Self> {
let declared_state = state.h.len() + state.c.len();
let actual = jit.output_buffer()?.size() / std::mem::size_of::<f32>();
if actual != head_len + declared_state {
return Err(JitError::OutputLayoutMismatch { declared_head: head_len, declared_state, actual });
}
Ok(Self { jit, state, head_buf: vec![0.0; head_len], head_len, last_timing: StepTiming::default() })
}
pub fn state(&self) -> &LstmState {
&self.state
}
pub fn state_mut(&mut self) -> &mut LstmState {
&mut self.state
}
pub fn head_len(&self) -> usize {
self.head_len
}
pub fn step<F>(&mut self, pack_inputs: F) -> Result<&[f32]>
where
F: FnOnce(&mut J) -> Result<()>,
{
let t0 = Instant::now();
pack_inputs(&mut self.jit)?;
self.jit.pack_state(&self.state)?;
let t1 = Instant::now();
self.jit.execute_step()?;
let t2 = Instant::now();
{
let out = self.jit.output_buffer()?;
let arr = out.as_array::<f32>().context(DeviceSnafu)?;
let flat = arr.as_slice().expect("contiguous JIT output");
let head_len = self.head_len;
let h_len = self.state.h.len();
self.head_buf.copy_from_slice(&flat[..head_len]);
self.state.h.copy_from_slice(&flat[head_len..head_len + h_len]);
self.state.c.copy_from_slice(&flat[head_len + h_len..]);
}
let t3 = Instant::now();
self.last_timing = StepTiming { pack: t1 - t0, exec: t2 - t1, read: t3 - t2 };
Ok(&self.head_buf)
}
pub fn reset(&mut self) {
self.state.reset();
}
}