wavekat_turn/audio/pipecat.rs
1//! Pipecat Smart Turn v3 backend.
2//!
3//! Audio-based turn detection using the Smart Turn ONNX model.
4//! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be
5//! upsampled before feeding to this detector.
6//!
7//! # Model
8//!
9//! - Source: <https://huggingface.co/pipecat-ai/smart-turn-v3>
10//! - File: `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB)
11//! - License: BSD 2-Clause
12//!
13//! # Tensor specification
14//!
15//! | Role | Name | Shape | Dtype |
16//! |--------|------------------|----------------|---------|
17//! | Input | `input_features` | `[B, 80, 800]` | float32 |
18//! | Output | `logits` | `[B, 1]` | float32 |
19//!
20//! Despite the name, `logits` is a **sigmoid probability** P(turn complete)
21//! in [0, 1] — the sigmoid is fused into the model before ONNX export.
22//! Threshold: `probability > 0.5` → `TurnState::Finished`.
23//!
24//! # Mel-feature specification
25//!
26//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`:
27//!
28//! | Parameter | Value |
29//! |---------------|--------------------------------|
30//! | Sample rate | 16 000 Hz |
31//! | n_fft | 400 samples (25 ms) |
32//! | hop_length | 160 samples (10 ms) |
33//! | n_mels | 80 |
34//! | Freq range | 0 – 8 000 Hz |
35//! | Mel scale | Slaney (NOT HTK) |
36//! | Window | Hann (periodic, size 400) |
37//! | Pre-emphasis | None |
38//! | Log | log10 with ε = 1e-10 |
39//! | Normalization | clamp(max − 8), (x + 4) / 4 |
40//!
41//! # Audio buffer
42//!
43//! - Exactly **8 seconds = 128 000 samples** at 16 kHz.
44//! - Shorter input: **front-padded** with zeros (audio is at the end).
45//! - Longer input: the **last** 8 s is used (oldest samples discarded).
46
47use std::collections::VecDeque;
48use std::path::Path;
49use std::sync::Arc;
50use std::time::Instant;
51
52use ndarray::{s, Array2, Array3};
53use ort::{inputs, value::Tensor};
54use realfft::num_complex::Complex;
55use realfft::{RealFftPlanner, RealToComplex};
56
57use crate::onnx;
58use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};
59
60// ---------------------------------------------------------------------------
61// Model variants
62// ---------------------------------------------------------------------------
63
64/// Language for a WaveKat fine-tune of Pipecat Smart Turn.
65///
66/// Each variant resolves to a `<lang>/smart-turn-cpu.onnx` file inside the
67/// language-agnostic HuggingFace repo `wavekat/smart-turn-ONNX`. The set is
68/// marked `#[non_exhaustive]` because adding a new language must not be a
69/// breaking change.
70#[cfg(feature = "wavekat-smart-turn")]
71#[non_exhaustive]
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum SmartTurnLang {
74 /// Mandarin Chinese.
75 Zh,
76}
77
78/// Which set of Smart Turn weights to load.
79///
80/// All variants share the same architecture (Whisper-Tiny encoder + binary
81/// classification head) and ONNX tensor contract — only the weights differ.
82#[non_exhaustive]
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum SmartTurnVariant {
85 /// Upstream multilingual Pipecat Smart Turn v3 (embedded in the crate).
86 PipecatV3,
87 /// WaveKat language-specialized fine-tune. Resolved at runtime through
88 /// HuggingFace (cached under `$HF_HOME/hub/`) and overridable via
89 /// `WAVEKAT_TURN_MODEL_DIR`.
90 #[cfg(feature = "wavekat-smart-turn")]
91 Wavekat(SmartTurnLang),
92}
93
94// ---------------------------------------------------------------------------
95// Constants
96// ---------------------------------------------------------------------------
97
98/// Sample rate the model expects.
99const SAMPLE_RATE: u32 = 16_000;
100/// FFT window size in samples (25 ms at 16 kHz).
101const N_FFT: usize = 400;
102/// STFT hop length in samples (10 ms at 16 kHz).
103const HOP_LENGTH: usize = 160;
104/// Number of mel filterbank bins.
105const N_MELS: usize = 80;
106/// Number of STFT frames the model expects (8 s × 100 fps).
107const N_FRAMES: usize = 800;
108/// FFT frequency bins: N_FFT/2 + 1.
109const N_FREQS: usize = N_FFT / 2 + 1; // 201
110/// Ring buffer capacity: 8 s × 16 kHz.
111const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000
112
113/// Embedded ONNX model bytes, downloaded by build.rs at compile time.
114const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx"));
115
116// ---------------------------------------------------------------------------
117// Mel feature extractor
118// ---------------------------------------------------------------------------
119
120/// Pre-computed Whisper-style log-mel feature extractor.
121///
122/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`].
123/// [`MelExtractor::extract`] is then called per inference.
124struct MelExtractor {
125 /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS].
126 mel_filters: Array2<f32>,
127 /// Periodic Hann window of length N_FFT.
128 hann_window: Vec<f32>,
129 /// Reusable forward real FFT plan.
130 fft: Arc<dyn RealToComplex<f32>>,
131 /// Reusable scratch buffer for the FFT.
132 fft_scratch: Vec<Complex<f32>>,
133 /// Reusable output spectrum buffer (N_FREQS complex values).
134 spectrum_buf: Vec<Complex<f32>>,
135 /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call.
136 /// Enables incremental STFT: only new frames are recomputed.
137 cached_power_spec: Option<Array2<f32>>,
138 /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call.
139 /// Enables incremental mel filterbank: only new columns are recomputed.
140 cached_mel_spec: Option<Array2<f32>>,
141}
142
143impl MelExtractor {
144 fn new() -> Self {
145 let mel_filters = build_mel_filters(
146 SAMPLE_RATE as usize,
147 N_FFT,
148 N_MELS,
149 0.0,
150 SAMPLE_RATE as f32 / 2.0,
151 );
152 let hann_window = periodic_hann(N_FFT);
153
154 let mut planner = RealFftPlanner::<f32>::new();
155 let fft = planner.plan_fft_forward(N_FFT);
156 let fft_scratch = fft.make_scratch_vec();
157 let spectrum_buf = fft.make_output_vec();
158
159 Self {
160 mel_filters,
161 hann_window,
162 fft,
163 fft_scratch,
164 spectrum_buf,
165 cached_power_spec: None,
166 cached_mel_spec: None,
167 }
168 }
169
170 /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly
171 /// `RING_CAPACITY` samples of 16 kHz mono audio.
172 ///
173 /// `shift_frames` is how many STFT frames worth of new audio were added
174 /// since the last call. When a valid cache exists and `shift_frames` is
175 /// in range, only the last `shift_frames` columns of the power spectrogram
176 /// are recomputed; the rest are copied from the shifted cache.
177 fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
178 debug_assert_eq!(audio.len(), RING_CAPACITY);
179
180 // ---- Center-pad: N_FFT/2 reflect samples on each side → 128 400 samples ----
181 // Matches WhisperFeatureExtractor: np.pad(waveform, n_fft//2, mode="reflect").
182 // Reflect (not zero) padding ensures the boundary frames match Python exactly.
183 // Gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
184 let pad = N_FFT / 2; // 200
185 let n = audio.len(); // 128 000
186 let mut padded = vec![0.0f32; pad + n + pad];
187 padded[pad..pad + n].copy_from_slice(audio);
188 // Left reflect: padded[0..pad] = audio[pad..1] reversed (exclude edge)
189 for i in 0..pad {
190 padded[i] = audio[pad - i];
191 }
192 // Right reflect: padded[pad+n..pad+n+pad] = audio[n-2..n-2-pad] reversed
193 for i in 0..pad {
194 padded[pad + n + i] = audio[n - 2 - i];
195 }
196
197 // n_total = (128 400 − 400) / 160 + 1 = 801
198 let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
199
200 // ---- Incremental STFT ----
201 // If we have a cached power spec and shift_frames < n_total_frames,
202 // reuse the unchanged frames by shifting the cache left and only
203 // computing the `shift_frames` new columns at the end.
204 let first_new_frame = match &self.cached_power_spec {
205 Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => {
206 let kept = n_total_frames - shift_frames;
207 let mut power_spec = Array2::<f32>::zeros((N_FREQS, n_total_frames));
208 power_spec
209 .slice_mut(s![.., ..kept])
210 .assign(&cached.slice(s![.., shift_frames..]));
211 self.cached_power_spec = Some(power_spec);
212 kept // only compute frames [kept..n_total_frames]
213 }
214 _ => {
215 self.cached_power_spec = Some(Array2::<f32>::zeros((N_FREQS, n_total_frames)));
216 0 // cold start: compute all frames
217 }
218 };
219
220 let power_spec = self.cached_power_spec.as_mut().unwrap();
221 let mut frame_buf = vec![0.0f32; N_FFT];
222
223 for frame_idx in first_new_frame..n_total_frames {
224 let start = frame_idx * HOP_LENGTH;
225 // Apply periodic Hann window
226 for (i, (&s, &w)) in padded[start..start + N_FFT]
227 .iter()
228 .zip(self.hann_window.iter())
229 .enumerate()
230 {
231 frame_buf[i] = s * w;
232 }
233
234 self.fft
235 .process_with_scratch(
236 &mut frame_buf,
237 &mut self.spectrum_buf,
238 &mut self.fft_scratch,
239 )
240 .expect("FFT failed: internal buffer size mismatch");
241
242 for (k, c) in self.spectrum_buf.iter().enumerate() {
243 power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im;
244 }
245 }
246
247 // Take first N_FRAMES columns (drop the trailing frame)
248 let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]);
249
250 // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ----
251 // Reuse the cached mel columns for the unchanged frames; only multiply
252 // the new power-spectrum columns against the filterbank.
253 let mel_spec = match &self.cached_mel_spec {
254 Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => {
255 let kept = N_FRAMES - shift_frames;
256 let mut ms = Array2::<f32>::zeros((N_MELS, N_FRAMES));
257 // Shift old columns left
258 ms.slice_mut(s![.., ..kept])
259 .assign(&cached.slice(s![.., shift_frames..]));
260 // Apply filterbank only to the new power-spectrum columns
261 let new_power = power_spec_view.slice(s![.., kept..]);
262 ms.slice_mut(s![.., kept..])
263 .assign(&self.mel_filters.dot(&new_power));
264 ms
265 }
266 _ => self.mel_filters.dot(&power_spec_view),
267 };
268 self.cached_mel_spec = Some(mel_spec.clone());
269
270 // ---- Log10 with floor at 1e-10 ----
271 let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10());
272
273 // ---- Dynamic range compression and normalization ----
274 // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4
275 let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
276 log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0);
277
278 log_mel
279 }
280
281 /// Invalidate all caches (call on reset).
282 fn invalidate_cache(&mut self) {
283 self.cached_power_spec = None;
284 self.cached_mel_spec = None;
285 }
286}
287
288// ---------------------------------------------------------------------------
289// Mel filterbank construction — Slaney scale, slaney norm
290// ---------------------------------------------------------------------------
291
292/// Convert Hz to mel (Slaney/librosa scale, NOT HTK).
293fn hz_to_mel(hz: f32) -> f32 {
294 const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel)
295 const MIN_LOG_HZ: f32 = 1000.0;
296 const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0
297 // logstep = ln(6.4) / 27 (≈ 0.068752)
298 let logstep = (6.4_f32).ln() / 27.0;
299 if hz >= MIN_LOG_HZ {
300 MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep
301 } else {
302 hz / F_SP
303 }
304}
305
306/// Convert mel back to Hz (Slaney scale).
307fn mel_to_hz(mel: f32) -> f32 {
308 const F_SP: f32 = 200.0 / 3.0;
309 const MIN_LOG_HZ: f32 = 1000.0;
310 const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP;
311 let logstep = (6.4_f32).ln() / 27.0;
312 if mel >= MIN_LOG_MEL {
313 MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp()
314 } else {
315 mel * F_SP
316 }
317}
318
319/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs].
320///
321/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax,
322/// norm="slaney", dtype=float32)` which is what HuggingFace's
323/// `WhisperFeatureExtractor` uses internally.
324fn build_mel_filters(
325 sr: usize,
326 n_fft: usize,
327 n_mels: usize,
328 f_min: f32,
329 f_max: f32,
330) -> Array2<f32> {
331 let n_freqs = n_fft / 2 + 1;
332
333 // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, …
334 let fft_freqs: Vec<f32> = (0..n_freqs)
335 .map(|i| i as f32 * sr as f32 / n_fft as f32)
336 .collect();
337
338 // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge)
339 let mel_min = hz_to_mel(f_min);
340 let mel_max = hz_to_mel(f_max);
341 let mel_pts: Vec<f32> = (0..=(n_mels + 1))
342 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
343 .collect();
344 let hz_pts: Vec<f32> = mel_pts.iter().map(|&m| mel_to_hz(m)).collect();
345
346 // Build triangular filters with Slaney normalisation
347 let mut filters = Array2::<f32>::zeros((n_mels, n_freqs));
348 for m in 0..n_mels {
349 let f_left = hz_pts[m];
350 let f_center = hz_pts[m + 1];
351 let f_right = hz_pts[m + 2];
352 // Slaney norm: 2 / (right_hz − left_hz)
353 let enorm = 2.0 / (f_right - f_left);
354
355 for (k, &f) in fft_freqs.iter().enumerate() {
356 let w = if f >= f_left && f <= f_center {
357 (f - f_left) / (f_center - f_left)
358 } else if f > f_center && f <= f_right {
359 (f_right - f) / (f_right - f_center)
360 } else {
361 0.0
362 };
363 filters[[m, k]] = w * enorm;
364 }
365 }
366 filters
367}
368
369// ---------------------------------------------------------------------------
370// Hann window
371// ---------------------------------------------------------------------------
372
373/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`.
374///
375/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n.
376/// This differs from the symmetric variant (which divides by n−1).
377fn periodic_hann(n: usize) -> Vec<f32> {
378 use std::f32::consts::PI;
379 (0..n)
380 .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos()))
381 .collect()
382}
383
384// ---------------------------------------------------------------------------
385// Audio preparation
386// ---------------------------------------------------------------------------
387
388/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples.
389///
390/// - Longer: keep the **last** 8 s (discard oldest).
391/// - Shorter: **front-pad** with zeros so audio is right-aligned.
392fn prepare_audio(samples: &[f32]) -> Vec<f32> {
393 match samples.len().cmp(&RING_CAPACITY) {
394 std::cmp::Ordering::Equal => samples.to_vec(),
395 std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(),
396 std::cmp::Ordering::Less => {
397 let mut out = vec![0.0f32; RING_CAPACITY - samples.len()];
398 out.extend_from_slice(samples);
399 out
400 }
401 }
402}
403
404// ---------------------------------------------------------------------------
405// PipecatSmartTurn
406// ---------------------------------------------------------------------------
407
408/// Pipecat Smart Turn v3 detector.
409///
410/// Wraps the Smart Turn v3 architecture (Whisper-Tiny encoder + binary
411/// classification head). Use [`new`] for the embedded upstream weights, or
412/// [`with_variant`] to pick a WaveKat fine-tune at runtime.
413///
414/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with
415/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires
416/// end-of-speech to get a [`TurnPrediction`].
417///
418/// # Usage with VAD
419///
420/// ```no_run
421/// # #[cfg(feature = "pipecat")]
422/// # {
423/// use wavekat_turn::audio::PipecatSmartTurn;
424/// use wavekat_turn::AudioTurnDetector;
425///
426/// let mut detector = PipecatSmartTurn::new().unwrap();
427/// // ... feed frames via push_audio ...
428/// let prediction = detector.predict().unwrap();
429/// println!("{:?} ({:.2})", prediction.state, prediction.confidence);
430/// # }
431/// ```
432///
433/// [`new`]: Self::new
434/// [`with_variant`]: Self::with_variant
435/// [`push_audio`]: AudioTurnDetector::push_audio
436/// [`predict`]: AudioTurnDetector::predict
437pub struct PipecatSmartTurn {
438 session: ort::session::Session,
439 ring_buffer: VecDeque<f32>,
440 mel: MelExtractor,
441 /// Counts samples pushed since the last `predict()` call.
442 /// Used to compute `shift_frames` for incremental STFT.
443 samples_since_predict: usize,
444}
445
446// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every
447// method that touches the session takes &mut self, preventing concurrent use.
448unsafe impl Send for PipecatSmartTurn {}
449unsafe impl Sync for PipecatSmartTurn {}
450
451impl PipecatSmartTurn {
452 /// Load the upstream Pipecat Smart Turn v3.2 model embedded at compile time.
453 ///
454 /// Equivalent to [`with_variant(SmartTurnVariant::PipecatV3)`](Self::with_variant).
455 pub fn new() -> Result<Self, TurnError> {
456 let session = onnx::session_from_memory(MODEL_BYTES)?;
457 Ok(Self::build(session))
458 }
459
460 /// Load a specific variant of the Smart Turn model.
461 ///
462 /// - [`SmartTurnVariant::PipecatV3`] uses the embedded ONNX bytes — no
463 /// network required.
464 /// - [`SmartTurnVariant::Wavekat`] (when the `wavekat-smart-turn` feature
465 /// is enabled) downloads the corresponding language file from the
466 /// `wavekat/smart-turn-ONNX` HuggingFace repo and caches it under
467 /// `$HF_HOME/hub/`. Set `WAVEKAT_TURN_MODEL_DIR` to point at a
468 /// pre-populated directory (offline / CI use).
469 pub fn with_variant(variant: SmartTurnVariant) -> Result<Self, TurnError> {
470 match variant {
471 SmartTurnVariant::PipecatV3 => Self::new(),
472 #[cfg(feature = "wavekat-smart-turn")]
473 SmartTurnVariant::Wavekat(lang) => {
474 let path = crate::audio::wavekat_download::resolve_model(lang)?;
475 Self::from_file(path)
476 }
477 }
478 }
479
480 /// Load a model from a custom path on disk.
481 ///
482 /// Useful for CI environments that supply the model file separately, or
483 /// for evaluating fine-tuned variants without recompiling.
484 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TurnError> {
485 let session = onnx::session_from_file(path)?;
486 Ok(Self::build(session))
487 }
488
489 fn build(session: ort::session::Session) -> Self {
490 Self {
491 session,
492 ring_buffer: VecDeque::with_capacity(RING_CAPACITY),
493 mel: MelExtractor::new(),
494 samples_since_predict: 0,
495 }
496 }
497}
498
499impl AudioTurnDetector for PipecatSmartTurn {
500 /// Append audio to the internal ring buffer.
501 ///
502 /// Frames with a sample rate other than 16 kHz are silently dropped.
503 /// The ring buffer holds at most 8 s; older samples are evicted.
504 fn push_audio(&mut self, frame: &AudioFrame) {
505 if frame.sample_rate() != SAMPLE_RATE {
506 return;
507 }
508 let samples = frame.samples();
509 // Evict oldest samples to make room
510 let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY);
511 if overflow > 0 {
512 self.ring_buffer.drain(..overflow);
513 }
514 self.ring_buffer.extend(samples.iter().copied());
515 self.samples_since_predict += samples.len();
516 }
517
518 /// Run inference on the buffered audio.
519 ///
520 /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts
521 /// Whisper log-mel features, and runs ONNX inference.
522 fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
523 let t_start = Instant::now();
524
525 // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples
526 let shift_frames = self.samples_since_predict / HOP_LENGTH;
527 self.samples_since_predict = 0;
528
529 let buffered: Vec<f32> = self.ring_buffer.iter().copied().collect();
530 let audio = prepare_audio(&buffered);
531 let t_after_audio_prep = Instant::now();
532
533 // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental)
534 let mel_spec = self.mel.extract(&audio, shift_frames);
535 let t_after_mel = Instant::now();
536
537 // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference
538 let (raw, _) = mel_spec.into_raw_vec_and_offset();
539 let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw)
540 .expect("internal: mel output has wrong element count");
541
542 let input_tensor = Tensor::from_array(input_array)
543 .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?;
544
545 let outputs = self
546 .session
547 .run(inputs!["input_features" => input_tensor])
548 .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?;
549 let t_after_onnx = Instant::now();
550
551 // Extract sigmoid probability from the "logits" output
552 let output = outputs
553 .get("logits")
554 .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?;
555 let (_, data): (_, &[f32]) = output
556 .try_extract_tensor()
557 .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?;
558 let probability = *data
559 .first()
560 .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?;
561
562 let latency_ms = t_start.elapsed().as_millis() as u64;
563
564 let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0;
565 let stage_times = vec![
566 StageTiming {
567 name: "audio_prep",
568 us: us(t_start, t_after_audio_prep),
569 },
570 StageTiming {
571 name: "mel",
572 us: us(t_after_audio_prep, t_after_mel),
573 },
574 StageTiming {
575 name: "onnx",
576 us: us(t_after_mel, t_after_onnx),
577 },
578 ];
579
580 // probability = P(turn complete); > 0.5 means the speaker has finished
581 let (state, confidence) = if probability > 0.5 {
582 (TurnState::Finished, probability)
583 } else {
584 (TurnState::Unfinished, 1.0 - probability)
585 };
586
587 let audio_duration_ms = (self.ring_buffer.len() as u64 * 1000) / SAMPLE_RATE as u64;
588
589 Ok(TurnPrediction {
590 state,
591 confidence,
592 latency_ms,
593 stage_times,
594 audio_duration_ms,
595 })
596 }
597
598 /// Clear the ring buffer. Call at the start of each new speech turn.
599 fn reset(&mut self) {
600 self.ring_buffer.clear();
601 self.samples_since_predict = 0;
602 self.mel.invalidate_cache();
603 }
604}
605
606// ---------------------------------------------------------------------------
607// Mel comparison tests (unit tests — need access to private MelExtractor)
608// ---------------------------------------------------------------------------
609
610#[cfg(test)]
611mod mel_tests {
612 use std::path::{Path, PathBuf};
613
614 use ndarray::Array2;
615 use ndarray_npy::ReadNpyExt;
616
617 use super::{prepare_audio, MelExtractor, RING_CAPACITY, SAMPLE_RATE};
618
619 /// Max allowed element-wise absolute difference between Rust and Python mel.
620 const MEL_TOLERANCE: f32 = 0.05;
621
622 fn fixtures_dir() -> PathBuf {
623 Path::new(env!("CARGO_MANIFEST_DIR"))
624 .parent()
625 .unwrap() // crates/
626 .parent()
627 .unwrap() // repo root
628 .join("tests/fixtures")
629 }
630
631 /// Load 16 kHz mono WAV as f32 in [-1, 1], normalised the same way as
632 /// Python's soundfile (divide by 32768, not i16::MAX).
633 fn load_wav_f32(path: &Path) -> Vec<f32> {
634 let mut reader = hound::WavReader::open(path)
635 .unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
636 let spec = reader.spec();
637 assert_eq!(spec.sample_rate, SAMPLE_RATE, "expected 16 kHz");
638 assert_eq!(spec.channels, 1, "expected mono");
639 match spec.sample_format {
640 hound::SampleFormat::Int => reader
641 .samples::<i16>()
642 .map(|s| s.unwrap() as f32 / 32768.0)
643 .collect(),
644 hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
645 }
646 }
647
648 fn load_python_mel(clip: &str) -> Array2<f32> {
649 let path = fixtures_dir().join(format!("{clip}.mel.npy"));
650 let file = std::fs::File::open(&path).unwrap_or_else(|_| {
651 panic!(
652 "missing {}: run `python scripts/gen_reference.py` first",
653 path.display()
654 )
655 });
656 Array2::<f32>::read_npy(file).expect("failed to parse .npy")
657 }
658
659 struct MelDiff {
660 max_diff: f32,
661 mean_diff: f32,
662 /// (mel_bin, frame) of the single largest diff
663 max_at: (usize, usize),
664 /// fraction of elements with diff > 0.01
665 outlier_frac: f32,
666 }
667
668 fn compare_mel(clip: &str) -> MelDiff {
669 let samples = load_wav_f32(&fixtures_dir().join(clip));
670 let audio = prepare_audio(&samples);
671 assert_eq!(audio.len(), RING_CAPACITY);
672
673 let mut extractor = MelExtractor::new();
674 let rust_mel = extractor.extract(&audio, 0);
675 let python_mel = load_python_mel(clip);
676
677 assert_eq!(
678 rust_mel.shape(),
679 python_mel.shape(),
680 "{clip}: mel shape mismatch"
681 );
682
683 let shape = rust_mel.shape();
684 let (n_mels, n_frames) = (shape[0], shape[1]);
685
686 let mut max_diff = 0.0f32;
687 let mut max_at = (0, 0);
688 let mut sum_diff = 0.0f32;
689 let mut outliers = 0usize;
690
691 for m in 0..n_mels {
692 for t in 0..n_frames {
693 let d = (rust_mel[[m, t]] - python_mel[[m, t]]).abs();
694 sum_diff += d;
695 if d > max_diff {
696 max_diff = d;
697 max_at = (m, t);
698 }
699 if d > 0.01 {
700 outliers += 1;
701 }
702 }
703 }
704
705 let total = (n_mels * n_frames) as f32;
706 MelDiff {
707 max_diff,
708 mean_diff: sum_diff / total,
709 max_at,
710 outlier_frac: outliers as f32 / total,
711 }
712 }
713
714 /// Print a markdown table of mel-level diffs between Rust and Python.
715 /// Run with: `make mel`
716 #[test]
717 #[ignore]
718 fn mel_report() {
719 let clips = ["silence_2s.wav", "speech_finished.wav", "speech_mid.wav"];
720
721 println!();
722 println!("MEL_TOLERANCE={MEL_TOLERANCE}");
723 println!();
724 println!("| Clip | Max Diff | Mean Diff | Max at (mel,frame) | Outliers >0.01 | Status |");
725 println!("|------|----------|-----------|---------------------|----------------|--------|");
726 for clip in clips {
727 let d = compare_mel(clip);
728 let status = if d.max_diff <= MEL_TOLERANCE {
729 "PASS"
730 } else {
731 "FAIL"
732 };
733 println!(
734 "| `{clip}` | {:.6} | {:.6} | ({},{}) | {:.2}% | {status} |",
735 d.max_diff,
736 d.mean_diff,
737 d.max_at.0,
738 d.max_at.1,
739 d.outlier_frac * 100.0,
740 );
741 }
742 println!();
743 }
744}