1use audioadapter_buffers::direct::InterleavedSlice;
5use rubato::{
6 Async, FixedAsync, PolynomialDegree, SincInterpolationParameters, SincInterpolationType,
7 WindowFunction,
8};
9use wedeo_core::error::{Error, Result};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Quality {
14 Fast,
17 Normal,
19 High,
21}
22
23pub struct Resampler {
28 inner: Box<dyn rubato::Resampler<f32>>,
29 channels: usize,
30 from_rate: u32,
31 to_rate: u32,
32 chunk_size: usize,
34 pending: Vec<f32>,
36}
37
38const DEFAULT_CHUNK_SIZE: usize = 1024;
40
41fn map_rubato_err<E: std::fmt::Display>(e: E) -> Error {
42 Error::Other(format!("resample: {e}"))
43}
44
45impl Resampler {
46 pub fn new(from_rate: u32, to_rate: u32, channels: usize, quality: Quality) -> Result<Self> {
52 if from_rate == 0 || to_rate == 0 {
53 return Err(Error::InvalidArgument);
54 }
55 if channels == 0 {
56 return Err(Error::InvalidArgument);
57 }
58
59 let ratio = to_rate as f64 / from_rate as f64;
60 let chunk_size = DEFAULT_CHUNK_SIZE;
61
62 let inner: Box<dyn rubato::Resampler<f32>> = match quality {
63 Quality::Fast => Box::new(
64 Async::<f32>::new_poly(
65 ratio,
66 1.0,
67 PolynomialDegree::Cubic,
68 chunk_size,
69 channels,
70 FixedAsync::Input,
71 )
72 .map_err(map_rubato_err)?,
73 ),
74 Quality::Normal => {
75 let params = SincInterpolationParameters {
76 sinc_len: 128,
77 f_cutoff: 0.95,
78 interpolation: SincInterpolationType::Linear,
79 oversampling_factor: 128,
80 window: WindowFunction::BlackmanHarris2,
81 };
82 Box::new(
83 Async::<f32>::new_sinc(
84 ratio,
85 1.0,
86 ¶ms,
87 chunk_size,
88 channels,
89 FixedAsync::Input,
90 )
91 .map_err(map_rubato_err)?,
92 )
93 }
94 Quality::High => {
95 let params = SincInterpolationParameters {
96 sinc_len: 256,
97 f_cutoff: 0.95,
98 interpolation: SincInterpolationType::Cubic,
99 oversampling_factor: 256,
100 window: WindowFunction::BlackmanHarris2,
101 };
102 Box::new(
103 Async::<f32>::new_sinc(
104 ratio,
105 1.0,
106 ¶ms,
107 chunk_size,
108 channels,
109 FixedAsync::Input,
110 )
111 .map_err(map_rubato_err)?,
112 )
113 }
114 };
115
116 Ok(Self {
117 inner,
118 channels,
119 from_rate,
120 to_rate,
121 chunk_size,
122 pending: Vec::new(),
123 })
124 }
125
126 pub fn process(&mut self, input: &[f32]) -> Result<Vec<f32>> {
135 if !input.len().is_multiple_of(self.channels) {
136 return Err(Error::InvalidArgument);
137 }
138
139 self.pending.extend_from_slice(input);
141
142 let samples_per_chunk = self.chunk_size * self.channels;
143 let mut output = Vec::new();
144
145 while self.pending.len() >= samples_per_chunk {
147 let rest = self.pending.split_off(samples_per_chunk);
150 let chunk = std::mem::replace(&mut self.pending, rest);
151 self.process_chunk(&chunk, &mut output)?;
152 }
153
154 Ok(output)
155 }
156
157 pub fn flush(&mut self) -> Result<Vec<f32>> {
162 if self.pending.is_empty() {
163 return Ok(Vec::new());
164 }
165
166 let samples_per_chunk = self.chunk_size * self.channels;
167
168 let pending_frames = self.pending.len() / self.channels;
170 let partial_samples = self.pending.len() % self.channels;
171
172 if partial_samples != 0 {
174 self.pending
175 .resize(self.pending.len() + (self.channels - partial_samples), 0.0);
176 }
177
178 self.pending.resize(samples_per_chunk, 0.0);
180
181 let mut output = Vec::new();
182
183 let input_adapter = InterleavedSlice::new(&self.pending, self.channels, self.chunk_size)
184 .map_err(map_rubato_err)?;
185
186 let result = self
187 .inner
188 .process(&input_adapter, 0, None)
189 .map_err(map_rubato_err)?;
190
191 let out_data = result.take_data();
193 let out_frames = self.inner.output_frames_next();
194 let valid_samples = out_frames.min(out_data.len() / self.channels) * self.channels;
196 output.extend_from_slice(&out_data[..valid_samples]);
197
198 let expected_out_frames = self.output_frames_estimate(pending_frames);
201 let expected_samples = expected_out_frames * self.channels;
202 if output.len() > expected_samples {
203 output.truncate(expected_samples);
204 }
205
206 self.pending.clear();
207 Ok(output)
208 }
209
210 pub fn output_frames_estimate(&self, input_frames: usize) -> usize {
212 (input_frames as u64 * self.to_rate as u64).div_ceil(self.from_rate as u64) as usize
213 }
214
215 pub fn from_rate(&self) -> u32 {
217 self.from_rate
218 }
219
220 pub fn to_rate(&self) -> u32 {
222 self.to_rate
223 }
224
225 pub fn channels(&self) -> usize {
227 self.channels
228 }
229
230 pub fn reset(&mut self) {
232 self.inner.reset();
233 self.pending.clear();
234 }
235
236 fn process_chunk(&mut self, chunk: &[f32], output: &mut Vec<f32>) -> Result<()> {
239 let frames = chunk.len() / self.channels;
240 let input_adapter =
241 InterleavedSlice::new(chunk, self.channels, frames).map_err(map_rubato_err)?;
242
243 let result = self
244 .inner
245 .process(&input_adapter, 0, None)
246 .map_err(map_rubato_err)?;
247
248 let out_data = result.take_data();
249 let out_frames = out_data.len() / self.channels;
250 output.extend_from_slice(&out_data[..out_frames * self.channels]);
251 Ok(())
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn create_resampler_all_qualities() {
261 for quality in [Quality::Fast, Quality::Normal, Quality::High] {
262 let r = Resampler::new(44100, 48000, 2, quality);
263 assert!(r.is_ok(), "Failed to create resampler with {quality:?}");
264 let r = r.unwrap();
265 assert_eq!(r.from_rate(), 44100);
266 assert_eq!(r.to_rate(), 48000);
267 assert_eq!(r.channels(), 2);
268 }
269 }
270
271 #[test]
272 fn invalid_params() {
273 assert!(Resampler::new(0, 48000, 2, Quality::Fast).is_err());
274 assert!(Resampler::new(44100, 0, 2, Quality::Fast).is_err());
275 assert!(Resampler::new(44100, 48000, 0, Quality::Fast).is_err());
276 }
277
278 #[test]
279 fn process_silence() {
280 let mut r = Resampler::new(44100, 48000, 1, Quality::Fast).unwrap();
281 let input = vec![0.0f32; 44100]; let output = r.process(&input).unwrap();
283 let tail = r.flush().unwrap();
284 let total_frames = output.len() + tail.len();
285 assert!(
288 total_frames > 40000 && total_frames < 56000,
289 "Unexpected output length: {total_frames}"
290 );
291 for &s in output.iter().chain(tail.iter()) {
293 assert!(s.abs() < 1e-6, "Non-silent sample in silence resample: {s}");
294 }
295 }
296
297 #[test]
298 fn process_non_multiple_of_chunk() {
299 let mut r = Resampler::new(48000, 16000, 2, Quality::Fast).unwrap();
301 let frames = 3000; let input = vec![0.0f32; frames * 2];
303 let output = r.process(&input).unwrap();
304 let tail = r.flush().unwrap();
305 let total_frames = (output.len() + tail.len()) / 2;
306 let expected = r.output_frames_estimate(frames);
307 assert!(
309 (total_frames as f64 - expected as f64).abs() / expected as f64 <= 0.2,
310 "Output frame count {total_frames} too far from estimate {expected}"
311 );
312 }
313
314 #[test]
315 fn reset_clears_state() {
316 let mut r = Resampler::new(44100, 48000, 1, Quality::Fast).unwrap();
317 let _ = r.process(&vec![0.5f32; 500]).unwrap();
318 r.reset();
319 let tail = r.flush().unwrap();
321 assert!(tail.is_empty());
322 }
323
324 #[test]
325 fn channel_mismatch_rejected() {
326 let mut r = Resampler::new(44100, 48000, 2, Quality::Fast).unwrap();
327 let result = r.process(&[1.0, 2.0, 3.0]);
329 assert!(result.is_err());
330 }
331}