Skip to main content

heliosdb_proxy/auth/
api_keys.rs

1//! API Key Management
2//!
3//! Provides API key generation, validation, and lifecycle management.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, ApiKeyConfig};
13
14/// API key errors
15#[derive(Debug, Error)]
16pub enum ApiKeyError {
17    #[error("API key not found")]
18    NotFound,
19
20    #[error("API key expired")]
21    Expired,
22
23    #[error("API key revoked")]
24    Revoked,
25
26    #[error("API key rate limited")]
27    RateLimited,
28
29    #[error("Invalid API key format")]
30    InvalidFormat,
31
32    #[error("Insufficient scope: {0}")]
33    InsufficientScope(String),
34
35    #[error("Key generation failed: {0}")]
36    GenerationFailed(String),
37
38    #[error("Storage error: {0}")]
39    StorageError(String),
40}
41
42/// API key entry
43#[derive(Debug, Clone)]
44pub struct ApiKey {
45    /// Unique key ID
46    pub id: String,
47
48    /// Key prefix (visible part, e.g., "hdb_live_")
49    pub prefix: String,
50
51    /// Hashed key value
52    pub key_hash: String,
53
54    /// Associated user identity
55    pub identity: Identity,
56
57    /// Key name/description
58    pub name: String,
59
60    /// Creation timestamp
61    pub created_at: chrono::DateTime<chrono::Utc>,
62
63    /// Expiration timestamp
64    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
65
66    /// Last used timestamp
67    pub last_used_at: Option<chrono::DateTime<chrono::Utc>>,
68
69    /// Whether key is active
70    pub active: bool,
71
72    /// Allowed scopes
73    pub scopes: Vec<String>,
74
75    /// Rate limit (requests per minute)
76    pub rate_limit: Option<u32>,
77
78    /// Allowed IP addresses (empty = all allowed)
79    pub allowed_ips: Vec<std::net::IpAddr>,
80
81    /// Metadata
82    pub metadata: HashMap<String, String>,
83}
84
85impl ApiKey {
86    /// Check if the key is valid
87    pub fn is_valid(&self) -> bool {
88        if !self.active {
89            return false;
90        }
91
92        if let Some(expires_at) = self.expires_at {
93            if chrono::Utc::now() > expires_at {
94                return false;
95            }
96        }
97
98        true
99    }
100
101    /// Check if key has a specific scope
102    pub fn has_scope(&self, scope: &str) -> bool {
103        self.scopes.iter().any(|s| s == scope || s == "*")
104    }
105
106    /// Check if IP is allowed
107    pub fn is_ip_allowed(&self, ip: &std::net::IpAddr) -> bool {
108        if self.allowed_ips.is_empty() {
109            return true;
110        }
111        self.allowed_ips.contains(ip)
112    }
113}
114
115/// API key manager
116pub struct ApiKeyManager {
117    /// Configuration
118    config: ApiKeyConfig,
119
120    /// Key store by ID
121    keys_by_id: Arc<RwLock<HashMap<String, ApiKey>>>,
122
123    /// Key lookup by hash
124    keys_by_hash: Arc<RwLock<HashMap<String, String>>>,
125
126    /// Rate limit state
127    rate_limits: Arc<RwLock<HashMap<String, RateLimitState>>>,
128
129    /// Key prefix
130    key_prefix: String,
131}
132
133/// Rate limit state for a key
134struct RateLimitState {
135    /// Request count in current window
136    count: u32,
137
138    /// Window start time
139    window_start: Instant,
140}
141
142impl RateLimitState {
143    fn new() -> Self {
144        Self {
145            count: 0,
146            window_start: Instant::now(),
147        }
148    }
149
150    fn check_and_increment(&mut self, limit: u32) -> bool {
151        let window = Duration::from_secs(60);
152
153        if self.window_start.elapsed() > window {
154            self.count = 1;
155            self.window_start = Instant::now();
156            true
157        } else if self.count < limit {
158            self.count += 1;
159            true
160        } else {
161            false
162        }
163    }
164}
165
166impl ApiKeyManager {
167    /// Create a new API key manager
168    pub fn new(config: ApiKeyConfig) -> Self {
169        let key_prefix = config.prefix.clone().unwrap_or_else(|| "hdb_".to_string());
170
171        Self {
172            config,
173            keys_by_id: Arc::new(RwLock::new(HashMap::new())),
174            keys_by_hash: Arc::new(RwLock::new(HashMap::new())),
175            rate_limits: Arc::new(RwLock::new(HashMap::new())),
176            key_prefix,
177        }
178    }
179
180    /// Generate a new API key
181    pub fn generate_key(
182        &self,
183        identity: Identity,
184        name: String,
185        scopes: Vec<String>,
186        expires_in: Option<Duration>,
187        rate_limit: Option<u32>,
188    ) -> Result<(ApiKey, String), ApiKeyError> {
189        // Generate random key value
190        let key_value = self.generate_random_key();
191        let full_key = format!("{}{}", self.key_prefix, key_value);
192
193        // Hash the key
194        let key_hash = self.hash_key(&full_key);
195
196        // Generate key ID
197        let key_id = self.generate_key_id();
198
199        let expires_at = expires_in.map(|d| chrono::Utc::now() + chrono::Duration::from_std(d).unwrap());
200
201        let api_key = ApiKey {
202            id: key_id.clone(),
203            prefix: self.key_prefix.clone(),
204            key_hash: key_hash.clone(),
205            identity,
206            name,
207            created_at: chrono::Utc::now(),
208            expires_at,
209            last_used_at: None,
210            active: true,
211            scopes,
212            rate_limit,
213            allowed_ips: Vec::new(),
214            metadata: HashMap::new(),
215        };
216
217        // Store the key
218        self.keys_by_id.write().insert(key_id.clone(), api_key.clone());
219        self.keys_by_hash.write().insert(key_hash, key_id);
220
221        Ok((api_key, full_key))
222    }
223
224    /// Validate an API key
225    pub fn validate(&self, key: &str) -> Result<ApiKey, ApiKeyError> {
226        // Check format
227        if !key.starts_with(&self.key_prefix) {
228            return Err(ApiKeyError::InvalidFormat);
229        }
230
231        let key_hash = self.hash_key(key);
232
233        // Look up by hash
234        let key_id = self.keys_by_hash.read()
235            .get(&key_hash)
236            .cloned()
237            .ok_or(ApiKeyError::NotFound)?;
238
239        let mut keys = self.keys_by_id.write();
240        let api_key = keys.get_mut(&key_id)
241            .ok_or(ApiKeyError::NotFound)?;
242
243        // Check if active
244        if !api_key.active {
245            return Err(ApiKeyError::Revoked);
246        }
247
248        // Check expiration
249        if let Some(expires_at) = api_key.expires_at {
250            if chrono::Utc::now() > expires_at {
251                return Err(ApiKeyError::Expired);
252            }
253        }
254
255        // Check rate limit
256        if let Some(limit) = api_key.rate_limit {
257            if !self.check_rate_limit(&key_id, limit) {
258                return Err(ApiKeyError::RateLimited);
259            }
260        }
261
262        // Update last used
263        api_key.last_used_at = Some(chrono::Utc::now());
264
265        Ok(api_key.clone())
266    }
267
268    /// Validate key and convert to identity
269    pub fn validate_to_identity(&self, key: &str) -> Result<Identity, ApiKeyError> {
270        let api_key = self.validate(key)?;
271        Ok(api_key.identity)
272    }
273
274    /// Validate key with required scopes
275    pub fn validate_with_scopes(
276        &self,
277        key: &str,
278        required_scopes: &[&str],
279    ) -> Result<ApiKey, ApiKeyError> {
280        let api_key = self.validate(key)?;
281
282        for scope in required_scopes {
283            if !api_key.has_scope(scope) {
284                return Err(ApiKeyError::InsufficientScope((*scope).to_string()));
285            }
286        }
287
288        Ok(api_key)
289    }
290
291    /// Validate key with IP check
292    pub fn validate_with_ip(
293        &self,
294        key: &str,
295        client_ip: &std::net::IpAddr,
296    ) -> Result<ApiKey, ApiKeyError> {
297        let api_key = self.validate(key)?;
298
299        if !api_key.is_ip_allowed(client_ip) {
300            return Err(ApiKeyError::InsufficientScope("IP not allowed".to_string()));
301        }
302
303        Ok(api_key)
304    }
305
306    /// Revoke an API key
307    pub fn revoke(&self, key_id: &str) -> Result<(), ApiKeyError> {
308        let mut keys = self.keys_by_id.write();
309        let api_key = keys.get_mut(key_id)
310            .ok_or(ApiKeyError::NotFound)?;
311
312        api_key.active = false;
313        Ok(())
314    }
315
316    /// Delete an API key
317    pub fn delete(&self, key_id: &str) -> Result<(), ApiKeyError> {
318        let api_key = self.keys_by_id.write().remove(key_id)
319            .ok_or(ApiKeyError::NotFound)?;
320
321        self.keys_by_hash.write().remove(&api_key.key_hash);
322        self.rate_limits.write().remove(key_id);
323
324        Ok(())
325    }
326
327    /// Get an API key by ID
328    pub fn get(&self, key_id: &str) -> Option<ApiKey> {
329        self.keys_by_id.read().get(key_id).cloned()
330    }
331
332    /// List all API keys for a user
333    pub fn list_by_user(&self, user_id: &str) -> Vec<ApiKey> {
334        self.keys_by_id.read()
335            .values()
336            .filter(|k| k.identity.user_id == user_id)
337            .cloned()
338            .collect()
339    }
340
341    /// List all active API keys
342    pub fn list_active(&self) -> Vec<ApiKey> {
343        self.keys_by_id.read()
344            .values()
345            .filter(|k| k.is_valid())
346            .cloned()
347            .collect()
348    }
349
350    /// Update API key metadata
351    pub fn update_metadata(
352        &self,
353        key_id: &str,
354        metadata: HashMap<String, String>,
355    ) -> Result<(), ApiKeyError> {
356        let mut keys = self.keys_by_id.write();
357        let api_key = keys.get_mut(key_id)
358            .ok_or(ApiKeyError::NotFound)?;
359
360        api_key.metadata.extend(metadata);
361        Ok(())
362    }
363
364    /// Update API key scopes
365    pub fn update_scopes(&self, key_id: &str, scopes: Vec<String>) -> Result<(), ApiKeyError> {
366        let mut keys = self.keys_by_id.write();
367        let api_key = keys.get_mut(key_id)
368            .ok_or(ApiKeyError::NotFound)?;
369
370        api_key.scopes = scopes;
371        Ok(())
372    }
373
374    /// Update API key allowed IPs
375    pub fn update_allowed_ips(
376        &self,
377        key_id: &str,
378        ips: Vec<std::net::IpAddr>,
379    ) -> Result<(), ApiKeyError> {
380        let mut keys = self.keys_by_id.write();
381        let api_key = keys.get_mut(key_id)
382            .ok_or(ApiKeyError::NotFound)?;
383
384        api_key.allowed_ips = ips;
385        Ok(())
386    }
387
388    /// Rotate an API key (generate new key value, same ID)
389    pub fn rotate(&self, key_id: &str) -> Result<String, ApiKeyError> {
390        let old_hash = {
391            let keys = self.keys_by_id.read();
392            let api_key = keys.get(key_id).ok_or(ApiKeyError::NotFound)?;
393            api_key.key_hash.clone()
394        };
395
396        // Generate new key value
397        let key_value = self.generate_random_key();
398        let full_key = format!("{}{}", self.key_prefix, key_value);
399        let new_hash = self.hash_key(&full_key);
400
401        // Update key
402        {
403            let mut keys = self.keys_by_id.write();
404            let api_key = keys.get_mut(key_id).ok_or(ApiKeyError::NotFound)?;
405            api_key.key_hash = new_hash.clone();
406        }
407
408        // Update hash lookup
409        {
410            let mut hashes = self.keys_by_hash.write();
411            hashes.remove(&old_hash);
412            hashes.insert(new_hash, key_id.to_string());
413        }
414
415        Ok(full_key)
416    }
417
418    /// Get key statistics
419    pub fn stats(&self) -> ApiKeyStats {
420        let keys = self.keys_by_id.read();
421        let total = keys.len();
422        let active = keys.values().filter(|k| k.active).count();
423        let expired = keys.values().filter(|k| {
424            k.expires_at.map(|e| chrono::Utc::now() > e).unwrap_or(false)
425        }).count();
426
427        ApiKeyStats {
428            total,
429            active,
430            expired,
431            revoked: total - active - expired,
432        }
433    }
434
435    /// Cleanup expired keys
436    pub fn cleanup_expired(&self) {
437        let expired_ids: Vec<String> = self.keys_by_id.read()
438            .iter()
439            .filter(|(_, k)| {
440                k.expires_at.map(|e| chrono::Utc::now() > e).unwrap_or(false)
441            })
442            .map(|(id, _)| id.clone())
443            .collect();
444
445        for id in expired_ids {
446            let _ = self.delete(&id);
447        }
448    }
449
450    /// Check rate limit for a key
451    fn check_rate_limit(&self, key_id: &str, limit: u32) -> bool {
452        let mut limits = self.rate_limits.write();
453        let state = limits.entry(key_id.to_string())
454            .or_insert_with(RateLimitState::new);
455        state.check_and_increment(limit)
456    }
457
458    /// Generate a random key value
459    fn generate_random_key(&self) -> String {
460        use std::collections::hash_map::RandomState;
461        use std::hash::{BuildHasher, Hasher};
462
463        let mut hasher = RandomState::new().build_hasher();
464        hasher.write_u128(std::time::SystemTime::now()
465            .duration_since(std::time::UNIX_EPOCH)
466            .unwrap()
467            .as_nanos());
468        hasher.write_usize(std::process::id() as usize);
469
470        let hash1 = hasher.finish();
471
472        hasher.write_u64(hash1);
473        let hash2 = hasher.finish();
474
475        format!("{:016x}{:016x}", hash1, hash2)
476    }
477
478    /// Generate a key ID
479    fn generate_key_id(&self) -> String {
480        use std::collections::hash_map::RandomState;
481        use std::hash::{BuildHasher, Hasher};
482
483        let mut hasher = RandomState::new().build_hasher();
484        hasher.write_u128(std::time::SystemTime::now()
485            .duration_since(std::time::UNIX_EPOCH)
486            .unwrap()
487            .as_nanos());
488
489        format!("key_{:016x}", hasher.finish())
490    }
491
492    /// Hash a key value
493    fn hash_key(&self, key: &str) -> String {
494        use std::hash::{Hash, Hasher};
495        let mut hasher = std::collections::hash_map::DefaultHasher::new();
496        key.hash(&mut hasher);
497
498        // In production, use a cryptographic hash like SHA-256
499        format!("{:016x}", hasher.finish())
500    }
501
502    /// Get the key prefix
503    pub fn key_prefix(&self) -> &str {
504        &self.key_prefix
505    }
506}
507
508/// API key statistics
509#[derive(Debug, Clone)]
510pub struct ApiKeyStats {
511    /// Total number of keys
512    pub total: usize,
513
514    /// Number of active keys
515    pub active: usize,
516
517    /// Number of expired keys
518    pub expired: usize,
519
520    /// Number of revoked keys
521    pub revoked: usize,
522}
523
524/// API key builder
525pub struct ApiKeyBuilder {
526    identity: Identity,
527    name: String,
528    scopes: Vec<String>,
529    expires_in: Option<Duration>,
530    rate_limit: Option<u32>,
531    allowed_ips: Vec<std::net::IpAddr>,
532    metadata: HashMap<String, String>,
533}
534
535impl ApiKeyBuilder {
536    /// Create a new builder
537    pub fn new(identity: Identity, name: impl Into<String>) -> Self {
538        Self {
539            identity,
540            name: name.into(),
541            scopes: Vec::new(),
542            expires_in: None,
543            rate_limit: None,
544            allowed_ips: Vec::new(),
545            metadata: HashMap::new(),
546        }
547    }
548
549    /// Add a scope
550    pub fn scope(mut self, scope: impl Into<String>) -> Self {
551        self.scopes.push(scope.into());
552        self
553    }
554
555    /// Add multiple scopes
556    pub fn scopes(mut self, scopes: Vec<String>) -> Self {
557        self.scopes.extend(scopes);
558        self
559    }
560
561    /// Set expiration
562    pub fn expires_in(mut self, duration: Duration) -> Self {
563        self.expires_in = Some(duration);
564        self
565    }
566
567    /// Set rate limit
568    pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
569        self.rate_limit = Some(requests_per_minute);
570        self
571    }
572
573    /// Add allowed IP
574    pub fn allow_ip(mut self, ip: std::net::IpAddr) -> Self {
575        self.allowed_ips.push(ip);
576        self
577    }
578
579    /// Add metadata
580    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
581        self.metadata.insert(key.into(), value.into());
582        self
583    }
584
585    /// Build the API key using the manager
586    pub fn build(self, manager: &ApiKeyManager) -> Result<(ApiKey, String), ApiKeyError> {
587        let (mut api_key, key_value) = manager.generate_key(
588            self.identity,
589            self.name,
590            self.scopes,
591            self.expires_in,
592            self.rate_limit,
593        )?;
594
595        api_key.allowed_ips = self.allowed_ips;
596        api_key.metadata = self.metadata;
597
598        // Update the stored key
599        manager.keys_by_id.write().insert(api_key.id.clone(), api_key.clone());
600
601        Ok((api_key, key_value))
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    fn test_config() -> ApiKeyConfig {
610        ApiKeyConfig {
611            header_name: "X-API-Key".to_string(),
612            query_param: Some("api_key".to_string()),
613            prefix: Some("hdb_test_".to_string()),
614            hash_algorithm: "sha256".to_string(),
615        }
616    }
617
618    fn test_identity() -> Identity {
619        Identity {
620            user_id: "user123".to_string(),
621            name: Some("Test User".to_string()),
622            email: Some("test@example.com".to_string()),
623            roles: vec!["user".to_string()],
624            groups: Vec::new(),
625            tenant_id: None,
626            claims: HashMap::new(),
627            auth_method: "api_key".to_string(),
628            authenticated_at: chrono::Utc::now(),
629        }
630    }
631
632    #[test]
633    fn test_generate_key() {
634        let manager = ApiKeyManager::new(test_config());
635        let (api_key, key_value) = manager.generate_key(
636            test_identity(),
637            "Test Key".to_string(),
638            vec!["read".to_string()],
639            None,
640            None,
641        ).unwrap();
642
643        assert!(key_value.starts_with("hdb_test_"));
644        assert!(api_key.active);
645        assert!(api_key.has_scope("read"));
646    }
647
648    #[test]
649    fn test_validate_key() {
650        let manager = ApiKeyManager::new(test_config());
651        let (_, key_value) = manager.generate_key(
652            test_identity(),
653            "Test Key".to_string(),
654            vec!["read".to_string()],
655            None,
656            None,
657        ).unwrap();
658
659        let validated = manager.validate(&key_value).unwrap();
660        assert_eq!(validated.identity.user_id, "user123");
661    }
662
663    #[test]
664    fn test_validate_invalid_key() {
665        let manager = ApiKeyManager::new(test_config());
666        let result = manager.validate("hdb_test_invalid");
667        assert!(matches!(result, Err(ApiKeyError::NotFound)));
668    }
669
670    #[test]
671    fn test_revoke_key() {
672        let manager = ApiKeyManager::new(test_config());
673        let (api_key, key_value) = manager.generate_key(
674            test_identity(),
675            "Test Key".to_string(),
676            vec!["read".to_string()],
677            None,
678            None,
679        ).unwrap();
680
681        manager.revoke(&api_key.id).unwrap();
682
683        let result = manager.validate(&key_value);
684        assert!(matches!(result, Err(ApiKeyError::Revoked)));
685    }
686
687    #[test]
688    fn test_key_expiration() {
689        let manager = ApiKeyManager::new(test_config());
690        let (_, key_value) = manager.generate_key(
691            test_identity(),
692            "Test Key".to_string(),
693            vec!["read".to_string()],
694            Some(Duration::from_secs(0)), // Expired immediately
695            None,
696        ).unwrap();
697
698        // Give it a moment to expire
699        std::thread::sleep(Duration::from_millis(10));
700
701        let result = manager.validate(&key_value);
702        assert!(matches!(result, Err(ApiKeyError::Expired)));
703    }
704
705    #[test]
706    fn test_scope_validation() {
707        let manager = ApiKeyManager::new(test_config());
708        let (_, key_value) = manager.generate_key(
709            test_identity(),
710            "Test Key".to_string(),
711            vec!["read".to_string()],
712            None,
713            None,
714        ).unwrap();
715
716        // Should succeed for read
717        assert!(manager.validate_with_scopes(&key_value, &["read"]).is_ok());
718
719        // Should fail for write
720        assert!(matches!(
721            manager.validate_with_scopes(&key_value, &["write"]),
722            Err(ApiKeyError::InsufficientScope(_))
723        ));
724    }
725
726    #[test]
727    fn test_list_by_user() {
728        let manager = ApiKeyManager::new(test_config());
729
730        let identity1 = test_identity();
731        let mut identity2 = test_identity();
732        identity2.user_id = "user456".to_string();
733
734        let _ = manager.generate_key(identity1, "Key 1".to_string(), vec![], None, None).unwrap();
735        let _ = manager.generate_key(identity2, "Key 2".to_string(), vec![], None, None).unwrap();
736
737        let user_keys = manager.list_by_user("user123");
738        assert_eq!(user_keys.len(), 1);
739    }
740
741    #[test]
742    fn test_key_stats() {
743        let manager = ApiKeyManager::new(test_config());
744
745        let (key1, _) = manager.generate_key(
746            test_identity(),
747            "Key 1".to_string(),
748            vec![],
749            None,
750            None,
751        ).unwrap();
752
753        let _ = manager.generate_key(
754            test_identity(),
755            "Key 2".to_string(),
756            vec![],
757            None,
758            None,
759        ).unwrap();
760
761        manager.revoke(&key1.id).unwrap();
762
763        let stats = manager.stats();
764        assert_eq!(stats.total, 2);
765        assert_eq!(stats.active, 1);
766    }
767
768    #[test]
769    fn test_rotate_key() {
770        let manager = ApiKeyManager::new(test_config());
771        let (api_key, old_key) = manager.generate_key(
772            test_identity(),
773            "Test Key".to_string(),
774            vec!["read".to_string()],
775            None,
776            None,
777        ).unwrap();
778
779        // Rotate
780        let new_key = manager.rotate(&api_key.id).unwrap();
781
782        // Old key should fail
783        assert!(manager.validate(&old_key).is_err());
784
785        // New key should work
786        assert!(manager.validate(&new_key).is_ok());
787    }
788
789    #[test]
790    fn test_api_key_builder() {
791        let manager = ApiKeyManager::new(test_config());
792
793        let (api_key, key_value) = ApiKeyBuilder::new(test_identity(), "Builder Key")
794            .scope("read")
795            .scope("write")
796            .rate_limit(100)
797            .expires_in(Duration::from_secs(3600))
798            .metadata("env", "test")
799            .build(&manager)
800            .unwrap();
801
802        assert!(key_value.starts_with("hdb_test_"));
803        assert!(api_key.has_scope("read"));
804        assert!(api_key.has_scope("write"));
805        assert_eq!(api_key.rate_limit, Some(100));
806        assert_eq!(api_key.metadata.get("env"), Some(&"test".to_string()));
807    }
808}