use alloc::vec::Vec;
use crate::{Result, StreamingFingerprinter, TimestampMs};
use super::embedder::{EmbedderCore, NeuralEmbedderConfig};
pub struct StreamingNeuralEmbedder {
core: EmbedderCore,
sample_carry: Vec<f32>,
samples_consumed: u64,
}
impl StreamingNeuralEmbedder {
pub fn new(cfg: NeuralEmbedderConfig) -> Result<Self> {
let inner = super::embedder::NeuralEmbedder::new(cfg)?;
Ok(Self {
core: inner.core,
sample_carry: Vec::new(),
samples_consumed: 0,
})
}
#[must_use]
pub fn config(&self) -> &NeuralEmbedderConfig {
&self.core.cfg
}
#[must_use]
pub fn embedding_dim(&self) -> usize {
self.core.embedding_dim
}
#[must_use]
pub fn window_samples(&self) -> usize {
self.core.window_samples
}
#[must_use]
pub fn hop_samples(&self) -> usize {
self.core.hop_samples
}
pub fn try_push(&mut self, samples: &[f32]) -> Result<Vec<(TimestampMs, Vec<f32>)>> {
let mut out = Vec::new();
self.try_push_with(samples, |t, v| {
out.push((t, v.to_vec()));
})?;
Ok(out)
}
pub fn try_push_with<F: FnMut(TimestampMs, &[f32])>(
&mut self,
samples: &[f32],
mut callback: F,
) -> Result<usize> {
self.sample_carry.extend_from_slice(samples);
let window_samples = self.core.window_samples;
let hop_samples = self.core.hop_samples;
let sr = self.core.cfg.sample_rate as u64;
let mut buf = Vec::with_capacity(self.core.embedding_dim);
let mut emitted = 0usize;
while self.sample_carry.len() >= window_samples {
{
let window = &self.sample_carry[..window_samples];
self.core.embed_window_into(window, &mut buf)?;
}
let t_start = TimestampMs(self.samples_consumed * 1000 / sr);
callback(t_start, &buf);
emitted += 1;
self.sample_carry.drain(..hop_samples);
self.samples_consumed += hop_samples as u64;
}
Ok(emitted)
}
pub fn reset(&mut self) {
self.sample_carry.clear();
self.samples_consumed = 0;
}
#[cfg(test)]
pub(crate) fn carry_len(&self) -> usize {
self.sample_carry.len()
}
#[cfg(test)]
pub(crate) fn samples_consumed(&self) -> u64 {
self.samples_consumed
}
#[cfg(test)]
pub(crate) fn __from_core_for_test(core: EmbedderCore) -> Self {
Self {
core,
sample_carry: Vec::new(),
samples_consumed: 0,
}
}
}
impl StreamingFingerprinter for StreamingNeuralEmbedder {
type Frame = Vec<f32>;
fn push(&mut self, samples: &[f32]) -> Vec<(TimestampMs, Self::Frame)> {
self.try_push(samples)
.unwrap_or_else(|e| panic!("neural inference failed during push: {e}"))
}
fn flush(&mut self) -> Vec<(TimestampMs, Self::Frame)> {
self.sample_carry.clear();
Vec::new()
}
fn latency_ms(&self) -> u32 {
let ms = (self.core.window_samples as u64 * 1000) / self.core.cfg.sample_rate.max(1) as u64;
ms.min(u32::MAX as u64) as u32
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AfpError;
use crate::Fingerprinter;
use crate::neural::test_support::{passthrough_streaming, small_cfg, synth_audio};
#[test]
fn missing_model_propagates_through_streaming_constructor() {
let cfg = NeuralEmbedderConfig::new("/definitely/does/not/exist.onnx");
match StreamingNeuralEmbedder::new(cfg) {
Err(AfpError::ModelNotFound(_)) => {}
Err(e) => panic!("expected ModelNotFound, got {e:?}"),
Ok(_) => panic!("expected ModelNotFound, got Ok"),
}
}
#[test]
fn empty_path_propagates_through_streaming_constructor() {
let cfg = NeuralEmbedderConfig::new("");
match StreamingNeuralEmbedder::new(cfg) {
Err(AfpError::ModelNotFound(p)) => assert!(p.is_empty()),
Err(e) => panic!("expected ModelNotFound, got {e:?}"),
Ok(_) => panic!("expected ModelNotFound, got Ok"),
}
}
fn fixture() -> StreamingNeuralEmbedder {
passthrough_streaming(small_cfg()).expect("fixture builds")
}
#[test]
fn embedding_dim_matches_passthrough_shape() {
let s = fixture();
assert_eq!(s.embedding_dim(), 8 * 30);
}
#[test]
fn latency_ms_matches_window_duration() {
let s = fixture();
assert_eq!(s.latency_ms(), 250);
}
#[test]
fn empty_push_emits_nothing_and_does_not_buffer() {
let mut s = fixture();
let out = s.push(&[]);
assert!(out.is_empty());
assert_eq!(s.carry_len(), 0);
}
#[test]
fn sub_window_push_only_buffers() {
let mut s = fixture();
let half = s.window_samples() / 2;
let chunk = vec![0.0_f32; half];
let out = s.push(&chunk);
assert!(out.is_empty(), "no embedding before window full");
assert_eq!(s.carry_len(), half);
assert_eq!(s.samples_consumed(), 0);
}
#[test]
fn one_full_window_emits_one_embedding() {
let mut s = fixture();
let chunk = synth_audio(1, s.window_samples(), 16_000);
let out = s.push(&chunk);
assert_eq!(out.len(), 1);
let (t, vec) = &out[0];
assert_eq!(t.0, 0);
assert_eq!(vec.len(), s.embedding_dim());
assert_eq!(s.carry_len(), 0); }
#[test]
fn timestamps_advance_by_hop_secs() {
let mut s = fixture();
let n = 4 * s.window_samples();
let chunk = synth_audio(2, n, 16_000);
let out = s.push(&chunk);
assert_eq!(out.len(), 4);
for (i, (t, _)) in out.iter().enumerate() {
assert_eq!(t.0, (i as u64) * 250);
}
}
#[test]
fn carry_is_bounded_under_arbitrarily_large_pushes() {
let mut s = fixture();
let n = 100 * s.window_samples() + 17;
let chunk = synth_audio(3, n, 16_000);
let out = s.push(&chunk);
assert_eq!(out.len(), 100);
assert!(
s.carry_len() < s.window_samples(),
"carry {} >= window {}",
s.carry_len(),
s.window_samples(),
);
assert_eq!(s.carry_len(), 17);
}
#[test]
fn flush_clears_carry_and_emits_nothing() {
let mut s = fixture();
let chunk = vec![0.0_f32; s.window_samples() / 2];
s.push(&chunk);
assert!(s.carry_len() > 0);
let f = s.flush();
assert!(f.is_empty());
assert_eq!(s.carry_len(), 0);
}
#[test]
fn reset_clears_carry_and_consumed_count() {
let mut s = fixture();
let chunk = synth_audio(4, 3 * s.window_samples(), 16_000);
let _ = s.push(&chunk);
assert!(s.samples_consumed() > 0);
s.reset();
assert_eq!(s.carry_len(), 0);
assert_eq!(s.samples_consumed(), 0);
}
#[test]
fn try_push_returns_same_as_push_on_success() {
let mut a = fixture();
let mut b = fixture();
let chunk = synth_audio(5, 2 * a.window_samples(), 16_000);
let out_push = a.push(&chunk);
let out_try = b.try_push(&chunk).expect("try_push ok");
assert_eq!(out_push.len(), out_try.len());
for ((t1, v1), (t2, v2)) in out_push.iter().zip(out_try.iter()) {
assert_eq!(t1.0, t2.0);
assert_eq!(v1, v2);
}
}
#[test]
fn streaming_matches_offline_on_full_buffer() {
let cfg = small_cfg();
let n = 5 * (cfg.window_secs * cfg.sample_rate as f32) as usize;
let audio = synth_audio(7, n, cfg.sample_rate);
let mut off = crate::neural::test_support::passthrough_embedder(cfg.clone()).unwrap();
let buf = crate::AudioBuffer {
samples: &audio,
rate: crate::SampleRate::HZ_16000,
};
let off_fp = off.extract(buf).unwrap();
let mut s = passthrough_streaming(cfg).unwrap();
let s_out = s.push(&audio);
assert_eq!(off_fp.embeddings.len(), s_out.len());
for (e, (t, v)) in off_fp.embeddings.iter().zip(s_out.iter()) {
assert_eq!(e.t_start.0, t.0);
assert_eq!(&e.vector, v);
}
}
#[test]
fn streaming_chunk_size_invariant() {
let cfg = small_cfg();
let n = 8 * (cfg.window_secs * cfg.sample_rate as f32) as usize + 23;
let audio = synth_audio(11, n, cfg.sample_rate);
let reference = {
let mut s = passthrough_streaming(cfg.clone()).unwrap();
s.push(&audio)
};
for chunk_size in [1, 7, 17, 256, 1024, 8_191] {
let mut s = passthrough_streaming(cfg.clone()).unwrap();
let mut collected = Vec::new();
let mut start = 0;
while start < audio.len() {
let end = (start + chunk_size).min(audio.len());
collected.extend(s.push(&audio[start..end]));
start = end;
}
assert_eq!(
collected.len(),
reference.len(),
"chunk_size={chunk_size}: count mismatch",
);
for ((t1, v1), (t2, v2)) in collected.iter().zip(reference.iter()) {
assert_eq!(t1.0, t2.0, "chunk_size={chunk_size}: timestamp drift");
assert_eq!(v1, v2, "chunk_size={chunk_size}: embedding drift");
}
}
}
#[test]
fn overlapping_window_streaming_matches_offline() {
let mut cfg = small_cfg();
cfg.window_secs = 0.5; cfg.hop_secs = 0.25; let n = 4 * (cfg.window_secs * cfg.sample_rate as f32) as usize;
let audio = synth_audio(13, n, cfg.sample_rate);
let mut off = crate::neural::test_support::passthrough_embedder(cfg.clone()).unwrap();
let buf = crate::AudioBuffer {
samples: &audio,
rate: crate::SampleRate::HZ_16000,
};
let off_fp = off.extract(buf).unwrap();
let mut s = passthrough_streaming(cfg).unwrap();
let s_out = s.push(&audio);
assert_eq!(off_fp.embeddings.len(), s_out.len());
for (e, (t, v)) in off_fp.embeddings.iter().zip(s_out.iter()) {
assert_eq!(e.t_start.0, t.0);
assert_eq!(&e.vector, v);
}
}
#[test]
fn l2_normalization_actually_normalises() {
let mut cfg = small_cfg();
cfg.l2_normalize = true;
let mut s = passthrough_streaming(cfg.clone()).unwrap();
let chunk = synth_audio(17, s.window_samples(), 16_000);
let out = s.push(&chunk);
let v = &out[0].1;
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4, "expected L2 norm ~1, got {norm}",);
}
#[test]
fn try_push_with_callback_matches_try_push_collected() {
let mut a = fixture();
let mut b = fixture();
let chunk = synth_audio(23, 3 * a.window_samples(), 16_000);
let collected_via_vec = a.try_push(&chunk).unwrap();
let mut collected_via_cb: Vec<(TimestampMs, Vec<f32>)> = Vec::new();
let n = b
.try_push_with(&chunk, |t, v| collected_via_cb.push((t, v.to_vec())))
.unwrap();
assert_eq!(n, collected_via_vec.len());
assert_eq!(collected_via_cb.len(), collected_via_vec.len());
for ((t1, v1), (t2, v2)) in collected_via_vec.iter().zip(collected_via_cb.iter()) {
assert_eq!(t1.0, t2.0);
assert_eq!(v1, v2);
}
}
#[test]
fn no_l2_normalization_preserves_magnitude() {
let mut cfg = small_cfg();
cfg.l2_normalize = false;
let mut s = passthrough_streaming(cfg.clone()).unwrap();
let chunk = synth_audio(19, s.window_samples(), 16_000);
let out = s.push(&chunk);
let v = &out[0].1;
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() > 0.5,
"unexpected near-unit norm without L2: {norm}",
);
}
}