allsource_core/
middleware.rs

1use crate::auth::{AuthManager, Claims, Permission};
2use crate::error::AllSourceError;
3use crate::rate_limit::RateLimiter;
4use axum::{
5    extract::{Request, State},
6    http::{HeaderMap, StatusCode},
7    middleware::Next,
8    response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11
12/// Authentication state shared across requests
13#[derive(Clone)]
14pub struct AuthState {
15    pub auth_manager: Arc<AuthManager>,
16}
17
18/// Rate limiting state
19#[derive(Clone)]
20pub struct RateLimitState {
21    pub rate_limiter: Arc<RateLimiter>,
22}
23
24/// Authenticated request context
25#[derive(Debug, Clone)]
26pub struct AuthContext {
27    pub claims: Claims,
28}
29
30impl AuthContext {
31    /// Check if user has required permission
32    pub fn require_permission(&self, permission: Permission) -> Result<(), AllSourceError> {
33        if self.claims.has_permission(permission) {
34            Ok(())
35        } else {
36            Err(AllSourceError::ValidationError(
37                "Insufficient permissions".to_string(),
38            ))
39        }
40    }
41
42    /// Get tenant ID from context
43    pub fn tenant_id(&self) -> &str {
44        &self.claims.tenant_id
45    }
46
47    /// Get user ID from context
48    pub fn user_id(&self) -> &str {
49        &self.claims.sub
50    }
51}
52
53/// Extract token from Authorization header
54fn extract_token(headers: &HeaderMap) -> Result<String, AllSourceError> {
55    let auth_header = headers
56        .get("authorization")
57        .ok_or_else(|| AllSourceError::ValidationError("Missing authorization header".to_string()))?
58        .to_str()
59        .map_err(|_| AllSourceError::ValidationError("Invalid authorization header".to_string()))?;
60
61    // Support both "Bearer <token>" and "<token>" formats
62    let token = if auth_header.starts_with("Bearer ") {
63        auth_header.trim_start_matches("Bearer ").trim()
64    } else if auth_header.starts_with("bearer ") {
65        auth_header.trim_start_matches("bearer ").trim()
66    } else {
67        auth_header.trim()
68    };
69
70    if token.is_empty() {
71        return Err(AllSourceError::ValidationError(
72            "Empty authorization token".to_string(),
73        ));
74    }
75
76    Ok(token.to_string())
77}
78
79/// Authentication middleware
80pub async fn auth_middleware(
81    State(auth_state): State<AuthState>,
82    mut request: Request,
83    next: Next,
84) -> Result<Response, AuthError> {
85    let headers = request.headers();
86
87    // Extract and validate token (JWT or API key)
88    let token = extract_token(headers)?;
89
90    let claims = if token.starts_with("ask_") {
91        // API Key authentication
92        auth_state.auth_manager.validate_api_key(&token)?
93    } else {
94        // JWT authentication
95        auth_state.auth_manager.validate_token(&token)?
96    };
97
98    // Insert auth context into request extensions
99    request.extensions_mut().insert(AuthContext { claims });
100
101    Ok(next.run(request).await)
102}
103
104/// Optional authentication middleware (allows unauthenticated requests)
105pub async fn optional_auth_middleware(
106    State(auth_state): State<AuthState>,
107    mut request: Request,
108    next: Next,
109) -> Response {
110    let headers = request.headers();
111
112    if let Ok(token) = extract_token(headers) {
113        // Try to authenticate, but don't fail if invalid
114        let claims = if token.starts_with("ask_") {
115            auth_state.auth_manager.validate_api_key(&token).ok()
116        } else {
117            auth_state.auth_manager.validate_token(&token).ok()
118        };
119
120        if let Some(claims) = claims {
121            request.extensions_mut().insert(AuthContext { claims });
122        }
123    }
124
125    next.run(request).await
126}
127
128/// Error type for authentication failures
129#[derive(Debug)]
130pub struct AuthError(AllSourceError);
131
132impl From<AllSourceError> for AuthError {
133    fn from(err: AllSourceError) -> Self {
134        AuthError(err)
135    }
136}
137
138impl IntoResponse for AuthError {
139    fn into_response(self) -> Response {
140        let (status, message) = match self.0 {
141            AllSourceError::ValidationError(msg) => (StatusCode::UNAUTHORIZED, msg),
142            _ => (
143                StatusCode::INTERNAL_SERVER_ERROR,
144                "Internal server error".to_string(),
145            ),
146        };
147
148        (status, message).into_response()
149    }
150}
151
152/// Axum extractor for authenticated requests
153pub struct Authenticated(pub AuthContext);
154
155#[axum::async_trait]
156impl<S> axum::extract::FromRequestParts<S> for Authenticated
157where
158    S: Send + Sync,
159{
160    type Rejection = (StatusCode, &'static str);
161
162    async fn from_request_parts(
163        parts: &mut axum::http::request::Parts,
164        _state: &S,
165    ) -> Result<Self, Self::Rejection> {
166        parts
167            .extensions
168            .get::<AuthContext>()
169            .cloned()
170            .map(Authenticated)
171            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))
172    }
173}
174
175/// Axum extractor for admin-only requests
176pub struct Admin(pub AuthContext);
177
178#[axum::async_trait]
179impl<S> axum::extract::FromRequestParts<S> for Admin
180where
181    S: Send + Sync,
182{
183    type Rejection = (StatusCode, &'static str);
184
185    async fn from_request_parts(
186        parts: &mut axum::http::request::Parts,
187        _state: &S,
188    ) -> Result<Self, Self::Rejection> {
189        let auth_ctx = parts
190            .extensions
191            .get::<AuthContext>()
192            .cloned()
193            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
194
195        auth_ctx
196            .require_permission(Permission::Admin)
197            .map_err(|_| (StatusCode::FORBIDDEN, "Admin permission required"))?;
198
199        Ok(Admin(auth_ctx))
200    }
201}
202
203/// Rate limiting middleware
204/// Checks rate limits based on tenant_id from auth context
205pub async fn rate_limit_middleware(
206    State(rate_limit_state): State<RateLimitState>,
207    request: Request,
208    next: Next,
209) -> Result<Response, RateLimitError> {
210    // Extract auth context from request
211    let auth_ctx = request
212        .extensions()
213        .get::<AuthContext>()
214        .ok_or_else(|| RateLimitError::Unauthorized)?;
215
216    // Check rate limit for this tenant
217    let result = rate_limit_state
218        .rate_limiter
219        .check_rate_limit(auth_ctx.tenant_id());
220
221    if !result.allowed {
222        return Err(RateLimitError::RateLimitExceeded {
223            retry_after: result.retry_after.unwrap_or_default().as_secs(),
224            limit: result.limit,
225        });
226    }
227
228    // Add rate limit headers to response
229    let mut response = next.run(request).await;
230    let headers = response.headers_mut();
231    headers.insert("X-RateLimit-Limit", result.limit.to_string().parse().unwrap());
232    headers.insert("X-RateLimit-Remaining", result.remaining.to_string().parse().unwrap());
233
234    Ok(response)
235}
236
237/// Error type for rate limiting failures
238#[derive(Debug)]
239pub enum RateLimitError {
240    RateLimitExceeded { retry_after: u64, limit: u32 },
241    Unauthorized,
242}
243
244impl IntoResponse for RateLimitError {
245    fn into_response(self) -> Response {
246        match self {
247            RateLimitError::RateLimitExceeded { retry_after, limit } => {
248                let mut response = (
249                    StatusCode::TOO_MANY_REQUESTS,
250                    format!("Rate limit exceeded. Limit: {} requests/min", limit),
251                )
252                    .into_response();
253
254                if retry_after > 0 {
255                    response.headers_mut().insert(
256                        "Retry-After",
257                        retry_after.to_string().parse().unwrap(),
258                    );
259                }
260
261                response
262            }
263            RateLimitError::Unauthorized => (
264                StatusCode::UNAUTHORIZED,
265                "Authentication required for rate limiting",
266            )
267                .into_response(),
268        }
269    }
270}
271
272/// Helper macro to require specific permission
273#[macro_export]
274macro_rules! require_permission {
275    ($auth:expr, $perm:expr) => {
276        $auth
277            .0
278            .require_permission($perm)
279            .map_err(|_| (axum::http::StatusCode::FORBIDDEN, "Insufficient permissions"))?
280    };
281}
282
283// ============================================================================
284// Tenant Isolation Middleware (Phase 5B)
285// ============================================================================
286
287use crate::domain::entities::Tenant;
288use crate::domain::repositories::TenantRepository;
289use crate::domain::value_objects::TenantId;
290
291/// Tenant isolation state for middleware
292#[derive(Clone)]
293pub struct TenantState<R: TenantRepository> {
294    pub tenant_repository: Arc<R>,
295}
296
297/// Validated tenant context injected into requests
298///
299/// This context is created by the tenant_isolation_middleware after
300/// validating that the tenant exists and is active.
301#[derive(Debug, Clone)]
302pub struct TenantContext {
303    pub tenant: Tenant,
304}
305
306impl TenantContext {
307    /// Get the tenant ID
308    pub fn tenant_id(&self) -> &TenantId {
309        self.tenant.id()
310    }
311
312    /// Check if tenant is active
313    pub fn is_active(&self) -> bool {
314        self.tenant.is_active()
315    }
316}
317
318/// Tenant isolation middleware
319///
320/// Validates that the authenticated tenant exists and is active.
321/// Injects TenantContext into the request for use by handlers.
322///
323/// # Phase 5B: Tenant Isolation
324/// This middleware enforces tenant boundaries by:
325/// 1. Extracting tenant_id from AuthContext
326/// 2. Loading tenant from repository
327/// 3. Validating tenant is active
328/// 4. Injecting TenantContext into request extensions
329///
330/// Must be applied after auth_middleware.
331pub async fn tenant_isolation_middleware<R: TenantRepository + 'static>(
332    State(tenant_state): State<TenantState<R>>,
333    mut request: Request,
334    next: Next,
335) -> Result<Response, TenantError> {
336    // Extract auth context (must be authenticated)
337    let auth_ctx = request
338        .extensions()
339        .get::<AuthContext>()
340        .ok_or(TenantError::Unauthorized)?
341        .clone();
342
343    // Parse tenant ID
344    let tenant_id = TenantId::new(auth_ctx.tenant_id().to_string())
345        .map_err(|_| TenantError::InvalidTenant)?;
346
347    // Load tenant from repository
348    let tenant = tenant_state
349        .tenant_repository
350        .find_by_id(&tenant_id)
351        .await
352        .map_err(|e| TenantError::RepositoryError(e.to_string()))?
353        .ok_or(TenantError::TenantNotFound)?;
354
355    // Validate tenant is active
356    if !tenant.is_active() {
357        return Err(TenantError::TenantInactive);
358    }
359
360    // Inject tenant context into request
361    request.extensions_mut().insert(TenantContext { tenant });
362
363    // Continue to next middleware/handler
364    Ok(next.run(request).await)
365}
366
367/// Error type for tenant isolation failures
368#[derive(Debug)]
369pub enum TenantError {
370    Unauthorized,
371    InvalidTenant,
372    TenantNotFound,
373    TenantInactive,
374    RepositoryError(String),
375}
376
377impl IntoResponse for TenantError {
378    fn into_response(self) -> Response {
379        let (status, message) = match self {
380            TenantError::Unauthorized => (
381                StatusCode::UNAUTHORIZED,
382                "Authentication required for tenant access",
383            ),
384            TenantError::InvalidTenant => (
385                StatusCode::BAD_REQUEST,
386                "Invalid tenant identifier",
387            ),
388            TenantError::TenantNotFound => (
389                StatusCode::NOT_FOUND,
390                "Tenant not found",
391            ),
392            TenantError::TenantInactive => (
393                StatusCode::FORBIDDEN,
394                "Tenant is inactive",
395            ),
396            TenantError::RepositoryError(_) => (
397                StatusCode::INTERNAL_SERVER_ERROR,
398                "Failed to validate tenant",
399            ),
400        };
401
402        (status, message).into_response()
403    }
404}
405
406// ============================================================================
407// Request ID Middleware (Phase 5C)
408// ============================================================================
409
410use uuid::Uuid;
411
412/// Request context with unique ID for tracing
413#[derive(Debug, Clone)]
414pub struct RequestId(pub String);
415
416impl RequestId {
417    /// Generate a new request ID
418    pub fn new() -> Self {
419        Self(Uuid::new_v4().to_string())
420    }
421
422    /// Get the request ID as a string
423    pub fn as_str(&self) -> &str {
424        &self.0
425    }
426}
427
428/// Request ID middleware
429///
430/// Generates a unique request ID for each request and injects it into:
431/// - Request extensions (for use in handlers/logging)
432/// - Response headers (X-Request-ID)
433///
434/// If the request already has an X-Request-ID header, it will be used instead.
435///
436/// # Phase 5C: Request Tracing
437/// This middleware enables distributed tracing by:
438/// 1. Generating unique IDs for each request
439/// 2. Propagating IDs through the request lifecycle
440/// 3. Returning IDs in response headers
441/// 4. Supporting client-provided request IDs
442pub async fn request_id_middleware(
443    mut request: Request,
444    next: Next,
445) -> Response {
446    // Check if request already has a request ID
447    let request_id = request
448        .headers()
449        .get("x-request-id")
450        .and_then(|v| v.to_str().ok())
451        .map(|s| RequestId(s.to_string()))
452        .unwrap_or_else(RequestId::new);
453
454    // Store request ID in extensions
455    request.extensions_mut().insert(request_id.clone());
456
457    // Process request
458    let mut response = next.run(request).await;
459
460    // Add request ID to response headers
461    response.headers_mut().insert(
462        "x-request-id",
463        request_id.0.parse().unwrap(),
464    );
465
466    response
467}
468
469// ============================================================================
470// Security Headers Middleware (Phase 5C)
471// ============================================================================
472
473/// Security headers configuration
474#[derive(Debug, Clone)]
475pub struct SecurityConfig {
476    /// Enable HSTS (HTTP Strict Transport Security)
477    pub enable_hsts: bool,
478    /// HSTS max age in seconds
479    pub hsts_max_age: u32,
480    /// Enable X-Frame-Options
481    pub enable_frame_options: bool,
482    /// X-Frame-Options value
483    pub frame_options: FrameOptions,
484    /// Enable X-Content-Type-Options
485    pub enable_content_type_options: bool,
486    /// Enable X-XSS-Protection
487    pub enable_xss_protection: bool,
488    /// Content Security Policy
489    pub csp: Option<String>,
490    /// CORS allowed origins
491    pub cors_origins: Vec<String>,
492    /// CORS allowed methods
493    pub cors_methods: Vec<String>,
494    /// CORS allowed headers
495    pub cors_headers: Vec<String>,
496    /// CORS max age
497    pub cors_max_age: u32,
498}
499
500#[derive(Debug, Clone)]
501pub enum FrameOptions {
502    Deny,
503    SameOrigin,
504    AllowFrom(String),
505}
506
507impl Default for SecurityConfig {
508    fn default() -> Self {
509        Self {
510            enable_hsts: true,
511            hsts_max_age: 31536000, // 1 year
512            enable_frame_options: true,
513            frame_options: FrameOptions::Deny,
514            enable_content_type_options: true,
515            enable_xss_protection: true,
516            csp: Some("default-src 'self'".to_string()),
517            cors_origins: vec!["*".to_string()],
518            cors_methods: vec!["GET".to_string(), "POST".to_string(), "PUT".to_string(), "DELETE".to_string()],
519            cors_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
520            cors_max_age: 3600,
521        }
522    }
523}
524
525#[derive(Clone)]
526pub struct SecurityState {
527    pub config: SecurityConfig,
528}
529
530/// Security headers middleware
531///
532/// Adds security-related HTTP headers to all responses:
533/// - HSTS (Strict-Transport-Security)
534/// - X-Frame-Options
535/// - X-Content-Type-Options
536/// - X-XSS-Protection
537/// - Content-Security-Policy
538/// - CORS headers
539///
540/// # Phase 5C: Security Hardening
541/// This middleware provides defense-in-depth by:
542/// 1. Preventing clickjacking (X-Frame-Options)
543/// 2. Preventing MIME sniffing (X-Content-Type-Options)
544/// 3. Enforcing HTTPS (HSTS)
545/// 4. Preventing XSS (CSP, X-XSS-Protection)
546/// 5. Enabling CORS for controlled access
547pub async fn security_headers_middleware(
548    State(security_state): State<SecurityState>,
549    request: Request,
550    next: Next,
551) -> Response {
552    let mut response = next.run(request).await;
553    let headers = response.headers_mut();
554    let config = &security_state.config;
555
556    // HSTS
557    if config.enable_hsts {
558        headers.insert(
559            "strict-transport-security",
560            format!("max-age={}", config.hsts_max_age).parse().unwrap(),
561        );
562    }
563
564    // X-Frame-Options
565    if config.enable_frame_options {
566        let value = match &config.frame_options {
567            FrameOptions::Deny => "DENY",
568            FrameOptions::SameOrigin => "SAMEORIGIN",
569            FrameOptions::AllowFrom(origin) => origin,
570        };
571        headers.insert("x-frame-options", value.parse().unwrap());
572    }
573
574    // X-Content-Type-Options
575    if config.enable_content_type_options {
576        headers.insert("x-content-type-options", "nosniff".parse().unwrap());
577    }
578
579    // X-XSS-Protection
580    if config.enable_xss_protection {
581        headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
582    }
583
584    // Content-Security-Policy
585    if let Some(csp) = &config.csp {
586        headers.insert("content-security-policy", csp.parse().unwrap());
587    }
588
589    // CORS headers
590    headers.insert(
591        "access-control-allow-origin",
592        config.cors_origins.join(", ").parse().unwrap(),
593    );
594    headers.insert(
595        "access-control-allow-methods",
596        config.cors_methods.join(", ").parse().unwrap(),
597    );
598    headers.insert(
599        "access-control-allow-headers",
600        config.cors_headers.join(", ").parse().unwrap(),
601    );
602    headers.insert(
603        "access-control-max-age",
604        config.cors_max_age.to_string().parse().unwrap(),
605    );
606
607    response
608}
609
610// ============================================================================
611// IP Filtering Middleware (Phase 5C)
612// ============================================================================
613
614use crate::infrastructure::security::IpFilter;
615use std::net::SocketAddr;
616
617#[derive(Clone)]
618pub struct IpFilterState {
619    pub ip_filter: Arc<IpFilter>,
620}
621
622/// IP filtering middleware
623///
624/// Blocks or allows requests based on IP address rules.
625/// Supports both global and per-tenant IP filtering.
626///
627/// # Phase 5C: Access Control
628/// This middleware provides IP-based access control by:
629/// 1. Extracting client IP from request
630/// 2. Checking against global and tenant-specific rules
631/// 3. Blocking requests from unauthorized IPs
632/// 4. Supporting both allowlists and blocklists
633pub async fn ip_filter_middleware(
634    State(ip_filter_state): State<IpFilterState>,
635    request: Request,
636    next: Next,
637) -> Result<Response, IpFilterError> {
638    // Extract client IP address
639    let client_ip = request
640        .extensions()
641        .get::<axum::extract::ConnectInfo<SocketAddr>>()
642        .map(|connect_info| connect_info.0.ip())
643        .ok_or(IpFilterError::NoIpAddress)?;
644
645    // Check if this is a tenant-scoped request
646    let result = if let Some(tenant_ctx) = request.extensions().get::<TenantContext>() {
647        // Tenant-specific filtering
648        ip_filter_state
649            .ip_filter
650            .is_allowed_for_tenant(tenant_ctx.tenant_id(), &client_ip)
651    } else {
652        // Global filtering only
653        ip_filter_state.ip_filter.is_allowed(&client_ip)
654    };
655
656    // Block if not allowed
657    if !result.allowed {
658        return Err(IpFilterError::Blocked { reason: result.reason });
659    }
660
661    // Allow request to proceed
662    Ok(next.run(request).await)
663}
664
665/// Error type for IP filtering failures
666#[derive(Debug)]
667pub enum IpFilterError {
668    NoIpAddress,
669    Blocked { reason: String },
670}
671
672impl IntoResponse for IpFilterError {
673    fn into_response(self) -> Response {
674        match self {
675            IpFilterError::NoIpAddress => (
676                StatusCode::BAD_REQUEST,
677                "Unable to determine client IP address",
678            ).into_response(),
679            IpFilterError::Blocked { reason } => (
680                StatusCode::FORBIDDEN,
681                format!("Access denied: {}", reason),
682            ).into_response(),
683        }
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use crate::auth::{Role, User};
691
692    #[test]
693    fn test_extract_bearer_token() {
694        let mut headers = HeaderMap::new();
695        headers.insert("authorization", "Bearer test_token_123".parse().unwrap());
696
697        let token = extract_token(&headers).unwrap();
698        assert_eq!(token, "test_token_123");
699    }
700
701    #[test]
702    fn test_extract_lowercase_bearer() {
703        let mut headers = HeaderMap::new();
704        headers.insert("authorization", "bearer test_token_123".parse().unwrap());
705
706        let token = extract_token(&headers).unwrap();
707        assert_eq!(token, "test_token_123");
708    }
709
710    #[test]
711    fn test_extract_plain_token() {
712        let mut headers = HeaderMap::new();
713        headers.insert("authorization", "test_token_123".parse().unwrap());
714
715        let token = extract_token(&headers).unwrap();
716        assert_eq!(token, "test_token_123");
717    }
718
719    #[test]
720    fn test_missing_auth_header() {
721        let headers = HeaderMap::new();
722        assert!(extract_token(&headers).is_err());
723    }
724
725    #[test]
726    fn test_auth_context_permissions() {
727        let claims = Claims::new(
728            "user1".to_string(),
729            "tenant1".to_string(),
730            Role::Developer,
731            chrono::Duration::hours(1),
732        );
733
734        let ctx = AuthContext { claims };
735
736        assert!(ctx.require_permission(Permission::Read).is_ok());
737        assert!(ctx.require_permission(Permission::Write).is_ok());
738        assert!(ctx.require_permission(Permission::Admin).is_err());
739    }
740}