use oxideav_celt::range_decoder::RangeDecoder;
use oxideav_celt::range_encoder::RangeEncoder;
use oxideav_core::Result;
use crate::silk::tables;
use crate::toc::OpusBandwidth;
pub fn pitch_lag_bounds(bw: OpusBandwidth) -> (i32, i32) {
match bw {
OpusBandwidth::Narrowband => (16, 144),
OpusBandwidth::Mediumband => (24, 216),
OpusBandwidth::Wideband => (32, 288),
_ => (32, 288),
}
}
pub fn decode_absolute_pitch_lag(rc: &mut RangeDecoder<'_>, bw: OpusBandwidth) -> Result<i32> {
let (min_lag, max_lag) = pitch_lag_bounds(bw);
let high = rc.decode_icdf(&tables::PITCH_LAG_NB_HIGH_ICDF, 8) as i32;
let low = rc.decode_icdf(&tables::PITCH_LAG_NB_LOW_ICDF, 8) as i32;
let lag = min_lag + high * 4 + low;
Ok(lag.clamp(min_lag, max_lag))
}
pub fn decode_delta_pitch_lag(rc: &mut RangeDecoder<'_>) -> Result<i32> {
let delta = rc.decode_icdf(&tables::PITCH_DELTA_ICDF, 8) as i32;
Ok(delta - 9)
}
pub fn decode_pitch_contour(rc: &mut RangeDecoder<'_>, _bw: OpusBandwidth) -> Result<usize> {
Ok(rc.decode_icdf(&tables::PITCH_CONTOUR_NB_20MS_ICDF, 8))
}
pub fn expand_pitch_contour(
primary_lag: i32,
_contour_idx: usize,
bw: OpusBandwidth,
lags: &mut [i32],
) {
let (min_lag, max_lag) = pitch_lag_bounds(bw);
for lag in lags.iter_mut() {
*lag = primary_lag.clamp(min_lag, max_lag);
}
}
pub fn decode_ltp_filter(rc: &mut RangeDecoder<'_>, periodicity: usize) -> [f32; 5] {
let icdf: &[u8] = match periodicity {
0 => &tables::LTP_FILTER_P0_ICDF,
1 => &tables::LTP_FILTER_P1_ICDF,
_ => &tables::LTP_FILTER_P2_ICDF,
};
let idx = rc.decode_icdf(icdf, 8);
ltp_filter_from_index(idx, periodicity)
}
pub fn ltp_filter_from_index(idx: usize, periodicity: usize) -> [f32; 5] {
let _ = periodicity;
let s = (idx as f32 - 4.0) / 32.0;
[
-0.05 - s * 0.02,
0.10,
0.70 + s * 0.10,
0.10,
-0.05 - s * 0.02,
]
}
pub fn ltp_filter_index_count(periodicity: usize) -> usize {
match periodicity {
0 => 8,
1 => 16,
_ => 32,
}
}
pub fn encode_primary_pitch_lag(
enc: &mut RangeEncoder,
bw: OpusBandwidth,
lag: i32,
prev_lag: i32,
) {
let (min_lag, max_lag) = pitch_lag_bounds(bw);
let lag_c = lag.clamp(min_lag, max_lag);
let delta = lag_c - prev_lag;
let use_abs = prev_lag == 0 || !(-9..=11).contains(&delta);
enc.encode_bit_logp(use_abs, 1);
if use_abs {
let (min_lag_nb, _) = pitch_lag_bounds(bw);
let raw = (lag_c - min_lag_nb).clamp(0, 127); let high = ((raw >> 2) & 0x1f) as usize;
let low = (raw & 0x3) as usize;
enc.encode_icdf(high, &tables::PITCH_LAG_NB_HIGH_ICDF, 8);
enc.encode_icdf(low, &tables::PITCH_LAG_NB_LOW_ICDF, 8);
} else {
let sym = (delta + 9).clamp(0, 20) as usize;
enc.encode_icdf(sym, &tables::PITCH_DELTA_ICDF, 8);
}
}
pub fn encode_pitch_contour(enc: &mut RangeEncoder, _bw: OpusBandwidth) {
enc.encode_icdf(0, &tables::PITCH_CONTOUR_NB_20MS_ICDF, 8);
}
pub fn encode_ltp_scaling(enc: &mut RangeEncoder, scale_q14: i32) {
let idx = match scale_q14 {
15565 => 0usize,
12288 => 1,
_ => 2, };
enc.encode_icdf(idx, &tables::LTP_SCALING_ICDF, 8);
}
pub fn encode_ltp_periodicity(enc: &mut RangeEncoder, periodicity: usize) {
let p = periodicity.min(2);
enc.encode_icdf(p, &tables::LTP_PERIODICITY_ICDF, 8);
}
pub fn encode_ltp_filter_index(enc: &mut RangeEncoder, periodicity: usize, idx: usize) {
let icdf: &[u8] = match periodicity {
0 => &tables::LTP_FILTER_P0_ICDF,
1 => &tables::LTP_FILTER_P1_ICDF,
_ => &tables::LTP_FILTER_P2_ICDF,
};
let n = icdf.len();
enc.encode_icdf(idx.min(n - 1), icdf, 8);
}
pub fn pick_ltp_filter_index(correlation: f32, periodicity: usize) -> usize {
let n = ltp_filter_index_count(periodicity);
let c = correlation.clamp(0.0, 1.0);
let f = c * (n as f32 - 1.0);
(f.round() as usize).min(n - 1)
}
#[cfg(test)]
mod tests {
use super::*;
use oxideav_celt::range_encoder::RangeEncoder;
#[test]
fn absolute_pitch_lag_roundtrip_nb() {
for lag in [16, 20, 40, 80, 120, 143] {
let mut enc = RangeEncoder::new(64);
encode_primary_pitch_lag(&mut enc, OpusBandwidth::Narrowband, lag, 0);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let abs_flag = dec.decode_bit_logp(1);
assert!(abs_flag, "expected abs flag for prev_lag=0");
let got = decode_absolute_pitch_lag(&mut dec, OpusBandwidth::Narrowband).unwrap();
assert_eq!(got, lag, "NB lag {lag} did not round-trip (got {got})");
}
}
#[test]
fn delta_pitch_lag_roundtrip() {
let prev = 80i32;
for delta in -9..=11 {
let lag = prev + delta;
let mut enc = RangeEncoder::new(64);
encode_primary_pitch_lag(&mut enc, OpusBandwidth::Narrowband, lag, prev);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let abs_flag = dec.decode_bit_logp(1);
assert!(!abs_flag, "expected delta for delta={delta}");
let d = decode_delta_pitch_lag(&mut dec).unwrap();
assert_eq!(d, delta, "delta {delta} did not round-trip");
}
}
#[test]
fn out_of_range_delta_uses_abs() {
let mut enc = RangeEncoder::new(64);
encode_primary_pitch_lag(&mut enc, OpusBandwidth::Narrowband, 140, 20); let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let abs_flag = dec.decode_bit_logp(1);
assert!(abs_flag);
let got = decode_absolute_pitch_lag(&mut dec, OpusBandwidth::Narrowband).unwrap();
assert_eq!(got, 140);
}
#[test]
fn ltp_filter_index_roundtrip() {
for periodicity in 0..3 {
let n = ltp_filter_index_count(periodicity);
for idx in [0, 1, n / 2, n - 1] {
let mut enc = RangeEncoder::new(64);
encode_ltp_filter_index(&mut enc, periodicity, idx);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let icdf: &[u8] = match periodicity {
0 => &tables::LTP_FILTER_P0_ICDF,
1 => &tables::LTP_FILTER_P1_ICDF,
_ => &tables::LTP_FILTER_P2_ICDF,
};
let got = dec.decode_icdf(icdf, 8);
assert_eq!(
got, idx,
"periodicity {periodicity} filter idx {idx} did not round-trip"
);
let enc_taps = ltp_filter_from_index(idx, periodicity);
let dec_taps = ltp_filter_from_index(got, periodicity);
for k in 0..5 {
assert!((enc_taps[k] - dec_taps[k]).abs() < 1e-6);
}
}
}
}
#[test]
fn ltp_scaling_roundtrip() {
for &scale in &[15565i32, 12288, 8192] {
let mut enc = RangeEncoder::new(32);
encode_ltp_scaling(&mut enc, scale);
let buf = enc.done().unwrap();
let mut dec = RangeDecoder::new(&buf);
let idx = dec.decode_icdf(&tables::LTP_SCALING_ICDF, 8);
let got = match idx {
0 => 15565,
1 => 12288,
_ => 8192,
};
assert_eq!(got, scale);
}
}
}