Skip to main content

allsource_core/infrastructure/security/
middleware.rs

1use crate::error::AllSourceError;
2use crate::infrastructure::security::auth::{AuthManager, Claims, Permission};
3use crate::infrastructure::security::rate_limit::RateLimiter;
4use axum::{
5    extract::{Request, State},
6    http::{HeaderMap, StatusCode},
7    middleware::Next,
8    response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11
12/// Authentication state shared across requests
13#[derive(Clone)]
14pub struct AuthState {
15    pub auth_manager: Arc<AuthManager>,
16}
17
18/// Rate limiting state
19#[derive(Clone)]
20pub struct RateLimitState {
21    pub rate_limiter: Arc<RateLimiter>,
22}
23
24/// Authenticated request context
25#[derive(Debug, Clone)]
26pub struct AuthContext {
27    pub claims: Claims,
28}
29
30impl AuthContext {
31    /// Check if user has required permission
32    pub fn require_permission(&self, permission: Permission) -> Result<(), AllSourceError> {
33        if self.claims.has_permission(permission) {
34            Ok(())
35        } else {
36            Err(AllSourceError::ValidationError(
37                "Insufficient permissions".to_string(),
38            ))
39        }
40    }
41
42    /// Get tenant ID from context
43    pub fn tenant_id(&self) -> &str {
44        &self.claims.tenant_id
45    }
46
47    /// Get user ID from context
48    pub fn user_id(&self) -> &str {
49        &self.claims.sub
50    }
51}
52
53/// Extract token from Authorization header
54fn extract_token(headers: &HeaderMap) -> Result<String, AllSourceError> {
55    let auth_header = headers
56        .get("authorization")
57        .ok_or_else(|| AllSourceError::ValidationError("Missing authorization header".to_string()))?
58        .to_str()
59        .map_err(|_| AllSourceError::ValidationError("Invalid authorization header".to_string()))?;
60
61    // Support both "Bearer <token>" and "<token>" formats
62    let token = if auth_header.starts_with("Bearer ") {
63        auth_header.trim_start_matches("Bearer ").trim()
64    } else if auth_header.starts_with("bearer ") {
65        auth_header.trim_start_matches("bearer ").trim()
66    } else {
67        auth_header.trim()
68    };
69
70    if token.is_empty() {
71        return Err(AllSourceError::ValidationError(
72            "Empty authorization token".to_string(),
73        ));
74    }
75
76    Ok(token.to_string())
77}
78
79/// Authentication middleware
80pub async fn auth_middleware(
81    State(auth_state): State<AuthState>,
82    mut request: Request,
83    next: Next,
84) -> Result<Response, AuthError> {
85    let headers = request.headers();
86
87    // Extract and validate token (JWT or API key)
88    let token = extract_token(headers)?;
89
90    let claims = if token.starts_with("ask_") {
91        // API Key authentication
92        auth_state.auth_manager.validate_api_key(&token)?
93    } else {
94        // JWT authentication
95        auth_state.auth_manager.validate_token(&token)?
96    };
97
98    // Insert auth context into request extensions
99    request.extensions_mut().insert(AuthContext { claims });
100
101    Ok(next.run(request).await)
102}
103
104/// Optional authentication middleware (allows unauthenticated requests)
105pub async fn optional_auth_middleware(
106    State(auth_state): State<AuthState>,
107    mut request: Request,
108    next: Next,
109) -> Response {
110    let headers = request.headers();
111
112    if let Ok(token) = extract_token(headers) {
113        // Try to authenticate, but don't fail if invalid
114        let claims = if token.starts_with("ask_") {
115            auth_state.auth_manager.validate_api_key(&token).ok()
116        } else {
117            auth_state.auth_manager.validate_token(&token).ok()
118        };
119
120        if let Some(claims) = claims {
121            request.extensions_mut().insert(AuthContext { claims });
122        }
123    }
124
125    next.run(request).await
126}
127
128/// Error type for authentication failures
129#[derive(Debug)]
130pub struct AuthError(AllSourceError);
131
132impl From<AllSourceError> for AuthError {
133    fn from(err: AllSourceError) -> Self {
134        AuthError(err)
135    }
136}
137
138impl IntoResponse for AuthError {
139    fn into_response(self) -> Response {
140        let (status, message) = match self.0 {
141            AllSourceError::ValidationError(msg) => (StatusCode::UNAUTHORIZED, msg),
142            _ => (
143                StatusCode::INTERNAL_SERVER_ERROR,
144                "Internal server error".to_string(),
145            ),
146        };
147
148        (status, message).into_response()
149    }
150}
151
152/// Axum extractor for authenticated requests
153pub struct Authenticated(pub AuthContext);
154
155impl<S> axum::extract::FromRequestParts<S> for Authenticated
156where
157    S: Send + Sync,
158{
159    type Rejection = (StatusCode, &'static str);
160
161    async fn from_request_parts(
162        parts: &mut axum::http::request::Parts,
163        _state: &S,
164    ) -> Result<Self, Self::Rejection> {
165        parts
166            .extensions
167            .get::<AuthContext>()
168            .cloned()
169            .map(Authenticated)
170            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))
171    }
172}
173
174/// Axum extractor for optional authentication (never rejects, returns Option)
175/// Use this for routes that work with or without authentication
176pub struct OptionalAuth(pub Option<AuthContext>);
177
178impl<S> axum::extract::FromRequestParts<S> for OptionalAuth
179where
180    S: Send + Sync,
181{
182    type Rejection = std::convert::Infallible;
183
184    async fn from_request_parts(
185        parts: &mut axum::http::request::Parts,
186        _state: &S,
187    ) -> Result<Self, Self::Rejection> {
188        Ok(OptionalAuth(parts.extensions.get::<AuthContext>().cloned()))
189    }
190}
191
192/// Axum extractor for admin-only requests
193pub struct Admin(pub AuthContext);
194
195impl<S> axum::extract::FromRequestParts<S> for Admin
196where
197    S: Send + Sync,
198{
199    type Rejection = (StatusCode, &'static str);
200
201    async fn from_request_parts(
202        parts: &mut axum::http::request::Parts,
203        _state: &S,
204    ) -> Result<Self, Self::Rejection> {
205        let auth_ctx = parts
206            .extensions
207            .get::<AuthContext>()
208            .cloned()
209            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
210
211        auth_ctx
212            .require_permission(Permission::Admin)
213            .map_err(|_| (StatusCode::FORBIDDEN, "Admin permission required"))?;
214
215        Ok(Admin(auth_ctx))
216    }
217}
218
219/// Rate limiting middleware
220/// Checks rate limits based on tenant_id from auth context
221pub async fn rate_limit_middleware(
222    State(rate_limit_state): State<RateLimitState>,
223    request: Request,
224    next: Next,
225) -> Result<Response, RateLimitError> {
226    // Extract auth context from request
227    let auth_ctx = request
228        .extensions()
229        .get::<AuthContext>()
230        .ok_or(RateLimitError::Unauthorized)?;
231
232    // Check rate limit for this tenant
233    let result = rate_limit_state
234        .rate_limiter
235        .check_rate_limit(auth_ctx.tenant_id());
236
237    if !result.allowed {
238        return Err(RateLimitError::RateLimitExceeded {
239            retry_after: result.retry_after.unwrap_or_default().as_secs(),
240            limit: result.limit,
241        });
242    }
243
244    // Add rate limit headers to response
245    let mut response = next.run(request).await;
246    let headers = response.headers_mut();
247    headers.insert(
248        "X-RateLimit-Limit",
249        result.limit.to_string().parse().unwrap(),
250    );
251    headers.insert(
252        "X-RateLimit-Remaining",
253        result.remaining.to_string().parse().unwrap(),
254    );
255
256    Ok(response)
257}
258
259/// Error type for rate limiting failures
260#[derive(Debug)]
261pub enum RateLimitError {
262    RateLimitExceeded { retry_after: u64, limit: u32 },
263    Unauthorized,
264}
265
266impl IntoResponse for RateLimitError {
267    fn into_response(self) -> Response {
268        match self {
269            RateLimitError::RateLimitExceeded { retry_after, limit } => {
270                let mut response = (
271                    StatusCode::TOO_MANY_REQUESTS,
272                    format!("Rate limit exceeded. Limit: {limit} requests/min"),
273                )
274                    .into_response();
275
276                if retry_after > 0 {
277                    response
278                        .headers_mut()
279                        .insert("Retry-After", retry_after.to_string().parse().unwrap());
280                }
281
282                response
283            }
284            RateLimitError::Unauthorized => (
285                StatusCode::UNAUTHORIZED,
286                "Authentication required for rate limiting",
287            )
288                .into_response(),
289        }
290    }
291}
292
293/// Helper macro to require specific permission
294#[macro_export]
295macro_rules! require_permission {
296    ($auth:expr, $perm:expr) => {
297        $auth.0.require_permission($perm).map_err(|_| {
298            (
299                axum::http::StatusCode::FORBIDDEN,
300                "Insufficient permissions",
301            )
302        })?
303    };
304}
305
306// ============================================================================
307// Tenant Isolation Middleware (Phase 5B)
308// ============================================================================
309
310use crate::domain::entities::Tenant;
311use crate::domain::repositories::TenantRepository;
312use crate::domain::value_objects::TenantId;
313
314/// Tenant isolation state for middleware
315#[derive(Clone)]
316pub struct TenantState<R: TenantRepository> {
317    pub tenant_repository: Arc<R>,
318}
319
320/// Validated tenant context injected into requests
321///
322/// This context is created by the tenant_isolation_middleware after
323/// validating that the tenant exists and is active.
324#[derive(Debug, Clone)]
325pub struct TenantContext {
326    pub tenant: Tenant,
327}
328
329impl TenantContext {
330    /// Get the tenant ID
331    pub fn tenant_id(&self) -> &TenantId {
332        self.tenant.id()
333    }
334
335    /// Check if tenant is active
336    pub fn is_active(&self) -> bool {
337        self.tenant.is_active()
338    }
339}
340
341/// Tenant isolation middleware
342///
343/// Validates that the authenticated tenant exists and is active.
344/// Injects TenantContext into the request for use by handlers.
345///
346/// # Phase 5B: Tenant Isolation
347/// This middleware enforces tenant boundaries by:
348/// 1. Extracting tenant_id from AuthContext
349/// 2. Loading tenant from repository
350/// 3. Validating tenant is active
351/// 4. Injecting TenantContext into request extensions
352///
353/// Must be applied after auth_middleware.
354pub async fn tenant_isolation_middleware<R: TenantRepository + 'static>(
355    State(tenant_state): State<TenantState<R>>,
356    mut request: Request,
357    next: Next,
358) -> Result<Response, TenantError> {
359    // Extract auth context (must be authenticated)
360    let auth_ctx = request
361        .extensions()
362        .get::<AuthContext>()
363        .ok_or(TenantError::Unauthorized)?
364        .clone();
365
366    // Parse tenant ID
367    let tenant_id =
368        TenantId::new(auth_ctx.tenant_id().to_string()).map_err(|_| TenantError::InvalidTenant)?;
369
370    // Load tenant from repository
371    let tenant = tenant_state
372        .tenant_repository
373        .find_by_id(&tenant_id)
374        .await
375        .map_err(|e| TenantError::RepositoryError(e.to_string()))?
376        .ok_or(TenantError::TenantNotFound)?;
377
378    // Validate tenant is active
379    if !tenant.is_active() {
380        return Err(TenantError::TenantInactive);
381    }
382
383    // Inject tenant context into request
384    request.extensions_mut().insert(TenantContext { tenant });
385
386    // Continue to next middleware/handler
387    Ok(next.run(request).await)
388}
389
390/// Error type for tenant isolation failures
391#[derive(Debug)]
392pub enum TenantError {
393    Unauthorized,
394    InvalidTenant,
395    TenantNotFound,
396    TenantInactive,
397    RepositoryError(String),
398}
399
400impl IntoResponse for TenantError {
401    fn into_response(self) -> Response {
402        let (status, message) = match self {
403            TenantError::Unauthorized => (
404                StatusCode::UNAUTHORIZED,
405                "Authentication required for tenant access",
406            ),
407            TenantError::InvalidTenant => (StatusCode::BAD_REQUEST, "Invalid tenant identifier"),
408            TenantError::TenantNotFound => (StatusCode::NOT_FOUND, "Tenant not found"),
409            TenantError::TenantInactive => (StatusCode::FORBIDDEN, "Tenant is inactive"),
410            TenantError::RepositoryError(_) => (
411                StatusCode::INTERNAL_SERVER_ERROR,
412                "Failed to validate tenant",
413            ),
414        };
415
416        (status, message).into_response()
417    }
418}
419
420// ============================================================================
421// Request ID Middleware (Phase 5C)
422// ============================================================================
423
424use uuid::Uuid;
425
426/// Request context with unique ID for tracing
427#[derive(Debug, Clone)]
428pub struct RequestId(pub String);
429
430impl Default for RequestId {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436impl RequestId {
437    /// Generate a new request ID
438    pub fn new() -> Self {
439        Self(Uuid::new_v4().to_string())
440    }
441
442    /// Get the request ID as a string
443    pub fn as_str(&self) -> &str {
444        &self.0
445    }
446}
447
448/// Request ID middleware
449///
450/// Generates a unique request ID for each request and injects it into:
451/// - Request extensions (for use in handlers/logging)
452/// - Response headers (X-Request-ID)
453///
454/// If the request already has an X-Request-ID header, it will be used instead.
455///
456/// # Phase 5C: Request Tracing
457/// This middleware enables distributed tracing by:
458/// 1. Generating unique IDs for each request
459/// 2. Propagating IDs through the request lifecycle
460/// 3. Returning IDs in response headers
461/// 4. Supporting client-provided request IDs
462pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
463    // Check if request already has a request ID
464    let request_id = request
465        .headers()
466        .get("x-request-id")
467        .and_then(|v| v.to_str().ok())
468        .map(|s| RequestId(s.to_string()))
469        .unwrap_or_else(RequestId::new);
470
471    // Store request ID in extensions
472    request.extensions_mut().insert(request_id.clone());
473
474    // Process request
475    let mut response = next.run(request).await;
476
477    // Add request ID to response headers
478    response
479        .headers_mut()
480        .insert("x-request-id", request_id.0.parse().unwrap());
481
482    response
483}
484
485// ============================================================================
486// Security Headers Middleware (Phase 5C)
487// ============================================================================
488
489/// Security headers configuration
490#[derive(Debug, Clone)]
491pub struct SecurityConfig {
492    /// Enable HSTS (HTTP Strict Transport Security)
493    pub enable_hsts: bool,
494    /// HSTS max age in seconds
495    pub hsts_max_age: u32,
496    /// Enable X-Frame-Options
497    pub enable_frame_options: bool,
498    /// X-Frame-Options value
499    pub frame_options: FrameOptions,
500    /// Enable X-Content-Type-Options
501    pub enable_content_type_options: bool,
502    /// Enable X-XSS-Protection
503    pub enable_xss_protection: bool,
504    /// Content Security Policy
505    pub csp: Option<String>,
506    /// CORS allowed origins
507    pub cors_origins: Vec<String>,
508    /// CORS allowed methods
509    pub cors_methods: Vec<String>,
510    /// CORS allowed headers
511    pub cors_headers: Vec<String>,
512    /// CORS max age
513    pub cors_max_age: u32,
514}
515
516#[derive(Debug, Clone)]
517pub enum FrameOptions {
518    Deny,
519    SameOrigin,
520    AllowFrom(String),
521}
522
523impl Default for SecurityConfig {
524    fn default() -> Self {
525        Self {
526            enable_hsts: true,
527            hsts_max_age: 31536000, // 1 year
528            enable_frame_options: true,
529            frame_options: FrameOptions::Deny,
530            enable_content_type_options: true,
531            enable_xss_protection: true,
532            csp: Some("default-src 'self'".to_string()),
533            cors_origins: vec!["*".to_string()],
534            cors_methods: vec![
535                "GET".to_string(),
536                "POST".to_string(),
537                "PUT".to_string(),
538                "DELETE".to_string(),
539            ],
540            cors_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
541            cors_max_age: 3600,
542        }
543    }
544}
545
546#[derive(Clone)]
547pub struct SecurityState {
548    pub config: SecurityConfig,
549}
550
551/// Security headers middleware
552///
553/// Adds security-related HTTP headers to all responses:
554/// - HSTS (Strict-Transport-Security)
555/// - X-Frame-Options
556/// - X-Content-Type-Options
557/// - X-XSS-Protection
558/// - Content-Security-Policy
559/// - CORS headers
560///
561/// # Phase 5C: Security Hardening
562/// This middleware provides defense-in-depth by:
563/// 1. Preventing clickjacking (X-Frame-Options)
564/// 2. Preventing MIME sniffing (X-Content-Type-Options)
565/// 3. Enforcing HTTPS (HSTS)
566/// 4. Preventing XSS (CSP, X-XSS-Protection)
567/// 5. Enabling CORS for controlled access
568pub async fn security_headers_middleware(
569    State(security_state): State<SecurityState>,
570    request: Request,
571    next: Next,
572) -> Response {
573    let mut response = next.run(request).await;
574    let headers = response.headers_mut();
575    let config = &security_state.config;
576
577    // HSTS
578    if config.enable_hsts {
579        headers.insert(
580            "strict-transport-security",
581            format!("max-age={}", config.hsts_max_age).parse().unwrap(),
582        );
583    }
584
585    // X-Frame-Options
586    if config.enable_frame_options {
587        let value = match &config.frame_options {
588            FrameOptions::Deny => "DENY",
589            FrameOptions::SameOrigin => "SAMEORIGIN",
590            FrameOptions::AllowFrom(origin) => origin,
591        };
592        headers.insert("x-frame-options", value.parse().unwrap());
593    }
594
595    // X-Content-Type-Options
596    if config.enable_content_type_options {
597        headers.insert("x-content-type-options", "nosniff".parse().unwrap());
598    }
599
600    // X-XSS-Protection
601    if config.enable_xss_protection {
602        headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
603    }
604
605    // Content-Security-Policy
606    if let Some(csp) = &config.csp {
607        headers.insert("content-security-policy", csp.parse().unwrap());
608    }
609
610    // CORS headers
611    headers.insert(
612        "access-control-allow-origin",
613        config.cors_origins.join(", ").parse().unwrap(),
614    );
615    headers.insert(
616        "access-control-allow-methods",
617        config.cors_methods.join(", ").parse().unwrap(),
618    );
619    headers.insert(
620        "access-control-allow-headers",
621        config.cors_headers.join(", ").parse().unwrap(),
622    );
623    headers.insert(
624        "access-control-max-age",
625        config.cors_max_age.to_string().parse().unwrap(),
626    );
627
628    response
629}
630
631// ============================================================================
632// IP Filtering Middleware (Phase 5C)
633// ============================================================================
634
635use crate::infrastructure::security::IpFilter;
636use std::net::SocketAddr;
637
638#[derive(Clone)]
639pub struct IpFilterState {
640    pub ip_filter: Arc<IpFilter>,
641}
642
643/// IP filtering middleware
644///
645/// Blocks or allows requests based on IP address rules.
646/// Supports both global and per-tenant IP filtering.
647///
648/// # Phase 5C: Access Control
649/// This middleware provides IP-based access control by:
650/// 1. Extracting client IP from request
651/// 2. Checking against global and tenant-specific rules
652/// 3. Blocking requests from unauthorized IPs
653/// 4. Supporting both allowlists and blocklists
654pub async fn ip_filter_middleware(
655    State(ip_filter_state): State<IpFilterState>,
656    request: Request,
657    next: Next,
658) -> Result<Response, IpFilterError> {
659    // Extract client IP address
660    let client_ip = request
661        .extensions()
662        .get::<axum::extract::ConnectInfo<SocketAddr>>()
663        .map(|connect_info| connect_info.0.ip())
664        .ok_or(IpFilterError::NoIpAddress)?;
665
666    // Check if this is a tenant-scoped request
667    let result = if let Some(tenant_ctx) = request.extensions().get::<TenantContext>() {
668        // Tenant-specific filtering
669        ip_filter_state
670            .ip_filter
671            .is_allowed_for_tenant(tenant_ctx.tenant_id(), &client_ip)
672    } else {
673        // Global filtering only
674        ip_filter_state.ip_filter.is_allowed(&client_ip)
675    };
676
677    // Block if not allowed
678    if !result.allowed {
679        return Err(IpFilterError::Blocked {
680            reason: result.reason,
681        });
682    }
683
684    // Allow request to proceed
685    Ok(next.run(request).await)
686}
687
688/// Error type for IP filtering failures
689#[derive(Debug)]
690pub enum IpFilterError {
691    NoIpAddress,
692    Blocked { reason: String },
693}
694
695impl IntoResponse for IpFilterError {
696    fn into_response(self) -> Response {
697        match self {
698            IpFilterError::NoIpAddress => (
699                StatusCode::BAD_REQUEST,
700                "Unable to determine client IP address",
701            )
702                .into_response(),
703            IpFilterError::Blocked { reason } => {
704                (StatusCode::FORBIDDEN, format!("Access denied: {reason}")).into_response()
705            }
706        }
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713    use crate::infrastructure::security::auth::Role;
714
715    #[test]
716    fn test_extract_bearer_token() {
717        let mut headers = HeaderMap::new();
718        headers.insert("authorization", "Bearer test_token_123".parse().unwrap());
719
720        let token = extract_token(&headers).unwrap();
721        assert_eq!(token, "test_token_123");
722    }
723
724    #[test]
725    fn test_extract_lowercase_bearer() {
726        let mut headers = HeaderMap::new();
727        headers.insert("authorization", "bearer test_token_123".parse().unwrap());
728
729        let token = extract_token(&headers).unwrap();
730        assert_eq!(token, "test_token_123");
731    }
732
733    #[test]
734    fn test_extract_plain_token() {
735        let mut headers = HeaderMap::new();
736        headers.insert("authorization", "test_token_123".parse().unwrap());
737
738        let token = extract_token(&headers).unwrap();
739        assert_eq!(token, "test_token_123");
740    }
741
742    #[test]
743    fn test_missing_auth_header() {
744        let headers = HeaderMap::new();
745        assert!(extract_token(&headers).is_err());
746    }
747
748    #[test]
749    fn test_empty_auth_header() {
750        let mut headers = HeaderMap::new();
751        headers.insert("authorization", "".parse().unwrap());
752        assert!(extract_token(&headers).is_err());
753    }
754
755    #[test]
756    fn test_bearer_with_empty_token() {
757        let mut headers = HeaderMap::new();
758        headers.insert("authorization", "Bearer ".parse().unwrap());
759        assert!(extract_token(&headers).is_err());
760    }
761
762    #[test]
763    fn test_auth_context_permissions() {
764        let claims = Claims::new(
765            "user1".to_string(),
766            "tenant1".to_string(),
767            Role::Developer,
768            chrono::Duration::hours(1),
769        );
770
771        let ctx = AuthContext { claims };
772
773        assert!(ctx.require_permission(Permission::Read).is_ok());
774        assert!(ctx.require_permission(Permission::Write).is_ok());
775        assert!(ctx.require_permission(Permission::Admin).is_err());
776    }
777
778    #[test]
779    fn test_auth_context_admin_permissions() {
780        let claims = Claims::new(
781            "admin1".to_string(),
782            "tenant1".to_string(),
783            Role::Admin,
784            chrono::Duration::hours(1),
785        );
786
787        let ctx = AuthContext { claims };
788
789        assert!(ctx.require_permission(Permission::Read).is_ok());
790        assert!(ctx.require_permission(Permission::Write).is_ok());
791        assert!(ctx.require_permission(Permission::Admin).is_ok());
792    }
793
794    #[test]
795    fn test_auth_context_readonly_permissions() {
796        let claims = Claims::new(
797            "readonly1".to_string(),
798            "tenant1".to_string(),
799            Role::ReadOnly,
800            chrono::Duration::hours(1),
801        );
802
803        let ctx = AuthContext { claims };
804
805        assert!(ctx.require_permission(Permission::Read).is_ok());
806        assert!(ctx.require_permission(Permission::Write).is_err());
807        assert!(ctx.require_permission(Permission::Admin).is_err());
808    }
809
810    #[test]
811    fn test_auth_context_tenant_id() {
812        let claims = Claims::new(
813            "user1".to_string(),
814            "my-tenant".to_string(),
815            Role::Developer,
816            chrono::Duration::hours(1),
817        );
818
819        let ctx = AuthContext { claims };
820        assert_eq!(ctx.tenant_id(), "my-tenant");
821    }
822
823    #[test]
824    fn test_auth_context_user_id() {
825        let claims = Claims::new(
826            "my-user".to_string(),
827            "tenant1".to_string(),
828            Role::Developer,
829            chrono::Duration::hours(1),
830        );
831
832        let ctx = AuthContext { claims };
833        assert_eq!(ctx.user_id(), "my-user");
834    }
835
836    #[test]
837    fn test_request_id_new() {
838        let id1 = RequestId::new();
839        let id2 = RequestId::new();
840
841        // IDs should be unique
842        assert_ne!(id1.as_str(), id2.as_str());
843        // IDs should be valid UUIDs (36 chars with hyphens)
844        assert_eq!(id1.as_str().len(), 36);
845    }
846
847    #[test]
848    fn test_request_id_default() {
849        let id = RequestId::default();
850        assert_eq!(id.as_str().len(), 36);
851    }
852
853    #[test]
854    fn test_security_config_default() {
855        let config = SecurityConfig::default();
856
857        assert!(config.enable_hsts);
858        assert_eq!(config.hsts_max_age, 31536000);
859        assert!(config.enable_frame_options);
860        assert!(config.enable_content_type_options);
861        assert!(config.enable_xss_protection);
862        assert!(config.csp.is_some());
863    }
864
865    #[test]
866    fn test_frame_options_variants() {
867        let deny = FrameOptions::Deny;
868        let same_origin = FrameOptions::SameOrigin;
869        let allow_from = FrameOptions::AllowFrom("https://example.com".to_string());
870
871        // Check that variants are distinct via debug formatting
872        assert!(format!("{:?}", deny).contains("Deny"));
873        assert!(format!("{:?}", same_origin).contains("SameOrigin"));
874        assert!(format!("{:?}", allow_from).contains("AllowFrom"));
875    }
876
877    #[test]
878    fn test_auth_error_from_validation_error() {
879        let error = AllSourceError::ValidationError("test error".to_string());
880        let auth_error = AuthError::from(error);
881        assert!(format!("{:?}", auth_error).contains("ValidationError"));
882    }
883
884    #[test]
885    fn test_rate_limit_error_display() {
886        let error = RateLimitError::RateLimitExceeded {
887            retry_after: 60,
888            limit: 100,
889        };
890        assert!(format!("{:?}", error).contains("RateLimitExceeded"));
891
892        let unauth_error = RateLimitError::Unauthorized;
893        assert!(format!("{:?}", unauth_error).contains("Unauthorized"));
894    }
895
896    #[test]
897    fn test_tenant_error_variants() {
898        let errors = vec![
899            TenantError::Unauthorized,
900            TenantError::InvalidTenant,
901            TenantError::TenantNotFound,
902            TenantError::TenantInactive,
903            TenantError::RepositoryError("test".to_string()),
904        ];
905
906        for error in errors {
907            // Ensure each variant can be debug-formatted
908            let _ = format!("{:?}", error);
909        }
910    }
911
912    #[test]
913    fn test_ip_filter_error_variants() {
914        let errors = vec![
915            IpFilterError::NoIpAddress,
916            IpFilterError::Blocked {
917                reason: "blocked".to_string(),
918            },
919        ];
920
921        for error in errors {
922            let _ = format!("{:?}", error);
923        }
924    }
925
926    #[test]
927    fn test_security_state_clone() {
928        let config = SecurityConfig::default();
929        let state = SecurityState {
930            config: config.clone(),
931        };
932        let cloned = state.clone();
933        assert_eq!(cloned.config.hsts_max_age, config.hsts_max_age);
934    }
935
936    #[test]
937    fn test_auth_state_clone() {
938        let auth_manager = Arc::new(AuthManager::new("test-secret"));
939        let state = AuthState { auth_manager };
940        let cloned = state.clone();
941        assert!(Arc::ptr_eq(&state.auth_manager, &cloned.auth_manager));
942    }
943
944    #[test]
945    fn test_rate_limit_state_clone() {
946        use crate::infrastructure::security::rate_limit::RateLimitConfig;
947        let rate_limiter = Arc::new(RateLimiter::new(RateLimitConfig::free_tier()));
948        let state = RateLimitState { rate_limiter };
949        let cloned = state.clone();
950        assert!(Arc::ptr_eq(&state.rate_limiter, &cloned.rate_limiter));
951    }
952}