Skip to main content

oximedia_codec/flac/
rice.rs

1//! Rice coding for FLAC residuals.
2//!
3//! Rice coding is a special case of Golomb coding where the divisor `m = 2^k`.
4//! FLAC uses Rice coding (partition coding) to compress the LPC residuals.
5//!
6//! Each partition's Rice parameter `k` is optimised to minimise bit usage.
7
8#![forbid(unsafe_code)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_sign_loss)]
11#![allow(clippy::cast_possible_wrap)]
12
13/// Maximum Rice parameter (FLAC allows 0-14 for Rice1, 0-30 for Rice2).
14pub const MAX_RICE_PARAM: u8 = 14;
15
16/// Map a signed residual to an unsigned zigzag-encoded value.
17///
18/// FLAC encodes signed residuals as zigzag: `0 → 0, -1 → 1, 1 → 2, -2 → 3, 2 → 4, ...`
19#[inline]
20pub fn zigzag_encode(v: i32) -> u32 {
21    if v >= 0 {
22        (v as u32) << 1
23    } else {
24        ((-v - 1) as u32) << 1 | 1
25    }
26}
27
28/// Decode a zigzag-encoded unsigned value back to signed.
29#[inline]
30pub fn zigzag_decode(u: u32) -> i32 {
31    if u & 1 == 0 {
32        (u >> 1) as i32
33    } else {
34        -((u >> 1) as i32) - 1
35    }
36}
37
38/// Compute the Rice bit cost for encoding `residuals` with parameter `k`.
39///
40/// `cost = sum(1 + k + (zigzag(r) >> k))` bits per sample.
41#[must_use]
42pub fn rice_bit_cost(residuals: &[i32], k: u8) -> u64 {
43    residuals
44        .iter()
45        .map(|&r| {
46            let u = zigzag_encode(r);
47            let quotient = u >> k;
48            1u64 + u64::from(k) + u64::from(quotient)
49        })
50        .sum()
51}
52
53/// Select the optimal Rice parameter for a partition of residuals.
54///
55/// Tests `k = 0..=MAX_RICE_PARAM` and returns the best.
56#[must_use]
57pub fn optimal_rice_param(residuals: &[i32]) -> u8 {
58    if residuals.is_empty() {
59        return 0;
60    }
61    let mut best_k = 0u8;
62    let mut best_cost = u64::MAX;
63    for k in 0..=MAX_RICE_PARAM {
64        let cost = rice_bit_cost(residuals, k);
65        if cost < best_cost {
66            best_cost = cost;
67            best_k = k;
68        }
69    }
70    best_k
71}
72
73/// Encode residuals using Rice coding with parameter `k`.
74///
75/// Returns the packed bit stream as a `Vec<u8>` (MSB-first, zero-padded to byte boundary).
76#[must_use]
77pub fn rice_encode(residuals: &[i32], k: u8) -> Vec<u8> {
78    let mut bits: Vec<bool> = Vec::new();
79
80    for &r in residuals {
81        let u = zigzag_encode(r);
82        let quotient = u >> k;
83        let remainder = u & ((1u32 << k) - 1);
84
85        // Unary-coded quotient: `quotient` ones followed by a zero
86        for _ in 0..quotient {
87            bits.push(true);
88        }
89        bits.push(false);
90
91        // Binary `k` bits of remainder (MSB first)
92        for bit_idx in (0..k).rev() {
93            bits.push((remainder >> bit_idx) & 1 != 0);
94        }
95    }
96
97    // Pack bits into bytes (MSB-first)
98    let mut out = Vec::with_capacity((bits.len() + 7) / 8);
99    let mut byte = 0u8;
100    let mut fill = 0u8;
101    for bit in bits {
102        byte = (byte << 1) | u8::from(bit);
103        fill += 1;
104        if fill == 8 {
105            out.push(byte);
106            byte = 0;
107            fill = 0;
108        }
109    }
110    if fill > 0 {
111        out.push(byte << (8 - fill));
112    }
113    out
114}
115
116/// Rice decoder state.
117pub struct RiceDecoder<'a> {
118    data: &'a [u8],
119    byte_pos: usize,
120    bit_pos: u8,
121}
122
123impl<'a> RiceDecoder<'a> {
124    /// Create a decoder over a Rice-coded byte stream.
125    #[must_use]
126    pub fn new(data: &'a [u8]) -> Self {
127        Self {
128            data,
129            byte_pos: 0,
130            bit_pos: 0,
131        }
132    }
133
134    fn read_bit(&mut self) -> Option<bool> {
135        if self.byte_pos >= self.data.len() {
136            return None;
137        }
138        let bit = (self.data[self.byte_pos] >> (7 - self.bit_pos)) & 1 != 0;
139        self.bit_pos += 1;
140        if self.bit_pos == 8 {
141            self.byte_pos += 1;
142            self.bit_pos = 0;
143        }
144        Some(bit)
145    }
146
147    /// Decode one Rice-coded residual with parameter `k`.
148    pub fn decode_one(&mut self, k: u8) -> Option<i32> {
149        // Read unary quotient
150        let mut quotient = 0u32;
151        loop {
152            let bit = self.read_bit()?;
153            if !bit {
154                break;
155            }
156            quotient += 1;
157            if quotient > 1024 * 1024 {
158                return None; // guard against corrupt data
159            }
160        }
161
162        // Read `k` remainder bits
163        let mut remainder = 0u32;
164        for _ in 0..k {
165            let bit = self.read_bit()?;
166            remainder = (remainder << 1) | u32::from(bit);
167        }
168
169        let u = (quotient << k) | remainder;
170        Some(zigzag_decode(u))
171    }
172
173    /// Decode `count` residuals with parameter `k`.
174    pub fn decode_n(&mut self, count: usize, k: u8) -> Vec<i32> {
175        (0..count).map_while(|_| self.decode_one(k)).collect()
176    }
177}
178
179// =============================================================================
180// Tests
181// =============================================================================
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_zigzag_encode_decode_identity() {
189        for v in [-100i32, -1, 0, 1, 100, i16::MAX as i32] {
190            let u = zigzag_encode(v);
191            let back = zigzag_decode(u);
192            assert_eq!(back, v, "zigzag roundtrip failed for {v}");
193        }
194    }
195
196    #[test]
197    fn test_zigzag_non_negative_output() {
198        // zigzag_encode maps i32 -> u32, which is inherently non-negative
199        for v in [-200i32, -100, -1, 0, 1, 100, 200] {
200            let _u = zigzag_encode(v);
201        }
202    }
203
204    #[test]
205    fn test_rice_bit_cost_zero_residuals() {
206        let res = vec![0i32; 16];
207        let cost = rice_bit_cost(&res, 0);
208        // Each 0 costs 1 (unary 0) + 0 (k=0) = 1 bit → 16 total
209        assert_eq!(cost, 16);
210    }
211
212    #[test]
213    fn test_rice_encode_decode_roundtrip() {
214        let residuals = vec![0i32, 1, -1, 2, -2, 5, -5, 10, -10];
215        let k = optimal_rice_param(&residuals);
216        let encoded = rice_encode(&residuals, k);
217        let mut dec = RiceDecoder::new(&encoded);
218        let decoded = dec.decode_n(residuals.len(), k);
219        assert_eq!(decoded, residuals, "Rice roundtrip must be lossless");
220    }
221
222    #[test]
223    fn test_rice_encode_empty() {
224        let encoded = rice_encode(&[], 4);
225        assert!(encoded.is_empty());
226    }
227
228    #[test]
229    fn test_optimal_rice_param_small_residuals() {
230        // Small residuals → small k is optimal
231        let residuals = vec![0i32; 32];
232        let k = optimal_rice_param(&residuals);
233        assert_eq!(k, 0, "All-zero residuals → k=0 is optimal");
234    }
235
236    #[test]
237    fn test_optimal_rice_param_large_residuals() {
238        // Large residuals → larger k is better
239        let residuals: Vec<i32> = (0..32).map(|i| i * 1000).collect();
240        let k_large = optimal_rice_param(&residuals);
241        let k_small = optimal_rice_param(&vec![0i32; 32]);
242        assert!(k_large >= k_small, "Large residuals should use larger k");
243    }
244
245    #[test]
246    fn test_rice_decode_n_partial() {
247        // If stream is shorter than count, decode_n returns fewer items
248        let residuals = vec![1i32, 2, 3];
249        let k = 1;
250        let encoded = rice_encode(&residuals, k);
251        // Request more than available
252        let mut dec = RiceDecoder::new(&encoded);
253        let decoded = dec.decode_n(100, k);
254        assert!(decoded.len() >= residuals.len());
255        assert_eq!(&decoded[..residuals.len()], &residuals[..]);
256    }
257}