1pub const EC_SYM_BITS: u32 = 8;
2pub const EC_CODE_BITS: u32 = 32;
3pub const EC_SYM_MAX: u32 = (1 << EC_SYM_BITS) - 1;
4pub const EC_CODE_SHIFT: u32 = EC_CODE_BITS - EC_SYM_BITS - 1;
5pub const EC_CODE_TOP: u32 = 1 << (EC_CODE_BITS - 1);
6pub const EC_CODE_BOT: u32 = EC_CODE_TOP >> EC_SYM_BITS;
7pub const EC_CODE_EXTRA: u32 = (EC_CODE_BITS - 2) % EC_SYM_BITS + 1;
8pub const BITRES: i32 = 3;
9
10#[derive(Clone)]
11pub struct RangeCoder {
12 pub buf: Vec<u8>,
13 pub storage: u32,
14 pub end_offs: u32,
15 pub end_window: u32,
16 pub nend_bits: i32,
17 pub nbits_total: i32,
18 pub offs: u32,
19 pub rng: u32,
20 pub val: u64,
21 pub ext: u32,
22 pub rem: i32,
23 pub error: i32,
24}
25
26impl RangeCoder {
27 pub fn new_encoder(size: u32) -> Self {
28 RangeCoder {
29 buf: vec![0; size as usize],
30 storage: size,
31 end_offs: 0,
32 end_window: 0,
33 nend_bits: 0,
34 nbits_total: 33,
35 offs: 0,
36 rng: 1 << 31,
37 val: 0,
38 ext: 0,
39 rem: -1,
40 error: 0,
41 }
42 }
43
44 pub fn new_decoder(data: &[u8]) -> Self {
45 let storage = data.len() as u32;
46 let buf = data.to_vec();
47 let mut rc = RangeCoder {
48 buf,
49 storage,
50 end_offs: 0,
51 end_window: 0,
52 nend_bits: 0,
53 nbits_total: (EC_CODE_BITS + 1
54 - ((EC_CODE_BITS - EC_CODE_EXTRA) / EC_SYM_BITS) * EC_SYM_BITS)
55 as i32,
56 offs: 0,
57 rng: 1 << EC_CODE_EXTRA,
58 val: 0,
59 ext: 0,
60 rem: 0,
61 error: 0,
62 };
63
64 rc.rem = rc.read_byte() as i32;
65 rc.val = (rc.rng - 1 - (rc.rem as u32 >> (EC_SYM_BITS - EC_CODE_EXTRA))) as u64;
66
67 rc.normalize_decoder();
68 rc
69 }
70
71 fn normalize_decoder(&mut self) {
72 while self.rng <= EC_CODE_BOT {
73 self.nbits_total += EC_SYM_BITS as i32;
74 self.rng <<= EC_SYM_BITS;
75 if self.rng == 0 {
76 debug_assert!(
77 false,
78 "normalize_decoder: rng=0 after shift, corrupt bitstream"
79 );
80 self.error = 1;
81 self.rng = 1;
82 return;
83 }
84
85 let sym = self.rem;
86 self.rem = self.read_byte() as i32;
87
88 let combined_sym = ((sym << EC_SYM_BITS) | self.rem) >> (EC_SYM_BITS - EC_CODE_EXTRA);
89 self.val = ((self.val << EC_SYM_BITS) + (EC_SYM_MAX & !combined_sym as u32) as u64)
90 & (EC_CODE_TOP as u64 - 1);
91 }
92 }
93
94 fn read_byte(&mut self) -> u8 {
95 if self.offs < self.storage {
96 let b = self.buf[self.offs as usize];
97 self.offs += 1;
98 b
99 } else {
100 0
101 }
102 }
103
104 pub fn enc_uint(&mut self, fl: u32, ft: u32) {
105 if ft > (1 << 8) {
106 let mut ft = ft - 1;
107 let s = 32 - ft.leading_zeros() as i32 - 8;
108 self.enc_bits(fl & ((1 << s) - 1), s as u32);
109 let fl = fl >> s;
110 ft >>= s;
111 ft += 1;
112 self.encode(fl, fl + 1, ft);
113 } else if ft > 1 {
114 self.encode(fl, fl + 1, ft);
115 }
116 }
117
118 pub fn dec_uint(&mut self, ft: u32) -> u32 {
119 if ft > (1 << 8) {
120 let mut ft = ft - 1;
121 let s = 32 - ft.leading_zeros() as i32 - 8;
122 let r = self.dec_bits(s as u32);
123 ft >>= s;
124 ft += 1;
125 let fs = self.decode(ft);
126 self.update(fs, fs + 1, ft);
127 (fs << s) | r
128 } else if ft > 1 {
129 let fs = self.decode(ft);
130 self.update(fs, fs + 1, ft);
131 fs
132 } else {
133 0
134 }
135 }
136
137 pub fn enc_bits(&mut self, val: u32, bits: u32) {
138 if bits == 0 {
139 return;
140 }
141 let mut window = self.end_window;
142 let mut used = self.nend_bits;
143 if (used as u32) + bits > EC_CODE_BITS {
144 while used >= EC_SYM_BITS as i32 {
145 self.write_byte_at_end((window & EC_SYM_MAX) as u8);
146 window >>= EC_SYM_BITS;
147 used -= EC_SYM_BITS as i32;
148 }
149 }
150 window |= (val & ((1 << bits) - 1)) << used;
151 used += bits as i32;
152 self.end_window = window;
153 self.nend_bits = used;
154 self.nbits_total += bits as i32;
155 }
156
157 pub fn dec_bits(&mut self, bits: u32) -> u32 {
158 if bits == 0 {
159 return 0;
160 }
161 let mut window = self.end_window;
162 let mut used = self.nend_bits;
163 if used < bits as i32 {
164 loop {
165 let byte = if self.end_offs < self.storage {
166 self.end_offs += 1;
167 self.buf[(self.storage - self.end_offs) as usize]
168 } else {
169 0
170 };
171 window |= (byte as u32) << used;
172 used += 8;
173 if used > 32 - 8 {
174 break;
175 }
176 }
177 }
178 let ret = window & ((1 << bits) - 1);
179 self.end_window = window >> bits;
180 self.nend_bits = used - bits as i32;
181 self.nbits_total += bits as i32;
182 ret
183 }
184
185 pub fn tell_frac(&self) -> i32 {
186 static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
187 let nbits = self.nbits_total << BITRES;
188 let mut l = 32 - self.rng.leading_zeros() as i32;
189 let r = self.rng >> (l - 16);
190 let mut b = (r >> 12) - 8;
191 if b < 7 && r > CORRECTION[b as usize] {
192 b += 1;
193 }
194 l = (l << 3) + b as i32;
195 nbits - l
196 }
197
198 pub fn tell(&self) -> i32 {
199 (self.tell_frac() + 7) >> 3
200 }
201
202 fn write_byte(&mut self, value: u8) {
203 if self.offs + self.end_offs < self.storage {
204 self.buf[self.offs as usize] = value;
205 self.offs += 1;
206 } else {
207 self.error = 1;
208 }
209 }
210
211 fn carry_out(&mut self, c: i32) {
212 if c != EC_SYM_MAX as i32 {
213 let carry = c >> EC_SYM_BITS;
214 if self.rem >= 0 {
215 self.write_byte((self.rem + carry) as u8);
216 }
217 if self.ext > 0 {
218 let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
219 for _ in 0..self.ext {
220 self.write_byte(sym as u8);
221 }
222 self.ext = 0;
223 }
224 self.rem = c & EC_SYM_MAX as i32;
225 } else {
226 self.ext += 1;
227 }
228 }
229
230 pub fn encode(&mut self, fl: u32, fh: u32, ft: u32) {
231 if ft == 0 {
232 return;
233 }
234 let r = self.rng / ft;
235 if fl > 0 {
236 self.val += (self.rng - r * (ft - fl)) as u64;
237 self.rng = r * (fh - fl);
238 } else {
239 self.rng -= r * (ft - fh);
240 }
241 self.normalize_encoder();
242 }
243
244 fn normalize_encoder(&mut self) {
245 if self.rng == 0 {
246 self.error = 1;
247 self.rng = 1;
248 return;
249 }
250 while self.rng <= EC_CODE_BOT {
251 self.carry_out((self.val >> EC_CODE_SHIFT) as i32);
252 self.val = (self.val << EC_SYM_BITS) & (EC_CODE_TOP as u64 - 1);
253 self.rng <<= EC_SYM_BITS;
254 self.nbits_total = self.nbits_total.wrapping_add(EC_SYM_BITS as i32);
255 }
256 }
257
258 pub fn encode_bit_logp(&mut self, val: bool, logp: u32) {
259 let s = self.rng >> logp;
260 let r = self.rng - s;
261 if val {
262 self.val += r as u64;
263 self.rng = s;
264 } else {
265 self.rng = r;
266 }
267 self.normalize_encoder();
268 }
269
270 pub fn encode_icdf(&mut self, s: i32, icdf: &[u8], ftb: u32) {
271 let r = self.rng >> ftb;
272 if s > 0 {
273 let val = icdf[(s - 1) as usize] as u32;
274 self.val += (self.rng as u64) - (r as u64 * val as u64);
275 let lower = icdf.get(s as usize).copied().unwrap_or(0) as u32;
277 let diff = val - lower;
278 debug_assert!(
279 diff > 0,
280 "encode_icdf: zero-probability symbol s={s}, icdf={icdf:?}, ftb={ftb} \
281 (icdf[{prev}]={val} == icdf[{s}]={lower})",
282 prev = s - 1,
283 );
284 self.rng = r * diff;
285 } else {
286 let val = icdf[s as usize] as u32;
287 let full = 1u32 << ftb;
288 debug_assert!(
289 val < full,
290 "encode_icdf: zero-probability symbol s=0, icdf={icdf:?}, ftb={ftb} \
291 (icdf[0]={val} == 2^ftb={full}, symbol has zero probability)"
292 );
293 self.rng -= r * val;
294 }
295 self.normalize_encoder();
296 }
297
298 pub fn decode_bit_logp(&mut self, logp: u32) -> bool {
299 let s = self.rng >> logp;
300 let ret = self.val < s as u64;
301 if !ret {
302 self.val -= s as u64;
303 self.rng -= s;
304 } else {
305 self.rng = s;
306 }
307 self.normalize_decoder();
308 ret
309 }
310
311 pub fn decode_icdf(&mut self, icdf: &[u8], ftb: u32) -> i32 {
314 let mut s = self.rng;
315 let d = self.val as u32;
316 let r = s >> ftb;
317 let mut ret = 0;
318 let mut t;
319
320 loop {
323 t = s;
324 s = r * (icdf[ret] as u32);
325 ret += 1;
326 if d >= s {
327 break;
328 }
329 }
330
331 self.val = (d - s) as u64;
332 self.rng = t - s;
333 self.normalize_decoder();
334 (ret - 1) as i32
335 }
336
337 pub fn decode(&mut self, ft: u32) -> u32 {
338 let r = self.rng / ft;
339 self.ext = r;
340 let s = (self.val / r as u64) as u32;
341 ft - ft.min(s + 1)
342 }
343
344 pub fn update(&mut self, fl: u32, fh: u32, ft: u32) {
345 let s = self.ext * (ft - fh);
346 self.val -= s as u64;
347 self.rng = if fl > 0 {
348 self.ext * (fh - fl)
349 } else {
350 self.rng - s
351 };
352
353 self.normalize_decoder();
354 }
355
356 pub fn laplace_encode(&mut self, value: &mut i32, fs: u32, decay: i32) {
357 let mut val = *value;
358 let mut fl = 0;
359 let mut fs_val = fs;
360
361 if val != 0 {
362 let s = if val < 0 { -1 } else { 0 };
363 val = (val + s) ^ s;
364 fl = fs_val;
365 fs_val = self.laplace_get_freq1(fs_val, decay);
366
367 let mut i = 1;
368 while fs_val > 0 && i < val {
369 fs_val *= 2;
370 fl += fs_val + 2;
371 fs_val = ((fs_val as i32 * decay) >> 15) as u32;
372 i += 1;
373 }
374
375 if fs_val == 0 {
376 let ndi_max = 32768 - fl + 1 - 1;
377 let ndi_max = (ndi_max as i32 - s) >> 1;
378 let di = (val - i).min(ndi_max - 1);
379 fl += (2 * di + 1 + s) as u32;
380 fs_val = 1u32.min(32768 - fl);
381 *value = (i + di + s) ^ s;
382 } else {
383 fs_val += 1;
384 fl += fs_val & (!s as u32);
385 }
386 }
387 self.encode(fl, fl + fs_val, 1 << 15);
388 }
389
390 fn laplace_get_freq1(&self, fs0: u32, decay: i32) -> u32 {
391 let ft = 32768 - (2 * 16) - fs0;
392 ((ft as i32 * (16384 - decay)) >> 15) as u32
393 }
394
395 pub fn laplace_decode(&mut self, fs: u32, decay: i32) -> i32 {
396 let fm = self.decode(1 << 15);
397 let mut fl = 0;
398 let mut fs_val = fs;
399 let mut val = 0;
400
401 if fm >= fs_val {
402 val += 1;
403 fl = fs_val;
404 fs_val = self.laplace_get_freq1(fs_val, decay) + 1;
405
406 while fs_val > 1 && fm >= fl + 2 * fs_val {
407 fs_val *= 2;
408 fl += fs_val;
409 fs_val = (((fs_val as i32 - 2) * decay) >> 15) as u32 + 1;
410 val += 1;
411 }
412
413 if fs_val <= 1 {
414 let di = (fm - fl) >> 1;
415 val += di as i32;
416 fl += 2 * di;
417 }
418
419 if fm < fl + fs_val {
420 val = -val;
421 } else {
422 fl += fs_val;
423 }
424 }
425
426 self.update(fl, fl + fs_val.min(32768 - fl), 1 << 15);
427 val
428 }
429
430 fn write_byte_at_end(&mut self, value: u8) {
431 if self.offs + self.end_offs < self.storage {
432 self.end_offs += 1;
433 let idx = (self.storage - self.end_offs) as usize;
434 self.buf[idx] = value;
435 } else {
436 self.error = 1;
437 }
438 }
439
440 pub fn patch_initial_bits(&mut self, val: u32, nbits: u32) {
441 let shift = EC_SYM_BITS - nbits;
442 let mask = ((1u32 << nbits) - 1) << shift;
443 if self.offs > 0 {
444 self.buf[0] = ((self.buf[0] as u32 & !mask) | (val << shift)) as u8;
445 } else if self.rem >= 0 {
446 self.rem = ((self.rem as u32 & !mask) | (val << shift)) as i32;
447 } else if self.rng <= (EC_CODE_TOP >> nbits) {
448 let mask64 = (mask as u64) << EC_CODE_SHIFT;
449 self.val = (self.val & !mask64) | ((val as u64) << (EC_CODE_SHIFT + shift));
450 } else {
451 self.error = -1;
452 }
453 }
454
455 pub fn done(&mut self) {
456 let ilog = 32 - self.rng.leading_zeros();
457 let mut l = (EC_CODE_BITS - ilog) as i32;
458 let mut msk = (EC_CODE_TOP as u64 - 1) >> l;
459 let mut end = (self.val + msk) & !msk;
460
461 if (end | msk) >= self.val + self.rng as u64 {
462 l += 1;
463 msk >>= 1;
464 end = (self.val + msk) & !msk;
465 }
466
467 while l > 0 {
468 self.carry_out((end >> EC_CODE_SHIFT) as i32);
469 end = (end << EC_SYM_BITS) & (EC_CODE_TOP as u64 - 1);
470 l -= EC_SYM_BITS as i32;
471 }
472
473 if self.rem >= 0 || self.ext > 0 {
474 self.carry_out(0);
475 }
476
477 let mut window = self.end_window;
478 let mut used = self.nend_bits;
479 while used >= EC_SYM_BITS as i32 {
480 self.write_byte_at_end((window & EC_SYM_MAX) as u8);
481 window >>= EC_SYM_BITS;
482 used -= EC_SYM_BITS as i32;
483 }
484
485 if self.error == 0 {
486 for i in self.offs..(self.storage - self.end_offs) {
487 self.buf[i as usize] = 0;
488 }
489
490 if used > 0 {
491 if self.end_offs >= self.storage {
492 self.error = -1;
493 } else {
494 let idx = (self.storage - self.end_offs - 1) as usize;
495 self.buf[idx] |= window as u8;
496
497 self.end_offs += 1;
498 }
499 }
500 }
501 }
502
503 pub fn finish(&mut self) -> Vec<u8> {
504 self.done();
505
506 let mut result = Vec::with_capacity((self.offs + self.end_offs) as usize);
507 result.extend_from_slice(&self.buf[0..self.offs as usize]);
508 result.extend_from_slice(
509 &self.buf[(self.storage - self.end_offs) as usize..self.storage as usize],
510 );
511 result
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_laplace() {
521 let mut enc = RangeCoder::new_encoder(100);
522 let mut val = -3;
523 let fs = 100 << 7;
524 let decay = 120 << 6;
525 enc.laplace_encode(&mut val, fs, decay);
526 enc.done();
527
528 assert_eq!(enc.offs, 1);
529 assert_eq!(enc.buf[0], 224);
530
531 let mut dec = RangeCoder::new_decoder(&enc.buf[..enc.offs as usize]);
532 let decoded_val = dec.laplace_decode(fs, decay);
533 assert_eq!(decoded_val, -3);
534 }
535
536 #[test]
537 fn test_icdf_consistency() {
538 let mut enc = RangeCoder::new_encoder(1024);
539 let icdf = [2, 1, 0];
540 enc.encode_icdf(0, &icdf, 2);
541 enc.encode_icdf(1, &icdf, 2);
542 enc.encode_icdf(2, &icdf, 2);
543 enc.done();
544 let data = enc.buf[..enc.offs as usize].to_vec();
545
546 let mut dec = RangeCoder::new_decoder(&data);
547 let s0 = dec.decode_icdf(&icdf, 2);
548 let s1 = dec.decode_icdf(&icdf, 2);
549 let s2 = dec.decode_icdf(&icdf, 2);
550
551 assert_eq!(s0, 0);
552 assert_eq!(s1, 1);
553 assert_eq!(s2, 2);
554 }
555
556 #[test]
559 fn test_icdf_last_symbol_no_oob() {
560 let icdf: &[u8] = &[170, 85, 0];
567 let ftb = 8u32;
568
569 for sym in 0..3i32 {
571 let mut enc = RangeCoder::new_encoder(256);
572 enc.encode_icdf(sym, icdf, ftb); enc.done();
574 let data = enc.buf[..enc.offs as usize].to_vec();
575
576 let mut dec = RangeCoder::new_decoder(&data);
577 let decoded = dec.decode_icdf(icdf, ftb);
578 assert_eq!(decoded, sym, "往返失败: 编码 symbol={sym} 解码得 {decoded}");
579 }
580 }
581
582 #[test]
585 fn test_icdf_decode_terminates() {
586 let icdf: &[u8] = &[192, 128, 64, 0];
589 let ftb = 8u32;
590
591 let symbols = [0i32, 1, 2, 3];
592 let mut enc = RangeCoder::new_encoder(256);
593 for &s in &symbols {
594 enc.encode_icdf(s, icdf, ftb);
595 }
596 enc.done();
597 let data = enc.buf[..enc.offs as usize].to_vec();
598
599 let mut dec = RangeCoder::new_decoder(&data);
600 for &expected in &symbols {
601 let got = dec.decode_icdf(icdf, ftb);
602 assert_eq!(got, expected, "解码器输出 {got},期望 {expected}");
603 }
604 }
605
606 #[test]
607 fn test_bits_only() {
608 let mut enc = RangeCoder::new_encoder(1024);
609
610 enc.enc_bits(1, 1);
611 enc.enc_bits(5, 3);
612 enc.enc_bits(7, 3);
613 enc.enc_bits(0, 2);
614
615 let data = enc.finish();
616 let mut dec = RangeCoder::new_decoder(&data);
617
618 let b1 = dec.dec_bits(1);
619 let b2 = dec.dec_bits(3);
620 let b3 = dec.dec_bits(3);
621 let b4 = dec.dec_bits(2);
622
623 assert_eq!(b1, 1);
624 assert_eq!(b2, 5);
625 assert_eq!(b3, 7);
626 assert_eq!(b4, 0);
627 }
628
629 #[test]
630 fn test_interleaved_bits_entropy() {
631 let mut enc = RangeCoder::new_encoder(1024);
632
633 enc.enc_bits(1, 1);
634
635 enc.encode(10, 20, 100);
636
637 enc.enc_bits(5, 3);
638
639 enc.encode(50, 60, 100);
640
641 let data = enc.finish();
642
643 let mut dec = RangeCoder::new_decoder(&data);
644
645 let b1 = dec.dec_bits(1);
646 let d1 = dec.decode(100);
647 dec.update(10, 20, 100);
648 let b2 = dec.dec_bits(3);
649 let d2 = dec.decode(100);
650 dec.update(50, 60, 100);
651
652 assert_eq!(b1, 1);
653 assert!((10..20).contains(&d1), "d1={} expected in [10, 20)", d1);
654 assert_eq!(b2, 5);
655 assert!((50..60).contains(&d2), "d2={} expected in [50, 60)", d2);
656 }
657}