Skip to main content

wavekat_turn/audio/
pipecat.rs

1//! Pipecat Smart Turn v3 backend.
2//!
3//! Audio-based turn detection using the Smart Turn ONNX model.
4//! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be
5//! upsampled before feeding to this detector.
6//!
7//! # Model
8//!
9//! - Source:  <https://huggingface.co/pipecat-ai/smart-turn-v3>
10//! - File:    `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB)
11//! - License: BSD 2-Clause
12//!
13//! # Tensor specification
14//!
15//! | Role   | Name             | Shape          | Dtype   |
16//! |--------|------------------|----------------|---------|
17//! | Input  | `input_features` | `[B, 80, 800]` | float32 |
18//! | Output | `logits`         | `[B, 1]`       | float32 |
19//!
20//! Despite the name, `logits` is a **sigmoid probability** P(turn complete)
21//! in [0, 1] — the sigmoid is fused into the model before ONNX export.
22//! Threshold: `probability > 0.5` → `TurnState::Finished`.
23//!
24//! # Mel-feature specification
25//!
26//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`:
27//!
28//! | Parameter     | Value                          |
29//! |---------------|--------------------------------|
30//! | Sample rate   | 16 000 Hz                      |
31//! | n_fft         | 400 samples (25 ms)            |
32//! | hop_length    | 160 samples (10 ms)            |
33//! | n_mels        | 80                             |
34//! | Freq range    | 0 – 8 000 Hz                   |
35//! | Mel scale     | Slaney (NOT HTK)               |
36//! | Window        | Hann (periodic, size 400)      |
37//! | Pre-emphasis  | None                           |
38//! | Log           | log10 with ε = 1e-10           |
39//! | Normalization | clamp(max − 8), (x + 4) / 4   |
40//!
41//! # Audio buffer
42//!
43//! - Exactly **8 seconds = 128 000 samples** at 16 kHz.
44//! - Shorter input: **front-padded** with zeros (audio is at the end).
45//! - Longer input: the **last** 8 s is used (oldest samples discarded).
46
47use std::collections::VecDeque;
48use std::path::Path;
49use std::sync::Arc;
50use std::time::Instant;
51
52use ndarray::{s, Array2, Array3};
53use ort::{inputs, value::Tensor};
54use realfft::num_complex::Complex;
55use realfft::{RealFftPlanner, RealToComplex};
56
57use crate::onnx;
58use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};
59
60// ---------------------------------------------------------------------------
61// Model variants
62// ---------------------------------------------------------------------------
63
64/// Language for a WaveKat fine-tune of Pipecat Smart Turn.
65///
66/// Each variant resolves to a `<lang>/smart-turn-cpu.onnx` file inside the
67/// language-agnostic HuggingFace repo `wavekat/smart-turn-ONNX`. The set is
68/// marked `#[non_exhaustive]` because adding a new language must not be a
69/// breaking change.
70#[cfg(feature = "wavekat-smart-turn")]
71#[non_exhaustive]
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum SmartTurnLang {
74    /// Mandarin Chinese.
75    Zh,
76}
77
78/// Which set of Smart Turn weights to load.
79///
80/// All variants share the same architecture (Whisper-Tiny encoder + binary
81/// classification head) and ONNX tensor contract — only the weights differ.
82#[non_exhaustive]
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum SmartTurnVariant {
85    /// Upstream multilingual Pipecat Smart Turn v3 (embedded in the crate).
86    PipecatV3,
87    /// WaveKat language-specialized fine-tune. Resolved at runtime through
88    /// HuggingFace (cached under `$HF_HOME/hub/`) and overridable via
89    /// `WAVEKAT_TURN_MODEL_DIR`.
90    #[cfg(feature = "wavekat-smart-turn")]
91    Wavekat(SmartTurnLang),
92}
93
94// ---------------------------------------------------------------------------
95// Constants
96// ---------------------------------------------------------------------------
97
98/// Sample rate the model expects.
99const SAMPLE_RATE: u32 = 16_000;
100/// FFT window size in samples (25 ms at 16 kHz).
101const N_FFT: usize = 400;
102/// STFT hop length in samples (10 ms at 16 kHz).
103const HOP_LENGTH: usize = 160;
104/// Number of mel filterbank bins.
105const N_MELS: usize = 80;
106/// Number of STFT frames the model expects (8 s × 100 fps).
107const N_FRAMES: usize = 800;
108/// FFT frequency bins: N_FFT/2 + 1.
109const N_FREQS: usize = N_FFT / 2 + 1; // 201
110/// Ring buffer capacity: 8 s × 16 kHz.
111const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000
112
113/// Embedded ONNX model bytes, downloaded by build.rs at compile time.
114const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx"));
115
116// ---------------------------------------------------------------------------
117// Mel feature extractor
118// ---------------------------------------------------------------------------
119
120/// Pre-computed Whisper-style log-mel feature extractor.
121///
122/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`].
123/// [`MelExtractor::extract`] is then called per inference.
124struct MelExtractor {
125    /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS].
126    mel_filters: Array2<f32>,
127    /// Periodic Hann window of length N_FFT.
128    hann_window: Vec<f32>,
129    /// Reusable forward real FFT plan.
130    fft: Arc<dyn RealToComplex<f32>>,
131    /// Reusable scratch buffer for the FFT.
132    fft_scratch: Vec<Complex<f32>>,
133    /// Reusable output spectrum buffer (N_FREQS complex values).
134    spectrum_buf: Vec<Complex<f32>>,
135    /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call.
136    /// Enables incremental STFT: only new frames are recomputed.
137    cached_power_spec: Option<Array2<f32>>,
138    /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call.
139    /// Enables incremental mel filterbank: only new columns are recomputed.
140    cached_mel_spec: Option<Array2<f32>>,
141}
142
143impl MelExtractor {
144    fn new() -> Self {
145        let mel_filters = build_mel_filters(
146            SAMPLE_RATE as usize,
147            N_FFT,
148            N_MELS,
149            0.0,
150            SAMPLE_RATE as f32 / 2.0,
151        );
152        let hann_window = periodic_hann(N_FFT);
153
154        let mut planner = RealFftPlanner::<f32>::new();
155        let fft = planner.plan_fft_forward(N_FFT);
156        let fft_scratch = fft.make_scratch_vec();
157        let spectrum_buf = fft.make_output_vec();
158
159        Self {
160            mel_filters,
161            hann_window,
162            fft,
163            fft_scratch,
164            spectrum_buf,
165            cached_power_spec: None,
166            cached_mel_spec: None,
167        }
168    }
169
170    /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly
171    /// `RING_CAPACITY` samples of 16 kHz mono audio.
172    ///
173    /// `shift_frames` is how many STFT frames worth of new audio were added
174    /// since the last call. When a valid cache exists and `shift_frames` is
175    /// in range, only the last `shift_frames` columns of the power spectrogram
176    /// are recomputed; the rest are copied from the shifted cache.
177    fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
178        debug_assert_eq!(audio.len(), RING_CAPACITY);
179
180        // ---- Center-pad: N_FFT/2 reflect samples on each side → 128 400 samples ----
181        // Matches WhisperFeatureExtractor: np.pad(waveform, n_fft//2, mode="reflect").
182        // Reflect (not zero) padding ensures the boundary frames match Python exactly.
183        // Gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
184        let pad = N_FFT / 2; // 200
185        let n = audio.len(); // 128 000
186        let mut padded = vec![0.0f32; pad + n + pad];
187        padded[pad..pad + n].copy_from_slice(audio);
188        // Left reflect: padded[0..pad] = audio[pad..1] reversed (exclude edge)
189        for i in 0..pad {
190            padded[i] = audio[pad - i];
191        }
192        // Right reflect: padded[pad+n..pad+n+pad] = audio[n-2..n-2-pad] reversed
193        for i in 0..pad {
194            padded[pad + n + i] = audio[n - 2 - i];
195        }
196
197        // n_total = (128 400 − 400) / 160 + 1 = 801
198        let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
199
200        // ---- Incremental STFT ----
201        // If we have a cached power spec and shift_frames < n_total_frames,
202        // reuse the unchanged frames by shifting the cache left and only
203        // computing the `shift_frames` new columns at the end.
204        let first_new_frame = match &self.cached_power_spec {
205            Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => {
206                let kept = n_total_frames - shift_frames;
207                let mut power_spec = Array2::<f32>::zeros((N_FREQS, n_total_frames));
208                power_spec
209                    .slice_mut(s![.., ..kept])
210                    .assign(&cached.slice(s![.., shift_frames..]));
211                self.cached_power_spec = Some(power_spec);
212                kept // only compute frames [kept..n_total_frames]
213            }
214            _ => {
215                self.cached_power_spec = Some(Array2::<f32>::zeros((N_FREQS, n_total_frames)));
216                0 // cold start: compute all frames
217            }
218        };
219
220        let power_spec = self.cached_power_spec.as_mut().unwrap();
221        let mut frame_buf = vec![0.0f32; N_FFT];
222
223        for frame_idx in first_new_frame..n_total_frames {
224            let start = frame_idx * HOP_LENGTH;
225            // Apply periodic Hann window
226            for (i, (&s, &w)) in padded[start..start + N_FFT]
227                .iter()
228                .zip(self.hann_window.iter())
229                .enumerate()
230            {
231                frame_buf[i] = s * w;
232            }
233
234            self.fft
235                .process_with_scratch(
236                    &mut frame_buf,
237                    &mut self.spectrum_buf,
238                    &mut self.fft_scratch,
239                )
240                .expect("FFT failed: internal buffer size mismatch");
241
242            for (k, c) in self.spectrum_buf.iter().enumerate() {
243                power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im;
244            }
245        }
246
247        // Take first N_FRAMES columns (drop the trailing frame)
248        let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]);
249
250        // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ----
251        // Reuse the cached mel columns for the unchanged frames; only multiply
252        // the new power-spectrum columns against the filterbank.
253        let mel_spec = match &self.cached_mel_spec {
254            Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => {
255                let kept = N_FRAMES - shift_frames;
256                let mut ms = Array2::<f32>::zeros((N_MELS, N_FRAMES));
257                // Shift old columns left
258                ms.slice_mut(s![.., ..kept])
259                    .assign(&cached.slice(s![.., shift_frames..]));
260                // Apply filterbank only to the new power-spectrum columns
261                let new_power = power_spec_view.slice(s![.., kept..]);
262                ms.slice_mut(s![.., kept..])
263                    .assign(&self.mel_filters.dot(&new_power));
264                ms
265            }
266            _ => self.mel_filters.dot(&power_spec_view),
267        };
268        self.cached_mel_spec = Some(mel_spec.clone());
269
270        // ---- Log10 with floor at 1e-10 ----
271        let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10());
272
273        // ---- Dynamic range compression and normalization ----
274        // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4
275        let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
276        log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0);
277
278        log_mel
279    }
280
281    /// Invalidate all caches (call on reset).
282    fn invalidate_cache(&mut self) {
283        self.cached_power_spec = None;
284        self.cached_mel_spec = None;
285    }
286}
287
288// ---------------------------------------------------------------------------
289// Mel filterbank construction — Slaney scale, slaney norm
290// ---------------------------------------------------------------------------
291
292/// Convert Hz to mel (Slaney/librosa scale, NOT HTK).
293fn hz_to_mel(hz: f32) -> f32 {
294    const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel)
295    const MIN_LOG_HZ: f32 = 1000.0;
296    const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0
297                                                // logstep = ln(6.4) / 27  (≈ 0.068752)
298    let logstep = (6.4_f32).ln() / 27.0;
299    if hz >= MIN_LOG_HZ {
300        MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep
301    } else {
302        hz / F_SP
303    }
304}
305
306/// Convert mel back to Hz (Slaney scale).
307fn mel_to_hz(mel: f32) -> f32 {
308    const F_SP: f32 = 200.0 / 3.0;
309    const MIN_LOG_HZ: f32 = 1000.0;
310    const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP;
311    let logstep = (6.4_f32).ln() / 27.0;
312    if mel >= MIN_LOG_MEL {
313        MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp()
314    } else {
315        mel * F_SP
316    }
317}
318
319/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs].
320///
321/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax,
322///   norm="slaney", dtype=float32)` which is what HuggingFace's
323/// `WhisperFeatureExtractor` uses internally.
324fn build_mel_filters(
325    sr: usize,
326    n_fft: usize,
327    n_mels: usize,
328    f_min: f32,
329    f_max: f32,
330) -> Array2<f32> {
331    let n_freqs = n_fft / 2 + 1;
332
333    // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, …
334    let fft_freqs: Vec<f32> = (0..n_freqs)
335        .map(|i| i as f32 * sr as f32 / n_fft as f32)
336        .collect();
337
338    // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge)
339    let mel_min = hz_to_mel(f_min);
340    let mel_max = hz_to_mel(f_max);
341    let mel_pts: Vec<f32> = (0..=(n_mels + 1))
342        .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
343        .collect();
344    let hz_pts: Vec<f32> = mel_pts.iter().map(|&m| mel_to_hz(m)).collect();
345
346    // Build triangular filters with Slaney normalisation
347    let mut filters = Array2::<f32>::zeros((n_mels, n_freqs));
348    for m in 0..n_mels {
349        let f_left = hz_pts[m];
350        let f_center = hz_pts[m + 1];
351        let f_right = hz_pts[m + 2];
352        // Slaney norm: 2 / (right_hz − left_hz)
353        let enorm = 2.0 / (f_right - f_left);
354
355        for (k, &f) in fft_freqs.iter().enumerate() {
356            let w = if f >= f_left && f <= f_center {
357                (f - f_left) / (f_center - f_left)
358            } else if f > f_center && f <= f_right {
359                (f_right - f) / (f_right - f_center)
360            } else {
361                0.0
362            };
363            filters[[m, k]] = w * enorm;
364        }
365    }
366    filters
367}
368
369// ---------------------------------------------------------------------------
370// Hann window
371// ---------------------------------------------------------------------------
372
373/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`.
374///
375/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n.
376/// This differs from the symmetric variant (which divides by n−1).
377fn periodic_hann(n: usize) -> Vec<f32> {
378    use std::f32::consts::PI;
379    (0..n)
380        .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos()))
381        .collect()
382}
383
384// ---------------------------------------------------------------------------
385// Audio preparation
386// ---------------------------------------------------------------------------
387
388/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples.
389///
390/// - Longer: keep the **last** 8 s (discard oldest).
391/// - Shorter: **front-pad** with zeros so audio is right-aligned.
392fn prepare_audio(samples: &[f32]) -> Vec<f32> {
393    match samples.len().cmp(&RING_CAPACITY) {
394        std::cmp::Ordering::Equal => samples.to_vec(),
395        std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(),
396        std::cmp::Ordering::Less => {
397            let mut out = vec![0.0f32; RING_CAPACITY - samples.len()];
398            out.extend_from_slice(samples);
399            out
400        }
401    }
402}
403
404// ---------------------------------------------------------------------------
405// PipecatSmartTurn
406// ---------------------------------------------------------------------------
407
408/// Pipecat Smart Turn v3 detector.
409///
410/// Wraps the Smart Turn v3 architecture (Whisper-Tiny encoder + binary
411/// classification head). Use [`new`] for the embedded upstream weights, or
412/// [`with_variant`] to pick a WaveKat fine-tune at runtime.
413///
414/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with
415/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires
416/// end-of-speech to get a [`TurnPrediction`].
417///
418/// # Usage with VAD
419///
420/// ```no_run
421/// # #[cfg(feature = "pipecat")]
422/// # {
423/// use wavekat_turn::audio::PipecatSmartTurn;
424/// use wavekat_turn::AudioTurnDetector;
425///
426/// let mut detector = PipecatSmartTurn::new().unwrap();
427/// // ... feed frames via push_audio ...
428/// let prediction = detector.predict().unwrap();
429/// println!("{:?} ({:.2})", prediction.state, prediction.confidence);
430/// # }
431/// ```
432///
433/// [`new`]: Self::new
434/// [`with_variant`]: Self::with_variant
435/// [`push_audio`]: AudioTurnDetector::push_audio
436/// [`predict`]: AudioTurnDetector::predict
437pub struct PipecatSmartTurn {
438    session: ort::session::Session,
439    ring_buffer: VecDeque<f32>,
440    mel: MelExtractor,
441    /// Counts samples pushed since the last `predict()` call.
442    /// Used to compute `shift_frames` for incremental STFT.
443    samples_since_predict: usize,
444}
445
446// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every
447// method that touches the session takes &mut self, preventing concurrent use.
448unsafe impl Send for PipecatSmartTurn {}
449unsafe impl Sync for PipecatSmartTurn {}
450
451impl PipecatSmartTurn {
452    /// Load the upstream Pipecat Smart Turn v3.2 model embedded at compile time.
453    ///
454    /// Equivalent to [`with_variant(SmartTurnVariant::PipecatV3)`](Self::with_variant).
455    pub fn new() -> Result<Self, TurnError> {
456        let session = onnx::session_from_memory(MODEL_BYTES)?;
457        Ok(Self::build(session))
458    }
459
460    /// Load a specific variant of the Smart Turn model.
461    ///
462    /// - [`SmartTurnVariant::PipecatV3`] uses the embedded ONNX bytes — no
463    ///   network required.
464    /// - [`SmartTurnVariant::Wavekat`] (when the `wavekat-smart-turn` feature
465    ///   is enabled) downloads the corresponding language file from the
466    ///   `wavekat/smart-turn-ONNX` HuggingFace repo and caches it under
467    ///   `$HF_HOME/hub/`. Set `WAVEKAT_TURN_MODEL_DIR` to point at a
468    ///   pre-populated directory (offline / CI use).
469    pub fn with_variant(variant: SmartTurnVariant) -> Result<Self, TurnError> {
470        match variant {
471            SmartTurnVariant::PipecatV3 => Self::new(),
472            #[cfg(feature = "wavekat-smart-turn")]
473            SmartTurnVariant::Wavekat(lang) => {
474                let path = crate::audio::wavekat_download::resolve_model(lang)?;
475                Self::from_file(path)
476            }
477        }
478    }
479
480    /// Load a model from a custom path on disk.
481    ///
482    /// Useful for CI environments that supply the model file separately, or
483    /// for evaluating fine-tuned variants without recompiling.
484    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TurnError> {
485        let session = onnx::session_from_file(path)?;
486        Ok(Self::build(session))
487    }
488
489    fn build(session: ort::session::Session) -> Self {
490        Self {
491            session,
492            ring_buffer: VecDeque::with_capacity(RING_CAPACITY),
493            mel: MelExtractor::new(),
494            samples_since_predict: 0,
495        }
496    }
497}
498
499impl AudioTurnDetector for PipecatSmartTurn {
500    /// Append audio to the internal ring buffer.
501    ///
502    /// Frames with a sample rate other than 16 kHz are silently dropped.
503    /// The ring buffer holds at most 8 s; older samples are evicted.
504    fn push_audio(&mut self, frame: &AudioFrame) {
505        if frame.sample_rate() != SAMPLE_RATE {
506            return;
507        }
508        let samples = frame.samples();
509        // Evict oldest samples to make room
510        let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY);
511        if overflow > 0 {
512            self.ring_buffer.drain(..overflow);
513        }
514        self.ring_buffer.extend(samples.iter().copied());
515        self.samples_since_predict += samples.len();
516    }
517
518    /// Run inference on the buffered audio.
519    ///
520    /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts
521    /// Whisper log-mel features, and runs ONNX inference.
522    fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
523        let t_start = Instant::now();
524
525        // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples
526        let shift_frames = self.samples_since_predict / HOP_LENGTH;
527        self.samples_since_predict = 0;
528
529        let buffered: Vec<f32> = self.ring_buffer.iter().copied().collect();
530        let audio = prepare_audio(&buffered);
531        let t_after_audio_prep = Instant::now();
532
533        // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental)
534        let mel_spec = self.mel.extract(&audio, shift_frames);
535        let t_after_mel = Instant::now();
536
537        // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference
538        let (raw, _) = mel_spec.into_raw_vec_and_offset();
539        let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw)
540            .expect("internal: mel output has wrong element count");
541
542        let input_tensor = Tensor::from_array(input_array)
543            .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?;
544
545        let outputs = self
546            .session
547            .run(inputs!["input_features" => input_tensor])
548            .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?;
549        let t_after_onnx = Instant::now();
550
551        // Extract sigmoid probability from the "logits" output
552        let output = outputs
553            .get("logits")
554            .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?;
555        let (_, data): (_, &[f32]) = output
556            .try_extract_tensor()
557            .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?;
558        let probability = *data
559            .first()
560            .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?;
561
562        let latency_ms = t_start.elapsed().as_millis() as u64;
563
564        let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0;
565        let stage_times = vec![
566            StageTiming {
567                name: "audio_prep",
568                us: us(t_start, t_after_audio_prep),
569            },
570            StageTiming {
571                name: "mel",
572                us: us(t_after_audio_prep, t_after_mel),
573            },
574            StageTiming {
575                name: "onnx",
576                us: us(t_after_mel, t_after_onnx),
577            },
578        ];
579
580        // probability = P(turn complete); > 0.5 means the speaker has finished
581        let (state, confidence) = if probability > 0.5 {
582            (TurnState::Finished, probability)
583        } else {
584            (TurnState::Unfinished, 1.0 - probability)
585        };
586
587        let audio_duration_ms = (self.ring_buffer.len() as u64 * 1000) / SAMPLE_RATE as u64;
588
589        Ok(TurnPrediction {
590            state,
591            confidence,
592            latency_ms,
593            stage_times,
594            audio_duration_ms,
595        })
596    }
597
598    /// Clear the ring buffer. Call at the start of each new speech turn.
599    fn reset(&mut self) {
600        self.ring_buffer.clear();
601        self.samples_since_predict = 0;
602        self.mel.invalidate_cache();
603    }
604}
605
606// ---------------------------------------------------------------------------
607// Mel comparison tests (unit tests — need access to private MelExtractor)
608// ---------------------------------------------------------------------------
609
610#[cfg(test)]
611mod mel_tests {
612    use std::path::{Path, PathBuf};
613
614    use ndarray::Array2;
615    use ndarray_npy::ReadNpyExt;
616
617    use super::{prepare_audio, MelExtractor, RING_CAPACITY, SAMPLE_RATE};
618
619    /// Max allowed element-wise absolute difference between Rust and Python mel.
620    const MEL_TOLERANCE: f32 = 0.05;
621
622    fn fixtures_dir() -> PathBuf {
623        Path::new(env!("CARGO_MANIFEST_DIR"))
624            .parent()
625            .unwrap() // crates/
626            .parent()
627            .unwrap() // repo root
628            .join("tests/fixtures")
629    }
630
631    /// Load 16 kHz mono WAV as f32 in [-1, 1], normalised the same way as
632    /// Python's soundfile (divide by 32768, not i16::MAX).
633    fn load_wav_f32(path: &Path) -> Vec<f32> {
634        let mut reader = hound::WavReader::open(path)
635            .unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
636        let spec = reader.spec();
637        assert_eq!(spec.sample_rate, SAMPLE_RATE, "expected 16 kHz");
638        assert_eq!(spec.channels, 1, "expected mono");
639        match spec.sample_format {
640            hound::SampleFormat::Int => reader
641                .samples::<i16>()
642                .map(|s| s.unwrap() as f32 / 32768.0)
643                .collect(),
644            hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
645        }
646    }
647
648    fn load_python_mel(clip: &str) -> Array2<f32> {
649        let path = fixtures_dir().join(format!("{clip}.mel.npy"));
650        let file = std::fs::File::open(&path).unwrap_or_else(|_| {
651            panic!(
652                "missing {}: run `python scripts/gen_reference.py` first",
653                path.display()
654            )
655        });
656        Array2::<f32>::read_npy(file).expect("failed to parse .npy")
657    }
658
659    struct MelDiff {
660        max_diff: f32,
661        mean_diff: f32,
662        /// (mel_bin, frame) of the single largest diff
663        max_at: (usize, usize),
664        /// fraction of elements with diff > 0.01
665        outlier_frac: f32,
666    }
667
668    fn compare_mel(clip: &str) -> MelDiff {
669        let samples = load_wav_f32(&fixtures_dir().join(clip));
670        let audio = prepare_audio(&samples);
671        assert_eq!(audio.len(), RING_CAPACITY);
672
673        let mut extractor = MelExtractor::new();
674        let rust_mel = extractor.extract(&audio, 0);
675        let python_mel = load_python_mel(clip);
676
677        assert_eq!(
678            rust_mel.shape(),
679            python_mel.shape(),
680            "{clip}: mel shape mismatch"
681        );
682
683        let shape = rust_mel.shape();
684        let (n_mels, n_frames) = (shape[0], shape[1]);
685
686        let mut max_diff = 0.0f32;
687        let mut max_at = (0, 0);
688        let mut sum_diff = 0.0f32;
689        let mut outliers = 0usize;
690
691        for m in 0..n_mels {
692            for t in 0..n_frames {
693                let d = (rust_mel[[m, t]] - python_mel[[m, t]]).abs();
694                sum_diff += d;
695                if d > max_diff {
696                    max_diff = d;
697                    max_at = (m, t);
698                }
699                if d > 0.01 {
700                    outliers += 1;
701                }
702            }
703        }
704
705        let total = (n_mels * n_frames) as f32;
706        MelDiff {
707            max_diff,
708            mean_diff: sum_diff / total,
709            max_at,
710            outlier_frac: outliers as f32 / total,
711        }
712    }
713
714    /// Print a markdown table of mel-level diffs between Rust and Python.
715    /// Run with: `make mel`
716    #[test]
717    #[ignore]
718    fn mel_report() {
719        let clips = ["silence_2s.wav", "speech_finished.wav", "speech_mid.wav"];
720
721        println!();
722        println!("MEL_TOLERANCE={MEL_TOLERANCE}");
723        println!();
724        println!("| Clip | Max Diff | Mean Diff | Max at (mel,frame) | Outliers >0.01 | Status |");
725        println!("|------|----------|-----------|---------------------|----------------|--------|");
726        for clip in clips {
727            let d = compare_mel(clip);
728            let status = if d.max_diff <= MEL_TOLERANCE {
729                "PASS"
730            } else {
731                "FAIL"
732            };
733            println!(
734                "| `{clip}` | {:.6} | {:.6} | ({},{}) | {:.2}% | {status} |",
735                d.max_diff,
736                d.mean_diff,
737                d.max_at.0,
738                d.max_at.1,
739                d.outlier_frac * 100.0,
740            );
741        }
742        println!();
743    }
744}