1use alloc::{sync::Arc, vec, vec::Vec};
2use core::fmt;
3#[cfg(not(feature = "no_std"))]
4use std::{
5 collections::HashMap,
6 sync::{LazyLock, Mutex},
7};
8
9use crate::{
10 Complex32, Forward, Inverse, Radix, RadixFFT, SampleRate,
11 error::ResampleError,
12 fft::planner::ConversionConfig,
13 window::{WindowType, calculate_cutoff_kaiser, make_sincs_for_kaiser},
14};
15
16const KAISER_BETA: f64 = 10.0;
17
18pub(crate) struct FftCacheData {
19 filter_spectrum: Arc<[Complex32]>,
20 fft: Arc<RadixFFT<Forward>>,
21 ifft: Arc<RadixFFT<Inverse>>,
22}
23
24impl Clone for FftCacheData {
25 fn clone(&self) -> Self {
26 Self {
27 filter_spectrum: Arc::clone(&self.filter_spectrum),
28 fft: Arc::clone(&self.fft),
29 ifft: Arc::clone(&self.ifft),
30 }
31 }
32}
33
34#[cfg(not(feature = "no_std"))]
35static FFT_CACHE: LazyLock<Mutex<HashMap<u64, FftCacheData>>> =
36 LazyLock::new(|| Mutex::new(HashMap::new()));
37
38pub struct ResamplerFft {
44 channels: usize,
45 fft_resampler: FftResampler,
46 chunk_size_input: usize,
47 chunk_size_output: usize,
48 fft_size_input: usize,
49 fft_size_output: usize,
50 saved_frames: usize,
51 overlaps: Vec<f32>,
52 input_scratch: Vec<f32>,
53 output_scratch: Vec<f32>,
54}
55
56impl fmt::Debug for ResamplerFft {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 f.debug_struct("ResamplerFft")
59 .field("channels", &self.channels)
60 .field("chunk_size_input", &self.chunk_size_input)
61 .field("chunk_size_output", &self.chunk_size_output)
62 .field("fft_size_input", &self.fft_size_input)
63 .field("fft_size_output", &self.fft_size_output)
64 .finish_non_exhaustive()
65 }
66}
67
68impl ResamplerFft {
69 pub fn new(
76 channels: usize,
77 sample_rate_input: SampleRate,
78 sample_rate_output: SampleRate,
79 ) -> Self {
80 let config = ConversionConfig::from_sample_rates(sample_rate_input, sample_rate_output);
83 let (fft_size_input, factors_in, fft_size_output, factors_out) =
84 config.scale_for_throughput();
85
86 let overlaps: Vec<f32> = vec![0.0; fft_size_output * channels];
87
88 let chunk_size_input = fft_size_input * channels;
89 let chunk_size_output = fft_size_output * channels;
90
91 let needed_input_buffer_size = chunk_size_input + fft_size_input;
92 let needed_buffer_size_output = chunk_size_output + fft_size_output;
93 let input_scratch: Vec<f32> = vec![0.0; needed_input_buffer_size * channels];
94 let output_scratch: Vec<f32> = vec![0.0; needed_buffer_size_output * channels];
95
96 let saved_frames = 0;
97
98 let fft_resampler = FftResampler::new(
99 u32::from(sample_rate_input),
100 u32::from(sample_rate_output),
101 fft_size_input,
102 factors_in,
103 fft_size_output,
104 factors_out,
105 );
106
107 ResamplerFft {
108 channels,
109 chunk_size_input,
110 chunk_size_output,
111 fft_size_input,
112 fft_size_output,
113 overlaps,
114 input_scratch,
115 output_scratch,
116 saved_frames,
117 fft_resampler,
118 }
119 }
120
121 fn input_scratch_ch_size(&self) -> usize {
123 self.chunk_size_input + self.fft_size_input
124 }
125
126 fn output_scratch_ch_size(&self) -> usize {
128 self.chunk_size_input + self.fft_size_input
129 }
130
131 pub fn chunk_size_input(&self) -> usize {
136 self.chunk_size_input
137 }
138
139 pub fn chunk_size_output(&self) -> usize {
144 self.chunk_size_output
145 }
146
147 pub fn delay(&self) -> usize {
152 self.fft_size_input / 2
153 }
154
155 pub fn resample(&mut self, input: &[f32], output: &mut [f32]) -> Result<(), ResampleError> {
183 let expected_input_len = self.chunk_size_input;
184 let min_output_len = self.chunk_size_output;
185
186 if input.len() < expected_input_len {
187 return Err(ResampleError::InvalidInputBufferSize);
188 }
189
190 if output.len() < min_output_len {
191 return Err(ResampleError::InvalidOutputBufferSize);
192 }
193
194 let in_scratch_ch_len = self.input_scratch_ch_size();
195 let out_scratch_ch_len = self.output_scratch_ch_size();
196 (0..self.fft_size_input).for_each(|frame_index| {
198 (0..self.channels).for_each(|channel| {
199 self.input_scratch[channel * in_scratch_ch_len + frame_index] =
200 input[frame_index * self.channels + channel];
201 });
202 });
203
204 let (subchunks_to_process, output_scratch_offset) = (
205 self.chunk_size_input / (self.fft_size_input * self.channels),
206 self.saved_frames,
207 );
208
209 for channel in 0..self.channels {
211 let start = channel * in_scratch_ch_len;
212 let end = start + in_scratch_ch_len;
213 for (input_chunk, output_chunk) in self.input_scratch[start..end]
214 .chunks(self.fft_size_input)
215 .take(subchunks_to_process)
216 .zip(
217 self.output_scratch[channel * out_scratch_ch_len + output_scratch_offset..]
218 .chunks_mut(self.fft_size_output),
219 )
220 {
221 let start = self.fft_size_output * channel;
222 let end = start + self.fft_size_output;
223 self.fft_resampler.resample(
224 input_chunk,
225 output_chunk,
226 &mut self.overlaps[start..end],
227 );
228 }
229 }
230
231 (0..self.fft_size_output).for_each(|frame_index| {
233 (0..self.channels).for_each(|channel| {
234 output[frame_index * self.channels + channel] =
235 self.output_scratch[channel * out_scratch_ch_len + frame_index];
236 });
237 });
238
239 Ok(())
240 }
241}
242
243struct FftResampler {
248 fft_size_input: usize,
249 fft_size_output: usize,
250 fft: Arc<RadixFFT<Forward>>,
251 ifft: Arc<RadixFFT<Inverse>>,
252 scratchpad_forward: Vec<Complex32>,
253 scratchpad_inverse: Vec<Complex32>,
254 filter_spectrum: Arc<[Complex32]>,
255 input_spectrum: Vec<Complex32>,
256 output_spectrum: Vec<Complex32>,
257 input_buffer: Vec<f32>,
258 output_buffer: Vec<f32>,
259}
260
261impl FftResampler {
262 pub(crate) fn new(
263 sample_rate_input: u32,
264 sample_rate_output: u32,
265 fft_size_input: usize,
266 factors_input: Vec<Radix>,
267 fft_size_output: usize,
268 factors_output: Vec<Radix>,
269 ) -> Self {
270 let cached = Self::get_or_create_fft_data(
271 sample_rate_input,
272 sample_rate_output,
273 fft_size_input,
274 factors_input,
275 fft_size_output,
276 factors_output,
277 );
278
279 let input_spectrum: Vec<Complex32> = vec![Complex32::zero(); fft_size_input + 1];
280 let input_buffer: Vec<f32> = vec![0.0; 2 * fft_size_input];
281 let output_spectrum: Vec<Complex32> = vec![Complex32::zero(); fft_size_output + 1];
282 let output_buffer: Vec<f32> = vec![0.0; 2 * fft_size_output];
283
284 let scratchpad_forward = vec![Complex32::zero(); cached.fft.scratchpad_size()];
285 let scratchpad_inverse = vec![Complex32::zero(); cached.ifft.scratchpad_size()];
286
287 FftResampler {
288 fft_size_input,
289 fft_size_output,
290 fft: cached.fft,
291 ifft: cached.ifft,
292 scratchpad_forward,
293 scratchpad_inverse,
294 filter_spectrum: cached.filter_spectrum,
295 input_spectrum,
296 output_spectrum,
297 input_buffer,
298 output_buffer,
299 }
300 }
301
302 #[cfg(not(feature = "no_std"))]
306 fn get_or_create_fft_data(
307 sample_rate_input: u32,
308 sample_rate_output: u32,
309 fft_size_input: usize,
310 factors_in: Vec<Radix>,
311 fft_size_output: usize,
312 factors_out: Vec<Radix>,
313 ) -> FftCacheData {
314 let cache_key = ((sample_rate_input as u64) << 32) | (sample_rate_output as u64);
315 FFT_CACHE
316 .lock()
317 .unwrap()
318 .entry(cache_key)
319 .or_insert_with(|| {
320 Self::create_fft_data(fft_size_input, factors_in, fft_size_output, factors_out)
321 })
322 .clone()
323 }
324
325 #[cfg(feature = "no_std")]
326 fn get_or_create_fft_data(
327 _sample_rate_input: u32,
328 _sample_rate_output: u32,
329 fft_size_input: usize,
330 factors_in: Vec<Radix>,
331 fft_size_output: usize,
332 factors_out: Vec<Radix>,
333 ) -> FftCacheData {
334 Self::create_fft_data(fft_size_input, factors_in, fft_size_output, factors_out)
335 }
336
337 fn create_fft_data(
339 fft_size_input: usize,
340 factors_in: Vec<Radix>,
341 fft_size_output: usize,
342 factors_out: Vec<Radix>,
343 ) -> FftCacheData {
344 let mut fft_factors_input = factors_in;
346 fft_factors_input.push(Radix::Factor2);
347 let mut fft_factors_output = factors_out;
348 fft_factors_output.push(Radix::Factor2);
349
350 let fft = RadixFFT::<Forward>::new(fft_factors_input);
351 let ifft = RadixFFT::<Inverse>::new(fft_factors_output);
352
353 let cutoff = match fft_size_input > fft_size_output {
354 true => {
355 let scale = fft_size_output as f64 / fft_size_input as f64;
356 calculate_cutoff_kaiser(fft_size_output, KAISER_BETA) * scale
357 }
358 false => calculate_cutoff_kaiser(fft_size_input, KAISER_BETA),
359 };
360
361 let sincs = make_sincs_for_kaiser(
362 fft_size_input,
363 1,
364 cutoff as f32,
365 KAISER_BETA,
366 WindowType::Periodic,
367 );
368 let mut filter_time = vec![0.0; 2 * fft_size_input];
369 let mut filter_spectrum = vec![Complex32::zero(); fft_size_input + 1];
370
371 for (index, filter_value) in filter_time.iter_mut().enumerate().take(fft_size_input) {
372 *filter_value = sincs[0][index] / (2 * fft_size_input) as f32;
373 }
374
375 let mut scratchpad = vec![Complex32::zero(); fft.scratchpad_size()];
376 fft.process(&filter_time, &mut filter_spectrum, &mut scratchpad);
377
378 FftCacheData {
379 filter_spectrum: filter_spectrum.into(),
380 fft: Arc::new(fft),
381 ifft: Arc::new(ifft),
382 }
383 }
384
385 fn resample(&mut self, wave_input: &[f32], wave_output: &mut [f32], overlap: &mut [f32]) {
386 self.input_buffer[..self.fft_size_input].copy_from_slice(wave_input);
388 self.input_buffer[self.fft_size_input..].fill(0.0);
389
390 self.fft.process(
391 &self.input_buffer,
392 &mut self.input_spectrum,
393 &mut self.scratchpad_forward,
394 );
395
396 let new_length = match self.fft_size_input < self.fft_size_output {
397 true => self.fft_size_input + 1,
398 false => self.fft_size_output,
399 };
400
401 self.input_spectrum
402 .iter_mut()
403 .take(new_length)
404 .zip(self.filter_spectrum.iter())
405 .for_each(|(spectrum, filter)| *spectrum = spectrum.mul(filter));
406
407 self.output_spectrum[0..new_length].copy_from_slice(&self.input_spectrum[0..new_length]);
408 self.output_spectrum[new_length..].fill(Complex32::zero());
409
410 self.ifft.process(
411 &self.output_spectrum,
412 &mut self.output_buffer,
413 &mut self.scratchpad_inverse,
414 );
415
416 for (index, item) in wave_output
417 .iter_mut()
418 .enumerate()
419 .take(self.fft_size_output)
420 {
421 *item = self.output_buffer[index] + overlap[index];
422 }
423 overlap.copy_from_slice(&self.output_buffer[self.fft_size_output..]);
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use core::f32::consts::PI;
430
431 use super::*;
432
433 const EPSILON: f32 = 0.02;
434
435 fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
436 (a - b).abs() < epsilon
437 }
438
439 #[test]
440 fn test_dc_signal_amplitude_preservation() {
441 let test_cases = vec![
442 (SampleRate::Hz48000, SampleRate::Hz44100, "48kHz -> 44.1kHz"),
443 (SampleRate::Hz44100, SampleRate::Hz48000, "44.1kHz -> 48kHz"),
444 (SampleRate::Hz48000, SampleRate::Hz32000, "48kHz -> 32kHz"),
445 (SampleRate::Hz32000, SampleRate::Hz48000, "32kHz -> 48kHz"),
446 (SampleRate::Hz96000, SampleRate::Hz48000, "96kHz -> 48kHz"),
447 (SampleRate::Hz48000, SampleRate::Hz96000, "48kHz -> 96kHz"),
448 ];
449
450 for (input_rate, output_rate, desc) in test_cases {
451 let mut resampler = ResamplerFft::new(1, input_rate, output_rate);
452
453 let dc_amplitude = 0.5f32;
454 let input = vec![dc_amplitude; resampler.chunk_size_input()];
455 let mut output = vec![0.0f32; resampler.chunk_size_output()];
456
457 for _ in 0..5 {
458 let _ = resampler.resample(&input, &mut output);
459 }
460
461 let delay = resampler.delay();
462 let check_start = delay.min(output.len() / 4);
463 let check_end = output.len() * 3 / 4;
464
465 for (i, &sample) in output[check_start..check_end].iter().enumerate() {
466 assert!(
467 approx_eq(sample, dc_amplitude, EPSILON),
468 "{desc}: DC amplitude not preserved at sample {}: expected {dc_amplitude}, got {sample} (error: {:.2}%)",
469 i + check_start,
470 ((sample - dc_amplitude) / dc_amplitude * 100.0).abs()
471 );
472 }
473 }
474 }
475
476 #[test]
477 fn test_sine_wave_amplitude_preservation() {
478 let test_cases = vec![
479 (SampleRate::Hz48000, SampleRate::Hz44100, "48kHz -> 44.1kHz"),
480 (SampleRate::Hz44100, SampleRate::Hz48000, "44.1kHz -> 48kHz"),
481 (SampleRate::Hz48000, SampleRate::Hz32000, "48kHz -> 32kHz"),
482 ];
483
484 for (input_rate, output_rate, desc) in test_cases {
485 let mut resampler = ResamplerFft::new(1, input_rate, output_rate);
486
487 let amplitude = 0.5f32;
488 let frequency = 1000.0f32;
489 let input_rate_hz = u32::from(input_rate) as f32;
490
491 let chunk_size = resampler.chunk_size_input();
492
493 let mut phase = 0.0f32;
494 let phase_increment = 2.0 * PI * frequency / input_rate_hz;
495 let input: Vec<f32> = (0..chunk_size)
496 .map(|_| {
497 let sample = amplitude * phase.sin();
498 phase += phase_increment;
499 sample
500 })
501 .collect();
502
503 let mut output = vec![0.0f32; resampler.chunk_size_output()];
504
505 for _ in 0..5 {
506 let _ = resampler.resample(&input, &mut output);
507 }
508
509 let delay = resampler.delay();
510 let check_start = delay.min(output.len() / 4);
511 let check_end = output.len() * 3 / 4;
512
513 let peak = output[check_start..check_end]
514 .iter()
515 .map(|&x| x.abs())
516 .fold(0.0f32, f32::max);
517
518 assert!(
519 approx_eq(peak, amplitude, EPSILON),
520 "{desc}: Sine wave amplitude not preserved: expected {amplitude}, got {peak} (error: {:.2}%)",
521 ((peak - amplitude) / amplitude * 100.0).abs()
522 );
523 }
524 }
525
526 #[test]
527 fn test_stereo_dc_amplitude_preservation() {
528 let mut resampler = ResamplerFft::new(2, SampleRate::Hz48000, SampleRate::Hz44100);
529
530 let dc_amplitude_left = 0.3f32;
531 let dc_amplitude_right = 0.6f32;
532 let chunk_size = resampler.chunk_size_input();
533
534 let mut input = vec![0.0f32; chunk_size];
535 for i in 0..(chunk_size / 2) {
536 input[i * 2] = dc_amplitude_left;
537 input[i * 2 + 1] = dc_amplitude_right;
538 }
539
540 let mut output = vec![0.0f32; resampler.chunk_size_output()];
541
542 for _ in 0..5 {
543 let _ = resampler.resample(&input, &mut output);
544 }
545
546 let delay = resampler.delay();
547 let check_start = delay.min(output.len() / 8) * 2;
548 let check_end = output.len() * 3 / 4;
549
550 for i in (check_start..check_end).step_by(2) {
551 let left_sample = output[i];
552 let right_sample = output[i + 1];
553
554 assert!(
555 approx_eq(left_sample, dc_amplitude_left, EPSILON),
556 "Stereo left channel DC not preserved at frame {}: expected {dc_amplitude_left}, got {left_sample}",
557 i / 2
558 );
559
560 assert!(
561 approx_eq(right_sample, dc_amplitude_right, EPSILON),
562 "Stereo right channel DC not preserved at frame {}: expected {dc_amplitude_right}, got {right_sample}",
563 i / 2
564 );
565 }
566 }
567}