1use super::clustering::TaskCluster;
6use super::techniques::PromptingTechnique;
7use super::temperature::TemperaturePerformance;
8use anyhow::{Context, Result};
9use rusqlite::{Connection, params};
10use serde_json;
11use std::path::Path;
12
13pub struct ClusterStorage {
15 conn: Connection,
16}
17
18impl ClusterStorage {
19 pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
21 let conn = Connection::open(db_path)?;
22
23 conn.execute("PRAGMA foreign_keys = ON", [])?;
25
26 Self::create_tables(&conn)?;
28
29 Ok(Self { conn })
30 }
31
32 fn create_tables(conn: &Connection) -> Result<()> {
34 conn.execute(
36 "CREATE TABLE IF NOT EXISTS clusters (
37 id TEXT PRIMARY KEY,
38 description TEXT NOT NULL,
39 embedding BLOB NOT NULL,
40 techniques TEXT NOT NULL,
41 example_tasks TEXT NOT NULL,
42 created_at INTEGER NOT NULL,
43 updated_at INTEGER NOT NULL
44 )",
45 [],
46 )?;
47
48 conn.execute(
50 "CREATE TABLE IF NOT EXISTS technique_performance (
51 cluster_id TEXT NOT NULL,
52 technique TEXT NOT NULL,
53 success_count INTEGER NOT NULL DEFAULT 0,
54 failure_count INTEGER NOT NULL DEFAULT 0,
55 avg_iterations REAL NOT NULL DEFAULT 0.0,
56 avg_quality REAL NOT NULL DEFAULT 0.0,
57 updated_at INTEGER NOT NULL,
58 PRIMARY KEY (cluster_id, technique),
59 FOREIGN KEY (cluster_id) REFERENCES clusters(id) ON DELETE CASCADE
60 )",
61 [],
62 )?;
63
64 conn.execute(
66 "CREATE TABLE IF NOT EXISTS temperature_performance (
67 cluster_id TEXT NOT NULL,
68 temperature_key INTEGER NOT NULL,
69 success_rate REAL NOT NULL DEFAULT 0.5,
70 avg_quality REAL NOT NULL DEFAULT 0.5,
71 sample_count INTEGER NOT NULL DEFAULT 0,
72 last_updated INTEGER NOT NULL,
73 PRIMARY KEY (cluster_id, temperature_key),
74 FOREIGN KEY (cluster_id) REFERENCES clusters(id) ON DELETE CASCADE
75 )",
76 [],
77 )?;
78
79 conn.execute(
81 "CREATE INDEX IF NOT EXISTS idx_clusters_updated
82 ON clusters(updated_at DESC)",
83 [],
84 )?;
85
86 conn.execute(
87 "CREATE INDEX IF NOT EXISTS idx_technique_perf_cluster
88 ON technique_performance(cluster_id)",
89 [],
90 )?;
91
92 conn.execute(
93 "CREATE INDEX IF NOT EXISTS idx_temp_perf_cluster
94 ON temperature_performance(cluster_id)",
95 [],
96 )?;
97
98 Ok(())
99 }
100
101 pub fn save_cluster(&mut self, cluster: &TaskCluster) -> Result<()> {
103 let embedding_bytes =
104 bincode::serde::encode_to_vec(&cluster.embedding, bincode::config::standard())
105 .context("Failed to serialize embedding")?;
106
107 let techniques_json =
108 serde_json::to_string(&cluster.techniques).context("Failed to serialize techniques")?;
109
110 let tasks_json = serde_json::to_string(&cluster.example_tasks)
111 .context("Failed to serialize example tasks")?;
112
113 let timestamp = chrono::Utc::now().timestamp();
114
115 self.conn.execute(
116 "INSERT OR REPLACE INTO clusters
117 (id, description, embedding, techniques, example_tasks, created_at, updated_at)
118 VALUES (?1, ?2, ?3, ?4, ?5,
119 COALESCE((SELECT created_at FROM clusters WHERE id = ?1), ?6),
120 ?6)",
121 params![
122 cluster.id,
123 cluster.description,
124 embedding_bytes,
125 techniques_json,
126 tasks_json,
127 timestamp,
128 ],
129 )?;
130
131 Ok(())
132 }
133
134 pub fn load_clusters(&self) -> Result<Vec<TaskCluster>> {
136 let mut stmt = self.conn.prepare(
137 "SELECT id, description, embedding, techniques, example_tasks
138 FROM clusters
139 ORDER BY updated_at DESC",
140 )?;
141
142 let clusters = stmt
143 .query_map([], |row| {
144 let id: String = row.get(0)?;
145 let description: String = row.get(1)?;
146 let embedding_bytes: Vec<u8> = row.get(2)?;
147 let techniques_json: String = row.get(3)?;
148 let tasks_json: String = row.get(4)?;
149
150 let (embedding, _): (Vec<f32>, _) = bincode::serde::decode_from_slice(
152 &embedding_bytes,
153 bincode::config::standard(),
154 )
155 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
156
157 let techniques: Vec<PromptingTechnique> = serde_json::from_str(&techniques_json)
159 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
160
161 let example_tasks: Vec<String> = serde_json::from_str(&tasks_json)
163 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
164
165 Ok(TaskCluster {
166 id,
167 description,
168 embedding,
169 techniques,
170 example_tasks,
171 seal_query_cores: Vec::new(), avg_seal_quality: 0.5, recommended_complexity: super::techniques::ComplexityLevel::Moderate,
174 })
175 })?
176 .collect::<Result<Vec<_>, _>>()?;
177
178 Ok(clusters)
179 }
180
181 pub fn load_cluster(&self, cluster_id: &str) -> Result<Option<TaskCluster>> {
183 let mut stmt = self.conn.prepare(
184 "SELECT id, description, embedding, techniques, example_tasks
185 FROM clusters
186 WHERE id = ?1",
187 )?;
188
189 let mut rows = stmt.query([cluster_id])?;
190
191 if let Some(row) = rows.next()? {
192 let id: String = row.get(0)?;
193 let description: String = row.get(1)?;
194 let embedding_bytes: Vec<u8> = row.get(2)?;
195 let techniques_json: String = row.get(3)?;
196 let tasks_json: String = row.get(4)?;
197
198 let (embedding, _): (Vec<f32>, _) =
199 bincode::serde::decode_from_slice(&embedding_bytes, bincode::config::standard())?;
200 let techniques = serde_json::from_str(&techniques_json)?;
201 let example_tasks = serde_json::from_str(&tasks_json)?;
202
203 Ok(Some(TaskCluster {
204 id,
205 description,
206 embedding,
207 techniques,
208 example_tasks,
209 seal_query_cores: Vec::new(),
210 avg_seal_quality: 0.5,
211 recommended_complexity: super::techniques::ComplexityLevel::Moderate,
212 }))
213 } else {
214 Ok(None)
215 }
216 }
217
218 pub fn delete_cluster(&mut self, cluster_id: &str) -> Result<()> {
220 self.conn
221 .execute("DELETE FROM clusters WHERE id = ?1", [cluster_id])?;
222 Ok(())
223 }
224
225 pub fn save_temperature_performance(
227 &mut self,
228 cluster_id: &str,
229 temperature_key: i32,
230 perf: &TemperaturePerformance,
231 ) -> Result<()> {
232 self.conn.execute(
233 "INSERT OR REPLACE INTO temperature_performance
234 (cluster_id, temperature_key, success_rate, avg_quality, sample_count, last_updated)
235 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
236 params![
237 cluster_id,
238 temperature_key,
239 perf.success_rate,
240 perf.avg_quality,
241 perf.sample_count,
242 perf.last_updated,
243 ],
244 )?;
245 Ok(())
246 }
247
248 pub fn load_temperature_performance(
250 &self,
251 cluster_id: &str,
252 ) -> Result<Vec<(i32, TemperaturePerformance)>> {
253 let mut stmt = self.conn.prepare(
254 "SELECT temperature_key, success_rate, avg_quality, sample_count, last_updated
255 FROM temperature_performance
256 WHERE cluster_id = ?1",
257 )?;
258
259 let perfs = stmt
260 .query_map([cluster_id], |row| {
261 let temp_key: i32 = row.get(0)?;
262 let perf = TemperaturePerformance {
263 success_rate: row.get(1)?,
264 avg_quality: row.get(2)?,
265 sample_count: row.get(3)?,
266 last_updated: row.get(4)?,
267 };
268 Ok((temp_key, perf))
269 })?
270 .collect::<Result<Vec<_>, _>>()?;
271
272 Ok(perfs)
273 }
274
275 pub fn get_stats(&self) -> Result<StorageStats> {
277 let cluster_count: u32 =
278 self.conn
279 .query_row("SELECT COUNT(*) FROM clusters", [], |row| row.get(0))?;
280
281 let technique_perf_count: u32 =
282 self.conn
283 .query_row("SELECT COUNT(*) FROM technique_performance", [], |row| {
284 row.get(0)
285 })?;
286
287 let temp_perf_count: u32 =
288 self.conn
289 .query_row("SELECT COUNT(*) FROM temperature_performance", [], |row| {
290 row.get(0)
291 })?;
292
293 let db_size_bytes = std::fs::metadata(self.conn.path().unwrap_or_default())
294 .map(|m| m.len())
295 .unwrap_or(0);
296
297 Ok(StorageStats {
298 cluster_count,
299 technique_perf_count,
300 temp_perf_count,
301 db_size_bytes,
302 })
303 }
304
305 pub fn vacuum(&mut self) -> Result<()> {
307 self.conn.execute("VACUUM", [])?;
308 Ok(())
309 }
310}
311
312#[derive(Debug, Clone)]
314pub struct StorageStats {
315 pub cluster_count: u32,
317 pub technique_perf_count: u32,
319 pub temp_perf_count: u32,
321 pub db_size_bytes: u64,
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::techniques::ComplexityLevel;
329
330 #[test]
331 fn test_create_and_load_cluster() {
332 let temp_dir = tempfile::tempdir().unwrap();
333 let db_path = temp_dir.path().join("test.db");
334
335 let mut storage = ClusterStorage::new(&db_path).unwrap();
336
337 let cluster = TaskCluster {
339 id: "test_cluster".to_string(),
340 description: "Test cluster description".to_string(),
341 embedding: vec![0.1, 0.2, 0.3, 0.4],
342 techniques: vec![
343 PromptingTechnique::ChainOfThought,
344 PromptingTechnique::RolePlaying,
345 ],
346 example_tasks: vec!["Task 1".to_string(), "Task 2".to_string()],
347 seal_query_cores: vec![],
348 avg_seal_quality: 0.8,
349 recommended_complexity: ComplexityLevel::Moderate,
350 };
351
352 storage.save_cluster(&cluster).unwrap();
354
355 let loaded = storage.load_cluster("test_cluster").unwrap().unwrap();
357 assert_eq!(loaded.id, "test_cluster");
358 assert_eq!(loaded.description, "Test cluster description");
359 assert_eq!(loaded.embedding, vec![0.1, 0.2, 0.3, 0.4]);
360 assert_eq!(loaded.techniques.len(), 2);
361 assert_eq!(loaded.example_tasks.len(), 2);
362 }
363
364 #[test]
365 fn test_load_all_clusters() {
366 let temp_dir = tempfile::tempdir().unwrap();
367 let db_path = temp_dir.path().join("test.db");
368
369 let mut storage = ClusterStorage::new(&db_path).unwrap();
370
371 for i in 0..3 {
373 let cluster = TaskCluster {
374 id: format!("cluster_{}", i),
375 description: format!("Cluster {}", i),
376 embedding: vec![i as f32; 4],
377 techniques: vec![PromptingTechnique::ChainOfThought],
378 example_tasks: vec![format!("Task {}", i)],
379 seal_query_cores: vec![],
380 avg_seal_quality: 0.5,
381 recommended_complexity: ComplexityLevel::Simple,
382 };
383 storage.save_cluster(&cluster).unwrap();
384 }
385
386 let clusters = storage.load_clusters().unwrap();
388 assert_eq!(clusters.len(), 3);
389 }
390
391 #[test]
392 fn test_delete_cluster() {
393 let temp_dir = tempfile::tempdir().unwrap();
394 let db_path = temp_dir.path().join("test.db");
395
396 let mut storage = ClusterStorage::new(&db_path).unwrap();
397
398 let cluster = TaskCluster::new(
399 "test".to_string(),
400 "Test".to_string(),
401 vec![0.5; 4],
402 vec![PromptingTechnique::RolePlaying],
403 vec!["Example".to_string()],
404 );
405
406 storage.save_cluster(&cluster).unwrap();
407 assert!(storage.load_cluster("test").unwrap().is_some());
408
409 storage.delete_cluster("test").unwrap();
410 assert!(storage.load_cluster("test").unwrap().is_none());
411 }
412
413 #[test]
414 fn test_temperature_performance_storage() {
415 let temp_dir = tempfile::tempdir().unwrap();
416 let db_path = temp_dir.path().join("test.db");
417
418 let mut storage = ClusterStorage::new(&db_path).unwrap();
419
420 let cluster = TaskCluster::new(
422 "test".to_string(),
423 "Test".to_string(),
424 vec![0.5; 4],
425 vec![],
426 vec![],
427 );
428 storage.save_cluster(&cluster).unwrap();
429
430 let perf = TemperaturePerformance {
432 success_rate: 0.85,
433 avg_quality: 0.9,
434 sample_count: 10,
435 last_updated: 12345,
436 };
437
438 storage
439 .save_temperature_performance("test", 0, &perf)
440 .unwrap();
441
442 let loaded = storage.load_temperature_performance("test").unwrap();
444 assert_eq!(loaded.len(), 1);
445 assert_eq!(loaded[0].0, 0); assert_eq!(loaded[0].1.sample_count, 10);
447 assert!((loaded[0].1.success_rate - 0.85).abs() < 0.01);
448 }
449
450 #[test]
451 fn test_storage_stats() {
452 let temp_dir = tempfile::tempdir().unwrap();
453 let db_path = temp_dir.path().join("test.db");
454
455 let mut storage = ClusterStorage::new(&db_path).unwrap();
456
457 let stats = storage.get_stats().unwrap();
458 assert_eq!(stats.cluster_count, 0);
459
460 let cluster = TaskCluster::new(
462 "test".to_string(),
463 "Test".to_string(),
464 vec![0.5; 4],
465 vec![],
466 vec![],
467 );
468 storage.save_cluster(&cluster).unwrap();
469
470 let stats = storage.get_stats().unwrap();
471 assert_eq!(stats.cluster_count, 1);
472 assert!(stats.db_size_bytes > 0);
473 }
474}