use super::laplace::ec_laplace_decode;
use super::modes::{BETA_COEF, BETA_INTRA, E_PROB_MODEL, MAX_FINE_BITS, NB_EBANDS, PRED_COEF};
use crate::range::RangeDecoder;
const SMALL_ENERGY_ICDF: [u8; 3] = [2, 1, 0];
#[derive(Debug, Clone, Default)]
pub struct EnergyState {
pub old_ebands: [[f32; NB_EBANDS]; 2],
}
#[allow(
clippy::too_many_arguments,
reason = "mirrors the reference unquant_coarse_energy signature"
)]
pub fn decode_coarse_energy(
dec: &mut RangeDecoder,
state: &mut EnergyState,
start: usize,
end: usize,
intra: bool,
channels: usize,
lm: usize,
budget_bits: u32,
) {
debug_assert!(end <= NB_EBANDS && start <= end);
debug_assert!(channels == 1 || channels == 2);
debug_assert!(lm < 4);
let prob_model = &E_PROB_MODEL[lm][usize::from(intra)];
let (coef, beta) = if intra {
(0.0, BETA_INTRA)
} else {
(PRED_COEF[lm], BETA_COEF[lm])
};
let budget = i64::from(budget_bits);
let mut prev = [0.0f32; 2];
for i in start..end {
for (c, prev_c) in prev.iter_mut().enumerate().take(channels) {
let tell = i64::from(dec.tell());
let qi: i32 = if budget - tell >= 15 {
let pi = 2 * i.min(20);
ec_laplace_decode(dec, u32::from(prob_model[pi]) << 7, u32::from(prob_model[pi + 1]) << 6)
} else if budget - tell >= 2 {
let q = dec.decode_icdf(&SMALL_ENERGY_ICDF, 2) as i32;
(q >> 1) ^ -(q & 1)
} else if budget - tell >= 1 {
-i32::from(dec.decode_bit_logp(1))
} else {
-1
};
let q = qi as f32;
let old = state.old_ebands[c][i].max(-9.0);
state.old_ebands[c][i] = coef * old + *prev_c + q;
*prev_c += q - beta * q;
}
}
}
pub fn decode_fine_energy(
dec: &mut RangeDecoder,
state: &mut EnergyState,
start: usize,
end: usize,
fine_quant: &[i32],
channels: usize,
) {
for (i, &bits) in fine_quant.iter().enumerate().take(end).skip(start) {
if bits <= 0 {
continue;
}
for c in 0..channels {
let q2 = dec.decode_raw_bits(bits as u32) as f32;
let offset = (q2 + 0.5) / (1 << bits) as f32 - 0.5;
state.old_ebands[c][i] += offset;
}
}
}
#[allow(
clippy::too_many_arguments,
reason = "mirrors the reference unquant_energy_finalise signature"
)]
pub fn decode_energy_finalise(
dec: &mut RangeDecoder,
state: &mut EnergyState,
start: usize,
end: usize,
fine_quant: &[i32],
fine_priority: &[bool],
mut bits_left: i32,
channels: usize,
) {
for prio in [false, true] {
let mut i = start;
while i < end && bits_left >= channels as i32 {
if fine_quant[i] >= MAX_FINE_BITS || fine_priority[i] != prio {
i += 1;
continue;
}
for c in 0..channels {
let q2 = dec.decode_raw_bits(1) as f32;
let offset = (q2 - 0.5) / (1 << (fine_quant[i] + 1)) as f32;
state.old_ebands[c][i] += offset;
bits_left -= 1;
}
i += 1;
}
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use super::*;
use crate::celt::laplace::ec_laplace_encode;
use crate::range::{RangeDecoder, RangeEncoder};
#[allow(clippy::needless_range_loop, reason = "indices mirror the decoder loop under test")]
fn predict(deltas: &[i32], lm: usize, intra: bool, channels: usize) -> Vec<[f32; NB_EBANDS]> {
let (coef, beta) = if intra {
(0.0, BETA_INTRA)
} else {
(PRED_COEF[lm], BETA_COEF[lm])
};
let mut out = vec![[0.0f32; NB_EBANDS]; channels];
let mut prev = [0.0f32; 2];
let mut it = deltas.iter();
for i in 0..NB_EBANDS {
for (c, prev_c) in prev.iter_mut().enumerate().take(channels) {
let q = *it.next().expect("enough deltas") as f32;
let old = out[c][i].max(-9.0);
out[c][i] = coef * old + *prev_c + q;
*prev_c += q - beta * q;
}
}
out
}
#[test]
#[allow(clippy::needless_range_loop, reason = "indices mirror the decoder loop under test")]
fn coarse_energy_round_trip() {
for lm in 0..4usize {
for intra in [false, true] {
for channels in [1usize, 2] {
let prob = &E_PROB_MODEL[lm][usize::from(intra)];
let deltas: Vec<i32> = (0..NB_EBANDS * channels).map(|i| ((i as i32 * 5) % 7) - 3).collect();
let mut enc = RangeEncoder::new(256);
let mut coded = Vec::new();
let mut it = deltas.iter();
for i in 0..NB_EBANDS {
for _ in 0..channels {
let pi = 2 * i.min(20);
coded.push(ec_laplace_encode(
&mut enc,
*it.next().expect("delta"),
u32::from(prob[pi]) << 7,
u32::from(prob[pi + 1]) << 6,
));
}
}
assert_eq!(coded, deltas, "small deltas never saturate");
let enc_rng = enc.range_size();
let buf = enc.finalize().expect("within budget");
let mut dec = RangeDecoder::new(&buf);
let mut state = EnergyState::default();
decode_coarse_energy(
&mut dec,
&mut state,
0,
NB_EBANDS,
intra,
channels,
lm,
buf.len() as u32 * 8,
);
assert_eq!(dec.range_size(), enc_rng, "lm={lm} intra={intra} C={channels}");
let expected = predict(&deltas, lm, intra, channels);
for c in 0..channels {
for i in 0..NB_EBANDS {
assert!(
(state.old_ebands[c][i] - expected[c][i]).abs() < 1e-5,
"lm={lm} intra={intra} c={c} band {i}: {} vs {}",
state.old_ebands[c][i],
expected[c][i]
);
}
}
}
}
}
}
#[test]
fn coarse_energy_respects_starved_budget() {
let buf = [0xA5u8, 0x3C];
let mut dec = RangeDecoder::new(&buf);
let mut state = EnergyState::default();
decode_coarse_energy(&mut dec, &mut state, 0, NB_EBANDS, false, 2, 3, 16);
for c in 0..2 {
for i in 0..NB_EBANDS {
assert!(state.old_ebands[c][i].is_finite());
}
}
}
#[test]
fn fine_energy_refines_toward_centre() {
let mut enc = RangeEncoder::new(16);
enc.encode_raw_bits(5, 3);
let buf = enc.finalize().expect("fits");
let mut dec = RangeDecoder::new(&buf);
let mut state = EnergyState::default();
let mut fine_quant = vec![0i32; NB_EBANDS];
fine_quant[0] = 3;
decode_fine_energy(&mut dec, &mut state, 0, 1, &fine_quant, 1);
assert!((state.old_ebands[0][0] - 0.1875).abs() < 1e-6);
}
#[test]
fn finalise_spends_priority_zero_first() {
let mut enc = RangeEncoder::new(16);
enc.encode_raw_bits(1, 1); let buf = enc.finalize().expect("fits");
let mut dec = RangeDecoder::new(&buf);
let mut state = EnergyState::default();
let fine_quant = vec![0i32; NB_EBANDS];
let mut fine_priority = vec![true; NB_EBANDS];
fine_priority[1] = false;
decode_energy_finalise(&mut dec, &mut state, 0, 2, &fine_quant, &fine_priority, 1, 1);
assert_eq!(state.old_ebands[0][0], 0.0, "priority-1 band untouched");
assert!((state.old_ebands[0][1] - 0.25).abs() < 1e-6, "q2=1 at B=0: (1-0.5)/2");
}
}