1use std::io;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum WavError {
10 #[error("IO error: {0}")]
11 Io(#[from] io::Error),
12 #[error("invalid WAV header: {0}")]
13 InvalidHeader(String),
14 #[error("unsupported format: {0}")]
15 UnsupportedFormat(String),
16}
17
18#[derive(Debug, Clone)]
20pub struct WavHeader {
21 pub sample_rate: u32,
22 pub channels: u16,
23 pub bits_per_sample: u16,
24 pub num_samples: usize,
25}
26
27impl WavHeader {
28 pub fn telephony() -> Self {
30 Self {
31 sample_rate: 8000,
32 channels: 1,
33 bits_per_sample: 16,
34 num_samples: 0,
35 }
36 }
37
38 pub fn mono(sample_rate: u32) -> Self {
40 Self {
41 sample_rate,
42 channels: 1,
43 bits_per_sample: 16,
44 num_samples: 0,
45 }
46 }
47}
48
49pub fn encode_wav(samples: &[i16], header: &WavHeader) -> Vec<u8> {
51 let data_size = samples.len() * 2; let byte_rate = header.sample_rate * header.channels as u32 * (header.bits_per_sample as u32 / 8);
53 let block_align = header.channels * (header.bits_per_sample / 8);
54
55 let mut buf = Vec::with_capacity(44 + data_size);
56
57 buf.extend_from_slice(b"RIFF");
59 buf.extend_from_slice(&((36 + data_size) as u32).to_le_bytes());
60 buf.extend_from_slice(b"WAVE");
61
62 buf.extend_from_slice(b"fmt ");
64 buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&header.channels.to_le_bytes());
67 buf.extend_from_slice(&header.sample_rate.to_le_bytes());
68 buf.extend_from_slice(&byte_rate.to_le_bytes());
69 buf.extend_from_slice(&block_align.to_le_bytes());
70 buf.extend_from_slice(&header.bits_per_sample.to_le_bytes());
71
72 buf.extend_from_slice(b"data");
74 buf.extend_from_slice(&(data_size as u32).to_le_bytes());
75 for &sample in samples {
76 buf.extend_from_slice(&sample.to_le_bytes());
77 }
78
79 buf
80}
81
82pub fn write_wav(path: &str, samples: &[i16], header: &WavHeader) -> Result<(), WavError> {
84 let data = encode_wav(samples, header);
85 std::fs::write(path, data)?;
86 Ok(())
87}
88
89pub fn decode_wav(data: &[u8]) -> Result<(WavHeader, Vec<i16>), WavError> {
91 if data.len() < 44 {
92 return Err(WavError::InvalidHeader("too short".to_string()));
93 }
94
95 if &data[0..4] != b"RIFF" {
97 return Err(WavError::InvalidHeader("missing RIFF".to_string()));
98 }
99 if &data[8..12] != b"WAVE" {
100 return Err(WavError::InvalidHeader("missing WAVE".to_string()));
101 }
102
103 if &data[12..16] != b"fmt " {
105 return Err(WavError::InvalidHeader("missing fmt chunk".to_string()));
106 }
107
108 let format = u16::from_le_bytes([data[20], data[21]]);
109 if format != 1 {
110 return Err(WavError::UnsupportedFormat(format!(
111 "not PCM (format={})",
112 format
113 )));
114 }
115
116 let channels = u16::from_le_bytes([data[22], data[23]]);
117 let sample_rate = u32::from_le_bytes([data[24], data[25], data[26], data[27]]);
118 let bits_per_sample = u16::from_le_bytes([data[34], data[35]]);
119
120 if bits_per_sample != 16 {
121 return Err(WavError::UnsupportedFormat(format!(
122 "not 16-bit (bits={})",
123 bits_per_sample
124 )));
125 }
126
127 let mut pos = 12;
129 loop {
130 if pos + 8 > data.len() {
131 return Err(WavError::InvalidHeader("missing data chunk".to_string()));
132 }
133 let chunk_id = &data[pos..pos + 4];
134 let chunk_size = u32::from_le_bytes([
135 data[pos + 4],
136 data[pos + 5],
137 data[pos + 6],
138 data[pos + 7],
139 ]) as usize;
140
141 if chunk_id == b"data" {
142 let sample_data = &data[pos + 8..pos + 8 + chunk_size.min(data.len() - pos - 8)];
143 let samples: Vec<i16> = sample_data
144 .chunks_exact(2)
145 .map(|c| i16::from_le_bytes([c[0], c[1]]))
146 .collect();
147
148 let header = WavHeader {
149 sample_rate,
150 channels,
151 bits_per_sample,
152 num_samples: samples.len(),
153 };
154
155 return Ok((header, samples));
156 }
157
158 pos += 8 + chunk_size;
159 if chunk_size % 2 != 0 {
161 pos += 1;
162 }
163 }
164}
165
166pub fn read_wav(path: &str) -> Result<(WavHeader, Vec<i16>), WavError> {
168 let data = std::fs::read(path)?;
169 decode_wav(&data)
170}
171
172#[derive(Debug, Clone)]
174pub struct AudioRecorder {
175 samples: Vec<i16>,
176 sample_rate: u32,
177}
178
179impl AudioRecorder {
180 pub fn new(sample_rate: u32) -> Self {
181 Self {
182 samples: Vec::new(),
183 sample_rate,
184 }
185 }
186
187 pub fn record_frame(&mut self, frame: &[i16]) {
189 self.samples.extend_from_slice(frame);
190 }
191
192 pub fn samples(&self) -> &[i16] {
194 &self.samples
195 }
196
197 pub fn duration_ms(&self) -> u64 {
199 (self.samples.len() as u64 * 1000) / self.sample_rate as u64
200 }
201
202 pub fn frame_count(&self) -> usize {
204 let samples_per_frame = (self.sample_rate as usize * 20) / 1000;
205 if samples_per_frame == 0 {
206 return 0;
207 }
208 self.samples.len() / samples_per_frame
209 }
210
211 pub fn to_wav(&self) -> Vec<u8> {
213 let header = WavHeader {
214 sample_rate: self.sample_rate,
215 channels: 1,
216 bits_per_sample: 16,
217 num_samples: self.samples.len(),
218 };
219 encode_wav(&self.samples, &header)
220 }
221
222 pub fn save_wav(&self, path: &str) -> Result<(), WavError> {
224 let header = WavHeader::mono(self.sample_rate);
225 write_wav(path, &self.samples, &header)
226 }
227
228 pub fn clear(&mut self) {
230 self.samples.clear();
231 }
232
233 pub fn is_empty(&self) -> bool {
235 self.samples.is_empty()
236 }
237
238 pub fn len(&self) -> usize {
240 self.samples.len()
241 }
242}
243
244pub fn generate_sine_tone(frequency: f64, sample_rate: u32, duration_ms: u32, amplitude: i16) -> Vec<i16> {
246 let num_samples = (sample_rate as u64 * duration_ms as u64 / 1000) as usize;
247 (0..num_samples)
248 .map(|i| {
249 let t = i as f64 / sample_rate as f64;
250 (f64::sin(2.0 * std::f64::consts::PI * frequency * t) * amplitude as f64) as i16
251 })
252 .collect()
253}
254
255pub fn generate_multi_tone(
257 frequencies: &[f64],
258 sample_rate: u32,
259 duration_ms: u32,
260 amplitude: i16,
261) -> Vec<i16> {
262 let num_samples = (sample_rate as u64 * duration_ms as u64 / 1000) as usize;
263 let scale = 1.0 / frequencies.len() as f64;
264 (0..num_samples)
265 .map(|i| {
266 let t = i as f64 / sample_rate as f64;
267 let sum: f64 = frequencies
268 .iter()
269 .map(|&freq| f64::sin(2.0 * std::f64::consts::PI * freq * t))
270 .sum();
271 (sum * scale * amplitude as f64) as i16
272 })
273 .collect()
274}
275
276pub fn compute_snr(original: &[i16], received: &[i16]) -> f64 {
279 let len = original.len().min(received.len());
280 if len == 0 {
281 return 0.0;
282 }
283
284 let mut signal_power = 0.0f64;
285 let mut noise_power = 0.0f64;
286
287 for i in 0..len {
288 let s = original[i] as f64;
289 let n = (original[i] as f64) - (received[i] as f64);
290 signal_power += s * s;
291 noise_power += n * n;
292 }
293
294 if noise_power < 1.0 {
295 return 100.0; }
297
298 10.0 * (signal_power / noise_power).log10()
299}
300
301pub fn cross_correlation(a: &[i16], b: &[i16]) -> f64 {
304 let len = a.len().min(b.len());
305 if len == 0 {
306 return 0.0;
307 }
308
309 let mut sum_ab = 0.0f64;
310 let mut sum_aa = 0.0f64;
311 let mut sum_bb = 0.0f64;
312
313 for i in 0..len {
314 let va = a[i] as f64;
315 let vb = b[i] as f64;
316 sum_ab += va * vb;
317 sum_aa += va * va;
318 sum_bb += vb * vb;
319 }
320
321 let denom = (sum_aa * sum_bb).sqrt();
322 if denom < 1.0 {
323 return 0.0;
324 }
325
326 sum_ab / denom
327}
328
329pub fn max_sample_error(original: &[i16], received: &[i16]) -> i32 {
331 let len = original.len().min(received.len());
332 let mut max_err = 0i32;
333 for i in 0..len {
334 let err = (original[i] as i32 - received[i] as i32).abs();
335 if err > max_err {
336 max_err = err;
337 }
338 }
339 max_err
340}
341
342pub fn rms_error(original: &[i16], received: &[i16]) -> f64 {
344 let len = original.len().min(received.len());
345 if len == 0 {
346 return 0.0;
347 }
348 let sum: f64 = (0..len)
349 .map(|i| {
350 let diff = original[i] as f64 - received[i] as f64;
351 diff * diff
352 })
353 .sum();
354 (sum / len as f64).sqrt()
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_wav_roundtrip() {
363 let samples: Vec<i16> = (0..8000)
364 .map(|i| ((i as f64 / 8000.0 * std::f64::consts::TAU * 440.0).sin() * 16000.0) as i16)
365 .collect();
366
367 let header = WavHeader::telephony();
368 let encoded = encode_wav(&samples, &header);
369 let (decoded_header, decoded_samples) = decode_wav(&encoded).unwrap();
370
371 assert_eq!(decoded_header.sample_rate, 8000);
372 assert_eq!(decoded_header.channels, 1);
373 assert_eq!(decoded_header.bits_per_sample, 16);
374 assert_eq!(decoded_samples, samples);
375 }
376
377 #[test]
378 fn test_wav_file_roundtrip() {
379 let samples = generate_sine_tone(440.0, 8000, 100, 16000);
380 let header = WavHeader::telephony();
381
382 let path = "/tmp/siphone_test_wav_roundtrip.wav";
383 write_wav(path, &samples, &header).unwrap();
384 let (_, read_samples) = read_wav(path).unwrap();
385 assert_eq!(read_samples, samples);
386
387 std::fs::remove_file(path).ok();
388 }
389
390 #[test]
391 fn test_wav_invalid() {
392 assert!(decode_wav(b"NOT A WAV").is_err());
393 assert!(decode_wav(&[0; 10]).is_err());
394 }
395
396 #[test]
397 fn test_generate_sine_tone() {
398 let tone = generate_sine_tone(440.0, 8000, 100, 16000);
399 assert_eq!(tone.len(), 800); assert!(tone.iter().any(|&s| s != 0));
403
404 let max = tone.iter().map(|s| s.abs()).max().unwrap();
406 assert!(max > 15000 && max <= 16000);
407 }
408
409 #[test]
410 fn test_generate_multi_tone() {
411 let tone = generate_multi_tone(&[300.0, 500.0, 700.0], 8000, 100, 16000);
412 assert_eq!(tone.len(), 800);
413 assert!(tone.iter().any(|&s| s != 0));
414 }
415
416 #[test]
417 fn test_audio_recorder() {
418 let mut recorder = AudioRecorder::new(8000);
419 assert!(recorder.is_empty());
420 assert_eq!(recorder.duration_ms(), 0);
421
422 let frame = vec![1000i16; 160];
423 recorder.record_frame(&frame);
424 assert_eq!(recorder.len(), 160);
425 assert_eq!(recorder.duration_ms(), 20); assert_eq!(recorder.frame_count(), 1);
427
428 recorder.record_frame(&frame);
429 assert_eq!(recorder.len(), 320);
430 assert_eq!(recorder.frame_count(), 2);
431 assert_eq!(recorder.duration_ms(), 40);
432 }
433
434 #[test]
435 fn test_recorder_to_wav() {
436 let mut recorder = AudioRecorder::new(8000);
437 let tone = generate_sine_tone(440.0, 8000, 100, 16000);
438 for frame in tone.chunks(160) {
439 recorder.record_frame(frame);
440 }
441
442 let wav = recorder.to_wav();
443 let (header, samples) = decode_wav(&wav).unwrap();
444 assert_eq!(header.sample_rate, 8000);
445 assert_eq!(samples, recorder.samples());
446 }
447
448 #[test]
449 fn test_compute_snr_identical() {
450 let signal = generate_sine_tone(440.0, 8000, 100, 16000);
451 let snr = compute_snr(&signal, &signal);
452 assert!(snr > 90.0, "SNR for identical signals should be very high, got {}", snr);
453 }
454
455 #[test]
456 fn test_compute_snr_with_noise() {
457 let signal = generate_sine_tone(440.0, 8000, 100, 16000);
458 let noisy: Vec<i16> = signal
459 .iter()
460 .enumerate()
461 .map(|(i, &s)| {
462 let noise = ((i as f64 * 0.1).sin() * 100.0) as i16;
463 s.saturating_add(noise)
464 })
465 .collect();
466
467 let snr = compute_snr(&signal, &noisy);
468 assert!(snr > 20.0, "SNR should be decent, got {}", snr);
469 }
470
471 #[test]
472 fn test_cross_correlation_identical() {
473 let signal = generate_sine_tone(440.0, 8000, 100, 16000);
474 let corr = cross_correlation(&signal, &signal);
475 assert!((corr - 1.0).abs() < 0.001, "Self-correlation should be ~1.0, got {}", corr);
476 }
477
478 #[test]
479 fn test_cross_correlation_different() {
480 let sig_a = generate_sine_tone(440.0, 8000, 100, 16000);
481 let sig_b = generate_sine_tone(880.0, 8000, 100, 16000);
482 let corr = cross_correlation(&sig_a, &sig_b);
483 assert!(corr < 0.5, "Different tones should have low correlation, got {}", corr);
485 }
486
487 #[test]
488 fn test_max_sample_error() {
489 let a = vec![100i16, 200, 300, 400, 500];
490 let b = vec![110i16, 190, 310, 350, 510];
491 let err = max_sample_error(&a, &b);
492 assert_eq!(err, 50); }
494
495 #[test]
496 fn test_rms_error_identical() {
497 let a = generate_sine_tone(440.0, 8000, 100, 16000);
498 let rms = rms_error(&a, &a);
499 assert!(rms < 0.001, "RMS error for identical signals should be ~0, got {}", rms);
500 }
501
502 #[test]
503 fn test_recorder_clear() {
504 let mut recorder = AudioRecorder::new(8000);
505 recorder.record_frame(&[100i16; 160]);
506 assert!(!recorder.is_empty());
507 recorder.clear();
508 assert!(recorder.is_empty());
509 }
510
511 #[test]
512 fn test_wav_header_constructors() {
513 let h = WavHeader::telephony();
514 assert_eq!(h.sample_rate, 8000);
515 assert_eq!(h.channels, 1);
516
517 let h = WavHeader::mono(48000);
518 assert_eq!(h.sample_rate, 48000);
519 assert_eq!(h.channels, 1);
520 }
521}