resampler/
resampler.rs

1
2use std::{cmp::min, sync::Arc, fmt::{self, Debug, Formatter}};
3use rustfft::{FftPlanner, Fft, num_complex::Complex};
4
5#[derive(Debug, Clone)]
6pub enum ResamplerError {
7    SizeError(String),
8}
9
10/// How the Resampler works
11/// For audio stretching:
12///   1. The input audio remains its original length, and zero-padding is applied at the end to reach the target length.
13///   2. Perform FFT transformation to obtain the frequency domain.
14///   3. In the frequency domain, scale down the frequency values proportionally (shift them lower).
15///   4. Perform inverse FFT to obtain the stretched audio.
16/// 
17/// For audio compression:
18///   1. Take the input audio.
19///   2. Perform FFT transformation.
20///   3. In the frequency domain, scale up the frequency values proportionally (shift them higher).
21///   4. Perform inverse FFT to obtain audio with increased pitch but unchanged length.
22///   5. Truncate the audio to shorten its duration.
23/// 
24/// This implies: the FFT length must be chosen as the longest possible length involved.
25#[derive(Clone)]
26pub struct Resampler {
27    fft_forward: Arc<dyn Fft<f64>>,
28    fft_inverse: Arc<dyn Fft<f64>>,
29    fft_size: usize,
30    normalize_scaler: f64,
31}
32
33fn get_average(complex: &[Complex<f64>]) -> Complex<f64> {
34    let sum: Complex<f64> = complex.iter().copied().sum();
35    let scaler = 1.0 / complex.len() as f64;
36    Complex::<f64> {
37        re: sum.re * scaler,
38        im: sum.im * scaler,
39    }
40}
41
42fn interpolate(c1: Complex<f64>, c2: Complex<f64>, s: f64) -> Complex<f64> {
43    c1 + (c2 - c1) * s
44}
45
46impl Resampler {
47    pub fn new(fft_size: usize) -> Self {
48        let mut planner = FftPlanner::new();
49        if fft_size & 1 != 0 {
50            panic!("The input size and the output size must be times of 2, got {fft_size}");
51        }
52        Self {
53            fft_forward: planner.plan_fft_forward(fft_size),
54            fft_inverse: planner.plan_fft_inverse(fft_size),
55            fft_size,
56            normalize_scaler: 1.0 / fft_size as f64,
57        }
58    }
59
60    /// * The fft size can be any number greater than the sample rate of the encoder or the decoder.
61    /// * It is for the resampler. A greater number results in better resample quality, but the process could be slower.
62    /// * In most cases, the audio sampling rate is about `11025` to `48000`, so `65536` is the best number for the resampler.
63    pub fn get_rounded_up_fft_size(sample_rate: u32) -> usize {
64        for i in 0..31 {
65            let fft_size = 1usize << i;
66            if fft_size >= sample_rate as usize {
67                return fft_size;
68            }
69        }
70        0x1_00000000_usize
71    }
72
73    /// Turn real numbers into complex numbers with conj
74    pub fn real_to_complex(samples: &[f32]) -> Vec<Complex<f64>> {
75        let n = samples.len();
76        let half = n / 2;
77        let back = n - 1;
78        let mut ret = vec![Complex::default(); n];
79        for i in 0..half {
80            ret[i] = Complex::new(samples[i * 2] as f64, samples[i * 2 + 1] as f64);
81            ret[back - i] = ret[i].conj();
82        }
83        if n & 1 == 1 {
84            ret[half] = Complex::new(samples[back] as f64, 0.0);
85        }
86        ret
87    }
88
89    /// Turn comples numbers into real numbers
90    pub fn complex_to_real(complex: &[Complex<f64>]) -> Vec<f64> {
91        let n = complex.len();
92        let half = n / 2;
93        let back = n - 1;
94        let mut ret = vec![0.0; n];
95        for i in 0..half {
96            ret[i * 2] = complex[i].re;
97            ret[i * 2 + 1] = complex[i].im;
98        }
99        if n & 1 == 1 {
100            ret[back] = complex[half].re;
101        }
102        ret
103    }
104
105    /// `desired_length`: The target audio length to achieve, which must not exceed the FFT size.
106    /// When samples.len() < desired_length, it indicates audio stretching to desired_length.
107    /// When samples.len() > desired_length, it indicates audio compression to desired_length.
108    pub fn resample_core(&self, samples: &[f32], desired_length: usize) -> Result<Vec<f32>, ResamplerError> {
109        const INTERPOLATE_UPSCALE: bool = true;
110        const INTERPOLATE_DNSCALE: bool = true;
111
112        let input_size = samples.len();
113        if input_size == desired_length {
114            return Ok(samples.to_vec());
115        }
116
117        if desired_length > self.fft_size {
118            return Err(ResamplerError::SizeError(format!("The desired size {desired_length} must not exceed the FFT size {}", self.fft_size)));
119        }
120
121        let mut fftbuf: Vec<Complex<f64>> = Self::real_to_complex(samples);
122
123        if fftbuf.len() <= self.fft_size {
124            fftbuf.resize(self.fft_size, Complex{re: 0.0, im: 0.0});
125        } else {
126            return Err(ResamplerError::SizeError(format!("The input size {} must not exceed the FFT size {}", fftbuf.len(), self.fft_size)));
127        }
128
129        // 进行 FFT 正向变换
130        self.fft_forward.process(&mut fftbuf);
131
132        // 准备进行插值
133        let mut fftdst = vec![Complex::<f64>{re: 0.0, im: 0.0}; self.fft_size];
134
135        let half = self.fft_size / 2;
136        let back = self.fft_size - 1;
137        let scaling = desired_length as f64 / input_size as f64;
138        if input_size > desired_length {
139            // Input size exceeds output size, indicating audio compression.
140            // This implies stretching in the frequency domain (scaling up).
141            for i in 0..half {
142                let scaled = i as f64 * scaling;
143                let i1 = scaled.trunc() as usize;
144                let i2 = i1 + 1;
145                let s = scaled.fract();
146                if INTERPOLATE_DNSCALE {
147                    fftdst[i] = interpolate(fftbuf[i1], fftbuf[i2], s);
148                    fftdst[back - i] = interpolate(fftbuf[back - i1], fftbuf[back - i2], s);
149                } else {
150                    fftdst[i] = fftbuf[i1];
151                    fftdst[back - i] = fftbuf[back - i1];
152                }
153            }
154        } else {
155            // Input size is smaller than the output size, indicating audio stretching.
156            // This implies compression in the frequency domain (scaling down).
157            for i in 0..half {
158                let i1 = (i as f64 * scaling).trunc() as usize;
159                let i2 = ((i + 1) as f64 * scaling).trunc() as usize;
160                if i2 >= half {break;}
161                let j1 = back - i2;
162                let j2 = back - i1;
163                if INTERPOLATE_UPSCALE {
164                    fftdst[i] = get_average(&fftbuf[i1..i2]);
165                    fftdst[back - i] = get_average(&fftbuf[j1..j2]);
166                } else {
167                    fftdst[i] = fftbuf[i1];
168                    fftdst[back - i] = fftbuf[back - i1];
169                }
170            }
171        }
172
173        self.fft_inverse.process(&mut fftdst);
174
175        let mut real_ret = Self::complex_to_real(&fftdst);
176
177        // Truncate at the waveform output stage.
178        real_ret.truncate(desired_length);
179
180        Ok(real_ret.into_iter().map(|r|(r * self.normalize_scaler) as f32).collect())
181    }
182
183    /// The processing unit size should be adjusted to work in "chunks per second", 
184    /// and artifacts will vanish when the chunk count aligns with the maximum infrasonic frequency.
185    /// Calling `self.get_desired_length()` determines the processed chunk size calculated based on the target sample rate.
186    pub fn get_process_size(&self, orig_size: usize, src_sample_rate: u32, dst_sample_rate: u32) -> usize {
187        const MAX_INFRASOUND_FREQ: usize = 20;
188        if src_sample_rate == dst_sample_rate {
189            min(self.fft_size, orig_size)
190        } else {
191            min(self.fft_size, src_sample_rate as usize / MAX_INFRASOUND_FREQ)
192        }
193    }
194
195    /// Get the processed chunk size calculated based on the target sample rate.
196    pub fn get_desired_length(&self, proc_size: usize, src_sample_rate: u32, dst_sample_rate: u32) -> usize {
197        min(self.fft_size, proc_size * dst_sample_rate as usize / src_sample_rate as usize)
198    }
199
200    pub fn resample(&self, input: &[f32], src_sample_rate: u32, dst_sample_rate: u32) -> Result<Vec<f32>, ResamplerError> {
201        if src_sample_rate == dst_sample_rate {
202            Ok(input.to_vec())
203        } else {
204            let proc_size = self.get_process_size(self.fft_size, src_sample_rate, dst_sample_rate);
205            let desired_length = self.get_desired_length(proc_size, src_sample_rate, dst_sample_rate);
206            if input.len() > proc_size {
207                Err(ResamplerError::SizeError(format!("To resize the waveform, the input size should be {proc_size}, not {}", input.len())))
208            } else if src_sample_rate > dst_sample_rate {
209                // Source sample rate is higher than the target, indicating waveform compression.
210                self.resample_core(input, desired_length)
211            } else {
212                // Source sample rate is lower than the target, indicating waveform stretching.
213                // When the input length is less than the desired length, zero-padding is applied to the end.
214                input.to_vec().resize(proc_size, 0.0);
215                self.resample_core(input, desired_length)
216            }
217        }
218    }
219
220    pub fn get_fft_size(&self) -> usize {
221        self.fft_size
222    }
223}
224
225impl Debug for Resampler {
226    fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
227        fmt.debug_struct("Resampler")
228            .field("fft_forward", &format_args!("..."))
229            .field("fft_inverse", &format_args!("..."))
230            .field("fft_size", &self.fft_size)
231            .field("normalize_scaler", &self.normalize_scaler)
232            .finish()
233    }
234}