snapper_box/file/
segment.rs

1//! Abstraction for breaking a file up into a list of segments
2
3use std::{
4    borrow::Cow,
5    io::{Read, Write},
6};
7
8use snafu::{ensure, OptionExt, ResultExt};
9
10use crate::{
11    crypto::{CipherText, Nonce},
12    error::{BackendError, NoData, NoDataIO, SegmentIO, SegmentLength},
13};
14
15/// A data segment within a file, encoded as vector of bytes and its length.
16///
17/// This type has both the ability to represent owned data, for normal copying reads, as well as owned
18/// data, for zero copy.
19///
20/// A valid segment has a length, encoded as an 8-byte little-endian integer, and a bytewise array of
21/// data of the specified length.
22#[derive(Debug, Hash, Clone, PartialEq, Eq)]
23pub struct Segment<'a> {
24    /// The length of a segment, encoded as a number of bytes
25    length: u64,
26    /// The data contained in the segment
27    data: Cow<'a, [u8]>,
28}
29
30impl<'a> Segment<'a> {
31    /// Parses and borrows a `Segment` from the provided data.
32    ///
33    /// # Errors
34    ///
35    /// This will return `Err(Error::SegmentLength)` if
36    ///   * The specified length is too big to possibly fit into memory
37    ///   * There is not enough data in the slice to fill the data buffer
38    pub fn read_borrowed(source: &'a [u8]) -> Result<Self, BackendError> {
39        // First make sure that the slice is long enough to contain data
40        // We need at least 8 bytes for the length tag
41        ensure!(!source.is_empty(), NoData);
42        ensure!(source.len() >= 8, SegmentLength);
43        // Decode the length
44        let mut length_array = [0_u8; 8];
45        length_array.copy_from_slice(&source[0..8]);
46        let length = u64::from_le_bytes(length_array);
47        // Make sure the length is small enough to fit into memory
48        let length_usize: usize = length.try_into().ok().context(SegmentLength)?;
49        // Make sure the data is big enough to contain the specified number of bytes
50        let data = &source[8..];
51        ensure!(data.len() >= length_usize, SegmentLength);
52        Ok(Segment {
53            length,
54            data: Cow::Borrowed(data),
55        })
56    }
57
58    /// Provides the length, in bytes, that this value will take up if serialized. This count includes the
59    /// embedded 8-byte length tag.
60    pub fn total_length(&self) -> usize {
61        // 8 bytes for the length tag, plus the length of the byte array
62        8 + self.data.len()
63    }
64
65    /// Writes this segment to an array of bytes, returning the number of bytes written
66    ///
67    /// # Errors
68    ///
69    /// Will return `SegmentLength` if the contained data is too big to fit in the buffer.
70    pub fn write_ref(&self, dest: &mut [u8]) -> Result<usize, BackendError> {
71        // Make sure the buffer is big enough
72        let length = self.total_length();
73        ensure!(dest.len() >= length, SegmentLength);
74        // Write to the buffer
75        // First the length
76        let length_bytes = (self.data.len() as u64).to_le_bytes();
77        (&mut dest[0..8]).copy_from_slice(&length_bytes);
78        // Then the data
79        let data = &mut dest[8..];
80        data.copy_from_slice(&self.data);
81        Ok(length)
82    }
83
84    /// Writes this segment to an IO [`Write`] instance, returning the number of bytes written
85    ///
86    /// # Errors
87    ///
88    /// Will pass through any underlying IO errors
89    pub fn write(&self, dest: &mut impl Write) -> Result<usize, BackendError> {
90        // First the length
91        let length = self.total_length();
92        let length_bytes = (self.data.len() as u64).to_le_bytes();
93        dest.write_all(&length_bytes).context(SegmentIO)?;
94        // Then the data
95        dest.write_all(&self.data).context(SegmentIO)?;
96        Ok(length)
97    }
98
99    /// Constructs a new segment from some borrowed data
100    pub fn new_borrowed(data: &'a [u8]) -> Self {
101        Self {
102            length: data.len().try_into().expect("Impossibly large data"),
103            data: Cow::Borrowed(data),
104        }
105    }
106
107    /// Gets a reference to the inner data
108    pub fn data(&self) -> &[u8] {
109        self.data.as_ref()
110    }
111}
112
113impl Segment<'static> {
114    /// Copies a `Segment` from the provided IO [`Read`]
115    ///
116    /// # Errors
117    ///
118    /// This will return `Err(Error::SegmentIo)` if
119    ///   * The specified length is too big to possibly fit into memory
120    ///   * There is not enough data in the slice to fill the data buffer
121    pub fn read_owned(source: &mut impl Read) -> Result<Self, BackendError> {
122        // Decode the length
123        let mut length_array = [0_u8; 8];
124        source.read_exact(&mut length_array).context(NoDataIO)?;
125        let length = u64::from_le_bytes(length_array);
126        // Make sure the length is small enough to fit into memory
127        let length_usize: usize = length.try_into().ok().context(SegmentLength)?;
128        // Create a buffer of the correct length to write the data into
129        let mut data = vec![0_u8; length_usize];
130        // Read the data into the buffer
131        source
132            .read_exact(&mut data[0..length_usize])
133            .context(SegmentIO)?;
134        Ok(Segment {
135            length,
136            data: Cow::from(data),
137        })
138    }
139
140    /// Constructs a new segment from some data
141    pub fn new(data: impl AsRef<[u8]>) -> Self {
142        let data = data.as_ref().to_vec();
143        Self {
144            length: data.len().try_into().expect("Impossibly large data"),
145            data: Cow::from(data),
146        }
147    }
148}
149
150impl<'a> From<CipherText<'a>> for Segment<'static> {
151    /// Encode a [`CipherText`] in binary form as a segment.
152    ///
153    /// This will encode:
154    ///   * The `compressed` flag - `0_u8` being false and `1_u8` being true
155    ///   * The rest of the fields as a concatenation of their bytes
156    fn from(x: CipherText<'a>) -> Self {
157        let mut buffer = vec![];
158        // Push the compression flag
159        if x.compressed {
160            buffer.push(1_u8);
161        } else {
162            buffer.push(0_u8);
163        };
164        // Push the nonce
165        buffer.extend(&*x.nonce.0);
166        // Push the HMAC
167        buffer.extend(&*x.hmac);
168        // Push the data
169        buffer.extend(&*x.payload);
170        Segment {
171            length: buffer.len() as u64,
172            data: buffer.into(),
173        }
174    }
175}
176
177impl<'a> TryFrom<Segment<'a>> for CipherText<'a> {
178    type Error = BackendError;
179
180    /// Attempt to decode a [`Segment`] as a [`CipherText`].
181    ///
182    /// # Errors
183    ///
184    ///   * `Error::SegmentLength` if there is a length mismatch
185    ///   * `Error::InvalidCompression` if the compression flag is invalid
186    fn try_from(value: Segment<'a>) -> Result<Self, Self::Error> {
187        let mut data: &[u8] = value.data.as_ref();
188        // Read the compression flag
189        let compressed: bool = match data[0] {
190            0_u8 => false,
191            1_u8 => true,
192            _ => return Err(BackendError::InvalidCompression),
193        };
194        data = &data[1..];
195        // Read the nonce
196        let mut nonce = [0_u8; 24];
197        ensure!(data.len() >= 24, SegmentLength);
198        nonce.copy_from_slice(&data[0..24]);
199        data = &data[24..];
200        // Read the hmac
201        let mut hmac = [0_u8; 32];
202        ensure!(data.len() >= 32, SegmentLength);
203        hmac.copy_from_slice(&data[0..32]);
204        Ok(CipherText {
205            compressed,
206            nonce: Nonce(nonce.into()),
207            hmac: hmac.into(),
208            payload: match value.data {
209                Cow::Borrowed(data) => Cow::Borrowed(&data[57..]),
210                Cow::Owned(data) => data[57..].to_vec().into(),
211            },
212        })
213    }
214}
215
216/// Unit tests
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::crypto::{ClearText, RootKey};
221    use proptest::prelude::*;
222    use std::io::{Cursor, Seek, SeekFrom};
223    proptest! {
224        /// Test round trip in borrowed mode
225        #[test]
226        fn borrowed_round_trip(bytes: Vec<u8>) {
227            // Make the segment
228            let segment = Segment::new_borrowed(&bytes);
229            // Test round trip via IO
230            let mut cursor = Cursor::new(Vec::<u8>::new());
231            segment.write(&mut cursor).expect("Failed to write to cursor");
232            // Test round trip via a buffer
233            let total_length = segment.total_length();
234            let mut buffer = vec![0_u8; total_length];
235            segment.write_ref(&mut buffer[0..total_length]).expect("Failed to write to buffer");
236            // Reread the segment from IO
237            let cursor_buff = cursor.into_inner();
238            let cursor_segment = Segment::read_borrowed(&cursor_buff[..])
239                .expect("Failed to read cursor segment");
240            assert_eq!(cursor_segment, segment);
241            // Reread the segment from buffer
242            let buffer_segment = Segment::read_borrowed(&buffer[..])
243                .expect("Failed to read buffer segment");
244            assert_eq!(buffer_segment, segment);
245        }
246        /// Test round trip in owned mode
247        #[test]
248        fn borrowed_owned(bytes: Vec<u8>) {
249            // Make the segment
250            let segment = Segment::new(&bytes);
251            // Test round trip via IO
252            let mut cursor = Cursor::new(Vec::<u8>::new());
253            segment.write(&mut cursor).expect("Failed to write to cursor");
254            // Seek back to start of cursor so we will be able to read it later
255            cursor.seek(SeekFrom::Start(0)).unwrap();
256            // Test round trip via a buffer
257            let total_length = segment.total_length();
258            let mut buffer = vec![0_u8; total_length];
259            segment.write_ref(&mut buffer[0..total_length]).expect("Failed to write to buffer");
260            // Reread the segment from IO
261            let cursor_segment = Segment::read_owned(&mut cursor)
262                .expect("Failed to read cursor segment");
263            assert_eq!(cursor_segment, segment);
264            // Reread the segment from buffer
265            let buffer_segment = Segment::read_borrowed(&buffer[..])
266                .expect("Failed to read buffer segment");
267            assert_eq!(buffer_segment, segment);
268
269        }
270    }
271    /// Test round trip of cipher text, without compression
272    #[test]
273    fn cipher_text_round_trip() -> Result<(), BackendError> {
274        // Get a cipher text
275        let root_key = RootKey::random();
276        let data = vec![1_u8; 256];
277        let plaintext = ClearText::new(&data)?;
278        let ciphertext = plaintext.clone().encrypt(&root_key, None)?;
279        // Get the segment
280        let segment: Segment<'_> = ciphertext.clone().into();
281        // Convert it back to a ciphertext
282        let recovered: CipherText<'_> = segment.try_into()?;
283        assert_eq!(recovered, ciphertext);
284        // Decrypt it
285        let recovered_plaintext = recovered.decrypt(&root_key)?;
286        assert_eq!(recovered_plaintext.payload, plaintext.payload);
287        // Deser it
288        let recovered_data: Vec<u8> = recovered_plaintext.deserialize()?;
289        assert_eq!(recovered_data, data);
290
291        Ok(())
292    }
293    /// Test round trip of cipher text, with compression
294    #[test]
295    fn cipher_text_round_trip_compress() -> Result<(), BackendError> {
296        // Get a cipher text
297        let root_key = RootKey::random();
298        let data = vec![1_u8; 256];
299        let plaintext = ClearText::new(&data)?;
300        let ciphertext = plaintext.clone().encrypt(&root_key, Some(1))?;
301        // Get the segment
302        let segment: Segment<'_> = ciphertext.clone().into();
303        // Convert it back to a ciphertext
304        let recovered: CipherText<'_> = segment.try_into()?;
305        assert_eq!(recovered, ciphertext);
306        // Decrypt it
307        let recovered_plaintext = recovered.decrypt(&root_key)?;
308        assert_eq!(recovered_plaintext.payload, plaintext.payload);
309        // Deser it
310        let recovered_data: Vec<u8> = recovered_plaintext.deserialize()?;
311        assert_eq!(recovered_data, data);
312
313        Ok(())
314    }
315}