use super::silk_decoder::SilkSignalType;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum NsqMode {
Greedy,
TrellisDelDec,
}
impl Default for NsqMode {
fn default() -> Self {
NsqMode::TrellisDelDec
}
}
#[derive(Debug, Clone)]
pub struct NsqState {
pub slpc: Vec<f32>,
pub sltp: Vec<f32>,
pub slf_ar_shp: f32,
pub slf_ma_shp: f32,
pub prev_gain: f32,
pub mode: NsqMode,
}
impl NsqState {
pub fn new(lpc_order: usize, ltp_max_lag: usize) -> Self {
Self {
slpc: vec![0.0; lpc_order],
sltp: vec![0.0; ltp_max_lag + 8],
slf_ar_shp: 0.0,
slf_ma_shp: 0.0,
prev_gain: 1.0,
mode: NsqMode::default(),
}
}
pub fn reset(&mut self) {
self.slpc.fill(0.0);
self.sltp.fill(0.0);
self.slf_ar_shp = 0.0;
self.slf_ma_shp = 0.0;
self.prev_gain = 1.0;
}
}
impl Default for NsqState {
fn default() -> Self {
Self::new(16, 288)
}
}
#[derive(Clone)]
struct TrellisPath {
cost: f64,
slpc: Vec<f32>,
sltp: Vec<f32>,
slf_ar_shp: f32,
slf_ma_shp: f32,
pulses: Vec<i32>,
}
pub fn bandwidth_expand(a: &[f32], gamma: f32) -> Vec<f32> {
a.iter()
.enumerate()
.map(|(k, &c)| c * gamma.powi(k as i32 + 1))
.collect()
}
pub fn silk_warped_lpc_analysis_filter(a: &[f32], lambda: f32) -> Vec<f32> {
let n = a.len();
if n == 0 {
return Vec::new();
}
let mut out = vec![0.0f32; n];
let mut state = 0.0f32;
for k in 0..n {
let new_state = a[k] + lambda * state;
out[k] = new_state;
state = new_state;
}
let mut warped = out.clone();
for k in (1..n).rev() {
warped[k - 1] = out[k - 1] - lambda * warped[k];
}
warped
}
const E_RAW_SCALE: f32 = 32768.0;
pub fn process_subframe(
input: &[f32],
lpc_coeffs: &[f32],
ltp_coeffs: &[f32; 5],
ltp_lag: usize,
quant_gain: f32,
signal_type: SilkSignalType,
state: &mut NsqState,
) -> Vec<f32> {
match state.mode {
NsqMode::Greedy => process_subframe_greedy(
input,
lpc_coeffs,
ltp_coeffs,
ltp_lag,
quant_gain,
signal_type,
state,
),
NsqMode::TrellisDelDec => {
let pulses = trellis_process_subframe(
input,
lpc_coeffs,
ltp_coeffs,
ltp_lag,
quant_gain,
signal_type,
state,
);
pulses
.iter()
.map(|&p| (p as f32 / E_RAW_SCALE).clamp(-1.0, 1.0))
.collect()
}
}
}
fn process_subframe_greedy(
input: &[f32],
lpc_coeffs: &[f32],
ltp_coeffs: &[f32; 5],
ltp_lag: usize,
quant_gain: f32,
signal_type: SilkSignalType,
state: &mut NsqState,
) -> Vec<f32> {
let n = input.len();
let lpc_order = lpc_coeffs.len().min(state.slpc.len());
let ltp_buf_len = state.sltp.len();
let (ar_coeff, ma_coeff) = match signal_type {
SilkSignalType::Voiced => (0.04f32, -0.03f32),
SilkSignalType::Unvoiced => (0.02f32, -0.02f32),
SilkSignalType::Inactive => (0.005f32, -0.005f32),
};
let safe_gain = if quant_gain.abs() > 1e-9 {
quant_gain
} else {
1e-9
};
let inv_gain = 1.0f32 / safe_gain;
let mut excitation = Vec::with_capacity(n);
for t in 0..n {
let mut p_lpc_neg = 0.0f32;
for k in 0..lpc_order {
p_lpc_neg -= lpc_coeffs[k] * state.slpc[k];
}
let mut p_ltp = 0.0f32;
if ltp_lag >= 3 && ltp_buf_len > ltp_lag + 3 {
for k in 0..5usize {
let src_idx = ltp_lag.saturating_sub(3).saturating_add(k);
if src_idx < ltp_buf_len {
p_ltp += ltp_coeffs[k] * state.sltp[src_idx];
}
}
}
let p_ar = ar_coeff * state.slf_ar_shp;
let p_ma = ma_coeff * state.slf_ma_shp;
let prediction = p_ltp + p_lpc_neg + p_ar + p_ma;
let err = input[t] - prediction;
let e_raw_raw = (err * inv_gain * E_RAW_SCALE).round() as i32;
let e_raw_clamped = e_raw_raw.clamp(-2047, 2047);
let lambda = 0.5f32 / E_RAW_SCALE;
let e_raw_chosen = {
let mut best = e_raw_clamped;
let mut best_cost = f32::INFINITY;
for delta in [-1i32, 0, 1] {
let cand = (e_raw_clamped + delta).clamp(-2047, 2047);
let approx_exc = cand as f32 / E_RAW_SCALE;
let predicted_out = approx_exc * safe_gain + prediction;
let dist = (input[t] - predicted_out).powi(2);
let rate_proxy = (cand.abs() as f32) * lambda;
let cost = dist + rate_proxy;
if cost < best_cost {
best_cost = cost;
best = cand;
}
}
best
};
let float_exc = (e_raw_chosen as f32 / E_RAW_SCALE).clamp(-1.0, 1.0);
let lpc_residual_q = float_exc * safe_gain + p_ltp;
let xq_out = lpc_residual_q + p_lpc_neg;
if lpc_order > 1 {
state.slpc.copy_within(0..lpc_order - 1, 1);
}
if lpc_order > 0 {
state.slpc[0] = xq_out;
}
if ltp_buf_len > 1 {
state.sltp.copy_within(0..ltp_buf_len - 1, 1);
}
if ltp_buf_len > 0 {
state.sltp[0] = lpc_residual_q;
}
let quant_err = err - float_exc * safe_gain;
state.slf_ar_shp = xq_out;
state.slf_ma_shp = quant_err;
excitation.push(float_exc);
}
excitation
}
const TRELLIS_N: usize = 4;
const TRELLIS_K: usize = 5;
pub fn trellis_process_subframe(
input: &[f32],
lpc_coeffs: &[f32],
ltp_coeffs: &[f32; 5],
ltp_lag: usize,
quant_gain: f32,
signal_type: SilkSignalType,
state: &mut NsqState,
) -> Vec<i32> {
let n = input.len();
let lpc_order = lpc_coeffs.len().min(state.slpc.len());
let ltp_buf_len = state.sltp.len();
let safe_gain = if quant_gain.abs() > 1e-9 {
quant_gain
} else {
1e-9
};
let inv_gain = 1.0f32 / safe_gain;
let (ar_coeff, ma_coeff) = match signal_type {
SilkSignalType::Voiced => (0.04f32, -0.03f32),
SilkSignalType::Unvoiced => (0.02f32, -0.02f32),
SilkSignalType::Inactive => (0.005f32, -0.005f32),
};
let lambda_base: f64 = match signal_type {
SilkSignalType::Voiced => 2e-12,
SilkSignalType::Unvoiced => 1.5e-12,
SilkSignalType::Inactive => 1e-12,
};
let gain_sq = f64::from(safe_gain) * f64::from(safe_gain);
let lambda = lambda_base * gain_sq.max(1e-20);
let mut paths: Vec<TrellisPath> = (0..TRELLIS_N)
.map(|_| TrellisPath {
cost: 0.0,
slpc: state.slpc.clone(),
sltp: state.sltp.clone(),
slf_ar_shp: state.slf_ar_shp,
slf_ma_shp: state.slf_ma_shp,
pulses: Vec::with_capacity(n),
})
.collect();
let mut candidates: Vec<(f64, usize, i32)> = Vec::with_capacity(TRELLIS_N * TRELLIS_K);
for t in 0..n {
candidates.clear();
for (pi, path) in paths.iter().enumerate() {
let mut p_lpc_neg = 0.0f32;
for k in 0..lpc_order {
p_lpc_neg -= lpc_coeffs[k] * path.slpc[k];
}
let mut p_ltp = 0.0f32;
if ltp_lag >= 3 && ltp_buf_len > ltp_lag + 3 {
for k in 0..5usize {
let src_idx = ltp_lag.saturating_sub(3).saturating_add(k);
if src_idx < ltp_buf_len {
p_ltp += ltp_coeffs[k] * path.sltp[src_idx];
}
}
}
let p_ar = ar_coeff * path.slf_ar_shp;
let p_ma = ma_coeff * path.slf_ma_shp;
let prediction = p_ltp + p_lpc_neg + p_ar + p_ma;
let err = input[t] - prediction;
let center = (err * inv_gain * E_RAW_SCALE).round() as i32;
let center = center.clamp(-2047, 2047);
let prev_pulse = path.pulses.last().copied().unwrap_or(0);
for delta in -2i32..=2i32 {
let pulse = (center + delta).clamp(-2047, 2047);
let float_exc = pulse as f32 / E_RAW_SCALE;
let reconstructed = float_exc * safe_gain + prediction;
let d = f64::from(input[t] - reconstructed).powi(2);
let sign_change = if pulse != 0 && pulse.signum() != prev_pulse.signum() {
1.0f64
} else {
0.0
};
let r = (pulse.unsigned_abs() as f64).powf(0.55) + sign_change;
let cand_cost = path.cost + d + lambda * r;
candidates.push((cand_cost, pi, pulse));
}
}
candidates.sort_unstable_by(|a, b| {
a.0.partial_cmp(&b.0)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.2.unsigned_abs().cmp(&b.2.unsigned_abs()))
});
let mut new_paths: Vec<TrellisPath> = Vec::with_capacity(TRELLIS_N);
let mut used_pulses = std::collections::HashSet::with_capacity(TRELLIS_N);
let mut used = 0usize;
for &(cand_cost, pi, pulse) in candidates.iter() {
if used >= TRELLIS_N {
break;
}
if used_pulses.contains(&pulse) {
continue; }
used_pulses.insert(pulse);
let src = &paths[pi];
let float_exc = pulse as f32 / E_RAW_SCALE;
let mut p_lpc_neg_here = 0.0f32;
for k in 0..lpc_order {
p_lpc_neg_here -= lpc_coeffs[k] * src.slpc[k];
}
let mut p_ltp_here = 0.0f32;
if ltp_lag >= 3 && ltp_buf_len > ltp_lag + 3 {
for k in 0..5usize {
let src_idx = ltp_lag.saturating_sub(3).saturating_add(k);
if src_idx < ltp_buf_len {
p_ltp_here += ltp_coeffs[k] * src.sltp[src_idx];
}
}
}
let p_ar_here = ar_coeff * src.slf_ar_shp;
let p_ma_here = ma_coeff * src.slf_ma_shp;
let prediction_here = p_ltp_here + p_lpc_neg_here + p_ar_here + p_ma_here;
let err_here = input[t] - prediction_here;
let lpc_residual_q = float_exc * safe_gain + p_ltp_here;
let xq_out = lpc_residual_q + p_lpc_neg_here;
let quant_err = err_here - float_exc * safe_gain;
let mut new_slpc = src.slpc.clone();
let mut new_sltp = src.sltp.clone();
if lpc_order > 1 {
new_slpc.copy_within(0..lpc_order - 1, 1);
}
if lpc_order > 0 {
new_slpc[0] = xq_out;
}
if ltp_buf_len > 1 {
new_sltp.copy_within(0..ltp_buf_len - 1, 1);
}
if ltp_buf_len > 0 {
new_sltp[0] = lpc_residual_q;
}
let mut new_pulses = src.pulses.clone();
new_pulses.push(pulse);
new_paths.push(TrellisPath {
cost: cand_cost,
slpc: new_slpc,
sltp: new_sltp,
slf_ar_shp: xq_out,
slf_ma_shp: quant_err,
pulses: new_pulses,
});
used += 1;
}
for &(cand_cost, pi, pulse) in candidates.iter() {
if used >= TRELLIS_N {
break;
}
let src = &paths[pi];
let float_exc = pulse as f32 / E_RAW_SCALE;
let mut p_lpc_neg_here = 0.0f32;
for k in 0..lpc_order {
p_lpc_neg_here -= lpc_coeffs[k] * src.slpc[k];
}
let mut p_ltp_here = 0.0f32;
if ltp_lag >= 3 && ltp_buf_len > ltp_lag + 3 {
for k in 0..5usize {
let src_idx = ltp_lag.saturating_sub(3).saturating_add(k);
if src_idx < ltp_buf_len {
p_ltp_here += ltp_coeffs[k] * src.sltp[src_idx];
}
}
}
let p_ar_here = ar_coeff * src.slf_ar_shp;
let p_ma_here = ma_coeff * src.slf_ma_shp;
let prediction_here = p_ltp_here + p_lpc_neg_here + p_ar_here + p_ma_here;
let err_here = input[t] - prediction_here;
let lpc_residual_q = float_exc * safe_gain + p_ltp_here;
let xq_out = lpc_residual_q + p_lpc_neg_here;
let quant_err = err_here - float_exc * safe_gain;
let mut new_slpc = src.slpc.clone();
let mut new_sltp = src.sltp.clone();
if lpc_order > 1 {
new_slpc.copy_within(0..lpc_order - 1, 1);
}
if lpc_order > 0 {
new_slpc[0] = xq_out;
}
if ltp_buf_len > 1 {
new_sltp.copy_within(0..ltp_buf_len - 1, 1);
}
if ltp_buf_len > 0 {
new_sltp[0] = lpc_residual_q;
}
let mut new_pulses = src.pulses.clone();
new_pulses.push(pulse);
new_paths.push(TrellisPath {
cost: cand_cost,
slpc: new_slpc,
sltp: new_sltp,
slf_ar_shp: xq_out,
slf_ma_shp: quant_err,
pulses: new_pulses,
});
used += 1;
}
paths = new_paths;
}
let winner = paths
.into_iter()
.min_by(|a, b| {
a.cost
.partial_cmp(&b.cost)
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or_else(|| {
TrellisPath {
cost: 0.0,
slpc: state.slpc.clone(),
sltp: state.sltp.clone(),
slf_ar_shp: state.slf_ar_shp,
slf_ma_shp: state.slf_ma_shp,
pulses: vec![0; n],
}
});
state.slpc = winner.slpc;
state.sltp = winner.sltp;
state.slf_ar_shp = winner.slf_ar_shp;
state.slf_ma_shp = winner.slf_ma_shp;
winner.pulses
}
pub fn segmental_snr_db(original: &[f32], reconstructed: &[f32], frame_len: usize) -> f32 {
if frame_len == 0 || original.is_empty() || reconstructed.is_empty() {
return 0.0;
}
let len = original.len().min(reconstructed.len());
let mut total_snr_db = 0.0f64;
let mut frames_counted = 0usize;
let mut offset = 0usize;
while offset + frame_len <= len {
let orig = &original[offset..offset + frame_len];
let recon = &reconstructed[offset..offset + frame_len];
let sig_e: f64 = orig.iter().map(|&s| f64::from(s) * f64::from(s)).sum();
let noise_e: f64 = orig
.iter()
.zip(recon.iter())
.map(|(&o, &r)| {
let d = f64::from(o) - f64::from(r);
d * d
})
.sum();
if sig_e < 1e-10 {
offset += frame_len;
continue;
}
let frame_snr = if noise_e < 1e-30 {
120.0
} else {
10.0 * (sig_e / noise_e).log10()
};
total_snr_db += frame_snr;
frames_counted += 1;
offset += frame_len;
}
if frames_counted == 0 {
return 0.0;
}
(total_snr_db / frames_counted as f64) as f32
}
#[cfg(test)]
mod tests {
use super::super::packet::OpusBandwidth;
use super::super::silk::{SilkDecoder, SilkEncoder};
use super::*;
#[test]
fn test_warped_lpc_zero_input() {
let a = vec![0.0f32; 8];
let out = silk_warped_lpc_analysis_filter(&a, 0.26);
assert_eq!(out.len(), 8);
for v in &out {
assert!(v.abs() < 1e-9, "expected zero, got {v}");
}
}
#[test]
fn test_warped_lpc_impulse_nonzero() {
let mut a = vec![0.0f32; 8];
a[0] = 1.0;
let out = silk_warped_lpc_analysis_filter(&a, 0.26);
assert_eq!(out.len(), 8);
let norm: f32 = out.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!(norm > 0.1, "warped output norm should be > 0.1, got {norm}");
let any_nonzero = out.iter().any(|v| v.abs() > 1e-6);
assert!(
any_nonzero,
"warped impulse should produce at least one non-zero coefficient"
);
}
#[test]
fn test_bandwidth_expand_known() {
let a = vec![1.0f32, -0.5f32];
let out = bandwidth_expand(&a, 0.9);
assert_eq!(out.len(), 2);
assert!((out[0] - 0.9).abs() < 0.001, "out[0] = {}", out[0]);
assert!((out[1] - (-0.405)).abs() < 0.001, "out[1] = {}", out[1]);
}
#[test]
fn test_nsq_roundtrip_white_noise_snr_positive() {
const SR: u32 = 16000;
const FRAME: usize = 320; let mut encoder = SilkEncoder::new(SR, 1, OpusBandwidth::Wideband);
let mut decoder = SilkDecoder::new(SR, 1, OpusBandwidth::Wideband);
let mut buf = vec![0u8; 4096];
let mut out = vec![0.0f32; FRAME];
let silence = vec![0.0f32; FRAME];
for _ in 0..4 {
let _ = encoder.encode(&silence, &mut buf, FRAME);
}
let mut seed: u32 = 0xDEAD_BEEF;
let noise: Vec<f32> = (0..FRAME * 5)
.map(|_| {
seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
((seed >> 1) as f32) / (i32::MAX as f32) - 0.5
})
.collect();
let mut all_orig = Vec::new();
let mut all_recon = Vec::new();
for k in 0..5 {
let sl = &noise[k * FRAME..(k + 1) * FRAME];
let n = encoder.encode(sl, &mut buf, FRAME).expect("enc");
decoder.decode(&buf[..n], &mut out, FRAME).expect("dec");
all_orig.extend_from_slice(sl);
all_recon.extend_from_slice(&out);
}
let snr = segmental_snr_db(&all_orig, &all_recon, FRAME);
println!("white noise SNR = {snr:.2} dB");
assert!(snr > 0.0, "white noise SNR = {snr:.2} dB, expected > 0 dB");
}
#[test]
fn test_nsq_snr_440hz_tone() {
const SR: u32 = 16000;
const FRAME: usize = 320;
let mut encoder = SilkEncoder::new(SR, 1, OpusBandwidth::Wideband);
let mut decoder = SilkDecoder::new(SR, 1, OpusBandwidth::Wideband);
let mut buf = vec![0u8; 4096];
let mut out = vec![0.0f32; FRAME];
let silence = vec![0.0f32; FRAME];
for _ in 0..6 {
let _ = encoder.encode(&silence, &mut buf, FRAME);
}
let tone: Vec<f32> = (0..FRAME * 8)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / SR as f32).sin() * 0.5)
.collect();
let mut all_orig = Vec::new();
let mut all_recon = Vec::new();
for k in 0..8 {
let sl = &tone[k * FRAME..(k + 1) * FRAME];
let n = encoder.encode(sl, &mut buf, FRAME).expect("enc");
decoder.decode(&buf[..n], &mut out, FRAME).expect("dec");
all_orig.extend_from_slice(sl);
all_recon.extend_from_slice(&out);
}
let snr = segmental_snr_db(&all_orig, &all_recon, FRAME);
println!("440 Hz tone SNR = {snr:.2} dB");
assert!(
snr > 4.5,
"440 Hz tone SNR = {snr:.2} dB (expected > 4.5 dB)"
);
}
#[test]
fn test_nsq_snr_1khz_tone() {
const SR: u32 = 16000;
const FRAME: usize = 320;
let mut encoder = SilkEncoder::new(SR, 1, OpusBandwidth::Wideband);
let mut decoder = SilkDecoder::new(SR, 1, OpusBandwidth::Wideband);
let mut buf = vec![0u8; 4096];
let mut out = vec![0.0f32; FRAME];
let silence = vec![0.0f32; FRAME];
for _ in 0..6 {
let _ = encoder.encode(&silence, &mut buf, FRAME);
}
let tone: Vec<f32> = (0..FRAME * 8)
.map(|i| (2.0 * std::f32::consts::PI * 1000.0 * i as f32 / SR as f32).sin() * 0.5)
.collect();
let mut all_orig = Vec::new();
let mut all_recon = Vec::new();
for k in 0..8 {
let sl = &tone[k * FRAME..(k + 1) * FRAME];
let n = encoder.encode(sl, &mut buf, FRAME).expect("enc");
decoder.decode(&buf[..n], &mut out, FRAME).expect("dec");
all_orig.extend_from_slice(sl);
all_recon.extend_from_slice(&out);
}
let snr = segmental_snr_db(&all_orig, &all_recon, FRAME);
println!("1 kHz tone NSQ SNR (trellis): {snr:.2} dB");
assert!(snr.is_finite(), "1 kHz NSQ SNR must be finite: {snr}");
assert!(snr > 0.0, "1 kHz NSQ SNR = {snr:.2} dB (expected > 0 dB)");
}
#[test]
fn test_nsq_mode_default_is_trellis() {
let state = NsqState::default();
assert_eq!(state.mode, NsqMode::TrellisDelDec);
}
#[test]
fn test_nsq_snr_440hz_trellis() {
const SR: u32 = 16000;
const FRAME: usize = 320;
let mut encoder = SilkEncoder::new(SR, 1, OpusBandwidth::Wideband);
let mut decoder = SilkDecoder::new(SR, 1, OpusBandwidth::Wideband);
let mut buf = vec![0u8; 4096];
let mut out = vec![0.0f32; FRAME];
let silence = vec![0.0f32; FRAME];
for _ in 0..6 {
let _ = encoder.encode(&silence, &mut buf, FRAME);
}
let tone: Vec<f32> = (0..FRAME * 8)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / SR as f32).sin() * 0.5)
.collect();
let mut all_orig = Vec::new();
let mut all_recon = Vec::new();
for k in 0..8 {
let sl = &tone[k * FRAME..(k + 1) * FRAME];
let n = encoder.encode(sl, &mut buf, FRAME).expect("enc");
decoder.decode(&buf[..n], &mut out, FRAME).expect("dec");
all_orig.extend_from_slice(sl);
all_recon.extend_from_slice(&out);
}
let snr = segmental_snr_db(&all_orig, &all_recon, FRAME);
println!("440 Hz trellis SNR = {snr:.2} dB");
assert!(
snr >= 4.5,
"440 Hz trellis SNR = {snr:.2} dB (expected >= 4.5 dB)"
);
}
#[test]
fn test_nsq_snr_1khz_trellis() {
const SR: u32 = 16000;
const FRAME: usize = 320;
let mut encoder = SilkEncoder::new(SR, 1, OpusBandwidth::Wideband);
let mut decoder = SilkDecoder::new(SR, 1, OpusBandwidth::Wideband);
let mut buf = vec![0u8; 4096];
let mut out = vec![0.0f32; FRAME];
let silence = vec![0.0f32; FRAME];
for _ in 0..6 {
let _ = encoder.encode(&silence, &mut buf, FRAME);
}
let tone: Vec<f32> = (0..FRAME * 8)
.map(|i| (2.0 * std::f32::consts::PI * 1000.0 * i as f32 / SR as f32).sin() * 0.5)
.collect();
let mut all_orig = Vec::new();
let mut all_recon = Vec::new();
for k in 0..8 {
let sl = &tone[k * FRAME..(k + 1) * FRAME];
let n = encoder.encode(sl, &mut buf, FRAME).expect("enc");
decoder.decode(&buf[..n], &mut out, FRAME).expect("dec");
all_orig.extend_from_slice(sl);
all_recon.extend_from_slice(&out);
}
let snr = segmental_snr_db(&all_orig, &all_recon, FRAME);
println!("1 kHz trellis SNR = {snr:.2} dB");
assert!(snr.is_finite(), "1 kHz trellis SNR must be finite");
assert!(
snr >= 3.0,
"1 kHz trellis SNR = {snr:.2} dB (expected >= 3 dB structural floor)"
);
}
#[test]
fn test_nsq_snr_white_noise_trellis() {
const SR: u32 = 16000;
const FRAME: usize = 320;
let mut encoder = SilkEncoder::new(SR, 1, OpusBandwidth::Wideband);
let mut decoder = SilkDecoder::new(SR, 1, OpusBandwidth::Wideband);
let mut buf = vec![0u8; 4096];
let mut out = vec![0.0f32; FRAME];
let silence = vec![0.0f32; FRAME];
for _ in 0..4 {
let _ = encoder.encode(&silence, &mut buf, FRAME);
}
let mut seed: u32 = 0xCAFE_BABE;
let noise: Vec<f32> = (0..FRAME * 5)
.map(|_| {
seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
((seed >> 1) as f32 / i32::MAX as f32) - 0.5
})
.collect();
let mut all_orig = Vec::new();
let mut all_recon = Vec::new();
for k in 0..5 {
let sl = &noise[k * FRAME..(k + 1) * FRAME];
let n = encoder.encode(sl, &mut buf, FRAME).expect("enc");
decoder.decode(&buf[..n], &mut out, FRAME).expect("dec");
all_orig.extend_from_slice(sl);
all_recon.extend_from_slice(&out);
}
let snr = segmental_snr_db(&all_orig, &all_recon, FRAME);
println!("White noise trellis SNR = {snr:.2} dB");
assert!(snr.is_finite(), "white noise trellis SNR must be finite");
assert!(
snr >= 0.0,
"white noise trellis SNR = {snr:.2} dB (expected ≥ 0 dB)"
);
}
#[test]
fn test_nsq_trellis_vs_greedy() {
const N: usize = 80; let gain = 0.005f32;
let sig_scale = gain * 5.0 / E_RAW_SCALE;
let mut seed: u32 = 0xDEAD_C0DE;
let mut lcg = || -> f32 {
seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
((seed >> 1) as f32 / i32::MAX as f32) - 0.5
};
let mut trellis_wins = 0usize;
for _ in 0..100 {
let lpc: Vec<f32> = (0..10).map(|_| lcg() * 0.05).collect();
let ltp = [0.0f32; 5];
let input: Vec<f32> = (0..N).map(|_| lcg() * sig_scale).collect();
let mut st = NsqState::new(10, 100);
st.mode = NsqMode::TrellisDelDec;
let trellis_exc = process_subframe(
&input,
&lpc,
<p,
0,
gain,
SilkSignalType::Unvoiced,
&mut st,
);
let mut sg = NsqState::new(10, 100);
sg.mode = NsqMode::Greedy;
let greedy_exc = process_subframe(
&input,
&lpc,
<p,
0,
gain,
SilkSignalType::Unvoiced,
&mut sg,
);
let trellis_dist: f64 = input
.iter()
.zip(&trellis_exc)
.map(|(&s, &e)| {
let d = f64::from(s) - f64::from(e * gain);
d * d
})
.sum();
let greedy_dist: f64 = input
.iter()
.zip(&greedy_exc)
.map(|(&s, &e)| {
let d = f64::from(s) - f64::from(e * gain);
d * d
})
.sum();
let threshold = greedy_dist * 1.01 + 1e-30;
assert!(
trellis_dist <= threshold,
"trellis dist {trellis_dist:.2e} > greedy dist {greedy_dist:.2e}"
);
if trellis_dist < greedy_dist * 0.999 {
trellis_wins += 1;
}
}
assert!(
trellis_wins >= 60,
"trellis only beat greedy {trellis_wins}/100 times (need ≥ 60)"
);
}
#[test]
fn test_nsq_mode_is_copy() {
let mode = NsqMode::TrellisDelDec;
let mode2 = mode; assert_eq!(mode, mode2);
}
}