1use crate::error::CodecError;
51
52pub const CDF_PROB_TOP: u16 = 32768;
58
59pub const CDF_PROB_BITS: u32 = 15;
61
62pub type CdfTable<const N: usize, const CTX: usize> = [[u16; N]; CTX];
74
75pub const DC_COEFF_SKIP_CDF: CdfTable<3, 1> = [[
85 0, 20000, 32768, ]];
89
90pub const AC_COEFF_SKIP_CDF: CdfTable<3, 1> = [[
94 0, 14000, 32768, ]];
98
99pub const TRANSFORM_TYPE_CDF: CdfTable<17, 1> = [[
105 0, 26200, 27340, 28000, 28600, 29100, 29550, 29950, 30310, 30640, 30950, 31240, 31520, 31790, 32060, 32400, 32768, ]];
123
124pub const PARTITION_TYPE_CDF: CdfTable<5, 1> = [[
128 0, 16000, 21000, 26000, 32768, ]];
134
135#[derive(Debug, Clone)]
160pub struct RangeCoder {
161 range: u32,
164
165 low: u32,
168 output: Vec<u8>,
170
171 input: Vec<u8>,
174 read_pos: usize,
176 code: u32,
178 decode_mode: bool,
180}
181
182impl RangeCoder {
183 const BOT: u32 = 1 << 16;
185
186 #[must_use]
188 pub fn new() -> Self {
189 Self {
190 range: u32::MAX,
191 low: 0,
192 output: Vec::new(),
193 input: Vec::new(),
194 read_pos: 0,
195 code: 0,
196 decode_mode: false,
197 }
198 }
199
200 pub fn init_from_slice(&mut self, data: &[u8]) -> Result<(), CodecError> {
206 if data.is_empty() {
207 return Err(CodecError::InvalidBitstream(
208 "RangeCoder: empty bitstream".into(),
209 ));
210 }
211 self.decode_mode = true;
212 self.input = data.to_vec();
213 self.read_pos = 0;
214 self.range = u32::MAX;
215 self.code = 0;
217 for _ in 0..4 {
218 let b = self.read_byte_internal();
219 self.code = (self.code << 8) | u32::from(b);
220 }
221 Ok(())
222 }
223
224 #[must_use]
229 pub fn flush(mut self) -> Vec<u8> {
230 if !self.decode_mode {
231 for _ in 0..4 {
232 self.output.push((self.low >> 24) as u8);
233 self.low = self.low.wrapping_shl(8);
234 }
235 }
236 self.output
237 }
238
239 fn read_byte_internal(&mut self) -> u8 {
242 if self.read_pos < self.input.len() {
243 let b = self.input[self.read_pos];
244 self.read_pos += 1;
245 b
246 } else {
247 0x00 }
249 }
250
251 fn renormalize_encoder(&mut self) {
253 while self.range < Self::BOT {
254 self.output.push((self.low >> 24) as u8);
255 self.low = self.low.wrapping_shl(8);
256 self.range <<= 8;
257 }
258 }
259
260 fn renormalize_decoder(&mut self) {
262 while self.range < Self::BOT {
263 let b = self.read_byte_internal();
264 self.code = (self.code << 8) | u32::from(b);
265 self.range <<= 8;
266 }
267 }
268
269 fn encode_symbol_with_cdf(&mut self, sym: usize, cdf: &[u16]) -> Result<(), CodecError> {
271 let n_syms = cdf.len().saturating_sub(1);
272 if n_syms == 0 {
273 return Err(CodecError::InvalidParameter(
274 "CDF must have at least 2 entries".into(),
275 ));
276 }
277 if sym >= n_syms {
278 return Err(CodecError::InvalidParameter(format!(
279 "symbol {sym} out of range for {n_syms}-symbol CDF"
280 )));
281 }
282
283 let total = u32::from(CDF_PROB_TOP);
284 let cum_lo = u32::from(cdf[sym]);
285 let cum_hi = u32::from(cdf[sym + 1]);
286 let step = self.range / total;
287
288 self.low = self.low.wrapping_add(step * cum_lo);
289 if sym + 1 < n_syms {
291 self.range = step * (cum_hi - cum_lo);
292 } else {
293 self.range -= step * cum_lo;
294 }
295
296 self.renormalize_encoder();
297 Ok(())
298 }
299
300 fn decode_symbol_with_cdf(&mut self, cdf: &[u16]) -> Result<u8, CodecError> {
302 let n_syms = cdf.len().saturating_sub(1);
303 if n_syms == 0 {
304 return Err(CodecError::InvalidBitstream(
305 "CDF must have at least 2 entries".into(),
306 ));
307 }
308
309 let total = u32::from(CDF_PROB_TOP);
310 let step = self.range / total;
311
312 let mut sym = n_syms - 1;
316 for i in 0..n_syms {
317 if self.code < step * u32::from(cdf[i + 1]) {
319 sym = i;
320 break;
321 }
322 }
323
324 let cum_lo = u32::from(cdf[sym]);
325
326 self.code = self.code.wrapping_sub(step * cum_lo);
327 if sym + 1 < n_syms {
328 let cum_hi = u32::from(cdf[sym + 1]);
329 self.range = step * (cum_hi - cum_lo);
330 } else {
331 self.range -= step * cum_lo;
332 }
333
334 self.renormalize_decoder();
335
336 Ok(sym as u8)
337 }
338}
339
340pub fn encode_symbol_table<const N: usize, const CTX: usize>(
351 rc: &mut RangeCoder,
352 sym: u8,
353 ctx: usize,
354 table: &CdfTable<N, CTX>,
355) -> Result<(), CodecError> {
356 if ctx >= CTX {
357 return Err(CodecError::InvalidParameter(format!(
358 "context {ctx} out of range (table has {CTX} contexts)"
359 )));
360 }
361 rc.encode_symbol_with_cdf(sym as usize, &table[ctx])
362}
363
364pub fn decode_symbol_table<const N: usize, const CTX: usize>(
373 rc: &mut RangeCoder,
374 ctx: usize,
375 table: &CdfTable<N, CTX>,
376) -> Result<u8, CodecError> {
377 if ctx >= CTX {
378 return Err(CodecError::InvalidParameter(format!(
379 "context {ctx} out of range (table has {CTX} contexts)"
380 )));
381 }
382 rc.decode_symbol_with_cdf(&table[ctx])
383}
384
385#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
396 fn dc_coeff_skip_cdf_valid() {
397 let row = &DC_COEFF_SKIP_CDF[0];
398 assert_eq!(row[0], 0, "first CDF entry must be 0");
399 assert_eq!(
400 *row.last().expect("non-empty row"),
401 CDF_PROB_TOP,
402 "last entry must be CDF_PROB_TOP"
403 );
404 for w in row.windows(2) {
406 assert!(w[0] <= w[1], "CDF must be monotonically non-decreasing");
407 }
408 }
409
410 #[test]
411 fn ac_coeff_skip_cdf_valid() {
412 let row = &AC_COEFF_SKIP_CDF[0];
413 assert_eq!(row[0], 0);
414 assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
415 for w in row.windows(2) {
416 assert!(w[0] <= w[1]);
417 }
418 }
419
420 #[test]
421 fn transform_type_cdf_valid() {
422 let row = &TRANSFORM_TYPE_CDF[0];
423 assert_eq!(row[0], 0);
424 assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
425 assert_eq!(row.len(), 17, "16 symbols + 1 sentinel");
426 for w in row.windows(2) {
427 assert!(w[0] <= w[1]);
428 }
429 }
430
431 #[test]
432 fn partition_type_cdf_valid() {
433 let row = &PARTITION_TYPE_CDF[0];
434 assert_eq!(row[0], 0);
435 assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
436 assert_eq!(row.len(), 5, "4 symbols + 1 sentinel");
437 for w in row.windows(2) {
438 assert!(w[0] <= w[1]);
439 }
440 }
441
442 #[test]
445 fn range_coder_dc_skip_roundtrip_zero() {
446 let mut rc = RangeCoder::new();
447 encode_symbol_table(&mut rc, 0, 0, &DC_COEFF_SKIP_CDF).expect("encode sym 0");
448 let bs = rc.flush();
449
450 let mut dec = RangeCoder::new();
451 dec.init_from_slice(&bs).expect("init");
452 let sym = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
453 assert_eq!(sym, 0, "should decode symbol 0");
454 }
455
456 #[test]
457 fn range_coder_dc_skip_roundtrip_one() {
458 let mut rc = RangeCoder::new();
459 encode_symbol_table(&mut rc, 1, 0, &DC_COEFF_SKIP_CDF).expect("encode sym 1");
460 let bs = rc.flush();
461
462 let mut dec = RangeCoder::new();
463 dec.init_from_slice(&bs).expect("init");
464 let sym = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
465 assert_eq!(sym, 1, "should decode symbol 1");
466 }
467
468 #[test]
469 fn range_coder_partition_type_all_symbols() {
470 for sym_in in 0u8..4 {
471 let mut rc = RangeCoder::new();
472 encode_symbol_table(&mut rc, sym_in, 0, &PARTITION_TYPE_CDF).expect("encode partition");
473 let bs = rc.flush();
474
475 let mut dec = RangeCoder::new();
476 dec.init_from_slice(&bs).expect("init");
477 let sym_out = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode");
478 assert_eq!(
479 sym_out, sym_in,
480 "partition type {sym_in} must survive round-trip"
481 );
482 }
483 }
484
485 #[test]
486 fn range_coder_transform_type_all_symbols() {
487 for sym_in in 0u8..16 {
488 let mut rc = RangeCoder::new();
489 encode_symbol_table(&mut rc, sym_in, 0, &TRANSFORM_TYPE_CDF).expect("encode tx type");
490 let bs = rc.flush();
491
492 let mut dec = RangeCoder::new();
493 dec.init_from_slice(&bs).expect("init");
494 let sym_out = decode_symbol_table(&mut dec, 0, &TRANSFORM_TYPE_CDF).expect("decode tx");
495 assert_eq!(
496 sym_out, sym_in,
497 "transform type {sym_in} must survive round-trip"
498 );
499 }
500 }
501
502 #[test]
503 fn range_coder_ac_skip_roundtrip() {
504 let symbols = [0u8, 1, 0, 0, 1, 1, 0, 1];
505 let mut rc = RangeCoder::new();
506 for &s in &symbols {
507 encode_symbol_table(&mut rc, s, 0, &AC_COEFF_SKIP_CDF).expect("encode");
508 }
509 let bs = rc.flush();
510
511 let mut dec = RangeCoder::new();
512 dec.init_from_slice(&bs).expect("init");
513 for &expected in &symbols {
514 let got = decode_symbol_table(&mut dec, 0, &AC_COEFF_SKIP_CDF).expect("decode");
515 assert_eq!(got, expected);
516 }
517 }
518
519 #[test]
520 fn range_coder_sequence_mixed_tables() {
521 let dc_syms = [0u8, 1, 0];
523 let tx_syms = [0u8, 5, 15];
524 let pt_syms = [3u8, 0, 2];
525
526 let mut rc = RangeCoder::new();
527 for i in 0..3 {
528 encode_symbol_table(&mut rc, dc_syms[i], 0, &DC_COEFF_SKIP_CDF).expect("encode dc");
529 encode_symbol_table(&mut rc, tx_syms[i], 0, &TRANSFORM_TYPE_CDF).expect("encode tx");
530 encode_symbol_table(&mut rc, pt_syms[i], 0, &PARTITION_TYPE_CDF).expect("encode pt");
531 }
532 let bs = rc.flush();
533
534 let mut dec = RangeCoder::new();
535 dec.init_from_slice(&bs).expect("init");
536 for i in 0..3 {
537 let dc = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode dc");
538 let tx = decode_symbol_table(&mut dec, 0, &TRANSFORM_TYPE_CDF).expect("decode tx");
539 let pt = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode pt");
540 assert_eq!(dc, dc_syms[i]);
541 assert_eq!(tx, tx_syms[i]);
542 assert_eq!(pt, pt_syms[i]);
543 }
544 }
545
546 #[test]
547 fn range_coder_long_sequence_dc_skip() {
548 let symbols: Vec<u8> = (0u8..100).map(|i| i % 2).collect();
550
551 let mut rc = RangeCoder::new();
552 for &s in &symbols {
553 encode_symbol_table(&mut rc, s, 0, &DC_COEFF_SKIP_CDF).expect("encode");
554 }
555 let bs = rc.flush();
556
557 let mut dec = RangeCoder::new();
558 dec.init_from_slice(&bs).expect("init");
559 for (i, &expected) in symbols.iter().enumerate() {
560 let got = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
561 assert_eq!(got, expected, "mismatch at symbol {i}");
562 }
563 }
564
565 #[test]
566 fn range_coder_all_same_symbol_zero() {
567 let n = 50;
568 let mut rc = RangeCoder::new();
569 for _ in 0..n {
570 encode_symbol_table(&mut rc, 0, 0, &PARTITION_TYPE_CDF).expect("encode");
571 }
572 let bs = rc.flush();
573
574 let mut dec = RangeCoder::new();
575 dec.init_from_slice(&bs).expect("init");
576 for i in 0..n {
577 let got = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode");
578 assert_eq!(got, 0u8, "all-zero sequence failed at index {i}");
579 }
580 }
581
582 #[test]
583 fn range_coder_context_out_of_range_error() {
584 let mut rc = RangeCoder::new();
585 let result = encode_symbol_table(&mut rc, 0, 1, &DC_COEFF_SKIP_CDF);
587 assert!(result.is_err(), "context 1 should be out of range");
588 }
589
590 #[test]
591 fn range_coder_symbol_out_of_range_error() {
592 let mut rc = RangeCoder::new();
593 let result = encode_symbol_table(&mut rc, 2, 0, &DC_COEFF_SKIP_CDF);
595 assert!(result.is_err(), "symbol 2 should be out of range");
596 }
597
598 #[test]
599 fn range_coder_empty_bitstream_error() {
600 let mut dec = RangeCoder::new();
601 let result = dec.init_from_slice(&[]);
602 assert!(result.is_err(), "empty bitstream must return error");
603 }
604
605 #[test]
606 fn range_coder_new_is_in_encode_mode() {
607 let rc = RangeCoder::new();
608 assert!(!rc.decode_mode, "new coder should be in encode mode");
609 assert_eq!(rc.output.len(), 0, "no output yet");
610 }
611
612 #[test]
613 fn range_coder_flush_produces_bytes() {
614 let mut rc = RangeCoder::new();
615 encode_symbol_table(&mut rc, 0, 0, &DC_COEFF_SKIP_CDF).expect("encode");
616 let bs = rc.flush();
617 assert!(!bs.is_empty(), "flush must produce at least one byte");
618 }
619
620 #[test]
621 fn benchmark_table_vs_scalar_estimate() {
622 let symbols: Vec<u8> = (0u8..200).cycle().take(10_000).map(|x| x % 2).collect();
625
626 let mut rc = RangeCoder::new();
627 for &s in &symbols {
628 encode_symbol_table(&mut rc, s, 0, &DC_COEFF_SKIP_CDF).expect("encode");
629 }
630 let bs = rc.flush();
631
632 assert!(
634 bs.len() <= 2500,
635 "compressed size {} should be ≤ 2500 bytes for {}-symbol DC skip stream",
636 bs.len(),
637 symbols.len()
638 );
639
640 let mut dec = RangeCoder::new();
642 dec.init_from_slice(&bs).expect("init");
643 for (i, &expected) in symbols.iter().enumerate() {
644 let got = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
645 assert_eq!(got, expected, "bulk decode mismatch at index {i}");
646 }
647 }
648}