use std::fmt;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ServeLimits {
pub max_request_size: usize,
pub max_headers: usize,
pub max_header_size: usize,
pub keep_alive_timeout: Duration,
pub client_timeout: Duration,
pub max_pipelined: usize,
pub max_connections: usize,
}
impl Default for ServeLimits {
fn default() -> Self {
Self {
max_request_size: 2 * 1024 * 1024, max_headers: 100,
max_header_size: 8 * 1024, keep_alive_timeout: Duration::from_secs(5),
client_timeout: Duration::from_secs(5),
max_pipelined: 16,
max_connections: 1024,
}
}
}
impl ServeLimits {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_request_size(mut self, size: usize) -> Self {
self.max_request_size = size;
self
}
#[must_use]
pub fn with_max_headers(mut self, count: usize) -> Self {
self.max_headers = count;
self
}
#[must_use]
pub fn with_max_connections(mut self, count: usize) -> Self {
self.max_connections = count;
self
}
pub fn validate_request(
&self,
headers_count: usize,
body_size: usize,
) -> Result<(), LimitError> {
if headers_count > self.max_headers {
return Err(LimitError::TooManyHeaders { count: headers_count, max: self.max_headers });
}
if body_size > self.max_request_size {
return Err(LimitError::BodyTooLarge { size: body_size, max: self.max_request_size });
}
Ok(())
}
pub fn validate_header_size(&self, size: usize) -> Result<(), LimitError> {
if size > self.max_header_size {
return Err(LimitError::HeaderTooLarge { size, max: self.max_header_size });
}
Ok(())
}
pub fn validate_pipelined(&self, count: usize) -> Result<(), LimitError> {
if count > self.max_pipelined {
return Err(LimitError::TooManyPipelined { count, max: self.max_pipelined });
}
Ok(())
}
pub fn validate_connections(&self, current: usize) -> Result<(), LimitError> {
if current >= self.max_connections {
return Err(LimitError::ConnectionLimitReached { current, max: self.max_connections });
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LimitError {
TooManyHeaders {
count: usize,
max: usize,
},
BodyTooLarge {
size: usize,
max: usize,
},
HeaderTooLarge {
size: usize,
max: usize,
},
TooManyPipelined {
count: usize,
max: usize,
},
ConnectionLimitReached {
current: usize,
max: usize,
},
}
impl fmt::Display for LimitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LimitError::TooManyHeaders { count, max } => {
write!(f, "too many headers: {} (max {})", count, max)
}
LimitError::BodyTooLarge { size, max } => {
write!(f, "body too large: {} bytes (max {})", size, max)
}
LimitError::HeaderTooLarge { size, max } => {
write!(f, "header too large: {} bytes (max {})", size, max)
}
LimitError::TooManyPipelined { count, max } => {
write!(f, "too many pipelined requests: {} (max {})", count, max)
}
LimitError::ConnectionLimitReached { current, max } => {
write!(f, "connection limit reached: {} (max {})", current, max)
}
}
}
}
impl std::error::Error for LimitError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serve_limits_default() {
let limits = ServeLimits::default();
assert_eq!(limits.max_request_size, 2 * 1024 * 1024);
assert_eq!(limits.max_headers, 100);
assert_eq!(limits.max_header_size, 8 * 1024);
assert_eq!(limits.max_pipelined, 16);
assert_eq!(limits.max_connections, 1024);
}
#[test]
fn test_serve_limits_builder() {
let limits = ServeLimits::new()
.with_max_request_size(1024)
.with_max_headers(50)
.with_max_connections(100);
assert_eq!(limits.max_request_size, 1024);
assert_eq!(limits.max_headers, 50);
assert_eq!(limits.max_connections, 100);
}
#[test]
fn test_validate_request_ok() {
let limits = ServeLimits::default();
assert!(limits.validate_request(50, 1024).is_ok());
}
#[test]
fn test_validate_request_too_many_headers() {
let limits = ServeLimits::default();
let result = limits.validate_request(200, 1024);
assert!(matches!(result, Err(LimitError::TooManyHeaders { .. })));
}
#[test]
fn test_validate_request_body_too_large() {
let limits = ServeLimits::default();
let result = limits.validate_request(50, 10 * 1024 * 1024);
assert!(matches!(result, Err(LimitError::BodyTooLarge { .. })));
}
#[test]
fn test_validate_header_size_ok() {
let limits = ServeLimits::default();
assert!(limits.validate_header_size(1024).is_ok());
}
#[test]
fn test_validate_header_size_too_large() {
let limits = ServeLimits::default();
let result = limits.validate_header_size(16 * 1024);
assert!(matches!(result, Err(LimitError::HeaderTooLarge { .. })));
}
#[test]
fn test_validate_pipelined_ok() {
let limits = ServeLimits::default();
assert!(limits.validate_pipelined(10).is_ok());
}
#[test]
fn test_validate_pipelined_too_many() {
let limits = ServeLimits::default();
let result = limits.validate_pipelined(20);
assert!(matches!(result, Err(LimitError::TooManyPipelined { .. })));
}
#[test]
fn test_validate_connections_ok() {
let limits = ServeLimits::default();
assert!(limits.validate_connections(500).is_ok());
}
#[test]
fn test_validate_connections_limit_reached() {
let limits = ServeLimits::default();
let result = limits.validate_connections(1024);
assert!(matches!(result, Err(LimitError::ConnectionLimitReached { .. })));
}
#[test]
fn test_limit_error_display() {
let err = LimitError::TooManyHeaders { count: 150, max: 100 };
assert_eq!(format!("{}", err), "too many headers: 150 (max 100)");
let err = LimitError::BodyTooLarge { size: 5000000, max: 2097152 };
assert_eq!(format!("{}", err), "body too large: 5000000 bytes (max 2097152)");
}
#[test]
fn test_limit_error_eq() {
let err1 = LimitError::TooManyHeaders { count: 150, max: 100 };
let err2 = LimitError::TooManyHeaders { count: 150, max: 100 };
let err3 = LimitError::TooManyHeaders { count: 200, max: 100 };
assert_eq!(err1, err2);
assert_ne!(err1, err3);
}
#[test]
fn test_falsify_exact_boundaries() {
let limits = ServeLimits::new()
.with_max_headers(100)
.with_max_request_size(1000)
.with_max_connections(10);
assert!(limits.validate_request(100, 0).is_ok());
assert!(limits.validate_request(101, 0).is_err());
assert!(limits.validate_request(0, 1000).is_ok());
assert!(limits.validate_request(0, 1001).is_err());
assert!(limits.validate_connections(9).is_ok());
assert!(limits.validate_connections(10).is_err());
}
#[test]
fn test_falsify_zero_values_pass() {
let limits = ServeLimits::default();
assert!(
limits.validate_request(0, 0).is_ok(),
"FALSIFICATION FAILED: Zero values should always pass"
);
assert!(
limits.validate_header_size(0).is_ok(),
"FALSIFICATION FAILED: Zero header size should pass"
);
assert!(
limits.validate_pipelined(0).is_ok(),
"FALSIFICATION FAILED: Zero pipelined should pass"
);
assert!(
limits.validate_connections(0).is_ok(),
"FALSIFICATION FAILED: Zero connections should pass"
);
}
}