use rubato::{
Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
};
use crate::audio::{AudioError, AudioFrame};
const SINC_LEN: usize = 256;
const F_CUTOFF: f32 = 0.95;
const OVERSAMPLING: usize = 256;
pub struct AudioResampler {
resampler: SincFixedIn<f32>,
in_rate: u32,
out_rate: u32,
channels: u8,
chunk_size: usize,
deinterleaved: Vec<Vec<f32>>,
carry: Vec<Vec<f32>>,
}
impl AudioResampler {
pub fn new(
in_rate: u32,
out_rate: u32,
channels: u8,
chunk_size: usize,
) -> Result<Self, AudioError> {
if in_rate == 0 || out_rate == 0 {
return Err(AudioError::Resample(format!(
"invalid sample rate {in_rate} -> {out_rate}"
)));
}
if channels == 0 || channels > 8 {
return Err(AudioError::Unsupported(format!(
"resampler channel count {channels} (must be 1..=8)"
)));
}
if chunk_size == 0 {
return Err(AudioError::Resample("chunk_size must be > 0".to_string()));
}
let params = SincInterpolationParameters {
sinc_len: SINC_LEN,
f_cutoff: F_CUTOFF,
interpolation: SincInterpolationType::Cubic,
oversampling_factor: OVERSAMPLING,
window: WindowFunction::BlackmanHarris2,
};
let ratio = f64::from(out_rate) / f64::from(in_rate);
let resampler = SincFixedIn::<f32>::new(ratio, 2.0, params, chunk_size, channels as usize)
.map_err(|e| AudioError::Resample(format!("rubato init: {e:?}")))?;
let deinterleaved = vec![vec![0.0f32; chunk_size]; channels as usize];
let carry = vec![Vec::new(); channels as usize];
Ok(Self {
resampler,
in_rate,
out_rate,
channels,
chunk_size,
deinterleaved,
carry,
})
}
pub fn in_rate(&self) -> u32 {
self.in_rate
}
pub fn out_rate(&self) -> u32 {
self.out_rate
}
pub fn channels(&self) -> u8 {
self.channels
}
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
pub fn process(&mut self, frame: &AudioFrame, out: &mut Vec<f32>) -> Result<(), AudioError> {
if frame.channels != self.channels {
return Err(AudioError::Resample(format!(
"channel mismatch: resampler={}, frame={}",
self.channels, frame.channels
)));
}
if frame.sample_rate != self.in_rate {
return Err(AudioError::Resample(format!(
"sample rate mismatch: resampler in_rate={}, frame={}",
self.in_rate, frame.sample_rate
)));
}
let chans = self.channels as usize;
let frames = frame.samples.len() / chans;
for ch in 0..chans {
let base = self.carry[ch].len();
self.carry[ch].reserve(frames);
for i in 0..frames {
self.carry[ch].push(frame.samples[i * chans + ch]);
}
debug_assert_eq!(self.carry[ch].len(), base + frames);
}
while self.carry[0].len() >= self.chunk_size {
for ch in 0..chans {
self.deinterleaved[ch].copy_from_slice(&self.carry[ch][..self.chunk_size]);
}
for ch in 0..chans {
self.carry[ch].drain(..self.chunk_size);
}
let result = self
.resampler
.process(&self.deinterleaved, None)
.map_err(|e| AudioError::Resample(format!("rubato process: {e:?}")))?;
let n_out = result[0].len();
out.reserve(n_out * chans);
for i in 0..n_out {
for ch in 0..chans {
out.push(result[ch][i]);
}
}
}
Ok(())
}
pub fn flush(&mut self, out: &mut Vec<f32>) -> Result<(), AudioError> {
let chans = self.channels as usize;
let n = self.carry[0].len();
if n == 0 {
return Ok(());
}
for ch in 0..chans {
self.carry[ch].resize(self.chunk_size, 0.0);
self.deinterleaved[ch].copy_from_slice(&self.carry[ch][..self.chunk_size]);
self.carry[ch].clear();
}
let result = self
.resampler
.process(&self.deinterleaved, None)
.map_err(|e| AudioError::Resample(format!("rubato flush: {e:?}")))?;
let n_out = result[0].len();
out.reserve(n_out * chans);
for i in 0..n_out {
for ch in 0..chans {
out.push(result[ch][i]);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resample_44100_to_48000_preserves_sample_count_within_tolerance() {
let chunk = 44100;
let mut r = AudioResampler::new(44100, 48000, 1, chunk).expect("resampler");
let frame = AudioFrame {
samples: vec![0.0f32; chunk],
sample_rate: 44100,
channels: 1,
pts: 0,
};
let mut out = Vec::new();
r.process(&frame, &mut out).expect("process");
let diff = (out.len() as i64 - 48000).abs();
assert!(
diff <= 480, "expected ~48000 output samples, got {} (diff {} — sinc filter delay from SINC_LEN)",
out.len(),
diff
);
}
#[test]
fn resample_rejects_zero_rates() {
assert!(AudioResampler::new(0, 48000, 1, 1024).is_err());
assert!(AudioResampler::new(44100, 0, 1, 1024).is_err());
}
#[test]
fn resample_rejects_unsupported_channels() {
assert!(AudioResampler::new(44100, 48000, 0, 1024).is_err());
assert!(AudioResampler::new(44100, 48000, 9, 1024).is_err());
assert!(AudioResampler::new(44100, 48000, 6, 1024).is_ok());
}
#[test]
fn resample_input_validation_catches_channel_mismatch() {
let mut r = AudioResampler::new(44100, 48000, 2, 1024).expect("resampler");
let frame = AudioFrame {
samples: vec![0.0f32; 1024],
sample_rate: 44100,
channels: 1,
pts: 0,
};
let mut out = Vec::new();
assert!(r.process(&frame, &mut out).is_err());
}
#[test]
fn resample_input_validation_catches_rate_mismatch() {
let mut r = AudioResampler::new(44100, 48000, 2, 1024).expect("resampler");
let frame = AudioFrame {
samples: vec![0.0f32; 2048],
sample_rate: 22050,
channels: 2,
pts: 0,
};
let mut out = Vec::new();
assert!(r.process(&frame, &mut out).is_err());
}
#[test]
fn resample_stereo_44100_to_48000_interleaved_layout_preserved() {
let chunk = 44100;
let mut r = AudioResampler::new(44100, 48000, 2, chunk).expect("resampler");
let mut samples = Vec::with_capacity(chunk * 2);
for _ in 0..chunk {
samples.push(0.1f32);
samples.push(-0.1f32);
}
let frame = AudioFrame {
samples,
sample_rate: 44100,
channels: 2,
pts: 0,
};
let mut out = Vec::new();
r.process(&frame, &mut out).expect("process");
assert!(out.len() % 2 == 0, "stereo output must be even");
let warmup = 512;
let mut ok_l = 0;
let mut ok_r = 0;
for i in (warmup..out.len()).step_by(2) {
if (out[i] - 0.1).abs() < 0.05 {
ok_l += 1;
}
if (out[i + 1] - (-0.1)).abs() < 0.05 {
ok_r += 1;
}
}
assert!(
ok_l > 100,
"L channel should converge near 0.1; got {ok_l} matches"
);
assert!(
ok_r > 100,
"R channel should converge near -0.1; got {ok_r} matches"
);
}
}