libflo_audio/core/
rice.rs

1// Rice coding implementation for residual compression
2
3use super::audio_constants::{f32_to_i32, i32_to_f32, I16_MAX_F64};
4
5/// Estimate optimal Rice parameter from float residuals
6/// Residuals are expected to be in -1.0 to 1.0 range and will be scaled to 16-bit
7pub fn estimate_rice_parameter(residuals: &[f32]) -> u8 {
8    if residuals.is_empty() {
9        return 10; // Good default for 16-bit audio
10    }
11
12    // Scale to 16-bit range and calculate mean absolute value
13    let mean_abs: f64 = residuals
14        .iter()
15        .map(|&r| (r * I16_MAX_F64 as f32).abs() as f64)
16        .sum::<f64>()
17        / residuals.len() as f64;
18
19    if mean_abs > 1.0 {
20        // Rice parameter k where 2^k approximates mean_abs
21        (mean_abs.log2().round() as u8).clamp(4, 14)
22    } else {
23        4
24    }
25}
26
27/// Estimate Rice parameter from integer residuals
28/// Ensures k is large enough that no quotient exceeds 255 during encoding
29pub fn estimate_rice_parameter_i32(residuals: &[i32]) -> u8 {
30    if residuals.is_empty() {
31        return 4;
32    }
33
34    // Find maximum absolute value to ensure no overflow
35    let max_abs = residuals
36        .iter()
37        .map(|&r| r.unsigned_abs())
38        .max()
39        .unwrap_or(0) as u64;
40
41    if max_abs == 0 {
42        return 0;
43    }
44
45    // Zigzag encoding doubles positive values: max_unsigned = 2 * max_abs
46    let max_unsigned = 2 * max_abs;
47
48    // quotient = unsigned >> k
49    // We need quotient <= 255, so unsigned <= 255 << k
50    // Therefore k >= log2(unsigned / 255)
51    let min_k = if max_unsigned > 255 {
52        let bits_needed = 64 - max_unsigned.leading_zeros();
53        bits_needed.saturating_sub(8) as u8
54    } else {
55        0
56    };
57
58    // Also consider mean for efficiency
59    let sum: u64 = residuals.iter().map(|&r| r.unsigned_abs() as u64).sum();
60    let mean = (sum / residuals.len() as u64) as u32;
61    let mean_k = if mean > 0 {
62        (32 - mean.leading_zeros()) as u8
63    } else {
64        0
65    };
66
67    // Use the larger of min_k (for correctness) and mean_k (for efficiency)
68    min_k.max(mean_k).clamp(0, 15)
69}
70
71/// Rice encode float residuals (quantizes to 16-bit)
72pub fn encode(residuals: &[f32], k: u8) -> Vec<u8> {
73    let mut bits = BitWriter::new();
74
75    for &residual in residuals {
76        let sample = f32_to_i32(residual);
77        encode_sample(&mut bits, sample, k);
78    }
79
80    bits.into_bytes()
81}
82
83/// Rice encode integer residuals directly
84pub fn encode_i32(residuals: &[i32], k: u8) -> Vec<u8> {
85    let mut bits = BitWriter::new();
86
87    for &sample in residuals {
88        encode_sample(&mut bits, sample, k);
89    }
90
91    bits.into_bytes()
92}
93
94fn encode_sample(bits: &mut BitWriter, sample: i32, k: u8) {
95    // Zigzag encode: map signed to unsigned
96    // 0 → 0, -1 → 1, 1 → 2, -2 → 3, 2 → 4, ...
97    let unsigned = ((sample << 1) ^ (sample >> 31)) as u32;
98
99    // Rice coding: quotient and remainder
100    let quotient = unsigned >> k;
101    let remainder = unsigned & ((1 << k) - 1);
102
103    // Unary code for quotient (capped to prevent huge outputs)
104    let q_capped = quotient.min(255);
105    for _ in 0..q_capped {
106        bits.write_bit(1);
107    }
108    bits.write_bit(0);
109
110    // Binary code for remainder
111    for i in (0..k).rev() {
112        bits.write_bit((remainder >> i) & 1);
113    }
114}
115
116/// Rice decode to float residuals
117pub fn decode(encoded: &[u8], k: u8, target_len: usize) -> Vec<f32> {
118    let decoded_i32 = decode_i32(encoded, k, target_len);
119    decoded_i32.iter().map(|&s| i32_to_f32(s)).collect()
120}
121
122/// Rice decode to integer residuals
123pub fn decode_i32(encoded: &[u8], k: u8, target_len: usize) -> Vec<i32> {
124    let mut bits = BitReader::new(encoded);
125    let mut residuals = Vec::with_capacity(target_len);
126
127    for _ in 0..target_len {
128        if bits.is_exhausted() {
129            residuals.push(0);
130            continue;
131        }
132
133        // Read unary quotient
134        let mut quotient = 0u32;
135        while !bits.is_exhausted() && bits.read_bit() == 1 {
136            quotient += 1;
137            if quotient > 255 {
138                break;
139            }
140        }
141
142        // Read binary remainder
143        let mut remainder = 0u32;
144        for _ in 0..k {
145            remainder = (remainder << 1) | bits.read_bit();
146        }
147
148        // Reconstruct unsigned value
149        let unsigned = (quotient << k) | remainder;
150
151        // Zigzag decode
152        // 0 → 0, 1 → -1, 2 → 1, 3 → -2, 4 → 2, ...
153        let signed = ((unsigned >> 1) as i32) ^ (-((unsigned & 1) as i32));
154
155        residuals.push(signed);
156    }
157
158    residuals
159}
160
161/// Bit-level writer
162pub struct BitWriter {
163    bytes: Vec<u8>,
164    current_byte: u8,
165    bit_pos: u8,
166}
167
168impl BitWriter {
169    pub fn new() -> Self {
170        BitWriter {
171            bytes: Vec::new(),
172            current_byte: 0,
173            bit_pos: 0,
174        }
175    }
176
177    pub fn write_bit(&mut self, bit: u32) {
178        if bit != 0 {
179            self.current_byte |= 1 << (7 - self.bit_pos);
180        }
181
182        self.bit_pos += 1;
183        if self.bit_pos == 8 {
184            self.bytes.push(self.current_byte);
185            self.current_byte = 0;
186            self.bit_pos = 0;
187        }
188    }
189
190    #[allow(dead_code)]
191    pub fn write_bits(&mut self, value: u32, num_bits: u8) {
192        for i in (0..num_bits).rev() {
193            self.write_bit((value >> i) & 1);
194        }
195    }
196
197    pub fn into_bytes(mut self) -> Vec<u8> {
198        if self.bit_pos > 0 {
199            self.bytes.push(self.current_byte);
200        }
201        self.bytes
202    }
203
204    #[allow(dead_code)]
205    pub fn byte_count(&self) -> usize {
206        self.bytes.len() + if self.bit_pos > 0 { 1 } else { 0 }
207    }
208}
209
210impl Default for BitWriter {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216/// Bit-level reader
217pub struct BitReader<'a> {
218    bytes: &'a [u8],
219    byte_pos: usize,
220    bit_pos: u8,
221}
222
223impl<'a> BitReader<'a> {
224    pub fn new(bytes: &'a [u8]) -> Self {
225        BitReader {
226            bytes,
227            byte_pos: 0,
228            bit_pos: 0,
229        }
230    }
231
232    pub fn read_bit(&mut self) -> u32 {
233        if self.byte_pos >= self.bytes.len() {
234            return 0;
235        }
236
237        let bit = (self.bytes[self.byte_pos] >> (7 - self.bit_pos)) & 1;
238
239        self.bit_pos += 1;
240        if self.bit_pos == 8 {
241            self.bit_pos = 0;
242            self.byte_pos += 1;
243        }
244
245        bit as u32
246    }
247
248    #[allow(dead_code)]
249    pub fn read_bits(&mut self, num_bits: u8) -> u32 {
250        let mut value = 0u32;
251        for _ in 0..num_bits {
252            value = (value << 1) | self.read_bit();
253        }
254        value
255    }
256
257    pub fn is_exhausted(&self) -> bool {
258        self.byte_pos >= self.bytes.len()
259    }
260
261    #[allow(dead_code)]
262    pub fn remaining_bytes(&self) -> usize {
263        if self.byte_pos >= self.bytes.len() {
264            0
265        } else {
266            self.bytes.len() - self.byte_pos
267        }
268    }
269}