1use std::{fmt::Debug, ops::RangeBounds};
2
3use bytes::Bytes;
4
5use crate::{
6 Encodable, Optimizable, SplinterRef,
7 codec::{encoder::Encoder, footer::Footer},
8 level::High,
9 partition::Partition,
10 traits::{PartitionRead, PartitionWrite},
11 util::RangeExt,
12};
13
14#[derive(Clone, PartialEq, Eq, Default, Debug)]
56pub struct Splinter(Partition<High>);
57
58static_assertions::const_assert_eq!(std::mem::size_of::<Splinter>(), 40);
59
60impl Splinter {
61 pub const EMPTY: Self = Splinter(Partition::EMPTY);
63
64 pub const FULL: Self = Splinter(Partition::Full);
66
67 pub fn encode_to_splinter_ref(&self) -> SplinterRef<Bytes> {
85 SplinterRef { data: self.encode_to_bytes() }
86 }
87
88 #[inline(always)]
89 pub(crate) fn new(inner: Partition<High>) -> Self {
90 Self(inner)
91 }
92
93 #[inline(always)]
94 pub(crate) fn inner(&self) -> &Partition<High> {
95 &self.0
96 }
97
98 #[inline(always)]
99 pub(crate) fn inner_mut(&mut self) -> &mut Partition<High> {
100 &mut self.0
101 }
102}
103
104impl FromIterator<u32> for Splinter {
105 fn from_iter<I: IntoIterator<Item = u32>>(iter: I) -> Self {
106 Self(Partition::<High>::from_iter(iter))
107 }
108}
109
110impl<R: RangeBounds<u32>> From<R> for Splinter {
111 fn from(range: R) -> Self {
112 if let Some(range) = range.try_into_inclusive() {
113 if range.start() == &u32::MIN && range.end() == &u32::MAX {
114 Self::FULL
115 } else {
116 Self(Partition::<High>::from(range))
117 }
118 } else {
119 Self::EMPTY
121 }
122 }
123}
124
125impl PartitionRead<High> for Splinter {
126 #[inline]
140 fn cardinality(&self) -> usize {
141 self.0.cardinality()
142 }
143
144 #[inline]
158 fn is_empty(&self) -> bool {
159 self.0.is_empty()
160 }
161
162 #[inline]
176 fn contains(&self, value: u32) -> bool {
177 self.0.contains(value)
178 }
179
180 #[inline]
198 fn position(&self, value: u32) -> Option<usize> {
199 self.0.position(value)
200 }
201
202 #[inline]
220 fn rank(&self, value: u32) -> usize {
221 self.0.rank(value)
222 }
223
224 #[inline]
241 fn select(&self, idx: usize) -> Option<u32> {
242 self.0.select(idx)
243 }
244
245 #[inline]
260 fn last(&self) -> Option<u32> {
261 self.0.last()
262 }
263
264 #[inline]
277 fn iter(&self) -> impl Iterator<Item = u32> {
278 self.0.iter()
279 }
280
281 #[inline]
305 fn contains_all<R: RangeBounds<u32>>(&self, values: R) -> bool {
306 self.0.contains_all(values)
307 }
308
309 #[inline]
334 fn contains_any<R: RangeBounds<u32>>(&self, values: R) -> bool {
335 self.0.contains_any(values)
336 }
337}
338
339impl PartitionWrite<High> for Splinter {
340 #[inline]
364 fn insert(&mut self, value: u32) -> bool {
365 self.0.insert(value)
366 }
367
368 #[inline]
391 fn remove(&mut self, value: u32) -> bool {
392 self.0.remove(value)
393 }
394
395 #[inline]
420 fn remove_range<R: RangeBounds<u32>>(&mut self, values: R) {
421 self.0.remove_range(values);
422 }
423}
424
425impl Encodable for Splinter {
426 fn encoded_size(&self) -> usize {
427 self.0.encoded_size() + std::mem::size_of::<Footer>()
428 }
429
430 fn encode<B: bytes::BufMut>(&self, encoder: &mut Encoder<B>) {
431 self.0.encode(encoder);
432 encoder.write_footer();
433 }
434}
435
436impl Optimizable for Splinter {
437 #[inline]
438 fn optimize(&mut self) {
439 self.0.optimize();
440 }
441}
442
443impl Extend<u32> for Splinter {
444 #[inline]
445 fn extend<T: IntoIterator<Item = u32>>(&mut self, iter: T) {
446 self.0.extend(iter);
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use std::ops::Bound;
453
454 use super::*;
455 use crate::{
456 codec::Encodable,
457 level::{Level, Low},
458 testutil::{SetGen, mksplinter, ratio_to_marks, test_partition_read, test_partition_write},
459 traits::Optimizable,
460 };
461 use itertools::{Itertools, assert_equal};
462 use proptest::{
463 collection::{hash_set, vec},
464 proptest,
465 };
466 use rand::{SeedableRng, seq::index};
467 use roaring::RoaringBitmap;
468
469 #[test]
470 fn test_sanity() {
471 let mut splinter = Splinter::EMPTY;
472
473 assert!(splinter.insert(1));
474 assert!(!splinter.insert(1));
475 assert!(splinter.contains(1));
476
477 let values = [1024, 123, 16384];
478 for v in values {
479 assert!(splinter.insert(v));
480 assert!(splinter.contains(v));
481 assert!(!splinter.contains(v + 1));
482 }
483
484 for i in 0..8192 + 10 {
485 splinter.insert(i);
486 }
487
488 splinter.optimize();
489
490 dbg!(&splinter);
491
492 let expected = splinter.iter().collect_vec();
493 test_partition_read(&splinter, &expected);
494 test_partition_write(&mut splinter);
495 }
496
497 #[test]
498 fn test_wat() {
499 let mut set_gen = SetGen::new(0xDEAD_BEEF);
500 let set = set_gen.random_max(64, 4096);
501 let baseline_size = set.len() * 4;
502
503 let mut splinter = Splinter::from_iter(set.iter().copied());
504 splinter.optimize();
505
506 dbg!(&splinter, splinter.encoded_size(), baseline_size, set.len());
507 itertools::assert_equal(splinter.iter(), set);
508 }
509
510 #[test]
511 fn test_splinter_write() {
512 let mut splinter = Splinter::from_iter(0u32..16384);
513 test_partition_write(&mut splinter);
514 }
515
516 #[test]
517 fn test_splinter_optimize_growth() {
518 let mut splinter = Splinter::EMPTY;
519 let mut rng = rand::rngs::StdRng::seed_from_u64(0xdeadbeef);
520 let set = index::sample(&mut rng, Low::MAX_LEN, 8);
521 dbg!(&splinter);
522 for i in set {
523 splinter.insert(i as u32);
524 dbg!(&splinter);
525 }
526 }
527
528 #[test]
529 fn test_splinter_from_range() {
530 let splinter = Splinter::from(..);
531 assert_eq!(splinter.cardinality(), (u32::MAX as usize) + 1);
532
533 let mut splinter = Splinter::from(1..);
534 assert_eq!(splinter.cardinality(), u32::MAX as usize);
535
536 splinter.remove(1024);
537 assert_eq!(splinter.cardinality(), (u32::MAX as usize) - 1);
538
539 let mut count = 1;
540 for i in (2048..=256000).step_by(1024) {
541 splinter.remove(i);
542 count += 1
543 }
544 assert_eq!(splinter.cardinality(), (u32::MAX as usize) - count);
545 }
546
547 proptest! {
548 #[test]
549 fn test_splinter_read_proptest(set in hash_set(0u32..16384, 0..1024)) {
550 let expected = set.iter().copied().sorted().collect_vec();
551 test_partition_read(&Splinter::from_iter(set), &expected);
552 }
553
554
555 #[test]
556 fn test_splinter_proptest(set in vec(0u32..16384, 0..1024)) {
557 let splinter = mksplinter(&set);
558 if set.is_empty() {
559 assert!(!splinter.contains(123));
560 } else {
561 let lookup = set[set.len() / 3];
562 assert!(splinter.contains(lookup));
563 }
564 }
565
566 #[test]
567 fn test_splinter_opt_proptest(set in vec(0u32..16384, 0..1024)) {
568 let mut splinter = mksplinter(&set);
569 splinter.optimize();
570 if set.is_empty() {
571 assert!(!splinter.contains(123));
572 } else {
573 let lookup = set[set.len() / 3];
574 assert!(splinter.contains(lookup));
575 }
576 }
577
578 #[test]
579 fn test_splinter_eq_proptest(set in vec(0u32..16384, 0..1024)) {
580 let a = mksplinter(&set);
581 assert_eq!(a, a.clone());
582 }
583
584 #[test]
585 fn test_splinter_opt_eq_proptest(set in vec(0u32..16384, 0..1024)) {
586 let mut a = mksplinter(&set);
587 let b = mksplinter(&set);
588 a.optimize();
589 assert_eq!(a, b);
590 }
591
592 #[test]
593 fn test_splinter_remove_range_proptest(set in hash_set(0u32..16384, 0..1024)) {
594 let expected = set.iter().copied().sorted().collect_vec();
595 let mut splinter = mksplinter(&expected);
596 if let Some(last) = expected.last() {
597 splinter.remove_range((Bound::Excluded(last), Bound::Unbounded));
598 assert_equal(splinter.iter(), expected);
599 }
600 }
601 }
602
603 use hegel::generators;
606
607 #[hegel::test]
609 fn test_iter_sorted_and_deduped(tc: hegel::TestCase) {
610 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
611 let splinter = Splinter::from_iter(values);
612 let items: Vec<u32> = splinter.iter().collect();
613 for window in items.windows(2) {
614 assert!(
615 window[0] < window[1],
616 "iter not strictly sorted: {window:?}"
617 );
618 }
619 }
620
621 #[hegel::test]
623 fn test_cardinality_equals_iter_count(tc: hegel::TestCase) {
624 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
625 let splinter = Splinter::from_iter(values);
626 assert_eq!(splinter.cardinality(), splinter.iter().count());
627 }
628
629 #[hegel::test]
631 fn test_contains_all_inserted_values(tc: hegel::TestCase) {
632 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
633 let splinter = Splinter::from_iter(values.iter().copied());
634 for &v in &values {
635 assert!(splinter.contains(v), "missing value {v}");
636 }
637 }
638
639 #[hegel::test]
641 fn test_insert_returns_correct_bool(tc: hegel::TestCase) {
642 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
643 let mut splinter = Splinter::EMPTY;
644 let mut seen = std::collections::HashSet::new();
645 for v in values {
646 let was_new = seen.insert(v);
647 assert_eq!(splinter.insert(v), was_new);
648 }
649 }
650
651 #[hegel::test]
653 fn test_remove_returns_correct_bool(tc: hegel::TestCase) {
654 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
655 let mut splinter = Splinter::from_iter(values.iter().copied());
656 let to_remove: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
657 let mut present: std::collections::HashSet<u32> = values.into_iter().collect();
658 for v in to_remove {
659 let was_present = present.remove(&v);
660 assert_eq!(splinter.remove(v), was_present);
661 }
662 }
663
664 #[hegel::test]
666 fn test_optimize_preserves_elements(tc: hegel::TestCase) {
667 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
668 let mut splinter = Splinter::from_iter(values.iter().copied());
669 let before: Vec<u32> = splinter.iter().collect();
670 splinter.optimize();
671 let after: Vec<u32> = splinter.iter().collect();
672 assert_eq!(before, after);
673 }
674
675 #[hegel::test]
677 fn test_optimize_idempotent(tc: hegel::TestCase) {
678 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
679 let mut splinter = Splinter::from_iter(values);
680 splinter.optimize();
681 let after_first = splinter.clone();
682 splinter.optimize();
683 assert_eq!(splinter, after_first);
684 }
685
686 #[hegel::test]
688 fn test_select_position_inverse(tc: hegel::TestCase) {
689 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()).min_size(1));
690 let splinter = Splinter::from_iter(values);
691 let cardinality = splinter.cardinality();
692 let idx = tc.draw(generators::integers::<usize>().max_value(cardinality - 1));
693 let value = splinter.select(idx).unwrap();
694 assert_eq!(splinter.position(value), Some(idx));
695 }
696
697 #[hegel::test]
699 fn test_rank_consistency(tc: hegel::TestCase) {
700 let values: Vec<u32> =
701 tc.draw(generators::vecs(generators::integers::<u32>().max_value(65535)).min_size(1));
702 let splinter = Splinter::from_iter(values);
703 let query = tc.draw(generators::integers::<u32>().max_value(65535));
704 let rank = splinter.rank(query);
705 let count_leq = splinter.iter().filter(|&v| v <= query).count();
706 assert_eq!(rank, count_leq);
707 }
708
709 #[hegel::test]
711 fn test_encode_decode_roundtrip(tc: hegel::TestCase) {
712 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
713 let mut splinter = Splinter::from_iter(values);
714 splinter.optimize();
715 let encoded = splinter.encode_to_bytes();
716 let splinter_ref = SplinterRef::from_bytes(encoded).unwrap();
717 let decoded = splinter_ref.decode_to_splinter();
718 assert_eq!(splinter, decoded);
719 }
720
721 #[hegel::test]
723 fn test_encoded_size_matches(tc: hegel::TestCase) {
724 let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
725 let mut splinter = Splinter::from_iter(values);
726 splinter.optimize();
727 let declared_size = splinter.encoded_size();
728 let actual_bytes = splinter.encode_to_bytes();
729 assert_eq!(declared_size, actual_bytes.len());
730 }
731
732 #[hegel::test]
734 fn test_from_range_contains_all(tc: hegel::TestCase) {
735 let mut a = tc.draw(generators::integers::<u16>());
736 let mut b = tc.draw(generators::integers::<u16>());
737 if a > b {
738 (a, b) = (b, a);
739 }
740 let start = a as u32;
741 let end = b as u32;
742 let splinter = Splinter::from(start..=end);
743 assert_eq!(splinter.cardinality(), (end - start + 1) as usize);
744 assert!(splinter.contains(start));
745 assert!(splinter.contains(end));
746 if start > 0 {
747 assert!(!splinter.contains(start - 1));
748 }
749 if end < u32::MAX {
750 assert!(!splinter.contains(end + 1));
751 }
752 }
753
754 #[test]
755 fn test_expected_compression() {
756 fn to_roaring(set: impl Iterator<Item = u32>) -> Vec<u8> {
757 let mut buf = std::io::Cursor::new(Vec::new());
758 let mut bmp = RoaringBitmap::from_sorted_iter(set).unwrap();
759 bmp.optimize();
760 bmp.serialize_into(&mut buf).unwrap();
761 buf.into_inner()
762 }
763
764 struct Report {
765 name: String,
766 baseline: usize,
767 splinter: (usize, usize),
769 roaring: (usize, usize),
770
771 splinter_lz4: usize,
772 roaring_lz4: usize,
773 }
774
775 let mut reports = vec![];
776
777 let mut run_test = |name: &str,
778 set: Vec<u32>,
779 expected_set_size: usize,
780 expected_splinter: usize,
781 expected_roaring: usize| {
782 assert_eq!(set.len(), expected_set_size, "Set size mismatch");
783
784 let mut splinter = Splinter::from_iter(set.clone());
785 splinter.optimize();
786 itertools::assert_equal(splinter.iter(), set.iter().copied());
787
788 test_partition_read(&splinter, &set);
789
790 let expected_size = splinter.encoded_size();
791 let splinter = splinter.encode_to_bytes();
792
793 assert_eq!(
794 splinter.len(),
795 expected_size,
796 "actual encoded size does not match declared encoded size"
797 );
798
799 let roaring = to_roaring(set.iter().copied());
800
801 let splinter_lz4 = lz4::block::compress(&splinter, None, false).unwrap();
802 let roaring_lz4 = lz4::block::compress(&roaring, None, false).unwrap();
803
804 assert_eq!(
806 splinter,
807 lz4::block::decompress(&splinter_lz4, Some(splinter.len() as i32)).unwrap()
808 );
809 assert_eq!(
810 roaring,
811 lz4::block::decompress(&roaring_lz4, Some(roaring.len() as i32)).unwrap()
812 );
813
814 reports.push(Report {
815 name: name.to_owned(),
816 baseline: set.len() * std::mem::size_of::<u32>(),
817 splinter: (splinter.len(), expected_splinter),
818 roaring: (roaring.len(), expected_roaring),
819
820 splinter_lz4: splinter_lz4.len(),
821 roaring_lz4: roaring_lz4.len(),
822 });
823 };
824
825 let mut set_gen = SetGen::new(0xDEAD_BEEF);
826
827 run_test("empty", vec![], 0, 13, 8);
829
830 let set = set_gen.distributed(1, 1, 1, 1);
832 run_test("1 element", set, 1, 21, 18);
833
834 let set = set_gen.distributed(1, 1, 1, 256);
836 run_test("1 dense block", set, 256, 25, 15);
837
838 let set = set_gen.distributed(1, 1, 1, 128);
840 run_test("1 half full block", set, 128, 72, 255);
841
842 let set = set_gen.distributed(1, 1, 1, 16);
844 run_test("1 sparse block", set, 16, 57, 48);
845
846 let set = set_gen.distributed(1, 1, 8, 128);
848 run_test("8 half full blocks", set, 1024, 338, 2003);
849
850 let set = set_gen.distributed(1, 1, 8, 2);
852 run_test("8 sparse blocks", set, 16, 67, 48);
853
854 let set = set_gen.distributed(4, 4, 4, 128);
856 run_test("64 half full blocks", set, 8192, 2634, 16452);
857
858 let set = set_gen.distributed(4, 4, 4, 2);
860 run_test("64 sparse blocks", set, 128, 450, 392);
861
862 let set = set_gen.distributed(4, 8, 8, 128);
864 run_test("256 half full blocks", set, 32768, 10074, 65580);
865
866 let set = set_gen.distributed(4, 8, 8, 2);
868 run_test("256 sparse blocks", set, 512, 1402, 1288);
869
870 let set = set_gen.distributed(8, 8, 8, 128);
872 run_test("512 half full blocks", set, 65536, 20134, 130810);
873
874 let set = set_gen.distributed(8, 8, 8, 2);
876 run_test("512 sparse blocks", set, 1024, 2790, 2568);
877
878 let elements = 4096;
880
881 let set = set_gen.distributed(1, 1, 16, 256);
883 run_test("fully dense", set, elements, 87, 63);
884
885 let set = set_gen.distributed(1, 1, 32, 128);
887 run_test("128/block; dense", set, elements, 1250, 8208);
888
889 let set = set_gen.distributed(1, 1, 128, 32);
891 run_test("32/block; dense", set, elements, 4802, 8208);
892
893 let set = set_gen.distributed(1, 1, 256, 16);
895 run_test("16/block; dense", set, elements, 5666, 8208);
896
897 let set = set_gen.distributed(1, 32, 1, 128);
899 run_test("128/block; sparse mid", set, elements, 1529, 8282);
900
901 let set = set_gen.distributed(32, 1, 1, 128);
903 run_test("128/block; sparse high", set, elements, 1870, 8224);
904
905 let set = set_gen.distributed(1, 256, 16, 1);
907 run_test("1/block; sparse mid", set, elements, 10521, 10248);
908
909 let set = set_gen.distributed(256, 16, 1, 1);
911 run_test("1/block; sparse high", set, elements, 15374, 40968);
912
913 let set = set_gen.dense(1, 16, 256, 1);
915 run_test("1/block; spread low", set, elements, 8377, 8328);
916
917 let set = set_gen.dense(8, 8, 8, 8);
919 run_test("dense throughout", set, elements, 2790, 2700);
920
921 let set = set_gen.dense(1, 1, 64, 64);
923 run_test("dense low", set, elements, 291, 267);
924
925 let set = set_gen.dense(1, 32, 16, 8);
927 run_test("dense mid/low", set, elements, 2393, 2376);
928
929 let random_cases = [
930 (32, High::MAX_LEN, 145, 328),
932 (256, High::MAX_LEN, 1041, 2544),
933 (1024, High::MAX_LEN, 4113, 10168),
934 (4096, High::MAX_LEN, 15374, 40056),
935 (16384, High::MAX_LEN, 52238, 148656),
936 (65536, High::MAX_LEN, 199694, 461288),
937 (32, 65536, 99, 80),
939 (256, 65536, 547, 528),
940 (1024, 65536, 2083, 2064),
941 (4096, 65536, 5666, 8208),
942 (65536, 65536, 25, 15),
943 (8, 1024, 49, 32),
945 (16, 1024, 67, 48),
946 (32, 1024, 94, 80),
947 (64, 1024, 126, 144),
948 (128, 1024, 183, 272),
949 ];
950
951 for (count, max, expected_splinter, expected_roaring) in random_cases {
952 let name = if max == High::MAX_LEN {
953 format!("random/{count}")
954 } else {
955 format!("random/{count}/{max}")
956 };
957 run_test(
958 &name,
959 set_gen.random_max(count, max),
960 count,
961 expected_splinter,
962 expected_roaring,
963 );
964 }
965
966 let mut fail_test = false;
967
968 println!("{}", "-".repeat(83));
969 println!(
970 "{:30} {:12} {:>6} {:>10} {:>10} {:>10}",
971 "test", "bitmap", "size", "expected", "relative", "ok"
972 );
973 for report in &reports {
974 println!(
975 "{:30} {:12} {:6} {:10} {:>10} {:>10}",
976 report.name,
977 "Splinter",
978 report.splinter.0,
979 report.splinter.1,
980 "1.00",
981 if report.splinter.0 == report.splinter.1 {
982 "ok"
983 } else {
984 fail_test = true;
985 "FAIL"
986 }
987 );
988
989 let diff = report.roaring.0 as f64 / report.splinter.0 as f64;
990 let ok_status = if report.roaring.0 != report.roaring.1 {
991 fail_test = true;
992 "FAIL".into()
993 } else {
994 ratio_to_marks(diff)
995 };
996 println!(
997 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
998 "", "Roaring", report.roaring.0, report.roaring.1, diff, ok_status
999 );
1000
1001 let diff = report.splinter_lz4 as f64 / report.splinter.0 as f64;
1002 println!(
1003 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1004 "",
1005 "Splinter LZ4",
1006 report.splinter_lz4,
1007 report.splinter_lz4,
1008 diff,
1009 ratio_to_marks(diff)
1010 );
1011
1012 let diff = report.roaring_lz4 as f64 / report.splinter_lz4 as f64;
1013 println!(
1014 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1015 "",
1016 "Roaring LZ4",
1017 report.roaring_lz4,
1018 report.roaring_lz4,
1019 diff,
1020 ratio_to_marks(diff)
1021 );
1022
1023 let diff = report.baseline as f64 / report.splinter.0 as f64;
1024 println!(
1025 "{:30} {:12} {:6} {:10} {:>10.2} {:>10}",
1026 "",
1027 "Baseline",
1028 report.baseline,
1029 report.baseline,
1030 diff,
1031 ratio_to_marks(diff)
1032 );
1033 }
1034
1035 let avg_ratio = reports
1037 .iter()
1038 .map(|r| r.splinter_lz4 as f64 / r.splinter.0 as f64)
1039 .sum::<f64>()
1040 / reports.len() as f64;
1041
1042 println!("average compression ratio (splinter_lz4 / splinter): {avg_ratio:.2}");
1043
1044 assert!(!fail_test, "compression test failed");
1045 }
1046}