splinter_rs/
splinter.rs

1use bytes::{Bytes, BytesMut};
2use culprit::Culprit;
3use either::Either;
4use std::{
5    fmt::Debug,
6    ops::{Bound, RangeBounds, RangeInclusive},
7};
8use zerocopy::{
9    ConvertError, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned,
10    little_endian::{U16, U32},
11};
12
13use crate::{
14    DecodeErr, Segment, SplinterRead, SplinterWrite,
15    bitmap::{BitmapExt, BitmapMutExt},
16    block::{Block, BlockRef},
17    partition::{Partition, PartitionRef},
18    relational::Relation,
19    util::{CopyToOwned, FromSuffix, SerializeContainer},
20};
21
22mod cmp;
23mod cut;
24mod intersection;
25mod merge;
26mod union;
27
28pub const SPLINTER_MAGIC: [u8; 4] = [0xDA, 0xAE, 0x12, 0xDF];
29
30pub const SPLINTER_MAX_VALUE: u32 = u32::MAX;
31
32#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
33#[repr(C)]
34struct Header {
35    magic: [u8; 4],
36}
37
38impl Header {
39    const DEFAULT: Header = Header { magic: SPLINTER_MAGIC };
40
41    fn serialize<B: bytes::BufMut>(&self, out: &mut B) -> usize {
42        out.put_slice(self.as_bytes());
43        Self::serialized_size()
44    }
45
46    const fn serialized_size() -> usize {
47        size_of::<Header>()
48    }
49}
50
51#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
52#[repr(C)]
53struct Footer {
54    partitions: U16,
55    unused: [u8; 2],
56}
57
58impl Footer {
59    fn new(partitions: u16) -> Self {
60        Self {
61            partitions: partitions.into(),
62            unused: [0; 2],
63        }
64    }
65
66    fn serialize<B: bytes::BufMut>(&self, out: &mut B) -> usize {
67        out.put_slice(self.as_bytes());
68        Self::serialized_size()
69    }
70
71    const fn serialized_size() -> usize {
72        size_of::<Footer>()
73    }
74}
75
76/// An owned, compressed bitmap for u32 keys
77#[derive(Default, Clone)]
78pub struct Splinter {
79    partitions: Partition<U32, Partition<U32, Partition<U16, Block>>>,
80}
81
82impl Splinter {
83    pub fn from_slice(data: &[u32]) -> Self {
84        let mut splinter = Self::default();
85        for &key in data {
86            splinter.insert(key);
87        }
88        splinter
89    }
90
91    pub fn from_bytes<T: AsRef<[u8]>>(data: T) -> Result<Self, Culprit<DecodeErr>> {
92        SplinterRef::from_bytes(data).map(Into::into)
93    }
94
95    fn insert_block(&mut self, a: u8, b: u8, c: u8, block: Block) {
96        let partition = self.partitions.get_or_init(a);
97        let partition = partition.get_or_init(b);
98        partition.insert(c, block);
99    }
100
101    /// Computes the serialized size of this Splinter
102    pub fn serialized_size(&self) -> usize {
103        Header::serialized_size() + self.partitions.serialized_size() + Footer::serialized_size()
104    }
105
106    pub fn serialize<B: bytes::BufMut>(&self, out: &mut B) -> usize {
107        let header_size = Header::DEFAULT.serialize(out);
108        let (cardinality, partitions_size) = self.partitions.serialize(out);
109        let footer_size =
110            Footer::new(cardinality.try_into().expect("cardinality overflow")).serialize(out);
111        header_size + partitions_size + footer_size
112    }
113
114    pub fn serialize_to_bytes(&self) -> Bytes {
115        let mut buf = BytesMut::new();
116        self.serialize(&mut buf);
117        buf.freeze()
118    }
119
120    pub fn serialize_to_splinter_ref(&self) -> SplinterRef<Bytes> {
121        SplinterRef::from_bytes(self.serialize_to_bytes()).expect("serialization roundtrip failed")
122    }
123}
124
125impl Debug for Splinter {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("Splinter")
128            .field("num_partitions", &self.partitions.len())
129            .field("cardinality", &self.cardinality())
130            .finish()
131    }
132}
133
134impl<K: Into<u32>> FromIterator<K> for Splinter {
135    fn from_iter<T: IntoIterator<Item = K>>(iter: T) -> Self {
136        let mut splinter = Self::default();
137        for key in iter {
138            splinter.insert(key.into());
139        }
140        splinter
141    }
142}
143
144impl SplinterRead for Splinter {
145    #[inline]
146    fn is_empty(&self) -> bool {
147        self.partitions.is_empty()
148    }
149
150    fn contains(&self, key: u32) -> bool {
151        let [a, b, c, d] = segments(key);
152
153        if let Some(partition) = self.partitions.get(a) {
154            if let Some(partition) = partition.get(b) {
155                if let Some(block) = partition.get(c) {
156                    return block.contains(d);
157                }
158            }
159        }
160
161        false
162    }
163
164    fn cardinality(&self) -> usize {
165        self.partitions
166            .iter()
167            .flat_map(|(_, p)| p.iter())
168            .flat_map(|(_, p)| p.iter())
169            .map(|(_, b)| b.cardinality())
170            .sum()
171    }
172
173    fn iter(&self) -> impl Iterator<Item = u32> + '_ {
174        self.partitions
175            .iter()
176            .flat_map(|(a, p)| p.iter().map(move |(b, p)| (a, b, p)))
177            .flat_map(|(a, b, p)| p.iter().map(move |(c, p)| (a, b, c, p)))
178            .flat_map(|(a, b, c, p)| p.segments().map(move |d| combine_segments(a, b, c, d)))
179    }
180
181    fn range<'a, R>(&'a self, range: R) -> impl Iterator<Item = u32> + 'a
182    where
183        R: RangeBounds<u32> + 'a,
184    {
185        // compute the high, mid, low, and block ranges
186        let Some([ra, rb, rc, rd]) = segment_ranges(range) else {
187            return Either::Left(std::iter::empty());
188        };
189        Either::Right(
190            self.partitions
191                .range(ra.into())
192                .flat_map(move |(a, p)| {
193                    p.range(inner_range(a, ra, rb)).map(move |(b, p)| (a, b, p))
194                })
195                .flat_map(move |(a, b, p)| {
196                    p.range(inner_range(b, rb, rc))
197                        .map(move |(c, p)| (a, b, c, p))
198                })
199                .flat_map(move |(a, b, c, p)| {
200                    p.range(inner_range(c, rc, rd))
201                        .map(move |d| combine_segments(a, b, c, d))
202                }),
203        )
204    }
205
206    fn last(&self) -> Option<u32> {
207        let (a, p) = self.partitions.last()?;
208        let (b, p) = p.last()?;
209        let (c, p) = p.last()?;
210        let d = p.last()?;
211        Some(combine_segments(a, b, c, d))
212    }
213}
214
215impl SplinterWrite for Splinter {
216    fn insert(&mut self, key: u32) -> bool {
217        let [a, b, c, d] = segments(key);
218        let partition = self.partitions.get_or_init(a);
219        let partition = partition.get_or_init(b);
220        let block = partition.get_or_init(c);
221        block.insert(d)
222    }
223}
224
225/// A compressed bitmap for u32 keys operating directly on a slice of bytes
226#[derive(Clone)]
227pub struct SplinterRef<T> {
228    data: T,
229    partitions: usize,
230}
231
232impl<T> SplinterRef<T> {
233    pub fn inner(&self) -> &T {
234        &self.data
235    }
236
237    pub fn into_inner(self) -> T {
238        self.data
239    }
240}
241
242impl<T> SplinterRef<T>
243where
244    T: AsRef<[u8]>,
245{
246    pub fn from_bytes(data: T) -> Result<Self, Culprit<DecodeErr>> {
247        use DecodeErr::*;
248
249        let (header, _) = Ref::<_, Header>::from_prefix(data.as_ref()).map_err(|err| {
250            debug_assert!(matches!(err, ConvertError::Size(_)));
251            InvalidHeader
252        })?;
253        if header.magic != SPLINTER_MAGIC {
254            return Err(InvalidMagic.into());
255        }
256
257        let (_, footer) = Ref::<_, Footer>::from_suffix(data.as_ref()).map_err(|err| {
258            debug_assert!(matches!(err, ConvertError::Size(_)));
259            InvalidFooter
260        })?;
261        let partitions = footer.partitions.get() as usize;
262
263        Ok(SplinterRef { data, partitions })
264    }
265
266    /// Returns the size of this SplinterRef's serialized bytes
267    pub fn size(&self) -> usize {
268        self.data.as_ref().len()
269    }
270
271    pub(crate) fn load_partitions(
272        &self,
273    ) -> PartitionRef<'_, U32, PartitionRef<'_, U32, PartitionRef<'_, U16, BlockRef<'_>>>> {
274        let data = self.data.as_ref();
275        let slice = &data[..data.len() - size_of::<Footer>()];
276        PartitionRef::from_suffix(slice, self.partitions)
277    }
278}
279
280impl<T: AsRef<[u8]>> Debug for SplinterRef<T> {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        f.debug_struct("SplinterRef")
283            .field("num_partitions", &self.partitions)
284            .field("cardinality", &self.cardinality())
285            .finish()
286    }
287}
288
289impl<T: AsRef<[u8]>> From<SplinterRef<T>> for Splinter {
290    fn from(value: SplinterRef<T>) -> Self {
291        value.copy_to_owned()
292    }
293}
294
295impl<T: AsRef<[u8]>> CopyToOwned for SplinterRef<T> {
296    type Owned = Splinter;
297
298    fn copy_to_owned(&self) -> Self::Owned {
299        let partitions = self.load_partitions().copy_to_owned();
300        Splinter { partitions }
301    }
302}
303
304impl<T: AsRef<[u8]>> SplinterRead for SplinterRef<T> {
305    /// Returns `true` if the splinter is empty.
306    ///
307    /// # Examples
308    ///
309    /// ```
310    /// # use splinter_rs::{Splinter, SplinterRead, SplinterWrite};
311    ///
312    /// let mut splinter = Splinter::default().serialize_to_splinter_ref();
313    /// assert!(splinter.is_empty());
314    ///
315    /// let mut splinter = Splinter::default();
316    /// splinter.insert(1);
317    /// let splinter = splinter.serialize_to_splinter_ref();
318    /// assert!(!splinter.is_empty());
319    /// ```
320    #[inline]
321    fn is_empty(&self) -> bool {
322        self.load_partitions().is_empty()
323    }
324
325    /// Returns `true` if the splinter contains the given key.
326    ///
327    /// # Examples
328    ///
329    /// ```
330    /// # use splinter_rs::{Splinter, SplinterRead, SplinterWrite};
331    ///
332    /// let mut splinter = Splinter::default();
333    /// splinter.insert(1);
334    /// splinter.insert(3);
335    /// let splinter = splinter.serialize_to_splinter_ref();
336    ///
337    /// assert!(splinter.contains(1));
338    /// assert!(!splinter.contains(2));
339    /// assert!(splinter.contains(3));
340    /// ```
341    fn contains(&self, key: u32) -> bool {
342        let [a, b, c, d] = segments(key);
343
344        if let Some(partition) = self.load_partitions().get(a) {
345            if let Some(partition) = partition.get(b) {
346                if let Some(block) = partition.get(c) {
347                    return block.contains(d);
348                }
349            }
350        }
351
352        false
353    }
354
355    /// Calculates the total number of values stored in the set.
356    ///
357    /// # Examples
358    ///
359    /// ```
360    /// # use splinter_rs::{Splinter, SplinterRead, SplinterWrite};
361    ///
362    /// let mut splinter = Splinter::default();
363    /// splinter.insert(6);
364    /// splinter.insert(1);
365    /// splinter.insert(3);
366    /// let splinter = splinter.serialize_to_splinter_ref();
367    ///
368    /// assert_eq!(3, splinter.cardinality());
369    /// ```
370    fn cardinality(&self) -> usize {
371        let mut sum = 0;
372        for (_, partition) in self.load_partitions().iter() {
373            for (_, partition) in partition.iter() {
374                sum += partition.cardinality();
375            }
376        }
377        sum
378    }
379
380    /// Returns an sorted [`Iterator`] over all keys.
381    ///
382    /// # Examples
383    ///
384    /// ```
385    /// # use splinter_rs::{Splinter, SplinterRead, SplinterWrite};
386    ///
387    /// let mut splinter = Splinter::default();
388    /// splinter.insert(6);
389    /// splinter.insert(1);
390    /// splinter.insert(3);
391    /// let splinter = splinter.serialize_to_splinter_ref();
392    ///
393    /// assert_eq!(&[1, 3, 6], &*splinter.iter().collect::<Vec<_>>());
394    /// ```
395    fn iter(&self) -> impl Iterator<Item = u32> + '_ {
396        self.load_partitions()
397            .into_iter()
398            .flat_map(|(a, p)| p.into_iter().map(move |(b, p)| (a, b, p)))
399            .flat_map(|(a, b, p)| p.into_iter().map(move |(c, p)| (a, b, c, p)))
400            .flat_map(|(a, b, c, p)| p.into_segments().map(move |d| combine_segments(a, b, c, d)))
401    }
402
403    /// Returns an sorted [`Iterator`] over all keys contained by the provided range.
404    ///
405    /// # Examples
406    ///
407    /// ```
408    /// # use splinter_rs::{Splinter, SplinterRead, SplinterWrite};
409    ///
410    /// let mut splinter = Splinter::default();
411    /// splinter.insert(6);
412    /// splinter.insert(1);
413    /// splinter.insert(3);
414    /// splinter.insert(5);
415    /// splinter.insert(9);
416    /// let splinter = splinter.serialize_to_splinter_ref();
417    ///
418    /// assert_eq!(&[3, 5, 6], &*splinter.range(3..=6).collect::<Vec<_>>());
419    /// ```
420    fn range<'a, R>(&'a self, range: R) -> impl Iterator<Item = u32> + 'a
421    where
422        R: RangeBounds<u32> + 'a,
423    {
424        // compute the high, mid, low, and block ranges
425        let Some([ra, rb, rc, rd]) = segment_ranges(range) else {
426            return Either::Left(std::iter::empty());
427        };
428        Either::Right(
429            self.load_partitions()
430                .into_range(ra.into())
431                .flat_map(move |(a, p)| {
432                    p.into_range(inner_range(a, ra, rb))
433                        .map(move |(b, p)| (a, b, p))
434                })
435                .flat_map(move |(a, b, p)| {
436                    p.into_range(inner_range(b, rb, rc))
437                        .map(move |(c, p)| (a, b, c, p))
438                })
439                .flat_map(move |(a, b, c, p)| {
440                    p.into_range(inner_range(c, rc, rd))
441                        .map(move |d| combine_segments(a, b, c, d))
442                }),
443        )
444    }
445
446    /// Returns the last key in the set
447    ///
448    /// # Examples
449    ///
450    /// ```
451    /// # use splinter_rs::{Splinter, SplinterRead, SplinterWrite};
452    ///
453    /// let mut splinter = Splinter::default();
454    /// splinter.insert(6);
455    /// splinter.insert(1);
456    /// splinter.insert(3);
457    ///
458    /// let splinter = splinter.serialize_to_splinter_ref();
459    /// assert_eq!(Some(6), splinter.last());
460    /// ```
461    fn last(&self) -> Option<u32> {
462        let (a, p) = self.load_partitions().last()?;
463        let (b, p) = p.last()?;
464        let (c, p) = p.last()?;
465        let d = p.last()?;
466        Some(combine_segments(a, b, c, d))
467    }
468}
469
470/// split the key into 4 8-bit segments
471#[inline]
472fn segments(key: u32) -> [Segment; 4] {
473    key.to_be_bytes()
474}
475
476#[inline]
477fn combine_segments(a: Segment, b: Segment, c: Segment, d: Segment) -> u32 {
478    u32::from_be_bytes([a, b, c, d])
479}
480
481#[derive(Debug, Clone, Copy)]
482struct SegmentRange {
483    start: Segment,
484    end: Segment,
485}
486
487impl From<SegmentRange> for RangeInclusive<Segment> {
488    fn from(val: SegmentRange) -> Self {
489        val.start..=val.end
490    }
491}
492
493/// Split a range of keys into 4 inclusive ranges corresponding to the high,
494/// mid, low, and block segments.
495///
496/// Returns None if the input range is empty.
497#[inline]
498fn segment_ranges<R: RangeBounds<u32>>(range: R) -> Option<[SegmentRange; 4]> {
499    use Bound::*;
500    let (start_bound, end_bound) = (range.start_bound().cloned(), range.end_bound().cloned());
501    let is_empty = match (start_bound, end_bound) {
502        (_, Excluded(u32::MIN)) | (Excluded(u32::MAX), _) => true,
503        (Included(start), Excluded(end))
504        | (Excluded(start), Included(end))
505        | (Excluded(start), Excluded(end)) => start >= end,
506        (Included(start), Included(end)) => start > end,
507        _ => false,
508    };
509    if is_empty {
510        return None;
511    }
512
513    let start = match start_bound {
514        Unbounded => [0; 4],
515        Included(segment) => segments(segment),
516        Excluded(segment) => segments(segment.saturating_add(1)),
517    };
518    let end = match end_bound {
519        Unbounded => [u8::MAX; 4],
520        Included(segment) => segments(segment),
521        Excluded(segment) => segments(segment.checked_sub(1).expect("end segment underflow")),
522    };
523    // zip the two arrays together
524    Some(std::array::from_fn(|i| SegmentRange {
525        start: start[i],
526        end: end[i],
527    }))
528}
529
530#[inline]
531fn inner_range(
532    key: Segment,
533    key_range: SegmentRange,
534    inner_range: SegmentRange,
535) -> RangeInclusive<Segment> {
536    let SegmentRange { start: s, end: e } = key_range;
537    if key == s && key == e {
538        inner_range.into()
539    } else if key == s {
540        inner_range.start..=u8::MAX
541    } else if key == e {
542        0..=inner_range.end
543    } else {
544        0..=u8::MAX
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use std::io;
551
552    use crate::testutil::{SetGen, mksplinter, mksplinter_ref};
553
554    use super::*;
555    use roaring::RoaringBitmap;
556
557    #[test]
558    fn test_splinter_sanity() {
559        // fill up the first partition and sparse fill up the second partition
560        let values = (0..65535)
561            .chain((65536..85222).step_by(7))
562            .collect::<Vec<_>>();
563
564        // build a splinter from the values
565        let splinter = mksplinter(values.iter().copied());
566
567        // check that all expected keys are present
568        for &i in &values {
569            if !splinter.contains(i) {
570                splinter.contains(i); // break here for debugging
571                panic!("missing key: {i}");
572            }
573        }
574
575        // check that some keys are not present
576        assert!(!splinter.contains(65535), "unexpected key: 65535");
577        assert!(!splinter.contains(90999), "unexpected key: 90999");
578    }
579
580    #[test]
581    fn test_roundtrip_sanity() {
582        let assert_round_trip = |splinter: Splinter| {
583            let estimated_size = splinter.serialized_size();
584            let splinter_ref = SplinterRef::from_bytes(splinter.serialize_to_bytes()).unwrap();
585            assert_eq!(
586                splinter_ref.size(),
587                estimated_size,
588                "serialized size matches estimated size"
589            );
590            assert_eq!(
591                splinter.cardinality(),
592                splinter_ref.cardinality(),
593                "cardinality equal"
594            );
595            assert_eq!(splinter, splinter_ref, "Splinter == SplinterRef");
596            assert_eq!(
597                splinter,
598                splinter_ref.copy_to_owned(),
599                "Splinter == Splinter"
600            );
601            assert_eq!(
602                splinter_ref.copy_to_owned().serialize_to_bytes(),
603                splinter.serialize_to_bytes(),
604                "deterministic serialization"
605            );
606        };
607
608        assert_round_trip(mksplinter(0..0));
609        assert_round_trip(mksplinter(0..1));
610        assert_round_trip(mksplinter(u32::MAX - 10..u32::MAX));
611        assert_round_trip(mksplinter(0..10));
612        assert_round_trip(mksplinter(0..=255));
613        assert_round_trip(mksplinter(0..=4096));
614        assert_round_trip(mksplinter(0..=16384));
615        assert_round_trip(mksplinter(1512..=3258));
616        assert_round_trip(mksplinter((0..=16384).step_by(7)));
617    }
618
619    #[test]
620    fn test_splinter_ref_sanity() {
621        // fill up the first partition and sparse fill up the second partition
622        let values = (0..65535)
623            .chain((65536..85222).step_by(7))
624            .collect::<Vec<_>>();
625
626        // build a splinter from the values
627        let splinter = mksplinter_ref(values.iter().copied());
628
629        // check that all expected keys are present
630        for &i in &values {
631            if !splinter.contains(i) {
632                splinter.contains(i); // break here for debugging
633                panic!("missing key: {i}");
634            }
635        }
636
637        // check that the splinter can enumerate all keys
638        assert!(itertools::equal(values, splinter.iter()));
639
640        // check that some keys are not present
641        assert!(!splinter.contains(65535), "unexpected key: 65535");
642        assert!(!splinter.contains(90999), "unexpected key: 90999");
643    }
644
645    /// verify Splinter::range and SplinterRef::range
646    #[test]
647    pub fn test_range() {
648        #[track_caller]
649        fn case<I1, R, I2>(name: &str, set: I1, range: R, expected: I2)
650        where
651            I1: IntoIterator<Item = u32> + Clone,
652            R: RangeBounds<u32> + Clone,
653            I2: IntoIterator<Item = u32> + Clone,
654        {
655            let expected = expected.into_iter().collect::<Vec<_>>();
656
657            let output = mksplinter(set.clone())
658                .range(range.clone())
659                .collect::<Vec<_>>();
660            assert!(
661                output == expected,
662                "Splinter::range failed for case: {name}; output: {:?}; expected: {:?}",
663                (output.first(), output.last(), output.len()),
664                (expected.first(), expected.last(), expected.len()),
665            );
666
667            let output = mksplinter_ref(set).range(range).collect::<Vec<_>>();
668            assert!(
669                output == expected,
670                "SplinterRef::range failed for case: {name}; output: {:?}; expected: {:?}",
671                (output.first(), output.last(), output.len()),
672                (expected.first(), expected.last(), expected.len()),
673            );
674        }
675
676        case("empty", [], .., []);
677        case("one element", [156106], .., [156106]);
678        case(
679            "one element, inclusive",
680            [156106],
681            156105..=156106,
682            [156106],
683        );
684        case("one element, exclusive", [156106], 156105..156107, [156106]);
685
686        case("zero", [0], .., [0]);
687        case("zero, inclusive end", [0], ..=0, [0]);
688        case("zero, inclusive start", [0], 0.., [0]);
689        case("zero, exclusive end", [0], ..0, []);
690        case("zero, exclusive start", [0], 1.., []);
691
692        case("max element", [u32::MAX], .., [u32::MAX]);
693        case(
694            "max element, inclusive end",
695            [u32::MAX],
696            ..=u32::MAX,
697            [u32::MAX],
698        );
699        case(
700            "max element, inclusive start",
701            [u32::MAX],
702            u32::MAX..,
703            [u32::MAX],
704        );
705        case("max element, exclusive end", [u32::MAX], ..u32::MAX, []);
706        case(
707            "max element, exclusive start",
708            [u32::MAX],
709            u32::MAX - 1..,
710            [u32::MAX],
711        );
712
713        case(
714            "simple set",
715            [12, 16, 19, 1000002, 1000016, 1000046],
716            ..,
717            [12, 16, 19, 1000002, 1000016, 1000046],
718        );
719        case(
720            "simple set, inclusive",
721            [12, 16, 19, 1000002, 1000016, 1000046],
722            19..=1000016,
723            [19, 1000002, 1000016],
724        );
725        case(
726            "simple set, exclusive",
727            [12, 16, 19, 1000002, 1000016, 1000046],
728            19..1000016,
729            [19, 1000002],
730        );
731
732        let mut set_gen = SetGen::new(0xDEAD_BEEF);
733
734        let set = set_gen.distributed(4, 8, 8, 128, 32768);
735        let expected = set[1024..16384].to_vec();
736        let range = expected[0]..=expected[expected.len() - 1];
737        case("256 half full blocks", set.clone(), range, expected);
738
739        let expected = set[1024..].to_vec();
740        let range = expected[0]..;
741        case(
742            "256 half full blocks, unbounded right",
743            set.clone(),
744            range,
745            expected,
746        );
747
748        let expected = set[..16384].to_vec();
749        let range = ..=expected[expected.len() - 1];
750        case(
751            "256 half full blocks, unbounded left",
752            set.clone(),
753            range,
754            expected,
755        );
756    }
757
758    /// Heuristic analyzer: prints patterns found in the data which could be
759    /// exploited by lz4 to improve compression
760    pub fn analyze_compression_patterns(data: &[u8]) {
761        use std::collections::HashMap;
762
763        let len = data.len();
764        if len == 0 {
765            println!("empty slice");
766            return;
767        }
768        println!("length: {len} bytes");
769
770        // --- zeros ---
771        let (mut zeros, mut longest_run, mut run) = (0usize, 0usize, 0usize);
772        for &b in data {
773            if b == 0 {
774                zeros += 1;
775                run += 1;
776                longest_run = longest_run.max(run);
777            } else {
778                run = 0;
779            }
780        }
781        println!(
782            "zeros: {zeros} ({:.2}%), longest run: {longest_run}",
783            zeros as f64 * 100.0 / len as f64
784        );
785
786        // --- histogram / entropy ---
787        let mut freq = [0u32; 256];
788        for &b in data {
789            freq[b as usize] += 1;
790        }
791        let entropy: f64 = freq
792            .iter()
793            .filter(|&&c| c != 0)
794            .map(|&c| {
795                let p = c as f64 / len as f64;
796                -p * p.log2()
797            })
798            .sum();
799        println!("shannon entropy ≈ {entropy:.3} bits/byte (max 8)");
800
801        // --- repeated 8-byte blocks ---
802        const BLOCK: usize = 8;
803        if len >= BLOCK {
804            let mut map: HashMap<&[u8], u32> = HashMap::new();
805            for chunk in data.chunks_exact(BLOCK) {
806                *map.entry(chunk).or_default() += 1;
807            }
808
809            let mut duplicate_bytes = 0u32;
810            let mut top: Option<(&[u8], u32)> = None;
811
812            for (&k, &v) in map.iter() {
813                if v > 1 {
814                    duplicate_bytes += (v - 1) * BLOCK as u32;
815                    if top.map_or(true, |(_, max)| v > max) {
816                        top = Some((k, v));
817                    }
818                }
819            }
820
821            if let Some((bytes, count)) = top {
822                println!(
823                    "repeated 8-byte blocks: {} duplicate bytes; most common occurs {count}× (bytes {:02X?})",
824                    duplicate_bytes, bytes
825                );
826            } else {
827                println!("no duplicated 8-byte blocks");
828            }
829        }
830
831        println!("analysis complete");
832    }
833
834    #[test]
835    fn test_expected_compression() {
836        let to_roaring = |set: Vec<u32>| {
837            let mut buf = io::Cursor::new(Vec::new());
838            let mut bmp = RoaringBitmap::from_sorted_iter(set).unwrap();
839            bmp.optimize();
840            bmp.serialize_into(&mut buf).unwrap();
841            buf.into_inner()
842        };
843
844        struct Report {
845            name: &'static str,
846            baseline: usize,
847            //        (actual, expected)
848            splinter: (usize, usize),
849            roaring: (usize, usize),
850
851            splinter_lz4: usize,
852            roaring_lz4: usize,
853        }
854
855        let mut reports = vec![];
856
857        let mut run_test = |name: &'static str,
858                            set: Vec<u32>,
859                            expected_splinter: usize,
860                            expected_roaring: usize| {
861            println!("-------------------------------------");
862            println!("running test: {name}");
863
864            let splinter = mksplinter(set.clone()).serialize_to_bytes();
865            let roaring = to_roaring(set.clone());
866
867            analyze_compression_patterns(&splinter);
868
869            let splinter_lz4 = lz4::block::compress(&splinter, None, false).unwrap();
870            let roaring_lz4 = lz4::block::compress(&roaring, None, false).unwrap();
871
872            // verify round trip
873            assert_eq!(
874                splinter,
875                lz4::block::decompress(&splinter_lz4, Some(splinter.len() as i32)).unwrap()
876            );
877            assert_eq!(
878                roaring,
879                lz4::block::decompress(&roaring_lz4, Some(roaring.len() as i32)).unwrap()
880            );
881
882            reports.push(Report {
883                name,
884                baseline: set.len() * std::mem::size_of::<u32>(),
885                splinter: (splinter.len(), expected_splinter),
886                roaring: (roaring.len(), expected_roaring),
887
888                splinter_lz4: splinter_lz4.len(),
889                roaring_lz4: roaring_lz4.len(),
890            });
891        };
892
893        let mut set_gen = SetGen::new(0xDEAD_BEEF);
894
895        // empty splinter
896        run_test("empty", vec![], 8, 8);
897
898        // 1 element in set
899        let set = set_gen.distributed(1, 1, 1, 1, 1);
900        run_test("1 element", set, 25, 18);
901
902        // 1 fully dense block
903        let set = set_gen.distributed(1, 1, 1, 256, 256);
904        run_test("1 dense block", set, 24, 15);
905
906        // 1 half full block
907        let set = set_gen.distributed(1, 1, 1, 128, 128);
908        run_test("1 half full block", set, 56, 247);
909
910        // 1 sparse block
911        let set = set_gen.distributed(1, 1, 1, 16, 16);
912        run_test("1 sparse block", set, 40, 48);
913
914        // 8 half full blocks
915        let set = set_gen.distributed(1, 1, 8, 128, 1024);
916        run_test("8 half full blocks", set, 308, 2064);
917
918        // 8 sparse blocks
919        let set = set_gen.distributed(1, 1, 8, 2, 16);
920        run_test("8 sparse blocks", set, 68, 48);
921
922        // 64 half full blocks
923        let set = set_gen.distributed(4, 4, 4, 128, 8192);
924        run_test("64 half full blocks", set, 2432, 16486);
925
926        // 64 sparse blocks
927        let set = set_gen.distributed(4, 4, 4, 2, 128);
928        run_test("64 sparse blocks", set, 512, 392);
929
930        // 256 half full blocks
931        let set = set_gen.distributed(4, 8, 8, 128, 32768);
932        run_test("256 half full blocks", set, 9440, 65520);
933
934        // 256 sparse blocks
935        let set = set_gen.distributed(4, 8, 8, 2, 512);
936        run_test("256 sparse blocks", set, 1760, 1288);
937
938        // 512 half full blocks
939        let set = set_gen.distributed(8, 8, 8, 128, 65536);
940        run_test("512 half full blocks", set, 18872, 130742);
941
942        // 512 sparse blocks
943        let set = set_gen.distributed(8, 8, 8, 2, 1024);
944        run_test("512 sparse blocks", set, 3512, 2568);
945
946        // the rest of the compression tests use 4k elements
947        let elements = 4096;
948
949        // fully dense splinter
950        let set = set_gen.distributed(1, 1, 16, 256, elements);
951        run_test("fully dense", set, 84, 75);
952
953        // 128 elements per block; dense partitions
954        let set = set_gen.distributed(1, 1, 32, 128, elements);
955        run_test("128/block; dense", set, 1172, 8195);
956
957        // 32 elements per block; dense partitions
958        let set = set_gen.distributed(1, 1, 128, 32, elements);
959        run_test("32/block; dense", set, 4532, 8208);
960
961        // 16 element per block; dense low partitions
962        let set = set_gen.distributed(1, 1, 256, 16, elements);
963        run_test("16/block; dense", set, 4884, 8208);
964
965        // 128 elements per block; sparse mid partitions
966        let set = set_gen.distributed(1, 32, 1, 128, elements);
967        run_test("128/block; sparse mid", set, 1358, 8300);
968
969        // 128 elements per block; sparse high partitions
970        let set = set_gen.distributed(32, 1, 1, 128, elements);
971        run_test("128/block; sparse high", set, 1544, 8290);
972
973        // 1 element per block; sparse mid partitions
974        let set = set_gen.distributed(1, 256, 16, 1, elements);
975        run_test("1/block; sparse mid", set, 21774, 10248);
976
977        // 1 element per block; sparse high partitions
978        let set = set_gen.distributed(256, 16, 1, 1, elements);
979        run_test("1/block; sparse high", set, 46344, 40968);
980
981        // 1/block; spread low
982        let set = set_gen.dense(1, 16, 256, 1, elements);
983        run_test("1/block; spread low", set, 16494, 8328);
984
985        // each partition is dense
986        let set = set_gen.dense(8, 8, 8, 8, elements);
987        run_test("dense throughout", set, 6584, 2700);
988
989        // the lowest partitions are dense
990        let set = set_gen.dense(1, 1, 64, 64, elements);
991        run_test("dense low", set, 2292, 267);
992
993        // the mid and low partitions are dense
994        let set = set_gen.dense(1, 32, 16, 8, elements);
995        run_test("dense mid/low", set, 6350, 2376);
996
997        // fully random sets of varying sizes
998        run_test("random/32", set_gen.random(32), 546, 328);
999        run_test("random/256", set_gen.random(256), 3655, 2560);
1000        run_test("random/1024", set_gen.random(1024), 12499, 10168);
1001        run_test("random/4096", set_gen.random(4096), 45582, 39952);
1002        run_test("random/16384", set_gen.random(16384), 163758, 148600);
1003        run_test("random/65535", set_gen.random(65535), 543584, 462190);
1004
1005        let mut fail_test = false;
1006
1007        println!(
1008            "{:30} {:12} {:>6} {:>10} {:>10} {:>10}",
1009            "test", "bitmap", "size", "expected", "relative", "ok"
1010        );
1011        for report in &reports {
1012            println!(
1013                "{:30} {:12} {:6} {:10} {:>10} {:>10}",
1014                report.name,
1015                "Splinter",
1016                report.splinter.0,
1017                report.splinter.1,
1018                "1.00",
1019                if report.splinter.0 == report.splinter.1 {
1020                    "ok"
1021                } else {
1022                    fail_test = true;
1023                    "FAIL"
1024                }
1025            );
1026            let diff = report.roaring.0 as f64 / report.splinter.0 as f64;
1027            println!(
1028                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1029                "",
1030                "Roaring",
1031                report.roaring.0,
1032                report.roaring.1,
1033                diff,
1034                if report.roaring.0 != report.roaring.1 {
1035                    fail_test = true;
1036                    "FAIL"
1037                } else if diff < 1.0 {
1038                    "<"
1039                } else {
1040                    "ok"
1041                }
1042            );
1043            let diff = report.splinter_lz4 as f64 / report.splinter.0 as f64;
1044            println!(
1045                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1046                "",
1047                "Splinter LZ4",
1048                report.splinter_lz4,
1049                report.splinter_lz4,
1050                diff,
1051                if report.splinter.0 <= report.splinter_lz4 {
1052                    ">"
1053                } else {
1054                    "<"
1055                }
1056            );
1057            let diff = report.roaring_lz4 as f64 / report.splinter_lz4 as f64;
1058            println!(
1059                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1060                "",
1061                "Roaring LZ4",
1062                report.roaring_lz4,
1063                report.roaring_lz4,
1064                diff,
1065                if report.splinter_lz4 <= report.roaring_lz4 {
1066                    "ok"
1067                } else {
1068                    "<"
1069                }
1070            );
1071            let diff = report.baseline as f64 / report.splinter.0 as f64;
1072            println!(
1073                "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1074                "",
1075                "Baseline",
1076                report.baseline,
1077                report.baseline,
1078                diff,
1079                if report.splinter.0 <= report.baseline {
1080                    "ok"
1081                } else {
1082                    // we don't fail the test, just report for informational purposes;
1083                    "<"
1084                }
1085            );
1086        }
1087
1088        // calculate average compression ratio (splinter_lz4 / splinter)
1089        let avg_ratio = reports
1090            .iter()
1091            .map(|r| r.splinter_lz4 as f64 / r.splinter.0 as f64)
1092            .sum::<f64>()
1093            / reports.len() as f64;
1094
1095        println!("average compression ratio (splinter_lz4 / splinter): {avg_ratio:.2}");
1096
1097        assert!(!fail_test, "compression test failed");
1098    }
1099}