divans/probability/
interface.rs

1use core;
2use super::numeric;
3pub type Prob = i16; // can be i32
4pub const
5MAX_FREQUENTIST_PROB: Prob = 0xa00;
6#[cfg(feature="billing")]
7use std::io::Write;
8#[cfg(feature="billing")]
9macro_rules! println_stderr(
10    ($($val:tt)*) => { {
11//        writeln!(&mut ::std::io::stderr(), $($val)*).unwrap();
12    } }
13);
14
15#[cfg(not(feature="billing"))]
16macro_rules! println_stderr(
17    ($($val:tt)*) => { {
18//        writeln!(&mut ::std::io::stderr(), $($val)*).unwrap();
19    } }
20);
21
22#[derive(Copy,Clone,PartialEq,Eq,Debug)]
23pub struct ProbRange {
24    pub start: Prob,
25    pub freq: Prob,
26}
27
28#[derive(Copy,Clone,PartialEq,Eq,Debug)]
29pub struct SymStartFreq {
30    pub range: ProbRange,
31    pub sym: u8,
32}
33
34#[cfg(not(feature="no-stdlib"))]
35fn log2(x:f64) -> f64 {
36    x.log2()
37}
38
39#[cfg(feature="no-stdlib")]
40fn log2(x:f64) -> f64 {
41    (63 - (x as u64).leading_zeros()) as f64 // hack
42}
43#[cfg(feature="avoid-divide")]
44#[inline(always)]
45pub fn lookup_divisor(cdfmax: i16) -> (i64, u8) {
46    numeric::lookup_divisor(cdfmax)
47}
48#[cfg(not(feature="avoid-divide"))]
49#[inline(always)]
50pub fn lookup_divisor(_cdfmax:i16) {
51}
52
53// Common interface for CDF2 and CDF16, with optional methods.
54pub trait BaseCDF {
55
56    // the cardinality of symbols supported. Typical implementation values are 2 and 16.
57    fn num_symbols() -> u8;
58
59    // the cumulative distribution function evaluated at the given symbol.
60    fn cdf(&self, symbol: u8) -> Prob;
61
62    // the probability distribution function evaluated at the given symbol.
63    fn pdf(&self, symbol: u8) -> Prob {
64        debug_assert!(symbol < Self::num_symbols());
65        if symbol == 0 {
66            self.cdf(symbol)
67        } else {
68            self.cdf(symbol) - self.cdf(symbol - 1)
69        }
70    }
71    fn div_by_max(&self, val: i32) -> i32;
72    // the maximum value relative to which cdf() and pdf() values should be normalized.
73    fn max(&self) -> Prob;
74
75    // the base-2 logarithm of max(), if available, to support bit-shifting.
76    fn log_max(&self) -> Option<i8>;
77
78    // returns true if used.
79    fn used(&self) -> bool { false }
80
81    // returns true if valid.
82    fn valid(&self) -> bool { false }
83
84    // returns the entropy of the current distribution.
85    fn entropy(&self) -> f64 {
86        let mut sum = 0.0f64;
87        for i in 0..Self::num_symbols() {
88            let v = self.pdf(i as u8);
89            sum += if v == 0 { 0.0f64 } else {
90                let v_f64 = f64::from(v) / f64::from(self.max());
91                v_f64 * (log2(-v_f64))
92            };
93        }
94        sum
95    }
96    #[inline(always)]
97    fn sym_to_start_and_freq(&self,
98                             sym: u8) -> SymStartFreq {
99        let cdf_sym = self.div_by_max((i32::from(self.cdf(sym)) << LOG2_SCALE));
100        let cdf_prev = if sym != 0 {self.div_by_max(i32::from(self.cdf(sym - 1)) << LOG2_SCALE)} else { 0 };
101        let freq = cdf_sym - cdf_prev;
102        SymStartFreq {
103            range: ProbRange {start: cdf_prev as Prob + 1, // major hax
104                              freq:  freq as Prob - 1, // don't want rounding errors to work out unfavorably
105            },
106            sym: sym,
107        }
108    }
109    #[inline(always)]
110    #[cfg(not(feature="avoid-divide"))]
111    fn sym_to_start_and_freq_with_div_hint(&self,
112                                           sym: u8,
113                                           inv_max_and_bitlen:()) -> SymStartFreq {
114        self.sym_to_start_and_freq(sym)
115    }
116    #[inline(always)]
117    #[cfg(feature="avoid-divide")]
118    fn sym_to_start_and_freq_with_div_hint(&self,
119                                           sym: u8,
120                                           inv_max_and_bitlen:(i64, u8)) -> SymStartFreq {
121        let cdf_sym = numeric::fast_divide_30bit_by_16bit((i32::from(self.cdf(sym)) << LOG2_SCALE), inv_max_and_bitlen);
122        let cdf_prev = if sym != 0 {numeric::fast_divide_30bit_by_16bit(i32::from(self.cdf(sym - 1)) << LOG2_SCALE, inv_max_and_bitlen)} else { 0 };
123        let freq = cdf_sym - cdf_prev;
124        SymStartFreq {
125            range: ProbRange {start: cdf_prev as Prob + 1, // major hax
126                              freq:  freq as Prob - 1, // don't want rounding errors to work out unfavorably
127            },
128            sym: sym,
129        }
130    }
131    #[inline(always)]
132    fn rescaled_cdf(&self, sym: u8) -> i32 {
133        i32::from(self.cdf(sym)) << LOG2_SCALE
134    }
135    #[inline(always)]
136    fn cdf_offset_to_sym_start_and_freq(&self,
137                                        cdf_offset_p: Prob) -> SymStartFreq {
138        let cdfmax = self.max();
139        let inv_max_and_bitlen = lookup_divisor(cdfmax);
140        let rescaled_cdf_offset = ((i32::from(cdf_offset_p) * i32::from(cdfmax)) >> LOG2_SCALE) as i16;
141        /* nice log(n) version which has too much dependent math, apparently, to be efficient
142        let candidate0 = 7u8;
143        let candidate1 = candidate0 - 4 + (((rescaled_cdf_offset >= self.cdf(candidate0)) as u8) << 3); // candidate1=3 or 11
144        let candidate2 = candidate1 - 2 + (((rescaled_cdf_offset >= self.cdf(candidate1)) as u8) << 2); // candidate2=1,5,9 or 13
145        let candidate3 = candidate2 - 1 + (((rescaled_cdf_offset >= self.cdf(candidate2)) as u8) << 1); // candidate3 or 12
146        let final_decision = (rescaled_cdf_offset >= self.cdf(candidate3)) as u8;
147        let sym = candidate3 + final_decision;
148        self.sym_to_start_and_freq(sym)
149         */
150        //        let cdf15 = self.cdf(15);
151        let sym: u8;
152        if rescaled_cdf_offset < self.cdf(0) {
153            return self.sym_to_start_and_freq(0);
154        }
155        if rescaled_cdf_offset < self.cdf(1) {
156            sym = 1;
157        } else if rescaled_cdf_offset < self.cdf(2) {
158            sym = 2;
159        } else if rescaled_cdf_offset < self.cdf(3) {
160            sym = 3;
161        } else if rescaled_cdf_offset < self.cdf(4) {
162            sym = 4;
163        } else if rescaled_cdf_offset < self.cdf(5) {
164            sym = 5;
165        } else if rescaled_cdf_offset < self.cdf(6) {
166            sym = 6;
167        } else if rescaled_cdf_offset < self.cdf(7) {
168            sym = 7;
169        } else if rescaled_cdf_offset < self.cdf(8) {
170            sym = 8;
171        } else if rescaled_cdf_offset < self.cdf(9) {
172            sym = 9;
173        } else if rescaled_cdf_offset < self.cdf(10) {
174            sym = 10;
175        } else if rescaled_cdf_offset < self.cdf(11) {
176            sym = 11;
177        } else if rescaled_cdf_offset < self.cdf(12) {
178            sym = 12;
179        } else if rescaled_cdf_offset < self.cdf(13) {
180            sym = 13;
181        } else if rescaled_cdf_offset < self.cdf(14) {
182            sym = 14;
183        } else {
184            sym = 15;
185        }
186        assert!(sym != 0);
187        assert!(sym <= 15);
188            
189        return self.sym_to_start_and_freq_with_div_hint(sym, inv_max_and_bitlen);
190        /* // this really should be the same speed as above
191        for i in 0..15 {
192            if rescaled_cdf_offset < self.cdf(i as u8) {
193                return self.sym_to_start_and_freq(i);
194            }
195        }
196        self.sym_to_start_and_freq(15)
197*/
198    }
199
200    // These methods are optional because implementing them requires nontrivial bookkeeping.
201    // Only CDFs that are intended for debugging should support them.
202    fn num_samples(&self) -> Option<u32> { None }
203    fn true_entropy(&self) -> Option<f64> { None }
204    fn rolling_entropy(&self) -> Option<f64> { None }
205    fn encoding_cost(&self) -> Option<f64> { None }
206    fn num_variants(&self) -> usize {
207        0
208    }
209    fn variant_cost(&self, variant_index: usize) -> f32 {
210        0.0
211    }
212    fn base_variant_cost(&self) -> f32 {
213        0.0
214    }
215}
216
217#[derive(Clone, Copy)]
218pub struct CDF2 {
219    counts: [u8; 2],
220    pub prob: u8,
221}
222
223impl Default for CDF2 {
224    fn default() -> Self {
225        CDF2 {
226            counts: [1, 1],
227            prob: 128,
228        }
229    }
230}
231
232impl BaseCDF for CDF2 {
233    fn cdf_offset_to_sym_start_and_freq(
234        &self,
235        cdf_offset: Prob) -> SymStartFreq {
236        let bit = ((i32::from(cdf_offset) * i32::from(self.max())) >> LOG2_SCALE) >= i32::from(self.prob);
237        let rescaled_prob = self.div_by_max(i32::from(self.prob) << LOG2_SCALE);
238        SymStartFreq {
239            sym: bit as u8,
240            range: ProbRange {start: if bit {rescaled_prob as Prob} else {0},
241                              freq: if bit {
242                                  ((1 << LOG2_SCALE) - rescaled_prob) as Prob
243                              } else {
244                                  rescaled_prob as Prob
245                              },
246            }
247        }
248    }
249    fn div_by_max(&self, val:i32) -> i32 {
250        return val / i32::from(self.max())
251    }
252    fn num_symbols() -> u8 { 2 }
253    fn cdf(&self, symbol: u8) -> Prob {
254        match symbol {
255            0 => Prob::from(self.prob),
256            1 => 256,
257            _ => { panic!("Symbol out of range"); }
258        }
259    }
260    fn used(&self) -> bool {
261        self.counts[0] != 1 || self.counts[1] != 1
262    }
263    fn max(&self) -> Prob {
264        256
265    }
266    fn log_max(&self) -> Option<i8> {
267        Some(8)
268    }
269}
270
271impl CDF2 {
272    pub fn blend(&mut self, symbol: bool, _speed: &Speed) {
273        let fcount = self.counts[0];
274        let tcount = self.counts[1];
275        debug_assert!(fcount != 0);
276        debug_assert!(tcount != 0);
277
278        let obs = if symbol {1} else {0};
279        let overflow = self.counts[obs] == 0xff;
280        self.counts[obs] = self.counts[obs].wrapping_add(1);
281        if overflow {
282            let not_obs = if symbol {0} else {1};
283            let neverseen = self.counts[not_obs] == 1;
284            if neverseen {
285                self.counts[obs] = 0xff;
286                self.prob = if symbol {0} else {0xff};
287            } else {
288                self.counts[0] = ((1 + u16::from(fcount)) >> 1) as u8;
289                self.counts[1] = ((1 + u16::from(tcount)) >> 1) as u8;
290                self.counts[obs] = 129;
291                self.prob = ((u16::from(self.counts[0]) << 8) / (u16::from(self.counts[0]) + u16::from(self.counts[1]))) as u8;
292            }
293        } else {
294            self.prob = ((u16::from(self.counts[0]) << 8) / (u16::from(fcount) + u16::from(tcount) + 1)) as u8;
295        }
296    }
297}
298#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
299pub struct Speed(i16,i16);
300pub const SPEED_PALETTE_SIZE: usize = 15;
301pub type SpeedPalette = [Speed;SPEED_PALETTE_SIZE];
302impl Speed {
303    pub const ENCODER_DEFAULT_PALETTE: SpeedPalette = [
304        Speed(0, 1024),
305        Speed(2, 1024),
306        Speed(1, 128),
307        Speed(1, 16384),
308        Speed(2, 2048),
309        Speed(4, 1024),
310        Speed(8, 8192),
311        Speed(16, 48),
312        Speed(16, 8192),// old mud
313        Speed(32, 4096),
314        Speed(64, 16384),
315        Speed(128, 256),
316        Speed(128, 16384),
317        Speed(512, 16384),
318        //Speed(1024, 16384),
319        Speed(1664, 16384),
320        ];
321    pub const GEOLOGIC: Speed = Speed(0x0001, 0x4000);
322    pub const GLACIAL: Speed = Speed(0x0004, 0x0a00);
323    pub const MUD: Speed =   Speed(0x0010, 0x2000);
324    pub const SLOW: Speed =  Speed(0x0020, 0x1000);
325    pub const MED: Speed =   Speed(0x0030, 0x4000);
326    pub const FAST: Speed =  Speed(0x0060, 0x4000);
327    pub const PLANE: Speed = Speed(0x0080, 0x4000);
328    pub const ROCKET: Speed =Speed(0x0180, 0x4000);
329    pub fn to_f8_tuple(&self) -> (u8, u8) {
330        (speed_to_u8(self.inc()), speed_to_u8(self.lim()))
331    }
332    pub fn from_f8_tuple(inp: (u8, u8)) -> Self {
333        Speed::new(u8_to_speed(inp.0), u8_to_speed(inp.1))
334    }
335    #[inline(never)]
336    #[cold]
337    pub fn cold_new(inc:i16, max: i16) -> Speed {
338        Self::new(inc, max)
339    }
340    #[inline(always)]
341    pub fn new(inc:i16, max: i16) -> Speed {
342        debug_assert!(inc <= 0x4000); // otherwise some sse hax fail
343        debug_assert!(max <= 0x4000); // otherwise some sse hax fail
344        Speed(inc, max)
345    }
346    #[inline(always)]
347    pub fn lim(&self) -> i16 {
348        let ret = self.1;
349        debug_assert!(ret <= 0x4000); // otherwise some sse hax fail
350        ret
351    }
352    #[inline(always)]
353    pub fn inc_and_gets(&mut self, ander: i16) -> Speed {
354        self.0 &= ander;
355        *self
356    }
357    #[inline(always)]
358    pub fn lim_or_gets(&mut self, orer: i16) {
359        self.1 |= orer;
360    }
361    #[inline(always)]
362    pub fn inc(&self) -> i16 {
363        self.0
364    }
365    #[inline(always)]
366    pub fn set_lim(&mut self, lim: i16) {
367        debug_assert!(lim <= 0x4000); // otherwise some sse hax fail
368        self.1 = lim;
369    }
370    #[inline(always)]
371    pub fn set_inc(&mut self, inc: i16) {
372        debug_assert!(inc <= 0x4000);
373        self.0 = inc;
374    }
375}
376impl core::str::FromStr for Speed {
377    type Err = core::num::ParseIntError;
378    fn from_str(inp:&str) -> Result<Speed, Self::Err> {
379        match inp {
380            "GEOLOGIC" => Ok(Speed::GEOLOGIC),
381            "GLACIAL" => Ok(Speed::GLACIAL),
382            "MUD" => Ok(Speed::MUD),
383            "SLOW" => Ok(Speed::SLOW),
384            "MED" => Ok(Speed::MED),
385            "FAST" => Ok(Speed::FAST),
386            "PLANE" => Ok(Speed::PLANE),
387            "ROCKET" => Ok(Speed::ROCKET),
388            _ => {
389               let mut split_location = 0;
390               for (index, item) in inp.chars().enumerate() {
391                  if item == ',' {
392                     split_location = index;
393                     break
394                  }
395               }
396               let first_num_str = inp.split_at(split_location).0;
397               let second_num_str = inp.split_at(split_location + 1).1;
398               let conv_inc = u16::from_str(first_num_str);
399               let conv_lim = u16::from_str(second_num_str);
400               match conv_inc {
401                  Err(e) => return Err(e),
402                  Ok(inc) => match conv_lim {
403                     Err(e) => return Err(e),
404                     Ok(lim) => {
405                         if lim <= 16384 && inc < 16384 {
406                         return Ok(Speed::new(inc as i16, lim as i16));
407                         } else {
408                             match "65537".parse::<u16>() {
409                                Err(e) => return Err(e),
410                                Ok(f) => unreachable!(),
411                             }
412                         }
413                     },
414                  },
415               }
416            },
417        }
418    }
419}
420
421pub trait CDF16: Sized + Default + Copy + BaseCDF {
422    fn blend(&mut self, symbol: u8, dyn:Speed);
423    fn average(&self, other: &Self, mix_rate: i32) ->Self;
424}
425
426pub const BLEND_FIXED_POINT_PRECISION : i8 = 15;
427pub const CDF_BITS : usize = 15; // 15 bits
428pub const LOG2_SCALE: u32 = CDF_BITS as u32;
429pub const CDF_MAX : Prob = 32_767; // last value is implicitly 32768
430const CDF_LIMIT : i64 = (CDF_MAX as i64) + 1;
431
432
433
434
435#[allow(unused)]
436fn gt(a:Prob, b:Prob) -> Prob {
437    (-((a > b) as i64)) as Prob
438}
439#[allow(unused)]
440fn gte_bool(a:Prob, b:Prob) -> Prob {
441    (a >= b) as Prob
442}
443
444
445
446#[cfg(feature="debug_entropy")]
447#[derive(Clone,Copy,Default)]
448pub struct DebugWrapperCDF16<Cdf16: CDF16> {
449    pub cdf: Cdf16,
450    pub counts: [u32; 16],
451    cost: f64,
452    rolling_entropy_sum: f64
453}
454
455#[cfg(feature="debug_entropy")]
456impl<Cdf16> CDF16 for DebugWrapperCDF16<Cdf16> where Cdf16: CDF16 {
457    fn blend(&mut self, symbol: u8, speed: Speed) {
458        self.counts[symbol as usize] += 1;
459        let p = self.cdf.pdf(symbol) as f64 / self.cdf.max() as f64;
460        self.cost += -log2(p);
461        match self.true_entropy() {
462            None => {},
463            Some(e) => { self.rolling_entropy_sum += e; }
464        }
465        self.cdf.blend(symbol, speed);
466    }
467    fn average(&self, other: &Self, mix_rate: i32) -> Self {
468        // NOTE(jongmin): The notion of averaging for a debug CDF is not well-formed
469        // because its private fields depend on the blend history that's not preserved in averaging.
470        let mut counts_both = [0u32; 16];
471        for i in 0..16 {
472            counts_both[i] = self.counts[i] + other.counts[i];
473        }
474        Self {
475            cdf: self.cdf.average(&other.cdf, mix_rate),
476            counts: counts_both,
477            cost: (self.cost + other.cost),
478            rolling_entropy_sum: (self.rolling_entropy_sum + other.rolling_entropy_sum)
479        }
480    }
481}
482
483#[cfg(feature="debug_entropy")]
484impl<Cdf16> BaseCDF for DebugWrapperCDF16<Cdf16> where Cdf16: CDF16 + BaseCDF {
485    fn num_symbols() -> u8 { 16 }
486    fn cdf(&self, symbol: u8) -> Prob { self.cdf.cdf(symbol) }
487    fn pdf(&self, symbol: u8) -> Prob { self.cdf.pdf(symbol) }
488    fn max(&self) -> Prob { self.cdf.max() }
489    fn log_max(&self) -> Option<i8> { self.cdf.log_max() }
490    fn entropy(&self) -> f64 { self.cdf.entropy() }
491    fn valid(&self) -> bool { self.cdf.valid() }
492    fn div_by_max(&self, val: i32) -> i32 {self.cdf.div_by_max(val)}
493    fn used(&self) -> bool {
494        self.num_samples().unwrap() > 0
495    }
496
497    fn num_samples(&self) -> Option<u32> {
498        let mut sum : u32 = 0;
499        for i in 0..16 {
500            sum += self.counts[i];
501        }
502        Some(sum)
503    }
504    fn true_entropy(&self) -> Option<f64> {
505        let num_samples = self.num_samples().unwrap();
506        if num_samples > 0 {
507            let mut sum : f64 = 0.0;
508            for i in 0..16 {
509                sum += if self.counts[i] == 0 { 0.0f64 } else {
510                    let p = (self.counts[i] as f64) / (num_samples as f64);
511                    p * (log2(-p))
512                };
513            }
514            Some(sum)
515        } else {
516            None
517        }
518    }
519    fn rolling_entropy(&self) -> Option<f64> {
520        match self.num_samples() {
521            None => None,
522            Some(n) => Some(self.rolling_entropy_sum / n as f64)
523        }
524    }
525    fn encoding_cost(&self) -> Option<f64> {
526        Some(self.cost)
527    }
528}
529
530#[cfg(feature="debug_entropy")]
531impl<Cdf16> DebugWrapperCDF16<Cdf16> where Cdf16: CDF16 {
532    fn new(cdf: Cdf16) -> Self {
533        DebugWrapperCDF16::<Cdf16> {
534            cdf: cdf,
535            counts: [0; 16],
536            cost: 0.0,
537            rolling_entropy_sum: 0.0
538        }
539    }
540}
541
542#[cfg(test)]
543#[cfg(feature="debug_entropy")]
544mod test {
545    use super::{BaseCDF, CDF16, Speed};
546    use super::super::{DebugWrapperCDF16, FrequentistCDF16, };
547    type DebugWrapperCDF16Impl = DebugWrapperCDF16<FrequentistCDF16>;
548    declare_common_tests!(DebugWrapperCDF16Impl);
549
550    #[test]
551    fn test_debug_info() {
552        let mut wrapper_cdf = DebugWrapperCDF16::<FrequentistCDF16>::default();
553        let mut reference_cdf = FrequentistCDF16::default();
554        let num_samples = 1234usize;
555        for i in 0..num_samples {
556            wrapper_cdf.blend((i & 0xf) as u8, Speed::MED);
557            reference_cdf.blend((i & 0xf) as u8, Speed::MED);
558        }
559        assert!(wrapper_cdf.num_samples().is_some());
560        assert_eq!(wrapper_cdf.num_samples().unwrap(), num_samples as u32);
561
562        use super::super::common_tests;
563        common_tests::assert_cdf_eq(&reference_cdf, &wrapper_cdf.cdf);
564    }
565}
566pub fn speed_to_u8(data: i16) -> u8 {
567    let length = 16 - data.leading_zeros() as u8;
568    let mantissa = if data != 0 {
569        let rem = data - (1 << (length - 1));
570        (rem << 3) >> (length - 1)
571    } else {
572        0
573    };
574    (length << 3) | mantissa as u8
575}
576
577pub fn u8_to_speed(data: u8) -> i16 {
578    if data < 8 {
579        0
580    } else {
581        let log_val = (data >> 3) - 1;
582        let rem = (i16::from(data) & 0x7) << log_val;
583        (1i16 << log_val) | (rem >> 3)
584    }
585}
586#[cfg(test)]
587mod test {
588    use super::speed_to_u8;
589    use super::u8_to_speed;
590    fn tst_u8_to_speed(data: i16) {
591        assert_eq!(u8_to_speed(speed_to_u8(data)), data);
592    }
593    #[test]
594    fn test_u8_to_speed() {
595        tst_u8_to_speed(0);
596        tst_u8_to_speed(1);
597        tst_u8_to_speed(2);
598        tst_u8_to_speed(3);
599        tst_u8_to_speed(4);
600        tst_u8_to_speed(5);
601        tst_u8_to_speed(6);
602        tst_u8_to_speed(7);
603        tst_u8_to_speed(8);
604        tst_u8_to_speed(10);
605        tst_u8_to_speed(12);
606        tst_u8_to_speed(16);
607        tst_u8_to_speed(24);
608        tst_u8_to_speed(32);
609        tst_u8_to_speed(48);
610        tst_u8_to_speed(64);
611        tst_u8_to_speed(96);
612        tst_u8_to_speed(768);
613        tst_u8_to_speed(1280);
614        tst_u8_to_speed(1536);
615        tst_u8_to_speed(1664);
616    }
617}