Skip to main content

heliosdb_proxy/auth/
credentials.rs

1//! Credential Providers
2//!
3//! Fetches database credentials from external sources like HashiCorp Vault,
4//! AWS Secrets Manager, or environment variables.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use parking_lot::RwLock;
11use thiserror::Error;
12
13use super::config::CredentialConfig;
14
15/// Credential errors
16#[derive(Debug, Error)]
17pub enum CredentialError {
18    #[error("Credential not found: {0}")]
19    NotFound(String),
20
21    #[error("Provider unavailable: {0}")]
22    ProviderUnavailable(String),
23
24    #[error("Access denied: {0}")]
25    AccessDenied(String),
26
27    #[error("Invalid credential format: {0}")]
28    InvalidFormat(String),
29
30    #[error("Credential expired")]
31    Expired,
32
33    #[error("Network error: {0}")]
34    NetworkError(String),
35
36    #[error("Configuration error: {0}")]
37    ConfigurationError(String),
38}
39
40/// Database credential
41#[derive(Debug, Clone)]
42pub struct DatabaseCredential {
43    /// Username
44    pub username: String,
45
46    /// Password
47    pub password: String,
48
49    /// Database name (optional)
50    pub database: Option<String>,
51
52    /// Host (optional, for connection routing)
53    pub host: Option<String>,
54
55    /// Port (optional)
56    pub port: Option<u16>,
57
58    /// Additional connection options
59    pub options: HashMap<String, String>,
60
61    /// Credential expiration
62    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
63
64    /// Source provider
65    pub source: CredentialSource,
66}
67
68impl DatabaseCredential {
69    /// Check if credential is expired
70    pub fn is_expired(&self) -> bool {
71        self.expires_at
72            .map(|exp| chrono::Utc::now() > exp)
73            .unwrap_or(false)
74    }
75
76    /// Get time until expiration
77    pub fn time_until_expiration(&self) -> Option<Duration> {
78        self.expires_at.and_then(|exp| {
79            let now = chrono::Utc::now();
80            if exp > now {
81                Some((exp - now).to_std().unwrap_or(Duration::ZERO))
82            } else {
83                None
84            }
85        })
86    }
87
88    /// Build connection string
89    pub fn connection_string(&self) -> String {
90        let host = self.host.as_deref().unwrap_or("localhost");
91        let port = self.port.unwrap_or(5432);
92        let database = self.database.as_deref().unwrap_or("postgres");
93
94        format!(
95            "postgresql://{}:{}@{}:{}/{}",
96            self.username, self.password, host, port, database
97        )
98    }
99}
100
101/// Credential source identifier
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum CredentialSource {
104    /// Static configuration
105    Static,
106
107    /// Environment variable
108    Environment,
109
110    /// HashiCorp Vault
111    Vault,
112
113    /// AWS Secrets Manager
114    AwsSecretsManager,
115
116    /// Azure Key Vault
117    AzureKeyVault,
118
119    /// GCP Secret Manager
120    GcpSecretManager,
121
122    /// Kubernetes secret
123    Kubernetes,
124
125    /// Custom provider
126    Custom(String),
127}
128
129/// Credential provider trait
130pub trait CredentialProvider: Send + Sync {
131    /// Get credential by key
132    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
133
134    /// Refresh credential
135    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError>;
136
137    /// List available credentials
138    fn list_credentials(&self) -> Result<Vec<String>, CredentialError>;
139
140    /// Provider name
141    fn provider_name(&self) -> &str;
142}
143
144/// Credential manager that aggregates multiple providers
145pub struct CredentialManager {
146    /// Configuration
147    config: CredentialConfig,
148
149    /// Credential providers
150    providers: Vec<Box<dyn CredentialProvider>>,
151
152    /// Credential cache
153    cache: Arc<RwLock<CredentialCache>>,
154
155    /// Default provider index
156    #[allow(dead_code)]
157    default_provider: usize,
158}
159
160/// Credential cache
161struct CredentialCache {
162    entries: HashMap<String, CachedCredential>,
163    max_size: usize,
164    default_ttl: Duration,
165}
166
167struct CachedCredential {
168    credential: DatabaseCredential,
169    cached_at: Instant,
170    ttl: Duration,
171}
172
173impl CredentialCache {
174    fn new(max_size: usize, default_ttl: Duration) -> Self {
175        Self {
176            entries: HashMap::new(),
177            max_size,
178            default_ttl,
179        }
180    }
181
182    fn get(&self, key: &str) -> Option<&DatabaseCredential> {
183        self.entries.get(key).and_then(|cached| {
184            if cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired() {
185                Some(&cached.credential)
186            } else {
187                None
188            }
189        })
190    }
191
192    fn insert(&mut self, key: String, credential: DatabaseCredential, ttl: Option<Duration>) {
193        if self.entries.len() >= self.max_size {
194            self.evict_expired();
195        }
196
197        let ttl = ttl.unwrap_or(self.default_ttl);
198        self.entries.insert(
199            key,
200            CachedCredential {
201                credential,
202                cached_at: Instant::now(),
203                ttl,
204            },
205        );
206    }
207
208    fn evict_expired(&mut self) {
209        self.entries.retain(|_, cached| {
210            cached.cached_at.elapsed() < cached.ttl && !cached.credential.is_expired()
211        });
212    }
213
214    fn invalidate(&mut self, key: &str) {
215        self.entries.remove(key);
216    }
217
218    fn clear(&mut self) {
219        self.entries.clear();
220    }
221}
222
223impl CredentialManager {
224    /// Create a new credential manager
225    pub fn new(config: CredentialConfig) -> Self {
226        let cache_ttl = config.cache_ttl;
227
228        Self {
229            config,
230            providers: Vec::new(),
231            cache: Arc::new(RwLock::new(CredentialCache::new(1000, cache_ttl))),
232            default_provider: 0,
233        }
234    }
235
236    /// Create a builder
237    pub fn builder() -> CredentialManagerBuilder {
238        CredentialManagerBuilder::new()
239    }
240
241    /// Add a provider
242    pub fn add_provider(&mut self, provider: Box<dyn CredentialProvider>) {
243        self.providers.push(provider);
244    }
245
246    /// Get credential by key
247    pub fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
248        // Check cache first
249        if let Some(cached) = self.cache.read().get(key) {
250            return Ok(cached.clone());
251        }
252
253        // Try each provider
254        for provider in &self.providers {
255            match provider.get_credential(key) {
256                Ok(credential) => {
257                    // Calculate cache TTL based on credential expiration
258                    let ttl = credential
259                        .time_until_expiration()
260                        .map(|d| d.min(self.config.cache_ttl))
261                        .or(Some(self.config.cache_ttl));
262
263                    // Cache and return
264                    self.cache
265                        .write()
266                        .insert(key.to_string(), credential.clone(), ttl);
267                    return Ok(credential);
268                }
269                Err(CredentialError::NotFound(_)) => continue,
270                Err(e) => return Err(e),
271            }
272        }
273
274        Err(CredentialError::NotFound(key.to_string()))
275    }
276
277    /// Get credential with specific provider
278    pub fn get_credential_from(
279        &self,
280        key: &str,
281        provider_name: &str,
282    ) -> Result<DatabaseCredential, CredentialError> {
283        let provider = self
284            .providers
285            .iter()
286            .find(|p| p.provider_name() == provider_name)
287            .ok_or_else(|| CredentialError::ProviderUnavailable(provider_name.to_string()))?;
288
289        provider.get_credential(key)
290    }
291
292    /// Refresh credential
293    pub fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
294        // Invalidate cache
295        self.cache.write().invalidate(key);
296
297        // Get fresh credential
298        for provider in &self.providers {
299            match provider.refresh_credential(key) {
300                Ok(credential) => {
301                    let ttl = credential
302                        .time_until_expiration()
303                        .map(|d| d.min(self.config.cache_ttl))
304                        .or(Some(self.config.cache_ttl));
305
306                    self.cache
307                        .write()
308                        .insert(key.to_string(), credential.clone(), ttl);
309                    return Ok(credential);
310                }
311                Err(CredentialError::NotFound(_)) => continue,
312                Err(e) => return Err(e),
313            }
314        }
315
316        Err(CredentialError::NotFound(key.to_string()))
317    }
318
319    /// List all available credentials
320    pub fn list_credentials(&self) -> Vec<(String, String)> {
321        let mut result = Vec::new();
322
323        for provider in &self.providers {
324            if let Ok(keys) = provider.list_credentials() {
325                for key in keys {
326                    result.push((key, provider.provider_name().to_string()));
327                }
328            }
329        }
330
331        result
332    }
333
334    /// Invalidate cached credential
335    pub fn invalidate(&self, key: &str) {
336        self.cache.write().invalidate(key);
337    }
338
339    /// Clear credential cache
340    pub fn clear_cache(&self) {
341        self.cache.write().clear();
342    }
343
344    /// Get cache statistics
345    pub fn cache_stats(&self) -> CacheStats {
346        let cache = self.cache.read();
347        CacheStats {
348            entries: cache.entries.len(),
349            max_size: cache.max_size,
350        }
351    }
352}
353
354/// Cache statistics
355#[derive(Debug, Clone)]
356pub struct CacheStats {
357    pub entries: usize,
358    pub max_size: usize,
359}
360
361/// Credential manager builder
362pub struct CredentialManagerBuilder {
363    config: CredentialConfig,
364    providers: Vec<Box<dyn CredentialProvider>>,
365}
366
367impl CredentialManagerBuilder {
368    /// Create a new builder
369    pub fn new() -> Self {
370        Self {
371            config: CredentialConfig::default(),
372            providers: Vec::new(),
373        }
374    }
375
376    /// Set cache TTL
377    pub fn cache_ttl(mut self, ttl: Duration) -> Self {
378        self.config.cache_ttl = ttl;
379        self
380    }
381
382    /// Add static provider
383    pub fn with_static_credentials(
384        mut self,
385        credentials: HashMap<String, DatabaseCredential>,
386    ) -> Self {
387        self.providers
388            .push(Box::new(StaticCredentialProvider::new(credentials)));
389        self
390    }
391
392    /// Add environment provider
393    pub fn with_environment(mut self, prefix: &str) -> Self {
394        self.providers
395            .push(Box::new(EnvironmentCredentialProvider::new(prefix)));
396        self
397    }
398
399    /// Add Vault provider
400    pub fn with_vault(mut self, address: &str, token: &str, mount: &str) -> Self {
401        self.providers.push(Box::new(VaultCredentialProvider::new(
402            address, token, mount,
403        )));
404        self
405    }
406
407    /// Add AWS Secrets Manager provider
408    pub fn with_aws_secrets_manager(mut self, region: &str) -> Self {
409        self.providers
410            .push(Box::new(AwsSecretsManagerProvider::new(region)));
411        self
412    }
413
414    /// Add custom provider
415    pub fn with_provider(mut self, provider: Box<dyn CredentialProvider>) -> Self {
416        self.providers.push(provider);
417        self
418    }
419
420    /// Build the manager
421    pub fn build(self) -> CredentialManager {
422        let mut manager = CredentialManager::new(self.config);
423        for provider in self.providers {
424            manager.add_provider(provider);
425        }
426        manager
427    }
428}
429
430impl Default for CredentialManagerBuilder {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436/// Static credential provider
437pub struct StaticCredentialProvider {
438    credentials: HashMap<String, DatabaseCredential>,
439}
440
441impl StaticCredentialProvider {
442    /// Create a new static provider
443    pub fn new(credentials: HashMap<String, DatabaseCredential>) -> Self {
444        Self { credentials }
445    }
446
447    /// Add a credential
448    pub fn add(&mut self, key: String, credential: DatabaseCredential) {
449        self.credentials.insert(key, credential);
450    }
451}
452
453impl CredentialProvider for StaticCredentialProvider {
454    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
455        self.credentials
456            .get(key)
457            .cloned()
458            .ok_or_else(|| CredentialError::NotFound(key.to_string()))
459    }
460
461    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
462        self.get_credential(key)
463    }
464
465    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
466        Ok(self.credentials.keys().cloned().collect())
467    }
468
469    fn provider_name(&self) -> &str {
470        "static"
471    }
472}
473
474/// Environment variable credential provider
475pub struct EnvironmentCredentialProvider {
476    prefix: String,
477}
478
479impl EnvironmentCredentialProvider {
480    /// Create a new environment provider
481    pub fn new(prefix: &str) -> Self {
482        Self {
483            prefix: prefix.to_string(),
484        }
485    }
486
487    fn var_name(&self, key: &str, suffix: &str) -> String {
488        format!("{}_{}{}", self.prefix, key.to_uppercase(), suffix)
489    }
490}
491
492impl CredentialProvider for EnvironmentCredentialProvider {
493    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
494        let username = std::env::var(self.var_name(key, "_USERNAME"))
495            .or_else(|_| std::env::var(self.var_name(key, "_USER")))
496            .map_err(|_| CredentialError::NotFound(key.to_string()))?;
497
498        let password = std::env::var(self.var_name(key, "_PASSWORD"))
499            .or_else(|_| std::env::var(self.var_name(key, "_PASS")))
500            .map_err(|_| CredentialError::NotFound(format!("{}_PASSWORD", key)))?;
501
502        let database = std::env::var(self.var_name(key, "_DATABASE")).ok();
503        let host = std::env::var(self.var_name(key, "_HOST")).ok();
504        let port = std::env::var(self.var_name(key, "_PORT"))
505            .ok()
506            .and_then(|p| p.parse().ok());
507
508        Ok(DatabaseCredential {
509            username,
510            password,
511            database,
512            host,
513            port,
514            options: HashMap::new(),
515            expires_at: None,
516            source: CredentialSource::Environment,
517        })
518    }
519
520    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
521        self.get_credential(key)
522    }
523
524    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
525        // Scan environment for matching credentials
526        let mut keys = Vec::new();
527        let prefix_upper = format!("{}_", self.prefix.to_uppercase());
528
529        for (key, _) in std::env::vars() {
530            if key.starts_with(&prefix_upper) && key.ends_with("_USERNAME") {
531                let name = key
532                    .strip_prefix(&prefix_upper)
533                    .and_then(|s| s.strip_suffix("_USERNAME"))
534                    .map(|s| s.to_lowercase());
535                if let Some(name) = name {
536                    keys.push(name);
537                }
538            }
539        }
540
541        Ok(keys)
542    }
543
544    fn provider_name(&self) -> &str {
545        "environment"
546    }
547}
548
549/// HashiCorp Vault credential provider
550pub struct VaultCredentialProvider {
551    address: String,
552    token: String,
553    mount: String,
554}
555
556impl VaultCredentialProvider {
557    /// Create a new Vault provider
558    pub fn new(address: &str, token: &str, mount: &str) -> Self {
559        Self {
560            address: address.to_string(),
561            token: token.to_string(),
562            mount: mount.to_string(),
563        }
564    }
565}
566
567impl CredentialProvider for VaultCredentialProvider {
568    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
569        // In a real implementation, this would make an HTTP request to Vault
570        // For demonstration, we return a placeholder
571        //
572        // Real implementation would:
573        // 1. POST to {address}/v1/{mount}/creds/{key}
574        // 2. Parse response for username/password
575        // 3. Handle lease renewal
576
577        let _ = (key, &self.address, &self.token, &self.mount);
578
579        Ok(DatabaseCredential {
580            username: format!("vault_user_{}", key),
581            password: "vault_generated_password".to_string(),
582            database: None,
583            host: None,
584            port: None,
585            options: HashMap::new(),
586            expires_at: Some(chrono::Utc::now() + chrono::Duration::hours(1)),
587            source: CredentialSource::Vault,
588        })
589    }
590
591    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
592        // In Vault, this would renew the lease or get a new credential
593        self.get_credential(key)
594    }
595
596    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
597        // In a real implementation, this would list roles from Vault
598        Ok(Vec::new())
599    }
600
601    fn provider_name(&self) -> &str {
602        "vault"
603    }
604}
605
606/// AWS Secrets Manager credential provider
607pub struct AwsSecretsManagerProvider {
608    region: String,
609}
610
611impl AwsSecretsManagerProvider {
612    /// Create a new AWS Secrets Manager provider
613    pub fn new(region: &str) -> Self {
614        Self {
615            region: region.to_string(),
616        }
617    }
618}
619
620impl CredentialProvider for AwsSecretsManagerProvider {
621    fn get_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
622        // In a real implementation, this would use the AWS SDK
623        // For demonstration, we return a placeholder
624        //
625        // Real implementation would:
626        // 1. Use aws_sdk_secretsmanager to get secret value
627        // 2. Parse JSON for username/password
628        // 3. Handle rotation
629
630        let _ = (key, &self.region);
631
632        Ok(DatabaseCredential {
633            username: format!("aws_user_{}", key),
634            password: "aws_managed_password".to_string(),
635            database: None,
636            host: None,
637            port: None,
638            options: HashMap::new(),
639            expires_at: None,
640            source: CredentialSource::AwsSecretsManager,
641        })
642    }
643
644    fn refresh_credential(&self, key: &str) -> Result<DatabaseCredential, CredentialError> {
645        self.get_credential(key)
646    }
647
648    fn list_credentials(&self) -> Result<Vec<String>, CredentialError> {
649        // In a real implementation, this would list secrets from AWS
650        Ok(Vec::new())
651    }
652
653    fn provider_name(&self) -> &str {
654        "aws_secrets_manager"
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    fn test_credential() -> DatabaseCredential {
663        DatabaseCredential {
664            username: "testuser".to_string(),
665            password: "testpass".to_string(),
666            database: Some("testdb".to_string()),
667            host: Some("localhost".to_string()),
668            port: Some(5432),
669            options: HashMap::new(),
670            expires_at: None,
671            source: CredentialSource::Static,
672        }
673    }
674
675    #[test]
676    fn test_static_provider() {
677        let mut credentials = HashMap::new();
678        credentials.insert("db1".to_string(), test_credential());
679
680        let provider = StaticCredentialProvider::new(credentials);
681
682        let cred = provider.get_credential("db1").unwrap();
683        assert_eq!(cred.username, "testuser");
684
685        assert!(provider.get_credential("db2").is_err());
686    }
687
688    #[test]
689    fn test_connection_string() {
690        let cred = test_credential();
691        let conn_str = cred.connection_string();
692
693        assert!(conn_str.contains("testuser"));
694        assert!(conn_str.contains("testpass"));
695        assert!(conn_str.contains("localhost"));
696        assert!(conn_str.contains("5432"));
697        assert!(conn_str.contains("testdb"));
698    }
699
700    #[test]
701    fn test_credential_expiration() {
702        let mut cred = test_credential();
703
704        // Not expired
705        cred.expires_at = Some(chrono::Utc::now() + chrono::Duration::hours(1));
706        assert!(!cred.is_expired());
707        assert!(cred.time_until_expiration().is_some());
708
709        // Expired
710        cred.expires_at = Some(chrono::Utc::now() - chrono::Duration::hours(1));
711        assert!(cred.is_expired());
712        assert!(cred.time_until_expiration().is_none());
713    }
714
715    #[test]
716    fn test_credential_manager() {
717        let mut credentials = HashMap::new();
718        credentials.insert("primary".to_string(), test_credential());
719
720        let manager = CredentialManager::builder()
721            .cache_ttl(Duration::from_secs(60))
722            .with_static_credentials(credentials)
723            .build();
724
725        let cred = manager.get_credential("primary").unwrap();
726        assert_eq!(cred.username, "testuser");
727
728        // Should be cached
729        let cached = manager.get_credential("primary").unwrap();
730        assert_eq!(cached.username, "testuser");
731
732        // Check stats
733        let stats = manager.cache_stats();
734        assert_eq!(stats.entries, 1);
735    }
736
737    #[test]
738    fn test_list_credentials() {
739        let mut credentials = HashMap::new();
740        credentials.insert("db1".to_string(), test_credential());
741        credentials.insert("db2".to_string(), test_credential());
742
743        let manager = CredentialManager::builder()
744            .with_static_credentials(credentials)
745            .build();
746
747        let list = manager.list_credentials();
748        assert_eq!(list.len(), 2);
749    }
750
751    #[test]
752    fn test_cache_invalidation() {
753        let mut credentials = HashMap::new();
754        credentials.insert("db1".to_string(), test_credential());
755
756        let manager = CredentialManager::builder()
757            .with_static_credentials(credentials)
758            .build();
759
760        // Cache it
761        let _ = manager.get_credential("db1").unwrap();
762        assert_eq!(manager.cache_stats().entries, 1);
763
764        // Invalidate
765        manager.invalidate("db1");
766        assert_eq!(manager.cache_stats().entries, 0);
767    }
768
769    #[test]
770    fn test_credential_source() {
771        assert_eq!(CredentialSource::Static, CredentialSource::Static);
772        assert_ne!(CredentialSource::Vault, CredentialSource::Environment);
773    }
774}