Skip to main content

dsi_bitstream/utils/
stats.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Inria
3 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
4 * SPDX-FileCopyrightText: 2024 Tommaso Fontana
5 *
6 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
7 */
8#[cfg(feature = "mem_dbg")]
9use mem_dbg::{MemDbg, MemSize};
10
11use crate::prelude::{
12    Codes, bit_len_vbyte, len_delta, len_exp_golomb, len_gamma, len_golomb, len_omega, len_pi,
13    len_rice, len_zeta,
14};
15use core::fmt::Debug;
16
17#[cfg(feature = "std")]
18use crate::dispatch::{CodesRead, CodesWrite};
19#[cfg(feature = "std")]
20use crate::prelude::Endianness;
21#[cfg(feature = "std")]
22use crate::prelude::{DynamicCodeRead, DynamicCodeWrite, StaticCodeRead, StaticCodeWrite};
23#[cfg(feature = "std")]
24use anyhow::Result;
25#[cfg(feature = "std")]
26use std::sync::Mutex;
27
28#[cfg(feature = "alloc")]
29use alloc::vec;
30#[cfg(feature = "alloc")]
31use alloc::vec::Vec;
32
33/// Keeps track of the space needed to store a stream of integers using
34/// different codes.
35///
36/// This structure can be used to determine empirically which code provides the
37/// best compression for a given stream. You have to [update the
38/// structure](Self::update) with the integers in the stream; at any time, you
39/// can examine the statistics or call [`best_code`](Self::best_code) to get the
40/// best code.
41///
42/// The structure keeps tracks of the codes for which the module
43/// [`code_consts`](crate::dispatch::code_consts) provide constants.
44#[derive(Debug, Copy, Clone, PartialEq, Eq)]
45#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
46pub struct CodesStats<
47    // How many ζ codes to consider.
48    const ZETA: usize = 10,
49    // How many Golomb codes to consider.
50    const GOLOMB: usize = 10,
51    // How many Exponential Golomb codes to consider.
52    const EXP_GOLOMB: usize = 10,
53    // How many Rice codes to consider.
54    const RICE: usize = 10,
55    // How many Pi and Pi web codes to consider.
56    const PI: usize = 10,
57> {
58    /// The total number of elements observed.
59    pub total: u64,
60    /// The total space used to store the elements if
61    /// they were stored using the unary code.
62    pub unary: u64,
63    /// The total space used to store the elements if
64    /// they were stored using the gamma code.
65    pub gamma: u64,
66    /// The total space used to store the elements if
67    /// they were stored using the delta code.
68    pub delta: u64,
69    /// The total space used to store the elements if
70    /// they were stored using the omega code.
71    pub omega: u64,
72    /// The total space used to store the elements if
73    /// they were stored using the variable byte code.
74    pub vbyte: u64,
75    /// The total space used to store the elements if
76    /// they were stored using the zeta code.
77    pub zeta: [u64; ZETA],
78    /// The total space used to store the elements if
79    /// they were stored using the Golomb code.
80    pub golomb: [u64; GOLOMB],
81    /// The total space used to store the elements if
82    /// they were stored using the exponential Golomb code.
83    pub exp_golomb: [u64; EXP_GOLOMB],
84    /// The total space used to store the elements if
85    /// they were stored using the Rice code.
86    pub rice: [u64; RICE],
87    /// The total space used to store the elements if
88    /// they were stored using the Pi code.
89    pub pi: [u64; PI],
90}
91
92impl<
93    const ZETA: usize,
94    const GOLOMB: usize,
95    const EXP_GOLOMB: usize,
96    const RICE: usize,
97    const PI: usize,
98> core::default::Default for CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
99{
100    fn default() -> Self {
101        Self {
102            total: 0,
103            unary: 0,
104            gamma: 0,
105            delta: 0,
106            omega: 0,
107            vbyte: 0,
108            zeta: [0; ZETA],
109            golomb: [0; GOLOMB],
110            exp_golomb: [0; EXP_GOLOMB],
111            rice: [0; RICE],
112            pi: [0; PI],
113        }
114    }
115}
116
117impl<
118    const ZETA: usize,
119    const GOLOMB: usize,
120    const EXP_GOLOMB: usize,
121    const RICE: usize,
122    const PI: usize,
123> CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
124{
125    /// Update the stats with the lengths of the codes for `n` and return
126    /// `n` for convenience.
127    pub fn update(&mut self, n: u64) -> u64 {
128        self.update_many(n, 1)
129    }
130
131    #[inline]
132    pub fn update_many(&mut self, n: u64, count: u64) -> u64 {
133        self.total += count;
134        self.unary += (n + 1) * count;
135        self.gamma += len_gamma(n) as u64 * count;
136        self.delta += len_delta(n) as u64 * count;
137        self.omega += len_omega(n) as u64 * count;
138        self.vbyte += bit_len_vbyte(n) as u64 * count;
139
140        for (k, val) in self.zeta.iter_mut().enumerate() {
141            *val += (len_zeta(n, (k + 1) as _) as u64) * count;
142        }
143        for (b, val) in self.golomb.iter_mut().enumerate() {
144            *val += (len_golomb(n, (b + 1) as _) as u64) * count;
145        }
146        for (k, val) in self.exp_golomb.iter_mut().enumerate() {
147            *val += (len_exp_golomb(n, k as _) as u64) * count;
148        }
149        for (log2_b, val) in self.rice.iter_mut().enumerate() {
150            *val += (len_rice(n, log2_b as _) as u64) * count;
151        }
152        // +2 because π0 = gamma and π1 = zeta_2
153        for (k, val) in self.pi.iter_mut().enumerate() {
154            *val += (len_pi(n, (k + 2) as _) as u64) * count;
155        }
156        n
157    }
158
159    // Combines additively this stats with another one.
160    pub fn add(&mut self, rhs: &Self) {
161        self.total += rhs.total;
162        self.unary += rhs.unary;
163        self.gamma += rhs.gamma;
164        self.delta += rhs.delta;
165        self.omega += rhs.omega;
166        self.vbyte += rhs.vbyte;
167        for (a, b) in self.zeta.iter_mut().zip(rhs.zeta.iter()) {
168            *a += *b;
169        }
170        for (a, b) in self.golomb.iter_mut().zip(rhs.golomb.iter()) {
171            *a += *b;
172        }
173        for (a, b) in self.exp_golomb.iter_mut().zip(rhs.exp_golomb.iter()) {
174            *a += *b;
175        }
176        for (a, b) in self.rice.iter_mut().zip(rhs.rice.iter()) {
177            *a += *b;
178        }
179        for (a, b) in self.pi.iter_mut().zip(rhs.pi.iter()) {
180            *a += *b;
181        }
182    }
183
184    /// Returns the best code for the stream and its space usage.
185    pub fn best_code(&self) -> (Codes, u64) {
186        let mut best = (Codes::Unary, self.unary);
187        if self.gamma < best.1 {
188            best = (Codes::Gamma, self.gamma);
189        }
190        if self.delta < best.1 {
191            best = (Codes::Delta, self.delta);
192        }
193        if self.omega < best.1 {
194            best = (Codes::Omega, self.omega);
195        }
196        if self.vbyte < best.1 {
197            best = (Codes::VByteBe, self.vbyte);
198        }
199        for (k, val) in self.zeta.iter().enumerate() {
200            if *val < best.1 {
201                best = (Codes::Zeta((k + 1) as _), *val);
202            }
203        }
204        for (b, val) in self.golomb.iter().enumerate() {
205            if *val < best.1 {
206                best = (Codes::Golomb((b + 1) as _), *val);
207            }
208        }
209        for (k, val) in self.exp_golomb.iter().enumerate() {
210            if *val < best.1 {
211                best = (Codes::ExpGolomb(k as _), *val);
212            }
213        }
214        for (log2_b, val) in self.rice.iter().enumerate() {
215            if *val < best.1 {
216                best = (Codes::Rice(log2_b as _), *val);
217            }
218        }
219        for (k, val) in self.pi.iter().enumerate() {
220            if *val < best.1 {
221                best = (Codes::Pi((k + 2) as _), *val);
222            }
223        }
224        best
225    }
226
227    /// Returns a vector of all codes and their space usage, in ascending order by space usage.
228    #[cfg(feature = "alloc")]
229    pub fn get_codes(&self) -> Vec<(Codes, u64)> {
230        let mut codes = vec![
231            (Codes::Unary, self.unary),
232            (Codes::Gamma, self.gamma),
233            (Codes::Delta, self.delta),
234            (Codes::Omega, self.omega),
235            (Codes::VByteBe, self.vbyte),
236        ];
237        for (k, val) in self.zeta.iter().enumerate() {
238            codes.push((Codes::Zeta((k + 1) as _), *val));
239        }
240        for (b, val) in self.golomb.iter().enumerate() {
241            codes.push((Codes::Golomb((b + 1) as _), *val));
242        }
243        for (k, val) in self.exp_golomb.iter().enumerate() {
244            codes.push((Codes::ExpGolomb(k as _), *val));
245        }
246        for (log2_b, val) in self.rice.iter().enumerate() {
247            codes.push((Codes::Rice(log2_b as _), *val));
248        }
249        for (k, val) in self.pi.iter().enumerate() {
250            codes.push((Codes::Pi((k + 2) as _), *val));
251        }
252        // sort them by length
253        codes.sort_by_key(|&(_, len)| len);
254        codes
255    }
256
257    /// Returns the number of bits used by the given code.
258    pub fn bits_for(&self, code: Codes) -> Option<u64> {
259        match code {
260            Codes::Unary => Some(self.unary),
261            Codes::Gamma => Some(self.gamma),
262            Codes::Delta => Some(self.delta),
263            Codes::Omega => Some(self.omega),
264            Codes::VByteBe | Codes::VByteLe => Some(self.vbyte),
265            Codes::Zeta(k) => self.zeta.get(k.checked_sub(1)?).copied(),
266            Codes::Golomb(b) => self.golomb.get(b.checked_sub(1)? as usize).copied(),
267            Codes::ExpGolomb(k) => self.exp_golomb.get(k).copied(),
268            Codes::Rice(log2_b) => self.rice.get(log2_b).copied(),
269            Codes::Pi(k) => self.pi.get(k.checked_sub(2)?).copied(),
270        }
271    }
272}
273
274/// Combines additively this stats with another one.
275impl<
276    const ZETA: usize,
277    const GOLOMB: usize,
278    const EXP_GOLOMB: usize,
279    const RICE: usize,
280    const PI: usize,
281> core::ops::AddAssign for CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
282{
283    fn add_assign(&mut self, rhs: Self) {
284        self.add(&rhs);
285    }
286}
287
288/// Combines additively this stats with another one creating a new one.
289impl<
290    const ZETA: usize,
291    const GOLOMB: usize,
292    const EXP_GOLOMB: usize,
293    const RICE: usize,
294    const PI: usize,
295> core::ops::Add for CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
296{
297    type Output = Self;
298
299    fn add(self, rhs: Self) -> Self {
300        let mut res = self;
301        res += rhs;
302        res
303    }
304}
305
306/// Allow to call .sum() on an iterator of CodesStats.
307impl<
308    const ZETA: usize,
309    const GOLOMB: usize,
310    const EXP_GOLOMB: usize,
311    const RICE: usize,
312    const PI: usize,
313> core::iter::Sum for CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
314{
315    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
316        iter.fold(Self::default(), |a, b| a + b)
317    }
318}
319
320#[cfg(feature = "serde")]
321impl<
322    const ZETA: usize,
323    const GOLOMB: usize,
324    const EXP_GOLOMB: usize,
325    const RICE: usize,
326    const PI: usize,
327> serde::Serialize for CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
328{
329    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
330        use serde::ser::SerializeStruct;
331
332        let mut state = serializer.serialize_struct("CodesStats", 11)?;
333        state.serialize_field("total", &self.total)?;
334        state.serialize_field("unary", &self.unary)?;
335        state.serialize_field("gamma", &self.gamma)?;
336        state.serialize_field("delta", &self.delta)?;
337        state.serialize_field("omega", &self.omega)?;
338        state.serialize_field("vbyte", &self.vbyte)?;
339        // these are array which don't play well with serde, so we convert them to slices
340        state.serialize_field("zeta", &self.zeta.as_slice())?;
341        state.serialize_field("golomb", &self.golomb.as_slice())?;
342        state.serialize_field("exp_golomb", &self.exp_golomb.as_slice())?;
343        state.serialize_field("rice", &self.rice.as_slice())?;
344        state.serialize_field("pi", &self.pi.as_slice())?;
345        state.end()
346    }
347}
348
349#[cfg(feature = "serde")]
350impl<
351    'de,
352    const ZETA: usize,
353    const GOLOMB: usize,
354    const EXP_GOLOMB: usize,
355    const RICE: usize,
356    const PI: usize,
357> serde::Deserialize<'de> for CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
358{
359    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
360        use serde::de::{MapAccess, Visitor};
361
362        struct CodesStatsVisitor<
363            const ZETA: usize,
364            const GOLOMB: usize,
365            const EXP_GOLOMB: usize,
366            const RICE: usize,
367            const PI: usize,
368        >;
369
370        impl<
371            'de,
372            const ZETA: usize,
373            const GOLOMB: usize,
374            const EXP_GOLOMB: usize,
375            const RICE: usize,
376            const PI: usize,
377        > Visitor<'de> for CodesStatsVisitor<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
378        {
379            type Value = CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>;
380
381            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
382                formatter.write_str("struct CodesStats")
383            }
384
385            fn visit_map<V: MapAccess<'de>>(self, mut map: V) -> Result<Self::Value, V::Error> {
386                let mut total = None;
387                let mut unary = None;
388                let mut gamma = None;
389                let mut delta = None;
390                let mut omega = None;
391                let mut vbyte = None;
392                let mut zeta: Option<[u64; ZETA]> = None;
393                let mut golomb: Option<[u64; GOLOMB]> = None;
394                let mut exp_golomb: Option<[u64; EXP_GOLOMB]> = None;
395                let mut rice: Option<[u64; RICE]> = None;
396                let mut pi: Option<[u64; PI]> = None;
397
398                // Helper to deserialize a Vec<u64> into a fixed-size array
399                fn vec_to_array<E: serde::de::Error, const N: usize>(
400                    v: Vec<u64>,
401                ) -> Result<[u64; N], E> {
402                    v.try_into().map_err(|v: Vec<u64>| {
403                        serde::de::Error::invalid_length(v.len(), &N.to_string().as_str())
404                    })
405                }
406
407                while let Some(key) = map.next_key::<&str>()? {
408                    match key {
409                        "total" => total = Some(map.next_value()?),
410                        "unary" => unary = Some(map.next_value()?),
411                        "gamma" => gamma = Some(map.next_value()?),
412                        "delta" => delta = Some(map.next_value()?),
413                        "omega" => omega = Some(map.next_value()?),
414                        "vbyte" => vbyte = Some(map.next_value()?),
415                        "zeta" => zeta = Some(vec_to_array(map.next_value()?)?),
416                        "golomb" => golomb = Some(vec_to_array(map.next_value()?)?),
417                        "exp_golomb" => exp_golomb = Some(vec_to_array(map.next_value()?)?),
418                        "rice" => rice = Some(vec_to_array(map.next_value()?)?),
419                        "pi" => pi = Some(vec_to_array(map.next_value()?)?),
420                        _ => {
421                            let _ = map.next_value::<serde::de::IgnoredAny>()?;
422                        }
423                    }
424                }
425
426                Ok(CodesStats {
427                    total: total.unwrap_or_default(),
428                    unary: unary.unwrap_or_default(),
429                    gamma: gamma.unwrap_or_default(),
430                    delta: delta.unwrap_or_default(),
431                    omega: omega.unwrap_or_default(),
432                    vbyte: vbyte.unwrap_or_default(),
433                    zeta: zeta.unwrap_or([0; ZETA]),
434                    golomb: golomb.unwrap_or([0; GOLOMB]),
435                    exp_golomb: exp_golomb.unwrap_or([0; EXP_GOLOMB]),
436                    rice: rice.unwrap_or([0; RICE]),
437                    pi: pi.unwrap_or([0; PI]),
438                })
439            }
440        }
441
442        deserializer.deserialize_struct(
443            "CodesStats",
444            &[
445                "total",
446                "unary",
447                "gamma",
448                "delta",
449                "omega",
450                "vbyte",
451                "zeta",
452                "golomb",
453                "exp_golomb",
454                "rice",
455                "pi",
456            ],
457            CodesStatsVisitor::<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>,
458        )
459    }
460}
461
462#[derive(Debug)]
463#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
464#[cfg(feature = "std")]
465/// A struct that can wrap `Code` and compute `CodesStats` for a given stream.
466pub struct CodesStatsWrapper<
467    W,
468    // How many ζ codes to consider.
469    const ZETA: usize = 10,
470    // How many Golomb codes to consider.
471    const GOLOMB: usize = 20,
472    // How many Exponential Golomb codes to consider.
473    const EXP_GOLOMB: usize = 10,
474    // How many Rice codes to consider.
475    const RICE: usize = 10,
476    // How many Pi and Pi web codes to consider.
477    const PI: usize = 10,
478> {
479    // TODO!: figure out how we can do this without a lock.
480    // This is needed because the [`DynamicCodeRead`] and [`DynamicCodeWrite`] traits must have
481    // &self and not &mut self.
482    stats: Mutex<CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>>,
483    wrapped: W,
484}
485
486#[cfg(feature = "std")]
487impl<
488    W,
489    const ZETA: usize,
490    const GOLOMB: usize,
491    const EXP_GOLOMB: usize,
492    const RICE: usize,
493    const PI: usize,
494> CodesStatsWrapper<W, ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
495{
496    /// Creates a new `CodesStatsWrapper` with the given wrapped value.
497    pub fn new(wrapped: W) -> Self {
498        Self {
499            stats: Mutex::new(CodesStats::default()),
500            wrapped,
501        }
502    }
503
504    /// Returns a reference to the stats.
505    pub fn stats(&self) -> &Mutex<CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>> {
506        &self.stats
507    }
508
509    /// Consumes the wrapper and returns the inner wrapped value and the stats.
510    pub fn into_inner(self) -> (W, CodesStats<ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>) {
511        (self.wrapped, self.stats.into_inner().unwrap())
512    }
513}
514
515#[cfg(feature = "std")]
516impl<
517    W: DynamicCodeRead,
518    const ZETA: usize,
519    const GOLOMB: usize,
520    const EXP_GOLOMB: usize,
521    const RICE: usize,
522    const PI: usize,
523> DynamicCodeRead for CodesStatsWrapper<W, ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
524{
525    #[inline]
526    fn read<E: Endianness, CR: CodesRead<E> + ?Sized>(
527        &self,
528        reader: &mut CR,
529    ) -> Result<u64, CR::Error> {
530        let res = self.wrapped.read(reader)?;
531        self.stats.lock().unwrap().update(res);
532        Ok(res)
533    }
534}
535
536#[cfg(feature = "std")]
537impl<
538    W: StaticCodeRead<E, CR>,
539    const ZETA: usize,
540    const GOLOMB: usize,
541    const EXP_GOLOMB: usize,
542    const RICE: usize,
543    const PI: usize,
544    E: Endianness,
545    CR: CodesRead<E> + ?Sized,
546> StaticCodeRead<E, CR> for CodesStatsWrapper<W, ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
547{
548    #[inline]
549    fn read(&self, reader: &mut CR) -> Result<u64, CR::Error> {
550        let res = self.wrapped.read(reader)?;
551        self.stats.lock().unwrap().update(res);
552        Ok(res)
553    }
554}
555
556#[cfg(feature = "std")]
557impl<
558    W: DynamicCodeWrite,
559    const ZETA: usize,
560    const GOLOMB: usize,
561    const EXP_GOLOMB: usize,
562    const RICE: usize,
563    const PI: usize,
564> DynamicCodeWrite for CodesStatsWrapper<W, ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
565{
566    #[inline]
567    fn write<E: Endianness, CW: CodesWrite<E> + ?Sized>(
568        &self,
569        writer: &mut CW,
570        value: u64,
571    ) -> Result<usize, CW::Error> {
572        let res = self.wrapped.write(writer, value)?;
573        self.stats.lock().unwrap().update(value);
574        Ok(res)
575    }
576}
577
578#[cfg(feature = "std")]
579impl<
580    W: StaticCodeWrite<E, CW>,
581    const ZETA: usize,
582    const GOLOMB: usize,
583    const EXP_GOLOMB: usize,
584    const RICE: usize,
585    const PI: usize,
586    E: Endianness,
587    CW: CodesWrite<E> + ?Sized,
588> StaticCodeWrite<E, CW> for CodesStatsWrapper<W, ZETA, GOLOMB, EXP_GOLOMB, RICE, PI>
589{
590    #[inline]
591    fn write(&self, writer: &mut CW, value: u64) -> Result<usize, CW::Error> {
592        let res = self.wrapped.write(writer, value)?;
593        self.stats.lock().unwrap().update(value);
594        Ok(res)
595    }
596}
597
598#[cfg(test)]
599#[cfg(feature = "serde")]
600mod serde_tests {
601    use super::*;
602
603    #[test]
604    fn test_serde_code_stats() {
605        let mut stats: CodesStats = CodesStats::default();
606        for i in 0..100 {
607            stats.update(i);
608        }
609        let json = serde_json::to_string(&stats).unwrap();
610        let deserialized: CodesStats = serde_json::from_str(&json).unwrap();
611        assert_eq!(stats, deserialized);
612    }
613
614    #[test]
615    fn test_roundtrip_different_sizes() {
616        let mut stats: CodesStats<10, 20, 5, 8, 6> = CodesStats::default();
617        for i in 0..1000 {
618            stats.update(i);
619        }
620        let json = serde_json::to_string_pretty(&stats).unwrap();
621        let deserialized: CodesStats<10, 20, 5, 8, 6> = serde_json::from_str(&json).unwrap();
622        assert_eq!(stats, deserialized);
623    }
624
625    #[test]
626    #[should_panic]
627    fn test_mismatched_sizes() {
628        let mut stats: CodesStats<10, 20, 5, 8, 6> = CodesStats::default();
629        for i in 0..1000 {
630            stats.update(i);
631        }
632        let json = serde_json::to_string_pretty(&stats).unwrap();
633        // This should panic because the JSON has 20 golomb values but we expect 21
634        let _deserialized: CodesStats<10, 21, 5, 8, 6> = serde_json::from_str(&json).unwrap();
635    }
636}