use smol_str::format_smolstr;
use crate::{
Array,
error::{ArithmeticOverflowPayload, Error, OutOfRangePayload, RankMismatchPayload, Result},
ops::shape::{concatenate, pad},
};
pub trait StreamingEncoderBackend {
fn window_size(&self) -> usize;
fn encode_window(&self, mel_window: &Array, valid_frames: usize) -> Result<Array>;
}
pub struct StreamingEncoder<B> {
encoder: B,
window_size: usize,
window_stride: usize,
max_cached_windows: usize,
cached_windows: Vec<Array>,
newly_encoded_windows: Vec<Array>,
total_encoded_windows: usize,
pending_frames: Option<Array>,
pending_frame_count: usize,
}
impl<B: StreamingEncoderBackend> StreamingEncoder<B> {
pub fn new(encoder: B, max_cached_windows: usize, overlap_frames: usize) -> Self {
let window_size = encoder.window_size();
let clamped_overlap = overlap_frames.min(window_size.saturating_sub(1));
let window_stride = window_size.saturating_sub(clamped_overlap).max(1);
Self {
encoder,
window_size,
window_stride,
max_cached_windows,
cached_windows: Vec::new(),
newly_encoded_windows: Vec::new(),
total_encoded_windows: 0,
pending_frames: None,
pending_frame_count: 0,
}
}
#[inline(always)]
pub fn backend(&self) -> &B {
&self.encoder
}
#[inline(always)]
pub fn window_size(&self) -> usize {
self.window_size
}
#[inline(always)]
pub fn window_stride(&self) -> usize {
self.window_stride
}
#[inline(always)]
pub fn encoded_window_count(&self) -> usize {
self.total_encoded_windows
}
#[inline(always)]
pub fn has_pending_frames(&self) -> bool {
self.pending_frame_count > 0
}
pub fn total_cached_tokens(&self) -> usize {
self
.cached_windows
.iter()
.map(|a| a.shape().first().copied().unwrap_or(0))
.sum()
}
pub fn feed(&mut self, mel_frames: &Array) -> Result<usize> {
if mel_frames.ndim() != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"StreamingEncoder::feed: mel_frames must be rank-2 (frames, n_mels)",
mel_frames.ndim() as u32,
mel_frames.shape(),
)));
}
let window_size_i32 = i32::try_from(self.window_size).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"StreamingEncoder::feed: window_size does not fit i32",
"i32",
[("window_size", self.window_size as u64)],
))
})?;
let stride_i32 = i32::try_from(self.window_stride).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"StreamingEncoder::feed: window_stride does not fit i32",
"i32",
[("window_stride", self.window_stride as u64)],
))
})?;
let combined = match self.pending_frames.as_ref() {
Some(existing) => concatenate(&[existing, mel_frames], 0)?,
None => mel_frames.try_clone()?,
};
let mut staged_count = combined.shape().first().copied().unwrap_or(0);
let mut staged_pending: Option<Array> = Some(combined);
let mut staged_cached: Vec<Array> = Vec::new();
let mut staged_newly: Vec<Array> = Vec::new();
while staged_count >= self.window_size {
let frames = staged_pending
.take()
.expect("staged_pending was non-empty when staged_count >= window_size");
let window = frames.slice(&[0i32, 0i32], &[window_size_i32, i32::MAX], &[1i32, 1i32])?;
let mut encoded = self.encoder.encode_window(&window, self.window_size)?;
encoded.eval()?;
let cached_handle = encoded.try_clone()?;
staged_cached.push(cached_handle);
staged_newly.push(encoded);
if staged_count > self.window_stride {
let remainder = frames.slice(&[stride_i32, 0i32], &[i32::MAX, i32::MAX], &[1i32, 1i32])?;
staged_count = remainder.shape().first().copied().unwrap_or(0);
staged_pending = Some(remainder);
} else {
staged_pending = None;
staged_count = 0;
}
}
let new_windows = staged_cached.len();
self.pending_frames = staged_pending;
self.pending_frame_count = staged_count;
for window in staged_cached {
self.cached_windows.push(window);
if self.cached_windows.len() > self.max_cached_windows {
self.cached_windows.remove(0);
}
}
self.newly_encoded_windows.extend(staged_newly);
self.total_encoded_windows = self.total_encoded_windows.saturating_add(new_windows);
Ok(new_windows)
}
pub fn flush_partial(&mut self) -> Result<usize> {
let Some(frames) = self.pending_frames.take() else {
return Ok(0);
};
let valid_frames = self.pending_frame_count;
if valid_frames == 0 {
return Ok(0);
}
let padded = pad_to_window_size(&frames, valid_frames, self.window_size)?;
let mut encoded = self.encoder.encode_window(&padded, valid_frames)?;
encoded.eval()?;
self.cached_windows.push(encoded);
self.pending_frame_count = 0;
if self.cached_windows.len() > self.max_cached_windows {
self.cached_windows.remove(0);
}
Ok(1)
}
pub fn cached_encoder_output(&self) -> Result<Option<Array>> {
self.cached_encoder_output_from_window(0)
}
pub fn cached_encoder_output_from_window(&self, start_window: usize) -> Result<Option<Array>> {
if start_window >= self.cached_windows.len() {
return Ok(None);
}
let slice = &self.cached_windows[start_window..];
if slice.is_empty() {
return Ok(None);
}
if slice.len() == 1 {
return Ok(Some(slice[0].try_clone()?));
}
let refs: Vec<&Array> = slice.iter().collect();
Ok(Some(concatenate(&refs, 0)?))
}
pub fn encode_pending(&self) -> Result<Option<Array>> {
let Some(frames) = self.pending_frames.as_ref() else {
return Ok(None);
};
let valid_frames = self.pending_frame_count;
if valid_frames == 0 {
return Ok(None);
}
let padded = pad_to_window_size(frames, valid_frames, self.window_size)?;
let mut encoded = self.encoder.encode_window(&padded, valid_frames)?;
encoded.eval()?;
Ok(Some(encoded))
}
pub fn full_encoder_output(&self, from_window: Option<usize>) -> Result<Option<Array>> {
let cached = match from_window {
Some(start) => self.cached_encoder_output_from_window(start)?,
None => self.cached_encoder_output()?,
};
let pending = self.encode_pending()?;
match (cached, pending) {
(None, None) => Ok(None),
(Some(c), None) => Ok(Some(c)),
(None, Some(p)) => Ok(Some(p)),
(Some(c), Some(p)) => Ok(Some(concatenate(&[&c, &p], 0)?)),
}
}
pub fn drain_newly_encoded_windows(&mut self) -> Vec<Array> {
std::mem::take(&mut self.newly_encoded_windows)
}
pub fn reset(&mut self) {
self.cached_windows.clear();
self.newly_encoded_windows.clear();
self.total_encoded_windows = 0;
self.pending_frames = None;
self.pending_frame_count = 0;
}
}
fn pad_to_window_size(frames: &Array, valid_frames: usize, window_size: usize) -> Result<Array> {
if valid_frames >= window_size {
return frames.try_clone();
}
let high = window_size - valid_frames;
let high_i32 = i32::try_from(high).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"StreamingEncoder: pad-high count for window-size padding",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{high}"),
))
})?;
let pad_value = Array::zeros::<f32>(&[0i32; 0])?;
pad(
frames,
&[0_i32],
&[0_i32],
&[high_i32],
&pad_value,
c"constant",
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
struct MockEncoder {
window_size: usize,
calls: Mutex<Vec<(usize, usize)>>,
}
impl MockEncoder {
fn new(window_size: usize) -> Self {
Self {
window_size,
calls: Mutex::new(Vec::new()),
}
}
fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
fn last_call_rows(&self) -> Option<usize> {
self.calls.lock().unwrap().last().map(|(rows, _)| *rows)
}
fn last_call_valid_frames(&self) -> Option<usize> {
self.calls.lock().unwrap().last().map(|(_, valid)| *valid)
}
}
impl StreamingEncoderBackend for MockEncoder {
fn window_size(&self) -> usize {
self.window_size
}
fn encode_window(&self, mel_window: &Array, valid_frames: usize) -> Result<Array> {
let rows = mel_window.shape().first().copied().unwrap_or(0);
self.calls.lock().unwrap().push((rows, valid_frames));
let mut buf: Vec<f32> = Vec::with_capacity(rows * 2);
for i in 0..rows {
buf.push(i as f32);
buf.push(0.0);
}
Array::from_slice::<f32>(&buf, &[rows as i32, 2i32])
}
}
fn zero_mel(rows: usize, n_mels: usize) -> Array {
let buf = vec![0.0_f32; rows * n_mels];
Array::from_slice::<f32>(&buf, &[rows as i32, n_mels as i32]).unwrap()
}
#[test]
fn feed_accumulates_until_window_full_then_calls_backend_once() {
let encoder = MockEncoder::new(16);
let mut stream = StreamingEncoder::new(encoder, 4, 0);
assert_eq!(stream.feed(&zero_mel(8, 4)).unwrap(), 0);
assert_eq!(stream.backend().call_count(), 0);
assert!(stream.has_pending_frames());
assert_eq!(stream.feed(&zero_mel(8, 4)).unwrap(), 1);
assert_eq!(stream.backend().call_count(), 1);
assert_eq!(stream.backend().last_call_rows(), Some(16));
assert_eq!(stream.backend().last_call_valid_frames(), Some(16));
assert!(!stream.has_pending_frames());
assert_eq!(stream.encoded_window_count(), 1);
}
#[test]
fn feed_emits_multiple_windows_when_input_exceeds_one_window() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 4, 0);
let new_windows = stream.feed(&zero_mel(24, 4)).unwrap();
assert_eq!(new_windows, 3);
assert_eq!(stream.encoded_window_count(), 3);
assert!(!stream.has_pending_frames());
}
#[test]
fn feed_with_overlap_advances_by_stride_not_full_window() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 4, 2); assert_eq!(stream.window_stride(), 6);
let n = stream.feed(&zero_mel(14, 4)).unwrap();
assert_eq!(n, 2);
assert!(stream.has_pending_frames());
}
#[test]
fn feed_rejects_non_2d_input() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 4, 0);
let one_d = Array::from_slice::<f32>(&[0.0_f32; 8], &[8i32]).unwrap();
let err = stream.feed(&one_d).unwrap_err();
assert!(matches!(err, Error::RankMismatch(ref p)
if p.actual() == 1 && p.context().contains("rank-2")));
}
#[test]
fn flush_partial_encodes_remaining_pending_frames() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 4, 0);
assert_eq!(stream.feed(&zero_mel(5, 4)).unwrap(), 0);
let flushed = stream.flush_partial().unwrap();
assert_eq!(flushed, 1);
assert_eq!(stream.backend().call_count(), 1);
assert_eq!(stream.backend().last_call_rows(), Some(8));
assert_eq!(stream.backend().last_call_valid_frames(), Some(5));
assert!(!stream.has_pending_frames());
}
#[test]
fn flush_partial_on_empty_buffer_is_noop() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 4, 0);
assert_eq!(stream.flush_partial().unwrap(), 0);
assert_eq!(stream.backend().call_count(), 0);
}
#[test]
fn cache_evicts_oldest_window_when_max_exceeded() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 2, 0);
let n = stream.feed(&zero_mel(24, 4)).unwrap();
assert_eq!(n, 3);
assert_eq!(stream.encoded_window_count(), 3);
let cached = stream.cached_encoder_output().unwrap().unwrap();
assert_eq!(cached.shape()[0], 16); }
#[test]
fn drain_newly_encoded_windows_returns_each_window_once() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 10, 0);
let _ = stream.feed(&zero_mel(16, 4)).unwrap(); let first = stream.drain_newly_encoded_windows();
assert_eq!(first.len(), 2);
let second = stream.drain_newly_encoded_windows();
assert_eq!(second.len(), 0);
let _ = stream.feed(&zero_mel(8, 4)).unwrap(); let third = stream.drain_newly_encoded_windows();
assert_eq!(third.len(), 1);
}
#[test]
fn encode_pending_does_not_consume_pending_frames() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 4, 0);
let _ = stream.feed(&zero_mel(5, 4)).unwrap();
let pending_before = stream.has_pending_frames();
assert!(pending_before);
let out = stream.encode_pending().unwrap().unwrap();
assert_eq!(out.shape()[0], 8);
assert!(stream.has_pending_frames());
assert_eq!(stream.backend().call_count(), 1);
assert_eq!(stream.backend().last_call_rows(), Some(8));
assert_eq!(stream.backend().last_call_valid_frames(), Some(5));
}
#[test]
fn reset_clears_state_for_new_session() {
let encoder = MockEncoder::new(8);
let mut stream = StreamingEncoder::new(encoder, 4, 0);
let _ = stream.feed(&zero_mel(16, 4)).unwrap();
assert_eq!(stream.encoded_window_count(), 2);
stream.reset();
assert_eq!(stream.encoded_window_count(), 0);
assert_eq!(stream.total_cached_tokens(), 0);
assert!(!stream.has_pending_frames());
}
struct StrictContractEncoder {
window_size: usize,
n_mels: usize,
calls: Mutex<Vec<(usize, usize)>>,
}
impl StrictContractEncoder {
fn new(window_size: usize, n_mels: usize) -> Self {
Self {
window_size,
n_mels,
calls: Mutex::new(Vec::new()),
}
}
}
impl StreamingEncoderBackend for StrictContractEncoder {
fn window_size(&self) -> usize {
self.window_size
}
fn encode_window(&self, mel_window: &Array, valid_frames: usize) -> Result<Array> {
let shape = mel_window.shape();
assert_eq!(
shape.len(),
2,
"strict contract: window must be 2-D, got shape={shape:?}"
);
assert_eq!(
shape[0], self.window_size,
"strict contract: window row count must equal window_size={}, got {}",
self.window_size, shape[0]
);
assert_eq!(
shape[1], self.n_mels,
"strict contract: window col count must equal n_mels={}, got {}",
self.n_mels, shape[1]
);
assert!(
valid_frames <= self.window_size,
"strict contract: valid_frames {valid_frames} must be <= window_size {}",
self.window_size
);
self.calls.lock().unwrap().push((shape[0], valid_frames));
Array::from_slice::<f32>(&vec![0.0_f32; shape[0] * 2], &[shape[0] as i32, 2i32])
}
}
#[test]
fn streaming_encoder_pads_partial_window_with_zeros_and_signals_valid_frames() {
let mut stream = StreamingEncoder::new(StrictContractEncoder::new(8, 4), 4, 0);
let _ = stream.feed(&zero_mel(3, 4)).unwrap();
let _ = stream.encode_pending().unwrap().unwrap();
let calls = stream.backend().calls.lock().unwrap();
let (rows, valid_frames) = calls[0];
assert_eq!(rows, 8, "rows must be padded to window_size");
assert_eq!(valid_frames, 3, "valid_frames must equal the partial count");
assert!(
valid_frames < rows,
"partial case: valid_frames < rows; got valid_frames={valid_frames} rows={rows}"
);
}
#[test]
fn streaming_encoder_full_window_passes_valid_frames_equals_window_size() {
let mut stream = StreamingEncoder::new(StrictContractEncoder::new(8, 4), 4, 0);
let new_windows = stream.feed(&zero_mel(8, 4)).unwrap();
assert_eq!(new_windows, 1);
let calls = stream.backend().calls.lock().unwrap();
let (rows, valid_frames) = calls[0];
assert_eq!(rows, 8);
assert_eq!(
valid_frames, 8,
"full-window path must pass valid_frames == window_size"
);
}
}