1use crate::state::AppState;
10use axum::{
11 body::Body,
12 extract::{ConnectInfo, State},
13 http::{HeaderValue, Request, Response, StatusCode},
14 middleware::Next,
15 response::IntoResponse,
16 Json,
17};
18use parking_lot::RwLock;
19use std::collections::HashMap;
20use std::net::SocketAddr;
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use uuid::Uuid;
24
25#[derive(Debug, Clone)]
31struct RateLimitEntry {
32 tokens: f64,
33 last_update: Instant,
34}
35
36#[derive(Debug, Clone)]
38pub struct RateLimiter {
39 entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
40 max_requests: u32,
41 window_secs: u64,
42}
43
44impl RateLimiter {
45 pub fn new(requests_per_minute: u32) -> Self {
47 Self {
48 entries: Arc::new(RwLock::new(HashMap::new())),
49 max_requests: requests_per_minute,
50 window_secs: 60,
51 }
52 }
53
54 pub fn check(&self, key: &str) -> bool {
57 let mut entries = self.entries.write();
58 let now = Instant::now();
59
60 let entry = entries
61 .entry(key.to_string())
62 .or_insert_with(|| RateLimitEntry {
63 tokens: self.max_requests as f64,
64 last_update: now,
65 });
66
67 let elapsed = now.duration_since(entry.last_update);
69 let refill_rate = self.max_requests as f64 / self.window_secs as f64;
70 let refill = elapsed.as_secs_f64() * refill_rate;
71 entry.tokens = (entry.tokens + refill).min(self.max_requests as f64);
72 entry.last_update = now;
73
74 if entry.tokens >= 1.0 {
76 entry.tokens -= 1.0;
77 true
78 } else {
79 false
80 }
81 }
82
83 pub fn cleanup(&self) {
85 let mut entries = self.entries.write();
86 let now = Instant::now();
87 let max_age = Duration::from_secs(self.window_secs * 2);
88
89 entries.retain(|_, entry| now.duration_since(entry.last_update) < max_age);
90 }
91}
92
93impl Default for RateLimiter {
94 fn default() -> Self {
95 Self::new(100) }
97}
98
99pub async fn request_id(mut request: Request<Body>, next: Next) -> Response<Body> {
105 let request_id = Uuid::new_v4().to_string();
106
107 request.headers_mut().insert(
108 "x-request-id",
109 HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
110 );
111
112 let mut response = next.run(request).await;
113
114 response.headers_mut().insert(
115 "x-request-id",
116 HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
117 );
118
119 response
120}
121
122pub async fn shield_check(
129 State(state): State<AppState>,
130 request: Request<Body>,
131 next: Next,
132) -> Result<Response<Body>, impl IntoResponse> {
133 let source_ip = request
134 .headers()
135 .get("x-forwarded-for")
136 .and_then(|h| h.to_str().ok())
137 .unwrap_or("127.0.0.1")
138 .split(',')
139 .next()
140 .unwrap_or("127.0.0.1")
141 .trim()
142 .to_string();
143
144 let ctx = aegis_shield::RequestContext {
145 source_ip: source_ip.clone(),
146 path: request.uri().path().to_string(),
147 method: request.method().to_string(),
148 user_agent: request
149 .headers()
150 .get("user-agent")
151 .and_then(|h| h.to_str().ok())
152 .map(|s| s.to_string()),
153 auth_user: None,
154 body_size: 0,
155 headers: std::collections::HashMap::new(),
156 };
157
158 match state.shield.analyze_request(&ctx) {
159 aegis_shield::ShieldVerdict::Allow => Ok(next.run(request).await),
160 aegis_shield::ShieldVerdict::Block {
161 reason,
162 threat_level,
163 } => {
164 tracing::warn!(
165 ip = %source_ip,
166 level = ?threat_level,
167 "Shield blocked request: {}",
168 reason
169 );
170 Err((
171 StatusCode::FORBIDDEN,
172 Json(serde_json::json!({
173 "error": "Request blocked by security shield",
174 "reason": reason,
175 })),
176 ))
177 }
178 aegis_shield::ShieldVerdict::RateLimit { delay_ms } => {
179 let mut response = next.run(request).await;
181 if let Ok(val) = HeaderValue::from_str(&delay_ms.to_string()) {
182 response.headers_mut().insert("x-ratelimit-delay-ms", val);
183 }
184 Ok(response)
185 }
186 }
187}
188
189pub async fn require_auth(
196 State(state): State<AppState>,
197 request: Request<Body>,
198 next: Next,
199) -> Result<Response<Body>, impl IntoResponse> {
200 if state.auth.list_users().is_empty() {
203 tracing::warn!(
204 path = %request.uri().path(),
205 "SECURITY: No admin user configured — all endpoints are unauthenticated. \
206 Create an admin user via POST /api/v1/auth/login or set \
207 AEGIS_ADMIN_USERNAME/AEGIS_ADMIN_PASSWORD to secure the server."
208 );
209 return Ok(next.run(request).await);
210 }
211
212 let auth_header = request
214 .headers()
215 .get("authorization")
216 .and_then(|h| h.to_str().ok());
217
218 let token = match auth_header {
219 Some(header) if header.starts_with("Bearer ") => &header[7..],
220 _ => {
221 return Err((
222 StatusCode::UNAUTHORIZED,
223 Json(serde_json::json!({
224 "error": "Missing or invalid Authorization header",
225 "message": "Provide a valid Bearer token in the Authorization header"
226 })),
227 ));
228 }
229 };
230
231 match state.auth.validate_session(token) {
233 Some(_user) => {
234 Ok(next.run(request).await)
236 }
237 None => Err((
238 StatusCode::UNAUTHORIZED,
239 Json(serde_json::json!({
240 "error": "Invalid or expired session token",
241 "message": "Please log in again to obtain a new token"
242 })),
243 )),
244 }
245}
246
247fn get_client_ip(request: &Request<Body>) -> String {
253 if let Some(forwarded) = request
255 .headers()
256 .get("x-forwarded-for")
257 .and_then(|h| h.to_str().ok())
258 {
259 if let Some(first_ip) = forwarded.split(',').next() {
261 return first_ip.trim().to_string();
262 }
263 }
264
265 if let Some(real_ip) = request
267 .headers()
268 .get("x-real-ip")
269 .and_then(|h| h.to_str().ok())
270 {
271 return real_ip.to_string();
272 }
273
274 if let Some(connect_info) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
276 return connect_info.0.ip().to_string();
277 }
278
279 "unknown".to_string()
281}
282
283pub async fn rate_limit(
286 State(state): State<AppState>,
287 request: Request<Body>,
288 next: Next,
289) -> Result<Response<Body>, impl IntoResponse> {
290 let client_ip = get_client_ip(&request);
291 let rate_limit = state.config.rate_limit_per_minute;
292
293 if rate_limit == 0 {
295 return Ok(next.run(request).await);
296 }
297
298 if state.rate_limiter.check(&client_ip) {
300 Ok(next.run(request).await)
301 } else {
302 Err((
303 StatusCode::TOO_MANY_REQUESTS,
304 Json(serde_json::json!({
305 "error": "Rate limit exceeded",
306 "message": format!("Too many requests. Please try again later. Limit: {} requests per minute.", rate_limit),
307 "retry_after_seconds": 60
308 })),
309 ))
310 }
311}
312
313pub async fn login_rate_limit(
316 State(state): State<AppState>,
317 request: Request<Body>,
318 next: Next,
319) -> Result<Response<Body>, impl IntoResponse> {
320 let client_ip = get_client_ip(&request);
321 let rate_limit = state.config.login_rate_limit_per_minute;
322
323 if rate_limit == 0 {
325 return Ok(next.run(request).await);
326 }
327
328 if state
330 .login_rate_limiter
331 .check(&format!("login:{}", client_ip))
332 {
333 Ok(next.run(request).await)
334 } else {
335 Err((
336 StatusCode::TOO_MANY_REQUESTS,
337 Json(serde_json::json!({
338 "error": "Too many login attempts",
339 "message": format!("Too many login attempts. Please try again later. Limit: {} attempts per minute.", rate_limit),
340 "retry_after_seconds": 60
341 })),
342 ))
343 }
344}
345
346pub async fn security_headers(
355 State(state): State<AppState>,
356 request: Request<Body>,
357 next: Next,
358) -> Response<Body> {
359 let mut response = next.run(request).await;
360 let headers = response.headers_mut();
361
362 headers.insert(
364 "content-security-policy",
365 HeaderValue::from_static("default-src 'self'"),
366 );
367
368 headers.insert(
370 "x-content-type-options",
371 HeaderValue::from_static("nosniff"),
372 );
373
374 headers.insert("x-frame-options", HeaderValue::from_static("DENY"));
376
377 headers.insert(
379 "x-xss-protection",
380 HeaderValue::from_static("1; mode=block"),
381 );
382
383 headers.insert(
385 "referrer-policy",
386 HeaderValue::from_static("strict-origin-when-cross-origin"),
387 );
388
389 if state.config.tls.is_some() {
391 headers.insert(
392 "strict-transport-security",
393 HeaderValue::from_static("max-age=31536000; includeSubDomains"),
394 );
395 }
396
397 response
398}
399
400#[cfg(test)]
405mod tests {
406 use super::*;
407 use crate::config::ServerConfig;
408 use axum::body::Body;
409 use axum::http::{Request, StatusCode};
410 use axum::{routing::get, Router};
411 use tower::util::ServiceExt;
412
413 async fn handler() -> &'static str {
414 "ok"
415 }
416
417 #[tokio::test]
418 async fn test_request_id_middleware() {
419 let app = Router::new()
420 .route("/", get(handler))
421 .layer(axum::middleware::from_fn(request_id));
422
423 let response = app
424 .oneshot(
425 Request::builder()
426 .uri("/")
427 .body(Body::empty())
428 .expect("failed to build request"),
429 )
430 .await
431 .expect("failed to execute request");
432
433 assert_eq!(response.status(), StatusCode::OK);
434 assert!(response.headers().contains_key("x-request-id"));
435 }
436
437 #[tokio::test]
438 async fn test_auth_middleware_no_token() {
439 let state = AppState::new(ServerConfig::default());
440 let _ = state
442 .auth
443 .create_user("testuser", "test@test.local", "TestPass123!", "admin");
444
445 let app = Router::new()
446 .route("/", get(handler))
447 .layer(axum::middleware::from_fn_with_state(
448 state.clone(),
449 require_auth,
450 ))
451 .with_state(state);
452
453 let response = app
454 .oneshot(
455 Request::builder()
456 .uri("/")
457 .body(Body::empty())
458 .expect("failed to build request"),
459 )
460 .await
461 .expect("failed to execute request");
462
463 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
464 }
465
466 #[tokio::test]
467 async fn test_auth_middleware_invalid_token() {
468 let state = AppState::new(ServerConfig::default());
469 let _ = state
471 .auth
472 .create_user("testuser", "test@test.local", "TestPass123!", "admin");
473
474 let app = Router::new()
475 .route("/", get(handler))
476 .layer(axum::middleware::from_fn_with_state(
477 state.clone(),
478 require_auth,
479 ))
480 .with_state(state);
481
482 let response = app
483 .oneshot(
484 Request::builder()
485 .uri("/")
486 .header("Authorization", "Bearer invalid_token")
487 .body(Body::empty())
488 .expect("failed to build request"),
489 )
490 .await
491 .expect("failed to execute request");
492
493 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
494 }
495
496 #[tokio::test]
497 async fn test_auth_middleware_valid_token() {
498 let state = AppState::new(ServerConfig::default());
499
500 state
502 .auth
503 .create_user("authtest", "auth@test.com", "TestPassword123!", "admin")
504 .expect("failed to create test user");
505 let login_response = state.auth.login("authtest", "TestPassword123!");
506 let token = login_response.token.expect("login should return token");
507
508 let app = Router::new()
509 .route("/", get(handler))
510 .layer(axum::middleware::from_fn_with_state(
511 state.clone(),
512 require_auth,
513 ))
514 .with_state(state);
515
516 let response = app
517 .oneshot(
518 Request::builder()
519 .uri("/")
520 .header("Authorization", format!("Bearer {}", token))
521 .body(Body::empty())
522 .expect("failed to build request"),
523 )
524 .await
525 .expect("failed to execute request");
526
527 assert_eq!(response.status(), StatusCode::OK);
528 }
529
530 #[test]
531 fn test_rate_limiter_allows_requests() {
532 let limiter = RateLimiter::new(10); for _ in 0..10 {
536 assert!(limiter.check("test_client"));
537 }
538
539 assert!(!limiter.check("test_client"));
541 }
542
543 #[test]
544 fn test_rate_limiter_different_clients() {
545 let limiter = RateLimiter::new(5);
546
547 for _ in 0..5 {
549 assert!(limiter.check("client_a"));
550 assert!(limiter.check("client_b"));
551 }
552
553 assert!(!limiter.check("client_a"));
555 assert!(!limiter.check("client_b"));
556 }
557
558 #[test]
559 fn test_rate_limiter_cleanup() {
560 let limiter = RateLimiter::new(10);
561
562 limiter.check("client_1");
564 limiter.check("client_2");
565
566 limiter.cleanup();
568
569 assert!(limiter.check("client_1"));
571 }
572
573 #[tokio::test]
574 async fn test_security_headers_without_tls() {
575 let state = AppState::new(ServerConfig::default());
576
577 let app = Router::new()
578 .route("/", get(handler))
579 .layer(axum::middleware::from_fn_with_state(
580 state.clone(),
581 security_headers,
582 ))
583 .with_state(state);
584
585 let response = app
586 .oneshot(
587 Request::builder()
588 .uri("/")
589 .body(Body::empty())
590 .expect("failed to build request"),
591 )
592 .await
593 .expect("failed to execute request");
594
595 assert_eq!(response.status(), StatusCode::OK);
596
597 assert_eq!(
599 response
600 .headers()
601 .get("content-security-policy")
602 .map(|v| v.to_str().unwrap()),
603 Some("default-src 'self'")
604 );
605 assert_eq!(
606 response
607 .headers()
608 .get("x-content-type-options")
609 .map(|v| v.to_str().unwrap()),
610 Some("nosniff")
611 );
612 assert_eq!(
613 response
614 .headers()
615 .get("x-frame-options")
616 .map(|v| v.to_str().unwrap()),
617 Some("DENY")
618 );
619 assert_eq!(
620 response
621 .headers()
622 .get("x-xss-protection")
623 .map(|v| v.to_str().unwrap()),
624 Some("1; mode=block")
625 );
626 assert_eq!(
627 response
628 .headers()
629 .get("referrer-policy")
630 .map(|v| v.to_str().unwrap()),
631 Some("strict-origin-when-cross-origin")
632 );
633
634 assert!(response
636 .headers()
637 .get("strict-transport-security")
638 .is_none());
639 }
640
641 #[tokio::test]
642 async fn test_security_headers_with_tls() {
643 let config = ServerConfig::default().with_tls("/path/to/cert.pem", "/path/to/key.pem");
644 let state = AppState::new(config);
645
646 let app = Router::new()
647 .route("/", get(handler))
648 .layer(axum::middleware::from_fn_with_state(
649 state.clone(),
650 security_headers,
651 ))
652 .with_state(state);
653
654 let response = app
655 .oneshot(
656 Request::builder()
657 .uri("/")
658 .body(Body::empty())
659 .expect("failed to build request"),
660 )
661 .await
662 .expect("failed to execute request");
663
664 assert_eq!(response.status(), StatusCode::OK);
665
666 assert_eq!(
668 response
669 .headers()
670 .get("strict-transport-security")
671 .map(|v| v.to_str().unwrap()),
672 Some("max-age=31536000; includeSubDomains")
673 );
674 }
675}