use crate::auth::{AuthError, AuthState, Claims, Permission};
use axum::{
body::Body,
extract::{Request, State},
http::{header, HeaderMap, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct CorsConfig {
pub allowed_origins: HashSet<String>,
pub allowed_methods: HashSet<Method>,
pub allowed_headers: HashSet<String>,
pub exposed_headers: HashSet<String>,
pub allow_credentials: bool,
pub max_age: u64,
}
impl Default for CorsConfig {
fn default() -> Self {
let mut methods = HashSet::new();
methods.insert(Method::GET);
methods.insert(Method::POST);
methods.insert(Method::PUT);
methods.insert(Method::DELETE);
methods.insert(Method::OPTIONS);
methods.insert(Method::HEAD);
let mut headers = HashSet::new();
headers.insert("content-type".to_string());
headers.insert("authorization".to_string());
headers.insert("accept".to_string());
headers.insert("origin".to_string());
headers.insert("x-requested-with".to_string());
Self {
allowed_origins: HashSet::new(), allowed_methods: methods,
allowed_headers: headers,
exposed_headers: HashSet::new(),
allow_credentials: false,
max_age: 86400, }
}
}
impl CorsConfig {
pub fn permissive() -> Self {
let mut config = Self::default();
config.allowed_origins.insert("*".to_string());
config
}
pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
self.allowed_origins.insert(origin.into());
self
}
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.allow_credentials = allow;
self
}
fn is_origin_allowed(&self, origin: &str) -> bool {
if self.allowed_origins.is_empty() || self.allowed_origins.contains("*") {
true
} else {
self.allowed_origins.contains(origin)
}
}
fn methods_string(&self) -> String {
self.allowed_methods
.iter()
.map(|m| m.as_str())
.collect::<Vec<_>>()
.join(", ")
}
fn headers_string(&self) -> String {
self.allowed_headers
.iter()
.cloned()
.collect::<Vec<_>>()
.join(", ")
}
}
#[derive(Clone)]
pub struct CorsState {
pub config: CorsConfig,
}
pub async fn cors_middleware(
State(cors_state): State<CorsState>,
req: Request,
next: Next,
) -> Response {
let origin = req
.headers()
.get(header::ORIGIN)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
if req.method() == Method::OPTIONS {
return build_preflight_response(&cors_state.config, origin.as_deref());
}
let mut response = next.run(req).await;
add_cors_headers(
response.headers_mut(),
&cors_state.config,
origin.as_deref(),
);
response
}
fn build_preflight_response(config: &CorsConfig, origin: Option<&str>) -> Response {
let mut response = Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.unwrap();
add_cors_headers(response.headers_mut(), config, origin);
if let Ok(value) = HeaderValue::from_str(&config.methods_string()) {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_METHODS, value);
}
if let Ok(value) = HeaderValue::from_str(&config.headers_string()) {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, value);
}
if let Ok(value) = HeaderValue::from_str(&config.max_age.to_string()) {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_MAX_AGE, value);
}
response
}
fn add_cors_headers(headers: &mut HeaderMap, config: &CorsConfig, origin: Option<&str>) {
let origin_value = if let Some(origin) = origin {
if config.is_origin_allowed(origin) {
if config.allowed_origins.contains("*") && !config.allow_credentials {
"*"
} else {
origin
}
} else {
return; }
} else if config.allowed_origins.contains("*") {
"*"
} else {
return;
};
if let Ok(value) = HeaderValue::from_str(origin_value) {
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, value);
}
if config.allow_credentials {
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if !config.exposed_headers.is_empty() {
let exposed = config
.exposed_headers
.iter()
.cloned()
.collect::<Vec<_>>()
.join(", ");
if let Ok(value) = HeaderValue::from_str(&exposed) {
headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, value);
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
pub burst_capacity: u32,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
burst_capacity: 10,
}
}
}
impl RateLimitConfig {
pub fn validate(&self) -> Result<(), String> {
if self.max_requests == 0 {
return Err("Maximum requests must be greater than 0".to_string());
}
if self.window.as_secs() == 0 {
return Err("Time window must be greater than 0".to_string());
}
if self.burst_capacity == 0 {
return Err("Burst capacity must be greater than 0".to_string());
}
if self.burst_capacity > self.max_requests {
return Err(format!(
"Burst capacity ({}) cannot exceed max requests ({})",
self.burst_capacity, self.max_requests
));
}
Ok(())
}
}
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
last_update: Instant,
capacity: f64,
refill_rate: f64, }
impl TokenBucket {
fn new(capacity: u32, refill_rate: f64) -> Self {
Self {
tokens: capacity as f64,
last_update: Instant::now(),
capacity: capacity as f64,
refill_rate,
}
}
fn try_acquire(&mut self) -> bool {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_update).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
self.last_update = now;
}
fn tokens_remaining(&self) -> u32 {
self.tokens as u32
}
}
#[derive(Clone)]
pub struct RateLimitState {
config: RateLimitConfig,
buckets: Arc<Mutex<std::collections::HashMap<String, TokenBucket>>>,
}
impl RateLimitState {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
async fn get_bucket(&self, ip: &str) -> (bool, u32) {
let mut buckets = self.buckets.lock().await;
let refill_rate = self.config.max_requests as f64 / self.config.window.as_secs_f64();
let bucket = buckets
.entry(ip.to_string())
.or_insert_with(|| TokenBucket::new(self.config.burst_capacity, refill_rate));
let allowed = bucket.try_acquire();
let remaining = bucket.tokens_remaining();
(allowed, remaining)
}
}
pub async fn rate_limit_middleware(
State(rate_state): State<RateLimitState>,
req: Request,
next: Next,
) -> Result<Response, RateLimitError> {
let ip = extract_client_ip(&req);
let (allowed, remaining) = rate_state.get_bucket(&ip).await;
if !allowed {
return Err(RateLimitError::TooManyRequests);
}
let mut response = next.run(req).await;
let headers = response.headers_mut();
if let Ok(value) = HeaderValue::from_str(&rate_state.config.max_requests.to_string()) {
headers.insert("X-RateLimit-Limit", value);
}
if let Ok(value) = HeaderValue::from_str(&remaining.to_string()) {
headers.insert("X-RateLimit-Remaining", value);
}
Ok(response)
}
fn extract_client_ip(req: &Request) -> String {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(s) = forwarded.to_str() {
if let Some(ip) = s.split(',').next() {
return ip.trim().to_string();
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(s) = real_ip.to_str() {
return s.to_string();
}
}
"unknown".to_string()
}
#[derive(Debug)]
pub enum RateLimitError {
TooManyRequests,
}
impl IntoResponse for RateLimitError {
fn into_response(self) -> Response {
let (status, message) = match self {
RateLimitError::TooManyRequests => (
StatusCode::TOO_MANY_REQUESTS,
"Rate limit exceeded. Please retry later.",
),
};
let mut response = (status, message).into_response();
response
.headers_mut()
.insert(header::RETRY_AFTER, HeaderValue::from_static("60"));
response
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionLevel {
Fastest,
#[default]
Balanced,
Best,
Custom(u32),
}
impl CompressionLevel {
pub fn to_level(self) -> u32 {
match self {
CompressionLevel::Fastest => 1,
CompressionLevel::Balanced => 5,
CompressionLevel::Best => 9,
CompressionLevel::Custom(level) => level.min(9),
}
}
pub fn to_brotli_quality(self) -> u32 {
match self {
CompressionLevel::Fastest => 1,
CompressionLevel::Balanced => 6,
CompressionLevel::Best => 11,
CompressionLevel::Custom(level) => level.min(11),
}
}
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub enable_gzip: bool,
pub enable_brotli: bool,
pub enable_deflate: bool,
pub level: CompressionLevel,
pub min_size: usize,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
enable_gzip: true,
enable_brotli: true,
enable_deflate: true,
level: CompressionLevel::Balanced,
min_size: 1024, }
}
}
impl CompressionConfig {
pub fn fast() -> Self {
Self {
level: CompressionLevel::Fastest,
..Default::default()
}
}
pub fn best() -> Self {
Self {
level: CompressionLevel::Best,
..Default::default()
}
}
pub fn with_level(mut self, level: CompressionLevel) -> Self {
self.level = level;
self
}
pub fn with_min_size(mut self, min_size: usize) -> Self {
self.min_size = min_size;
self
}
pub fn with_algorithms(mut self, gzip: bool, brotli: bool, deflate: bool) -> Self {
self.enable_gzip = gzip;
self.enable_brotli = brotli;
self.enable_deflate = deflate;
self
}
pub fn validate(&self) -> Result<(), String> {
if !self.enable_gzip && !self.enable_brotli && !self.enable_deflate {
return Err("At least one compression algorithm must be enabled".to_string());
}
if self.min_size > 100 * 1024 * 1024 {
return Err(format!(
"Minimum compression size {} is too large (max: 100MB)",
self.min_size
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub default_max_age: u64,
pub public: bool,
pub immutable_cids: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
default_max_age: 3600, public: true,
immutable_cids: true, }
}
}
impl CacheConfig {
pub fn validate(&self) -> Result<(), String> {
const MAX_AGE_LIMIT: u64 = 365 * 24 * 3600;
if self.default_max_age > MAX_AGE_LIMIT {
return Err(format!(
"Max age {} exceeds maximum {} (1 year)",
self.default_max_age, MAX_AGE_LIMIT
));
}
Ok(())
}
}
pub fn add_caching_headers(headers: &mut HeaderMap, cid: &str, config: &CacheConfig) {
if let Ok(etag) = HeaderValue::from_str(&format!("\"{}\"", cid)) {
headers.insert(header::ETAG, etag);
}
let mut cache_control = String::new();
if config.public {
cache_control.push_str("public, ");
} else {
cache_control.push_str("private, ");
}
cache_control.push_str(&format!("max-age={}", config.default_max_age));
if config.immutable_cids {
cache_control.push_str(", immutable");
}
if let Ok(value) = HeaderValue::from_str(&cache_control) {
headers.insert(header::CACHE_CONTROL, value);
}
}
pub fn check_etag_match(headers: &HeaderMap, cid: &str) -> bool {
if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH) {
if let Ok(value) = if_none_match.to_str() {
let etag = value.trim().trim_matches('"');
return etag == cid;
}
}
false
}
pub fn not_modified_response(cid: &str, config: &CacheConfig) -> Response {
let mut response = Response::builder()
.status(StatusCode::NOT_MODIFIED)
.body(Body::empty())
.unwrap();
add_caching_headers(response.headers_mut(), cid, config);
response
}
#[derive(Debug, Clone)]
pub struct AuthUser {
pub user_id: Uuid,
pub username: String,
pub claims: Option<Claims>,
}
fn authenticate_user(req: &Request, auth_state: &AuthState) -> Result<AuthUser, AuthError> {
let auth_header = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.ok_or(AuthError::InvalidToken(
"Missing Authorization header".to_string(),
))?;
if let Some(token) = auth_header.strip_prefix("Bearer ") {
let claims = auth_state.jwt_manager.validate_token(token)?;
let user = auth_state.user_store.get_user(&claims.username)?;
return Ok(AuthUser {
user_id: user.id,
username: user.username,
claims: Some(claims),
});
}
if auth_header.starts_with("ipfrs_") {
let (_api_key, user_id) = auth_state.api_key_store.authenticate(auth_header)?;
let user = auth_state.user_store.get_by_id(&user_id)?;
return Ok(AuthUser {
user_id: user.id,
username: user.username,
claims: None,
});
}
Err(AuthError::InvalidToken(
"Authorization header must be either 'Bearer <token>' or 'ipfrs_<key>'".to_string(),
))
}
pub async fn auth_middleware(
State(auth_state): State<AuthState>,
mut req: Request,
next: Next,
) -> Result<Response, AuthMiddlewareError> {
let auth_user = authenticate_user(&req, &auth_state)?;
req.extensions_mut().insert(auth_user);
Ok(next.run(req).await)
}
type PermissionCheckFuture = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Response, AuthMiddlewareError>> + Send>,
>;
pub fn require_permission(
required: Permission,
) -> impl Fn(State<AuthState>, Request, Next) -> PermissionCheckFuture + Clone {
move |State(auth_state): State<AuthState>, req: Request, next: Next| {
let required = required;
Box::pin(async move {
let auth_user = req
.extensions()
.get::<AuthUser>()
.ok_or_else(|| AuthError::InvalidToken("User not authenticated".to_string()))?;
let user = auth_state.user_store.get_by_id(&auth_user.user_id)?;
if !user.has_permission(required) {
return Err(AuthMiddlewareError::from(
AuthError::InsufficientPermissions,
));
}
Ok(next.run(req).await)
})
}
}
#[derive(Debug)]
pub struct AuthMiddlewareError {
error: AuthError,
}
impl From<AuthError> for AuthMiddlewareError {
fn from(error: AuthError) -> Self {
Self { error }
}
}
impl IntoResponse for AuthMiddlewareError {
fn into_response(self) -> Response {
let (status, message) = match self.error {
AuthError::InvalidToken(_) | AuthError::TokenExpired => {
(StatusCode::UNAUTHORIZED, "Authentication required")
}
AuthError::InsufficientPermissions => {
(StatusCode::FORBIDDEN, "Insufficient permissions")
}
AuthError::UserNotFound | AuthError::InvalidCredentials => {
(StatusCode::UNAUTHORIZED, "Invalid credentials")
}
_ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
};
(status, message).into_response()
}
}
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub max_body_size: usize,
pub max_cid_length: usize,
pub validate_cid_format: bool,
pub content_type_validation: bool,
pub max_batch_size: usize,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
max_body_size: 100 * 1024 * 1024, max_cid_length: 100,
validate_cid_format: true,
content_type_validation: true,
max_batch_size: 1000,
}
}
}
impl ValidationConfig {
pub fn strict() -> Self {
Self {
max_body_size: 10 * 1024 * 1024, max_cid_length: 64,
validate_cid_format: true,
content_type_validation: true,
max_batch_size: 100,
}
}
pub fn permissive() -> Self {
Self {
max_body_size: 1024 * 1024 * 1024, max_cid_length: 200,
validate_cid_format: false,
content_type_validation: false,
max_batch_size: 10000,
}
}
}
#[derive(Debug)]
pub enum ValidationError {
BodyTooLarge { size: usize, max: usize },
InvalidCid(String),
InvalidContentType { expected: String, actual: String },
MissingParameter(String),
BatchTooLarge { size: usize, max: usize },
InvalidParameter { name: String, reason: String },
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValidationError::BodyTooLarge { size, max } => {
write!(
f,
"Request body too large: {} bytes (max: {} bytes)",
size, max
)
}
ValidationError::InvalidCid(cid) => {
write!(f, "Invalid CID format: {}", cid)
}
ValidationError::InvalidContentType { expected, actual } => {
write!(
f,
"Invalid content type: expected {}, got {}",
expected, actual
)
}
ValidationError::MissingParameter(param) => {
write!(f, "Missing required parameter: {}", param)
}
ValidationError::BatchTooLarge { size, max } => {
write!(
f,
"Batch size too large: {} items (max: {} items)",
size, max
)
}
ValidationError::InvalidParameter { name, reason } => {
write!(f, "Invalid parameter '{}': {}", name, reason)
}
}
}
}
impl std::error::Error for ValidationError {}
impl IntoResponse for ValidationError {
fn into_response(self) -> Response {
let request_id = Uuid::new_v4();
let error_message = self.to_string();
let (status, code) = match self {
ValidationError::BodyTooLarge { .. } => {
(StatusCode::PAYLOAD_TOO_LARGE, "BODY_TOO_LARGE")
}
ValidationError::InvalidCid(_) => (StatusCode::BAD_REQUEST, "INVALID_CID"),
ValidationError::InvalidContentType { .. } => {
(StatusCode::UNSUPPORTED_MEDIA_TYPE, "INVALID_CONTENT_TYPE")
}
ValidationError::MissingParameter(_) => (StatusCode::BAD_REQUEST, "MISSING_PARAMETER"),
ValidationError::BatchTooLarge { .. } => (StatusCode::BAD_REQUEST, "BATCH_TOO_LARGE"),
ValidationError::InvalidParameter { .. } => {
(StatusCode::BAD_REQUEST, "INVALID_PARAMETER")
}
};
let body = serde_json::json!({
"error": error_message,
"code": code,
"request_id": request_id.to_string(),
});
(status, serde_json::to_string(&body).unwrap()).into_response()
}
}
pub fn validate_cid(cid: &str, config: &ValidationConfig) -> Result<(), ValidationError> {
if cid.is_empty() {
return Err(ValidationError::InvalidCid(
"CID cannot be empty".to_string(),
));
}
if !config.validate_cid_format {
return Ok(());
}
if cid.len() > config.max_cid_length {
return Err(ValidationError::InvalidCid(format!(
"CID too long: {} chars (max: {})",
cid.len(),
config.max_cid_length
)));
}
if cid.starts_with("Qm") && cid.len() == 46 {
Ok(())
} else if cid.starts_with("b") || cid.starts_with("z") || cid.starts_with("f") {
Ok(())
} else {
Err(ValidationError::InvalidCid(
"Invalid CID format: must be CIDv0 (Qm...) or CIDv1 (b..., z..., f...)".to_string(),
))
}
}
pub fn validate_batch_size(size: usize, config: &ValidationConfig) -> Result<(), ValidationError> {
if size == 0 {
return Err(ValidationError::InvalidParameter {
name: "batch".to_string(),
reason: "Batch cannot be empty".to_string(),
});
}
if size > config.max_batch_size {
return Err(ValidationError::BatchTooLarge {
size,
max: config.max_batch_size,
});
}
Ok(())
}
pub fn validate_content_type(
headers: &HeaderMap,
expected: &str,
config: &ValidationConfig,
) -> Result<(), ValidationError> {
if !config.content_type_validation {
return Ok(());
}
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("");
if !content_type.starts_with(expected) {
return Err(ValidationError::InvalidContentType {
expected: expected.to_string(),
actual: content_type.to_string(),
});
}
Ok(())
}
#[derive(Clone)]
pub struct ValidationState {
pub config: ValidationConfig,
}
pub async fn validation_middleware(
State(_validation_state): State<ValidationState>,
req: Request,
next: Next,
) -> Result<Response, ValidationError> {
let (parts, body) = req.into_parts();
if parts.method == Method::POST || parts.method == Method::PUT {
if let Some(content_type) = parts.headers.get(header::CONTENT_TYPE) {
if let Ok(ct_str) = content_type.to_str() {
if ct_str.contains("multipart/form-data") {
let req = Request::from_parts(parts, body);
return Ok(next.run(req).await);
}
}
}
}
let req = Request::from_parts(parts, body);
Ok(next.run(req).await)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cors_config_default() {
let config = CorsConfig::default();
assert!(config.allowed_origins.is_empty());
assert!(config.allowed_methods.contains(&Method::GET));
assert!(config.allowed_methods.contains(&Method::POST));
assert!(!config.allow_credentials);
assert_eq!(config.max_age, 86400);
}
#[test]
fn test_cors_config_permissive() {
let config = CorsConfig::permissive();
assert!(config.allowed_origins.contains("*"));
assert!(config.is_origin_allowed("https://example.com"));
assert!(config.is_origin_allowed("http://localhost:3000"));
}
#[test]
fn test_cors_config_allow_origin() {
let config = CorsConfig::default()
.allow_origin("https://example.com")
.allow_origin("https://api.example.com");
assert!(config.is_origin_allowed("https://example.com"));
assert!(config.is_origin_allowed("https://api.example.com"));
assert!(!config.is_origin_allowed("https://other.com"));
}
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert_eq!(config.max_requests, 100);
assert_eq!(config.window, Duration::from_secs(60));
assert_eq!(config.burst_capacity, 10);
}
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert_eq!(config.default_max_age, 3600);
assert!(config.public);
assert!(config.immutable_cids);
}
#[test]
fn test_add_caching_headers() {
let mut headers = HeaderMap::new();
let config = CacheConfig::default();
add_caching_headers(&mut headers, "QmTest123", &config);
assert!(headers.contains_key(header::ETAG));
assert!(headers.contains_key(header::CACHE_CONTROL));
let etag = headers.get(header::ETAG).unwrap().to_str().unwrap();
assert_eq!(etag, "\"QmTest123\"");
let cache_control = headers
.get(header::CACHE_CONTROL)
.unwrap()
.to_str()
.unwrap();
assert!(cache_control.contains("public"));
assert!(cache_control.contains("max-age=3600"));
assert!(cache_control.contains("immutable"));
}
#[test]
fn test_check_etag_match() {
let mut headers = HeaderMap::new();
assert!(!check_etag_match(&headers, "QmTest123"));
headers.insert(
header::IF_NONE_MATCH,
HeaderValue::from_static("\"QmTest123\""),
);
assert!(check_etag_match(&headers, "QmTest123"));
assert!(!check_etag_match(&headers, "QmOther456"));
}
#[tokio::test]
async fn test_rate_limit_state() {
let config = RateLimitConfig {
max_requests: 5,
window: Duration::from_secs(1),
burst_capacity: 3,
};
let state = RateLimitState::new(config);
for _ in 0..3 {
let (allowed, _) = state.get_bucket("127.0.0.1").await;
assert!(allowed);
}
}
#[test]
fn test_compression_level_to_level() {
assert_eq!(CompressionLevel::Fastest.to_level(), 1);
assert_eq!(CompressionLevel::Balanced.to_level(), 5);
assert_eq!(CompressionLevel::Best.to_level(), 9);
assert_eq!(CompressionLevel::Custom(7).to_level(), 7);
assert_eq!(CompressionLevel::Custom(15).to_level(), 9); }
#[test]
fn test_compression_level_to_brotli_quality() {
assert_eq!(CompressionLevel::Fastest.to_brotli_quality(), 1);
assert_eq!(CompressionLevel::Balanced.to_brotli_quality(), 6);
assert_eq!(CompressionLevel::Best.to_brotli_quality(), 11);
assert_eq!(CompressionLevel::Custom(8).to_brotli_quality(), 8);
assert_eq!(CompressionLevel::Custom(15).to_brotli_quality(), 11); }
#[test]
fn test_compression_config_default() {
let config = CompressionConfig::default();
assert!(config.enable_gzip);
assert!(config.enable_brotli);
assert!(config.enable_deflate);
assert_eq!(config.level, CompressionLevel::Balanced);
assert_eq!(config.min_size, 1024);
}
#[test]
fn test_compression_config_fast() {
let config = CompressionConfig::fast();
assert_eq!(config.level, CompressionLevel::Fastest);
assert!(config.enable_gzip);
}
#[test]
fn test_compression_config_best() {
let config = CompressionConfig::best();
assert_eq!(config.level, CompressionLevel::Best);
assert!(config.enable_brotli);
}
#[test]
fn test_compression_config_builder() {
let config = CompressionConfig::default()
.with_level(CompressionLevel::Custom(7))
.with_min_size(2048)
.with_algorithms(true, false, false);
assert_eq!(config.level, CompressionLevel::Custom(7));
assert_eq!(config.min_size, 2048);
assert!(config.enable_gzip);
assert!(!config.enable_brotli);
assert!(!config.enable_deflate);
}
#[test]
fn test_compression_config_validation_valid() {
let config = CompressionConfig::default();
assert!(config.validate().is_ok());
let config = CompressionConfig::default().with_algorithms(true, false, false);
assert!(config.validate().is_ok());
}
#[test]
fn test_compression_config_validation_invalid() {
let config = CompressionConfig::default().with_algorithms(false, false, false);
assert!(config.validate().is_err());
let config = CompressionConfig::default().with_min_size(200 * 1024 * 1024);
assert!(config.validate().is_err());
}
#[test]
fn test_rate_limit_config_validation_valid() {
let config = RateLimitConfig::default();
assert!(config.validate().is_ok());
let config = RateLimitConfig {
max_requests: 100,
window: Duration::from_secs(60),
burst_capacity: 50,
};
assert!(config.validate().is_ok());
}
#[test]
fn test_rate_limit_config_validation_invalid() {
let config = RateLimitConfig {
max_requests: 0,
window: Duration::from_secs(60),
burst_capacity: 10,
};
assert!(config.validate().is_err());
let config = RateLimitConfig {
max_requests: 100,
window: Duration::from_secs(0),
burst_capacity: 10,
};
assert!(config.validate().is_err());
let config = RateLimitConfig {
max_requests: 100,
window: Duration::from_secs(60),
burst_capacity: 200,
};
assert!(config.validate().is_err());
}
#[test]
fn test_cache_config_validation_valid() {
let config = CacheConfig::default();
assert!(config.validate().is_ok());
let config = CacheConfig {
default_max_age: 86400, public: true,
immutable_cids: true,
};
assert!(config.validate().is_ok());
}
#[test]
fn test_cache_config_validation_invalid() {
let config = CacheConfig {
default_max_age: 400 * 24 * 3600, public: true,
immutable_cids: true,
};
assert!(config.validate().is_err());
}
#[test]
fn test_validation_config_default() {
let config = ValidationConfig::default();
assert_eq!(config.max_body_size, 100 * 1024 * 1024);
assert_eq!(config.max_cid_length, 100);
assert!(config.validate_cid_format);
assert!(config.content_type_validation);
assert_eq!(config.max_batch_size, 1000);
}
#[test]
fn test_validation_config_strict() {
let config = ValidationConfig::strict();
assert_eq!(config.max_body_size, 10 * 1024 * 1024);
assert_eq!(config.max_cid_length, 64);
assert_eq!(config.max_batch_size, 100);
}
#[test]
fn test_validation_config_permissive() {
let config = ValidationConfig::permissive();
assert_eq!(config.max_body_size, 1024 * 1024 * 1024);
assert_eq!(config.max_cid_length, 200);
assert!(!config.validate_cid_format);
assert!(!config.content_type_validation);
assert_eq!(config.max_batch_size, 10000);
}
#[test]
fn test_validate_cid_v0() {
let config = ValidationConfig::default();
assert!(validate_cid("QmXoypizjW3WknFiJnKLwHCnL72vedxjQkDDP1mXWo6uco", &config).is_ok());
assert!(validate_cid("QmShort", &config).is_err());
assert!(validate_cid("", &config).is_err());
}
#[test]
fn test_validate_cid_v1() {
let config = ValidationConfig::default();
assert!(validate_cid(
"bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
&config
)
.is_ok());
assert!(validate_cid("zb2rhk6GMPQF8p1kqXvhYnCMp3hGGUQVvqp6qjdvNLKqCqKCo", &config).is_ok());
assert!(validate_cid("invalid_cid_format", &config).is_err());
}
#[test]
fn test_validate_cid_disabled() {
let config = ValidationConfig {
validate_cid_format: false,
..Default::default()
};
assert!(validate_cid("invalid_format", &config).is_ok());
assert!(validate_cid("", &config).is_err()); }
#[test]
fn test_validate_batch_size_valid() {
let config = ValidationConfig::default();
assert!(validate_batch_size(1, &config).is_ok());
assert!(validate_batch_size(100, &config).is_ok());
assert!(validate_batch_size(1000, &config).is_ok());
}
#[test]
fn test_validate_batch_size_invalid() {
let config = ValidationConfig::default();
assert!(validate_batch_size(0, &config).is_err());
assert!(validate_batch_size(1001, &config).is_err());
assert!(validate_batch_size(10000, &config).is_err());
}
#[test]
fn test_validate_content_type_valid() {
let config = ValidationConfig::default();
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
assert!(validate_content_type(&headers, "application/json", &config).is_ok());
}
#[test]
fn test_validate_content_type_invalid() {
let config = ValidationConfig::default();
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
assert!(validate_content_type(&headers, "application/json", &config).is_err());
}
#[test]
fn test_validate_content_type_disabled() {
let config = ValidationConfig {
content_type_validation: false,
..Default::default()
};
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
assert!(validate_content_type(&headers, "application/json", &config).is_ok());
}
#[test]
fn test_validation_error_display() {
let err = ValidationError::InvalidCid("test".to_string());
assert_eq!(err.to_string(), "Invalid CID format: test");
let err = ValidationError::BodyTooLarge {
size: 200,
max: 100,
};
assert!(err.to_string().contains("200 bytes"));
assert!(err.to_string().contains("100 bytes"));
let err = ValidationError::BatchTooLarge {
size: 2000,
max: 1000,
};
assert!(err.to_string().contains("2000 items"));
assert!(err.to_string().contains("1000 items"));
}
}