use std::{
fmt::Debug,
ops::{Deref, RangeBounds},
};
use bytes::Bytes;
use zerocopy::FromBytes;
use crate::{
Splinter,
codec::{
DecodeErr, Encodable,
encoder::Encoder,
footer::{Footer, SPLINTER_V2_MAGIC},
partition_ref::PartitionRef,
},
level::High,
traits::PartitionRead,
};
#[derive(Copy, Clone)]
pub struct SplinterRef<B> {
pub(crate) data: B,
}
impl<B: Deref<Target = [u8]>> Debug for SplinterRef<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SplinterRef")
.field(&self.load_unchecked())
.finish()
}
}
impl<B> SplinterRef<B> {
#[inline]
pub fn inner(&self) -> &B {
&self.data
}
#[inline]
pub fn into_inner(self) -> B {
self.data
}
}
impl SplinterRef<Bytes> {
#[inline]
pub fn encode_to_bytes(&self) -> Bytes {
self.data.clone()
}
}
impl<B: Deref<Target = [u8]>> Encodable for SplinterRef<B> {
#[inline]
fn encoded_size(&self) -> usize {
self.data.len()
}
#[inline]
fn encode<T: bytes::BufMut>(&self, encoder: &mut Encoder<T>) {
encoder.write_splinter(&self.data);
}
}
impl<B: Deref<Target = [u8]>> SplinterRef<B> {
pub fn decode_to_splinter(&self) -> Splinter {
Splinter::new((&self.load_unchecked()).into())
}
pub fn from_bytes(data: B) -> Result<Self, DecodeErr> {
pub(crate) const SPLINTER_V1_MAGIC: [u8; 4] = [0xDA, 0xAE, 0x12, 0xDF];
if data.len() >= 4
&& data.starts_with(&SPLINTER_V1_MAGIC)
&& !data.ends_with(&SPLINTER_V2_MAGIC)
{
return Err(DecodeErr::SplinterV1);
}
if data.len() < Footer::SIZE {
return Err(DecodeErr::Length);
}
let (partitions, footer) = data.split_at(data.len() - Footer::SIZE);
Footer::ref_from_bytes(footer)?.validate(partitions)?;
PartitionRef::<High>::from_suffix(partitions)?;
Ok(Self { data })
}
pub(crate) fn load_unchecked(&self) -> PartitionRef<'_, High> {
let without_footer = &self.data[..(self.data.len() - Footer::SIZE)];
PartitionRef::from_suffix(without_footer).unwrap()
}
}
impl<B: Deref<Target = [u8]>> PartitionRead<High> for SplinterRef<B> {
fn cardinality(&self) -> usize {
self.load_unchecked().cardinality()
}
fn is_empty(&self) -> bool {
self.load_unchecked().is_empty()
}
fn contains(&self, value: u32) -> bool {
self.load_unchecked().contains(value)
}
fn position(&self, value: u32) -> Option<usize> {
self.load_unchecked().position(value)
}
fn rank(&self, value: u32) -> usize {
self.load_unchecked().rank(value)
}
fn select(&self, idx: usize) -> Option<u32> {
self.load_unchecked().select(idx)
}
fn last(&self) -> Option<u32> {
self.load_unchecked().last()
}
fn iter(&self) -> impl Iterator<Item = u32> {
self.load_unchecked().into_iter()
}
fn contains_all<R: RangeBounds<u32>>(&self, values: R) -> bool {
self.load_unchecked().contains_all(values)
}
fn contains_any<R: RangeBounds<u32>>(&self, values: R) -> bool {
self.load_unchecked().contains_any(values)
}
}
#[cfg(test)]
mod test {
use proptest::{collection::vec, prop_assume, proptest};
use crate::{
Optimizable, PartitionRead, Splinter,
testutil::{SetGen, mksplinter},
};
#[test]
fn test_empty() {
let splinter = mksplinter(&[]).encode_to_splinter_ref();
assert_eq!(splinter.decode_to_splinter(), Splinter::EMPTY);
assert!(!splinter.contains(0));
assert_eq!(splinter.cardinality(), 0);
assert_eq!(splinter.last(), None);
}
#[test]
fn test_contains_bug() {
let mut set_gen = SetGen::new(0xDEAD_BEEF);
let set = set_gen.random(1024);
let lookup = set[set.len() / 3];
let splinter = mksplinter(&set).encode_to_splinter_ref();
assert!(splinter.contains(lookup))
}
proptest! {
#[test]
fn test_splinter_ref_proptest(set in vec(0u32..16384, 0..1024)) {
let splinter = mksplinter(&set).encode_to_splinter_ref();
if set.is_empty() {
assert!(!splinter.contains(123))
} else {
let lookup = set[set.len() / 3];
assert!(splinter.contains(lookup))
}
}
#[test]
fn test_splinter_opt_ref_proptest(set in vec(0u32..16384, 0..1024)) {
let mut splinter = mksplinter(&set);
splinter.optimize();
let splinter = splinter.encode_to_splinter_ref();
if set.is_empty() {
assert!(!splinter.contains(123))
} else {
let lookup = set[set.len() / 3];
assert!(splinter.contains(lookup))
}
}
#[test]
fn test_splinter_ref_eq_proptest(set in vec(0u32..16384, 0..1024)) {
let ref1 = mksplinter(&set).encode_to_splinter_ref();
let ref2 = mksplinter(&set).encode_to_splinter_ref();
assert_eq!(ref1, ref2)
}
#[test]
fn test_splinter_opt_ref_eq_proptest(set in vec(0u32..16384, 0..1024)) {
let mut ref1 = mksplinter(&set);
ref1.optimize();
let ref1 = ref1.encode_to_splinter_ref();
let ref2 = mksplinter(&set).encode_to_splinter_ref();
assert_eq!(ref1, ref2)
}
#[test]
fn test_splinter_ref_ne_proptest(
set1 in vec(0u32..16384, 0..1024),
set2 in vec(0u32..16384, 0..1024),
) {
prop_assume!(set1 != set2);
let ref1 = mksplinter(&set1).encode_to_splinter_ref();
let ref2 = mksplinter(&set2).encode_to_splinter_ref();
assert_ne!(ref1, ref2)
}
#[test]
fn test_splinter_opt_ref_ne_proptest(
set1 in vec(0u32..16384, 0..1024),
set2 in vec(0u32..16384, 0..1024),
) {
prop_assume!(set1 != set2);
let mut ref1 = mksplinter(&set1);
ref1.optimize();
let ref1 = ref1.encode_to_splinter_ref();
let ref2 = mksplinter(&set2).encode_to_splinter_ref();
assert_ne!(ref1 ,ref2)
}
}
#[test]
fn test_ref_wat() {
#[rustfmt::skip]
let set = [ 6400, 11776, 768, 15872, 6912, 0, 11008, 769, 770, 11009, 4608, 771, 0, 768, 6401, 0, 8192, 8192, 4609, 772, 4610, 0, 0, 0, 0, 0, 768, 773, 774, 14336, 0, 0, 0, 15872, 11010, 775, 0, 768, 11777, 776, 0, 0, 0, 6400, 14337, 8193, 0, 0, 0, 0, 0, 0, 0, ];
let mut ref1 = mksplinter(&set);
ref1.optimize();
let ref1 = ref1.encode_to_splinter_ref();
let ref2 = mksplinter(&set).encode_to_splinter_ref();
assert_eq!(ref1, ref2)
}
#[test]
fn test_splinter_ref_reencode_is_valid() {
let splinter_ref = mksplinter(&[1, 2, 3]).encode_to_splinter_ref();
let bytes = splinter_ref.encode_to_bytes();
assert_eq!(splinter_ref.data, bytes);
}
use hegel::generators;
#[hegel::test]
fn test_splinter_ref_contains_same_values(tc: hegel::TestCase) {
let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
let splinter = mksplinter(&values);
let splinter_ref = splinter.encode_to_splinter_ref();
assert_eq!(splinter.cardinality(), splinter_ref.cardinality());
assert_eq!(splinter.is_empty(), splinter_ref.is_empty());
assert_eq!(splinter.last(), splinter_ref.last());
for v in splinter.iter() {
assert!(splinter_ref.contains(v));
}
}
#[hegel::test]
fn test_splinter_ref_iter_matches(tc: hegel::TestCase) {
let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
let splinter = mksplinter(&values);
let splinter_ref = splinter.encode_to_splinter_ref();
let owned_items: Vec<u32> = splinter.iter().collect();
let ref_items: Vec<u32> = splinter_ref.iter().collect();
assert_eq!(owned_items, ref_items);
}
#[hegel::test]
fn test_optimized_splinter_ref_equivalence(tc: hegel::TestCase) {
let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
let mut optimized = mksplinter(&values);
optimized.optimize();
let unoptimized = mksplinter(&values);
let opt_ref = optimized.encode_to_splinter_ref();
let unopt_ref = unoptimized.encode_to_splinter_ref();
assert_eq!(opt_ref, unopt_ref);
}
#[hegel::test]
fn test_double_roundtrip(tc: hegel::TestCase) {
let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()));
let mut splinter = mksplinter(&values);
splinter.optimize();
let ref1 = splinter.encode_to_splinter_ref();
let decoded1 = ref1.decode_to_splinter();
let ref2 = decoded1.encode_to_splinter_ref();
let decoded2 = ref2.decode_to_splinter();
assert_eq!(decoded1, decoded2);
}
#[hegel::test]
fn test_splinter_ref_select_rank(tc: hegel::TestCase) {
let values: Vec<u32> = tc.draw(generators::vecs(generators::integers::<u32>()).min_size(1));
let splinter = mksplinter(&values);
let splinter_ref = splinter.encode_to_splinter_ref();
let cardinality = splinter.cardinality();
let idx = tc.draw(generators::integers::<usize>().max_value(cardinality - 1));
assert_eq!(splinter.select(idx), splinter_ref.select(idx));
let val = splinter.select(idx).unwrap();
assert_eq!(splinter.rank(val), splinter_ref.rank(val));
assert_eq!(splinter.position(val), splinter_ref.position(val));
}
}