audio_codec/
resampler.rs

1use super::{PcmBuf, Sample};
2use std::f32::consts::PI;
3
4/// A Polyphase FIR Resampler suitable for VoIP.
5pub struct Resampler {
6    input_rate: usize,
7    output_rate: usize,
8    ratio: f64,
9    coeffs: Vec<f32>,
10    num_phases: usize,
11    taps_per_phase: usize,
12    history: Vec<f32>,
13    current_pos: f64,
14}
15
16impl Resampler {
17    pub fn new(input_rate: usize, output_rate: usize) -> Self {
18        let ratio = output_rate as f64 / input_rate as f64;
19        let num_phases = 128;
20        let taps_per_phase = 16; // Increased to 16 for SIMD alignment (4x f32)
21        let filter_len = num_phases * taps_per_phase;
22
23        let mut raw_coeffs = vec![0.0f32; filter_len];
24        let cutoff = if ratio < 1.0 {
25            ratio as f32 * 0.45
26        } else {
27            0.45f32
28        };
29
30        let center = (taps_per_phase as f32 - 1.0) / 2.0;
31
32        for p in 0..num_phases {
33            let mut sum = 0.0;
34            let mut phase_coeffs = vec![0.0f32; taps_per_phase];
35            for t in 0..taps_per_phase {
36                let x = t as f32 - center - (p as f32 / num_phases as f32);
37                let val = if x.abs() < 1e-6 {
38                    2.0 * cutoff
39                } else {
40                    let x_pi = x * PI;
41                    (x_pi * 2.0 * cutoff).sin() / x_pi
42                };
43                let window = 0.54
44                    - 0.46
45                        * (2.0 * PI * (t as f32 * num_phases as f32 + p as f32)
46                            / (filter_len as f32 - 1.0))
47                            .cos();
48                phase_coeffs[t] = val * window;
49                sum += phase_coeffs[t];
50            }
51            for t in 0..taps_per_phase {
52                raw_coeffs[p * taps_per_phase + t] = phase_coeffs[t] / sum;
53            }
54        }
55
56        Self {
57            input_rate,
58            output_rate,
59            ratio,
60            coeffs: raw_coeffs,
61            num_phases,
62            taps_per_phase,
63            history: vec![0.0; taps_per_phase],
64            current_pos: 0.0,
65        }
66    }
67
68    #[inline(always)]
69    fn dot_product(a: &[f32], b: &[f32]) -> f32 {
70        #[cfg(target_arch = "aarch64")]
71        {
72            // Use ARM Neon intrinsics for aarch64 (Apple Silicon)
73            unsafe {
74                use std::arch::aarch64::*;
75                let mut sumv = vdupq_n_f32(0.0);
76                for i in (0..16).step_by(4) {
77                    let av = vld1q_f32(a.as_ptr().add(i));
78                    let bv = vld1q_f32(b.as_ptr().add(i));
79                    sumv = vfmaq_f32(sumv, av, bv);
80                }
81                vaddvq_f32(sumv)
82            }
83        }
84        #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
85        {
86            // Use AVX for x86_64
87            unsafe {
88                use std::arch::x86_64::*;
89                let av = _mm256_loadu_ps(a.as_ptr());
90                let bv = _mm256_loadu_ps(b.as_ptr());
91                let mut sumv = _mm256_mul_ps(av, bv);
92
93                let av2 = _mm256_loadu_ps(a.as_ptr().add(8));
94                let bv2 = _mm256_loadu_ps(b.as_ptr().add(8));
95                sumv = _mm256_add_ps(sumv, _mm256_mul_ps(av2, bv2));
96
97                let x128 = _mm_add_ps(_mm256_extractf128_ps(sumv, 1), _mm256_castps256_ps128(sumv));
98                let x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
99                let x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
100                _mm_cvtss_f32(x32)
101            }
102        }
103        #[cfg(not(any(
104            target_arch = "aarch64",
105            all(target_arch = "x86_64", target_feature = "avx")
106        )))]
107        {
108            // Fallback to auto-vectorized iterator
109            a.iter().zip(b).map(|(x, y)| x * y).sum()
110        }
111    }
112
113    pub fn resample(&mut self, input: &[Sample]) -> PcmBuf {
114        if self.input_rate == self.output_rate {
115            return input.to_vec();
116        }
117
118        let mut output = Vec::with_capacity((input.len() as f64 * self.ratio) as usize + 1);
119        let inv_ratio = 1.0 / self.ratio;
120        let taps = self.taps_per_phase;
121        let num_phases_f = self.num_phases as f64;
122
123        for &sample in input {
124            self.history.copy_within(1..taps, 0);
125            self.history[taps - 1] = sample as f32;
126
127            while self.current_pos < 1.0 {
128                let phase_idx = (self.current_pos * num_phases_f) as usize;
129                let offset = phase_idx * taps;
130                let phase_coeffs = &self.coeffs[offset..offset + taps];
131
132                let out_sample = Self::dot_product(phase_coeffs, &self.history);
133
134                output.push(out_sample.clamp(i16::MIN as f32, i16::MAX as f32) as i16);
135                self.current_pos += inv_ratio;
136            }
137            self.current_pos -= 1.0;
138        }
139
140        output
141    }
142}
143
144pub fn resample(input: &[Sample], input_sample_rate: u32, output_sample_rate: u32) -> PcmBuf {
145    if input_sample_rate == output_sample_rate {
146        return input.to_vec();
147    }
148    let mut r = Resampler::new(input_sample_rate as usize, output_sample_rate as usize);
149    r.resample(input)
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::time::Instant;
156
157    #[test]
158    fn test_resample_8k_to_16k() {
159        let mut resampler = Resampler::new(8000, 16000);
160        let input = vec![1000i16; 80];
161        let output = resampler.resample(&input);
162        assert!(output.len() >= 150 && output.len() <= 170);
163        for &s in &output[20..output.len() - 20] {
164            assert!((s - 1000).abs() < 100, "Value {} is too far from 1000", s);
165        }
166    }
167
168    #[test]
169    fn test_resample_16k_to_8k() {
170        let mut resampler = Resampler::new(16000, 8000);
171        let input = vec![1000i16; 160];
172        let output = resampler.resample(&input);
173        assert!(output.len() >= 75 && output.len() <= 85);
174        for &s in &output[20..output.len() - 20] {
175            assert!((s - 1000).abs() < 100, "Value {} is too far from 1000", s);
176        }
177    }
178
179    #[test]
180    fn test_performance() {
181        let mut resampler = Resampler::new(48000, 8000);
182        let input = vec![0i16; 48000];
183
184        let start = Instant::now();
185        let iterations = 100;
186        for _ in 0..iterations {
187            let _ = resampler.resample(&input);
188        }
189        let duration = start.elapsed();
190        let per_second = duration.as_secs_f64() / iterations as f64;
191        println!(
192            "Resampling 1s of 48kHz to 8kHz took: {:.4}ms",
193            per_second * 1000.0
194        );
195        assert!(per_second < 0.05);
196    }
197}