offline_intelligence/memory_db/
api_keys_store.rs1use anyhow::{Context, Result};
10use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
11use chrono::{DateTime, Utc};
12use r2d2::Pool;
13use r2d2_sqlite::SqliteConnectionManager;
14use rusqlite::{params, OptionalExtension};
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use tracing::{info, warn};
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
25pub enum ApiKeyType {
26 HuggingFace,
27 OpenRouter,
28}
29
30impl ApiKeyType {
31 pub fn as_str(&self) -> &'static str {
32 match self {
33 ApiKeyType::HuggingFace => "huggingface",
34 ApiKeyType::OpenRouter => "openrouter",
35 }
36 }
37
38 pub fn from_str(s: &str) -> Option<Self> {
39 match s.to_lowercase().as_str() {
40 "huggingface" | "hf" => Some(ApiKeyType::HuggingFace),
41 "openrouter" | "or" => Some(ApiKeyType::OpenRouter),
42 _ => None,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ApiKeyRecord {
51 pub id: i64,
52 pub key_type: String,
53 pub encrypted_value: String,
54 pub created_at: DateTime<Utc>,
55 pub last_used_at: Option<DateTime<Utc>>,
56 pub last_mode: Option<String>,
57 pub usage_count: i64,
58}
59
60pub struct ApiKeysStore {
66 pool: Arc<Pool<SqliteConnectionManager>>,
67}
68
69impl ApiKeysStore {
70 pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
71 Self { pool }
72 }
73
74 pub fn initialize_schema(&self) -> Result<()> {
78 let conn = self.pool.get()?;
79 conn.execute(
80 "CREATE TABLE IF NOT EXISTS api_keys (
81 id INTEGER PRIMARY KEY AUTOINCREMENT,
82 key_type TEXT NOT NULL UNIQUE,
83 encrypted_value TEXT NOT NULL,
84 created_at TEXT NOT NULL,
85 last_used_at TEXT,
86 last_mode TEXT,
87 usage_count INTEGER DEFAULT 0
88 )",
89 [],
90 )?;
91 info!("API keys table initialized");
92 Ok(())
93 }
94
95 pub fn save_key(&self, key_type: ApiKeyType, plaintext: &str) -> Result<()> {
102 let name = key_type.as_str();
103
104 let encrypted = Encryption::encrypt(plaintext);
106 info!("Encrypted API key for: {}", name);
107
108 let conn = self.pool.get()?;
109 let now = Utc::now().to_rfc3339();
110 conn.execute(
111 "INSERT INTO api_keys (key_type, encrypted_value, created_at, usage_count)
112 VALUES (?1, ?2, ?3, 0)
113 ON CONFLICT(key_type) DO UPDATE SET
114 encrypted_value = excluded.encrypted_value,
115 created_at = excluded.created_at",
116 params![name, encrypted, now],
117 )?;
118
119 info!("API key saved for: {}", name);
120 Ok(())
121 }
122
123 pub fn get_key_plaintext(&self, key_type: &ApiKeyType) -> Result<Option<String>> {
129 let record = match self.get_key_metadata(key_type)? {
130 Some(r) => r,
131 None => return Ok(None),
132 };
133
134 match Encryption::decrypt(&record.encrypted_value) {
136 Ok(plaintext) => Ok(Some(plaintext)),
137 Err(e) => {
138 warn!("Failed to decrypt API key '{}': {}", key_type.as_str(), e);
139 Err(e)
140 }
141 }
142 }
143
144 pub fn get_key_metadata(&self, key_type: &ApiKeyType) -> Result<Option<ApiKeyRecord>> {
146 let conn = self.pool.get()?;
147 let name = key_type.as_str();
148
149 let row = conn
150 .query_row(
151 "SELECT id, key_type, encrypted_value, created_at, last_used_at, last_mode, usage_count
152 FROM api_keys WHERE key_type = ?1",
153 params![name],
154 |row| {
155 Ok(ApiKeyRecord {
156 id: row.get(0)?,
157 key_type: row.get(1)?,
158 encrypted_value: row.get(2)?,
159 created_at: row
160 .get::<_, String>(3)?
161 .parse()
162 .unwrap_or_else(|_| Utc::now()),
163 last_used_at: row
164 .get::<_, Option<String>>(4)?
165 .and_then(|s| s.parse().ok()),
166 last_mode: row.get(5)?,
167 usage_count: row.get(6)?,
168 })
169 },
170 )
171 .optional()?;
172
173 Ok(row)
174 }
175
176 pub fn get_all_keys(&self) -> Result<Vec<ApiKeyRecord>> {
178 let conn = self.pool.get()?;
179 let mut stmt = conn.prepare(
180 "SELECT id, key_type, encrypted_value, created_at, last_used_at, last_mode, usage_count
181 FROM api_keys",
182 )?;
183
184 let rows = stmt.query_map([], |row| {
185 Ok(ApiKeyRecord {
186 id: row.get(0)?,
187 key_type: row.get(1)?,
188 encrypted_value: row.get(2)?,
189 created_at: row
190 .get::<_, String>(3)?
191 .parse()
192 .unwrap_or_else(|_| Utc::now()),
193 last_used_at: row
194 .get::<_, Option<String>>(4)?
195 .and_then(|s| s.parse().ok()),
196 last_mode: row.get(5)?,
197 usage_count: row.get(6)?,
198 })
199 })?;
200
201 let mut result = Vec::new();
202 for row in rows {
203 result.push(row?);
204 }
205 Ok(result)
206 }
207
208 pub fn get_all_keys_with_values(&self) -> Result<Vec<(ApiKeyRecord, String)>> {
210 let keys = self.get_all_keys()?;
211 let mut result = Vec::new();
212 for record in keys {
213 match Encryption::decrypt(&record.encrypted_value) {
214 Ok(value) => result.push((record, value)),
215 Err(e) => {
216 warn!("Failed to decrypt API key '{}': {}", record.key_type, e);
217 }
218 }
219 }
220 Ok(result)
221 }
222
223 pub fn key_exists(&self, key_type: &ApiKeyType) -> Result<bool> {
225 let conn = self.pool.get()?;
226 let name = key_type.as_str();
227 let count: i64 = conn.query_row(
228 "SELECT COUNT(*) FROM api_keys WHERE key_type = ?1",
229 params![name],
230 |row| row.get(0),
231 )?;
232 Ok(count > 0)
233 }
234
235 pub fn mark_used(&self, key_type: ApiKeyType, mode: &str) -> Result<()> {
239 let conn = self.pool.get()?;
240 let name = key_type.as_str();
241 let now = Utc::now().to_rfc3339();
242
243 conn.execute(
244 "UPDATE api_keys SET last_used_at = ?1, last_mode = ?2, usage_count = usage_count + 1
245 WHERE key_type = ?3",
246 params![now, mode, name],
247 )?;
248 Ok(())
249 }
250
251 pub fn delete_key(&self, key_type: ApiKeyType) -> Result<bool> {
255 let conn = self.pool.get()?;
256 let name = key_type.as_str();
257
258 let rows = conn.execute("DELETE FROM api_keys WHERE key_type = ?1", params![name])?;
259 info!("Deleted API key for: {}", name);
260 Ok(rows > 0)
261 }
262}
263
264pub struct Encryption;
272
273impl Encryption {
274 fn get_machine_key() -> Vec<u8> {
275 let machine_id = whoami::devicename();
276 let mut key: Vec<u8> = machine_id.bytes().collect();
277 while key.len() < 32 {
278 key.push((key.len() as u8).wrapping_mul(17));
279 }
280 key.truncate(32);
281 key
282 }
283
284 pub fn encrypt(plaintext: &str) -> String {
287 let key = Self::get_machine_key();
288 let encrypted: Vec<u8> = plaintext
289 .as_bytes()
290 .iter()
291 .enumerate()
292 .map(|(i, &b)| b ^ key[i % key.len()])
293 .collect();
294 BASE64.encode(&encrypted)
295 }
296
297 pub fn decrypt(ciphertext: &str) -> Result<String> {
299 let key = Self::get_machine_key();
300 let bytes = BASE64
301 .decode(ciphertext)
302 .context("Failed to decode base64")?;
303 let decrypted: Vec<u8> = bytes
304 .iter()
305 .enumerate()
306 .map(|(i, &b)| b ^ key[i % key.len()])
307 .collect();
308 Ok(String::from_utf8(decrypted)?)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_encrypt_decrypt_roundtrip() {
318 let original = "hf_1234567890abcdef";
319 let encrypted = Encryption::encrypt(original);
320 let decrypted = Encryption::decrypt(&encrypted).unwrap();
321 assert_eq!(original, decrypted);
322 }
323
324 #[test]
325 fn test_encrypt_different_output() {
326 let key1 = Encryption::encrypt("test_key");
327 let key2 = Encryption::encrypt("test_key");
328 assert_eq!(key1, key2);
330 }
331}