Skip to main content

heliosdb_proxy/auth/
api_keys.rs

1//! API Key Management
2//!
3//! Provides API key generation, validation, and lifecycle management.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{ApiKeyConfig, Identity};
13
14/// API key errors
15#[derive(Debug, Error)]
16pub enum ApiKeyError {
17    #[error("API key not found")]
18    NotFound,
19
20    #[error("API key expired")]
21    Expired,
22
23    #[error("API key revoked")]
24    Revoked,
25
26    #[error("API key rate limited")]
27    RateLimited,
28
29    #[error("Invalid API key format")]
30    InvalidFormat,
31
32    #[error("Insufficient scope: {0}")]
33    InsufficientScope(String),
34
35    #[error("Key generation failed: {0}")]
36    GenerationFailed(String),
37
38    #[error("Storage error: {0}")]
39    StorageError(String),
40}
41
42/// API key entry
43#[derive(Debug, Clone)]
44pub struct ApiKey {
45    /// Unique key ID
46    pub id: String,
47
48    /// Key prefix (visible part, e.g., "hdb_live_")
49    pub prefix: String,
50
51    /// Hashed key value
52    pub key_hash: String,
53
54    /// Associated user identity
55    pub identity: Identity,
56
57    /// Key name/description
58    pub name: String,
59
60    /// Creation timestamp
61    pub created_at: chrono::DateTime<chrono::Utc>,
62
63    /// Expiration timestamp
64    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
65
66    /// Last used timestamp
67    pub last_used_at: Option<chrono::DateTime<chrono::Utc>>,
68
69    /// Whether key is active
70    pub active: bool,
71
72    /// Allowed scopes
73    pub scopes: Vec<String>,
74
75    /// Rate limit (requests per minute)
76    pub rate_limit: Option<u32>,
77
78    /// Allowed IP addresses (empty = all allowed)
79    pub allowed_ips: Vec<std::net::IpAddr>,
80
81    /// Metadata
82    pub metadata: HashMap<String, String>,
83}
84
85impl ApiKey {
86    /// Check if the key is valid
87    pub fn is_valid(&self) -> bool {
88        if !self.active {
89            return false;
90        }
91
92        if let Some(expires_at) = self.expires_at {
93            if chrono::Utc::now() > expires_at {
94                return false;
95            }
96        }
97
98        true
99    }
100
101    /// Check if key has a specific scope
102    pub fn has_scope(&self, scope: &str) -> bool {
103        self.scopes.iter().any(|s| s == scope || s == "*")
104    }
105
106    /// Check if IP is allowed
107    pub fn is_ip_allowed(&self, ip: &std::net::IpAddr) -> bool {
108        if self.allowed_ips.is_empty() {
109            return true;
110        }
111        self.allowed_ips.contains(ip)
112    }
113}
114
115/// API key manager
116pub struct ApiKeyManager {
117    /// Configuration
118    #[allow(dead_code)]
119    config: ApiKeyConfig,
120
121    /// Key store by ID
122    keys_by_id: Arc<RwLock<HashMap<String, ApiKey>>>,
123
124    /// Key lookup by hash
125    keys_by_hash: Arc<RwLock<HashMap<String, String>>>,
126
127    /// Rate limit state
128    rate_limits: Arc<RwLock<HashMap<String, RateLimitState>>>,
129
130    /// Key prefix
131    key_prefix: String,
132}
133
134/// Rate limit state for a key
135struct RateLimitState {
136    /// Request count in current window
137    count: u32,
138
139    /// Window start time
140    window_start: Instant,
141}
142
143impl RateLimitState {
144    fn new() -> Self {
145        Self {
146            count: 0,
147            window_start: Instant::now(),
148        }
149    }
150
151    fn check_and_increment(&mut self, limit: u32) -> bool {
152        let window = Duration::from_secs(60);
153
154        if self.window_start.elapsed() > window {
155            self.count = 1;
156            self.window_start = Instant::now();
157            true
158        } else if self.count < limit {
159            self.count += 1;
160            true
161        } else {
162            false
163        }
164    }
165}
166
167impl ApiKeyManager {
168    /// Create a new API key manager
169    pub fn new(config: ApiKeyConfig) -> Self {
170        let key_prefix = config.prefix.clone().unwrap_or_else(|| "hdb_".to_string());
171
172        Self {
173            config,
174            keys_by_id: Arc::new(RwLock::new(HashMap::new())),
175            keys_by_hash: Arc::new(RwLock::new(HashMap::new())),
176            rate_limits: Arc::new(RwLock::new(HashMap::new())),
177            key_prefix,
178        }
179    }
180
181    /// Generate a new API key
182    pub fn generate_key(
183        &self,
184        identity: Identity,
185        name: String,
186        scopes: Vec<String>,
187        expires_in: Option<Duration>,
188        rate_limit: Option<u32>,
189    ) -> Result<(ApiKey, String), ApiKeyError> {
190        // Generate random key value
191        let key_value = self.generate_random_key();
192        let full_key = format!("{}{}", self.key_prefix, key_value);
193
194        // Hash the key
195        let key_hash = self.hash_key(&full_key);
196
197        // Generate key ID
198        let key_id = self.generate_key_id();
199
200        let expires_at =
201            expires_in.map(|d| chrono::Utc::now() + chrono::Duration::from_std(d).unwrap());
202
203        let api_key = ApiKey {
204            id: key_id.clone(),
205            prefix: self.key_prefix.clone(),
206            key_hash: key_hash.clone(),
207            identity,
208            name,
209            created_at: chrono::Utc::now(),
210            expires_at,
211            last_used_at: None,
212            active: true,
213            scopes,
214            rate_limit,
215            allowed_ips: Vec::new(),
216            metadata: HashMap::new(),
217        };
218
219        // Store the key
220        self.keys_by_id
221            .write()
222            .insert(key_id.clone(), api_key.clone());
223        self.keys_by_hash.write().insert(key_hash, key_id);
224
225        Ok((api_key, full_key))
226    }
227
228    /// Validate an API key
229    pub fn validate(&self, key: &str) -> Result<ApiKey, ApiKeyError> {
230        // Check format
231        if !key.starts_with(&self.key_prefix) {
232            return Err(ApiKeyError::InvalidFormat);
233        }
234
235        let key_hash = self.hash_key(key);
236
237        // Look up by hash
238        let key_id = self
239            .keys_by_hash
240            .read()
241            .get(&key_hash)
242            .cloned()
243            .ok_or(ApiKeyError::NotFound)?;
244
245        let mut keys = self.keys_by_id.write();
246        let api_key = keys.get_mut(&key_id).ok_or(ApiKeyError::NotFound)?;
247
248        // Check if active
249        if !api_key.active {
250            return Err(ApiKeyError::Revoked);
251        }
252
253        // Check expiration
254        if let Some(expires_at) = api_key.expires_at {
255            if chrono::Utc::now() > expires_at {
256                return Err(ApiKeyError::Expired);
257            }
258        }
259
260        // Check rate limit
261        if let Some(limit) = api_key.rate_limit {
262            if !self.check_rate_limit(&key_id, limit) {
263                return Err(ApiKeyError::RateLimited);
264            }
265        }
266
267        // Update last used
268        api_key.last_used_at = Some(chrono::Utc::now());
269
270        Ok(api_key.clone())
271    }
272
273    /// Validate key and convert to identity
274    pub fn validate_to_identity(&self, key: &str) -> Result<Identity, ApiKeyError> {
275        let api_key = self.validate(key)?;
276        Ok(api_key.identity)
277    }
278
279    /// Validate key with required scopes
280    pub fn validate_with_scopes(
281        &self,
282        key: &str,
283        required_scopes: &[&str],
284    ) -> Result<ApiKey, ApiKeyError> {
285        let api_key = self.validate(key)?;
286
287        for scope in required_scopes {
288            if !api_key.has_scope(scope) {
289                return Err(ApiKeyError::InsufficientScope((*scope).to_string()));
290            }
291        }
292
293        Ok(api_key)
294    }
295
296    /// Validate key with IP check
297    pub fn validate_with_ip(
298        &self,
299        key: &str,
300        client_ip: &std::net::IpAddr,
301    ) -> Result<ApiKey, ApiKeyError> {
302        let api_key = self.validate(key)?;
303
304        if !api_key.is_ip_allowed(client_ip) {
305            return Err(ApiKeyError::InsufficientScope("IP not allowed".to_string()));
306        }
307
308        Ok(api_key)
309    }
310
311    /// Revoke an API key
312    pub fn revoke(&self, key_id: &str) -> Result<(), ApiKeyError> {
313        let mut keys = self.keys_by_id.write();
314        let api_key = keys.get_mut(key_id).ok_or(ApiKeyError::NotFound)?;
315
316        api_key.active = false;
317        Ok(())
318    }
319
320    /// Delete an API key
321    pub fn delete(&self, key_id: &str) -> Result<(), ApiKeyError> {
322        let api_key = self
323            .keys_by_id
324            .write()
325            .remove(key_id)
326            .ok_or(ApiKeyError::NotFound)?;
327
328        self.keys_by_hash.write().remove(&api_key.key_hash);
329        self.rate_limits.write().remove(key_id);
330
331        Ok(())
332    }
333
334    /// Get an API key by ID
335    pub fn get(&self, key_id: &str) -> Option<ApiKey> {
336        self.keys_by_id.read().get(key_id).cloned()
337    }
338
339    /// List all API keys for a user
340    pub fn list_by_user(&self, user_id: &str) -> Vec<ApiKey> {
341        self.keys_by_id
342            .read()
343            .values()
344            .filter(|k| k.identity.user_id == user_id)
345            .cloned()
346            .collect()
347    }
348
349    /// List all active API keys
350    pub fn list_active(&self) -> Vec<ApiKey> {
351        self.keys_by_id
352            .read()
353            .values()
354            .filter(|k| k.is_valid())
355            .cloned()
356            .collect()
357    }
358
359    /// Update API key metadata
360    pub fn update_metadata(
361        &self,
362        key_id: &str,
363        metadata: HashMap<String, String>,
364    ) -> Result<(), ApiKeyError> {
365        let mut keys = self.keys_by_id.write();
366        let api_key = keys.get_mut(key_id).ok_or(ApiKeyError::NotFound)?;
367
368        api_key.metadata.extend(metadata);
369        Ok(())
370    }
371
372    /// Update API key scopes
373    pub fn update_scopes(&self, key_id: &str, scopes: Vec<String>) -> Result<(), ApiKeyError> {
374        let mut keys = self.keys_by_id.write();
375        let api_key = keys.get_mut(key_id).ok_or(ApiKeyError::NotFound)?;
376
377        api_key.scopes = scopes;
378        Ok(())
379    }
380
381    /// Update API key allowed IPs
382    pub fn update_allowed_ips(
383        &self,
384        key_id: &str,
385        ips: Vec<std::net::IpAddr>,
386    ) -> Result<(), ApiKeyError> {
387        let mut keys = self.keys_by_id.write();
388        let api_key = keys.get_mut(key_id).ok_or(ApiKeyError::NotFound)?;
389
390        api_key.allowed_ips = ips;
391        Ok(())
392    }
393
394    /// Rotate an API key (generate new key value, same ID)
395    pub fn rotate(&self, key_id: &str) -> Result<String, ApiKeyError> {
396        let old_hash = {
397            let keys = self.keys_by_id.read();
398            let api_key = keys.get(key_id).ok_or(ApiKeyError::NotFound)?;
399            api_key.key_hash.clone()
400        };
401
402        // Generate new key value
403        let key_value = self.generate_random_key();
404        let full_key = format!("{}{}", self.key_prefix, key_value);
405        let new_hash = self.hash_key(&full_key);
406
407        // Update key
408        {
409            let mut keys = self.keys_by_id.write();
410            let api_key = keys.get_mut(key_id).ok_or(ApiKeyError::NotFound)?;
411            api_key.key_hash = new_hash.clone();
412        }
413
414        // Update hash lookup
415        {
416            let mut hashes = self.keys_by_hash.write();
417            hashes.remove(&old_hash);
418            hashes.insert(new_hash, key_id.to_string());
419        }
420
421        Ok(full_key)
422    }
423
424    /// Get key statistics
425    pub fn stats(&self) -> ApiKeyStats {
426        let keys = self.keys_by_id.read();
427        let total = keys.len();
428        let active = keys.values().filter(|k| k.active).count();
429        let expired = keys
430            .values()
431            .filter(|k| {
432                k.expires_at
433                    .map(|e| chrono::Utc::now() > e)
434                    .unwrap_or(false)
435            })
436            .count();
437
438        ApiKeyStats {
439            total,
440            active,
441            expired,
442            revoked: total - active - expired,
443        }
444    }
445
446    /// Cleanup expired keys
447    pub fn cleanup_expired(&self) {
448        let expired_ids: Vec<String> = self
449            .keys_by_id
450            .read()
451            .iter()
452            .filter(|(_, k)| {
453                k.expires_at
454                    .map(|e| chrono::Utc::now() > e)
455                    .unwrap_or(false)
456            })
457            .map(|(id, _)| id.clone())
458            .collect();
459
460        for id in expired_ids {
461            let _ = self.delete(&id);
462        }
463    }
464
465    /// Check rate limit for a key
466    fn check_rate_limit(&self, key_id: &str, limit: u32) -> bool {
467        let mut limits = self.rate_limits.write();
468        let state = limits
469            .entry(key_id.to_string())
470            .or_insert_with(RateLimitState::new);
471        state.check_and_increment(limit)
472    }
473
474    /// Generate a random key value
475    fn generate_random_key(&self) -> String {
476        use std::collections::hash_map::RandomState;
477        use std::hash::{BuildHasher, Hasher};
478
479        let mut hasher = RandomState::new().build_hasher();
480        hasher.write_u128(
481            std::time::SystemTime::now()
482                .duration_since(std::time::UNIX_EPOCH)
483                .unwrap()
484                .as_nanos(),
485        );
486        hasher.write_usize(std::process::id() as usize);
487
488        let hash1 = hasher.finish();
489
490        hasher.write_u64(hash1);
491        let hash2 = hasher.finish();
492
493        format!("{:016x}{:016x}", hash1, hash2)
494    }
495
496    /// Generate a key ID
497    fn generate_key_id(&self) -> String {
498        use std::collections::hash_map::RandomState;
499        use std::hash::{BuildHasher, Hasher};
500
501        let mut hasher = RandomState::new().build_hasher();
502        hasher.write_u128(
503            std::time::SystemTime::now()
504                .duration_since(std::time::UNIX_EPOCH)
505                .unwrap()
506                .as_nanos(),
507        );
508
509        format!("key_{:016x}", hasher.finish())
510    }
511
512    /// Hash a key value
513    fn hash_key(&self, key: &str) -> String {
514        use std::hash::{Hash, Hasher};
515        let mut hasher = std::collections::hash_map::DefaultHasher::new();
516        key.hash(&mut hasher);
517
518        // In production, use a cryptographic hash like SHA-256
519        format!("{:016x}", hasher.finish())
520    }
521
522    /// Get the key prefix
523    pub fn key_prefix(&self) -> &str {
524        &self.key_prefix
525    }
526}
527
528/// API key statistics
529#[derive(Debug, Clone)]
530pub struct ApiKeyStats {
531    /// Total number of keys
532    pub total: usize,
533
534    /// Number of active keys
535    pub active: usize,
536
537    /// Number of expired keys
538    pub expired: usize,
539
540    /// Number of revoked keys
541    pub revoked: usize,
542}
543
544/// API key builder
545pub struct ApiKeyBuilder {
546    identity: Identity,
547    name: String,
548    scopes: Vec<String>,
549    expires_in: Option<Duration>,
550    rate_limit: Option<u32>,
551    allowed_ips: Vec<std::net::IpAddr>,
552    metadata: HashMap<String, String>,
553}
554
555impl ApiKeyBuilder {
556    /// Create a new builder
557    pub fn new(identity: Identity, name: impl Into<String>) -> Self {
558        Self {
559            identity,
560            name: name.into(),
561            scopes: Vec::new(),
562            expires_in: None,
563            rate_limit: None,
564            allowed_ips: Vec::new(),
565            metadata: HashMap::new(),
566        }
567    }
568
569    /// Add a scope
570    pub fn scope(mut self, scope: impl Into<String>) -> Self {
571        self.scopes.push(scope.into());
572        self
573    }
574
575    /// Add multiple scopes
576    pub fn scopes(mut self, scopes: Vec<String>) -> Self {
577        self.scopes.extend(scopes);
578        self
579    }
580
581    /// Set expiration
582    pub fn expires_in(mut self, duration: Duration) -> Self {
583        self.expires_in = Some(duration);
584        self
585    }
586
587    /// Set rate limit
588    pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
589        self.rate_limit = Some(requests_per_minute);
590        self
591    }
592
593    /// Add allowed IP
594    pub fn allow_ip(mut self, ip: std::net::IpAddr) -> Self {
595        self.allowed_ips.push(ip);
596        self
597    }
598
599    /// Add metadata
600    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
601        self.metadata.insert(key.into(), value.into());
602        self
603    }
604
605    /// Build the API key using the manager
606    pub fn build(self, manager: &ApiKeyManager) -> Result<(ApiKey, String), ApiKeyError> {
607        let (mut api_key, key_value) = manager.generate_key(
608            self.identity,
609            self.name,
610            self.scopes,
611            self.expires_in,
612            self.rate_limit,
613        )?;
614
615        api_key.allowed_ips = self.allowed_ips;
616        api_key.metadata = self.metadata;
617
618        // Update the stored key
619        manager
620            .keys_by_id
621            .write()
622            .insert(api_key.id.clone(), api_key.clone());
623
624        Ok((api_key, key_value))
625    }
626}
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    fn test_config() -> ApiKeyConfig {
633        ApiKeyConfig {
634            header_name: "X-API-Key".to_string(),
635            query_param: Some("api_key".to_string()),
636            prefix: Some("hdb_test_".to_string()),
637            hash_algorithm: "sha256".to_string(),
638        }
639    }
640
641    fn test_identity() -> Identity {
642        Identity {
643            user_id: "user123".to_string(),
644            name: Some("Test User".to_string()),
645            email: Some("test@example.com".to_string()),
646            roles: vec!["user".to_string()],
647            groups: Vec::new(),
648            tenant_id: None,
649            claims: HashMap::new(),
650            auth_method: "api_key".to_string(),
651            authenticated_at: chrono::Utc::now(),
652        }
653    }
654
655    #[test]
656    fn test_generate_key() {
657        let manager = ApiKeyManager::new(test_config());
658        let (api_key, key_value) = manager
659            .generate_key(
660                test_identity(),
661                "Test Key".to_string(),
662                vec!["read".to_string()],
663                None,
664                None,
665            )
666            .unwrap();
667
668        assert!(key_value.starts_with("hdb_test_"));
669        assert!(api_key.active);
670        assert!(api_key.has_scope("read"));
671    }
672
673    #[test]
674    fn test_validate_key() {
675        let manager = ApiKeyManager::new(test_config());
676        let (_, key_value) = manager
677            .generate_key(
678                test_identity(),
679                "Test Key".to_string(),
680                vec!["read".to_string()],
681                None,
682                None,
683            )
684            .unwrap();
685
686        let validated = manager.validate(&key_value).unwrap();
687        assert_eq!(validated.identity.user_id, "user123");
688    }
689
690    #[test]
691    fn test_validate_invalid_key() {
692        let manager = ApiKeyManager::new(test_config());
693        let result = manager.validate("hdb_test_invalid");
694        assert!(matches!(result, Err(ApiKeyError::NotFound)));
695    }
696
697    #[test]
698    fn test_revoke_key() {
699        let manager = ApiKeyManager::new(test_config());
700        let (api_key, key_value) = manager
701            .generate_key(
702                test_identity(),
703                "Test Key".to_string(),
704                vec!["read".to_string()],
705                None,
706                None,
707            )
708            .unwrap();
709
710        manager.revoke(&api_key.id).unwrap();
711
712        let result = manager.validate(&key_value);
713        assert!(matches!(result, Err(ApiKeyError::Revoked)));
714    }
715
716    #[test]
717    fn test_key_expiration() {
718        let manager = ApiKeyManager::new(test_config());
719        let (_, key_value) = manager
720            .generate_key(
721                test_identity(),
722                "Test Key".to_string(),
723                vec!["read".to_string()],
724                Some(Duration::from_secs(0)), // Expired immediately
725                None,
726            )
727            .unwrap();
728
729        // Give it a moment to expire
730        std::thread::sleep(Duration::from_millis(10));
731
732        let result = manager.validate(&key_value);
733        assert!(matches!(result, Err(ApiKeyError::Expired)));
734    }
735
736    #[test]
737    fn test_scope_validation() {
738        let manager = ApiKeyManager::new(test_config());
739        let (_, key_value) = manager
740            .generate_key(
741                test_identity(),
742                "Test Key".to_string(),
743                vec!["read".to_string()],
744                None,
745                None,
746            )
747            .unwrap();
748
749        // Should succeed for read
750        assert!(manager.validate_with_scopes(&key_value, &["read"]).is_ok());
751
752        // Should fail for write
753        assert!(matches!(
754            manager.validate_with_scopes(&key_value, &["write"]),
755            Err(ApiKeyError::InsufficientScope(_))
756        ));
757    }
758
759    #[test]
760    fn test_list_by_user() {
761        let manager = ApiKeyManager::new(test_config());
762
763        let identity1 = test_identity();
764        let mut identity2 = test_identity();
765        identity2.user_id = "user456".to_string();
766
767        let _ = manager
768            .generate_key(identity1, "Key 1".to_string(), vec![], None, None)
769            .unwrap();
770        let _ = manager
771            .generate_key(identity2, "Key 2".to_string(), vec![], None, None)
772            .unwrap();
773
774        let user_keys = manager.list_by_user("user123");
775        assert_eq!(user_keys.len(), 1);
776    }
777
778    #[test]
779    fn test_key_stats() {
780        let manager = ApiKeyManager::new(test_config());
781
782        let (key1, _) = manager
783            .generate_key(test_identity(), "Key 1".to_string(), vec![], None, None)
784            .unwrap();
785
786        let _ = manager
787            .generate_key(test_identity(), "Key 2".to_string(), vec![], None, None)
788            .unwrap();
789
790        manager.revoke(&key1.id).unwrap();
791
792        let stats = manager.stats();
793        assert_eq!(stats.total, 2);
794        assert_eq!(stats.active, 1);
795    }
796
797    #[test]
798    fn test_rotate_key() {
799        let manager = ApiKeyManager::new(test_config());
800        let (api_key, old_key) = manager
801            .generate_key(
802                test_identity(),
803                "Test Key".to_string(),
804                vec!["read".to_string()],
805                None,
806                None,
807            )
808            .unwrap();
809
810        // Rotate
811        let new_key = manager.rotate(&api_key.id).unwrap();
812
813        // Old key should fail
814        assert!(manager.validate(&old_key).is_err());
815
816        // New key should work
817        assert!(manager.validate(&new_key).is_ok());
818    }
819
820    #[test]
821    fn test_api_key_builder() {
822        let manager = ApiKeyManager::new(test_config());
823
824        let (api_key, key_value) = ApiKeyBuilder::new(test_identity(), "Builder Key")
825            .scope("read")
826            .scope("write")
827            .rate_limit(100)
828            .expires_in(Duration::from_secs(3600))
829            .metadata("env", "test")
830            .build(&manager)
831            .unwrap();
832
833        assert!(key_value.starts_with("hdb_test_"));
834        assert!(api_key.has_scope("read"));
835        assert!(api_key.has_scope("write"));
836        assert_eq!(api_key.rate_limit, Some(100));
837        assert_eq!(api_key.metadata.get("env"), Some(&"test".to_string()));
838    }
839}