1use std::io::{Read, Seek};
4
5mod decode;
6mod denoise;
7
8use decode::{decode_audio_bytes, decode_audio_stream, decode_wav_bytes};
9use denoise::denoise_audio_samples;
10pub use denoise::DenoiseOptions;
11
12#[derive(Debug, Clone)]
14pub struct AudioSamples {
15 pub samples: Vec<f32>,
17 pub sample_rate: u32,
19 pub channels: u16,
21}
22
23impl AudioSamples {
24 pub fn new(samples: Vec<f32>, sample_rate: u32) -> Self {
26 Self {
27 samples,
28 sample_rate,
29 channels: 1,
30 }
31 }
32
33 pub fn duration_secs(&self) -> f32 {
35 if self.sample_rate == 0 {
36 return 0.0;
37 }
38 self.samples.len() as f32 / self.sample_rate as f32
39 }
40
41 pub fn len(&self) -> usize {
43 self.samples.len()
44 }
45
46 pub fn is_empty(&self) -> bool {
48 self.samples.is_empty()
49 }
50
51 pub fn from_wav_bytes(bytes: &[u8]) -> Result<Self, crate::TtsError> {
57 decode_wav_bytes(bytes)
58 }
59
60 pub fn from_wav_file(path: impl AsRef<std::path::Path>) -> Result<Self, crate::TtsError> {
62 let data = std::fs::read(path)?;
63 Self::from_wav_bytes(&data)
64 }
65
66 pub fn from_audio_stream<R>(stream: R) -> Result<Self, crate::TtsError>
71 where
72 R: Read + Seek + Send + Sync + 'static,
73 {
74 decode_audio_stream(stream)
75 }
76
77 pub fn from_audio_bytes(bytes: &[u8]) -> Result<Self, crate::TtsError> {
79 decode_audio_bytes(bytes)
80 }
81
82 pub fn from_audio_file(path: impl AsRef<std::path::Path>) -> Result<Self, crate::TtsError> {
84 let data = std::fs::read(path)?;
85 Self::from_audio_bytes(&data)
86 }
87
88 pub fn denoise_audio_stream<R>(
90 stream: R,
91 options: DenoiseOptions,
92 ) -> Result<Self, crate::TtsError>
93 where
94 R: Read + Seek + Send + Sync + 'static,
95 {
96 Ok(Self::from_audio_stream(stream)?.denoise_speech(options))
97 }
98
99 pub fn denoise_audio_bytes(
101 bytes: &[u8],
102 options: DenoiseOptions,
103 ) -> Result<Self, crate::TtsError> {
104 Ok(Self::from_audio_bytes(bytes)?.denoise_speech(options))
105 }
106
107 pub fn denoise_speech(&self, options: DenoiseOptions) -> Self {
112 denoise_audio_samples(self, options)
113 }
114
115 pub fn to_i16(&self) -> Vec<i16> {
117 self.samples
118 .iter()
119 .map(|&s| {
120 let clamped = s.clamp(-1.0, 1.0);
121 (clamped * i16::MAX as f32) as i16
122 })
123 .collect()
124 }
125
126 pub fn get_wav(&self) -> Vec<u8> {
135 let pcm = self.to_i16();
136 let data_len = (pcm.len() * 2) as u32;
137 let file_len = 36 + data_len;
138 let byte_rate = self.sample_rate * self.channels as u32 * 2;
139 let block_align = self.channels * 2;
140
141 let mut buf = Vec::with_capacity(44 + data_len as usize);
143
144 buf.extend_from_slice(b"RIFF");
146 buf.extend_from_slice(&file_len.to_le_bytes());
147 buf.extend_from_slice(b"WAVE");
148
149 buf.extend_from_slice(b"fmt ");
151 buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&self.channels.to_le_bytes());
154 buf.extend_from_slice(&self.sample_rate.to_le_bytes());
155 buf.extend_from_slice(&byte_rate.to_le_bytes());
156 buf.extend_from_slice(&block_align.to_le_bytes());
157 buf.extend_from_slice(&16u16.to_le_bytes()); buf.extend_from_slice(b"data");
161 buf.extend_from_slice(&data_len.to_le_bytes());
162 for sample in &pcm {
163 buf.extend_from_slice(&sample.to_le_bytes());
164 }
165
166 buf
167 }
168
169 pub fn save_wav(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
173 let path = path.as_ref();
174 if let Some(parent) = path.parent() {
175 std::fs::create_dir_all(parent)?;
176 }
177 std::fs::write(path, self.get_wav())
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use base64::Engine;
185 use std::f32::consts::PI;
186 use std::io::Cursor;
187
188 const MP3_FIXTURE_BASE64: &str = "SUQzBAAAAAAAIlRTU0UAAAAOAAADTGF2ZjYxLjcuMTAwAAAAAAAAAAAAAAD/86TEAAWgBuJhQQABkQMiLzhh4efgHh55+AAAGe2Hh5//gIUXgBEZiAUDAUCAQBgSX26mpsGnjrCGMjzZfQYxMiAUGUJo1P2AU0EJBOfw5QWoJ8O3/EZC6juGGGG/8xLpImReL34lCQNN/KhIGgsCytAADosjVUE6KILUMMPJqgsA7cE8gj4UH7g5CxDs06FQV0J0jeIMIP3a+DMVJWtq2CCJ0AAOX8IbfFCA0OI7z4wAwAsGoGkFgIBAAAbZsDnojFGdFf5b8hEmQjmET/5fiiAS0Liv8CAXkcbcaSq/5aFvVafhTsOf/53jm8iAqSPZAPA/+TSQTEFNRTMuMTAwVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8zTE/BKA8rr5mmkAVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8yTE7QSAPuMBzQAAVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8xTE/gKwPt6AAFgFVVVVVVVVVVVVVVX/8xTE/gKoQtqAA1gIVVVVVVVVVVVVVVX/8xTE/gJQQuaAApgIVVVVVVVVVVVVVVX/8xTE/wP4OuMBSwABVVVVVVVVVVVVVVX/8zTE+hHAsssZmnkAVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVX/8xTE7gAAA0gBwAAAVVVVVVVVVVVVVVU=";
189
190 #[test]
191 fn test_duration_calculation() {
192 let audio = AudioSamples::new(vec![0.0; 24000], 24000);
193 assert!((audio.duration_secs() - 1.0).abs() < f32::EPSILON);
194 }
195
196 #[test]
197 fn test_duration_zero_sample_rate() {
198 let audio = AudioSamples {
199 samples: vec![0.0; 100],
200 sample_rate: 0,
201 channels: 1,
202 };
203 assert!((audio.duration_secs()).abs() < f32::EPSILON);
204 }
205
206 #[test]
207 fn test_to_i16_conversion() {
208 let audio = AudioSamples::new(vec![0.0, 1.0, -1.0, 0.5], 24000);
209 let pcm = audio.to_i16();
210 assert_eq!(pcm[0], 0);
211 assert_eq!(pcm[1], i16::MAX);
212 assert_eq!(pcm[2], -i16::MAX);
213 assert_eq!(pcm[3], 16383);
215 }
216
217 #[test]
218 fn test_to_i16_clamping() {
219 let audio = AudioSamples::new(vec![2.0, -2.0], 24000);
220 let pcm = audio.to_i16();
221 assert_eq!(pcm[0], i16::MAX);
222 assert_eq!(pcm[1], -i16::MAX);
223 }
224
225 #[test]
226 fn test_empty_audio() {
227 let audio = AudioSamples::new(vec![], 24000);
228 assert!(audio.is_empty());
229 assert_eq!(audio.len(), 0);
230 assert!((audio.duration_secs()).abs() < f32::EPSILON);
231 }
232
233 #[test]
234 fn test_wav_roundtrip() {
235 let original = AudioSamples::new(vec![0.0, 0.25, -0.25, 1.0, -1.0], 24000);
236 let decoded = AudioSamples::from_wav_bytes(&original.get_wav()).unwrap();
237
238 assert_eq!(decoded.sample_rate, 24000);
239 assert_eq!(decoded.channels, 1);
240 assert_eq!(decoded.samples.len(), original.samples.len());
241 assert!((decoded.samples[1] - original.samples[1]).abs() < 1e-3);
242 assert!((decoded.samples[2] - original.samples[2]).abs() < 1e-3);
243 }
244
245 #[test]
246 fn test_invalid_wav_rejected() {
247 let err = AudioSamples::from_wav_bytes(b"not a wav").unwrap_err();
248 assert!(err.to_string().contains("Invalid WAV header"));
249 }
250
251 #[test]
252 fn test_from_audio_bytes_auto_detects_wav() {
253 let original = AudioSamples::new(vec![0.0, 0.2, -0.2, 0.5, -0.5], 16_000);
254 let decoded = AudioSamples::from_audio_bytes(&original.get_wav()).unwrap();
255
256 assert_eq!(decoded.sample_rate, original.sample_rate);
257 assert_eq!(decoded.channels, 1);
258 assert_eq!(decoded.samples.len(), original.samples.len());
259 }
260
261 #[test]
262 fn test_denoise_audio_stream_decodes_wav() {
263 let original = AudioSamples::new(synthetic_voice_like_signal(16_000, 1.0), 16_000);
264 let cleaned = AudioSamples::denoise_audio_stream(
265 Cursor::new(original.get_wav()),
266 DenoiseOptions::default(),
267 )
268 .unwrap();
269
270 assert_eq!(cleaned.sample_rate, original.sample_rate);
271 assert_eq!(cleaned.channels, 1);
272 assert_eq!(cleaned.samples.len(), original.samples.len());
273 }
274
275 #[test]
276 fn test_denoise_speech_improves_snr_on_synthetic_mix() {
277 let sample_rate = 16_000;
278 let clean = synthetic_voice_like_signal(sample_rate, 2.0);
279 let noisy = mix_background_music(&clean, sample_rate);
280 let audio = AudioSamples::new(noisy.clone(), sample_rate);
281 let reference = AudioSamples::new(clean, sample_rate).denoise_speech(DenoiseOptions {
282 noise_reduction: 0.0,
283 residual_floor: 1.0,
284 wet_mix: 1.0,
285 ..DenoiseOptions::default()
286 });
287 let band_limited_noisy =
288 AudioSamples::new(noisy.clone(), sample_rate).denoise_speech(DenoiseOptions {
289 noise_reduction: 0.0,
290 residual_floor: 1.0,
291 wet_mix: 1.0,
292 ..DenoiseOptions::default()
293 });
294 let cleaned = audio.denoise_speech(DenoiseOptions::default());
295
296 let snr_before = snr_db(&reference.samples, &band_limited_noisy.samples);
297 let snr_after = snr_db(&reference.samples, &cleaned.samples);
298
299 assert!(
300 snr_after > snr_before + 0.5,
301 "Expected denoiser to improve SNR, before={snr_before:.2} dB after={snr_after:.2} dB"
302 );
303 }
304
305 #[test]
306 fn test_denoise_speech_reduces_quiet_region_noise_floor() {
307 let sample_rate = 16_000;
308 let quiet_prefix_len = sample_rate as usize / 2;
309 let mut clean = vec![0.0; quiet_prefix_len];
310 clean.extend(synthetic_voice_like_signal(sample_rate, 1.5));
311
312 let noisy = mix_background_music(&clean, sample_rate);
313 let baseline =
314 AudioSamples::new(noisy.clone(), sample_rate).denoise_speech(DenoiseOptions {
315 noise_reduction: 0.0,
316 residual_floor: 1.0,
317 wet_mix: 1.0,
318 ..DenoiseOptions::default()
319 });
320 let cleaned =
321 AudioSamples::new(noisy, sample_rate).denoise_speech(DenoiseOptions::default());
322
323 let baseline_quiet_rms = rms(&baseline.samples[..quiet_prefix_len]);
324 let cleaned_quiet_rms = rms(&cleaned.samples[..quiet_prefix_len]);
325 let baseline_speech_rms = rms(&baseline.samples[quiet_prefix_len..]);
326 let cleaned_speech_rms = rms(&cleaned.samples[quiet_prefix_len..]);
327
328 assert!(
329 cleaned_quiet_rms < baseline_quiet_rms * 0.7,
330 "Expected denoiser to lower the quiet-region RMS, before={baseline_quiet_rms:.4} after={cleaned_quiet_rms:.4}"
331 );
332 assert!(
333 cleaned_speech_rms > baseline_speech_rms * 0.45,
334 "Expected denoiser to preserve speech energy, before={baseline_speech_rms:.4} after={cleaned_speech_rms:.4}"
335 );
336 }
337
338 #[test]
339 fn test_from_audio_stream_decodes_mp3() {
340 let mp3 = base64::engine::general_purpose::STANDARD
341 .decode(MP3_FIXTURE_BASE64)
342 .unwrap();
343 let decoded = AudioSamples::from_audio_stream(Cursor::new(mp3)).unwrap();
344
345 assert_eq!(decoded.sample_rate, 24_000);
346 assert!(!decoded.samples.is_empty());
347 assert!(decoded.samples.iter().any(|sample| sample.abs() > 1e-3));
348 }
349
350 fn synthetic_voice_like_signal(sample_rate: u32, duration_secs: f32) -> Vec<f32> {
351 let sample_count = (sample_rate as f32 * duration_secs) as usize;
352 (0..sample_count)
353 .map(|index| {
354 let time = index as f32 / sample_rate as f32;
355 let phrase = (2.0 * PI * 1.15 * time).sin().max(0.0).powf(1.8);
356 let syllable = (2.0 * PI * 2.6 * time).sin().abs().powf(0.8);
357 let clean = 0.45 * (2.0 * PI * 180.0 * time).sin()
358 + 0.25 * (2.0 * PI * 360.0 * time).sin()
359 + 0.08 * (2.0 * PI * 1_200.0 * time).sin();
360 clean * phrase * (0.2 + 0.8 * syllable)
361 })
362 .collect()
363 }
364
365 fn mix_background_music(clean: &[f32], sample_rate: u32) -> Vec<f32> {
366 let mut state = 0x1234_5678u32;
367 clean
368 .iter()
369 .enumerate()
370 .map(|(index, &sample)| {
371 let time = index as f32 / sample_rate as f32;
372 state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
373 let pseudo_noise = ((state >> 8) as f32 / (u32::MAX >> 8) as f32) * 2.0 - 1.0;
374 let music = 0.18 * (2.0 * PI * 110.0 * time).sin()
375 + 0.12 * (2.0 * PI * 220.0 * time).sin()
376 + 0.08 * (2.0 * PI * 3_600.0 * time).sin()
377 + 0.04 * pseudo_noise;
378 (sample + music).clamp(-1.0, 1.0)
379 })
380 .collect()
381 }
382
383 fn snr_db(reference: &[f32], observed: &[f32]) -> f32 {
384 let signal_power =
385 reference.iter().map(|sample| sample * sample).sum::<f32>() / reference.len() as f32;
386 let noise_power = reference
387 .iter()
388 .zip(observed)
389 .map(|(reference, observed)| {
390 let error = observed - reference;
391 error * error
392 })
393 .sum::<f32>()
394 / reference.len() as f32;
395
396 10.0 * (signal_power / noise_power.max(1e-9)).log10()
397 }
398 fn rms(samples: &[f32]) -> f32 {
399 if samples.is_empty() {
400 return 0.0;
401 }
402
403 (samples.iter().map(|sample| sample * sample).sum::<f32>() / samples.len() as f32).sqrt()
404 }
405}