use alloc::vec::Vec;
use crate::lpc::{
LpcLevels, MAX_COEFFICIENT_SHIFT, compute_residuals_into, lpc_analyze_levels_into,
lpc_synthesize_into,
};
use crate::rice::{estimate_cost, rice_decode_into, rice_encode_zigzag_into, zigzag};
use crate::{MAX_LPC_ORDER, MAX_PARTITION_ORDER};
pub const SYNC_WORD: u16 = 0x1ACC;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum DecodeError {
BadSyncWord {
got: u16,
},
InvalidPredictionOrder {
got: u8,
},
InvalidPartitionOrder {
got: u8,
},
InvalidCoefficientShift {
got: u8,
},
CoefficientShiftWithoutOrder {
shift: u8,
},
Truncated,
InvalidParameter,
Unsupported,
}
impl core::fmt::Display for DecodeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DecodeError::BadSyncWord { got } => {
write!(
f,
"bad sync word: got {got:#06x}, expected {SYNC_WORD:#06x}"
)
}
DecodeError::InvalidPredictionOrder { got } => {
write!(f, "prediction_order {got} exceeds max {MAX_LPC_ORDER}")
}
DecodeError::InvalidPartitionOrder { got } => {
write!(f, "partition_order {got} exceeds max {MAX_PARTITION_ORDER}")
}
DecodeError::InvalidCoefficientShift { got } => {
write!(
f,
"coefficient_shift {got} exceeds max {MAX_COEFFICIENT_SHIFT}"
)
}
DecodeError::CoefficientShiftWithoutOrder { shift } => {
write!(f, "coefficient_shift is {shift} but prediction_order is 0")
}
DecodeError::Truncated => f.write_str("bitstream truncated"),
DecodeError::InvalidParameter => f.write_str("decoded parameter out of range"),
DecodeError::Unsupported => f.write_str("unsupported stream feature"),
}
}
}
impl core::error::Error for DecodeError {}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct AudioFrameHeader {
pub prediction_order: u8,
pub partition_order: u8,
pub coefficient_shift: u8,
pub frame_sample_count: u16,
pub lpc_coefficients: Vec<i16>,
}
pub fn parse_header(data: &[u8]) -> Result<(AudioFrameHeader, usize), DecodeError> {
if data.len() < 7 {
return Err(DecodeError::Truncated);
}
let sync = u16::from_be_bytes([data[0], data[1]]);
if sync != SYNC_WORD {
return Err(DecodeError::BadSyncWord { got: sync });
}
let prediction_order = data[2];
if prediction_order > MAX_LPC_ORDER {
return Err(DecodeError::InvalidPredictionOrder {
got: prediction_order,
});
}
let partition_order = data[3];
if partition_order > MAX_PARTITION_ORDER {
return Err(DecodeError::InvalidPartitionOrder {
got: partition_order,
});
}
let coefficient_shift = data[4];
if coefficient_shift > MAX_COEFFICIENT_SHIFT {
return Err(DecodeError::InvalidCoefficientShift {
got: coefficient_shift,
});
}
if prediction_order == 0 && coefficient_shift != 0 {
return Err(DecodeError::CoefficientShiftWithoutOrder {
shift: coefficient_shift,
});
}
let frame_sample_count = u16::from_be_bytes([data[5], data[6]]);
if frame_sample_count == 0 {
return Err(DecodeError::InvalidParameter);
}
let n_partitions = 1u32 << partition_order;
if !(frame_sample_count as u32).is_multiple_of(n_partitions) {
return Err(DecodeError::InvalidParameter);
}
let coeff_bytes = prediction_order as usize * 2;
if data.len() < 7 + coeff_bytes {
return Err(DecodeError::Truncated);
}
let lpc_coefficients: Vec<i16> = (0..prediction_order as usize)
.map(|i| i16::from_be_bytes([data[7 + i * 2], data[8 + i * 2]]))
.collect();
Ok((
AudioFrameHeader {
prediction_order,
partition_order,
coefficient_shift,
frame_sample_count,
lpc_coefficients,
},
7 + coeff_bytes,
))
}
pub fn encode_frame(samples: &[i32]) -> Vec<u8> {
let mut out = Vec::new();
encode_frame_into(samples, &mut out);
out
}
pub fn encode_frame_into(samples: &[i32], out: &mut Vec<u8>) {
const ORDER_GRID: &[u8] = &[0, 2, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32];
encode_frame_with_grid(samples, ORDER_GRID, 2, out);
}
pub(crate) fn encode_frame_with_grid(
samples: &[i32],
order_grid: &[u8],
patience: u8,
out: &mut Vec<u8>,
) {
assert!(
samples.len() <= u16::MAX as usize,
"LAC frame_sample_count {} exceeds u16::MAX ({}) — chunk the input into smaller frames",
samples.len(),
u16::MAX
);
assert!(!samples.is_empty(), "LAC cannot encode a zero-sample frame");
assert!(
samples
.iter()
.all(|&s| (-((1 << 23) - 1)..=((1 << 23) - 1)).contains(&s)),
"LAC encoder input must satisfy |sample| ≤ 2^23 - 1 (spec §1); found out-of-range value"
);
let r0: i64 = samples.iter().map(|&s| (s as i64) * (s as i64)).sum();
let mut levels = LpcLevels::new();
let levels_valid = if r0 == 0 {
false
} else {
lpc_analyze_levels_into(samples, MAX_LPC_ORDER, &mut levels)
};
let mut best_total_bits = usize::MAX;
let mut best_order = 0u8;
let mut best_partition_order = 0u8;
let mut best_shift = 0u8;
let mut best_fixed_coeffs: Option<&'static [i16]> = None;
const FIXED_PREDICTORS: &[(u8, &[i16], u8)] = &[
(1, &[16_384], 1), (2, &[16_384, -8_192], 2), (3, &[24_576, -24_576, 8_192], 2), (4, &[16_384, -24_576, 16_384, -4_096], 3), ];
let mut stale_orders = 0u8;
let mut residuals_buf: Vec<i32> = Vec::with_capacity(samples.len());
let mut zigzag_buf: Vec<u32> = Vec::with_capacity(samples.len());
let mut best_zigzag: Vec<u32> = Vec::with_capacity(samples.len());
const EMPTY_COEFFS: &[i16] = &[];
for &order in order_grid {
let prev_best = best_total_bits;
let (coeffs_slice, shift): (&[i16], u8) = if order == 0 || !levels_valid {
(EMPTY_COEFFS, 0)
} else {
let view = levels.get(order);
(view.coefficients, view.shift)
};
compute_residuals_into(samples, coeffs_slice, shift, &mut residuals_buf);
zigzag_buf.clear();
zigzag_buf.extend(residuals_buf.iter().map(|&r| zigzag(r)));
let header_bits = (7 + coeffs_slice.len() * 2) * 8;
let mut best_updated = false;
for po in 0..=MAX_PARTITION_ORDER {
let Some(rice_bits) = estimate_cost(&zigzag_buf, po) else {
continue;
};
let total = header_bits + rice_bits;
if total < best_total_bits {
best_total_bits = total;
best_order = order;
best_partition_order = po;
best_shift = shift;
best_fixed_coeffs = None;
best_updated = true;
}
}
if best_updated {
core::mem::swap(&mut best_zigzag, &mut zigzag_buf);
}
if r0 == 0 {
break;
}
if best_total_bits < prev_best {
stale_orders = 0;
} else {
stale_orders += 1;
if stale_orders >= patience {
break;
}
}
}
if r0 != 0 {
for &(fp_order, fp_coeffs, fp_shift) in FIXED_PREDICTORS {
compute_residuals_into(samples, fp_coeffs, fp_shift, &mut residuals_buf);
zigzag_buf.clear();
zigzag_buf.extend(residuals_buf.iter().map(|&r| zigzag(r)));
let header_bits = (7 + fp_coeffs.len() * 2) * 8;
let mut best_updated = false;
for po in 0..=MAX_PARTITION_ORDER {
let Some(rice_bits) = estimate_cost(&zigzag_buf, po) else {
continue;
};
let total = header_bits + rice_bits;
if total < best_total_bits {
best_total_bits = total;
best_order = fp_order;
best_partition_order = po;
best_shift = fp_shift;
best_fixed_coeffs = Some(fp_coeffs);
best_updated = true;
}
}
if best_updated {
core::mem::swap(&mut best_zigzag, &mut zigzag_buf);
}
}
}
let best_coeffs: &[i16] = if let Some(fixed) = best_fixed_coeffs {
fixed
} else if best_order == 0 || !levels_valid {
&[]
} else {
levels.get(best_order).coefficients
};
if best_coeffs.is_empty() {
best_order = 0;
best_shift = 0;
}
out.clear();
out.reserve(7 + best_coeffs.len() * 2);
out.extend_from_slice(&SYNC_WORD.to_be_bytes());
out.push(best_order);
out.push(best_partition_order);
out.push(best_shift);
out.extend_from_slice(&(samples.len() as u16).to_be_bytes());
for &c in best_coeffs {
out.extend_from_slice(&c.to_be_bytes());
}
rice_encode_zigzag_into(&best_zigzag, best_partition_order, out);
best_zigzag.clear(); }
pub fn decode_frame(data: &[u8]) -> Result<Vec<i32>, DecodeError> {
let mut out = Vec::new();
decode_frame_into(data, &mut out)?;
Ok(out)
}
pub fn decode_frame_into(data: &[u8], out: &mut Vec<i32>) -> Result<(), DecodeError> {
let (header, header_len) = parse_header(data)?;
let rice_data = &data[header_len..];
let count = header.frame_sample_count as usize;
let mut residuals: Vec<i32> = Vec::with_capacity(count);
rice_decode_into(rice_data, header.partition_order, count, &mut residuals)?;
lpc_synthesize_into(
&residuals,
&header.lpc_coefficients,
header.coefficient_shift,
out,
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
use crate::test_signals::{angular_step_q32, sine_samples as int_sine_samples};
fn sine_samples(n: usize, freq_hz: u64, sample_rate: u64, amplitude: i32) -> Vec<i32> {
let step = angular_step_q32(freq_hz, sample_rate);
int_sine_samples(n, step, amplitude)
}
#[test]
fn roundtrip_sine_440hz() {
let samples = sine_samples(960, 440, 48_000, 100_000);
let encoded = encode_frame(&samples);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples, "sine frame roundtrip failed");
}
#[test]
fn roundtrip_silence() {
let samples = vec![0i32; 960];
let encoded = encode_frame(&samples);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples);
}
#[test]
fn silence_uses_order_zero() {
let samples = vec![0i32; 960];
let encoded = encode_frame(&samples);
let (header, _) = parse_header(&encoded).unwrap();
assert_eq!(header.prediction_order, 0);
}
#[test]
fn roundtrip_short_frame() {
let samples = sine_samples(128, 220, 48_000, 50_000);
let encoded = encode_frame(&samples);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples);
}
#[test]
fn roundtrip_non_power_of_two_length() {
let samples = sine_samples(137, 220, 48_000, 50_000);
let encoded = encode_frame(&samples);
let (header, _) = parse_header(&encoded).unwrap();
assert_eq!(header.partition_order, 0);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples);
}
#[test]
fn roundtrip_full_scale() {
let samples: Vec<i32> = (0..1024)
.map(|i| {
if i % 2 == 0 {
(1 << 23) - 1
} else {
-((1 << 23) - 1)
}
})
.collect();
let encoded = encode_frame(&samples);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples);
}
#[test]
fn sync_word_present() {
let samples = sine_samples(1024, 1000, 48_000, 10_000);
let encoded = encode_frame(&samples);
let sync = u16::from_be_bytes([encoded[0], encoded[1]]);
assert_eq!(sync, SYNC_WORD);
}
#[test]
fn decode_rejects_bad_sync() {
let samples = sine_samples(960, 440, 48_000, 1_000);
let mut encoded = encode_frame(&samples);
encoded[0] = 0xFF;
assert_eq!(
decode_frame(&encoded),
Err(DecodeError::BadSyncWord { got: 0xFFCC })
);
}
#[test]
fn decode_rejects_order_above_max() {
let mut encoded = encode_frame(&sine_samples(960, 440, 48_000, 1_000));
encoded[2] = MAX_LPC_ORDER + 1;
assert_eq!(
decode_frame(&encoded),
Err(DecodeError::InvalidPredictionOrder {
got: MAX_LPC_ORDER + 1
})
);
}
#[test]
fn decode_rejects_partition_order_above_max() {
let mut encoded = encode_frame(&sine_samples(960, 440, 48_000, 1_000));
encoded[3] = MAX_PARTITION_ORDER + 1;
assert_eq!(
decode_frame(&encoded),
Err(DecodeError::InvalidPartitionOrder {
got: MAX_PARTITION_ORDER + 1
})
);
}
#[test]
fn decode_rejects_mismatched_partition_count() {
let mut buf = Vec::new();
buf.extend_from_slice(&SYNC_WORD.to_be_bytes());
buf.push(0); buf.push(3); buf.push(0); buf.extend_from_slice(&7u16.to_be_bytes()); assert_eq!(parse_header(&buf), Err(DecodeError::InvalidParameter));
}
#[test]
fn decode_rejects_coefficient_shift_above_max() {
let mut encoded = encode_frame(&sine_samples(960, 440, 48_000, 10_000));
let (hdr, _) = parse_header(&encoded).unwrap();
if hdr.prediction_order > 0 {
encoded[4] = MAX_COEFFICIENT_SHIFT + 1;
assert_eq!(
decode_frame(&encoded),
Err(DecodeError::InvalidCoefficientShift {
got: MAX_COEFFICIENT_SHIFT + 1
})
);
}
}
#[test]
fn decode_rejects_coefficient_shift_without_order() {
let mut buf = Vec::new();
buf.extend_from_slice(&SYNC_WORD.to_be_bytes());
buf.push(0); buf.push(0); buf.push(3); buf.extend_from_slice(&320u16.to_be_bytes());
assert_eq!(
parse_header(&buf),
Err(DecodeError::CoefficientShiftWithoutOrder { shift: 3 })
);
}
#[test]
fn decode_rejects_truncated_header() {
let encoded = encode_frame(&sine_samples(960, 440, 48_000, 1_000));
let truncated = &encoded[..6];
assert_eq!(decode_frame(truncated), Err(DecodeError::Truncated));
}
#[test]
fn decode_rejects_truncated_coefficients() {
let encoded = encode_frame(&sine_samples(960, 440, 48_000, 10_000));
let (hdr, _) = parse_header(&encoded).unwrap();
if hdr.prediction_order > 0 {
let cut_at = 7 + hdr.prediction_order as usize * 2 - 1;
assert_eq!(
decode_frame(&encoded[..cut_at]),
Err(DecodeError::Truncated)
);
}
}
#[test]
fn higher_order_chosen_for_tonal_signal() {
let samples = sine_samples(960, 440, 48_000, 500_000);
let encoded = encode_frame(&samples);
let (header, _) = parse_header(&encoded).unwrap();
assert!(
header.prediction_order > 0,
"expected non-zero order for tonal signal, got {}",
header.prediction_order
);
}
#[test]
fn roundtrip_various_lengths() {
for &n in &[64usize, 120, 256, 480, 960, 1024, 2048, 4096] {
let samples = sine_samples(n, 1000, 48_000, 10_000);
let encoded = encode_frame(&samples);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples, "roundtrip failed at n={n}");
}
}
#[test]
fn roundtrip_transient_burst() {
let mut samples = Vec::with_capacity(1024);
for i in 0..256i32 {
samples.push((i % 13) - 6);
}
for i in 0..256i32 {
samples.push(((i * 31) % 400_000) - 200_000);
}
let step = crate::test_signals::angular_step_q32(200, 6283); let mut phase: u32 = 0;
for i in 0..512i32 {
let decay = 50_000 * (512 - i) / 512;
let s = crate::test_signals::sin_q15(phase);
let sample = ((s as i64 * decay as i64 + (1 << 14)) >> 15) as i32;
samples.push(sample);
phase = phase.wrapping_add(step);
}
let encoded = encode_frame(&samples);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples);
}
#[test]
fn decode_rejects_zero_sample_count() {
let mut buf = Vec::new();
buf.extend_from_slice(&SYNC_WORD.to_be_bytes());
buf.push(0); buf.push(0); buf.push(0); buf.extend_from_slice(&0u16.to_be_bytes()); assert_eq!(
parse_header(&buf),
Err(DecodeError::InvalidParameter),
"a zero-sample frame must be rejected"
);
assert_eq!(
decode_frame(&buf),
Err(DecodeError::InvalidParameter),
"decode_frame must surface the zero-count rejection"
);
}
#[test]
fn roundtrip_single_sample() {
for v in [0i32, 1, -1, 123_456, -((1 << 23) - 1), (1 << 23) - 1] {
let samples = vec![v];
let encoded = encode_frame(&samples);
let (hdr, _) = parse_header(&encoded).unwrap();
assert_eq!(hdr.frame_sample_count, 1);
assert_eq!(hdr.partition_order, 0);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples, "single-sample roundtrip failed for v={v}");
}
}
#[test]
fn roundtrip_single_sample_forced_high_order() {
for &forced_order in &[16u8, 32u8] {
for v in [0i32, 1, -1, 123_456, (1 << 23) - 1] {
let samples = vec![v];
let mut encoded = Vec::new();
encode_frame_with_grid(&samples, &[forced_order], u8::MAX, &mut encoded);
let (hdr, _) =
parse_header(&encoded).expect("forced-order encoder output must parse");
assert_eq!(hdr.frame_sample_count, 1);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(
decoded, samples,
"single-sample round-trip failed at forced order {forced_order} for v={v}"
);
}
}
}
#[test]
fn roundtrip_frame_at_u16_max() {
let samples: Vec<i32> = (0..u16::MAX as i32)
.map(|i| (i.wrapping_mul(17)) & 0xFFFF)
.collect();
assert_eq!(samples.len(), u16::MAX as usize);
let encoded = encode_frame(&samples);
let decoded = decode_frame(&encoded).unwrap();
assert_eq!(decoded, samples);
}
#[test]
#[should_panic(expected = "exceeds u16::MAX")]
fn encode_panics_at_frame_above_u16_max() {
let samples = vec![0i32; u16::MAX as usize + 1];
let _ = encode_frame(&samples);
}
#[test]
fn decode_panic_free_on_adversarial_coefficients_and_residuals() {
let mut buf = Vec::new();
buf.extend_from_slice(&SYNC_WORD.to_be_bytes());
buf.push(32); buf.push(0); buf.push(5); let n_samples = 64u16;
buf.extend_from_slice(&n_samples.to_be_bytes());
for j in 0..32 {
let c = if j & 1 == 0 { i16::MAX } else { i16::MIN };
buf.extend_from_slice(&c.to_be_bytes());
}
{
let mut w = crate::bit_io::BitWriter::new(&mut buf);
w.write_bits(0, 5); for _ in 0..n_samples {
for _ in 0..8192 {
w.write_bit(false);
}
w.write_bit(true);
}
w.finish();
}
let _ = decode_frame(&buf);
}
#[test]
fn sparse_vs_exhaustive_on_headset_speech() {
extern crate std;
use hound::WavReader;
use std::eprintln;
use std::path::Path;
const CORPUS_PATH: &str = "corpus/ES2002a.Headset-0.wav";
const FRAME_SIZE: usize = 4096;
let path = Path::new(CORPUS_PATH);
if !path.exists() {
eprintln!("skipping: corpus file not found: {}", path.display());
return;
}
let mut reader = WavReader::open(path).expect("open headset wav");
let channel: Vec<i32> = reader
.samples::<i32>()
.collect::<Result<Vec<_>, _>>()
.expect("parse samples");
let cap = (16_000 * 60).min(channel.len());
let channel = &channel[..cap];
let exhaustive_grid: Vec<u8> = (0u8..=32).collect();
let sparse_grid: &[u8] = &[0, 2, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32];
let mut sparse_bytes = 0usize;
let mut exhaustive_bytes = 0usize;
let mut sparse = Vec::new();
let mut exhaustive = Vec::new();
for chunk in channel.chunks(FRAME_SIZE) {
encode_frame_with_grid(chunk, sparse_grid, 2, &mut sparse);
encode_frame_with_grid(chunk, &exhaustive_grid, u8::MAX, &mut exhaustive);
assert_eq!(decode_frame(&sparse).unwrap(), chunk);
assert_eq!(decode_frame(&exhaustive).unwrap(), chunk);
sparse_bytes += sparse.len();
exhaustive_bytes += exhaustive.len();
}
assert!(
sparse_bytes >= exhaustive_bytes,
"sparse smaller than exhaustive? sparse={} exhaustive={}",
sparse_bytes,
exhaustive_bytes
);
let excess = (sparse_bytes as f64 / exhaustive_bytes as f64) - 1.0;
eprintln!(
"sparse_vs_exhaustive_on_headset_speech sparse={} exhaustive={} excess={:.2}%",
sparse_bytes,
exhaustive_bytes,
excess * 100.0
);
assert!(
excess < 0.005,
"sparse grid is {:.2}% larger than exhaustive (budget 0.5%)",
excess * 100.0
);
}
}