rust_h265 0.1.0

Pure Rust H.265/HEVC video decoder (Main and Main 10 profile, 8-bit and 10-bit 4:2:0)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
//! HEVC CABAC (Context-Adaptive Binary Arithmetic Coding) decoder.
//!
//! Implements the binary arithmetic decoder per HEVC spec section 9.3.4.
//! The arithmetic engine itself (rangeTabLPS, transIdxLPS/MPS, renormalization)
//! is mathematically identical to H.264 CABAC; HEVC inherited it verbatim.
//! Only the per-syntax-element context **init values** and the per-element
//! **context derivation rules** differ between H.264 and HEVC.
//!
//! State encoding follows the FFmpeg-style packed convention used by
//! `rust_h264`: a single `u8` per context where the low bit is `valMps`
//! (0 = MPS is 0, 1 = MPS is 1) and the upper bits encode `pStateIdx`.
//! Concretely the combined value is `2 * pStateIdx + valMps`, ranging
//! 0..=125 for legal states. State transitions go through `MLPS_STATE`,
//! which has 256 entries: indices 0..127 are MPS transitions, 128..255 are
//! LPS transitions.

use crate::cabac_tables::{HEVC_CONTEXTS, INIT_VALUES, LPS_RANGE, MLPS_STATE, NORM_SHIFT};
use crate::slice::SliceType;

/// Bits buffered in the low register beyond the active range.
const CABAC_BITS: u32 = 16;
/// Mask for the buffered region of `low`.
const CABAC_MASK: u32 = (1 << CABAC_BITS) - 1;

#[inline]
fn norm_shift(range: u32) -> u32 {
    NORM_SHIFT[range as usize] as u32
}

#[inline]
fn lps_range_lookup(range: u32, state: u8) -> u32 {
    // Bits 6..7 of `range` select one of four LPS sub-ranges (spec table 9-41 /
    // FFmpeg-style 4-way grouping); the state value picks the row.
    LPS_RANGE[2 * (range & 0xC0) as usize + state as usize] as u32
}

/// HEVC CABAC binary arithmetic decoder.
///
/// Operates on the slice data RBSP **after** the slice header's byte
/// alignment. The caller is responsible for handing the right slice of bytes
/// (i.e. starting at the byte CABAC initialization should consume first).
pub struct CabacReader<'a> {
    low: u32,
    range: u32,
    data: &'a [u8],
    pos: usize,
}

impl<'a> CabacReader<'a> {
    /// Initialize the CABAC decoder from RBSP data at a given byte position
    /// (HEVC spec 9.3.2.2).
    ///
    /// The 2-byte aligned initialization matches the FFmpeg-style code: load
    /// two bytes into `low`, plant a fixed bias of `1 << 9` in place of the
    /// "third" byte, and let the first renormalization refill consume the
    /// next two bytes naturally.
    pub fn new(data: &'a [u8], byte_offset: usize) -> Self {
        assert!(
            byte_offset + 2 <= data.len(),
            "CABAC init needs at least 2 bytes"
        );
        let mut low: u32 = (data[byte_offset] as u32) << 18;
        low = low.wrapping_add((data[byte_offset + 1] as u32) << 10);
        low = low.wrapping_add(1 << 9);
        CabacReader {
            low,
            range: 0x1FE,
            data,
            pos: byte_offset + 2,
        }
    }

    /// Refill the buffered region of `low` after a renormalization that
    /// emptied it (spec 9.3.4.3.1 — the `read_bits()` step).
    #[inline]
    fn refill2(&mut self) {
        let i = self.low.trailing_zeros().wrapping_sub(CABAC_BITS);

        let b0 = if self.pos < self.data.len() {
            self.data[self.pos]
        } else {
            0
        };
        let b1 = if self.pos + 1 < self.data.len() {
            self.data[self.pos + 1]
        } else {
            0
        };
        let x = (b0 as u32) << 9 | (b1 as u32) << 1;
        let x = x.wrapping_sub(CABAC_MASK);
        self.low = self.low.wrapping_add(x << i);
        self.pos += 2;
    }

