blueprint_auth/
api_keys.rs

1//! Long-lived API key management
2//!
3//! API keys are the primary authentication mechanism for services.
4//! They follow the format `ak_<key_id>.<secret>` where:
5//!
6//! - `ak_` is a fixed prefix for identification
7//! - `key_id` is a unique identifier (5 chars, base64url)
8//! - `secret` is the secret part (32 bytes, base64url encoded)
9//!
10//! # Storage
11//!
12//! Keys are stored in RocksDB with:
13//! - The full key is never stored, only a hash
14//! - Key metadata includes service ID, creation time, etc.
15//! - Supports expiration and enabled/disabled states
16//!
17//! # Usage
18//!
19//! API keys are exchanged for short-lived access tokens:
20//!
21//! ```text
22//! // Client sends API key
23//! Authorization: Bearer ak_2n4f8.w9x7y6z5a4b3c2d1
24//!
25//! // Server validates and returns access token
26//! {
27//!   "access_token": "v4.local.xxxxx",
28//!   "expires_in": 900
29//! }
30//! ```
31
32use base64::Engine;
33use blueprint_std::rand::{CryptoRng, RngCore};
34use prost::Message;
35use std::collections::BTreeMap;
36use std::time::{SystemTime, UNIX_EPOCH};
37
38use crate::{
39    Error,
40    api_tokens::CUSTOM_ENGINE,
41    db::{RocksDb, cf},
42    types::ServiceId,
43};
44
45/// Long-lived API Key model stored in database
46#[derive(prost::Message, Clone)]
47pub struct ApiKeyModel {
48    /// Unique database ID
49    #[prost(uint64)]
50    pub id: u64,
51    /// API key identifier (e.g., "ak_2n4f8...")
52    #[prost(string)]
53    pub key_id: String,
54    /// Hashed secret part of the API key
55    #[prost(string)]
56    pub key_hash: String,
57    /// Service ID this key belongs to
58    #[prost(uint64)]
59    pub service_id: u64,
60    /// Sub-service ID (zero means no sub-service)
61    #[prost(uint64)]
62    pub sub_service_id: u64,
63    /// When this key was created (seconds since epoch)
64    #[prost(uint64)]
65    pub created_at: u64,
66    /// When this key was last used (seconds since epoch)
67    #[prost(uint64)]
68    pub last_used: u64,
69    /// When this key expires (seconds since epoch)
70    #[prost(uint64)]
71    pub expires_at: u64,
72    /// Whether this key is enabled
73    #[prost(bool)]
74    pub is_enabled: bool,
75    /// Default headers to include in access tokens (JSON-encoded)
76    #[prost(bytes)]
77    pub default_headers: Vec<u8>,
78    /// Human-readable description
79    #[prost(string)]
80    pub description: String,
81}
82
83/// Generated API Key containing both public and secret parts
84#[derive(Debug, Clone)]
85pub struct GeneratedApiKey {
86    /// Public key identifier
87    pub key_id: String,
88    /// Full API key (key_id + secret)
89    pub full_key: String,
90    /// Service ID this key is for
91    pub service_id: ServiceId,
92    /// Expiration timestamp
93    pub expires_at: u64,
94    /// Default headers
95    pub default_headers: BTreeMap<String, String>,
96}
97
98/// API Key generator
99pub struct ApiKeyGenerator {
100    prefix: String,
101}
102
103impl Default for ApiKeyGenerator {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl ApiKeyGenerator {
110    /// Create new generator with default prefix "ak_"
111    pub fn new() -> Self {
112        Self {
113            prefix: "ak_".to_string(),
114        }
115    }
116
117    /// Create generator with custom prefix
118    pub fn with_prefix(prefix: &str) -> Self {
119        Self {
120            prefix: prefix.to_string(),
121        }
122    }
123
124    /// Generate a new API key
125    pub fn generate_key<R: RngCore + CryptoRng>(
126        &self,
127        service_id: ServiceId,
128        expires_at: u64,
129        default_headers: BTreeMap<String, String>,
130        rng: &mut R,
131    ) -> GeneratedApiKey {
132        // Generate 32 bytes of randomness
133        let mut secret_bytes = vec![0u8; 32];
134        rng.fill_bytes(&mut secret_bytes);
135
136        // Create key identifier (first 8 bytes, base64 encoded)
137        let key_id = format!(
138            "{}{}",
139            self.prefix,
140            CUSTOM_ENGINE.encode(&secret_bytes[..8])
141        );
142
143        // Full key is key_id + "." + remaining 24 bytes base64 encoded
144        let secret_part = CUSTOM_ENGINE.encode(&secret_bytes[8..]);
145        let full_key = format!("{key_id}.{secret_part}");
146
147        GeneratedApiKey {
148            key_id,
149            full_key,
150            service_id,
151            expires_at,
152            default_headers,
153        }
154    }
155}
156
157impl GeneratedApiKey {
158    /// Get the public key identifier
159    pub fn key_id(&self) -> &str {
160        &self.key_id
161    }
162
163    /// Get the full API key (to be shared with client)
164    pub fn full_key(&self) -> &str {
165        &self.full_key
166    }
167
168    /// Get expiration time
169    pub fn expires_at(&self) -> u64 {
170        self.expires_at
171    }
172
173    /// Get default headers
174    pub fn default_headers(&self) -> &BTreeMap<String, String> {
175        &self.default_headers
176    }
177}
178
179impl ApiKeyModel {
180    /// Find API key by key_id
181    pub fn find_by_key_id(key_id: &str, db: &RocksDb) -> Result<Option<Self>, Error> {
182        let cf = db
183            .cf_handle(cf::API_KEYS_CF)
184            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
185
186        if let Some(key_bytes) = db.get_pinned_cf(&cf, key_id.as_bytes())? {
187            let model = Self::decode(key_bytes.as_ref())?;
188            Ok(Some(model))
189        } else {
190            Ok(None)
191        }
192    }
193
194    /// Find API key by database ID
195    pub fn find_by_id(id: u64, db: &RocksDb) -> Result<Option<Self>, Error> {
196        let cf = db
197            .cf_handle(cf::API_KEYS_BY_ID_CF)
198            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
199
200        if let Some(key_id_bytes) = db.get_pinned_cf(&cf, id.to_be_bytes())? {
201            let key_id =
202                String::from_utf8(key_id_bytes.to_vec()).map_err(|_| Error::UnknownKeyType)?; // Reusing error type
203            Self::find_by_key_id(&key_id, db)
204        } else {
205            Ok(None)
206        }
207    }
208
209    /// Validate if the given full key matches this stored key
210    pub fn validates_key(&self, full_key: &str) -> bool {
211        // Parse full key: "ak_xxxx.yyyy"
212        if let Some((key_id_part, _)) = full_key.split_once('.') {
213            if key_id_part != self.key_id {
214                return false;
215            }
216
217            // Hash the full key and compare
218            use tiny_keccak::Hasher;
219            let mut hasher = tiny_keccak::Keccak::v256();
220            hasher.update(full_key.as_bytes());
221            let mut output = [0u8; 32];
222            hasher.finalize(&mut output);
223            let computed_hash = CUSTOM_ENGINE.encode(output);
224
225            self.key_hash == computed_hash
226        } else {
227            false
228        }
229    }
230
231    /// Check if key is expired
232    pub fn is_expired(&self) -> bool {
233        let now = SystemTime::now()
234            .duration_since(UNIX_EPOCH)
235            .unwrap_or_default()
236            .as_secs();
237        now >= self.expires_at
238    }
239
240    /// Update last used timestamp
241    pub fn update_last_used(&mut self, db: &RocksDb) -> Result<(), Error> {
242        let now = SystemTime::now()
243            .duration_since(UNIX_EPOCH)
244            .unwrap_or_default()
245            .as_secs();
246        self.last_used = now;
247        self.save(db)
248    }
249
250    /// Get default headers as BTreeMap
251    pub fn get_default_headers(&self) -> BTreeMap<String, String> {
252        if self.default_headers.is_empty() {
253            BTreeMap::new()
254        } else {
255            serde_json::from_slice(&self.default_headers).unwrap_or_default()
256        }
257    }
258
259    /// Set default headers from BTreeMap
260    pub fn set_default_headers(&mut self, headers: &BTreeMap<String, String>) {
261        self.default_headers = serde_json::to_vec(headers).unwrap_or_default();
262    }
263
264    /// Get service ID
265    pub fn service_id(&self) -> ServiceId {
266        ServiceId::new(self.service_id).with_subservice(self.sub_service_id)
267    }
268
269    /// Save to database
270    pub fn save(&mut self, db: &RocksDb) -> Result<(), Error> {
271        let keys_cf = db
272            .cf_handle(cf::API_KEYS_CF)
273            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
274        let ids_cf = db
275            .cf_handle(cf::API_KEYS_BY_ID_CF)
276            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
277
278        if self.id == 0 {
279            self.create(db)
280        } else {
281            // Update existing
282            let key_bytes = self.encode_to_vec();
283            db.put_cf(&keys_cf, self.key_id.as_bytes(), key_bytes)?;
284            db.put_cf(&ids_cf, self.id.to_be_bytes(), self.key_id.as_bytes())?;
285            Ok(())
286        }
287    }
288
289    fn create(&mut self, db: &RocksDb) -> Result<(), Error> {
290        let keys_cf = db
291            .cf_handle(cf::API_KEYS_CF)
292            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
293        let ids_cf = db
294            .cf_handle(cf::API_KEYS_BY_ID_CF)
295            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
296        let seq_cf = db
297            .cf_handle(cf::SEQ_CF)
298            .ok_or(Error::UnknownColumnFamily(cf::SEQ_CF))?;
299
300        let txn = db.transaction();
301
302        // Increment sequence
303        let mut retry_count = 0;
304        let max_retries = 10;
305        loop {
306            let result = txn.merge_cf(&seq_cf, b"api_keys", 1u64.to_be_bytes());
307            match result {
308                Ok(()) => break,
309                Err(e)
310                    if matches!(
311                        e.kind(),
312                        rocksdb::ErrorKind::Busy | rocksdb::ErrorKind::TryAgain
313                    ) =>
314                {
315                    retry_count += 1;
316                    if retry_count >= max_retries {
317                        return Err(Error::RocksDB(e));
318                    }
319                }
320                Err(e) => return Err(Error::RocksDB(e)),
321            }
322        }
323
324        let next_id = txn
325            .get_cf(&seq_cf, b"api_keys")?
326            .map(|v| {
327                let mut id = [0u8; 8];
328                id.copy_from_slice(&v);
329                u64::from_be_bytes(id)
330            })
331            .unwrap_or(1u64);
332
333        self.id = next_id;
334        let key_bytes = self.encode_to_vec();
335        txn.put_cf(&keys_cf, self.key_id.as_bytes(), key_bytes)?;
336        txn.put_cf(&ids_cf, next_id.to_be_bytes(), self.key_id.as_bytes())?;
337
338        txn.commit()?;
339        Ok(())
340    }
341
342    /// Delete from database
343    pub fn delete(&self, db: &RocksDb) -> Result<(), Error> {
344        let keys_cf = db
345            .cf_handle(cf::API_KEYS_CF)
346            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
347        let ids_cf = db
348            .cf_handle(cf::API_KEYS_BY_ID_CF)
349            .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
350
351        let txn = db.transaction();
352        txn.delete_cf(&keys_cf, self.key_id.as_bytes())?;
353        txn.delete_cf(&ids_cf, self.id.to_be_bytes())?;
354        txn.commit()?;
355        Ok(())
356    }
357}
358
359impl From<&GeneratedApiKey> for ApiKeyModel {
360    fn from(key: &GeneratedApiKey) -> Self {
361        use tiny_keccak::Hasher;
362
363        let now = SystemTime::now()
364            .duration_since(UNIX_EPOCH)
365            .unwrap_or_default()
366            .as_secs();
367
368        // Hash the full key
369        let mut hasher = tiny_keccak::Keccak::v256();
370        hasher.update(key.full_key.as_bytes());
371        let mut output = [0u8; 32];
372        hasher.finalize(&mut output);
373        let key_hash = CUSTOM_ENGINE.encode(output);
374
375        let mut model = Self {
376            id: 0,
377            key_id: key.key_id.clone(),
378            key_hash,
379            service_id: key.service_id.0,
380            sub_service_id: key.service_id.1,
381            created_at: now,
382            last_used: 0,
383            expires_at: key.expires_at,
384            is_enabled: true,
385            default_headers: Vec::new(),
386            description: String::new(),
387        };
388
389        model.set_default_headers(&key.default_headers);
390        model
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::types::ServiceId;
398    use tempfile::tempdir;
399
400    #[test]
401    fn test_api_key_generation() {
402        let mut rng = blueprint_std::BlueprintRng::new();
403        let generator = ApiKeyGenerator::new();
404        let service_id = ServiceId::new(1);
405        let expires_at = 1234567890;
406        let mut headers = BTreeMap::new();
407        headers.insert("X-Tenant-Id".to_string(), "tenant123".to_string());
408
409        let key = generator.generate_key(service_id, expires_at, headers.clone(), &mut rng);
410
411        assert!(key.key_id().starts_with("ak_"));
412        assert!(key.full_key().contains('.'));
413        assert_eq!(key.expires_at(), expires_at);
414        assert_eq!(key.default_headers(), &headers);
415
416        // Should have format: ak_xxxx.yyyy
417        let parts: Vec<&str> = key.full_key().split('.').collect();
418        assert_eq!(parts.len(), 2);
419        assert_eq!(parts[0], key.key_id());
420    }
421
422    #[test]
423    fn test_api_key_validation() {
424        let mut rng = blueprint_std::BlueprintRng::new();
425        let generator = ApiKeyGenerator::new();
426
427        let key = generator.generate_key(ServiceId::new(1), 1234567890, BTreeMap::new(), &mut rng);
428
429        let model = ApiKeyModel::from(&key);
430
431        // Should validate correct key
432        assert!(model.validates_key(key.full_key()));
433
434        // Should not validate incorrect key
435        let wrong_key = key.full_key().replace('a', "b");
436        assert!(!model.validates_key(&wrong_key));
437
438        // Should not validate malformed key
439        assert!(!model.validates_key("invalid"));
440        assert!(!model.validates_key("ak_test"));
441    }
442
443    #[test]
444    fn test_api_key_database_operations() {
445        let tmp_dir = tempdir().unwrap();
446        let db_config = crate::db::RocksDbConfig::default();
447        let db = RocksDb::open(tmp_dir.path(), &db_config).unwrap();
448        let mut rng = blueprint_std::BlueprintRng::new();
449        let generator = ApiKeyGenerator::new();
450
451        let key = generator.generate_key(ServiceId::new(1), 1234567890, BTreeMap::new(), &mut rng);
452
453        let mut model = ApiKeyModel::from(&key);
454
455        // Save should assign ID
456        model.save(&db).unwrap();
457        assert_ne!(model.id, 0);
458
459        // Should be able to find by key_id
460        let found = ApiKeyModel::find_by_key_id(key.key_id(), &db)
461            .unwrap()
462            .unwrap();
463        assert_eq!(found.key_id, model.key_id);
464        assert_eq!(found.id, model.id);
465
466        // Should be able to find by ID
467        let found_by_id = ApiKeyModel::find_by_id(model.id, &db).unwrap().unwrap();
468        assert_eq!(found_by_id.key_id, model.key_id);
469
470        // Should validate the key
471        assert!(found.validates_key(key.full_key()));
472
473        // Delete should work
474        model.delete(&db).unwrap();
475        assert!(
476            ApiKeyModel::find_by_key_id(key.key_id(), &db)
477                .unwrap()
478                .is_none()
479        );
480    }
481
482    #[test]
483    fn test_expiration() {
484        let now = SystemTime::now()
485            .duration_since(UNIX_EPOCH)
486            .unwrap()
487            .as_secs();
488
489        let mut model = ApiKeyModel {
490            id: 1,
491            key_id: "ak_test".to_string(),
492            key_hash: "hash".to_string(),
493            service_id: 1,
494            sub_service_id: 0,
495            created_at: now,
496            last_used: 0,
497            expires_at: now - 1, // Expired
498            is_enabled: true,
499            default_headers: Vec::new(),
500            description: "Test".to_string(),
501        };
502
503        assert!(model.is_expired());
504
505        model.expires_at = now + 3600; // Not expired
506        assert!(!model.is_expired());
507    }
508
509    #[test]
510    fn test_headers_serialization() {
511        let mut model = ApiKeyModel {
512            id: 1,
513            key_id: "ak_test".to_string(),
514            key_hash: "hash".to_string(),
515            service_id: 1,
516            sub_service_id: 0,
517            created_at: 0,
518            last_used: 0,
519            expires_at: 0,
520            is_enabled: true,
521            default_headers: Vec::new(),
522            description: "Test".to_string(),
523        };
524
525        let mut headers = BTreeMap::new();
526        headers.insert("X-Test".to_string(), "value".to_string());
527
528        model.set_default_headers(&headers);
529        let retrieved = model.get_default_headers();
530
531        assert_eq!(retrieved, headers);
532    }
533}