snapper-box 0.0.4

Cryptographic storage for snapper
Documentation
//! Abstraction for breaking a file up into a list of segments

use std::{
    borrow::Cow,
    io::{Read, Write},
};

use snafu::{ensure, OptionExt, ResultExt};

use crate::{
    crypto::{CipherText, Nonce},
    error::{BackendError, NoData, NoDataIO, SegmentIO, SegmentLength},
};

/// A data segment within a file, encoded as vector of bytes and its length.
///
/// This type has both the ability to represent owned data, for normal copying reads, as well as owned
/// data, for zero copy.
///
/// A valid segment has a length, encoded as an 8-byte little-endian integer, and a bytewise array of
/// data of the specified length.
#[derive(Debug, Hash, Clone, PartialEq, Eq)]
pub struct Segment<'a> {
    /// The length of a segment, encoded as a number of bytes
    length: u64,
    /// The data contained in the segment
    data: Cow<'a, [u8]>,
}

impl<'a> Segment<'a> {
    /// Parses and borrows a `Segment` from the provided data.
    ///
    /// # Errors
    ///
    /// This will return `Err(Error::SegmentLength)` if
    ///   * The specified length is too big to possibly fit into memory
    ///   * There is not enough data in the slice to fill the data buffer
    pub fn read_borrowed(source: &'a [u8]) -> Result<Self, BackendError> {
        // First make sure that the slice is long enough to contain data
        // We need at least 8 bytes for the length tag
        ensure!(!source.is_empty(), NoData);
        ensure!(source.len() >= 8, SegmentLength);
        // Decode the length
        let mut length_array = [0_u8; 8];
        length_array.copy_from_slice(&source[0..8]);
        let length = u64::from_le_bytes(length_array);
        // Make sure the length is small enough to fit into memory
        let length_usize: usize = length.try_into().ok().context(SegmentLength)?;
        // Make sure the data is big enough to contain the specified number of bytes
        let data = &source[8..];
        ensure!(data.len() >= length_usize, SegmentLength);
        Ok(Segment {
            length,
            data: Cow::Borrowed(data),
        })
    }

    /// Provides the length, in bytes, that this value will take up if serialized. This count includes the
    /// embedded 8-byte length tag.
    pub fn total_length(&self) -> usize {
        // 8 bytes for the length tag, plus the length of the byte array
        8 + self.data.len()
    }

    /// Writes this segment to an array of bytes, returning the number of bytes written
    ///
    /// # Errors
    ///
    /// Will return `SegmentLength` if the contained data is too big to fit in the buffer.
    pub fn write_ref(&self, dest: &mut [u8]) -> Result<usize, BackendError> {
        // Make sure the buffer is big enough
        let length = self.total_length();
        ensure!(dest.len() >= length, SegmentLength);
        // Write to the buffer
        // First the length
        let length_bytes = (self.data.len() as u64).to_le_bytes();
        (&mut dest[0..8]).copy_from_slice(&length_bytes);
        // Then the data
        let data = &mut dest[8..];
        data.copy_from_slice(&self.data);
        Ok(length)
    }

    /// Writes this segment to an IO [`Write`] instance, returning the number of bytes written
    ///
    /// # Errors
    ///
    /// Will pass through any underlying IO errors
    pub fn write(&self, dest: &mut impl Write) -> Result<usize, BackendError> {
        // First the length
        let length = self.total_length();
        let length_bytes = (self.data.len() as u64).to_le_bytes();
        dest.write_all(&length_bytes).context(SegmentIO)?;
        // Then the data
        dest.write_all(&self.data).context(SegmentIO)?;
        Ok(length)
    }

    /// Constructs a new segment from some borrowed data
    pub fn new_borrowed(data: &'a [u8]) -> Self {
        Self {
            length: data.len().try_into().expect("Impossibly large data"),
            data: Cow::Borrowed(data),
        }
    }