    /// Refill used by bypass / terminate paths (single renormalization step).
    #[inline]
    fn refill(&mut self) {
        let b0 = if self.pos < self.data.len() {
            self.data[self.pos]
        } else {
            0
        };
        let b1 = if self.pos + 1 < self.data.len() {
            self.data[self.pos + 1]
        } else {
            0
        };
        self.low = self.low.wrapping_add((b0 as u32) << 9);
        self.low = self.low.wrapping_add((b1 as u32) << 1);
        self.low = self.low.wrapping_sub(CABAC_MASK);
        self.pos += 2;
    }

    /// Decode a context-coded bin (HEVC spec 9.3.4.3.2). Returns 0 or 1 and
    /// updates the context state.
    #[inline]
    pub fn decode_bin(&mut self, state: &mut u8) -> u32 {
        let s = *state;
        let range_lps = lps_range_lookup(self.range, s);

        self.range -= range_lps;
        // FFmpeg-style branchless MPS/LPS resolution: build a sign-extended
        // mask that's 0 for MPS and 0xFFFFFFFF for LPS.
        let lps_mask =
            (((self.range << (CABAC_BITS + 1)).wrapping_sub(self.low)) as i32 >> 31) as u32;

        self.low = self
            .low
            .wrapping_sub((self.range << (CABAC_BITS + 1)) & lps_mask);
        self.range = self
            .range
            .wrapping_add(range_lps.wrapping_sub(self.range) & lps_mask);

        let s_signed = (s as i32) ^ (lps_mask as i32);
        *state = MLPS_STATE[(128 + s_signed) as usize];
        let bit = (s_signed & 1) as u32;

        // Renormalization (spec 9.3.4.3.5).
        let shift = norm_shift(self.range);
        self.range <<= shift;
        self.low = self.low.wrapping_shl(shift);
        if self.low & CABAC_MASK == 0 {
            self.refill2();
        }
        bit
    }

    /// Decode a bypass bin (HEVC spec 9.3.4.3.4). Equiprobable, no context.
    #[inline]
    pub fn decode_bypass(&mut self) -> u32 {
        self.low = self.low.wrapping_add(self.low);
        if self.low & CABAC_MASK == 0 {
            self.refill();
        }
        let range = self.range << (CABAC_BITS + 1);
        if self.low < range {
            0
        } else {
            self.low = self.low.wrapping_sub(range);
            1
        }
    }

    /// Decode `n` bypass bins as an unsigned integer (MSB first).
    pub fn decode_bypass_bits(&mut self, n: u8) -> u32 {
        let mut val = 0u32;
        for _ in 0..n {
            val = (val << 1) | self.decode_bypass();
        }
        val
    }

    /// Decode the end-of-slice flag (HEVC spec 9.3.4.3.5).
    /// Returns 1 if the slice is terminating, 0 otherwise.
    pub fn decode_terminate(&mut self) -> u32 {
        self.range -= 2;
        if self.low < self.range << (CABAC_BITS + 1) {
            // Renormalize once
            let shift = (self.range.wrapping_sub(0x100)) >> 31;
            self.range <<= shift;
            self.low = self.low.wrapping_shl(shift);
            if self.low & CABAC_MASK == 0 {
                self.refill();
            }
            0
        } else {
            1
        }
    }

    /// Current byte position in the RBSP (used for diagnostics).
    pub fn position(&self) -> usize {
        self.pos
    }

    /// Borrow the underlying RBSP byte slice (used by PCM sample decoding
    /// after `pcm_byte_position` identifies where the raw bytes start).
    pub fn rbsp(&self) -> &[u8] {
        self.data
    }

