1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use parking_lot::RwLock;
11use thiserror::Error;
12
13use super::config::CredentialConfig;
14
15#[derive(Debug, Error)]
17pub enum CredentialError {
18 #[error("Credential not found: {0}")]
19 NotFound(String),
20
21 #[error("Provider unavailable: {0}")]
22 ProviderUnavailable(String),
23
24 #[error("Access denied: {0}")]
25 AccessDenied(String),
26
27 #[error("Invalid credential format: {0}")]
28 InvalidFormat(String),
29
30 #[error("Credential expired")]
31 Expired,
32
33 #[error("Network error: {0}")]
34 NetworkError(String),
35
36 #[error("Configuration error: {0}")]
37 ConfigurationError(String),
38}
39
40#[derive(Debug, Clone)]
42pub struct DatabaseCredential {
43 pub username: String,
45
46 pub password: String,
48
49 pub database: Option<String>,
51
52 pub host: Option<String>,
54
55 pub port: Option<u16>,
57
58 pub options: HashMap<String, String>,
60
61 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
63
64 pub source: CredentialSource,
66}
67
68impl DatabaseCredential {
69 pub fn is_expired(&self) -> bool {
71 self.expires_at
72 .map(|exp| chrono::Utc::now() > exp)
73 .unwrap_or(false)
74 }
75
76 pub fn time_until_expiration(&self) -> Option<Duration> {
78 self.expires_at.and_then(|exp| {
79 let now = chrono::Utc::now();
80 if exp > now {
81 Some((exp - now).to_std().unwrap_or(Duration::ZERO))
82 } else {
83 None
84 }
85 })
86 }
87
88 pub fn connection_string(&self) -> String {
90 let host = self.host.as_deref().unwrap_or("localhost");
91 let port = self.port.unwrap_or(5432);
92 let database = self.database.as_deref().unwrap_or("postgres");
93
94 format!(
95 "postgresql://{}:{}@{}:{}/{}",
96 self.username, self.password, host, port, database
97 )
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum CredentialSource {
104 Static,
106
107 Environment,
109
110 Vault,
112
113 AwsSecretsManager,
115
116 AzureKeyVault,
118
119 GcpSecretManager,
121
122 Kubernetes,
124
125 Custom(String),
127}
128
129pub trait CredentialProvider: Send + Sync {
131 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
133
134 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
136
137 fn list_credentials(&self) -> Result<Vec<String>, CredentialError>;
139
140 fn provider_name(&self) -> &str;
142}
143
144pub struct CredentialManager {
146 config: CredentialConfig,
148
149 providers: Vec<Box<dyn CredentialProvider>>,
151
152 cache: Arc<RwLock<CredentialCache>>,
154
155 #[allow(dead_code)]
157 default_provider: usize,
158}
159
160struct CredentialCache {
162 entries: HashMap<String, CachedCredential>,
163 max_size: usize,
164 default_ttl: Duration,
165}
166
167struct CachedCredential {
168 credential: DatabaseCredential,
169 cached_at: Instant,
170 ttl: Duration,
171}
172
173impl CredentialCache {
174 fn new(max_size: usize, default_ttl: Duration) -> Self {
175 Self {
176 entries: HashMap::new(),
177 max_size,
178 default_ttl,
179 }
180 }
181
182 fn get(&self, key: &str) -> Option<&DatabaseCredential> {
183 self.entries.get(key).and_then(|cached| {
184 if cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired() {
185 Some(&cached.credential)
186 } else {
187 None
188 }
189 })
190 }
191
192 fn insert(&mut self, key: String, credential: DatabaseCredential, ttl: Option<Duration>) {
193 if self.entries.len() >= self.max_size {
194 self.evict_expired();
195 }
196
197 let ttl = ttl.unwrap_or(self.default_ttl);
198 self.entries.insert(
199 key,
200 CachedCredential {
201 credential,
202 cached_at: Instant::now(),
203 ttl,
204 },
205 );
206 }
207
208 fn evict_expired(&mut self) {
209 self.entries.retain(|_, cached| {
210 cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired()
211 });
212 }
213
214 fn invalidate(&mut self, key: &str) {
215 self.entries.remove(key);
216 }
217
218 fn clear(&mut self) {
219 self.entries.clear();
220 }
221}
222
223impl CredentialManager {
224 pub fn new(config: CredentialConfig) -> Self {
226 let cache_ttl = config.cache_ttl;
227
228 Self {
229 config,
230 providers: Vec::new(),
231 cache: Arc::new(RwLock::new(CredentialCache::new(1000, cache_ttl))),
232 default_provider: 0,
233 }
234 }
235
236 pub fn builder() -> CredentialManagerBuilder {
238 CredentialManagerBuilder::new()
239 }
240
241 pub fn add_provider(&mut self, provider: Box<dyn CredentialProvider>) {
243 self.providers.push(provider);
244 }
245
246 pub fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
248 if let Some(cached) = self.cache.read().get(key) {
250 return Ok(cached.clone());
251 }
252
253 for provider in &self.providers {
255 match provider.get_credential(key) {
256 Ok(credential) => {
257 let ttl = credential
259 .time_until_expiration()
260 .map(|d| d.min(self.config.cache_ttl))
261 .or(Some(self.config.cache_ttl));
262
263 self.cache
265 .write()
266 .insert(key.to_string(), credential.clone(), ttl);
267 return Ok(credential);
268 }
269 Err(CredentialError::NotFound(_)) => continue,
270 Err(e) => return Err(e),
271 }
272 }
273
274 Err(CredentialError::NotFound(key.to_string()))
275 }
276
277 pub fn get_credential_from(
279 &self,
280 key: &str,
281 provider_name: &str,
282 ) -> Result<DatabaseCredential, CredentialError> {
283 let provider = self
284 .providers
285 .iter()
286 .find(|p| p.provider_name() == provider_name)
287 .ok_or_else(|| CredentialError::ProviderUnavailable(provider_name.to_string()))?;
288
289 provider.get_credential(key)
290 }
291
292 pub fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
294 self.cache.write().invalidate(key);
296
297 for provider in &self.providers {
299 match provider.refresh_credential(key) {
300 Ok(credential) => {
301 let ttl = credential
302 .time_until_expiration()
303 .map(|d| d.min(self.config.cache_ttl))
304 .or(Some(self.config.cache_ttl));
305
306 self.cache
307 .write()
308 .insert(key.to_string(), credential.clone(), ttl);
309 return Ok(credential);
310 }
311 Err(CredentialError::NotFound(_)) => continue,
312 Err(e) => return Err(e),
313 }
314 }
315
316 Err(CredentialError::NotFound(key.to_string()))
317 }
318
319 pub fn list_credentials(&self) -> Vec<(String, String)> {
321 let mut result = Vec::new();
322
323 for provider in &self.providers {
324 if let Ok(keys) = provider.list_credentials() {
325 for key in keys {
326 result.push((key, provider.provider_name().to_string()));
327 }
328 }
329 }
330
331 result
332 }
333
334 pub fn invalidate(&self, key: &str) {
336 self.cache.write().invalidate(key);
337 }
338
339 pub fn clear_cache(&self) {
341 self.cache.write().clear();
342 }
343
344 pub fn cache_stats(&self) -> CacheStats {
346 let cache = self.cache.read();
347 CacheStats {
348 entries: cache.entries.len(),
349 max_size: cache.max_size,
350 }
351 }
352}
353
354#[derive(Debug, Clone)]
356pub struct CacheStats {
357 pub entries: usize,
358 pub max_size: usize,
359}
360
361pub struct CredentialManagerBuilder {
363 config: CredentialConfig,
364 providers: Vec<Box<dyn CredentialProvider>>,
365}
366
367impl CredentialManagerBuilder {
368 pub fn new() -> Self {
370 Self {
371 config: CredentialConfig::default(),
372 providers: Vec::new(),
373 }
374 }
375
376 pub fn cache_ttl(mut self, ttl: Duration) -> Self {
378 self.config.cache_ttl = ttl;
379 self
380 }
381
382 pub fn with_static_credentials(
384 mut self,
385 credentials: HashMap<String, DatabaseCredential>,
386 ) -> Self {
387 self.providers
388 .push(Box::new(StaticCredentialProvider::new(credentials)));
389 self
390 }
391
392 pub fn with_environment(mut self, prefix: &str) -> Self {
394 self.providers
395 .push(Box::new(EnvironmentCredentialProvider::new(prefix)));
396 self
397 }
398
399 pub fn with_vault(mut self, address: &str, token: &str, mount: &str) -> Self {
401 self.providers.push(Box::new(VaultCredentialProvider::new(
402 address, token, mount,
403 )));
404 self
405 }
406
407 pub fn with_aws_secrets_manager(mut self, region: &str) -> Self {
409 self.providers
410 .push(Box::new(AwsSecretsManagerProvider::new(region)));
411 self
412 }
413
414 pub fn with_provider(mut self, provider: Box<dyn CredentialProvider>) -> Self {
416 self.providers.push(provider);
417 self
418 }
419
420 pub fn build(self) -> CredentialManager {
422 let mut manager = CredentialManager::new(self.config);
423 for provider in self.providers {
424 manager.add_provider(provider);
425 }
426 manager
427 }
428}
429
430impl Default for CredentialManagerBuilder {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436pub struct StaticCredentialProvider {
438 credentials: HashMap<String, DatabaseCredential>,
439}
440
441impl StaticCredentialProvider {
442 pub fn new(credentials: HashMap<String, DatabaseCredential>) -> Self {
444 Self { credentials }
445 }
446
447 pub fn add(&mut self, key: String, credential: DatabaseCredential) {
449 self.credentials.insert(key, credential);
450 }
451}
452
453impl CredentialProvider for StaticCredentialProvider {
454 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
455 self.credentials
456 .get(key)
457 .cloned()
458 .ok_or_else(|| CredentialError::NotFound(key.to_string()))
459 }
460
461 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
462 self.get_credential(key)
463 }
464
465 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
466 Ok(self.credentials.keys().cloned().collect())
467 }
468
469 fn provider_name(&self) -> &str {
470 "static"
471 }
472}
473
474pub struct EnvironmentCredentialProvider {
476 prefix: String,
477}
478
479impl EnvironmentCredentialProvider {
480 pub fn new(prefix: &str) -> Self {
482 Self {
483 prefix: prefix.to_string(),
484 }
485 }
486
487 fn var_name(&self, key: &str, suffix: &str) -> String {
488 format!("{}_{}{}", self.prefix, key.to_uppercase(), suffix)
489 }
490}
491
492impl CredentialProvider for EnvironmentCredentialProvider {
493 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
494 let username = std::env::var(self.var_name(key, "_USERNAME"))
495 .or_else(|_| std::env::var(self.var_name(key, "_USER")))
496 .map_err(|_| CredentialError::NotFound(key.to_string()))?;
497
498 let password = std::env::var(self.var_name(key, "_PASSWORD"))
499 .or_else(|_| std::env::var(self.var_name(key, "_PASS")))
500 .map_err(|_| CredentialError::NotFound(format!("{}_PASSWORD", key)))?;
501
502 let database = std::env::var(self.var_name(key, "_DATABASE")).ok();
503 let host = std::env::var(self.var_name(key, "_HOST")).ok();
504 let port = std::env::var(self.var_name(key, "_PORT"))
505 .ok()
506 .and_then(|p| p.parse().ok());
507
508 Ok(DatabaseCredential {
509 username,
510 password,
511 database,
512 host,
513 port,
514 options: HashMap::new(),
515 expires_at: None,
516 source: CredentialSource::Environment,
517 })
518 }
519
520 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
521 self.get_credential(key)
522 }
523
524 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
525 let mut keys = Vec::new();
527 let prefix_upper = format!("{}_", self.prefix.to_uppercase());
528
529 for (key, _) in std::env::vars() {
530 if key.starts_with(&prefix_upper) && key.ends_with("_USERNAME") {
531 let name = key
532 .strip_prefix(&prefix_upper)
533 .and_then(|s| s.strip_suffix("_USERNAME"))
534 .map(|s| s.to_lowercase());
535 if let Some(name) = name {
536 keys.push(name);
537 }
538 }
539 }
540
541 Ok(keys)
542 }
543
544 fn provider_name(&self) -> &str {
545 "environment"
546 }
547}
548
549pub struct VaultCredentialProvider {
551 address: String,
552 token: String,
553 mount: String,
554}
555
556impl VaultCredentialProvider {
557 pub fn new(address: &str, token: &str, mount: &str) -> Self {
559 Self {
560 address: address.to_string(),
561 token: token.to_string(),
562 mount: mount.to_string(),
563 }
564 }
565}
566
567impl CredentialProvider for VaultCredentialProvider {
568 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
569 let _ = (key, &self.address, &self.token, &self.mount);
578
579 Ok(DatabaseCredential {
580 username: format!("vault_user_{}", key),
581 password: "vault_generated_password".to_string(),
582 database: None,
583 host: None,
584 port: None,
585 options: HashMap::new(),
586 expires_at: Some(chrono::Utc::now() + chrono::Duration::hours(1)),
587 source: CredentialSource::Vault,
588 })
589 }
590
591 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
592 self.get_credential(key)
594 }
595
596 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
597 Ok(Vec::new())
599 }
600
601 fn provider_name(&self) -> &str {
602 "vault"
603 }
604}
605
606pub struct AwsSecretsManagerProvider {
608 region: String,
609}
610
611impl AwsSecretsManagerProvider {
612 pub fn new(region: &str) -> Self {
614 Self {
615 region: region.to_string(),
616 }
617 }
618}
619
620impl CredentialProvider for AwsSecretsManagerProvider {
621 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
622 let _ = (key, &self.region);
631
632 Ok(DatabaseCredential {
633 username: format!("aws_user_{}", key),
634 password: "aws_managed_password".to_string(),
635 database: None,
636 host: None,
637 port: None,
638 options: HashMap::new(),
639 expires_at: None,
640 source: CredentialSource::AwsSecretsManager,
641 })
642 }
643
644 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
645 self.get_credential(key)
646 }
647
648 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
649 Ok(Vec::new())
651 }
652
653 fn provider_name(&self) -> &str {
654 "aws_secrets_manager"
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 fn test_credential() -> DatabaseCredential {
663 DatabaseCredential {
664 username: "testuser".to_string(),
665 password: "testpass".to_string(),
666 database: Some("testdb".to_string()),
667 host: Some("localhost".to_string()),
668 port: Some(5432),
669 options: HashMap::new(),
670 expires_at: None,
671 source: CredentialSource::Static,
672 }
673 }
674
675 #[test]
676 fn test_static_provider() {
677 let mut credentials = HashMap::new();
678 credentials.insert("db1".to_string(), test_credential());
679
680 let provider = StaticCredentialProvider::new(credentials);
681
682 let cred = provider.get_credential("db1").unwrap();
683 assert_eq!(cred.username, "testuser");
684
685 assert!(provider.get_credential("db2").is_err());
686 }
687
688 #[test]
689 fn test_connection_string() {
690 let cred = test_credential();
691 let conn_str = cred.connection_string();
692
693 assert!(conn_str.contains("testuser"));
694 assert!(conn_str.contains("testpass"));
695 assert!(conn_str.contains("localhost"));
696 assert!(conn_str.contains("5432"));
697 assert!(conn_str.contains("testdb"));
698 }
699
700 #[test]
701 fn test_credential_expiration() {
702 let mut cred = test_credential();
703
704 cred.expires_at = Some(chrono::Utc::now() + chrono::Duration::hours(1));
706 assert!(!cred.is_expired());
707 assert!(cred.time_until_expiration().is_some());
708
709 cred.expires_at = Some(chrono::Utc::now() - chrono::Duration::hours(1));
711 assert!(cred.is_expired());
712 assert!(cred.time_until_expiration().is_none());
713 }
714
715 #[test]
716 fn test_credential_manager() {
717 let mut credentials = HashMap::new();
718 credentials.insert("primary".to_string(), test_credential());
719
720 let manager = CredentialManager::builder()
721 .cache_ttl(Duration::from_secs(60))
722 .with_static_credentials(credentials)
723 .build();
724
725 let cred = manager.get_credential("primary").unwrap();
726 assert_eq!(cred.username, "testuser");
727
728 let cached = manager.get_credential("primary").unwrap();
730 assert_eq!(cached.username, "testuser");
731
732 let stats = manager.cache_stats();
734 assert_eq!(stats.entries, 1);
735 }
736
737 #[test]
738 fn test_list_credentials() {
739 let mut credentials = HashMap::new();
740 credentials.insert("db1".to_string(), test_credential());
741 credentials.insert("db2".to_string(), test_credential());
742
743 let manager = CredentialManager::builder()
744 .with_static_credentials(credentials)
745 .build();
746
747 let list = manager.list_credentials();
748 assert_eq!(list.len(), 2);
749 }
750
751 #[test]
752 fn test_cache_invalidation() {
753 let mut credentials = HashMap::new();
754 credentials.insert("db1".to_string(), test_credential());
755
756 let manager = CredentialManager::builder()
757 .with_static_credentials(credentials)
758 .build();
759
760 let _ = manager.get_credential("db1").unwrap();
762 assert_eq!(manager.cache_stats().entries, 1);
763
764 manager.invalidate("db1");
766 assert_eq!(manager.cache_stats().entries, 0);
767 }
768
769 #[test]
770 fn test_credential_source() {
771 assert_eq!(CredentialSource::Static, CredentialSource::Static);
772 assert_ne!(CredentialSource::Vault, CredentialSource::Environment);
773 }
774}