    /// Gets a reference to the inner data
    pub fn data(&self) -> &[u8] {
        self.data.as_ref()
    }
}

impl Segment<'static> {
    /// Copies a `Segment` from the provided IO [`Read`]
    ///
    /// # Errors
    ///
    /// This will return `Err(Error::SegmentIo)` if
    ///   * The specified length is too big to possibly fit into memory
    ///   * There is not enough data in the slice to fill the data buffer
    pub fn read_owned(source: &mut impl Read) -> Result<Self, BackendError> {
        // Decode the length
        let mut length_array = [0_u8; 8];
        source.read_exact(&mut length_array).context(NoDataIO)?;
        let length = u64::from_le_bytes(length_array);
        // Make sure the length is small enough to fit into memory
        let length_usize: usize = length.try_into().ok().context(SegmentLength)?;
        // Create a buffer of the correct length to write the data into
        let mut data = vec![0_u8; length_usize];
        // Read the data into the buffer
        source
            .read_exact(&mut data[0..length_usize])
            .context(SegmentIO)?;
        Ok(Segment {
            length,
            data: Cow::from(data),
        })
    }

    /// Constructs a new segment from some data
    pub fn new(data: impl AsRef<[u8]>) -> Self {
        let data = data.as_ref().to_vec();
        Self {
            length: data.len().try_into().expect("Impossibly large data"),
            data: Cow::from(data),
        }
    }
}

impl<'a> From<CipherText<'a>> for Segment<'static> {
    /// Encode a [`CipherText`] in binary form as a segment.
    ///
    /// This will encode:
    ///   * The `compressed` flag - `0_u8` being false and `1_u8` being true
    ///   * The rest of the fields as a concatenation of their bytes
    fn from(x: CipherText<'a>) -> Self {
        let mut buffer = vec![];
        // Push the compression flag
        if x.compressed {
            buffer.push(1_u8);
        } else {
            buffer.push(0_u8);
        };
        // Push the nonce
        buffer.extend(&*x.nonce.0);
        // Push the HMAC
        buffer.extend(&*x.hmac);
        // Push the data
        buffer.extend(&*x.payload);
        Segment {
            length: buffer.len() as u64,
            data: buffer.into(),
        }
    }
}

impl<'a> TryFrom<Segment<'a>> for CipherText<'a> {
    type Error = BackendError;

    /// Attempt to decode a [`Segment`] as a [`CipherText`].
    ///
    /// # Errors
    ///
    ///   * `Error::SegmentLength` if there is a length mismatch
    ///   * `Error::InvalidCompression` if the compression flag is invalid
    fn try_from(value: Segment<'a>) -> Result<Self, Self::Error> {
        let mut data: &[u8] = value.data.as_ref();
        // Read the compression flag
        let compressed: bool = match data[0] {
            0_u8 => false,
            1_u8 => true,
            _ => return Err(BackendError::InvalidCompression),
        };
        data = &data[1..];
        // Read the nonce
        let mut nonce = [0_u8; 24];
        ensure!(data.len() >= 24, SegmentLength);
        nonce.copy_from_slice(&data[0..24]);
        data = &data[24..];
        // Read the hmac
        let mut hmac = [0_u8; 32];
        ensure!(data.len() >= 32, SegmentLength);
        hmac.copy_from_slice(&data[0..32]);
        Ok(CipherText {
            compressed,
            nonce: Nonce(nonce.into()),
            hmac: hmac.into(),
            payload: match value.data {
                Cow::Borrowed(data) => Cow::Borrowed(&data[57..]),
                Cow::Owned(data) => data[57..].to_vec().into(),
            },
        })
    }
}