    /// After decoding a `pcm_flag = 1` terminate bin, compute the byte offset
    /// in the underlying RBSP where the first raw PCM byte lives. Mirrors
    /// FFmpeg's `skip_bytes` (`libavcodec/cabac_functions.h`) which walks the
    /// `bytestream` pointer back by up to two bytes based on the current
    /// renormalization state (`low & 0x1` / `low & 0x1FF`).
    ///
    /// Background: `self.pos` is the byte CABAC would next refill from. But
    /// because of the 16-bit buffered window and the post-decode
    /// renormalization, anywhere from zero to two of those "already-consumed"
    /// bytes are still sitting in `low` unused when the terminate bin fires.
    /// Those unused bytes become the PCM byte stream; subtracting them from
    /// `self.pos` gives the address of the first PCM byte.
    pub fn pcm_byte_position(&self) -> usize {
        let mut ptr = self.pos;
        if self.low & 0x1 != 0 {
            ptr -= 1;
        }
        // CABAC_BITS == 16 in our implementation, so the second check always
        // applies (see FFmpeg's `#if CABAC_BITS == 16`).
        if self.low & 0x1FF != 0 {
            ptr -= 1;
        }
        ptr
    }

    /// Re-initialize the CABAC engine at a new byte offset. Used after the
    /// PCM block's raw bytes have been consumed, to resume CABAC decoding at
    /// the next byte boundary (HEVC spec 7.3.8.5 / FFmpeg `ff_init_cabac_decoder`).
    pub fn reinit_at(&mut self, byte_offset: usize) {
        assert!(
            byte_offset + 2 <= self.data.len(),
            "CABAC reinit needs at least 2 bytes"
        );
        let mut low: u32 = (self.data[byte_offset] as u32) << 18;
        low = low.wrapping_add((self.data[byte_offset + 1] as u32) << 10);
        low = low.wrapping_add(1 << 9);
        self.low = low;
        self.range = 0x1FE;
        self.pos = byte_offset + 2;
    }
}

/// Initialize a single CABAC context from a HEVC packed init value
/// (spec 9.3.4.2.1).
///
/// The HEVC init value packs two 4-bit indices: the high nibble is `slopeIdx`
/// and the low nibble is `offsetIdx`. From these we derive `m` and `n` and
/// then compute `preCtxState`, exactly as in H.264 — the only difference is
/// the packing convention.
///
/// Returns the FFmpeg-style packed state value (`2 * pStateIdx + valMps`)
/// expected by [`CabacReader::decode_bin`].
pub fn init_state(init_value: u8, slice_qp: i32) -> u8 {
    let slope_idx = (init_value >> 4) as i32;
    let offset_idx = (init_value & 0x0F) as i32;
    let m = slope_idx * 5 - 45;
    let n = (offset_idx << 3) - 16;
    let qp = slice_qp.clamp(0, 51);

    // Same trick as rust_h264: compute `2 * preCtxState - 127`, fold the sign
    // away, and clamp into [0, 125]. The mapping is:
    //   preCtxState <= 63: pStateIdx = 63 - pre, valMps = 0
    //                       → packed = 2*(63 - pre) + 0 = 126 - 2*pre
    //   preCtxState >= 64: pStateIdx = pre - 64, valMps = 1
    //                       → packed = 2*(pre - 64) + 1 = 2*pre - 127
    // Both cases collapse to `|2*pre - 127|`.
    let mut pre = 2 * (((m * qp) >> 4) + n) - 127;
    pre ^= pre >> 31;
    if pre > 124 {
        pre = 124 + (pre & 1);
    }
    pre as u8
}

/// Initialize a slice of contexts in one shot. `init_values[i]` corresponds
/// to `state[i]`.
pub fn init_states(init_values: &[u8], slice_qp: i32, state: &mut [u8]) {
    debug_assert_eq!(init_values.len(), state.len());
    for (i, &iv) in init_values.iter().enumerate() {
        state[i] = init_state(iv, slice_qp);
    }
}

/// Compute the HEVC `initType` index into [`INIT_VALUES`] for the given
/// slice type and `cabac_init_flag` (spec 9.3.2.2 / FFmpeg
/// `cabac_init_state`).
///
/// Mapping:
/// - I slice → 0 (`cabac_init_flag` is not present in I slices, ignored)
/// - P slice → 1, or 2 if `cabac_init_flag` is set
/// - B slice → 2, or 1 if `cabac_init_flag` is set
pub fn init_type_for_slice(slice_type: SliceType, cabac_init_flag: bool) -> usize {
    let base = 2 - slice_type as i32;
    let it = if cabac_init_flag && slice_type != SliceType::I {
        base ^ 3
    } else {
        base
    };
    it as usize
}

