Skip to main content

hanzo_audio/
lib.rs

1//! Audio utilities for `hanzo`.
2//!
3//! This crate mirrors the `hanzo-vision` crate and focuses on audio specific
4//! functionality such as reading audio data, resampling and computing
5//! mel spectrogram features.
6
7use anyhow::Result;
8use symphonia::core::{
9    audio::SampleBuffer, codecs::DecoderOptions, formats::FormatOptions, io::MediaSourceStream,
10    meta::MetadataOptions, probe::Hint,
11};
12
13/// Raw audio input consisting of PCM samples and a sample rate.
14#[derive(Clone, Debug, PartialEq)]
15pub struct AudioInput {
16    pub samples: Vec<f32>,
17    pub sample_rate: u32,
18    pub channels: u16,
19}
20
21impl AudioInput {
22    /// Read a wav file from disk.
23    pub fn read_wav(wav_path: &str) -> Result<Self> {
24        let mut reader = hound::WavReader::open(wav_path)?;
25        let spec = reader.spec();
26        let samples: Vec<f32> = match spec.sample_format {
27            hound::SampleFormat::Float => reader
28                .samples::<f32>()
29                .collect::<std::result::Result<_, _>>()?,
30            hound::SampleFormat::Int => reader
31                .samples::<i16>()
32                // Match libsndfile/soundfile normalization for PCM16 by
33                // dividing by the full signed range, not by `i16::MAX`.
34                .map(|s| s.map(|v| v as f32 / 32768.0))
35                .collect::<std::result::Result<_, _>>()?,
36        };
37        Ok(Self {
38            samples,
39            sample_rate: spec.sample_rate,
40            channels: spec.channels,
41        })
42    }
43
44    /// Decode audio bytes using `symphonia`.
45    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
46        let cursor = std::io::Cursor::new(bytes.to_vec());
47        let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
48        let hint = Hint::new();
49        let probed = symphonia::default::get_probe().format(
50            &hint,
51            mss,
52            &FormatOptions::default(),
53            &MetadataOptions::default(),
54        )?;
55        let mut format = probed.format;
56        let track = format
57            .default_track()
58            .ok_or_else(|| anyhow::anyhow!("no supported audio tracks"))?;
59        let codec_params = &track.codec_params;
60        let sample_rate = codec_params
61            .sample_rate
62            .ok_or_else(|| anyhow::anyhow!("unknown sample rate"))?;
63        #[allow(clippy::cast_possible_truncation)]
64        let channels = codec_params.channels.map(|c| c.count() as u16).unwrap_or(1);
65        let mut decoder =
66            symphonia::default::get_codecs().make(codec_params, &DecoderOptions::default())?;
67        let mut samples = Vec::new();
68        loop {
69            match format.next_packet() {
70                Ok(packet) => {
71                    let decoded = decoder.decode(&packet)?;
72                    let mut buf =
73                        SampleBuffer::<f32>::new(decoded.capacity() as u64, *decoded.spec());
74                    buf.copy_interleaved_ref(decoded);
75                    samples.extend_from_slice(buf.samples());
76                }
77                Err(symphonia::core::errors::Error::IoError(e))
78                    if e.kind() == std::io::ErrorKind::UnexpectedEof =>
79                {
80                    break;
81                }
82                Err(e) => return Err(e.into()),
83            }
84        }
85        Ok(Self {
86            samples,
87            sample_rate,
88            channels,
89        })
90    }
91
92    /// Convert multi channel audio to mono by averaging channels.
93    pub fn to_mono(&self) -> Vec<f32> {
94        if self.channels <= 1 {
95            return self.samples.clone();
96        }
97        let mut mono = vec![0.0; self.samples.len() / self.channels as usize];
98        for (i, sample) in self.samples.iter().enumerate() {
99            mono[i / self.channels as usize] += *sample;
100        }
101        for s in &mut mono {
102            *s /= self.channels as f32;
103        }
104        mono
105    }
106
107    /// Normalize audio to prevent clipping
108    pub fn normalize(&mut self) -> &mut Self {
109        let max_amplitude = self.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
110        if max_amplitude > 0.0 && max_amplitude != 1.0 {
111            let scale = 1.0 / max_amplitude;
112            for sample in &mut self.samples {
113                *sample *= scale;
114            }
115        }
116        self
117    }
118
119    /// Apply fade in/out to reduce audio artifacts
120    pub fn apply_fade(&mut self, fade_in_samples: usize, fade_out_samples: usize) -> &mut Self {
121        let len = self.samples.len();
122        // Fade in
123        for i in 0..fade_in_samples.min(len) {
124            let factor = i as f32 / fade_in_samples as f32;
125            self.samples[i] *= factor;
126        }
127        // Fade out
128        for i in 0..fade_out_samples.min(len) {
129            let factor = (fade_out_samples - i) as f32 / fade_out_samples as f32;
130            self.samples[len - 1 - i] *= factor;
131        }
132        self
133    }
134
135    /// Remove DC offset (audio centered around 0)
136    pub fn remove_dc_offset(&mut self) -> &mut Self {
137        if self.samples.is_empty() {
138            return self;
139        }
140        let mean = self.samples.iter().sum::<f32>() / self.samples.len() as f32;
141        for sample in &mut self.samples {
142            *sample -= mean;
143        }
144        self
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::AudioInput;
151    use hound::{SampleFormat, WavSpec, WavWriter};
152    use std::io::Cursor;
153
154    #[test]
155    fn read_wav_roundtrip() {
156        let spec = WavSpec {
157            channels: 1,
158            sample_rate: 16000,
159            bits_per_sample: 16,
160            sample_format: SampleFormat::Int,
161        };
162        let mut writer = WavWriter::create("/tmp/test.wav", spec).unwrap();
163        for _ in 0..160 {
164            writer.write_sample::<i16>(0).unwrap();
165        }
166        writer.finalize().unwrap();
167        let input = AudioInput::read_wav("/tmp/test.wav").unwrap();
168        assert_eq!(input.samples.len(), 160);
169        assert_eq!(input.sample_rate, 16000);
170        std::fs::remove_file("/tmp/test.wav").unwrap();
171    }
172
173    #[test]
174    fn read_wav_matches_pcm16_full_scale_normalization() {
175        let spec = WavSpec {
176            channels: 1,
177            sample_rate: 16000,
178            bits_per_sample: 16,
179            sample_format: SampleFormat::Int,
180        };
181        let mut writer = WavWriter::create("/tmp/test_full_scale.wav", spec).unwrap();
182        writer.write_sample::<i16>(i16::MIN).unwrap();
183        writer.write_sample::<i16>(i16::MAX).unwrap();
184        writer.finalize().unwrap();
185
186        let input = AudioInput::read_wav("/tmp/test_full_scale.wav").unwrap();
187        assert_eq!(input.samples, vec![-1.0, 32767.0 / 32768.0]);
188
189        std::fs::remove_file("/tmp/test_full_scale.wav").unwrap();
190    }
191
192    #[test]
193    fn from_bytes() {
194        let spec = WavSpec {
195            channels: 1,
196            sample_rate: 8000,
197            bits_per_sample: 16,
198            sample_format: SampleFormat::Int,
199        };
200        let mut buffer: Vec<u8> = Vec::new();
201        {
202            let mut writer = WavWriter::new(Cursor::new(&mut buffer), spec).unwrap();
203            for _ in 0..80 {
204                writer.write_sample::<i16>(0).unwrap();
205            }
206            writer.finalize().unwrap();
207        }
208        let input = AudioInput::from_bytes(&buffer).unwrap();
209        assert_eq!(input.samples.len(), 80);
210        assert_eq!(input.sample_rate, 8000);
211    }
212
213    #[test]
214    fn test_normalize() {
215        let mut input = AudioInput {
216            samples: vec![0.2, -0.5, 0.8, -1.0],
217            sample_rate: 16000,
218            channels: 1,
219        };
220        input.normalize();
221        let max = input.samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
222        assert!((max - 1.0).abs() < 1e-6);
223    }
224
225    #[test]
226    fn test_remove_dc_offset() {
227        let mut input = AudioInput {
228            samples: vec![1.0, 1.0, 1.0, 1.0],
229            sample_rate: 16000,
230            channels: 1,
231        };
232        input.remove_dc_offset();
233        for s in input.samples {
234            assert!((s - 0.0).abs() < 1e-6);
235        }
236    }
237}