1use std::fs::{self, File};
20use std::io::BufWriter;
21use std::path::{Path, PathBuf};
22
23use anyhow::{Context, Result};
24use hound::{SampleFormat, WavSpec};
25use rubato::audioadapter_buffers::direct::InterleavedSlice;
26use rubato::{
27 Async, FixedAsync, Indexing, Resampler as _, SincInterpolationParameters,
28 SincInterpolationType, WindowFunction,
29};
30
31pub const TARGET_SAMPLE_RATE: u32 = 16_000;
33
34pub const RESAMPLER_CHUNK_FRAMES: usize = 4096;
38
39#[must_use]
49pub fn mono_mixdown(samples: &[f32], channels: u16) -> Vec<f32> {
50 if channels <= 1 {
51 return samples.to_vec();
52 }
53 let channels = channels as usize;
54 let frame_count = samples.len() / channels;
55 let mut out = Vec::with_capacity(frame_count);
56 let inv = 1.0_f32 / channels as f32;
57 for frame in samples.chunks_exact(channels) {
58 let sum: f32 = frame.iter().copied().sum();
59 out.push(sum * inv);
60 }
61 out
62}
63
64pub struct Resampler {
75 input_rate: u32,
76 inner: Option<Inner>,
77}
78
79struct Inner {
80 resampler: Async<f32>,
81 pending: Vec<f32>,
83 chunk_frames: usize,
85}
86
87impl Resampler {
88 pub fn new(input_rate: u32) -> Result<Self> {
92 if input_rate == 0 {
93 anyhow::bail!("Resampler input rate must be > 0");
94 }
95 if input_rate == TARGET_SAMPLE_RATE {
96 return Ok(Self {
97 input_rate,
98 inner: None,
99 });
100 }
101 let ratio = f64::from(TARGET_SAMPLE_RATE) / f64::from(input_rate);
102 let params = SincInterpolationParameters {
103 sinc_len: 256,
104 f_cutoff: 0.95,
105 oversampling_factor: 128,
106 interpolation: SincInterpolationType::Linear,
107 window: WindowFunction::BlackmanHarris2,
108 };
109 let resampler = Async::<f32>::new_sinc(
110 ratio,
111 1.0,
112 ¶ms,
113 RESAMPLER_CHUNK_FRAMES,
114 1,
115 FixedAsync::Input,
116 )
117 .with_context(|| {
118 format!("Failed to build resampler for {input_rate} Hz → {TARGET_SAMPLE_RATE} Hz")
119 })?;
120 Ok(Self {
121 input_rate,
122 inner: Some(Inner {
123 resampler,
124 pending: Vec::with_capacity(RESAMPLER_CHUNK_FRAMES * 2),
125 chunk_frames: RESAMPLER_CHUNK_FRAMES,
126 }),
127 })
128 }
129
130 #[must_use]
132 pub fn input_rate(&self) -> u32 {
133 self.input_rate
134 }
135
136 #[must_use]
138 pub fn output_rate(&self) -> u32 {
139 TARGET_SAMPLE_RATE
140 }
141
142 pub fn push(&mut self, mono: &[f32]) -> Result<Vec<f32>> {
147 let Some(inner) = self.inner.as_mut() else {
148 return Ok(mono.to_vec());
149 };
150 inner.pending.extend_from_slice(mono);
151 let mut out = Vec::new();
152 while inner.pending.len() >= inner.chunk_frames {
153 let chunk = &inner.pending[..inner.chunk_frames];
154 let Ok(input_adapter) = InterleavedSlice::new(chunk, 1, inner.chunk_frames) else {
155 unreachable!("chunk.len() == 1 * chunk_frames by construction")
156 };
157 let drained = inner
158 .resampler
159 .process(&input_adapter, 0, None)
160 .context("Resampler chunk processing failed")?;
161 inner.pending.drain(..inner.chunk_frames);
162 out.extend(drained.take_data());
164 }
165 Ok(out)
166 }
167
168 pub fn flush(&mut self) -> Result<Vec<f32>> {
172 let Some(inner) = self.inner.as_mut() else {
173 return Ok(Vec::new());
174 };
175 let tail = std::mem::take(&mut inner.pending);
176 if tail.is_empty() {
177 return Ok(Vec::new());
178 }
179 let Ok(input_adapter) = InterleavedSlice::new(&tail, 1, tail.len()) else {
180 unreachable!("tail.len() == 1 * tail.len() by construction")
181 };
182 let output_capacity = inner.resampler.output_frames_max();
183 let mut output_buf = vec![0.0_f32; output_capacity];
184 let Ok(mut output_adapter) = InterleavedSlice::new_mut(&mut output_buf, 1, output_capacity)
185 else {
186 unreachable!("output_buf.len() == 1 * output_capacity by construction")
187 };
188 let indexing = Indexing {
189 input_offset: 0,
190 output_offset: 0,
191 partial_len: Some(tail.len()),
192 active_channels_mask: None,
193 };
194 let (_in_frames, out_frames) = inner
195 .resampler
196 .process_into_buffer(&input_adapter, &mut output_adapter, Some(&indexing))
197 .context("Resampler flush failed")?;
198 output_buf.truncate(out_frames);
199 Ok(output_buf)
200 }
201}
202
203pub const OUTPUT_BITS_PER_SAMPLE: u16 = 16;
205
206pub struct WavWriter {
215 inner: hound::WavWriter<BufWriter<File>>,
216 path: PathBuf,
217 samples_written: u64,
218}
219
220impl WavWriter {
221 pub fn create(path: impl AsRef<Path>) -> Result<Self> {
225 let path = path.as_ref().to_path_buf();
226 if let Some(parent) = path.parent() {
227 if !parent.as_os_str().is_empty() {
228 fs::create_dir_all(parent).with_context(|| {
229 format!("Failed to create parent directory {}", parent.display())
230 })?;
231 }
232 }
233 let spec = WavSpec {
234 channels: 1,
235 sample_rate: TARGET_SAMPLE_RATE,
236 bits_per_sample: OUTPUT_BITS_PER_SAMPLE,
237 sample_format: SampleFormat::Int,
238 };
239 let inner = hound::WavWriter::create(&path, spec)
240 .with_context(|| format!("Failed to create WAV file at {}", path.display()))?;
241 Ok(Self {
242 inner,
243 path,
244 samples_written: 0,
245 })
246 }
247
248 pub fn write_samples(&mut self, samples: &[f32]) -> Result<()> {
251 for s in samples {
252 let clamped = s.clamp(-1.0, 1.0);
253 let scaled = (clamped * f32::from(i16::MAX)).round() as i16;
254 self.inner
255 .write_sample(scaled)
256 .with_context(|| format!("Failed to write sample to {}", self.path.display()))?;
257 }
258 self.samples_written += samples.len() as u64;
259 Ok(())
260 }
261
262 #[must_use]
265 pub fn samples_written(&self) -> u64 {
266 self.samples_written
267 }
268
269 #[must_use]
271 pub fn path(&self) -> &Path {
272 &self.path
273 }
274
275 pub fn finalize(self) -> Result<()> {
279 self.inner
280 .finalize()
281 .with_context(|| format!("Failed to finalize WAV file at {}", self.path.display()))
282 }
283}
284
285#[cfg(test)]
286#[allow(clippy::unwrap_used, clippy::expect_used)]
287mod tests {
288 use super::*;
289
290 use std::f32::consts::TAU;
291
292 #[test]
293 fn mono_mixdown_passes_through_mono_untouched() {
294 let input = vec![0.1, -0.2, 0.3, -0.4];
295 assert_eq!(mono_mixdown(&input, 1), input);
296 }
297
298 #[test]
299 fn mono_mixdown_averages_stereo_to_zero_for_inverted_signal() {
300 let input = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
302 let out = mono_mixdown(&input, 2);
303 assert_eq!(out, vec![0.0, 0.0, 0.0]);
304 }
305
306 #[test]
307 fn mono_mixdown_averages_quad_channel() {
308 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
310 let out = mono_mixdown(&input, 4);
311 assert_eq!(out, vec![2.5, 6.5]);
312 }
313
314 #[test]
315 fn mono_mixdown_drops_trailing_partial_frame() {
316 let input = vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 99.0];
318 let out = mono_mixdown(&input, 2);
319 assert_eq!(out, vec![1.0, 2.0, 3.0]);
320 }
321
322 #[test]
323 fn resampler_identity_path_returns_input_verbatim() -> Result<()> {
324 let mut r = Resampler::new(TARGET_SAMPLE_RATE)?;
325 assert_eq!(r.input_rate(), TARGET_SAMPLE_RATE);
326 assert_eq!(r.output_rate(), TARGET_SAMPLE_RATE);
327 let input: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) - 0.5).collect();
328 let out = r.push(&input)?;
329 assert_eq!(out, input);
330 let flushed = r.flush()?;
331 assert!(flushed.is_empty());
332 Ok(())
333 }
334
335 #[test]
336 fn resampler_rejects_zero_input_rate() {
337 let err = Resampler::new(0).err().expect("must reject zero rate");
338 assert!(err.to_string().contains("> 0"));
339 }
340
341 fn sine_wave(rate: u32, freq_hz: f32, duration_s: f32, amplitude: f32) -> Vec<f32> {
342 let n = (rate as f32 * duration_s) as usize;
343 (0..n)
344 .map(|i| amplitude * (TAU * freq_hz * i as f32 / rate as f32).sin())
345 .collect()
346 }
347
348 fn rms(samples: &[f32]) -> f32 {
349 if samples.is_empty() {
350 return 0.0;
351 }
352 let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
353 (sum_sq / samples.len() as f32).sqrt()
354 }
355
356 #[test]
357 fn resampler_48k_to_16k_preserves_signal_rms() -> Result<()> {
358 let input = sine_wave(48_000, 440.0, 2.0, 0.5);
360 let mut r = Resampler::new(48_000)?;
361 let mut output = r.push(&input)?;
362 output.extend(r.flush()?);
363 let expected_len: usize = 32_000;
368 let max_overrun = (RESAMPLER_CHUNK_FRAMES as f64 * 16_000.0 / 48_000.0).ceil() as usize;
369 assert!(
370 output.len() >= expected_len.saturating_sub(256),
371 "output too short: got {}, expected ≥ {}",
372 output.len(),
373 expected_len - 256
374 );
375 assert!(
376 output.len() <= expected_len + max_overrun + 256,
377 "output too long: got {}, expected ≤ {}",
378 output.len(),
379 expected_len + max_overrun + 256
380 );
381 let warmup = 800; let in_rms = rms(&input);
384 let out_rms = rms(&output[warmup..]);
385 assert!(
386 (in_rms - out_rms).abs() < 0.02,
387 "RMS drift too large: in={in_rms}, out={out_rms}"
388 );
389 Ok(())
390 }
391
392 #[test]
393 fn resampler_chunked_and_one_shot_match() -> Result<()> {
394 let input = sine_wave(48_000, 261.6, 1.0, 0.3);
396 let mut one_shot = Resampler::new(48_000)?;
397 let mut a = one_shot.push(&input)?;
398 a.extend(one_shot.flush()?);
399
400 let mut chunked = Resampler::new(48_000)?;
401 let mut b = Vec::new();
402 for chunk in input.chunks(977) {
403 b.extend(chunked.push(chunk)?);
405 }
406 b.extend(chunked.flush()?);
407
408 assert_eq!(a.len(), b.len(), "chunked and one-shot length disagree");
409 for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
410 assert!(
411 (x - y).abs() < 1e-5,
412 "sample {i}: chunked={y}, one-shot={x}"
413 );
414 }
415 Ok(())
416 }
417
418 #[test]
419 fn wav_writer_round_trips_samples() -> Result<()> {
420 let tmp = tempfile::TempDir::new()?;
421 let path = tmp.path().join("out.wav");
422 let mut writer = WavWriter::create(&path)?;
423 let samples: Vec<f32> = (0..1000)
424 .map(|i| (TAU * 440.0 * i as f32 / 16_000.0).sin() * 0.25)
425 .collect();
426 writer.write_samples(&samples)?;
427 assert_eq!(writer.samples_written(), 1000);
428 writer.finalize()?;
429
430 let mut reader = hound::WavReader::open(&path)?;
431 let spec = reader.spec();
432 assert_eq!(spec.channels, 1);
433 assert_eq!(spec.sample_rate, TARGET_SAMPLE_RATE);
434 assert_eq!(spec.bits_per_sample, OUTPUT_BITS_PER_SAMPLE);
435 assert_eq!(spec.sample_format, SampleFormat::Int);
436 let decoded: Vec<f32> = reader
437 .samples::<i16>()
438 .map(|s| f32::from(s.unwrap()) / f32::from(i16::MAX))
439 .collect();
440 assert_eq!(decoded.len(), 1000);
441 for (i, (orig, got)) in samples.iter().zip(decoded.iter()).enumerate() {
443 assert!(
444 (orig - got).abs() < 1.0 / f32::from(i16::MAX),
445 "sample {i}: orig={orig}, got={got}"
446 );
447 }
448 Ok(())
449 }
450
451 #[test]
452 fn wav_writer_clamps_samples_to_int_range() -> Result<()> {
453 let tmp = tempfile::TempDir::new()?;
454 let path = tmp.path().join("clamp.wav");
455 let mut writer = WavWriter::create(&path)?;
456 writer.write_samples(&[2.0, -2.0, 0.5, -0.5])?;
458 writer.finalize()?;
459
460 let mut reader = hound::WavReader::open(&path)?;
461 let decoded: Vec<i16> = reader.samples::<i16>().map(|s| s.unwrap()).collect();
462 assert_eq!(decoded[0], i16::MAX, "2.0 should clamp to i16::MAX");
463 assert_eq!(decoded[1], -i16::MAX, "-2.0 should clamp to -i16::MAX");
466 assert!((decoded[2] - 16384).abs() <= 1);
468 assert!((decoded[3] + 16384).abs() <= 1);
469 Ok(())
470 }
471
472 #[test]
473 fn wav_writer_creates_parent_dirs() -> Result<()> {
474 let tmp = tempfile::TempDir::new()?;
475 let nested = tmp.path().join("a").join("b").join("c");
476 let path = nested.join("nested.wav");
477 let writer = WavWriter::create(&path)?;
478 writer.finalize()?;
479 assert!(path.exists());
480 Ok(())
481 }
482
483 #[test]
484 fn mono_mixdown_passes_through_zero_channels_unchanged() {
485 let input = vec![0.1, 0.2, 0.3];
489 assert_eq!(mono_mixdown(&input, 0), input);
490 }
491
492 #[test]
493 fn resampler_push_empty_returns_empty() -> Result<()> {
494 let mut r = Resampler::new(48_000)?;
495 let out = r.push(&[])?;
496 assert!(out.is_empty());
497 Ok(())
498 }
499
500 #[test]
501 fn resampler_flush_with_no_pending_input_is_empty() -> Result<()> {
502 let mut r = Resampler::new(TARGET_SAMPLE_RATE)?;
504 assert!(r.flush()?.is_empty());
505 Ok(())
506 }
507
508 #[test]
509 fn resampler_identity_push_empty_returns_empty() -> Result<()> {
510 let mut r = Resampler::new(TARGET_SAMPLE_RATE)?;
511 let out = r.push(&[])?;
512 assert!(out.is_empty());
513 Ok(())
514 }
515}