use std::io::Read;
use zip::result::ZipError;
use crate::PackageError;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ExtractionLimits {
max_inflated_bytes: Option<u64>,
}
impl ExtractionLimits {
#[must_use]
pub const fn bounded(max_inflated_bytes: u64) -> Self {
Self {
max_inflated_bytes: Some(max_inflated_bytes),
}
}
#[must_use]
pub const fn unbounded() -> Self {
Self {
max_inflated_bytes: None,
}
}
pub(crate) const fn budget(self) -> ExtractionBudget {
match self.max_inflated_bytes {
Some(limit) => ExtractionBudget::Bounded {
limit,
remaining: limit,
},
None => ExtractionBudget::Unbounded,
}
}
}
#[derive(Debug)]
pub(crate) enum ExtractionBudget {
Unbounded,
Bounded {
limit: u64,
remaining: u64,
},
}
impl ExtractionBudget {
pub(crate) fn read_entry<R>(&mut self, reader: &mut R) -> Result<Vec<u8>, PackageError>
where
R: Read,
{
let mut bytes = Vec::new();
match self {
Self::Unbounded => {
reader
.read_to_end(&mut bytes)
.map_err(|source| PackageError::ArchiveRead(ZipError::Io(source)))?;
}
Self::Bounded { limit, remaining } => {
let probe = remaining.saturating_add(1);
let mut taken = reader.take(probe);
taken
.read_to_end(&mut bytes)
.map_err(|source| PackageError::ArchiveRead(ZipError::Io(source)))?;
let inflated = probe - taken.limit();
if inflated > *remaining {
return Err(PackageError::InflatedSizeExceeded { limit: *limit });
}
*remaining -= inflated;
}
}
Ok(bytes)
}
}
#[cfg(test)]
mod tests {
use super::ExtractionLimits;
use crate::PackageError;
#[test]
fn bounded_budget_charges_across_entries() -> Result<(), PackageError> {
let mut budget = ExtractionLimits::bounded(10).budget();
let first = budget.read_entry(&mut &[0_u8; 6][..])?;
assert_eq!(first.len(), 6);
let second = budget.read_entry(&mut &[0_u8; 4][..])?;
assert_eq!(second.len(), 4);
let result = budget.read_entry(&mut &[0_u8; 1][..]);
assert!(matches!(
result,
Err(PackageError::InflatedSizeExceeded { limit: 10 })
));
Ok(())
}
#[test]
fn single_entry_past_budget_is_refused_reporting_the_limit() {
let mut budget = ExtractionLimits::bounded(4).budget();
let result = budget.read_entry(&mut &[0_u8; 5][..]);
assert!(matches!(
result,
Err(PackageError::InflatedSizeExceeded { limit: 4 })
));
}
#[test]
fn unbounded_budget_reads_everything() -> Result<(), PackageError> {
let mut budget = ExtractionLimits::unbounded().budget();
let bytes = budget.read_entry(&mut &[0_u8; 4096][..])?;
assert_eq!(bytes.len(), 4096);
Ok(())
}
}