use ans::{AnsError, FrequencyTable, RansDecoder, RansEncoder};
fn main() -> Result<(), AnsError> {
let prior = FrequencyTable::from_counts(&[1, 1], 12)?;
let posterior = FrequencyTable::from_counts(&[8, 2], 12)?;
let seed_model = FrequencyTable::from_counts(&[3, 7], 12)?;
let seed_message = [1u32, 0, 1, 1, 0, 1, 0, 0];
let mut enc = RansEncoder::new();
for &sym in seed_message.iter().rev() {
enc.put(sym, &seed_model)?;
}
println!("=== Bits-back coding (BB-ANS) ===\n");
println!("Seed: encoded {} symbols", seed_message.len());
println!("Encoder state after seed: {}", enc.state());
let num_latents = 5;
let mut latent_samples = Vec::new();
let midpoint_bytes = enc.finish();
let mut dec = RansDecoder::new(&midpoint_bytes)?;
println!("\n--- Bits-back encode ({num_latents} latents) ---");
for i in 0..num_latents {
let z = dec.peek(&prior);
dec.advance(z, &prior)?;
latent_samples.push(z);
println!(" latent[{i}]: decoded z={z} from prior");
}
let remaining = dec.remaining_bytes();
let dec_state = dec.state();
println!("\nAfter prior decodes: state={dec_state}, remaining_bytes={remaining}");
let mut enc2 = RansEncoder::new();
for &z in latent_samples.iter().rev() {
enc2.put(z, &posterior)?;
}
let posterior_bytes = enc2.finish();
println!("Posterior encoding: {} bytes", posterior_bytes.len());
println!("\n--- Bits-back decode ---");
let mut dec2 = RansDecoder::new(&posterior_bytes)?;
let mut recovered = Vec::new();
for i in 0..num_latents {
let z = dec2.peek(&posterior);
dec2.advance(z, &posterior)?;
recovered.push(z);
println!(" latent[{i}]: decoded z={z} from posterior");
}
assert_eq!(recovered, latent_samples);
println!("\nRecovered latents match originals.");
let posterior_bits = posterior_bytes.len() as f64 * 8.0;
let naive_bits = {
let mut naive_enc = RansEncoder::new();
for &z in latent_samples.iter().rev() {
naive_enc.put(z, &posterior)?;
}
naive_enc.finish().len() as f64 * 8.0
};
println!("\nBits used (posterior only): {posterior_bits:.0}");
println!("Bits used (naive, same): {naive_bits:.0}");
println!(
"In a full BB-ANS pipeline, the prior-decode step extracts ~{:.1} free bits per latent",
1.0 );
Ok(())
}