use std::{
fmt,
ops::{Bound, Range, RangeBounds},
};
use bytes::{Bytes, BytesMut};
use serde::{
Deserialize, Deserializer, Serialize, Serializer,
de::{self, Visitor},
};
use thiserror::Error;
use mountpoint_s3_client::checksums::{
crc32c::{self, Crc32c},
crc32c_from_base64, crc32c_to_base64,
};
fn is_integrity_validation_disabled() -> bool {
std::env::var("EXPERIMENTAL_MOUNTPOINT_NO_DOWNLOAD_INTEGRITY_VALIDATION").is_ok()
}
#[derive(Clone, Debug)]
#[must_use]
pub struct ChecksummedBytes {
buffer: Bytes,
range: Range<usize>,
checksum: Crc32c,
}
impl ChecksummedBytes {
pub fn new_from_inner_data(bytes: Bytes, checksum: Crc32c) -> Self {
let full_range = 0..bytes.len();
Self {
buffer: bytes,
range: full_range,
checksum,
}
}
pub fn new(bytes: Bytes) -> Self {
let checksum = if is_integrity_validation_disabled() {
Crc32c::new(0) } else {
crc32c::checksum(&bytes)
};
Self::new_from_inner_data(bytes, checksum)
}
pub fn into_bytes(self) -> Result<Bytes, IntegrityError> {
self.validate()?;
Ok(self.buffer_slice())
}
pub fn len(&self) -> usize {
self.range.len()
}
pub fn is_empty(&self) -> bool {
self.range.is_empty()
}
pub fn split_off(&mut self, at: usize) -> ChecksummedBytes {
assert!(at < self.len());
let start = self.range.start;
let prefix_range = start..(start + at);
let suffix_range = (start + at)..self.range.end;
self.range = prefix_range;
Self {
buffer: self.buffer.clone(),
range: suffix_range,
checksum: self.checksum,
}
}
pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
let sliced_range = {
let original_len = self.len();
let original_start = self.range.start;
let slice_start_offset = match range.start_bound() {
Bound::Included(&n) => n,
Bound::Excluded(&n) => n.checked_add(1).expect("range start greater than maximum usize"),
Bound::Unbounded => 0,
};
let slice_end_offset = match range.end_bound() {
Bound::Included(&n) => n.checked_add(1).expect("range end greater than maximum usize"),
Bound::Excluded(&n) => n,
Bound::Unbounded => original_len,
};
assert!(
slice_start_offset <= slice_end_offset,
"range start must not be greater than end: {slice_start_offset:?} <= {slice_end_offset:?}",
);
assert!(
slice_end_offset <= original_len,
"range end out of bounds: {slice_end_offset:?} <= {original_len:?}",
);
(original_start + slice_start_offset)..(original_start + slice_end_offset)
};
Self {
buffer: self.buffer.clone(),
range: sliced_range,
checksum: self.checksum,
}
}
pub fn shrink_to_fit(&mut self) -> Result<(), IntegrityError> {
if self.len() == self.buffer.len() {
return Ok(());
}
let bytes = self.buffer_slice();
let checksum = crc32c::checksum(&bytes);
self.validate()?;
*self = Self {
buffer: bytes,
range: 0..self.len(),
checksum,
};
Ok(())
}
pub fn extend(&mut self, mut extend: ChecksummedBytes) -> Result<(), IntegrityError> {
if extend.is_empty() {
extend.validate()?;
return Ok(());
}
if self.is_empty() {
self.validate()?;
*self = extend;
return Ok(());
}
self.shrink_to_fit()?;
assert_eq!(self.buffer.len(), self.len());
extend.shrink_to_fit()?;
assert_eq!(extend.buffer.len(), extend.len());
let new_checksum = combine_checksums(self.checksum, extend.checksum, extend.len());
let new_bytes = {
let mut bytes_mut = BytesMut::with_capacity(self.len() + extend.len());
bytes_mut.extend_from_slice(&self.buffer);
bytes_mut.extend_from_slice(&extend.buffer);
bytes_mut.freeze()
};
let new_range = 0..(new_bytes.len());
*self = Self {
buffer: new_bytes,
range: new_range,
checksum: new_checksum,
};
Ok(())
}
pub fn validate(&self) -> Result<(), IntegrityError> {
if is_integrity_validation_disabled() {
return Ok(()); }
let checksum = crc32c::checksum(&self.buffer);
if self.checksum != checksum {
return Err(IntegrityError::ChecksumMismatch(self.checksum, checksum));
}
Ok(())
}
pub fn into_inner(mut self) -> Result<(Bytes, Crc32c), IntegrityError> {
self.shrink_to_fit()?;
Ok((self.buffer, self.checksum))
}
fn buffer_slice(&self) -> Bytes {
self.buffer.slice(self.range.clone())
}
}
impl Default for ChecksummedBytes {
fn default() -> Self {
Self {
buffer: Default::default(),
range: Default::default(),
checksum: Crc32c::new(0),
}
}
}
impl From<Bytes> for ChecksummedBytes {
fn from(value: Bytes) -> Self {
Self::new(value)
}
}
impl TryFrom<ChecksummedBytes> for Bytes {
type Error = IntegrityError;
fn try_from(value: ChecksummedBytes) -> Result<Self, Self::Error> {
value.into_bytes()
}
}
pub fn combine_checksums(prefix_crc: Crc32c, suffix_crc: Crc32c, suffix_len: usize) -> Crc32c {
if is_integrity_validation_disabled() {
return Crc32c::new(0); }
let combined = ::crc32c::crc32c_combine(prefix_crc.value(), suffix_crc.value(), suffix_len);
Crc32c::new(combined)
}
#[derive(Debug, Error)]
pub enum IntegrityError {
#[error("Checksum mismatch. expected: {0:?}, actual: {1:?}")]
ChecksumMismatch(Crc32c, Crc32c),
}
#[derive(Debug)]
pub struct Crc32cBase64(Crc32c);
impl Crc32cBase64 {
pub fn new(value: u32) -> Crc32cBase64 {
Crc32cBase64(Crc32c::new(value))
}
pub fn value(&self) -> Crc32c {
self.0
}
}
impl Serialize for Crc32cBase64 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = crc32c_to_base64(&self.0);
serializer.serialize_str(&encoded)
}
}
impl<'de> Deserialize<'de> for Crc32cBase64 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Crc32cVisitor;
impl<'de> Visitor<'de> for Crc32cVisitor {
type Value = Crc32cBase64;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a base64-encoded CRC32C string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
crc32c_from_base64(v).map(Crc32cBase64).map_err(E::custom)
}
}
deserializer.deserialize_str(Crc32cVisitor)
}
}
#[cfg(test)]
impl PartialEq for ChecksummedBytes {
fn eq(&self, other: &Self) -> bool {
let result = self.buffer_slice() == other.buffer_slice();
self.validate().expect("should be valid");
other.validate().expect("should be valid");
result
}
}
#[cfg(test)]
mod tests {
use std::ops::{RangeFrom, RangeTo};
use mountpoint_s3_client::checksums::crc32c;
use test_case::test_case;
use super::*;
#[test]
fn test_into_bytes() {
let bytes = Bytes::from_static(b"some bytes");
let expected = bytes.clone();
let checksummed_bytes = ChecksummedBytes::new(bytes);
let actual = checksummed_bytes.into_bytes().unwrap();
assert_eq!(expected, actual);
}
#[test]
fn test_into_bytes_integrity_error() {
let bytes = Bytes::from_static(b"some bytes");
let mut checksummed_bytes = ChecksummedBytes::new(bytes);
checksummed_bytes.buffer = Bytes::from_static(b"otherbytes");
let actual = checksummed_bytes.into_bytes();
assert!(matches!(actual, Err(IntegrityError::ChecksumMismatch(_, _))));
}
#[test]
fn test_split_off() {
let split_off_at = 4;
let bytes = Bytes::from_static(b"some bytes");
let expected = bytes.clone();
let expected_checksum = crc32c::checksum(&expected);
let mut checksummed_bytes = ChecksummedBytes::new(bytes);
let mut expected_part1 = expected.clone();
let expected_part2 = expected_part1.split_off(split_off_at);
let new_checksummed_bytes = checksummed_bytes.split_off(split_off_at);
assert_eq!(expected, checksummed_bytes.buffer);
assert_eq!(expected, new_checksummed_bytes.buffer);
assert_eq!(expected_part1, checksummed_bytes.buffer_slice());
assert_eq!(expected_part2, new_checksummed_bytes.buffer_slice());
assert_eq!(expected_checksum, checksummed_bytes.checksum);
assert_eq!(expected_checksum, new_checksummed_bytes.checksum);
}
#[test]
fn test_slice() {
let range = 3..7;
let bytes = Bytes::from_static(b"some bytes");
let expected = bytes.clone();
let expected_slice = bytes.slice(range.clone());
let expected_checksum = crc32c::checksum(&expected);
let original = ChecksummedBytes::new(bytes);
let slice = original.slice(range);
assert_eq!(expected, original.buffer);
assert_eq!(expected, original.buffer_slice());
assert_eq!(expected, slice.buffer);
assert_eq!(expected_slice, slice.buffer_slice());
assert_eq!(expected_checksum, original.checksum);
assert_eq!(expected_checksum, slice.checksum);
}
fn create_checksummed_bytes_with_range(range: Range<usize>) -> ChecksummedBytes {
let buffer = Bytes::copy_from_slice(&vec![0; range.len()]);
let checksum = crc32c::checksum(&buffer);
ChecksummedBytes {
buffer,
range,
checksum,
}
}
#[test_case(0..10, 0..10, 0..10)]
#[test_case(0..10, 5..6, 5..6)]
#[test_case(5..10, 2..4, 7..9)]
fn test_slice_range(original: Range<usize>, range: Range<usize>, expected: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
let slice = bytes.slice(range);
assert_eq!(slice.range, expected);
}
#[allow(clippy::reversed_empty_ranges)]
#[should_panic]
#[test_case(5..10, 4..2; "start greater than end")]
#[test_case(5..10, 4..12; "out of bounds")]
fn test_slice_range_fail(original: Range<usize>, range: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
_ = bytes.slice(range);
}
#[test_case(0..10, ..10, 0..10)]
#[test_case(0..10, ..6, 0..6)]
#[test_case(5..10, ..4, 5..9)]
fn test_slice_range_to(original: Range<usize>, range: RangeTo<usize>, expected: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
let slice = bytes.slice(range);
assert_eq!(slice.range, expected);
}
#[test_case(0..10, 0.., 0..10)]
#[test_case(0..10, 4.., 4..10)]
#[test_case(5..10, 2.., 7..10)]
fn test_slice_range_from(original: Range<usize>, range: RangeFrom<usize>, expected: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
let slice = bytes.slice(range);
assert_eq!(slice.range, expected);
}
#[test]
fn test_shrink_to_fit() {
let original = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
let mut unchanged = original.clone();
unchanged.shrink_to_fit().unwrap();
assert_eq!(original.buffer_slice(), unchanged.buffer_slice());
assert_eq!(original.buffer, unchanged.buffer);
assert_eq!(original.checksum, unchanged.checksum);
let slice = original.clone().split_off(5);
let mut shrunken = slice.clone();
shrunken.shrink_to_fit().unwrap();
assert_eq!(slice.buffer_slice(), shrunken.buffer_slice());
assert_ne!(slice.buffer, shrunken.buffer);
assert_ne!(slice.checksum, shrunken.checksum);
}
#[test]
fn test_shrink_to_fit_corrupted() {
let mut original = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
original.buffer = Bytes::from_static(b"otherbytes");
assert!(matches!(
original.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
let mut unchanged = original.clone();
unchanged.shrink_to_fit().unwrap();
assert_eq!(original.buffer_slice(), unchanged.buffer_slice());
assert_eq!(original.buffer, unchanged.buffer);
assert_eq!(original.checksum, unchanged.checksum);
assert!(matches!(
unchanged.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
let mut slice = original.clone().split_off(5);
assert!(matches!(slice.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
let result = slice.shrink_to_fit();
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
}
#[test]
fn test_into_inner() {
let original = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
let (unchanged_bytes, unchanged_checksum) = original.clone().into_inner().unwrap();
assert_eq!(original.buffer_slice(), unchanged_bytes);
assert_eq!(original.buffer, unchanged_bytes);
assert_eq!(original.checksum, unchanged_checksum);
let slice = original.clone().split_off(5);
let (shrunken_bytes, shrunken_checksum) = slice.clone().into_inner().unwrap();
assert_eq!(slice.buffer_slice(), shrunken_bytes);
assert_ne!(slice.buffer, shrunken_bytes);
assert_ne!(slice.checksum, shrunken_checksum);
}
#[test]
fn test_extend() {
let expected = Bytes::from_static(b"some bytes extended");
let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
let extend_bytes = ChecksummedBytes::new(Bytes::from_static(b" extended"));
checksummed_bytes.extend(extend_bytes).unwrap();
let actual = checksummed_bytes.buffer_slice();
assert_eq!(expected, actual);
}
#[test]
fn test_extend_after_split() {
let expected = Bytes::from_static(b"some bytes extended");
let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
let mut extend = ChecksummedBytes::new(Bytes::from_static(b"bytes extended"));
_ = checksummed_bytes.split_off(7);
extend = extend.split_off(2);
checksummed_bytes.extend(extend).unwrap();
let actual = checksummed_bytes.buffer_slice();
assert_eq!(expected, actual);
}
#[test]
fn test_extend_self_corrupted() {
let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
checksummed_bytes.buffer = Bytes::from_static(b"otherbytes");
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
let extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
assert!(matches!(extend.validate(), Ok(())));
checksummed_bytes.extend(extend).unwrap();
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
}
#[test]
fn test_extend_after_split_self_corrupted() {
let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
checksummed_bytes.buffer = Bytes::from_static(b"otherbytes");
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
_ = checksummed_bytes.split_off(4);
let extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
assert!(matches!(extend.validate(), Ok(())));
let result = checksummed_bytes.extend(extend);
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
}
#[test]
fn test_extend_split_off_self_corrupted() {
let mut split_off = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
split_off.buffer = Bytes::from_static(b"otherbytes");
split_off = split_off.split_off(4);
assert!(matches!(
split_off.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
let extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
assert!(matches!(extend.validate(), Ok(())));
let result = split_off.extend(extend);
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
}
#[test]
fn test_extend_other_corrupted() {
let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
assert!(matches!(checksummed_bytes.validate(), Ok(())));
let mut extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
extend.buffer = Bytes::from_static(b"corrupted");
assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
checksummed_bytes.extend(extend).unwrap();
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
}
#[test]
fn test_extend_after_split_other_corrupted() {
let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
assert!(matches!(checksummed_bytes.validate(), Ok(())));
let mut extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
extend.buffer = Bytes::from_static(b"corrupted");
assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
_ = extend.split_off(4);
let result = checksummed_bytes.extend(extend);
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
}
#[test]
fn test_extend_split_off_other_corrupted() {
let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
assert!(matches!(checksummed_bytes.validate(), Ok(())));
let mut split_off = ChecksummedBytes::new(Bytes::from_static(b"bytes extended"));
split_off.buffer = Bytes::from_static(b"bytescorrupted");
split_off = split_off.split_off(5);
assert!(matches!(
split_off.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
let result = checksummed_bytes.extend(split_off);
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
}
#[test]
fn test_combine_checksums() {
let buf: &[u8] = b"123456789";
let (buf1, buf2) = buf.split_at(4);
let crc = crc32c::checksum(buf);
let crc1 = crc32c::checksum(buf1);
let crc2 = crc32c::checksum(buf2);
let combined = combine_checksums(crc1, crc2, buf2.len());
assert_eq!(combined, crc);
}
}