Skip to main content

splinter_rs/
codec.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use thiserror::Error;
3use zerocopy::{ConvertError, SizeError};
4
5use crate::codec::encoder::Encoder;
6
7pub mod encoder;
8
9pub(crate) mod footer;
10pub(crate) mod partition_ref;
11pub(crate) mod runs_ref;
12pub(crate) mod tree_ref;
13
14/// Trait for types that can be encoded into a binary format.
15pub trait Encodable {
16    /// Returns the number of bytes required to encode this value.
17    ///
18    /// This should return the exact number of bytes that [`encode`](Self::encode)
19    /// will write, allowing for efficient buffer pre-allocation.
20    ///
21    /// Note: This function traverses the entire datastructure which scales with cardinality.
22    fn encoded_size(&self) -> usize;
23
24    /// Encodes this value into the provided encoder.
25    fn encode<B: BufMut>(&self, encoder: &mut Encoder<B>);
26
27    /// Convenience method that encodes this value to a [`Bytes`] buffer.
28    ///
29    /// This is the easiest way to serialize splinter data. It allocates
30    /// a buffer of the exact required size and encodes the value into it.
31    ///
32    /// # Examples
33    ///
34    /// ```
35    /// use splinter_rs::{Splinter, Encodable, PartitionWrite};
36    ///
37    /// let splinter = Splinter::from_iter([8, 42, 16]);
38    /// let bytes = splinter.encode_to_bytes();
39    /// assert!(!bytes.is_empty());
40    /// assert_eq!(bytes.len(), splinter.encoded_size());
41    /// ```
42    fn encode_to_bytes(&self) -> Bytes {
43        let size = self.encoded_size();
44        let mut encoder = Encoder::new(BytesMut::with_capacity(size));
45        self.encode(&mut encoder);
46        encoder.into_inner().freeze()
47    }
48}
49
50/// Errors that can occur when deserializing splinter data from bytes.
51///
52/// These errors indicate various types of corruption or invalid data that can
53/// be encountered when attempting to decode serialized splinter data.
54#[derive(Debug, Error)]
55pub enum DecodeErr {
56    /// The buffer does not contain enough bytes to decode the expected data.
57    ///
58    /// This error occurs when the buffer is truncated or smaller than the
59    /// minimum required size for a valid splinter.
60    #[error("not enough bytes")]
61    Length,
62
63    /// The data contains invalid or corrupted encoding structures.
64    ///
65    /// This error indicates that while the buffer has sufficient length and
66    /// correct magic bytes, the internal data structures are malformed or
67    /// contain invalid values.
68    #[error("invalid encoding")]
69    Validity,
70
71    /// The buffer does not end with the expected magic bytes.
72    ///
73    /// Splinter data ends with specific magic bytes to identify the format.
74    /// This error indicates the buffer does not contain valid splinter data
75    /// or has been corrupted at the end.
76    #[error("unknown magic value")]
77    Magic,
78
79    /// The calculated checksum does not match the stored checksum.
80    ///
81    /// This error indicates data corruption has occurred somewhere in the
82    /// buffer, as the integrity check has failed.
83    #[error("invalid checksum")]
84    Checksum,
85
86    /// The buffer contains data from the incompatible Splinter V1 format.
87    ///
88    /// This version of splinter-rs can only decode V2 format data. To decode
89    /// V1 data, use splinter-rs version 0.3.3 or earlier.
90    #[error("buffer contains serialized Splinter V1, decode using splinter-rs:v0.3.3")]
91    SplinterV1,
92}
93
94impl DecodeErr {
95    #[inline]
96    fn ensure_bytes_available(data: &[u8], len: usize) -> Result<(), DecodeErr> {
97        if data.len() < len {
98            Err(Self::Length)
99        } else {
100            Ok(())
101        }
102    }
103}
104
105impl<S, D> From<SizeError<S, D>> for DecodeErr {
106    #[track_caller]
107    fn from(_: SizeError<S, D>) -> Self {
108        DecodeErr::Length
109    }
110}
111
112impl<A, S, V> From<ConvertError<A, S, V>> for DecodeErr {
113    #[track_caller]
114    fn from(err: ConvertError<A, S, V>) -> Self {
115        match err {
116            ConvertError::Alignment(_) => panic!("All zerocopy transmutations must be unaligned"),
117            ConvertError::Size(_) => DecodeErr::Length,
118            ConvertError::Validity(_) => DecodeErr::Validity,
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use bytes::BytesMut;
126    use itertools::Itertools;
127    use proptest::proptest;
128
129    use crate::{
130        Encodable, Splinter, SplinterRef, assert_error,
131        codec::{
132            DecodeErr,
133            encoder::Encoder,
134            footer::{Footer, SPLINTER_V2_MAGIC},
135            partition_ref::PartitionRef,
136        },
137        level::{Block, Level, Low},
138        partition_kind::PartitionKind,
139        testutil::{
140            LevelSetGen, mkpartition, mkpartition_buf, mksplinter_buf, mksplinter_manual,
141            test_partition_read,
142        },
143        traits::{Optimizable, PartitionRead, TruncateFrom},
144    };
145
146    #[test]
147    fn test_encode_decode_direct() {
148        let mut setgen = LevelSetGen::<Low>::new(0xDEADBEEF);
149        let kinds = [
150            PartitionKind::Bitmap,
151            PartitionKind::Vec,
152            PartitionKind::Run,
153            PartitionKind::Tree,
154        ];
155        let sets = &[
156            vec![0],
157            vec![0, 1],
158            vec![0, u16::MAX],
159            vec![u16::MAX],
160            setgen.random(8),
161            setgen.random(4096),
162            setgen.runs(4096, 0.01),
163            setgen.runs(4096, 0.2),
164            setgen.runs(4096, 0.5),
165            setgen.runs(4096, 0.9),
166            (0..Low::MAX_LEN)
167                .map(|v| <Low as Level>::Value::truncate_from(v))
168                .collect_vec(),
169        ];
170
171        for kind in kinds {
172            for (i, set) in sets.iter().enumerate() {
173                println!("Testing partition kind: {kind:?} with set {i}");
174
175                let partition = mkpartition::<Low>(kind, &set);
176                let buf = partition.encode_to_bytes();
177                assert_eq!(
178                    partition.encoded_size(),
179                    buf.len(),
180                    "encoded_size doesn't match actual size"
181                );
182
183                let partition_ref = PartitionRef::<'_, Low>::from_suffix(&buf).unwrap();
184
185                assert_eq!(partition_ref.kind(), kind);
186                test_partition_read(&partition_ref, &set);
187            }
188        }
189    }
190
191    proptest! {
192        #[test]
193        fn test_encode_decode_proptest(
194            values in proptest::collection::vec(0u32..16384, 0..1024),
195        ) {
196            let expected = values.iter().copied().sorted().dedup().collect_vec();
197            let mut splinter = Splinter::from_iter(values);
198            splinter.optimize();
199            let buf = splinter.encode_to_bytes();
200            assert_eq!(
201                buf.len(),
202                splinter.encoded_size(),
203                "encoded_size doesn't match actual size"
204            );
205            let splinter_ref = SplinterRef::from_bytes(buf).unwrap();
206
207            test_partition_read(&splinter_ref, &expected);
208        }
209    }
210
211    #[test]
212    fn test_dense_splinter_roundtrip_7936_boundary() {
213        let encoded = (1u32..=7936).collect::<Splinter>().encode_to_bytes();
214        let decoded = SplinterRef::from_bytes(encoded).expect("decode");
215
216        assert_eq!(decoded.cardinality(), 7936);
217        assert_eq!(decoded.select(0), Some(1));
218        assert_eq!(decoded.last(), Some(7936));
219        assert!(!decoded.contains(0));
220    }
221
222    #[test]
223    fn test_length_corruption() {
224        for i in 0..Footer::SIZE {
225            let truncated = [0].repeat(i);
226            assert_error!(
227                SplinterRef::from_bytes(truncated),
228                DecodeErr::Length,
229                "Failed for truncated buffer of size {}",
230                i
231            );
232        }
233    }
234
235    #[test]
236    fn test_corrupted_root_partition_kind() {
237        let mut buf = mksplinter_buf(&[1, 2, 3]);
238
239        // Buffer with just footer size but corrupted partition kind
240        let footer_offset = buf.len() - Footer::SIZE;
241        let partitions = &mut buf[0..footer_offset];
242        partitions[partitions.len() - 1] = 10;
243        let corrupted = mksplinter_manual(partitions);
244
245        assert_error!(SplinterRef::from_bytes(corrupted), DecodeErr::Validity);
246    }
247
248    #[test]
249    fn test_corrupted_magic() {
250        let mut buf = mksplinter_buf(&[1, 2, 3]);
251
252        let magic_offset = buf.len() - SPLINTER_V2_MAGIC.len();
253        buf[magic_offset..].copy_from_slice(&[0].repeat(4));
254
255        assert_error!(SplinterRef::from_bytes(buf), DecodeErr::Magic);
256    }
257
258    #[test]
259    fn test_corrupted_data() {
260        let mut buf = mksplinter_buf(&[1, 2, 3]);
261        buf[0] = 123;
262        assert_error!(SplinterRef::from_bytes(buf), DecodeErr::Checksum);
263    }
264
265    #[test]
266    fn test_corrupted_checksum() {
267        let mut buf = mksplinter_buf(&[1, 2, 3]);
268        let checksum_offset = buf.len() - Footer::SIZE;
269        buf[checksum_offset] = 123;
270        assert_error!(SplinterRef::from_bytes(buf), DecodeErr::Checksum);
271    }
272
273    #[test]
274    fn test_corrupted_vec_partition() {
275        let mut buf = mkpartition_buf::<Block>(PartitionKind::Vec, &[1, 2, 3]);
276
277        //                            1     2     3   len  kind
278        assert_eq!(buf.as_ref(), &[0x01, 0x02, 0x03, 0x02, 0x03]);
279
280        // corrupt the length
281        buf[3] = 5;
282
283        assert_error!(PartitionRef::<Block>::from_suffix(&buf), DecodeErr::Length);
284    }
285
286    #[test]
287    fn test_corrupted_run_partition() {
288        let mut buf = mkpartition_buf::<Block>(PartitionKind::Run, &[1, 2, 3]);
289
290        //                            1     3   len  kind
291        assert_eq!(buf.as_ref(), &[0x01, 0x03, 0x00, 0x04]);
292
293        // corrupt the length
294        buf[2] = 5;
295
296        assert_error!(PartitionRef::<Block>::from_suffix(&buf), DecodeErr::Length);
297    }
298
299    #[test]
300    fn test_corrupted_tree_partition() {
301        let mut buf = mkpartition_buf::<Low>(PartitionKind::Tree, &[1, 2]);
302
303        assert_eq!(
304            buf.as_ref(),
305            &[
306                // Vec partition (child)
307                // 1     2   len  kind
308                0x01, 0x02, 0x01, 0x03,
309                // Tree partition
310                // offsets (u16), cumulative_cardinalities-1 (u16), segments, len, kind
311                0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x05
312            ]
313        );
314
315        // corrupt the tree len
316        buf[9] = 5;
317
318        assert_error!(PartitionRef::<Block>::from_suffix(&buf), DecodeErr::Length);
319    }
320
321    #[test]
322    fn test_vec_byteorder() {
323        let buf = mkpartition_buf::<Low>(PartitionKind::Vec, &[0x01_00, 0x02_00]);
324        assert_eq!(
325            buf.as_ref(),
326            &[
327                0x01, 0x00, // first value
328                0x02, 0x00, // second value
329                0x00, 0x01, // length
330                0x03, // kind
331            ]
332        );
333    }
334
335    #[test]
336    fn test_run_byteorder() {
337        let buf = mkpartition_buf::<Low>(PartitionKind::Run, &[0x01_00, 0x02_00]);
338        assert_eq!(
339            buf.as_ref(),
340            &[
341                0x01, 0x00, 0x01, 0x00, // first run
342                0x02, 0x00, 0x02, 0x00, // second run
343                0x00, 0x01, // length
344                0x04, // kind
345            ]
346        );
347    }
348
349    #[test]
350    fn test_detect_splinter_v1() {
351        let empty_splinter_v1 = b"\xda\xae\x12\xdf\0\0\0\0";
352        assert_error!(
353            SplinterRef::from_bytes(empty_splinter_v1.as_slice()),
354            DecodeErr::SplinterV1
355        );
356    }
357
358    #[test]
359    #[should_panic(expected = "footer already present")]
360    fn test_encoder_panics_when_footer_is_written_after_splinter_blob() {
361        let mut buf = BytesMut::new();
362        let mut encoder = Encoder::new(&mut buf);
363        encoder.write_splinter(&[1, 2, 3]);
364        encoder.write_footer();
365    }
366
367    #[test]
368    #[should_panic(expected = "footer already present")]
369    fn test_encoder_panics_when_footer_is_written_twice() {
370        let mut buf = BytesMut::new();
371        let mut encoder = Encoder::new(&mut buf);
372        encoder.write_footer();
373        encoder.write_footer();
374    }
375}