allsource_core/infrastructure/security/
middleware.rs

1use crate::error::AllSourceError;
2use crate::infrastructure::security::auth::{AuthManager, Claims, Permission};
3use crate::infrastructure::security::rate_limit::RateLimiter;
4use axum::{
5    extract::{Request, State},
6    http::{HeaderMap, StatusCode},
7    middleware::Next,
8    response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11
12/// Authentication state shared across requests
13#[derive(Clone)]
14pub struct AuthState {
15    pub auth_manager: Arc<AuthManager>,
16}
17
18/// Rate limiting state
19#[derive(Clone)]
20pub struct RateLimitState {
21    pub rate_limiter: Arc<RateLimiter>,
22}
23
24/// Authenticated request context
25#[derive(Debug, Clone)]
26pub struct AuthContext {
27    pub claims: Claims,
28}
29
30impl AuthContext {
31    /// Check if user has required permission
32    pub fn require_permission(&self, permission: Permission) -> Result<(), AllSourceError> {
33        if self.claims.has_permission(permission) {
34            Ok(())
35        } else {
36            Err(AllSourceError::ValidationError(
37                "Insufficient permissions".to_string(),
38            ))
39        }
40    }
41
42    /// Get tenant ID from context
43    pub fn tenant_id(&self) -> &str {
44        &self.claims.tenant_id
45    }
46
47    /// Get user ID from context
48    pub fn user_id(&self) -> &str {
49        &self.claims.sub
50    }
51}
52
53/// Extract token from Authorization header
54fn extract_token(headers: &HeaderMap) -> Result<String, AllSourceError> {
55    let auth_header = headers
56        .get("authorization")
57        .ok_or_else(|| AllSourceError::ValidationError("Missing authorization header".to_string()))?
58        .to_str()
59        .map_err(|_| AllSourceError::ValidationError("Invalid authorization header".to_string()))?;
60
61    // Support both "Bearer <token>" and "<token>" formats
62    let token = if auth_header.starts_with("Bearer ") {
63        auth_header.trim_start_matches("Bearer ").trim()
64    } else if auth_header.starts_with("bearer ") {
65        auth_header.trim_start_matches("bearer ").trim()
66    } else {
67        auth_header.trim()
68    };
69
70    if token.is_empty() {
71        return Err(AllSourceError::ValidationError(
72            "Empty authorization token".to_string(),
73        ));
74    }
75
76    Ok(token.to_string())
77}
78
79/// Authentication middleware
80pub async fn auth_middleware(
81    State(auth_state): State<AuthState>,
82    mut request: Request,
83    next: Next,
84) -> Result<Response, AuthError> {
85    let headers = request.headers();
86
87    // Extract and validate token (JWT or API key)
88    let token = extract_token(headers)?;
89
90    let claims = if token.starts_with("ask_") {
91        // API Key authentication
92        auth_state.auth_manager.validate_api_key(&token)?
93    } else {
94        // JWT authentication
95        auth_state.auth_manager.validate_token(&token)?
96    };
97
98    // Insert auth context into request extensions
99    request.extensions_mut().insert(AuthContext { claims });
100
101    Ok(next.run(request).await)
102}
103
104/// Optional authentication middleware (allows unauthenticated requests)
105pub async fn optional_auth_middleware(
106    State(auth_state): State<AuthState>,
107    mut request: Request,
108    next: Next,
109) -> Response {
110    let headers = request.headers();
111
112    if let Ok(token) = extract_token(headers) {
113        // Try to authenticate, but don't fail if invalid
114        let claims = if token.starts_with("ask_") {
115            auth_state.auth_manager.validate_api_key(&token).ok()
116        } else {
117            auth_state.auth_manager.validate_token(&token).ok()
118        };
119
120        if let Some(claims) = claims {
121            request.extensions_mut().insert(AuthContext { claims });
122        }
123    }
124
125    next.run(request).await
126}
127
128/// Error type for authentication failures
129#[derive(Debug)]
130pub struct AuthError(AllSourceError);
131
132impl From<AllSourceError> for AuthError {
133    fn from(err: AllSourceError) -> Self {
134        AuthError(err)
135    }
136}
137
138impl IntoResponse for AuthError {
139    fn into_response(self) -> Response {
140        let (status, message) = match self.0 {
141            AllSourceError::ValidationError(msg) => (StatusCode::UNAUTHORIZED, msg),
142            _ => (
143                StatusCode::INTERNAL_SERVER_ERROR,
144                "Internal server error".to_string(),
145            ),
146        };
147
148        (status, message).into_response()
149    }
150}
151
152/// Axum extractor for authenticated requests
153pub struct Authenticated(pub AuthContext);
154
155impl<S> axum::extract::FromRequestParts<S> for Authenticated
156where
157    S: Send + Sync,
158{
159    type Rejection = (StatusCode, &'static str);
160
161    async fn from_request_parts(
162        parts: &mut axum::http::request::Parts,
163        _state: &S,
164    ) -> Result<Self, Self::Rejection> {
165        parts
166            .extensions
167            .get::<AuthContext>()
168            .cloned()
169            .map(Authenticated)
170            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))
171    }
172}
173
174/// Axum extractor for optional authentication (never rejects, returns Option)
175/// Use this for routes that work with or without authentication
176pub struct OptionalAuth(pub Option<AuthContext>);
177
178impl<S> axum::extract::FromRequestParts<S> for OptionalAuth
179where
180    S: Send + Sync,
181{
182    type Rejection = std::convert::Infallible;
183
184    async fn from_request_parts(
185        parts: &mut axum::http::request::Parts,
186        _state: &S,
187    ) -> Result<Self, Self::Rejection> {
188        Ok(OptionalAuth(parts.extensions.get::<AuthContext>().cloned()))
189    }
190}
191
192/// Axum extractor for admin-only requests
193pub struct Admin(pub AuthContext);
194
195impl<S> axum::extract::FromRequestParts<S> for Admin
196where
197    S: Send + Sync,
198{
199    type Rejection = (StatusCode, &'static str);
200
201    async fn from_request_parts(
202        parts: &mut axum::http::request::Parts,
203        _state: &S,
204    ) -> Result<Self, Self::Rejection> {
205        let auth_ctx = parts
206            .extensions
207            .get::<AuthContext>()
208            .cloned()
209            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
210
211        auth_ctx
212            .require_permission(Permission::Admin)
213            .map_err(|_| (StatusCode::FORBIDDEN, "Admin permission required"))?;
214
215        Ok(Admin(auth_ctx))
216    }
217}
218
219/// Rate limiting middleware
220/// Checks rate limits based on tenant_id from auth context
221pub async fn rate_limit_middleware(
222    State(rate_limit_state): State<RateLimitState>,
223    request: Request,
224    next: Next,
225) -> Result<Response, RateLimitError> {
226    // Extract auth context from request
227    let auth_ctx = request
228        .extensions()
229        .get::<AuthContext>()
230        .ok_or_else(|| RateLimitError::Unauthorized)?;
231
232    // Check rate limit for this tenant
233    let result = rate_limit_state
234        .rate_limiter
235        .check_rate_limit(auth_ctx.tenant_id());
236
237    if !result.allowed {
238        return Err(RateLimitError::RateLimitExceeded {
239            retry_after: result.retry_after.unwrap_or_default().as_secs(),
240            limit: result.limit,
241        });
242    }
243
244    // Add rate limit headers to response
245    let mut response = next.run(request).await;
246    let headers = response.headers_mut();
247    headers.insert(
248        "X-RateLimit-Limit",
249        result.limit.to_string().parse().unwrap(),
250    );
251    headers.insert(
252        "X-RateLimit-Remaining",
253        result.remaining.to_string().parse().unwrap(),
254    );
255
256    Ok(response)
257}
258
259/// Error type for rate limiting failures
260#[derive(Debug)]
261pub enum RateLimitError {
262    RateLimitExceeded { retry_after: u64, limit: u32 },
263    Unauthorized,
264}
265
266impl IntoResponse for RateLimitError {
267    fn into_response(self) -> Response {
268        match self {
269            RateLimitError::RateLimitExceeded { retry_after, limit } => {
270                let mut response = (
271                    StatusCode::TOO_MANY_REQUESTS,
272                    format!("Rate limit exceeded. Limit: {} requests/min", limit),
273                )
274                    .into_response();
275
276                if retry_after > 0 {
277                    response
278                        .headers_mut()
279                        .insert("Retry-After", retry_after.to_string().parse().unwrap());
280                }
281
282                response
283            }
284            RateLimitError::Unauthorized => (
285                StatusCode::UNAUTHORIZED,
286                "Authentication required for rate limiting",
287            )
288                .into_response(),
289        }
290    }
291}
292
293/// Helper macro to require specific permission
294#[macro_export]
295macro_rules! require_permission {
296    ($auth:expr, $perm:expr) => {
297        $auth.0.require_permission($perm).map_err(|_| {
298            (
299                axum::http::StatusCode::FORBIDDEN,
300                "Insufficient permissions",
301            )
302        })?
303    };
304}
305
306// ============================================================================
307// Tenant Isolation Middleware (Phase 5B)
308// ============================================================================
309
310use crate::domain::entities::Tenant;
311use crate::domain::repositories::TenantRepository;
312use crate::domain::value_objects::TenantId;
313
314/// Tenant isolation state for middleware
315#[derive(Clone)]
316pub struct TenantState<R: TenantRepository> {
317    pub tenant_repository: Arc<R>,
318}
319
320/// Validated tenant context injected into requests
321///
322/// This context is created by the tenant_isolation_middleware after
323/// validating that the tenant exists and is active.
324#[derive(Debug, Clone)]
325pub struct TenantContext {
326    pub tenant: Tenant,
327}
328
329impl TenantContext {
330    /// Get the tenant ID
331    pub fn tenant_id(&self) -> &TenantId {
332        self.tenant.id()
333    }
334
335    /// Check if tenant is active
336    pub fn is_active(&self) -> bool {
337        self.tenant.is_active()
338    }
339}
340
341/// Tenant isolation middleware
342///
343/// Validates that the authenticated tenant exists and is active.
344/// Injects TenantContext into the request for use by handlers.
345///
346/// # Phase 5B: Tenant Isolation
347/// This middleware enforces tenant boundaries by:
348/// 1. Extracting tenant_id from AuthContext
349/// 2. Loading tenant from repository
350/// 3. Validating tenant is active
351/// 4. Injecting TenantContext into request extensions
352///
353/// Must be applied after auth_middleware.
354pub async fn tenant_isolation_middleware<R: TenantRepository + 'static>(
355    State(tenant_state): State<TenantState<R>>,
356    mut request: Request,
357    next: Next,
358) -> Result<Response, TenantError> {
359    // Extract auth context (must be authenticated)
360    let auth_ctx = request
361        .extensions()
362        .get::<AuthContext>()
363        .ok_or(TenantError::Unauthorized)?
364        .clone();
365
366    // Parse tenant ID
367    let tenant_id =
368        TenantId::new(auth_ctx.tenant_id().to_string()).map_err(|_| TenantError::InvalidTenant)?;
369
370    // Load tenant from repository
371    let tenant = tenant_state
372        .tenant_repository
373        .find_by_id(&tenant_id)
374        .await
375        .map_err(|e| TenantError::RepositoryError(e.to_string()))?
376        .ok_or(TenantError::TenantNotFound)?;
377
378    // Validate tenant is active
379    if !tenant.is_active() {
380        return Err(TenantError::TenantInactive);
381    }
382
383    // Inject tenant context into request
384    request.extensions_mut().insert(TenantContext { tenant });
385
386    // Continue to next middleware/handler
387    Ok(next.run(request).await)
388}
389
390/// Error type for tenant isolation failures
391#[derive(Debug)]
392pub enum TenantError {
393    Unauthorized,
394    InvalidTenant,
395    TenantNotFound,
396    TenantInactive,
397    RepositoryError(String),
398}
399
400impl IntoResponse for TenantError {
401    fn into_response(self) -> Response {
402        let (status, message) = match self {
403            TenantError::Unauthorized => (
404                StatusCode::UNAUTHORIZED,
405                "Authentication required for tenant access",
406            ),
407            TenantError::InvalidTenant => (StatusCode::BAD_REQUEST, "Invalid tenant identifier"),
408            TenantError::TenantNotFound => (StatusCode::NOT_FOUND, "Tenant not found"),
409            TenantError::TenantInactive => (StatusCode::FORBIDDEN, "Tenant is inactive"),
410            TenantError::RepositoryError(_) => (
411                StatusCode::INTERNAL_SERVER_ERROR,
412                "Failed to validate tenant",
413            ),
414        };
415
416        (status, message).into_response()
417    }
418}
419
420// ============================================================================
421// Request ID Middleware (Phase 5C)
422// ============================================================================
423
424use uuid::Uuid;
425
426/// Request context with unique ID for tracing
427#[derive(Debug, Clone)]
428pub struct RequestId(pub String);
429
430impl RequestId {
431    /// Generate a new request ID
432    pub fn new() -> Self {
433        Self(Uuid::new_v4().to_string())
434    }
435
436    /// Get the request ID as a string
437    pub fn as_str(&self) -> &str {
438        &self.0
439    }
440}
441
442/// Request ID middleware
443///
444/// Generates a unique request ID for each request and injects it into:
445/// - Request extensions (for use in handlers/logging)
446/// - Response headers (X-Request-ID)
447///
448/// If the request already has an X-Request-ID header, it will be used instead.
449///
450/// # Phase 5C: Request Tracing
451/// This middleware enables distributed tracing by:
452/// 1. Generating unique IDs for each request
453/// 2. Propagating IDs through the request lifecycle
454/// 3. Returning IDs in response headers
455/// 4. Supporting client-provided request IDs
456pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
457    // Check if request already has a request ID
458    let request_id = request
459        .headers()
460        .get("x-request-id")
461        .and_then(|v| v.to_str().ok())
462        .map(|s| RequestId(s.to_string()))
463        .unwrap_or_else(RequestId::new);
464
465    // Store request ID in extensions
466    request.extensions_mut().insert(request_id.clone());
467
468    // Process request
469    let mut response = next.run(request).await;
470
471    // Add request ID to response headers
472    response
473        .headers_mut()
474        .insert("x-request-id", request_id.0.parse().unwrap());
475
476    response
477}
478
479// ============================================================================
480// Security Headers Middleware (Phase 5C)
481// ============================================================================
482
483/// Security headers configuration
484#[derive(Debug, Clone)]
485pub struct SecurityConfig {
486    /// Enable HSTS (HTTP Strict Transport Security)
487    pub enable_hsts: bool,
488    /// HSTS max age in seconds
489    pub hsts_max_age: u32,
490    /// Enable X-Frame-Options
491    pub enable_frame_options: bool,
492    /// X-Frame-Options value
493    pub frame_options: FrameOptions,
494    /// Enable X-Content-Type-Options
495    pub enable_content_type_options: bool,
496    /// Enable X-XSS-Protection
497    pub enable_xss_protection: bool,
498    /// Content Security Policy
499    pub csp: Option<String>,
500    /// CORS allowed origins
501    pub cors_origins: Vec<String>,
502    /// CORS allowed methods
503    pub cors_methods: Vec<String>,
504    /// CORS allowed headers
505    pub cors_headers: Vec<String>,
506    /// CORS max age
507    pub cors_max_age: u32,
508}
509
510#[derive(Debug, Clone)]
511pub enum FrameOptions {
512    Deny,
513    SameOrigin,
514    AllowFrom(String),
515}
516
517impl Default for SecurityConfig {
518    fn default() -> Self {
519        Self {
520            enable_hsts: true,
521            hsts_max_age: 31536000, // 1 year
522            enable_frame_options: true,
523            frame_options: FrameOptions::Deny,
524            enable_content_type_options: true,
525            enable_xss_protection: true,
526            csp: Some("default-src 'self'".to_string()),
527            cors_origins: vec!["*".to_string()],
528            cors_methods: vec![
529                "GET".to_string(),
530                "POST".to_string(),
531                "PUT".to_string(),
532                "DELETE".to_string(),
533            ],
534            cors_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
535            cors_max_age: 3600,
536        }
537    }
538}
539
540#[derive(Clone)]
541pub struct SecurityState {
542    pub config: SecurityConfig,
543}
544
545/// Security headers middleware
546///
547/// Adds security-related HTTP headers to all responses:
548/// - HSTS (Strict-Transport-Security)
549/// - X-Frame-Options
550/// - X-Content-Type-Options
551/// - X-XSS-Protection
552/// - Content-Security-Policy
553/// - CORS headers
554///
555/// # Phase 5C: Security Hardening
556/// This middleware provides defense-in-depth by:
557/// 1. Preventing clickjacking (X-Frame-Options)
558/// 2. Preventing MIME sniffing (X-Content-Type-Options)
559/// 3. Enforcing HTTPS (HSTS)
560/// 4. Preventing XSS (CSP, X-XSS-Protection)
561/// 5. Enabling CORS for controlled access
562pub async fn security_headers_middleware(
563    State(security_state): State<SecurityState>,
564    request: Request,
565    next: Next,
566) -> Response {
567    let mut response = next.run(request).await;
568    let headers = response.headers_mut();
569    let config = &security_state.config;
570
571    // HSTS
572    if config.enable_hsts {
573        headers.insert(
574            "strict-transport-security",
575            format!("max-age={}", config.hsts_max_age).parse().unwrap(),
576        );
577    }
578
579    // X-Frame-Options
580    if config.enable_frame_options {
581        let value = match &config.frame_options {
582            FrameOptions::Deny => "DENY",
583            FrameOptions::SameOrigin => "SAMEORIGIN",
584            FrameOptions::AllowFrom(origin) => origin,
585        };
586        headers.insert("x-frame-options", value.parse().unwrap());
587    }
588
589    // X-Content-Type-Options
590    if config.enable_content_type_options {
591        headers.insert("x-content-type-options", "nosniff".parse().unwrap());
592    }
593
594    // X-XSS-Protection
595    if config.enable_xss_protection {
596        headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
597    }
598
599    // Content-Security-Policy
600    if let Some(csp) = &config.csp {
601        headers.insert("content-security-policy", csp.parse().unwrap());
602    }
603
604    // CORS headers
605    headers.insert(
606        "access-control-allow-origin",
607        config.cors_origins.join(", ").parse().unwrap(),
608    );
609    headers.insert(
610        "access-control-allow-methods",
611        config.cors_methods.join(", ").parse().unwrap(),
612    );
613    headers.insert(
614        "access-control-allow-headers",
615        config.cors_headers.join(", ").parse().unwrap(),
616    );
617    headers.insert(
618        "access-control-max-age",
619        config.cors_max_age.to_string().parse().unwrap(),
620    );
621
622    response
623}
624
625// ============================================================================
626// IP Filtering Middleware (Phase 5C)
627// ============================================================================
628
629use crate::infrastructure::security::IpFilter;
630use std::net::SocketAddr;
631
632#[derive(Clone)]
633pub struct IpFilterState {
634    pub ip_filter: Arc<IpFilter>,
635}
636
637/// IP filtering middleware
638///
639/// Blocks or allows requests based on IP address rules.
640/// Supports both global and per-tenant IP filtering.
641///
642/// # Phase 5C: Access Control
643/// This middleware provides IP-based access control by:
644/// 1. Extracting client IP from request
645/// 2. Checking against global and tenant-specific rules
646/// 3. Blocking requests from unauthorized IPs
647/// 4. Supporting both allowlists and blocklists
648pub async fn ip_filter_middleware(
649    State(ip_filter_state): State<IpFilterState>,
650    request: Request,
651    next: Next,
652) -> Result<Response, IpFilterError> {
653    // Extract client IP address
654    let client_ip = request
655        .extensions()
656        .get::<axum::extract::ConnectInfo<SocketAddr>>()
657        .map(|connect_info| connect_info.0.ip())
658        .ok_or(IpFilterError::NoIpAddress)?;
659
660    // Check if this is a tenant-scoped request
661    let result = if let Some(tenant_ctx) = request.extensions().get::<TenantContext>() {
662        // Tenant-specific filtering
663        ip_filter_state
664            .ip_filter
665            .is_allowed_for_tenant(tenant_ctx.tenant_id(), &client_ip)
666    } else {
667        // Global filtering only
668        ip_filter_state.ip_filter.is_allowed(&client_ip)
669    };
670
671    // Block if not allowed
672    if !result.allowed {
673        return Err(IpFilterError::Blocked {
674            reason: result.reason,
675        });
676    }
677
678    // Allow request to proceed
679    Ok(next.run(request).await)
680}
681
682/// Error type for IP filtering failures
683#[derive(Debug)]
684pub enum IpFilterError {
685    NoIpAddress,
686    Blocked { reason: String },
687}
688
689impl IntoResponse for IpFilterError {
690    fn into_response(self) -> Response {
691        match self {
692            IpFilterError::NoIpAddress => (
693                StatusCode::BAD_REQUEST,
694                "Unable to determine client IP address",
695            )
696                .into_response(),
697            IpFilterError::Blocked { reason } => {
698                (StatusCode::FORBIDDEN, format!("Access denied: {}", reason)).into_response()
699            }
700        }
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707    use crate::infrastructure::security::auth::{Role, User};
708
709    #[test]
710    fn test_extract_bearer_token() {
711        let mut headers = HeaderMap::new();
712        headers.insert("authorization", "Bearer test_token_123".parse().unwrap());
713
714        let token = extract_token(&headers).unwrap();
715        assert_eq!(token, "test_token_123");
716    }
717
718    #[test]
719    fn test_extract_lowercase_bearer() {
720        let mut headers = HeaderMap::new();
721        headers.insert("authorization", "bearer test_token_123".parse().unwrap());
722
723        let token = extract_token(&headers).unwrap();
724        assert_eq!(token, "test_token_123");
725    }
726
727    #[test]
728    fn test_extract_plain_token() {
729        let mut headers = HeaderMap::new();
730        headers.insert("authorization", "test_token_123".parse().unwrap());
731
732        let token = extract_token(&headers).unwrap();
733        assert_eq!(token, "test_token_123");
734    }
735
736    #[test]
737    fn test_missing_auth_header() {
738        let headers = HeaderMap::new();
739        assert!(extract_token(&headers).is_err());
740    }
741
742    #[test]
743    fn test_auth_context_permissions() {
744        let claims = Claims::new(
745            "user1".to_string(),
746            "tenant1".to_string(),
747            Role::Developer,
748            chrono::Duration::hours(1),
749        );
750
751        let ctx = AuthContext { claims };
752
753        assert!(ctx.require_permission(Permission::Read).is_ok());
754        assert!(ctx.require_permission(Permission::Write).is_ok());
755        assert!(ctx.require_permission(Permission::Admin).is_err());
756    }
757}