Skip to main content

audio_codec/
resampler.rs

1use super::{PcmBuf, Sample};
2use std::f64::consts::PI as PI_F64;
3
4pub struct Resampler {
5    input_rate: usize,
6    output_rate: usize,
7    ratio: f64,
8    coeffs: Vec<f32>,
9    num_phases: usize,
10    taps_per_phase: usize,
11    history: Vec<f32>,
12    current_pos: f64,
13}
14
15fn bessel_i0(x: f64) -> f64 {
16    let mut sum = 1.0_f64;
17    let mut term = 1.0_f64;
18    let x_sq = x * x * 0.25;
19
20    for m in 1..=30 {
21        term *= x_sq / (m * m) as f64;
22        sum += term;
23        if term < 1e-15 * sum {
24            break;
25        }
26    }
27    sum
28}
29
30fn kaiser_window(n: usize, n_total: usize, beta: f64) -> f64 {
31    if n_total <= 1 {
32        return 1.0;
33    }
34    let alpha = (n_total - 1) as f64 / 2.0;
35    let x = (n as f64 - alpha) / alpha;
36    let arg = beta * (1.0 - x * x).sqrt();
37    bessel_i0(arg) / bessel_i0(beta)
38}
39
40impl Resampler {
41    pub fn new(input_rate: usize, output_rate: usize) -> Self {
42        let ratio = output_rate as f64 / input_rate as f64;
43
44        const NUM_PHASES: usize = 256;
45        const TAPS_PER_PHASE: usize = 24;
46        const KAISER_BETA: f64 = 7.0;
47
48        let num_phases = NUM_PHASES;
49        let taps_per_phase = TAPS_PER_PHASE;
50        let filter_len = num_phases * taps_per_phase;
51
52        let mut raw_coeffs = vec![0.0_f32; filter_len];
53
54        let cutoff = if ratio < 1.0 {
55            ratio * 0.5 * 0.95
56        } else {
57            0.5 * 0.95
58        };
59
60        let center = (taps_per_phase as f64 - 1.0) / 2.0;
61
62        // Design the polyphase filter
63        for p in 0..num_phases {
64            let mut phase_coeffs = vec![0.0_f64; taps_per_phase];
65            let mut sum = 0.0_f64;
66
67            for t in 0..taps_per_phase {
68                let x = t as f64 - center - (p as f64 / num_phases as f64);
69
70                let sinc_val = if x.abs() < 1e-10 {
71                    2.0 * cutoff
72                } else {
73                    let x_pi = x * PI_F64;
74                    (x_pi * 2.0 * cutoff).sin() / x_pi
75                };
76
77                let full_filter_idx = t * num_phases + p;
78                let window = kaiser_window(full_filter_idx, filter_len, KAISER_BETA);
79
80                phase_coeffs[t] = sinc_val * window;
81                sum += phase_coeffs[t];
82            }
83
84            for t in 0..taps_per_phase {
85                let normalized = (phase_coeffs[t] / sum) as f32;
86                raw_coeffs[p * taps_per_phase + t] = normalized;
87            }
88        }
89
90        Self {
91            input_rate,
92            output_rate,
93            ratio,
94            coeffs: raw_coeffs,
95            num_phases,
96            taps_per_phase,
97            history: vec![0.0; taps_per_phase],
98            current_pos: 0.0,
99        }
100    }
101
102    #[inline(always)]
103    fn dot_product(a: &[f32], b: &[f32]) -> f32 {
104        debug_assert_eq!(a.len(), 24);
105        debug_assert_eq!(b.len(), 24);
106
107        #[cfg(target_arch = "aarch64")]
108        {
109            // ARM NEON: 24 taps = 6 iterations of 4-wide vectors
110            unsafe {
111                use std::arch::aarch64::*;
112                let mut sumv = vdupq_n_f32(0.0);
113                for i in (0..24).step_by(4) {
114                    let av = vld1q_f32(a.as_ptr().add(i));
115                    let bv = vld1q_f32(b.as_ptr().add(i));
116                    sumv = vfmaq_f32(sumv, av, bv);
117                }
118                vaddvq_f32(sumv)
119            }
120        }
121        #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
122        {
123            unsafe {
124                use std::arch::x86_64::*;
125                let mut sumv = _mm256_setzero_ps();
126                for i in (0..24).step_by(8) {
127                    let av = _mm256_loadu_ps(a.as_ptr().add(i));
128                    let bv = _mm256_loadu_ps(b.as_ptr().add(i));
129                    sumv = _mm256_add_ps(sumv, _mm256_mul_ps(av, bv));
130                }
131                // Horizontal sum
132                let x128 = _mm_add_ps(_mm256_extractf128_ps(sumv, 1), _mm256_castps256_ps128(sumv));
133                let x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
134                let x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
135                _mm_cvtss_f32(x32)
136            }
137        }
138        #[cfg(all(
139            target_arch = "x86_64",
140            target_feature = "sse2",
141            not(target_feature = "avx")
142        ))]
143        {
144            unsafe {
145                use std::arch::x86_64::*;
146                let mut sumv = _mm_setzero_ps();
147                for i in (0..24).step_by(4) {
148                    let av = _mm_loadu_ps(a.as_ptr().add(i));
149                    let bv = _mm_loadu_ps(b.as_ptr().add(i));
150                    sumv = _mm_add_ps(sumv, _mm_mul_ps(av, bv));
151                }
152                let x64 = _mm_add_ps(sumv, _mm_shuffle_ps(sumv, sumv, 0x4e));
153                let x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x11));
154                _mm_cvtss_f32(x32)
155            }
156        }
157        #[cfg(not(any(
158            target_arch = "aarch64",
159            all(target_arch = "x86_64", target_feature = "sse2")
160        )))]
161        {
162            a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
163        }
164    }
165
166    pub fn resample(&mut self, input: &[Sample]) -> PcmBuf {
167        if self.input_rate == self.output_rate {
168            return input.to_vec();
169        }
170
171        let mut output = Vec::with_capacity((input.len() as f64 * self.ratio) as usize + 1);
172        let inv_ratio = 1.0 / self.ratio;
173        let taps = self.taps_per_phase;
174        let num_phases_f = self.num_phases as f64;
175
176        for &sample in input {
177            self.history.copy_within(1..taps, 0);
178            self.history[taps - 1] = sample as f32;
179
180            while self.current_pos < 1.0 {
181                let phase_idx = (self.current_pos * num_phases_f) as usize;
182                let phase_idx = phase_idx.min(self.num_phases - 1); // Safety clamp
183                let offset = phase_idx * taps;
184                let phase_coeffs = &self.coeffs[offset..offset + taps];
185
186                let out_sample = Self::dot_product(phase_coeffs, &self.history);
187
188                output.push(out_sample.clamp(i16::MIN as f32, i16::MAX as f32) as i16);
189                self.current_pos += inv_ratio;
190            }
191            self.current_pos -= 1.0;
192        }
193
194        output
195    }
196
197    pub fn reset(&mut self) {
198        self.history.fill(0.0);
199        self.current_pos = 0.0;
200    }
201}
202
203pub fn resample(input: &[Sample], input_sample_rate: u32, output_sample_rate: u32) -> PcmBuf {
204    if input_sample_rate == output_sample_rate {
205        return input.to_vec();
206    }
207    let mut r = Resampler::new(input_sample_rate as usize, output_sample_rate as usize);
208    r.resample(input)
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use std::f32::consts::PI as PI_F32;
215    use std::time::Instant;
216
217    #[test]
218    fn test_resample_8k_to_16k() {
219        let mut resampler = Resampler::new(8000, 16000);
220        let input = vec![1000i16; 80];
221        let output = resampler.resample(&input);
222        assert!(output.len() >= 150 && output.len() <= 170);
223        for &s in &output[48..output.len().saturating_sub(48)] {
224            assert!((s - 1000).abs() < 100, "Value {} is too far from 1000", s);
225        }
226    }
227
228    #[test]
229    fn test_resample_16k_to_8k() {
230        let mut resampler = Resampler::new(16000, 8000);
231        let input = vec![1000i16; 160];
232        let output = resampler.resample(&input);
233        assert!(output.len() >= 75 && output.len() <= 85);
234        let skip = output.len() / 4;
235        for &s in &output[skip..output.len() - skip] {
236            assert!((s - 1000).abs() < 100, "Value {} is too far from 1000", s);
237        }
238    }
239
240    #[test]
241    fn test_frequency_response_downsample() {
242        let mut resampler = Resampler::new(16000, 8000);
243        let freq = 2000.0_f32; // Well below 4kHz Nyquist
244        let samples: Vec<i16> = (0..160)
245            .map(|i| ((i as f32 * freq * 2.0 * PI_F32 / 16000.0).sin() * 10000.0) as i16)
246            .collect();
247
248        let output = resampler.resample(&samples);
249
250        // Output should have similar amplitude (allowing for some attenuation)
251        let input_rms: f32 = samples
252            .iter()
253            .map(|&s| (s as f32).powi(2))
254            .sum::<f32>()
255            .sqrt()
256            / samples.len() as f32;
257        let output_rms: f32 = output
258            .iter()
259            .map(|&s| (s as f32).powi(2))
260            .sum::<f32>()
261            .sqrt()
262            / output.len() as f32;
263
264        assert!(
265            output_rms > input_rms * 0.7,
266            "Too much attenuation: input_rms={}, output_rms={}",
267            input_rms,
268            output_rms
269        );
270    }
271
272    #[test]
273    fn test_aliasing_suppression() {
274        let mut resampler = Resampler::new(16000, 8000);
275        let freq = 7000.0_f32; // Above 4kHz Nyquist of output
276        let samples: Vec<i16> = (0..1600)
277            .map(|i| ((i as f32 * freq * 2.0 * PI_F32 / 16000.0).sin() * 10000.0) as i16)
278            .collect();
279
280        let output = resampler.resample(&samples);
281
282        let output_rms: f32 =
283            (output.iter().map(|&s| (s as f32).powi(2)).sum::<f32>() / output.len() as f32).sqrt();
284        let input_rms: f32 = 10000.0 / 1.414; // Expected RMS of sine wave with amplitude 10000
285
286        assert!(
287            output_rms < input_rms / 50.0,
288            "Aliasing not sufficiently suppressed: output_rms={}",
289            output_rms
290        );
291    }
292
293    #[test]
294    fn test_performance_48k_to_8k() {
295        let mut resampler = Resampler::new(48000, 8000);
296        let input = vec![0i16; 48000];
297
298        let start = Instant::now();
299        let iterations = 100;
300        for _ in 0..iterations {
301            let _ = resampler.resample(&input);
302            resampler.reset();
303        }
304        let duration = start.elapsed();
305        let per_second = duration.as_secs_f64() / iterations as f64;
306        println!(
307            "Resampling 1s of 48kHz to 8kHz (24 taps) took: {:.4}ms",
308            per_second * 1000.0
309        );
310        assert!(
311            per_second < 0.1,
312            "Performance regression: {}ms",
313            per_second * 1000.0
314        );
315    }
316
317    #[test]
318    fn test_continuity_between_chunks() {
319        let input_rate = 16000;
320        let output_rate = 8000;
321
322        let freq = 1000.0_f32;
323        let total_samples = 3200;
324        let input: Vec<i16> = (0..total_samples)
325            .map(|i| ((i as f32 * freq * 2.0 * PI_F32 / input_rate as f32).sin() * 5000.0) as i16)
326            .collect();
327
328        let mut resampler1 = Resampler::new(input_rate, output_rate);
329        let output1 = resampler1.resample(&input);
330
331        let mut resampler2 = Resampler::new(input_rate, output_rate);
332        let mid = input.len() / 2;
333        let mut output2 = resampler2.resample(&input[..mid]);
334        output2.extend_from_slice(&resampler2.resample(&input[mid..]));
335
336        assert_eq!(output1.len(), output2.len(), "Output lengths differ");
337
338        let max_diff: i16 = output1
339            .iter()
340            .zip(output2.iter())
341            .map(|(a, b)| (a - b).abs())
342            .max()
343            .unwrap_or(0);
344
345        assert!(
346            max_diff < 100,
347            "Large discontinuity between chunks: max_diff={}",
348            max_diff
349        );
350    }
351}