Skip to main content

omni_dev/voice/
wav.rs

1//! Mixdown, resampling, and WAV writing.
2//!
3//! The write-path half of the capture pipeline:
4//!
5//! 1. [`mono_mixdown`] collapses N interleaved channels into a single mono
6//!    stream by averaging.
7//! 2. [`Resampler`] rate-converts mono f32 from the device-native rate to
8//!    `16_000` Hz via `rubato`'s sinc interpolator. Identity-passthrough is
9//!    used when the input is already 16 kHz so the pipeline stays
10//!    bit-exact in that common case.
11//! 3. [`WavWriter`] serialises 16 kHz mono f32 to 16-bit signed PCM WAV via
12//!    `hound`, with clamp-on-cast to handle resampler overshoot at the
13//!    extremes of `[-1.0, 1.0]`.
14//!
15//! Idle detection and trailing-silence trimming live in
16//! [`super::idle`] — they operate on the post-resample 16 kHz stream
17//! and are independent of the writer.
18
19use std::fs::{self, File};
20use std::io::BufWriter;
21use std::path::{Path, PathBuf};
22
23use anyhow::{Context, Result};
24use hound::{SampleFormat, WavSpec};
25use rubato::audioadapter_buffers::direct::InterleavedSlice;
26use rubato::{
27    Async, FixedAsync, Indexing, Resampler as _, SincInterpolationParameters,
28    SincInterpolationType, WindowFunction,
29};
30
31/// Target sample rate for the capture pipeline (whisper.cpp convention).
32pub const TARGET_SAMPLE_RATE: u32 = 16_000;
33
34/// Number of input frames fed to the resampler per call. Sized to amortise
35/// the sinc filter overhead without making the streaming API feel chunky.
36/// At 48 kHz input this is ~85 ms of audio.
37pub const RESAMPLER_CHUNK_FRAMES: usize = 4096;
38
39/// Averages N interleaved channels into a single mono stream.
40///
41/// `samples.len()` must be a multiple of `channels`; any trailing partial
42/// frame is silently dropped (cpal callbacks never produce partial frames,
43/// but the guard lets fixture sources be sloppy without panicking).
44///
45/// When `channels == 1` the input is returned as-is (no copy round-trip
46/// through arithmetic), which preserves bit-exact behaviour for fixtures
47/// that are already mono.
48#[must_use]
49pub fn mono_mixdown(samples: &[f32], channels: u16) -> Vec<f32> {
50    if channels <= 1 {
51        return samples.to_vec();
52    }
53    let channels = channels as usize;
54    let frame_count = samples.len() / channels;
55    let mut out = Vec::with_capacity(frame_count);
56    let inv = 1.0_f32 / channels as f32;
57    for frame in samples.chunks_exact(channels) {
58        let sum: f32 = frame.iter().copied().sum();
59        out.push(sum * inv);
60    }
61    out
62}
63
64/// Streaming resampler from an arbitrary input rate to
65/// [`TARGET_SAMPLE_RATE`] (16 kHz), mono.
66///
67/// The wrapper buffers input frames until it has enough for one
68/// fixed-size `rubato` chunk, processes that chunk, and accumulates the
69/// variable-length output. Callers feed mono f32 via [`Resampler::push`]
70/// and drain a single tail batch via [`Resampler::flush`] at end-of-stream.
71///
72/// At 16 kHz input the resampler is bypassed entirely — input is forwarded
73/// to output verbatim, with zero sinc-filter latency.
74pub struct Resampler {
75    input_rate: u32,
76    inner: Option<Inner>,
77}
78
79struct Inner {
80    resampler: Async<f32>,
81    /// Pending input frames not yet large enough for one chunk.
82    pending: Vec<f32>,
83    /// Required input frames per call (constant for fixed-input `Async`).
84    chunk_frames: usize,
85}
86
87impl Resampler {
88    /// Builds a resampler that converts `input_rate` Hz mono f32 to
89    /// 16 kHz mono f32. Returns an error if the rate is zero or the
90    /// `rubato` constructor rejects the configuration.
91    pub fn new(input_rate: u32) -> Result<Self> {
92        if input_rate == 0 {
93            anyhow::bail!("Resampler input rate must be > 0");
94        }
95        if input_rate == TARGET_SAMPLE_RATE {
96            return Ok(Self {
97                input_rate,
98                inner: None,
99            });
100        }
101        let ratio = f64::from(TARGET_SAMPLE_RATE) / f64::from(input_rate);
102        let params = SincInterpolationParameters {
103            sinc_len: 256,
104            f_cutoff: 0.95,
105            oversampling_factor: 128,
106            interpolation: SincInterpolationType::Linear,
107            window: WindowFunction::BlackmanHarris2,
108        };
109        let resampler = Async::<f32>::new_sinc(
110            ratio,
111            1.0,
112            &params,
113            RESAMPLER_CHUNK_FRAMES,
114            1,
115            FixedAsync::Input,
116        )
117        .with_context(|| {
118            format!("Failed to build resampler for {input_rate} Hz → {TARGET_SAMPLE_RATE} Hz")
119        })?;
120        Ok(Self {
121            input_rate,
122            inner: Some(Inner {
123                resampler,
124                pending: Vec::with_capacity(RESAMPLER_CHUNK_FRAMES * 2),
125                chunk_frames: RESAMPLER_CHUNK_FRAMES,
126            }),
127        })
128    }
129
130    /// The configured input sample rate in Hz.
131    #[must_use]
132    pub fn input_rate(&self) -> u32 {
133        self.input_rate
134    }
135
136    /// Output sample rate (constant — always [`TARGET_SAMPLE_RATE`]).
137    #[must_use]
138    pub fn output_rate(&self) -> u32 {
139        TARGET_SAMPLE_RATE
140    }
141
142    /// Feeds mono samples and returns any 16 kHz output produced this call.
143    ///
144    /// Partial input that doesn't fill a full chunk is buffered internally
145    /// and emitted on a subsequent `push` or `flush`.
146    pub fn push(&mut self, mono: &[f32]) -> Result<Vec<f32>> {
147        let Some(inner) = self.inner.as_mut() else {
148            return Ok(mono.to_vec());
149        };
150        inner.pending.extend_from_slice(mono);
151        let mut out = Vec::new();
152        while inner.pending.len() >= inner.chunk_frames {
153            let chunk = &inner.pending[..inner.chunk_frames];
154            let Ok(input_adapter) = InterleavedSlice::new(chunk, 1, inner.chunk_frames) else {
155                unreachable!("chunk.len() == 1 * chunk_frames by construction")
156            };
157            let drained = inner
158                .resampler
159                .process(&input_adapter, 0, None)
160                .context("Resampler chunk processing failed")?;
161            inner.pending.drain(..inner.chunk_frames);
162            // Mono → interleaved layout is just a flat Vec<f32> of samples.
163            out.extend(drained.take_data());
164        }
165        Ok(out)
166    }
167
168    /// Flushes any buffered samples at end-of-stream using `process_partial`
169    /// (zero-pads internally). Call at most once after the source is
170    /// exhausted.
171    pub fn flush(&mut self) -> Result<Vec<f32>> {
172        let Some(inner) = self.inner.as_mut() else {
173            return Ok(Vec::new());
174        };
175        let tail = std::mem::take(&mut inner.pending);
176        if tail.is_empty() {
177            return Ok(Vec::new());
178        }
179        let Ok(input_adapter) = InterleavedSlice::new(&tail, 1, tail.len()) else {
180            unreachable!("tail.len() == 1 * tail.len() by construction")
181        };
182        let output_capacity = inner.resampler.output_frames_max();
183        let mut output_buf = vec![0.0_f32; output_capacity];
184        let Ok(mut output_adapter) = InterleavedSlice::new_mut(&mut output_buf, 1, output_capacity)
185        else {
186            unreachable!("output_buf.len() == 1 * output_capacity by construction")
187        };
188        let indexing = Indexing {
189            input_offset: 0,
190            output_offset: 0,
191            partial_len: Some(tail.len()),
192            active_channels_mask: None,
193        };
194        let (_in_frames, out_frames) = inner
195            .resampler
196            .process_into_buffer(&input_adapter, &mut output_adapter, Some(&indexing))
197            .context("Resampler flush failed")?;
198        output_buf.truncate(out_frames);
199        Ok(output_buf)
200    }
201}
202
203/// Bit depth of the output WAV (whisper.cpp convention).
204pub const OUTPUT_BITS_PER_SAMPLE: u16 = 16;
205
206/// Streaming WAV writer that accepts mono 16 kHz f32 samples and emits
207/// 16-bit signed PCM.
208///
209/// Samples are clamped into `[-1.0, 1.0]` before the cast to `i16`.
210/// [`WavWriter::finalize`] must be called to flush the header; on drop
211/// without finalisation the file is left with an invalid header (size
212/// field unwritten) — the orchestrator therefore always calls
213/// `finalize`, even on signal-driven shutdown.
214pub struct WavWriter {
215    inner: hound::WavWriter<BufWriter<File>>,
216    path: PathBuf,
217    samples_written: u64,
218}
219
220impl WavWriter {
221    /// Creates a WAV file at `path`, with parent directories created if
222    /// they do not exist. The header is written eagerly; the size field
223    /// is patched up by [`WavWriter::finalize`].
224    pub fn create(path: impl AsRef<Path>) -> Result<Self> {
225        let path = path.as_ref().to_path_buf();
226        if let Some(parent) = path.parent() {
227            if !parent.as_os_str().is_empty() {
228                fs::create_dir_all(parent).with_context(|| {
229                    format!("Failed to create parent directory {}", parent.display())
230                })?;
231            }
232        }
233        let spec = WavSpec {
234            channels: 1,
235            sample_rate: TARGET_SAMPLE_RATE,
236            bits_per_sample: OUTPUT_BITS_PER_SAMPLE,
237            sample_format: SampleFormat::Int,
238        };
239        let inner = hound::WavWriter::create(&path, spec)
240            .with_context(|| format!("Failed to create WAV file at {}", path.display()))?;
241        Ok(Self {
242            inner,
243            path,
244            samples_written: 0,
245        })
246    }
247
248    /// Writes a chunk of mono samples. Values outside `[-1.0, 1.0]` are
249    /// clamped (resampler overshoot near 0 dBFS).
250    pub fn write_samples(&mut self, samples: &[f32]) -> Result<()> {
251        for s in samples {
252            let clamped = s.clamp(-1.0, 1.0);
253            let scaled = (clamped * f32::from(i16::MAX)).round() as i16;
254            self.inner
255                .write_sample(scaled)
256                .with_context(|| format!("Failed to write sample to {}", self.path.display()))?;
257        }
258        self.samples_written += samples.len() as u64;
259        Ok(())
260    }
261
262    /// Number of samples written so far (not counting any clamped value
263    /// any differently).
264    #[must_use]
265    pub fn samples_written(&self) -> u64 {
266        self.samples_written
267    }
268
269    /// Path the WAV was created at.
270    #[must_use]
271    pub fn path(&self) -> &Path {
272        &self.path
273    }
274
275    /// Flushes the WAV header so the file is playable. Consumes `self` so
276    /// double-finalisation cannot occur. Always call this — including on
277    /// signal-driven shutdown — or the on-disk file will be malformed.
278    pub fn finalize(self) -> Result<()> {
279        self.inner
280            .finalize()
281            .with_context(|| format!("Failed to finalize WAV file at {}", self.path.display()))
282    }
283}
284
285#[cfg(test)]
286#[allow(clippy::unwrap_used, clippy::expect_used)]
287mod tests {
288    use super::*;
289
290    use std::f32::consts::TAU;
291
292    #[test]
293    fn mono_mixdown_passes_through_mono_untouched() {
294        let input = vec![0.1, -0.2, 0.3, -0.4];
295        assert_eq!(mono_mixdown(&input, 1), input);
296    }
297
298    #[test]
299    fn mono_mixdown_averages_stereo_to_zero_for_inverted_signal() {
300        // L = 1.0, R = -1.0 → mean = 0.0 for every frame.
301        let input = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
302        let out = mono_mixdown(&input, 2);
303        assert_eq!(out, vec![0.0, 0.0, 0.0]);
304    }
305
306    #[test]
307    fn mono_mixdown_averages_quad_channel() {
308        // 4-channel frame averages: (1+2+3+4)/4 = 2.5, (5+6+7+8)/4 = 6.5
309        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
310        let out = mono_mixdown(&input, 4);
311        assert_eq!(out, vec![2.5, 6.5]);
312    }
313
314    #[test]
315    fn mono_mixdown_drops_trailing_partial_frame() {
316        // 7 samples / 2 channels = 3 full frames, 1 stranded sample.
317        let input = vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 99.0];
318        let out = mono_mixdown(&input, 2);
319        assert_eq!(out, vec![1.0, 2.0, 3.0]);
320    }
321
322    #[test]
323    fn resampler_identity_path_returns_input_verbatim() -> Result<()> {
324        let mut r = Resampler::new(TARGET_SAMPLE_RATE)?;
325        assert_eq!(r.input_rate(), TARGET_SAMPLE_RATE);
326        assert_eq!(r.output_rate(), TARGET_SAMPLE_RATE);
327        let input: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) - 0.5).collect();
328        let out = r.push(&input)?;
329        assert_eq!(out, input);
330        let flushed = r.flush()?;
331        assert!(flushed.is_empty());
332        Ok(())
333    }
334
335    #[test]
336    fn resampler_rejects_zero_input_rate() {
337        let err = Resampler::new(0).err().expect("must reject zero rate");
338        assert!(err.to_string().contains("> 0"));
339    }
340
341    fn sine_wave(rate: u32, freq_hz: f32, duration_s: f32, amplitude: f32) -> Vec<f32> {
342        let n = (rate as f32 * duration_s) as usize;
343        (0..n)
344            .map(|i| amplitude * (TAU * freq_hz * i as f32 / rate as f32).sin())
345            .collect()
346    }
347
348    fn rms(samples: &[f32]) -> f32 {
349        if samples.is_empty() {
350            return 0.0;
351        }
352        let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
353        (sum_sq / samples.len() as f32).sqrt()
354    }
355
356    #[test]
357    fn resampler_48k_to_16k_preserves_signal_rms() -> Result<()> {
358        // 2 s of a 440 Hz sine at amplitude 0.5 — well below Nyquist at both rates.
359        let input = sine_wave(48_000, 440.0, 2.0, 0.5);
360        let mut r = Resampler::new(48_000)?;
361        let mut output = r.push(&input)?;
362        output.extend(r.flush()?);
363        // 2 s @ 16 kHz ≈ 32_000 samples. The flush() call zero-pads any
364        // residual input to one full rubato chunk, so the output can run
365        // up to one chunk's worth of resampled frames long. Trailing-
366        // silence trim (step 4) is responsible for cleaning that up.
367        let expected_len: usize = 32_000;
368        let max_overrun = (RESAMPLER_CHUNK_FRAMES as f64 * 16_000.0 / 48_000.0).ceil() as usize;
369        assert!(
370            output.len() >= expected_len.saturating_sub(256),
371            "output too short: got {}, expected ≥ {}",
372            output.len(),
373            expected_len - 256
374        );
375        assert!(
376            output.len() <= expected_len + max_overrun + 256,
377            "output too long: got {}, expected ≤ {}",
378            output.len(),
379            expected_len + max_overrun + 256
380        );
381        // Ignore the first ~50 ms transient (sinc filter warm-up); compare RMS over the steady-state.
382        let warmup = 800; // 50 ms @ 16k
383        let in_rms = rms(&input);
384        let out_rms = rms(&output[warmup..]);
385        assert!(
386            (in_rms - out_rms).abs() < 0.02,
387            "RMS drift too large: in={in_rms}, out={out_rms}"
388        );
389        Ok(())
390    }
391
392    #[test]
393    fn resampler_chunked_and_one_shot_match() -> Result<()> {
394        // Two resamplers, identical input, different chunking — outputs must agree.
395        let input = sine_wave(48_000, 261.6, 1.0, 0.3);
396        let mut one_shot = Resampler::new(48_000)?;
397        let mut a = one_shot.push(&input)?;
398        a.extend(one_shot.flush()?);
399
400        let mut chunked = Resampler::new(48_000)?;
401        let mut b = Vec::new();
402        for chunk in input.chunks(977) {
403            // odd chunk size — exercises boundary handling
404            b.extend(chunked.push(chunk)?);
405        }
406        b.extend(chunked.flush()?);
407
408        assert_eq!(a.len(), b.len(), "chunked and one-shot length disagree");
409        for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
410            assert!(
411                (x - y).abs() < 1e-5,
412                "sample {i}: chunked={y}, one-shot={x}"
413            );
414        }
415        Ok(())
416    }
417
418    #[test]
419    fn wav_writer_round_trips_samples() -> Result<()> {
420        let tmp = tempfile::TempDir::new()?;
421        let path = tmp.path().join("out.wav");
422        let mut writer = WavWriter::create(&path)?;
423        let samples: Vec<f32> = (0..1000)
424            .map(|i| (TAU * 440.0 * i as f32 / 16_000.0).sin() * 0.25)
425            .collect();
426        writer.write_samples(&samples)?;
427        assert_eq!(writer.samples_written(), 1000);
428        writer.finalize()?;
429
430        let mut reader = hound::WavReader::open(&path)?;
431        let spec = reader.spec();
432        assert_eq!(spec.channels, 1);
433        assert_eq!(spec.sample_rate, TARGET_SAMPLE_RATE);
434        assert_eq!(spec.bits_per_sample, OUTPUT_BITS_PER_SAMPLE);
435        assert_eq!(spec.sample_format, SampleFormat::Int);
436        let decoded: Vec<f32> = reader
437            .samples::<i16>()
438            .map(|s| f32::from(s.unwrap()) / f32::from(i16::MAX))
439            .collect();
440        assert_eq!(decoded.len(), 1000);
441        // Round-trip drift bounded by 1 lsb of 16-bit quantisation.
442        for (i, (orig, got)) in samples.iter().zip(decoded.iter()).enumerate() {
443            assert!(
444                (orig - got).abs() < 1.0 / f32::from(i16::MAX),
445                "sample {i}: orig={orig}, got={got}"
446            );
447        }
448        Ok(())
449    }
450
451    #[test]
452    fn wav_writer_clamps_samples_to_int_range() -> Result<()> {
453        let tmp = tempfile::TempDir::new()?;
454        let path = tmp.path().join("clamp.wav");
455        let mut writer = WavWriter::create(&path)?;
456        // Values outside [-1.0, 1.0] should clamp, not wrap.
457        writer.write_samples(&[2.0, -2.0, 0.5, -0.5])?;
458        writer.finalize()?;
459
460        let mut reader = hound::WavReader::open(&path)?;
461        let decoded: Vec<i16> = reader.samples::<i16>().map(|s| s.unwrap()).collect();
462        assert_eq!(decoded[0], i16::MAX, "2.0 should clamp to i16::MAX");
463        // -1.0 * i16::MAX = -32767, not i16::MIN (-32768) — we scale by
464        // i16::MAX, not i16::MIN.abs(), so symmetric range.
465        assert_eq!(decoded[1], -i16::MAX, "-2.0 should clamp to -i16::MAX");
466        // 0.5 * 32767 ≈ 16383, with rounding.
467        assert!((decoded[2] - 16384).abs() <= 1);
468        assert!((decoded[3] + 16384).abs() <= 1);
469        Ok(())
470    }
471
472    #[test]
473    fn wav_writer_creates_parent_dirs() -> Result<()> {
474        let tmp = tempfile::TempDir::new()?;
475        let nested = tmp.path().join("a").join("b").join("c");
476        let path = nested.join("nested.wav");
477        let writer = WavWriter::create(&path)?;
478        writer.finalize()?;
479        assert!(path.exists());
480        Ok(())
481    }
482
483    #[test]
484    fn mono_mixdown_passes_through_zero_channels_unchanged() {
485        // channels == 0 hits the `<= 1` branch and returns the input as-is;
486        // documents the no-op behaviour rather than letting it silently
487        // change.
488        let input = vec![0.1, 0.2, 0.3];
489        assert_eq!(mono_mixdown(&input, 0), input);
490    }
491
492    #[test]
493    fn resampler_push_empty_returns_empty() -> Result<()> {
494        let mut r = Resampler::new(48_000)?;
495        let out = r.push(&[])?;
496        assert!(out.is_empty());
497        Ok(())
498    }
499
500    #[test]
501    fn resampler_flush_with_no_pending_input_is_empty() -> Result<()> {
502        // Identity path: no inner resampler, flush always returns empty.
503        let mut r = Resampler::new(TARGET_SAMPLE_RATE)?;
504        assert!(r.flush()?.is_empty());
505        Ok(())
506    }
507
508    #[test]
509    fn resampler_identity_push_empty_returns_empty() -> Result<()> {
510        let mut r = Resampler::new(TARGET_SAMPLE_RATE)?;
511        let out = r.push(&[])?;
512        assert!(out.is_empty());
513        Ok(())
514    }
515}