Skip to main content

any_tts/
audio.rs

1//! Audio output types and utilities.
2
3use std::io::{Read, Seek};
4
5mod decode;
6mod denoise;
7
8use decode::{decode_audio_bytes, decode_audio_stream, decode_wav_bytes};
9use denoise::denoise_audio_samples;
10pub use denoise::DenoiseOptions;
11
12/// Raw audio samples produced by TTS synthesis.
13#[derive(Debug, Clone)]
14pub struct AudioSamples {
15    /// Raw PCM samples as f32 in range [-1.0, 1.0].
16    pub samples: Vec<f32>,
17    /// Sample rate in Hz (e.g. 24000).
18    pub sample_rate: u32,
19    /// Number of audio channels (always 1 for mono).
20    pub channels: u16,
21}
22
23impl AudioSamples {
24    /// Create a new `AudioSamples` instance.
25    pub fn new(samples: Vec<f32>, sample_rate: u32) -> Self {
26        Self {
27            samples,
28            sample_rate,
29            channels: 1,
30        }
31    }
32
33    /// Duration of the audio in seconds.
34    pub fn duration_secs(&self) -> f32 {
35        if self.sample_rate == 0 {
36            return 0.0;
37        }
38        self.samples.len() as f32 / self.sample_rate as f32
39    }
40
41    /// Number of samples.
42    pub fn len(&self) -> usize {
43        self.samples.len()
44    }
45
46    /// Whether the audio is empty.
47    pub fn is_empty(&self) -> bool {
48        self.samples.is_empty()
49    }
50
51    /// Decode a WAV file from bytes.
52    ///
53    /// Supports RIFF/WAVE PCM integer data (8/16/24/32-bit) and 32-bit float.
54    /// Multi-channel audio is downmixed to mono to match the library's output
55    /// convention.
56    pub fn from_wav_bytes(bytes: &[u8]) -> Result<Self, crate::TtsError> {
57        decode_wav_bytes(bytes)
58    }
59
60    /// Decode a WAV file from disk.
61    pub fn from_wav_file(path: impl AsRef<std::path::Path>) -> Result<Self, crate::TtsError> {
62        let data = std::fs::read(path)?;
63        Self::from_wav_bytes(&data)
64    }
65
66    /// Decode a WAV or MP3 stream into mono PCM samples.
67    ///
68    /// The input format is auto-detected. WAV is decoded directly and MP3 is
69    /// decoded with the built-in MP3 decoder.
70    pub fn from_audio_stream<R>(stream: R) -> Result<Self, crate::TtsError>
71    where
72        R: Read + Seek + Send + Sync + 'static,
73    {
74        decode_audio_stream(stream)
75    }
76
77    /// Decode a WAV or MP3 byte buffer into mono PCM samples.
78    pub fn from_audio_bytes(bytes: &[u8]) -> Result<Self, crate::TtsError> {
79        decode_audio_bytes(bytes)
80    }
81
82    /// Decode a WAV or MP3 file from disk.
83    pub fn from_audio_file(path: impl AsRef<std::path::Path>) -> Result<Self, crate::TtsError> {
84        let data = std::fs::read(path)?;
85        Self::from_audio_bytes(&data)
86    }
87
88    /// Decode a WAV or MP3 stream and apply speech-focused denoising.
89    pub fn denoise_audio_stream<R>(
90        stream: R,
91        options: DenoiseOptions,
92    ) -> Result<Self, crate::TtsError>
93    where
94        R: Read + Seek + Send + Sync + 'static,
95    {
96        Ok(Self::from_audio_stream(stream)?.denoise_speech(options))
97    }
98
99    /// Decode a WAV or MP3 byte buffer and apply speech-focused denoising.
100    pub fn denoise_audio_bytes(
101        bytes: &[u8],
102        options: DenoiseOptions,
103    ) -> Result<Self, crate::TtsError> {
104        Ok(Self::from_audio_bytes(bytes)?.denoise_speech(options))
105    }
106
107    /// Apply speech-focused denoising to the current audio samples.
108    ///
109    /// This is a classical DSP pass, not a learned source-separation model.
110    /// It works best on mono spoken audio with steady background noise or music.
111    pub fn denoise_speech(&self, options: DenoiseOptions) -> Self {
112        denoise_audio_samples(self, options)
113    }
114
115    /// Convert samples to i16 PCM (for WAV output).
116    pub fn to_i16(&self) -> Vec<i16> {
117        self.samples
118            .iter()
119            .map(|&s| {
120                let clamped = s.clamp(-1.0, 1.0);
121                (clamped * i16::MAX as f32) as i16
122            })
123            .collect()
124    }
125
126    // -----------------------------------------------------------------------
127    // WAV
128    // -----------------------------------------------------------------------
129
130    /// Encode the audio as a 16-bit PCM WAV and return the raw bytes.
131    ///
132    /// The returned `Vec<u8>` contains a complete RIFF WAV file that can be
133    /// written to disk, sent over the network, or played back directly.
134    pub fn get_wav(&self) -> Vec<u8> {
135        let pcm = self.to_i16();
136        let data_len = (pcm.len() * 2) as u32;
137        let file_len = 36 + data_len;
138        let byte_rate = self.sample_rate * self.channels as u32 * 2;
139        let block_align = self.channels * 2;
140
141        // Pre-allocate: 44 bytes header + PCM data
142        let mut buf = Vec::with_capacity(44 + data_len as usize);
143
144        // RIFF header
145        buf.extend_from_slice(b"RIFF");
146        buf.extend_from_slice(&file_len.to_le_bytes());
147        buf.extend_from_slice(b"WAVE");
148
149        // fmt chunk
150        buf.extend_from_slice(b"fmt ");
151        buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
152        buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
153        buf.extend_from_slice(&self.channels.to_le_bytes());
154        buf.extend_from_slice(&self.sample_rate.to_le_bytes());
155        buf.extend_from_slice(&byte_rate.to_le_bytes());
156        buf.extend_from_slice(&block_align.to_le_bytes());
157        buf.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
158
159        // data chunk
160        buf.extend_from_slice(b"data");
161        buf.extend_from_slice(&data_len.to_le_bytes());
162        for sample in &pcm {
163            buf.extend_from_slice(&sample.to_le_bytes());
164        }
165
166        buf
167    }
168
169    /// Save the audio to a WAV file (16-bit PCM).
170    ///
171    /// Creates parent directories automatically.
172    pub fn save_wav(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
173        let path = path.as_ref();
174        if let Some(parent) = path.parent() {
175            std::fs::create_dir_all(parent)?;
176        }
177        std::fs::write(path, self.get_wav())
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use base64::Engine;
185    use std::f32::consts::PI;
186    use std::io::Cursor;
187
188    const MP3_FIXTURE_BASE64: &str = "SUQzBAAAAAAAIlRTU0UAAAAOAAADTGF2ZjYxLjcuMTAwAAAAAAAAAAAAAAD/86TEAAWgBuJhQQABkQMiLzhh4efgHh55+AAAGe2Hh5//gIUXgBEZiAUDAUCAQBgSX26mpsGnjrCGMjzZfQYxMiAUGUJo1P2AU0EJBOfw5QWoJ8O3/EZC6juGGGG/8xLpImReL34lCQNN/KhIGgsCytAADosjVUE6KILUMMPJqgsA7cE8gj4UH7g5CxDs06FQV0J0jeIMIP3a+DMVJWtq2CCJ0AAOX8IbfFCA0OI7z4wAwAsGoGkFgIBAAAbZsDnojFGdFf5b8hEmQjmET/5fiiAS0Liv8CAXkcbcaSq/5aFvVafhTsOf/53jm8iAqSPZAPA/+TSQTEFNRTMuMTAwVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8zTE/BKA8rr5mmkAVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8yTE7QSAPuMBzQAAVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8xTE/gKwPt6AAFgFVVVVVVVVVVVVVVX/8xTE/gKoQtqAA1gIVVVVVVVVVVVVVVX/8xTE/gJQQuaAApgIVVVVVVVVVVVVVVX/8xTE/wP4OuMBSwABVVVVVVVVVVVVVVX/8zTE+hHAsssZmnkAVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8xTE7gAAA0gBwAAAVVVVVVVVVVVVVVU=";
189
190    #[test]
191    fn test_duration_calculation() {
192        let audio = AudioSamples::new(vec![0.0; 24000], 24000);
193        assert!((audio.duration_secs() - 1.0).abs() < f32::EPSILON);
194    }
195
196    #[test]
197    fn test_duration_zero_sample_rate() {
198        let audio = AudioSamples {
199            samples: vec![0.0; 100],
200            sample_rate: 0,
201            channels: 1,
202        };
203        assert!((audio.duration_secs()).abs() < f32::EPSILON);
204    }
205
206    #[test]
207    fn test_to_i16_conversion() {
208        let audio = AudioSamples::new(vec![0.0, 1.0, -1.0, 0.5], 24000);
209        let pcm = audio.to_i16();
210        assert_eq!(pcm[0], 0);
211        assert_eq!(pcm[1], i16::MAX);
212        assert_eq!(pcm[2], -i16::MAX);
213        // 0.5 * 32767 = 16383 (truncated)
214        assert_eq!(pcm[3], 16383);
215    }
216
217    #[test]
218    fn test_to_i16_clamping() {
219        let audio = AudioSamples::new(vec![2.0, -2.0], 24000);
220        let pcm = audio.to_i16();
221        assert_eq!(pcm[0], i16::MAX);
222        assert_eq!(pcm[1], -i16::MAX);
223    }
224
225    #[test]
226    fn test_empty_audio() {
227        let audio = AudioSamples::new(vec![], 24000);
228        assert!(audio.is_empty());
229        assert_eq!(audio.len(), 0);
230        assert!((audio.duration_secs()).abs() < f32::EPSILON);
231    }
232
233    #[test]
234    fn test_wav_roundtrip() {
235        let original = AudioSamples::new(vec![0.0, 0.25, -0.25, 1.0, -1.0], 24000);
236        let decoded = AudioSamples::from_wav_bytes(&original.get_wav()).unwrap();
237
238        assert_eq!(decoded.sample_rate, 24000);
239        assert_eq!(decoded.channels, 1);
240        assert_eq!(decoded.samples.len(), original.samples.len());
241        assert!((decoded.samples[1] - original.samples[1]).abs() < 1e-3);
242        assert!((decoded.samples[2] - original.samples[2]).abs() < 1e-3);
243    }
244
245    #[test]
246    fn test_invalid_wav_rejected() {
247        let err = AudioSamples::from_wav_bytes(b"not a wav").unwrap_err();
248        assert!(err.to_string().contains("Invalid WAV header"));
249    }
250
251    #[test]
252    fn test_from_audio_bytes_auto_detects_wav() {
253        let original = AudioSamples::new(vec![0.0, 0.2, -0.2, 0.5, -0.5], 16_000);
254        let decoded = AudioSamples::from_audio_bytes(&original.get_wav()).unwrap();
255
256        assert_eq!(decoded.sample_rate, original.sample_rate);
257        assert_eq!(decoded.channels, 1);
258        assert_eq!(decoded.samples.len(), original.samples.len());
259    }
260
261    #[test]
262    fn test_denoise_audio_stream_decodes_wav() {
263        let original = AudioSamples::new(synthetic_voice_like_signal(16_000, 1.0), 16_000);
264        let cleaned = AudioSamples::denoise_audio_stream(
265            Cursor::new(original.get_wav()),
266            DenoiseOptions::default(),
267        )
268        .unwrap();
269
270        assert_eq!(cleaned.sample_rate, original.sample_rate);
271        assert_eq!(cleaned.channels, 1);
272        assert_eq!(cleaned.samples.len(), original.samples.len());
273    }
274
275    #[test]
276    fn test_denoise_speech_improves_snr_on_synthetic_mix() {
277        let sample_rate = 16_000;
278        let clean = synthetic_voice_like_signal(sample_rate, 2.0);
279        let noisy = mix_background_music(&clean, sample_rate);
280        let audio = AudioSamples::new(noisy.clone(), sample_rate);
281        let reference = AudioSamples::new(clean, sample_rate).denoise_speech(DenoiseOptions {
282            noise_reduction: 0.0,
283            residual_floor: 1.0,
284            wet_mix: 1.0,
285            ..DenoiseOptions::default()
286        });
287        let band_limited_noisy =
288            AudioSamples::new(noisy.clone(), sample_rate).denoise_speech(DenoiseOptions {
289                noise_reduction: 0.0,
290                residual_floor: 1.0,
291                wet_mix: 1.0,
292                ..DenoiseOptions::default()
293            });
294        let cleaned = audio.denoise_speech(DenoiseOptions::default());
295
296        let snr_before = snr_db(&reference.samples, &band_limited_noisy.samples);
297        let snr_after = snr_db(&reference.samples, &cleaned.samples);
298
299        assert!(
300            snr_after > snr_before + 0.5,
301            "Expected denoiser to improve SNR, before={snr_before:.2} dB after={snr_after:.2} dB"
302        );
303    }
304
305    #[test]
306    fn test_denoise_speech_reduces_quiet_region_noise_floor() {
307        let sample_rate = 16_000;
308        let quiet_prefix_len = sample_rate as usize / 2;
309        let mut clean = vec![0.0; quiet_prefix_len];
310        clean.extend(synthetic_voice_like_signal(sample_rate, 1.5));
311
312        let noisy = mix_background_music(&clean, sample_rate);
313        let baseline =
314            AudioSamples::new(noisy.clone(), sample_rate).denoise_speech(DenoiseOptions {
315                noise_reduction: 0.0,
316                residual_floor: 1.0,
317                wet_mix: 1.0,
318                ..DenoiseOptions::default()
319            });
320        let cleaned =
321            AudioSamples::new(noisy, sample_rate).denoise_speech(DenoiseOptions::default());
322
323        let baseline_quiet_rms = rms(&baseline.samples[..quiet_prefix_len]);
324        let cleaned_quiet_rms = rms(&cleaned.samples[..quiet_prefix_len]);
325        let baseline_speech_rms = rms(&baseline.samples[quiet_prefix_len..]);
326        let cleaned_speech_rms = rms(&cleaned.samples[quiet_prefix_len..]);
327
328        assert!(
329            cleaned_quiet_rms < baseline_quiet_rms * 0.7,
330            "Expected denoiser to lower the quiet-region RMS, before={baseline_quiet_rms:.4} after={cleaned_quiet_rms:.4}"
331        );
332        assert!(
333            cleaned_speech_rms > baseline_speech_rms * 0.45,
334            "Expected denoiser to preserve speech energy, before={baseline_speech_rms:.4} after={cleaned_speech_rms:.4}"
335        );
336    }
337
338    #[test]
339    fn test_from_audio_stream_decodes_mp3() {
340        let mp3 = base64::engine::general_purpose::STANDARD
341            .decode(MP3_FIXTURE_BASE64)
342            .unwrap();
343        let decoded = AudioSamples::from_audio_stream(Cursor::new(mp3)).unwrap();
344
345        assert_eq!(decoded.sample_rate, 24_000);
346        assert!(!decoded.samples.is_empty());
347        assert!(decoded.samples.iter().any(|sample| sample.abs() > 1e-3));
348    }
349
350    fn synthetic_voice_like_signal(sample_rate: u32, duration_secs: f32) -> Vec<f32> {
351        let sample_count = (sample_rate as f32 * duration_secs) as usize;
352        (0..sample_count)
353            .map(|index| {
354                let time = index as f32 / sample_rate as f32;
355                let phrase = (2.0 * PI * 1.15 * time).sin().max(0.0).powf(1.8);
356                let syllable = (2.0 * PI * 2.6 * time).sin().abs().powf(0.8);
357                let clean = 0.45 * (2.0 * PI * 180.0 * time).sin()
358                    + 0.25 * (2.0 * PI * 360.0 * time).sin()
359                    + 0.08 * (2.0 * PI * 1_200.0 * time).sin();
360                clean * phrase * (0.2 + 0.8 * syllable)
361            })
362            .collect()
363    }
364
365    fn mix_background_music(clean: &[f32], sample_rate: u32) -> Vec<f32> {
366        let mut state = 0x1234_5678u32;
367        clean
368            .iter()
369            .enumerate()
370            .map(|(index, &sample)| {
371                let time = index as f32 / sample_rate as f32;
372                state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
373                let pseudo_noise = ((state >> 8) as f32 / (u32::MAX >> 8) as f32) * 2.0 - 1.0;
374                let music = 0.18 * (2.0 * PI * 110.0 * time).sin()
375                    + 0.12 * (2.0 * PI * 220.0 * time).sin()
376                    + 0.08 * (2.0 * PI * 3_600.0 * time).sin()
377                    + 0.04 * pseudo_noise;
378                (sample + music).clamp(-1.0, 1.0)
379            })
380            .collect()
381    }
382
383    fn snr_db(reference: &[f32], observed: &[f32]) -> f32 {
384        let signal_power =
385            reference.iter().map(|sample| sample * sample).sum::<f32>() / reference.len() as f32;
386        let noise_power = reference
387            .iter()
388            .zip(observed)
389            .map(|(reference, observed)| {
390                let error = observed - reference;
391                error * error
392            })
393            .sum::<f32>()
394            / reference.len() as f32;
395
396        10.0 * (signal_power / noise_power.max(1e-9)).log10()
397    }
398    fn rms(samples: &[f32]) -> f32 {
399        if samples.is_empty() {
400            return 0.0;
401        }
402
403        (samples.iter().map(|sample| sample * sample).sum::<f32>() / samples.len() as f32).sqrt()
404    }
405}