use crate::celt_band_layout::CELT_NUM_BANDS;
pub const E_PROB_MODEL_LM_COUNT: usize = 4;
pub const E_PROB_MODEL_MODE_COUNT: usize = 2;
pub const E_PROB_MODEL_MODE_INTER: usize = 0;
pub const E_PROB_MODEL_MODE_INTRA: usize = 1;
pub const E_PROB_MODEL_BYTES_PER_BAND: usize = 2;
pub const E_PROB_MODEL_BYTES_PER_ROW: usize = CELT_NUM_BANDS * E_PROB_MODEL_BYTES_PER_BAND;
pub const E_PROB_MODEL_TOTAL_BYTES: usize =
E_PROB_MODEL_LM_COUNT * E_PROB_MODEL_MODE_COUNT * E_PROB_MODEL_BYTES_PER_ROW;
pub const INTRA_PRED_BETA_Q15: u16 = 4915;
pub const Q15_ONE: u32 = 32768;
pub const INTRA_PRED_ALPHA_Q15: u16 = 0;
pub const INTER_PRED_ALPHA_Q15: [u16; E_PROB_MODEL_LM_COUNT] = [29440, 26112, 21248, 16384];
pub const INTER_PRED_BETA_Q15: [u16; E_PROB_MODEL_LM_COUNT] = [30147, 22282, 12124, 6554];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EProbPair {
pub prob: u8,
pub decay: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnergyPredictionMode {
Inter,
Intra,
}
impl EnergyPredictionMode {
pub const fn from_intra_flag(intra_flag: bool) -> Self {
if intra_flag {
EnergyPredictionMode::Intra
} else {
EnergyPredictionMode::Inter
}
}
pub const fn table_index(self) -> usize {
match self {
EnergyPredictionMode::Inter => E_PROB_MODEL_MODE_INTER,
EnergyPredictionMode::Intra => E_PROB_MODEL_MODE_INTRA,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EnergyPredCoef {
pub alpha_q15: u16,
pub beta_q15: u16,
}
impl EnergyPredCoef {
pub fn alpha(self) -> f64 {
f64::from(self.alpha_q15) / f64::from(Q15_ONE)
}
pub fn beta(self) -> f64 {
f64::from(self.beta_q15) / f64::from(Q15_ONE)
}
}
pub fn energy_pred_coef(
lm: u32,
mode: EnergyPredictionMode,
) -> Result<EnergyPredCoef, EProbModelError> {
if lm >= E_PROB_MODEL_LM_COUNT as u32 {
return Err(EProbModelError::LmOutOfRange { lm });
}
Ok(match mode {
EnergyPredictionMode::Inter => EnergyPredCoef {
alpha_q15: INTER_PRED_ALPHA_Q15[lm as usize],
beta_q15: INTER_PRED_BETA_Q15[lm as usize],
},
EnergyPredictionMode::Intra => EnergyPredCoef {
alpha_q15: INTRA_PRED_ALPHA_Q15,
beta_q15: INTRA_PRED_BETA_Q15,
},
})
}
pub const E_PROB_MODEL: [[[u8; E_PROB_MODEL_BYTES_PER_ROW]; E_PROB_MODEL_MODE_COUNT];
E_PROB_MODEL_LM_COUNT] = [
[
[
72, 127, 65, 129, 66, 128, 65, 128, 64, 128, 62, 128, 64, 128, 64, 128, 92, 78, 92, 79,
92, 78, 90, 79, 116, 41, 115, 40, 114, 40, 132, 26, 132, 26, 145, 17, 161, 12, 176, 10,
177, 11,
],
[
24, 179, 48, 138, 54, 135, 54, 132, 53, 134, 56, 133, 55, 132, 55, 132, 61, 114, 70,
96, 74, 88, 75, 88, 87, 74, 89, 66, 91, 67, 100, 59, 108, 50, 120, 40, 122, 37, 97, 43,
78, 50,
],
],
[
[
83, 78, 84, 81, 88, 75, 86, 74, 87, 71, 90, 73, 93, 74, 93, 74, 109, 40, 114, 36, 117,
34, 117, 34, 143, 17, 145, 18, 146, 19, 162, 12, 165, 10, 178, 7, 189, 6, 190, 8, 177,
9,
],
[
23, 178, 54, 115, 63, 102, 66, 98, 69, 99, 74, 89, 71, 91, 73, 91, 78, 89, 86, 80, 92,
66, 93, 64, 102, 59, 103, 60, 104, 60, 117, 52, 123, 44, 138, 35, 133, 31, 97, 38, 77,
45,
],
],
[
[
61, 90, 93, 60, 105, 42, 107, 41, 110, 45, 116, 38, 113, 38, 112, 38, 124, 26, 132, 27,
136, 19, 140, 20, 155, 14, 159, 16, 158, 18, 170, 13, 177, 10, 187, 8, 192, 6, 175, 9,
159, 10,
],
[
21, 178, 59, 110, 71, 86, 75, 85, 84, 83, 91, 66, 88, 73, 87, 72, 92, 75, 98, 72, 105,
58, 107, 54, 115, 52, 114, 55, 112, 56, 129, 51, 132, 40, 150, 33, 140, 29, 98, 35, 77,
42,
],
],
[
[
42, 121, 96, 66, 108, 43, 111, 40, 117, 44, 123, 32, 120, 36, 119, 33, 127, 33, 134,
34, 139, 21, 147, 23, 152, 20, 158, 25, 154, 26, 166, 21, 173, 16, 184, 13, 184, 10,
150, 13, 139, 15,
],
[
22, 178, 63, 114, 74, 82, 84, 83, 92, 82, 103, 62, 96, 72, 96, 67, 101, 73, 107, 72,
113, 55, 118, 52, 125, 52, 118, 52, 117, 55, 135, 49, 137, 39, 157, 32, 145, 29, 97,
33, 77, 40,
],
],
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EProbModelError {
LmOutOfRange { lm: u32 },
BandOutOfRange { band: u32 },
}
pub fn e_prob_pair(
lm: u32,
mode: EnergyPredictionMode,
band: u32,
) -> Result<EProbPair, EProbModelError> {
if lm >= E_PROB_MODEL_LM_COUNT as u32 {
return Err(EProbModelError::LmOutOfRange { lm });
}
if band >= CELT_NUM_BANDS as u32 {
return Err(EProbModelError::BandOutOfRange { band });
}
let row = &E_PROB_MODEL[lm as usize][mode.table_index()];
let off = (band as usize) * E_PROB_MODEL_BYTES_PER_BAND;
Ok(EProbPair {
prob: row[off],
decay: row[off + 1],
})
}
pub fn e_prob_row(
lm: u32,
mode: EnergyPredictionMode,
) -> Result<&'static [u8; E_PROB_MODEL_BYTES_PER_ROW], EProbModelError> {
if lm >= E_PROB_MODEL_LM_COUNT as u32 {
return Err(EProbModelError::LmOutOfRange { lm });
}
Ok(&E_PROB_MODEL[lm as usize][mode.table_index()])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn table_shape_constants_match_struct() {
assert_eq!(E_PROB_MODEL_LM_COUNT, 4);
assert_eq!(E_PROB_MODEL_MODE_COUNT, 2);
assert_eq!(E_PROB_MODEL_BYTES_PER_BAND, 2);
assert_eq!(E_PROB_MODEL_BYTES_PER_ROW, 42);
assert_eq!(E_PROB_MODEL_TOTAL_BYTES, 336);
}
#[test]
fn table_inner_row_length_matches_band_count_times_two() {
for (lm, by_lm) in E_PROB_MODEL.iter().enumerate() {
for (mode, row) in by_lm.iter().enumerate() {
assert_eq!(
row.len(),
E_PROB_MODEL_BYTES_PER_ROW,
"(lm={lm},mode={mode}) inner row length mismatch"
);
assert_eq!(
row.len(),
CELT_NUM_BANDS * 2,
"row should be 21 bands × 2 bytes"
);
}
}
}
#[test]
fn table_total_bytes_matches_lm_times_mode_times_row() {
let total: usize = E_PROB_MODEL
.iter()
.map(|by_lm| by_lm.iter().map(|row| row.len()).sum::<usize>())
.sum();
assert_eq!(total, E_PROB_MODEL_TOTAL_BYTES);
}
#[test]
fn intra_alpha_is_zero_per_rfc() {
assert_eq!(INTRA_PRED_ALPHA_Q15, 0);
}
#[test]
fn intra_beta_is_4915_over_32768_per_rfc() {
assert_eq!(INTRA_PRED_BETA_Q15, 4915);
assert_eq!(Q15_ONE, 32768);
}
#[test]
fn inter_alpha_q15_values_per_appendix_a() {
assert_eq!(INTER_PRED_ALPHA_Q15, [29440, 26112, 21248, 16384]);
}
#[test]
fn inter_beta_q15_values_per_appendix_a() {
assert_eq!(INTER_PRED_BETA_Q15, [30147, 22282, 12124, 6554]);
}
#[test]
fn inter_alpha_lm3_is_exactly_one_half() {
assert_eq!(u32::from(INTER_PRED_ALPHA_Q15[3]) * 2, Q15_ONE);
}
#[test]
fn inter_coefficients_strictly_decrease_with_frame_size() {
for lm in 0..E_PROB_MODEL_LM_COUNT - 1 {
assert!(
INTER_PRED_ALPHA_Q15[lm] > INTER_PRED_ALPHA_Q15[lm + 1],
"alpha should strictly decrease between LM={lm} and LM={}",
lm + 1
);
assert!(
INTER_PRED_BETA_Q15[lm] > INTER_PRED_BETA_Q15[lm + 1],
"beta should strictly decrease between LM={lm} and LM={}",
lm + 1
);
}
}
#[test]
fn inter_beta_always_exceeds_intra_beta() {
for &beta in &INTER_PRED_BETA_Q15 {
assert!(beta > INTRA_PRED_BETA_Q15);
}
}
#[test]
fn energy_pred_coef_inter_matches_tables_for_every_lm() {
for lm in 0..E_PROB_MODEL_LM_COUNT as u32 {
let c = energy_pred_coef(lm, EnergyPredictionMode::Inter).unwrap();
assert_eq!(c.alpha_q15, INTER_PRED_ALPHA_Q15[lm as usize]);
assert_eq!(c.beta_q15, INTER_PRED_BETA_Q15[lm as usize]);
}
}
#[test]
fn energy_pred_coef_intra_is_lm_independent() {
for lm in 0..E_PROB_MODEL_LM_COUNT as u32 {
let c = energy_pred_coef(lm, EnergyPredictionMode::Intra).unwrap();
assert_eq!(
c,
EnergyPredCoef {
alpha_q15: 0,
beta_q15: 4915,
}
);
}
}
#[test]
fn energy_pred_coef_rejects_lm_out_of_range_in_both_modes() {
for mode in [EnergyPredictionMode::Inter, EnergyPredictionMode::Intra] {
let err = energy_pred_coef(4, mode).unwrap_err();
assert_eq!(err, EProbModelError::LmOutOfRange { lm: 4 });
let err = energy_pred_coef(u32::MAX, mode).unwrap_err();
assert_eq!(err, EProbModelError::LmOutOfRange { lm: u32::MAX });
}
}
#[test]
fn energy_pred_coef_float_views_match_q15_ratios() {
let c = energy_pred_coef(3, EnergyPredictionMode::Inter).unwrap();
assert_eq!(c.alpha(), 0.5);
assert_eq!(c.beta(), 6554.0 / 32768.0);
let c = energy_pred_coef(0, EnergyPredictionMode::Intra).unwrap();
assert_eq!(c.alpha(), 0.0);
assert_eq!(c.beta(), 4915.0 / 32768.0);
}
#[test]
fn inter_q15_approximations_documented_in_doc_comments() {
let beta_approx = [0.920, 0.680, 0.370, 0.200];
let alpha_approx = [0.898, 0.797, 0.648, 0.500];
for lm in 0..E_PROB_MODEL_LM_COUNT {
let a = f64::from(INTER_PRED_ALPHA_Q15[lm]) / f64::from(Q15_ONE);
let b = f64::from(INTER_PRED_BETA_Q15[lm]) / f64::from(Q15_ONE);
assert!((a - alpha_approx[lm]).abs() < 5e-4, "alpha LM={lm}");
assert!((b - beta_approx[lm]).abs() < 5e-4, "beta LM={lm}");
}
}
#[test]
fn intra_flag_true_routes_to_intra() {
assert_eq!(
EnergyPredictionMode::from_intra_flag(true),
EnergyPredictionMode::Intra
);
}
#[test]
fn intra_flag_false_routes_to_inter() {
assert_eq!(
EnergyPredictionMode::from_intra_flag(false),
EnergyPredictionMode::Inter
);
}
#[test]
fn mode_table_indices_match_csv_layout() {
assert_eq!(EnergyPredictionMode::Inter.table_index(), 0);
assert_eq!(EnergyPredictionMode::Intra.table_index(), 1);
assert_eq!(
EnergyPredictionMode::Inter.table_index(),
E_PROB_MODEL_MODE_INTER
);
assert_eq!(
EnergyPredictionMode::Intra.table_index(),
E_PROB_MODEL_MODE_INTRA
);
}
#[test]
fn csv_row_0_lm0_inter_first_pair_band_0() {
let p = e_prob_pair(0, EnergyPredictionMode::Inter, 0).unwrap();
assert_eq!(
p,
EProbPair {
prob: 72,
decay: 127
}
);
}
#[test]
fn csv_row_0_lm0_inter_last_pair_band_20() {
let p = e_prob_pair(0, EnergyPredictionMode::Inter, 20).unwrap();
assert_eq!(
p,
EProbPair {
prob: 177,
decay: 11
}
);
}
#[test]
fn csv_row_1_lm0_intra_first_pair_band_0() {
let p = e_prob_pair(0, EnergyPredictionMode::Intra, 0).unwrap();
assert_eq!(
p,
EProbPair {
prob: 24,
decay: 179
}
);
}
#[test]
fn csv_row_3_lm1_intra_band_5() {
let p = e_prob_pair(1, EnergyPredictionMode::Intra, 5).unwrap();
assert_eq!(
p,
EProbPair {
prob: 74,
decay: 89
}
);
}
#[test]
fn csv_row_4_lm2_inter_band_10() {
let p = e_prob_pair(2, EnergyPredictionMode::Inter, 10).unwrap();
assert_eq!(
p,
EProbPair {
prob: 136,
decay: 19
}
);
}
#[test]
fn csv_row_6_lm3_inter_first_pair_band_0() {
let p = e_prob_pair(3, EnergyPredictionMode::Inter, 0).unwrap();
assert_eq!(
p,
EProbPair {
prob: 42,
decay: 121
}
);
}
#[test]
fn csv_row_7_lm3_intra_last_pair_band_20() {
let p = e_prob_pair(3, EnergyPredictionMode::Intra, 20).unwrap();
assert_eq!(
p,
EProbPair {
prob: 77,
decay: 40
}
);
}
#[test]
fn e_prob_pair_rejects_lm_out_of_range() {
let err = e_prob_pair(4, EnergyPredictionMode::Inter, 0).unwrap_err();
assert_eq!(err, EProbModelError::LmOutOfRange { lm: 4 });
let err = e_prob_pair(u32::MAX, EnergyPredictionMode::Intra, 0).unwrap_err();
assert_eq!(err, EProbModelError::LmOutOfRange { lm: u32::MAX });
}
#[test]
fn e_prob_pair_rejects_band_out_of_range() {
let err = e_prob_pair(0, EnergyPredictionMode::Inter, 21).unwrap_err();
assert_eq!(err, EProbModelError::BandOutOfRange { band: 21 });
let err = e_prob_pair(2, EnergyPredictionMode::Intra, 100).unwrap_err();
assert_eq!(err, EProbModelError::BandOutOfRange { band: 100 });
}
#[test]
fn e_prob_row_returns_full_42_byte_row() {
let row = e_prob_row(0, EnergyPredictionMode::Inter).unwrap();
assert_eq!(row.len(), 42);
assert_eq!(row[0], 72);
assert_eq!(row[1], 127);
assert_eq!(row[40], 177);
assert_eq!(row[41], 11);
}
#[test]
fn e_prob_row_rejects_lm_out_of_range() {
let err = e_prob_row(99, EnergyPredictionMode::Inter).unwrap_err();
assert_eq!(err, EProbModelError::LmOutOfRange { lm: 99 });
}
#[test]
fn every_lm_mode_band_lookup_succeeds() {
for lm in 0..E_PROB_MODEL_LM_COUNT as u32 {
for mode in [EnergyPredictionMode::Inter, EnergyPredictionMode::Intra] {
for band in 0..CELT_NUM_BANDS as u32 {
let p = e_prob_pair(lm, mode, band).unwrap_or_else(|e| {
panic!("lookup failed for (lm={lm},mode={mode:?},band={band}): {e:?}")
});
let _ = p.prob;
let _ = p.decay;
}
}
}
}
#[test]
fn pair_lookup_matches_row_lookup_for_every_cell() {
for lm in 0..E_PROB_MODEL_LM_COUNT as u32 {
for mode in [EnergyPredictionMode::Inter, EnergyPredictionMode::Intra] {
let row = e_prob_row(lm, mode).unwrap();
for band in 0..CELT_NUM_BANDS as u32 {
let pair = e_prob_pair(lm, mode, band).unwrap();
let off = (band as usize) * 2;
assert_eq!(
pair.prob, row[off],
"(lm={lm},mode={mode:?},band={band}) prob mismatch"
);
assert_eq!(
pair.decay,
row[off + 1],
"(lm={lm},mode={mode:?},band={band}) decay mismatch"
);
}
}
}
}
#[test]
fn intra_rows_have_lower_band0_probability_than_inter() {
for lm in 0..E_PROB_MODEL_LM_COUNT as u32 {
let inter = e_prob_pair(lm, EnergyPredictionMode::Inter, 0).unwrap();
let intra = e_prob_pair(lm, EnergyPredictionMode::Intra, 0).unwrap();
assert!(
intra.prob < inter.prob,
"(lm={lm}) intra band-0 prob {} should be < inter band-0 prob {}",
intra.prob,
inter.prob
);
}
}
}