resampler/
resampler_fft.rs

1use alloc::{sync::Arc, vec, vec::Vec};
2use core::fmt;
3#[cfg(not(feature = "no_std"))]
4use std::{
5    collections::HashMap,
6    sync::{LazyLock, Mutex},
7};
8
9use crate::{
10    Complex32, Forward, Inverse, Radix, RadixFFT, SampleRate,
11    error::ResampleError,
12    fft::planner::ConversionConfig,
13    window::{WindowType, calculate_cutoff_kaiser, make_sincs_for_kaiser},
14};
15
16const KAISER_BETA: f64 = 10.0;
17
18pub(crate) struct FftCacheData {
19    filter_spectrum: Arc<[Complex32]>,
20    fft: Arc<RadixFFT<Forward>>,
21    ifft: Arc<RadixFFT<Inverse>>,
22}
23
24impl Clone for FftCacheData {
25    fn clone(&self) -> Self {
26        Self {
27            filter_spectrum: Arc::clone(&self.filter_spectrum),
28            fft: Arc::clone(&self.fft),
29            ifft: Arc::clone(&self.ifft),
30        }
31    }
32}
33
34#[cfg(not(feature = "no_std"))]
35static FFT_CACHE: LazyLock<Mutex<HashMap<u64, FftCacheData>>> =
36    LazyLock::new(|| Mutex::new(HashMap::new()));
37
38/// High-quality and high-performance FFT-based audio resampler supporting multi-channel audio.
39///
40/// `ResamplerFft` uses the overlap-add FFT method with Kaiser windowing to convert audio
41/// between different sample rates. The field channels specifies the
42/// number of audio channels (e.g., 1 for mono, 2 for stereo).
43pub struct ResamplerFft {
44    channels: usize,
45    fft_resampler: FftResampler,
46    chunk_size_input: usize,
47    chunk_size_output: usize,
48    fft_size_input: usize,
49    fft_size_output: usize,
50    saved_frames: usize,
51    overlaps: Vec<f32>,
52    input_scratch: Vec<f32>,
53    output_scratch: Vec<f32>,
54}
55
56impl fmt::Debug for ResamplerFft {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        f.debug_struct("ResamplerFft")
59            .field("channels", &self.channels)
60            .field("chunk_size_input", &self.chunk_size_input)
61            .field("chunk_size_output", &self.chunk_size_output)
62            .field("fft_size_input", &self.fft_size_input)
63            .field("fft_size_output", &self.fft_size_output)
64            .finish_non_exhaustive()
65    }
66}
67
68impl ResamplerFft {
69    /// Create a new [`ResamplerFft`].
70    ///
71    /// Parameters are:
72    /// - `channels`: The channel count.
73    /// - `sample_rate_input`: Input sample rate.
74    /// - `sample_rate_output`: Output sample rate.
75    pub fn new(
76        channels: usize,
77        sample_rate_input: SampleRate,
78        sample_rate_output: SampleRate,
79    ) -> Self {
80        // Get the optimized FFT sizes and factors directly from the conversion table.
81        // These sizes are carefully chosen for efficient factorization and minimal latency.
82        let config = ConversionConfig::from_sample_rates(sample_rate_input, sample_rate_output);
83        let (fft_size_input, factors_in, fft_size_output, factors_out) =
84            config.scale_for_throughput();
85
86        let overlaps: Vec<f32> = vec![0.0; fft_size_output * channels];
87
88        let chunk_size_input = fft_size_input * channels;
89        let chunk_size_output = fft_size_output * channels;
90
91        let needed_input_buffer_size = chunk_size_input + fft_size_input;
92        let needed_buffer_size_output = chunk_size_output + fft_size_output;
93        let input_scratch: Vec<f32> = vec![0.0; needed_input_buffer_size * channels];
94        let output_scratch: Vec<f32> = vec![0.0; needed_buffer_size_output * channels];
95
96        let saved_frames = 0;
97
98        let fft_resampler = FftResampler::new(
99            u32::from(sample_rate_input),
100            u32::from(sample_rate_output),
101            fft_size_input,
102            factors_in,
103            fft_size_output,
104            factors_out,
105        );
106
107        ResamplerFft {
108            channels,
109            chunk_size_input,
110            chunk_size_output,
111            fft_size_input,
112            fft_size_output,
113            overlaps,
114            input_scratch,
115            output_scratch,
116            saved_frames,
117            fft_resampler,
118        }
119    }
120
121    /// Returns the size used to store input scratch for 1 channel
122    fn input_scratch_ch_size(&self) -> usize {
123        self.chunk_size_input + self.fft_size_input
124    }
125
126    /// Returns the size used to store output scratch for 1 channel
127    fn output_scratch_ch_size(&self) -> usize {
128        self.chunk_size_input + self.fft_size_input
129    }
130
131    /// Returns the required input buffer size in total f32 values (including all channels).
132    ///
133    /// For example, with a stereo resampler (CHANNEL=2), this returns the total number
134    /// of f32 values needed in the interleaved input buffer [L0, R0, L1, R1, ...].
135    pub fn chunk_size_input(&self) -> usize {
136        self.chunk_size_input
137    }
138
139    /// Returns the required output buffer size in total f32 values (including all channels).
140    ///
141    /// For example, with a stereo resampler (CHANNEL=2), this returns the total number
142    /// of f32 values needed in the interleaved output buffer [L0, R0, L1, R1, ...].
143    pub fn chunk_size_output(&self) -> usize {
144        self.chunk_size_output
145    }
146
147    /// Returns the algorithmic delay (latency) of the resampler in input samples.
148    ///
149    /// This delay is inherent to the FFT-based overlap-add process and equals
150    /// half the FFT input size due to the windowing operation.
151    pub fn delay(&self) -> usize {
152        self.fft_size_input / 2
153    }
154
155    /// Processes one chunk of audio, resampling from input to output sample rate.
156    ///
157    /// Input and output must be interleaved f32 slices with all channels interleaved.
158    /// For stereo audio, the format is `[L0, R0, L1, R1, ...]`. For mono, it's `[S0, S1, S2, ...]`.
159    ///
160    /// ## Parameters
161    ///
162    /// - `input`: Interleaved input samples. Must contain at least [`chunk_size_input()`](Self::chunk_size_input) values.
163    /// - `output`: Interleaved output buffer. Must have capacity for at least [`chunk_size_output()`](Self::chunk_size_output) values.
164    ///
165    /// ## Example
166    ///
167    /// ```rust
168    /// use resampler::{ResamplerFft, SampleRate};
169    ///
170    /// let mut resampler = ResamplerFft::new(1, SampleRate::Hz48000, SampleRate::Hz44100);
171    ///
172    /// let input = vec![0.0f32; resampler.chunk_size_input()];
173    /// let mut output = vec![0.0f32; resampler.chunk_size_output()];
174    ///
175    /// match resampler.resample(&input, &mut output) {
176    ///     Ok(()) => {
177    ///         println!("Resample successfully");
178    ///     }
179    ///     Err(error) => eprintln!("Resampling error: {error:?}"),
180    /// }
181    /// ```
182    pub fn resample(&mut self, input: &[f32], output: &mut [f32]) -> Result<(), ResampleError> {
183        let expected_input_len = self.chunk_size_input;
184        let min_output_len = self.chunk_size_output;
185
186        if input.len() < expected_input_len {
187            return Err(ResampleError::InvalidInputBufferSize);
188        }
189
190        if output.len() < min_output_len {
191            return Err(ResampleError::InvalidOutputBufferSize);
192        }
193
194        let in_scratch_ch_len = self.input_scratch_ch_size();
195        let out_scratch_ch_len = self.output_scratch_ch_size();
196        // Deinterleave input into per-channel scratch buffers.
197        (0..self.fft_size_input).for_each(|frame_index| {
198            (0..self.channels).for_each(|channel| {
199                self.input_scratch[channel * in_scratch_ch_len + frame_index] =
200                    input[frame_index * self.channels + channel];
201            });
202        });
203
204        let (subchunks_to_process, output_scratch_offset) = (
205            self.chunk_size_input / (self.fft_size_input * self.channels),
206            self.saved_frames,
207        );
208
209        // Resample between input and output scratch buffers.
210        for channel in 0..self.channels {
211            let start = channel * in_scratch_ch_len;
212            let end = start + in_scratch_ch_len;
213            for (input_chunk, output_chunk) in self.input_scratch[start..end]
214                .chunks(self.fft_size_input)
215                .take(subchunks_to_process)
216                .zip(
217                    self.output_scratch[channel * out_scratch_ch_len + output_scratch_offset..]
218                        .chunks_mut(self.fft_size_output),
219                )
220            {
221                let start = self.fft_size_output * channel;
222                let end = start + self.fft_size_output;
223                self.fft_resampler.resample(
224                    input_chunk,
225                    output_chunk,
226                    &mut self.overlaps[start..end],
227                );
228            }
229        }
230
231        // Deinterleave output from per-channel scratch buffers.
232        (0..self.fft_size_output).for_each(|frame_index| {
233            (0..self.channels).for_each(|channel| {
234                output[frame_index * self.channels + channel] =
235                    self.output_scratch[channel * out_scratch_ch_len + frame_index];
236            });
237        });
238
239        Ok(())
240    }
241}
242
243/// FFT-based resampler using overlap-add reconstruction.
244///
245/// The overlap-add resampling approach is based on the Rubato crate:
246/// https://github.com/HEnquist/rubato
247struct FftResampler {
248    fft_size_input: usize,
249    fft_size_output: usize,
250    fft: Arc<RadixFFT<Forward>>,
251    ifft: Arc<RadixFFT<Inverse>>,
252    scratchpad_forward: Vec<Complex32>,
253    scratchpad_inverse: Vec<Complex32>,
254    filter_spectrum: Arc<[Complex32]>,
255    input_spectrum: Vec<Complex32>,
256    output_spectrum: Vec<Complex32>,
257    input_buffer: Vec<f32>,
258    output_buffer: Vec<f32>,
259}
260
261impl FftResampler {
262    pub(crate) fn new(
263        sample_rate_input: u32,
264        sample_rate_output: u32,
265        fft_size_input: usize,
266        factors_input: Vec<Radix>,
267        fft_size_output: usize,
268        factors_output: Vec<Radix>,
269    ) -> Self {
270        let cached = Self::get_or_create_fft_data(
271            sample_rate_input,
272            sample_rate_output,
273            fft_size_input,
274            factors_input,
275            fft_size_output,
276            factors_output,
277        );
278
279        let input_spectrum: Vec<Complex32> = vec![Complex32::zero(); fft_size_input + 1];
280        let input_buffer: Vec<f32> = vec![0.0; 2 * fft_size_input];
281        let output_spectrum: Vec<Complex32> = vec![Complex32::zero(); fft_size_output + 1];
282        let output_buffer: Vec<f32> = vec![0.0; 2 * fft_size_output];
283
284        let scratchpad_forward = vec![Complex32::zero(); cached.fft.scratchpad_size()];
285        let scratchpad_inverse = vec![Complex32::zero(); cached.ifft.scratchpad_size()];
286
287        FftResampler {
288            fft_size_input,
289            fft_size_output,
290            fft: cached.fft,
291            ifft: cached.ifft,
292            scratchpad_forward,
293            scratchpad_inverse,
294            filter_spectrum: cached.filter_spectrum,
295            input_spectrum,
296            output_spectrum,
297            input_buffer,
298            output_buffer,
299        }
300    }
301
302    /// Retrieves or creates FFT data. By default, this uses a global cache to share FFT
303    /// objects across multiple Resampler instances. With the "no_std" feature, it creates
304    /// new FFT objects each time.
305    #[cfg(not(feature = "no_std"))]
306    fn get_or_create_fft_data(
307        sample_rate_input: u32,
308        sample_rate_output: u32,
309        fft_size_input: usize,
310        factors_in: Vec<Radix>,
311        fft_size_output: usize,
312        factors_out: Vec<Radix>,
313    ) -> FftCacheData {
314        let cache_key = ((sample_rate_input as u64) << 32) | (sample_rate_output as u64);
315        FFT_CACHE
316            .lock()
317            .unwrap()
318            .entry(cache_key)
319            .or_insert_with(|| {
320                Self::create_fft_data(fft_size_input, factors_in, fft_size_output, factors_out)
321            })
322            .clone()
323    }
324
325    #[cfg(feature = "no_std")]
326    fn get_or_create_fft_data(
327        _sample_rate_input: u32,
328        _sample_rate_output: u32,
329        fft_size_input: usize,
330        factors_in: Vec<Radix>,
331        fft_size_output: usize,
332        factors_out: Vec<Radix>,
333    ) -> FftCacheData {
334        Self::create_fft_data(fft_size_input, factors_in, fft_size_output, factors_out)
335    }
336
337    /// Creates FFT objects and filter spectrum. This is the no-std compatible core logic.
338    fn create_fft_data(
339        fft_size_input: usize,
340        factors_in: Vec<Radix>,
341        fft_size_output: usize,
342        factors_out: Vec<Radix>,
343    ) -> FftCacheData {
344        // Scale factors for the 2x windowing multiplier.
345        let mut fft_factors_input = factors_in;
346        fft_factors_input.push(Radix::Factor2);
347        let mut fft_factors_output = factors_out;
348        fft_factors_output.push(Radix::Factor2);
349
350        let fft = RadixFFT::<Forward>::new(fft_factors_input);
351        let ifft = RadixFFT::<Inverse>::new(fft_factors_output);
352
353        let cutoff = match fft_size_input > fft_size_output {
354            true => {
355                let scale = fft_size_output as f64 / fft_size_input as f64;
356                calculate_cutoff_kaiser(fft_size_output, KAISER_BETA) * scale
357            }
358            false => calculate_cutoff_kaiser(fft_size_input, KAISER_BETA),
359        };
360
361        let sincs = make_sincs_for_kaiser(
362            fft_size_input,
363            1,
364            cutoff as f32,
365            KAISER_BETA,
366            WindowType::Periodic,
367        );
368        let mut filter_time = vec![0.0; 2 * fft_size_input];
369        let mut filter_spectrum = vec![Complex32::zero(); fft_size_input + 1];
370
371        for (index, filter_value) in filter_time.iter_mut().enumerate().take(fft_size_input) {
372            *filter_value = sincs[0][index] / (2 * fft_size_input) as f32;
373        }
374
375        let mut scratchpad = vec![Complex32::zero(); fft.scratchpad_size()];
376        fft.process(&filter_time, &mut filter_spectrum, &mut scratchpad);
377
378        FftCacheData {
379            filter_spectrum: filter_spectrum.into(),
380            fft: Arc::new(fft),
381            ifft: Arc::new(ifft),
382        }
383    }
384
385    fn resample(&mut self, wave_input: &[f32], wave_output: &mut [f32], overlap: &mut [f32]) {
386        // Copy input and clear padding.
387        self.input_buffer[..self.fft_size_input].copy_from_slice(wave_input);
388        self.input_buffer[self.fft_size_input..].fill(0.0);
389
390        self.fft.process(
391            &self.input_buffer,
392            &mut self.input_spectrum,
393            &mut self.scratchpad_forward,
394        );
395
396        let new_length = match self.fft_size_input < self.fft_size_output {
397            true => self.fft_size_input + 1,
398            false => self.fft_size_output,
399        };
400
401        self.input_spectrum
402            .iter_mut()
403            .take(new_length)
404            .zip(self.filter_spectrum.iter())
405            .for_each(|(spectrum, filter)| *spectrum = spectrum.mul(filter));
406
407        self.output_spectrum[0..new_length].copy_from_slice(&self.input_spectrum[0..new_length]);
408        self.output_spectrum[new_length..].fill(Complex32::zero());
409
410        self.ifft.process(
411            &self.output_spectrum,
412            &mut self.output_buffer,
413            &mut self.scratchpad_inverse,
414        );
415
416        for (index, item) in wave_output
417            .iter_mut()
418            .enumerate()
419            .take(self.fft_size_output)
420        {
421            *item = self.output_buffer[index] + overlap[index];
422        }
423        overlap.copy_from_slice(&self.output_buffer[self.fft_size_output..]);
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use core::f32::consts::PI;
430
431    use super::*;
432
433    const EPSILON: f32 = 0.02;
434
435    fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
436        (a - b).abs() < epsilon
437    }
438
439    #[test]
440    fn test_dc_signal_amplitude_preservation() {
441        let test_cases = vec![
442            (SampleRate::Hz48000, SampleRate::Hz44100, "48kHz -> 44.1kHz"),
443            (SampleRate::Hz44100, SampleRate::Hz48000, "44.1kHz -> 48kHz"),
444            (SampleRate::Hz48000, SampleRate::Hz32000, "48kHz -> 32kHz"),
445            (SampleRate::Hz32000, SampleRate::Hz48000, "32kHz -> 48kHz"),
446            (SampleRate::Hz96000, SampleRate::Hz48000, "96kHz -> 48kHz"),
447            (SampleRate::Hz48000, SampleRate::Hz96000, "48kHz -> 96kHz"),
448        ];
449
450        for (input_rate, output_rate, desc) in test_cases {
451            let mut resampler = ResamplerFft::new(1, input_rate, output_rate);
452
453            let dc_amplitude = 0.5f32;
454            let input = vec![dc_amplitude; resampler.chunk_size_input()];
455            let mut output = vec![0.0f32; resampler.chunk_size_output()];
456
457            for _ in 0..5 {
458                let _ = resampler.resample(&input, &mut output);
459            }
460
461            let delay = resampler.delay();
462            let check_start = delay.min(output.len() / 4);
463            let check_end = output.len() * 3 / 4;
464
465            for (i, &sample) in output[check_start..check_end].iter().enumerate() {
466                assert!(
467                    approx_eq(sample, dc_amplitude, EPSILON),
468                    "{desc}: DC amplitude not preserved at sample {}: expected {dc_amplitude}, got {sample} (error: {:.2}%)",
469                    i + check_start,
470                    ((sample - dc_amplitude) / dc_amplitude * 100.0).abs()
471                );
472            }
473        }
474    }
475
476    #[test]
477    fn test_sine_wave_amplitude_preservation() {
478        let test_cases = vec![
479            (SampleRate::Hz48000, SampleRate::Hz44100, "48kHz -> 44.1kHz"),
480            (SampleRate::Hz44100, SampleRate::Hz48000, "44.1kHz -> 48kHz"),
481            (SampleRate::Hz48000, SampleRate::Hz32000, "48kHz -> 32kHz"),
482        ];
483
484        for (input_rate, output_rate, desc) in test_cases {
485            let mut resampler = ResamplerFft::new(1, input_rate, output_rate);
486
487            let amplitude = 0.5f32;
488            let frequency = 1000.0f32;
489            let input_rate_hz = u32::from(input_rate) as f32;
490
491            let chunk_size = resampler.chunk_size_input();
492
493            let mut phase = 0.0f32;
494            let phase_increment = 2.0 * PI * frequency / input_rate_hz;
495            let input: Vec<f32> = (0..chunk_size)
496                .map(|_| {
497                    let sample = amplitude * phase.sin();
498                    phase += phase_increment;
499                    sample
500                })
501                .collect();
502
503            let mut output = vec![0.0f32; resampler.chunk_size_output()];
504
505            for _ in 0..5 {
506                let _ = resampler.resample(&input, &mut output);
507            }
508
509            let delay = resampler.delay();
510            let check_start = delay.min(output.len() / 4);
511            let check_end = output.len() * 3 / 4;
512
513            let peak = output[check_start..check_end]
514                .iter()
515                .map(|&x| x.abs())
516                .fold(0.0f32, f32::max);
517
518            assert!(
519                approx_eq(peak, amplitude, EPSILON),
520                "{desc}: Sine wave amplitude not preserved: expected {amplitude}, got {peak} (error: {:.2}%)",
521                ((peak - amplitude) / amplitude * 100.0).abs()
522            );
523        }
524    }
525
526    #[test]
527    fn test_stereo_dc_amplitude_preservation() {
528        let mut resampler = ResamplerFft::new(2, SampleRate::Hz48000, SampleRate::Hz44100);
529
530        let dc_amplitude_left = 0.3f32;
531        let dc_amplitude_right = 0.6f32;
532        let chunk_size = resampler.chunk_size_input();
533
534        let mut input = vec![0.0f32; chunk_size];
535        for i in 0..(chunk_size / 2) {
536            input[i * 2] = dc_amplitude_left;
537            input[i * 2 + 1] = dc_amplitude_right;
538        }
539
540        let mut output = vec![0.0f32; resampler.chunk_size_output()];
541
542        for _ in 0..5 {
543            let _ = resampler.resample(&input, &mut output);
544        }
545
546        let delay = resampler.delay();
547        let check_start = delay.min(output.len() / 8) * 2;
548        let check_end = output.len() * 3 / 4;
549
550        for i in (check_start..check_end).step_by(2) {
551            let left_sample = output[i];
552            let right_sample = output[i + 1];
553
554            assert!(
555                approx_eq(left_sample, dc_amplitude_left, EPSILON),
556                "Stereo left channel DC not preserved at frame {}: expected {dc_amplitude_left}, got {left_sample}",
557                i / 2
558            );
559
560            assert!(
561                approx_eq(right_sample, dc_amplitude_right, EPSILON),
562                "Stereo right channel DC not preserved at frame {}: expected {dc_amplitude_right}, got {right_sample}",
563                i / 2
564            );
565        }
566    }
567}