1use crate::error::{CharonError, Result};
4use hound::{WavSpec, WavWriter};
5use ndarray::Array2;
6use rubato::{
7 Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
8};
9use std::path::Path;
10use symphonia::core::audio::{AudioBufferRef, Signal};
11use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
12use symphonia::core::conv::IntoSample;
13use symphonia::core::formats::FormatOptions;
14use symphonia::core::io::MediaSourceStream;
15use symphonia::core::meta::MetadataOptions;
16use symphonia::core::probe::Hint;
17
18#[derive(Debug, Clone)]
20pub struct AudioBuffer {
21 pub data: Array2<f32>,
23 pub sample_rate: u32,
25}
26
27impl AudioBuffer {
28 pub fn new(data: Array2<f32>, sample_rate: u32) -> Self {
30 Self { data, sample_rate }
31 }
32
33 pub fn channels(&self) -> usize {
35 self.data.nrows()
36 }
37
38 pub fn samples(&self) -> usize {
40 self.data.ncols()
41 }
42
43 pub fn duration(&self) -> f64 {
45 self.samples() as f64 / self.sample_rate as f64
46 }
47
48 pub fn to_mono(&self) -> Array2<f32> {
50 let mono = self.data.mean_axis(ndarray::Axis(0)).unwrap();
51 mono.insert_axis(ndarray::Axis(0))
52 }
53
54 pub fn resample(&self, target_rate: u32) -> Result<Self> {
56 if self.sample_rate == target_rate {
57 return Ok(self.clone());
58 }
59
60 let params = SincInterpolationParameters {
61 sinc_len: 256,
62 f_cutoff: 0.95,
63 interpolation: SincInterpolationType::Linear,
64 oversampling_factor: 256,
65 window: WindowFunction::BlackmanHarris2,
66 };
67
68 let mut resampler = SincFixedIn::<f32>::new(
69 target_rate as f64 / self.sample_rate as f64,
70 2.0,
71 params,
72 self.samples(),
73 self.channels(),
74 )
75 .map_err(|e| CharonError::Resampling(e.to_string()))?;
76
77 let mut input_data: Vec<Vec<f32>> = Vec::new();
79 for ch in 0..self.channels() {
80 input_data.push(self.data.row(ch).to_vec());
81 }
82
83 let output_data = resampler
84 .process(&input_data, None)
85 .map_err(|e| CharonError::Resampling(e.to_string()))?;
86
87 let output_samples = output_data[0].len();
89 let mut data = Array2::zeros((self.channels(), output_samples));
90 for (ch, channel_data) in output_data.iter().enumerate() {
91 for (i, &sample) in channel_data.iter().enumerate() {
92 data[[ch, i]] = sample;
93 }
94 }
95
96 Ok(AudioBuffer::new(data, target_rate))
97 }
98
99 pub fn convert_channels(&self, target_channels: usize) -> Result<Self> {
101 if self.channels() == target_channels {
102 return Ok(self.clone());
103 }
104
105 let data = match (self.channels(), target_channels) {
106 (1, 2) => {
107 let mono = self.data.row(0);
109 ndarray::stack![ndarray::Axis(0), mono, mono]
110 }
111 (2, 1) => {
112 self.to_mono()
114 }
115 (n, 1) if n > 1 => {
116 self.to_mono()
118 }
119 (n, m) if n > m => {
120 self.data.slice(ndarray::s![0..m, ..]).to_owned()
122 }
123 _ => {
124 return Err(CharonError::Audio(format!(
125 "Unsupported channel conversion from {} to {}",
126 self.channels(),
127 target_channels
128 )))
129 }
130 };
131
132 Ok(AudioBuffer::new(data, self.sample_rate))
133 }
134
135 pub fn normalize(&mut self) {
137 let max_val = self.data.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
138 if max_val > 0.0 {
139 self.data /= max_val;
140 }
141 }
142
143 pub fn apply_gain(&mut self, gain_db: f32) {
145 let gain = 10.0f32.powf(gain_db / 20.0);
146 self.data *= gain;
147 }
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
152pub enum AudioFormat {
153 Wav,
154 Mp3,
155 Flac,
156 Ogg,
157 Auto,
158}
159
160impl AudioFormat {
161 pub fn from_path(path: &Path) -> Self {
163 match path.extension().and_then(|s| s.to_str()) {
164 Some("wav") => AudioFormat::Wav,
165 Some("mp3") => AudioFormat::Mp3,
166 Some("flac") => AudioFormat::Flac,
167 Some("ogg") => AudioFormat::Ogg,
168 _ => AudioFormat::Auto,
169 }
170 }
171}
172
173pub struct AudioFile;
175
176impl AudioFile {
177 pub fn read<P: AsRef<Path>>(path: P) -> Result<AudioBuffer> {
179 let path = path.as_ref();
180 let file = std::fs::File::open(path)?;
181 let mss = MediaSourceStream::new(Box::new(file), Default::default());
182
183 let mut hint = Hint::new();
184 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
185 hint.with_extension(ext);
186 }
187
188 let meta_opts = MetadataOptions::default();
189 let fmt_opts = FormatOptions::default();
190
191 let probed = symphonia::default::get_probe()
192 .format(&hint, mss, &fmt_opts, &meta_opts)
193 .map_err(|e| CharonError::Audio(e.to_string()))?;
194
195 let mut format = probed.format;
196 let track = format
197 .tracks()
198 .iter()
199 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
200 .ok_or_else(|| CharonError::Audio("No supported audio track found".to_string()))?;
201
202 let dec_opts = DecoderOptions::default();
203 let mut decoder = symphonia::default::get_codecs()
204 .make(&track.codec_params, &dec_opts)
205 .map_err(|e| CharonError::Audio(e.to_string()))?;
206
207 let sample_rate = track
208 .codec_params
209 .sample_rate
210 .ok_or_else(|| CharonError::Audio("Sample rate not found".to_string()))?;
211
212 let channels = track
213 .codec_params
214 .channels
215 .ok_or_else(|| CharonError::Audio("Channel info not found".to_string()))?
216 .count();
217
218 let mut samples: Vec<Vec<f32>> = vec![Vec::new(); channels];
219
220 while let Ok(packet) = format.next_packet() {
221 let decoded = match decoder.decode(&packet) {
222 Ok(decoded) => decoded,
223 Err(_) => continue,
224 };
225
226 Self::copy_samples(&decoded, &mut samples);
227 }
228
229 let num_samples = samples[0].len();
231 let mut data = Array2::zeros((channels, num_samples));
232 for (ch, channel_samples) in samples.iter().enumerate() {
233 for (i, &sample) in channel_samples.iter().enumerate() {
234 data[[ch, i]] = sample;
235 }
236 }
237
238 Ok(AudioBuffer::new(data, sample_rate))
239 }
240
241 fn copy_samples(decoded: &AudioBufferRef, output: &mut [Vec<f32>]) {
242 match decoded {
243 AudioBufferRef::F32(buf) => {
244 for (ch, out_ch) in output
245 .iter_mut()
246 .enumerate()
247 .take(buf.spec().channels.count())
248 {
249 out_ch.extend_from_slice(buf.chan(ch));
250 }
251 }
252 AudioBufferRef::S32(buf) => {
253 for (ch, out_ch) in output
254 .iter_mut()
255 .enumerate()
256 .take(buf.spec().channels.count())
257 {
258 out_ch.extend(
259 buf.chan(ch)
260 .iter()
261 .map(|&s| IntoSample::<f32>::into_sample(s)),
262 );
263 }
264 }
265 AudioBufferRef::S16(buf) => {
266 for (ch, out_ch) in output
267 .iter_mut()
268 .enumerate()
269 .take(buf.spec().channels.count())
270 {
271 out_ch.extend(
272 buf.chan(ch)
273 .iter()
274 .map(|&s| IntoSample::<f32>::into_sample(s)),
275 );
276 }
277 }
278 AudioBufferRef::U8(buf) => {
279 for (ch, out_ch) in output
280 .iter_mut()
281 .enumerate()
282 .take(buf.spec().channels.count())
283 {
284 out_ch.extend(
285 buf.chan(ch)
286 .iter()
287 .map(|&s| IntoSample::<f32>::into_sample(s)),
288 );
289 }
290 }
291 _ => {}
292 }
293 }
294
295 pub fn write_wav<P: AsRef<Path>>(path: P, buffer: &AudioBuffer) -> Result<()> {
297 let spec = WavSpec {
298 channels: buffer.channels() as u16,
299 sample_rate: buffer.sample_rate,
300 bits_per_sample: 32,
301 sample_format: hound::SampleFormat::Float,
302 };
303
304 let mut writer =
305 WavWriter::create(path, spec).map_err(|e| CharonError::Audio(e.to_string()))?;
306
307 for i in 0..buffer.samples() {
309 for ch in 0..buffer.channels() {
310 writer
311 .write_sample(buffer.data[[ch, i]])
312 .map_err(|e| CharonError::Audio(e.to_string()))?;
313 }
314 }
315
316 writer
317 .finalize()
318 .map_err(|e| CharonError::Audio(e.to_string()))?;
319 Ok(())
320 }
321
322 pub fn write<P: AsRef<Path>>(path: P, buffer: &AudioBuffer) -> Result<()> {
324 let format = AudioFormat::from_path(path.as_ref());
325 match format {
326 AudioFormat::Wav | AudioFormat::Auto => Self::write_wav(path, buffer),
327 _ => Err(CharonError::NotSupported(
328 "Only WAV output is currently supported".to_string(),
329 )),
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use approx::assert_abs_diff_eq;
338
339 #[test]
340 fn test_audio_buffer_creation() {
341 let data = Array2::zeros((2, 1000));
342 let buffer = AudioBuffer::new(data, 44100);
343 assert_eq!(buffer.channels(), 2);
344 assert_eq!(buffer.samples(), 1000);
345 assert_eq!(buffer.sample_rate, 44100);
346 }
347
348 #[test]
349 fn test_duration_calculation() {
350 let data = Array2::zeros((2, 44100));
351 let buffer = AudioBuffer::new(data, 44100);
352 assert_abs_diff_eq!(buffer.duration(), 1.0, epsilon = 0.001);
353 }
354
355 #[test]
356 fn test_mono_conversion() {
357 let mut data = Array2::zeros((2, 100));
358 data.row_mut(0).fill(1.0);
359 data.row_mut(1).fill(3.0);
360
361 let buffer = AudioBuffer::new(data, 44100);
362 let mono = buffer.to_mono();
363
364 assert_eq!(mono.nrows(), 1);
365 assert_abs_diff_eq!(mono[[0, 0]], 2.0, epsilon = 0.001);
366 }
367}