1use base64::Engine;
33use blueprint_std::rand::{CryptoRng, RngCore};
34use prost::Message;
35use std::collections::BTreeMap;
36use std::time::{SystemTime, UNIX_EPOCH};
37
38use crate::{
39 Error,
40 api_tokens::CUSTOM_ENGINE,
41 db::{RocksDb, cf},
42 types::ServiceId,
43};
44
45#[derive(prost::Message, Clone)]
47pub struct ApiKeyModel {
48 #[prost(uint64)]
50 pub id: u64,
51 #[prost(string)]
53 pub key_id: String,
54 #[prost(string)]
56 pub key_hash: String,
57 #[prost(uint64)]
59 pub service_id: u64,
60 #[prost(uint64)]
62 pub sub_service_id: u64,
63 #[prost(uint64)]
65 pub created_at: u64,
66 #[prost(uint64)]
68 pub last_used: u64,
69 #[prost(uint64)]
71 pub expires_at: u64,
72 #[prost(bool)]
74 pub is_enabled: bool,
75 #[prost(bytes)]
77 pub default_headers: Vec<u8>,
78 #[prost(string)]
80 pub description: String,
81}
82
83#[derive(Debug, Clone)]
85pub struct GeneratedApiKey {
86 pub key_id: String,
88 pub full_key: String,
90 pub service_id: ServiceId,
92 pub expires_at: u64,
94 pub default_headers: BTreeMap<String, String>,
96}
97
98pub struct ApiKeyGenerator {
100 prefix: String,
101}
102
103impl Default for ApiKeyGenerator {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl ApiKeyGenerator {
110 pub fn new() -> Self {
112 Self {
113 prefix: "ak_".to_string(),
114 }
115 }
116
117 pub fn with_prefix(prefix: &str) -> Self {
119 Self {
120 prefix: prefix.to_string(),
121 }
122 }
123
124 pub fn generate_key<R: RngCore + CryptoRng>(
126 &self,
127 service_id: ServiceId,
128 expires_at: u64,
129 default_headers: BTreeMap<String, String>,
130 rng: &mut R,
131 ) -> GeneratedApiKey {
132 let mut secret_bytes = vec![0u8; 32];
134 rng.fill_bytes(&mut secret_bytes);
135
136 let key_id = format!(
138 "{}{}",
139 self.prefix,
140 CUSTOM_ENGINE.encode(&secret_bytes[..8])
141 );
142
143 let secret_part = CUSTOM_ENGINE.encode(&secret_bytes[8..]);
145 let full_key = format!("{key_id}.{secret_part}");
146
147 GeneratedApiKey {
148 key_id,
149 full_key,
150 service_id,
151 expires_at,
152 default_headers,
153 }
154 }
155}
156
157impl GeneratedApiKey {
158 pub fn key_id(&self) -> &str {
160 &self.key_id
161 }
162
163 pub fn full_key(&self) -> &str {
165 &self.full_key
166 }
167
168 pub fn expires_at(&self) -> u64 {
170 self.expires_at
171 }
172
173 pub fn default_headers(&self) -> &BTreeMap<String, String> {
175 &self.default_headers
176 }
177}
178
179impl ApiKeyModel {
180 pub fn find_by_key_id(key_id: &str, db: &RocksDb) -> Result<Option<Self>, Error> {
182 let cf = db
183 .cf_handle(cf::API_KEYS_CF)
184 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
185
186 if let Some(key_bytes) = db.get_pinned_cf(&cf, key_id.as_bytes())? {
187 let model = Self::decode(key_bytes.as_ref())?;
188 Ok(Some(model))
189 } else {
190 Ok(None)
191 }
192 }
193
194 pub fn find_by_id(id: u64, db: &RocksDb) -> Result<Option<Self>, Error> {
196 let cf = db
197 .cf_handle(cf::API_KEYS_BY_ID_CF)
198 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
199
200 if let Some(key_id_bytes) = db.get_pinned_cf(&cf, id.to_be_bytes())? {
201 let key_id =
202 String::from_utf8(key_id_bytes.to_vec()).map_err(|_| Error::UnknownKeyType)?; Self::find_by_key_id(&key_id, db)
204 } else {
205 Ok(None)
206 }
207 }
208
209 pub fn validates_key(&self, full_key: &str) -> bool {
211 if let Some((key_id_part, _)) = full_key.split_once('.') {
213 if key_id_part != self.key_id {
214 return false;
215 }
216
217 use tiny_keccak::Hasher;
219 let mut hasher = tiny_keccak::Keccak::v256();
220 hasher.update(full_key.as_bytes());
221 let mut output = [0u8; 32];
222 hasher.finalize(&mut output);
223 let computed_hash = CUSTOM_ENGINE.encode(output);
224
225 self.key_hash == computed_hash
226 } else {
227 false
228 }
229 }
230
231 pub fn is_expired(&self) -> bool {
233 let now = SystemTime::now()
234 .duration_since(UNIX_EPOCH)
235 .unwrap_or_default()
236 .as_secs();
237 now >= self.expires_at
238 }
239
240 pub fn update_last_used(&mut self, db: &RocksDb) -> Result<(), Error> {
242 let now = SystemTime::now()
243 .duration_since(UNIX_EPOCH)
244 .unwrap_or_default()
245 .as_secs();
246 self.last_used = now;
247 self.save(db)
248 }
249
250 pub fn get_default_headers(&self) -> BTreeMap<String, String> {
252 if self.default_headers.is_empty() {
253 BTreeMap::new()
254 } else {
255 serde_json::from_slice(&self.default_headers).unwrap_or_default()
256 }
257 }
258
259 pub fn set_default_headers(&mut self, headers: &BTreeMap<String, String>) {
261 self.default_headers = serde_json::to_vec(headers).unwrap_or_default();
262 }
263
264 pub fn service_id(&self) -> ServiceId {
266 ServiceId::new(self.service_id).with_subservice(self.sub_service_id)
267 }
268
269 pub fn save(&mut self, db: &RocksDb) -> Result<(), Error> {
271 let keys_cf = db
272 .cf_handle(cf::API_KEYS_CF)
273 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
274 let ids_cf = db
275 .cf_handle(cf::API_KEYS_BY_ID_CF)
276 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
277
278 if self.id == 0 {
279 self.create(db)
280 } else {
281 let key_bytes = self.encode_to_vec();
283 db.put_cf(&keys_cf, self.key_id.as_bytes(), key_bytes)?;
284 db.put_cf(&ids_cf, self.id.to_be_bytes(), self.key_id.as_bytes())?;
285 Ok(())
286 }
287 }
288
289 fn create(&mut self, db: &RocksDb) -> Result<(), Error> {
290 let keys_cf = db
291 .cf_handle(cf::API_KEYS_CF)
292 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
293 let ids_cf = db
294 .cf_handle(cf::API_KEYS_BY_ID_CF)
295 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
296 let seq_cf = db
297 .cf_handle(cf::SEQ_CF)
298 .ok_or(Error::UnknownColumnFamily(cf::SEQ_CF))?;
299
300 let txn = db.transaction();
301
302 let mut retry_count = 0;
304 let max_retries = 10;
305 loop {
306 let result = txn.merge_cf(&seq_cf, b"api_keys", 1u64.to_be_bytes());
307 match result {
308 Ok(()) => break,
309 Err(e)
310 if matches!(
311 e.kind(),
312 rocksdb::ErrorKind::Busy | rocksdb::ErrorKind::TryAgain
313 ) =>
314 {
315 retry_count += 1;
316 if retry_count >= max_retries {
317 return Err(Error::RocksDB(e));
318 }
319 }
320 Err(e) => return Err(Error::RocksDB(e)),
321 }
322 }
323
324 let next_id = txn
325 .get_cf(&seq_cf, b"api_keys")?
326 .map(|v| {
327 let mut id = [0u8; 8];
328 id.copy_from_slice(&v);
329 u64::from_be_bytes(id)
330 })
331 .unwrap_or(1u64);
332
333 self.id = next_id;
334 let key_bytes = self.encode_to_vec();
335 txn.put_cf(&keys_cf, self.key_id.as_bytes(), key_bytes)?;
336 txn.put_cf(&ids_cf, next_id.to_be_bytes(), self.key_id.as_bytes())?;
337
338 txn.commit()?;
339 Ok(())
340 }
341
342 pub fn delete(&self, db: &RocksDb) -> Result<(), Error> {
344 let keys_cf = db
345 .cf_handle(cf::API_KEYS_CF)
346 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_CF))?;
347 let ids_cf = db
348 .cf_handle(cf::API_KEYS_BY_ID_CF)
349 .ok_or(Error::UnknownColumnFamily(cf::API_KEYS_BY_ID_CF))?;
350
351 let txn = db.transaction();
352 txn.delete_cf(&keys_cf, self.key_id.as_bytes())?;
353 txn.delete_cf(&ids_cf, self.id.to_be_bytes())?;
354 txn.commit()?;
355 Ok(())
356 }
357}
358
359impl From<&GeneratedApiKey> for ApiKeyModel {
360 fn from(key: &GeneratedApiKey) -> Self {
361 use tiny_keccak::Hasher;
362
363 let now = SystemTime::now()
364 .duration_since(UNIX_EPOCH)
365 .unwrap_or_default()
366 .as_secs();
367
368 let mut hasher = tiny_keccak::Keccak::v256();
370 hasher.update(key.full_key.as_bytes());
371 let mut output = [0u8; 32];
372 hasher.finalize(&mut output);
373 let key_hash = CUSTOM_ENGINE.encode(output);
374
375 let mut model = Self {
376 id: 0,
377 key_id: key.key_id.clone(),
378 key_hash,
379 service_id: key.service_id.0,
380 sub_service_id: key.service_id.1,
381 created_at: now,
382 last_used: 0,
383 expires_at: key.expires_at,
384 is_enabled: true,
385 default_headers: Vec::new(),
386 description: String::new(),
387 };
388
389 model.set_default_headers(&key.default_headers);
390 model
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::types::ServiceId;
398 use tempfile::tempdir;
399
400 #[test]
401 fn test_api_key_generation() {
402 let mut rng = blueprint_std::BlueprintRng::new();
403 let generator = ApiKeyGenerator::new();
404 let service_id = ServiceId::new(1);
405 let expires_at = 1234567890;
406 let mut headers = BTreeMap::new();
407 headers.insert("X-Tenant-Id".to_string(), "tenant123".to_string());
408
409 let key = generator.generate_key(service_id, expires_at, headers.clone(), &mut rng);
410
411 assert!(key.key_id().starts_with("ak_"));
412 assert!(key.full_key().contains('.'));
413 assert_eq!(key.expires_at(), expires_at);
414 assert_eq!(key.default_headers(), &headers);
415
416 let parts: Vec<&str> = key.full_key().split('.').collect();
418 assert_eq!(parts.len(), 2);
419 assert_eq!(parts[0], key.key_id());
420 }
421
422 #[test]
423 fn test_api_key_validation() {
424 let mut rng = blueprint_std::BlueprintRng::new();
425 let generator = ApiKeyGenerator::new();
426
427 let key = generator.generate_key(ServiceId::new(1), 1234567890, BTreeMap::new(), &mut rng);
428
429 let model = ApiKeyModel::from(&key);
430
431 assert!(model.validates_key(key.full_key()));
433
434 let wrong_key = key.full_key().replace('a', "b");
436 assert!(!model.validates_key(&wrong_key));
437
438 assert!(!model.validates_key("invalid"));
440 assert!(!model.validates_key("ak_test"));
441 }
442
443 #[test]
444 fn test_api_key_database_operations() {
445 let tmp_dir = tempdir().unwrap();
446 let db_config = crate::db::RocksDbConfig::default();
447 let db = RocksDb::open(tmp_dir.path(), &db_config).unwrap();
448 let mut rng = blueprint_std::BlueprintRng::new();
449 let generator = ApiKeyGenerator::new();
450
451 let key = generator.generate_key(ServiceId::new(1), 1234567890, BTreeMap::new(), &mut rng);
452
453 let mut model = ApiKeyModel::from(&key);
454
455 model.save(&db).unwrap();
457 assert_ne!(model.id, 0);
458
459 let found = ApiKeyModel::find_by_key_id(key.key_id(), &db)
461 .unwrap()
462 .unwrap();
463 assert_eq!(found.key_id, model.key_id);
464 assert_eq!(found.id, model.id);
465
466 let found_by_id = ApiKeyModel::find_by_id(model.id, &db).unwrap().unwrap();
468 assert_eq!(found_by_id.key_id, model.key_id);
469
470 assert!(found.validates_key(key.full_key()));
472
473 model.delete(&db).unwrap();
475 assert!(
476 ApiKeyModel::find_by_key_id(key.key_id(), &db)
477 .unwrap()
478 .is_none()
479 );
480 }
481
482 #[test]
483 fn test_expiration() {
484 let now = SystemTime::now()
485 .duration_since(UNIX_EPOCH)
486 .unwrap()
487 .as_secs();
488
489 let mut model = ApiKeyModel {
490 id: 1,
491 key_id: "ak_test".to_string(),
492 key_hash: "hash".to_string(),
493 service_id: 1,
494 sub_service_id: 0,
495 created_at: now,
496 last_used: 0,
497 expires_at: now - 1, is_enabled: true,
499 default_headers: Vec::new(),
500 description: "Test".to_string(),
501 };
502
503 assert!(model.is_expired());
504
505 model.expires_at = now + 3600; assert!(!model.is_expired());
507 }
508
509 #[test]
510 fn test_headers_serialization() {
511 let mut model = ApiKeyModel {
512 id: 1,
513 key_id: "ak_test".to_string(),
514 key_hash: "hash".to_string(),
515 service_id: 1,
516 sub_service_id: 0,
517 created_at: 0,
518 last_used: 0,
519 expires_at: 0,
520 is_enabled: true,
521 default_headers: Vec::new(),
522 description: "Test".to_string(),
523 };
524
525 let mut headers = BTreeMap::new();
526 headers.insert("X-Test".to_string(), "value".to_string());
527
528 model.set_default_headers(&headers);
529 let retrieved = model.get_default_headers();
530
531 assert_eq!(retrieved, headers);
532 }
533}