Skip to main content

heliosdb_proxy/auth/
credentials.rs

1//! Credential Providers
2//!
3//! Fetches database credentials from external sources like HashiCorp Vault,
4//! AWS Secrets Manager, or environment variables.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use parking_lot::RwLock;
11use thiserror::Error;
12
13use super::config::CredentialConfig;
14
15/// Credential errors
16#[derive(Debug, Error)]
17pub enum CredentialError {
18    #[error("Credential not found: {0}")]
19    NotFound(String),
20
21    #[error("Provider unavailable: {0}")]
22    ProviderUnavailable(String),
23
24    #[error("Access denied: {0}")]
25    AccessDenied(String),
26
27    #[error("Invalid credential format: {0}")]
28    InvalidFormat(String),
29
30    #[error("Credential expired")]
31    Expired,
32
33    #[error("Network error: {0}")]
34    NetworkError(String),
35
36    #[error("Configuration error: {0}")]
37    ConfigurationError(String),
38}
39
40/// Database credential
41#[derive(Debug, Clone)]
42pub struct DatabaseCredential {
43    /// Username
44    pub username: String,
45
46    /// Password
47    pub password: String,
48
49    /// Database name (optional)
50    pub database: Option<String>,
51
52    /// Host (optional, for connection routing)
53    pub host: Option<String>,
54
55    /// Port (optional)
56    pub port: Option<u16>,
57
58    /// Additional connection options
59    pub options: HashMap<String, String>,
60
61    /// Credential expiration
62    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
63
64    /// Source provider
65    pub source: CredentialSource,
66}
67
68impl DatabaseCredential {
69    /// Check if credential is expired
70    pub fn is_expired(&self) -> bool {
71        self.expires_at
72            .map(|exp| chrono::Utc::now() > exp)
73            .unwrap_or(false)
74    }
75
76    /// Get time until expiration
77    pub fn time_until_expiration(&self) -> Option<Duration> {
78        self.expires_at.and_then(|exp| {
79            let now = chrono::Utc::now();
80            if exp > now {
81                Some((exp - now).to_std().unwrap_or(Duration::ZERO))
82            } else {
83                None
84            }
85        })
86    }
87
88    /// Build connection string
89    pub fn connection_string(&self) -> String {
90        let host = self.host.as_deref().unwrap_or("localhost");
91        let port = self.port.unwrap_or(5432);
92        let database = self.database.as_deref().unwrap_or("postgres");
93
94        format!(
95            "postgresql://{}:{}@{}:{}/{}",
96            self.username,
97            self.password,
98            host,
99            port,
100            database
101        )
102    }
103}
104
105/// Credential source identifier
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub enum CredentialSource {
108    /// Static configuration
109    Static,
110
111    /// Environment variable
112    Environment,
113
114    /// HashiCorp Vault
115    Vault,
116
117    /// AWS Secrets Manager
118    AwsSecretsManager,
119
120    /// Azure Key Vault
121    AzureKeyVault,
122
123    /// GCP Secret Manager
124    GcpSecretManager,
125
126    /// Kubernetes secret
127    Kubernetes,
128
129    /// Custom provider
130    Custom(String),
131}
132
133/// Credential provider trait
134pub trait CredentialProvider: Send + Sync {
135    /// Get credential by key
136    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
137
138    /// Refresh credential
139    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
140
141    /// List available credentials
142    fn list_credentials(&self) -> Result<Vec<String>, CredentialError>;
143
144    /// Provider name
145    fn provider_name(&self) -> &str;
146}
147
148/// Credential manager that aggregates multiple providers
149pub struct CredentialManager {
150    /// Configuration
151    config: CredentialConfig,
152
153    /// Credential providers
154    providers: Vec<Box<dyn CredentialProvider>>,
155
156    /// Credential cache
157    cache: Arc<RwLock<CredentialCache>>,
158
159    /// Default provider index
160    default_provider: usize,
161}
162
163/// Credential cache
164struct CredentialCache {
165    entries: HashMap<String, CachedCredential>,
166    max_size: usize,
167    default_ttl: Duration,
168}
169
170struct CachedCredential {
171    credential: DatabaseCredential,
172    cached_at: Instant,
173    ttl: Duration,
174}
175
176impl CredentialCache {
177    fn new(max_size: usize, default_ttl: Duration) -> Self {
178        Self {
179            entries: HashMap::new(),
180            max_size,
181            default_ttl,
182        }
183    }
184
185    fn get(&self, key: &str) -> Option<&DatabaseCredential> {
186        self.entries.get(key).and_then(|cached| {
187            if cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired() {
188                Some(&cached.credential)
189            } else {
190                None
191            }
192        })
193    }
194
195    fn insert(&mut self, key: String, credential: DatabaseCredential, ttl: Option<Duration>) {
196        if self.entries.len() >= self.max_size {
197            self.evict_expired();
198        }
199
200        let ttl = ttl.unwrap_or(self.default_ttl);
201        self.entries.insert(key, CachedCredential {
202            credential,
203            cached_at: Instant::now(),
204            ttl,
205        });
206    }
207
208    fn evict_expired(&mut self) {
209        self.entries.retain(|_, cached| {
210            cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired()
211        });
212    }
213
214    fn invalidate(&mut self, key: &str) {
215        self.entries.remove(key);
216    }
217
218    fn clear(&mut self) {
219        self.entries.clear();
220    }
221}
222
223impl CredentialManager {
224    /// Create a new credential manager
225    pub fn new(config: CredentialConfig) -> Self {
226        let cache_ttl = config.cache_ttl;
227
228        Self {
229            config,
230            providers: Vec::new(),
231            cache: Arc::new(RwLock::new(CredentialCache::new(1000, cache_ttl))),
232            default_provider: 0,
233        }
234    }
235
236    /// Create a builder
237    pub fn builder() -> CredentialManagerBuilder {
238        CredentialManagerBuilder::new()
239    }
240
241    /// Add a provider
242    pub fn add_provider(&mut self, provider: Box<dyn CredentialProvider>) {
243        self.providers.push(provider);
244    }
245
246    /// Get credential by key
247    pub fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
248        // Check cache first
249        if let Some(cached) = self.cache.read().get(key) {
250            return Ok(cached.clone());
251        }
252
253        // Try each provider
254        for provider in &self.providers {
255            match provider.get_credential(key) {
256                Ok(credential) => {
257                    // Calculate cache TTL based on credential expiration
258                    let ttl = credential.time_until_expiration()
259                        .map(|d| d.min(self.config.cache_ttl))
260                        .or(Some(self.config.cache_ttl));
261
262                    // Cache and return
263                    self.cache.write().insert(key.to_string(), credential.clone(), ttl);
264                    return Ok(credential);
265                }
266                Err(CredentialError::NotFound(_)) => continue,
267                Err(e) => return Err(e),
268            }
269        }
270
271        Err(CredentialError::NotFound(key.to_string()))
272    }
273
274    /// Get credential with specific provider
275    pub fn get_credential_from(
276        &self,
277        key: &str,
278        provider_name: &str,
279    ) -> Result<DatabaseCredential, CredentialError> {
280        let provider = self.providers
281            .iter()
282            .find(|p| p.provider_name() == provider_name)
283            .ok_or_else(|| CredentialError::ProviderUnavailable(provider_name.to_string()))?;
284
285        provider.get_credential(key)
286    }
287
288    /// Refresh credential
289    pub fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
290        // Invalidate cache
291        self.cache.write().invalidate(key);
292
293        // Get fresh credential
294        for provider in &self.providers {
295            match provider.refresh_credential(key) {
296                Ok(credential) => {
297                    let ttl = credential.time_until_expiration()
298                        .map(|d| d.min(self.config.cache_ttl))
299                        .or(Some(self.config.cache_ttl));
300
301                    self.cache.write().insert(key.to_string(), credential.clone(), ttl);
302                    return Ok(credential);
303                }
304                Err(CredentialError::NotFound(_)) => continue,
305                Err(e) => return Err(e),
306            }
307        }
308
309        Err(CredentialError::NotFound(key.to_string()))
310    }
311
312    /// List all available credentials
313    pub fn list_credentials(&self) -> Vec<(String, String)> {
314        let mut result = Vec::new();
315
316        for provider in &self.providers {
317            if let Ok(keys) = provider.list_credentials() {
318                for key in keys {
319                    result.push((key, provider.provider_name().to_string()));
320                }
321            }
322        }
323
324        result
325    }
326
327    /// Invalidate cached credential
328    pub fn invalidate(&self, key: &str) {
329        self.cache.write().invalidate(key);
330    }
331
332    /// Clear credential cache
333    pub fn clear_cache(&self) {
334        self.cache.write().clear();
335    }
336
337    /// Get cache statistics
338    pub fn cache_stats(&self) -> CacheStats {
339        let cache = self.cache.read();
340        CacheStats {
341            entries: cache.entries.len(),
342            max_size: cache.max_size,
343        }
344    }
345}
346
347/// Cache statistics
348#[derive(Debug, Clone)]
349pub struct CacheStats {
350    pub entries: usize,
351    pub max_size: usize,
352}
353
354/// Credential manager builder
355pub struct CredentialManagerBuilder {
356    config: CredentialConfig,
357    providers: Vec<Box<dyn CredentialProvider>>,
358}
359
360impl CredentialManagerBuilder {
361    /// Create a new builder
362    pub fn new() -> Self {
363        Self {
364            config: CredentialConfig::default(),
365            providers: Vec::new(),
366        }
367    }
368
369    /// Set cache TTL
370    pub fn cache_ttl(mut self, ttl: Duration) -> Self {
371        self.config.cache_ttl = ttl;
372        self
373    }
374
375    /// Add static provider
376    pub fn with_static_credentials(mut self, credentials: HashMap<String, DatabaseCredential>) -> Self {
377        self.providers.push(Box::new(StaticCredentialProvider::new(credentials)));
378        self
379    }
380
381    /// Add environment provider
382    pub fn with_environment(mut self, prefix: &str) -> Self {
383        self.providers.push(Box::new(EnvironmentCredentialProvider::new(prefix)));
384        self
385    }
386
387    /// Add Vault provider
388    pub fn with_vault(mut self, address: &str, token: &str, mount: &str) -> Self {
389        self.providers.push(Box::new(VaultCredentialProvider::new(address, token, mount)));
390        self
391    }
392
393    /// Add AWS Secrets Manager provider
394    pub fn with_aws_secrets_manager(mut self, region: &str) -> Self {
395        self.providers.push(Box::new(AwsSecretsManagerProvider::new(region)));
396        self
397    }
398
399    /// Add custom provider
400    pub fn with_provider(mut self, provider: Box<dyn CredentialProvider>) -> Self {
401        self.providers.push(provider);
402        self
403    }
404
405    /// Build the manager
406    pub fn build(self) -> CredentialManager {
407        let mut manager = CredentialManager::new(self.config);
408        for provider in self.providers {
409            manager.add_provider(provider);
410        }
411        manager
412    }
413}
414
415impl Default for CredentialManagerBuilder {
416    fn default() -> Self {
417        Self::new()
418    }
419}
420
421/// Static credential provider
422pub struct StaticCredentialProvider {
423    credentials: HashMap<String, DatabaseCredential>,
424}
425
426impl StaticCredentialProvider {
427    /// Create a new static provider
428    pub fn new(credentials: HashMap<String, DatabaseCredential>) -> Self {
429        Self { credentials }
430    }
431
432    /// Add a credential
433    pub fn add(&mut self, key: String, credential: DatabaseCredential) {
434        self.credentials.insert(key, credential);
435    }
436}
437
438impl CredentialProvider for StaticCredentialProvider {
439    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
440        self.credentials
441            .get(key)
442            .cloned()
443            .ok_or_else(|| CredentialError::NotFound(key.to_string()))
444    }
445
446    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
447        self.get_credential(key)
448    }
449
450    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
451        Ok(self.credentials.keys().cloned().collect())
452    }
453
454    fn provider_name(&self) -> &str {
455        "static"
456    }
457}
458
459/// Environment variable credential provider
460pub struct EnvironmentCredentialProvider {
461    prefix: String,
462}
463
464impl EnvironmentCredentialProvider {
465    /// Create a new environment provider
466    pub fn new(prefix: &str) -> Self {
467        Self {
468            prefix: prefix.to_string(),
469        }
470    }
471
472    fn var_name(&self, key: &str, suffix: &str) -> String {
473        format!("{}_{}{}", self.prefix, key.to_uppercase(), suffix)
474    }
475}
476
477impl CredentialProvider for EnvironmentCredentialProvider {
478    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
479        let username = std::env::var(self.var_name(key, "_USERNAME"))
480            .or_else(|_| std::env::var(self.var_name(key, "_USER")))
481            .map_err(|_| CredentialError::NotFound(key.to_string()))?;
482
483        let password = std::env::var(self.var_name(key, "_PASSWORD"))
484            .or_else(|_| std::env::var(self.var_name(key, "_PASS")))
485            .map_err(|_| CredentialError::NotFound(format!("{}_PASSWORD", key)))?;
486
487        let database = std::env::var(self.var_name(key, "_DATABASE")).ok();
488        let host = std::env::var(self.var_name(key, "_HOST")).ok();
489        let port = std::env::var(self.var_name(key, "_PORT"))
490            .ok()
491            .and_then(|p| p.parse().ok());
492
493        Ok(DatabaseCredential {
494            username,
495            password,
496            database,
497            host,
498            port,
499            options: HashMap::new(),
500            expires_at: None,
501            source: CredentialSource::Environment,
502        })
503    }
504
505    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
506        self.get_credential(key)
507    }
508
509    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
510        // Scan environment for matching credentials
511        let mut keys = Vec::new();
512        let prefix_upper = format!("{}_", self.prefix.to_uppercase());
513
514        for (key, _) in std::env::vars() {
515            if key.starts_with(&prefix_upper) && key.ends_with("_USERNAME") {
516                let name = key
517                    .strip_prefix(&prefix_upper)
518                    .and_then(|s| s.strip_suffix("_USERNAME"))
519                    .map(|s| s.to_lowercase());
520                if let Some(name) = name {
521                    keys.push(name);
522                }
523            }
524        }
525
526        Ok(keys)
527    }
528
529    fn provider_name(&self) -> &str {
530        "environment"
531    }
532}
533
534/// HashiCorp Vault credential provider
535pub struct VaultCredentialProvider {
536    address: String,
537    token: String,
538    mount: String,
539}
540
541impl VaultCredentialProvider {
542    /// Create a new Vault provider
543    pub fn new(address: &str, token: &str, mount: &str) -> Self {
544        Self {
545            address: address.to_string(),
546            token: token.to_string(),
547            mount: mount.to_string(),
548        }
549    }
550}
551
552impl CredentialProvider for VaultCredentialProvider {
553    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
554        // In a real implementation, this would make an HTTP request to Vault
555        // For demonstration, we return a placeholder
556        //
557        // Real implementation would:
558        // 1. POST to {address}/v1/{mount}/creds/{key}
559        // 2. Parse response for username/password
560        // 3. Handle lease renewal
561
562        let _ = (key, &self.address, &self.token, &self.mount);
563
564        Ok(DatabaseCredential {
565            username: format!("vault_user_{}", key),
566            password: "vault_generated_password".to_string(),
567            database: None,
568            host: None,
569            port: None,
570            options: HashMap::new(),
571            expires_at: Some(chrono::Utc::now() + chrono::Duration::hours(1)),
572            source: CredentialSource::Vault,
573        })
574    }
575
576    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
577        // In Vault, this would renew the lease or get a new credential
578        self.get_credential(key)
579    }
580
581    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
582        // In a real implementation, this would list roles from Vault
583        Ok(Vec::new())
584    }
585
586    fn provider_name(&self) -> &str {
587        "vault"
588    }
589}
590
591/// AWS Secrets Manager credential provider
592pub struct AwsSecretsManagerProvider {
593    region: String,
594}
595
596impl AwsSecretsManagerProvider {
597    /// Create a new AWS Secrets Manager provider
598    pub fn new(region: &str) -> Self {
599        Self {
600            region: region.to_string(),
601        }
602    }
603}
604
605impl CredentialProvider for AwsSecretsManagerProvider {
606    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
607        // In a real implementation, this would use the AWS SDK
608        // For demonstration, we return a placeholder
609        //
610        // Real implementation would:
611        // 1. Use aws_sdk_secretsmanager to get secret value
612        // 2. Parse JSON for username/password
613        // 3. Handle rotation
614
615        let _ = (key, &self.region);
616
617        Ok(DatabaseCredential {
618            username: format!("aws_user_{}", key),
619            password: "aws_managed_password".to_string(),
620            database: None,
621            host: None,
622            port: None,
623            options: HashMap::new(),
624            expires_at: None,
625            source: CredentialSource::AwsSecretsManager,
626        })
627    }
628
629    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
630        self.get_credential(key)
631    }
632
633    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
634        // In a real implementation, this would list secrets from AWS
635        Ok(Vec::new())
636    }
637
638    fn provider_name(&self) -> &str {
639        "aws_secrets_manager"
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646
647    fn test_credential() -> DatabaseCredential {
648        DatabaseCredential {
649            username: "testuser".to_string(),
650            password: "testpass".to_string(),
651            database: Some("testdb".to_string()),
652            host: Some("localhost".to_string()),
653            port: Some(5432),
654            options: HashMap::new(),
655            expires_at: None,
656            source: CredentialSource::Static,
657        }
658    }
659
660    #[test]
661    fn test_static_provider() {
662        let mut credentials = HashMap::new();
663        credentials.insert("db1".to_string(), test_credential());
664
665        let provider = StaticCredentialProvider::new(credentials);
666
667        let cred = provider.get_credential("db1").unwrap();
668        assert_eq!(cred.username, "testuser");
669
670        assert!(provider.get_credential("db2").is_err());
671    }
672
673    #[test]
674    fn test_connection_string() {
675        let cred = test_credential();
676        let conn_str = cred.connection_string();
677
678        assert!(conn_str.contains("testuser"));
679        assert!(conn_str.contains("testpass"));
680        assert!(conn_str.contains("localhost"));
681        assert!(conn_str.contains("5432"));
682        assert!(conn_str.contains("testdb"));
683    }
684
685    #[test]
686    fn test_credential_expiration() {
687        let mut cred = test_credential();
688
689        // Not expired
690        cred.expires_at = Some(chrono::Utc::now() + chrono::Duration::hours(1));
691        assert!(!cred.is_expired());
692        assert!(cred.time_until_expiration().is_some());
693
694        // Expired
695        cred.expires_at = Some(chrono::Utc::now() - chrono::Duration::hours(1));
696        assert!(cred.is_expired());
697        assert!(cred.time_until_expiration().is_none());
698    }
699
700    #[test]
701    fn test_credential_manager() {
702        let mut credentials = HashMap::new();
703        credentials.insert("primary".to_string(), test_credential());
704
705        let manager = CredentialManager::builder()
706            .cache_ttl(Duration::from_secs(60))
707            .with_static_credentials(credentials)
708            .build();
709
710        let cred = manager.get_credential("primary").unwrap();
711        assert_eq!(cred.username, "testuser");
712
713        // Should be cached
714        let cached = manager.get_credential("primary").unwrap();
715        assert_eq!(cached.username, "testuser");
716
717        // Check stats
718        let stats = manager.cache_stats();
719        assert_eq!(stats.entries, 1);
720    }
721
722    #[test]
723    fn test_list_credentials() {
724        let mut credentials = HashMap::new();
725        credentials.insert("db1".to_string(), test_credential());
726        credentials.insert("db2".to_string(), test_credential());
727
728        let manager = CredentialManager::builder()
729            .with_static_credentials(credentials)
730            .build();
731
732        let list = manager.list_credentials();
733        assert_eq!(list.len(), 2);
734    }
735
736    #[test]
737    fn test_cache_invalidation() {
738        let mut credentials = HashMap::new();
739        credentials.insert("db1".to_string(), test_credential());
740
741        let manager = CredentialManager::builder()
742            .with_static_credentials(credentials)
743            .build();
744
745        // Cache it
746        let _ = manager.get_credential("db1").unwrap();
747        assert_eq!(manager.cache_stats().entries, 1);
748
749        // Invalidate
750        manager.invalidate("db1");
751        assert_eq!(manager.cache_stats().entries, 0);
752    }
753
754    #[test]
755    fn test_credential_source() {
756        assert_eq!(CredentialSource::Static, CredentialSource::Static);
757        assert_ne!(CredentialSource::Vault, CredentialSource::Environment);
758    }
759}