1use rubato::{
28 Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
29};
30
31use crate::audio::{AudioError, AudioFrame};
32
33const SINC_LEN: usize = 256;
37const F_CUTOFF: f32 = 0.95;
41const OVERSAMPLING: usize = 256;
44
45pub struct AudioResampler {
46 resampler: SincFixedIn<f32>,
47 in_rate: u32,
48 out_rate: u32,
49 channels: u8,
50 chunk_size: usize,
51 deinterleaved: Vec<Vec<f32>>,
54 carry: Vec<Vec<f32>>,
58}
59
60impl AudioResampler {
61 pub fn new(
65 in_rate: u32,
66 out_rate: u32,
67 channels: u8,
68 chunk_size: usize,
69 ) -> Result<Self, AudioError> {
70 if in_rate == 0 || out_rate == 0 {
71 return Err(AudioError::Resample(format!(
72 "invalid sample rate {in_rate} -> {out_rate}"
73 )));
74 }
75 if channels == 0 || channels > 8 {
84 return Err(AudioError::Unsupported(format!(
85 "resampler channel count {channels} (must be 1..=8)"
86 )));
87 }
88 if chunk_size == 0 {
89 return Err(AudioError::Resample("chunk_size must be > 0".to_string()));
90 }
91
92 let params = SincInterpolationParameters {
93 sinc_len: SINC_LEN,
94 f_cutoff: F_CUTOFF,
95 interpolation: SincInterpolationType::Cubic,
96 oversampling_factor: OVERSAMPLING,
97 window: WindowFunction::BlackmanHarris2,
98 };
99
100 let ratio = f64::from(out_rate) / f64::from(in_rate);
101 let resampler = SincFixedIn::<f32>::new(ratio, 2.0, params, chunk_size, channels as usize)
102 .map_err(|e| AudioError::Resample(format!("rubato init: {e:?}")))?;
103
104 let deinterleaved = vec![vec![0.0f32; chunk_size]; channels as usize];
105 let carry = vec![Vec::new(); channels as usize];
106
107 Ok(Self {
108 resampler,
109 in_rate,
110 out_rate,
111 channels,
112 chunk_size,
113 deinterleaved,
114 carry,
115 })
116 }
117
118 pub fn in_rate(&self) -> u32 {
119 self.in_rate
120 }
121 pub fn out_rate(&self) -> u32 {
122 self.out_rate
123 }
124 pub fn channels(&self) -> u8 {
125 self.channels
126 }
127 pub fn chunk_size(&self) -> usize {
128 self.chunk_size
129 }
130
131 pub fn process(&mut self, frame: &AudioFrame, out: &mut Vec<f32>) -> Result<(), AudioError> {
140 if frame.channels != self.channels {
141 return Err(AudioError::Resample(format!(
142 "channel mismatch: resampler={}, frame={}",
143 self.channels, frame.channels
144 )));
145 }
146 if frame.sample_rate != self.in_rate {
147 return Err(AudioError::Resample(format!(
148 "sample rate mismatch: resampler in_rate={}, frame={}",
149 self.in_rate, frame.sample_rate
150 )));
151 }
152
153 let chans = self.channels as usize;
155 let frames = frame.samples.len() / chans;
156 for ch in 0..chans {
157 let base = self.carry[ch].len();
158 self.carry[ch].reserve(frames);
159 for i in 0..frames {
160 self.carry[ch].push(frame.samples[i * chans + ch]);
161 }
162 debug_assert_eq!(self.carry[ch].len(), base + frames);
166 }
167
168 while self.carry[0].len() >= self.chunk_size {
170 for ch in 0..chans {
171 self.deinterleaved[ch].copy_from_slice(&self.carry[ch][..self.chunk_size]);
172 }
173 for ch in 0..chans {
174 self.carry[ch].drain(..self.chunk_size);
175 }
176 let result = self
177 .resampler
178 .process(&self.deinterleaved, None)
179 .map_err(|e| AudioError::Resample(format!("rubato process: {e:?}")))?;
180 let n_out = result[0].len();
182 out.reserve(n_out * chans);
183 for i in 0..n_out {
184 for ch in 0..chans {
185 out.push(result[ch][i]);
186 }
187 }
188 }
189
190 Ok(())
191 }
192
193 pub fn flush(&mut self, out: &mut Vec<f32>) -> Result<(), AudioError> {
196 let chans = self.channels as usize;
197 let n = self.carry[0].len();
198 if n == 0 {
199 return Ok(());
200 }
201 for ch in 0..chans {
202 self.carry[ch].resize(self.chunk_size, 0.0);
203 self.deinterleaved[ch].copy_from_slice(&self.carry[ch][..self.chunk_size]);
204 self.carry[ch].clear();
205 }
206 let result = self
207 .resampler
208 .process(&self.deinterleaved, None)
209 .map_err(|e| AudioError::Resample(format!("rubato flush: {e:?}")))?;
210 let n_out = result[0].len();
211 out.reserve(n_out * chans);
212 for i in 0..n_out {
213 for ch in 0..chans {
214 out.push(result[ch][i]);
215 }
216 }
217 Ok(())
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn resample_44100_to_48000_preserves_sample_count_within_tolerance() {
227 let chunk = 44100;
236 let mut r = AudioResampler::new(44100, 48000, 1, chunk).expect("resampler");
237 let frame = AudioFrame {
238 samples: vec![0.0f32; chunk],
239 sample_rate: 44100,
240 channels: 1,
241 pts: 0,
242 };
243 let mut out = Vec::new();
244 r.process(&frame, &mut out).expect("process");
245 let diff = (out.len() as i64 - 48000).abs();
246 assert!(
247 diff <= 480, "expected ~48000 output samples, got {} (diff {} — sinc filter delay from SINC_LEN)",
249 out.len(),
250 diff
251 );
252 }
253
254 #[test]
255 fn resample_rejects_zero_rates() {
256 assert!(AudioResampler::new(0, 48000, 1, 1024).is_err());
257 assert!(AudioResampler::new(44100, 0, 1, 1024).is_err());
258 }
259
260 #[test]
261 fn resample_rejects_unsupported_channels() {
262 assert!(AudioResampler::new(44100, 48000, 0, 1024).is_err());
264 assert!(AudioResampler::new(44100, 48000, 9, 1024).is_err());
265 assert!(AudioResampler::new(44100, 48000, 6, 1024).is_ok());
268 }
269
270 #[test]
271 fn resample_input_validation_catches_channel_mismatch() {
272 let mut r = AudioResampler::new(44100, 48000, 2, 1024).expect("resampler");
273 let frame = AudioFrame {
274 samples: vec![0.0f32; 1024],
275 sample_rate: 44100,
276 channels: 1,
277 pts: 0,
278 };
279 let mut out = Vec::new();
280 assert!(r.process(&frame, &mut out).is_err());
281 }
282
283 #[test]
284 fn resample_input_validation_catches_rate_mismatch() {
285 let mut r = AudioResampler::new(44100, 48000, 2, 1024).expect("resampler");
286 let frame = AudioFrame {
287 samples: vec![0.0f32; 2048],
288 sample_rate: 22050,
289 channels: 2,
290 pts: 0,
291 };
292 let mut out = Vec::new();
293 assert!(r.process(&frame, &mut out).is_err());
294 }
295
296 #[test]
297 fn resample_stereo_44100_to_48000_interleaved_layout_preserved() {
298 let chunk = 44100;
299 let mut r = AudioResampler::new(44100, 48000, 2, chunk).expect("resampler");
300 let mut samples = Vec::with_capacity(chunk * 2);
302 for _ in 0..chunk {
303 samples.push(0.1f32);
304 samples.push(-0.1f32);
305 }
306 let frame = AudioFrame {
307 samples,
308 sample_rate: 44100,
309 channels: 2,
310 pts: 0,
311 };
312 let mut out = Vec::new();
313 r.process(&frame, &mut out).expect("process");
314 assert!(out.len() % 2 == 0, "stereo output must be even");
315 let warmup = 512;
318 let mut ok_l = 0;
319 let mut ok_r = 0;
320 for i in (warmup..out.len()).step_by(2) {
321 if (out[i] - 0.1).abs() < 0.05 {
322 ok_l += 1;
323 }
324 if (out[i + 1] - (-0.1)).abs() < 0.05 {
325 ok_r += 1;
326 }
327 }
328 assert!(
329 ok_l > 100,
330 "L channel should converge near 0.1; got {ok_l} matches"
331 );
332 assert!(
333 ok_r > 100,
334 "R channel should converge near -0.1; got {ok_r} matches"
335 );
336 }
337}