1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use parking_lot::RwLock;
11use thiserror::Error;
12
13use super::config::{
14 AuthConfig, Identity, AuthMethod, JwtConfig, OAuthConfig,
15 LdapConfig, ApiKeyConfig,
16};
17use super::jwt::{JwtValidator, JwtError};
18
19#[derive(Debug, Error)]
21pub enum AuthError {
22 #[error("Authentication required")]
23 AuthenticationRequired,
24
25 #[error("Invalid credentials")]
26 InvalidCredentials,
27
28 #[error("Token expired")]
29 TokenExpired,
30
31 #[error("Insufficient permissions: {0}")]
32 InsufficientPermissions(String),
33
34 #[error("Rate limited: retry after {0} seconds")]
35 RateLimited(u64),
36
37 #[error("Authentication provider unavailable: {0}")]
38 ProviderUnavailable(String),
39
40 #[error("Invalid authentication method: {0}")]
41 InvalidMethod(String),
42
43 #[error("JWT error: {0}")]
44 Jwt(#[from] JwtError),
45
46 #[error("OAuth error: {0}")]
47 OAuth(String),
48
49 #[error("LDAP error: {0}")]
50 Ldap(String),
51
52 #[error("API key error: {0}")]
53 ApiKey(String),
54
55 #[error("Session error: {0}")]
56 Session(String),
57
58 #[error("Configuration error: {0}")]
59 Configuration(String),
60}
61
62#[derive(Debug, Clone)]
64pub struct AuthRequest {
65 pub headers: HashMap<String, String>,
67
68 pub username: Option<String>,
70
71 pub password: Option<String>,
73
74 pub client_ip: Option<std::net::IpAddr>,
76
77 pub database: Option<String>,
79
80 pub timestamp: chrono::DateTime<chrono::Utc>,
82}
83
84impl AuthRequest {
85 pub fn new() -> Self {
87 Self {
88 headers: HashMap::new(),
89 username: None,
90 password: None,
91 client_ip: None,
92 database: None,
93 timestamp: chrono::Utc::now(),
94 }
95 }
96
97 pub fn with_username(mut self, username: impl Into<String>) -> Self {
99 self.username = Some(username.into());
100 self
101 }
102
103 pub fn with_password(mut self, password: impl Into<String>) -> Self {
105 self.password = Some(password.into());
106 self
107 }
108
109 pub fn with_client_ip(mut self, ip: std::net::IpAddr) -> Self {
111 self.client_ip = Some(ip);
112 self
113 }
114
115 pub fn with_database(mut self, database: impl Into<String>) -> Self {
117 self.database = Some(database.into());
118 self
119 }
120
121 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
123 self.headers.insert(key.into(), value.into());
124 self
125 }
126
127 pub fn authorization_header(&self) -> Option<&str> {
129 self.headers.get("authorization")
130 .or_else(|| self.headers.get("Authorization"))
131 .map(|s| s.as_str())
132 }
133
134 pub fn bearer_token(&self) -> Option<&str> {
136 self.authorization_header()
137 .and_then(|h| h.strip_prefix("Bearer "))
138 .or_else(|| self.authorization_header()?.strip_prefix("bearer "))
139 }
140
141 pub fn api_key(&self, header_name: &str) -> Option<&str> {
143 self.headers.get(header_name)
144 .or_else(|| self.headers.get(&header_name.to_lowercase()))
145 .map(|s| s.as_str())
146 }
147}
148
149impl Default for AuthRequest {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155#[derive(Debug, Clone)]
157pub struct AuthResult {
158 pub identity: Identity,
160
161 pub session_token: Option<String>,
163
164 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
166
167 pub metadata: HashMap<String, String>,
169}
170
171impl AuthResult {
172 pub fn new(identity: Identity) -> Self {
174 Self {
175 identity,
176 session_token: None,
177 expires_at: None,
178 metadata: HashMap::new(),
179 }
180 }
181
182 pub fn with_session_token(mut self, token: String) -> Self {
184 self.session_token = Some(token);
185 self
186 }
187
188 pub fn with_expiration(mut self, expires_at: chrono::DateTime<chrono::Utc>) -> Self {
190 self.expires_at = Some(expires_at);
191 self
192 }
193
194 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
196 self.metadata.insert(key.into(), value.into());
197 self
198 }
199}
200
201pub struct AuthenticationHandler {
203 config: AuthConfig,
205
206 jwt_validator: Option<JwtValidator>,
208
209 oauth_enabled: bool,
211
212 ldap_enabled: bool,
214
215 api_keys: Arc<RwLock<HashMap<String, ApiKeyEntry>>>,
217
218 rate_limiter: Arc<RwLock<RateLimiterState>>,
220
221 auth_cache: Arc<RwLock<AuthCache>>,
223}
224
225#[derive(Debug, Clone)]
227struct ApiKeyEntry {
228 key_id: String,
230
231 key_hash: String,
233
234 identity: Identity,
236
237 created_at: chrono::DateTime<chrono::Utc>,
239
240 expires_at: Option<chrono::DateTime<chrono::Utc>>,
242
243 active: bool,
245
246 scopes: Vec<String>,
248
249 rate_limit: Option<u32>,
251}
252
253struct RateLimiterState {
255 by_ip: HashMap<std::net::IpAddr, RateLimitBucket>,
257
258 by_user: HashMap<String, RateLimitBucket>,
260
261 last_cleanup: Instant,
263}
264
265struct RateLimitBucket {
267 count: u32,
269
270 window_start: Instant,
272}
273
274impl RateLimitBucket {
275 fn new() -> Self {
276 Self {
277 count: 0,
278 window_start: Instant::now(),
279 }
280 }
281
282 fn increment(&mut self, window: Duration) -> u32 {
283 if self.window_start.elapsed() > window {
284 self.count = 1;
285 self.window_start = Instant::now();
286 } else {
287 self.count += 1;
288 }
289 self.count
290 }
291}
292
293struct AuthCache {
295 entries: HashMap<String, CachedAuth>,
297
298 max_size: usize,
300
301 ttl: Duration,
303}
304
305struct CachedAuth {
307 result: AuthResult,
309
310 cached_at: Instant,
312}
313
314impl AuthCache {
315 fn new(max_size: usize, ttl: Duration) -> Self {
316 Self {
317 entries: HashMap::new(),
318 max_size,
319 ttl,
320 }
321 }
322
323 fn get(&self, key: &str) -> Option<&AuthResult> {
324 self.entries.get(key).and_then(|cached| {
325 if cached.cached_at.elapsed() < self.ttl {
326 Some(&cached.result)
327 } else {
328 None
329 }
330 })
331 }
332
333 fn insert(&mut self, key: String, result: AuthResult) {
334 if self.entries.len() >= self.max_size {
335 self.evict_expired();
336 }
337 self.entries.insert(key, CachedAuth {
338 result,
339 cached_at: Instant::now(),
340 });
341 }
342
343 fn evict_expired(&mut self) {
344 self.entries.retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
345 }
346}
347
348impl AuthenticationHandler {
349 pub fn new(config: AuthConfig) -> Self {
351 let jwt_validator = config.jwt.as_ref().map(|jwt_config| {
352 JwtValidator::new(jwt_config.clone())
353 });
354
355 let oauth_enabled = config.oauth.is_some();
356 let ldap_enabled = config.ldap.is_some();
357
358 Self {
359 config,
360 jwt_validator,
361 oauth_enabled,
362 ldap_enabled,
363 api_keys: Arc::new(RwLock::new(HashMap::new())),
364 rate_limiter: Arc::new(RwLock::new(RateLimiterState {
365 by_ip: HashMap::new(),
366 by_user: HashMap::new(),
367 last_cleanup: Instant::now(),
368 })),
369 auth_cache: Arc::new(RwLock::new(AuthCache::new(
370 1000,
371 Duration::from_secs(60),
372 ))),
373 }
374 }
375
376 pub fn builder() -> AuthenticationHandlerBuilder {
378 AuthenticationHandlerBuilder::new()
379 }
380
381 pub async fn authenticate(&self, request: &AuthRequest) -> Result<AuthResult, AuthError> {
383 if !self.config.enabled {
385 return Ok(AuthResult::new(Identity::anonymous()));
387 }
388
389 self.check_rate_limit(request)?;
391
392 let methods = &self.config.auth_methods;
394
395 for method in methods {
396 match self.try_authenticate(request, method).await {
397 Ok(result) => return Ok(result),
398 Err(AuthError::AuthenticationRequired) => continue,
399 Err(e) => return Err(e),
400 }
401 }
402
403 Err(AuthError::AuthenticationRequired)
405 }
406
407 async fn try_authenticate(
409 &self,
410 request: &AuthRequest,
411 method: &AuthMethod,
412 ) -> Result<AuthResult, AuthError> {
413 match method {
414 AuthMethod::Jwt => self.authenticate_jwt(request).await,
415 AuthMethod::OAuth => self.authenticate_oauth(request).await,
416 AuthMethod::Ldap => self.authenticate_ldap(request).await,
417 AuthMethod::ApiKey => self.authenticate_api_key(request).await,
418 AuthMethod::Basic => self.authenticate_basic(request).await,
419 AuthMethod::Trust => self.authenticate_trust(request),
420 AuthMethod::AgentToken | AuthMethod::Session | AuthMethod::Anonymous => {
421 self.authenticate_trust(request)
422 }
423 }
424 }
425
426 async fn authenticate_jwt(&self, request: &AuthRequest) -> Result<AuthResult, AuthError> {
428 let validator = self.jwt_validator.as_ref()
429 .ok_or(AuthError::Configuration("JWT not configured".to_string()))?;
430
431 let token = request.bearer_token()
432 .ok_or(AuthError::AuthenticationRequired)?;
433
434 if let Some(cached) = self.auth_cache.read().get(token) {
436 return Ok(cached.clone());
437 }
438
439 let identity = validator.validate_to_identity(token)?;
441 let result = AuthResult::new(identity);
442
443 self.auth_cache.write().insert(token.to_string(), result.clone());
445
446 Ok(result)
447 }
448
449 async fn authenticate_oauth(&self, request: &AuthRequest) -> Result<AuthResult, AuthError> {
451 if !self.oauth_enabled {
452 return Err(AuthError::Configuration("OAuth not configured".to_string()));
453 }
454
455 let token = request.bearer_token()
456 .ok_or(AuthError::AuthenticationRequired)?;
457
458 if let Some(cached) = self.auth_cache.read().get(token) {
460 return Ok(cached.clone());
461 }
462
463 let identity = Identity {
466 user_id: "oauth_user".to_string(),
467 name: Some("OAuth User".to_string()),
468 email: None,
469 roles: vec!["user".to_string()],
470 groups: Vec::new(),
471 tenant_id: None,
472 claims: HashMap::new(),
473 auth_method: "oauth".to_string(),
474 authenticated_at: chrono::Utc::now(),
475 };
476
477 let result = AuthResult::new(identity);
478 self.auth_cache.write().insert(token.to_string(), result.clone());
479
480 Ok(result)
481 }
482
483 async fn authenticate_ldap(&self, request: &AuthRequest) -> Result<AuthResult, AuthError> {
485 if !self.ldap_enabled {
486 return Err(AuthError::Configuration("LDAP not configured".to_string()));
487 }
488
489 let username = request.username.as_ref()
490 .ok_or(AuthError::AuthenticationRequired)?;
491 let password = request.password.as_ref()
492 .ok_or(AuthError::AuthenticationRequired)?;
493
494 if password.is_empty() {
497 return Err(AuthError::InvalidCredentials);
498 }
499
500 let identity = Identity {
501 user_id: username.clone(),
502 name: Some(username.clone()),
503 email: None,
504 roles: vec!["user".to_string()],
505 groups: Vec::new(),
506 tenant_id: None,
507 claims: HashMap::new(),
508 auth_method: "ldap".to_string(),
509 authenticated_at: chrono::Utc::now(),
510 };
511
512 Ok(AuthResult::new(identity))
513 }
514
515 async fn authenticate_api_key(&self, request: &AuthRequest) -> Result<AuthResult, AuthError> {
517 let api_key_config = self.config.api_keys.as_ref()
518 .ok_or(AuthError::Configuration("API keys not configured".to_string()))?;
519
520 let header_name = &api_key_config.header_name;
521 let key = request.api_key(header_name)
522 .ok_or(AuthError::AuthenticationRequired)?;
523
524 if let Some(cached) = self.auth_cache.read().get(key) {
526 return Ok(cached.clone());
527 }
528
529 let api_keys = self.api_keys.read();
531 let entry = api_keys.values()
532 .find(|e| self.verify_api_key(key, &e.key_hash) && e.active)
533 .ok_or(AuthError::InvalidCredentials)?;
534
535 if let Some(expires_at) = entry.expires_at {
537 if chrono::Utc::now() > expires_at {
538 return Err(AuthError::TokenExpired);
539 }
540 }
541
542 let result = AuthResult::new(entry.identity.clone());
543 self.auth_cache.write().insert(key.to_string(), result.clone());
544
545 Ok(result)
546 }
547
548 async fn authenticate_basic(&self, request: &AuthRequest) -> Result<AuthResult, AuthError> {
550 let auth_header = request.authorization_header()
551 .ok_or(AuthError::AuthenticationRequired)?;
552
553 if !auth_header.starts_with("Basic ") {
554 return Err(AuthError::AuthenticationRequired);
555 }
556
557 let encoded = &auth_header[6..];
558 let decoded = base64_decode(encoded)
559 .map_err(|_| AuthError::InvalidCredentials)?;
560 let credentials = String::from_utf8(decoded)
561 .map_err(|_| AuthError::InvalidCredentials)?;
562
563 let parts: Vec<&str> = credentials.splitn(2, ':').collect();
564 if parts.len() != 2 {
565 return Err(AuthError::InvalidCredentials);
566 }
567
568 let username = parts[0];
569 let password = parts[1];
570
571 if password.is_empty() {
574 return Err(AuthError::InvalidCredentials);
575 }
576
577 let identity = Identity {
578 user_id: username.to_string(),
579 name: Some(username.to_string()),
580 email: None,
581 roles: vec!["user".to_string()],
582 groups: Vec::new(),
583 tenant_id: None,
584 claims: HashMap::new(),
585 auth_method: "basic".to_string(),
586 authenticated_at: chrono::Utc::now(),
587 };
588
589 Ok(AuthResult::new(identity))
590 }
591
592 fn authenticate_trust(&self, request: &AuthRequest) -> Result<AuthResult, AuthError> {
594 let username = request.username.as_ref()
596 .unwrap_or(&"anonymous".to_string())
597 .clone();
598
599 let identity = Identity {
600 user_id: username.clone(),
601 name: Some(username),
602 email: None,
603 roles: vec!["trusted".to_string()],
604 groups: Vec::new(),
605 tenant_id: None,
606 claims: HashMap::new(),
607 auth_method: "trust".to_string(),
608 authenticated_at: chrono::Utc::now(),
609 };
610
611 Ok(AuthResult::new(identity))
612 }
613
614 fn check_rate_limit(&self, request: &AuthRequest) -> Result<(), AuthError> {
616 let config = &self.config.rate_limit;
617 if !config.enabled {
618 return Ok(());
619 }
620
621 let mut limiter = self.rate_limiter.write();
622
623 if limiter.last_cleanup.elapsed() > Duration::from_secs(60) {
625 let window = Duration::from_secs(config.window_seconds);
626 limiter.by_ip.retain(|_, b| b.window_start.elapsed() < window);
627 limiter.by_user.retain(|_, b| b.window_start.elapsed() < window);
628 limiter.last_cleanup = Instant::now();
629 }
630
631 let window = Duration::from_secs(config.window_seconds);
632
633 if let Some(ip) = request.client_ip {
635 let bucket = limiter.by_ip.entry(ip).or_insert_with(RateLimitBucket::new);
636 let count = bucket.increment(window);
637 if count > config.max_requests_per_ip {
638 let retry_after = window.as_secs().saturating_sub(bucket.window_start.elapsed().as_secs());
639 return Err(AuthError::RateLimited(retry_after));
640 }
641 }
642
643 if let Some(username) = &request.username {
645 let bucket = limiter.by_user.entry(username.clone()).or_insert_with(RateLimitBucket::new);
646 let count = bucket.increment(window);
647 if count > config.max_requests_per_user {
648 let retry_after = window.as_secs().saturating_sub(bucket.window_start.elapsed().as_secs());
649 return Err(AuthError::RateLimited(retry_after));
650 }
651 }
652
653 Ok(())
654 }
655
656 fn verify_api_key(&self, key: &str, hash: &str) -> bool {
658 let key_hash = self.hash_api_key(key);
661 key_hash == hash
662 }
663
664 fn hash_api_key(&self, key: &str) -> String {
666 use std::hash::{Hash, Hasher};
668 let mut hasher = std::collections::hash_map::DefaultHasher::new();
669 key.hash(&mut hasher);
670 format!("{:x}", hasher.finish())
671 }
672
673 pub fn register_api_key(
675 &self,
676 key_id: String,
677 key_value: String,
678 identity: Identity,
679 expires_at: Option<chrono::DateTime<chrono::Utc>>,
680 scopes: Vec<String>,
681 ) {
682 let entry = ApiKeyEntry {
683 key_id: key_id.clone(),
684 key_hash: self.hash_api_key(&key_value),
685 identity,
686 created_at: chrono::Utc::now(),
687 expires_at,
688 active: true,
689 scopes,
690 rate_limit: None,
691 };
692
693 self.api_keys.write().insert(key_id, entry);
694 }
695
696 pub fn revoke_api_key(&self, key_id: &str) -> bool {
698 if let Some(entry) = self.api_keys.write().get_mut(key_id) {
699 entry.active = false;
700 true
701 } else {
702 false
703 }
704 }
705
706 pub async fn refresh_jwks_if_needed(&self) -> Result<(), AuthError> {
708 if let Some(validator) = &self.jwt_validator {
709 if validator.needs_refresh() {
710 validator.refresh_jwks().await?;
711 }
712 }
713 Ok(())
714 }
715
716 pub fn clear_cache(&self) {
718 self.auth_cache.write().entries.clear();
719 }
720
721 pub fn cache_stats(&self) -> CacheStats {
723 let cache = self.auth_cache.read();
724 CacheStats {
725 entries: cache.entries.len(),
726 max_size: cache.max_size,
727 ttl_seconds: cache.ttl.as_secs(),
728 }
729 }
730
731 pub fn is_enabled(&self) -> bool {
733 self.config.enabled
734 }
735}
736
737#[derive(Debug, Clone)]
739pub struct CacheStats {
740 pub entries: usize,
742
743 pub max_size: usize,
745
746 pub ttl_seconds: u64,
748}
749
750pub struct AuthenticationHandlerBuilder {
752 config: AuthConfig,
753}
754
755impl AuthenticationHandlerBuilder {
756 pub fn new() -> Self {
758 Self {
759 config: AuthConfig::default(),
760 }
761 }
762
763 pub fn enabled(mut self, enabled: bool) -> Self {
765 self.config.enabled = enabled;
766 self
767 }
768
769 pub fn with_jwt(mut self, config: JwtConfig) -> Self {
771 self.config.jwt = Some(config);
772 self.config.auth_methods.push(AuthMethod::Jwt);
773 self
774 }
775
776 pub fn with_oauth(mut self, config: OAuthConfig) -> Self {
778 self.config.oauth = Some(config);
779 self.config.auth_methods.push(AuthMethod::OAuth);
780 self
781 }
782
783 pub fn with_ldap(mut self, config: LdapConfig) -> Self {
785 self.config.ldap = Some(config);
786 self.config.auth_methods.push(AuthMethod::Ldap);
787 self
788 }
789
790 pub fn with_api_keys(mut self, config: ApiKeyConfig) -> Self {
792 self.config.api_keys = Some(config);
793 self.config.auth_methods.push(AuthMethod::ApiKey);
794 self
795 }
796
797 pub fn with_basic_auth(mut self) -> Self {
799 self.config.auth_methods.push(AuthMethod::Basic);
800 self
801 }
802
803 pub fn with_trust_auth(mut self) -> Self {
805 self.config.auth_methods.push(AuthMethod::Trust);
806 self
807 }
808
809 pub fn default_role(mut self, role: impl Into<String>) -> Self {
811 self.config.default_role = Some(role.into());
812 self
813 }
814
815 pub fn build(self) -> AuthenticationHandler {
817 AuthenticationHandler::new(self.config)
818 }
819}
820
821impl Default for AuthenticationHandlerBuilder {
822 fn default() -> Self {
823 Self::new()
824 }
825}
826
827fn base64_decode(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
829 use base64::{engine::general_purpose::STANDARD, Engine};
830 STANDARD.decode(input)
831}
832
833#[cfg(test)]
834mod tests {
835 use super::*;
836
837 fn test_config() -> AuthConfig {
838 let mut config = AuthConfig::default();
839 config.enabled = true;
840 config.auth_methods = vec![AuthMethod::Trust];
841 config
842 }
843
844 #[tokio::test]
845 async fn test_authentication_disabled() {
846 let mut config = AuthConfig::default();
847 config.enabled = false;
848 let handler = AuthenticationHandler::new(config);
849
850 let request = AuthRequest::new();
851 let result = handler.authenticate(&request).await.unwrap();
852
853 assert_eq!(result.identity.auth_method, "anonymous");
854 }
855
856 #[tokio::test]
857 async fn test_trust_authentication() {
858 let handler = AuthenticationHandler::new(test_config());
859
860 let request = AuthRequest::new().with_username("testuser");
861 let result = handler.authenticate(&request).await.unwrap();
862
863 assert_eq!(result.identity.user_id, "testuser");
864 assert_eq!(result.identity.auth_method, "trust");
865 }
866
867 #[test]
868 fn test_auth_request_builder() {
869 let request = AuthRequest::new()
870 .with_username("user")
871 .with_password("pass")
872 .with_database("mydb")
873 .with_header("Authorization", "Bearer token123");
874
875 assert_eq!(request.username, Some("user".to_string()));
876 assert_eq!(request.password, Some("pass".to_string()));
877 assert_eq!(request.database, Some("mydb".to_string()));
878 assert_eq!(request.bearer_token(), Some("token123"));
879 }
880
881 #[test]
882 fn test_bearer_token_extraction() {
883 let request = AuthRequest::new()
884 .with_header("Authorization", "Bearer my-jwt-token");
885
886 assert_eq!(request.bearer_token(), Some("my-jwt-token"));
887 }
888
889 #[test]
890 fn test_api_key_extraction() {
891 let request = AuthRequest::new()
892 .with_header("X-API-Key", "secret-key-123");
893
894 assert_eq!(request.api_key("X-API-Key"), Some("secret-key-123"));
895 }
896
897 #[tokio::test]
898 async fn test_api_key_registration_and_auth() {
899 let mut config = AuthConfig::default();
900 config.enabled = true;
901 config.api_keys = Some(ApiKeyConfig {
902 header_name: "X-API-Key".to_string(),
903 query_param: None,
904 prefix: None,
905 hash_algorithm: "sha256".to_string(),
906 });
907 config.auth_methods = vec![AuthMethod::ApiKey];
908
909 let handler = AuthenticationHandler::new(config);
910
911 let identity = Identity {
913 user_id: "api_user".to_string(),
914 name: Some("API User".to_string()),
915 email: None,
916 roles: vec!["api".to_string()],
917 groups: Vec::new(),
918 tenant_id: None,
919 claims: HashMap::new(),
920 auth_method: "api_key".to_string(),
921 authenticated_at: chrono::Utc::now(),
922 };
923
924 handler.register_api_key(
925 "key1".to_string(),
926 "secret123".to_string(),
927 identity,
928 None,
929 vec!["read".to_string()],
930 );
931
932 let request = AuthRequest::new()
934 .with_header("X-API-Key", "secret123");
935
936 let result = handler.authenticate(&request).await.unwrap();
937 assert_eq!(result.identity.user_id, "api_user");
938 }
939
940 #[test]
941 fn test_cache_stats() {
942 let handler = AuthenticationHandler::new(test_config());
943 let stats = handler.cache_stats();
944
945 assert_eq!(stats.entries, 0);
946 assert_eq!(stats.max_size, 1000);
947 }
948
949 #[test]
950 fn test_handler_builder() {
951 let handler = AuthenticationHandler::builder()
952 .enabled(true)
953 .with_trust_auth()
954 .default_role("user")
955 .build();
956
957 assert!(handler.is_enabled());
958 }
959}