1use crate::auth::{AuthError, AuthState, Claims, Permission};
11use axum::{
12 body::Body,
13 extract::{Request, State},
14 http::{header, HeaderMap, HeaderValue, Method, StatusCode},
15 middleware::Next,
16 response::{IntoResponse, Response},
17};
18use std::collections::HashSet;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::Mutex;
22use uuid::Uuid;
23
24#[derive(Debug, Clone)]
30pub struct CorsConfig {
31 pub allowed_origins: HashSet<String>,
33 pub allowed_methods: HashSet<Method>,
35 pub allowed_headers: HashSet<String>,
37 pub exposed_headers: HashSet<String>,
39 pub allow_credentials: bool,
41 pub max_age: u64,
43}
44
45impl Default for CorsConfig {
46 fn default() -> Self {
47 let mut methods = HashSet::new();
48 methods.insert(Method::GET);
49 methods.insert(Method::POST);
50 methods.insert(Method::PUT);
51 methods.insert(Method::DELETE);
52 methods.insert(Method::OPTIONS);
53 methods.insert(Method::HEAD);
54
55 let mut headers = HashSet::new();
56 headers.insert("content-type".to_string());
57 headers.insert("authorization".to_string());
58 headers.insert("accept".to_string());
59 headers.insert("origin".to_string());
60 headers.insert("x-requested-with".to_string());
61
62 Self {
63 allowed_origins: HashSet::new(), allowed_methods: methods,
65 allowed_headers: headers,
66 exposed_headers: HashSet::new(),
67 allow_credentials: false,
68 max_age: 86400, }
70 }
71}
72
73impl CorsConfig {
74 pub fn permissive() -> Self {
76 let mut config = Self::default();
77 config.allowed_origins.insert("*".to_string());
78 config
79 }
80
81 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
83 self.allowed_origins.insert(origin.into());
84 self
85 }
86
87 pub fn allow_credentials(mut self, allow: bool) -> Self {
89 self.allow_credentials = allow;
90 self
91 }
92
93 fn is_origin_allowed(&self, origin: &str) -> bool {
95 if self.allowed_origins.is_empty() || self.allowed_origins.contains("*") {
96 true
97 } else {
98 self.allowed_origins.contains(origin)
99 }
100 }
101
102 fn methods_string(&self) -> String {
104 self.allowed_methods
105 .iter()
106 .map(|m| m.as_str())
107 .collect::<Vec<_>>()
108 .join(", ")
109 }
110
111 fn headers_string(&self) -> String {
113 self.allowed_headers
114 .iter()
115 .cloned()
116 .collect::<Vec<_>>()
117 .join(", ")
118 }
119}
120
121#[derive(Clone)]
123pub struct CorsState {
124 pub config: CorsConfig,
125}
126
127pub async fn cors_middleware(
131 State(cors_state): State<CorsState>,
132 req: Request,
133 next: Next,
134) -> Response {
135 let origin = req
136 .headers()
137 .get(header::ORIGIN)
138 .and_then(|h| h.to_str().ok())
139 .map(|s| s.to_string());
140
141 if req.method() == Method::OPTIONS {
143 return build_preflight_response(&cors_state.config, origin.as_deref());
144 }
145
146 let mut response = next.run(req).await;
148
149 add_cors_headers(
151 response.headers_mut(),
152 &cors_state.config,
153 origin.as_deref(),
154 );
155
156 response
157}
158
159fn build_preflight_response(config: &CorsConfig, origin: Option<&str>) -> Response {
161 let mut response = Response::builder()
162 .status(StatusCode::NO_CONTENT)
163 .body(Body::empty())
164 .unwrap();
165
166 add_cors_headers(response.headers_mut(), config, origin);
167
168 if let Ok(value) = HeaderValue::from_str(&config.methods_string()) {
170 response
171 .headers_mut()
172 .insert(header::ACCESS_CONTROL_ALLOW_METHODS, value);
173 }
174 if let Ok(value) = HeaderValue::from_str(&config.headers_string()) {
175 response
176 .headers_mut()
177 .insert(header::ACCESS_CONTROL_ALLOW_HEADERS, value);
178 }
179 if let Ok(value) = HeaderValue::from_str(&config.max_age.to_string()) {
180 response
181 .headers_mut()
182 .insert(header::ACCESS_CONTROL_MAX_AGE, value);
183 }
184
185 response
186}
187
188fn add_cors_headers(headers: &mut HeaderMap, config: &CorsConfig, origin: Option<&str>) {
190 let origin_value = if let Some(origin) = origin {
192 if config.is_origin_allowed(origin) {
193 if config.allowed_origins.contains("*") && !config.allow_credentials {
194 "*"
195 } else {
196 origin
197 }
198 } else {
199 return; }
201 } else if config.allowed_origins.contains("*") {
202 "*"
203 } else {
204 return;
205 };
206
207 if let Ok(value) = HeaderValue::from_str(origin_value) {
208 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, value);
209 }
210
211 if config.allow_credentials {
213 headers.insert(
214 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
215 HeaderValue::from_static("true"),
216 );
217 }
218
219 if !config.exposed_headers.is_empty() {
221 let exposed = config
222 .exposed_headers
223 .iter()
224 .cloned()
225 .collect::<Vec<_>>()
226 .join(", ");
227 if let Ok(value) = HeaderValue::from_str(&exposed) {
228 headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, value);
229 }
230 }
231}
232
233#[derive(Debug, Clone)]
239pub struct RateLimitConfig {
240 pub max_requests: u32,
242 pub window: Duration,
244 pub burst_capacity: u32,
246}
247
248impl Default for RateLimitConfig {
249 fn default() -> Self {
250 Self {
251 max_requests: 100,
252 window: Duration::from_secs(60),
253 burst_capacity: 10,
254 }
255 }
256}
257
258impl RateLimitConfig {
259 pub fn validate(&self) -> Result<(), String> {
261 if self.max_requests == 0 {
262 return Err("Maximum requests must be greater than 0".to_string());
263 }
264
265 if self.window.as_secs() == 0 {
266 return Err("Time window must be greater than 0".to_string());
267 }
268
269 if self.burst_capacity == 0 {
270 return Err("Burst capacity must be greater than 0".to_string());
271 }
272
273 if self.burst_capacity > self.max_requests {
274 return Err(format!(
275 "Burst capacity ({}) cannot exceed max requests ({})",
276 self.burst_capacity, self.max_requests
277 ));
278 }
279
280 Ok(())
281 }
282}
283
284#[derive(Debug)]
286struct TokenBucket {
287 tokens: f64,
288 last_update: Instant,
289 capacity: f64,
290 refill_rate: f64, }
292
293impl TokenBucket {
294 fn new(capacity: u32, refill_rate: f64) -> Self {
295 Self {
296 tokens: capacity as f64,
297 last_update: Instant::now(),
298 capacity: capacity as f64,
299 refill_rate,
300 }
301 }
302
303 fn try_acquire(&mut self) -> bool {
304 self.refill();
305 if self.tokens >= 1.0 {
306 self.tokens -= 1.0;
307 true
308 } else {
309 false
310 }
311 }
312
313 fn refill(&mut self) {
314 let now = Instant::now();
315 let elapsed = now.duration_since(self.last_update).as_secs_f64();
316 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
317 self.last_update = now;
318 }
319
320 fn tokens_remaining(&self) -> u32 {
321 self.tokens as u32
322 }
323}
324
325#[derive(Clone)]
327pub struct RateLimitState {
328 config: RateLimitConfig,
329 buckets: Arc<Mutex<std::collections::HashMap<String, TokenBucket>>>,
330}
331
332impl RateLimitState {
333 pub fn new(config: RateLimitConfig) -> Self {
335 Self {
336 config,
337 buckets: Arc::new(Mutex::new(std::collections::HashMap::new())),
338 }
339 }
340
341 async fn get_bucket(&self, ip: &str) -> (bool, u32) {
343 let mut buckets = self.buckets.lock().await;
344
345 let refill_rate = self.config.max_requests as f64 / self.config.window.as_secs_f64();
346
347 let bucket = buckets
348 .entry(ip.to_string())
349 .or_insert_with(|| TokenBucket::new(self.config.burst_capacity, refill_rate));
350
351 let allowed = bucket.try_acquire();
352 let remaining = bucket.tokens_remaining();
353
354 (allowed, remaining)
355 }
356}
357
358pub async fn rate_limit_middleware(
362 State(rate_state): State<RateLimitState>,
363 req: Request,
364 next: Next,
365) -> Result<Response, RateLimitError> {
366 let ip = extract_client_ip(&req);
368
369 let (allowed, remaining) = rate_state.get_bucket(&ip).await;
370
371 if !allowed {
372 return Err(RateLimitError::TooManyRequests);
373 }
374
375 let mut response = next.run(req).await;
376
377 let headers = response.headers_mut();
379 if let Ok(value) = HeaderValue::from_str(&rate_state.config.max_requests.to_string()) {
380 headers.insert("X-RateLimit-Limit", value);
381 }
382 if let Ok(value) = HeaderValue::from_str(&remaining.to_string()) {
383 headers.insert("X-RateLimit-Remaining", value);
384 }
385
386 Ok(response)
387}
388
389fn extract_client_ip(req: &Request) -> String {
391 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
393 if let Ok(s) = forwarded.to_str() {
394 if let Some(ip) = s.split(',').next() {
395 return ip.trim().to_string();
396 }
397 }
398 }
399
400 if let Some(real_ip) = req.headers().get("x-real-ip") {
402 if let Ok(s) = real_ip.to_str() {
403 return s.to_string();
404 }
405 }
406
407 "unknown".to_string()
409}
410
411#[derive(Debug)]
413pub enum RateLimitError {
414 TooManyRequests,
415}
416
417impl IntoResponse for RateLimitError {
418 fn into_response(self) -> Response {
419 let (status, message) = match self {
420 RateLimitError::TooManyRequests => (
421 StatusCode::TOO_MANY_REQUESTS,
422 "Rate limit exceeded. Please retry later.",
423 ),
424 };
425
426 let mut response = (status, message).into_response();
427
428 response
430 .headers_mut()
431 .insert(header::RETRY_AFTER, HeaderValue::from_static("60"));
432
433 response
434 }
435}
436
437#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
443pub enum CompressionLevel {
444 Fastest,
446 #[default]
448 Balanced,
449 Best,
451 Custom(u32),
453}
454
455impl CompressionLevel {
456 pub fn to_level(self) -> u32 {
458 match self {
459 CompressionLevel::Fastest => 1,
460 CompressionLevel::Balanced => 5,
461 CompressionLevel::Best => 9,
462 CompressionLevel::Custom(level) => level.min(9),
463 }
464 }
465
466 pub fn to_brotli_quality(self) -> u32 {
468 match self {
469 CompressionLevel::Fastest => 1,
470 CompressionLevel::Balanced => 6,
471 CompressionLevel::Best => 11,
472 CompressionLevel::Custom(level) => level.min(11),
473 }
474 }
475}
476
477#[derive(Debug, Clone)]
479pub struct CompressionConfig {
480 pub enable_gzip: bool,
482 pub enable_brotli: bool,
484 pub enable_deflate: bool,
486 pub level: CompressionLevel,
488 pub min_size: usize,
490}
491
492impl Default for CompressionConfig {
493 fn default() -> Self {
494 Self {
495 enable_gzip: true,
496 enable_brotli: true,
497 enable_deflate: true,
498 level: CompressionLevel::Balanced,
499 min_size: 1024, }
501 }
502}
503
504impl CompressionConfig {
505 pub fn fast() -> Self {
507 Self {
508 level: CompressionLevel::Fastest,
509 ..Default::default()
510 }
511 }
512
513 pub fn best() -> Self {
515 Self {
516 level: CompressionLevel::Best,
517 ..Default::default()
518 }
519 }
520
521 pub fn with_level(mut self, level: CompressionLevel) -> Self {
523 self.level = level;
524 self
525 }
526
527 pub fn with_min_size(mut self, min_size: usize) -> Self {
529 self.min_size = min_size;
530 self
531 }
532
533 pub fn with_algorithms(mut self, gzip: bool, brotli: bool, deflate: bool) -> Self {
535 self.enable_gzip = gzip;
536 self.enable_brotli = brotli;
537 self.enable_deflate = deflate;
538 self
539 }
540
541 pub fn validate(&self) -> Result<(), String> {
543 if !self.enable_gzip && !self.enable_brotli && !self.enable_deflate {
545 return Err("At least one compression algorithm must be enabled".to_string());
546 }
547
548 if self.min_size > 100 * 1024 * 1024 {
550 return Err(format!(
551 "Minimum compression size {} is too large (max: 100MB)",
552 self.min_size
553 ));
554 }
555
556 Ok(())
557 }
558}
559
560#[derive(Debug, Clone)]
566pub struct CacheConfig {
567 pub default_max_age: u64,
569 pub public: bool,
571 pub immutable_cids: bool,
573}
574
575impl Default for CacheConfig {
576 fn default() -> Self {
577 Self {
578 default_max_age: 3600, public: true,
580 immutable_cids: true, }
582 }
583}
584
585impl CacheConfig {
586 pub fn validate(&self) -> Result<(), String> {
588 const MAX_AGE_LIMIT: u64 = 365 * 24 * 3600; if self.default_max_age > MAX_AGE_LIMIT {
592 return Err(format!(
593 "Max age {} exceeds maximum {} (1 year)",
594 self.default_max_age, MAX_AGE_LIMIT
595 ));
596 }
597
598 Ok(())
599 }
600}
601
602pub fn add_caching_headers(headers: &mut HeaderMap, cid: &str, config: &CacheConfig) {
604 if let Ok(etag) = HeaderValue::from_str(&format!("\"{}\"", cid)) {
606 headers.insert(header::ETAG, etag);
607 }
608
609 let mut cache_control = String::new();
611 if config.public {
612 cache_control.push_str("public, ");
613 } else {
614 cache_control.push_str("private, ");
615 }
616 cache_control.push_str(&format!("max-age={}", config.default_max_age));
617
618 if config.immutable_cids {
620 cache_control.push_str(", immutable");
621 }
622
623 if let Ok(value) = HeaderValue::from_str(&cache_control) {
624 headers.insert(header::CACHE_CONTROL, value);
625 }
626}
627
628pub fn check_etag_match(headers: &HeaderMap, cid: &str) -> bool {
630 if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH) {
631 if let Ok(value) = if_none_match.to_str() {
632 let etag = value.trim().trim_matches('"');
634 return etag == cid;
635 }
636 }
637 false
638}
639
640pub fn not_modified_response(cid: &str, config: &CacheConfig) -> Response {
642 let mut response = Response::builder()
643 .status(StatusCode::NOT_MODIFIED)
644 .body(Body::empty())
645 .unwrap();
646
647 add_caching_headers(response.headers_mut(), cid, config);
648
649 response
650}
651
652#[derive(Debug, Clone)]
654pub struct AuthUser {
655 pub user_id: Uuid,
656 pub username: String,
657 pub claims: Option<Claims>,
658}
659
660fn authenticate_user(req: &Request, auth_state: &AuthState) -> Result<AuthUser, AuthError> {
664 let auth_header = req
665 .headers()
666 .get(header::AUTHORIZATION)
667 .and_then(|h| h.to_str().ok())
668 .ok_or(AuthError::InvalidToken(
669 "Missing Authorization header".to_string(),
670 ))?;
671
672 if let Some(token) = auth_header.strip_prefix("Bearer ") {
674 let claims = auth_state.jwt_manager.validate_token(token)?;
675 let user = auth_state.user_store.get_user(&claims.username)?;
676
677 return Ok(AuthUser {
678 user_id: user.id,
679 username: user.username,
680 claims: Some(claims),
681 });
682 }
683
684 if auth_header.starts_with("ipfrs_") {
686 let (_api_key, user_id) = auth_state.api_key_store.authenticate(auth_header)?;
687 let user = auth_state.user_store.get_by_id(&user_id)?;
688
689 return Ok(AuthUser {
690 user_id: user.id,
691 username: user.username,
692 claims: None,
693 });
694 }
695
696 Err(AuthError::InvalidToken(
697 "Authorization header must be either 'Bearer <token>' or 'ipfrs_<key>'".to_string(),
698 ))
699}
700
701pub async fn auth_middleware(
705 State(auth_state): State<AuthState>,
706 mut req: Request,
707 next: Next,
708) -> Result<Response, AuthMiddlewareError> {
709 let auth_user = authenticate_user(&req, &auth_state)?;
711
712 req.extensions_mut().insert(auth_user);
714
715 Ok(next.run(req).await)
716}
717
718type PermissionCheckFuture = std::pin::Pin<
720 Box<dyn std::future::Future<Output = Result<Response, AuthMiddlewareError>> + Send>,
721>;
722
723pub fn require_permission(
727 required: Permission,
728) -> impl Fn(State<AuthState>, Request, Next) -> PermissionCheckFuture + Clone {
729 move |State(auth_state): State<AuthState>, req: Request, next: Next| {
730 let required = required;
731 Box::pin(async move {
732 let auth_user = req
734 .extensions()
735 .get::<AuthUser>()
736 .ok_or_else(|| AuthError::InvalidToken("User not authenticated".to_string()))?;
737
738 let user = auth_state.user_store.get_by_id(&auth_user.user_id)?;
740
741 if !user.has_permission(required) {
743 return Err(AuthMiddlewareError::from(
744 AuthError::InsufficientPermissions,
745 ));
746 }
747
748 Ok(next.run(req).await)
749 })
750 }
751}
752
753#[derive(Debug)]
755pub struct AuthMiddlewareError {
756 error: AuthError,
757}
758
759impl From<AuthError> for AuthMiddlewareError {
760 fn from(error: AuthError) -> Self {
761 Self { error }
762 }
763}
764
765impl IntoResponse for AuthMiddlewareError {
766 fn into_response(self) -> Response {
767 let (status, message) = match self.error {
768 AuthError::InvalidToken(_) | AuthError::TokenExpired => {
769 (StatusCode::UNAUTHORIZED, "Authentication required")
770 }
771 AuthError::InsufficientPermissions => {
772 (StatusCode::FORBIDDEN, "Insufficient permissions")
773 }
774 AuthError::UserNotFound | AuthError::InvalidCredentials => {
775 (StatusCode::UNAUTHORIZED, "Invalid credentials")
776 }
777 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
778 };
779
780 (status, message).into_response()
781 }
782}
783
784#[derive(Debug, Clone)]
790pub struct ValidationConfig {
791 pub max_body_size: usize,
793 pub max_cid_length: usize,
795 pub validate_cid_format: bool,
797 pub content_type_validation: bool,
799 pub max_batch_size: usize,
801}
802
803impl Default for ValidationConfig {
804 fn default() -> Self {
805 Self {
806 max_body_size: 100 * 1024 * 1024, max_cid_length: 100,
808 validate_cid_format: true,
809 content_type_validation: true,
810 max_batch_size: 1000,
811 }
812 }
813}
814
815impl ValidationConfig {
816 pub fn strict() -> Self {
818 Self {
819 max_body_size: 10 * 1024 * 1024, max_cid_length: 64,
821 validate_cid_format: true,
822 content_type_validation: true,
823 max_batch_size: 100,
824 }
825 }
826
827 pub fn permissive() -> Self {
829 Self {
830 max_body_size: 1024 * 1024 * 1024, max_cid_length: 200,
832 validate_cid_format: false,
833 content_type_validation: false,
834 max_batch_size: 10000,
835 }
836 }
837}
838
839#[derive(Debug)]
841pub enum ValidationError {
842 BodyTooLarge { size: usize, max: usize },
844 InvalidCid(String),
846 InvalidContentType { expected: String, actual: String },
848 MissingParameter(String),
850 BatchTooLarge { size: usize, max: usize },
852 InvalidParameter { name: String, reason: String },
854}
855
856impl std::fmt::Display for ValidationError {
857 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
858 match self {
859 ValidationError::BodyTooLarge { size, max } => {
860 write!(
861 f,
862 "Request body too large: {} bytes (max: {} bytes)",
863 size, max
864 )
865 }
866 ValidationError::InvalidCid(cid) => {
867 write!(f, "Invalid CID format: {}", cid)
868 }
869 ValidationError::InvalidContentType { expected, actual } => {
870 write!(
871 f,
872 "Invalid content type: expected {}, got {}",
873 expected, actual
874 )
875 }
876 ValidationError::MissingParameter(param) => {
877 write!(f, "Missing required parameter: {}", param)
878 }
879 ValidationError::BatchTooLarge { size, max } => {
880 write!(
881 f,
882 "Batch size too large: {} items (max: {} items)",
883 size, max
884 )
885 }
886 ValidationError::InvalidParameter { name, reason } => {
887 write!(f, "Invalid parameter '{}': {}", name, reason)
888 }
889 }
890 }
891}
892
893impl std::error::Error for ValidationError {}
894
895impl IntoResponse for ValidationError {
896 fn into_response(self) -> Response {
897 let request_id = Uuid::new_v4();
898 let error_message = self.to_string();
899
900 let (status, code) = match self {
901 ValidationError::BodyTooLarge { .. } => {
902 (StatusCode::PAYLOAD_TOO_LARGE, "BODY_TOO_LARGE")
903 }
904 ValidationError::InvalidCid(_) => (StatusCode::BAD_REQUEST, "INVALID_CID"),
905 ValidationError::InvalidContentType { .. } => {
906 (StatusCode::UNSUPPORTED_MEDIA_TYPE, "INVALID_CONTENT_TYPE")
907 }
908 ValidationError::MissingParameter(_) => (StatusCode::BAD_REQUEST, "MISSING_PARAMETER"),
909 ValidationError::BatchTooLarge { .. } => (StatusCode::BAD_REQUEST, "BATCH_TOO_LARGE"),
910 ValidationError::InvalidParameter { .. } => {
911 (StatusCode::BAD_REQUEST, "INVALID_PARAMETER")
912 }
913 };
914
915 let body = serde_json::json!({
916 "error": error_message,
917 "code": code,
918 "request_id": request_id.to_string(),
919 });
920
921 (status, serde_json::to_string(&body).unwrap()).into_response()
922 }
923}
924
925pub fn validate_cid(cid: &str, config: &ValidationConfig) -> Result<(), ValidationError> {
929 if cid.is_empty() {
931 return Err(ValidationError::InvalidCid(
932 "CID cannot be empty".to_string(),
933 ));
934 }
935
936 if !config.validate_cid_format {
937 return Ok(());
938 }
939
940 if cid.len() > config.max_cid_length {
941 return Err(ValidationError::InvalidCid(format!(
942 "CID too long: {} chars (max: {})",
943 cid.len(),
944 config.max_cid_length
945 )));
946 }
947
948 if cid.starts_with("Qm") && cid.len() == 46 {
950 Ok(())
952 } else if cid.starts_with("b") || cid.starts_with("z") || cid.starts_with("f") {
953 Ok(())
955 } else {
956 Err(ValidationError::InvalidCid(
957 "Invalid CID format: must be CIDv0 (Qm...) or CIDv1 (b..., z..., f...)".to_string(),
958 ))
959 }
960}
961
962pub fn validate_batch_size(size: usize, config: &ValidationConfig) -> Result<(), ValidationError> {
964 if size == 0 {
965 return Err(ValidationError::InvalidParameter {
966 name: "batch".to_string(),
967 reason: "Batch cannot be empty".to_string(),
968 });
969 }
970
971 if size > config.max_batch_size {
972 return Err(ValidationError::BatchTooLarge {
973 size,
974 max: config.max_batch_size,
975 });
976 }
977
978 Ok(())
979}
980
981pub fn validate_content_type(
983 headers: &HeaderMap,
984 expected: &str,
985 config: &ValidationConfig,
986) -> Result<(), ValidationError> {
987 if !config.content_type_validation {
988 return Ok(());
989 }
990
991 let content_type = headers
992 .get(header::CONTENT_TYPE)
993 .and_then(|h| h.to_str().ok())
994 .unwrap_or("");
995
996 if !content_type.starts_with(expected) {
997 return Err(ValidationError::InvalidContentType {
998 expected: expected.to_string(),
999 actual: content_type.to_string(),
1000 });
1001 }
1002
1003 Ok(())
1004}
1005
1006#[derive(Clone)]
1008pub struct ValidationState {
1009 pub config: ValidationConfig,
1010}
1011
1012pub async fn validation_middleware(
1016 State(_validation_state): State<ValidationState>,
1017 req: Request,
1018 next: Next,
1019) -> Result<Response, ValidationError> {
1020 let (parts, body) = req.into_parts();
1021
1022 if parts.method == Method::POST || parts.method == Method::PUT {
1024 if let Some(content_type) = parts.headers.get(header::CONTENT_TYPE) {
1026 if let Ok(ct_str) = content_type.to_str() {
1027 if ct_str.contains("multipart/form-data") {
1028 let req = Request::from_parts(parts, body);
1030 return Ok(next.run(req).await);
1031 }
1032 }
1033 }
1034 }
1035
1036 let req = Request::from_parts(parts, body);
1038 Ok(next.run(req).await)
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043 use super::*;
1044
1045 #[test]
1046 fn test_cors_config_default() {
1047 let config = CorsConfig::default();
1048 assert!(config.allowed_origins.is_empty());
1049 assert!(config.allowed_methods.contains(&Method::GET));
1050 assert!(config.allowed_methods.contains(&Method::POST));
1051 assert!(!config.allow_credentials);
1052 assert_eq!(config.max_age, 86400);
1053 }
1054
1055 #[test]
1056 fn test_cors_config_permissive() {
1057 let config = CorsConfig::permissive();
1058 assert!(config.allowed_origins.contains("*"));
1059 assert!(config.is_origin_allowed("https://example.com"));
1060 assert!(config.is_origin_allowed("http://localhost:3000"));
1061 }
1062
1063 #[test]
1064 fn test_cors_config_allow_origin() {
1065 let config = CorsConfig::default()
1066 .allow_origin("https://example.com")
1067 .allow_origin("https://api.example.com");
1068
1069 assert!(config.is_origin_allowed("https://example.com"));
1070 assert!(config.is_origin_allowed("https://api.example.com"));
1071 assert!(!config.is_origin_allowed("https://other.com"));
1072 }
1073
1074 #[test]
1075 fn test_rate_limit_config_default() {
1076 let config = RateLimitConfig::default();
1077 assert_eq!(config.max_requests, 100);
1078 assert_eq!(config.window, Duration::from_secs(60));
1079 assert_eq!(config.burst_capacity, 10);
1080 }
1081
1082 #[test]
1083 fn test_cache_config_default() {
1084 let config = CacheConfig::default();
1085 assert_eq!(config.default_max_age, 3600);
1086 assert!(config.public);
1087 assert!(config.immutable_cids);
1088 }
1089
1090 #[test]
1091 fn test_add_caching_headers() {
1092 let mut headers = HeaderMap::new();
1093 let config = CacheConfig::default();
1094
1095 add_caching_headers(&mut headers, "QmTest123", &config);
1096
1097 assert!(headers.contains_key(header::ETAG));
1098 assert!(headers.contains_key(header::CACHE_CONTROL));
1099
1100 let etag = headers.get(header::ETAG).unwrap().to_str().unwrap();
1101 assert_eq!(etag, "\"QmTest123\"");
1102
1103 let cache_control = headers
1104 .get(header::CACHE_CONTROL)
1105 .unwrap()
1106 .to_str()
1107 .unwrap();
1108 assert!(cache_control.contains("public"));
1109 assert!(cache_control.contains("max-age=3600"));
1110 assert!(cache_control.contains("immutable"));
1111 }
1112
1113 #[test]
1114 fn test_check_etag_match() {
1115 let mut headers = HeaderMap::new();
1116
1117 assert!(!check_etag_match(&headers, "QmTest123"));
1119
1120 headers.insert(
1122 header::IF_NONE_MATCH,
1123 HeaderValue::from_static("\"QmTest123\""),
1124 );
1125 assert!(check_etag_match(&headers, "QmTest123"));
1126
1127 assert!(!check_etag_match(&headers, "QmOther456"));
1129 }
1130
1131 #[tokio::test]
1132 async fn test_rate_limit_state() {
1133 let config = RateLimitConfig {
1134 max_requests: 5,
1135 window: Duration::from_secs(1),
1136 burst_capacity: 3,
1137 };
1138 let state = RateLimitState::new(config);
1139
1140 for _ in 0..3 {
1142 let (allowed, _) = state.get_bucket("127.0.0.1").await;
1143 assert!(allowed);
1144 }
1145 }
1146
1147 #[test]
1148 fn test_compression_level_to_level() {
1149 assert_eq!(CompressionLevel::Fastest.to_level(), 1);
1150 assert_eq!(CompressionLevel::Balanced.to_level(), 5);
1151 assert_eq!(CompressionLevel::Best.to_level(), 9);
1152 assert_eq!(CompressionLevel::Custom(7).to_level(), 7);
1153 assert_eq!(CompressionLevel::Custom(15).to_level(), 9); }
1155
1156 #[test]
1157 fn test_compression_level_to_brotli_quality() {
1158 assert_eq!(CompressionLevel::Fastest.to_brotli_quality(), 1);
1159 assert_eq!(CompressionLevel::Balanced.to_brotli_quality(), 6);
1160 assert_eq!(CompressionLevel::Best.to_brotli_quality(), 11);
1161 assert_eq!(CompressionLevel::Custom(8).to_brotli_quality(), 8);
1162 assert_eq!(CompressionLevel::Custom(15).to_brotli_quality(), 11); }
1164
1165 #[test]
1166 fn test_compression_config_default() {
1167 let config = CompressionConfig::default();
1168 assert!(config.enable_gzip);
1169 assert!(config.enable_brotli);
1170 assert!(config.enable_deflate);
1171 assert_eq!(config.level, CompressionLevel::Balanced);
1172 assert_eq!(config.min_size, 1024);
1173 }
1174
1175 #[test]
1176 fn test_compression_config_fast() {
1177 let config = CompressionConfig::fast();
1178 assert_eq!(config.level, CompressionLevel::Fastest);
1179 assert!(config.enable_gzip);
1180 }
1181
1182 #[test]
1183 fn test_compression_config_best() {
1184 let config = CompressionConfig::best();
1185 assert_eq!(config.level, CompressionLevel::Best);
1186 assert!(config.enable_brotli);
1187 }
1188
1189 #[test]
1190 fn test_compression_config_builder() {
1191 let config = CompressionConfig::default()
1192 .with_level(CompressionLevel::Custom(7))
1193 .with_min_size(2048)
1194 .with_algorithms(true, false, false);
1195
1196 assert_eq!(config.level, CompressionLevel::Custom(7));
1197 assert_eq!(config.min_size, 2048);
1198 assert!(config.enable_gzip);
1199 assert!(!config.enable_brotli);
1200 assert!(!config.enable_deflate);
1201 }
1202
1203 #[test]
1204 fn test_compression_config_validation_valid() {
1205 let config = CompressionConfig::default();
1206 assert!(config.validate().is_ok());
1207
1208 let config = CompressionConfig::default().with_algorithms(true, false, false);
1209 assert!(config.validate().is_ok());
1210 }
1211
1212 #[test]
1213 fn test_compression_config_validation_invalid() {
1214 let config = CompressionConfig::default().with_algorithms(false, false, false);
1216 assert!(config.validate().is_err());
1217
1218 let config = CompressionConfig::default().with_min_size(200 * 1024 * 1024);
1220 assert!(config.validate().is_err());
1221 }
1222
1223 #[test]
1224 fn test_rate_limit_config_validation_valid() {
1225 let config = RateLimitConfig::default();
1226 assert!(config.validate().is_ok());
1227
1228 let config = RateLimitConfig {
1229 max_requests: 100,
1230 window: Duration::from_secs(60),
1231 burst_capacity: 50,
1232 };
1233 assert!(config.validate().is_ok());
1234 }
1235
1236 #[test]
1237 fn test_rate_limit_config_validation_invalid() {
1238 let config = RateLimitConfig {
1240 max_requests: 0,
1241 window: Duration::from_secs(60),
1242 burst_capacity: 10,
1243 };
1244 assert!(config.validate().is_err());
1245
1246 let config = RateLimitConfig {
1248 max_requests: 100,
1249 window: Duration::from_secs(0),
1250 burst_capacity: 10,
1251 };
1252 assert!(config.validate().is_err());
1253
1254 let config = RateLimitConfig {
1256 max_requests: 100,
1257 window: Duration::from_secs(60),
1258 burst_capacity: 200,
1259 };
1260 assert!(config.validate().is_err());
1261 }
1262
1263 #[test]
1264 fn test_cache_config_validation_valid() {
1265 let config = CacheConfig::default();
1266 assert!(config.validate().is_ok());
1267
1268 let config = CacheConfig {
1269 default_max_age: 86400, public: true,
1271 immutable_cids: true,
1272 };
1273 assert!(config.validate().is_ok());
1274 }
1275
1276 #[test]
1277 fn test_cache_config_validation_invalid() {
1278 let config = CacheConfig {
1280 default_max_age: 400 * 24 * 3600, public: true,
1282 immutable_cids: true,
1283 };
1284 assert!(config.validate().is_err());
1285 }
1286
1287 #[test]
1290 fn test_validation_config_default() {
1291 let config = ValidationConfig::default();
1292 assert_eq!(config.max_body_size, 100 * 1024 * 1024);
1293 assert_eq!(config.max_cid_length, 100);
1294 assert!(config.validate_cid_format);
1295 assert!(config.content_type_validation);
1296 assert_eq!(config.max_batch_size, 1000);
1297 }
1298
1299 #[test]
1300 fn test_validation_config_strict() {
1301 let config = ValidationConfig::strict();
1302 assert_eq!(config.max_body_size, 10 * 1024 * 1024);
1303 assert_eq!(config.max_cid_length, 64);
1304 assert_eq!(config.max_batch_size, 100);
1305 }
1306
1307 #[test]
1308 fn test_validation_config_permissive() {
1309 let config = ValidationConfig::permissive();
1310 assert_eq!(config.max_body_size, 1024 * 1024 * 1024);
1311 assert_eq!(config.max_cid_length, 200);
1312 assert!(!config.validate_cid_format);
1313 assert!(!config.content_type_validation);
1314 assert_eq!(config.max_batch_size, 10000);
1315 }
1316
1317 #[test]
1318 fn test_validate_cid_v0() {
1319 let config = ValidationConfig::default();
1320
1321 assert!(validate_cid("QmXoypizjW3WknFiJnKLwHCnL72vedxjQkDDP1mXWo6uco", &config).is_ok());
1323
1324 assert!(validate_cid("QmShort", &config).is_err());
1326
1327 assert!(validate_cid("", &config).is_err());
1329 }
1330
1331 #[test]
1332 fn test_validate_cid_v1() {
1333 let config = ValidationConfig::default();
1334
1335 assert!(validate_cid(
1337 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1338 &config
1339 )
1340 .is_ok());
1341 assert!(validate_cid("zb2rhk6GMPQF8p1kqXvhYnCMp3hGGUQVvqp6qjdvNLKqCqKCo", &config).is_ok());
1342
1343 assert!(validate_cid("invalid_cid_format", &config).is_err());
1345 }
1346
1347 #[test]
1348 fn test_validate_cid_disabled() {
1349 let config = ValidationConfig {
1350 validate_cid_format: false,
1351 ..Default::default()
1352 };
1353
1354 assert!(validate_cid("invalid_format", &config).is_ok());
1356 assert!(validate_cid("", &config).is_err()); }
1358
1359 #[test]
1360 fn test_validate_batch_size_valid() {
1361 let config = ValidationConfig::default();
1362
1363 assert!(validate_batch_size(1, &config).is_ok());
1364 assert!(validate_batch_size(100, &config).is_ok());
1365 assert!(validate_batch_size(1000, &config).is_ok());
1366 }
1367
1368 #[test]
1369 fn test_validate_batch_size_invalid() {
1370 let config = ValidationConfig::default();
1371
1372 assert!(validate_batch_size(0, &config).is_err());
1374
1375 assert!(validate_batch_size(1001, &config).is_err());
1377 assert!(validate_batch_size(10000, &config).is_err());
1378 }
1379
1380 #[test]
1381 fn test_validate_content_type_valid() {
1382 let config = ValidationConfig::default();
1383 let mut headers = HeaderMap::new();
1384 headers.insert(
1385 header::CONTENT_TYPE,
1386 HeaderValue::from_static("application/json"),
1387 );
1388
1389 assert!(validate_content_type(&headers, "application/json", &config).is_ok());
1390 }
1391
1392 #[test]
1393 fn test_validate_content_type_invalid() {
1394 let config = ValidationConfig::default();
1395 let mut headers = HeaderMap::new();
1396 headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
1397
1398 assert!(validate_content_type(&headers, "application/json", &config).is_err());
1399 }
1400
1401 #[test]
1402 fn test_validate_content_type_disabled() {
1403 let config = ValidationConfig {
1404 content_type_validation: false,
1405 ..Default::default()
1406 };
1407
1408 let mut headers = HeaderMap::new();
1409 headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
1410
1411 assert!(validate_content_type(&headers, "application/json", &config).is_ok());
1413 }
1414
1415 #[test]
1416 fn test_validation_error_display() {
1417 let err = ValidationError::InvalidCid("test".to_string());
1418 assert_eq!(err.to_string(), "Invalid CID format: test");
1419
1420 let err = ValidationError::BodyTooLarge {
1421 size: 200,
1422 max: 100,
1423 };
1424 assert!(err.to_string().contains("200 bytes"));
1425 assert!(err.to_string().contains("100 bytes"));
1426
1427 let err = ValidationError::BatchTooLarge {
1428 size: 2000,
1429 max: 1000,
1430 };
1431 assert!(err.to_string().contains("2000 items"));
1432 assert!(err.to_string().contains("1000 items"));
1433 }
1434}