use crate::codecs::g711::*;
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;
#[repr(C, packed)]
struct WavHeader {
riff_id: [u8; 4], file_size: u32, wave_id: [u8; 4],
fmt_id: [u8; 4], fmt_size: u32, audio_format: u16, num_channels: u16, sample_rate: u32, byte_rate: u32, block_align: u16, bits_per_sample: u16,
data_id: [u8; 4], data_size: u32, }
impl WavHeader {
fn new(num_samples: usize) -> Self {
let data_size = (num_samples * 2) as u32; let file_size = data_size + 36;
WavHeader {
riff_id: *b"RIFF",
file_size,
wave_id: *b"WAVE",
fmt_id: *b"fmt ",
fmt_size: 16,
audio_format: 1,
num_channels: 1,
sample_rate: 8000,
byte_rate: 16000, block_align: 2, bits_per_sample: 16,
data_id: *b"data",
data_size,
}
}
}
async fn download_file(url: &str, path: &Path) -> Result<(), Box<dyn std::error::Error>> {
let response = std::process::Command::new("curl")
.arg("-s") .arg("-L") .arg("-o")
.arg(path)
.arg(url)
.output()?;
if !response.status.success() {
return Err(format!(
"Failed to download file: {}",
String::from_utf8_lossy(&response.stderr)
)
.into());
}
Ok(())
}
fn read_wav_file(path: &Path) -> Result<Vec<i16>, Box<dyn std::error::Error>> {
let mut file = File::open(path)?;
let mut header_bytes = [0u8; 44];
file.read_exact(&mut header_bytes)?;
if &header_bytes[0..4] != b"RIFF" || &header_bytes[8..12] != b"WAVE" {
return Err("Not a valid WAV file".into());
}
let audio_format = u16::from_le_bytes([header_bytes[20], header_bytes[21]]);
let num_channels = u16::from_le_bytes([header_bytes[22], header_bytes[23]]);
let sample_rate = u32::from_le_bytes([
header_bytes[24],
header_bytes[25],
header_bytes[26],
header_bytes[27],
]);
let bits_per_sample = u16::from_le_bytes([header_bytes[34], header_bytes[35]]);
if audio_format != 1 {
return Err("Only PCM format is supported".into());
}
if num_channels != 1 {
return Err("Only mono audio is supported".into());
}
if sample_rate != 8000 {
return Err("Only 8000 Hz sample rate is supported".into());
}
if bits_per_sample != 16 {
return Err("Only 16-bit samples are supported".into());
}
let data_size = u32::from_le_bytes([
header_bytes[40],
header_bytes[41],
header_bytes[42],
header_bytes[43],
]);
let num_samples = (data_size / 2) as usize;
let mut samples = vec![0i16; num_samples];
for i in 0..num_samples {
let mut bytes = [0u8; 2];
file.read_exact(&mut bytes)?;
samples[i] = i16::from_le_bytes(bytes);
}
Ok(samples)
}
fn write_wav_file(path: &Path, samples: &[i16]) -> Result<(), Box<dyn std::error::Error>> {
let mut file = File::create(path)?;
let header = WavHeader::new(samples.len());
file.write_all(&header.riff_id)?;
file.write_all(&header.file_size.to_le_bytes())?;
file.write_all(&header.wave_id)?;
file.write_all(&header.fmt_id)?;
file.write_all(&header.fmt_size.to_le_bytes())?;
file.write_all(&header.audio_format.to_le_bytes())?;
file.write_all(&header.num_channels.to_le_bytes())?;
file.write_all(&header.sample_rate.to_le_bytes())?;
file.write_all(&header.byte_rate.to_le_bytes())?;
file.write_all(&header.block_align.to_le_bytes())?;
file.write_all(&header.bits_per_sample.to_le_bytes())?;
file.write_all(&header.data_id)?;
file.write_all(&header.data_size.to_le_bytes())?;
for sample in samples {
file.write_all(&sample.to_le_bytes())?;
}
Ok(())
}
fn calculate_snr(original: &[i16], decoded: &[i16]) -> f64 {
assert_eq!(
original.len(),
decoded.len(),
"Signals must have same length"
);
let mut signal_power = 0i64;
let mut noise_power = 0i64;
for i in 0..original.len() {
let signal = original[i] as i64;
let noise = (original[i] as i64) - (decoded[i] as i64);
signal_power += signal * signal;
noise_power += noise * noise;
}
if noise_power == 0 {
return f64::INFINITY;
}
10.0 * ((signal_power as f64) / (noise_power as f64)).log10()
}
#[cfg(test)]
mod tests {
use super::*;
use tokio;
#[tokio::test]
async fn test_g711_alaw_roundtrip_real_audio() {
test_g711_roundtrip_real_audio(G711Variant::ALaw, "alaw").await;
}
#[tokio::test]
async fn test_g711_ulaw_roundtrip_real_audio() {
test_g711_roundtrip_real_audio(G711Variant::MuLaw, "ulaw").await;
}
async fn test_g711_roundtrip_real_audio(variant: G711Variant, variant_name: &str) {
const WAV_URL: &str =
"https://www.voiptroubleshooter.com/open_speech/american/OSR_us_000_0010_8k.wav";
let test_dir = Path::new("src/codecs/g711/tests/test_data");
std::fs::create_dir_all(test_dir).expect("Failed to create test directory");
let original_wav_path = test_dir.join("OSR_us_000_0010_8k.wav");
if !original_wav_path.exists() {
println!("Downloading WAV file from: {}", WAV_URL);
download_file(WAV_URL, &original_wav_path)
.await
.expect("Failed to download WAV file");
println!("Downloaded WAV file to: {:?}", original_wav_path);
} else {
println!("Using existing WAV file: {:?}", original_wav_path);
}
let original_samples = read_wav_file(&original_wav_path).expect("Failed to read WAV file");
println!("Loaded {} samples from WAV file", original_samples.len());
let codec = G711Codec::new(variant);
let encoded = codec
.compress(&original_samples)
.expect("Failed to encode samples");
println!(
"Encoded {} samples to {} bytes using {}",
original_samples.len(),
encoded.len(),
variant_name
);
let decoded_samples = codec.expand(&encoded).expect("Failed to decode samples");
println!(
"Decoded {} bytes to {} samples using {}",
encoded.len(),
decoded_samples.len(),
variant_name
);
assert_eq!(
original_samples.len(),
decoded_samples.len(),
"Sample count mismatch after roundtrip"
);
let snr = calculate_snr(&original_samples, &decoded_samples);
println!("Signal-to-Noise Ratio: {:.2} dB", snr);
assert!(snr > 20.0, "SNR too low: {:.2} dB", snr);
let output_wav_path =
test_dir.join(format!("OSR_us_000_0010_8k_roundtrip_{}.wav", variant_name));
write_wav_file(&output_wav_path, &decoded_samples)
.expect("Failed to write output WAV file");
println!("Saved roundtrip result to: {:?}", output_wav_path);
assert!(
!decoded_samples.is_empty(),
"Decoded samples should not be empty"
);
let non_zero_samples = decoded_samples.iter().filter(|&&s| s != 0).count();
assert!(
non_zero_samples > original_samples.len() / 10,
"Too many zero samples, audio might be corrupted"
);
println!("✓ G.711 {} roundtrip test passed!", variant_name);
println!(" - Original samples: {}", original_samples.len());
println!(" - Encoded bytes: {}", encoded.len());
println!(" - Compression ratio: 1:1 (G.711 is 1:1)");
println!(" - SNR: {:.2} dB", snr);
println!(
" - Non-zero samples: {}%",
(non_zero_samples * 100) / original_samples.len()
);
}
}