Skip to main content

rtp_core/
wav.rs

1//! WAV file reading and writing for audio recording and playback.
2//!
3//! Supports 16-bit PCM mono WAV files at any sample rate (typically 8000 Hz for telephony).
4
5use std::io;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum WavError {
10    #[error("IO error: {0}")]
11    Io(#[from] io::Error),
12    #[error("invalid WAV header: {0}")]
13    InvalidHeader(String),
14    #[error("unsupported format: {0}")]
15    UnsupportedFormat(String),
16}
17
18/// WAV file header parameters
19#[derive(Debug, Clone)]
20pub struct WavHeader {
21    pub sample_rate: u32,
22    pub channels: u16,
23    pub bits_per_sample: u16,
24    pub num_samples: usize,
25}
26
27impl WavHeader {
28    /// Standard telephony WAV: 8kHz mono 16-bit
29    pub fn telephony() -> Self {
30        Self {
31            sample_rate: 8000,
32            channels: 1,
33            bits_per_sample: 16,
34            num_samples: 0,
35        }
36    }
37
38    /// WAV at a custom sample rate, mono 16-bit
39    pub fn mono(sample_rate: u32) -> Self {
40        Self {
41            sample_rate,
42            channels: 1,
43            bits_per_sample: 16,
44            num_samples: 0,
45        }
46    }
47}
48
49/// Write PCM samples to a WAV file (in memory as bytes).
50pub fn encode_wav(samples: &[i16], header: &WavHeader) -> Vec<u8> {
51    let data_size = samples.len() * 2; // 16-bit = 2 bytes per sample
52    let byte_rate = header.sample_rate * header.channels as u32 * (header.bits_per_sample as u32 / 8);
53    let block_align = header.channels * (header.bits_per_sample / 8);
54
55    let mut buf = Vec::with_capacity(44 + data_size);
56
57    // RIFF header
58    buf.extend_from_slice(b"RIFF");
59    buf.extend_from_slice(&((36 + data_size) as u32).to_le_bytes());
60    buf.extend_from_slice(b"WAVE");
61
62    // fmt chunk
63    buf.extend_from_slice(b"fmt ");
64    buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
65    buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
66    buf.extend_from_slice(&header.channels.to_le_bytes());
67    buf.extend_from_slice(&header.sample_rate.to_le_bytes());
68    buf.extend_from_slice(&byte_rate.to_le_bytes());
69    buf.extend_from_slice(&block_align.to_le_bytes());
70    buf.extend_from_slice(&header.bits_per_sample.to_le_bytes());
71
72    // data chunk
73    buf.extend_from_slice(b"data");
74    buf.extend_from_slice(&(data_size as u32).to_le_bytes());
75    for &sample in samples {
76        buf.extend_from_slice(&sample.to_le_bytes());
77    }
78
79    buf
80}
81
82/// Write PCM samples to a WAV file on disk.
83pub fn write_wav(path: &str, samples: &[i16], header: &WavHeader) -> Result<(), WavError> {
84    let data = encode_wav(samples, header);
85    std::fs::write(path, data)?;
86    Ok(())
87}
88
89/// Read PCM samples from WAV bytes.
90pub fn decode_wav(data: &[u8]) -> Result<(WavHeader, Vec<i16>), WavError> {
91    if data.len() < 44 {
92        return Err(WavError::InvalidHeader("too short".to_string()));
93    }
94
95    // Verify RIFF header
96    if &data[0..4] != b"RIFF" {
97        return Err(WavError::InvalidHeader("missing RIFF".to_string()));
98    }
99    if &data[8..12] != b"WAVE" {
100        return Err(WavError::InvalidHeader("missing WAVE".to_string()));
101    }
102
103    // Parse fmt chunk
104    if &data[12..16] != b"fmt " {
105        return Err(WavError::InvalidHeader("missing fmt chunk".to_string()));
106    }
107
108    let format = u16::from_le_bytes([data[20], data[21]]);
109    if format != 1 {
110        return Err(WavError::UnsupportedFormat(format!(
111            "not PCM (format={})",
112            format
113        )));
114    }
115
116    let channels = u16::from_le_bytes([data[22], data[23]]);
117    let sample_rate = u32::from_le_bytes([data[24], data[25], data[26], data[27]]);
118    let bits_per_sample = u16::from_le_bytes([data[34], data[35]]);
119
120    if bits_per_sample != 16 {
121        return Err(WavError::UnsupportedFormat(format!(
122            "not 16-bit (bits={})",
123            bits_per_sample
124        )));
125    }
126
127    // Find data chunk (skip any extra fmt data or other chunks)
128    let mut pos = 12;
129    loop {
130        if pos + 8 > data.len() {
131            return Err(WavError::InvalidHeader("missing data chunk".to_string()));
132        }
133        let chunk_id = &data[pos..pos + 4];
134        let chunk_size = u32::from_le_bytes([
135            data[pos + 4],
136            data[pos + 5],
137            data[pos + 6],
138            data[pos + 7],
139        ]) as usize;
140
141        if chunk_id == b"data" {
142            let sample_data = &data[pos + 8..pos + 8 + chunk_size.min(data.len() - pos - 8)];
143            let samples: Vec<i16> = sample_data
144                .chunks_exact(2)
145                .map(|c| i16::from_le_bytes([c[0], c[1]]))
146                .collect();
147
148            let header = WavHeader {
149                sample_rate,
150                channels,
151                bits_per_sample,
152                num_samples: samples.len(),
153            };
154
155            return Ok((header, samples));
156        }
157
158        pos += 8 + chunk_size;
159        // Align to even boundary
160        if chunk_size % 2 != 0 {
161            pos += 1;
162        }
163    }
164}
165
166/// Read PCM samples from a WAV file on disk.
167pub fn read_wav(path: &str) -> Result<(WavHeader, Vec<i16>), WavError> {
168    let data = std::fs::read(path)?;
169    decode_wav(&data)
170}
171
172/// An audio recorder that accumulates PCM samples.
173#[derive(Debug, Clone)]
174pub struct AudioRecorder {
175    samples: Vec<i16>,
176    sample_rate: u32,
177}
178
179impl AudioRecorder {
180    pub fn new(sample_rate: u32) -> Self {
181        Self {
182            samples: Vec::new(),
183            sample_rate,
184        }
185    }
186
187    /// Record a frame of PCM samples.
188    pub fn record_frame(&mut self, frame: &[i16]) {
189        self.samples.extend_from_slice(frame);
190    }
191
192    /// Get all recorded samples.
193    pub fn samples(&self) -> &[i16] {
194        &self.samples
195    }
196
197    /// Get duration in milliseconds.
198    pub fn duration_ms(&self) -> u64 {
199        (self.samples.len() as u64 * 1000) / self.sample_rate as u64
200    }
201
202    /// Get number of frames recorded (assuming 20ms frames).
203    pub fn frame_count(&self) -> usize {
204        let samples_per_frame = (self.sample_rate as usize * 20) / 1000;
205        if samples_per_frame == 0 {
206            return 0;
207        }
208        self.samples.len() / samples_per_frame
209    }
210
211    /// Export as WAV bytes.
212    pub fn to_wav(&self) -> Vec<u8> {
213        let header = WavHeader {
214            sample_rate: self.sample_rate,
215            channels: 1,
216            bits_per_sample: 16,
217            num_samples: self.samples.len(),
218        };
219        encode_wav(&self.samples, &header)
220    }
221
222    /// Save to a WAV file.
223    pub fn save_wav(&self, path: &str) -> Result<(), WavError> {
224        let header = WavHeader::mono(self.sample_rate);
225        write_wav(path, &self.samples, &header)
226    }
227
228    /// Clear all recorded samples.
229    pub fn clear(&mut self) {
230        self.samples.clear();
231    }
232
233    /// Check if any audio has been recorded.
234    pub fn is_empty(&self) -> bool {
235        self.samples.is_empty()
236    }
237
238    /// Total number of samples recorded.
239    pub fn len(&self) -> usize {
240        self.samples.len()
241    }
242}
243
244/// Generate a sine wave tone (for testing).
245pub fn generate_sine_tone(frequency: f64, sample_rate: u32, duration_ms: u32, amplitude: i16) -> Vec<i16> {
246    let num_samples = (sample_rate as u64 * duration_ms as u64 / 1000) as usize;
247    (0..num_samples)
248        .map(|i| {
249            let t = i as f64 / sample_rate as f64;
250            (f64::sin(2.0 * std::f64::consts::PI * frequency * t) * amplitude as f64) as i16
251        })
252        .collect()
253}
254
255/// Generate a multi-tone signal with distinct frequencies (for fidelity testing).
256pub fn generate_multi_tone(
257    frequencies: &[f64],
258    sample_rate: u32,
259    duration_ms: u32,
260    amplitude: i16,
261) -> Vec<i16> {
262    let num_samples = (sample_rate as u64 * duration_ms as u64 / 1000) as usize;
263    let scale = 1.0 / frequencies.len() as f64;
264    (0..num_samples)
265        .map(|i| {
266            let t = i as f64 / sample_rate as f64;
267            let sum: f64 = frequencies
268                .iter()
269                .map(|&freq| f64::sin(2.0 * std::f64::consts::PI * freq * t))
270                .sum();
271            (sum * scale * amplitude as f64) as i16
272        })
273        .collect()
274}
275
276/// Compute the signal-to-noise ratio (SNR) in dB between original and received audio.
277/// Higher values mean better fidelity. Typical telephony: >20 dB is acceptable.
278pub fn compute_snr(original: &[i16], received: &[i16]) -> f64 {
279    let len = original.len().min(received.len());
280    if len == 0 {
281        return 0.0;
282    }
283
284    let mut signal_power = 0.0f64;
285    let mut noise_power = 0.0f64;
286
287    for i in 0..len {
288        let s = original[i] as f64;
289        let n = (original[i] as f64) - (received[i] as f64);
290        signal_power += s * s;
291        noise_power += n * n;
292    }
293
294    if noise_power < 1.0 {
295        return 100.0; // Perfect match
296    }
297
298    10.0 * (signal_power / noise_power).log10()
299}
300
301/// Compute normalized cross-correlation between two signals.
302/// Returns a value between -1.0 and 1.0. Values > 0.9 indicate strong similarity.
303pub fn cross_correlation(a: &[i16], b: &[i16]) -> f64 {
304    let len = a.len().min(b.len());
305    if len == 0 {
306        return 0.0;
307    }
308
309    let mut sum_ab = 0.0f64;
310    let mut sum_aa = 0.0f64;
311    let mut sum_bb = 0.0f64;
312
313    for i in 0..len {
314        let va = a[i] as f64;
315        let vb = b[i] as f64;
316        sum_ab += va * vb;
317        sum_aa += va * va;
318        sum_bb += vb * vb;
319    }
320
321    let denom = (sum_aa * sum_bb).sqrt();
322    if denom < 1.0 {
323        return 0.0;
324    }
325
326    sum_ab / denom
327}
328
329/// Compute the maximum absolute sample-by-sample error.
330pub fn max_sample_error(original: &[i16], received: &[i16]) -> i32 {
331    let len = original.len().min(received.len());
332    let mut max_err = 0i32;
333    for i in 0..len {
334        let err = (original[i] as i32 - received[i] as i32).abs();
335        if err > max_err {
336            max_err = err;
337        }
338    }
339    max_err
340}
341
342/// Compute root-mean-square error between two signals.
343pub fn rms_error(original: &[i16], received: &[i16]) -> f64 {
344    let len = original.len().min(received.len());
345    if len == 0 {
346        return 0.0;
347    }
348    let sum: f64 = (0..len)
349        .map(|i| {
350            let diff = original[i] as f64 - received[i] as f64;
351            diff * diff
352        })
353        .sum();
354    (sum / len as f64).sqrt()
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_wav_roundtrip() {
363        let samples: Vec<i16> = (0..8000)
364            .map(|i| ((i as f64 / 8000.0 * std::f64::consts::TAU * 440.0).sin() * 16000.0) as i16)
365            .collect();
366
367        let header = WavHeader::telephony();
368        let encoded = encode_wav(&samples, &header);
369        let (decoded_header, decoded_samples) = decode_wav(&encoded).unwrap();
370
371        assert_eq!(decoded_header.sample_rate, 8000);
372        assert_eq!(decoded_header.channels, 1);
373        assert_eq!(decoded_header.bits_per_sample, 16);
374        assert_eq!(decoded_samples, samples);
375    }
376
377    #[test]
378    fn test_wav_file_roundtrip() {
379        let samples = generate_sine_tone(440.0, 8000, 100, 16000);
380        let header = WavHeader::telephony();
381
382        let path = "/tmp/siphone_test_wav_roundtrip.wav";
383        write_wav(path, &samples, &header).unwrap();
384        let (_, read_samples) = read_wav(path).unwrap();
385        assert_eq!(read_samples, samples);
386
387        std::fs::remove_file(path).ok();
388    }
389
390    #[test]
391    fn test_wav_invalid() {
392        assert!(decode_wav(b"NOT A WAV").is_err());
393        assert!(decode_wav(&[0; 10]).is_err());
394    }
395
396    #[test]
397    fn test_generate_sine_tone() {
398        let tone = generate_sine_tone(440.0, 8000, 100, 16000);
399        assert_eq!(tone.len(), 800); // 8000 * 0.1 = 800 samples
400
401        // Should have non-zero samples
402        assert!(tone.iter().any(|&s| s != 0));
403
404        // Max amplitude should be close to 16000
405        let max = tone.iter().map(|s| s.abs()).max().unwrap();
406        assert!(max > 15000 && max <= 16000);
407    }
408
409    #[test]
410    fn test_generate_multi_tone() {
411        let tone = generate_multi_tone(&[300.0, 500.0, 700.0], 8000, 100, 16000);
412        assert_eq!(tone.len(), 800);
413        assert!(tone.iter().any(|&s| s != 0));
414    }
415
416    #[test]
417    fn test_audio_recorder() {
418        let mut recorder = AudioRecorder::new(8000);
419        assert!(recorder.is_empty());
420        assert_eq!(recorder.duration_ms(), 0);
421
422        let frame = vec![1000i16; 160];
423        recorder.record_frame(&frame);
424        assert_eq!(recorder.len(), 160);
425        assert_eq!(recorder.duration_ms(), 20); // 160 / 8000 * 1000 = 20ms
426        assert_eq!(recorder.frame_count(), 1);
427
428        recorder.record_frame(&frame);
429        assert_eq!(recorder.len(), 320);
430        assert_eq!(recorder.frame_count(), 2);
431        assert_eq!(recorder.duration_ms(), 40);
432    }
433
434    #[test]
435    fn test_recorder_to_wav() {
436        let mut recorder = AudioRecorder::new(8000);
437        let tone = generate_sine_tone(440.0, 8000, 100, 16000);
438        for frame in tone.chunks(160) {
439            recorder.record_frame(frame);
440        }
441
442        let wav = recorder.to_wav();
443        let (header, samples) = decode_wav(&wav).unwrap();
444        assert_eq!(header.sample_rate, 8000);
445        assert_eq!(samples, recorder.samples());
446    }
447
448    #[test]
449    fn test_compute_snr_identical() {
450        let signal = generate_sine_tone(440.0, 8000, 100, 16000);
451        let snr = compute_snr(&signal, &signal);
452        assert!(snr > 90.0, "SNR for identical signals should be very high, got {}", snr);
453    }
454
455    #[test]
456    fn test_compute_snr_with_noise() {
457        let signal = generate_sine_tone(440.0, 8000, 100, 16000);
458        let noisy: Vec<i16> = signal
459            .iter()
460            .enumerate()
461            .map(|(i, &s)| {
462                let noise = ((i as f64 * 0.1).sin() * 100.0) as i16;
463                s.saturating_add(noise)
464            })
465            .collect();
466
467        let snr = compute_snr(&signal, &noisy);
468        assert!(snr > 20.0, "SNR should be decent, got {}", snr);
469    }
470
471    #[test]
472    fn test_cross_correlation_identical() {
473        let signal = generate_sine_tone(440.0, 8000, 100, 16000);
474        let corr = cross_correlation(&signal, &signal);
475        assert!((corr - 1.0).abs() < 0.001, "Self-correlation should be ~1.0, got {}", corr);
476    }
477
478    #[test]
479    fn test_cross_correlation_different() {
480        let sig_a = generate_sine_tone(440.0, 8000, 100, 16000);
481        let sig_b = generate_sine_tone(880.0, 8000, 100, 16000);
482        let corr = cross_correlation(&sig_a, &sig_b);
483        // Different frequencies should have lower correlation
484        assert!(corr < 0.5, "Different tones should have low correlation, got {}", corr);
485    }
486
487    #[test]
488    fn test_max_sample_error() {
489        let a = vec![100i16, 200, 300, 400, 500];
490        let b = vec![110i16, 190, 310, 350, 510];
491        let err = max_sample_error(&a, &b);
492        assert_eq!(err, 50); // 400 - 350
493    }
494
495    #[test]
496    fn test_rms_error_identical() {
497        let a = generate_sine_tone(440.0, 8000, 100, 16000);
498        let rms = rms_error(&a, &a);
499        assert!(rms < 0.001, "RMS error for identical signals should be ~0, got {}", rms);
500    }
501
502    #[test]
503    fn test_recorder_clear() {
504        let mut recorder = AudioRecorder::new(8000);
505        recorder.record_frame(&[100i16; 160]);
506        assert!(!recorder.is_empty());
507        recorder.clear();
508        assert!(recorder.is_empty());
509    }
510
511    #[test]
512    fn test_wav_header_constructors() {
513        let h = WavHeader::telephony();
514        assert_eq!(h.sample_rate, 8000);
515        assert_eq!(h.channels, 1);
516
517        let h = WavHeader::mono(48000);
518        assert_eq!(h.sample_rate, 48000);
519        assert_eq!(h.channels, 1);
520    }
521}