use crate::range_decoder::RangeDecoder;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CeltPostFilter {
pub octave: u8,
pub period: u16,
pub gain_index: u8,
pub tapset: u8,
}
impl CeltPostFilter {
pub fn pitch_period(octave: u8, fine_pitch: u16) -> u16 {
((16u16) << octave) + fine_pitch - 1
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CeltHeaderPrefix {
pub silence: bool,
pub post_filter: Option<CeltPostFilter>,
pub transient: bool,
pub intra: bool,
}
const SILENCE_ICDF: &[u8] = &[1, 0];
const SILENCE_FTB: u32 = 15;
const TAPSET_ICDF: &[u8] = &[2, 1, 0];
const TAPSET_FTB: u32 = 2;
impl CeltHeaderPrefix {
pub fn decode(rd: &mut RangeDecoder<'_>) -> Self {
let silence = rd.dec_icdf(SILENCE_ICDF, SILENCE_FTB) == 1;
if silence {
return Self {
silence: true,
post_filter: None,
transient: false,
intra: false,
};
}
let post_filter_enabled = rd.dec_bit_logp(1) == 1;
let post_filter = if post_filter_enabled {
let octave = rd.dec_uint(6).unwrap_or(0) as u8;
let raw_bits = 4 + u32::from(octave);
let fine_pitch = rd.dec_bits(raw_bits) as u16;
let period = CeltPostFilter::pitch_period(octave, fine_pitch);
let gain_index = rd.dec_bits(3) as u8;
let tapset = rd.dec_icdf(TAPSET_ICDF, TAPSET_FTB) as u8;
Some(CeltPostFilter {
octave,
period,
gain_index,
tapset,
})
} else {
None
};
let transient = rd.dec_bit_logp(3) == 1;
let intra = rd.dec_bit_logp(3) == 1;
Self {
silence: false,
post_filter,
transient,
intra,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn silence_pdf_sums_to_32768() {
let pdf = [32767u32, 1];
assert_eq!(pdf.iter().sum::<u32>(), 1u32 << SILENCE_FTB);
assert_eq!(SILENCE_ICDF, &[1u8, 0]);
assert_eq!(*SILENCE_ICDF.last().unwrap(), 0);
for w in SILENCE_ICDF.windows(2) {
assert!(w[0] > w[1]);
}
}
#[test]
fn tapset_pdf_sums_to_4() {
let pdf = [2u32, 1, 1];
assert_eq!(pdf.iter().sum::<u32>(), 1u32 << TAPSET_FTB);
assert_eq!(TAPSET_ICDF, &[2u8, 1, 0]);
assert_eq!(*TAPSET_ICDF.last().unwrap(), 0);
for w in TAPSET_ICDF.windows(2) {
assert!(w[0] > w[1]);
}
}
#[test]
fn pitch_period_minimum_is_15() {
assert_eq!(CeltPostFilter::pitch_period(0, 0), 15);
}
#[test]
fn pitch_period_maximum_is_1022() {
assert_eq!(CeltPostFilter::pitch_period(5, 511), 1022);
}
#[test]
fn pitch_period_octave_boundaries() {
let lower: [u16; 6] = [15, 31, 63, 127, 255, 511];
for (k, &want) in lower.iter().enumerate() {
assert_eq!(CeltPostFilter::pitch_period(k as u8, 0), want);
}
let upper: [u16; 6] = [30, 62, 126, 254, 510, 1022];
for k in 0..=5u8 {
let fp = (1u16 << (4 + k)) - 1;
assert_eq!(CeltPostFilter::pitch_period(k, fp), upper[k as usize]);
}
}
fn buf(bytes: &[u8]) -> Vec<u8> {
if bytes.len() < 2 {
let mut v = bytes.to_vec();
while v.len() < 2 {
v.push(0);
}
v
} else {
bytes.to_vec()
}
}
#[test]
fn decode_terminates_on_all_zero_buffer() {
let b = buf(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
let mut rd = RangeDecoder::new(&b);
let hp = CeltHeaderPrefix::decode(&mut rd);
assert!(!hp.silence);
assert!(hp.post_filter.is_none());
assert!(!hp.transient);
assert!(!hp.intra);
}
#[test]
fn decode_terminates_on_all_ones_buffer() {
let b = buf(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]);
let mut rd = RangeDecoder::new(&b);
let hp = CeltHeaderPrefix::decode(&mut rd);
if let Some(pf) = hp.post_filter {
assert!(pf.octave <= 5, "octave {} out of range", pf.octave);
assert!(
pf.period >= 15 && pf.period <= 1022,
"period {} oob",
pf.period
);
assert!(pf.gain_index <= 7, "gain_index {} oob", pf.gain_index);
assert!(pf.tapset <= 2, "tapset {} oob", pf.tapset);
}
let _ = hp.silence;
let _ = hp.transient;
let _ = hp.intra;
}
#[test]
fn decode_advances_tell() {
let b = buf(&[0x80, 0x00, 0x00, 0x00, 0x00, 0x00]);
let mut rd = RangeDecoder::new(&b);
let t0 = rd.tell();
let _ = CeltHeaderPrefix::decode(&mut rd);
let t1 = rd.tell();
assert!(t1 > t0, "tell did not advance: {t0} -> {t1}");
}
#[test]
fn decode_post_filter_field_ranges_swept() {
for byte in 0..=255u8 {
let b = buf(&[byte, byte ^ 0xA5, byte.wrapping_add(0x33), 0x00]);
let mut rd = RangeDecoder::new(&b);
let hp = CeltHeaderPrefix::decode(&mut rd);
if let Some(pf) = hp.post_filter {
assert!(pf.octave <= 5);
assert!((15..=1022).contains(&pf.period));
assert!(pf.gain_index <= 7);
assert!(pf.tapset <= 2);
}
}
}
#[test]
fn silence_shortcircuits_other_symbols() {
let hp = CeltHeaderPrefix {
silence: true,
post_filter: None,
transient: false,
intra: false,
};
assert!(hp.silence);
assert!(hp.post_filter.is_none());
assert!(!hp.transient);
assert!(!hp.intra);
}
}