Skip to main content

aegis_server/
middleware.rs

1//! Aegis Middleware
2//!
3//! HTTP middleware for cross-cutting concerns including request ID generation,
4//! authentication, rate limiting, and request logging.
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use crate::state::AppState;
10use axum::{
11    body::Body,
12    extract::{ConnectInfo, State},
13    http::{HeaderValue, Request, Response, StatusCode},
14    middleware::Next,
15    response::IntoResponse,
16    Json,
17};
18use parking_lot::RwLock;
19use std::collections::HashMap;
20use std::net::SocketAddr;
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use uuid::Uuid;
24
25// =============================================================================
26// Rate Limiter
27// =============================================================================
28
29/// Token bucket rate limiter entry for a single client.
30#[derive(Debug, Clone)]
31struct RateLimitEntry {
32    tokens: f64,
33    last_update: Instant,
34}
35
36/// Shared rate limiter state.
37#[derive(Debug, Clone)]
38pub struct RateLimiter {
39    entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
40    max_requests: u32,
41    window_secs: u64,
42}
43
44impl RateLimiter {
45    /// Create a new rate limiter with the specified requests per minute.
46    pub fn new(requests_per_minute: u32) -> Self {
47        Self {
48            entries: Arc::new(RwLock::new(HashMap::new())),
49            max_requests: requests_per_minute,
50            window_secs: 60,
51        }
52    }
53
54    /// Check if a request from the given key should be allowed.
55    /// Returns true if allowed, false if rate limited.
56    pub fn check(&self, key: &str) -> bool {
57        let mut entries = self.entries.write();
58        let now = Instant::now();
59
60        let entry = entries
61            .entry(key.to_string())
62            .or_insert_with(|| RateLimitEntry {
63                tokens: self.max_requests as f64,
64                last_update: now,
65            });
66
67        // Refill tokens based on elapsed time (token bucket algorithm)
68        let elapsed = now.duration_since(entry.last_update);
69        let refill_rate = self.max_requests as f64 / self.window_secs as f64;
70        let refill = elapsed.as_secs_f64() * refill_rate;
71        entry.tokens = (entry.tokens + refill).min(self.max_requests as f64);
72        entry.last_update = now;
73
74        // Check if we have tokens available
75        if entry.tokens >= 1.0 {
76            entry.tokens -= 1.0;
77            true
78        } else {
79            false
80        }
81    }
82
83    /// Clean up old entries to prevent memory growth.
84    pub fn cleanup(&self) {
85        let mut entries = self.entries.write();
86        let now = Instant::now();
87        let max_age = Duration::from_secs(self.window_secs * 2);
88
89        entries.retain(|_, entry| now.duration_since(entry.last_update) < max_age);
90    }
91}
92
93impl Default for RateLimiter {
94    fn default() -> Self {
95        Self::new(100) // Default: 100 requests per minute
96    }
97}
98
99// =============================================================================
100// Request ID Middleware
101// =============================================================================
102
103/// Add a unique request ID to each request.
104pub async fn request_id(mut request: Request<Body>, next: Next) -> Response<Body> {
105    let request_id = Uuid::new_v4().to_string();
106
107    request.headers_mut().insert(
108        "x-request-id",
109        HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
110    );
111
112    let mut response = next.run(request).await;
113
114    response.headers_mut().insert(
115        "x-request-id",
116        HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
117    );
118
119    response
120}
121
122// =============================================================================
123// Shield Middleware
124// =============================================================================
125
126/// Security shield check — runs before all other middleware.
127/// Analyzes requests for threats and blocks malicious traffic.
128pub async fn shield_check(
129    State(state): State<AppState>,
130    request: Request<Body>,
131    next: Next,
132) -> Result<Response<Body>, impl IntoResponse> {
133    let source_ip = request
134        .headers()
135        .get("x-forwarded-for")
136        .and_then(|h| h.to_str().ok())
137        .unwrap_or("127.0.0.1")
138        .split(',')
139        .next()
140        .unwrap_or("127.0.0.1")
141        .trim()
142        .to_string();
143
144    let ctx = aegis_shield::RequestContext {
145        source_ip: source_ip.clone(),
146        path: request.uri().path().to_string(),
147        method: request.method().to_string(),
148        user_agent: request
149            .headers()
150            .get("user-agent")
151            .and_then(|h| h.to_str().ok())
152            .map(|s| s.to_string()),
153        auth_user: None,
154        body_size: 0,
155        headers: std::collections::HashMap::new(),
156    };
157
158    match state.shield.analyze_request(&ctx) {
159        aegis_shield::ShieldVerdict::Allow => Ok(next.run(request).await),
160        aegis_shield::ShieldVerdict::Block {
161            reason,
162            threat_level,
163        } => {
164            tracing::warn!(
165                ip = %source_ip,
166                level = ?threat_level,
167                "Shield blocked request: {}",
168                reason
169            );
170            Err((
171                StatusCode::FORBIDDEN,
172                Json(serde_json::json!({
173                    "error": "Request blocked by security shield",
174                    "reason": reason,
175                })),
176            ))
177        }
178        aegis_shield::ShieldVerdict::RateLimit { delay_ms } => {
179            // For rate-limited requests, add a delay header but allow through
180            let mut response = next.run(request).await;
181            if let Ok(val) = HeaderValue::from_str(&delay_ms.to_string()) {
182                response.headers_mut().insert("x-ratelimit-delay-ms", val);
183            }
184            Ok(response)
185        }
186    }
187}
188
189// =============================================================================
190// Authentication Middleware
191// =============================================================================
192
193/// Require authentication for protected routes.
194/// Returns 401 Unauthorized if no valid session token is provided.
195pub async fn require_auth(
196    State(state): State<AppState>,
197    request: Request<Body>,
198    next: Next,
199) -> Result<Response<Body>, impl IntoResponse> {
200    // If no users exist, auth cannot be enforced — allow open access for bootstrap.
201    // Log a warning on every request so operators notice the insecure state.
202    if state.auth.list_users().is_empty() {
203        tracing::warn!(
204            path = %request.uri().path(),
205            "SECURITY: No admin user configured — all endpoints are unauthenticated. \
206             Create an admin user via POST /api/v1/auth/login or set \
207             AEGIS_ADMIN_USERNAME/AEGIS_ADMIN_PASSWORD to secure the server."
208        );
209        return Ok(next.run(request).await);
210    }
211
212    // Extract token from Authorization header
213    let auth_header = request
214        .headers()
215        .get("authorization")
216        .and_then(|h| h.to_str().ok());
217
218    let token = match auth_header {
219        Some(header) if header.starts_with("Bearer ") => &header[7..],
220        _ => {
221            return Err((
222                StatusCode::UNAUTHORIZED,
223                Json(serde_json::json!({
224                    "error": "Missing or invalid Authorization header",
225                    "message": "Provide a valid Bearer token in the Authorization header"
226                })),
227            ));
228        }
229    };
230
231    // Validate the session token
232    match state.auth.validate_session(token) {
233        Some(_user) => {
234            // Token is valid, proceed with the request
235            Ok(next.run(request).await)
236        }
237        None => Err((
238            StatusCode::UNAUTHORIZED,
239            Json(serde_json::json!({
240                "error": "Invalid or expired session token",
241                "message": "Please log in again to obtain a new token"
242            })),
243        )),
244    }
245}
246
247// =============================================================================
248// Rate Limiting Middleware
249// =============================================================================
250
251/// Extract client IP from request, checking X-Forwarded-For header first.
252fn get_client_ip(request: &Request<Body>) -> String {
253    // Check X-Forwarded-For header (from reverse proxies)
254    if let Some(forwarded) = request
255        .headers()
256        .get("x-forwarded-for")
257        .and_then(|h| h.to_str().ok())
258    {
259        // Take the first IP in the chain (original client)
260        if let Some(first_ip) = forwarded.split(',').next() {
261            return first_ip.trim().to_string();
262        }
263    }
264
265    // Check X-Real-IP header
266    if let Some(real_ip) = request
267        .headers()
268        .get("x-real-ip")
269        .and_then(|h| h.to_str().ok())
270    {
271        return real_ip.to_string();
272    }
273
274    // Fall back to socket address from extensions (if available via ConnectInfo)
275    if let Some(connect_info) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
276        return connect_info.0.ip().to_string();
277    }
278
279    // Ultimate fallback
280    "unknown".to_string()
281}
282
283/// Rate limiting middleware for general API requests.
284/// Returns 429 Too Many Requests if the client exceeds the rate limit.
285pub async fn rate_limit(
286    State(state): State<AppState>,
287    request: Request<Body>,
288    next: Next,
289) -> Result<Response<Body>, impl IntoResponse> {
290    let client_ip = get_client_ip(&request);
291    let rate_limit = state.config.rate_limit_per_minute;
292
293    // Skip rate limiting if disabled (rate_limit = 0)
294    if rate_limit == 0 {
295        return Ok(next.run(request).await);
296    }
297
298    // Use the rate limiter from AppState
299    if state.rate_limiter.check(&client_ip) {
300        Ok(next.run(request).await)
301    } else {
302        Err((
303            StatusCode::TOO_MANY_REQUESTS,
304            Json(serde_json::json!({
305                "error": "Rate limit exceeded",
306                "message": format!("Too many requests. Please try again later. Limit: {} requests per minute.", rate_limit),
307                "retry_after_seconds": 60
308            })),
309        ))
310    }
311}
312
313/// Rate limiting middleware specifically for login attempts.
314/// Uses a stricter limit to prevent brute force attacks.
315pub async fn login_rate_limit(
316    State(state): State<AppState>,
317    request: Request<Body>,
318    next: Next,
319) -> Result<Response<Body>, impl IntoResponse> {
320    let client_ip = get_client_ip(&request);
321    let rate_limit = state.config.login_rate_limit_per_minute;
322
323    // Skip rate limiting if disabled (rate_limit = 0)
324    if rate_limit == 0 {
325        return Ok(next.run(request).await);
326    }
327
328    // Use the login rate limiter from AppState
329    if state
330        .login_rate_limiter
331        .check(&format!("login:{}", client_ip))
332    {
333        Ok(next.run(request).await)
334    } else {
335        Err((
336            StatusCode::TOO_MANY_REQUESTS,
337            Json(serde_json::json!({
338                "error": "Too many login attempts",
339                "message": format!("Too many login attempts. Please try again later. Limit: {} attempts per minute.", rate_limit),
340                "retry_after_seconds": 60
341            })),
342        ))
343    }
344}
345
346// =============================================================================
347// Security Headers Middleware
348// =============================================================================
349
350/// Add HTTP security headers to all responses.
351/// Includes Content-Security-Policy, X-Content-Type-Options, X-Frame-Options,
352/// X-XSS-Protection, Referrer-Policy, and optionally Strict-Transport-Security
353/// when TLS is enabled.
354pub async fn security_headers(
355    State(state): State<AppState>,
356    request: Request<Body>,
357    next: Next,
358) -> Response<Body> {
359    let mut response = next.run(request).await;
360    let headers = response.headers_mut();
361
362    // Content-Security-Policy: Restrict resource loading to same origin
363    headers.insert(
364        "content-security-policy",
365        HeaderValue::from_static("default-src 'self'"),
366    );
367
368    // X-Content-Type-Options: Prevent MIME type sniffing
369    headers.insert(
370        "x-content-type-options",
371        HeaderValue::from_static("nosniff"),
372    );
373
374    // X-Frame-Options: Prevent clickjacking by disabling framing
375    headers.insert("x-frame-options", HeaderValue::from_static("DENY"));
376
377    // X-XSS-Protection: Enable browser XSS filtering
378    headers.insert(
379        "x-xss-protection",
380        HeaderValue::from_static("1; mode=block"),
381    );
382
383    // Referrer-Policy: Control referrer information sent with requests
384    headers.insert(
385        "referrer-policy",
386        HeaderValue::from_static("strict-origin-when-cross-origin"),
387    );
388
389    // Strict-Transport-Security: Only add when TLS is enabled
390    if state.config.tls.is_some() {
391        headers.insert(
392            "strict-transport-security",
393            HeaderValue::from_static("max-age=31536000; includeSubDomains"),
394        );
395    }
396
397    response
398}
399
400// =============================================================================
401// Tests
402// =============================================================================
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use crate::config::ServerConfig;
408    use axum::body::Body;
409    use axum::http::{Request, StatusCode};
410    use axum::{routing::get, Router};
411    use tower::util::ServiceExt;
412
413    async fn handler() -> &'static str {
414        "ok"
415    }
416
417    #[tokio::test]
418    async fn test_request_id_middleware() {
419        let app = Router::new()
420            .route("/", get(handler))
421            .layer(axum::middleware::from_fn(request_id));
422
423        let response = app
424            .oneshot(
425                Request::builder()
426                    .uri("/")
427                    .body(Body::empty())
428                    .expect("failed to build request"),
429            )
430            .await
431            .expect("failed to execute request");
432
433        assert_eq!(response.status(), StatusCode::OK);
434        assert!(response.headers().contains_key("x-request-id"));
435    }
436
437    #[tokio::test]
438    async fn test_auth_middleware_no_token() {
439        let state = AppState::new(ServerConfig::default());
440        // Create a user so auth middleware is enforced
441        let _ = state
442            .auth
443            .create_user("testuser", "test@test.local", "TestPass123!", "admin");
444
445        let app = Router::new()
446            .route("/", get(handler))
447            .layer(axum::middleware::from_fn_with_state(
448                state.clone(),
449                require_auth,
450            ))
451            .with_state(state);
452
453        let response = app
454            .oneshot(
455                Request::builder()
456                    .uri("/")
457                    .body(Body::empty())
458                    .expect("failed to build request"),
459            )
460            .await
461            .expect("failed to execute request");
462
463        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
464    }
465
466    #[tokio::test]
467    async fn test_auth_middleware_invalid_token() {
468        let state = AppState::new(ServerConfig::default());
469        // Create a user so auth middleware is enforced
470        let _ = state
471            .auth
472            .create_user("testuser", "test@test.local", "TestPass123!", "admin");
473
474        let app = Router::new()
475            .route("/", get(handler))
476            .layer(axum::middleware::from_fn_with_state(
477                state.clone(),
478                require_auth,
479            ))
480            .with_state(state);
481
482        let response = app
483            .oneshot(
484                Request::builder()
485                    .uri("/")
486                    .header("Authorization", "Bearer invalid_token")
487                    .body(Body::empty())
488                    .expect("failed to build request"),
489            )
490            .await
491            .expect("failed to execute request");
492
493        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
494    }
495
496    #[tokio::test]
497    async fn test_auth_middleware_valid_token() {
498        let state = AppState::new(ServerConfig::default());
499
500        // Create a test user and get a valid token
501        state
502            .auth
503            .create_user("authtest", "auth@test.com", "TestPassword123!", "admin")
504            .expect("failed to create test user");
505        let login_response = state.auth.login("authtest", "TestPassword123!");
506        let token = login_response.token.expect("login should return token");
507
508        let app = Router::new()
509            .route("/", get(handler))
510            .layer(axum::middleware::from_fn_with_state(
511                state.clone(),
512                require_auth,
513            ))
514            .with_state(state);
515
516        let response = app
517            .oneshot(
518                Request::builder()
519                    .uri("/")
520                    .header("Authorization", format!("Bearer {}", token))
521                    .body(Body::empty())
522                    .expect("failed to build request"),
523            )
524            .await
525            .expect("failed to execute request");
526
527        assert_eq!(response.status(), StatusCode::OK);
528    }
529
530    #[test]
531    fn test_rate_limiter_allows_requests() {
532        let limiter = RateLimiter::new(10); // 10 requests per minute
533
534        // First 10 requests should be allowed
535        for _ in 0..10 {
536            assert!(limiter.check("test_client"));
537        }
538
539        // 11th request should be rate limited
540        assert!(!limiter.check("test_client"));
541    }
542
543    #[test]
544    fn test_rate_limiter_different_clients() {
545        let limiter = RateLimiter::new(5);
546
547        // Each client should have its own limit
548        for _ in 0..5 {
549            assert!(limiter.check("client_a"));
550            assert!(limiter.check("client_b"));
551        }
552
553        // Both should now be rate limited
554        assert!(!limiter.check("client_a"));
555        assert!(!limiter.check("client_b"));
556    }
557
558    #[test]
559    fn test_rate_limiter_cleanup() {
560        let limiter = RateLimiter::new(10);
561
562        // Add some entries
563        limiter.check("client_1");
564        limiter.check("client_2");
565
566        // Cleanup should not panic
567        limiter.cleanup();
568
569        // Should still work after cleanup
570        assert!(limiter.check("client_1"));
571    }
572
573    #[tokio::test]
574    async fn test_security_headers_without_tls() {
575        let state = AppState::new(ServerConfig::default());
576
577        let app = Router::new()
578            .route("/", get(handler))
579            .layer(axum::middleware::from_fn_with_state(
580                state.clone(),
581                security_headers,
582            ))
583            .with_state(state);
584
585        let response = app
586            .oneshot(
587                Request::builder()
588                    .uri("/")
589                    .body(Body::empty())
590                    .expect("failed to build request"),
591            )
592            .await
593            .expect("failed to execute request");
594
595        assert_eq!(response.status(), StatusCode::OK);
596
597        // Check security headers are present
598        assert_eq!(
599            response
600                .headers()
601                .get("content-security-policy")
602                .map(|v| v.to_str().unwrap()),
603            Some("default-src 'self'")
604        );
605        assert_eq!(
606            response
607                .headers()
608                .get("x-content-type-options")
609                .map(|v| v.to_str().unwrap()),
610            Some("nosniff")
611        );
612        assert_eq!(
613            response
614                .headers()
615                .get("x-frame-options")
616                .map(|v| v.to_str().unwrap()),
617            Some("DENY")
618        );
619        assert_eq!(
620            response
621                .headers()
622                .get("x-xss-protection")
623                .map(|v| v.to_str().unwrap()),
624            Some("1; mode=block")
625        );
626        assert_eq!(
627            response
628                .headers()
629                .get("referrer-policy")
630                .map(|v| v.to_str().unwrap()),
631            Some("strict-origin-when-cross-origin")
632        );
633
634        // HSTS should NOT be present without TLS
635        assert!(response
636            .headers()
637            .get("strict-transport-security")
638            .is_none());
639    }
640
641    #[tokio::test]
642    async fn test_security_headers_with_tls() {
643        let config = ServerConfig::default().with_tls("/path/to/cert.pem", "/path/to/key.pem");
644        let state = AppState::new(config);
645
646        let app = Router::new()
647            .route("/", get(handler))
648            .layer(axum::middleware::from_fn_with_state(
649                state.clone(),
650                security_headers,
651            ))
652            .with_state(state);
653
654        let response = app
655            .oneshot(
656                Request::builder()
657                    .uri("/")
658                    .body(Body::empty())
659                    .expect("failed to build request"),
660            )
661            .await
662            .expect("failed to execute request");
663
664        assert_eq!(response.status(), StatusCode::OK);
665
666        // HSTS should be present with TLS
667        assert_eq!(
668            response
669                .headers()
670                .get("strict-transport-security")
671                .map(|v| v.to_str().unwrap()),
672            Some("max-age=31536000; includeSubDomains")
673        );
674    }
675}