use std::ops::Range;
use crate::options::SampleRate;
pub(crate) const STATE_LAYERS: usize = 2;
pub(crate) const STATE_HIDDEN_DIM: usize = 128;
pub(crate) const STATE_VALUES: usize = STATE_LAYERS * STATE_HIDDEN_DIM;
pub(crate) const MAX_CONTEXT_SAMPLES: usize = SampleRate::Rate16k.context_samples();
pub(crate) const MAX_CHUNK_SAMPLES: usize = SampleRate::Rate16k.chunk_samples();
#[derive(Debug, Clone)]
pub struct StreamState {
sample_rate: SampleRate,
rnn_state: [f32; STATE_VALUES],
context: [f32; MAX_CONTEXT_SAMPLES],
pending: [f32; MAX_CHUNK_SAMPLES],
pending_len: usize,
}
impl StreamState {
pub fn new(sample_rate: SampleRate) -> Self {
Self {
sample_rate,
rnn_state: [0.0; STATE_VALUES],
context: [0.0; MAX_CONTEXT_SAMPLES],
pending: [0.0; MAX_CHUNK_SAMPLES],
pending_len: 0,
}
}
#[inline]
pub const fn sample_rate(&self) -> SampleRate {
self.sample_rate
}
#[inline]
pub fn set_sample_rate(&mut self, sample_rate: SampleRate) {
if self.sample_rate != sample_rate {
self.sample_rate = sample_rate;
self.reset();
}
}
#[inline]
pub fn reset(&mut self) {
self.rnn_state.fill(0.0);
self.context.fill(0.0);
self.pending_len = 0;
}
#[inline]
pub fn pending_len(&self) -> usize {
self.pending_len
}
#[inline]
pub fn has_pending(&self) -> bool {
self.pending_len != 0
}
#[inline]
pub(crate) fn context(&self) -> &[f32] {
&self.context[..self.sample_rate.context_samples()]
}
#[inline]
pub(crate) fn context_mut(&mut self) -> &mut [f32] {
let context_len = self.sample_rate.context_samples();
&mut self.context[..context_len]
}
#[inline]
pub(crate) fn pending(&self) -> &[f32] {
&self.pending[..self.pending_len]
}
#[inline]
pub(crate) fn append_pending(&mut self, samples: &[f32]) {
let end = self.pending_len + samples.len();
debug_assert!(end <= self.sample_rate.chunk_samples());
self.pending[self.pending_len..end].copy_from_slice(samples);
self.pending_len = end;
}
#[inline]
pub(crate) fn clear_pending(&mut self) {
self.pending_len = 0;
}
#[inline]
pub(crate) fn layer(&self, layer: usize) -> &[f32] {
&self.rnn_state[layer_range(layer)]
}
#[inline]
pub(crate) fn layer_mut(&mut self, layer: usize) -> &mut [f32] {
&mut self.rnn_state[layer_range(layer)]
}
}
impl Default for StreamState {
#[inline]
fn default() -> Self {
Self::new(SampleRate::default())
}
}
#[inline]
fn layer_range(layer: usize) -> Range<usize> {
let start = layer * STATE_HIDDEN_DIM;
start..start + STATE_HIDDEN_DIM
}
#[cfg(test)]
mod tests {
use crate::options::SampleRate;
use super::StreamState;
#[test]
fn reset_clears_state_and_pending() {
let mut state = StreamState::new(SampleRate::Rate16k);
state.layer_mut(0).fill(1.0);
state.context_mut().fill(1.0);
state.append_pending(&[0.1, 0.2]);
state.reset();
assert!(state.layer(0).iter().all(|value| *value == 0.0));
assert!(state.context().iter().all(|value| *value == 0.0));
assert!(state.pending().is_empty());
}
#[test]
fn sample_rate_switch_reinitializes_context_shape() {
let mut state = StreamState::new(SampleRate::Rate16k);
assert_eq!(state.context().len(), 64);
state.set_sample_rate(SampleRate::Rate8k);
assert_eq!(state.context().len(), 32);
assert!(state.context().iter().all(|value| *value == 0.0));
}
}