Skip to main content

svod_model/jit/
recurrent.rs

1use std::time::{Duration, Instant};
2
3use snafu::ResultExt;
4use svod_device::Buffer;
5
6use crate::jit::{DeviceSnafu, JitError, Result};
7
8/// Flat host-side LSTM state. `h` and `c` are `f32` vectors of equal length;
9/// for a single-layer cell `h.len() = hidden_size`, for a multi-layer stack
10/// `h.len() = num_layers * hidden_size` (layer-major flat layout matching
11/// `Tensor::stack([..], 0).reshape([1, 1, L*P])`).
12pub struct LstmState {
13    pub h: Vec<f32>,
14    pub c: Vec<f32>,
15}
16
17impl LstmState {
18    pub fn zeros(total: usize) -> Self {
19        Self { h: vec![0.0f32; total], c: vec![0.0f32; total] }
20    }
21
22    pub fn reset(&mut self) {
23        self.h.fill(0.0);
24        self.c.fill(0.0);
25    }
26}
27
28/// Per-step timing for a [`JitRecurrent::step`] call. `pack` = `pack_inputs` +
29/// [`RecurrentJit::pack_state`]; `exec` = JIT execute; `read` = output copy +
30/// state update.
31#[derive(Default, Clone, Debug)]
32pub struct StepTiming {
33    pub pack: Duration,
34    pub exec: Duration,
35    pub read: Duration,
36}
37
38/// A `jit_wrapper!`-generated JIT that participates in a recurrent loop.
39///
40/// Expected output layout: `[non_state_head | h_flat | c_flat]` concatenated
41/// along the last axis, where `h_flat` and `c_flat` each have length
42/// `state.h.len()` floats. [`JitRecurrent::step`] copies the head into a
43/// wrapper-owned buffer and writes the tail back into the host-side state.
44pub trait RecurrentJit {
45    /// Copy the active state into the JIT's typed state input buffers. Called
46    /// once per step before [`execute_step`](Self::execute_step).
47    fn pack_state(&mut self, state: &LstmState) -> Result<()>;
48
49    /// Run the prepared JIT.
50    fn execute_step(&mut self) -> Result<()>;
51
52    /// Borrow the JIT's output buffer (`[head | h_tail | c_tail]`).
53    fn output_buffer(&self) -> Result<&Buffer>;
54}
55
56/// Wraps a recurrent JIT and its host-side LSTM state. Owns the head read
57/// buffer so callers can borrow the non-state output between steps.
58pub struct JitRecurrent<J: RecurrentJit> {
59    pub jit: J,
60    state: LstmState,
61    head_buf: Vec<f32>,
62    head_len: usize,
63    pub last_timing: StepTiming,
64}
65
66impl<J: RecurrentJit> JitRecurrent<J> {
67    /// Construct from a prepared JIT, the initial state, and the declared
68    /// head length in `f32` elements. The JIT's output buffer is read once
69    /// and its size is checked against `head_len + |h| + |c|` to catch
70    /// `build`-closure layout drift at construction time instead of letting
71    /// silently mis-split output corrupt downstream values.
72    pub fn new(jit: J, state: LstmState, head_len: usize) -> Result<Self> {
73        let declared_state = state.h.len() + state.c.len();
74        let actual = jit.output_buffer()?.size() / std::mem::size_of::<f32>();
75        if actual != head_len + declared_state {
76            return Err(JitError::OutputLayoutMismatch { declared_head: head_len, declared_state, actual });
77        }
78        Ok(Self { jit, state, head_buf: vec![0.0; head_len], head_len, last_timing: StepTiming::default() })
79    }
80
81    pub fn state(&self) -> &LstmState {
82        &self.state
83    }
84    pub fn state_mut(&mut self) -> &mut LstmState {
85        &mut self.state
86    }
87
88    pub fn head_len(&self) -> usize {
89        self.head_len
90    }
91
92    /// One recurrent step. `pack_inputs` writes per-step non-state inputs
93    /// (audio chunk, token id, encoder frame, …) into the JIT. The wrapper
94    /// then packs state, executes, splits the output into
95    /// `[head | h_tail | c_tail]`, and updates host state. Returns the head
96    /// slice borrowed from the wrapper-owned buffer.
97    pub fn step<F>(&mut self, pack_inputs: F) -> Result<&[f32]>
98    where
99        F: FnOnce(&mut J) -> Result<()>,
100    {
101        let t0 = Instant::now();
102        pack_inputs(&mut self.jit)?;
103        self.jit.pack_state(&self.state)?;
104        let t1 = Instant::now();
105        self.jit.execute_step()?;
106        let t2 = Instant::now();
107        {
108            let out = self.jit.output_buffer()?;
109            let arr = out.as_array::<f32>().context(DeviceSnafu)?;
110            let flat = arr.as_slice().expect("contiguous JIT output");
111            let head_len = self.head_len;
112            let h_len = self.state.h.len();
113            self.head_buf.copy_from_slice(&flat[..head_len]);
114            self.state.h.copy_from_slice(&flat[head_len..head_len + h_len]);
115            self.state.c.copy_from_slice(&flat[head_len + h_len..]);
116        }
117        let t3 = Instant::now();
118        self.last_timing = StepTiming { pack: t1 - t0, exec: t2 - t1, read: t3 - t2 };
119        Ok(&self.head_buf)
120    }
121
122    /// Zero the host-side state. The next [`step`](Self::step) writes the
123    /// fresh state into the JIT.
124    pub fn reset(&mut self) {
125        self.state.reset();
126    }
127}