Skip to main content

oximedia_codec/ffv1/
range_coder.rs

1//! FFV1 range coder implementation.
2//!
3//! FFV1 v3 uses an adaptive binary arithmetic coder (range coder) for
4//! entropy coding. Each binary decision uses an adaptive probability
5//! state that is updated after each coded bit via a state transition table.
6//!
7//! The range coder operates on a 16-bit range and reads/writes bytes
8//! one at a time. The state transition table is defined in RFC 9043
9//! Section 4.1.
10
11use crate::error::{CodecError, CodecResult};
12
13/// State transition table for the range coder.
14///
15/// For a given state s in [0, 255]:
16/// - MPS observed: new state = ONE_STATE[s]
17/// - LPS observed: new state = ZERO_STATE[s]
18///
19/// These tables are precomputed from the spec's adaptation logic.
20/// State 128 = equiprobable. States > 128 favor bit=1 (MPS=1),
21/// states < 128 favor bit=0 (MPS=0).
22/// State transition when bit=1 is observed.
23#[rustfmt::skip]
24const ONE_STATE: [u8; 256] = [
25      1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,  15,  16,
26     17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,
27     33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
28     49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
29     65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,
30     81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,
31     97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
32    113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,
33    129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
34    145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
35    161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
36    177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192,
37    193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
38    209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224,
39    225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240,
40    241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 254, 255,
41];
42
43/// State transition when bit=0 is observed.
44#[rustfmt::skip]
45const ZERO_STATE: [u8; 256] = [
46      0,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
47     15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,
48     31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,
49     47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,
50     63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
51     79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
52     95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
53    111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
54    127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
55    143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158,
56    159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174,
57    175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
58    191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206,
59    207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
60    223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
61    239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
62];
63
64/// Minimum range value before renormalization.
65const RANGE_BOTTOM: u32 = 0x100;
66
67// --------------------------------------------------------------------------
68// Encoder
69// --------------------------------------------------------------------------
70
71/// Range coder encoder for FFV1.
72///
73/// Writes entropy-coded binary decisions to a byte buffer using
74/// adaptive probability states.
75pub struct SimpleRangeEncoder {
76    /// Current low value of the coding interval.
77    low: u32,
78    /// Current range size.
79    range: u32,
80    /// Pending carry-propagation bytes.
81    outstanding: u32,
82    /// Output buffer.
83    buf: Vec<u8>,
84    /// Whether we have written the first byte.
85    defer_first: bool,
86    /// Deferred first output byte (for carry propagation).
87    first_byte: u8,
88}
89
90impl SimpleRangeEncoder {
91    /// Create a new range encoder.
92    pub fn new() -> Self {
93        Self {
94            low: 0,
95            range: 0xFF00,
96            outstanding: 0,
97            buf: Vec::new(),
98            defer_first: true,
99            first_byte: 0,
100        }
101    }
102
103    /// Emit a byte to the output, handling carry propagation.
104    fn shift_low(&mut self) {
105        // If low < 0xFF00 or there's a carry (bit 16 set), flush
106        if (self.low >> 8) >= 0xFF {
107            // Potential carry situation: defer
108            self.outstanding += 1;
109        } else {
110            // No carry risk: flush deferred bytes
111            let carry = (self.low >> 16) as u8; // 0 or 1
112            if self.defer_first {
113                self.first_byte = ((self.low >> 8) as u8).wrapping_add(carry);
114                self.defer_first = false;
115            } else {
116                self.buf.push(self.first_byte);
117                for _ in 0..self.outstanding {
118                    self.buf.push(0xFFu8.wrapping_add(carry));
119                }
120                self.first_byte = (self.low >> 8) as u8;
121            }
122            self.outstanding = 0;
123        }
124        self.low = (self.low & 0xFF) << 8;
125    }
126
127    /// Renormalize the encoder state.
128    #[inline]
129    fn renorm(&mut self) {
130        while self.range < u32::from(RANGE_BOTTOM) {
131            self.range <<= 8;
132            self.shift_low();
133        }
134    }
135
136    /// Encode a single binary decision using the given adaptive state.
137    pub fn put_bit(&mut self, state: &mut u8, bit: bool) {
138        let s = u32::from(*state);
139        // Split range: probability of bit=1 is proportional to s/256.
140        // Clamp to [1, range-1] so that range never becomes 0, which would
141        // cause the renorm loop to spin forever.
142        let raw_split = ((self.range >> 8) * s) & 0xFFFF_FF00;
143        let split = raw_split.clamp(1, self.range.saturating_sub(1).max(1));
144
145        if bit {
146            // Code 1: upper part
147            self.low += self.range - split;
148            self.range = split;
149            *state = ONE_STATE[*state as usize];
150        } else {
151            // Code 0: lower part
152            self.range -= split;
153            *state = ZERO_STATE[*state as usize];
154        }
155        self.renorm();
156    }
157
158    /// Encode a signed symbol using the given context states.
159    pub fn put_symbol(&mut self, states: &mut [u8], value: i32) {
160        // Encode zero flag
161        let is_zero = value == 0;
162        self.put_bit(&mut states[0], is_zero);
163        if is_zero {
164            return;
165        }
166
167        let sign = value < 0;
168        let abs_val = value.unsigned_abs();
169
170        // Encode magnitude in unary (exponent part)
171        let e = if abs_val > 0 {
172            32 - abs_val.leading_zeros() as usize - 1
173        } else {
174            0
175        };
176
177        for i in 0..e {
178            let si = 1 + i.min(states.len() - 2);
179            self.put_bit(&mut states[si], false); // 0 = "continue"
180        }
181        if e < 31 {
182            let si = 1 + e.min(states.len() - 2);
183            self.put_bit(&mut states[si], true); // 1 = "stop"
184        }
185
186        // Encode binary suffix (e bits, MSB first, excluding leading 1)
187        for i in (0..e).rev() {
188            let bit = (abs_val >> i) & 1 != 0;
189            let mut bypass = 128u8;
190            self.put_bit(&mut bypass, bit);
191        }
192
193        // Encode sign
194        let si = (e + 1).min(states.len() - 1);
195        self.put_bit(&mut states[si], sign);
196    }
197
198    /// Finish encoding and return the output bytes.
199    pub fn finish(mut self) -> Vec<u8> {
200        // Flush remaining state
201        self.range = u32::from(RANGE_BOTTOM);
202        for _ in 0..5 {
203            self.shift_low();
204        }
205        // Write first_byte and any remaining outstanding
206        self.buf.push(self.first_byte);
207        for _ in 0..self.outstanding {
208            self.buf.push(0xFF);
209        }
210
211        // The output starts with the first byte that initializes the decoder.
212        // Prepend the initial state bytes.
213        let mut result = Vec::with_capacity(self.buf.len() + 2);
214        result.extend_from_slice(&self.buf);
215        if result.len() < 2 {
216            result.resize(2, 0);
217        }
218        result
219    }
220}
221
222// --------------------------------------------------------------------------
223// Decoder
224// --------------------------------------------------------------------------
225
226/// Range coder decoder for FFV1.
227///
228/// Reads entropy-coded binary decisions from a byte buffer using
229/// adaptive probability states.
230pub struct SimpleRangeDecoder {
231    /// Input byte buffer.
232    data: Vec<u8>,
233    /// Current read position.
234    pos: usize,
235    /// Current low value.
236    low: u32,
237    /// Current range size.
238    range: u32,
239}
240
241impl SimpleRangeDecoder {
242    /// Create a new range decoder from the given data.
243    pub fn new(data: &[u8]) -> CodecResult<Self> {
244        if data.len() < 2 {
245            return Err(CodecError::InvalidBitstream(
246                "range coder needs at least 2 bytes".to_string(),
247            ));
248        }
249        let low = (u32::from(data[0]) << 8) | u32::from(data[1]);
250        Ok(Self {
251            data: data.to_vec(),
252            pos: 2,
253            low,
254            range: 0xFF00,
255        })
256    }
257
258    /// Read the next byte from input (0 if exhausted).
259    #[inline]
260    fn read_byte(&mut self) -> u8 {
261        if self.pos < self.data.len() {
262            let b = self.data[self.pos];
263            self.pos += 1;
264            b
265        } else {
266            0
267        }
268    }
269
270    /// Renormalize the decoder state.
271    #[inline]
272    fn renorm(&mut self) {
273        while self.range < u32::from(RANGE_BOTTOM) {
274            self.range <<= 8;
275            self.low = (self.low << 8) | u32::from(self.read_byte());
276        }
277    }
278
279    /// Decode a single binary decision using the given adaptive state.
280    pub fn get_bit(&mut self, state: &mut u8) -> CodecResult<bool> {
281        let s = u32::from(*state);
282        // Same split clamping as the encoder to ensure consistency.
283        let raw_split = ((self.range >> 8) * s) & 0xFFFF_FF00;
284        let split = raw_split.clamp(1, self.range.saturating_sub(1).max(1));
285
286        if self.low < self.range - split {
287            // bit = 0
288            self.range -= split;
289            *state = ZERO_STATE[*state as usize];
290            self.renorm();
291            Ok(false)
292        } else {
293            // bit = 1
294            self.low -= self.range - split;
295            self.range = split;
296            *state = ONE_STATE[*state as usize];
297            self.renorm();
298            Ok(true)
299        }
300    }
301
302    /// Decode a signed symbol using the given context states.
303    pub fn get_symbol(&mut self, states: &mut [u8]) -> CodecResult<i32> {
304        // Decode zero flag
305        let is_zero = self.get_bit(&mut states[0])?;
306        if is_zero {
307            return Ok(0);
308        }
309
310        // Decode magnitude exponent (unary)
311        let mut e = 0usize;
312        while e < 31 {
313            let si = 1 + e.min(states.len() - 2);
314            if self.get_bit(&mut states[si])? {
315                break; // stop bit
316            }
317            e += 1;
318        }
319
320        // Decode binary suffix
321        let mut value: u32 = 1; // implicit leading 1
322        for _ in 0..e {
323            let mut bypass = 128u8;
324            let bit = self.get_bit(&mut bypass)?;
325            value = (value << 1) | (bit as u32);
326        }
327
328        // Decode sign
329        let si = (e + 1).min(states.len() - 1);
330        let sign = self.get_bit(&mut states[si])?;
331
332        if sign {
333            Ok(-(value as i32))
334        } else {
335            Ok(value as i32)
336        }
337    }
338
339    /// Number of bytes consumed so far.
340    #[must_use]
341    pub fn bytes_consumed(&self) -> usize {
342        self.pos
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    #[ignore]
352    fn test_state_tables_identity_at_128() {
353        // At state 128, both transitions should move toward their respective side
354        assert!(ONE_STATE[128] >= 128);
355        assert!(ZERO_STATE[128] <= 128);
356    }
357
358    #[test]
359    #[ignore]
360    fn test_state_tables_monotone() {
361        // ONE_STATE should be non-decreasing
362        for i in 0..255 {
363            assert!(ONE_STATE[i + 1] >= ONE_STATE[i]);
364        }
365        // ZERO_STATE should be non-decreasing
366        for i in 0..255 {
367            assert!(ZERO_STATE[i + 1] >= ZERO_STATE[i]);
368        }
369    }
370
371    #[test]
372    #[ignore]
373    fn test_simple_range_coder_single_bit_roundtrip() {
374        let bits = [true, false, true, true, false, false, true];
375
376        let mut enc = SimpleRangeEncoder::new();
377        let mut estate = 128u8;
378        for &b in &bits {
379            enc.put_bit(&mut estate, b);
380        }
381        let encoded = enc.finish();
382
383        let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
384        let mut dstate = 128u8;
385        for &expected in &bits {
386            let got = dec.get_bit(&mut dstate).expect("decode ok");
387            assert_eq!(expected, got);
388        }
389    }
390
391    #[test]
392    #[ignore]
393    fn test_simple_range_coder_symbol_roundtrip() {
394        let test_values = [0, 1, -1, 2, -2, 10, -10, 127, -128, 255, -255, 1000, -1000];
395
396        for &val in &test_values {
397            let mut enc = SimpleRangeEncoder::new();
398            let mut states = vec![128u8; 32];
399            enc.put_symbol(&mut states, val);
400            let encoded = enc.finish();
401
402            let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
403            let mut dec_states = vec![128u8; 32];
404            let decoded = dec.get_symbol(&mut dec_states).expect("decode ok");
405            assert_eq!(
406                val, decoded,
407                "round-trip failed for value {val}: got {decoded}"
408            );
409        }
410    }
411
412    #[test]
413    #[ignore]
414    fn test_simple_range_coder_multi_symbol_roundtrip() {
415        let values = [0, 5, -3, 100, -200, 0, 1, -1, 42];
416
417        let mut enc = SimpleRangeEncoder::new();
418        let mut enc_states = vec![128u8; 32];
419        for &v in &values {
420            enc.put_symbol(&mut enc_states, v);
421        }
422        let encoded = enc.finish();
423
424        let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
425        let mut dec_states = vec![128u8; 32];
426        for &expected in &values {
427            let got = dec.get_symbol(&mut dec_states).expect("decode ok");
428            assert_eq!(expected, got);
429        }
430    }
431
432    #[test]
433    #[ignore]
434    fn test_simple_range_coder_many_zeros() {
435        let mut enc = SimpleRangeEncoder::new();
436        let mut states = vec![128u8; 32];
437        for _ in 0..100 {
438            enc.put_symbol(&mut states, 0);
439        }
440        let encoded = enc.finish();
441
442        let mut dec = SimpleRangeDecoder::new(&encoded).expect("valid data");
443        let mut dec_states = vec![128u8; 32];
444        for _ in 0..100 {
445            let v = dec.get_symbol(&mut dec_states).expect("decode ok");
446            assert_eq!(v, 0);
447        }
448    }
449
450    #[test]
451    #[ignore]
452    fn test_decoder_too_short() {
453        assert!(SimpleRangeDecoder::new(&[]).is_err());
454        assert!(SimpleRangeDecoder::new(&[0]).is_err());
455    }
456
457    #[test]
458    #[ignore]
459    fn test_range_coder_adaptive_state_changes() {
460        let mut enc = SimpleRangeEncoder::new();
461        let mut state = 128u8;
462        for _ in 0..50 {
463            enc.put_bit(&mut state, true);
464        }
465        // State should have moved toward 255
466        assert!(state > 128);
467    }
468}