Skip to main content

allsource_core/infrastructure/security/
middleware.rs

1use crate::{
2    error::AllSourceError,
3    infrastructure::security::{
4        auth::{AuthManager, Claims, Permission, Role},
5        rate_limit::RateLimiter,
6    },
7};
8use axum::{
9    extract::{Request, State},
10    http::{HeaderMap, StatusCode},
11    middleware::Next,
12    response::{IntoResponse, Response},
13};
14use std::sync::{Arc, LazyLock};
15
16/// Paths that bypass authentication (exact match)
17pub const AUTH_SKIP_PATHS: &[&str] = &[
18    "/health",
19    "/metrics",
20    "/api/v1/auth/register",
21    "/api/v1/auth/login",
22    "/api/v1/demo/seed",
23];
24
25/// Path prefixes that bypass authentication and rate limiting.
26///
27/// Internal endpoints are used by the sentinel process for automated failover
28/// (promote, repoint). They must not require API keys or be rate-limited,
29/// otherwise failover can timeout or fail when credentials are unavailable.
30pub const AUTH_SKIP_PREFIXES: &[&str] = &["/internal/"];
31
32/// Check if a path should skip authentication and rate limiting.
33#[inline]
34pub fn should_skip_auth(path: &str) -> bool {
35    AUTH_SKIP_PATHS.contains(&path) || AUTH_SKIP_PREFIXES.iter().any(|pfx| path.starts_with(pfx))
36}
37
38/// Check if development mode is enabled via environment variable.
39/// When enabled, authentication and rate limiting are bypassed for local development.
40///
41/// Set `ALLSOURCE_DEV_MODE=true` or `ALLSOURCE_DEV_MODE=1` to enable.
42///
43/// **WARNING**: Never enable this in production environments!
44static DEV_MODE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
45    let enabled = std::env::var("ALLSOURCE_DEV_MODE")
46        .map(|v| v == "true" || v == "1")
47        .unwrap_or(false);
48    if enabled {
49        tracing::warn!(
50            "⚠️  ALLSOURCE_DEV_MODE is enabled - authentication and rate limiting are DISABLED"
51        );
52        tracing::warn!("⚠️  This should NEVER be used in production!");
53    }
54    enabled
55});
56
57/// Check if dev mode is enabled
58#[inline]
59pub fn is_dev_mode() -> bool {
60    *DEV_MODE_ENABLED
61}
62
63/// Create a development-mode AuthContext with admin privileges
64fn dev_mode_auth_context() -> AuthContext {
65    AuthContext {
66        claims: Claims::new(
67            "dev-user".to_string(),
68            "dev-tenant".to_string(),
69            Role::Admin,
70            chrono::Duration::hours(24),
71        ),
72    }
73}
74
75/// Authentication state shared across requests
76#[derive(Clone)]
77pub struct AuthState {
78    pub auth_manager: Arc<AuthManager>,
79}
80
81/// Rate limiting state
82#[derive(Clone)]
83pub struct RateLimitState {
84    pub rate_limiter: Arc<RateLimiter>,
85}
86
87/// Authenticated request context
88#[derive(Debug, Clone)]
89pub struct AuthContext {
90    pub claims: Claims,
91}
92
93impl AuthContext {
94    /// Check if user has required permission
95    pub fn require_permission(&self, permission: Permission) -> Result<(), AllSourceError> {
96        if self.claims.has_permission(permission) {
97            Ok(())
98        } else {
99            Err(AllSourceError::ValidationError(
100                "Insufficient permissions".to_string(),
101            ))
102        }
103    }
104
105    /// Get tenant ID from context
106    pub fn tenant_id(&self) -> &str {
107        &self.claims.tenant_id
108    }
109
110    /// Get user ID from context
111    pub fn user_id(&self) -> &str {
112        &self.claims.sub
113    }
114}
115
116/// Extract token from Authorization header, with X-API-Key fallback for backwards compatibility.
117fn extract_token(headers: &HeaderMap) -> Result<String, AllSourceError> {
118    // Primary: Authorization header (Bearer <token> or plain <token>)
119    // Fallback: X-API-Key header (legacy, deprecated)
120    let auth_header = if let Some(val) = headers.get("authorization") {
121        val.to_str()
122            .map_err(|_| {
123                AllSourceError::ValidationError("Invalid authorization header".to_string())
124            })?
125            .to_string()
126    } else if let Some(val) = headers.get("x-api-key") {
127        val.to_str()
128            .map_err(|_| AllSourceError::ValidationError("Invalid X-API-Key header".to_string()))?
129            .to_string()
130    } else {
131        return Err(AllSourceError::ValidationError(
132            "Missing authorization header".to_string(),
133        ));
134    };
135
136    // Support both "Bearer <token>" and "<token>" formats
137    let token = if auth_header.starts_with("Bearer ") {
138        auth_header.trim_start_matches("Bearer ").trim()
139    } else if auth_header.starts_with("bearer ") {
140        auth_header.trim_start_matches("bearer ").trim()
141    } else {
142        auth_header.trim()
143    };
144
145    if token.is_empty() {
146        return Err(AllSourceError::ValidationError(
147            "Empty authorization token".to_string(),
148        ));
149    }
150
151    Ok(token.to_string())
152}
153
154/// Authentication middleware
155pub async fn auth_middleware(
156    State(auth_state): State<AuthState>,
157    mut request: Request,
158    next: Next,
159) -> Result<Response, AuthError> {
160    // Skip authentication for public and internal paths
161    let path = request.uri().path();
162    if should_skip_auth(path) {
163        return Ok(next.run(request).await);
164    }
165
166    // Dev mode: if a valid token is present, authenticate normally so that
167    // /me returns the real tenant_id (not a hardcoded "dev-tenant").
168    // Fall back to the synthetic dev context only when no token is provided.
169    if is_dev_mode() {
170        let headers = request.headers();
171        let auth_ctx = match extract_token(headers) {
172            Ok(token) => {
173                let claims = if token.starts_with("ask_") {
174                    auth_state.auth_manager.validate_api_key(&token).ok()
175                } else {
176                    auth_state.auth_manager.validate_token(&token).ok()
177                };
178                claims.map_or_else(dev_mode_auth_context, |c| AuthContext { claims: c })
179            }
180            Err(_) => dev_mode_auth_context(),
181        };
182        request.extensions_mut().insert(auth_ctx);
183        return Ok(next.run(request).await);
184    }
185
186    let headers = request.headers();
187
188    // Extract and validate token (JWT or API key)
189    let token = extract_token(headers)?;
190
191    let claims = if token.starts_with("ask_") {
192        // API Key authentication
193        auth_state.auth_manager.validate_api_key(&token)?
194    } else {
195        // JWT authentication
196        auth_state.auth_manager.validate_token(&token)?
197    };
198
199    // Insert auth context into request extensions
200    request.extensions_mut().insert(AuthContext { claims });
201
202    Ok(next.run(request).await)
203}
204
205/// Optional authentication middleware (allows unauthenticated requests)
206pub async fn optional_auth_middleware(
207    State(auth_state): State<AuthState>,
208    mut request: Request,
209    next: Next,
210) -> Response {
211    let headers = request.headers();
212
213    if let Ok(token) = extract_token(headers) {
214        // Try to authenticate, but don't fail if invalid
215        let claims = if token.starts_with("ask_") {
216            auth_state.auth_manager.validate_api_key(&token).ok()
217        } else {
218            auth_state.auth_manager.validate_token(&token).ok()
219        };
220
221        if let Some(claims) = claims {
222            request.extensions_mut().insert(AuthContext { claims });
223        }
224    }
225
226    next.run(request).await
227}
228
229/// Error type for authentication failures
230#[derive(Debug)]
231pub struct AuthError(AllSourceError);
232
233impl From<AllSourceError> for AuthError {
234    fn from(err: AllSourceError) -> Self {
235        AuthError(err)
236    }
237}
238
239impl IntoResponse for AuthError {
240    fn into_response(self) -> Response {
241        let (status, message) = match self.0 {
242            AllSourceError::ValidationError(msg) => (StatusCode::UNAUTHORIZED, msg),
243            _ => (
244                StatusCode::INTERNAL_SERVER_ERROR,
245                "Internal server error".to_string(),
246            ),
247        };
248
249        (status, message).into_response()
250    }
251}
252
253/// Axum extractor for authenticated requests
254pub struct Authenticated(pub AuthContext);
255
256impl<S> axum::extract::FromRequestParts<S> for Authenticated
257where
258    S: Send + Sync,
259{
260    type Rejection = (StatusCode, &'static str);
261
262    async fn from_request_parts(
263        parts: &mut axum::http::request::Parts,
264        _state: &S,
265    ) -> Result<Self, Self::Rejection> {
266        parts
267            .extensions
268            .get::<AuthContext>()
269            .cloned()
270            .map(Authenticated)
271            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))
272    }
273}
274
275/// Axum extractor for optional authentication (never rejects, returns Option)
276/// Use this for routes that work with or without authentication
277pub struct OptionalAuth(pub Option<AuthContext>);
278
279impl<S> axum::extract::FromRequestParts<S> for OptionalAuth
280where
281    S: Send + Sync,
282{
283    type Rejection = std::convert::Infallible;
284
285    async fn from_request_parts(
286        parts: &mut axum::http::request::Parts,
287        _state: &S,
288    ) -> Result<Self, Self::Rejection> {
289        Ok(OptionalAuth(parts.extensions.get::<AuthContext>().cloned()))
290    }
291}
292
293/// Axum extractor for admin-only requests
294pub struct Admin(pub AuthContext);
295
296impl<S> axum::extract::FromRequestParts<S> for Admin
297where
298    S: Send + Sync,
299{
300    type Rejection = (StatusCode, &'static str);
301
302    async fn from_request_parts(
303        parts: &mut axum::http::request::Parts,
304        _state: &S,
305    ) -> Result<Self, Self::Rejection> {
306        let auth_ctx = parts
307            .extensions
308            .get::<AuthContext>()
309            .cloned()
310            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
311
312        auth_ctx
313            .require_permission(Permission::Admin)
314            .map_err(|_| (StatusCode::FORBIDDEN, "Admin permission required"))?;
315
316        Ok(Admin(auth_ctx))
317    }
318}
319
320/// Rate limiting middleware
321/// Checks rate limits based on tenant_id from auth context
322pub async fn rate_limit_middleware(
323    State(rate_limit_state): State<RateLimitState>,
324    request: Request,
325    next: Next,
326) -> Result<Response, RateLimitError> {
327    // Skip rate limiting for public and internal paths
328    let path = request.uri().path();
329    if should_skip_auth(path) {
330        return Ok(next.run(request).await);
331    }
332
333    // Dev mode: bypass rate limiting entirely
334    if is_dev_mode() {
335        return Ok(next.run(request).await);
336    }
337
338    // Extract auth context from request
339    let auth_ctx = request
340        .extensions()
341        .get::<AuthContext>()
342        .ok_or(RateLimitError::Unauthorized)?;
343
344    // Check rate limit for this tenant
345    let result = rate_limit_state
346        .rate_limiter
347        .check_rate_limit(auth_ctx.tenant_id());
348
349    if !result.allowed {
350        return Err(RateLimitError::RateLimitExceeded {
351            retry_after: result.retry_after.unwrap_or_default().as_secs(),
352            limit: result.limit,
353        });
354    }
355
356    // Add rate limit headers to response
357    let mut response = next.run(request).await;
358    let headers = response.headers_mut();
359    headers.insert(
360        "X-RateLimit-Limit",
361        result.limit.to_string().parse().unwrap(),
362    );
363    headers.insert(
364        "X-RateLimit-Remaining",
365        result.remaining.to_string().parse().unwrap(),
366    );
367
368    Ok(response)
369}
370
371/// Error type for rate limiting failures
372#[derive(Debug)]
373pub enum RateLimitError {
374    RateLimitExceeded { retry_after: u64, limit: u32 },
375    Unauthorized,
376}
377
378impl IntoResponse for RateLimitError {
379    fn into_response(self) -> Response {
380        match self {
381            RateLimitError::RateLimitExceeded { retry_after, limit } => {
382                let mut response = (
383                    StatusCode::TOO_MANY_REQUESTS,
384                    format!("Rate limit exceeded. Limit: {limit} requests/min"),
385                )
386                    .into_response();
387
388                if retry_after > 0 {
389                    response
390                        .headers_mut()
391                        .insert("Retry-After", retry_after.to_string().parse().unwrap());
392                }
393
394                response
395            }
396            RateLimitError::Unauthorized => (
397                StatusCode::UNAUTHORIZED,
398                "Authentication required for rate limiting",
399            )
400                .into_response(),
401        }
402    }
403}
404
405/// Helper macro to require specific permission
406#[macro_export]
407macro_rules! require_permission {
408    ($auth:expr, $perm:expr) => {
409        $auth.0.require_permission($perm).map_err(|_| {
410            (
411                axum::http::StatusCode::FORBIDDEN,
412                "Insufficient permissions",
413            )
414        })?
415    };
416}
417
418// ============================================================================
419// Tenant Isolation Middleware (Phase 5B)
420// ============================================================================
421
422use crate::domain::{entities::Tenant, repositories::TenantRepository, value_objects::TenantId};
423
424/// Tenant isolation state for middleware
425#[derive(Clone)]
426pub struct TenantState<R: TenantRepository> {
427    pub tenant_repository: Arc<R>,
428}
429
430/// Validated tenant context injected into requests
431///
432/// This context is created by the tenant_isolation_middleware after
433/// validating that the tenant exists and is active.
434#[derive(Debug, Clone)]
435pub struct TenantContext {
436    pub tenant: Tenant,
437}
438
439impl TenantContext {
440    /// Get the tenant ID
441    pub fn tenant_id(&self) -> &TenantId {
442        self.tenant.id()
443    }
444
445    /// Check if tenant is active
446    pub fn is_active(&self) -> bool {
447        self.tenant.is_active()
448    }
449}
450
451/// Tenant isolation middleware
452///
453/// Validates that the authenticated tenant exists and is active.
454/// Injects TenantContext into the request for use by handlers.
455///
456/// # Phase 5B: Tenant Isolation
457/// This middleware enforces tenant boundaries by:
458/// 1. Extracting tenant_id from AuthContext
459/// 2. Loading tenant from repository
460/// 3. Validating tenant is active
461/// 4. Injecting TenantContext into request extensions
462///
463/// Must be applied after auth_middleware.
464pub async fn tenant_isolation_middleware<R: TenantRepository + 'static>(
465    State(tenant_state): State<TenantState<R>>,
466    mut request: Request,
467    next: Next,
468) -> Result<Response, TenantError> {
469    // Extract auth context (must be authenticated)
470    let auth_ctx = request
471        .extensions()
472        .get::<AuthContext>()
473        .ok_or(TenantError::Unauthorized)?
474        .clone();
475
476    // Parse tenant ID
477    let tenant_id =
478        TenantId::new(auth_ctx.tenant_id().to_string()).map_err(|_| TenantError::InvalidTenant)?;
479
480    // Load tenant from repository
481    let tenant = tenant_state
482        .tenant_repository
483        .find_by_id(&tenant_id)
484        .await
485        .map_err(|e| TenantError::RepositoryError(e.to_string()))?
486        .ok_or(TenantError::TenantNotFound)?;
487
488    // Validate tenant is active
489    if !tenant.is_active() {
490        return Err(TenantError::TenantInactive);
491    }
492
493    // Inject tenant context into request
494    request.extensions_mut().insert(TenantContext { tenant });
495
496    // Continue to next middleware/handler
497    Ok(next.run(request).await)
498}
499
500/// Error type for tenant isolation failures
501#[derive(Debug)]
502pub enum TenantError {
503    Unauthorized,
504    InvalidTenant,
505    TenantNotFound,
506    TenantInactive,
507    RepositoryError(String),
508}
509
510impl IntoResponse for TenantError {
511    fn into_response(self) -> Response {
512        let (status, message) = match self {
513            TenantError::Unauthorized => (
514                StatusCode::UNAUTHORIZED,
515                "Authentication required for tenant access",
516            ),
517            TenantError::InvalidTenant => (StatusCode::BAD_REQUEST, "Invalid tenant identifier"),
518            TenantError::TenantNotFound => (StatusCode::NOT_FOUND, "Tenant not found"),
519            TenantError::TenantInactive => (StatusCode::FORBIDDEN, "Tenant is inactive"),
520            TenantError::RepositoryError(_) => (
521                StatusCode::INTERNAL_SERVER_ERROR,
522                "Failed to validate tenant",
523            ),
524        };
525
526        (status, message).into_response()
527    }
528}
529
530// ============================================================================
531// Request ID Middleware (Phase 5C)
532// ============================================================================
533
534use uuid::Uuid;
535
536/// Request context with unique ID for tracing
537#[derive(Debug, Clone)]
538pub struct RequestId(pub String);
539
540impl Default for RequestId {
541    fn default() -> Self {
542        Self::new()
543    }
544}
545
546impl RequestId {
547    /// Generate a new request ID
548    pub fn new() -> Self {
549        Self(Uuid::new_v4().to_string())
550    }
551
552    /// Get the request ID as a string
553    pub fn as_str(&self) -> &str {
554        &self.0
555    }
556}
557
558/// Request ID middleware
559///
560/// Generates a unique request ID for each request and injects it into:
561/// - Request extensions (for use in handlers/logging)
562/// - Response headers (X-Request-ID)
563///
564/// If the request already has an X-Request-ID header, it will be used instead.
565///
566/// # Phase 5C: Request Tracing
567/// This middleware enables distributed tracing by:
568/// 1. Generating unique IDs for each request
569/// 2. Propagating IDs through the request lifecycle
570/// 3. Returning IDs in response headers
571/// 4. Supporting client-provided request IDs
572pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
573    // Check if request already has a request ID
574    let request_id = request
575        .headers()
576        .get("x-request-id")
577        .and_then(|v| v.to_str().ok())
578        .map_or_else(RequestId::new, |s| RequestId(s.to_string()));
579
580    // Store request ID in extensions
581    request.extensions_mut().insert(request_id.clone());
582
583    // Process request
584    let mut response = next.run(request).await;
585
586    // Add request ID to response headers
587    response
588        .headers_mut()
589        .insert("x-request-id", request_id.0.parse().unwrap());
590
591    response
592}
593
594// ============================================================================
595// Security Headers Middleware (Phase 5C)
596// ============================================================================
597
598/// Security headers configuration
599#[derive(Debug, Clone)]
600pub struct SecurityConfig {
601    /// Enable HSTS (HTTP Strict Transport Security)
602    pub enable_hsts: bool,
603    /// HSTS max age in seconds
604    pub hsts_max_age: u32,
605    /// Enable X-Frame-Options
606    pub enable_frame_options: bool,
607    /// X-Frame-Options value
608    pub frame_options: FrameOptions,
609    /// Enable X-Content-Type-Options
610    pub enable_content_type_options: bool,
611    /// Enable X-XSS-Protection
612    pub enable_xss_protection: bool,
613    /// Content Security Policy
614    pub csp: Option<String>,
615    /// CORS allowed origins
616    pub cors_origins: Vec<String>,
617    /// CORS allowed methods
618    pub cors_methods: Vec<String>,
619    /// CORS allowed headers
620    pub cors_headers: Vec<String>,
621    /// CORS max age
622    pub cors_max_age: u32,
623}
624
625#[derive(Debug, Clone)]
626pub enum FrameOptions {
627    Deny,
628    SameOrigin,
629    AllowFrom(String),
630}
631
632impl Default for SecurityConfig {
633    fn default() -> Self {
634        Self {
635            enable_hsts: true,
636            hsts_max_age: 31_536_000, // 1 year
637            enable_frame_options: true,
638            frame_options: FrameOptions::Deny,
639            enable_content_type_options: true,
640            enable_xss_protection: true,
641            csp: Some("default-src 'self'".to_string()),
642            cors_origins: vec!["*".to_string()],
643            cors_methods: vec![
644                "GET".to_string(),
645                "POST".to_string(),
646                "PUT".to_string(),
647                "DELETE".to_string(),
648            ],
649            cors_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
650            cors_max_age: 3600,
651        }
652    }
653}
654
655#[derive(Clone)]
656pub struct SecurityState {
657    pub config: SecurityConfig,
658}
659
660/// Security headers middleware
661///
662/// Adds security-related HTTP headers to all responses:
663/// - HSTS (Strict-Transport-Security)
664/// - X-Frame-Options
665/// - X-Content-Type-Options
666/// - X-XSS-Protection
667/// - Content-Security-Policy
668/// - CORS headers
669///
670/// # Phase 5C: Security Hardening
671/// This middleware provides defense-in-depth by:
672/// 1. Preventing clickjacking (X-Frame-Options)
673/// 2. Preventing MIME sniffing (X-Content-Type-Options)
674/// 3. Enforcing HTTPS (HSTS)
675/// 4. Preventing XSS (CSP, X-XSS-Protection)
676/// 5. Enabling CORS for controlled access
677pub async fn security_headers_middleware(
678    State(security_state): State<SecurityState>,
679    request: Request,
680    next: Next,
681) -> Response {
682    let mut response = next.run(request).await;
683    let headers = response.headers_mut();
684    let config = &security_state.config;
685
686    // HSTS
687    if config.enable_hsts {
688        headers.insert(
689            "strict-transport-security",
690            format!("max-age={}", config.hsts_max_age).parse().unwrap(),
691        );
692    }
693
694    // X-Frame-Options
695    if config.enable_frame_options {
696        let value = match &config.frame_options {
697            FrameOptions::Deny => "DENY",
698            FrameOptions::SameOrigin => "SAMEORIGIN",
699            FrameOptions::AllowFrom(origin) => origin,
700        };
701        headers.insert("x-frame-options", value.parse().unwrap());
702    }
703
704    // X-Content-Type-Options
705    if config.enable_content_type_options {
706        headers.insert("x-content-type-options", "nosniff".parse().unwrap());
707    }
708
709    // X-XSS-Protection
710    if config.enable_xss_protection {
711        headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
712    }
713
714    // Content-Security-Policy
715    if let Some(csp) = &config.csp {
716        headers.insert("content-security-policy", csp.parse().unwrap());
717    }
718
719    // CORS headers
720    headers.insert(
721        "access-control-allow-origin",
722        config.cors_origins.join(", ").parse().unwrap(),
723    );
724    headers.insert(
725        "access-control-allow-methods",
726        config.cors_methods.join(", ").parse().unwrap(),
727    );
728    headers.insert(
729        "access-control-allow-headers",
730        config.cors_headers.join(", ").parse().unwrap(),
731    );
732    headers.insert(
733        "access-control-max-age",
734        config.cors_max_age.to_string().parse().unwrap(),
735    );
736
737    response
738}
739
740// ============================================================================
741// IP Filtering Middleware (Phase 5C)
742// ============================================================================
743
744use crate::infrastructure::security::IpFilter;
745use std::net::SocketAddr;
746
747#[derive(Clone)]
748pub struct IpFilterState {
749    pub ip_filter: Arc<IpFilter>,
750}
751
752/// IP filtering middleware
753///
754/// Blocks or allows requests based on IP address rules.
755/// Supports both global and per-tenant IP filtering.
756///
757/// # Phase 5C: Access Control
758/// This middleware provides IP-based access control by:
759/// 1. Extracting client IP from request
760/// 2. Checking against global and tenant-specific rules
761/// 3. Blocking requests from unauthorized IPs
762/// 4. Supporting both allowlists and blocklists
763pub async fn ip_filter_middleware(
764    State(ip_filter_state): State<IpFilterState>,
765    request: Request,
766    next: Next,
767) -> Result<Response, IpFilterError> {
768    // Extract client IP address
769    let client_ip = request
770        .extensions()
771        .get::<axum::extract::ConnectInfo<SocketAddr>>()
772        .map(|connect_info| connect_info.0.ip())
773        .ok_or(IpFilterError::NoIpAddress)?;
774
775    // Check if this is a tenant-scoped request
776    let result = if let Some(tenant_ctx) = request.extensions().get::<TenantContext>() {
777        // Tenant-specific filtering
778        ip_filter_state
779            .ip_filter
780            .is_allowed_for_tenant(tenant_ctx.tenant_id(), &client_ip)
781    } else {
782        // Global filtering only
783        ip_filter_state.ip_filter.is_allowed(&client_ip)
784    };
785
786    // Block if not allowed
787    if !result.allowed {
788        return Err(IpFilterError::Blocked {
789            reason: result.reason,
790        });
791    }
792
793    // Allow request to proceed
794    Ok(next.run(request).await)
795}
796
797/// Error type for IP filtering failures
798#[derive(Debug)]
799pub enum IpFilterError {
800    NoIpAddress,
801    Blocked { reason: String },
802}
803
804impl IntoResponse for IpFilterError {
805    fn into_response(self) -> Response {
806        match self {
807            IpFilterError::NoIpAddress => (
808                StatusCode::BAD_REQUEST,
809                "Unable to determine client IP address",
810            )
811                .into_response(),
812            IpFilterError::Blocked { reason } => {
813                (StatusCode::FORBIDDEN, format!("Access denied: {reason}")).into_response()
814            }
815        }
816    }
817}
818
819#[cfg(test)]
820mod tests {
821    use super::*;
822    use crate::infrastructure::security::auth::Role;
823
824    #[test]
825    fn test_extract_bearer_token() {
826        let mut headers = HeaderMap::new();
827        headers.insert("authorization", "Bearer test_token_123".parse().unwrap());
828
829        let token = extract_token(&headers).unwrap();
830        assert_eq!(token, "test_token_123");
831    }
832
833    #[test]
834    fn test_extract_lowercase_bearer() {
835        let mut headers = HeaderMap::new();
836        headers.insert("authorization", "bearer test_token_123".parse().unwrap());
837
838        let token = extract_token(&headers).unwrap();
839        assert_eq!(token, "test_token_123");
840    }
841
842    #[test]
843    fn test_extract_plain_token() {
844        let mut headers = HeaderMap::new();
845        headers.insert("authorization", "test_token_123".parse().unwrap());
846
847        let token = extract_token(&headers).unwrap();
848        assert_eq!(token, "test_token_123");
849    }
850
851    #[test]
852    fn test_missing_auth_header() {
853        let headers = HeaderMap::new();
854        assert!(extract_token(&headers).is_err());
855    }
856
857    #[test]
858    fn test_empty_auth_header() {
859        let mut headers = HeaderMap::new();
860        headers.insert("authorization", "".parse().unwrap());
861        assert!(extract_token(&headers).is_err());
862    }
863
864    #[test]
865    fn test_bearer_with_empty_token() {
866        let mut headers = HeaderMap::new();
867        headers.insert("authorization", "Bearer ".parse().unwrap());
868        assert!(extract_token(&headers).is_err());
869    }
870
871    #[test]
872    fn test_auth_context_permissions() {
873        let claims = Claims::new(
874            "user1".to_string(),
875            "tenant1".to_string(),
876            Role::Developer,
877            chrono::Duration::hours(1),
878        );
879
880        let ctx = AuthContext { claims };
881
882        assert!(ctx.require_permission(Permission::Read).is_ok());
883        assert!(ctx.require_permission(Permission::Write).is_ok());
884        assert!(ctx.require_permission(Permission::Admin).is_err());
885    }
886
887    #[test]
888    fn test_auth_context_admin_permissions() {
889        let claims = Claims::new(
890            "admin1".to_string(),
891            "tenant1".to_string(),
892            Role::Admin,
893            chrono::Duration::hours(1),
894        );
895
896        let ctx = AuthContext { claims };
897
898        assert!(ctx.require_permission(Permission::Read).is_ok());
899        assert!(ctx.require_permission(Permission::Write).is_ok());
900        assert!(ctx.require_permission(Permission::Admin).is_ok());
901    }
902
903    #[test]
904    fn test_auth_context_readonly_permissions() {
905        let claims = Claims::new(
906            "readonly1".to_string(),
907            "tenant1".to_string(),
908            Role::ReadOnly,
909            chrono::Duration::hours(1),
910        );
911
912        let ctx = AuthContext { claims };
913
914        assert!(ctx.require_permission(Permission::Read).is_ok());
915        assert!(ctx.require_permission(Permission::Write).is_err());
916        assert!(ctx.require_permission(Permission::Admin).is_err());
917    }
918
919    #[test]
920    fn test_auth_context_tenant_id() {
921        let claims = Claims::new(
922            "user1".to_string(),
923            "my-tenant".to_string(),
924            Role::Developer,
925            chrono::Duration::hours(1),
926        );
927
928        let ctx = AuthContext { claims };
929        assert_eq!(ctx.tenant_id(), "my-tenant");
930    }
931
932    #[test]
933    fn test_auth_context_user_id() {
934        let claims = Claims::new(
935            "my-user".to_string(),
936            "tenant1".to_string(),
937            Role::Developer,
938            chrono::Duration::hours(1),
939        );
940
941        let ctx = AuthContext { claims };
942        assert_eq!(ctx.user_id(), "my-user");
943    }
944
945    #[test]
946    fn test_request_id_new() {
947        let id1 = RequestId::new();
948        let id2 = RequestId::new();
949
950        // IDs should be unique
951        assert_ne!(id1.as_str(), id2.as_str());
952        // IDs should be valid UUIDs (36 chars with hyphens)
953        assert_eq!(id1.as_str().len(), 36);
954    }
955
956    #[test]
957    fn test_request_id_default() {
958        let id = RequestId::default();
959        assert_eq!(id.as_str().len(), 36);
960    }
961
962    #[test]
963    fn test_security_config_default() {
964        let config = SecurityConfig::default();
965
966        assert!(config.enable_hsts);
967        assert_eq!(config.hsts_max_age, 31536000);
968        assert!(config.enable_frame_options);
969        assert!(config.enable_content_type_options);
970        assert!(config.enable_xss_protection);
971        assert!(config.csp.is_some());
972    }
973
974    #[test]
975    fn test_frame_options_variants() {
976        let deny = FrameOptions::Deny;
977        let same_origin = FrameOptions::SameOrigin;
978        let allow_from = FrameOptions::AllowFrom("https://example.com".to_string());
979
980        // Check that variants are distinct via debug formatting
981        assert!(format!("{deny:?}").contains("Deny"));
982        assert!(format!("{same_origin:?}").contains("SameOrigin"));
983        assert!(format!("{allow_from:?}").contains("AllowFrom"));
984    }
985
986    #[test]
987    fn test_auth_error_from_validation_error() {
988        let error = AllSourceError::ValidationError("test error".to_string());
989        let auth_error = AuthError::from(error);
990        assert!(format!("{auth_error:?}").contains("ValidationError"));
991    }
992
993    #[test]
994    fn test_rate_limit_error_display() {
995        let error = RateLimitError::RateLimitExceeded {
996            retry_after: 60,
997            limit: 100,
998        };
999        assert!(format!("{error:?}").contains("RateLimitExceeded"));
1000
1001        let unauth_error = RateLimitError::Unauthorized;
1002        assert!(format!("{unauth_error:?}").contains("Unauthorized"));
1003    }
1004
1005    #[test]
1006    fn test_tenant_error_variants() {
1007        let errors = vec![
1008            TenantError::Unauthorized,
1009            TenantError::InvalidTenant,
1010            TenantError::TenantNotFound,
1011            TenantError::TenantInactive,
1012            TenantError::RepositoryError("test".to_string()),
1013        ];
1014
1015        for error in errors {
1016            // Ensure each variant can be debug-formatted
1017            let _ = format!("{error:?}");
1018        }
1019    }
1020
1021    #[test]
1022    fn test_ip_filter_error_variants() {
1023        let errors = vec![
1024            IpFilterError::NoIpAddress,
1025            IpFilterError::Blocked {
1026                reason: "blocked".to_string(),
1027            },
1028        ];
1029
1030        for error in errors {
1031            let _ = format!("{error:?}");
1032        }
1033    }
1034
1035    #[test]
1036    fn test_security_state_clone() {
1037        let config = SecurityConfig::default();
1038        let state = SecurityState {
1039            config: config.clone(),
1040        };
1041        let cloned = state.clone();
1042        assert_eq!(cloned.config.hsts_max_age, config.hsts_max_age);
1043    }
1044
1045    #[test]
1046    fn test_auth_state_clone() {
1047        let auth_manager = Arc::new(AuthManager::new("test-secret"));
1048        let state = AuthState { auth_manager };
1049        let cloned = state.clone();
1050        assert!(Arc::ptr_eq(&state.auth_manager, &cloned.auth_manager));
1051    }
1052
1053    #[test]
1054    fn test_rate_limit_state_clone() {
1055        use crate::infrastructure::security::rate_limit::RateLimitConfig;
1056        let rate_limiter = Arc::new(RateLimiter::new(RateLimitConfig::free_tier()));
1057        let state = RateLimitState { rate_limiter };
1058        let cloned = state.clone();
1059        assert!(Arc::ptr_eq(&state.rate_limiter, &cloned.rate_limiter));
1060    }
1061
1062    #[test]
1063    fn test_auth_skip_paths_contains_expected() {
1064        // Verify public paths are configured for auth/rate-limit skipping
1065        assert!(should_skip_auth("/health"));
1066        assert!(should_skip_auth("/metrics"));
1067        assert!(should_skip_auth("/api/v1/auth/register"));
1068        assert!(should_skip_auth("/api/v1/auth/login"));
1069        assert!(should_skip_auth("/api/v1/demo/seed"));
1070
1071        // Verify internal endpoints bypass auth (sentinel failover)
1072        assert!(should_skip_auth("/internal/promote"));
1073        assert!(should_skip_auth("/internal/repoint"));
1074        assert!(should_skip_auth("/internal/anything"));
1075
1076        // Verify protected paths are NOT in skip list
1077        assert!(!should_skip_auth("/api/v1/events"));
1078        assert!(!should_skip_auth("/api/v1/auth/me"));
1079        assert!(!should_skip_auth("/api/v1/tenants"));
1080    }
1081
1082    #[test]
1083    fn test_dev_mode_auth_context() {
1084        let ctx = dev_mode_auth_context();
1085
1086        // Dev user should have admin privileges
1087        assert_eq!(ctx.tenant_id(), "dev-tenant");
1088        assert_eq!(ctx.user_id(), "dev-user");
1089        assert!(ctx.require_permission(Permission::Admin).is_ok());
1090        assert!(ctx.require_permission(Permission::Read).is_ok());
1091        assert!(ctx.require_permission(Permission::Write).is_ok());
1092    }
1093
1094    #[test]
1095    fn test_dev_mode_disabled_by_default() {
1096        // Dev mode should be disabled by default (env var not set in tests)
1097        // Note: This test may fail if ALLSOURCE_DEV_MODE is set in the test environment
1098        // In a clean environment, dev mode is disabled
1099        let env_value = std::env::var("ALLSOURCE_DEV_MODE").unwrap_or_default();
1100        if env_value.is_empty() {
1101            assert!(!is_dev_mode());
1102        }
1103    }
1104}