Skip to main content

engram/auth/
tokens.rs

1//! API key and token management
2
3use crate::auth::{PermissionSet, UserId};
4use crate::error::{EngramError, Result};
5use chrono::{DateTime, Utc};
6use rand::Rng;
7use rusqlite::{params, Connection, OptionalExtension};
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use uuid::Uuid;
11
12/// API key with prefix for easy identification
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ApiKey {
15    pub id: String,
16    pub user_id: UserId,
17    pub name: String,
18    pub key_prefix: String,
19    pub permissions: PermissionSet,
20    pub namespace: Option<String>,
21    pub expires_at: Option<DateTime<Utc>>,
22    pub last_used_at: Option<DateTime<Utc>>,
23    pub is_active: bool,
24    pub created_at: DateTime<Utc>,
25}
26
27/// Token claims for validation
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TokenClaims {
30    pub user_id: UserId,
31    pub key_id: String,
32    pub permissions: PermissionSet,
33    pub namespace: Option<String>,
34    pub issued_at: DateTime<Utc>,
35    pub expires_at: Option<DateTime<Utc>>,
36}
37
38impl TokenClaims {
39    /// Check if the token is expired
40    pub fn is_expired(&self) -> bool {
41        if let Some(exp) = self.expires_at {
42            return Utc::now() > exp;
43        }
44        false
45    }
46}
47
48/// API key manager
49pub struct ApiKeyManager<'a> {
50    conn: &'a Connection,
51}
52
53impl<'a> ApiKeyManager<'a> {
54    /// Create a new API key manager
55    pub fn new(conn: &'a Connection) -> Self {
56        Self { conn }
57    }
58
59    /// Generate a new API key
60    /// Returns (ApiKey, raw_key) - raw_key should only be shown once
61    pub fn create_api_key(
62        &self,
63        user_id: &UserId,
64        name: &str,
65        permissions: PermissionSet,
66        namespace: Option<String>,
67        expires_in_days: Option<i64>,
68    ) -> Result<(ApiKey, String)> {
69        let id = Uuid::new_v4().to_string();
70        let raw_key = generate_api_key();
71        let key_hash = hash_key(&raw_key);
72        let key_prefix = &raw_key[..12]; // Show first 12 chars for identification
73
74        let expires_at = expires_in_days.map(|days| Utc::now() + chrono::Duration::days(days));
75
76        let permissions_json = serde_json::to_string(&permissions)?;
77
78        self.conn.execute(
79            r#"
80            INSERT INTO api_keys (id, user_id, key_hash, key_prefix, name, permissions, namespace, expires_at, is_active, created_at)
81            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, datetime('now'))
82            "#,
83            params![
84                id,
85                user_id.as_str(),
86                key_hash,
87                key_prefix,
88                name,
89                permissions_json,
90                namespace,
91                expires_at.map(|dt| dt.to_rfc3339()),
92            ],
93        )?;
94
95        let api_key = ApiKey {
96            id,
97            user_id: user_id.clone(),
98            name: name.to_string(),
99            key_prefix: key_prefix.to_string(),
100            permissions,
101            namespace,
102            expires_at,
103            last_used_at: None,
104            is_active: true,
105            created_at: Utc::now(),
106        };
107
108        Ok((api_key, raw_key))
109    }
110
111    /// Validate an API key and return claims
112    pub fn validate_key(&self, raw_key: &str) -> Result<Option<TokenClaims>> {
113        let key_hash = hash_key(raw_key);
114
115        let result: Option<(String, String, String, Option<String>, Option<String>, bool)> = self
116            .conn
117            .query_row(
118                r#"
119                SELECT ak.id, ak.user_id, ak.permissions, ak.namespace, ak.expires_at, u.is_active as user_active
120                FROM api_keys ak
121                JOIN users u ON ak.user_id = u.id
122                WHERE ak.key_hash = ?1 AND ak.is_active = 1
123                "#,
124                params![key_hash],
125                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?, row.get(5)?)),
126            )
127            .optional()?;
128
129        if let Some((key_id, user_id, permissions_json, namespace, expires_at_str, user_active)) =
130            result
131        {
132            if !user_active {
133                return Ok(None);
134            }
135
136            let expires_at = expires_at_str
137                .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
138                .map(|dt| dt.with_timezone(&Utc));
139
140            // Check expiration
141            if let Some(exp) = expires_at {
142                if Utc::now() > exp {
143                    return Ok(None);
144                }
145            }
146
147            // Update last used
148            self.conn.execute(
149                "UPDATE api_keys SET last_used_at = datetime('now') WHERE id = ?1",
150                params![key_id],
151            )?;
152
153            let permissions: PermissionSet = serde_json::from_str(&permissions_json)?;
154
155            Ok(Some(TokenClaims {
156                user_id: UserId::from_string(user_id),
157                key_id,
158                permissions,
159                namespace,
160                issued_at: Utc::now(),
161                expires_at,
162            }))
163        } else {
164            Ok(None)
165        }
166    }
167
168    /// Get API key by ID (without the raw key)
169    pub fn get_key(&self, id: &str) -> Result<Option<ApiKey>> {
170        self.conn
171            .query_row(
172                r#"
173                SELECT id, user_id, key_prefix, name, permissions, namespace, expires_at, last_used_at, is_active, created_at
174                FROM api_keys WHERE id = ?1
175                "#,
176                params![id],
177                |row| {
178                    let permissions_json: String = row.get(4)?;
179                    Ok(ApiKey {
180                        id: row.get(0)?,
181                        user_id: UserId::from_string(row.get::<_, String>(1)?),
182                        key_prefix: row.get(2)?,
183                        name: row.get(3)?,
184                        permissions: serde_json::from_str(&permissions_json).unwrap_or_default(),
185                        namespace: row.get(5)?,
186                        expires_at: row.get::<_, Option<String>>(6)?
187                            .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
188                            .map(|dt| dt.with_timezone(&Utc)),
189                        last_used_at: row.get::<_, Option<String>>(7)?
190                            .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
191                            .map(|dt| dt.with_timezone(&Utc)),
192                        is_active: row.get(8)?,
193                        created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
194                            .map(|dt| dt.with_timezone(&Utc))
195                            .unwrap_or_else(|_| Utc::now()),
196                    })
197                },
198            )
199            .optional()
200            .map_err(EngramError::from)
201    }
202
203    /// List API keys for a user
204    pub fn list_keys(&self, user_id: &UserId) -> Result<Vec<ApiKey>> {
205        let mut stmt = self.conn.prepare(
206            r#"
207            SELECT id, user_id, key_prefix, name, permissions, namespace, expires_at, last_used_at, is_active, created_at
208            FROM api_keys WHERE user_id = ?1 ORDER BY created_at DESC
209            "#,
210        )?;
211
212        let keys = stmt
213            .query_map(params![user_id.as_str()], |row| {
214                let permissions_json: String = row.get(4)?;
215                Ok(ApiKey {
216                    id: row.get(0)?,
217                    user_id: UserId::from_string(row.get::<_, String>(1)?),
218                    key_prefix: row.get(2)?,
219                    name: row.get(3)?,
220                    permissions: serde_json::from_str(&permissions_json).unwrap_or_default(),
221                    namespace: row.get(5)?,
222                    expires_at: row
223                        .get::<_, Option<String>>(6)?
224                        .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
225                        .map(|dt| dt.with_timezone(&Utc)),
226                    last_used_at: row
227                        .get::<_, Option<String>>(7)?
228                        .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
229                        .map(|dt| dt.with_timezone(&Utc)),
230                    is_active: row.get(8)?,
231                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
232                        .map(|dt| dt.with_timezone(&Utc))
233                        .unwrap_or_else(|_| Utc::now()),
234                })
235            })?
236            .collect::<std::result::Result<Vec<_>, _>>()?;
237
238        Ok(keys)
239    }
240
241    /// Revoke an API key
242    pub fn revoke_key(&self, id: &str) -> Result<bool> {
243        let updated = self.conn.execute(
244            "UPDATE api_keys SET is_active = 0 WHERE id = ?1",
245            params![id],
246        )?;
247        Ok(updated > 0)
248    }
249
250    /// Delete an API key
251    pub fn delete_key(&self, id: &str) -> Result<bool> {
252        let deleted = self
253            .conn
254            .execute("DELETE FROM api_keys WHERE id = ?1", params![id])?;
255        Ok(deleted > 0)
256    }
257}
258
259/// Generate a secure API key
260fn generate_api_key() -> String {
261    let mut rng = rand::thread_rng();
262    let bytes: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
263    format!("eng_{}", hex::encode(bytes))
264}
265
266/// Hash an API key for storage
267fn hash_key(key: &str) -> String {
268    let mut hasher = Sha256::new();
269    hasher.update(key.as_bytes());
270    hex::encode(hasher.finalize())
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::auth::{init_auth_tables, Permission, ResourceType, User, UserManager};
277
278    fn setup_db() -> Connection {
279        let conn = Connection::open_in_memory().unwrap();
280        init_auth_tables(&conn).unwrap();
281        conn
282    }
283
284    #[test]
285    fn test_create_and_validate_api_key() {
286        let conn = setup_db();
287
288        // Create user first
289        let user = User::new("testuser");
290        UserManager::new(&conn).create_user(&user, None).unwrap();
291
292        // Create API key
293        let manager = ApiKeyManager::new(&conn);
294        let (api_key, raw_key) = manager
295            .create_api_key(
296                &user.id,
297                "Test Key",
298                PermissionSet::standard_user(),
299                None,
300                None,
301            )
302            .unwrap();
303
304        assert!(raw_key.starts_with("eng_"));
305        assert_eq!(api_key.name, "Test Key");
306
307        // Validate key
308        let claims = manager.validate_key(&raw_key).unwrap().unwrap();
309        assert_eq!(claims.user_id, user.id);
310        assert!(claims
311            .permissions
312            .has_permission(Permission::Read, ResourceType::Memory));
313    }
314
315    #[test]
316    fn test_validate_invalid_key() {
317        let conn = setup_db();
318        let manager = ApiKeyManager::new(&conn);
319
320        let claims = manager.validate_key("eng_invalid_key_here").unwrap();
321        assert!(claims.is_none());
322    }
323
324    #[test]
325    fn test_revoke_key() {
326        let conn = setup_db();
327
328        let user = User::new("testuser");
329        UserManager::new(&conn).create_user(&user, None).unwrap();
330
331        let manager = ApiKeyManager::new(&conn);
332        let (api_key, raw_key) = manager
333            .create_api_key(
334                &user.id,
335                "Revoke Test",
336                PermissionSet::read_only(),
337                None,
338                None,
339            )
340            .unwrap();
341
342        // Key should work
343        assert!(manager.validate_key(&raw_key).unwrap().is_some());
344
345        // Revoke key
346        manager.revoke_key(&api_key.id).unwrap();
347
348        // Key should no longer work
349        assert!(manager.validate_key(&raw_key).unwrap().is_none());
350    }
351
352    #[test]
353    fn test_expired_key() {
354        let conn = setup_db();
355
356        let user = User::new("testuser");
357        UserManager::new(&conn).create_user(&user, None).unwrap();
358
359        let manager = ApiKeyManager::new(&conn);
360
361        // Create key that expires in -1 days (already expired)
362        // We'll manually set the expiration to test
363        let (_api_key, raw_key) = manager
364            .create_api_key(
365                &user.id,
366                "Expiring Key",
367                PermissionSet::read_only(),
368                None,
369                Some(-1),
370            )
371            .unwrap();
372
373        // Key should be expired
374        let claims = manager.validate_key(&raw_key).unwrap();
375        assert!(claims.is_none());
376    }
377
378    #[test]
379    fn test_list_keys() {
380        let conn = setup_db();
381
382        let user = User::new("testuser");
383        UserManager::new(&conn).create_user(&user, None).unwrap();
384
385        let manager = ApiKeyManager::new(&conn);
386        manager
387            .create_api_key(&user.id, "Key 1", PermissionSet::read_only(), None, None)
388            .unwrap();
389        manager
390            .create_api_key(
391                &user.id,
392                "Key 2",
393                PermissionSet::standard_user(),
394                None,
395                None,
396            )
397            .unwrap();
398
399        let keys = manager.list_keys(&user.id).unwrap();
400        assert_eq!(keys.len(), 2);
401    }
402}