1use anyhow::Result;
8use symphonia::core::{
9 audio::SampleBuffer, codecs::DecoderOptions, formats::FormatOptions, io::MediaSourceStream,
10 meta::MetadataOptions, probe::Hint,
11};
12
13#[derive(Clone, Debug, PartialEq)]
15pub struct AudioInput {
16 pub samples: Vec<f32>,
17 pub sample_rate: u32,
18 pub channels: u16,
19}
20
21impl AudioInput {
22 pub fn read_wav(wav_path: &str) -> Result<Self> {
24 let mut reader = hound::WavReader::open(wav_path)?;
25 let spec = reader.spec();
26 let samples: Vec<f32> = match spec.sample_format {
27 hound::SampleFormat::Float => reader
28 .samples::<f32>()
29 .collect::<std::result::Result<_, _>>()?,
30 hound::SampleFormat::Int => reader
31 .samples::<i16>()
32 .map(|s| s.map(|v| v as f32 / 32768.0))
35 .collect::<std::result::Result<_, _>>()?,
36 };
37 Ok(Self {
38 samples,
39 sample_rate: spec.sample_rate,
40 channels: spec.channels,
41 })
42 }
43
44 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
46 let cursor = std::io::Cursor::new(bytes.to_vec());
47 let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
48 let hint = Hint::new();
49 let probed = symphonia::default::get_probe().format(
50 &hint,
51 mss,
52 &FormatOptions::default(),
53 &MetadataOptions::default(),
54 )?;
55 let mut format = probed.format;
56 let track = format
57 .default_track()
58 .ok_or_else(|| anyhow::anyhow!("no supported audio tracks"))?;
59 let codec_params = &track.codec_params;
60 let sample_rate = codec_params
61 .sample_rate
62 .ok_or_else(|| anyhow::anyhow!("unknown sample rate"))?;
63 #[allow(clippy::cast_possible_truncation)]
64 let channels = codec_params.channels.map(|c| c.count() as u16).unwrap_or(1);
65 let mut decoder =
66 symphonia::default::get_codecs().make(codec_params, &DecoderOptions::default())?;
67 let mut samples = Vec::new();
68 loop {
69 match format.next_packet() {
70 Ok(packet) => {
71 let decoded = decoder.decode(&packet)?;
72 let mut buf =
73 SampleBuffer::<f32>::new(decoded.capacity() as u64, *decoded.spec());
74 buf.copy_interleaved_ref(decoded);
75 samples.extend_from_slice(buf.samples());
76 }
77 Err(symphonia::core::errors::Error::IoError(e))
78 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
79 {
80 break;
81 }
82 Err(e) => return Err(e.into()),
83 }
84 }
85 Ok(Self {
86 samples,
87 sample_rate,
88 channels,
89 })
90 }
91
92 pub fn to_mono(&self) -> Vec<f32> {
94 if self.channels <= 1 {
95 return self.samples.clone();
96 }
97 let mut mono = vec![0.0; self.samples.len() / self.channels as usize];
98 for (i, sample) in self.samples.iter().enumerate() {
99 mono[i / self.channels as usize] += *sample;
100 }
101 for s in &mut mono {
102 *s /= self.channels as f32;
103 }
104 mono
105 }
106
107 pub fn normalize(&mut self) -> &mut Self {
109 let max_amplitude = self.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
110 if max_amplitude > 0.0 && max_amplitude != 1.0 {
111 let scale = 1.0 / max_amplitude;
112 for sample in &mut self.samples {
113 *sample *= scale;
114 }
115 }
116 self
117 }
118
119 pub fn apply_fade(&mut self, fade_in_samples: usize, fade_out_samples: usize) -> &mut Self {
121 let len = self.samples.len();
122 for i in 0..fade_in_samples.min(len) {
124 let factor = i as f32 / fade_in_samples as f32;
125 self.samples[i] *= factor;
126 }
127 for i in 0..fade_out_samples.min(len) {
129 let factor = (fade_out_samples - i) as f32 / fade_out_samples as f32;
130 self.samples[len - 1 - i] *= factor;
131 }
132 self
133 }
134
135 pub fn remove_dc_offset(&mut self) -> &mut Self {
137 if self.samples.is_empty() {
138 return self;
139 }
140 let mean = self.samples.iter().sum::<f32>() / self.samples.len() as f32;
141 for sample in &mut self.samples {
142 *sample -= mean;
143 }
144 self
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::AudioInput;
151 use hound::{SampleFormat, WavSpec, WavWriter};
152 use std::io::Cursor;
153
154 #[test]
155 fn read_wav_roundtrip() {
156 let spec = WavSpec {
157 channels: 1,
158 sample_rate: 16000,
159 bits_per_sample: 16,
160 sample_format: SampleFormat::Int,
161 };
162 let mut writer = WavWriter::create("/tmp/test.wav", spec).unwrap();
163 for _ in 0..160 {
164 writer.write_sample::<i16>(0).unwrap();
165 }
166 writer.finalize().unwrap();
167 let input = AudioInput::read_wav("/tmp/test.wav").unwrap();
168 assert_eq!(input.samples.len(), 160);
169 assert_eq!(input.sample_rate, 16000);
170 std::fs::remove_file("/tmp/test.wav").unwrap();
171 }
172
173 #[test]
174 fn read_wav_matches_pcm16_full_scale_normalization() {
175 let spec = WavSpec {
176 channels: 1,
177 sample_rate: 16000,
178 bits_per_sample: 16,
179 sample_format: SampleFormat::Int,
180 };
181 let mut writer = WavWriter::create("/tmp/test_full_scale.wav", spec).unwrap();
182 writer.write_sample::<i16>(i16::MIN).unwrap();
183 writer.write_sample::<i16>(i16::MAX).unwrap();
184 writer.finalize().unwrap();
185
186 let input = AudioInput::read_wav("/tmp/test_full_scale.wav").unwrap();
187 assert_eq!(input.samples, vec![-1.0, 32767.0 / 32768.0]);
188
189 std::fs::remove_file("/tmp/test_full_scale.wav").unwrap();
190 }
191
192 #[test]
193 fn from_bytes() {
194 let spec = WavSpec {
195 channels: 1,
196 sample_rate: 8000,
197 bits_per_sample: 16,
198 sample_format: SampleFormat::Int,
199 };
200 let mut buffer: Vec<u8> = Vec::new();
201 {
202 let mut writer = WavWriter::new(Cursor::new(&mut buffer), spec).unwrap();
203 for _ in 0..80 {
204 writer.write_sample::<i16>(0).unwrap();
205 }
206 writer.finalize().unwrap();
207 }
208 let input = AudioInput::from_bytes(&buffer).unwrap();
209 assert_eq!(input.samples.len(), 80);
210 assert_eq!(input.sample_rate, 8000);
211 }
212
213 #[test]
214 fn test_normalize() {
215 let mut input = AudioInput {
216 samples: vec![0.2, -0.5, 0.8, -1.0],
217 sample_rate: 16000,
218 channels: 1,
219 };
220 input.normalize();
221 let max = input.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
222 assert!((max - 1.0).abs() < 1e-6);
223 }
224
225 #[test]
226 fn test_remove_dc_offset() {
227 let mut input = AudioInput {
228 samples: vec![1.0, 1.0, 1.0, 1.0],
229 sample_rate: 16000,
230 channels: 1,
231 };
232 input.remove_dc_offset();
233 for s in input.samples {
234 assert!((s - 0.0).abs() < 1e-6);
235 }
236 }
237}