ipfrs_interface/
middleware.rs

1//! HTTP Middleware for IPFRS Gateway
2//!
3//! Provides:
4//! - Authentication and authorization middleware using JWT tokens and API keys
5//! - CORS middleware for cross-origin requests
6//! - Rate limiting middleware for DoS prevention
7//! - Compression middleware for bandwidth optimization
8//! - Caching middleware for HTTP caching headers
9
10use crate::auth::{AuthError, AuthState, Claims, Permission};
11use axum::{
12    body::Body,
13    extract::{Request, State},
14    http::{header, HeaderMap, HeaderValue, Method, StatusCode},
15    middleware::Next,
16    response::{IntoResponse, Response},
17};
18use std::collections::HashSet;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::Mutex;
22use uuid::Uuid;
23
24// ============================================================================
25// CORS Configuration
26// ============================================================================
27
28/// CORS configuration
29#[derive(Debug, Clone)]
30pub struct CorsConfig {
31    /// Allowed origins (use "*" for any origin)
32    pub allowed_origins: HashSet<String>,
33    /// Allowed HTTP methods
34    pub allowed_methods: HashSet<Method>,
35    /// Allowed headers
36    pub allowed_headers: HashSet<String>,
37    /// Headers to expose to the client
38    pub exposed_headers: HashSet<String>,
39    /// Allow credentials (cookies, authorization headers)
40    pub allow_credentials: bool,
41    /// Max age for preflight cache (seconds)
42    pub max_age: u64,
43}
44
45impl Default for CorsConfig {
46    fn default() -> Self {
47        let mut methods = HashSet::new();
48        methods.insert(Method::GET);
49        methods.insert(Method::POST);
50        methods.insert(Method::PUT);
51        methods.insert(Method::DELETE);
52        methods.insert(Method::OPTIONS);
53        methods.insert(Method::HEAD);
54
55        let mut headers = HashSet::new();
56        headers.insert("content-type".to_string());
57        headers.insert("authorization".to_string());
58        headers.insert("accept".to_string());
59        headers.insert("origin".to_string());
60        headers.insert("x-requested-with".to_string());
61
62        Self {
63            allowed_origins: HashSet::new(), // Empty = allow all
64            allowed_methods: methods,
65            allowed_headers: headers,
66            exposed_headers: HashSet::new(),
67            allow_credentials: false,
68            max_age: 86400, // 24 hours
69        }
70    }
71}
72
73impl CorsConfig {
74    /// Create a permissive CORS config (allows all origins)
75    pub fn permissive() -> Self {
76        let mut config = Self::default();
77        config.allowed_origins.insert("*".to_string());
78        config
79    }
80
81    /// Allow specific origin
82    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
83        self.allowed_origins.insert(origin.into());
84        self
85    }
86
87    /// Allow credentials
88    pub fn allow_credentials(mut self, allow: bool) -> Self {
89        self.allow_credentials = allow;
90        self
91    }
92
93    /// Check if origin is allowed
94    fn is_origin_allowed(&self, origin: &str) -> bool {
95        if self.allowed_origins.is_empty() || self.allowed_origins.contains("*") {
96            true
97        } else {
98            self.allowed_origins.contains(origin)
99        }
100    }
101
102    /// Get allowed methods as comma-separated string
103    fn methods_string(&self) -> String {
104        self.allowed_methods
105            .iter()
106            .map(|m| m.as_str())
107            .collect::<Vec<_>>()
108            .join(", ")
109    }
110
111    /// Get allowed headers as comma-separated string
112    fn headers_string(&self) -> String {
113        self.allowed_headers
114            .iter()
115            .cloned()
116            .collect::<Vec<_>>()
117            .join(", ")
118    }
119}
120
121/// CORS middleware state
122#[derive(Clone)]
123pub struct CorsState {
124    pub config: CorsConfig,
125}
126
127/// CORS middleware
128///
129/// Handles preflight requests and adds CORS headers to responses.
130pub async fn cors_middleware(
131    State(cors_state): State<CorsState>,
132    req: Request,
133    next: Next,
134) -> Response {
135    let origin = req
136        .headers()
137        .get(header::ORIGIN)
138        .and_then(|h| h.to_str().ok())
139        .map(|s| s.to_string());
140
141    // Handle preflight (OPTIONS) requests
142    if req.method() == Method::OPTIONS {
143        return build_preflight_response(&cors_state.config, origin.as_deref());
144    }
145
146    // Process the request
147    let mut response = next.run(req).await;
148
149    // Add CORS headers to response
150    add_cors_headers(
151        response.headers_mut(),
152        &cors_state.config,
153        origin.as_deref(),
154    );
155
156    response
157}
158
159/// Build preflight response for OPTIONS requests
160fn build_preflight_response(config: &CorsConfig, origin: Option<&str>) -> Response {
161    let mut response = Response::builder()
162        .status(StatusCode::NO_CONTENT)
163        .body(Body::empty())
164        .unwrap();
165
166    add_cors_headers(response.headers_mut(), config, origin);
167
168    // Add preflight-specific headers
169    if let Ok(value) = HeaderValue::from_str(&config.methods_string()) {
170        response
171            .headers_mut()
172            .insert(header::ACCESS_CONTROL_ALLOW_METHODS, value);
173    }
174    if let Ok(value) = HeaderValue::from_str(&config.headers_string()) {
175        response
176            .headers_mut()
177            .insert(header::ACCESS_CONTROL_ALLOW_HEADERS, value);
178    }
179    if let Ok(value) = HeaderValue::from_str(&config.max_age.to_string()) {
180        response
181            .headers_mut()
182            .insert(header::ACCESS_CONTROL_MAX_AGE, value);
183    }
184
185    response
186}
187
188/// Add CORS headers to a response
189fn add_cors_headers(headers: &mut HeaderMap, config: &CorsConfig, origin: Option<&str>) {
190    // Access-Control-Allow-Origin
191    let origin_value = if let Some(origin) = origin {
192        if config.is_origin_allowed(origin) {
193            if config.allowed_origins.contains("*") && !config.allow_credentials {
194                "*"
195            } else {
196                origin
197            }
198        } else {
199            return; // Origin not allowed, don't add CORS headers
200        }
201    } else if config.allowed_origins.contains("*") {
202        "*"
203    } else {
204        return;
205    };
206
207    if let Ok(value) = HeaderValue::from_str(origin_value) {
208        headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, value);
209    }
210
211    // Access-Control-Allow-Credentials
212    if config.allow_credentials {
213        headers.insert(
214            header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
215            HeaderValue::from_static("true"),
216        );
217    }
218
219    // Access-Control-Expose-Headers
220    if !config.exposed_headers.is_empty() {
221        let exposed = config
222            .exposed_headers
223            .iter()
224            .cloned()
225            .collect::<Vec<_>>()
226            .join(", ");
227        if let Ok(value) = HeaderValue::from_str(&exposed) {
228            headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, value);
229        }
230    }
231}
232
233// ============================================================================
234// Rate Limiting
235// ============================================================================
236
237/// Rate limiter configuration
238#[derive(Debug, Clone)]
239pub struct RateLimitConfig {
240    /// Maximum requests per window
241    pub max_requests: u32,
242    /// Time window duration
243    pub window: Duration,
244    /// Burst capacity (token bucket max tokens)
245    pub burst_capacity: u32,
246}
247
248impl Default for RateLimitConfig {
249    fn default() -> Self {
250        Self {
251            max_requests: 100,
252            window: Duration::from_secs(60),
253            burst_capacity: 10,
254        }
255    }
256}
257
258impl RateLimitConfig {
259    /// Validate configuration
260    pub fn validate(&self) -> Result<(), String> {
261        if self.max_requests == 0 {
262            return Err("Maximum requests must be greater than 0".to_string());
263        }
264
265        if self.window.as_secs() == 0 {
266            return Err("Time window must be greater than 0".to_string());
267        }
268
269        if self.burst_capacity == 0 {
270            return Err("Burst capacity must be greater than 0".to_string());
271        }
272
273        if self.burst_capacity > self.max_requests {
274            return Err(format!(
275                "Burst capacity ({}) cannot exceed max requests ({})",
276                self.burst_capacity, self.max_requests
277            ));
278        }
279
280        Ok(())
281    }
282}
283
284/// Token bucket for rate limiting
285#[derive(Debug)]
286struct TokenBucket {
287    tokens: f64,
288    last_update: Instant,
289    capacity: f64,
290    refill_rate: f64, // tokens per second
291}
292
293impl TokenBucket {
294    fn new(capacity: u32, refill_rate: f64) -> Self {
295        Self {
296            tokens: capacity as f64,
297            last_update: Instant::now(),
298            capacity: capacity as f64,
299            refill_rate,
300        }
301    }
302
303    fn try_acquire(&mut self) -> bool {
304        self.refill();
305        if self.tokens >= 1.0 {
306            self.tokens -= 1.0;
307            true
308        } else {
309            false
310        }
311    }
312
313    fn refill(&mut self) {
314        let now = Instant::now();
315        let elapsed = now.duration_since(self.last_update).as_secs_f64();
316        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
317        self.last_update = now;
318    }
319
320    fn tokens_remaining(&self) -> u32 {
321        self.tokens as u32
322    }
323}
324
325/// Rate limiter state (per-IP buckets)
326#[derive(Clone)]
327pub struct RateLimitState {
328    config: RateLimitConfig,
329    buckets: Arc<Mutex<std::collections::HashMap<String, TokenBucket>>>,
330}
331
332impl RateLimitState {
333    /// Create a new rate limiter state
334    pub fn new(config: RateLimitConfig) -> Self {
335        Self {
336            config,
337            buckets: Arc::new(Mutex::new(std::collections::HashMap::new())),
338        }
339    }
340
341    /// Get or create a token bucket for an IP
342    async fn get_bucket(&self, ip: &str) -> (bool, u32) {
343        let mut buckets = self.buckets.lock().await;
344
345        let refill_rate = self.config.max_requests as f64 / self.config.window.as_secs_f64();
346
347        let bucket = buckets
348            .entry(ip.to_string())
349            .or_insert_with(|| TokenBucket::new(self.config.burst_capacity, refill_rate));
350
351        let allowed = bucket.try_acquire();
352        let remaining = bucket.tokens_remaining();
353
354        (allowed, remaining)
355    }
356}
357
358/// Rate limiting middleware
359///
360/// Limits requests per IP using token bucket algorithm.
361pub async fn rate_limit_middleware(
362    State(rate_state): State<RateLimitState>,
363    req: Request,
364    next: Next,
365) -> Result<Response, RateLimitError> {
366    // Extract client IP from headers or connection
367    let ip = extract_client_ip(&req);
368
369    let (allowed, remaining) = rate_state.get_bucket(&ip).await;
370
371    if !allowed {
372        return Err(RateLimitError::TooManyRequests);
373    }
374
375    let mut response = next.run(req).await;
376
377    // Add rate limit headers
378    let headers = response.headers_mut();
379    if let Ok(value) = HeaderValue::from_str(&rate_state.config.max_requests.to_string()) {
380        headers.insert("X-RateLimit-Limit", value);
381    }
382    if let Ok(value) = HeaderValue::from_str(&remaining.to_string()) {
383        headers.insert("X-RateLimit-Remaining", value);
384    }
385
386    Ok(response)
387}
388
389/// Extract client IP from request
390fn extract_client_ip(req: &Request) -> String {
391    // Check X-Forwarded-For first (for proxied requests)
392    if let Some(forwarded) = req.headers().get("x-forwarded-for") {
393        if let Ok(s) = forwarded.to_str() {
394            if let Some(ip) = s.split(',').next() {
395                return ip.trim().to_string();
396            }
397        }
398    }
399
400    // Check X-Real-IP
401    if let Some(real_ip) = req.headers().get("x-real-ip") {
402        if let Ok(s) = real_ip.to_str() {
403            return s.to_string();
404        }
405    }
406
407    // Fallback to unknown
408    "unknown".to_string()
409}
410
411/// Rate limit error
412#[derive(Debug)]
413pub enum RateLimitError {
414    TooManyRequests,
415}
416
417impl IntoResponse for RateLimitError {
418    fn into_response(self) -> Response {
419        let (status, message) = match self {
420            RateLimitError::TooManyRequests => (
421                StatusCode::TOO_MANY_REQUESTS,
422                "Rate limit exceeded. Please retry later.",
423            ),
424        };
425
426        let mut response = (status, message).into_response();
427
428        // Add Retry-After header (60 seconds)
429        response
430            .headers_mut()
431            .insert(header::RETRY_AFTER, HeaderValue::from_static("60"));
432
433        response
434    }
435}
436
437// ============================================================================
438// Compression Configuration
439// ============================================================================
440
441/// Compression level - balances speed vs compression ratio
442#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
443pub enum CompressionLevel {
444    /// Fastest compression, larger files (level 1)
445    Fastest,
446    /// Balanced compression (level 5-6)
447    #[default]
448    Balanced,
449    /// Best compression, slower (level 9)
450    Best,
451    /// Custom compression level (0-9)
452    Custom(u32),
453}
454
455impl CompressionLevel {
456    /// Get the numeric compression level for gzip/deflate
457    pub fn to_level(self) -> u32 {
458        match self {
459            CompressionLevel::Fastest => 1,
460            CompressionLevel::Balanced => 5,
461            CompressionLevel::Best => 9,
462            CompressionLevel::Custom(level) => level.min(9),
463        }
464    }
465
466    /// Get the quality level for brotli (0-11)
467    pub fn to_brotli_quality(self) -> u32 {
468        match self {
469            CompressionLevel::Fastest => 1,
470            CompressionLevel::Balanced => 6,
471            CompressionLevel::Best => 11,
472            CompressionLevel::Custom(level) => level.min(11),
473        }
474    }
475}
476
477/// Compression configuration
478#[derive(Debug, Clone)]
479pub struct CompressionConfig {
480    /// Enable gzip compression
481    pub enable_gzip: bool,
482    /// Enable brotli compression
483    pub enable_brotli: bool,
484    /// Enable deflate compression
485    pub enable_deflate: bool,
486    /// Compression level (speed vs size trade-off)
487    pub level: CompressionLevel,
488    /// Minimum size in bytes to compress (smaller files not compressed)
489    pub min_size: usize,
490}
491
492impl Default for CompressionConfig {
493    fn default() -> Self {
494        Self {
495            enable_gzip: true,
496            enable_brotli: true,
497            enable_deflate: true,
498            level: CompressionLevel::Balanced,
499            min_size: 1024, // Don't compress files smaller than 1KB
500        }
501    }
502}
503
504impl CompressionConfig {
505    /// Create a fast compression config (prioritize speed)
506    pub fn fast() -> Self {
507        Self {
508            level: CompressionLevel::Fastest,
509            ..Default::default()
510        }
511    }
512
513    /// Create a best compression config (prioritize size)
514    pub fn best() -> Self {
515        Self {
516            level: CompressionLevel::Best,
517            ..Default::default()
518        }
519    }
520
521    /// Set compression level
522    pub fn with_level(mut self, level: CompressionLevel) -> Self {
523        self.level = level;
524        self
525    }
526
527    /// Set minimum size threshold
528    pub fn with_min_size(mut self, min_size: usize) -> Self {
529        self.min_size = min_size;
530        self
531    }
532
533    /// Enable/disable specific compression algorithms
534    pub fn with_algorithms(mut self, gzip: bool, brotli: bool, deflate: bool) -> Self {
535        self.enable_gzip = gzip;
536        self.enable_brotli = brotli;
537        self.enable_deflate = deflate;
538        self
539    }
540
541    /// Validate configuration
542    pub fn validate(&self) -> Result<(), String> {
543        // At least one compression algorithm should be enabled if we're using compression
544        if !self.enable_gzip && !self.enable_brotli && !self.enable_deflate {
545            return Err("At least one compression algorithm must be enabled".to_string());
546        }
547
548        // Minimum size should be reasonable
549        if self.min_size > 100 * 1024 * 1024 {
550            return Err(format!(
551                "Minimum compression size {} is too large (max: 100MB)",
552                self.min_size
553            ));
554        }
555
556        Ok(())
557    }
558}
559
560// ============================================================================
561// HTTP Caching
562// ============================================================================
563
564/// Cache configuration
565#[derive(Debug, Clone)]
566pub struct CacheConfig {
567    /// Default max-age for cacheable responses (seconds)
568    pub default_max_age: u64,
569    /// Whether responses are public (can be cached by CDNs)
570    pub public: bool,
571    /// Whether to mark CID responses as immutable
572    pub immutable_cids: bool,
573}
574
575impl Default for CacheConfig {
576    fn default() -> Self {
577        Self {
578            default_max_age: 3600, // 1 hour
579            public: true,
580            immutable_cids: true, // CID content is immutable by definition
581        }
582    }
583}
584
585impl CacheConfig {
586    /// Validate configuration
587    pub fn validate(&self) -> Result<(), String> {
588        // Max age should be reasonable (not more than 1 year)
589        const MAX_AGE_LIMIT: u64 = 365 * 24 * 3600; // 1 year in seconds
590
591        if self.default_max_age > MAX_AGE_LIMIT {
592            return Err(format!(
593                "Max age {} exceeds maximum {} (1 year)",
594                self.default_max_age, MAX_AGE_LIMIT
595            ));
596        }
597
598        Ok(())
599    }
600}
601
602/// Add caching headers to a response for a given CID
603pub fn add_caching_headers(headers: &mut HeaderMap, cid: &str, config: &CacheConfig) {
604    // ETag based on CID (content-addressed = perfect ETag)
605    if let Ok(etag) = HeaderValue::from_str(&format!("\"{}\"", cid)) {
606        headers.insert(header::ETAG, etag);
607    }
608
609    // Cache-Control
610    let mut cache_control = String::new();
611    if config.public {
612        cache_control.push_str("public, ");
613    } else {
614        cache_control.push_str("private, ");
615    }
616    cache_control.push_str(&format!("max-age={}", config.default_max_age));
617
618    // CID content is immutable - it will never change
619    if config.immutable_cids {
620        cache_control.push_str(", immutable");
621    }
622
623    if let Ok(value) = HeaderValue::from_str(&cache_control) {
624        headers.insert(header::CACHE_CONTROL, value);
625    }
626}
627
628/// Check if request has a matching ETag (for conditional requests)
629pub fn check_etag_match(headers: &HeaderMap, cid: &str) -> bool {
630    if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH) {
631        if let Ok(value) = if_none_match.to_str() {
632            // Remove quotes and compare
633            let etag = value.trim().trim_matches('"');
634            return etag == cid;
635        }
636    }
637    false
638}
639
640/// Build a 304 Not Modified response
641pub fn not_modified_response(cid: &str, config: &CacheConfig) -> Response {
642    let mut response = Response::builder()
643        .status(StatusCode::NOT_MODIFIED)
644        .body(Body::empty())
645        .unwrap();
646
647    add_caching_headers(response.headers_mut(), cid, config);
648
649    response
650}
651
652/// Authenticated user context
653#[derive(Debug, Clone)]
654pub struct AuthUser {
655    pub user_id: Uuid,
656    pub username: String,
657    pub claims: Option<Claims>,
658}
659
660/// Authenticate user from Authorization header
661///
662/// Supports both JWT tokens (Bearer <token>) and API keys (ipfrs_...)
663fn authenticate_user(req: &Request, auth_state: &AuthState) -> Result<AuthUser, AuthError> {
664    let auth_header = req
665        .headers()
666        .get(header::AUTHORIZATION)
667        .and_then(|h| h.to_str().ok())
668        .ok_or(AuthError::InvalidToken(
669            "Missing Authorization header".to_string(),
670        ))?;
671
672    // Try JWT token first (Bearer <token>)
673    if let Some(token) = auth_header.strip_prefix("Bearer ") {
674        let claims = auth_state.jwt_manager.validate_token(token)?;
675        let user = auth_state.user_store.get_user(&claims.username)?;
676
677        return Ok(AuthUser {
678            user_id: user.id,
679            username: user.username,
680            claims: Some(claims),
681        });
682    }
683
684    // Try API key (ipfrs_...)
685    if auth_header.starts_with("ipfrs_") {
686        let (_api_key, user_id) = auth_state.api_key_store.authenticate(auth_header)?;
687        let user = auth_state.user_store.get_by_id(&user_id)?;
688
689        return Ok(AuthUser {
690            user_id: user.id,
691            username: user.username,
692            claims: None,
693        });
694    }
695
696    Err(AuthError::InvalidToken(
697        "Authorization header must be either 'Bearer <token>' or 'ipfrs_<key>'".to_string(),
698    ))
699}
700
701/// Authentication middleware
702///
703/// Validates JWT token or API key from Authorization header and injects authenticated user into request extensions.
704pub async fn auth_middleware(
705    State(auth_state): State<AuthState>,
706    mut req: Request,
707    next: Next,
708) -> Result<Response, AuthMiddlewareError> {
709    // Authenticate user (JWT or API key)
710    let auth_user = authenticate_user(&req, &auth_state)?;
711
712    // Inject authenticated user into request extensions
713    req.extensions_mut().insert(auth_user);
714
715    Ok(next.run(req).await)
716}
717
718/// Type alias for the permission check middleware future
719type PermissionCheckFuture = std::pin::Pin<
720    Box<dyn std::future::Future<Output = Result<Response, AuthMiddlewareError>> + Send>,
721>;
722
723/// Authorization middleware factory
724///
725/// Creates middleware that checks if the authenticated user has required permissions.
726pub fn require_permission(
727    required: Permission,
728) -> impl Fn(State<AuthState>, Request, Next) -> PermissionCheckFuture + Clone {
729    move |State(auth_state): State<AuthState>, req: Request, next: Next| {
730        let required = required;
731        Box::pin(async move {
732            // Get authenticated user from extensions
733            let auth_user = req
734                .extensions()
735                .get::<AuthUser>()
736                .ok_or_else(|| AuthError::InvalidToken("User not authenticated".to_string()))?;
737
738            // Get user from store to check permissions
739            let user = auth_state.user_store.get_by_id(&auth_user.user_id)?;
740
741            // Check if user has required permission
742            if !user.has_permission(required) {
743                return Err(AuthMiddlewareError::from(
744                    AuthError::InsufficientPermissions,
745                ));
746            }
747
748            Ok(next.run(req).await)
749        })
750    }
751}
752
753/// Middleware error wrapper
754#[derive(Debug)]
755pub struct AuthMiddlewareError {
756    error: AuthError,
757}
758
759impl From<AuthError> for AuthMiddlewareError {
760    fn from(error: AuthError) -> Self {
761        Self { error }
762    }
763}
764
765impl IntoResponse for AuthMiddlewareError {
766    fn into_response(self) -> Response {
767        let (status, message) = match self.error {
768            AuthError::InvalidToken(_) | AuthError::TokenExpired => {
769                (StatusCode::UNAUTHORIZED, "Authentication required")
770            }
771            AuthError::InsufficientPermissions => {
772                (StatusCode::FORBIDDEN, "Insufficient permissions")
773            }
774            AuthError::UserNotFound | AuthError::InvalidCredentials => {
775                (StatusCode::UNAUTHORIZED, "Invalid credentials")
776            }
777            _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
778        };
779
780        (status, message).into_response()
781    }
782}
783
784// ============================================================================
785// Request Validation Middleware
786// ============================================================================
787
788/// Request validation configuration
789#[derive(Debug, Clone)]
790pub struct ValidationConfig {
791    /// Maximum request body size (bytes)
792    pub max_body_size: usize,
793    /// Maximum CID length
794    pub max_cid_length: usize,
795    /// Validate CID format
796    pub validate_cid_format: bool,
797    /// Required content types for specific endpoints
798    pub content_type_validation: bool,
799    /// Maximum batch size for batch operations
800    pub max_batch_size: usize,
801}
802
803impl Default for ValidationConfig {
804    fn default() -> Self {
805        Self {
806            max_body_size: 100 * 1024 * 1024, // 100 MB
807            max_cid_length: 100,
808            validate_cid_format: true,
809            content_type_validation: true,
810            max_batch_size: 1000,
811        }
812    }
813}
814
815impl ValidationConfig {
816    /// Create a strict validation config
817    pub fn strict() -> Self {
818        Self {
819            max_body_size: 10 * 1024 * 1024, // 10 MB
820            max_cid_length: 64,
821            validate_cid_format: true,
822            content_type_validation: true,
823            max_batch_size: 100,
824        }
825    }
826
827    /// Create a permissive validation config
828    pub fn permissive() -> Self {
829        Self {
830            max_body_size: 1024 * 1024 * 1024, // 1 GB
831            max_cid_length: 200,
832            validate_cid_format: false,
833            content_type_validation: false,
834            max_batch_size: 10000,
835        }
836    }
837}
838
839/// Validation error types
840#[derive(Debug)]
841pub enum ValidationError {
842    /// Request body too large
843    BodyTooLarge { size: usize, max: usize },
844    /// Invalid CID format
845    InvalidCid(String),
846    /// Invalid content type
847    InvalidContentType { expected: String, actual: String },
848    /// Missing required parameter
849    MissingParameter(String),
850    /// Batch size exceeds limit
851    BatchTooLarge { size: usize, max: usize },
852    /// Invalid parameter value
853    InvalidParameter { name: String, reason: String },
854}
855
856impl std::fmt::Display for ValidationError {
857    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
858        match self {
859            ValidationError::BodyTooLarge { size, max } => {
860                write!(
861                    f,
862                    "Request body too large: {} bytes (max: {} bytes)",
863                    size, max
864                )
865            }
866            ValidationError::InvalidCid(cid) => {
867                write!(f, "Invalid CID format: {}", cid)
868            }
869            ValidationError::InvalidContentType { expected, actual } => {
870                write!(
871                    f,
872                    "Invalid content type: expected {}, got {}",
873                    expected, actual
874                )
875            }
876            ValidationError::MissingParameter(param) => {
877                write!(f, "Missing required parameter: {}", param)
878            }
879            ValidationError::BatchTooLarge { size, max } => {
880                write!(
881                    f,
882                    "Batch size too large: {} items (max: {} items)",
883                    size, max
884                )
885            }
886            ValidationError::InvalidParameter { name, reason } => {
887                write!(f, "Invalid parameter '{}': {}", name, reason)
888            }
889        }
890    }
891}
892
893impl std::error::Error for ValidationError {}
894
895impl IntoResponse for ValidationError {
896    fn into_response(self) -> Response {
897        let request_id = Uuid::new_v4();
898        let error_message = self.to_string();
899
900        let (status, code) = match self {
901            ValidationError::BodyTooLarge { .. } => {
902                (StatusCode::PAYLOAD_TOO_LARGE, "BODY_TOO_LARGE")
903            }
904            ValidationError::InvalidCid(_) => (StatusCode::BAD_REQUEST, "INVALID_CID"),
905            ValidationError::InvalidContentType { .. } => {
906                (StatusCode::UNSUPPORTED_MEDIA_TYPE, "INVALID_CONTENT_TYPE")
907            }
908            ValidationError::MissingParameter(_) => (StatusCode::BAD_REQUEST, "MISSING_PARAMETER"),
909            ValidationError::BatchTooLarge { .. } => (StatusCode::BAD_REQUEST, "BATCH_TOO_LARGE"),
910            ValidationError::InvalidParameter { .. } => {
911                (StatusCode::BAD_REQUEST, "INVALID_PARAMETER")
912            }
913        };
914
915        let body = serde_json::json!({
916            "error": error_message,
917            "code": code,
918            "request_id": request_id.to_string(),
919        });
920
921        (status, serde_json::to_string(&body).unwrap()).into_response()
922    }
923}
924
925/// Validate CID format
926///
927/// Basic validation: CIDv0 starts with "Qm" and is 46 chars, CIDv1 is base32/base58
928pub fn validate_cid(cid: &str, config: &ValidationConfig) -> Result<(), ValidationError> {
929    // Empty CID is always invalid, regardless of validation settings
930    if cid.is_empty() {
931        return Err(ValidationError::InvalidCid(
932            "CID cannot be empty".to_string(),
933        ));
934    }
935
936    if !config.validate_cid_format {
937        return Ok(());
938    }
939
940    if cid.len() > config.max_cid_length {
941        return Err(ValidationError::InvalidCid(format!(
942            "CID too long: {} chars (max: {})",
943            cid.len(),
944            config.max_cid_length
945        )));
946    }
947
948    // Basic format check: CIDv0 or CIDv1
949    if cid.starts_with("Qm") && cid.len() == 46 {
950        // CIDv0 (base58btc encoded SHA-256 hash)
951        Ok(())
952    } else if cid.starts_with("b") || cid.starts_with("z") || cid.starts_with("f") {
953        // CIDv1 (multibase prefix)
954        Ok(())
955    } else {
956        Err(ValidationError::InvalidCid(
957            "Invalid CID format: must be CIDv0 (Qm...) or CIDv1 (b..., z..., f...)".to_string(),
958        ))
959    }
960}
961
962/// Validate batch size
963pub fn validate_batch_size(size: usize, config: &ValidationConfig) -> Result<(), ValidationError> {
964    if size == 0 {
965        return Err(ValidationError::InvalidParameter {
966            name: "batch".to_string(),
967            reason: "Batch cannot be empty".to_string(),
968        });
969    }
970
971    if size > config.max_batch_size {
972        return Err(ValidationError::BatchTooLarge {
973            size,
974            max: config.max_batch_size,
975        });
976    }
977
978    Ok(())
979}
980
981/// Validate content type
982pub fn validate_content_type(
983    headers: &HeaderMap,
984    expected: &str,
985    config: &ValidationConfig,
986) -> Result<(), ValidationError> {
987    if !config.content_type_validation {
988        return Ok(());
989    }
990
991    let content_type = headers
992        .get(header::CONTENT_TYPE)
993        .and_then(|h| h.to_str().ok())
994        .unwrap_or("");
995
996    if !content_type.starts_with(expected) {
997        return Err(ValidationError::InvalidContentType {
998            expected: expected.to_string(),
999            actual: content_type.to_string(),
1000        });
1001    }
1002
1003    Ok(())
1004}
1005
1006/// Validation middleware state
1007#[derive(Clone)]
1008pub struct ValidationState {
1009    pub config: ValidationConfig,
1010}
1011
1012/// Request validation middleware
1013///
1014/// Validates request size and basic parameters before processing
1015pub async fn validation_middleware(
1016    State(_validation_state): State<ValidationState>,
1017    req: Request,
1018    next: Next,
1019) -> Result<Response, ValidationError> {
1020    let (parts, body) = req.into_parts();
1021
1022    // Validate content-type for POST/PUT requests
1023    if parts.method == Method::POST || parts.method == Method::PUT {
1024        // Skip validation for multipart/form-data (handled by body parser)
1025        if let Some(content_type) = parts.headers.get(header::CONTENT_TYPE) {
1026            if let Ok(ct_str) = content_type.to_str() {
1027                if ct_str.contains("multipart/form-data") {
1028                    // Skip body size validation for multipart
1029                    let req = Request::from_parts(parts, body);
1030                    return Ok(next.run(req).await);
1031                }
1032            }
1033        }
1034    }
1035
1036    // Reconstruct request and continue
1037    let req = Request::from_parts(parts, body);
1038    Ok(next.run(req).await)
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043    use super::*;
1044
1045    #[test]
1046    fn test_cors_config_default() {
1047        let config = CorsConfig::default();
1048        assert!(config.allowed_origins.is_empty());
1049        assert!(config.allowed_methods.contains(&Method::GET));
1050        assert!(config.allowed_methods.contains(&Method::POST));
1051        assert!(!config.allow_credentials);
1052        assert_eq!(config.max_age, 86400);
1053    }
1054
1055    #[test]
1056    fn test_cors_config_permissive() {
1057        let config = CorsConfig::permissive();
1058        assert!(config.allowed_origins.contains("*"));
1059        assert!(config.is_origin_allowed("https://example.com"));
1060        assert!(config.is_origin_allowed("http://localhost:3000"));
1061    }
1062
1063    #[test]
1064    fn test_cors_config_allow_origin() {
1065        let config = CorsConfig::default()
1066            .allow_origin("https://example.com")
1067            .allow_origin("https://api.example.com");
1068
1069        assert!(config.is_origin_allowed("https://example.com"));
1070        assert!(config.is_origin_allowed("https://api.example.com"));
1071        assert!(!config.is_origin_allowed("https://other.com"));
1072    }
1073
1074    #[test]
1075    fn test_rate_limit_config_default() {
1076        let config = RateLimitConfig::default();
1077        assert_eq!(config.max_requests, 100);
1078        assert_eq!(config.window, Duration::from_secs(60));
1079        assert_eq!(config.burst_capacity, 10);
1080    }
1081
1082    #[test]
1083    fn test_cache_config_default() {
1084        let config = CacheConfig::default();
1085        assert_eq!(config.default_max_age, 3600);
1086        assert!(config.public);
1087        assert!(config.immutable_cids);
1088    }
1089
1090    #[test]
1091    fn test_add_caching_headers() {
1092        let mut headers = HeaderMap::new();
1093        let config = CacheConfig::default();
1094
1095        add_caching_headers(&mut headers, "QmTest123", &config);
1096
1097        assert!(headers.contains_key(header::ETAG));
1098        assert!(headers.contains_key(header::CACHE_CONTROL));
1099
1100        let etag = headers.get(header::ETAG).unwrap().to_str().unwrap();
1101        assert_eq!(etag, "\"QmTest123\"");
1102
1103        let cache_control = headers
1104            .get(header::CACHE_CONTROL)
1105            .unwrap()
1106            .to_str()
1107            .unwrap();
1108        assert!(cache_control.contains("public"));
1109        assert!(cache_control.contains("max-age=3600"));
1110        assert!(cache_control.contains("immutable"));
1111    }
1112
1113    #[test]
1114    fn test_check_etag_match() {
1115        let mut headers = HeaderMap::new();
1116
1117        // No If-None-Match header
1118        assert!(!check_etag_match(&headers, "QmTest123"));
1119
1120        // With matching ETag
1121        headers.insert(
1122            header::IF_NONE_MATCH,
1123            HeaderValue::from_static("\"QmTest123\""),
1124        );
1125        assert!(check_etag_match(&headers, "QmTest123"));
1126
1127        // With non-matching ETag
1128        assert!(!check_etag_match(&headers, "QmOther456"));
1129    }
1130
1131    #[tokio::test]
1132    async fn test_rate_limit_state() {
1133        let config = RateLimitConfig {
1134            max_requests: 5,
1135            window: Duration::from_secs(1),
1136            burst_capacity: 3,
1137        };
1138        let state = RateLimitState::new(config);
1139
1140        // First 3 requests should succeed (burst capacity)
1141        for _ in 0..3 {
1142            let (allowed, _) = state.get_bucket("127.0.0.1").await;
1143            assert!(allowed);
1144        }
1145    }
1146
1147    #[test]
1148    fn test_compression_level_to_level() {
1149        assert_eq!(CompressionLevel::Fastest.to_level(), 1);
1150        assert_eq!(CompressionLevel::Balanced.to_level(), 5);
1151        assert_eq!(CompressionLevel::Best.to_level(), 9);
1152        assert_eq!(CompressionLevel::Custom(7).to_level(), 7);
1153        assert_eq!(CompressionLevel::Custom(15).to_level(), 9); // Capped at 9
1154    }
1155
1156    #[test]
1157    fn test_compression_level_to_brotli_quality() {
1158        assert_eq!(CompressionLevel::Fastest.to_brotli_quality(), 1);
1159        assert_eq!(CompressionLevel::Balanced.to_brotli_quality(), 6);
1160        assert_eq!(CompressionLevel::Best.to_brotli_quality(), 11);
1161        assert_eq!(CompressionLevel::Custom(8).to_brotli_quality(), 8);
1162        assert_eq!(CompressionLevel::Custom(15).to_brotli_quality(), 11); // Capped at 11
1163    }
1164
1165    #[test]
1166    fn test_compression_config_default() {
1167        let config = CompressionConfig::default();
1168        assert!(config.enable_gzip);
1169        assert!(config.enable_brotli);
1170        assert!(config.enable_deflate);
1171        assert_eq!(config.level, CompressionLevel::Balanced);
1172        assert_eq!(config.min_size, 1024);
1173    }
1174
1175    #[test]
1176    fn test_compression_config_fast() {
1177        let config = CompressionConfig::fast();
1178        assert_eq!(config.level, CompressionLevel::Fastest);
1179        assert!(config.enable_gzip);
1180    }
1181
1182    #[test]
1183    fn test_compression_config_best() {
1184        let config = CompressionConfig::best();
1185        assert_eq!(config.level, CompressionLevel::Best);
1186        assert!(config.enable_brotli);
1187    }
1188
1189    #[test]
1190    fn test_compression_config_builder() {
1191        let config = CompressionConfig::default()
1192            .with_level(CompressionLevel::Custom(7))
1193            .with_min_size(2048)
1194            .with_algorithms(true, false, false);
1195
1196        assert_eq!(config.level, CompressionLevel::Custom(7));
1197        assert_eq!(config.min_size, 2048);
1198        assert!(config.enable_gzip);
1199        assert!(!config.enable_brotli);
1200        assert!(!config.enable_deflate);
1201    }
1202
1203    #[test]
1204    fn test_compression_config_validation_valid() {
1205        let config = CompressionConfig::default();
1206        assert!(config.validate().is_ok());
1207
1208        let config = CompressionConfig::default().with_algorithms(true, false, false);
1209        assert!(config.validate().is_ok());
1210    }
1211
1212    #[test]
1213    fn test_compression_config_validation_invalid() {
1214        // No algorithms enabled
1215        let config = CompressionConfig::default().with_algorithms(false, false, false);
1216        assert!(config.validate().is_err());
1217
1218        // Min size too large
1219        let config = CompressionConfig::default().with_min_size(200 * 1024 * 1024);
1220        assert!(config.validate().is_err());
1221    }
1222
1223    #[test]
1224    fn test_rate_limit_config_validation_valid() {
1225        let config = RateLimitConfig::default();
1226        assert!(config.validate().is_ok());
1227
1228        let config = RateLimitConfig {
1229            max_requests: 100,
1230            window: Duration::from_secs(60),
1231            burst_capacity: 50,
1232        };
1233        assert!(config.validate().is_ok());
1234    }
1235
1236    #[test]
1237    fn test_rate_limit_config_validation_invalid() {
1238        // Zero max requests
1239        let config = RateLimitConfig {
1240            max_requests: 0,
1241            window: Duration::from_secs(60),
1242            burst_capacity: 10,
1243        };
1244        assert!(config.validate().is_err());
1245
1246        // Zero window
1247        let config = RateLimitConfig {
1248            max_requests: 100,
1249            window: Duration::from_secs(0),
1250            burst_capacity: 10,
1251        };
1252        assert!(config.validate().is_err());
1253
1254        // Burst exceeds max
1255        let config = RateLimitConfig {
1256            max_requests: 100,
1257            window: Duration::from_secs(60),
1258            burst_capacity: 200,
1259        };
1260        assert!(config.validate().is_err());
1261    }
1262
1263    #[test]
1264    fn test_cache_config_validation_valid() {
1265        let config = CacheConfig::default();
1266        assert!(config.validate().is_ok());
1267
1268        let config = CacheConfig {
1269            default_max_age: 86400, // 1 day
1270            public: true,
1271            immutable_cids: true,
1272        };
1273        assert!(config.validate().is_ok());
1274    }
1275
1276    #[test]
1277    fn test_cache_config_validation_invalid() {
1278        // Max age too large (more than 1 year)
1279        let config = CacheConfig {
1280            default_max_age: 400 * 24 * 3600, // More than 1 year
1281            public: true,
1282            immutable_cids: true,
1283        };
1284        assert!(config.validate().is_err());
1285    }
1286
1287    // Validation middleware tests
1288
1289    #[test]
1290    fn test_validation_config_default() {
1291        let config = ValidationConfig::default();
1292        assert_eq!(config.max_body_size, 100 * 1024 * 1024);
1293        assert_eq!(config.max_cid_length, 100);
1294        assert!(config.validate_cid_format);
1295        assert!(config.content_type_validation);
1296        assert_eq!(config.max_batch_size, 1000);
1297    }
1298
1299    #[test]
1300    fn test_validation_config_strict() {
1301        let config = ValidationConfig::strict();
1302        assert_eq!(config.max_body_size, 10 * 1024 * 1024);
1303        assert_eq!(config.max_cid_length, 64);
1304        assert_eq!(config.max_batch_size, 100);
1305    }
1306
1307    #[test]
1308    fn test_validation_config_permissive() {
1309        let config = ValidationConfig::permissive();
1310        assert_eq!(config.max_body_size, 1024 * 1024 * 1024);
1311        assert_eq!(config.max_cid_length, 200);
1312        assert!(!config.validate_cid_format);
1313        assert!(!config.content_type_validation);
1314        assert_eq!(config.max_batch_size, 10000);
1315    }
1316
1317    #[test]
1318    fn test_validate_cid_v0() {
1319        let config = ValidationConfig::default();
1320
1321        // Valid CIDv0
1322        assert!(validate_cid("QmXoypizjW3WknFiJnKLwHCnL72vedxjQkDDP1mXWo6uco", &config).is_ok());
1323
1324        // Invalid CIDv0 (wrong length)
1325        assert!(validate_cid("QmShort", &config).is_err());
1326
1327        // Empty CID
1328        assert!(validate_cid("", &config).is_err());
1329    }
1330
1331    #[test]
1332    fn test_validate_cid_v1() {
1333        let config = ValidationConfig::default();
1334
1335        // Valid CIDv1 prefixes
1336        assert!(validate_cid(
1337            "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1338            &config
1339        )
1340        .is_ok());
1341        assert!(validate_cid("zb2rhk6GMPQF8p1kqXvhYnCMp3hGGUQVvqp6qjdvNLKqCqKCo", &config).is_ok());
1342
1343        // Invalid format
1344        assert!(validate_cid("invalid_cid_format", &config).is_err());
1345    }
1346
1347    #[test]
1348    fn test_validate_cid_disabled() {
1349        let config = ValidationConfig {
1350            validate_cid_format: false,
1351            ..Default::default()
1352        };
1353
1354        // Should accept any string when validation is disabled
1355        assert!(validate_cid("invalid_format", &config).is_ok());
1356        assert!(validate_cid("", &config).is_err()); // Empty still fails
1357    }
1358
1359    #[test]
1360    fn test_validate_batch_size_valid() {
1361        let config = ValidationConfig::default();
1362
1363        assert!(validate_batch_size(1, &config).is_ok());
1364        assert!(validate_batch_size(100, &config).is_ok());
1365        assert!(validate_batch_size(1000, &config).is_ok());
1366    }
1367
1368    #[test]
1369    fn test_validate_batch_size_invalid() {
1370        let config = ValidationConfig::default();
1371
1372        // Empty batch
1373        assert!(validate_batch_size(0, &config).is_err());
1374
1375        // Too large
1376        assert!(validate_batch_size(1001, &config).is_err());
1377        assert!(validate_batch_size(10000, &config).is_err());
1378    }
1379
1380    #[test]
1381    fn test_validate_content_type_valid() {
1382        let config = ValidationConfig::default();
1383        let mut headers = HeaderMap::new();
1384        headers.insert(
1385            header::CONTENT_TYPE,
1386            HeaderValue::from_static("application/json"),
1387        );
1388
1389        assert!(validate_content_type(&headers, "application/json", &config).is_ok());
1390    }
1391
1392    #[test]
1393    fn test_validate_content_type_invalid() {
1394        let config = ValidationConfig::default();
1395        let mut headers = HeaderMap::new();
1396        headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
1397
1398        assert!(validate_content_type(&headers, "application/json", &config).is_err());
1399    }
1400
1401    #[test]
1402    fn test_validate_content_type_disabled() {
1403        let config = ValidationConfig {
1404            content_type_validation: false,
1405            ..Default::default()
1406        };
1407
1408        let mut headers = HeaderMap::new();
1409        headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
1410
1411        // Should accept any content type when validation is disabled
1412        assert!(validate_content_type(&headers, "application/json", &config).is_ok());
1413    }
1414
1415    #[test]
1416    fn test_validation_error_display() {
1417        let err = ValidationError::InvalidCid("test".to_string());
1418        assert_eq!(err.to_string(), "Invalid CID format: test");
1419
1420        let err = ValidationError::BodyTooLarge {
1421            size: 200,
1422            max: 100,
1423        };
1424        assert!(err.to_string().contains("200 bytes"));
1425        assert!(err.to_string().contains("100 bytes"));
1426
1427        let err = ValidationError::BatchTooLarge {
1428            size: 2000,
1429            max: 1000,
1430        };
1431        assert!(err.to_string().contains("2000 items"));
1432        assert!(err.to_string().contains("1000 items"));
1433    }
1434}