use crate::celt_band_layout::CeltFrameSize;
use crate::celt_log2_frac_table::{log2_frac, Log2FracError};
pub const ONE_BIT_EIGHTH_BITS: u32 = 8;
pub const CONSERVATIVE_DEDUCTION_EIGHTH_BITS: u32 = 1;
pub const EIGHTH_BITS_PER_BYTE: u32 = 64;
pub const ANTI_COLLAPSE_LM_MIN_EXCLUSIVE: u32 = 1;
pub const ANTI_COLLAPSE_HEADROOM_MULT_EIGHTH_BITS: u32 = 8;
pub const ANTI_COLLAPSE_HEADROOM_LM_OFFSET: u32 = 2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReservationError {
FrameSizeOverflows,
TellExceedsFrame {
frame_eighth_bits: u32,
ec_tell_frac: u32,
},
TotalBoostExceedsFrame {
frame_eighth_bits: u32,
ec_tell_frac: u32,
total_boost: u32,
},
LogFracLookupFailed(Log2FracError),
}
impl From<Log2FracError> for ReservationError {
fn from(value: Log2FracError) -> Self {
ReservationError::LogFracLookupFailed(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct ReservationOutcome {
pub anti_collapse_rsv: u32,
pub skip_rsv: u32,
pub intensity_rsv: u32,
pub dual_stereo_rsv: u32,
pub total_remaining_eighth_bits: u32,
}
impl ReservationOutcome {
pub const fn reserved_total_eighth_bits(&self) -> u32 {
self.anti_collapse_rsv
.saturating_add(self.skip_rsv)
.saturating_add(self.intensity_rsv)
.saturating_add(self.dual_stereo_rsv)
}
}
pub fn reserve_block(
frame_size_bytes: u32,
ec_tell_frac: u32,
total_boost: u32,
lm: CeltFrameSize,
is_transient: bool,
is_stereo: bool,
coded_bands: u32,
) -> Result<ReservationOutcome, ReservationError> {
let frame_eighth_bits = frame_size_bytes
.checked_mul(EIGHTH_BITS_PER_BYTE)
.ok_or(ReservationError::FrameSizeOverflows)?;
if ec_tell_frac > frame_eighth_bits {
return Err(ReservationError::TellExceedsFrame {
frame_eighth_bits,
ec_tell_frac,
});
}
let after_tell = frame_eighth_bits - ec_tell_frac;
if total_boost > after_tell {
return Err(ReservationError::TotalBoostExceedsFrame {
frame_eighth_bits,
ec_tell_frac,
total_boost,
});
}
let mut total = after_tell.saturating_sub(CONSERVATIVE_DEDUCTION_EIGHTH_BITS);
let lm_idx = lm.column_index() as u32;
let mut anti_collapse_rsv: u32 = 0;
if is_transient && lm_idx > ANTI_COLLAPSE_LM_MIN_EXCLUSIVE {
let headroom = (lm_idx + ANTI_COLLAPSE_HEADROOM_LM_OFFSET)
.saturating_mul(ANTI_COLLAPSE_HEADROOM_MULT_EIGHTH_BITS);
if total >= headroom {
anti_collapse_rsv = ONE_BIT_EIGHTH_BITS;
}
}
total = total.saturating_sub(anti_collapse_rsv);
let mut skip_rsv: u32 = 0;
if total > ONE_BIT_EIGHTH_BITS {
skip_rsv = ONE_BIT_EIGHTH_BITS;
}
total = total.saturating_sub(skip_rsv);
let mut intensity_rsv: u32 = 0;
let mut dual_stereo_rsv: u32 = 0;
if is_stereo {
let raw_intensity = log2_frac(coded_bands)? as u32;
if raw_intensity > total {
intensity_rsv = 0;
} else {
intensity_rsv = raw_intensity;
total -= intensity_rsv;
if total > ONE_BIT_EIGHTH_BITS {
dual_stereo_rsv = ONE_BIT_EIGHTH_BITS;
total -= dual_stereo_rsv;
}
}
}
Ok(ReservationOutcome {
anti_collapse_rsv,
skip_rsv,
intensity_rsv,
dual_stereo_rsv,
total_remaining_eighth_bits: total,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn one_bit_constant_matches_rfc() {
assert_eq!(ONE_BIT_EIGHTH_BITS, 8);
assert_eq!(ONE_BIT_EIGHTH_BITS / 8, 1);
}
#[test]
fn conservative_deduction_constant_is_one() {
assert_eq!(CONSERVATIVE_DEDUCTION_EIGHTH_BITS, 1);
}
#[test]
fn eighth_bits_per_byte_matches_trim_module() {
assert_eq!(EIGHTH_BITS_PER_BYTE, 64);
assert_eq!(
EIGHTH_BITS_PER_BYTE,
crate::celt_alloc_trim::EIGHTH_BITS_PER_BYTE
);
}
#[test]
fn anti_collapse_lm_threshold_matches_rfc() {
assert_eq!(ANTI_COLLAPSE_LM_MIN_EXCLUSIVE, 1);
assert_eq!(CeltFrameSize::Ms2_5.column_index() as u32, 0);
assert_eq!(CeltFrameSize::Ms5.column_index() as u32, 1);
assert_eq!(CeltFrameSize::Ms10.column_index() as u32, 2);
assert_eq!(CeltFrameSize::Ms20.column_index() as u32, 3);
}
#[test]
fn anti_collapse_headroom_constants_match_rfc() {
assert_eq!(ANTI_COLLAPSE_HEADROOM_LM_OFFSET, 2);
assert_eq!(ANTI_COLLAPSE_HEADROOM_MULT_EIGHTH_BITS, 8);
}
fn small_mono_inputs() -> (u32, u32, u32, CeltFrameSize, bool, bool, u32) {
(100, 0, 0, CeltFrameSize::Ms20, false, false, 21)
}
#[test]
fn mono_nontransient_short_lm_yields_no_anti_collapse_no_stereo_rsv() {
let outcome = reserve_block(
100,
0,
0,
CeltFrameSize::Ms2_5,
true, false,
21,
)
.unwrap();
assert_eq!(outcome.anti_collapse_rsv, 0);
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.intensity_rsv, 0);
assert_eq!(outcome.dual_stereo_rsv, 0);
}
#[test]
fn mono_nontransient_long_lm_yields_no_anti_collapse() {
let (bytes, tell, boost, lm, _, stereo, cb) = small_mono_inputs();
let outcome = reserve_block(bytes, tell, boost, lm, false, stereo, cb).unwrap();
assert_eq!(outcome.anti_collapse_rsv, 0);
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.intensity_rsv, 0);
assert_eq!(outcome.dual_stereo_rsv, 0);
}
#[test]
fn mono_transient_lm0_yields_no_anti_collapse() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms2_5, true, false, 21).unwrap();
assert_eq!(outcome.anti_collapse_rsv, 0);
}
#[test]
fn mono_transient_lm1_yields_no_anti_collapse() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms5, true, false, 21).unwrap();
assert_eq!(outcome.anti_collapse_rsv, 0);
}
#[test]
fn mono_transient_lm2_with_room_yields_anti_collapse() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms10, true, false, 21).unwrap();
assert_eq!(outcome.anti_collapse_rsv, 8);
assert_eq!(outcome.skip_rsv, 8);
}
#[test]
fn mono_transient_lm3_with_room_yields_anti_collapse() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, true, false, 21).unwrap();
assert_eq!(outcome.anti_collapse_rsv, 8);
}
#[test]
fn anti_collapse_threshold_exact_match_passes() {
let frame_size_bytes = 1u32;
let frame_eighth = frame_size_bytes * 64;
let want_total = (2 + 2) * 8; let ec_tell = frame_eighth - want_total - 1;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms10,
true,
false,
21,
)
.unwrap();
assert_eq!(outcome.anti_collapse_rsv, 8);
}
#[test]
fn anti_collapse_threshold_one_short_fails() {
let frame_size_bytes = 1u32;
let frame_eighth = frame_size_bytes * 64;
let want_total = (2 + 2) * 8; let ec_tell = frame_eighth - want_total - 1 + 1;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms10,
true,
false,
21,
)
.unwrap();
assert_eq!(outcome.anti_collapse_rsv, 0);
}
#[test]
fn skip_rsv_set_when_total_above_eight() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, false, 21).unwrap();
assert_eq!(outcome.skip_rsv, 8);
}
#[test]
fn skip_rsv_threshold_strictly_greater_than_eight() {
let frame_size_bytes = 1u32;
let frame_eighth = frame_size_bytes * 64;
let want_total = 8u32;
let ec_tell = frame_eighth - want_total - 1;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms5,
false,
false,
21,
)
.unwrap();
assert_eq!(outcome.skip_rsv, 0);
}
#[test]
fn skip_rsv_threshold_one_above_eight() {
let frame_size_bytes = 1u32;
let frame_eighth = frame_size_bytes * 64;
let want_total = 9u32;
let ec_tell = frame_eighth - want_total - 1;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms5,
false,
false,
21,
)
.unwrap();
assert_eq!(outcome.skip_rsv, 8);
}
#[test]
fn anti_collapse_deducts_from_total_before_skip_gate() {
let frame_size_bytes = 1u32;
let frame_eighth = frame_size_bytes * 64;
let ec_tell = 23u32;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms10,
true,
false,
21,
)
.unwrap();
assert_eq!(outcome.anti_collapse_rsv, 8);
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.total_remaining_eighth_bits, 24);
assert_eq!(frame_eighth - ec_tell - 1, 16 + 24);
}
#[test]
fn stereo_with_budget_sets_intensity_and_dual() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, true, 21).unwrap();
assert_eq!(outcome.intensity_rsv, 36);
assert_eq!(outcome.dual_stereo_rsv, 8);
}
#[test]
fn stereo_intensity_above_total_resets_to_zero() {
let frame_size_bytes = 1u32;
let ec_tell = 25u32;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms5,
false,
true,
21,
)
.unwrap();
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.intensity_rsv, 0);
assert_eq!(outcome.dual_stereo_rsv, 0);
assert_eq!(outcome.total_remaining_eighth_bits, 30);
}
#[test]
fn stereo_intensity_consumes_total_no_dual_stereo() {
let frame_size_bytes = 1u32;
let ec_tell = 12u32;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms5,
false,
true,
21,
)
.unwrap();
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.intensity_rsv, 36);
assert_eq!(outcome.dual_stereo_rsv, 0);
assert_eq!(outcome.total_remaining_eighth_bits, 7);
}
#[test]
fn stereo_intensity_consumes_total_dual_stereo_just_fits() {
let frame_size_bytes = 1u32;
let ec_tell = 10u32;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms5,
false,
true,
21,
)
.unwrap();
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.intensity_rsv, 36);
assert_eq!(outcome.dual_stereo_rsv, 8);
assert_eq!(outcome.total_remaining_eighth_bits, 1);
}
#[test]
fn mono_skips_stereo_branches_even_with_budget() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, false, 21).unwrap();
assert_eq!(outcome.intensity_rsv, 0);
assert_eq!(outcome.dual_stereo_rsv, 0);
}
#[test]
fn hybrid_window_intensity_uses_log2_frac_table_at_four() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, true, 4).unwrap();
assert_eq!(outcome.intensity_rsv, 19);
assert_eq!(outcome.dual_stereo_rsv, 8);
}
#[test]
fn coded_bands_one_intensity_is_eight() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, true, 1).unwrap();
assert_eq!(outcome.intensity_rsv, 8);
}
#[test]
fn coded_bands_zero_intensity_is_zero() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, true, 0).unwrap();
assert_eq!(outcome.intensity_rsv, 0);
assert_eq!(outcome.dual_stereo_rsv, 8);
}
#[test]
fn reservation_invariant_holds_for_stereo_with_room() {
let outcome = reserve_block(50, 0, 0, CeltFrameSize::Ms20, true, true, 21).unwrap();
let frame_eighth: u32 = 50 * 64;
let conservative_total = frame_eighth - 1;
let reserved = outcome.reserved_total_eighth_bits();
assert_eq!(
outcome.total_remaining_eighth_bits + reserved,
conservative_total
);
}
#[test]
fn reservation_invariant_holds_for_mono_with_no_anti_collapse() {
let outcome = reserve_block(200, 0, 0, CeltFrameSize::Ms20, false, false, 21).unwrap();
let frame_eighth: u32 = 200 * 64;
let conservative_total = frame_eighth - 1;
let reserved = outcome.reserved_total_eighth_bits();
assert_eq!(
outcome.total_remaining_eighth_bits + reserved,
conservative_total
);
}
#[test]
fn reservation_invariant_holds_with_nonzero_tell() {
let outcome = reserve_block(100, 137, 24, CeltFrameSize::Ms20, true, true, 21).unwrap();
let frame_eighth = 100 * 64;
let conservative_total = frame_eighth - 137 - 1;
let reserved = outcome.reserved_total_eighth_bits();
assert_eq!(
outcome.total_remaining_eighth_bits + reserved,
conservative_total
);
}
#[test]
fn reservation_invariant_holds_when_intensity_resets() {
let frame_size_bytes = 1u32;
let ec_tell = 25u32;
let outcome = reserve_block(
frame_size_bytes,
ec_tell,
0,
CeltFrameSize::Ms5,
false,
true,
21,
)
.unwrap();
assert_eq!(outcome.intensity_rsv, 0);
let frame_eighth = frame_size_bytes * 64;
let conservative_total = frame_eighth - ec_tell - 1;
let reserved = outcome.reserved_total_eighth_bits();
assert_eq!(
outcome.total_remaining_eighth_bits + reserved,
conservative_total
);
}
#[test]
fn frame_size_overflow_rejected() {
let err = reserve_block(u32::MAX, 0, 0, CeltFrameSize::Ms20, false, false, 21).unwrap_err();
assert_eq!(err, ReservationError::FrameSizeOverflows);
}
#[test]
fn tell_exceeds_frame_rejected() {
let err =
reserve_block(10, 10 * 64 + 1, 0, CeltFrameSize::Ms20, false, false, 21).unwrap_err();
assert!(matches!(err, ReservationError::TellExceedsFrame { .. }));
}
#[test]
fn total_boost_exceeds_frame_rejected() {
let err =
reserve_block(10, 0, 10 * 64 + 1, CeltFrameSize::Ms20, false, false, 21).unwrap_err();
assert!(matches!(
err,
ReservationError::TotalBoostExceedsFrame { .. }
));
}
#[test]
fn coded_bands_above_table_rejected() {
let err = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, true, 24).unwrap_err();
assert!(matches!(
err,
ReservationError::LogFracLookupFailed(Log2FracError::CodedBandsOutOfRange { .. })
));
}
#[test]
fn mono_with_oversize_coded_bands_does_not_lookup() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, false, false, 999).unwrap();
assert_eq!(outcome.intensity_rsv, 0);
assert_eq!(outcome.dual_stereo_rsv, 0);
}
#[test]
fn zero_frame_yields_zero_reservations_and_zero_total() {
let outcome = reserve_block(0, 0, 0, CeltFrameSize::Ms20, true, true, 21).unwrap();
assert_eq!(outcome.anti_collapse_rsv, 0);
assert_eq!(outcome.skip_rsv, 0);
assert_eq!(outcome.intensity_rsv, 0);
assert_eq!(outcome.dual_stereo_rsv, 0);
assert_eq!(outcome.total_remaining_eighth_bits, 0);
}
#[test]
fn frame_one_byte_minus_one_yields_no_reservations() {
let outcome = reserve_block(1, 0, 0, CeltFrameSize::Ms2_5, false, false, 21).unwrap();
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.total_remaining_eighth_bits, 55);
}
#[test]
fn maximum_frame_yields_correct_remaining() {
let outcome = reserve_block(1275, 0, 0, CeltFrameSize::Ms20, true, true, 23).unwrap();
assert_eq!(outcome.anti_collapse_rsv, 8);
assert_eq!(outcome.skip_rsv, 8);
assert_eq!(outcome.intensity_rsv, 37);
assert_eq!(outcome.dual_stereo_rsv, 8);
assert_eq!(outcome.total_remaining_eighth_bits, 81538);
}
#[test]
fn outcome_default_is_all_zero() {
let d = ReservationOutcome::default();
assert_eq!(d.anti_collapse_rsv, 0);
assert_eq!(d.skip_rsv, 0);
assert_eq!(d.intensity_rsv, 0);
assert_eq!(d.dual_stereo_rsv, 0);
assert_eq!(d.total_remaining_eighth_bits, 0);
assert_eq!(d.reserved_total_eighth_bits(), 0);
}
#[test]
fn debug_format_renders() {
let outcome = reserve_block(100, 0, 0, CeltFrameSize::Ms20, true, true, 21).unwrap();
let s = format!("{outcome:?}");
assert!(s.contains("anti_collapse_rsv"));
assert!(s.contains("intensity_rsv"));
}
#[test]
fn determinism_across_repeats() {
let a = reserve_block(137, 91, 24, CeltFrameSize::Ms20, true, true, 17).unwrap();
let b = reserve_block(137, 91, 24, CeltFrameSize::Ms20, true, true, 17).unwrap();
assert_eq!(a, b);
}
#[test]
fn from_log_frac_error_round_trip() {
let inner = Log2FracError::CodedBandsOutOfRange { coded_bands: 100 };
let outer: ReservationError = inner.into();
assert_eq!(outer, ReservationError::LogFracLookupFailed(inner));
}
}