#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(clippy::struct_field_names)]
pub struct ParseLimits {
max_single_alloc_bytes: u64,
max_total_bytes: u64,
max_item_count: u64,
max_decompression_ratio: u64,
}
impl ParseLimits {
#[must_use]
pub const fn unbounded() -> Self {
Self {
max_single_alloc_bytes: u64::MAX,
max_total_bytes: u64::MAX,
max_item_count: u64::MAX,
max_decompression_ratio: u64::MAX,
}
}
#[must_use]
pub const fn with_max_single_alloc(mut self, bytes: u64) -> Self {
self.max_single_alloc_bytes = bytes;
self
}
#[must_use]
pub const fn with_max_total_bytes(mut self, bytes: u64) -> Self {
self.max_total_bytes = bytes;
self
}
#[must_use]
pub const fn with_max_item_count(mut self, count: u64) -> Self {
self.max_item_count = count;
self
}
#[must_use]
pub const fn with_max_decompression_ratio(mut self, ratio: u64) -> Self {
self.max_decompression_ratio = ratio;
self
}
#[must_use]
pub const fn max_single_alloc_bytes(&self) -> u64 {
self.max_single_alloc_bytes
}
#[must_use]
pub const fn max_total_bytes(&self) -> u64 {
self.max_total_bytes
}
#[must_use]
pub const fn max_item_count(&self) -> u64 {
self.max_item_count
}
#[must_use]
pub const fn max_decompression_ratio(&self) -> u64 {
self.max_decompression_ratio
}
pub(crate) fn check_alloc(&self, requested: u64, context: &str) -> crate::Result<()> {
if requested > self.max_single_alloc_bytes {
return Err(crate::AnamnesisError::Parse {
reason: format!(
"requested allocation {requested} bytes exceeds caller \
ParseLimits max_single_alloc {} ({context})",
self.max_single_alloc_bytes
),
});
}
Ok(())
}
#[cfg_attr(
not(any(feature = "npz", feature = "pth", feature = "gguf")),
allow(dead_code)
)]
pub(crate) fn check_item_count(&self, count: u64, context: &str) -> crate::Result<()> {
if count > self.max_item_count {
return Err(crate::AnamnesisError::Parse {
reason: format!(
"declared item count {count} exceeds caller ParseLimits \
max_item_count {} ({context})",
self.max_item_count
),
});
}
Ok(())
}
#[cfg_attr(not(feature = "npz"), allow(dead_code))]
pub(crate) fn check_decompression_ratio(
&self,
uncompressed: u64,
compressed: u64,
context: &str,
) -> crate::Result<()> {
if self.max_decompression_ratio == u64::MAX {
return Ok(());
}
match self.max_decompression_ratio.checked_mul(compressed) {
Some(allowed) if uncompressed > allowed => Err(crate::AnamnesisError::Parse {
reason: format!(
"decompression ratio (uncompressed {uncompressed} / compressed \
{compressed}) exceeds caller ParseLimits max_decompression_ratio \
{} ({context})",
self.max_decompression_ratio
),
}),
_ => Ok(()),
}
}
}
impl Default for ParseLimits {
fn default() -> Self {
Self::unbounded()
}
}
pub(crate) struct Budget {
limits: ParseLimits,
total_bytes: u64,
}
impl Budget {
pub(crate) fn new(limits: &ParseLimits) -> Self {
Self {
limits: limits.clone(),
total_bytes: 0,
}
}
#[cfg_attr(not(feature = "npz"), allow(dead_code))]
pub(crate) fn unbounded() -> Self {
Self::new(&ParseLimits::unbounded())
}
pub(crate) fn charge_alloc(&mut self, bytes: u64, context: &str) -> crate::Result<()> {
self.limits.check_alloc(bytes, context)?;
let new_total =
self.total_bytes
.checked_add(bytes)
.ok_or_else(|| crate::AnamnesisError::Parse {
reason: format!("aggregate byte total overflow charging {bytes} ({context})"),
})?;
if new_total > self.limits.max_total_bytes {
return Err(crate::AnamnesisError::Parse {
reason: format!(
"cumulative declared bytes {new_total} exceeds caller ParseLimits \
max_total_bytes {} ({context})",
self.limits.max_total_bytes
),
});
}
self.total_bytes = new_total;
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::{Budget, ParseLimits};
#[test]
fn default_is_unbounded() {
let limits = ParseLimits::default();
assert_eq!(limits.max_single_alloc_bytes(), u64::MAX);
assert_eq!(limits.max_total_bytes(), u64::MAX);
assert_eq!(limits.max_item_count(), u64::MAX);
assert_eq!(limits.max_decompression_ratio(), u64::MAX);
assert_eq!(limits, ParseLimits::unbounded());
}
#[test]
fn check_decompression_ratio_boundary() {
let limits = ParseLimits::default().with_max_decompression_ratio(100);
assert!(limits.check_decompression_ratio(1000, 10, "ctx").is_ok());
let err = limits
.check_decompression_ratio(1001, 10, "ctx")
.unwrap_err();
assert!(
matches!(err, crate::AnamnesisError::Parse { ref reason } if reason.contains("max_decompression_ratio")),
"expected ratio error, got: {err}"
);
assert!(limits.check_decompression_ratio(0, 0, "ctx").is_ok());
assert!(limits.check_decompression_ratio(1, 0, "ctx").is_err());
let huge = ParseLimits::default().with_max_decompression_ratio(u64::MAX / 2);
assert!(huge.check_decompression_ratio(10, 4, "ctx").is_ok());
assert!(ParseLimits::default()
.check_decompression_ratio(u64::MAX, 1, "ctx")
.is_ok());
}
#[test]
fn builders_set_only_their_axis() {
let limits = ParseLimits::default().with_max_single_alloc(1024);
assert_eq!(limits.max_single_alloc_bytes(), 1024);
assert_eq!(limits.max_total_bytes(), u64::MAX);
assert_eq!(limits.max_item_count(), u64::MAX);
assert_eq!(limits.max_decompression_ratio(), u64::MAX);
let limits = ParseLimits::default().with_max_total_bytes(4096);
assert_eq!(limits.max_total_bytes(), 4096);
assert_eq!(limits.max_single_alloc_bytes(), u64::MAX);
assert_eq!(limits.max_item_count(), u64::MAX);
assert_eq!(limits.max_decompression_ratio(), u64::MAX);
let limits = ParseLimits::default().with_max_item_count(8);
assert_eq!(limits.max_item_count(), 8);
assert_eq!(limits.max_single_alloc_bytes(), u64::MAX);
assert_eq!(limits.max_total_bytes(), u64::MAX);
assert_eq!(limits.max_decompression_ratio(), u64::MAX);
let limits = ParseLimits::default().with_max_decompression_ratio(1000);
assert_eq!(limits.max_decompression_ratio(), 1000);
assert_eq!(limits.max_single_alloc_bytes(), u64::MAX);
assert_eq!(limits.max_total_bytes(), u64::MAX);
assert_eq!(limits.max_item_count(), u64::MAX);
}
#[test]
fn budget_aggregate_catches_what_per_item_misses() {
let limits = ParseLimits::default()
.with_max_single_alloc(1000)
.with_max_total_bytes(250);
let mut budget = Budget::new(&limits);
assert!(budget.charge_alloc(100, "a").is_ok()); assert!(budget.charge_alloc(100, "b").is_ok()); let err = budget.charge_alloc(100, "c").unwrap_err();
assert!(
matches!(err, crate::AnamnesisError::Parse { ref reason } if reason.contains("max_total_bytes")),
"expected aggregate error, got: {err}"
);
}
#[test]
fn budget_per_item_cap_still_applies() {
let limits = ParseLimits::default()
.with_max_single_alloc(50)
.with_max_total_bytes(u64::MAX);
let mut budget = Budget::new(&limits);
let err = budget.charge_alloc(51, "x").unwrap_err();
assert!(
matches!(err, crate::AnamnesisError::Parse { ref reason } if reason.contains("max_single_alloc")),
"expected single-alloc error, got: {err}"
);
}
#[test]
fn budget_unbounded_charges_until_overflow() {
let mut budget = Budget::unbounded();
assert!(budget.charge_alloc(u64::MAX, "x").is_ok());
let err = budget.charge_alloc(1, "y").unwrap_err();
assert!(
matches!(err, crate::AnamnesisError::Parse { ref reason } if reason.contains("overflow")),
"expected overflow error, got: {err}"
);
}
#[test]
fn check_alloc_boundary() {
let limits = ParseLimits::default().with_max_single_alloc(1024);
assert!(limits.check_alloc(1024, "ctx").is_ok());
assert!(limits.check_alloc(1025, "ctx").is_err());
assert!(ParseLimits::default().check_alloc(u64::MAX, "ctx").is_ok());
}
#[test]
fn check_item_count_boundary() {
let limits = ParseLimits::default().with_max_item_count(4);
assert!(limits.check_item_count(4, "ctx").is_ok());
assert!(limits.check_item_count(5, "ctx").is_err());
assert!(ParseLimits::default()
.check_item_count(u64::MAX, "ctx")
.is_ok());
}
}