use oxideav_celt::range_decoder::RangeDecoder;
use oxideav_celt::range_encoder::RangeEncoder;
use oxideav_core::Result;
use crate::silk::tables;
use crate::toc::OpusBandwidth;
pub fn pitch_lag_bounds(bw: OpusBandwidth) -> (i32, i32) {
match bw {
OpusBandwidth::Narrowband => (16, 144),
OpusBandwidth::Mediumband => (24, 216),
OpusBandwidth::Wideband => (32, 288),
_ => (32, 288),
}
}
pub fn decode_absolute_pitch_lag(rc: &mut RangeDecoder<'_>, bw: OpusBandwidth) -> Result<i32> {
let (min_lag, max_lag) = pitch_lag_bounds(bw);
let high = rc.decode_icdf(&tables::PITCH_LAG_NB_HIGH_ICDF, 8) as i32;
let low = rc.decode_icdf(&tables::PITCH_LAG_NB_LOW_ICDF, 8) as i32;
let lag = min_lag + high * 4 + low;
Ok(lag.clamp(min_lag, max_lag))
}
pub fn decode_delta_pitch_lag(rc: &mut RangeDecoder<'_>) -> Result<i32> {
let delta = rc.decode_icdf(&tables::PITCH_DELTA_ICDF, 8) as i32;
Ok(delta - 9)
}
pub fn decode_pitch_contour(rc: &mut RangeDecoder<'_>, _bw: OpusBandwidth) -> Result<usize> {
Ok(rc.decode_icdf(&tables::PITCH_CONTOUR_NB_20MS_ICDF, 8))
}
pub fn expand_pitch_contour(
primary_lag: i32,
_contour_idx: usize,
bw: OpusBandwidth,
lags: &mut [i32],
) {
let (min_lag, max_lag) = pitch_lag_bounds(bw);
for lag in lags.iter_mut() {
*lag = primary_lag.clamp(min_lag, max_lag);
}
}
pub fn decode_ltp_filter(rc: &mut RangeDecoder<'_>, periodicity: usize) -> [f32; 5] {
let icdf: &[u8] = match periodicity {
0 => &tables::LTP_FILTER_P0_ICDF,
1 => &tables::LTP_FILTER_P1_ICDF,
_ => &tables::LTP_FILTER_P2_ICDF,
};
let idx = rc.decode_icdf(icdf, 8);
ltp_filter_from_index(idx, periodicity)
}
const LTP_P0_Q7: [[i8; 5]; 8] = [
[4, 6, 24, 7, 5],
[0, 0, 2, 0, 0],
[12, 28, 41, 13, -4],
[-9, 15, 42, 25, 14],
[1, -2, 62, 41, -9],
[-10, 37, 65, -4, 3],
[-6, 4, 66, 7, -8],
[16, 14, 38, -3, 33],
];
const LTP_P1_Q7: [[i8; 5]; 16] = [
[13, 22, 39, 23, 12],
[-1, 36, 64, 27, -6],
[-7, 10, 55, 43, 17],
[1, 1, 8, 1, 1],
[6, -11, 74, 53, -9],
[-12, 55, 76, -12, 8],
[-3, 3, 93, 27, -4],
[26, 39, 59, 3, -8],
[2, 0, 77, 11, 9],
[-8, 22, 44, -6, 7],
[40, 9, 26, 3, 9],
[-7, 20, 101, -7, 4],
[3, -8, 42, 26, 0],
[-15, 33, 68, 2, 23],
[-2, 55, 46, -2, 15],
[3, -1, 21, 16, 41],
];
const LTP_P2_Q7: [[i8; 5]; 32] = [
[-6, 27, 61, 39, 5],
[-11, 42, 88, 4, 1],
[-2, 60, 65, 6, -4],
[-1, -5, 73, 56, 1],
[-9, 19, 94, 29, -9],
[0, 12, 99, 6, 4],
[8, -19, 102, 46, -13],
[3, 2, 13, 3, 2],
[9, -21, 84, 72, -18],
[-11, 46, 104, -22, 8],
[18, 38, 48, 23, 0],
[-16, 70, 83, -21, 11],
[5, -11, 117, 22, -8],
[-6, 23, 117, -12, 3],
[3, -8, 95, 28, 4],
[-10, 15, 77, 60, -15],
[-1, 4, 124, 2, -4],
[3, 38, 84, 24, -25],
[2, 13, 42, 13, 31],
[21, -4, 56, 46, -1],
[-1, 35, 79, -13, 19],
[-7, 65, 88, -9, -14],
[20, 4, 81, 49, -29],
[20, 0, 75, 3, -17],
[5, -9, 44, 92, -8],
[1, -3, 22, 69, 31],
[-6, 95, 41, -12, 5],
[39, 67, 16, -4, 1],
[0, -6, 120, 55, -36],
[-13, 44, 122, 4, -24],
[81, 5, 11, 3, 7],
[2, 0, 9, 10, 88],
];
pub fn ltp_filter_from_index(idx: usize, periodicity: usize) -> [f32; 5] {
let row: [i8; 5] = match periodicity {
0 => {
let i = idx.min(LTP_P0_Q7.len() - 1);
LTP_P0_Q7[i]
}
1 => {
let i = idx.min(LTP_P1_Q7.len() - 1);
LTP_P1_Q7[i]
}
_ => {
let i = idx.min(LTP_P2_Q7.len() - 1);
LTP_P2_Q7[i]
}
};
[
row[0] as f32 / 128.0,
row[1] as f32 / 128.0,
row[2] as f32 / 128.0,
row[3] as f32 / 128.0,
row[4] as f32 / 128.0,
]
}
pub fn ltp_filter_index_count(periodicity: usize) -> usize {
match periodicity {
0 => 8,
1 => 16,
_ => 32,
}
}
pub fn encode_primary_pitch_lag(
enc: &mut RangeEncoder,
bw: OpusBandwidth,
lag: i32,
prev_lag: i32,
) {
let (min_lag, max_lag) = pitch_lag_bounds(bw);
let lag_c = lag.clamp(min_lag, max_lag);
let delta = lag_c - prev_lag;
let use_abs = prev_lag == 0 || !(-9..=11).contains(&delta);
enc.encode_bit_logp(use_abs, 1);
if use_abs {
let (min_lag_nb, _) = pitch_lag_bounds(bw);
let raw = (lag_c - min_lag_nb).clamp(0, 127); let high = ((raw >> 2) & 0x1f) as usize;
let low = (raw & 0x3) as usize;
enc.encode_icdf(high, &tables::PITCH_LAG_NB_HIGH_ICDF, 8);
enc.encode_icdf(low, &tables::PITCH_LAG_NB_LOW_ICDF, 8);
} else {
let sym = (delta + 9).clamp(0, 20) as usize;
enc.encode_icdf(sym, &tables::PITCH_DELTA_ICDF, 8);
}
}
pub fn encode_pitch_contour(enc: &mut RangeEncoder, _bw: OpusBandwidth) {
enc.encode_icdf(0, &tables::PITCH_CONTOUR_NB_20MS_ICDF, 8);
}
pub fn encode_ltp_scaling(enc: &mut RangeEncoder, scale_q14: i32) {
let idx = match scale_q14 {
15565 => 0usize,
12288 => 1,
_ => 2, };
enc.encode_icdf(idx, &tables::LTP_SCALING_ICDF, 8);
}
pub fn encode_ltp_periodicity(enc: &mut RangeEncoder, periodicity: usize) {
let p = periodicity.min(2);
enc.encode_icdf(p, &tables::LTP_PERIODICITY_ICDF, 8);
}
pub fn encode_ltp_filter_index(enc: &mut RangeEncoder, periodicity: usize, idx: usize) {
let icdf: &[u8] = match periodicity {
0 => &tables::LTP_FILTER_P0_ICDF,
1 => &tables::LTP_FILTER_P1_ICDF,
_ => &tables::LTP_FILTER_P2_ICDF,
};
let n = icdf.len();
enc.encode_icdf(idx.min(n - 1), icdf, 8);
}
pub fn pick_ltp_filter_index(correlation: f32, periodicity: usize) -> usize {
let n = ltp_filter_index_count(periodicity);
let c = correlation.clamp(0.0, 1.0);
let f = c * (n as f32 - 1.0);
(f.round() as usize).min(n - 1)
}
pub fn pick_ltp_filter_from_history(
pcm: &[f32],
ltp_history: &[f32],
lag: i32,
periodicity: usize,
) -> usize {
if lag <= 2 || ltp_history.is_empty() {
return 0;
}
let hist_len = ltp_history.len() as i32;
let n_hist = ((lag - 2) as usize).min(pcm.len());
if n_hist == 0 {
return 0;
}
let mut xcorr = [0.0f32; 5];
let mut lag_energy = [0.0f32; 5];
for n in 0..n_hist {
let xn = pcm[n];
for k in 0..5 {
let abs_j = hist_len + n as i32 - lag + 2 - k as i32;
let h = if abs_j >= 0 && (abs_j as usize) < ltp_history.len() {
ltp_history[abs_j as usize]
} else {
0.0
};
xcorr[k] += xn * h;
lag_energy[k] += h * h;
}
}
let n_cand = ltp_filter_index_count(periodicity);
let mut best_idx = 0usize;
let mut best_score = f32::NEG_INFINITY;
for idx in 0..n_cand {
let row: [i8; 5] = match periodicity {
0 => LTP_P0_Q7[idx],
1 => LTP_P1_Q7[idx],
_ => LTP_P2_Q7[idx],
};
let taps: [f32; 5] = [
row[0] as f32 / 128.0,
row[1] as f32 / 128.0,
row[2] as f32 / 128.0,
row[3] as f32 / 128.0,
row[4] as f32 / 128.0,
];
let b_xcorr: f32 = taps.iter().zip(xcorr.iter()).map(|(b, x)| b * x).sum();
let b_energy: f32 = taps
.iter()
.zip(lag_energy.iter())
.map(|(b, e)| b * b * e)
.sum();
let score = if b_energy > 1e-9 {
(b_xcorr * b_xcorr) / b_energy
} else {
0.0
};
if score > best_score {
best_score = score;
best_idx = idx;
}
}
best_idx
}
#[cfg(test)]
mod tests {
use super::*;
use oxideav_celt::range_encoder::RangeEncoder;
#[test]
fn absolute_pitch_lag_roundtrip_nb() {
for lag in [16, 20, 40, 80, 120, 143] {
let mut enc = RangeEncoder::new(64);
encode_primary_pitch_lag(&mut enc, OpusBandwidth::Narrowband, lag, 0);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let abs_flag = dec.decode_bit_logp(1);
assert!(abs_flag, "expected abs flag for prev_lag=0");
let got = decode_absolute_pitch_lag(&mut dec, OpusBandwidth::Narrowband).unwrap();
assert_eq!(got, lag, "NB lag {lag} did not round-trip (got {got})");
}
}
#[test]
fn delta_pitch_lag_roundtrip() {
let prev = 80i32;
for delta in -9..=11 {
let lag = prev + delta;
let mut enc = RangeEncoder::new(64);
encode_primary_pitch_lag(&mut enc, OpusBandwidth::Narrowband, lag, prev);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let abs_flag = dec.decode_bit_logp(1);
assert!(!abs_flag, "expected delta for delta={delta}");
let d = decode_delta_pitch_lag(&mut dec).unwrap();
assert_eq!(d, delta, "delta {delta} did not round-trip");
}
}
#[test]
fn out_of_range_delta_uses_abs() {
let mut enc = RangeEncoder::new(64);
encode_primary_pitch_lag(&mut enc, OpusBandwidth::Narrowband, 140, 20); let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let abs_flag = dec.decode_bit_logp(1);
assert!(abs_flag);
let got = decode_absolute_pitch_lag(&mut dec, OpusBandwidth::Narrowband).unwrap();
assert_eq!(got, 140);
}
#[test]
fn ltp_filter_index_roundtrip() {
for periodicity in 0..3 {
let n = ltp_filter_index_count(periodicity);
for idx in [0, 1, n / 2, n - 1] {
let mut enc = RangeEncoder::new(64);
encode_ltp_filter_index(&mut enc, periodicity, idx);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let icdf: &[u8] = match periodicity {
0 => &tables::LTP_FILTER_P0_ICDF,
1 => &tables::LTP_FILTER_P1_ICDF,
_ => &tables::LTP_FILTER_P2_ICDF,
};
let got = dec.decode_icdf(icdf, 8);
assert_eq!(
got, idx,
"periodicity {periodicity} filter idx {idx} did not round-trip"
);
let enc_taps = ltp_filter_from_index(idx, periodicity);
let dec_taps = ltp_filter_from_index(got, periodicity);
for k in 0..5 {
assert!((enc_taps[k] - dec_taps[k]).abs() < 1e-6);
}
}
}
}
#[test]
fn ltp_scaling_roundtrip() {
for &scale in &[15565i32, 12288, 8192] {
let mut enc = RangeEncoder::new(32);
encode_ltp_scaling(&mut enc, scale);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let idx = dec.decode_icdf(&tables::LTP_SCALING_ICDF, 8);
let got = match idx {
0 => 15565,
1 => 12288,
_ => 8192,
};
assert_eq!(got, scale);
}
}
}