/// Unit tests
#[cfg(test)]
mod tests {
    use super::*;
    use crate::crypto::{ClearText, RootKey};
    use proptest::prelude::*;
    use std::io::{Cursor, Seek, SeekFrom};
    proptest! {
        /// Test round trip in borrowed mode
        #[test]
        fn borrowed_round_trip(bytes: Vec<u8>) {
            // Make the segment
            let segment = Segment::new_borrowed(&bytes);
            // Test round trip via IO
            let mut cursor = Cursor::new(Vec::<u8>::new());
            segment.write(&mut cursor).expect("Failed to write to cursor");
            // Test round trip via a buffer
            let total_length = segment.total_length();
            let mut buffer = vec![0_u8; total_length];
            segment.write_ref(&mut buffer[0..total_length]).expect("Failed to write to buffer");
            // Reread the segment from IO
            let cursor_buff = cursor.into_inner();
            let cursor_segment = Segment::read_borrowed(&cursor_buff[..])
                .expect("Failed to read cursor segment");
            assert_eq!(cursor_segment, segment);
            // Reread the segment from buffer
            let buffer_segment = Segment::read_borrowed(&buffer[..])
                .expect("Failed to read buffer segment");
            assert_eq!(buffer_segment, segment);
        }
        /// Test round trip in owned mode
        #[test]
        fn borrowed_owned(bytes: Vec<u8>) {
            // Make the segment
            let segment = Segment::new(&bytes);
            // Test round trip via IO
            let mut cursor = Cursor::new(Vec::<u8>::new());
            segment.write(&mut cursor).expect("Failed to write to cursor");
            // Seek back to start of cursor so we will be able to read it later
            cursor.seek(SeekFrom::Start(0)).unwrap();
            // Test round trip via a buffer
            let total_length = segment.total_length();
            let mut buffer = vec![0_u8; total_length];
            segment.write_ref(&mut buffer[0..total_length]).expect("Failed to write to buffer");
            // Reread the segment from IO
            let cursor_segment = Segment::read_owned(&mut cursor)
                .expect("Failed to read cursor segment");
            assert_eq!(cursor_segment, segment);
            // Reread the segment from buffer
            let buffer_segment = Segment::read_borrowed(&buffer[..])
                .expect("Failed to read buffer segment");
            assert_eq!(buffer_segment, segment);

        }
    }
    /// Test round trip of cipher text, without compression
    #[test]
    fn cipher_text_round_trip() -> Result<(), BackendError> {
        // Get a cipher text
        let root_key = RootKey::random();
        let data = vec![1_u8; 256];
        let plaintext = ClearText::new(&data)?;
        let ciphertext = plaintext.clone().encrypt(&root_key, None)?;
        // Get the segment
        let segment: Segment<'_> = ciphertext.clone().into();
        // Convert it back to a ciphertext
        let recovered: CipherText<'_> = segment.try_into()?;
        assert_eq!(recovered, ciphertext);
        // Decrypt it
        let recovered_plaintext = recovered.decrypt(&root_key)?;
        assert_eq!(recovered_plaintext.payload, plaintext.payload);
        // Deser it
        let recovered_data: Vec<u8> = recovered_plaintext.deserialize()?;
        assert_eq!(recovered_data, data);

        Ok(())
    }
    /// Test round trip of cipher text, with compression
    #[test]
    fn cipher_text_round_trip_compress() -> Result<(), BackendError> {
        // Get a cipher text
        let root_key = RootKey::random();
        let data = vec![1_u8; 256];
        let plaintext = ClearText::new(&data)?;
        let ciphertext = plaintext.clone().encrypt(&root_key, Some(1))?;
        // Get the segment
        let segment: Segment<'_> = ciphertext.clone().into();
        // Convert it back to a ciphertext
        let recovered: CipherText<'_> = segment.try_into()?;
        assert_eq!(recovered, ciphertext);
        // Decrypt it
        let recovered_plaintext = recovered.decrypt(&root_key)?;
        assert_eq!(recovered_plaintext.payload, plaintext.payload);
        // Deser it
        let recovered_data: Vec<u8> = recovered_plaintext.deserialize()?;
        assert_eq!(recovered_data, data);

        Ok(())
    }
}