Skip to main content

perfgate_server/storage/
key_store.rs

1//! Persistent key store trait and implementations.
2//!
3//! Provides [`KeyStore`] for managing API keys in a database, with
4//! implementations for SQLite and in-memory backends.
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use perfgate_auth::Role;
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use tokio::sync::RwLock;
13
14use crate::error::StoreError;
15
16/// A persistent API key record stored in the database.
17#[derive(Debug, Clone)]
18pub struct KeyRecord {
19    /// Unique identifier
20    pub id: String,
21    /// SHA-256 hash of the plaintext key
22    pub key_hash: String,
23    /// First 12 characters of the key (for display)
24    pub key_prefix: String,
25    /// Assigned role
26    pub role: Role,
27    /// Project scope
28    pub project: String,
29    /// Optional benchmark glob pattern
30    pub pattern: Option<String>,
31    /// Human-readable description
32    pub description: String,
33    /// Creation timestamp
34    pub created_at: DateTime<Utc>,
35    /// Expiration timestamp
36    pub expires_at: Option<DateTime<Utc>>,
37    /// Revocation timestamp (None if active)
38    pub revoked_at: Option<DateTime<Utc>>,
39}
40
41impl KeyRecord {
42    /// Returns true if this key has been revoked.
43    pub fn is_revoked(&self) -> bool {
44        self.revoked_at.is_some()
45    }
46
47    /// Returns true if this key has expired.
48    pub fn is_expired(&self) -> bool {
49        self.expires_at.is_some_and(|exp| exp < Utc::now())
50    }
51
52    /// Returns true if this key is currently valid (not revoked and not expired).
53    pub fn is_active(&self) -> bool {
54        !self.is_revoked() && !self.is_expired()
55    }
56}
57
58/// Hashes a plaintext API key for storage.
59pub fn hash_key(key: &str) -> String {
60    let mut hasher = Sha256::new();
61    hasher.update(key.as_bytes());
62    format!("{:x}", hasher.finalize())
63}
64
65/// Extracts a display prefix from a plaintext key.
66pub fn key_prefix(key: &str) -> String {
67    let prefix_len = 12.min(key.len());
68    format!("{}...***", &key[..prefix_len])
69}
70
71/// Trait for persistent API key storage.
72#[async_trait]
73pub trait KeyStore: Send + Sync {
74    /// Stores a new key record. The `key_hash` field must already be set.
75    async fn create_key(&self, record: &KeyRecord) -> Result<(), StoreError>;
76
77    /// Lists all keys (including revoked, for admin view).
78    async fn list_keys(&self) -> Result<Vec<KeyRecord>, StoreError>;
79
80    /// Revokes a key by its ID. Returns the revocation timestamp.
81    async fn revoke_key(&self, id: &str) -> Result<Option<DateTime<Utc>>, StoreError>;
82
83    /// Validates a plaintext key and returns its record if active.
84    async fn validate_key(&self, raw_key: &str) -> Result<Option<KeyRecord>, StoreError>;
85}
86
87// ── In-Memory Implementation ──────────────────────────────────────────
88
89/// In-memory key store for testing and development.
90#[derive(Debug, Default)]
91pub struct InMemoryKeyStore {
92    /// Records keyed by ID
93    records: Arc<RwLock<HashMap<String, KeyRecord>>>,
94}
95
96impl InMemoryKeyStore {
97    /// Creates a new empty in-memory key store.
98    pub fn new() -> Self {
99        Self {
100            records: Arc::new(RwLock::new(HashMap::new())),
101        }
102    }
103}
104
105#[async_trait]
106impl KeyStore for InMemoryKeyStore {
107    async fn create_key(&self, record: &KeyRecord) -> Result<(), StoreError> {
108        let mut records = self.records.write().await;
109        if records.contains_key(&record.id) {
110            return Err(StoreError::AlreadyExists(format!("key id={}", record.id)));
111        }
112        records.insert(record.id.clone(), record.clone());
113        Ok(())
114    }
115
116    async fn list_keys(&self) -> Result<Vec<KeyRecord>, StoreError> {
117        let records = self.records.read().await;
118        let mut keys: Vec<_> = records.values().cloned().collect();
119        keys.sort_by(|a, b| b.created_at.cmp(&a.created_at));
120        Ok(keys)
121    }
122
123    async fn revoke_key(&self, id: &str) -> Result<Option<DateTime<Utc>>, StoreError> {
124        let mut records = self.records.write().await;
125        if let Some(record) = records.get_mut(id) {
126            if record.revoked_at.is_some() {
127                return Ok(record.revoked_at);
128            }
129            let now = Utc::now();
130            record.revoked_at = Some(now);
131            Ok(Some(now))
132        } else {
133            Ok(None)
134        }
135    }
136
137    async fn validate_key(&self, raw_key: &str) -> Result<Option<KeyRecord>, StoreError> {
138        let hash = hash_key(raw_key);
139        let records = self.records.read().await;
140        let record = records.values().find(|r| r.key_hash == hash).cloned();
141        match record {
142            Some(r) if r.is_active() => Ok(Some(r)),
143            _ => Ok(None),
144        }
145    }
146}
147
148// ── SQLite Implementation ─────────────────────────────────────────────
149
150/// SQLite-backed persistent key store.
151#[derive(Debug)]
152pub struct SqliteKeyStore {
153    conn: Arc<Mutex<rusqlite::Connection>>,
154}
155
156impl SqliteKeyStore {
157    /// Opens or creates a key store backed by the given SQLite connection.
158    /// The `api_keys` table is created if it does not exist.
159    pub fn new(conn: Arc<Mutex<rusqlite::Connection>>) -> Result<Self, StoreError> {
160        {
161            let c = conn
162                .lock()
163                .map_err(|e| StoreError::LockError(e.to_string()))?;
164            c.execute_batch(
165                r#"
166                CREATE TABLE IF NOT EXISTS api_keys (
167                    id TEXT PRIMARY KEY,
168                    key_hash TEXT NOT NULL UNIQUE,
169                    key_prefix TEXT NOT NULL,
170                    role TEXT NOT NULL,
171                    project TEXT NOT NULL,
172                    pattern TEXT,
173                    description TEXT NOT NULL,
174                    created_at TEXT NOT NULL,
175                    expires_at TEXT,
176                    revoked_at TEXT
177                );
178                CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
179                "#,
180            )?;
181        }
182        Ok(Self { conn })
183    }
184
185    /// Creates an in-memory SQLite key store (for testing).
186    pub fn in_memory() -> Result<Self, StoreError> {
187        let conn = rusqlite::Connection::open_in_memory()?;
188        Self::new(Arc::new(Mutex::new(conn)))
189    }
190
191    fn row_to_record(row: &rusqlite::Row) -> Result<KeyRecord, rusqlite::Error> {
192        let role_str: String = row.get(3)?;
193        let role = match role_str.as_str() {
194            "admin" => Role::Admin,
195            "promoter" => Role::Promoter,
196            "contributor" => Role::Contributor,
197            _ => Role::Viewer,
198        };
199
200        let created_at_str: String = row.get(7)?;
201        let expires_at_str: Option<String> = row.get(8)?;
202        let revoked_at_str: Option<String> = row.get(9)?;
203
204        Ok(KeyRecord {
205            id: row.get(0)?,
206            key_hash: row.get(1)?,
207            key_prefix: row.get(2)?,
208            role,
209            project: row.get(4)?,
210            pattern: row.get(5)?,
211            description: row.get(6)?,
212            created_at: parse_dt(&created_at_str),
213            expires_at: expires_at_str.as_deref().map(parse_dt),
214            revoked_at: revoked_at_str.as_deref().map(parse_dt),
215        })
216    }
217}
218
219fn parse_dt(s: &str) -> DateTime<Utc> {
220    chrono::DateTime::parse_from_rfc3339(s)
221        .map(|dt| dt.with_timezone(&Utc))
222        .unwrap_or_else(|_| Utc::now())
223}
224
225fn role_str(role: &Role) -> &'static str {
226    match role {
227        Role::Admin => "admin",
228        Role::Promoter => "promoter",
229        Role::Contributor => "contributor",
230        Role::Viewer => "viewer",
231    }
232}
233
234#[async_trait]
235impl KeyStore for SqliteKeyStore {
236    async fn create_key(&self, record: &KeyRecord) -> Result<(), StoreError> {
237        let conn = self
238            .conn
239            .lock()
240            .map_err(|e| StoreError::LockError(e.to_string()))?;
241        conn.execute(
242            r#"
243            INSERT INTO api_keys (id, key_hash, key_prefix, role, project, pattern, description, created_at, expires_at, revoked_at)
244            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
245            "#,
246            rusqlite::params![
247                record.id,
248                record.key_hash,
249                record.key_prefix,
250                role_str(&record.role),
251                record.project,
252                record.pattern,
253                record.description,
254                record.created_at.to_rfc3339(),
255                record.expires_at.map(|t| t.to_rfc3339()),
256                record.revoked_at.map(|t| t.to_rfc3339()),
257            ],
258        )
259        .map_err(|e| match &e {
260            rusqlite::Error::SqliteFailure(err, _)
261                if err.code == rusqlite::ErrorCode::ConstraintViolation =>
262            {
263                StoreError::AlreadyExists(format!("key id={}", record.id))
264            }
265            _ => StoreError::SqliteError(e),
266        })?;
267        Ok(())
268    }
269
270    async fn list_keys(&self) -> Result<Vec<KeyRecord>, StoreError> {
271        let conn = self
272            .conn
273            .lock()
274            .map_err(|e| StoreError::LockError(e.to_string()))?;
275        let mut stmt = conn.prepare("SELECT * FROM api_keys ORDER BY created_at DESC")?;
276        let rows = stmt
277            .query_map([], Self::row_to_record)?
278            .collect::<Result<Vec<_>, _>>()?;
279        Ok(rows)
280    }
281
282    async fn revoke_key(&self, id: &str) -> Result<Option<DateTime<Utc>>, StoreError> {
283        let conn = self
284            .conn
285            .lock()
286            .map_err(|e| StoreError::LockError(e.to_string()))?;
287        let now = Utc::now();
288        let n = conn.execute(
289            "UPDATE api_keys SET revoked_at = ?1 WHERE id = ?2 AND revoked_at IS NULL",
290            rusqlite::params![now.to_rfc3339(), id],
291        )?;
292        if n > 0 {
293            Ok(Some(now))
294        } else {
295            // Check if the key exists at all
296            let exists: bool = conn.query_row(
297                "SELECT COUNT(*) > 0 FROM api_keys WHERE id = ?1",
298                rusqlite::params![id],
299                |row| row.get(0),
300            )?;
301            if exists {
302                // Already revoked — return existing revoked_at
303                let revoked_at: Option<String> = conn.query_row(
304                    "SELECT revoked_at FROM api_keys WHERE id = ?1",
305                    rusqlite::params![id],
306                    |row| row.get(0),
307                )?;
308                Ok(revoked_at.as_deref().map(parse_dt))
309            } else {
310                Ok(None)
311            }
312        }
313    }
314
315    async fn validate_key(&self, raw_key: &str) -> Result<Option<KeyRecord>, StoreError> {
316        let hash = hash_key(raw_key);
317        let conn = self
318            .conn
319            .lock()
320            .map_err(|e| StoreError::LockError(e.to_string()))?;
321        let result = conn
322            .query_row(
323                "SELECT * FROM api_keys WHERE key_hash = ?1",
324                rusqlite::params![hash],
325                Self::row_to_record,
326            )
327            .optional()?;
328        match result {
329            Some(r) if r.is_active() => Ok(Some(r)),
330            _ => Ok(None),
331        }
332    }
333}
334
335use rusqlite::OptionalExtension;
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use perfgate_auth::generate_api_key;
341
342    fn make_record(raw_key: &str, role: Role) -> KeyRecord {
343        KeyRecord {
344            id: uuid::Uuid::new_v4().to_string(),
345            key_hash: hash_key(raw_key),
346            key_prefix: key_prefix(raw_key),
347            role,
348            project: "default".to_string(),
349            pattern: None,
350            description: "test key".to_string(),
351            created_at: Utc::now(),
352            expires_at: None,
353            revoked_at: None,
354        }
355    }
356
357    #[tokio::test]
358    async fn test_inmemory_crud() {
359        let store = InMemoryKeyStore::new();
360        let raw = generate_api_key(false);
361        let rec = make_record(&raw, Role::Contributor);
362        let id = rec.id.clone();
363
364        store.create_key(&rec).await.unwrap();
365
366        let keys = store.list_keys().await.unwrap();
367        assert_eq!(keys.len(), 1);
368        assert_eq!(keys[0].id, id);
369
370        let found = store.validate_key(&raw).await.unwrap();
371        assert!(found.is_some());
372
373        let revoked = store.revoke_key(&id).await.unwrap();
374        assert!(revoked.is_some());
375
376        let found = store.validate_key(&raw).await.unwrap();
377        assert!(found.is_none(), "revoked key should not validate");
378    }
379
380    #[tokio::test]
381    async fn test_inmemory_expiration() {
382        let store = InMemoryKeyStore::new();
383        let raw = generate_api_key(false);
384        let mut rec = make_record(&raw, Role::Viewer);
385        rec.expires_at = Some(Utc::now() - chrono::Duration::hours(1));
386
387        store.create_key(&rec).await.unwrap();
388        let found = store.validate_key(&raw).await.unwrap();
389        assert!(found.is_none(), "expired key should not validate");
390    }
391
392    #[tokio::test(flavor = "multi_thread")]
393    async fn test_sqlite_crud() {
394        let store = SqliteKeyStore::in_memory().unwrap();
395        let raw = generate_api_key(false);
396        let rec = make_record(&raw, Role::Admin);
397        let id = rec.id.clone();
398
399        store.create_key(&rec).await.unwrap();
400
401        let keys = store.list_keys().await.unwrap();
402        assert_eq!(keys.len(), 1);
403        assert_eq!(keys[0].role, Role::Admin);
404
405        let found = store.validate_key(&raw).await.unwrap();
406        assert!(found.is_some());
407
408        let revoked = store.revoke_key(&id).await.unwrap();
409        assert!(revoked.is_some());
410
411        let found = store.validate_key(&raw).await.unwrap();
412        assert!(found.is_none());
413    }
414
415    #[tokio::test(flavor = "multi_thread")]
416    async fn test_sqlite_expiration() {
417        let store = SqliteKeyStore::in_memory().unwrap();
418        let raw = generate_api_key(false);
419        let mut rec = make_record(&raw, Role::Viewer);
420        rec.expires_at = Some(Utc::now() - chrono::Duration::hours(1));
421
422        store.create_key(&rec).await.unwrap();
423        let found = store.validate_key(&raw).await.unwrap();
424        assert!(found.is_none());
425    }
426
427    #[tokio::test(flavor = "multi_thread")]
428    async fn test_sqlite_revoke_nonexistent() {
429        let store = SqliteKeyStore::in_memory().unwrap();
430        let result = store.revoke_key("nonexistent-id").await.unwrap();
431        assert!(result.is_none());
432    }
433
434    #[test]
435    fn test_hash_key_deterministic() {
436        let h1 = hash_key("pg_live_test123456789012345678901234567890");
437        let h2 = hash_key("pg_live_test123456789012345678901234567890");
438        assert_eq!(h1, h2);
439
440        let h3 = hash_key("pg_live_different1234567890123456789012");
441        assert_ne!(h1, h3);
442    }
443
444    #[test]
445    fn test_key_prefix_display() {
446        let prefix = key_prefix("pg_live_abcdefghijklmnopqrstuvwxyz123456");
447        assert_eq!(prefix, "pg_live_abcd...***");
448    }
449
450    #[test]
451    fn test_key_record_active() {
452        let raw = generate_api_key(false);
453        let rec = make_record(&raw, Role::Viewer);
454        assert!(rec.is_active());
455        assert!(!rec.is_revoked());
456        assert!(!rec.is_expired());
457    }
458}