1use std::{
2 fmt,
3 ops::{Bound, Range, RangeBounds},
4};
5
6use bytes::{Bytes, BytesMut};
7use serde::{
8 Deserialize, Deserializer, Serialize, Serializer,
9 de::{self, Visitor},
10};
11use thiserror::Error;
12
13use mountpoint_s3_client::checksums::{
14 crc32c::{self, Crc32c},
15 crc32c_from_base64, crc32c_to_base64,
16};
17
18fn is_integrity_validation_disabled() -> bool {
20 std::env::var("EXPERIMENTAL_MOUNTPOINT_NO_DOWNLOAD_INTEGRITY_VALIDATION").is_ok()
21}
22
23#[derive(Clone, Debug)]
28#[must_use]
29pub struct ChecksummedBytes {
30 buffer: Bytes,
32 range: Range<usize>,
34 checksum: Crc32c,
36}
37
38impl ChecksummedBytes {
39 pub fn new_from_inner_data(bytes: Bytes, checksum: Crc32c) -> Self {
42 let full_range = 0..bytes.len();
43 Self {
44 buffer: bytes,
45 range: full_range,
46 checksum,
47 }
48 }
49
50 pub fn new(bytes: Bytes) -> Self {
52 let checksum = if is_integrity_validation_disabled() {
53 Crc32c::new(0) } else {
55 crc32c::checksum(&bytes)
56 };
57 Self::new_from_inner_data(bytes, checksum)
58 }
59
60 pub fn into_bytes(self) -> Result<Bytes, IntegrityError> {
64 self.validate()?;
65 Ok(self.buffer_slice())
66 }
67
68 pub fn len(&self) -> usize {
70 self.range.len()
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.range.is_empty()
76 }
77
78 pub fn split_off(&mut self, at: usize) -> ChecksummedBytes {
85 assert!(at < self.len());
86
87 let start = self.range.start;
88 let prefix_range = start..(start + at);
89 let suffix_range = (start + at)..self.range.end;
90
91 self.range = prefix_range;
92 Self {
93 buffer: self.buffer.clone(),
94 range: suffix_range,
95 checksum: self.checksum,
96 }
97 }
98
99 pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
104 let sliced_range = {
105 let original_len = self.len();
106 let original_start = self.range.start;
107
108 let slice_start_offset = match range.start_bound() {
109 Bound::Included(&n) => n,
110 Bound::Excluded(&n) => n.checked_add(1).expect("range start greater than maximum usize"),
111 Bound::Unbounded => 0,
112 };
113
114 let slice_end_offset = match range.end_bound() {
115 Bound::Included(&n) => n.checked_add(1).expect("range end greater than maximum usize"),
116 Bound::Excluded(&n) => n,
117 Bound::Unbounded => original_len,
118 };
119
120 assert!(
121 slice_start_offset <= slice_end_offset,
122 "range start must not be greater than end: {slice_start_offset:?} <= {slice_end_offset:?}",
123 );
124 assert!(
125 slice_end_offset <= original_len,
126 "range end out of bounds: {slice_end_offset:?} <= {original_len:?}",
127 );
128
129 (original_start + slice_start_offset)..(original_start + slice_end_offset)
130 };
131
132 Self {
133 buffer: self.buffer.clone(),
134 range: sliced_range,
135 checksum: self.checksum,
136 }
137 }
138
139 pub fn shrink_to_fit(&mut self) -> Result<(), IntegrityError> {
144 if self.len() == self.buffer.len() {
145 return Ok(());
146 }
147
148 let bytes = self.buffer_slice();
150 let checksum = crc32c::checksum(&bytes);
151
152 self.validate()?;
154
155 *self = Self {
156 buffer: bytes,
157 range: 0..self.len(),
158 checksum,
159 };
160 Ok(())
161 }
162
163 pub fn extend(&mut self, mut extend: ChecksummedBytes) -> Result<(), IntegrityError> {
168 if extend.is_empty() {
169 extend.validate()?;
171 return Ok(());
172 }
173
174 if self.is_empty() {
175 self.validate()?;
177 *self = extend;
178 return Ok(());
179 }
180
181 self.shrink_to_fit()?;
187 assert_eq!(self.buffer.len(), self.len());
188 extend.shrink_to_fit()?;
189 assert_eq!(extend.buffer.len(), extend.len());
190
191 let new_checksum = combine_checksums(self.checksum, extend.checksum, extend.len());
193
194 let new_bytes = {
196 let mut bytes_mut = BytesMut::with_capacity(self.len() + extend.len());
197 bytes_mut.extend_from_slice(&self.buffer);
198 bytes_mut.extend_from_slice(&extend.buffer);
199 bytes_mut.freeze()
200 };
201
202 let new_range = 0..(new_bytes.len());
203 *self = Self {
204 buffer: new_bytes,
205 range: new_range,
206 checksum: new_checksum,
207 };
208 Ok(())
209 }
210
211 pub fn validate(&self) -> Result<(), IntegrityError> {
215 if is_integrity_validation_disabled() {
216 return Ok(()); }
218
219 let checksum = crc32c::checksum(&self.buffer);
220 if self.checksum != checksum {
221 return Err(IntegrityError::ChecksumMismatch(self.checksum, checksum));
222 }
223 Ok(())
224 }
225
226 pub fn into_inner(mut self) -> Result<(Bytes, Crc32c), IntegrityError> {
232 self.shrink_to_fit()?;
233 Ok((self.buffer, self.checksum))
234 }
235
236 fn buffer_slice(&self) -> Bytes {
240 self.buffer.slice(self.range.clone())
241 }
242}
243
244impl Default for ChecksummedBytes {
245 fn default() -> Self {
246 Self {
247 buffer: Default::default(),
248 range: Default::default(),
249 checksum: Crc32c::new(0),
250 }
251 }
252}
253
254impl From<Bytes> for ChecksummedBytes {
255 fn from(value: Bytes) -> Self {
256 Self::new(value)
257 }
258}
259
260impl TryFrom<ChecksummedBytes> for Bytes {
261 type Error = IntegrityError;
262
263 fn try_from(value: ChecksummedBytes) -> Result<Self, Self::Error> {
264 value.into_bytes()
265 }
266}
267
268pub fn combine_checksums(prefix_crc: Crc32c, suffix_crc: Crc32c, suffix_len: usize) -> Crc32c {
271 if is_integrity_validation_disabled() {
272 return Crc32c::new(0); }
274
275 let combined = ::crc32c::crc32c_combine(prefix_crc.value(), suffix_crc.value(), suffix_len);
276 Crc32c::new(combined)
277}
278
279#[derive(Debug, Error)]
280pub enum IntegrityError {
281 #[error("Checksum mismatch. expected: {0:?}, actual: {1:?}")]
282 ChecksumMismatch(Crc32c, Crc32c),
283}
284
285#[derive(Debug)]
289pub struct Crc32cBase64(Crc32c);
290
291impl Crc32cBase64 {
292 pub fn new(value: u32) -> Crc32cBase64 {
294 Crc32cBase64(Crc32c::new(value))
295 }
296
297 pub fn value(&self) -> Crc32c {
299 self.0
300 }
301}
302
303impl Serialize for Crc32cBase64 {
304 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
305 where
306 S: Serializer,
307 {
308 let encoded = crc32c_to_base64(&self.0);
309 serializer.serialize_str(&encoded)
310 }
311}
312
313impl<'de> Deserialize<'de> for Crc32cBase64 {
314 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
315 where
316 D: Deserializer<'de>,
317 {
318 struct Crc32cVisitor;
319
320 impl<'de> Visitor<'de> for Crc32cVisitor {
321 type Value = Crc32cBase64;
322
323 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
324 formatter.write_str("a base64-encoded CRC32C string")
325 }
326
327 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
328 where
329 E: de::Error,
330 {
331 crc32c_from_base64(v).map(Crc32cBase64).map_err(E::custom)
332 }
333 }
334
335 deserializer.deserialize_str(Crc32cVisitor)
336 }
337}
338
339#[cfg(test)]
341impl PartialEq for ChecksummedBytes {
342 fn eq(&self, other: &Self) -> bool {
343 let result = self.buffer_slice() == other.buffer_slice();
344 self.validate().expect("should be valid");
345 other.validate().expect("should be valid");
346 result
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use std::ops::{RangeFrom, RangeTo};
353
354 use mountpoint_s3_client::checksums::crc32c;
355 use test_case::test_case;
356
357 use super::*;
358
359 #[test]
360 fn test_into_bytes() {
361 let bytes = Bytes::from_static(b"some bytes");
362 let expected = bytes.clone();
363 let checksummed_bytes = ChecksummedBytes::new(bytes);
364
365 let actual = checksummed_bytes.into_bytes().unwrap();
366 assert_eq!(expected, actual);
367 }
368
369 #[test]
370 fn test_into_bytes_integrity_error() {
371 let bytes = Bytes::from_static(b"some bytes");
372 let mut checksummed_bytes = ChecksummedBytes::new(bytes);
373
374 checksummed_bytes.buffer = Bytes::from_static(b"otherbytes");
376
377 let actual = checksummed_bytes.into_bytes();
378 assert!(matches!(actual, Err(IntegrityError::ChecksumMismatch(_, _))));
379 }
380
381 #[test]
382 fn test_split_off() {
383 let split_off_at = 4;
384 let bytes = Bytes::from_static(b"some bytes");
385 let expected = bytes.clone();
386 let expected_checksum = crc32c::checksum(&expected);
387 let mut checksummed_bytes = ChecksummedBytes::new(bytes);
388
389 let mut expected_part1 = expected.clone();
390 let expected_part2 = expected_part1.split_off(split_off_at);
391 let new_checksummed_bytes = checksummed_bytes.split_off(split_off_at);
392
393 assert_eq!(expected, checksummed_bytes.buffer);
394 assert_eq!(expected, new_checksummed_bytes.buffer);
395 assert_eq!(expected_part1, checksummed_bytes.buffer_slice());
396 assert_eq!(expected_part2, new_checksummed_bytes.buffer_slice());
397 assert_eq!(expected_checksum, checksummed_bytes.checksum);
398 assert_eq!(expected_checksum, new_checksummed_bytes.checksum);
399 }
400
401 #[test]
402 fn test_slice() {
403 let range = 3..7;
404 let bytes = Bytes::from_static(b"some bytes");
405 let expected = bytes.clone();
406 let expected_slice = bytes.slice(range.clone());
407 let expected_checksum = crc32c::checksum(&expected);
408 let original = ChecksummedBytes::new(bytes);
409 let slice = original.slice(range);
410
411 assert_eq!(expected, original.buffer);
412 assert_eq!(expected, original.buffer_slice());
413 assert_eq!(expected, slice.buffer);
414 assert_eq!(expected_slice, slice.buffer_slice());
415 assert_eq!(expected_checksum, original.checksum);
416 assert_eq!(expected_checksum, slice.checksum);
417 }
418
419 fn create_checksummed_bytes_with_range(range: Range<usize>) -> ChecksummedBytes {
420 let buffer = Bytes::copy_from_slice(&vec![0; range.len()]);
421 let checksum = crc32c::checksum(&buffer);
422 ChecksummedBytes {
423 buffer,
424 range,
425 checksum,
426 }
427 }
428
429 #[test_case(0..10, 0..10, 0..10)]
430 #[test_case(0..10, 5..6, 5..6)]
431 #[test_case(5..10, 2..4, 7..9)]
432 fn test_slice_range(original: Range<usize>, range: Range<usize>, expected: Range<usize>) {
433 let bytes = create_checksummed_bytes_with_range(original);
434 let slice = bytes.slice(range);
435 assert_eq!(slice.range, expected);
436 }
437
438 #[allow(clippy::reversed_empty_ranges)]
439 #[should_panic]
440 #[test_case(5..10, 4..2; "start greater than end")]
441 #[test_case(5..10, 4..12; "out of bounds")]
442 fn test_slice_range_fail(original: Range<usize>, range: Range<usize>) {
443 let bytes = create_checksummed_bytes_with_range(original);
444 _ = bytes.slice(range);
445 }
446
447 #[test_case(0..10, ..10, 0..10)]
448 #[test_case(0..10, ..6, 0..6)]
449 #[test_case(5..10, ..4, 5..9)]
450 fn test_slice_range_to(original: Range<usize>, range: RangeTo<usize>, expected: Range<usize>) {
451 let bytes = create_checksummed_bytes_with_range(original);
452 let slice = bytes.slice(range);
453 assert_eq!(slice.range, expected);
454 }
455
456 #[test_case(0..10, 0.., 0..10)]
457 #[test_case(0..10, 4.., 4..10)]
458 #[test_case(5..10, 2.., 7..10)]
459 fn test_slice_range_from(original: Range<usize>, range: RangeFrom<usize>, expected: Range<usize>) {
460 let bytes = create_checksummed_bytes_with_range(original);
461 let slice = bytes.slice(range);
462 assert_eq!(slice.range, expected);
463 }
464
465 #[test]
466 fn test_shrink_to_fit() {
467 let original = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
468 let mut unchanged = original.clone();
469 unchanged.shrink_to_fit().unwrap();
470 assert_eq!(original.buffer_slice(), unchanged.buffer_slice());
471 assert_eq!(original.buffer, unchanged.buffer);
472 assert_eq!(original.checksum, unchanged.checksum);
473
474 let slice = original.clone().split_off(5);
475 let mut shrunken = slice.clone();
476 shrunken.shrink_to_fit().unwrap();
477 assert_eq!(slice.buffer_slice(), shrunken.buffer_slice());
478 assert_ne!(slice.buffer, shrunken.buffer);
479 assert_ne!(slice.checksum, shrunken.checksum);
480 }
481
482 #[test]
483 fn test_shrink_to_fit_corrupted() {
484 let mut original = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
485
486 original.buffer = Bytes::from_static(b"otherbytes");
488
489 assert!(matches!(
490 original.validate(),
491 Err(IntegrityError::ChecksumMismatch(_, _))
492 ));
493
494 let mut unchanged = original.clone();
495 unchanged.shrink_to_fit().unwrap();
496 assert_eq!(original.buffer_slice(), unchanged.buffer_slice());
497 assert_eq!(original.buffer, unchanged.buffer);
498 assert_eq!(original.checksum, unchanged.checksum);
499 assert!(matches!(
500 unchanged.validate(),
501 Err(IntegrityError::ChecksumMismatch(_, _))
502 ));
503
504 let mut slice = original.clone().split_off(5);
505 assert!(matches!(slice.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
506
507 let result = slice.shrink_to_fit();
508 assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
509 }
510
511 #[test]
512 fn test_into_inner() {
513 let original = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
514 let (unchanged_bytes, unchanged_checksum) = original.clone().into_inner().unwrap();
515 assert_eq!(original.buffer_slice(), unchanged_bytes);
516 assert_eq!(original.buffer, unchanged_bytes);
517 assert_eq!(original.checksum, unchanged_checksum);
518
519 let slice = original.clone().split_off(5);
520 let (shrunken_bytes, shrunken_checksum) = slice.clone().into_inner().unwrap();
521 assert_eq!(slice.buffer_slice(), shrunken_bytes);
522 assert_ne!(slice.buffer, shrunken_bytes);
523 assert_ne!(slice.checksum, shrunken_checksum);
524 }
525
526 #[test]
527 fn test_extend() {
528 let expected = Bytes::from_static(b"some bytes extended");
529 let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
530 let extend_bytes = ChecksummedBytes::new(Bytes::from_static(b" extended"));
531 checksummed_bytes.extend(extend_bytes).unwrap();
532 let actual = checksummed_bytes.buffer_slice();
533 assert_eq!(expected, actual);
534 }
535
536 #[test]
537 fn test_extend_after_split() {
538 let expected = Bytes::from_static(b"some bytes extended");
539 let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
540 let mut extend = ChecksummedBytes::new(Bytes::from_static(b"bytes extended"));
541 _ = checksummed_bytes.split_off(7);
542 extend = extend.split_off(2);
543 checksummed_bytes.extend(extend).unwrap();
544 let actual = checksummed_bytes.buffer_slice();
545 assert_eq!(expected, actual);
546 }
547
548 #[test]
549 fn test_extend_self_corrupted() {
550 let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
551
552 checksummed_bytes.buffer = Bytes::from_static(b"otherbytes");
554
555 assert!(matches!(
556 checksummed_bytes.validate(),
557 Err(IntegrityError::ChecksumMismatch(_, _))
558 ));
559
560 let extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
561 assert!(matches!(extend.validate(), Ok(())));
562
563 checksummed_bytes.extend(extend).unwrap();
564 assert!(matches!(
565 checksummed_bytes.validate(),
566 Err(IntegrityError::ChecksumMismatch(_, _))
567 ));
568 }
569
570 #[test]
571 fn test_extend_after_split_self_corrupted() {
572 let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
573
574 checksummed_bytes.buffer = Bytes::from_static(b"otherbytes");
576
577 assert!(matches!(
578 checksummed_bytes.validate(),
579 Err(IntegrityError::ChecksumMismatch(_, _))
580 ));
581
582 _ = checksummed_bytes.split_off(4);
583
584 let extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
585 assert!(matches!(extend.validate(), Ok(())));
586
587 let result = checksummed_bytes.extend(extend);
588 assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
589 }
590
591 #[test]
592 fn test_extend_split_off_self_corrupted() {
593 let mut split_off = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
594
595 split_off.buffer = Bytes::from_static(b"otherbytes");
597
598 split_off = split_off.split_off(4);
599
600 assert!(matches!(
601 split_off.validate(),
602 Err(IntegrityError::ChecksumMismatch(_, _))
603 ));
604
605 let extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
606 assert!(matches!(extend.validate(), Ok(())));
607
608 let result = split_off.extend(extend);
609 assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
610 }
611
612 #[test]
613 fn test_extend_other_corrupted() {
614 let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
615 assert!(matches!(checksummed_bytes.validate(), Ok(())));
616
617 let mut extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
618
619 extend.buffer = Bytes::from_static(b"corrupted");
621
622 assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
623
624 checksummed_bytes.extend(extend).unwrap();
625 assert!(matches!(
626 checksummed_bytes.validate(),
627 Err(IntegrityError::ChecksumMismatch(_, _))
628 ));
629 }
630
631 #[test]
632 fn test_extend_after_split_other_corrupted() {
633 let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
634 assert!(matches!(checksummed_bytes.validate(), Ok(())));
635
636 let mut extend = ChecksummedBytes::new(Bytes::from_static(b" extended"));
637
638 extend.buffer = Bytes::from_static(b"corrupted");
640
641 assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));
642
643 _ = extend.split_off(4);
644
645 let result = checksummed_bytes.extend(extend);
646 assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
647 }
648
649 #[test]
650 fn test_extend_split_off_other_corrupted() {
651 let mut checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"some bytes"));
652 assert!(matches!(checksummed_bytes.validate(), Ok(())));
653
654 let mut split_off = ChecksummedBytes::new(Bytes::from_static(b"bytes extended"));
655
656 split_off.buffer = Bytes::from_static(b"bytescorrupted");
658
659 split_off = split_off.split_off(5);
660 assert!(matches!(
661 split_off.validate(),
662 Err(IntegrityError::ChecksumMismatch(_, _))
663 ));
664
665 let result = checksummed_bytes.extend(split_off);
666 assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
667 }
668
669 #[test]
670 fn test_combine_checksums() {
671 let buf: &[u8] = b"123456789";
672 let (buf1, buf2) = buf.split_at(4);
673 let crc = crc32c::checksum(buf);
674 let crc1 = crc32c::checksum(buf1);
675 let crc2 = crc32c::checksum(buf2);
676 let combined = combine_checksums(crc1, crc2, buf2.len());
677 assert_eq!(combined, crc);
678 }
679}