Skip to main content

wavekat_turn/audio/
pipecat.rs

1//! Pipecat Smart Turn v3 backend.
2//!
3//! Audio-based turn detection using the Smart Turn ONNX model.
4//! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be
5//! upsampled before feeding to this detector.
6//!
7//! # Model
8//!
9//! - Source:  <https://huggingface.co/pipecat-ai/smart-turn-v3>
10//! - File:    `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB)
11//! - License: BSD 2-Clause
12//!
13//! # Tensor specification
14//!
15//! | Role   | Name             | Shape          | Dtype   |
16//! |--------|------------------|----------------|---------|
17//! | Input  | `input_features` | `[B, 80, 800]` | float32 |
18//! | Output | `logits`         | `[B, 1]`       | float32 |
19//!
20//! Despite the name, `logits` is a **sigmoid probability** P(turn complete)
21//! in [0, 1] — the sigmoid is fused into the model before ONNX export.
22//! Threshold: `probability > 0.5` → `TurnState::Finished`.
23//!
24//! # Mel-feature specification
25//!
26//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`:
27//!
28//! | Parameter     | Value                          |
29//! |---------------|--------------------------------|
30//! | Sample rate   | 16 000 Hz                      |
31//! | n_fft         | 400 samples (25 ms)            |
32//! | hop_length    | 160 samples (10 ms)            |
33//! | n_mels        | 80                             |
34//! | Freq range    | 0 – 8 000 Hz                   |
35//! | Mel scale     | Slaney (NOT HTK)               |
36//! | Window        | Hann (periodic, size 400)      |
37//! | Pre-emphasis  | None                           |
38//! | Log           | log10 with ε = 1e-10           |
39//! | Normalization | clamp(max − 8), (x + 4) / 4   |
40//!
41//! # Audio buffer
42//!
43//! - Exactly **8 seconds = 128 000 samples** at 16 kHz.
44//! - Shorter input: **front-padded** with zeros (audio is at the end).
45//! - Longer input: the **last** 8 s is used (oldest samples discarded).
46
47use std::collections::VecDeque;
48use std::path::Path;
49use std::sync::Arc;
50use std::time::Instant;
51
52use ndarray::{s, Array2, Array3};
53use ort::{inputs, value::Tensor};
54use realfft::num_complex::Complex;
55use realfft::{RealFftPlanner, RealToComplex};
56
57use crate::onnx;
58use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};
59
60// ---------------------------------------------------------------------------
61// Constants
62// ---------------------------------------------------------------------------
63
64/// Sample rate the model expects.
65const SAMPLE_RATE: u32 = 16_000;
66/// FFT window size in samples (25 ms at 16 kHz).
67const N_FFT: usize = 400;
68/// STFT hop length in samples (10 ms at 16 kHz).
69const HOP_LENGTH: usize = 160;
70/// Number of mel filterbank bins.
71const N_MELS: usize = 80;
72/// Number of STFT frames the model expects (8 s × 100 fps).
73const N_FRAMES: usize = 800;
74/// FFT frequency bins: N_FFT/2 + 1.
75const N_FREQS: usize = N_FFT / 2 + 1; // 201
76/// Ring buffer capacity: 8 s × 16 kHz.
77const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000
78
79/// Embedded ONNX model bytes, downloaded by build.rs at compile time.
80const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx"));
81
82// ---------------------------------------------------------------------------
83// Mel feature extractor
84// ---------------------------------------------------------------------------
85
86/// Pre-computed Whisper-style log-mel feature extractor.
87///
88/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`].
89/// [`MelExtractor::extract`] is then called per inference.
90struct MelExtractor {
91    /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS].
92    mel_filters: Array2<f32>,
93    /// Periodic Hann window of length N_FFT.
94    hann_window: Vec<f32>,
95    /// Reusable forward real FFT plan.
96    fft: Arc<dyn RealToComplex<f32>>,
97    /// Reusable scratch buffer for the FFT.
98    fft_scratch: Vec<Complex<f32>>,
99    /// Reusable output spectrum buffer (N_FREQS complex values).
100    spectrum_buf: Vec<Complex<f32>>,
101    /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call.
102    /// Enables incremental STFT: only new frames are recomputed.
103    cached_power_spec: Option<Array2<f32>>,
104    /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call.
105    /// Enables incremental mel filterbank: only new columns are recomputed.
106    cached_mel_spec: Option<Array2<f32>>,
107}
108
109impl MelExtractor {
110    fn new() -> Self {
111        let mel_filters = build_mel_filters(
112            SAMPLE_RATE as usize,
113            N_FFT,
114            N_MELS,
115            0.0,
116            SAMPLE_RATE as f32 / 2.0,
117        );
118        let hann_window = periodic_hann(N_FFT);
119
120        let mut planner = RealFftPlanner::<f32>::new();
121        let fft = planner.plan_fft_forward(N_FFT);
122        let fft_scratch = fft.make_scratch_vec();
123        let spectrum_buf = fft.make_output_vec();
124
125        Self {
126            mel_filters,
127            hann_window,
128            fft,
129            fft_scratch,
130            spectrum_buf,
131            cached_power_spec: None,
132            cached_mel_spec: None,
133        }
134    }
135
136    /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly
137    /// `RING_CAPACITY` samples of 16 kHz mono audio.
138    ///
139    /// `shift_frames` is how many STFT frames worth of new audio were added
140    /// since the last call. When a valid cache exists and `shift_frames` is
141    /// in range, only the last `shift_frames` columns of the power spectrogram
142    /// are recomputed; the rest are copied from the shifted cache.
143    fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
144        debug_assert_eq!(audio.len(), RING_CAPACITY);
145
146        // ---- Center-pad: N_FFT/2 reflect samples on each side → 128 400 samples ----
147        // Matches WhisperFeatureExtractor: np.pad(waveform, n_fft//2, mode="reflect").
148        // Reflect (not zero) padding ensures the boundary frames match Python exactly.
149        // Gives exactly N_FRAMES + 1 = 801 frames; we discard the last one.
150        let pad = N_FFT / 2; // 200
151        let n = audio.len(); // 128 000
152        let mut padded = vec![0.0f32; pad + n + pad];
153        padded[pad..pad + n].copy_from_slice(audio);
154        // Left reflect: padded[0..pad] = audio[pad..1] reversed (exclude edge)
155        for i in 0..pad {
156            padded[i] = audio[pad - i];
157        }
158        // Right reflect: padded[pad+n..pad+n+pad] = audio[n-2..n-2-pad] reversed
159        for i in 0..pad {
160            padded[pad + n + i] = audio[n - 2 - i];
161        }
162
163        // n_total = (128 400 − 400) / 160 + 1 = 801
164        let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
165
166        // ---- Incremental STFT ----
167        // If we have a cached power spec and shift_frames < n_total_frames,
168        // reuse the unchanged frames by shifting the cache left and only
169        // computing the `shift_frames` new columns at the end.
170        let first_new_frame = match &self.cached_power_spec {
171            Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => {
172                let kept = n_total_frames - shift_frames;
173                let mut power_spec = Array2::<f32>::zeros((N_FREQS, n_total_frames));
174                power_spec
175                    .slice_mut(s![.., ..kept])
176                    .assign(&cached.slice(s![.., shift_frames..]));
177                self.cached_power_spec = Some(power_spec);
178                kept // only compute frames [kept..n_total_frames]
179            }
180            _ => {
181                self.cached_power_spec = Some(Array2::<f32>::zeros((N_FREQS, n_total_frames)));
182                0 // cold start: compute all frames
183            }
184        };
185
186        let power_spec = self.cached_power_spec.as_mut().unwrap();
187        let mut frame_buf = vec![0.0f32; N_FFT];
188
189        for frame_idx in first_new_frame..n_total_frames {
190            let start = frame_idx * HOP_LENGTH;
191            // Apply periodic Hann window
192            for (i, (&s, &w)) in padded[start..start + N_FFT]
193                .iter()
194                .zip(self.hann_window.iter())
195                .enumerate()
196            {
197                frame_buf[i] = s * w;
198            }
199
200            self.fft
201                .process_with_scratch(
202                    &mut frame_buf,
203                    &mut self.spectrum_buf,
204                    &mut self.fft_scratch,
205                )
206                .expect("FFT failed: internal buffer size mismatch");
207
208            for (k, c) in self.spectrum_buf.iter().enumerate() {
209                power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im;
210            }
211        }
212
213        // Take first N_FRAMES columns (drop the trailing frame)
214        let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]);
215
216        // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ----
217        // Reuse the cached mel columns for the unchanged frames; only multiply
218        // the new power-spectrum columns against the filterbank.
219        let mel_spec = match &self.cached_mel_spec {
220            Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => {
221                let kept = N_FRAMES - shift_frames;
222                let mut ms = Array2::<f32>::zeros((N_MELS, N_FRAMES));
223                // Shift old columns left
224                ms.slice_mut(s![.., ..kept])
225                    .assign(&cached.slice(s![.., shift_frames..]));
226                // Apply filterbank only to the new power-spectrum columns
227                let new_power = power_spec_view.slice(s![.., kept..]);
228                ms.slice_mut(s![.., kept..])
229                    .assign(&self.mel_filters.dot(&new_power));
230                ms
231            }
232            _ => self.mel_filters.dot(&power_spec_view),
233        };
234        self.cached_mel_spec = Some(mel_spec.clone());
235
236        // ---- Log10 with floor at 1e-10 ----
237        let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10());
238
239        // ---- Dynamic range compression and normalization ----
240        // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4
241        let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
242        log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0);
243
244        log_mel
245    }
246
247    /// Invalidate all caches (call on reset).
248    fn invalidate_cache(&mut self) {
249        self.cached_power_spec = None;
250        self.cached_mel_spec = None;
251    }
252}
253
254// ---------------------------------------------------------------------------
255// Mel filterbank construction — Slaney scale, slaney norm
256// ---------------------------------------------------------------------------
257
258/// Convert Hz to mel (Slaney/librosa scale, NOT HTK).
259fn hz_to_mel(hz: f32) -> f32 {
260    const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel)
261    const MIN_LOG_HZ: f32 = 1000.0;
262    const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0
263                                                // logstep = ln(6.4) / 27  (≈ 0.068752)
264    let logstep = (6.4_f32).ln() / 27.0;
265    if hz >= MIN_LOG_HZ {
266        MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep
267    } else {
268        hz / F_SP
269    }
270}
271
272/// Convert mel back to Hz (Slaney scale).
273fn mel_to_hz(mel: f32) -> f32 {
274    const F_SP: f32 = 200.0 / 3.0;
275    const MIN_LOG_HZ: f32 = 1000.0;
276    const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP;
277    let logstep = (6.4_f32).ln() / 27.0;
278    if mel >= MIN_LOG_MEL {
279        MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp()
280    } else {
281        mel * F_SP
282    }
283}
284
285/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs].
286///
287/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax,
288///   norm="slaney", dtype=float32)` which is what HuggingFace's
289/// `WhisperFeatureExtractor` uses internally.
290fn build_mel_filters(
291    sr: usize,
292    n_fft: usize,
293    n_mels: usize,
294    f_min: f32,
295    f_max: f32,
296) -> Array2<f32> {
297    let n_freqs = n_fft / 2 + 1;
298
299    // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, …
300    let fft_freqs: Vec<f32> = (0..n_freqs)
301        .map(|i| i as f32 * sr as f32 / n_fft as f32)
302        .collect();
303
304    // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge)
305    let mel_min = hz_to_mel(f_min);
306    let mel_max = hz_to_mel(f_max);
307    let mel_pts: Vec<f32> = (0..=(n_mels + 1))
308        .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
309        .collect();
310    let hz_pts: Vec<f32> = mel_pts.iter().map(|&m| mel_to_hz(m)).collect();
311
312    // Build triangular filters with Slaney normalisation
313    let mut filters = Array2::<f32>::zeros((n_mels, n_freqs));
314    for m in 0..n_mels {
315        let f_left = hz_pts[m];
316        let f_center = hz_pts[m + 1];
317        let f_right = hz_pts[m + 2];
318        // Slaney norm: 2 / (right_hz − left_hz)
319        let enorm = 2.0 / (f_right - f_left);
320
321        for (k, &f) in fft_freqs.iter().enumerate() {
322            let w = if f >= f_left && f <= f_center {
323                (f - f_left) / (f_center - f_left)
324            } else if f > f_center && f <= f_right {
325                (f_right - f) / (f_right - f_center)
326            } else {
327                0.0
328            };
329            filters[[m, k]] = w * enorm;
330        }
331    }
332    filters
333}
334
335// ---------------------------------------------------------------------------
336// Hann window
337// ---------------------------------------------------------------------------
338
339/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`.
340///
341/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n.
342/// This differs from the symmetric variant (which divides by n−1).
343fn periodic_hann(n: usize) -> Vec<f32> {
344    use std::f32::consts::PI;
345    (0..n)
346        .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos()))
347        .collect()
348}
349
350// ---------------------------------------------------------------------------
351// Audio preparation
352// ---------------------------------------------------------------------------
353
354/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples.
355///
356/// - Longer: keep the **last** 8 s (discard oldest).
357/// - Shorter: **front-pad** with zeros so audio is right-aligned.
358fn prepare_audio(samples: &[f32]) -> Vec<f32> {
359    match samples.len().cmp(&RING_CAPACITY) {
360        std::cmp::Ordering::Equal => samples.to_vec(),
361        std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(),
362        std::cmp::Ordering::Less => {
363            let mut out = vec![0.0f32; RING_CAPACITY - samples.len()];
364            out.extend_from_slice(samples);
365            out
366        }
367    }
368}
369
370// ---------------------------------------------------------------------------
371// PipecatSmartTurn
372// ---------------------------------------------------------------------------
373
374/// Pipecat Smart Turn v3 detector.
375///
376/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with
377/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires
378/// end-of-speech to get a [`TurnPrediction`].
379///
380/// # Usage with VAD
381///
382/// ```no_run
383/// # #[cfg(feature = "pipecat")]
384/// # {
385/// use wavekat_turn::audio::PipecatSmartTurn;
386/// use wavekat_turn::AudioTurnDetector;
387///
388/// let mut detector = PipecatSmartTurn::new().unwrap();
389/// // ... feed frames via push_audio ...
390/// let prediction = detector.predict().unwrap();
391/// println!("{:?} ({:.2})", prediction.state, prediction.confidence);
392/// # }
393/// ```
394///
395/// [`push_audio`]: AudioTurnDetector::push_audio
396/// [`predict`]: AudioTurnDetector::predict
397pub struct PipecatSmartTurn {
398    session: ort::session::Session,
399    ring_buffer: VecDeque<f32>,
400    mel: MelExtractor,
401    /// Counts samples pushed since the last `predict()` call.
402    /// Used to compute `shift_frames` for incremental STFT.
403    samples_since_predict: usize,
404}
405
406// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every
407// method that touches the session takes &mut self, preventing concurrent use.
408unsafe impl Send for PipecatSmartTurn {}
409unsafe impl Sync for PipecatSmartTurn {}
410
411impl PipecatSmartTurn {
412    /// Load the Smart Turn v3.2 model embedded at compile time.
413    pub fn new() -> Result<Self, TurnError> {
414        let session = onnx::session_from_memory(MODEL_BYTES)?;
415        Ok(Self::build(session))
416    }
417
418    /// Load a model from a custom path on disk.
419    ///
420    /// Useful for CI environments that supply the model file separately, or
421    /// for evaluating fine-tuned variants without recompiling.
422    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TurnError> {
423        let session = onnx::session_from_file(path)?;
424        Ok(Self::build(session))
425    }
426
427    fn build(session: ort::session::Session) -> Self {
428        Self {
429            session,
430            ring_buffer: VecDeque::with_capacity(RING_CAPACITY),
431            mel: MelExtractor::new(),
432            samples_since_predict: 0,
433        }
434    }
435}
436
437impl AudioTurnDetector for PipecatSmartTurn {
438    /// Append audio to the internal ring buffer.
439    ///
440    /// Frames with a sample rate other than 16 kHz are silently dropped.
441    /// The ring buffer holds at most 8 s; older samples are evicted.
442    fn push_audio(&mut self, frame: &AudioFrame) {
443        if frame.sample_rate() != SAMPLE_RATE {
444            return;
445        }
446        let samples = frame.samples();
447        // Evict oldest samples to make room
448        let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY);
449        if overflow > 0 {
450            self.ring_buffer.drain(..overflow);
451        }
452        self.ring_buffer.extend(samples.iter().copied());
453        self.samples_since_predict += samples.len();
454    }
455
456    /// Run inference on the buffered audio.
457    ///
458    /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts
459    /// Whisper log-mel features, and runs ONNX inference.
460    fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
461        let t_start = Instant::now();
462
463        // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples
464        let shift_frames = self.samples_since_predict / HOP_LENGTH;
465        self.samples_since_predict = 0;
466
467        let buffered: Vec<f32> = self.ring_buffer.iter().copied().collect();
468        let audio = prepare_audio(&buffered);
469        let t_after_audio_prep = Instant::now();
470
471        // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental)
472        let mel_spec = self.mel.extract(&audio, shift_frames);
473        let t_after_mel = Instant::now();
474
475        // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference
476        let (raw, _) = mel_spec.into_raw_vec_and_offset();
477        let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw)
478            .expect("internal: mel output has wrong element count");
479
480        let input_tensor = Tensor::from_array(input_array)
481            .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?;
482
483        let outputs = self
484            .session
485            .run(inputs!["input_features" => input_tensor])
486            .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?;
487        let t_after_onnx = Instant::now();
488
489        // Extract sigmoid probability from the "logits" output
490        let output = outputs
491            .get("logits")
492            .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?;
493        let (_, data): (_, &[f32]) = output
494            .try_extract_tensor()
495            .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?;
496        let probability = *data
497            .first()
498            .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?;
499
500        let latency_ms = t_start.elapsed().as_millis() as u64;
501
502        let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0;
503        let stage_times = vec![
504            StageTiming {
505                name: "audio_prep",
506                us: us(t_start, t_after_audio_prep),
507            },
508            StageTiming {
509                name: "mel",
510                us: us(t_after_audio_prep, t_after_mel),
511            },
512            StageTiming {
513                name: "onnx",
514                us: us(t_after_mel, t_after_onnx),
515            },
516        ];
517
518        // probability = P(turn complete); > 0.5 means the speaker has finished
519        let (state, confidence) = if probability > 0.5 {
520            (TurnState::Finished, probability)
521        } else {
522            (TurnState::Unfinished, 1.0 - probability)
523        };
524
525        Ok(TurnPrediction {
526            state,
527            confidence,
528            latency_ms,
529            stage_times,
530        })
531    }
532
533    /// Clear the ring buffer. Call at the start of each new speech turn.
534    fn reset(&mut self) {
535        self.ring_buffer.clear();
536        self.samples_since_predict = 0;
537        self.mel.invalidate_cache();
538    }
539}
540
541// ---------------------------------------------------------------------------
542// Mel comparison tests (unit tests — need access to private MelExtractor)
543// ---------------------------------------------------------------------------
544
545#[cfg(test)]
546mod mel_tests {
547    use std::path::{Path, PathBuf};
548
549    use ndarray::Array2;
550    use ndarray_npy::ReadNpyExt;
551
552    use super::{prepare_audio, MelExtractor, RING_CAPACITY, SAMPLE_RATE};
553
554    /// Max allowed element-wise absolute difference between Rust and Python mel.
555    const MEL_TOLERANCE: f32 = 0.05;
556
557    fn fixtures_dir() -> PathBuf {
558        Path::new(env!("CARGO_MANIFEST_DIR"))
559            .parent()
560            .unwrap() // crates/
561            .parent()
562            .unwrap() // repo root
563            .join("tests/fixtures")
564    }
565
566    /// Load 16 kHz mono WAV as f32 in [-1, 1], normalised the same way as
567    /// Python's soundfile (divide by 32768, not i16::MAX).
568    fn load_wav_f32(path: &Path) -> Vec<f32> {
569        let mut reader = hound::WavReader::open(path)
570            .unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
571        let spec = reader.spec();
572        assert_eq!(spec.sample_rate, SAMPLE_RATE, "expected 16 kHz");
573        assert_eq!(spec.channels, 1, "expected mono");
574        match spec.sample_format {
575            hound::SampleFormat::Int => reader
576                .samples::<i16>()
577                .map(|s| s.unwrap() as f32 / 32768.0)
578                .collect(),
579            hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
580        }
581    }
582
583    fn load_python_mel(clip: &str) -> Array2<f32> {
584        let path = fixtures_dir().join(format!("{clip}.mel.npy"));
585        let file = std::fs::File::open(&path).unwrap_or_else(|_| {
586            panic!(
587                "missing {}: run `python scripts/gen_reference.py` first",
588                path.display()
589            )
590        });
591        Array2::<f32>::read_npy(file).expect("failed to parse .npy")
592    }
593
594    struct MelDiff {
595        max_diff: f32,
596        mean_diff: f32,
597        /// (mel_bin, frame) of the single largest diff
598        max_at: (usize, usize),
599        /// fraction of elements with diff > 0.01
600        outlier_frac: f32,
601    }
602
603    fn compare_mel(clip: &str) -> MelDiff {
604        let samples = load_wav_f32(&fixtures_dir().join(clip));
605        let audio = prepare_audio(&samples);
606        assert_eq!(audio.len(), RING_CAPACITY);
607
608        let mut extractor = MelExtractor::new();
609        let rust_mel = extractor.extract(&audio, 0);
610        let python_mel = load_python_mel(clip);
611
612        assert_eq!(
613            rust_mel.shape(),
614            python_mel.shape(),
615            "{clip}: mel shape mismatch"
616        );
617
618        let shape = rust_mel.shape();
619        let (n_mels, n_frames) = (shape[0], shape[1]);
620
621        let mut max_diff = 0.0f32;
622        let mut max_at = (0, 0);
623        let mut sum_diff = 0.0f32;
624        let mut outliers = 0usize;
625
626        for m in 0..n_mels {
627            for t in 0..n_frames {
628                let d = (rust_mel[[m, t]] - python_mel[[m, t]]).abs();
629                sum_diff += d;
630                if d > max_diff {
631                    max_diff = d;
632                    max_at = (m, t);
633                }
634                if d > 0.01 {
635                    outliers += 1;
636                }
637            }
638        }
639
640        let total = (n_mels * n_frames) as f32;
641        MelDiff {
642            max_diff,
643            mean_diff: sum_diff / total,
644            max_at,
645            outlier_frac: outliers as f32 / total,
646        }
647    }
648
649    /// Print a markdown table of mel-level diffs between Rust and Python.
650    /// Run with: `make mel`
651    #[test]
652    #[ignore]
653    fn mel_report() {
654        let clips = ["silence_2s.wav", "speech_finished.wav", "speech_mid.wav"];
655
656        println!();
657        println!("MEL_TOLERANCE={MEL_TOLERANCE}");
658        println!();
659        println!("| Clip | Max Diff | Mean Diff | Max at (mel,frame) | Outliers >0.01 | Status |");
660        println!("|------|----------|-----------|---------------------|----------------|--------|");
661        for clip in clips {
662            let d = compare_mel(clip);
663            let status = if d.max_diff <= MEL_TOLERANCE {
664                "PASS"
665            } else {
666                "FAIL"
667            };
668            println!(
669                "| `{clip}` | {:.6} | {:.6} | ({},{}) | {:.2}% | {status} |",
670                d.max_diff,
671                d.mean_diff,
672                d.max_at.0,
673                d.max_at.1,
674                d.outlier_frac * 100.0,
675            );
676        }
677        println!();
678    }
679}