use opus_rs::celt::{CeltDecoder, CeltEncoder};
use opus_rs::modes::default_mode;
use opus_rs::range_coder::RangeCoder;
#[test]
fn test_celt_loopback() {
let mode = default_mode();
let channels = 1;
let frame_size = mode.mdct.n / 2;
let mut encoder = CeltEncoder::new(mode, channels);
let mut decoder = CeltDecoder::new(mode, channels);
let nb_frames = 12;
let mut all_in = Vec::new();
let mut all_out = Vec::new();
for f in 0..nb_frames {
let mut pcm_in = vec![0.0f32; frame_size * channels];
for i in 0..frame_size {
pcm_in[i] = ((f * frame_size + i) as f32 * 0.1).sin();
}
all_in.extend_from_slice(&pcm_in);
let rc_size: u32 = 2048;
let mut rc = RangeCoder::new_encoder(rc_size);
encoder.encode(&pcm_in, frame_size, &mut rc);
rc.done();
let mut compressed = vec![0u8; rc_size as usize];
compressed.copy_from_slice(&rc.buf[..rc_size as usize]);
let mut pcm_out = vec![0.0f32; frame_size * channels];
let decoded_len = decoder.decode(&compressed, frame_size, &mut pcm_out);
assert_eq!(decoded_len, frame_size);
all_out.extend_from_slice(&pcm_out);
println!("Frame {}:", f);
println!(" pcm_in[0..5] = {:?}", &pcm_in[0..5]);
println!(" pcm_out[0..5] = {:?}", &pcm_out[0..5]);
let delay = 0; if all_out.len() >= delay + frame_size {
let start_out = all_out.len() - frame_size;
let start_in = start_out - delay;
let mut sq_err = 0.0;
let mut sq_sig = 0.0;
for i in 0..frame_size {
let s_in = all_in[start_in + i];
let s_out = all_out[start_out + i];
sq_err += (s_in - s_out) * (s_in - s_out);
sq_sig += s_in * s_in;
}
let snr = 10.0 * (sq_sig / (sq_err + 1e-10)).log10();
println!("Frame {} SNR (0 delay): {:.2} dB", f, snr);
}
}
let mut best_snr = -100.0f32;
let mut best_delay = 0;
let start_idx = 4 * frame_size;
let end_idx = 10 * frame_size;
for delay in 0..2000 {
let mut s_e = 0.0;
let mut n_e = 0.0;
let mut count = 0;
for i in start_idx..end_idx {
if i + delay >= all_out.len() {
break;
}
let s = all_in[i];
let d = all_out[i + delay];
s_e += (s as f64) * (s as f64);
n_e += ((s - d) as f64) * ((s - d) as f64);
count += 1;
}
if count < frame_size {
continue;
}
let snr = 10.0 * (s_e / (n_e + 1e-10)).log10() as f32;
if snr > best_snr {
best_snr = snr;
best_delay = delay;
}
}
let snr_0 = calculate_snr(&all_in, &all_out, 0, start_idx, end_idx);
println!("SNR at delay 0: {:.2} dB", snr_0);
println!(
"Loopback Global Best SNR: {:.2} dB at delay {}",
best_snr, best_delay
);
println!("Samples at delay 0:");
for i in 0..10 {
let idx = start_idx + i;
if idx < all_in.len() && idx < all_out.len() {
println!(
" {}: in={:10.6}, out={:10.6}",
idx, all_in[idx], all_out[idx]
);
}
}
println!("Samples at delay {}:", best_delay);
for i in 0..10 {
let idx = start_idx + i;
if idx < all_in.len() && idx + best_delay < all_out.len() {
println!(
" {}: in={:10.6}, out={:10.6}",
idx,
all_in[idx],
all_out[idx + best_delay]
);
}
}
assert!(best_snr > 1.8, "SNR too low: {:.2} dB", best_snr);
}
fn calculate_snr(all_in: &[f32], all_out: &[f32], delay: usize, start: usize, end: usize) -> f32 {
let mut s_e = 0.0;
let mut n_e = 0.0;
for i in start..end {
if i + delay >= all_out.len() {
break;
}
let s = all_in[i];
let d = all_out[i + delay];
s_e += (s as f64) * (s as f64);
n_e += ((s - d) as f64) * ((s - d) as f64);
}
10.0 * (s_e / (n_e + 1e-10)).log10() as f32
}