1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use parking_lot::RwLock;
11use thiserror::Error;
12
13use super::config::CredentialConfig;
14
15#[derive(Debug, Error)]
17pub enum CredentialError {
18 #[error("Credential not found: {0}")]
19 NotFound(String),
20
21 #[error("Provider unavailable: {0}")]
22 ProviderUnavailable(String),
23
24 #[error("Access denied: {0}")]
25 AccessDenied(String),
26
27 #[error("Invalid credential format: {0}")]
28 InvalidFormat(String),
29
30 #[error("Credential expired")]
31 Expired,
32
33 #[error("Network error: {0}")]
34 NetworkError(String),
35
36 #[error("Configuration error: {0}")]
37 ConfigurationError(String),
38}
39
40#[derive(Debug, Clone)]
42pub struct DatabaseCredential {
43 pub username: String,
45
46 pub password: String,
48
49 pub database: Option<String>,
51
52 pub host: Option<String>,
54
55 pub port: Option<u16>,
57
58 pub options: HashMap<String, String>,
60
61 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
63
64 pub source: CredentialSource,
66}
67
68impl DatabaseCredential {
69 pub fn is_expired(&self) -> bool {
71 self.expires_at
72 .map(|exp| chrono::Utc::now() > exp)
73 .unwrap_or(false)
74 }
75
76 pub fn time_until_expiration(&self) -> Option<Duration> {
78 self.expires_at.and_then(|exp| {
79 let now = chrono::Utc::now();
80 if exp > now {
81 Some((exp - now).to_std().unwrap_or(Duration::ZERO))
82 } else {
83 None
84 }
85 })
86 }
87
88 pub fn connection_string(&self) -> String {
90 let host = self.host.as_deref().unwrap_or("localhost");
91 let port = self.port.unwrap_or(5432);
92 let database = self.database.as_deref().unwrap_or("postgres");
93
94 format!(
95 "postgresql://{}:{}@{}:{}/{}",
96 self.username,
97 self.password,
98 host,
99 port,
100 database
101 )
102 }
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
107pub enum CredentialSource {
108 Static,
110
111 Environment,
113
114 Vault,
116
117 AwsSecretsManager,
119
120 AzureKeyVault,
122
123 GcpSecretManager,
125
126 Kubernetes,
128
129 Custom(String),
131}
132
133pub trait CredentialProvider: Send + Sync {
135 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
137
138 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
140
141 fn list_credentials(&self) -> Result<Vec<String>, CredentialError>;
143
144 fn provider_name(&self) -> &str;
146}
147
148pub struct CredentialManager {
150 config: CredentialConfig,
152
153 providers: Vec<Box<dyn CredentialProvider>>,
155
156 cache: Arc<RwLock<CredentialCache>>,
158
159 default_provider: usize,
161}
162
163struct CredentialCache {
165 entries: HashMap<String, CachedCredential>,
166 max_size: usize,
167 default_ttl: Duration,
168}
169
170struct CachedCredential {
171 credential: DatabaseCredential,
172 cached_at: Instant,
173 ttl: Duration,
174}
175
176impl CredentialCache {
177 fn new(max_size: usize, default_ttl: Duration) -> Self {
178 Self {
179 entries: HashMap::new(),
180 max_size,
181 default_ttl,
182 }
183 }
184
185 fn get(&self, key: &str) -> Option<&DatabaseCredential> {
186 self.entries.get(key).and_then(|cached| {
187 if cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired() {
188 Some(&cached.credential)
189 } else {
190 None
191 }
192 })
193 }
194
195 fn insert(&mut self, key: String, credential: DatabaseCredential, ttl: Option<Duration>) {
196 if self.entries.len() >= self.max_size {
197 self.evict_expired();
198 }
199
200 let ttl = ttl.unwrap_or(self.default_ttl);
201 self.entries.insert(key, CachedCredential {
202 credential,
203 cached_at: Instant::now(),
204 ttl,
205 });
206 }
207
208 fn evict_expired(&mut self) {
209 self.entries.retain(|_, cached| {
210 cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired()
211 });
212 }
213
214 fn invalidate(&mut self, key: &str) {
215 self.entries.remove(key);
216 }
217
218 fn clear(&mut self) {
219 self.entries.clear();
220 }
221}
222
223impl CredentialManager {
224 pub fn new(config: CredentialConfig) -> Self {
226 let cache_ttl = config.cache_ttl;
227
228 Self {
229 config,
230 providers: Vec::new(),
231 cache: Arc::new(RwLock::new(CredentialCache::new(1000, cache_ttl))),
232 default_provider: 0,
233 }
234 }
235
236 pub fn builder() -> CredentialManagerBuilder {
238 CredentialManagerBuilder::new()
239 }
240
241 pub fn add_provider(&mut self, provider: Box<dyn CredentialProvider>) {
243 self.providers.push(provider);
244 }
245
246 pub fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
248 if let Some(cached) = self.cache.read().get(key) {
250 return Ok(cached.clone());
251 }
252
253 for provider in &self.providers {
255 match provider.get_credential(key) {
256 Ok(credential) => {
257 let ttl = credential.time_until_expiration()
259 .map(|d| d.min(self.config.cache_ttl))
260 .or(Some(self.config.cache_ttl));
261
262 self.cache.write().insert(key.to_string(), credential.clone(), ttl);
264 return Ok(credential);
265 }
266 Err(CredentialError::NotFound(_)) => continue,
267 Err(e) => return Err(e),
268 }
269 }
270
271 Err(CredentialError::NotFound(key.to_string()))
272 }
273
274 pub fn get_credential_from(
276 &self,
277 key: &str,
278 provider_name: &str,
279 ) -> Result<DatabaseCredential, CredentialError> {
280 let provider = self.providers
281 .iter()
282 .find(|p| p.provider_name() == provider_name)
283 .ok_or_else(|| CredentialError::ProviderUnavailable(provider_name.to_string()))?;
284
285 provider.get_credential(key)
286 }
287
288 pub fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
290 self.cache.write().invalidate(key);
292
293 for provider in &self.providers {
295 match provider.refresh_credential(key) {
296 Ok(credential) => {
297 let ttl = credential.time_until_expiration()
298 .map(|d| d.min(self.config.cache_ttl))
299 .or(Some(self.config.cache_ttl));
300
301 self.cache.write().insert(key.to_string(), credential.clone(), ttl);
302 return Ok(credential);
303 }
304 Err(CredentialError::NotFound(_)) => continue,
305 Err(e) => return Err(e),
306 }
307 }
308
309 Err(CredentialError::NotFound(key.to_string()))
310 }
311
312 pub fn list_credentials(&self) -> Vec<(String, String)> {
314 let mut result = Vec::new();
315
316 for provider in &self.providers {
317 if let Ok(keys) = provider.list_credentials() {
318 for key in keys {
319 result.push((key, provider.provider_name().to_string()));
320 }
321 }
322 }
323
324 result
325 }
326
327 pub fn invalidate(&self, key: &str) {
329 self.cache.write().invalidate(key);
330 }
331
332 pub fn clear_cache(&self) {
334 self.cache.write().clear();
335 }
336
337 pub fn cache_stats(&self) -> CacheStats {
339 let cache = self.cache.read();
340 CacheStats {
341 entries: cache.entries.len(),
342 max_size: cache.max_size,
343 }
344 }
345}
346
347#[derive(Debug, Clone)]
349pub struct CacheStats {
350 pub entries: usize,
351 pub max_size: usize,
352}
353
354pub struct CredentialManagerBuilder {
356 config: CredentialConfig,
357 providers: Vec<Box<dyn CredentialProvider>>,
358}
359
360impl CredentialManagerBuilder {
361 pub fn new() -> Self {
363 Self {
364 config: CredentialConfig::default(),
365 providers: Vec::new(),
366 }
367 }
368
369 pub fn cache_ttl(mut self, ttl: Duration) -> Self {
371 self.config.cache_ttl = ttl;
372 self
373 }
374
375 pub fn with_static_credentials(mut self, credentials: HashMap<String, DatabaseCredential>) -> Self {
377 self.providers.push(Box::new(StaticCredentialProvider::new(credentials)));
378 self
379 }
380
381 pub fn with_environment(mut self, prefix: &str) -> Self {
383 self.providers.push(Box::new(EnvironmentCredentialProvider::new(prefix)));
384 self
385 }
386
387 pub fn with_vault(mut self, address: &str, token: &str, mount: &str) -> Self {
389 self.providers.push(Box::new(VaultCredentialProvider::new(address, token, mount)));
390 self
391 }
392
393 pub fn with_aws_secrets_manager(mut self, region: &str) -> Self {
395 self.providers.push(Box::new(AwsSecretsManagerProvider::new(region)));
396 self
397 }
398
399 pub fn with_provider(mut self, provider: Box<dyn CredentialProvider>) -> Self {
401 self.providers.push(provider);
402 self
403 }
404
405 pub fn build(self) -> CredentialManager {
407 let mut manager = CredentialManager::new(self.config);
408 for provider in self.providers {
409 manager.add_provider(provider);
410 }
411 manager
412 }
413}
414
415impl Default for CredentialManagerBuilder {
416 fn default() -> Self {
417 Self::new()
418 }
419}
420
421pub struct StaticCredentialProvider {
423 credentials: HashMap<String, DatabaseCredential>,
424}
425
426impl StaticCredentialProvider {
427 pub fn new(credentials: HashMap<String, DatabaseCredential>) -> Self {
429 Self { credentials }
430 }
431
432 pub fn add(&mut self, key: String, credential: DatabaseCredential) {
434 self.credentials.insert(key, credential);
435 }
436}
437
438impl CredentialProvider for StaticCredentialProvider {
439 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
440 self.credentials
441 .get(key)
442 .cloned()
443 .ok_or_else(|| CredentialError::NotFound(key.to_string()))
444 }
445
446 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
447 self.get_credential(key)
448 }
449
450 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
451 Ok(self.credentials.keys().cloned().collect())
452 }
453
454 fn provider_name(&self) -> &str {
455 "static"
456 }
457}
458
459pub struct EnvironmentCredentialProvider {
461 prefix: String,
462}
463
464impl EnvironmentCredentialProvider {
465 pub fn new(prefix: &str) -> Self {
467 Self {
468 prefix: prefix.to_string(),
469 }
470 }
471
472 fn var_name(&self, key: &str, suffix: &str) -> String {
473 format!("{}_{}{}", self.prefix, key.to_uppercase(), suffix)
474 }
475}
476
477impl CredentialProvider for EnvironmentCredentialProvider {
478 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
479 let username = std::env::var(self.var_name(key, "_USERNAME"))
480 .or_else(|_| std::env::var(self.var_name(key, "_USER")))
481 .map_err(|_| CredentialError::NotFound(key.to_string()))?;
482
483 let password = std::env::var(self.var_name(key, "_PASSWORD"))
484 .or_else(|_| std::env::var(self.var_name(key, "_PASS")))
485 .map_err(|_| CredentialError::NotFound(format!("{}_PASSWORD", key)))?;
486
487 let database = std::env::var(self.var_name(key, "_DATABASE")).ok();
488 let host = std::env::var(self.var_name(key, "_HOST")).ok();
489 let port = std::env::var(self.var_name(key, "_PORT"))
490 .ok()
491 .and_then(|p| p.parse().ok());
492
493 Ok(DatabaseCredential {
494 username,
495 password,
496 database,
497 host,
498 port,
499 options: HashMap::new(),
500 expires_at: None,
501 source: CredentialSource::Environment,
502 })
503 }
504
505 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
506 self.get_credential(key)
507 }
508
509 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
510 let mut keys = Vec::new();
512 let prefix_upper = format!("{}_", self.prefix.to_uppercase());
513
514 for (key, _) in std::env::vars() {
515 if key.starts_with(&prefix_upper) && key.ends_with("_USERNAME") {
516 let name = key
517 .strip_prefix(&prefix_upper)
518 .and_then(|s| s.strip_suffix("_USERNAME"))
519 .map(|s| s.to_lowercase());
520 if let Some(name) = name {
521 keys.push(name);
522 }
523 }
524 }
525
526 Ok(keys)
527 }
528
529 fn provider_name(&self) -> &str {
530 "environment"
531 }
532}
533
534pub struct VaultCredentialProvider {
536 address: String,
537 token: String,
538 mount: String,
539}
540
541impl VaultCredentialProvider {
542 pub fn new(address: &str, token: &str, mount: &str) -> Self {
544 Self {
545 address: address.to_string(),
546 token: token.to_string(),
547 mount: mount.to_string(),
548 }
549 }
550}
551
552impl CredentialProvider for VaultCredentialProvider {
553 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
554 let _ = (key, &self.address, &self.token, &self.mount);
563
564 Ok(DatabaseCredential {
565 username: format!("vault_user_{}", key),
566 password: "vault_generated_password".to_string(),
567 database: None,
568 host: None,
569 port: None,
570 options: HashMap::new(),
571 expires_at: Some(chrono::Utc::now() + chrono::Duration::hours(1)),
572 source: CredentialSource::Vault,
573 })
574 }
575
576 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
577 self.get_credential(key)
579 }
580
581 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
582 Ok(Vec::new())
584 }
585
586 fn provider_name(&self) -> &str {
587 "vault"
588 }
589}
590
591pub struct AwsSecretsManagerProvider {
593 region: String,
594}
595
596impl AwsSecretsManagerProvider {
597 pub fn new(region: &str) -> Self {
599 Self {
600 region: region.to_string(),
601 }
602 }
603}
604
605impl CredentialProvider for AwsSecretsManagerProvider {
606 fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
607 let _ = (key, &self.region);
616
617 Ok(DatabaseCredential {
618 username: format!("aws_user_{}", key),
619 password: "aws_managed_password".to_string(),
620 database: None,
621 host: None,
622 port: None,
623 options: HashMap::new(),
624 expires_at: None,
625 source: CredentialSource::AwsSecretsManager,
626 })
627 }
628
629 fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
630 self.get_credential(key)
631 }
632
633 fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
634 Ok(Vec::new())
636 }
637
638 fn provider_name(&self) -> &str {
639 "aws_secrets_manager"
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646
647 fn test_credential() -> DatabaseCredential {
648 DatabaseCredential {
649 username: "testuser".to_string(),
650 password: "testpass".to_string(),
651 database: Some("testdb".to_string()),
652 host: Some("localhost".to_string()),
653 port: Some(5432),
654 options: HashMap::new(),
655 expires_at: None,
656 source: CredentialSource::Static,
657 }
658 }
659
660 #[test]
661 fn test_static_provider() {
662 let mut credentials = HashMap::new();
663 credentials.insert("db1".to_string(), test_credential());
664
665 let provider = StaticCredentialProvider::new(credentials);
666
667 let cred = provider.get_credential("db1").unwrap();
668 assert_eq!(cred.username, "testuser");
669
670 assert!(provider.get_credential("db2").is_err());
671 }
672
673 #[test]
674 fn test_connection_string() {
675 let cred = test_credential();
676 let conn_str = cred.connection_string();
677
678 assert!(conn_str.contains("testuser"));
679 assert!(conn_str.contains("testpass"));
680 assert!(conn_str.contains("localhost"));
681 assert!(conn_str.contains("5432"));
682 assert!(conn_str.contains("testdb"));
683 }
684
685 #[test]
686 fn test_credential_expiration() {
687 let mut cred = test_credential();
688
689 cred.expires_at = Some(chrono::Utc::now() + chrono::Duration::hours(1));
691 assert!(!cred.is_expired());
692 assert!(cred.time_until_expiration().is_some());
693
694 cred.expires_at = Some(chrono::Utc::now() - chrono::Duration::hours(1));
696 assert!(cred.is_expired());
697 assert!(cred.time_until_expiration().is_none());
698 }
699
700 #[test]
701 fn test_credential_manager() {
702 let mut credentials = HashMap::new();
703 credentials.insert("primary".to_string(), test_credential());
704
705 let manager = CredentialManager::builder()
706 .cache_ttl(Duration::from_secs(60))
707 .with_static_credentials(credentials)
708 .build();
709
710 let cred = manager.get_credential("primary").unwrap();
711 assert_eq!(cred.username, "testuser");
712
713 let cached = manager.get_credential("primary").unwrap();
715 assert_eq!(cached.username, "testuser");
716
717 let stats = manager.cache_stats();
719 assert_eq!(stats.entries, 1);
720 }
721
722 #[test]
723 fn test_list_credentials() {
724 let mut credentials = HashMap::new();
725 credentials.insert("db1".to_string(), test_credential());
726 credentials.insert("db2".to_string(), test_credential());
727
728 let manager = CredentialManager::builder()
729 .with_static_credentials(credentials)
730 .build();
731
732 let list = manager.list_credentials();
733 assert_eq!(list.len(), 2);
734 }
735
736 #[test]
737 fn test_cache_invalidation() {
738 let mut credentials = HashMap::new();
739 credentials.insert("db1".to_string(), test_credential());
740
741 let manager = CredentialManager::builder()
742 .with_static_credentials(credentials)
743 .build();
744
745 let _ = manager.get_credential("db1").unwrap();
747 assert_eq!(manager.cache_stats().entries, 1);
748
749 manager.invalidate("db1");
751 assert_eq!(manager.cache_stats().entries, 0);
752 }
753
754 #[test]
755 fn test_credential_source() {
756 assert_eq!(CredentialSource::Static, CredentialSource::Static);
757 assert_ne!(CredentialSource::Vault, CredentialSource::Environment);
758 }
759}