charon_audio/
audio.rs

1//! Audio I/O and processing utilities
2
3use crate::error::{CharonError, Result};
4use hound::{WavSpec, WavWriter};
5use ndarray::Array2;
6use rubato::{
7    Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
8};
9use std::path::Path;
10use symphonia::core::audio::{AudioBufferRef, Signal};
11use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
12use symphonia::core::conv::IntoSample;
13use symphonia::core::formats::FormatOptions;
14use symphonia::core::io::MediaSourceStream;
15use symphonia::core::meta::MetadataOptions;
16use symphonia::core::probe::Hint;
17
18/// Audio buffer holding multi-channel audio data
19#[derive(Debug, Clone)]
20pub struct AudioBuffer {
21    /// Audio samples [channels, samples]
22    pub data: Array2<f32>,
23    /// Sample rate in Hz
24    pub sample_rate: u32,
25}
26
27impl AudioBuffer {
28    /// Create a new audio buffer
29    pub fn new(data: Array2<f32>, sample_rate: u32) -> Self {
30        Self { data, sample_rate }
31    }
32
33    /// Get number of channels
34    pub fn channels(&self) -> usize {
35        self.data.nrows()
36    }
37
38    /// Get number of samples per channel
39    pub fn samples(&self) -> usize {
40        self.data.ncols()
41    }
42
43    /// Get duration in seconds
44    pub fn duration(&self) -> f64 {
45        self.samples() as f64 / self.sample_rate as f64
46    }
47
48    /// Convert to mono by averaging channels
49    pub fn to_mono(&self) -> Array2<f32> {
50        let mono = self.data.mean_axis(ndarray::Axis(0)).unwrap();
51        mono.insert_axis(ndarray::Axis(0))
52    }
53
54    /// Resample to target sample rate
55    pub fn resample(&self, target_rate: u32) -> Result<Self> {
56        if self.sample_rate == target_rate {
57            return Ok(self.clone());
58        }
59
60        let params = SincInterpolationParameters {
61            sinc_len: 256,
62            f_cutoff: 0.95,
63            interpolation: SincInterpolationType::Linear,
64            oversampling_factor: 256,
65            window: WindowFunction::BlackmanHarris2,
66        };
67
68        let mut resampler = SincFixedIn::<f32>::new(
69            target_rate as f64 / self.sample_rate as f64,
70            2.0,
71            params,
72            self.samples(),
73            self.channels(),
74        )
75        .map_err(|e| CharonError::Resampling(e.to_string()))?;
76
77        // Convert to channel-major format for rubato
78        let mut input_data: Vec<Vec<f32>> = Vec::new();
79        for ch in 0..self.channels() {
80            input_data.push(self.data.row(ch).to_vec());
81        }
82
83        let output_data = resampler
84            .process(&input_data, None)
85            .map_err(|e| CharonError::Resampling(e.to_string()))?;
86
87        // Convert back to ndarray format
88        let output_samples = output_data[0].len();
89        let mut data = Array2::zeros((self.channels(), output_samples));
90        for (ch, channel_data) in output_data.iter().enumerate() {
91            for (i, &sample) in channel_data.iter().enumerate() {
92                data[[ch, i]] = sample;
93            }
94        }
95
96        Ok(AudioBuffer::new(data, target_rate))
97    }
98
99    /// Convert number of channels
100    pub fn convert_channels(&self, target_channels: usize) -> Result<Self> {
101        if self.channels() == target_channels {
102            return Ok(self.clone());
103        }
104
105        let data = match (self.channels(), target_channels) {
106            (1, 2) => {
107                // Mono to stereo: duplicate channel
108                let mono = self.data.row(0);
109                ndarray::stack![ndarray::Axis(0), mono, mono]
110            }
111            (2, 1) => {
112                // Stereo to mono: average channels
113                self.to_mono()
114            }
115            (n, 1) if n > 1 => {
116                // Multi-channel to mono: average all channels
117                self.to_mono()
118            }
119            (n, m) if n > m => {
120                // Downmix: take first m channels
121                self.data.slice(ndarray::s![0..m, ..]).to_owned()
122            }
123            _ => {
124                return Err(CharonError::Audio(format!(
125                    "Unsupported channel conversion from {} to {}",
126                    self.channels(),
127                    target_channels
128                )))
129            }
130        };
131
132        Ok(AudioBuffer::new(data, self.sample_rate))
133    }
134
135    /// Normalize audio to [-1, 1] range
136    pub fn normalize(&mut self) {
137        let max_val = self.data.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
138        if max_val > 0.0 {
139            self.data /= max_val;
140        }
141    }
142
143    /// Apply gain (in dB)
144    pub fn apply_gain(&mut self, gain_db: f32) {
145        let gain = 10.0f32.powf(gain_db / 20.0);
146        self.data *= gain;
147    }
148}
149
150/// Audio file format
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
152pub enum AudioFormat {
153    Wav,
154    Mp3,
155    Flac,
156    Ogg,
157    Auto,
158}
159
160impl AudioFormat {
161    /// Detect format from file extension
162    pub fn from_path(path: &Path) -> Self {
163        match path.extension().and_then(|s| s.to_str()) {
164            Some("wav") => AudioFormat::Wav,
165            Some("mp3") => AudioFormat::Mp3,
166            Some("flac") => AudioFormat::Flac,
167            Some("ogg") => AudioFormat::Ogg,
168            _ => AudioFormat::Auto,
169        }
170    }
171}
172
173/// Audio file reader/writer
174pub struct AudioFile;
175
176impl AudioFile {
177    /// Read audio file with automatic format detection
178    pub fn read<P: AsRef<Path>>(path: P) -> Result<AudioBuffer> {
179        let path = path.as_ref();
180        let file = std::fs::File::open(path)?;
181        let mss = MediaSourceStream::new(Box::new(file), Default::default());
182
183        let mut hint = Hint::new();
184        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
185            hint.with_extension(ext);
186        }
187
188        let meta_opts = MetadataOptions::default();
189        let fmt_opts = FormatOptions::default();
190
191        let probed = symphonia::default::get_probe()
192            .format(&hint, mss, &fmt_opts, &meta_opts)
193            .map_err(|e| CharonError::Audio(e.to_string()))?;
194
195        let mut format = probed.format;
196        let track = format
197            .tracks()
198            .iter()
199            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
200            .ok_or_else(|| CharonError::Audio("No supported audio track found".to_string()))?;
201
202        let dec_opts = DecoderOptions::default();
203        let mut decoder = symphonia::default::get_codecs()
204            .make(&track.codec_params, &dec_opts)
205            .map_err(|e| CharonError::Audio(e.to_string()))?;
206
207        let sample_rate = track
208            .codec_params
209            .sample_rate
210            .ok_or_else(|| CharonError::Audio("Sample rate not found".to_string()))?;
211
212        let channels = track
213            .codec_params
214            .channels
215            .ok_or_else(|| CharonError::Audio("Channel info not found".to_string()))?
216            .count();
217
218        let mut samples: Vec<Vec<f32>> = vec![Vec::new(); channels];
219
220        while let Ok(packet) = format.next_packet() {
221            let decoded = match decoder.decode(&packet) {
222                Ok(decoded) => decoded,
223                Err(_) => continue,
224            };
225
226            Self::copy_samples(&decoded, &mut samples);
227        }
228
229        // Convert to ndarray
230        let num_samples = samples[0].len();
231        let mut data = Array2::zeros((channels, num_samples));
232        for (ch, channel_samples) in samples.iter().enumerate() {
233            for (i, &sample) in channel_samples.iter().enumerate() {
234                data[[ch, i]] = sample;
235            }
236        }
237
238        Ok(AudioBuffer::new(data, sample_rate))
239    }
240
241    fn copy_samples(decoded: &AudioBufferRef, output: &mut [Vec<f32>]) {
242        match decoded {
243            AudioBufferRef::F32(buf) => {
244                for (ch, out_ch) in output
245                    .iter_mut()
246                    .enumerate()
247                    .take(buf.spec().channels.count())
248                {
249                    out_ch.extend_from_slice(buf.chan(ch));
250                }
251            }
252            AudioBufferRef::S32(buf) => {
253                for (ch, out_ch) in output
254                    .iter_mut()
255                    .enumerate()
256                    .take(buf.spec().channels.count())
257                {
258                    out_ch.extend(
259                        buf.chan(ch)
260                            .iter()
261                            .map(|&s| IntoSample::<f32>::into_sample(s)),
262                    );
263                }
264            }
265            AudioBufferRef::S16(buf) => {
266                for (ch, out_ch) in output
267                    .iter_mut()
268                    .enumerate()
269                    .take(buf.spec().channels.count())
270                {
271                    out_ch.extend(
272                        buf.chan(ch)
273                            .iter()
274                            .map(|&s| IntoSample::<f32>::into_sample(s)),
275                    );
276                }
277            }
278            AudioBufferRef::U8(buf) => {
279                for (ch, out_ch) in output
280                    .iter_mut()
281                    .enumerate()
282                    .take(buf.spec().channels.count())
283                {
284                    out_ch.extend(
285                        buf.chan(ch)
286                            .iter()
287                            .map(|&s| IntoSample::<f32>::into_sample(s)),
288                    );
289                }
290            }
291            _ => {}
292        }
293    }
294
295    /// Write audio buffer to WAV file
296    pub fn write_wav<P: AsRef<Path>>(path: P, buffer: &AudioBuffer) -> Result<()> {
297        let spec = WavSpec {
298            channels: buffer.channels() as u16,
299            sample_rate: buffer.sample_rate,
300            bits_per_sample: 32,
301            sample_format: hound::SampleFormat::Float,
302        };
303
304        let mut writer =
305            WavWriter::create(path, spec).map_err(|e| CharonError::Audio(e.to_string()))?;
306
307        // Interleave samples
308        for i in 0..buffer.samples() {
309            for ch in 0..buffer.channels() {
310                writer
311                    .write_sample(buffer.data[[ch, i]])
312                    .map_err(|e| CharonError::Audio(e.to_string()))?;
313            }
314        }
315
316        writer
317            .finalize()
318            .map_err(|e| CharonError::Audio(e.to_string()))?;
319        Ok(())
320    }
321
322    /// Write audio buffer to file (format detected from extension)
323    pub fn write<P: AsRef<Path>>(path: P, buffer: &AudioBuffer) -> Result<()> {
324        let format = AudioFormat::from_path(path.as_ref());
325        match format {
326            AudioFormat::Wav | AudioFormat::Auto => Self::write_wav(path, buffer),
327            _ => Err(CharonError::NotSupported(
328                "Only WAV output is currently supported".to_string(),
329            )),
330        }
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use approx::assert_abs_diff_eq;
338
339    #[test]
340    fn test_audio_buffer_creation() {
341        let data = Array2::zeros((2, 1000));
342        let buffer = AudioBuffer::new(data, 44100);
343        assert_eq!(buffer.channels(), 2);
344        assert_eq!(buffer.samples(), 1000);
345        assert_eq!(buffer.sample_rate, 44100);
346    }
347
348    #[test]
349    fn test_duration_calculation() {
350        let data = Array2::zeros((2, 44100));
351        let buffer = AudioBuffer::new(data, 44100);
352        assert_abs_diff_eq!(buffer.duration(), 1.0, epsilon = 0.001);
353    }
354
355    #[test]
356    fn test_mono_conversion() {
357        let mut data = Array2::zeros((2, 100));
358        data.row_mut(0).fill(1.0);
359        data.row_mut(1).fill(3.0);
360
361        let buffer = AudioBuffer::new(data, 44100);
362        let mono = buffer.to_mono();
363
364        assert_eq!(mono.nrows(), 1);
365        assert_abs_diff_eq!(mono[[0, 0]], 2.0, epsilon = 0.001);
366    }
367}