1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use alloc::vec;
use alloc::vec::Vec;
use crate::arithmetic_decoder::{ArithmeticDecoder, Context};
/// The integer arithmetic decoder (A.2).
pub(crate) struct IntegerDecoder {
/// `CX` - Context memory for the integer decoder.
contexts: Vec<Context>,
}
impl IntegerDecoder {
/// Create a new integer decoder with fresh contexts.
#[inline(always)]
pub(crate) fn new() -> Self {
Self {
// A.2: "Each arithmetic integer decoding procedure requires 512 bytes of
// storage for its context memory."
contexts: vec![Context::default(); 512],
}
}
/// The integer arithmetic decoding procedure (A.2, Figure A.1).
///
/// Returns `Some(value)` on success, or `None` if OOB (out-of-band) is decoded.
#[inline(always)]
pub(crate) fn decode(&mut self, decoder: &mut ArithmeticDecoder<'_>) -> Option<i32> {
// A.2 step 1: "Set: PREV = 1"
// `PREV` - Context prefix, contains bits decoded so far plus a leading 1.
let mut prev: u32 = 1;
// A.2 step 2: "Follow the flowchart in Figure A.1. Decode each bit with
// CX equal to 'IAx + PREV' where 'IAx' represents the identifier of the
// current arithmetic integer decoding procedure, '+' represents
// concatenation, and the rightmost 9 bits of PREV are used."
// `S` - Sign bit.
let s = self.decode_bit(decoder, &mut prev);
// `V` - Magnitude value, decoded according to Figure A.1 flowchart.
#[expect(
clippy::same_functions_in_if_condition,
reason = "each call mutates `prev`"
)]
let v = if self.decode_bit(decoder, &mut prev) == 0 {
// Figure A.1: "V = next 2 bits"
self.decode_n_bits(decoder, &mut prev, 2)
} else if self.decode_bit(decoder, &mut prev) == 0 {
// Figure A.1: "V = (next 4 bits) + 4"
self.decode_n_bits(decoder, &mut prev, 4) + 4
} else if self.decode_bit(decoder, &mut prev) == 0 {
// Figure A.1: "V = (next 6 bits) + 20"
self.decode_n_bits(decoder, &mut prev, 6) + 20
} else if self.decode_bit(decoder, &mut prev) == 0 {
// Figure A.1: "V = (next 8 bits) + 84"
self.decode_n_bits(decoder, &mut prev, 8) + 84
} else if self.decode_bit(decoder, &mut prev) == 0 {
// Figure A.1: "V = (next 12 bits) + 340"
self.decode_n_bits(decoder, &mut prev, 12) + 340
} else {
// Figure A.1: "V = (next 32 bits) + 4436"
self.decode_n_bits(decoder, &mut prev, 32) + 4436
};
// A.2: "The result of the integer arithmetic decoding procedure is equal to:
// - V if S = 0
// - -V if S = 1 and V > 0
// - OOB if S = 1 and V = 0"
if s == 0 {
Some(v as i32)
} else if v > 0 {
Some(-(v as i32))
} else {
None
}
}
/// Decode a single bit and update `PREV` (A.2 step 3).
///
/// A.2 step 3: "After each bit is decoded: If PREV < 256 set:
/// PREV = (PREV << 1) OR D. Otherwise set:
/// PREV = (((PREV << 1) OR D) AND 511) OR 256"
#[inline(always)]
fn decode_bit(&mut self, decoder: &mut ArithmeticDecoder<'_>, prev: &mut u32) -> u32 {
let ctx_idx = (*prev & 0x1FF) as usize;
// `D` - The just-decoded bit.
let d = decoder.decode(&mut self.contexts[ctx_idx]);
// A.2 step 3: Update PREV.
if *prev < 256 {
*prev = (*prev << 1) | d;
} else {
// A.2: "PREV always contains the values of the eight most-recently-
// decoded bits, plus a leading 1 bit, which is used to indicate the
// number of bits decoded so far."
*prev = (((*prev << 1) | d) & 511) | 256;
}
d
}
/// Decode `n` bits and update `PREV` for each.
#[inline(always)]
fn decode_n_bits(
&mut self,
decoder: &mut ArithmeticDecoder<'_>,
prev: &mut u32,
n: usize,
) -> u32 {
let mut value = 0_u32;
for _ in 0..n {
let bit = self.decode_bit(decoder, prev);
value = (value << 1) | bit;
}
value
}
}