1use crate::auth::{AuthManager, Claims, Permission};
2use crate::error::AllSourceError;
3use crate::rate_limit::RateLimiter;
4use axum::{
5 extract::{Request, State},
6 http::{HeaderMap, StatusCode},
7 middleware::Next,
8 response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11
12#[derive(Clone)]
14pub struct AuthState {
15 pub auth_manager: Arc<AuthManager>,
16}
17
18#[derive(Clone)]
20pub struct RateLimitState {
21 pub rate_limiter: Arc<RateLimiter>,
22}
23
24#[derive(Debug, Clone)]
26pub struct AuthContext {
27 pub claims: Claims,
28}
29
30impl AuthContext {
31 pub fn require_permission(&self, permission: Permission) -> Result<(), AllSourceError> {
33 if self.claims.has_permission(permission) {
34 Ok(())
35 } else {
36 Err(AllSourceError::ValidationError(
37 "Insufficient permissions".to_string(),
38 ))
39 }
40 }
41
42 pub fn tenant_id(&self) -> &str {
44 &self.claims.tenant_id
45 }
46
47 pub fn user_id(&self) -> &str {
49 &self.claims.sub
50 }
51}
52
53fn extract_token(headers: &HeaderMap) -> Result<String, AllSourceError> {
55 let auth_header = headers
56 .get("authorization")
57 .ok_or_else(|| AllSourceError::ValidationError("Missing authorization header".to_string()))?
58 .to_str()
59 .map_err(|_| AllSourceError::ValidationError("Invalid authorization header".to_string()))?;
60
61 let token = if auth_header.starts_with("Bearer ") {
63 auth_header.trim_start_matches("Bearer ").trim()
64 } else if auth_header.starts_with("bearer ") {
65 auth_header.trim_start_matches("bearer ").trim()
66 } else {
67 auth_header.trim()
68 };
69
70 if token.is_empty() {
71 return Err(AllSourceError::ValidationError(
72 "Empty authorization token".to_string(),
73 ));
74 }
75
76 Ok(token.to_string())
77}
78
79pub async fn auth_middleware(
81 State(auth_state): State<AuthState>,
82 mut request: Request,
83 next: Next,
84) -> Result<Response, AuthError> {
85 let headers = request.headers();
86
87 let token = extract_token(headers)?;
89
90 let claims = if token.starts_with("ask_") {
91 auth_state.auth_manager.validate_api_key(&token)?
93 } else {
94 auth_state.auth_manager.validate_token(&token)?
96 };
97
98 request.extensions_mut().insert(AuthContext { claims });
100
101 Ok(next.run(request).await)
102}
103
104pub async fn optional_auth_middleware(
106 State(auth_state): State<AuthState>,
107 mut request: Request,
108 next: Next,
109) -> Response {
110 let headers = request.headers();
111
112 if let Ok(token) = extract_token(headers) {
113 let claims = if token.starts_with("ask_") {
115 auth_state.auth_manager.validate_api_key(&token).ok()
116 } else {
117 auth_state.auth_manager.validate_token(&token).ok()
118 };
119
120 if let Some(claims) = claims {
121 request.extensions_mut().insert(AuthContext { claims });
122 }
123 }
124
125 next.run(request).await
126}
127
128#[derive(Debug)]
130pub struct AuthError(AllSourceError);
131
132impl From<AllSourceError> for AuthError {
133 fn from(err: AllSourceError) -> Self {
134 AuthError(err)
135 }
136}
137
138impl IntoResponse for AuthError {
139 fn into_response(self) -> Response {
140 let (status, message) = match self.0 {
141 AllSourceError::ValidationError(msg) => (StatusCode::UNAUTHORIZED, msg),
142 _ => (
143 StatusCode::INTERNAL_SERVER_ERROR,
144 "Internal server error".to_string(),
145 ),
146 };
147
148 (status, message).into_response()
149 }
150}
151
152pub struct Authenticated(pub AuthContext);
154
155#[axum::async_trait]
156impl<S> axum::extract::FromRequestParts<S> for Authenticated
157where
158 S: Send + Sync,
159{
160 type Rejection = (StatusCode, &'static str);
161
162 async fn from_request_parts(
163 parts: &mut axum::http::request::Parts,
164 _state: &S,
165 ) -> Result<Self, Self::Rejection> {
166 parts
167 .extensions
168 .get::<AuthContext>()
169 .cloned()
170 .map(Authenticated)
171 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))
172 }
173}
174
175pub struct Admin(pub AuthContext);
177
178#[axum::async_trait]
179impl<S> axum::extract::FromRequestParts<S> for Admin
180where
181 S: Send + Sync,
182{
183 type Rejection = (StatusCode, &'static str);
184
185 async fn from_request_parts(
186 parts: &mut axum::http::request::Parts,
187 _state: &S,
188 ) -> Result<Self, Self::Rejection> {
189 let auth_ctx = parts
190 .extensions
191 .get::<AuthContext>()
192 .cloned()
193 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
194
195 auth_ctx
196 .require_permission(Permission::Admin)
197 .map_err(|_| (StatusCode::FORBIDDEN, "Admin permission required"))?;
198
199 Ok(Admin(auth_ctx))
200 }
201}
202
203pub async fn rate_limit_middleware(
206 State(rate_limit_state): State<RateLimitState>,
207 request: Request,
208 next: Next,
209) -> Result<Response, RateLimitError> {
210 let auth_ctx = request
212 .extensions()
213 .get::<AuthContext>()
214 .ok_or_else(|| RateLimitError::Unauthorized)?;
215
216 let result = rate_limit_state
218 .rate_limiter
219 .check_rate_limit(auth_ctx.tenant_id());
220
221 if !result.allowed {
222 return Err(RateLimitError::RateLimitExceeded {
223 retry_after: result.retry_after.unwrap_or_default().as_secs(),
224 limit: result.limit,
225 });
226 }
227
228 let mut response = next.run(request).await;
230 let headers = response.headers_mut();
231 headers.insert("X-RateLimit-Limit", result.limit.to_string().parse().unwrap());
232 headers.insert("X-RateLimit-Remaining", result.remaining.to_string().parse().unwrap());
233
234 Ok(response)
235}
236
237#[derive(Debug)]
239pub enum RateLimitError {
240 RateLimitExceeded { retry_after: u64, limit: u32 },
241 Unauthorized,
242}
243
244impl IntoResponse for RateLimitError {
245 fn into_response(self) -> Response {
246 match self {
247 RateLimitError::RateLimitExceeded { retry_after, limit } => {
248 let mut response = (
249 StatusCode::TOO_MANY_REQUESTS,
250 format!("Rate limit exceeded. Limit: {} requests/min", limit),
251 )
252 .into_response();
253
254 if retry_after > 0 {
255 response.headers_mut().insert(
256 "Retry-After",
257 retry_after.to_string().parse().unwrap(),
258 );
259 }
260
261 response
262 }
263 RateLimitError::Unauthorized => (
264 StatusCode::UNAUTHORIZED,
265 "Authentication required for rate limiting",
266 )
267 .into_response(),
268 }
269 }
270}
271
272#[macro_export]
274macro_rules! require_permission {
275 ($auth:expr, $perm:expr) => {
276 $auth
277 .0
278 .require_permission($perm)
279 .map_err(|_| (axum::http::StatusCode::FORBIDDEN, "Insufficient permissions"))?
280 };
281}
282
283use crate::domain::entities::Tenant;
288use crate::domain::repositories::TenantRepository;
289use crate::domain::value_objects::TenantId;
290
291#[derive(Clone)]
293pub struct TenantState<R: TenantRepository> {
294 pub tenant_repository: Arc<R>,
295}
296
297#[derive(Debug, Clone)]
302pub struct TenantContext {
303 pub tenant: Tenant,
304}
305
306impl TenantContext {
307 pub fn tenant_id(&self) -> &TenantId {
309 self.tenant.id()
310 }
311
312 pub fn is_active(&self) -> bool {
314 self.tenant.is_active()
315 }
316}
317
318pub async fn tenant_isolation_middleware<R: TenantRepository + 'static>(
332 State(tenant_state): State<TenantState<R>>,
333 mut request: Request,
334 next: Next,
335) -> Result<Response, TenantError> {
336 let auth_ctx = request
338 .extensions()
339 .get::<AuthContext>()
340 .ok_or(TenantError::Unauthorized)?
341 .clone();
342
343 let tenant_id = TenantId::new(auth_ctx.tenant_id().to_string())
345 .map_err(|_| TenantError::InvalidTenant)?;
346
347 let tenant = tenant_state
349 .tenant_repository
350 .find_by_id(&tenant_id)
351 .await
352 .map_err(|e| TenantError::RepositoryError(e.to_string()))?
353 .ok_or(TenantError::TenantNotFound)?;
354
355 if !tenant.is_active() {
357 return Err(TenantError::TenantInactive);
358 }
359
360 request.extensions_mut().insert(TenantContext { tenant });
362
363 Ok(next.run(request).await)
365}
366
367#[derive(Debug)]
369pub enum TenantError {
370 Unauthorized,
371 InvalidTenant,
372 TenantNotFound,
373 TenantInactive,
374 RepositoryError(String),
375}
376
377impl IntoResponse for TenantError {
378 fn into_response(self) -> Response {
379 let (status, message) = match self {
380 TenantError::Unauthorized => (
381 StatusCode::UNAUTHORIZED,
382 "Authentication required for tenant access",
383 ),
384 TenantError::InvalidTenant => (
385 StatusCode::BAD_REQUEST,
386 "Invalid tenant identifier",
387 ),
388 TenantError::TenantNotFound => (
389 StatusCode::NOT_FOUND,
390 "Tenant not found",
391 ),
392 TenantError::TenantInactive => (
393 StatusCode::FORBIDDEN,
394 "Tenant is inactive",
395 ),
396 TenantError::RepositoryError(_) => (
397 StatusCode::INTERNAL_SERVER_ERROR,
398 "Failed to validate tenant",
399 ),
400 };
401
402 (status, message).into_response()
403 }
404}
405
406use uuid::Uuid;
411
412#[derive(Debug, Clone)]
414pub struct RequestId(pub String);
415
416impl RequestId {
417 pub fn new() -> Self {
419 Self(Uuid::new_v4().to_string())
420 }
421
422 pub fn as_str(&self) -> &str {
424 &self.0
425 }
426}
427
428pub async fn request_id_middleware(
443 mut request: Request,
444 next: Next,
445) -> Response {
446 let request_id = request
448 .headers()
449 .get("x-request-id")
450 .and_then(|v| v.to_str().ok())
451 .map(|s| RequestId(s.to_string()))
452 .unwrap_or_else(RequestId::new);
453
454 request.extensions_mut().insert(request_id.clone());
456
457 let mut response = next.run(request).await;
459
460 response.headers_mut().insert(
462 "x-request-id",
463 request_id.0.parse().unwrap(),
464 );
465
466 response
467}
468
469#[derive(Debug, Clone)]
475pub struct SecurityConfig {
476 pub enable_hsts: bool,
478 pub hsts_max_age: u32,
480 pub enable_frame_options: bool,
482 pub frame_options: FrameOptions,
484 pub enable_content_type_options: bool,
486 pub enable_xss_protection: bool,
488 pub csp: Option<String>,
490 pub cors_origins: Vec<String>,
492 pub cors_methods: Vec<String>,
494 pub cors_headers: Vec<String>,
496 pub cors_max_age: u32,
498}
499
500#[derive(Debug, Clone)]
501pub enum FrameOptions {
502 Deny,
503 SameOrigin,
504 AllowFrom(String),
505}
506
507impl Default for SecurityConfig {
508 fn default() -> Self {
509 Self {
510 enable_hsts: true,
511 hsts_max_age: 31536000, enable_frame_options: true,
513 frame_options: FrameOptions::Deny,
514 enable_content_type_options: true,
515 enable_xss_protection: true,
516 csp: Some("default-src 'self'".to_string()),
517 cors_origins: vec!["*".to_string()],
518 cors_methods: vec!["GET".to_string(), "POST".to_string(), "PUT".to_string(), "DELETE".to_string()],
519 cors_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
520 cors_max_age: 3600,
521 }
522 }
523}
524
525#[derive(Clone)]
526pub struct SecurityState {
527 pub config: SecurityConfig,
528}
529
530pub async fn security_headers_middleware(
548 State(security_state): State<SecurityState>,
549 request: Request,
550 next: Next,
551) -> Response {
552 let mut response = next.run(request).await;
553 let headers = response.headers_mut();
554 let config = &security_state.config;
555
556 if config.enable_hsts {
558 headers.insert(
559 "strict-transport-security",
560 format!("max-age={}", config.hsts_max_age).parse().unwrap(),
561 );
562 }
563
564 if config.enable_frame_options {
566 let value = match &config.frame_options {
567 FrameOptions::Deny => "DENY",
568 FrameOptions::SameOrigin => "SAMEORIGIN",
569 FrameOptions::AllowFrom(origin) => origin,
570 };
571 headers.insert("x-frame-options", value.parse().unwrap());
572 }
573
574 if config.enable_content_type_options {
576 headers.insert("x-content-type-options", "nosniff".parse().unwrap());
577 }
578
579 if config.enable_xss_protection {
581 headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
582 }
583
584 if let Some(csp) = &config.csp {
586 headers.insert("content-security-policy", csp.parse().unwrap());
587 }
588
589 headers.insert(
591 "access-control-allow-origin",
592 config.cors_origins.join(", ").parse().unwrap(),
593 );
594 headers.insert(
595 "access-control-allow-methods",
596 config.cors_methods.join(", ").parse().unwrap(),
597 );
598 headers.insert(
599 "access-control-allow-headers",
600 config.cors_headers.join(", ").parse().unwrap(),
601 );
602 headers.insert(
603 "access-control-max-age",
604 config.cors_max_age.to_string().parse().unwrap(),
605 );
606
607 response
608}
609
610use crate::infrastructure::security::IpFilter;
615use std::net::SocketAddr;
616
617#[derive(Clone)]
618pub struct IpFilterState {
619 pub ip_filter: Arc<IpFilter>,
620}
621
622pub async fn ip_filter_middleware(
634 State(ip_filter_state): State<IpFilterState>,
635 request: Request,
636 next: Next,
637) -> Result<Response, IpFilterError> {
638 let client_ip = request
640 .extensions()
641 .get::<axum::extract::ConnectInfo<SocketAddr>>()
642 .map(|connect_info| connect_info.0.ip())
643 .ok_or(IpFilterError::NoIpAddress)?;
644
645 let result = if let Some(tenant_ctx) = request.extensions().get::<TenantContext>() {
647 ip_filter_state
649 .ip_filter
650 .is_allowed_for_tenant(tenant_ctx.tenant_id(), &client_ip)
651 } else {
652 ip_filter_state.ip_filter.is_allowed(&client_ip)
654 };
655
656 if !result.allowed {
658 return Err(IpFilterError::Blocked { reason: result.reason });
659 }
660
661 Ok(next.run(request).await)
663}
664
665#[derive(Debug)]
667pub enum IpFilterError {
668 NoIpAddress,
669 Blocked { reason: String },
670}
671
672impl IntoResponse for IpFilterError {
673 fn into_response(self) -> Response {
674 match self {
675 IpFilterError::NoIpAddress => (
676 StatusCode::BAD_REQUEST,
677 "Unable to determine client IP address",
678 ).into_response(),
679 IpFilterError::Blocked { reason } => (
680 StatusCode::FORBIDDEN,
681 format!("Access denied: {}", reason),
682 ).into_response(),
683 }
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::auth::{Role, User};
691
692 #[test]
693 fn test_extract_bearer_token() {
694 let mut headers = HeaderMap::new();
695 headers.insert("authorization", "Bearer test_token_123".parse().unwrap());
696
697 let token = extract_token(&headers).unwrap();
698 assert_eq!(token, "test_token_123");
699 }
700
701 #[test]
702 fn test_extract_lowercase_bearer() {
703 let mut headers = HeaderMap::new();
704 headers.insert("authorization", "bearer test_token_123".parse().unwrap());
705
706 let token = extract_token(&headers).unwrap();
707 assert_eq!(token, "test_token_123");
708 }
709
710 #[test]
711 fn test_extract_plain_token() {
712 let mut headers = HeaderMap::new();
713 headers.insert("authorization", "test_token_123".parse().unwrap());
714
715 let token = extract_token(&headers).unwrap();
716 assert_eq!(token, "test_token_123");
717 }
718
719 #[test]
720 fn test_missing_auth_header() {
721 let headers = HeaderMap::new();
722 assert!(extract_token(&headers).is_err());
723 }
724
725 #[test]
726 fn test_auth_context_permissions() {
727 let claims = Claims::new(
728 "user1".to_string(),
729 "tenant1".to_string(),
730 Role::Developer,
731 chrono::Duration::hours(1),
732 );
733
734 let ctx = AuthContext { claims };
735
736 assert!(ctx.require_permission(Permission::Read).is_ok());
737 assert!(ctx.require_permission(Permission::Write).is_ok());
738 assert!(ctx.require_permission(Permission::Admin).is_err());
739 }
740}