Skip to main content

brainwires_prompting/
storage.rs

1//! Persistence Layer
2//!
3//! This module provides SQLite storage for task clusters and technique performance.
4
5use super::clustering::TaskCluster;
6use super::techniques::PromptingTechnique;
7use super::temperature::TemperaturePerformance;
8use anyhow::{Context, Result};
9use rusqlite::{Connection, params};
10use serde_json;
11use std::path::Path;
12
13/// Manages persistent storage of task clusters and performance data
14pub struct ClusterStorage {
15    conn: Connection,
16}
17
18impl ClusterStorage {
19    /// Create a new cluster storage at the specified database path
20    pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
21        let conn = Connection::open(db_path)?;
22
23        // Enable foreign keys
24        conn.execute("PRAGMA foreign_keys = ON", [])?;
25
26        // Create tables if they don't exist
27        Self::create_tables(&conn)?;
28
29        Ok(Self { conn })
30    }
31
32    /// Create all required tables
33    fn create_tables(conn: &Connection) -> Result<()> {
34        // Clusters table
35        conn.execute(
36            "CREATE TABLE IF NOT EXISTS clusters (
37                id TEXT PRIMARY KEY,
38                description TEXT NOT NULL,
39                embedding BLOB NOT NULL,
40                techniques TEXT NOT NULL,
41                example_tasks TEXT NOT NULL,
42                created_at INTEGER NOT NULL,
43                updated_at INTEGER NOT NULL
44            )",
45            [],
46        )?;
47
48        // Technique performance table
49        conn.execute(
50            "CREATE TABLE IF NOT EXISTS technique_performance (
51                cluster_id TEXT NOT NULL,
52                technique TEXT NOT NULL,
53                success_count INTEGER NOT NULL DEFAULT 0,
54                failure_count INTEGER NOT NULL DEFAULT 0,
55                avg_iterations REAL NOT NULL DEFAULT 0.0,
56                avg_quality REAL NOT NULL DEFAULT 0.0,
57                updated_at INTEGER NOT NULL,
58                PRIMARY KEY (cluster_id, technique),
59                FOREIGN KEY (cluster_id) REFERENCES clusters(id) ON DELETE CASCADE
60            )",
61            [],
62        )?;
63
64        // Temperature performance table
65        conn.execute(
66            "CREATE TABLE IF NOT EXISTS temperature_performance (
67                cluster_id TEXT NOT NULL,
68                temperature_key INTEGER NOT NULL,
69                success_rate REAL NOT NULL DEFAULT 0.5,
70                avg_quality REAL NOT NULL DEFAULT 0.5,
71                sample_count INTEGER NOT NULL DEFAULT 0,
72                last_updated INTEGER NOT NULL,
73                PRIMARY KEY (cluster_id, temperature_key),
74                FOREIGN KEY (cluster_id) REFERENCES clusters(id) ON DELETE CASCADE
75            )",
76            [],
77        )?;
78
79        // Create indexes for common queries
80        conn.execute(
81            "CREATE INDEX IF NOT EXISTS idx_clusters_updated
82             ON clusters(updated_at DESC)",
83            [],
84        )?;
85
86        conn.execute(
87            "CREATE INDEX IF NOT EXISTS idx_technique_perf_cluster
88             ON technique_performance(cluster_id)",
89            [],
90        )?;
91
92        conn.execute(
93            "CREATE INDEX IF NOT EXISTS idx_temp_perf_cluster
94             ON temperature_performance(cluster_id)",
95            [],
96        )?;
97
98        Ok(())
99    }
100
101    /// Save a task cluster to the database
102    pub fn save_cluster(&mut self, cluster: &TaskCluster) -> Result<()> {
103        let embedding_bytes =
104            bincode::serde::encode_to_vec(&cluster.embedding, bincode::config::standard())
105                .context("Failed to serialize embedding")?;
106
107        let techniques_json =
108            serde_json::to_string(&cluster.techniques).context("Failed to serialize techniques")?;
109
110        let tasks_json = serde_json::to_string(&cluster.example_tasks)
111            .context("Failed to serialize example tasks")?;
112
113        let timestamp = chrono::Utc::now().timestamp();
114
115        self.conn.execute(
116            "INSERT OR REPLACE INTO clusters
117             (id, description, embedding, techniques, example_tasks, created_at, updated_at)
118             VALUES (?1, ?2, ?3, ?4, ?5,
119                     COALESCE((SELECT created_at FROM clusters WHERE id = ?1), ?6),
120                     ?6)",
121            params![
122                cluster.id,
123                cluster.description,
124                embedding_bytes,
125                techniques_json,
126                tasks_json,
127                timestamp,
128            ],
129        )?;
130
131        Ok(())
132    }
133
134    /// Load all clusters from the database
135    pub fn load_clusters(&self) -> Result<Vec<TaskCluster>> {
136        let mut stmt = self.conn.prepare(
137            "SELECT id, description, embedding, techniques, example_tasks
138             FROM clusters
139             ORDER BY updated_at DESC",
140        )?;
141
142        let clusters = stmt
143            .query_map([], |row| {
144                let id: String = row.get(0)?;
145                let description: String = row.get(1)?;
146                let embedding_bytes: Vec<u8> = row.get(2)?;
147                let techniques_json: String = row.get(3)?;
148                let tasks_json: String = row.get(4)?;
149
150                // Deserialize embedding
151                let (embedding, _): (Vec<f32>, _) = bincode::serde::decode_from_slice(
152                    &embedding_bytes,
153                    bincode::config::standard(),
154                )
155                .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
156
157                // Deserialize techniques
158                let techniques: Vec<PromptingTechnique> = serde_json::from_str(&techniques_json)
159                    .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
160
161                // Deserialize example tasks
162                let example_tasks: Vec<String> = serde_json::from_str(&tasks_json)
163                    .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
164
165                Ok(TaskCluster {
166                    id,
167                    description,
168                    embedding,
169                    techniques,
170                    example_tasks,
171                    seal_query_cores: Vec::new(), // Not stored currently
172                    avg_seal_quality: 0.5,        // Not stored currently
173                    recommended_complexity: super::techniques::ComplexityLevel::Moderate,
174                })
175            })?
176            .collect::<Result<Vec<_>, _>>()?;
177
178        Ok(clusters)
179    }
180
181    /// Load a specific cluster by ID
182    pub fn load_cluster(&self, cluster_id: &str) -> Result<Option<TaskCluster>> {
183        let mut stmt = self.conn.prepare(
184            "SELECT id, description, embedding, techniques, example_tasks
185             FROM clusters
186             WHERE id = ?1",
187        )?;
188
189        let mut rows = stmt.query([cluster_id])?;
190
191        if let Some(row) = rows.next()? {
192            let id: String = row.get(0)?;
193            let description: String = row.get(1)?;
194            let embedding_bytes: Vec<u8> = row.get(2)?;
195            let techniques_json: String = row.get(3)?;
196            let tasks_json: String = row.get(4)?;
197
198            let (embedding, _): (Vec<f32>, _) =
199                bincode::serde::decode_from_slice(&embedding_bytes, bincode::config::standard())?;
200            let techniques = serde_json::from_str(&techniques_json)?;
201            let example_tasks = serde_json::from_str(&tasks_json)?;
202
203            Ok(Some(TaskCluster {
204                id,
205                description,
206                embedding,
207                techniques,
208                example_tasks,
209                seal_query_cores: Vec::new(),
210                avg_seal_quality: 0.5,
211                recommended_complexity: super::techniques::ComplexityLevel::Moderate,
212            }))
213        } else {
214            Ok(None)
215        }
216    }
217
218    /// Delete a cluster and all its associated performance data
219    pub fn delete_cluster(&mut self, cluster_id: &str) -> Result<()> {
220        self.conn
221            .execute("DELETE FROM clusters WHERE id = ?1", [cluster_id])?;
222        Ok(())
223    }
224
225    /// Save temperature performance for a cluster
226    pub fn save_temperature_performance(
227        &mut self,
228        cluster_id: &str,
229        temperature_key: i32,
230        perf: &TemperaturePerformance,
231    ) -> Result<()> {
232        self.conn.execute(
233            "INSERT OR REPLACE INTO temperature_performance
234             (cluster_id, temperature_key, success_rate, avg_quality, sample_count, last_updated)
235             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
236            params![
237                cluster_id,
238                temperature_key,
239                perf.success_rate,
240                perf.avg_quality,
241                perf.sample_count,
242                perf.last_updated,
243            ],
244        )?;
245        Ok(())
246    }
247
248    /// Load temperature performance for a cluster
249    pub fn load_temperature_performance(
250        &self,
251        cluster_id: &str,
252    ) -> Result<Vec<(i32, TemperaturePerformance)>> {
253        let mut stmt = self.conn.prepare(
254            "SELECT temperature_key, success_rate, avg_quality, sample_count, last_updated
255             FROM temperature_performance
256             WHERE cluster_id = ?1",
257        )?;
258
259        let perfs = stmt
260            .query_map([cluster_id], |row| {
261                let temp_key: i32 = row.get(0)?;
262                let perf = TemperaturePerformance {
263                    success_rate: row.get(1)?,
264                    avg_quality: row.get(2)?,
265                    sample_count: row.get(3)?,
266                    last_updated: row.get(4)?,
267                };
268                Ok((temp_key, perf))
269            })?
270            .collect::<Result<Vec<_>, _>>()?;
271
272        Ok(perfs)
273    }
274
275    /// Get statistics about stored data
276    pub fn get_stats(&self) -> Result<StorageStats> {
277        let cluster_count: u32 =
278            self.conn
279                .query_row("SELECT COUNT(*) FROM clusters", [], |row| row.get(0))?;
280
281        let technique_perf_count: u32 =
282            self.conn
283                .query_row("SELECT COUNT(*) FROM technique_performance", [], |row| {
284                    row.get(0)
285                })?;
286
287        let temp_perf_count: u32 =
288            self.conn
289                .query_row("SELECT COUNT(*) FROM temperature_performance", [], |row| {
290                    row.get(0)
291                })?;
292
293        let db_size_bytes = std::fs::metadata(self.conn.path().unwrap_or_default())
294            .map(|m| m.len())
295            .unwrap_or(0);
296
297        Ok(StorageStats {
298            cluster_count,
299            technique_perf_count,
300            temp_perf_count,
301            db_size_bytes,
302        })
303    }
304
305    /// Vacuum the database to reclaim space
306    pub fn vacuum(&mut self) -> Result<()> {
307        self.conn.execute("VACUUM", [])?;
308        Ok(())
309    }
310}
311
312/// Statistics about stored data
313#[derive(Debug, Clone)]
314pub struct StorageStats {
315    /// Number of stored task clusters.
316    pub cluster_count: u32,
317    /// Number of technique performance records.
318    pub technique_perf_count: u32,
319    /// Number of temperature performance records.
320    pub temp_perf_count: u32,
321    /// Total database size in bytes.
322    pub db_size_bytes: u64,
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::techniques::ComplexityLevel;
329
330    #[test]
331    fn test_create_and_load_cluster() {
332        let temp_dir = tempfile::tempdir().unwrap();
333        let db_path = temp_dir.path().join("test.db");
334
335        let mut storage = ClusterStorage::new(&db_path).unwrap();
336
337        // Create test cluster
338        let cluster = TaskCluster {
339            id: "test_cluster".to_string(),
340            description: "Test cluster description".to_string(),
341            embedding: vec![0.1, 0.2, 0.3, 0.4],
342            techniques: vec![
343                PromptingTechnique::ChainOfThought,
344                PromptingTechnique::RolePlaying,
345            ],
346            example_tasks: vec!["Task 1".to_string(), "Task 2".to_string()],
347            seal_query_cores: vec![],
348            avg_seal_quality: 0.8,
349            recommended_complexity: ComplexityLevel::Moderate,
350        };
351
352        // Save cluster
353        storage.save_cluster(&cluster).unwrap();
354
355        // Load cluster
356        let loaded = storage.load_cluster("test_cluster").unwrap().unwrap();
357        assert_eq!(loaded.id, "test_cluster");
358        assert_eq!(loaded.description, "Test cluster description");
359        assert_eq!(loaded.embedding, vec![0.1, 0.2, 0.3, 0.4]);
360        assert_eq!(loaded.techniques.len(), 2);
361        assert_eq!(loaded.example_tasks.len(), 2);
362    }
363
364    #[test]
365    fn test_load_all_clusters() {
366        let temp_dir = tempfile::tempdir().unwrap();
367        let db_path = temp_dir.path().join("test.db");
368
369        let mut storage = ClusterStorage::new(&db_path).unwrap();
370
371        // Create multiple clusters
372        for i in 0..3 {
373            let cluster = TaskCluster {
374                id: format!("cluster_{}", i),
375                description: format!("Cluster {}", i),
376                embedding: vec![i as f32; 4],
377                techniques: vec![PromptingTechnique::ChainOfThought],
378                example_tasks: vec![format!("Task {}", i)],
379                seal_query_cores: vec![],
380                avg_seal_quality: 0.5,
381                recommended_complexity: ComplexityLevel::Simple,
382            };
383            storage.save_cluster(&cluster).unwrap();
384        }
385
386        // Load all
387        let clusters = storage.load_clusters().unwrap();
388        assert_eq!(clusters.len(), 3);
389    }
390
391    #[test]
392    fn test_delete_cluster() {
393        let temp_dir = tempfile::tempdir().unwrap();
394        let db_path = temp_dir.path().join("test.db");
395
396        let mut storage = ClusterStorage::new(&db_path).unwrap();
397
398        let cluster = TaskCluster::new(
399            "test".to_string(),
400            "Test".to_string(),
401            vec![0.5; 4],
402            vec![PromptingTechnique::RolePlaying],
403            vec!["Example".to_string()],
404        );
405
406        storage.save_cluster(&cluster).unwrap();
407        assert!(storage.load_cluster("test").unwrap().is_some());
408
409        storage.delete_cluster("test").unwrap();
410        assert!(storage.load_cluster("test").unwrap().is_none());
411    }
412
413    #[test]
414    fn test_temperature_performance_storage() {
415        let temp_dir = tempfile::tempdir().unwrap();
416        let db_path = temp_dir.path().join("test.db");
417
418        let mut storage = ClusterStorage::new(&db_path).unwrap();
419
420        // Create cluster first
421        let cluster = TaskCluster::new(
422            "test".to_string(),
423            "Test".to_string(),
424            vec![0.5; 4],
425            vec![],
426            vec![],
427        );
428        storage.save_cluster(&cluster).unwrap();
429
430        // Save temperature performance
431        let perf = TemperaturePerformance {
432            success_rate: 0.85,
433            avg_quality: 0.9,
434            sample_count: 10,
435            last_updated: 12345,
436        };
437
438        storage
439            .save_temperature_performance("test", 0, &perf)
440            .unwrap();
441
442        // Load temperature performance
443        let loaded = storage.load_temperature_performance("test").unwrap();
444        assert_eq!(loaded.len(), 1);
445        assert_eq!(loaded[0].0, 0); // temperature_key
446        assert_eq!(loaded[0].1.sample_count, 10);
447        assert!((loaded[0].1.success_rate - 0.85).abs() < 0.01);
448    }
449
450    #[test]
451    fn test_storage_stats() {
452        let temp_dir = tempfile::tempdir().unwrap();
453        let db_path = temp_dir.path().join("test.db");
454
455        let mut storage = ClusterStorage::new(&db_path).unwrap();
456
457        let stats = storage.get_stats().unwrap();
458        assert_eq!(stats.cluster_count, 0);
459
460        // Add a cluster
461        let cluster = TaskCluster::new(
462            "test".to_string(),
463            "Test".to_string(),
464            vec![0.5; 4],
465            vec![],
466            vec![],
467        );
468        storage.save_cluster(&cluster).unwrap();
469
470        let stats = storage.get_stats().unwrap();
471        assert_eq!(stats.cluster_count, 1);
472        assert!(stats.db_size_bytes > 0);
473    }
474}