gamut_bitstream/
symbol.rs1const EC_PROB_SHIFT: u32 = 6;
19const EC_MIN_PROB: u32 = 4;
21const CDF_PROB_TOP: u32 = 1 << 15;
23
24#[derive(Debug, Clone)]
31pub struct SymbolEncoder {
32 low: u64,
35 rng: u32,
37 cnt: i32,
39 precarry: Vec<u16>,
41}
42
43impl Default for SymbolEncoder {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl SymbolEncoder {
50 #[must_use]
52 pub fn new() -> Self {
53 Self {
54 low: 0,
55 rng: CDF_PROB_TOP,
56 cnt: -9,
57 precarry: Vec::new(),
58 }
59 }
60
61 pub fn encode_symbol(&mut self, symbol: usize, cdf: &[u16]) {
67 let nsyms = cdf.len();
68 debug_assert!(symbol < nsyms);
69 debug_assert_eq!(u32::from(cdf[nsyms - 1]), CDF_PROB_TOP);
70 let fl = if symbol > 0 {
73 CDF_PROB_TOP - u32::from(cdf[symbol - 1])
74 } else {
75 CDF_PROB_TOP
76 };
77 let fh = CDF_PROB_TOP - u32::from(cdf[symbol]);
78 self.encode_q15(fl, fh, symbol as u32, nsyms as u32);
79 }
80
81 pub fn encode_literal(&mut self, value: u32, n: u32) {
86 const BOOL_CDF: [u16; 2] = [1 << 14, 1 << 15];
87 for i in (0..n).rev() {
88 self.encode_symbol(((value >> i) & 1) as usize, &BOOL_CDF);
89 }
90 }
91
92 fn encode_q15(&mut self, fl: u32, fh: u32, s: u32, nsyms: u32) {
95 let mut low = self.low;
96 let mut r = self.rng;
97 debug_assert!(r >= CDF_PROB_TOP);
98 let n = nsyms - 1;
99 if fl < CDF_PROB_TOP {
100 let u = (((r >> 8) * (fl >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT))
101 + EC_MIN_PROB * (n - (s - 1));
102 let v =
103 (((r >> 8) * (fh >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB * (n - s);
104 debug_assert!(u <= r && v < u);
105 low += u64::from(r - u);
106 r = u - v;
107 } else {
108 let v =
110 (((r >> 8) * (fh >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB * (n - s);
111 debug_assert!(v < r);
112 r -= v;
113 }
114 self.normalize(low, r);
115 }
116
117 fn normalize(&mut self, mut low: u64, rng: u32) {
120 let d = rng.leading_zeros() - 16;
122 let mut c = self.cnt;
123 let mut s = c + d as i32;
124 if s >= 0 {
125 c += 16;
126 let mut m = (1u64 << c) - 1;
127 if s >= 8 {
128 self.precarry.push((low >> c) as u16);
129 low &= m;
130 c -= 8;
131 m = (1u64 << c) - 1;
132 }
133 self.precarry.push((low >> c) as u16);
134 s = c + d as i32 - 24;
135 low &= m;
136 }
137 self.low = low << d;
138 self.rng = rng << d;
139 self.cnt = s;
140 }
141
142 #[must_use]
146 pub fn finish(mut self) -> Vec<u8> {
147 let l = self.low;
148 let mut c = self.cnt;
149 let mut s = 10 + c;
150 let m: u64 = 0x3FFF;
151 let mut e = ((l + m) & !m) | (m + 1);
152 if s > 0 {
153 let mut n = (1u64 << (c + 16)) - 1;
154 loop {
155 self.precarry.push((e >> (c + 16)) as u16);
156 e &= n;
157 s -= 8;
158 c -= 8;
159 n >>= 8;
160 if s <= 0 {
161 break;
162 }
163 }
164 }
165 let mut out = vec![0u8; self.precarry.len()];
167 let mut carry: u32 = 0;
168 for i in (0..self.precarry.len()).rev() {
169 let val = u32::from(self.precarry[i]) + carry;
170 out[i] = (val & 0xff) as u8;
171 carry = val >> 8;
172 }
173 out
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 struct SymbolDecoder<'a> {
183 data: &'a [u8],
184 bit_pos: usize,
185 value: u32,
186 range: u32,
187 max_bits: i64,
188 }
189
190 impl<'a> SymbolDecoder<'a> {
191 fn read_f(&mut self, n: u32) -> u32 {
193 let mut x = 0u32;
194 for _ in 0..n {
195 let idx = self.bit_pos >> 3;
196 let bit = if idx < self.data.len() {
197 (self.data[idx] >> (7 - (self.bit_pos & 7))) & 1
198 } else {
199 0
200 };
201 x = (x << 1) | u32::from(bit);
202 self.bit_pos += 1;
203 }
204 x
205 }
206
207 fn new(data: &'a [u8]) -> Self {
209 let sz = data.len();
210 let mut d = Self {
211 data,
212 bit_pos: 0,
213 value: 0,
214 range: 1 << 15,
215 max_bits: 8 * sz as i64 - 15,
216 };
217 let num_bits = core::cmp::min(sz * 8, 15) as u32;
218 let buf = d.read_f(num_bits);
219 let padded = buf << (15 - num_bits);
220 d.value = ((1 << 15) - 1) ^ padded;
221 d
222 }
223
224 fn read_symbol(&mut self, cdf: &[u16]) -> usize {
227 let n = cdf.len() as u32;
228 let mut cur = self.range;
229 let mut symbol: i64 = -1;
230 let mut prev;
231 loop {
232 symbol += 1;
233 prev = cur;
234 let f = (1u32 << 15) - u32::from(cdf[symbol as usize]);
235 cur = ((self.range >> 8) * (f >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT);
236 cur += EC_MIN_PROB * (n - symbol as u32 - 1);
237 if self.value >= cur {
238 break;
239 }
240 }
241 self.range = prev - cur;
242 self.value -= cur;
243 let bits = 15 - (31 - self.range.leading_zeros());
245 self.range <<= bits;
246 let num_bits = core::cmp::min(i64::from(bits), self.max_bits.max(0)) as u32;
247 let new_data = self.read_f(num_bits);
248 let padded = new_data << (bits - num_bits);
249 self.value = padded ^ (((self.value + 1) << bits) - 1);
250 self.max_bits -= i64::from(bits);
251 symbol as usize
252 }
253
254 fn read_literal(&mut self, n: u32) -> u32 {
255 const BOOL_CDF: [u16; 2] = [1 << 14, 1 << 15];
256 let mut x = 0;
257 for _ in 0..n {
258 x = (x << 1) | self.read_symbol(&BOOL_CDF) as u32;
259 }
260 x
261 }
262 }
263
264 struct Lcg(u64);
266 impl Lcg {
267 fn next_u32(&mut self) -> u32 {
268 self.0 = self
269 .0
270 .wrapping_mul(6364136223846793005)
271 .wrapping_add(1442695040888963407);
272 (self.0 >> 32) as u32
273 }
274 fn below(&mut self, bound: u32) -> u32 {
275 self.next_u32() % bound
276 }
277 }
278
279 fn random_cdf(rng: &mut Lcg, nsyms: usize) -> Vec<u16> {
281 let mut points = Vec::new();
283 while points.len() < nsyms - 1 {
284 let p = 1 + rng.below(32767) as u16;
285 if !points.contains(&p) {
286 points.push(p);
287 }
288 }
289 points.sort_unstable();
290 points.push(32768);
291 points
292 }
293
294 #[test]
295 fn empty_stream_roundtrips() {
296 let enc = SymbolEncoder::new();
297 let bytes = enc.finish();
298 let _ = SymbolDecoder::new(&bytes);
300 }
301
302 #[test]
303 fn single_symbol_streams_roundtrip() {
304 for nsyms in 2..=12usize {
306 let mut cdf: Vec<u16> = (1..nsyms).map(|i| (i * 32768 / nsyms) as u16).collect();
307 cdf.push(32768);
308 for s in 0..nsyms {
309 let mut enc = SymbolEncoder::new();
310 enc.encode_symbol(s, &cdf);
311 let bytes = enc.finish();
312 let mut dec = SymbolDecoder::new(&bytes);
313 assert_eq!(dec.read_symbol(&cdf), s, "nsyms={nsyms} s={s}");
314 }
315 }
316 }
317
318 #[test]
319 fn long_random_symbol_stream_roundtrips() {
320 let mut rng = Lcg(0x1234_5678_9abc_def0);
321 let cdfs: Vec<Vec<u16>> = (2..=14).map(|n| random_cdf(&mut rng, n)).collect();
323 let mut events = Vec::new();
324 let mut enc = SymbolEncoder::new();
325 for _ in 0..20_000 {
326 let cdf = &cdfs[rng.below(cdfs.len() as u32) as usize];
327 let s = rng.below(cdf.len() as u32) as usize;
328 enc.encode_symbol(s, cdf);
329 events.push((s, cdf.clone()));
330 }
331 let bytes = enc.finish();
332 let mut dec = SymbolDecoder::new(&bytes);
333 for (i, (s, cdf)) in events.iter().enumerate() {
334 assert_eq!(dec.read_symbol(cdf), *s, "event {i}");
335 }
336 }
337
338 #[test]
339 fn literals_roundtrip() {
340 let mut rng = Lcg(0xdead_beef_0bad_f00d);
341 let mut enc = SymbolEncoder::new();
342 let mut events = Vec::new();
343 for _ in 0..5000 {
344 let n = 1 + rng.below(16);
345 let v = rng.next_u32() & ((1u32 << n) - 1);
346 enc.encode_literal(v, n);
347 events.push((v, n));
348 }
349 let bytes = enc.finish();
350 let mut dec = SymbolDecoder::new(&bytes);
351 for (v, n) in events {
352 assert_eq!(dec.read_literal(n), v);
353 }
354 }
355
356 #[test]
357 fn mixed_symbols_and_literals_roundtrip() {
358 let mut rng = Lcg(0x0f0f_0f0f_1234_9999);
359 let cdf = random_cdf(&mut rng, 8);
360 let mut enc = SymbolEncoder::new();
361 let mut events: Vec<(bool, u32)> = Vec::new(); for _ in 0..8000 {
363 if rng.next_u32() & 1 == 0 {
364 let s = rng.below(cdf.len() as u32);
365 enc.encode_symbol(s as usize, &cdf);
366 events.push((false, s));
367 } else {
368 let v = rng.next_u32() & 0xff;
369 enc.encode_literal(v, 8);
370 events.push((true, v));
371 }
372 }
373 let bytes = enc.finish();
374 let mut dec = SymbolDecoder::new(&bytes);
375 for (is_lit, payload) in events {
376 if is_lit {
377 assert_eq!(dec.read_literal(8), payload);
378 } else {
379 assert_eq!(dec.read_symbol(&cdf) as u32, payload);
380 }
381 }
382 }
383}