1use audioadapter_buffers::direct::InterleavedSlice;
2use num::Float;
3use rubato::Fft;
4use rubato::Resampler as _;
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum ResampleError {
9 #[error("could not create resampler")]
10 Construction(#[from] rubato::ResamplerConstructionError),
11 #[error("could not resample audio")]
12 Process(#[from] rubato::ResampleError),
13}
14
15pub fn resample<F: Float + rubato::Sample>(
16 audio_interleaved: &[F],
17 num_channels: usize,
18 sr_in: u32,
19 sr_out: u32,
20) -> Result<Vec<F>, ResampleError> {
21 let mut resampler = Fft::new(
22 sr_in as usize,
23 sr_out as usize,
24 1024,
25 2,
26 num_channels,
27 rubato::FixedSync::Both,
28 )?;
29
30 let num_input_frames = audio_interleaved.len() / num_channels;
31 let buffer_in = InterleavedSlice::new(audio_interleaved, num_channels, num_input_frames)
32 .expect("Should be the right size");
33
34 let num_output_frames = resampler.process_all_needed_output_len(num_input_frames);
35 let mut out_slice = vec![F::zero(); num_output_frames * num_channels];
36 let mut buffer_out = InterleavedSlice::new_mut(&mut out_slice, num_channels, num_output_frames)
37 .expect("should be the right size");
38
39 let (_, actual_output_frames) =
42 resampler.process_all_into_buffer(&buffer_in, &mut buffer_out, num_input_frames, None)?;
43
44 out_slice.truncate(actual_output_frames * num_channels);
46 Ok(out_slice)
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 #[test]
54 fn test_resample_preserves_frequency() {
55 use crate::reader::{ReadConfig, read};
56 use audio_blocks::{AudioBlock, InterleavedView};
57
58 let audio = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
60
61 assert_eq!(audio.sample_rate, 48000);
62 assert_eq!(audio.num_channels, 4);
63
64 let sr_out = 22050u32;
66 let resampled = resample(
67 &audio.samples_interleaved,
68 audio.num_channels as usize,
69 audio.sample_rate,
70 sr_out,
71 )
72 .unwrap();
73
74 let block = InterleavedView::from_slice(&resampled, audio.num_channels);
75
76 let expected_frames = 22050usize;
78 assert_eq!(
79 block.num_frames(),
80 expected_frames,
81 "Expected {} frames, got {}",
82 expected_frames,
83 block.num_frames()
84 );
85
86 const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
89
90 let start_frame = 100;
92 let test_frames = 1000;
93
94 for (ch, &freq) in FREQUENCIES.iter().enumerate() {
95 let mut max_error: f32 = 0.0;
96 for frame in start_frame..(start_frame + test_frames) {
97 let expected =
98 (2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
99 let actual = block.sample(ch as u16, frame);
100 let error = (actual - expected).abs();
101 max_error = max_error.max(error);
102 }
103 assert!(
104 max_error < 0.02,
105 "Channel {} ({}Hz): max error {} exceeds threshold",
106 ch,
107 freq,
108 max_error
109 );
110 }
111 }
112}