use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ResourceLimits {
pub max_response_size: u64,
pub max_parse_time_ms: u64,
pub max_decompressed_size: u64,
pub max_compression_ratio: f64,
pub max_dom_depth: usize,
pub max_dom_elements: usize,
}
impl Default for ResourceLimits {
fn default() -> Self {
Self {
max_response_size: 50 * 1024 * 1024, max_parse_time_ms: 30000, max_decompressed_size: 100 * 1024 * 1024, max_compression_ratio: 100.0,
max_dom_depth: 100,
max_dom_elements: 100000,
}
}
}
impl ResourceLimits {
pub fn new() -> Self {
Self::default()
}
pub fn check_response_size(&self, size: u64) -> crate::types::error::Result<()> {
if size > self.max_response_size {
Err(crate::types::error::Error::SizeExceeded {
max: self.max_response_size,
actual: size,
})
} else {
Ok(())
}
}
pub fn check_response_size_limit(&self, size: u64) -> Result<(), LimitError> {
if size > self.max_response_size {
Err(LimitError::ResponseSizeExceeded {
size,
max: self.max_response_size,
})
} else {
Ok(())
}
}
pub fn check_decompressed_size(&self, size: u64) -> Result<(), LimitError> {
if size > self.max_decompressed_size {
Err(LimitError::DecompressedSizeExceeded {
size,
max: self.max_decompressed_size,
})
} else {
Ok(())
}
}
pub fn check_compression_ratio(&self, compressed: u64, decompressed: u64) -> Result<(), LimitError> {
if compressed == 0 {
return Ok(());
}
let ratio = decompressed as f64 / compressed as f64;
if ratio > self.max_compression_ratio {
Err(LimitError::CompressionRatioExceeded {
ratio,
max: self.max_compression_ratio,
})
} else {
Ok(())
}
}
pub fn check_parse_time(&self, elapsed_ms: u64) -> crate::types::error::Result<()> {
if elapsed_ms > self.max_parse_time_ms {
Err(crate::types::error::Error::ResourceLimit(format!(
"Parse time {}ms exceeds limit {}ms",
elapsed_ms, self.max_parse_time_ms
)))
} else {
Ok(())
}
}
pub fn check_dom_depth(&self, depth: usize) -> Result<(), LimitError> {
if depth > self.max_dom_depth {
Err(LimitError::DomDepthExceeded {
depth,
max: self.max_dom_depth,
})
} else {
Ok(())
}
}
pub fn check_dom_elements(&self, count: usize) -> Result<(), LimitError> {
if count > self.max_dom_elements {
Err(LimitError::DomElementsExceeded {
count,
max: self.max_dom_elements,
})
} else {
Ok(())
}
}
}
#[derive(Debug, Clone)]
pub enum LimitError {
ResponseSizeExceeded {
size: u64,
max: u64,
},
DecompressedSizeExceeded {
size: u64,
max: u64,
},
CompressionRatioExceeded {
ratio: f64,
max: f64,
},
ParseTimeExceeded {
elapsed_ms: u64,
max_ms: u64,
},
DomDepthExceeded {
depth: usize,
max: usize,
},
DomElementsExceeded {
count: usize,
max: usize,
},
}
impl std::fmt::Display for LimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LimitError::ResponseSizeExceeded { size, max } => {
write!(f, "Response size {} exceeds limit {}", size, max)
}
LimitError::DecompressedSizeExceeded { size, max } => {
write!(f, "Decompressed size {} exceeds limit {}", size, max)
}
LimitError::CompressionRatioExceeded { ratio, max } => {
write!(f, "Compression ratio {:.2} exceeds limit {:.2}", ratio, max)
}
LimitError::ParseTimeExceeded { elapsed_ms, max_ms } => {
write!(f, "Parse time {}ms exceeds limit {}ms", elapsed_ms, max_ms)
}
LimitError::DomDepthExceeded { depth, max } => {
write!(f, "DOM depth {} exceeds limit {}", depth, max)
}
LimitError::DomElementsExceeded { count, max } => {
write!(f, "DOM elements {} exceeds limit {}", count, max)
}
}
}
}
impl std::error::Error for LimitError {}
pub struct ParseTimer {
started_at: Instant,
max_duration: Duration,
}
impl ParseTimer {
pub fn new(max_ms: u64) -> Self {
Self {
started_at: Instant::now(),
max_duration: Duration::from_millis(max_ms),
}
}
pub fn check(&self) -> Result<(), LimitError> {
let elapsed = self.started_at.elapsed();
if elapsed > self.max_duration {
Err(LimitError::ParseTimeExceeded {
elapsed_ms: elapsed.as_millis() as u64,
max_ms: self.max_duration.as_millis() as u64,
})
} else {
Ok(())
}
}
pub fn elapsed_ms(&self) -> u64 {
self.started_at.elapsed().as_millis() as u64
}
pub fn remaining_ms(&self) -> u64 {
let elapsed = self.started_at.elapsed();
if elapsed >= self.max_duration {
0
} else {
(self.max_duration - elapsed).as_millis() as u64
}
}
}