1use crate::error::AllSourceError;
2use crate::infrastructure::security::auth::{AuthManager, Claims, Permission};
3use crate::infrastructure::security::rate_limit::RateLimiter;
4use axum::{
5 extract::{Request, State},
6 http::{HeaderMap, StatusCode},
7 middleware::Next,
8 response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11
12#[derive(Clone)]
14pub struct AuthState {
15 pub auth_manager: Arc<AuthManager>,
16}
17
18#[derive(Clone)]
20pub struct RateLimitState {
21 pub rate_limiter: Arc<RateLimiter>,
22}
23
24#[derive(Debug, Clone)]
26pub struct AuthContext {
27 pub claims: Claims,
28}
29
30impl AuthContext {
31 pub fn require_permission(&self, permission: Permission) -> Result<(), AllSourceError> {
33 if self.claims.has_permission(permission) {
34 Ok(())
35 } else {
36 Err(AllSourceError::ValidationError(
37 "Insufficient permissions".to_string(),
38 ))
39 }
40 }
41
42 pub fn tenant_id(&self) -> &str {
44 &self.claims.tenant_id
45 }
46
47 pub fn user_id(&self) -> &str {
49 &self.claims.sub
50 }
51}
52
53fn extract_token(headers: &HeaderMap) -> Result<String, AllSourceError> {
55 let auth_header = headers
56 .get("authorization")
57 .ok_or_else(|| AllSourceError::ValidationError("Missing authorization header".to_string()))?
58 .to_str()
59 .map_err(|_| AllSourceError::ValidationError("Invalid authorization header".to_string()))?;
60
61 let token = if auth_header.starts_with("Bearer ") {
63 auth_header.trim_start_matches("Bearer ").trim()
64 } else if auth_header.starts_with("bearer ") {
65 auth_header.trim_start_matches("bearer ").trim()
66 } else {
67 auth_header.trim()
68 };
69
70 if token.is_empty() {
71 return Err(AllSourceError::ValidationError(
72 "Empty authorization token".to_string(),
73 ));
74 }
75
76 Ok(token.to_string())
77}
78
79pub async fn auth_middleware(
81 State(auth_state): State<AuthState>,
82 mut request: Request,
83 next: Next,
84) -> Result<Response, AuthError> {
85 let headers = request.headers();
86
87 let token = extract_token(headers)?;
89
90 let claims = if token.starts_with("ask_") {
91 auth_state.auth_manager.validate_api_key(&token)?
93 } else {
94 auth_state.auth_manager.validate_token(&token)?
96 };
97
98 request.extensions_mut().insert(AuthContext { claims });
100
101 Ok(next.run(request).await)
102}
103
104pub async fn optional_auth_middleware(
106 State(auth_state): State<AuthState>,
107 mut request: Request,
108 next: Next,
109) -> Response {
110 let headers = request.headers();
111
112 if let Ok(token) = extract_token(headers) {
113 let claims = if token.starts_with("ask_") {
115 auth_state.auth_manager.validate_api_key(&token).ok()
116 } else {
117 auth_state.auth_manager.validate_token(&token).ok()
118 };
119
120 if let Some(claims) = claims {
121 request.extensions_mut().insert(AuthContext { claims });
122 }
123 }
124
125 next.run(request).await
126}
127
128#[derive(Debug)]
130pub struct AuthError(AllSourceError);
131
132impl From<AllSourceError> for AuthError {
133 fn from(err: AllSourceError) -> Self {
134 AuthError(err)
135 }
136}
137
138impl IntoResponse for AuthError {
139 fn into_response(self) -> Response {
140 let (status, message) = match self.0 {
141 AllSourceError::ValidationError(msg) => (StatusCode::UNAUTHORIZED, msg),
142 _ => (
143 StatusCode::INTERNAL_SERVER_ERROR,
144 "Internal server error".to_string(),
145 ),
146 };
147
148 (status, message).into_response()
149 }
150}
151
152pub struct Authenticated(pub AuthContext);
154
155impl<S> axum::extract::FromRequestParts<S> for Authenticated
156where
157 S: Send + Sync,
158{
159 type Rejection = (StatusCode, &'static str);
160
161 async fn from_request_parts(
162 parts: &mut axum::http::request::Parts,
163 _state: &S,
164 ) -> Result<Self, Self::Rejection> {
165 parts
166 .extensions
167 .get::<AuthContext>()
168 .cloned()
169 .map(Authenticated)
170 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))
171 }
172}
173
174pub struct OptionalAuth(pub Option<AuthContext>);
177
178impl<S> axum::extract::FromRequestParts<S> for OptionalAuth
179where
180 S: Send + Sync,
181{
182 type Rejection = std::convert::Infallible;
183
184 async fn from_request_parts(
185 parts: &mut axum::http::request::Parts,
186 _state: &S,
187 ) -> Result<Self, Self::Rejection> {
188 Ok(OptionalAuth(parts.extensions.get::<AuthContext>().cloned()))
189 }
190}
191
192pub struct Admin(pub AuthContext);
194
195impl<S> axum::extract::FromRequestParts<S> for Admin
196where
197 S: Send + Sync,
198{
199 type Rejection = (StatusCode, &'static str);
200
201 async fn from_request_parts(
202 parts: &mut axum::http::request::Parts,
203 _state: &S,
204 ) -> Result<Self, Self::Rejection> {
205 let auth_ctx = parts
206 .extensions
207 .get::<AuthContext>()
208 .cloned()
209 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
210
211 auth_ctx
212 .require_permission(Permission::Admin)
213 .map_err(|_| (StatusCode::FORBIDDEN, "Admin permission required"))?;
214
215 Ok(Admin(auth_ctx))
216 }
217}
218
219pub async fn rate_limit_middleware(
222 State(rate_limit_state): State<RateLimitState>,
223 request: Request,
224 next: Next,
225) -> Result<Response, RateLimitError> {
226 let auth_ctx = request
228 .extensions()
229 .get::<AuthContext>()
230 .ok_or_else(|| RateLimitError::Unauthorized)?;
231
232 let result = rate_limit_state
234 .rate_limiter
235 .check_rate_limit(auth_ctx.tenant_id());
236
237 if !result.allowed {
238 return Err(RateLimitError::RateLimitExceeded {
239 retry_after: result.retry_after.unwrap_or_default().as_secs(),
240 limit: result.limit,
241 });
242 }
243
244 let mut response = next.run(request).await;
246 let headers = response.headers_mut();
247 headers.insert(
248 "X-RateLimit-Limit",
249 result.limit.to_string().parse().unwrap(),
250 );
251 headers.insert(
252 "X-RateLimit-Remaining",
253 result.remaining.to_string().parse().unwrap(),
254 );
255
256 Ok(response)
257}
258
259#[derive(Debug)]
261pub enum RateLimitError {
262 RateLimitExceeded { retry_after: u64, limit: u32 },
263 Unauthorized,
264}
265
266impl IntoResponse for RateLimitError {
267 fn into_response(self) -> Response {
268 match self {
269 RateLimitError::RateLimitExceeded { retry_after, limit } => {
270 let mut response = (
271 StatusCode::TOO_MANY_REQUESTS,
272 format!("Rate limit exceeded. Limit: {} requests/min", limit),
273 )
274 .into_response();
275
276 if retry_after > 0 {
277 response
278 .headers_mut()
279 .insert("Retry-After", retry_after.to_string().parse().unwrap());
280 }
281
282 response
283 }
284 RateLimitError::Unauthorized => (
285 StatusCode::UNAUTHORIZED,
286 "Authentication required for rate limiting",
287 )
288 .into_response(),
289 }
290 }
291}
292
293#[macro_export]
295macro_rules! require_permission {
296 ($auth:expr, $perm:expr) => {
297 $auth.0.require_permission($perm).map_err(|_| {
298 (
299 axum::http::StatusCode::FORBIDDEN,
300 "Insufficient permissions",
301 )
302 })?
303 };
304}
305
306use crate::domain::entities::Tenant;
311use crate::domain::repositories::TenantRepository;
312use crate::domain::value_objects::TenantId;
313
314#[derive(Clone)]
316pub struct TenantState<R: TenantRepository> {
317 pub tenant_repository: Arc<R>,
318}
319
320#[derive(Debug, Clone)]
325pub struct TenantContext {
326 pub tenant: Tenant,
327}
328
329impl TenantContext {
330 pub fn tenant_id(&self) -> &TenantId {
332 self.tenant.id()
333 }
334
335 pub fn is_active(&self) -> bool {
337 self.tenant.is_active()
338 }
339}
340
341pub async fn tenant_isolation_middleware<R: TenantRepository + 'static>(
355 State(tenant_state): State<TenantState<R>>,
356 mut request: Request,
357 next: Next,
358) -> Result<Response, TenantError> {
359 let auth_ctx = request
361 .extensions()
362 .get::<AuthContext>()
363 .ok_or(TenantError::Unauthorized)?
364 .clone();
365
366 let tenant_id =
368 TenantId::new(auth_ctx.tenant_id().to_string()).map_err(|_| TenantError::InvalidTenant)?;
369
370 let tenant = tenant_state
372 .tenant_repository
373 .find_by_id(&tenant_id)
374 .await
375 .map_err(|e| TenantError::RepositoryError(e.to_string()))?
376 .ok_or(TenantError::TenantNotFound)?;
377
378 if !tenant.is_active() {
380 return Err(TenantError::TenantInactive);
381 }
382
383 request.extensions_mut().insert(TenantContext { tenant });
385
386 Ok(next.run(request).await)
388}
389
390#[derive(Debug)]
392pub enum TenantError {
393 Unauthorized,
394 InvalidTenant,
395 TenantNotFound,
396 TenantInactive,
397 RepositoryError(String),
398}
399
400impl IntoResponse for TenantError {
401 fn into_response(self) -> Response {
402 let (status, message) = match self {
403 TenantError::Unauthorized => (
404 StatusCode::UNAUTHORIZED,
405 "Authentication required for tenant access",
406 ),
407 TenantError::InvalidTenant => (StatusCode::BAD_REQUEST, "Invalid tenant identifier"),
408 TenantError::TenantNotFound => (StatusCode::NOT_FOUND, "Tenant not found"),
409 TenantError::TenantInactive => (StatusCode::FORBIDDEN, "Tenant is inactive"),
410 TenantError::RepositoryError(_) => (
411 StatusCode::INTERNAL_SERVER_ERROR,
412 "Failed to validate tenant",
413 ),
414 };
415
416 (status, message).into_response()
417 }
418}
419
420use uuid::Uuid;
425
426#[derive(Debug, Clone)]
428pub struct RequestId(pub String);
429
430impl RequestId {
431 pub fn new() -> Self {
433 Self(Uuid::new_v4().to_string())
434 }
435
436 pub fn as_str(&self) -> &str {
438 &self.0
439 }
440}
441
442pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
457 let request_id = request
459 .headers()
460 .get("x-request-id")
461 .and_then(|v| v.to_str().ok())
462 .map(|s| RequestId(s.to_string()))
463 .unwrap_or_else(RequestId::new);
464
465 request.extensions_mut().insert(request_id.clone());
467
468 let mut response = next.run(request).await;
470
471 response
473 .headers_mut()
474 .insert("x-request-id", request_id.0.parse().unwrap());
475
476 response
477}
478
479#[derive(Debug, Clone)]
485pub struct SecurityConfig {
486 pub enable_hsts: bool,
488 pub hsts_max_age: u32,
490 pub enable_frame_options: bool,
492 pub frame_options: FrameOptions,
494 pub enable_content_type_options: bool,
496 pub enable_xss_protection: bool,
498 pub csp: Option<String>,
500 pub cors_origins: Vec<String>,
502 pub cors_methods: Vec<String>,
504 pub cors_headers: Vec<String>,
506 pub cors_max_age: u32,
508}
509
510#[derive(Debug, Clone)]
511pub enum FrameOptions {
512 Deny,
513 SameOrigin,
514 AllowFrom(String),
515}
516
517impl Default for SecurityConfig {
518 fn default() -> Self {
519 Self {
520 enable_hsts: true,
521 hsts_max_age: 31536000, enable_frame_options: true,
523 frame_options: FrameOptions::Deny,
524 enable_content_type_options: true,
525 enable_xss_protection: true,
526 csp: Some("default-src 'self'".to_string()),
527 cors_origins: vec!["*".to_string()],
528 cors_methods: vec![
529 "GET".to_string(),
530 "POST".to_string(),
531 "PUT".to_string(),
532 "DELETE".to_string(),
533 ],
534 cors_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
535 cors_max_age: 3600,
536 }
537 }
538}
539
540#[derive(Clone)]
541pub struct SecurityState {
542 pub config: SecurityConfig,
543}
544
545pub async fn security_headers_middleware(
563 State(security_state): State<SecurityState>,
564 request: Request,
565 next: Next,
566) -> Response {
567 let mut response = next.run(request).await;
568 let headers = response.headers_mut();
569 let config = &security_state.config;
570
571 if config.enable_hsts {
573 headers.insert(
574 "strict-transport-security",
575 format!("max-age={}", config.hsts_max_age).parse().unwrap(),
576 );
577 }
578
579 if config.enable_frame_options {
581 let value = match &config.frame_options {
582 FrameOptions::Deny => "DENY",
583 FrameOptions::SameOrigin => "SAMEORIGIN",
584 FrameOptions::AllowFrom(origin) => origin,
585 };
586 headers.insert("x-frame-options", value.parse().unwrap());
587 }
588
589 if config.enable_content_type_options {
591 headers.insert("x-content-type-options", "nosniff".parse().unwrap());
592 }
593
594 if config.enable_xss_protection {
596 headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
597 }
598
599 if let Some(csp) = &config.csp {
601 headers.insert("content-security-policy", csp.parse().unwrap());
602 }
603
604 headers.insert(
606 "access-control-allow-origin",
607 config.cors_origins.join(", ").parse().unwrap(),
608 );
609 headers.insert(
610 "access-control-allow-methods",
611 config.cors_methods.join(", ").parse().unwrap(),
612 );
613 headers.insert(
614 "access-control-allow-headers",
615 config.cors_headers.join(", ").parse().unwrap(),
616 );
617 headers.insert(
618 "access-control-max-age",
619 config.cors_max_age.to_string().parse().unwrap(),
620 );
621
622 response
623}
624
625use crate::infrastructure::security::IpFilter;
630use std::net::SocketAddr;
631
632#[derive(Clone)]
633pub struct IpFilterState {
634 pub ip_filter: Arc<IpFilter>,
635}
636
637pub async fn ip_filter_middleware(
649 State(ip_filter_state): State<IpFilterState>,
650 request: Request,
651 next: Next,
652) -> Result<Response, IpFilterError> {
653 let client_ip = request
655 .extensions()
656 .get::<axum::extract::ConnectInfo<SocketAddr>>()
657 .map(|connect_info| connect_info.0.ip())
658 .ok_or(IpFilterError::NoIpAddress)?;
659
660 let result = if let Some(tenant_ctx) = request.extensions().get::<TenantContext>() {
662 ip_filter_state
664 .ip_filter
665 .is_allowed_for_tenant(tenant_ctx.tenant_id(), &client_ip)
666 } else {
667 ip_filter_state.ip_filter.is_allowed(&client_ip)
669 };
670
671 if !result.allowed {
673 return Err(IpFilterError::Blocked {
674 reason: result.reason,
675 });
676 }
677
678 Ok(next.run(request).await)
680}
681
682#[derive(Debug)]
684pub enum IpFilterError {
685 NoIpAddress,
686 Blocked { reason: String },
687}
688
689impl IntoResponse for IpFilterError {
690 fn into_response(self) -> Response {
691 match self {
692 IpFilterError::NoIpAddress => (
693 StatusCode::BAD_REQUEST,
694 "Unable to determine client IP address",
695 )
696 .into_response(),
697 IpFilterError::Blocked { reason } => {
698 (StatusCode::FORBIDDEN, format!("Access denied: {}", reason)).into_response()
699 }
700 }
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707 use crate::infrastructure::security::auth::{Role, User};
708
709 #[test]
710 fn test_extract_bearer_token() {
711 let mut headers = HeaderMap::new();
712 headers.insert("authorization", "Bearer test_token_123".parse().unwrap());
713
714 let token = extract_token(&headers).unwrap();
715 assert_eq!(token, "test_token_123");
716 }
717
718 #[test]
719 fn test_extract_lowercase_bearer() {
720 let mut headers = HeaderMap::new();
721 headers.insert("authorization", "bearer test_token_123".parse().unwrap());
722
723 let token = extract_token(&headers).unwrap();
724 assert_eq!(token, "test_token_123");
725 }
726
727 #[test]
728 fn test_extract_plain_token() {
729 let mut headers = HeaderMap::new();
730 headers.insert("authorization", "test_token_123".parse().unwrap());
731
732 let token = extract_token(&headers).unwrap();
733 assert_eq!(token, "test_token_123");
734 }
735
736 #[test]
737 fn test_missing_auth_header() {
738 let headers = HeaderMap::new();
739 assert!(extract_token(&headers).is_err());
740 }
741
742 #[test]
743 fn test_auth_context_permissions() {
744 let claims = Claims::new(
745 "user1".to_string(),
746 "tenant1".to_string(),
747 Role::Developer,
748 chrono::Duration::hours(1),
749 );
750
751 let ctx = AuthContext { claims };
752
753 assert!(ctx.require_permission(Permission::Read).is_ok());
754 assert!(ctx.require_permission(Permission::Write).is_ok());
755 assert!(ctx.require_permission(Permission::Admin).is_err());
756 }
757}