1use crate::error::AllSourceError;
2use crate::infrastructure::security::auth::{AuthManager, Claims, Permission};
3use crate::infrastructure::security::rate_limit::RateLimiter;
4use axum::{
5 extract::{Request, State},
6 http::{HeaderMap, StatusCode},
7 middleware::Next,
8 response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11
12#[derive(Clone)]
14pub struct AuthState {
15 pub auth_manager: Arc<AuthManager>,
16}
17
18#[derive(Clone)]
20pub struct RateLimitState {
21 pub rate_limiter: Arc<RateLimiter>,
22}
23
24#[derive(Debug, Clone)]
26pub struct AuthContext {
27 pub claims: Claims,
28}
29
30impl AuthContext {
31 pub fn require_permission(&self, permission: Permission) -> Result<(), AllSourceError> {
33 if self.claims.has_permission(permission) {
34 Ok(())
35 } else {
36 Err(AllSourceError::ValidationError(
37 "Insufficient permissions".to_string(),
38 ))
39 }
40 }
41
42 pub fn tenant_id(&self) -> &str {
44 &self.claims.tenant_id
45 }
46
47 pub fn user_id(&self) -> &str {
49 &self.claims.sub
50 }
51}
52
53fn extract_token(headers: &HeaderMap) -> Result<String, AllSourceError> {
55 let auth_header = headers
56 .get("authorization")
57 .ok_or_else(|| AllSourceError::ValidationError("Missing authorization header".to_string()))?
58 .to_str()
59 .map_err(|_| AllSourceError::ValidationError("Invalid authorization header".to_string()))?;
60
61 let token = if auth_header.starts_with("Bearer ") {
63 auth_header.trim_start_matches("Bearer ").trim()
64 } else if auth_header.starts_with("bearer ") {
65 auth_header.trim_start_matches("bearer ").trim()
66 } else {
67 auth_header.trim()
68 };
69
70 if token.is_empty() {
71 return Err(AllSourceError::ValidationError(
72 "Empty authorization token".to_string(),
73 ));
74 }
75
76 Ok(token.to_string())
77}
78
79pub async fn auth_middleware(
81 State(auth_state): State<AuthState>,
82 mut request: Request,
83 next: Next,
84) -> Result<Response, AuthError> {
85 let headers = request.headers();
86
87 let token = extract_token(headers)?;
89
90 let claims = if token.starts_with("ask_") {
91 auth_state.auth_manager.validate_api_key(&token)?
93 } else {
94 auth_state.auth_manager.validate_token(&token)?
96 };
97
98 request.extensions_mut().insert(AuthContext { claims });
100
101 Ok(next.run(request).await)
102}
103
104pub async fn optional_auth_middleware(
106 State(auth_state): State<AuthState>,
107 mut request: Request,
108 next: Next,
109) -> Response {
110 let headers = request.headers();
111
112 if let Ok(token) = extract_token(headers) {
113 let claims = if token.starts_with("ask_") {
115 auth_state.auth_manager.validate_api_key(&token).ok()
116 } else {
117 auth_state.auth_manager.validate_token(&token).ok()
118 };
119
120 if let Some(claims) = claims {
121 request.extensions_mut().insert(AuthContext { claims });
122 }
123 }
124
125 next.run(request).await
126}
127
128#[derive(Debug)]
130pub struct AuthError(AllSourceError);
131
132impl From<AllSourceError> for AuthError {
133 fn from(err: AllSourceError) -> Self {
134 AuthError(err)
135 }
136}
137
138impl IntoResponse for AuthError {
139 fn into_response(self) -> Response {
140 let (status, message) = match self.0 {
141 AllSourceError::ValidationError(msg) => (StatusCode::UNAUTHORIZED, msg),
142 _ => (
143 StatusCode::INTERNAL_SERVER_ERROR,
144 "Internal server error".to_string(),
145 ),
146 };
147
148 (status, message).into_response()
149 }
150}
151
152pub struct Authenticated(pub AuthContext);
154
155impl<S> axum::extract::FromRequestParts<S> for Authenticated
156where
157 S: Send + Sync,
158{
159 type Rejection = (StatusCode, &'static str);
160
161 async fn from_request_parts(
162 parts: &mut axum::http::request::Parts,
163 _state: &S,
164 ) -> Result<Self, Self::Rejection> {
165 parts
166 .extensions
167 .get::<AuthContext>()
168 .cloned()
169 .map(Authenticated)
170 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))
171 }
172}
173
174pub struct OptionalAuth(pub Option<AuthContext>);
177
178impl<S> axum::extract::FromRequestParts<S> for OptionalAuth
179where
180 S: Send + Sync,
181{
182 type Rejection = std::convert::Infallible;
183
184 async fn from_request_parts(
185 parts: &mut axum::http::request::Parts,
186 _state: &S,
187 ) -> Result<Self, Self::Rejection> {
188 Ok(OptionalAuth(parts.extensions.get::<AuthContext>().cloned()))
189 }
190}
191
192pub struct Admin(pub AuthContext);
194
195impl<S> axum::extract::FromRequestParts<S> for Admin
196where
197 S: Send + Sync,
198{
199 type Rejection = (StatusCode, &'static str);
200
201 async fn from_request_parts(
202 parts: &mut axum::http::request::Parts,
203 _state: &S,
204 ) -> Result<Self, Self::Rejection> {
205 let auth_ctx = parts
206 .extensions
207 .get::<AuthContext>()
208 .cloned()
209 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
210
211 auth_ctx
212 .require_permission(Permission::Admin)
213 .map_err(|_| (StatusCode::FORBIDDEN, "Admin permission required"))?;
214
215 Ok(Admin(auth_ctx))
216 }
217}
218
219pub async fn rate_limit_middleware(
222 State(rate_limit_state): State<RateLimitState>,
223 request: Request,
224 next: Next,
225) -> Result<Response, RateLimitError> {
226 let auth_ctx = request
228 .extensions()
229 .get::<AuthContext>()
230 .ok_or(RateLimitError::Unauthorized)?;
231
232 let result = rate_limit_state
234 .rate_limiter
235 .check_rate_limit(auth_ctx.tenant_id());
236
237 if !result.allowed {
238 return Err(RateLimitError::RateLimitExceeded {
239 retry_after: result.retry_after.unwrap_or_default().as_secs(),
240 limit: result.limit,
241 });
242 }
243
244 let mut response = next.run(request).await;
246 let headers = response.headers_mut();
247 headers.insert(
248 "X-RateLimit-Limit",
249 result.limit.to_string().parse().unwrap(),
250 );
251 headers.insert(
252 "X-RateLimit-Remaining",
253 result.remaining.to_string().parse().unwrap(),
254 );
255
256 Ok(response)
257}
258
259#[derive(Debug)]
261pub enum RateLimitError {
262 RateLimitExceeded { retry_after: u64, limit: u32 },
263 Unauthorized,
264}
265
266impl IntoResponse for RateLimitError {
267 fn into_response(self) -> Response {
268 match self {
269 RateLimitError::RateLimitExceeded { retry_after, limit } => {
270 let mut response = (
271 StatusCode::TOO_MANY_REQUESTS,
272 format!("Rate limit exceeded. Limit: {limit} requests/min"),
273 )
274 .into_response();
275
276 if retry_after > 0 {
277 response
278 .headers_mut()
279 .insert("Retry-After", retry_after.to_string().parse().unwrap());
280 }
281
282 response
283 }
284 RateLimitError::Unauthorized => (
285 StatusCode::UNAUTHORIZED,
286 "Authentication required for rate limiting",
287 )
288 .into_response(),
289 }
290 }
291}
292
293#[macro_export]
295macro_rules! require_permission {
296 ($auth:expr, $perm:expr) => {
297 $auth.0.require_permission($perm).map_err(|_| {
298 (
299 axum::http::StatusCode::FORBIDDEN,
300 "Insufficient permissions",
301 )
302 })?
303 };
304}
305
306use crate::domain::entities::Tenant;
311use crate::domain::repositories::TenantRepository;
312use crate::domain::value_objects::TenantId;
313
314#[derive(Clone)]
316pub struct TenantState<R: TenantRepository> {
317 pub tenant_repository: Arc<R>,
318}
319
320#[derive(Debug, Clone)]
325pub struct TenantContext {
326 pub tenant: Tenant,
327}
328
329impl TenantContext {
330 pub fn tenant_id(&self) -> &TenantId {
332 self.tenant.id()
333 }
334
335 pub fn is_active(&self) -> bool {
337 self.tenant.is_active()
338 }
339}
340
341pub async fn tenant_isolation_middleware<R: TenantRepository + 'static>(
355 State(tenant_state): State<TenantState<R>>,
356 mut request: Request,
357 next: Next,
358) -> Result<Response, TenantError> {
359 let auth_ctx = request
361 .extensions()
362 .get::<AuthContext>()
363 .ok_or(TenantError::Unauthorized)?
364 .clone();
365
366 let tenant_id =
368 TenantId::new(auth_ctx.tenant_id().to_string()).map_err(|_| TenantError::InvalidTenant)?;
369
370 let tenant = tenant_state
372 .tenant_repository
373 .find_by_id(&tenant_id)
374 .await
375 .map_err(|e| TenantError::RepositoryError(e.to_string()))?
376 .ok_or(TenantError::TenantNotFound)?;
377
378 if !tenant.is_active() {
380 return Err(TenantError::TenantInactive);
381 }
382
383 request.extensions_mut().insert(TenantContext { tenant });
385
386 Ok(next.run(request).await)
388}
389
390#[derive(Debug)]
392pub enum TenantError {
393 Unauthorized,
394 InvalidTenant,
395 TenantNotFound,
396 TenantInactive,
397 RepositoryError(String),
398}
399
400impl IntoResponse for TenantError {
401 fn into_response(self) -> Response {
402 let (status, message) = match self {
403 TenantError::Unauthorized => (
404 StatusCode::UNAUTHORIZED,
405 "Authentication required for tenant access",
406 ),
407 TenantError::InvalidTenant => (StatusCode::BAD_REQUEST, "Invalid tenant identifier"),
408 TenantError::TenantNotFound => (StatusCode::NOT_FOUND, "Tenant not found"),
409 TenantError::TenantInactive => (StatusCode::FORBIDDEN, "Tenant is inactive"),
410 TenantError::RepositoryError(_) => (
411 StatusCode::INTERNAL_SERVER_ERROR,
412 "Failed to validate tenant",
413 ),
414 };
415
416 (status, message).into_response()
417 }
418}
419
420use uuid::Uuid;
425
426#[derive(Debug, Clone)]
428pub struct RequestId(pub String);
429
430impl Default for RequestId {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436impl RequestId {
437 pub fn new() -> Self {
439 Self(Uuid::new_v4().to_string())
440 }
441
442 pub fn as_str(&self) -> &str {
444 &self.0
445 }
446}
447
448pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
463 let request_id = request
465 .headers()
466 .get("x-request-id")
467 .and_then(|v| v.to_str().ok())
468 .map(|s| RequestId(s.to_string()))
469 .unwrap_or_else(RequestId::new);
470
471 request.extensions_mut().insert(request_id.clone());
473
474 let mut response = next.run(request).await;
476
477 response
479 .headers_mut()
480 .insert("x-request-id", request_id.0.parse().unwrap());
481
482 response
483}
484
485#[derive(Debug, Clone)]
491pub struct SecurityConfig {
492 pub enable_hsts: bool,
494 pub hsts_max_age: u32,
496 pub enable_frame_options: bool,
498 pub frame_options: FrameOptions,
500 pub enable_content_type_options: bool,
502 pub enable_xss_protection: bool,
504 pub csp: Option<String>,
506 pub cors_origins: Vec<String>,
508 pub cors_methods: Vec<String>,
510 pub cors_headers: Vec<String>,
512 pub cors_max_age: u32,
514}
515
516#[derive(Debug, Clone)]
517pub enum FrameOptions {
518 Deny,
519 SameOrigin,
520 AllowFrom(String),
521}
522
523impl Default for SecurityConfig {
524 fn default() -> Self {
525 Self {
526 enable_hsts: true,
527 hsts_max_age: 31536000, enable_frame_options: true,
529 frame_options: FrameOptions::Deny,
530 enable_content_type_options: true,
531 enable_xss_protection: true,
532 csp: Some("default-src 'self'".to_string()),
533 cors_origins: vec!["*".to_string()],
534 cors_methods: vec![
535 "GET".to_string(),
536 "POST".to_string(),
537 "PUT".to_string(),
538 "DELETE".to_string(),
539 ],
540 cors_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
541 cors_max_age: 3600,
542 }
543 }
544}
545
546#[derive(Clone)]
547pub struct SecurityState {
548 pub config: SecurityConfig,
549}
550
551pub async fn security_headers_middleware(
569 State(security_state): State<SecurityState>,
570 request: Request,
571 next: Next,
572) -> Response {
573 let mut response = next.run(request).await;
574 let headers = response.headers_mut();
575 let config = &security_state.config;
576
577 if config.enable_hsts {
579 headers.insert(
580 "strict-transport-security",
581 format!("max-age={}", config.hsts_max_age).parse().unwrap(),
582 );
583 }
584
585 if config.enable_frame_options {
587 let value = match &config.frame_options {
588 FrameOptions::Deny => "DENY",
589 FrameOptions::SameOrigin => "SAMEORIGIN",
590 FrameOptions::AllowFrom(origin) => origin,
591 };
592 headers.insert("x-frame-options", value.parse().unwrap());
593 }
594
595 if config.enable_content_type_options {
597 headers.insert("x-content-type-options", "nosniff".parse().unwrap());
598 }
599
600 if config.enable_xss_protection {
602 headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
603 }
604
605 if let Some(csp) = &config.csp {
607 headers.insert("content-security-policy", csp.parse().unwrap());
608 }
609
610 headers.insert(
612 "access-control-allow-origin",
613 config.cors_origins.join(", ").parse().unwrap(),
614 );
615 headers.insert(
616 "access-control-allow-methods",
617 config.cors_methods.join(", ").parse().unwrap(),
618 );
619 headers.insert(
620 "access-control-allow-headers",
621 config.cors_headers.join(", ").parse().unwrap(),
622 );
623 headers.insert(
624 "access-control-max-age",
625 config.cors_max_age.to_string().parse().unwrap(),
626 );
627
628 response
629}
630
631use crate::infrastructure::security::IpFilter;
636use std::net::SocketAddr;
637
638#[derive(Clone)]
639pub struct IpFilterState {
640 pub ip_filter: Arc<IpFilter>,
641}
642
643pub async fn ip_filter_middleware(
655 State(ip_filter_state): State<IpFilterState>,
656 request: Request,
657 next: Next,
658) -> Result<Response, IpFilterError> {
659 let client_ip = request
661 .extensions()
662 .get::<axum::extract::ConnectInfo<SocketAddr>>()
663 .map(|connect_info| connect_info.0.ip())
664 .ok_or(IpFilterError::NoIpAddress)?;
665
666 let result = if let Some(tenant_ctx) = request.extensions().get::<TenantContext>() {
668 ip_filter_state
670 .ip_filter
671 .is_allowed_for_tenant(tenant_ctx.tenant_id(), &client_ip)
672 } else {
673 ip_filter_state.ip_filter.is_allowed(&client_ip)
675 };
676
677 if !result.allowed {
679 return Err(IpFilterError::Blocked {
680 reason: result.reason,
681 });
682 }
683
684 Ok(next.run(request).await)
686}
687
688#[derive(Debug)]
690pub enum IpFilterError {
691 NoIpAddress,
692 Blocked { reason: String },
693}
694
695impl IntoResponse for IpFilterError {
696 fn into_response(self) -> Response {
697 match self {
698 IpFilterError::NoIpAddress => (
699 StatusCode::BAD_REQUEST,
700 "Unable to determine client IP address",
701 )
702 .into_response(),
703 IpFilterError::Blocked { reason } => {
704 (StatusCode::FORBIDDEN, format!("Access denied: {reason}")).into_response()
705 }
706 }
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713 use crate::infrastructure::security::auth::Role;
714
715 #[test]
716 fn test_extract_bearer_token() {
717 let mut headers = HeaderMap::new();
718 headers.insert("authorization", "Bearer test_token_123".parse().unwrap());
719
720 let token = extract_token(&headers).unwrap();
721 assert_eq!(token, "test_token_123");
722 }
723
724 #[test]
725 fn test_extract_lowercase_bearer() {
726 let mut headers = HeaderMap::new();
727 headers.insert("authorization", "bearer test_token_123".parse().unwrap());
728
729 let token = extract_token(&headers).unwrap();
730 assert_eq!(token, "test_token_123");
731 }
732
733 #[test]
734 fn test_extract_plain_token() {
735 let mut headers = HeaderMap::new();
736 headers.insert("authorization", "test_token_123".parse().unwrap());
737
738 let token = extract_token(&headers).unwrap();
739 assert_eq!(token, "test_token_123");
740 }
741
742 #[test]
743 fn test_missing_auth_header() {
744 let headers = HeaderMap::new();
745 assert!(extract_token(&headers).is_err());
746 }
747
748 #[test]
749 fn test_empty_auth_header() {
750 let mut headers = HeaderMap::new();
751 headers.insert("authorization", "".parse().unwrap());
752 assert!(extract_token(&headers).is_err());
753 }
754
755 #[test]
756 fn test_bearer_with_empty_token() {
757 let mut headers = HeaderMap::new();
758 headers.insert("authorization", "Bearer ".parse().unwrap());
759 assert!(extract_token(&headers).is_err());
760 }
761
762 #[test]
763 fn test_auth_context_permissions() {
764 let claims = Claims::new(
765 "user1".to_string(),
766 "tenant1".to_string(),
767 Role::Developer,
768 chrono::Duration::hours(1),
769 );
770
771 let ctx = AuthContext { claims };
772
773 assert!(ctx.require_permission(Permission::Read).is_ok());
774 assert!(ctx.require_permission(Permission::Write).is_ok());
775 assert!(ctx.require_permission(Permission::Admin).is_err());
776 }
777
778 #[test]
779 fn test_auth_context_admin_permissions() {
780 let claims = Claims::new(
781 "admin1".to_string(),
782 "tenant1".to_string(),
783 Role::Admin,
784 chrono::Duration::hours(1),
785 );
786
787 let ctx = AuthContext { claims };
788
789 assert!(ctx.require_permission(Permission::Read).is_ok());
790 assert!(ctx.require_permission(Permission::Write).is_ok());
791 assert!(ctx.require_permission(Permission::Admin).is_ok());
792 }
793
794 #[test]
795 fn test_auth_context_readonly_permissions() {
796 let claims = Claims::new(
797 "readonly1".to_string(),
798 "tenant1".to_string(),
799 Role::ReadOnly,
800 chrono::Duration::hours(1),
801 );
802
803 let ctx = AuthContext { claims };
804
805 assert!(ctx.require_permission(Permission::Read).is_ok());
806 assert!(ctx.require_permission(Permission::Write).is_err());
807 assert!(ctx.require_permission(Permission::Admin).is_err());
808 }
809
810 #[test]
811 fn test_auth_context_tenant_id() {
812 let claims = Claims::new(
813 "user1".to_string(),
814 "my-tenant".to_string(),
815 Role::Developer,
816 chrono::Duration::hours(1),
817 );
818
819 let ctx = AuthContext { claims };
820 assert_eq!(ctx.tenant_id(), "my-tenant");
821 }
822
823 #[test]
824 fn test_auth_context_user_id() {
825 let claims = Claims::new(
826 "my-user".to_string(),
827 "tenant1".to_string(),
828 Role::Developer,
829 chrono::Duration::hours(1),
830 );
831
832 let ctx = AuthContext { claims };
833 assert_eq!(ctx.user_id(), "my-user");
834 }
835
836 #[test]
837 fn test_request_id_new() {
838 let id1 = RequestId::new();
839 let id2 = RequestId::new();
840
841 assert_ne!(id1.as_str(), id2.as_str());
843 assert_eq!(id1.as_str().len(), 36);
845 }
846
847 #[test]
848 fn test_request_id_default() {
849 let id = RequestId::default();
850 assert_eq!(id.as_str().len(), 36);
851 }
852
853 #[test]
854 fn test_security_config_default() {
855 let config = SecurityConfig::default();
856
857 assert!(config.enable_hsts);
858 assert_eq!(config.hsts_max_age, 31536000);
859 assert!(config.enable_frame_options);
860 assert!(config.enable_content_type_options);
861 assert!(config.enable_xss_protection);
862 assert!(config.csp.is_some());
863 }
864
865 #[test]
866 fn test_frame_options_variants() {
867 let deny = FrameOptions::Deny;
868 let same_origin = FrameOptions::SameOrigin;
869 let allow_from = FrameOptions::AllowFrom("https://example.com".to_string());
870
871 assert!(format!("{:?}", deny).contains("Deny"));
873 assert!(format!("{:?}", same_origin).contains("SameOrigin"));
874 assert!(format!("{:?}", allow_from).contains("AllowFrom"));
875 }
876
877 #[test]
878 fn test_auth_error_from_validation_error() {
879 let error = AllSourceError::ValidationError("test error".to_string());
880 let auth_error = AuthError::from(error);
881 assert!(format!("{:?}", auth_error).contains("ValidationError"));
882 }
883
884 #[test]
885 fn test_rate_limit_error_display() {
886 let error = RateLimitError::RateLimitExceeded {
887 retry_after: 60,
888 limit: 100,
889 };
890 assert!(format!("{:?}", error).contains("RateLimitExceeded"));
891
892 let unauth_error = RateLimitError::Unauthorized;
893 assert!(format!("{:?}", unauth_error).contains("Unauthorized"));
894 }
895
896 #[test]
897 fn test_tenant_error_variants() {
898 let errors = vec![
899 TenantError::Unauthorized,
900 TenantError::InvalidTenant,
901 TenantError::TenantNotFound,
902 TenantError::TenantInactive,
903 TenantError::RepositoryError("test".to_string()),
904 ];
905
906 for error in errors {
907 let _ = format!("{:?}", error);
909 }
910 }
911
912 #[test]
913 fn test_ip_filter_error_variants() {
914 let errors = vec![
915 IpFilterError::NoIpAddress,
916 IpFilterError::Blocked {
917 reason: "blocked".to_string(),
918 },
919 ];
920
921 for error in errors {
922 let _ = format!("{:?}", error);
923 }
924 }
925
926 #[test]
927 fn test_security_state_clone() {
928 let config = SecurityConfig::default();
929 let state = SecurityState {
930 config: config.clone(),
931 };
932 let cloned = state.clone();
933 assert_eq!(cloned.config.hsts_max_age, config.hsts_max_age);
934 }
935
936 #[test]
937 fn test_auth_state_clone() {
938 let auth_manager = Arc::new(AuthManager::new("test-secret"));
939 let state = AuthState { auth_manager };
940 let cloned = state.clone();
941 assert!(Arc::ptr_eq(&state.auth_manager, &cloned.auth_manager));
942 }
943
944 #[test]
945 fn test_rate_limit_state_clone() {
946 use crate::infrastructure::security::rate_limit::RateLimitConfig;
947 let rate_limiter = Arc::new(RateLimiter::new(RateLimitConfig::free_tier()));
948 let state = RateLimitState { rate_limiter };
949 let cloned = state.clone();
950 assert!(Arc::ptr_eq(&state.rate_limiter, &cloned.rate_limiter));
951 }
952}