splinter_rs/
splinter_ref.rs

1use std::{
2    fmt::Debug,
3    ops::{Deref, RangeBounds},
4};
5
6use bytes::Bytes;
7use zerocopy::FromBytes;
8
9use crate::{
10    Splinter,
11    codec::{
12        DecodeErr, Encodable,
13        encoder::Encoder,
14        footer::{Footer, SPLINTER_V2_MAGIC},
15        partition_ref::PartitionRef,
16    },
17    level::High,
18    traits::PartitionRead,
19};
20
21/// A zero-copy reference to serialized splinter data.
22///
23/// `SplinterRef` allows efficient querying of compressed bitmap data without
24/// deserializing the underlying structure. It wraps any type that can be
25/// dereferenced to `[u8]` and provides all the same read operations as
26/// [`Splinter`], but with minimal memory overhead and no allocation during
27/// queries.
28///
29/// This is the preferred type for read-only operations on serialized splinter
30/// data, especially when the data comes from files, network, or other external
31/// sources.
32///
33/// # Type Parameter
34///
35/// - `B`: Any type that implements `Deref<Target = [u8]>` such as `&[u8]`,
36///   `Vec<u8>`, `Bytes`, `Arc<[u8]>`, etc.
37///
38/// # Examples
39///
40/// Creating from serialized bytes:
41///
42/// ```
43/// use splinter_rs::{Splinter, SplinterRef, PartitionWrite, PartitionRead, Encodable};
44///
45/// // Create and populate a splinter
46/// let mut splinter = Splinter::EMPTY;
47/// splinter.insert(100);
48/// splinter.insert(200);
49///
50/// // Serialize it to bytes
51/// let bytes = splinter.encode_to_bytes();
52///
53/// // Create a zero-copy reference
54/// let splinter_ref = SplinterRef::from_bytes(bytes).unwrap();
55/// assert_eq!(splinter_ref.cardinality(), 2);
56/// assert!(splinter_ref.contains(100));
57/// ```
58///
59/// Working with different buffer types:
60///
61/// ```
62/// use splinter_rs::{Splinter, SplinterRef, PartitionWrite, PartitionRead, Encodable};
63/// use std::sync::Arc;
64///
65/// let mut splinter = Splinter::EMPTY;
66/// splinter.insert(42);
67///
68/// let bytes = splinter.encode_to_bytes();
69/// let shared_bytes: Arc<[u8]> = bytes.to_vec().into();
70///
71/// let splinter_ref = SplinterRef::from_bytes(shared_bytes).unwrap();
72/// assert!(splinter_ref.contains(42));
73/// ```
74#[derive(Clone)]
75pub struct SplinterRef<B> {
76    pub(crate) data: B,
77}
78
79impl<B: Deref<Target = [u8]>> Debug for SplinterRef<B> {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_tuple("SplinterRef")
82            .field(&self.load_unchecked())
83            .finish()
84    }
85}
86
87impl<B> SplinterRef<B> {
88    /// Returns a reference to the underlying data buffer.
89    ///
90    /// This provides access to the raw bytes that store the serialized splinter
91    /// data.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use splinter_rs::{Splinter, SplinterRef, PartitionWrite, Encodable};
97    ///
98    /// let mut splinter = Splinter::EMPTY;
99    /// splinter.insert(42);
100    /// let bytes = splinter.encode_to_bytes();
101    /// let splinter_ref = SplinterRef::from_bytes(bytes).unwrap();
102    ///
103    /// let inner_bytes = splinter_ref.inner();
104    /// assert!(!inner_bytes.is_empty());
105    /// ```
106    #[inline]
107    pub fn inner(&self) -> &B {
108        &self.data
109    }
110
111    /// Consumes the `SplinterRef` and returns the underlying data buffer.
112    ///
113    /// This is useful when you need to take ownership of the underlying data
114    /// after you're done querying the splinter.
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// use splinter_rs::{Splinter, SplinterRef, PartitionWrite, Encodable};
120    ///
121    /// let mut splinter = Splinter::EMPTY;
122    /// splinter.insert(42);
123    /// let bytes = splinter.encode_to_bytes();
124    /// let splinter_ref = SplinterRef::from_bytes(bytes.clone()).unwrap();
125    ///
126    /// let recovered_bytes = splinter_ref.into_inner();
127    /// assert_eq!(recovered_bytes, bytes);
128    /// ```
129    #[inline]
130    pub fn into_inner(self) -> B {
131        self.data
132    }
133}
134
135impl SplinterRef<Bytes> {
136    /// Returns a clone of the underlying bytes.
137    ///
138    /// This is efficient for `Bytes` since it uses reference counting
139    /// internally and doesn't actually copy the data.
140    ///
141    /// # Examples
142    ///
143    /// ```
144    /// use splinter_rs::{Splinter, PartitionWrite};
145    ///
146    /// let mut splinter = Splinter::EMPTY;
147    /// splinter.insert(42);
148    /// let splinter_ref = splinter.encode_to_splinter_ref();
149    ///
150    /// let bytes_copy = splinter_ref.encode_to_bytes();
151    /// assert!(!bytes_copy.is_empty());
152    /// ```
153    #[inline]
154    pub fn encode_to_bytes(&self) -> Bytes {
155        self.data.clone()
156    }
157}
158
159impl<B: Deref<Target = [u8]>> Encodable for SplinterRef<B> {
160    #[inline]
161    fn encoded_size(&self) -> usize {
162        self.data.len()
163    }
164
165    #[inline]
166    fn encode<T: bytes::BufMut>(&self, encoder: &mut Encoder<T>) {
167        encoder.write_splinter(&self.data);
168    }
169}
170
171impl<B: Deref<Target = [u8]>> SplinterRef<B> {
172    /// Converts this reference back to an owned [`Splinter`].
173    ///
174    /// This method deserializes the underlying data and creates a new owned
175    /// `Splinter` that supports mutation. This involves iterating through all
176    /// values and rebuilding the data structure, so it has a cost proportional
177    /// to the number of elements.
178    ///
179    /// # Examples
180    ///
181    /// ```
182    /// use splinter_rs::{Splinter, SplinterRef, PartitionWrite, PartitionRead};
183    ///
184    /// let mut original = Splinter::EMPTY;
185    /// original.insert(100);
186    /// original.insert(200);
187    ///
188    /// let splinter_ref = original.encode_to_splinter_ref();
189    /// let decoded = splinter_ref.decode_to_splinter();
190    ///
191    /// assert_eq!(decoded.cardinality(), 2);
192    /// assert!(decoded.contains(100));
193    /// assert!(decoded.contains(200));
194    /// ```
195    pub fn decode_to_splinter(&self) -> Splinter {
196        Splinter::new((&self.load_unchecked()).into())
197    }
198
199    /// Creates a `SplinterRef` from raw bytes, validating the format.
200    ///
201    /// This method parses and validates the serialized splinter format, checking:
202    /// - Sufficient data length
203    /// - Valid magic bytes
204    /// - Correct checksum
205    ///
206    /// IMPORTANT: This method *does not* recursively verify the entire
207    /// splinter, opting instead to rely on the checksum to detect any
208    /// corruption. Do not use Splinter with untrusted data as it's trivial to
209    /// construct a Splinter which will cause your program to panic at runtime.
210    ///
211    /// Returns an error if the data is corrupted or in an invalid format.
212    ///
213    /// # Errors
214    ///
215    /// - [`DecodeErr::Length`]: Not enough bytes in the buffer
216    /// - [`DecodeErr::Magic`]: Invalid magic bytes
217    /// - [`DecodeErr::Checksum`]: Data corruption detected
218    /// - [`DecodeErr::Validity`]: Invalid internal structure
219    /// - [`DecodeErr::SplinterV1`]: Data is from incompatible v1 format
220    ///
221    /// # Examples
222    ///
223    /// ```
224    /// use splinter_rs::{Splinter, SplinterRef, PartitionWrite, PartitionRead, Encodable};
225    ///
226    /// let mut splinter = Splinter::EMPTY;
227    /// splinter.insert(42);
228    /// let bytes = splinter.encode_to_bytes();
229    ///
230    /// let splinter_ref = SplinterRef::from_bytes(bytes).unwrap();
231    /// assert!(splinter_ref.contains(42));
232    /// ```
233    ///
234    /// Error handling:
235    ///
236    /// ```
237    /// use splinter_rs::{SplinterRef, codec::DecodeErr};
238    ///
239    /// let invalid_bytes = vec![0u8; 5]; // Too short
240    /// let result = SplinterRef::from_bytes(invalid_bytes);
241    /// assert!(matches!(result.unwrap_err(), DecodeErr::Length));
242    /// ```
243    pub fn from_bytes(data: B) -> Result<Self, DecodeErr> {
244        pub(crate) const SPLINTER_V1_MAGIC: [u8; 4] = [0xDA, 0xAE, 0x12, 0xDF];
245        if data.len() >= 4
246            && data.starts_with(&SPLINTER_V1_MAGIC)
247            && !data.ends_with(&SPLINTER_V2_MAGIC)
248        {
249            return Err(DecodeErr::SplinterV1);
250        }
251
252        if data.len() < Footer::SIZE {
253            return Err(DecodeErr::Length);
254        }
255        let (partitions, footer) = data.split_at(data.len() - Footer::SIZE);
256        Footer::ref_from_bytes(footer)?.validate(partitions)?;
257        PartitionRef::<High>::from_suffix(partitions)?;
258        Ok(Self { data })
259    }
260
261    pub(crate) fn load_unchecked(&self) -> PartitionRef<'_, High> {
262        let without_footer = &self.data[..(self.data.len() - Footer::SIZE)];
263        PartitionRef::from_suffix(without_footer).unwrap()
264    }
265}
266
267impl<B: Deref<Target = [u8]>> PartitionRead<High> for SplinterRef<B> {
268    fn cardinality(&self) -> usize {
269        self.load_unchecked().cardinality()
270    }
271
272    fn is_empty(&self) -> bool {
273        self.load_unchecked().is_empty()
274    }
275
276    fn contains(&self, value: u32) -> bool {
277        self.load_unchecked().contains(value)
278    }
279
280    fn position(&self, value: u32) -> Option<usize> {
281        self.load_unchecked().position(value)
282    }
283
284    fn rank(&self, value: u32) -> usize {
285        self.load_unchecked().rank(value)
286    }
287
288    fn select(&self, idx: usize) -> Option<u32> {
289        self.load_unchecked().select(idx)
290    }
291
292    fn last(&self) -> Option<u32> {
293        self.load_unchecked().last()
294    }
295
296    fn iter(&self) -> impl Iterator<Item = u32> {
297        self.load_unchecked().into_iter()
298    }
299
300    fn contains_all<R: RangeBounds<u32>>(&self, values: R) -> bool {
301        self.load_unchecked().contains_all(values)
302    }
303
304    fn contains_any<R: RangeBounds<u32>>(&self, values: R) -> bool {
305        self.load_unchecked().contains_any(values)
306    }
307}
308
309#[cfg(test)]
310mod test {
311    use proptest::{collection::vec, prop_assume, proptest};
312
313    use crate::{
314        Optimizable, PartitionRead, Splinter,
315        testutil::{SetGen, mksplinter},
316    };
317
318    #[test]
319    fn test_empty() {
320        let splinter = mksplinter(&[]).encode_to_splinter_ref();
321
322        assert_eq!(splinter.decode_to_splinter(), Splinter::EMPTY);
323        assert!(!splinter.contains(0));
324        assert_eq!(splinter.cardinality(), 0);
325        assert_eq!(splinter.last(), None);
326    }
327
328    /// This is a regression test for a bug in the SplinterRef encoding. The bug
329    /// was that we used LittleEndian encoded values to store unaligned values,
330    /// which sort in reverse order from what we expect.
331    #[test]
332    fn test_contains_bug() {
333        let mut set_gen = SetGen::new(0xDEAD_BEEF);
334        let set = set_gen.random(1024);
335        let lookup = set[(set.len() / 3) as usize];
336        let splinter = mksplinter(&set).encode_to_splinter_ref();
337        assert!(splinter.contains(lookup))
338    }
339
340    proptest! {
341        #[test]
342        fn test_splinter_ref_proptest(set in vec(0u32..16384, 0..1024)) {
343            let splinter = mksplinter(&set).encode_to_splinter_ref();
344            if set.is_empty() {
345                assert!(!splinter.contains(123))
346            } else {
347                let lookup = set[set.len() / 3];
348                assert!(splinter.contains(lookup))
349            }
350        }
351
352        #[test]
353        fn test_splinter_opt_ref_proptest(set in vec(0u32..16384, 0..1024))  {
354            let mut splinter = mksplinter(&set);
355            splinter.optimize();
356            let splinter = splinter.encode_to_splinter_ref();
357            if set.is_empty() {
358                assert!(!splinter.contains(123))
359            } else {
360                let lookup = set[set.len() / 3];
361                assert!(splinter.contains(lookup))
362            }
363        }
364
365        #[test]
366        fn test_splinter_ref_eq_proptest(set in vec(0u32..16384, 0..1024))  {
367            let ref1 = mksplinter(&set).encode_to_splinter_ref();
368            let ref2 = mksplinter(&set).encode_to_splinter_ref();
369            assert_eq!(ref1, ref2)
370        }
371
372        #[test]
373        fn test_splinter_opt_ref_eq_proptest(set in vec(0u32..16384, 0..1024))  {
374            let mut ref1 = mksplinter(&set);
375            ref1.optimize();
376            let ref1 = ref1.encode_to_splinter_ref();
377            let ref2 = mksplinter(&set).encode_to_splinter_ref();
378            assert_eq!(ref1, ref2)
379        }
380
381        #[test]
382        fn test_splinter_ref_ne_proptest(
383            set1 in vec(0u32..16384, 0..1024),
384            set2 in vec(0u32..16384, 0..1024),
385        ) {
386            prop_assume!(set1 != set2);
387
388            let ref1 = mksplinter(&set1).encode_to_splinter_ref();
389            let ref2 = mksplinter(&set2).encode_to_splinter_ref();
390            assert_ne!(ref1, ref2)
391        }
392
393        #[test]
394        fn test_splinter_opt_ref_ne_proptest(
395            set1 in vec(0u32..16384, 0..1024),
396            set2 in vec(0u32..16384, 0..1024),
397        ) {
398            prop_assume!(set1 != set2);
399
400            let mut ref1 = mksplinter(&set1);
401            ref1.optimize();
402            let ref1 = ref1.encode_to_splinter_ref();
403            let ref2 = mksplinter(&set2).encode_to_splinter_ref();
404            assert_ne!(ref1 ,ref2)
405        }
406    }
407
408    #[test]
409    fn test_ref_wat() {
410        #[rustfmt::skip]
411        let set = [ 6400, 11776, 768, 15872, 6912, 0, 11008, 769, 770, 11009, 4608, 771, 0, 768, 6401, 0, 8192, 8192, 4609, 772, 4610, 0, 0, 0, 0, 0, 768, 773, 774, 14336, 0, 0, 0, 15872, 11010, 775, 0, 768, 11777, 776, 0, 0, 0, 6400, 14337, 8193, 0, 0, 0, 0, 0, 0, 0, ];
412        let mut ref1 = mksplinter(&set);
413        ref1.optimize();
414        let ref1 = ref1.encode_to_splinter_ref();
415        let ref2 = mksplinter(&set).encode_to_splinter_ref();
416        assert_eq!(ref1, ref2)
417    }
418}