/// Full set of HEVC CABAC contexts for one slice. Layout matches FFmpeg's
/// `cabac_state[HEVC_CONTEXTS]` exactly so the offsets in
/// [`crate::cabac_tables::ctx`] index directly into [`Self::state`].
pub struct CabacContexts {
    pub state: [u8; HEVC_CONTEXTS],
}

impl CabacContexts {
    /// Initialize all 179 contexts from the slice QP, slice type and
    /// `cabac_init_flag` (HEVC spec 9.3.2.2).
    pub fn init(slice_qp: i32, slice_type: SliceType, cabac_init_flag: bool) -> Self {
        let init_type = init_type_for_slice(slice_type, cabac_init_flag);
        let row = &INIT_VALUES[init_type];
        let mut state = [0u8; HEVC_CONTEXTS];
        for i in 0..HEVC_CONTEXTS {
            state[i] = init_state(row[i], slice_qp);
        }
        Self { state }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Hand-decoded reference: a CABAC bypass-only stream of 8 bits with the
    /// MSB-first value `0b10110100 = 0xB4`. Encoded as the trivial range-coded
    /// representation (8 bypass bins). We use this to verify the bypass path
    /// without depending on context tables.
    ///
    /// To construct the byte stream we leverage the fact that bypass bins are
    /// equivalent to bit reads — the engine just shifts `low` left by one and
    /// compares against `range << (CABAC_BITS + 1)`. With `range == 510`
    /// throughout (no context updates), each bypass bin maps directly to a
    /// bit of the input. So a stream that emits the MSB of `0xB4` first must
    /// start with byte `0xB4`.
    #[test]
    fn test_decode_bypass_round_trip() {
        // 16 bytes is well above CABAC's 2-byte init footprint and avoids any
        // padding edge cases.
        let data = [0xB4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
        let mut cabac = CabacReader::new(&data, 0);
        // After reading 0xB4 = 0b1011_0100 we should observe these bits in order.
        let expected = [1, 0, 1, 1, 0, 1, 0, 0];
        for &b in &expected {
            assert_eq!(cabac.decode_bypass(), b, "bypass bin mismatch");
        }
    }

    /// Decoding `n` bypass bits should equal the multi-bit helper.
    #[test]
    fn test_decode_bypass_bits_matches_loop() {
        let data = [0xCA, 0xFE, 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC];
        let mut a = CabacReader::new(&data, 0);
        let mut b = CabacReader::new(&data, 0);
        let n = 12;
        let v1 = a.decode_bypass_bits(n);
        let mut v2 = 0u32;
        for _ in 0..n {
            v2 = (v2 << 1) | b.decode_bypass();
        }
        assert_eq!(v1, v2);
        assert_eq!(v1 >> (32 - n), 0); // top bits zero
    }

    /// `init_state` clamps QP to [0, 51] and produces a state in [0, 125].
    #[test]
    fn test_init_state_in_range() {
        for &iv in &[0u8, 0x10, 0x77, 0x88, 0xC9, 0xFE, 0xFF] {
            for &qp in &[-10i32, 0, 1, 26, 51, 60] {
                let s = init_state(iv, qp);
                assert!(s <= 125, "state {} out of range for iv={} qp={}", s, iv, qp);
            }
        }
    }

    /// Hand-computed initial states for the Phase 1 fixture's slice QP (25)
    /// and a handful of context offsets. Each derivation is in a comment for
    /// auditability — if any of these break, the table or formula is wrong.
    #[test]
    fn test_init_known_states_at_qp25() {
        use crate::cabac_tables::ctx;
        let ctxs = CabacContexts::init(25, SliceType::I, false);

        // split_cu_flag[0] (init_value = 139): slope=8 offset=11
        //   m = -5, n = 72
        //   pre_inner = (-5 * 25) >> 4 = -8;  pre_inner + n = 64
        //   2*64 - 127 = 1;  state = 1
        assert_eq!(ctxs.state[ctx::SPLIT_CODING_UNIT_FLAG], 1);

        // prev_intra_luma_pred_flag (init_value = 184): slope=11 offset=8
        //   m = 10, n = 48
        //   pre_inner = (10 * 25) >> 4 = 15;  pre_inner + n = 63
        //   2*63 - 127 = -1; |.| = 0;  state = 0
        assert_eq!(ctxs.state[ctx::PREV_INTRA_LUMA_PRED_FLAG], 0);

        // cbf_luma[0] (init_value = 111): slope=6 offset=15
        //   m = -15, n = 104
        //   pre_inner = (-15 * 25) >> 4 = -24;  pre_inner + n = 80
        //   2*80 - 127 = 33;  state = 33
        assert_eq!(ctxs.state[ctx::CBF_LUMA], 33);

        // intra_chroma_pred_mode[0] (init_value = 63): slope=3 offset=15
        //   m = -30, n = 104
        //   pre_inner = (-30 * 25) >> 4 = -47;  pre_inner + n = 57
        //   2*57 - 127 = -13;  state = 12
        assert_eq!(ctxs.state[ctx::INTRA_CHROMA_PRED_MODE], 12);
    }

    /// `init_type_for_slice` mapping must match HEVC spec 9.3.2.2.
    #[test]
    fn test_init_type_for_slice() {
        // I-slice ignores cabac_init_flag and always picks 0.
        assert_eq!(init_type_for_slice(SliceType::I, false), 0);
        assert_eq!(init_type_for_slice(SliceType::I, true), 0);
        // P-slice: 1 by default, 2 with cabac_init_flag.
        assert_eq!(init_type_for_slice(SliceType::P, false), 1);
        assert_eq!(init_type_for_slice(SliceType::P, true), 2);
        // B-slice: 2 by default, 1 with cabac_init_flag.
        assert_eq!(init_type_for_slice(SliceType::B, false), 2);
        assert_eq!(init_type_for_slice(SliceType::B, true), 1);
    }

    /// `init_state` matches the explicit two-step HEVC formula.
    #[test]
    fn test_init_state_matches_spec_formula() {
        // Recreate (pStateIdx, valMps) by hand and compare to packed encoding.
        for iv in 0u8..=255 {
            for &qp in &[0, 13, 26, 37, 51] {
                let slope_idx = (iv >> 4) as i32;
                let offset_idx = (iv & 0x0F) as i32;
                let m = slope_idx * 5 - 45;
                let n = (offset_idx << 3) - 16;
                let pre = (((m * qp) >> 4) + n).clamp(1, 126);
                let (p_state_idx, val_mps): (i32, i32) = if pre <= 63 {
                    (63 - pre, 0)
                } else {
                    (pre - 64, 1)
                };
                let expected_packed = (2 * p_state_idx + val_mps) as u8;
                // The packed encoding uses the same formula but via the
                // FFmpeg `|2*pre - 127|` trick. Verify equivalence.
                let actual = init_state(iv, qp);
                assert_eq!(
                    actual, expected_packed,
                    "init_state(iv={iv:#04x}, qp={qp}) packed mismatch: got {actual}, want {expected_packed}",
                );
            }
        }
    }

    /// `reinit_at` should reset the arithmetic engine to a fresh state,
    /// equivalent to constructing a new `CabacReader` at the same offset.
    #[test]
    fn test_reinit_at_matches_fresh_new() {
        let data = [0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67];
        let mut a = CabacReader::new(&data, 0);
        // Consume a few bypass bins to advance state.
        let _ = a.decode_bypass_bits(4);
        a.reinit_at(4);
        let b = CabacReader::new(&data, 4);
        assert_eq!(a.low, b.low);
        assert_eq!(a.range, b.range);
        assert_eq!(a.pos, b.pos);
    }
}